Skip to content

Commit 862be07

Browse files
authored
[RSDK-11569] fix race when creating session (#1001)
1 parent 78c0141 commit 862be07

File tree

2 files changed

+44
-31
lines changed

2 files changed

+44
-31
lines changed

src/viam/sessions_client.py

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from enum import IntEnum
88
from threading import Lock, Thread
99
from typing import MutableMapping, Optional
10+
from concurrent.futures import ThreadPoolExecutor
11+
from contextlib import asynccontextmanager
1012

1113
from grpclib import Status
1214
from grpclib.client import Channel
@@ -46,6 +48,7 @@ class SessionsClient:
4648
_heartbeat_interval: Optional[timedelta]
4749
_supported: _SupportedState
4850
_thread: Optional[Thread]
51+
_pool: ThreadPoolExecutor
4952

5053
_HEARTBEAT_MONITORED_METHODS: MutableMapping[str, bool] = {}
5154

@@ -71,6 +74,7 @@ def __init__(
7174
self._heartbeat_interval = None
7275
self._supported = _SupportedState.UNKNOWN
7376
self._thread = None
77+
self._pool = ThreadPoolExecutor()
7478

7579
listen(self.channel, SendRequest, self._send_request)
7680
listen(self.channel, RecvTrailingMetadata, self._recv_trailers)
@@ -104,40 +108,45 @@ async def _recv_trailers(self, event: RecvTrailingMetadata):
104108
if event.status == Status.INVALID_ARGUMENT and event.status_message == "SESSION_EXPIRED":
105109
LOGGER.debug("Session expired")
106110
self.reset()
111+
@asynccontextmanager
112+
async def _acquire_lock_async(self):
113+
loop = asyncio.get_event_loop()
114+
await loop.run_in_executor(self._pool, self._lock.acquire)
115+
try:
116+
yield
117+
finally:
118+
self._lock.release()
107119

108120
@property
109121
async def metadata(self) -> _MetadataLike:
110-
with self._lock:
122+
async with self._acquire_lock_async():
111123
if self._disabled or self._supported != _SupportedState.UNKNOWN:
112124
return self._metadata
113125

114-
request = StartSessionRequest(resume=self._current_id)
115-
try:
116-
response: StartSessionResponse = await self.client.StartSession(request)
117-
except GRPCError as error:
118-
if error.status == Status.UNIMPLEMENTED:
119-
with self._lock:
120-
self._reset()
121-
self._supported = _SupportedState.FALSE
122-
return self._metadata
123-
else:
124-
raise
125-
126-
if response is None:
127-
raise GRPCError(status=Status.INTERNAL, message="Expected response to start session")
128-
129-
if response.heartbeat_window is None:
130-
raise GRPCError(status=Status.INTERNAL, message="Expected heartbeat window in response to start session")
126+
request = StartSessionRequest(resume=self._current_id)
127+
try:
128+
response: StartSessionResponse = await self.client.StartSession(request)
129+
except GRPCError as error:
130+
if error.status == Status.UNIMPLEMENTED:
131+
self._reset()
132+
self._supported = _SupportedState.FALSE
133+
return self._metadata
134+
else:
135+
raise
136+
137+
if response is None:
138+
raise GRPCError(status=Status.INTERNAL, message="Expected response to start session")
139+
140+
if response.heartbeat_window is None:
141+
raise GRPCError(status=Status.INTERNAL, message="Expected heartbeat window in response to start session")
131142

132-
with self._lock:
133143
self._supported = _SupportedState.TRUE
134144
self._heartbeat_interval = response.heartbeat_window.ToTimedelta()
135145
self._current_id = response.id
136146

137-
# tick once to ensure heartbeats are supported
138-
await self._heartbeat_tick(self.client)
147+
# tick once to ensure heartbeats are supported
148+
await self._heartbeat_tick(self.client)
139149

140-
with self._lock:
141150
if self._thread is not None:
142151
self._reset()
143152
if self._supported == _SupportedState.TRUE:
@@ -156,17 +165,16 @@ async def metadata(self) -> _MetadataLike:
156165
return self._metadata
157166

158167
async def _heartbeat_tick(self, client: RobotServiceStub):
159-
with self._lock:
160-
if not self._current_id:
161-
LOGGER.debug("Failed to send heartbeat, session client reset")
162-
return
163-
request = SendSessionHeartbeatRequest(id=self._current_id)
168+
if not self._current_id:
169+
LOGGER.debug("Failed to send heartbeat, session client reset")
170+
return
171+
request = SendSessionHeartbeatRequest(id=self._current_id)
164172

165173
try:
166174
await client.SendSessionHeartbeat(request)
167175
except (GRPCError, StreamTerminatedError):
168176
LOGGER.debug("Heartbeat terminated", exc_info=True)
169-
self.reset()
177+
self._reset()
170178
else:
171179
LOGGER.debug("Sent heartbeat successfully")
172180

@@ -185,10 +193,10 @@ async def _heartbeat_process(self, wait: float):
185193
channel = await dial(address=addr, options=self._dial_options)
186194
client = RobotServiceStub(channel.channel)
187195
while True:
188-
with self._lock:
196+
async with self._acquire_lock_async():
189197
if self._supported != _SupportedState.TRUE:
190198
return
191-
await self._heartbeat_tick(client)
199+
await self._heartbeat_tick(client)
192200
await asyncio.sleep(wait)
193201

194202
@property

tests/test_sessions_client.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,12 @@ async def test_sessions_heartbeat_thread_blocked():
106106
channel = await dial(address=addr, options=options)
107107

108108
client = SessionsClient(channel.channel, addr, options)
109-
assert await client.metadata == {SESSION_METADATA_KEY: MockRobot.SESSION_ID}
109+
t1 = asyncio.create_task(client.metadata)
110+
t2 = asyncio.create_task(client.metadata)
111+
112+
await asyncio.gather(t1,t2)
113+
assert t1.result() == {SESSION_METADATA_KEY: MockRobot.SESSION_ID}
114+
assert t2.result() == {SESSION_METADATA_KEY: MockRobot.SESSION_ID}
110115

111116
assert client._supported == _SupportedState.TRUE
112117
assert client._heartbeat_interval and client._heartbeat_interval.total_seconds() == MockRobot.HEARTBEAT_INTERVAL

0 commit comments

Comments
 (0)