11import  logging 
22import  numpy  as  np 
3- from  typing  import  Any , Dict 
3+ from  typing  import  Any , Dict ,  Optional 
44import  tensorflow  as  tf 
55
66from  mlagents .envs .timers  import  timed 
7- from  mlagents .envs .brain  import  BrainInfo 
7+ from  mlagents .envs .brain  import  BrainInfo ,  BrainParameters 
88from  mlagents .trainers .models  import  EncoderType , LearningRateSchedule 
99from  mlagents .trainers .ppo .models  import  PPOModel 
1010from  mlagents .trainers .tf_policy  import  TFPolicy 
1717
1818
1919class  PPOPolicy (TFPolicy ):
20-     def  __init__ (self , seed , brain , trainer_params , is_training , load ):
20+     def  __init__ (
21+         self ,
22+         seed : int ,
23+         brain : BrainParameters ,
24+         trainer_params : Dict [str , Any ],
25+         is_training : bool ,
26+         load : bool ,
27+     ):
2128        """ 
2229        Policy for Proximal Policy Optimization Networks. 
2330        :param seed: Random seed. 
@@ -29,8 +36,8 @@ def __init__(self, seed, brain, trainer_params, is_training, load):
2936        super ().__init__ (seed , brain , trainer_params )
3037
3138        reward_signal_configs  =  trainer_params ["reward_signals" ]
32-         self .inference_dict  =  {}
33-         self .update_dict  =  {}
39+         self .inference_dict :  Dict [ str ,  tf . Tensor ]  =  {}
40+         self .update_dict :  Dict [ str ,  tf . Tensor ]  =  {}
3441        self .stats_name_to_update_name  =  {
3542            "Losses/Value Loss" : "value_loss" ,
3643            "Losses/Policy Loss" : "policy_loss" ,
@@ -42,6 +49,7 @@ def __init__(self, seed, brain, trainer_params, is_training, load):
4249        self .create_reward_signals (reward_signal_configs )
4350
4451        with  self .graph .as_default ():
52+             self .bc_module : Optional [BCModule ] =  None 
4553            # Create pretrainer if needed 
4654            if  "pretraining"  in  trainer_params :
4755                BCModule .check_config (trainer_params ["pretraining" ])
@@ -52,8 +60,6 @@ def __init__(self, seed, brain, trainer_params, is_training, load):
5260                    default_num_epoch = trainer_params ["num_epoch" ],
5361                    ** trainer_params ["pretraining" ],
5462                )
55-             else :
56-                 self .bc_module  =  None 
5763
5864        if  load :
5965            self ._load_graph ()
0 commit comments