Skip to content

Commit d462ca9

Browse files
zhihaoshan-googleZhihao Shan
andauthored
add seperate prefill detokenization thread (#152)
Co-authored-by: Zhihao Shan <[email protected]>
1 parent 15e3963 commit d462ca9

File tree

1 file changed

+49
-19
lines changed

1 file changed

+49
-19
lines changed

jetstream/core/orchestrator.py

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
of the generation loop at the relevant slot.
3939
- Regardless, it performs a step.
4040
- It takes the sampled tokens, and places them on a 'detokenizing_queue'.
41-
7. Within the detokenizing thread:
41+
7. Within the detokenizing thread (Prefill and Generate separately):
4242
- Tokens are detokenized for every 'slot' in a given set of sampled tokens.
4343
- When an end condition is met, the 'slot' integer is returned to the
4444
respective generation queue.
@@ -210,7 +210,8 @@ class Driver:
210210
# Stage 4
211211
# This can be a list because we can pass it as an arg to generate and
212212
# detokenize threads. It is a list of tokens to be detokenized.
213-
_detokenize_backlogs: list[queue.Queue[engine_api.ResultTokens]] = []
213+
_prefill_detokenize_backlogs: list[queue.Queue[engine_api.ResultTokens]] = []
214+
_generate_detokenize_backlogs: list[queue.Queue[engine_api.ResultTokens]] = []
214215
_generate_slots: list[queue.Queue[int]] = []
215216
_active_requests: list[queue.Queue[tuple[int, ActiveRequest]]] = []
216217

@@ -270,11 +271,11 @@ def __init__(
270271
# one of the generate backlogs.
271272
# Interleaved Mode: Max size is 1 to increase the HBM utilization
272273
# during generate.
273-
# Disaggregated Mode: Max size is 4 to allow for 2 prefills to be enqueued
274-
# while 1 transfer is enqueued while 1 is being transferred.
274+
# Disaggregated Mode: Max size is 16 to allow for total 16 prefills to
275+
# be enqueued or enqueued while 1 is being transferred.
275276
# TODO: Make queue size configurable.
276277
self._transfer_backlogs = [
277-
queue.Queue(1 if self._interleaved_mode else 4)
278+
queue.Queue(1 if self._interleaved_mode else 16)
278279
for i in range(len(self._prefill_engines))
279280
]
280281
if self._metrics_collector:
@@ -302,10 +303,11 @@ def __init__(
302303
functools.partial(float, backlog.qsize())
303304
)
304305
# Stage 4
305-
# After generation, ActiveRequests are placed on the detokenization backlog
306-
# for tokens to be sent into each ActiveRequest's return channel.
307-
# We have one of these per generate engine to simplify the logic keeping
308-
# track of which generation engine to replace slots on.
306+
# After prefill and generation, ActiveRequests are placed on the
307+
# detokenization backlog for tokens to be sent into each ActiveRequest's
308+
# return channel.
309+
# We have one of these per prefill / generate engine to simplify
310+
# the logic keeping track of which generation engine to replace slots on.
309311
# This is a queue of either - tuple[int, ActiveRequest] which represents our
310312
# active requests, or tuple[int, sample_tokens]. We combine these into one
311313
# queue because it allows us to be somewhat clever with how we do
@@ -320,7 +322,16 @@ def __init__(
320322
# the possibility of race conditions where a slot is made live before the
321323
# tokens are ready and it receives tokens from a different sequence,
322324
# or tokens detokenized before the relevant slot is live.
323-
self._detokenize_backlogs = [
325+
326+
self._prefill_detokenize_backlogs = [
327+
# No need to set maxsize, as transfer queue can
328+
# provide the backpressure to the prefill workload
329+
# (to avoid the overwhelming prefill).
330+
queue.Queue()
331+
for _ in self._prefill_engines
332+
]
333+
334+
self._generate_detokenize_backlogs = [
324335
# We don't let detokenization accumulate more than 8 steps to avoid
325336
# synchronization issues.
326337
queue.Queue(8)
@@ -376,13 +387,25 @@ def __init__(
376387
)
377388
for idx in range(len(self._generate_engines))
378389
]
379-
self.detokenize_threads = [
390+
self.prefill_detokenize_threads = [
380391
JetThread(
381392
target=functools.partial(
382393
self._detokenize_thread,
383-
idx,
394+
is_prefill=True,
395+
idx=idx,
396+
),
397+
name=f"prefill_detokenize-{idx}",
398+
)
399+
for idx in range(len(self._generate_engines))
400+
]
401+
self.generate_detokenize_threads = [
402+
JetThread(
403+
target=functools.partial(
404+
self._detokenize_thread,
405+
is_prefill=False,
406+
idx=idx,
384407
),
385-
name=f"detokenize-{idx}",
408+
name=f"generate_detokenize-{idx}",
386409
)
387410
for idx in range(len(self._generate_engines))
388411
]
@@ -391,7 +414,8 @@ def __init__(
391414
self._prefill_threads,
392415
self._transfer_threads,
393416
self._generate_threads,
394-
self.detokenize_threads,
417+
self.prefill_detokenize_threads,
418+
self.generate_detokenize_threads,
395419
)
396420
)
397421
self.live = True
@@ -410,7 +434,8 @@ def stop(self):
410434
[self._prefill_backlog],
411435
self._transfer_backlogs,
412436
self._generate_backlogs.values(),
413-
self._detokenize_backlogs,
437+
self._prefill_detokenize_backlogs,
438+
self._generate_detokenize_backlogs,
414439
)
415440
)
416441

@@ -523,7 +548,7 @@ def _prefill_thread(self, idx: int):
523548

524549
# put first token to detokenize queue
525550
request.complete = np.zeros((prefill_engine.samples_per_slot,), np.bool_)
526-
my_detokenize_backlog = self._detokenize_backlogs[idx]
551+
my_detokenize_backlog = self._prefill_detokenize_backlogs[idx]
527552
request.metadata.transfer_enqueue_time = time.perf_counter()
528553
my_detokenize_backlog.put(
529554
(first_token, request, request.metadata.prefill_dequeue_time),
@@ -619,7 +644,7 @@ def _generate_thread(self, idx: int):
619644
generate_engine = self._generate_engines[idx]
620645
my_slots = self._generate_slots[idx]
621646
my_generate_backlog = self._generate_backlogs[idx]
622-
my_detokenize_backlog = self._detokenize_backlogs[idx]
647+
my_detokenize_backlog = self._generate_detokenize_backlogs[idx]
623648

624649
# Keep track of what step tokens were generated at.
625650
generate_timestep = 0
@@ -749,12 +774,17 @@ def _generate_thread(self, idx: int):
749774
)
750775
time_of_last_generate = time.time()
751776

752-
def _detokenize_thread(self, idx: int):
777+
def _detokenize_thread(self, is_prefill: bool, idx: int):
753778
"""Detokenize sampled tokens and returns them to the user."""
754779
# One of these per generate engine.
755780
# For all filled my_slots, pop the sampled token onto the relevant
756781
# requests return channel. If it done, place it back onto free slots.
757-
my_detokenize_backlog = self._detokenize_backlogs[idx]
782+
783+
if is_prefill:
784+
my_detokenize_backlog = self._prefill_detokenize_backlogs[idx]
785+
else:
786+
my_detokenize_backlog = self._generate_detokenize_backlogs[idx]
787+
758788
my_generate_engine = self._generate_engines[idx]
759789
my_slots = self._generate_slots[idx]
760790

0 commit comments

Comments
 (0)