diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java index f6f17770f..e0f8aa738 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java @@ -20,6 +20,7 @@ import io.netty.util.concurrent.DefaultThreadFactory; import java.sql.SQLException; +import java.util.List; import java.util.Properties; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -180,19 +181,30 @@ public Properties getClientInfo() { @Override public void close() throws SQLException { - clientHandler.close(); - if (executorService != null) { - executorService.shutdown(); + Exception topLevelException = null; + try { + AutoCloseables.close(List.copyOf(statementMap.values())); + } catch (final Exception e) { + topLevelException = e; } - try { AutoCloseables.close(clientHandler); + if (executorService != null) { + executorService.shutdown(); + } allocator.getChildAllocators().forEach(AutoCloseables::closeNoChecked); AutoCloseables.close(allocator); - super.close(); } catch (final Exception e) { - throw AvaticaConnection.HELPER.createException(e.getMessage(), e); + if (topLevelException == null) { + topLevelException = e; + } else { + topLevelException.addSuppressed(e); + } + } + if (topLevelException != null) { + throw AvaticaConnection.HELPER.createException( + topLevelException.getMessage(), topLevelException); } } diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ConnectionTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ConnectionTest.java index 72e4b222a..f1af2732f 100644 --- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ConnectionTest.java +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ConnectionTest.java @@ -16,6 +16,7 @@ */ package org.apache.arrow.driver.jdbc; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -26,6 +27,7 @@ import java.sql.Driver; import java.sql.DriverManager; import java.sql.SQLException; +import java.sql.Statement; import java.util.Properties; import org.apache.arrow.driver.jdbc.authentication.UserPasswordAuthentication; import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler; @@ -622,4 +624,40 @@ public void testJdbcDriverVersionIntegration() throws Exception { "Expected: " + expectedUserAgent + " but found: " + actualUserAgent); } } + + @Test + public void testStatementsClosedOnConnectionClose() throws Exception { + // create a connection + final Properties properties = new Properties(); + properties.put(ArrowFlightConnectionProperty.HOST.camelName(), "localhost"); + properties.put( + ArrowFlightConnectionProperty.PORT.camelName(), FLIGHT_SERVER_TEST_EXTENSION.getPort()); + properties.put(ArrowFlightConnectionProperty.USER.camelName(), userTest); + properties.put(ArrowFlightConnectionProperty.PASSWORD.camelName(), passTest); + properties.put("useEncryption", false); + + Connection connection = + DriverManager.getConnection( + "jdbc:arrow-flight-sql://" + + FLIGHT_SERVER_TEST_EXTENSION.getHost() + + ":" + + FLIGHT_SERVER_TEST_EXTENSION.getPort(), + properties); + + // create some statements + int numStatements = 3; + Statement[] statements = new Statement[numStatements]; + for (int i = 0; i < numStatements; i++) { + statements[i] = connection.createStatement(); + assertFalse(statements[i].isClosed()); + } + + // close the connection + connection.close(); + + // assert the statements are closed + for (int i = 0; i < numStatements; i++) { + assertTrue(statements[i].isClosed()); + } + } }