3838 of the generation loop at the relevant slot.
3939 - Regardless, it performs a step.
4040 - It takes the sampled tokens, and places them on a 'detokenizing_queue'.
41- 7. Within the detokenizing thread:
41+ 7. Within the detokenizing thread (Prefill and Generate separately) :
4242 - Tokens are detokenized for every 'slot' in a given set of sampled tokens.
4343 - When an end condition is met, the 'slot' integer is returned to the
4444 respective generation queue.
@@ -210,7 +210,8 @@ class Driver:
210210 # Stage 4
211211 # This can be a list because we can pass it as an arg to generate and
212212 # detokenize threads. It is a list of tokens to be detokenized.
213- _detokenize_backlogs : list [queue .Queue [engine_api .ResultTokens ]] = []
213+ _prefill_detokenize_backlogs : list [queue .Queue [engine_api .ResultTokens ]] = []
214+ _generate_detokenize_backlogs : list [queue .Queue [engine_api .ResultTokens ]] = []
214215 _generate_slots : list [queue .Queue [int ]] = []
215216 _active_requests : list [queue .Queue [tuple [int , ActiveRequest ]]] = []
216217
@@ -270,11 +271,11 @@ def __init__(
270271 # one of the generate backlogs.
271272 # Interleaved Mode: Max size is 1 to increase the HBM utilization
272273 # during generate.
273- # Disaggregated Mode: Max size is 4 to allow for 2 prefills to be enqueued
274- # while 1 transfer is enqueued while 1 is being transferred.
274+ # Disaggregated Mode: Max size is 16 to allow for total 16 prefills to
275+ # be enqueued or enqueued while 1 is being transferred.
275276 # TODO: Make queue size configurable.
276277 self ._transfer_backlogs = [
277- queue .Queue (1 if self ._interleaved_mode else 4 )
278+ queue .Queue (1 if self ._interleaved_mode else 16 )
278279 for i in range (len (self ._prefill_engines ))
279280 ]
280281 if self ._metrics_collector :
@@ -302,10 +303,11 @@ def __init__(
302303 functools .partial (float , backlog .qsize ())
303304 )
304305 # Stage 4
305- # After generation, ActiveRequests are placed on the detokenization backlog
306- # for tokens to be sent into each ActiveRequest's return channel.
307- # We have one of these per generate engine to simplify the logic keeping
308- # track of which generation engine to replace slots on.
306+ # After prefill and generation, ActiveRequests are placed on the
307+ # detokenization backlog for tokens to be sent into each ActiveRequest's
308+ # return channel.
309+ # We have one of these per prefill / generate engine to simplify
310+ # the logic keeping track of which generation engine to replace slots on.
309311 # This is a queue of either - tuple[int, ActiveRequest] which represents our
310312 # active requests, or tuple[int, sample_tokens]. We combine these into one
311313 # queue because it allows us to be somewhat clever with how we do
@@ -320,7 +322,16 @@ def __init__(
320322 # the possibility of race conditions where a slot is made live before the
321323 # tokens are ready and it receives tokens from a different sequence,
322324 # or tokens detokenized before the relevant slot is live.
323- self ._detokenize_backlogs = [
325+
326+ self ._prefill_detokenize_backlogs = [
327+ # No need to set maxsize, as transfer queue can
328+ # provide the backpressure to the prefill workload
329+ # (to avoid the overwhelming prefill).
330+ queue .Queue ()
331+ for _ in self ._prefill_engines
332+ ]
333+
334+ self ._generate_detokenize_backlogs = [
324335 # We don't let detokenization accumulate more than 8 steps to avoid
325336 # synchronization issues.
326337 queue .Queue (8 )
@@ -376,13 +387,25 @@ def __init__(
376387 )
377388 for idx in range (len (self ._generate_engines ))
378389 ]
379- self .detokenize_threads = [
390+ self .prefill_detokenize_threads = [
380391 JetThread (
381392 target = functools .partial (
382393 self ._detokenize_thread ,
383- idx ,
394+ is_prefill = True ,
395+ idx = idx ,
396+ ),
397+ name = f"prefill_detokenize-{ idx } " ,
398+ )
399+ for idx in range (len (self ._generate_engines ))
400+ ]
401+ self .generate_detokenize_threads = [
402+ JetThread (
403+ target = functools .partial (
404+ self ._detokenize_thread ,
405+ is_prefill = False ,
406+ idx = idx ,
384407 ),
385- name = f"detokenize -{ idx } " ,
408+ name = f"generate_detokenize -{ idx } " ,
386409 )
387410 for idx in range (len (self ._generate_engines ))
388411 ]
@@ -391,7 +414,8 @@ def __init__(
391414 self ._prefill_threads ,
392415 self ._transfer_threads ,
393416 self ._generate_threads ,
394- self .detokenize_threads ,
417+ self .prefill_detokenize_threads ,
418+ self .generate_detokenize_threads ,
395419 )
396420 )
397421 self .live = True
@@ -410,7 +434,8 @@ def stop(self):
410434 [self ._prefill_backlog ],
411435 self ._transfer_backlogs ,
412436 self ._generate_backlogs .values (),
413- self ._detokenize_backlogs ,
437+ self ._prefill_detokenize_backlogs ,
438+ self ._generate_detokenize_backlogs ,
414439 )
415440 )
416441
@@ -523,7 +548,7 @@ def _prefill_thread(self, idx: int):
523548
524549 # put first token to detokenize queue
525550 request .complete = np .zeros ((prefill_engine .samples_per_slot ,), np .bool_ )
526- my_detokenize_backlog = self ._detokenize_backlogs [idx ]
551+ my_detokenize_backlog = self ._prefill_detokenize_backlogs [idx ]
527552 request .metadata .transfer_enqueue_time = time .perf_counter ()
528553 my_detokenize_backlog .put (
529554 (first_token , request , request .metadata .prefill_dequeue_time ),
@@ -619,7 +644,7 @@ def _generate_thread(self, idx: int):
619644 generate_engine = self ._generate_engines [idx ]
620645 my_slots = self ._generate_slots [idx ]
621646 my_generate_backlog = self ._generate_backlogs [idx ]
622- my_detokenize_backlog = self ._detokenize_backlogs [idx ]
647+ my_detokenize_backlog = self ._generate_detokenize_backlogs [idx ]
623648
624649 # Keep track of what step tokens were generated at.
625650 generate_timestep = 0
@@ -749,12 +774,17 @@ def _generate_thread(self, idx: int):
749774 )
750775 time_of_last_generate = time .time ()
751776
752- def _detokenize_thread (self , idx : int ):
777+ def _detokenize_thread (self , is_prefill : bool , idx : int ):
753778 """Detokenize sampled tokens and returns them to the user."""
754779 # One of these per generate engine.
755780 # For all filled my_slots, pop the sampled token onto the relevant
756781 # requests return channel. If it done, place it back onto free slots.
757- my_detokenize_backlog = self ._detokenize_backlogs [idx ]
782+
783+ if is_prefill :
784+ my_detokenize_backlog = self ._prefill_detokenize_backlogs [idx ]
785+ else :
786+ my_detokenize_backlog = self ._generate_detokenize_backlogs [idx ]
787+
758788 my_generate_engine = self ._generate_engines [idx ]
759789 my_slots = self ._generate_slots [idx ]
760790
0 commit comments