Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docker/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ endif
$(GPU_OPTS) \
--volume $(SOURCE_DIR):$(CODE_DIR) \
$(EXTRA_VOLUMES) \
$(if $(and $(filter 1,$(LOCAL_USER)),$(shell [ -w "$(USER_CACHE_DIR)" ] && echo 1)),--volume $(USER_CACHE_DIR):/home/$(USER_NAME)/.cache:rw) \
--env "CCACHE_DIR=$(CCACHE_DIR)" \
--env "CCACHE_BASEDIR=$(CODE_DIR)" \
--env "CONAN_HOME=$(CONAN_DIR)" \
Expand Down
95 changes: 61 additions & 34 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@
from .sampler import Sampler, SampleState, SampleStateTensors
from .scheduler import RequestScheduler, ScheduledRequests

torch._C._activate_gpu_trace()
torch.cuda._gpu_trace.register_callback_for_event_synchronization(lambda event: logger.info(f"TorchEvent {event} synchronized"))
torch.cuda._gpu_trace.register_callback_for_event_creation(lambda event: logger.info(f"TorchEvent {event} created"))
torch.cuda._gpu_trace.register_callback_for_event_record(lambda event, t: logger.info(f"TorchEvent {event} recorded at {t}"))
torch.cuda._gpu_trace.register_callback_for_event_wait(lambda event, t: logger.info(f"TorchEvent {event} waited at {t}"))
torch.cuda._gpu_trace.register_callback_for_event_deletion(lambda event: logger.info(f"TorchEvent {event} destroyed"))

Comment on lines +43 to +49
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Guard and wrap private PyTorch GPU trace hooks; fix long lines (E501)

  • This unconditionally activates private APIs (torch._C._activate_gpu_trace, torch.cuda._gpu_trace.*) at import time. Risky and costly by default.
  • Gate with an env var and try/except; downgrade to a warning on failure.
  • Split long lines per Ruff E501.
- torch._C._activate_gpu_trace()
- torch.cuda._gpu_trace.register_callback_for_event_synchronization(lambda event: logger.info(f"TorchEvent {event} synchronized"))
- torch.cuda._gpu_trace.register_callback_for_event_creation(lambda event: logger.info(f"TorchEvent {event} created"))
- torch.cuda._gpu_trace.register_callback_for_event_record(lambda event, t: logger.info(f"TorchEvent {event} recorded at {t}"))
- torch.cuda._gpu_trace.register_callback_for_event_wait(lambda event, t: logger.info(f"TorchEvent {event} waited at {t}"))
- torch.cuda._gpu_trace.register_callback_for_event_deletion(lambda event: logger.info(f"TorchEvent {event} destroyed"))
+if os.environ.get("TLLM_TORCH_GPU_TRACE") == "1":
+    try:
+        torch._C._activate_gpu_trace()
+        torch.cuda._gpu_trace.register_callback_for_event_synchronization(
+            lambda event: logger.info(f"TorchEvent {event} synchronized"))
+        torch.cuda._gpu_trace.register_callback_for_event_creation(
+            lambda event: logger.info(f"TorchEvent {event} created"))
+        torch.cuda._gpu_trace.register_callback_for_event_record(
+            lambda event, t: logger.info(f"TorchEvent {event} recorded at {t}"))
+        torch.cuda._gpu_trace.register_callback_for_event_wait(
+            lambda event, t: logger.info(f"TorchEvent {event} waited at {t}"))
+        torch.cuda._gpu_trace.register_callback_for_event_deletion(
+            lambda event: logger.info(f"TorchEvent {event} destroyed"))
+    except Exception as e:
+        logger.warning("Failed to activate Torch GPU tracing: %s", e)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
torch._C._activate_gpu_trace()
torch.cuda._gpu_trace.register_callback_for_event_synchronization(lambda event: logger.info(f"TorchEvent {event} synchronized"))
torch.cuda._gpu_trace.register_callback_for_event_creation(lambda event: logger.info(f"TorchEvent {event} created"))
torch.cuda._gpu_trace.register_callback_for_event_record(lambda event, t: logger.info(f"TorchEvent {event} recorded at {t}"))
torch.cuda._gpu_trace.register_callback_for_event_wait(lambda event, t: logger.info(f"TorchEvent {event} waited at {t}"))
torch.cuda._gpu_trace.register_callback_for_event_deletion(lambda event: logger.info(f"TorchEvent {event} destroyed"))
if os.environ.get("TLLM_TORCH_GPU_TRACE") == "1":
try:
torch._C._activate_gpu_trace()
torch.cuda._gpu_trace.register_callback_for_event_synchronization(
lambda event: logger.info(f"TorchEvent {event} synchronized"))
torch.cuda._gpu_trace.register_callback_for_event_creation(
lambda event: logger.info(f"TorchEvent {event} created"))
torch.cuda._gpu_trace.register_callback_for_event_record(
lambda event, t: logger.info(f"TorchEvent {event} recorded at {t}"))
torch.cuda._gpu_trace.register_callback_for_event_wait(
lambda event, t: logger.info(f"TorchEvent {event} waited at {t}"))
torch.cuda._gpu_trace.register_callback_for_event_deletion(
lambda event: logger.info(f"TorchEvent {event} destroyed"))
except Exception as e:
logger.warning("Failed to activate Torch GPU tracing: %s", e)
🧰 Tools
🪛 Ruff (0.12.2)

