diff --git a/google-cloud-spanner/clirr-ignored-differences.xml b/google-cloud-spanner/clirr-ignored-differences.xml
index 8e7a392302a..c690514b59e 100644
--- a/google-cloud-spanner/clirr-ignored-differences.xml
+++ b/google-cloud-spanner/clirr-ignored-differences.xml
@@ -758,4 +758,12 @@
com/google/cloud/spanner/connection/Connection
boolean isKeepTransactionAlive()
+
+
+
+ 7012
+ com/google/cloud/spanner/SpannerOptions$CloseableExecutorProvider
+ java.util.concurrent.ExecutorService getExecutorService()
+
+
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/CachedMinMaxThreadsExecutor.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/CachedMinMaxThreadsExecutor.java
new file mode 100644
index 00000000000..9d0da39d14c
--- /dev/null
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/CachedMinMaxThreadsExecutor.java
@@ -0,0 +1,110 @@
+/*
+ * Copyright 2024 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.google.cloud.spanner;
+
+import com.google.common.util.concurrent.Uninterruptibles;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.RejectedExecutionHandler;
+import java.util.concurrent.ThreadFactory;
+import java.util.concurrent.ThreadPoolExecutor;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import javax.annotation.Nonnull;
+
+/**
+ * {@link ThreadPoolExecutor} that uses a cached thread pool with a min and a max number of threads.
+ */
+public class CachedMinMaxThreadsExecutor extends ThreadPoolExecutor {
+
+ /**
+ * Creates a new {@link ExecutorService} that uses a cached thread pool with a min and a max
+ * number of threads for the underlying thread pool.
+ */
+ public static ExecutorService newCachedMinMaxThreadPoolExecutor(
+ int numCoreThreads,
+ int maxNumThreads,
+ long keepAliveTime,
+ TimeUnit timeUnit,
+ ThreadFactory threadFactory) {
+ WorkQueue queue = new WorkQueue();
+ ThreadPoolExecutor executor =
+ new CachedMinMaxThreadsExecutor(
+ numCoreThreads, maxNumThreads, keepAliveTime, timeUnit, queue, threadFactory);
+ executor.setRejectedExecutionHandler(ForceQueuePolicy.INSTANCE);
+ queue.setThreadPoolExecutor(executor);
+
+ return executor;
+ }
+
+ /** Work queue for {@link CachedMinMaxThreadsExecutor}. */
+ private static class WorkQueue extends LinkedBlockingQueue {
+ private ThreadPoolExecutor executor;
+
+ void setThreadPoolExecutor(ThreadPoolExecutor executor) {
+ this.executor = executor;
+ }
+
+ @Override
+ public boolean offer(@Nonnull Runnable work) {
+ // Calculate the number of running tasks + queued tasks.
+ int currentTaskCount = executor.getActiveCount() + size();
+ // Accept the task if there are more threads in the pool than active tasks.
+ // Reject it if the current task count occupies all threads in the pool.
+ // This will trigger the RejectedExecutionHandler to be triggered, which again
+ // will add the work to the queue. That again will trigger the creation of more
+ // threads, as long as the thread count won't exceed the maximum thread count.
+ return currentTaskCount < executor.getPoolSize() && super.offer(work);
+ }
+ }
+
+ private static class ForceQueuePolicy implements RejectedExecutionHandler {
+ private static final ForceQueuePolicy INSTANCE = new ForceQueuePolicy();
+
+ @Override
+ public void rejectedExecution(Runnable runnable, ThreadPoolExecutor executor) {
+ Uninterruptibles.putUninterruptibly(executor.getQueue(), runnable);
+ }
+ }
+
+ private final AtomicInteger activeCount = new AtomicInteger();
+
+ private CachedMinMaxThreadsExecutor(
+ int corePoolSize,
+ int maximumPoolSize,
+ long keepAliveTime,
+ TimeUnit unit,
+ BlockingQueue workQueue,
+ ThreadFactory threadFactory) {
+ super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory);
+ }
+
+ @Override
+ public int getActiveCount() {
+ return activeCount.get();
+ }
+
+ @Override
+ protected void beforeExecute(Thread t, Runnable r) {
+ activeCount.incrementAndGet();
+ }
+
+ @Override
+ protected void afterExecute(Runnable r, Throwable t) {
+ activeCount.decrementAndGet();
+ }
+}
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java
index 5756ff64b89..20445d219b5 100644
--- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java
@@ -80,6 +80,7 @@
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
+import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.ThreadFactory;
@@ -522,6 +523,11 @@ public interface CloseableExecutorProvider extends ExecutorProvider, AutoCloseab
/** Overridden to suppress the throws declaration of the super interface. */
@Override
void close();
+
+ /** Returns a normal (non-scheduled) {@link ExecutorService}. */
+ default ExecutorService getExecutorService() {
+ return getExecutor();
+ }
}
/**
diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/CachedMinMaxThreadExecutorTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/CachedMinMaxThreadExecutorTest.java
new file mode 100644
index 00000000000..d716e6557a5
--- /dev/null
+++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/CachedMinMaxThreadExecutorTest.java
@@ -0,0 +1,78 @@
+/*
+ * Copyright 2024 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.cloud.spanner;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import com.google.common.base.Stopwatch;
+import com.google.common.util.concurrent.ThreadFactoryBuilder;
+import java.util.concurrent.Callable;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ThreadFactory;
+import java.util.concurrent.TimeUnit;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class CachedMinMaxThreadExecutorTest {
+
+ @Test
+ public void testMinMaxExecutor() throws Exception {
+ String format = "test-thread-pool-%d";
+ ThreadFactory threadFactory =
+ new ThreadFactoryBuilder().setDaemon(true).setNameFormat(format).build();
+ CachedMinMaxThreadsExecutor executor =
+ (CachedMinMaxThreadsExecutor)
+ CachedMinMaxThreadsExecutor.newCachedMinMaxThreadPoolExecutor(
+ 1, 2, 1L, TimeUnit.NANOSECONDS, threadFactory);
+ CountDownLatch startLatch = new CountDownLatch(2);
+ CountDownLatch continueLatch = new CountDownLatch(1);
+ CountDownLatch finishLatch = new CountDownLatch(3);
+ Callable callable =
+ () -> {
+ startLatch.countDown();
+ continueLatch.await();
+ finishLatch.countDown();
+ return null;
+ };
+ executor.submit(callable);
+ executor.submit(callable);
+ executor.submit(callable);
+
+ // Wait until 2 of the tasks have started.
+ assertTrue(startLatch.await(1L, TimeUnit.SECONDS));
+ // Verify that we have 2 concurrent threads.
+ assertEquals(2, executor.getActiveCount());
+ assertEquals(2, executor.getPoolSize());
+ // Allow the tasks to continue.
+ continueLatch.countDown();
+ // Verify that all three tasks finish.
+ assertTrue(finishLatch.await(1L, TimeUnit.SECONDS));
+ // Verify that the max pool size was 2.
+ assertEquals(2, executor.getMaximumPoolSize());
+ // Verify that the pool scales back down to 1 core thread.
+ Stopwatch stopwatch = Stopwatch.createStarted();
+ while (stopwatch.elapsed(TimeUnit.MILLISECONDS) < 1000 && executor.getPoolSize() > 1) {
+ Thread.yield();
+ }
+ assertEquals(1, executor.getPoolSize());
+
+ executor.shutdown();
+ }
+}
diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ReadAsyncTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ReadAsyncTest.java
index dbef7ce29f6..a462cfd3fb4 100644
--- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ReadAsyncTest.java
+++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ReadAsyncTest.java
@@ -47,6 +47,7 @@
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.SynchronousQueue;
import org.junit.After;
@@ -165,13 +166,32 @@ public void emptyReadAsync() throws Exception {
@Test
public void pointReadAsync() throws Exception {
- ApiFuture row =
- client
- .singleUse(TimestampBound.strong())
- .readRowAsync(READ_TABLE_NAME, Key.of("k1"), READ_COLUMN_NAMES);
- assertThat(row.get()).isNotNull();
- assertThat(row.get().getString(0)).isEqualTo("k1");
- assertThat(row.get().getString(1)).isEqualTo("v1");
+ int numThreads = 32;
+ ExecutorService service = Executors.newFixedThreadPool(32);
+ List> futures = new ArrayList<>(numThreads);
+ for (int i = 0; i < numThreads; i++) {
+ Future> future =
+ service.submit(
+ () -> {
+ try {
+ ApiFuture row =
+ client
+ .singleUse(TimestampBound.strong())
+ .readRowAsync(READ_TABLE_NAME, Key.of("k1"), READ_COLUMN_NAMES);
+ assertThat(row.get()).isNotNull();
+ assertThat(row.get().getString(0)).isEqualTo("k1");
+ assertThat(row.get().getString(1)).isEqualTo("v1");
+ } catch (Throwable t) {
+ throw SpannerExceptionFactory.asSpannerException(t);
+ }
+ });
+ futures.add(future);
+ }
+ service.shutdown();
+ for (Future> future : futures) {
+ future.get();
+ }
+ System.out.println("Done");
}
@Test