@@ -391,6 +391,58 @@ def append_eagle3(tokens: torch.Tensor, model_outputs):
391
391
d2t = model_outputs ["d2t" ][tokens ]
392
392
tokens += d2t
393
393
394
+ @staticmethod
395
+ def _apply_embedding_bias (
396
+ logits : torch .Tensor ,
397
+ requests : list [LlmRequest ],
398
+ steps_per_request : list [int ] = None ) -> torch .Tensor :
399
+ """Apply embedding bias (aka logit bias) to logits.
400
+ If steps_per_request is None, assumes 1 step per request (non-batched path).
401
+ """
402
+ # Collect biases and their associated data
403
+ bias_list = []
404
+ bias_data = [] # Either indices (fast path) or steps (batched path)
405
+
406
+ for i , req in enumerate (requests ):
407
+ bias = req ._py_embedding_bias_1d
408
+ if bias is not None :
409
+ bias_list .append (bias )
410
+ bias_data .append (i if steps_per_request is
411
+ None else steps_per_request [i ])
412
+
413
+ if not bias_list :
414
+ return logits
415
+
416
+ bias_tensor = torch .stack (bias_list ).to (logits .device ,
417
+ non_blocking = True )
418
+ logits = logits .clone ()
419
+
420
+ if steps_per_request is None :
421
+ # Fast path: direct indexing
422
+ indices = torch .tensor (bias_data , device = logits .device )
423
+ logits [indices ] += bias_tensor
424
+ else :
425
+ # Batched path: expand biases and use boolean mask
426
+ expanded_biases = torch .repeat_interleave (bias_tensor ,
427
+ torch .tensor (
428
+ bias_data ,
429
+ device = logits .device ),
430
+ dim = 0 )
431
+
432
+ mask = torch .zeros (sum (steps_per_request ),
433
+ dtype = torch .bool ,
434
+ device = logits .device )
435
+ offset = 0
436
+ for i , req in enumerate (requests ):
437
+ steps = steps_per_request [i ]
438
+ if req ._py_embedding_bias_1d is not None :
439
+ mask [offset :offset + steps ] = True
440
+ offset += steps
441
+
442
+ logits [mask ] += expanded_biases
443
+
444
+ return logits
445
+
394
446
def _process_requests (self ,
395
447
requests : list [LlmRequest ],
396
448
model_outputs : dict [str , torch .Tensor ],
@@ -411,6 +463,7 @@ def _process_requests(self,
411
463
412
464
if fast_path :
413
465
logits = raw_logits [:len (requests )]
466
+ logits = self ._apply_embedding_bias (logits , requests )
414
467
next_tokens = torch .argmax (logits , dim = - 1 )
415
468
self .append_eagle3 (next_tokens , model_outputs )
416
469
int_next_tokens = next_tokens .to (torch .int , non_blocking = True )
@@ -430,17 +483,29 @@ def _process_requests(self,
430
483
431
484
if batched_strategy is not None :
432
485
logits = raw_logits [:sum_steps ]
486
+ # Collect steps per request for batched strategy
487
+ steps_per_request = [
488
+ 1 + len (req .py_draft_tokens ) for req in requests
489
+ ]
490
+ logits = self ._apply_embedding_bias (logits , requests ,
491
+ steps_per_request )
433
492
batched_next_tokens , batched_softmax = sample (
434
493
batched_strategy , logits )
435
494
self .append_eagle3 (batched_next_tokens , model_outputs )
436
495
437
496
offset = 0
438
- for strategy , slot , steps in zip (strategies , seq_slots , num_steps ):
497
+ for i , (strategy , slot ,
498
+ steps ) in enumerate (zip (strategies , seq_slots , num_steps )):
439
499
input_slice = slice (offset , offset + steps )
440
500
logits = raw_logits [input_slice ]
501
+
502
+ req = requests [i ]
503
+
441
504
if batched_next_tokens is None :
505
+ logits = self ._apply_embedding_bias (logits , [req ])
442
506
next_tokens , softmax = sample (strategy , logits )
443
507
else :
508
+ # Batched processing already applied bias, just use the results
444
509
next_tokens = batched_next_tokens [input_slice ]
445
510
softmax = batched_softmax [input_slice ]
446
511
current_slice = slice (0 , steps ), slot , beam
0 commit comments