From 44c0287a338d3e6f1d741582e73176cae9b64d07 Mon Sep 17 00:00:00 2001 From: puyuan Date: Fri, 23 May 2025 18:14:31 +0800 Subject: [PATCH 1/2] fix(pu): fix chess reset bug when use alphazero ctree --- lzero/entry/train_alphazero.py | 2 +- .../ctree/ctree_alphazero/mcts_alphazero.cpp | 1 + lzero/mcts/ctree/ctree_alphazero/pybind11 | 2 +- lzero/mcts/ptree/ptree_az.py | 2 +- lzero/policy/alphazero.py | 18 ++++++++++++++---- .../config/chess_alphazero_sp_mode_config.py | 16 +++++++++------- .../chess/envs/chess_lightzero_env.py | 16 +++++++++++----- 7 files changed, 38 insertions(+), 19 deletions(-) diff --git a/lzero/entry/train_alphazero.py b/lzero/entry/train_alphazero.py index bb49f0cb6..44aa2d11c 100644 --- a/lzero/entry/train_alphazero.py +++ b/lzero/entry/train_alphazero.py @@ -106,7 +106,7 @@ def train_alphazero( ) # Evaluate policy performance - if evaluator.should_eval(learner.train_iter) and learner.train_iter > 0: + if evaluator.should_eval(learner.train_iter) or learner.train_iter == 0: stop, reward = evaluator.eval( learner.save_checkpoint, learner.train_iter, diff --git a/lzero/mcts/ctree/ctree_alphazero/mcts_alphazero.cpp b/lzero/mcts/ctree/ctree_alphazero/mcts_alphazero.cpp index 0892c3041..926fc6413 100644 --- a/lzero/mcts/ctree/ctree_alphazero/mcts_alphazero.cpp +++ b/lzero/mcts/ctree/ctree_alphazero/mcts_alphazero.cpp @@ -166,6 +166,7 @@ class MCTS { if (!init_state.is_none()) { init_state = py::bytes(init_state.attr("tobytes")()); } + py::object katago_game_state = state_config_for_env_reset["katago_game_state"]; if (!katago_game_state.is_none()) { katago_game_state = py::module::import("pickle").attr("dumps")(katago_game_state); diff --git a/lzero/mcts/ctree/ctree_alphazero/pybind11 b/lzero/mcts/ctree/ctree_alphazero/pybind11 index f2606930b..98bd78f06 160000 --- a/lzero/mcts/ctree/ctree_alphazero/pybind11 +++ b/lzero/mcts/ctree/ctree_alphazero/pybind11 @@ -1 +1 @@ -Subproject commit f2606930bf5d1140daecc6bc2aea2baf4b58f7ff +Subproject commit 98bd78f063b2f30570740030cb2d13b2a62a062c diff --git a/lzero/mcts/ptree/ptree_az.py b/lzero/mcts/ptree/ptree_az.py index 58143c481..2dc89249c 100644 --- a/lzero/mcts/ptree/ptree_az.py +++ b/lzero/mcts/ptree/ptree_az.py @@ -261,7 +261,7 @@ def get_next_action( action = actions[np.argmax(action_probs)] # Return the selected action and the output probability of each action. - return action, action_probs + return action, action_probs, None def _simulate(self, node: Node, simulate_env: Type[BaseEnv], policy_forward_fn: Callable) -> None: """ diff --git a/lzero/policy/alphazero.py b/lzero/policy/alphazero.py index e87c3bd5f..078ac1531 100644 --- a/lzero/policy/alphazero.py +++ b/lzero/policy/alphazero.py @@ -251,7 +251,11 @@ def _forward_collect(self, obs: Dict, temperature: float = 1) -> Dict[str, torch """ self.collect_mcts_temperature = temperature ready_env_id = list(obs.keys()) - init_state = {env_id: obs[env_id]['board'] for env_id in ready_env_id} + if self._cfg.simulation_env_id == 'chess': # obs[env_id]['board'] is FEN str + init_state = {env_id: obs[env_id]['board'].encode() for env_id in ready_env_id} # str → bytes + else: + init_state = {env_id: obs[env_id]['board'] for env_id in ready_env_id} + # If 'katago_game_state' is in the observation of the given environment ID, it's value is used. # If it's not present (which will raise a KeyError), None is used instead. # This approach is taken to maintain compatibility with the handling of 'katago' related parts of 'alphazero_mcts_ctree' in Go. @@ -259,9 +263,11 @@ def _forward_collect(self, obs: Dict, temperature: float = 1) -> Dict[str, torch start_player_index = {env_id: obs[env_id]['current_player_index'] for env_id in ready_env_id} output = {} self._policy_model = self._collect_model + for env_id in ready_env_id: state_config_for_simulation_env_reset = EasyDict(dict(start_player_index=start_player_index[env_id], - init_state=init_state[env_id], + # init_state=init_state[env_id], # orig + init_state=np.frombuffer(init_state[env_id], dtype=np.int8) if self._cfg.simulation_env_id == 'chess' else init_state[env_id], katago_policy_init=False, katago_game_state=katago_game_state[env_id])) action, mcts_probs, root = self._collect_mcts.get_next_action(state_config_for_simulation_env_reset, self._policy_value_fn, self.collect_mcts_temperature, True) @@ -314,7 +320,11 @@ def _forward_eval(self, obs: Dict) -> Dict[str, torch.Tensor]: the corresponding policy output in this timestep, including action, probs and so on. """ ready_env_id = list(obs.keys()) - init_state = {env_id: obs[env_id]['board'] for env_id in ready_env_id} + if self._cfg.simulation_env_id == 'chess': # obs[env_id]['board'] is FEN str + init_state = {env_id: obs[env_id]['board'].encode() for env_id in ready_env_id} # str → bytes + else: + init_state = {env_id: obs[env_id]['board'] for env_id in ready_env_id} + # If 'katago_game_state' is in the observation of the given environment ID, it's value is used. # If it's not present (which will raise a KeyError), None is used instead. # This approach is taken to maintain compatibility with the handling of 'katago' related parts of 'alphazero_mcts_ctree' in Go. @@ -324,7 +334,7 @@ def _forward_eval(self, obs: Dict) -> Dict[str, torch.Tensor]: self._policy_model = self._eval_model for env_id in ready_env_id: state_config_for_simulation_env_reset = EasyDict(dict(start_player_index=start_player_index[env_id], - init_state=init_state[env_id], + init_state=np.frombuffer(init_state[env_id], dtype=np.int8) if self._cfg.simulation_env_id == 'chess' else init_state[env_id], katago_policy_init=False, katago_game_state=katago_game_state[env_id])) action, mcts_probs, root = self._eval_mcts.get_next_action( diff --git a/zoo/board_games/chess/config/chess_alphazero_sp_mode_config.py b/zoo/board_games/chess/config/chess_alphazero_sp_mode_config.py index 557b6e508..7b6b67132 100644 --- a/zoo/board_games/chess/config/chess_alphazero_sp_mode_config.py +++ b/zoo/board_games/chess/config/chess_alphazero_sp_mode_config.py @@ -11,15 +11,17 @@ batch_size = 512 max_env_step = int(1e6) mcts_ctree = True +# mcts_ctree = False + # TODO: for debug -# collector_env_num = 2 -# n_episode = 2 -# evaluator_env_num = 2 -# num_simulations = 4 -# update_per_collect = 2 -# batch_size = 2 -# max_env_step = int(1e4) +collector_env_num = 2 +n_episode = 2 +evaluator_env_num = 2 +num_simulations = 4 +update_per_collect = 2 +batch_size = 2 +max_env_step = int(1e4) # mcts_ctree = False # ============================================================== # end of the most frequently changed config specified by the user diff --git a/zoo/board_games/chess/envs/chess_lightzero_env.py b/zoo/board_games/chess/envs/chess_lightzero_env.py index 401c643d1..3b89183cf 100644 --- a/zoo/board_games/chess/envs/chess_lightzero_env.py +++ b/zoo/board_games/chess/envs/chess_lightzero_env.py @@ -57,6 +57,8 @@ def observe(self, agent_index): observation = chess_utils.get_observation(self.board, agent_index).astype(float) # TODO except Exception as e: print('debug') + print(f"self.board:{self.board}") + # TODO: # observation = np.dstack((observation[:, :, :7], self.board_history)) @@ -109,10 +111,6 @@ def get_done_winner(self): return done, winner def reset(self, start_player_index=0, init_state=None, katago_policy_init=False, katago_game_state=None): - if self.alphazero_mcts_ctree and init_state is not None: - # Convert byte string to np.ndarray - init_state = np.frombuffer(init_state, dtype=np.int32) - if self.scale: self._observation_space = spaces.Dict( { @@ -131,8 +129,16 @@ def reset(self, start_player_index=0, init_state=None, katago_policy_init=False, self._reward_space = spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32) self.start_player_index = start_player_index self._current_player = self.players[self.start_player_index] + if init_state is not None: - self.board = chess.Board(init_state) + if isinstance(init_state, np.ndarray): + # ndarray → bytes → str + fen = init_state.tobytes().decode() + elif isinstance(init_state, (bytes, bytearray)): + fen = init_state.decode() + else: # init_state is str + fen = init_state + self.board = chess.Board(fen) else: self.board = chess.Board() From 6649e851c9ca627ed963a164b3c43c611f7845cb Mon Sep 17 00:00:00 2001 From: puyuan Date: Fri, 23 May 2025 18:51:38 +0800 Subject: [PATCH 2/2] polish(pu): polish chess config --- .../config/chess_alphazero_bot_mode_config.py | 34 +++++++++--------- .../config/chess_alphazero_sp_mode_config.py | 36 +++++++++---------- .../chess/envs/chess_lightzero_env.py | 33 ++++++++--------- 3 files changed, 51 insertions(+), 52 deletions(-) diff --git a/zoo/board_games/chess/config/chess_alphazero_bot_mode_config.py b/zoo/board_games/chess/config/chess_alphazero_bot_mode_config.py index 25d130747..ac659caf0 100644 --- a/zoo/board_games/chess/config/chess_alphazero_bot_mode_config.py +++ b/zoo/board_games/chess/config/chess_alphazero_bot_mode_config.py @@ -3,24 +3,24 @@ # ============================================================== # begin of the most frequently changed config specified by the user # ============================================================== -# collector_env_num = 8 -# n_episode = 8 -# evaluator_env_num = 5 -# num_simulations = 400 -# update_per_collect = 200 -# batch_size = 512 -# max_env_step = int(1e6) -# mcts_ctree = False +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 5 +num_simulations = 400 +update_per_collect = 200 +batch_size = 512 +max_env_step = int(1e6) +mcts_ctree = False # TODO: for debug -collector_env_num = 2 -n_episode = 2 -evaluator_env_num = 2 -num_simulations = 4 -update_per_collect = 2 -batch_size = 2 -max_env_step = int(1e4) -mcts_ctree = False +# collector_env_num = 2 +# n_episode = 2 +# evaluator_env_num = 2 +# num_simulations = 4 +# update_per_collect = 2 +# batch_size = 2 +# max_env_step = int(1e4) +# mcts_ctree = False # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== @@ -56,7 +56,7 @@ model=dict( observation_shape=(8, 8, 20), action_space_size=int(8 * 8 * 73), - # TODO: for debug + # TODO: only for for debug num_res_blocks=1, num_channels=1, value_head_hidden_channels=[16], diff --git a/zoo/board_games/chess/config/chess_alphazero_sp_mode_config.py b/zoo/board_games/chess/config/chess_alphazero_sp_mode_config.py index 7b6b67132..cee7a6f99 100644 --- a/zoo/board_games/chess/config/chess_alphazero_sp_mode_config.py +++ b/zoo/board_games/chess/config/chess_alphazero_sp_mode_config.py @@ -10,24 +10,22 @@ update_per_collect = 200 batch_size = 512 max_env_step = int(1e6) -mcts_ctree = True -# mcts_ctree = False - +mcts_ctree = False # TODO: for debug -collector_env_num = 2 -n_episode = 2 -evaluator_env_num = 2 -num_simulations = 4 -update_per_collect = 2 -batch_size = 2 -max_env_step = int(1e4) +# collector_env_num = 2 +# n_episode = 2 +# evaluator_env_num = 2 +# num_simulations = 2 +# update_per_collect = 1 +# batch_size = 2 +# max_env_step = int(1e4) # mcts_ctree = False # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== chess_alphazero_config = dict( - exp_name='data_az_ctree/chess_sp-mode_alphazero_seed0', + exp_name='data_az_ptree/chess_sp-mode_alphazero_seed0', env=dict( board_size=8, battle_mode='self_play_mode', @@ -58,14 +56,14 @@ observation_shape=(8, 8, 20), action_space_size=int(8 * 8 * 73), # TODO: for debug - num_res_blocks=1, - num_channels=1, - value_head_hidden_channels=[16], - policy_head_hidden_channels=[16], - # num_res_blocks=8, - # num_channels=256, - # value_head_hidden_channels=[256, 256], - # policy_head_hidden_channels=[256, 256], + # num_res_blocks=1, + # num_channels=1, + # value_head_hidden_channels=[16], + # policy_head_hidden_channels=[16], + num_res_blocks=8, + num_channels=256, + value_head_hidden_channels=[256, 256], + policy_head_hidden_channels=[256, 256], ), cuda=True, board_size=8, diff --git a/zoo/board_games/chess/envs/chess_lightzero_env.py b/zoo/board_games/chess/envs/chess_lightzero_env.py index 3b89183cf..90d34560c 100644 --- a/zoo/board_games/chess/envs/chess_lightzero_env.py +++ b/zoo/board_games/chess/envs/chess_lightzero_env.py @@ -10,9 +10,8 @@ from ding.envs.env.base_env import BaseEnvTimestep from ding.utils.registry_factory import ENV_REGISTRY from gymnasium import spaces -from pettingzoo.classic.chess import chess_utils - from zoo.board_games.chess.envs.chess_env import ChessEnv +from pettingzoo.classic.chess import chess_utils as pz_cu @ENV_REGISTRY.register('chess_lightzero') @@ -50,16 +49,15 @@ def __init__(self, cfg=None): @property def legal_actions(self): - return chess_utils.legal_moves(self.board) + return pz_cu.legal_moves(self.board) def observe(self, agent_index): try: - observation = chess_utils.get_observation(self.board, agent_index).astype(float) # TODO + observation = pz_cu.get_observation(self.board, agent_index).astype(float) # TODO except Exception as e: - print('debug') + print(f'debug: {e}') print(f"self.board:{self.board}") - # TODO: # observation = np.dstack((observation[:, :, :7], self.board_history)) # We need to swap the white 6 channels with black 6 channels @@ -75,9 +73,12 @@ def observe(self, agent_index): # observation[..., 13 * i : 13 * i + 6] = tmp action_mask = np.zeros(4672, dtype=np.int8) - action_mask[chess_utils.legal_moves(self.board)] = 1 + action_mask[pz_cu.legal_moves(self.board)] = 1 return {'observation': observation, 'action_mask': action_mask} + + + def current_state(self): """ Overview: @@ -103,7 +104,7 @@ def get_done_winner(self): if result == "*": winner = -1 else: - winner = chess_utils.result_to_int(result) + winner = pz_cu.result_to_int(result) if not done: winner = -1 @@ -143,7 +144,7 @@ def reset(self, start_player_index=0, init_state=None, katago_policy_init=False, self.board = chess.Board() action_mask = np.zeros(4672, dtype=np.int8) - action_mask[chess_utils.legal_moves(self.board)] = 1 + action_mask[pz_cu.legal_moves(self.board)] = 1 # self.board_history = np.zeros((8, 8, 104), dtype=bool) if self.battle_mode == 'play_with_bot_mode' or self.battle_mode == 'eval_mode': @@ -265,10 +266,10 @@ def _player_step(self, action): current_agent = self.current_player_index # TODO: Update board history - # next_board = chess_utils.get_observation(self.board, current_agent) + # next_board = pz_cu.get_observation(self.board, current_agent) # self.board_history = np.dstack((next_board[:, :, 7:], self.board_history[:, :, :-13])) - chosen_move = chess_utils.action_to_move(self.board, action, current_agent) + chosen_move = pz_cu.action_to_move(self.board, action, current_agent) assert chosen_move in self.board.legal_moves self.board.push(chosen_move) @@ -277,7 +278,7 @@ def _player_step(self, action): if result == "*": reward = 0. else: - reward = chess_utils.result_to_int(result) + reward = pz_cu.result_to_int(result) if self.current_player == 1: reward = -reward @@ -287,7 +288,7 @@ def _player_step(self, action): info['eval_episode_return'] = reward action_mask = np.zeros(4672, dtype=np.int8) - action_mask[chess_utils.legal_moves(self.board)] = 1 + action_mask[pz_cu.legal_moves(self.board)] = 1 obs = { 'observation': self.observe(self.current_player_index)['observation'], @@ -318,14 +319,14 @@ def current_player(self, value): self._current_player = value def random_action(self): - action_list = chess_utils.legal_moves(self.board) + action_list = pz_cu.legal_moves(self.board) return np.random.choice(action_list) def simulate_action(self, action): - if action not in chess_utils.legal_moves(self.board): + if action not in pz_cu.legal_moves(self.board): raise ValueError("action {0} on board {1} is not legal".format(action, self.board.fen())) new_board = copy.deepcopy(self.board) - new_board.push(chess_utils.action_to_move(self.board, action, self.current_player_index)) + new_board.push(pz_cu.action_to_move(self.board, action, self.current_player_index)) if self.start_player_index == 0: start_player_index = 1 else: