Message queue WIP

This commit is contained in:
Viktor Lofgren 2023-07-03 11:04:08 +02:00
parent 62cc9df206
commit 31ae71c7d6
18 changed files with 1130 additions and 0 deletions

View File

@ -0,0 +1,20 @@
CREATE TABLE PROC_MESSAGE(
ID BIGINT AUTO_INCREMENT PRIMARY KEY COMMENT 'Unique id',
RELATED_ID BIGINT NOT NULL DEFAULT -1 COMMENT 'Unique id a related message',
SENDER_INBOX VARCHAR(255) COMMENT 'Name of the sender inbox',
RECIPIENT_INBOX VARCHAR(255) NOT NULL COMMENT 'Name of the recipient inbox',
FUNCTION VARCHAR(255) NOT NULL COMMENT 'Which function to run',
PAYLOAD TEXT COMMENT 'Message to recipient',
OWNER_INSTANCE VARCHAR(255) COMMENT 'Instance UUID corresponding to the party that has claimed the message',
OWNER_TICK BIGINT DEFAULT -1 COMMENT 'Used by recipient to determine which messages it has processed',
STATE ENUM('NEW', 'ACK', 'OK', 'ERR', 'DEAD')
NOT NULL DEFAULT 'NEW' COMMENT 'Processing state',
CREATED_TIME TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) COMMENT 'Time of creation',
UPDATED_TIME TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) COMMENT 'Time of last update',
TTL INT COMMENT 'Time to live in seconds'
);

View File

@ -0,0 +1,23 @@
CREATE TABLE PROC_MESSAGE(
ID BIGINT AUTO_INCREMENT PRIMARY KEY COMMENT 'Unique id',
RELATED_ID BIGINT COMMENT 'Unique id a related message',
SENDER_INBOX VARCHAR(255) COMMENT 'Name of the sender inbox',
RECIPIENT_INBOX VARCHAR(255) NOT NULL COMMENT 'Name of the recipient inbox',
FUNCTION VARCHAR(255) NOT NULL COMMENT 'Which function to run',
PAYLOAD TEXT COMMENT 'Message to recipient',
-- These fields are used to avoid double processing of messages
-- instance marks the unique instance of the party, and the tick marks
-- the current polling iteration. Both are necessary.
OWNER_INSTANCE VARCHAR(255) COMMENT 'Instance UUID corresponding to the party that has claimed the message',
OWNER_TICK BIGINT DEFAULT -1 COMMENT 'Used by recipient to determine which messages it has processed',
STATE ENUM('NEW', 'ACK', 'OK', 'ERR', 'DEAD')
NOT NULL DEFAULT 'NEW' COMMENT 'Processing state',
CREATED_TIME TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) COMMENT 'Time of creation',
UPDATED_TIME TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) COMMENT 'Time of last update',
TTL INT COMMENT 'Time to live in seconds'
);

View File

@ -0,0 +1,48 @@
plugins {
id 'java'
}
java {
toolchain {
languageVersion.set(JavaLanguageVersion.of(17))
}
}
dependencies {
implementation project(':code:common:service-client')
implementation project(':code:common:service-discovery')
implementation project(':code:common:db')
implementation libs.lombok
annotationProcessor libs.lombok
implementation libs.spark
implementation libs.guice
implementation libs.rxjava
implementation libs.bundles.prometheus
implementation libs.bundles.slf4j
implementation libs.bucket4j
testImplementation libs.bundles.slf4j.test
implementation libs.bundles.mariadb
testImplementation libs.bundles.slf4j.test
testImplementation libs.bundles.junit
testImplementation libs.mockito
testImplementation platform('org.testcontainers:testcontainers-bom:1.17.4')
testImplementation 'org.testcontainers:mariadb:1.17.4'
testImplementation 'org.testcontainers:junit-jupiter:1.17.4'
}
test {
useJUnitPlatform()
}
task fastTests(type: Test) {
useJUnitPlatform {
excludeTags "slow"
}
}

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 18 KiB

View File

@ -0,0 +1,5 @@
# Message Queue
Implements a message queue using mariadb.
![Message States](msgstate.svg)

View File

