diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/CMakeLists.txt b/cpp/src/arrow/flight/sql/odbc/odbc_impl/CMakeLists.txt index b232577ee37..c94312fcd22 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/CMakeLists.txt +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/CMakeLists.txt @@ -162,9 +162,11 @@ add_arrow_test(odbc_spi_impl_test accessors/time_array_accessor_test.cc accessors/timestamp_array_accessor_test.cc flight_sql_connection_test.cc + flight_sql_stream_chunk_buffer_test.cc parse_table_types_test.cc json_converter_test.cc record_batch_transformer_test.cc util_test.cc EXTRA_LINK_LIBS - arrow_odbc_spi_impl) + arrow_odbc_spi_impl + arrow_flight_testing_shared) diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_auth_method.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_auth_method.cc index bdf7f71589c..b0090a8cf74 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_auth_method.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_auth_method.cc @@ -44,8 +44,8 @@ class NoOpClientAuthHandler : public ClientAuthHandler { NoOpClientAuthHandler() {} Status Authenticate(ClientAuthSender* outgoing, ClientAuthReader* incoming) override { - // Write a blank string. The server should ignore this and just accept any Handshake - // request. + // The server should ignore this and just accept any Handshake + // request. Some servers do not allow authentication with no handshakes. return outgoing->Write(std::string()); } diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_connection.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_connection.cc index 479a72f3fea..d8bde79f6ee 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_connection.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_connection.cc @@ -100,6 +100,8 @@ inline std::string GetCerts() { return ""; } #endif const std::set BUILT_IN_PROPERTIES = { + FlightSqlConnection::DRIVER, + FlightSqlConnection::DSN, FlightSqlConnection::HOST, FlightSqlConnection::PORT, FlightSqlConnection::USER, @@ -153,14 +155,14 @@ void FlightSqlConnection::Connect(const ConnPropertyMap& properties, auto flight_ssl_configs = LoadFlightSslConfigs(properties); Location location = BuildLocation(properties, missing_attr, flight_ssl_configs); - FlightClientOptions client_options = + client_options_ = BuildFlightClientOptions(properties, missing_attr, flight_ssl_configs); const std::shared_ptr& cookie_factory = GetCookieFactory(); - client_options.middleware.push_back(cookie_factory); + client_options_.middleware.push_back(cookie_factory); std::unique_ptr flight_client; - ThrowIfNotOK(FlightClient::Connect(location, client_options).Value(&flight_client)); + ThrowIfNotOK(FlightClient::Connect(location, client_options_).Value(&flight_client)); std::unique_ptr auth_method = FlightSqlAuthMethod::FromProperties(flight_client, properties); @@ -364,7 +366,7 @@ void FlightSqlConnection::Close() { std::shared_ptr FlightSqlConnection::CreateStatement() { return std::shared_ptr(new FlightSqlStatement( - diagnostics_, *sql_client_, call_options_, metadata_settings_)); + diagnostics_, *sql_client_, client_options_, call_options_, metadata_settings_)); } bool FlightSqlConnection::SetAttribute(Connection::AttributeId attribute, @@ -410,7 +412,7 @@ FlightSqlConnection::FlightSqlConnection(OdbcVersion odbc_version, const std::string& driver_version) : diagnostics_("Apache Arrow", "Flight SQL", odbc_version), odbc_version_(odbc_version), - info_(call_options_, sql_client_, driver_version), + info_(client_options_, call_options_, sql_client_, driver_version), closed_(true) { attribute_[CONNECTION_DEAD] = static_cast(SQL_TRUE); attribute_[LOGIN_TIMEOUT] = static_cast(0); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set.cc index 19149b3c48d..80967b9f200 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set.cc @@ -29,12 +29,12 @@ namespace arrow::flight::sql::odbc { FlightSqlResultSet::FlightSqlResultSet( - FlightSqlClient& flight_sql_client, const FlightCallOptions& call_options, - const std::shared_ptr& flight_info, + FlightSqlClient& flight_sql_client, const FlightClientOptions& client_options, + const FlightCallOptions& call_options, const std::shared_ptr& flight_info, const std::shared_ptr& transformer, Diagnostics& diagnostics, const MetadataSettings& metadata_settings) : metadata_settings_(metadata_settings), - chunk_buffer_(flight_sql_client, call_options, flight_info, + chunk_buffer_(flight_sql_client, client_options, call_options, flight_info, metadata_settings_.chunk_buffer_capacity), transformer_(transformer), metadata_(transformer diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set.h index 6083b332824..ac2ae80e010 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set.h @@ -51,6 +51,7 @@ class FlightSqlResultSet : public ResultSet { ~FlightSqlResultSet() override; FlightSqlResultSet(FlightSqlClient& flight_sql_client, + const FlightClientOptions& client_options, const FlightCallOptions& call_options, const std::shared_ptr& flight_info, const std::shared_ptr& transformer, diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.cc index 30eb1fdf44a..785a04c7b0e 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.cc @@ -41,9 +41,10 @@ using util::ThrowIfNotOK; namespace { -void ClosePreparedStatementIfAny(std::shared_ptr& prepared_statement) { +void ClosePreparedStatementIfAny(std::shared_ptr& prepared_statement, + const FlightCallOptions& options) { if (prepared_statement != nullptr) { - ThrowIfNotOK(prepared_statement->Close()); + ThrowIfNotOK(prepared_statement->Close(options)); prepared_statement.reset(); } } @@ -52,11 +53,13 @@ void ClosePreparedStatementIfAny(std::shared_ptr& prepared_st FlightSqlStatement::FlightSqlStatement(const Diagnostics& diagnostics, FlightSqlClient& sql_client, + FlightClientOptions client_options, FlightCallOptions call_options, const MetadataSettings& metadata_settings) : diagnostics_("Apache Arrow", diagnostics.GetDataSourceComponent(), diagnostics.GetOdbcVersion()), sql_client_(sql_client), + client_options_(std::move(client_options)), call_options_(std::move(call_options)), metadata_settings_(metadata_settings) { attribute_[METADATA_ID] = static_cast(SQL_FALSE); @@ -97,7 +100,7 @@ boost::optional FlightSqlStatement::GetAttribute( boost::optional> FlightSqlStatement::Prepare( const std::string& query) { - ClosePreparedStatementIfAny(prepared_statement_); + ClosePreparedStatementIfAny(prepared_statement_, call_options_); Result> result = sql_client_.Prepare(call_options_, query); @@ -111,27 +114,30 @@ boost::optional> FlightSqlStatement::Prepare( } bool FlightSqlStatement::ExecutePrepared() { + // GH-47990 TODO: use DCHECK instead of assert assert(prepared_statement_.get() != nullptr); - Result> result = prepared_statement_->Execute(); + Result> result = + prepared_statement_->Execute(call_options_); + ThrowIfNotOK(result.status()); current_result_set_ = std::make_shared( - sql_client_, call_options_, result.ValueOrDie(), nullptr, diagnostics_, - metadata_settings_); + sql_client_, client_options_, call_options_, result.ValueOrDie(), nullptr, + diagnostics_, metadata_settings_); return true; } bool FlightSqlStatement::Execute(const std::string& query) { - ClosePreparedStatementIfAny(prepared_statement_); + ClosePreparedStatementIfAny(prepared_statement_, call_options_); Result> result = sql_client_.Execute(call_options_, query); ThrowIfNotOK(result.status()); current_result_set_ = std::make_shared( - sql_client_, call_options_, result.ValueOrDie(), nullptr, diagnostics_, - metadata_settings_); + sql_client_, client_options_, call_options_, result.ValueOrDie(), nullptr, + diagnostics_, metadata_settings_); return true; } @@ -146,33 +152,35 @@ std::shared_ptr FlightSqlStatement::GetTables( const std::string* catalog_name, const std::string* schema_name, const std::string* table_name, const std::string* table_type, const ColumnNames& column_names) { - ClosePreparedStatementIfAny(prepared_statement_); + ClosePreparedStatementIfAny(prepared_statement_, call_options_); std::vector table_types; if ((catalog_name && *catalog_name == "%") && (schema_name && schema_name->empty()) && (table_name && table_name->empty())) { - current_result_set_ = GetTablesForSQLAllCatalogs( - column_names, call_options_, sql_client_, diagnostics_, metadata_settings_); + current_result_set_ = + GetTablesForSQLAllCatalogs(column_names, client_options_, call_options_, + sql_client_, diagnostics_, metadata_settings_); } else if ((catalog_name && catalog_name->empty()) && (schema_name && *schema_name == "%") && (table_name && table_name->empty())) { - current_result_set_ = - GetTablesForSQLAllDbSchemas(column_names, call_options_, sql_client_, schema_name, - diagnostics_, metadata_settings_); + current_result_set_ = GetTablesForSQLAllDbSchemas( + column_names, client_options_, call_options_, sql_client_, schema_name, + diagnostics_, metadata_settings_); } else if ((catalog_name && catalog_name->empty()) && (schema_name && schema_name->empty()) && (table_name && table_name->empty()) && (table_type && *table_type == "%")) { - current_result_set_ = GetTablesForSQLAllTableTypes( - column_names, call_options_, sql_client_, diagnostics_, metadata_settings_); + current_result_set_ = + GetTablesForSQLAllTableTypes(column_names, client_options_, call_options_, + sql_client_, diagnostics_, metadata_settings_); } else { if (table_type) { ParseTableTypes(*table_type, table_types); } current_result_set_ = GetTablesForGenericUse( - column_names, call_options_, sql_client_, catalog_name, schema_name, table_name, - table_types, diagnostics_, metadata_settings_); + column_names, client_options_, call_options_, sql_client_, catalog_name, + schema_name, table_name, table_types, diagnostics_, metadata_settings_); } return current_result_set_; @@ -199,7 +207,7 @@ std::shared_ptr FlightSqlStatement::GetTables_V3( std::shared_ptr FlightSqlStatement::GetColumns_V2( const std::string* catalog_name, const std::string* schema_name, const std::string* table_name, const std::string* column_name) { - ClosePreparedStatementIfAny(prepared_statement_); + ClosePreparedStatementIfAny(prepared_statement_, call_options_); Result> result = sql_client_.GetTables( call_options_, catalog_name, schema_name, table_name, true, nullptr); @@ -210,9 +218,9 @@ std::shared_ptr FlightSqlStatement::GetColumns_V2( auto transformer = std::make_shared( metadata_settings_, OdbcVersion::V_2, column_name); - current_result_set_ = - std::make_shared(sql_client_, call_options_, flight_info, - transformer, diagnostics_, metadata_settings_); + current_result_set_ = std::make_shared( + sql_client_, client_options_, call_options_, flight_info, transformer, diagnostics_, + metadata_settings_); return current_result_set_; } @@ -220,7 +228,7 @@ std::shared_ptr FlightSqlStatement::GetColumns_V2( std::shared_ptr FlightSqlStatement::GetColumns_V3( const std::string* catalog_name, const std::string* schema_name, const std::string* table_name, const std::string* column_name) { - ClosePreparedStatementIfAny(prepared_statement_); + ClosePreparedStatementIfAny(prepared_statement_, call_options_); Result> result = sql_client_.GetTables( call_options_, catalog_name, schema_name, table_name, true, nullptr); @@ -231,15 +239,15 @@ std::shared_ptr FlightSqlStatement::GetColumns_V3( auto transformer = std::make_shared( metadata_settings_, OdbcVersion::V_3, column_name); - current_result_set_ = - std::make_shared(sql_client_, call_options_, flight_info, - transformer, diagnostics_, metadata_settings_); + current_result_set_ = std::make_shared( + sql_client_, client_options_, call_options_, flight_info, transformer, diagnostics_, + metadata_settings_); return current_result_set_; } std::shared_ptr FlightSqlStatement::GetTypeInfo_V2(int16_t data_type) { - ClosePreparedStatementIfAny(prepared_statement_); + ClosePreparedStatementIfAny(prepared_statement_, call_options_); Result> result = sql_client_.GetXdbcTypeInfo(call_options_); ThrowIfNotOK(result.status()); @@ -249,15 +257,15 @@ std::shared_ptr FlightSqlStatement::GetTypeInfo_V2(int16_t data_type) auto transformer = std::make_shared( metadata_settings_, OdbcVersion::V_2, data_type); - current_result_set_ = - std::make_shared(sql_client_, call_options_, flight_info, - transformer, diagnostics_, metadata_settings_); + current_result_set_ = std::make_shared( + sql_client_, client_options_, call_options_, flight_info, transformer, diagnostics_, + metadata_settings_); return current_result_set_; } std::shared_ptr FlightSqlStatement::GetTypeInfo_V3(int16_t data_type) { - ClosePreparedStatementIfAny(prepared_statement_); + ClosePreparedStatementIfAny(prepared_statement_, call_options_); Result> result = sql_client_.GetXdbcTypeInfo(call_options_); ThrowIfNotOK(result.status()); @@ -267,9 +275,9 @@ std::shared_ptr FlightSqlStatement::GetTypeInfo_V3(int16_t data_type) auto transformer = std::make_shared( metadata_settings_, OdbcVersion::V_3, data_type); - current_result_set_ = - std::make_shared(sql_client_, call_options_, flight_info, - transformer, diagnostics_, metadata_settings_); + current_result_set_ = std::make_shared( + sql_client_, client_options_, call_options_, flight_info, transformer, diagnostics_, + metadata_settings_); return current_result_set_; } diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.h index 36dc245c1d7..3593b2f774d 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.h @@ -32,6 +32,7 @@ class FlightSqlStatement : public Statement { private: Diagnostics diagnostics_; std::map attribute_; + FlightClientOptions client_options_; FlightCallOptions call_options_; FlightSqlClient& sql_client_; std::shared_ptr current_result_set_; @@ -46,7 +47,7 @@ class FlightSqlStatement : public Statement { public: FlightSqlStatement(const Diagnostics& diagnostics, FlightSqlClient& sql_client, - FlightCallOptions call_options, + FlightClientOptions client_options, FlightCallOptions call_options, const MetadataSettings& metadata_settings); bool SetAttribute(StatementAttributeId attribute, const Attribute& value) override; diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_tables.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_tables.cc index 2a6bb8970b4..f50cc4cf2f9 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_tables.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_tables.cc @@ -69,9 +69,9 @@ void ParseTableTypes(const std::string& table_type, } std::shared_ptr GetTablesForSQLAllCatalogs( - const ColumnNames& names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, Diagnostics& diagnostics, - const MetadataSettings& metadata_settings) { + const ColumnNames& names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + Diagnostics& diagnostics, const MetadataSettings& metadata_settings) { Result> result = sql_client.GetCatalogs(call_options); std::shared_ptr schema; @@ -89,13 +89,15 @@ std::shared_ptr GetTablesForSQLAllCatalogs( .AddFieldOfNulls(names.remarks_column, arrow::utf8()) .Build(); - return std::make_shared( - sql_client, call_options, flight_info, transformer, diagnostics, metadata_settings); + return std::make_shared(sql_client, client_options, call_options, + flight_info, transformer, diagnostics, + metadata_settings); } std::shared_ptr GetTablesForSQLAllDbSchemas( - const ColumnNames& names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, const std::string* schema_name, Diagnostics& diagnostics, + const ColumnNames& names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + const std::string* schema_name, Diagnostics& diagnostics, const MetadataSettings& metadata_settings) { Result> result = sql_client.GetDbSchemas(call_options, nullptr, schema_name); @@ -115,14 +117,15 @@ std::shared_ptr GetTablesForSQLAllDbSchemas( .AddFieldOfNulls(names.remarks_column, arrow::utf8()) .Build(); - return std::make_shared( - sql_client, call_options, flight_info, transformer, diagnostics, metadata_settings); + return std::make_shared(sql_client, client_options, call_options, + flight_info, transformer, diagnostics, + metadata_settings); } std::shared_ptr GetTablesForSQLAllTableTypes( - const ColumnNames& names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, Diagnostics& diagnostics, - const MetadataSettings& metadata_settings) { + const ColumnNames& names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + Diagnostics& diagnostics, const MetadataSettings& metadata_settings) { Result> result = sql_client.GetTableTypes(call_options); std::shared_ptr schema; @@ -140,16 +143,17 @@ std::shared_ptr GetTablesForSQLAllTableTypes( .AddFieldOfNulls(names.remarks_column, arrow::utf8()) .Build(); - return std::make_shared( - sql_client, call_options, flight_info, transformer, diagnostics, metadata_settings); + return std::make_shared(sql_client, client_options, call_options, + flight_info, transformer, diagnostics, + metadata_settings); } std::shared_ptr GetTablesForGenericUse( - const ColumnNames& names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, const std::string* catalog_name, - const std::string* schema_name, const std::string* table_name, - const std::vector& table_types, Diagnostics& diagnostics, - const MetadataSettings& metadata_settings) { + const ColumnNames& names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + const std::string* catalog_name, const std::string* schema_name, + const std::string* table_name, const std::vector& table_types, + Diagnostics& diagnostics, const MetadataSettings& metadata_settings) { Result> result = sql_client.GetTables( call_options, catalog_name, schema_name, table_name, false, &table_types); @@ -168,8 +172,9 @@ std::shared_ptr GetTablesForGenericUse( .AddFieldOfNulls(names.remarks_column, arrow::utf8()) .Build(); - return std::make_shared( - sql_client, call_options, flight_info, transformer, diagnostics, metadata_settings); + return std::make_shared(sql_client, client_options, call_options, + flight_info, transformer, diagnostics, + metadata_settings); } } // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_tables.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_tables.h index 31abab91cb5..0c3ad10f97b 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_tables.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_tables.h @@ -40,25 +40,26 @@ void ParseTableTypes(const std::string& table_type, std::vector& table_types); std::shared_ptr GetTablesForSQLAllCatalogs( - const ColumnNames& column_names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, Diagnostics& diagnostics, - const MetadataSettings& metadata_settings); + const ColumnNames& column_names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + Diagnostics& diagnostics, const MetadataSettings& metadata_settings); std::shared_ptr GetTablesForSQLAllDbSchemas( - const ColumnNames& column_names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, const std::string* schema_name, Diagnostics& diagnostics, + const ColumnNames& column_names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + const std::string* schema_name, Diagnostics& diagnostics, const MetadataSettings& metadata_settings); std::shared_ptr GetTablesForSQLAllTableTypes( - const ColumnNames& column_names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, Diagnostics& diagnostics, - const MetadataSettings& metadata_settings); + const ColumnNames& column_names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + Diagnostics& diagnostics, const MetadataSettings& metadata_settings); std::shared_ptr GetTablesForGenericUse( - const ColumnNames& column_names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, const std::string* catalog_name, - const std::string* schema_name, const std::string* table_name, - const std::vector& table_types, Diagnostics& diagnostics, - const MetadataSettings& metadata_settings); + const ColumnNames& column_names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + const std::string* catalog_name, const std::string* schema_name, + const std::string* table_name, const std::vector& table_types, + Diagnostics& diagnostics, const MetadataSettings& metadata_settings); } // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.cc index 25bf04ea507..fa1a4d9f064 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.cc @@ -20,37 +20,67 @@ namespace arrow::flight::sql::odbc { -using arrow::Result; - FlightStreamChunkBuffer::FlightStreamChunkBuffer( - FlightSqlClient& flight_sql_client, const FlightCallOptions& call_options, - const std::shared_ptr& flight_info, size_t queue_capacity) + FlightSqlClient& flight_sql_client, const FlightClientOptions& client_options, + const FlightCallOptions& call_options, const std::shared_ptr& flight_info, + size_t queue_capacity) : queue_(queue_capacity) { - // FIXME: Endpoint iteration should consider endpoints may be at different hosts for (const auto& endpoint : flight_info->endpoints()) { const Ticket& ticket = endpoint.ticket; - auto result = flight_sql_client.DoGet(call_options, ticket); + arrow::Result> result; + std::shared_ptr temp_flight_sql_client; + auto endpoint_locations = endpoint.locations; + if (endpoint_locations.empty()) { + // list of Locations needs to be empty to proceed + result = flight_sql_client.DoGet(call_options, ticket); + } else { + // If it is non-empty, the driver should create a FlightSqlClient to connect to one + // of the specified Locations directly. + + // GH-47117: Currently a new FlightClient will be made for each partition that + // returns a non-empty Location, which is then disposed of. It may be better to + // cache clients because a server may report the same Locations. It would also be + // good to identify when the reported Location is the same as the original + // connection's Location and skip creating a FlightClient in that scenario. + + std::unique_ptr temp_flight_client; + util::ThrowIfNotOK(FlightClient::Connect(endpoint_locations[0], client_options) + .Value(&temp_flight_client)); + temp_flight_sql_client.reset(new FlightSqlClient(std::move(temp_flight_client))); + + result = temp_flight_sql_client->DoGet(call_options, ticket); + } + util::ThrowIfNotOK(result.status()); std::shared_ptr stream_reader_ptr(std::move(result.ValueOrDie())); - BlockingQueue>::Supplier supplier = [=] { + BlockingQueue, + std::shared_ptr>>::Supplier supplier = [=] { auto result = stream_reader_ptr->Next(); bool is_not_ok = !result.ok(); bool is_not_empty = result.ok() && (result.ValueOrDie().data != nullptr); - return boost::make_optional(is_not_ok || is_not_empty, std::move(result)); + // If result is valid, save the temp Flight SQL Client for future stream reader + // call. temp_flight_sql_client is intentionally null if the list of endpoint + // Locations is empty. + // After all data is fetched from reader, the temp client is closed. + return boost::make_optional( + is_not_ok || is_not_empty, + std::make_pair(std::move(result), temp_flight_sql_client)); }; queue_.AddProducer(std::move(supplier)); } } bool FlightStreamChunkBuffer::GetNext(FlightStreamChunk* chunk) { - Result result; - if (!queue_.Pop(&result)) { + std::pair, std::shared_ptr> + closeable_endpoint_stream_pair; + if (!queue_.Pop(&closeable_endpoint_stream_pair)) { return false; } + Result result = closeable_endpoint_stream_pair.first; if (!result.status().ok()) { Close(); throw DriverException(result.status().message()); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.h index f59336c984d..772a854eb59 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.h @@ -23,11 +23,15 @@ namespace arrow::flight::sql::odbc { +using arrow::Result; + class FlightStreamChunkBuffer { - BlockingQueue> queue_; + BlockingQueue, std::shared_ptr>> + queue_; public: FlightStreamChunkBuffer(FlightSqlClient& flight_sql_client, + const FlightClientOptions& client_options, const FlightCallOptions& call_options, const std::shared_ptr& flight_info, size_t queue_capacity = 5); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer_test.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer_test.cc new file mode 100644 index 00000000000..18ae3d8c57e --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer_test.cc @@ -0,0 +1,138 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/array.h" + +#include "arrow/testing/gtest_util.h" + +#include "arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.h" +#include "arrow/flight/sql/odbc/odbc_impl/json_converter.h" +#include "arrow/flight/test_flight_server.h" +#include "arrow/flight/test_util.h" + +#include + +namespace arrow::flight::sql::odbc { + +using arrow::Array; +using arrow::flight::FlightCallOptions; +using arrow::flight::FlightClientOptions; +using arrow::flight::FlightDescriptor; +using arrow::flight::FlightEndpoint; +using arrow::flight::Location; +using arrow::flight::Ticket; +using arrow::flight::sql::FlightSqlClient; + +class FlightStreamChunkBufferTest : public ::testing::Test { + // Sets up two mock servers for each test case. + // This is for testing endpoint iteration only. + + protected: + void SetUp() override { + // Set up server 1 + server1 = std::make_shared(); + ASSERT_OK_AND_ASSIGN(auto location1, Location::ForGrpcTcp("0.0.0.0", 0)); + arrow::flight::FlightServerOptions options1(location1); + ASSERT_OK(server1->Init(options1)); + ASSERT_OK_AND_ASSIGN(server_location1, + Location::ForGrpcTcp("localhost", server1->port())); + + // Set up server 2 + server2 = std::make_shared(); + ASSERT_OK_AND_ASSIGN(auto location2, Location::ForGrpcTcp("0.0.0.0", 0)); + arrow::flight::FlightServerOptions options2(location2); + ASSERT_OK(server2->Init(options2)); + ASSERT_OK_AND_ASSIGN(server_location2, + Location::ForGrpcTcp("localhost", server2->port())); + + // Make SQL Client that is connected to server 1 + ASSERT_OK_AND_ASSIGN(auto client, arrow::flight::FlightClient::Connect(location1)); + sql_client.reset(new FlightSqlClient(std::move(client))); + } + + void TearDown() override { + ASSERT_OK(server1->Shutdown()); + ASSERT_OK(server1->Wait()); + ASSERT_OK(server2->Shutdown()); + ASSERT_OK(server1->Wait()); + } + + public: + arrow::flight::Location server_location1; + std::shared_ptr server1; + arrow::flight::Location server_location2; + std::shared_ptr server2; + std::shared_ptr sql_client; +}; + +FlightInfo MultipleEndpointsFlightInfo(Location location1, Location location2) { + // Sever will generate random data for `ticket-ints-1` + FlightEndpoint endpoint1({Ticket{"ticket-ints-1"}, {location1}, std::nullopt, {}}); + FlightEndpoint endpoint2({Ticket{"ticket-ints-1"}, {location2}, std::nullopt, {}}); + + FlightDescriptor descr1{FlightDescriptor::PATH, "", {"examples", "ints"}}; + + auto schema1 = arrow::flight::ExampleIntSchema(); + + return arrow::flight::MakeFlightInfo(*schema1, descr1, {endpoint1, endpoint2}, 1000, + 100000, false, ""); +} + +void VerifyArraysContainIntsOnly(std::shared_ptr intArray) { + for (int64_t i = 0; i < intArray->length(); ++i) { + // null values are accepted + if (!intArray->IsNull(i)) { + auto scalar_data = intArray->GetScalar(i).ValueOrDie(); + std::string scalar_str = ConvertToJson(*scalar_data); + ASSERT_TRUE(std::all_of(scalar_str.begin(), scalar_str.end(), ::isdigit)); + } + } +} + +TEST_F(FlightStreamChunkBufferTest, TestMultipleEndpointsInt) { + FlightClientOptions client_options = FlightClientOptions::Defaults(); + FlightCallOptions options; + FlightInfo info = MultipleEndpointsFlightInfo(server_location1, server_location2); + std::shared_ptr info_ptr = std::make_shared(info); + + FlightStreamChunkBuffer chunk_buffer(*sql_client, client_options, options, info_ptr); + + FlightStreamChunk current_chunk; + + // Server returns 5 batch of results from each endpoints. + // Each batch contains 8 columns + int num_chunks = 0; + while (chunk_buffer.GetNext(¤t_chunk)) { + num_chunks++; + + int num_cols = current_chunk.data->num_columns(); + ASSERT_EQ(8, num_cols); + + for (int i = 0; i < num_cols; i++) { + auto array = current_chunk.data->column(i); + // Each array has random length + ASSERT_GT(array->length(), 0); + + VerifyArraysContainIntsOnly(array); + } + } + + // Verify 5 batches of data is returned by each of the two endpoints. + // In total 10 batches should be returned. + ASSERT_EQ(10, num_chunks); +} +} // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/get_info_cache.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/get_info_cache.cc index bf2f6b6eca2..7f6ba8042de 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/get_info_cache.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/get_info_cache.cc @@ -199,10 +199,14 @@ inline void SetDefaultIfMissing(std::unordered_map& } // namespace -GetInfoCache::GetInfoCache(FlightCallOptions& call_options, +GetInfoCache::GetInfoCache(FlightClientOptions& client_options, + FlightCallOptions& call_options, std::unique_ptr& client, const std::string& driver_version) - : call_options_(call_options), sql_client_(client), has_server_info_(false) { + : client_options_(client_options), + call_options_(call_options), + sql_client_(client), + has_server_info_(false) { info_[SQL_DRIVER_NAME] = "Arrow Flight ODBC Driver"; info_[SQL_DRIVER_VER] = util::ConvertToDBMSVer(driver_version); @@ -283,7 +287,8 @@ bool GetInfoCache::LoadInfoFromServer() { arrow::Result> result = sql_client_->GetSqlInfo(call_options_, {}); util::ThrowIfNotOK(result.status()); - FlightStreamChunkBuffer chunk_iter(*sql_client_, call_options_, result.ValueOrDie()); + FlightStreamChunkBuffer chunk_iter(*sql_client_, client_options_, call_options_, + result.ValueOrDie()); FlightStreamChunk chunk; bool supports_correlation_name = false; @@ -311,8 +316,8 @@ bool GetInfoCache::LoadInfoFromServer() { std::string server_name( reinterpret_cast(scalar->child_value().get())->view()); - // TODO: Consider creating different properties in GetSqlInfo. - // TODO: Investigate if SQL_SERVER_NAME should just be the host + // GH-47855 TODO: Consider creating different properties in GetSqlInfo. + // GH-47856 TODO: Investigate if SQL_SERVER_NAME should just be the host // address as well. In JDBC, FLIGHT_SQL_SERVER_NAME is only used for // the DatabaseProductName. info_[SQL_SERVER_NAME] = server_name; diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/get_info_cache.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/get_info_cache.h index d0e0efd159f..a1452e4b466 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/get_info_cache.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/get_info_cache.h @@ -30,13 +30,15 @@ namespace arrow::flight::sql::odbc { class GetInfoCache { private: std::unordered_map info_; + FlightClientOptions& client_options_; FlightCallOptions& call_options_; std::unique_ptr& sql_client_; std::mutex mutex_; std::atomic has_server_info_; public: - GetInfoCache(FlightCallOptions& call_options, std::unique_ptr& client, + GetInfoCache(FlightClientOptions& client_options, FlightCallOptions& call_options, + std::unique_ptr& client, const std::string& driver_version); void SetProperty(uint16_t property, Connection::Info value); Connection::Info GetInfo(uint16_t info_type);