Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions lzero/entry/train_alphazero.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ def train_alphazero(
collector_env.seed(cfg.seed)
evaluator_env.seed(cfg.seed, dynamic_seed=False)
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

# Pass full config to policy so it can access cfg.env and create_cfg for simulation env
cfg.policy.full_cfg = cfg
cfg.policy.create_cfg = create_cfg

policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])

# load pretrained model
Expand Down
68 changes: 16 additions & 52 deletions lzero/policy/alphazero.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,58 +337,22 @@ def _forward_eval(self, obs: Dict) -> Dict[str, torch.Tensor]:
return output

def _get_simulation_env(self):
if self._cfg.simulation_env_id == 'tictactoe':
from zoo.board_games.tictactoe.envs.tictactoe_env import TicTacToeEnv
if self._cfg.simulation_env_config_type == 'play_with_bot':
from zoo.board_games.tictactoe.config.tictactoe_alphazero_bot_mode_config import \
tictactoe_alphazero_config
elif self._cfg.simulation_env_config_type == 'self_play':
from zoo.board_games.tictactoe.config.tictactoe_alphazero_sp_mode_config import \
tictactoe_alphazero_config
else:
raise NotImplementedError
self.simulate_env = TicTacToeEnv(tictactoe_alphazero_config.env)

elif self._cfg.simulation_env_id == 'gomoku':
from zoo.board_games.gomoku.envs.gomoku_env import GomokuEnv
if self._cfg.simulation_env_config_type == 'play_with_bot':
from zoo.board_games.gomoku.config.gomoku_alphazero_bot_mode_config import gomoku_alphazero_config
elif self._cfg.simulation_env_config_type == 'self_play':
from zoo.board_games.gomoku.config.gomoku_alphazero_sp_mode_config import gomoku_alphazero_config
else:
raise NotImplementedError
self.simulate_env = GomokuEnv(gomoku_alphazero_config.env)
elif self._cfg.simulation_env_id == 'connect4':
from zoo.board_games.connect4.envs.connect4_env import Connect4Env
if self._cfg.simulation_env_config_type == 'play_with_bot':
from zoo.board_games.connect4.config.connect4_alphazero_bot_mode_config import connect4_alphazero_config
elif self._cfg.simulation_env_config_type == 'self_play':
from zoo.board_games.connect4.config.connect4_alphazero_sp_mode_config import connect4_alphazero_config
else:
raise NotImplementedError
self.simulate_env = Connect4Env(connect4_alphazero_config.env)
elif self._cfg.simulation_env_id == 'chess':
from zoo.board_games.chess.envs.chess_lightzero_env import ChessLightZeroEnv
if self._cfg.simulation_env_config_type == 'play_with_bot':
from zoo.board_games.chess.config.chess_alphazero_bot_mode_config import chess_alphazero_config
elif self._cfg.simulation_env_config_type == 'self_play':
from zoo.board_games.chess.config.chess_alphazero_sp_mode_config import chess_alphazero_config
else:
raise NotImplementedError
self.simulate_env = ChessLightZeroEnv(chess_alphazero_config.env)
elif self._cfg.simulation_env_id == 'dummy_any_game':
from zoo.board_games.tictactoe.envs.dummy_any_game_env import AnyGameEnv
if self._cfg.simulation_env_config_type == 'single_player_mode':
from zoo.board_games.tictactoe.config.dummy_any_game_alphazero_single_player_mode_config import \
dummy_any_game_alphazero_config
elif self._cfg.simulation_env_config_type == 'self_play':
from zoo.board_games.tictactoe.config.dummy_any_game_alphazero_self_play_mode_config import \
dummy_any_game_alphazero_config
else:
raise NotImplementedError
self.simulate_env = AnyGameEnv(dummy_any_game_alphazero_config.env)
else:
raise NotImplementedError
"""
Overview:
Create simulation environment for MCTS using registry-based approach.
Uses ENV_REGISTRY to instantiate environment based on simulation_env_id.
"""
from ding.utils import import_module, ENV_REGISTRY

