Skip to content

Commit e4b7eb5

Browse files
authored
PYTHON-5215 Add an asyncio.Protocol implementation for KMS (#2460)
1 parent 37d327f commit e4b7eb5

File tree

7 files changed

+398
-573
lines changed

7 files changed

+398
-573
lines changed

pymongo/asynchronous/encryption.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
from pymongo.asynchronous.cursor import AsyncCursor
6565
from pymongo.asynchronous.database import AsyncDatabase
6666
from pymongo.asynchronous.mongo_client import AsyncMongoClient
67+
from pymongo.asynchronous.pool import AsyncBaseConnection
6768
from pymongo.common import CONNECT_TIMEOUT
6869
from pymongo.daemon import _spawn_daemon
6970
from pymongo.encryption_options import AutoEncryptionOpts, RangeOpts
@@ -76,11 +77,11 @@
7677
ServerSelectionTimeoutError,
7778
)
7879
from pymongo.helpers_shared import _get_timeout_details
79-
from pymongo.network_layer import async_socket_sendall
80+
from pymongo.network_layer import PyMongoKMSProtocol, async_receive_kms, async_sendall
8081
from pymongo.operations import UpdateOne
8182
from pymongo.pool_options import PoolOptions
8283
from pymongo.pool_shared import (
83-
_async_configured_socket,
84+
_configured_protocol_interface,
8485
_raise_connection_failure,
8586
)
8687
from pymongo.read_concern import ReadConcern
@@ -93,10 +94,8 @@
9394
if TYPE_CHECKING:
9495
from pymongocrypt.mongocrypt import MongoCryptKmsContext
9596

96-
from pymongo.pyopenssl_context import _sslConn
9797
from pymongo.typings import _Address
9898

99-
10099
_IS_SYNC = False
101100

102101
_HTTPS_PORT = 443
@@ -111,9 +110,10 @@
111110
_KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument)
112111

113112

114-
async def _connect_kms(address: _Address, opts: PoolOptions) -> Union[socket.socket, _sslConn]:
113+
async def _connect_kms(address: _Address, opts: PoolOptions) -> AsyncBaseConnection:
115114
try:
116-
return await _async_configured_socket(address, opts)
115+
interface = await _configured_protocol_interface(address, opts, PyMongoKMSProtocol)
116+
return AsyncBaseConnection(interface, opts)
117117
except Exception as exc:
118118
_raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts))
119119

@@ -198,18 +198,11 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
198198
try:
199199
conn = await _connect_kms(address, opts)
200200
try:
201-
await async_socket_sendall(conn, message)
201+
await async_sendall(conn.conn.get_conn, message)
202202
while kms_context.bytes_needed > 0:
203203
# CSOT: update timeout.
204-
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
205-
if _IS_SYNC:
206-
data = conn.recv(kms_context.bytes_needed)
207-
else:
208-
from pymongo.network_layer import ( # type: ignore[attr-defined]
209-
async_receive_data_socket,
210-
)
211-
212-
data = await async_receive_data_socket(conn, kms_context.bytes_needed)
204+
conn.set_conn_timeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
205+
data = await async_receive_kms(conn, kms_context.bytes_needed)
213206
if not data:
214207
raise OSError("KMS connection closed")
215208
kms_context.feed(data)
@@ -228,7 +221,7 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
228221
address, exc, msg_prefix=msg_prefix, timeout_details=_get_timeout_details(opts)
229222
)
230223
finally:
231-
conn.close()
224+
await conn.close_conn(None)
232225
except MongoCryptError:
233226
raise # Propagate MongoCryptError errors directly.
234227
except Exception as exc:

pymongo/asynchronous/pool.py

Lines changed: 90 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,89 @@ def _set_non_inheritable_non_atomic(fd: int) -> None: # noqa: ARG001
123123
_IS_SYNC = False
124124

125125

