From 6ad672259ac562b51a684b1076f36a703e637cae Mon Sep 17 00:00:00 2001 From: Igor Lovich Date: Tue, 28 Mar 2023 09:31:15 +0200 Subject: [PATCH] GH-8582:Transactional support in PostgresSubscribableChannel --- ...PostgresChannelMessageTableSubscriber.java | 82 ++++++++++- .../channel/PostgresSubscribableChannel.java | 62 ++++++-- ...resChannelMessageTableSubscriberTests.java | 138 +++++++++++++++--- 3 files changed, 240 insertions(+), 42 deletions(-) diff --git a/spring-integration-jdbc/src/main/java/org/springframework/integration/jdbc/channel/PostgresChannelMessageTableSubscriber.java b/spring-integration-jdbc/src/main/java/org/springframework/integration/jdbc/channel/PostgresChannelMessageTableSubscriber.java index a02c1a1485e..58622ba6099 100644 --- a/spring-integration-jdbc/src/main/java/org/springframework/integration/jdbc/channel/PostgresChannelMessageTableSubscriber.java +++ b/spring-integration-jdbc/src/main/java/org/springframework/integration/jdbc/channel/PostgresChannelMessageTableSubscriber.java @@ -1,5 +1,5 @@ /* - * Copyright 2022 the original author or authors. + * Copyright 2022-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; +import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.TimeUnit; import org.postgresql.PGNotification; @@ -59,6 +60,7 @@ * * @author Rafael Winterhalter * @author Artem Bilan + * @author Igor Lovich * * @since 6.0 */ @@ -77,6 +79,8 @@ public final class PostgresChannelMessageTableSubscriber implements SmartLifecyc private CountDownLatch latch = new CountDownLatch(0); + private boolean userProvidedExecutor = false; + private Future future = CompletableFuture.completedFuture(null); @Nullable @@ -143,12 +147,17 @@ public synchronized void start() { ExecutorService executorToUse = this.executor; if (executorToUse == null) { CustomizableThreadFactory threadFactory = - new CustomizableThreadFactory("postgres-channel-message-table-subscriber-"); + new CustomizableThreadFactory("postgres-channel-notifications-"); threadFactory.setDaemon(true); - executorToUse = Executors.newSingleThreadExecutor(threadFactory); + executorToUse = Executors.newFixedThreadPool(2, threadFactory); this.executor = executorToUse; } + else { + this.userProvidedExecutor = true; + } this.latch = new CountDownLatch(1); + + CountDownLatch startingLatch = new CountDownLatch(1); this.future = executorToUse.submit(() -> { try { while (isActive()) { @@ -166,11 +175,13 @@ public synchronized void start() { } throw ex; } - this.subscriptionsMap.values() - .forEach(subscriptions -> subscriptions.forEach(Subscription::notifyUpdate)); + this.subscriptionsMap.values().forEach(this::notifyAll); + try { this.connection = conn; while (isActive()) { + startingLatch.countDown(); + PGNotification[] notifications = conn.getNotifications(0); // Unfortunately, there is no good way of interrupting a notification // poll but by closing its connection. @@ -184,9 +195,7 @@ public synchronized void start() { if (subscriptions == null) { continue; } - for (Subscription subscription : subscriptions) { - subscription.notifyUpdate(); - } + notifyAll(subscriptions); } } } @@ -208,6 +217,29 @@ public synchronized void start() { this.latch.countDown(); } }); + + try { + if (!startingLatch.await(5, TimeUnit.SECONDS)) { + throw new IllegalStateException("Failed to start " + + PostgresChannelMessageTableSubscriber.class.getName()); + } + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new IllegalStateException("Failed to start " + + PostgresChannelMessageTableSubscriber.class.getName(), e); + } + } + + private void notifyAll(Set subscriptions) { + subscriptions.forEach(it -> { + try { + this.executor.submit(it::notifyUpdate); + } + catch (RejectedExecutionException e) { + LOGGER.warn(e, "Executor rejected submission of notification task"); + } + }); } private boolean isActive() { @@ -232,6 +264,11 @@ public synchronized void stop() { catch (SQLException ignored) { } } + + if (!this.userProvidedExecutor) { + shutdownAndAwaitTermination(this.executor); + } + try { if (!this.latch.await(5, TimeUnit.SECONDS)) { throw new IllegalStateException("Failed to stop " @@ -242,6 +279,35 @@ public synchronized void stop() { } } + + /** + * Gracefully shutdown an executor service. Taken from @see ExecutorService javadoc + * + * @param pool The pool to shut down + */ + private void shutdownAndAwaitTermination(@Nullable ExecutorService pool) { + if (pool == null) { + return; + } + pool.shutdown(); // Disable new tasks from being submitted + try { + // Wait a while for existing tasks to terminate + if (!pool.awaitTermination(2, TimeUnit.SECONDS)) { + pool.shutdownNow(); // Cancel currently executing tasks + // Wait a while for tasks to respond to being cancelled + if (!pool.awaitTermination(2, TimeUnit.SECONDS)) { + LOGGER.warn("Unable to shutdown the executor service"); + } + } + } + catch (InterruptedException ie) { + // (Re-)Cancel if current thread also interrupted + pool.shutdownNow(); + // Preserve interrupt status + Thread.currentThread().interrupt(); + } + } + @Override public boolean isRunning() { return this.latch.getCount() > 0; diff --git a/spring-integration-jdbc/src/main/java/org/springframework/integration/jdbc/channel/PostgresSubscribableChannel.java b/spring-integration-jdbc/src/main/java/org/springframework/integration/jdbc/channel/PostgresSubscribableChannel.java index dde6bfcd345..d71fa3c0157 100644 --- a/spring-integration-jdbc/src/main/java/org/springframework/integration/jdbc/channel/PostgresSubscribableChannel.java +++ b/spring-integration-jdbc/src/main/java/org/springframework/integration/jdbc/channel/PostgresSubscribableChannel.java @@ -1,5 +1,5 @@ /* - * Copyright 2022 the original author or authors. + * Copyright 2022-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,15 +16,17 @@ package org.springframework.integration.jdbc.channel; -import java.util.concurrent.Executor; +import java.util.Optional; -import org.springframework.core.task.SimpleAsyncTaskExecutor; import org.springframework.integration.channel.AbstractSubscribableChannel; import org.springframework.integration.dispatcher.MessageDispatcher; import org.springframework.integration.dispatcher.UnicastingDispatcher; import org.springframework.integration.jdbc.store.JdbcChannelMessageStore; import org.springframework.messaging.Message; import org.springframework.messaging.MessageHandler; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.transaction.PlatformTransactionManager; +import org.springframework.transaction.support.TransactionTemplate; import org.springframework.util.Assert; /** @@ -39,6 +41,7 @@ * * @author Rafael Winterhalter * @author Artem Bilan + * @author Igor Lovich * * @since 6.0 */ @@ -47,11 +50,15 @@ public class PostgresSubscribableChannel extends AbstractSubscribableChannel private final JdbcChannelMessageStore jdbcChannelMessageStore; + private TransactionTemplate transactionTemplate; + private final Object groupId; private final PostgresChannelMessageTableSubscriber messageTableSubscriber; - private UnicastingDispatcher dispatcher = new UnicastingDispatcher(new SimpleAsyncTaskExecutor()); + private final UnicastingDispatcher dispatcher = new UnicastingDispatcher(); + + private RetryTemplate retryTemplate = RetryTemplate.builder().maxAttempts(1).build(); /** * Create a subscribable channel for a Postgres database. @@ -70,12 +77,22 @@ public PostgresSubscribableChannel(JdbcChannelMessageStore jdbcChannelMessageSto } /** - * Set the executor to use for dispatching newly received messages. - * @param executor The executor to use. + * Sets the transaction manager to use for message processing. Each message will be processed in a + * separate transaction + * @param transactionManager The transaction manager to use */ - public void setDispatcherExecutor(Executor executor) { - Assert.notNull(executor, "An executor must be provided."); - this.dispatcher = new UnicastingDispatcher(executor); + public void setTransactionManager(PlatformTransactionManager transactionManager) { + Assert.notNull(transactionManager, "A platform transaction manager must be provided."); + this.transactionTemplate = new TransactionTemplate(transactionManager); + } + + /** + * Sets retry template to use for retries in case of exception in downstream processing + * @param retryTemplate The retry template to use + */ + public void setRetryTemplate(RetryTemplate retryTemplate) { + Assert.notNull(retryTemplate, "A retry template must be provided."); + this.retryTemplate = retryTemplate; } @Override @@ -110,10 +127,29 @@ protected boolean doSend(Message message, long timeout) { @Override public void notifyUpdate() { - Message message; - while ((message = this.jdbcChannelMessageStore.pollMessageFromGroup(this.groupId)) != null) { - this.dispatcher.dispatch(message); - } + Optional> dispatchedMessage; + + do { + if (this.transactionTemplate != null) { + dispatchedMessage = this.retryTemplate.execute(context -> + this.transactionTemplate.execute(status -> + Optional.ofNullable(this.jdbcChannelMessageStore.pollMessageFromGroup(this.groupId)) + .map(this::dispatch) + ) + ); + } + else { + dispatchedMessage = Optional.ofNullable(this.jdbcChannelMessageStore.pollMessageFromGroup(this.groupId)) + .map(message -> + this.retryTemplate.execute(context -> dispatch(message)) + ); + } + } while (dispatchedMessage.isPresent()); + } + + private Message dispatch(Message message) { + this.dispatcher.dispatch(message); + return message; } @Override diff --git a/spring-integration-jdbc/src/test/java/org/springframework/integration/jdbc/channel/PostgresChannelMessageTableSubscriberTests.java b/spring-integration-jdbc/src/test/java/org/springframework/integration/jdbc/channel/PostgresChannelMessageTableSubscriberTests.java index 68fc5ed9559..25ac16b8266 100644 --- a/spring-integration-jdbc/src/test/java/org/springframework/integration/jdbc/channel/PostgresChannelMessageTableSubscriberTests.java +++ b/spring-integration-jdbc/src/test/java/org/springframework/integration/jdbc/channel/PostgresChannelMessageTableSubscriberTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2022 the original author or authors. + * Copyright 2022-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,12 +21,16 @@ import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import javax.sql.DataSource; import org.apache.commons.dbcp2.BasicDataSource; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInfo; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.postgresql.jdbc.PgConnection; import org.springframework.beans.factory.annotation.Autowired; @@ -36,18 +40,22 @@ import org.springframework.integration.config.EnableIntegration; import org.springframework.integration.jdbc.store.JdbcChannelMessageStore; import org.springframework.integration.jdbc.store.channel.PostgresChannelMessageStoreQueryProvider; +import org.springframework.jdbc.datasource.DataSourceTransactionManager; import org.springframework.jdbc.datasource.init.DataSourceInitializer; import org.springframework.jdbc.datasource.init.ResourceDatabasePopulator; import org.springframework.jdbc.datasource.init.ScriptUtils; import org.springframework.messaging.support.GenericMessage; +import org.springframework.retry.support.RetryTemplate; import org.springframework.test.annotation.DirtiesContext; import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; +import org.springframework.transaction.PlatformTransactionManager; import static org.assertj.core.api.Assertions.assertThat; /** * @author Rafael Winterhalter * @author Artem Bilan + * @author Igor Lovich * * @since 6.0 */ @@ -92,10 +100,17 @@ CREATE FUNCTION INT_CHANNEL_MESSAGE_NOTIFY_FCT() @Autowired private JdbcChannelMessageStore messageStore; + @Autowired + private PlatformTransactionManager transactionManager; + private PostgresChannelMessageTableSubscriber postgresChannelMessageTableSubscriber; + private PostgresSubscribableChannel postgresSubscribableChannel; + + private String groupId; + @BeforeEach - void setUp() { + void setUp(TestInfo testInfo) { // Not initiated as a bean to allow for registrations prior and post the life cycle this.postgresChannelMessageTableSubscriber = new PostgresChannelMessageTableSubscriber( () -> DriverManager.getConnection(POSTGRES_CONTAINER.getJdbcUrl(), @@ -103,6 +118,12 @@ void setUp() { POSTGRES_CONTAINER.getPassword()) .unwrap(PgConnection.class) ); + + this.groupId = testInfo.getDisplayName(); + + this.postgresSubscribableChannel = new PostgresSubscribableChannel(messageStore, + groupId, + postgresChannelMessageTableSubscriber); } @Test @@ -111,18 +132,13 @@ public void testMessagePollMessagesAddedAfterStart() throws Exception { List payloads = new ArrayList<>(); postgresChannelMessageTableSubscriber.start(); try { - PostgresSubscribableChannel channel = new PostgresSubscribableChannel(messageStore, - "testMessagePollMessagesAddedAfterStart", - postgresChannelMessageTableSubscriber); - channel.subscribe(message -> { + postgresSubscribableChannel.subscribe(message -> { payloads.add(message.getPayload()); latch.countDown(); }); - messageStore.addMessageToGroup("testMessagePollMessagesAddedAfterStart", new GenericMessage<>("1")); - messageStore.addMessageToGroup("testMessagePollMessagesAddedAfterStart", new GenericMessage<>("2")); - assertThat(latch.await(3, TimeUnit.SECONDS)) - .as("Expected Postgres notification within 3 seconds") - .isTrue(); + messageStore.addMessageToGroup(groupId, new GenericMessage<>("1")); + messageStore.addMessageToGroup(groupId, new GenericMessage<>("2")); + waitForNotificationAndAssert(latch); } finally { postgresChannelMessageTableSubscriber.stop(); @@ -134,21 +150,16 @@ public void testMessagePollMessagesAddedAfterStart() throws Exception { public void testMessagePollMessagesAddedBeforeStart() throws InterruptedException { CountDownLatch latch = new CountDownLatch(2); List payloads = new ArrayList<>(); - PostgresSubscribableChannel channel = - new PostgresSubscribableChannel(messageStore, - "testMessagePollMessagesAddedBeforeStart", - postgresChannelMessageTableSubscriber); - channel.subscribe(message -> { + + postgresSubscribableChannel.subscribe(message -> { payloads.add(message.getPayload()); latch.countDown(); }); - messageStore.addMessageToGroup("testMessagePollMessagesAddedBeforeStart", new GenericMessage<>("1")); - messageStore.addMessageToGroup("testMessagePollMessagesAddedBeforeStart", new GenericMessage<>("2")); + messageStore.addMessageToGroup(groupId, new GenericMessage<>("1")); + messageStore.addMessageToGroup(groupId, new GenericMessage<>("2")); postgresChannelMessageTableSubscriber.start(); try { - assertThat(latch.await(3, TimeUnit.SECONDS)) - .as("Expected Postgres notification within 3 seconds") - .isTrue(); + waitForNotificationAndAssert(latch); } finally { postgresChannelMessageTableSubscriber.stop(); @@ -156,6 +167,86 @@ public void testMessagePollMessagesAddedBeforeStart() throws InterruptedExceptio assertThat(payloads).containsExactly("1", "2"); } + @Test + void testMessagesDispatchedInTransaction() throws InterruptedException { + CountDownLatch latch = new CountDownLatch(2); + postgresSubscribableChannel.setTransactionManager(transactionManager); + + postgresChannelMessageTableSubscriber.start(); + try { + postgresSubscribableChannel.subscribe(message -> { + try { + throw new RuntimeException("An error has occurred"); + } + finally { + latch.countDown(); + } + }); + + messageStore.addMessageToGroup(groupId, new GenericMessage<>("1")); + messageStore.addMessageToGroup(groupId, new GenericMessage<>("2")); + + waitForNotificationAndAssert(latch); + } + finally { + postgresChannelMessageTableSubscriber.stop(); + } + + assertThat(messageStore.messageGroupSize(groupId)).isEqualTo(2); + assertThat(messageStore.pollMessageFromGroup(groupId).getPayload()).isEqualTo("1"); + assertThat(messageStore.pollMessageFromGroup(groupId).getPayload()).isEqualTo("2"); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testRetryOnErrorDuringDispatch(boolean transactionsEnabled) throws InterruptedException { + CountDownLatch latch = new CountDownLatch(2); + List payloads = new ArrayList<>(); + AtomicInteger actualTries = new AtomicInteger(); + + int maxAttempts = 2; + postgresSubscribableChannel.setRetryTemplate(RetryTemplate.builder().maxAttempts(maxAttempts).build()); + + if (transactionsEnabled) { + postgresSubscribableChannel.setTransactionManager(transactionManager); + } + + postgresChannelMessageTableSubscriber.start(); + + try { + + postgresSubscribableChannel.subscribe(message -> { + try { + //fail once + if (actualTries.getAndIncrement() == 0) { + throw new RuntimeException("An error has occurred"); + } + payloads.add(message.getPayload()); + } + finally { + latch.countDown(); + } + }); + + messageStore.addMessageToGroup(groupId, new GenericMessage<>("1")); + + waitForNotificationAndAssert(latch); + } + finally { + postgresChannelMessageTableSubscriber.stop(); + } + + assertThat(actualTries.get()).isEqualTo(maxAttempts); + assertThat(payloads).containsExactly("1"); + } + + private static void waitForNotificationAndAssert(CountDownLatch latch) throws InterruptedException { + assertThat(latch.await(3, TimeUnit.SECONDS)) + .as("Expected Postgres notification within 3 seconds") + .isTrue(); + } + + @Configuration @EnableIntegration public static class Config { @@ -181,6 +272,11 @@ DataSourceInitializer dataSourceInitializer(DataSource dataSource) { return dataSourceInitializer; } + @Bean + PlatformTransactionManager transactionManager(DataSource dataSource) { + return new DataSourceTransactionManager(dataSource); + } + @Bean public JdbcChannelMessageStore jdbcChannelMessageStore(DataSource dataSource) { JdbcChannelMessageStore messageStore = new JdbcChannelMessageStore(dataSource);