1616import os
1717import torch
1818import logging
19- from typing import Optional , Union , List
2019import hashlib
20+ import warnings
21+ from typing import Optional , Union , List
2122from transformers_neuronx import bucket
2223from transformers_neuronx import utils
2324from transformers_neuronx import module
@@ -39,7 +40,7 @@ def save(self, directory):
3940 def load (self , directory ):
4041 assert self .serialization_enabled (), 'serialization is not enabled for this model'
4142 self ._compiled_artifacts_directory = directory
42-
43+
4344 # top level api
4445 def compile (self , parallel_degree = None ):
4546 kernels = self ._get_all_kernels ()
@@ -78,6 +79,13 @@ def enable_speculative_decoder(self,speculation_length:Optional[Union[List[int],
7879 for k in speculation_length :
7980 self .decoder_lm_head_for_speculation [k ]= self .decoder_param_set .init_speculative_decoder (unroll = self .unroll , buckets = self .token_buckets , model_obj = self , n_active_tokens = k )
8081
82+ def enable_window_context_decoder (self , window_context_length :Optional [Union [List [int ], int ]], unroll ):
83+ if isinstance (window_context_length , int ):
84+ window_context_length = [window_context_length ]
85+ self .window_context_buckets = bucket .context_sizes (window_context_length , self .token_buckets )
86+ for k in self .window_context_buckets :
87+ self .decoder_lm_head_for_window_context [k ]= self .decoder_param_set .init_window_context_decoder (unroll = unroll , buckets = self .token_buckets , model_obj = self , n_active_tokens = k )
88+
8189 def is_compiled (self ):
8290 # First check if the kernels have neffs already
8391 try :
@@ -141,7 +149,7 @@ def register_for_serialization(self, nbs_obj):
141149 def reset (self ):
142150 self .decoder_lm_head .reset ()
143151
144- def context (self , hidden , cache_ids , start_ids , last_token_id , * rest , neuron_config = None ):
152+ def context (self , hidden , cache_ids , start_ids , last_token_id , * rest ):
145153 """A helper to process context (prompt)
146154 1) if there is available context encoding model (infered from self.context_buckets)
147155 - when context_length >= estimate, slice the context up to estimate,
@@ -190,17 +198,45 @@ def context(self, hidden, cache_ids, start_ids, last_token_id, *rest, neuron_con
190198
191199 if current == estimate :
192200 model = self .decoder_lm_head_for_context [estimate , batch_size ]
193- logits = model (hidden_context , cache_context , start_ids , last_token_id , * rest , neuron_config = neuron_config )
201+ if self .neuron_config .log_softmax_scores :
202+ logits , scores = model (hidden_context , cache_context , start_ids , last_token_id , * rest )
203+ else :
204+ logits = model (hidden_context , cache_context , start_ids , last_token_id , * rest )
205+
206+
207+
208+ # process the leftovers context
209+ while current < context_length - 1 :
210+ # find the optimal "window"
211+ estimate = None
212+ if hasattr (self , "window_context_buckets" ):
213+ estimate = bucket .find (self .window_context_buckets , context_length - current )
214+
215+ # when the leftovers is smaller than estimate, fall back to single token generation
216+ # TODO: can we pad?
217+ if estimate is None or context_length - current < estimate :
218+ for i in range (current , context_length ):
219+ cache_ids = torch .as_tensor ([i ], dtype = torch .int32 )
220+ hidden_slice = hidden [:, i :i + 1 ].contiguous ()
221+ logits = self .decoder_lm_head (hidden_slice , cache_ids , start_ids , last_token_id , * rest )
222+ break
223+
224+ hidden_slice = hidden [:, current :current + estimate ].contiguous ()
225+ cache_ids = torch .as_tensor ([i for i in range (current , current + estimate )], dtype = torch .int32 )
226+ last_token_id = torch .as_tensor (estimate - 1 )
227+ if self .neuron_config .log_softmax_scores :
228+ logits , scores = self .decoder_lm_head_for_window_context [estimate ](hidden_slice , cache_ids , start_ids , last_token_id , * rest )
229+ else :
230+ logits = self .decoder_lm_head_for_window_context [estimate ](hidden_slice , cache_ids , start_ids , last_token_id , * rest )
194231
195- for i in range (current , context_length ):
196- cache_ids = torch .as_tensor ([i ], dtype = torch .int32 )
197- hidden_slice = hidden [:, i :i + 1 ].contiguous ()
198- logits = self .decoder_lm_head (hidden_slice , cache_ids , start_ids , last_token_id , * rest , neuron_config = neuron_config )
232+ current += estimate
199233
200234 if self .is_fid :
201235 logits [:] = float ('-inf' )
202236 logits [self .bos_token_id ] = 1.0
203237
238+ if self .neuron_config .log_softmax_scores :
239+ return logits , scores
204240 return logits
205241
206242 def _prepare_for_par_ctx_rhs_padding (self , input_ids ):
@@ -318,9 +354,9 @@ def _cast_logits(self, logits):
318354 logits_dtype = getattr (torch , self .neuron_config .cast_logits_dtype )
319355 return logits .to (logits_dtype )
320356
321- def _context_dynamic_batching (self , hidden , * args , neuron_config = None ):
322- # Taking HSB layout
323- _ , context_length , input_batch_size = hidden .shape
357+ def _context_dynamic_batching (self , hidden , * args ):
358+ is_bsh = self . neuron_config and self . neuron_config . attention_layout == LAYOUT_BSH
359+ input_batch_size = hidden .shape [ 0 ] if is_bsh else hidden . shape [ 2 ]
324360 assert hasattr (self , "context_batch_sizes" ), f"{ type (self )} doesn't support dynamic batching."
325361
326362 running_batch_size = self .context_batch_sizes [- 1 ]
@@ -334,34 +370,37 @@ def _context_dynamic_batching(self, hidden, *args, neuron_config=None):
334370 # Assuming HSB layout
335371 start_idx = iter_id * running_batch_size
336372 end_idx = (iter_id + 1 )* running_batch_size
337- hidden_per_batch = hidden [:, :, start_idx :end_idx ]
373+ if is_bsh :
374+ hidden_per_batch = hidden [start_idx :end_idx , :, :]
375+ else :
376+ hidden_per_batch = hidden [:, :, start_idx :end_idx ]
338377 cache_ids_per_batch = cache_ids [start_idx :end_idx , :]
339378 start_ids_per_batch = start_ids [start_idx :end_idx ]
340379 last_token_id = cache_ids_per_batch .max ()
341380 logits_per_batch = self .context (hidden_per_batch , cache_ids_per_batch ,
342- start_ids_per_batch , last_token_id , neuron_config = neuron_config )
381+ start_ids_per_batch , last_token_id )
343382 all_logits .append (logits_per_batch )
344383 logits = torch .cat (all_logits , dim = 2 )
345384 else :
346385 assert input_batch_size == running_batch_size , \
347386 "input batch size ({input_batch_size}) not equal to running batch size ({running_batch_size})"
348- logits = self .context (hidden , * args , neuron_config = neuron_config )
387+ logits = self .context (hidden , * args )
349388 return logits
350389
351- def _forward (self , hidden , * args , neuron_config = None ):
352- # Taking HSB layout
390+ def _forward (self , hidden , * args ):
353391 _ , context_length , * _ = hidden .shape
354- if not self .neuron_config .on_device_embedding :
355- hidden = hidden .transpose (0 , - 1 ).contiguous ()
356392
357393 if context_length > 1 :
358394 continuous_batching = self .neuron_config and self .neuron_config .continuous_batching
359395 if continuous_batching :
360- logits = self ._context_dynamic_batching (hidden , * args , neuron_config = neuron_config )
396+ logits = self ._context_dynamic_batching (hidden , * args )
361397 else :
362- logits = self .context (hidden , * args , neuron_config = neuron_config )
398+ logits = self .context (hidden , * args )
363399 else :
364- logits = self .decoder_lm_head (hidden , * args , neuron_config = neuron_config )
400+ logits = self .decoder_lm_head (hidden , * args )
401+
402+ if self .neuron_config .on_device_generation :
403+ return logits
365404
366405 logits = self ._cast_logits (logits )
367406 logits = logits [:self .config .vocab_size , - 1 , :]
@@ -378,7 +417,7 @@ def pp_forward(self, *args, **kwargs):
378417 if self .neuron_config .rank_id == 0 :
379418 broad_cast_objects = [args , kwargs ]
380419 dist .broadcast_object_list (broad_cast_objects , src = 0 , device = torch .device ("cpu" ))
381- res = self . forward (* args , ** kwargs )
420+ res = self (* args , ** kwargs )
382421 return res
383422 else :
384423 # if non-host, fall back to a for loop
@@ -394,17 +433,17 @@ def pp_forward(self, *args, **kwargs):
394433 # it is now naturally handled in forward call
395434 dist .broadcast_object_list (broad_cast_objects , src = 0 , device = torch .device ("cpu" ))
396435 args , kwargs = broad_cast_objects
397- self . forward (* args , ** kwargs )
436+ self (* args , ** kwargs )
398437
399438 def serialization_enabled (self ):
400439 return getattr (self , 'nbs_objs' , None ) is not None
401440
402- def profile (self , profile_dir ):
441+ def profile (self , profile_dir , ntff_count_limit ):
403442 kernels = self ._get_all_kernels ()
404443
405444 for kernel in kernels :
406445 if isinstance (kernel , ParallelKernel ):
407- kernel .profile (profile_dir )
446+ kernel .profile (profile_dir , ntff_count_limit )
408447
409448# Base class for all "Serializable Objects"
410449class NeuronBaseSerializer :
0 commit comments