44-44: Line too long (128 > 120)

(E501)


46-46: Line too long (125 > 120)

(E501)


47-47: Line too long (121 > 120)

(E501)

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/py_executor.py around lines 43 to 49, the code
unconditionally calls private PyTorch GPU trace APIs and uses very long lines;
wrap activation and registrations behind an environment variable check (e.g.,
read a specific TRITON/TORCH_TRACE env var) and perform attribute existence
checks before calling, then enclose the activation and each register_* call in a
try/except that logs a warning (logger.warning) on any Exception instead of
raising; also break the logging callback expressions into shorter lines to
satisfy line-length rules and avoid long f-strings in single lines.

# Environment variable to specify iteration ranges for profiling start/stop.
# Format: "start1-stop1,start2-stop2,..." or single iterations "iter1,iter2,..."
PROFILE_START_STOP_ENV_VAR_NAME = "TLLM_PROFILE_START_STOP"
Expand Down Expand Up @@ -242,8 +249,14 @@ def __init__(self,

self.kv_cache_transceiver = kv_cache_transceiver
if self.dist.pp_size > 1:
logger.info(
f"rank {self.dist.pp_rank} _executor_loop_pp: {self.dist.pp_size}"
)
self.event_loop = self._executor_loop_pp
else:
logger.info(
f"rank {self.dist.pp_rank} _executor_loop: {disable_overlap_scheduler}"
)
self.event_loop = self._executor_loop if disable_overlap_scheduler else self._executor_loop_overlap
if is_trace_enabled("TLLM_TRACE_EXECUTOR_LOOP"):
self.event_loop = trace_func(self.event_loop)
Expand Down Expand Up @@ -396,6 +409,9 @@ def set_gather_responses(self, gather_all_responses):

@property
def should_stop_processing(self):
logger.info(
f"rank {self.dist.pp_rank} should_stop_processing: {self.is_shutdown} {len(self.active_requests)} {self.executor_request_queue.get_waiting_queue_size()} handle {len([h for h in self.send_handles if h is not None])}"
)
return self.is_shutdown and len(self.active_requests) == 0 and \
self.executor_request_queue.get_waiting_queue_size() == 0

Expand Down Expand Up @@ -627,6 +643,11 @@ def _process_iter_stats(self, finished_requests: list[LlmRequest],
batch_state.sample_state.scheduled_requests), req_stats)

def _executor_loop_cleanup(self):
# Unblock receiving processes. When second-last rank quits before last rank,
# last rank will never return from recv_object.
for req in self.send_handles:
if req is not None:
req.wait()
with self.response_cv:
self.is_shutdown = True
self.response_cv.notify_all()
Expand Down Expand Up @@ -750,6 +771,7 @@ def _executor_loop_pp(self):

