(mqsm) guard against spurious transitions from unexpected messages

This commit is contained in:
Viktor Lofgren 2023-07-12 22:44:05 +02:00
parent bf783dad7a
commit 6c88f00a9d
8 changed files with 90 additions and 34 deletions

View File

@ -1,6 +1,6 @@
CREATE TABLE IF NOT EXISTS MESSAGE_QUEUE (
ID BIGINT AUTO_INCREMENT PRIMARY KEY COMMENT 'Unique id',
RELATED_ID BIGINT COMMENT 'Unique id a related message',
RELATED_ID BIGINT NOT NULL DEFAULT -1 COMMENT 'Unique id a related message',
SENDER_INBOX VARCHAR(255) COMMENT 'Name of the sender inbox',
RECIPIENT_INBOX VARCHAR(255) NOT NULL COMMENT 'Name of the recipient inbox',
FUNCTION VARCHAR(255) NOT NULL COMMENT 'Which function to run',

View File

@ -1,6 +1,6 @@
CREATE TABLE IF NOT EXISTS MESSAGE_QUEUE (
ID BIGINT AUTO_INCREMENT PRIMARY KEY COMMENT 'Unique id',
RELATED_ID BIGINT COMMENT 'Unique id a related message',
RELATED_ID BIGINT NOT NULL DEFAULT -1 COMMENT 'Unique id a related message',
SENDER_INBOX VARCHAR(255) COMMENT 'Name of the sender inbox',
RECIPIENT_INBOX VARCHAR(255) NOT NULL COMMENT 'Name of the recipient inbox',
FUNCTION VARCHAR(255) NOT NULL COMMENT 'Which function to run',

View File

@ -7,7 +7,6 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.sql.SQLException;
import java.sql.Time;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
@ -107,7 +106,7 @@ public class MqOutbox {
* <br>
* Use waitResponse(id) or pollResponse(id) to fetch the response. */
public long sendAsync(String function, String payload) throws Exception {
var id = persistence.sendNewMessage(inboxName, replyInboxName, function, payload, null);
var id = persistence.sendNewMessage(inboxName, replyInboxName, null, function, payload, null);
pendingRequests.put(id, id);
@ -163,7 +162,13 @@ public class MqOutbox {
}
public long notify(String function, String payload) throws Exception {
return persistence.sendNewMessage(inboxName, null, function, payload, null);
return persistence.sendNewMessage(inboxName, null, null, function, payload, null);
}
public long notify(long relatedId, String function, String payload) throws Exception {
return persistence.sendNewMessage(inboxName, null, relatedId, function, payload, null);
}
public void flagAsBad(long id) throws SQLException {
persistence.updateMessageState(id, MqMessageState.ERR);
}
}

View File

@ -53,6 +53,7 @@ public class MqPersistence {
*
* @param recipientInboxName The recipient's inbox name
* @param senderInboxName (nullable) The sender's inbox name. Only needed if a reply is expected. If null, the message is not expected to be replied to.
* @param relatedMessageId (nullable) The id of the message this message is related to. If null, the message is not related to any other message.
* @param function The function to call
* @param payload The payload to send, typically JSON.
* @param ttl (nullable) The time to live of the message, in seconds. If null, the message will never set to DEAD.
@ -61,14 +62,15 @@ public class MqPersistence {
public long sendNewMessage(String recipientInboxName,
@Nullable
String senderInboxName,
Long relatedMessageId,
String function,
String payload,
@Nullable Duration ttl
) throws Exception {
try (var conn = dataSource.getConnection();
var stmt = conn.prepareStatement("""
INSERT INTO MESSAGE_QUEUE(RECIPIENT_INBOX, SENDER_INBOX, FUNCTION, PAYLOAD, TTL)
VALUES(?, ?, ?, ?, ?)
INSERT INTO MESSAGE_QUEUE(RECIPIENT_INBOX, SENDER_INBOX, RELATED_ID, FUNCTION, PAYLOAD, TTL)
VALUES(?, ?, ?, ?, ?, ?)
""");
var lastIdQuery = conn.prepareStatement("SELECT LAST_INSERT_ID()")) {
@ -77,10 +79,13 @@ public class MqPersistence {
if (senderInboxName == null) stmt.setNull(2, java.sql.Types.VARCHAR);
else stmt.setString(2, senderInboxName);
stmt.setString(3, function);
stmt.setString(4, payload);
if (ttl == null) stmt.setNull(5, java.sql.Types.BIGINT);
else stmt.setLong(5, ttl.toSeconds());
if (relatedMessageId == null) stmt.setLong(3, -1);
else stmt.setLong(3, relatedMessageId);
stmt.setString(4, function);
stmt.setString(5, payload);
if (ttl == null) stmt.setNull(6, java.sql.Types.BIGINT);
else stmt.setLong(6, ttl.toSeconds());
stmt.executeUpdate();
var rsp = lastIdQuery.executeQuery();

View File

@ -36,6 +36,14 @@ public class StateMachine {
private final Map<String, MachineState> allStates = new HashMap<>();
/* The expectedMessageId guards against spurious state changes being triggered by old messages in the queue
*
* It contains the message id of the last message that was processed, and the messages sent by the state machine to
* itself via the message queue all have relatedId set to expectedMessageId. If the state machine is unitialized or
* in a terminal state, it will accept messages with relatedIds that are equal to -1.
* */
private long expectedMessageId = -1;
public StateMachine(MqFactory messageQueueFactory,
String queueName,
UUID instanceUUID,
@ -99,7 +107,7 @@ public class StateMachine {
}
smInbox.start();
smOutbox.notify(transition.state(), transition.message());
smOutbox.notify(expectedMessageId, transition.state(), transition.message());
}
/** Initialize the state machine. */
@ -112,7 +120,7 @@ public class StateMachine {
}
smInbox.start();
smOutbox.notify(transition.state(), transition.message());
smOutbox.notify(expectedMessageId, transition.state(), transition.message());
}
/** Resume the state machine from the last known state. */
@ -133,6 +141,7 @@ public class StateMachine {
smInbox.start();
logger.info("Resuming state machine from {}({})/{}", firstMessage.function(), firstMessage.payload(), firstMessage.state());
expectedMessageId = firstMessage.relatedId();
if (firstMessage.state() == MqMessageState.NEW) {
// The message is not acknowledged, so starting the inbox will trigger a state transition
@ -141,10 +150,10 @@ public class StateMachine {
state = resumingState;
} else if (resumeState.resumeBehavior().equals(ResumeBehavior.ERROR)) {
// The message is acknowledged, but the state does not support resuming
smOutbox.notify("ERROR", "Illegal resumption from ACK'ed state " + firstMessage.function());
smOutbox.notify(expectedMessageId, "ERROR", "Illegal resumption from ACK'ed state " + firstMessage.function());
} else {
// The message is already acknowledged, so we replay the last state
onStateTransition(firstMessage.function(), firstMessage.payload());
onStateTransition(firstMessage);
}
}
@ -153,13 +162,24 @@ public class StateMachine {
smOutbox.stop();
}
private void onStateTransition(String nextState, String message) {
private void onStateTransition(MqMessage msg) {
final String nextState = msg.function();
final String data = msg.payload();
final long messageId = msg.msgId();
final long relatedId = msg.relatedId();
if (expectedMessageId != relatedId) {
// We've received a message that we didn't expect, throwing an exception will cause it to be flagged
// as an error in the message queue; the message queue will proceed
throw new IllegalStateException("Unexpected message id " + relatedId + ", expected " + expectedMessageId);
}
try {
logger.info("FSM State change in {}: {}->{}({})",
queueName,
state == null ? "[null]" : state.name(),
nextState,
message);
data);
if (!allStates.containsKey(nextState)) {
logger.error("Unknown state {}", nextState);
@ -173,8 +193,13 @@ public class StateMachine {
}
if (!state.isFinal()) {
var transition = state.next(message);
smOutbox.notify(transition.state(), transition.message());
var transition = state.next(msg.payload());
expectedMessageId = messageId;
smOutbox.notify(expectedMessageId, transition.state(), transition.message());
}
else {
expectedMessageId = -1;
}
}
catch (Exception e) {
@ -204,7 +229,7 @@ public class StateMachine {
@Override
public void onNotification(MqMessage msg) {
onStateTransition(msg.function(), msg.payload());
onStateTransition(msg);
try {
stateChangeListeners.forEach(l -> l.accept(msg.function(), msg.payload()));
}

View File

@ -57,7 +57,7 @@ public class MqPersistenceTest {
@Test
public void testReaper() throws Exception {
long id = persistence.sendNewMessage(recipientId, senderId, "function", "payload", Duration.ofSeconds(2));
long id = persistence.sendNewMessage(recipientId, senderId, null, "function", "payload", Duration.ofSeconds(2));
persistence.reapDeadMessages();
var messages = MqTestUtil.getMessages(dataSource, recipientId);
@ -77,7 +77,7 @@ public class MqPersistenceTest {
@Test
public void sendWithReplyAddress() throws Exception {
long id = persistence.sendNewMessage(recipientId, senderId, "function", "payload", Duration.ofSeconds(30));
long id = persistence.sendNewMessage(recipientId, senderId, null, "function", "payload", Duration.ofSeconds(30));
var messages = MqTestUtil.getMessages(dataSource, recipientId);
assertEquals(1, messages.size());
@ -95,7 +95,7 @@ public class MqPersistenceTest {
@Test
public void sendNoReplyAddress() throws Exception {
long id = persistence.sendNewMessage(recipientId, null, "function", "payload", Duration.ofSeconds(30));
long id = persistence.sendNewMessage(recipientId, null, null, "function", "payload", Duration.ofSeconds(30));
var messages = MqTestUtil.getMessages(dataSource, recipientId);
assertEquals(1, messages.size());
@ -114,7 +114,7 @@ public class MqPersistenceTest {
@Test
public void updateState() throws Exception {
long id = persistence.sendNewMessage(recipientId, senderId, "function", "payload", Duration.ofSeconds(30));
long id = persistence.sendNewMessage(recipientId, senderId, null, "function", "payload", Duration.ofSeconds(30));
persistence.updateMessageState(id, MqMessageState.OK);
System.out.println(id);
@ -131,7 +131,7 @@ public class MqPersistenceTest {
@Test
public void testReply() throws Exception {
long request = persistence.sendNewMessage(recipientId, senderId, "function", "payload", Duration.ofSeconds(30));
long request = persistence.sendNewMessage(recipientId, senderId, null, "function", "payload", Duration.ofSeconds(30));
long response = persistence.sendResponse(request, MqMessageState.OK, "response");
var sentMessages = MqTestUtil.getMessages(dataSource, recipientId);
@ -159,7 +159,7 @@ public class MqPersistenceTest {
String instanceId = "BATMAN";
long tick = 1234L;
long id = persistence.sendNewMessage(recipientId, null,"function", "payload", Duration.ofSeconds(30));
long id = persistence.sendNewMessage(recipientId, null, null, "function", "payload", Duration.ofSeconds(30));
var messagesPollFirstTime = persistence.pollInbox(recipientId, instanceId , tick, 10);

View File

@ -81,7 +81,7 @@ public class StateMachineResumeTest {
var stateFactory = new StateFactory(new GsonBuilder().create());
var sm = new StateMachine(messageQueueFactory, inboxId, UUID.randomUUID(), new ResumeTrialsGraph(stateFactory));
persistence.sendNewMessage(inboxId, null,"RESUMABLE", "", null);
persistence.sendNewMessage(inboxId, null, -1L, "RESUMABLE", "", null);
sm.resume();
@ -102,7 +102,7 @@ public class StateMachineResumeTest {
var stateFactory = new StateFactory(new GsonBuilder().create());
var sm = new StateMachine(messageQueueFactory, inboxId, UUID.randomUUID(), new ResumeTrialsGraph(stateFactory));
long id = persistence.sendNewMessage(inboxId, null,"RESUMABLE", "", null);
long id = persistence.sendNewMessage(inboxId, null, -1L, "RESUMABLE", "", null);
persistence.updateMessageState(id, MqMessageState.ACK);
sm.resume();
@ -125,7 +125,7 @@ public class StateMachineResumeTest {
var stateFactory = new StateFactory(new GsonBuilder().create());
var sm = new StateMachine(messageQueueFactory, inboxId, UUID.randomUUID(), new ResumeTrialsGraph(stateFactory));
persistence.sendNewMessage(inboxId, null,"NON-RESUMABLE", "", null);
persistence.sendNewMessage(inboxId, null, -1L, "NON-RESUMABLE", "", null);
sm.resume();
@ -146,7 +146,7 @@ public class StateMachineResumeTest {
var stateFactory = new StateFactory(new GsonBuilder().create());
var sm = new StateMachine(messageQueueFactory, inboxId, UUID.randomUUID(), new ResumeTrialsGraph(stateFactory));
long id = persistence.sendNewMessage(inboxId, null,"NON-RESUMABLE", "", null);
long id = persistence.sendNewMessage(inboxId, null, null, "NON-RESUMABLE", "", null);
persistence.updateMessageState(id, MqMessageState.ACK);
sm.resume();

View File

@ -118,4 +118,25 @@ public class StateMachineTest {
MqTestUtil.getMessages(dataSource, inboxId).forEach(System.out::println);
}
@Test
public void testFalseTransition() throws Exception {
var stateFactory = new StateFactory(new GsonBuilder().create());
var sm = new StateMachine(messageQueueFactory, inboxId, UUID.randomUUID(), new TestGraph(stateFactory));
// Prep the queue with a message to set the state to initial,
// and an additional message to trigger the false transition back to initial
persistence.sendNewMessage(inboxId, null, null, "INITIAL", "", null);
persistence.sendNewMessage(inboxId, null, null, "INITIAL", "", null);
sm.resume();
Thread.sleep(50);
sm.join();
sm.stop();
MqTestUtil.getMessages(dataSource, inboxId).forEach(System.out::println);
}
}