@ -0,0 +1,11 @@
package nu.marginalia.mq;
public class MqException extends Exception {
public MqException(String message) {
super(message);
}
public MqException(String message, Throwable cause) {
super(message, cause);
}
}

View File

@ -0,0 +1,10 @@
package nu.marginalia.mq;
public record MqMessage(
long msgId,
long relatedId,
String function,
String payload,
MqMessageState state
) {
}

View File

@ -0,0 +1,9 @@
package nu.marginalia.mq;
public enum MqMessageState {
NEW,
ACK,
OK,
ERR,
DEAD
}

View File

@ -0,0 +1,185 @@
package nu.marginalia.mq.inbox;
import nu.marginalia.mq.MqMessage;
import nu.marginalia.mq.MqMessageState;
import nu.marginalia.mq.persistence.MqPersistence;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.List;
import java.sql.SQLException;
import java.util.Collection;
import java.util.UUID;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
public class MqInbox {
private final Logger logger = LoggerFactory.getLogger(MqInbox.class);
private final String inboxName;
private final String instanceUUID;
private final ExecutorService threadPool;
private final MqPersistence persistence;
private volatile boolean run = true;
private final int pollIntervalMs = Integer.getInteger("mq.inbox.poll-interval-ms", 1000);
private final List<MqSubscription> eventSubscribers = new ArrayList<>();
private final LinkedBlockingQueue<MqMessage> queue = new LinkedBlockingQueue<>(32);
private Thread pollDbThread;
private Thread notifyThread;
public MqInbox(MqPersistence persistence,
String inboxName,
UUID instanceUUID)
{
this.threadPool = Executors.newCachedThreadPool();
this.persistence = persistence;
this.inboxName = inboxName;
this.instanceUUID = instanceUUID.toString();
}
public void subscribe(MqSubscription subscription) {
eventSubscribers.add(subscription);
}
public void start() {
run = true;
if (eventSubscribers.isEmpty()) {
logger.error("No subscribers for inbox {}, registering shredder", inboxName);
}
// Add a final handler that fails any message that is not handled
eventSubscribers.add(new MqInboxShredder());
pollDbThread = new Thread(this::pollDb, "mq-inbox-update-thread:"+inboxName);
pollDbThread.setDaemon(true);
pollDbThread.start();
notifyThread = new Thread(this::notifySubscribers, "mq-inbox-notify-thread:"+inboxName);
notifyThread.setDaemon(true);
notifyThread.start();
}
public void stop() throws InterruptedException {
if (!run)
return;
logger.info("Shutting down inbox {}", inboxName);
run = false;
pollDbThread.join();
notifyThread.join();
threadPool.shutdownNow();
while (!threadPool.awaitTermination(5, TimeUnit.SECONDS));
}
private void notifySubscribers() {
try {
while (run) {
MqMessage msg = queue.poll(pollIntervalMs, TimeUnit.MILLISECONDS);
if (msg == null)
continue;
logger.info("Notifying subscribers of message {}", msg.msgId());
boolean handled = false;
for (var eventSubscriber : eventSubscribers) {
if (eventSubscriber.filter(msg)) {
handleMessageWithSubscriber(eventSubscriber, msg);
handled = true;
break;
}
}
if (!handled) {
logger.error("No subscriber wanted to handle message {}", msg.msgId());
}
}
}
catch (InterruptedException ex) {
logger.error("MQ inbox notify thread interrupted", ex);
}
}
private void handleMessageWithSubscriber(MqSubscription subscriber, MqMessage msg) {
threadPool.execute(() -> {
try {
final var rsp = subscriber.handle(msg);
sendResponse(msg, rsp.state(), rsp.message());
} catch (Exception ex) {
logger.error("Message Queue subscriber threw exception", ex);
sendResponse(msg, MqMessageState.ERR);
}
});
}
private void sendResponse(MqMessage msg, MqMessageState mqMessageState) {
try {
persistence.updateMessageState(msg.msgId(), mqMessageState);
}
catch (SQLException ex) {
logger.error("Failed to update message state", ex);
}
}
private void sendResponse(MqMessage msg, MqMessageState mqMessageState, String response) {
try {
persistence.sendResponse(msg.msgId(), mqMessageState, response);
}
catch (SQLException ex) {
logger.error("Failed to update message state", ex);
}
}
public void pollDb() {
try {
for (long tick = 1; run; tick++) {
queue.addAll(pollInbox(tick));
TimeUnit.MILLISECONDS.sleep(pollIntervalMs);
}
}
catch (InterruptedException ex) {
logger.error("MQ inbox update thread interrupted", ex);
}
}
private Collection<MqMessage> pollInbox(long tick) {
try {
return persistence.pollInbox(inboxName, instanceUUID, tick);
}
catch (SQLException ex) {
logger.error("Failed to poll inbox", ex);
return List.of();
}
}
private class MqInboxShredder implements MqSubscription {
@Override
public boolean filter(MqMessage rawMessage) {
return true;
}
@Override
public MqInboxResponse handle(MqMessage msg) {
logger.warn("Unhandled message {}", msg.msgId());
return MqInboxResponse.err();
}
}
}