sample_state = self._sample_async(
scheduled_batch, batch_outputs)
assert sample_state is not None, "Sampling failed"
sample_state.host.logits = logits_host
self._update_request_states(scheduled_batch)

Expand All @@ -775,47 +797,49 @@ def _executor_loop_pp(self):
offset) % self.num_micro_batches
previous_batch = self.micro_batches[prev_microbatch_id]
if previous_batch is not None:
sample_state = previous_batch.sample_state
if not self.dist.is_last_pp_rank:
torch.cuda.nvtx.range_push(
"_handle_new_tokens_inter_pp")
with torch.cuda.nvtx.range(
f"_handle_new_tokens_inter_pp{self.dist.pp_rank}_pr{self.dist.prev_pp_rank}_mb{prev_microbatch_id}"):
# Receive tokens from previous pp rank (w.r.t model forward direction)
(
logits,
sample_state.host,
) = self.dist.recv_object(
src=self.dist.prev_pp_rank,
tag=prev_microbatch_id,
)
if logits is not None:
logits_host = torch.from_numpy(logits)
sample_state.host.logits = logits_host
sample_state.device.logits = logits_host.to(
self.device_id)
(
logits,
previous_batch.sample_state.host,
) = self.dist.recv_object(
src=self.dist.prev_pp_rank,
tag=prev_microbatch_id,
)
if logits is not None:
logits_host = torch.from_numpy(logits)
previous_batch.sample_state.host.logits = logits_host
previous_batch.sample_state.device.logits = logits_host.to(
self.device_id)
else:
torch.cuda.nvtx.range_push("_handle_new_tokens_last_pp")
sample_state.sampler_event.synchronize()
with torch.cuda.nvtx.range(
f"_sync_new_tokens_last_pp_{previous_batch.sample_state.sampler_event.counter}"):
previous_batch.sample_state.sampler_event.synchronize()

# Send tokens to next pp rank (w.r.t model forward direction)
# Second last rank does not need to since last rank has original decoded tokens
if not self.dist.is_second_last_pp_rank:
if self.send_handles[prev_microbatch_id] is not None:
self.send_handles[prev_microbatch_id].wait()
needs_logits = (
self._need_return_logits(scheduled_batch)
or (self._need_return_log_probs(scheduled_batch)
and sample_state.host.log_probs is not None))
serialized_logits = sample_state.host.logits.numpy(
) if needs_logits else None
self.send_handles[
prev_microbatch_id] = self.dist.isend_object(
(
serialized_logits,
sample_state.host,
),
dest=self.dist.next_pp_rank,
tag=prev_microbatch_id)
torch.cuda.nvtx.range_pop()
with torch.cuda.nvtx.range(
f"_send_new_tokens_{self.dist.pp_rank}_pr{self.dist.next_pp_rank}_mb{prev_microbatch_id}"):
if self.send_handles[prev_microbatch_id] is not None:
self.send_handles[prev_microbatch_id].wait()
self.send_handles[prev_microbatch_id] = None
needs_logits = (
self._need_return_logits(scheduled_batch)
or (self._need_return_log_probs(scheduled_batch)
and sample_state.host.log_probs is not None))
serialized_logits = sample_state.host.logits.numpy(
) if needs_logits else None
self.send_handles[
prev_microbatch_id] = self.dist.isend_object(
(
serialized_logits,
sample_state.host,
),
dest=self.dist.next_pp_rank,
tag=prev_microbatch_id)

Comment on lines +801 to 843
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Make last-PP NVTX naming robust; fix logits send logic to reference the correct batch

  • previous_batch.sample_state.sampler_event may be a plain torch.cuda.Event (e.g., TRTLLMSampler), so .counter may not exist.
  • needs_logits is computed from the current batch but you’re sending tokens/logits for the previous microbatch. Use previous_batch’s host state.
  • When serializing logits, use the previous batch’s host.logits, not the current sample_state.host.logits.
