diff --git a/google-cloud-spanner/clirr-ignored-differences.xml b/google-cloud-spanner/clirr-ignored-differences.xml index 8e7a392302a..c690514b59e 100644 --- a/google-cloud-spanner/clirr-ignored-differences.xml +++ b/google-cloud-spanner/clirr-ignored-differences.xml @@ -758,4 +758,12 @@ com/google/cloud/spanner/connection/Connection boolean isKeepTransactionAlive() + + + + 7012 + com/google/cloud/spanner/SpannerOptions$CloseableExecutorProvider + java.util.concurrent.ExecutorService getExecutorService() + + diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/CachedMinMaxThreadsExecutor.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/CachedMinMaxThreadsExecutor.java new file mode 100644 index 00000000000..9d0da39d14c --- /dev/null +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/CachedMinMaxThreadsExecutor.java @@ -0,0 +1,110 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.spanner; + +import com.google.common.util.concurrent.Uninterruptibles; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.RejectedExecutionHandler; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import javax.annotation.Nonnull; + +/** + * {@link ThreadPoolExecutor} that uses a cached thread pool with a min and a max number of threads. + */ +public class CachedMinMaxThreadsExecutor extends ThreadPoolExecutor { + + /** + * Creates a new {@link ExecutorService} that uses a cached thread pool with a min and a max + * number of threads for the underlying thread pool. + */ + public static ExecutorService newCachedMinMaxThreadPoolExecutor( + int numCoreThreads, + int maxNumThreads, + long keepAliveTime, + TimeUnit timeUnit, + ThreadFactory threadFactory) { + WorkQueue queue = new WorkQueue(); + ThreadPoolExecutor executor = + new CachedMinMaxThreadsExecutor( + numCoreThreads, maxNumThreads, keepAliveTime, timeUnit, queue, threadFactory); + executor.setRejectedExecutionHandler(ForceQueuePolicy.INSTANCE); + queue.setThreadPoolExecutor(executor); + + return executor; + } + + /** Work queue for {@link CachedMinMaxThreadsExecutor}. */ + private static class WorkQueue extends LinkedBlockingQueue { + private ThreadPoolExecutor executor; + + void setThreadPoolExecutor(ThreadPoolExecutor executor) { + this.executor = executor; + } + + @Override + public boolean offer(@Nonnull Runnable work) { + // Calculate the number of running tasks + queued tasks. + int currentTaskCount = executor.getActiveCount() + size(); + // Accept the task if there are more threads in the pool than active tasks. + // Reject it if the current task count occupies all threads in the pool. + // This will trigger the RejectedExecutionHandler to be triggered, which again + // will add the work to the queue. That again will trigger the creation of more + // threads, as long as the thread count won't exceed the maximum thread count. + return currentTaskCount < executor.getPoolSize() && super.offer(work); + } + } + + private static class ForceQueuePolicy implements RejectedExecutionHandler { + private static final ForceQueuePolicy INSTANCE = new ForceQueuePolicy(); + + @Override + public void rejectedExecution(Runnable runnable, ThreadPoolExecutor executor) { + Uninterruptibles.putUninterruptibly(executor.getQueue(), runnable); + } + } + + private final AtomicInteger activeCount = new AtomicInteger(); + + private CachedMinMaxThreadsExecutor( + int corePoolSize, + int maximumPoolSize, + long keepAliveTime, + TimeUnit unit, + BlockingQueue workQueue, + ThreadFactory threadFactory) { + super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory); + } + + @Override + public int getActiveCount() { + return activeCount.get(); + } + + @Override + protected void beforeExecute(Thread t, Runnable r) { + activeCount.incrementAndGet(); + } + + @Override + protected void afterExecute(Runnable r, Throwable t) { + activeCount.decrementAndGet(); + } +} diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java index 5756ff64b89..20445d219b5 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java @@ -80,6 +80,7 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Set; +import java.util.concurrent.ExecutorService; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.ThreadFactory; @@ -522,6 +523,11 @@ public interface CloseableExecutorProvider extends ExecutorProvider, AutoCloseab /** Overridden to suppress the throws declaration of the super interface. */ @Override void close(); + + /** Returns a normal (non-scheduled) {@link ExecutorService}. */ + default ExecutorService getExecutorService() { + return getExecutor(); + } } /** diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/CachedMinMaxThreadExecutorTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/CachedMinMaxThreadExecutorTest.java new file mode 100644 index 00000000000..d716e6557a5 --- /dev/null +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/CachedMinMaxThreadExecutorTest.java @@ -0,0 +1,78 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import com.google.common.base.Stopwatch; +import com.google.common.util.concurrent.ThreadFactoryBuilder; +import java.util.concurrent.Callable; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class CachedMinMaxThreadExecutorTest { + + @Test + public void testMinMaxExecutor() throws Exception { + String format = "test-thread-pool-%d"; + ThreadFactory threadFactory = + new ThreadFactoryBuilder().setDaemon(true).setNameFormat(format).build(); + CachedMinMaxThreadsExecutor executor = + (CachedMinMaxThreadsExecutor) + CachedMinMaxThreadsExecutor.newCachedMinMaxThreadPoolExecutor( + 1, 2, 1L, TimeUnit.NANOSECONDS, threadFactory); + CountDownLatch startLatch = new CountDownLatch(2); + CountDownLatch continueLatch = new CountDownLatch(1); + CountDownLatch finishLatch = new CountDownLatch(3); + Callable callable = + () -> { + startLatch.countDown(); + continueLatch.await(); + finishLatch.countDown(); + return null; + }; + executor.submit(callable); + executor.submit(callable); + executor.submit(callable); + + // Wait until 2 of the tasks have started. + assertTrue(startLatch.await(1L, TimeUnit.SECONDS)); + // Verify that we have 2 concurrent threads. + assertEquals(2, executor.getActiveCount()); + assertEquals(2, executor.getPoolSize()); + // Allow the tasks to continue. + continueLatch.countDown(); + // Verify that all three tasks finish. + assertTrue(finishLatch.await(1L, TimeUnit.SECONDS)); + // Verify that the max pool size was 2. + assertEquals(2, executor.getMaximumPoolSize()); + // Verify that the pool scales back down to 1 core thread. + Stopwatch stopwatch = Stopwatch.createStarted(); + while (stopwatch.elapsed(TimeUnit.MILLISECONDS) < 1000 && executor.getPoolSize() > 1) { + Thread.yield(); + } + assertEquals(1, executor.getPoolSize()); + + executor.shutdown(); + } +} diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ReadAsyncTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ReadAsyncTest.java index dbef7ce29f6..a462cfd3fb4 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ReadAsyncTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ReadAsyncTest.java @@ -47,6 +47,7 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.Future; import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.SynchronousQueue; import org.junit.After; @@ -165,13 +166,32 @@ public void emptyReadAsync() throws Exception { @Test public void pointReadAsync() throws Exception { - ApiFuture row = - client - .singleUse(TimestampBound.strong()) - .readRowAsync(READ_TABLE_NAME, Key.of("k1"), READ_COLUMN_NAMES); - assertThat(row.get()).isNotNull(); - assertThat(row.get().getString(0)).isEqualTo("k1"); - assertThat(row.get().getString(1)).isEqualTo("v1"); + int numThreads = 32; + ExecutorService service = Executors.newFixedThreadPool(32); + List> futures = new ArrayList<>(numThreads); + for (int i = 0; i < numThreads; i++) { + Future future = + service.submit( + () -> { + try { + ApiFuture row = + client + .singleUse(TimestampBound.strong()) + .readRowAsync(READ_TABLE_NAME, Key.of("k1"), READ_COLUMN_NAMES); + assertThat(row.get()).isNotNull(); + assertThat(row.get().getString(0)).isEqualTo("k1"); + assertThat(row.get().getString(1)).isEqualTo("v1"); + } catch (Throwable t) { + throw SpannerExceptionFactory.asSpannerException(t); + } + }); + futures.add(future); + } + service.shutdown(); + for (Future future : futures) { + future.get(); + } + System.out.println("Done"); } @Test