View File

@ -0,0 +1,22 @@
package nu.marginalia.mq.inbox;
import nu.marginalia.mq.MqMessageState;
public record MqInboxResponse(String message, MqMessageState state) {
public static MqInboxResponse ok(String message) {
return new MqInboxResponse(message, MqMessageState.OK);
}
public static MqInboxResponse ok() {
return new MqInboxResponse("", MqMessageState.OK);
}
public static MqInboxResponse err(String message) {
return new MqInboxResponse(message, MqMessageState.ERR);
}
public static MqInboxResponse err() {
return new MqInboxResponse("", MqMessageState.ERR);
}
}

View File

@ -0,0 +1,9 @@
package nu.marginalia.mq.inbox;
import nu.marginalia.mq.MqMessage;
public interface MqSubscription {
boolean filter(MqMessage rawMessage);
MqInboxResponse handle(MqMessage msg);
}

View File

@ -0,0 +1,107 @@
package nu.marginalia.mq.outbox;
import nu.marginalia.mq.MqMessage;
import nu.marginalia.mq.persistence.MqPersistence;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.sql.SQLException;
import java.time.Duration;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
public class MqOutbox {
private final Logger logger = LoggerFactory.getLogger(MqOutbox.class);
private final MqPersistence persistence;
private final String inboxName;
private final String replyInboxName;
private final String instanceUUID;
private final ConcurrentHashMap<Long, Long> pendingRequests = new ConcurrentHashMap<>();
private final ConcurrentHashMap<Long, MqMessage> pendingResponses = new ConcurrentHashMap<>();
private final int pollIntervalMs = Integer.getInteger("mq.outbox.poll-interval-ms", 1000);
private final Thread pollThread;
private volatile boolean run = true;
public MqOutbox(MqPersistence persistence,
String inboxName,
UUID instanceUUID) {
this.persistence = persistence;
this.inboxName = inboxName;
this.replyInboxName = "reply:" + inboxName;
this.instanceUUID = instanceUUID.toString();
pollThread = new Thread(this::poll, "mq-outbox-poll-thread:" + inboxName);
pollThread.setDaemon(true);
pollThread.start();
}
public void stop() throws InterruptedException {
if (!run)
return;
logger.info("Shutting down outbox {}", inboxName);
pendingRequests.clear();
run = false;
pollThread.join();
}
private void poll() {
try {
for (long id = 1; run; id++) {
pollDb(id);
TimeUnit.MILLISECONDS.sleep(pollIntervalMs);
}
} catch (InterruptedException ex) {
logger.error("Outbox poll thread interrupted", ex);
}
}
private void pollDb(long tick) {
if (pendingRequests.isEmpty())
return;
try {
var updates = persistence.pollReplyInbox(replyInboxName, instanceUUID, tick);
for (var message : updates) {
pendingResponses.put(message.relatedId(), message);
pendingRequests.remove(message.relatedId());
}
if (updates.isEmpty() || pendingResponses.isEmpty())
return;
logger.info("Notifying {} pending responses", pendingResponses.size());
synchronized (pendingResponses) {
pendingResponses.notifyAll();
}
}
catch (SQLException ex) {
logger.error("Failed to poll inbox", ex);
}
}
public MqMessage send(String function, String payload) throws Exception {
var id = persistence.sendNewMessage(inboxName, replyInboxName, function, payload, null);
pendingRequests.put(id, id);
synchronized (pendingResponses) {
while (!pendingResponses.containsKey(id)) {
pendingResponses.wait(100);
}
return pendingResponses.remove(id);
}
}
}

