From 9aba2391066fa8cd06c84696ec966d126d93e3e7 Mon Sep 17 00:00:00 2001 From: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com> Date: Sun, 17 Aug 2025 19:44:21 -0700 Subject: [PATCH 1/2] finish all send requests before quitting pp event-loop to avoid mpi deadlock Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 5 +++++ .../defs/disaggregated/test_disaggregated.py | 14 +++++++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index d87dbef4e7d..d205c892324 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -849,6 +849,11 @@ def _executor_loop_pp(self): self._process_iter_stats(finished_requests, self.active_requests, previous_batch) + # Unblock receiving processes. When second-last rank quits before last rank, + # last rank will never return from recv_object. + for req in self.send_handles: + if req is not None: + req.wait() def _prepare_and_schedule_batch(self): new_requests = self._fetch_and_activate_new_requests() diff --git a/tests/integration/defs/disaggregated/test_disaggregated.py b/tests/integration/defs/disaggregated/test_disaggregated.py index a02d5a1a16c..01e410fd17c 100644 --- a/tests/integration/defs/disaggregated/test_disaggregated.py +++ b/tests/integration/defs/disaggregated/test_disaggregated.py @@ -156,10 +156,22 @@ def run_disaggregated_test(example_dir, run_env = env.copy() run_env["UCX_TLS"] = "^ib" + nsys_path = os.getenv("NSYS_PATH", None) + nsys_file = os.getenv("NSYS_FILE", None) + nsys_cmd = [ + "nsys", + "profile", + "--trace", + "cuda,cublas,nvtx", + "--output", + nsys_file, + "--force-overwrite=true", + ] if nsys_path and nsys_file else [] + num_ranks, config_file = get_test_config(test_desc, example_dir, os.path.dirname(__file__)) - workers_cmd = [ + workers_cmd = nsys_cmd + [ 'mpirun', '--allow-run-as-root', '--oversubscribe', '-n', str(num_ranks), 'trtllm-serve', 'disaggregated_mpi_worker', '-c', config_file From 546be7d7fc7aea8f90e60d0629630a0dd859fa11 Mon Sep 17 00:00:00 2001 From: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com> Date: Tue, 19 Aug 2025 00:55:59 -0700 Subject: [PATCH 2/2] debugging torch.cuda.event hang Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com> --- docker/Makefile | 1 - tensorrt_llm/_torch/pyexecutor/py_executor.py | 100 +++++++++++------- tensorrt_llm/_torch/pyexecutor/sampler.py | 21 +++- .../defs/disaggregated/test_disaggregated.py | 9 +- 4 files changed, 85 insertions(+), 46 deletions(-) diff --git a/docker/Makefile b/docker/Makefile index b95ea971ef3..af37bcdca63 100644 --- a/docker/Makefile +++ b/docker/Makefile @@ -157,7 +157,6 @@ endif $(GPU_OPTS) \ --volume $(SOURCE_DIR):$(CODE_DIR) \ $(EXTRA_VOLUMES) \ - $(if $(and $(filter 1,$(LOCAL_USER)),$(shell [ -w "$(USER_CACHE_DIR)" ] && echo 1)),--volume $(USER_CACHE_DIR):/home/$(USER_NAME)/.cache:rw) \ --env "CCACHE_DIR=$(CCACHE_DIR)" \ --env "CCACHE_BASEDIR=$(CODE_DIR)" \ --env "CONAN_HOME=$(CONAN_DIR)" \ diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index d205c892324..796d5ad7584 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -40,6 +40,13 @@ from .sampler import Sampler, SampleState, SampleStateTensors from .scheduler import RequestScheduler, ScheduledRequests +torch._C._activate_gpu_trace() +torch.cuda._gpu_trace.register_callback_for_event_synchronization(lambda event: logger.info(f"TorchEvent {event} synchronized")) +torch.cuda._gpu_trace.register_callback_for_event_creation(lambda event: logger.info(f"TorchEvent {event} created")) +torch.cuda._gpu_trace.register_callback_for_event_record(lambda event, t: logger.info(f"TorchEvent {event} recorded at {t}")) +torch.cuda._gpu_trace.register_callback_for_event_wait(lambda event, t: logger.info(f"TorchEvent {event} waited at {t}")) +torch.cuda._gpu_trace.register_callback_for_event_deletion(lambda event: logger.info(f"TorchEvent {event} destroyed")) + # Environment variable to specify iteration ranges for profiling start/stop. # Format: "start1-stop1,start2-stop2,..." or single iterations "iter1,iter2,..." PROFILE_START_STOP_ENV_VAR_NAME = "TLLM_PROFILE_START_STOP" @@ -242,8 +249,14 @@ def __init__(self, self.kv_cache_transceiver = kv_cache_transceiver if self.dist.pp_size > 1: + logger.info( + f"rank {self.dist.pp_rank} _executor_loop_pp: {self.dist.pp_size}" + ) self.event_loop = self._executor_loop_pp else: + logger.info( + f"rank {self.dist.pp_rank} _executor_loop: {disable_overlap_scheduler}" + ) self.event_loop = self._executor_loop if disable_overlap_scheduler else self._executor_loop_overlap if is_trace_enabled("TLLM_TRACE_EXECUTOR_LOOP"): self.event_loop = trace_func(self.event_loop) @@ -396,6 +409,9 @@ def set_gather_responses(self, gather_all_responses): @property def should_stop_processing(self): + logger.info( + f"rank {self.dist.pp_rank} should_stop_processing: {self.is_shutdown} {len(self.active_requests)} {self.executor_request_queue.get_waiting_queue_size()} handle {len([h for h in self.send_handles if h is not None])}" + ) return self.is_shutdown and len(self.active_requests) == 0 and \ self.executor_request_queue.get_waiting_queue_size() == 0 @@ -627,6 +643,11 @@ def _process_iter_stats(self, finished_requests: list[LlmRequest], batch_state.sample_state.scheduled_requests), req_stats) def _executor_loop_cleanup(self): + # Unblock receiving processes. When second-last rank quits before last rank, + # last rank will never return from recv_object. + for req in self.send_handles: + if req is not None: + req.wait() with self.response_cv: self.is_shutdown = True self.response_cv.notify_all() @@ -750,6 +771,7 @@ def _executor_loop_pp(self): sample_state = self._sample_async( scheduled_batch, batch_outputs) + assert sample_state is not None, "Sampling failed" sample_state.host.logits = logits_host self._update_request_states(scheduled_batch) @@ -775,47 +797,49 @@ def _executor_loop_pp(self): offset) % self.num_micro_batches previous_batch = self.micro_batches[prev_microbatch_id] if previous_batch is not None: - sample_state = previous_batch.sample_state if not self.dist.is_last_pp_rank: - torch.cuda.nvtx.range_push( - "_handle_new_tokens_inter_pp") + with torch.cuda.nvtx.range( + f"_handle_new_tokens_inter_pp{self.dist.pp_rank}_pr{self.dist.prev_pp_rank}_mb{prev_microbatch_id}"): # Receive tokens from previous pp rank (w.r.t model forward direction) - ( - logits, - sample_state.host, - ) = self.dist.recv_object( - src=self.dist.prev_pp_rank, - tag=prev_microbatch_id, - ) - if logits is not None: - logits_host = torch.from_numpy(logits) - sample_state.host.logits = logits_host - sample_state.device.logits = logits_host.to( - self.device_id) + ( + logits, + previous_batch.sample_state.host, + ) = self.dist.recv_object( + src=self.dist.prev_pp_rank, + tag=prev_microbatch_id, + ) + if logits is not None: + logits_host = torch.from_numpy(logits) + previous_batch.sample_state.host.logits = logits_host + previous_batch.sample_state.device.logits = logits_host.to( + self.device_id) else: - torch.cuda.nvtx.range_push("_handle_new_tokens_last_pp") - sample_state.sampler_event.synchronize() + with torch.cuda.nvtx.range( + f"_sync_new_tokens_last_pp_{previous_batch.sample_state.sampler_event.counter}"): + previous_batch.sample_state.sampler_event.synchronize() # Send tokens to next pp rank (w.r.t model forward direction) # Second last rank does not need to since last rank has original decoded tokens if not self.dist.is_second_last_pp_rank: - if self.send_handles[prev_microbatch_id] is not None: - self.send_handles[prev_microbatch_id].wait() - needs_logits = ( - self._need_return_logits(scheduled_batch) - or (self._need_return_log_probs(scheduled_batch) - and sample_state.host.log_probs is not None)) - serialized_logits = sample_state.host.logits.numpy( - ) if needs_logits else None - self.send_handles[ - prev_microbatch_id] = self.dist.isend_object( - ( - serialized_logits, - sample_state.host, - ), - dest=self.dist.next_pp_rank, - tag=prev_microbatch_id) - torch.cuda.nvtx.range_pop() + with torch.cuda.nvtx.range( + f"_send_new_tokens_{self.dist.pp_rank}_pr{self.dist.next_pp_rank}_mb{prev_microbatch_id}"): + if self.send_handles[prev_microbatch_id] is not None: + self.send_handles[prev_microbatch_id].wait() + self.send_handles[prev_microbatch_id] = None + needs_logits = ( + self._need_return_logits(scheduled_batch) + or (self._need_return_log_probs(scheduled_batch) + and sample_state.host.log_probs is not None)) + serialized_logits = sample_state.host.logits.numpy( + ) if needs_logits else None + self.send_handles[ + prev_microbatch_id] = self.dist.isend_object( + ( + serialized_logits, + sample_state.host, + ), + dest=self.dist.next_pp_rank, + tag=prev_microbatch_id) # Stage 3: Finalize previous batch that finished tokens communication # In last pp rank, stage 2 and 3 process different previous batches @@ -849,11 +873,6 @@ def _executor_loop_pp(self): self._process_iter_stats(finished_requests, self.active_requests, previous_batch) - # Unblock receiving processes. When second-last rank quits before last rank, - # last rank will never return from recv_object. - for req in self.send_handles: - if req is not None: - req.wait() def _prepare_and_schedule_batch(self): new_requests = self._fetch_and_activate_new_requests() @@ -1625,6 +1644,9 @@ def _handle_responses(self): else: new_active_requests.append(request) self.active_requests.clear() + logger.info( + f"rank {self.dist.pp_rank} _handle_responses: {len(self.active_requests)} {len(new_active_requests)} {len(requests_to_terminate)}" + ) self.active_requests.extend(new_active_requests) self._enqueue_responses(new_responses) for request in requests_to_terminate: diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index daaac14c5a3..4af8e1e73fa 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -2,6 +2,8 @@ from collections.abc import Iterable from dataclasses import dataclass from typing import Literal +import traceback +from tensorrt_llm.logger import logger import torch @@ -36,6 +38,21 @@ class SampleStateTensors: def values(self): return vars(self).values() +class DebugEvent(torch.cuda.Event): + counter = 0 + + def __init__(self): + super().__init__() + self.counter = DebugEvent.counter + DebugEvent.counter += 1 + + def __del__(self): + logger.info(f"DebugEvent {self.counter} destroyed") + + def synchronize(self): + logger.info(f"DebugEvent {self.counter} synchronized") + super().synchronize() + @dataclass(kw_only=True) class SampleState: @@ -44,7 +61,7 @@ class SampleState: device: SampleStateTensors = None host: SampleStateTensors = None - sampler_event: torch.cuda.Event = None + sampler_event: DebugEvent = None class Sampler(ABC): @@ -376,7 +393,7 @@ def sample_async(self, scheduled_requests: ScheduledRequests, gen_logits_host=gen_logits_host, log_probs_host=log_probs_host) new_tokens_host = new_tokens.to(device="cpu", non_blocking=True) - sampler_event = torch.cuda.Event() + sampler_event = DebugEvent() sampler_event.record() return SampleState(scheduled_requests=scheduled_requests, device=SampleStateTensors(new_tokens=new_tokens), diff --git a/tests/integration/defs/disaggregated/test_disaggregated.py b/tests/integration/defs/disaggregated/test_disaggregated.py index 01e410fd17c..7ac4f97555b 100644 --- a/tests/integration/defs/disaggregated/test_disaggregated.py +++ b/tests/integration/defs/disaggregated/test_disaggregated.py @@ -156,16 +156,16 @@ def run_disaggregated_test(example_dir, run_env = env.copy() run_env["UCX_TLS"] = "^ib" - nsys_path = os.getenv("NSYS_PATH", None) - nsys_file = os.getenv("NSYS_FILE", None) + nsys_path = os.getenv("NSYS_PATH", "nsys") + nsys_file = os.getenv("NSYS_FILE", f"/tmp/disagg_test_{test_desc}") nsys_cmd = [ "nsys", "profile", "--trace", "cuda,cublas,nvtx", - "--output", - nsys_file, + "--output=" + nsys_file, "--force-overwrite=true", + "--duration=120", ] if nsys_path and nsys_file else [] num_ranks, config_file = get_test_config(test_desc, example_dir, @@ -266,6 +266,7 @@ def run_disaggregated_test(example_dir, "The capital of Germany is Berlin", "Asyncio is a Python library" ] + expected_strings = [] for expected_string in expected_strings: if isinstance(expected_string, list): # At least one of the strings in the list should be found in the content