Skip to content

Commit f37e2df

Browse files
committed
Apply registry-based simulation env to all AlphaZero policies
- Updated alphazero.py, gumbel_alphazero.py, sampled_alphazero.py - All three policies now use ENV_REGISTRY instead of hardcoded imports - Simplifies adding new environments without modifying core code
1 parent adaf23b commit f37e2df

File tree

2 files changed

+32
-98
lines changed

2 files changed

+32
-98
lines changed

lzero/policy/alphazero.py

Lines changed: 16 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -337,58 +337,22 @@ def _forward_eval(self, obs: Dict) -> Dict[str, torch.Tensor]:
337337
return output
338338

339339
def _get_simulation_env(self):
340-
if self._cfg.simulation_env_id == 'tictactoe':
341-
from zoo.board_games.tictactoe.envs.tictactoe_env import TicTacToeEnv
342-
if self._cfg.simulation_env_config_type == 'play_with_bot':
343-
from zoo.board_games.tictactoe.config.tictactoe_alphazero_bot_mode_config import \
344-
tictactoe_alphazero_config
345-
elif self._cfg.simulation_env_config_type == 'self_play':
346-
from zoo.board_games.tictactoe.config.tictactoe_alphazero_sp_mode_config import \
347-
tictactoe_alphazero_config
348-
else:
349-
raise NotImplementedError
350-
self.simulate_env = TicTacToeEnv(tictactoe_alphazero_config.env)
351-
352-
elif self._cfg.simulation_env_id == 'gomoku':
353-
from zoo.board_games.gomoku.envs.gomoku_env import GomokuEnv
354-
if self._cfg.simulation_env_config_type == 'play_with_bot':
355-
from zoo.board_games.gomoku.config.gomoku_alphazero_bot_mode_config import gomoku_alphazero_config
356-
elif self._cfg.simulation_env_config_type == 'self_play':
357-
from zoo.board_games.gomoku.config.gomoku_alphazero_sp_mode_config import gomoku_alphazero_config
358-
else:
359-
raise NotImplementedError
360-
self.simulate_env = GomokuEnv(gomoku_alphazero_config.env)
361-
elif self._cfg.simulation_env_id == 'connect4':
362-
from zoo.board_games.connect4.envs.connect4_env import Connect4Env
363-
if self._cfg.simulation_env_config_type == 'play_with_bot':
364-
from zoo.board_games.connect4.config.connect4_alphazero_bot_mode_config import connect4_alphazero_config
365-
elif self._cfg.simulation_env_config_type == 'self_play':
366-
from zoo.board_games.connect4.config.connect4_alphazero_sp_mode_config import connect4_alphazero_config
367-
else:
368-
raise NotImplementedError
369-
self.simulate_env = Connect4Env(connect4_alphazero_config.env)
370-
elif self._cfg.simulation_env_id == 'chess':
371-
from zoo.board_games.chess.envs.chess_lightzero_env import ChessLightZeroEnv
372-
if self._cfg.simulation_env_config_type == 'play_with_bot':
373-
from zoo.board_games.chess.config.chess_alphazero_bot_mode_config import chess_alphazero_config
374-
elif self._cfg.simulation_env_config_type == 'self_play':
375-
from zoo.board_games.chess.config.chess_alphazero_sp_mode_config import chess_alphazero_config
376-
else:
377-
raise NotImplementedError
378-
self.simulate_env = ChessLightZeroEnv(chess_alphazero_config.env)
379-
elif self._cfg.simulation_env_id == 'dummy_any_game':
380-
from zoo.board_games.tictactoe.envs.dummy_any_game_env import AnyGameEnv
381-
if self._cfg.simulation_env_config_type == 'single_player_mode':
382-
from zoo.board_games.tictactoe.config.dummy_any_game_alphazero_single_player_mode_config import \
383-
dummy_any_game_alphazero_config
384-
elif self._cfg.simulation_env_config_type == 'self_play':
385-
from zoo.board_games.tictactoe.config.dummy_any_game_alphazero_self_play_mode_config import \
386-
dummy_any_game_alphazero_config
387-
else:
388-
raise NotImplementedError
389-
self.simulate_env = AnyGameEnv(dummy_any_game_alphazero_config.env)
390-
else:
391-
raise NotImplementedError
340+
"""
341+
Overview:
342+
Create simulation environment for MCTS using registry-based approach.
343+
Uses ENV_REGISTRY to instantiate environment based on simulation_env_id.
344+
"""
345+
from ding.utils import import_module, ENV_REGISTRY
346+
347+
# Import env modules to trigger registration
348+
import_names = self._cfg.create_cfg.env.get('import_names', [])
349+
import_module(import_names)
350+
351+
# Get env class from registry
352+
env_cls = ENV_REGISTRY.get(self._cfg.simulation_env_id)
353+
354+
# Create simulation env with config from main config
355+
self.simulate_env = env_cls(self._cfg.full_cfg.env)
392356