View File

@ -0,0 +1,237 @@
package nu.marginalia.mq.persistence;
import com.google.inject.Inject;
import com.google.inject.Singleton;
import com.zaxxer.hikari.HikariDataSource;
import nu.marginalia.mq.MqMessageState;
import nu.marginalia.mq.MqMessage;
import javax.annotation.Nullable;
import java.sql.SQLException;
import java.time.Duration;
import java.util.*;
@Singleton
public class MqPersistence {
private final HikariDataSource dataSource;
@Inject
public MqPersistence(HikariDataSource dataSource) {
this.dataSource = dataSource;
}
/** Flags messages as dead if they have not been set to a terminal state within a TTL after the last update. */
public int reapDeadMessages() throws SQLException {
try (var conn = dataSource.getConnection();
var stmt = conn.prepareStatement("""
UPDATE PROC_MESSAGE
SET STATE='DEAD', UPDATED_TIME=CURRENT_TIMESTAMP(6)
WHERE STATE IN ('NEW', 'ACK')
AND TTL IS NOT NULL
AND TIMESTAMPDIFF(SECOND, UPDATED_TIME, CURRENT_TIMESTAMP(6)) > TTL
""")) {
return stmt.executeUpdate();
}
}
public long sendNewMessage(String recipientInboxName,
@Nullable
String senderInboxName,
String function,
String payload,
@Nullable Duration ttl
) throws Exception {
try (var conn = dataSource.getConnection();
var stmt = conn.prepareStatement("""
INSERT INTO PROC_MESSAGE(RECIPIENT_INBOX, SENDER_INBOX, FUNCTION, PAYLOAD, TTL)
VALUES(?, ?, ?, ?, ?)
""");
var lastIdQuery = conn.prepareStatement("SELECT LAST_INSERT_ID()")) {
stmt.setString(1, recipientInboxName);
if (senderInboxName == null) stmt.setNull(2, java.sql.Types.VARCHAR);
else stmt.setString(2, senderInboxName);
stmt.setString(3, function);
stmt.setString(4, payload);
if (ttl == null) stmt.setNull(5, java.sql.Types.BIGINT);
else stmt.setLong(5, ttl.toSeconds());
stmt.executeUpdate();
var rsp = lastIdQuery.executeQuery();
if (!rsp.next()) {
throw new IllegalStateException("No last insert id");
}
return rsp.getLong(1);
}
}
public void updateMessageState(long id, MqMessageState mqMessageState) throws SQLException {
try (var conn = dataSource.getConnection();
var stmt = conn.prepareStatement("""
UPDATE PROC_MESSAGE
SET STATE=?, UPDATED_TIME=CURRENT_TIMESTAMP(6)
WHERE ID=?
""")) {
stmt.setString(1, mqMessageState.name());
stmt.setLong(2, id);
if (stmt.executeUpdate() != 1) {
throw new IllegalArgumentException("No rows updated");
}
}
}
public long sendResponse(long id, MqMessageState mqMessageState, String message) throws SQLException {
try (var conn = dataSource.getConnection()) {
conn.setAutoCommit(false);
try (var updateState = conn.prepareStatement("""
UPDATE PROC_MESSAGE
SET STATE=?, UPDATED_TIME=CURRENT_TIMESTAMP(6)
WHERE ID=?
""");
var addResponse = conn.prepareStatement("""
INSERT INTO PROC_MESSAGE(RECIPIENT_INBOX, RELATED_ID, FUNCTION, PAYLOAD)
SELECT SENDER_INBOX, ID, ?, ?
FROM PROC_MESSAGE
WHERE ID=? AND SENDER_INBOX IS NOT NULL
""");
var lastIdQuery = conn.prepareStatement("SELECT LAST_INSERT_ID()")
) {
updateState.setString(1, mqMessageState.name());
updateState.setLong(2, id);
if (updateState.executeUpdate() != 1) {
throw new IllegalArgumentException("No rows updated");
}
addResponse.setString(1, "REPLY");
addResponse.setString(2, message);
addResponse.setLong(3, id);
if (addResponse.executeUpdate() != 1) {
throw new IllegalArgumentException("No rows updated");
}
var rsp = lastIdQuery.executeQuery();
if (!rsp.next()) {
throw new IllegalStateException("No last insert id");
}
long newId = rsp.getLong(1);
conn.commit();
return newId;
} catch (SQLException|IllegalStateException|IllegalArgumentException ex) {
conn.rollback();
throw ex;
} finally {
conn.setAutoCommit(true);
}
}
}
private int markInboxMessages(String inboxName, String instanceUUID, long tick) throws SQLException {
try (var conn = dataSource.getConnection();
var updateStmt = conn.prepareStatement("""
UPDATE PROC_MESSAGE
SET OWNER_INSTANCE=?, OWNER_TICK=?, UPDATED_TIME=CURRENT_TIMESTAMP(6), STATE='ACK'
WHERE RECIPIENT_INBOX=?
AND OWNER_INSTANCE IS NULL AND STATE='NEW'
""");
) {
updateStmt.setString(1, instanceUUID);
updateStmt.setLong(2, tick);
updateStmt.setString(3, inboxName);
return updateStmt.executeUpdate();
}
}
/** Marks unclaimed messages addressed to this inbox with instanceUUID and tick,
* then returns these messages.
*/
public Collection<MqMessage> pollInbox(String inboxName, String instanceUUID, long tick) throws SQLException {
int expected = markInboxMessages(inboxName, instanceUUID, tick);
if (expected == 0) {
return Collections.emptyList();
}
try (var conn = dataSource.getConnection();
var queryStmt = conn.prepareStatement("""
SELECT ID, RELATED_ID, FUNCTION, PAYLOAD, STATE FROM PROC_MESSAGE
WHERE OWNER_INSTANCE=? AND OWNER_TICK=?
""")
) {
queryStmt.setString(1, instanceUUID);
queryStmt.setLong(2, tick);
var rs = queryStmt.executeQuery();
List<MqMessage> messages = new ArrayList<>(expected);
while (rs.next()) {
long msgId = rs.getLong(1);
long relatedId = rs.getLong(2);
String function = rs.getString(3);
String payload = rs.getString(4);
MqMessageState state = MqMessageState.valueOf(rs.getString(5));
var msg = new MqMessage(msgId, relatedId, function, payload, state);
messages.add(msg);
}
return messages;
}
}
/** Marks unclaimed messages addressed to this inbox with instanceUUID and tick,
* then returns these messages.
*/
public Collection<MqMessage> pollReplyInbox(String inboxName, String instanceUUID, long tick) throws SQLException {
int expected = markInboxMessages(inboxName, instanceUUID, tick);
if (expected == 0) {
return Collections.emptyList();
}
try (var conn = dataSource.getConnection();
var queryStmt = conn.prepareStatement("""
SELECT SELF.ID, SELF.RELATED_ID, SELF.FUNCTION, SELF.PAYLOAD, PARENT.STATE FROM PROC_MESSAGE SELF
LEFT JOIN PROC_MESSAGE PARENT ON SELF.RELATED_ID=PARENT.ID
WHERE SELF.OWNER_INSTANCE=? AND SELF.OWNER_TICK=?
""")
) {
queryStmt.setString(1, instanceUUID);
queryStmt.setLong(2, tick);
var rs = queryStmt.executeQuery();
List<MqMessage> messages = new ArrayList<>(expected);
while (rs.next()) {
long msgId = rs.getLong(1);
long relatedId = rs.getLong(2);
String function = rs.getString(3);
String payload = rs.getString(4);
MqMessageState state = MqMessageState.valueOf(rs.getString(5));
var msg = new MqMessage(msgId, relatedId, function, payload, state);
messages.add(msg);
}
return messages;
}
}
}