# Import env modules to trigger registration
import_names = self._cfg.create_cfg.env.get('import_names', [])
import_module(import_names)

# Get env class from registry
env_cls = ENV_REGISTRY.get(self._cfg.simulation_env_id)

# Create simulation env with config from main config
self.simulate_env = env_cls(self._cfg.full_cfg.env)

@torch.no_grad()
def _policy_value_fn(self, env: 'Env') -> Tuple[Dict[int, np.ndarray], float]: # noqa
Expand Down
48 changes: 16 additions & 32 deletions lzero/policy/gumbel_alphazero.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,38 +379,22 @@ def _forward_eval(self, obs: Dict) -> Dict[str, torch.Tensor]:
return output

def _get_simulation_env(self):
if self._cfg.simulation_env_id == 'tictactoe':
from zoo.board_games.tictactoe.envs.tictactoe_env import TicTacToeEnv
if self._cfg.simulation_env_config_type == 'play_with_bot':
from zoo.board_games.tictactoe.config.tictactoe_gumbel_alphazero_bot_mode_config import \
tictactoe_gumbel_alphazero_config
elif self._cfg.simulation_env_config_type == 'self_play':
from zoo.board_games.tictactoe.config.tictactoe_gumbel_alphazero_sp_mode_config import \
tictactoe_gumbel_alphazero_config
else:
raise NotImplementedError
self.simulate_env = TicTacToeEnv(tictactoe_gumbel_alphazero_config.env)

elif self._cfg.simulation_env_id == 'gomoku':
from zoo.board_games.gomoku.envs.gomoku_env import GomokuEnv
if self._cfg.simulation_env_config_type == 'play_with_bot':
from zoo.board_games.gomoku.config.gomoku_gumbel_alphazero_bot_mode_config import gomoku_gumbel_alphazero_config
elif self._cfg.simulation_env_config_type == 'self_play':
from zoo.board_games.gomoku.config.gomoku_gumbel_alphazero_sp_mode_config import gomoku_gumbel_alphazero_config
else:
raise NotImplementedError
self.simulate_env = GomokuEnv(gomoku_gumbel_alphazero_config.env)
elif self._cfg.simulation_env_id == 'connect4':
from zoo.board_games.connect4.envs.connect4_env import Connect4Env
if self._cfg.simulation_env_config_type == 'play_with_bot':
from zoo.board_games.connect4.config.connect4_gumbel_alphazero_bot_mode_config import connect4_gumbel_alphazero_config
elif self._cfg.simulation_env_config_type == 'self_play':
from zoo.board_games.connect4.config.connect4_gumbel_alphazero_sp_mode_config import connect4_gumbel_alphazero_config
else:
raise NotImplementedError
self.simulate_env = Connect4Env(connect4_gumbel_alphazero_config.env)
else:
raise NotImplementedError
"""
Overview:
Create simulation environment for MCTS using registry-based approach.
Uses ENV_REGISTRY to instantiate environment based on simulation_env_id.
"""
from ding.utils import import_module, ENV_REGISTRY

# Import env modules to trigger registration
import_names = self._cfg.create_cfg.env.get('import_names', [])
import_module(import_names)

# Get env class from registry
env_cls = ENV_REGISTRY.get(self._cfg.simulation_env_id)

# Create simulation env with config from main config
self.simulate_env = env_cls(self._cfg.full_cfg.env)

@torch.no_grad()
def _policy_value_fn(self, env: 'Env') -> Tuple[Dict[int, np.ndarray], float]: # noqa
Expand Down
62 changes: 16 additions & 46 deletions lzero/policy/sampled_alphazero.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,52 +459,22 @@ def _forward_eval(self, obs: Dict) -> Dict[str, torch.Tensor]:
return output