126-
class AsyncConnection:
126+
class AsyncBaseConnection:
127+
"""A base connection object for server and kms connections."""
128+
129+
def __init__(self, conn: AsyncNetworkingInterface, opts: PoolOptions):
130+
self.conn = conn
131+
self.socket_checker: SocketChecker = SocketChecker()
132+
self.cancel_context: _CancellationContext = _CancellationContext()
133+
self.is_sdam = False
134+
self.closed = False
135+
self.last_timeout: float | None = None
136+
self.more_to_come = False
137+
self.opts = opts
138+
self.max_wire_version = -1
139+
140+
def set_conn_timeout(self, timeout: Optional[float]) -> None:
141+
"""Cache last timeout to avoid duplicate calls to conn.settimeout."""
142+
if timeout == self.last_timeout:
143+
return
144+
self.last_timeout = timeout
145+
self.conn.get_conn.settimeout(timeout)
146+
147+
def apply_timeout(
148+
self, client: AsyncMongoClient[Any], cmd: Optional[MutableMapping[str, Any]]
149+
) -> Optional[float]:
150+
# CSOT: use remaining timeout when set.
151+
timeout = _csot.remaining()
152+
if timeout is None:
153+
# Reset the socket timeout unless we're performing a streaming monitor check.
154+
if not self.more_to_come:
155+
self.set_conn_timeout(self.opts.socket_timeout)
156+
return None
157+
# RTT validation.
158+
rtt = _csot.get_rtt()
159+
if rtt is None:
160+
rtt = self.connect_rtt
161+
max_time_ms = timeout - rtt
162+
if max_time_ms < 0:
163+
timeout_details = _get_timeout_details(self.opts)
164+
formatted = format_timeout_details(timeout_details)
165+
# CSOT: raise an error without running the command since we know it will time out.
166+
errmsg = f"operation would exceed time limit, remaining timeout:{timeout:.5f} <= network round trip time:{rtt:.5f} {formatted}"
167+
if self.max_wire_version != -1:
168+
raise ExecutionTimeout(
169+
errmsg,
170+
50,
171+
{"ok": 0, "errmsg": errmsg, "code": 50},
172+
self.max_wire_version,
173+
)
174+
else:
175+
raise TimeoutError(errmsg)
176+
if cmd is not None:
177+
cmd["maxTimeMS"] = int(max_time_ms * 1000)
178+
self.set_conn_timeout(timeout)
179+
return timeout
180+
181+
async def close_conn(self, reason: Optional[str]) -> None:
182+
"""Close this connection with a reason."""
183+
if self.closed:
184+
return
185+
await self._close_conn()
186+
187+
async def _close_conn(self) -> None:
188+
"""Close this connection."""
189+
if self.closed:
190+
return
191+
self.closed = True
192+
self.cancel_context.cancel()
193+
# Note: We catch exceptions to avoid spurious errors on interpreter
194+
# shutdown.
195+
try:
196+
await self.conn.close()
197+
except Exception: # noqa: S110
198+
pass
199+
200+
def conn_closed(self) -> bool:
201+
"""Return True if we know socket has been closed, False otherwise."""
202+
if _IS_SYNC:
203+
return self.socket_checker.socket_closed(self.conn.get_conn)
204+
else:
205+
return self.conn.is_closing()
206+
207+
208+
class AsyncConnection(AsyncBaseConnection):
127209
"""Store a connection with some metadata.
128210
129211
:param conn: a raw connection object
@@ -141,29 +223,27 @@ def __init__(
141223
id: int,
142224
is_sdam: bool,
143225
):
226+
super().__init__(conn, pool.opts)
144227
self.pool_ref = weakref.ref(pool)
145-
self.conn = conn
146-
self.address = address
147-
self.id = id
228+
self.address: tuple[str, int] = address
229+
self.id: int = id
148230
self.is_sdam = is_sdam
149-
self.closed = False
150231
self.last_checkin_time = time.monotonic()
151232
self.performed_handshake = False
152233
self.is_writable: bool = False
153234
self.max_wire_version = MAX_WIRE_VERSION
154-
self.max_bson_size = MAX_BSON_SIZE
155-
self.max_message_size = MAX_MESSAGE_SIZE
156-
self.max_write_batch_size = MAX_WRITE_BATCH_SIZE
235+
self.max_bson_size: int = MAX_BSON_SIZE
236+
self.max_message_size: int = MAX_MESSAGE_SIZE
237+
self.max_write_batch_size: int = MAX_WRITE_BATCH_SIZE
157238
self.supports_sessions = False
158239
self.hello_ok: bool = False
159-
self.is_mongos = False
240+
self.is_mongos: bool = False
160241
self.op_msg_enabled = False
161242
self.listeners = pool.opts._event_listeners
162243
self.enabled_for_cmap = pool.enabled_for_cmap
163244
self.enabled_for_logging = pool.enabled_for_logging
164245
self.compression_settings = pool.opts._compression_settings
165246
self.compression_context: Union[SnappyContext, ZlibContext, ZstdContext, None] = None
166-
self.socket_checker: SocketChecker = SocketChecker()
167247
self.oidc_token_gen_id: Optional[int] = None
168248
# Support for mechanism negotiation on the initial handshake.
169249
self.negotiated_mechs: Optional[list[str]] = None
@@ -174,9 +254,6 @@ def __init__(
174254
self.pool_gen = pool.gen
175255
self.generation = self.pool_gen.get_overall()
176256
self.ready = False
177-
self.cancel_context: _CancellationContext = _CancellationContext()
178-
self.opts = pool.opts
179-
self.more_to_come: bool = False
180257
# For load balancer support.
181258
self.service_id: Optional[ObjectId] = None
182259
self.server_connection_id: Optional[int] = None
@@ -192,44 +269,6 @@ def __init__(
192269
# For gossiping $clusterTime from the connection handshake to the client.
193270
self._cluster_time = None
194271

195-
def set_conn_timeout(self, timeout: Optional[float]) -> None:
196-
"""Cache last timeout to avoid duplicate calls to conn.settimeout."""
197-
if timeout == self.last_timeout:
198-
return
199-
self.last_timeout = timeout
200-
self.conn.get_conn.settimeout(timeout)
201-
202-
def apply_timeout(
203-
self, client: AsyncMongoClient[Any], cmd: Optional[MutableMapping[str, Any]]
204-
) -> Optional[float]:
205-
# CSOT: use remaining timeout when set.
206-
timeout = _csot.remaining()
207-
if timeout is None:
208-
# Reset the socket timeout unless we're performing a streaming monitor check.
209-
if not self.more_to_come:
210-
self.set_conn_timeout(self.opts.socket_timeout)
211-
return None
212-
# RTT validation.
213-
rtt = _csot.get_rtt()
214-
if rtt is None:
215-
rtt = self.connect_rtt
216-
max_time_ms = timeout - rtt
217-
if max_time_ms < 0:
218-
timeout_details = _get_timeout_details(self.opts)
219-
formatted = format_timeout_details(timeout_details)
220-
# CSOT: raise an error without running the command since we know it will time out.
221-
errmsg = f"operation would exceed time limit, remaining timeout:{timeout:.5f} <= network round trip time:{rtt:.5f} {formatted}"
222-
raise ExecutionTimeout(
223-
errmsg,
224-
50,
225-
{"ok": 0, "errmsg": errmsg, "code": 50},
226-
self.max_wire_version,
227-
)
228-
if cmd is not None:
229-
cmd["maxTimeMS"] = int(max_time_ms * 1000)
230-
self.set_conn_timeout(timeout)
231-
return timeout
232-
233272
def pin_txn(self) -> None:
234273
self.pinned_txn = True
235274
assert not self.pinned_cursor
@@ -573,26 +612,6 @@ async def close_conn(self, reason: Optional[str]) -> None:
573612
error=reason,
574613
)
575614

576-
async def _close_conn(self) -> None:
577-
"""Close this connection."""
578-
if self.closed:
579-
return
580-
self.closed = True
581-
self.cancel_context.cancel()
582-
# Note: We catch exceptions to avoid spurious errors on interpreter
583-
# shutdown.
584-
try:
585-
await self.conn.close()
586-
except Exception: # noqa: S110
587-
pass
588-
589-
def conn_closed(self) -> bool:
590-
"""Return True if we know socket has been closed, False otherwise."""
591-
if _IS_SYNC:
592-
return self.socket_checker.socket_closed(self.conn.get_conn)
593-
else:
594-
return self.conn.is_closing()
595-
596615
def send_cluster_time(
597616
self,
598617
command: MutableMapping[str, Any],

0 commit comments

Comments
 (0)