1515# limitations under the License. 
1616
1717import  re 
18- from  typing  import  Dict , List , Optional , Sequence , Tuple 
18+ from  typing  import  Dict , List , Optional , Tuple 
19+ from  collections .abc  import  Sequence 
1920
2021import  numpy  as  np 
2122import  torch 
@@ -112,10 +113,10 @@ def __init__(
112113        self ,
113114        use_film : bool  =  False ,
114115        patch_size : int  =  32 ,
115-         kernel_sizes : Tuple [int , ...] =  (3 , 3 , 3 , 3 ),
116-         strides : Tuple [int , ...] =  (2 , 2 , 2 , 2 ),
117-         features : Tuple [int , ...] =  (32 , 96 , 192 , 384 ),
118-         padding : Tuple [int , ...] =  (1 , 1 , 1 , 1 ),
116+         kernel_sizes : tuple [int , ...] =  (3 , 3 , 3 , 3 ),
117+         strides : tuple [int , ...] =  (2 , 2 , 2 , 2 ),
118+         features : tuple [int , ...] =  (32 , 96 , 192 , 384 ),
119+         padding : tuple [int , ...] =  (1 , 1 , 1 , 1 ),
119120        num_features : int  =  512 ,
120121        img_norm_type : str  =  "default" ,
121122    ):
@@ -167,7 +168,7 @@ def __init__(
167168        self .film  =  FilmConditioning () if  use_film  else  None 
168169
169170    def  forward (
170-         self , observations : torch .Tensor , train : bool  =  True , cond_var : Optional [ torch .Tensor ]  =  None 
171+         self , observations : torch .Tensor , train : bool  =  True , cond_var : torch .Tensor   |   None  =  None 
171172    ):
172173        """ 
173174        Args: 
@@ -212,10 +213,10 @@ class SmallStem16(SmallStem):
212213    def  __init__ (
213214        self ,
214215        use_film : bool  =  False ,
215-         kernel_sizes : Tuple [int , ...] =  (3 , 3 , 3 , 3 ),
216-         strides : Tuple [int , ...] =  (2 , 2 , 2 , 2 ),
217-         features : Tuple [int , ...] =  (32 , 96 , 192 , 384 ),
218-         padding : Tuple [int , ...] =  (1 , 1 , 1 , 1 ),
216+         kernel_sizes : tuple [int , ...] =  (3 , 3 , 3 , 3 ),
217+         strides : tuple [int , ...] =  (2 , 2 , 2 , 2 ),
218+         features : tuple [int , ...] =  (32 , 96 , 192 , 384 ),
219+         padding : tuple [int , ...] =  (1 , 1 , 1 , 1 ),
219220        num_features : int  =  512 ,
220221        img_norm_type : str  =  "default" ,
221222    ):
@@ -243,7 +244,7 @@ def regex_filter(regex_keys, xs):
243244
244245def  generate_proper_pad_mask (
245246    tokens : torch .Tensor ,
246-     pad_mask_dict : Optional [ Dict [ str , torch .Tensor ]] ,
247+     pad_mask_dict : dict [ str , torch .Tensor ]  |   None ,
247248    keys : Sequence [str ],
248249) ->  torch .Tensor :
249250    """Generate proper padding mask for tokens.""" 
@@ -286,8 +287,8 @@ def __init__(
286287
287288    def  forward (
288289        self ,
289-         observations : Dict [str , torch .Tensor ],
290-         tasks : Optional [ Dict [ str , torch .Tensor ]]  =  None ,
290+         observations : dict [str , torch .Tensor ],
291+         tasks : dict [ str , torch .Tensor ]  |   None  =  None ,
291292    ):
292293        """Forward pass through image tokenizer.""" 
293294
@@ -382,7 +383,7 @@ def __init__(self, finetune_encoder: bool = False, proper_pad_mask: bool = True)
382383            for  param  in  self .t5_encoder .parameters ():
383384                param .requires_grad  =  False 
384385
385-     def  forward (self , language_input : Dict [str , torch .Tensor ], tasks = None ) ->  TokenGroup :
386+     def  forward (self , language_input : dict [str , torch .Tensor ], tasks = None ) ->  TokenGroup :
386387        outputs  =  self .t5_encoder (
387388            input_ids = language_input ["input_ids" ], attention_mask = language_input ["attention_mask" ]
388389        )
@@ -411,7 +412,7 @@ def forward(self, language_input: Dict[str, torch.Tensor], tasks=None) -> TokenG
411412class  TextProcessor :
412413    """HF Tokenizer wrapper.""" 
413414
414-     def  __init__ (self , tokenizer_name : str  =  "t5-base" , tokenizer_kwargs : Optional [ Dict ]  =  None ):
415+     def  __init__ (self , tokenizer_name : str  =  "t5-base" , tokenizer_kwargs : dict   |   None  =  None ):
415416        if  tokenizer_kwargs  is  None :
416417            tokenizer_kwargs  =  {
417418                "max_length" : 16 ,
@@ -423,6 +424,6 @@ def __init__(self, tokenizer_name: str = "t5-base", tokenizer_kwargs: Optional[D
423424        self .tokenizer  =  AutoTokenizer .from_pretrained (tokenizer_name )
424425        self .tokenizer_kwargs  =  tokenizer_kwargs 
425426
426-     def  encode (self , strings : List [str ]) ->  Dict [str , torch .Tensor ]:
427+     def  encode (self , strings : list [str ]) ->  dict [str , torch .Tensor ]:
427428        """Encode strings to token IDs and attention masks.""" 
428429        return  self .tokenizer (strings , ** self .tokenizer_kwargs )
0 commit comments