-                        with torch.cuda.nvtx.range(
-                            f"_sync_new_tokens_last_pp_{previous_batch.sample_state.sampler_event.counter}"):
-                            previous_batch.sample_state.sampler_event.synchronize()
+                        _ev = previous_batch.sample_state.sampler_event
+                        _ctr = getattr(_ev, "counter", -1)
+                        with torch.cuda.nvtx.range(
+                            f"_sync_new_tokens_last_pp_{_ctr}"):
+                            _ev.synchronize()
@@
-                            if self.send_handles[prev_microbatch_id] is not None:
+                            if self.send_handles[prev_microbatch_id] is not None:
                                 self.send_handles[prev_microbatch_id].wait()
                                 self.send_handles[prev_microbatch_id] = None
-                            needs_logits = (
-                                self._need_return_logits(scheduled_batch)
-                                or (self._need_return_log_probs(scheduled_batch)
-                                    and sample_state.host.log_probs is not None))
-                            serialized_logits = sample_state.host.logits.numpy(
-                            ) if needs_logits else None
+                            prev_host = previous_batch.sample_state.host
+                            needs_logits = (
+                                (getattr(prev_host, "logits", None) is not None)
+                                or (getattr(prev_host, "log_probs", None) is not None)
+                            )
+                            serialized_logits = (
+                                prev_host.logits.numpy() if needs_logits else None
+                            )
                             self.send_handles[
                                 prev_microbatch_id] = self.dist.isend_object(
                                     (
                                         serialized_logits,
-                                        sample_state.host,
+                                        prev_host,
                                     ),
                                     dest=self.dist.next_pp_rank,
                                     tag=prev_microbatch_id)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
