Skip to content

Commit ae48ba5

Browse files
yancanmaoMao Yancanedoakes
authored andcommitted
gc collect from a gc_thread (ray-project#55838)
Ray proactively triggers gc.collect() on idle workers to release Python objects that may still hold Plasma shared memory (shm) references. In the current implementation in (_raylet.pyx gc_collect()), Ray calls gc.collect() from Cython under a with gil block periodically. If the Python object graph is complex (e.g., cyclic references with finalizers), gc.collect() may take a long time. During this period, since the GIL is held for the entire collection, user code is completely frozen if gc.collect() time is longer than the periodic interval (e.g., 10s). We propose decoupling GC execution from the RPC call: gc_collect in Cython should not directly run gc.collect(). Instead, it should "signal an event" with minimum execution time (e.g., using a threading.Event or similar). A dedicated Python GC thread consumes this event and executes gc.collect() asynchronously, with a configurable GC interval. ## Related issue number Closes ray-project#55837 --------- Signed-off-by: Mao Yancan <[email protected]> Co-authored-by: Mao Yancan <[email protected]> Co-authored-by: Edward Oakes <[email protected]> Signed-off-by: zac <[email protected]>
1 parent 4f2f054 commit ae48ba5

File tree

10 files changed

+240
-9
lines changed

10 files changed

+240
-9
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import gc
2+
import logging
3+
import threading
4+
import time
5+
from typing import Callable, Optional
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
class PythonGCThread(threading.Thread):
11+
"""A background thread that triggers Python garbage collection.
12+
13+
This thread waits for GC events from CoreWorker and triggers `gc.collect()` when
14+
requested, ensuring that collections are spaced out by at least
15+
`min_interval_s` seconds."""
16+
17+
def __init__(
18+
self, *, min_interval_s: int = 5, gc_collect_func: Optional[Callable] = None
19+
):
20+
logger.debug("Starting Python GC thread")
21+
super().__init__(name="PythonGCThread", daemon=True)
22+
self._should_exit = False
23+
self._last_gc_time = float("-inf")
24+
self._min_gc_interval = min_interval_s
25+
self._gc_event = threading.Event()
26+
# Set the gc_collect_func for UT, defaulting to gc.collect if None
27+
self._gc_collect_func = gc_collect_func or gc.collect
28+
29+
def trigger_gc(self) -> None:
30+
self._gc_event.set()
31+
32+
def run(self):
33+
while not self._should_exit:
34+
self._gc_event.wait()
35+
self._gc_event.clear()
36+
37+
if self._should_exit:
38+
break
39+
40+
time_since_last_gc = time.monotonic() - self._last_gc_time
41+
if time_since_last_gc < self._min_gc_interval:
42+
logger.debug(
43+
f"Skipping GC, only {time_since_last_gc:.2f}s since last GC"
44+
)
45+
continue
46+
47+
try:
48+
start = time.monotonic()
49+
num_freed = self._gc_collect_func()
50+
self._last_gc_time = time.monotonic()
51+
if num_freed > 0:
52+
logger.debug(
53+
"gc.collect() freed {} refs in {} seconds".format(
54+
num_freed, self._last_gc_time - start
55+
)
56+
)
57+
except Exception as e:
58+
logger.error(f"Error during GC: {e}")
59+
self._last_gc_time = time.monotonic()
60+
61+
def stop(self):
62+
logger.debug("Stopping Python GC thread")
63+
self._should_exit = True
64+
self._gc_event.set()
65+
self.join()

python/ray/_private/ray_constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,3 +584,5 @@ def gcs_actor_scheduling_enabled():
584584
FETCH_FAIL_TIMEOUT_SECONDS = (
585585
env_integer("RAY_fetch_fail_timeout_milliseconds", 60000) / 1000
586586
)
587+
588+
RAY_GC_MIN_COLLECT_INTERVAL = env_float("RAY_GC_MIN_COLLECT_INTERVAL_S", 5)

python/ray/_raylet.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ cdef class CoreWorker:
140140
object _task_id_to_future_lock
141141
dict _task_id_to_future
142142
object event_loop_executor
143+
object _gc_thread
143144

144145
cdef unique_ptr[CAddress] _convert_python_address(self, address=*)
145146
cdef store_task_output(

python/ray/_raylet.pyx

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ from ray._private.utils import DeferSigint
248248
from ray._private.object_ref_generator import DynamicObjectRefGenerator
249249
from ray.util.annotations import PublicAPI
250250
from ray._private.custom_types import TensorTransportEnum
251+
from ray._private.gc_collect_manager import PythonGCThread
251252

252253
# Expose GCC & Clang macro to report
253254
# whether C++ optimizations were enabled during compilation.
@@ -2496,14 +2497,21 @@ cdef CRayStatus check_signals() nogil:
24962497

24972498

24982499
cdef void gc_collect(c_bool triggered_by_global_gc) nogil:
2499-
with gil:
2500-
start = time.perf_counter()
2501-
num_freed = gc.collect()
2502-
end = time.perf_counter()
2503-
if num_freed > 0:
2504-
logger.debug(
2505-
"gc.collect() freed {} refs in {} seconds".format(
2506-
num_freed, end - start))
2500+
with gil:
2501+
if RayConfig.instance().start_python_gc_manager_thread():
2502+
start = time.perf_counter()
2503+
worker = ray._private.worker.global_worker
2504+
worker.core_worker.trigger_gc()
2505+
end = time.perf_counter()
2506+
logger.debug("GC event triggered in {} seconds".format(end - start))
2507+
else:
2508+
start = time.perf_counter()
2509+
num_freed = gc.collect()
2510+
end = time.perf_counter()
2511+
if num_freed > 0:
2512+
logger.debug(
2513+
"gc.collect() freed {} refs in {} seconds".format(
2514+
num_freed, end - start))
25072515

25082516

25092517
cdef c_vector[c_string] spill_objects_handler(
@@ -3054,13 +3062,21 @@ cdef class CoreWorker:
30543062
self._task_id_to_future = {}
30553063
self.event_loop_executor = None
30563064

3065+
self._gc_thread = None
3066+
if RayConfig.instance().start_python_gc_manager_thread():
3067+
self._gc_thread = PythonGCThread(min_interval_s=ray_constants.RAY_GC_MIN_COLLECT_INTERVAL)
3068+
self._gc_thread.start()
3069+
30573070
def shutdown_driver(self):
30583071
# If it's a worker, the core worker process should have been
30593072
# shutdown. So we can't call
30603073
# `CCoreWorkerProcess.GetCoreWorker().GetWorkerType()` here.
30613074
# Instead, we use the cached `is_driver` flag to test if it's a
30623075
# driver.
30633076
assert self.is_driver
3077+
if self._gc_thread is not None:
3078+
self._gc_thread.stop()
3079+
self._gc_thread = None
30643080
with nogil:
30653081
CCoreWorkerProcess.Shutdown()
30663082

@@ -4719,6 +4735,9 @@ cdef class CoreWorker:
47194735

47204736
return self.current_runtime_env
47214737

4738+
def trigger_gc(self):
4739+
self._gc_thread.trigger_gc()
4740+
47224741
def get_pending_children_task_ids(self, parent_task_id: TaskID):
47234742
cdef:
47244743
CTaskID c_parent_task_id = parent_task_id.native()

python/ray/includes/ray_config.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,5 @@ cdef extern from "ray/common/ray_config.h" nogil:
8686
int maximum_gcs_destroyed_actor_cached_count() const
8787

8888
c_bool record_task_actor_creation_sites() const
89+
90+
c_bool start_python_gc_manager_thread() const

python/ray/includes/ray_config.pxi

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,7 @@ cdef class Config:
140140
@staticmethod
141141
def maximum_gcs_destroyed_actor_cached_count():
142142
return RayConfig.instance().maximum_gcs_destroyed_actor_cached_count()
143+
144+
@staticmethod
145+
def start_python_gc_manager_thread():
146+
return RayConfig.instance().start_python_gc_manager_thread()

python/ray/tests/test_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def get_thread_count(self):
267267
ray.get(actor.get_thread_count.remote())
268268
# Lowering these numbers in this assert should be celebrated,
269269
# increasing these numbers should be scrutinized
270-
assert ray.get(actor.get_thread_count.remote()) in {24, 25}
270+
assert ray.get(actor.get_thread_count.remote()) in {24, 25, 26}
271271

272272

273273
# https://github.com/ray-project/ray/issues/7287

python/ray/tests/test_global_gc.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@
22
import gc
33
import logging
44
import sys
5+
import time
56
import weakref
7+
from unittest.mock import Mock
68

79
import numpy as np
810
import pytest
911

1012
import ray
1113
import ray.cluster_utils
1214
from ray._common.test_utils import wait_for_condition
15+
from ray._private.gc_collect_manager import PythonGCThread
1316
from ray._private.internal_api import global_gc
1417

1518
logger = logging.getLogger(__name__)
@@ -216,5 +219,136 @@ def f(self):
216219
gc.enable()
217220

218221

222+
def test_local_gc_called_once_per_interval(shutdown_only):
223+
ray.init(
224+
num_cpus=2,
225+
_system_config={
226+
"local_gc_interval_s": 1,
227+
"local_gc_min_interval_s": 0,
228+
"global_gc_min_interval_s": 0,
229+
},
230+
)
231+
232+
class ObjectWithCyclicRef:
233+
def __init__(self):
234+
self.loop = self
235+
236+
@ray.remote(num_cpus=1)
237+
class GarbageHolder:
238+
def __init__(self):
239+
gc.disable()
240+
self.garbage = None
241+
242+
def make_garbage(self):
243+
x = ObjectWithCyclicRef()
244+
self.garbage = weakref.ref(x)
245+
return True
246+
247+
def has_garbage(self):
248+
return self.garbage() is not None
249+
250+
def all_garbage_collected(local_ref):
251+
return local_ref() is None and not any(
252+
ray.get([a.has_garbage.remote() for a in actors])
253+
)
254+
255+
try:
256+
gc.disable()
257+
258+
# Round 1: first batch of garbage should be collected
259+
# Local driver.
260+
local_ref = weakref.ref(ObjectWithCyclicRef())
261+
# Remote workers.
262+
actors = [GarbageHolder.remote() for _ in range(2)]
263+
ray.get([a.make_garbage.remote() for a in actors])
264+
265+
assert local_ref() is not None
266+
assert all(ray.get([a.has_garbage.remote() for a in actors]))
267+
268+
wait_for_condition(
269+
lambda: all_garbage_collected(local_ref),
270+
)
271+
272+
# Round 2: second batch should NOT be collected within min_interval
273+
local_ref = weakref.ref(ObjectWithCyclicRef())
274+
ray.get([a.make_garbage.remote() for a in actors])
275+
276+
with pytest.raises(RuntimeError):
277+
wait_for_condition(
278+
lambda: all_garbage_collected(local_ref),
279+
timeout=2.0, # shorter than min_interval
280+
retry_interval_ms=50,
281+
)
282+
283+
# Round 3: after min_interval passes, garbage should be collected
284+
wait_for_condition(
285+
lambda: all_garbage_collected(local_ref),
286+
timeout=10.0,
287+
retry_interval_ms=50,
288+
)
289+
290+
finally:
291+
gc.enable()
292+
293+
294+
def test_gc_manager_thread_basic_functionality():
295+
mock_gc_collect = Mock(return_value=10)
296+
297+
gc_thread = PythonGCThread(min_interval_s=1, gc_collect_func=mock_gc_collect)
298+
299+
try:
300+
gc_thread.start()
301+
assert gc_thread.is_alive()
302+
303+
gc_thread.trigger_gc()
304+
305+
wait_for_condition(lambda: mock_gc_collect.call_count == 1, timeout=2)
306+
307+
mock_gc_collect.assert_called_once()
308+
309+
finally:
310+
gc_thread.stop()
311+
assert not gc_thread.is_alive()
312+
313+
314+
def test_gc_manager_thread_min_interval_throttling():
315+
mock_gc_collect = Mock(return_value=5)
316+
317+
gc_thread = PythonGCThread(min_interval_s=2, gc_collect_func=mock_gc_collect)
318+
319+
try:
320+
gc_thread.start()
321+
322+
for _ in range(3):
323+
gc_thread.trigger_gc()
324+
time.sleep(1)
325+
326+
wait_for_condition(lambda: mock_gc_collect.call_count == 2, timeout=2)
327+
328+
assert mock_gc_collect.call_count == 2
329+
330+
finally:
331+
gc_thread.stop()
332+
333+
334+
def test_gc_manager_thread_exception_handling():
335+
mock_gc_collect = Mock(side_effect=RuntimeError("GC failed"))
336+
337+
gc_thread = PythonGCThread(min_interval_s=5, gc_collect_func=mock_gc_collect)
338+
339+
try:
340+
gc_thread.start()
341+
342+
for _ in range(3):
343+
gc_thread.trigger_gc()
344+
time.sleep(0.1)
345+
346+
assert gc_thread.is_alive()
347+
mock_gc_collect.assert_called_once()
348+
349+
finally:
350+
gc_thread.stop()
351+
352+
219353
if __name__ == "__main__":
220354
sys.exit(pytest.main(["-sv", __file__]))

src/ray/common/ray_config_def.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -945,3 +945,6 @@ RAY_CONFIG(int32_t, raylet_rpc_server_reconnect_timeout_s, 60)
945945
// process getting spawned. Setting to zero or less maintains the default
946946
// number of threads grpc will spawn.
947947
RAY_CONFIG(int64_t, worker_num_grpc_internal_threads, 0)
948+
949+
// Whether to start a background thread to manage Python GC in workers.
950+
RAY_CONFIG(bool, start_python_gc_manager_thread, true)

src/ray/raylet/worker_pool.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,7 @@ WorkerPool::BuildProcessCommandArgs(const Language &language,
461461
// Support forking in gRPC.
462462
env.insert({"GRPC_ENABLE_FORK_SUPPORT", "True"});
463463
env.insert({"GRPC_POLL_STRATEGY", "poll"});
464+
env.insert({"RAY_start_python_gc_manager_thread", "0"});
464465
}
465466

466467
return {std::move(worker_command_args), std::move(env)};

0 commit comments

Comments
 (0)