@@ -1222,7 +1222,7 @@ def _prepare_tp_inputs(
12221222 num_accepted_draft_tokens = [] # per request
12231223 # if using tree decoding, we need to store the request type and accepted path for each request,
12241224 # which will be used to update the hidden_states_read_indices.
1225- request_accepted_path = {} # per request
1225+ request_accepted_path = {} # per request
12261226
12271227 for request in scheduled_requests .context_requests :
12281228 request_ids .append (request .py_request_id )
@@ -1237,7 +1237,9 @@ def _prepare_tp_inputs(
12371237 gather_ids .append (len (input_ids ) - 1 )
12381238 sequence_lengths .append (len (prompt_tokens ))
12391239 num_accepted_draft_tokens .append (len (prompt_tokens ) - 1 )
1240- request_accepted_path [request .py_request_id ] = request .py_num_accepted_draft_tokens_indices
1240+ request_accepted_path [
1241+ request .
1242+ py_request_id ] = request .py_num_accepted_draft_tokens_indices
12411243 prompt_lengths .append (len (prompt_tokens ))
12421244 past_seen_token_num = begin_compute
12431245 num_cached_tokens_per_seq .append (past_seen_token_num )
@@ -1323,7 +1325,9 @@ def _prepare_tp_inputs(
13231325 previous_pos_indices = []
13241326 for request in extend_requests :
13251327 request_ids .append (request .py_request_id )
1326- request_accepted_path [request .py_request_id ] = request .py_num_accepted_draft_tokens_indices
1328+ request_accepted_path [
1329+ request .
1330+ py_request_id ] = request .py_num_accepted_draft_tokens_indices
13271331 # the request has no previous tensor:
13281332 # (1) next_draft_tokens_device is None, which means overlap scheduler is disabled; or
13291333 # (2) a dummy request; or
@@ -1359,13 +1363,15 @@ def _prepare_tp_inputs(
13591363 assert spec_tree_manager is not None
13601364 assert num_draft_tokens == spec_tree_manager .max_total_draft_tokens
13611365 position_ids .extend (
1362- past_seen_token_num + spec_tree_manager .spec_dec_position_offsets [0 ] # [max_total_draft_tokens + 1]
1366+ past_seen_token_num +
1367+ spec_tree_manager .spec_dec_position_offsets [
1368+ 0 ] # [max_total_draft_tokens + 1]
13631369 )
13641370 else :
13651371 position_ids .extend (
13661372 list (
1367- range (past_seen_token_num , past_seen_token_num + 1 +
1368- num_draft_tokens )))
1373+ range (past_seen_token_num ,
1374+ past_seen_token_num + 1 + num_draft_tokens )))
13691375 num_cached_tokens_per_seq .append (past_seen_token_num )
13701376 request .cached_tokens = num_cached_tokens_per_seq [- 1 ]
13711377 # update batch index
@@ -1390,12 +1396,15 @@ def _prepare_tp_inputs(
13901396 assert spec_tree_manager is not None
13911397 assert num_draft_tokens == spec_tree_manager .max_total_draft_tokens
13921398 position_ids .extend (
1393- past_seen_token_num + spec_tree_manager .spec_dec_position_offsets [0 ] # [max_total_draft_tokens + 1]
1399+ past_seen_token_num +
1400+ spec_tree_manager .spec_dec_position_offsets [
1401+ 0 ] # [max_total_draft_tokens + 1]
13941402 )
13951403 else :
13961404 position_ids .extend (
13971405 list (
1398- range (past_seen_token_num , past_seen_token_num + 1 +
1406+ range (
1407+ past_seen_token_num , past_seen_token_num + 1 +
13991408 self .runtime_draft_len )))
14001409 # previous tensor
14011410 previous_batch_indices .append (previous_batch_idx )
@@ -1433,7 +1442,9 @@ def _prepare_tp_inputs(
14331442 sequence_lengths .append (1 + self .original_max_draft_len )
14341443 num_accepted_draft_tokens .append (
14351444 request .py_num_accepted_draft_tokens )
1436- request_accepted_path [request .py_request_id ] = request .py_num_accepted_draft_tokens_indices
1445+ request_accepted_path [
1446+ request .
1447+ py_request_id ] = request .py_num_accepted_draft_tokens_indices
14371448 prompt_lengths .append (request .py_prompt_len )
14381449 past_seen_token_num = begin_compute
14391450 num_cached_tokens_per_seq .append (past_seen_token_num )
@@ -2241,15 +2252,14 @@ def _get_lora_params_from_requests(self,
22412252 return lora_params
22422253
22432254 @nvtx_range ("_prepare_inputs" )
2244- def _prepare_inputs (
2245- self ,
2246- scheduled_requests : ScheduledRequests ,
2247- kv_cache_manager : KVCacheManager ,
2248- attn_metadata : AttentionMetadata ,
2249- spec_metadata : Optional [SpecMetadata ] = None ,
2250- new_tensors_device : Optional [SampleStateTensors ] = None ,
2251- cache_indirection_buffer : Optional [torch .Tensor ] = None ,
2252- resource_manager : Optional [ResourceManager ] = None ):
2255+ def _prepare_inputs (self ,
2256+ scheduled_requests : ScheduledRequests ,
2257+ kv_cache_manager : KVCacheManager ,
2258+ attn_metadata : AttentionMetadata ,
2259+ spec_metadata : Optional [SpecMetadata ] = None ,
2260+ new_tensors_device : Optional [SampleStateTensors ] = None ,
2261+ cache_indirection_buffer : Optional [torch .Tensor ] = None ,
2262+ resource_manager : Optional [ResourceManager ] = None ):
22532263 if self .mapping is not None and 'cp_type' in self .mapping .cp_config :
22542264 cp_type = self .mapping .cp_config ['cp_type' ]
22552265 if CpType .STAR == cp_type :
@@ -2297,7 +2307,8 @@ def forward(
22972307 self .model_is_wrapped , spec_metadata .is_spec_dec_tree )
22982308 attn_metadata .update_spec_dec_param (
22992309 is_spec_dec_mode , spec_metadata , spec_tree_manager ,
2300- self .original_max_draft_len , self .original_max_total_draft_tokens )
2310+ self .original_max_draft_len ,
2311+ self .original_max_total_draft_tokens )
23012312 else :
23022313 spec_resource_manager = None
23032314 spec_metadata = None
0 commit comments