with torch.cuda.nvtx.range(
f"_handle_new_tokens_inter_pp{self.dist.pp_rank}_pr{self.dist.prev_pp_rank}_mb{prev_microbatch_id}"):
# Receive tokens from previous pp rank (w.r.t model forward direction)
(
logits,
sample_state.host,
) = self.dist.recv_object(
src=self.dist.prev_pp_rank,
tag=prev_microbatch_id,
)
if logits is not None:
logits_host = torch.from_numpy(logits)
sample_state.host.logits = logits_host
sample_state.device.logits = logits_host.to(
self.device_id)
(
logits,
previous_batch.sample_state.host,
) = self.dist.recv_object(
src=self.dist.prev_pp_rank,
tag=prev_microbatch_id,
)
if logits is not None:
logits_host = torch.from_numpy(logits)
previous_batch.sample_state.host.logits = logits_host
previous_batch.sample_state.device.logits = logits_host.to(
self.device_id)
else:
torch.cuda.nvtx.range_push("_handle_new_tokens_last_pp")
sample_state.sampler_event.synchronize()
with torch.cuda.nvtx.range(
f"_sync_new_tokens_last_pp_{previous_batch.sample_state.sampler_event.counter}"):
previous_batch.sample_state.sampler_event.synchronize()
# Send tokens to next pp rank (w.r.t model forward direction)
# Second last rank does not need to since last rank has original decoded tokens
if not self.dist.is_second_last_pp_rank:
if self.send_handles[prev_microbatch_id] is not None:
self.send_handles[prev_microbatch_id].wait()
needs_logits = (
self._need_return_logits(scheduled_batch)
or (self._need_return_log_probs(scheduled_batch)
and sample_state.host.log_probs is not None))
serialized_logits = sample_state.host.logits.numpy(
) if needs_logits else None
self.send_handles[
prev_microbatch_id] = self.dist.isend_object(
(
serialized_logits,
sample_state.host,
),
dest=self.dist.next_pp_rank,
tag=prev_microbatch_id)
torch.cuda.nvtx.range_pop()
with torch.cuda.nvtx.range(
f"_send_new_tokens_{self.dist.pp_rank}_pr{self.dist.next_pp_rank}_mb{prev_microbatch_id}"):
if self.send_handles[prev_microbatch_id] is not None:
self.send_handles[prev_microbatch_id].wait()
self.send_handles[prev_microbatch_id] = None
needs_logits = (
self._need_return_logits(scheduled_batch)
or (self._need_return_log_probs(scheduled_batch)
and sample_state.host.log_probs is not None))
serialized_logits = sample_state.host.logits.numpy(
) if needs_logits else None
self.send_handles[
prev_microbatch_id] = self.dist.isend_object(
(
serialized_logits,
sample_state.host,
),
dest=self.dist.next_pp_rank,
tag=prev_microbatch_id)
with torch.cuda.nvtx.range(
f"_handle_new_tokens_inter_pp{self.dist.pp_rank}_pr{self.dist.prev_pp_rank}_mb{prev_microbatch_id}"):
# Receive tokens from previous pp rank (w.r.t model forward direction)
(
logits,
previous_batch.sample_state.host,
) = self.dist.recv_object(
src=self.dist.prev_pp_rank,
tag=prev_microbatch_id,
)
if logits is not None:
logits_host = torch.from_numpy(logits)
previous_batch.sample_state.host.logits = logits_host
previous_batch.sample_state.device.logits = logits_host.to(
self.device_id)
else:
_ev = previous_batch.sample_state.sampler_event
_ctr = getattr(_ev, "counter", -1)
with torch.cuda.nvtx.range(
f"_sync_new_tokens_last_pp_{_ctr}"):
_ev.synchronize()
# Send tokens to next pp rank (w.r.t model forward direction)
# Second last rank does not need to since last rank has original decoded tokens
if not self.dist.is_second_last_pp_rank:
with torch.cuda.nvtx.range(
f"_send_new_tokens_{self.dist.pp_rank}_pr{self.dist.next_pp_rank}_mb{prev_microbatch_id}"):
if self.send_handles[prev_microbatch_id] is not None:
self.send_handles[prev_microbatch_id].wait()
self.send_handles[prev_microbatch_id] = None
prev_host = previous_batch.sample_state.host
needs_logits = (
(getattr(prev_host, "logits", None) is not None)
or (getattr(prev_host, "log_probs", None) is not None)
)
serialized_logits = (
prev_host.logits.numpy() if needs_logits else None
)
self.send_handles[
prev_microbatch_id] = self.dist.isend_object(
(
serialized_logits,
prev_host,
),
dest=self.dist.next_pp_rank,
tag=prev_microbatch_id)
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/py_executor.py around lines 801 to 843, make
the last-PP NVTX naming and logits-send logic robust: replace the direct access
to previous_batch.sample_state.sampler_event.counter with a safe fallback (e.g.,
check hasattr(event, "counter") and use it, otherwise use a stable identifier
such as previous_microbatch_id or omit the counter) so plain torch.cuda.Event
instances don't break; compute needs_logits using previous_batch.sample_state
(use previous_batch.sample_state.host.log_probs when checking for log-probs)
because you are sending the previous microbatch, and serialize/send
previous_batch.sample_state.host.logits (and convert to device from
previous_batch.sample_state.device if needed) instead of using the current
sample_state.host.logits; keep existing send/handle logic but ensure you
reference previous_batch consistently for both the condition and the serialized
payload.

# Stage 3: Finalize previous batch that finished tokens communication
# In last pp rank, stage 2 and 3 process different previous batches
Expand Down Expand Up @@ -1620,6 +1644,9 @@ def _handle_responses(self):
else:
new_active_requests.append(request)
self.active_requests.clear()
logger.info(
f"rank {self.dist.pp_rank} _handle_responses: {len(self.active_requests)} {len(new_active_requests)} {len(requests_to_terminate)}"
)
self.active_requests.extend(new_active_requests)
self._enqueue_responses(new_responses)
for request in requests_to_terminate:
Expand Down
21 changes: 19 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Literal
import traceback
from tensorrt_llm.logger import logger

import torch

Expand Down Expand Up @@ -36,6 +38,21 @@ class SampleStateTensors:
def values(self):
return vars(self).values()

class DebugEvent(torch.cuda.Event):
counter = 0

def __init__(self):
super().__init__()
self.counter = DebugEvent.counter
DebugEvent.counter += 1

