Skip to content

Commit 64e84f5

Browse files
committed
Extract implementation of gh-46574
1 parent 08b3cc9 commit 64e84f5

14 files changed

+295
-98
lines changed

cpp/src/arrow/flight/sql/odbc/odbc_impl/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,11 @@ add_arrow_test(odbc_spi_impl_test
162162
accessors/time_array_accessor_test.cc
163163
accessors/timestamp_array_accessor_test.cc
164164
flight_sql_connection_test.cc
165+
flight_sql_stream_chunk_buffer_test.cc
165166
parse_table_types_test.cc
166167
json_converter_test.cc
167168
record_batch_transformer_test.cc
168169
util_test.cc
169170
EXTRA_LINK_LIBS
170-
arrow_odbc_spi_impl)
171+
arrow_odbc_spi_impl
172+
arrow_flight_testing_shared)

cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_auth_method.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ class NoOpClientAuthHandler : public ClientAuthHandler {
4444
NoOpClientAuthHandler() {}
4545

4646
Status Authenticate(ClientAuthSender* outgoing, ClientAuthReader* incoming) override {
47-
// Write a blank string. The server should ignore this and just accept any Handshake
48-
// request.
47+
// The server should ignore this and just accept any Handshake
48+
// request. Some servers do not allow authentication with no handshakes.
4949
return outgoing->Write(std::string());
5050
}
5151

cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_connection.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ inline std::string GetCerts() { return ""; }
100100
#endif
101101

102102
const std::set<std::string_view, CaseInsensitiveComparator> BUILT_IN_PROPERTIES = {
103+
FlightSqlConnection::DRIVER,
104+
FlightSqlConnection::DSN,
103105
FlightSqlConnection::HOST,
104106
FlightSqlConnection::PORT,
105107
FlightSqlConnection::USER,
@@ -153,14 +155,14 @@ void FlightSqlConnection::Connect(const ConnPropertyMap& properties,
153155
auto flight_ssl_configs = LoadFlightSslConfigs(properties);
154156

155157
Location location = BuildLocation(properties, missing_attr, flight_ssl_configs);
156-
FlightClientOptions client_options =
158+
client_options_ =
157159
BuildFlightClientOptions(properties, missing_attr, flight_ssl_configs);
158160

159161
const std::shared_ptr<ClientMiddlewareFactory>& cookie_factory = GetCookieFactory();
160-
client_options.middleware.push_back(cookie_factory);
162+
client_options_.middleware.push_back(cookie_factory);
161163

162164
std::unique_ptr<FlightClient> flight_client;
163-
ThrowIfNotOK(FlightClient::Connect(location, client_options).Value(&flight_client));
165+
ThrowIfNotOK(FlightClient::Connect(location, client_options_).Value(&flight_client));
164166

165167
std::unique_ptr<FlightSqlAuthMethod> auth_method =
166168
FlightSqlAuthMethod::FromProperties(flight_client, properties);
@@ -364,7 +366,7 @@ void FlightSqlConnection::Close() {
364366

365367
std::shared_ptr<Statement> FlightSqlConnection::CreateStatement() {
366368
return std::shared_ptr<Statement>(new FlightSqlStatement(
367-
diagnostics_, *sql_client_, call_options_, metadata_settings_));
369+
diagnostics_, *sql_client_, client_options_, call_options_, metadata_settings_));
368370
}
369371

370372
bool FlightSqlConnection::SetAttribute(Connection::AttributeId attribute,
@@ -410,7 +412,7 @@ FlightSqlConnection::FlightSqlConnection(OdbcVersion odbc_version,
410412
const std::string& driver_version)
411413
: diagnostics_("Apache Arrow", "Flight SQL", odbc_version),
412414
odbc_version_(odbc_version),
413-
info_(call_options_, sql_client_, driver_version),
415+
info_(client_options_, call_options_, sql_client_, driver_version),
414416
closed_(true) {
415417
attribute_[CONNECTION_DEAD] = static_cast<uint32_t>(SQL_TRUE);
416418
attribute_[LOGIN_TIMEOUT] = static_cast<uint32_t>(0);

cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@
2929
namespace arrow::flight::sql::odbc {
3030

3131
FlightSqlResultSet::FlightSqlResultSet(
32-
FlightSqlClient& flight_sql_client, const FlightCallOptions& call_options,
33-
const std::shared_ptr<FlightInfo>& flight_info,
32+
FlightSqlClient& flight_sql_client, const FlightClientOptions& client_options,
33+
const FlightCallOptions& call_options, const std::shared_ptr<FlightInfo>& flight_info,
3434
const std::shared_ptr<RecordBatchTransformer>& transformer, Diagnostics& diagnostics,
3535
const MetadataSettings& metadata_settings)
3636
: metadata_settings_(metadata_settings),
37-
chunk_buffer_(flight_sql_client, call_options, flight_info,
37+
chunk_buffer_(flight_sql_client, client_options, call_options, flight_info,
3838
metadata_settings_.chunk_buffer_capacity),
3939
transformer_(transformer),
4040
metadata_(transformer

cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class FlightSqlResultSet : public ResultSet {
5151
~FlightSqlResultSet() override;
5252

5353
FlightSqlResultSet(FlightSqlClient& flight_sql_client,
54+
const FlightClientOptions& client_options,
5455
const FlightCallOptions& call_options,
5556
const std::shared_ptr<FlightInfo>& flight_info,
5657
const std::shared_ptr<RecordBatchTransformer>& transformer,

cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.cc

Lines changed: 43 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@ using util::ThrowIfNotOK;
4141

4242
namespace {
4343

44-
void ClosePreparedStatementIfAny(std::shared_ptr<PreparedStatement>& prepared_statement) {
44+
void ClosePreparedStatementIfAny(std::shared_ptr<PreparedStatement>& prepared_statement,
45+
const FlightCallOptions& options) {
4546
if (prepared_statement != nullptr) {
46-
ThrowIfNotOK(prepared_statement->Close());
47+
ThrowIfNotOK(prepared_statement->Close(options));
4748
prepared_statement.reset();
4849
}
4950
}
@@ -52,11 +53,13 @@ void ClosePreparedStatementIfAny(std::shared_ptr<PreparedStatement>& prepared_st
5253

5354
FlightSqlStatement::FlightSqlStatement(const Diagnostics& diagnostics,
5455
FlightSqlClient& sql_client,
56+
FlightClientOptions client_options,
5557
FlightCallOptions call_options,
5658
const MetadataSettings& metadata_settings)
5759
: diagnostics_("Apache Arrow", diagnostics.GetDataSourceComponent(),
5860
diagnostics.GetOdbcVersion()),
5961
sql_client_(sql_client),
62+
client_options_(std::move(client_options)),
6063
call_options_(std::move(call_options)),
6164
metadata_settings_(metadata_settings) {
6265
attribute_[METADATA_ID] = static_cast<size_t>(SQL_FALSE);
@@ -97,7 +100,7 @@ boost::optional<Statement::Attribute> FlightSqlStatement::GetAttribute(
97100

98101
boost::optional<std::shared_ptr<ResultSetMetadata>> FlightSqlStatement::Prepare(
99102
const std::string& query) {
100-
ClosePreparedStatementIfAny(prepared_statement_);
103+
ClosePreparedStatementIfAny(prepared_statement_, call_options_);
101104

102105
Result<std::shared_ptr<PreparedStatement>> result =
103106
sql_client_.Prepare(call_options_, query);
@@ -111,27 +114,30 @@ boost::optional<std::shared_ptr<ResultSetMetadata>> FlightSqlStatement::Prepare(
111114
}
112115

113116
bool FlightSqlStatement::ExecutePrepared() {
117+
// GH-47990 TODO: use DCHECK instead of assert
114118
assert(prepared_statement_.get() != nullptr);
115119

116-
Result<std::shared_ptr<FlightInfo>> result = prepared_statement_->Execute();
120+
Result<std::shared_ptr<FlightInfo>> result =
121+
prepared_statement_->Execute(call_options_);
122+
117123
ThrowIfNotOK(result.status());
118124

119125
current_result_set_ = std::make_shared<FlightSqlResultSet>(
120-
sql_client_, call_options_, result.ValueOrDie(), nullptr, diagnostics_,
121-
metadata_settings_);
126+
sql_client_, client_options_, call_options_, result.ValueOrDie(), nullptr,
127+
diagnostics_, metadata_settings_);
122128

123129
return true;
124130
}
125131

126132
bool FlightSqlStatement::Execute(const std::string& query) {
127-
ClosePreparedStatementIfAny(prepared_statement_);
133+
ClosePreparedStatementIfAny(prepared_statement_, call_options_);
128134

129135
Result<std::shared_ptr<FlightInfo>> result = sql_client_.Execute(call_options_, query);
130136
ThrowIfNotOK(result.status());
131137

132138
current_result_set_ = std::make_shared<FlightSqlResultSet>(
133-
sql_client_, call_options_, result.ValueOrDie(), nullptr, diagnostics_,
134-
metadata_settings_);
139+
sql_client_, client_options_, call_options_, result.ValueOrDie(), nullptr,
140+
diagnostics_, metadata_settings_);
135141

136142
return true;
137143
}
@@ -146,33 +152,35 @@ std::shared_ptr<ResultSet> FlightSqlStatement::GetTables(
146152
const std::string* catalog_name, const std::string* schema_name,
147153
const std::string* table_name, const std::string* table_type,
148154
const ColumnNames& column_names) {
149-
ClosePreparedStatementIfAny(prepared_statement_);
155+
ClosePreparedStatementIfAny(prepared_statement_, call_options_);
150156

151157
std::vector<std::string> table_types;
152158

153159
if ((catalog_name && *catalog_name == "%") && (schema_name && schema_name->empty()) &&
154160
(table_name && table_name->empty())) {
155-
current_result_set_ = GetTablesForSQLAllCatalogs(
156-
column_names, call_options_, sql_client_, diagnostics_, metadata_settings_);
161+
current_result_set_ =
162+
GetTablesForSQLAllCatalogs(column_names, client_options_, call_options_,
163+
sql_client_, diagnostics_, metadata_settings_);
157164
} else if ((catalog_name && catalog_name->empty()) &&
158165
(schema_name && *schema_name == "%") &&
159166
(table_name && table_name->empty())) {
160-
current_result_set_ =
161-
GetTablesForSQLAllDbSchemas(column_names, call_options_, sql_client_, schema_name,
162-
diagnostics_, metadata_settings_);
167+
current_result_set_ = GetTablesForSQLAllDbSchemas(
168+
column_names, client_options_, call_options_, sql_client_, schema_name,
169+
diagnostics_, metadata_settings_);
163170
} else if ((catalog_name && catalog_name->empty()) &&
164171
(schema_name && schema_name->empty()) &&
165172
(table_name && table_name->empty()) && (table_type && *table_type == "%")) {
166-
current_result_set_ = GetTablesForSQLAllTableTypes(
167-
column_names, call_options_, sql_client_, diagnostics_, metadata_settings_);
173+
current_result_set_ =
174+
GetTablesForSQLAllTableTypes(column_names, client_options_, call_options_,
175+
sql_client_, diagnostics_, metadata_settings_);
168176
} else {
169177
if (table_type) {
170178
ParseTableTypes(*table_type, table_types);
171179
}
172180

173181
current_result_set_ = GetTablesForGenericUse(
174-
column_names, call_options_, sql_client_, catalog_name, schema_name, table_name,
175-
table_types, diagnostics_, metadata_settings_);
182+
column_names, client_options_, call_options_, sql_client_, catalog_name,
183+
schema_name, table_name, table_types, diagnostics_, metadata_settings_);
176184
}
177185

178186
return current_result_set_;
@@ -199,7 +207,7 @@ std::shared_ptr<ResultSet> FlightSqlStatement::GetTables_V3(
199207
std::shared_ptr<ResultSet> FlightSqlStatement::GetColumns_V2(
200208
const std::string* catalog_name, const std::string* schema_name,
201209
const std::string* table_name, const std::string* column_name) {
202-
ClosePreparedStatementIfAny(prepared_statement_);
210+
ClosePreparedStatementIfAny(prepared_statement_, call_options_);
203211

204212
Result<std::shared_ptr<FlightInfo>> result = sql_client_.GetTables(
205213
call_options_, catalog_name, schema_name, table_name, true, nullptr);
@@ -210,17 +218,17 @@ std::shared_ptr<ResultSet> FlightSqlStatement::GetColumns_V2(
210218
auto transformer = std::make_shared<GetColumns_Transformer>(
211219
metadata_settings_, OdbcVersion::V_2, column_name);
212220

213-
current_result_set_ =
214-
std::make_shared<FlightSqlResultSet>(sql_client_, call_options_, flight_info,
215-
transformer, diagnostics_, metadata_settings_);
221+
current_result_set_ = std::make_shared<FlightSqlResultSet>(
222+
sql_client_, client_options_, call_options_, flight_info, transformer, diagnostics_,
223+
metadata_settings_);
216224

217225
return current_result_set_;
218226
}
219227

220228
std::shared_ptr<ResultSet> FlightSqlStatement::GetColumns_V3(
221229
const std::string* catalog_name, const std::string* schema_name,
222230
const std::string* table_name, const std::string* column_name) {
223-
ClosePreparedStatementIfAny(prepared_statement_);
231+
ClosePreparedStatementIfAny(prepared_statement_, call_options_);
224232

225233
Result<std::shared_ptr<FlightInfo>> result = sql_client_.GetTables(
226234
call_options_, catalog_name, schema_name, table_name, true, nullptr);
@@ -231,15 +239,15 @@ std::shared_ptr<ResultSet> FlightSqlStatement::GetColumns_V3(
231239
auto transformer = std::make_shared<GetColumns_Transformer>(
232240
metadata_settings_, OdbcVersion::V_3, column_name);
233241

234-
current_result_set_ =
235-
std::make_shared<FlightSqlResultSet>(sql_client_, call_options_, flight_info,
236-
transformer, diagnostics_, metadata_settings_);
242+
current_result_set_ = std::make_shared<FlightSqlResultSet>(
243+
sql_client_, client_options_, call_options_, flight_info, transformer, diagnostics_,
244+
metadata_settings_);
237245

238246
return current_result_set_;
239247
}
240248

241249
std::shared_ptr<ResultSet> FlightSqlStatement::GetTypeInfo_V2(int16_t data_type) {
242-
ClosePreparedStatementIfAny(prepared_statement_);
250+
ClosePreparedStatementIfAny(prepared_statement_, call_options_);
243251

244252
Result<std::shared_ptr<FlightInfo>> result = sql_client_.GetXdbcTypeInfo(call_options_);
245253
ThrowIfNotOK(result.status());
@@ -249,15 +257,15 @@ std::shared_ptr<ResultSet> FlightSqlStatement::GetTypeInfo_V2(int16_t data_type)
249257
auto transformer = std::make_shared<GetTypeInfoTransformer>(
250258
metadata_settings_, OdbcVersion::V_2, data_type);
251259

252-
current_result_set_ =
253-
std::make_shared<FlightSqlResultSet>(sql_client_, call_options_, flight_info,
254-
transformer, diagnostics_, metadata_settings_);
260+
current_result_set_ = std::make_shared<FlightSqlResultSet>(
261+
sql_client_, client_options_, call_options_, flight_info, transformer, diagnostics_,
262+
metadata_settings_);
255263

256264
return current_result_set_;
257265
}
258266

259267
std::shared_ptr<ResultSet> FlightSqlStatement::GetTypeInfo_V3(int16_t data_type) {
260-
ClosePreparedStatementIfAny(prepared_statement_);
268+
ClosePreparedStatementIfAny(prepared_statement_, call_options_);
261269

262270
Result<std::shared_ptr<FlightInfo>> result = sql_client_.GetXdbcTypeInfo(call_options_);
263271
ThrowIfNotOK(result.status());
@@ -267,9 +275,9 @@ std::shared_ptr<ResultSet> FlightSqlStatement::GetTypeInfo_V3(int16_t data_type)
267275
auto transformer = std::make_shared<GetTypeInfoTransformer>(
268276
metadata_settings_, OdbcVersion::V_3, data_type);
269277

270-
current_result_set_ =
271-
std::make_shared<FlightSqlResultSet>(sql_client_, call_options_, flight_info,
272-
transformer, diagnostics_, metadata_settings_);
278+
current_result_set_ = std::make_shared<FlightSqlResultSet>(
279+
sql_client_, client_options_, call_options_, flight_info, transformer, diagnostics_,
280+
metadata_settings_);
273281

274282
return current_result_set_;
275283
}

cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class FlightSqlStatement : public Statement {
3232
private:
3333
Diagnostics diagnostics_;
3434
std::map<StatementAttributeId, Attribute> attribute_;
35+
FlightClientOptions client_options_;
3536
FlightCallOptions call_options_;
3637
FlightSqlClient& sql_client_;
3738
std::shared_ptr<ResultSet> current_result_set_;
@@ -46,7 +47,7 @@ class FlightSqlStatement : public Statement {
4647

4748
public:
4849
FlightSqlStatement(const Diagnostics& diagnostics, FlightSqlClient& sql_client,
49-
FlightCallOptions call_options,
50+
FlightClientOptions client_options, FlightCallOptions call_options,
5051
const MetadataSettings& metadata_settings);
5152

5253
bool SetAttribute(StatementAttributeId attribute, const Attribute& value) override;

0 commit comments

Comments
 (0)