View File

@ -0,0 +1,21 @@
package nu.marginalia.mq.outbox;
import nu.marginalia.mq.MqMessageState;
import javax.annotation.Nullable;
public record MqMessageRow (
long id,
long relatedId,
@Nullable
String senderInbox,
String recipientInbox,
String function,
String payload,
MqMessageState state,
String ownerInstance,
long ownerTick,
long createdTime,
long updatedTime,
long ttl
) {}

View File

@ -0,0 +1,177 @@
package nu.marginalia.mq.outbox;
import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.HikariDataSource;
import nu.marginalia.mq.MqMessage;
import nu.marginalia.mq.MqMessageState;
import nu.marginalia.mq.inbox.MqInboxResponse;
import nu.marginalia.mq.inbox.MqInbox;
import nu.marginalia.mq.inbox.MqSubscription;
import nu.marginalia.mq.persistence.MqPersistence;
import org.junit.jupiter.api.*;
import org.testcontainers.containers.MariaDBContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import java.util.UUID;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag("slow")
@Testcontainers
public class MqOutboxTest {
@Container
static MariaDBContainer<?> mariaDBContainer = new MariaDBContainer<>("mariadb")
.withDatabaseName("WMSA_prod")
.withUsername("wmsa")
.withPassword("wmsa")
.withInitScript("sql/current/11-message-queue.sql")
.withNetworkAliases("mariadb");
static HikariDataSource dataSource;
private String inboxId;
@BeforeEach
public void setUp() {
inboxId = UUID.randomUUID().toString();
}
@BeforeAll
public static void setUpAll() {
HikariConfig config = new HikariConfig();
config.setJdbcUrl(mariaDBContainer.getJdbcUrl());
config.setUsername("wmsa");
config.setPassword("wmsa");
dataSource = new HikariDataSource(config);
}
@AfterAll
public static void tearDownAll() {
dataSource.close();
}
@Test
public void testOpenClose() throws InterruptedException {
var outbox = new MqOutbox(new MqPersistence(dataSource), inboxId, UUID.randomUUID());
outbox.stop();
}
@Test
public void testSend() throws Exception {
var outbox = new MqOutbox(new MqPersistence(dataSource), inboxId, UUID.randomUUID());
Executors.newSingleThreadExecutor().submit(() -> outbox.send("test", "Hello World"));
TimeUnit.MILLISECONDS.sleep(100);
var messages = MqTestUtil.getMessages(dataSource, inboxId);
assertEquals(1, messages.size());
System.out.println(messages.get(0));
outbox.stop();
}
@Test
public void testSendAndRespond() throws Exception {
var outbox = new MqOutbox(new MqPersistence(dataSource), inboxId, UUID.randomUUID());
var inbox = new MqInbox(new MqPersistence(dataSource), inboxId, UUID.randomUUID());
inbox.subscribe(justRespond("Alright then"));
inbox.start();
var rsp = outbox.send("test", "Hello World");
assertEquals(MqMessageState.OK, rsp.state());
assertEquals("Alright then", rsp.payload());
var messages = MqTestUtil.getMessages(dataSource, inboxId);
assertEquals(1, messages.size());
assertEquals(MqMessageState.OK, messages.get(0).state());
outbox.stop();
inbox.stop();
}
@Test
public void testSendMultiple() throws Exception {
var outbox = new MqOutbox(new MqPersistence(dataSource), inboxId, UUID.randomUUID());
var inbox = new MqInbox(new MqPersistence(dataSource), inboxId, UUID.randomUUID());
inbox.subscribe(echo());
inbox.start();
var rsp1 = outbox.send("test", "one");
var rsp2 = outbox.send("test", "two");
var rsp3 = outbox.send("test", "three");
var rsp4 = outbox.send("test", "four");
Thread.sleep(500);
assertEquals(MqMessageState.OK, rsp1.state());
assertEquals("one", rsp1.payload());
assertEquals(MqMessageState.OK, rsp2.state());
assertEquals("two", rsp2.payload());
assertEquals(MqMessageState.OK, rsp3.state());
assertEquals("three", rsp3.payload());
assertEquals(MqMessageState.OK, rsp4.state());
assertEquals("four", rsp4.payload());
var messages = MqTestUtil.getMessages(dataSource, inboxId);
assertEquals(4, messages.size());
for (var message : messages) {
assertEquals(MqMessageState.OK, message.state());
}
outbox.stop();
inbox.stop();
}
@Test
public void testSendAndRespondWithErrorHandler() throws Exception {
var outbox = new MqOutbox(new MqPersistence(dataSource), inboxId, UUID.randomUUID());
var inbox = new MqInbox(new MqPersistence(dataSource), inboxId, UUID.randomUUID());
inbox.start();
var rsp = outbox.send("test", "Hello World");
assertEquals(MqMessageState.ERR, rsp.state());
var messages = MqTestUtil.getMessages(dataSource, inboxId);
assertEquals(1, messages.size());
assertEquals(MqMessageState.ERR, messages.get(0).state());
outbox.stop();
inbox.stop();
}
public MqSubscription justRespond(String response) {
return new MqSubscription() {
@Override
public boolean filter(MqMessage rawMessage) {
return true;
}
@Override
public MqInboxResponse handle(MqMessage msg) {
return MqInboxResponse.ok(response);
}
};
}
public MqSubscription echo() {
return new MqSubscription() {
@Override
public boolean filter(MqMessage rawMessage) {
return true;
}
@Override
public MqInboxResponse handle(MqMessage msg) {
return MqInboxResponse.ok(msg.payload());
}
};
}
}