def __del__(self):
logger.info(f"DebugEvent {self.counter} destroyed")

def synchronize(self):
logger.info(f"DebugEvent {self.counter} synchronized")
super().synchronize()

Comment on lines +41 to +55
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

DebugEvent destructor and logging may cause shutdown-time issues; make logging safe and less chatty

  • Logging in __del__ can run during interpreter shutdown when logging infra is partially torn down.
  • Use debug level to avoid flooding logs.
  • Wrap destructor logging in try/except to avoid noisy exceptions at teardown.
 class DebugEvent(torch.cuda.Event):
     counter = 0

     def __init__(self):
         super().__init__()
         self.counter = DebugEvent.counter
         DebugEvent.counter += 1

     def __del__(self):
-        logger.info(f"DebugEvent {self.counter} destroyed")
+        try:
+            logger.debug(f"DebugEvent {self.counter} destroyed")
+        except Exception:
+            # Best-effort logging; ignore errors during interpreter shutdown.
+            pass

     def synchronize(self):
-        logger.info(f"DebugEvent {self.counter} synchronized")
+        logger.debug(f"DebugEvent {self.counter} synchronized")
         super().synchronize()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
class DebugEvent(torch.cuda.Event):
counter = 0
def __init__(self):
super().__init__()
self.counter = DebugEvent.counter
DebugEvent.counter += 1
def __del__(self):
logger.info(f"DebugEvent {self.counter} destroyed")
def synchronize(self):
logger.info(f"DebugEvent {self.counter} synchronized")
super().synchronize()
class DebugEvent(torch.cuda.Event):
counter = 0
def __init__(self):
super().__init__()
self.counter = DebugEvent.counter
DebugEvent.counter += 1
def __del__(self):
try:
logger.debug(f"DebugEvent {self.counter} destroyed")
except Exception:
# Best-effort logging; ignore errors during interpreter shutdown.
pass
def synchronize(self):
logger.debug(f"DebugEvent {self.counter} synchronized")
super().synchronize()
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/sampler.py around lines 41 to 55, the
DebugEvent __del__ and synchronize logging can cause shutdown-time issues and
are too chatty; change both logging calls to logger.debug and wrap the __del__
logging in a try/except that suppresses all exceptions (silently ignore
failures) so nothing raises during interpreter teardown, i.e., perform
logger.debug inside a try block and except Exception: pass to ensure safe,
non-throwing, low-verbosity teardown logging.


@dataclass(kw_only=True)
class SampleState:
Expand All @@ -44,7 +61,7 @@ class SampleState:
device: SampleStateTensors = None
host: SampleStateTensors = None

sampler_event: torch.cuda.Event = None
sampler_event: DebugEvent = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Type now forces DebugEvent, but TRTLLMSampler still returns torch.cuda.Event

SampleState.sampler_event is now DebugEvent, yet TRTLLMSampler (below) constructs a plain torch.cuda.Event. Callers (py_executor) access .counter, which will raise AttributeError when used with TRTLLMSampler.

Align TRTLLMSampler to also use DebugEvent.

Outside the shown range (in this same file), update TRTLLMSampler.sample_async:

@@
-        sampler_event = torch.cuda.Event()
+        sampler_event = DebugEvent()
         sampler_event.record()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
sampler_event: DebugEvent = None
sampler_event = DebugEvent()
sampler_event.record()