def _get_simulation_env(self):
assert self._cfg.simulation_env_id in ['tictactoe', 'gomoku', 'go'], self._cfg.simulation_env_id
assert self._cfg.simulation_env_config_type in ['play_with_bot', 'self_play', 'league',
'sampled_play_with_bot'], self._cfg.simulation_env_config_type
if self._cfg.simulation_env_id == 'tictactoe':
from zoo.board_games.tictactoe.envs.tictactoe_env import TicTacToeEnv
if self._cfg.simulation_env_config_type == 'play_with_bot':
from zoo.board_games.tictactoe.config.tictactoe_alphazero_bot_mode_config import \
tictactoe_alphazero_config
elif self._cfg.simulation_env_config_type == 'self_play':
from zoo.board_games.tictactoe.config.tictactoe_alphazero_sp_mode_config import \
tictactoe_alphazero_config
elif self._cfg.simulation_env_config_type == 'league':
from zoo.board_games.tictactoe.config.tictactoe_alphazero_league_config import \
tictactoe_alphazero_config
elif self._cfg.simulation_env_config_type == 'sampled_play_with_bot':
from zoo.board_games.tictactoe.config.tictactoe_sampled_alphazero_bot_mode_config import \
tictactoe_sampled_alphazero_config as tictactoe_alphazero_config

self.simulate_env = TicTacToeEnv(tictactoe_alphazero_config.env)

elif self._cfg.simulation_env_id == 'gomoku':
from zoo.board_games.gomoku.envs.gomoku_env import GomokuEnv
if self._cfg.simulation_env_config_type == 'play_with_bot':
from zoo.board_games.gomoku.config.gomoku_alphazero_bot_mode_config import gomoku_alphazero_config
elif self._cfg.simulation_env_config_type == 'self_play':
from zoo.board_games.gomoku.config.gomoku_alphazero_sp_mode_config import gomoku_alphazero_config
elif self._cfg.simulation_env_config_type == 'league':
from zoo.board_games.gomoku.config.gomoku_alphazero_league_config import gomoku_alphazero_config
elif self._cfg.simulation_env_config_type == 'sampled_play_with_bot':
from zoo.board_games.gomoku.config.gomoku_sampled_alphazero_bot_mode_config import \
gomoku_sampled_alphazero_config as gomoku_alphazero_config

self.simulate_env = GomokuEnv(gomoku_alphazero_config.env)
elif self._cfg.simulation_env_id == 'go':
from zoo.board_games.go.envs.go_env import GoEnv
if self._cfg.simulation_env_config_type == 'play_with_bot':
from zoo.board_games.go.config.go_alphazero_bot_mode_config import go_alphazero_config
elif self._cfg.simulation_env_config_type == 'self_play':
from zoo.board_games.go.config.go_alphazero_sp_mode_config import go_alphazero_config
elif self._cfg.simulation_env_config_type == 'league':
from zoo.board_games.go.config.go_alphazero_league_config import go_alphazero_config
elif self._cfg.simulation_env_config_type == 'sampled_play_with_bot':
from zoo.board_games.go.config.go_sampled_alphazero_bot_mode_config import \
go_sampled_alphazero_config as go_alphazero_config

self.simulate_env = GoEnv(go_alphazero_config.env)
"""
Overview:
Create simulation environment for MCTS using registry-based approach.
Uses ENV_REGISTRY to instantiate environment based on simulation_env_id.
"""
from ding.utils import import_module, ENV_REGISTRY

# Import env modules to trigger registration
import_names = self._cfg.create_cfg.env.get('import_names', [])
import_module(import_names)

# Get env class from registry
env_cls = ENV_REGISTRY.get(self._cfg.simulation_env_id)

# Create simulation env with config from main config
self.simulate_env = env_cls(self._cfg.full_cfg.env)

@torch.no_grad()
def _policy_value_func(self, environment: 'Environment') -> Tuple[Dict[int, np.ndarray], float]:
Expand Down