Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lzero/entry/train_alphazero.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def train_alphazero(
)

# Evaluate policy performance
if evaluator.should_eval(learner.train_iter) and learner.train_iter > 0:
if evaluator.should_eval(learner.train_iter) or learner.train_iter == 0:
stop, reward = evaluator.eval(
learner.save_checkpoint,
learner.train_iter,
Expand Down
1 change: 1 addition & 0 deletions lzero/mcts/ctree/ctree_alphazero/mcts_alphazero.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ class MCTS {
if (!init_state.is_none()) {
init_state = py::bytes(init_state.attr("tobytes")());
}

py::object katago_game_state = state_config_for_env_reset["katago_game_state"];
if (!katago_game_state.is_none()) {
katago_game_state = py::module::import("pickle").attr("dumps")(katago_game_state);
Expand Down
2 changes: 1 addition & 1 deletion lzero/mcts/ctree/ctree_alphazero/pybind11
Submodule pybind11 updated 286 files
2 changes: 1 addition & 1 deletion lzero/mcts/ptree/ptree_az.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def get_next_action(
action = actions[np.argmax(action_probs)]

# Return the selected action and the output probability of each action.
return action, action_probs
return action, action_probs, None

def _simulate(self, node: Node, simulate_env: Type[BaseEnv], policy_forward_fn: Callable) -> None:
"""
Expand Down
18 changes: 14 additions & 4 deletions lzero/policy/alphazero.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,17 +251,23 @@ def _forward_collect(self, obs: Dict, temperature: float = 1) -> Dict[str, torch
"""
self.collect_mcts_temperature = temperature
ready_env_id = list(obs.keys())
init_state = {env_id: obs[env_id]['board'] for env_id in ready_env_id}
if self._cfg.simulation_env_id == 'chess': # obs[env_id]['board'] is FEN str
init_state = {env_id: obs[env_id]['board'].encode() for env_id in ready_env_id} # str → bytes
else:
init_state = {env_id: obs[env_id]['board'] for env_id in ready_env_id}

# If 'katago_game_state' is in the observation of the given environment ID, it's value is used.
# If it's not present (which will raise a KeyError), None is used instead.
# This approach is taken to maintain compatibility with the handling of 'katago' related parts of 'alphazero_mcts_ctree' in Go.
katago_game_state = {env_id: obs[env_id].get('katago_game_state', None) for env_id in ready_env_id}
start_player_index = {env_id: obs[env_id]['current_player_index'] for env_id in ready_env_id}
output = {}
self._policy_model = self._collect_model

for env_id in ready_env_id:
state_config_for_simulation_env_reset = EasyDict(dict(start_player_index=start_player_index[env_id],
init_state=init_state[env_id],
# init_state=init_state[env_id], # orig
init_state=np.frombuffer(init_state[env_id], dtype=np.int8) if self._cfg.simulation_env_id == 'chess' else init_state[env_id],
katago_policy_init=False,
katago_game_state=katago_game_state[env_id]))
action, mcts_probs, root = self._collect_mcts.get_next_action(state_config_for_simulation_env_reset, self._policy_value_fn, self.collect_mcts_temperature, True)
Expand Down Expand Up @@ -314,7 +320,11 @@ def _forward_eval(self, obs: Dict) -> Dict[str, torch.Tensor]:
the corresponding policy output in this timestep, including action, probs and so on.
"""
ready_env_id = list(obs.keys())
init_state = {env_id: obs[env_id]['board'] for env_id in ready_env_id}
if self._cfg.simulation_env_id == 'chess': # obs[env_id]['board'] is FEN str
init_state = {env_id: obs[env_id]['board'].encode() for env_id in ready_env_id} # str → bytes
else:
init_state = {env_id: obs[env_id]['board'] for env_id in ready_env_id}

# If 'katago_game_state' is in the observation of the given environment ID, it's value is used.
# If it's not present (which will raise a KeyError), None is used instead.
# This approach is taken to maintain compatibility with the handling of 'katago' related parts of 'alphazero_mcts_ctree' in Go.
Expand All @@ -324,7 +334,7 @@ def _forward_eval(self, obs: Dict) -> Dict[str, torch.Tensor]:
self._policy_model = self._eval_model
for env_id in ready_env_id:
state_config_for_simulation_env_reset = EasyDict(dict(start_player_index=start_player_index[env_id],
init_state=init_state[env_id],
init_state=np.frombuffer(init_state[env_id], dtype=np.int8) if self._cfg.simulation_env_id == 'chess' else init_state[env_id],
katago_policy_init=False,
katago_game_state=katago_game_state[env_id]))
action, mcts_probs, root = self._eval_mcts.get_next_action(
Expand Down
34 changes: 17 additions & 17 deletions zoo/board_games/chess/config/chess_alphazero_bot_mode_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,24 @@
# ==============================================================
# begin of the most frequently changed config specified by the user
# ==============================================================
# collector_env_num = 8
# n_episode = 8
# evaluator_env_num = 5
# num_simulations = 400
# update_per_collect = 200
# batch_size = 512
# max_env_step = int(1e6)
# mcts_ctree = False
collector_env_num = 8
n_episode = 8
evaluator_env_num = 5
num_simulations = 400
update_per_collect = 200
batch_size = 512
max_env_step = int(1e6)
mcts_ctree = False

# TODO: for debug
collector_env_num = 2
n_episode = 2
evaluator_env_num = 2
num_simulations = 4
update_per_collect = 2
batch_size = 2
max_env_step = int(1e4)
mcts_ctree = False
# collector_env_num = 2
# n_episode = 2
# evaluator_env_num = 2
# num_simulations = 4
# update_per_collect = 2
# batch_size = 2
# max_env_step = int(1e4)
# mcts_ctree = False
# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================
Expand Down Expand Up @@ -56,7 +56,7 @@
model=dict(
observation_shape=(8, 8, 20),
action_space_size=int(8 * 8 * 73),
# TODO: for debug
# TODO: only for for debug
num_res_blocks=1,
num_channels=1,
value_head_hidden_channels=[16],
Expand Down
24 changes: 12 additions & 12 deletions zoo/board_games/chess/config/chess_alphazero_sp_mode_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,22 @@
update_per_collect = 200
batch_size = 512
max_env_step = int(1e6)
mcts_ctree = True
mcts_ctree = False

# TODO: for debug
# collector_env_num = 2
# n_episode = 2
# evaluator_env_num = 2
# num_simulations = 4
# update_per_collect = 2
# num_simulations = 2
# update_per_collect = 1
# batch_size = 2
# max_env_step = int(1e4)
# mcts_ctree = False
# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================
chess_alphazero_config = dict(
exp_name='data_az_ctree/chess_sp-mode_alphazero_seed0',
exp_name='data_az_ptree/chess_sp-mode_alphazero_seed0',
env=dict(
board_size=8,
battle_mode='self_play_mode',
Expand Down Expand Up @@ -56,14 +56,14 @@
observation_shape=(8, 8, 20),
action_space_size=int(8 * 8 * 73),
# TODO: for debug
num_res_blocks=1,
num_channels=1,
value_head_hidden_channels=[16],
policy_head_hidden_channels=[16],
# num_res_blocks=8,
# num_channels=256,
# value_head_hidden_channels=[256, 256],
# policy_head_hidden_channels=[256, 256],
# num_res_blocks=1,
# num_channels=1,
# value_head_hidden_channels=[16],
# policy_head_hidden_channels=[16],
num_res_blocks=8,
num_channels=256,
value_head_hidden_channels=[256, 256],
policy_head_hidden_channels=[256, 256],
),
cuda=True,
board_size=8,
Expand Down
47 changes: 27 additions & 20 deletions zoo/board_games/chess/envs/chess_lightzero_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
from ding.envs.env.base_env import BaseEnvTimestep
from ding.utils.registry_factory import ENV_REGISTRY
from gymnasium import spaces
from pettingzoo.classic.chess import chess_utils

from zoo.board_games.chess.envs.chess_env import ChessEnv
from pettingzoo.classic.chess import chess_utils as pz_cu


@ENV_REGISTRY.register('chess_lightzero')
Expand Down Expand Up @@ -50,13 +49,14 @@ def __init__(self, cfg=None):

@property
def legal_actions(self):
return chess_utils.legal_moves(self.board)
return pz_cu.legal_moves(self.board)

def observe(self, agent_index):
try:
observation = chess_utils.get_observation(self.board, agent_index).astype(float) # TODO
observation = pz_cu.get_observation(self.board, agent_index).astype(float) # TODO
except Exception as e:
print('debug')
print(f'debug: {e}')
print(f"self.board:{self.board}")

# TODO:
# observation = np.dstack((observation[:, :, :7], self.board_history))
Expand All @@ -73,9 +73,12 @@ def observe(self, agent_index):
# observation[..., 13 * i : 13 * i + 6] = tmp

action_mask = np.zeros(4672, dtype=np.int8)
action_mask[chess_utils.legal_moves(self.board)] = 1
action_mask[pz_cu.legal_moves(self.board)] = 1
return {'observation': observation, 'action_mask': action_mask}




def current_state(self):
"""
Overview:
Expand All @@ -101,18 +104,14 @@ def get_done_winner(self):
if result == "*":
winner = -1
else:
winner = chess_utils.result_to_int(result)
winner = pz_cu.result_to_int(result)

if not done:
winner = -1

return done, winner

def reset(self, start_player_index=0, init_state=None, katago_policy_init=False, katago_game_state=None):
if self.alphazero_mcts_ctree and init_state is not None:
# Convert byte string to np.ndarray
init_state = np.frombuffer(init_state, dtype=np.int32)

if self.scale:
self._observation_space = spaces.Dict(
{
Expand All @@ -131,13 +130,21 @@ def reset(self, start_player_index=0, init_state=None, katago_policy_init=False,
self._reward_space = spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32)
self.start_player_index = start_player_index
self._current_player = self.players[self.start_player_index]

if init_state is not None:
self.board = chess.Board(init_state)
if isinstance(init_state, np.ndarray):
# ndarray → bytes → str
fen = init_state.tobytes().decode()
elif isinstance(init_state, (bytes, bytearray)):
fen = init_state.decode()
else: # init_state is str
fen = init_state
self.board = chess.Board(fen)
else:
self.board = chess.Board()

action_mask = np.zeros(4672, dtype=np.int8)
action_mask[chess_utils.legal_moves(self.board)] = 1
action_mask[pz_cu.legal_moves(self.board)] = 1
# self.board_history = np.zeros((8, 8, 104), dtype=bool)

if self.battle_mode == 'play_with_bot_mode' or self.battle_mode == 'eval_mode':
Expand Down Expand Up @@ -259,10 +266,10 @@ def _player_step(self, action):
current_agent = self.current_player_index

# TODO: Update board history
# next_board = chess_utils.get_observation(self.board, current_agent)
# next_board = pz_cu.get_observation(self.board, current_agent)
# self.board_history = np.dstack((next_board[:, :, 7:], self.board_history[:, :, :-13]))

chosen_move = chess_utils.action_to_move(self.board, action, current_agent)
chosen_move = pz_cu.action_to_move(self.board, action, current_agent)
assert chosen_move in self.board.legal_moves
self.board.push(chosen_move)

Expand All @@ -271,7 +278,7 @@ def _player_step(self, action):
if result == "*":
reward = 0.
else:
reward = chess_utils.result_to_int(result)
reward = pz_cu.result_to_int(result)

if self.current_player == 1:
reward = -reward
Expand All @@ -281,7 +288,7 @@ def _player_step(self, action):
info['eval_episode_return'] = reward

action_mask = np.zeros(4672, dtype=np.int8)
action_mask[chess_utils.legal_moves(self.board)] = 1
action_mask[pz_cu.legal_moves(self.board)] = 1

obs = {
'observation': self.observe(self.current_player_index)['observation'],
Expand Down Expand Up @@ -312,14 +319,14 @@ def current_player(self, value):
self._current_player = value

def random_action(self):
action_list = chess_utils.legal_moves(self.board)
action_list = pz_cu.legal_moves(self.board)
return np.random.choice(action_list)

def simulate_action(self, action):
if action not in chess_utils.legal_moves(self.board):
if action not in pz_cu.legal_moves(self.board):
raise ValueError("action {0} on board {1} is not legal".format(action, self.board.fen()))
new_board = copy.deepcopy(self.board)
new_board.push(chess_utils.action_to_move(self.board, action, self.current_player_index))
new_board.push(pz_cu.action_to_move(self.board, action, self.current_player_index))
if self.start_player_index == 0:
start_player_index = 1
else:
Expand Down
Loading