class Sampler(ABC):
Expand Down Expand Up @@ -376,7 +393,7 @@ def sample_async(self, scheduled_requests: ScheduledRequests,
gen_logits_host=gen_logits_host,
log_probs_host=log_probs_host)
new_tokens_host = new_tokens.to(device="cpu", non_blocking=True)
sampler_event = torch.cuda.Event()
sampler_event = DebugEvent()
sampler_event.record()
return SampleState(scheduled_requests=scheduled_requests,
device=SampleStateTensors(new_tokens=new_tokens),
Expand Down
15 changes: 14 additions & 1 deletion tests/integration/defs/disaggregated/test_disaggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,22 @@ def run_disaggregated_test(example_dir,
run_env = env.copy()
run_env["UCX_TLS"] = "^ib"

nsys_path = os.getenv("NSYS_PATH", "nsys")
nsys_file = os.getenv("NSYS_FILE", f"/tmp/disagg_test_{test_desc}")
nsys_cmd = [
"nsys",
"profile",
"--trace",
"cuda,cublas,nvtx",
"--output=" + nsys_file,
"--force-overwrite=true",
"--duration=120",
] if nsys_path and nsys_file else []

Comment on lines +159 to +170
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

NSYS profiling is always enabled and ignores NSYS_PATH; fix gating and path usage

  • NSYS_PATH is defaulted to "nsys", making the condition truthy even when unset; profiling will always run.
  • The command uses the literal "nsys" instead of nsys_path.

Gate profiling strictly on both env vars being set, and use the provided NSYS binary. Optionally make duration configurable.

-    nsys_path = os.getenv("NSYS_PATH", "nsys")
-    nsys_file = os.getenv("NSYS_FILE", f"/tmp/disagg_test_{test_desc}")
-    nsys_cmd = [
-        "nsys",
-        "profile",
-        "--trace",
-        "cuda,cublas,nvtx",
-        "--output=" + nsys_file,
-        "--force-overwrite=true",
-        "--duration=120",
-    ] if nsys_path and nsys_file else []
+    nsys_path = os.getenv("NSYS_PATH")
+    nsys_file = os.getenv("NSYS_FILE")
+    nsys_duration = os.getenv("NSYS_DURATION", "120")
+    nsys_cmd = []
+    if nsys_path and nsys_file:
+        nsys_cmd = [
+            nsys_path,
+            "profile",
+            "--trace",
+            "cuda,cublas,nvtx",
+            f"--output={nsys_file}",
+            "--force-overwrite=true",
+            f"--duration={nsys_duration}",
+        ]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
nsys_path = os.getenv("NSYS_PATH", "nsys")
nsys_file = os.getenv("NSYS_FILE", f"/tmp/disagg_test_{test_desc}")
nsys_cmd = [
"nsys",
"profile",
"--trace",
"cuda,cublas,nvtx",
"--output=" + nsys_file,
"--force-overwrite=true",
"--duration=120",
] if nsys_path and nsys_file else []
nsys_path = os.getenv("NSYS_PATH")
nsys_file = os.getenv("NSYS_FILE")
nsys_duration = os.getenv("NSYS_DURATION", "120")
nsys_cmd = []
if nsys_path and nsys_file:
nsys_cmd = [
nsys_path,
"profile",
"--trace",
"cuda,cublas,nvtx",
f"--output={nsys_file}",
"--force-overwrite=true",
f"--duration={nsys_duration}",
]
🤖 Prompt for AI Agents
In tests/integration/defs/disaggregated/test_disaggregated.py around lines 159
to 170, NSYS profiling is incorrectly always enabled because NSYS_PATH is
defaulted to "nsys" and the command uses the literal "nsys"; change the logic to
only enable profiling when both NSYS_PATH and NSYS_FILE environment variables
are explicitly set (no default for NSYS_PATH), use the resolved nsys_path
variable in the command array instead of the string "nsys", and optionally read
a NSYS_DURATION env var (with a sensible default) to make the duration
configurable.

num_ranks, config_file = get_test_config(test_desc, example_dir,
os.path.dirname(__file__))

workers_cmd = [
workers_cmd = nsys_cmd + [
'mpirun', '--allow-run-as-root', '--oversubscribe', '-n',
str(num_ranks), 'trtllm-serve', 'disaggregated_mpi_worker', '-c',
config_file
Expand Down Expand Up @@ -254,6 +266,7 @@ def run_disaggregated_test(example_dir,
"The capital of Germany is Berlin",
"Asyncio is a Python library"
]
expected_strings = []
for expected_string in expected_strings:
if isinstance(expected_string, list):
# At least one of the strings in the list should be found in the content
Expand Down
Loading