Skip to content

Commit e58314d

Browse files
authored
remove daemon threads (#1483)
* daemons no more * rebase on PR#1482 * shutdown in singlethreadedmapper
1 parent 10b87d6 commit e58314d

File tree

1 file changed

+6
-11
lines changed

1 file changed

+6
-11
lines changed

torchdata/nodes/map.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,6 @@ def __init__(
166166
max_concurrent: Optional[int],
167167
snapshot_frequency: int,
168168
initial_state: Optional[Dict[str, Any]],
169-
daemonic_reading: bool,
170169
):
171170
self.source = source
172171
self.map_fn = map_fn
@@ -175,7 +174,6 @@ def __init__(
175174
self.method = method
176175
self.mp_context = mp_context
177176
self.snapshot_frequency = snapshot_frequency
178-
self.daemonic_reading = daemonic_reading
179177

180178
self._in_q: Union[queue.Queue, mp.Queue] = queue.Queue() if method == "thread" else mp_context.Queue()
181179
self._intermed_q: Union[queue.Queue, mp.Queue] = queue.Queue() if method == "thread" else mp_context.Queue()
@@ -209,7 +207,7 @@ def __init__(
209207
self._stop,
210208
),
211209
name="read_thread(target=_populate_queue)",
212-
daemon=self.daemonic_reading,
210+
daemon=False,
213211
)
214212
self._read_thread.start()
215213

@@ -249,7 +247,7 @@ def __init__(
249247
self._sort_q,
250248
self._stop,
251249
),
252-
daemon=True,
250+
daemon=False,
253251
name="sort_thread(target=_sort_worker)",
254252
)
255253
self._out_q = self._sort_q
@@ -352,7 +350,6 @@ def __init__(
352350
multiprocessing_context: Optional[str] = None,
353351
max_concurrent: Optional[int] = None,
354352
snapshot_frequency: int = 1,
355-
daemonic_reading: bool = True,
356353
):
357354
super().__init__()
358355
assert method in ["thread", "process"]
@@ -371,7 +368,6 @@ def __init__(
371368
raise ValueError(f"{max_concurrent=} should be <= {num_workers=}!")
372369
self.max_concurrent = max_concurrent
373370
self.snapshot_frequency = snapshot_frequency
374-
self.daemonic_reading = daemonic_reading
375371
self._it: Optional[Union[_InlineMapperIter[T], _ParallelMapperIter[T]]] = None
376372

377373
def reset(self, initial_state: Optional[Dict[str, Any]] = None):
@@ -399,7 +395,6 @@ def _parallel_reset(self, initial_state: Optional[Dict[str, Any]]):
399395
max_concurrent=self.max_concurrent,
400396
snapshot_frequency=self.snapshot_frequency,
401397
initial_state=initial_state,
402-
daemonic_reading=self.daemonic_reading,
403398
)
404399

405400
def next(self) -> T:
@@ -448,7 +443,6 @@ def __init__(
448443
max_concurrent: Optional[int] = None,
449444
snapshot_frequency: int = 1,
450445
prebatch: Optional[int] = None,
451-
daemonic_reading: bool = True,
452446
):
453447
super().__init__()
454448
assert method in ["thread", "process"]
@@ -462,7 +456,6 @@ def __init__(
462456
self.max_concurrent = max_concurrent
463457
self.snapshot_frequency = snapshot_frequency
464458
self.prebatch = prebatch
465-
self.daemonic_reading = daemonic_reading
466459
if prebatch is None:
467460
self.map_fn = map_fn
468461
self.source = source
@@ -481,7 +474,6 @@ def __init__(
481474
multiprocessing_context=self.multiprocessing_context,
482475
max_concurrent=self.max_concurrent,
483476
snapshot_frequency=self.snapshot_frequency,
484-
daemonic_reading=self.daemonic_reading,
485477
)
486478

487479
if self.prebatch is None:
@@ -581,7 +573,7 @@ def __init__(
581573
self._sem,
582574
self._stop_event,
583575
),
584-
daemon=True,
576+
daemon=False,
585577
name=f"worker_thread(target={self.worker.__name__})",
586578
)
587579
self._thread.start()
@@ -605,6 +597,7 @@ def __iter__(self) -> Iterator[T]:
605597
def __next__(self) -> T:
606598
while True:
607599
if self._stop_event.is_set():
600+
self._shutdown()
608601
raise StopIteration()
609602
try:
610603
item, idx = self._q.get(block=True, timeout=QUEUE_TIMEOUT)
@@ -614,11 +607,13 @@ def __next__(self) -> T:
614607
if isinstance(item, StopIteration):
615608
self._sem.release()
616609
self._stop_event.set()
610+
self._shutdown()
617611
raise item
618612
elif isinstance(item, ExceptionWrapper):
619613
if not isinstance(item, StartupExceptionWrapper):
620614
# We don't need to release for startup exceptions
621615
self._sem.release()
616+
self._shutdown()
622617
self._stop_event.set()
623618
item.reraise()
624619
else:

0 commit comments

Comments
 (0)