View File

@ -0,0 +1,189 @@
package nu.marginalia.mq.outbox;
import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.HikariDataSource;
import nu.marginalia.mq.MqMessageState;
import nu.marginalia.mq.persistence.MqPersistence;
import org.junit.jupiter.api.*;
import org.testcontainers.containers.MariaDBContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import java.time.Duration;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
@Tag("slow")
@Testcontainers
public class MqPersistenceTest {
@Container
static MariaDBContainer<?> mariaDBContainer = new MariaDBContainer<>("mariadb")
.withDatabaseName("WMSA_prod")
.withUsername("wmsa")
.withPassword("wmsa")
.withInitScript("sql/current/11-message-queue.sql")
.withNetworkAliases("mariadb");
static HikariDataSource dataSource;
static MqPersistence persistence;
String recipientId;
String senderId;
@BeforeEach
public void setUp() {
senderId = UUID.randomUUID().toString();
recipientId = UUID.randomUUID().toString();
}
@BeforeAll
public static void setUpAll() {
HikariConfig config = new HikariConfig();
config.setJdbcUrl(mariaDBContainer.getJdbcUrl());
config.setUsername("wmsa");
config.setPassword("wmsa");
dataSource = new HikariDataSource(config);
persistence = new MqPersistence(dataSource);
}
@AfterAll
public static void tearDownAll() {
dataSource.close();
}
@Test
public void testReaper() throws Exception {
long id = persistence.sendNewMessage(recipientId, senderId, "function", "payload", Duration.ofSeconds(2));
persistence.reapDeadMessages();
var messages = MqTestUtil.getMessages(dataSource, recipientId);
assertEquals(1, messages.size());
assertEquals(MqMessageState.NEW, messages.get(0).state());
TimeUnit.SECONDS.sleep(5);
persistence.reapDeadMessages();
messages = MqTestUtil.getMessages(dataSource, recipientId);
assertEquals(1, messages.size());
assertEquals(MqMessageState.DEAD, messages.get(0).state());
}
@Test
public void sendWithReplyAddress() throws Exception {
long id = persistence.sendNewMessage(recipientId, senderId, "function", "payload", Duration.ofSeconds(30));
var messages = MqTestUtil.getMessages(dataSource, recipientId);
assertEquals(1, messages.size());
var message = messages.get(0);
assertEquals(id, message.id());
assertEquals("function", message.function());
assertEquals("payload", message.payload());
assertEquals(MqMessageState.NEW, message.state());
System.out.println(message);
}
@Test
public void sendNoReplyAddress() throws Exception {
long id = persistence.sendNewMessage(recipientId, null, "function", "payload", Duration.ofSeconds(30));
var messages = MqTestUtil.getMessages(dataSource, recipientId);
assertEquals(1, messages.size());
var message = messages.get(0);
assertEquals(id, message.id());
assertNull(message.senderInbox());
assertEquals("function", message.function());
assertEquals("payload", message.payload());
assertEquals(MqMessageState.NEW, message.state());
System.out.println(message);
}
@Test
public void updateState() throws Exception {
long id = persistence.sendNewMessage(recipientId, senderId, "function", "payload", Duration.ofSeconds(30));
persistence.updateMessageState(id, MqMessageState.OK);
System.out.println(id);
var messages = MqTestUtil.getMessages(dataSource, recipientId);
assertEquals(1, messages.size());
var message = messages.get(0);
assertEquals(id, message.id());
assertEquals(MqMessageState.OK, message.state());
System.out.println(message);
}
@Test
public void testReply() throws Exception {
long request = persistence.sendNewMessage(recipientId, senderId, "function", "payload", Duration.ofSeconds(30));
long response = persistence.sendResponse(request, MqMessageState.OK, "response");
var sentMessages = MqTestUtil.getMessages(dataSource, recipientId);
System.out.println(sentMessages);
assertEquals(1, sentMessages.size());
var requestMessage = sentMessages.get(0);
assertEquals(request, requestMessage.id());
assertEquals(MqMessageState.OK, requestMessage.state());
var replies = MqTestUtil.getMessages(dataSource, senderId);
System.out.println(replies);
assertEquals(1, replies.size());
var responseMessage = replies.get(0);
assertEquals(response, responseMessage.id());
assertEquals(request, responseMessage.relatedId());
assertEquals(MqMessageState.NEW, responseMessage.state());
}
@Test
public void testPollInbox() throws Exception {
String instanceId = "BATMAN";
long tick = 1234L;
long id = persistence.sendNewMessage(recipientId, null,"function", "payload", Duration.ofSeconds(30));
var messagesPollFirstTime = persistence.pollInbox(recipientId, instanceId , tick);
/** CHECK POLL RESULT */
assertEquals(1, messagesPollFirstTime.size());
var firstPollMessage = messagesPollFirstTime.iterator().next();
assertEquals(id, firstPollMessage.msgId());
assertEquals("function", firstPollMessage.function());
assertEquals("payload", firstPollMessage.payload());
/** CHECK DB TABLE */
var messages = MqTestUtil.getMessages(dataSource, recipientId);
assertEquals(1, messages.size());
var message = messages.get(0);
assertEquals(id, message.id());
assertEquals("function", message.function());
assertEquals("payload", message.payload());
assertEquals(MqMessageState.ACK, message.state());
assertEquals(instanceId, message.ownerInstance());
assertEquals(tick, message.ownerTick());
/** VERIFY SECOND POLL IS EMPTY */
var messagePollSecondTime = persistence.pollInbox(recipientId, instanceId , 1);
assertEquals(0, messagePollSecondTime.size());
}
}

