1- from  typing  import  Union 
1+ from  typing  import  Any ,  Dict ,  Union 
22
33import  torch 
44from  neuronx_distributed .operators .argmax  import  argmax  as  nxd_argmax 
55from  neuronx_distributed .operators .topk  import  topk  as  nxd_topk 
66from  neuronx_distributed .parallel_layers  import  parallel_state 
77from  torch_neuronx .xla_impl .ops  import  xla_hlo_call 
88
9- from  neuronx_distributed_inference .models .config  import  NeuronConfig 
9+ from  neuronx_distributed_inference .models .config  import  NeuronConfig ,  OnDeviceSamplingConfig 
1010
1111
1212@xla_hlo_call  
@@ -18,6 +18,62 @@ def rand_like(tensor):
1818    return  dtype [shape ].Rng (minimum , maximum , distribution = 1 )  # Uniform distribution 
1919
2020
21+ def  validate_sampling_params (
22+     params : torch .Tensor , on_device_sampling_config : Union [Dict [str , Any ], OnDeviceSamplingConfig ]
23+ ) ->  None :
24+     """ 
25+     Validates sampling parameters for language models. 
26+ 
27+     Args: 
28+     params (torch.Tensor): Tensor of shape (batch_size, 3) containing sampling parameters 
29+                            in the order: top-k, top-p, temperature. 
30+     on_device_sampling_config 
31+ 
32+     Raises: 
33+     ValueError: If any of the parameters are invalid. 
34+     """ 
35+     if  params .shape [1 ] !=  3 :
36+         raise  ValueError (f"Expected tensor of shape (batch_size, 3), but got { params .shape }  )
37+ 
38+     # autocast params tensor to float32 
39+     params  =  params .to (torch .float32 )
40+ 
41+     # Unpack parameters 
42+     top_k , top_p , temperature  =  params [:, 0 ], params [:, 1 ], params [:, 2 ]
43+ 
44+     if  isinstance (on_device_sampling_config , OnDeviceSamplingConfig ):
45+         global_top_k  =  on_device_sampling_config .global_topk 
46+     else :
47+         global_top_k  =  on_device_sampling_config ["global_topk" ]
48+ 
49+     # Validate top-k value range 
50+     valid_top_k  =  (top_k  ==  - 1 ) |  ((top_k  >  0 ) &  (top_k  <=  global_top_k ))
51+     if  not  torch .all (valid_top_k ):
52+         raise  ValueError (
53+             f"Invalid top-k values found. top-k must be -1 or greater than 0 but less than or equal to { global_top_k = } { top_k = }  
54+         )
55+ 
56+     # checks if top-k values can be represented as integers 
57+     if  not  torch .equal (top_k , top_k .floor ()):
58+         raise  ValueError (
59+             f"Invalid top-k values found. top-k values should be able to be represented as integer values, but found decimal parts. Found { top_k = }  
60+         )
61+ 
62+     # Validate top-p 
63+     valid_top_p  =  (top_p  >  0.0 ) &  (top_p  <=  1.0 )
64+     if  not  torch .all (valid_top_p ):
65+         raise  ValueError (
66+             f"Invalid top-p values found. top-p must be in the range (0.0, 1.0]. Found { top_p = }  
67+         )
68+ 
69+     # Validate temperature 
70+     valid_temp  =  temperature  >  0.0 
71+     if  not  torch .all (valid_temp ):
72+         raise  ValueError (
73+             f"Invalid temperature values found. Temperature must be strictly greater than 0.0. Found { temperature = }  
74+         )
75+ 
76+ 
2177def  prepare_sampling_params (batch_size , top_k = [1 ], top_p = [1.0 ], temperature = [1.0 ]):
2278    top_k  =  prepare_tensor (top_k )
2379    top_p  =  prepare_tensor (top_p )
0 commit comments