diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcPageSource.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcPageSource.java index 64e2b705b40..cfa8aeca9a3 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcPageSource.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcPageSource.java @@ -41,17 +41,18 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static io.airlift.concurrent.MoreFutures.getFutureValue; import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_ERROR; import static java.lang.System.nanoTime; import static java.util.Objects.requireNonNull; +import static java.util.concurrent.CompletableFuture.runAsync; import static java.util.concurrent.CompletableFuture.supplyAsync; public final class JdbcPageSource implements ConnectorPageSource { private static final Logger log = Logger.get(JdbcPageSource.class); - private static final CompletableFuture UNINITIALIZED_RESULT_SET_FUTURE = CompletableFuture.completedFuture(null); private final List columnHandles; private final ReadFunction[] readFunctions; @@ -62,12 +63,11 @@ public final class JdbcPageSource private final ObjectReadFunction[] objectReadFunctions; private final JdbcClient jdbcClient; - private final ExecutorService executor; private final Connection connection; private final PreparedStatement statement; private final AtomicLong readTimeNanos = new AtomicLong(0); private final PageBuilder pageBuilder; - private CompletableFuture resultSetFuture; + private CompletableFuture resultSetFuture; @Nullable private ResultSet resultSet; private boolean finished; @@ -77,7 +77,6 @@ public final class JdbcPageSource public JdbcPageSource(JdbcClient jdbcClient, ExecutorService executor, ConnectorSession session, JdbcSplit split, BaseJdbcConnectorTableHandle table, List columnHandles) { this.jdbcClient = requireNonNull(jdbcClient, "jdbcClient is null"); - this.executor = requireNonNull(executor, "executor is null"); this.columnHandles = ImmutableList.copyOf(columnHandles); readFunctions = new ReadFunction[columnHandles.size()]; @@ -133,7 +132,22 @@ else if (javaType == Slice.class) { pageBuilder = new PageBuilder(columnHandles.stream() .map(JdbcColumnHandle::getColumnType) .collect(toImmutableList())); - resultSetFuture = UNINITIALIZED_RESULT_SET_FUTURE; + resultSetFuture = supplyAsync(() -> { + long start = nanoTime(); + try { + log.debug("Executing: %s", statement); + return statement.executeQuery(); + } + catch (SQLException e) { + throw handleSqlException(e); + } + finally { + readTimeNanos.addAndGet(nanoTime() - start); + } + }, executor).thenAcceptAsync(resultSet -> { + this.resultSet = requireNonNull(resultSet, "resultSet is null"); + buildPageFromResultSet(); + }, directExecutor()); } catch (SQLException | RuntimeException e) { throw handleSqlException(e); @@ -149,41 +163,44 @@ public long getReadTimeNanos() @Override public boolean isFinished() { - return finished; + return finished && pageBuilder.isEmpty(); } @Override public SourcePage getNextSourcePage() { - verify(pageBuilder.isEmpty(), "Expected pageBuilder to be empty"); - if (finished) { + if (!resultSetFuture.isDone()) { return null; } + try { - if (resultSetFuture == UNINITIALIZED_RESULT_SET_FUTURE && resultSet == null) { - checkState(!closed, "page source is closed"); - resultSetFuture = supplyAsync(() -> { - long start = nanoTime(); - try { - log.debug("Executing: %s", statement); - return statement.executeQuery(); - } - catch (SQLException e) { - throw handleSqlException(e); - } - finally { - readTimeNanos.addAndGet(nanoTime() - start); - } - }, executor); - } - if (resultSet == null) { - if (!resultSetFuture.isDone()) { - return null; - } - resultSet = requireNonNull(getFutureValue(resultSetFuture), "resultSet is null"); - } + // throw exception + getFutureValue(resultSetFuture); checkState(!closed, "page source is closed"); + } + catch (Throwable throwable) { + throw handleSqlException(throwable); + } + + if (isFinished()) { + return null; + } + + Page page = pageBuilder.build(); + pageBuilder.reset(); + + if (!finished) { + resultSetFuture = runAsync(this::buildPageFromResultSet, directExecutor()); + } + + return SourcePage.create(page); + } + + private void buildPageFromResultSet() + { + verify(pageBuilder.isEmpty(), "Expected pageBuilder to be empty"); + try { while (!pageBuilder.isFull() && resultSet.next()) { pageBuilder.declarePosition(); completedPositions++; @@ -215,13 +232,9 @@ else if (sliceReadFunctions[i] != null) { finished = true; } } - catch (SQLException | RuntimeException e) { + catch (SQLException e) { throw handleSqlException(e); } - - Page page = pageBuilder.build(); - pageBuilder.reset(); - return SourcePage.create(page); } @Override @@ -280,12 +293,12 @@ public void close() resultSet = null; } - private RuntimeException handleSqlException(Exception e) + private RuntimeException handleSqlException(Throwable e) { try { close(); } - catch (Exception closeException) { + catch (Throwable closeException) { // Self-suppression not permitted if (e != closeException) { e.addSuppressed(closeException);