View File

@ -0,0 +1,52 @@
package nu.marginalia.mq.outbox;
import com.zaxxer.hikari.HikariDataSource;
import nu.marginalia.mq.MqMessageState;
import org.junit.jupiter.api.Assertions;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
public class MqTestUtil {
public static List<MqMessageRow> getMessages(HikariDataSource dataSource, String inbox) {
List<MqMessageRow> messages = new ArrayList<>();
try (var conn = dataSource.getConnection();
var stmt = conn.prepareStatement("""
SELECT ID, RELATED_ID,
SENDER_INBOX, RECIPIENT_INBOX,
FUNCTION, PAYLOAD,
STATE,
OWNER_INSTANCE, OWNER_TICK,
CREATED_TIME, UPDATED_TIME,
TTL
FROM PROC_MESSAGE
WHERE RECIPIENT_INBOX = ?
"""))
{
stmt.setString(1, inbox);
var rsp = stmt.executeQuery();
while (rsp.next()) {
messages.add(new MqMessageRow(
rsp.getLong("ID"),
rsp.getLong("RELATED_ID"),
rsp.getString("SENDER_INBOX"),
rsp.getString("RECIPIENT_INBOX"),
rsp.getString("FUNCTION"),
rsp.getString("PAYLOAD"),
MqMessageState.valueOf(rsp.getString("STATE")),
rsp.getString("OWNER_INSTANCE"),
rsp.getLong("OWNER_TICK"),
rsp.getTimestamp("CREATED_TIME").getTime(),
rsp.getTimestamp("UPDATED_TIME").getTime(),
rsp.getLong("TTL")
));
}
}
catch (SQLException ex) {
Assertions.fail(ex);
}
return messages;
}
}

View File

@ -48,6 +48,7 @@ include 'code:api:assistant-api'
include 'code:common:service-discovery'
include 'code:common:service-client'
include 'code:common:db'
include 'code:common:message-queue'
include 'code:common:service'
include 'code:common:config'
include 'code:common:model'