diff --git a/code/common/message-queue/build.gradle b/code/common/message-queue/build.gradle index 84ea9651..d71ca1d4 100644 --- a/code/common/message-queue/build.gradle +++ b/code/common/message-queue/build.gradle @@ -19,6 +19,7 @@ dependencies { implementation libs.spark implementation libs.guice + implementation libs.gson implementation libs.rxjava implementation libs.bundles.prometheus diff --git a/code/common/message-queue/readme.md b/code/common/message-queue/readme.md index 68ae2825..20e59642 100644 --- a/code/common/message-queue/readme.md +++ b/code/common/message-queue/readme.md @@ -1,5 +1,8 @@ # Message Queue -Implements a message queue using mariadb. +Implements resilient message queueing for the application, +as well as a finite state machine library backed by the +message queue that enables long-running tasks that outlive +the execution lifespan of the involved processes. ![Message States](msgstate.svg) \ No newline at end of file diff --git a/code/common/message-queue/src/main/java/nu/marginalia/mq/MqMessage.java b/code/common/message-queue/src/main/java/nu/marginalia/mq/MqMessage.java index 5f4c11aa..df0c4839 100644 --- a/code/common/message-queue/src/main/java/nu/marginalia/mq/MqMessage.java +++ b/code/common/message-queue/src/main/java/nu/marginalia/mq/MqMessage.java @@ -5,6 +5,7 @@ public record MqMessage( long relatedId, String function, String payload, - MqMessageState state + MqMessageState state, + boolean expectsResponse ) { } diff --git a/code/common/message-queue/src/main/java/nu/marginalia/mq/inbox/MqInbox.java b/code/common/message-queue/src/main/java/nu/marginalia/mq/inbox/MqInbox.java index 7d94b327..00b30cad 100644 --- a/code/common/message-queue/src/main/java/nu/marginalia/mq/inbox/MqInbox.java +++ b/code/common/message-queue/src/main/java/nu/marginalia/mq/inbox/MqInbox.java @@ -15,6 +15,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; public class MqInbox { private final Logger logger = LoggerFactory.getLogger(MqInbox.class); @@ -26,7 +27,7 @@ public class MqInbox { private volatile boolean run = true; - private final int pollIntervalMs = Integer.getInteger("mq.inbox.poll-interval-ms", 1000); + private final int pollIntervalMs = Integer.getInteger("mq.inbox.poll-interval-ms", 100); private final List eventSubscribers = new ArrayList<>(); private final LinkedBlockingQueue queue = new LinkedBlockingQueue<>(32); @@ -114,27 +115,52 @@ public class MqInbox { private void handleMessageWithSubscriber(MqSubscription subscriber, MqMessage msg) { - threadPool.execute(() -> { - try { - final var rsp = subscriber.handle(msg); - - sendResponse(msg, rsp.state(), rsp.message()); - } catch (Exception ex) { - logger.error("Message Queue subscriber threw exception", ex); - sendResponse(msg, MqMessageState.ERR); - } - }); + if (msg.expectsResponse()) { + threadPool.execute(() -> respondToMessage(subscriber, msg)); + } + else { + threadPool.execute(() -> acknowledgeNotification(subscriber, msg)); + } } - private void sendResponse(MqMessage msg, MqMessageState mqMessageState) { + private void respondToMessage(MqSubscription subscriber, MqMessage msg) { try { - persistence.updateMessageState(msg.msgId(), mqMessageState); + final var rsp = subscriber.onRequest(msg); + sendResponse(msg, rsp.state(), rsp.message()); + } catch (Exception ex) { + logger.error("Message Queue subscriber threw exception", ex); + sendResponse(msg, MqMessageState.ERR); + } + } + + private void acknowledgeNotification(MqSubscription subscriber, MqMessage msg) { + try { + subscriber.onNotification(msg); + updateMessageState(msg, MqMessageState.OK); + } catch (Exception ex) { + logger.error("Message Queue subscriber threw exception", ex); + updateMessageState(msg, MqMessageState.ERR); + } + } + + private void sendResponse(MqMessage msg, MqMessageState state) { + try { + persistence.updateMessageState(msg.msgId(), state); } catch (SQLException ex) { logger.error("Failed to update message state", ex); } } + private void updateMessageState(MqMessage msg, MqMessageState state) { + try { + persistence.updateMessageState(msg.msgId(), state); + } + catch (SQLException ex2) { + logger.error("Failed to update message state", ex2); + } + } + private void sendResponse(MqMessage msg, MqMessageState mqMessageState, String response) { try { persistence.sendResponse(msg.msgId(), mqMessageState, response); @@ -159,14 +185,25 @@ public class MqInbox { } private Collection pollInbox(long tick) { - try { - return persistence.pollInbox(inboxName, instanceUUID, tick); - } - catch (SQLException ex) { - logger.error("Failed to poll inbox", ex); - return List.of(); - } - } + try { + return persistence.pollInbox(inboxName, instanceUUID, tick); + } + catch (SQLException ex) { + logger.error("Failed to poll inbox", ex); + return List.of(); + } + } + + /** Retrieve the last N messages from the inbox. */ + public List replay(int lastN) { + try { + return persistence.lastNMessages(inboxName, lastN); + } + catch (SQLException ex) { + logger.error("Failed to replay inbox", ex); + return List.of(); + } + } private class MqInboxShredder implements MqSubscription { @@ -177,9 +214,14 @@ public class MqInbox { } @Override - public MqInboxResponse handle(MqMessage msg) { + public MqInboxResponse onRequest(MqMessage msg) { logger.warn("Unhandled message {}", msg.msgId()); return MqInboxResponse.err(); } + + @Override + public void onNotification(MqMessage msg) { + logger.warn("Unhandled message {}", msg.msgId()); + } } } diff --git a/code/common/message-queue/src/main/java/nu/marginalia/mq/inbox/MqSubscription.java b/code/common/message-queue/src/main/java/nu/marginalia/mq/inbox/MqSubscription.java index ce52a26b..417b7b35 100644 --- a/code/common/message-queue/src/main/java/nu/marginalia/mq/inbox/MqSubscription.java +++ b/code/common/message-queue/src/main/java/nu/marginalia/mq/inbox/MqSubscription.java @@ -3,7 +3,12 @@ package nu.marginalia.mq.inbox; import nu.marginalia.mq.MqMessage; public interface MqSubscription { + /** Return true if this subscription should handle the message. */ boolean filter(MqMessage rawMessage); - MqInboxResponse handle(MqMessage msg); + /** Handle the message and return a response. */ + MqInboxResponse onRequest(MqMessage msg); + + /** Handle a message with no reply address */ + void onNotification(MqMessage msg); } diff --git a/code/common/message-queue/src/main/java/nu/marginalia/mq/outbox/MqOutbox.java b/code/common/message-queue/src/main/java/nu/marginalia/mq/outbox/MqOutbox.java index e4fa2e23..e8faa0ab 100644 --- a/code/common/message-queue/src/main/java/nu/marginalia/mq/outbox/MqOutbox.java +++ b/code/common/message-queue/src/main/java/nu/marginalia/mq/outbox/MqOutbox.java @@ -6,7 +6,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.sql.SQLException; -import java.time.Duration; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; @@ -21,7 +20,7 @@ public class MqOutbox { private final ConcurrentHashMap pendingRequests = new ConcurrentHashMap<>(); private final ConcurrentHashMap pendingResponses = new ConcurrentHashMap<>(); - private final int pollIntervalMs = Integer.getInteger("mq.outbox.poll-interval-ms", 1000); + private final int pollIntervalMs = Integer.getInteger("mq.outbox.poll-interval-ms", 100); private final Thread pollThread; private volatile boolean run = true; @@ -103,5 +102,8 @@ public class MqOutbox { } } + public long notify(String function, String payload) throws Exception { + return persistence.sendNewMessage(inboxName, null, function, payload, null); + } } \ No newline at end of file diff --git a/code/common/message-queue/src/main/java/nu/marginalia/mq/persistence/MqPersistence.java b/code/common/message-queue/src/main/java/nu/marginalia/mq/persistence/MqPersistence.java index 92fffb51..d5356c55 100644 --- a/code/common/message-queue/src/main/java/nu/marginalia/mq/persistence/MqPersistence.java +++ b/code/common/message-queue/src/main/java/nu/marginalia/mq/persistence/MqPersistence.java @@ -1,5 +1,6 @@ package nu.marginalia.mq.persistence; +import com.google.common.collect.Lists; import com.google.inject.Inject; import com.google.inject.Singleton; import com.zaxxer.hikari.HikariDataSource; @@ -164,7 +165,7 @@ public class MqPersistence { try (var conn = dataSource.getConnection(); var queryStmt = conn.prepareStatement(""" - SELECT ID, RELATED_ID, FUNCTION, PAYLOAD, STATE FROM PROC_MESSAGE + SELECT ID, RELATED_ID, FUNCTION, PAYLOAD, STATE, SENDER_INBOX FROM PROC_MESSAGE WHERE OWNER_INSTANCE=? AND OWNER_TICK=? """) ) { @@ -182,8 +183,9 @@ public class MqPersistence { String payload = rs.getString(4); MqMessageState state = MqMessageState.valueOf(rs.getString(5)); + boolean expectsResponse = rs.getBoolean(6); - var msg = new MqMessage(msgId, relatedId, function, payload, state); + var msg = new MqMessage(msgId, relatedId, function, payload, state, expectsResponse); messages.add(msg); } @@ -226,7 +228,7 @@ public class MqPersistence { MqMessageState state = MqMessageState.valueOf(rs.getString(5)); - var msg = new MqMessage(msgId, relatedId, function, payload, state); + var msg = new MqMessage(msgId, relatedId, function, payload, state, false); messages.add(msg); } @@ -234,4 +236,38 @@ public class MqPersistence { return messages; } } + + public List lastNMessages(String inboxName, int lastN) throws SQLException { + try (var conn = dataSource.getConnection(); + var stmt = conn.prepareStatement(""" + SELECT ID, RELATED_ID, FUNCTION, PAYLOAD, STATE, SENDER_INBOX FROM PROC_MESSAGE + WHERE RECIPIENT_INBOX = ? + ORDER BY ID DESC LIMIT ? + """)) { + + stmt.setString(1, inboxName); + stmt.setInt(2, lastN); + List messages = new ArrayList<>(lastN); + + var rs = stmt.executeQuery(); + while (rs.next()) { + long msgId = rs.getLong(1); + long relatedId = rs.getLong(2); + + String function = rs.getString(3); + String payload = rs.getString(4); + + MqMessageState state = MqMessageState.valueOf(rs.getString(5)); + boolean expectsResponse = rs.getBoolean(6); + + var msg = new MqMessage(msgId, relatedId, function, payload, state, expectsResponse); + + messages.add(msg); + } + + Lists.reverse(messages); + return messages; + } + + } } diff --git a/code/common/message-queue/src/main/java/nu/marginalia/mqsm/StateFactory.java b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/StateFactory.java new file mode 100644 index 00000000..8dccde4b --- /dev/null +++ b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/StateFactory.java @@ -0,0 +1,66 @@ +package nu.marginalia.mqsm; + +import com.google.gson.Gson; +import com.google.inject.Inject; +import com.google.inject.Singleton; +import nu.marginalia.mqsm.state.MachineState; +import nu.marginalia.mqsm.state.StateTransition; + +import java.util.function.Function; +import java.util.function.Supplier; + +@Singleton +public class StateFactory { + private final Gson gson; + + @Inject + public StateFactory(Gson gson) { + this.gson = gson; + } + + public MachineState create(String name, Class param, Function logic) { + return new MachineState() { + @Override + public String name() { + return name; + } + + @Override + public StateTransition next(String message) { + return logic.apply(gson.fromJson(message, param)); + } + + @Override + public boolean isFinal() { + return false; + } + }; + } + + public MachineState create(String name, Supplier logic) { + return new MachineState() { + @Override + public String name() { + return name; + } + + @Override + public StateTransition next(String message) { + return logic.get(); + } + + @Override + public boolean isFinal() { + return false; + } + }; + } + + public StateTransition transition(String state) { + return StateTransition.to(state); + } + + public StateTransition transition(String state, Object message) { + return StateTransition.to(state, gson.toJson(message)); + } +} diff --git a/code/common/message-queue/src/main/java/nu/marginalia/mqsm/StateMachine.java b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/StateMachine.java new file mode 100644 index 00000000..cb7d1f33 --- /dev/null +++ b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/StateMachine.java @@ -0,0 +1,176 @@ +package nu.marginalia.mqsm; + +import nu.marginalia.mq.MqMessage; +import nu.marginalia.mq.MqMessageState; +import nu.marginalia.mq.inbox.MqInbox; +import nu.marginalia.mq.inbox.MqInboxResponse; +import nu.marginalia.mq.inbox.MqSubscription; +import nu.marginalia.mq.outbox.MqOutbox; +import nu.marginalia.mq.persistence.MqPersistence; +import nu.marginalia.mqsm.state.*; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +/** A state machine that can be used to implement a finite state machine + * using a message queue as the persistence layer. The state machine is + * resilient to crashes and can be resumed from the last state. + */ +public class StateMachine { + private final Logger logger = LoggerFactory.getLogger(StateMachine.class); + + private final MqInbox smInbox; + private final MqOutbox smOutbox; + private final String queueName; + private MachineState state; + + private final MachineState errorState = new ErrorState(); + private final MachineState finalState = new FinalState(); + private final MachineState resumingState = new ResumingState(); + + private final Map allStates = new HashMap<>(); + + public StateMachine(MqPersistence persistence, String queueName, UUID instanceUUID) { + this.queueName = queueName; + + smInbox = new MqInbox(persistence, queueName, instanceUUID); + smOutbox = new MqOutbox(persistence, queueName, instanceUUID); + + smInbox.subscribe(new StateEventSubscription()); + + registerStates(List.of(errorState, finalState, resumingState)); + } + + /** Register the state graph */ + public void registerStates(MachineState... states) { + if (state != null) { + throw new IllegalStateException("Cannot register states after state machine has been initialized"); + } + + for (var state : states) { + allStates.put(state.name(), state); + } + } + + /** Register the state graph */ + public void registerStates(List states) { + for (var state : states) { + allStates.put(state.name(), state); + } + } + + /** Wait for the state machine to reach a final state. + * (possibly forever, halting problem and so on) + */ + public void join() throws InterruptedException { + synchronized (this) { + if (null == state) + return; + + while (!state.isFinal()) { + wait(); + } + } + } + + + /** Initialize the state machine. */ + public void init() throws Exception { + var transition = StateTransition.to("INITIAL"); + + synchronized (this) { + this.state = allStates.get(transition.state()); + notifyAll(); + } + + smInbox.start(); + smOutbox.notify(transition.state(), transition.message()); + } + + /** Resume the state machine from the last known state. */ + public void resume() throws Exception { + + if (state == null) { + var messages = smInbox.replay(1); + + if (messages.isEmpty()) { + init(); + } else { + var firstMessage = messages.get(0); + + smInbox.start(); + + logger.info("Resuming state machine from {}({})/{}", firstMessage.function(), firstMessage.payload(), firstMessage.state()); + + if (firstMessage.state() == MqMessageState.NEW) { + // The message is not acknowledged, so starting the inbox will trigger a state transition + // + // We still need to set a state here so that the join() method works + + state = resumingState; + } else { + // The message is already acknowledged, so we replay the last state + onStateTransition(firstMessage.function(), firstMessage.payload()); + } + } + } + } + + public void stop() throws InterruptedException { + smInbox.stop(); + smOutbox.stop(); + } + + private void onStateTransition(String nextState, String message) { + try { + logger.info("FSM State change in {}: {}->{}({})", + queueName, + state == null ? "[null]" : state.name(), + nextState, + message); + + synchronized (this) { + this.state = allStates.get(nextState); + notifyAll(); + } + + if (!state.isFinal()) { + var transition = state.next(message); + smOutbox.notify(transition.state(), transition.message()); + } + } + catch (Exception e) { + logger.error("Error in state machine transition", e); + setErrorState(); + } + } + + private void setErrorState() { + synchronized (this) { + state = errorState; + notifyAll(); + } + } + + private class StateEventSubscription implements MqSubscription { + + @Override + public boolean filter(MqMessage rawMessage) { + return true; + } + + @Override + public MqInboxResponse onRequest(MqMessage msg) { + return null; + } + + @Override + public void onNotification(MqMessage msg) { + onStateTransition(msg.function(), msg.payload()); + } + } +} diff --git a/code/common/message-queue/src/main/java/nu/marginalia/mqsm/state/ErrorState.java b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/state/ErrorState.java new file mode 100644 index 00000000..4f1fef96 --- /dev/null +++ b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/state/ErrorState.java @@ -0,0 +1,14 @@ +package nu.marginalia.mqsm.state; + +public class ErrorState implements MachineState { + @Override + public String name() { return "ERROR"; } + + @Override + public StateTransition next(String message) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isFinal() { return true; } +} diff --git a/code/common/message-queue/src/main/java/nu/marginalia/mqsm/state/FinalState.java b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/state/FinalState.java new file mode 100644 index 00000000..5ee7d435 --- /dev/null +++ b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/state/FinalState.java @@ -0,0 +1,14 @@ +package nu.marginalia.mqsm.state; + +public class FinalState implements MachineState { + @Override + public String name() { return "END"; } + + @Override + public StateTransition next(String message) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isFinal() { return true; } +} diff --git a/code/common/message-queue/src/main/java/nu/marginalia/mqsm/state/MachineState.java b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/state/MachineState.java new file mode 100644 index 00000000..4bba33cf --- /dev/null +++ b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/state/MachineState.java @@ -0,0 +1,8 @@ +package nu.marginalia.mqsm.state; + +public interface MachineState { + String name(); + StateTransition next(String message); + + boolean isFinal(); +} diff --git a/code/common/message-queue/src/main/java/nu/marginalia/mqsm/state/ResumingState.java b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/state/ResumingState.java new file mode 100644 index 00000000..36a474e2 --- /dev/null +++ b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/state/ResumingState.java @@ -0,0 +1,14 @@ +package nu.marginalia.mqsm.state; + +public class ResumingState implements MachineState { + @Override + public String name() { return "RESUMING"; } + + @Override + public StateTransition next(String message) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isFinal() { return false; } +} diff --git a/code/common/message-queue/src/main/java/nu/marginalia/mqsm/state/StateTransition.java b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/state/StateTransition.java new file mode 100644 index 00000000..6ca5d387 --- /dev/null +++ b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/state/StateTransition.java @@ -0,0 +1,11 @@ +package nu.marginalia.mqsm.state; + +public record StateTransition(String state, String message) { + public static StateTransition to(String state) { + return new StateTransition(state, ""); + } + + public static StateTransition to(String state, String message) { + return new StateTransition(state, message); + } +} diff --git a/code/common/message-queue/src/test/java/nu/marginalia/mq/outbox/MqMessageRow.java b/code/common/message-queue/src/test/java/nu/marginalia/mq/MqMessageRow.java similarity index 92% rename from code/common/message-queue/src/test/java/nu/marginalia/mq/outbox/MqMessageRow.java rename to code/common/message-queue/src/test/java/nu/marginalia/mq/MqMessageRow.java index 933cdb62..ef12105a 100644 --- a/code/common/message-queue/src/test/java/nu/marginalia/mq/outbox/MqMessageRow.java +++ b/code/common/message-queue/src/test/java/nu/marginalia/mq/MqMessageRow.java @@ -1,4 +1,4 @@ -package nu.marginalia.mq.outbox; +package nu.marginalia.mq; import nu.marginalia.mq.MqMessageState; diff --git a/code/common/message-queue/src/test/java/nu/marginalia/mq/outbox/MqTestUtil.java b/code/common/message-queue/src/test/java/nu/marginalia/mq/MqTestUtil.java similarity index 96% rename from code/common/message-queue/src/test/java/nu/marginalia/mq/outbox/MqTestUtil.java rename to code/common/message-queue/src/test/java/nu/marginalia/mq/MqTestUtil.java index 3fee8b20..dcefaf1a 100644 --- a/code/common/message-queue/src/test/java/nu/marginalia/mq/outbox/MqTestUtil.java +++ b/code/common/message-queue/src/test/java/nu/marginalia/mq/MqTestUtil.java @@ -1,7 +1,6 @@ -package nu.marginalia.mq.outbox; +package nu.marginalia.mq; import com.zaxxer.hikari.HikariDataSource; -import nu.marginalia.mq.MqMessageState; import org.junit.jupiter.api.Assertions; import java.sql.SQLException; diff --git a/code/common/message-queue/src/test/java/nu/marginalia/mq/outbox/MqOutboxTest.java b/code/common/message-queue/src/test/java/nu/marginalia/mq/outbox/MqOutboxTest.java index 789aec15..6dc51f2d 100644 --- a/code/common/message-queue/src/test/java/nu/marginalia/mq/outbox/MqOutboxTest.java +++ b/code/common/message-queue/src/test/java/nu/marginalia/mq/outbox/MqOutboxTest.java @@ -4,6 +4,7 @@ import com.zaxxer.hikari.HikariConfig; import com.zaxxer.hikari.HikariDataSource; import nu.marginalia.mq.MqMessage; import nu.marginalia.mq.MqMessageState; +import nu.marginalia.mq.MqTestUtil; import nu.marginalia.mq.inbox.MqInboxResponse; import nu.marginalia.mq.inbox.MqInbox; import nu.marginalia.mq.inbox.MqSubscription; @@ -154,9 +155,12 @@ public class MqOutboxTest { } @Override - public MqInboxResponse handle(MqMessage msg) { + public MqInboxResponse onRequest(MqMessage msg) { return MqInboxResponse.ok(response); } + + @Override + public void onNotification(MqMessage msg) { } }; } @@ -168,9 +172,12 @@ public class MqOutboxTest { } @Override - public MqInboxResponse handle(MqMessage msg) { + public MqInboxResponse onRequest(MqMessage msg) { return MqInboxResponse.ok(msg.payload()); } + + @Override + public void onNotification(MqMessage msg) {} }; } diff --git a/code/common/message-queue/src/test/java/nu/marginalia/mq/outbox/MqPersistenceTest.java b/code/common/message-queue/src/test/java/nu/marginalia/mq/persistence/MqPersistenceTest.java similarity index 98% rename from code/common/message-queue/src/test/java/nu/marginalia/mq/outbox/MqPersistenceTest.java rename to code/common/message-queue/src/test/java/nu/marginalia/mq/persistence/MqPersistenceTest.java index 590ff64b..7166531d 100644 --- a/code/common/message-queue/src/test/java/nu/marginalia/mq/outbox/MqPersistenceTest.java +++ b/code/common/message-queue/src/test/java/nu/marginalia/mq/persistence/MqPersistenceTest.java @@ -1,9 +1,9 @@ -package nu.marginalia.mq.outbox; +package nu.marginalia.mq.persistence; import com.zaxxer.hikari.HikariConfig; import com.zaxxer.hikari.HikariDataSource; import nu.marginalia.mq.MqMessageState; -import nu.marginalia.mq.persistence.MqPersistence; +import nu.marginalia.mq.MqTestUtil; import org.junit.jupiter.api.*; import org.testcontainers.containers.MariaDBContainer; import org.testcontainers.junit.jupiter.Container; diff --git a/code/common/message-queue/src/test/java/nu/marginalia/mqsm/StateMachineTest.java b/code/common/message-queue/src/test/java/nu/marginalia/mqsm/StateMachineTest.java new file mode 100644 index 00000000..06cc658c --- /dev/null +++ b/code/common/message-queue/src/test/java/nu/marginalia/mqsm/StateMachineTest.java @@ -0,0 +1,174 @@ +package nu.marginalia.mqsm; + +import com.google.gson.GsonBuilder; +import com.zaxxer.hikari.HikariConfig; +import com.zaxxer.hikari.HikariDataSource; +import nu.marginalia.mq.MqMessageRow; +import nu.marginalia.mq.MqMessageState; +import nu.marginalia.mq.MqTestUtil; +import nu.marginalia.mq.persistence.MqPersistence; +import org.junit.jupiter.api.*; +import org.testcontainers.containers.MariaDBContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import java.util.List; +import java.util.UUID; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +@Tag("slow") +@Testcontainers +public class StateMachineTest { + @Container + static MariaDBContainer mariaDBContainer = new MariaDBContainer<>("mariadb") + .withDatabaseName("WMSA_prod") + .withUsername("wmsa") + .withPassword("wmsa") + .withInitScript("sql/current/11-message-queue.sql") + .withNetworkAliases("mariadb"); + + static HikariDataSource dataSource; + static MqPersistence persistence; + private String inboxId; + + @BeforeEach + public void setUp() { + inboxId = UUID.randomUUID().toString(); + } + @BeforeAll + public static void setUpAll() { + HikariConfig config = new HikariConfig(); + config.setJdbcUrl(mariaDBContainer.getJdbcUrl()); + config.setUsername("wmsa"); + config.setPassword("wmsa"); + + dataSource = new HikariDataSource(config); + persistence = new MqPersistence(dataSource); + } + + @AfterAll + public static void tearDownAll() { + dataSource.close(); + } + + @Test + public void testStartStopStartStop() throws Exception { + var sm = new StateMachine(persistence, inboxId, UUID.randomUUID()); + var stateFactory = new StateFactory(new GsonBuilder().create()); + + var initial = stateFactory.create("INITIAL", () -> stateFactory.transition("GREET", "World")); + + var greet = stateFactory.create("GREET", String.class, (String message) -> { + System.out.println("Hello, " + message + "!"); + return stateFactory.transition("COUNT-TO-FIVE", 0); + }); + + var ctf = stateFactory.create("COUNT-TO-FIVE", Integer.class, (Integer count) -> { + System.out.println(count); + if (count < 5) { + return stateFactory.transition("COUNT-TO-FIVE", count + 1); + } else { + return stateFactory.transition("END"); + } + }); + + sm.registerStates(initial, greet, ctf); + + sm.init(); + + Thread.sleep(300); + sm.stop(); + + var sm2 = new StateMachine(persistence, inboxId, UUID.randomUUID()); + sm2.registerStates(initial, greet, ctf); + sm2.resume(); + sm2.join(); + sm2.stop(); + + MqTestUtil.getMessages(dataSource, inboxId).forEach(System.out::println); + } + + @Test + public void smResumeFromNew() throws Exception { + var sm = new StateMachine(persistence, inboxId, UUID.randomUUID()); + var stateFactory = new StateFactory(new GsonBuilder().create()); + + var initial = stateFactory.create("INITIAL", () -> stateFactory.transition("A")); + var stateA = stateFactory.create("A", () -> stateFactory.transition("B")); + var stateB = stateFactory.create("B", () -> stateFactory.transition("C")); + var stateC = stateFactory.create("C", () -> stateFactory.transition("END")); + + sm.registerStates(initial, stateA, stateB, stateC); + persistence.sendNewMessage(inboxId, null,"B", "", null); + + sm.resume(); + + sm.join(); + sm.stop(); + + List states = MqTestUtil.getMessages(dataSource, inboxId) + .stream() + .peek(System.out::println) + .map(MqMessageRow::function) + .toList(); + + assertEquals(List.of("B", "C", "END"), states); + } + + @Test + public void smResumeFromAck() throws Exception { + var sm = new StateMachine(persistence, inboxId, UUID.randomUUID()); + var stateFactory = new StateFactory(new GsonBuilder().create()); + + var initial = stateFactory.create("INITIAL", () -> stateFactory.transition("A")); + var stateA = stateFactory.create("A", () -> stateFactory.transition("B")); + var stateB = stateFactory.create("B", () -> stateFactory.transition("C")); + var stateC = stateFactory.create("C", () -> stateFactory.transition("END")); + + sm.registerStates(initial, stateA, stateB, stateC); + + long id = persistence.sendNewMessage(inboxId, null,"B", "", null); + persistence.updateMessageState(id, MqMessageState.ACK); + + sm.resume(); + + sm.join(); + sm.stop(); + + List states = MqTestUtil.getMessages(dataSource, inboxId) + .stream() + .peek(System.out::println) + .map(MqMessageRow::function) + .toList(); + + assertEquals(List.of("B", "C", "END"), states); + } + + + @Test + public void smResumeEmptyQueue() throws Exception { + var sm = new StateMachine(persistence, inboxId, UUID.randomUUID()); + var stateFactory = new StateFactory(new GsonBuilder().create()); + + var initial = stateFactory.create("INITIAL", () -> stateFactory.transition("A")); + var stateA = stateFactory.create("A", () -> stateFactory.transition("B")); + var stateB = stateFactory.create("B", () -> stateFactory.transition("C")); + var stateC = stateFactory.create("C", () -> stateFactory.transition("END")); + + sm.registerStates(initial, stateA, stateB, stateC); + + sm.resume(); + + sm.join(); + sm.stop(); + + List states = MqTestUtil.getMessages(dataSource, inboxId) + .stream() + .peek(System.out::println) + .map(MqMessageRow::function) + .toList(); + + assertEquals(List.of("INITIAL", "A", "B", "C", "END"), states); + } +}