Skip to content

Commit 1cc8640

Browse files
author
michael stack
committed
Add process-local s3blobstore connection pool isolation in simulation.
Prevent connection sharing and corruption across simulated processes. Add eager close to connections in simulation so less likely resources can be harvested by another. Some specialization around http for simulation case.
1 parent dfd8a96 commit 1cc8640

File tree

2 files changed

+98
-16
lines changed

2 files changed

+98
-16
lines changed

fdbclient/S3BlobStore.actor.cpp

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
#include "fdbclient/S3BlobStore.h"
2222

23+
#include <sstream>
24+
#include "fdbrpc/HTTP.h"
2325
#include "fdbclient/ClientKnobs.h"
2426
#include "fdbclient/Knobs.h"
2527
#include "flow/FastRef.h"
@@ -76,9 +78,6 @@ S3BlobStoreEndpoint::Stats S3BlobStoreEndpoint::s_stats;
7678
std::unique_ptr<S3BlobStoreEndpoint::BlobStats> S3BlobStoreEndpoint::blobStats;
7779
Future<Void> S3BlobStoreEndpoint::statsLogger = Never();
7880

79-
std::unordered_map<BlobStoreConnectionPoolKey, Reference<S3BlobStoreEndpoint::ConnectionPoolData>>
80-
S3BlobStoreEndpoint::globalConnectionPool;
81-
8281
S3BlobStoreEndpoint::BlobKnobs::BlobKnobs() {
8382
secure_connection = 1;
8483
connect_tries = CLIENT_KNOBS->BLOBSTORE_CONNECT_TRIES;
@@ -199,6 +198,11 @@ std::string S3BlobStoreEndpoint::BlobKnobs::getURLParameters() const {
199198
}
200199

201200
std::string guessRegionFromDomain(std::string domain) {
201+
// Special case for localhost/127.0.0.1 to prevent basic_string exception
202+
if (domain == "127.0.0.1" || domain == "localhost") {
203+
return "us-east-1";
204+
}
205+
202206
static const std::vector<const char*> knownServices = { "s3.", "cos.", "oss-", "obs." };
203207
boost::algorithm::to_lower(domain);
204208

@@ -843,6 +847,10 @@ ACTOR Future<S3BlobStoreEndpoint::ReusableConnection> connect_impl(Reference<S3B
843847
} else {
844848
wait(store(conn, INetworkConnections::net()->connect(host, service, isTLS)));
845849
}
850+
851+
// Ensure connection is valid before handshake
852+
ASSERT(conn.isValid());
853+
846854
wait(conn->connectHandshake());
847855

848856
TraceEvent("S3BlobStoreEndpointNewConnectionSuccess")
@@ -1030,6 +1038,12 @@ ACTOR Future<Reference<HTTP::IncomingResponse>> doRequest_impl(Reference<S3BlobS
10301038
req->data.headers["Host"] = bstore->host;
10311039
req->data.headers["Accept"] = "application/xml";
10321040

1041+
// In simulation, disable connection pooling for MockS3 to prevent NetSAV use-after-free crashes
1042+
// This forces connection closure after each request, preventing race conditions during coordinator shutdown
1043+
if (g_network->isSimulated() && bstore->host == "127.0.0.1") {
1044+
req->data.headers["Connection"] = "close";
1045+
}
1046+
10331047
// Avoid to send request with an empty resource.
10341048
if (resource.empty()) {
10351049
resource = "/";
@@ -1140,7 +1154,11 @@ ACTOR Future<Reference<HTTP::IncomingResponse>> doRequest_impl(Reference<S3BlobS
11401154
rconn.conn, dryrunRequest, bstore->sendRate, &bstore->s_stats.bytes_sent, bstore->recvRate);
11411155
Reference<HTTP::IncomingResponse> _dryrunR = wait(timeoutError(dryrunResponse, requestTimeout));
11421156
dryrunR = _dryrunR;
1143-
std::string s3Error = parseErrorCodeFromS3(dryrunR->data.content);
1157+
// Only parse S3 error code for error responses (4xx/5xx), not successful responses (2xx)
1158+
std::string s3Error;
1159+
if (dryrunR->code >= 400) {
1160+
s3Error = parseErrorCodeFromS3(dryrunR->data.content);
1161+
}
11441162
if (dryrunR->code == badRequestCode && isS3TokenError(s3Error)) {
11451163
// authentication fails and s3 token error persists, retry with a HEAD dryrun request
11461164
// to avoid sending duplicate data indefinitly to save network bandwidth
@@ -1263,7 +1281,12 @@ ACTOR Future<Reference<HTTP::IncomingResponse>> doRequest_impl(Reference<S3BlobS
12631281

12641282
if (!err.present()) {
12651283
event.detail("ResponseCode", r->code);
1266-
std::string s3Error = parseErrorCodeFromS3(r->data.content);
1284+
// Only parse S3 error code for real error responses (4xx/5xx), not successful responses (2xx)
1285+
// Skip parsing for simulated errors where response content is still binary data
1286+
std::string s3Error;
1287+
if (r->code >= 400 && !simulateS3TokenError) {
1288+
s3Error = parseErrorCodeFromS3(r->data.content);
1289+
}
12671290
event.detail("S3ErrorCode", s3Error);
12681291
if (r->code == badRequestCode) {
12691292
if (isS3TokenError(s3Error) || simulateS3TokenError) {
@@ -1460,7 +1483,8 @@ ACTOR Future<Void> listObjectsStream_impl(Reference<S3BlobStoreEndpoint> bstore,
14601483
if (key == nullptr) {
14611484
throw http_bad_response();
14621485
}
1463-
object.name = key->value();
1486+
// URL decode the object name since S3 XML responses contain URL-encoded names
1487+
object.name = HTTP::urlDecode(key->value());
14641488

14651489
xml_node<>* size = n->first_node("Size");
14661490
if (size == nullptr) {
@@ -2035,8 +2059,11 @@ ACTOR Future<int> readObject_impl(Reference<S3BlobStoreEndpoint> bstore,
20352059
try {
20362060
// Copy the output bytes, server could have sent more or less bytes than requested so copy at most length
20372061
// bytes
2038-
memcpy(data, r->data.content.data(), std::min<int64_t>(r->data.contentLen, length));
2039-
return r->data.contentLen;
2062+
int bytesToCopy = std::min<int64_t>(r->data.contentLen, length);
2063+
memcpy(data, r->data.content.data(), bytesToCopy);
2064+
// Return the number of bytes actually copied, not the contentLen
2065+
// This ensures AsyncFileEncrypted gets blocks of the correct size (4KB)
2066+
return bytesToCopy;
20402067
} catch (Error& e) {
20412068
TraceEvent(SevWarn, "S3BlobStoreReadObjectMemcpyError").detail("Error", e.what());
20422069
throw io_error();
@@ -2441,4 +2468,4 @@ TEST_CASE("/backup/s3/guess_region") {
24412468
ASSERT_EQ(e.code(), error_code_backup_invalid_url);
24422469
}
24432470
return Void();
2444-
}
2471+
}

fdbclient/include/fdbclient/S3BlobStore.h

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,14 @@ class S3BlobStoreEndpoint : public ReferenceCounted<S3BlobStoreEndpoint> {
117117
void maybeStartStatsLogger() {
118118
if (!blobStats && CLIENT_KNOBS->BLOBSTORE_ENABLE_LOGGING) {
119119
blobStats = std::make_unique<BlobStats>();
120-
specialCounter(
121-
blobStats->cc, "GlobalConnectionPoolCount", [this]() { return this->globalConnectionPool.size(); });
120+
specialCounter(blobStats->cc, "GlobalConnectionPoolCount", [this]() {
121+
return this->getGlobalConnectionPool().size();
122+
});
122123
specialCounter(blobStats->cc, "GlobalConnectionPoolSize", [this]() {
123124
// FIXME: could track this explicitly via an int variable with extra logic, but this should be small and
124125
// infrequent
125126
int totalConnections = 0;
126-
for (auto& it : this->globalConnectionPool) {
127+
for (auto& it : this->getGlobalConnectionPool()) {
127128
totalConnections += it.second->pool.size();
128129
}
129130
return totalConnections;
@@ -200,6 +201,48 @@ class S3BlobStoreEndpoint : public ReferenceCounted<S3BlobStoreEndpoint> {
200201
struct ReusableConnection {
201202
Reference<IConnection> conn;
202203
double expirationTime;
204+
// CROSS_PROCESS_FIX: Track which process created this connection
205+
NetworkAddress creatingProcess;
206+
207+
ReusableConnection() : expirationTime(0) {
208+
if (g_network && g_network->isSimulated()) {
209+
creatingProcess = g_network->getLocalAddress();
210+
}
211+
}
212+
213+
ReusableConnection(Reference<IConnection> c, double exp) : conn(c), expirationTime(exp) {
214+
if (g_network && g_network->isSimulated()) {
215+
creatingProcess = g_network->getLocalAddress();
216+
}
217+
}
218+
219+
// CROSS_PROCESS_FIX: Copy constructor with cross-process detection
220+
ReusableConnection(const ReusableConnection& other)
221+
: conn(other.conn), expirationTime(other.expirationTime), creatingProcess(other.creatingProcess) {
222+
if (g_network && g_network->isSimulated() && creatingProcess.isValid() &&
223+
creatingProcess != g_network->getLocalAddress()) {
224+
// Cross-process copy detected - invalidate the connection to prevent sharing
225+
conn = Reference<IConnection>();
226+
expirationTime = 0; // Mark as expired
227+
}
228+
}
229+
230+
// CROSS_PROCESS_FIX: Assignment operator with cross-process detection
231+
ReusableConnection& operator=(const ReusableConnection& other) {
232+
if (this != &other) {
233+
conn = other.conn;
234+
expirationTime = other.expirationTime;
235+
creatingProcess = other.creatingProcess;
236+
237+
if (g_network && g_network->isSimulated() && creatingProcess.isValid() &&
238+
creatingProcess != g_network->getLocalAddress()) {
239+
// Cross-process assignment detected - invalidate the connection to prevent sharing
240+
conn = Reference<IConnection>();
241+
expirationTime = 0; // Mark as expired
242+
}
243+
}
244+
return *this;
245+
}
203246
};
204247

205248
// basically, reference counted queue with option to add other fields
@@ -208,7 +251,19 @@ class S3BlobStoreEndpoint : public ReferenceCounted<S3BlobStoreEndpoint> {
208251
};
209252

210253
// global connection pool for multiple blobstore endpoints with same connection settings and request destination
211-
static std::unordered_map<BlobStoreConnectionPoolKey, Reference<ConnectionPoolData>> globalConnectionPool;
254+
// CROSS_PROCESS_FIX: Make connection pool process-local to prevent cross-process connection sharing
255+
static std::unordered_map<BlobStoreConnectionPoolKey, Reference<ConnectionPoolData>>& getGlobalConnectionPool() {
256+
// Use process address as key to separate connection pools per simulated process
257+
static std::map<NetworkAddress, std::unordered_map<BlobStoreConnectionPoolKey, Reference<ConnectionPoolData>>>
258+
processConnectionPools;
259+
260+
NetworkAddress currentProcess;
261+
if (g_network && g_network->isSimulated()) {
262+
currentProcess = g_network->getLocalAddress();
263+
}
264+
265+
return processConnectionPools[currentProcess];
266+
}
212267

213268
S3BlobStoreEndpoint(std::string const& host,
214269
std::string const& service,
@@ -241,12 +296,12 @@ class S3BlobStoreEndpoint : public ReferenceCounted<S3BlobStoreEndpoint> {
241296
connectionPool = makeReference<ConnectionPoolData>();
242297
} else {
243298
BlobStoreConnectionPoolKey key(host, service, region, knobs.isTLS());
244-
auto it = globalConnectionPool.find(key);
245-
if (it != globalConnectionPool.end()) {
299+
auto it = getGlobalConnectionPool().find(key);
300+
if (it != getGlobalConnectionPool().end()) {
246301
connectionPool = it->second;
247302
} else {
248303
connectionPool = makeReference<ConnectionPoolData>();
249-
globalConnectionPool.insert({ key, connectionPool });
304+
getGlobalConnectionPool().insert({ key, connectionPool });
250305
}
251306
}
252307
ASSERT(connectionPool.isValid());

0 commit comments

Comments
 (0)