Message queue based state machine

This commit is contained in:
Viktor Lofgren 2023-07-04 17:42:06 +02:00
parent 31ae71c7d6
commit 2ae0b8c159
19 changed files with 610 additions and 37 deletions

View File

@ -19,6 +19,7 @@ dependencies {
implementation libs.spark
implementation libs.guice
implementation libs.gson
implementation libs.rxjava
implementation libs.bundles.prometheus

View File

@ -1,5 +1,8 @@
# Message Queue
Implements a message queue using mariadb.
Implements resilient message queueing for the application,
as well as a finite state machine library backed by the
message queue that enables long-running tasks that outlive
the execution lifespan of the involved processes.
![Message States](msgstate.svg)

View File

@ -5,6 +5,7 @@ public record MqMessage(
long relatedId,
String function,
String payload,
MqMessageState state
MqMessageState state,
boolean expectsResponse
) {
}

View File

@ -15,6 +15,7 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
public class MqInbox {
private final Logger logger = LoggerFactory.getLogger(MqInbox.class);
@ -26,7 +27,7 @@ public class MqInbox {
private volatile boolean run = true;
private final int pollIntervalMs = Integer.getInteger("mq.inbox.poll-interval-ms", 1000);
private final int pollIntervalMs = Integer.getInteger("mq.inbox.poll-interval-ms", 100);
private final List<MqSubscription> eventSubscribers = new ArrayList<>();
private final LinkedBlockingQueue<MqMessage> queue = new LinkedBlockingQueue<>(32);
@ -114,27 +115,52 @@ public class MqInbox {
private void handleMessageWithSubscriber(MqSubscription subscriber, MqMessage msg) {
threadPool.execute(() -> {
try {
final var rsp = subscriber.handle(msg);
sendResponse(msg, rsp.state(), rsp.message());
} catch (Exception ex) {
logger.error("Message Queue subscriber threw exception", ex);
sendResponse(msg, MqMessageState.ERR);
}
});
if (msg.expectsResponse()) {
threadPool.execute(() -> respondToMessage(subscriber, msg));
}
else {
threadPool.execute(() -> acknowledgeNotification(subscriber, msg));
}
}
private void sendResponse(MqMessage msg, MqMessageState mqMessageState) {
private void respondToMessage(MqSubscription subscriber, MqMessage msg) {
try {
persistence.updateMessageState(msg.msgId(), mqMessageState);
final var rsp = subscriber.onRequest(msg);
sendResponse(msg, rsp.state(), rsp.message());
} catch (Exception ex) {
logger.error("Message Queue subscriber threw exception", ex);
sendResponse(msg, MqMessageState.ERR);
}
}
private void acknowledgeNotification(MqSubscription subscriber, MqMessage msg) {
try {
subscriber.onNotification(msg);
updateMessageState(msg, MqMessageState.OK);
} catch (Exception ex) {
logger.error("Message Queue subscriber threw exception", ex);
updateMessageState(msg, MqMessageState.ERR);
}
}
private void sendResponse(MqMessage msg, MqMessageState state) {
try {
persistence.updateMessageState(msg.msgId(), state);
}
catch (SQLException ex) {
logger.error("Failed to update message state", ex);
}
}
private void updateMessageState(MqMessage msg, MqMessageState state) {
try {
persistence.updateMessageState(msg.msgId(), state);
}
catch (SQLException ex2) {
logger.error("Failed to update message state", ex2);
}
}
private void sendResponse(MqMessage msg, MqMessageState mqMessageState, String response) {
try {
persistence.sendResponse(msg.msgId(), mqMessageState, response);
@ -159,14 +185,25 @@ public class MqInbox {
}
private Collection<MqMessage> pollInbox(long tick) {
try {
return persistence.pollInbox(inboxName, instanceUUID, tick);
}
catch (SQLException ex) {
logger.error("Failed to poll inbox", ex);
return List.of();
}
}
try {
return persistence.pollInbox(inboxName, instanceUUID, tick);
}
catch (SQLException ex) {
logger.error("Failed to poll inbox", ex);
return List.of();
}
}
/** Retrieve the last N messages from the inbox. */
public List<MqMessage> replay(int lastN) {
try {
return persistence.lastNMessages(inboxName, lastN);
}
catch (SQLException ex) {
logger.error("Failed to replay inbox", ex);
return List.of();
}
}
private class MqInboxShredder implements MqSubscription {
@ -177,9 +214,14 @@ public class MqInbox {
}
@Override
public MqInboxResponse handle(MqMessage msg) {
public MqInboxResponse onRequest(MqMessage msg) {
logger.warn("Unhandled message {}", msg.msgId());
return MqInboxResponse.err();
}
@Override
public void onNotification(MqMessage msg) {
logger.warn("Unhandled message {}", msg.msgId());
}
}
}

View File

@ -3,7 +3,12 @@ package nu.marginalia.mq.inbox;
import nu.marginalia.mq.MqMessage;
public interface MqSubscription {
/** Return true if this subscription should handle the message. */
boolean filter(MqMessage rawMessage);
MqInboxResponse handle(MqMessage msg);
/** Handle the message and return a response. */
MqInboxResponse onRequest(MqMessage msg);
/** Handle a message with no reply address */
void onNotification(MqMessage msg);
}

View File

@ -6,7 +6,6 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.sql.SQLException;
import java.time.Duration;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
@ -21,7 +20,7 @@ public class MqOutbox {
private final ConcurrentHashMap<Long, Long> pendingRequests = new ConcurrentHashMap<>();
private final ConcurrentHashMap<Long, MqMessage> pendingResponses = new ConcurrentHashMap<>();
private final int pollIntervalMs = Integer.getInteger("mq.outbox.poll-interval-ms", 1000);
private final int pollIntervalMs = Integer.getInteger("mq.outbox.poll-interval-ms", 100);
private final Thread pollThread;
private volatile boolean run = true;
@ -103,5 +102,8 @@ public class MqOutbox {
}
}
public long notify(String function, String payload) throws Exception {
return persistence.sendNewMessage(inboxName, null, function, payload, null);
}
}

View File

@ -1,5 +1,6 @@
package nu.marginalia.mq.persistence;
import com.google.common.collect.Lists;
import com.google.inject.Inject;
import com.google.inject.Singleton;
import com.zaxxer.hikari.HikariDataSource;
@ -164,7 +165,7 @@ public class MqPersistence {
try (var conn = dataSource.getConnection();
var queryStmt = conn.prepareStatement("""
SELECT ID, RELATED_ID, FUNCTION, PAYLOAD, STATE FROM PROC_MESSAGE
SELECT ID, RELATED_ID, FUNCTION, PAYLOAD, STATE, SENDER_INBOX FROM PROC_MESSAGE
WHERE OWNER_INSTANCE=? AND OWNER_TICK=?
""")
) {
@ -182,8 +183,9 @@ public class MqPersistence {
String payload = rs.getString(4);
MqMessageState state = MqMessageState.valueOf(rs.getString(5));
boolean expectsResponse = rs.getBoolean(6);
var msg = new MqMessage(msgId, relatedId, function, payload, state);
var msg = new MqMessage(msgId, relatedId, function, payload, state, expectsResponse);
messages.add(msg);
}
@ -226,7 +228,7 @@ public class MqPersistence {
MqMessageState state = MqMessageState.valueOf(rs.getString(5));
var msg = new MqMessage(msgId, relatedId, function, payload, state);
var msg = new MqMessage(msgId, relatedId, function, payload, state, false);
messages.add(msg);
}
@ -234,4 +236,38 @@ public class MqPersistence {
return messages;
}
}
public List<MqMessage> lastNMessages(String inboxName, int lastN) throws SQLException {
try (var conn = dataSource.getConnection();
var stmt = conn.prepareStatement("""
SELECT ID, RELATED_ID, FUNCTION, PAYLOAD, STATE, SENDER_INBOX FROM PROC_MESSAGE
WHERE RECIPIENT_INBOX = ?
ORDER BY ID DESC LIMIT ?
""")) {
stmt.setString(1, inboxName);
stmt.setInt(2, lastN);
List<MqMessage> messages = new ArrayList<>(lastN);
var rs = stmt.executeQuery();
while (rs.next()) {
long msgId = rs.getLong(1);
long relatedId = rs.getLong(2);
String function = rs.getString(3);
String payload = rs.getString(4);
MqMessageState state = MqMessageState.valueOf(rs.getString(5));
boolean expectsResponse = rs.getBoolean(6);
var msg = new MqMessage(msgId, relatedId, function, payload, state, expectsResponse);
messages.add(msg);
}
Lists.reverse(messages);
return messages;
}
}
}

View File

@ -0,0 +1,66 @@
package nu.marginalia.mqsm;
import com.google.gson.Gson;
import com.google.inject.Inject;
import com.google.inject.Singleton;
import nu.marginalia.mqsm.state.MachineState;
import nu.marginalia.mqsm.state.StateTransition;
import java.util.function.Function;
import java.util.function.Supplier;
@Singleton
public class StateFactory {
private final Gson gson;
@Inject
public StateFactory(Gson gson) {
this.gson = gson;
}
public <T> MachineState create(String name, Class<T> param, Function<T, StateTransition> logic) {
return new MachineState() {
@Override
public String name() {
return name;
}
@Override
public StateTransition next(String message) {
return logic.apply(gson.fromJson(message, param));
}
@Override
public boolean isFinal() {
return false;
}
};
}
public MachineState create(String name, Supplier<StateTransition> logic) {
return new MachineState() {
@Override
public String name() {
return name;
}
@Override
public StateTransition next(String message) {
return logic.get();
}
@Override
public boolean isFinal() {
return false;
}
};
}
public StateTransition transition(String state) {
return StateTransition.to(state);
}
public StateTransition transition(String state, Object message) {
return StateTransition.to(state, gson.toJson(message));
}
}

View File

@ -0,0 +1,176 @@
package nu.marginalia.mqsm;
import nu.marginalia.mq.MqMessage;
import nu.marginalia.mq.MqMessageState;
import nu.marginalia.mq.inbox.MqInbox;
import nu.marginalia.mq.inbox.MqInboxResponse;
import nu.marginalia.mq.inbox.MqSubscription;
import nu.marginalia.mq.outbox.MqOutbox;
import nu.marginalia.mq.persistence.MqPersistence;
import nu.marginalia.mqsm.state.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
/** A state machine that can be used to implement a finite state machine
* using a message queue as the persistence layer. The state machine is
* resilient to crashes and can be resumed from the last state.
*/
public class StateMachine {
private final Logger logger = LoggerFactory.getLogger(StateMachine.class);
private final MqInbox smInbox;
private final MqOutbox smOutbox;
private final String queueName;
private MachineState state;
private final MachineState errorState = new ErrorState();
private final MachineState finalState = new FinalState();
private final MachineState resumingState = new ResumingState();
private final Map<String, MachineState> allStates = new HashMap<>();
public StateMachine(MqPersistence persistence, String queueName, UUID instanceUUID) {
this.queueName = queueName;
smInbox = new MqInbox(persistence, queueName, instanceUUID);
smOutbox = new MqOutbox(persistence, queueName, instanceUUID);
smInbox.subscribe(new StateEventSubscription());
registerStates(List.of(errorState, finalState, resumingState));
}
/** Register the state graph */
public void registerStates(MachineState... states) {
if (state != null) {
throw new IllegalStateException("Cannot register states after state machine has been initialized");
}
for (var state : states) {
allStates.put(state.name(), state);
}
}
/** Register the state graph */
public void registerStates(List<MachineState> states) {
for (var state : states) {
allStates.put(state.name(), state);
}
}
/** Wait for the state machine to reach a final state.
* (possibly forever, halting problem and so on)
*/
public void join() throws InterruptedException {
synchronized (this) {
if (null == state)
return;
while (!state.isFinal()) {
wait();
}
}
}
/** Initialize the state machine. */
public void init() throws Exception {
var transition = StateTransition.to("INITIAL");
synchronized (this) {
this.state = allStates.get(transition.state());
notifyAll();
}
smInbox.start();
smOutbox.notify(transition.state(), transition.message());
}
/** Resume the state machine from the last known state. */
public void resume() throws Exception {
if (state == null) {
var messages = smInbox.replay(1);
if (messages.isEmpty()) {
init();
} else {
var firstMessage = messages.get(0);
smInbox.start();
logger.info("Resuming state machine from {}({})/{}", firstMessage.function(), firstMessage.payload(), firstMessage.state());
if (firstMessage.state() == MqMessageState.NEW) {
// The message is not acknowledged, so starting the inbox will trigger a state transition
//
// We still need to set a state here so that the join() method works
state = resumingState;
} else {
// The message is already acknowledged, so we replay the last state
onStateTransition(firstMessage.function(), firstMessage.payload());
}
}
}
}
public void stop() throws InterruptedException {
smInbox.stop();
smOutbox.stop();
}
private void onStateTransition(String nextState, String message) {
try {
logger.info("FSM State change in {}: {}->{}({})",
queueName,
state == null ? "[null]" : state.name(),
nextState,
message);
synchronized (this) {
this.state = allStates.get(nextState);
notifyAll();
}
if (!state.isFinal()) {
var transition = state.next(message);
smOutbox.notify(transition.state(), transition.message());
}
}
catch (Exception e) {
logger.error("Error in state machine transition", e);
setErrorState();
}
}
private void setErrorState() {
synchronized (this) {
state = errorState;
notifyAll();
}
}
private class StateEventSubscription implements MqSubscription {
@Override
public boolean filter(MqMessage rawMessage) {
return true;
}
@Override
public MqInboxResponse onRequest(MqMessage msg) {
return null;
}
@Override
public void onNotification(MqMessage msg) {
onStateTransition(msg.function(), msg.payload());
}
}
}

View File

@ -0,0 +1,14 @@
package nu.marginalia.mqsm.state;
public class ErrorState implements MachineState {
@Override
public String name() { return "ERROR"; }
@Override
public StateTransition next(String message) {
throw new UnsupportedOperationException();
}
@Override
public boolean isFinal() { return true; }
}

View File

@ -0,0 +1,14 @@
package nu.marginalia.mqsm.state;
public class FinalState implements MachineState {
@Override
public String name() { return "END"; }
@Override
public StateTransition next(String message) {
throw new UnsupportedOperationException();
}
@Override
public boolean isFinal() { return true; }
}

View File

@ -0,0 +1,8 @@
package nu.marginalia.mqsm.state;
public interface MachineState {
String name();
StateTransition next(String message);
boolean isFinal();
}

View File

@ -0,0 +1,14 @@
package nu.marginalia.mqsm.state;
public class ResumingState implements MachineState {
@Override
public String name() { return "RESUMING"; }
@Override
public StateTransition next(String message) {
throw new UnsupportedOperationException();
}
@Override
public boolean isFinal() { return false; }
}

View File

@ -0,0 +1,11 @@
package nu.marginalia.mqsm.state;
public record StateTransition(String state, String message) {
public static StateTransition to(String state) {
return new StateTransition(state, "");
}
public static StateTransition to(String state, String message) {
return new StateTransition(state, message);
}
}

View File

@ -1,4 +1,4 @@
package nu.marginalia.mq.outbox;
package nu.marginalia.mq;
import nu.marginalia.mq.MqMessageState;

View File

@ -1,7 +1,6 @@
package nu.marginalia.mq.outbox;
package nu.marginalia.mq;
import com.zaxxer.hikari.HikariDataSource;
import nu.marginalia.mq.MqMessageState;
import org.junit.jupiter.api.Assertions;
import java.sql.SQLException;

View File

@ -4,6 +4,7 @@ import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.HikariDataSource;
import nu.marginalia.mq.MqMessage;
import nu.marginalia.mq.MqMessageState;
import nu.marginalia.mq.MqTestUtil;
import nu.marginalia.mq.inbox.MqInboxResponse;
import nu.marginalia.mq.inbox.MqInbox;
import nu.marginalia.mq.inbox.MqSubscription;
@ -154,9 +155,12 @@ public class MqOutboxTest {
}
@Override
public MqInboxResponse handle(MqMessage msg) {
public MqInboxResponse onRequest(MqMessage msg) {
return MqInboxResponse.ok(response);
}
@Override
public void onNotification(MqMessage msg) { }
};
}
@ -168,9 +172,12 @@ public class MqOutboxTest {
}
@Override
public MqInboxResponse handle(MqMessage msg) {
public MqInboxResponse onRequest(MqMessage msg) {
return MqInboxResponse.ok(msg.payload());
}
@Override
public void onNotification(MqMessage msg) {}
};
}

View File

@ -1,9 +1,9 @@
package nu.marginalia.mq.outbox;
package nu.marginalia.mq.persistence;
import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.HikariDataSource;
import nu.marginalia.mq.MqMessageState;
import nu.marginalia.mq.persistence.MqPersistence;
import nu.marginalia.mq.MqTestUtil;
import org.junit.jupiter.api.*;
import org.testcontainers.containers.MariaDBContainer;
import org.testcontainers.junit.jupiter.Container;

View File

@ -0,0 +1,174 @@
package nu.marginalia.mqsm;
import com.google.gson.GsonBuilder;
import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.HikariDataSource;
import nu.marginalia.mq.MqMessageRow;
import nu.marginalia.mq.MqMessageState;
import nu.marginalia.mq.MqTestUtil;
import nu.marginalia.mq.persistence.MqPersistence;
import org.junit.jupiter.api.*;
import org.testcontainers.containers.MariaDBContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import java.util.List;
import java.util.UUID;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag("slow")
@Testcontainers
public class StateMachineTest {
@Container
static MariaDBContainer<?> mariaDBContainer = new MariaDBContainer<>("mariadb")
.withDatabaseName("WMSA_prod")
.withUsername("wmsa")
.withPassword("wmsa")
.withInitScript("sql/current/11-message-queue.sql")
.withNetworkAliases("mariadb");
static HikariDataSource dataSource;
static MqPersistence persistence;
private String inboxId;
@BeforeEach
public void setUp() {
inboxId = UUID.randomUUID().toString();
}
@BeforeAll
public static void setUpAll() {
HikariConfig config = new HikariConfig();
config.setJdbcUrl(mariaDBContainer.getJdbcUrl());
config.setUsername("wmsa");
config.setPassword("wmsa");
dataSource = new HikariDataSource(config);
persistence = new MqPersistence(dataSource);
}
@AfterAll
public static void tearDownAll() {
dataSource.close();
}
@Test
public void testStartStopStartStop() throws Exception {
var sm = new StateMachine(persistence, inboxId, UUID.randomUUID());
var stateFactory = new StateFactory(new GsonBuilder().create());
var initial = stateFactory.create("INITIAL", () -> stateFactory.transition("GREET", "World"));
var greet = stateFactory.create("GREET", String.class, (String message) -> {
System.out.println("Hello, " + message + "!");
return stateFactory.transition("COUNT-TO-FIVE", 0);
});
var ctf = stateFactory.create("COUNT-TO-FIVE", Integer.class, (Integer count) -> {
System.out.println(count);
if (count < 5) {
return stateFactory.transition("COUNT-TO-FIVE", count + 1);
} else {
return stateFactory.transition("END");
}
});
sm.registerStates(initial, greet, ctf);
sm.init();
Thread.sleep(300);
sm.stop();
var sm2 = new StateMachine(persistence, inboxId, UUID.randomUUID());
sm2.registerStates(initial, greet, ctf);
sm2.resume();
sm2.join();
sm2.stop();
MqTestUtil.getMessages(dataSource, inboxId).forEach(System.out::println);
}
@Test
public void smResumeFromNew() throws Exception {
var sm = new StateMachine(persistence, inboxId, UUID.randomUUID());
var stateFactory = new StateFactory(new GsonBuilder().create());
var initial = stateFactory.create("INITIAL", () -> stateFactory.transition("A"));
var stateA = stateFactory.create("A", () -> stateFactory.transition("B"));
var stateB = stateFactory.create("B", () -> stateFactory.transition("C"));
var stateC = stateFactory.create("C", () -> stateFactory.transition("END"));
sm.registerStates(initial, stateA, stateB, stateC);
persistence.sendNewMessage(inboxId, null,"B", "", null);
sm.resume();
sm.join();
sm.stop();
List<String> states = MqTestUtil.getMessages(dataSource, inboxId)
.stream()
.peek(System.out::println)
.map(MqMessageRow::function)
.toList();
assertEquals(List.of("B", "C", "END"), states);
}
@Test
public void smResumeFromAck() throws Exception {
var sm = new StateMachine(persistence, inboxId, UUID.randomUUID());
var stateFactory = new StateFactory(new GsonBuilder().create());
var initial = stateFactory.create("INITIAL", () -> stateFactory.transition("A"));
var stateA = stateFactory.create("A", () -> stateFactory.transition("B"));
var stateB = stateFactory.create("B", () -> stateFactory.transition("C"));
var stateC = stateFactory.create("C", () -> stateFactory.transition("END"));
sm.registerStates(initial, stateA, stateB, stateC);
long id = persistence.sendNewMessage(inboxId, null,"B", "", null);
persistence.updateMessageState(id, MqMessageState.ACK);
sm.resume();
sm.join();
sm.stop();
List<String> states = MqTestUtil.getMessages(dataSource, inboxId)
.stream()
.peek(System.out::println)
.map(MqMessageRow::function)
.toList();
assertEquals(List.of("B", "C", "END"), states);
}
@Test
public void smResumeEmptyQueue() throws Exception {
var sm = new StateMachine(persistence, inboxId, UUID.randomUUID());
var stateFactory = new StateFactory(new GsonBuilder().create());
var initial = stateFactory.create("INITIAL", () -> stateFactory.transition("A"));
var stateA = stateFactory.create("A", () -> stateFactory.transition("B"));
var stateB = stateFactory.create("B", () -> stateFactory.transition("C"));
var stateC = stateFactory.create("C", () -> stateFactory.transition("END"));
sm.registerStates(initial, stateA, stateB, stateC);
sm.resume();
sm.join();
sm.stop();
List<String> states = MqTestUtil.getMessages(dataSource, inboxId)
.stream()
.peek(System.out::println)
.map(MqMessageRow::function)
.toList();
assertEquals(List.of("INITIAL", "A", "B", "C", "END"), states);
}
}