@@ -123,7 +123,89 @@ def _set_non_inheritable_non_atomic(fd: int) -> None: # noqa: ARG001
123
123
_IS_SYNC = False
124
124
125
125
126
- class AsyncConnection :
126
+ class AsyncBaseConnection :
127
+ """A base connection object for server and kms connections."""
128
+
129
+ def __init__ (self , conn : AsyncNetworkingInterface , opts : PoolOptions ):
130
+ self .conn = conn
131
+ self .socket_checker : SocketChecker = SocketChecker ()
132
+ self .cancel_context : _CancellationContext = _CancellationContext ()
133
+ self .is_sdam = False
134
+ self .closed = False
135
+ self .last_timeout : float | None = None
136
+ self .more_to_come = False
137
+ self .opts = opts
138
+ self .max_wire_version = - 1
139
+
140
+ def set_conn_timeout (self , timeout : Optional [float ]) -> None :
141
+ """Cache last timeout to avoid duplicate calls to conn.settimeout."""
142
+ if timeout == self .last_timeout :
143
+ return
144
+ self .last_timeout = timeout
145
+ self .conn .get_conn .settimeout (timeout )
146
+
147
+ def apply_timeout (
148
+ self , client : AsyncMongoClient [Any ], cmd : Optional [MutableMapping [str , Any ]]
149
+ ) -> Optional [float ]:
150
+ # CSOT: use remaining timeout when set.
151
+ timeout = _csot .remaining ()
152
+ if timeout is None :
153
+ # Reset the socket timeout unless we're performing a streaming monitor check.
154
+ if not self .more_to_come :
155
+ self .set_conn_timeout (self .opts .socket_timeout )
156
+ return None
157
+ # RTT validation.
158
+ rtt = _csot .get_rtt ()
159
+ if rtt is None :
160
+ rtt = self .connect_rtt
161
+ max_time_ms = timeout - rtt
162
+ if max_time_ms < 0 :
163
+ timeout_details = _get_timeout_details (self .opts )
164
+ formatted = format_timeout_details (timeout_details )
165
+ # CSOT: raise an error without running the command since we know it will time out.
166
+ errmsg = f"operation would exceed time limit, remaining timeout:{ timeout :.5f} <= network round trip time:{ rtt :.5f} { formatted } "
167
+ if self .max_wire_version != - 1 :
168
+ raise ExecutionTimeout (
169
+ errmsg ,
170
+ 50 ,
171
+ {"ok" : 0 , "errmsg" : errmsg , "code" : 50 },
172
+ self .max_wire_version ,
173
+ )
174
+ else :
175
+ raise TimeoutError (errmsg )
176
+ if cmd is not None :
177
+ cmd ["maxTimeMS" ] = int (max_time_ms * 1000 )
178
+ self .set_conn_timeout (timeout )
179
+ return timeout
180
+
181
+ async def close_conn (self , reason : Optional [str ]) -> None :
182
+ """Close this connection with a reason."""
183
+ if self .closed :
184
+ return
185
+ await self ._close_conn ()
186
+
187
+ async def _close_conn (self ) -> None :
188
+ """Close this connection."""
189
+ if self .closed :
190
+ return
191
+ self .closed = True
192
+ self .cancel_context .cancel ()
193
+ # Note: We catch exceptions to avoid spurious errors on interpreter
194
+ # shutdown.
195
+ try :
196
+ await self .conn .close ()
197
+ except Exception : # noqa: S110
198
+ pass
199
+
200
+ def conn_closed (self ) -> bool :
201
+ """Return True if we know socket has been closed, False otherwise."""
202
+ if _IS_SYNC :
203
+ return self .socket_checker .socket_closed (self .conn .get_conn )
204
+ else :
205
+ return self .conn .is_closing ()
206
+
207
+
208
+ class AsyncConnection (AsyncBaseConnection ):
127
209
"""Store a connection with some metadata.
128
210
129
211
:param conn: a raw connection object
@@ -141,29 +223,27 @@ def __init__(
141
223
id : int ,
142
224
is_sdam : bool ,
143
225
):
226
+ super ().__init__ (conn , pool .opts )
144
227
self .pool_ref = weakref .ref (pool )
145
- self .conn = conn
146
- self .address = address
147
- self .id = id
228
+ self .address : tuple [str , int ] = address
229
+ self .id : int = id
148
230
self .is_sdam = is_sdam
149
- self .closed = False
150
231
self .last_checkin_time = time .monotonic ()
151
232
self .performed_handshake = False
152
233
self .is_writable : bool = False
153
234
self .max_wire_version = MAX_WIRE_VERSION
154
- self .max_bson_size = MAX_BSON_SIZE
155
- self .max_message_size = MAX_MESSAGE_SIZE
156
- self .max_write_batch_size = MAX_WRITE_BATCH_SIZE
235
+ self .max_bson_size : int = MAX_BSON_SIZE
236
+ self .max_message_size : int = MAX_MESSAGE_SIZE
237
+ self .max_write_batch_size : int = MAX_WRITE_BATCH_SIZE
157
238
self .supports_sessions = False
158
239
self .hello_ok : bool = False
159
- self .is_mongos = False
240
+ self .is_mongos : bool = False
160
241
self .op_msg_enabled = False
161
242
self .listeners = pool .opts ._event_listeners
162
243
self .enabled_for_cmap = pool .enabled_for_cmap
163
244
self .enabled_for_logging = pool .enabled_for_logging
164
245
self .compression_settings = pool .opts ._compression_settings
165
246
self .compression_context : Union [SnappyContext , ZlibContext , ZstdContext , None ] = None
166
- self .socket_checker : SocketChecker = SocketChecker ()
167
247
self .oidc_token_gen_id : Optional [int ] = None
168
248
# Support for mechanism negotiation on the initial handshake.
169
249
self .negotiated_mechs : Optional [list [str ]] = None
@@ -174,9 +254,6 @@ def __init__(
174
254
self .pool_gen = pool .gen
175
255
self .generation = self .pool_gen .get_overall ()
176
256
self .ready = False
177
- self .cancel_context : _CancellationContext = _CancellationContext ()
178
- self .opts = pool .opts
179
- self .more_to_come : bool = False
180
257
# For load balancer support.
181
258
self .service_id : Optional [ObjectId ] = None
182
259
self .server_connection_id : Optional [int ] = None
@@ -192,44 +269,6 @@ def __init__(
192
269
# For gossiping $clusterTime from the connection handshake to the client.
193
270
self ._cluster_time = None
194
271
195
- def set_conn_timeout (self , timeout : Optional [float ]) -> None :
196
- """Cache last timeout to avoid duplicate calls to conn.settimeout."""
197
- if timeout == self .last_timeout :
198
- return
199
- self .last_timeout = timeout
200
- self .conn .get_conn .settimeout (timeout )
201
-
202
- def apply_timeout (
203
- self , client : AsyncMongoClient [Any ], cmd : Optional [MutableMapping [str , Any ]]
204
- ) -> Optional [float ]:
205
- # CSOT: use remaining timeout when set.
206
- timeout = _csot .remaining ()
207
- if timeout is None :
208
- # Reset the socket timeout unless we're performing a streaming monitor check.
209
- if not self .more_to_come :
210
- self .set_conn_timeout (self .opts .socket_timeout )
211
- return None
212
- # RTT validation.
213
- rtt = _csot .get_rtt ()
214
- if rtt is None :
215
- rtt = self .connect_rtt
216
- max_time_ms = timeout - rtt
217
- if max_time_ms < 0 :
218
- timeout_details = _get_timeout_details (self .opts )
219
- formatted = format_timeout_details (timeout_details )
220
- # CSOT: raise an error without running the command since we know it will time out.
221
- errmsg = f"operation would exceed time limit, remaining timeout:{ timeout :.5f} <= network round trip time:{ rtt :.5f} { formatted } "
222
- raise ExecutionTimeout (
223
- errmsg ,
224
- 50 ,
225
- {"ok" : 0 , "errmsg" : errmsg , "code" : 50 },
226
- self .max_wire_version ,
227
- )
228
- if cmd is not None :
229
- cmd ["maxTimeMS" ] = int (max_time_ms * 1000 )
230
- self .set_conn_timeout (timeout )
231
- return timeout
232
-
233
272
def pin_txn (self ) -> None :
234
273
self .pinned_txn = True
235
274
assert not self .pinned_cursor
@@ -573,26 +612,6 @@ async def close_conn(self, reason: Optional[str]) -> None:
573
612
error = reason ,
574
613
)
575
614
576
- async def _close_conn (self ) -> None :
577
- """Close this connection."""
578
- if self .closed :
579
- return
580
- self .closed = True
581
- self .cancel_context .cancel ()
582
- # Note: We catch exceptions to avoid spurious errors on interpreter
583
- # shutdown.
584
- try :
585
- await self .conn .close ()
586
- except Exception : # noqa: S110
587
- pass
588
-
589
- def conn_closed (self ) -> bool :
590
- """Return True if we know socket has been closed, False otherwise."""
591
- if _IS_SYNC :
592
- return self .socket_checker .socket_closed (self .conn .get_conn )
593
- else :
594
- return self .conn .is_closing ()
595
-
596
615
def send_cluster_time (
597
616
self ,
598
617
command : MutableMapping [str , Any ],
0 commit comments