@@ -459,52 +459,22 @@ def _forward_eval(self, obs: Dict) -> Dict[str, torch.Tensor]:
459459        return  output 
460460
461461    def  _get_simulation_env (self ):
462-         assert  self ._cfg .simulation_env_id  in  ['tictactoe' , 'gomoku' , 'go' ], self ._cfg .simulation_env_id 
463-         assert  self ._cfg .simulation_env_config_type  in  ['play_with_bot' , 'self_play' , 'league' ,
464-                                                         'sampled_play_with_bot' ], self ._cfg .simulation_env_config_type 
465-         if  self ._cfg .simulation_env_id  ==  'tictactoe' :
466-             from  zoo .board_games .tictactoe .envs .tictactoe_env  import  TicTacToeEnv 
467-             if  self ._cfg .simulation_env_config_type  ==  'play_with_bot' :
468-                 from  zoo .board_games .tictactoe .config .tictactoe_alphazero_bot_mode_config  import  \
469-                     tictactoe_alphazero_config 
470-             elif  self ._cfg .simulation_env_config_type  ==  'self_play' :
471-                 from  zoo .board_games .tictactoe .config .tictactoe_alphazero_sp_mode_config  import  \
472-                     tictactoe_alphazero_config 
473-             elif  self ._cfg .simulation_env_config_type  ==  'league' :
474-                 from  zoo .board_games .tictactoe .config .tictactoe_alphazero_league_config  import  \
475-                     tictactoe_alphazero_config 
476-             elif  self ._cfg .simulation_env_config_type  ==  'sampled_play_with_bot' :
477-                 from  zoo .board_games .tictactoe .config .tictactoe_sampled_alphazero_bot_mode_config  import  \
478-                     tictactoe_sampled_alphazero_config  as  tictactoe_alphazero_config 
479- 
480-             self .simulate_env  =  TicTacToeEnv (tictactoe_alphazero_config .env )
481- 
482-         elif  self ._cfg .simulation_env_id  ==  'gomoku' :
483-             from  zoo .board_games .gomoku .envs .gomoku_env  import  GomokuEnv 
484-             if  self ._cfg .simulation_env_config_type  ==  'play_with_bot' :
485-                 from  zoo .board_games .gomoku .config .gomoku_alphazero_bot_mode_config  import  gomoku_alphazero_config 
486-             elif  self ._cfg .simulation_env_config_type  ==  'self_play' :
487-                 from  zoo .board_games .gomoku .config .gomoku_alphazero_sp_mode_config  import  gomoku_alphazero_config 
488-             elif  self ._cfg .simulation_env_config_type  ==  'league' :
489-                 from  zoo .board_games .gomoku .config .gomoku_alphazero_league_config  import  gomoku_alphazero_config 
490-             elif  self ._cfg .simulation_env_config_type  ==  'sampled_play_with_bot' :
491-                 from  zoo .board_games .gomoku .config .gomoku_sampled_alphazero_bot_mode_config  import  \
492-                     gomoku_sampled_alphazero_config  as  gomoku_alphazero_config 
493- 
494-             self .simulate_env  =  GomokuEnv (gomoku_alphazero_config .env )
495-         elif  self ._cfg .simulation_env_id  ==  'go' :
496-             from  zoo .board_games .go .envs .go_env  import  GoEnv 
497-             if  self ._cfg .simulation_env_config_type  ==  'play_with_bot' :
498-                 from  zoo .board_games .go .config .go_alphazero_bot_mode_config  import  go_alphazero_config 
499-             elif  self ._cfg .simulation_env_config_type  ==  'self_play' :
500-                 from  zoo .board_games .go .config .go_alphazero_sp_mode_config  import  go_alphazero_config 
501-             elif  self ._cfg .simulation_env_config_type  ==  'league' :
502-                 from  zoo .board_games .go .config .go_alphazero_league_config  import  go_alphazero_config 
503-             elif  self ._cfg .simulation_env_config_type  ==  'sampled_play_with_bot' :
504-                 from  zoo .board_games .go .config .go_sampled_alphazero_bot_mode_config  import  \
505-                     go_sampled_alphazero_config  as  go_alphazero_config 
506- 
507-             self .simulate_env  =  GoEnv (go_alphazero_config .env )
462+         """ 
463+         Overview: 
464+             Create simulation environment for MCTS using registry-based approach. 
465+             Uses ENV_REGISTRY to instantiate environment based on simulation_env_id. 
466+         """ 
467+         from  ding .utils  import  import_module , ENV_REGISTRY 
468+ 
469+         # Import env modules to trigger registration 
470+         import_names  =  self ._cfg .create_cfg .env .get ('import_names' , [])
471+         import_module (import_names )
472+ 
473+         # Get env class from registry 
474+         env_cls  =  ENV_REGISTRY .get (self ._cfg .simulation_env_id )
475+ 
476+         # Create simulation env with config from main config 
477+         self .simulate_env  =  env_cls (self ._cfg .full_cfg .env )
508478
509479    @torch .no_grad () 
510480    def  _policy_value_func (self , environment : 'Environment' ) ->  Tuple [Dict [int , np .ndarray ], float ]:
0 commit comments