@@ -87,13 +87,13 @@ def get_draft_tokens(
8787 self ,
8888 prefix : list [int ],
8989 request_id : int ,
90- end_id : int ,
90+ padding_id : int ,
9191 max_sequence_length : int ,
9292 ):
9393 prefix_len = len (prefix )
9494 max_draft_token_length_this_step = max_sequence_length - 1 - prefix_len
9595 if max_draft_token_length_this_step <= 0 : # No draft token is need if the prefix is long enough
96- return [end_id ]
96+ return [padding_id ]
9797 if request_id not in self .start_index : # Extend start_index and pool for a new request
9898 self .start_index [request_id ] = 0
9999 if not self .is_public_pool :
@@ -126,7 +126,7 @@ def get_draft_tokens(
126126 pool [pattern ].add (new_match )
127127
128128 # Find match
129- draft_tokens = [end_id ] # fallback value
129+ draft_tokens = [padding_id ] # fallback value
130130 for size in range (min (self .max_matching_ngram_size , prefix_len - 1 ), 0 ,
131131 - 1 ):
132132 pattern = tuple (prefix [- size :])
@@ -194,11 +194,12 @@ def prepare_draft_tokens(
194194 draft_tokens = self .spec_resource_manager .get_draft_tokens (
195195 prefix ,
196196 request .request_id ,
197- request .py_end_id ,
198- request .py_orig_prompt_len + request .py_max_new_tokens ,
197+ padding_id = 0 ,
198+ max_sequence_length = request .py_orig_prompt_len +
199+ request .py_max_new_tokens ,
199200 )
200201 # Pad length to `self.max_draft_len`
201202 if len (draft_tokens ) > 0 :
202203 pad_length = self .max_draft_len - len (draft_tokens )
203- draft_tokens .extend ([request . py_end_id ] * pad_length )
204+ draft_tokens .extend ([0 ] * pad_length )
204205 request .py_draft_tokens = draft_tokens
0 commit comments