Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"type": "bugfix",
"category": "Amazon Simple Queue Service",
"contributor": "thornhillcody",
"description": "Fix SqsAsyncBatchManager excessive batch flushing under heavy load. Fixes [#6374](https://github.com/aws/aws-sdk-java-v2/issues/6374)."
}
5 changes: 5 additions & 0 deletions services/sqs/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@
<artifactId>mockito-junit-jupiter</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<scope>test</scope>
</dependency>

</dependencies>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ public boolean contains(String batchKey) {
return batchContextMap.containsKey(batchKey);
}

public void putScheduledFlush(String batchKey, ScheduledFuture<?> scheduledFlush) {
batchContextMap.get(batchKey).putScheduledFlush(scheduledFlush);
public void cancelAndReplaceScheduledFlush(String batchKey, ScheduledFuture<?> scheduledFlush) {
batchContextMap.get(batchKey).cancelAndReplaceScheduledFlush(scheduledFlush);
}

public void forEach(BiConsumer<String, RequestBatchBuffer<RequestT, ResponseT>> action) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
@SdkInternalApi
public final class RequestBatchBuffer<RequestT, ResponseT> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we please add some test cases in RequestBatchBufferTest for this fix ?

private final Object flushLock = new Object();
private final Object scheduledFlushLock = new Object();

private final Map<String, BatchingExecutionContext<RequestT, ResponseT>> idToBatchContext;
private final int maxBatchItems;
Expand Down Expand Up @@ -144,12 +145,20 @@ private String nextBatchEntry() {
return Integer.toString(nextBatchEntry++);
}

public void putScheduledFlush(ScheduledFuture<?> scheduledFlush) {
this.scheduledFlush = scheduledFlush;
public void cancelAndReplaceScheduledFlush(ScheduledFuture<?> scheduledFlush) {
// Locking the cancellation and replacement of the scheduledFlush ensures that there is only one active.
synchronized (scheduledFlushLock) {
if (this.scheduledFlush != null) {
cancelScheduledFlush();
}
this.scheduledFlush = scheduledFlush;
}
}

public void cancelScheduledFlush() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also ad synchronization here since cancelScheduledFlush() can be independently call

    public void cancelScheduledFlush() {
        synchronized (scheduledFlushLock) {
            scheduledFlush.cancel(false);
        }
    }

scheduledFlush.cancel(false);
synchronized (scheduledFlushLock) {
scheduledFlush.cancel(false);
}
}

public Collection<CompletableFuture<ResponseT>> responses() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,9 @@ protected abstract CompletableFuture<BatchResponseT> batchAndSend(List<Identifia

private void manualFlushBuffer(String batchKey,
Map<String, BatchingExecutionContext<RequestT, ResponseT>> flushableRequests) {
requestsAndResponsesMaps.cancelScheduledFlush(batchKey);
flushBuffer(batchKey, flushableRequests);
requestsAndResponsesMaps.putScheduledFlush(batchKey,
scheduleBufferFlush(batchKey,
requestsAndResponsesMaps.cancelAndReplaceScheduledFlush(batchKey,
scheduleBufferFlush(batchKey,
sendRequestFrequency.toMillis(),
scheduledExecutor));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.awssdk.services.sqs.batchmanager;

import static com.github.tomakehurst.wiremock.client.WireMock.*;
import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig;
import static org.assertj.core.api.Assertions.assertThat;

import com.github.tomakehurst.wiremock.junit5.WireMockExtension;
import com.github.tomakehurst.wiremock.verification.LoggedRequest;
import com.google.common.util.concurrent.RateLimiter;
import java.net.URI;
import java.time.Duration;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.services.sqs.SqsAsyncClient;


/**
* Tests the batching efficiency of {@link SqsAsyncBatchManager} under various load scenarios.
*/
public class BatchingEfficiencyUnderLoadTest {

private static final String QUEUE_URL = "https://sqs.us-east-1.amazonaws.com/123456789012/MyQueue";
private static final int CONCURRENT_THREADS = 50;
private static final int MAX_BATCH_SIZE = 10;
private static final int SEND_FREQUENCY_MILLIS = 5;

@RegisterExtension
static WireMockExtension wireMock = WireMockExtension.newInstance()
.options(wireMockConfig().dynamicPort())
.configureStaticDsl(true)
.build();

private SqsAsyncClient client;
private SqsAsyncBatchManager batchManager;

@BeforeEach
void setUp() {
client = SqsAsyncClient.builder()
.endpointOverride(URI.create("http://localhost:" + wireMock.getPort()))
.checksumValidationEnabled(false)
.credentialsProvider(StaticCredentialsProvider.create(
AwsBasicCredentials.create("key", "secret")))
.build();

batchManager = SqsAsyncBatchManager.builder()
.client(client)
.scheduledExecutor(Executors.newScheduledThreadPool(10))
.overrideConfiguration(config -> config
.sendRequestFrequency(Duration.ofMillis(SEND_FREQUENCY_MILLIS))
.maxBatchSize(MAX_BATCH_SIZE))
.build();
}

@AfterEach
void tearDown() {
batchManager.close();
client.close();
}

/**
* Test runs heavy load and expects average batch sizes to be close to max.
*/
@Test
void sendMessage_whenHighLoadScenario_shouldEfficientlyBatchMessages() throws Exception {
int expectedBatchSize = 25; // more than double the actual max of 10
int rateLimit = 1000 / SEND_FREQUENCY_MILLIS * expectedBatchSize;
int messageCount = rateLimit * 2; // run it for 2 seconds
runThroughputTest(messageCount, rateLimit);

// Then: Verify messages were efficiently batched
List<LoggedRequest> batchRequests = findAll(postRequestedFor(anyUrl()));

// Calculate batching metrics
List<Integer> batchSizes = batchRequests.stream()
.map(req -> req.getBodyAsString().split("\"Id\"").length - 1)
.collect(Collectors.toList());

double avgBatchSize = batchSizes.stream()
.mapToInt(Integer::intValue)
.average()
.orElse(0);

double fullBatchRatio = batchSizes.stream()
.filter(size -> size >= 9)
.count() / (double) batchSizes.size();

// Assert efficient batching
assertThat(avgBatchSize)
.as("Average batch size")
.isGreaterThan(8.0);


assertThat(fullBatchRatio)
.as("Ratio of nearly full batches (9-10 messages)")
.isGreaterThan(0.8);

assertThat((double)batchRequests.size())
.as("Total batch requests for %d messages", messageCount)
.isLessThan(messageCount / 5d);
}

/**
* Test runs a load that should cause an average batch size of 5.
*/
@Test
void sendMessage_whenMediumLoadScenario_shouldCreateHalfSizeBatches() throws Exception {
int expectedBatchSize = 5;
int rateLimit = 1000 / SEND_FREQUENCY_MILLIS * expectedBatchSize;
int messageCount = rateLimit * 2; // run it for 2 seconds
runThroughputTest(messageCount, rateLimit);

// Then: Verify batches were roughly half max size
List<LoggedRequest> batchRequests = findAll(postRequestedFor(anyUrl()));

// Calculate batching metrics
List<Integer> batchSizes = batchRequests.stream()
.map(req -> req.getBodyAsString().split("\"Id\"").length - 1)
.collect(Collectors.toList());

double avgBatchSize = batchSizes.stream()
.mapToInt(Integer::intValue)
.average()
.orElse(0);

// Assert batch expected range
assertThat(avgBatchSize)
.as("Average batch size")
.isLessThan(7.0)
.isGreaterThan(3.0);

assertThat((double)batchRequests.size())
.as("Total batch requests for %d messages", messageCount)
.isLessThan(messageCount / 3d);
}

@Test
void sendMessage_whenLowLoadScenario_shouldCreateSmallBatches() throws Exception {
int expectedBatchSize = 1;
int rateLimit = 1000 / SEND_FREQUENCY_MILLIS * expectedBatchSize;
int messageCount = rateLimit * 2; // run it for 2 seconds
runThroughputTest(messageCount, rateLimit);

// Then: Verify batches were roughly half max size
List<LoggedRequest> batchRequests = findAll(postRequestedFor(anyUrl()));

// Calculate batching metrics
List<Integer> batchSizes = batchRequests.stream()
.map(req -> req.getBodyAsString().split("\"Id\"").length - 1)
.collect(Collectors.toList());

double avgBatchSize = batchSizes.stream()
.mapToInt(Integer::intValue)
.average()
.orElse(0);

// Assert batch expected range
assertThat(avgBatchSize)
.as("Average batch size")
.isLessThan(2.0);

assertThat((double)batchRequests.size())
.as("Total batch requests for %d messages", messageCount)
.isGreaterThan(messageCount * .5);
}

private void runThroughputTest(int messageCount, int rateLimit) throws InterruptedException {
// Given: SQS returns success for batch requests
stubFor(post(anyUrl())
.willReturn(aResponse()
.withStatus(200)
.withBody("{\"Successful\": []}")));

// When: Send rateLimit messages per second concurrently (using 50 threads)
ExecutorService executor = Executors.newFixedThreadPool(CONCURRENT_THREADS);

// Rate limit to spread it out over a couple seconds; enough time to make
// any orphaned scheduled flushes obvious.
RateLimiter rateLimiter = RateLimiter.create(rateLimit);

for (int i = 0; i < messageCount; i++) {
String messageBody = String.valueOf(i);
rateLimiter.acquire();
executor.execute(() -> {
try {
batchManager.sendMessage(builder ->
builder.queueUrl(QUEUE_URL)
.messageBody(messageBody));
} catch (Exception ignored) {
// Test will fail on assertions if messages aren't sent
}
});
}

executor.shutdown();
executor.awaitTermination(10, TimeUnit.SECONDS);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,12 @@ void whenMaxBufferSizeReachedThenThrowException() {
}

@Test
void whenPutScheduledFlushThenFlushIsSet() {
void whenCancelAndReplaceScheduledFlushThenFlushIsSetAndOldFlushIsCanceled() {
batchBuffer = new RequestBatchBuffer<>(scheduledFlush, 10, MAX_SEND_MESSAGE_PAYLOAD_SIZE_BYTES, maxBufferSize);
ScheduledFuture<?> newScheduledFlush = mock(ScheduledFuture.class);
batchBuffer.putScheduledFlush(newScheduledFlush);
batchBuffer.cancelAndReplaceScheduledFlush(newScheduledFlush);
assertNotNull(newScheduledFlush);
verify(scheduledFlush).cancel(false);
}

@Test
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a new Junit Test class testing end to end something like

import static com.github.tomakehurst.wiremock.client.WireMock.*;
import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig;
import static org.assertj.core.api.Assertions.assertThat;

import com.github.tomakehurst.wiremock.junit5.WireMockExtension;
import com.github.tomakehurst.wiremock.verification.LoggedRequest;
import java.net.URI;
import java.time.Duration;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.services.sqs.SqsAsyncClient;

public class HighThroughputBatchingTest {

    private static final String QUEUE_URL = "https://sqs.us-east-1.amazonaws.com/123456789012/MyQueue";
    private static final int MESSAGE_COUNT = 5000;
    private static final int CONCURRENT_THREADS = 50;
    private static final int MAX_BATCH_SIZE = 10;

    @RegisterExtension
    static WireMockExtension wireMock = WireMockExtension.newInstance()
                                                         .options(wireMockConfig().dynamicPort())
                                                         .configureStaticDsl(true)
                                                         .build();

    private SqsAsyncClient client;
    private SqsAsyncBatchManager batchManager;

    @BeforeEach
    void setUp() {
        client = SqsAsyncClient.builder()
                               .endpointOverride(URI.create("http://localhost:" + wireMock.getPort()))
                               .checksumValidationEnabled(false)
                               .credentialsProvider(StaticCredentialsProvider.create(
                                   AwsBasicCredentials.create("key", "secret")))
                               .build();

        batchManager = SqsAsyncBatchManager.builder()
                                           .client(client)
                                           .scheduledExecutor(Executors.newScheduledThreadPool(10))
                                           .overrideConfiguration(config -> config
                                               .sendRequestFrequency(Duration.ofMillis(10))
                                               .maxBatchSize(MAX_BATCH_SIZE))
                                           .build();
    }

    @AfterEach
    void tearDown() {
        batchManager.close();
        client.close();
    }

    @Test
    void shouldEfficientlyBatchMessagesUnderHighLoad() throws Exception {
        // Given: SQS returns success for batch requests
        stubFor(post(anyUrl())
                    .willReturn(aResponse()
                                    .withStatus(200)
                                    .withBody("{\"Successful\": []}")));

        // When: Send 5000 messages concurrently using 50 threads
        ExecutorService executor = Executors.newFixedThreadPool(CONCURRENT_THREADS);
        CountDownLatch startSignal = new CountDownLatch(1);

        for (int i = 0; i < MESSAGE_COUNT; i++) {
            final String messageBody = String.valueOf(i);
            executor.submit(() -> {
                try {
                    startSignal.await(); // Wait to start all at once
                    batchManager.sendMessage(builder ->
                                                 builder.queueUrl(QUEUE_URL)
                                                        .messageBody(messageBody));
                } catch (Exception ignored) {
                    // Test will fail on assertions if messages aren't sent
                }
            });
        }

        startSignal.countDown(); // Fire all threads simultaneously
        executor.shutdown();
        executor.awaitTermination(10, TimeUnit.SECONDS);

        // Allow batch manager to complete processing
        Thread.sleep(2000);

        // Then: Verify messages were efficiently batched
        List<LoggedRequest> batchRequests = findAll(postRequestedFor(anyUrl()));

        // Calculate batching metrics
        List<Integer> batchSizes = batchRequests.stream()
                                                .map(req -> req.getBodyAsString().split("\"Id\"").length - 1)
                                                .collect(Collectors.toList());

        double avgBatchSize = batchSizes.stream()
                                        .mapToInt(Integer::intValue)
                                        .average()
                                        .orElse(0);

        double fullBatchRatio = batchSizes.stream()
                                          .filter(size -> size >= 9)
                                          .count() / (double) batchSizes.size();

        // Assert efficient batching
        assertThat(avgBatchSize)
            .as("Average batch size")
            .isGreaterThan(8.0);

        assertThat(fullBatchRatio)
            .as("Ratio of nearly full batches (9-10 messages)")
            .isGreaterThan(0.8);

        assertThat(batchRequests.size())
            .as("Total batch requests for %d messages", MESSAGE_COUNT)
            .isLessThan(MESSAGE_COUNT / 5);
    }

   //please add more test cases with different combination of maxBatchSize(too low 1 and max 10) and sendFrequency(too long and too short  )

}

Copy link
Contributor Author

@thornhillcody thornhillcody Sep 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used this as a base with some changes to make it work. Added a rate limiter to spread out traffic and a small batch frequency to reduce the amount of time required to run the test. Also did medium/small batch size tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[INFO] Running software.amazon.awssdk.services.sqs.batchmanager.BatchingEfficiencyUnderLoadTest
[INFO] Tests run: 3, Failures: 0, Errors: 0, Skipped: 0, Time elapsed: 6.365 s -- in software.amazon.awssdk.services.sqs.batchmanager.BatchingEfficiencyUnderLoadTest

Takes 6 seconds (2 per test) as it is.

Expand Down Expand Up @@ -188,6 +189,46 @@ void testFlushWhenCumulativePayloadExceedsMaxSize() {
}


@Test
void whenSequentialCancelAndReplaceScheduledFlushThenEachPreviousFlushIsCanceled() {
batchBuffer = new RequestBatchBuffer<>(scheduledFlush, 10, MAX_SEND_MESSAGE_PAYLOAD_SIZE_BYTES, maxBufferSize);

// Create a sequence of mock scheduled futures
ScheduledFuture<?> flush1 = mock(ScheduledFuture.class);
ScheduledFuture<?> flush2 = mock(ScheduledFuture.class);
ScheduledFuture<?> flush3 = mock(ScheduledFuture.class);

// First replacement - should cancel the initial scheduledFlush
batchBuffer.cancelAndReplaceScheduledFlush(flush1);
verify(scheduledFlush, times(1)).cancel(false);

// Second replacement - should cancel flush1
batchBuffer.cancelAndReplaceScheduledFlush(flush2);
verify(flush1, times(1)).cancel(false);

// Verify flush2 has not been canceled (it's the current one)
verify(flush2, never()).cancel(false);

// Verify buffer is still functional
CompletableFuture<String> response = new CompletableFuture<>();
batchBuffer.put("test-request", response);
assertEquals(1, batchBuffer.responses().size());
}

@Test
void whenCancelAndReplaceScheduledFlushWithNullInitialFlushThenNoExceptionThrown() {
// Create buffer with null initial flush
batchBuffer = new RequestBatchBuffer<>(null, 10, MAX_SEND_MESSAGE_PAYLOAD_SIZE_BYTES, maxBufferSize);

ScheduledFuture<?> newFlush = mock(ScheduledFuture.class);

// Should not throw exception when initial flush is null
assertDoesNotThrow(() -> batchBuffer.cancelAndReplaceScheduledFlush(newFlush));

// Verify newFlush is not canceled (it's the current one)
verify(newFlush, never()).cancel(false);
}

private String createLargeString(char ch, int length) {
StringBuilder sb = new StringBuilder(length);
for (int i = 0; i < length; i++) {
Expand All @@ -198,4 +239,4 @@ private String createLargeString(char ch, int length) {



}
}
Loading