393357
@torch.no_grad()
394358
def _policy_value_fn(self, env: 'Env') -> Tuple[Dict[int, np.ndarray], float]: # noqa

lzero/policy/sampled_alphazero.py

Lines changed: 16 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -459,52 +459,22 @@ def _forward_eval(self, obs: Dict) -> Dict[str, torch.Tensor]:
459459
return output
460460

461461
def _get_simulation_env(self):
462-
assert self._cfg.simulation_env_id in ['tictactoe', 'gomoku', 'go'], self._cfg.simulation_env_id
463-
assert self._cfg.simulation_env_config_type in ['play_with_bot', 'self_play', 'league',
464-
'sampled_play_with_bot'], self._cfg.simulation_env_config_type
465-
if self._cfg.simulation_env_id == 'tictactoe':
466-
from zoo.board_games.tictactoe.envs.tictactoe_env import TicTacToeEnv
467-
if self._cfg.simulation_env_config_type == 'play_with_bot':
468-
from zoo.board_games.tictactoe.config.tictactoe_alphazero_bot_mode_config import \
469-
tictactoe_alphazero_config
470-
elif self._cfg.simulation_env_config_type == 'self_play':
471-
from zoo.board_games.tictactoe.config.tictactoe_alphazero_sp_mode_config import \
472-
tictactoe_alphazero_config
473-
elif self._cfg.simulation_env_config_type == 'league':
474-
from zoo.board_games.tictactoe.config.tictactoe_alphazero_league_config import \
475-
tictactoe_alphazero_config
476-
elif self._cfg.simulation_env_config_type == 'sampled_play_with_bot':
477-
from zoo.board_games.tictactoe.config.tictactoe_sampled_alphazero_bot_mode_config import \
478-
tictactoe_sampled_alphazero_config as tictactoe_alphazero_config
479-
480-
self.simulate_env = TicTacToeEnv(tictactoe_alphazero_config.env)
481-
482-
elif self._cfg.simulation_env_id == 'gomoku':
483-
from zoo.board_games.gomoku.envs.gomoku_env import GomokuEnv
484-
if self._cfg.simulation_env_config_type == 'play_with_bot':
485-
from zoo.board_games.gomoku.config.gomoku_alphazero_bot_mode_config import gomoku_alphazero_config
486-
elif self._cfg.simulation_env_config_type == 'self_play':
487-
from zoo.board_games.gomoku.config.gomoku_alphazero_sp_mode_config import gomoku_alphazero_config
488-
elif self._cfg.simulation_env_config_type == 'league':
489-
from zoo.board_games.gomoku.config.gomoku_alphazero_league_config import gomoku_alphazero_config
490-
elif self._cfg.simulation_env_config_type == 'sampled_play_with_bot':
491-
from zoo.board_games.gomoku.config.gomoku_sampled_alphazero_bot_mode_config import \
492-
gomoku_sampled_alphazero_config as gomoku_alphazero_config
493-
494-
self.simulate_env = GomokuEnv(gomoku_alphazero_config.env)
495-
elif self._cfg.simulation_env_id == 'go':
496-
from zoo.board_games.go.envs.go_env import GoEnv
497-
if self._cfg.simulation_env_config_type == 'play_with_bot':
498-
from zoo.board_games.go.config.go_alphazero_bot_mode_config import go_alphazero_config
499-
elif self._cfg.simulation_env_config_type == 'self_play':
500-
from zoo.board_games.go.config.go_alphazero_sp_mode_config import go_alphazero_config
501-
elif self._cfg.simulation_env_config_type == 'league':
502-
from zoo.board_games.go.config.go_alphazero_league_config import go_alphazero_config
503-
elif self._cfg.simulation_env_config_type == 'sampled_play_with_bot':
504-
from zoo.board_games.go.config.go_sampled_alphazero_bot_mode_config import \
505-
go_sampled_alphazero_config as go_alphazero_config
506-
507-
self.simulate_env = GoEnv(go_alphazero_config.env)
462+
"""
463+
Overview:
464+
Create simulation environment for MCTS using registry-based approach.
465+
Uses ENV_REGISTRY to instantiate environment based on simulation_env_id.
466+
"""
467+
from ding.utils import import_module, ENV_REGISTRY
468+
469+
# Import env modules to trigger registration
470+
import_names = self._cfg.create_cfg.env.get('import_names', [])
471+
import_module(import_names)
472+
473+
# Get env class from registry
474+
env_cls = ENV_REGISTRY.get(self._cfg.simulation_env_id)
475+
476+
# Create simulation env with config from main config
477+
self.simulate_env = env_cls(self._cfg.full_cfg.env)
508478

509479
@torch.no_grad()
510480
def _policy_value_func(self, environment: 'Environment') -> Tuple[Dict[int, np.ndarray], float]:

0 commit comments

Comments
 (0)