Message queue based state machine
This commit is contained in:
parent
31ae71c7d6
commit
2ae0b8c159
@ -19,6 +19,7 @@ dependencies {
|
||||
|
||||
implementation libs.spark
|
||||
implementation libs.guice
|
||||
implementation libs.gson
|
||||
implementation libs.rxjava
|
||||
|
||||
implementation libs.bundles.prometheus
|
||||
|
@ -1,5 +1,8 @@
|
||||
# Message Queue
|
||||
|
||||
Implements a message queue using mariadb.
|
||||
Implements resilient message queueing for the application,
|
||||
as well as a finite state machine library backed by the
|
||||
message queue that enables long-running tasks that outlive
|
||||
the execution lifespan of the involved processes.
|
||||
|
||||
![Message States](msgstate.svg)
|
@ -5,6 +5,7 @@ public record MqMessage(
|
||||
long relatedId,
|
||||
String function,
|
||||
String payload,
|
||||
MqMessageState state
|
||||
MqMessageState state,
|
||||
boolean expectsResponse
|
||||
) {
|
||||
}
|
||||
|
@ -15,6 +15,7 @@ import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.LinkedBlockingQueue;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
public class MqInbox {
|
||||
private final Logger logger = LoggerFactory.getLogger(MqInbox.class);
|
||||
@ -26,7 +27,7 @@ public class MqInbox {
|
||||
|
||||
private volatile boolean run = true;
|
||||
|
||||
private final int pollIntervalMs = Integer.getInteger("mq.inbox.poll-interval-ms", 1000);
|
||||
private final int pollIntervalMs = Integer.getInteger("mq.inbox.poll-interval-ms", 100);
|
||||
private final List<MqSubscription> eventSubscribers = new ArrayList<>();
|
||||
private final LinkedBlockingQueue<MqMessage> queue = new LinkedBlockingQueue<>(32);
|
||||
|
||||
@ -114,27 +115,52 @@ public class MqInbox {
|
||||
|
||||
private void handleMessageWithSubscriber(MqSubscription subscriber, MqMessage msg) {
|
||||
|
||||
threadPool.execute(() -> {
|
||||
try {
|
||||
final var rsp = subscriber.handle(msg);
|
||||
|
||||
sendResponse(msg, rsp.state(), rsp.message());
|
||||
} catch (Exception ex) {
|
||||
logger.error("Message Queue subscriber threw exception", ex);
|
||||
sendResponse(msg, MqMessageState.ERR);
|
||||
}
|
||||
});
|
||||
if (msg.expectsResponse()) {
|
||||
threadPool.execute(() -> respondToMessage(subscriber, msg));
|
||||
}
|
||||
else {
|
||||
threadPool.execute(() -> acknowledgeNotification(subscriber, msg));
|
||||
}
|
||||
}
|
||||
|
||||
private void sendResponse(MqMessage msg, MqMessageState mqMessageState) {
|
||||
private void respondToMessage(MqSubscription subscriber, MqMessage msg) {
|
||||
try {
|
||||
persistence.updateMessageState(msg.msgId(), mqMessageState);
|
||||
final var rsp = subscriber.onRequest(msg);
|
||||
sendResponse(msg, rsp.state(), rsp.message());
|
||||
} catch (Exception ex) {
|
||||
logger.error("Message Queue subscriber threw exception", ex);
|
||||
sendResponse(msg, MqMessageState.ERR);
|
||||
}
|
||||
}
|
||||
|
||||
private void acknowledgeNotification(MqSubscription subscriber, MqMessage msg) {
|
||||
try {
|
||||
subscriber.onNotification(msg);
|
||||
updateMessageState(msg, MqMessageState.OK);
|
||||
} catch (Exception ex) {
|
||||
logger.error("Message Queue subscriber threw exception", ex);
|
||||
updateMessageState(msg, MqMessageState.ERR);
|
||||
}
|
||||
}
|
||||
|
||||
private void sendResponse(MqMessage msg, MqMessageState state) {
|
||||
try {
|
||||
persistence.updateMessageState(msg.msgId(), state);
|
||||
}
|
||||
catch (SQLException ex) {
|
||||
logger.error("Failed to update message state", ex);
|
||||
}
|
||||
}
|
||||
|
||||
private void updateMessageState(MqMessage msg, MqMessageState state) {
|
||||
try {
|
||||
persistence.updateMessageState(msg.msgId(), state);
|
||||
}
|
||||
catch (SQLException ex2) {
|
||||
logger.error("Failed to update message state", ex2);
|
||||
}
|
||||
}
|
||||
|
||||
private void sendResponse(MqMessage msg, MqMessageState mqMessageState, String response) {
|
||||
try {
|
||||
persistence.sendResponse(msg.msgId(), mqMessageState, response);
|
||||
@ -159,14 +185,25 @@ public class MqInbox {
|
||||
}
|
||||
|
||||
private Collection<MqMessage> pollInbox(long tick) {
|
||||
try {
|
||||
return persistence.pollInbox(inboxName, instanceUUID, tick);
|
||||
}
|
||||
catch (SQLException ex) {
|
||||
logger.error("Failed to poll inbox", ex);
|
||||
return List.of();
|
||||
}
|
||||
}
|
||||
try {
|
||||
return persistence.pollInbox(inboxName, instanceUUID, tick);
|
||||
}
|
||||
catch (SQLException ex) {
|
||||
logger.error("Failed to poll inbox", ex);
|
||||
return List.of();
|
||||
}
|
||||
}
|
||||
|
||||
/** Retrieve the last N messages from the inbox. */
|
||||
public List<MqMessage> replay(int lastN) {
|
||||
try {
|
||||
return persistence.lastNMessages(inboxName, lastN);
|
||||
}
|
||||
catch (SQLException ex) {
|
||||
logger.error("Failed to replay inbox", ex);
|
||||
return List.of();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private class MqInboxShredder implements MqSubscription {
|
||||
@ -177,9 +214,14 @@ public class MqInbox {
|
||||
}
|
||||
|
||||
@Override
|
||||
public MqInboxResponse handle(MqMessage msg) {
|
||||
public MqInboxResponse onRequest(MqMessage msg) {
|
||||
logger.warn("Unhandled message {}", msg.msgId());
|
||||
return MqInboxResponse.err();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onNotification(MqMessage msg) {
|
||||
logger.warn("Unhandled message {}", msg.msgId());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -3,7 +3,12 @@ package nu.marginalia.mq.inbox;
|
||||
import nu.marginalia.mq.MqMessage;
|
||||
|
||||
public interface MqSubscription {
|
||||
/** Return true if this subscription should handle the message. */
|
||||
boolean filter(MqMessage rawMessage);
|
||||
|
||||
MqInboxResponse handle(MqMessage msg);
|
||||
/** Handle the message and return a response. */
|
||||
MqInboxResponse onRequest(MqMessage msg);
|
||||
|
||||
/** Handle a message with no reply address */
|
||||
void onNotification(MqMessage msg);
|
||||
}
|
||||
|
@ -6,7 +6,6 @@ import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.sql.SQLException;
|
||||
import java.time.Duration;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
@ -21,7 +20,7 @@ public class MqOutbox {
|
||||
private final ConcurrentHashMap<Long, Long> pendingRequests = new ConcurrentHashMap<>();
|
||||
private final ConcurrentHashMap<Long, MqMessage> pendingResponses = new ConcurrentHashMap<>();
|
||||
|
||||
private final int pollIntervalMs = Integer.getInteger("mq.outbox.poll-interval-ms", 1000);
|
||||
private final int pollIntervalMs = Integer.getInteger("mq.outbox.poll-interval-ms", 100);
|
||||
private final Thread pollThread;
|
||||
|
||||
private volatile boolean run = true;
|
||||
@ -103,5 +102,8 @@ public class MqOutbox {
|
||||
}
|
||||
}
|
||||
|
||||
public long notify(String function, String payload) throws Exception {
|
||||
return persistence.sendNewMessage(inboxName, null, function, payload, null);
|
||||
}
|
||||
|
||||
}
|
@ -1,5 +1,6 @@
|
||||
package nu.marginalia.mq.persistence;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.inject.Inject;
|
||||
import com.google.inject.Singleton;
|
||||
import com.zaxxer.hikari.HikariDataSource;
|
||||
@ -164,7 +165,7 @@ public class MqPersistence {
|
||||
|
||||
try (var conn = dataSource.getConnection();
|
||||
var queryStmt = conn.prepareStatement("""
|
||||
SELECT ID, RELATED_ID, FUNCTION, PAYLOAD, STATE FROM PROC_MESSAGE
|
||||
SELECT ID, RELATED_ID, FUNCTION, PAYLOAD, STATE, SENDER_INBOX FROM PROC_MESSAGE
|
||||
WHERE OWNER_INSTANCE=? AND OWNER_TICK=?
|
||||
""")
|
||||
) {
|
||||
@ -182,8 +183,9 @@ public class MqPersistence {
|
||||
String payload = rs.getString(4);
|
||||
|
||||
MqMessageState state = MqMessageState.valueOf(rs.getString(5));
|
||||
boolean expectsResponse = rs.getBoolean(6);
|
||||
|
||||
var msg = new MqMessage(msgId, relatedId, function, payload, state);
|
||||
var msg = new MqMessage(msgId, relatedId, function, payload, state, expectsResponse);
|
||||
|
||||
messages.add(msg);
|
||||
}
|
||||
@ -226,7 +228,7 @@ public class MqPersistence {
|
||||
|
||||
MqMessageState state = MqMessageState.valueOf(rs.getString(5));
|
||||
|
||||
var msg = new MqMessage(msgId, relatedId, function, payload, state);
|
||||
var msg = new MqMessage(msgId, relatedId, function, payload, state, false);
|
||||
|
||||
messages.add(msg);
|
||||
}
|
||||
@ -234,4 +236,38 @@ public class MqPersistence {
|
||||
return messages;
|
||||
}
|
||||
}
|
||||
|
||||
public List<MqMessage> lastNMessages(String inboxName, int lastN) throws SQLException {
|
||||
try (var conn = dataSource.getConnection();
|
||||
var stmt = conn.prepareStatement("""
|
||||
SELECT ID, RELATED_ID, FUNCTION, PAYLOAD, STATE, SENDER_INBOX FROM PROC_MESSAGE
|
||||
WHERE RECIPIENT_INBOX = ?
|
||||
ORDER BY ID DESC LIMIT ?
|
||||
""")) {
|
||||
|
||||
stmt.setString(1, inboxName);
|
||||
stmt.setInt(2, lastN);
|
||||
List<MqMessage> messages = new ArrayList<>(lastN);
|
||||
|
||||
var rs = stmt.executeQuery();
|
||||
while (rs.next()) {
|
||||
long msgId = rs.getLong(1);
|
||||
long relatedId = rs.getLong(2);
|
||||
|
||||
String function = rs.getString(3);
|
||||
String payload = rs.getString(4);
|
||||
|
||||
MqMessageState state = MqMessageState.valueOf(rs.getString(5));
|
||||
boolean expectsResponse = rs.getBoolean(6);
|
||||
|
||||
var msg = new MqMessage(msgId, relatedId, function, payload, state, expectsResponse);
|
||||
|
||||
messages.add(msg);
|
||||
}
|
||||
|
||||
Lists.reverse(messages);
|
||||
return messages;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,66 @@
|
||||
package nu.marginalia.mqsm;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import com.google.inject.Inject;
|
||||
import com.google.inject.Singleton;
|
||||
import nu.marginalia.mqsm.state.MachineState;
|
||||
import nu.marginalia.mqsm.state.StateTransition;
|
||||
|
||||
import java.util.function.Function;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
@Singleton
|
||||
public class StateFactory {
|
||||
private final Gson gson;
|
||||
|
||||
@Inject
|
||||
public StateFactory(Gson gson) {
|
||||
this.gson = gson;
|
||||
}
|
||||
|
||||
public <T> MachineState create(String name, Class<T> param, Function<T, StateTransition> logic) {
|
||||
return new MachineState() {
|
||||
@Override
|
||||
public String name() {
|
||||
return name;
|
||||
}
|
||||
|
||||
@Override
|
||||
public StateTransition next(String message) {
|
||||
return logic.apply(gson.fromJson(message, param));
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isFinal() {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
public MachineState create(String name, Supplier<StateTransition> logic) {
|
||||
return new MachineState() {
|
||||
@Override
|
||||
public String name() {
|
||||
return name;
|
||||
}
|
||||
|
||||
@Override
|
||||
public StateTransition next(String message) {
|
||||
return logic.get();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isFinal() {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
public StateTransition transition(String state) {
|
||||
return StateTransition.to(state);
|
||||
}
|
||||
|
||||
public StateTransition transition(String state, Object message) {
|
||||
return StateTransition.to(state, gson.toJson(message));
|
||||
}
|
||||
}
|
@ -0,0 +1,176 @@
|
||||
package nu.marginalia.mqsm;
|
||||
|
||||
import nu.marginalia.mq.MqMessage;
|
||||
import nu.marginalia.mq.MqMessageState;
|
||||
import nu.marginalia.mq.inbox.MqInbox;
|
||||
import nu.marginalia.mq.inbox.MqInboxResponse;
|
||||
import nu.marginalia.mq.inbox.MqSubscription;
|
||||
import nu.marginalia.mq.outbox.MqOutbox;
|
||||
import nu.marginalia.mq.persistence.MqPersistence;
|
||||
import nu.marginalia.mqsm.state.*;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
|
||||
/** A state machine that can be used to implement a finite state machine
|
||||
* using a message queue as the persistence layer. The state machine is
|
||||
* resilient to crashes and can be resumed from the last state.
|
||||
*/
|
||||
public class StateMachine {
|
||||
private final Logger logger = LoggerFactory.getLogger(StateMachine.class);
|
||||
|
||||
private final MqInbox smInbox;
|
||||
private final MqOutbox smOutbox;
|
||||
private final String queueName;
|
||||
private MachineState state;
|
||||
|
||||
private final MachineState errorState = new ErrorState();
|
||||
private final MachineState finalState = new FinalState();
|
||||
private final MachineState resumingState = new ResumingState();
|
||||
|
||||
private final Map<String, MachineState> allStates = new HashMap<>();
|
||||
|
||||
public StateMachine(MqPersistence persistence, String queueName, UUID instanceUUID) {
|
||||
this.queueName = queueName;
|
||||
|
||||
smInbox = new MqInbox(persistence, queueName, instanceUUID);
|
||||
smOutbox = new MqOutbox(persistence, queueName, instanceUUID);
|
||||
|
||||
smInbox.subscribe(new StateEventSubscription());
|
||||
|
||||
registerStates(List.of(errorState, finalState, resumingState));
|
||||
}
|
||||
|
||||
/** Register the state graph */
|
||||
public void registerStates(MachineState... states) {
|
||||
if (state != null) {
|
||||
throw new IllegalStateException("Cannot register states after state machine has been initialized");
|
||||
}
|
||||
|
||||
for (var state : states) {
|
||||
allStates.put(state.name(), state);
|
||||
}
|
||||
}
|
||||
|
||||
/** Register the state graph */
|
||||
public void registerStates(List<MachineState> states) {
|
||||
for (var state : states) {
|
||||
allStates.put(state.name(), state);
|
||||
}
|
||||
}
|
||||
|
||||
/** Wait for the state machine to reach a final state.
|
||||
* (possibly forever, halting problem and so on)
|
||||
*/
|
||||
public void join() throws InterruptedException {
|
||||
synchronized (this) {
|
||||
if (null == state)
|
||||
return;
|
||||
|
||||
while (!state.isFinal()) {
|
||||
wait();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/** Initialize the state machine. */
|
||||
public void init() throws Exception {
|
||||
var transition = StateTransition.to("INITIAL");
|
||||
|
||||
synchronized (this) {
|
||||
this.state = allStates.get(transition.state());
|
||||
notifyAll();
|
||||
}
|
||||
|
||||
smInbox.start();
|
||||
smOutbox.notify(transition.state(), transition.message());
|
||||
}
|
||||
|
||||
/** Resume the state machine from the last known state. */
|
||||
public void resume() throws Exception {
|
||||
|
||||
if (state == null) {
|
||||
var messages = smInbox.replay(1);
|
||||
|
||||
if (messages.isEmpty()) {
|
||||
init();
|
||||
} else {
|
||||
var firstMessage = messages.get(0);
|
||||
|
||||
smInbox.start();
|
||||
|
||||
logger.info("Resuming state machine from {}({})/{}", firstMessage.function(), firstMessage.payload(), firstMessage.state());
|
||||
|
||||
if (firstMessage.state() == MqMessageState.NEW) {
|
||||
// The message is not acknowledged, so starting the inbox will trigger a state transition
|
||||
//
|
||||
// We still need to set a state here so that the join() method works
|
||||
|
||||
state = resumingState;
|
||||
} else {
|
||||
// The message is already acknowledged, so we replay the last state
|
||||
onStateTransition(firstMessage.function(), firstMessage.payload());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void stop() throws InterruptedException {
|
||||
smInbox.stop();
|
||||
smOutbox.stop();
|
||||
}
|
||||
|
||||
private void onStateTransition(String nextState, String message) {
|
||||
try {
|
||||
logger.info("FSM State change in {}: {}->{}({})",
|
||||
queueName,
|
||||
state == null ? "[null]" : state.name(),
|
||||
nextState,
|
||||
message);
|
||||
|
||||
synchronized (this) {
|
||||
this.state = allStates.get(nextState);
|
||||
notifyAll();
|
||||
}
|
||||
|
||||
if (!state.isFinal()) {
|
||||
var transition = state.next(message);
|
||||
smOutbox.notify(transition.state(), transition.message());
|
||||
}
|
||||
}
|
||||
catch (Exception e) {
|
||||
logger.error("Error in state machine transition", e);
|
||||
setErrorState();
|
||||
}
|
||||
}
|
||||
|
||||
private void setErrorState() {
|
||||
synchronized (this) {
|
||||
state = errorState;
|
||||
notifyAll();
|
||||
}
|
||||
}
|
||||
|
||||
private class StateEventSubscription implements MqSubscription {
|
||||
|
||||
@Override
|
||||
public boolean filter(MqMessage rawMessage) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MqInboxResponse onRequest(MqMessage msg) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onNotification(MqMessage msg) {
|
||||
onStateTransition(msg.function(), msg.payload());
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,14 @@
|
||||
package nu.marginalia.mqsm.state;
|
||||
|
||||
public class ErrorState implements MachineState {
|
||||
@Override
|
||||
public String name() { return "ERROR"; }
|
||||
|
||||
@Override
|
||||
public StateTransition next(String message) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isFinal() { return true; }
|
||||
}
|
@ -0,0 +1,14 @@
|
||||
package nu.marginalia.mqsm.state;
|
||||
|
||||
public class FinalState implements MachineState {
|
||||
@Override
|
||||
public String name() { return "END"; }
|
||||
|
||||
@Override
|
||||
public StateTransition next(String message) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isFinal() { return true; }
|
||||
}
|
@ -0,0 +1,8 @@
|
||||
package nu.marginalia.mqsm.state;
|
||||
|
||||
public interface MachineState {
|
||||
String name();
|
||||
StateTransition next(String message);
|
||||
|
||||
boolean isFinal();
|
||||
}
|
@ -0,0 +1,14 @@
|
||||
package nu.marginalia.mqsm.state;
|
||||
|
||||
public class ResumingState implements MachineState {
|
||||
@Override
|
||||
public String name() { return "RESUMING"; }
|
||||
|
||||
@Override
|
||||
public StateTransition next(String message) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isFinal() { return false; }
|
||||
}
|
@ -0,0 +1,11 @@
|
||||
package nu.marginalia.mqsm.state;
|
||||
|
||||
public record StateTransition(String state, String message) {
|
||||
public static StateTransition to(String state) {
|
||||
return new StateTransition(state, "");
|
||||
}
|
||||
|
||||
public static StateTransition to(String state, String message) {
|
||||
return new StateTransition(state, message);
|
||||
}
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package nu.marginalia.mq.outbox;
|
||||
package nu.marginalia.mq;
|
||||
|
||||
import nu.marginalia.mq.MqMessageState;
|
||||
|
@ -1,7 +1,6 @@
|
||||
package nu.marginalia.mq.outbox;
|
||||
package nu.marginalia.mq;
|
||||
|
||||
import com.zaxxer.hikari.HikariDataSource;
|
||||
import nu.marginalia.mq.MqMessageState;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
|
||||
import java.sql.SQLException;
|
@ -4,6 +4,7 @@ import com.zaxxer.hikari.HikariConfig;
|
||||
import com.zaxxer.hikari.HikariDataSource;
|
||||
import nu.marginalia.mq.MqMessage;
|
||||
import nu.marginalia.mq.MqMessageState;
|
||||
import nu.marginalia.mq.MqTestUtil;
|
||||
import nu.marginalia.mq.inbox.MqInboxResponse;
|
||||
import nu.marginalia.mq.inbox.MqInbox;
|
||||
import nu.marginalia.mq.inbox.MqSubscription;
|
||||
@ -154,9 +155,12 @@ public class MqOutboxTest {
|
||||
}
|
||||
|
||||
@Override
|
||||
public MqInboxResponse handle(MqMessage msg) {
|
||||
public MqInboxResponse onRequest(MqMessage msg) {
|
||||
return MqInboxResponse.ok(response);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onNotification(MqMessage msg) { }
|
||||
};
|
||||
}
|
||||
|
||||
@ -168,9 +172,12 @@ public class MqOutboxTest {
|
||||
}
|
||||
|
||||
@Override
|
||||
public MqInboxResponse handle(MqMessage msg) {
|
||||
public MqInboxResponse onRequest(MqMessage msg) {
|
||||
return MqInboxResponse.ok(msg.payload());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onNotification(MqMessage msg) {}
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -1,9 +1,9 @@
|
||||
package nu.marginalia.mq.outbox;
|
||||
package nu.marginalia.mq.persistence;
|
||||
|
||||
import com.zaxxer.hikari.HikariConfig;
|
||||
import com.zaxxer.hikari.HikariDataSource;
|
||||
import nu.marginalia.mq.MqMessageState;
|
||||
import nu.marginalia.mq.persistence.MqPersistence;
|
||||
import nu.marginalia.mq.MqTestUtil;
|
||||
import org.junit.jupiter.api.*;
|
||||
import org.testcontainers.containers.MariaDBContainer;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
@ -0,0 +1,174 @@
|
||||
package nu.marginalia.mqsm;
|
||||
|
||||
import com.google.gson.GsonBuilder;
|
||||
import com.zaxxer.hikari.HikariConfig;
|
||||
import com.zaxxer.hikari.HikariDataSource;
|
||||
import nu.marginalia.mq.MqMessageRow;
|
||||
import nu.marginalia.mq.MqMessageState;
|
||||
import nu.marginalia.mq.MqTestUtil;
|
||||
import nu.marginalia.mq.persistence.MqPersistence;
|
||||
import org.junit.jupiter.api.*;
|
||||
import org.testcontainers.containers.MariaDBContainer;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@Tag("slow")
|
||||
@Testcontainers
|
||||
public class StateMachineTest {
|
||||
@Container
|
||||
static MariaDBContainer<?> mariaDBContainer = new MariaDBContainer<>("mariadb")
|
||||
.withDatabaseName("WMSA_prod")
|
||||
.withUsername("wmsa")
|
||||
.withPassword("wmsa")
|
||||
.withInitScript("sql/current/11-message-queue.sql")
|
||||
.withNetworkAliases("mariadb");
|
||||
|
||||
static HikariDataSource dataSource;
|
||||
static MqPersistence persistence;
|
||||
private String inboxId;
|
||||
|
||||
@BeforeEach
|
||||
public void setUp() {
|
||||
inboxId = UUID.randomUUID().toString();
|
||||
}
|
||||
@BeforeAll
|
||||
public static void setUpAll() {
|
||||
HikariConfig config = new HikariConfig();
|
||||
config.setJdbcUrl(mariaDBContainer.getJdbcUrl());
|
||||
config.setUsername("wmsa");
|
||||
config.setPassword("wmsa");
|
||||
|
||||
dataSource = new HikariDataSource(config);
|
||||
persistence = new MqPersistence(dataSource);
|
||||
}
|
||||
|
||||
@AfterAll
|
||||
public static void tearDownAll() {
|
||||
dataSource.close();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStartStopStartStop() throws Exception {
|
||||
var sm = new StateMachine(persistence, inboxId, UUID.randomUUID());
|
||||
var stateFactory = new StateFactory(new GsonBuilder().create());
|
||||
|
||||
var initial = stateFactory.create("INITIAL", () -> stateFactory.transition("GREET", "World"));
|
||||
|
||||
var greet = stateFactory.create("GREET", String.class, (String message) -> {
|
||||
System.out.println("Hello, " + message + "!");
|
||||
return stateFactory.transition("COUNT-TO-FIVE", 0);
|
||||
});
|
||||
|
||||
var ctf = stateFactory.create("COUNT-TO-FIVE", Integer.class, (Integer count) -> {
|
||||
System.out.println(count);
|
||||
if (count < 5) {
|
||||
return stateFactory.transition("COUNT-TO-FIVE", count + 1);
|
||||
} else {
|
||||
return stateFactory.transition("END");
|
||||
}
|
||||
});
|
||||
|
||||
sm.registerStates(initial, greet, ctf);
|
||||
|
||||
sm.init();
|
||||
|
||||
Thread.sleep(300);
|
||||
sm.stop();
|
||||
|
||||
var sm2 = new StateMachine(persistence, inboxId, UUID.randomUUID());
|
||||
sm2.registerStates(initial, greet, ctf);
|
||||
sm2.resume();
|
||||
sm2.join();
|
||||
sm2.stop();
|
||||
|
||||
MqTestUtil.getMessages(dataSource, inboxId).forEach(System.out::println);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void smResumeFromNew() throws Exception {
|
||||
var sm = new StateMachine(persistence, inboxId, UUID.randomUUID());
|
||||
var stateFactory = new StateFactory(new GsonBuilder().create());
|
||||
|
||||
var initial = stateFactory.create("INITIAL", () -> stateFactory.transition("A"));
|
||||
var stateA = stateFactory.create("A", () -> stateFactory.transition("B"));
|
||||
var stateB = stateFactory.create("B", () -> stateFactory.transition("C"));
|
||||
var stateC = stateFactory.create("C", () -> stateFactory.transition("END"));
|
||||
|
||||
sm.registerStates(initial, stateA, stateB, stateC);
|
||||
persistence.sendNewMessage(inboxId, null,"B", "", null);
|
||||
|
||||
sm.resume();
|
||||
|
||||
sm.join();
|
||||
sm.stop();
|
||||
|
||||
List<String> states = MqTestUtil.getMessages(dataSource, inboxId)
|
||||
.stream()
|
||||
.peek(System.out::println)
|
||||
.map(MqMessageRow::function)
|
||||
.toList();
|
||||
|
||||
assertEquals(List.of("B", "C", "END"), states);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void smResumeFromAck() throws Exception {
|
||||
var sm = new StateMachine(persistence, inboxId, UUID.randomUUID());
|
||||
var stateFactory = new StateFactory(new GsonBuilder().create());
|
||||
|
||||
var initial = stateFactory.create("INITIAL", () -> stateFactory.transition("A"));
|
||||
var stateA = stateFactory.create("A", () -> stateFactory.transition("B"));
|
||||
var stateB = stateFactory.create("B", () -> stateFactory.transition("C"));
|
||||
var stateC = stateFactory.create("C", () -> stateFactory.transition("END"));
|
||||
|
||||
sm.registerStates(initial, stateA, stateB, stateC);
|
||||
|
||||
long id = persistence.sendNewMessage(inboxId, null,"B", "", null);
|
||||
persistence.updateMessageState(id, MqMessageState.ACK);
|
||||
|
||||
sm.resume();
|
||||
|
||||
sm.join();
|
||||
sm.stop();
|
||||
|
||||
List<String> states = MqTestUtil.getMessages(dataSource, inboxId)
|
||||
.stream()
|
||||
.peek(System.out::println)
|
||||
.map(MqMessageRow::function)
|
||||
.toList();
|
||||
|
||||
assertEquals(List.of("B", "C", "END"), states);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void smResumeEmptyQueue() throws Exception {
|
||||
var sm = new StateMachine(persistence, inboxId, UUID.randomUUID());
|
||||
var stateFactory = new StateFactory(new GsonBuilder().create());
|
||||
|
||||
var initial = stateFactory.create("INITIAL", () -> stateFactory.transition("A"));
|
||||
var stateA = stateFactory.create("A", () -> stateFactory.transition("B"));
|
||||
var stateB = stateFactory.create("B", () -> stateFactory.transition("C"));
|
||||
var stateC = stateFactory.create("C", () -> stateFactory.transition("END"));
|
||||
|
||||
sm.registerStates(initial, stateA, stateB, stateC);
|
||||
|
||||
sm.resume();
|
||||
|
||||
sm.join();
|
||||
sm.stop();
|
||||
|
||||
List<String> states = MqTestUtil.getMessages(dataSource, inboxId)
|
||||
.stream()
|
||||
.peek(System.out::println)
|
||||
.map(MqMessageRow::function)
|
||||
.toList();
|
||||
|
||||
assertEquals(List.of("INITIAL", "A", "B", "C", "END"), states);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user