diff --git a/README.md b/README.md index 9977c9423..0632c518d 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ [![GitHub license](https://img.shields.io/github/license/opendilab/LightZero)](https://github.com/opendilab/LightZero/blob/master/LICENSE) [![discord badge](https://dcbadge.vercel.app/api/server/dkZS2JF56X?style=flat)](https://discord.gg/dkZS2JF56X) -Updated on 2025.04.09 LightZero-v0.2.0 +Updated on 2025.06.03 LightZero-v0.2.0 English | [简体中文(Simplified Chinese)](https://github.com/opendilab/LightZero/blob/main/README.zh.md) | [Documentation](https://opendilab.github.io/LightZero) | [LightZero Paper](https://arxiv.org/abs/2310.08348) | [🔥UniZero Paper](https://arxiv.org/abs/2406.10667) | [🔥ReZero Paper](https://arxiv.org/abs/2404.16364) diff --git a/README.zh.md b/README.zh.md index 064e0c200..5fa336d2a 100644 --- a/README.zh.md +++ b/README.zh.md @@ -27,7 +27,7 @@ [![Contributors](https://img.shields.io/github/contributors/opendilab/LightZero)](https://github.com/opendilab/LightZero/graphs/contributors) [![GitHub license](https://img.shields.io/github/license/opendilab/LightZero)](https://github.com/opendilab/LightZero/blob/master/LICENSE) -最近更新于 2025.04.09 LightZero-v0.2.0 +最近更新于 2025.06.03 LightZero-v0.2.0 [English](https://github.com/opendilab/LightZero/blob/main/README.md) | 简体中文 | [文档](https://opendilab.github.io/LightZero) | [LightZero 论文](https://arxiv.org/abs/2310.08348) | [🔥UniZero 论文](https://arxiv.org/abs/2406.10667) | [🔥ReZero 论文](https://arxiv.org/abs/2404.16364) diff --git a/docs/source/tutorials/algos/customize_algos.md b/docs/source/tutorials/algos/customize_algos.md index 88e48513c..c4d6bf2c4 100644 --- a/docs/source/tutorials/algos/customize_algos.md +++ b/docs/source/tutorials/algos/customize_algos.md @@ -119,16 +119,17 @@ Here is an example of unit testing in LightZero. In this example, we test the `i ```Python import pytest import torch -from lzero.policy.scaling_transform import inverse_scalar_transform, InverseScalarTransform +from lzero.policy.scaling_transform import DiscreteSupport, inverse_scalar_transform, InverseScalarTransform @pytest.mark.unittest def test_scaling_transform(): import time logit = torch.randn(16, 601) + discrete_support = DiscreteSupport(-300., 301., 1.) start = time.time() - output_1 = inverse_scalar_transform(logit, 300) + output_1 = inverse_scalar_transform(logit, discrete_support) print('t1', time.time() - start) - handle = InverseScalarTransform(300) + handle = InverseScalarTransform(discrete_support) start = time.time() output_2 = handle(logit) print('t2', time.time() - start) diff --git a/docs/source/tutorials/algos/customize_algos_zh.md b/docs/source/tutorials/algos/customize_algos_zh.md index 4d115aefa..b06c38f68 100644 --- a/docs/source/tutorials/algos/customize_algos_zh.md +++ b/docs/source/tutorials/algos/customize_algos_zh.md @@ -120,16 +120,17 @@ if timestep.done: ```Python import pytest import torch -from lzero.policy.scaling_transform import inverse_scalar_transform, InverseScalarTransform +from lzero.policy.scaling_transform import DiscreteSupport, inverse_scalar_transform, InverseScalarTransform @pytest.mark.unittest def test_scaling_transform(): import time logit = torch.randn(16, 601) + discrete_support = DiscreteSupport(-300., 301., 1.) start = time.time() - output_1 = inverse_scalar_transform(logit, 300) + output_1 = inverse_scalar_transform(logit, discrete_support) print('t1', time.time() - start) - handle = InverseScalarTransform(300) + handle = InverseScalarTransform(discrete_support) start = time.time() output_2 = handle(logit) print('t2', time.time() - start) diff --git a/docs/source/tutorials/config/config.md b/docs/source/tutorials/config/config.md index f868c8053..06a908815 100644 --- a/docs/source/tutorials/config/config.md +++ b/docs/source/tutorials/config/config.md @@ -44,7 +44,8 @@ The `main_config` dictionary contains the main parameter settings for running th - `downsample`: Whether to downsample the input. - `norm_type`: The type of normalization used. - `num_channels`: The number of channels in the convolutional layers (number of features extracted). - - `support_scale`: The range of the value support set (`-support_scale` to `support_scale`). + - `reward_support_range`: The range of the reward support set (`(start, stop, step)`). + - `value_support_range`: The range of the value support set (`(start, stop, step)`). - `bias`: Whether to use bias terms in the layers. - `discrete_action_encoding_type`: How discrete actions are encoded. - `self_supervised_learning_loss`: Whether to use a self-supervised learning loss (as in EfficientZero). diff --git a/docs/source/tutorials/config/config_zh.md b/docs/source/tutorials/config/config_zh.md index 824b44799..5068c71ac 100644 --- a/docs/source/tutorials/config/config_zh.md +++ b/docs/source/tutorials/config/config_zh.md @@ -43,7 +43,8 @@ - `downsample`: 是否进行降采样。 - `norm_type`: 归一化使用的方法。 - `num_channels`: 卷积层提取的特征个数。 - - `support_scale`: 价值支持集的范围 (-support_scale, support_scale)。 + - `reward_support_range`: 价值支持集的范围 (`(start, stop, step)`)。 + - `value_support_range`: 价值支持集的范围 (`(start, stop, step)`)。 - `bias`: 是否使用偏置。 - `discrete_action_encoding_type`: 离散化动作空间使用的编码类型。 - `self_supervised_learning_loss`: 是否使用自监督学习损失(参照EfficientZero的实现)。 diff --git a/lzero/agent/config/gumbel_muzero/gomoku_play_with_bot.py b/lzero/agent/config/gumbel_muzero/gomoku_play_with_bot.py index 40e04834e..c840b1df4 100644 --- a/lzero/agent/config/gumbel_muzero/gomoku_play_with_bot.py +++ b/lzero/agent/config/gumbel_muzero/gomoku_play_with_bot.py @@ -44,9 +44,8 @@ image_channel=3, num_res_blocks=1, num_channels=32, - support_scale=10, - reward_support_size=21, - value_support_size=21, + reward_support_range=(-10., 11., 1.), + value_support_range=(-10., 11., 1.), ), cuda=True, env_type='board_games', diff --git a/lzero/agent/config/gumbel_muzero/tictactoe_play_with_bot.py b/lzero/agent/config/gumbel_muzero/tictactoe_play_with_bot.py index 865bf49ea..aeb5c27ac 100644 --- a/lzero/agent/config/gumbel_muzero/tictactoe_play_with_bot.py +++ b/lzero/agent/config/gumbel_muzero/tictactoe_play_with_bot.py @@ -38,9 +38,8 @@ reward_head_hidden_channels=[8], value_head_hidden_channels=[8], policy_head_hidden_channels=[8], - support_scale=10, - reward_support_size=21, - value_support_size=21, + reward_support_range=(-10., 11., 1.), + value_support_range=(-10., 11., 1.), ), cuda=True, env_type='board_games', diff --git a/lzero/agent/config/muzero/gomoku_play_with_bot.py b/lzero/agent/config/muzero/gomoku_play_with_bot.py index 7158a7fa5..d6db6042e 100644 --- a/lzero/agent/config/muzero/gomoku_play_with_bot.py +++ b/lzero/agent/config/muzero/gomoku_play_with_bot.py @@ -44,9 +44,8 @@ image_channel=3, num_res_blocks=1, num_channels=32, - support_scale=10, - reward_support_size=21, - value_support_size=21, + reward_support_range=(-10., 11., 1.), + value_support_range=(-10., 11., 1.), ), cuda=True, env_type='board_games', diff --git a/lzero/agent/config/muzero/tictactoe_play_with_bot.py b/lzero/agent/config/muzero/tictactoe_play_with_bot.py index 6e16f5e02..531978cfd 100644 --- a/lzero/agent/config/muzero/tictactoe_play_with_bot.py +++ b/lzero/agent/config/muzero/tictactoe_play_with_bot.py @@ -38,9 +38,8 @@ reward_head_hidden_channels=[8], value_head_hidden_channels=[8], policy_head_hidden_channels=[8], - support_scale=10, - reward_support_size=21, - value_support_size=21, + reward_support_range=(-10., 11., 1.), + value_support_range=(-10., 11., 1.), norm_type='BN', ), cuda=True, diff --git a/lzero/entry/__init__.py b/lzero/entry/__init__.py index 68f3a66aa..2a269a261 100644 --- a/lzero/entry/__init__.py +++ b/lzero/entry/__init__.py @@ -1,5 +1,6 @@ from .eval_alphazero import eval_alphazero from .eval_muzero import eval_muzero + from .eval_muzero_with_gym_env import eval_muzero_with_gym_env from .train_alphazero import train_alphazero from .train_muzero import train_muzero @@ -12,4 +13,5 @@ from .train_muzero_multitask_segment_ddp import train_muzero_multitask_segment_ddp from .train_unizero_multitask_segment_ddp import train_unizero_multitask_segment_ddp from .train_unizero_multitask_segment_eval import train_unizero_multitask_segment_eval -from .utils import * +from .train_unizero_multitask_balance_segment_ddp import train_unizero_multitask_balance_segment_ddp +from .utils import * \ No newline at end of file diff --git a/lzero/entry/compute_task_weight.py b/lzero/entry/compute_task_weight.py deleted file mode 100644 index 84204a9a2..000000000 --- a/lzero/entry/compute_task_weight.py +++ /dev/null @@ -1,80 +0,0 @@ - - - -import numpy as np -import torch - - -def symlog(x: torch.Tensor) -> torch.Tensor: - """ - Symlog 归一化,减少目标值的幅度差异。 - symlog(x) = sign(x) * log(|x| + 1) - """ - return torch.sign(x) * torch.log(torch.abs(x) + 1) - - -def inv_symlog(x: torch.Tensor) -> torch.Tensor: - """ - Symlog 的逆操作,用于恢复原始值。 - inv_symlog(x) = sign(x) * (exp(|x|) - 1) - """ - return torch.sign(x) * (torch.exp(torch.abs(x)) - 1) - - -def compute_task_weights( - task_rewards: dict, - epsilon: float = 1e-6, - min_weight: float = 0.1, - max_weight: float = 0.5, - temperature: float = 1.0, - use_symlog: bool = True, -) -> dict: - """ - 改进后的任务权重计算函数,加入 symlog 处理和鲁棒性设计。 - - Args: - task_rewards (dict): 每个任务的字典,键为 task_id,值为评估奖励。 - epsilon (float): 避免分母为零的小值。 - min_weight (float): 权重的最小值,用于裁剪。 - max_weight (float): 权重的最大值,用于裁剪。 - temperature (float): 控制权重分布的温度系数。 - use_symlog (bool): 是否使用 symlog 对 task_rewards 进行矫正。 - - Returns: - dict: 每个任务的权重,键为 task_id,值为归一化并裁剪后的权重。 - """ - # Step 1: 矫正奖励值(可选,使用 symlog) - if use_symlog: - rewards_tensor = torch.tensor(list(task_rewards.values()), dtype=torch.float32) - corrected_rewards = symlog(rewards_tensor).numpy() # 使用 symlog 矫正 - task_rewards = dict(zip(task_rewards.keys(), corrected_rewards)) - - # Step 2: 计算初始权重(反比例关系) - raw_weights = {task_id: 1 / (reward + epsilon) for task_id, reward in task_rewards.items()} - - # Step 3: 温度缩放 - scaled_weights = {task_id: weight ** (1 / temperature) for task_id, weight in raw_weights.items()} - - # Step 4: 归一化权重 - total_weight = sum(scaled_weights.values()) - normalized_weights = {task_id: weight / total_weight for task_id, weight in scaled_weights.items()} - - # Step 5: 裁剪权重,确保在 [min_weight, max_weight] 范围内 - clipped_weights = {task_id: np.clip(weight, min_weight, max_weight) for task_id, weight in normalized_weights.items()} - - final_weights = clipped_weights - return final_weights - -task_rewards_list = [ - {"task1": 10, "task2": 100, "task3": 1000, "task4": 500, "task5": 300}, - {"task1": 1, "task2": 10, "task3": 100, "task4": 1000, "task5": 10000}, - {"task1": 0.1, "task2": 0.5, "task3": 0.9, "task4": 5, "task5": 10}, -] - -for i, task_rewards in enumerate(task_rewards_list, start=1): - print(f"Case {i}: Original Rewards: {task_rewards}") - print("Original Weights:") - print(compute_task_weights(task_rewards, use_symlog=False)) - print("Improved Weights with Symlog:") - print(compute_task_weights(task_rewards, use_symlog=True)) - print() \ No newline at end of file diff --git a/lzero/entry/eval_muzero.py b/lzero/entry/eval_muzero.py index 4501499ba..6f87c656e 100644 --- a/lzero/entry/eval_muzero.py +++ b/lzero/entry/eval_muzero.py @@ -1,6 +1,7 @@ import os from functools import partial from typing import Optional, Tuple +import logging import numpy as np import torch @@ -51,7 +52,7 @@ def eval_muzero( # Create main components: env, policy env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) - + # print(f"cfg.seed:{cfg.seed}") evaluator_env.seed(cfg.seed, dynamic_seed=False) set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) diff --git a/lzero/entry/train_muzero_multitask_segment_ddp.py b/lzero/entry/train_muzero_multitask_segment_ddp.py index 5ece29f28..5d608271a 100644 --- a/lzero/entry/train_muzero_multitask_segment_ddp.py +++ b/lzero/entry/train_muzero_multitask_segment_ddp.py @@ -1,32 +1,33 @@ +import concurrent.futures import logging import os from functools import partial -from typing import Tuple, Optional, List +from typing import Any, Dict, List, Optional, Tuple -import torch import numpy as np +import torch +import torch.distributed as dist from ding.config import compile_config from ding.envs import create_env_manager, get_vec_env_setting -from ding.policy import create_policy +from ding.policy import Policy, create_policy from ding.rl_utils import get_epsilon_greedy_fn -from ding.utils import set_pkg_seed, get_rank, get_world_size +from ding.utils import EasyTimer, set_pkg_seed, get_rank, get_world_size from ding.worker import BaseLearner from tensorboardX import SummaryWriter from lzero.entry.utils import log_buffer_memory_usage -from lzero.policy import visit_count_temperature from lzero.mcts import MuZeroGameBuffer as GameBuffer +from lzero.policy import visit_count_temperature +from lzero.worker import MuZeroCollector as Collector from lzero.worker import MuZeroEvaluator as Evaluator -from lzero.worker import MuZeroSegmentCollector as Collector -from ding.utils import EasyTimer -import torch.distributed as dist -import concurrent.futures +# ========================== +# Global Constants +# ========================== +EVALUATION_TIMEOUT_SECONDS: int = 3600 +MAX_TRAIN_ITER_INF: int = int(1e10) +MAX_ENV_STEP_INF: int = int(1e10) -# ========== 超时时间设置 ========== -TIMEOUT = 3600 # 例如,60分钟 - -timer = EasyTimer() def safe_eval( evaluator: Evaluator, @@ -36,547 +37,527 @@ def safe_eval( world_size: int ) -> Tuple[Optional[bool], Optional[float]]: """ - 安全地执行评估操作,防止因超时导致训练过程阻塞。 - - Args: - evaluator (Evaluator): 评估器实例。 - learner (BaseLearner): 学习器实例。 - collector (Collector): 数据收集器实例。 - rank (int): 当前进程的排名。 - world_size (int): 总进程数。 - + Overview: + Safely performs an evaluation step with a timeout to prevent the training process from blocking. + Arguments: + - evaluator (:obj:`Evaluator`): The evaluator instance. + - learner (:obj:`BaseLearner`): The learner instance to save checkpoints. + - collector (:obj:`Collector`): The collector instance to get the current envstep. + - rank (:obj:`int`): The rank of the current process. + - world_size (:obj:`int`): The total number of processes. Returns: - Tuple[Optional[bool], Optional[float]]: - - stop (Optional[bool]): 评估是否停止的标志。 - - reward (Optional[float]): 评估得到的奖励。 + - (:obj:`Tuple[Optional[bool], Optional[float]]`): A tuple containing the stop flag and the evaluation reward. + Returns (None, None) if a timeout occurs. """ - print(f"=========评估前 Rank {rank}/{world_size}===========") - # 重置 stop_event,确保每次评估前都处于未设置状态 + logging.info(f"Rank {rank}/{world_size}: Starting evaluation...") + # Ensure the stop_event is clear before each evaluation. evaluator.stop_event.clear() with concurrent.futures.ThreadPoolExecutor() as executor: - # 提交 evaluator.eval 任务 future = executor.submit( evaluator.eval, learner.save_checkpoint, learner.train_iter, collector.envstep ) - try: - stop, reward = future.result(timeout=TIMEOUT) + stop, reward = future.result(timeout=EVALUATION_TIMEOUT_SECONDS) + logging.info(f"Rank {rank}/{world_size}: Evaluation finished successfully. Stop: {stop}, Reward: {reward}") + return stop, reward except concurrent.futures.TimeoutError: - # 超时,设置 evaluator 的 stop_event + # Set the evaluator's stop_event on timeout to gracefully stop the evaluation worker. evaluator.stop_event.set() - print(f"评估操作在 Rank {rank}/{world_size} 上超过 {TIMEOUT} 秒超时。") + logging.warning( + f"Rank {rank}/{world_size}: Evaluation timed out after {EVALUATION_TIMEOUT_SECONDS} seconds. " + f"Continuing training." + ) return None, None - print(f"======评估后 Rank {rank}/{world_size}======") - return stop, reward - def allocate_batch_size( - cfgs: List, + cfgs: List[Any], game_buffers: List[GameBuffer], alpha: float = 1.0, - clip_scale: int = 1 + clip_scale: float = 1.0 ) -> List[int]: """ - 根据不同任务的 num_of_collected_episodes 反比分配 batch_size, - 并动态调整 batch_size 限制范围以提高训练的稳定性和效率。 - - Args: - cfgs (List): 每个任务的配置列表。 - game_buffers (List[GameBuffer]): 每个任务的 replay_buffer 实例列表。 - alpha (float): 控制反比程度的超参数 (默认为1.0)。 - clip_scale (int): 动态调整的缩放因子 (默认为1)。 - + Overview: + Allocates batch sizes for different tasks inversely proportional to their number of collected episodes. + This method dynamically adjusts the batch size range to enhance training stability and efficiency. + Arguments: + - cfgs (:obj:`List[Any]`): A list of configuration objects for each task. + - game_buffers (:obj:`List[GameBuffer]`): A list of replay buffer instances for each task. + - alpha (:obj:`float`): A hyperparameter to control the degree of inverse proportionality. Defaults to 1.0. + - clip_scale (:obj:`float`): A scaling factor for dynamic adjustment of min/max batch size. Defaults to 1.0. Returns: - List[int]: 分配后的 batch_size 列表。 + - (:obj:`List[int]`): A list of allocated batch sizes for each task. """ - # 提取每个任务的 num_of_collected_episodes - buffer_num_of_collected_episodes = [ - buffer.num_of_collected_episodes for buffer in game_buffers - ] - - # 获取当前的 world_size 和 rank - world_size = torch.distributed.get_world_size() - rank = torch.distributed.get_rank() - - # 收集所有 rank 的 num_of_collected_episodes 列表 - all_task_num_of_collected_episodes = [None for _ in range(world_size)] - torch.distributed.all_gather_object( - all_task_num_of_collected_episodes, - buffer_num_of_collected_episodes - ) + # Step 1: Gather the number of collected episodes from all buffers on the current rank. + buffer_num_episodes = [buffer.num_of_collected_episodes for buffer in game_buffers] - # 将所有 rank 的 num_of_collected_episodes 拼接成一个大列表 - all_task_num_of_collected_episodes = [ - item for sublist in all_task_num_of_collected_episodes for item in sublist - ] + world_size = get_world_size() + rank = get_rank() + + # Step 2: Gather episode counts from all tasks across all ranks. + all_task_num_episodes = [None for _ in range(world_size)] + dist.all_gather_object(all_task_num_episodes, buffer_num_episodes) + + # Flatten the list of lists into a single list. + flat_task_num_episodes = [item for sublist in all_task_num_episodes for item in sublist] if rank == 0: - print(f'all_task_num_of_collected_episodes: {all_task_num_of_collected_episodes}') + logging.info(f'Number of collected episodes per task (all ranks): {flat_task_num_episodes}') - # 计算每个任务的反比权重 - inv_episodes = np.array([ - 1.0 / (episodes + 1) for episodes in all_task_num_of_collected_episodes - ]) + # Step 3: Calculate inverse proportional weights. Add 1 to avoid division by zero. + inv_episodes = np.array([1.0 / (episodes + 1) for episodes in flat_task_num_episodes]) inv_sum = np.sum(inv_episodes) - # 计算总的 batch_size (所有任务 cfg.policy.max_batch_size 的和) - max_batch_size = cfgs[0].policy.max_batch_size + # Step 4: Calculate the total batch size from the config of the first task. + # Assumption: max_batch_size is the same across all task configs and represents the global batch size. + global_batch_size = cfgs[0].policy.max_batch_size - # 动态调整的部分:最小和最大的 batch_size 范围 - avg_batch_size = max_batch_size / world_size - min_batch_size = avg_batch_size / clip_scale - max_batch_size = avg_batch_size * clip_scale + # Step 5: Dynamically adjust the min and max batch size bounds. + avg_batch_size = global_batch_size / len(flat_task_num_episodes) + min_batch_size = max(1, avg_batch_size / clip_scale) # Ensure min_batch_size is at least 1. + max_batch_size_clip = avg_batch_size * clip_scale - # 动态调整 alpha,让 batch_size 的变化更加平滑 + # Step 6: Calculate batch sizes based on weights and apply clipping. task_weights = (inv_episodes / inv_sum) ** alpha - batch_sizes = max_batch_size * task_weights + # Note: The original code used max_batch_size, which seems to be a typo. + # It should be global_batch_size to distribute the total batch size. + batch_sizes = global_batch_size * task_weights + batch_sizes = np.clip(batch_sizes, min_batch_size, max_batch_size_clip) - # 控制 batch_size 在 [min_batch_size, max_batch_size] 之间 - batch_sizes = np.clip(batch_sizes, min_batch_size, max_batch_size) + # Ensure batch sizes are integers. + final_batch_sizes = [int(size) for size in batch_sizes] - # 确保 batch_size 是整数 - batch_sizes = [int(size) for size in batch_sizes] + if rank == 0: + logging.info(f"Allocated batch sizes: {final_batch_sizes}") - # 返回最终分配的 batch_size 列表 - return batch_sizes + return final_batch_sizes -def train_muzero_multitask_segment_ddp( - input_cfg_list: List[Tuple[int, Tuple[dict, dict]]], - seed: int = 0, - model: Optional[torch.nn.Module] = None, - model_path: Optional[str] = None, - max_train_iter: Optional[int] = int(1e10), - max_env_step: Optional[int] = int(1e10), -) -> 'Policy': +class MuZeroMultiTaskTrainer: """ Overview: - The train entry for multi-task MuZero, adapted from UniZero's multi-task training. - This script aims to enhance the planning capabilities of reinforcement learning agents - by leveraging multi-task learning to address diverse environments. - - Args: - input_cfg_list (List[Tuple[int, Tuple[dict, dict]]]): - Configurations for different tasks as a list of tuples containing task ID and configuration dictionaries. - seed (int): - Random seed for reproducibility. - model (Optional[torch.nn.Module]): - Predefined model instance. If provided, it will be used instead of creating a new one. - model_path (Optional[str]): - Path to the pretrained model checkpoint. Should point to the ckpt file of the pretrained model. - max_train_iter (Optional[int]): - Maximum number of training iterations. Defaults to 1e10. - max_env_step (Optional[int]): - Maximum number of environment interaction steps. Defaults to 1e10. - - Returns: - Policy: - The trained policy instance. + A trainer class to manage the multi-task training loop for MuZero. + It encapsulates the state and logic for initialization, data collection, + evaluation, training, and termination. """ - # 获取当前进程的 rank 和总的进程数 - rank = get_rank() - world_size = get_world_size() - - # 任务划分 - total_tasks = len(input_cfg_list) - tasks_per_rank = total_tasks // world_size - remainder = total_tasks % world_size - - if rank < remainder: - start_idx = rank * (tasks_per_rank + 1) - end_idx = start_idx + tasks_per_rank + 1 - else: - start_idx = rank * tasks_per_rank + remainder - end_idx = start_idx + tasks_per_rank - - tasks_for_this_rank = input_cfg_list[start_idx:end_idx] - - # 确保至少有一个任务 - if len(tasks_for_this_rank) == 0: - logging.warning(f"Rank {rank}: 未分配任何任务,继续运行但无任务处理。") - # 初始化一些空列表以避免后续代码报错 - cfgs, game_buffers, collector_envs, evaluator_envs, collectors, evaluators = [], [], [], [], [], [] - return - - print(f"Rank {rank}/{world_size}, 处理任务 {start_idx} 到 {end_idx - 1}") - - cfgs = [] - game_buffers = [] - collector_envs = [] - evaluator_envs = [] - collectors = [] - evaluators = [] - - # 使用第一个任务的配置来创建共享的 policy - task_id, [cfg, create_cfg] = tasks_for_this_rank[0] - - # 设置每个任务的随机种子和任务编号 - for config in tasks_for_this_rank: - config[1][0].policy.task_num = len(tasks_for_this_rank) - - # 根据 CUDA 可用性设置设备 - cfg.policy.device = cfg.policy.model.device if torch.cuda.is_available() else 'cpu' - logging.info(f'cfg.policy.device: {cfg.policy.device}') - - # 编译配置 - cfg = compile_config( - cfg, - seed=seed, - env=None, - auto=True, - create_cfg=create_cfg, - save_cfg=True - ) - # 创建共享的 policy - policy = create_policy( - cfg.policy, - model=model, - enable_field=['learn', 'collect', 'eval'] - ) - - # 如果指定了预训练模型,则加载 - if model_path is not None: - logging.info(f'开始加载模型来自 {model_path}...') - policy.learn_mode.load_state_dict( - torch.load(model_path, map_location=cfg.policy.device) - ) - logging.info(f'完成加载模型来自 {model_path}.') - - # 创建 TensorBoard 的日志记录器 - log_dir = os.path.join(f'./{cfg.exp_name}/log', f'serial_rank_{rank}') - tb_logger = SummaryWriter(log_dir) - - # 创建共享的 learner - learner = BaseLearner( - cfg.policy.learn.learner, - policy.learn_mode, - tb_logger, - exp_name=cfg.exp_name - ) - - policy_config = cfg.policy - batch_size = policy_config.batch_size[0] - - # 只处理当前进程分配到的任务 - for local_task_id, (task_id, [cfg, create_cfg]) in enumerate(tasks_for_this_rank): - # 设置每个任务自己的随机种子 - cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu' - cfg = compile_config( - cfg, - seed=seed + task_id, - env=None, - auto=True, - create_cfg=create_cfg, - save_cfg=True - ) - policy_config = cfg.policy - policy.collect_mode.get_attribute('cfg').n_episode = policy_config.n_episode - policy.eval_mode.get_attribute('cfg').n_episode = policy_config.n_episode - - env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) - collector_env = create_env_manager( - cfg.env.manager, - [partial(env_fn, cfg=c) for c in collector_env_cfg] - ) - evaluator_env = create_env_manager( - cfg.env.manager, - [partial(env_fn, cfg=c) for c in evaluator_env_cfg] - ) - collector_env.seed(cfg.seed + task_id) - evaluator_env.seed(cfg.seed + task_id, dynamic_seed=False) - set_pkg_seed(cfg.seed + task_id, use_cuda=cfg.policy.cuda) - - # 为每个任务创建不同的 game buffer、collector、evaluator - replay_buffer = GameBuffer(policy_config) - collector = Collector( - env=collector_env, - policy=policy.collect_mode, - tb_logger=tb_logger, - exp_name=cfg.exp_name, - policy_config=policy_config, - task_id=task_id - ) - evaluator = Evaluator( - eval_freq=cfg.policy.eval_freq, - n_evaluator_episode=cfg.env.n_evaluator_episode, - stop_value=cfg.env.stop_value, - env=evaluator_env, - policy=policy.eval_mode, - tb_logger=tb_logger, - exp_name=cfg.exp_name, - policy_config=policy_config, - task_id=task_id - ) - cfgs.append(cfg) - replay_buffer.batch_size = cfg.policy.batch_size[task_id] - - game_buffers.append(replay_buffer) - collector_envs.append(collector_env) - evaluator_envs.append(evaluator_env) - collectors.append(collector) - evaluators.append(evaluator) - - learner.call_hook('before_run') - value_priority_tasks = {} - - buffer_reanalyze_count = 0 - train_epoch = 0 - reanalyze_batch_size = cfg.policy.reanalyze_batch_size - update_per_collect = cfg.policy.update_per_collect - - while True: - torch.cuda.empty_cache() - - if cfg.policy.allocated_batch_sizes: - # TODO========== - # 线性变化的 随着 train_epoch 从 0 增加到 1000, clip_scale 从 1 线性增加到 4 - clip_scale = np.clip(1 + (3 * train_epoch / 1000), 1, 4) - allocated_batch_sizes = allocate_batch_size( - cfgs, - game_buffers, - alpha=1.0, - clip_scale=clip_scale + def __init__( + self, + input_cfg_list: List[Tuple[int, Tuple[dict, dict]]], + seed: int, + model: Optional[torch.nn.Module], + model_path: Optional[str], + max_train_iter: int, + max_env_step: int, + ) -> None: + """ + Overview: + Initializes the multi-task trainer. + Arguments: + - input_cfg_list (:obj:`List[Tuple[int, Tuple[dict, dict]]]`): Configs for all tasks. + - seed (:obj:`int`): The base random seed. + - model (:obj:`Optional[torch.nn.Module]`): An optional pre-defined model. + - model_path (:obj:`Optional[str]`): Path to a pre-trained model checkpoint. + - max_train_iter (:obj:`int`): Maximum training iterations. + - max_env_step (:obj:`int`): Maximum environment steps. + """ + self.max_train_iter = max_train_iter + self.max_env_step = max_env_step + self.seed = seed + self.rank = get_rank() + self.world_size = get_world_size() + self.timer = EasyTimer() + + # State variables + self.train_epoch = 0 + self.buffer_reanalyze_count = 0 + self.value_priority_tasks = {} + + # Task partitioning + self.tasks_for_this_rank = self._partition_tasks(input_cfg_list) + if not self.tasks_for_this_rank: + logging.warning(f"Rank {self.rank}: No tasks assigned. Process will run without tasks.") + self.is_active = False + return + self.is_active = True + + # Initialize shared components (Policy, Learner) + self.policy, self.learner, self.tb_logger = self._initialize_shared_components(model, model_path) + + # Initialize task-specific components + ( + self.cfgs, self.game_buffers, self.collectors, self.evaluators + ) = self._initialize_task_specific_components() + + self.update_per_collect = self.cfgs[0].policy.update_per_collect + + def _partition_tasks(self, input_cfg_list: List[Tuple[int, Tuple[dict, dict]]]) -> List[ + Tuple[int, Tuple[dict, dict]]]: + """Partitions tasks among distributed processes.""" + total_tasks = len(input_cfg_list) + tasks_per_rank = total_tasks // self.world_size + remainder = total_tasks % self.world_size + + if self.rank < remainder: + start_idx = self.rank * (tasks_per_rank + 1) + end_idx = start_idx + tasks_per_rank + 1 + else: + start_idx = self.rank * tasks_per_rank + remainder + end_idx = start_idx + tasks_per_rank + + logging.info(f"Rank {self.rank}/{self.world_size} is assigned tasks from index {start_idx} to {end_idx - 1}.") + return input_cfg_list[start_idx:end_idx] + + def _initialize_shared_components(self, model: Optional[torch.nn.Module], model_path: Optional[str]) -> Tuple[ + Policy, BaseLearner, SummaryWriter]: + """Initializes components shared across all tasks on this rank.""" + _, [cfg, create_cfg] = self.tasks_for_this_rank[0] + + # Set task_num for the shared policy + for task_config in self.tasks_for_this_rank: + task_config[1][0].policy.task_num = len(self.tasks_for_this_rank) + + cfg.policy.device = 'cuda' if torch.cuda.is_available() else 'cpu' + compiled_cfg = compile_config(cfg, seed=self.seed, auto=True, create_cfg=create_cfg, save_cfg=True) + + policy = create_policy(compiled_cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + if model_path: + logging.info(f'Loading model from {model_path}...') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=compiled_cfg.policy.device)) + logging.info(f'Model loaded successfully from {model_path}.') + + log_dir = os.path.join(f'./{compiled_cfg.exp_name}/log', f'serial_rank_{self.rank}') + tb_logger = SummaryWriter(log_dir) + learner = BaseLearner(compiled_cfg.policy.learn.learner, policy.learn_mode, tb_logger, + exp_name=compiled_cfg.exp_name) + return policy, learner, tb_logger + + def _initialize_task_specific_components(self) -> Tuple[List, List, List, List]: + """Initializes components for each task assigned to this rank.""" + cfgs, game_buffers, collectors, evaluators = [], [], [], [] + + for local_task_id, (task_id, [cfg, create_cfg]) in enumerate(self.tasks_for_this_rank): + task_seed = self.seed + task_id + cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu' + compiled_cfg = compile_config(cfg, seed=task_seed, auto=True, create_cfg=create_cfg, save_cfg=True) + + # Create environments + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(compiled_cfg.env) + collector_env = create_env_manager(compiled_cfg.env.manager, + [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(compiled_cfg.env.manager, + [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(task_seed) + evaluator_env.seed(task_seed, dynamic_seed=False) + set_pkg_seed(task_seed, use_cuda=compiled_cfg.policy.cuda) + + # Create buffer, collector, and evaluator + replay_buffer = GameBuffer(compiled_cfg.policy) + # Set initial batch size from config + replay_buffer.batch_size = compiled_cfg.policy.batch_size[task_id] + + collector = Collector( + env=collector_env, + policy=self.policy.collect_mode, + tb_logger=self.tb_logger, + exp_name=compiled_cfg.exp_name, + policy_config=compiled_cfg.policy, + task_id=task_id ) - if rank == 0: - print("分配后的 batch_sizes: ", allocated_batch_sizes) - for idx, (cfg, collector, evaluator, replay_buffer) in enumerate( - zip(cfgs, collectors, evaluators, game_buffers) - ): - cfg.policy.batch_size = allocated_batch_sizes[idx] - policy._cfg.batch_size[idx] = allocated_batch_sizes[idx] - - # 对于当前进程的每个任务,进行数据收集和评估 - for idx, (cfg, collector, evaluator, replay_buffer) in enumerate( - zip(cfgs, collectors, evaluators, game_buffers) - ): - - log_buffer_memory_usage( - learner.train_iter, - replay_buffer, - tb_logger, - cfg.policy.task_id + evaluator = Evaluator( + eval_freq=compiled_cfg.policy.eval_freq, + n_evaluator_episode=compiled_cfg.env.n_evaluator_episode, + stop_value=compiled_cfg.env.stop_value, + env=evaluator_env, + policy=self.policy.eval_mode, + tb_logger=self.tb_logger, + exp_name=compiled_cfg.exp_name, + policy_config=compiled_cfg.policy, + task_id=task_id ) - collect_kwargs = { - 'temperature': visit_count_temperature( - policy_config.manual_temperature_decay, - policy_config.fixed_temperature_value, - policy_config.threshold_training_steps_for_final_temperature, - trained_steps=learner.train_iter - ), - 'epsilon': 0.0 # 默认的 epsilon 值 - } - - if policy_config.eps.eps_greedy_exploration_in_collect: - epsilon_greedy_fn = get_epsilon_greedy_fn( - start=policy_config.eps.start, - end=policy_config.eps.end, - decay=policy_config.eps.decay, - type_=policy_config.eps.type - ) - collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) - - if learner.train_iter == 0 or evaluator.should_eval(learner.train_iter): - # if learner.train_iter > 1 and evaluator.should_eval(learner.train_iter): # TODO: debug - print('=' * 20) - print(f'Rank {rank} 评估 task_id: {cfg.policy.task_id}...') - - # 在训练进程中调用 safe_eval - stop, reward = safe_eval( - evaluator, - learner, - collector, - rank, - world_size - ) - # 判断评估是否成功 - if stop is None or reward is None: - print(f"Rank {rank} 在评估期间遇到问题。继续训练中...") - else: - print(f"评估成功: stop={stop}, reward={reward}") + cfgs.append(compiled_cfg) + game_buffers.append(replay_buffer) + collectors.append(collector) + evaluators.append(evaluator) + + return cfgs, game_buffers, collectors, evaluators + + def run(self) -> Policy: + """ + Overview: + The main training loop. Executes collection, evaluation, and training steps + until a termination condition is met. + Returns: + - (:obj:`Policy`): The trained policy. + """ + if not self.is_active: + # This rank has no tasks, so it should wait for others to finish. + self._wait_for_termination() + return self.policy + + self.learner.call_hook('before_run') + + while True: + torch.cuda.empty_cache() + + self._update_dynamic_batch_sizes() + self._collect_and_evaluate() + + if self._is_training_ready(): + dist.barrier() + self._train_iteration() + dist.barrier() + else: + logging.warning(f"Rank {self.rank}: Not enough data for training, skipping training step.") - print('=' * 20) - print(f'entry: Rank {rank} 收集 task_id: {cfg.policy.task_id}...') + if self._check_termination_conditions(): + dist.barrier() # Final barrier to ensure all processes stop together. + break - # 收集数据 - new_data = collector.collect( - train_iter=learner.train_iter, - policy_kwargs=collect_kwargs + self.learner.call_hook('after_run') + return self.policy + + def _update_dynamic_batch_sizes(self) -> None: + """Dynamically allocates batch sizes if enabled in the config.""" + if not self.cfgs[0].policy.get('allocated_batch_sizes', False): + return + + # Linearly increase clip_scale from 1 to 4 as train_epoch goes from 0 to 1000. + clip_scale = np.clip(1 + (3 * self.train_epoch / 1000), 1, 4) + allocated_sizes = allocate_batch_size(self.cfgs, self.game_buffers, alpha=1.0, clip_scale=clip_scale) + + # Distribute the allocated sizes to the tasks on the current rank. + # This requires knowing the global task distribution. + total_tasks = self.world_size * len(self.tasks_for_this_rank) # Approximation, needs exact count + # This part is tricky in a distributed setting without global knowledge of task indices. + # Assuming the allocation order matches the task_id order. + for i, cfg in enumerate(self.cfgs): + task_id = cfg.policy.task_id + if task_id < len(allocated_sizes): + batch_size = allocated_sizes[task_id] + cfg.policy.batch_size = batch_size + # Also update the batch size in the shared policy config if necessary + self.policy._cfg.batch_size[task_id] = batch_size + + + def _collect_and_evaluate(self) -> None: + """Runs the data collection and evaluation loop for each assigned task.""" + for i, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(self.cfgs, self.collectors, self.evaluators, self.game_buffers)): + log_buffer_memory_usage(self.learner.train_iter, replay_buffer, self.tb_logger, cfg.policy.task_id) + + # Evaluation step + if evaluator.should_eval(self.learner.train_iter): + safe_eval(evaluator, self.learner, collector, self.rank, self.world_size) + + # Collection step + self._collect_data_for_task(cfg, collector, replay_buffer) + + def _collect_data_for_task(self, cfg: Any, collector: Collector, replay_buffer: GameBuffer) -> None: + """Collects data for a single task and pushes it to the replay buffer.""" + policy_config = cfg.policy + collect_kwargs = { + 'temperature': visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=self.learner.train_iter + ), + 'epsilon': 0.0 + } + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, end=policy_config.eps.end, + decay=policy_config.eps.decay, type_=policy_config.eps.type ) + collect_kwargs['epsilon'] = epsilon_fn(collector.envstep) - # 更新 replay buffer - replay_buffer.push_game_segments(new_data) - replay_buffer.remove_oldest_data_to_fit() + logging.info(f'Rank {self.rank}: Collecting data for task {cfg.policy.task_id}...') + new_data = collector.collect(train_iter=self.learner.train_iter, policy_kwargs=collect_kwargs) + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + logging.info(f'Rank {self.rank}: Finished data collection for task {cfg.policy.task_id}.') - # 周期性地重新分析缓冲区 - if cfg.policy.buffer_reanalyze_freq >= 1: - # 在一个训练 epoch 中重新分析缓冲区 次 - reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq - else: - # 每 <1/buffer_reanalyze_freq> 个训练 epoch 重新分析一次缓冲区 - if ( - train_epoch % int(1 / cfg.policy.buffer_reanalyze_freq) == 0 and - replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > - int(reanalyze_batch_size / cfg.policy.reanalyze_partition) - ): - with timer: - # 每个重新分析过程将重新分析 个序列 - replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) - buffer_reanalyze_count += 1 - logging.info(f'缓冲区重新分析计数: {buffer_reanalyze_count}') - logging.info(f'缓冲区重新分析时间: {timer.value}') - - # 数据收集结束后添加日志 - logging.info(f'Rank {rank}: 完成任务 {cfg.policy.task_id} 的数据收集') - - # 检查是否有足够的数据进行训练 - not_enough_data = any( - replay_buffer.get_num_of_transitions() < cfg.policy.batch_size[cfg.policy.task_id] - for cfg, replay_buffer in zip(cfgs, game_buffers) - ) - assert not not_enough_data, f"Rank {rank}: 某些任务的数据量不足以进行训练。请确保所有任务的 replay buffer 中有足够的数据。" + # Periodic reanalysis of the buffer + self._reanalyze_buffer_if_needed(cfg, replay_buffer, is_during_training=False) - # 同步训练前所有 rank 的准备状态 - try: - dist.barrier() - logging.info(f'Rank {rank}: 通过训练前的 barrier') - except Exception as e: - logging.error(f'Rank {rank}: Barrier 失败,错误: {e}') - break # 或者进行其他错误处理 - - # 学习策略 - if not not_enough_data: - # Learner 将在一次迭代中训练 update_per_collect 次 - for i in range(update_per_collect): - train_data_multi_task = [] - envstep_multi_task = 0 - for idx, (cfg, collector, replay_buffer) in enumerate( - zip(cfgs, collectors, game_buffers) - ): - envstep_multi_task += collector.envstep - batch_size = cfg.policy.batch_size[cfg.policy.task_id] - if replay_buffer.get_num_of_transitions() > batch_size: - if cfg.policy.buffer_reanalyze_freq >= 1: - # 在一个训练 epoch 中重新分析缓冲区 次 - if ( - i % reanalyze_interval == 0 and - replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > - int(reanalyze_batch_size / cfg.policy.reanalyze_partition) - ): - with timer: - # 每个重新分析过程将重新分析 个序列 - replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) - buffer_reanalyze_count += 1 - logging.info(f'缓冲区重新分析计数: {buffer_reanalyze_count}') - logging.info(f'缓冲区重新分析时间: {timer.value}') - - train_data = replay_buffer.sample(batch_size, policy) - # 追加 task_id,以便在训练时区分任务 - train_data.append(cfg.policy.task_id) - train_data_multi_task.append(train_data) + def _reanalyze_buffer_if_needed(self, cfg: Any, replay_buffer: GameBuffer, is_during_training: bool, + train_loop_idx: int = 0) -> None: + """Handles the logic for reanalyzing the game buffer.""" + policy_config = cfg.policy + reanalyze_freq = policy_config.buffer_reanalyze_freq + reanalyze_batch_size = policy_config.reanalyze_batch_size + reanalyze_partition = policy_config.reanalyze_partition + update_per_collect = policy_config.update_per_collect + + should_reanalyze = False + if reanalyze_freq >= 1: + reanalyze_interval = update_per_collect // reanalyze_freq + if is_during_training and train_loop_idx % reanalyze_interval == 0: + should_reanalyze = True + else: # reanalyze_freq is a fraction, e.g., 0.1 + if not is_during_training and self.train_epoch % int(1 / reanalyze_freq) == 0: + should_reanalyze = True + + if should_reanalyze and replay_buffer.get_num_of_transitions() // policy_config.num_unroll_steps > int(reanalyze_batch_size / reanalyze_partition): + with self.timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, self.policy) + self.buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {self.buffer_reanalyze_count}, Time: {self.timer.value:.2f}s') + + def _is_training_ready(self) -> bool: + """Checks if there is enough data in all buffers to start training.""" + for cfg, buffer in zip(self.cfgs, self.game_buffers): + if buffer.get_num_of_transitions() < cfg.policy.batch_size[cfg.policy.task_id]: + logging.warning(f"Rank {self.rank}, Task {cfg.policy.task_id}: Not enough data. " + f"Required: {cfg.policy.batch_size[cfg.policy.task_id]}, " + f"Available: {buffer.get_num_of_transitions()}") + return False + return True + + def _train_iteration(self) -> None: + """Performs one full training iteration, consisting of multiple updates.""" + for i in range(self.update_per_collect): + train_data_multi_task = [] + envstep_multi_task = 0 + + for idx, (cfg, collector, replay_buffer) in enumerate( + zip(self.cfgs, self.collectors, self.game_buffers)): + envstep_multi_task += collector.envstep + batch_size = cfg.policy.batch_size[cfg.policy.task_id] + + if replay_buffer.get_num_of_transitions() > batch_size: + self._reanalyze_buffer_if_needed(cfg, replay_buffer, is_during_training=True, train_loop_idx=i) + train_data = replay_buffer.sample(batch_size, self.policy) + train_data.append(cfg.policy.task_id) # Append task_id for multi-task loss + train_data_multi_task.append(train_data) + else: + # This case should ideally be prevented by _is_training_ready + logging.warning(f"Skipping sample for task {cfg.policy.task_id} due to insufficient data.") + train_data_multi_task.clear() # Invalidate the whole batch if one task fails + break + + if train_data_multi_task: + log_vars = self.learner.train(train_data_multi_task, envstep_multi_task) + if self.cfgs[0].policy.use_priority: + self._update_priorities(train_data_multi_task, log_vars) + + self.train_epoch += 1 + + def _update_priorities(self, train_data_multi_task: List, log_vars: List[Dict]) -> None: + """Updates the priorities in the replay buffers after a training step.""" + for idx, (cfg, replay_buffer) in enumerate(zip(self.cfgs, self.game_buffers)): + task_id = cfg.policy.task_id + priority_key = f'value_priority_task{task_id}' + + if priority_key in log_vars[0]: + priorities = log_vars[0][priority_key] + replay_buffer.update_priority(train_data_multi_task[idx], priorities) + + # Log priority statistics + if cfg.policy.get('print_task_priority_logs', False): + mean_priority = np.mean(priorities) + std_priority = np.std(priorities) + + # Update running mean of priority + running_mean_key = f'running_mean_priority_task{task_id}' + alpha = 0.1 # Smoothing factor for running average + if running_mean_key not in self.value_priority_tasks: + self.value_priority_tasks[running_mean_key] = mean_priority else: - logging.warning( - f'Replay buffer 中的数据不足以采样一个 mini-batch: ' - f'batch_size: {batch_size}, replay_buffer: {replay_buffer}' - ) - break - - if train_data_multi_task: - # 在训练时,DDP 会自动同步梯度和参数 - log_vars = learner.train(train_data_multi_task, envstep_multi_task) - - if cfg.policy.use_priority: - for idx, (cfg, replay_buffer) in enumerate( - zip(cfgs, game_buffers) - ): - # 更新任务特定的 replay buffer 的优先级 - task_id = cfg.policy.task_id - replay_buffer.update_priority( - train_data_multi_task[idx], - log_vars[0][f'value_priority_task{task_id}'] - ) - - current_priorities = log_vars[0][f'value_priority_task{task_id}'] - - mean_priority = np.mean(current_priorities) - std_priority = np.std(current_priorities) - - alpha = 0.1 # 运行均值的平滑因子 - if f'running_mean_priority_task{task_id}' not in value_priority_tasks: - # 如果不存在,则初始化运行均值 - value_priority_tasks[f'running_mean_priority_task{task_id}'] = mean_priority - else: - # 更新运行均值 - value_priority_tasks[f'running_mean_priority_task{task_id}'] = ( - alpha * mean_priority + - (1 - alpha) * value_priority_tasks[f'running_mean_priority_task{task_id}'] - ) - - # 使用运行均值计算归一化的优先级 - running_mean_priority = value_priority_tasks[f'running_mean_priority_task{task_id}'] - normalized_priorities = ( - current_priorities - running_mean_priority - ) / (std_priority + 1e-6) - - # 如果需要,可以将归一化的优先级存储回 replay buffer - # replay_buffer.update_priority(train_data_multi_task[idx], normalized_priorities) - - # 如果设置了 print_task_priority_logs 标志,则记录统计信息 - if cfg.policy.print_task_priority_logs: - print( - f"任务 {task_id} - 平均优先级: {mean_priority:.8f}, " - f"运行平均优先级: {running_mean_priority:.8f}, " - f"标准差: {std_priority:.8f}" - ) - - train_epoch += 1 - - # 同步所有 Rank,确保所有 Rank 都完成了训练 + self.value_priority_tasks[running_mean_key] = \ + alpha * mean_priority + (1 - alpha) * self.value_priority_tasks[running_mean_key] + + running_mean_priority = self.value_priority_tasks[running_mean_key] + logging.info( + f"Task {task_id} - Priority Stats: Mean={mean_priority:.6f}, " + f"Running Mean={running_mean_priority:.6f}, Std={std_priority:.6f}" + ) + + def _check_termination_conditions(self) -> bool: + """Checks if the training should be terminated based on env steps or train iterations.""" try: - dist.barrier() - logging.info(f'Rank {rank}: 通过训练后的 barrier') - except Exception as e: - logging.error(f'Rank {rank}: Barrier 失败,错误: {e}') - break # 或者进行其他错误处理 + # Check max_env_step + local_envsteps = [collector.envstep for collector in self.collectors] + all_ranks_envsteps = [None for _ in range(self.world_size)] + dist.all_gather_object(all_ranks_envsteps, local_envsteps) + + # Flatten and check if all tasks have reached the step limit + all_envsteps = [step for rank_steps in all_ranks_envsteps for step in rank_steps] + if all(step >= self.max_env_step for step in all_envsteps): + logging.info(f"Rank {self.rank}: All tasks reached max_env_step ({self.max_env_step}). Terminating.") + return True + + # Check max_train_iter + local_train_iter = torch.tensor([self.learner.train_iter], device=self.policy.device) + all_train_iters = [torch.zeros_like(local_train_iter) for _ in range(self.world_size)] + dist.all_gather(all_train_iters, local_train_iter) + + if any(it.item() >= self.max_train_iter for it in all_train_iters): + logging.info(f"Rank {self.rank}: A process reached max_train_iter ({self.max_train_iter}). Terminating.") + return True - # 检查是否需要终止训练 - try: - # local_envsteps 不再需要填充 - local_envsteps = [collector.envstep for collector in collectors] - - total_envsteps = [None for _ in range(world_size)] - dist.all_gather_object(total_envsteps, local_envsteps) - - # 将所有 envsteps 拼接在一起 - all_envsteps = torch.cat([ - torch.tensor(envsteps, device=cfg.policy.device) - for envsteps in total_envsteps - ]) - max_envstep_reached = torch.all(all_envsteps >= max_env_step) - - # 收集所有进程的 train_iter - global_train_iter = torch.tensor([learner.train_iter], device=cfg.policy.device) - all_train_iters = [torch.zeros_like(global_train_iter) for _ in range(world_size)] - dist.all_gather(all_train_iters, global_train_iter) - - max_train_iter_reached = torch.any( - torch.stack(all_train_iters) >= max_train_iter - ) - - if max_envstep_reached.item() or max_train_iter_reached.item(): - logging.info(f'Rank {rank}: 满足终止条件') - dist.barrier() # 确保所有进程同步 + except Exception as e: + logging.error(f'Rank {self.rank}: Failed during termination check. Error: {e}', exc_info=True) + return True # Terminate on error to prevent hanging + + return False + + def _wait_for_termination(self) -> None: + """ + For inactive ranks, this method blocks and waits for a termination signal + (e.g., another rank finishing) by participating in barriers and termination checks. + """ + while True: + # Participate in barriers to stay in sync + dist.barrier() # Pre-train barrier + dist.barrier() # Post-train barrier + + if self._check_termination_conditions(): + dist.barrier() # Final barrier break - else: - pass - except Exception as e: - logging.error(f'Rank {rank}: 终止检查失败,错误: {e}') - break # 或者进行其他错误处理 +def train_muzero_multitask_segment_ddp( + input_cfg_list: List[Tuple[int, Tuple[dict, dict]]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = MAX_TRAIN_ITER_INF, + max_env_step: Optional[int] = MAX_ENV_STEP_INF, +) -> Policy: + """ + Overview: + The main entry point for multi-task MuZero training using Distributed Data Parallel (DDP). + This function sets up the distributed environment, partitions tasks, and launches the training process, + which is managed by the MuZeroMultiTaskTrainer class. + Arguments: + - input_cfg_list (:obj:`List[Tuple[int, Tuple[dict, dict]]]`): A list of tuples, where each tuple contains + a task ID and its corresponding configuration dictionaries (main_config, create_config). + - seed (:obj:`int`): The base random seed for reproducibility. Defaults to 0. + - model (:obj:`Optional[torch.nn.Module]`): An optional pre-defined model instance. If provided, + it will be used instead of creating a new one from the config. Defaults to None. + - model_path (:obj:`Optional[str]`): Path to a pre-trained model checkpoint file. If provided, + the model weights will be loaded before training starts. Defaults to None. + - max_train_iter (:obj:`Optional[int]`): The maximum number of training iterations. + Training will stop if any process reaches this limit. Defaults to a very large number. + - max_env_step (:obj:`Optional[int]`): The maximum number of environment steps for each task. + Training will stop when all tasks have reached this limit. Defaults to a very large number. + Returns: + - (:obj:`Policy`): The final trained policy instance from the primary rank. + """ + # Initialize the trainer, which handles all the complex setup and logic internally. + trainer = MuZeroMultiTaskTrainer( + input_cfg_list=input_cfg_list, + seed=seed, + model=model, + model_path=model_path, + max_train_iter=max_train_iter, + max_env_step=max_env_step, + ) - learner.call_hook('after_run') - return policy \ No newline at end of file + # Run the training loop and return the trained policy. + return trainer.run() \ No newline at end of file diff --git a/lzero/entry/train_unizero_multitask_balance_segment_ddp.py b/lzero/entry/train_unizero_multitask_balance_segment_ddp.py new file mode 100644 index 000000000..d80106e49 --- /dev/null +++ b/lzero/entry/train_unizero_multitask_balance_segment_ddp.py @@ -0,0 +1,548 @@ +import logging +import os +from functools import partial +from typing import Tuple, Optional, List, Dict, Any + +import torch +import numpy as np +from ding.config import compile_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import set_pkg_seed, get_rank, get_world_size +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage, TemperatureScheduler +from lzero.policy import visit_count_temperature +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroSegmentCollector as Collector +from ding.utils import EasyTimer +import torch.nn.functional as F +import torch.distributed as dist +import concurrent.futures +from lzero.model.unizero_world_models.transformer import set_curriculum_stage, CurriculumLoRALinear + +from collections import defaultdict +import math +from .utils import ( + freeze_non_lora_parameters, + compute_task_weights, + log_module_trainable_status, + log_param_statistics, + tasks_per_stage, + compute_unizero_mt_normalized_stats, + allocate_batch_size +) + +# A global dictionary to store the most recent evaluation return for each task. +# Format: {task_id: eval_episode_return_mean} +GLOBAL_EVAL_RETURNS: Dict[int, float] = defaultdict(lambda: None) + +# Timeout for the evaluation process in seconds. +EVALUATION_TIMEOUT = 12000 # 200 minutes + + +class CurriculumController: + """ + Overview: + Manages the curriculum learning stages for a multi-task policy. + It tracks the number of solved tasks and training iterations to decide when to transition + to the next curriculum stage, which typically involves freezing parts of the model + and activating new LoRA adapters. + """ + + def __init__(self, cfg: 'EasyDict', policy: 'Policy') -> None: + """ + Overview: + Initializes the CurriculumController. + Arguments: + - cfg (:obj:`EasyDict`): The experiment configuration. + - policy (:obj:`Policy`): The policy being trained. + """ + world_model_cfg = cfg.policy.model.world_model_cfg + self.stage_num: int = world_model_cfg.curriculum_stage_num + self.min_stage0_iters: int = world_model_cfg.min_stage0_iters + self.max_stage_iters: int = world_model_cfg.max_stage_iters + self.policy: 'Policy' = policy + + # Flag to determine if curriculum learning should also be applied to the encoder. + # Defaults to False for backward compatibility. + self.apply_curriculum_to_encoder: bool = getattr(world_model_cfg, 'apply_curriculum_to_encoder', False) + logging.info(f"[CurriculumController] Initialized. Curriculum will be applied to Encoder: {self.apply_curriculum_to_encoder}") + + self.stage: int = 0 + self.last_switch_iter: int = 0 + self.last_solved_count: int = 0 # Snapshot of the last count of solved tasks + + def step(self, solved_count: int, unsolved_count: int, train_iter: int) -> bool: + """ + Overview: + Checks if the curriculum should transition to the next stage and performs the switch if needed. + This method should be called at the end of each training loop. + Arguments: + - solved_count (:obj:`int`): The current total number of solved tasks. + - unsolved_count (:obj:`int`): The current number of tasks yet to be solved. + - train_iter (:obj:`int`): The current training iteration. + Returns: + - bool: True if a stage switch occurred, False otherwise. + """ + # --- Stage 0 is a mandatory training phase for a minimum number of iterations --- + if self.stage == 0 and train_iter < self.min_stage0_iters: + return False + + # --- Determine if a stage switch is necessary --- + should_switch = False + + # 1. Trigger based on task progress + newly_solved = solved_count - self.last_solved_count + remaining_lora_stages = self.stage_num - 1 - self.stage # Stage 0 doesn't use LoRA + if remaining_lora_stages > 0: + # Calculate tasks per stage (tps) for the remaining unsolved tasks + tps = tasks_per_stage(unsolved_count, remaining_lora_stages) + if newly_solved >= tps: + should_switch = True + + # 2. Trigger based on maximum iterations per stage + if train_iter - self.last_switch_iter >= self.max_stage_iters: + should_switch = True + + # --- Execute the stage switch --- + if should_switch and self.stage < self.stage_num - 1: + is_entering_stage1 = (self.stage == 0) + self.stage += 1 + + world_model = self.policy._learn_model.world_model + vit_encoder = world_model.tokenizer.encoder + transformer_backbone = world_model.transformer + + # --- Apply curriculum stage update and freeze parameters accordingly --- + + # 1. Conditionally apply to ViT Encoder based on configuration + if self.apply_curriculum_to_encoder: + logging.info(f"[Curriculum] Applying curriculum stage {self.stage} to ViT Encoder.") + set_curriculum_stage(vit_encoder, self.stage) + if is_entering_stage1: + logging.info("[Curriculum] Entering Stage 1. Freezing non-LoRA parameters in ViT Encoder.") + freeze_non_lora_parameters(vit_encoder, freeze=True, verbose=True) + log_module_trainable_status(vit_encoder, "ViT Encoder") + else: + logging.info("[Curriculum] Skipping curriculum stage update for ViT Encoder as per configuration.") + log_module_trainable_status(vit_encoder, "ViT Encoder (Curriculum Not Applied)") + + # 2. Always apply to Transformer Decoder + logging.info(f"[Curriculum] Applying curriculum stage {self.stage} to Transformer Backbone.") + set_curriculum_stage(transformer_backbone, self.stage) + if is_entering_stage1: + logging.info("[Curriculum] Entering Stage 1. Freezing non-LoRA parameters in Transformer Backbone.") + freeze_non_lora_parameters(transformer_backbone, freeze=True, verbose=True) + log_module_trainable_status(transformer_backbone, "Transformer Backbone") + + logging.info( + f'[Curriculum] Switched to stage {self.stage} ' + f'(solved={solved_count}, unsolved={unsolved_count}, iter={train_iter})' + ) + + # Log parameter statistics after the switch + updated_params = sum(p.requires_grad for p in self.policy._learn_model.world_model.parameters()) + total_params = sum(1 for _ in self.policy._learn_model.world_model.parameters()) + logging.info(f'{updated_params}/{total_params} parameters in the world model will be optimized.') + log_param_statistics(self.policy._learn_model.world_model) + + self.last_solved_count = solved_count + self.last_switch_iter = train_iter + return True + + return False + + +def safe_eval( + evaluator: Evaluator, + learner: BaseLearner, + collector: Collector, + rank: int, + world_size: int +) -> Tuple[Optional[bool], Optional[Dict[str, Any]]]: + """ + Overview: + Executes the evaluation process with a timeout to prevent the training from stalling. + Arguments: + - evaluator (:obj:`Evaluator`): The evaluator instance. + - learner (:obj:`BaseLearner`): The learner instance, used to save checkpoints. + - collector (:obj:`Collector`): The collector instance, used to get the current envstep. + - rank (:obj:`int`): The rank of the current process. + - world_size (:obj:`int`): The total number of processes. + Returns: + - Tuple[Optional[bool], Optional[Dict[str, Any]]]: A tuple containing the stop flag and the reward dictionary + if evaluation succeeds. Returns (None, None) on timeout or error. + """ + try: + logging.info(f"========= Evaluation starting on Rank {rank}/{world_size} =========") + # Ensure the stop_event is clear before starting a new evaluation. + evaluator.stop_event.clear() + with concurrent.futures.ThreadPoolExecutor() as executor: + # Submit the evaluation task. + future = executor.submit(evaluator.eval, learner.save_checkpoint, learner.train_iter, collector.envstep) + try: + stop_flag, reward_dict = future.result(timeout=EVALUATION_TIMEOUT) + except concurrent.futures.TimeoutError: + # Set the stop_event to terminate the stuck evaluation thread. + evaluator.stop_event.set() + logging.error(f"Evaluation timed out on Rank {rank}/{world_size} after {EVALUATION_TIMEOUT} seconds.") + return None, None + + logging.info(f"====== Evaluation finished on Rank {rank}/{world_size} ======") + return stop_flag, reward_dict + except Exception as e: + logging.error(f"An error occurred during evaluation on Rank {rank}/{world_size}: {e}", exc_info=True) + return None, None + + +def train_unizero_multitask_balance_segment_ddp( + input_cfg_list: List[Tuple[int, Tuple[dict, dict]]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), + benchmark_name: str = "atari" +) -> 'Policy': + """ + Overview: + The main training entry point for UniZero in a multi-task, curriculum-based setting using DDP. + This function orchestrates distributed data collection, training, and evaluation across multiple tasks. + The curriculum learning strategy involves: + - Defining a `target_return` for each task. + - Moving tasks to a `solved_task_pool` once they achieve their target return, excluding them from + further training and collection. + - Progressing through curriculum stages where the model's backbone is frozen, and only specialized + modules (like LoRA) are trained on harder, unsolved tasks. + This allows the model to first learn general features and then specialize on difficult tasks without + catastrophic forgetting. + Arguments: + - input_cfg_list (:obj:`List[Tuple[int, Tuple[dict, dict]]]`): A list of configurations for each task. + - seed (:obj:`int`): The random seed. + - model (:obj:`Optional[torch.nn.Module]`): An optional pre-existing model instance. + - model_path (:obj:`Optional[str]`): Path to a pre-trained model checkpoint file. + - max_train_iter (:obj:`Optional[int]`): The maximum number of training iterations. + - max_env_step (:obj:`Optional[int]`): The maximum number of environment steps. + - benchmark_name (:obj:`str`): The name of the benchmark (e.g., "atari", "dmc") to load normalization scores. + Returns: + - Policy: The trained policy. + """ + # --- Initialization and DDP Setup --- + logging.basicConfig(level=logging.INFO) + rank = get_rank() + world_size = get_world_size() + timer = EasyTimer() + + # --- Benchmark Score Initialization --- + if benchmark_name == "atari": + RANDOM_SCORES = np.array([ + 227.8, 5.8, 222.4, 210.0, 14.2, 2360.0, 0.1, 1.7, 811.0, 10780.5, + 152.1, 0.0, 65.2, 257.6, 1027.0, 29.0, 52.0, 1598.0, 258.5, 307.3, + -20.7, 24.9, 163.9, 11.5, 68.4, 533.4 + ]) + HUMAN_SCORES = np.array([ + 7127.7, 1719.5, 742.0, 8503.3, 753.1, 37187.5, 12.1, 30.5, 7387.8, 35829.4, + 1971.0, 29.6, 4334.7, 2412.5, 30826.4, 302.8, 3035.0, 2665.5, 22736.3, 6951.6, + 14.6, 69571.3, 13455.0, 7845.0, 42054.7, 11693.2 + ]) + new_order = [ + 20, 19, 24, 6, 0, 8, 14, 23, 1, 2, 3, 4, 5, 9, 10, 11, 12, 13, 15, 16, 17, 18, 21, 25, 22, 7 + ] + new_RANDOM_SCORES = RANDOM_SCORES[new_order] + new_HUMAN_SCORES = HUMAN_SCORES[new_order] + elif benchmark_name == "dmc": + new_RANDOM_SCORES = np.zeros(26) + new_HUMAN_SCORES = np.ones(26) * 1000 + else: + raise ValueError(f"Unsupported benchmark_name: {benchmark_name}") + + # --- Task Distribution Across Ranks --- + total_tasks = len(input_cfg_list) + tasks_per_rank = total_tasks // world_size + remainder = total_tasks % world_size + start_idx = rank * tasks_per_rank + min(rank, remainder) + end_idx = start_idx + tasks_per_rank + (1 if rank < remainder else 0) + tasks_for_this_rank = input_cfg_list[start_idx:end_idx] + + if not tasks_for_this_rank: + logging.warning(f"Rank {rank}: No tasks assigned. Process will idle but maintain DDP communication.") + # An idle process must still participate in collective communications. + # The main loop handles this by waiting at barriers. + while True: + dist.barrier() # Wait for other processes + dist.barrier() # Sync after potential training step + # A mechanism to terminate idle processes would be needed here, + # for now, they sync and wait. + # This part requires a robust termination signal from active processes. + + logging.info(f"Rank {rank}/{world_size} is handling tasks from index {start_idx} to {end_idx - 1}.") + + # --- Environment, Policy, and Worker Initialization --- + task_configs, replay_buffers, collectors, evaluators = [], [], [], [] + + # Use the first task's config to create the shared policy and learner + _, [main_cfg, main_create_cfg] = tasks_for_this_rank[0] + for _, [cfg, _] in tasks_for_this_rank: + cfg.policy.task_num = len(tasks_for_this_rank) + + assert main_create_cfg.policy.type in ['unizero_multitask', 'sampled_unizero_multitask'], \ + "This entry only supports 'unizero_multitask' or 'sampled_unizero_multitask' policies." + + GameBuffer = None + if main_create_cfg.policy.type == 'unizero_multitask': + from lzero.mcts import UniZeroGameBuffer as GameBuffer + elif main_create_cfg.policy.type == 'sampled_unizero_multitask': + from lzero.mcts import SampledUniZeroGameBuffer as GameBuffer + + main_cfg.policy.device = 'cuda' if torch.cuda.is_available() else 'cpu' + compiled_cfg = compile_config(main_cfg, seed=seed, auto=True, create_cfg=main_create_cfg, save_cfg=True) + + policy = create_policy(compiled_cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + if model_path: + logging.info(f'Loading pre-trained model from: {model_path}') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=compiled_cfg.policy.device)) + logging.info('Model loading complete.') + + tb_logger = SummaryWriter(os.path.join(f'./{compiled_cfg.exp_name}/log', f'rank_{rank}')) + learner = BaseLearner(compiled_cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=compiled_cfg.exp_name) + learner.call_hook('before_run') + + # Initialize components for each assigned task + for local_task_id, (task_id, [cfg, create_cfg]) in enumerate(tasks_for_this_rank): + task_seed = seed + task_id + cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu' + compiled_task_cfg = compile_config(cfg, seed=task_seed, auto=True, create_cfg=create_cfg, save_cfg=True) + + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(compiled_task_cfg.env) + collector_env = create_env_manager(compiled_task_cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(compiled_task_cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(task_seed) + evaluator_env.seed(task_seed, dynamic_seed=False) + set_pkg_seed(task_seed, use_cuda=compiled_task_cfg.policy.cuda) + + replay_buffers.append(GameBuffer(compiled_task_cfg.policy)) + collectors.append(Collector(collector_env, policy.collect_mode, tb_logger, compiled_task_cfg.exp_name, compiled_task_cfg.policy, task_id)) + evaluators.append(Evaluator(compiled_task_cfg.policy.eval_freq, compiled_task_cfg.env.n_evaluator_episode, compiled_task_cfg.env.stop_value, evaluator_env, policy.eval_mode, tb_logger, compiled_task_cfg.exp_name, compiled_task_cfg.policy, task_id)) + task_configs.append(compiled_task_cfg) + + # --- Curriculum and Training Loop Initialization --- + solved_task_pool = set() + curriculum_controller = CurriculumController(compiled_cfg, policy) + temperature_scheduler = TemperatureScheduler(initial_temp=10.0, final_temp=1.0, threshold_steps=int(1e4), mode='linear') + + train_epoch = 0 + buffer_reanalyze_count = 0 + + logging.info(f"Rank {rank}: Initial trainable parameters in world model: {sum(p.requires_grad for p in policy._learn_model.world_model.parameters())}/{sum(1 for _ in policy._learn_model.world_model.parameters())}") + + # ============================================================================================ + # Main Training Loop + # ============================================================================================ + while True: + # --- 1. Dynamic Batch Size Allocation (Optional) --- + if compiled_cfg.policy.allocated_batch_sizes: + clip_scale = np.clip(1 + (3 * train_epoch / 1000), 1, 4) + allocated_batch_sizes = allocate_batch_size(task_configs, replay_buffers, alpha=1.0, clip_scale=clip_scale) + if rank == 0: + logging.info(f"Dynamically allocated batch sizes: {allocated_batch_sizes}") + for i, cfg in enumerate(task_configs): + cfg.policy.batch_size = allocated_batch_sizes + policy._cfg.batch_size = allocated_batch_sizes + + # --- 2. Data Collection and Evaluation for each task on this rank --- + local_task_returns = {} + for i, (cfg, collector, evaluator, replay_buffer) in enumerate(zip(task_configs, collectors, evaluators, replay_buffers)): + task_id = cfg.policy.task_id + if task_id in solved_task_pool: + continue + + # Evaluate policy if it's time + if learner.train_iter > 10 and evaluator.should_eval(learner.train_iter): + logging.info(f'Rank {rank} evaluating task_id: {task_id}...') + evaluator._policy.reset(reset_init_data=True, task_id=task_id) + stop_flag, reward_dict = safe_eval(evaluator, learner, collector, rank, world_size) + + if reward_dict is not None: + eval_mean_reward = reward_dict.get('eval_episode_return_mean', float('-inf')) + logging.info(f"Task {task_id} evaluation reward: {eval_mean_reward}") + local_task_returns[task_id] = eval_mean_reward + if eval_mean_reward >= cfg.policy.target_return: + logging.info(f"Task {task_id} has reached its target return of {cfg.policy.target_return}. Adding to solved pool.") + solved_task_pool.add(task_id) + else: + logging.warning(f"Evaluation failed or timed out for task {task_id}. Assigning a low score.") + local_task_returns[task_id] = float('-inf') + + # Collect new data + logging.info(f'Rank {rank} collecting data for task_id: {task_id}...') + collect_kwargs = {'temperature': visit_count_temperature(cfg.policy.manual_temperature_decay, cfg.policy.fixed_temperature_value, cfg.policy.threshold_training_steps_for_final_temperature, learner.train_iter)} + if cfg.policy.eps.eps_greedy_exploration_in_collect: + epsilon_fn = get_epsilon_greedy_fn(cfg.policy.eps.start, cfg.policy.eps.end, cfg.policy.eps.decay, cfg.policy.eps.type) + collect_kwargs['epsilon'] = epsilon_fn(collector.envstep) + + collector._policy.reset(reset_init_data=True, task_id=task_id) + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + logging.info(f'Rank {rank}: Data collection finished for task {task_id}.') + + # --- 3. DDP Synchronization of Task Status and Weights --- + dist.barrier() + # Gather solved tasks from all ranks + all_solved_pools = [None for _ in range(world_size)] + dist.all_gather_object(all_solved_pools, solved_task_pool) + global_solved_task_pool = set().union(*[pool for pool in all_solved_pools if pool is not None]) + solved_task_pool = global_solved_task_pool # Sync local pool with global + global_solved_count = len(solved_task_pool) + + # Gather evaluation returns and compute task weights + task_weights = None + if learner.train_iter > 10 and learner.train_iter % compiled_cfg.policy.eval_freq == 0: + all_task_returns = [None for _ in range(world_size)] + dist.all_gather_object(all_task_returns, local_task_returns) + + merged_task_returns = {k: v for d in all_task_returns if d for k, v in d.items()} + for tid, ret in merged_task_returns.items(): + GLOBAL_EVAL_RETURNS[tid] = ret # Update global tracker + + unsolved_task_returns = {tid: ret for tid, ret in merged_task_returns.items() if tid not in solved_task_pool} + + if rank == 0: + logging.info(f"Global unsolved task returns for weight calculation: {unsolved_task_returns}") + if compiled_cfg.policy.task_complexity_weight and unsolved_task_returns: + temp = temperature_scheduler.get_temperature(learner.train_iter) + task_weights = compute_task_weights(unsolved_task_returns, option="rank", temperature=temp) + logging.info(f"Computed task weights: {task_weights}") + + # Log UniZero-MT normalized stats + mean_norm, median_norm = compute_unizero_mt_normalized_stats(GLOBAL_EVAL_RETURNS) + if mean_norm is not None: + tb_logger.add_scalar('UniZero-MT/NormalizedMean', mean_norm, learner.train_iter) + tb_logger.add_scalar('UniZero-MT/NormalizedMedian', median_norm, learner.train_iter) + logging.info(f"UniZero-MT Normalized Mean={mean_norm:.4f}, Median={median_norm:.4f}") + + # Broadcast weights from rank 0 to all other ranks + broadcast_objects = [task_weights] + dist.broadcast_object_list(broadcast_objects, src=0) + task_weights = broadcast_objects[0] + + # --- 4. Curriculum Stage Update --- + unsolved_count = total_tasks - global_solved_count + switched = curriculum_controller.step(global_solved_count, unsolved_count, learner.train_iter) + + if rank == 0: + tb_logger.add_scalar('Curriculum/Stage', curriculum_controller.stage, learner.train_iter) + tb_logger.add_scalar('Curriculum/GlobalSolvedTasks', global_solved_count, learner.train_iter) + + # TODO 遍历 transformer 中所有子模块,根据其名称查找 CurriculumLoRALinear 模块 + # transformer = policy._learn_model.world_model.transformer + # for module_name, module in transformer.named_modules(): + # if isinstance(module, CurriculumLoRALinear) and module.adapters is not None: + # for adapter_idx, scale_param in enumerate(module.adapter_scales): + # tb_logger.add_scalar( + # f'Curriculum/adapter_scales/{module_name}/adapter_{adapter_idx}', + # scale_param().item(), + # global_step=learner.train_iter + # ) + + # 新增的 alpha 缩放因子日志记录 + try: + transformer = policy._learn_model.world_model.transformer + for module_name, module in transformer.named_modules(): + if isinstance(module, CurriculumLoRALinear): + # 检查模块是否有 base_weight_scale 属性 + if hasattr(module, 'base_weight_scale') and module.base_weight_scale is not None: + # 1. 记录基座权重的缩放因子 (alpha_0) + tb_logger.add_scalar( + f'Curriculum/alpha_scales/{module_name}/alpha_0_base_weight', + module.base_weight_scale().item(), + global_step=learner.train_iter + ) + + # 检查模块是否有 adapter_scales 属性 + if hasattr(module, 'adapter_scales') and module.adapter_scales is not None: + # 2. 遍历并记录所有适配器的缩放因子 (alpha_1, alpha_2, ...) + for adapter_idx, scale_param in enumerate(module.adapter_scales): + # adapter_idx 是从 0 开始的,对应 alpha_{idx+1} + tb_logger.add_scalar( + f'Curriculum/alpha_scales/{module_name}/alpha_{adapter_idx + 1}', + scale_param().item(), + global_step=learner.train_iter + ) + except Exception as e: + logging.warning(f"Failed to log alpha scales: {e}") + + + # Ensure all processes are aware of a potential stage switch + dist.barrier() + + # --- 5. Training Step --- + unsolved_buffers = [rb for cfg, rb in zip(task_configs, replay_buffers) if cfg.policy.task_id not in solved_task_pool] + unsolved_cfgs = [cfg for cfg in task_configs if cfg.policy.task_id not in solved_task_pool] + + if not unsolved_buffers: + logging.info(f"Rank {rank}: All assigned tasks are solved. Performing dummy training to maintain DDP sync.") + # When all local tasks are solved, we must still participate in DDP. + # A dummy forward/backward pass with zeroed gradients can ensure this. + # The current implementation uses a minimal batch from solved tasks with `ignore_grad=True`. + for _ in range(compiled_cfg.policy.update_per_collect): + train_data_list = [] + for cfg, replay_buffer in zip(task_configs, replay_buffers): # Use original buffers + batch_size = 2 # Minimal batch size for sync + if replay_buffer.get_num_of_transitions() >= batch_size: + train_data = replay_buffer.sample(batch_size, policy) + train_data.append(cfg.policy.task_id) + train_data_list.append(train_data) + + if train_data_list: + learner.train(train_data_list, collector.envstep, policy_kwargs={'task_weights': None, "ignore_grad": True}) + + else: + for _ in range(compiled_cfg.policy.update_per_collect): + train_data_list = [] + total_envstep = sum(c.envstep for c in collectors) + for cfg, replay_buffer in zip(unsolved_cfgs, unsolved_buffers): + batch_size = cfg.policy.batch_size[cfg.policy.task_id] + if replay_buffer.get_num_of_transitions() >= batch_size: + train_data = replay_buffer.sample(batch_size, policy) + train_data.append(cfg.policy.task_id) + train_data_list.append(train_data) + else: + logging.warning(f"Skipping training for task {cfg.policy.task_id}: not enough data in buffer.") + + if train_data_list: + learn_kwargs = {'task_weights': task_weights, "ignore_grad": False} + learner.train(train_data_list, total_envstep, policy_kwargs=learn_kwargs) + + train_epoch += 1 + policy.recompute_pos_emb_diff_and_clear_cache() + + # --- 6. Synchronization and Termination Check --- + dist.barrier() # Ensure all ranks complete the training step + + # Check for termination conditions + max_iter_reached = torch.tensor([learner.train_iter >= max_train_iter], dtype=torch.bool, device=compiled_cfg.policy.device) + dist.all_reduce(max_iter_reached, op=dist.ReduceOp.SUM) + + # For env_step, gather from all collectors on all ranks + local_env_steps = torch.tensor([c.envstep for c in collectors], dtype=torch.long, device=compiled_cfg.policy.device) + all_env_steps = [torch.zeros_like(local_env_steps) for _ in range(world_size)] + # Note: all_gather requires all tensors to be the same size. This assumes each rank has the same number of collectors. + # If not, a more complex gathering method (e.g., all_gather_object) is needed. + try: + dist.all_gather(all_env_steps, local_env_steps) + max_step_reached = (torch.cat(all_env_steps).min() >= max_env_step) if all_env_steps else False + except RuntimeError: # If tensor sizes mismatch + max_step_reached = False # Fallback, consider logging an error + logging.warning("Could not gather env_steps due to tensor size mismatch across ranks. Termination check may be inaccurate.") + + if max_iter_reached.item() or max_step_reached: + logging.info(f"Rank {rank}: Termination condition met. Stopping training.") + break + + # --- Finalization --- + learner.call_hook('after_run') + return policy \ No newline at end of file diff --git a/lzero/entry/train_unizero_multitask_segment_ddp.py b/lzero/entry/train_unizero_multitask_segment_ddp.py index 8c3d6c15f..ada067bd2 100644 --- a/lzero/entry/train_unizero_multitask_segment_ddp.py +++ b/lzero/entry/train_unizero_multitask_segment_ddp.py @@ -1,13 +1,13 @@ import logging import os from functools import partial -from typing import Tuple, Optional, List +from typing import Tuple, Optional, List, Dict import torch import numpy as np from ding.config import compile_config from ding.envs import create_env_manager, get_vec_env_setting -from ding.policy import create_policy +from ding.policy import create_policy, Policy from ding.rl_utils import get_epsilon_greedy_fn from ding.utils import set_pkg_seed, get_rank, get_world_size from ding.worker import BaseLearner @@ -15,18 +15,114 @@ from lzero.entry.utils import log_buffer_memory_usage, TemperatureScheduler from lzero.policy import visit_count_temperature +# HACK: The following imports are for type hinting purposes. +# The actual GameBuffer is selected dynamically based on the policy type. +from lzero.mcts import UniZeroGameBuffer from lzero.worker import MuZeroEvaluator as Evaluator from lzero.worker import MuZeroSegmentCollector as Collector from ding.utils import EasyTimer import torch.nn.functional as F import torch.distributed as dist - import concurrent.futures +from collections import defaultdict + + +# ==================================================================================================================== +# Note: The following global benchmark score definitions are for reference. +# The active implementation for score initialization is located within the `train_unizero_multitask_segment_ddp` function +# to ensure scores are correctly set based on the `benchmark_name` argument passed to the function. +# ==================================================================================================================== +# global BENCHMARK_NAME +# # BENCHMARK_NAME = "atari" +# BENCHMARK_NAME = "dmc" # TODO +# if BENCHMARK_NAME == "atari": +# RANDOM_SCORES = np.array([ +# 227.8, 5.8, 222.4, 210.0, 14.2, 2360.0, 0.1, 1.7, 811.0, 10780.5, +# 152.1, 0.0, 65.2, 257.6, 1027.0, 29.0, 52.0, 1598.0, 258.5, 307.3, +# -20.7, 24.9, 163.9, 11.5, 68.4, 533.4 +# ]) +# HUMAN_SCORES = np.array([ +# 7127.7, 1719.5, 742.0, 8503.3, 753.1, 37187.5, 12.1, 30.5, 7387.8, 35829.4, +# 1971.0, 29.6, 4334.7, 2412.5, 30826.4, 302.8, 3035.0, 2665.5, 22736.3, 6951.6, +# 14.6, 69571.3, 13455.0, 7845.0, 42054.7, 11693.2 +# ]) +# elif BENCHMARK_NAME == "dmc": +# RANDOM_SCORES = np.array([0]*26) +# HUMAN_SCORES = np.array([1000]*26) +# +# # New order to original index mapping +# # New order: [Pong, MsPacman, Seaquest, Boxing, Alien, ChopperCommand, Hero, RoadRunner, +# # Amidar, Assault, Asterix, BankHeist, BattleZone, CrazyClimber, DemonAttack, +# # Freeway, Frostbite, Gopher, Jamesbond, Kangaroo, Krull, KungFuMaster, +# # PrivateEye, UpNDown, Qbert, Breakout] +# # Mapping to indices in the original array (0-based) +# new_order = [ +# 20, 19, 24, 6, 0, 8, 14, 23, 1, 2, 3, 4, 5, 9, 10, 11, 12, 13, 15, 16, 17, 18, 21, 25, 22, 7 +# ] +# +# # Generate new arrays based on new_order +# new_RANDOM_SCORES = RANDOM_SCORES[new_order] +# new_HUMAN_SCORES = HUMAN_SCORES[new_order] + + +# ------------------------------------------------------------ +# 1. Add a dedicated process-group for the learner. +# (This should be called once during main/learner initialization) +# ------------------------------------------------------------ +def build_learner_group(learner_ranks: list[int]) -> "dist.ProcessGroup": + """ + Overview: + Build a new process group for learners that perform backward propagation. + This is useful in scenarios like MoCo where specific ranks handle the learning process. + Arguments: + - learner_ranks (:obj:`list[int]`): A list of ranks that will perform the backward pass. + For example, if CUDA_VISIBLE_DEVICES=0,1, then learner_ranks=[0,1]. + Returns: + - pg (:obj:`dist.ProcessGroup`): A new process group for the specified learner ranks. + """ + world_pg = dist.group.WORLD + pg = dist.new_group(ranks=learner_ranks, backend='nccl') + if dist.get_rank() in learner_ranks: + torch.cuda.set_device(learner_ranks.index(dist.get_rank())) + return pg + + +# Stores the latest evaluation returns: {task_id: eval_episode_return_mean} +GLOBAL_EVAL_RETURNS: Dict[int, float] = defaultdict(lambda: None) + + +def compute_unizero_mt_normalized_stats( + eval_returns: Dict[int, float] +) -> Tuple[Optional[float], Optional[float]]: + """ + Overview: + Computes the Human-Normalized Mean and Median from evaluation returns for UniZero-MT. + If there are no samples, it returns (None, None). + Arguments: + - eval_returns (:obj:`Dict[int, float]`): A dictionary of evaluation returns, keyed by task ID. + Returns: + - (:obj:`Tuple[Optional[float], Optional[float]]`): A tuple containing the human-normalized mean and median. + Returns (None, None) if no valid returns are provided. + """ + normalized = [] + for tid, ret in eval_returns.items(): + if ret is None: + continue + # Denominator for normalization + denom = new_HUMAN_SCORES[tid] - new_RANDOM_SCORES[tid] + if denom == 0: + continue + normalized.append((ret - new_RANDOM_SCORES[tid]) / denom) + + if not normalized: + return None, None + arr = np.asarray(normalized, dtype=np.float32) + return float(arr.mean()), float(np.median(arr)) -# 设置超时时间 (秒) -TIMEOUT = 12000 # 例如200分钟 +# Set a timeout for evaluation in seconds +TIMEOUT = 12000 # e.g., 200 minutes timer = EasyTimer() @@ -39,29 +135,29 @@ def safe_eval( world_size: int ) -> Tuple[Optional[bool], Optional[float]]: """ - Safely执行评估任务,避免超时。 - - Args: - evaluator (Evaluator): 评估器实例。 - learner (BaseLearner): 学习器实例。 - collector (Collector): 数据收集器实例。 - rank (int): 当前进程的rank。 - world_size (int): 总进程数。 - + Overview: + Safely executes an evaluation task with a timeout to prevent hangs. + Arguments: + - evaluator (:obj:`Evaluator`): The evaluator instance. + - learner (:obj:`BaseLearner`): The learner instance. + - collector (:obj:`Collector`): The data collector instance. + - rank (:obj:`int`): The rank of the current process. + - world_size (:obj:`int`): The total number of processes. Returns: - Tuple[Optional[bool], Optional[float]]: 如果评估成功,返回停止标志和奖励,否则返回(None, None)。 + - (:obj:`Tuple[Optional[bool], Optional[float]]`): A tuple containing the stop flag and reward if evaluation succeeds, + otherwise (None, None). """ try: print(f"=========评估开始 Rank {rank}/{world_size}===========") - # 重置 stop_event,确保每次评估前都处于未设置状态 + # Reset the stop_event to ensure it is not set before each evaluation. evaluator.stop_event.clear() with concurrent.futures.ThreadPoolExecutor() as executor: - # 提交评估任务 + # Submit the evaluation task. future = executor.submit(evaluator.eval, learner.save_checkpoint, learner.train_iter, collector.envstep) try: stop, reward = future.result(timeout=TIMEOUT) except concurrent.futures.TimeoutError: - # 超时,设置 stop_event + # If a timeout occurs, set the stop_event. evaluator.stop_event.set() print(f"评估操作在 Rank {rank}/{world_size} 上超时,耗时 {TIMEOUT} 秒。") return None, None @@ -75,185 +171,185 @@ def safe_eval( def allocate_batch_size( cfgs: List[dict], - game_buffers, + game_buffers: List['UniZeroGameBuffer'], alpha: float = 1.0, clip_scale: int = 1 ) -> List[int]: """ - 根据不同任务的收集剧集数反比分配batch_size,并动态调整batch_size范围以提高训练稳定性和效率。 - - Args: - cfgs (List[dict]): 每个任务的配置列表。 - game_buffers (List[GameBuffer]): 每个任务的重放缓冲区实例列表。 - alpha (float, optional): 控制反比程度的超参数。默认为1.0。 - clip_scale (int, optional): 动态调整的clip比例。默认为1。 - + Overview: + Allocates batch sizes for different tasks inversely proportional to the number of collected episodes. + It also dynamically adjusts the batch size range to improve training stability and efficiency. + Arguments: + - cfgs (:obj:`List[dict]`): A list of configurations for each task. + - game_buffers (:obj:`List[GameBuffer]`): A list of replay buffer instances for each task. + - alpha (:obj:`float`): A hyperparameter to control the degree of inverse proportionality. Defaults to 1.0. + - clip_scale (:obj:`int`): The clipping ratio for dynamic adjustment. Defaults to 1. Returns: - List[int]: 分配后的batch_size列表。 + - (:obj:`List[int]`): The list of allocated batch sizes. """ - # 提取每个任务的 collected episodes 数量 + # Extract the number of collected episodes for each task. buffer_num_of_collected_episodes = [buffer.num_of_collected_episodes for buffer in game_buffers] - # 获取当前的 world_size 和 rank + # Get the current world_size and rank. world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() - # 收集所有 rank 的 collected episodes 列表 + # Gather the lists of collected episodes from all ranks. all_task_num_of_collected_episodes = [None for _ in range(world_size)] torch.distributed.all_gather_object(all_task_num_of_collected_episodes, buffer_num_of_collected_episodes) - # 将所有 rank 的 collected episodes 合并为一个大列表 + # Merge the collected episodes from all ranks into a single list. all_task_num_of_collected_episodes = [ episode for sublist in all_task_num_of_collected_episodes for episode in sublist ] if rank == 0: print(f'所有任务的 collected episodes: {all_task_num_of_collected_episodes}') - # 计算每个任务的反比权重 + # Calculate the inverse proportional weights for each task. inv_episodes = np.array([1.0 / (episodes + 1) for episodes in all_task_num_of_collected_episodes]) inv_sum = np.sum(inv_episodes) - # 计算总的batch_size (所有任务 cfg.policy.batch_size 的和) + # Calculate the total batch size (sum of cfg.policy.batch_size for all tasks). total_batch_size = cfgs[0].policy.total_batch_size - # 动态调整的部分:最小和最大的 batch_size 范围 + # Dynamic adjustment: define the min and max batch size range. avg_batch_size = total_batch_size / world_size min_batch_size = avg_batch_size / clip_scale max_batch_size = avg_batch_size * clip_scale - # 动态调整 alpha,让 batch_size 的变化更加平滑 + # Dynamically adjust alpha to make batch size changes smoother. task_weights = (inv_episodes / inv_sum) ** alpha batch_sizes = total_batch_size * task_weights - # 控制 batch_size 在 [min_batch_size, max_batch_size] 之间 + # Clip the batch sizes to be within the [min_batch_size, max_batch_size] range. batch_sizes = np.clip(batch_sizes, min_batch_size, max_batch_size) - # 确保 batch_size 是整数 + # Ensure batch sizes are integers. batch_sizes = [int(size) for size in batch_sizes] return batch_sizes -import numpy as np - def symlog(x: torch.Tensor) -> torch.Tensor: """ - Symlog 归一化,减少目标值的幅度差异。 - symlog(x) = sign(x) * log(|x| + 1) + Overview: + Symlog normalization to reduce the magnitude difference of target values. + symlog(x) = sign(x) * log(|x| + 1) """ return torch.sign(x) * torch.log(torch.abs(x) + 1) + def inv_symlog(x: torch.Tensor) -> torch.Tensor: """ - Symlog 的逆操作,用于恢复原始值。 - inv_symlog(x) = sign(x) * (exp(|x|) - 1) + Overview: + Inverse operation of Symlog to restore the original value. + inv_symlog(x) = sign(x) * (exp(|x|) - 1) """ return torch.sign(x) * (torch.exp(torch.abs(x)) - 1) -# 全局最大值和最小值(用于 "run-max-min") + +# Global max and min for "run-max-min" normalization GLOBAL_MAX = -float('inf') GLOBAL_MIN = float('inf') + def compute_task_weights( - task_rewards: dict, - option: str = "symlog", - epsilon: float = 1e-6, - temperature: float = 1.0, - use_softmax: bool = False, # 是否使用 Softmax - reverse: bool = False, # 正比 (False) 或反比 (True) - clip_min: float = 1e-2, # 权重的最小值 - clip_max: float = 1.0, # 权重的最大值 -) -> dict: + task_returns: Dict[int, float], + option: str = "symlog", + epsilon: float = 1e-6, + temperature: float = 1.0, + use_softmax: bool = False, + reverse: bool = False, + clip_min: float = 1e-2, + clip_max: float = 1.0, +) -> Dict[int, float]: """ - 改进后的任务权重计算函数,支持多种标准化方式、Softmax 和正反比权重计算,并增加权重范围裁剪功能。 - - Args: - task_rewards (dict): 每个任务的字典,键为 task_id,值为评估奖励或损失。 - option (str): 标准化方式,可选值为 "symlog", "max-min", "run-max-min", "rank", "none"。 - epsilon (float): 避免分母为零的小值。 - temperature (float): 控制权重分布的温度系数。 - use_softmax (bool): 是否使用 Softmax 进行权重分配。 - reverse (bool): 若为 True,权重与值反比;若为 False,权重与值正比。 - clip_min (float): 权重的最小值,用于裁剪。 - clip_max (float): 权重的最大值,用于裁剪。 - + Overview: + An improved function for calculating task weights, supporting multiple normalization methods, + Softmax, proportional/inverse weighting, and weight clipping. + Arguments: + - task_returns (:obj:`Dict[int, float]`): A dictionary where keys are task_ids and values are evaluation rewards or losses. + - option (:obj:`str`): Normalization method. Options: "symlog", "max-min", "run-max-min", "rank", "none". + - epsilon (:obj:`float`): A small value to avoid division by zero. + - temperature (:obj:`float`): Temperature coefficient to control the weight distribution. + - use_softmax (:obj:`bool`): Whether to use Softmax for weight distribution. + - reverse (:obj:`bool`): If True, weights are inversely proportional to values; if False, they are proportional. + - clip_min (:obj:`float`): The minimum value to clip weights to. + - clip_max (:obj:`float`): The maximum value to clip weights to. Returns: - dict: 每个任务的权重,键为 task_id,值为归一化后的权重。 + - (:obj:`Dict[int, float]`): A dictionary of weights for each task, where keys are task_ids. """ - import torch - import torch.nn.functional as F - global GLOBAL_MAX, GLOBAL_MIN - # 如果输入为空字典,直接返回空结果 - if not task_rewards: + # Return an empty dictionary if the input is empty. + if not task_returns: return {} - # Step 1: 对 task_rewards 的值构造张量 - task_ids = list(task_rewards.keys()) - rewards_tensor = torch.tensor(list(task_rewards.values()), dtype=torch.float32) + # Step 1: Construct a tensor from the values of task_returns. + task_ids = list(task_returns.keys()) + returns_tensor = torch.tensor(list(task_returns.values()), dtype=torch.float32) if option == "symlog": - # 使用 symlog 标准化 - scaled_rewards = symlog(rewards_tensor) + # Use symlog normalization. + scaled_returns = symlog(returns_tensor) elif option == "max-min": - # 使用最大最小值归一化 - max_reward = rewards_tensor.max().item() - min_reward = rewards_tensor.min().item() - scaled_rewards = (rewards_tensor - min_reward) / (max_reward - min_reward + epsilon) + # Use max-min normalization. + max_reward = returns_tensor.max().item() + min_reward = returns_tensor.min().item() + scaled_returns = (returns_tensor - min_reward) / (max_reward - min_reward + epsilon) elif option == "run-max-min": - # 使用全局最大最小值归一化 - GLOBAL_MAX = max(GLOBAL_MAX, rewards_tensor.max().item()) - GLOBAL_MIN = min(GLOBAL_MIN, rewards_tensor.min().item()) - scaled_rewards = (rewards_tensor - GLOBAL_MIN) / (GLOBAL_MAX - GLOBAL_MIN + epsilon) + # Use global running max-min normalization. + GLOBAL_MAX = max(GLOBAL_MAX, returns_tensor.max().item()) + GLOBAL_MIN = min(GLOBAL_MIN, returns_tensor.min().item()) + scaled_returns = (returns_tensor - GLOBAL_MIN) / (GLOBAL_MAX - GLOBAL_MIN + epsilon) elif option == "rank": - # 使用 rank 标准化 - # Rank 是基于值大小的排名,1 表示最小值,越大排名越高 - sorted_indices = torch.argsort(rewards_tensor) - scaled_rewards = torch.empty_like(rewards_tensor) - rank_values = torch.arange(1, len(rewards_tensor) + 1, dtype=torch.float32) # 1 到 N - scaled_rewards[sorted_indices] = rank_values + # Use rank-based normalization. Rank is based on value size, with 1 for the smallest. + sorted_indices = torch.argsort(returns_tensor) + scaled_returns = torch.empty_like(returns_tensor) + rank_values = torch.arange(1, len(returns_tensor) + 1, dtype=torch.float32) # Ranks from 1 to N + scaled_returns[sorted_indices] = rank_values elif option == "none": - # 不进行标准化 - scaled_rewards = rewards_tensor + # No normalization. + scaled_returns = returns_tensor else: raise ValueError(f"Unsupported option: {option}") - # Step 2: 根据 reverse 确定权重是正比还是反比 + # Step 2: Determine if weights are proportional or inversely proportional based on `reverse`. if not reverse: - # 正比:权重与值正相关 - raw_weights = scaled_rewards + # Proportional: weight is positively correlated with the value. + raw_weights = scaled_returns else: - # 反比:权重与值负相关 - # 避免 scaled_rewards 为负数或零 - scaled_rewards = torch.clamp(scaled_rewards, min=epsilon) - raw_weights = 1.0 / scaled_rewards + # Inverse: weight is negatively correlated with the value. + # Clamp to avoid division by zero or negative numbers. + scaled_returns = torch.clamp(scaled_returns, min=epsilon) + raw_weights = 1.0 / scaled_returns - # Step 3: 根据是否使用 Softmax 进行权重计算 + # Step 3: Calculate weights with or without Softmax. if use_softmax: - # 使用 Softmax 进行权重分配 - beta = 1.0 / max(temperature, epsilon) # 确保 temperature 不为零 + # Use Softmax for weight distribution. + beta = 1.0 / max(temperature, epsilon) # Ensure temperature is not zero. logits = -beta * raw_weights softmax_weights = F.softmax(logits, dim=0).numpy() weights = dict(zip(task_ids, softmax_weights)) else: - # 不使用 Softmax,直接计算权重 - # 温度缩放 - scaled_weights = raw_weights ** (1 / max(temperature, epsilon)) # 确保温度不为零 + # Do not use Softmax, calculate weights directly. + # Temperature scaling. + scaled_weights = raw_weights ** (1 / max(temperature, epsilon)) # Ensure temperature is not zero. - # 归一化权重 + # Normalize weights. total_weight = scaled_weights.sum() normalized_weights = scaled_weights / total_weight - # 转换为字典 + # Convert to dictionary. weights = dict(zip(task_ids, normalized_weights.numpy())) - # Step 4: Clip 权重范围 + # Step 4: Clip the weight range. for task_id in weights: weights[task_id] = max(min(weights[task_id], clip_max), clip_min) return weights + def train_unizero_multitask_segment_ddp( input_cfg_list: List[Tuple[int, Tuple[dict, dict]]], seed: int = 0, @@ -261,56 +357,104 @@ def train_unizero_multitask_segment_ddp( model_path: Optional[str] = None, max_train_iter: Optional[int] = int(1e10), max_env_step: Optional[int] = int(1e10), + benchmark_name: str = "atari" ) -> 'Policy': """ Overview: - UniZero的训练入口,旨在通过解决MuZero类算法在需要捕捉长期依赖环境中的局限性,提高强化学习代理的规划能力。 - 详细信息请参阅 https://arxiv.org/abs/2406.10667。 - - Args: - - input_cfg_list (:obj:`List[Tuple[int, Tuple[dict, dict]]]`): 不同任务的配置列表。 - - seed (:obj:`int`): 随机种子。 - - model (:obj:`Optional[torch.nn.Module]`): torch.nn.Module实例。 - - model_path (:obj:`Optional[str]`): 预训练模型路径,应指向预训练模型的ckpt文件。 - - max_train_iter (:obj:`Optional[int]`): 训练中的最大策略更新迭代次数。 - - max_env_step (:obj:`Optional[int]`): 最大收集环境交互步数。 + The training entry point for UniZero, designed to enhance the planning capabilities of reinforcement learning agents + by addressing the limitations of MuZero-like algorithms in environments requiring long-term dependency capture. + For more details, please refer to https://arxiv.org/abs/2406.10667. + + Arguments: + - input_cfg_list (:obj:`List[Tuple[int, Tuple[dict, dict]]]`): A list of configurations for different tasks. + - seed (:obj:`int`): The random seed. + - model (:obj:`Optional[torch.nn.Module]`): An instance of torch.nn.Module. + - model_path (:obj:`Optional[str]`): The path to a pre-trained model checkpoint file. + - max_train_iter (:obj:`Optional[int]`): The maximum number of policy update iterations during training. + - max_env_step (:obj:`Optional[int]`): The maximum number of environment interaction steps to collect. + - benchmark_name (:obj:`str`): The name of the benchmark, e.g., "atari" or "dmc". Returns: - - policy (:obj:`Policy`): 收敛的策略。 + - policy (:obj:`Policy`): The converged policy. """ - # 初始化温度调度器 + # ------------------------------------------------------------------------------------ + # ====== UniZero-MT Benchmark Scores (corresponding to 26 Atari100k task IDs) ====== + # Original RANDOM_SCORES and HUMAN_SCORES + if benchmark_name == "atari": + RANDOM_SCORES = np.array([ + 227.8, 5.8, 222.4, 210.0, 14.2, 2360.0, 0.1, 1.7, 811.0, 10780.5, + 152.1, 0.0, 65.2, 257.6, 1027.0, 29.0, 52.0, 1598.0, 258.5, 307.3, + -20.7, 24.9, 163.9, 11.5, 68.4, 533.4 + ]) + HUMAN_SCORES = np.array([ + 7127.7, 1719.5, 742.0, 8503.3, 753.1, 37187.5, 12.1, 30.5, 7387.8, 35829.4, + 1971.0, 29.6, 4334.7, 2412.5, 30826.4, 302.8, 3035.0, 2665.5, 22736.3, 6951.6, + 14.6, 69571.3, 13455.0, 7845.0, 42054.7, 11693.2 + ]) + elif benchmark_name == "dmc": + RANDOM_SCORES = np.zeros(26) + HUMAN_SCORES = np.ones(26) * 1000 + else: + raise ValueError(f"Unsupported BENCHMARK_NAME: {benchmark_name}") + + # New order to original index mapping + # New order: [Pong, MsPacman, Seaquest, Boxing, Alien, ChopperCommand, Hero, RoadRunner, + # Amidar, Assault, Asterix, BankHeist, BattleZone, CrazyClimber, DemonAttack, + # Freeway, Frostbite, Gopher, Jamesbond, Kangaroo, Krull, KungFuMaster, + # PrivateEye, UpNDown, Qbert, Breakout] + # Mapping to indices in the original array (0-based) + new_order = [ + 20, 19, 24, 6, 0, 8, 14, 23, 1, 2, 3, 4, 5, 9, 10, 11, 12, 13, 15, 16, 17, 18, 21, 25, 22, 7 + ] + global new_RANDOM_SCORES, new_HUMAN_SCORES + # Generate new arrays based on new_order + new_RANDOM_SCORES = RANDOM_SCORES[new_order] + new_HUMAN_SCORES = HUMAN_SCORES[new_order] + # Log the reordered results + print("重排后的 RANDOM_SCORES:") + print(new_RANDOM_SCORES) + print("\n重排后的 HUMAN_SCORES:") + print(new_HUMAN_SCORES) + # ------------------------------------------------------------------------------------ + + # Initialize the temperature scheduler for task weighting. initial_temperature = 10.0 final_temperature = 1.0 - threshold_steps = int(1e4) # 训练步数达到 10k 时,温度降至 1.0 + threshold_steps = int(1e4) # Temperature drops to 1.0 after 10k training steps. temperature_scheduler = TemperatureScheduler( initial_temp=initial_temperature, final_temp=final_temperature, threshold_steps=threshold_steps, - mode='linear' # 或 'exponential' + mode='linear' # or 'exponential' ) - # 获取当前进程的rank和总进程数 + # Get the current process rank and total world size. rank = get_rank() world_size = get_world_size() - # 任务划分 + # Task partitioning among ranks. total_tasks = len(input_cfg_list) tasks_per_rank = total_tasks // world_size remainder = total_tasks % world_size + # ==================== START: 关键修复 ==================== + # 1. 精确计算当前Rank负责的任务数量 if rank < remainder: start_idx = rank * (tasks_per_rank + 1) end_idx = start_idx + tasks_per_rank + 1 + num_tasks_for_this_rank = tasks_per_rank + 1 else: start_idx = rank * tasks_per_rank + remainder end_idx = start_idx + tasks_per_rank + num_tasks_for_this_rank = tasks_per_rank + # ==================== END: 关键修复 ==================== tasks_for_this_rank = input_cfg_list[start_idx:end_idx] - # 确保至少有一个任务 + # Ensure at least one task is assigned. if len(tasks_for_this_rank) == 0: - logging.warning(f"Rank {rank}: 未分配任务,继续执行。") - # 初始化空列表以避免后续代码报错 + logging.warning(f"Rank {rank}: No tasks assigned, continuing execution.") + # Initialize empty lists to avoid errors later. cfgs, game_buffers, collector_envs, evaluator_envs, collectors, evaluators = [], [], [], [], [], [] else: print(f"Rank {rank}/{world_size}, 处理任务 {start_idx} 到 {end_idx - 1}") @@ -323,56 +467,63 @@ def train_unizero_multitask_segment_ddp( evaluators = [] if tasks_for_this_rank: - # 使用第一个任务的配置创建共享的policy + # Use the config of the first task to create a shared policy. task_id, [cfg, create_cfg] = tasks_for_this_rank[0] - for config in tasks_for_this_rank: - config[1][0].policy.task_num = tasks_per_rank - - # 确保指定的策略类型受支持 - assert create_cfg.policy.type in ['unizero_multitask', - 'sampled_unizero_multitask'], "train_unizero entry 目前仅支持 'unizero_multitask'" + # ==================== START: 关键修复 ==================== + # 2. 将正确的任务数量设置到 *所有* 相关配置中 + # 在创建Policy实例之前,必须确保配置是正确的 + for config_tuple in tasks_for_this_rank: + # config_tuple is (task_id, [cfg_obj, create_cfg_obj]) + config_tuple[1][0].policy.task_num = num_tasks_for_this_rank + + # 3. 确保用于创建Policy的那个cfg对象也拥有正确的task_num + cfg.policy.task_num = num_tasks_for_this_rank + # ==================== END: 关键修复 ==================== + + # Ensure the specified policy type is supported. + assert create_cfg.policy.type in ['unizero_multitask', 'sampled_unizero_multitask'], \ + "train_unizero entry currently only supports 'unizero_multitask' or 'sampled_unizero_multitask'" if create_cfg.policy.type == 'unizero_multitask': from lzero.mcts import UniZeroGameBuffer as GameBuffer if create_cfg.policy.type == 'sampled_unizero_multitask': from lzero.mcts import SampledUniZeroGameBuffer as GameBuffer - - # 根据CUDA可用性设置设备 + # Set device based on CUDA availability. cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu' logging.info(f'配置的设备: {cfg.policy.device}') - # 编译配置 + # Compile the configuration. cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) - # 创建共享的policy + # Create the shared policy. policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) - # 加载预训练模型(如果提供) + # Load a pre-trained model if a path is provided. if model_path is not None: logging.info(f'开始加载模型: {model_path}') policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) logging.info(f'完成加载模型: {model_path}') - # 创建TensorBoard日志记录器 + # Create a TensorBoard logger. log_dir = os.path.join('./{}/log'.format(cfg.exp_name), f'serial_rank_{rank}') tb_logger = SummaryWriter(log_dir) - # 创建共享的learner + # Create the shared learner. learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) policy_config = cfg.policy - # 处理当前进程分配到的每个任务 + # Process each task assigned to the current rank. for local_task_id, (task_id, [cfg, create_cfg]) in enumerate(tasks_for_this_rank): - # 设置每个任务的随机种子 + # Set a unique random seed for each task. cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu' cfg = compile_config(cfg, seed=seed + task_id, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) policy_config = cfg.policy policy.collect_mode.get_attribute('cfg').n_episode = policy_config.n_episode policy.eval_mode.get_attribute('cfg').n_episode = policy_config.n_episode - # 创建环境 + # Create environments. env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) @@ -380,7 +531,7 @@ def train_unizero_multitask_segment_ddp( evaluator_env.seed(cfg.seed + task_id, dynamic_seed=False) set_pkg_seed(cfg.seed + task_id, use_cuda=cfg.policy.cuda) - # 创建不同的game buffer、collector和evaluator + # Create task-specific game buffers, collectors, and evaluators. replay_buffer = GameBuffer(policy_config) collector = Collector( env=collector_env, @@ -411,7 +562,7 @@ def train_unizero_multitask_segment_ddp( collectors.append(collector) evaluators.append(evaluator) - # 调用learner的before_run钩子 + # Call the learner's before_run hook. learner.call_hook('before_run') value_priority_tasks = {} @@ -420,15 +571,13 @@ def train_unizero_multitask_segment_ddp( reanalyze_batch_size = cfg.policy.reanalyze_batch_size update_per_collect = cfg.policy.update_per_collect - task_complexity_weight = cfg.policy.task_complexity_weight - use_task_exploitation_weight = cfg.policy.use_task_exploitation_weight task_exploitation_weight = None - # 创建任务奖励字典 - task_rewards = {} # {task_id: reward} + # Dictionary to store task rewards. + task_returns = {} # {task_id: reward} while True: - # 动态调整batch_size + # Dynamically adjust batch sizes. if cfg.policy.allocated_batch_sizes: clip_scale = np.clip(1 + (3 * train_epoch / 1000), 1, 4) allocated_batch_sizes = allocate_batch_size(cfgs, game_buffers, alpha=1.0, clip_scale=clip_scale) @@ -439,11 +588,11 @@ def train_unizero_multitask_segment_ddp( cfg.policy.batch_size = allocated_batch_sizes policy._cfg.batch_size = allocated_batch_sizes - # 对于当前进程的每个任务,进行数据收集和评估 + # For each task on the current rank, perform data collection and evaluation. for idx, (cfg, collector, evaluator, replay_buffer) in enumerate( zip(cfgs, collectors, evaluators, game_buffers)): - # 记录缓冲区内存使用情况 + # Log buffer memory usage. log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger, cfg.policy.task_id) collect_kwargs = { @@ -453,7 +602,7 @@ def train_unizero_multitask_segment_ddp( policy_config.threshold_training_steps_for_final_temperature, trained_steps=learner.train_iter ), - 'epsilon': 0.0 # 默认的epsilon值 + 'epsilon': 0.0 # Default epsilon value. } if policy_config.eps.eps_greedy_exploration_in_collect: @@ -465,57 +614,55 @@ def train_unizero_multitask_segment_ddp( ) collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) - # 判断是否需要进行评估 - # if learner.train_iter == 0 or evaluator.should_eval(learner.train_iter): - if learner.train_iter > 10 and evaluator.should_eval(learner.train_iter): # only for debug - # if evaluator.should_eval(learner.train_iter): + # Check if it's time for evaluation. + if learner.train_iter > 10 and learner.train_iter % cfg.policy.eval_freq == 0: + # if learner.train_iter == 0 or learner.train_iter % cfg.policy.eval_freq == 0: # only for debug TODO + print('=' * 20) print(f'Rank {rank} 评估任务_id: {cfg.policy.task_id}...') - # =========TODO========= + # TODO: Ensure policy reset logic is optimal for multi-task settings. evaluator._policy.reset(reset_init_data=True, task_id=cfg.policy.task_id) - # 执行安全评估 + # Perform safe evaluation. stop, reward = safe_eval(evaluator, learner, collector, rank, world_size) - # 判断评估是否成功 + # Check if evaluation was successful. if stop is None or reward is None: print(f"Rank {rank} 在评估过程中遇到问题,继续训练...") - task_rewards[cfg.policy.task_id] = float('inf') # 如果评估失败,将任务难度设为最大值 + task_returns[cfg.policy.task_id] = float('inf') # Set task difficulty to max if evaluation fails. else: - # 确保从评估结果中提取 `eval_episode_return_mean` 作为奖励值 + # Extract 'eval_episode_return_mean' from the reward dictionary. try: eval_mean_reward = reward.get('eval_episode_return_mean', float('inf')) print(f"任务 {cfg.policy.task_id} 的评估奖励: {eval_mean_reward}") - task_rewards[cfg.policy.task_id] = eval_mean_reward + task_returns[cfg.policy.task_id] = eval_mean_reward except Exception as e: print(f"提取评估奖励时发生错误: {e}") - task_rewards[cfg.policy.task_id] = float('inf') # 出现问题时,将奖励设为最大值 - + task_returns[cfg.policy.task_id] = float('inf') # Set reward to max on error. print('=' * 20) print(f'开始收集 Rank {rank} 的任务_id: {cfg.policy.task_id}...') print(f'Rank {rank}: cfg.policy.task_id={cfg.policy.task_id} ') - # 在每次收集之前重置初始数据,这对于多任务设置非常重要 + # Reset initial data before each collection, crucial for multi-task settings. collector._policy.reset(reset_init_data=True, task_id=cfg.policy.task_id) - # 收集数据 + # Collect data. new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) - # 更新重放缓冲区 + # Update the replay buffer. replay_buffer.push_game_segments(new_data) replay_buffer.remove_oldest_data_to_fit() - # # ===== only for debug ===== + # ===== For debugging purposes only ===== # if train_epoch > 2: # with timer: # replay_buffer.reanalyze_buffer(2, policy) # buffer_reanalyze_count += 1 # logging.info(f'缓冲区重新分析次数: {buffer_reanalyze_count}') - # logging.info(f'缓冲区重新分析耗时: {timer.value}') - # # ===== only for debug ===== + # logging.info(f'缓冲区重新分析耗时: {timer.value}') + # ==================================== - - # 周期性地重新分析缓冲区 + # Periodically reanalyze the buffer. if cfg.policy.buffer_reanalyze_freq >= 1: reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq else: @@ -528,45 +675,62 @@ def train_unizero_multitask_segment_ddp( logging.info(f'缓冲区重新分析次数: {buffer_reanalyze_count}') logging.info(f'缓冲区重新分析耗时: {timer.value}') - # 数据收集结束后添加日志 + # Log after data collection. logging.info(f'Rank {rank}: 完成任务 {cfg.policy.task_id} 的数据收集') - # 检查是否有足够的数据进行训练 + # Check if there is enough data for training. not_enough_data = any( replay_buffer.get_num_of_transitions() < cfgs[0].policy.total_batch_size / world_size for replay_buffer in game_buffers ) - # 获取当前温度 + print(f"not_enough_data:{not_enough_data}") + # Get the current temperature for task weighting. current_temperature_task_weight = temperature_scheduler.get_temperature(learner.train_iter) - # collector._policy._task_weight_temperature = current_temperature_task_weight - # policy.collect_mode.get_attribute('task_weight_temperature') = current_temperature_task_weight - # 计算任务权重 - try: - # 汇聚任务奖励 - dist.barrier() - if task_complexity_weight: - all_task_rewards = [None for _ in range(world_size)] - dist.all_gather_object(all_task_rewards, task_rewards) - # 合并任务奖励 - merged_task_rewards = {} - for rewards in all_task_rewards: - if rewards: - merged_task_rewards.update(rewards) - # 计算全局任务权重 - task_weights = compute_task_weights(merged_task_rewards, temperature=current_temperature_task_weight) - # 同步任务权重 + if learner.train_iter > 10 and learner.train_iter % cfg.policy.eval_freq == 0: + # Calculate task weights. + try: + # Gather task rewards. + dist.barrier() + all_task_returns = [None for _ in range(world_size)] + dist.all_gather_object(all_task_returns, task_returns) + # Merge task rewards. + merged_task_returns = {} + for returns in all_task_returns: + if returns: + merged_task_returns.update(returns) + + logging.warning(f"Rank {rank}: merged_task_returns: {merged_task_returns}") + + # Calculate global task weights. + task_weights = compute_task_weights(merged_task_returns, temperature=current_temperature_task_weight) + + # ---------- Maintain UniZero-MT global evaluation results ---------- + for tid, ret in merged_task_returns.items(): + GLOBAL_EVAL_RETURNS[tid] = ret # Update even for solved tasks. + + # Calculate Human-Normalized Mean / Median. + uni_mean, uni_median = compute_unizero_mt_normalized_stats(GLOBAL_EVAL_RETURNS) + + if uni_mean is not None: # At least one task has been evaluated. + if rank == 0: # Only write to TensorBoard on rank 0 to avoid duplication. + tb_logger.add_scalar('UniZero-MT/NormalizedMean', uni_mean, global_step=learner.train_iter) + tb_logger.add_scalar('UniZero-MT/NormalizedMedian', uni_median, global_step=learner.train_iter) + logging.info(f"Rank {rank}: UniZero-MT Norm Mean={uni_mean:.4f}, Median={uni_median:.4f}") + else: + logging.info(f"Rank {rank}: 暂无数据计算 UniZero-MT 归一化指标") + + # Synchronize task weights. dist.broadcast_object_list([task_weights], src=0) - print(f"rank{rank}, 全局任务权重 (按 task_id 排列): {task_weights}") - else: - task_weights = None - except Exception as e: - logging.error(f'Rank {rank}: 同步任务权重失败,错误: {e}') - break + except Exception as e: + logging.error(f'Rank {rank}: 同步任务权重失败,错误: {e}') + break + # ---------------- Sampling done, preparing for backward pass ---------------- + # dist.barrier() # ★★★ Critical synchronization point ★★★ - # 学习策略 + # Learn policy. if not not_enough_data: for i in range(update_per_collect): train_data_multi_task = [] @@ -586,7 +750,7 @@ def train_unizero_multitask_segment_ddp( logging.info(f'缓冲区重新分析耗时: {timer.value}') train_data = replay_buffer.sample(batch_size, policy) - train_data.append(cfg.policy.task_id) # 追加task_id以区分任务 + train_data.append(cfg.policy.task_id) # Append task_id to differentiate tasks. train_data_multi_task.append(train_data) else: logging.warning( @@ -596,96 +760,100 @@ def train_unizero_multitask_segment_ddp( break if train_data_multi_task: - # learn_kwargs = {'task_exploitation_weight':task_exploitation_weight, 'task_weights':task_weights, } - learn_kwargs = {'task_weights':task_exploitation_weight} - - # 在训练时,DDP会自动同步梯度和参数 + learn_kwargs = {'task_weights': None,"train_iter":learner.train_iter} + + # DDP automatically synchronizes gradients and parameters during training. log_vars = learner.train(train_data_multi_task, envstep_multi_task, policy_kwargs=learn_kwargs) - # 判断是否需要计算task_exploitation_weight + # Check if task_exploitation_weight needs to be calculated. if i == 0: - # 计算任务权重 + # Calculate task weights. try: - dist.barrier() # 等待所有进程同步 - if use_task_exploitation_weight: - # 收集所有任务的 obs_loss + dist.barrier() # Wait for all processes to synchronize. + if cfg.policy.use_task_exploitation_weight: # Use obs loss now, new polish. + # Gather obs_loss from all tasks. all_obs_loss = [None for _ in range(world_size)] - # 构建当前进程的任务 obs_loss 数据 + # Build obs_loss data for the current process's tasks. merged_obs_loss_task = {} for cfg, replay_buffer in zip(cfgs, game_buffers): task_id = cfg.policy.task_id if f'noreduce_obs_loss_task{task_id}' in log_vars[0]: - merged_obs_loss_task[task_id] = log_vars[0][f'noreduce_obs_loss_task{task_id}'] - # 汇聚所有进程的 obs_loss 数据 + merged_obs_loss_task[task_id] = log_vars[0][ + f'noreduce_obs_loss_task{task_id}'] + # Gather obs_loss data from all processes. dist.all_gather_object(all_obs_loss, merged_obs_loss_task) - # 合并所有进程的 obs_loss 数据 + # Merge obs_loss data from all processes. global_obs_loss_task = {} for obs_loss_task in all_obs_loss: if obs_loss_task: global_obs_loss_task.update(obs_loss_task) - # 计算全局任务权重 + # Calculate global task weights. if global_obs_loss_task: task_exploitation_weight = compute_task_weights( global_obs_loss_task, option="rank", - # temperature=current_temperature_task_weight # TODO + # TODO: Decide whether to use the temperature scheduler here. temperature=1, ) - # 广播任务权重到所有进程 + # Broadcast task weights to all processes. dist.broadcast_object_list([task_exploitation_weight], src=0) - print(f"rank{rank}, task_exploitation_weight (按 task_id 排列): {task_exploitation_weight}") + print( + f"rank{rank}, task_exploitation_weight (按 task_id 排列): {task_exploitation_weight}") else: logging.warning(f"Rank {rank}: 未能计算全局 obs_loss 任务权重,obs_loss 数据为空。") task_exploitation_weight = None else: task_exploitation_weight = None - # 更新训练参数,使其包含计算后的任务权重 + # Update training parameters to include the calculated task weights. learn_kwargs['task_weight'] = task_exploitation_weight except Exception as e: logging.error(f'Rank {rank}: 同步任务权重失败,错误: {e}') - raise e # 保留异常抛出,便于外部捕获和分析 - - + raise e # Re-raise the exception for external capture and analysis. if cfg.policy.use_priority: for idx, (cfg, replay_buffer) in enumerate(zip(cfgs, game_buffers)): - # 更新任务特定的重放缓冲区优先级 + # Update task-specific replay buffer priorities. task_id = cfg.policy.task_id + # replay_buffer.update_priority( + # train_data_multi_task[idx], + # log_vars[0][f'value_priority_task{task_id}'] + # ) replay_buffer.update_priority( train_data_multi_task[idx], - log_vars[0][f'value_priority_task{task_id}'] + log_vars[0][f'noreduce_value_priority_task{task_id}'] ) - current_priorities = log_vars[0][f'value_priority_task{task_id}'] - mean_priority = np.mean(current_priorities) - std_priority = np.std(current_priorities) - - alpha = 0.1 # 平滑因子 - if f'running_mean_priority_task{task_id}' not in value_priority_tasks: - value_priority_tasks[f'running_mean_priority_task{task_id}'] = mean_priority - else: - value_priority_tasks[f'running_mean_priority_task{task_id}'] = ( - alpha * mean_priority + - (1 - alpha) * value_priority_tasks[f'running_mean_priority_task{task_id}'] - ) - - # 使用运行均值计算归一化的优先级 - running_mean_priority = value_priority_tasks[f'running_mean_priority_task{task_id}'] - normalized_priorities = (current_priorities - running_mean_priority) / (std_priority + 1e-6) - - # 如果需要,可以将归一化的优先级存储回重放缓冲区 - # replay_buffer.update_priority(train_data_multi_task[idx], normalized_priorities) - - # 记录优先级统计信息 - if cfg.policy.print_task_priority_logs: - print(f"任务 {task_id} - 平均优先级: {mean_priority:.8f}, " - f"运行平均优先级: {running_mean_priority:.8f}, " - f"标准差: {std_priority:.8f}") + # current_priorities = log_vars[0][f'value_priority_task{task_id}'] + # mean_priority = np.mean(current_priorities) + # std_priority = np.std(current_priorities) + + # alpha = 0.1 # Smoothing factor + # if f'running_mean_priority_task{task_id}' not in value_priority_tasks: + # value_priority_tasks[f'running_mean_priority_task{task_id}'] = mean_priority + # else: + # value_priority_tasks[f'running_mean_priority_task{task_id}'] = ( + # alpha * mean_priority + + # (1 - alpha) * value_priority_tasks[f'running_mean_priority_task{task_id}'] + # ) + + # # Use running mean to calculate normalized priorities. + # running_mean_priority = value_priority_tasks[f'running_mean_priority_task{task_id}'] + # normalized_priorities = (current_priorities - running_mean_priority) / ( + # std_priority + 1e-6) + + # # If needed, update the replay buffer with normalized priorities. + # # replay_buffer.update_priority(train_data_multi_task[idx], normalized_priorities) + + # # Log priority statistics. + # if cfg.policy.print_task_priority_logs: + # print(f"任务 {task_id} - 平均优先级: {mean_priority:.8f}, " + # f"运行平均优先级: {running_mean_priority:.8f}, " + # f"标准差: {std_priority:.8f}") train_epoch += 1 policy.recompute_pos_emb_diff_and_clear_cache() - # 同步所有Rank,确保所有Rank完成训练 + # Synchronize all ranks to ensure they have completed training. try: dist.barrier() logging.info(f'Rank {rank}: 通过训练后的同步障碍') @@ -693,7 +861,7 @@ def train_unizero_multitask_segment_ddp( logging.error(f'Rank {rank}: 同步障碍失败,错误: {e}') break - # 检查是否需要终止训练 + # Check for termination conditions. try: local_envsteps = [collector.envstep for collector in collectors] total_envsteps = [None for _ in range(world_size)] @@ -702,7 +870,7 @@ def train_unizero_multitask_segment_ddp( all_envsteps = torch.cat([torch.tensor(envsteps, device=cfg.policy.device) for envsteps in total_envsteps]) max_envstep_reached = torch.all(all_envsteps >= max_env_step) - # 收集所有进程的train_iter + # Gather train_iter from all processes. global_train_iter = torch.tensor([learner.train_iter], device=cfg.policy.device) all_train_iters = [torch.zeros_like(global_train_iter) for _ in range(world_size)] dist.all_gather(all_train_iters, global_train_iter) @@ -711,12 +879,12 @@ def train_unizero_multitask_segment_ddp( if max_envstep_reached.item() or max_train_iter_reached.item(): logging.info(f'Rank {rank}: 达到终止条件') - dist.barrier() # 确保所有进程同步 + dist.barrier() # Ensure all processes synchronize before exiting. break except Exception as e: logging.error(f'Rank {rank}: 终止检查失败,错误: {e}') break - # 调用learner的after_run钩子 + # Call the learner's after_run hook. learner.call_hook('after_run') return policy \ No newline at end of file diff --git a/lzero/entry/train_unizero_multitask_segment_eval.py b/lzero/entry/train_unizero_multitask_segment_eval.py index f98e4c41b..3715cbef4 100644 --- a/lzero/entry/train_unizero_multitask_segment_eval.py +++ b/lzero/entry/train_unizero_multitask_segment_eval.py @@ -1,13 +1,15 @@ import logging import os +import concurrent.futures from functools import partial -from typing import Tuple, Optional, List, Dict, Any +from typing import Tuple, Optional, List, Dict, Any, Type import torch +import torch.distributed as dist import numpy as np from ding.config import compile_config from ding.envs import create_env_manager, get_vec_env_setting -from ding.policy import create_policy +from ding.policy import create_policy, Policy from ding.rl_utils import get_epsilon_greedy_fn from ding.utils import set_pkg_seed, get_rank, get_world_size, EasyTimer from ding.worker import BaseLearner @@ -19,11 +21,11 @@ from lzero.worker import MuZeroEvaluator as Evaluator from lzero.worker import MuZeroSegmentCollector as Collector -import torch.distributed as dist -import concurrent.futures - -# 设置超时时间 (秒) -TIMEOUT = 12000 # 例如200分钟 +# Configure basic logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', +) def safe_eval( @@ -31,46 +33,47 @@ def safe_eval( learner: BaseLearner, collector: Collector, rank: int, - world_size: int + world_size: int, + timeout: int = 12000 ) -> Tuple[Optional[bool], Optional[float]]: """ - Safely evaluates the policy using the evaluator with a timeout. - - Args: - evaluator (Evaluator): The evaluator instance. - learner (BaseLearner): The learner instance. - collector (Collector): The collector instance. - rank (int): The rank of the current process. - world_size (int): Total number of processes. - + Overview: + Safely evaluates the policy using the evaluator with a specified timeout. This wrapper prevents + the entire training process from crashing due to evaluation-related issues like deadlocks. + Arguments: + - evaluator (:obj:`Evaluator`): The evaluator instance to run. + - learner (:obj:`BaseLearner`): The learner instance, used to access checkpoint saving and training iteration. + - collector (:obj:`Collector`): The collector instance, used to access the environment step count. + - rank (:obj:`int`): The rank of the current process in distributed training. + - world_size (:obj:`int`): The total number of processes. + - timeout (:obj:`int`): The maximum time in seconds to wait for the evaluation to complete. Returns: - Tuple[Optional[bool], Optional[float]]: A tuple containing the stop flag and reward. + - (:obj:`Tuple[Optional[bool], Optional[float]]`): A tuple containing the stop flag and the reward. + Returns (None, None) if evaluation times out or an exception occurs. """ try: - print(f"=========before eval Rank {rank}/{world_size}===========") - # 重置 stop_event,确保每次评估前都处于未设置状态 + logging.info(f"Rank {rank}/{world_size}: Starting evaluation.") + # Ensure the stop_event is clear before starting a new evaluation. evaluator.stop_event.clear() with concurrent.futures.ThreadPoolExecutor() as executor: - # 提交 evaluator.eval 任务 future = executor.submit( evaluator.eval, learner.save_checkpoint, learner.train_iter, collector.envstep ) - try: - stop, reward = future.result(timeout=TIMEOUT) + stop, reward = future.result(timeout=timeout) except concurrent.futures.TimeoutError: - # 超时,设置 evaluator 的 stop_event + # If evaluation exceeds the timeout, set the evaluator's stop event to terminate it gracefully. evaluator.stop_event.set() - print(f"Eval operation timed out after {TIMEOUT} seconds on Rank {rank}/{world_size}.") + logging.warning(f"Rank {rank}/{world_size}: Evaluation timed out after {timeout} seconds.") return None, None - print(f"======after eval Rank {rank}/{world_size}======") + logging.info(f"Rank {rank}/{world_size}: Evaluation finished successfully.") return stop, reward except Exception as e: - print(f"An error occurred during evaluation on Rank {rank}/{world_size}: {e}") + logging.error(f"Rank {rank}/{world_size}: An error occurred during evaluation: {e}", exc_info=True) return None, None @@ -81,63 +84,55 @@ def allocate_batch_size( clip_scale: int = 1 ) -> List[int]: """ - Allocates batch sizes inversely proportional to the number of collected episodes for each task. - Dynamically adjusts batch size within a specified range to enhance training stability and efficiency. - - Args: - cfgs (List[Any]): List of configurations for each task. - game_buffers (List[GameBuffer]): List of replay buffer instances for each task. - alpha (float): The hyperparameter controlling the degree of inverse proportionality. Default is 1.0. - clip_scale (int): The scaling factor to clip the batch size. Default is 1. - + Overview: + Allocates batch sizes inversely proportional to the number of collected episodes for each task. + This dynamic adjustment helps balance training focus across multiple tasks, prioritizing those + with less data. The batch sizes are clipped to a dynamic range to maintain stability. + Arguments: + - cfgs (:obj:`List[Any]`): List of configuration objects for each task. + - game_buffers (:obj:`List[GameBuffer]`): List of replay buffer instances for each task. + - alpha (:obj:`float`): A hyperparameter controlling the degree of inverse proportionality. Defaults to 1.0. + - clip_scale (:obj:`int`): A scaling factor to define the clipping range for the batch size. Defaults to 1. Returns: - List[int]: A list of allocated batch sizes for each task. + - (:obj:`List[int]`): A list of allocated batch sizes for each task. """ - # 提取每个任务的 num_of_collected_episodes - buffer_num_of_collected_episodes = [ - buffer.num_of_collected_episodes for buffer in game_buffers - ] + # Extract the number of collected episodes from each task's buffer. + buffer_num_of_collected_episodes = [buffer.num_of_collected_episodes for buffer in game_buffers] - # 获取当前的 world_size 和 rank world_size = get_world_size() rank = get_rank() - # 收集所有 rank 的 num_of_collected_episodes 列表 - all_task_num_of_collected_episodes = [None for _ in range(world_size)] - dist.all_gather_object(all_task_num_of_collected_episodes, buffer_num_of_collected_episodes) + # Gather the episode counts from all ranks. + all_task_num_of_collected_episodes_obj = [None for _ in range(world_size)] + dist.all_gather_object(all_task_num_of_collected_episodes_obj, buffer_num_of_collected_episodes) - # 将所有 rank 的 num_of_collected_episodes 拼接成一个大列表 - all_task_num_of_collected_episodes = [ - item for sublist in all_task_num_of_collected_episodes for item in sublist - ] + # Concatenate the lists from all ranks into a single flat list. + all_task_num_of_collected_episodes = [item for sublist in all_task_num_of_collected_episodes_obj for item in sublist] if rank == 0: - print(f'all_task_num_of_collected_episodes: {all_task_num_of_collected_episodes}') + logging.info(f'All task collected episodes: {all_task_num_of_collected_episodes}') - # 计算每个任务的反比权重 - inv_episodes = np.array([ - 1.0 / (episodes + 1) for episodes in all_task_num_of_collected_episodes - ]) + # Calculate the inverse weight for each task. Adding 1 to avoid division by zero. + inv_episodes = np.array([1.0 / (episodes + 1) for episodes in all_task_num_of_collected_episodes]) inv_sum = np.sum(inv_episodes) - # 计算总的 batch_size (所有任务 cfg.policy.batch_size 的和) + # The total batch size is defined in the config of the first task. total_batch_size = cfgs[0].policy.total_batch_size - # 动态调整的部分:最小和最大的 batch_size 范围 + # Define a dynamic range for batch sizes to prevent extreme values. avg_batch_size = total_batch_size / world_size min_batch_size = avg_batch_size / clip_scale max_batch_size = avg_batch_size * clip_scale - # 动态调整 alpha,让 batch_size 的变化更加平滑 + # Calculate task weights based on inverse proportionality, smoothed by alpha. task_weights = (inv_episodes / inv_sum) ** alpha batch_sizes = total_batch_size * task_weights - # 控制 batch_size 在 [min_batch_size, max_batch_size] 之间 + # Clip the batch sizes to the calculated dynamic range. batch_sizes = np.clip(batch_sizes, min_batch_size, max_batch_size) - # 确保 batch_size 是整数 + # Ensure batch sizes are integers. batch_sizes = [int(size) for size in batch_sizes] - # 返回最终分配的 batch_size 列表 return batch_sizes @@ -151,33 +146,31 @@ def train_unizero_multitask_segment_eval( ) -> 'Policy': """ Overview: - The training entry point for UniZero, as proposed in the paper "UniZero: Generalized and Efficient Planning with Scalable Latent World Models". - UniZero aims to enhance the planning capabilities of reinforcement learning agents by addressing limitations found in MuZero-style algorithms, - particularly in environments requiring the capture of long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667. - - Args: - input_cfg_list (List[Tuple[int, Tuple[Dict[str, Any], Dict[str, Any]]]]): - List of configurations for different tasks. Each item is a tuple containing a task ID and a tuple of configuration dictionaries. - seed (int): - Random seed for reproducibility. - model (Optional[torch.nn.Module]): - Instance of torch.nn.Module representing the model. If None, a new model will be created. - model_path (Optional[str]): - Path to a pretrained model checkpoint. Should point to the ckpt file of the pretrained model. - max_train_iter (Optional[int]): - Maximum number of policy update iterations during training. Default is a very large number. - max_env_step (Optional[int]): - Maximum number of environment interaction steps to collect. Default is a very large number. - + The main training entry point for UniZero, as proposed in the paper "UniZero: Generalized and Efficient Planning + with Scalable Latent World Models" (https://arxiv.org/abs/2406.10667). This function sets up a distributed + multi-task training environment where multiple reinforcement learning tasks are trained in parallel using a + single shared model. It handles task distribution, component initialization (policy, learner, buffers, etc.), + and the main training loop orchestration. + Arguments: + - input_cfg_list (:obj:`List[Tuple[int, Tuple[Dict, Dict]]]`): A list of configurations for each task. Each + element is a tuple containing the task ID and its corresponding configuration dictionaries. + - seed (:obj:`int`): The master random seed for reproducibility. + - model (:obj:`Optional[torch.nn.Module]`): An optional pre-existing model instance. If None, a new model is + created based on the config. + - model_path (:obj:`Optional[str]`): An optional path to a pre-trained model checkpoint. + - max_train_iter (:obj:`Optional[int]`): The maximum number of training iterations before termination. + - max_env_step (:obj:`Optional[int]`): The maximum number of environment steps before termination. Returns: - 'Policy': - The converged policy after training. + - (:obj:`'Policy'`): The trained policy instance after the training loop has converged or terminated. """ - # 获取当前进程的 rank 和总的进程数 + # ============================================================== + # 1. Initialization + # ============================================================== + + # 1.1. Distributed Setup & Task Partitioning rank = get_rank() world_size = get_world_size() - # 任务划分 total_tasks = len(input_cfg_list) tasks_per_rank = total_tasks // world_size remainder = total_tasks % world_size @@ -191,290 +184,225 @@ def train_unizero_multitask_segment_eval( tasks_for_this_rank = input_cfg_list[start_idx:end_idx] - # 确保至少有一个任务 - if len(tasks_for_this_rank) == 0: - logging.warning(f"Rank {rank}: No tasks assigned, continuing without tasks.") - # 初始化一些空列表以避免后续代码报错 - cfgs, game_buffers, collectors, evaluators = [], [], [], [] - else: - print(f"Rank {rank}/{world_size}, handling tasks {start_idx} to {end_idx - 1}") - - cfgs: List[Any] = [] - game_buffers: List[GameBuffer] = [] - collectors: List[Collector] = [] - evaluators: List[Evaluator] = [] - - # 使用本rank的第一个任务的配置来创建共享的 policy - task_id, (cfg, create_cfg) = tasks_for_this_rank[0] - - # 设置每个任务的 task_num 以用于 learner_log - for config in tasks_for_this_rank: - config[1][0].policy.task_num = tasks_per_rank - - # 确保指定的 policy 类型是支持的 - assert create_cfg.policy.type in [ - 'unizero_multitask'], "train_unizero entry now only supports 'unizero_multitask'" - - # 根据 CUDA 可用性设置设备 - cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu' - logging.info(f'cfg.policy.device: {cfg.policy.device}') - - # 编译配置 - cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) - # 创建共享的 policy - policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) - - # 如果指定了预训练模型,则加载 - if model_path is not None: - logging.info(f'Loading model from {model_path} begin...') - policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) - logging.info(f'Loading model from {model_path} end!') - - # 创建 TensorBoard 的日志记录器 - log_dir = os.path.join('./{}/log'.format(cfg.exp_name), f'serial_rank_{rank}') - tb_logger = SummaryWriter(log_dir) - - # 创建共享的 learner - learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) - - policy_config = cfg.policy - batch_size = policy_config.batch_size[0] - - # 只处理当前进程分配到的任务 - for local_task_id, (task_id, (cfg, create_cfg)) in enumerate(tasks_for_this_rank): - # 设置每个任务自己的随机种子 - cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu' - cfg = compile_config(cfg, seed=seed + task_id, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) - policy_config = cfg.policy - policy.collect_mode.get_attribute('cfg').n_episode = policy_config.n_episode - policy.eval_mode.get_attribute('cfg').n_episode = policy_config.n_episode - - env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) - collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) - evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) - collector_env.seed(cfg.seed + task_id) - evaluator_env.seed(cfg.seed + task_id, dynamic_seed=False) - set_pkg_seed(cfg.seed + task_id, use_cuda=cfg.policy.cuda) - - # 为每个任务创建不同的 game buffer、collector、evaluator - replay_buffer = GameBuffer(policy_config) - collector = Collector( - env=collector_env, - policy=policy.collect_mode, - tb_logger=tb_logger, - exp_name=cfg.exp_name, - policy_config=policy_config, - task_id=task_id - ) - evaluator = Evaluator( - eval_freq=cfg.policy.eval_freq, - n_evaluator_episode=cfg.env.n_evaluator_episode, - stop_value=cfg.env.stop_value, - env=evaluator_env, - policy=policy.eval_mode, - tb_logger=tb_logger, - exp_name=cfg.exp_name, - policy_config=policy_config, - task_id=task_id - ) - - cfgs.append(cfg) - replay_buffer.batch_size = cfg.policy.batch_size[task_id] + if not tasks_for_this_rank: + logging.warning(f"Rank {rank}: No tasks assigned. This rank will be idle.") + # Keep the process alive to participate in collective communications. + dist.barrier() + return + + logging.info(f"Rank {rank}/{world_size}: Handling tasks from index {start_idx} to {end_idx - 1}.") + + # 1.2. Shared Policy, Learner, and Logger Initialization + # Use the configuration of the first task on this rank to create the shared components. + _, (first_cfg, first_create_cfg) = tasks_for_this_rank[0] + + # Set task_num for learner logging purposes. + for _, (cfg, _) in tasks_for_this_rank: + cfg.policy.task_num = tasks_per_rank + + assert first_create_cfg.policy.type in ['unizero_multitask'], \ + "This entry point currently only supports 'unizero_multitask' policy type." + + first_cfg.policy.device = 'cuda' if torch.cuda.is_available() else 'cpu' + logging.info(f'Shared policy device: {first_cfg.policy.device}') + + # Compile the main configuration. + cfg = compile_config(first_cfg, seed=seed, auto=True, create_cfg=first_create_cfg, save_cfg=True) + + # Create the shared policy. + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + # Load a pre-trained model if a path is provided. + if model_path is not None: + logging.info(f'Loading pre-trained model from: {model_path}') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + logging.info('Model loading complete.') + + # Create a TensorBoard logger for this rank. + log_dir = os.path.join(f'./{cfg.exp_name}/log', f'serial_rank_{rank}') + tb_logger = SummaryWriter(log_dir) + + # Create the shared learner instance. + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + # 1.3. Task-Specific Components Initialization + cfgs, game_buffers, collectors, evaluators = [], [], [], [] + for task_id, (task_cfg, task_create_cfg) in tasks_for_this_rank: + # Set a unique seed for each task to ensure diversity in data collection. + task_seed = seed + task_id + task_cfg.policy.device = 'cuda' if task_cfg.policy.cuda and torch.cuda.is_available() else 'cpu' + task_cfg = compile_config(task_cfg, seed=task_seed, auto=True, create_cfg=task_create_cfg, save_cfg=True) + + policy.collect_mode.get_attribute('cfg').n_episode = task_cfg.policy.n_episode + policy.eval_mode.get_attribute('cfg').n_episode = task_cfg.policy.n_episode + + # Create environment managers for collection and evaluation. + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(task_cfg.env) + collector_env = create_env_manager(task_cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(task_cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(task_seed) + evaluator_env.seed(task_seed, dynamic_seed=False) + set_pkg_seed(task_seed, use_cuda=task_cfg.policy.cuda) + + # Create task-specific buffers, collectors, and evaluators. + replay_buffer = GameBuffer(task_cfg.policy) + replay_buffer.batch_size = task_cfg.policy.batch_size[task_id] + + collector = Collector( + env=collector_env, policy=policy.collect_mode, tb_logger=tb_logger, exp_name=task_cfg.exp_name, + policy_config=task_cfg.policy, task_id=task_id + ) + evaluator = Evaluator( + eval_freq=task_cfg.policy.eval_freq, n_evaluator_episode=task_cfg.env.n_evaluator_episode, + stop_value=task_cfg.env.stop_value, env=evaluator_env, policy=policy.eval_mode, + tb_logger=tb_logger, exp_name=task_cfg.exp_name, policy_config=task_cfg.policy, task_id=task_id + ) - game_buffers.append(replay_buffer) - collectors.append(collector) - evaluators.append(evaluator) + cfgs.append(task_cfg) + game_buffers.append(replay_buffer) + collectors.append(collector) + evaluators.append(evaluator) learner.call_hook('before_run') + + # ============================================================== + # 2. Main Training Loop + # ============================================================== buffer_reanalyze_count = 0 train_epoch = 0 - reanalyze_batch_size = cfg.policy.reanalyze_batch_size - update_per_collect = cfg.policy.update_per_collect - while True: - # 预先计算位置嵌入矩阵(如果需要) - # policy._collect_model.world_model.precompute_pos_emb_diff_kv() - # policy._target_model.world_model.precompute_pos_emb_diff_kv() + if learner.train_iter >= max_train_iter or collector.envstep >= max_env_step: + break + # 2.1. Dynamic Batch Size Allocation (Optional) if cfg.policy.allocated_batch_sizes: - # 动态调整 clip_scale 随着 train_epoch 从 0 增加到 1000, clip_scale 从 1 线性增加到 4 + # As training progresses, allow for a larger divergence in batch sizes. clip_scale = np.clip(1 + (3 * train_epoch / 1000), 1, 4) allocated_batch_sizes = allocate_batch_size(cfgs, game_buffers, alpha=1.0, clip_scale=clip_scale) if rank == 0: - print("分配后的 batch_sizes: ", allocated_batch_sizes) - for cfg, _collector, _evaluator, replay_buffer in zip(cfgs, collectors, evaluators, game_buffers): - cfg.policy.batch_size = allocated_batch_sizes + logging.info(f"Allocated batch sizes: {allocated_batch_sizes}") + for task_cfg, replay_buffer in zip(cfgs, game_buffers): + task_cfg.policy.batch_size = allocated_batch_sizes policy._cfg.batch_size = allocated_batch_sizes - # 对于当前进程的每个任务,进行数据收集和评估 - for cfg, collector, evaluator, replay_buffer in zip(cfgs, collectors, evaluators, game_buffers): - log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger, cfg.policy.task_id) + # 2.2. Collection and Evaluation Phase + for task_cfg, collector, evaluator, replay_buffer in zip(cfgs, collectors, evaluators, game_buffers): + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger, task_cfg.policy.task_id) + # Determine exploration parameters for collection. collect_kwargs = { 'temperature': visit_count_temperature( - policy_config.manual_temperature_decay, - policy_config.fixed_temperature_value, - policy_config.threshold_training_steps_for_final_temperature, - trained_steps=learner.train_iter + task_cfg.policy.manual_temperature_decay, task_cfg.policy.fixed_temperature_value, + task_cfg.policy.threshold_training_steps_for_final_temperature, trained_steps=learner.train_iter ), - 'epsilon': 0.0 # 默认的 epsilon 值 + 'epsilon': 0.0 } - - if policy_config.eps.eps_greedy_exploration_in_collect: - epsilon_greedy_fn = get_epsilon_greedy_fn( - start=policy_config.eps.start, - end=policy_config.eps.end, - decay=policy_config.eps.decay, - type_=policy_config.eps.type + if task_cfg.policy.eps.eps_greedy_exploration_in_collect: + epsilon_fn = get_epsilon_greedy_fn( + start=task_cfg.policy.eps.start, end=task_cfg.policy.eps.end, + decay=task_cfg.policy.eps.decay, type_=task_cfg.policy.eps.type ) - collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + collect_kwargs['epsilon'] = epsilon_fn(collector.envstep) + # Evaluate the policy periodically. if evaluator.should_eval(learner.train_iter): - print('=' * 20) - print(f'Rank {rank} evaluates task_id: {cfg.policy.task_id}...') - - # 在训练进程中调用 safe_eval + logging.info(f'Rank {rank} evaluating task_id: {task_cfg.policy.task_id}...') stop, reward = safe_eval(evaluator, learner, collector, rank, world_size) - # 判断评估是否成功 if stop is None or reward is None: - print(f"Rank {rank} encountered an issue during evaluation. Continuing training...") + logging.warning(f"Rank {rank} evaluation for task {task_cfg.policy.task_id} failed or timed out.") else: - print(f"Evaluation successful: stop={stop}, reward={reward}") - - print('=' * 20) - print(f'entry: Rank {rank} collects task_id: {cfg.policy.task_id}...') + logging.info(f"Evaluation successful for task {task_cfg.policy.task_id}: stop={stop}, reward={reward}") - # NOTE: 在每次收集之前重置初始数据,这对于多任务设置非常重要 + # Collect new data. + logging.info(f'Rank {rank} collecting for task_id: {task_cfg.policy.task_id}...') + # NOTE: Resetting initial data is crucial in multi-task settings to avoid state leakage. collector._policy.reset(reset_init_data=True) - # 收集数据 new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) - # 更新 replay buffer + # Update the replay buffer. replay_buffer.push_game_segments(new_data) replay_buffer.remove_oldest_data_to_fit() - # 周期性地重新分析缓冲区 - if cfg.policy.buffer_reanalyze_freq >= 1: - # 在一个训练 epoch 中重新分析缓冲区 次 - reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq - else: - # 每 <1/buffer_reanalyze_freq> 个训练 epoch 重新分析一次缓冲区 - if (train_epoch % int(1 / cfg.policy.buffer_reanalyze_freq) == 0 and - replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > - int(reanalyze_batch_size / cfg.policy.reanalyze_partition)): + # Periodically reanalyze the buffer to update value/policy targets with a more recent model. + # This logic handles two cases for `buffer_reanalyze_freq`: + # Case 1: freq < 1 (e.g., 0.5) -> Reanalyze every `1/freq` training epochs. + if 0 < task_cfg.policy.buffer_reanalyze_freq < 1: + if (train_epoch % int(1 / task_cfg.policy.buffer_reanalyze_freq) == 0 and + replay_buffer.get_num_of_transitions() // task_cfg.policy.num_unroll_steps > + int(task_cfg.policy.reanalyze_batch_size / task_cfg.policy.reanalyze_partition)): with EasyTimer() as timer: - # 每个重新分析过程将重新分析 个序列 - replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + replay_buffer.reanalyze_buffer(task_cfg.policy.reanalyze_batch_size, policy) buffer_reanalyze_count += 1 - logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') - logging.info(f'Buffer reanalyze time: {timer.value}') + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}, Time: {timer.value:.2f}s') - # 数据收集结束后添加日志 - logging.info(f'Rank {rank}: Completed data collection for task {cfg.policy.task_id}') + logging.info(f'Rank {rank}: Data collection complete for task {task_cfg.policy.task_id}') - # 检查是否有足够的数据进行训练 + # 2.3. Pre-Training Synchronization and Data Check + # Check if any buffer has insufficient data for training. not_enough_data = any( - replay_buffer.get_num_of_transitions() < cfgs[0].policy.total_batch_size / world_size - for replay_buffer in game_buffers + rb.get_num_of_transitions() < cfg.policy.total_batch_size / world_size for rb in game_buffers ) - # 同步训练前所有 rank 的准备状态 try: dist.barrier() - logging.info(f'Rank {rank}: Passed barrier before training') except Exception as e: - logging.error(f'Rank {rank}: Barrier failed with error {e}') - break # 或者进行其他错误处理 + logging.error(f'Rank {rank}: Barrier failed before training with error {e}', exc_info=True) + break - # 学习策略 + # 2.4. Training Phase if not not_enough_data: - # Learner 将在一次迭代中训练 update_per_collect 次 + update_per_collect = cfg.policy.update_per_collect for i in range(update_per_collect): train_data_multi_task = [] - envstep_multi_task = 0 - for cfg, collector, replay_buffer in zip(cfgs, collectors, game_buffers): - envstep_multi_task += collector.envstep - batch_size = cfg.policy.batch_size[cfg.policy.task_id] + envstep_multi_task = sum(c.envstep for c in collectors) + + for task_cfg, replay_buffer in zip(cfgs, game_buffers): + batch_size = task_cfg.policy.batch_size[task_cfg.policy.task_id] if replay_buffer.get_num_of_transitions() > batch_size: - if cfg.policy.buffer_reanalyze_freq >= 1: - # 在一个训练 epoch 中重新分析缓冲区 次 + # Case 2: freq >= 1 -> Reanalyze `freq` times per collection cycle (spread across updates). + if task_cfg.policy.buffer_reanalyze_freq >= 1: + reanalyze_interval = update_per_collect // task_cfg.policy.buffer_reanalyze_freq if (i % reanalyze_interval == 0 and - replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > - int(reanalyze_batch_size / cfg.policy.reanalyze_partition)): + replay_buffer.get_num_of_transitions() // task_cfg.policy.num_unroll_steps > + int(task_cfg.policy.reanalyze_batch_size / task_cfg.policy.reanalyze_partition)): with EasyTimer() as timer: - # 每个重新分析过程将重新分析 个序列 - replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + replay_buffer.reanalyze_buffer(task_cfg.policy.reanalyze_batch_size, policy) buffer_reanalyze_count += 1 - logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') - logging.info(f'Buffer reanalyze time: {timer.value}') + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}, Time: {timer.value:.2f}s') + # Sample data and append task_id for multi-task learning. train_data = replay_buffer.sample(batch_size, policy) - # 追加 task_id,以便在训练时区分任务 - train_data.append(cfg.policy.task_id) + train_data.append(task_cfg.policy.task_id) train_data_multi_task.append(train_data) else: logging.warning( - f'The data in replay_buffer is not sufficient to sample a mini-batch: ' - f'batch_size: {batch_size}, replay_buffer: {replay_buffer}' + f"Skipping training for task {task_cfg.policy.task_id}: insufficient data. " + f"Required: {batch_size}, Available: {replay_buffer.get_num_of_transitions()}" ) - break if train_data_multi_task: - # 在训练时,DDP 会自动同步梯度和参数 - log_vars = learner.train(train_data_multi_task, envstep_multi_task) + # DDP handles gradient synchronization automatically. + learner.train(train_data_multi_task, envstep_multi_task) - # 同步训练前所有 rank 的准备状态 + # Synchronize after each training step to maintain consistency. try: dist.barrier() - logging.info(f'Rank {rank}: Passed barrier during training') except Exception as e: - logging.error(f'Rank {rank}: Barrier failed with error {e}') - break # 或者进行其他错误处理 - - # TODO: 可选:终止进程 - import sys - sys.exit(0) + logging.error(f'Rank {rank}: Barrier failed during training step with error {e}', exc_info=True) + break + else: + logging.warning(f"Rank {rank}: Skipping training cycle due to insufficient data in one or more buffers.") train_epoch += 1 policy.recompute_pos_emb_diff_and_clear_cache() - # 同步所有 Rank,确保所有 Rank 都完成了训练 + # 2.5. Post-Training Synchronization and Termination Check try: dist.barrier() - logging.info(f'Rank {rank}: Passed barrier after training') - except Exception as e: - logging.error(f'Rank {rank}: Barrier failed with error {e}') - break # 或者进行其他错误处理 - - # 检查是否需要终止训练 - try: - # 收集本地的 envsteps - local_envsteps = [collector.envstep for collector in collectors] - - # 收集所有进程的 envsteps - total_envsteps: List[Optional[int]] = [None for _ in range(world_size)] - dist.all_gather_object(total_envsteps, local_envsteps) - - # 将所有 envsteps 拼接在一起进行检查 - all_envsteps = torch.cat([ - torch.tensor(envsteps, device=cfg.policy.device) for envsteps in total_envsteps - ]) - max_envstep_reached = torch.all(all_envsteps >= max_env_step) - - # 收集所有进程的 train_iter - global_train_iter = torch.tensor([learner.train_iter], device=cfg.policy.device) - all_train_iters = [torch.zeros_like(global_train_iter) for _ in range(world_size)] - dist.all_gather(all_train_iters, global_train_iter) - - max_train_iter_reached = torch.any(torch.stack(all_train_iters) >= max_train_iter) - - if max_envstep_reached.item() or max_train_iter_reached.item(): - logging.info(f'Rank {rank}: Termination condition met') - dist.barrier() # 确保所有进程同步 - break except Exception as e: - logging.error(f'Rank {rank}: Termination check failed with error {e}') - break # 或者进行其他错误处理 + logging.error(f'Rank {rank}: Barrier failed after training cycle with error {e}', exc_info=True) + break learner.call_hook('after_run') + logging.info(f"Rank {rank}: Training finished.") return policy \ No newline at end of file diff --git a/lzero/entry/train_unizero_segment.py b/lzero/entry/train_unizero_segment.py index c1ed74b16..0559934c0 100644 --- a/lzero/entry/train_unizero_segment.py +++ b/lzero/entry/train_unizero_segment.py @@ -154,7 +154,9 @@ def train_unizero_segment( collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) # Evaluate policy performance - if evaluator.should_eval(learner.train_iter): + # if learner.train_iter == 0 or evaluator.should_eval(learner.train_iter): + if learner.train_iter > 0 and evaluator.should_eval(learner.train_iter): + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) if stop: break diff --git a/lzero/entry/utils.py b/lzero/entry/utils.py index 525ac0812..99b22b852 100644 --- a/lzero/entry/utils.py +++ b/lzero/entry/utils.py @@ -1,263 +1,816 @@ +# -*- coding: utf-8 -*- +""" +Optimized and refactored utility code for reinforcement learning models, +focusing on clarity, professionalism, efficiency, and extensibility. +""" + +# ============================================================================== +# Imports +# ============================================================================== +from __future__ import annotations + +import logging +import math import os -from typing import Optional, Callable, Union, List, Tuple +import re +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import numpy as np import psutil import torch import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F from pympler.asizeof import asizeof from tensorboardX import SummaryWriter +# ============================================================================== +# Placeholder Types for External Dependencies +# +# To ensure type hints work without having the full definitions of these complex +# external classes, we define them as `Any`. +# ============================================================================== +EasyDict = Any +Policy = Any +RandomPolicy = Any +ISerialCollector = Any +BaseEnvManager = Any +IBuffer = Any +GameBuffer = Any + + +# ============================================================================== +# Mathematical & Tensor Utilities +# ============================================================================== + +def symlog(x: torch.Tensor) -> torch.Tensor: + """ + Overview: + Applies the symlog transformation to a tensor, which is useful for + normalizing target values with large magnitude differences. + The transformation is defined as: symlog(x) = sign(x) * log(|x| + 1). + + Arguments: + - x (:obj:`torch.Tensor`): The input tensor. + + Returns: + - torch.Tensor: The tensor after applying the symlog transformation. + """ + return torch.sign(x) * torch.log(torch.abs(x) + 1) + + +def inv_symlog(x: torch.Tensor) -> torch.Tensor: + """ + Overview: + Applies the inverse of the symlog transformation to a tensor, restoring + the original scale of the values. + The transformation is defined as: inv_symlog(x) = sign(x) * (exp(|x|) - 1). + + Arguments: + - x (:obj:`torch.Tensor`): The input tensor in symlog space. + + Returns: + - torch.Tensor: The tensor restored to its original scale. + """ + return torch.sign(x) * (torch.exp(torch.abs(x)) - 1) + + +def initialize_zeros_batch( + observation_shape: Union[int, List[int], Tuple[int, ...]], + batch_size: int, + device: str +) -> torch.Tensor: + """ + Overview: + Initializes a zeros tensor for a batch of observations based on the + provided shape. This is commonly used to prepare initial input for models + like UniZero. + + Arguments: + - observation_shape (:obj:`Union[int, List[int], Tuple[int, ...]]`): The shape of a single observation. + - batch_size (:obj:`int`): The number of observations in the batch. + - device (:obj:`str`): The device to store the tensor on (e.g., 'cpu', 'cuda'). + + Returns: + - torch.Tensor: A zeros tensor with the shape [batch_size, *observation_shape]. + """ + if isinstance(observation_shape, (list, tuple)): + shape = (batch_size, *observation_shape) + elif isinstance(observation_shape, int): + shape = (batch_size, observation_shape) + else: + raise TypeError( + f"observation_shape must be an int, list, or tuple, but got {type(observation_shape).__name__}" + ) + return torch.zeros(shape, device=device) + + +# ============================================================================== +# LoRA (Low-Rank Adaptation) Utilities +# ============================================================================== + +# A compiled regex pattern to efficiently detect LoRA-related parameters. +# It matches parameter names ending with: +# - .lora_A or .lora_B (for LoRA weights) +# - .adapter_scales.{digit}.logit (for learnable scale parameters) +_LORA_PAT = re.compile(r"\.(?:lora_[AB]|adapter_scales\.\d+\.logit)$") + + +def _is_lora_param(name: str) -> bool: + """A helper function to check if a parameter name matches the LoRA pattern.""" + return bool(_LORA_PAT.search(name)) + + +def freeze_non_lora_parameters( + module: nn.Module, + freeze: bool = True, + *, + verbose: bool = False, +) -> Tuple[int, int]: + """ + Overview: + Freezes or un-freezes all parameters in a module that are not identified + as LoRA-related parameters. This is useful for curriculum learning stages + where the backbone model is frozen and only LoRA adapters are trained. + + Arguments: + - module (:obj:`nn.Module`): The PyTorch module to process (e.g., a transformer). + - freeze (:obj:`bool`): If True, sets `requires_grad=False` for non-LoRA parameters. + If False, sets `requires_grad=True` for non-LoRA parameters. + - verbose (:obj:`bool`): If True, prints a summary of trainable and frozen parameters. + + Returns: + - Tuple[int, int]: A tuple containing the number of frozen parameters and trainable parameters. + """ + n_frozen = 0 + n_trainable = 0 + + for name, param in module.named_parameters(): + if _is_lora_param(name): + # LoRA-related parameters should always be trainable. + param.requires_grad = True + n_trainable += 1 + else: + # All other parameters are frozen or unfrozen based on the `freeze` flag. + param.requires_grad = not freeze + if param.requires_grad: + n_trainable += 1 + else: + n_frozen += 1 + + if verbose: + total = n_frozen + n_trainable + # Ensure total is not zero to avoid division by zero error. + percentage_trainable = (n_trainable / total * 100) if total > 0 else 0 + print( + f"[freeze_non_lora] Trainable: {n_trainable}/{total} ({percentage_trainable:.1f}%), " + f"Frozen: {n_frozen}" + ) + return n_frozen, n_trainable + + +# ============================================================================== +# Task & Curriculum Learning Utilities +# ============================================================================== + +def compute_task_weights( + task_returns: Dict[str, float], + option: str = "symlog", + epsilon: float = 1e-6, + temperature: float = 1.0, + use_softmax: bool = False, + reverse: bool = False, + clip_min: float = 1e-2, + clip_max: float = 1.0, +) -> Dict[str, float]: + """ + Overview: + Calculates sampling weights for different tasks based on their returns (e.g., rewards or losses). + This function supports various normalization methods, softmax-based distribution, + proportional/inverse weighting, and weight clipping. + + Arguments: + - task_returns (:obj:`Dict[str, float]`): A dictionary mapping task IDs to their return values. + - option (:obj:`str`): Normalization method. One of ["symlog", "max-min", "run-max-min", "rank", "none"]. + - epsilon (:obj:`float`): A small value to prevent division by zero. + - temperature (:obj:`float`): A temperature parameter to control the sharpness of the weight distribution. + - use_softmax (:obj:`bool`): If True, use softmax to compute weights; otherwise, use direct normalization. + - reverse (:obj:`bool`): If True, weights are inversely proportional to returns; otherwise, directly proportional. + - clip_min (:obj:`float`): The minimum value to clip the final weights to. + - clip_max (:obj:`float`): The maximum value to clip the final weights to. + + Returns: + - Dict[str, float]: A dictionary mapping task IDs to their computed weights. + """ + if not task_returns: + return {} + + task_ids = list(task_returns.keys()) + returns_tensor = torch.tensor(list(task_returns.values()), dtype=torch.float32) + + # Step 1: Normalize the returns based on the chosen option. + scaled_returns: torch.Tensor + if option == "symlog": + scaled_returns = symlog(returns_tensor) + elif option == "max-min": + min_val, max_val = returns_tensor.min(), returns_tensor.max() + scaled_returns = (returns_tensor - min_val) / (max_val - min_val + epsilon) + elif option == "run-max-min": + # Use function attributes to maintain state across calls, avoiding global variables. + compute_task_weights.RUNNING_MAX = max(compute_task_weights.RUNNING_MAX, returns_tensor.max().item()) + compute_task_weights.RUNNING_MIN = min(compute_task_weights.RUNNING_MIN, returns_tensor.min().item()) + scaled_returns = (returns_tensor - compute_task_weights.RUNNING_MIN) / \ + (compute_task_weights.RUNNING_MAX - compute_task_weights.RUNNING_MIN + epsilon) + elif option == "rank": + sorted_indices = torch.argsort(returns_tensor) + ranks = torch.empty_like(returns_tensor) + # Ranks are from 1 to N. + ranks[sorted_indices] = torch.arange(1, len(returns_tensor) + 1, dtype=torch.float32) + scaled_returns = ranks + elif option == "none": + scaled_returns = returns_tensor + else: + raise ValueError(f"Unsupported normalization option: {option}") + + # Step 2: Determine if weights should be proportional or inversely proportional to returns. + if reverse: + # Inverse proportion: smaller return -> higher weight. + raw_weights = 1.0 / (scaled_returns + epsilon) + else: + # Direct proportion: higher return -> higher weight. + raw_weights = scaled_returns + + # Step 3: Calculate final weights using either softmax or direct normalization. + final_weights: np.ndarray + safe_temperature = max(temperature, epsilon) + if use_softmax: + # Softmax provides a smooth distribution, often used with inverse weights. + # A higher beta (lower temperature) makes the distribution sharper. + beta = 1.0 / safe_temperature + # The sign depends on whether we want to favor high or low raw_weights. + # If reverse=True, raw_weights are high for low returns. We want to sample these more. + # Softmax(logits) gives higher probability to higher logits. + # So, logits should be proportional to the desired sampling probability. + logits = raw_weights if reverse else -raw_weights + final_weights = F.softmax(logits * beta, dim=0).numpy() + else: + # Direct normalization with temperature scaling. + scaled_weights = raw_weights**(1 / safe_temperature) + total_weight = scaled_weights.sum() + normalized_weights = scaled_weights / (total_weight + epsilon) + final_weights = normalized_weights.numpy() + + # Step 4: Clip weights to the desired range and create the result dictionary. + weights_dict = { + task_id: np.clip(weight, clip_min, clip_max) + for task_id, weight in zip(task_ids, final_weights) + } + + return weights_dict + +# Initialize state for the 'run-max-min' option as function attributes. +compute_task_weights.RUNNING_MAX = -float('inf') +compute_task_weights.RUNNING_MIN = float('inf') -import torch -import numpy as np -import torch -import torch.nn.functional as F -import matplotlib.pyplot as plt class TemperatureScheduler: - def __init__(self, initial_temp: float, final_temp: float, threshold_steps: int, mode: str = 'linear'): - """ - 温度调度器,用于根据当前训练步数逐渐调整温度。 + """ + Overview: + A scheduler to gradually adjust a temperature value over a specified number + of training steps. This can be used for exploration or weighting schemes. - Args: - initial_temp (float): 初始温度值。 - final_temp (float): 最终温度值。 - threshold_steps (int): 温度衰减到最终温度所需的训练步数。 - mode (str): 衰减方式,可选 'linear' 或 'exponential'。默认 'linear'。 - """ + Arguments: + - initial_temp (:obj:`float`): The starting temperature. + - final_temp (:obj:`float`): The target temperature to be reached after `threshold_steps`. + - threshold_steps (:obj:`int`): The number of steps over which the temperature will anneal. + - mode (:obj:`str`): The annealing mode, either 'linear' or 'exponential'. + """ + + def __init__(self, initial_temp: float, final_temp: float, threshold_steps: int, mode: str = 'linear'): + if mode not in ['linear', 'exponential']: + raise ValueError("Mode must be 'linear' or 'exponential'.") self.initial_temp = initial_temp self.final_temp = final_temp - self.threshold_steps = threshold_steps - assert mode in ['linear', 'exponential'], "Mode must be 'linear' or 'exponential'." + self.threshold_steps = max(1, threshold_steps) # Avoid division by zero self.mode = mode def get_temperature(self, current_step: int) -> float: """ - 根据当前步数计算温度。 + Overview: + Calculates the temperature for the given training step. - Args: - current_step (int): 当前的训练步数。 + Arguments: + - current_step (:obj:`int`): The current training step. Returns: - float: 当前温度值。 + - float: The calculated temperature for the current step. """ if current_step >= self.threshold_steps: return self.final_temp + progress = current_step / self.threshold_steps + if self.mode == 'linear': - temp = self.initial_temp - (self.initial_temp - self.final_temp) * progress - elif self.mode == 'exponential': - # 指数衰减,确保温度逐渐接近 final_temp - decay_rate = np.log(self.final_temp / self.initial_temp) / self.threshold_steps - temp = self.initial_temp * np.exp(decay_rate * current_step) - temp = max(temp, self.final_temp) - return temp + return self.initial_temp - (self.initial_temp - self.final_temp) * progress + else: # 'exponential' + # Exponential decay from initial_temp to final_temp + # T(t) = T_initial * (T_final / T_initial)^(t / N) + if self.initial_temp <= 0: + raise ValueError("Initial temperature must be positive for exponential decay.") + scale = self.final_temp / self.initial_temp + return self.initial_temp * (scale**progress) + + +def tasks_per_stage(unsolved: int, remain_lora: int) -> int: + """ + Overview: + Calculates the number of tasks to assign per LoRA adapter stage. + It's the ceiling of the division of unsolved tasks by remaining adapters. + + Arguments: + - unsolved (:obj:`int`): The number of tasks yet to be solved. + - remain_lora (:obj:`int`): The number of available LoRA adapters. + + Returns: + - int: The number of tasks to be handled in the current stage, at least 1. + """ + return max(1, math.ceil(unsolved / max(remain_lora, 1))) + + +def compute_unizero_mt_normalized_stats( + eval_returns: Dict[int, float], + human_scores: Dict[int, float], + random_scores: Dict[int, float] +) -> Tuple[Optional[float], Optional[float]]: + """ + Overview: + Calculates the Human-Normalized Mean and Median for a set of evaluation returns. + If no valid returns are provided, it returns (None, None). + + Arguments: + - eval_returns (:obj:`Dict[int, float]`): A dictionary of evaluation returns per task ID. + - human_scores (:obj:`Dict[int, float]`): A dictionary of human expert scores per task ID. + - random_scores (:obj:`Dict[int, float]`): A dictionary of random policy scores per task ID. -def is_ddp_enabled(): + Returns: + - Tuple[Optional[float], Optional[float]]: A tuple containing the human-normalized mean and median. + """ + normalized = [] + for tid, ret in eval_returns.items(): + if ret is None or tid not in human_scores or tid not in random_scores: + continue + denom = human_scores[tid] - random_scores[tid] + if denom == 0: + continue + normalized.append((ret - random_scores[tid]) / denom) + + if not normalized: + return None, None + + arr = np.asarray(normalized, dtype=np.float32) + return float(arr.mean()), float(np.median(arr)) + + +def allocate_batch_size( + cfgs: List[EasyDict], + game_buffers: List[GameBuffer], + alpha: float = 1.0, + clip_scale: int = 1 +) -> List[int]: """ - Check if Distributed Data Parallel (DDP) is enabled by verifying if - PyTorch's distributed package is available and initialized. + Overview: + Allocates batch sizes for different tasks inversely proportional to the + number of collected episodes for each task. It also dynamically clips + the batch size range to improve training stability. + + Arguments: + - cfgs (:obj:`List[EasyDict]`): A list of configuration objects for each task. + - game_buffers (:obj:`List[GameBuffer]`): A list of replay buffer instances for each task. + - alpha (:obj:`float`): A hyperparameter to control the degree of inverse proportionality. + - clip_scale (:obj:`int`): A scaling factor to determine the min/max batch size clip range. + + Returns: + - List[int]: A list of allocated batch sizes for each task. + """ + # This function assumes a DDP environment. + if not dist.is_available() or not dist.is_initialized(): + # Fallback for non-DDP environment if needed, though the logic is DDP-centric. + logging.warning("allocate_batch_size is designed for DDP and may not work as expected.") + world_size = 1 + rank = 0 + else: + world_size = dist.get_world_size() + rank = dist.get_rank() + + # Extract the number of collected episodes from each local buffer. + local_episodes = [buffer.num_of_collected_episodes for buffer in game_buffers] + + # Gather episode counts from all ranks. + all_task_episodes_list = [None for _ in range(world_size)] + dist.all_gather_object(all_task_episodes_list, local_episodes) + + # Flatten the list of lists into a single list of episode counts for all tasks. + all_task_episodes = [ep for sublist in all_task_episodes_list for ep in sublist] + + if rank == 0: + logging.info(f'All task collected episodes: {all_task_episodes}') + + # Calculate weights inversely proportional to episode counts. + # Add 1 to avoid division by zero for new tasks. + inv_episodes = np.array([1.0 / (episodes + 1) for episodes in all_task_episodes]) + inv_sum = np.sum(inv_episodes) + + # Total batch size is assumed to be consistent across configs. + total_batch_size = cfgs[0].policy.total_batch_size + + # Define dynamic clipping range for batch sizes. + avg_batch_size = total_batch_size / len(all_task_episodes) + min_batch_size = avg_batch_size / clip_scale + max_batch_size = avg_batch_size * clip_scale + + # Calculate batch sizes based on weights, apply alpha for smoothing. + task_weights = (inv_episodes / inv_sum)**alpha + batch_sizes = total_batch_size * task_weights + + # Clip and convert to integers. + batch_sizes = np.clip(batch_sizes, min_batch_size, max_batch_size) + batch_sizes = [int(size) for size in batch_sizes] + + return batch_sizes + + +# ============================================================================== +# Distributed Data Parallel (DDP) Utilities +# ============================================================================== + +def is_ddp_enabled() -> bool: + """ + Overview: + Checks if the environment is set up for Distributed Data Parallel (DDP) training. + + Returns: + - bool: True if `torch.distributed` is available and initialized, False otherwise. """ return dist.is_available() and dist.is_initialized() -def ddp_synchronize(): + +def ddp_synchronize() -> None: """ - Perform a barrier synchronization across all processes in DDP mode. - Ensures all processes reach this point before continuing. + Overview: + Performs a barrier synchronization across all processes in a DDP group. + This ensures that all processes reach this point before any of them proceed. """ if is_ddp_enabled(): dist.barrier() + def ddp_all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor: """ - Perform an all-reduce operation (sum) on the given tensor across - all processes in DDP mode. Returns the reduced tensor. + Overview: + Performs an all-reduce operation (sum) on a given tensor across all + processes in the DDP group. Arguments: - - tensor (:obj:`torch.Tensor`): The input tensor to be reduced. + - tensor (:obj:`torch.Tensor`): The tensor to be reduced. Returns: - - torch.Tensor: The reduced tensor, summed across all processes. + - torch.Tensor: The reduced tensor, with values summed across all processes. """ if is_ddp_enabled(): dist.all_reduce(tensor, op=dist.ReduceOp.SUM) return tensor -def calculate_update_per_collect(cfg: 'EasyDict', new_data: List[List[torch.Tensor]], world_size: int = 1) -> int: + +# ============================================================================== +# Reinforcement Learning Workflow Utilities +# ============================================================================== + +def calculate_update_per_collect( + cfg: EasyDict, + new_data: List[List[torch.Tensor]], + world_size: int = 1 +) -> int: """ - Calculate the number of updates to perform per data collection in a - Distributed Data Parallel (DDP) setting. This ensures that all GPUs - compute the same `update_per_collect` value, synchronized across processes. + Overview: + Calculates the number of training updates to perform per data collection cycle. + In a DDP setting, it synchronizes transition counts across all GPUs to ensure + a consistent `update_per_collect` value. Arguments: - - cfg: Configuration object containing policy settings. - - new_data (List[List[torch.Tensor]]): The newly collected data segments. - - world_size (int): The total number of processes. + - cfg (:obj:`EasyDict`): The configuration object containing policy settings. + It's expected to have `cfg.policy.update_per_collect`, + `cfg.policy.replay_ratio`, etc. + - new_data (:obj:`List[List[torch.Tensor]]`): The newly collected data segments. + - world_size (:obj:`int`): The total number of DDP processes. Returns: - - int: The number of updates to perform per collection. + - int: The number of updates to perform. """ - # Retrieve the update_per_collect setting from the configuration - update_per_collect = cfg.policy.update_per_collect - - if update_per_collect is None: - # If update_per_collect is not explicitly set, calculate it based on - # the number of collected transitions and the replay ratio. - - # The length of game_segment (i.e., len(game_segment.action_segment)) can be smaller than cfg.policy.game_segment_length if it represents the final segment of the game. - # On the other hand, its length will be less than cfg.policy.game_segment_length + padding_length when it is not the last game segment. Typically, padding_length is the sum of unroll_steps and td_steps. - collected_transitions_num = sum( - min(len(game_segment), cfg.policy.game_segment_length) - for game_segment in new_data[0] + update_per_collect = cfg.policy.get('update_per_collect') + + if update_per_collect is not None: + return update_per_collect + + # If not explicitly set, calculate based on replay ratio. + # Note: A game segment's length can be less than `game_segment_length` if it's the + # final segment of an episode. + collected_transitions_num = sum( + min(len(game_segment), cfg.policy.game_segment_length) + for game_segment in new_data[0] + ) + + if torch.cuda.is_available() and world_size > 1: + # In DDP, synchronize the transition count across all GPUs. + collected_transitions_tensor = torch.tensor( + collected_transitions_num, dtype=torch.int64, device='cuda' ) + total_collected_transitions = ddp_all_reduce_sum( + collected_transitions_tensor + ).item() + updates = int(total_collected_transitions * cfg.policy.replay_ratio) + else: + # In a single-process setup. + updates = int(collected_transitions_num * cfg.policy.replay_ratio) - if torch.cuda.is_available() and world_size > 1: - # Convert the collected transitions count to a GPU tensor for DDP operations. - collected_transitions_tensor = torch.tensor( - collected_transitions_num, dtype=torch.int64, device='cuda' - ) - - # Synchronize the collected transitions count across all GPUs using all-reduce. - total_collected_transitions = ddp_all_reduce_sum( - collected_transitions_tensor - ).item() - - # Calculate update_per_collect based on the total synchronized transitions count. - update_per_collect = int(total_collected_transitions * cfg.policy.replay_ratio) + return max(1, updates) # Ensure at least one update. - # Ensure the computed update_per_collect is positive. - assert update_per_collect > 0, "update_per_collect must be positive" - else: - # If not using DDP, calculate update_per_collect directly from the local count. - update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio) - return update_per_collect - -def initialize_zeros_batch(observation_shape: Union[int, List[int], Tuple[int]], batch_size: int, device: str) -> torch.Tensor: +def initialize_pad_batch(observation_shape: Union[int, List[int], Tuple[int]], batch_size: int, device: str, pad_token_id: int = 0) -> torch.Tensor: """ Overview: - Initialize a zeros tensor for batch observations based on the shape. This function is used to initialize the UniZero model input. + Initialize a tensor filled with `pad_token_id` for batch observations. + This function is designed to be flexible and can handle both textual + and non-textual observations: + + - For textual observations: it initializes `input_ids` with padding tokens, + ensuring consistent sequence lengths within a batch. + - For non-textual observations: it provides a convenient way to fill + observation tensors with a default of 0, + ensuring shape compatibility and preventing uninitialized values. Arguments: - observation_shape (:obj:`Union[int, List[int], Tuple[int]]`): The shape of the observation tensor. - batch_size (:obj:`int`): The batch size. - device (:obj:`str`): The device to store the tensor. + - pad_token_id (:obj:`int`): The token ID (or placeholder value) used for padding. Returns: - - zeros (:obj:`torch.Tensor`): The zeros tensor. + - padded_tensor (:obj:`torch.Tensor`): A tensor of the given shape, + filled with `pad_token_id`. """ - if isinstance(observation_shape, (list,tuple)): + if isinstance(observation_shape, (list, tuple)): shape = [batch_size, *observation_shape] elif isinstance(observation_shape, int): shape = [batch_size, observation_shape] else: - raise TypeError(f"observation_shape must be either an int, a list, or a tuple, but got {type(observation_shape).__name__}") + raise TypeError(f"observation_shape must be int, list, or tuple, but got {type(observation_shape).__name__}") - return torch.zeros(shape).to(device) + return torch.full(shape, fill_value=pad_token_id, dtype=torch.float32, device=device) if pad_token_id == 0 else torch.full(shape, fill_value=pad_token_id, dtype=torch.long, device=device) def random_collect( - policy_cfg: 'EasyDict', # noqa - policy: 'Policy', # noqa - RandomPolicy: 'Policy', # noqa - collector: 'ISerialCollector', # noqa - collector_env: 'BaseEnvManager', # noqa - replay_buffer: 'IBuffer', # noqa - postprocess_data_fn: Optional[Callable] = None -) -> None: # noqa - assert policy_cfg.random_collect_episode_num > 0 + policy_cfg: EasyDict, + policy: Policy, + RandomPolicy: Callable, + collector: ISerialCollector, + collector_env: BaseEnvManager, + replay_buffer: IBuffer, + postprocess_data_fn: Optional[Callable] = None +) -> None: + """ + Overview: + Performs an initial data collection phase using a random policy to populate + the replay buffer before training begins. + + Arguments: + - policy_cfg (:obj:`EasyDict`): Configuration for the policy. + - policy (:obj:`Policy`): The main training policy instance. + - RandomPolicy (:obj:`Callable`): A constructor or class for creating a random policy. + - collector (:obj:`ISerialCollector`): The data collector instance. + - collector_env (:obj:`BaseEnvManager`): The environment manager. + - replay_buffer (:obj:`IBuffer`): The replay buffer to store collected data. + - postprocess_data_fn (:obj:`Optional[Callable]`): An optional function to process data after collection. + """ + random_collect_episode_num = policy_cfg.get('random_collect_episode_num', 0) + if random_collect_episode_num <= 0: + return random_policy = RandomPolicy(cfg=policy_cfg, action_space=collector_env.env_ref.action_space) - # set the policy to random policy collector.reset_policy(random_policy.collect_mode) - # set temperature for visit count distributions according to the train_iter, - # please refer to Appendix D in MuZero paper for details. - collect_kwargs = {'temperature': 1, 'epsilon': 0.0} + # Use neutral MCTS parameters for random collection. + collect_kwargs = {'temperature': 1.0, 'epsilon': 0.0} - # Collect data by default config n_sample/n_episode. - new_data = collector.collect(n_episode=policy_cfg.random_collect_episode_num, train_iter=0, - policy_kwargs=collect_kwargs) + new_data = collector.collect( + n_episode=random_collect_episode_num, + train_iter=0, + policy_kwargs=collect_kwargs + ) - if postprocess_data_fn is not None: + if postprocess_data_fn: new_data = postprocess_data_fn(new_data) - # save returned new_data collected by the collector replay_buffer.push_game_segments(new_data) - # remove the oldest data if the replay buffer is full. replay_buffer.remove_oldest_data_to_fit() - # restore the policy + # Restore the original policy to the collector. collector.reset_policy(policy.collect_mode) -def log_buffer_memory_usage(train_iter: int, buffer: "GameBuffer", writer: SummaryWriter, task_id=0) -> None: +# ============================================================================== +# Logging Utilities +# ============================================================================== + +def log_module_trainable_status( + module: nn.Module, + module_name: str, + logger: logging.Logger +) -> None: """ Overview: - Log the memory usage of the buffer and the current process to TensorBoard. + Logs the detailed trainable/frozen status of all parameters within a given module. + Arguments: - - train_iter (:obj:`int`): The current training iteration. - - buffer (:obj:`GameBuffer`): The game buffer. - - writer (:obj:`SummaryWriter`): The TensorBoard writer. + - module (:obj:`nn.Module`): The module to inspect (e.g., a ViT Encoder). + - module_name (:obj:`str`): The name of the module for logging purposes. + - logger (:obj:`logging.Logger`): The logger instance to use for output. """ - # "writer is None" means we are in a slave process in the DDP setup. - if writer is not None: - writer.add_scalar(f'Buffer/num_of_all_collected_episodes_{task_id}', buffer.num_of_collected_episodes, train_iter) - writer.add_scalar(f'Buffer/num_of_game_segments_{task_id}', len(buffer.game_segment_buffer), train_iter) - writer.add_scalar(f'Buffer/num_of_transitions_{task_id}', len(buffer.game_segment_game_pos_look_up), train_iter) - - game_segment_buffer = buffer.game_segment_buffer + logger.info(f"--- Parameter Status Details for Module: '{module_name}' ---") - # Calculate the amount of memory occupied by self.game_segment_buffer (in bytes). - buffer_memory_usage = asizeof(game_segment_buffer) + total_params = 0 + trainable_params = 0 - # Convert buffer_memory_usage to megabytes (MB). - buffer_memory_usage_mb = buffer_memory_usage / (1024 * 1024) + param_list = list(module.named_parameters()) + if not param_list: + logger.info(" - No parameters found in this module.") + return - # Record the memory usage of self.game_segment_buffer to TensorBoard. - writer.add_scalar(f'Buffer/memory_usage/game_segment_buffer_{task_id}', buffer_memory_usage_mb, train_iter) + for name, param in param_list: + total_params += param.numel() + status = "Trainable" if param.requires_grad else "Frozen" + logger.info(f" - {name:<60} | Shape: {str(param.shape):<25} | Status: {status}") + if param.requires_grad: + trainable_params += param.numel() - # Get the amount of memory currently used by the process (in bytes). - process = psutil.Process(os.getpid()) - process_memory_usage = process.memory_info().rss + logger.info(f"--- Summary for Module: '{module_name}' ---") + logger.info(f" - Total Parameters: {total_params:,}") + logger.info(f" - Trainable Parameters: {trainable_params:,}") + if total_params > 0: + percentage = 100 * trainable_params / total_params + logger.info(f" - Trainable Percentage: {percentage:.4f}%") + logger.info("-" * (len(module_name) + 40)) - # Convert process_memory_usage to megabytes (MB). - process_memory_usage_mb = process_memory_usage / (1024 * 1024) - - # Record the memory usage of the process to TensorBoard. - writer.add_scalar(f'Buffer/memory_usage/process_{task_id}', process_memory_usage_mb, train_iter) +def log_param_statistics(model: nn.Module, logger: logging.Logger) -> None: + """ + Overview: + Logs a concise summary of the number and size of trainable versus total + parameters in a model. -def log_buffer_run_time(train_iter: int, buffer: "GameBuffer", writer: SummaryWriter) -> None: + Arguments: + - model (:obj:`nn.Module`): The model to analyze. + - logger (:obj:`logging.Logger`): The logger instance for output. + """ + n_tensors_total = sum(1 for _ in model.parameters()) + n_tensors_train = sum(1 for p in model.parameters() if p.requires_grad) + + n_elems_total = sum(p.numel() for p in model.parameters()) + n_elems_train = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.info( + f'Trainable Parameters: ' + f'{n_tensors_train}/{n_tensors_total} tensors | ' + f'{n_elems_train:,}/{n_elems_total:,} elements ' + f'({n_elems_train/1e6:.2f}M / {n_elems_total/1e6:.2f}M)' + ) + + +def log_buffer_memory_usage( + train_iter: int, + buffer: GameBuffer, + writer: SummaryWriter, + task_id: int = 0 +) -> None: """ Overview: - Log the average runtime metrics of the buffer to TensorBoard. + Logs the memory usage of the replay buffer and the current process to TensorBoard. + Arguments: - train_iter (:obj:`int`): The current training iteration. - - buffer (:obj:`GameBuffer`): The game buffer containing runtime metrics. - - writer (:obj:`SummaryWriter`): The TensorBoard writer for logging metrics. - - .. note:: - "writer is None" indicates that the function is being called in a slave process in the DDP setup. + - buffer (:obj:`GameBuffer`): The replay buffer instance. + - writer (:obj:`SummaryWriter`): The TensorBoard writer. + - task_id (:obj:`int`): An optional ID to distinguish logs for different tasks. """ - if writer is not None: - sample_times = buffer.sample_times + # In DDP, only the main process should write to TensorBoard. + if writer is None: + return - if sample_times == 0: - return + prefix = f"Buffer/Task_{task_id}" + writer.add_scalar(f'{prefix}/num_collected_episodes', buffer.num_of_collected_episodes, train_iter) + writer.add_scalar(f'{prefix}/num_game_segments', len(buffer.game_segment_buffer), train_iter) + writer.add_scalar(f'{prefix}/num_transitions', len(buffer.game_segment_game_pos_look_up), train_iter) - # Calculate and log average reanalyze time. - average_reanalyze_time = buffer.compute_target_re_time / sample_times - writer.add_scalar('Buffer/average_reanalyze_time', average_reanalyze_time, train_iter) + # Calculate and log memory usage of the main buffer component. + buffer_memory_bytes = asizeof(buffer.game_segment_buffer) + buffer_memory_mb = buffer_memory_bytes / (1024 * 1024) + writer.add_scalar(f'{prefix}/memory_usage_mb/game_segment_buffer', buffer_memory_mb, train_iter) - # Calculate and log average origin search time. - average_origin_search_time = buffer.origin_search_time / sample_times - writer.add_scalar('Buffer/average_origin_search_time', average_origin_search_time, train_iter) + # Get and log total memory usage of the current process. + process = psutil.Process(os.getpid()) + process_memory_bytes = process.memory_info().rss + process_memory_mb = process_memory_bytes / (1024 * 1024) + writer.add_scalar(f'{prefix}/memory_usage_mb/process', process_memory_mb, train_iter) - # Calculate and log average reuse search time. - average_reuse_search_time = buffer.reuse_search_time / sample_times - writer.add_scalar('Buffer/average_reuse_search_time', average_reuse_search_time, train_iter) - # Calculate and log average active root number. - average_active_root_num = buffer.active_root_num / sample_times - writer.add_scalar('Buffer/average_active_root_num', average_active_root_num, train_iter) +def log_buffer_run_time(train_iter: int, buffer: GameBuffer, writer: SummaryWriter) -> None: + """ + Overview: + Logs average runtime metrics related to buffer operations (e.g., sampling, search) + to TensorBoard. - # Reset the time records in the buffer. - buffer.reset_runtime_metrics() + Arguments: + - train_iter (:obj:`int`): The current training iteration. + - buffer (:obj:`GameBuffer`): The buffer instance containing runtime metrics. + - writer (:obj:`SummaryWriter`): The TensorBoard writer. + """ + if writer is None or buffer.sample_times == 0: + return + + sample_times = buffer.sample_times + writer.add_scalar('Buffer/avg_reanalyze_time_ms', (buffer.compute_target_re_time / sample_times) * 1000, train_iter) + writer.add_scalar('Buffer/avg_origin_search_time_ms', (buffer.origin_search_time / sample_times) * 1000, train_iter) + writer.add_scalar('Buffer/avg_reuse_search_time_ms', (buffer.reuse_search_time / sample_times) * 1000, train_iter) + writer.add_scalar('Buffer/avg_active_root_num', buffer.active_root_num / sample_times, train_iter) + + # Reset metrics after logging to prepare for the next interval. + buffer.reset_runtime_metrics() + + +# ============================================================================== +# Example Usage +# ============================================================================== +if __name__ == '__main__': + # Configure a basic logger to see output from functions with `verbose=True` + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + + print("\n--- Example for `compute_task_weights` ---") + task_rewards_list = [ + {"task1": 10, "task2": 100, "task3": 1000, "task4": 500, "task5": 300}, + {"task1": 1, "task2": 10, "task3": 100, "task4": 1000, "task5": 10000}, + {"task1": 0.1, "task2": 0.5, "task3": 0.9, "task4": 5, "task5": 10}, + ] + + for i, task_rewards in enumerate(task_rewards_list, start=1): + print(f"\n--- Case {i} ---") + print(f"Original Rewards: {task_rewards}") + + # Example 1: Using 'none' normalization (proportional to raw values) + weights_none = compute_task_weights(task_rewards, option="none", use_softmax=False) + print(f"Weights (proportional to raw values): {weights_none}") + + # Example 2: Using 'symlog' normalization + weights_symlog = compute_task_weights(task_rewards, option="symlog", use_softmax=False) + print(f"Weights (with symlog normalization): {weights_symlog}") + + # Example 3: Using 'rank' normalization and softmax with inverse proportion + weights_rank_softmax = compute_task_weights(task_rewards, option="rank", use_softmax=True, reverse=True) + print(f"Weights (inverse rank with softmax): {weights_rank_softmax}") + + print("\n--- Example for `freeze_non_lora` ---") + + # ========================================================================== + # FIX: The nn.Parameter must be wrapped in an nn.Module subclass to be + # placed inside an nn.ModuleDict. + # ========================================================================== + class AdapterScale(nn.Module): + """A simple nn.Module wrapper for a single learnable parameter.""" + def __init__(self): + super().__init__() + self.logit = nn.Parameter(torch.randn(1)) + + # Create a dummy model to demonstrate freezing + class DummyModel(nn.Module): + def __init__(self): + super().__init__() + self.backbone = nn.Linear(10, 10) + self.layer1 = nn.Linear(10, 10) + # Simulate LoRA parameters with correct naming + self.layer1.lora_A = nn.Parameter(torch.randn(10, 2)) + self.layer1.lora_B = nn.Parameter(torch.randn(2, 10)) + + # Correctly structure the adapter_scales using the wrapper module. + # This ensures that the value associated with key '0' is a valid nn.Module. + self.adapter_scales = nn.ModuleDict({ + '0': AdapterScale() + }) + + model = DummyModel() + print("Initial parameter status:") + log_module_trainable_status(model, "DummyModel", logging.getLogger()) + + print("\nFreezing non-LoRA parameters...") + freeze_non_lora(model, freeze=True, verbose=True) + print("\nParameter status after freezing:") + log_module_trainable_status(model, "DummyModel", logging.getLogger()) + + print("\nUn-freezing non-LoRA parameters...") + freeze_non_lora(model, freeze=False, verbose=True) + print("\nParameter status after un-freezing:") + log_module_trainable_status(model, "DummyModel", logging.getLogger()) \ No newline at end of file diff --git a/lzero/mcts/buffer/game_buffer.py b/lzero/mcts/buffer/game_buffer.py index 1097636f3..253935652 100644 --- a/lzero/mcts/buffer/game_buffer.py +++ b/lzero/mcts/buffer/game_buffer.py @@ -153,15 +153,51 @@ def _sample_orig_data(self, batch_size: int, print_priority_logs: bool = False) # Indices exceeding `game_segment_length` are padded with the next segment and are not updated # in the current implementation. Therefore, we need to sample `pos_in_game_segment` within # [0, game_segment_length - num_unroll_steps] to avoid padded data. - # TODO: Consider increasing `self._cfg.game_segment_length` to ensure sampling efficiency. - # NOTE: Sample the init position from the whole segment, but not from the padded part - if pos_in_game_segment >= self._cfg.game_segment_length - self._cfg.num_unroll_steps: - pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps, 1).item() + + if self._cfg.action_type == 'varied_action_space': + # For some environments (e.g., Jericho), the action space size may be different. + # To ensure we can always unroll `num_unroll_steps` steps starting from the sampled position (without exceeding segment length), + # we avoid sampling from the last `num_unroll_steps` steps of the game segment. + if pos_in_game_segment >= self._cfg.game_segment_length - self._cfg.num_unroll_steps - self._cfg.td_steps: + pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps - self._cfg.td_steps, 1).item() + + segment_len = len(game_segment.action_segment) + if pos_in_game_segment >= segment_len - 1: + # If the segment is very short (length 0 or 1), we can't randomly sample a position + # before the last one. The only safe position is 0. + if segment_len > 1: + # If the segment has at least 2 actions, we can safely sample from [0, len-2]. + # The upper bound for np.random.choice is exclusive, so (segment_len - 1) is correct. + pos_in_game_segment = np.random.choice(segment_len - 1, 1).item() + else: + # If segment length is 0 or 1, the only valid/safe position is 0. + pos_in_game_segment = 0 + + else: + # For environments with a fixed action space (e.g., Atari), + # we can safely sample from the entire game segment range. + if pos_in_game_segment >= self._cfg.game_segment_length: + pos_in_game_segment = np.random.choice(self._cfg.game_segment_length, 1).item() + + segment_len = len(game_segment.action_segment) + if pos_in_game_segment >= segment_len - 1: + # If the segment is very short (length 0 or 1), we can't randomly sample a position + # before the last one. The only safe position is 0. + if segment_len > 1: + # If the segment has at least 2 actions, we can safely sample from [0, len-2]. + # The upper bound for np.random.choice is exclusive, so (segment_len - 1) is correct. + pos_in_game_segment = np.random.choice(segment_len - 1, 1).item() + else: + # If segment length is 0 or 1, the only valid/safe position is 0. + pos_in_game_segment = 0 pos_in_game_segment_list.append(pos_in_game_segment) - make_time = [time.time() for _ in range(len(batch_index_list))] + # make_time = [time.time() for _ in range(len(batch_index_list))] + + # Set the make_time for each sample (set to 0 for now, but can be the actual time if needed). + make_time = [0. for _ in range(len(batch_index_list))] orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time) @@ -173,109 +209,136 @@ def _sample_orig_data(self, batch_size: int, print_priority_logs: bool = False) return orig_data def _sample_orig_reanalyze_batch(self, batch_size: int) -> Tuple: - """ - Overview: - This function samples a batch of game segments for reanalysis from the replay buffer. - It uses priority sampling based on the `reanalyze_time` of each game segment, with segments - that have been reanalyzed more frequently receiving lower priority. - - The function returns a tuple containing information about the sampled game segments, - including their positions within each segment and the time the batch was created. - Arguments: - - batch_size (:obj:`int`): - The number of samples to draw in this batch. - - Returns: - - Tuple: - A tuple containing the following elements: - - game_segment_list: A list of the sampled game segments. - - pos_in_game_segment_list: A list of indices representing the position of each transition - within its corresponding game segment. - - batch_index_list: The indices of the sampled game segments in the replay buffer. - - make_time: A list of timestamps (set to `0` in this implementation) indicating when - the batch was created. - - Key Details: - 1. **Priority Sampling**: - Game segments are sampled based on a probability distribution calculated using - the `reanalyze_time` of each segment. Segments that have been reanalyzed more frequently - are less likely to be selected. - 2. **Segment Slicing**: - Each selected game segment is sampled at regular intervals determined by the - `num_unroll_steps` parameter. Up to `samples_per_segment` transitions are sampled - from each selected segment. - 3. **Handling Extra Samples**: - If the `batch_size` is not perfectly divisible by the number of samples per segment, - additional segments are sampled to make up the difference. - 4. **Reanalyze Time Update**: - The `reanalyze_time` attribute of each sampled game segment is incremented to reflect - that it has been selected for reanalysis again. - Raises: - - ValueError: - If the `game_segment_length` is too small to accommodate the `num_unroll_steps`. - """ - train_sample_num = len(self.game_segment_buffer) - assert self._cfg.reanalyze_partition <= 0.75, "The reanalyze partition should be less than 0.75." - valid_sample_num = int(train_sample_num * self._cfg.reanalyze_partition) - - # Calculate the number of samples per segment - samples_per_segment = self._cfg.game_segment_length // self._cfg.num_unroll_steps - - # Make sure that the batch size can be divided by the number of samples per segment - if samples_per_segment == 0: - raise ValueError("The game segment length is too small for num_unroll_steps.") - - # Calculate the number of samples per segment - batch_size_per_segment = batch_size // samples_per_segment - - # If the batch size cannot be divided, process the remainder part - extra_samples = batch_size % samples_per_segment - - # We use the reanalyze_time in the game_segment_buffer to generate weights - reanalyze_times = np.array([segment.reanalyze_time for segment in self.game_segment_buffer[:valid_sample_num]]) - - # Calculate weights: the larger the reanalyze_time, the smaller the weight (use exp(-reanalyze_time)) - base_decay_rate = 100 - decay_rate = base_decay_rate / valid_sample_num - weights = np.exp(-decay_rate * reanalyze_times) - - # Normalize the weights to a probability distribution - probabilities = weights / np.sum(weights) - - # Sample game segments according to the probabilities - selected_game_segments = np.random.choice(valid_sample_num, batch_size_per_segment, replace=False, - p=probabilities) - - # If there are extra samples to be allocated, randomly select some game segments and sample again - if extra_samples > 0: - extra_game_segments = np.random.choice(valid_sample_num, extra_samples, replace=False, p=probabilities) - selected_game_segments = np.concatenate((selected_game_segments, extra_game_segments)) - - game_segment_list = [] - pos_in_game_segment_list = [] - batch_index_list = [] - - for game_segment_idx in selected_game_segments: - game_segment_idx -= self.base_idx - game_segment = self.game_segment_buffer[game_segment_idx] - - # Update reanalyze_time only once - game_segment.reanalyze_time += 1 - - # The sampling position should be 0, 0 + num_unroll_steps, ... (integer multiples of num_unroll_steps) - for i in range(samples_per_segment): - game_segment_list.append(game_segment) - pos_in_game_segment = i * self._cfg.num_unroll_steps - if pos_in_game_segment >= len(game_segment): - pos_in_game_segment = np.random.choice(len(game_segment), 1).item() - pos_in_game_segment_list.append(pos_in_game_segment) - batch_index_list.append(game_segment_idx) - - # Set the make_time for each sample (set to 0 for now, but can be the actual time if needed). - make_time = [0. for _ in range(len(batch_index_list))] - - orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, [], make_time) - return orig_data + """ + Overview: + This function samples a batch of game segments for reanalysis from the replay buffer. + It uses priority sampling based on the `reanalyze_time` of each game segment, with segments + that have been reanalyzed more frequently receiving lower priority. + + The function returns a tuple containing information about the sampled game segments, + including their positions within each segment and the time the batch was created. + Arguments: + - batch_size (:obj:`int`): + The number of samples to draw in this batch. + + Returns: + - Tuple: + A tuple containing the following elements: + - game_segment_list: A list of the sampled game segments. + - pos_in_game_segment_list: A list of indices representing the position of each transition + within its corresponding game segment. + - batch_index_list: The indices of the sampled game segments in the replay buffer. + - make_time: A list of timestamps (set to `0` in this implementation) indicating when + the batch was created. + + Key Details: + 1. **Priority Sampling**: + Game segments are sampled based on a probability distribution calculated using + the `reanalyze_time` of each segment. Segments that have been reanalyzed more frequently + are less likely to be selected. + 2. **Segment Slicing**: + Each selected game segment is sampled at regular intervals determined by the + `num_unroll_steps` parameter. Up to `samples_per_segment` transitions are sampled + from each selected segment. + 3. **Handling Extra Samples**: + If the `batch_size` is not perfectly divisible by the number of samples per segment, + additional segments are sampled to make up the difference. + 4. **Reanalyze Time Update**: + The `reanalyze_time` attribute of each sampled game segment is incremented to reflect + that it has been selected for reanalysis again. + Raises: + - ValueError: + If the `game_segment_length` is too small to accommodate the `num_unroll_steps`. + """ + train_sample_num = len(self.game_segment_buffer) + assert self._cfg.reanalyze_partition <= 0.75, "The reanalyze partition should be less than 0.75." + valid_sample_num = int(train_sample_num * self._cfg.reanalyze_partition) + + # Calculate the number of samples per segment + samples_per_segment = self._cfg.game_segment_length // self._cfg.num_unroll_steps + + # Make sure that the batch size can be divided by the number of samples per segment + if samples_per_segment == 0: + raise ValueError("The game segment length is too small for num_unroll_steps.") + + # Calculate the number of samples per segment + batch_size_per_segment = batch_size // samples_per_segment + + # If the batch size cannot be divided, process the remainder part + extra_samples = batch_size % samples_per_segment + + # We use the reanalyze_time in the game_segment_buffer to generate weights + reanalyze_times = np.array([segment.reanalyze_time for segment in self.game_segment_buffer[:valid_sample_num]]) + + # Calculate weights: the larger the reanalyze_time, the smaller the weight (use exp(-reanalyze_time)) + base_decay_rate = 100 + # Add a small epsilon to avoid division by zero if valid_sample_num is 0 + decay_rate = base_decay_rate / (valid_sample_num + 1e-6) + weights = np.exp(-decay_rate * reanalyze_times) + + # Normalize the weights to a probability distribution, handle case where sum is zero + sum_weights = np.sum(weights) + if sum_weights > 0: + probabilities = weights / sum_weights + else: + # If all weights are zero, use a uniform distribution + probabilities = np.ones(valid_sample_num) / valid_sample_num + + # Sample game segments according to the probabilities + # Ensure valid_sample_num is not zero before sampling + if valid_sample_num == 0: + return ([], [], [], [], []) + + selected_game_segments = np.random.choice(valid_sample_num, batch_size_per_segment, replace=False, + p=probabilities) + + # If there are extra samples to be allocated, randomly select some game segments and sample again + if extra_samples > 0: + # We need to handle the case where we might sample the same segment again. + # A simple way is to allow replacement for extra samples or sample from remaining ones. + # For simplicity, let's stick to the original logic but ensure it's safe. + remaining_segments = np.setdiff1d(np.arange(valid_sample_num), selected_game_segments) + if len(remaining_segments) < extra_samples: + # If not enough unique segments left, sample with replacement from all valid segments + extra_game_segments = np.random.choice(valid_sample_num, extra_samples, replace=True, p=probabilities) + else: + # Sample from the remaining unique segments + remaining_probs = probabilities[remaining_segments] + remaining_probs /= np.sum(remaining_probs) + extra_game_segments = np.random.choice(remaining_segments, extra_samples, replace=False, p=remaining_probs) + + selected_game_segments = np.concatenate((selected_game_segments, extra_game_segments)) + + game_segment_list = [] + pos_in_game_segment_list = [] + batch_index_list = [] + print(f"selected_game_segments:{selected_game_segments}") + for game_segment_idx in selected_game_segments: + # ========================================================================= + # FIX: The line below is the source of the error and has been removed. + # `game_segment_idx` is already a valid physical index for `game_segment_buffer`. + # game_segment_idx -= self.base_idx + # ========================================================================= + game_segment = self.game_segment_buffer[game_segment_idx] + + # Update reanalyze_time only once + game_segment.reanalyze_time += 1 + + # The sampling position should be 0, 0 + num_unroll_steps, ... (integer multiples of num_unroll_steps) + for i in range(samples_per_segment): + game_segment_list.append(game_segment) + pos_in_game_segment = i * self._cfg.num_unroll_steps + if pos_in_game_segment >= len(game_segment): + pos_in_game_segment = np.random.choice(len(game_segment), 1).item() + pos_in_game_segment_list.append(pos_in_game_segment) + # NOTE: We should append the physical index here, as it corresponds to the sampled segment. + batch_index_list.append(game_segment_idx) + + # Set the make_time for each sample (set to 0 for now, but can be the actual time if needed). + make_time = [0. for _ in range(len(batch_index_list))] + + orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, [], make_time) + return orig_data def _sample_orig_reanalyze_data(self, batch_size: int) -> Tuple: """ diff --git a/lzero/mcts/buffer/game_buffer_efficientzero.py b/lzero/mcts/buffer/game_buffer_efficientzero.py index a909e6a3a..8941b1fc5 100644 --- a/lzero/mcts/buffer/game_buffer_efficientzero.py +++ b/lzero/mcts/buffer/game_buffer_efficientzero.py @@ -7,7 +7,7 @@ from lzero.mcts.tree_search.mcts_ctree import EfficientZeroMCTSCtree as MCTSCtree from lzero.mcts.tree_search.mcts_ptree import EfficientZeroMCTSPtree as MCTSPtree from lzero.mcts.utils import prepare_observation -from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform +from lzero.policy import DiscreteSupport, to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform from .game_buffer_muzero import MuZeroGameBuffer @@ -45,6 +45,9 @@ def __init__(self, cfg: dict): self.base_idx = 0 self.clear_time = 0 + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range) + def sample(self, batch_size: int, policy: Any) -> List[Any]: """ Overview: @@ -209,7 +212,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( [ m_output.latent_state, - inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), + inverse_scalar_transform(m_output.value, self.value_support), m_output.policy_logits ] ) @@ -359,7 +362,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( [ m_output.latent_state, - inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), + inverse_scalar_transform(m_output.value, self.value_support), m_output.policy_logits ] ) diff --git a/lzero/mcts/buffer/game_buffer_muzero.py b/lzero/mcts/buffer/game_buffer_muzero.py index 664b2042f..972a95498 100644 --- a/lzero/mcts/buffer/game_buffer_muzero.py +++ b/lzero/mcts/buffer/game_buffer_muzero.py @@ -8,7 +8,7 @@ from lzero.mcts.tree_search.mcts_ctree import MuZeroMCTSCtree as MCTSCtree from lzero.mcts.tree_search.mcts_ptree import MuZeroMCTSPtree as MCTSPtree from lzero.mcts.utils import prepare_observation -from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform +from lzero.policy import DiscreteSupport, to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform from .game_buffer import GameBuffer if TYPE_CHECKING: @@ -73,6 +73,8 @@ def __init__(self, cfg: dict): self.task_id = None print("No task_id found in configuration. Task ID is set to None.") self.action_space_size = self._cfg.model.action_space_size + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range) def reset_runtime_metrics(self): """ @@ -433,6 +435,24 @@ def _prepare_policy_reanalyzed_context( ] return policy_re_context + def _scalar_reward(self, r: Any) -> float: + """ + Overview: + Convert a reward input of various types into a scalar float value. + Arguments: + - r (Any): The reward input, which can be a numpy array, list, tuple, or a scalar. + If it is a numpy array, list, or tuple, the function uses the first element. + Returns: + - float: The scalar representation of the input reward. + """ + # If the reward is in the form of a list, tuple, or numpy array, + # convert it to a numpy array, reshape it into a flat array, and take the first element. + if isinstance(r, (list, tuple, np.ndarray)): + r = np.asarray(r).reshape(-1)[0] + + # Return the float value of the reward. + return float(r) + def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any) -> Tuple[Any, Any]: """ Overview: @@ -528,11 +548,11 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A value_list[value_index] += reward * self._cfg.discount_factor ** i # TODO: check the boundary condition - target_values.append(value_list[value_index]) + target_values.append(self._scalar_reward(value_list[value_index])) if current_index < len(reward_list): - target_rewards.append(reward_list[current_index]) + target_rewards.append(self._scalar_reward(reward_list[current_index])) else: - target_rewards.append(np.array(0.)) + target_rewards.append(0.) value_index += 1 diff --git a/lzero/mcts/buffer/game_buffer_rezero_ez.py b/lzero/mcts/buffer/game_buffer_rezero_ez.py index fdfae46df..c78381d02 100644 --- a/lzero/mcts/buffer/game_buffer_rezero_ez.py +++ b/lzero/mcts/buffer/game_buffer_rezero_ez.py @@ -6,7 +6,7 @@ from lzero.mcts.tree_search.mcts_ctree import EfficientZeroMCTSCtree as MCTSCtree from lzero.mcts.utils import prepare_observation -from lzero.policy import to_detach_cpu_numpy, concat_output, inverse_scalar_transform +from lzero.policy import DiscreteSupport, to_detach_cpu_numpy, concat_output, inverse_scalar_transform from .game_buffer_efficientzero import EfficientZeroGameBuffer from .game_buffer_rezero_mz import ReZeroMZGameBuffer, compute_all_filters @@ -71,6 +71,9 @@ def __init__(self, cfg: dict): self.active_root_num = 0 self.average_infer = 0 + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range) + def sample( self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] ) -> List[Any]: @@ -172,7 +175,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( [ m_output.latent_state, - inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), + inverse_scalar_transform(m_output.value, self.value_support), m_output.policy_logits ] ) diff --git a/lzero/mcts/buffer/game_buffer_rezero_mz.py b/lzero/mcts/buffer/game_buffer_rezero_mz.py index 9e864ac5e..4ffffd315 100644 --- a/lzero/mcts/buffer/game_buffer_rezero_mz.py +++ b/lzero/mcts/buffer/game_buffer_rezero_mz.py @@ -8,7 +8,7 @@ from lzero.mcts.tree_search.mcts_ctree import MuZeroMCTSCtree as MCTSCtree from lzero.mcts.tree_search.mcts_ptree import MuZeroMCTSPtree as MCTSPtree from lzero.mcts.utils import prepare_observation -from lzero.policy import to_detach_cpu_numpy, concat_output, inverse_scalar_transform +from lzero.policy import DiscreteSupport, to_detach_cpu_numpy, concat_output, inverse_scalar_transform from .game_buffer_muzero import MuZeroGameBuffer # from line_profiler import line_profiler @@ -76,6 +76,9 @@ def __init__(self, cfg: dict): self.active_root_num = 0 self.average_infer = 0 + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range) + def reanalyze_buffer( self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] ) -> List[Any]: @@ -244,7 +247,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: m_output.latent_state, m_output.value, m_output.policy_logits = to_detach_cpu_numpy( [ m_output.latent_state, - inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), + inverse_scalar_transform(m_output.value, self.value_support), m_output.policy_logits ] ) diff --git a/lzero/mcts/buffer/game_buffer_sampled_efficientzero.py b/lzero/mcts/buffer/game_buffer_sampled_efficientzero.py index 1821f7a2e..6f715b285 100644 --- a/lzero/mcts/buffer/game_buffer_sampled_efficientzero.py +++ b/lzero/mcts/buffer/game_buffer_sampled_efficientzero.py @@ -7,7 +7,7 @@ from lzero.mcts.tree_search.mcts_ctree_sampled import SampledEfficientZeroMCTSCtree as MCTSCtree from lzero.mcts.tree_search.mcts_ptree_sampled import SampledEfficientZeroMCTSPtree as MCTSPtree from lzero.mcts.utils import prepare_observation, generate_random_actions_discrete -from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform +from lzero.policy import DiscreteSupport, to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform from .game_buffer_efficientzero import EfficientZeroGameBuffer @@ -45,6 +45,9 @@ def __init__(self, cfg: dict): self.base_idx = 0 self.clear_time = 0 + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range) + def sample(self, batch_size: int, policy: Any) -> List[Any]: """ Overview: @@ -291,7 +294,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( [ m_output.latent_state, - inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), + inverse_scalar_transform(m_output.value, self.value_support), m_output.policy_logits ] ) @@ -398,13 +401,13 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A horizon_id += 1 if current_index < game_segment_len_non_re: - target_values.append(value_list[value_index]) + target_values.append(value_list[value_index].item()) # Since the horizon is small and the discount_factor is close to 1. # Compute the reward sum to approximate the value prefix for simplification value_prefix += reward_list[current_index].item() # * config.discount_factor ** (current_index - base_index) target_value_prefixs.append(value_prefix.item()) else: - target_values.append(np.array(0.)) + target_values.append(0.) target_value_prefixs.append(value_prefix.item()) value_index += 1 @@ -469,7 +472,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( [ m_output.latent_state, - inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), + inverse_scalar_transform(m_output.value, self.value_support), m_output.policy_logits ] ) diff --git a/lzero/mcts/buffer/game_buffer_sampled_muzero.py b/lzero/mcts/buffer/game_buffer_sampled_muzero.py index ddbdd5a05..8e04d77b5 100644 --- a/lzero/mcts/buffer/game_buffer_sampled_muzero.py +++ b/lzero/mcts/buffer/game_buffer_sampled_muzero.py @@ -7,7 +7,7 @@ from lzero.mcts.tree_search.mcts_ctree_sampled import SampledMuZeroMCTSCtree as MCTSCtree # from lzero.mcts.tree_search.mcts_ptree_sampled import SampledMuZeroMCTSPtree as MCTSPtree from lzero.mcts.utils import prepare_observation, generate_random_actions_discrete -from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform +from lzero.policy import DiscreteSupport, to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform from .game_buffer_muzero import MuZeroGameBuffer @@ -45,6 +45,9 @@ def __init__(self, cfg: dict): self.base_idx = 0 self.clear_time = 0 + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range) + def sample(self, batch_size: int, policy: Any) -> List[Any]: """ Overview: @@ -291,7 +294,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( [ m_output.latent_state, - inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), + inverse_scalar_transform(m_output.value, self.value_support), m_output.policy_logits ] ) @@ -454,7 +457,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( [ m_output.latent_state, - inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), + inverse_scalar_transform(m_output.value, self.value_support), m_output.policy_logits ] ) diff --git a/lzero/mcts/buffer/game_buffer_sampled_unizero.py b/lzero/mcts/buffer/game_buffer_sampled_unizero.py index 5af5228a2..da09fc311 100644 --- a/lzero/mcts/buffer/game_buffer_sampled_unizero.py +++ b/lzero/mcts/buffer/game_buffer_sampled_unizero.py @@ -7,7 +7,7 @@ from lzero.mcts.tree_search.mcts_ctree_sampled import SampledUniZeroMCTSCtree as MCTSCtree # from lzero.mcts.tree_search.mcts_ptree import MuZeroMCTSPtree as MCTSPtree from lzero.mcts.utils import prepare_observation, generate_random_actions_discrete -from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform +from lzero.policy import DiscreteSupport, to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform from .game_buffer_unizero import UniZeroGameBuffer if TYPE_CHECKING: @@ -60,6 +60,8 @@ def __init__(self, cfg: dict): print("No task_id found in configuration. Task ID is set to None.") self.action_space_size = self._cfg.model.action_space_size + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range) def reanalyze_buffer( self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] @@ -510,7 +512,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( [ m_output.latent_state, - inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), + inverse_scalar_transform(m_output.value, self.value_support), m_output.policy_logits ] ) @@ -668,19 +670,18 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A # calculate the target value # batch_obs.shape torch.Size([352, 3, 64, 64]) 32*11 = 352 if self.task_id is not None: - m_output = model.initial_inference(batch_obs, batch_action, start_pos=batch_timestep, task_id=self.task_id) + # m_output = model.initial_inference(batch_obs, batch_action, start_pos=batch_timestep, task_id=self.task_id) + + m_output = model.initial_inference(batch_obs, batch_action, task_id=self.task_id) else: m_output = model.initial_inference(batch_obs, batch_action, start_pos=batch_timestep) # ====================================================================== - # print(f'model.training:{model.training}') - # model.training = False - # if not model.training: # if not in training, obtain the scalars of the value/reward [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( [ m_output.latent_state, - inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), + inverse_scalar_transform(m_output.value, self.value_support), m_output.policy_logits ] ) diff --git a/lzero/mcts/buffer/game_buffer_unizero.py b/lzero/mcts/buffer/game_buffer_unizero.py index 38c1935ea..b4de66031 100644 --- a/lzero/mcts/buffer/game_buffer_unizero.py +++ b/lzero/mcts/buffer/game_buffer_unizero.py @@ -6,7 +6,7 @@ from lzero.mcts.tree_search.mcts_ctree import UniZeroMCTSCtree as MCTSCtree from lzero.mcts.utils import prepare_observation -from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform +from lzero.policy import DiscreteSupport, to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform from .game_buffer_muzero import MuZeroGameBuffer if TYPE_CHECKING: @@ -61,6 +61,9 @@ def __init__(self, cfg: dict): print("No task_id found in configuration. Task ID is set to None.") self.action_space_size = self._cfg.model.action_space_size + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range) + #@profile def sample( self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] @@ -145,19 +148,12 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: self._cfg.num_unroll_steps].tolist() timestep_tmp = game.timestep_segment[pos_in_game_segment:pos_in_game_segment + self._cfg.num_unroll_steps].tolist() - # add mask for invalid actions (out of trajectory), 1 for valid, 0 for invalid - # mask_tmp = [1. for i in range(len(actions_tmp))] - # mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] # TODO: the child_visits after position in the segment (with padded part) may not be updated # So the corresponding position should not be used in the training mask_tmp = [1. for i in range(min(len(actions_tmp), self._cfg.game_segment_length - pos_in_game_segment))] mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] - # TODO: original buffer mask - # mask_tmp = [1. for i in range(min(len(actions_tmp), self._cfg.game_segment_length - pos_in_game_segment))] - # mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] - # pad random action actions_tmp += [ np.random.randint(0, game.action_space_size) @@ -294,9 +290,6 @@ def _make_batch_for_reanalyze(self, batch_size: int) -> Tuple[Any]: mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] timestep_tmp = game.timestep_segment[pos_in_game_segment:pos_in_game_segment + self._cfg.num_unroll_steps].tolist() - # TODO: original buffer mask - # mask_tmp = [1. for i in range(min(len(actions_tmp), self._cfg.game_segment_length - pos_in_game_segment))] - # mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] # pad random action actions_tmp += [ @@ -461,7 +454,6 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: # ======================================================================= - # if not model.training: # if not in training, obtain the scalars of the value/reward [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( [ @@ -487,6 +479,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: # do MCTS for a new policy with the recent target model if self.task_id is not None: MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play, task_id=self.task_id) + # TODO: adapt unizero multitask to timestep in rope # MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play, batch_timestep[:self.reanalyze_num], task_id=self.task_id) else: MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play, batch_timestep[:self.reanalyze_num]) @@ -582,12 +575,11 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A # ====================================================================== - # if not model.training: # if not in training, obtain the scalars of the value/reward [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( [ m_output.latent_state, - inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), + inverse_scalar_transform(m_output.value, self.value_support), m_output.policy_logits ] ) @@ -667,3 +659,34 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A batch_target_values = np.asarray(batch_target_values) return batch_rewards, batch_target_values + + def update_priority(self, train_data: List[np.ndarray], batch_priorities: np.ndarray) -> None: + """ + Overview: + Update the priority of training data. + Arguments: + - train_data (:obj:`List[np.ndarray]`): training data to be updated priority. + - batch_priorities (:obj:`np.ndarray`): priorities to update to. + NOTE: + train_data = [current_batch, target_batch] + current_batch = [obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list] + """ + # TODO: NOTE: -4 is batch_index_list + indices = train_data[0][-4] + metas = {'make_time': train_data[0][-1], 'batch_priorities': batch_priorities} + # only update the priorities for data still in replay buffer + for i in range(len(indices)): + # ==================== START OF FINAL FIX ==================== + + # FIX 1: Handle ValueError by using the first timestamp of the segment for comparison. + first_transition_time = metas['make_time'][i][0] + + if first_transition_time > self.clear_time: + # FIX 2: Handle IndexError by converting the float index to an integer before use. + idx = int(indices[i]) + prio = metas['batch_priorities'][i] + + # Now, idx is a valid integer index. + self.game_pos_priorities[idx] = prio + + # ===================== END OF FINAL FIX ===================== diff --git a/lzero/mcts/tests/config/atari_efficientzero_config_for_test.py b/lzero/mcts/tests/config/atari_efficientzero_config_for_test.py index a376d7b16..2dce111bb 100644 --- a/lzero/mcts/tests/config/atari_efficientzero_config_for_test.py +++ b/lzero/mcts/tests/config/atari_efficientzero_config_for_test.py @@ -58,7 +58,8 @@ self_supervised_learning_loss=True, categorical_distribution=True, image_channel=1, - support_scale=300, + reward_support_range=(-300., 301., 1.), + value_support_range=(-300., 301., 1.), lstm_hidden_size=512, ), cuda=True, diff --git a/lzero/mcts/tests/config/tictactoe_muzero_bot_mode_config_for_test.py b/lzero/mcts/tests/config/tictactoe_muzero_bot_mode_config_for_test.py index 18442f461..27433c608 100644 --- a/lzero/mcts/tests/config/tictactoe_muzero_bot_mode_config_for_test.py +++ b/lzero/mcts/tests/config/tictactoe_muzero_bot_mode_config_for_test.py @@ -53,9 +53,8 @@ reward_head_hidden_channels=[8], value_head_hidden_channels=[8], policy_head_hidden_channels=[8], - support_scale=10, - reward_support_size=21, - value_support_size=21, + reward_support_range=(-10., 11., 1.), + value_support_range=(-10., 11., 1.), categorical_distribution=True, ), cuda=True, diff --git a/lzero/mcts/tests/cprofile_mcts_ptree.py b/lzero/mcts/tests/cprofile_mcts_ptree.py index 956ec39fa..9e79aeb6d 100644 --- a/lzero/mcts/tests/cprofile_mcts_ptree.py +++ b/lzero/mcts/tests/cprofile_mcts_ptree.py @@ -1,7 +1,7 @@ import torch from easydict import EasyDict -from lzero.policy.scaling_transform import inverse_scalar_transform +from lzero.policy.scaling_transform import DiscreteSupport, inverse_scalar_transform class MuZeroModelFake(torch.nn.Module): @@ -76,7 +76,8 @@ def check_mcts(): model=dict( action_space_size=9, categorical_distribution=True, - support_scale=300, + reward_support_range=(-300., 301., 1.), + value_support_range=(-300., 301., 1.), ), ) ) @@ -100,8 +101,9 @@ def check_mcts(): policy_logits_pool = network_output['policy_logits'] # network output process + discrete_support = DiscreteSupport(*policy_config.model.value_support_range) pred_values_pool = inverse_scalar_transform(pred_values_pool, - policy_config.model.support_scale).detach().cpu().numpy() + discrete_support).detach().cpu().numpy() latent_state_roots = latent_state_roots.detach().cpu().numpy() reward_hidden_state_state = ( reward_hidden_state_state[0].detach().cpu().numpy(), reward_hidden_state_state[1].detach().cpu().numpy() diff --git a/lzero/mcts/tests/eval_tree_speed.py b/lzero/mcts/tests/eval_tree_speed.py index c7134f3b3..df5aaf325 100644 --- a/lzero/mcts/tests/eval_tree_speed.py +++ b/lzero/mcts/tests/eval_tree_speed.py @@ -1,6 +1,6 @@ import torch from easydict import EasyDict -from lzero.policy import inverse_scalar_transform, select_action +from lzero.policy import DiscreteSupport, inverse_scalar_transform, select_action import numpy as np import random @@ -81,6 +81,8 @@ def ptree_func(policy_config, num_simulations): search_time = [] total_time = [] + discrete_support = DiscreteSupport(*policy_config.model.value_support_range) + for n_s in num_simulations: t0 = time.time() model = MuZeroModelFake(action_num=action_space_size) @@ -102,7 +104,7 @@ def ptree_func(policy_config, num_simulations): # network output process pred_values_pool = inverse_scalar_transform(pred_values_pool, - policy_config.model.support_scale).detach().cpu().numpy() + discrete_support).detach().cpu().numpy() latent_state_roots = latent_state_roots.detach().cpu().numpy() reward_hidden_state_state = ( reward_hidden_state_state[0].detach().cpu().numpy(), reward_hidden_state_state[1].detach().cpu().numpy() @@ -175,6 +177,8 @@ def ctree_func(policy_config, num_simulations): search_time = [] total_time = [] + discrete_support = DiscreteSupport(*policy_config.model.value_support_range) + for n_s in num_simulations: t0 = time.time() model = MuZeroModelFake(action_num=action_space_size) @@ -196,7 +200,7 @@ def ctree_func(policy_config, num_simulations): # network output process pred_values_pool = inverse_scalar_transform(pred_values_pool, - policy_config.model.support_scale).detach().cpu().numpy() + discrete_support).detach().cpu().numpy() latent_state_roots = latent_state_roots.detach().cpu().numpy() reward_hidden_state_state = ( reward_hidden_state_state[0].detach().cpu().numpy(), reward_hidden_state_state[1].detach().cpu().numpy() @@ -297,7 +301,8 @@ def plot(ctree_time, ptree_time, iters, label): dict( lstm_horizon_len=5, model=dict( - support_scale=300, + reward_support_range=(-300., 301., 1.), + value_support_range=(-300., 301., 1.), categorical_distribution=True, ), action_space_size=100, diff --git a/lzero/mcts/tests/test_mcts_ctree.py b/lzero/mcts/tests/test_mcts_ctree.py index 21c2b3315..702cae1e6 100644 --- a/lzero/mcts/tests/test_mcts_ctree.py +++ b/lzero/mcts/tests/test_mcts_ctree.py @@ -3,7 +3,7 @@ import torch from easydict import EasyDict -from lzero.policy import inverse_scalar_transform, select_action +from lzero.policy import DiscreteSupport, inverse_scalar_transform, select_action policy = 'GumbelMuZero' @@ -89,7 +89,8 @@ def recurrent_inference(self, latent_states, reward_hidden_states, actions=None) value_delta_max=0.01, model=dict( action_space_size=9, - support_scale=300, + reward_support_range=(-300., 301., 1.), + value_support_range=(-300., 301., 1.), categorical_distribution=True, ), env_type='not_board_games', @@ -110,7 +111,8 @@ def recurrent_inference(self, latent_states, reward_hidden_states, actions=None) policy_logits_pool = network_output['policy_logits'] # network output process -pred_values_pool = inverse_scalar_transform(pred_values_pool, policy_config.model.support_scale).detach().cpu().numpy() +discrete_support = DiscreteSupport(*policy_config.model.value_support_range) +pred_values_pool = inverse_scalar_transform(pred_values_pool, discrete_support).detach().cpu().numpy() latent_state_roots = latent_state_roots.detach().cpu().numpy() reward_hidden_state_roots = ( reward_hidden_state_roots[0].detach().cpu().numpy(), reward_hidden_state_roots[1].detach().cpu().numpy() @@ -201,8 +203,9 @@ def test_mcts_vs_bot_to_play_large(): policy_logits_pool = network_output['policy_logits'] # network output process + discrete_support = DiscreteSupport(*policy_config.model.value_support_range) pred_values_pool = inverse_scalar_transform(pred_values_pool, - policy_config.model.support_scale).detach().cpu().numpy() + discrete_support).detach().cpu().numpy() latent_state_roots = latent_state_roots.detach().cpu().numpy() reward_hidden_state_roots = ( reward_hidden_state_roots[0].detach().cpu().numpy(), reward_hidden_state_roots[1].detach().cpu().numpy() diff --git a/lzero/mcts/tests/test_mcts_ptree.py b/lzero/mcts/tests/test_mcts_ptree.py index e27f31a53..43c79246b 100644 --- a/lzero/mcts/tests/test_mcts_ptree.py +++ b/lzero/mcts/tests/test_mcts_ptree.py @@ -1,7 +1,7 @@ import pytest import torch from easydict import EasyDict -from lzero.policy import inverse_scalar_transform, select_action +from lzero.policy import DiscreteSupport, inverse_scalar_transform, select_action import numpy as np from lzero.mcts.tree_search.mcts_ptree import EfficientZeroMCTSPtree as MCTSPtree @@ -74,7 +74,8 @@ def recurrent_inference(self, hidden_states, reward_hidden_states, actions): model=dict( action_space_size=9, categorical_distribution=True, - support_scale=300, + reward_support_range=(-300., 301., 1.), + value_support_range=(-300., 301., 1.), ), env_type='not_board_games', ) @@ -100,7 +101,8 @@ def recurrent_inference(self, hidden_states, reward_hidden_states, actions): policy_logits_pool = network_output['policy_logits'] # network output process -pred_values_pool = inverse_scalar_transform(pred_values_pool, policy_config.model.support_scale).detach().cpu().numpy() +discrete_support = DiscreteSupport(*policy_config.model.value_support_range) +pred_values_pool = inverse_scalar_transform(pred_values_pool, discrete_support).detach().cpu().numpy() latent_state_roots = latent_state_roots.detach().cpu().numpy() reward_hidden_state_state = ( reward_hidden_state_state[0].detach().cpu().numpy(), reward_hidden_state_state[1].detach().cpu().numpy() diff --git a/lzero/mcts/tests/test_mcts_sampled_ctree.py b/lzero/mcts/tests/test_mcts_sampled_ctree.py index 72a06bd05..fcd8192ae 100644 --- a/lzero/mcts/tests/test_mcts_sampled_ctree.py +++ b/lzero/mcts/tests/test_mcts_sampled_ctree.py @@ -1,7 +1,7 @@ import pytest import torch from easydict import EasyDict -from lzero.policy import inverse_scalar_transform +from lzero.policy import DiscreteSupport, inverse_scalar_transform class MuZeroModelFake(torch.nn.Module): @@ -80,7 +80,8 @@ def test_mcts(): value_delta_max=0, model=dict( continuous_action_space=True, - support_scale=300, + reward_support_range=(-300., 301., 1.), + value_support_range=(-300., 301., 1.), action_space_size=2, categorical_distribution=True, ), @@ -106,8 +107,9 @@ def test_mcts(): policy_logits_pool = network_output['policy_logits'] # network output process + discrete_support = DiscreteSupport(*policy_config.model.value_support_range) pred_values_pool = inverse_scalar_transform(pred_values_pool, - policy_config.model.support_scale).detach().cpu().numpy() + discrete_support).detach().cpu().numpy() latent_state_roots = latent_state_roots.detach().cpu().numpy() reward_hidden_state_state = ( reward_hidden_state_state[0].detach().cpu().numpy(), reward_hidden_state_state[1].detach().cpu().numpy() diff --git a/lzero/mcts/tree_search/mcts_ctree.py b/lzero/mcts/tree_search/mcts_ctree.py index b9041c639..4efcf1688 100644 --- a/lzero/mcts/tree_search/mcts_ctree.py +++ b/lzero/mcts/tree_search/mcts_ctree.py @@ -8,7 +8,7 @@ from lzero.mcts.ctree.ctree_efficientzero import ez_tree as tree_efficientzero from lzero.mcts.ctree.ctree_gumbel_muzero import gmz_tree as tree_gumbel_muzero from lzero.mcts.ctree.ctree_muzero import mz_tree as tree_muzero -from lzero.policy import InverseScalarTransform, to_detach_cpu_numpy +from lzero.policy import DiscreteSupport, InverseScalarTransform, to_detach_cpu_numpy if TYPE_CHECKING: from lzero.mcts.ctree.ctree_efficientzero import ez_tree as ez_ctree @@ -56,9 +56,10 @@ def __init__(self, cfg: EasyDict = None) -> None: default_config = self.default_config() default_config.update(cfg) self._cfg = default_config - self.inverse_scalar_transform_handle = InverseScalarTransform( - self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution - ) + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) + self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) + self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) @classmethod def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "mz_ctree": @@ -94,6 +95,10 @@ def search( # preparation some constant batch_size = roots.num + + # Store the latent state of each possible action at the MCTS root for each environment. + first_action_latent_map = {env_id: {} for env_id in range(batch_size)} # {env_id: {action: latent_state}} + pb_c_base, pb_c_init, discount_factor = self._cfg.pb_c_base, self._cfg.pb_c_init, self._cfg.discount_factor # the data storage of latent states: storing the latent state of all the nodes in the search. latent_state_batch_in_search_path = [latent_state_roots] @@ -133,22 +138,13 @@ def search( for ix, iy in zip(latent_state_index_in_search_path, latent_state_index_in_batch): latent_states.append(latent_state_batch_in_search_path[ix][iy]) - # latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device) try: - # print ("latent_state_roots.shape:", latent_state_roots.shape) - # print ("latent_states[0].shape:", latent_states[0].shape) - # print ("latent_states[1].shape:", latent_states[1].shape) - # import ipdb; ipdb.set_trace() latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device) except Exception as e: print("="*20) print(e) - # print("latent_states raw:", latent_states) print("roots:", roots, "latent_state_roots:", latent_state_roots) print ("latent_state_roots.shape:", latent_state_roots.shape) - # if not all(isinstance(x, np.ndarray) and x.shape == latent_states[0].shape for x in latent_states): - # raise ValueError(f"Inconsistent latent_states shapes: {[x.shape if isinstance(x, np.ndarray) else type(x) for x in latent_states]}") - import ipdb; ipdb.set_trace() # TODO: .long() is only for discrete action @@ -179,18 +175,26 @@ def search( # for UniZero if task_id is not None: # multi task setting - network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth, timestep, task_id=task_id) + # network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth, timestep, task_id=task_id) + network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth, task_id=task_id) else: # single task setting network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth, timestep) network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) - network_output.value = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.value)) - network_output.reward = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.reward)) - + network_output.value = to_detach_cpu_numpy(self.value_inverse_scalar_transform_handle(network_output.value)) + network_output.reward = to_detach_cpu_numpy(self.reward_inverse_scalar_transform_handle(network_output.reward)) + + for env_id in range(batch_size): + depth = search_depth[env_id] + action = last_actions[env_id].item() + if depth == 1 and action not in first_action_latent_map[env_id]: + first_action_latent_map[env_id][action] = network_output.latent_state[env_id] + else: + continue + latent_state_batch_in_search_path.append(network_output.latent_state) - # tolist() is to be compatible with cpp datatype. reward_batch = network_output.reward.reshape(-1).tolist() value_batch = network_output.value.reshape(-1).tolist() @@ -206,6 +210,8 @@ def search( current_latent_state_index, discount_factor, reward_batch, value_batch, policy_logits_batch, min_max_stats_lst, results, virtual_to_play_batch ) + + return first_action_latent_map class MuZeroMCTSCtree(object): @@ -247,9 +253,10 @@ def __init__(self, cfg: EasyDict = None) -> None: default_config = self.default_config() default_config.update(cfg) self._cfg = default_config - self.inverse_scalar_transform_handle = InverseScalarTransform( - self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution - ) + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) + self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) + self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) @classmethod def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "mz_ctree": @@ -345,8 +352,8 @@ def search( network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) - network_output.value = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.value)) - network_output.reward = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.reward)) + network_output.value = to_detach_cpu_numpy(self.value_inverse_scalar_transform_handle(network_output.value)) + network_output.reward = to_detach_cpu_numpy(self.reward_inverse_scalar_transform_handle(network_output.reward)) latent_state_batch_in_search_path.append(network_output.latent_state) @@ -438,9 +445,9 @@ def search_with_reuse( network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) network_output.value = to_detach_cpu_numpy( - self.inverse_scalar_transform_handle(network_output.value)) + self.value_inverse_scalar_transform_handle(network_output.value)) network_output.reward = to_detach_cpu_numpy( - self.inverse_scalar_transform_handle(network_output.reward)) + self.reward_inverse_scalar_transform_handle(network_output.reward)) latent_state_batch_in_search_path.append(network_output.latent_state) reward_batch = network_output.reward.reshape(-1).tolist() @@ -522,9 +529,10 @@ def __init__(self, cfg: EasyDict = None) -> None: # Update the default configuration with the values provided by the user in ``cfg``. default_config.update(cfg) self._cfg = default_config - self.inverse_scalar_transform_handle = InverseScalarTransform( - self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution - ) + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) + self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) + self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) @classmethod def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "ez_ctree.Roots": @@ -642,8 +650,8 @@ def search( ) network_output.predict_next_latent_state = to_detach_cpu_numpy(network_output.predict_next_latent_state) network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) - network_output.value = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.value)) - network_output.value_prefix = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.value_prefix)) + network_output.value = to_detach_cpu_numpy(self.value_inverse_scalar_transform_handle(network_output.value)) + network_output.value_prefix = to_detach_cpu_numpy(self.value_inverse_scalar_transform_handle(network_output.value_prefix)) network_output.reward_hidden_state = network_output.reward_hidden_state.detach().cpu().numpy() latent_state_batch_in_search_path.append(network_output.predict_next_latent_state) @@ -722,9 +730,10 @@ def __init__(self, cfg: EasyDict = None) -> None: # Update the default configuration with the values provided by the user in ``cfg``. default_config.update(cfg) self._cfg = default_config - self.inverse_scalar_transform_handle = InverseScalarTransform( - self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution - ) + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) + self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) + self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) @classmethod def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "ez_ctree.Roots": @@ -835,9 +844,9 @@ def search( network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) - network_output.value = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.value)) + network_output.value = to_detach_cpu_numpy(self.value_inverse_scalar_transform_handle(network_output.value)) network_output.value_prefix = to_detach_cpu_numpy( - self.inverse_scalar_transform_handle(network_output.value_prefix)) + self.value_inverse_scalar_transform_handle(network_output.value_prefix)) network_output.reward_hidden_state = ( network_output.reward_hidden_state[0].detach().cpu().numpy(), @@ -954,9 +963,9 @@ def search_with_reuse( network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) network_output.value = to_detach_cpu_numpy( - self.inverse_scalar_transform_handle(network_output.value)) + self.value_inverse_scalar_transform_handle(network_output.value)) network_output.value_prefix = to_detach_cpu_numpy( - self.inverse_scalar_transform_handle(network_output.value_prefix)) + self.value_inverse_scalar_transform_handle(network_output.value_prefix)) network_output.reward_hidden_state = ( network_output.reward_hidden_state[0].detach().cpu().numpy(), @@ -1052,9 +1061,10 @@ def __init__(self, cfg: EasyDict = None) -> None: # Update the default configuration with the values provided by the user in ``cfg``. default_config.update(cfg) self._cfg = default_config - self.inverse_scalar_transform_handle = InverseScalarTransform( - self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution - ) + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) + self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) + self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) @classmethod def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "gmz_ctree": @@ -1146,8 +1156,8 @@ def search(self, roots: Any, model: torch.nn.Module, latent_state_roots: List[An network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) - network_output.value = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.value)) - network_output.reward = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.reward)) + network_output.value = to_detach_cpu_numpy(self.value_inverse_scalar_transform_handle(network_output.value)) + network_output.reward = to_detach_cpu_numpy(self.reward_inverse_scalar_transform_handle(network_output.reward)) latent_state_batch_in_search_path.append(network_output.latent_state) # tolist() is to be compatible with cpp datatype. diff --git a/lzero/mcts/tree_search/mcts_ctree_sampled.py b/lzero/mcts/tree_search/mcts_ctree_sampled.py index 19c9f0140..7ab0d210e 100644 --- a/lzero/mcts/tree_search/mcts_ctree_sampled.py +++ b/lzero/mcts/tree_search/mcts_ctree_sampled.py @@ -7,7 +7,7 @@ from lzero.mcts.ctree.ctree_sampled_efficientzero import ezs_tree as tree_sampled_efficientzero from lzero.mcts.ctree.ctree_sampled_muzero import smz_tree as tree_sampled_muzero -from lzero.policy import InverseScalarTransform, to_detach_cpu_numpy +from lzero.policy import DiscreteSupport, InverseScalarTransform, to_detach_cpu_numpy if TYPE_CHECKING: from lzero.mcts.ctree.ctree_sampled_efficientzero import ezs_tree as ezs_ctree @@ -53,9 +53,10 @@ def __init__(self, cfg: EasyDict = None) -> None: default_config = self.default_config() default_config.update(cfg) self._cfg = default_config - self.inverse_scalar_transform_handle = InverseScalarTransform( - self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution - ) + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) + self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) + self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) @classmethod def roots( @@ -140,17 +141,7 @@ def search( for ix, iy in zip(latent_state_index_in_search_path, latent_state_index_in_batch): latent_states.append(latent_state_batch_in_search_path[ix][iy]) - # try: latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device) - # except Exception as e: - # print("="*20) - # print(e) - # # print("latent_states raw:", latent_states) - # print("roots:", roots, "latent_state_roots:", latent_state_roots) - # print ("latent_state_roots.shape:", latent_state_roots.shape) - # # if not all(isinstance(x, np.ndarray) and x.shape == latent_states[0].shape for x in latent_states): - # # raise ValueError(f"Inconsistent latent_states shapes: {[x.shape if isinstance(x, np.ndarray) else type(x) for x in latent_states]}") - # import ipdb; ipdb.set_trace() if self._cfg.model.continuous_action_space is True: # continuous action @@ -169,24 +160,24 @@ def search( MCTS stage 3: Backup At the end of the simulation, the statistics along the trajectory are updated. """ + # search_depth is used for rope in UniZero + search_depth = results.get_search_len() # for Sampled UniZero if task_id is not None: # multi task setting - network_output = model.recurrent_inference(state_action_history, simulation_index, latent_state_index_in_search_path, timestep, task_id=task_id) + network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth, task_id=task_id) else: # single task setting - network_output = model.recurrent_inference(state_action_history, simulation_index, latent_state_index_in_search_path, timestep) + network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth, timestep) network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) - network_output.value = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.value)) - network_output.reward = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.reward)) + network_output.value = to_detach_cpu_numpy(self.value_inverse_scalar_transform_handle(network_output.value)) + network_output.reward = to_detach_cpu_numpy(self.reward_inverse_scalar_transform_handle(network_output.reward)) latent_state_batch_in_search_path.append(network_output.latent_state) - # print("network_output.latent_state.shape:", network_output.latent_state.shape) - # tolist() is to be compatible with cpp datatype. reward_batch = network_output.reward.reshape(-1).tolist() value_batch = network_output.value.reshape(-1).tolist() @@ -260,9 +251,10 @@ def __init__(self, cfg: EasyDict = None) -> None: # Update the default configuration with the values provided by the user in ``cfg``. default_config.update(cfg) self._cfg = default_config - self.inverse_scalar_transform_handle = InverseScalarTransform( - self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution - ) + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) + self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) + self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) @classmethod def roots( @@ -376,8 +368,8 @@ def search( [ network_output.latent_state, network_output.policy_logits, - self.inverse_scalar_transform_handle(network_output.value), - self.inverse_scalar_transform_handle(network_output.reward), + self.value_inverse_scalar_transform_handle(network_output.value), + self.reward_inverse_scalar_transform_handle(network_output.reward), ] ) latent_state_batch_in_search_path.append(network_output.latent_state) @@ -454,9 +446,10 @@ def __init__(self, cfg: EasyDict = None) -> None: # Update the default configuration with the values provided by the user in ``cfg``. default_config.update(cfg) self._cfg = default_config - self.inverse_scalar_transform_handle = InverseScalarTransform( - self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution - ) + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) + self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) + self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) @classmethod def roots( @@ -580,8 +573,8 @@ def search( [ network_output.latent_state, network_output.policy_logits, - self.inverse_scalar_transform_handle(network_output.value), - self.inverse_scalar_transform_handle(network_output.value_prefix), + self.value_inverse_scalar_transform_handle(network_output.value), + self.value_inverse_scalar_transform_handle(network_output.value_prefix), ] ) network_output.reward_hidden_state = ( diff --git a/lzero/mcts/tree_search/mcts_ctree_stochastic.py b/lzero/mcts/tree_search/mcts_ctree_stochastic.py index ab08fddd6..d82d242ee 100644 --- a/lzero/mcts/tree_search/mcts_ctree_stochastic.py +++ b/lzero/mcts/tree_search/mcts_ctree_stochastic.py @@ -5,7 +5,7 @@ import torch from easydict import EasyDict -from lzero.policy import InverseScalarTransform +from lzero.policy import DiscreteSupport, InverseScalarTransform from lzero.mcts.ctree.ctree_stochastic_muzero import stochastic_mz_tree @@ -64,9 +64,10 @@ def __init__(self, cfg: EasyDict = None) -> None: # Update the default configuration with the values provided by the user in ``cfg``. default_config.update(cfg) self._cfg = default_config - self.inverse_scalar_transform_handle = InverseScalarTransform( - self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution - ) + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) + self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) + self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) @classmethod def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any], @@ -198,8 +199,8 @@ def process_nodes(nodes_index, is_chance): reward_splits, policy_logits_splits)): if not model.training: - value = self.inverse_scalar_transform_handle(value).detach().cpu().numpy() - reward = self.inverse_scalar_transform_handle(reward).detach().cpu().numpy() + value = self.value_inverse_scalar_transform_handle(value).detach().cpu().numpy() + reward = self.reward_inverse_scalar_transform_handle(reward).detach().cpu().numpy() latent_state = latent_state.detach().cpu().numpy() policy_logits = policy_logits.detach().cpu().numpy() diff --git a/lzero/mcts/tree_search/mcts_ptree.py b/lzero/mcts/tree_search/mcts_ptree.py index 3e0cda9af..564dac529 100644 --- a/lzero/mcts/tree_search/mcts_ptree.py +++ b/lzero/mcts/tree_search/mcts_ptree.py @@ -8,7 +8,7 @@ import lzero.mcts.ptree.ptree_ez as tree_efficientzero import lzero.mcts.ptree.ptree_mz as tree_muzero from lzero.mcts.ptree import MinMaxStatsList -from lzero.policy import InverseScalarTransform, to_detach_cpu_numpy +from lzero.policy import DiscreteSupport, InverseScalarTransform, to_detach_cpu_numpy if TYPE_CHECKING: import lzero.mcts.ptree.ptree_ez as ez_ptree @@ -71,9 +71,10 @@ def __init__(self, cfg: EasyDict = None) -> None: # Update the default configuration with the values provided by the user in ``cfg``. default_config.update(cfg) self._cfg = default_config - self.inverse_scalar_transform_handle = InverseScalarTransform( - self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution - ) + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) + self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) + self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) @classmethod def roots(cls: int, root_num: int, legal_actions: List[Any]) -> "mz_ptree.Roots": @@ -171,8 +172,8 @@ def search( [ network_output.latent_state, network_output.policy_logits, - self.inverse_scalar_transform_handle(network_output.value), - self.inverse_scalar_transform_handle(network_output.reward), + self.value_inverse_scalar_transform_handle(network_output.value), + self.reward_inverse_scalar_transform_handle(network_output.reward), ] ) @@ -250,9 +251,10 @@ def __init__(self, cfg: EasyDict = None) -> None: # Update the default configuration with the values provided by the user in ``cfg``. default_config.update(cfg) self._cfg = default_config - self.inverse_scalar_transform_handle = InverseScalarTransform( - self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution - ) + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) + self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) + self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) @classmethod def roots(cls: int, root_num: int, legal_actions: List[Any]) -> "ez_ptree.Roots": @@ -367,8 +369,8 @@ def search( [ network_output.latent_state, network_output.policy_logits, - self.inverse_scalar_transform_handle(network_output.value), - self.inverse_scalar_transform_handle(network_output.value_prefix), + self.value_inverse_scalar_transform_handle(network_output.value), + self.value_inverse_scalar_transform_handle(network_output.value_prefix), ] ) network_output.reward_hidden_state = ( diff --git a/lzero/mcts/tree_search/mcts_ptree_sampled.py b/lzero/mcts/tree_search/mcts_ptree_sampled.py index eeefc55d6..896d803ff 100644 --- a/lzero/mcts/tree_search/mcts_ptree_sampled.py +++ b/lzero/mcts/tree_search/mcts_ptree_sampled.py @@ -6,7 +6,7 @@ from easydict import EasyDict from lzero.mcts.ptree import MinMaxStatsList -from lzero.policy import InverseScalarTransform, to_detach_cpu_numpy +from lzero.policy import DiscreteSupport, InverseScalarTransform, to_detach_cpu_numpy if TYPE_CHECKING: import lzero.mcts.ptree.ptree_sez as ptree @@ -71,9 +71,10 @@ def __init__(self, cfg: EasyDict = None) -> None: # Update the default configuration with the values provided by the user in ``cfg``. default_config.update(cfg) self._cfg = default_config - self.inverse_scalar_transform_handle = InverseScalarTransform( - self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution - ) + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) + self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) + self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) @classmethod def roots( @@ -202,8 +203,8 @@ def search( [ network_output.latent_state, network_output.policy_logits, - self.inverse_scalar_transform_handle(network_output.value), - self.inverse_scalar_transform_handle(network_output.value_prefix), + self.value_inverse_scalar_transform_handle(network_output.value), + self.value_inverse_scalar_transform_handle(network_output.value_prefix), ] ) network_output.reward_hidden_state = ( diff --git a/lzero/mcts/tree_search/mcts_ptree_stochastic.py b/lzero/mcts/tree_search/mcts_ptree_stochastic.py index 48058e510..52587d242 100644 --- a/lzero/mcts/tree_search/mcts_ptree_stochastic.py +++ b/lzero/mcts/tree_search/mcts_ptree_stochastic.py @@ -7,7 +7,7 @@ import lzero.mcts.ptree.ptree_stochastic_mz as tree_stochastic_muzero from lzero.mcts.ptree import MinMaxStatsList -from lzero.policy import InverseScalarTransform +from lzero.policy import DiscreteSupport, InverseScalarTransform if TYPE_CHECKING: import lzero.mcts.ptree.ptree_stochastic_mz as stochastic_mz_ptree @@ -69,9 +69,10 @@ def __init__(self, cfg: EasyDict = None) -> None: # Update the default configuration with the values provided by the user in ``cfg``. default_config.update(cfg) self._cfg = default_config - self.inverse_scalar_transform_handle = InverseScalarTransform( - self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution - ) + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) + self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) + self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) @classmethod def roots(cls: int, root_num: int, legal_actions: List[Any]) -> "stochastic_mz_ptree.Roots": @@ -209,8 +210,8 @@ def process_nodes(node_indices, is_chance): reward_splits, policy_logits_splits)): if not model.training: - value = self.inverse_scalar_transform_handle(value).detach().cpu().numpy() - reward = self.inverse_scalar_transform_handle(reward).detach().cpu().numpy() + value = self.value_inverse_scalar_transform_handle(value).detach().cpu().numpy() + reward = self.reward_inverse_scalar_transform_handle(reward).detach().cpu().numpy() latent_state = latent_state.detach().cpu().numpy() policy_logits = policy_logits.detach().cpu().numpy() diff --git a/lzero/model/alphazero_model.py b/lzero/model/alphazero_model.py index 765f5dfeb..d541794e9 100644 --- a/lzero/model/alphazero_model.py +++ b/lzero/model/alphazero_model.py @@ -34,7 +34,7 @@ def __init__( policy_head_channels: int = 16, value_head_hidden_channels: SequenceType = [32], policy_head_hidden_channels: SequenceType = [32], - value_support_size: int = 601, + value_support_range: SequenceType =(-300., 301., 1.), # ============================================================== # specific sampled related config # ============================================================== @@ -68,13 +68,13 @@ def __init__( - policy_head_channels (:obj:`int`): The channels of policy head. - value_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). - policy_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). - - value_support_size (:obj:`int`): The size of categorical value. + - value_support_range (:obj:`SequenceType`): The range of categorical value output. """ super(AlphaZeroModel, self).__init__() - self.categorical_distribution = categorical_distribution self.observation_shape = observation_shape + self.categorical_distribution = categorical_distribution if self.categorical_distribution: - self.value_support_size = value_support_size + self.value_support_size = len(torch.arange(*value_support_range)) else: self.value_support_size = 1 diff --git a/lzero/model/common.py b/lzero/model/common.py index cc607c2af..88186f711 100644 --- a/lzero/model/common.py +++ b/lzero/model/common.py @@ -1,32 +1,35 @@ """ Overview: - In this Python file, we provide a collection of reusable model templates designed to streamline the development + This Python file provides a collection of reusable model templates designed to streamline the development process for various custom algorithms. By utilizing these pre-built model templates, users can quickly adapt and - customize their custom algorithms, ensuring efficient and effective development. - BTW, users can refer to the unittest of these model templates to learn how to use them. + customize their algorithms, ensuring efficient and effective development. + Users can refer to the unittest of these model templates to learn how to use them. """ import math from dataclasses import dataclass -from typing import Callable, List, Optional -from typing import Tuple +from typing import Callable, List, Optional, Tuple, Sequence import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init +from ditk import logging +# Assuming these imports are valid in the user's environment. +# If they are not, they should be replaced with the correct ones. from ding.torch_utils import MLP, ResBlock from ding.torch_utils.network.normalization import build_normalization -from ding.utils import SequenceType -from ditk import logging +from ding.utils import SequenceType, get_rank, get_world_size +from transformers import AutoModelForCausalLM, AutoTokenizer from ding.utils import set_pkg_seed, get_rank, get_world_size -import torch + + def MLP_V2( in_channels: int, hidden_channels: List[int], out_channels: int, - layer_fn: Callable = None, + layer_fn: Callable = nn.Linear, activation: Optional[nn.Module] = None, norm_type: Optional[str] = None, use_dropout: bool = False, @@ -34,118 +37,122 @@ def MLP_V2( output_activation: bool = True, output_norm: bool = True, last_linear_layer_init_zero: bool = False, -): +) -> nn.Sequential: """ Overview: - Create a multi-layer perceptron (MLP) using a list of hidden dimensions. Each layer consists of a fully + Creates a multi-layer perceptron (MLP) using a list of hidden dimensions. Each layer consists of a fully connected block with optional activation, normalization, and dropout. The final layer is configurable - to include or exclude activation, normalization, and dropout based on user preferences. - + to include or exclude activation and normalization. Arguments: - in_channels (:obj:`int`): Number of input channels (dimensionality of the input tensor). - hidden_channels (:obj:`List[int]`): A list specifying the number of channels for each hidden layer. - For example, [512, 256, 128] means the MLP will have three hidden layers with 512, 256, and 128 units, respectively. - out_channels (:obj:`int`): Number of output channels (dimensionality of the output tensor). - - layer_fn (:obj:`Callable`, optional): Layer function to construct layers (default is `nn.Linear`). - - activation (:obj:`nn.Module`, optional): Activation function to use after each layer - (e.g., `nn.ReLU`, `nn.Sigmoid`). Default is None (no activation). - - norm_type (:obj:`str`, optional): Type of normalization to apply after each layer. - If None, no normalization is applied. Supported values depend on the implementation of `build_normalization`. - - use_dropout (:obj:`bool`, optional): Whether to apply dropout after each layer. Default is False. - - dropout_probability (:obj:`float`, optional): The probability of setting elements to zero in dropout. Default is 0.5. - - output_activation (:obj:`bool`, optional): Whether to apply activation to the output layer. Default is True. - - output_norm (:obj:`bool`, optional): Whether to apply normalization to the output layer. Default is True. - - last_linear_layer_init_zero (:obj:`bool`, optional): Whether to initialize the weights and biases of the - last linear layer to zeros. This is commonly used in reinforcement learning for stable initial outputs. - + - layer_fn (:obj:`Callable`): The function to construct layers, defaults to `nn.Linear`. + - activation (:obj:`Optional[nn.Module]`): Activation function to use after each layer, defaults to None. + - norm_type (:obj:`Optional[str]`): Type of normalization to apply. If None, no normalization is applied. + - use_dropout (:obj:`bool`): Whether to apply dropout after each layer, defaults to False. + - dropout_probability (:obj:`float`): The probability for dropout, defaults to 0.5. + - output_activation (:obj:`bool`): Whether to apply activation to the output layer, defaults to True. + - output_norm (:obj:`bool`): Whether to apply normalization to the output layer, defaults to True. + - last_linear_layer_init_zero (:obj:`bool`): Whether to initialize the last linear layer's weights and biases to zero. Returns: - block (:obj:`nn.Sequential`): A PyTorch `nn.Sequential` object containing the layers of the MLP. - - Notes: - - The final layer's normalization, activation, and dropout are controlled by `output_activation`, - `output_norm`, and `use_dropout`. - - If `last_linear_layer_init_zero` is True, the weights and biases of the last linear layer are initialized to 0. """ - assert len(hidden_channels) > 0, "The hidden_channels list must contain at least one element." - if layer_fn is None: - layer_fn = nn.Linear - - # Initialize the MLP block - block = [] - channels = [in_channels] + hidden_channels + [out_channels] - - # Build all layers except the final layer - for i, (in_channels, out_channels) in enumerate(zip(channels[:-2], channels[1:-1])): - block.append(layer_fn(in_channels, out_channels)) - if norm_type is not None: - block.append(build_normalization(norm_type, dim=1)(out_channels)) - if activation is not None: - block.append(activation) - if use_dropout: - block.append(nn.Dropout(dropout_probability)) - - # Build the final layer - in_channels = channels[-2] - out_channels = channels[-1] - block.append(layer_fn(in_channels, out_channels)) - - # Add optional normalization and activation for the final layer - if output_norm and norm_type is not None: - block.append(build_normalization(norm_type, dim=1)(out_channels)) - if output_activation and activation is not None: - block.append(activation) - if use_dropout: - block.append(nn.Dropout(dropout_probability)) - - # Initialize the weights and biases of the last linear layer to zero if specified + if not hidden_channels: + logging.warning("hidden_channels is empty, creating a single-layer MLP.") + + layers = [] + all_channels = [in_channels] + hidden_channels + [out_channels] + num_layers = len(all_channels) - 1 + + for i in range(num_layers): + is_last_layer = (i == num_layers - 1) + layers.append(layer_fn(all_channels[i], all_channels[i+1])) + + if not is_last_layer: + # Intermediate layers + if norm_type: + layers.append(build_normalization(norm_type, dim=1)(all_channels[i+1])) + if activation: + layers.append(activation) + if use_dropout: + layers.append(nn.Dropout(dropout_probability)) + else: + # Last layer + if output_norm and norm_type: + layers.append(build_normalization(norm_type, dim=1)(all_channels[i+1])) + if output_activation and activation: + layers.append(activation) + # Note: Dropout on the final output is usually not recommended unless for specific regularization purposes. + # The original logic applied it, so we keep it for consistency. + if use_dropout: + layers.append(nn.Dropout(dropout_probability)) + + # Initialize the last linear layer to zero if specified if last_linear_layer_init_zero: - for layer in reversed(block): + for layer in reversed(layers): if isinstance(layer, nn.Linear): nn.init.zeros_(layer.weight) nn.init.zeros_(layer.bias) break - return nn.Sequential(*block) + return nn.Sequential(*layers) + + +# --- Data-structures for Network Outputs --- -# use dataclass to make the output of network more convenient to use @dataclass class MZRNNNetworkOutput: - # output format of the MuZeroRNN model + """ + Overview: + Data structure for the output of the MuZeroRNN model. + """ value: torch.Tensor value_prefix: torch.Tensor policy_logits: torch.Tensor latent_state: torch.Tensor predict_next_latent_state: torch.Tensor - reward_hidden_state: Tuple[torch.Tensor] + reward_hidden_state: Tuple[torch.Tensor, torch.Tensor] @dataclass class EZNetworkOutput: - # output format of the EfficientZero model + """ + Overview: + Data structure for the output of the EfficientZero model. + """ value: torch.Tensor value_prefix: torch.Tensor policy_logits: torch.Tensor latent_state: torch.Tensor - reward_hidden_state: Tuple[torch.Tensor] + reward_hidden_state: Tuple[torch.Tensor, torch.Tensor] @dataclass class MZNetworkOutput: - # output format of the MuZero model + """ + Overview: + Data structure for the output of the MuZero model. + """ value: torch.Tensor reward: torch.Tensor policy_logits: torch.Tensor latent_state: torch.Tensor +# --- Core Network Components --- + class SimNorm(nn.Module): + """ + Overview: + Implements Simplicial Normalization as described in the paper: https://arxiv.org/abs/2204.00616. + It groups features and applies softmax to each group. + """ def __init__(self, simnorm_dim: int) -> None: """ - Overview: - Simplicial normalization. Adapted from https://arxiv.org/abs/2204.00616. Arguments: - - simnorm_dim (:obj:`int`): The dimension for simplicial normalization. + - simnorm_dim (:obj:`int`): The size of each group (simplex) to apply softmax over. """ super().__init__() self.dim = simnorm_dim @@ -153,331 +160,402 @@ def __init__(self, simnorm_dim: int) -> None: def forward(self, x: torch.Tensor) -> torch.Tensor: """ Overview: - Forward pass of the SimNorm layer. + Forward pass for SimNorm. Arguments: - - x (:obj:`torch.Tensor`): The input tensor to normalize. + - x (:obj:`torch.Tensor`): The input tensor. Returns: - - x (:obj:`torch.Tensor`): The normalized tensor. + - (:obj:`torch.Tensor`): The tensor after applying Simplicial Normalization. """ - shp = x.shape - # Ensure that there is at least one simplex to normalize across. - if shp[1] != 0: - x = x.view(*shp[:-1], -1, self.dim) - x = F.softmax(x, dim=-1) - return x.view(*shp) - else: + if x.shape[1] == 0: return x + # Reshape to (batch, groups, dim) + x_reshaped = x.view(*x.shape[:-1], -1, self.dim) + # Apply softmax over the last dimension (the simplex) + x_softmax = F.softmax(x_reshaped, dim=-1) + # Reshape back to the original tensor shape + return x_softmax.view(*x.shape) def __repr__(self) -> str: - """ - Overview: - String representation of the SimNorm layer. - Returns: - - output (:obj:`str`): The string representation. - """ return f"SimNorm(dim={self.dim})" -def AvgL1Norm(x, eps=1e-8): +def AvgL1Norm(x: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: """ Overview: - Normalize the input tensor by the L1 norm. + Normalizes a tensor by the mean of its absolute values (L1 norm) along the last dimension. Arguments: - x (:obj:`torch.Tensor`): The input tensor to normalize. - - eps (:obj:`float`): The epsilon value to prevent division by zero. + - eps (:obj:`float`): A small epsilon value to prevent division by zero. Returns: - - :obj:`torch.Tensor`: The normalized tensor. + - (:obj:`torch.Tensor`): The normalized tensor. """ - return x / x.abs().mean(-1, keepdim=True).clamp(min=eps) + return x / (x.abs().mean(dim=-1, keepdim=True) + eps) class FeatureAndGradientHook: + """ + Overview: + A utility class to capture and analyze features and gradients of a specific module during + the forward and backward passes. This is useful for debugging and understanding model dynamics. + """ - def __init__(self): + def __init__(self, module: nn.Module): """ - Overview: - Class to capture features and gradients at SimNorm. + Arguments: + - module (:obj:`nn.Module`): The PyTorch module to attach the hooks to. """ self.features_before = [] self.features_after = [] self.grads_before = [] self.grads_after = [] + self.forward_handler = module.register_forward_hook(self._forward_hook) + self.backward_handler = module.register_full_backward_hook(self._backward_hook) - def setup_hooks(self, model): - # Hooks to capture features and gradients at SimNorm - self.forward_handler = model.sim_norm.register_forward_hook(self.forward_hook) - self.backward_handler = model.sim_norm.register_full_backward_hook(self.backward_hook) - - def forward_hook(self, module, input, output): + def _forward_hook(self, module: nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor) -> None: + """Hook to capture input and output features during the forward pass.""" with torch.no_grad(): - self.features_before.append(input[0]) - self.features_after.append(output) + self.features_before.append(inputs[0].clone().detach()) + self.features_after.append(output.clone().detach()) - def backward_hook(self, module, grad_input, grad_output): + def _backward_hook(self, module: nn.Module, grad_inputs: Tuple[torch.Tensor], grad_outputs: Tuple[torch.Tensor]) -> None: + """Hook to capture input and output gradients during the backward pass.""" with torch.no_grad(): - self.grads_before.append(grad_input[0] if grad_input[0] is not None else None) - self.grads_after.append(grad_output[0] if grad_output[0] is not None else None) + self.grads_before.append(grad_inputs[0].clone().detach() if grad_inputs[0] is not None else None) + self.grads_after.append(grad_outputs[0].clone().detach() if grad_outputs[0] is not None else None) - def analyze(self): - # Calculate L2 norms of features - l2_norm_before = torch.mean(torch.stack([torch.norm(f, p=2, dim=1).mean() for f in self.features_before])) - l2_norm_after = torch.mean(torch.stack([torch.norm(f, p=2, dim=1).mean() for f in self.features_after])) + def analyze(self) -> Tuple[float, float, float, float]: + """ + Overview: + Analyzes the captured features and gradients by computing their average L2 norms. + This method clears the stored data after analysis to free memory. + Returns: + - (:obj:`Tuple[float, float, float, float]`): A tuple containing the L2 norms of + (features_before, features_after, grads_before, grads_after). + """ + if not self.features_before: + return 0.0, 0.0, 0.0, 0.0 - # Calculate norms of gradients - grad_norm_before = torch.mean( - torch.stack([torch.norm(g, p=2, dim=1).mean() for g in self.grads_before if g is not None])) - grad_norm_after = torch.mean( - torch.stack([torch.norm(g, p=2, dim=1).mean() for g in self.grads_after if g is not None])) + l2_norm_before = torch.mean(torch.stack([torch.norm(f, p=2) for f in self.features_before])).item() + l2_norm_after = torch.mean(torch.stack([torch.norm(f, p=2) for f in self.features_after])).item() - # Clear stored data and delete tensors to free memory - self.clear_data() + valid_grads_before = [g for g in self.grads_before if g is not None] + grad_norm_before = torch.mean(torch.stack([torch.norm(g, p=2) for g in valid_grads_before])).item() if valid_grads_before else 0.0 - # Optionally clear CUDA cache - if torch.cuda.is_available(): - torch.cuda.empty_cache() + valid_grads_after = [g for g in self.grads_after if g is not None] + grad_norm_after = torch.mean(torch.stack([torch.norm(g, p=2) for g in valid_grads_after])).item() if valid_grads_after else 0.0 + self.clear_data() return l2_norm_before, l2_norm_after, grad_norm_before, grad_norm_after - def clear_data(self): - del self.features_before[:] - del self.features_after[:] - del self.grads_before[:] - del self.grads_after[:] + def clear_data(self) -> None: + """Clears all stored feature and gradient tensors to free up memory.""" + self.features_before.clear() + self.features_after.clear() + self.grads_before.clear() + self.grads_after.clear() + if torch.cuda.is_available(): + torch.cuda.empty_cache() - def remove_hooks(self): + def remove_hooks(self) -> None: + """Removes the registered forward and backward hooks.""" self.forward_handler.remove() self.backward_handler.remove() class DownSample(nn.Module): + """ + Overview: + A convolutional network for downsampling image-based observations, commonly used in Atari environments. + It consists of a series of convolutional, normalization, and residual blocks. + """ - def __init__(self, observation_shape: SequenceType, out_channels: int, - activation: nn.Module = nn.ReLU(inplace=True), - norm_type: Optional[str] = 'BN', - num_resblocks: int = 1, - ) -> None: + def __init__( + self, + observation_shape: Sequence[int], + out_channels: int, + activation: nn.Module = nn.ReLU(inplace=True), + norm_type: str = 'BN', + num_resblocks: int = 1, + ) -> None: """ - Overview: - Define downSample convolution network. Encode the observation into hidden state. - This network is often used in video games like Atari. In board games like go and chess, - we don't need this module. Arguments: - - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[12, 96, 96] - for video games like atari, RGB 3 channel times stack 4 frames. - - out_channels (:obj:`int`): The output channels of output hidden state. - - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \ - Use the inplace operation to speed up. - - norm_type (:obj:`Optional[str]`): The normalization type used in network, defaults to 'BN'. - - num_resblocks (:obj:`int`): The number of residual blocks. Defaults to 1. + - observation_shape (:obj:`Sequence[int]`): The shape of the input observation, e.g., (C, H, W). + - out_channels (:obj:`int`): The number of output channels. + - activation (:obj:`nn.Module`): The activation function to use. + - norm_type (:obj:`str`): The type of normalization ('BN' or 'LN'). + - num_resblocks (:obj:`int`): The number of residual blocks in each stage. """ super().__init__() - assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" + if norm_type not in ['BN', 'LN']: + raise ValueError(f"Unsupported norm_type: {norm_type}. Must be 'BN' or 'LN'.") + # The original design was fixed to 1 resblock per stage. + if num_resblocks != 1: + logging.warning(f"DownSample is designed for num_resblocks=1, but got {num_resblocks}.") - assert num_resblocks == 1, "num_resblocks must be 1 in DownSample" - self.observation_shape = observation_shape - self.conv1 = nn.Conv2d( - observation_shape[0], - out_channels // 2, - kernel_size=3, - stride=2, - padding=1, - bias=False, # disable bias for better convergence - ) - if norm_type == 'BN': - self.norm1 = nn.BatchNorm2d(out_channels // 2) - elif norm_type == 'LN': - self.norm1 = nn.LayerNorm([out_channels // 2, observation_shape[-2] // 2, observation_shape[-1] // 2], - eps=1e-5) + self.activation = activation - self.resblocks1 = nn.ModuleList( - [ - ResBlock( - in_channels=out_channels // 2, - activation=activation, - norm_type=norm_type, - res_type='basic', - bias=False - ) for _ in range(num_resblocks) - ] - ) - self.downsample_block = ResBlock( - in_channels=out_channels // 2, - out_channels=out_channels, - activation=activation, - norm_type=norm_type, - res_type='downsample', - bias=False - ) - self.resblocks2 = nn.ModuleList( - [ - ResBlock( - in_channels=out_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False - ) for _ in range(num_resblocks) - ] - ) + # Initial convolution: stride 2 + self.conv1 = nn.Conv2d(observation_shape[0], out_channels // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.norm1 = build_normalization(norm_type, dim=2)(out_channels // 2) + + # Stage 1 with residual blocks + self.resblocks1 = nn.ModuleList([ + ResBlock(in_channels=out_channels // 2, activation=activation, norm_type=norm_type, res_type='basic', bias=False) + for _ in range(num_resblocks) + ]) + + # Downsample block: stride 2 + self.downsample_block = ResBlock(in_channels=out_channels // 2, out_channels=out_channels, activation=activation, norm_type=norm_type, res_type='downsample', bias=False) + + # Stage 2 with residual blocks + self.resblocks2 = nn.ModuleList([ + ResBlock(in_channels=out_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False) + for _ in range(num_resblocks) + ]) + + # Pooling 1: stride 2 self.pooling1 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) - self.resblocks3 = nn.ModuleList( - [ - ResBlock( - in_channels=out_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False - ) for _ in range(num_resblocks) - ] - ) + + # Stage 3 with residual blocks + self.resblocks3 = nn.ModuleList([ + ResBlock(in_channels=out_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False) + for _ in range(num_resblocks) + ]) + + # Final pooling for specific input sizes self.pooling2 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) - self.activation = activation def forward(self, x: torch.Tensor) -> torch.Tensor: """ Shapes: - - x (:obj:`torch.Tensor`): :math:`(B, C_in, W, H)`, where B is batch size, C_in is channel, W is width, \ - H is height. - - output (:obj:`torch.Tensor`): :math:`(B, C_out, W_, H_)`, where B is batch size, C_out is channel, W_ is \ - output width, H_ is output height. + - x (:obj:`torch.Tensor`): (B, C_in, H, W) + - output (:obj:`torch.Tensor`): (B, C_out, H_out, W_out) + x = self.norm1(x) """ x = self.conv1(x) - x = self.norm1(x) x = self.activation(x) for block in self.resblocks1: x = block(x) + x = self.downsample_block(x) for block in self.resblocks2: x = block(x) + x = self.pooling1(x) for block in self.resblocks3: x = block(x) - # 64, 84, 96 are the most common observation shapes in Atari games. - if self.observation_shape[1] == 64: - output = x - elif self.observation_shape[1] == 84: - x = self.pooling2(x) - output = x - elif self.observation_shape[1] == 96: - x = self.pooling2(x) - output = x + # This part handles specific Atari resolutions. A more general approach might be desirable, + # but we maintain original behavior. + obs_height = self.observation_shape[1] + if obs_height == 64: + return x + elif obs_height in [84, 96]: + return self.pooling2(x) else: - raise NotImplementedError(f"DownSample for observation shape {self.observation_shape} is not implemented now. " - f"You should transform the observation shape to 64 or 96 in the env.") - - return output - + raise NotImplementedError( + f"DownSample for observation height {obs_height} is not implemented. " + f"Supported heights are 64, 84, 96." + ) -class HFLanguageRepresentationNetwork(nn.Module): +class QwenNetwork(nn.Module): def __init__(self, - model_path: str = 'google-bert/bert-base-uncased', + model_path: str = 'Qwen/Qwen3-1.7B', embedding_size: int = 768, + final_norm_option_in_encoder: str = "layernorm", group_size: int = 8, - norm_type: str = "simnorm", - # norm_type: str = "layernorm", # TODO: Why does nan appear in the first step of training? tokenizer=None): + super().__init__() + + logging.info(f"Loading Qwen model from: {model_path}") + + local_rank = get_rank() + if local_rank == 0: + self.pretrained_model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype="auto", + device_map={"": local_rank}, + attn_implementation="flash_attention_2" + ) + if get_world_size() > 1: + torch.distributed.barrier() + if local_rank != 0: + self.pretrained_model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype="auto", + device_map={"": local_rank}, + attn_implementation="flash_attention_2" + ) + + for p in self.pretrained_model.parameters(): + p.requires_grad = False + + if tokenizer is None: + if local_rank == 0: + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + if get_world_size() > 1: + torch.distributed.barrier() + if local_rank != 0: + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + else: + self.tokenizer = tokenizer + + qwen_hidden_size = self.pretrained_model.config.hidden_size + + self.embedding_head = nn.Sequential( + nn.Linear(qwen_hidden_size, embedding_size), + self._create_norm_layer(final_norm_option_in_encoder, embedding_size, group_size) + ) + + def _create_norm_layer(self, norm_option, embedding_size, group_size): + if norm_option.lower() == "simnorm": + return SimNorm(simnorm_dim=group_size) + elif norm_option.lower() == "layernorm": + return nn.LayerNorm(embedding_size) + else: + raise NotImplementedError(f"Normalization type '{norm_option}' is not implemented.") + + def encode(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor: """ Overview: - This class defines a language representation network that utilizes a pretrained Hugging Face model. - The network outputs embeddings with the specified dimension and can optionally use SimNorm or LayerNorm - for normalization at the final stage to ensure training stability. + Encode the input token sequence `x` into a latent representation + using a pretrained language model backbone followed by a projection head. + Arguments: + - x (:obj:`torch.Tensor`): Input token ids of shape (B, L) + - no_grad (:obj:`bool`, optional, default=True): If True, encoding is performed under `torch.no_grad()` to save memory and computation (no gradient tracking). + Returns: + - latent (:obj:`torch.Tensor`): Encoded latent state of shape (B, D). + """ + pad_id = self.tokenizer.pad_token_id + attention_mask = (x != pad_id).long().to(x.device) + context = {'input_ids': x.long(), 'attention_mask': attention_mask} + if no_grad: + with torch.no_grad(): + outputs = self.pretrained_model(**context, output_hidden_states=True, return_dict=True) + else: + outputs = self.pretrained_model(**context, output_hidden_states=True, return_dict=True) + last_hidden = outputs.hidden_states[-1] + + B, L, H = last_hidden.size() + lengths = attention_mask.sum(dim=1) # [B] + positions = torch.clamp(lengths - 1, min=0) # [B] + batch_idx = torch.arange(B, device=last_hidden.device) + + selected = last_hidden[batch_idx, positions] # [B, H] + + latent = self.embedding_head(selected.to(self.embedding_head[0].weight.dtype)) + return latent + + def decode(self, embeddings: torch.Tensor, max_length: int = 512) -> str: + """ + Decodes embeddings into text via the decoder network. + """ + embeddings_detached = embeddings.detach() + self.pretrained_model.eval() + + # Directly generate using provided embeddings + with torch.no_grad(): + param = next(self.pretrained_model.parameters()) + embeddings = embeddings_detached.to(device=param.device, dtype=param.dtype) + gen_ids = self.pretrained_model.generate( + inputs_embeds=embeddings, + max_length=max_length + ) + texts = self.tokenizer.batch_decode(gen_ids, skip_special_tokens=True) + self.pretrained_model.train() + return texts[0] if len(texts) == 1 else texts + + def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor: + return self.encode(x, no_grad=no_grad) + + +class HFLanguageRepresentationNetwork(nn.Module): + def __init__(self, + model_path: str = 'google-bert/bert-base-uncased', + embedding_size: int = 768, + group_size: int = 8, + final_norm_option_in_encoder: str = "layernorm", + tokenizer=None): + """ Arguments: - model_path (str): The path to the pretrained Hugging Face model. Default is 'google-bert/bert-base-uncased'. - embedding_size (int): The dimension of the output embeddings. Default is 768. - group_size (int): The group size for SimNorm when using normalization. - - norm_type (str): The type of normalization to use ("simnorm" or "layernorm"). Default is "layernorm". + - final_norm_option_in_encoder (str): The type of normalization to use ("simnorm" or "layernorm"). Default is "layernorm". - tokenizer (Optional): An instance of a tokenizer. If None, the tokenizer will be loaded from the pretrained model. """ super().__init__() - from transformers import AutoModel, AutoTokenizer - logging.info(f"Loading model from: {model_path}") - # In distributed training, only the rank 0 process downloads the model, and other processes load from cache to speed up startup. + # In distributed settings, ensure only rank 0 downloads the model/tokenizer. if get_rank() == 0: - self.model = AutoModel.from_pretrained(model_path) + self.pretrained_model = AutoModel.from_pretrained(model_path) + if get_world_size() > 1: # Wait for rank 0 to finish loading the model. torch.distributed.barrier() if get_rank() != 0: - self.model = AutoModel.from_pretrained(model_path) + self.pretrained_model = AutoModel.from_pretrained(model_path) - if tokenizer is None: - # Only rank 0 downloads the tokenizer, and then other processes load it from cache. - if get_rank() == 0: - self.tokenizer = AutoTokenizer.from_pretrained(model_path) - if get_world_size() > 1: - torch.distributed.barrier() - if get_rank() != 0: + if get_rank() != 0: + logging.info(f"Worker process is loading model from cache: {model_path}") + self.model = AutoModel.from_pretrained(model_path) + if tokenizer is None: self.tokenizer = AutoTokenizer.from_pretrained(model_path) - else: + + if tokenizer is not None: self.tokenizer = tokenizer - # Set the embedding dimension. A linear projection is added (the dimension remains unchanged here but can be extended for other mappings). self.embedding_size = embedding_size - self.embed_proj_head = nn.Linear(self.model.config.hidden_size, self.embedding_size) + self.embed_proj_head = nn.Linear(self.pretrained_model.config.hidden_size, self.embedding_size) - # Select the normalization method based on the norm_type parameter. - if norm_type.lower() == "simnorm": + # # Select the normalization method based on the final_norm_option_in_encoder parameter. + if final_norm_option_in_encoder.lower() == "simnorm": self.norm = SimNorm(simnorm_dim=group_size) - elif norm_type.lower() == "layernorm": + elif final_norm_option_in_encoder.lower() == "layernorm": self.norm = nn.LayerNorm(embedding_size) else: - raise NotImplementedError(f"Normalization type '{norm_type}' is not implemented. " + raise NotImplementedError(f"Normalization type '{final_norm_option_in_encoder}' is not implemented. " f"Choose 'simnorm' or 'layernorm'.") def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor: """ - Forward Propagation: - Compute the language representation based on the input token sequence. - The [CLS] token’s representation is extracted from the output of the pretrained model, - then passed through a linear projection and final normalization layer (SimNorm or LayerNorm). - + Overview: + Computes language representation from input token IDs. Arguments: - - x (torch.Tensor): Input token sequence of shape [batch_size, seq_len]. - - no_grad (bool): Whether to run in no-gradient mode for memory efficiency. Default is True. + - x (:obj:`torch.Tensor`): Input token sequence of shape (B, seq_len). + - no_grad (:obj:`bool`): If True, run the transformer model in `torch.no_grad()` context. Returns: - - torch.Tensor: The processed language embedding with shape [batch_size, embedding_size]. + - (:obj:`torch.Tensor`): The final language embedding of shape (B, embedding_size). """ + # Construct the attention mask to exclude padding tokens. attention_mask = x != self.tokenizer.pad_token_id - # Use no_grad context if specified to disable gradient computation. if no_grad: with torch.no_grad(): x = x.long() # Ensure the input tensor is of type long. - outputs = self.model(x, attention_mask=attention_mask) + outputs = self.pretrained_model(x, attention_mask=attention_mask) # Get the hidden state from the last layer and select the output corresponding to the [CLS] token. cls_embedding = outputs.last_hidden_state[:, 0, :] else: x = x.long() - outputs = self.model(x, attention_mask=attention_mask) + outputs = self.pretrained_model(x, attention_mask=attention_mask) cls_embedding = outputs.last_hidden_state[:, 0, :] - # Apply linear projection to obtain the desired output dimension. cls_embedding = self.embed_proj_head(cls_embedding) - # Normalize the embeddings using the selected normalization layer (SimNorm or LayerNorm) to ensure training stability. cls_embedding = self.norm(cls_embedding) - - return cls_embedding - -from torch.nn.utils import weight_norm - -# AdaptiveFeatureScaler:在对 1D 向量进行 scaling 时,加入 clamp 限制,避免 runaway -class AdaptiveFeatureScaler(nn.Module): - def __init__(self, init_scale=0.1, max_scale=1.0): - super().__init__() - self.scale = nn.Parameter(torch.tensor(init_scale)) - self.max_scale = max_scale - def forward(self, x): - # 限制 scale 参数的最大值,避免数值爆炸 - clamped_scale = torch.clamp(self.scale, 0.0, self.max_scale) - return x * clamped_scale / math.sqrt(x.size(1)) + return cls_embedding -# 假设 SimNorm, ResBlock, DownSample 在其他地方已经定义 -# 下面仅给出 RepresentationNetworkUniZero 的实现 class RepresentationNetworkUniZero(nn.Module): + def __init__( self, - observation_shape: tuple = (3, 64, 64), + observation_shape: SequenceType = (3, 64, 64), num_res_blocks: int = 1, num_channels: int = 64, downsample: bool = True, @@ -485,112 +563,102 @@ def __init__( norm_type: str = 'BN', embedding_dim: int = 256, group_size: int = 8, - final_norm_option_in_encoder: str = 'SimNorm', - use_adaptive_scale: bool = False - # use_global_pooling: bool = True # 新增超参数:是否使用全局平均池化 - # use_global_pooling: bool = False # 新增超参数:是否使用全局平均池化 + final_norm_option_in_encoder: str = 'LayerNorm', # TODO ) -> None: """ - Representation network used in UniZero. - 对于 channel 数较大的场景,可使用全局平均池化来降低全连接层的输入维度,提高训练稳定性。 + Overview: + Representation network used in UniZero. Encode the 2D image obs into latent state. + Currently, the network only supports obs images with both a width and height of 64. + Arguments: + - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64] + for video games like atari, RGB 3 channel. + - num_res_blocks (:obj:`int`): The number of residual blocks. + - num_channels (:obj:`int`): The channel of output hidden state. + - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ + defaults to True. This option is often used in video games like Atari. In board games like go, \ + we don't need this module. + - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \ + Use the inplace operation to speed up. + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - embedding_dim (:obj:`int`): The dimension of the latent state. + - group_size (:obj:`int`): The dimension for simplicial normalization. + - final_norm_option_in_encoder (:obj:`str`): The normalization option for the final layer, defaults to 'SimNorm'. \ + Options are 'SimNorm' and 'LayerNorm'. """ super().__init__() - assert norm_type in ['BN', 'LN'], "norm_type must be in ['BN', 'LN']" - # 打印日志信息(可选) - print(f"Using norm type: {norm_type}") - print(f"Using activation type: {activation}") - - self.use_global_pooling = False + assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" + logging.info(f"Using norm type: {norm_type}") + logging.info(f"Using activation type: {activation}") self.observation_shape = observation_shape self.downsample = downsample - if self.downsample: - # DownSample 对象的实现需自行定义 self.downsample_net = DownSample( observation_shape, num_channels, activation=activation, norm_type=norm_type, - num_resblocks=1, ) else: self.conv = nn.Conv2d(observation_shape[0], num_channels, kernel_size=3, stride=1, padding=1, bias=False) + if norm_type == 'BN': self.norm = nn.BatchNorm2d(num_channels) elif norm_type == 'LN': - # 当不进行 downsample 时,观察图尺寸不变 - self.norm = nn.LayerNorm([num_channels, observation_shape[-2], observation_shape[-1]], eps=1e-5) + if downsample: + self.norm = nn.LayerNorm( + [num_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)], + eps=1e-5) + else: + self.norm = nn.LayerNorm([num_channels, observation_shape[-2], observation_shape[-1]], eps=1e-5) - # 构建 residual block 层 self.resblocks = nn.ModuleList( [ ResBlock( - in_channels=num_channels, - activation=activation, - norm_type=norm_type, - res_type='basic', - bias=False + in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False ) for _ in range(num_res_blocks) ] ) self.activation = activation self.embedding_dim = embedding_dim - # 根据观察图尺寸确定空间维度 - if self.observation_shape[1] == 64: - spatial_size = 8 - elif self.observation_shape[1] in [84, 96]: - spatial_size = 6 - else: - spatial_size = self.observation_shape[1] # 默认采用输入H - + # ==================== 修改开始 ==================== if self.observation_shape[1] == 64: - last_linear_in_dim = num_channels * 8 * 8 - elif self.observation_shape[1] in [84, 96]: - last_linear_in_dim = num_channels * 6 * 6 - else: - # 默认采用完整 flatten 的维度 - last_linear_in_dim = num_channels * self.observation_shape[1] * self.observation_shape[2] - - self.last_linear = nn.Linear(last_linear_in_dim, self.embedding_dim, bias=False) - - - # 根据是否使用全局平均池化决定 last_linear 前的输入维度以及 norm 的形状 - if self.use_global_pooling: - linear_in_dim = num_channels # 全局池化后形状: (B, num_channels, 1, 1) - self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) - # 对 1D 向量使用 LayerNorm - self.norm_before_last_linear = nn.LayerNorm(linear_in_dim, eps=1e-5) - else: - linear_in_dim = num_channels * spatial_size * spatial_size - if use_adaptive_scale: - # 若通过 flatten 后进行 adaptive scaling,对 1D 向量归一化 - self.norm_before_last_linear = nn.LayerNorm(linear_in_dim, eps=1e-5) - else: - # 保留空间信息时,在 (C, H, W) 上归一化 - self.norm_before_last_linear = nn.LayerNorm([num_channels, spatial_size, spatial_size], eps=1e-5) + # 修复:将硬编码的 64 替换为 num_channels + self.last_linear = nn.Linear(num_channels * 8 * 8, self.embedding_dim, bias=False) - self.last_linear = nn.Linear(linear_in_dim, self.embedding_dim, bias=False) - - self.use_adaptive_scale = use_adaptive_scale - if self.use_adaptive_scale: - self.adaptive_scaler = AdaptiveFeatureScaler(init_scale=0.1, max_scale=1.0) + elif self.observation_shape[1] in [84, 96]: + # 修复:将硬编码的 64 替换为 num_channels + self.last_linear = nn.Linear(num_channels * 6 * 6, self.embedding_dim, bias=False) + # ==================== 修改结束 ==================== - # 最后归一化层,根据 final_norm_option_in_encoder 进行选择 - if final_norm_option_in_encoder == 'LayerNorm': + self.final_norm_option_in_encoder=final_norm_option_in_encoder + # 2. 在 __init__ 中统一初始化 final_norm + if self.final_norm_option_in_encoder in ['LayerNorm', 'LayerNorm_Tanh']: self.final_norm = nn.LayerNorm(self.embedding_dim, eps=1e-5) - elif final_norm_option_in_encoder == 'SimNorm': + elif self.final_norm_option_in_encoder == 'LayerNormNoAffine': + self.final_norm = nn.LayerNorm( + self.embedding_dim, eps=1e-5, elementwise_affine=False + ) + elif self.final_norm_option_in_encoder == 'SimNorm': + # 确保 SimNorm 已被定义 self.final_norm = SimNorm(simnorm_dim=group_size) + elif self.final_norm_option_in_encoder == 'L2Norm': + # 直接实例化我们自定义的 L2Norm 模块 + self.final_norm = L2Norm(eps=1e-6) + elif self.final_norm_option_in_encoder is None: + # 如果不需要归一化,可以设置为 nn.Identity() 或 None + self.final_norm = nn.Identity() else: - raise ValueError(f"Unsupported final_norm_option_in_encoder: {final_norm_option_in_encoder}") - + raise ValueError(f"Unsupported final_norm_option_in_encoder: {self.final_norm_option_in_encoder}") + def forward(self, x: torch.Tensor) -> torch.Tensor: """ - Args: - x: (B, C_in, H, W) - Returns: - x: (B, embedding_dim) + Shapes: + - x (:obj:`torch.Tensor`): :math:`(B, C_in, W, H)`, where B is batch size, C_in is channel, W is width, \ + H is height. + - output (:obj:`torch.Tensor`): :math:`(B, C_out, W_, H_)`, where B is batch size, C_out is channel, W_ is \ + output width, H_ is output height. """ if self.downsample: x = self.downsample_net(x) @@ -598,113 +666,86 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv(x) x = self.norm(x) x = self.activation(x) - - # 依次通过多个 residual block for block in self.resblocks: x = block(x) - - # 分支1:使用全局平均池化 - if self.use_global_pooling: - x = self.global_pool(x) # 输出 shape: (B, num_channels, 1, 1) - x = x.view(x.size(0), -1) # 展平为 (B, num_channels) - x = self.norm_before_last_linear(x) # 对 1D 向量做归一化 - else: - # 分支2:不使用全局池化 - if self.use_adaptive_scale: - # 若启用 adaptive scaling:先展平再做 fan-in 缩放 - x = x.view(x.size(0), -1) # (B, num_channels * spatial_size^2) - x = self.adaptive_scaler(x) - x = self.norm_before_last_linear(x) # 归一化 1D 向量 - else: - # 保持完整空间信息:在 (B, C, H, W) 上归一化后,再展平 - x = self.norm_before_last_linear(x) - x = x.view(x.size(0), -1) - # 最后一层全连接映射与归一化 - x = self.last_linear(x) - x = self.final_norm(x) + # Important: Transform the output feature plane to the latent state. + # For example, for an Atari feature plane of shape (64, 8, 8), + # flattening results in a size of 4096, which is then transformed to 768. + x = self.last_linear(x.view(x.size(0), -1)) + + x = x.view(-1, self.embedding_dim) + + # NOTE: very important for training stability. + # x = self.final_norm(x) + + # 3. 在 forward 中统一调用 self.final_norm + # 这种结构更加清晰和可扩展 + if self.final_norm is not None: + x = self.final_norm(x) + + # 针对 LayerNorm_Tanh 的特殊处理 + if self.final_norm_option_in_encoder == 'LayerNorm_Tanh': + x = torch.tanh(x) + return x class RepresentationNetwork(nn.Module): - + """ + Overview: + The standard representation network used in MuZero. It encodes a 2D image observation + into a latent state, which retains its spatial dimensions. + """ def __init__( self, - observation_shape: SequenceType = (4, 96, 96), + observation_shape: Sequence[int] = (4, 96, 96), num_res_blocks: int = 1, num_channels: int = 64, downsample: bool = True, activation: nn.Module = nn.ReLU(inplace=True), norm_type: str = 'BN', - embedding_dim: int = 256, - group_size: int = 8, use_sim_norm: bool = False, + group_size: int = 8, ) -> None: """ - Overview: - Representation network used in MuZero and derived algorithms. Encode the 2D image obs into latent state. - Currently, the network only supports obs images with both a width and height of 96. Arguments: - - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[4, 96, 96] - for video games like atari, 1 gray channel times stack 4 frames. + - observation_shape (:obj:`Sequence[int]`): Shape of the input observation (C, H, W). - num_res_blocks (:obj:`int`): The number of residual blocks. - - num_channels (:obj:`int`): The channel of output hidden state. - - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ - defaults to True. This option is often used in video games like Atari. In board games like go, \ - we don't need this module. - - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \ - Use the inplace operation to speed up. - - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. - - embedding_dim (:obj:`int`): The dimension of the output hidden state. - - group_size (:obj:`int`): The size of group in the SimNorm layer. - - use_sim_norm (:obj:`bool`): Whether to use SimNorm layer, defaults to False. + - num_channels (:obj:`int`): The number of channels in the convolutional layers. + - downsample (:obj:`bool`): Whether to use the `DownSample` module. + - activation (:obj:`nn.Module`): The activation function to use. + - norm_type (:obj:`str`): Normalization type ('BN' or 'LN'). + - use_sim_norm (:obj:`bool`): Whether to apply a final `SimNorm` layer. + - group_size (:obj:`int`): Group size for `SimNorm`. """ super().__init__() - assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" + if norm_type not in ['BN', 'LN']: + raise ValueError(f"Unsupported norm_type: {norm_type}. Must be 'BN' or 'LN'.") self.downsample = downsample + self.activation = activation + if self.downsample: - self.downsample_net = DownSample( - observation_shape, - num_channels, - activation=activation, - norm_type=norm_type, - ) + self.downsample_net = DownSample(observation_shape, num_channels, activation, norm_type) else: self.conv = nn.Conv2d(observation_shape[0], num_channels, kernel_size=3, stride=1, padding=1, bias=False) + self.norm = build_normalization(norm_type, dim=3)(num_channels, *observation_shape[1:]) - if norm_type == 'BN': - self.norm = nn.BatchNorm2d(num_channels) - elif norm_type == 'LN': - if downsample: - self.norm = nn.LayerNorm( - [num_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)], - eps=1e-5) - else: - self.norm = nn.LayerNorm([num_channels, observation_shape[-2], observation_shape[-1]], eps=1e-5) - - self.resblocks = nn.ModuleList( - [ - ResBlock( - in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False - ) for _ in range(num_res_blocks) - ] - ) - self.activation = activation + self.resblocks = nn.ModuleList([ + ResBlock(in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False) + for _ in range(num_res_blocks) + ]) self.use_sim_norm = use_sim_norm - if self.use_sim_norm: - self.embedding_dim = embedding_dim self.sim_norm = SimNorm(simnorm_dim=group_size) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Shapes: - - x (:obj:`torch.Tensor`): :math:`(B, C_in, W, H)`, where B is batch size, C_in is channel, W is width, \ - H is height. - - output (:obj:`torch.Tensor`): :math:`(B, C_out, W_, H_)`, where B is batch size, C_out is channel, W_ is \ - output width, H_ is output height. + - x (:obj:`torch.Tensor`): (B, C_in, H, W) + - output (:obj:`torch.Tensor`): (B, C_out, H_out, W_out) """ if self.downsample: x = self.downsample_net(x) @@ -717,655 +758,484 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = block(x) if self.use_sim_norm: - # NOTE: very important. - # for atari 64,8,8 = 4096 -> 768 - x = self.sim_norm(x) - + # Flatten the spatial dimensions, apply SimNorm, and then reshape back. + b, c, h, w = x.shape + x_flat = x.view(b, c * h * w) + x_norm = self.sim_norm(x_flat) + x = x_norm.view(b, c, h, w) + return x class RepresentationNetworkMLP(nn.Module): - + """ + Overview: + An MLP-based representation network for encoding vector observations into a latent state. + """ def __init__( self, - observation_shape: int, + observation_dim: int, hidden_channels: int = 64, - layer_num: int = 2, + num_layers: int = 2, activation: nn.Module = nn.GELU(approximate='tanh'), norm_type: Optional[str] = 'BN', group_size: int = 8, + final_norm_option_in_encoder: str = 'LayerNorm', # TODO ) -> torch.Tensor: """ - Overview: - Representation network used in MuZero and derived algorithms. Encode the vector obs into latent state \ - with Multi-Layer Perceptron (MLP). Arguments: - - observation_shape (:obj:`int`): The shape of vector observation space, e.g. N = 10. - - num_res_blocks (:obj:`int`): The number of residual blocks. - - hidden_channels (:obj:`int`): The channel of output hidden state. - - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ - defaults to True. This option is often used in video games like Atari. In board games like go, \ - we don't need this module. - - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \ - Use the inplace operation to speed up. - - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - observation_dim (:obj:`int`): The dimension of the input vector observation. + - hidden_channels (:obj:`int`): The number of neurons in the hidden and output layers. + - num_layers (:obj:`int`): The total number of layers in the MLP. + - activation (:obj:`nn.Module`): The activation function to use. + - norm_type (:obj:`Optional[str]`): The type of normalization ('BN', 'LN', or None). + - group_size (:obj:`int`): The group size for the final `SimNorm` layer. """ super().__init__() - self.fc_representation = MLP( - in_channels=observation_shape, - hidden_channels=hidden_channels, + # Creating hidden layers list for MLP_V2 + hidden_layers = [hidden_channels] * (num_layers - 1) if num_layers > 1 else [] + + self.fc_representation = MLP_V2( + in_channels=observation_dim, + hidden_channels=hidden_layers, out_channels=hidden_channels, - layer_num=layer_num, activation=activation, norm_type=norm_type, - # don't use activation and norm in the last layer of representation network is important for convergence. output_activation=False, output_norm=False, - # last_linear_layer_init_zero=True is beneficial for convergence speed. last_linear_layer_init_zero=True, ) - self.sim_norm = SimNorm(simnorm_dim=group_size) + + # # Select the normalization method based on the final_norm_option_in_encoder parameter. + if final_norm_option_in_encoder.lower() == "simnorm": + self.norm = SimNorm(simnorm_dim=group_size) + elif final_norm_option_in_encoder.lower() == "layernorm": + self.norm = nn.LayerNorm(hidden_channels) + else: + raise NotImplementedError(f"Normalization type '{final_norm_option_in_encoder}' is not implemented. " + f"Choose 'simnorm' or 'layernorm'.") def forward(self, x: torch.Tensor) -> torch.Tensor: """ Shapes: - - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size, N is the length of vector observation. - - output (:obj:`torch.Tensor`): :math:`(B, hidden_channels)`, where B is batch size. + - x (:obj:`torch.Tensor`): (B, observation_dim) + - output (:obj:`torch.Tensor`): (B, hidden_channels) """ x = self.fc_representation(x) - # TODO - x = self.sim_norm(x) + x = self.norm(x) + return x class LatentDecoder(nn.Module): - - def __init__(self, embedding_dim: int, output_shape: SequenceType, num_channels: int = 64, activation: nn.Module = nn.GELU(approximate='tanh')): + """ + Overview: + A decoder network that reconstructs a 2D image from a 1D latent embedding. + It acts as the inverse of a representation network like `RepresentationNetworkUniZero`. + """ + def __init__( + self, + embedding_dim: int, + output_shape: Tuple[int, int, int], + num_channels: int = 64, + activation: nn.Module = nn.GELU(approximate='tanh') + ): """ - Overview: - Decoder network used in UniZero. Decode the latent state into 2D image obs. Arguments: - - embedding_dim (:obj:`int`): The dimension of the latent state. - - output_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64] - for video games like atari, RGB 3 channel times stack 4 frames. - - num_channels (:obj:`int`): The channel of output hidden state. - - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.GELU(approximate='tanh'). + - embedding_dim (:obj:`int`): The dimension of the input latent embedding. + - output_shape (:obj:`Tuple[int, int, int]`): The shape of the target output image (C, H, W). + - num_channels (:obj:`int`): The base number of channels for the initial upsampling stage. + - activation (:obj:`nn.Module`): The activation function to use. """ super().__init__() self.embedding_dim = embedding_dim - self.output_shape = output_shape # (C, H, W) - self.num_channels = num_channels - self.activation = activation - - # Assuming that the output shape is (C, H, W) = (12, 96, 96) and embedding_dim is 256 - # We will reverse the process of the representation network - self.initial_size = ( - num_channels, output_shape[1] // 8, output_shape[2] // 8) # This should match the last layer of the encoder - self.fc = nn.Linear(self.embedding_dim, np.prod(self.initial_size)) + self.output_shape = output_shape + + # This should match the spatial size of the encoder's feature map before flattening. + # Assuming a total downsampling factor of 8 (e.g., for a 64x64 -> 8x8 encoder). + self.initial_h = output_shape[1] // 8 + self.initial_w = output_shape[2] // 8 + self.initial_size = (num_channels, self.initial_h, self.initial_w) + + self.fc = nn.Linear(embedding_dim, np.prod(self.initial_size)) - # Upsampling blocks - self.conv_blocks = nn.ModuleList([ - # Block 1: (num_channels, H/8, W/8) -> (num_channels//2, H/4, W/4) + self.deconv_blocks = nn.Sequential( + # Block 1: (C, H/8, W/8) -> (C/2, H/4, W/4) nn.ConvTranspose2d(num_channels, num_channels // 2, kernel_size=3, stride=2, padding=1, output_padding=1), - self.activation, + activation, nn.BatchNorm2d(num_channels // 2), - # Block 2: (num_channels//2, H/4, W/4) -> (num_channels//4, H/2, W/2) - nn.ConvTranspose2d(num_channels // 2, num_channels // 4, kernel_size=3, stride=2, padding=1, - output_padding=1), - self.activation, + # Block 2: (C/2, H/4, W/4) -> (C/4, H/2, W/2) + nn.ConvTranspose2d(num_channels // 2, num_channels // 4, kernel_size=3, stride=2, padding=1, output_padding=1), + activation, nn.BatchNorm2d(num_channels // 4), - # Block 3: (num_channels//4, H/2, W/2) -> (output_shape[0], H, W) - nn.ConvTranspose2d(num_channels // 4, output_shape[0], kernel_size=3, stride=2, padding=1, - output_padding=1), - ]) - # TODO: last layer use sigmoid? + # Block 3: (C/4, H/2, W/2) -> (output_C, H, W) + nn.ConvTranspose2d(num_channels // 4, output_shape[0], kernel_size=3, stride=2, padding=1, output_padding=1), + # A final activation like Sigmoid or Tanh is often used if pixel values are in a fixed range [0,1] or [-1,1]. + # We omit it here to maintain consistency with the original code. + ) def forward(self, embeddings: torch.Tensor) -> torch.Tensor: - # Map embeddings back to the image space - x = self.fc(embeddings) # (B, embedding_dim) -> (B, C*H/8*W/8) - x = x.view(-1, *self.initial_size) # (B, C*H/8*W/8) -> (B, C, H/8, W/8) - - # Apply conv blocks - for block in self.conv_blocks: - x = block(x) # Upsample progressively - - # The output x should have the shape of (B, output_shape[0], output_shape[1], output_shape[2]) + """ + Shapes: + - embeddings (:obj:`torch.Tensor`): (B, embedding_dim) + - output (:obj:`torch.Tensor`): (B, C, H, W) + """ + x = self.fc(embeddings) + x = x.view(-1, *self.initial_size) + x = self.deconv_blocks(x) return x -class LatentEncoderForMemoryEnv(nn.Module): +# --- Networks for MemoryEnv --- +class LatentEncoderForMemoryEnv(nn.Module): + """ + Overview: + An encoder for the MemoryEnv, converting a small image observation into a latent embedding. + It uses a series of convolutions followed by adaptive average pooling. + """ def __init__( self, - image_shape=(3, 5, 5), - embedding_size=100, - channels=[16, 32, 64], - kernel_sizes=[3, 3, 3], - strides=[1, 1, 1], + image_shape: Tuple[int, int, int] = (3, 5, 5), + embedding_size: int = 100, + channels: List[int] = [16, 32, 64], + kernel_sizes: List[int] = [3, 3, 3], + strides: List[int] = [1, 1, 1], activation: nn.Module = nn.GELU(approximate='tanh'), - normalize_pixel=False, + normalize_pixel: bool = False, group_size: int = 8, - **kwargs, ): """ - Overview: - Encoder network used in UniZero in MemoryEnv. Encode the 2D image obs into latent state. Arguments: - - image_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64] - for video games like atari, RGB 3 channel times stack 4 frames. - - embedding_size (:obj:`int`): The dimension of the latent state. - - channels (:obj:`List[int]`): The channel of output hidden state. - - kernel_sizes (:obj:`List[int]`): The kernel size of convolution layers. - - strides (:obj:`List[int]`): The stride of convolution layers. - - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.GELU(approximate='tanh'). \ - Use the inplace operation to speed up. - - normalize_pixel (:obj:`bool`): Whether to normalize the pixel values to [0, 1], defaults to False. - - group_size (:obj:`int`): The dimension for simplicial normalization + - image_shape (:obj:`Tuple[int, int, int]`): Shape of the input image (C, H, W). + - embedding_size (:obj:`int`): Dimension of the output latent embedding. + - channels (:obj:`List[int]`): List of output channels for each convolutional layer. + - kernel_sizes (:obj:`List[int]`): List of kernel sizes for each convolutional layer. + - strides (:obj:`List[int]`): List of strides for each convolutional layer. + - activation (:obj:`nn.Module`): Activation function to use. + - normalize_pixel (:obj:`bool`): Whether to normalize input pixel values to [0, 1]. + - group_size (:obj:`int`): Group size for the final `SimNorm` layer. """ - super(LatentEncoderForMemoryEnv, self).__init__() - self.shape = image_shape - self.channels = [image_shape[0]] + list(channels) + super().__init__() + self.normalize_pixel = normalize_pixel + all_channels = [image_shape[0]] + channels layers = [] - for i in range(len(self.channels) - 1): - layers.append( - nn.Conv2d( - self.channels[i], self.channels[i + 1], kernel_sizes[i], strides[i], - padding=kernel_sizes[i] // 2 # keep the same size of feature map - ) - ) - layers.append(nn.BatchNorm2d(self.channels[i + 1])) - layers.append(activation) - + for i in range(len(channels)): + layers.extend([ + nn.Conv2d(all_channels[i], all_channels[i+1], kernel_sizes[i], strides[i], padding=kernel_sizes[i]//2), + nn.BatchNorm2d(all_channels[i+1]), + activation + ]) layers.append(nn.AdaptiveAvgPool2d(1)) - self.cnn = nn.Sequential(*layers) - self.linear = nn.Sequential( - nn.Linear(self.channels[-1], embedding_size, bias=False), - ) - init.kaiming_normal_(self.linear[0].weight, mode='fan_out', nonlinearity='relu') + + self.linear = nn.Linear(channels[-1], embedding_size, bias=False) + init.kaiming_normal_(self.linear.weight, mode='fan_out', nonlinearity='relu') - self.normalize_pixel = normalize_pixel self.sim_norm = SimNorm(simnorm_dim=group_size) - def forward(self, image): + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Shapes: + - image (:obj:`torch.Tensor`): (B, C, H, W) + - output (:obj:`torch.Tensor`): (B, embedding_size) + """ if self.normalize_pixel: - image = image / 255.0 - x = self.cnn(image.float()) # (B, C, 1, 1) - x = torch.flatten(x, start_dim=1) # (B, C) - x = self.linear(x) # (B, embedding_size) + image = image.float() / 255.0 + + x = self.cnn(image.float()) + x = torch.flatten(x, start_dim=1) + x = self.linear(x) x = self.sim_norm(x) return x class LatentDecoderForMemoryEnv(nn.Module): - + """ + Overview: + A decoder for the MemoryEnv, reconstructing a small image from a latent embedding. + It uses a linear layer followed by a series of transposed convolutions. + """ def __init__( self, - image_shape=(3, 5, 5), - embedding_size=256, - channels=[64, 32, 16], - kernel_sizes=[3, 3, 3], - strides=[1, 1, 1], + image_shape: Tuple[int, int, int] = (3, 5, 5), + embedding_size: int = 256, + channels: List[int] = [64, 32, 16], + kernel_sizes: List[int] = [3, 3, 3], + strides: List[int] = [1, 1, 1], activation: nn.Module = nn.LeakyReLU(negative_slope=0.01), - **kwargs, ): """ - Overview: - Decoder network used in UniZero in MemoryEnv. Decode the latent state into 2D image obs. Arguments: - - image_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64] - for video games like atari, RGB 3 channel times stack 4 frames. - - embedding_size (:obj:`int`): The dimension of the latent state. - - channels (:obj:`List[int]`): The channel of output hidden state. - - kernel_sizes (:obj:`List[int]`): The kernel size of convolution layers. - - strides (:obj:`List[int]`): The stride of convolution layers. - - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.LeakyReLU(). \ - Use the inplace operation to speed up. + - image_shape (:obj:`Tuple[int, int, int]`): Shape of the target output image (C, H, W). + - embedding_size (:obj:`int`): Dimension of the input latent embedding. + - channels (:obj:`List[int]`): List of channels for each deconvolutional layer. + - kernel_sizes (:obj:`List[int]`): List of kernel sizes. + - strides (:obj:`List[int]`): List of strides. + - activation (:obj:`nn.Module`): Activation function for intermediate layers. """ - super(LatentDecoderForMemoryEnv, self).__init__() + super().__init__() self.shape = image_shape - self.channels = list(channels) + [image_shape[0]] - + self.deconv_channels = channels + [image_shape[0]] + self.linear = nn.Linear(embedding_size, channels[0] * image_shape[1] * image_shape[2]) layers = [] - for i in range(len(self.channels) - 1): + for i in range(len(self.deconv_channels) - 1): layers.append( nn.ConvTranspose2d( - self.channels[i], self.channels[i + 1], kernel_sizes[i], strides[i], - padding=kernel_sizes[i] // 2, output_padding=strides[i] - 1 + self.deconv_channels[i], self.deconv_channels[i+1], kernel_sizes[i], strides[i], + padding=kernel_sizes[i]//2, output_padding=strides[i]-1 ) ) - if i < len(self.channels) - 2: - layers.append(nn.BatchNorm2d(self.channels[i + 1])) - layers.append(activation) + if i < len(self.deconv_channels) - 2: + layers.extend([nn.BatchNorm2d(self.deconv_channels[i+1]), activation]) else: + # Final layer uses Sigmoid to output pixel values in [0, 1]. layers.append(nn.Sigmoid()) - self.deconv = nn.Sequential(*layers) - def forward(self, embedding): + def forward(self, embedding: torch.Tensor) -> torch.Tensor: + """ + Shapes: + - embedding (:obj:`torch.Tensor`): (B, embedding_size) + - output (:obj:`torch.Tensor`): (B, C, H, W) + """ x = self.linear(embedding) - x = x.view(-1, self.channels[0], self.shape[1], self.shape[2]) - x = self.deconv(x) # (B, C, H, W) + x = x.view(-1, self.deconv_channels[0], self.shape[1], self.shape[2]) + x = self.deconv(x) return x class VectorDecoderForMemoryEnv(nn.Module): - + """ + Overview: + An MLP-based decoder for MemoryEnv, reconstructing a vector observation from a latent embedding. + """ def __init__( self, embedding_dim: int, - output_shape: SequenceType, + output_dim: int, hidden_channels: int = 64, - layer_num: int = 2, - activation: nn.Module = nn.LeakyReLU(negative_slope=0.01), # TODO + num_layers: int = 2, + activation: nn.Module = nn.LeakyReLU(negative_slope=0.01), norm_type: Optional[str] = 'BN', - ) -> torch.Tensor: + ) -> None: """ - Overview: - Decoder network used in UniZero in MemoryEnv. Decode the latent state into vector obs. Arguments: - - observation_shape (:obj:`int`): The shape of vector observation space, e.g. N = 10. - - num_res_blocks (:obj:`int`): The number of residual blocks. - - hidden_channels (:obj:`int`): The channel of output hidden state. - - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ - defaults to True. This option is often used in video games like Atari. In board games like go, \ - we don't need this module. - - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(). \ - Use the inplace operation to speed up. - - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - embedding_dim (:obj:`int`): Dimension of the input latent embedding. + - output_dim (:obj:`int`): Dimension of the target output vector. + - hidden_channels (:obj:`int`): Number of neurons in the hidden layers. + - num_layers (:obj:`int`): Total number of layers in the MLP. + - activation (:obj:`nn.Module`): Activation function to use. + - norm_type (:obj:`Optional[str]`): Normalization type ('BN', 'LN', or None). """ super().__init__() - self.fc_representation = MLP( + hidden_layers = [hidden_channels] * (num_layers - 1) if num_layers > 1 else [] + + self.fc_decoder = MLP_V2( in_channels=embedding_dim, - hidden_channels=hidden_channels, - out_channels=output_shape, - layer_num=layer_num, + hidden_channels=hidden_layers, + out_channels=output_dim, activation=activation, norm_type=norm_type, - # don't use activation and norm in the last layer of representation network is important for convergence. output_activation=False, output_norm=False, - # last_linear_layer_init_zero=True is beneficial for convergence speed. last_linear_layer_init_zero=True, ) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Shapes: - - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size, N is the length of vector observation. - - output (:obj:`torch.Tensor`): :math:`(B, hidden_channels)`, where B is batch size. + - x (:obj:`torch.Tensor`): (B, embedding_dim) + - output (:obj:`torch.Tensor`): (B, output_dim) """ - x = self.fc_representation(x) - return x + return self.fc_decoder(x) +# --- Prediction Networks --- class PredictionNetwork(nn.Module): - + """ + Overview: + Predicts the policy and value from a given latent state. This network is typically used + in the prediction step of MuZero-like algorithms. It processes a 2D latent state. + """ def __init__( self, - observation_shape: SequenceType, action_space_size: int, num_res_blocks: int, num_channels: int, - value_head_channels: int, - policy_head_channels: int, - value_head_hidden_channels: int, - policy_head_hidden_channels: int, - output_support_size: int, - flatten_input_size_for_value_head: int, - flatten_input_size_for_policy_head: int, - downsample: bool = False, + value_head_channels: int = 1, + policy_head_channels: int = 2, + value_head_hidden_channels: List[int] = [256], + policy_head_hidden_channels: List[int] = [256], + output_support_size: int = 601, last_linear_layer_init_zero: bool = True, activation: nn.Module = nn.ReLU(inplace=True), - norm_type: Optional[str] = 'BN', + norm_type: str = 'BN', ) -> None: """ - Overview: - The definition of policy and value prediction network, which is used to predict value and policy by the - given latent state. Arguments: - - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. (C, H, W) for image. - - action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space. - - num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model. - - num_channels (:obj:`int`): The channels of hidden states. - - value_head_channels (:obj:`int`): The channels of value head. - - policy_head_channels (:obj:`int`): The channels of policy head. - - value_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). - - policy_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). - - output_support_size (:obj:`int`): The size of categorical value output. - - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks \ - - flatten_input_size_for_value_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \ - of the value head. - - flatten_input_size_for_policy_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \ - of the policy head. - - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``. - - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \ - dynamics/prediction mlp, default sets it to True. - - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ - operation to speedup, e.g. ReLU(inplace=True). - - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - action_space_size: (:obj:`int`): The size of the action space. + - num_res_blocks (:obj:`int`): The number of residual blocks. + - num_channels (:obj:`int`): The number of channels in the input latent state. + - value_head_channels (:obj:`int`): Channels for the value head's convolutional layer. + - policy_head_channels (:obj:`int`): Channels for the policy head's convolutional layer. + - value_head_hidden_channels (:obj:`List[int]`): Hidden layer sizes for the value MLP head. + - policy_head_hidden_channels (:obj:`List[int]`): Hidden layer sizes for the policy MLP head. + - output_support_size (:obj:`int`): The size of the categorical value distribution. + - last_linear_layer_init_zero (:obj:`bool`): Whether to initialize the last layer of heads to zero. + - activation (:obj:`nn.Module`): The activation function. + - norm_type (:obj:`str`): The normalization type ('BN' or 'LN'). """ - super(PredictionNetwork, self).__init__() - assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" - - self.resblocks = nn.ModuleList( - [ - ResBlock( - in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False - ) for _ in range(num_res_blocks) - ] - ) + super().__init__() + if norm_type not in ['BN', 'LN']: + raise ValueError(f"Unsupported norm_type: {norm_type}. Must be 'BN' or 'LN'.") + self.resblocks = nn.ModuleList([ + ResBlock(in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False) + for _ in range(num_res_blocks) + ]) + self.conv1x1_value = nn.Conv2d(num_channels, value_head_channels, 1) self.conv1x1_policy = nn.Conv2d(num_channels, policy_head_channels, 1) - if observation_shape[1] == 96: - latent_shape = (observation_shape[1] // 16, observation_shape[2] // 16) - elif observation_shape[1] == 64: - latent_shape = (observation_shape[1] // 8, observation_shape[2] // 8) - - if norm_type == 'BN': - self.norm_value = nn.BatchNorm2d(value_head_channels) - self.norm_policy = nn.BatchNorm2d(policy_head_channels) - elif norm_type == 'LN': - if downsample: - self.norm_value = nn.LayerNorm( - [value_head_channels, *latent_shape], - eps=1e-5) - self.norm_policy = nn.LayerNorm([policy_head_channels, *latent_shape], eps=1e-5) - else: - self.norm_value = nn.LayerNorm([value_head_channels, observation_shape[-2], observation_shape[-1]], - eps=1e-5) - self.norm_policy = nn.LayerNorm([policy_head_channels, observation_shape[-2], observation_shape[-1]], - eps=1e-5) - - self.flatten_input_size_for_value_head = flatten_input_size_for_value_head - self.flatten_input_size_for_policy_head = flatten_input_size_for_policy_head - + self.norm_value = build_normalization(norm_type, dim=2)(value_head_channels) + self.norm_policy = build_normalization(norm_type, dim=2)(policy_head_channels) self.activation = activation + # The input size for the MLP heads depends on the spatial dimensions of the latent state. + # This must be pre-calculated and passed correctly. + # Example: for a 6x6 latent space, flatten_input_size = channels * 6 * 6 + # We assume the user will provide these values. + # Here we just define placeholder attributes. + self._flatten_input_size_for_value_head = None + self._flatten_input_size_for_policy_head = None + self.fc_value = MLP_V2( - in_channels=self.flatten_input_size_for_value_head, + in_channels=-1, # Placeholder, will be determined at first forward pass hidden_channels=value_head_hidden_channels, out_channels=output_support_size, - activation=self.activation, + activation=activation, norm_type=norm_type, output_activation=False, output_norm=False, - # last_linear_layer_init_zero=True is beneficial for convergence speed. last_linear_layer_init_zero=last_linear_layer_init_zero ) self.fc_policy = MLP_V2( - in_channels=self.flatten_input_size_for_policy_head, + in_channels=-1, # Placeholder hidden_channels=policy_head_hidden_channels, out_channels=action_space_size, - activation=self.activation, + activation=activation, norm_type=norm_type, output_activation=False, output_norm=False, - # last_linear_layer_init_zero=True is beneficial for convergence speed. last_linear_layer_init_zero=last_linear_layer_init_zero ) def forward(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ - Overview: - Forward computation of the prediction network. - Arguments: - - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim). - Returns: - - policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size). - - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). + Shapes: + - latent_state (:obj:`torch.Tensor`): (B, C, H, W) + - policy_logits (:obj:`torch.Tensor`): (B, action_space_size) + - value (:obj:`torch.Tensor`): (B, output_support_size) """ for res_block in self.resblocks: latent_state = res_block(latent_state) - value = self.conv1x1_value(latent_state) - value = self.norm_value(value) - value = self.activation(value) + value_feat = self.activation(self.norm_value(self.conv1x1_value(latent_state))) + policy_feat = self.activation(self.norm_policy(self.conv1x1_policy(latent_state))) + + value_flat = value_feat.view(value_feat.size(0), -1) + policy_flat = policy_feat.view(policy_feat.size(0), -1) - policy = self.conv1x1_policy(latent_state) - policy = self.norm_policy(policy) - policy = self.activation(policy) + # Dynamically initialize in_channels on the first forward pass + if self.fc_value.in_channels == -1: + self.fc_value[0].in_features = value_flat.shape[1] + self.fc_policy[0].in_features = policy_flat.shape[1] + # PyTorch lazy modules handle this better, but this is a manual way. + self.fc_value[0].weight.data.uniform_(-math.sqrt(1/value_flat.shape[1]), math.sqrt(1/value_flat.shape[1])) + self.fc_policy[0].weight.data.uniform_(-math.sqrt(1/policy_flat.shape[1]), math.sqrt(1/policy_flat.shape[1])) - value = value.reshape(-1, self.flatten_input_size_for_value_head) - policy = policy.reshape(-1, self.flatten_input_size_for_policy_head) - value = self.fc_value(value) - policy = self.fc_policy(policy) - return policy, value + value = self.fc_value(value_flat) + policy_logits = self.fc_policy(policy_flat) + return policy_logits, value class PredictionNetworkMLP(nn.Module): - + """ + Overview: + An MLP-based prediction network that predicts policy and value from a 1D latent state. + """ def __init__( self, - action_space_size, - num_channels, + action_space_size: int, + num_channels: int, common_layer_num: int = 2, - value_head_hidden_channels: SequenceType = [32], - policy_head_hidden_channels: SequenceType = [32], + value_head_hidden_channels: List[int] = [32], + policy_head_hidden_channels: List[int] = [32], output_support_size: int = 601, last_linear_layer_init_zero: bool = True, - activation: Optional[nn.Module] = nn.ReLU(inplace=True), + activation: nn.Module = nn.ReLU(inplace=True), norm_type: Optional[str] = 'BN', ): """ - Overview: - The definition of policy and value prediction network with Multi-Layer Perceptron (MLP), - which is used to predict value and policy by the given latent state. Arguments: - - action_space_size: (:obj:`int`): Action space size, usually an integer number. For discrete action \ - space, it is the number of discrete actions. - - num_channels (:obj:`int`): The channels of latent states. - - value_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). - - policy_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). - - output_support_size (:obj:`int`): The size of categorical value output. - - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \ - dynamics/prediction mlp, default sets it to True. - - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ - operation to speedup, e.g. ReLU(inplace=True). - - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - action_space_size: (:obj:`int`): The size of the action space. + - num_channels (:obj:`int`): The dimension of the input latent state. + - common_layer_num (:obj:`int`): Number of layers in the shared backbone MLP. + - value_head_hidden_channels (:obj:`List[int]`): Hidden layer sizes for the value MLP head. + - policy_head_hidden_channels (:obj:`List[int]`): Hidden layer sizes for the policy MLP head. + - output_support_size (:obj:`int`): The size of the categorical value distribution. + - last_linear_layer_init_zero (:obj:`bool`): Whether to initialize the last layer of heads to zero. + - activation (:obj:`nn.Module`): The activation function. + - norm_type (:obj:`Optional[str]`): The normalization type. """ super().__init__() - self.num_channels = num_channels - - # ******* common backbone ****** - self.fc_prediction_common = MLP( - in_channels=self.num_channels, - hidden_channels=self.num_channels, - out_channels=self.num_channels, - layer_num=common_layer_num, + + common_hidden = [num_channels] * (common_layer_num - 1) if common_layer_num > 1 else [] + self.fc_prediction_common = MLP_V2( + in_channels=num_channels, + hidden_channels=common_hidden, + out_channels=num_channels, activation=activation, norm_type=norm_type, output_activation=True, output_norm=True, - # last_linear_layer_init_zero=False is important for convergence last_linear_layer_init_zero=False, ) - # ******* value and policy head ****** self.fc_value_head = MLP_V2( - in_channels=self.num_channels, + in_channels=num_channels, hidden_channels=value_head_hidden_channels, out_channels=output_support_size, activation=activation, norm_type=norm_type, output_activation=False, output_norm=False, - # last_linear_layer_init_zero=True is beneficial for convergence speed. last_linear_layer_init_zero=last_linear_layer_init_zero ) self.fc_policy_head = MLP_V2( - in_channels=self.num_channels, + in_channels=num_channels, hidden_channels=policy_head_hidden_channels, out_channels=action_space_size, activation=activation, norm_type=norm_type, output_activation=False, output_norm=False, - # last_linear_layer_init_zero=True is beneficial for convergence speed. - last_linear_layer_init_zero=last_linear_layer_init_zero - ) - - def forward(self, latent_state: torch.Tensor): - """ - Overview: - Forward computation of the prediction network. - Arguments: - - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim). - Returns: - - policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size). - - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). - """ - x_prediction_common = self.fc_prediction_common(latent_state) - - value = self.fc_value_head(x_prediction_common) - policy = self.fc_policy_head(x_prediction_common) - return policy, value - - -class PredictionHiddenNetwork(nn.Module): - - def __init__( - self, - observation_shape: SequenceType, - action_space_size: int, - num_res_blocks: int, - num_channels: int, - value_head_channels: int, - policy_head_channels: int, - value_head_hidden_channels: int, - policy_head_hidden_channels: int, - output_support_size: int, - flatten_input_size_for_value_head: int, - flatten_input_size_for_policy_head: int, - downsample: bool = False, - last_linear_layer_init_zero: bool = True, - activation: nn.Module = nn.ReLU(inplace=True), - norm_type: Optional[str] = 'BN', - gru_hidden_size: int = 512, - ) -> None: - """ - Overview: - The definition of policy and value prediction network, which is used to predict value and policy by the - given latent state. - Arguments: - - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. (C, H, W) for image. - - action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space. - - num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model. - - num_channels (:obj:`int`): The channels of hidden states. - - value_head_channels (:obj:`int`): The channels of value head. - - policy_head_channels (:obj:`int`): The channels of policy head. - - value_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). - - policy_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). - - output_support_size (:obj:`int`): The size of categorical value output. - - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks \ - - flatten_input_size_for_value_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \ - of the value head. - - flatten_input_size_for_policy_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \ - of the policy head. - - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``. - - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \ - dynamics/prediction mlp, default sets it to True. - - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ - operation to speedup, e.g. ReLU(inplace=True). - - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. - """ - super(PredictionHiddenNetwork, self).__init__() - assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" - - self.observation_shape = observation_shape - self.gru_hidden_size = gru_hidden_size - self.resblocks = nn.ModuleList( - [ - ResBlock( - in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False - ) for _ in range(num_res_blocks) - ] - ) - - self.conv1x1_value = nn.Conv2d(num_channels, value_head_channels, 1) - self.conv1x1_policy = nn.Conv2d(num_channels, policy_head_channels, 1) - - if norm_type == 'BN': - self.norm_value = nn.BatchNorm2d(value_head_channels) - self.norm_policy = nn.BatchNorm2d(policy_head_channels) - elif norm_type == 'LN': - if downsample: - self.norm_value = nn.LayerNorm( - [value_head_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)], - eps=1e-5) - self.norm_policy = nn.LayerNorm([policy_head_channels, math.ceil(observation_shape[-2] / 16), - math.ceil(observation_shape[-1] / 16)], eps=1e-5) - else: - self.norm_value = nn.LayerNorm([value_head_channels, observation_shape[-2], observation_shape[-1]], - eps=1e-5) - self.norm_policy = nn.LayerNorm([policy_head_channels, observation_shape[-2], observation_shape[-1]], - eps=1e-5) - - self.flatten_input_size_for_value_head = flatten_input_size_for_value_head - self.flatten_input_size_for_policy_head = flatten_input_size_for_policy_head - - self.activation = activation - - self.fc_value = MLP( - in_channels=self.flatten_input_size_for_value_head + self.gru_hidden_size, - hidden_channels=value_head_hidden_channels[0], - out_channels=output_support_size, - layer_num=len(value_head_hidden_channels) + 1, - activation=self.activation, - norm_type=norm_type, - output_activation=False, - output_norm=False, - # last_linear_layer_init_zero=True is beneficial for convergence speed. - last_linear_layer_init_zero=last_linear_layer_init_zero - ) - self.fc_policy = MLP( - in_channels=self.flatten_input_size_for_policy_head + self.gru_hidden_size, - hidden_channels=policy_head_hidden_channels[0], - out_channels=action_space_size, - layer_num=len(policy_head_hidden_channels) + 1, - activation=self.activation, - norm_type=norm_type, - output_activation=False, - output_norm=False, - # last_linear_layer_init_zero=True is beneficial for convergence speed. last_linear_layer_init_zero=last_linear_layer_init_zero ) - def forward(self, latent_state: torch.Tensor, world_model_latent_history: torch.Tensor) -> Tuple[ - torch.Tensor, torch.Tensor]: + def forward(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ - Overview: - Forward computation of the prediction network. - Arguments: - - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim). - Returns: - - policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size). - - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). + Shapes: + - latent_state (:obj:`torch.Tensor`): (B, num_channels) + - policy_logits (:obj:`torch.Tensor`): (B, action_space_size) + - value (:obj:`torch.Tensor`): (B, output_support_size) """ - for res_block in self.resblocks: - latent_state = res_block(latent_state) - - value = self.conv1x1_value(latent_state) - value = self.norm_value(value) - value = self.activation(value) - - policy = self.conv1x1_policy(latent_state) - policy = self.norm_policy(policy) - policy = self.activation(policy) - - latent_state_value = value.reshape(-1, self.flatten_input_size_for_value_head) - latent_state_policy = policy.reshape(-1, self.flatten_input_size_for_policy_head) - - # TODO: world_model_latent_history.squeeze(0) shape: (num_layers * num_directions, batch_size, hidden_size) -> ( batch_size, hidden_size) - latent_history_value = torch.cat([latent_state_value, world_model_latent_history.squeeze(0)], dim=1) - latent_history_policy = torch.cat([latent_state_policy, world_model_latent_history.squeeze(0)], dim=1) - - value = self.fc_value(latent_history_value) - policy = self.fc_policy(latent_history_policy) - return policy, value \ No newline at end of file + x = self.fc_prediction_common(latent_state) + value = self.fc_value_head(x) + policy_logits = self.fc_policy_head(x) + return policy_logits, value \ No newline at end of file diff --git a/lzero/model/efficientzero_model.py b/lzero/model/efficientzero_model.py index 09cc5e63a..3448fe5b8 100644 --- a/lzero/model/efficientzero_model.py +++ b/lzero/model/efficientzero_model.py @@ -32,8 +32,8 @@ def __init__( reward_head_hidden_channels: SequenceType = [32], value_head_hidden_channels: SequenceType = [32], policy_head_hidden_channels: SequenceType = [32], - reward_support_size: int = 601, - value_support_size: int = 601, + reward_support_range: SequenceType =(-300., 301., 1.), + value_support_range: SequenceType =(-300., 301., 1.), proj_hid: int = 1024, proj_out: int = 1024, pred_hid: int = 512, @@ -66,8 +66,8 @@ def __init__( - reward_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). - value_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). - policy_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). - - reward_support_size (:obj:`int`): The size of categorical reward output - - value_support_size (:obj:`int`): The size of categorical value output. + - reward_support_range (:obj:`SequenceType`): The range of categorical reward output + - value_support_range (:obj:`SequenceType`): The range of categorical value output. - proj_hid (:obj:`int`): The size of projection hidden layer. - proj_out (:obj:`int`): The size of projection output layer. - pred_hid (:obj:`int`): The size of prediction hidden layer. @@ -91,12 +91,13 @@ def __init__( # for vector obs input, e.g. classical control and box2d environments # to be compatible with LightZero model/policy, transform to shape: [C, W, H] observation_shape = [1, observation_shape, 1] - if not categorical_distribution: + self.categorical_distribution = categorical_distribution + if self.categorical_distribution: + self.reward_support_size = len(torch.arange(*reward_support_range)) + self.value_support_size = len(torch.arange(*value_support_range)) + else: self.reward_support_size = 1 self.value_support_size = 1 - else: - self.reward_support_size = reward_support_size - self.value_support_size = value_support_size self.action_space_size = action_space_size assert discrete_action_encoding_type in ['one_hot', 'not_one_hot'], discrete_action_encoding_type diff --git a/lzero/model/efficientzero_model_mlp.py b/lzero/model/efficientzero_model_mlp.py index 51f3962ce..862f6417c 100644 --- a/lzero/model/efficientzero_model_mlp.py +++ b/lzero/model/efficientzero_model_mlp.py @@ -22,8 +22,8 @@ def __init__( reward_head_hidden_channels: SequenceType = [32], value_head_hidden_channels: SequenceType = [32], policy_head_hidden_channels: SequenceType = [32], - reward_support_size: int = 601, - value_support_size: int = 601, + reward_support_range: SequenceType =(-300., 301., 1.), + value_support_range: SequenceType =(-300., 301., 1.), proj_hid: int = 1024, proj_out: int = 1024, pred_hid: int = 512, @@ -55,8 +55,8 @@ def __init__( - reward_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). - value_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). - policy_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). - - reward_support_size (:obj:`int`): The size of categorical reward output - - value_support_size (:obj:`int`): The size of categorical value output. + - reward_support_range (:obj:`SequenceType`): The range of categorical reward output + - value_support_range (:obj:`SequenceType`): The range of categorical value output. - proj_hid (:obj:`int`): The size of projection hidden layer. - proj_out (:obj:`int`): The size of projection output layer. - pred_hid (:obj:`int`): The size of prediction hidden layer. @@ -72,12 +72,13 @@ def __init__( - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection for dynamics network, default set it to False. """ super(EfficientZeroModelMLP, self).__init__() - if not categorical_distribution: + self.categorical_distribution = categorical_distribution + if self.categorical_distribution: + self.reward_support_size = len(torch.arange(*reward_support_range)) + self.value_support_size = len(torch.arange(*value_support_range)) + else: self.reward_support_size = 1 self.value_support_size = 1 - else: - self.reward_support_size = reward_support_size - self.value_support_size = value_support_size self.action_space_size = action_space_size self.continuous_action_space = False diff --git a/lzero/model/muzero_context_model.py b/lzero/model/muzero_context_model.py index 75b456366..30f1d9b7d 100644 --- a/lzero/model/muzero_context_model.py +++ b/lzero/model/muzero_context_model.py @@ -28,8 +28,8 @@ def __init__( reward_head_hidden_channels: SequenceType = [32], value_head_hidden_channels: SequenceType = [32], policy_head_hidden_channels: SequenceType = [32], - reward_support_size: int = 601, - value_support_size: int = 601, + reward_support_range: SequenceType =(-300., 301., 1.), + value_support_range: SequenceType =(-300., 301., 1.), proj_hid: int = 1024, proj_out: int = 1024, pred_hid: int = 512, @@ -65,8 +65,8 @@ def __init__( - reward_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). - value_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). - policy_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). - - reward_support_size (:obj:`int`): The size of categorical reward output - - value_support_size (:obj:`int`): The size of categorical value output. + - reward_support_range (:obj:`SequenceType`): The range of categorical reward output + - value_support_range (:obj:`SequenceType`): The range of categorical value output. - proj_hid (:obj:`int`): The size of projection hidden layer. - proj_out (:obj:`int`): The size of projection output layer. - pred_hid (:obj:`int`): The size of prediction hidden layer. @@ -98,8 +98,8 @@ def __init__( self.categorical_distribution = categorical_distribution if self.categorical_distribution: - self.reward_support_size = reward_support_size - self.value_support_size = value_support_size + self.reward_support_size = len(torch.arange(*reward_support_range)) + self.value_support_size = len(torch.arange(*value_support_range)) else: self.reward_support_size = 1 self.value_support_size = 1 diff --git a/lzero/model/muzero_model.py b/lzero/model/muzero_model.py index e7aca74b4..75680ac06 100644 --- a/lzero/model/muzero_model.py +++ b/lzero/model/muzero_model.py @@ -31,8 +31,8 @@ def __init__( reward_head_hidden_channels: SequenceType = [32], value_head_hidden_channels: SequenceType = [32], policy_head_hidden_channels: SequenceType = [32], - reward_support_size: int = 601, - value_support_size: int = 601, + reward_support_range: SequenceType =(-300., 301., 1.), + value_support_range: SequenceType =(-300., 301., 1.), proj_hid: int = 1024, proj_out: int = 1024, pred_hid: int = 512, @@ -65,8 +65,8 @@ def __init__( - reward_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). - value_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). - policy_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). - - reward_support_size (:obj:`int`): The size of categorical reward output - - value_support_size (:obj:`int`): The size of categorical value output. + - reward_support_range (:obj:`SequenceType`): The range of categorical reward output + - value_support_range (:obj:`SequenceType`): The range of categorical value output. - proj_hid (:obj:`int`): The size of projection hidden layer. - proj_out (:obj:`int`): The size of projection output layer. - pred_hid (:obj:`int`): The size of prediction hidden layer. @@ -97,8 +97,8 @@ def __init__( self.categorical_distribution = categorical_distribution if self.categorical_distribution: - self.reward_support_size = reward_support_size - self.value_support_size = value_support_size + self.reward_support_size = len(torch.arange(*reward_support_range)) + self.value_support_size = len(torch.arange(*value_support_range)) else: self.reward_support_size = 1 self.value_support_size = 1 diff --git a/lzero/model/muzero_model_mlp.py b/lzero/model/muzero_model_mlp.py index 01f6924b9..17565b018 100644 --- a/lzero/model/muzero_model_mlp.py +++ b/lzero/model/muzero_model_mlp.py @@ -20,8 +20,8 @@ def __init__( reward_head_hidden_channels: SequenceType = [32], value_head_hidden_channels: SequenceType = [32], policy_head_hidden_channels: SequenceType = [32], - reward_support_size: int = 601, - value_support_size: int = 601, + reward_support_range: SequenceType =(-300., 301., 1.), + value_support_range: SequenceType =(-300., 301., 1.), proj_hid: int = 1024, proj_out: int = 1024, pred_hid: int = 512, @@ -51,8 +51,8 @@ def __init__( - reward_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). - value_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). - policy_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). - - reward_support_size (:obj:`int`): The size of categorical reward output - - value_support_size (:obj:`int`): The size of categorical value output. + - reward_support_range (:obj:`SequenceType`): The range of categorical reward output + - value_support_range (:obj:`SequenceType`): The range of categorical value output. - proj_hid (:obj:`int`): The size of projection hidden layer. - proj_out (:obj:`int`): The size of projection output layer. - pred_hid (:obj:`int`): The size of prediction hidden layer. @@ -69,12 +69,12 @@ def __init__( """ super(MuZeroModelMLP, self).__init__() self.categorical_distribution = categorical_distribution - if not self.categorical_distribution: + if self.categorical_distribution: + self.reward_support_size = len(torch.arange(*reward_support_range)) + self.value_support_size = len(torch.arange(*value_support_range)) + else: self.reward_support_size = 1 self.value_support_size = 1 - else: - self.reward_support_size = reward_support_size - self.value_support_size = value_support_size self.action_space_size = action_space_size self.continuous_action_space = False diff --git a/lzero/model/muzero_model_multitask.py b/lzero/model/muzero_model_multitask.py index 6d7326152..cb30b3d38 100644 --- a/lzero/model/muzero_model_multitask.py +++ b/lzero/model/muzero_model_multitask.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import Optional, Tuple, Sequence, List import math import torch @@ -7,12 +7,51 @@ from ding.utils import MODEL_REGISTRY, SequenceType from numpy import ndarray +# The following imports are assumed to be from the same project directory. +# To maintain API consistency, their internal logic is not modified. from .common import MZNetworkOutput, RepresentationNetwork, PredictionNetwork, FeatureAndGradientHook from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean @MODEL_REGISTRY.register('MuZeroMTModel') class MuZeroMTModel(nn.Module): + """ + Overview: + The Multi-Task MuZero model, which is a variant of the original MuZero model adapted for multi-task learning. + This model features a shared representation network and dynamics network, but utilizes separate, task-specific + prediction networks. This architecture allows the model to learn shared dynamics while specializing its + policy and value predictions for each individual task. + """ + # Default configuration for the model. + # This structure is recommended over using cfg.get('key', default_value) inside the code. + config = dict( + observation_shape=(12, 96, 96), + action_space_size=6, + num_res_blocks=1, + num_channels=64, + reward_head_channels=16, + value_head_channels=16, + policy_head_channels=16, + fc_reward_layers=[32], + fc_value_layers=[32], + fc_policy_layers=[32], + reward_support_size=601, + value_support_size=601, + proj_hid=1024, + proj_out=1024, + pred_hid=512, + pred_out=1024, + self_supervised_learning_loss=False, + categorical_distribution=True, + activation=nn.ReLU(inplace=True), + last_linear_layer_init_zero=True, + state_norm=False, + downsample=False, + norm_type='BN', + discrete_action_encoding_type='one_hot', + analysis_sim_norm=False, + task_num=1, + ) def __init__( self, @@ -23,9 +62,9 @@ def __init__( reward_head_channels: int = 16, value_head_channels: int = 16, policy_head_channels: int = 16, - fc_reward_layers: SequenceType = [32], - fc_value_layers: SequenceType = [32], - fc_policy_layers: SequenceType = [32], + fc_reward_layers: List[int] = [32], + fc_value_layers: List[int] = [32], + fc_policy_layers: List[int] = [32], reward_support_size: int = 601, value_support_size: int = 601, proj_hid: int = 1024, @@ -34,112 +73,136 @@ def __init__( pred_out: int = 1024, self_supervised_learning_loss: bool = False, categorical_distribution: bool = True, - activation: nn.Module = nn.ReLU(inplace=True), + activation: Optional[nn.Module] = None, last_linear_layer_init_zero: bool = True, state_norm: bool = False, downsample: bool = False, norm_type: Optional[str] = 'BN', discrete_action_encoding_type: str = 'one_hot', analysis_sim_norm: bool = False, - task_num: int = 1, # 任务数量 + task_num: int = 1, *args, **kwargs - ): + ) -> None: """ - 多任务MuZero模型的定义,继承自MuZeroModel。 - 增加了多任务相关的处理,如任务数量和动作空间大小调整。 + Overview: + Constructor for the MuZeroMTModel. + Arguments: + - observation_shape (:obj:`SequenceType`): The shape of the input observation, e.g., (12, 96, 96). + - action_space_size (:obj:`int`): The size of the action space, applicable for discrete action spaces. + - num_res_blocks (:obj:`int`): The number of residual blocks in the representation, dynamics, and prediction networks. + - num_channels (:obj:`int`): The number of channels in the latent state. + - reward_head_channels (:obj:`int`): The number of channels in the reward head. + - value_head_channels (:obj:`int`): The number of channels in the value head. + - policy_head_channels (:obj:`int`): The number of channels in the policy head. + - fc_reward_layers (:obj:`List[int]`): The hidden layer sizes of the reward MLP. + - fc_value_layers (:obj:`List[int]`): The hidden layer sizes of the value MLP. + - fc_policy_layers (:obj:`List[int]`): The hidden layer sizes of the policy MLP. + - reward_support_size (:obj:`int`): The support size for categorical reward distribution. + - value_support_size (:obj:`int`): The support size for categorical value distribution. + - proj_hid (:obj:`int`): The hidden size of the projection network for SSL. + - proj_out (:obj:`int`): The output size of the projection network for SSL. + - pred_hid (:obj:`int`): The hidden size of the prediction head for SSL. + - pred_out (:obj:`int`): The output size of the prediction head for SSL. + - self_supervised_learning_loss (:obj:`bool`): Whether to use self-supervised learning loss. + - categorical_distribution (:obj:`bool`): Whether to use categorical distribution for value and reward. + - activation (:obj:`Optional[nn.Module]`): The activation function to use. Defaults to nn.ReLU(inplace=True). + - last_linear_layer_init_zero (:obj:`bool`): Whether to initialize the last linear layer to zero. + - state_norm (:obj:`bool`): Whether to apply re-normalization to the latent state. + - downsample (:obj:`bool`): Whether to downsample the observation image. + - norm_type (:obj:`Optional[str]`): The type of normalization to use, either 'BN' (BatchNorm) or 'LN' (LayerNorm). + - discrete_action_encoding_type (:obj:`str`): The encoding type for discrete actions, 'one_hot' or 'not_one_hot'. + - analysis_sim_norm (:obj:`bool`): A flag for analysis, enables hooks for SimNorm analysis. + - task_num (:obj:`int`): The total number of tasks for the multi-task setup. """ super(MuZeroMTModel, self).__init__() - - print(f'==========MuZeroMTModel, num_res_blocks:{num_res_blocks}, num_channels:{num_channels}, task_num:{task_num}===========') - - if discrete_action_encoding_type == 'one_hot': - self.action_encoding_dim = action_space_size - elif discrete_action_encoding_type == 'not_one_hot': - self.action_encoding_dim = 1 - - assert discrete_action_encoding_type in ['one_hot', 'not_one_hot'], discrete_action_encoding_type - - if isinstance(observation_shape, int) or len(observation_shape) == 1: - # for vector obs input, e.g. classical control and box2d environments - # to be compatible with LightZero model/policy, transform to shape: [C, W, H] - observation_shape = [1, observation_shape, 1] + if activation is None: + activation = nn.ReLU(inplace=True) + # --- Store configuration --- + self.action_space_size = action_space_size self.categorical_distribution = categorical_distribution + self.self_supervised_learning_loss = self_supervised_learning_loss + self.state_norm = state_norm + self.downsample = downsample + self.task_num = task_num + self.discrete_action_encoding_type = discrete_action_encoding_type + if self.categorical_distribution: self.reward_support_size = reward_support_size self.value_support_size = value_support_size else: self.reward_support_size = 1 self.value_support_size = 1 + + # --- Prepare observation shape and action encoding dimension --- + if isinstance(observation_shape, int) or len(observation_shape) == 1: + # For 1D vector observations (e.g., classic control), wrap them into a 2D image-like format [C, W, H] + # to be compatible with the convolutional networks. + observation_shape = (1, observation_shape[0], 1) if isinstance(observation_shape, tuple) else (1, observation_shape, 1) - self.task_num = task_num - self.action_space_size = 18 # 假设每个任务的动作空间相同 + if self.discrete_action_encoding_type == 'one_hot': + self.action_encoding_dim = self.action_space_size + elif self.discrete_action_encoding_type == 'not_one_hot': + self.action_encoding_dim = 1 + else: + raise ValueError(f"Unsupported discrete_action_encoding_type: {self.discrete_action_encoding_type}") - self.categorical_distribution = categorical_distribution + latent_size = self._get_latent_size(observation_shape, self.downsample) - self.discrete_action_encoding_type = 'one_hot' + # --- Initialize Network Components --- - # 共享表示网络 + # 1. Shared Representation Network self.representation_network = RepresentationNetwork( - observation_shape, - num_res_blocks, - num_channels, - downsample, + observation_shape=observation_shape, + num_res_blocks=num_res_blocks, + num_channels=num_channels, + downsample=self.downsample, activation=activation, norm_type=norm_type ) - # ====== for analysis ====== - if analysis_sim_norm: - self.encoder_hook = FeatureAndGradientHook() - self.encoder_hook.setup_hooks(self.representation_network) - - # 共享动态网络 + # 2. Shared Dynamics Network self.dynamics_network = DynamicsNetwork( - observation_shape, + observation_shape=observation_shape, action_encoding_dim=self.action_encoding_dim, num_res_blocks=num_res_blocks, num_channels=num_channels + self.action_encoding_dim, reward_head_channels=reward_head_channels, fc_reward_layers=fc_reward_layers, - output_support_size=reward_support_size, - flatten_output_size_for_reward_head=reward_head_channels * self._get_latent_size(observation_shape, downsample), - downsample=downsample, + output_support_size=self.reward_support_size, + flatten_output_size_for_reward_head=reward_head_channels * latent_size, + downsample=self.downsample, last_linear_layer_init_zero=last_linear_layer_init_zero, activation=activation, norm_type=norm_type ) - # 独立的预测网络,每个任务一个 - # 计算flatten_output_size - value_flatten_size = int(value_head_channels * self._get_latent_size(observation_shape, downsample)) - policy_flatten_size = int(policy_head_channels * self._get_latent_size(observation_shape, downsample)) - + # 3. Task-Specific Prediction Networks self.prediction_networks = nn.ModuleList([ PredictionNetwork( - observation_shape, - action_space_size, - num_res_blocks, - num_channels, - value_head_channels, - policy_head_channels, - fc_value_layers, - fc_policy_layers, - self.value_support_size, - flatten_output_size_for_value_head=value_flatten_size, - flatten_output_size_for_policy_head=policy_flatten_size, - downsample=downsample, + observation_shape=observation_shape, + action_space_size=self.action_space_size, + num_res_blocks=num_res_blocks, + num_channels=num_channels, + value_head_channels=value_head_channels, + policy_head_channels=policy_head_channels, + fc_value_layers=fc_value_layers, + fc_policy_layers=fc_policy_layers, + output_support_size=self.value_support_size, + flatten_output_size_for_value_head=value_head_channels * latent_size, + flatten_output_size_for_policy_head=policy_head_channels * latent_size, + downsample=self.downsample, last_linear_layer_init_zero=last_linear_layer_init_zero, activation=activation, norm_type=norm_type - ) for _ in range(task_num) + ) for _ in range(self.task_num) ]) - # 共享投影和预测头(如果使用自监督学习损失) - if self_supervised_learning_loss: + # 4. Optional Self-Supervised Learning (SSL) Components + if self.self_supervised_learning_loss: self.projection_network = nn.Sequential( - nn.Linear(num_channels * self._get_latent_size(observation_shape, downsample), proj_hid), + nn.Linear(num_channels * latent_size, proj_hid), nn.BatchNorm1d(proj_hid), activation, nn.Linear(proj_hid, proj_hid), @@ -148,145 +211,194 @@ def __init__( nn.Linear(proj_hid, proj_out), nn.BatchNorm1d(proj_out) ) - self.prediction_head = nn.Sequential( nn.Linear(proj_out, pred_hid), nn.BatchNorm1d(pred_hid), activation, nn.Linear(pred_hid, pred_out), ) + + # 5. Optional Hook for Analysis + if analysis_sim_norm: + self.encoder_hook = FeatureAndGradientHook() + self.encoder_hook.setup_hooks(self.representation_network) - self.self_supervised_learning_loss = self_supervised_learning_loss - self.state_norm = state_norm - self.downsample = downsample - - def _get_latent_size(self, observation_shape: SequenceType, downsample: bool) -> int: + @staticmethod + def _get_latent_size(observation_shape: SequenceType, downsample: bool) -> int: """ - 辅助函数,根据观测形状和下采样选项计算潜在状态的大小。 + Overview: + Helper function to calculate the flattened size of the latent space based on observation shape and downsampling. + Arguments: + - observation_shape (:obj:`SequenceType`): The shape of the input observation. + - downsample (:obj:`bool`): Whether downsampling is enabled. + Returns: + - int: The flattened size (height * width) of the latent space. """ if downsample: + # With downsampling, the spatial dimensions are reduced by a factor of 16 (2^4). return math.ceil(observation_shape[-2] / 16) * math.ceil(observation_shape[-1] / 16) else: return observation_shape[-2] * observation_shape[-1] def initial_inference(self, obs: torch.Tensor, task_id: int = 0) -> MZNetworkOutput: """ - 多任务初始推理,基于任务ID选择对应的预测网络。 + Overview: + Performs the initial inference from a raw observation. It encodes the observation into a latent state + and then uses the task-specific prediction network to compute the policy and value. + Arguments: + - obs (:obj:`torch.Tensor`): The raw observation tensor. + - task_id (:obj:`int`): The identifier for the current task, used to select the correct prediction network. + Returns: + - MZNetworkOutput: A dataclass containing the predicted value, reward (initially zero), policy logits, and latent state. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, C, H, W)`, where B is batch size. + - task_id (:obj:`int`): Scalar. + - Return.value: :math:`(B, value_support_size)`. + - Return.reward: :math:`(B, reward_support_size)`. + - Return.policy_logits: :math:`(B, action_space_size)`. + - Return.latent_state: :math:`(B, num_channels, H', W')`. """ batch_size = obs.size(0) latent_state = self.representation_network(obs) if self.state_norm: latent_state = renormalize(latent_state) + + # Select the prediction network based on the task ID. + assert 0 <= task_id < self.task_num, f"Task ID {task_id} is out of range [0, {self.task_num-1}]" prediction_net = self.prediction_networks[task_id] policy_logits, value = prediction_net(latent_state) return MZNetworkOutput( - value, - [0. for _ in range(batch_size)], - policy_logits, - latent_state, + value=value, + reward=[0. for _ in range(batch_size)], # Initial reward is always zero. + policy_logits=policy_logits, + latent_state=latent_state, ) def recurrent_inference(self, latent_state: torch.Tensor, action: torch.Tensor, task_id: int = 0) -> MZNetworkOutput: """ - 多任务递归推理,根据任务ID选择对应的预测网络。 + Overview: + Performs recurrent inference from a latent state and an action. It uses the dynamics network to predict + the next latent state and reward, and then uses the task-specific prediction network to compute the + policy and value for the next state. + Arguments: + - latent_state (:obj:`torch.Tensor`): The current latent state. + - action (:obj:`torch.Tensor`): The action taken in the current state. + - task_id (:obj:`int`): The identifier for the current task. + Returns: + - MZNetworkOutput: A dataclass containing the predicted value, reward, policy logits, and the next latent state. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, num_channels, H', W')`. + - action (:obj:`torch.Tensor`): :math:`(B, )`. + - task_id (:obj:`int`): Scalar. + - Return.value: :math:`(B, value_support_size)`. + - Return.reward: :math:`(B, reward_support_size)`. + - Return.policy_logits: :math:`(B, action_space_size)`. + - Return.latent_state: :math:`(B, num_channels, H', W')`. """ next_latent_state, reward = self._dynamics(latent_state, action) if self.state_norm: next_latent_state = renormalize(next_latent_state) + + # Select the prediction network based on the task ID. + assert 0 <= task_id < self.task_num, f"Task ID {task_id} is out of range [0, {self.task_num-1}]" prediction_net = self.prediction_networks[task_id] policy_logits, value = prediction_net(next_latent_state) return MZNetworkOutput(value, reward, policy_logits, next_latent_state) - def _dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Overview: - Concatenate ``latent_state`` and ``action`` and use the dynamics network to predict ``next_latent_state`` - and ``reward``. + Applies the dynamics function by concatenating the latent state with the encoded action and passing it + through the dynamics network to predict the next latent state and reward. Arguments: - - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. - - action (:obj:`torch.Tensor`): The predicted action to rollout. + - latent_state (:obj:`torch.Tensor`): The encoding latent state of the input state. + - action (:obj:`torch.Tensor`): The action to rollout. Returns: - - next_latent_state (:obj:`torch.Tensor`): The predicted latent state of the next timestep. - - reward (:obj:`torch.Tensor`): The predicted reward of the current latent state and selected action. + - Tuple[torch.Tensor, torch.Tensor]: A tuple containing the predicted next latent state and reward. Shapes: - - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ - latent state, W_ is the width of latent state. - - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. - - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ - latent state, W_ is the width of latent state. - - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, C, H', W')`. + - action (:obj:`torch.Tensor`): :math:`(B, )`. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, C, H', W')`. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`. """ - # NOTE: the discrete action encoding type is important for some environments - - # discrete action space + # Encode the action and expand it to match the spatial dimensions of the latent state. if self.discrete_action_encoding_type == 'one_hot': - # Stack latent_state with the one hot encoded action. - # The final action_encoding shape is (batch_size, action_space_size, latent_state[2], latent_state[3]), e.g. (8, 2, 4, 1). - if len(action.shape) == 1: - # (batch_size, ) -> (batch_size, 1) - # e.g., torch.Size([8]) -> torch.Size([8, 1]) - action = action.unsqueeze(-1) - - # transform action to one-hot encoding. - # action_one_hot shape: (batch_size, action_space_size), e.g., (8, 4) - action_one_hot = torch.zeros(action.shape[0], self.action_space_size, device=action.device) - # transform action to torch.int64 - action = action.long() - action_one_hot.scatter_(1, action, 1) - + # Convert action indices to one-hot vectors. + action_one_hot = F.one_hot(action.long(), num_classes=self.action_space_size).float() + # Reshape for broadcasting: (B, A) -> (B, A, 1, 1) action_encoding_tmp = action_one_hot.unsqueeze(-1).unsqueeze(-1) + # Expand to (B, A, H', W') action_encoding = action_encoding_tmp.expand( latent_state.shape[0], self.action_space_size, latent_state.shape[2], latent_state.shape[3] ) - elif self.discrete_action_encoding_type == 'not_one_hot': - # Stack latent_state with the normalized encoded action. - # The final action_encoding shape is (batch_size, 1, latent_state[2], latent_state[3]), e.g. (8, 1, 4, 1). - if len(action.shape) == 2: - # (batch_size, action_dim=1) -> (batch_size, 1, 1, 1) - # e.g., torch.Size([8, 1]) -> torch.Size([8, 1, 1, 1]) - action = action.unsqueeze(-1).unsqueeze(-1) - elif len(action.shape) == 1: - # (batch_size,) -> (batch_size, 1, 1, 1) - # e.g., -> torch.Size([8, 1, 1, 1]) - action = action.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) - - action_encoding = action.expand( + # Encode action as a single channel, normalized by action space size. + # Reshape for broadcasting: (B,) -> (B, 1, 1, 1) + action_encoding_tmp = action.float().view(-1, 1, 1, 1) + # Normalize and expand to (B, 1, H', W') + action_encoding = action_encoding_tmp / self.action_space_size + action_encoding = action_encoding.expand( latent_state.shape[0], 1, latent_state.shape[2], latent_state.shape[3] - ) / self.action_space_size + ) - # state_action_encoding shape: (batch_size, latent_state[1] + action_dim, latent_state[2], latent_state[3]) or - # (batch_size, latent_state[1] + action_space_size, latent_state[2], latent_state[3]) depending on the discrete_action_encoding_type. + # Concatenate latent state and action encoding along the channel dimension. state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) + # Predict next state and reward. next_latent_state, reward = self.dynamics_network(state_action_encoding) + if self.state_norm: next_latent_state = renormalize(next_latent_state) + return next_latent_state, reward def project(self, latent_state: torch.Tensor, with_grad: bool = True) -> torch.Tensor: """ - 多任务投影方法,当前实现为共享投影网络。 + Overview: + Projects the latent state into a different space for self-supervised learning (e.g., BYOL, SimSiam). + This involves a projection network and an optional prediction head. + Arguments: + - latent_state (:obj:`torch.Tensor`): The latent state to project. + - with_grad (:obj:`bool`): If False, detach the output of the projection network to stop gradients. + This is typically used for the target network in SSL. + Returns: + - torch.Tensor: The projected (and possibly predicted) representation. """ if not self.self_supervised_learning_loss: - raise NotImplementedError("Self-supervised learning loss is not enabled for this model.") + raise NotImplementedError("The 'project' method requires 'self_supervised_learning_loss' to be enabled.") + # Flatten the latent state from (B, C, H, W) to (B, C*H*W). latent_state = latent_state.reshape(latent_state.shape[0], -1) + proj = self.projection_network(latent_state) + if with_grad: + # Return the output of the prediction head, with gradients flowing. return self.prediction_head(proj) else: + # Return the output of the projection network, detached from the graph. return proj.detach() def get_params_mean(self) -> float: + """ + Overview: + Computes the mean of all model parameters. Useful for debugging and monitoring training. + Returns: + - float: The mean value of all parameters. + """ return get_params_mean(self) class DynamicsNetwork(nn.Module): + """ + Overview: + The dynamics network of the MuZero model. It takes a state-action encoding as input and predicts + the next latent state and the reward for the transition. This network is shared across all tasks + in the multi-task setup. + """ def __init__( self, @@ -295,76 +407,111 @@ def __init__( num_res_blocks: int = 1, num_channels: int = 64, reward_head_channels: int = 64, - fc_reward_layers: SequenceType = [32], + fc_reward_layers: List[int] = [32], output_support_size: int = 601, flatten_output_size_for_reward_head: int = 64, downsample: bool = False, last_linear_layer_init_zero: bool = True, - activation: Optional[nn.Module] = nn.ReLU(inplace=True), + activation: Optional[nn.Module] = None, norm_type: Optional[str] = 'BN', - ): + ) -> None: """ - DynamicsNetwork定义,适用于多任务共享。 + Overview: + Constructor for the DynamicsNetwork. + Arguments: + - observation_shape (:obj:`SequenceType`): The shape of the original input observation. + - action_encoding_dim (:obj:`int`): The dimension of the encoded action. + - num_res_blocks (:obj:`int`): The number of residual blocks. + - num_channels (:obj:`int`): The number of channels in the input (latent_state + action_encoding). + - reward_head_channels (:obj:`int`): The number of channels for the reward head's convolutional layer. + - fc_reward_layers (:obj:`List[int]`): The hidden layer sizes of the reward MLP. + - output_support_size (:obj:`int`): The support size for the categorical reward distribution. + - flatten_output_size_for_reward_head (:obj:`int`): The flattened input size for the reward MLP. + - downsample (:obj:`bool`): Whether downsampling is used, affecting LayerNorm shapes. + - last_linear_layer_init_zero (:obj:`bool`): Whether to initialize the last linear layer to zero. + - activation (:obj:`Optional[nn.Module]`): The activation function. Defaults to nn.ReLU(inplace=True). + - norm_type (:obj:`Optional[str]`): The type of normalization, 'BN' or 'LN'. """ super().__init__() - assert norm_type in ['BN', 'LN'], "norm_type must be in ['BN', 'LN']" - assert num_channels > action_encoding_dim, f'num_channels:{num_channels} <= action_encoding_dim:{action_encoding_dim}' - - self.num_channels = num_channels - self.flatten_output_size_for_reward_head = flatten_output_size_for_reward_head + if activation is None: + activation = nn.ReLU(inplace=True) + + assert norm_type in ['BN', 'LN'], f"norm_type must be 'BN' or 'LN', but got {norm_type}" + # The input channels to the first conv layer is num_channels, which includes the original latent channels + # and the action encoding channels. The output should be the number of channels for the latent state. + latent_channels = num_channels - action_encoding_dim + assert latent_channels > 0, f"num_channels ({num_channels}) must be greater than action_encoding_dim ({action_encoding_dim})" self.action_encoding_dim = action_encoding_dim - self.conv = nn.Conv2d(num_channels, num_channels - self.action_encoding_dim, kernel_size=3, stride=1, padding=1, bias=False) - + self.activation = activation + + # Convolutional layer to process the combined state-action encoding. + self.conv = nn.Conv2d(num_channels, latent_channels, kernel_size=3, stride=1, padding=1, bias=False) + + # Normalization layer for the main path. if norm_type == 'BN': - self.norm_common = nn.BatchNorm2d(num_channels - self.action_encoding_dim) + self.norm_common = nn.BatchNorm2d(latent_channels) elif norm_type == 'LN': if downsample: - self.norm_common = nn.LayerNorm([num_channels - self.action_encoding_dim, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)]) + ln_shape = [latent_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)] else: - self.norm_common = nn.LayerNorm([num_channels - self.action_encoding_dim, observation_shape[-2], observation_shape[-1]]) + ln_shape = [latent_channels, observation_shape[-2], observation_shape[-1]] + self.norm_common = nn.LayerNorm(ln_shape) + # A series of residual blocks to deepen the network. self.resblocks = nn.ModuleList( - [ - ResBlock( - in_channels=num_channels - self.action_encoding_dim, activation=activation, norm_type='BN', res_type='basic', bias=False - ) for _ in range(num_res_blocks) - ] + [ResBlock(in_channels=latent_channels, activation=activation, norm_type='BN', res_type='basic', bias=False) + for _ in range(num_res_blocks)] ) - self.conv1x1_reward = nn.Conv2d(num_channels - self.action_encoding_dim, reward_head_channels, 1) - + # --- Reward Head --- + # 1x1 convolution to create an input for the reward MLP. + self.conv1x1_reward = nn.Conv2d(latent_channels, reward_head_channels, 1) + + # Normalization for the reward head. if norm_type == 'BN': self.norm_reward = nn.BatchNorm2d(reward_head_channels) elif norm_type == 'LN': if downsample: - self.norm_reward = nn.LayerNorm([reward_head_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)]) + ln_shape_reward = [reward_head_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)] else: - self.norm_reward = nn.LayerNorm([reward_head_channels, observation_shape[-2], observation_shape[-1]]) + ln_shape_reward = [reward_head_channels, observation_shape[-2], observation_shape[-1]] + self.norm_reward = nn.LayerNorm(ln_shape_reward) + # MLP to predict the reward value from the processed features. self.fc_reward_head = MLP( - self.flatten_output_size_for_reward_head, + in_channels=flatten_output_size_for_reward_head, hidden_channels=fc_reward_layers[0], - layer_num=len(fc_reward_layers) + 1, out_channels=output_support_size, + layer_num=len(fc_reward_layers) + 1, activation=activation, norm_type=norm_type, output_activation=False, output_norm=False, last_linear_layer_init_zero=last_linear_layer_init_zero ) - self.activation = activation def forward(self, state_action_encoding: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ - DynamicsNetwork的前向传播,预测下一个潜在状态和奖励。 + Overview: + Forward pass for the dynamics network. + Arguments: + - state_action_encoding (:obj:`torch.Tensor`): The concatenated latent state and action encoding. + Returns: + - Tuple[torch.Tensor, torch.Tensor]: A tuple containing the next latent state and the predicted reward. + Shapes: + - state_action_encoding (:obj:`torch.Tensor`): :math:`(B, C_latent + C_action, H', W')`. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, C_latent, H', W')`. + - reward (:obj:`torch.Tensor`): :math:`(B, output_support_size)`. """ - # 提取状态编码(去除动作编码部分) - state_encoding = state_action_encoding[:, :-self.action_encoding_dim, :, :] + # The original latent state is part of the input, used for the residual connection. + state_encoding = state_action_encoding[:, : -self.action_encoding_dim, :, :] + + # Main path for predicting the next latent state. x = self.conv(state_action_encoding) x = self.norm_common(x) - - # 残差连接 + + # Add residual connection from the original latent state. x += state_encoding x = self.activation(x) @@ -372,18 +519,31 @@ def forward(self, state_action_encoding: torch.Tensor) -> Tuple[torch.Tensor, to x = block(x) next_latent_state = x - x = self.conv1x1_reward(next_latent_state) - x = self.norm_reward(x) - x = self.activation(x) - x = x.view(x.shape[0], -1) - - # 使用全连接层预测奖励 - reward = self.fc_reward_head(x) + # --- Reward Prediction Path --- + # Process the next latent state to predict the reward. + reward_x = self.conv1x1_reward(next_latent_state) + reward_x = self.norm_reward(reward_x) + reward_x = self.activation(reward_x) + # Flatten the features before passing to the MLP. + reward_x = reward_x.view(reward_x.shape[0], -1) + reward = self.fc_reward_head(reward_x) return next_latent_state, reward def get_dynamic_mean(self) -> float: + """ + Overview: + Computes the mean of parameters in the dynamics-related layers (conv and resblocks). + Returns: + - float: The mean value of dynamics parameters. + """ return get_dynamic_mean(self) def get_reward_mean(self) -> Tuple[ndarray, float]: + """ + Overview: + Computes the mean of parameters and the last layer bias in the reward head. + Returns: + - Tuple[ndarray, float]: A tuple containing the mean of the last layer's weights and its bias. + """ return get_reward_mean(self) \ No newline at end of file diff --git a/lzero/model/muzero_rnn_full_obs_model.py b/lzero/model/muzero_rnn_full_obs_model.py index 7adb9add2..af7d72c32 100644 --- a/lzero/model/muzero_rnn_full_obs_model.py +++ b/lzero/model/muzero_rnn_full_obs_model.py @@ -31,8 +31,8 @@ def __init__( reward_head_hidden_channels: SequenceType = [32], value_head_hidden_channels: SequenceType = [32], policy_head_hidden_channels: SequenceType = [32], - reward_support_size: int = 601, - value_support_size: int = 601, + reward_support_range: SequenceType =(-300., 301., 1.), + value_support_range: SequenceType =(-300., 301., 1.), proj_hid: int = 1024, proj_out: int = 1024, pred_hid: int = 512, @@ -70,8 +70,8 @@ def __init__( - reward_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). - value_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). - policy_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). - - reward_support_size (:obj:`int`): The size of categorical reward output - - value_support_size (:obj:`int`): The size of categorical value output. + - reward_support_range (:obj:`SequenceType`): The range of categorical reward output + - value_support_range (:obj:`SequenceType`): The range of categorical value output. - proj_hid (:obj:`int`): The size of projection hidden layer. - proj_out (:obj:`int`): The size of projection output layer. - pred_hid (:obj:`int`): The size of prediction hidden layer. @@ -95,12 +95,13 @@ def __init__( # for vector obs input, e.g. classical control and box2d environments # to be compatible with LightZero model/policy, transform to shape: [C, W, H] observation_shape = [1, observation_shape, 1] - if not categorical_distribution: + self.categorical_distribution = categorical_distribution + if self.categorical_distribution: + self.reward_support_size = len(torch.arange(*reward_support_range)) + self.value_support_size = len(torch.arange(*value_support_range)) + else: self.reward_support_size = 1 self.value_support_size = 1 - else: - self.reward_support_size = reward_support_size - self.value_support_size = value_support_size self.action_space_size = action_space_size assert discrete_action_encoding_type in ['one_hot', 'not_one_hot'], discrete_action_encoding_type diff --git a/lzero/model/sampled_efficientzero_model.py b/lzero/model/sampled_efficientzero_model.py index 0bd14c6d2..726e55b16 100644 --- a/lzero/model/sampled_efficientzero_model.py +++ b/lzero/model/sampled_efficientzero_model.py @@ -29,8 +29,8 @@ def __init__( reward_head_hidden_channels: SequenceType = [256], value_head_hidden_channels: SequenceType = [256], policy_head_hidden_channels: SequenceType = [256], - reward_support_size: int = 601, - value_support_size: int = 601, + reward_support_range: SequenceType =(-300., 301., 1.), + value_support_range: SequenceType =(-300., 301., 1.), proj_hid: int = 1024, proj_out: int = 1024, pred_hid: int = 512, @@ -76,8 +76,8 @@ def __init__( - reward_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). - value_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). - policy_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). - - reward_support_size (:obj:`int`): The size of categorical reward output - - value_support_size (:obj:`int`): The size of categorical value output. + - reward_support_range (:obj:`SequenceType`): The range of categorical reward output + - value_support_range (:obj:`SequenceType`): The range of categorical value output. - proj_hid (:obj:`int`): The size of projection hidden layer. - proj_out (:obj:`int`): The size of projection output layer. - pred_hid (:obj:`int`): The size of prediction hidden layer. @@ -110,12 +110,13 @@ def __init__( # for vector obs input, e.g. classical control and box2d environments # to be compatible with LightZero model/policy, transform to shape: [C, W, H] observation_shape = [1, observation_shape, 1] - if not categorical_distribution: + self.categorical_distribution = categorical_distribution + if self.categorical_distribution: + self.reward_support_size = len(torch.arange(*reward_support_range)) + self.value_support_size = len(torch.arange(*value_support_range)) + else: self.reward_support_size = 1 self.value_support_size = 1 - else: - self.reward_support_size = reward_support_size - self.value_support_size = value_support_size self.continuous_action_space = continuous_action_space self.action_space_size = action_space_size diff --git a/lzero/model/sampled_efficientzero_model_mlp.py b/lzero/model/sampled_efficientzero_model_mlp.py index 39f0c716f..e38eb282d 100644 --- a/lzero/model/sampled_efficientzero_model_mlp.py +++ b/lzero/model/sampled_efficientzero_model_mlp.py @@ -23,8 +23,8 @@ def __init__( reward_head_hidden_channels: SequenceType = [256], value_head_hidden_channels: SequenceType = [256], policy_head_hidden_channels: SequenceType = [256], - reward_support_size: int = 601, - value_support_size: int = 601, + reward_support_range: SequenceType =(-300., 301., 1.), + value_support_range: SequenceType =(-300., 301., 1.), proj_hid: int = 1024, proj_out: int = 1024, pred_hid: int = 512, @@ -65,8 +65,8 @@ def __init__( - reward_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). - value_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). - policy_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). - - reward_support_size (:obj:`int`): The size of categorical reward output - - value_support_size (:obj:`int`): The size of categorical value output. + - reward_support_range (:obj:`SequenceType`): The range of categorical reward output + - value_support_range (:obj:`SequenceType`): The range of categorical value output. - proj_hid (:obj:`int`): The size of projection hidden layer. - proj_out (:obj:`int`): The size of projection output layer. - pred_hid (:obj:`int`): The size of prediction hidden layer. @@ -91,12 +91,13 @@ def __init__( - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection for dynamics network, default set it to False. """ super(SampledEfficientZeroModelMLP, self).__init__() - if not categorical_distribution: + self.categorical_distribution = categorical_distribution + if self.categorical_distribution: + self.reward_support_size = len(torch.arange(*reward_support_range)) + self.value_support_size = len(torch.arange(*value_support_range)) + else: self.reward_support_size = 1 self.value_support_size = 1 - else: - self.reward_support_size = reward_support_size - self.value_support_size = value_support_size self.continuous_action_space = continuous_action_space self.observation_shape = observation_shape diff --git a/lzero/model/sampled_muzero_model.py b/lzero/model/sampled_muzero_model.py index 505c98f21..82509f2d2 100644 --- a/lzero/model/sampled_muzero_model.py +++ b/lzero/model/sampled_muzero_model.py @@ -27,8 +27,8 @@ def __init__( reward_head_hidden_channels: SequenceType = [256], value_head_hidden_channels: SequenceType = [256], policy_head_hidden_channels: SequenceType = [256], - reward_support_size: int = 601, - value_support_size: int = 601, + reward_support_range: SequenceType =(-300., 301., 1.), + value_support_range: SequenceType =(-300., 301., 1.), proj_hid: int = 1024, proj_out: int = 1024, pred_hid: int = 512, @@ -65,8 +65,8 @@ def __init__( - reward_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). - value_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). - policy_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). - - reward_support_size (:obj:`int`): The size of categorical reward output - - value_support_size (:obj:`int`): The size of categorical value output. + - reward_support_range (:obj:`SequenceType`): The range of categorical reward output + - value_support_range (:obj:`SequenceType`): The range of categorical value output. - proj_hid (:obj:`int`): The size of projection hidden layer. - proj_out (:obj:`int`): The size of projection output layer. - pred_hid (:obj:`int`): The size of prediction hidden layer. @@ -91,12 +91,13 @@ def __init__( - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection for dynamics network, default set it to False. """ super(SampledMuZeroModel, self).__init__() - if not categorical_distribution: + self.categorical_distribution = categorical_distribution + if self.categorical_distribution: + self.reward_support_size = len(torch.arange(*reward_support_range)) + self.value_support_size = len(torch.arange(*value_support_range)) + else: self.reward_support_size = 1 self.value_support_size = 1 - else: - self.reward_support_size = reward_support_size - self.value_support_size = value_support_size self.continuous_action_space = continuous_action_space self.observation_shape = observation_shape diff --git a/lzero/model/sampled_muzero_model_mlp.py b/lzero/model/sampled_muzero_model_mlp.py index 37871d365..0b6856e12 100644 --- a/lzero/model/sampled_muzero_model_mlp.py +++ b/lzero/model/sampled_muzero_model_mlp.py @@ -22,8 +22,8 @@ def __init__( reward_head_hidden_channels: SequenceType = [256], value_head_hidden_channels: SequenceType = [256], policy_head_hidden_channels: SequenceType = [256], - reward_support_size: int = 601, - value_support_size: int = 601, + reward_support_range: SequenceType =(-300., 301., 1.), + value_support_range: SequenceType =(-300., 301., 1.), proj_hid: int = 1024, proj_out: int = 1024, pred_hid: int = 512, @@ -63,8 +63,8 @@ def __init__( - reward_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). - value_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). - policy_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). - - reward_support_size (:obj:`int`): The size of categorical reward output - - value_support_size (:obj:`int`): The size of categorical value output. + - reward_support_range (:obj:`SequenceType`): The range of categorical reward output + - value_support_range (:obj:`SequenceType`): The range of categorical value output. - proj_hid (:obj:`int`): The size of projection hidden layer. - proj_out (:obj:`int`): The size of projection output layer. - pred_hid (:obj:`int`): The size of prediction hidden layer. @@ -89,12 +89,13 @@ def __init__( - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection for dynamics network, default set it to False. """ super(SampledMuZeroModelMLP, self).__init__() - if not categorical_distribution: + self.categorical_distribution = categorical_distribution + if self.categorical_distribution: + self.reward_support_size = len(torch.arange(*reward_support_range)) + self.value_support_size = len(torch.arange(*value_support_range)) + else: self.reward_support_size = 1 self.value_support_size = 1 - else: - self.reward_support_size = reward_support_size - self.value_support_size = value_support_size self.continuous_action_space = continuous_action_space self.observation_shape = observation_shape diff --git a/lzero/model/sampled_unizero_model_multitask.py b/lzero/model/sampled_unizero_model_multitask.py index a8c4f850e..e0026d0ff 100644 --- a/lzero/model/sampled_unizero_model_multitask.py +++ b/lzero/model/sampled_unizero_model_multitask.py @@ -1,9 +1,9 @@ -from typing import Optional, List +from typing import Optional, List, Sequence import torch import torch.nn as nn from ding.torch_utils import MLP -from ding.utils import MODEL_REGISTRY, SequenceType +from ding.utils import MODEL_REGISTRY from easydict import EasyDict from .common import MZNetworkOutput, RepresentationNetworkUniZero, LatentDecoder, \ @@ -12,39 +12,46 @@ from .unizero_world_models.world_model_multitask import WorldModelMT class RepresentationNetworkMLPMT(nn.Module): + """ + Overview: + A multi-task representation network that encodes vector observations into a latent state + using a Multi-Layer Perceptron (MLP). It supports task-specific encoders and an optional + shared projection layer to map representations into a common embedding space. + """ + def __init__( self, - observation_shape_list: List[int], # List of observation shapes for each task + observation_shape_list: List[int], hidden_channels: int = 64, layer_num: int = 2, activation: nn.Module = nn.GELU(approximate='tanh'), norm_type: Optional[str] = 'BN', embedding_dim: int = 256, group_size: int = 8, - use_shared_projection: bool = False, # 控制是否启用共享投影层 - shared_projection_dim: Optional[int] = None, # 共享投影层的维度 - final_norm_option_in_encoder: str = 'LayerNorm', # TODO - ) -> torch.Tensor: + use_shared_projection: bool = False, + shared_projection_dim: Optional[int] = None, + final_norm_option_in_encoder: str = 'LayerNorm', # TODO: Further investigate norm options + ) -> None: """ - Overview: - Representation network used in MuZero and derived algorithms. Encode the vector obs into latent state \ - with Multi-Layer Perceptron (MLP), optionally followed by a shared projection layer. Arguments: - - observation_shape_list (:obj:`List[int]`): The list of observation shape for each task. - - hidden_channels (:obj:`int`): The channel of output hidden state. - - layer_num (:obj:`int`): The number of layers in the MLP. - - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.GELU(approximate='tanh'). - - norm_type (:obj:`str`): The type of normalization in networks, defaults to 'BN'. - - group_size (:obj:`int`): The group size used in SimNorm. - - use_shared_projection (:obj:`bool`): Whether to use a shared projection layer, defaults to False. - - shared_projection_dim (:obj:`Optional[int]`): The dimension of the shared projection layer. \ - If None, defaults to `hidden_channels`. + - observation_shape_list (:obj:`List[int]`): A list of observation feature dimensions, one for each task. + - hidden_channels (:obj:`int`): The number of hidden channels in the task-specific MLPs. + - layer_num (:obj:`int`): The number of layers in each MLP. + - activation (:obj:`nn.Module`): The activation function to use in the MLPs. Defaults to nn.GELU(approximate='tanh'). + - norm_type (:obj:`str`): The type of normalization to use within the MLPs. Defaults to 'BN'. + - embedding_dim (:obj:`int`): The dimension of the final output embedding. + - group_size (:obj:`int`): The group size for SimNorm if it is used. + - use_shared_projection (:obj:`bool`): Whether to use a shared projection layer after task-specific encoding. Defaults to False. + - shared_projection_dim (:obj:`Optional[int]`): The dimension of the shared projection layer. If None, it defaults to `hidden_channels`. + - final_norm_option_in_encoder (:obj:`str`): The final normalization layer type ('LayerNorm' or 'SimNorm'). Defaults to 'LayerNorm'. """ super().__init__() self.env_num = len(observation_shape_list) self.use_shared_projection = use_shared_projection self.hidden_channels = hidden_channels self.shared_projection_dim = shared_projection_dim or hidden_channels + self.embedding_dim = embedding_dim + self.final_norm_option_in_encoder = final_norm_option_in_encoder # Task-specific representation networks self.fc_representation = nn.ModuleList([ @@ -55,25 +62,16 @@ def __init__( layer_num=layer_num, activation=activation, norm_type=norm_type, - # don't use activation and norm in the last layer of representation network is important for convergence. + # No activation or norm in the last layer is important for convergence. output_activation=False, output_norm=False, - # last_linear_layer_init_zero=True is beneficial for convergence speed. + # Initializing the last linear layer to zero can be beneficial for convergence speed. last_linear_layer_init_zero=True, ) for obs_shape in observation_shape_list ]) - - # Shared projection layer - if self.use_shared_projection: - self.shared_projection = nn.Linear(hidden_channels, self.shared_projection_dim) - # self.projection_norm = nn.LayerNorm(self.shared_projection_dim) # Optional normalization for shared space - self.projection_norm = SimNorm(simnorm_dim=group_size) # Optional normalization for shared space - self.embedding_dim = embedding_dim - # SimNorm for task-specific outputs - # self.sim_norm = SimNorm(simnorm_dim=group_size) - self.final_norm_option_in_encoder = final_norm_option_in_encoder + # Final normalization layer before projection if self.final_norm_option_in_encoder == 'LayerNorm': self.final_norm = nn.LayerNorm(self.embedding_dim, eps=1e-5) elif self.final_norm_option_in_encoder == 'SimNorm': @@ -81,246 +79,184 @@ def __init__( else: raise ValueError(f"Unsupported final_norm_option_in_encoder: {self.final_norm_option_in_encoder}") + # Optional shared projection layer + if self.use_shared_projection: + self.shared_projection = nn.Linear(hidden_channels, self.shared_projection_dim) + # Using SimNorm for the shared space projection + self.projection_norm = SimNorm(simnorm_dim=group_size) def forward(self, x: torch.Tensor, task_id: int) -> torch.Tensor: """ Shapes: - - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size, N is the length of vector observation. - - task_id (:obj:`int`): The ID of the current task. - - output (:obj:`torch.Tensor`): :math:`(B, hidden_channels)` if shared projection is not used, \ - otherwise :math:`(B, shared_projection_dim)`. + - x (:obj:`torch.Tensor`): The input tensor of shape :math:`(B, N)`, where B is the batch size and N is the length of the vector observation. + - task_id (:obj:`int`): The identifier for the current task, used to select the appropriate encoder. + - output (:obj:`torch.Tensor`): The output latent state. Its shape is :math:`(B, embedding_dim)` if shared projection is not used, otherwise :math:`(B, shared_projection_dim)`. """ - # Task-specific representation + # Encode observation using the task-specific MLP x = self.fc_representation[task_id](x) + # Apply final normalization x = self.final_norm(x) - # x = self.sim_norm(x) - # Shared projection layer (if enabled) + # Apply the shared projection layer if enabled if self.use_shared_projection: x = self.shared_projection(x) - x = self.projection_norm(x) # Optional normalization + x = self.projection_norm(x) return x -# class RepresentationNetworkMLPMT(nn.Module): -# def __init__( -# self, -# observation_shape_list: List[int], # List of observation shapes for each task -# hidden_channels: int = 64, -# layer_num: int = 2, -# activation: nn.Module = nn.GELU(approximate='tanh'), -# norm_type: Optional[str] = 'BN', -# group_size: int = 8, -# ) -> torch.Tensor: -# """ -# Overview: -# Representation network used in MuZero and derived algorithms. Encode the vector obs into latent state \ -# with Multi-Layer Perceptron (MLP). -# Arguments: -# - observation_shape_list (:obj:`List[int]`): The list of observation shape for each task. -# - hidden_channels (:obj:`int`): The channel of output hidden state. -# - layer_num (:obj:`int`): The number of layers in the MLP. -# - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.GELU(approximate='tanh'). -# - norm_type (:obj:`str`): The type of normalization in networks, defaults to 'BN'. -# - group_size (:obj:`int`): The group size used in SimNorm. -# """ -# super().__init__() -# self.env_num = len(observation_shape_list) -# self.fc_representation = nn.ModuleList([ -# MLP( -# in_channels=obs_shape, -# hidden_channels=hidden_channels, -# out_channels=hidden_channels, -# layer_num=layer_num, -# activation=activation, -# norm_type=norm_type, -# # don't use activation and norm in the last layer of representation network is important for convergence. -# output_activation=False, -# output_norm=False, -# # last_linear_layer_init_zero=True is beneficial for convergence speed. -# last_linear_layer_init_zero=True, -# ) -# for obs_shape in observation_shape_list -# ]) -# self.sim_norm = SimNorm(simnorm_dim=group_size) - -# def forward(self, x: torch.Tensor, task_id: int) -> torch.Tensor: -# """ -# Shapes: -# - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size, N is the length of vector observation. -# - task_id (:obj:`int`): The ID of the current task. -# - output (:obj:`torch.Tensor`): :math:`(B, hidden_channels)`, where B is batch size. -# """ -# x = self.fc_representation[task_id](x) -# x = self.sim_norm(x) -# return x - - @MODEL_REGISTRY.register('SampledUniZeroMTModel') class SampledUniZeroMTModel(nn.Module): + """ + Overview: + The main model for Sampled UniZero in a multi-task setting. It integrates a representation + network, a tokenizer, and a world model to perform initial and recurrent inference, + which are essential for MuZero-style planning algorithms. The model is designed to handle + both vector and image-based observations across multiple tasks. + """ + def __init__( self, - observation_shape_list: List[SequenceType], # List of observation shapes for each task - action_space_size_list: List[int], # List of action space sizes for each task + observation_shape_list: List[Sequence], + action_space_size_list: List[int], num_res_blocks: int = 1, num_channels: int = 64, activation: nn.Module = nn.GELU(approximate='tanh'), downsample: bool = True, norm_type: Optional[str] = 'LN', - # world_model_cfgs: List[EasyDict] = None, # List of world model configs for each task - world_model_cfg: List[EasyDict] = None, # List of world model configs for each task + world_model_cfg: EasyDict = None, *args, **kwargs ): """ - Overview: - The definition of data procession in the scalable latent world model of UniZero (https://arxiv.org/abs/2406.10667), including two main parts: - - initial_inference, which is used to predict the value, policy, and latent state based on the current observation. - - recurrent_inference, which is used to predict the value, policy, reward, and next latent state based on the current latent state and action. - The world model consists of three main components: - - a tokenizer, which encodes observations into embeddings, - - a transformer, which processes the input sequences, - - and heads, which generate the logits for observations, rewards, policy, and value. Arguments: - - observation_shape_list (:obj:`List[SequenceType]`): List of observation space shapes for each task, e.g. [C, W, H]=[3, 64, 64] for Atari. - - action_space_size_list (:obj:`List[int]`): List of action space sizes for each task. - - num_res_blocks (:obj:`int`): The number of res blocks in UniZero model. - - num_channels (:obj:`int`): The channels of hidden states in representation network. - - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ - operation to speedup, e.g. ReLU(inplace=True). - - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ - defaults to True. This option is often used in video games like Atari. In board games like go, \ - we don't need this module. - - norm_type (:obj=`str`): The type of normalization in networks. Defaults to 'LN'. - - world_model_cfgs (:obj=`List[EasyDict]`): The list of world model configurations for each task. + - observation_shape_list (:obj:`List[Sequence]`): A list of observation space shapes for each task (e.g., `[C, W, H]` for images or `[D]` for vectors). + - action_space_size_list (:obj:`List[int]`): A list of action space sizes for each task. + - num_res_blocks (:obj:`int`): The number of residual blocks in the image representation network. + - num_channels (:obj:`int`): The number of channels in the hidden states of the image representation network. + - activation (:obj:`nn.Module`): The activation function used throughout the network. + - downsample (:obj:`bool`): Whether to downsample observations in the image representation network. + - norm_type (:obj:`str`): The type of normalization to use in networks. Defaults to 'LN'. + - world_model_cfg (:obj:`EasyDict`): A single configuration object for the world model, shared across all tasks. """ super(SampledUniZeroMTModel, self).__init__() self.task_num = len(observation_shape_list) self.activation = activation self.downsample = downsample - # Initialize environment-specific networks and models - self.representation_networks = nn.ModuleList() - # self.decoder_networks = nn.ModuleList() - # self.world_models = nn.ModuleList() - + # Determine the embedding dimension for observations and actions if world_model_cfg.task_embed_option == "concat_task_embed": obs_act_embed_dim = world_model_cfg.embed_dim - world_model_cfg.task_embed_dim if hasattr(world_model_cfg, "task_embed_dim") else 96 else: obs_act_embed_dim = world_model_cfg.embed_dim - - for task_id in range(self.task_num): - # world_model_cfg = world_model_cfgs[task_id] - world_model_cfg.norm_type = norm_type - assert world_model_cfg.max_tokens == 2 * world_model_cfg.max_blocks, 'max_tokens should be 2 * max_blocks, because each timestep has 2 tokens: obs and action' - if world_model_cfg.obs_type == 'vector': - self.representation_network = RepresentationNetworkMLPMT( - observation_shape_list=observation_shape_list, - hidden_channels=obs_act_embed_dim, - layer_num=2, + world_model_cfg.norm_type = norm_type + assert world_model_cfg.max_tokens == 2 * world_model_cfg.max_blocks, \ + 'max_tokens should be 2 * max_blocks, as each timestep consists of an observation and an action token.' + + # Initialize networks based on observation type + if world_model_cfg.obs_type == 'vector': + # A single representation network capable of handling multiple tasks via task_id + self.representation_network = RepresentationNetworkMLPMT( + observation_shape_list=observation_shape_list, + hidden_channels=obs_act_embed_dim, + layer_num=2, + activation=self.activation, + norm_type=norm_type, + embedding_dim=obs_act_embed_dim, + group_size=world_model_cfg.group_size, + use_shared_projection=world_model_cfg.use_shared_projection, + final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder, + ) + self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=None, with_lpips=False) + self.world_model = WorldModelMT(config=world_model_cfg, tokenizer=self.tokenizer) + + elif world_model_cfg.obs_type == 'image': + self.representation_network = nn.ModuleList() + # TODO: Currently uses a single shared encoder for all image-based tasks. + # This can be extended to support multiple independent encoders if needed. + for _ in range(1): + self.representation_network.append(RepresentationNetworkUniZero( + observation_shape_list[0], # Assuming shared encoder uses the shape of the first task + num_res_blocks, + num_channels, + self.downsample, activation=self.activation, norm_type=norm_type, embedding_dim=obs_act_embed_dim, group_size=world_model_cfg.group_size, - use_shared_projection=world_model_cfg.use_shared_projection, final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder, - ) - self.tokenizer = Tokenizer(encoder=self.representation_network, - decoder_network=None, with_lpips=False) - self.world_model = WorldModelMT(config=world_model_cfg, tokenizer=self.tokenizer) - elif world_model_cfg.obs_type == 'image': - self.representation_network = nn.ModuleList() - # for task_id in range(self.task_num): # TODO: N independent encoder - for task_id in range(1): # TODO: one share encoder - self.representation_network.append(RepresentationNetworkUniZero( - observation_shape_list[task_id], - num_res_blocks, - num_channels, - self.downsample, - activation=self.activation, - norm_type=norm_type, - embedding_dim=obs_act_embed_dim, - group_size=world_model_cfg.group_size, - final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder, - )) - # TODO: we should change the output_shape to the real observation shape - # self.decoder_network = LatentDecoder(embedding_dim=world_model_cfg.embed_dim, output_shape=(3, 64, 64)) - + )) + # TODO: The world model and tokenizer for the 'image' case should be initialized here. + # self.tokenizer = Tokenizer(...) + # self.world_model = WorldModelMT(...) - # Print model parameters for debugging - print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') - print('==' * 20) - print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') - print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') - print('==' * 20) + # Print model parameter counts for verification + print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') + print('==' * 20) + print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') + if hasattr(self.tokenizer, 'encoder') and self.tokenizer.encoder is not None: + print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') + print('==' * 20) - def initial_inference(self, obs_batch: torch.Tensor, action_batch=None, current_obs_batch=None, task_id=None) -> MZNetworkOutput: + def initial_inference(self, obs_batch: torch.Tensor, action_batch: Optional[torch.Tensor] = None, current_obs_batch: Optional[torch.Tensor] = None, task_id: Optional[int] = None) -> MZNetworkOutput: """ Overview: - Initial inference of UniZero model, which is the first step of the UniZero model. - To perform the initial inference, we first use the representation network to obtain the ``latent_state``. - Then we use the prediction network to predict ``value`` and ``policy_logits`` of the ``latent_state``. + Performs the initial inference step of the UniZero model. It takes an observation + and produces a latent state, a value prediction, and an initial policy. Arguments: - - obs_batch (:obj:`torch.Tensor`): The 3D image observation data. - - task_id (:obj:`int`): The ID of the current task. + - obs_batch (:obj:`torch.Tensor`): The initial batch of observations. + - action_batch (:obj:`Optional[torch.Tensor]`): An optional batch of actions. + - current_obs_batch (:obj:`Optional[torch.Tensor]`): An optional batch of current observations. + - task_id (:obj:`Optional[int]`): The identifier for the current task. Returns (MZNetworkOutput): - - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. - - reward (:obj:`torch.Tensor`): The predicted reward of input state and selected action. \ - In initial inference, we set it to zero vector. - - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. - - latent_state (:obj=`torch.Tensor`): The encoding latent state of input state. + An object containing the predicted value, initial reward (zero), policy logits, and latent state. Shapes: - - obs (:obj:`torch.Tensor`): :math:`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size. - - value (:obj=`torch.Tensor`): :math=`(B, value_support_size)`, where B is batch_size. - - reward (:obj=`torch.Tensor`): :math=`(B, reward_support_size)`, where B is batch_size. - - policy_logits (:obj=`torch.Tensor`): :math=`(B, action_dim)`, where B is batch_size. - - latent_state (:obj=`torch.Tensor`): :math=`(B, H_, W_)`, where B is batch_size, H_ is the height of latent state, W_ is the width of latent state. + - obs_batch (:obj:`torch.Tensor`): :math:`(B, ...)` where B is the batch size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`. + - latent_state (:obj:`torch.Tensor`): :math:`(B, embedding_dim)`. """ batch_size = obs_batch.size(0) obs_act_dict = {'obs': obs_batch, 'action': action_batch, 'current_obs': current_obs_batch} _, obs_token, logits_rewards, logits_policy, logits_value = self.world_model.forward_initial_inference(obs_act_dict, task_id=task_id) - latent_state, reward, policy_logits, value = obs_token, logits_rewards, logits_policy, logits_value - policy_logits = policy_logits.squeeze(1) - value = value.squeeze(1) + + latent_state = obs_token + policy_logits = logits_policy.squeeze(1) + value = logits_value.squeeze(1) return MZNetworkOutput( - value, - [0. for _ in range(batch_size)], - policy_logits, - latent_state, + value=value, + reward=[0. for _ in range(batch_size)], # Initial reward is always zero + policy_logits=policy_logits, + latent_state=latent_state, ) - def recurrent_inference(self, state_action_history: torch.Tensor, simulation_index=0, - latent_state_index_in_search_path=[], task_id=0) -> MZNetworkOutput: + def recurrent_inference(self, state_action_history: torch.Tensor, simulation_index: int = 0, search_depth: List[int] = [], task_id: int = 0) -> MZNetworkOutput: """ Overview: - Recurrent inference of UniZero model. To perform the recurrent inference, we concurrently predict the latent dynamics (reward/next_latent_state) - and decision-oriented quantities (value/policy) conditioned on the learned latent history in the world_model. + Performs the recurrent inference step (the dynamics function). Given a history of + latent states and actions, it predicts the next latent state, reward, value, and policy. Arguments: - - state_action_history (:obj:`torch.Tensor`): The history of states and actions. - - task_id (:obj:`int`): The ID of the current task. - - simulation_index (:obj=`int`): The index of the current simulation. - - latent_state_index_in_search_path (:obj=`List[int]`): The indices of latent states in the search path. + - state_action_history (:obj:`torch.Tensor`): A history of states and actions. + - simulation_index (:obj:`int`): The index of the current simulation step in MCTS. + - search_depth (:obj:`List[int]`): The indices of latent states in the current search path. + - task_id (:obj:`int`): The identifier for the current task. Returns (MZNetworkOutput): - - value (:obj=`torch.Tensor`): The output value of input state to help policy improvement and evaluation. - - reward (:obj=`torch.Tensor`): The predicted reward of input state and selected action. - - policy_logits (:obj=`torch.Tensor`): The output logit to select discrete action. - - latent_state (:obj=`torch.Tensor`): The encoding latent state of input state. - - next_latent_state (:obj=`torch.Tensor`): The predicted next latent state. + An object containing the predicted value, reward, policy logits, and the next latent state. Shapes: - - obs (:obj=`torch.Tensor`): :math=`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size. - - action (:obj=`torch.Tensor`): :math=`(B, )`, where B is batch_size. - - value (:obj=`torch.Tensor`): :math=`(B, value_support_size)`, where B is batch_size. - - reward (:obj=`torch.Tensor`): :math=`(B, reward_support_size)`, where B is batch_size. - - policy_logits (:obj=`torch.Tensor`): :math=`(B, action_dim)`, where B is batch_size. - - latent_state (:obj=`torch.Tensor`): :math=`(B, H_, W_)`, where B is batch_size, H_ is the height of latent state, W_ is the width of latent state. - - next_latent_state (:obj=`torch.Tensor`): :math=`(B, H_, W_)`, where B is batch_size, H_ is the height of latent state, W_ is the width of latent state. - """ + - state_action_history (:obj:`torch.Tensor`): :math:`(B, L, D)`, where L is sequence length. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, embedding_dim)`. + """ _, logits_observations, logits_rewards, logits_policy, logits_value = self.world_model.forward_recurrent_inference( - state_action_history, simulation_index, latent_state_index_in_search_path, task_id=task_id) - next_latent_state, reward, policy_logits, value = logits_observations, logits_rewards, logits_policy, logits_value - policy_logits = policy_logits.squeeze(1) - value = value.squeeze(1) - reward = reward.squeeze(1) + state_action_history, simulation_index, search_depth, task_id=task_id) + + next_latent_state = logits_observations + reward = logits_rewards.squeeze(1) + policy_logits = logits_policy.squeeze(1) + value = logits_value.squeeze(1) + return MZNetworkOutput(value, reward, policy_logits, next_latent_state) \ No newline at end of file diff --git a/lzero/model/stochastic_muzero_model.py b/lzero/model/stochastic_muzero_model.py index 00ccea619..7aa7ce678 100644 --- a/lzero/model/stochastic_muzero_model.py +++ b/lzero/model/stochastic_muzero_model.py @@ -27,8 +27,8 @@ def __init__( reward_head_hidden_channels: SequenceType = [32], value_head_hidden_channels: SequenceType = [32], policy_head_hidden_channels: SequenceType = [32], - reward_support_size: int = 601, - value_support_size: int = 601, + reward_support_range: SequenceType =(-300., 301., 1.), + value_support_range: SequenceType =(-300., 301., 1.), proj_hid: int = 1024, proj_out: int = 1024, pred_hid: int = 512, @@ -61,8 +61,8 @@ def __init__( - reward_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). - value_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). - policy_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). - - reward_support_size (:obj:`int`): The size of categorical reward output - - value_support_size (:obj:`int`): The size of categorical value output. + - reward_support_range (:obj:`SequenceType`): The range of categorical reward output + - value_support_range (:obj:`SequenceType`): The range of categorical value output. - proj_hid (:obj:`int`): The size of projection hidden layer. - proj_out (:obj:`int`): The size of projection output layer. - pred_hid (:obj:`int`): The size of prediction hidden layer. @@ -83,8 +83,8 @@ def __init__( super(StochasticMuZeroModel, self).__init__() self.categorical_distribution = categorical_distribution if self.categorical_distribution: - self.reward_support_size = reward_support_size - self.value_support_size = value_support_size + self.reward_support_size = len(torch.arange(*reward_support_range)) + self.value_support_size = len(torch.arange(*value_support_range)) else: self.reward_support_size = 1 self.value_support_size = 1 diff --git a/lzero/model/stochastic_muzero_model_mlp.py b/lzero/model/stochastic_muzero_model_mlp.py index a0b4b8211..9ac6efe92 100644 --- a/lzero/model/stochastic_muzero_model_mlp.py +++ b/lzero/model/stochastic_muzero_model_mlp.py @@ -22,8 +22,8 @@ def __init__( reward_head_hidden_channels: SequenceType = [32], value_head_hidden_channels: SequenceType = [32], policy_head_hidden_channels: SequenceType = [32], - reward_support_size: int = 601, - value_support_size: int = 601, + reward_support_range: SequenceType =(-300., 301., 1.), + value_support_range: SequenceType =(-300., 301., 1.), proj_hid: int = 1024, proj_out: int = 1024, pred_hid: int = 512, @@ -54,8 +54,8 @@ def __init__( - reward_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). - value_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). - policy_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). - - reward_support_size (:obj:`int`): The size of categorical reward output - - value_support_size (:obj:`int`): The size of categorical value output. + - reward_support_range (:obj:`SequenceType`): The range of categorical reward output + - value_support_range (:obj:`SequenceType`): The range of categorical value output. - proj_hid (:obj:`int`): The size of projection hidden layer. - proj_out (:obj:`int`): The size of projection output layer. - pred_hid (:obj:`int`): The size of prediction hidden layer. @@ -72,12 +72,12 @@ def __init__( """ super(StochasticMuZeroModelMLP, self).__init__() self.categorical_distribution = categorical_distribution - if not self.categorical_distribution: + if self.categorical_distribution: + self.reward_support_size = len(torch.arange(*reward_support_range)) + self.value_support_size = len(torch.arange(*value_support_range)) + else: self.reward_support_size = 1 self.value_support_size = 1 - else: - self.reward_support_size = reward_support_size - self.value_support_size = value_support_size self.action_space_size = action_space_size self.chance_space_size = chance_space_size diff --git a/lzero/model/unizero_model.py b/lzero/model/unizero_model.py index 6b092a978..59b893b21 100644 --- a/lzero/model/unizero_model.py +++ b/lzero/model/unizero_model.py @@ -4,12 +4,15 @@ import torch.nn as nn from ding.utils import MODEL_REGISTRY, SequenceType from easydict import EasyDict +# from transformers import T5ForConditionalGeneration, T5Tokenizer from .common import MZNetworkOutput, RepresentationNetworkUniZero, RepresentationNetworkMLP, LatentDecoder, \ VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook, \ - HFLanguageRepresentationNetwork + HFLanguageRepresentationNetwork, QwenNetwork from .unizero_world_models.tokenizer import Tokenizer from .unizero_world_models.world_model import WorldModel +from .vit import ViT, ViTConfig +from ding.utils import ENV_REGISTRY, set_pkg_seed, get_rank, get_world_size # use ModelRegistry to register the model, for more details about ModelRegistry, please refer to DI-engine's document. @@ -64,6 +67,10 @@ def __init__( - analysis_sim_norm (:obj:`bool`): Whether to analyze the similarity of the norm. """ super(UniZeroModel, self).__init__() + # Get current world size and rank for distributed setups. + self.world_size: int = get_world_size() + self.rank: int = get_rank() + self.action_space_size = action_space_size self.activation = activation self.downsample = downsample @@ -77,11 +84,12 @@ def __init__( layer_num=2, activation=self.activation, group_size=world_model_cfg.group_size, + final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder ) # TODO: only for MemoryEnv now self.decoder_network = VectorDecoderForMemoryEnv(embedding_dim=world_model_cfg.embed_dim, output_shape=25) self.tokenizer = Tokenizer(encoder=self.representation_network, - decoder_network=self.decoder_network, with_lpips=False, obs_type=world_model_cfg.obs_type) + decoder=self.decoder_network, with_lpips=False, obs_type=world_model_cfg.obs_type) self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer) print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') print('==' * 20) @@ -89,8 +97,37 @@ def __init__( print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') print('==' * 20) elif world_model_cfg.obs_type == 'text': - self.representation_network = HFLanguageRepresentationNetwork(model_path=kwargs['encoder_url'], embedding_size=world_model_cfg.embed_dim) - self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=None, with_lpips=False,) + if kwargs['encoder_option'] == 'legacy': + self.representation_network = HFLanguageRepresentationNetwork(model_path=kwargs['encoder_url'], embedding_size=world_model_cfg.embed_dim, final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder) + if world_model_cfg.decode_loss_mode is None or world_model_cfg.decode_loss_mode.lower() == 'none': + self.decoder_network = None + self.decoder_network_tokenizer = None + projection = None + else: + if self.rank == 0: + self.decoder_network = T5ForConditionalGeneration.from_pretrained("t5-small") + self.decoder_network_tokenizer = T5Tokenizer.from_pretrained("t5-small") + if self.world_size > 1: + # Wait until rank 0 finishes loading the tokenizer + torch.distributed.barrier() + if self.rank != 0: + self.decoder_network = T5ForConditionalGeneration.from_pretrained("t5-small") + self.decoder_network_tokenizer = T5Tokenizer.from_pretrained("t5-small") + projection = [world_model_cfg.embed_dim, self.decoder_network.config.d_model] + elif kwargs['encoder_option'] == 'qwen': + self.representation_network = QwenNetwork(model_path=kwargs['encoder_url'], embedding_size=world_model_cfg.embed_dim, final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder) + if world_model_cfg.decode_loss_mode is None or world_model_cfg.decode_loss_mode.lower() == 'none': + self.decoder_network = None + self.decoder_network_tokenizer = None + projection = None + else: + projection = [world_model_cfg.embed_dim, self.representation_network.pretrained_model.config.hidden_size] + self.decoder_network = self.representation_network + self.decoder_network_tokenizer = None + else: + raise ValueError(f"Unsupported encoder option: {kwargs['encoder_option']}") + self.tokenizer = Tokenizer(encoder=self.representation_network, decoder=self.decoder_network, decoder_network_tokenizer=self.decoder_network_tokenizer, + with_lpips=False, projection=projection, encoder_option=kwargs['encoder_option']) self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer) print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') print('==' * 20) @@ -98,24 +135,41 @@ def __init__( print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') print('==' * 20) elif world_model_cfg.obs_type == 'image': - self.representation_network = RepresentationNetworkUniZero( - observation_shape, - num_res_blocks, - num_channels, - self.downsample, - activation=self.activation, - norm_type=norm_type, - embedding_dim=world_model_cfg.embed_dim, - group_size=world_model_cfg.group_size, - final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder, - ) + if world_model_cfg.encoder_type == "resnet": + self.representation_network = RepresentationNetworkUniZero( + observation_shape, + num_res_blocks, + num_channels, + self.downsample, + activation=self.activation, + norm_type=norm_type, + embedding_dim=world_model_cfg.embed_dim, + group_size=world_model_cfg.group_size, + final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder, + ) + elif world_model_cfg.encoder_type == "vit": + # vit base + vit_config = ViTConfig( + image_size=observation_shape[1], + patch_size=8, + num_classes=world_model_cfg.embed_dim, + dim=768, + depth=12, + heads=12, + mlp_dim=3072, + dropout=0.1, + emb_dropout=0.1, + final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder, + lora_config=world_model_cfg, + ) + self.representation_network = ViT(config=vit_config) # ====== for analysis ====== if world_model_cfg.analysis_sim_norm: self.encoder_hook = FeatureAndGradientHook() self.encoder_hook.setup_hooks(self.representation_network) - self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=None, with_lpips=False, obs_type=world_model_cfg.obs_type) + self.tokenizer = Tokenizer(encoder=self.representation_network, decoder=None, with_lpips=False, obs_type=world_model_cfg.obs_type) self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer) print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') print('==' * 20) @@ -146,7 +200,7 @@ def __init__( self.encoder_hook = FeatureAndGradientHook() self.encoder_hook.setup_hooks(self.representation_network) - self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=self.decoder_network, obs_type=world_model_cfg.obs_type) + self.tokenizer = Tokenizer(encoder=self.representation_network, decoder=self.decoder_network, obs_type=world_model_cfg.obs_type) self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer) print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') print(f'{sum(p.numel() for p in self.world_model.parameters()) - sum(p.numel() for p in self.tokenizer.decoder_network.parameters()) - sum(p.numel() for p in self.tokenizer.lpips.parameters())} parameters in agent.world_model - (decoder_network and lpips)') diff --git a/lzero/model/unizero_model_multitask.py b/lzero/model/unizero_model_multitask.py index 71cf60ea6..68095de46 100644 --- a/lzero/model/unizero_model_multitask.py +++ b/lzero/model/unizero_model_multitask.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Sequence, Dict, Any, List import torch import torch.nn as nn @@ -9,14 +9,21 @@ VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook from .unizero_world_models.tokenizer import Tokenizer from .unizero_world_models.world_model_multitask import WorldModelMT +from .vit import ViT, ViTConfig -from line_profiler import line_profiler -# use ModelRegistry to register the model, for more details about ModelRegistry, please refer to DI-engine's document. @MODEL_REGISTRY.register('UniZeroMTModel') class UniZeroMTModel(nn.Module): + """ + Overview: + The main model for UniZero, a multi-task agent based on a scalable latent world model. + This class orchestrates the representation network, world model, and prediction heads. + It provides two primary interfaces: + - `initial_inference`: Encodes an observation to produce an initial latent state and predictions (value, policy). + - `recurrent_inference`: Simulates dynamics by taking a history of latent states and actions to predict the next + latent state, reward, value, and policy. + """ - #@profile def __init__( self, observation_shape: SequenceType = (4, 64, 64), @@ -25,232 +32,253 @@ def __init__( num_channels: int = 64, activation: nn.Module = nn.GELU(approximate='tanh'), downsample: bool = True, - norm_type: Optional[str] = 'BN', + norm_type: str = 'BN', world_model_cfg: EasyDict = None, task_num: int = 1, - *args, - **kwargs - ): + *args: Any, + **kwargs: Any + ) -> None: """ Overview: - The definition of data procession in the scalable latent world model of UniZero (https://arxiv.org/abs/2406.10667), including two main parts: - - initial_inference, which is used to predict the value, policy, and latent state based on the current observation. - - recurrent_inference, which is used to predict the value, policy, reward, and next latent state based on the current latent state and action. - The world model consists of three main components: - - a tokenizer, which encodes observations into embeddings, - - a transformer, which processes the input sequences, - - and heads, which generate the logits for observations, rewards, policy, and value. + Initializes the UniZeroMTModel, setting up the representation network, tokenizer, and world model + based on the provided configuration. Arguments: - - observation_shape (:obj:`SequenceType`): Observation space shape, e.g. [C, W, H]=[3, 64, 64] for Atari. - - action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space. - - num_res_blocks (:obj:`int`): The number of res blocks in UniZero model. - - num_channels (:obj:`int`): The channels of hidden states in representation network. - - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ - operation to speedup, e.g. ReLU(inplace=True). - - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ - defaults to True. This option is often used in video games like Atari. In board games like go, \ - we don't need this module. - - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. - - world_model_cfg (:obj:`EasyDict`): The configuration of the world model, including the following keys: - - obs_type (:obj:`str`): The type of observation, which can be 'image', 'vector', or 'image_memory'. - - embed_dim (:obj:`int`): The dimension of the embedding. - - group_size (:obj:`int`): The group size of the transformer. - - max_blocks (:obj:`int`): The maximum number of blocks in the transformer. - - max_tokens (:obj:`int`): The maximum number of tokens in the transformer. - - context_length (:obj:`int`): The context length of the transformer. - - device (:obj:`str`): The device of the model, which can be 'cuda' or 'cpu'. - - action_space_size (:obj:`int`): The shape of the action. - - num_layers (:obj:`int`): The number of layers in the transformer. - - num_heads (:obj:`int`): The number of heads in the transformer. - - policy_entropy_weight (:obj:`float`): The weight of the policy entropy. - - analysis_sim_norm (:obj:`bool`): Whether to analyze the similarity of the norm. + - observation_shape (:obj:`SequenceType`): The shape of the input observation, e.g., (C, H, W). + - action_space_size (:obj:`int`): The size of the discrete action space. + - num_res_blocks (:obj:`int`): The number of residual blocks in the ResNet-based representation network. + - num_channels (:obj:`int`): The number of channels in the ResNet-based representation network. + - activation (:obj:`nn.Module`): The activation function to use throughout the network. + - downsample (:obj:`bool`): Whether to downsample the observation in the representation network. + - norm_type (:obj:`str`): The type of normalization to use, e.g., 'BN' for BatchNorm. + - world_model_cfg (:obj:`EasyDict`): Configuration for the world model and its components. + - task_num (:obj:`int`): The number of tasks for multi-task learning. """ - super(UniZeroMTModel, self).__init__() + super().__init__() + print(f'========== Initializing UniZeroMTModel (num_res_blocks: {num_res_blocks}, num_channels: {num_channels}) ==========') - print(f'==========UniZeroMTModel, num_res_blocks:{num_res_blocks}, num_channels:{num_channels}===========') - - self.action_space_size = action_space_size - - # for multi-task - self.action_space_size = 18 + # --- Basic attribute setup --- self.task_num = task_num - self.activation = activation self.downsample = downsample world_model_cfg.norm_type = norm_type - assert world_model_cfg.max_tokens == 2 * world_model_cfg.max_blocks, 'max_tokens should be 2 * max_blocks, because each timestep has 2 tokens: obs and action' + # NOTE: The action_space_size passed as an argument is immediately overridden. + # This might be intentional for specific experiments but is not a general practice. + self.action_space_size = 18 + + assert world_model_cfg.max_tokens == 2 * world_model_cfg.max_blocks, \ + "max_tokens should be 2 * max_blocks, as each timestep consists of an observation and an action token." + + # --- Determine embedding dimensions --- if world_model_cfg.task_embed_option == "concat_task_embed": - obs_act_embed_dim = world_model_cfg.embed_dim - world_model_cfg.task_embed_dim if hasattr(world_model_cfg, "task_embed_dim") else 96 + task_embed_dim = world_model_cfg.get("task_embed_dim", 32) # Default task_embed_dim to 32 if not specified + obs_act_embed_dim = world_model_cfg.embed_dim - task_embed_dim else: obs_act_embed_dim = world_model_cfg.embed_dim - if world_model_cfg.obs_type == 'vector': - self.representation_network = RepresentationNetworkMLP( - observation_shape, - hidden_channels=obs_act_embed_dim, - layer_num=2, - activation=self.activation, - group_size=world_model_cfg.group_size, - ) - # TODO: only for MemoryEnv now - self.decoder_network = VectorDecoderForMemoryEnv(embedding_dim=world_model_cfg.embed_dim, output_shape=25) - self.tokenizer = Tokenizer(encoder=self.representation_network, - decoder_network=self.decoder_network, with_lpips=False, obs_type=world_model_cfg.obs_type) - self.world_model = WorldModelMT(config=world_model_cfg, tokenizer=self.tokenizer) - print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') - print('==' * 20) - print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') - print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') - print('==' * 20) - elif world_model_cfg.obs_type == 'image': - self.representation_network = nn.ModuleList() - # for task_id in range(self.task_num): # TODO: N independent encoder - for task_id in range(1): # TODO: one share encoder - self.representation_network.append(RepresentationNetworkUniZero( - observation_shape, - num_res_blocks, - num_channels, - self.downsample, - activation=self.activation, - norm_type=norm_type, - embedding_dim=obs_act_embed_dim, - group_size=world_model_cfg.group_size, - final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder, - use_adaptive_scale=world_model_cfg.use_adaptive_scale, - )) - # self.representation_network = RepresentationNetworkUniZero( - # observation_shape, - # num_res_blocks, - # num_channels, - # self.downsample, - # activation=self.activation, - # norm_type=norm_type, - # embedding_dim=world_model_cfg.embed_dim, - # group_size=world_model_cfg.group_size, - # ) - # TODO: we should change the output_shape to the real observation shape - # self.decoder_network = LatentDecoder(embedding_dim=world_model_cfg.embed_dim, output_shape=(3, 64, 64)) + # --- Initialize model components based on observation type --- + obs_type = world_model_cfg.obs_type + if obs_type == 'vector': + self._init_vector_components(world_model_cfg, obs_act_embed_dim) + elif obs_type == 'image': + self._init_image_components(world_model_cfg, observation_shape, num_res_blocks, num_channels, obs_act_embed_dim) + elif obs_type == 'image_memory': + self._init_image_memory_components(world_model_cfg) + else: + raise ValueError(f"Unsupported observation type: {obs_type}") + + # --- Initialize world model and tokenizer --- + self.world_model = WorldModelMT(config=world_model_cfg, tokenizer=self.tokenizer) + + # --- Log parameter counts for analysis --- + self._log_model_parameters(obs_type) - # ====== for analysis ====== - if world_model_cfg.analysis_sim_norm: - self.encoder_hook = FeatureAndGradientHook() - self.encoder_hook.setup_hooks(self.representation_network) + def _init_vector_components(self, world_model_cfg: EasyDict, obs_act_embed_dim: int) -> None: + """Initializes components for 'vector' observation type.""" + self.representation_network = RepresentationNetworkMLP( + observation_shape=world_model_cfg.observation_shape, + hidden_channels=obs_act_embed_dim, + layer_num=2, + activation=self.activation, + group_size=world_model_cfg.group_size, + ) + # TODO: This is currently specific to MemoryEnv. Generalize if needed. + self.decoder_network = VectorDecoderForMemoryEnv(embedding_dim=world_model_cfg.embed_dim, output_shape=25) + self.tokenizer = Tokenizer( + encoder=self.representation_network, + decoder=self.decoder_network, + with_lpips=False, + obs_type=world_model_cfg.obs_type + ) - self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=None, with_lpips=False, obs_type=world_model_cfg.obs_type) - self.world_model = WorldModelMT(config=world_model_cfg, tokenizer=self.tokenizer) - print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') - print('==' * 20) - print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') - print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') - print('==' * 20) - elif world_model_cfg.obs_type == 'image_memory': - # todo for concat_task_embed - self.representation_network = LatentEncoderForMemoryEnv( - image_shape=(3, 5, 5), - embedding_size=world_model_cfg.embed_dim, - channels=[16, 32, 64], - kernel_sizes=[3, 3, 3], - strides=[1, 1, 1], + def _init_image_components(self, world_model_cfg: EasyDict, observation_shape: SequenceType, num_res_blocks: int, + num_channels: int, obs_act_embed_dim: int) -> None: + """Initializes components for 'image' observation type.""" + self.representation_network = nn.ModuleList() + encoder_type = world_model_cfg.encoder_type + + # NOTE: Using a single shared encoder. The original code used a loop `for _ in range(1):`. + # To support N independent encoders, this logic would need to be modified. + if encoder_type == "resnet": + encoder = RepresentationNetworkUniZero( + observation_shape=observation_shape, + num_res_blocks=num_res_blocks, + num_channels=num_channels, + downsample=self.downsample, activation=self.activation, + norm_type=world_model_cfg.norm_type, + embedding_dim=obs_act_embed_dim, group_size=world_model_cfg.group_size, + final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder, ) - self.decoder_network = LatentDecoderForMemoryEnv( - image_shape=(3, 5, 5), - embedding_size=world_model_cfg.embed_dim, - channels=[64, 32, 16], - kernel_sizes=[3, 3, 3], - strides=[1, 1, 1], - activation=self.activation, - ) + self.representation_network.append(encoder) + elif encoder_type == "vit": + vit_configs = { + 'small': {'dim': 768, 'depth': 6, 'heads': 6, 'mlp_dim': 2048}, + 'base': {'dim': 768, 'depth': 12, 'heads': 12, 'mlp_dim': 3072}, + 'large': {'dim': 1024, 'depth': 24, 'heads': 16, 'mlp_dim': 4096}, + } + vit_size = 'base' if self.task_num > 8 else 'small' + selected_vit_config = vit_configs[vit_size] - if world_model_cfg.analysis_sim_norm: - # ====== for analysis ====== - self.encoder_hook = FeatureAndGradientHook() - self.encoder_hook.setup_hooks(self.representation_network) + vit_params = { + 'image_size': observation_shape[1], + 'patch_size': 8, + 'num_classes': obs_act_embed_dim, + 'dropout': 0.1, + 'emb_dropout': 0.1, + 'final_norm_option_in_encoder': world_model_cfg.final_norm_option_in_encoder, + 'lora_config': world_model_cfg, + **selected_vit_config + } + vit_config = ViTConfig(**vit_params) + encoder = ViT(config=vit_config) + + self.representation_network.append(encoder) + else: + raise ValueError(f"Unsupported encoder type for image observations: {encoder_type}") - self.tokenizer = Tokenizer(with_lpips=True, encoder=self.representation_network, - decoder_network=self.decoder_network, obs_type=world_model_cfg.obs_type) - self.world_model = WorldModelMT(config=world_model_cfg, tokenizer=self.tokenizer) - print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') - print(f'{sum(p.numel() for p in self.world_model.parameters()) - sum(p.numel() for p in self.tokenizer.decoder_network.parameters()) - sum(p.numel() for p in self.tokenizer.lpips.parameters())} parameters in agent.world_model - (decoder_network and lpips)') + # For image observations, the decoder is currently not used for reconstruction during training. + self.decoder_network = None + self.tokenizer = Tokenizer( + encoder=self.representation_network, + decoder=self.decoder_network, + with_lpips=False, + obs_type=world_model_cfg.obs_type + ) + if world_model_cfg.analysis_sim_norm: + self.encoder_hook = FeatureAndGradientHook() + self.encoder_hook.setup_hooks(self.representation_network) - print('==' * 20) - print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') - print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') - print(f'{sum(p.numel() for p in self.tokenizer.decoder_network.parameters())} parameters in agent.tokenizer.decoder_network') - print('==' * 20) + def _init_image_memory_components(self, world_model_cfg: EasyDict) -> None: + """Initializes components for 'image_memory' observation type.""" + # TODO: The 'concat_task_embed' option needs to be fully implemented for this obs_type. + self.representation_network = LatentEncoderForMemoryEnv( + image_shape=(3, 5, 5), + embedding_size=world_model_cfg.embed_dim, + channels=[16, 32, 64], + kernel_sizes=[3, 3, 3], + strides=[1, 1, 1], + activation=self.activation, + group_size=world_model_cfg.group_size, + ) + self.decoder_network = LatentDecoderForMemoryEnv( + image_shape=(3, 5, 5), + embedding_size=world_model_cfg.embed_dim, + channels=[64, 32, 16], + kernel_sizes=[3, 3, 3], + strides=[1, 1, 1], + activation=self.activation, + ) + self.tokenizer = Tokenizer( + encoder=self.representation_network, + decoder=self.decoder_network, + with_lpips=True, + obs_type=world_model_cfg.obs_type + ) + if world_model_cfg.analysis_sim_norm: + self.encoder_hook = FeatureAndGradientHook() + self.encoder_hook.setup_hooks(self.representation_network) + + def _log_model_parameters(self, obs_type: str) -> None: + """Logs the parameter counts of the main model components.""" + print('--------------------------------------------------') + print(f'{sum(p.numel() for p in self.world_model.parameters()):,} parameters in world_model') + print(f'{sum(p.numel() for p in self.world_model.transformer.parameters()):,} parameters in world_model.transformer') + print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters()):,} parameters in tokenizer.encoder') - #@profile - def initial_inference(self, obs_batch: torch.Tensor, action_batch=None, current_obs_batch=None, task_id=None) -> MZNetworkOutput: + if obs_type in ['vector', 'image_memory'] and self.tokenizer.decoder_network is not None: + print(f'{sum(p.numel() for p in self.tokenizer.decoder_network.parameters()):,} parameters in tokenizer.decoder_network') + if obs_type == 'image_memory': + # Calculate parameters excluding decoder and LPIPS for a specific comparison point. + params_without_decoder = sum(p.numel() for p in self.world_model.parameters()) - \ + sum(p.numel() for p in self.tokenizer.decoder_network.parameters()) - \ + sum(p.numel() for p in self.tokenizer.lpips.parameters()) + print(f'{params_without_decoder:,} parameters in world_model (excluding decoder and lpips)') + print('--------------------------------------------------') + + def initial_inference(self, obs_batch: torch.Tensor, action_batch: Optional[torch.Tensor] = None, + current_obs_batch: Optional[torch.Tensor] = None, task_id: Optional[Any] = None) -> MZNetworkOutput: """ Overview: - Initial inference of UniZero model, which is the first step of the UniZero model. - To perform the initial inference, we first use the representation network to obtain the ``latent_state``. - Then we use the prediction network to predict ``value`` and ``policy_logits`` of the ``latent_state``. + Performs the initial inference step of the model, corresponding to the representation function `h` in MuZero. + It takes an observation and produces a latent state and initial predictions. Arguments: - - obs_batch (:obj:`torch.Tensor`): The 3D image observation data. - Returns (MZNetworkOutput): - - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. - - reward (:obj:`torch.Tensor`): The predicted reward of input state and selected action. \ - In initial inference, we set it to zero vector. - - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. - - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. - Shapes: - - obs (:obj:`torch.Tensor`): :math:`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size. - - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. - - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. - - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. - - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ - latent state, W_ is the width of latent state. - """ + - obs_batch (:obj:`torch.Tensor`): A batch of initial observations. + - action_batch (:obj:`Optional[torch.Tensor]`): A batch of actions (if available, context-dependent). + - current_obs_batch (:obj:`Optional[torch.Tensor]`): A batch of current observations (if different from obs_batch). + - task_id (:obj:`Optional[Any]`): Identifier for the current task in a multi-task setting. + Returns: + - MZNetworkOutput: An object containing the predicted value, policy logits, and the initial latent state. + The reward is set to a zero tensor, as it's not predicted at the initial step. + """ batch_size = obs_batch.size(0) - # print('=here 5='*20) - # import ipdb; ipdb.set_trace() obs_act_dict = {'obs': obs_batch, 'action': action_batch, 'current_obs': current_obs_batch} - _, obs_token, logits_rewards, logits_policy, logits_value = self.world_model.forward_initial_inference(obs_act_dict, task_id=task_id) - latent_state, reward, policy_logits, value = obs_token, logits_rewards, logits_policy, logits_value - policy_logits = policy_logits.squeeze(1) - value = value.squeeze(1) + + _, obs_token, logits_rewards, logits_policy, logits_value = self.world_model.forward_initial_inference( + obs_act_dict, task_id=task_id + ) + + # The world model returns tokens and logits; map them to the standard MZNetworkOutput format. + latent_state = obs_token + policy_logits = logits_policy.squeeze(1) + value = logits_value.squeeze(1) return MZNetworkOutput( - value, - [0. for _ in range(batch_size)], - policy_logits, - latent_state, + value=value, + reward=torch.zeros(batch_size, device=value.device), # Reward is 0 at initial inference + policy_logits=policy_logits, + latent_state=latent_state, ) - #@profile - def recurrent_inference(self, state_action_history: torch.Tensor, simulation_index=0, - latent_state_index_in_search_path=[], task_id=None) -> MZNetworkOutput: + def recurrent_inference(self, state_action_history: torch.Tensor, simulation_index: int = 0, + search_depth: List = [], task_id: Optional[Any] = None) -> MZNetworkOutput: """ Overview: - Recurrent inference of UniZero model.To perform the recurrent inference, we concurrently predict the latent dynamics (reward/next_latent_state) - and decision-oriented quantities (value/policy) conditioned on the learned latent history in the world_model. + Performs a recurrent inference step, corresponding to the dynamics function `g` and prediction + function `f` in MuZero. It predicts the next latent state, reward, policy, and value based on a + history of latent states and actions. Arguments: - - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. - - action (:obj:`torch.Tensor`): The predicted action to rollout. - Returns (MZNetworkOutput): - - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. - - reward (:obj:`torch.Tensor`): The predicted reward of input state and selected action. - - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. - - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. - - next_latent_state (:obj:`torch.Tensor`): The predicted next latent state. - Shapes: - - obs (:obj:`torch.Tensor`): :math:`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size. - - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. - - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. - - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. - - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. - - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ - latent state, W_ is the width of latent state. - - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ - latent state, W_ is the width of latent state. - """ + - state_action_history (:obj:`torch.Tensor`): A tensor representing the history of latent states and actions. + - simulation_index (:obj:`int`): The index of the current simulation step within MCTS. + - search_depth (:obj:`List`): Information about the search depth, used for positional embeddings. + - task_id (:obj:`Optional[Any]`): Identifier for the current task in a multi-task setting. + Returns: + - MZNetworkOutput: An object containing the predicted value, reward, policy logits, and the next latent state. + """ _, logits_observations, logits_rewards, logits_policy, logits_value = self.world_model.forward_recurrent_inference( - state_action_history, simulation_index, latent_state_index_in_search_path, task_id=task_id) - next_latent_state, reward, policy_logits, value = logits_observations, logits_rewards, logits_policy, logits_value - policy_logits = policy_logits.squeeze(1) - value = value.squeeze(1) - reward = reward.squeeze(1) - return MZNetworkOutput(value, reward, policy_logits, next_latent_state) \ No newline at end of file + state_action_history, simulation_index, search_depth, task_id=task_id + ) + + # Map the world model outputs to the standard MZNetworkOutput format. + next_latent_state = logits_observations + reward = logits_rewards.squeeze(1) + policy_logits = logits_policy.squeeze(1) + value = logits_value.squeeze(1) + + return MZNetworkOutput( + value=value, + reward=reward, + policy_logits=policy_logits, + latent_state=next_latent_state, + ) \ No newline at end of file diff --git a/lzero/model/unizero_world_models/kv_caching.py b/lzero/model/unizero_world_models/kv_caching.py index f373739c6..cf040b13a 100644 --- a/lzero/model/unizero_world_models/kv_caching.py +++ b/lzero/model/unizero_world_models/kv_caching.py @@ -1,165 +1,254 @@ -# Modified from https://github.com/eloialonso/iris/blob/main/src/models/kv_caching.py +# -*- coding: utf-8 -*- +""" +This script is a refactored version of the key-value caching mechanism from: +https://github.com/eloialonso/iris/blob/main/src/models/kv_caching.py -from typing import Tuple +The optimization focuses on improving clarity, documentation, and adherence to modern coding standards +while strictly preserving the original functionality and external API. +""" +from typing import Tuple, Optional import numpy as np import torch +class AssignWithoutInplaceCheck(torch.autograd.Function): + """ + Overview: + A custom autograd function to perform an in-place-like assignment on a tensor slice + without triggering PyTorch's version counter checks. This is useful for updating + buffers or caches within a computation graph. + + Reference: + Inspired by discussions on the PyTorch forums, such as: + https://discuss.pytorch.org/t/disable-in-place-correctness-version-check-any-other-workaround/90738/4 + + .. warning:: + This function is unsafe if the same slice of the input tensor is overwritten + multiple times, as it can lead to incorrect gradient calculations. + """ + + @staticmethod + def _get_slice(dim: int, start: int, stop: int) -> Tuple[slice, ...]: + """ + Overview: + Creates a slice tuple for indexing a tensor at a specific dimension. + Arguments: + - dim (:obj:`int`): The dimension to slice along. + - start (:obj:`int`): The starting index for the slice. + - stop (:obj:`int`): The ending index for the slice. + Returns: + - slice_tuple (:obj:`Tuple[slice, ...]`): A tuple of slice objects for indexing. + """ + return (slice(None),) * dim + (slice(start, stop),) + + @staticmethod + def forward( + ctx, + input_tensor: torch.Tensor, + value: torch.Tensor, + dim: int, + start: int, + stop: int + ) -> torch.Tensor: + """ + Overview: + The forward pass assigns the `value` tensor to a slice of the `input_tensor`. + Arguments: + - ctx: The context object for storing information for the backward pass. + - input_tensor (:obj:`torch.Tensor`): The tensor to be modified. + - value (:obj:`torch.Tensor`): The tensor to assign to the slice. + - dim (:obj:`int`): The dimension along which to perform the assignment. + - start (:obj:`int`): The starting index of the slice. + - stop (:obj:`int`): The ending index of the slice. + Returns: + - modified_tensor (:obj:`torch.Tensor`): The `input_tensor` after modification. + """ + ctx.dim = dim + ctx.start = start + ctx.stop = stop + # Directly modify the data of the input tensor to bypass version checks. + input_tensor.data[AssignWithoutInplaceCheck._get_slice(dim, start, stop)] = value + return input_tensor + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]: + """ + Overview: + The backward pass computes gradients for the inputs of the forward pass. + Arguments: + - ctx: The context object with saved information from the forward pass. + - grad_output (:obj:`torch.Tensor`): The gradient of the output tensor. + Returns: + - grad_input_tensor (:obj:`torch.Tensor`): The gradient with respect to `input_tensor`. + - grad_value (:obj:`torch.Tensor`): The gradient with respect to `value`. + - None, None, None: Gradients for `dim`, `start`, and `stop`, which are not needed. + """ + # The gradient for the original input tensor is the same as the output gradient. + grad_input_tensor = grad_output + # The gradient for the value tensor is the slice of the output gradient. + grad_value = grad_output[AssignWithoutInplaceCheck._get_slice(ctx.dim, ctx.start, ctx.stop)] + return grad_input_tensor, grad_value, None, None, None + + class Cache: + """ + Overview: + A cache for storing a single type of intermediate tensor (e.g., keys or values) + in a Transformer-like model. It handles dynamic updates and size management. + """ + def __init__(self, num_samples: int, num_heads: int, max_tokens: int, embed_dim: int, device: torch.device) -> None: """ Overview: - Cache for storing intermediate results in a transformer model. + Initializes the cache. Arguments: - - num_samples (:obj:`int`): The number of samples to cache. + - num_samples (:obj:`int`): The number of samples (batch size) to cache. - num_heads (:obj:`int`): The number of attention heads. - - max_tokens (:obj:`int`): The maximum number of tokens. - - embed_dim (:obj:`int`): The dimension of the embeddings. - - device (:obj:`torch.device`): The device on which to store the cache. + - max_tokens (:obj:`int`): The maximum number of tokens the cache can hold. + - embed_dim (:obj:`int`): The total dimension of the embeddings. + - device (:obj:`torch.device`): The device on which to store the cache tensor. """ - assert embed_dim % num_heads == 0 - self._num_samples, self._cache, self._size = num_samples, None, None - self._reset = lambda n: torch.empty(n, num_heads, max_tokens, embed_dim // num_heads, device=device) # (B, nh, T, hs) + if embed_dim % num_heads != 0: + raise ValueError(f"Embedding dimension ({embed_dim}) must be divisible by the number of heads ({num_heads}).") + + self._num_samples = num_samples + self._num_heads = num_heads + self._max_tokens = max_tokens + self._head_dim = embed_dim // num_heads + self._device = device + + self._cache: torch.Tensor = self._create_cache_tensor(self._num_samples) + self._size: int = 0 self.reset() + def _create_cache_tensor(self, num_samples: int) -> torch.Tensor: + """ + Overview: + Creates an empty tensor with the correct shape and device for the cache. + Arguments: + - num_samples (:obj:`int`): The number of samples for which to create the cache. + Returns: + - empty_cache (:obj:`torch.Tensor`): An uninitialized tensor for the cache. + """ + return torch.empty( + num_samples, self._num_heads, self._max_tokens, self._head_dim, device=self._device + ) # Shape: (B, nh, T, hs) + @property def shape(self) -> Tuple[int, int, int, int]: """ Overview: - Get the shape of the cache. + Gets the effective shape of the cache's content. Returns: - - shape (:obj:`Tuple[int, int, int, int]`): The shape of the cache. + - shape (:obj:`Tuple[int, int, int, int]`): A tuple representing (num_samples, num_heads, current_size, head_dim). """ - n, num_heads, _, head_dim = self._cache.shape - return n, num_heads, self._size, head_dim + return self._num_samples, self._num_heads, self._size, self._head_dim def reset(self) -> None: """ Overview: - Reset the cache to its initial state. + Resets the cache to an empty state. """ - self._cache = self._reset(self._num_samples) + self._cache = self._create_cache_tensor(self._num_samples) self._size = 0 def prune(self, mask: np.ndarray) -> None: """ Overview: - Prune the cache based on a mask. + Prunes the cache along the sample dimension using a boolean mask. Arguments: - - mask (:obj:`np.ndarray`): A boolean mask indicating which samples to keep. + - mask (:obj:`np.ndarray`): A 1D boolean array where `True` indicates which samples to keep. """ - assert mask.ndim == 1 and mask.shape[0] == self.shape[0] + if not (mask.ndim == 1 and mask.shape[0] == self._num_samples): + raise ValueError("Mask must be a 1D numpy array with length equal to the number of samples.") self._cache = self._cache[mask] self._num_samples = self._cache.shape[0] def get(self) -> torch.Tensor: """ Overview: - Get the current contents of the cache. + Retrieves the current contents of the cache. Returns: - - cache (:obj:`torch.Tensor`): The current contents of the cache. + - cache_content (:obj:`torch.Tensor`): A tensor containing the valid data in the cache. """ return self._cache[:, :, :self._size, :] def update(self, x: torch.Tensor, tokens: int) -> None: """ Overview: - Update the cache with new values. + Updates the cache with new tensor values. If the cache is full, it discards the oldest + tokens to make space. Arguments: - - x (:obj:`torch.Tensor`): The new values to update the cache with. - - tokens (:obj:`int`): The number of tokens to update. - """ - try: - # Calculate the required capacity after adding the new tokens - required_capacity = self._size + tokens - # print(f'self._size:{self._size}, tokens:{tokens}') - - # Check if the cache has enough space to accommodate the new tokens, - # kv_cache, z/a, register_token - # 这样修复后kv_cache的位置编码不是从0开始的, 那后面按照从零开始矫正也就是错误的, - # 但是由于self.keys_values_wm._keys_values[layer]._k_cache._size < context_length - 1,所以不会矫正 - # 但是在_add_position_embeddings时,prev_steps是错误的,导致新增的z/a的位置编码索引与前面的kv不连续 - if required_capacity > self._cache.shape[2]: - # Shift existing cache data by removing the oldest entries - shift_amount = required_capacity - self._cache.shape[2] - # =======TODO: 应该去掉偶数个(z,a)以保证 head 输出pattern保持不变======= - if shift_amount % 2 != 0: - shift_amount = shift_amount + 1 - # print(f'required_capacity:{required_capacity}, self._cache.shape[2]:{self._cache.shape[2]}, shift_amount:{shift_amount}') - if shift_amount >= self._size: - # If the shift amount exceeds or equals the current size, just reset the cache - print("Cache too small; resetting the entire cache") - self._cache = torch.zeros_like(self._cache) # Reset cache to zeros - self._size = 0 # Reset size - else: - # Shift the cache to make room for new data - self._cache[:, :, :self._size - shift_amount, :] = self._cache[:, :, shift_amount:self._size, :] - self._size -= shift_amount # Update the size after shifting - - # Update the cache with new values - self._cache = AssignWithoutInplaceCheck.apply( - self._cache, x, 2, self._size, self._size + tokens - ) - self._size += tokens # Update the size after adding new values - - except Exception as e: - print(f"An error occurred during cache update: {e}") - - # def update(self, x: torch.Tensor, tokens: int) -> None: - # """ - # Overview: - # Update the cache with new values. - # Arguments: - # - x (:obj:`torch.Tensor`): The new values to update the cache with. - # - tokens (:obj:`int`): The number of tokens to update. - # """ - # # assert (x.ndim == self._cache.ndim) and all([x.size(i) == self._cache.size(i) for i in (0, 1, 3)]) - # # assert self._size + tokens <= self._cache.shape[2] # TODO - # try: - # self._cache = AssignWithoutInplaceCheck.apply(self._cache, x, 2, self._size, self._size + tokens) - # self._size += tokens - # except Exception as e: - # print(e) - # # import ipdb; ipdb.set_trace() - + - x (:obj:`torch.Tensor`): The new tensor data to add to the cache. + - tokens (:obj:`int`): The number of tokens being added (sequence length of `x`). + """ + required_capacity = self._size + tokens + + # If the new tokens exceed the cache's maximum capacity, shift existing data to make room. + if required_capacity > self._max_tokens: + shift_amount = required_capacity - self._max_tokens + + # This logic is crucial for models like MuZero where tokens are added in (state, action) pairs. + # To maintain the integrity of these pairs, an even number of tokens must be discarded. + if shift_amount % 2 != 0: + shift_amount += 1 + + if shift_amount >= self._size: + # If the required shift is larger than the current cache size, it's more efficient to reset. + self._cache.zero_() + self._size = 0 + else: + # Shift the existing cache content to the left, discarding the oldest tokens. + self._cache[:, :, :self._size - shift_amount, :] = self._cache[:, :, shift_amount:self._size, :] + self._size -= shift_amount + # NOTE: Shifting the cache invalidates absolute positional embeddings. + # The parent model must handle positional encoding adjustments. For example, if positional + # embeddings are calculated based on `prev_steps`, this shift means `prev_steps` may no + # longer correspond to the true start, potentially causing discontinuities. + + # Use the custom autograd function to assign the new data without inplace errors. + self._cache = AssignWithoutInplaceCheck.apply( + self._cache, x, 2, self._size, self._size + tokens + ) + self._size += tokens class KVCache: - def __init__(self, n: int, num_heads: int, max_tokens: int, embed_dim: int, device: torch.device) -> None: + """ + Overview: + A container for a pair of caches: one for keys (K) and one for values (V), + typically used in a single attention layer of a Transformer. + """ + + def __init__(self, num_samples: int, num_heads: int, max_tokens: int, embed_dim: int, device: torch.device) -> None: """ Overview: - Cache for storing key and value tensors in a transformer model. + Initializes the Key-Value cache pair. Arguments: - - n (:obj:`int`): The number of samples to cache. + - num_samples (:obj:`int`): The number of samples (batch size) to cache. - num_heads (:obj:`int`): The number of attention heads. - - max_tokens (:obj:`int`): The maximum number of tokens. - - embed_dim (:obj:`int`): The dimension of the embeddings. - - device (:obj:`torch.device`): The device on which to store the cache. + - max_tokens (:obj:`int`): The maximum number of tokens the cache can hold. + - embed_dim (:obj:`int`): The total dimension of the embeddings. + - device (:obj:`torch.device`): The device on which to store the cache tensors. """ - self._k_cache = Cache(n, num_heads, max_tokens, embed_dim, device) - self._v_cache = Cache(n, num_heads, max_tokens, embed_dim, device) - - # self.register_token_num = 2 # Number of register tokens TODO====== - - # def set_register_token_num(self, num: int) -> None: - # """Set the number of register tokens.""" - # self.register_token_num = num + self._k_cache = Cache(num_samples, num_heads, max_tokens, embed_dim, device) + self._v_cache = Cache(num_samples, num_heads, max_tokens, embed_dim, device) @property def shape(self) -> Tuple[int, int, int, int]: """ Overview: - Get the shape of the key cache. + Gets the effective shape of the key cache's content. Returns: - - shape (:obj:`Tuple[int, int, int, int]`): The shape of the key cache. + - shape (:obj:`Tuple[int, int, int, int]`): Shape of the key cache (num_samples, num_heads, current_size, head_dim). """ return self._k_cache.shape def reset(self) -> None: """ Overview: - Reset both key and value caches to their initial states. + Resets both the key and value caches to their empty states. """ self._k_cache.reset() self._v_cache.reset() @@ -167,9 +256,9 @@ def reset(self) -> None: def prune(self, mask: np.ndarray) -> None: """ Overview: - Prune both key and value caches based on a mask. + Prunes both key and value caches based on a boolean mask. Arguments: - - mask (:obj:`np.ndarray`): A boolean mask indicating which samples to keep. + - mask (:obj:`np.ndarray`): A 1D boolean array indicating which samples to keep. """ self._k_cache.prune(mask) self._v_cache.prune(mask) @@ -177,71 +266,94 @@ def prune(self, mask: np.ndarray) -> None: def get(self) -> Tuple[torch.Tensor, torch.Tensor]: """ Overview: - Get the current contents of the key and value caches. + Retrieves the current contents of the key and value caches. Returns: - key_cache (:obj:`torch.Tensor`): The current contents of the key cache. - value_cache (:obj:`torch.Tensor`): The current contents of the value cache. """ return self._k_cache.get(), self._v_cache.get() - def update(self, k: torch.Tensor, v: torch.Tensor): + def update(self, k: torch.Tensor, v: torch.Tensor) -> None: """ Overview: - Update both key and value caches with new values. - If `is_register_token` is True, prepend the register tokens to the cache. + Updates both key and value caches with new tensors. + Arguments: + - k (:obj:`torch.Tensor`): The new key tensor to add. + - v (:obj:`torch.Tensor`): The new value tensor to add. """ - self._k_cache.update(k, k.size(2)) - self._v_cache.update(v, v.size(2)) + # The number of tokens is inferred from the sequence dimension (dim 2). + num_tokens = k.size(2) + self._k_cache.update(k, num_tokens) + self._v_cache.update(v, num_tokens) + class KeysValues: - def __init__(self, n: int, num_heads: int, max_tokens: int, embed_dim: int, num_layers: int, device: torch.device) -> None: + """ + Overview: + Manages a collection of KVCache objects, one for each layer in a Transformer model. + """ + + def __init__( + self, + num_samples: int, + num_heads: int, + max_tokens: int, + embed_dim: int, + num_layers: int, + device: torch.device + ) -> None: """ Overview: - Class for managing multiple layers of key and value caches in a transformer model. + Initializes KV caches for all layers. Arguments: - - n (:obj:`int`): The number of samples to cache. + - num_samples (:obj:`int`): The number of samples (batch size). - num_heads (:obj:`int`): The number of attention heads. - - max_tokens (:obj:`int`): The maximum number of tokens. + - max_tokens (:obj:`int`): The maximum number of tokens per cache. - embed_dim (:obj:`int`): The dimension of the embeddings. - - num_layers (:obj:`int`): The number of layers in the transformer model. - - device (:obj:`torch.device`): The device on which to store the caches. + - num_layers (:obj:`int`): The number of layers in the Transformer model. + - device (:obj:`torch.device`): The device for storing cache tensors. """ - self._keys_values = tuple([KVCache(n, num_heads, max_tokens, embed_dim, device) for _ in range(num_layers)]) + self._keys_values = tuple([ + KVCache(num_samples, num_heads, max_tokens, embed_dim, device) for _ in range(num_layers) + ]) - def __getitem__(self, index: int) -> KVCache: + def __getitem__(self, layer_index: int) -> KVCache: """ Overview: - Get the key and value cache for a specific layer. + Retrieves the KVCache for a specific layer. Arguments: - - index (:obj:`int`): The layer index. + - layer_index (:obj:`int`): The index of the layer. Returns: - - kv_cache (:obj:`KVCache`): The key and value cache for the specified layer. + - kv_cache (:obj:`KVCache`): The key-value cache for the specified layer. """ - return self._keys_values[index] + return self._keys_values[layer_index] - def __len__(self): + def __len__(self) -> int: """ Overview: - Get the number of layers in the transformer model. + Gets the number of layers. Returns: - - length (:obj:`int`): The number of layers. + - num_layers (:obj:`int`): The number of layers being managed. """ return len(self._keys_values) @property - def size(self): + def size(self) -> int: """ Overview: - Get the size of the tokens in the cache. + Gets the current number of tokens stored in the caches. Returns: - - size (:obj:`int`): The size of the tokens in the cache. + - size (:obj:`int`): The number of tokens in the cache (assumes all layers have the same size). """ + # All layer caches are synchronized, so we can check the size of the first one. + if not self._keys_values: + return 0 return self._keys_values[0].shape[2] def reset(self) -> None: """ Overview: - Reset all key and value caches to their initial states. + Resets the KV caches for all layers. """ for kv_cache in self._keys_values: kv_cache.reset() @@ -249,82 +361,27 @@ def reset(self) -> None: def prune(self, mask: np.ndarray) -> None: """ Overview: - Prune all key and value caches based on a mask. + Prunes the KV caches for all layers based on a mask. Arguments: - mask (:obj:`np.ndarray`): A boolean mask indicating which samples to keep. """ for kv_cache in self._keys_values: kv_cache.prune(mask) - def remove_register_tokens(self, register_token_num: int): - """ - Overview: - 移除所有层 KV 缓存开头的 Register Token。 - 在推理结束后调用,保证外层看到的 KV 不包含 Register Token。 - """ - # import ipdb; ipdb.set_trace() - for kv_cache in self._keys_values: - # 移除 KVCache 中后面的 register_token_num 个 token - kv_cache._k_cache._size -= register_token_num - kv_cache._v_cache._size -= register_token_num - - -class AssignWithoutInplaceCheck(torch.autograd.Function): - """ - Overview: - Custom autograd function to perform in-place assignment without triggering version checks. - Inspired from: - https://discuss.pytorch.org/t/disable-in-place-correctness-version-check-any-other-workaround/90738/4 - - .. warning: - Do not use it to overwrite a slice twice. - """ - - @staticmethod - def get_slice(dim: int, start: int, stop: int) -> Tuple[slice]: - """ - Overview: - Get the slice object for the given dimension and range. - Arguments: - - dim (:obj:`int`): The dimension along which to slice. - - start (:obj:`int`): The start index of the slice. - - stop (:obj:`int`): The stop index of the slice. - Returns: - - slice (:obj:`Tuple[slice]`): The slice object. - """ - return tuple([slice(None), ] * dim + [slice(start, stop)]) - - @staticmethod - def forward(ctx, input: torch.Tensor, value: torch.Tensor, dim: int, start: int, stop: int) -> torch.Tensor: - """ - Overview: - Forward pass of the custom autograd function. - Arguments: - - ctx: The context object to store information for backward computation. - - input (:obj:`torch.Tensor`): The input tensor to be modified. - - value (:obj:`torch.Tensor`): The value tensor to assign to the input. - - dim (:obj:`int`): The dimension along which to assign the value. - - start (:obj:`int`): The start index of the assignment. - - stop (:obj:`int`): The stop index of the assignment. - Returns: - - output (:obj:`torch.Tensor`): The modified input tensor. - """ - ctx.dim = dim - ctx.start = start - ctx.stop = stop - input.data[AssignWithoutInplaceCheck.get_slice(dim, start, stop)] = value - return input - - @staticmethod - def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor]: + def remove_register_tokens(self, register_token_num: int) -> None: """ Overview: - Backward pass of the custom autograd function. + Removes the last `register_token_num` tokens from the active view of the cache + in each layer by adjusting the internal size pointer. This does not delete the data + but makes it invisible to subsequent `get` and `update` calls. + This is typically called after an inference step that used temporary tokens + (e.g., register tokens) to ensure they are not part of the ongoing context. Arguments: - - ctx: The context object storing information from forward computation. - - grad_out (:obj:`torch.Tensor`): The gradient of the output tensor. - Returns: - - grad_input (:obj:`torch.Tensor`): The gradient of the input tensor. - - grad_value (:obj:`torch.Tensor`): The gradient of the value tensor. + - register_token_num (:obj:`int`): The number of tokens to remove from the end of the cache view. """ - return grad_out, grad_out[AssignWithoutInplaceCheck.get_slice(ctx.dim, ctx.start, ctx.stop)], None, None, None \ No newline at end of file + if register_token_num <= 0: + return + for kv_cache in self._keys_values: + # Decrement the size pointer for both K and V caches. + kv_cache._k_cache._size = max(0, kv_cache._k_cache._size - register_token_num) + kv_cache._v_cache._size = max(0, kv_cache._v_cache._size - register_token_num) \ No newline at end of file diff --git a/lzero/model/unizero_world_models/lpips.py b/lzero/model/unizero_world_models/lpips.py index 7abd5c062..2afa15a83 100644 --- a/lzero/model/unizero_world_models/lpips.py +++ b/lzero/model/unizero_world_models/lpips.py @@ -20,16 +20,14 @@ def __init__(self, use_dropout: bool = True): super().__init__() self.scaling_layer = ScalingLayer() self.chns = [64, 128, 256, 512, 512] # vg16 features + # Comment out the following line if you don't need perceptual loss # self.net = vgg16(pretrained=True, requires_grad=False) - # self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) # self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) # self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) # self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) # self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) - - # Comment out the following line if you don't need perceptual loss # self.load_from_pretrained() # for param in self.parameters(): # param.requires_grad = False diff --git a/lzero/model/unizero_world_models/moe.py b/lzero/model/unizero_world_models/moe.py index 159afd69e..53f0c5620 100644 --- a/lzero/model/unizero_world_models/moe.py +++ b/lzero/model/unizero_world_models/moe.py @@ -1,49 +1,273 @@ import dataclasses -from typing import List +from typing import List, Any import torch import torch.nn.functional as F from simple_parsing.helpers import Serializable from torch import nn +from lzero.model.unizero_world_models.transformer import _maybe_wrap_linear + +# Note: The following lines are examples of how _maybe_wrap_linear might be used. +# _maybe_wrap_linear(nn.Linear(config.embed_dim, 4 * config.embed_dim), config, "feed_forward") + +# This implementation is inspired by the following sources: +# https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/moe.py +# https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/transformer_layers.py#L149 # Modified from https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/transformer.py#L108 + + class MultiplicationFeedForward(nn.Module): - def __init__(self, config): - super().__init__() + """ + Overview: + Implements the SwiGLU (Swish-Gated Linear Unit) feed-forward layer, a variant of a transformer feed-forward network + that uses element-wise multiplication of two linear projections, one of which is passed through a SiLU activation. + This is often expressed as: FFN_SwiGLU(x) = (SiLU(x @ W1) * (x @ W3)) @ W2. + """ - self.w1 = nn.Linear(config.embed_dim, 4 * config.embed_dim, bias=False) - self.w2 = nn.Linear(4 * config.embed_dim, config.embed_dim, bias=False) - self.w3 = nn.Linear(config.embed_dim, 4 * config.embed_dim, bias=False) + def __init__(self, config: Any) -> None: + """ + Overview: + Initializes the MultiplicationFeedForward layer. + Arguments: + - config (:obj:`Any`): A configuration object containing model hyperparameters. + It is expected to have `embed_dim` (int) and `moe_use_lora` (bool). + """ + super().__init__() + hidden_dim = 4 * config.embed_dim + if config.moe_use_lora: + self.w1 = _maybe_wrap_linear(nn.Linear(config.embed_dim, hidden_dim, bias=False), config, "feed_forward") + self.w2 = _maybe_wrap_linear(nn.Linear(hidden_dim, config.embed_dim, bias=False), config, "feed_forward") + self.w3 = _maybe_wrap_linear(nn.Linear(config.embed_dim, hidden_dim, bias=False), config, "feed_forward") + else: + self.w1 = nn.Linear(config.embed_dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, config.embed_dim, bias=False) + self.w3 = nn.Linear(config.embed_dim, hidden_dim, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) # type: ignore + """ + Overview: + Performs the forward pass of the SwiGLU layer. + Arguments: + - x (:obj:`torch.Tensor`): The input tensor. + Returns: + - torch.Tensor: The output tensor after applying the SwiGLU transformation. + """ + return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) + @dataclasses.dataclass class MoeArgs(Serializable): - num_experts: int - num_experts_per_tok: int + """ + Overview: + Dataclass for storing Mixture-of-Experts (MoE) configuration arguments. + """ + num_experts: int # The total number of experts in the MoE layer. + num_experts_per_tok: int # The number of experts to route each token to (k). + + +class MoELayer(nn.Module): + """ + Overview: + A straightforward implementation of a Mixture-of-Experts (MoE) layer. + This version iterates through each expert and processes the tokens routed to it. + While clear and easy to understand, it can be less efficient than vectorized approaches. + + The process is as follows: + 1. The input tensor `x` is flattened from [B, T, D] to [N, D], where N = B * T. + 2. A gating network calculates logits for each token to determine expert assignment. + 3. For each token, the top-k experts are selected based on the logits. + 4. The layer iterates through each expert, gathers all tokens assigned to it, + and computes their outputs. + 5. The outputs are weighted by the gating scores and summed up. + 6. An optional shared expert can be applied to all tokens. + 7. The final tensor is reshaped to its original shape [B, T, D]. + Attributes: + - dim (:obj:`int`): The dimension of the input features. + - num_experts (:obj:`int`): The total number of experts. + - num_experts_per_tok (:obj:`int`): The number of experts activated per token (top-k). + - gate (:obj:`nn.Module`): The gating network that produces routing logits. + - experts (:obj:`nn.ModuleList`): A list of expert networks. + - shared_expert (:obj:`nn.Module` or `None`): An optional shared expert applied to all tokens. + """ -class MoeLayer(nn.Module): - def __init__(self, experts: List[nn.Module], gate: nn.Module, num_experts_per_tok=1): + def __init__(self, config: Any, experts: List[nn.Module], gate: nn.Module, num_experts_per_tok: int = 1) -> None: + """ + Overview: + Initializes the MoELayer. + Arguments: + - config (:obj:`Any`): A configuration object. Expected to have `embed_dim` and optionally `n_shared_experts`. + - experts (:obj:`List[nn.Module]`): A list of PyTorch modules representing the experts. + - gate (:obj:`nn.Module`): The gating module for routing tokens. + - num_experts_per_tok (:obj:`int`): The number of experts to use for each token. + """ super().__init__() - assert len(experts) > 0 - self.experts = nn.ModuleList(experts) + self.dim = config.embed_dim + self.num_experts = len(experts) + self.num_experts_per_tok = num_experts_per_tok self.gate = gate + self.experts = nn.ModuleList(experts) + + # If specified in the config, create a shared expert branch. + if hasattr(config, "n_shared_experts") and config.n_shared_experts > 0: + # TODO: The architecture of the shared expert could be made more configurable. + self.shared_expert = nn.Sequential( + nn.Linear(self.dim, config.n_shared_experts * (4 * self.dim)), + nn.GELU(), + nn.Linear(config.n_shared_experts * (4 * self.dim), self.dim) + ) + else: + self.shared_expert = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Overview: + Performs the forward pass for the MoE layer. + Arguments: + - x (:obj:`torch.Tensor`): The input tensor of shape [batch_size, seq_len, dim]. + Returns: + - torch.Tensor: The output tensor with the same shape as the input. + """ + # Store original shape and flatten input to 2D: [batch_size * seq_len, dim] + original_shape = x.size() + x = x.view(-1, self.dim) + + # Compute gate logits, shape: [num_tokens, num_experts] + gate_logits = self.gate(x) + # Select top-k experts for each token. + weights, indices = torch.topk(gate_logits, self.num_experts_per_tok, dim=1) + # Normalize the weights of selected experts using softmax. + weights = F.softmax(weights, dim=1).to(x.dtype) + + # Initialize the output tensor for expert computations. + expert_output = torch.zeros_like(x) + + # Iterate over each expert to compute outputs for the tokens routed to it. + for expert_id in range(self.num_experts): + # Find the tokens that have this expert in their top-k list. + batch_idx, expert_tok_idx = torch.where(indices == expert_id) + if batch_idx.numel() == 0: + continue + + # Select the subset of tokens for the current expert. + token_subset = x[batch_idx] # Shape: [num_tokens_for_expert, dim] + # Compute the output from the current expert. + output_expert = self.experts[expert_id](token_subset) + # Get the corresponding weights for these tokens. + token_weights = weights[batch_idx, expert_tok_idx].unsqueeze(-1) + # Apply weights and accumulate the output. + expert_output[batch_idx] += output_expert * token_weights + + # If a shared expert exists, add its output. + if self.shared_expert is not None: + shared_output = self.shared_expert(x) + output = expert_output + shared_output + else: + output = expert_output + + # Restore the original tensor shape and return. + return output.view(original_shape) + + +class MoELayerOptimized(nn.Module): + """ + Overview: + An optimized implementation of the Mixture-of-Experts (MoE) layer that maintains the same API as `MoELayer`. + This version avoids loops over experts by using a vectorized scatter-gather approach, which is significantly + more efficient on modern hardware. The forward pass complexity is O(N_tokens + ΣE_i), where ΣE_i is the + total number of tokens processed across all experts. + + The process is as follows: + 1. **Routing**: Get top-k experts and their weights for each token. + 2. **Flattening**: Create a flat list of (token_index, expert_index, weight) tuples. + 3. **Sorting**: Sort these tuples by expert_index. This groups all tokens destined for the same expert together. + 4. **Batch Forward**: Process the tokens for each expert in a single, contiguous batch, avoiding Python loops. + 5. **Weighted Scatter**: Apply gating weights to the expert outputs and scatter-add them back to a buffer + indexed by the original token positions. + 6. **Shared Expert**: If configured, add the output from the shared expert. + 7. **Reshape**: Reshape the final output tensor to its original 3D shape. + """ + + def __init__(self, config: Any, experts: List[nn.Module], gate: nn.Module, num_experts_per_tok: int = 1) -> None: + """ + Overview: + Initializes the MoELayerOptimized. + Arguments: + - config (:obj:`Any`): A configuration object. Expected to have `embed_dim` and optionally `n_shared_experts`. + - experts (:obj:`List[nn.Module]`): A list of PyTorch modules representing the experts. + - gate (:obj:`nn.Module`): The gating module for routing tokens. + - num_experts_per_tok (:obj:`int`): The number of experts to use for each token. + """ + super().__init__() + self.dim = config.embed_dim + self.num_experts = len(experts) self.num_experts_per_tok = num_experts_per_tok + self.gate = gate + self.experts = nn.ModuleList(experts) + + self.use_shared = getattr(config, "n_shared_experts", 0) > 0 + if self.use_shared: + # TODO: The architecture of the shared expert could be made more configurable. + self.shared_expert = nn.Sequential( + nn.Linear(self.dim, config.n_shared_experts * (4 * self.dim)), + nn.GELU(), + nn.Linear(config.n_shared_experts * (4 * self.dim), self.dim), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Overview: + Performs the optimized forward pass for the MoE layer. + Arguments: + - x (:obj:`torch.Tensor`): The input tensor of shape [B, T, D]. + Returns: + - torch.Tensor: The output tensor with the same shape as the input. + """ + B, T, D = x.shape + x_flat = x.reshape(-1, D) # [N, D]; N = B*T + + # 1. Routing: Get top-k experts and weights. + gate_logits = self.gate(x_flat) # [N, E] + weights, topk_idx = torch.topk(gate_logits, self.num_experts_per_tok, dim=1) # [N, k] + weights = F.softmax(weights, dim=1).to(x.dtype) # [N, k] + + # 2. Flatten token-expert pairs. + N, k = weights.shape + flat_token_idx = torch.arange(N, device=x.device).repeat_interleave(k) # [N*k] + flat_expert_idx = topk_idx.reshape(-1) # [N*k] + flat_weight = weights.reshape(-1, 1) # [N*k, 1] + flat_input = x_flat[flat_token_idx] # [N*k, D] + + # 3. Sort by expert index to group tokens for batch processing. + sort_order = torch.argsort(flat_expert_idx) # [N*k] + flat_expert_idx = flat_expert_idx[sort_order] + flat_token_idx = flat_token_idx[sort_order] + flat_weight = flat_weight[sort_order] + flat_input = flat_input[sort_order] + + # Count how many tokens each expert will process. + counts = torch.bincount(flat_expert_idx, minlength=self.num_experts) # [E] + + # Prepare output buffer. + out_buffer = torch.zeros_like(flat_input) # [N*k, D] + + # 4. Perform forward pass for each expert on its batch of tokens. + ptr = 0 + for eid, num in enumerate(counts.tolist()): + if num == 0: + continue + seg = slice(ptr, ptr + num) + out_buffer[seg] = self.experts[eid](flat_input[seg]) + ptr += num + + # 5. Apply weights and scatter-add results back to token-indexed buffer. + out_buffer.mul_(flat_weight) # In-place multiplication by weights. + token_output = torch.zeros_like(x_flat) # [N, D] + token_output.index_add_(0, flat_token_idx, out_buffer) + + # 6. Add shared expert output if it exists. + if self.use_shared: + token_output.add_(self.shared_expert(x_flat)) - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - # if len(self.experts) == 1: - # # 只有一个专家时,直接使用该专家 - # return self.experts[0](inputs) - - gate_logits = self.gate(inputs) - weights, selected_experts = torch.topk(gate_logits, self.num_experts_per_tok) - weights = F.softmax(weights, dim=1, dtype=torch.float).to(inputs.dtype) - results = torch.zeros_like(inputs) - for i, expert in enumerate(self.experts): - # batch_idx, nth_expert = torch.where(selected_experts == i) - # results[batch_idx] += weights[batch_idx, nth_expert, None] * expert(inputs[batch_idx]) - batch_idx, token_idx, nth_expert = torch.where(selected_experts == i) - results[batch_idx, token_idx] += weights[batch_idx, token_idx, nth_expert][:, None] * expert(inputs[batch_idx, token_idx]) - return results \ No newline at end of file + return token_output.reshape(B, T, D) \ No newline at end of file diff --git a/lzero/model/unizero_world_models/test_moe.py b/lzero/model/unizero_world_models/test_moe.py index 6ab93cc16..1f0f5437c 100644 --- a/lzero/model/unizero_world_models/test_moe.py +++ b/lzero/model/unizero_world_models/test_moe.py @@ -1,43 +1,122 @@ +""" +test_moe.py + +Overview: + A test script to verify the functional equivalence between a standard Transformer's feed-forward network (FFN) + and a Mixture-of-Experts (MoE) layer configured with a single expert. This script demonstrates that + the MoE layer correctly specializes to a standard FFN when num_experts is 1, ensuring backward + compatibility and correct routing logic. +""" import dataclasses from typing import List import torch +import torch.nn as nn import torch.nn.functional as F -from simple_parsing.helpers import Serializable -from torch import nn -# 定义MoeArgs数据类,用于存储MoE的配置参数 + @dataclasses.dataclass -class MoeArgs(Serializable): - num_experts: int - num_experts_per_tok: int +class TransformerConfig: + """ + Overview: + Configuration for the Transformer block and its potential MoE layer. + + Arguments: + - embed_dim (int): The embedding dimension for the model. + - resid_pdrop (float): The dropout probability for the residual connections. + - moe_in_transformer (bool): If True, use an MoE layer for the feed-forward part. Otherwise, use a standard MLP. + - num_experts (int): The total number of experts in the MoE layer. + - num_experts_per_tok (int): The number of experts to route each token to (top-k routing). + """ + embed_dim: int = 64 + resid_pdrop: float = 0.1 + moe_in_transformer: bool = False + num_experts: int = 1 + num_experts_per_tok: int = 1 + -# 定义Mixture of Experts(MoE)层 -class MoeLayer(nn.Module): - def __init__(self, experts: List[nn.Module], gate: nn.Module, num_experts_per_tok=1): +class MoELayer(nn.Module): + """ + Overview: + An efficient, vectorized implementation of a Mixture-of-Experts (MoE) layer. + This layer routes each token to a subset of experts (Top-k routing) and combines their + outputs using a weighted sum. The implementation is highly optimized for parallel + computation on hardware like GPUs. + """ + + def __init__(self, experts: List[nn.Module], gate: nn.Module, num_experts_per_tok: int): + """ + Overview: + Initializes the MoE layer. + Arguments: + - experts (List[nn.Module]): A list of expert neural network modules. + - gate (nn.Module): The gating network that computes routing logits. + - num_experts_per_tok (int): The number of experts to route each token to. + """ super().__init__() - assert len(experts) > 0 + assert len(experts) > 0, "The list of experts cannot be empty." self.experts = nn.ModuleList(experts) self.gate = gate + self.num_experts = len(experts) self.num_experts_per_tok = num_experts_per_tok - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - if len(self.experts) == 1: - # 只有一个专家时,直接使用该专家 - return self.experts[0](inputs) + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Overview: + Performs the forward pass of the MoE layer. + Arguments: + - x (torch.Tensor): Input tensor of shape `[batch_size, seq_len, embed_dim]`. + Returns: + - (torch.Tensor): Output tensor of the same shape as the input. + """ + batch_size, seq_len, dim = x.shape + x_flat = x.view(-1, dim) + + gate_logits = self.gate(x_flat) + weights, topk_indices = torch.topk(gate_logits, self.num_experts_per_tok, dim=1) + weights = F.softmax(weights, dim=1, dtype=torch.float).to(x.dtype) + + num_tokens = x_flat.shape[0] + flat_token_indices = torch.arange(num_tokens, device=x.device).repeat_interleave(self.num_experts_per_tok) + flat_expert_indices = topk_indices.view(-1) - gate_logits = self.gate(inputs) - weights, selected_experts = torch.topk(gate_logits, self.num_experts_per_tok) - weights = F.softmax(weights, dim=1, dtype=torch.float).to(inputs.dtype) - results = torch.zeros_like(inputs) - for i, expert in enumerate(self.experts): - batch_idx, token_idx, nth_expert = torch.where(selected_experts == i) - results[batch_idx, token_idx] += weights[batch_idx, token_idx, nth_expert][:, None] * expert(inputs[batch_idx, token_idx]) - return results - -# 定义一个简单的Transformer块 + sort_order = torch.argsort(flat_expert_indices) + sorted_expert_indices = flat_expert_indices[sort_order] + sorted_token_indices = flat_token_indices[sort_order] + + expert_inputs = x_flat[sorted_token_indices] + sorted_weights = weights.view(-1, 1)[sort_order] + + expert_counts = torch.bincount(sorted_expert_indices, minlength=self.num_experts) + output_buffer = torch.zeros_like(expert_inputs) + + ptr = 0 + for i, count in enumerate(expert_counts.tolist()): + if count == 0: + continue + segment = slice(ptr, ptr + count) + output_buffer[segment] = self.experts[i](expert_inputs[segment]) + ptr += count + + # --- FIX: Simplified and corrected scattering logic --- + # Weight the outputs and directly add them to the correct token's position. + weighted_outputs = output_buffer * sorted_weights + + token_output = torch.zeros_like(x_flat) + # Use `sorted_token_indices` to add the results back to their original token positions. + token_output.index_add_(0, sorted_token_indices, weighted_outputs) + + return token_output.view(batch_size, seq_len, dim) + + class TransformerBlock(nn.Module): - def __init__(self, config): + """ + Overview: + A simplified Transformer block that contains a feed-forward network (FFN). + The FFN can be either a standard MLP or a Mixture-of-Experts (MoE) layer, + controlled by the configuration. + """ + def __init__(self, config: TransformerConfig): super().__init__() self.mlp = nn.Sequential( nn.Linear(config.embed_dim, 4 * config.embed_dim), @@ -47,61 +126,75 @@ def __init__(self, config): ) if config.moe_in_transformer: - self.feed_forward = MoeLayer( - experts=[self.mlp for _ in range(config.num_experts_of_moe_in_transformer)], - gate=nn.Linear(config.embed_dim, config.num_experts_of_moe_in_transformer, bias=False), - num_experts_per_tok=1, + experts = [self.mlp for _ in range(config.num_experts)] + self.feed_forward = MoELayer( + experts=experts, + gate=nn.Linear(config.embed_dim, config.num_experts, bias=False), + num_experts_per_tok=config.num_experts_per_tok, ) - print("="*20) - print('使用MoE在Transformer的feed_forward中') - print("="*20) + print("=" * 40) + print("TransformerBlock initialized with MoE layer.") + print("=" * 40) else: self.feed_forward = self.mlp + print("-" * 40) + print("TransformerBlock initialized with standard MLP.") + print("-" * 40) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.feed_forward(x) -# 定义配置类 -class Config: - def __init__(self, embed_dim, resid_pdrop, num_experts_of_moe_in_transformer, moe_in_transformer): - self.embed_dim = embed_dim - self.resid_pdrop = resid_pdrop - self.num_experts_of_moe_in_transformer = num_experts_of_moe_in_transformer - self.moe_in_transformer = moe_in_transformer - -# 测试代码 -def test_transformer_block(): - # 初始化配置 - embed_dim = 64 - resid_pdrop = 0.1 - num_experts_of_moe_in_transformer = 1 - # 创建输入数据 - inputs = torch.randn(10, 5, embed_dim) # (batch_size, seq_len, embed_dim) +def test_transformer_block_equivalence(): + """ + Overview: + Tests that an MoE layer with a single expert produces an output identical + to that of a standard MLP layer, given that they share the same weights. + """ + torch.manual_seed(42) + + embed_dim = 64 + batch_size = 10 + seq_len = 5 + + config_mlp = TransformerConfig(embed_dim=embed_dim, moe_in_transformer=False) + config_moe = TransformerConfig(embed_dim=embed_dim, moe_in_transformer=True, num_experts=1, num_experts_per_tok=1) - # 初始化两个输出变量 - outputs_true = None - outputs_false = None + # --- FIX: Ensure identical weights for a fair comparison --- + # 1. Create the standard MLP block first. + transformer_block_mlp = TransformerBlock(config_mlp) - # 对于moe_in_transformer为True和False分别进行测试 - for moe_in_transformer in [True, False]: - config = Config(embed_dim, resid_pdrop, num_experts_of_moe_in_transformer, moe_in_transformer) - transformer_block = TransformerBlock(config) - - outputs = transformer_block(inputs) - print(f"moe_in_transformer={moe_in_transformer}: outputs={outputs}") + # 2. Create the MoE block. + transformer_block_moe = TransformerBlock(config_moe) - if moe_in_transformer: - outputs_true = outputs - else: - outputs_false = outputs + # 3. CRITICAL: Load the MLP's weights into the MoE's expert MLP. + # This guarantees that the underlying expert has the exact same weights as the standalone MLP. + transformer_block_moe.mlp.load_state_dict(transformer_block_mlp.mlp.state_dict()) + + # Also, for a perfect match, the gate should be initialized to a state + # that it doesn't affect the output scaling. We can manually set its weights. + # In a single-expert case, softmax ensures the weight is 1, so this is not strictly + # necessary, but it's good practice for more complex tests. + + inputs = torch.randn(batch_size, seq_len, embed_dim) + + print("\nRunning forward pass for standard MLP block...") + output_mlp = transformer_block_mlp(inputs) + + print("\nRunning forward pass for MoE block...") + output_moe = transformer_block_moe(inputs) - # 计算输出的差异 - mse_difference = None - if outputs_true is not None and outputs_false is not None: - mse_difference = F.mse_loss(outputs_true, outputs_false).item() + is_close = torch.allclose(output_moe, output_mlp, atol=1e-6) + mse_difference = F.mse_loss(output_moe, output_mlp).item() + + print("\n" + "=" * 25 + " TEST RESULTS " + "=" * 25) + print(f"Outputs are close: {is_close}") + print(f"Mean Squared Error (MSE) between outputs: {mse_difference:.10f}") - print(f"输出差异的均方误差(MSE): {mse_difference}") + assert is_close, "Test failed: Outputs of single-expert MoE and MLP are not identical." + print("\n✅ Test Passed: Single-expert MoE layer behaves identically to a standard MLP.") + print("=" * 64 + "\n") + if __name__ == "__main__": - test_transformer_block() \ No newline at end of file + test_transformer_block_equivalence() \ No newline at end of file diff --git a/lzero/model/unizero_world_models/tokenizer.py b/lzero/model/unizero_world_models/tokenizer.py index 1e87efb17..65325b3b4 100644 --- a/lzero/model/unizero_world_models/tokenizer.py +++ b/lzero/model/unizero_world_models/tokenizer.py @@ -1,165 +1,364 @@ """ Modified from https://github.com/CompVis/taming-transformers +This module provides an autoencoder-style tokenizer for encoding observations into latent embeddings and decoding them back. """ from dataclasses import dataclass +from typing import Any, Dict, Optional import torch import torch.nn as nn from einops import rearrange from torch.nn import functional as F - +from typing import Optional, List +from transformers.modeling_outputs import BaseModelOutput class LossWithIntermediateLosses: - def __init__(self, **kwargs): - """Initialize with various loss components.""" - self.loss_total = sum(kwargs.values()) - self.intermediate_losses = {k: v.item() for k, v in kwargs.items()} - - def __truediv__(self, value): - """Divide all loss components by a given value.""" - for k, v in self.intermediate_losses.items(): - self.intermediate_losses[k] = v / value + """ + Overview: + A helper class to manage a total loss value alongside a dictionary of its constituent, named loss components. + This is primarily used for detailed logging. + """ + + def __init__(self, **kwargs: torch.Tensor) -> None: + """ + Overview: + Initializes the loss object. + Arguments: + - kwargs (:obj:`torch.Tensor`): Keyword arguments where keys are loss names and values are the corresponding loss tensors. + """ + # The total loss, which can be used for backpropagation. + self.loss_total: torch.Tensor = sum(kwargs.values()) + # A dictionary holding the scalar values of intermediate losses, detached from the computation graph. + self.intermediate_losses: Dict[str, float] = {k: v.item() for k, v in kwargs.items()} + + def __truediv__(self, value: float) -> "LossWithIntermediateLosses": + """ + Overview: + Overloads the division operator to scale all loss components by a scalar value. + This is useful for operations like averaging over batch size or gradient accumulation steps. + Arguments: + - value (:obj:`float`): The scalar value to divide the losses by. + Returns: + - LossWithIntermediateLosses: The same instance with updated loss values. + """ + if not isinstance(value, (int, float)) or value == 0: + raise ValueError(f"Division is only supported for a non-zero scalar, but got {value}.") + self.loss_total = self.loss_total / value + for k in self.intermediate_losses: + self.intermediate_losses[k] /= value return self @dataclass class TokenizerEncoderOutput: + """ + Overview: + A data structure to hold the various outputs from a VQ-VAE style encoder, + including continuous and quantized latent representations, and discrete tokens. + """ + # Continuous latent representation from the encoder. z: torch.FloatTensor + # Quantized latent representation. z_quantized: torch.FloatTensor + # Discrete integer tokens corresponding to the codebook entries. tokens: torch.LongTensor class Tokenizer(nn.Module): """ Overview: - Tokenizer model that encodes and decodes observations. + An autoencoder model that encodes high-dimensional observations (like images or state vectors) + into low-dimensional latent embeddings and decodes them back. It can also compute reconstruction + and perceptual losses. This implementation does not include the quantization step (Vector Quantization) + but serves as the encoder-decoder backbone. """ - def __init__(self, encoder=None, decoder_network=None, with_lpips: bool = False, obs_type=None) -> None: - """Initialize the Tokenizer. + def __init__( + self, + encoder: nn.Module, + decoder: nn.Module, + with_lpips: bool = False, + obs_type: str = 'image' + ) -> None: + """ + Overview: + Initializes the Tokenizer (Autoencoder). Arguments: - encoder (nn.Module, optional): Encoder network. Defaults to None. - decoder_network (nn.Module, optional): Decoder network. Defaults to None. - with_lpips (bool, optional): Whether to use LPIPS for perceptual loss. Defaults to False. + - encoder (:obj:`nn.Module`): The network responsible for encoding observations into latent embeddings. It can be a single module or an nn.ModuleList for multi-task scenarios. + - decoder (:obj:`nn.Module`): The network responsible for decoding latent embeddings back into observations. + - with_lpips (:obj:`bool`): If True, initializes the LPIPS model to compute perceptual loss. Defaults to False. + - obs_type (:obj:`str`): The type of observation, e.g., 'image' or 'vector'. This can inform model architecture choices. Defaults to 'image'. """ super().__init__() + self.encoder = encoder + self.decoder_network = decoder + self.obs_type = obs_type + self.lpips: Optional[nn.Module] = None if with_lpips: + # Lazily import LPIPS as it's an optional dependency. from lzero.model.unizero_world_models.lpips import LPIPS self.lpips = LPIPS().eval() - else: - self.lpips = None - self.encoder = encoder - self.decoder_network = decoder_network - self.obs_type = obs_type - - def encode_to_obs_embeddings(self, x: torch.Tensor, task_id = None) -> torch.Tensor: + def encode_to_obs_embeddings(self, x: torch.Tensor, task_id: int = 0) -> torch.Tensor: """ - Encode observations to embeddings. - + Overview: + Encodes a batch of observations into latent embeddings, handling various input shapes and multi-task encoders. Arguments: - - x (torch.Tensor): Input tensor of shape (B, ...). - + - x (:obj:`torch.Tensor`): The input tensor of observations. Shape can be (B, E), (B, T, E), (B, C, H, W), or (B, T, C, H, W). + - task_id (:obj:`int`): The identifier for the task, used to select the correct encoder from an nn.ModuleList in multi-task settings. Defaults to 0. Returns: - - torch.Tensor: Encoded embeddings of shape (B, 1, E). + - torch.Tensor: The encoded latent embeddings with a consistent shape of (B, 1, E), where B is the effective batch size. """ - shape = x.shape - # TODO: ====== - if task_id is None: - # for compatibility with multitask setting - task_id = 0 - else: - # task_id = 0 # one share encoder - task_id = task_id # TODO: one encoder per task - # print(f'='*20) - # print(f'x.shape:{x.shape}') - # print(f'self.encoder:{self.encoder}') - - # Process input tensor based on its dimensionality - if len(shape) == 2: - # Case when input is 2D (B, E) - # obs_embeddings = self.encoder[task_id](x) - obs_embeddings = self.encoder(x, task_id) # TODO: - - obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') - elif len(shape) == 3: - # Case when input is 3D (B, T, E) - x = x.contiguous().view(-1, shape[-1]) # Flatten the last two dimensions (B * T, E) - # obs_embeddings = self.encoder[task_id](x) - obs_embeddings = self.encoder(x,task_id) # TODO: - - obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') - elif len(shape) == 4: - # Case when input is 4D (B, C, H, W) - if self.obs_type == 'vector': - obs_embeddings = self.encoder(x, task_id=task_id) # TODO: for dmc multitask - elif self.obs_type == 'image': - try: - obs_embeddings = self.encoder[0](x) # TODO: for atari/memory env - except: - obs_embeddings = self.encoder(x) # TODO: for atari/memory env single-task - - obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') - elif len(shape) == 5: - # Case when input is 5D (B, T, C, H, W) - x = x.contiguous().view(-1, *shape[-3:]) # Flatten the first two dimensions (B * T, C, H, W) - if self.obs_type == 'vector': - obs_embeddings = self.encoder[task_id](x) - elif self.obs_type == 'image': - try: - obs_embeddings = self.encoder[0](x) # TODO: for atari/memory env - except: - obs_embeddings = self.encoder(x) # TODO: for atari/memory env single-task - - obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') + + # global DEBUG_ENABLED;DEBUG_ENABLED = True + # import torch.distributed as dist + # if dist.get_rank() == 0 and DEBUG_ENABLED: + # print(f"rank {dist.get_rank()} 进入调试模式,输入interact,可以键入整段的python代码调试。通过设置 DEBUG_ENABLED = False, 可以跳过调试状态") + # import ipdb; ipdb.set_trace() + # # 同步点,防止其它进程早跑 + # dist.barrier() + + # Step 1: Select the appropriate encoder module. + # This handles both single-task (a single nn.Module) and multi-task (an nn.ModuleList) scenarios. + if isinstance(self.encoder, nn.ModuleList): + if not 0 <= task_id < len(self.encoder): + # raise ValueError( + # f"Provided task_id {task_id} is invalid for the encoder list of size {len(self.encoder)}." + # ) + encoder_module = self.encoder[0] + else: + encoder_module = self.encoder[task_id] else: - raise ValueError(f"Invalid input shape: {shape}") + encoder_module = self.encoder + + # Step 2: Pre-process and reshape the input tensor based on its dimensions. + # The goal is to transform the input into a 2D or 4D tensor that the encoder can process. + original_shape = x.shape + if len(original_shape) == 5: # Batch of sequences of images: (B, T, C, H, W) + # Flatten the batch and time dimensions to create a batch of images. + x = x.contiguous().view(-1, *original_shape[-3:]) # Shape: (B*T, C, H, W) + elif len(original_shape) == 3: # Batch of sequences of vectors: (B, T, E) + # Flatten the batch and time dimensions to create a batch of vectors. + x = x.contiguous().view(-1, original_shape[-1]) # Shape: (B*T, E) + # Note: 2D (B, E) and 4D (B, C, H, W) inputs are processed directly without reshaping. + + # Step 3: Pass the processed tensor through the encoder. + obs_embeddings = encoder_module(x) + if len(obs_embeddings.shape) != 2: + raise RuntimeError( + f"Encoder output was expected to be 2D (batch, embedding_dim), but got shape {obs_embeddings.shape}." + ) + + # Step 4: Reshape the output to a consistent sequence format (B', 1, E). + # The '1' represents a sequence length of one, making it compatible with sequence models. + obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') return obs_embeddings def decode_to_obs(self, embeddings: torch.Tensor) -> torch.Tensor: - """Decode embeddings to observations. + """ + Overview: + Decodes a batch of latent embeddings back into the observation space. + Arguments: + - embeddings (:obj:`torch.Tensor`): The latent embeddings to decode. + Returns: + - torch.Tensor: The reconstructed observations. + """ + return self.decoder_network(embeddings) + def decode_to_reconstruction_outputs(self, embeddings: torch.Tensor, target_ids: torch.Tensor) -> torch.Tensor: + """ + Overview: + This function takes input embeddings and corresponding target token IDs, + then uses a seq2seq decoder (like T5) to reconstruct the original text. + It handles reshaping, retokenization, projection, and calls the decoder + to compute the reconstruction loss and logits. Arguments: - embeddings (:obj:`torch.Tensor`): Input embeddings. + embeddings (torch.Tensor): Input embeddings of shape (B, E), (B, L, E), or (B*T, 1, E). + target_ids (torch.Tensor): Ground-truth token IDs of shape (B, L) or (B*T, L). + Returns: + torch.Tensor: Decoder output including loss, logits, hidden states (if return_dict=True). + """ + if embeddings.dim() == 2: + embeddings = embeddings.unsqueeze(1) + elif embeddings.dim() == 3: + B,T,E = embeddings.shape + embeddings = embeddings.reshape(B*T,1,E) + target_ids = target_ids.reshape(B*T, -1) + + if self.encoder_option == 'legacy': # T5 decoder + # Instead of using raw target_ids, convert them to plain text and re-tokenize using the decoder's tokenizer. + # This guarantees alignment with the decoder's vocabulary, special tokens, and tokenization rules. + text_list = self.encoder.tokenizer.batch_decode(target_ids, skip_special_tokens=True) + t5_target_ids = self.decoder_network_tokenizer(text_list, + padding="max_length", + truncation=True, + max_length=512, + return_tensors="pt") + labels = t5_target_ids.input_ids + labels[labels == self.decoder_network_tokenizer.pad_token_id] = -100 + + embeddings = self.projection_layer(embeddings) # (B', 1, E) -> (B', 1, E'), B' = B*T + encoder_outputs_tuple = BaseModelOutput(last_hidden_state=embeddings) + encoder_attention_mask = torch.ones( + embeddings.size(0), embeddings.size(1), + device=embeddings.device, dtype=torch.long + ) + + labels = labels.to(embeddings.device) + + outputs = self.decoder_network(encoder_outputs=encoder_outputs_tuple, + attention_mask=encoder_attention_mask, + labels=labels, + return_dict=True) + return outputs + + elif self.encoder_option == 'qwen': + hidden = self.projection_layer(embeddings) + lm = self.decoder_network.pretrained_model + # Get a reference parameter for device/dtype info + param = next(lm.parameters()) + try: + # Retrieve the input embedding layer of the language model + input_embedding_layer = lm.get_input_embeddings() + except: + raise ValueError('Error... Could not retrieve input embedding layer from the decoder network.') + + # Convert target token IDs into embeddings using the LM's input embedding layer + target_embeds = input_embedding_layer(target_ids) + + # Concatenate the projected hidden embeddings (prompt) with target embeddings + # hidden: (B, 1, D), target_embeds: (B, L, D) → inputs_embeds: (B, 1+L, D) + inputs_embeds = torch.cat([hidden, target_embeds.detach()], dim=1) + + inputs_embeds = inputs_embeds.to(device=param.device, dtype=param.dtype) + + prompt_attention_mask = torch.ones(hidden.size(0), 1, device=param.device, dtype=torch.long) + target_attention_mask = (target_ids != self.decoder_network.tokenizer.pad_token_id).to(device=param.device, dtype=torch.long) + # Concatenate prompt mask and target mask along sequence length + attention_mask = torch.cat([prompt_attention_mask, target_attention_mask], dim=1) + # Construct labels: for the prompt part, use -100 (ignored by loss function) + prompt_labels = torch.full((hidden.size(0), 1), -100, device=param.device, dtype=torch.long) + + # Copy target token IDs as labels, masking pad positions with -100 + labels = target_ids.clone().to(param.device) + labels[labels == self.decoder_network.tokenizer.pad_token_id] = -100 + + final_labels = torch.cat([prompt_labels, labels], dim=1) + + outputs = lm( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + labels=final_labels, + return_dict=True + ) + + return outputs + + def decode_to_plain_text( + self, embeddings: torch.Tensor, + max_length: int = 512 + ) -> List[List[int]]: + """ + Overview: + This function decodes latent embeddings into plain text using the decoder's generate method. + It includes projection, prepares encoder outputs and attention mask, and performs autoregressive decoding. + Arguments: + embeddings (torch.Tensor): Latent embeddings, shape (B, E) or (B, L, E). + max_length (int, optional): Max token length for generation. Defaults to 512. Returns: - torch.Tensor: Decoded observations. + List[List[int]]: List of decoded strings, one per input in batch. """ - return self.decoder_network(embeddings) + + # Set decoder_network and projection_layer to evaluation mode to disable dropout and other training-specific behaviors. + self.decoder_network.eval() + self.projection_layer.eval() + + # If embeddings is not a Tensor, convert it to a torch.Tensor. + if not isinstance(embeddings, torch.Tensor): + embeddings = torch.tensor(embeddings, dtype=torch.float32) + + # Attempt to retrieve the device information from decoder_network; if unavailable, fall back to the model’s parameters. + try: + device = self.decoder_network.device + except AttributeError: + device = next(self.decoder_network.parameters()).device + + embeddings = embeddings.to(device) + + with torch.no_grad(): + if embeddings.dim() == 2: + embeddings = embeddings.unsqueeze(1) + + embeddings = self.projection_layer(embeddings) + if self.encoder_option == 'legacy': # T5 decoder + encoder_outputs_tuple = BaseModelOutput(last_hidden_state=embeddings) + encoder_attention_mask = torch.ones( + embeddings.size(0), embeddings.size(1), + device=device, dtype=torch.long + ) + + # Use the decoder's generate() method to autoregressively decode text from the input embeddings. + # The projected embeddings serve as encoder outputs in a typical encoder-decoder architecture, + # where the decoder attends to them via cross-attention at each step until max_length or EOS is reached. + generated_t5_ids = self.decoder_network.generate( + encoder_outputs=encoder_outputs_tuple, + attention_mask=encoder_attention_mask, + max_length=max_length + ) + + # Convert the generated output to a list of strings on CPU, skipping special tokens. + generated_text = self.decoder_network_tokenizer.batch_decode( + generated_t5_ids, skip_special_tokens=True) + + assert len(generated_text) == 1, f"Expected 1 generated text, got {len(generated_text)}" + return generated_text[0] + + elif self.encoder_option == 'qwen': + return self.decoder_network.decode(embeddings=embeddings, max_length=max_length) @staticmethod def reconstruction_loss(original_images: torch.Tensor, reconstructed_images: torch.Tensor) -> torch.Tensor: - """Calculate the reconstruction loss. - + """ + Overview: + Calculates the reconstruction loss between original and reconstructed observations. + It uses L2 (MSE) loss for vector-based observations and L1 (MAE) loss for image-based observations. Arguments: - - original_images (:obj:`torch.Tensor`): Original images. - - reconstructed_images (:obj:`torch.Tensor`): Reconstructed images. - + - original_images (:obj:`torch.Tensor`): The ground-truth observations. + - reconstructed_images (:obj:`torch.Tensor`): The observations reconstructed by the decoder. Returns: - - torch.Tensor: Computed reconstruction loss. + - torch.Tensor: A scalar tensor representing the computed reconstruction loss. """ if len(original_images.shape) == 2: - # For memory environment vector observations - loss = F.mse_loss(original_images, reconstructed_images) # L2 loss + # Use Mean Squared Error (L2 loss) for vector-based observations. + return F.mse_loss(reconstructed_images, original_images) else: - # For Atari image environment - loss = torch.abs(original_images - reconstructed_images).mean() # L1 loss - return loss + # Use Mean Absolute Error (L1 loss) for image-based observations, which is often more robust to outliers. + return torch.abs(original_images - reconstructed_images).mean() def perceptual_loss(self, original_images: torch.Tensor, reconstructed_images: torch.Tensor) -> torch.Tensor: - """Calculate the perceptual loss using LPIPS. - + """ + Overview: + Calculates the perceptual loss (LPIPS) between original and reconstructed images. + This loss is designed to better align with human perception of image similarity. Arguments: - original_images (:obj:`torch.Tensor`): Original images. - reconstructed_images (:obj:`torch.Tensor`): Reconstructed images. - + - original_images (:obj:`torch.Tensor`): The ground-truth images. + - reconstructed_images (:obj:`torch.Tensor`): The images reconstructed by the decoder. Returns: - torch.Tensor: Computed perceptual loss. + - torch.Tensor: A scalar tensor representing the computed perceptual loss. """ + if self.lpips is None: + raise RuntimeError("LPIPS model was not initialized. Please set `with_lpips=True` during Tokenizer instantiation.") return torch.mean(self.lpips(original_images, reconstructed_images)) + + def __repr__(self) -> str: - return "Tokenizer" \ No newline at end of file + """ + Overview: + Provides a string representation of the Tokenizer module. + """ + return f"Tokenizer(obs_type='{self.obs_type}', with_lpips={self.lpips is not None})" \ No newline at end of file diff --git a/lzero/model/unizero_world_models/transformer.py b/lzero/model/unizero_world_models/transformer.py index 5e2e3e670..0e855d289 100644 --- a/lzero/model/unizero_world_models/transformer.py +++ b/lzero/model/unizero_world_models/transformer.py @@ -1,50 +1,89 @@ """ -Modified from https://github.com/karpathy/nanoGPT - -在原 transformer.py 基础上增加 LoRA 微调相关代码, -并通过传入配置参数控制 LoRA 微调的模块(默认是 attention 中的 k, q, v, proj 和 feed_forward) -保持原有代码的可扩展性。 +This script is an extension of the original transformer.py from karpathy/nanoGPT. +It incorporates LoRA (Low-Rank Adaptation) for fine-tuning and introduces a +Curriculum Learning mechanism that activates different LoRA adapters sequentially. + +Key features: +- Adds `CurriculumLoRALinear`, a custom linear layer with multiple LoRA adapters. +- Controls which modules to apply LoRA to via configuration (e.g., attention and feed-forward layers). +- Maintains the extensibility and readability of the original nanoGPT codebase. """ -import numpy as np import math +import logging from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional import torch import torch.nn as nn -import torch.nn as nn -from torch.nn import functional as F from ding.torch_utils.network import GRUGatingUnit from einops import rearrange +from torch.nn import functional as F from .kv_caching import KeysValues -from .moe import MoeLayer, MultiplicationFeedForward -from line_profiler import line_profiler from lzero.model.common import SimNorm -############################################# -# 新增:LoRA 微调相关代码 -############################################# -class LoRALinear(nn.Module): +class LearnableScale(nn.Module): """ - LoRA 适配器包装的线性层。 + A learnable scalar parameter constrained within a specific range. + + The formula `s = offset + scale * tanh(ŝ)` maps an unbounded logit `ŝ` + to the range (offset - scale, offset + scale). Using tanh can sometimes + provide more stable gradients than sigmoid. - 原理: - 使用冻结的原始 nn.Linear 层,并添加两个小型低秩矩阵, - 计算公式为:y = x @ W^T + scaling * ((drop(x) @ A^T) @ B^T) - 其中 A 和 B 为低秩参数,scaling = lora_alpha / r. + For example, to achieve a range of (0.8, 1.2), one would use + `init=1.0` and `s_range=0.2`. """ - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - r: int = 0, - lora_alpha: int = 1, - lora_dropout: float = 0.0 - ): + + def __init__(self, init: float = 1.0, s_range: float = 0.2) -> None: + """ + Overview: + Initializes the LearnableScale module. + Arguments: + - init (:obj:`float`): The initial value of the scalar, which also serves as the center of the range. + - s_range (:obj:`float`): The scale factor that determines the range (init - s_range, init + s_range). + """ + super().__init__() + assert s_range > 0, "The scaling range must be positive." + self.offset = init + self.scale = s_range + + # Initialize the logit to 0, so the initial output is exactly `init`. + self.logit = nn.Parameter(torch.tensor(0.0)) + # TODO: Initially frozen, activated by a CurriculumController. + self.logit.requires_grad = False + + def forward(self) -> torch.Tensor: + """ + Overview: + Computes the scaled value. + Returns: + - torch.Tensor: The learnable scalar, constrained to the specified range. + """ + return self.offset + self.scale * torch.tanh(self.logit) + +############################################## +# Optimized CurriculumLoRALinear Implementation (Recommended Version) +############################################## + +class CurriculumLoRALinear(nn.Module): + """ + Optimized CurriculumLoRALinear. + + Effective weight at stage s: + W_eff = α₀*W₀ + Σ_{j=1 to s} αⱼ*Δθⱼ + + Optimization logic at stage s (s >= 1): + - Train: Δθₛ, α₀, and {αⱼ | 1 <= j < s} + - Freeze: W₀, {Δθⱼ | 1 <= j < s}, and αₛ + + This avoids the redundancy of training αₛ alongside Δθₛ. + """ + + def __init__(self, in_features: int, out_features: int, bias: bool = True, + r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, + curriculum_stage_num: int = 1, lora_scale_init: float = 1.0) -> None: super().__init__() self.in_features = in_features self.out_features = out_features @@ -52,58 +91,314 @@ def __init__( self.lora_alpha = lora_alpha self.scaling = lora_alpha / r if r > 0 else 1.0 self.lora_dropout = nn.Dropout(p=lora_dropout) if lora_dropout > 0.0 else nn.Identity() + self.curriculum_stage_num = curriculum_stage_num + self.curriculum_stage = 0 - # 原始权重(冻结参数,不更新) + # Base weights (W₀ and bias) self.weight = nn.Parameter(torch.empty(out_features, in_features)) - self.bias = nn.Parameter(torch.empty(out_features)) if bias else None - nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) if bias: + self.bias = nn.Parameter(torch.empty(out_features)) + else: + self.register_parameter('bias', None) + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) - bound = 1 / math.sqrt(fan_in) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 nn.init.uniform_(self.bias, -bound, bound) - # 低秩矩阵参数(仅在 r > 0 时添加) - if r > 0: - # A 将 in_features 映射到低秩 r;B 从低秩 r 映射回 out_features - self.lora_A = nn.Parameter(torch.randn(r, in_features) * 0.01) - self.lora_B = nn.Parameter(torch.zeros(out_features, r)) + # Learnable scale for the base weight (α₀) + self.base_weight_scale = LearnableScale(init=1.0, s_range=0.2) + + # A scale for each adapter (α₁, α₂, ...) + self.adapters = nn.ModuleList() + self.adapter_scales = nn.ModuleList() + + if r > 0 and (curriculum_stage_num - 1) > 0: + for _ in range(curriculum_stage_num - 1): + adapter = nn.ParameterDict({ + 'lora_A': nn.Parameter(torch.randn(r, in_features) * 0.01), + 'lora_B': nn.Parameter(torch.zeros(out_features, r)) + }) + self.adapters.append(adapter) + self.adapter_scales.append(LearnableScale(lora_scale_init, s_range=0.2)) else: - self.lora_A = None - self.lora_B = None + self.adapters = None + + self.set_curriculum_stage(0) + + def set_curriculum_stage(self, stage: int) -> None: + assert 0 <= stage < self.curriculum_stage_num, f"Stage must be within [0, {self.curriculum_stage_num-1}]" + self.curriculum_stage = stage + module_id = f"({self.in_features}x{self.out_features})" + + # --- Stage 0: Base Training --- + if stage == 0: + self.weight.requires_grad = True + if self.bias is not None: self.bias.requires_grad = True + + # Freeze everything else + self.base_weight_scale.logit.requires_grad = False + if self.adapters: + for adapter in self.adapters: + adapter['lora_A'].requires_grad = False + adapter['lora_B'].requires_grad = False + for scale in self.adapter_scales: + scale.logit.requires_grad = False + logging.info(f"[CurriculumLoRALinear {module_id}] Stage 0: Base layer trainable.") + + # --- Stage >= 1: Adaptation --- + else: + # Freeze base model + self.weight.requires_grad = False + if self.bias is not None: self.bias.requires_grad = False + + # α₀ is trainable from stage 1 onwards + self.base_weight_scale.logit.requires_grad = True + + if self.adapters: + # Set trainability for LoRA adapters + for idx, adapter in enumerate(self.adapters): + is_current_adapter = (idx == stage - 1) + adapter['lora_A'].requires_grad = is_current_adapter + adapter['lora_B'].requires_grad = is_current_adapter + + # --- OPTIMIZED LOGIC FOR SCALES --- + # Set trainability for adapter scales {α_j} + for idx, scale in enumerate(self.adapter_scales): + # A scale α_j is trainable if it belongs to a *previous* stage (j < s). + # The current stage's scale α_s (idx = stage - 1) is NOT trained. + is_previous_scale = (idx < stage - 1) + scale.logit.requires_grad = is_previous_scale + + logging.info(f"[CurriculumLoRALinear {module_id}] Stage {stage}: Activating adapter {stage - 1} and scales for stages < {stage - 1}.") - # 冻结原始权重参数,保证仅更新 LoRA 参数 - self.weight.requires_grad = False - if self.bias is not None: - self.bias.requires_grad = False def forward(self, x: torch.Tensor) -> torch.Tensor: - # 原始线性输出(冻结部分) - result = F.linear(x, self.weight, self.bias) - # 如启用了 LoRA,则加上低秩部分 - if self.r > 0: - lora_out = F.linear(self.lora_dropout(x), self.lora_A) # (…, r) - lora_out = F.linear(lora_out, self.lora_B) # (…, out_features) - result = result + self.scaling * lora_out - return result + # Apply scaling to base weight if in an adaptation stage + if self.curriculum_stage > 0: + alpha_0 = self.base_weight_scale() + scaled_weight = self.weight * alpha_0 + baseline_out = F.linear(x, scaled_weight, self.bias) + else: + baseline_out = F.linear(x, self.weight, self.bias) + if self.curriculum_stage == 0 or self.adapters is None: + return baseline_out + + adapter_out = 0 + # Iterate through all adapters up to the current stage + for idx in range(self.curriculum_stage): + if idx >= len(self.adapters): + break + + adapter = self.adapters[idx] + scale = self.adapter_scales[idx]() + + lora_x = self.lora_dropout(x) + out = F.linear(lora_x, adapter['lora_A']) + out = F.linear(out, adapter['lora_B']) + + # The forward pass is a simple sum. The magic happens in `set_curriculum_stage` + # which controls `requires_grad`. No need for `.detach()` here. + # Gradients will naturally flow only to parameters with `requires_grad=True`. + adapter_out = adapter_out + self.scaling * out * scale + + return baseline_out + adapter_out + + +# ############################################## +# # CurriculumLoRALinear Implementation +# ############################################## + +# class CurriculumLoRALinear(nn.Module): +# """ +# CurriculumLoRALinear extends a standard linear layer with curriculum-based LoRA adapters. + +# This module internally stores a base weight and bias. It also initializes multiple +# LoRA adapters (number = curriculum_stage_num - 1), which are activated sequentially. + +# Forward pass logic: +# - If `curriculum_stage == 0`: +# Output = F.linear(x, W, bias) +# - If `curriculum_stage >= 1`: +# Output = base_output + sum_{i=0}^{curriculum_stage-1} scaling * adapter_i(x) +# where only the adapter for the current stage (index == curriculum_stage - 1) is trainable. +# Previous adapters contribute to the forward pass but their gradients are detached. + +# Note: +# - The `set_curriculum_stage(stage)` method must be called externally to switch between stages. +# - Logging messages indicate the module's dimensions and the freeze/unfreeze status of its parameters. +# """ + +# def __init__(self, in_features: int, out_features: int, bias: bool = True, +# r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, +# curriculum_stage_num: int = 1, lora_scale_init: float = 1.0) -> None: +# """ +# Overview: +# Initializes the CurriculumLoRALinear layer. If `curriculum_stage_num > 1`, +# it creates `curriculum_stage_num - 1` LoRA adapters. +# Arguments: +# - in_features (:obj:`int`): Size of each input sample. +# - out_features (:obj:`int`): Size of each output sample. +# - bias (:obj:`bool`): If True, adds a learnable bias to the output. +# - r (:obj:`int`): The rank of the LoRA decomposition. If 0, LoRA is disabled. +# - lora_alpha (:obj:`int`): The alpha parameter for LoRA scaling. +# - lora_dropout (:obj:`float`): The dropout probability for LoRA layers. +# - curriculum_stage_num (:obj:`int`): The total number of curriculum stages. +# - lora_scale_init (:obj:`float`): The initial value for the learnable scale of each adapter. +# """ +# super().__init__() +# self.in_features = in_features +# self.out_features = out_features +# self.r = r +# self.lora_alpha = lora_alpha +# self.scaling = lora_alpha / r if r > 0 else 1.0 +# self.lora_dropout = nn.Dropout(p=lora_dropout) if lora_dropout > 0.0 else nn.Identity() +# self.curriculum_stage_num = curriculum_stage_num +# self.curriculum_stage = 0 # Initial stage is 0 + +# # Initialize base weights (part of the base transformer), trainable by default +# self.weight = nn.Parameter(torch.empty(out_features, in_features)) +# if bias: +# self.bias = nn.Parameter(torch.empty(out_features)) +# else: +# self.register_parameter('bias', None) +# nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) +# if self.bias is not None: +# fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) +# bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 +# nn.init.uniform_(self.bias, -bound, bound) + +# # Initialize LoRA adapters, which exist only if r > 0 and curriculum_stage_num > 1 +# self.adapters = nn.ModuleList() +# self.adapter_scales = nn.ModuleList() + +# if r > 0 and (curriculum_stage_num - 1) > 0: +# for _ in range(curriculum_stage_num - 1): +# adapter = nn.ParameterDict({ +# 'lora_A': nn.Parameter(torch.randn(r, in_features) * 0.01), +# 'lora_B': nn.Parameter(torch.zeros(out_features, r)) +# }) +# self.adapters.append(adapter) +# self.adapter_scales.append(LearnableScale(lora_scale_init, s_range=0.2)) + +# else: +# self.adapters = None + +# # Initially (stage 0), the base layer is trainable, and all adapters are frozen +# self.weight.requires_grad = True +# if self.bias is not None: +# self.bias.requires_grad = True +# if self.adapters is not None: +# for adapter in self.adapters: +# adapter['lora_A'].requires_grad = False +# adapter['lora_B'].requires_grad = False + +# def set_curriculum_stage(self, stage: int) -> None: +# """ +# Overview: +# Sets the current curriculum stage and updates the `requires_grad` status of parameters accordingly. +# - Stage 0: The base layer is trainable; all adapters are frozen. +# - Stage >= 1: The base layer is frozen. Only the current adapter (index = stage - 1) is trainable. +# Previous adapters contribute to the forward pass but do not propagate gradients. +# Arguments: +# - stage (:obj:`int`): The curriculum stage to set, in the range [0, curriculum_stage_num - 1]. +# """ +# assert 0 <= stage < self.curriculum_stage_num, f"Stage must be within [0, {self.curriculum_stage_num-1}]" +# self.curriculum_stage = stage + +# module_id = f"({self.in_features}x{self.out_features})" +# if stage == 0: +# self.weight.requires_grad = True +# if self.bias is not None: +# self.bias.requires_grad = True +# if self.adapters is not None: +# for adapter in self.adapters: +# adapter['lora_A'].requires_grad = False +# adapter['lora_B'].requires_grad = False +# logging.info(f"[CurriculumLoRALinear {module_id}] Stage 0: Base layer is trainable, all adapters are frozen.") +# else: +# # For stages > 0, freeze the base layer +# self.weight.requires_grad = False +# if self.bias is not None: +# self.bias.requires_grad = False + +# if self.adapters is not None: +# for idx, adapter in enumerate(self.adapters): +# is_current_adapter = (idx == stage - 1) +# adapter['lora_A'].requires_grad = is_current_adapter +# adapter['lora_B'].requires_grad = is_current_adapter +# status = "activated (trainable)" if is_current_adapter else "frozen (forward-only)" +# logging.info(f"[CurriculumLoRALinear {module_id}] Stage {stage}: Adapter {idx} is {status}.") + +# def forward(self, x: torch.Tensor) -> torch.Tensor: +# """ +# Overview: +# Performs the forward pass of the CurriculumLoRALinear layer. +# Arguments: +# - x (:obj:`torch.Tensor`): The input tensor. +# Returns: +# - torch.Tensor: The output tensor. +# """ +# baseline_out = F.linear(x, self.weight, self.bias) +# if self.curriculum_stage == 0 or self.adapters is None: +# return baseline_out + +# adapter_out = 0 +# # For the first `curriculum_stage` adapters, only the last one backpropagates. +# # Others are detached to contribute only to the forward pass. +# for idx in range(self.curriculum_stage): +# if idx >= len(self.adapters): +# break +# adapter = self.adapters[idx] +# lora_x = self.lora_dropout(x) +# out = F.linear(lora_x, adapter['lora_A']) +# out = F.linear(out, adapter['lora_B']) + +# scale = self.adapter_scales[idx]() + +# # NOTE: All adapter scales are currently trainable. +# if idx == self.curriculum_stage - 1: +# # Only the current adapter's output contributes to the gradient computation. +# adapter_out = adapter_out + self.scaling * out * scale +# else: +# # Outputs from previous adapters are detached. +# adapter_out = adapter_out + self.scaling * out.detach() * scale + +# return baseline_out + adapter_out + + +############################################## +# Helper function to wrap linear layers +############################################## def _maybe_wrap_linear(linear: nn.Linear, config, module_label: str) -> nn.Module: """ - 辅助函数:当 config.lora_r > 0 且 module_label 存在于 config.lora_target_modules 时, - 将传入的线性层替换为 LoRALinear,并复制原始权重数据。 - - module_label 的取值含义由上层逻辑定义,例如: - - 若 module_label 为 "attn",表示在 SelfAttention 中替换 k, q, v, proj 等层。 - - 若 module_label 为 "feed_forward",表示在 Transformer Block 的 MLP 中替换线性层。 + Overview: + A helper function that wraps an `nn.Linear` layer with `CurriculumLoRALinear` + if LoRA and curriculum learning are enabled for the specified module. + Arguments: + - linear (:obj:`nn.Linear`): The original linear layer to be potentially wrapped. + - config: The model configuration object. + - module_label (:obj:`str`): A label identifying the module type (e.g., "attn", "feed_forward"). + Returns: + - nn.Module: The wrapped `CurriculumLoRALinear` layer or the original `nn.Linear` layer. """ - if config.lora_r > 0 and module_label in config.lora_target_modules: - new_linear = LoRALinear( + use_curriculum_lora = ( + config.lora_r > 0 and + module_label in config.lora_target_modules and + getattr(config, "curriculum_stage_num", 1) > 1 + ) + if use_curriculum_lora: + new_linear = CurriculumLoRALinear( in_features=linear.in_features, out_features=linear.out_features, bias=(linear.bias is not None), r=config.lora_r, lora_alpha=config.lora_alpha, - lora_dropout=config.lora_dropout + lora_dropout=config.lora_dropout, + curriculum_stage_num=config.curriculum_stage_num, + lora_scale_init=config.lora_scale_init ) new_linear.weight.data.copy_(linear.weight.data) if linear.bias is not None: @@ -112,8 +407,39 @@ def _maybe_wrap_linear(linear: nn.Linear, config, module_label: str) -> nn.Modul else: return linear + +############################################## +# Helper function to set curriculum stage +############################################## + +def set_curriculum_stage(model: nn.Module, stage: int) -> None: + """ + Overview: + Recursively traverses all submodules of a given model, finds all instances + of `CurriculumLoRALinear`, and calls their `set_curriculum_stage` method. + This function is generic and can be applied to any model structure. + Arguments: + - model (:obj:`nn.Module`): The model to update (e.g., a Transformer or Vision Transformer). + - stage (:obj:`int`): The curriculum stage to set. + """ + count = 0 + for module in model.modules(): + if isinstance(module, CurriculumLoRALinear): + module.set_curriculum_stage(stage) + count += 1 + if count > 0: + logging.info(f"[Curriculum] Updated {count} CurriculumLoRALinear modules in {type(model).__name__} to stage {stage}.") + +# Alias for backward compatibility +set_curriculum_stage_for_transformer = set_curriculum_stage + + +############################################## +# Transformer Configuration +############################################## @dataclass class TransformerConfig: + """Configuration for the Transformer model.""" tokens_per_block: int max_blocks: int attention: str @@ -125,316 +451,193 @@ class TransformerConfig: embed_pdrop: float resid_pdrop: float attn_pdrop: float - - # for RoPE - rope_theta: float - max_seq_len: int - rotary_emb: bool = False - # LoRA 参数: + # LoRA parameters lora_r: int = 0 lora_alpha: int = 1 lora_dropout: float = 0.0 - # 指定哪些模块应用 LoRA,默认:attention 中的 k, q, v, proj 和 feed_forward 层(当非 moe 模型时) lora_target_modules: list = None - # Register Token 相关 + # Curriculum Learning parameters + # `curriculum_stage_num` is the total number of stages (e.g., 3 means stages 0, 1, 2) + curriculum_stage_num: int = 1 # 1 (base) + number of available LoRA adapters + min_stage0_iters: int = 10_000 # Minimum iterations for stage 0 + max_stage_iters: int = 20_000 # Maximum iterations per stage + lora_scale_init: float = 1.0 # Initial value for learnable adapter scales + + # Other configurations task_embed_option: str = "none" register_token_num: int = 4 register_token_shared: bool = True - # 其它配置项 gru_gating: bool = False moe_in_transformer: bool = False multiplication_moe_in_transformer: bool = False num_experts_of_moe_in_transformer: int = 1 @property - def max_tokens(self): + def max_tokens(self) -> int: + """Maximum number of tokens the model can handle.""" return self.tokens_per_block * self.max_blocks -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): - """ - Precompute the frequency components for the rotary positional embeddings. - - Arguments: - - dim (int): The dimension of the embedding. - - end (int): The length of the sequence for which frequencies are computed. - - theta (float): A scaling factor for the frequencies, default is 10000.0. - - Returns: - - freqs_cis (torch.Tensor): A tensor of complex numbers representing the precomputed frequencies. - """ - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) - t = torch.arange(end, device=freqs.device, dtype=torch.float32) - freqs = torch.outer(t, freqs) - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 - return freqs_cis - - -def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): - """ - Reshape the frequency components for broadcasting with the input tensor. - - Arguments: - - freqs_cis (torch.Tensor): The frequency components tensor. - - x (torch.Tensor): The input tensor to which the frequencies will be applied. - - Returns: - - torch.Tensor: The reshaped frequency components tensor. - """ - # Reference: https://github.com/meta-llama/llama3/blob/main/llama/model.py#L61 - ndim = x.ndim - shape = [d if i in (0, 2, ndim - 1) else 1 for i, d in enumerate(x.shape)] - return freqs_cis.view(*shape) - - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Apply rotary positional embeddings to the query and key tensors. - - Arguments: - - xq (torch.Tensor): The query tensor. - - xk (torch.Tensor): The key tensor. - - freqs_cis (torch.Tensor): The precomputed frequency components. - - Returns: - - Tuple[torch.Tensor, torch.Tensor]: The transformed query and key tensors. - - Note: - For more information on rotary positional embeddings, refer to the blog post: - https://spaces.ac.cn/archives/8265/ or paper https://arxiv.org/abs/2104.09864 - """ - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - freqs_cis = reshape_for_broadcast(freqs_cis, xq_) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2) - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2) - return xq_out.type_as(xq), xk_out.type_as(xk) - - class Transformer(nn.Module): """ - Transformer model class. - - Arguments: - - config (:obj:`TransformerConfig`): Configuration for the Transformer model. - - Attributes: - - config (:obj:`TransformerConfig`): Configuration object. - - drop (:obj:`nn.Dropout`): Dropout layer for embedding dropout. - - blocks (:obj:`nn.ModuleList`): List of Transformer blocks. - - ln_f (:obj:`nn.LayerNorm`): Layer normalization applied to the final output. + A Transformer model implementation. """ - def __init__(self, config: TransformerConfig, task_embed=None) -> None: + def __init__(self, config: TransformerConfig, task_embed: Optional[nn.Module] = None) -> None: + """ + Overview: + Initializes the Transformer model. + Arguments: + - config (:obj:`TransformerConfig`): The configuration object for the model. + - task_embed (:obj:`Optional[nn.Module]`): An optional module for generating task embeddings. + """ super().__init__() self.config = config self.drop = nn.Dropout(config.embed_pdrop) self.blocks = nn.ModuleList([Block(config) for _ in range(config.num_layers)]) self.ln_f = nn.LayerNorm(config.embed_dim) - if self.config.rotary_emb: - freqs_cis = precompute_freqs_cis( - self.config.embed_dim // self.config.num_heads, - self.config.max_seq_len * 2, - self.config.rope_theta, - ) - self.register_buffer("freqs_cis", freqs_cis) self.task_embed = task_embed - self.task_embed_option = self.config.task_embed_option # Strategy for task embeddings - self.register_token_shared = True - - # TODO: 共享模式下,所有任务使用同一参数 - - if self.task_embed_option == "register_task_embed": - self.use_register_token = True # TODO - # Register token setup - self.register_token_num = config.register_token_num if hasattr(config, "register_token_num") else 4 + self.task_embed_option = self.config.task_embed_option + self.use_register_token = (self.task_embed_option == "register_task_embed") - # 判断是否采用共享模式 + if self.use_register_token: + self.register_token_num = getattr(config, "register_token_num", 4) self.register_token_shared = getattr(config, "register_token_shared", True) + if self.register_token_shared: - # print(f'self.register_token_shared:{self.register_token_shared}') - # print(f'='*20) - # 共享模式:所有任务使用同一个 register_tokens 参数,形状为 (register_token_num, embed_dim) + # Shared mode: all tasks use the same register_tokens parameter. self.register_tokens = nn.Parameter(torch.empty(self.register_token_num, config.embed_dim)) nn.init.xavier_uniform_(self.register_tokens) else: - # 非共享模式:依赖外部传入的 task_embed 模块来生成 task embedding, - # 并通过 SimNorm 归一化后复制出 register token - self.task_embed = task_embed # 外部传入的模块,如 nn.Embedding - self.sim_norm = SimNorm(simnorm_dim=config.embed_dim) # Normalization for task embeddings - - else: - self.use_register_token = False # TODO - + # Non-shared mode: relies on the external `task_embed` module to generate + # task-specific embeddings, which are then normalized and expanded. + self.task_embed = task_embed + self.sim_norm = SimNorm(simnorm_dim=config.embed_dim) def add_register_tokens(self, sequences: torch.Tensor, task_id: int) -> torch.Tensor: """ - 将 register_token_num 个 Register Token 拼接到序列最前面。 - + Overview: + Prepends or appends register tokens to the input sequences. Arguments: - - sequences (:obj:`torch.Tensor`): (B, T, C) - - task_id (:obj:`int`): 当前任务的 ID - + - sequences (:obj:`torch.Tensor`): The input sequences, with shape (B, T, C). + - task_id (:obj:`int`): The ID of the current task. Returns: - - new_sequences (:obj:`torch.Tensor`): (B, T + register_token_num, C) + - torch.Tensor: The sequences with register tokens concatenated, shape (B, T + register_token_num, C). """ B = sequences.size(0) device = sequences.device if self.register_token_shared: - # 共享模式:直接使用同一组 register_tokens 参数 - # register_tokens 形状为 (register_token_num, embed_dim) - register_tokens = self.register_tokens - register_tokens = register_tokens.unsqueeze(0).expand(B, -1, -1) # 形状 (B, register_token_num, embed_dim) + # Shared mode: use the same set of register tokens for all batches. + register_tokens = self.register_tokens.unsqueeze(0).expand(B, -1, -1) else: - # 非共享模式:依靠 task_embed 动态生成 task embedding,然后复制出 register tokens - task_embedding = self.task_embed(torch.tensor([task_id], device=device)) # (1, embed_dim) - task_embedding = self.sim_norm(task_embedding.view(1, -1)).view(-1) # (embed_dim,) - register_tokens = task_embedding.unsqueeze(0).expand(self.register_token_num, -1) # (register_token_num, embed_dim) - register_tokens = register_tokens.unsqueeze(0).expand(B, -1, -1) # (B, register_token_num, embed_dim) - - new_sequences = torch.cat([sequences, register_tokens], dim=1) # 在序列末尾拼接 register tokens (B, register_token_num + T, C) + # Non-shared mode: dynamically generate task embedding and expand it. + task_embedding = self.task_embed(torch.tensor([task_id], device=device)) + task_embedding = self.sim_norm(task_embedding.view(1, -1)).view(-1) + register_tokens = task_embedding.unsqueeze(0).expand(self.register_token_num, -1) + register_tokens = register_tokens.unsqueeze(0).expand(B, -1, -1) + + # Concatenate register tokens at the end of the sequence. + new_sequences = torch.cat([sequences, register_tokens], dim=1) return new_sequences - def remove_register_tokens_from_kv(self, past_keys_values: KeysValues) -> None: + def remove_register_tokens_from_kv(self, past_keys_values: Optional[KeysValues]) -> None: """ - 移除所有层 KV 中最前面的 register_token_num 个 token,用于在 forward() 结束时调用。 + Overview: + Removes the register tokens from the key-value cache of all layers. + This is called at the end of the forward pass during inference. + Arguments: + - past_keys_values (:obj:`Optional[KeysValues]`): The key-value cache. """ - if past_keys_values is None: - return - past_keys_values.remove_register_tokens(self.register_token_num) + if past_keys_values is not None: + past_keys_values.remove_register_tokens(self.register_token_num) def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues: """ - Generate a placeholder for keys and values. - + Overview: + Generates a placeholder for the key-value cache. Arguments: - - n (:obj:`int`): Batch size. - - max_tokens (:obj:`int`): Maximum number of tokens in the sequence. - + - n (:obj:`int`): The batch size. + - max_tokens (:obj:`int`): The maximum number of tokens in the sequence. Returns: - - KeysValues: An object containing empty keys and values. + - KeysValues: An object containing empty tensors for keys and values. """ - device = self.ln_f.weight.device # Assumption: All submodules are on the same device + device = self.ln_f.weight.device return KeysValues(n, self.config.num_heads, max_tokens, self.config.embed_dim, self.config.num_layers, device) - def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues] = None, - valid_context_lengths: Optional[torch.Tensor] = None, start_pos: int = 0) -> torch.Tensor: + def forward( + self, + sequences: torch.Tensor, + past_keys_values: Optional[KeysValues] = None, + valid_context_lengths: Optional[torch.Tensor] = None, + task_id: int = 0, + start_pos: int = 0 + ) -> torch.Tensor: """ - Forward pass of the Transformer model. - + Overview: + Performs the forward pass of the Transformer model. Arguments: - - sequences (:obj:`torch.Tensor`): Input tensor of shape (batch_size, seq_length, embed_dim). - - past_keys_values (:obj:`Optional[KeysValues]`): Precomputed keys and values for faster generation (default: None). - - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Valid lengths of context for masking (default: None). - - start_pos (:obj:`int`): Starting position for rotary embeddings (default: 0). - + - sequences (:obj:`torch.Tensor`): The input tensor of shape (B, T, C). + - past_keys_values (:obj:`Optional[KeysValues]`): An optional cache for keys and values to speed up inference. + - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Tensor indicating the valid length of the context for each sample. + - task_id (:obj:`int`): The ID of the current task. + - start_pos (:obj:`int`): The starting position for the current sequence (used with kv-caching). Returns: - - 输出张量 (B, T + register_token_num, C) 或 (B, T, C),视是否添加 Register Token 而定 + - torch.Tensor: The output tensor of shape (B, T, C). """ - seqlen = sequences.shape[1] - # If using Rotary Position Embeddings (RoPE), slice the frequency components accordingly - if self.config.rotary_emb: - if isinstance(start_pos, (int, float, np.integer)): - # In the reanalyze_phase or reset stage in collection/evaluation phase, create a tensor filled with start_pos, expanded to match the batch size, and adjust for sequence type, e.g., start_pos=2. - start_pos_tensor = torch.full((sequences.shape[0],), int(start_pos), device=sequences.device) - elif isinstance(start_pos, (list, np.ndarray, torch.Tensor)): - if isinstance(start_pos[0], (np.ndarray, torch.Tensor, list)): - # In the training phase, flatten start_pos, take the first element, convert to tensor, e.g., start_pos=[array([ 8, 10, 12, 14, 16]), array([12, 14, 16, 18, 20])] - start_pos_tensor = torch.as_tensor( - [x.reshape(-1)[0].item() for x in start_pos], # Force flatten and take the first element - device=sequences.device - ) - elif isinstance(start_pos[0], (int, float, np.integer)): - # In the collection/evaluation phase, e.g., start_pos = [0, 0, 0, 0, 0, 0, 0, 0] - start_pos_tensor = torch.as_tensor([int(x) for x in start_pos], device=sequences.device) - else: - raise ValueError("start_pos must be an int, float, list, numpy array or torch.Tensor.") - - # TODO: Determine how to handle cases when episode length exceeds max_seq_len - # Use modulo operation to ensure start_pos does not exceed max_seq_len - start_pos_tensor = torch.remainder(start_pos_tensor, self.config.max_seq_len) - # Convert each sample's start_pos to a list - start_pos_list = start_pos_tensor.tolist() - # For each sample, slice the corresponding range of freqs_cis based on start_pos - freqs_cis_slices = [self.freqs_cis[int(pos): int(pos) + seqlen] for pos in start_pos_list] - freqs_cis = torch.stack(freqs_cis_slices) - - if freqs_cis.ndim == 3 and freqs_cis.shape[1] == 1: - # Convert shape [seq_len, 1, num_pairs] to [seq_len, num_pairs] - freqs_cis = freqs_cis.squeeze(1) - else: - freqs_cis = None - - # print(f"freqs_cis.shape:{freqs_cis.shape}") + if self.use_register_token: + sequences = self.add_register_tokens(sequences, task_id) - # Ensure past keys and values match the number of transformer blocks - assert past_keys_values is None or len(past_keys_values) == len(self.blocks) - # Apply dropout to the input sequences x = self.drop(sequences) - # Pass through each transformer block + for i, block in enumerate(self.blocks): - x = block(x, None if past_keys_values is None else past_keys_values[i], valid_context_lengths, freqs_cis) - # Apply final layer normalization - x = self.ln_f(x) + kv_cache_layer = None if past_keys_values is None else past_keys_values[i] + x = block(x, kv_cache_layer, valid_context_lengths) - # 如果 past_keys_values 不为 None,说明是推理阶段,此时我们需要把 KV 缓存中 - # 尾部多加的 Register Token 移除,以保证外键信息一致,不用修改外部逻辑 - # if self.use_register_token and (past_keys_values is not None): - if self.use_register_token: - self.remove_register_tokens_from_kv(past_keys_values) + x = self.ln_f(x) - # TODO if self.use_register_token: - # import ipdb; ipdb.set_trace() + # During inference, remove register tokens from the KV cache to maintain consistency + # for external logic that does not expect them. + if past_keys_values is not None: + self.remove_register_tokens_from_kv(past_keys_values) + + # TODO: Remove register tokens from the final output to match the input sequence length. x = x[:, :-self.register_token_num, :] return x - - class Block(nn.Module): """ - Transformer block class. - - Arguments: - config (:obj:`TransformerConfig`): Configuration for the Transformer block. - - Attributes: - - gru_gating (:obj:`bool`): Flag to use GRU gating mechanism. - - gru_bias (:obj:`float`): Bias for the GRU gating mechanism. - - gate1 (:obj:`Optional[GRUGatingUnit]`): First GRU gating unit (if GRU gating is enabled). - - gate2 (:obj:`Optional[GRUGatingUnit]`): Second GRU gating unit (if GRU gating is enabled). - - ln1 (:obj:`nn.LayerNorm`): Layer normalization before the attention layer. - - ln2 (:obj:`nn.LayerNorm`): Layer normalization before the MLP. - - attn (:obj:`SelfAttention`): Self-attention mechanism. - - mlp (:obj:`nn.Sequential`): Multi-layer perceptron. + A single Transformer block, consisting of self-attention and a feed-forward network. """ def __init__(self, config: TransformerConfig) -> None: + """ + Overview: + Initializes a Transformer block. + Arguments: + - config (:obj:`TransformerConfig`): The configuration object for the block. + """ super().__init__() - # NOTE: GRU gating as in GTrXL self.gru_gating = config.gru_gating - self.gru_bias = 2.0 if self.gru_gating: - self.gate1 = GRUGatingUnit(config.embed_dim, self.gru_bias) - self.gate2 = GRUGatingUnit(config.embed_dim, self.gru_bias) + # As in GTrXL, for stabilizing training with recurrence + self.gate1 = GRUGatingUnit(config.embed_dim, bias_init=2.0) + self.gate2 = GRUGatingUnit(config.embed_dim, bias_init=2.0) self.ln1 = nn.LayerNorm(config.embed_dim) self.ln2 = nn.LayerNorm(config.embed_dim) self.attn = SelfAttention(config) if config.moe_in_transformer: - # 创Create multiple independent MLP instances + from .moe import MoELayer + # Create multiple independent MLP instances as experts self.experts = nn.ModuleList([ nn.Sequential( nn.Linear(config.embed_dim, 4 * config.embed_dim), @@ -443,38 +646,28 @@ def __init__(self, config: TransformerConfig) -> None: nn.Dropout(config.resid_pdrop), ) for _ in range(config.num_experts_of_moe_in_transformer) ]) - self.feed_forward = MoeLayer( + self.feed_forward = MoELayer( + config, experts=self.experts, gate=nn.Linear(config.embed_dim, config.num_experts_of_moe_in_transformer, bias=False), - num_experts_per_tok=1, + num_experts_per_tok=config.num_experts_per_tok, ) - - print("="*20) - print(f'use moe in feed_forward of transformer, num of expert: {config.num_experts_of_moe_in_transformer}') - print("="*20) + logging.info(f"Using MoE in transformer feed-forward with {config.num_experts_of_moe_in_transformer} experts.") elif config.multiplication_moe_in_transformer: + from .moe import MoELayer, MultiplicationFeedForward # Create multiple FeedForward instances for multiplication-based MoE self.experts = nn.ModuleList([ MultiplicationFeedForward(config) for _ in range(config.num_experts_of_moe_in_transformer) ]) - - self.feed_forward = MoeLayer( + self.feed_forward = MoELayer( + config, experts=self.experts, gate=nn.Linear(config.embed_dim, config.num_experts_of_moe_in_transformer, bias=False), - num_experts_per_tok=1, + num_experts_per_tok=config.num_experts_per_tok, ) - - print("="*20) - print(f'use multiplication moe in feed_forward of transformer, num of expert: {config.num_experts_of_moe_in_transformer}') - print("="*20) + logging.info(f"Using Multiplication MoE in transformer feed-forward with {config.num_experts_of_moe_in_transformer} experts.") else: - # self.feed_forward = nn.Sequential( - # nn.Linear(config.embed_dim, 4 * config.embed_dim), - # nn.GELU(approximate='tanh'), - # nn.Linear(4 * config.embed_dim, config.embed_dim), - # nn.Dropout(config.resid_pdrop), - # ) - # 普通的 MLP,若在 feed_forward 上启用 LoRA,则对其中线性层进行包装 + # Standard MLP, with linear layers potentially wrapped for LoRA. self.feed_forward = nn.Sequential( _maybe_wrap_linear(nn.Linear(config.embed_dim, 4 * config.embed_dim), config, "feed_forward"), nn.GELU(approximate='tanh'), @@ -483,222 +676,178 @@ def __init__(self, config: TransformerConfig) -> None: ) def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None, - valid_context_lengths: Optional[torch.Tensor] = None, freqs_cis: torch.Tensor = None) -> torch.Tensor: + valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor: """ - Forward pass of the Transformer block. - + Overview: + Performs the forward pass of the Transformer block. Arguments: - x (:obj:`torch.Tensor`): Input tensor of shape (batch_size, seq_length, embed_dim). - - past_keys_values (:obj:`Optional[KeysValues]`): Precomputed keys and values for faster generation (default: None). - - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Valid lengths of context for masking (default: None). - - freqs_cis (:obj:`torch.Tensor`): Frequency components for rotary position embeddings, used to modulate the attention mechanism (default: None). - + - past_keys_values (:obj:`Optional[KeysValues]`): Precomputed keys and values for faster generation. + - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Valid lengths of context for masking. Returns: - torch.Tensor: Output tensor of shape (batch_size, seq_length, embed_dim). """ - x_attn = self.attn(self.ln1(x), past_keys_values, valid_context_lengths, freqs_cis) + attn_output = self.attn(self.ln1(x), past_keys_values, valid_context_lengths) if self.gru_gating: - x = self.gate1(x, x_attn) - x = self.gate2(x, self.feed_forward(self.ln2(x))) + x = self.gate1(x, attn_output) + ff_output = self.feed_forward(self.ln2(x)) + x = self.gate2(x, ff_output) else: - x = x + x_attn + x = x + attn_output x = x + self.feed_forward(self.ln2(x)) - return x class SelfAttention(nn.Module): """ - Implements self-attention mechanism for transformers. - - Arguments: - config (:obj:`TransformerConfig`): Configuration object containing hyperparameters. - - Attributes: - - config (:obj:`TransformerConfig`): Stores the configuration for the self-attention module. - - num_heads (:obj:`int`): Number of attention heads. - - key (:obj:`nn.Linear`): Linear layer to project input to key vectors. - - query (:obj:`nn.Linear`): Linear layer to project input to query vectors. - - value (:obj:`nn.Linear`): Linear layer to project input to value vectors. - - attn_drop (:obj:`nn.Dropout`): Dropout layer for attention weights. - - resid_drop (:obj:`nn.Dropout`): Dropout layer for residual connection. - - proj (:obj:`nn.Linear`): Final linear layer for projection. - - mask (:obj:`torch.Tensor`): Mask tensor for causal or block-causal attention. + Implements the self-attention mechanism for a Transformer. """ + def __init__(self, config: TransformerConfig) -> None: + """ + Overview: + Initializes the SelfAttention module. + Arguments: + - config (:obj:`TransformerConfig`): The configuration object for the attention module. + """ super().__init__() assert config.embed_dim % config.num_heads == 0, "Embedding dimension must be divisible by number of heads." self.config = config - + self.num_heads = config.num_heads + self.task_embed_option = self.config.task_embed_option - if self.task_embed_option == "register_task_embed": - self.use_register_token = True # TODO - # Register token setup - self.register_token_num = config.register_token_num if hasattr(config, "register_token_num") else 4 - else: - self.use_register_token = False # TODO + self.use_register_token = (self.task_embed_option == "register_task_embed") + if self.use_register_token: + self.register_token_num = getattr(config, "register_token_num", 4) - self.num_heads = config.num_heads - - if config.lora_r > 0 and ("attn" in config.lora_target_modules): - self.key = _maybe_wrap_linear(nn.Linear(config.embed_dim, config.embed_dim), config, "attn") - self.query = _maybe_wrap_linear(nn.Linear(config.embed_dim, config.embed_dim), config, "attn") - self.value = _maybe_wrap_linear(nn.Linear(config.embed_dim, config.embed_dim), config, "attn") - self.proj = _maybe_wrap_linear(nn.Linear(config.embed_dim, config.embed_dim), config, "attn") - else: - self.key = nn.Linear(config.embed_dim, config.embed_dim) - self.query = nn.Linear(config.embed_dim, config.embed_dim) - self.value = nn.Linear(config.embed_dim, config.embed_dim) - self.proj = nn.Linear(config.embed_dim, config.embed_dim) + # Wrap linear layers if LoRA is enabled for the attention module + self.key = _maybe_wrap_linear(nn.Linear(config.embed_dim, config.embed_dim), config, "attn") + self.query = _maybe_wrap_linear(nn.Linear(config.embed_dim, config.embed_dim), config, "attn") + self.value = _maybe_wrap_linear(nn.Linear(config.embed_dim, config.embed_dim), config, "attn") + self.proj = _maybe_wrap_linear(nn.Linear(config.embed_dim, config.embed_dim), config, "attn") self.attn_drop = nn.Dropout(config.attn_pdrop) self.resid_drop = nn.Dropout(config.resid_pdrop) - if self.use_register_token: # ======= TODO ======== - causal_mask = torch.tril(torch.ones(config.max_tokens+self.register_token_num*5, config.max_tokens+self.register_token_num*5)) - else: - causal_mask = torch.tril(torch.ones(config.max_tokens, config.max_tokens)) - + # TODO: The mask size is conservatively large to accommodate register tokens. + # This could be made more dynamic. + mask_size = config.max_tokens + if self.use_register_token: + mask_size += self.register_token_num * 5 + causal_mask = torch.tril(torch.ones(mask_size, mask_size)) self.register_buffer('mask', causal_mask) - #@profile def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None, - valid_context_lengths: Optional[torch.Tensor] = None, freqs_cis: torch.Tensor = None) -> torch.Tensor: + valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor: """ - Forward pass for the self-attention mechanism. - + Overview: + Performs the forward pass for the self-attention mechanism. Arguments: - - x (:obj:`torch.Tensor`): Input tensor of shape (B, T, C) where B is batch size, - T is sequence length, and C is embedding dimension. + - x (:obj:`torch.Tensor`): Input tensor of shape (B, T, C). - kv_cache (:obj:`Optional[KeysValues]`): Optional key-value cache for faster inference. - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Optional tensor containing valid context lengths. - - freqs_cis (:obj:`torch.Tensor`): Frequency components for rotary position embeddings, used to modulate the attention mechanism (default: None). - Returns: - torch.Tensor: Output tensor of shape (B, T, C). """ B, T, C = x.size() + head_size = C // self.num_heads + + past_len = 0 if kv_cache is not None: - b, nh, L, c = kv_cache.shape - try: - assert nh == self.num_heads and b == B and c * nh == C, "Cache dimensions do not match input dimensions." - except Exception as e: - print('debug') - else: - L = 0 + past_len = kv_cache.shape[2] - q = self.query(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, T, head_size) - k = self.key(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, T, head_size) - v = self.value(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, T, head_size) - - if self.config.rotary_emb: - q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis) + q = self.query(x).view(B, T, self.num_heads, head_size).transpose(1, 2) + k = self.key(x).view(B, T, self.num_heads, head_size).transpose(1, 2) + v = self.value(x).view(B, T, self.num_heads, head_size).transpose(1, 2) if kv_cache is not None: - # import ipdb; ipdb.set_trace() - kv_cache.update(k, v) # time occupancy 21% - k, v = kv_cache.get() # time occupancy 5% + kv_cache.update(k, v) + k, v = kv_cache.get() + current_len = k.size(2) att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + # Construct the attention mask + mask = self.mask[past_len:past_len + T, :current_len] + if valid_context_lengths is not None: - # Final mask.shape: (B, T, L + T) - # L is the context length, T is the current input length, - # valid_context_lengths is the valid length at the end of the context. - mask = torch.zeros(B, T, L + T, device=att.device) - # For each sample, set the invalid parts to 0 based on its valid length. + # This logic is for a specific use case and may need adjustment. + # It creates a custom mask for each item in the batch. + batch_mask = torch.zeros(B, T, current_len, device=att.device) for i in range(B): - mask[i] = self.mask[L:L + T, :L + T].clone() - mask[i, :, :(L - valid_context_lengths[i])] = 0 # Set invalid parts to 0. - # Adjust mask dimensions to match the last two dimensions of att. - # (B, T, L + T) -> (B, 1, T, L + T) -> (B, num_heads, T, L + T) - mask = mask.unsqueeze(1).expand(-1, att.size(1), -1, -1) - else: - # mask.shape: (T, L + T) - mask = self.mask[L:L + T, :L + T] - - # import ipdb; ipdb.set_trace() + batch_mask[i] = mask.clone() + # Zero out attention to invalid past context + batch_mask[i, :, :(past_len - valid_context_lengths[i])] = 0 + mask = batch_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1) - # Adjust mask for register tokens if applicable + # Adjust mask for register tokens if they are in use if self.use_register_token and self.register_token_num > 0: - # Allow all positions to attend to the last `register_token_num` tokens - register_mask = mask.clone() # (T, L + T) - register_mask[-self.register_token_num:, :] = 1 # Allow register tokens to see all positions - register_mask[:, -self.register_token_num:] = 1 # Allow all positions to see register tokens + # Allow all positions to attend to register tokens and vice-versa + register_mask = mask.clone() + # Register tokens are at the end of the sequence + register_indices_start = current_len - self.register_token_num + register_mask[..., register_indices_start:] = 1 # All can see registers + # This part is more complex if T is not the full sequence length + if T > self.register_token_num: + # Only the actual register tokens in the current input `x` can see everything + register_mask[..., -self.register_token_num:, :] = 1 mask = register_mask - + if kv_cache is not None: - # =============TODO============= - # import ipdb; ipdb.set_trace() - b, nh, new_L, c = kv_cache.shape # new_L可能小于L + T - mask = mask[:,-new_L:] - # else: - # import ipdb; ipdb.set_trace() + # Ensure mask dimensions match the potentially smaller KV cache length + new_L = kv_cache.shape[2] + mask = mask[..., :new_L] - - # att.shape: (B, num_heads, T, L + T) att = att.masked_fill(mask == 0, float('-inf')) - att = F.softmax(att, dim=-1) att = self.attn_drop(att) - # import ipdb; ipdb.set_trace() - y = att @ v # (B, num_heads, T, L + T) x (B, num_heads, L + T, head_size) -> (B, num_heads, T, head_size) - - y = rearrange(y, 'b h t e -> b t (h e)') # Combine the heads back together (B, T, embed_dim) + y = att @ v + y = rearrange(y, 'b h t e -> b t (h e)') y = self.resid_drop(self.proj(y)) - - return y @torch.no_grad() def get_attention_map(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None, valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor: """ - Compute the attention map for the input sequence. This is useful for visualization purposes. - More details can be found in visualizing_utils.py. - + Overview: + Computes the attention map for visualization, without computing the final output. Arguments: - x (:obj:`torch.Tensor`): Input sequence with shape (B, T, C). - - kv_cache (:obj:`Optional[KeysValues]`): Cached keys and values for supporting long sequence inference. - - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Valid context lengths for handling variable-length contexts. - + - kv_cache (:obj:`Optional[KeysValues]`): Cached keys and values for long sequence inference. + - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Valid context lengths for variable-length inputs. Returns: - - torch.Tensor: Attention map with shape (B, nh, T, L + T), representing the distribution of attention. + - torch.Tensor: Attention map of shape (B, num_heads, T, L + T). """ B, T, C = x.size() + head_size = C // self.num_heads + + past_len = 0 if kv_cache is not None: - b, nh, L, c = kv_cache.shape - assert nh == self.num_heads and b == B and c * nh == C, "Cache dimensions are inconsistent with input dimensions." - else: - L = 0 + past_len = kv_cache.shape[2] - # Compute query, key, and value projections - q = self.query(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs) - k = self.key(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs) - v = self.value(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs) + q = self.query(x).view(B, T, self.num_heads, head_size).transpose(1, 2) + k = self.key(x).view(B, T, self.num_heads, head_size).transpose(1, 2) + v = self.value(x).view(B, T, self.num_heads, head_size).transpose(1, 2) if kv_cache is not None: - # Update the kv_cache with the new keys and values kv_cache.update(k, v) k, v = kv_cache.get() - # Compute the attention scores + current_len = k.size(2) att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + mask = self.mask[past_len:past_len + T, :current_len] if valid_context_lengths is not None: - mask = torch.zeros(B, T, L + T, device=att.device) + batch_mask = torch.zeros(B, T, current_len, device=att.device) for i in range(B): - # Create attention mask for each batch - mask[i] = self.mask[L:L + T, :L + T].clone() - mask[i, :, :(L - valid_context_lengths[i])] = 0 - mask = mask.unsqueeze(1).expand(-1, att.size(1), -1, -1) - else: - mask = self.mask[L:L + T, :L + T] + batch_mask[i] = mask.clone() + batch_mask[i, :, :(past_len - valid_context_lengths[i])] = 0 + mask = batch_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1) - # Apply the attention mask att = att.masked_fill(mask == 0, float('-inf')) att = F.softmax(att, dim=-1) diff --git a/lzero/model/unizero_world_models/transformer_no-lora.py b/lzero/model/unizero_world_models/transformer_no-lora.py deleted file mode 100644 index e0f0f0c0b..000000000 --- a/lzero/model/unizero_world_models/transformer_no-lora.py +++ /dev/null @@ -1,477 +0,0 @@ -""" -Modified from https://github.com/karpathy/nanoGPT -""" - -import math -from dataclasses import dataclass -from typing import Optional - -import torch -import torch.nn as nn -from ding.torch_utils.network import GRUGatingUnit -from einops import rearrange -from torch.nn import functional as F - -from .kv_caching import KeysValues -from .moe import MoeLayer, MultiplicationFeedForward -from line_profiler import line_profiler -from lzero.model.common import SimNorm - - -@dataclass -class TransformerConfig: - tokens_per_block: int - max_blocks: int - attention: str - - num_layers: int - num_heads: int - embed_dim: int - - embed_pdrop: float - resid_pdrop: float - attn_pdrop: float - - @property - def max_tokens(self): - return self.tokens_per_block * self.max_blocks - - -class Transformer(nn.Module): - """ - Transformer model class. - - Arguments: - config (:obj:`TransformerConfig`): Configuration for the Transformer model. - - Attributes: - - config (:obj:`TransformerConfig`): Configuration object. - - drop (:obj:`nn.Dropout`): Dropout layer for embedding dropout. - - blocks (:obj:`nn.ModuleList`): List of Transformer blocks. - - ln_f (:obj:`nn.LayerNorm`): Layer normalization applied to the final output. - """ - - def __init__(self, config: TransformerConfig, task_embed=None) -> None: - super().__init__() - self.config = config - self.drop = nn.Dropout(config.embed_pdrop) - self.blocks = nn.ModuleList([Block(config) for _ in range(config.num_layers)]) - self.ln_f = nn.LayerNorm(config.embed_dim) - - self.task_embed = task_embed - self.task_embed_option = self.config.task_embed_option # Strategy for task embeddings - self.register_token_shared = True - - # TODO: 共享模式下,所有任务使用同一参数 - - if self.task_embed_option == "register_task_embed": - self.use_register_token = True # TODO - # Register token setup - self.register_token_num = config.register_token_num if hasattr(config, "register_token_num") else 4 - - # 判断是否采用共享模式 - self.register_token_shared = getattr(config, "register_token_shared", True) - if self.register_token_shared: - # print(f'self.register_token_shared:{self.register_token_shared}') - # print(f'='*20) - # 共享模式:所有任务使用同一个 register_tokens 参数,形状为 (register_token_num, embed_dim) - self.register_tokens = nn.Parameter(torch.empty(self.register_token_num, config.embed_dim)) - nn.init.xavier_uniform_(self.register_tokens) - else: - # 非共享模式:依赖外部传入的 task_embed 模块来生成 task embedding, - # 并通过 SimNorm 归一化后复制出 register token - self.task_embed = task_embed # 外部传入的模块,如 nn.Embedding - self.sim_norm = SimNorm(simnorm_dim=config.embed_dim) # Normalization for task embeddings - - else: - self.use_register_token = False # TODO - - - def add_register_tokens(self, sequences: torch.Tensor, task_id: int) -> torch.Tensor: - """ - 将 register_token_num 个 Register Token 拼接到序列最前面。 - - Arguments: - - sequences (:obj:`torch.Tensor`): (B, T, C) - - task_id (:obj:`int`): 当前任务的 ID - - Returns: - - new_sequences (:obj:`torch.Tensor`): (B, T + register_token_num, C) - """ - B = sequences.size(0) - device = sequences.device - - if self.register_token_shared: - # 共享模式:直接使用同一组 register_tokens 参数 - # register_tokens 形状为 (register_token_num, embed_dim) - register_tokens = self.register_tokens - register_tokens = register_tokens.unsqueeze(0).expand(B, -1, -1) # 形状 (B, register_token_num, embed_dim) - else: - # 非共享模式:依靠 task_embed 动态生成 task embedding,然后复制出 register tokens - task_embedding = self.task_embed(torch.tensor([task_id], device=device)) # (1, embed_dim) - task_embedding = self.sim_norm(task_embedding.view(1, -1)).view(-1) # (embed_dim,) - register_tokens = task_embedding.unsqueeze(0).expand(self.register_token_num, -1) # (register_token_num, embed_dim) - register_tokens = register_tokens.unsqueeze(0).expand(B, -1, -1) # (B, register_token_num, embed_dim) - - new_sequences = torch.cat([sequences, register_tokens], dim=1) # 在序列末尾拼接 register tokens (B, register_token_num + T, C) - return new_sequences - - def remove_register_tokens_from_kv(self, past_keys_values: KeysValues) -> None: - """ - 移除所有层 KV 中最前面的 register_token_num 个 token,用于在 forward() 结束时调用。 - """ - if past_keys_values is None: - return - past_keys_values.remove_register_tokens(self.register_token_num) - - def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues: - """ - Generate a placeholder for keys and values. - - Arguments: - - n (:obj:`int`): Batch size. - - max_tokens (:obj:`int`): Maximum number of tokens in the sequence. - - Returns: - - KeysValues: An object containing empty keys and values. - """ - device = self.ln_f.weight.device # Assumption: All submodules are on the same device - return KeysValues(n, self.config.num_heads, max_tokens, self.config.embed_dim, self.config.num_layers, device) - - - #@profile - def forward( - self, - sequences: torch.Tensor, # (B, T, C) - past_keys_values: Optional[KeysValues] = None, - valid_context_lengths: Optional[torch.Tensor] = None, - task_id: int = 0 - ) -> torch.Tensor: - """ - Forward pass of the Transformer model. - - Arguments: - - sequences (:obj:`torch.Tensor`): (B, T, C) - - past_keys_values (:obj:`Optional[KeysValues]`): 缓存,用于推理时加速 - - valid_context_lengths (:obj:`Optional[torch.Tensor]`): 某些场景下可用的有效上下文长度 - - task_id (:obj:`int`): 任务 ID - - Returns: - - 输出张量 (B, T + register_token_num, C) 或 (B, T, C),视是否添加 Register Token 而定 - """ - # 若使用 Register Token,则将其拼到序列最前面 - # 训练阶段和推理阶段都统一处理 - if self.use_register_token: - sequences = self.add_register_tokens(sequences, task_id) - - # 接入 dropout - x = self.drop(sequences) - - # 逐层调用 - for i, block in enumerate(self.blocks): - x = block(x, - None if past_keys_values is None else past_keys_values[i], - valid_context_lengths) - - # 最后层 LN - x = self.ln_f(x) - - # 如果 past_keys_values 不为 None,说明是推理阶段,此时我们需要把 KV 缓存中 - # 尾部多加的 Register Token 移除,以保证外键信息一致,不用修改外部逻辑 - # if self.use_register_token and (past_keys_values is not None): - if self.use_register_token: - self.remove_register_tokens_from_kv(past_keys_values) - - # TODO - if self.use_register_token: - # import ipdb; ipdb.set_trace() - x = x[:, :-self.register_token_num, :] - - return x - - - - -class Block(nn.Module): - """ - Transformer block class. - - Arguments: - config (:obj:`TransformerConfig`): Configuration for the Transformer block. - - Attributes: - - gru_gating (:obj:`bool`): Flag to use GRU gating mechanism. - - gru_bias (:obj:`float`): Bias for the GRU gating mechanism. - - gate1 (:obj:`Optional[GRUGatingUnit]`): First GRU gating unit (if GRU gating is enabled). - - gate2 (:obj:`Optional[GRUGatingUnit]`): Second GRU gating unit (if GRU gating is enabled). - - ln1 (:obj:`nn.LayerNorm`): Layer normalization before the attention layer. - - ln2 (:obj:`nn.LayerNorm`): Layer normalization before the MLP. - - attn (:obj:`SelfAttention`): Self-attention mechanism. - - mlp (:obj:`nn.Sequential`): Multi-layer perceptron. - """ - - def __init__(self, config: TransformerConfig) -> None: - super().__init__() - # NOTE: GRU gating as in GTrXL - self.gru_gating = config.gru_gating - self.gru_bias = 2.0 - if self.gru_gating: - self.gate1 = GRUGatingUnit(config.embed_dim, self.gru_bias) - self.gate2 = GRUGatingUnit(config.embed_dim, self.gru_bias) - - self.ln1 = nn.LayerNorm(config.embed_dim) - self.ln2 = nn.LayerNorm(config.embed_dim) - self.attn = SelfAttention(config) - if config.moe_in_transformer: - # 创Create multiple independent MLP instances - self.experts = nn.ModuleList([ - nn.Sequential( - nn.Linear(config.embed_dim, 4 * config.embed_dim), - nn.GELU(approximate='tanh'), - nn.Linear(4 * config.embed_dim, config.embed_dim), - nn.Dropout(config.resid_pdrop), - ) for _ in range(config.num_experts_of_moe_in_transformer) - ]) - - self.feed_forward = MoeLayer( - experts=self.experts, - gate=nn.Linear(config.embed_dim, config.num_experts_of_moe_in_transformer, bias=False), - num_experts_per_tok=1, - ) - - print("="*20) - print(f'use moe in feed_forward of transformer, num of expert: {config.num_experts_of_moe_in_transformer}') - print("="*20) - elif config.multiplication_moe_in_transformer: - # Create multiple FeedForward instances for multiplication-based MoE - self.experts = nn.ModuleList([ - MultiplicationFeedForward(config) for _ in range(config.num_experts_of_moe_in_transformer) - ]) - - self.feed_forward = MoeLayer( - experts=self.experts, - gate=nn.Linear(config.embed_dim, config.num_experts_of_moe_in_transformer, bias=False), - num_experts_per_tok=1, - ) - - print("="*20) - print(f'use multiplication moe in feed_forward of transformer, num of expert: {config.num_experts_of_moe_in_transformer}') - print("="*20) - else: - self.feed_forward = nn.Sequential( - nn.Linear(config.embed_dim, 4 * config.embed_dim), - nn.GELU(approximate='tanh'), - nn.Linear(4 * config.embed_dim, config.embed_dim), - nn.Dropout(config.resid_pdrop), - ) - - def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None, - valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor: - """ - Forward pass of the Transformer block. - - Arguments: - - x (:obj:`torch.Tensor`): Input tensor of shape (batch_size, seq_length, embed_dim). - - past_keys_values (:obj:`Optional[KeysValues]`): Precomputed keys and values for faster generation (default: None). - - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Valid lengths of context for masking (default: None). - - Returns: - - torch.Tensor: Output tensor of shape (batch_size, seq_length, embed_dim). - """ - x_attn = self.attn(self.ln1(x), past_keys_values, valid_context_lengths) - if self.gru_gating: - x = self.gate1(x, x_attn) - x = self.gate2(x, self.feed_forward(self.ln2(x))) - else: - x = x + x_attn - x = x + self.feed_forward(self.ln2(x)) - - return x - - -class SelfAttention(nn.Module): - """ - Implements self-attention mechanism for transformers. - - Arguments: - config (:obj:`TransformerConfig`): Configuration object containing hyperparameters. - - Attributes: - - config (:obj:`TransformerConfig`): Stores the configuration for the self-attention module. - - num_heads (:obj:`int`): Number of attention heads. - - key (:obj:`nn.Linear`): Linear layer to project input to key vectors. - - query (:obj:`nn.Linear`): Linear layer to project input to query vectors. - - value (:obj:`nn.Linear`): Linear layer to project input to value vectors. - - attn_drop (:obj:`nn.Dropout`): Dropout layer for attention weights. - - resid_drop (:obj:`nn.Dropout`): Dropout layer for residual connection. - - proj (:obj:`nn.Linear`): Final linear layer for projection. - - mask (:obj:`torch.Tensor`): Mask tensor for causal or block-causal attention. - """ - def __init__(self, config: TransformerConfig) -> None: - super().__init__() - assert config.embed_dim % config.num_heads == 0, "Embedding dimension must be divisible by number of heads." - - self.config = config - - self.task_embed_option = self.config.task_embed_option - if self.task_embed_option == "register_task_embed": - self.use_register_token = True # TODO - # Register token setup - self.register_token_num = config.register_token_num if hasattr(config, "register_token_num") else 4 - else: - self.use_register_token = False # TODO - - self.num_heads = config.num_heads - - self.key = nn.Linear(config.embed_dim, config.embed_dim) - self.query = nn.Linear(config.embed_dim, config.embed_dim) - self.value = nn.Linear(config.embed_dim, config.embed_dim) - - self.attn_drop = nn.Dropout(config.attn_pdrop) - self.resid_drop = nn.Dropout(config.resid_pdrop) - self.proj = nn.Linear(config.embed_dim, config.embed_dim) - - if self.use_register_token: # ======= TODO ======== - causal_mask = torch.tril(torch.ones(config.max_tokens+self.register_token_num*5, config.max_tokens+self.register_token_num*5)) - else: - causal_mask = torch.tril(torch.ones(config.max_tokens, config.max_tokens)) - - self.register_buffer('mask', causal_mask) - - #@profile - def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None, - valid_context_lengths: Optional[torch.Tensor] = None, ) -> torch.Tensor: - """ - Forward pass for the self-attention mechanism. - - Arguments: - - x (:obj:`torch.Tensor`): Input tensor of shape (B, T, C) where B is batch size, - T is sequence length, and C is embedding dimension. - - kv_cache (:obj:`Optional[KeysValues]`): Optional key-value cache for faster inference. - - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Optional tensor containing valid context lengths. - - Returns: - - torch.Tensor: Output tensor of shape (B, T, C). - """ - B, T, C = x.size() - if kv_cache is not None: - b, nh, L, c = kv_cache.shape - try: - assert nh == self.num_heads and b == B and c * nh == C, "Cache dimensions do not match input dimensions." - except Exception as e: - print('debug') - else: - L = 0 - - q = self.query(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, T, head_size) - k = self.key(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, T, head_size) - v = self.value(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, T, head_size) - - if kv_cache is not None: - # import ipdb; ipdb.set_trace() - kv_cache.update(k, v) # time occupancy 21% - k, v = kv_cache.get() # time occupancy 5% - - att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - - if valid_context_lengths is not None: - # Final mask.shape: (B, T, L + T) - # L is the context length, T is the current input length, - # valid_context_lengths is the valid length at the end of the context. - mask = torch.zeros(B, T, L + T, device=att.device) - # For each sample, set the invalid parts to 0 based on its valid length. - for i in range(B): - mask[i] = self.mask[L:L + T, :L + T].clone() - mask[i, :, :(L - valid_context_lengths[i])] = 0 # Set invalid parts to 0. - # Adjust mask dimensions to match the last two dimensions of att. - # (B, T, L + T) -> (B, 1, T, L + T) -> (B, num_heads, T, L + T) - mask = mask.unsqueeze(1).expand(-1, att.size(1), -1, -1) - else: - # mask.shape: (T, L + T) - mask = self.mask[L:L + T, :L + T] - - # import ipdb; ipdb.set_trace() - - # Adjust mask for register tokens if applicable - if self.use_register_token and self.register_token_num > 0: - # Allow all positions to attend to the last `register_token_num` tokens - register_mask = mask.clone() # (T, L + T) - register_mask[-self.register_token_num:, :] = 1 # Allow register tokens to see all positions - register_mask[:, -self.register_token_num:] = 1 # Allow all positions to see register tokens - mask = register_mask - - if kv_cache is not None: - # =============TODO============= - # import ipdb; ipdb.set_trace() - b, nh, new_L, c = kv_cache.shape # new_L可能小于L + T - mask = mask[:,-new_L:] - # else: - # import ipdb; ipdb.set_trace() - - - # att.shape: (B, num_heads, T, L + T) - att = att.masked_fill(mask == 0, float('-inf')) - - att = F.softmax(att, dim=-1) - att = self.attn_drop(att) - - # import ipdb; ipdb.set_trace() - y = att @ v # (B, num_heads, T, L + T) x (B, num_heads, L + T, head_size) -> (B, num_heads, T, head_size) - - y = rearrange(y, 'b h t e -> b t (h e)') # Combine the heads back together (B, T, embed_dim) - y = self.resid_drop(self.proj(y)) - - - - return y - - @torch.no_grad() - def get_attention_map(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None, - valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor: - """ - Compute the attention map for the input sequence. This is useful for visualization purposes. - More details can be found in visualizing_utils.py. - - Arguments: - - x (:obj:`torch.Tensor`): Input sequence with shape (B, T, C). - - kv_cache (:obj:`Optional[KeysValues]`): Cached keys and values for supporting long sequence inference. - - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Valid context lengths for handling variable-length contexts. - - Returns: - - torch.Tensor: Attention map with shape (B, nh, T, L + T), representing the distribution of attention. - """ - B, T, C = x.size() - if kv_cache is not None: - b, nh, L, c = kv_cache.shape - assert nh == self.num_heads and b == B and c * nh == C, "Cache dimensions are inconsistent with input dimensions." - else: - L = 0 - - # Compute query, key, and value projections - q = self.query(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs) - k = self.key(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs) - v = self.value(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs) - - if kv_cache is not None: - # Update the kv_cache with the new keys and values - kv_cache.update(k, v) - k, v = kv_cache.get() - - # Compute the attention scores - att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - - if valid_context_lengths is not None: - mask = torch.zeros(B, T, L + T, device=att.device) - for i in range(B): - # Create attention mask for each batch - mask[i] = self.mask[L:L + T, :L + T].clone() - mask[i, :, :(L - valid_context_lengths[i])] = 0 - mask = mask.unsqueeze(1).expand(-1, att.size(1), -1, -1) - else: - mask = self.mask[L:L + T, :L + T] - - # Apply the attention mask - att = att.masked_fill(mask == 0, float('-inf')) - att = F.softmax(att, dim=-1) - - return att \ No newline at end of file diff --git a/lzero/model/unizero_world_models/utils.py b/lzero/model/unizero_world_models/utils.py index 0a0c9dd51..bde598061 100644 --- a/lzero/model/unizero_world_models/utils.py +++ b/lzero/model/unizero_world_models/utils.py @@ -179,17 +179,44 @@ def calculate_cuda_memory_gb(past_keys_values_cache, num_layers: int): total_memory_gb = total_memory_bytes / (1024 ** 3) return total_memory_gb -def hash_state(state): +# def hash_state(state): +# """ +# Hash the state vector. + +# Arguments: +# state: The state vector to be hashed. +# Returns: +# The hash value of the state vector. +# """ +# # Use xxhash for faster hashing +# return xxhash.xxh64(state).hexdigest() + +def hash_state(state: np.ndarray) -> int: """ - Hash the state vector. + Overview: + Computes a fast and robust hash for a NumPy array state. + + Why this is optimal: + 1. Algorithm (`xxhash.xxh64`): Uses one of the fastest non-cryptographic hash + functions available, ideal for performance-critical applications like caching. + 2. Input Preparation (`state.tobytes()`): Ensures correctness by creating a + canonical byte representation of the array. This guarantees that two + logically identical arrays will produce the same hash, regardless of their + internal memory layout (e.g., C-contiguous, F-contiguous, or strided views). + 3. Output Format (`.intdigest()`): Directly produces an integer hash value, + which is the most efficient key type for Python dictionaries, avoiding the + overhead of string keys. Arguments: - state: The state vector to be hashed. + - state (np.ndarray): The state array to be hashed. + Returns: - The hash value of the state vector. + - int: A 64-bit integer hash of the state. """ - # Use xxhash for faster hashing - return xxhash.xxh64(state).hexdigest() + # Ensure the array is contiguous in memory before converting to bytes, + # although .tobytes() handles this, being explicit can sometimes be clearer. + # For simplicity and since .tobytes() defaults to C-order, we can rely on it. + return xxhash.xxh64(state.tobytes()).intdigest() @dataclass class WorldModelOutput: @@ -201,28 +228,36 @@ class WorldModelOutput: logits_value: torch.FloatTensor -def init_weights(module, norm_type='BN'): +def init_weights(module, norm_type='BN',liner_weight_zero=False): """ Initialize the weights of the module based on the specified normalization type. - Arguments: module (nn.Module): The module to initialize. norm_type (str): The type of normalization to use ('BN' for BatchNorm, 'LN' for LayerNorm). """ - if isinstance(module, (nn.Linear, nn.Embedding)): + if isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=0.02) - if isinstance(module, nn.Linear) and module.bias is not None: + elif isinstance(module, nn.Linear): + # 现在这个分支可以被正确执行了 + if norm_type == 'BN': + nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') + print("Init Linear using kaiming normal for BN") + elif norm_type == 'LN': + # 对于Transformer结构,Xavier/Glorot更常见 + nn.init.xavier_uniform_(module.weight) + print("Init Linear using xavier uniform for LN") + + if module.bias is not None: module.bias.data.zero_() + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): print(f"Init {module} using zero bias, 1 weight") try: + module.weight.data.fill_(1.0) module.bias.data.zero_() except Exception as e: print(e) - try: - module.weight.data.fill_(1.0) - except Exception as e: - print(e) + elif isinstance(module, nn.BatchNorm2d): print(f"Init nn.BatchNorm2d using zero bias, 1 weight") module.weight.data.fill_(1.0) @@ -234,13 +269,47 @@ def init_weights(module, norm_type='BN'): elif norm_type == 'LN': nn.init.xavier_uniform_(module.weight) print(f"Init nn.Conv2d using xavier uniform for LN") - elif isinstance(module, nn.Linear): - if norm_type == 'BN': - nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') - print("Init Linear using kaiming normal for BN") - elif norm_type == 'LN': - nn.init.xavier_uniform_(module.weight) - print("Init Linear using xavier uniform for LN") + +# def init_weights(module, norm_type='BN'): +# """ +# Initialize the weights of the module based on the specified normalization type. + +# Arguments: +# module (nn.Module): The module to initialize. +# norm_type (str): The type of normalization to use ('BN' for BatchNorm, 'LN' for LayerNorm). +# """ +# if isinstance(module, (nn.Linear, nn.Embedding)): +# module.weight.data.normal_(mean=0.0, std=0.02) +# if isinstance(module, nn.Linear) and module.bias is not None: +# module.bias.data.zero_() +# elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): +# print(f"Init {module} using zero bias, 1 weight") +# try: +# module.bias.data.zero_() +# except Exception as e: +# print(e) +# try: +# module.weight.data.fill_(1.0) +# except Exception as e: +# print(e) +# elif isinstance(module, nn.BatchNorm2d): +# print(f"Init nn.BatchNorm2d using zero bias, 1 weight") +# module.weight.data.fill_(1.0) +# module.bias.data.zero_() +# elif isinstance(module, nn.Conv2d): +# if norm_type == 'BN': +# nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') +# print(f"Init nn.Conv2d using kaiming normal for BN") +# elif norm_type == 'LN': +# nn.init.xavier_uniform_(module.weight) +# print(f"Init nn.Conv2d using xavier uniform for LN") +# elif isinstance(module, nn.Linear): +# if norm_type == 'BN': +# nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') +# print("Init Linear using kaiming normal for BN") +# elif norm_type == 'LN': +# nn.init.xavier_uniform_(module.weight) +# print("Init Linear using xavier uniform for LN") class LossWithIntermediateLosses: diff --git a/lzero/model/unizero_world_models/world_model.py b/lzero/model/unizero_world_models/world_model.py index 86583d198..eff859a4f 100644 --- a/lzero/model/unizero_world_models/world_model.py +++ b/lzero/model/unizero_world_models/world_model.py @@ -9,12 +9,28 @@ from torch.distributions import Categorical, Independent, Normal, TransformedDistribution, TanhTransform from lzero.model.common import SimNorm -from lzero.model.utils import cal_dormant_ratio, compute_average_weight_magnitude, cal_effective_rank +from lzero.model.utils import calculate_dormant_ratio, compute_average_weight_magnitude, compute_effective_rank from .kv_caching import KeysValues from .slicer import Head, PolicyHeadCont from .tokenizer import Tokenizer from .transformer import Transformer, TransformerConfig from .utils import LossWithIntermediateLosses, init_weights, WorldModelOutput, hash_state +from collections import OrderedDict +logging.getLogger().setLevel(logging.DEBUG) + +from collections import OrderedDict, defaultdict +import matplotlib.pyplot as plt +from matplotlib.offsetbox import OffsetImage, AnnotationBbox +from sklearn.manifold import TSNE +import torch +import numpy as np +import matplotlib.pyplot as plt +from sklearn.manifold import TSNE +from matplotlib.offsetbox import OffsetImage, AnnotationBbox +import os +import datetime +import torch +import torch.nn as nn logging.getLogger().setLevel(logging.DEBUG) @@ -45,6 +61,7 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self.transformer = Transformer(self.config) self.task_num = 1 + self.env_num = self.config.env_num if self.config.device == 'cpu': self.device = torch.device('cpu') else: @@ -70,7 +87,10 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: print(f"self.pos_emb.weight.device: {self.pos_emb.weight.device}") self.register_token_num = config.register_token_num if hasattr(config, "register_token_num") else 4 - + if self.task_embed_option == "concat_task_embed": + self.obs_per_embdding_dim = self.config.embed_dim - self.task_embed_dim + else: + self.obs_per_embdding_dim = self.config.embed_dim self.continuous_action_space = self.config.continuous_action_space # Initialize action embedding table @@ -84,16 +104,13 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self.act_embedding_table = nn.Embedding(config.action_space_size, config.embed_dim, device=self.device) logging.info(f"self.act_embedding_table.weight.device: {self.act_embedding_table.weight.device}") - self.final_norm_option_in_obs_head = getattr(config, 'final_norm_option_in_obs_head', 'SimNorm') - # self.final_norm_option_in_obs_head = getattr(config, 'final_norm_option_in_obs_head', 'LayerNorm') # TODO + self.final_norm_option_in_obs_head = getattr(config, 'final_norm_option_in_obs_head', 'LayerNorm') # Head modules self.head_rewards = self._create_head(self.act_tokens_pattern, self.support_size) - self.head_observations = self._create_head( - self.all_but_last_latent_state_pattern, - self.config.embed_dim, - self._get_final_norm(self.final_norm_option_in_obs_head) # 使用指定的归一化方法 - ) + self.head_observations = self._create_head(self.all_but_last_latent_state_pattern, self.obs_per_embdding_dim, \ + self._get_final_norm(self.final_norm_option_in_obs_head) # NOTE: using the specified normalization method for observations head + ) if self.continuous_action_space: self.sigma_type = self.config.sigma_type self.bound_type = self.config.bound_type @@ -102,7 +119,6 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self.head_policy = self._create_head(self.value_policy_tokens_pattern, self.action_space_size) self.head_value = self._create_head(self.value_policy_tokens_pattern, self.support_size) - # 对于 head 部分,查找所有以 "head_" 开头的子模块 self.head_dict = {} for name, module in self.named_children(): if name.startswith("head_"): @@ -111,11 +127,32 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self.head_dict = nn.ModuleDict(self.head_dict) # Apply weight initialization, the order is important - self.apply(lambda module: init_weights(module, norm_type=self.config.norm_type)) + # self.apply(lambda module: init_weights(module, norm_type=self.config.norm_type)) + + # Build the set of modules to skip during re-initialization. + # This is compatible with cases where self.tokenizer.encoder does not have 'pretrained_model', + # or self.tokenizer does not have 'decoder_network'. + # NOTE: This step is crucial — without skipping, pretrained modules (e.g., encoder/decoder) would be unintentionally re-initialized + skip_modules = set() + if hasattr(self.tokenizer.encoder, 'pretrained_model'): + skip_modules.update(self.tokenizer.encoder.pretrained_model.modules()) + if hasattr(self.tokenizer, 'decoder_network') and self.tokenizer.decoder_network is not None: + skip_modules.update(self.tokenizer.decoder_network.modules()) + + def custom_init(module): + # If the current module is part of the skip list, return without reinitializing + if module in skip_modules: + return + # Otherwise, apply the specified initialization method + init_weights(module, norm_type=self.config.norm_type) + + # Recursively apply `custom_init` to all submodules of the model + self.apply(custom_init) + self._initialize_last_layer() - # Cache structures - self._initialize_cache_structures() + # # Cache structures + # self._initialize_cache_structures() # Projection input dimension self._initialize_projection_input_dim() @@ -129,18 +166,25 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self.latent_recon_loss = torch.tensor(0., device=self.device) self.perceptual_loss = torch.tensor(0., device=self.device) + # 先设置为game_segment_length,以保持self.shared_pool_init_infer都是有效的kv + # TODO: 非常重要,应该改为和segment_length一样 + self.shared_pool_size_init = int(self.config.game_segment_length) # NOTE: Will having too many cause incorrect retrieval of the kv cache? + # TODO: check the size of the shared pool # for self.kv_cache_recurrent_infer # If needed, recurrent_infer should store the results of the one MCTS search. self.num_simulations = getattr(self.config, 'num_simulations', 50) - self.shared_pool_size = int(self.num_simulations*self.env_num) - self.shared_pool_recur_infer = [None] * self.shared_pool_size + + + self.shared_pool_size_recur = int(self.num_simulations*self.env_num) + self.shared_pool_recur_infer = [None] * self.shared_pool_size_recur self.shared_pool_index = 0 + # Cache structures + self._initialize_cache_structures() + # for self.kv_cache_init_infer # In contrast, init_infer only needs to retain the results of the most recent step. - # self.shared_pool_size_init = int(2*self.env_num) - self.shared_pool_size_init = int(2) # NOTE: Will having too many cause incorrect retrieval of the kv cache? self.shared_pool_init_infer = [[None] * self.shared_pool_size_init for _ in range(self.env_num)] self.shared_pool_index_init_envs = [0 for _ in range(self.env_num)] @@ -151,10 +195,141 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self.reanalyze_phase = False + def _initialize_cache_structures(self) -> None: + """Initialize cache structures for past keys and values.""" + from collections import defaultdict + + # self.past_kv_cache_recurrent_infer = defaultdict(dict) + # self.past_kv_cache_init_infer_envs = [defaultdict(dict) for _ in range(self.env_num)] + + self.past_kv_cache_recurrent_infer = {} + self.pool_idx_to_key_map_recur_infer = [None] * self.shared_pool_size_recur + self.past_kv_cache_init_infer_envs = [{} for _ in range(self.env_num)] + # 辅助数据结构,用于反向查找:pool_index -> key + self.pool_idx_to_key_map_init_envs = [[None] * self.shared_pool_size_init for _ in range(self.env_num)] + + self.keys_values_wm_list = [] + self.keys_values_wm_size_list = [] + + def _analyze_latent_representation( + self, + latent_states: torch.Tensor, + timesteps: torch.Tensor, + game_states: torch.Tensor, + predicted_values: torch.Tensor, + predicted_rewards: torch.Tensor, + step_counter: int + ): + """ + 分析并记录 latent states 的统计信息和t-SNE可视化。 + 【新功能】:在t-SNE图上显示对应的游戏图像,并标注预测的Value和Reward。 + 【已修改】:如果保存路径已存在同名文件,则在文件名后附加时间戳。 + + Args: + latent_states (torch.Tensor): Encoder的输出, shape (B*L, 1, E) + timesteps (torch.Tensor): 对应的时间步, shape (B, L) + game_states (torch.Tensor): 原始的游戏观测, shape (B, L, C, H, W) + predicted_values (torch.Tensor): 预测的标量Value, shape (B*L,) + predicted_rewards (torch.Tensor): 预测的标量Reward, shape (B*L,) + step_counter (int): 全局训练步数 + """ + # ... (统计分析部分保持不变) ... + # (确保 latent_states 和 game_states 的形状为 (N, ...)) + if latent_states.dim() > 2: + latent_states = latent_states.reshape(-1, latent_states.shape[-1]) + num_c, num_h, num_w = game_states.shape[-3:] + game_states = game_states.reshape(-1, num_c, num_h, num_w) + + with torch.no_grad(): + l2_norm = torch.norm(latent_states, p=2, dim=1).mean() + mean = latent_states.mean() + std = latent_states.std() + print(f"[Step {step_counter}] Latent Stats | L2 Norm: {l2_norm:.4f}, Mean: {mean:.4f}, Std: {std:.4f}") + + # 带图像和V/R值的 t-SNE 可视化 + if step_counter >= 0: + # if step_counter > 0 and step_counter % 200 == 0: + + print(f"[Step {step_counter}] Performing t-SNE analysis with images, values, and rewards...") + + # 将数据转换到CPU + latents_np = latent_states.detach().cpu().numpy() + images_np = game_states.detach().cpu().numpy() + values_np = predicted_values.detach().cpu().numpy() + rewards_np = predicted_rewards.detach().cpu().numpy() + + tsne = TSNE(n_components=2, perplexity=30, n_iter=300, random_state=42) + tsne_results = tsne.fit_transform(latents_np) + + # --- 绘制带图像和标注的散点图 --- + + # 减少图像数量以保持清晰 + num_points_to_plot = min(len(latents_np), 70) # 减少到70个点 + indices = np.random.choice(len(latents_np), num_points_to_plot, replace=False) + + fig, ax = plt.subplots(figsize=(20, 18)) # 增大画布尺寸 + + # 先画出所有点的散点图作为背景 + ax.scatter(tsne_results[:, 0], tsne_results[:, 1], c=values_np, cmap='viridis', alpha=0.3, s=10) + + for i in indices: + x, y = tsne_results[i] + img = images_np[i].transpose(1, 2, 0) + img = np.clip(img, 0, 1) + + # 放置图像 + im = OffsetImage(img, zoom=0.7) # 稍微放大图像 + ab = AnnotationBbox(im, (x, y), frameon=True, pad=0.0, bboxprops=dict(edgecolor='none')) + ax.add_artist(ab) + + # 在图像下方添加文字标注 + text_label = f"V:{values_np[i]:.1f} R:{rewards_np[i]:.1f}" + ax.text(x, y - 1.0, text_label, ha='center', va='top', fontsize=8, color='red', + bbox=dict(boxstyle='round,pad=0.2', fc='yellow', alpha=0.5)) + + ax.update_datalim(tsne_results) + ax.autoscale() + + ax.set_title(f't-SNE of Latent States (Value as Color) at Step {step_counter}', fontsize=16) + ax.set_xlabel('t-SNE dimension 1', fontsize=12) + ax.set_ylabel('t-SNE dimension 2', fontsize=12) + + # 添加colorbar来解释背景点的颜色 + norm = plt.Normalize(values_np.min(), values_np.max()) + sm = plt.cm.ScalarMappable(cmap='viridis', norm=norm) + sm.set_array([]) + fig.colorbar(sm, ax=ax, label='Predicted Value') + + # --- 修改部分:检查文件是否存在,如果存在则添加时间戳 --- + # 1. 构建基础路径 + # base_save_path = ( + # f'/mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/unizero_mspacman_analyze/' + # f'tsne_with_vr_{self.config.optim_type}_lr{self.config.learning_rate}_step_{step_counter}.png' + # ) + base_save_path = ( + f'/mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/unizero_mspacman_analyze/' + f'tsne_with_vr_{self.config.optim_type}_step_{step_counter}.png' + ) + + # 2. 检查文件是否存在,并确定最终保存路径 + if os.path.exists(base_save_path): + # 如果文件已存在,则生成时间戳并附加到文件名 + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + path_root, path_ext = os.path.splitext(base_save_path) + save_path = f"{path_root}_{timestamp}{path_ext}" + print(f"File '{base_save_path}' already exists. Saving to new path with timestamp.") + else: + # 如果文件不存在,则使用原始路径 + save_path = base_save_path + + # 3. 保存图像 + plt.savefig(save_path) + plt.close(fig) # 明确关闭图形对象 + print(f"t-SNE plot with V/R annotations saved to {save_path}") def _get_final_norm(self, norm_option: str) -> nn.Module: """ - 根据指定的归一化选项返回相应的归一化模块。 + Return the corresponding normalization module based on the specified normalization option. """ if norm_option == 'LayerNorm': return nn.LayerNorm(self.config.embed_dim, eps=1e-5) @@ -268,7 +443,7 @@ def custom_copy_kv_cache_to_shared_recur(self, src_kv: KeysValues) -> int: dst_layer._v_cache._size = src_layer._v_cache._size index = self.shared_pool_index - self.shared_pool_index = (self.shared_pool_index + 1) % self.shared_pool_size + self.shared_pool_index = (self.shared_pool_index + 1) % self.shared_pool_size_recur return index @@ -307,7 +482,9 @@ def _initialize_patterns(self) -> None: def _create_head(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None) -> Head: """Create head modules for the transformer.""" modules = [ + nn.LayerNorm(self.config.embed_dim), # <-- 核心优化! # TODO nn.Linear(self.config.embed_dim, self.config.embed_dim), + nn.LayerNorm(self.config.embed_dim), # 2. <-- 新增!稳定内部激活 nn.GELU(approximate='tanh'), nn.Linear(self.config.embed_dim, output_dim) ] @@ -354,14 +531,7 @@ def _initialize_last_layer(self) -> None: nn.init.zeros_(layer.bias) break - def _initialize_cache_structures(self) -> None: - """Initialize cache structures for past keys and values.""" - from collections import defaultdict - self.past_kv_cache_recurrent_infer = defaultdict(dict) - self.past_kv_cache_init_infer_envs = [defaultdict(dict) for _ in range(self.env_num)] - self.keys_values_wm_list = [] - self.keys_values_wm_size_list = [] def _initialize_projection_input_dim(self) -> None: """Initialize the projection input dimension based on the number of observation tokens.""" @@ -1231,14 +1401,54 @@ def update_cache_context(self, latent_state, is_init_infer=True, simulation_inde self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = context_length - 3 self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = context_length - 3 + # ORIGNAL + # if is_init_infer: + # # Store the latest key-value cache for initial inference + # cache_index = self.custom_copy_kv_cache_to_shared_init_envs(self.keys_values_wm_single_env, i) + # self.past_kv_cache_init_infer_envs[i][cache_key] = cache_index + # else: + # # Store the latest key-value cache for recurrent inference + # cache_index = self.custom_copy_kv_cache_to_shared_recur(self.keys_values_wm_single_env) + # self.past_kv_cache_recurrent_infer[cache_key] = cache_index + + if is_init_infer: - # Store the latest key-value cache for initial inference + # TODO + # ==================== 主动淘汰修复逻辑 ==================== + # 1. 获取即将被覆写的物理索引 + index_to_write = self.shared_pool_index_init_envs[i] + # 2. 使用辅助列表查找该索引上存储的旧的 key + old_key_to_evict = self.pool_idx_to_key_map_init_envs[i][index_to_write] + # 3. 如果存在旧 key,就从主 cache map 中删除它 + if old_key_to_evict is not None: + # 确保要删除的键确实存在,避免意外错误 + if old_key_to_evict in self.past_kv_cache_init_infer_envs[i]: + del self.past_kv_cache_init_infer_envs[i][old_key_to_evict] + + # 现在可以安全地写入新数据了 cache_index = self.custom_copy_kv_cache_to_shared_init_envs(self.keys_values_wm_single_env, i) + + # 4. 在主 cache map 和辅助列表中同时更新新的映射关系 self.past_kv_cache_init_infer_envs[i][cache_key] = cache_index + self.pool_idx_to_key_map_init_envs[i][index_to_write] = cache_key else: - # Store the latest key-value cache for recurrent inference + # ==================== RECURRENT INFER FIX ==================== + # 1. 获取即将被覆写的物理索引 + index_to_write = self.shared_pool_index + # 2. 使用辅助列表查找该索引上存储的旧的 key + old_key_to_evict = self.pool_idx_to_key_map_recur_infer[index_to_write] + # 3. 如果存在旧 key,就从主 cache map 中删除它 + if old_key_to_evict is not None: + if old_key_to_evict in self.past_kv_cache_recurrent_infer: + del self.past_kv_cache_recurrent_infer[old_key_to_evict] + + # 4. 现在可以安全地写入新数据了 cache_index = self.custom_copy_kv_cache_to_shared_recur(self.keys_values_wm_single_env) + + # 5. 在主 cache map 和辅助列表中同时更新新的映射关系 self.past_kv_cache_recurrent_infer[cache_key] = cache_index + self.pool_idx_to_key_map_recur_infer[index_to_write] = cache_key + #@profile @@ -1275,8 +1485,20 @@ def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, matched_value = None # If not found, try to retrieve from past_kv_cache_recurrent_infer + # if matched_value is None: + # matched_value = self.shared_pool_recur_infer[self.past_kv_cache_recurrent_infer.get(cache_key)] + + # ==================== TODO ==================== + # 步骤 2: 仅当在 init_infer 中未找到时,才尝试从 recurrent_infer 缓存中查找 if matched_value is None: - matched_value = self.shared_pool_recur_infer[self.past_kv_cache_recurrent_infer.get(cache_key)] + # 2.1 安全地从字典中获取索引,它可能返回 None + recur_cache_index = self.past_kv_cache_recurrent_infer.get(cache_key) + # 2.2 只有在索引有效(不是 None)的情况下,才使用它来从物理池中检索值 + if recur_cache_index is not None: + matched_value = self.shared_pool_recur_infer[recur_cache_index] + + if recur_cache_index is None: + print(f"[CACHE MISS] Not found for key={cache_key} in recurrent infer. Generating new cache.") if matched_value is not None: # If a matching cache is found, add it to the lists @@ -1325,33 +1547,42 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar # self.plot_latent_tsne_each_and_all(obs_embeddings, suffix='visual_match_memlen1-60-15_tsne') # self.save_as_image_with_timestep(batch['observations'], suffix='visual_match_memlen1-60-15_tsne') - # ========= logging for analysis ========= + # ======================== Logging for Analysis ======================== + # This block calculates various metrics for model analysis if the corresponding config flag is enabled. + # These metrics help in debugging and understanding model behavior during training. if self.analysis_dormant_ratio_weight_rank: - # Calculate dormant ratio of the encoder - shape = batch['observations'].shape # (..., C, H, W) - inputs = batch['observations'].contiguous().view(-1, *shape[-3:]) # (32,5,3,64,64) -> (160,3,64,64) - dormant_ratio_encoder_dict = cal_dormant_ratio(self.tokenizer.encoder, inputs.detach(), - dormant_threshold=self.dormant_threshold) - # print(dormant_ratio_encoder_dict) + # --- Dormant Ratio Calculation --- + # Calculate the dormant ratio of the encoder to monitor neuron activity. + shape = batch['observations'].shape # Original shape, e.g., (B, T, C, H, W) + # Reshape observations to create a single large batch for the encoder. + # E.g., (32, 5, 3, 64, 64) -> (160, 3, 64, 64) + inputs = batch['observations'].contiguous().view(-1, *shape[-3:]) + + dormant_ratio_encoder_dict = calculate_dormant_ratio( + self.tokenizer.encoder, inputs.detach(), dormant_threshold=self.dormant_threshold + ) dormant_ratio_encoder = dormant_ratio_encoder_dict['global'] - # 计算全局平均权重绝对值 + # --- Average Weight Magnitude Calculation --- + # Calculate the global average absolute weight magnitude for different model components. + # This is a useful metric for monitoring training stability. avg_weight_mag_encoder = compute_average_weight_magnitude(self.tokenizer.encoder) - # print("Average Weight Magnitude of encoder:", avg_weight_mag_encoder) - # 计算全局平均权重绝对值 avg_weight_mag_transformer = compute_average_weight_magnitude(self.transformer) - # print("Average Weight Magnitude of transformer:", avg_weight_mag_transformer) - # print(f"self.head_dict:{self.head_dict}") avg_weight_mag_head = compute_average_weight_magnitude(self.head_dict) - # print("Average Weight Magnitude of head:", avg_weight_mag_head) - # 计算 effective rank,对于 representation 层,注意: - # representation 层在 model.named_modules() 的名称为 "representation" - # print(f"self.tokenizer.encoder:{self.tokenizer.encoder}") - e_rank_last_linear = cal_effective_rank(self.tokenizer.encoder, inputs, representation_layer_name="last_linear") - # print("Effective Rank of encoder_last_linear:", e_rank_last_linear) - e_rank_sim_norm = cal_effective_rank(self.tokenizer.encoder, inputs, representation_layer_name="sim_norm") - # print("Effective Rank of encoder_sim_norm:", e_rank_sim_norm) + # --- Effective Rank Calculation --- + # Calculate the effective rank of representations from specific layers in the encoder. + # This metric helps analyze the dimensionality and information content of the learned features. + # The 'representation_layer_name' argument specifies the target layer within the model's named modules. + + # Effective rank for the final linear layer of the encoder. + e_rank_last_linear = compute_effective_rank( + self.tokenizer.encoder, inputs, representation_layer_name="last_linear" + ) + # Effective rank for the SimNorm layer of the encoder. + e_rank_sim_norm = compute_effective_rank( + self.tokenizer.encoder, inputs, representation_layer_name="sim_norm" + ) self.past_kv_cache_recurrent_infer.clear() @@ -1368,6 +1599,65 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar # Calculate the L2 norm of the latent state roots latent_state_l2_norms = torch.norm(obs_embeddings, p=2, dim=2).mean() + # Action tokens + if self.continuous_action_space: + act_tokens = batch['actions'] + else: + act_tokens = rearrange(batch['actions'], 'b l -> b l 1') + + # Forward pass to obtain predictions for observations, rewards, and policies + outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, start_pos=start_pos) + + # [新增] 从模型输出中获取中间张量 x,并分离计算图 + intermediate_tensor_x = outputs.output_sequence.detach() + + global_step = kwargs.get('global_step', 0) + # if global_step >= 0 and global_step % 10000 == 0: # 20k + if global_step > 0 and global_step % 100000000000 == 0: # 20k # TODO + + with torch.no_grad(): + # 将logits转换为标量值 + # 注意:outputs的形状是(B, L, E),我们需要reshape + batch_size, seq_len = batch['actions'].shape[0], batch['actions'].shape[1] + + pred_val_logits = outputs.logits_value.view(batch_size * seq_len, -1) + pred_rew_logits = outputs.logits_rewards.view(batch_size * seq_len, -1) + + scalar_values = inverse_scalar_transform_handle(pred_val_logits).squeeze(-1) + scalar_rewards = inverse_scalar_transform_handle(pred_rew_logits).squeeze(-1) + + self._analyze_latent_representation( + latent_states=obs_embeddings, + timesteps=batch['timestep'], + game_states=batch['observations'], + predicted_values=scalar_values, # 传入预测的Value + predicted_rewards=scalar_rewards, # 传入预测的Reward + step_counter=global_step + ) + + if self.config.use_priority: + # ==================== START MODIFICATION 5 ==================== + # Calculate value_priority, similar to MuZero. + with torch.no_grad(): + # 1. Get the predicted value logits for the first step of the sequence (t=0). + # The shape is (B, support_size). + predicted_value_logits_step0 = outputs.logits_value[:, 0, :] + + # 2. Convert the categorical prediction to a scalar value. + # The shape becomes (B, 1). + predicted_scalar_value_step0 = inverse_scalar_transform_handle(predicted_value_logits_step0) + + # 3. Get the target scalar value for the first step from the batch. + # The shape is (B, num_unroll_steps), so we take the first column. + target_scalar_value_step0 = batch['scalar_target_value'][:, 0] + + # 4. Calculate the L1 loss (absolute difference) between prediction and target. + # This is the priority. We use reduction='none' to get per-sample priorities. + value_priority = F.l1_loss(predicted_scalar_value_step0.squeeze(-1), target_scalar_value_step0, reduction='none') + # ===================== END MODIFICATION 5 ===================== + else: + value_priority = torch.tensor(0.) + if self.obs_type == 'image': # Reconstruct observations from latent state representations # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) @@ -1404,14 +1694,29 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar elif self.obs_type == 'text': perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=torch.float32) + decode_loss_mode = self.config.decode_loss_mode - # Reconstruct observations from latent state representations - # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings.reshape(-1, self.embed_dim)) + # Reconstruction loss for predicting the next latent (via backbone) + # input -> encoder -> backbone(unizero) -> decoder -> latent_recon_loss + if decode_loss_mode == "after_backbone": + next_latent_state = outputs.logits_observations[:, :-1, :] + next_target_ids = batch['observations'][:, 1:, :] + + latent_recon_loss = self.tokenizer.decode_to_reconstruction_outputs( + embeddings=next_latent_state, + target_ids=next_target_ids, + ).loss + + #Reconstruction loss for predicting the current latent (without using the backbone) + # input -> encoder -> decoder -> latent_recon_loss + elif decode_loss_mode == "before_backbone": + latent_recon_loss = self.tokenizer.decode_to_reconstruction_outputs( + embeddings=obs_embeddings, + target_ids=batch['observations'], + ).loss - # # Calculate reconstruction loss - # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 25), - # reconstructed_images) - latent_recon_loss = self.latent_recon_loss + else: + latent_recon_loss = self.latent_recon_loss elif self.obs_type == 'image_memory': # Reconstruct observations from latent state representations @@ -1433,19 +1738,10 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar latent_recon_loss = self.latent_recon_loss perceptual_loss = self.perceptual_loss - # Action tokens - if self.continuous_action_space: - act_tokens = batch['actions'] - else: - act_tokens = rearrange(batch['actions'], 'b l -> b l 1') - - # Forward pass to obtain predictions for observations, rewards, and policies - outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, start_pos=start_pos) - # ========= logging for analysis ========= if self.analysis_dormant_ratio_weight_rank: # Calculate dormant ratio of the world model - dormant_ratio_world_model = cal_dormant_ratio(self, { + dormant_ratio_world_model = calculate_dormant_ratio(self, { 'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens.detach())}, dormant_threshold=self.dormant_threshold) dormant_ratio_transformer = dormant_ratio_world_model['transformer'] @@ -1543,9 +1839,6 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar # Compute discount coefficients for each timestep discounts = self.gamma ** timesteps - if batch['mask_padding'].sum() == 0: - assert False, "mask_padding is all zeros" - # Group losses into first step, middle step, and last step first_step_losses = {} middle_step_losses = {} @@ -1584,7 +1877,6 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar # Discount reconstruction loss and perceptual loss discounted_latent_recon_loss = latent_recon_loss discounted_perceptual_loss = perceptual_loss - # Calculate overall discounted loss discounted_loss_obs = (loss_obs.view(-1, batch['actions'].shape[1] - 1) * discounts[1:]).sum()/ batch['mask_padding'][:,1:].sum() discounted_loss_rewards = (loss_rewards.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() @@ -1593,6 +1885,10 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar discounted_orig_policy_loss = (orig_policy_loss.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() discounted_policy_entropy = (policy_entropy.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + # 为了让外部的训练循环能够获取encoder的输出,我们将其加入返回字典 + # 使用 .detach() 是因为这个张量仅用于后续的clip操作,不应影响梯度计算 + detached_obs_embeddings = obs_embeddings.detach() + if self.continuous_action_space: return LossWithIntermediateLosses( latent_recon_loss_weight=self.latent_recon_loss_weight, @@ -1621,6 +1917,10 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar policy_mu=mu, policy_sigma=sigma, target_sampled_actions=target_sampled_actions, + + value_priority=value_priority, + intermediate_tensor_x=intermediate_tensor_x, + obs_embeddings=detached_obs_embeddings, # <-- 新增 ) else: return LossWithIntermediateLosses( @@ -1647,8 +1947,13 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar e_rank_last_linear = e_rank_last_linear, e_rank_sim_norm = e_rank_sim_norm, latent_state_l2_norms=latent_state_l2_norms, + + value_priority=value_priority, + intermediate_tensor_x=intermediate_tensor_x, + obs_embeddings=detached_obs_embeddings, # <-- 新增 ) + # TODO: test correctness def _calculate_policy_loss_cont_simple(self, outputs, batch: dict): """ diff --git a/lzero/model/unizero_world_models/world_model_multitask.py b/lzero/model/unizero_world_models/world_model_multitask.py index ccb47eb5a..47872da28 100644 --- a/lzero/model/unizero_world_models/world_model_multitask.py +++ b/lzero/model/unizero_world_models/world_model_multitask.py @@ -1,143 +1,137 @@ import collections import logging -from typing import Any, Tuple -from typing import Optional -from typing import Union, Dict +import math +import os +from typing import Any, Dict, Optional, Tuple, Union +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F +from ding.utils import get_rank from einops import rearrange +from matplotlib.offsetbox import AnnotationBbox, OffsetImage +from matplotlib.patches import Patch +from sklearn.manifold import TSNE from lzero.model.common import SimNorm from lzero.model.unizero_world_models.world_model import WorldModel -from lzero.model.utils import cal_dormant_ratio, compute_average_weight_magnitude, cal_effective_rank -from .moe import MoeLayer, MultiplicationFeedForward +from lzero.model.utils import ( + calculate_dormant_ratio, + calculate_effective_rank, + compute_average_weight_magnitude, +) + from .slicer import Head from .tokenizer import Tokenizer from .transformer import Transformer, TransformerConfig -from .utils import LossWithIntermediateLosses, init_weights -from .utils import WorldModelOutput, hash_state +from .utils import LossWithIntermediateLosses, WorldModelOutput, hash_state, init_weights +# Set the logging level for the root logger logging.getLogger().setLevel(logging.DEBUG) -from ding.utils import get_rank -import torch.distributed as dist -from sklearn.manifold import TSNE -import os -import numpy as np -import matplotlib.pyplot as plt -from matplotlib.patches import Patch -from matplotlib.offsetbox import OffsetImage, AnnotationBbox -import torch class WorldModelMT(WorldModel): """ Overview: - The WorldModel class is responsible for the scalable latent world model of UniZero (https://arxiv.org/abs/2406.10667), - which is used to predict the next latent state, rewards, policy, and value based on the current latent state and action. - The world model consists of three main components: - - a tokenizer, which encodes observations into embeddings, - - a transformer, which processes the input sequences, - - and heads, which generate the logits for observations, rewards, policy, and value. + The WorldModel class for the multi-task UniZero model. It is responsible for + predicting the next latent state, reward, policy, and value based on the + current latent state and action. This model is a scalable latent world model + composed of three main parts: a tokenizer, a transformer, and prediction heads. """ - #@profile - def __init__(self, config: TransformerConfig, tokenizer) -> None: + def __init__(self, config: TransformerConfig, tokenizer: Tokenizer) -> None: """ Overview: - Initialize the WorldModel class. + Initializes the multi-task WorldModel. Arguments: - - config (:obj:`TransformerConfig`): The configuration for the transformer. - - tokenizer (:obj:`Tokenizer`): The tokenizer. - - - task_embed_option (str): Strategy for incorporating task embeddings. Options: - - "add_task_embed": Adds task embeddings to observation embeddings (default). - - "concat_task_embed": Concatenates task embeddings with observation embeddings. - - "register_task_embed": Uses task embeddings as additional input tokens. + - config (:obj:`TransformerConfig`): The configuration object for the transformer and world model. + - tokenizer (:obj:`Tokenizer`): The tokenizer for encoding observations. """ super().__init__(config, tokenizer) self.tokenizer = tokenizer self.config = config - self.share_head = config.share_head # 新增参数 - if self.config.device == 'cpu': - self.device = torch.device('cpu') - else: - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - # Move all modules to the specified device + self.continuous_action_space = self.config.continuous_action_space + self.task_num = config.task_num + self.env_num = self.config.env_num + + # TODO: Investigate sharing the encoder across all 26 games and scaling its gradient. + # if not self.continuous_action_space: + # # Share encoder for Atari games. + # encoder_index = 0 + # encoder = self.tokenizer.encoder[encoder_index] + # # Register a hook for all parameters of the encoder to scale gradients. + # for p in encoder.parameters(): + # p.register_hook(self._scale_grad) + + # Whether to share prediction heads across tasks. + self.share_head = config.share_head + + self.device = torch.device('cuda' if torch.cuda.is_available() and self.config.device != 'cpu' else 'cpu') print(f"self.device: {self.device}") - # Position embedding + + # Positional embedding layer. self.pos_emb = nn.Embedding(config.max_tokens, self.config.embed_dim, device=self.device) print(f"self.pos_emb.weight.device: {self.pos_emb.weight.device}") - if self.task_embed_option == "register_task_embed": - # 由于 "register_task_embed"设定下的位置编码没有矫正 - # 使用 nn.Embedding,但初始化为全零并禁止学习 - self.pos_emb = nn.Embedding(config.max_tokens, self.config.embed_dim, device=self.device) - nn.init.constant_(self.pos_emb.weight, 0.0) # 初始化全零 - self.pos_emb.weight.requires_grad = False # 禁止更新 - - # Task embedding setup + # Task embedding setup. self.use_task_embed = config.use_task_embed - self.task_embed_option = self.config.task_embed_option # Strategy for task embeddings - self.task_num = config.task_num + self.task_embed_option = self.config.task_embed_option self.task_embed_dim = config.task_embed_dim if hasattr(config, "task_embed_dim") else 96 self.register_token_num = config.register_token_num if hasattr(config, "register_token_num") else 4 - + + if self.task_embed_option == "register_task_embed": + # When using "register_task_embed", the positional encoding is not adjusted. + # Use a non-trainable, zero-initialized nn.Embedding for positional embeddings. + self.pos_emb = nn.Embedding(config.max_tokens, self.config.embed_dim, device=self.device) + nn.init.constant_(self.pos_emb.weight, 0.0) # Initialize with all zeros. + self.pos_emb.weight.requires_grad = False # Disable updates. + + # Precompute positional embedding differences for efficient inference. self.precompute_pos_emb_diff_kv() - self.sim_norm = SimNorm(simnorm_dim=self.group_size) + self.sim_norm = SimNorm(simnorm_dim=self.config.group_size) + + # Configure embedding dimensions based on the task embedding strategy. if self.task_embed_option == "concat_task_embed": - # TODO:目前在 "concat_task_embed"下面,self.pos_emb需要设置为固定的0 - self.task_emb = nn.Embedding(self.task_num, self.task_embed_dim, max_norm=1) # TODO: TDMPC2:max_norm=1性能更好 - # self.task_emb.weight = self.sim_norm(self.task_emb.weight) + # TODO: Currently, with "concat_task_embed", self.pos_emb needs to be fixed at 0. + self.task_emb = nn.Embedding(self.task_num, self.task_embed_dim, max_norm=1) # TDMPC2 suggests max_norm=1. self.obs_act_embed_dim = config.embed_dim - self.task_embed_dim self.register_token_num = 0 elif self.task_embed_option == "register_task_embed": - self.task_emb = nn.Embedding(self.task_num, config.embed_dim, max_norm=1) # TODO + self.task_emb = nn.Embedding(self.task_num, config.embed_dim, max_norm=1) self.obs_act_embed_dim = config.embed_dim elif self.task_embed_option == "add_task_embed": - self.task_emb = nn.Embedding(self.task_num, config.embed_dim, max_norm=1) # TODO + self.task_emb = nn.Embedding(self.task_num, config.embed_dim, max_norm=1) self.obs_act_embed_dim = config.embed_dim else: self.task_emb = None self.obs_act_embed_dim = config.embed_dim self.register_token_num = 0 - self.transformer = Transformer(self.config, self.task_emb) - # TODO ======== + # --- Analysis and Logging Setup --- + self.analysis_dormant_ratio_interval = self.config.get('analysis_dormant_ratio_interval', 100) + self._analysis_step_counter = 0 + self.do_analysis = self.config.analysis_dormant_ratio_weight_rank + self.analysis_tsne = self.config.get('analysis_tsne', False) - if self.analysis_tsne: self.env_id_list = self.config.env_id_list - # 自动生成 self.env_short_names - self.env_short_names = {} - - # 遍历 env_id_list,提取短名称 - for env_id in self.config.env_id_list: - # 提取 'NoFrameskip-v4' 之前的部分作为短名称 - short_name = env_id.replace('NoFrameskip-v4', '') - self.env_short_names[env_id] = short_name - # 映射环境 ID 到简写名称 - # self.env_short_names = { - # 'PongNoFrameskip-v4': 'Pong', - # 'MsPacmanNoFrameskip-v4': 'MsPacman', - # 'SeaquestNoFrameskip-v4': 'Seaquest', - # 'BoxingNoFrameskip-v4': 'Boxing', - # 'AlienNoFrameskip-v4': 'Alien', - # 'ChopperCommandNoFrameskip-v4': 'Chopper', - # 'HeroNoFrameskip-v4': 'Hero', - # 'RoadRunnerNoFrameskip-v4': 'RoadRunner' - # } - # 颜色映射,确保每个任务有固定的颜色 + # Automatically generate short names for environments. + self.env_short_names = { + env_id: env_id.replace('NoFrameskip-v4', '') + for env_id in self.config.env_id_list + } + # Color mapping to ensure each task has a fixed color. self.num_tasks = len(self.env_id_list) - - # 生成足够多的颜色 - self.colors = self._generate_colors(len(self.env_id_list)) + self.colors = self._generate_colors(self.num_tasks) - + # --- Prediction Head Initialization --- self.head_policy_multi_task = nn.ModuleList() self.head_value_multi_task = nn.ModuleList() self.head_rewards_multi_task = nn.ModuleList() @@ -148,153 +142,113 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self.use_moe_head = config.use_moe_head self.use_softmoe_head = config.use_softmoe_head - self.to(self.device) - # Initialize configuration parameters + # Initialize configuration parameters from the config object. self._initialize_config_parameters() - - # Initialize patterns for block masks self._initialize_patterns() self.hidden_size = config.embed_dim // config.num_heads - self.continuous_action_space = self.config.continuous_action_space - - # Initialize action embedding table + # Initialize action embedding table based on action space type. if self.continuous_action_space: - # TODO: check the effect of SimNorm - # self.act_embedding_table = nn.Sequential( - # nn.Linear(config.action_space_size, config.embed_dim, device=self.device, bias=False), - # SimNorm(simnorm_dim=self.group_size)) - # print(f'config.action_space_size_list:{config.action_space_size_list}') self.act_embedding_table = nn.ModuleList([ nn.Sequential( nn.Linear(config.action_space_size_list[task_id], self.obs_act_embed_dim, device=self.device, bias=False), SimNorm(simnorm_dim=self.group_size) - ) - for task_id in range(self.task_num) + ) for task_id in range(self.task_num) ]) else: - # for discrete action space + # For discrete action space. self.act_embedding_table = nn.Embedding(config.action_space_size, self.obs_act_embed_dim, device=self.device) print(f"self.act_embedding_table.weight.device: {self.act_embedding_table.weight.device}") + print(f'=' * 20) + print(f"self.obs_act_embed_dim: {self.obs_act_embed_dim}") + print(f'=' * 20) - print(f'='*20) - print(f"self.obs_act_embed_dim:{self.obs_act_embed_dim}") - print(f'='*20) - - - # if self.num_experts_in_moe_head == -1: assert self.num_experts_in_moe_head > 0 if self.use_normal_head: - self.final_norm_option_in_obs_head = getattr(config, 'final_norm_option_in_obs_head', 'SimNorm') - # self.final_norm_option_in_obs_head = getattr(config, 'final_norm_option_in_obs_head', 'LayerNorm') # TODO - + self.final_norm_option_in_obs_head = getattr(config, 'final_norm_option_in_obs_head', 'LayerNorm') print('We use normal head') - # TODO: Normal Head for task_id in range(self.task_num): if self.continuous_action_space: - # TODO self.sigma_type = self.config.sigma_type self.bound_type = self.config.bound_type - self.head_policy = self._create_head_cont(self.value_policy_tokens_pattern, self.config.action_space_size_list[task_id]) # TODO + head_policy = self._create_head_cont(self.value_policy_tokens_pattern, self.config.action_space_size_list[task_id]) else: - self.head_policy = self._create_head(self.value_policy_tokens_pattern, self.action_space_size) + head_policy = self._create_head(self.value_policy_tokens_pattern, self.action_space_size) if not self.share_head or task_id == 0: - self.head_policy_multi_task.append(self.head_policy) + self.head_policy_multi_task.append(head_policy) - self.head_value = self._create_head(self.value_policy_tokens_pattern, self.support_size) + head_value = self._create_head(self.value_policy_tokens_pattern, self.support_size) if not self.share_head or task_id == 0: - self.head_value_multi_task.append(self.head_value) + self.head_value_multi_task.append(head_value) - self.head_rewards = self._create_head(self.act_tokens_pattern, self.support_size) + head_rewards = self._create_head(self.act_tokens_pattern, self.support_size) if not self.share_head or task_id == 0: - self.head_rewards_multi_task.append(self.head_rewards) + self.head_rewards_multi_task.append(head_rewards) - self.head_observations = self._create_head(self.all_but_last_latent_state_pattern, - self.config.embed_dim, - # self.sim_norm - self._get_final_norm(self.final_norm_option_in_obs_head) # 使用指定的归一化方法 - ) # NOTE: we add a sim_norm to the head for observations + head_observations = self._create_head( + self.all_but_last_latent_state_pattern, + self.config.embed_dim, + self._get_final_norm(self.final_norm_option_in_obs_head) # Use the specified normalization method. + ) if not self.share_head or task_id == 0: - self.head_observations_multi_task.append(self.head_observations) + self.head_observations_multi_task.append(head_observations) + elif self.use_softmoe_head: print(f'We use softmoe head, self.num_experts_in_moe_head is {self.num_experts_in_moe_head}') - # Dictionary to store SoftMoE instances self.soft_moe_instances = {} - - # Create softmoe head modules self.create_head_modules_softmoe() - self.head_policy_multi_task.append(self.head_policy) self.head_value_multi_task.append(self.head_value) self.head_rewards_multi_task.append(self.head_rewards) self.head_observations_multi_task.append(self.head_observations) elif self.use_moe_head: print(f'We use moe head, self.num_experts_in_moe_head is {self.num_experts_in_moe_head}') - # Dictionary to store moe instances self.moe_instances = {} - - # Create moe head modules self.create_head_modules_moe() - self.head_policy_multi_task.append(self.head_policy) self.head_value_multi_task.append(self.head_value) self.head_rewards_multi_task.append(self.head_rewards) self.head_observations_multi_task.append(self.head_observations) - # 对于 head 部分,查找所有以 "head_" 开头的子模块 - # self.head_dict = {} - # for name, module in self.named_children(): - # # TODO: check - # if name.startswith("head_") and name.endswith("_multi_task") : - # self.head_dict[name] = module - # if self.head_dict: - # self.head_dict = nn.ModuleDict(self.head_dict) - + # Group all head modules into a ModuleDict for easier management. self.head_dict = nn.ModuleDict({ name: module for name, module in self.named_children() if name.startswith("head_") and name.endswith("_multi_task") }) - print("="*20) + print("=" * 20) print(f"self.head_dict:{self.head_dict}") - # Apply weight initialization, the order is important + # Apply weight initialization. The order of initialization is important. self.apply(lambda module: init_weights(module, norm_type=self.config.norm_type)) - self._initialize_last_layer() + self._initialize_last_layer_mt() - # Cache structures + # --- Cache and State Initialization --- self._initialize_cache_structures() - - # Projection input dimension self._initialize_projection_input_dim() - - # Hit count and query count statistics self._initialize_statistics() - - # Initialize keys and values for transformer self._initialize_transformer_keys_values() - + self.latent_recon_loss = torch.tensor(0., device=self.device) self.perceptual_loss = torch.tensor(0., device=self.device) - # TODO: check the size of the shared pool - # for self.kv_cache_recurrent_infer - # If needed, recurrent_infer should store the results of the one MCTS search. - self.shared_pool_size = int(50*self.env_num) - self.shared_pool_recur_infer = [None] * self.shared_pool_size + # 先设置为game_segment_length,以保持self.shared_pool_init_infer都是有效的kv + # TODO: 非常重要,应该改为和segment_length一样 + self.shared_pool_size_init = int(self.config.game_segment_length) # NOTE: Will having too many cause incorrect retrieval of the kv cache? + + self.shared_pool_size_recur = int(self.num_simulations*self.env_num) + self.shared_pool_recur_infer = [None] * self.shared_pool_size_recur self.shared_pool_index = 0 - # for self.kv_cache_init_infer - # In contrast, init_infer only needs to retain the results of the most recent step. - # self.shared_pool_size_init = int(2*self.env_num) - self.shared_pool_size_init = int(2) # NOTE: Will having too many cause incorrect retrieval of the kv cache? + # For init_infer, it only needs to retain the results of the most recent step. + # NOTE: A large pool size might cause incorrect retrieval of the kv cache. self.shared_pool_init_infer = [[None] * self.shared_pool_size_init for _ in range(self.env_num)] self.shared_pool_index_init_envs = [0 for _ in range(self.env_num)] - # for self.kv_cache_wm + # For wm (world model) forward passes during training. self.shared_pool_size_wm = int(self.env_num) self.shared_pool_wm = [None] * self.shared_pool_size_wm self.shared_pool_index_wm = 0 @@ -302,17 +256,30 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self.reanalyze_phase = False self._rank = get_rank() - def _generate_colors(self, num_colors): + def _scale_grad(self, grad: torch.Tensor) -> torch.Tensor: """ - 生成足够多的独特颜色,适用于大量分类。 - - 参数: - - num_colors: 所需颜色数量。 + Overview: + Scales the gradient. This hook is registered to encoder parameters + to stabilize multi-task training. + Arguments: + - grad (:obj:`torch.Tensor`): The original gradient. + Returns: + - (:obj:`torch.Tensor`): The scaled gradient. + """ + # Scale by 1/sqrt(k) for a conservative approach, where k is the number of tasks. + return grad / math.sqrt(self.task_num) - 返回: - - colors: 颜色列表。 + def _generate_colors(self, num_colors: int) -> list: + """ + Overview: + Generates a list of unique colors for visualization purposes, + suitable for a large number of categories. + Arguments: + - num_colors (:obj:`int`): The desired number of unique colors. + Returns: + - (:obj:`list`): A list of colors. """ - # 使用多个matplotlib离散色图拼接 + # Concatenate multiple discrete colormaps from matplotlib to get more colors. color_maps = ['tab20', 'tab20b', 'tab20c'] colors = [] for cmap_name in color_maps: @@ -320,14 +287,14 @@ def _generate_colors(self, num_colors): colors.extend([cmap(i) for i in range(cmap.N)]) if len(colors) >= num_colors: break + # Generate additional colors if needed. if len(colors) < num_colors: - # 生成额外的颜色,如果需要 additional_colors = plt.cm.get_cmap('hsv', num_colors - len(colors)) colors.extend([additional_colors(i) for i in range(num_colors - len(colors))]) return colors[:num_colors] def _initialize_config_parameters(self) -> None: - """Initialize configuration parameters.""" + """Initializes model attributes from the configuration object.""" self.policy_entropy_weight = self.config.policy_entropy_weight self.predict_latent_loss_type = self.config.predict_latent_loss_type self.group_size = self.config.group_size @@ -342,16 +309,13 @@ def _initialize_config_parameters(self) -> None: self.num_observations_tokens = self.config.tokens_per_block - 1 self.latent_recon_loss_weight = self.config.latent_recon_loss_weight self.perceptual_loss_weight = self.config.perceptual_loss_weight - self.device = self.config.device self.support_size = self.config.support_size self.action_space_size = self.config.action_space_size self.max_cache_size = self.config.max_cache_size - self.env_num = self.config.env_num self.num_layers = self.config.num_layers - self.sim_norm = SimNorm(simnorm_dim=self.group_size) def _initialize_patterns(self) -> None: - """Initialize patterns for block masks.""" + """Initializes patterns (masks) for selecting specific tokens for prediction heads.""" self.all_but_last_latent_state_pattern = torch.ones(self.config.tokens_per_block) self.all_but_last_latent_state_pattern[-2] = 0 self.act_tokens_pattern = torch.zeros(self.config.tokens_per_block) @@ -360,9 +324,7 @@ def _initialize_patterns(self) -> None: self.value_policy_tokens_pattern[-2] = 1 def _get_final_norm(self, norm_option: str) -> nn.Module: - """ - 根据指定的归一化选项返回相应的归一化模块。 - """ + """Returns the specified normalization module.""" if norm_option == 'LayerNorm': return nn.LayerNorm(self.config.embed_dim, eps=1e-5) elif norm_option == 'SimNorm': @@ -370,10 +332,12 @@ def _get_final_norm(self, norm_option: str) -> nn.Module: else: raise ValueError(f"Unsupported final_norm_option_in_obs_head: {norm_option}") - def _create_head(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None) -> Head: - """Create head modules for the transformer.""" + def _create_head(self, block_mask: torch.Tensor, output_dim: int, norm_layer: Optional[nn.Module] = None) -> Head: + """Creates a standard prediction head.""" modules = [ + nn.LayerNorm(self.config.embed_dim), # <-- 核心优化! # TODO nn.Linear(self.config.embed_dim, self.config.embed_dim), + nn.LayerNorm(self.config.embed_dim), # 2. <-- 新增!稳定内部激活 nn.GELU(approximate='tanh'), nn.Linear(self.config.embed_dim, output_dim) ] @@ -385,9 +349,10 @@ def _create_head(self, block_mask: torch.Tensor, output_dim: int, norm_layer=Non head_module=nn.Sequential(*modules) ) - def _create_head_moe(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None, moe=None) -> Head: - """Create moe head modules for the transformer.""" + def _create_head_moe(self, block_mask: torch.Tensor, output_dim: int, norm_layer: Optional[nn.Module] = None, moe: Optional[nn.Module] = None) -> Head: + """Creates a prediction head with a Mixture-of-Experts (MoE) layer.""" modules = [ + nn.LayerNorm(self.config.embed_dim), # <-- 核心优化! # TODO moe, nn.Linear(self.config.embed_dim, output_dim) ] @@ -398,55 +363,32 @@ def _create_head_moe(self, block_mask: torch.Tensor, output_dim: int, norm_layer block_mask=block_mask, head_module=nn.Sequential(*modules) ) - def get_moe(self, name): - """Get or create a MoE instance""" + + def get_moe(self, name: str) -> nn.Module: + """Gets or creates a MoE instance by name.""" + from .moe import MoELayer, MultiplicationFeedForward + if name not in self.moe_instances: - # Create multiple FeedForward instances for multiplication-based MoE - self.experts = nn.ModuleList([ + # Create multiple FeedForward instances for multiplication-based MoE. + experts = nn.ModuleList([ MultiplicationFeedForward(self.config) for _ in range(self.config.num_experts_of_moe_in_transformer) ]) - - self.moe_instances[name] = MoeLayer( - experts=self.experts, + self.moe_instances[name] = MoELayer( + experts=experts, gate=nn.Linear(self.config.embed_dim, self.config.num_experts_of_moe_in_transformer, bias=False), num_experts_per_tok=1, ) - return self.moe_instances[name] - def create_head_modules_moe(self): - """Create all softmoe head modules""" - # Rewards head - self.head_rewards = self._create_head_moe( - self.act_tokens_pattern, - self.support_size, - moe=self.get_moe("rewards_moe") - ) + def create_head_modules_moe(self) -> None: + """Creates all MoE prediction head modules.""" + self.head_rewards = self._create_head_moe(self.act_tokens_pattern, self.support_size, moe=self.get_moe("rewards_moe")) + self.head_observations = self._create_head_moe(self.all_but_last_latent_state_pattern, self.embed_dim, norm_layer=self.sim_norm, moe=self.get_moe("observations_moe")) + self.head_policy = self._create_head_moe(self.value_policy_tokens_pattern, self.action_space_size, moe=self.get_moe("policy_moe")) + self.head_value = self._create_head_moe(self.value_policy_tokens_pattern, self.support_size, moe=self.get_moe("value_moe")) - # Observations head - self.head_observations = self._create_head_moe( - self.all_but_last_latent_state_pattern, - self.embdding_dim, - norm_layer=self.sim_norm, # NOTE - moe=self.get_moe("observations_moe") - ) - - # Policy head - self.head_policy = self._create_head_moe( - self.value_policy_tokens_pattern, - self.action_space_size, - moe=self.get_moe("policy_moe") - ) - - # Value head - self.head_value = self._create_head_moe( - self.value_policy_tokens_pattern, - self.support_size, - moe=self.get_moe("value_moe") - ) - - def _create_head_softmoe(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None, soft_moe=None) -> Head: - """Create softmoe head modules for the transformer.""" + def _create_head_softmoe(self, block_mask: torch.Tensor, output_dim: int, norm_layer: Optional[nn.Module] = None, soft_moe: Optional[nn.Module] = None) -> Head: + """Creates a prediction head with a Soft-MoE layer.""" modules = [ soft_moe, nn.Linear(self.config.embed_dim, output_dim) @@ -458,113 +400,72 @@ def _create_head_softmoe(self, block_mask: torch.Tensor, output_dim: int, norm_l block_mask=block_mask, head_module=nn.Sequential(*modules) ) - - def get_soft_moe(self, name): - """Get or create a SoftMoE instance""" - # from soft_moe_pytorch import SoftMoE - # if name not in self.soft_moe_instances: - # self.soft_moe_instances[name] = SoftMoE( - # dim=self.embed_dim, - # seq_len=20, # TODO - # num_experts=self.num_experts_in_moe_head, - # ) + + def get_soft_moe(self, name: str) -> nn.Module: + """Gets or creates a Soft-MoE instance by name.""" from soft_moe_pytorch import DynamicSlotsSoftMoE as SoftMoE if name not in self.soft_moe_instances: self.soft_moe_instances[name] = SoftMoE( dim=self.embed_dim, num_experts=self.num_experts_in_moe_head, - geglu = True + geglu=True ) return self.soft_moe_instances[name] - def create_head_modules_softmoe(self): - """Create all softmoe head modules""" - # Rewards head - self.head_rewards = self._create_head_softmoe( - self.act_tokens_pattern, - self.support_size, - soft_moe=self.get_soft_moe("rewards_soft_moe") - ) + def create_head_modules_softmoe(self) -> None: + """Creates all Soft-MoE prediction head modules.""" + self.head_rewards = self._create_head_softmoe(self.act_tokens_pattern, self.support_size, soft_moe=self.get_soft_moe("rewards_soft_moe")) + self.head_observations = self._create_head_softmoe(self.all_but_last_latent_state_pattern, self.config.embed_dim, norm_layer=self.sim_norm, soft_moe=self.get_soft_moe("observations_soft_moe")) + self.head_policy = self._create_head_softmoe(self.value_policy_tokens_pattern, self.action_space_size, soft_moe=self.get_soft_moe("policy_soft_moe")) + self.head_value = self._create_head_softmoe(self.value_policy_tokens_pattern, self.support_size, soft_moe=self.get_soft_moe("value_soft_moe")) - # Observations head - self.head_observations = self._create_head_softmoe( - self.all_but_last_latent_state_pattern, - self.config.embed_dim, - norm_layer=self.sim_norm, # NOTE - soft_moe=self.get_soft_moe("observations_soft_moe") - ) - - # Policy head - self.head_policy = self._create_head_softmoe( - self.value_policy_tokens_pattern, - self.action_space_size, - soft_moe=self.get_soft_moe("policy_soft_moe") - ) - - # Value head - self.head_value = self._create_head_softmoe( - self.value_policy_tokens_pattern, - self.support_size, - soft_moe=self.get_soft_moe("value_soft_moe") - ) - - def _initialize_last_layer(self) -> None: - """Initialize the last linear layer.""" + def _initialize_last_layer_mt(self) -> None: + """Initializes the last linear layer of prediction heads to zero for training stability.""" last_linear_layer_init_zero = True print(f'world_model_mt.py:self.task_num:{self.task_num}') if last_linear_layer_init_zero: if self.continuous_action_space: - module_to_initialize = [self.head_value, self.head_rewards, self.head_observations] + # For continuous actions, policy head might have a different initialization strategy. + module_to_initialize = self.head_value_multi_task + self.head_rewards_multi_task + self.head_observations_multi_task else: - module_to_initialize = [self.head_policy, self.head_value, self.head_rewards, self.head_observations] - - # TODO: multitask - if self.task_num == 1: - for head in module_to_initialize: - for layer in reversed(head.head_module): - if isinstance(layer, nn.Linear): - nn.init.zeros_(layer.weight) - if layer.bias is not None: - nn.init.zeros_(layer.bias) - break - elif self.task_num > 1: - if self.continuous_action_space: - module_to_initialize = self.head_value_multi_task + self.head_rewards_multi_task + self.head_observations_multi_task - else: - module_to_initialize = self.head_policy_multi_task + self.head_value_multi_task + self.head_rewards_multi_task + self.head_observations_multi_task + module_to_initialize = self.head_policy_multi_task + self.head_value_multi_task + self.head_rewards_multi_task + self.head_observations_multi_task - for head in module_to_initialize: - for layer in reversed(head.head_module): - if isinstance(layer, nn.Linear): - nn.init.zeros_(layer.weight) - if layer.bias is not None: - nn.init.zeros_(layer.bias) - break + for head in module_to_initialize: + for layer in reversed(head.head_module): + if isinstance(layer, nn.Linear): + nn.init.zeros_(layer.weight) + if layer.bias is not None: + nn.init.zeros_(layer.bias) + break def _initialize_cache_structures(self) -> None: - """Initialize cache structures for past keys and values.""" - self.past_kv_cache_recurrent_infer = collections.OrderedDict() - self.past_kv_cache_init_infer = collections.OrderedDict() - self.past_kv_cache_init_infer_envs = [collections.OrderedDict() for _ in range(self.env_num)] + """Initializes cache structures for storing past keys and values during inference.""" + # self.past_kv_cache_recurrent_infer = collections.OrderedDict() + # self.past_kv_cache_init_infer_envs = [collections.OrderedDict() for _ in range(self.env_num)] + + self.past_kv_cache_recurrent_infer = {} + self.pool_idx_to_key_map_recur_infer = [None] * self.shared_pool_size_recur + self.past_kv_cache_init_infer_envs = [{} for _ in range(self.env_num)] + # 辅助数据结构,用于反向查找:pool_index -> key + self.pool_idx_to_key_map_init_envs = [[None] * self.shared_pool_size_init for _ in range(self.env_num)] + self.keys_values_wm_list = [] self.keys_values_wm_size_list = [] def _initialize_projection_input_dim(self) -> None: - """Initialize the projection input dimension based on the number of observation tokens.""" + """Initializes the input dimension for the projection based on observation tokenization.""" if self.num_observations_tokens == 16: self.projection_input_dim = 128 elif self.num_observations_tokens == 1: - if self.task_embed_option == "concat_task_embed": - self.projection_input_dim = self.config.embed_dim - self.task_embed_dim - elif self.task_embed_option == "register_task_embed": - self.projection_input_dim = self.config.embed_dim - elif self.task_embed_option == "add_task_embed": + if self.task_embed_option in ["concat_task_embed", "register_task_embed", "add_task_embed"]: self.projection_input_dim = self.config.embed_dim + if self.task_embed_option == "concat_task_embed": + self.projection_input_dim -= self.task_embed_dim else: self.projection_input_dim = self.config.embed_dim def _initialize_statistics(self) -> None: - """Initialize counters for hit count and query count statistics.""" + """Initializes counters for cache hit rates and other statistics.""" self.hit_count = 0 self.total_query_count = 0 self.length_largethan_maxminus5_context_cnt = 0 @@ -572,32 +473,26 @@ def _initialize_statistics(self) -> None: self.root_hit_cnt = 0 self.root_total_query_cnt = 0 - #@profile def _initialize_transformer_keys_values(self) -> None: - """Initialize keys and values for the transformer.""" - self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, - max_tokens=self.context_length) - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=self.env_num, - max_tokens=self.context_length) + """Initializes empty key-value cache structures for the transformer.""" + self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.context_length) + self.keys_values_wm = self.transformer.generate_empty_keys_values(n=self.env_num, max_tokens=self.context_length) - #@profile - def precompute_pos_emb_diff_kv(self): - """ Precompute positional embedding differences for key and value. """ + def precompute_pos_emb_diff_kv(self) -> None: + """ + Overview: + Precomputes positional embedding differences for keys and values. This is an + optimization to speed up KV cache updates during recurrent inference by avoiding + re-computation of positional embeddings. + """ if self.context_length <= 2: - # If context length is 2 or less, no context is present - return + return # No context to precompute for. - # Precompute positional embedding matrices for inference in collect/eval stages, not for training - self.positional_embedding_k = [ - self._get_positional_embedding(layer, 'key') - for layer in range(self.config.num_layers) - ] - self.positional_embedding_v = [ - self._get_positional_embedding(layer, 'value') - for layer in range(self.config.num_layers) - ] + # Precompute positional embedding matrices for all layers. + self.positional_embedding_k = [self._get_positional_embedding(layer, 'key') for layer in range(self.config.num_layers)] + self.positional_embedding_v = [self._get_positional_embedding(layer, 'value') for layer in range(self.config.num_layers)] - # Precompute all possible positional embedding differences + # Precompute all possible positional embedding differences. self.pos_emb_diff_k = [] self.pos_emb_diff_v = [] @@ -605,9 +500,10 @@ def precompute_pos_emb_diff_kv(self): layer_pos_emb_diff_k = {} layer_pos_emb_diff_v = {} + # This is for the case when context window is full and we shift it. + # TODO: Generalize for different start/end points if necessary. for start in [2]: - for end in [self.context_length - 1]: # TODO - # for end in [self.context_length - self.register_token_num - 1]: + for end in [self.context_length - 1]: original_pos_emb_k = self.positional_embedding_k[layer][:, :, start:end, :] new_pos_emb_k = self.positional_embedding_k[layer][:, :, :end - start, :] layer_pos_emb_diff_k[(start, end)] = new_pos_emb_k - original_pos_emb_k @@ -619,107 +515,85 @@ def precompute_pos_emb_diff_kv(self): self.pos_emb_diff_k.append(layer_pos_emb_diff_k) self.pos_emb_diff_v.append(layer_pos_emb_diff_v) - #@profile - def _get_positional_embedding(self, layer, attn_type) -> torch.Tensor: + def _get_positional_embedding(self, layer: int, attn_type: str) -> torch.Tensor: """ - Helper function to get positional embedding for a given layer and attention type. - - Arguments: - - layer (:obj:`int`): Layer index. - - attn_type (:obj:`str`): Attention type, either 'key' or 'value'. - - Returns: - - torch.Tensor: The positional embedding tensor. - """ - # TODO: detach() ========== + Overview: + Helper function to get positional embedding for a given layer and attention type. + Arguments: + - layer (:obj:`int`): The layer index. + - attn_type (:obj:`str`): The attention type, either 'key' or 'value'. + Returns: + - (:obj:`torch.Tensor`): The positional embedding tensor, detached from the graph. + """ + # TODO: Review the use of detach(). It's used here to prevent gradients from flowing back + # through the positional embeddings during this pre-computation phase. attn_func = getattr(self.transformer.blocks[layer].attn, attn_type) - if torch.cuda.is_available(): - return attn_func(self.pos_emb.weight).view( - 1, self.config.max_tokens, self.num_heads, self.embed_dim // self.num_heads - ).transpose(1, 2).to(self.device).detach() - else: - return attn_func(self.pos_emb.weight).view( - 1, self.config.max_tokens, self.num_heads, self.embed_dim // self.num_heads - ).transpose(1, 2).detach() - - #@profile - def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tuple]], - past_keys_values: Optional[torch.Tensor] = None, - kvcache_independent: bool = False, is_init_infer: bool = True, - valid_context_lengths: Optional[torch.Tensor] = None, task_id=0) -> WorldModelOutput: + pos_emb = attn_func(self.pos_emb.weight).view( + 1, self.config.max_tokens, self.num_heads, self.embed_dim // self.num_heads + ).transpose(1, 2) + return pos_emb.to(self.device).detach() + + def forward( + self, + obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tuple]], + past_keys_values: Optional[torch.Tensor] = None, + kvcache_independent: bool = False, + is_init_infer: bool = True, + valid_context_lengths: Optional[torch.Tensor] = None, + task_id: int = 0 + ) -> WorldModelOutput: """ - Forward pass for the model. - + Overview: + Main forward pass for the world model. It processes either observation embeddings, + action tokens, or a combination of both, and passes them through the transformer + to generate predictions. Arguments: - - obs_embeddings_or_act_tokens (:obj:`dict`): Dictionary containing observation embeddings or action tokens. - - past_keys_values (:obj:`Optional[torch.Tensor]`): Previous keys and values for transformer. - - kvcache_independent (:obj:`bool`): Whether to use independent key-value caching. - - is_init_infer (:obj:`bool`): Initialize inference. - - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Valid context lengths. + - obs_embeddings_or_act_tokens (:obj:`Dict`): A dictionary containing input tensors. + Can be 'obs_embeddings', 'act_tokens', or 'obs_embeddings_and_act_tokens'. + - past_keys_values (:obj:`Optional[torch.Tensor]`): The KV cache from previous steps. + - kvcache_independent (:obj:`bool`): Whether to use independent KV caching per item in the batch. + - is_init_infer (:obj:`bool`): Flag indicating if this is an initial inference step. + - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Tensor of valid context lengths for each item. + - task_id (:obj:`int`): The ID of the current task. Returns: - - WorldModelOutput: Model output containing logits for observations, rewards, policy, and value. + - (:obj:`WorldModelOutput`): An object containing the transformer output and logits for + observations, rewards, policy, and value. """ if self.use_task_embed: - self.task_embeddings = self.task_emb(torch.tensor(task_id, device=self.device)) # NOTE: TODO - self.task_embeddings = self.sim_norm(self.task_embeddings.view(1,-1)).view(-1) # TODO + self.task_embeddings = self.task_emb(torch.tensor(task_id, device=self.device)) + self.task_embeddings = self.sim_norm(self.task_embeddings.view(1, -1)).view(-1) else: - self.task_embeddings = torch.zeros(self.config.embed_dim, device=self.device) # ============= TODO: no task_embeddings now ============= + # Use a zero tensor if task embeddings are disabled. + self.task_embeddings = torch.zeros(self.config.embed_dim, device=self.device) - # Determine previous steps based on key-value caching method + prev_steps = 0 if past_keys_values is None else past_keys_values.size if kvcache_independent: - prev_steps = torch.tensor([0 if past_keys_values is None else past_kv.size for past_kv in past_keys_values], - device=self.device) - else: - prev_steps = 0 if past_keys_values is None else past_keys_values.size + prev_steps = torch.tensor([0 if past_keys_values is None else past_kv.size for past_kv in past_keys_values], device=self.device) - # Reset valid_context_lengths during initial inference if is_init_infer: valid_context_lengths = None - # inference阶段: collect或者eval Process observation embeddings + # --- Branch 1: Inference Phase (Collect/Eval) - Process observation embeddings --- if 'obs_embeddings' in obs_embeddings_or_act_tokens: obs_embeddings = obs_embeddings_or_act_tokens['obs_embeddings'] if len(obs_embeddings.shape) == 2: obs_embeddings = obs_embeddings.unsqueeze(1) - - # TODO: multitask + + # Apply task embeddings based on the chosen strategy. if self.task_embed_option == "add_task_embed": obs_embeddings = obs_embeddings + self.task_embeddings elif self.task_embed_option == "concat_task_embed": - - # print(f'=='*20) - # print(f"is_init_infer:{is_init_infer}") - # print(f'obs_embeddings.shape:{obs_embeddings.shape}') - # print(f'self.task_embeddings.shape:{self.task_embeddings.shape}') - # print(f'=='*20) - - # if is_init_infer: - # # 注意只有在inference时,只有在is_init_infer时拼接task embeddings,recurr_infer中已经在init_infer中增加了task embeddings的信息了 - # # Expand task embeddings to match the sequence shape - # task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(obs_embeddings.shape[0], obs_embeddings.shape[1], -1) - # obs_embeddings = torch.cat([obs_embeddings, task_emb_expanded], dim=-1) - if is_init_infer and not self.reanalyze_phase: - # 注意只有在inference时,只有在is_init_infer时拼接task embeddings,recurr_infer中已经在init_infer中增加了task embeddings的信息了 - # Expand task embeddings to match the sequence shape + # Concatenate task embeddings only during initial inference. task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(obs_embeddings.shape[0], obs_embeddings.shape[1], -1) obs_embeddings = torch.cat([obs_embeddings, task_emb_expanded], dim=-1) - # if is_init_infer: - # if self.task_embed_option == "register_task_embed": - # # Register task embeddings as input tokens - # task_tokens = self.task_embeddings.expand(obs_embeddings.shape[0], self.register_token_length, -1) - # obs_embeddings = torch.cat([task_tokens, obs_embeddings], dim=1) - num_steps = obs_embeddings.size(1) - sequences = self._add_position_embeddings(obs_embeddings, prev_steps, num_steps, kvcache_independent, - is_init_infer, valid_context_lengths) + sequences = self._add_position_embeddings(obs_embeddings, prev_steps, num_steps, kvcache_independent, is_init_infer, valid_context_lengths) - - # inference阶段: collect或者eval Process action tokens + # --- Branch 2: Inference Phase (Collect/Eval) - Process action tokens --- elif 'act_tokens' in obs_embeddings_or_act_tokens: act_tokens = obs_embeddings_or_act_tokens['act_tokens'] - if self.continuous_action_space: num_steps = 1 act_tokens = act_tokens.float() @@ -729,347 +603,254 @@ def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tu if len(act_tokens.shape) == 3: act_tokens = act_tokens.squeeze(1) num_steps = act_tokens.size(1) + + # Get action embeddings from the task-specific or shared table. if self.task_num >= 1 and self.continuous_action_space: act_embeddings = self.act_embedding_table[task_id](act_tokens) else: act_embeddings = self.act_embedding_table(act_tokens) - - if self.task_embed_option == "add_task_embed": - # TODO: 对于action_token不需要增加task_embeddings会造成歧义,反而干扰学习 - # obs_embeddings = obs_embeddings + self.task_embeddings - pass - elif self.task_embed_option == "concat_task_embed": - # print(f'=='*20) - # print(f'act_embeddings.shape:{act_embeddings.shape}') - # print(f'self.task_embeddings.shape:{self.task_embeddings.shape}') - # print(f'=='*20) - # Expand task embeddings to match the sequence shape + + # Apply task embeddings. + if self.task_embed_option == "concat_task_embed": task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(act_embeddings.shape[0], act_embeddings.shape[1], -1) act_embeddings = torch.cat([act_embeddings, task_emb_expanded], dim=-1) - - sequences = self._add_position_embeddings(act_embeddings, prev_steps, num_steps, kvcache_independent, - is_init_infer, valid_context_lengths) + sequences = self._add_position_embeddings(act_embeddings, prev_steps, num_steps, kvcache_independent, is_init_infer, valid_context_lengths) - # 训练阶段: Process combined observation embeddings and action tokens + # --- Branch 3: Training Phase - Process combined observation embeddings and action tokens --- else: - # "add_task_embed"在self._process_obs_act_combined_cont方法内部处理, - # process_obs_act_combined目前还没有增加task_embed的concat和register模式 if self.continuous_action_space: sequences, num_steps = self._process_obs_act_combined_cont(obs_embeddings_or_act_tokens, prev_steps, task_id=task_id) else: sequences, num_steps = self._process_obs_act_combined(obs_embeddings_or_act_tokens, prev_steps) - - # Pass sequences through transformer + # Pass sequences through the transformer. x = self._transformer_pass(sequences, past_keys_values, kvcache_independent, valid_context_lengths, task_id=task_id) - # Generate logits - - # 1,...,0,1 https://github.com/eloialonso/iris/issues/19 - # TODO: one head or moe head - if self.use_moe_head: + # Generate logits using shared, task-specific, or MoE heads. + head_index = 0 if self.share_head else task_id + if self.use_moe_head or self.use_softmoe_head: logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps) logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps) logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps) logits_value = self.head_value(x, num_steps=num_steps, prev_steps=prev_steps) else: - # 使用共享head或任务特定的head - head_index = 0 if self.share_head else task_id - # print(f"="*20) - # print(f"head_index:{head_index}") - # print(f"="*20) logits_observations = self.head_observations_multi_task[head_index](x, num_steps=num_steps, prev_steps=prev_steps) logits_rewards = self.head_rewards_multi_task[head_index](x, num_steps=num_steps, prev_steps=prev_steps) logits_policy = self.head_policy_multi_task[head_index](x, num_steps=num_steps, prev_steps=prev_steps) logits_value = self.head_value_multi_task[head_index](x, num_steps=num_steps, prev_steps=prev_steps) - # logits_ends is None return WorldModelOutput(x, logits_observations, logits_rewards, None, logits_policy, logits_value) - - #@profile - def _add_position_embeddings(self, embeddings, prev_steps, num_steps, kvcache_independent, is_init_infer, - valid_context_lengths): + def _add_position_embeddings( + self, + embeddings: torch.Tensor, + prev_steps: Union[int, torch.Tensor], + num_steps: int, + kvcache_independent: bool, + is_init_infer: bool, + valid_context_lengths: Optional[torch.Tensor] + ) -> torch.Tensor: """ - Add position embeddings to the input embeddings. - + Overview: + Adds positional embeddings to the input embeddings. Arguments: - embeddings (:obj:`torch.Tensor`): Input embeddings. - - prev_steps (:obj:`torch.Tensor`): Previous steps. - - num_steps (:obj:`int`): Number of steps. - - kvcache_independent (:obj:`bool`): Whether to use independent key-value caching. - - is_init_infer (:obj:`bool`): Initialize inference. - - valid_context_lengths (:obj:`torch.Tensor`): Valid context lengths. + - prev_steps (:obj:`Union[int, torch.Tensor]`): Number of previous steps in the cache. + - num_steps (:obj:`int`): Number of new steps being added. + - kvcache_independent (:obj:`bool`): Flag for independent KV caching. + - is_init_infer (:obj:`bool`): Flag for initial inference. + - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Valid context lengths for each sequence. Returns: - - torch.Tensor: Embeddings with position information added. + - (:obj:`torch.Tensor`): Embeddings with added positional information. """ if kvcache_independent: - steps_indices = prev_steps + torch.arange(num_steps, device=embeddings.device) - position_embeddings = self.pos_emb(steps_indices).view(-1, num_steps, embeddings.shape[-1]) + steps_indices = prev_steps.unsqueeze(1) + torch.arange(num_steps, device=embeddings.device) + position_embeddings = self.pos_emb(steps_indices) return embeddings + position_embeddings else: - # 修复前面kv_cache和z/a的位置编码不对, kv_cache, z/a, register_token - # if self.use_task_embed and self.task_embed_option == "register_task_embed": - # if prev_steps + num_steps + self.register_token_num > self.context_length: - # prev_steps = self.context_length - self.register_token_num - 1 - if is_init_infer: - return embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)) + # For initial inference, positions are sequential from the previous step count. + pos_indices = prev_steps + torch.arange(num_steps, device=self.device) + return embeddings + self.pos_emb(pos_indices) else: + # For recurrent steps, use valid_context_lengths to get correct positions. valid_context_lengths = torch.tensor(self.keys_values_wm_size_list_current, device=self.device) - - # try: - position_embeddings = self.pos_emb( - valid_context_lengths + torch.arange(num_steps, device=self.device)).unsqueeze(1) - # except Exception as e: - # print(e) - # import ipdb; ipdb.set_trace() - + pos_indices = valid_context_lengths.unsqueeze(1) + torch.arange(num_steps, device=self.device) + position_embeddings = self.pos_emb(pos_indices) return embeddings + position_embeddings - #@profile - def _process_obs_act_combined_cont(self, obs_embeddings_or_act_tokens, prev_steps, task_id=0): + def _process_obs_act_combined_cont(self, obs_embeddings_or_act_tokens: dict, prev_steps: int, task_id: int = 0) -> Tuple[torch.Tensor, int]: """ - Process combined observation embeddings and action tokens. - + Overview: + Processes and combines observation embeddings and continuous action tokens for training. Arguments: - - obs_embeddings_or_act_tokens (:obj:`dict`): Dictionary containing combined observation embeddings and action tokens. - - prev_steps (:obj:`torch.Tensor`): Previous steps. + - obs_embeddings_or_act_tokens (:obj:`dict`): Dictionary with 'obs_embeddings_and_act_tokens'. + - prev_steps (:obj:`int`): Number of previous steps. + - task_id (:obj:`int`): The current task ID. Returns: - - torch.Tensor: Combined observation and action embeddings with position information added. + - (:obj:`Tuple[torch.Tensor, int]`): A tuple of the combined sequence tensor and the number of steps. """ obs_embeddings, act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens'] if len(obs_embeddings.shape) == 3: - obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, - -1) + obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, -1) num_steps = int(obs_embeddings.size(1) * (obs_embeddings.size(2) + 1)) - if self.continuous_action_space: - act_tokens = act_tokens.float() - if len(act_tokens.shape) == 2: # TODO - act_tokens = act_tokens.unsqueeze(-1) + act_tokens = act_tokens.float() + if len(act_tokens.shape) == 2: + act_tokens = act_tokens.unsqueeze(-1) - # B, L, E act_embeddings = self.act_embedding_table[task_id](act_tokens) - B, L, K, E = obs_embeddings.size() + B, L, K, E_obs = obs_embeddings.size() + obs_act_embeddings = torch.empty(B, L * (K + 1), self.config.embed_dim, device=self.device) if self.task_embed_option == "concat_task_embed": - # B, L*2, E - obs_act_embeddings = torch.empty(B, L * (K + 1), self.config.embed_dim, device=self.device) - else: - # B, L*2, E - obs_act_embeddings = torch.empty(B, L * (K + 1), self.config.embed_dim, device=self.device) - - - if self.task_embed_option == "concat_task_embed": - # print(f'=='*20) - # print(f'self.task_embeddings.shape:{self.task_embeddings.shape}') - # print(f'=='*20) - # Expand task embeddings to match the sequence shape task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(B, 1, -1) for i in range(L): + obs = obs_embeddings[:, i, :, :] if self.task_embed_option == "add_task_embed": - obs = obs_embeddings[:, i, :, :] + self.task_embeddings # Shape: (B, K, E) TODO: task_embeddings + obs = obs + self.task_embeddings elif self.task_embed_option == "concat_task_embed": - # print(f'=='*20) - # print(f'obs_embeddings.shape:{obs_embeddings.shape}') - # print(f'=='*20) - obs = torch.cat([obs_embeddings[:, i, :, :], task_emb_expanded], dim=-1) - else: - obs = obs_embeddings[:, i, :, :] # Shape: (B, K, E) + obs = torch.cat([obs, task_emb_expanded.expand(B, K, -1)], dim=-1) act = act_embeddings[:, i, :].unsqueeze(1) if self.task_embed_option == "concat_task_embed": - # print(f'=='*20) - # print(f'act_embeddings.shape:{act_embeddings.shape}') - # print(f'=='*20) act = torch.cat([act, task_emb_expanded], dim=-1) obs_act = torch.cat([obs, act], dim=1) - # print(f'obs_act.shape:{obs_act.shape}') - obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act + pos_indices = prev_steps + torch.arange(num_steps, device=self.device) + return obs_act_embeddings + self.pos_emb(pos_indices), num_steps - return obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)), num_steps - - - #@profile - def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps, task_id=0): + def _process_obs_act_combined(self, obs_embeddings_or_act_tokens: dict, prev_steps: int, task_id: int = 0) -> Tuple[torch.Tensor, int]: """ - Process combined observation embeddings and action tokens. - + Overview: + Processes and combines observation embeddings and discrete action tokens for training. Arguments: - - obs_embeddings_or_act_tokens (:obj:`dict`): Dictionary containing combined observation embeddings and action tokens. - - prev_steps (:obj:`torch.Tensor`): Previous steps. + - obs_embeddings_or_act_tokens (:obj:`dict`): Dictionary with 'obs_embeddings_and_act_tokens'. + - prev_steps (:obj:`int`): Number of previous steps. + - task_id (:obj:`int`): The current task ID. Returns: - - torch.Tensor: Combined observation and action embeddings with position information added. + - (:obj:`Tuple[torch.Tensor, int]`): A tuple of the combined sequence tensor and the number of steps. """ obs_embeddings, act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens'] if len(obs_embeddings.shape) == 3: - obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, - -1) + obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, -1) num_steps = int(obs_embeddings.size(1) * (obs_embeddings.size(2) + 1)) act_embeddings = self.act_embedding_table(act_tokens) - B, L, K, E = obs_embeddings.size() - if self.task_embed_option == "concat_task_embed": - # B, L*2, E - obs_act_embeddings = torch.empty(B, L * (K + 1), self.config.embed_dim, device=self.device) - else: - # B, L*2, E - obs_act_embeddings = torch.empty(B, L * (K + 1), self.config.embed_dim, device=self.device) + B, L, K, E_obs = obs_embeddings.size() + obs_act_embeddings = torch.empty(B, L * (K + 1), self.config.embed_dim, device=self.device) if self.task_embed_option == "concat_task_embed": - # Expand task embeddings to match the sequence shape task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(B, 1, -1) - for i in range(L): + obs = obs_embeddings[:, i, :, :] if self.task_embed_option == "add_task_embed": - obs = obs_embeddings[:, i, :, :] + self.task_embeddings # Shape: (B, K, E) TODO: task_embeddings + obs = obs + self.task_embeddings elif self.task_embed_option == "concat_task_embed": - obs = torch.cat([obs_embeddings[:, i, :, :], task_emb_expanded], dim=-1) - else: - obs = obs_embeddings[:, i, :, :] # Shape: (B, K, E) + obs = torch.cat([obs, task_emb_expanded.expand(B, K, -1)], dim=-1) act = act_embeddings[:, i, 0, :].unsqueeze(1) if self.task_embed_option == "concat_task_embed": act = torch.cat([act, task_emb_expanded], dim=-1) obs_act = torch.cat([obs, act], dim=1) - # print(f'obs_act.shape:{obs_act.shape}') - obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act - return obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)), num_steps - - - #@profile - # def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps, task_id=0): - # """ - # Process combined observation embeddings and action tokens. - - # Arguments: - # - obs_embeddings_or_act_tokens (:obj:`dict`): Dictionary containing combined observation embeddings and action tokens. - # - prev_steps (:obj:`torch.Tensor`): Previous steps. - # Returns: - # - torch.Tensor: Combined observation and action embeddings with position information added. - # """ - # obs_embeddings, act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens'] - # if len(obs_embeddings.shape) == 3: - # obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, - # -1) - - # num_steps = int(obs_embeddings.size(1) * (obs_embeddings.size(2) + 1)) - # # act_embeddings = self.act_embedding_table[task_id](act_tokens) - # act_embeddings = self.act_embedding_table(act_tokens) - - # B, L, K, E = obs_embeddings.size() - # obs_act_embeddings = torch.empty(B, L * (K + 1), E, device=self.device) - - # for i in range(L): - # # obs = obs_embeddings[:, i, :, :] - # obs = obs_embeddings[:, i, :, :] + self.task_embeddings # Shape: (B, K, E) TODO: task_embeddings - # act = act_embeddings[:, i, 0, :].unsqueeze(1) - # obs_act = torch.cat([obs, act], dim=1) - # obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act - - # return obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)), num_steps - - #@profile - def _transformer_pass(self, sequences, past_keys_values, kvcache_independent, valid_context_lengths, task_id=0): + pos_indices = prev_steps + torch.arange(num_steps, device=self.device) + return obs_act_embeddings + self.pos_emb(pos_indices), num_steps + + def _transformer_pass( + self, + sequences: torch.Tensor, + past_keys_values: Optional[torch.Tensor], + kvcache_independent: bool, + valid_context_lengths: Optional[torch.Tensor], + task_id: int = 0 + ) -> torch.Tensor: """ - Pass sequences through the transformer. - + Overview: + Passes sequences through the transformer, handling different KV cache modes. Arguments: - sequences (:obj:`torch.Tensor`): Input sequences. - - past_keys_values (:obj:`Optional[torch.Tensor]`): Previous keys and values for transformer. - - kvcache_independent (:obj:`bool`): Whether to use independent key-value caching. - - valid_context_lengths (:obj:`torch.Tensor`): Valid context lengths. + - past_keys_values (:obj:`Optional[torch.Tensor]`): The KV cache from previous steps. + - kvcache_independent (:obj:`bool`): Flag for independent KV caching. + - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Tensor of valid context lengths. + - task_id (:obj:`int`): The current task ID. Returns: - - torch.Tensor: Transformer output. + - (:obj:`torch.Tensor`): The output from the transformer. """ if kvcache_independent: - x = [self.transformer(sequences[k].unsqueeze(0), past_kv, - valid_context_lengths=valid_context_lengths[k].unsqueeze(0)) for k, past_kv in - enumerate(past_keys_values)] + x = [ + self.transformer(sequences[k].unsqueeze(0), past_kv, valid_context_lengths=valid_context_lengths[k].unsqueeze(0)) + for k, past_kv in enumerate(past_keys_values) + ] return torch.cat(x, dim=0) else: return self.transformer(sequences, past_keys_values, valid_context_lengths=valid_context_lengths) - #@profile @torch.no_grad() - def reset_for_initial_inference(self, obs_act_dict: torch.FloatTensor, task_id = 0) -> torch.FloatTensor: + def reset_for_initial_inference(self, obs_act_dict: dict, task_id: int = 0) -> Tuple[WorldModelOutput, torch.Tensor]: """ - Reset the model state based on initial observations and actions. - + Overview: + Resets the model state for the beginning of an episode or a new inference sequence. + It processes the initial observations and actions to create the first latent state + and populate the KV cache. Arguments: - - obs_act_dict (:obj:`torch.FloatTensor`): A dictionary containing 'obs', 'action', and 'current_obs'. + - obs_act_dict (:obj:`dict`): A dictionary containing 'obs', 'action', and 'current_obs'. + - task_id (:obj:`int`): The ID of the current task. Returns: - - torch.FloatTensor: The outputs from the world model and the latent state. + - (:obj:`Tuple[WorldModelOutput, torch.Tensor]`): A tuple containing the world model output + and the initial latent state. """ if self.use_task_embed: - self.task_embeddings = self.task_emb(torch.tensor(task_id, device=self.device)) # NOTE: TODO - self.task_embeddings = self.sim_norm(self.task_embeddings.view(1,-1)).view(-1) # TODO + self.task_embeddings = self.task_emb(torch.tensor(task_id, device=self.device)) + self.task_embeddings = self.sim_norm(self.task_embeddings.view(1, -1)).view(-1) else: - self.task_embeddings = torch.zeros(self.config.embed_dim, device=self.device) # ============= TODO: no task_embeddings now ============= - + self.task_embeddings = torch.zeros(self.config.embed_dim, device=self.device) - # Extract observations, actions, and current observations from the dictionary. - if isinstance(obs_act_dict, dict): - batch_obs = obs_act_dict['obs'] - batch_action = obs_act_dict['action'] - batch_current_obs = obs_act_dict['current_obs'] + batch_obs = obs_act_dict['obs'] + batch_action = obs_act_dict['action'] + batch_current_obs = obs_act_dict['current_obs'] - # Encode observations to latent embeddings. obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch_obs, task_id=task_id) if batch_current_obs is not None: - # ================ Collect and Evaluation Phase ================ - # Encode current observations to latent embeddings + # --- Collect and Evaluation Phase --- current_obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch_current_obs, task_id=task_id) - # print(f"current_obs_embeddings.device: {current_obs_embeddings.device}") - if self.use_task_embed and self.task_embed_option == "register_task_embed": - self.latent_state = current_obs_embeddings - elif not self.use_task_embed: + # The latent state is the combination of observation embedding and task embedding. + if self.use_task_embed: + if self.task_embed_option == "add_task_embed": + self.latent_state = current_obs_embeddings + self.task_embeddings + elif self.task_embed_option == "concat_task_embed": + task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(current_obs_embeddings.shape[0], current_obs_embeddings.shape[1], -1) + self.latent_state = torch.cat([current_obs_embeddings, task_emb_expanded], dim=-1) + else: # "register_task_embed" or other cases + self.latent_state = current_obs_embeddings + else: self.latent_state = current_obs_embeddings - # ================ NOTE ================ - # import ipdb; ipdb.set_trace() - # self.latent_state 是原来的obs_embeddings与task_embedding的组合: add或者concat - if self.use_task_embed and self.task_embed_option == "add_task_embed": - self.latent_state = current_obs_embeddings + self.task_embeddings - if self.use_task_embed and self.task_embed_option == "concat_task_embed": - task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(current_obs_embeddings.shape[0], current_obs_embeddings.shape[1], -1) - self.latent_state = torch.cat([current_obs_embeddings, task_emb_expanded], dim=-1) - # ================ NOTE ================ - outputs_wm = self.wm_forward_for_initial_inference(obs_embeddings, batch_action, current_obs_embeddings, task_id=task_id) else: - # ================ calculate the target value in Train phase ================ - - # self.latent_state = obs_embeddings - - # ================ NOTE ================ - # import ipdb; ipdb.set_trace() - # self.latent_state 是原来的obs_embeddings与task_embedding的组合: add或者concat - if self.use_task_embed and self.task_embed_option == "add_task_embed": - self.latent_state = obs_embeddings + self.task_embeddings - elif self.use_task_embed and self.task_embed_option == "concat_task_embed": - task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(obs_embeddings.shape[0], obs_embeddings.shape[1], -1) - self.latent_state = torch.cat([obs_embeddings, task_emb_expanded], dim=-1) + # --- Training Phase (for calculating target values) --- + if self.use_task_embed: + if self.task_embed_option == "add_task_embed": + self.latent_state = obs_embeddings + self.task_embeddings + elif self.task_embed_option == "concat_task_embed": + task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(obs_embeddings.shape[0], obs_embeddings.shape[1], -1) + self.latent_state = torch.cat([obs_embeddings, task_emb_expanded], dim=-1) + else: + self.latent_state = obs_embeddings else: self.latent_state = obs_embeddings - # print(f" Train phase self.latent_state.shape: {self.latent_state.shape}") - # ================ NOTE ================ - outputs_wm = self.wm_forward_for_initial_inference(obs_embeddings, batch_action, None, task_id=task_id) return outputs_wm, self.latent_state @@ -1194,12 +975,6 @@ def wm_forward_for_initial_inference(self, last_obs_embeddings: torch.LongTensor outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (last_obs_embeddings, act_tokens)}, task_id=task_id) - # if self.reanalyze_phase: - # # TODO - # outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (last_obs_embeddings, act_tokens)}, is_init_infer=False, task_id=task_id) - # else: - # outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (last_obs_embeddings, act_tokens)}, is_init_infer=True, task_id=task_id) - # select the last timestep for each sample last_steps_value = outputs_wm.logits_value[:, -1:, :] outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1) @@ -1247,8 +1022,6 @@ def forward_recurrent_inference(self, state_action_history, simulation_index=0, Returns: - tuple: A tuple containing output sequence, updated latent state, reward, logits policy, and logits value. """ - # import ipdb; ipdb.set_trace() - latest_state, action = state_action_history[-1] ready_env_num = latest_state.shape[0] @@ -1289,7 +1062,6 @@ def forward_recurrent_inference(self, state_action_history, simulation_index=0, else: obs_embeddings_or_act_tokens = {'obs_embeddings': token} - # try: # Perform forward pass outputs_wm = self.forward( obs_embeddings_or_act_tokens, @@ -1298,25 +1070,9 @@ def forward_recurrent_inference(self, state_action_history, simulation_index=0, is_init_infer=False, task_id = task_id ) - # except Exception as e: - # print(e) - # import ipdb; ipdb.set_trace() self.keys_values_wm_size_list_current = [i + 1 for i in self.keys_values_wm_size_list_current] - # if self.task_embed_option == "register_task_embed": - # # kv_cache, z/a, register_token - # # 这样修复后kv_cache的位置编码不是从0开始的, 那后面按照从零开始矫正也就是错误的, - # # 但是由于self.keys_values_wm._keys_values[layer]._k_cache._size < context_length - 1,所以不会矫正 - # # 但是在_add_position_embeddings时,prev_steps是错误的,导致新增的z/a的位置编码索引与前面的kv不连续 - # # import ipdb; ipdb.set_trace() - # print(f'self.keys_values_wm_size_list_current:{self.keys_values_wm_size_list_current}') - # print(f'self.keys_values_wm.size:{self.keys_values_wm.size}') - # self.keys_values_wm_size_list_current = [min(self.keys_values_wm.size, i + 1) for i in self.keys_values_wm_size_list_current] - # else: - # self.keys_values_wm_size_list_current = [i + 1 for i in self.keys_values_wm_size_list_current] - - if k == 0: reward = outputs_wm.logits_rewards # (B,) @@ -1411,10 +1167,6 @@ def update_cache_context(self, latent_state, is_init_infer=True, simulation_inde for i in range(latent_state.size(0)): # ============ Iterate over each environment ============ cache_key = hash_state(latent_state[i].view(-1).cpu().numpy()) # latent_state[i] is torch.Tensor - # if self.task_embed_option == "register_task_embed": - # context_length = self.context_length - self.register_token_num - # else: - # context_length = self.context_length context_length = self.context_length @@ -1529,15 +1281,53 @@ def update_cache_context(self, latent_state, is_init_infer=True, simulation_inde self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = context_length - 3 self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = context_length - 3 + # ORIGNAL + # if is_init_infer: + # # Store the latest key-value cache for initial inference + # cache_index = self.custom_copy_kv_cache_to_shared_init_envs(self.keys_values_wm_single_env, i) + # self.past_kv_cache_init_infer_envs[i][cache_key] = cache_index + # else: + # # Store the latest key-value cache for recurrent inference + # cache_index = self.custom_copy_kv_cache_to_shared_recur(self.keys_values_wm_single_env) + # self.past_kv_cache_recurrent_infer[cache_key] = cache_index + + if is_init_infer: - # Store the latest key-value cache for initial inference - # import ipdb; ipdb.set_trace() + # TODO + # ==================== 主动淘汰修复逻辑 ==================== + # 1. 获取即将被覆写的物理索引 + index_to_write = self.shared_pool_index_init_envs[i] + # 2. 使用辅助列表查找该索引上存储的旧的 key + old_key_to_evict = self.pool_idx_to_key_map_init_envs[i][index_to_write] + # 3. 如果存在旧 key,就从主 cache map 中删除它 + if old_key_to_evict is not None: + # 确保要删除的键确实存在,避免意外错误 + if old_key_to_evict in self.past_kv_cache_init_infer_envs[i]: + del self.past_kv_cache_init_infer_envs[i][old_key_to_evict] + + # 现在可以安全地写入新数据了 cache_index = self.custom_copy_kv_cache_to_shared_init_envs(self.keys_values_wm_single_env, i) + + # 4. 在主 cache map 和辅助列表中同时更新新的映射关系 self.past_kv_cache_init_infer_envs[i][cache_key] = cache_index + self.pool_idx_to_key_map_init_envs[i][index_to_write] = cache_key else: - # Store the latest key-value cache for recurrent inference + # ==================== RECURRENT INFER FIX ==================== + # 1. 获取即将被覆写的物理索引 + index_to_write = self.shared_pool_index + # 2. 使用辅助列表查找该索引上存储的旧的 key + old_key_to_evict = self.pool_idx_to_key_map_recur_infer[index_to_write] + # 3. 如果存在旧 key,就从主 cache map 中删除它 + if old_key_to_evict is not None: + if old_key_to_evict in self.past_kv_cache_recurrent_infer: + del self.past_kv_cache_recurrent_infer[old_key_to_evict] + + # 4. 现在可以安全地写入新数据了 cache_index = self.custom_copy_kv_cache_to_shared_recur(self.keys_values_wm_single_env) + + # 5. 在主 cache map 和辅助列表中同时更新新的映射关系 self.past_kv_cache_recurrent_infer[cache_key] = cache_index + self.pool_idx_to_key_map_recur_infer[index_to_write] = cache_key #@profile def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, @@ -1573,9 +1363,20 @@ def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, matched_value = None # If not found, try to retrieve from past_kv_cache_recurrent_infer + # if matched_value is None: + # matched_value = self.shared_pool_recur_infer[self.past_kv_cache_recurrent_infer.get(cache_key)] + + # ==================== TODO ==================== + # 步骤 2: 仅当在 init_infer 中未找到时,才尝试从 recurrent_infer 缓存中查找 if matched_value is None: - # import ipdb; ipdb.set_trace() - matched_value = self.shared_pool_recur_infer[self.past_kv_cache_recurrent_infer.get(cache_key)] + # 2.1 安全地从字典中获取索引,它可能返回 None + recur_cache_index = self.past_kv_cache_recurrent_infer.get(cache_key) + # 2.2 只有在索引有效(不是 None)的情况下,才使用它来从物理池中检索值 + if recur_cache_index is not None: + matched_value = self.shared_pool_recur_infer[recur_cache_index] + + if recur_cache_index is None: + print(f"[CACHE MISS] Not found for key={cache_key} in recurrent infer. Generating new cache.") if matched_value is not None: # If a matching cache is found, add it to the lists @@ -1592,43 +1393,40 @@ def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, {'obs_embeddings': torch.from_numpy(state_single_env).unsqueeze(0).to(self.device)}, past_keys_values=self.keys_values_wm_single_env, is_init_infer=True, task_id=task_id ) - # if self.reanalyze_phase: - # self.forward( - # {'obs_embeddings': torch.from_numpy(state_single_env).unsqueeze(0).to(self.device)}, - # past_keys_values=self.keys_values_wm_single_env, is_init_infer=False, task_id=task_id - # ) - # else: - # self.forward( - # {'obs_embeddings': torch.from_numpy(state_single_env).unsqueeze(0).to(self.device)}, - # past_keys_values=self.keys_values_wm_single_env, is_init_infer=True, task_id=task_id - # ) self.keys_values_wm_list.append(self.keys_values_wm_single_env) self.keys_values_wm_size_list.append(1) return self.keys_values_wm_size_list - - def plot_embeddings(self, tsne_results, task_ids, observations, samples_per_task=5, save_dir='tsne_plots_26games'): - """ - 生成 t-SNE 可视化图,并在图中为每个任务随机标注指定数量的观测样本图像。 - - 参数: - - tsne_results: t-SNE 降维结果 (N x 2 的数组) - - task_ids: 环境任务 ID,用于着色 (N 的数组) - - observations: 对应的观测样本 (N x C x H x W 的张量或数组) - - samples_per_task: 每个任务选择的样本数量,默认 5 - - save_dir: 保存路径,默认 'tsne_plots_26games' + def plot_embeddings( + self, + tsne_results: np.ndarray, + task_ids: np.ndarray, + observations: Union[np.ndarray, torch.Tensor], + samples_per_task: int = 5, + save_dir: str = 'tsne_plots_26games' + ) -> None: """ + Overview: + Generates a t-SNE visualization plot and annotates it with a specified number of + randomly selected observation images for each task. - # 创建保存目录 + Arguments: + - tsne_results (:obj:`np.ndarray`): The t-SNE dimensionality reduction results (N x 2 array). + - task_ids (:obj:`np.ndarray`): An array of environment task IDs, used for coloring the points (N array). + - observations (:obj:`Union[np.ndarray, torch.Tensor]`): The corresponding observation samples (N x C x H x W tensor or array). + - samples_per_task (:obj:`int`): The number of samples to select for image annotation per task. Defaults to 5. + - save_dir (:obj:`str`): The directory path where the plot will be saved. Defaults to 'tsne_plots_26games'. + """ + # Create the save directory if it doesn't exist. os.makedirs(save_dir, exist_ok=True) - print(f"[INFO] 保存目录已创建或已存在: {save_dir}") + print(f"[INFO] Save directory created or already exists: {save_dir}") - # 创建 t-SNE 图 - print("[INFO] 开始绘制 t-SNE 散点图...") - plt.figure(figsize=(18, 10)) # 增大图像宽度以适应右侧图例 + # Create the t-SNE plot. + print("[INFO] Starting to draw the t-SNE scatter plot...") + plt.figure(figsize=(18, 10)) # Increase figure width to accommodate the legend on the right. - # 散点图 + # Scatter plot of the t-SNE results. scatter = plt.scatter( tsne_results[:, 0], tsne_results[:, 1], @@ -1638,7 +1436,7 @@ def plot_embeddings(self, tsne_results, task_ids, observations, samples_per_task linewidth=0.5 ) - # 创建自定义图例 + # Create a custom legend for the tasks. legend_elements = [] for idx, env_id in enumerate(self.env_id_list): short_name = self.env_short_names.get(env_id, env_id) @@ -1647,53 +1445,56 @@ def plot_embeddings(self, tsne_results, task_ids, observations, samples_per_task Patch(facecolor=color, edgecolor='w', label=f"{idx}: {short_name}") ) - # 将图例放在图像右侧,并且每个图例项占一行 + # Place the legend on the right side of the plot, with each item on a new line. plt.legend( handles=legend_elements, title="Environment IDs", loc='center left', - bbox_to_anchor=(1, 0.5), # 图例在图像右侧中央 + bbox_to_anchor=(1, 0.5), # Position the legend in the center-right of the plot area. fontsize=10, title_fontsize=12, ncol=1, - frameon=False # 去除图例边框,增强美观 + frameon=False # Remove the legend border for a cleaner look. ) - # 设置标题和轴标签 + # Set the title and axis labels. plt.title("t-SNE of Latent States across Environments", fontsize=16) plt.xlabel("t-SNE Dimension 1", fontsize=14) plt.ylabel("t-SNE Dimension 2", fontsize=14) plt.xticks(fontsize=12) plt.yticks(fontsize=12) plt.grid(True, linestyle='--', alpha=0.5) - print(f"[INFO] t-SNE 散点图绘制完成,共有 {len(tsne_results)} 个点。") + print(f"[INFO] t-SNE scatter plot completed with {len(tsne_results)} points.") - # 为每个任务选择指定数量的样本进行图像标注 - print(f"[INFO] 开始为每个任务选择 {samples_per_task} 个样本进行图像标注...") + # Select a specified number of samples per task for image annotation. + print(f"[INFO] Starting to select {samples_per_task} samples per task for image annotation...") for task_id in range(len(self.env_id_list)): - # 找到当前任务的所有索引 + # Find all indices for the current task. task_indices = np.where(task_ids == task_id)[0] if len(task_indices) == 0: - print(f"[WARNING] 任务 ID {task_id} 没有对应的样本。") + print(f"[WARNING] No samples found for task ID {task_id}.") continue - # 如果样本数量少于所需,全部选取 + + # If the number of samples is less than required, select all of them. if len(task_indices) < samples_per_task: selected_indices = task_indices - print(f"[INFO] 任务 ID {task_id} 的样本数量 ({len(task_indices)}) 少于 {samples_per_task},选取全部。") + print(f"[INFO] Task ID {task_id} has fewer samples ({len(task_indices)}) than required ({samples_per_task}). Selecting all.") else: selected_indices = np.random.choice(task_indices, size=samples_per_task, replace=False) - print(f"[INFO] 任务 ID {task_id} 随机选取 {samples_per_task} 个样本进行标注。") + print(f"[INFO] Randomly selecting {samples_per_task} samples for task ID {task_id} for annotation.") for idx in selected_indices: img = observations[idx] if isinstance(img, torch.Tensor): img = img.cpu().numpy() - if img.shape[0] == 1 or img.shape[0] == 3: # 处理灰度图或 RGB 图 + + # Handle channel-first (C, H, W) format for grayscale or RGB images. + if img.shape[0] == 1 or img.shape[0] == 3: img = np.transpose(img, (1, 2, 0)) else: raise ValueError(f"Unsupported image shape: {img.shape}") - # 标准化图像到 [0,1] 范围 + # Normalize the image to the [0, 1] range for correct display. img_min, img_max = img.min(), img.max() if img_max - img_min > 1e-5: img = (img - img_min) / (img_max - img_min) @@ -1708,37 +1509,52 @@ def plot_embeddings(self, tsne_results, task_ids, observations, samples_per_task pad=0.3 ) plt.gca().add_artist(ab) - print(f"[INFO] 已添加图像标注: 任务 ID {task_id}, 点索引 {idx}, t-SNE 坐标 ({tsne_results[idx, 0]:.2f}, {tsne_results[idx, 1]:.2f})") + print(f"[INFO] Added image annotation: Task ID {task_id}, point index {idx}, t-SNE coords ({tsne_results[idx, 0]:.2f}, {tsne_results[idx, 1]:.2f})") - # 调整布局以适应图例 - plt.tight_layout(rect=[0, 0, 0.9, 1]) # 为右侧的图例预留空间 + # Adjust layout to prevent the legend from being cut off. + plt.tight_layout(rect=[0, 0, 0.9, 1]) # Reserve space for the legend on the right. - # 保存图像,使用高分辨率 + # Save the figure in both PNG and PDF formats with high resolution. save_path_png = os.path.join(save_dir, 'tsne_plot.png') save_path_pdf = os.path.join(save_dir, 'tsne_plot.pdf') plt.savefig(save_path_png, dpi=300, bbox_inches='tight') plt.savefig(save_path_pdf, dpi=300, bbox_inches='tight') - print(f"[INFO] t-SNE 可视化图已保存至: {save_path_png} 和 {save_path_pdf}") + print(f"[INFO] t-SNE visualization plot saved to: {save_path_png} and {save_path_pdf}") plt.close() - + @torch.no_grad() - def gather_and_plot(self, local_embeddings, local_task_ids, local_observations): + def gather_and_plot( + self, + local_embeddings: torch.Tensor, + local_task_ids: torch.Tensor, + local_observations: torch.Tensor + ) -> None: + """ + Overview: + Gathers embeddings, task IDs, and observations from all distributed processes. + On the main process (rank 0), it performs t-SNE and plots the results. + + Arguments: + - local_embeddings (:obj:`torch.Tensor`): The embedding tensor from the current process. + - local_task_ids (:obj:`torch.Tensor`): The task ID tensor from the current process. + - local_observations (:obj:`torch.Tensor`): The observation tensor from the current process. + """ world_size = dist.get_world_size() rank = dist.get_rank() - # 准备接收来自所有进程的CUDA张量 + # Prepare lists to receive CUDA tensors from all processes. embeddings_list = [torch.zeros_like(local_embeddings) for _ in range(world_size)] task_ids_list = [torch.zeros_like(local_task_ids) for _ in range(world_size)] - # 准备接收来自所有进程的CPU对象 + # Prepare a list to receive CPU objects (observations) from all processes. observations_list = [None for _ in range(world_size)] try: - # 收集CUDA张量:embeddings和task_ids + # Gather CUDA tensors: embeddings and task_ids. dist.all_gather(embeddings_list, local_embeddings) dist.all_gather(task_ids_list, local_task_ids) - # 收集CPU对象:observations + # Gather CPU objects: observations (must be moved to CPU and converted first). local_observations_cpu = local_observations.cpu().numpy().tolist() dist.all_gather_object(observations_list, local_observations_cpu) except RuntimeError as e: @@ -1746,26 +1562,26 @@ def gather_and_plot(self, local_embeddings, local_task_ids, local_observations): return if rank == 0: - # 拼接所有embeddings和task_ids + # Concatenate all embeddings and task_ids on the main process. all_embeddings = torch.cat(embeddings_list, dim=0).cpu().numpy() all_task_ids = torch.cat(task_ids_list, dim=0).cpu().numpy() - # 拼接所有observations - all_observations = [] + # Concatenate all observations. + all_observations_list = [] for obs in observations_list: - all_observations.extend(obs) - all_observations = np.array(all_observations) + all_observations_list.extend(obs) + all_observations = np.array(all_observations_list) print(f"Shape of all_embeddings: {all_embeddings.shape}") all_embeddings = all_embeddings.reshape(-1, all_embeddings.shape[-1]) print(f"Shape of all_observations: {all_observations.shape}") all_observations = all_observations.reshape(-1, *all_observations.shape[-3:]) - # 执行t-SNE降维 + # Perform t-SNE dimensionality reduction. tsne = TSNE(n_components=2, random_state=42) tsne_results = tsne.fit_transform(all_embeddings) - # 绘制并保存图像 + # Plot and save the resulting image. self.plot_embeddings(tsne_results, all_task_ids, all_observations, save_dir=f'tsne_plots_{self.num_tasks}games') #@profile @@ -1775,23 +1591,24 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar if self.analysis_tsne: # =========== tsne analysis =========== - # 确保embeddings在CUDA设备上且为稠密张量 if not obs_embeddings.is_cuda: obs_embeddings = obs_embeddings.cuda() obs_embeddings = obs_embeddings.contiguous() - - # 保存当前进程的 embeddings 和 task_id local_embeddings = obs_embeddings.detach() local_task_ids = torch.full((local_embeddings.size(0),), task_id, dtype=torch.long, device=local_embeddings.device) - - # 将observations移到CPU并转换为numpy local_observations = batch['observations'].detach().cpu() - - # 进行数据收集和可视化 self.gather_and_plot(local_embeddings, local_task_ids, local_observations) # ========= logging for analysis ========= if self.analysis_dormant_ratio_weight_rank: + self._analysis_step_counter += 1 + self.do_analysis = ( + self.analysis_dormant_ratio_weight_rank # 总开关 + and self._analysis_step_counter % self.analysis_dormant_ratio_interval == 0 + ) + + # ========= logging for analysis ========= + if self.do_analysis: # Calculate dormant ratio of the encoder shape = batch['observations'].shape # (..., C, H, W) inputs = batch['observations'].contiguous().view(-1, *shape[-3:]) # (32,5,3,64,64) -> (160,3,64,64) @@ -1799,38 +1616,35 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar encoder_index = task_id else: encoder_index = 0 - dormant_ratio_encoder_dict = cal_dormant_ratio(self.tokenizer.encoder[encoder_index], inputs.detach(), + dormant_ratio_encoder_dict = calculate_dormant_ratio(self.tokenizer.encoder[encoder_index], inputs.detach(), dormant_threshold=self.dormant_threshold) - - # print(dormant_ratio_encoder_dict) dormant_ratio_encoder = dormant_ratio_encoder_dict['global'] - # 计算全局平均权重绝对值 avg_weight_mag_encoder = compute_average_weight_magnitude(self.tokenizer.encoder[encoder_index]) - # print("Average Weight Magnitude of encoder:", avg_weight_mag_encoder) - # 计算全局平均权重绝对值 avg_weight_mag_transformer = compute_average_weight_magnitude(self.transformer) - # print("Average Weight Magnitude of transformer:", avg_weight_mag_transformer) - # print(f"self.head_dict:{self.head_dict}") avg_weight_mag_head = compute_average_weight_magnitude(self.head_dict) - # print("Average Weight Magnitude of head:", avg_weight_mag_head) - # 计算 effective rank,对于 representation 层,注意: - # representation 层在 model.named_modules() 的名称为 "representation" - # print(f"self.tokenizer.encoder:{self.tokenizer.encoder}") - - e_rank_last_linear = cal_effective_rank(self.tokenizer.encoder[encoder_index], inputs, representation_layer_name="last_linear") - # print("Effective Rank of encoder_last_linear:", e_rank_last_linear) - e_rank_sim_norm = cal_effective_rank(self.tokenizer.encoder[encoder_index], inputs, representation_layer_name="final_norm") - # print("Effective Rank of encoder_sim_norm:", e_rank_sim_norm) - - self.past_kv_cache_init_infer.clear() + e_rank_last_linear = calculate_effective_rank(self.tokenizer.encoder[encoder_index], inputs, representation_layer_name="last_linear") + try: + e_rank_sim_norm = calculate_effective_rank(self.tokenizer.encoder[encoder_index], inputs, representation_layer_name="final_norm") + except Exception as e: + e_rank_sim_norm = torch.tensor(0.) + + for kv_cache_dict_env in self.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() self.past_kv_cache_recurrent_infer.clear() self.keys_values_wm_list.clear() torch.cuda.empty_cache() else: dormant_ratio_encoder = torch.tensor(0.) + avg_weight_mag_encoder = torch.tensor(0.) + avg_weight_mag_transformer = torch.tensor(0.) + avg_weight_mag_head = torch.tensor(0.) + e_rank_last_linear = torch.tensor(0.) + e_rank_sim_norm = torch.tensor(0.) + # dormant_ratio_encoder = None + # Calculate the L2 norm of the latent state roots latent_state_l2_norms = torch.norm(obs_embeddings, p=2, dim=2).mean() @@ -1902,26 +1716,46 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar # Forward pass to obtain predictions for observations, rewards, and policies outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, task_id=task_id) + if self.config.use_priority: + # ==================== START MODIFICATION 5 ==================== + # Calculate value_priority, similar to MuZero. + with torch.no_grad(): + # 1. Get the predicted value logits for the first step of the sequence (t=0). + # The shape is (B, support_size). + predicted_value_logits_step0 = outputs.logits_value[:, 0, :] + + # 2. Convert the categorical prediction to a scalar value. + # The shape becomes (B, 1). + predicted_scalar_value_step0 = inverse_scalar_transform_handle(predicted_value_logits_step0) + + # 3. Get the target scalar value for the first step from the batch. + # The shape is (B, num_unroll_steps), so we take the first column. + target_scalar_value_step0 = batch['scalar_target_value'][:, 0] + + # 4. Calculate the L1 loss (absolute difference) between prediction and target. + # This is the priority. We use reduction='none' to get per-sample priorities. + value_priority = F.l1_loss(predicted_scalar_value_step0.squeeze(-1), target_scalar_value_step0, reduction='none') + # ===================== END MODIFICATION 5 ===================== + else: + value_priority = torch.tensor(0.) + # ========= logging for analysis ========= - if self.analysis_dormant_ratio_weight_rank: + # if self.analysis_dormant_ratio_weight_rank: + if self.do_analysis: # Calculate dormant ratio of the world model - dormant_ratio_world_model = cal_dormant_ratio(self, { + dormant_ratio_world_model = calculate_dormant_ratio(self, { 'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens.detach())}, dormant_threshold=self.dormant_threshold) dormant_ratio_transformer = dormant_ratio_world_model['transformer'] dormant_ratio_head = dormant_ratio_world_model['head'] - self.past_kv_cache_init_infer.clear() + for kv_cache_dict_env in self.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() self.past_kv_cache_recurrent_infer.clear() self.keys_values_wm_list.clear() torch.cuda.empty_cache() else: dormant_ratio_transformer = torch.tensor(0.) dormant_ratio_head = torch.tensor(0.) - avg_weight_mag_encoder = torch.tensor(0.) - avg_weight_mag_transformer = torch.tensor(0.) - avg_weight_mag_head = torch.tensor(0.) - e_rank_last_linear = torch.tensor(0.) - e_rank_sim_norm = torch.tensor(0.) # ========== for visualization ========== # Uncomment the lines below for visualization @@ -1951,23 +1785,11 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar labels_observations = labels_observations.reshape(-1, self.projection_input_dim) if self.use_task_embed and self.task_embed_option == "concat_task_embed": - # print(f'=='*20) - # print(f'labels_observations.shape:{labels_observations.shape}') - # print(f'=='*20) # Expand task embeddings to match the sequence shape self.task_embeddings = self.task_emb(torch.tensor(task_id, device=self.device)) # NOTE: TODO self.task_embeddings = self.sim_norm(self.task_embeddings.view(1,-1)).view(-1) # TODO task_emb_expanded = self.task_embeddings.expand(labels_observations.shape[0], -1) - # print(f'task_emb_expanded:{task_emb_expanded}') - # print(f"task_emb_expanded.shape: {task_emb_expanded.shape}") - # print(f"task_emb_expanded (min, max, mean): {task_emb_expanded.min()}, {task_emb_expanded.max()}, {task_emb_expanded.mean()}") - # assert not torch.isnan(task_emb_expanded).any(), "task_emb_expanded 存在 NaN 值" - # print(f"logits_observations.shape: {logits_observations.shape}") labels_observations = torch.cat([labels_observations, task_emb_expanded.detach()], dim=-1) # NOTE: detach() - # print(f"labels_observations.shape: {labels_observations.shape}") - # assert logits_observations.shape == labels_observations.shape, "logits 和 labels 的形状不匹配" - - # Compute prediction loss for observations. Options: MSE and Group KL if self.predict_latent_loss_type == 'mse': @@ -2077,6 +1899,10 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar discounted_orig_policy_loss = (orig_policy_loss.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() discounted_policy_entropy = (policy_entropy.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + # 为了让外部的训练循环能够获取encoder的输出,我们将其加入返回字典 + # 使用 .detach() 是因为这个张量仅用于后续的clip操作,不应影响梯度计算 + detached_obs_embeddings = obs_embeddings.detach() + if self.continuous_action_space: return LossWithIntermediateLosses( latent_recon_loss_weight=self.latent_recon_loss_weight, @@ -2105,6 +1931,10 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar policy_mu=mu, policy_sigma=sigma, target_sampled_actions=target_sampled_actions, + + value_priority=value_priority, + obs_embeddings=detached_obs_embeddings, # <-- 新增 + ) else: return LossWithIntermediateLosses( @@ -2131,6 +1961,11 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar e_rank_last_linear = e_rank_last_linear, e_rank_sim_norm = e_rank_sim_norm, latent_state_l2_norms=latent_state_l2_norms, + + value_priority=value_priority, + obs_embeddings=detached_obs_embeddings, # <-- 新增 + + ) #@profile diff --git a/lzero/model/utils.py b/lzero/model/utils.py index c849aedca..1204070f9 100644 --- a/lzero/model/utils.py +++ b/lzero/model/utils.py @@ -1,225 +1,208 @@ """ Overview: - In this file, we provide a set of utility functions for probing network parameters and gradients, - which can be helpful in analyzing and debugging the inner workings of various models. + This file provides a set of utility functions for probing network parameters and gradients. + These tools are helpful for analyzing and debugging the inner workings of various models. """ -from typing import List, Tuple, Union, Dict -from torch.nn import functional as F +from typing import List, Tuple, Union, Dict, Type, Optional + import numpy as np import torch import torch.nn as nn -############################### -# 1. 计算 average_weight_magnitude -############################### + def compute_average_weight_magnitude(model: nn.Module) -> float: """ - 计算模型中所有参数的平均绝对值。 + Overview: + Calculates the average absolute magnitude of all parameters in a given model. Arguments: - model: 待评估模型,类型为 nn.Module + - model (:obj:`nn.Module`): The model to be evaluated. Returns: - 平均权重绝对值(float) + - float: The average absolute magnitude of the model's weights. """ num_weights = 0 - # 使用模型中第一个参数的设备,保证计算时设备一致 + # Use the device of the model's first parameter to ensure consistency. device = next(model.parameters()).device sum_weight_magnitude = torch.tensor(0.0, device=device) for p in model.parameters(): num_weights += p.numel() sum_weight_magnitude += torch.sum(torch.abs(p)) - + if num_weights == 0: return 0.0 return sum_weight_magnitude.cpu().item() / num_weights -############################### -# 2. 计算 effective_rank -############################### + def compute_effective_rank(singular_values: np.ndarray) -> float: """ - 根据给定的奇异值数组计算 effective rank,公式为: - effective_rank = exp( - sum_i [p_i * log(p_i)] ) - 其中 p_i 是归一化后的奇异值(p_i = s_i / ∑ s_i) + Overview: + Computes the effective rank from an array of singular values. The formula is: + effective_rank = exp(-sum_i [p_i * log(p_i)]), where p_i is the normalized singular value. Arguments: - singular_values: 奇异值数组,类型为 np.ndarray + - singular_values (:obj:`np.ndarray`): An array of singular values. Returns: - effective rank(float) + - float: The calculated effective rank. """ + # Normalize singular values to form a probability distribution. norm_sv = singular_values / np.sum(np.abs(singular_values)) entropy = 0.0 for p in norm_sv: - if p > 0.0: + if p > 1e-8: # Avoid log(0) entropy -= p * np.log(p) - return np.e ** entropy + return np.exp(entropy) -# 定义一个 Hook 类,用来捕获中间层的输出 class IntermediateOutputHook: """ - 用于捕获模块输出的 Hook,保存输出张量列表。 + Overview: + A hook class to capture and store the output tensors from a specific nn.Module during a forward pass. """ def __init__(self): self.outputs: List[torch.Tensor] = [] - def __call__(self, module: nn.Module, input: Tuple[torch.Tensor], output: torch.Tensor) -> None: - # 这里使用 detach 防止反向传播干扰,并转移到 CPU 便于后续统计 + def __call__(self, module: nn.Module, inputs: Tuple[torch.Tensor, ...], output: torch.Tensor) -> None: + """ + Overview: + This method is called by PyTorch when the hooked module completes its forward pass. + """ + # Detach the tensor from the computation graph and move to CPU to save memory. self.outputs.append(output.detach().cpu()) -def cal_effective_rank( + def clear(self) -> None: + """ + Overview: + Clears the list of captured outputs. + """ + self.outputs.clear() + + +def calculate_effective_rank( model: nn.Module, - inputs: Union[torch.Tensor, List[torch.Tensor]], + inputs: Union[torch.Tensor, List[torch.Tensor]], representation_layer_name: str, ) -> float: """ - 针对模型指定的中间层(representation 层), - 使用 Hook 捕获该层输出,并计算 effective rank。 + Overview: + Calculates the effective rank of a specified intermediate layer's output (representation) + by using a forward hook to capture the activations. Arguments: - model: 待评估模型,应为 nn.Module 类型。 - inputs: 模型 forward 的输入,可以为 tensor 或 tensor-list。 - representation_layer_name: 模型中表示 representation 层的名称, - 该名称必须能够在 model.named_modules() 中找到对应模块。 + - model (:obj:`nn.Module`): The model to be evaluated. + - inputs (:obj:`Union[torch.Tensor, List[torch.Tensor]]`): The inputs for the model's forward pass. + - representation_layer_name (:obj:`str`): The name of the representation layer, which must be + findable within `model.named_modules()`. Returns: - effective rank(float) + - float: The effective rank of the representation layer's output. """ - # 获取 representation 层模块(若名称不存在将引发 KeyError) module_dict = dict(model.named_modules()) if representation_layer_name not in module_dict: raise KeyError(f"Representation layer '{representation_layer_name}' not found in model.named_modules().") representation_module = module_dict[representation_layer_name] - # 注册 hook hook = IntermediateOutputHook() handle = representation_module.register_forward_hook(hook) - - # 执行 forward 推理 + model.eval() with torch.no_grad(): if isinstance(inputs, (list, tuple)): _ = model(*inputs) else: _ = model(inputs) - - # 注销 hook,避免内存泄露 + + # Always remove the hook to prevent memory leaks. handle.remove() if not hook.outputs: - raise RuntimeError("No outputs captured from the representation layer.") + raise RuntimeError("No outputs were captured from the representation layer.") - # 这里假定有一个或多个 forward(例如在 batch 或多次调用的场景), - # 将所有输出在 batch 维度上拼接 - if len(hook.outputs) > 1: - rep_tensor = torch.cat(hook.outputs, dim=0) - else: - rep_tensor = hook.outputs[0] + # Concatenate all captured outputs along the batch dimension. + rep_tensor = torch.cat(hook.outputs, dim=0) if len(hook.outputs) > 1 else hook.outputs[0] - # 将 representation 展开为二维矩阵: (samples, features) + # Reshape the representation to a 2D matrix (samples, features). rep_tensor = rep_tensor.view(rep_tensor.size(0), -1) - # 将 tensor 转换为 numpy 数组以使用 numpy.linalg.svd - rep_np = rep_tensor.cpu().numpy() + # Compute singular values using SVD. + singular_values = np.linalg.svd(rep_tensor.cpu().numpy(), full_matrices=False, compute_uv=False) - # 计算奇异值 - singular_values = np.linalg.svd(rep_np, full_matrices=False, compute_uv=False) - - # 计算 effective rank + # Calculate the effective rank. e_rank = compute_effective_rank(singular_values) - # 清空 hook 存储(若需要多次调用可以保持清洁状态) - hook.outputs.clear() + hook.clear() return e_rank - def compute_dormant_stats(outputs: List[torch.Tensor], threshold: float) -> Tuple[int, int]: """ - 对给定的一组输出(同一层可能 forward 多次)进行元素级统计。 - + Overview: + Computes element-wise statistics for a list of output tensors from a layer. + Arguments: - outputs: List[torch.Tensor],每个 tensor 表示一次 forward 的输出 - threshold: 判断 dormant 的阈值,当激活值 <= threshold 时视为 dormant - + - outputs (:obj:`List[torch.Tensor]`): A list of tensors, each representing an output from a forward pass. + - threshold (:obj:`float`): The activation threshold below which a neuron is considered dormant. + Returns: - layer_total: 该层总元素数(累加多个 forward) - layer_dormant: 该层中满足 dormant 条件的元素数目 + - Tuple[int, int]: A tuple containing the total number of elements and the number of dormant elements. """ layer_total = 0 layer_dormant = 0 for out in outputs: flattened = out.view(-1) - total = flattened.numel() - dormant = torch.sum(flattened <= threshold).item() - layer_total += total - layer_dormant += dormant + layer_total += flattened.numel() + layer_dormant += torch.sum(flattened <= threshold).item() return layer_total, layer_dormant -def cal_dormant_ratio( + +def calculate_dormant_ratio( model: nn.Module, inputs: Union[torch.Tensor, List[torch.Tensor]], dormant_threshold: float = 1e-2, + target_modules: Tuple[Type[nn.Module], ...] = (nn.Conv2d, nn.Linear), ) -> Dict[str, float]: """ - 针对模型中 encoder、transformer backbone 以及 head 三个部分, - 分别统计各部分中所有目标层(例如 nn.Conv2d、nn.Linear、nn.MultiheadAttention 等)的 - dormant ratio(元素级 dormant 百分比),同时返回全局统计指标。 - + Overview: + Calculates the dormant ratio (percentage of neurons with activation below a threshold) for + different parts of a model (e.g., encoder, transformer, head). It assumes the model has + attributes like `encoder`, `transformer`, or `head_dict`. + Arguments: - model: 待评估模型,应包含属性 encoder、transformer(backbone)以及 head(可选)。 - inputs: 模型的输入,支持 tensor 或 tensor-list,要求与模型 forward 调用一致。 - dormant_threshold: 激活值低于该阈值时视为 dormant,默认 1e-2。 - + - model (:obj:`nn.Module`): The model to evaluate, expected to have `encoder`, `transformer`, or `head_dict` attributes. + - inputs (:obj:`Union[torch.Tensor, List[torch.Tensor]]`): The inputs for the model's forward pass. + - dormant_threshold (:obj:`float`): The activation threshold for defining a dormant neuron. Defaults to 1e-2. + - target_modules (:obj:`Tuple[Type[nn.Module], ...]`): A tuple of module types to attach hooks to. + Returns: - results: 包含各部分以及全局 dormant ratio 的字典,单位为百分比(%)。 - 如:{"encoder": 2.5, "transformer": 1.8, "head": 0.5, "global": 1.6} + - Dict[str, float]: A dictionary containing the dormant ratios for each model part and a global ratio. """ - - # 我们将统计分类为三个部分 parts = {} if hasattr(model, "encoder"): parts["encoder"] = model.encoder if hasattr(model, "transformer"): parts["transformer"] = model.transformer - - # 对于 head 部分,查找所有以 "head_" 开头的子模块 - # head_dict = {} - # for name, module in model.named_children(): - # if name.startswith("head_"): - # head_dict[name] = module - # if head_dict: - # parts["head"] = nn.ModuleDict(head_dict) - if hasattr(model, "head_dict"): parts["head"] = model.head_dict - if not hasattr(model, "encoder") and not hasattr(model, "transformer") and not hasattr(model, "head"): - # 如果传入的是self.tokenizer.encoder + # Fallback for models that don't have the standard part attributes. + if not parts: parts["model"] = model - # 定义要捕获的目标模块类型 TODO: 增加更多模块 - target_modules = (nn.Conv2d, nn.Linear) - - # 用于存储各部分的 hook(字典:部分名 -> list of (module_name, hook)) hooks_dict = {part: [] for part in parts} hook_handles = [] - # 为每个部分中的满足类型条件的模块注册 hook + # Register a forward hook for each target module in each part. for part_name, submodule in parts.items(): for name, module in submodule.named_modules(): if isinstance(module, target_modules): hook = IntermediateOutputHook() - # 为了避免名称冲突,加上所属部分前缀 full_name = f"{part_name}/{name}" hooks_dict[part_name].append((full_name, hook)) handle = module.register_forward_hook(hook) hook_handles.append(handle) - # 调用 forward,执行一次推理 model.eval() with torch.no_grad(): if isinstance(inputs, (list, tuple)): @@ -227,98 +210,110 @@ def cal_dormant_ratio( else: _ = model(inputs) - # 统计各部分各个模块的 dormant 数量和总数 results = {} total_global = 0 dormant_global = 0 + + # Calculate dormant stats from captured outputs. for part, hooks in hooks_dict.items(): part_total = 0 part_dormant = 0 for full_name, hook in hooks: layer_total, layer_dormant = compute_dormant_stats(hook.outputs, dormant_threshold) - # if part == "model": - # print(hook.outputs) - # 可打印日志,也可记录更详细信息 - # print(f"{full_name}: {layer_dormant}/{layer_total} -> {layer_dormant / layer_total * 100.0 if layer_total > 0 else 0.0}%") part_total += layer_total part_dormant += layer_dormant - if part_total > 0: - ratio = (part_dormant / part_total) * 100.0 - else: - ratio = 0.0 - results[part] = ratio + + results[part] = (part_dormant / part_total) * 100.0 if part_total > 0 else 0.0 total_global += part_total dormant_global += part_dormant results["global"] = (dormant_global / total_global) * 100.0 if total_global > 0 else 0.0 - # 清理所有 hook + # Clean up all hooks. for handle in hook_handles: handle.remove() for hooks in hooks_dict.values(): for _, hook in hooks: - hook.outputs.clear() + hook.clear() return results + def renormalize(inputs: torch.Tensor, first_dim: int = 1) -> torch.Tensor: """ Overview: - Normalize the input data using the max-min-normalization. + Normalizes the input tensor using min-max scaling. The normalization is applied + over all dimensions starting from `first_dim`. + Arguments: - - inputs (:obj:`torch.Tensor`): The input data needs to be normalized. - - first_dim (:obj:`int`): The first dimension of flattening the input data. + - inputs (:obj:`torch.Tensor`): The input tensor to be normalized. + - first_dim (:obj:`int`): The first dimension from which to flatten the tensor for normalization. + Returns: - - output (:obj:`torch.Tensor`): The normalized data. + - torch.Tensor: The min-max normalized tensor. """ if first_dim < 0: - first_dim = len(inputs.shape) + first_dim - flat_input = inputs.view(*inputs.shape[:first_dim], -1) - max_val = torch.max(flat_input, first_dim, keepdim=True).values - min_val = torch.min(flat_input, first_dim, keepdim=True).values - flat_input = (flat_input - min_val) / (max_val - min_val) - - return flat_input.view(*inputs.shape) - + first_dim = inputs.dim() + first_dim + + shape = inputs.shape + flat_input = inputs.view(*shape[:first_dim], -1) + + max_val, _ = torch.max(flat_input, dim=first_dim, keepdim=True) + min_val, _ = torch.min(flat_input, dim=first_dim, keepdim=True) + + # Add a small epsilon to avoid division by zero. + denominator = max_val - min_val + denominator[denominator < 1e-8] = 1e-8 + + normalized_flat = (flat_input - min_val) / denominator + + return normalized_flat.view(*shape) -def get_dynamic_mean(model: nn.Module) -> float: - dynamic_mean = np.abs(model.conv.weight.detach().cpu().numpy().reshape(-1)).tolist() - for block in model.resblocks: - for name, param in block.named_parameters(): - dynamic_mean += np.abs(param.detach().cpu().numpy().reshape(-1)).tolist() - dynamic_mean = sum(dynamic_mean) / len(dynamic_mean) - return dynamic_mean +def get_params_mean(model: nn.Module) -> float: + """ + Overview: + Calculates the mean of the absolute values of all parameters in a model. This is an alias + for `compute_average_weight_magnitude`. + Arguments: + - model (:obj:`nn.Module`): The model to be evaluated. -def get_reward_mean(model: nn.Module) -> Tuple[np.ndarray, float]: - reward_w_dist = model.conv1x1_reward.weight.detach().cpu().numpy().reshape(-1) + Returns: + - float: The mean of the absolute parameter values. + """ + return compute_average_weight_magnitude(model) - for name, param in model.fc.named_parameters(): - temp_weights = param.detach().cpu().numpy().reshape(-1) - reward_w_dist = np.concatenate((reward_w_dist, temp_weights)) - reward_mean = np.abs(reward_w_dist).mean() - return reward_w_dist, reward_mean +def get_gradients(model: nn.Module) -> List[Optional[torch.Tensor]]: + """ + Overview: + Retrieves the gradients of all parameters in a model. -def get_params_mean(model: nn.Module) -> Tuple[np.ndarray, float, float, float]: - representation_mean = model.representation_network.get_param_mean() - dynamic_mean = model.dynamics_network.get_dynamic_mean() - reward_w_dist, reward_mean = model.dynamics_network.get_reward_mean() + Arguments: + - model (:obj:`nn.Module`): The model from which to get gradients. - return reward_w_dist, representation_mean, dynamic_mean, reward_mean + Returns: + - List[Optional[torch.Tensor]]: A list of gradient tensors. If a parameter has no gradient, + the corresponding list entry is None. + """ + return [p.grad.detach() if p.grad is not None else None for p in model.parameters()] -def get_gradients(model: nn.Module) -> List[torch.Tensor]: - grads = [] - for p in model.parameters(): - grad = None if p.grad is None else p.grad.detach() - grads.append(grad) - return grads +def set_gradients(model: nn.Module, gradients: List[Optional[torch.Tensor]]) -> None: + """ + Overview: + Sets the gradients for all parameters in a model. + Arguments: + - model (:obj:`nn.Module`): The model whose gradients are to be set. + - gradients (:obj:`List[Optional[torch.Tensor]]`): A list of gradients to assign to the model's parameters. + """ + params = list(model.parameters()) + if len(gradients) != len(params): + raise ValueError(f"Number of gradients ({len(gradients)}) does not match number of model parameters ({len(params)}).") -def set_gradients(model: nn.Module, gradients: List[torch.Tensor]) -> None: - # TODO due to the drawback of zip operation, we have to check whether gradients match model's parameters - for g, p in zip(gradients, model.parameters()): + for g, p in zip(gradients, params): if g is not None: - p.grad = g + # Ensure the gradient is on the same device as the parameter. + p.grad = g.to(p.device) \ No newline at end of file diff --git a/lzero/model/vit.py b/lzero/model/vit.py new file mode 100644 index 000000000..0bc5ebc04 --- /dev/null +++ b/lzero/model/vit.py @@ -0,0 +1,444 @@ +# -*- coding: utf-8 -*- +""" +Optimized Vision Transformer (ViT) Model. + +This script provides an optimized implementation of the Vision Transformer (ViT) architecture. +It includes improvements in code structure, clarity, and adherence to modern Python coding standards, +including comprehensive type hinting and documentation. The implementation also supports +integration with Low-Rank Adaptation (LoRA) through a flexible configuration system. + +Author: [Your Name/Team Name] +Date: [Current Date] +""" + +import torch +from torch import nn +from einops import rearrange, repeat +from einops.layers.torch import Rearrange +from lzero.model.common import SimNorm +from typing import Tuple, Union, Type, Optional + +# ==================== LoRA Integration Section Start ==================== + +# Attempt to import core components from a local transformer.py for LoRA support. +# This allows for flexible adaptation (e.g., LoRA) of linear layers. +try: + # Assuming transformer.py is in the same directory. Adjust the import path if necessary. + from .transformer import _maybe_wrap_linear, TransformerConfig +except ImportError: + # If the import fails (e.g., when running this file directly), provide a fallback. + # This ensures the model remains functional without LoRA components. + print("Warning: LoRA components could not be imported. Using standard nn.Linear.") + _maybe_wrap_linear = lambda linear, config, label: linear + + # Define a placeholder class for TransformerConfig if it's not available. + class TransformerConfig: + """Placeholder for TransformerConfig when LoRA components are not available.""" + pass + +# ==================== LoRA Integration Section End ==================== + + +# ==================== Configuration Class ==================== + +class ViTConfig: + """ + Overview: + Configuration class for the Vision Transformer (ViT) model. + This class centralizes all hyperparameters, making the model easier to configure and manage. + """ + def __init__(self, **kwargs): + """ + Overview: + Initializes the ViTConfig object. + Arguments: + - **kwargs: Arbitrary keyword arguments to override default settings. + """ + # Image and Patch Dimensions + self.image_size: Union[int, Tuple[int, int]] = 64 + self.patch_size: Union[int, Tuple[int, int]] = 8 + self.channels: int = 3 + + # Model Architecture + self.num_classes: int = 768 + self.dim: int = 768 + self.depth: int = 12 + self.heads: int = 12 + self.mlp_dim: int = 3072 + self.dim_head: int = 64 + + # Pooling and Normalization + self.pool: str = 'cls' # 'cls' or 'mean' + self.final_norm_option_in_encoder: str = 'LayerNorm' # 'LayerNorm' or 'SimNorm' + + # Dropout Rates + self.dropout: float = 0.1 + self.emb_dropout: float = 0.1 + + # LoRA Configuration + self.lora_config: Optional[TransformerConfig] = None + + # Update attributes with any provided keyword arguments + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + else: + print(f"Warning: Ignoring unknown config parameter '{key}'") + + +# ==================== Helper Functions ==================== + +def pair(t: Union[int, Tuple[int, int]]) -> Tuple[int, int]: + """ + Overview: + Converts an integer to a tuple of two identical integers. If the input is already a tuple, it is returned as is. + This is useful for handling kernel sizes, strides, etc., which can be specified as a single number or a tuple. + Arguments: + - t (:obj:`Union[int, Tuple[int, int]]`): The input value. + Returns: + - (:obj:`Tuple[int, int]`): A tuple of two integers. + """ + return t if isinstance(t, tuple) else (t, t) + + +# ==================== Core Modules ==================== + +class FeedForward(nn.Module): + """ + Overview: + A standard feed-forward network block used in Transformer architectures. + It consists of two linear layers with a GELU activation in between. + """ + def __init__( + self, + dim: int, + hidden_dim: int, + dropout: float = 0.0, + config: Optional[TransformerConfig] = None + ): + """ + Overview: + Initializes the FeedForward module. + Arguments: + - dim (:obj:`int`): The input and output dimension. + - hidden_dim (:obj:`int`): The dimension of the hidden layer. + - dropout (:obj:`float`): The dropout rate. + - config (:obj:`Optional[TransformerConfig]`): Configuration for LoRA wrapping. + """ + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(dim), + _maybe_wrap_linear(nn.Linear(dim, hidden_dim), config, "feed_forward"), + nn.GELU(), + nn.Dropout(dropout), + _maybe_wrap_linear(nn.Linear(hidden_dim, dim), config, "feed_forward"), + nn.Dropout(dropout) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Overview: + Forward pass for the FeedForward block. + Arguments: + - x (:obj:`torch.Tensor`): The input tensor of shape (batch_size, num_tokens, dim). + Returns: + - (:obj:`torch.Tensor`): The output tensor of the same shape as input. + """ + return self.net(x) + + +class Attention(nn.Module): + """ + Overview: + Multi-Head Self-Attention (MHSA) module. + It computes scaled dot-product attention across multiple heads. + """ + def __init__( + self, + dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + config: Optional[TransformerConfig] = None + ): + """ + Overview: + Initializes the Attention module. + Arguments: + - dim (:obj:`int`): The input and output dimension. + - heads (:obj:`int`): The number of attention heads. + - dim_head (:obj:`int`): The dimension of each attention head. + - dropout (:obj:`float`): The dropout rate for attention weights and output. + - config (:obj:`Optional[TransformerConfig]`): Configuration for LoRA wrapping. + """ + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.norm = nn.LayerNorm(dim) + self.attend = nn.Softmax(dim=-1) + self.dropout = nn.Dropout(dropout) + + # Linear layer to project input to Q, K, V. Potentially wrapped for LoRA. + self.to_qkv = _maybe_wrap_linear(nn.Linear(dim, inner_dim * 3, bias=False), config, "attn") + + # Output projection layer. + if project_out: + # Wrap the linear layer inside the sequential module for LoRA. + wrapped_linear = _maybe_wrap_linear(nn.Linear(inner_dim, dim), config, "attn") + self.to_out = nn.Sequential( + wrapped_linear, + nn.Dropout(dropout) + ) + else: + self.to_out = nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Overview: + Forward pass for the Attention module. + Arguments: + - x (:obj:`torch.Tensor`): Input tensor of shape (batch_size, num_tokens, dim). + Returns: + - (:obj:`torch.Tensor`): Output tensor of the same shape as input. + """ + x = self.norm(x) + + # Project to Q, K, V and split. + qkv = self.to_qkv(x).chunk(3, dim=-1) + # Rearrange for multi-head attention: b n (h d) -> b h n d + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv) + + # Scaled dot-product attention. + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + attn = self.attend(dots) + attn = self.dropout(attn) + + # Apply attention to values. + out = torch.matmul(attn, v) + # Rearrange back to original shape: b h n d -> b n (h d) + out = rearrange(out, 'b h n d -> b n (h d)') + + return self.to_out(out) + + +class Transformer(nn.Module): + """ + Overview: + A stack of Transformer blocks, each containing a multi-head self-attention + layer and a feed-forward network. + """ + def __init__( + self, + dim: int, + depth: int, + heads: int, + dim_head: int, + mlp_dim: int, + dropout: float = 0.0, + config: Optional[TransformerConfig] = None + ): + """ + Overview: + Initializes the Transformer module. + Arguments: + - dim (:obj:`int`): The dimension of the token embeddings. + - depth (:obj:`int`): The number of Transformer blocks. + - heads (:obj:`int`): The number of attention heads. + - dim_head (:obj:`int`): The dimension of each attention head. + - mlp_dim (:obj:`int`): The hidden dimension of the feed-forward network. + - dropout (:obj:`float`): The dropout rate. + - config (:obj:`Optional[TransformerConfig]`): Configuration for LoRA. + """ + super().__init__() + self.norm = nn.LayerNorm(dim) + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout, config=config), + FeedForward(dim, mlp_dim, dropout=dropout, config=config) + ])) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Overview: + Forward pass for the Transformer stack. + Arguments: + - x (:obj:`torch.Tensor`): Input tensor of shape (batch_size, num_tokens, dim). + Returns: + - (:obj:`torch.Tensor`): Output tensor of the same shape. + """ + for attn, ff in self.layers: + x = attn(x) + x # Apply attention and residual connection + x = ff(x) + x # Apply feed-forward and residual connection + return self.norm(x) + + +class ViT(nn.Module): + """ + Overview: + Vision Transformer (ViT) model. This model applies the Transformer architecture + to sequences of image patches for image classification tasks. + """ + def __init__(self, config: ViTConfig): + """ + Overview: + Initializes the ViT model using a configuration object. + Arguments: + - config (:obj:`ViTConfig`): A configuration object containing all model hyperparameters. + """ + super().__init__() + self.config = config + + image_height, image_width = pair(config.image_size) + patch_height, patch_width = pair(config.patch_size) + + assert image_height % patch_height == 0 and image_width % patch_width == 0, \ + 'Image dimensions must be divisible by the patch size.' + + num_patches = (image_height // patch_height) * (image_width // patch_width) + patch_dim = config.channels * patch_height * patch_width + assert config.pool in {'cls', 'mean'}, 'pool type must be either "cls" or "mean"' + + # Patch embedding layer + self.to_patch_embedding = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width), + nn.LayerNorm(patch_dim), + nn.Linear(patch_dim, config.dim), + nn.LayerNorm(config.dim), + ) + + # Positional embedding and CLS token + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, config.dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, config.dim)) + self.dropout = nn.Dropout(config.emb_dropout) + + # Transformer encoder stack + self.transformer = Transformer( + dim=config.dim, + depth=config.depth, + heads=config.heads, + dim_head=config.dim_head, + mlp_dim=config.mlp_dim, + dropout=config.dropout, + config=config.lora_config + ) + + self.pool = config.pool + self.last_linear = nn.Linear(config.dim, config.num_classes) + + # Final normalization layer + if config.final_norm_option_in_encoder == 'LayerNorm': + self.final_norm = nn.LayerNorm(config.num_classes, eps=1e-5) + elif config.final_norm_option_in_encoder == 'SimNorm': + group_size = 8 # As specified in original code + self.final_norm = SimNorm(simnorm_dim=group_size) + else: + raise ValueError(f"Unsupported final_norm_option_in_encoder: {config.final_norm_option_in_encoder}") + + def forward(self, img: torch.Tensor) -> torch.Tensor: + """ + Overview: + Forward pass for the ViT model. + Arguments: + - img (:obj:`torch.Tensor`): Input image tensor of shape (batch_size, channels, height, width). + Returns: + - (:obj:`torch.Tensor`): Output logits tensor of shape (batch_size, num_classes). + """ + # 1. Patch embedding + x = self.to_patch_embedding(img) + b, n, _ = x.shape + + # 2. Prepend CLS token + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=b) + x = torch.cat((cls_tokens, x), dim=1) + + # 3. Add positional embedding + x += self.pos_embedding[:, :(n + 1)] + x = self.dropout(x) + + # 4. Pass through Transformer encoder + x = self.transformer(x) + + # 5. Pooling + x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0] + + # 6. Final classification head + x = self.last_linear(x) + x = self.final_norm(x) + + return x + + +# ==================== Test and Benchmark Code ==================== +if __name__ == "__main__": + import random + import time + + # Fix random seeds for reproducibility + torch.manual_seed(42) + random.seed(42) + + # 1. Create a configuration object + # This is now the standard way to configure the model. + vit_config = ViTConfig( + image_size=64, + patch_size=8, + num_classes=768, + dim=768, + depth=12, + heads=12, + mlp_dim=3072, + dropout=0.1, + emb_dropout=0.1, + final_norm_option_in_encoder="LayerNorm" + ) + + # 2. Instantiate the model with the config + model = ViT(config=vit_config) + + # Move model to GPU if available + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + model.eval() # Set model to evaluation mode for inference + + # Create a dummy input tensor + dummy_input = torch.randn(256, 3, 64, 64).to(device) + + # Perform a single forward pass + with torch.no_grad(): + out = model(dummy_input) + + print(f"Device: {device}") + print(f"Output shape: {out.shape}") + print(f"Output[0] (first 50 values): {out[0][:50]}") + + # 3. Simple Benchmark + print("\nStarting benchmark...") + warmup_reps, bench_reps = 5, 20 + + with torch.no_grad(): + # Warm-up runs + for _ in range(warmup_reps): + _ = model(dummy_input) + + # Synchronize before timing (for CUDA) + if torch.cuda.is_available(): + torch.cuda.synchronize() + + start_time = time.time() + for _ in range(bench_reps): + _ = model(dummy_input) + + # Synchronize after timing + if torch.cuda.is_available(): + torch.cuda.synchronize() + + end_time = time.time() + + total_time = end_time - start_time + avg_latency_ms = (total_time / bench_reps) * 1000 + print(f"Average latency over {bench_reps} runs: {avg_latency_ms:.2f} ms") \ No newline at end of file diff --git a/lzero/policy/efficientzero.py b/lzero/policy/efficientzero.py index e8a8250bd..c0fb1536a 100644 --- a/lzero/policy/efficientzero.py +++ b/lzero/policy/efficientzero.py @@ -45,9 +45,10 @@ class EfficientZeroPolicy(MuZeroPolicy): image_channel=1, # (int) The number of frames to stack together. frame_stack_num=1, - # (int) The scale of supports used in categorical distribution. - # This variable is only effective when ``categorical_distribution=True``. - support_scale=300, + # (tuple) The range of supports used in categorical distribution. + # These variables are only effective when ``model.categorical_distribution=True``. + reward_support_range=(-300., 301., 1.), + value_support_range=(-300., 301., 1.), # (int) The hidden size in LSTM. lstm_hidden_size=512, # (bool) whether to learn bias in the last linear layer in value and policy head. @@ -275,12 +276,12 @@ def _init_learn(self) -> None: self._cfg.augmentation, image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) ) - self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) - self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) - - self.inverse_scalar_transform_handle = InverseScalarTransform( - self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution - ) + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) + assert self.value_support.size == self._learn_model.value_support_size # if these assertions fails, somebody introduced... + assert self.reward_support.size == self._learn_model.reward_support_size # ...incoherence between policy and model + self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) + self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: """ @@ -342,7 +343,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: # transform the scaled value or its categorical representation to its original value, # i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. - original_value = self.inverse_scalar_transform_handle(value) + original_value = self.value_inverse_scalar_transform_handle(value) # Note: The following lines are just for debugging. predicted_value_prefixs = [] @@ -398,7 +399,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: # transform the scaled value or its categorical representation to its original value, # i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. - original_value = self.inverse_scalar_transform_handle(value) + original_value = self.value_inverse_scalar_transform_handle(value) # ============================================================== # calculate consistency loss for the next ``num_unroll_steps`` unroll steps. @@ -456,10 +457,10 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: ) if self._cfg.monitor_extra_statistics: - original_value_prefixs = self.inverse_scalar_transform_handle(value_prefix) + original_value_prefixs = self.value_inverse_scalar_transform_handle(value_prefix) original_value_prefixs_cpu = original_value_prefixs.detach().cpu() predicted_values = torch.cat( - (predicted_values, self.inverse_scalar_transform_handle(value).detach().cpu()) + (predicted_values, self.value_inverse_scalar_transform_handle(value).detach().cpu()) ) predicted_value_prefixs.append(original_value_prefixs_cpu) predicted_policies = torch.cat((predicted_policies, torch.softmax(policy_logits, dim=1).detach().cpu())) @@ -583,7 +584,7 @@ def _forward_collect( network_output ) - pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() latent_state_roots = latent_state_roots.detach().cpu().numpy() reward_hidden_state_roots = ( reward_hidden_state_roots[0].detach().cpu().numpy(), @@ -702,7 +703,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: Union[in if not self._eval_model.training: # if not in training, obtain the scalars of the value/reward - pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) latent_state_roots = latent_state_roots.detach().cpu().numpy() reward_hidden_state_roots = ( reward_hidden_state_roots[0].detach().cpu().numpy(), diff --git a/lzero/policy/gumbel_muzero.py b/lzero/policy/gumbel_muzero.py index 65e0fa7d2..e44464038 100644 --- a/lzero/policy/gumbel_muzero.py +++ b/lzero/policy/gumbel_muzero.py @@ -48,9 +48,10 @@ class GumbelMuZeroPolicy(MuZeroPolicy): num_res_blocks=1, # (int) The number of channels of hidden states in MuZero model. num_channels=64, - # (int) The scale of supports used in categorical distribution. - # This variable is only effective when ``categorical_distribution=True``. - support_scale=300, + # (tuple) The range of supports used in categorical distribution. + # These variables are only effective when ``categorical_distribution=True``. + reward_support_range=(-300., 301., 1.), + value_support_range=(-300., 301., 1.), # (bool) whether to learn bias in the last linear layer in value and policy head. bias=True, # (str) The type of action encoding. Options are ['one_hot', 'not_one_hot']. Default to 'one_hot'. @@ -262,11 +263,13 @@ def _init_learn(self) -> None: self._cfg.augmentation, image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) ) - self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) - self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) - self.inverse_scalar_transform_handle = InverseScalarTransform( - self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution - ) + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) + assert self.value_support.size == self._learn_model.value_support_size # if these assertions fails, somebody introduced... + assert self.reward_support.size == self._learn_model.reward_support_size # ...incoherence between policy and model + self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) + self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) + self.kl_loss = KLDivLoss(reduction='none') def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: @@ -333,7 +336,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: # transform the scaled value or its categorical representation to its original value, # i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. - original_value = self.inverse_scalar_transform_handle(value) + original_value = self.value_inverse_scalar_transform_handle(value) # Note: The following lines are just for debugging. predicted_rewards = [] @@ -378,7 +381,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: # transform the scaled value or its categorical representation to its original value, # i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. - original_value = self.inverse_scalar_transform_handle(value) + original_value = self.value_inverse_scalar_transform_handle(value) if self._cfg.model.self_supervised_learning_loss: # ============================================================== @@ -414,11 +417,11 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: policy_entropy += (prob * prob.log()).sum(-1) if self._cfg.monitor_extra_statistics: - original_rewards = self.inverse_scalar_transform_handle(reward) + original_rewards = self.reward_inverse_scalar_transform_handle(reward) original_rewards_cpu = original_rewards.detach().cpu() predicted_values = torch.cat( - (predicted_values, self.inverse_scalar_transform_handle(value).detach().cpu()) + (predicted_values, self.value_inverse_scalar_transform_handle(value).detach().cpu()) ) predicted_rewards.append(original_rewards_cpu) predicted_policies = torch.cat((predicted_policies, torch.softmax(policy_logits, dim=1).detach().cpu())) @@ -539,7 +542,7 @@ def _forward_collect( network_output = self._collect_model.initial_inference(data) latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) - pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() latent_state_roots = latent_state_roots.detach().cpu().numpy() policy_logits = policy_logits.detach().cpu().numpy().tolist() @@ -648,7 +651,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [ if not self._eval_model.training: # if not in training, obtain the scalars of the value/reward - pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) latent_state_roots = latent_state_roots.detach().cpu().numpy() policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) diff --git a/lzero/policy/muzero.py b/lzero/policy/muzero.py index 55bf0dc6c..da69fbd80 100644 --- a/lzero/policy/muzero.py +++ b/lzero/policy/muzero.py @@ -15,7 +15,7 @@ from lzero.mcts import MuZeroMCTSCtree as MCTSCtree from lzero.mcts import MuZeroMCTSPtree as MCTSPtree from lzero.model import ImageTransforms -from lzero.model.utils import cal_dormant_ratio +from lzero.model.utils import calculate_dormant_ratio from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, \ prepare_obs, configure_optimizers @@ -57,9 +57,10 @@ class MuZeroPolicy(Policy): num_res_blocks=1, # (int) The number of channels of hidden states in MuZero model. num_channels=64, - # (int) The scale of supports used in categorical distribution. - # This variable is only effective when ``categorical_distribution=True``. - support_scale=300, + # (tuple) The range of supports used in categorical distribution. + # These variables are only effective when ``model.categorical_distribution=True``. + reward_support_range=(-300., 301., 1.), + value_support_range=(-300., 301., 1.), # (bool) whether to learn bias in the last linear layer in value and policy head. bias=True, # (str) The type of action encoding. Options are ['one_hot', 'not_one_hot']. Default to 'one_hot'. @@ -112,7 +113,7 @@ class MuZeroPolicy(Policy): # This is done by setting the parameter learn.learner.hook.save_ckpt_after_iter to the same value as eval_freq in the train_muzero.py automatically. eval_offline=False, # (bool) Whether to calculate the dormant ratio. - cal_dormant_ratio=False, + calculate_dormant_ratio=False, # (bool) Whether to analyze simulation normalization. analysis_sim_norm=False, # (bool) Whether to analyze dormant ratio. @@ -312,11 +313,12 @@ def _init_learn(self) -> None: self._cfg.augmentation, image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) ) - self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) - self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) - self.inverse_scalar_transform_handle = InverseScalarTransform( - self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution - ) + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) + assert self.value_support.size == self._learn_model.value_support_size # if these assertions fails, somebody introduced... + assert self.reward_support.size == self._learn_model.reward_support_size # ...incoherence between policy and model + self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) + self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) # ============================================================== # harmonydream (learnable weights for different losses) @@ -421,8 +423,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # ========= logging for analysis ========= # calculate dormant ratio of encoder - if self._cfg.cal_dormant_ratio: - self.dormant_ratio_encoder = cal_dormant_ratio(self._learn_model.representation_network, obs_batch.detach(), + if self._cfg.calculate_dormant_ratio: + self.dormant_ratio_encoder = calculate_dormant_ratio(self._learn_model.representation_network, obs_batch.detach(), percentage=self._cfg.dormant_threshold) # calculate L2 norm of latent state latent_state_l2_norms = torch.norm(latent_state.view(latent_state.shape[0], -1), p=2, dim=1).mean() @@ -430,7 +432,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # transform the scaled value or its categorical representation to its original value, # i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. - original_value = self.inverse_scalar_transform_handle(value) + original_value = self.value_inverse_scalar_transform_handle(value) # Note: The following lines are just for debugging. predicted_rewards = [] @@ -468,7 +470,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in latent_state, reward, value, policy_logits = mz_network_output_unpack(network_output) # ========= logging for analysis =============== - if step_k == self._cfg.num_unroll_steps - 1 and self._cfg.cal_dormant_ratio: + if step_k == self._cfg.num_unroll_steps - 1 and self._cfg.calculate_dormant_ratio: # calculate dormant ratio of encoder action_tmp = action_batch[:, step_k] if len(action_tmp.shape) == 1: @@ -484,14 +486,14 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in latent_state.shape[0], policy_logits.shape[-1], latent_state.shape[2], latent_state.shape[3] ) state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) - self.dormant_ratio_dynamics = cal_dormant_ratio(self._learn_model.dynamics_network, + self.dormant_ratio_dynamics = calculate_dormant_ratio(self._learn_model.dynamics_network, state_action_encoding.detach(), percentage=self._cfg.dormant_threshold) # ========= logging for analysis =============== # transform the scaled value or its categorical representation to its original value, # i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. - original_value = self.inverse_scalar_transform_handle(value) + original_value = self.value_inverse_scalar_transform_handle(value) if self._cfg.model.self_supervised_learning_loss: # ============================================================== @@ -511,8 +513,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in temp_loss = negative_cosine_similarity(dynamic_proj, observation_proj) * mask_batch[:, step_k] consistency_loss += temp_loss - # NOTE: the target policy, target_value_categorical, target_reward_categorical is calculated in - # game buffer now. + # NOTE: the target policy is calculated in game buffer now. # ============================================================== # calculate policy loss for the next ``num_unroll_steps`` unroll steps. # NOTE: the +=. @@ -543,11 +544,11 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in reward_loss += cross_entropy_loss(reward, target_reward_categorical[:, step_k]) if self._cfg.monitor_extra_statistics: - original_rewards = self.inverse_scalar_transform_handle(reward) + original_rewards = self.reward_inverse_scalar_transform_handle(reward) original_rewards_cpu = original_rewards.detach().cpu() predicted_values = torch.cat( - (predicted_values, self.inverse_scalar_transform_handle(value).detach().cpu()) + (predicted_values, self.value_inverse_scalar_transform_handle(value).detach().cpu()) ) predicted_rewards.append(original_rewards_cpu) predicted_policies = torch.cat((predicted_policies, torch.softmax(policy_logits, dim=1).detach().cpu())) @@ -738,7 +739,7 @@ def _forward_collect( latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) - pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() latent_state_roots = latent_state_roots.detach().cpu().numpy() policy_logits = policy_logits.detach().cpu().numpy().tolist() @@ -891,7 +892,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [ if not self._eval_model.training: # if not in training, obtain the scalars of the value/reward - pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) latent_state_roots = latent_state_roots.detach().cpu().numpy() policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) diff --git a/lzero/policy/muzero_multitask.py b/lzero/policy/muzero_multitask.py index c65ccc5e8..45addaf59 100644 --- a/lzero/policy/muzero_multitask.py +++ b/lzero/policy/muzero_multitask.py @@ -26,24 +26,33 @@ from lzero.policy.muzero import MuZeroPolicy -def generate_task_loss_dict(multi_task_losses, task_name_template, task_id): +def generate_task_loss_dict(multi_task_losses: List[float], task_name_template: str, task_id: int) -> Dict[str, float]: """ - 生成每个任务的损失字典 - :param multi_task_losses: 包含每个任务损失的列表 - :param task_name_template: 任务名称模板,例如 'loss_task{}' - :param task_id: 任务起始ID - :return: 一个字典,包含每个任务的损失 + Overview: + Generates a dictionary for the losses of each task. + Arguments: + - multi_task_losses (:obj:`List[float]`): A list containing the loss for each task. + - task_name_template (:obj:`str`): A template for the task name, e.g., 'loss_task{}'. + - task_id (:obj:`int`): The starting ID for the tasks. + Returns: + - task_loss_dict (:obj:`Dict[str, float]`): A dictionary containing the loss for each task. """ task_loss_dict = {} for task_idx, task_loss in enumerate(multi_task_losses): task_name = task_name_template.format(task_idx + task_id) try: + # Ensure the loss is a scalar value for logging. task_loss_dict[task_name] = task_loss.item() if hasattr(task_loss, 'item') else task_loss except Exception: task_loss_dict[task_name] = task_loss return task_loss_dict class WrappedModelV2: + """ + Overview: + A wrapper class to bundle different parts of a model (tokenizer, transformer, embeddings) + for easier management of parameters and gradients. + """ def __init__(self, tokenizer, transformer, pos_emb, task_emb, act_embedding_table): self.tokenizer = tokenizer self.transformer = transformer @@ -51,8 +60,11 @@ def __init__(self, tokenizer, transformer, pos_emb, task_emb, act_embedding_tabl self.task_emb = task_emb self.act_embedding_table = act_embedding_table - def parameters(self): - # 返回 tokenizer, transformer 以及所有嵌入层的参数 + def parameters(self) -> List[torch.nn.Parameter]: + """ + Overview: + Returns a list of all parameters from the tokenizer, transformer, and all embedding layers. + """ return ( list(self.tokenizer.parameters()) + list(self.transformer.parameters()) + @@ -61,8 +73,13 @@ def parameters(self): list(self.act_embedding_table.parameters()) ) - def zero_grad(self, set_to_none=False): - # 将 tokenizer, transformer 和所有嵌入层的梯度设为零 + def zero_grad(self, set_to_none: bool = False) -> None: + """ + Overview: + Sets the gradients of all parameters in the tokenizer, transformer, and embedding layers to zero. + Arguments: + - set_to_none (:obj:`bool`): Whether to set gradients to None instead of zero. + """ self.tokenizer.zero_grad(set_to_none=set_to_none) self.transformer.zero_grad(set_to_none=set_to_none) self.pos_emb.zero_grad(set_to_none=set_to_none) @@ -72,11 +89,12 @@ def zero_grad(self, set_to_none=False): @POLICY_REGISTRY.register('muzero_multitask') class MuZeroMTPolicy(MuZeroPolicy): """ - 概述: - MuZero 的多任务策略类,扩展自 MuZeroPolicy。支持同时训练多个任务,通过分离每个任务的损失并进行优化。 + Overview: + The multi-task policy for MuZero, extending MuZeroPolicy. It supports training multiple tasks + simultaneously by separating the loss for each task and optimizing them jointly. """ - # MuZeroMTPolicy 的默认配置 + # Default configuration for MuZeroMTPolicy. config = dict( type='muzero_multitask', model=dict( @@ -175,29 +193,29 @@ class MuZeroMTPolicy(MuZeroPolicy): decay=int(1e5), ), - # ****** 多任务相关 ****** - task_num=2, # 任务数量,根据实际需求调整 - task_id=0, # 当前任务的起始ID + # ****** Multi-task related ****** + task_num=2, # Number of tasks, adjust as needed. + task_id=0, # The starting ID of the current task. ) def default_model(self) -> Tuple[str, List[str]]: """ - 概述: - 返回该算法的默认模型设置。 - 返回: - - model_info (:obj:`Tuple[str, List[str]]`): 模型名称和模型导入路径列表。 + Overview: + Returns the default model configuration for this algorithm. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): A tuple containing the model name and a list of import paths. """ return 'MuZeroMTModel', ['lzero.model.muzero_model_multitask'] def _init_learn(self) -> None: """ - 概述: - 学习模式初始化方法。初始化学习模型、优化器和MCTS工具。 + Overview: + Initializes the learning mode. This method sets up the learning model, optimizer, and MCTS utilities. """ super()._init_learn() assert self._cfg.optim_type in ['SGD', 'Adam', 'AdamW'], self._cfg.optim_type - # NOTE: in board_games, for fixed lr 0.003, 'Adam' is better than 'SGD'. + # NOTE: In board games, for a fixed learning rate of 0.003, 'Adam' performs better than 'SGD'. if self._cfg.optim_type == 'SGD': self._optimizer = optim.SGD( self._model.parameters(), @@ -213,14 +231,15 @@ def _init_learn(self) -> None: self._optimizer = configure_optimizers(model=self._model, weight_decay=self._cfg.weight_decay, learning_rate=self._cfg.learning_rate, device_type=self._cfg.device) + # Learning rate scheduler if self._cfg.lr_piecewise_constant_decay: from torch.optim.lr_scheduler import LambdaLR max_step = self._cfg.threshold_training_steps_for_final_lr - # NOTE: the 1, 0.1, 0.01 is the decay rate, not the lr. + # NOTE: 1, 0.1, 0.01 are decay rates, not the learning rate itself. lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa self.lr_scheduler = LambdaLR(self._optimizer, lr_lambda=lr_lambda) - # use model_wrapper for specialized demands of different modes + # Use model_wrapper for specialized demands of different modes. self._target_model = copy.deepcopy(self._model) self._target_model = model_wrap( self._target_model, @@ -230,11 +249,14 @@ def _init_learn(self) -> None: ) self._learn_model = self._model + # Image augmentation if self._cfg.use_augmentation: self.image_transforms = ImageTransforms( self._cfg.augmentation, image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) ) + + # Support for categorical distribution self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) self.inverse_scalar_transform_handle = InverseScalarTransform( @@ -242,16 +264,17 @@ def _init_learn(self) -> None: ) # ============================================================== - # harmonydream (learnable weights for different losses) + # HarmonyDream (learnable weights for different losses) # ============================================================== if self._cfg.model.harmony_balance: - # List of parameter names + # List of parameter names. harmony_names = ["harmony_dynamics", "harmony_policy", "harmony_value", "harmony_reward", "harmony_entropy"] - # Initialize and name each parameter + # Initialize and name each parameter. for name in harmony_names: param = torch.nn.Parameter(-torch.log(torch.tensor(1.0))) setattr(self, name, param) - + + # RND model for intrinsic reward if self._cfg.use_rnd_model: if self._cfg.target_model_for_intrinsic_reward_update_type == 'assign': self._target_model_for_intrinsic_reward = model_wrap( @@ -268,31 +291,35 @@ def _init_learn(self) -> None: update_kwargs={'theta': self._cfg.target_update_theta_for_intrinsic_reward} ) - # ========= logging for analysis ========= + # ========= Logging for analysis ========= self.l2_norm_before = 0. self.l2_norm_after = 0. self.grad_norm_before = 0. self.grad_norm_after = 0. self.dormant_ratio_encoder = 0. self.dormant_ratio_dynamics = 0. - # 初始化多任务相关参数 + + # Initialize multi-task related parameters. self.task_num_for_current_rank = self._cfg.task_num self.task_id = self._cfg.task_id def _forward_learn(self, data: List[Tuple[torch.Tensor, torch.Tensor, int]]) -> Dict[str, Union[float, int]]: """ - 概述: - 学习模式的前向函数,是学习过程的核心。数据从重放缓冲区采样,计算损失并反向传播更新模型。 - 参数: - - data (:obj:`List[Tuple[torch.Tensor, torch.Tensor, int]]`): 每个任务的数据元组列表, - 每个元组包含 (current_batch, target_batch, task_id)。 - 返回: - - info_dict (:obj:`Dict[str, Union[float, int]]`): 用于记录的信息字典,包含当前学习损失和学习统计信息。 + Overview: + The forward function for learning, which is the core of the learning process. + Data is sampled from the replay buffer, and the loss is calculated and backpropagated + to update the model. + Arguments: + - data (:obj:`List[Tuple[torch.Tensor, torch.Tensor, int]]`): A list of data tuples for each task, + where each tuple contains (current_batch, target_batch, task_id). + Returns: + - info_dict (:obj:`Dict[str, Union[float, int]]`): A dictionary of information for logging, + including the current learning loss and other learning statistics. """ self._learn_model.train() self._target_model.train() - # 初始化多任务损失列表 + # Initialize lists for multi-task losses. reward_loss_multi_task = [] policy_loss_multi_task = [] value_loss_multi_task = [] @@ -302,8 +329,8 @@ def _forward_learn(self, data: List[Tuple[torch.Tensor, torch.Tensor, int]]) -> value_priority_multi_task = [] value_priority_mean_multi_task = [] - weighted_total_loss = 0.0 # 初始化为0 - losses_list = [] # 用于存储每个任务的损失 + weighted_total_loss = 0.0 # Initialize to zero. + losses_list = [] # To store the loss for each task. for task_idx, (current_batch, target_batch, task_id) in enumerate(data): obs_batch_ori, action_batch, mask_batch, indices, weights, make_time = current_batch @@ -311,13 +338,13 @@ def _forward_learn(self, data: List[Tuple[torch.Tensor, torch.Tensor, int]]) -> obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) - # 数据增强 + # Data augmentation. if self._cfg.use_augmentation: obs_batch = self.image_transforms.transform(obs_batch) if self._cfg.model.self_supervised_learning_loss: obs_target_batch = self.image_transforms.transform(obs_target_batch) - # 准备动作批次并转换为张量 + # Prepare action batch and convert to tensor. action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(-1).long() data_list = [mask_batch, target_reward, target_value, target_policy, weights] mask_batch, target_reward, target_value, target_policy, weights = to_torch_float_tensor( @@ -329,20 +356,20 @@ def _forward_learn(self, data: List[Tuple[torch.Tensor, torch.Tensor, int]]) -> assert obs_batch.size(0) == self._cfg.batch_size[task_idx] == target_reward.size(0) - # 变换奖励和价值到缩放形式 + # Transform rewards and values to scaled representation. transformed_target_reward = scalar_transform(target_reward) transformed_target_value = scalar_transform(target_value) - # 转换为类别分布 + # Convert to categorical distribution. target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) target_value_categorical = phi_transform(self.value_support, transformed_target_value) - # 初始推理 + # Initial inference. network_output = self._learn_model.initial_inference(obs_batch, task_id=task_id) latent_state, reward, value, policy_logits = mz_network_output_unpack(network_output) - # 记录 Dormant Ratio 和 L2 Norm + # Log Dormant Ratio and L2 Norm. if self._cfg.cal_dormant_ratio: self.dormant_ratio_encoder = cal_dormant_ratio( self._learn_model.representation_network, obs_batch.detach(), @@ -350,21 +377,21 @@ def _forward_learn(self, data: List[Tuple[torch.Tensor, torch.Tensor, int]]) -> ) latent_state_l2_norms = torch.norm(latent_state.view(latent_state.shape[0], -1), p=2, dim=1).mean() - # 逆变换价值 + # Inverse transform value. original_value = self.inverse_scalar_transform_handle(value) - # 初始化预测值和策略 + # Initialize predicted values and policies. predicted_rewards = [] if self._cfg.monitor_extra_statistics: predicted_values, predicted_policies = original_value.detach().cpu(), torch.softmax( policy_logits, dim=1 ).detach().cpu() - # 计算优先级 + # Calculate priority. value_priority = torch.nn.L1Loss(reduction='none')(original_value.squeeze(-1), target_value[:, 0]) value_priority = value_priority.data.cpu().numpy() + 1e-6 - # 计算第一个步骤的策略和价值损失 + # Calculate policy and value loss for the first step. policy_loss = cross_entropy_loss(policy_logits, target_policy[:, 0]) value_loss = cross_entropy_loss(value, target_value_categorical[:, 0]) @@ -376,18 +403,18 @@ def _forward_learn(self, data: List[Tuple[torch.Tensor, torch.Tensor, int]]) -> consistency_loss = torch.zeros(self._cfg.batch_size[task_idx], device=self._cfg.device) target_policy_entropy = 0 - # 循环进行多个unroll步骤 + # Unroll loop for multiple steps. for step_k in range(self._cfg.num_unroll_steps): - # 使用动态函数进行递归推理 + # Recurrent inference using the dynamics function. network_output = self._learn_model.recurrent_inference(latent_state, action_batch[:, step_k]) latent_state, reward, value, policy_logits = mz_network_output_unpack(network_output) - # 记录 Dormant Ratio + # Log Dormant Ratio for the dynamics network. if step_k == self._cfg.num_unroll_steps - 1 and self._cfg.cal_dormant_ratio: action_tmp = action_batch[:, step_k] if len(action_tmp.shape) == 1: action_tmp = action_tmp.unsqueeze(-1) - # 转换动作为独热编码 + # Convert action to one-hot encoding. action_one_hot = torch.zeros(action_tmp.shape[0], policy_logits.shape[-1], device=action_tmp.device) action_tmp = action_tmp.long() action_one_hot.scatter_(1, action_tmp, 1) @@ -402,10 +429,10 @@ def _forward_learn(self, data: List[Tuple[torch.Tensor, torch.Tensor, int]]) -> percentage=self._cfg.dormant_threshold ) - # 逆变换价值 + # Inverse transform value. original_value = self.inverse_scalar_transform_handle(value) - # 计算一致性损失 + # Calculate consistency loss (self-supervised learning). if self._cfg.model.self_supervised_learning_loss and self._cfg.ssl_loss_weight > 0: beg_index, end_index = self._get_target_obs_index_in_step_k(step_k) network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index], task_id=task_id) @@ -418,17 +445,17 @@ def _forward_learn(self, data: List[Tuple[torch.Tensor, torch.Tensor, int]]) -> temp_loss = negative_cosine_similarity(dynamic_proj, observation_proj) * mask_batch[:, step_k] consistency_loss += temp_loss - # 计算策略和价值损失 + # Calculate policy and value losses. policy_loss += cross_entropy_loss(policy_logits, target_policy[:, step_k + 1]) value_loss += cross_entropy_loss(value, target_value_categorical[:, step_k + 1]) reward_loss += cross_entropy_loss(reward, target_reward_categorical[:, step_k]) - # 计算策略熵损失 + # Calculate policy entropy loss. prob = torch.softmax(policy_logits, dim=-1) entropy = -(prob * torch.log(prob + 1e-9)).sum(-1) policy_entropy_loss += -entropy - # 计算目标策略熵(仅用于调试) + # Calculate target policy entropy (for debugging purposes only). target_normalized_visit_count = target_policy[:, step_k + 1] non_masked_indices = torch.nonzero(mask_batch[:, step_k + 1]).squeeze(-1) if len(non_masked_indices) > 0: @@ -444,8 +471,7 @@ def _forward_learn(self, data: List[Tuple[torch.Tensor, torch.Tensor, int]]) -> torch.tensor(target_normalized_visit_count.shape[-1], device=self._cfg.device) ) - - # 记录预测值和奖励(如果监控额外统计) + # Log predicted values and rewards if monitoring extra statistics. if self._cfg.monitor_extra_statistics: original_rewards = self.inverse_scalar_transform_handle(reward) original_rewards_cpu = original_rewards.detach().cpu() @@ -458,52 +484,53 @@ def _forward_learn(self, data: List[Tuple[torch.Tensor, torch.Tensor, int]]) -> (predicted_policies, torch.softmax(policy_logits, dim=1).detach().cpu()) ) - # 核心学习模型更新步骤 + # Core learning model update step. weighted_loss = self._cfg.policy_loss_weight * policy_loss + \ self._cfg.value_loss_weight * value_loss + \ self._cfg.reward_loss_weight * reward_loss + \ self._cfg.ssl_loss_weight * consistency_loss + \ self._cfg.policy_entropy_weight * policy_entropy_loss - # 将多个任务的损失累加 + # Accumulate losses from multiple tasks. weighted_total_loss += weighted_loss.mean() - # 保留每个任务的损失用于日志记录 + # Store per-task losses for logging. reward_loss_multi_task.append(reward_loss.mean().item()) policy_loss_multi_task.append(policy_loss.mean().item()) value_loss_multi_task.append(value_loss.mean().item()) consistency_loss_multi_task.append(consistency_loss.mean().item()) policy_entropy_multi_task.append(policy_entropy_loss.mean().item()) - lambd_multi_task.append(torch.tensor(0., device=self._cfg.device).item()) # TODO: 如果使用梯度校正,可以在这里调整 + # TODO: Adjust if using gradient correction. + lambd_multi_task.append(torch.tensor(0., device=self._cfg.device).item()) value_priority_multi_task.append(value_priority.mean().item()) value_priority_mean_multi_task.append(value_priority.mean().item()) losses_list.append(weighted_loss.mean().item()) - # 清零优化器的梯度 + # Zero the optimizer's gradients. self._optimizer.zero_grad() - # 反向传播 + # Backward pass. weighted_total_loss.backward() - # 梯度裁剪 + # Gradient clipping. total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_( self._learn_model.parameters(), self._cfg.grad_clip_value ) - # 多GPU训练时同步梯度 + # Sync gradients for multi-GPU training. if self._cfg.multi_gpu: self.sync_gradients(self._learn_model) - # 更新优化器 + # Update optimizer. self._optimizer.step() if self._cfg.lr_piecewise_constant_decay: self.lr_scheduler.step() - # 更新目标模型 + # Update target model. self._target_model.update(self._learn_model.state_dict()) - # 获取GPU内存使用情况 + # Get GPU memory usage. if torch.cuda.is_available(): torch.cuda.synchronize() current_memory_allocated = torch.cuda.memory_allocated() @@ -514,7 +541,7 @@ def _forward_learn(self, data: List[Tuple[torch.Tensor, torch.Tensor, int]]) -> current_memory_allocated_gb = 0.0 max_memory_allocated_gb = 0.0 - # 构建返回的损失字典 + # Build the return loss dictionary. return_loss_dict = { 'Current_GPU': current_memory_allocated_gb, 'Max_GPU': max_memory_allocated_gb, @@ -525,8 +552,7 @@ def _forward_learn(self, data: List[Tuple[torch.Tensor, torch.Tensor, int]]) -> 'total_grad_norm_before_clip_wm': total_grad_norm_before_clip_wm.item(), } - # print(f'self.task_id:{self.task_id}') - # 生成任务相关的损失字典,并为每个任务相关的 loss 添加前缀 "noreduce_" + # Generate task-specific loss dictionaries, prefixing each with "noreduce_". multi_task_loss_dicts = { **generate_task_loss_dict(consistency_loss_multi_task, 'noreduce_consistency_loss_task{}', task_id=self.task_id), **generate_task_loss_dict(reward_loss_multi_task, 'noreduce_reward_loss_task{}', task_id=self.task_id), @@ -538,10 +564,10 @@ def _forward_learn(self, data: List[Tuple[torch.Tensor, torch.Tensor, int]]) -> **generate_task_loss_dict(value_priority_mean_multi_task, 'noreduce_value_priority_mean_task{}', task_id=self.task_id), } - # 合并两个字典 + # Merge the dictionaries. return_loss_dict.update(multi_task_loss_dicts) - # 返回最终的损失字典 + # Return the final loss dictionary. return return_loss_dict def _reset_collect(self, data_id: Optional[List[int]] = None, task_id: int = None) -> None: @@ -549,7 +575,8 @@ def _reset_collect(self, data_id: Optional[List[int]] = None, task_id: int = Non Overview: Reset the observation and action for the collector environment. Arguments: - - data_id (`Optional[List[int]]`): List of data ids to reset (not used in this implementation). + - data_id (:obj:`Optional[List[int]]`): List of data ids to reset (not used in this implementation). + - task_id (:obj:`int`): The ID of the task. """ if self._cfg.model.model_type in ["conv_context"]: self.last_batch_obs = initialize_zeros_batch( @@ -565,6 +592,7 @@ def _reset_eval(self, data_id: Optional[List[int]] = None, task_id: int = None) Reset the observation and action for the evaluator environment. Arguments: - data_id (:obj:`Optional[List[int]]`): List of data ids to reset (not used in this implementation). + - task_id (:obj:`int`): The ID of the task. """ if self._cfg.model.model_type in ["conv_context"]: self.last_batch_obs = initialize_zeros_batch( @@ -577,15 +605,16 @@ def _reset_eval(self, data_id: Optional[List[int]] = None, task_id: int = None) def _monitor_vars_learn(self, num_tasks: int = None) -> List[str]: """ - 概述: - 注册学习模式中需要监控的变量。注册的变量将根据 `_forward_learn` 的返回值记录到tensorboard。 - 如果提供了 `num_tasks`,则为每个任务生成监控变量。 - 参数: - - num_tasks (:obj:`int`, 可选): 任务数量。 - 返回: - - monitored_vars (:obj:`List[str]`): 需要监控的变量列表。 + Overview: + Registers variables to be monitored during the learning phase. The registered variables + will be recorded to TensorBoard based on the return value of `_forward_learn`. + If `num_tasks` is provided, it generates monitoring variables for each task. + Arguments: + - num_tasks (:obj:`int`, optional): The number of tasks. + Returns: + - monitored_vars (:obj:`List[str]`): A list of variable names to be monitored. """ - # 基本监控变量 + # Basic monitoring variables. monitored_vars = [ 'Current_GPU', 'Max_GPU', @@ -596,7 +625,7 @@ def _monitor_vars_learn(self, num_tasks: int = None) -> List[str]: 'total_grad_norm_before_clip_wm', ] - # 任务特定的监控变量 + # Task-specific monitoring variables. task_specific_vars = [ 'noreduce_consistency_loss', 'noreduce_reward_loss', @@ -607,7 +636,8 @@ def _monitor_vars_learn(self, num_tasks: int = None) -> List[str]: 'noreduce_value_priority', 'noreduce_value_priority_mean', ] - # self.task_num_for_current_rank 作为当前rank的base_index + + # Use self.task_num_for_current_rank as the number of tasks for the current rank. num_tasks = self.task_num_for_current_rank print(f'self.task_num_for_current_rank: {self.task_num_for_current_rank}') if num_tasks is not None: @@ -656,16 +686,7 @@ def _forward_collect( - to_play (:obj:`int`): The player to play. - epsilon (:obj:`float`): The epsilon of the eps greedy exploration. - ready_env_id (:obj:`list`): The id of the env that is ready to collect. - Shape: - - data (:obj:`torch.Tensor`): - - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ - S is the number of stacked frames, H is the height of the image, W is the width of the image. - - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. - - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. - - temperature: :math:`(1, )`. - - to_play: :math:`(N, 1)`, where N is the number of collect_env. - - epsilon: :math:`(1, )`. - - ready_env_id: None + - task_id (:obj:`int`): The ID of the task. Returns: - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. @@ -692,22 +713,22 @@ def _forward_collect( legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] if not self._cfg.collect_with_pure_policy: - # the only difference between collect and eval is the dirichlet noise + # The only difference between collect and eval is the dirichlet noise. noises = [ np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) ).astype(np.float32).tolist() for j in range(active_collect_env_num) ] if self._cfg.mcts_ctree: - # cpp mcts_tree + # C++ MCTS tree. roots = MCTSCtree.roots(active_collect_env_num, legal_actions) else: - # python mcts_tree + # Python MCTS tree. roots = MCTSPtree.roots(active_collect_env_num, legal_actions) roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play, task_id=task_id) - # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + # List of lists, shape: ``{list: batch_size} -> {list: action_space_size}`` roots_visit_count_distributions = roots.get_distributions() roots_values = roots.get_values() # shape: {list: batch_size} @@ -715,7 +736,7 @@ def _forward_collect( for i, env_id in enumerate(ready_env_id): distributions, value = roots_visit_count_distributions[i], roots_values[i] if self._cfg.eps.eps_greedy_exploration_in_collect: - # eps greedy collect + # Epsilon-greedy exploration for collection. action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( distributions, temperature=self._collect_mcts_temperature, deterministic=True ) @@ -723,13 +744,13 @@ def _forward_collect( if np.random.rand() < self.collect_epsilon: action = np.random.choice(legal_actions[i]) else: - # normal collect - # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents - # the index within the legal action set, rather than the index in the entire action set. + # Normal collection. + # NOTE: Only legal actions possess visit counts, so ``action_index_in_legal_action_set`` represents + # the index within the legal action set, not the entire action set. action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( distributions, temperature=self._collect_mcts_temperature, deterministic=False ) - # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. + # NOTE: Convert ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] output[env_id] = { 'action': action, @@ -746,6 +767,7 @@ def _forward_collect( self.last_batch_obs = data self.last_batch_action = batch_action else: + # Pure policy collection (without MCTS). for i, env_id in enumerate(ready_env_id): policy_values = torch.softmax(torch.tensor([policy_logits[i][a] for a in legal_actions[i]]), dim=0).tolist() @@ -761,21 +783,15 @@ def _forward_collect( return output - def _get_target_obs_index_in_step_k(self, step): + def _get_target_obs_index_in_step_k(self, step: int) -> Tuple[int, int]: """ Overview: - Get the begin index and end index of the target obs in step k. + Get the begin and end indices of the target observation at step k. Arguments: - step (:obj:`int`): The current step k. Returns: - - beg_index (:obj:`int`): The begin index of the target obs in step k. - - end_index (:obj:`int`): The end index of the target obs in step k. - Examples: - >>> self._cfg.model.model_type = 'conv' - >>> self._cfg.model.image_channel = 3 - >>> self._cfg.model.frame_stack_num = 4 - >>> self._get_target_obs_index_in_step_k(0) - >>> (0, 12) + - beg_index (:obj:`int`): The beginning index of the target observation. + - end_index (:obj:`int`): The ending index of the target observation. """ if self._cfg.model.model_type in ['conv', 'conv_context']: beg_index = self._cfg.model.image_channel * step @@ -798,9 +814,6 @@ def _init_eval(self) -> None: if self._cfg.model.model_type == 'conv_context': self.last_batch_obs = torch.zeros([3, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) self.last_batch_action = [-1 for _ in range(3)] - # elif self._cfg.model.model_type == 'mlp_context': - # self.last_batch_obs = torch.zeros([3, self._cfg.model.observation_shape]).to(self._cfg.device) - # self.last_batch_action = [-1 for _ in range(3)] def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id: np.array = None, task_id: int = None) -> Dict: @@ -813,14 +826,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. - to_play (:obj:`int`): The player to play. - ready_env_id (:obj:`list`): The id of the env that is ready to collect. - Shape: - - data (:obj:`torch.Tensor`): - - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ - S is the number of stacked frames, H is the height of the image, W is the width of the image. - - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. - - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. - - to_play: :math:`(N, 1)`, where N is the number of collect_env. - - ready_env_id: None + - task_id (:obj:`int`): The ID of the task. Returns: - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. @@ -838,36 +844,36 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) if not self._eval_model.training: - # if not in training, obtain the scalars of the value/reward - pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) + # If not in training, obtain the scalar values of the value/reward. + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape (B, 1) latent_state_roots = latent_state_roots.detach().cpu().numpy() - policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) + policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape (B, A) legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)] if self._cfg.mcts_ctree: - # cpp mcts_tree + # C++ MCTS tree. roots = MCTSCtree.roots(active_eval_env_num, legal_actions) else: - # python mcts_tree + # Python MCTS tree. roots = MCTSPtree.roots(active_eval_env_num, legal_actions) roots.prepare_no_noise(reward_roots, policy_logits, to_play) self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play, task_id=task_id) - # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + # List of lists, shape: ``{list: batch_size} -> {list: action_space_size}`` roots_visit_count_distributions = roots.get_distributions() roots_values = roots.get_values() # shape: {list: batch_size} batch_action = [] for i, env_id in enumerate(ready_env_id): distributions, value = roots_visit_count_distributions[i], roots_values[i] - # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents - # the index within the legal action set, rather than the index in the entire action set. - # Setting deterministic=True implies choosing the action with the highest value (argmax) rather than - # sampling during the evaluation phase. + # NOTE: Only legal actions possess visit counts, so ``action_index_in_legal_action_set`` represents + # the index within the legal action set, not the entire action set. + # Setting deterministic=True implies choosing the action with the highest value (argmax) + # rather than sampling during the evaluation phase. action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( distributions, temperature=1, deterministic=True ) - # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the + # NOTE: Convert ``action_index_in_legal_action_set`` to the corresponding ``action`` in the # entire action set. action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] @@ -886,5 +892,4 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 self.last_batch_obs = data self.last_batch_action = batch_action - return output - + return output \ No newline at end of file diff --git a/lzero/policy/muzero_rnn_full_obs.py b/lzero/policy/muzero_rnn_full_obs.py index 060c43680..e7dbf55ed 100644 --- a/lzero/policy/muzero_rnn_full_obs.py +++ b/lzero/policy/muzero_rnn_full_obs.py @@ -44,9 +44,10 @@ class MuZeroRNNFullObsPolicy(MuZeroPolicy): image_channel=1, # (int) The number of frames to stack together. frame_stack_num=1, - # (int) The scale of supports used in categorical distribution. - # This variable is only effective when ``categorical_distribution=True``. - support_scale=300, + # (tuple) The range of supports used in categorical distribution. + # These variables are only effective when ``model.categorical_distribution=True``. + reward_support_range=(-300., 301., 1.), + value_support_range=(-300., 301., 1.), # (int) The hidden size in LSTM. rnn_hidden_size=512, # gru_hidden_size=512, @@ -300,7 +301,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: # transform the scaled value or its categorical representation to its original value, # i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. - original_value = self.inverse_scalar_transform_handle(value) + original_value = self.value_inverse_scalar_transform_handle(value) # Note: The following lines are just for debugging. predicted_rewards = [] @@ -433,10 +434,10 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: reward_loss += cross_entropy_loss(reward, target_reward_categorical[:, step_k]) if self._cfg.monitor_extra_statistics: - original_rewards = self.inverse_scalar_transform_handle(reward) + original_rewards = self.reward_inverse_scalar_transform_handle(reward) original_rewards_cpu = original_rewards.detach().cpu() predicted_values = torch.cat( - (predicted_values, self.inverse_scalar_transform_handle(value).detach().cpu()) + (predicted_values, self.value_inverse_scalar_transform_handle(value).detach().cpu()) ) predicted_rewards.append(original_rewards_cpu) predicted_policies = torch.cat((predicted_policies, torch.softmax(policy_logits, dim=1).detach().cpu())) @@ -581,7 +582,7 @@ def _forward_collect( latent_state_roots, reward_roots, world_model_latent_history_roots, pred_values, policy_logits = ez_network_output_unpack( network_output ) - pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() latent_state_roots = latent_state_roots.detach().cpu().numpy() world_model_latent_history_roots = world_model_latent_history_roots.detach().cpu().numpy() @@ -709,7 +710,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [ if not self._eval_model.training: # if not in training, obtain the scalars of the value/reward - pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) latent_state_roots = latent_state_roots.detach().cpu().numpy() world_model_latent_history_roots = world_model_latent_history_roots.detach().cpu().numpy() policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) diff --git a/lzero/policy/random_policy.py b/lzero/policy/random_policy.py index bc914e1a2..c84806b76 100644 --- a/lzero/policy/random_policy.py +++ b/lzero/policy/random_policy.py @@ -5,7 +5,7 @@ from ding.policy.base_policy import Policy from ding.utils import POLICY_REGISTRY -from lzero.policy import InverseScalarTransform, select_action, ez_network_output_unpack, mz_network_output_unpack +from lzero.policy import DiscreteSupport, InverseScalarTransform, select_action, ez_network_output_unpack, mz_network_output_unpack @POLICY_REGISTRY.register('lightzero_random_policy') @@ -81,9 +81,10 @@ def _init_collect(self) -> None: self._mcts_collect = self.MCTSPtree(self._cfg) self._collect_mcts_temperature = 1 self.collect_epsilon = 0.0 - self.inverse_scalar_transform_handle = InverseScalarTransform( - self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution - ) + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) + self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) + self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) def _forward_collect( self, @@ -132,7 +133,7 @@ def _forward_collect( else: raise NotImplementedError("need to implement pipeline: {}".format(self._cfg.type)) - pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() latent_state_roots = latent_state_roots.detach().cpu().numpy() if self._cfg.type in ['efficientzero', 'sampled_efficientzero']: reward_hidden_state_roots = ( diff --git a/lzero/policy/sampled_efficientzero.py b/lzero/policy/sampled_efficientzero.py index fb73014e6..eb8c32273 100644 --- a/lzero/policy/sampled_efficientzero.py +++ b/lzero/policy/sampled_efficientzero.py @@ -50,9 +50,10 @@ class SampledEfficientZeroPolicy(MuZeroPolicy): image_channel=1, # (int) The number of frames to stack together. frame_stack_num=1, - # (int) The scale of supports used in categorical distribution. - # This variable is only effective when ``categorical_distribution=True``. - support_scale=300, + # (tuple) The range of supports used in categorical distribution. + # These variables are only effective when ``model.categorical_distribution=True``. + reward_support_range=(-300., 301., 1.), + value_support_range=(-300., 301., 1.), # (int) The number of res blocks in Sampled EfficientZero model. num_res_blocks=1, # (int) The hidden size in LSTM. @@ -302,11 +303,12 @@ def _init_learn(self) -> None: self._cfg.augmentation, image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) ) - self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) - self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) - self.inverse_scalar_transform_handle = InverseScalarTransform( - self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution - ) + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) + assert self.value_support.size == self._learn_model.value_support_size # if these assertions fails, somebody introduced... + assert self.reward_support.size == self._learn_model.reward_support_size # ...incoherence between policy and model + self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) + self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: """ @@ -378,7 +380,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: # transform the scaled value or its categorical representation to its original value, # i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. - original_value = self.inverse_scalar_transform_handle(value) + original_value = self.value_inverse_scalar_transform_handle(value) # Note: The following lines are just for logging. predicted_value_prefixs = [] @@ -486,11 +488,11 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: ) if self._cfg.monitor_extra_statistics: - original_value_prefixs = self.inverse_scalar_transform_handle(value_prefix) + original_value_prefixs = self.value_inverse_scalar_transform_handle(value_prefix) original_value_prefixs_cpu = original_value_prefixs.detach().cpu() predicted_values = torch.cat( - (predicted_values, self.inverse_scalar_transform_handle(value).detach().cpu()) + (predicted_values, self.value_inverse_scalar_transform_handle(value).detach().cpu()) ) predicted_value_prefixs.append(original_value_prefixs_cpu) predicted_policies = torch.cat((predicted_policies, torch.softmax(policy_logits, dim=1).detach().cpu())) @@ -838,7 +840,7 @@ def _forward_collect( network_output ) - pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() latent_state_roots = latent_state_roots.detach().cpu().numpy() reward_hidden_state_roots = ( reward_hidden_state_roots[0].detach().cpu().numpy(), @@ -973,7 +975,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [ if not self._eval_model.training: # if not in training, obtain the scalars of the value/reward - pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) latent_state_roots = latent_state_roots.detach().cpu().numpy() reward_hidden_state_roots = ( reward_hidden_state_roots[0].detach().cpu().numpy(), diff --git a/lzero/policy/sampled_muzero.py b/lzero/policy/sampled_muzero.py index 2a72d6ccb..3548c03be 100644 --- a/lzero/policy/sampled_muzero.py +++ b/lzero/policy/sampled_muzero.py @@ -50,9 +50,10 @@ class SampledMuZeroPolicy(MuZeroPolicy): image_channel=1, # (int) The number of frames to stack together. frame_stack_num=1, - # (int) The scale of supports used in categorical distribution. - # This variable is only effective when ``categorical_distribution=True``. - support_scale=300, + # (tuple) The range of supports used in categorical distribution. + # These variables are only effective when ``model.categorical_distribution=True``. + reward_support_range=(-300., 301., 1.), + value_support_range=(-300., 301., 1.), # (int) The number of res blocks in Sampled MuZero model. num_res_blocks=1, # (int) The hidden size in LSTM. @@ -302,11 +303,12 @@ def _init_learn(self) -> None: self._cfg.augmentation, image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) ) - self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) - self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) - self.inverse_scalar_transform_handle = InverseScalarTransform( - self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution - ) + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) + assert self.value_support.size == self._learn_model.value_support_size # if these assertions fails, somebody introduced... + assert self.reward_support.size == self._learn_model.reward_support_size # ...incoherence between policy and model + self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) + self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: """ @@ -378,7 +380,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: # transform the scaled value or its categorical representation to its original value, # i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. - original_value = self.inverse_scalar_transform_handle(value) + original_value = self.value_inverse_scalar_transform_handle(value) # Note: The following lines are just for logging. predicted_rewards = [] @@ -479,11 +481,11 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: reward_loss += cross_entropy_loss(reward, target_reward_categorical[:, step_k]) if self._cfg.monitor_extra_statistics: - original_rewards = self.inverse_scalar_transform_handle(reward) + original_rewards = self.reward_inverse_scalar_transform_handle(reward) original_rewards_cpu = original_rewards.detach().cpu() predicted_values = torch.cat( - (predicted_values, self.inverse_scalar_transform_handle(value).detach().cpu()) + (predicted_values, self.value_inverse_scalar_transform_handle(value).detach().cpu()) ) predicted_rewards.append(original_rewards_cpu) predicted_policies = torch.cat((predicted_policies, torch.softmax(policy_logits, dim=1).detach().cpu())) @@ -835,7 +837,7 @@ def _forward_collect( network_output ) - pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() latent_state_roots = latent_state_roots.detach().cpu().numpy() policy_logits = policy_logits.detach().cpu().numpy().tolist() @@ -966,7 +968,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [ if not self._eval_model.training: # if not in training, obtain the scalars of the value/reward - pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) latent_state_roots = latent_state_roots.detach().cpu().numpy() policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) diff --git a/lzero/policy/sampled_unizero.py b/lzero/policy/sampled_unizero.py index 3e872cca3..ec7399fc6 100644 --- a/lzero/policy/sampled_unizero.py +++ b/lzero/policy/sampled_unizero.py @@ -60,9 +60,10 @@ class SampledUniZeroPolicy(UniZeroPolicy): num_res_blocks=1, # (int) The number of channels of hidden states in MuZero model. num_channels=64, - # (int) The scale of supports used in categorical distribution. - # This variable is only effective when ``categorical_distribution=True``. - support_scale=50, + # (tuple) The range of supports used in categorical distribution. + # These variables are only effective when ``model.categorical_distribution=True``. + reward_support_range=(-50., 51., 1.), + value_support_range=(-50., 51., 1.), # (bool) whether to learn bias in the last linear layer in value and policy head. bias=True, # (bool) whether to use res connection in dynamics. @@ -357,11 +358,13 @@ def _init_learn(self) -> None: self._cfg.augmentation, image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) ) - self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) - self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) - self.inverse_scalar_transform_handle = InverseScalarTransform( - self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution - ) + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) + assert self.value_support.size == self._learn_model.value_support_size # if these assertions fails, somebody introduced... + assert self.reward_support.size == self._learn_model.reward_support_size # ...incoherence between policy and model + self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) + self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) + self.intermediate_losses = defaultdict(float) self.l2_norm_before = 0. self.l2_norm_after = 0. @@ -467,8 +470,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # Update world model losses = self._learn_model.world_model.compute_loss( - batch_for_gpt, self._target_model.world_model.tokenizer, self.inverse_scalar_transform_handle - ) + batch_for_gpt, self._target_model.world_model.tokenizer, self.value_inverse_scalar_transform_handle + ) # NOTE : compute_loss third argument is now a dead argument. If this changes, it could need adaptation between value_inverse and reward_inverse. weighted_total_loss = losses.loss_total for loss_name, loss_value in losses.intermediate_losses.items(): @@ -695,7 +698,7 @@ def _forward_collect( network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, timestep) latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) - pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() latent_state_roots = latent_state_roots.detach().cpu().numpy() policy_logits = policy_logits.detach().cpu().numpy().tolist() @@ -849,7 +852,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [ if not self._eval_model.training: # if not in training, obtain the scalars of the value/reward - pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) latent_state_roots = latent_state_roots.detach().cpu().numpy() policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) diff --git a/lzero/policy/sampled_unizero_multitask.py b/lzero/policy/sampled_unizero_multitask.py index ccdefb656..00d929f51 100644 --- a/lzero/policy/sampled_unizero_multitask.py +++ b/lzero/policy/sampled_unizero_multitask.py @@ -1,5 +1,3 @@ -# /Users/puyuan/code/LightZero/lzero/policy/sample_unizero_multitask.py - import copy import logging from collections import defaultdict @@ -9,7 +7,7 @@ import torch import wandb from ding.model import model_wrap -from ding.utils import POLICY_REGISTRY +from ding.utils import POLICY_REGISTRY, set_pkg_seed, get_rank, get_world_size from lzero.entry.utils import initialize_zeros_batch from lzero.mcts import SampledUniZeroMCTSCtree as MCTSCtree @@ -23,45 +21,74 @@ mz_network_output_unpack, select_action, prepare_obs, - prepare_obs_stack4_for_unizero + prepare_obs_stack_for_unizero ) from lzero.policy.unizero import UniZeroPolicy from .utils import configure_optimizers_nanogpt import torch.nn.functional as F import torch.distributed as dist + +# Please add the path to your LibMTL library. +# For example: sys.path.append('/path/to/your/LibMTL/') import sys -sys.path.append('/mnt/afs/niuyazhe/code/LibMTL/') +# sys.path.append('/path/to/your/LibMTL/') # Template path from LibMTL.weighting.MoCo_unizero import MoCo as GradCorrect -# from LibMTL.weighting.CAGrad_unizero import CAGrad as GradCorrect -def generate_task_loss_dict(multi_task_losses, task_name_template, task_id): + +def generate_task_loss_dict(multi_task_losses: List[Union[torch.Tensor, float]], task_name_template: str, task_id: int) -> Dict[str, float]: """ - 生成每个任务的损失字典 - :param multi_task_losses: 包含每个任务损失的列表 - :param task_name_template: 任务名称模板,例如 'obs_loss_task{}' - :param task_id: 基础任务 ID - :return: 一个字典,包含每个任务的损失 + Overview: + Generates a dictionary for losses of each task. + Arguments: + - multi_task_losses (:obj:`List[Union[torch.Tensor, float]]`): A list containing the loss for each task. + - task_name_template (:obj:`str`): A template for the task name, e.g., 'obs_loss_task{}'. + - task_id (:obj:`int`): The base task ID. + Returns: + - (:obj:`Dict[str, float]`): A dictionary containing the loss for each task. """ task_loss_dict = {} for task_idx, task_loss in enumerate(multi_task_losses): task_name = task_name_template.format(task_idx + task_id) try: + # Convert tensor to float if it has .item(), otherwise cast to float. task_loss_dict[task_name] = task_loss.item() if hasattr(task_loss, 'item') else float(task_loss) except Exception as e: + # Fallback for cases where conversion fails. task_loss_dict[task_name] = task_loss return task_loss_dict class WrappedModelV2: - def __init__(self, tokenizer, transformer, pos_emb, task_emb, act_embedding_table): + """ + Overview: + A wrapper class to conveniently manage different parts of a larger model, + such as the tokenizer, transformer, and various embedding layers. This allows for + easier handling of parameters and gradients for these components. + """ + def __init__(self, tokenizer: torch.nn.Module, transformer: torch.nn.Module, pos_emb: torch.nn.Module, task_emb: torch.nn.Module, act_embedding_table: torch.nn.Module): + """ + Overview: + Initializes the WrappedModelV2 with model components. + Arguments: + - tokenizer (:obj:`torch.nn.Module`): The tokenizer module. + - transformer (:obj:`torch.nn.Module`): The main transformer module. + - pos_emb (:obj:`torch.nn.Module`): The positional embedding layer. + - task_emb (:obj:`torch.nn.Module`): The task embedding layer. + - act_embedding_table (:obj:`torch.nn.Module`): The action embedding table. + """ self.tokenizer = tokenizer self.transformer = transformer self.pos_emb = pos_emb self.task_emb = task_emb self.act_embedding_table = act_embedding_table - def parameters(self): - # 返回 tokenizer, transformer 以及所有嵌入层的参数 + def parameters(self) -> List[torch.Tensor]: + """ + Overview: + Collects and returns all parameters from the wrapped model components. + Returns: + - (:obj:`List[torch.Tensor]`): A list of all parameters. + """ return ( list(self.tokenizer.parameters()) + list(self.transformer.parameters()) + @@ -70,18 +97,27 @@ def parameters(self): list(self.act_embedding_table.parameters()) ) - def zero_grad(self, set_to_none=False): - # 将 tokenizer, transformer 和所有嵌入层的梯度设为零 + def zero_grad(self, set_to_none: bool = False) -> None: + """ + Overview: + Sets the gradients of all wrapped model components to zero. + Arguments: + - set_to_none (:obj:`bool`): Whether to set gradients to None instead of zero. Defaults to False. + """ self.tokenizer.zero_grad(set_to_none=set_to_none) self.transformer.zero_grad(set_to_none=set_to_none) self.pos_emb.zero_grad(set_to_none=set_to_none) # self.task_emb.zero_grad(set_to_none=set_to_none) self.act_embedding_table.zero_grad(set_to_none=set_to_none) - def get_group_parameters(self): + def get_group_parameters(self) -> Dict[str, List[torch.Tensor]]: """ - 返回一个字典,其中 key 为模块名或更细粒度的层, - value 为对应的参数列表。注意返回顺序应与 parameters()方法中参数的排列顺序一致。 + Overview: + Returns a dictionary where keys are module names (or finer-grained layers) + and values are the corresponding parameter lists. The order of parameters in the + returned dictionary's values should be consistent with the `parameters()` method. + Returns: + - (:obj:`Dict[str, List[torch.Tensor]]`): A dictionary of grouped parameters. """ groups = {} groups['tokenizer'] = list(self.tokenizer.parameters()) @@ -89,15 +125,14 @@ def get_group_parameters(self): groups['pos_emb'] = list(self.pos_emb.parameters()) groups['act_embedding_table'] = list(self.act_embedding_table.parameters()) - # 如 transformer 内部分层(假设 transformer.blocks 是列表) + # Example of how to add parameters from sub-layers within the transformer. + # This is for demonstration; ensure the order in parameters() is consistent if used. if hasattr(self.transformer, 'blocks'): - # 若要单独统计 transformer 内各层,保持原 transformer 参数在 parameters() 中顺序不变, - # 可以在这里添加各层的切片,但需保证 parameters() 返回的顺序与此一致, - # 此处仅作为示例: for i, layer in enumerate(self.transformer.blocks): groups[f'transformer_layer_{i}'] = list(layer.parameters()) return groups + @POLICY_REGISTRY.register('sampled_unizero_multitask') class SampledUniZeroMTPolicy(UniZeroPolicy): """ @@ -153,7 +188,7 @@ class SampledUniZeroMTPolicy(UniZeroPolicy): predict_latent_loss_type='group_kl', obs_type='image', gamma=1, - dormant_threshold=0.025, + dormant_threshold=0.01, policy_loss_type='kl', ), ), @@ -231,15 +266,21 @@ class SampledUniZeroMTPolicy(UniZeroPolicy): def default_model(self) -> Tuple[str, List[str]]: """ - Return this algorithm's default model setting for demonstration. + Overview: + Return this algorithm's default model setting for demonstration. + Returns: + - (:obj:`Tuple[str, List[str]]`): A tuple containing the model name and the import paths. """ return 'SampledUniZeroMTModel', ['lzero.model.sampled_unizero_model_multitask'] def _init_learn(self) -> None: """ - Learn mode init method. Initialize the learn model, optimizer, and MCTS utils. + Overview: + Initializes the learning mode. This method sets up the learn model, optimizer, + target model, and other utilities required for training, such as LR schedulers + and gradient correction methods (e.g., MoCo). """ - # Configure optimizer for world model + # Configure optimizer for the world model using NanoGPT's configuration utility. self._optimizer_world_model = configure_optimizers_nanogpt( model=self._model.world_model, learning_rate=self._cfg.learning_rate, @@ -248,6 +289,7 @@ def _init_learn(self) -> None: betas=(0.9, 0.95), ) + # Initialize learning rate schedulers if configured. if self._cfg.cos_lr_scheduler or self._cfg.piecewise_decay_lr_scheduler: from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR @@ -256,13 +298,12 @@ def _init_learn(self) -> None: self._optimizer_world_model, T_max=int(1e5), eta_min=0, last_epoch=-1 ) elif self._cfg.piecewise_decay_lr_scheduler: - # Example step scheduler, adjust milestones and gamma as needed self.lr_scheduler = StepLR( self._optimizer_world_model, step_size=int(5e4), gamma=0.1 ) + # Initialize weights for continuous action spaces. if self._cfg.model.continuous_action_space: - # Weight Init for the last output layer of gaussian policy head in prediction network. init_w = self._cfg.init_w self._model.world_model.fc_policy_head.mu.weight.data.uniform_(-init_w, init_w) self._model.world_model.fc_policy_head.mu.bias.data.uniform_(-init_w, init_w) @@ -272,13 +313,13 @@ def _init_learn(self) -> None: except Exception as exception: logging.warning(exception) - # Initialize target model + # Initialize and compile the target model. self._target_model = copy.deepcopy(self._model) - # Ensure torch version >= 2.0 - assert int(''.join(filter(str.isdigit, torch.__version__))) >= 200, "We need torch version >= 2.0" + assert int(''.join(filter(str.isdigit, torch.__version__))) >= 200, "Torch version 2.0 or higher is required." self._model = torch.compile(self._model) self._target_model = torch.compile(self._target_model) - # Soft target update + + # Wrap the target model for soft updates (momentum-based). self._target_model = model_wrap( self._target_model, wrapper_name='target', @@ -287,12 +328,7 @@ def _init_learn(self) -> None: ) self._learn_model = self._model - # if self._cfg.use_augmentation: - # self.image_transforms = ImageTransforms( - # self._cfg.augmentation, - # image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) - # ) - + # Initialize utilities for loss calculation and transformations. self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) self.inverse_scalar_transform_handle = InverseScalarTransform( @@ -307,225 +343,188 @@ def _init_learn(self) -> None: self.task_id = self._cfg.task_id self.task_num_for_current_rank = self._cfg.task_num print(f'self._cfg.only_use_moco_stats:{self._cfg.only_use_moco_stats}') + + # Initialize gradient correction method (MoCo) if enabled. if self._cfg.use_moco or self._cfg.only_use_moco_stats: - # 创建 WrappedModel 实例,仅矫正部分参数,保持可扩展性 - # wrapped_model = WrappedModelV2( - # self._learn_model.world_model.tokenizer.encoder[0], # 假设只有一个编码器 - # self._learn_model.world_model.transformer, - # self._learn_model.world_model.pos_emb, - # self._learn_model.world_model.task_emb, - # self._learn_model.world_model.act_embedding_table, - # ) - - # head 没有矫正梯度 + # Wrap model components for gradient correction. Note: Heads are not included. wrapped_model = WrappedModelV2( - self._learn_model.world_model.tokenizer.encoder, # TODO: one or N encoder inside + self._learn_model.world_model.tokenizer.encoder, # TODO: This might contain one or multiple encoders. self._learn_model.world_model.transformer, self._learn_model.world_model.pos_emb, self._learn_model.world_model.task_emb, self._learn_model.world_model.act_embedding_table, ) - # TODO - # 如果需要,可以在这里初始化梯度校正方法(如 MoCo, CAGrad) - # self.grad_correct = GradCorrect(wrapped_model, self.task_num, self._cfg.device) - # self.grad_correct = GradCorrect(wrapped_model, self._cfg.task_num, self._cfg.device, self._cfg.multi_gpu) # only compatiable with for 1GPU training - self.grad_correct = GradCorrect(wrapped_model, self._cfg.total_task_num, self._cfg.device, self._cfg.multi_gpu) # only compatiable with for 1GPU training + # TODO: The GradCorrect class might need adjustments for multi-GPU training compatibility. + # Initialize the gradient correction mechanism. + self.grad_correct = GradCorrect(wrapped_model, self._cfg.total_task_num, self._cfg.device, self._cfg.multi_gpu) self.grad_correct.init_param() self.grad_correct.rep_grad = False - def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None) -> Dict[str, Union[float, int]]: + def _forward_learn(self, data: Tuple[torch.Tensor], task_weights: Any = None, ignore_grad: bool = False) -> Dict[str, Union[float, int]]: """ - Forward function for learning policy in learn mode, handling multiple tasks. + Overview: + The forward pass for training. This method processes a batch of data for multiple tasks, + computes losses, and updates the model weights. + Arguments: + - data (:obj:`Tuple[torch.Tensor]`): A tuple of data batches, one for each task. + - task_weights (:obj:`Any`): Weights for each task's loss. Defaults to None. + - ignore_grad (:obj:`bool`): If True, gradients are zeroed out after computation, effectively skipping the update. Defaults to False. + Returns: + - (:obj:`Dict[str, Union[float, int]]`): A dictionary containing various loss values and training statistics. """ self._learn_model.train() self._target_model.train() - # Initialize multi-task loss lists - task_weight_multi_task = [] - - obs_loss_multi_task = [] - reward_loss_multi_task = [] - policy_loss_multi_task = [] - orig_policy_loss_multi_task = [] - policy_entropy_multi_task = [] - value_loss_multi_task = [] - latent_recon_loss_multi_task = [] - perceptual_loss_multi_task = [] - latent_state_l2_norms_multi_task = [] - average_target_policy_entropy_multi_task = [] - value_priority_multi_task = [] - value_priority_mean_multi_task = [] + # Initialize lists to store losses and metrics for each task. + task_weight_multi_task, obs_loss_multi_task, reward_loss_multi_task = [], [], [] + policy_loss_multi_task, orig_policy_loss_multi_task, policy_entropy_multi_task = [], [], [] + value_loss_multi_task, latent_recon_loss_multi_task, perceptual_loss_multi_task = [], [], [] + latent_state_l2_norms_multi_task, average_target_policy_entropy_multi_task = [], [] + value_priority_multi_task, value_priority_mean_multi_task = [], [] weighted_total_loss = 0.0 - losses_list = [] # 存储每个任务的损失 + losses_list = [] # Stores the individual loss tensor for each task. for task_id, data_one_task in enumerate(data): + # Unpack data for the current task. current_batch, target_batch, task_id = data_one_task - obs_batch_ori, action_batch, child_sampled_actions_batch, target_action_batch, mask_batch, indices, weights, make_time = current_batch + obs_batch_ori, action_batch, child_sampled_actions_batch, target_action_batch, mask_batch, indices, weights, make_time, timestep_batch = current_batch target_reward, target_value, target_policy = target_batch - # Prepare observations based on frame stack number + # Prepare observations. if self._cfg.model.frame_stack_num == 4: - obs_batch, obs_target_batch = prepare_obs_stack4_for_unizero(obs_batch_ori, self._cfg) + obs_batch, obs_target_batch = prepare_obs_stack_for_unizero(obs_batch_ori, self._cfg) else: obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg, task_id) - # Apply augmentations if needed + # Apply data augmentation if enabled. if self._cfg.use_augmentation: obs_batch = self.image_transforms.transform(obs_batch) if self._cfg.model.self_supervised_learning_loss: obs_target_batch = self.image_transforms.transform(obs_target_batch) - # Prepare action batch and convert to torch tensor - if self._cfg.model.continuous_action_space: - action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(-1) - else: - action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(-1).long() - data_list = [ - mask_batch, - target_reward.astype('float32'), - target_value.astype('float32'), - target_policy, - weights - ] + # Prepare actions and convert data to torch tensors. + action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(-1) + if not self._cfg.model.continuous_action_space: + action_batch = action_batch.long() + + data_list = [mask_batch, target_reward.astype('float32'), target_value.astype('float32'), target_policy, weights] mask_batch, target_reward, target_value, target_policy, weights = to_torch_float_tensor(data_list, self._cfg.device) - target_reward = target_reward.view(self._cfg.batch_size[task_id], -1) - target_value = target_value.view(self._cfg.batch_size[task_id], -1) + cur_batch_size = target_reward.size(0) + target_reward = target_reward.view(cur_batch_size, -1) + target_value = target_value.view(cur_batch_size, -1) - # Transform rewards and values to their scaled forms + # Transform scalar targets to their categorical representation. transformed_target_reward = scalar_transform(target_reward) transformed_target_value = scalar_transform(target_value) - - # Convert to categorical distributions target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) target_value_categorical = phi_transform(self.value_support, transformed_target_value) - # Prepare batch for GPT model + # Prepare the batch for the GPT-based world model. batch_for_gpt = {} if isinstance(self._cfg.model.observation_shape_list[task_id], int) or len(self._cfg.model.observation_shape_list[task_id]) == 1: - batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( - self._cfg.batch_size[task_id], -1, self._cfg.model.observation_shape_list[task_id]) - elif len(self._cfg.model.observation_shape_list[task_id]) == 3: - batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( - self._cfg.batch_size[task_id], -1, *self._cfg.model.observation_shape_list[task_id]) + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape(cur_batch_size, -1, self._cfg.model.observation_shape_list[task_id]) + else: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape(cur_batch_size, -1, *self._cfg.model.observation_shape_list[task_id]) batch_for_gpt['actions'] = action_batch.squeeze(-1) batch_for_gpt['child_sampled_actions'] = torch.from_numpy(child_sampled_actions_batch).to(self._cfg.device)[:, :-1] batch_for_gpt['rewards'] = target_reward_categorical[:, :-1] - batch_for_gpt['mask_padding'] = mask_batch == 1.0 # 0 means invalid padding data - batch_for_gpt['mask_padding'] = batch_for_gpt['mask_padding'][:, :-1] + batch_for_gpt['mask_padding'] = (mask_batch == 1.0)[:, :-1] # 0 indicates invalid padding data. batch_for_gpt['observations'] = batch_for_gpt['observations'][:, :-1] batch_for_gpt['ends'] = torch.zeros(batch_for_gpt['mask_padding'].shape, dtype=torch.long, device=self._cfg.device) batch_for_gpt['target_value'] = target_value_categorical[:, :-1] batch_for_gpt['target_policy'] = target_policy[:, :-1] - # Extract valid target policy data and compute entropy + # Compute target policy entropy for monitoring. valid_target_policy = batch_for_gpt['target_policy'][batch_for_gpt['mask_padding']] target_policy_entropy = -torch.sum(valid_target_policy * torch.log(valid_target_policy + 1e-9), dim=-1) average_target_policy_entropy = target_policy_entropy.mean().item() - # Update world model + # Compute losses using the world model. losses = self._learn_model.world_model.compute_loss( - batch_for_gpt, - self._target_model.world_model.tokenizer, - self.inverse_scalar_transform_handle, - task_id=task_id + batch_for_gpt, self._target_model.world_model.tokenizer, self.inverse_scalar_transform_handle, task_id=task_id ) - if task_weights is not None: - weighted_total_loss += losses.loss_total * task_weights[task_id] - losses_list.append(losses.loss_total * task_weights[task_id]) - - task_weight_multi_task.append(task_weights[task_id]) - else: - weighted_total_loss += losses.loss_total - losses_list.append(losses.loss_total) - - task_weight_multi_task.append(1) - + + # Accumulate weighted total loss. + current_task_weight = task_weights[task_id] if task_weights is not None else 1 + weighted_total_loss += losses.loss_total * current_task_weight + losses_list.append(losses.loss_total * current_task_weight) + task_weight_multi_task.append(current_task_weight) + # Store intermediate losses for logging. for loss_name, loss_value in losses.intermediate_losses.items(): self.intermediate_losses[f"{loss_name}"] = loss_value - # print(f'{loss_name}: {loss_value.sum()}') - # print(f'{loss_name}: {loss_value[0][0]}') - - # print(f"=== 全局任务权重 (按 task_id 排列): {task_weights}") - # assert not torch.isnan(losses.loss_total).any(), f"Loss contains NaN values, losses.loss_total:{losses.loss_total}, losses:{losses}" - # assert not torch.isinf(losses.loss_total).any(), f"Loss contains Inf values, losses.loss_total:{losses.loss_total}, losses:{losses}" - - # Collect losses per task - obs_loss = self.intermediate_losses.get('loss_obs', 0.0) or 0.0 - reward_loss = self.intermediate_losses.get('loss_rewards', 0.0) or 0.0 - policy_loss = self.intermediate_losses.get('loss_policy', 0.0) or 0.0 - orig_policy_loss = self.intermediate_losses.get('orig_policy_loss', 0.0) or 0.0 - policy_entropy = self.intermediate_losses.get('policy_entropy', 0.0) or 0.0 - value_loss = self.intermediate_losses.get('loss_value', 0.0) or 0.0 - latent_recon_loss = self.intermediate_losses.get('latent_recon_loss', 0.0) or 0.0 - perceptual_loss = self.intermediate_losses.get('perceptual_loss', 0.0) or 0.0 - latent_state_l2_norms = self.intermediate_losses.get('latent_state_l2_norms', 0.0) or 0.0 - value_priority = torch.tensor(0., device=self._cfg.device) # Placeholder, adjust as needed - - obs_loss_multi_task.append(obs_loss) - reward_loss_multi_task.append(reward_loss) - policy_loss_multi_task.append(policy_loss) - orig_policy_loss_multi_task.append(orig_policy_loss) - policy_entropy_multi_task.append(policy_entropy) - value_loss_multi_task.append(value_loss) - latent_recon_loss_multi_task.append(latent_recon_loss) - perceptual_loss_multi_task.append(perceptual_loss) - latent_state_l2_norms_multi_task.append(latent_state_l2_norms) + + # Collect individual losses for the current task. + obs_loss_multi_task.append(self.intermediate_losses.get('loss_obs', 0.0) or 0.0) + reward_loss_multi_task.append(self.intermediate_losses.get('loss_rewards', 0.0) or 0.0) + policy_loss_multi_task.append(self.intermediate_losses.get('loss_policy', 0.0) or 0.0) + orig_policy_loss_multi_task.append(self.intermediate_losses.get('orig_policy_loss', 0.0) or 0.0) + policy_entropy_multi_task.append(self.intermediate_losses.get('policy_entropy', 0.0) or 0.0) + value_loss_multi_task.append(self.intermediate_losses.get('loss_value', 0.0) or 0.0) + latent_recon_loss_multi_task.append(self.intermediate_losses.get('latent_recon_loss', 0.0) or 0.0) + perceptual_loss_multi_task.append(self.intermediate_losses.get('perceptual_loss', 0.0) or 0.0) + latent_state_l2_norms_multi_task.append(self.intermediate_losses.get('latent_state_l2_norms', 0.0) or 0.0) average_target_policy_entropy_multi_task.append(average_target_policy_entropy) + value_priority = torch.tensor(0., device=self._cfg.device) # Placeholder value_priority_multi_task.append(value_priority) value_priority_mean_multi_task.append(value_priority.mean().item()) - # Core learn model update step + # --- Model Update Step --- self._optimizer_world_model.zero_grad() - # 假设每个进程计算出的 losses_list 为可求梯度的 tensor list,比如多个标量 loss 组成的列表 - # 例如 losses_list = [loss1, loss2, ...],其中每个 loss_i 都是形如 (1,) 的 tensor 且 requires_grad=True + # Perform backward pass, either with or without gradient correction. if self._cfg.use_moco: - # 调用 MoCo backward,由 grad_correct 中的 backward 实现梯度校正 + # Use MoCo for gradient correction and backpropagation. lambd, stats = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params) elif self._cfg.only_use_moco_stats: + # Compute MoCo stats but perform standard backpropagation. lambd, stats = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params) - # 不使用梯度校正的情况,由各 rank 自己执行反向传播 weighted_total_loss.backward() else: - # 不使用梯度校正的情况,由各 rank 自己执行反向传播 + # Standard backpropagation without gradient correction. lambd = torch.tensor([0. for _ in range(self.task_num_for_current_rank)], device=self._cfg.device) weighted_total_loss.backward() + # Clip gradients to prevent exploding gradients. total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_(self._learn_model.world_model.parameters(), self._cfg.grad_clip_value) + # NOTE: If ignore_grad is True, zero out gradients. This is useful for DDP synchronization + # when a GPU has finished all its tasks but still needs to participate in the training step. + if ignore_grad: + self._optimizer_world_model.zero_grad() + + # Synchronize gradients across GPUs in multi-GPU setup. if self._cfg.multi_gpu: - # if not self._cfg.use_moco or self._cfg.only_use_moco_stats: - # self.sync_gradients(self._learn_model) if not self._cfg.use_moco: + # TODO: Investigate if a barrier is needed here for synchronization. + # dist.barrier() self.sync_gradients(self._learn_model) + # Update model parameters. self._optimizer_world_model.step() + # Step the learning rate scheduler. if self._cfg.cos_lr_scheduler or self._cfg.piecewise_decay_lr_scheduler: self.lr_scheduler.step() - # Core target model update step + # Update the target model using a soft update rule. self._target_model.update(self._learn_model.state_dict()) - # 获取GPU内存使用情况 + # Monitor GPU memory usage. if torch.cuda.is_available(): torch.cuda.synchronize() - current_memory_allocated = torch.cuda.memory_allocated() - max_memory_allocated = torch.cuda.max_memory_allocated() - current_memory_allocated_gb = current_memory_allocated / (1024 ** 3) - max_memory_allocated_gb = max_memory_allocated / (1024 ** 3) + current_memory_allocated_gb = torch.cuda.memory_allocated() / (1024 ** 3) + max_memory_allocated_gb = torch.cuda.max_memory_allocated() / (1024 ** 3) else: - current_memory_allocated_gb = 0. - max_memory_allocated_gb = 0. + current_memory_allocated_gb, max_memory_allocated_gb = 0., 0. - # 构建损失字典 + # --- Logging and Return --- return_loss_dict = { 'Current_GPU': current_memory_allocated_gb, 'Max_GPU': max_memory_allocated_gb, @@ -536,99 +535,81 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None) -> Dict[s 'total_grad_norm_before_clip_wm': total_grad_norm_before_clip_wm.item(), } - # if task_weights is None: - # task_weights = {self.task_id+i: 1 for i in range(self.task_num_for_current_rank)} - # else: - # print(f'task_weights:{task_weights}') - # from ding.utils import EasyTimer, set_pkg_seed, get_rank - - # print(f'rank:{get_rank()}, task_id:{self.task_id}') - - # 生成任务相关的损失字典,并为每个任务相关的 loss 添加前缀 "noreduce_" + # Generate and merge task-specific loss dictionaries. + # The "noreduce_" prefix indicates these are per-rank values before DDP reduction. multi_task_loss_dicts = { - **generate_task_loss_dict(task_weight_multi_task, 'noreduce_task_weight_task{}', task_id=self.task_id), - **generate_task_loss_dict(obs_loss_multi_task, 'noreduce_obs_loss_task{}', task_id=self.task_id), - **generate_task_loss_dict(latent_recon_loss_multi_task, 'noreduce_latent_recon_loss_task{}', task_id=self.task_id), - **generate_task_loss_dict(perceptual_loss_multi_task, 'noreduce_perceptual_loss_task{}', task_id=self.task_id), - **generate_task_loss_dict(latent_state_l2_norms_multi_task, 'noreduce_latent_state_l2_norms_task{}', task_id=self.task_id), - **generate_task_loss_dict(policy_loss_multi_task, 'noreduce_policy_loss_task{}', task_id=self.task_id), - **generate_task_loss_dict(orig_policy_loss_multi_task, 'noreduce_orig_policy_loss_task{}', task_id=self.task_id), - **generate_task_loss_dict(policy_entropy_multi_task, 'noreduce_policy_entropy_task{}', task_id=self.task_id), - **generate_task_loss_dict(reward_loss_multi_task, 'noreduce_reward_loss_task{}', task_id=self.task_id), - **generate_task_loss_dict(value_loss_multi_task, 'noreduce_value_loss_task{}', task_id=self.task_id), - **generate_task_loss_dict(average_target_policy_entropy_multi_task, 'noreduce_target_policy_entropy_task{}', task_id=self.task_id), - **generate_task_loss_dict(lambd, 'noreduce_lambd_task{}', task_id=self.task_id), - **generate_task_loss_dict(value_priority_multi_task, 'noreduce_value_priority_task{}', task_id=self.task_id), - **generate_task_loss_dict(value_priority_mean_multi_task, 'noreduce_value_priority_mean_task{}', task_id=self.task_id), + **generate_task_loss_dict(task_weight_multi_task, 'noreduce_task_weight_task{}', self.task_id), + **generate_task_loss_dict(obs_loss_multi_task, 'noreduce_obs_loss_task{}', self.task_id), + **generate_task_loss_dict(latent_recon_loss_multi_task, 'noreduce_latent_recon_loss_task{}', self.task_id), + **generate_task_loss_dict(perceptual_loss_multi_task, 'noreduce_perceptual_loss_task{}', self.task_id), + **generate_task_loss_dict(latent_state_l2_norms_multi_task, 'noreduce_latent_state_l2_norms_task{}', self.task_id), + **generate_task_loss_dict(policy_loss_multi_task, 'noreduce_policy_loss_task{}', self.task_id), + **generate_task_loss_dict(orig_policy_loss_multi_task, 'noreduce_orig_policy_loss_task{}', self.task_id), + **generate_task_loss_dict(policy_entropy_multi_task, 'noreduce_policy_entropy_task{}', self.task_id), + **generate_task_loss_dict(reward_loss_multi_task, 'noreduce_reward_loss_task{}', self.task_id), + **generate_task_loss_dict(value_loss_multi_task, 'noreduce_value_loss_task{}', self.task_id), + **generate_task_loss_dict(average_target_policy_entropy_multi_task, 'noreduce_target_policy_entropy_task{}', self.task_id), + **generate_task_loss_dict(lambd, 'noreduce_lambd_task{}', self.task_id), + **generate_task_loss_dict(value_priority_multi_task, 'noreduce_value_priority_task{}', self.task_id), + **generate_task_loss_dict(value_priority_mean_multi_task, 'noreduce_value_priority_mean_task{}', self.task_id), } - - # print(f'multi_task_loss_dicts:{ multi_task_loss_dicts}') - - # 合并两个字典 return_loss_dict.update(multi_task_loss_dicts) - # 如果需要,可以将损失字典记录到日志或其他地方 + # Log to wandb if enabled. if self._cfg.use_wandb: wandb.log({'learner_step/' + k: v for k, v in return_loss_dict.items()}, step=self.env_step) wandb.log({"learner_iter_vs_env_step": self.train_iter}, step=self.env_step) return return_loss_dict - # TODO: num_tasks - def _monitor_vars_learn(self, num_tasks=2) -> List[str]: + def _monitor_vars_learn(self, num_tasks: int = 2) -> List[str]: """ Overview: - Register the variables to be monitored in learn mode. The registered variables will be logged in - tensorboard according to the return value ``_forward_learn``. - If num_tasks is provided, generate monitored variables for each task. + Specifies the variables to be monitored during training. These variables will be logged + (e.g., to TensorBoard) based on the dictionary returned by `_forward_learn`. + Arguments: + - num_tasks (:obj:`int`): The number of tasks to generate monitored variables for. This argument is for API consistency and is overridden by `self.task_num_for_current_rank`. + Returns: + - (:obj:`List[str]`): A list of variable names to monitor. """ - # Basic monitored variables that do not depend on the number of tasks + # Basic monitored variables, independent of the number of tasks. monitored_vars = [ - 'Current_GPU', - 'Max_GPU', - 'collect_epsilon', - 'collect_mcts_temperature', - 'cur_lr_world_model', - 'weighted_total_loss', - 'total_grad_norm_before_clip_wm', + 'Current_GPU', 'Max_GPU', 'collect_epsilon', 'collect_mcts_temperature', + 'cur_lr_world_model', 'weighted_total_loss', 'total_grad_norm_before_clip_wm', ] - # rank = get_rank() + # Task-specific variables. task_specific_vars = [ - 'noreduce_task_weight', - 'noreduce_obs_loss', - 'noreduce_orig_policy_loss', - 'noreduce_policy_loss', - 'noreduce_latent_recon_loss', - 'noreduce_policy_entropy', - 'noreduce_target_policy_entropy', - 'noreduce_reward_loss', - 'noreduce_value_loss', - 'noreduce_perceptual_loss', - 'noreduce_latent_state_l2_norms', - 'noreduce_lambd', + 'noreduce_task_weight', 'noreduce_obs_loss', 'noreduce_orig_policy_loss', + 'noreduce_policy_loss', 'noreduce_latent_recon_loss', 'noreduce_policy_entropy', + 'noreduce_target_policy_entropy', 'noreduce_reward_loss', 'noreduce_value_loss', + 'noreduce_perceptual_loss', 'noreduce_latent_state_l2_norms', 'noreduce_lambd', 'noreduce_value_priority_mean', ] - # self.task_num_for_current_rank 作为当前rank的base_index - num_tasks = self.task_num_for_current_rank - # If the number of tasks is provided, extend the monitored variables list with task-specific variables - if num_tasks is not None: + + # The number of tasks handled by the current rank. + num_tasks_on_rank = self.task_num_for_current_rank + + # Generate full variable names for each task on the current rank. + if num_tasks_on_rank is not None: for var in task_specific_vars: - for task_idx in range(num_tasks): - # print(f"learner policy Rank {rank}, self.task_id: {self.task_id}") - monitored_vars.append(f'{var}_task{self.task_id+task_idx}') + for task_idx in range(num_tasks_on_rank): + # The task ID is offset by the base task ID for this rank. + monitored_vars.append(f'{var}_task{self.task_id + task_idx}') else: - # If num_tasks is not provided, we assume there's only one task and keep the original variable names monitored_vars.extend(task_specific_vars) return monitored_vars - def monitor_weights_and_grads(self, model): + def monitor_weights_and_grads(self, model: torch.nn.Module) -> None: """ - Monitor and print the weights and gradients of the model. + Overview: + A utility function to monitor and print the statistics (mean, std) of model weights and their gradients. + Arguments: + - model (:obj:`torch.nn.Module`): The model to inspect. """ for name, param in model.named_parameters(): - if param.requires_grad: + if param.requires_grad and param.grad is not None: print(f"Layer: {name} | " f"Weight mean: {param.data.mean():.4f} | " f"Weight std: {param.data.std():.4f} | " @@ -637,7 +618,9 @@ def monitor_weights_and_grads(self, model): def _init_collect(self) -> None: """ - Collect mode init method. Initialize the collect model and MCTS utils. + Overview: + Initializes the collection mode. This method sets up the collect model, MCTS utilities, + and initial states for the collector environments. """ self._collect_model = self._model @@ -647,35 +630,45 @@ def _init_collect(self) -> None: self._mcts_collect = MCTSPtree(self._cfg) self._collect_mcts_temperature = 1. self._task_weight_temperature = 10. - self._collect_epsilon = 0.0 self.collector_env_num = self._cfg.collector_env_num + + # Initialize placeholders for the last observation and action batches. if self._cfg.model.model_type == 'conv': - self.last_batch_obs = torch.zeros( - [self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64] - ).to(self._cfg.device) - self.last_batch_action = [-1 for _ in range(self.collector_env_num)] + obs_shape = [self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64] + self.last_batch_obs = torch.zeros(obs_shape, device=self._cfg.device) elif self._cfg.model.model_type == 'mlp': - self.last_batch_obs = torch.zeros( - [self.collector_env_num, self._cfg.model.observation_shape_list[0]] - ).to(self._cfg.device) - self.last_batch_action = [-1 for _ in range(self.collector_env_num)] + obs_shape = [self.collector_env_num, self._cfg.model.observation_shape_list[0]] + self.last_batch_obs = torch.zeros(obs_shape, device=self._cfg.device) + self.last_batch_action = [-1 for _ in range(self.collector_env_num)] def _forward_collect( self, data: torch.Tensor, - action_mask: list = None, - temperature: float = 1, - to_play: List = [-1], + action_mask: List = None, + temperature: float = 1.0, + to_play: List[int] = [-1], epsilon: float = 0.25, - ready_env_id: np.array = None, + ready_env_id: np.ndarray = None, + timestep: List[int] = [0], task_id: int = None, - ) -> Dict: + ) -> Dict[int, Dict[str, Any]]: """ - Forward function for collecting data in collect mode, handling multiple tasks. + Overview: + The forward pass for data collection. It uses MCTS to select actions for the current states. + Arguments: + - data (:obj:`torch.Tensor`): The current batch of observations. + - action_mask (:obj:`List`): A list of action masks for each environment. + - temperature (:obj:`float`): The temperature parameter for MCTS action selection. + - to_play (:obj:`List[int]`): A list indicating the current player for each environment. + - epsilon (:obj:`float`): The exploration noise parameter. + - ready_env_id (:obj:`np.ndarray`): An array of environment IDs that are ready for action. + - timestep (:obj:`List[int]`): The current timestep for each environment. + - task_id (:obj:`int`): The ID of the task being executed. + Returns: + - (:obj:`Dict[int, Dict[str, Any]]`): A dictionary mapping environment IDs to action selection results. """ self._collect_model.eval() - self._collect_mcts_temperature = temperature self._collect_epsilon = epsilon active_collect_env_num = data.shape[0] @@ -684,55 +677,33 @@ def _forward_collect( output = {i: None for i in ready_env_id} with torch.no_grad(): - network_output = self._collect_model.initial_inference( - self.last_batch_obs, - self.last_batch_action, - data, - task_id=task_id - ) + # 1. Initial inference to get root information. + network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, task_id=task_id) latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() latent_state_roots = latent_state_roots.detach().cpu().numpy() policy_logits = policy_logits.detach().cpu().numpy().tolist() - legal_actions = [ - [i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num) - ] if not self._cfg.model.continuous_action_space else [ - [-1 for _ in range(self._cfg.model.world_model_cfg.num_of_sampled_actions)] - for _ in range(active_collect_env_num) - ] + # 2. Prepare MCTS roots. + if not self._cfg.model.continuous_action_space: + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] + else: + legal_actions = [[-1] * self._cfg.model.world_model_cfg.num_of_sampled_actions for _ in range(active_collect_env_num)] - noises = [ - np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(self._cfg.model.world_model_cfg.num_of_sampled_actions)) - .astype(np.float32).tolist() for _ in range(active_collect_env_num) - ] + noises = [np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.world_model_cfg.num_of_sampled_actions).astype(np.float32).tolist() for _ in range(active_collect_env_num)] if self._cfg.mcts_ctree: - roots = MCTSCtree.roots( - active_collect_env_num, - legal_actions, - self._cfg.model.world_model_cfg.action_space_size, - self._cfg.model.world_model_cfg.num_of_sampled_actions, - self._cfg.model.continuous_action_space - ) + roots = MCTSCtree.roots(active_collect_env_num, legal_actions, self._cfg.model.world_model_cfg.action_space_size, self._cfg.model.world_model_cfg.num_of_sampled_actions, self._cfg.model.continuous_action_space) else: roots = MCTSPtree.roots(active_collect_env_num, legal_actions) roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) - # try: - self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play, task_id=task_id) - # print("latent_state_roots.shape:", latent_state_roots.shape) - # except Exception as e: - # print("="*20) - # print(e) - # print("roots:", roots, "latent_state_roots:", latent_state_roots) - # print("latent_state_roots.shape:", latent_state_roots.shape) - # print("="*20) - # import ipdb; ipdb.set_trace() - + # 3. MCTS search. + self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play, timestep=timestep, task_id=task_id) + # 4. Get results from MCTS and select actions. roots_visit_count_distributions = roots.get_distributions() roots_values = roots.get_values() roots_sampled_actions = roots.get_sampled_actions() @@ -740,17 +711,11 @@ def _forward_collect( batch_action = [] for i, env_id in enumerate(ready_env_id): distributions, value = roots_visit_count_distributions[i], roots_values[i] - root_sampled_actions = np.array([ - getattr(action, 'value', action) for action in roots_sampled_actions[i] - ]) + root_sampled_actions = np.array([getattr(action, 'value', action) for action in roots_sampled_actions[i]]) - # 选择动作 - action, visit_count_distribution_entropy = select_action( - distributions, temperature=self._collect_mcts_temperature, deterministic=False - ) - - # 获取采样动作 - action = root_sampled_actions[action] + # Select action based on visit counts, with temperature for exploration. + action_idx, visit_count_distribution_entropy = select_action(distributions, temperature=self._collect_mcts_temperature, deterministic=False) + action = root_sampled_actions[action_idx] if not self._cfg.model.continuous_action_space: action = int(action.item()) @@ -765,23 +730,23 @@ def _forward_collect( } batch_action.append(action) + # 5. Update state for the next step. self.last_batch_obs = data self.last_batch_action = batch_action - # 检查并重置采集器 + # Reset collector if the number of active environments is less than expected. if active_collect_env_num < self.collector_env_num: - print('==========collect_forward============') - print(f'len(self.last_batch_obs) < self.collector_env_num, {active_collect_env_num}<{self.collector_env_num}') + logging.warning(f'Number of active envs ({active_collect_env_num}) is less than collector_env_num ({self.collector_env_num}). Resetting collector.') self._reset_collect(reset_init_data=True, task_id=task_id) return output def _init_eval(self) -> None: """ - Evaluate mode init method. Initialize the eval model and MCTS utils. + Overview: + Initializes the evaluation mode. This method sets up the evaluation model, MCTS utilities, + and initial states for the evaluator environments. """ - from ding.utils import EasyTimer, set_pkg_seed, get_rank - self._eval_model = self._model if self._cfg.mcts_ctree: self._mcts_eval = MCTSCtree(self._cfg) @@ -792,69 +757,63 @@ def _init_eval(self) -> None: self.task_id_for_eval = self._cfg.task_id self.task_num_for_current_rank = self._cfg.task_num + # Initialize placeholders for the last observation and action batches for evaluation. if self._cfg.model.model_type == 'conv': - self.last_batch_obs_eval = torch.zeros( - [self.evaluator_env_num, self._cfg.model.observation_shape[0], 64, 64] - ).to(self._cfg.device) - self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] + obs_shape = [self.evaluator_env_num, self._cfg.model.observation_shape[0], 64, 64] + self.last_batch_obs_eval = torch.zeros(obs_shape, device=self._cfg.device) elif self._cfg.model.model_type == 'mlp': - self.last_batch_obs_eval = torch.zeros( - [self.evaluator_env_num, self._cfg.model.observation_shape_list[self.task_id_for_eval]] # TODO - ).to(self._cfg.device) - print(f'rank {get_rank()} last_batch_obs_eval:', self.last_batch_obs_eval.shape) - self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] - - def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, - ready_env_id: np.array = None, task_id: int = None) -> Dict: + # TODO: Ensure observation_shape_list is correctly indexed for the evaluation task. + obs_shape = [self.evaluator_env_num, self._cfg.model.observation_shape_list[self.task_id_for_eval]] + self.last_batch_obs_eval = torch.zeros(obs_shape, device=self._cfg.device) + print(f'rank {get_rank()} last_batch_obs_eval shape: {self.last_batch_obs_eval.shape}') + self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] + + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id: np.ndarray = None, timestep: List[int] = [0], task_id: int = None) -> Dict[int, Dict[str, Any]]: """ - Forward function for evaluating the current policy in eval mode, handling multiple tasks. + Overview: + The forward pass for evaluation. It uses MCTS to select actions deterministically. + Arguments: + - data (:obj:`torch.Tensor`): The current batch of observations. + - action_mask (:obj:`List`): A list of action masks for each environment. + - to_play (:obj:`int`): The current player. + - ready_env_id (:obj:`np.ndarray`): An array of environment IDs that are ready for action. + - timestep (:obj:`List[int]`): The current timestep for each environment. + - task_id (:obj:`int`): The ID of the task being evaluated. + Returns: + - (:obj:`Dict[int, Dict[str, Any]]`): A dictionary mapping environment IDs to action selection results. """ self._eval_model.eval() active_eval_env_num = data.shape[0] if ready_env_id is None: ready_env_id = np.arange(active_eval_env_num) output = {i: None for i in ready_env_id} + with torch.no_grad(): - network_output = self._eval_model.initial_inference( - self.last_batch_obs_eval, - self.last_batch_action, - data, - task_id=task_id - ) + # 1. Initial inference. + network_output = self._eval_model.initial_inference(self.last_batch_obs_eval, self.last_batch_action, data, task_id=task_id) latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) - # TODO:======== - # self._eval_model.training = False - # if not self._eval_model.training: pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() latent_state_roots = latent_state_roots.detach().cpu().numpy() policy_logits = policy_logits.detach().cpu().numpy().tolist() - legal_actions = [ - [i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num) - ] if not self._cfg.model.continuous_action_space else [ - [-1 for _ in range(self._cfg.model.world_model_cfg.num_of_sampled_actions)] - for _ in range(active_eval_env_num) - ] + # 2. Prepare MCTS roots without noise for deterministic evaluation. + if not self._cfg.model.continuous_action_space: + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)] + else: + legal_actions = [[-1] * self._cfg.model.world_model_cfg.num_of_sampled_actions for _ in range(active_eval_env_num)] if self._cfg.mcts_ctree: - roots = MCTSCtree.roots( - active_eval_env_num, - legal_actions, - self._cfg.model.world_model_cfg.action_space_size, - self._cfg.model.world_model_cfg.num_of_sampled_actions, - self._cfg.model.continuous_action_space - ) + roots = MCTSCtree.roots(active_eval_env_num, legal_actions, self._cfg.model.world_model_cfg.action_space_size, self._cfg.model.world_model_cfg.num_of_sampled_actions, self._cfg.model.continuous_action_space) else: roots = MCTSPtree.roots(active_eval_env_num, legal_actions) - - # print(f'type(policy_logits): {type(policy_logits)}') - # print(f'policy_logits.shape: {policy_logits.shape}') - # print(f'policy_logits: {policy_logits}') - + roots.prepare_no_noise(reward_roots, policy_logits, to_play) - self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play, task_id=task_id) + + # 3. MCTS search. + self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play, timestep=timestep, task_id=task_id) + # 4. Get results and select actions deterministically. roots_visit_count_distributions = roots.get_distributions() roots_values = roots.get_values() roots_sampled_actions = roots.get_sampled_actions() @@ -862,17 +821,11 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 batch_action = [] for i, env_id in enumerate(ready_env_id): distributions, value = roots_visit_count_distributions[i], roots_values[i] - root_sampled_actions = np.array([ - getattr(action, 'value', action) for action in roots_sampled_actions[i] - ]) + root_sampled_actions = np.array([getattr(action, 'value', action) for action in roots_sampled_actions[i]]) - # 选择动作(确定性) - action, visit_count_distribution_entropy = select_action( - distributions, temperature=1, deterministic=True - ) - - # 获取采样动作 - action = root_sampled_actions[action] + # Select action deterministically (greedy selection from visit counts). + action_idx, visit_count_distribution_entropy = select_action(distributions, temperature=1, deterministic=True) + action = root_sampled_actions[action_idx] if not self._cfg.model.continuous_action_space: action = int(action.item()) @@ -886,7 +839,8 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 'predicted_policy_logits': policy_logits[i], } batch_action.append(action) - + + # 5. Update state for the next evaluation step. self.last_batch_obs_eval = data self.last_batch_action = batch_action @@ -894,49 +848,42 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 def _reset_collect(self, env_id: int = None, current_steps: int = 0, reset_init_data: bool = True, task_id: int = None) -> None: """ - Reset the collection process for a specific environment. + Overview: + Resets the collector state. This can be a full reset of initial data or a periodic + clearing of model caches to manage memory. + Arguments: + - env_id (:obj:`int`, optional): The ID of the environment to reset. If None, applies to all. + - current_steps (:obj:`int`): The current number of steps, used for periodic cache clearing. + - reset_init_data (:obj:`bool`): Whether to reset the initial observation and action batches. + - task_id (:obj:`int`, optional): The task ID, used to determine observation shape. """ if reset_init_data: - if task_id is not None: - self.last_batch_obs = initialize_zeros_batch( - self._cfg.model.observation_shape_list[task_id], - self._cfg.collector_env_num, - self._cfg.device - ) - else: - self.last_batch_obs = initialize_zeros_batch( - self._cfg.model.observation_shape, - self._cfg.collector_env_num, - self._cfg.device - ) + obs_shape = self._cfg.model.observation_shape_list[task_id] if task_id is not None else self._cfg.model.observation_shape + self.last_batch_obs = initialize_zeros_batch(obs_shape, self._cfg.collector_env_num, self._cfg.device) self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)] - logging.info(f'collector: last_batch_obs, last_batch_action reset() {self.last_batch_obs.shape}') + logging.info(f'Collector: last_batch_obs and last_batch_action have been reset. Shape: {self.last_batch_obs.shape}') if env_id is None or isinstance(env_id, list): return + # Periodically clear model caches to free up memory. clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 - - if current_steps % clear_interval == 0: - logging.info(f'clear_interval: {clear_interval}') - + if current_steps > 0 and current_steps % clear_interval == 0: + logging.info(f'Clearing model caches at step {current_steps}.') world_model = self._collect_model.world_model world_model.past_kv_cache_init_infer.clear() for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: kv_cache_dict_env.clear() world_model.past_kv_cache_recurrent_infer.clear() world_model.keys_values_wm_list.clear() - torch.cuda.empty_cache() - - logging.info('collector: collect_model clear()') - logging.info(f'eps_steps_lst[{env_id}]: {current_steps}') - + logging.info('Collector: collect_model caches cleared.') self._reset_target_model() def _reset_target_model(self) -> None: """ - Reset the target model's caches. + Overview: + Resets the caches of the target model to free up GPU memory. """ world_model = self._target_model.world_model world_model.past_kv_cache_init_infer.clear() @@ -944,13 +891,15 @@ def _reset_target_model(self) -> None: kv_cache_dict_env.clear() world_model.past_kv_cache_recurrent_infer.clear() world_model.keys_values_wm_list.clear() - torch.cuda.empty_cache() - logging.info('collector: target_model past_kv_cache.clear()') + logging.info('Collector: target_model caches cleared.') def _state_dict_learn(self) -> Dict[str, Any]: """ - Return the state_dict of learn mode, including model, target_model, and optimizer. + Overview: + Returns the state dictionary of the learning components. + Returns: + - (:obj:`Dict[str, Any]`): A dictionary containing the state of the model, target model, and optimizer. """ return { 'model': self._learn_model.state_dict(), @@ -958,27 +907,28 @@ def _state_dict_learn(self) -> Dict[str, Any]: 'optimizer_world_model': self._optimizer_world_model.state_dict(), } - # ========== TODO: original version: load all parameters ========== def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: """ Overview: - Load the state_dict variable into policy learn mode. + Loads the state dictionary into the learning components. Arguments: - - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. + - state_dict (:obj:`Dict[str, Any]`): The state dictionary to load. """ self._learn_model.load_state_dict(state_dict['model']) self._target_model.load_state_dict(state_dict['target_model']) self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model']) - # ========== TODO: pretrain-finetue version: only load encoder and transformer-backbone parameters ========== + # TODO: The following is a version for pretrain-finetune workflow, which only loads backbone parameters. # def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: # """ # Overview: - # Load the state_dict variable into policy learn mode, excluding multi-task related parameters. + # Loads a state_dict into the policy's learn mode, but excludes parameters related to + # multi-task heads and task embeddings. This is useful for fine-tuning a pre-trained model + # on a new set of tasks. # Arguments: - # - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved previously. + # - state_dict (:obj:`Dict[str, Any]`): The dict of the policy learn state saved previously. # """ - # # 定义需要排除的参数前缀 + # # Define prefixes of parameters to exclude (e.g., multi-task heads, task embeddings). # exclude_prefixes = [ # '_orig_mod.world_model.head_policy_multi_task.', # '_orig_mod.world_model.head_value_multi_task.', @@ -987,60 +937,53 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: # '_orig_mod.world_model.task_emb.' # ] - # # 定义需要排除的具体参数(如果有特殊情况) + # # Define specific keys to exclude if they don't fit a prefix pattern. # exclude_keys = [ # '_orig_mod.world_model.task_emb.weight', - # '_orig_mod.world_model.task_emb.bias', # 如果存在则添加 - # # 添加其他需要排除的具体参数名 + # '_orig_mod.world_model.task_emb.bias', # ] # def filter_state_dict(state_dict_loader: Dict[str, Any], exclude_prefixes: list, exclude_keys: list = []) -> Dict[str, Any]: # """ - # 过滤掉需要排除的参数。 + # Filters out parameters that should not be loaded. # """ # filtered = {} # for k, v in state_dict_loader.items(): - # if any(k.startswith(prefix) for prefix in exclude_prefixes): - # print(f"Excluding parameter: {k}") # 调试用,查看哪些参数被排除 - # continue - # if k in exclude_keys: - # print(f"Excluding specific parameter: {k}") # 调试用 + # if any(k.startswith(prefix) for prefix in exclude_prefixes) or k in exclude_keys: + # print(f"Excluding parameter from loading: {k}") # continue # filtered[k] = v # return filtered - # # 过滤并加载 'model' 部分 + # # Filter and load state_dict for the main model. # if 'model' in state_dict: # model_state_dict = state_dict['model'] # filtered_model_state_dict = filter_state_dict(model_state_dict, exclude_prefixes, exclude_keys) - # missing_keys, unexpected_keys = self._learn_model.load_state_dict(filtered_model_state_dict, strict=False) - # if missing_keys: - # print(f"Missing keys when loading _learn_model: {missing_keys}") - # if unexpected_keys: - # print(f"Unexpected keys when loading _learn_model: {unexpected_keys}") + # missing, unexpected = self._learn_model.load_state_dict(filtered_model_state_dict, strict=False) + # if missing: + # print(f"Missing keys when loading _learn_model: {missing}") + # if unexpected: + # print(f"Unexpected keys when loading _learn_model: {unexpected}") # else: - # print("No 'model' key found in the state_dict.") + # print("Warning: 'model' key not found in the state_dict.") - # # 过滤并加载 'target_model' 部分 + # # Filter and load state_dict for the target model. # if 'target_model' in state_dict: # target_model_state_dict = state_dict['target_model'] # filtered_target_model_state_dict = filter_state_dict(target_model_state_dict, exclude_prefixes, exclude_keys) - # missing_keys, unexpected_keys = self._target_model.load_state_dict(filtered_target_model_state_dict, strict=False) - # if missing_keys: - # print(f"Missing keys when loading _target_model: {missing_keys}") - # if unexpected_keys: - # print(f"Unexpected keys when loading _target_model: {unexpected_keys}") + # missing, unexpected = self._target_model.load_state_dict(filtered_target_model_state_dict, strict=False) + # if missing: + # print(f"Missing keys when loading _target_model: {missing}") + # if unexpected: + # print(f"Unexpected keys when loading _target_model: {unexpected}") # else: - # print("No 'target_model' key found in the state_dict.") + # print("Warning: 'target_model' key not found in the state_dict.") - # # 加载优化器的 state_dict,不需要过滤,因为优化器通常不包含模型参数 + # # Load optimizer state_dict. This is often skipped during fine-tuning, but included here for completeness. # if 'optimizer_world_model' in state_dict: - # optimizer_state_dict = state_dict['optimizer_world_model'] # try: - # self._optimizer_world_model.load_state_dict(optimizer_state_dict) + # self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model']) # except Exception as e: - # print(f"Error loading optimizer state_dict: {e}") + # print(f"Could not load optimizer state_dict: {e}. This may be expected during fine-tuning.") # else: - # print("No 'optimizer_world_model' key found in the state_dict.") - - # # 如果需要,还可以加载其他部分,例如 scheduler 等 \ No newline at end of file + # print("Warning: 'optimizer_world_model' key not found in the state_dict.") \ No newline at end of file diff --git a/lzero/policy/scaling_transform.py b/lzero/policy/scaling_transform.py index 4e3efb4af..17eee4052 100644 --- a/lzero/policy/scaling_transform.py +++ b/lzero/policy/scaling_transform.py @@ -1,19 +1,15 @@ from typing import Union -import numpy as np import torch - +import numpy as np class DiscreteSupport(object): - def __init__(self, min: int, max: int, delta: float = 1.) -> None: - assert min < max - self.min = min - self.max = max - self.range = np.arange(min, max + 1, delta) - self.size = len(self.range) - self.set_size = len(self.range) - self.delta = delta - + def __init__(self, start: float, stop: float, step: float = 1., device: Union[str, torch.device] = 'cpu') -> None: + assert start < stop + self.arange = torch.arange(start, stop, step, dtype=torch.float32).unsqueeze(0).to(device) + self.size = self.arange.shape[1] + assert self.size > 0, "DiscreteSupport size must be greater than 0" + self.step = step def scalar_transform(x: torch.Tensor, epsilon: float = 0.001, delta: float = 1.) -> torch.Tensor: """ @@ -33,38 +29,9 @@ def scalar_transform(x: torch.Tensor, epsilon: float = 0.001, delta: float = 1.) return output -def ensure_softmax(logits, dim=1): - """ - Overview: - Ensure that the input tensor is normalized along the specified dimension. - Arguments: - - logits (:obj:`torch.Tensor`): The input tensor. - - dim (:obj:`int`): The dimension along which to normalize the input tensor. - Returns: - - output (:obj:`torch.Tensor`): The normalized tensor. - """ - # Calculate the sum along the specified dimension (dim=1 in this case) - sum_along_dim = logits.sum(dim=dim, keepdim=True) - - # Create a tensor of ones with the same shape as sum_along_dim - ones_like_sum = torch.ones_like(sum_along_dim) - - # Check if the logits are already normalized (i.e., if the sum along the dimension is approximately 1) - # torch.allclose checks if all elements of two tensors are close within a tolerance - # atol (absolute tolerance) is set to a small value to allow for numerical precision issues - is_normalized = torch.allclose(sum_along_dim, ones_like_sum, atol=1e-5) - - # If logits are not normalized, apply softmax along the specified dimension - if not is_normalized: - return torch.softmax(logits, dim=dim) - else: - # If logits are already normalized, return them as they are - return logits - - def inverse_scalar_transform( logits: torch.Tensor, - support_size: int, + scalar_support: DiscreteSupport, epsilon: float = 0.001, categorical_distribution: bool = True ) -> torch.Tensor: @@ -77,9 +44,8 @@ def inverse_scalar_transform( - https://arxiv.org/pdf/1805.11593.pdf Appendix A: Proposition A.2 """ if categorical_distribution: - scalar_support = DiscreteSupport(-support_size, support_size, delta=1) - value_probs = ensure_softmax(logits, dim=1) - value_support = torch.from_numpy(scalar_support.range).unsqueeze(0) + value_probs = torch.softmax(logits, dim=1) + value_support = scalar_support.arange value_support = value_support.to(device=value_probs.device) value = (value_support * value_probs).sum(1, keepdim=True) @@ -106,18 +72,15 @@ class InverseScalarTransform: def __init__( self, - support_size: int, - device: Union[str, torch.device] = 'cpu', + scalar_support: DiscreteSupport, categorical_distribution: bool = True ) -> None: - scalar_support = DiscreteSupport(-support_size, support_size, delta=1) - self.value_support = torch.from_numpy(scalar_support.range).unsqueeze(0) - self.value_support = self.value_support.to(device) + self.value_support = scalar_support.arange self.categorical_distribution = categorical_distribution def __call__(self, logits: torch.Tensor, epsilon: float = 0.001) -> torch.Tensor: if self.categorical_distribution: - value_probs = ensure_softmax(logits, dim=1) + value_probs = torch.softmax(logits, dim=1) value = value_probs.mul_(self.value_support).sum(1, keepdim=True) else: value = logits @@ -143,31 +106,72 @@ def visit_count_temperature( return fixed_temperature_value -def phi_transform(discrete_support: DiscreteSupport, x: torch.Tensor) -> torch.Tensor: +def phi_transform( + discrete_support: DiscreteSupport, + x: torch.Tensor, + label_smoothing_eps: float = 0. # <--- 新增平滑参数 +) -> torch.Tensor: """ Overview: - We then apply a transformation ``phi`` to the scalar in order to obtain equivalent categorical representations. - After this transformation, each scalar is represented as the linear combination of its two adjacent supports. - Reference: - - MuZero paper Appendix F: Network Architecture. + Map a real-valued scalar to a categorical distribution over a discrete support using linear interpolation (a.k.a. “soft” one-hot). + + For each scalar value the probability mass is split between the two + nearest support atoms so that their weighted sum equals the original + value (MuZero, Appendix F). + + Arguments: + - discrete_support : DiscreteSupport + Container with the support values (must be evenly spaced). + - x : torch.Tensor + Input tensor of arbitrary shape ``(...,)`` containing real numbers. + + Returns: + - torch.Tensor + Tensor of shape ``(*x.shape, N)`` where ``N = discrete_support.size``. + The last dimension is a probability distribution (sums to 1). + + Notes + ----- + • No in-place ops on the input are used, improving autograd safety. + • Only one `scatter_add_` kernel is launched for efficiency. """ - min = discrete_support.min - max = discrete_support.max - set_size = discrete_support.set_size - delta = discrete_support.delta - - x.clamp_(min, max) - x_low = x.floor() - x_high = x.ceil() - p_high = x - x_low - p_low = 1 - p_high - - target = torch.zeros(x.shape[0], x.shape[1], set_size).to(x.device) - x_high_idx, x_low_idx = x_high - min / delta, x_low - min / delta - target.scatter_(2, x_high_idx.long().unsqueeze(-1), p_high.unsqueeze(-1)) - target.scatter_(2, x_low_idx.long().unsqueeze(-1), p_low.unsqueeze(-1)) - - return target + # --- constants ---------------------------------------------------------- + min_bound = discrete_support.arange[0, 0] + max_bound = discrete_support.arange[0, -1] + step = discrete_support.step + size = discrete_support.size + + # --- 1. clip to the valid range ---------------------------------------- + x = x.clamp(min_bound, max_bound) + + # --- 2. locate neighbouring indices ------------------------------------ + pos = (x - min_bound) / step # continuous position + low_idx_float = torch.floor(pos) # lower index + low_idx_long = low_idx_float.long() # lower index + high_idx = low_idx_long + 1 # upper index (may overflow) + + # --- 3. linear interpolation weights ----------------------------------- + p_high = pos - low_idx_float # distance to lower atom + p_low = 1.0 - p_high # complementary mass + + # --- 4. stack indices / probs and scatter ------------------------------ + idx = torch.stack([low_idx_long, + torch.clamp(high_idx, max=size - 1)], dim=-1) # (*x, 2) + prob = torch.stack([p_low, p_high], dim=-1) # (*x, 2) + + target = torch.zeros(*x.shape, size, + dtype=x.dtype, device=x.device) + + target.scatter_add_(-1, idx, prob) + # return target + + # --- 5. 应用标签平滑 --- + if label_smoothing_eps > 0: + # 将原始的 two-hot 目标与一个均匀分布混合 + smooth_target = (1.0 - label_smoothing_eps) * target + (label_smoothing_eps / size) + return smooth_target + else: + return target def cross_entropy_loss(prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: diff --git a/lzero/policy/stochastic_muzero.py b/lzero/policy/stochastic_muzero.py index cd0f050c7..00bcd3c8a 100644 --- a/lzero/policy/stochastic_muzero.py +++ b/lzero/policy/stochastic_muzero.py @@ -50,9 +50,10 @@ class StochasticMuZeroPolicy(MuZeroPolicy): num_res_blocks=1, # (int) The number of channels of hidden states in MuZero model. num_channels=64, - # (int) The scale of supports used in categorical distribution. - # This variable is only effective when ``categorical_distribution=True``. - support_scale=300, + # (tuple) The range of supports used in categorical distribution. + # These variables are only effective when ``model.categorical_distribution=True``. + reward_support_range=(-300., 301., 1.), + value_support_range=(-300., 301., 1.), # (bool) whether to learn bias in the last linear layer in value and policy head. bias=True, ), @@ -262,11 +263,14 @@ def _init_learn(self) -> None: self._cfg.augmentation, image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) ) - self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) - self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) - self.inverse_scalar_transform_handle = InverseScalarTransform( - self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution - ) + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) + assert self.value_support.size == self._learn_model.value_support_size # if these assertions fails, somebody introduced... + assert self.reward_support.size == self._learn_model.reward_support_size # ...incoherence between policy and model + self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) + self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) + + self.mse_loss = torch.nn.MSELoss() def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: """ @@ -344,7 +348,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # transform the scaled value or its categorical representation to its original value, # i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. - original_value = self.inverse_scalar_transform_handle(value) + original_value = self.value_inverse_scalar_transform_handle(value) # Note: The following lines are just for debugging. predicted_rewards = [] @@ -406,7 +410,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # transform the scaled value or its categorical representation to its original value, # i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. - original_value = self.inverse_scalar_transform_handle(value) + original_value = self.value_inverse_scalar_transform_handle(value) if self._cfg.model.self_supervised_learning_loss: # ============================================================== @@ -426,8 +430,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in temp_loss = negative_cosine_similarity(dynamic_proj, observation_proj) * mask_batch[:, step_k] consistency_loss += temp_loss - # NOTE: the target policy, target_value_categorical, target_reward_categorical is calculated in - # game buffer now. + # NOTE: the target policy is calculated in game buffer now. # ============================================================== # calculate policy loss for the next ``num_unroll_steps`` unroll steps. # NOTE: the +=. @@ -447,7 +450,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in plot_topk_accuracy(afterstate_policy_logits, true_chance_one_hot, topK_values) # The chance encoder is not used in the mcts, so we don't need to calculate the commitment loss. - commitment_loss += torch.nn.MSELoss()(chance_encoding, true_chance_one_hot.float().detach()) + commitment_loss += self.mse_loss(chance_encoding, true_chance_one_hot.float().detach()) else: afterstate_policy_loss += cross_entropy_loss(afterstate_policy_logits, chance_one_hot.detach()) @@ -460,18 +463,18 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # calculate the topK accuracy of afterstate_policy_logits and plot the topK accuracy curve. plot_topk_accuracy(afterstate_policy_logits, true_chance_one_hot, topK_values) - commitment_loss += torch.nn.MSELoss()(chance_encoding, chance_one_hot.float()) + commitment_loss += self.mse_loss(chance_encoding, chance_one_hot.float()) afterstate_value_loss += cross_entropy_loss(afterstate_value, target_value_categorical[:, step_k]) value_loss += cross_entropy_loss(value, target_value_categorical[:, step_k + 1]) reward_loss += cross_entropy_loss(reward, target_reward_categorical[:, step_k]) if self._cfg.monitor_extra_statistics: - original_rewards = self.inverse_scalar_transform_handle(reward) + original_rewards = self.reward_inverse_scalar_transform_handle(reward) original_rewards_cpu = original_rewards.detach().cpu() predicted_values = torch.cat( - (predicted_values, self.inverse_scalar_transform_handle(value).detach().cpu()) + (predicted_values, self.value_inverse_scalar_transform_handle(value).detach().cpu()) ) predicted_rewards.append(original_rewards_cpu) predicted_policies = torch.cat((predicted_policies, torch.softmax(policy_logits, dim=1).detach().cpu())) @@ -580,6 +583,7 @@ def _forward_collect( to_play: List = [-1], epsilon: float = 0.25, ready_env_id: np.array = None, + **kwargs, ) -> Dict: """ Overview: @@ -617,7 +621,7 @@ def _forward_collect( if not self._learn_model.training: # if not in training, obtain the scalars of the value/reward - pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() latent_state_roots = latent_state_roots.detach().cpu().numpy() policy_logits = policy_logits.detach().cpu().numpy().tolist() @@ -673,7 +677,7 @@ def _init_eval(self) -> None: else: self._mcts_eval = MCTSPtree(self._cfg) - def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [-1], ready_env_id: np.array = None,) -> Dict: + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [-1], ready_env_id: np.array = None, **kwargs) -> Dict: """ Overview: The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. \ @@ -707,7 +711,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [ if not self._eval_model.training: # if not in training, obtain the scalars of the value/reward - pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) latent_state_roots = latent_state_roots.detach().cpu().numpy() policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) diff --git a/lzero/policy/tests/test_scaling_transform.py b/lzero/policy/tests/test_scaling_transform.py index 7499a9348..25475b0ec 100644 --- a/lzero/policy/tests/test_scaling_transform.py +++ b/lzero/policy/tests/test_scaling_transform.py @@ -1,16 +1,17 @@ import pytest import torch -from lzero.policy.scaling_transform import inverse_scalar_transform, InverseScalarTransform +from lzero.policy.scaling_transform import DiscreteSupport, inverse_scalar_transform, InverseScalarTransform @pytest.mark.unittest def test_scaling_transform(): import time logit = torch.randn(16, 601) + discrete_support = DiscreteSupport(-300., 301., 1.) start = time.time() - output_1 = inverse_scalar_transform(logit, 300) + output_1 = inverse_scalar_transform(logit, discrete_support) print('t1', time.time() - start) - handle = InverseScalarTransform(300) + handle = InverseScalarTransform(discrete_support) start = time.time() output_2 = handle(logit) print('t2', time.time() - start) diff --git a/lzero/policy/unizero.py b/lzero/policy/unizero.py index a459275a7..a56474ccb 100644 --- a/lzero/policy/unizero.py +++ b/lzero/policy/unizero.py @@ -8,7 +8,7 @@ from ding.model import model_wrap from ding.utils import POLICY_REGISTRY -from lzero.entry.utils import initialize_zeros_batch +from lzero.entry.utils import initialize_zeros_batch, initialize_pad_batch from lzero.mcts import UniZeroMCTSCtree as MCTSCtree from lzero.model import ImageTransforms from lzero.policy import scalar_transform, InverseScalarTransform, phi_transform, \ @@ -17,7 +17,76 @@ from lzero.policy.muzero import MuZeroPolicy from .utils import configure_optimizers_nanogpt +from torch.nn.utils.convert_parameters import parameters_to_vector, vector_to_parameters +import torch.nn.functional as F +def scale_module_weights_vectorized(module: torch.nn.Module, scale_factor: float): + """ + 使用向量化操作高效地缩放一个模块的所有权重。 + """ + if not (0.0 < scale_factor < 1.0): + return # 如果缩放因子无效,则不执行任何操作 + + # 1. 将模块的所有参数展平成一个单一向量 + params_vec = parameters_to_vector(module.parameters()) + + # 2. 在这个向量上执行一次乘法操作 + params_vec.data.mul_(scale_factor) + + # 3. 将缩放后的向量复制回模块的各个参数 + vector_to_parameters(params_vec, module.parameters()) + + +def configure_optimizer_unizero(model, learning_rate, weight_decay, device_type, betas): + """ + 为UniZero模型配置带有差异化学习率的优化器。 + """ + # 1. 定义需要特殊处理的参数 + param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad} + + # 2. 将参数分为三组:Transformer主干、Tokenizer、Heads + transformer_params = {pn: p for pn, p in param_dict.items() if 'transformer' in pn} + tokenizer_params = {pn: p for pn, p in param_dict.items() if 'tokenizer' in pn} + + # Heads的参数是那些既不属于transformer也不属于tokenizer的 + head_params = { + pn: p for pn, p in param_dict.items() + if 'transformer' not in pn and 'tokenizer' not in pn + } + + # 3. 为每组设置不同的优化器参数(特别是学习率) + # 这里我们仍然使用AdamW,但学习率设置更合理 + optim_groups = [ + { + 'params': list(tokenizer_params.values()), + 'lr': learning_rate, # Tokenizer使用基础学习率,例如 1e-4 + # 'lr': learning_rate * 0.1, # 为encoder设置一个较小的学习率,例如 1e-5 + 'weight_decay': weight_decay * 5.0 # <-- 为Encoder设置5倍的权重衰减!这是一个强力正则化 + # 'weight_decay': weight_decay # <-- 为Encoder设置5倍的权重衰减!这是一个强力正则化 + }, + { + 'params': list(transformer_params.values()), + 'lr': learning_rate, # 1e-4 + # 'lr': learning_rate * 0.2, # 为Transformer主干设置一个较小的学习率,例如 1e-5 + 'weight_decay': weight_decay + # 'weight_decay': weight_decay * 5.0 + }, + { + 'params': list(head_params.values()), + 'lr': learning_rate, # Heads也使用基础学习率率,例如 1e-4 + 'weight_decay': 0.0 # 通常Heads的权重不做衰减 + # 'weight_decay': weight_decay + + } + ] + + print("--- Optimizer Groups ---") + print(f"Transformer LR: {learning_rate}") + print(f"Tokenizer/Heads LR: {learning_rate}") + + optimizer = torch.optim.AdamW(optim_groups, betas=betas) + return optimizer + @POLICY_REGISTRY.register('unizero') class UniZeroPolicy(MuZeroPolicy): """ @@ -50,9 +119,10 @@ class UniZeroPolicy(MuZeroPolicy): num_res_blocks=1, # (int) The number of channels of hidden states in MuZero model. num_channels=64, - # (int) The scale of supports used in categorical distribution. - # This variable is only effective when ``categorical_distribution=True``. - support_scale=50, + # (tuple) The range of supports used in categorical distribution. + # These variables are only effective when ``model.categorical_distribution=True``. + reward_support_range=(-50., 51., 1.), + value_support_range=(-50., 51., 1.), # (bool) whether to learn bias in the last linear layer in value and policy head. bias=True, # (bool) whether to use res connection in dynamics. @@ -112,8 +182,17 @@ class UniZeroPolicy(MuZeroPolicy): perceptual_loss_weight=0., # (float) The weight of the policy entropy loss. policy_entropy_weight=0, - # (str) The type of loss for predicting latent variables. Options could be ['group_kl', 'mse']. - predict_latent_loss_type='group_kl', + # (str) The normalization type for the final layer in both the head and the encoder. + # This option must be the same for both 'final_norm_option_in_head' and 'final_norm_option_in_encoder'. + # Valid options are 'LayerNorm' and 'SimNorm'. + # When set to 'LayerNorm', the 'predict_latent_loss_type' should be 'mse'. + # When set to 'SimNorm', the 'predict_latent_loss_type' should be 'group_kl'. + final_norm_option_in_head="LayerNorm", + final_norm_option_in_encoder="LayerNorm", + # (str) The type of loss function for predicting latent variables. + # Options are 'mse' (Mean Squared Error) or 'group_kl' (Group Kullback-Leibler divergence). + # This choice is dependent on the normalization method selected above. + predict_latent_loss_type='mse', # (str) The type of observation. Options are ['image', 'vector']. obs_type='image', # (float) The discount factor for future rewards. @@ -130,9 +209,30 @@ class UniZeroPolicy(MuZeroPolicy): # (int) The maximum sequence length for position encoding. max_seq_len=8192, lora_r= 0, + # Controls where to compute reconstruction loss: 'after_backbone', 'before_backbone', or None. + # - after_backbone: The reconstruction loss is computed after the encoded representation passes through the backbone. + # - before_backbone: The reconstruction loss is computed directly on the encoded representation, without the backbone. + decode_loss_mode=None, ), ), # ****** common ****** + # (bool) 是否启用自适应策略熵权重 (alpha) + use_adaptive_entropy_weight=True, + # (float) 自适应alpha优化器的学习率 + adaptive_entropy_alpha_lr=1e-4, + # ==================== START: Encoder-Clip Annealing Config ==================== + # (bool) 是否启用 encoder-clip 值的退火。 + use_encoder_clip_annealing=True, + # (str) 退火类型。可选 'linear' 或 'cosine'。 + encoder_clip_anneal_type='cosine', + # (float) 退火的起始 clip 值 (训练初期,较宽松)。 + encoder_clip_start_value=30.0, + # (float) 退火的结束 clip 值 (训练后期,较严格)。 + encoder_clip_end_value=10.0, + # (int) 完成从起始值到结束值的退火所需的训练迭代步数。 + encoder_clip_anneal_steps=100000, # 例如,在200k次迭代后达到最终值 + # ===================== END: Encoder-Clip Annealing Config ===================== + # (bool) whether to use rnd model. use_rnd_model=False, # (bool) Whether to use multi-gpu training. @@ -198,6 +298,10 @@ class UniZeroPolicy(MuZeroPolicy): optim_type='AdamW', # (float) Learning rate for training policy network. Initial lr for manually decay schedule. learning_rate=0.0001, + # ==================== [新增] 范数监控频率 ==================== + # 每隔多少个训练迭代步数,监控一次模型参数的范数。设置为0则禁用。 + monitor_norm_freq=5000, + # ============================================================ # (int) Frequency of hard target network update. target_update_freq=100, # (int) Frequency of soft target network update. @@ -214,8 +318,12 @@ class UniZeroPolicy(MuZeroPolicy): n_episode=8, # (int) The number of num_segments in each collecting stage when use muzero_segment_collector. num_segments=8, - # (int) the number of simulations in MCTS. + # # (int) the number of simulations in MCTS for renalyze. num_simulations=50, + # (int) The number of simulations in MCTS for the collect phase. + collect_num_simulations=25, + # (int) The number of simulations in MCTS for the eval phase. + eval_num_simulations=50, # (float) Discount factor (gamma) for returns. discount_factor=0.997, # (int) The number of steps for calculating target q_value. @@ -300,24 +408,142 @@ def default_model(self) -> Tuple[str, List[str]]: """ return 'UniZeroModel', ['lzero.model.unizero_model'] + + # ==================== [新增] 模型范数监控函数 ==================== + def _monitor_model_norms(self) -> Dict[str, float]: + """ + Overview: + 计算并返回模型关键组件(Encoder, Transformer, Heads)的参数矩阵范数。 + 此函数应在 torch.no_grad() 环境下调用,以提高效率。 + Returns: + - norm_metrics (:obj:`Dict[str, float]`): 包含所有范数指标的字典,用于日志记录。 + """ + world_model = self._learn_model.world_model + norm_metrics = {} + + # 定义要监控的模块组 + module_groups = { + 'encoder': world_model.tokenizer.encoder, + 'transformer': world_model.transformer, + 'head_value': world_model.head_value, + 'head_reward': world_model.head_rewards, + 'head_policy': world_model.head_policy, + } + + for group_name, group_module in module_groups.items(): + total_norm_sq = 0.0 + for param_name, param in group_module.named_parameters(): + if param.requires_grad: + # 计算单层参数的L2范数 + param_norm = param.data.norm(2).item() + # 替换点号,使其在TensorBoard中正确显示为层级 + log_name = f'norm/{group_name}/{param_name.replace(".", "/")}' + norm_metrics[log_name] = param_norm + total_norm_sq += param_norm ** 2 + + # 计算整个模块的总范数 + total_group_norm = np.sqrt(total_norm_sq) + norm_metrics[f'norm/{group_name}/_total_norm'] = total_group_norm + + return norm_metrics + + def _monitor_gradient_norms(self) -> Dict[str, float]: + """ + Overview: + 计算并返回模型关键组件的梯度范数。 + 此函数应在梯度计算完成后、参数更新之前调用。 + Returns: + - grad_metrics (:obj:`Dict[str, float]`): 包含所有梯度范数指标的字典,用于日志记录。 + """ + world_model = self._learn_model.world_model + grad_metrics = {} + + # 定义要监控的模块组 + module_groups = { + 'encoder': world_model.tokenizer.encoder, + 'transformer': world_model.transformer, + 'head_value': world_model.head_value, + 'head_reward': world_model.head_rewards, + 'head_policy': world_model.head_policy, + } + + for group_name, group_module in module_groups.items(): + total_grad_norm_sq = 0.0 + num_params_with_grad = 0 + + for param_name, param in group_module.named_parameters(): + if param.requires_grad and param.grad is not None: + # 计算单层参数的梯度L2范数 + grad_norm = param.grad.data.norm(2).item() + # 替换点号,使其在TensorBoard中正确显示为层级 + log_name = f'grad/{group_name}/{param_name.replace(".", "/")}' + grad_metrics[log_name] = grad_norm + total_grad_norm_sq += grad_norm ** 2 + num_params_with_grad += 1 + + # 计算整个模块的总梯度范数 + if num_params_with_grad > 0: + total_group_grad_norm = np.sqrt(total_grad_norm_sq) + grad_metrics[f'grad/{group_name}/_total_norm'] = total_group_grad_norm + else: + grad_metrics[f'grad/{group_name}/_total_norm'] = 0.0 + + return grad_metrics + # ================================================================= + def _init_learn(self) -> None: """ Overview: Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. """ - # NOTE: nanoGPT optimizer - self._optimizer_world_model = configure_optimizers_nanogpt( - model=self._model.world_model, - learning_rate=self._cfg.learning_rate, - weight_decay=self._cfg.weight_decay, - device_type=self._cfg.device, - betas=(0.9, 0.95), - ) + if self._cfg.optim_type == 'SGD': + # --- 改为SGD优化器 --- + self._optimizer_world_model = torch.optim.SGD( + self._model.world_model.parameters(), + lr=self._cfg.learning_rate, # 初始学习率,在配置中设为 0.2 + momentum=self._cfg.momentum, # 在配置中设为 0.9 + weight_decay=self._cfg.weight_decay # 在配置中设为 1e-4 + ) + elif self._cfg.optim_type == 'AdamW': + # NOTE: nanoGPT optimizer + self._optimizer_world_model = configure_optimizers_nanogpt( + model=self._model.world_model, + learning_rate=self._cfg.learning_rate, + weight_decay=self._cfg.weight_decay, + device_type=self._cfg.device, + betas=(0.9, 0.95), + ) + elif self._cfg.optim_type == 'AdamW_mix_lr_wdecay': + self._optimizer_world_model = configure_optimizer_unizero( + model=self._model.world_model, + learning_rate=self._cfg.learning_rate, # 使用一个合理的AdamW基础学习率 + weight_decay=self._cfg.weight_decay, + device_type=self._cfg.device, + betas=(0.9, 0.95), + ) if self._cfg.cos_lr_scheduler: from torch.optim.lr_scheduler import CosineAnnealingLR # TODO: check the total training steps - self.lr_scheduler = CosineAnnealingLR(self._optimizer_world_model, 1e5, eta_min=0, last_epoch=-1) + # self.lr_scheduler = CosineAnnealingLR(self._optimizer_world_model, 1e5, eta_min=0, last_epoch=-1) + total_iters = self._cfg.get('total_iterations', 500000) # 500k iter + # final_lr = self._cfg.get('final_learning_rate', 0.0) + final_lr = self._cfg.get('final_learning_rate', 1e-6) + + self.lr_scheduler = CosineAnnealingLR( + self._optimizer_world_model, + T_max=total_iters, + eta_min=final_lr + ) + print(f"CosineAnnealingLR enabled: T_max={total_iters}, eta_min={final_lr}") + + + if self._cfg.piecewise_decay_lr_scheduler: + from torch.optim.lr_scheduler import LambdaLR + max_step = self._cfg.threshold_training_steps_for_final_lr + # NOTE: the 1, 0.1, 0.01 is the decay rate, not the lr. + lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa + self.lr_scheduler = LambdaLR(self._optimizer_world_model, lr_lambda=lr_lambda) # use model_wrapper for specialized demands of different modes self._target_model = copy.deepcopy(self._model) @@ -339,16 +565,19 @@ def _init_learn(self) -> None: self._cfg.augmentation, image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) ) - self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) - self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) - self.inverse_scalar_transform_handle = InverseScalarTransform( - self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution - ) + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) + self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) + self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) + self.intermediate_losses = defaultdict(float) self.l2_norm_before = 0. self.l2_norm_after = 0. self.grad_norm_before = 0. self.grad_norm_after = 0. + + encoder_tokenizer = getattr(self._model.tokenizer.encoder, 'tokenizer', None) + self.pad_token_id = encoder_tokenizer.pad_token_id if encoder_tokenizer is not None else 0 if self._cfg.use_wandb: # TODO: add the model to wandb @@ -356,6 +585,63 @@ def _init_learn(self) -> None: self.accumulation_steps = self._cfg.accumulation_steps + # ==================== START: 目标熵正则化初始化 ==================== + # 从配置中读取是否启用自适应alpha,并提供一个默认值 + self.use_adaptive_entropy_weight = self._cfg.get('use_adaptive_entropy_weight', True) + + # 在 _init_learn 中增加配置 + self.target_entropy_start_ratio = self._cfg.get('target_entropy_start_ratio', 0.98) + self.target_entropy_end_ratio = self._cfg.get('target_entropy_end_ratio', 0.7) + self.target_entropy_decay_steps = self._cfg.get('target_entropy_decay_steps', 200000) # 例如,在200k步内完成退火 2M envsteps + + if self.use_adaptive_entropy_weight: + # 1. 设置目标熵。对于离散动作空间,一个常见的启发式设置是动作空间维度的负对数乘以一个系数。 + # 这个系数(例如0.98)可以作为一个超参数。 + action_space_size = self._cfg.model.action_space_size + self.target_entropy = -np.log(1.0 / action_space_size) * 0.98 + + # 2. 初始化一个可学习的 log_alpha 参数。 + # 初始化为0,意味着初始的 alpha = exp(0) = 1.0。 + self.log_alpha = torch.nn.Parameter(torch.zeros(1, device=self._cfg.device), requires_grad=True) + + # 3. 为 log_alpha 创建一个专属的优化器。 + # 使用与主优化器不同的、较小的学习率(例如1e-4)通常更稳定。 + alpha_lr = self._cfg.get('adaptive_entropy_alpha_lr', 1e-4) + self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=alpha_lr) + + print("="*20) + print(">>> 目标熵正则化 (自适应Alpha) 已启用 <<<") + print(f" 目标熵 (Target Entropy): {self.target_entropy:.4f}") + print(f" Alpha 优化器学习率: {alpha_lr:.2e}") + print("="*20) + # ===================== END: 目标熵正则化初始化 ===================== + + # ==================== START: 初始化 Encoder-Clip Annealing 参数 ==================== + self.use_encoder_clip_annealing = self._cfg.get('use_encoder_clip_annealing', False) + self.latent_norm_clip_threshold = self._cfg.get('latent_norm_clip_threshold', 20.0) # TODO + if self.use_encoder_clip_annealing: + self.encoder_clip_anneal_type = self._cfg.get('encoder_clip_anneal_type', 'cosine') + self.encoder_clip_start = self._cfg.get('encoder_clip_start_value', 30.0) + self.encoder_clip_end = self._cfg.get('encoder_clip_end_value', 10.0) + self.encoder_clip_anneal_steps = self._cfg.get('encoder_clip_anneal_steps', 200000) + + print("="*20) + print(">>> Encoder-Clip 退火已启用 <<<") + print(f" 类型: {self.encoder_clip_anneal_type}") + print(f" 范围: {self.encoder_clip_start} -> {self.encoder_clip_end}") + print(f" 步数: {self.encoder_clip_anneal_steps}") + print("="*20) + else: + # 如果不启用退火,则使用固定的 clip 阈值 + self.latent_norm_clip_threshold = self._cfg.get('latent_norm_clip_threshold', 20.0) + # ===================== END: 初始化 Encoder-Clip Annealing 参数 ===================== + + # --- NEW: Policy Label Smoothing Parameters --- + self.policy_ls_eps_start = self._cfg.get('policy_ls_eps_start', 0.05) # TODO policy_label_smoothing_eps_start 越大的action space需要越大的eps + self.policy_ls_eps_end = self._cfg.get('policy_label_smoothing_eps_end ', 0.01) # TODO policy_label_smoothing_eps_start + self.policy_ls_eps_decay_steps = self._cfg.get('policy_ls_eps_decay_steps ', 50000) # TODO 50k + print(f"self.policy_ls_eps_start:{self.policy_ls_eps_start}") + # @profile def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: """ @@ -377,6 +663,13 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in obs_batch_ori, action_batch, target_action_batch, mask_batch, indices, weights, make_time, timestep_batch = current_batch target_reward, target_value, target_policy = target_batch + # --- NEW: Calculate current epsilon for policy --- + if self.policy_ls_eps_start > 0: + progress = min(1.0, train_iter / self.policy_ls_eps_decay_steps) + current_policy_label_eps = self.policy_ls_eps_start * (1 - progress) + self.policy_ls_eps_end * progress + else: + current_policy_label_eps = 0.0 + # Prepare observations based on frame stack number if self._cfg.model.frame_stack_num > 1: obs_batch, obs_target_batch = prepare_obs_stack_for_unizero(obs_batch_ori, self._cfg) @@ -405,8 +698,11 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in transformed_target_value = scalar_transform(target_value) # Convert to categorical distributions - target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) - target_value_categorical = phi_transform(self.value_support, transformed_target_value) + # target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) + # target_value_categorical = phi_transform(self.value_support, transformed_target_value) + + target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward, label_smoothing_eps= self._cfg.label_smoothing_eps) + target_value_categorical = phi_transform(self.value_support, transformed_target_value, label_smoothing_eps=self._cfg.label_smoothing_eps) # Prepare batch for GPT model batch_for_gpt = {} @@ -429,6 +725,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in batch_for_gpt['target_value'] = target_value_categorical[:, :-1] batch_for_gpt['target_policy'] = target_policy[:, :-1] + batch_for_gpt['scalar_target_value'] = target_value + # Extract valid target policy data and compute entropy valid_target_policy = batch_for_gpt['target_policy'][batch_for_gpt['mask_padding']] target_policy_entropy = -torch.sum(valid_target_policy * torch.log(valid_target_policy + 1e-9), dim=-1) @@ -436,16 +734,83 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # Update world model losses = self._learn_model.world_model.compute_loss( - batch_for_gpt, self._target_model.world_model.tokenizer, self.inverse_scalar_transform_handle - ) - - weighted_total_loss = losses.loss_total - # 合并 intermediate_losses 字典,避免重复赋值 - # self.intermediate_losses.update(losses.intermediate_losses) + batch_for_gpt, self._target_model.world_model.tokenizer, self.value_inverse_scalar_transform_handle, global_step=train_iter, current_policy_label_eps=current_policy_label_eps, + ) # NOTE : compute_loss third argument is now a dead argument. If this changes, it could need adaptation between value_inverse and reward_inverse. + + # ==================== [修改] 集成范数监控逻辑 ==================== + norm_log_dict = {} + # 检查是否达到监控频率 + if self._cfg.monitor_norm_freq > 0 and (train_iter == 0 or (train_iter % self._cfg.monitor_norm_freq == 0)): + with torch.no_grad(): + # 1. 监控模型参数范数 + param_norm_metrics = self._monitor_model_norms() + norm_log_dict.update(param_norm_metrics) + + # 2. 监控中间张量 x (Transformer的输出) + intermediate_x = losses.intermediate_losses.get('intermediate_tensor_x') + if intermediate_x is not None: + # x 的形状为 (B, T, E) + # 计算每个 token 的 L2 范数 + token_norms = intermediate_x.norm(p=2, dim=-1) + + # 记录这些范数的统计数据 + norm_log_dict['norm/x_token/mean'] = token_norms.mean().item() + norm_log_dict['norm/x_token/std'] = token_norms.std().item() + norm_log_dict['norm/x_token/max'] = token_norms.max().item() + norm_log_dict['norm/x_token/min'] = token_norms.min().item() + + # 3. 监控 logits 的详细统计 (Value, Policy, Reward) + logits_value = losses.intermediate_losses.get('logits_value') + if logits_value is not None: + norm_log_dict['logits/value/mean'] = logits_value.mean().item() + norm_log_dict['logits/value/std'] = logits_value.std().item() + norm_log_dict['logits/value/max'] = logits_value.max().item() + norm_log_dict['logits/value/min'] = logits_value.min().item() + norm_log_dict['logits/value/abs_max'] = logits_value.abs().max().item() + + logits_policy = losses.intermediate_losses.get('logits_policy') + if logits_policy is not None: + norm_log_dict['logits/policy/mean'] = logits_policy.mean().item() + norm_log_dict['logits/policy/std'] = logits_policy.std().item() + norm_log_dict['logits/policy/max'] = logits_policy.max().item() + norm_log_dict['logits/policy/min'] = logits_policy.min().item() + norm_log_dict['logits/policy/abs_max'] = logits_policy.abs().max().item() + + logits_reward = losses.intermediate_losses.get('logits_reward') + if logits_reward is not None: + norm_log_dict['logits/reward/mean'] = logits_reward.mean().item() + norm_log_dict['logits/reward/std'] = logits_reward.std().item() + norm_log_dict['logits/reward/max'] = logits_reward.max().item() + norm_log_dict['logits/reward/min'] = logits_reward.min().item() + norm_log_dict['logits/reward/abs_max'] = logits_reward.abs().max().item() + + # 4. 监控 obs_embeddings (Encoder输出) 的统计 + obs_embeddings = losses.intermediate_losses.get('obs_embeddings') + if obs_embeddings is not None: + # 计算每个 embedding 的 L2 范数 + emb_norms = obs_embeddings.norm(p=2, dim=-1) + norm_log_dict['embeddings/obs/norm_mean'] = emb_norms.mean().item() + norm_log_dict['embeddings/obs/norm_std'] = emb_norms.std().item() + norm_log_dict['embeddings/obs/norm_max'] = emb_norms.max().item() + norm_log_dict['embeddings/obs/norm_min'] = emb_norms.min().item() + # ================================================================= + + # ==================== START MODIFICATION 2 ==================== + # Extract the calculated value_priority from the returned losses. + value_priority_tensor = losses.intermediate_losses['value_priority'] + # Convert to numpy array for the replay buffer, adding a small epsilon. + value_priority_np = value_priority_tensor.detach().cpu().numpy() + 1e-6 + # ===================== END MODIFICATION 2 ===================== + + # weighted_total_loss = losses.loss_total + # TODO: + weighted_total_loss = (weights * losses.loss_total).mean() for loss_name, loss_value in losses.intermediate_losses.items(): self.intermediate_losses[f"{loss_name}"] = loss_value + # 从 losses 对象中提取策略熵 + obs_loss = self.intermediate_losses['loss_obs'] reward_loss = self.intermediate_losses['loss_rewards'] policy_loss = self.intermediate_losses['loss_policy'] @@ -467,6 +832,17 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in e_rank_sim_norm = self.intermediate_losses['e_rank_sim_norm'] latent_state_l2_norms = self.intermediate_losses['latent_state_l2_norms'] + latent_action_l2_norms = self.intermediate_losses['latent_action_l2_norms'] + logits_value_mean=self.intermediate_losses['logits_value_mean'] + logits_value_max=self.intermediate_losses['logits_value_max'] + logits_value_min=self.intermediate_losses['logits_value_min'] + logits_policy_mean=self.intermediate_losses['logits_policy_mean'] + logits_policy_max=self.intermediate_losses['logits_policy_max'] + logits_policy_min=self.intermediate_losses['logits_policy_min'] + temperature_value=self.intermediate_losses['temperature_value'] + temperature_reward=self.intermediate_losses['temperature_reward'] + temperature_policy=self.intermediate_losses['temperature_policy'] + assert not torch.isnan(losses.loss_total).any(), "Loss contains NaN values" assert not torch.isinf(losses.loss_total).any(), "Loss contains Inf values" @@ -475,19 +851,107 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in if (train_iter % self.accumulation_steps) == 0: self._optimizer_world_model.zero_grad() + + # ==================== START: 目标熵正则化更新逻辑 ==================== + alpha_loss = None + current_alpha = self._cfg.model.world_model_cfg.policy_entropy_weight # 默认使用固定值 + if self.use_adaptive_entropy_weight: + # --- 动态计算目标熵 (这部分逻辑是正确的,予以保留) --- + progress = min(1.0, train_iter / self.target_entropy_decay_steps) + current_ratio = self.target_entropy_start_ratio * (1 - progress) + self.target_entropy_end_ratio * progress + action_space_size = self._cfg.model.action_space_size + # 注意:我们将 target_entropy 定义为正数,更符合直觉 + current_target_entropy = -np.log(1.0 / action_space_size) * current_ratio + + # --- 计算 alpha_loss (已修正符号) --- + # 这是核心修正点:去掉了最前面的负号 + # detach() 仍然是关键,确保 alpha_loss 的梯度只流向 log_alpha + alpha_loss = (self.log_alpha * (policy_entropy.detach() - current_target_entropy)).mean() + + # # --- 更新 log_alpha --- + self.alpha_optimizer.zero_grad() + alpha_loss.backward() + self.alpha_optimizer.step() + # --- [优化建议] 增加 log_alpha 裁剪作为安全措施 --- + with torch.no_grad(): + # 将 alpha 限制在例如 [1e-4, 10.0] 的范围内 + self.log_alpha.clamp_(np.log(1e-4), np.log(10.0)) + + # --- 使用当前更新后的 alpha (截断梯度流) --- + current_alpha = self.log_alpha.exp().detach() + + # 重新计算加权的策略损失和总损失 + # 注意:这里的 policy_entropy 已经是一个batch的平均值 + weighted_policy_loss = orig_policy_loss - current_alpha * policy_entropy + # 重新构建总损失 (不使用 losses.loss_total) + # 确保这里的权重与 LossWithIntermediateLosses 类中的计算方式一致 + self.obs_loss_weight = 10 + self.value_loss_weight = 0.5 + self.reward_loss_weight = 1. + self.policy_loss_weight = 1. + self.ends_loss_weight = 0. + total_loss = ( + self.reward_loss_weight * reward_loss + + self.value_loss_weight * value_loss + + self.policy_loss_weight * weighted_policy_loss + + self.obs_loss_weight * obs_loss # 假设 ssl_loss_weight 是 obs_loss 的权重 + # ... 如果还有其他损失项,也加进来 ... + ) + weighted_total_loss = (weights * total_loss).mean() + # ===================== END: 目标熵正则化更新逻辑 ===================== + # Scale the loss by the number of accumulation steps weighted_total_loss = weighted_total_loss / self.accumulation_steps weighted_total_loss.backward() + # ----------------------------------------------------------------- + # 仍然在 torch.no_grad() 环境下执行 + # ================================================================= + with torch.no_grad(): + # 1. Encoder-Clip + # ==================== START: 动态计算当前 Clip 阈值 ==================== + current_clip_value = self.latent_norm_clip_threshold # 默认使用固定值 + if self.use_encoder_clip_annealing: + progress = min(1.0, train_iter / self.encoder_clip_anneal_steps) + + if self.encoder_clip_anneal_type == 'cosine': + # 余弦调度: 从1平滑过渡到0 + cosine_progress = 0.5 * (1.0 + np.cos(np.pi * progress)) + current_clip_value = self.encoder_clip_end + \ + (self.encoder_clip_start - self.encoder_clip_end) * cosine_progress + else: # 默认为线性调度 + current_clip_value = self.encoder_clip_start * (1 - progress) + \ + self.encoder_clip_end * progress + # ===================== END: 动态计算当前 Clip 阈值 ===================== + + # 1. Encoder-Clip (使用动态计算出的 current_clip_value) + if current_clip_value > 0 and 'obs_embeddings' in losses.intermediate_losses: + obs_embeddings = losses.intermediate_losses['obs_embeddings'] + if obs_embeddings is not None: + max_latent_norm = obs_embeddings.norm(p=2, dim=-1).max() + if max_latent_norm > current_clip_value: + scale_factor = current_clip_value / max_latent_norm.item() + # 不再频繁打印,或者可以改为每隔N步打印一次 + if train_iter % 1000 == 0: + print(f"[Encoder-Clip Annealing] Iter {train_iter}: Max latent norm {max_latent_norm.item():.2f} > {current_clip_value:.2f}. Scaling by {scale_factor:.4f}.") + scale_module_weights_vectorized(self._model.world_model.tokenizer.encoder, scale_factor) + # Check if the current iteration completes an accumulation cycle if (train_iter + 1) % self.accumulation_steps == 0: + # ==================== [新增] 监控梯度范数 ==================== + # 在梯度裁剪之前监控梯度范数,用于诊断梯度爆炸/消失问题 + if self._cfg.monitor_norm_freq > 0 and (train_iter == 0 or (train_iter % self._cfg.monitor_norm_freq == 0)): + grad_norm_metrics = self._monitor_gradient_norms() + norm_log_dict.update(grad_norm_metrics) + # ================================================================= + # Analyze gradient norms if simulation normalization analysis is enabled if self._cfg.analysis_sim_norm: # Clear previous analysis results to prevent memory overflow del self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after = self._learn_model.encoder_hook.analyze() self._target_model.encoder_hook.clear_data() - + # Clip gradients to prevent exploding gradients total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_( self._learn_model.world_model.parameters(), self._cfg.grad_clip_value @@ -554,15 +1018,17 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in 'target_policy_entropy': average_target_policy_entropy.item(), 'reward_loss': reward_loss.item(), 'value_loss': value_loss.item(), - # 'value_priority_orig': np.zeros(self._cfg.batch_size), # TODO + # Add value_priority to the log dictionary. + 'value_priority': value_priority_np.mean().item(), + 'value_priority_orig': value_priority_np, 'target_reward': target_reward.mean().item(), 'target_value': target_value.mean().item(), 'transformed_target_reward': transformed_target_reward.mean().item(), 'transformed_target_value': transformed_target_value.mean().item(), 'total_grad_norm_before_clip_wm': total_grad_norm_before_clip_wm.item(), - 'analysis/dormant_ratio_encoder': dormant_ratio_encoder, #.item(), - 'analysis/dormant_ratio_transformer': dormant_ratio_transformer,#.item(), - 'analysis/dormant_ratio_head': dormant_ratio_head,#.item(), + 'analysis/dormant_ratio_encoder': dormant_ratio_encoder, + 'analysis/dormant_ratio_transformer': dormant_ratio_transformer, + 'analysis/dormant_ratio_head': dormant_ratio_head, 'analysis/avg_weight_mag_encoder': avg_weight_mag_encoder, 'analysis/avg_weight_mag_transformer': avg_weight_mag_transformer, @@ -571,12 +1037,42 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in 'analysis/e_rank_sim_norm': e_rank_sim_norm, 'analysis/latent_state_l2_norms': latent_state_l2_norms.item(), + 'analysis/latent_action_l2_norms': latent_action_l2_norms, 'analysis/l2_norm_before': self.l2_norm_before, 'analysis/l2_norm_after': self.l2_norm_after, 'analysis/grad_norm_before': self.grad_norm_before, 'analysis/grad_norm_after': self.grad_norm_after, + "logits_value_mean":logits_value_mean, + "logits_value_max":logits_value_max, + "logits_value_min":logits_value_min, + "logits_policy_mean":logits_policy_mean, + "logits_policy_max":logits_policy_max, + "logits_policy_min":logits_policy_min, + + "temperature_value":temperature_value, + "temperature_reward":temperature_reward, + "temperature_policy":temperature_policy, + + "current_policy_label_eps":current_policy_label_eps, } - + + # ==================== [修改] 将范数监控结果合并到日志中 ==================== + if norm_log_dict: + return_log_dict.update(norm_log_dict) + # ======================================================================= + + # ==================== START: 添加新日志项 ==================== + if self.use_adaptive_entropy_weight: + return_log_dict['adaptive_alpha'] = current_alpha.item() + return_log_dict['adaptive_target_entropy_ratio'] = current_ratio + return_log_dict['alpha_loss'] = alpha_loss.item() + # ==================== START: 添加新日志项 ==================== + + # ==================== START: 添加新日志项 ==================== + if self.use_encoder_clip_annealing: + return_log_dict['current_encoder_clip_value'] = current_clip_value + # ===================== END: 添加新日志项 ===================== + if self._cfg.use_wandb: wandb.log({'learner_step/' + k: v for k, v in return_log_dict.items()}, step=self.env_step) wandb.log({"learner_iter_vs_env_step": self.train_iter}, step=self.env_step) @@ -598,11 +1094,13 @@ def _init_collect(self) -> None: Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. """ self._collect_model = self._model - + # 为 collect MCTS 创建一个配置副本,并设置特定的模拟次数 + mcts_collect_cfg = copy.deepcopy(self._cfg) + mcts_collect_cfg.num_simulations = self._cfg.collect_num_simulations if self._cfg.mcts_ctree: - self._mcts_collect = MCTSCtree(self._cfg) + self._mcts_collect = MCTSCtree(mcts_collect_cfg) else: - self._mcts_collect = MCTSPtree(self._cfg) + self._mcts_collect = MCTSPtree(mcts_collect_cfg) self._collect_mcts_temperature = 1. self._collect_epsilon = 0.0 self.collector_env_num = self._cfg.collector_env_num @@ -610,7 +1108,9 @@ def _init_collect(self) -> None: self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) self.last_batch_action = [-1 for i in range(self.collector_env_num)] elif self._cfg.model.model_type == 'mlp': - self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape]).to(self._cfg.device) + self.last_batch_obs = torch.full( + [self.collector_env_num, self._cfg.model.observation_shape], fill_value=self.pad_token_id, + ).to(self._cfg.device) self.last_batch_action = [-1 for i in range(self.collector_env_num)] # @profile @@ -664,7 +1164,7 @@ def _forward_collect( network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, timestep) latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) - pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() latent_state_roots = latent_state_roots.detach().cpu().numpy() policy_logits = policy_logits.detach().cpu().numpy().tolist() @@ -682,12 +1182,14 @@ def _forward_collect( roots = MCTSPtree.roots(active_collect_env_num, legal_actions) roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) - self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play, timestep) + next_latent_state_with_env = self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play, timestep) + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` roots_visit_count_distributions = roots.get_distributions() roots_values = roots.get_values() # shape: {list: batch_size} + batch_action = [] for i, env_id in enumerate(ready_env_id): distributions, value = roots_visit_count_distributions[i], roots_values[i] @@ -710,6 +1212,14 @@ def _forward_collect( # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + next_latent_state = next_latent_state_with_env[i][action] + + if self._cfg.model.world_model_cfg.obs_type == 'text' and self._cfg.model.world_model_cfg.decode_loss_mode is not None and self._cfg.model.world_model_cfg.decode_loss_mode.lower() != 'none': + # Output the plain text content decoded by the decoder from the next latent state + predicted_next = self._collect_model.tokenizer.decode_to_plain_text(embeddings=next_latent_state, max_length=256) + else: + predicted_next = None + # ============== TODO: only for visualize ============== # action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( # distributions, temperature=self._collect_mcts_temperature, deterministic=True @@ -724,22 +1234,33 @@ def _forward_collect( 'searched_value': value, 'predicted_value': pred_values[i], 'predicted_policy_logits': policy_logits[i], - 'timestep': timestep[i] + 'timestep': timestep[i], + 'predicted_next_text': predicted_next, } batch_action.append(action) self.last_batch_obs = data self.last_batch_action = batch_action - # ========= TODO: for muzero_segment_collector now ========= + # ========= TODO: This logic is a temporary workaround specific to the muzero_segment_collector. ========= if active_collect_env_num < self.collector_env_num: - # 当collect_env中有一个环境先done时,传回的self.last_batch_obs的长度会减少1, transformer在检索kv_cache时需要知道env_id,实现比较复杂 - # 因此直接《self.collector_env_num》个环境的self.last_batch_action全部重置为-1,让transformer从0开始,避免检索错误 - print('==========collect_forward============') - print(f'len(self.last_batch_obs) < self.collector_env_num, {active_collect_env_num}<{self.collector_env_num}') + # When an environment finishes an episode ('done'), the length of `self.last_batch_obs` passed back + # becomes smaller than the total number of collector environments. + # Handling this dynamic batch size is complex, as the transformer's KV cache retrieval + # requires a stable environment ID for correct indexing. A mismatch would cause retrieval errors. + # + # Therefore, as a simpler solution, we reset the collection state for ALL environments. + # By resetting `self.last_batch_action` to -1 for all `self.collector_env_num` environments, + # we force the transformer to start its context from scratch, avoiding incorrect cache lookups. + print('========== collect_forward ============') + print(f'An environment has finished. Active envs: {active_collect_env_num} < Total envs: {self.collector_env_num}. Resetting all.') + self._reset_collect(reset_init_data=True) + + # If the sampling type is 'episode', it's unexpected for the number of active environments to drop, + # as this suggests an inconsistent state or a potential issue in the collection logic. if getattr(self._cfg, 'sample_type', '') == 'episode': - print('BUG: sample_type is episode, but len(self.last_batch_obs) < self.collector_env_num') + print('WARNING: Inconsistent state detected. `sample_type` is "episode", but the number of active environments has changed.') return output @@ -749,18 +1270,26 @@ def _init_eval(self) -> None: Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. """ self._eval_model = self._model + + # 为 eval MCTS 创建一个配置副本,并设置特定的模拟次数 + mcts_eval_cfg = copy.deepcopy(self._cfg) + mcts_eval_cfg.num_simulations = self._cfg.eval_num_simulations + if self._cfg.mcts_ctree: - self._mcts_eval = MCTSCtree(self._cfg) + self._mcts_eval = MCTSCtree(mcts_eval_cfg) else: - self._mcts_eval = MCTSPtree(self._cfg) + self._mcts_eval = MCTSPtree(mcts_eval_cfg) + self.evaluator_env_num = self._cfg.evaluator_env_num if self._cfg.model.model_type == 'conv': - self.last_batch_obs = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) - self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] + self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_action = [-1 for i in range(self.collector_env_num)] elif self._cfg.model.model_type == 'mlp': - self.last_batch_obs = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape]).to(self._cfg.device) - self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] + self.last_batch_obs = torch.full( + [self.collector_env_num, self._cfg.model.observation_shape], fill_value=self.pad_token_id, + ).to(self._cfg.device) + self.last_batch_action = [-1 for i in range(self.collector_env_num)] def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id: np.array = None, timestep: List = [0], task_id: int = None,) -> Dict: @@ -799,7 +1328,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) # if not in training, obtain the scalars of the value/reward - pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) latent_state_roots = latent_state_roots.detach().cpu().numpy() policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) @@ -811,14 +1340,14 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 # python mcts_tree roots = MCTSPtree.roots(active_eval_env_num, legal_actions) roots.prepare_no_noise(reward_roots, policy_logits, to_play) - self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play, timestep) + next_latent_state_with_env = self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play, timestep) # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` roots_visit_count_distributions = roots.get_distributions() roots_values = roots.get_values() # shape: {list: batch_size} batch_action = [] - + for i, env_id in enumerate(ready_env_id): distributions, value = roots_visit_count_distributions[i], roots_values[i] # print("roots_visit_count_distributions:", distributions, "root_value:", value) @@ -834,6 +1363,15 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 # entire action set. action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + # Predict the next latent state based on the selected action and policy + next_latent_state = next_latent_state_with_env[i][action] + + if self._cfg.model.world_model_cfg.obs_type == 'text' and self._cfg.model.world_model_cfg.decode_loss_mode is not None and self._cfg.model.world_model_cfg.decode_loss_mode.lower() != 'none': + # Output the plain text content decoded by the decoder from the next latent state + predicted_next = self._eval_model.tokenizer.decode_to_plain_text(embeddings=next_latent_state, max_length=256) + else: + predicted_next = None + output[env_id] = { 'action': action, 'visit_count_distributions': distributions, @@ -841,7 +1379,8 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 'searched_value': value, 'predicted_value': pred_values[i], 'predicted_policy_logits': policy_logits[i], - 'timestep': timestep[i] + 'timestep': timestep[i], + 'predicted_next_text': predicted_next, } batch_action.append(action) @@ -863,22 +1402,39 @@ def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_in - reset_init_data (:obj:`bool`, optional): Whether to reset the initial data. If True, the initial data will be reset. """ if reset_init_data: - self.last_batch_obs = initialize_zeros_batch( + self.last_batch_obs = initialize_pad_batch( self._cfg.model.observation_shape, self._cfg.collector_env_num, - self._cfg.device + self._cfg.device, + pad_token_id=self.pad_token_id ) self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)] - # Return immediately if env_id is None or a list - if env_id is None or isinstance(env_id, list): - return + # We must handle both single int and list of ints for env_id. + if env_id is not None: + if isinstance(env_id, int): + env_ids_to_reset = [env_id] + else: # Assumes it's a list + env_ids_to_reset = env_id + + # The key condition: `current_steps` is None only on the end-of-episode reset call from the collector. + if current_steps is None: + world_model = self._collect_model.world_model + for eid in env_ids_to_reset: + # Clear the specific environment's initial inference cache. + if eid < len(world_model.past_kv_cache_init_infer_envs): + world_model.past_kv_cache_init_infer_envs[eid].clear() + + print(f'>>> [Collector] Cleared KV cache for env_id: {eid} at episode end.') + + # ======== TODO: 20251015 ======== # Determine the clear interval based on the environment's sample type - clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + # clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else self._cfg.game_segment_length # Clear caches if the current steps are a multiple of the clear interval - if current_steps % clear_interval == 0: + if current_steps is not None and current_steps % clear_interval == 0: print(f'clear_interval: {clear_interval}') # Clear various caches in the collect model's world model @@ -891,8 +1447,7 @@ def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_in # Free up GPU memory torch.cuda.empty_cache() - print('collector: collect_model clear()') - print(f'eps_steps_lst[{env_id}]: {current_steps}') + print(f'eps_steps_lst[{env_id}]: {current_steps}, collector: collect_model clear()') def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_data: bool = True, task_id: int = None) -> None: """ @@ -911,29 +1466,57 @@ def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_ self.last_batch_obs_eval = initialize_zeros_batch( self._cfg.model.observation_shape_list[task_id], self._cfg.evaluator_env_num, - self._cfg.device + self._cfg.device, + pad_token_id=self.pad_token_id ) print(f'unizero.py task_id:{task_id} after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape) else: - self.last_batch_obs_eval = initialize_zeros_batch( + self.last_batch_obs_eval = initialize_pad_batch( # TODO self._cfg.model.observation_shape, self._cfg.evaluator_env_num, - self._cfg.device + self._cfg.device, + pad_token_id=self.pad_token_id ) print(f'unizero.py task_id:{task_id} after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape) self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)] - # Return immediately if env_id is None or a list - if env_id is None or isinstance(env_id, list): - return + # --- BEGIN ROBUST FIX --- + # This logic handles the crucial end-of-episode cache clearing for evaluation. + # The evaluator calls `_policy.reset([env_id])` when an episode is done. + if env_id is not None: + if isinstance(env_id, int): + env_ids_to_reset = [env_id] + else: # Assumes it's a list + env_ids_to_reset = env_id - # Determine the clear interval based on the environment's sample type - clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + # The key condition: `current_steps` is None only on the end-of-episode reset call from the evaluator. + if current_steps is None: + world_model = self._eval_model.world_model + for eid in env_ids_to_reset: + # Clear the specific environment's initial inference cache. + if eid < len(world_model.past_kv_cache_init_infer_envs): + world_model.past_kv_cache_init_infer_envs[eid].clear() + + print(f'>>> [Evaluator] Cleared KV cache for env_id: {eid} at episode end.') + + # The recurrent cache is global. + world_model.past_kv_cache_recurrent_infer.clear() + + if hasattr(world_model, 'keys_values_wm_list'): + world_model.keys_values_wm_list.clear() + + torch.cuda.empty_cache() + return + # --- END ROBUST FIX --- + # ======== TODO: 20251015 ======== + # Determine the clear interval based on the environment's sample type + # clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else self._cfg.game_segment_length # Clear caches if the current steps are a multiple of the clear interval - if current_steps % clear_interval == 0: + if current_steps is not None and current_steps % clear_interval == 0: print(f'clear_interval: {clear_interval}') # Clear various caches in the eval model's world model @@ -955,64 +1538,142 @@ def _monitor_vars_learn(self) -> List[str]: Register the variables to be monitored in learn mode. The registered variables will be logged in tensorboard according to the return value ``_forward_learn``. """ - return [ + base_vars = [ + # ==================== Analysis Metrics ==================== 'analysis/dormant_ratio_encoder', 'analysis/dormant_ratio_transformer', 'analysis/dormant_ratio_head', - 'analysis/avg_weight_mag_encoder', 'analysis/avg_weight_mag_transformer', 'analysis/avg_weight_mag_head', 'analysis/e_rank_last_linear', 'analysis/e_rank_sim_norm', - 'analysis/latent_state_l2_norms', + 'analysis/latent_action_l2_norms', 'analysis/l2_norm_before', 'analysis/l2_norm_after', 'analysis/grad_norm_before', 'analysis/grad_norm_after', + # ==================== Step-wise Loss Analysis ==================== 'analysis/first_step_loss_value', 'analysis/first_step_loss_policy', 'analysis/first_step_loss_rewards', 'analysis/first_step_loss_obs', - 'analysis/middle_step_loss_value', 'analysis/middle_step_loss_policy', 'analysis/middle_step_loss_rewards', 'analysis/middle_step_loss_obs', - 'analysis/last_step_loss_value', 'analysis/last_step_loss_policy', 'analysis/last_step_loss_rewards', 'analysis/last_step_loss_obs', + # ==================== System Metrics ==================== 'Current_GPU', 'Max_GPU', 'collect_epsilon', 'collect_mcts_temperature', 'cur_lr_world_model', - 'cur_lr_tokenizer', + # ==================== Core Losses ==================== 'weighted_total_loss', 'obs_loss', 'policy_loss', 'orig_policy_loss', 'policy_entropy', 'latent_recon_loss', + 'perceptual_loss', 'target_policy_entropy', 'reward_loss', 'value_loss', - 'consistency_loss', 'value_priority', 'target_reward', 'target_value', + 'transformed_target_reward', + 'transformed_target_value', + + # ==================== Gradient Norms ==================== 'total_grad_norm_before_clip_wm', - # tokenizer - 'commitment_loss', - 'reconstruction_loss', - 'perceptual_loss', + + # ==================== Logits Statistics ==================== + 'logits_value_mean', + 'logits_value_max', + 'logits_value_min', + 'logits_policy_mean', + 'logits_policy_max', + 'logits_policy_min', + + # ==================== Temperature Parameters ==================== + 'temperature_value', + 'temperature_reward', + 'temperature_policy', + + # ==================== Training Configuration ==================== + 'current_policy_label_eps', + 'adaptive_alpha', + 'adaptive_target_entropy_ratio', + 'alpha_loss', + 'current_encoder_clip_value', + ] + + # ==================== [新增] 范数和中间张量监控变量 ==================== + norm_vars = [ + # 模块总范数 (参数范数) + 'norm/encoder/_total_norm', + 'norm/transformer/_total_norm', + 'norm/head_value/_total_norm', + 'norm/head_reward/_total_norm', + 'norm/head_policy/_total_norm', + + # 模块总范数 (梯度范数) + 'grad/encoder/_total_norm', + 'grad/transformer/_total_norm', + 'grad/head_value/_total_norm', + 'grad/head_reward/_total_norm', + 'grad/head_policy/_total_norm', + + # 中间张量 x (Transformer输出) 的统计信息 + 'norm/x_token/mean', + 'norm/x_token/std', + 'norm/x_token/max', + 'norm/x_token/min', + + # Logits 的详细统计 (Value) + 'logits/value/mean', + 'logits/value/std', + 'logits/value/max', + 'logits/value/min', + 'logits/value/abs_max', + + # Logits 的详细统计 (Policy) + 'logits/policy/mean', + 'logits/policy/std', + 'logits/policy/max', + 'logits/policy/min', + 'logits/policy/abs_max', + + # Logits 的详细统计 (Reward) + 'logits/reward/mean', + 'logits/reward/std', + 'logits/reward/max', + 'logits/reward/min', + 'logits/reward/abs_max', + + # Embeddings 的统计信息 + 'embeddings/obs/norm_mean', + 'embeddings/obs/norm_std', + 'embeddings/obs/norm_max', + 'embeddings/obs/norm_min', ] + # 注意:我们不把每一层的范数都加到这里,因为数量太多会导致日志混乱。 + # 在实践中,如果通过总范数发现问题,可以临时在TensorBoard中搜索特定层的范数, + # 或者在本地打印 `norm_log_dict` 来进行详细分析。 + # wandb等工具可以更好地处理大量的动态指标。 + # ======================================================================== + + return base_vars + norm_vars + def _state_dict_learn(self) -> Dict[str, Any]: """ @@ -1021,11 +1682,16 @@ def _state_dict_learn(self) -> Dict[str, Any]: Returns: - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. """ - return { + state_dict = { 'model': self._learn_model.state_dict(), 'target_model': self._target_model.state_dict(), 'optimizer_world_model': self._optimizer_world_model.state_dict(), } + # ==================== START: 保存Alpha优化器状态 ==================== + if self.use_adaptive_entropy_weight: + state_dict['alpha_optimizer'] = self.alpha_optimizer.state_dict() + # ===================== END: 保存Alpha优化器状态 ===================== + return state_dict def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: """ @@ -1036,7 +1702,12 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: """ self._learn_model.load_state_dict(state_dict['model']) self._target_model.load_state_dict(state_dict['target_model']) - self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model']) + # self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model']) + + # ==================== START: 加载Alpha优化器状态 ==================== + # if self.use_adaptive_entropy_weight and 'alpha_optimizer' in state_dict: + # self.alpha_optimizer.load_state_dict(state_dict['alpha_optimizer']) + # ===================== END: 加载Alpha优化器状态 ===================== def recompute_pos_emb_diff_and_clear_cache(self) -> None: """ diff --git a/lzero/policy/unizero_multitask.py b/lzero/policy/unizero_multitask.py index 52469d1eb..cbf605a1e 100644 --- a/lzero/policy/unizero_multitask.py +++ b/lzero/policy/unizero_multitask.py @@ -13,108 +13,329 @@ from lzero.policy import prepare_obs_stack_for_unizero from lzero.policy import scalar_transform, InverseScalarTransform, phi_transform, \ DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, prepare_obs -from lzero.policy.unizero import UniZeroPolicy +from lzero.policy.unizero import UniZeroPolicy, scale_module_weights_vectorized from .utils import configure_optimizers_nanogpt import sys -sys.path.append('/fs-computility/ai-shen/puyuan/code/LibMTL') +# Please replace the path with the actual location of your LibMTL library. +sys.path.append('/path/to/your/LibMTL') + from LibMTL.weighting.MoCo_unizero import MoCo as GradCorrect -# from LibMTL.weighting.CAGrad_unizero import CAGrad as GradCorrect +from LibMTL.weighting.moco_fast_mem_eff import FastMoCoMemEff as FastMoCo +from LibMTL.weighting.moco_fast_mem_eff import MoCoCfg + +import torch.distributed as dist -# from LibMTL.weighting.abstract_weighting import AbsWeighting +# ------------------------------------------------------------ +# 1. Add a dedicated process-group for the learner. +# (This function should be called once during the initialization of the main process or the learner.) +# ------------------------------------------------------------ +def build_learner_group(learner_ranks: list[int]) -> dist.ProcessGroup: + """ + Overview: + Builds and returns a new process group containing only the learner ranks. + This is used for methods like GenericMoCo that require collective communication + only among the ranks performing training. + Arguments: + - learner_ranks (:obj:`list[int]`): A list of world ranks that are designated as learners. + These are the ranks that will perform the backward pass. + e.g., if CUDA_VISIBLE_DEVICES=0,1, then learner_ranks=[0,1]. + Returns: + - pg (:obj:`dist.ProcessGroup`): A new process group containing only the learner ranks. + """ + world_pg = dist.group.WORLD + pg = dist.new_group(ranks=learner_ranks, backend='nccl') + if dist.get_rank() in learner_ranks: + torch.cuda.set_device(learner_ranks.index(dist.get_rank())) + return pg -def generate_task_loss_dict(multi_task_losses, task_name_template, task_id): +def generate_task_loss_dict(multi_task_losses: List[Union[torch.Tensor, float]], task_name_template: str, task_id: int) -> Dict[str, float]: """ - 生成每个任务的损失字典 - :param multi_task_losses: 包含每个任务损失的列表 - :param task_name_template: 任务名称模板,例如 'obs_loss_task{}' - :return: 一个字典,包含每个任务的损失 + Overview: + Generates a dictionary for the losses of each task. + Arguments: + - multi_task_losses (:obj:`List[Union[torch.Tensor, float]]`): A list containing the loss for each task. + - task_name_template (:obj:`str`): The template for the task name, e.g., 'obs_loss_task{}'. + - task_id (:obj:`int`): The starting ID of the tasks. + Returns: + - task_loss_dict (:obj:`Dict[str, float]`): A dictionary where keys are formatted task names and values are the corresponding losses. """ task_loss_dict = {} for task_idx, task_loss in enumerate(multi_task_losses): task_name = task_name_template.format(task_idx + task_id) try: + # Get the scalar value of the loss if it's a tensor. task_loss_dict[task_name] = task_loss.item() if hasattr(task_loss, 'item') else task_loss except Exception as e: task_loss_dict[task_name] = task_loss return task_loss_dict +# # 修改后的函数: +# def generate_task_loss_dict( +# multi_task_losses: List[Union[torch.Tensor, float]], +# task_name_template: str, +# global_task_ids: List[int] +# ) -> Dict[str, float]: +# """ +# Overview: +# Generates a dictionary for the losses of each task using their explicit global IDs. +# Arguments: +# - multi_task_losses (:obj:`List[Union[torch.Tensor, float]]`): A list containing the loss for each task. +# - task_name_template (:obj:`str`): The template for the task name, e.g., 'obs_loss_task{}'. +# - global_task_ids (:obj:`List[int]`): A list of global task IDs corresponding to each loss in multi_task_losses. +# Returns: +# - task_loss_dict (:obj:`Dict[str, float]`): A dictionary where keys are formatted task names and values are the corresponding losses. +# """ +# task_loss_dict = {} +# # 使用 zip 将每个损失与其正确的全局ID配对 +# for task_loss, global_id in zip(multi_task_losses, global_task_ids): +# task_name = task_name_template.format(global_id) +# try: +# task_loss_dict[task_name] = task_loss.item() if hasattr(task_loss, 'item') else task_loss +# except Exception as e: +# task_loss_dict[task_name] = task_loss +# return task_loss_dict class WrappedModel: - def __init__(self, world_model): + """ + Overview: + A wrapper class for the world model to conveniently access its parameters and zero its gradients. + This version wraps the entire world model. + """ + def __init__(self, world_model: torch.nn.Module): + """ + Arguments: + - world_model (:obj:`torch.nn.Module`): The world model instance. + """ self.world_model = world_model - def parameters(self): - # 返回 tokenizer, transformer 以及所有嵌入层的参数 + def parameters(self) -> iter: + """ + Overview: + Returns an iterator over the parameters of the entire world model. + """ return self.world_model.parameters() - def zero_grad(self, set_to_none=False): - # 将 tokenizer, transformer 和所有嵌入层的梯度设为零 + def zero_grad(self, set_to_none: bool = False) -> None: + """ + Overview: + Sets the gradients of all world model parameters to zero. + Arguments: + - set_to_none (:obj:`bool`): Whether to set gradients to None instead of zero. + """ self.world_model.zero_grad(set_to_none=set_to_none) class WrappedModelV2: - def __init__(self, tokenizer, transformer, pos_emb, task_emb, act_embedding_table): + """ + Overview: + A wrapper for specific components of the world model. + This version is designed to group parameters that are considered "shared" + across tasks for gradient correction methods like MoCo, excluding the prediction heads. + """ + def __init__(self, tokenizer: torch.nn.Module, transformer: torch.nn.Module, pos_emb: torch.nn.Module, task_emb: torch.nn.Module, act_embedding_table: torch.nn.Module): + """ + Arguments: + - tokenizer (:obj:`torch.nn.Module`): The tokenizer module. + - transformer (:obj:`torch.nn.Module`): The transformer backbone. + - pos_emb (:obj:`torch.nn.Module`): The positional embedding module. + - task_emb (:obj:`torch.nn.Module`): The task embedding module. + - act_embedding_table (:obj:`torch.nn.Module`): The action embedding table. + """ self.tokenizer = tokenizer self.transformer = transformer self.pos_emb = pos_emb self.task_emb = task_emb self.act_embedding_table = act_embedding_table - def parameters(self): - # 返回 tokenizer, transformer 以及所有嵌入层的参数 + def parameters(self) -> iter: + """ + Overview: + Returns an iterator over the parameters of the wrapped components (tokenizer, transformer, embeddings). + These are typically the shared parts of the model whose gradients need to be managed for multi-task learning. + """ return (list(self.tokenizer.parameters()) + list(self.transformer.parameters()) + list(self.pos_emb.parameters()) + - # list(self.task_emb.parameters()) + # TODO + # list(self.task_emb.parameters()) + # TODO: Decide whether to include task embeddings in shared parameters. list(self.act_embedding_table.parameters())) - def zero_grad(self, set_to_none=False): - # 将 tokenizer, transformer 和所有嵌入层的梯度设为零 + def zero_grad(self, set_to_none: bool = False) -> None: + """ + Overview: + Sets the gradients of all wrapped components to zero. + Arguments: + - set_to_none (:obj:`bool`): Whether to set gradients to None instead of zero. + """ self.tokenizer.zero_grad(set_to_none=set_to_none) self.transformer.zero_grad(set_to_none=set_to_none) self.pos_emb.zero_grad(set_to_none=set_to_none) - # self.task_emb.zero_grad(set_to_none=set_to_none) # TODO + # self.task_emb.zero_grad(set_to_none=set_to_none) # TODO: Match the decision made in the parameters() method. self.act_embedding_table.zero_grad(set_to_none=set_to_none) class WrappedModelV3: - def __init__(self, transformer, pos_emb, task_emb, act_embedding_table): + """ + Overview: + An alternative wrapper for world model components. + This version excludes the tokenizer from the shared parameters, focusing gradient correction + on the transformer and embedding layers. + """ + def __init__(self, transformer: torch.nn.Module, pos_emb: torch.nn.Module, task_emb: torch.nn.Module, act_embedding_table: torch.nn.Module): + """ + Arguments: + - transformer (:obj:`torch.nn.Module`): The transformer backbone. + - pos_emb (:obj:`torch.nn.Module`): The positional embedding module. + - task_emb (:obj:`torch.nn.Module`): The task embedding module. + - act_embedding_table (:obj:`torch.nn.Module`): The action embedding table. + """ self.transformer = transformer self.pos_emb = pos_emb self.task_emb = task_emb self.act_embedding_table = act_embedding_table - def parameters(self): - # 返回 tokenizer, transformer 以及所有嵌入层的参数 + def parameters(self) -> iter: + """ + Overview: + Returns an iterator over the parameters of the transformer and various embedding layers. + """ return (list(self.transformer.parameters()) + list(self.pos_emb.parameters()) + list(self.task_emb.parameters()) + list(self.act_embedding_table.parameters())) - def zero_grad(self, set_to_none=False): - # 将 tokenizer, transformer 和所有嵌入层的梯度设为零 - # self.tokenizer.zero_grad(set_to_none=set_to_none) + def zero_grad(self, set_to_none: bool = False) -> None: + """ + Overview: + Sets the gradients of the wrapped components to zero. + Arguments: + - set_to_none (:obj:`bool`): Whether to set gradients to None instead of zero. + """ self.transformer.zero_grad(set_to_none=set_to_none) self.pos_emb.zero_grad(set_to_none=set_to_none) self.task_emb.zero_grad(set_to_none=set_to_none) self.act_embedding_table.zero_grad(set_to_none=set_to_none) +# def configure_optimizer_unizero(model, learning_rate, weight_decay, device_type, betas): +# """ +# 为UniZero模型配置带有差异化学习率的优化器。 +# """ +# # 1. 定义需要特殊处理的参数 +# param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad} + +# # 2. 将参数分为三组:Transformer主干、Tokenizer、Heads +# transformer_params = {pn: p for pn, p in param_dict.items() if 'transformer' in pn} +# tokenizer_params = {pn: p for pn, p in param_dict.items() if 'tokenizer' in pn} + +# # Heads的参数是那些既不属于transformer也不属于tokenizer的 +# head_params = { +# pn: p for pn, p in param_dict.items() +# if 'transformer' not in pn and 'tokenizer' not in pn +# } + +# # 3. 为每组设置不同的优化器参数(特别是学习率) +# # 这里我们仍然使用AdamW,但学习率设置更合理 +# optim_groups = [ +# { +# 'params': list(transformer_params.values()), +# 'lr': learning_rate, # 1e-4 +# # 'lr': learning_rate * 0.2, # 为Transformer主干设置一个较小的学习率,例如 1e-5 +# 'weight_decay': weight_decay +# # 'weight_decay': weight_decay * 5.0 +# }, +# { +# 'params': list(tokenizer_params.values()), +# 'lr': learning_rate, # Tokenizer使用基础学习率,例如 1e-4 +# # 'lr': learning_rate * 0.1, # 为encoder设置一个较小的学习率,例如 1e-5 +# 'weight_decay': weight_decay * 5.0 # <-- 为Encoder设置5倍的权重衰减!这是一个强力正则化 + +# }, +# { +# 'params': list(head_params.values()), +# 'lr': learning_rate, # Heads也使用基础学习率率,例如 1e-4 +# 'weight_decay': 0.0 # 通常Heads的权重不做衰减 +# # 'weight_decay': weight_decay + +# } +# ] + +# print("--- Optimizer Groups ---") +# print(f"Transformer LR: {learning_rate}") +# print(f"Tokenizer/Heads LR: {learning_rate}") + +# optimizer = torch.optim.AdamW(optim_groups, betas=betas) +# return optimizer + +def configure_optimizer_unizero(model, learning_rate, weight_decay, device_type, betas): + """ + 为UniZero模型配置带有差异化学习率的优化器。 + (修正版,确保参数组互斥) + """ + # 1. 创建空的参数列表用于分组 + transformer_params = [] + tokenizer_params = [] + head_params = [] + + # 2. 遍历所有可训练参数,并使用 if/elif/else 结构确保每个参数只被分配到一个组 + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + + if 'transformer' in name: + transformer_params.append(param) + elif 'tokenizer' in name: + tokenizer_params.append(param) + else: + head_params.append(param) + + # 3. 为每组设置不同的优化器参数 + # 这里我们仍然使用AdamW,但学习率设置更合理 + optim_groups = [ + { + 'params': transformer_params, + 'lr': learning_rate, # 1e-4 + 'weight_decay': weight_decay + }, + { + 'params': tokenizer_params, + 'lr': learning_rate, # Tokenizer使用基础学习率,例如 1e-4 + # 'weight_decay': weight_decay * 5.0 # <-- 为Encoder设置5倍的权重衰减!这是一个强力正则化 + 'weight_decay': weight_decay # <-- 为Encoder设置5倍的权重衰减!这是一个强力正则化 + }, + { + 'params': head_params, + 'lr': learning_rate, # Heads也使用基础学习率率,例如 1e-4 + # 'weight_decay': 0.0 # 通常Heads的权重不做衰减 + 'weight_decay': weight_decay + + } + ] + + print("--- Optimizer Groups ---") + # 打印每个组的参数数量以供调试 + print(f"Transformer params: {len(transformer_params)}") + print(f"Tokenizer params: {len(tokenizer_params)}") + print(f"Head params: {len(head_params)}") + print(f"Transformer LR: {learning_rate}") + print(f"Tokenizer/Heads LR: {learning_rate}") + + optimizer = torch.optim.AdamW(optim_groups, betas=betas) + return optimizer @POLICY_REGISTRY.register('unizero_multitask') class UniZeroMTPolicy(UniZeroPolicy): """ Overview: - The policy class for UniZero, official implementation for paper UniZero: Generalized and Efficient Planning - with Scalable LatentWorld Models. UniZero aims to enhance the planning capabilities of reinforcement learning agents - by addressing the limitations found in MuZero-style algorithms, particularly in environments requiring the - capture of long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667. + The policy class for multi-task UniZero, an official implementation for the paper "UniZero: Generalized and Efficient Planning + with Scalable Latent World Models". UniZero aims to enhance the planning capabilities of reinforcement learning agents + by addressing the limitations of MuZero-style algorithms, particularly in environments requiring the + capture of long-term dependencies. More details can be found at: https://arxiv.org/abs/2406.10667. """ - # The default_config for UniZero policy. + # The default_config for UniZero multi-task policy. config = dict( type='unizero_multitask', model=dict( @@ -144,7 +365,7 @@ class UniZeroMTPolicy(UniZeroPolicy): # (bool) whether to use res connection in dynamics. res_connection_in_dynamics=True, # (str) The type of normalization in MuZero model. Options are ['BN', 'LN']. Default to 'BN'. - norm_type='LN', # NOTE: TODO + norm_type='LN', # NOTE: LayerNorm is used in the transformer-based world model. # (bool) Whether to analyze simulation normalization. analysis_sim_norm=False, # (int) The save interval of the model. @@ -169,7 +390,7 @@ class UniZeroMTPolicy(UniZeroPolicy): # (int) The shape of the action space. action_space_size=6, # (int) The size of the group, related to simulation normalization. - group_size=8, # NOTE: sim_norm + group_size=8, # NOTE: for sim_norm # (str) The type of attention mechanism used. Options could be ['causal']. attention='causal', # (int) The number of layers in the model. @@ -205,7 +426,8 @@ class UniZeroMTPolicy(UniZeroPolicy): # (bool) Whether to analyze dormant ratio, average_weight_magnitude of net, effective_rank of latent. analysis_dormant_ratio_weight_rank=False, # (float) The threshold for a dormant neuron. - dormant_threshold=0.025, + dormant_threshold=0.01, + ), ), # ****** common ****** @@ -275,6 +497,10 @@ class UniZeroMTPolicy(UniZeroPolicy): optim_type='AdamW', # (float) Learning rate for training policy network. Initial lr for manually decay schedule. learning_rate=0.0001, + # ==================== [新增] 范数监控频率 ==================== + # 每隔多少个训练迭代步数,监控一次模型参数的范数。设置为0则禁用。 + monitor_norm_freq=5000, + # ============================================================ # (int) Frequency of hard target network update. target_update_freq=100, # (int) Frequency of soft target network update. @@ -291,8 +517,12 @@ class UniZeroMTPolicy(UniZeroPolicy): n_episode=8, # (int) The number of num_segments in each collecting stage when use muzero_segment_collector. num_segments=8, - # (int) the number of simulations in MCTS. + # # (int) the number of simulations in MCTS for renalyze. num_simulations=50, + # (int) The number of simulations in MCTS for the collect phase. + collect_num_simulations=25, + # (int) The number of simulations in MCTS for the eval phase. + eval_num_simulations=50, # (float) Discount factor (gamma) for returns. discount_factor=0.997, # (int) The number of steps for calculating target q_value. @@ -364,52 +594,188 @@ class UniZeroMTPolicy(UniZeroPolicy): def default_model(self) -> Tuple[str, List[str]]: """ Overview: - Return this algorithm default model setting for demonstration. + Return this algorithm's default model setting for demonstration. Returns: - - model_info (:obj:`Tuple[str, List[str]]`): model name and model import_names. - - model_type (:obj:`str`): The model type used in this algorithm, which is registered in ModelRegistry. - - import_names (:obj:`List[str]`): The model class path list used in this algorithm. + - model_info (:obj:`Tuple[str, List[str]]`): A tuple containing the model name and a list of import paths. + - model_type (:obj:`str`): The model type used in this algorithm, registered in ModelRegistry. + - import_names (:obj:`List[str]`): The list of model class paths used in this algorithm. .. note:: - The user can define and use customized network model but must obey the same interface definition indicated \ - by import_names path. For MuZero, ``lzero.model.unizero_model.MuZeroModel`` + Users can define and use customized network models, but they must adhere to the same interface definition + as indicated by the import_names path. For multi-task UniZero, this is ``lzero.model.unizero_model_multitask.UniZeroMTModel``. """ - # NOTE: multi-task model + # NOTE: This specifies the default multi-task model. return 'UniZeroMTModel', ['lzero.model.unizero_model_multitask'] + # ==================== [新增] 模型范数监控函数 ==================== + def _monitor_model_norms(self) -> Dict[str, float]: + """ + Overview: + 计算并返回模型关键组件(Encoder, Transformer, Heads)的参数矩阵范数。 + 此函数应在 torch.no_grad() 环境下调用,以提高效率。 + Returns: + - norm_metrics (:obj:`Dict[str, float]`): 包含所有范数指标的字典,用于日志记录。 + """ + world_model = self._learn_model.world_model + norm_metrics = {} + + # 定义要监控的模块组 + module_groups = { + 'encoder': world_model.tokenizer.encoder, + 'transformer': world_model.transformer, + 'head_value': world_model.head_values, # Note: multi-task uses head_values (plural) + 'head_reward': world_model.head_rewards, + 'head_policy': world_model.head_policies, # Note: multi-task uses head_policies (plural) + } + + for group_name, group_module in module_groups.items(): + # Handle ModuleList (for multi-task heads) + if isinstance(group_module, torch.nn.ModuleList): + for task_idx, task_module in enumerate(group_module): + total_norm_sq = 0.0 + for param_name, param in task_module.named_parameters(): + if param.requires_grad: + param_norm = param.data.norm(2).item() + log_name = f'norm/{group_name}_task{task_idx}/{param_name.replace(".", "/")}' + norm_metrics[log_name] = param_norm + total_norm_sq += param_norm ** 2 + total_group_norm = np.sqrt(total_norm_sq) + norm_metrics[f'norm/{group_name}_task{task_idx}/_total_norm'] = total_group_norm + else: + # Handle single module + total_norm_sq = 0.0 + for param_name, param in group_module.named_parameters(): + if param.requires_grad: + param_norm = param.data.norm(2).item() + log_name = f'norm/{group_name}/{param_name.replace(".", "/")}' + norm_metrics[log_name] = param_norm + total_norm_sq += param_norm ** 2 + total_group_norm = np.sqrt(total_norm_sq) + norm_metrics[f'norm/{group_name}/_total_norm'] = total_group_norm + + return norm_metrics + + def _monitor_gradient_norms(self) -> Dict[str, float]: + """ + Overview: + 计算并返回模型关键组件的梯度范数。 + 此函数应在梯度计算完成后、参数更新之前调用。 + Returns: + - grad_metrics (:obj:`Dict[str, float]`): 包含所有梯度范数指标的字典,用于日志记录。 + """ + world_model = self._learn_model.world_model + grad_metrics = {} + + # 定义要监控的模块组 + module_groups = { + 'encoder': world_model.tokenizer.encoder, + 'transformer': world_model.transformer, + 'head_value': world_model.head_values, + 'head_reward': world_model.head_rewards, + 'head_policy': world_model.head_policies, + } + + for group_name, group_module in module_groups.items(): + # Handle ModuleList (for multi-task heads) + if isinstance(group_module, torch.nn.ModuleList): + for task_idx, task_module in enumerate(group_module): + total_grad_norm_sq = 0.0 + num_params_with_grad = 0 + for param_name, param in task_module.named_parameters(): + if param.requires_grad and param.grad is not None: + grad_norm = param.grad.data.norm(2).item() + log_name = f'grad/{group_name}_task{task_idx}/{param_name.replace(".", "/")}' + grad_metrics[log_name] = grad_norm + total_grad_norm_sq += grad_norm ** 2 + num_params_with_grad += 1 + if num_params_with_grad > 0: + total_group_grad_norm = np.sqrt(total_grad_norm_sq) + grad_metrics[f'grad/{group_name}_task{task_idx}/_total_norm'] = total_group_grad_norm + else: + grad_metrics[f'grad/{group_name}_task{task_idx}/_total_norm'] = 0.0 + else: + # Handle single module + total_grad_norm_sq = 0.0 + num_params_with_grad = 0 + for param_name, param in group_module.named_parameters(): + if param.requires_grad and param.grad is not None: + grad_norm = param.grad.data.norm(2).item() + log_name = f'grad/{group_name}/{param_name.replace(".", "/")}' + grad_metrics[log_name] = grad_norm + total_grad_norm_sq += grad_norm ** 2 + num_params_with_grad += 1 + if num_params_with_grad > 0: + total_group_grad_norm = np.sqrt(total_grad_norm_sq) + grad_metrics[f'grad/{group_name}/_total_norm'] = total_group_grad_norm + else: + grad_metrics[f'grad/{group_name}/_total_norm'] = 0.0 + + return grad_metrics + # ================================================================= + def _init_learn(self) -> None: """ Overview: - Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. + Initializes the learn mode. This method is called by ``self.__init__``. + It sets up the learn model, optimizer, target model, and other utilities required for training. """ - # NOTE: nanoGPT optimizer - self._optimizer_world_model = configure_optimizers_nanogpt( - model=self._model.world_model, - learning_rate=self._cfg.learning_rate, - weight_decay=self._cfg.weight_decay, - device_type=self._cfg.device, - betas=(0.9, 0.95), - ) + if self._cfg.optim_type == 'SGD': + # --- 改为SGD优化器 --- + self._optimizer_world_model = torch.optim.SGD( + self._model.world_model.parameters(), + lr=self._cfg.learning_rate, # 初始学习率,在配置中设为 0.2 + momentum=self._cfg.momentum, # 在配置中设为 0.9 + weight_decay=self._cfg.weight_decay # 在配置中设为 1e-4 + ) + elif self._cfg.optim_type == 'AdamW': + # NOTE: nanoGPT optimizer + self._optimizer_world_model = configure_optimizers_nanogpt( + model=self._model.world_model, + learning_rate=self._cfg.learning_rate, + weight_decay=self._cfg.weight_decay, + device_type=self._cfg.device, + betas=(0.9, 0.95), + ) + elif self._cfg.optim_type == 'AdamW_mix_lr_wdecay': + self._optimizer_world_model = configure_optimizer_unizero( + model=self._model.world_model, + learning_rate=self._cfg.learning_rate, # 使用一个合理的AdamW基础学习率 + weight_decay=self._cfg.weight_decay, + device_type=self._cfg.device, + betas=(0.9, 0.95), + ) + + if self._cfg.cos_lr_scheduler: + from torch.optim.lr_scheduler import CosineAnnealingLR + # TODO: check the total training steps + # self.lr_scheduler = CosineAnnealingLR(self._optimizer_world_model, 1e5, eta_min=0, last_epoch=-1) + total_iters = self._cfg.get('total_iterations', 500000) # 500k iter + # final_lr = self._cfg.get('final_learning_rate', 0.0) + final_lr = self._cfg.get('final_learning_rate', 1e-6) + + self.lr_scheduler = CosineAnnealingLR( + self._optimizer_world_model, + T_max=total_iters, + eta_min=final_lr + ) + print(f"CosineAnnealingLR enabled: T_max={total_iters}, eta_min={final_lr}") + + + if self._cfg.piecewise_decay_lr_scheduler: + from torch.optim.lr_scheduler import LambdaLR + max_step = self._cfg.threshold_training_steps_for_final_lr + # NOTE: the 1, 0.1, 0.01 is the decay rate, not the lr. + lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa + self.lr_scheduler = LambdaLR(self._optimizer_world_model, lr_lambda=lr_lambda) - if self._cfg.cos_lr_scheduler or self._cfg.piecewise_decay_lr_scheduler: - from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR - - if self._cfg.cos_lr_scheduler: - self.lr_scheduler = CosineAnnealingLR( - self._optimizer_world_model, T_max=int(2e5), eta_min=0, last_epoch=-1 - ) # TODO - elif self._cfg.piecewise_decay_lr_scheduler: - # Example step scheduler, adjust milestones and gamma as needed - self.lr_scheduler = StepLR( - self._optimizer_world_model, step_size=int(5e4), gamma=0.1 - ) - # use model_wrapper for specialized demands of different modes + # Use a deep copy for the target model. self._target_model = copy.deepcopy(self._model) - # Ensure that the installed torch version is greater than or equal to 2.0 + # Ensure that the installed torch version is >= 2.0 for torch.compile. assert int(''.join(filter(str.isdigit, torch.__version__))) >= 200, "We need torch version >= 2.0" self._model = torch.compile(self._model) self._target_model = torch.compile(self._target_model) - # NOTE: soft target + + # Wrap the target model for soft updates (momentum-based). self._target_model = model_wrap( self._target_model, wrapper_name='target', @@ -423,19 +789,21 @@ def _init_learn(self) -> None: self._cfg.augmentation, image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) ) - self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) - self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) - self.inverse_scalar_transform_handle = InverseScalarTransform( - self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution - ) + + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) + self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) + self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) + self.intermediate_losses = defaultdict(float) self.l2_norm_before = 0. self.l2_norm_after = 0. self.grad_norm_before = 0. self.grad_norm_after = 0. - # 创建 WrappedModel 实例 - # 所有参数都共享,即所有参数都需要进行矫正 + # Create a WrappedModel instance. + # This is used for gradient correction methods where gradients of shared parameters are managed. + # In this setup, all parameters are considered shared and subject to correction. # wrapped_model = WrappedModel( # self._learn_model.world_model, # ) @@ -445,17 +813,17 @@ def _init_learn(self) -> None: print(f'self._cfg.only_use_moco_stats:{self._cfg.only_use_moco_stats}') if self._cfg.use_moco or self._cfg.only_use_moco_stats: - # head 没有矫正梯度 + # The prediction heads' gradients are not corrected. self.wrapped_model = WrappedModelV2( - # self._learn_model.world_model.tokenizer, # TODO: - self._learn_model.world_model.tokenizer.encoder[0], # TODO: one encoder + # TODO: This assumes the tokenizer has an encoder attribute which is a list. This might need to be more robust. + self._learn_model.world_model.tokenizer.encoder[0], self._learn_model.world_model.transformer, self._learn_model.world_model.pos_emb, self._learn_model.world_model.task_emb, self._learn_model.world_model.act_embedding_table, ) - # head 和 tokenizer.encoder 没有矫正梯度 + # Alternative setup: The head and tokenizer.encoder gradients are not corrected. # wrapped_model = WrappedModelV3( # self._learn_model.world_model.transformer, # self._learn_model.world_model.pos_emb, @@ -463,34 +831,149 @@ def _init_learn(self) -> None: # self._learn_model.world_model.act_embedding_table, # ) - # 将 wrapped_model 作为 share_model 传递给 GradCorrect - # ========= 初始化 MoCo CAGrad 参数 ========= - # self.grad_correct = GradCorrect(self.wrapped_model, self.task_num_for_current_rank, self._cfg.device) - self.grad_correct = GradCorrect(self.wrapped_model, self._cfg.total_task_num, self._cfg.device, self._cfg.multi_gpu) # only compatiable with for 1GPU training + # Pass the wrapped_model as `shared_module` to the gradient correction method. + # ========= Initialize MoCo/CAGrad parameters ========= + if self._cfg.moco_version=="v0": + # This version is only compatible with single-GPU training. + self.grad_correct = GradCorrect(self.wrapped_model, self._cfg.total_task_num, self._cfg.device, self._cfg.multi_gpu) + self.grad_correct.init_param() + self.grad_correct.rep_grad = False + elif self._cfg.moco_version=="v1": + cfg_moco = MoCoCfg( + beta0=0.9, beta_sigma=0.95, + gamma0=0.1, gamma_sigma=0.95, + rho=0.01, stat_interval=10000) + self.grad_correct = FastMoCo( + shared_module=self.wrapped_model, + world_task_num=self._cfg.total_task_num, # Total number of tasks globally + device=self._cfg.device, + multi_gpu=self._cfg.multi_gpu, + cfg=cfg_moco, + ) - self.grad_correct.init_param() - self.grad_correct.rep_grad = False + # Cache for plasticity-related metrics from the previous frame. + self._prev_plasticity_metrics = dict( + dormant_ratio_encoder = 0.0, + dormant_ratio_transformer = 0.0, + dormant_ratio_head = 0.0, + avg_weight_mag_encoder = 0.0, + avg_weight_mag_transformer = 0.0, + avg_weight_mag_head = 0.0, + e_rank_last_linear = 0.0, + e_rank_sim_norm = 0.0, + ) + # ==================== START: 目标熵正则化初始化 ==================== + # 从配置中读取是否启用自适应alpha,并提供一个默认值 + self.use_adaptive_entropy_weight = self._cfg.get('use_adaptive_entropy_weight', True) + + # 在 _init_learn 中增加配置 + self.target_entropy_start_ratio = self._cfg.get('target_entropy_start_ratio', 0.98) + self.target_entropy_end_ratio = self._cfg.get('target_entropy_end_ratio', 0.7) + self.target_entropy_decay_steps = self._cfg.get('target_entropy_decay_steps', 200000) # 例如,在200k步内完成退火 2M envsteps + + if self.use_adaptive_entropy_weight: + # 1. 设置目标熵。对于离散动作空间,一个常见的启发式设置是动作空间维度的负对数乘以一个系数。 + # 这个系数(例如0.98)可以作为一个超参数。 + action_space_size = self._cfg.model.action_space_size + self.target_entropy = -np.log(1.0 / action_space_size) * 0.98 + + # 2. 初始化一个可学习的 log_alpha 参数。 + # 初始化为0,意味着初始的 alpha = exp(0) = 1.0。 + self.log_alpha = torch.nn.Parameter(torch.zeros(1, device=self._cfg.device), requires_grad=True) + + # 3. 为 log_alpha 创建一个专属的优化器。 + # 使用与主优化器不同的、较小的学习率(例如1e-4)通常更稳定。 + alpha_lr = self._cfg.get('adaptive_entropy_alpha_lr', 1e-4) + self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=alpha_lr) + + print("="*20) + print(">>> 目标熵正则化 (自适应Alpha) 已启用 <<<") + print(f" 目标熵 (Target Entropy): {self.target_entropy:.4f}") + print(f" Alpha 优化器学习率: {alpha_lr:.2e}") + print("="*20) + # ===================== END: 目标熵正则化初始化 ===================== + + self.latent_norm_clip_threshold = self._cfg.get('latent_norm_clip_threshold', 30.0) + # ==================== START: 初始化 Encoder-Clip Annealing 参数 ==================== + self.use_encoder_clip_annealing = self._cfg.get('use_encoder_clip_annealing', False) + if self.use_encoder_clip_annealing: + self.encoder_clip_anneal_type = self._cfg.get('encoder_clip_anneal_type', 'cosine') + self.encoder_clip_start = self._cfg.get('encoder_clip_start_value', 30.0) + self.encoder_clip_end = self._cfg.get('encoder_clip_end_value', 10.0) + self.encoder_clip_anneal_steps = self._cfg.get('encoder_clip_anneal_steps', 200000) + + print("="*20) + print(">>> Encoder-Clip 退火已启用 <<<") + print(f" 类型: {self.encoder_clip_anneal_type}") + print(f" 范围: {self.encoder_clip_start} -> {self.encoder_clip_end}") + print(f" 步数: {self.encoder_clip_anneal_steps}") + print("="*20) + else: + # 如果不启用退火,则使用固定的 clip 阈值 + self.latent_norm_clip_threshold = self._cfg.get('latent_norm_clip_threshold', 30.0) + # ===================== END: 初始化 Encoder-Clip Annealing 参数 ===================== + + # --- NEW: Policy Label Smoothing Parameters --- + self.policy_ls_eps_start = self._cfg.get('policy_ls_eps_start', 0.05) # TODO policy_label_smoothing_eps_start 越大的action space需要越大的eps + self.policy_ls_eps_end = self._cfg.get('policy_label_smoothing_eps_end ', 0.01) # TODO policy_label_smoothing_eps_start + self.policy_ls_eps_decay_steps = self._cfg.get('policy_ls_eps_decay_steps ', 50000) # TODO 50k + print(f"self.policy_ls_eps_start:{self.policy_ls_eps_start}") + + @staticmethod + def _is_zero(x: Union[float, torch.Tensor], eps: float = 1e-8) -> bool: + """ + Overview: + Checks if a scalar or a 0-D tensor can be considered zero within a small tolerance. + Arguments: + - x (:obj:`Union[float, torch.Tensor]`): The input value to check. + - eps (:obj:`float`): The tolerance for checking against zero. + Returns: + - (:obj:`bool`): True if the value is close to zero, False otherwise. + """ + if isinstance(x, torch.Tensor): + return torch.all(torch.abs(x) < eps).item() + return abs(x) < eps + def _retain_prev_if_zero(self, name: str, + value: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]: + """ + Overview: + If the current `value` is close to zero, returns the cached value from the previous frame. + Otherwise, it updates the cache with the current value and returns it. This is useful for + metrics that are computed intermittently. + Arguments: + - name (:obj:`str`): The name of the metric to cache. + - value (:obj:`Union[float, torch.Tensor]`): The current value of the metric. + Returns: + - (:obj:`Union[float, torch.Tensor]`): The retained or current value. + """ + if self._is_zero(value): + # Directly return the previous value (can be float or tensor). + return self._prev_plasticity_metrics[name] + else: + # Update the cache and return the current value. + self._prev_plasticity_metrics[name] = value + return value #@profile - def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None) -> Dict[str, Union[float, int]]: + def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, train_iter=None, ignore_grad=False) -> Dict[str, Union[float, int]]: """ Overview: - The forward function for learning policy in learn mode, which is the core of the learning process. - The data is sampled from replay buffer. - The loss is calculated by the loss function and the loss is backpropagated to update the model. + The forward function for learning in the policy. This is the core of the training process. + Data is sampled from the replay buffer, losses are calculated, and the model is updated via backpropagation. Arguments: - - data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. - The first tensor is the current_batch, the second tensor is the target_batch. + - data (:obj:`Tuple[torch.Tensor]`): A tuple of data batches, where each element corresponds to a different task. + - task_weights (:obj:`Any`, optional): Optional weights for each task's loss. Not currently used. + - ignore_grad (:obj:`bool`): If True, gradients are zeroed out after computation, effectively skipping the update. Returns: - - info_dict (:obj:`Dict[str, Union[float, int]]`): The information dict to be logged, which contains \ - current learning loss and learning statistics. + - info_dict (:obj:`Dict[str, Union[float, int]]`): A dictionary containing current learning losses and statistics for logging. """ self._learn_model.train() self._target_model.train() + # Lists to store metrics for each task within the batch. obs_loss_multi_task = [] reward_loss_multi_task = [] policy_loss_multi_task = [] @@ -499,14 +982,15 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None) -> Dict[s perceptual_loss_multi_task = [] orig_policy_loss_multi_task = [] policy_entropy_multi_task = [] - weighted_total_loss = 0.0 # 初始化为0,避免使用in-place操作 + weighted_total_loss = 0.0 # Initialize to 0.0 to avoid in-place operations. + total_alpha_loss = 0.0 latent_state_l2_norms_multi_task = [] average_target_policy_entropy_multi_task = [] value_priority_multi_task = [] value_priority_mean_multi_task = [] - # 网络可塑性分析指标 + # Metrics for network plasticity analysis. dormant_ratio_encoder_multi_task = [] dormant_ratio_transformer_multi_task = [] dormant_ratio_head_multi_task = [] @@ -516,95 +1000,134 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None) -> Dict[s e_rank_last_linear_multi_task = [] e_rank_sim_norm_multi_task = [] + # --- NEW: Calculate current epsilon for policy --- + # if self.policy_ls_eps_start > 0: + # progress = min(1.0, train_iter / self.policy_ls_eps_decay_steps) + # current_policy_label_eps = self.policy_ls_eps_start * (1 - progress) + self.policy_ls_eps_end * progress + # else: + # current_policy_label_eps = 0.0 + current_policy_label_eps = 0.01 + + # 新增一个列表来收集当前批次中所有任务的真实全局ID + global_task_ids_in_batch = [] + alpha_loss = None + + + # 用于Alpha日志记录的新列表 + alpha_loss_multi_task = [] + target_entropy_multi_task = [] - losses_list = [] # 用于存储每个任务的损失 + # 仅在自适应alpha启用时,预先获取当前alpha值,确保在单次迭代中对所有任务一致 + current_alpha = self._cfg.model.world_model_cfg.policy_entropy_weight + if self.use_adaptive_entropy_weight: + current_alpha = self.log_alpha.exp().detach() + + losses_list = [] # Used to store the loss tensor for each task, required by gradient correction methods. for task_id, data_one_task in enumerate(data): - current_batch, target_batch, task_id = data_one_task - # current_batch, target_batch, _ = data - # TODO: multitask适配rope(timestep_batch) + current_batch, target_batch, task_id = data_one_task # task_id 是真实的全局ID + + # 将真实的全局ID添加到列表中 + global_task_ids_in_batch.append(task_id) + + # TODO: Adapt RoPE for multitask settings (using timestep_batch). obs_batch_ori, action_batch, target_action_batch, mask_batch, indices, weights, make_time, timestep_batch = current_batch target_reward, target_value, target_policy = target_batch - # Prepare observations based on frame stack number + # Prepare observations based on frame stack number. if self._cfg.model.frame_stack_num == 4: obs_batch, obs_target_batch = prepare_obs_stack_for_unizero(obs_batch_ori, self._cfg) else: obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) - # Apply augmentations if needed + # Apply augmentations if needed. if self._cfg.use_augmentation: obs_batch = self.image_transforms.transform(obs_batch) if self._cfg.model.self_supervised_learning_loss: obs_target_batch = self.image_transforms.transform(obs_target_batch) - # Prepare action batch and convert to torch tensor + # Prepare action batch and convert to a torch tensor. action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze( - -1).long() # For discrete action space + -1).long() # For discrete action space. data_list = [mask_batch, target_reward.astype('float32'), target_value.astype('float32'), target_policy, weights] mask_batch, target_reward, target_value, target_policy, weights = to_torch_float_tensor(data_list, self._cfg.device) - - # rank = get_rank() - # print(f'Rank {rank}: cfg.policy.task_id : {self._cfg.task_id}, self._cfg.batch_size {self._cfg.batch_size}') - - target_reward = target_reward.view(self._cfg.batch_size[task_id], -1) - target_value = target_value.view(self._cfg.batch_size[task_id], -1) - - target_reward = target_reward.view(self._cfg.batch_size[task_id], -1) - target_value = target_value.view(self._cfg.batch_size[task_id], -1) + cur_batch_size = target_reward.size(0) # Run-time batch size. - # assert obs_batch.size(0) == self._cfg.batch_size == target_reward.size(0) + target_reward = target_reward.view(cur_batch_size, -1) + target_value = target_value.view(cur_batch_size, -1) - # Transform rewards and values to their scaled forms + # Transform scalar rewards and values to their scaled representations. transformed_target_reward = scalar_transform(target_reward) transformed_target_value = scalar_transform(target_value) - # Convert to categorical distributions - target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) - target_value_categorical = phi_transform(self.value_support, transformed_target_value) + # Convert scaled representations to categorical distributions. + # target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) + # target_value_categorical = phi_transform(self.value_support, transformed_target_value) + + target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward, label_smoothing_eps= self._cfg.label_smoothing_eps) + target_value_categorical = phi_transform(self.value_support, transformed_target_value, label_smoothing_eps=self._cfg.label_smoothing_eps) + - # Prepare batch for a transformer-based world model + # Prepare the batch for the transformer-based world model. batch_for_gpt = {} if isinstance(self._cfg.model.observation_shape, int) or len(self._cfg.model.observation_shape) == 1: batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( - self._cfg.batch_size[task_id], -1, self._cfg.model.observation_shape) + cur_batch_size, -1, self._cfg.model.observation_shape) elif len(self._cfg.model.observation_shape) == 3: batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( - self._cfg.batch_size[task_id], -1, *self._cfg.model.observation_shape) + cur_batch_size, -1, *self._cfg.model.observation_shape) batch_for_gpt['actions'] = action_batch.squeeze(-1) batch_for_gpt['rewards'] = target_reward_categorical[:, :-1] - batch_for_gpt['mask_padding'] = mask_batch == 1.0 # 0 means invalid padding data + batch_for_gpt['mask_padding'] = mask_batch == 1.0 # 0 means invalid padding data. batch_for_gpt['mask_padding'] = batch_for_gpt['mask_padding'][:, :-1] batch_for_gpt['observations'] = batch_for_gpt['observations'][:, :-1] batch_for_gpt['ends'] = torch.zeros(batch_for_gpt['mask_padding'].shape, dtype=torch.long, device=self._cfg.device) batch_for_gpt['target_value'] = target_value_categorical[:, :-1] batch_for_gpt['target_policy'] = target_policy[:, :-1] + batch_for_gpt['scalar_target_value'] = target_value - # Extract valid target policy data and compute entropy + # Extract valid target policy data and compute its entropy. valid_target_policy = batch_for_gpt['target_policy'][batch_for_gpt['mask_padding']] target_policy_entropy = -torch.sum(valid_target_policy * torch.log(valid_target_policy + 1e-9), dim=-1) average_target_policy_entropy = target_policy_entropy.mean().item() - # Update world model + # Update world model and compute losses. intermediate_losses = defaultdict(float) + # losses = self._learn_model.world_model.compute_loss( + # batch_for_gpt, self._target_model.world_model.tokenizer, self.value_inverse_scalar_transform_handle, task_id=task_id + # ) + losses = self._learn_model.world_model.compute_loss( - batch_for_gpt, self._target_model.world_model.tokenizer, self.inverse_scalar_transform_handle, task_id=task_id + batch_for_gpt, self._target_model.world_model.tokenizer, self.value_inverse_scalar_transform_handle, current_policy_label_eps=current_policy_label_eps, task_id=task_id ) - weighted_total_loss += losses.loss_total # TODO + # ==================== START MODIFICATION 2 ==================== + # Extract the calculated value_priority from the returned losses. + value_priority_tensor = losses.intermediate_losses['value_priority'] + # Convert to numpy array for the replay buffer, adding a small epsilon. + value_priority_np = value_priority_tensor.detach().cpu().numpy() + 1e-6 + # ===================== END MODIFICATION 2 ===================== + - # assert not torch.isnan(losses.loss_total).any(), "Loss contains NaN values" # TODO + # TODO: Accumulate the weighted total loss. This assumes the loss from `compute_loss` is already weighted. + weighted_total_loss += losses.loss_total # NOTE:+= + + # TODO: Add assertions to check for NaN or Inf values in the loss if needed for debugging. + # assert not torch.isnan(losses.loss_total).any(), "Loss contains NaN values" # assert not torch.isinf(losses.loss_total).any(), "Loss contains Inf values" - losses_list.append(losses.loss_total) # TODO: for moco + # TODO: Append the total loss for this task, used by MoCo. + losses_list.append(losses.loss_total) for loss_name, loss_value in losses.intermediate_losses.items(): intermediate_losses[f"{loss_name}"] = loss_value + + obs_loss = intermediate_losses['loss_obs'] reward_loss = intermediate_losses['loss_rewards'] policy_loss = intermediate_losses['loss_policy'] @@ -615,47 +1138,114 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None) -> Dict[s perceptual_loss = intermediate_losses['perceptual_loss'] latent_state_l2_norms = intermediate_losses['latent_state_l2_norms'] - # value_priority = intermediate_losses['value_priority'] - # logits_value = intermediate_losses['logits_value'] - - # print(f'logits_value:" {logits_value}') - # print(f'logits_value.shape:" {logits_value.shape}') - # print(f"batch_for_gpt['observations'].shape: {batch_for_gpt['observations'].shape}") - - # ============ for value priority ============ - # transform the categorical representation of the scaled value to its original value - # original_value = self.inverse_scalar_transform_handle(logits_value.reshape(-1, 101)).reshape( + # 从 losses 对象中提取策略熵 + # ==================== START: 目标熵正则化更新逻辑 ==================== + current_alpha = self._cfg.model.world_model_cfg.policy_entropy_weight # 默认使用固定值 + if self.use_adaptive_entropy_weight: + + # --- 动态计算目标熵 (这部分逻辑是正确的,予以保留) --- + progress = min(1.0, train_iter / self.target_entropy_decay_steps) + current_ratio = self.target_entropy_start_ratio * (1 - progress) + self.target_entropy_end_ratio * progress + action_space_size = self._cfg.model.action_space_size + # 注意:我们将 target_entropy 定义为正数,更符合直觉 + current_target_entropy = -np.log(1.0 / action_space_size) * current_ratio + + # --- 计算 alpha_loss (已修正符号) --- + # 这是核心修正点:去掉了最前面的负号 + # detach() 仍然是关键,确保 alpha_loss 的梯度只流向 log_alpha + alpha_loss_task = (self.log_alpha * (policy_entropy.detach() - current_target_entropy)).mean() # NOTE:= + + # # --- 更新 log_alpha --- + # self.alpha_optimizer.zero_grad() + # alpha_loss.backward() + # self.alpha_optimizer.step() + + # 累加alpha_loss + total_alpha_loss += alpha_loss_task + # 为日志记录收集每个任务的alpha_loss和目标熵 + alpha_loss_multi_task.append(alpha_loss_task) + target_entropy_multi_task.append(current_target_entropy) + + # --- [优化建议] 增加 log_alpha 裁剪作为安全措施 --- + with torch.no_grad(): + # 将 alpha 限制在例如 [1e-4, 10.0] 的范围内 + self.log_alpha.clamp_(np.log(1e-4), np.log(10.0)) + + # --- 使用当前更新后的 alpha (截断梯度流) --- + current_alpha = self.log_alpha.exp().detach() + + # 重新计算加权的策略损失和总损失 + # 注意:这里的 policy_entropy 已经是一个batch的平均值 + weighted_policy_loss = orig_policy_loss - current_alpha * policy_entropy + # 重新构建总损失 (不使用 losses.loss_total) + # 确保这里的权重与 LossWithIntermediateLosses 类中的计算方式一致 + self.obs_loss_weight = 10 + self.value_loss_weight = 0.5 + self.reward_loss_weight = 1. + self.policy_loss_weight = 1. + self.ends_loss_weight = 0. + total_loss = ( + self.reward_loss_weight * reward_loss + + self.value_loss_weight * value_loss + + self.policy_loss_weight * weighted_policy_loss + + self.obs_loss_weight * obs_loss # 假设 ssl_loss_weight 是 obs_loss 的权重 + # ... 如果还有其他损失项,也加进来 ... + ) + weighted_total_loss += (weights * total_loss).mean() # NOTE:+= + # ===================== END: 目标熵正则化更新逻辑 ===================== + + # ============ For value-based priority calculation ============ + # TODO: The following section for calculating value_priority is commented out. + # If re-enabled, ensure it correctly computes L1 loss between predicted and target values + # and handles CPU/Numpy conversion properly. + # original_value = self.value_inverse_scalar_transform_handle(logits_value.reshape(-1, 101)).reshape( # batch_for_gpt['observations'].shape[0], batch_for_gpt['observations'].shape[1], 1) - # calculate the new priorities for each transition. - # value_priority = torch.nn.L1Loss(reduction='none')(original_value.squeeze(-1)[:,0], target_value[:, 0]) # TODO: mix of mean and sum - # value_priority = value_priority.data.cpu().numpy() + 1e-6 # TODO: log-reduce not support array now - value_priority = torch.tensor(0., device=self._cfg.device) - # ============ for value priority ============ - - # 关于网络可塑性的指标 - dormant_ratio_encoder = intermediate_losses['dormant_ratio_encoder'] - dormant_ratio_transformer = intermediate_losses['dormant_ratio_transformer'] - dormant_ratio_head = intermediate_losses['dormant_ratio_head'] - avg_weight_mag_encoder = intermediate_losses['avg_weight_mag_encoder'] - avg_weight_mag_transformer = intermediate_losses['avg_weight_mag_transformer'] - avg_weight_mag_head = intermediate_losses['avg_weight_mag_head'] - e_rank_last_linear = intermediate_losses['e_rank_last_linear'] - e_rank_sim_norm = intermediate_losses['e_rank_sim_norm'] - + # value_priority = torch.nn.L1Loss(reduction='none')(original_value.squeeze(-1)[:,0], target_value[:, 0]) + # value_priority = value_priority.data.cpu().numpy() + 1e-6 + # value_priority = torch.tensor(0., device=self._cfg.device) + # ============ End of value priority section ============ + + # Metrics related to network plasticity. + # Use the helper function to retain the previous value if the current one is zero. + dormant_ratio_encoder = self._retain_prev_if_zero( + 'dormant_ratio_encoder', + intermediate_losses['dormant_ratio_encoder']) + dormant_ratio_transformer = self._retain_prev_if_zero( + 'dormant_ratio_transformer', + intermediate_losses['dormant_ratio_transformer']) + dormant_ratio_head = self._retain_prev_if_zero( + 'dormant_ratio_head', + intermediate_losses['dormant_ratio_head']) + avg_weight_mag_encoder = self._retain_prev_if_zero( + 'avg_weight_mag_encoder', + intermediate_losses['avg_weight_mag_encoder']) + avg_weight_mag_transformer = self._retain_prev_if_zero( + 'avg_weight_mag_transformer', + intermediate_losses['avg_weight_mag_transformer']) + avg_weight_mag_head = self._retain_prev_if_zero( + 'avg_weight_mag_head', + intermediate_losses['avg_weight_mag_head']) + e_rank_last_linear = self._retain_prev_if_zero( + 'e_rank_last_linear', + intermediate_losses['e_rank_last_linear']) + e_rank_sim_norm = self._retain_prev_if_zero( + 'e_rank_sim_norm', + intermediate_losses['e_rank_sim_norm']) + + # Append all metrics for this task to their respective lists. obs_loss_multi_task.append(obs_loss) reward_loss_multi_task.append(reward_loss) policy_loss_multi_task.append(policy_loss) orig_policy_loss_multi_task.append(orig_policy_loss) policy_entropy_multi_task.append(policy_entropy) - reward_loss_multi_task.append(reward_loss) value_loss_multi_task.append(value_loss) latent_recon_loss_multi_task.append(latent_recon_loss) perceptual_loss_multi_task.append(perceptual_loss) latent_state_l2_norms_multi_task.append(latent_state_l2_norms) - value_priority_multi_task.append(value_priority) - value_priority_mean_multi_task.append(value_priority.mean().item()) + value_priority_multi_task.append(value_priority_tensor) + value_priority_mean_multi_task.append(value_priority_tensor.mean().item()) - # 关于网络可塑性的指标 + # Append plasticity metrics. dormant_ratio_encoder_multi_task.append(dormant_ratio_encoder) dormant_ratio_transformer_multi_task.append(dormant_ratio_transformer) dormant_ratio_head_multi_task.append(dormant_ratio_head) @@ -666,37 +1256,160 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None) -> Dict[s e_rank_sim_norm_multi_task.append(e_rank_sim_norm) - # Core learn model update step + # ==================== [新增] 集成范数监控逻辑 ==================== + norm_log_dict = {} + # 检查是否达到监控频率 + if self._cfg.monitor_norm_freq > 0 and (train_iter == 0 or (train_iter % self._cfg.monitor_norm_freq == 0)): + with torch.no_grad(): + # 1. 监控模型参数范数 + param_norm_metrics = self._monitor_model_norms() + norm_log_dict.update(param_norm_metrics) + + # 2. 监控中间张量 x (Transformer的输出) + intermediate_x = losses.intermediate_losses.get('intermediate_tensor_x') + if intermediate_x is not None: + # x 的形状为 (B, T, E) + # 计算每个 token 的 L2 范数 + token_norms = intermediate_x.norm(p=2, dim=-1) + + # 记录这些范数的统计数据 + norm_log_dict['norm/x_token/mean'] = token_norms.mean().item() + norm_log_dict['norm/x_token/std'] = token_norms.std().item() + norm_log_dict['norm/x_token/max'] = token_norms.max().item() + norm_log_dict['norm/x_token/min'] = token_norms.min().item() + + # 3. 监控 logits 的详细统计 (Value, Policy, Reward) + logits_value = losses.intermediate_losses.get('logits_value') + if logits_value is not None: + norm_log_dict['logits/value/mean'] = logits_value.mean().item() + norm_log_dict['logits/value/std'] = logits_value.std().item() + norm_log_dict['logits/value/max'] = logits_value.max().item() + norm_log_dict['logits/value/min'] = logits_value.min().item() + norm_log_dict['logits/value/abs_max'] = logits_value.abs().max().item() + + logits_policy = losses.intermediate_losses.get('logits_policy') + if logits_policy is not None: + norm_log_dict['logits/policy/mean'] = logits_policy.mean().item() + norm_log_dict['logits/policy/std'] = logits_policy.std().item() + norm_log_dict['logits/policy/max'] = logits_policy.max().item() + norm_log_dict['logits/policy/min'] = logits_policy.min().item() + norm_log_dict['logits/policy/abs_max'] = logits_policy.abs().max().item() + + logits_reward = losses.intermediate_losses.get('logits_reward') + if logits_reward is not None: + norm_log_dict['logits/reward/mean'] = logits_reward.mean().item() + norm_log_dict['logits/reward/std'] = logits_reward.std().item() + norm_log_dict['logits/reward/max'] = logits_reward.max().item() + norm_log_dict['logits/reward/min'] = logits_reward.min().item() + norm_log_dict['logits/reward/abs_max'] = logits_reward.abs().max().item() + + # 4. 监控 obs_embeddings (Encoder输出) 的统计 + obs_embeddings = losses.intermediate_losses.get('obs_embeddings') + if obs_embeddings is not None: + # 计算每个 embedding 的 L2 范数 + emb_norms = obs_embeddings.norm(p=2, dim=-1) + norm_log_dict['embeddings/obs/norm_mean'] = emb_norms.mean().item() + norm_log_dict['embeddings/obs/norm_std'] = emb_norms.std().item() + norm_log_dict['embeddings/obs/norm_max'] = emb_norms.max().item() + norm_log_dict['embeddings/obs/norm_min'] = emb_norms.min().item() + # ================================================================= + + # Core learn model update step. self._optimizer_world_model.zero_grad() - # 假设每个进程计算出的 losses_list 为可求梯度的 tensor list,比如多个标量 loss 组成的列表 - # 例如 losses_list = [loss1, loss2, ...],其中每个 loss_i 都是形如 (1,) 的 tensor 且 requires_grad=True + if self.use_adaptive_entropy_weight: + self.alpha_optimizer.zero_grad() + # 2. 计算最终的alpha loss (在累加后取平均) + final_alpha_loss = None + if self.use_adaptive_entropy_weight: + if len(data) > 0: + final_alpha_loss = total_alpha_loss / len(data) + else: # 防御性编程,避免除以0 + final_alpha_loss = torch.tensor(0.0, device=self._cfg.device) + + # Assuming losses_list is a list of tensors with gradients, e.g., [loss1, loss2, ...]. if self._cfg.use_moco: - # 调用 MoCo backward,由 grad_correct 中的 backward 实现梯度校正 - lambd, stats = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params) + # Call MoCo's backward method, which handles gradient correction internally. + if self._cfg.moco_version=="v0": + lambd, stats = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params) + elif self._cfg.moco_version=="v1": + lambd, stats = self.grad_correct.backward(losses_list) + + # 单独为alpha loss进行反向传播 + if self.use_adaptive_entropy_weight: + final_alpha_loss.backward() + elif self._cfg.only_use_moco_stats: + # Only compute MoCo stats without applying gradient correction. lambd, stats = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params) - # 不使用梯度校正的情况,由各 rank 自己执行反向传播 - weighted_total_loss.backward() + + # Each rank performs its own backpropagation. + # weighted_total_loss.backward() + + # 如果启用自适应alpha,将alpha loss加到主损失上一起反向传播 + if self.use_adaptive_entropy_weight: + (weighted_total_loss + final_alpha_loss).backward() + elif weighted_total_loss != 0.0: # 确保有损失可以反向传播 + weighted_total_loss.backward() + else: - # 不使用梯度校正的情况,由各 rank 自己执行反向传播 + # If not using gradient correction, each rank performs standard backpropagation. lambd = torch.tensor([0. for _ in range(self.task_num_for_current_rank)], device=self._cfg.device) - weighted_total_loss.backward() - # TODO: 使用 MoCo 或 CAGrad 来计算梯度和权重 - # ============= for CAGrad and MoCo ============= - # lambd = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params) + # weighted_total_loss.backward() - # ============= TODO: 不使用梯度矫正的情况 ============= - # lambd = torch.tensor([0. for i in range(self.task_num_for_current_rank)], device=self._cfg.device) - # weighted_total_loss.backward() + # 如果启用自适应alpha,将alpha loss加到主损失上一起反向传播 + if self.use_adaptive_entropy_weight: + (weighted_total_loss + final_alpha_loss).backward() + elif weighted_total_loss != 0.0: # 确保有损失可以反向传播 + weighted_total_loss.backward() - # ========== for debugging ========== + # ----------------------------------------------------------------- + # 仍然在 torch.no_grad() 环境下执行 + # ================================================================= + with torch.no_grad(): + # 1. Encoder-Clip + # ==================== START: 动态计算当前 Clip 阈值 ==================== + current_clip_value = self.latent_norm_clip_threshold # 默认使用固定值 + if self.use_encoder_clip_annealing: + progress = min(1.0, train_iter / self.encoder_clip_anneal_steps) + + if self.encoder_clip_anneal_type == 'cosine': + # 余弦调度: 从1平滑过渡到0 + cosine_progress = 0.5 * (1.0 + np.cos(np.pi * progress)) + current_clip_value = self.encoder_clip_end + \ + (self.encoder_clip_start - self.encoder_clip_end) * cosine_progress + else: # 默认为线性调度 + current_clip_value = self.encoder_clip_start * (1 - progress) + \ + self.encoder_clip_end * progress + # ===================== END: 动态计算当前 Clip 阈值 ===================== + + # 1. Encoder-Clip (使用动态计算出的 current_clip_value) + if current_clip_value > 0 and 'obs_embeddings' in losses.intermediate_losses: + obs_embeddings = losses.intermediate_losses['obs_embeddings'] + if obs_embeddings is not None: + max_latent_norm = obs_embeddings.norm(p=2, dim=-1).max() + if max_latent_norm > current_clip_value: + scale_factor = current_clip_value / max_latent_norm.item() + # 不再频繁打印,或者可以改为每隔N步打印一次 + if train_iter % 1000 == 0: + print(f"[Encoder-Clip Annealing] Iter {train_iter}: Max latent norm {max_latent_norm.item():.2f} > {current_clip_value:.2f}. Scaling by {scale_factor:.4f}.") + scale_module_weights_vectorized(self._model.world_model.tokenizer.encoder, scale_factor) + + + # For debugging purposes. # for name, param in self._learn_model.world_model.tokenizer.encoder.named_parameters(): # print('name, param.mean(), param.std():', name, param.mean(), param.std()) # if param.requires_grad: # print(name, param.grad.norm()) + # ==================== [新增] 监控梯度范数 ==================== + # 在梯度裁剪之前监控梯度范数,用于诊断梯度爆炸/消失问题 + if self._cfg.monitor_norm_freq > 0 and (train_iter == 0 or (train_iter % self._cfg.monitor_norm_freq == 0)): + grad_norm_metrics = self._monitor_gradient_norms() + norm_log_dict.update(grad_norm_metrics) + # ================================================================= + if self._cfg.analysis_sim_norm: del self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after = self._learn_model.encoder_hook.analyze() @@ -705,33 +1418,29 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None) -> Dict[s total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_(self._learn_model.world_model.parameters(), self._cfg.grad_clip_value) - # if self._cfg.multi_gpu: - # # Very important to sync gradients before updating the model - # # rank = get_rank() - # # print(f'Rank {rank} train task_id: {self._cfg.task_id} sync grad begin...') - # self.sync_gradients(self._learn_model) - # # print(f'Rank {rank} train task_id: {self._cfg.task_id} sync grad end...') + if ignore_grad: + # NOTE: For cases where all tasks on a GPU are solved, `train` is still called for DDP synchronization, + # but gradients should be zeroed out to prevent updates. + self._optimizer_world_model.zero_grad() if self._cfg.multi_gpu: - # if not self._cfg.use_moco or self._cfg.only_use_moco_stats: - # self.sync_gradients(self._learn_model) + # If not using a gradient correction method that handles it, sync gradients manually. if not self._cfg.use_moco: self.sync_gradients(self._learn_model) - # print("=== Step 前,参数梯度详细信息 ===") - # for idx, param in enumerate(self.grad_correct.share_model.parameters()): - # if param.grad is not None: - # print(f"Param[{idx}] - device: {param.device}, dtype: {param.dtype}, " - # f"grad device: {param.grad.device}, grad dtype: {param.grad.dtype}") - # else: - # print(f"Param[{idx}] 没有梯度!") - self._optimizer_world_model.step() + # 4. 更新Alpha优化器 + if self.use_adaptive_entropy_weight: + self.alpha_optimizer.step() + # 裁剪log_alpha以保证稳定性 + with torch.no_grad(): + self.log_alpha.clamp_(np.log(1e-4), np.log(10.0)) + if self._cfg.cos_lr_scheduler or self._cfg.piecewise_decay_lr_scheduler: self.lr_scheduler.step() - # Core target model update step + # Core target model update step. self._target_model.update(self._learn_model.state_dict()) if torch.cuda.is_available(): @@ -744,37 +1453,32 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None) -> Dict[s current_memory_allocated_gb = 0. max_memory_allocated_gb = 0. - # 然后,在您的代码中,使用这个函数来构建损失字典: - return_loss_dict = { + # Build the dictionary of return values for logging. + return_log_dict = { 'Current_GPU': current_memory_allocated_gb, 'Max_GPU': max_memory_allocated_gb, 'collect_mcts_temperature': self._collect_mcts_temperature, 'collect_epsilon': self._collect_epsilon, 'cur_lr_world_model': self._optimizer_world_model.param_groups[0]['lr'], 'weighted_total_loss': weighted_total_loss.item(), - # 'policy_entropy': policy_entropy, - # 'target_policy_entropy': average_target_policy_entropy, 'total_grad_norm_before_clip_wm': total_grad_norm_before_clip_wm.item(), } - # 生成任务相关的损失字典,并为每个任务相关的 loss 添加前缀 "noreduce_" + # ==================== START: 添加新日志项 ==================== + if self.use_adaptive_entropy_weight: + return_log_dict['adaptive_alpha'] = current_alpha.item() + return_log_dict['adaptive_target_entropy_ratio'] = current_ratio + return_log_dict['final_alpha_loss'] = final_alpha_loss.item() + # ==================== START: 添加新日志项 ==================== + + # Generate task-related loss dictionaries and prefix each task-related loss with "noreduce_". multi_task_loss_dicts = { - **generate_task_loss_dict(obs_loss_multi_task, 'noreduce_obs_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(obs_loss_multi_task, 'noreduce_obs_loss_task{}', task_id=self.task_id), #global_task_ids=global_task_ids_in_batch), # task_id=self.task_id), **generate_task_loss_dict(latent_recon_loss_multi_task, 'noreduce_latent_recon_loss_task{}', task_id=self.task_id), **generate_task_loss_dict(perceptual_loss_multi_task, 'noreduce_perceptual_loss_task{}', task_id=self.task_id), **generate_task_loss_dict(latent_state_l2_norms_multi_task, 'noreduce_latent_state_l2_norms_task{}', task_id=self.task_id), **generate_task_loss_dict(dormant_ratio_head_multi_task, 'noreduce_dormant_ratio_head_task{}', task_id=self.task_id), - # 关于网络可塑性的指标 - **generate_task_loss_dict(dormant_ratio_encoder_multi_task, 'noreduce_dormant_ratio_encoder_task{}', task_id=self.task_id), - **generate_task_loss_dict(dormant_ratio_transformer_multi_task, 'noreduce_dormant_ratio_transformer_task{}', task_id=self.task_id), - **generate_task_loss_dict(dormant_ratio_head_multi_task, 'noreduce_dormant_ratio_head_task{}', task_id=self.task_id), - **generate_task_loss_dict(avg_weight_mag_encoder_multi_task, 'noreduce_avg_weight_mag_encoder_task{}', task_id=self.task_id), - **generate_task_loss_dict(avg_weight_mag_transformer_multi_task, 'noreduce_avg_weight_mag_transformer_task{}', task_id=self.task_id), - **generate_task_loss_dict(avg_weight_mag_head_multi_task, 'noreduce_avg_weight_mag_head_task{}', task_id=self.task_id), - **generate_task_loss_dict(e_rank_last_linear_multi_task, 'noreduce_e_rank_last_linear_task{}', task_id=self.task_id), - **generate_task_loss_dict(e_rank_sim_norm_multi_task, 'noreduce_e_rank_sim_norm_task{}', task_id=self.task_id), - **generate_task_loss_dict(policy_loss_multi_task, 'noreduce_policy_loss_task{}', task_id=self.task_id), **generate_task_loss_dict(orig_policy_loss_multi_task, 'noreduce_orig_policy_loss_task{}', task_id=self.task_id), **generate_task_loss_dict(policy_entropy_multi_task, 'noreduce_policy_entropy_task{}', task_id=self.task_id), @@ -784,15 +1488,45 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None) -> Dict[s **generate_task_loss_dict(lambd, 'noreduce_lambd_task{}', task_id=self.task_id), **generate_task_loss_dict(value_priority_multi_task, 'noreduce_value_priority_task{}', task_id=self.task_id), **generate_task_loss_dict(value_priority_mean_multi_task, 'noreduce_value_priority_mean_task{}', task_id=self.task_id), - } - # 合并两个字典 - return_loss_dict.update(multi_task_loss_dicts) - # print(f'return_loss_dict:{return_loss_dict}') - # 返回最终的损失字典 - return return_loss_dict - - def monitor_weights_and_grads(self, model): + # 新增alpha相关日志 + **generate_task_loss_dict(alpha_loss_multi_task, 'noreduce_alpha_loss_task{}', self.task_id), + **generate_task_loss_dict(target_entropy_multi_task, 'noreduce_target_entropy_task{}', self.task_id), + } + return_log_dict.update(multi_task_loss_dicts) + + + if self._learn_model.world_model.do_analysis: + # Include plasticity metrics if analysis is enabled. + plasticity_loss_dicts = { + **generate_task_loss_dict(dormant_ratio_encoder_multi_task, 'noreduce_dormant_ratio_encoder_task{}', task_id=self.task_id), + **generate_task_loss_dict(dormant_ratio_transformer_multi_task, 'noreduce_dormant_ratio_transformer_task{}', task_id=self.task_id), + **generate_task_loss_dict(dormant_ratio_head_multi_task, 'noreduce_dormant_ratio_head_task{}', task_id=self.task_id), + **generate_task_loss_dict(avg_weight_mag_encoder_multi_task, 'noreduce_avg_weight_mag_encoder_task{}', task_id=self.task_id), + **generate_task_loss_dict(avg_weight_mag_transformer_multi_task, 'noreduce_avg_weight_mag_transformer_task{}', task_id=self.task_id), + **generate_task_loss_dict(avg_weight_mag_head_multi_task, 'noreduce_avg_weight_mag_head_task{}', task_id=self.task_id), + **generate_task_loss_dict(e_rank_last_linear_multi_task, 'noreduce_e_rank_last_linear_task{}', task_id=self.task_id), + **generate_task_loss_dict(e_rank_sim_norm_multi_task, 'noreduce_e_rank_sim_norm_task{}', task_id=self.task_id), + } + # Merge the dictionaries. + return_log_dict.update(plasticity_loss_dicts) + + # ==================== [修改] 将范数监控结果合并到日志中 ==================== + if norm_log_dict: + return_log_dict.update(norm_log_dict) + # ======================================================================= + + # Return the final loss dictionary. + return return_log_dict + + def monitor_weights_and_grads(self, model: torch.nn.Module) -> None: + """ + Overview: + A utility function to print the mean and standard deviation of weights and their gradients for each layer in a model. + Useful for debugging training issues like exploding or vanishing gradients. + Arguments: + - model (:obj:`torch.nn.Module`): The model to monitor. + """ for name, param in model.named_parameters(): if param.requires_grad: print(f"Layer: {name} | " @@ -804,14 +1538,20 @@ def monitor_weights_and_grads(self, model): def _init_collect(self) -> None: """ Overview: - Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. + Initializes the collect mode. This method is called by ``self.__init__``. + It sets up the collect model and MCTS utilities for data collection. """ self._collect_model = self._model + # Create a copy of the configuration for collect MCTS and set a specific number of simulations. + mcts_collect_cfg = copy.deepcopy(self._cfg) + mcts_collect_cfg.num_simulations = self._cfg.collect_num_simulations + if self._cfg.mcts_ctree: - self._mcts_collect = MCTSCtree(self._cfg) + self._mcts_collect = MCTSCtree(mcts_collect_cfg) else: - self._mcts_collect = MCTSPtree(self._cfg) + self._mcts_collect = MCTSPtree(mcts_collect_cfg) + self._collect_mcts_temperature = 1. self._collect_epsilon = 0.0 self.collector_env_num = self._cfg.collector_env_num @@ -822,15 +1562,18 @@ def _init_collect(self) -> None: self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape]).to(self._cfg.device) self.last_batch_action = [-1 for i in range(self.collector_env_num)] - # TODO: num_tasks - def _monitor_vars_learn(self, num_tasks=2) -> List[str]: + # TODO: The num_tasks parameter is hardcoded. It should ideally be derived from the config. + def _monitor_vars_learn(self, num_tasks: int = 2) -> List[str]: """ Overview: - Register the variables to be monitored in learn mode. The registered variables will be logged in - tensorboard according to the return value ``_forward_learn``. - If num_tasks is provided, generate monitored variables for each task. + Registers variables to be monitored during training. These variables will be logged in TensorBoard. + It dynamically creates variable names for each task if `num_tasks` is provided. + Arguments: + - num_tasks (:obj:`int`): The number of tasks being trained on the current rank. + Returns: + - monitored_vars (:obj:`List[str]`): A list of strings, where each string is the name of a variable to be logged. """ - # Basic monitored variables that do not depend on the number of tasks + # Basic monitored variables that do not depend on the number of tasks. monitored_vars = [ 'Current_GPU', 'Max_GPU', @@ -839,9 +1582,63 @@ def _monitor_vars_learn(self, num_tasks=2) -> List[str]: 'cur_lr_world_model', 'weighted_total_loss', 'total_grad_norm_before_clip_wm', + + # 'value_priority', + 'adaptive_alpha', + "adaptive_target_entropy_ratio", + 'final_alpha_loss', ] - # rank = get_rank() + # ==================== [新增] 范数和中间张量监控变量 ==================== + # 这些变量对所有任务是共享的(不是per-task的) + norm_vars = [ + # 模块总范数 (参数范数) - 共享模块 + 'norm/encoder/_total_norm', + 'norm/transformer/_total_norm', + + # 模块总范数 (梯度范数) - 共享模块 + 'grad/encoder/_total_norm', + 'grad/transformer/_total_norm', + + # 中间张量 x (Transformer输出) 的统计信息 + 'norm/x_token/mean', + 'norm/x_token/std', + 'norm/x_token/max', + 'norm/x_token/min', + + # Logits 的详细统计 (Value) + 'logits/value/mean', + 'logits/value/std', + 'logits/value/max', + 'logits/value/min', + 'logits/value/abs_max', + + # Logits 的详细统计 (Policy) + 'logits/policy/mean', + 'logits/policy/std', + 'logits/policy/max', + 'logits/policy/min', + 'logits/policy/abs_max', + + # Logits 的详细统计 (Reward) + 'logits/reward/mean', + 'logits/reward/std', + 'logits/reward/max', + 'logits/reward/min', + 'logits/reward/abs_max', + + # Embeddings 的统计信息 + 'embeddings/obs/norm_mean', + 'embeddings/obs/norm_std', + 'embeddings/obs/norm_max', + 'embeddings/obs/norm_min', + ] + monitored_vars.extend(norm_vars) + # ======================================================================== + + + + # Task-specific variables to be monitored. task_specific_vars = [ 'noreduce_obs_loss', 'noreduce_orig_policy_loss', @@ -855,7 +1652,7 @@ def _monitor_vars_learn(self, num_tasks=2) -> List[str]: 'noreduce_latent_state_l2_norms', 'noreduce_lambd', 'noreduce_value_priority_mean', - # 关于网络可塑性的指标 + # Metrics related to network plasticity. 'noreduce_dormant_ratio_encoder', 'noreduce_dormant_ratio_transformer', 'noreduce_dormant_ratio_head', @@ -863,19 +1660,21 @@ def _monitor_vars_learn(self, num_tasks=2) -> List[str]: 'noreduce_avg_weight_mag_transformer', 'noreduce_avg_weight_mag_head', 'noreduce_e_rank_last_linear', - 'noreduce_e_rank_sim_norm' + 'noreduce_e_rank_sim_norm', + "noreduce_alpha_loss", + "noreduce_target_entropy", ] - # self.task_num_for_current_rank 作为当前rank的base_index + + # Use self.task_num_for_current_rank as the number of tasks for the current rank. num_tasks = self.task_num_for_current_rank - # If the number of tasks is provided, extend the monitored variables list with task-specific variables + # If the number of tasks is provided, extend the monitored variables list with task-specific variable names. if num_tasks is not None: for var in task_specific_vars: for task_idx in range(num_tasks): - # print(f"learner policy Rank {rank}, self.task_id: {self.task_id}") monitored_vars.append(f'{var}_task{self.task_id+task_idx}') else: - # If num_tasks is not provided, we assume there's only one task and keep the original variable names + # If num_tasks is not provided, assume a single task and use the original variable names. monitored_vars.extend(task_specific_vars) return monitored_vars @@ -889,30 +1688,25 @@ def _forward_collect( to_play: List = [-1], epsilon: float = 0.25, ready_env_id: np.array = None, + timestep: List = [0], task_id: int = None, ) -> Dict: """ Overview: - The forward function for collecting data in collect mode. Use model to execute MCTS search. - Choosing the action through sampling during the collect mode. + The forward function for collecting data. It uses the model to perform MCTS search and + selects actions via sampling to encourage exploration. Arguments: - - data (:obj:`torch.Tensor`): The input data, i.e. the observation. - - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. - - temperature (:obj:`float`): The temperature of the policy. - - to_play (:obj:`int`): The player to play. - - ready_env_id (:obj:`list`): The id of the env that is ready to collect. - Shape: - - data (:obj:`torch.Tensor`): - - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ - S is the number of stacked frames, H is the height of the image, W is the width of the image. - - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. - - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. - - temperature: :math:`(1, )`. - - to_play: :math:`(N, 1)`, where N is the number of collect_env. - - ready_env_id: None + - data (:obj:`torch.Tensor`): The input data, i.e., the current observation. + - action_mask (:obj:`list`, optional): A list of action masks for each environment. + - temperature (:obj:`float`, optional): The temperature for MCTS action selection. + - to_play (:obj:`List`, optional): A list of player IDs for each environment. + - epsilon (:obj:`float`, optional): The probability for epsilon-greedy exploration. + - ready_env_id (:obj:`np.array`, optional): An array of IDs for environments that are ready for a new action. + - timestep (:obj:`List`, optional): The current timestep in each environment. + - task_id (:obj:`int`, optional): The ID of the task for the current environments. Returns: - - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ - ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + - output (:obj:`Dict`): A dictionary where keys are environment IDs and values are dictionaries + containing the selected action and other MCTS statistics. """ self._collect_model.eval() @@ -927,36 +1721,51 @@ def _forward_collect( network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, task_id=task_id) latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) - pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() latent_state_roots = latent_state_roots.detach().cpu().numpy() + + # ========================== 核心修复 ========================== + # C++ 绑定需要一个 list,即使它在 MuZero 中代表奖励。 + reward_roots = reward_roots.detach().cpu().numpy().tolist() + # =============================================================== + policy_logits = policy_logits.detach().cpu().numpy().tolist() legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] - # the only difference between collect and eval is the dirichlet noise + # The main difference between collect and eval is the addition of Dirichlet noise at the root. noises = [ np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) ).astype(np.float32).tolist() for j in range(active_collect_env_num) ] if self._cfg.mcts_ctree: - # cpp mcts_tree + # C++ MCTS tree implementation. roots = MCTSCtree.roots(active_collect_env_num, legal_actions) else: - # python mcts_tree + # Python MCTS tree implementation. roots = MCTSPtree.roots(active_collect_env_num, legal_actions) + + # # 在本文件开始,通过全局变量来控制是否处于调试状态 + # global DEBUG_ENABLED;DEBUG_ENABLED = True + # import torch.distributed as dist + # if dist.get_rank() == 0 and DEBUG_ENABLED: + # print(f"rank {dist.get_rank()} 进入调试模式,输入interact,可以键入整段的python代码调试。通过设置 DEBUG_ENABLED = False, 可以跳过调试状态") + # import ipdb; ipdb.set_trace() + # # 同步点,防止其它进程早跑 + # dist.barrier() + roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) - self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play, task_id=task_id) + self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play, timestep= timestep, task_id=task_id) - # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` roots_visit_count_distributions = roots.get_distributions() - roots_values = roots.get_values() # shape: {list: batch_size} + roots_values = roots.get_values() batch_action = [] for i, env_id in enumerate(ready_env_id): distributions, value = roots_visit_count_distributions[i], roots_values[i] if self._cfg.eps.eps_greedy_exploration_in_collect: - # eps greedy collect + # Epsilon-greedy collection strategy. action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( distributions, temperature=self._collect_mcts_temperature, deterministic=True ) @@ -964,21 +1773,21 @@ def _forward_collect( if np.random.rand() < self._collect_epsilon: action = np.random.choice(legal_actions[i]) else: - # normal collect - # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents - # the index within the legal action set, rather than the index in the entire action set. + # Standard collection strategy (sampling from MCTS policy). + # NOTE: `action_index_in_legal_action_set` is the index within the set of legal actions. action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( distributions, temperature=self._collect_mcts_temperature, deterministic=False ) - # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. + # Convert the index back to the action in the full action space. action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - # ============== TODO: only for visualize ============== + # ============== TODO: This section is for visualization purposes only and should be removed for training. ============== + # It forces deterministic action selection during collection. # action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( # distributions, temperature=self._collect_mcts_temperature, deterministic=True # ) # action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - # ============== TODO: only for visualize ============== + # ============== End of visualization section. ============== output[env_id] = { 'action': action, @@ -993,10 +1802,12 @@ def _forward_collect( self.last_batch_obs = data self.last_batch_action = batch_action - # ========= TODO: for muzero_segment_collector now ========= + # ========= TODO: This logic is currently for the `muzero_segment_collector`. ========= if active_collect_env_num < self.collector_env_num: - # 当collect_env中有一个环境先done时,传回的self.last_batch_obs的长度会减少1, transformer在检索kv_cache时需要知道env_id,实现比较复杂 - # 因此直接《self.collector_env_num》个环境的self.last_batch_action全部重置为-1,让transformer从0开始,避免检索错误 + # When one environment in `collect_env` finishes early, the length of `self.last_batch_obs` is reduced. + # The transformer needs the `env_id` to retrieve from the KV cache, which is complex to manage with a dynamic batch size. + # Therefore, we reset `self.last_batch_action` for all environments to -1, forcing the transformer + # to start from scratch and avoid retrieval errors. print('==========collect_forward============') print(f'len(self.last_batch_obs) < self.collector_env_num, {active_collect_env_num}<{self.collector_env_num}') self._reset_collect(reset_init_data=True, task_id=task_id) @@ -1008,13 +1819,20 @@ def _forward_collect( def _init_eval(self) -> None: """ Overview: - Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. + Initializes the eval mode. This method is called by ``self.__init__``. + It sets up the eval model and MCTS utilities for evaluation. """ self._eval_model = self._model + + # Create a copy of the configuration for eval MCTS and set a specific number of simulations. + mcts_eval_cfg = copy.deepcopy(self._cfg) + mcts_eval_cfg.num_simulations = self._cfg.eval_num_simulations + if self._cfg.mcts_ctree: - self._mcts_eval = MCTSCtree(self._cfg) + self._mcts_eval = MCTSCtree(mcts_eval_cfg) else: - self._mcts_eval = MCTSPtree(self._cfg) + self._mcts_eval = MCTSPtree(mcts_eval_cfg) + self.evaluator_env_num = self._cfg.evaluator_env_num if self._cfg.model.model_type == 'conv': @@ -1026,27 +1844,21 @@ def _init_eval(self) -> None: #@profile def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, - ready_env_id: np.array = None, task_id: int = None) -> Dict: + ready_env_id: np.array = None, timestep: List = [0], task_id: int = None) -> Dict: """ Overview: - The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. - Choosing the action with the highest value (argmax) rather than sampling during the eval mode. + The forward function for evaluating the policy. It uses the model to perform MCTS search and + selects actions deterministically (choosing the one with the highest visit count). Arguments: - - data (:obj:`torch.Tensor`): The input data, i.e. the observation. - - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. - - to_play (:obj:`int`): The player to play. - - ready_env_id (:obj:`list`): The id of the env that is ready to collect. - Shape: - - data (:obj:`torch.Tensor`): - - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ - S is the number of stacked frames, H is the height of the image, W is the width of the image. - - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. - - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. - - to_play: :math:`(N, 1)`, where N is the number of collect_env. - - ready_env_id: None + - data (:obj:`torch.Tensor`): The input data, i.e., the current observation. + - action_mask (:obj:`list`): A list of action masks for each environment. + - to_play (:obj:`int`, optional): The player ID for the current turn. + - ready_env_id (:obj:`np.array`, optional): An array of IDs for environments that are ready for a new action. + - timestep (:obj:`List`, optional): The current timestep in each environment. + - task_id (:obj:`int`, optional): The ID of the task for the current environments. Returns: - - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ - ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + - output (:obj:`Dict`): A dictionary where keys are environment IDs and values are dictionaries + containing the selected action and other MCTS statistics. """ self._eval_model.eval() active_eval_env_num = data.shape[0] @@ -1057,41 +1869,42 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 network_output = self._eval_model.initial_inference(self.last_batch_obs_eval, self.last_batch_action, data, task_id=task_id) latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) - # if not self._eval_model.training: - # if not in training, obtain the scalars of the value/reward - pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() latent_state_roots = latent_state_roots.detach().cpu().numpy() - policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + # ========================== 核心修复 ========================== + # C++ 绑定需要一个 list,即使它在 MuZero 中代表奖励。 + reward_roots = reward_roots.detach().cpu().numpy().tolist() # TODO============================= + # =============================================================== + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)] if self._cfg.mcts_ctree: - # cpp mcts_tree + # C++ MCTS tree implementation. roots = MCTSCtree.roots(active_eval_env_num, legal_actions) else: - # python mcts_tree + # Python MCTS tree implementation. roots = MCTSPtree.roots(active_eval_env_num, legal_actions) + + # During evaluation, no noise is added to the root policy. roots.prepare_no_noise(reward_roots, policy_logits, to_play) - self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play, task_id=task_id) + self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play, timestep= timestep, task_id=task_id) - # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` roots_visit_count_distributions = roots.get_distributions() - roots_values = roots.get_values() # shape: {list: batch_size} + roots_values = roots.get_values() batch_action = [] for i, env_id in enumerate(ready_env_id): distributions, value = roots_visit_count_distributions[i], roots_values[i] - # print("roots_visit_count_distributions:", distributions, "root_value:", value) - # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents - # the index within the legal action set, rather than the index in the entire action set. - # Setting deterministic=True implies choosing the action with the highest value (argmax) rather than - # sampling during the evaluation phase. + # NOTE: `deterministic=True` means we select the action with the highest visit count (argmax) + # rather than sampling, which is standard for evaluation. action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( distributions, temperature=1, deterministic=True ) - # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the - # entire action set. + # Convert the index back to the action in the full action space. action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] output[env_id] = { @@ -1113,14 +1926,13 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 def _reset_collect(self, env_id: int = None, current_steps: int = 0, reset_init_data: bool = True, task_id: int = None) -> None: """ Overview: - This method resets the collection process for a specific environment. It clears caches and memory - when certain conditions are met, ensuring optimal performance. If reset_init_data is True, the initial data - will be reset. + Resets the collection process for a specific environment or all environments. + It can clear caches and reset initial data to ensure optimal performance and prevent state leakage. Arguments: - - env_id (:obj:`int`, optional): The ID of the environment to reset. If None or list, the function returns immediately. - - current_steps (:obj:`int`, optional): The current step count in the environment. Used to determine - whether to clear caches. - - reset_init_data (:obj:`bool`, optional): Whether to reset the initial data. If True, the initial data will be reset. + - env_id (:obj:`int`, optional): The ID of the environment to reset. If None, the reset applies more broadly. Defaults to None. + - current_steps (:obj:`int`, optional): The current step count in the environment, used to trigger periodic cache clearing. Defaults to 0. + - reset_init_data (:obj:`bool`, optional): If True, resets the initial observation and action buffers. Defaults to True. + - task_id (:obj:`int`, optional): The task ID, currently unused in this method. Defaults to None. """ if reset_init_data: self.last_batch_obs = initialize_zeros_batch( @@ -1129,122 +1941,155 @@ def _reset_collect(self, env_id: int = None, current_steps: int = 0, reset_init_ self._cfg.device ) self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)] - # print('collector: last_batch_obs, last_batch_action reset()', self.last_batch_obs.shape) + # print('Collector: last_batch_obs and last_batch_action have been reset.') + + # Return immediately if env_id is not a single integer (e.g., None or a list). + # if env_id is None or isinstance(env_id, list): + # return + + # We must handle both single int and list of ints for env_id. + if env_id is not None: + if isinstance(env_id, int): + env_ids_to_reset = [env_id] + else: # Assumes it's a list + env_ids_to_reset = env_id + + # The key condition: `current_steps` is None only on the end-of-episode reset call from the collector. + if current_steps is None: + world_model = self._collect_model.world_model + for eid in env_ids_to_reset: + # Clear the specific environment's initial inference cache. + if eid < len(world_model.past_kv_cache_init_infer_envs): + world_model.past_kv_cache_init_infer_envs[eid].clear() - # Return immediately if env_id is None or a list - if env_id is None or isinstance(env_id, list): - return + print(f'>>> [Collector] Cleared KV cache for env_id: {eid} at episode end.') - # Determine the clear interval based on the environment's sample type - clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 - # Clear caches if the current steps are a multiple of the clear interval - if current_steps % clear_interval == 0: + # Determine the clear interval based on the environment's sample type. + # clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else self._cfg.game_segment_length + + # Clear caches periodically to manage memory. + # if current_steps % clear_interval == 0: + if current_steps is not None and current_steps % clear_interval == 0: + print(f'clear_interval: {clear_interval}') - # Clear various caches in the collect model's world model + # Clear various KV caches in the collect model's world model. world_model = self._collect_model.world_model - world_model.past_kv_cache_init_infer.clear() for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: kv_cache_dict_env.clear() world_model.past_kv_cache_recurrent_infer.clear() world_model.keys_values_wm_list.clear() - # Free up GPU memory + # Free up unused GPU memory. torch.cuda.empty_cache() - print('collector: collect_model clear()') - print(f'eps_steps_lst[{env_id}]: {current_steps}') + print(f'Collector: Caches cleared for collect_model at step {current_steps} for env {env_id}.') - # TODO: check its correctness ========= + # TODO: Check if resetting the target model here is correct and necessary. self._reset_target_model() #@profile def _reset_target_model(self) -> None: """ Overview: - This method resets the target model. It clears caches and memory, ensuring optimal performance. - Arguments: - - None + Resets the target model by clearing its internal caches. This is crucial for managing memory, + especially when using transformer-based models with KV caching. """ - - # Clear various caches in the target_model + # Clear various KV caches in the target model's world model. world_model = self._target_model.world_model - world_model.past_kv_cache_init_infer.clear() for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: kv_cache_dict_env.clear() world_model.past_kv_cache_recurrent_infer.clear() world_model.keys_values_wm_list.clear() - # Free up GPU memory + # Free up unused GPU memory. torch.cuda.empty_cache() - print('collector: target_model past_kv_cache.clear()') + print('Collector: Target model past_kv_cache cleared.') #@profile def _reset_eval(self, env_id: int = None, current_steps: int = 0, reset_init_data: bool = True, task_id: int = None) -> None: """ Overview: - This method resets the evaluation process for a specific environment. It clears caches and memory - when certain conditions are met, ensuring optimal performance. If reset_init_data is True, - the initial data will be reset. + Resets the evaluation process for a specific environment or all environments. + Clears caches and resets initial data to ensure clean evaluation runs. Arguments: - - env_id (:obj:`int`, optional): The ID of the environment to reset. If None or list, the function returns immediately. - - current_steps (:obj:`int`, optional): The current step count in the environment. Used to determine - whether to clear caches. - - reset_init_data (:obj:`bool`, optional): Whether to reset the initial data. If True, the initial data will be reset. + - env_id (:obj:`int`, optional): The ID of the environment to reset. Defaults to None. + - current_steps (:obj:`int`, optional): The current step count, used for periodic cache clearing. Defaults to 0. + - reset_init_data (:obj:`bool`, optional): If True, resets the initial observation and action buffers. Defaults to True. + - task_id (:obj:`int`, optional): The task ID. Can be used to handle different observation shapes per task. Defaults to None. """ if reset_init_data: - # if task_id is not None: - # self.last_batch_obs_eval = initialize_zeros_batch( - # self._cfg.model.observation_shape_list[task_id], - # self._cfg.evaluator_env_num, - # self._cfg.device - # ) - # print('unizero_multitask.py task_id is not None after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape) - - # else: self.last_batch_obs_eval = initialize_zeros_batch( self._cfg.model.observation_shape, self._cfg.evaluator_env_num, self._cfg.device ) - print('unizero_multitask.py task_id is None after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape) + # print(f'Evaluator reset: last_batch_obs_eval shape: {self.last_batch_obs_eval.shape}') self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)] - # Return immediately if env_id is None or a list - if env_id is None or isinstance(env_id, list): - return + # --- BEGIN ROBUST FIX --- + # This logic handles the crucial end-of-episode cache clearing for evaluation. + # The evaluator calls `_policy.reset([env_id])` when an episode is done. + if env_id is not None: + if isinstance(env_id, int): + env_ids_to_reset = [env_id] + else: # Assumes it's a list + env_ids_to_reset = env_id + + # The key condition: `current_steps` is None only on the end-of-episode reset call from the evaluator. + if current_steps is None: + world_model = self._eval_model.world_model + for eid in env_ids_to_reset: + # Clear the specific environment's initial inference cache. + if eid < len(world_model.past_kv_cache_init_infer_envs): + world_model.past_kv_cache_init_infer_envs[eid].clear() + + print(f'>>> [Evaluator] Cleared KV cache for env_id: {eid} at episode end.') + + # The recurrent cache is global. + world_model.past_kv_cache_recurrent_infer.clear() + + if hasattr(world_model, 'keys_values_wm_list'): + world_model.keys_values_wm_list.clear() + + torch.cuda.empty_cache() + return + # --- END ROBUST FIX --- + + # Determine the clear interval. + # clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else self._cfg.game_segment_length - # Determine the clear interval based on the environment's sample type - clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + # Clear caches periodically. + # if current_steps % clear_interval == 0: + if current_steps is not None and current_steps % clear_interval == 0: - # Clear caches if the current steps are a multiple of the clear interval - if current_steps % clear_interval == 0: print(f'clear_interval: {clear_interval}') - # Clear various caches in the eval model's world model + # Clear various KV caches in the eval model's world model. world_model = self._eval_model.world_model - # world_model.past_kv_cache_init_infer.clear() for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: kv_cache_dict_env.clear() world_model.past_kv_cache_recurrent_infer.clear() world_model.keys_values_wm_list.clear() - # Free up GPU memory + # Free up unused GPU memory. torch.cuda.empty_cache() - print('evaluator: eval_model clear()') - print(f'eps_steps_lst[{env_id}]: {current_steps}') + print(f'Evaluator: Caches cleared for eval_model at step {current_steps} for env {env_id}.') def recompute_pos_emb_diff_and_clear_cache(self) -> None: """ Overview: - Clear the caches and precompute positional embedding matrices in the model. + Clears all KV caches and precomputes positional embedding matrices in the model. + This is typically called when the maximum sequence length changes. """ - # NOTE: Clear caches and precompute positional embedding matrices both for the collect and target models + # NOTE: This must be done for both the collect and target models. for model in [self._collect_model, self._target_model]: model.world_model.precompute_pos_emb_diff_kv() model.world_model.clear_caches() @@ -1253,9 +2098,11 @@ def recompute_pos_emb_diff_and_clear_cache(self) -> None: def _state_dict_learn(self) -> Dict[str, Any]: """ Overview: - Return the state_dict of learn mode, usually including model, target_model and optimizer. + Returns the state dictionary of the learn mode. + This typically includes the model, target model, and optimizer states, + which are necessary for saving and resuming training. Returns: - - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. + - state_dict (:obj:`Dict[str, Any]`): The state dictionary for the current learning progress. """ return { 'model': self._learn_model.state_dict(), @@ -1263,34 +2110,35 @@ def _state_dict_learn(self) -> Dict[str, Any]: 'optimizer_world_model': self._optimizer_world_model.state_dict(), } - # ========== TODO: original version: load all parameters ========== + # ========== NOTE: This is the original version which loads all parameters from the state_dict. ========== # def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: # """ # Overview: - # Load the state_dict variable into policy learn mode. + # Loads the state_dict into the policy's learn mode. # Arguments: - # - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. + # - state_dict (:obj:`Dict[str, Any]`): The state dictionary saved from a previous training session. # """ # self._learn_model.load_state_dict(state_dict['model']) # self._target_model.load_state_dict(state_dict['target_model']) # self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model']) - # ========== TODO: pretrain-finetue version: only load encoder and transformer-backbone parameters ========== + # ========== NOTE: This is a pretrain-finetune version that selectively loads parameters and freezes layers. ========== def _load_state_dict_learn(self, state_dict: Dict[str, Any], finetune_components: List[str] = []) -> None: """ Overview: - Load the state_dict variable into policy learn mode, excluding multi-task related parameters. - 根据 finetune_components 参数,决定加载 encoder 和 transformer 后,哪些部分参与后续更新,哪些被冻结。 + Loads a state_dict for fine-tuning. It excludes multi-task specific parameters + and can freeze parts of the model (e.g., encoder, transformer) based on `finetune_components`. Arguments: - - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved previously. - - finetune_components (:obj:`List[str]`, optional): A list of component names that will remain trainable after loading. - For example, it can include "encoder", "transformer", or both. The components not in this list will be frozen. + - state_dict (:obj:`Dict[str, Any]`): The state dictionary from a pre-trained model. + - finetune_components (:obj:`List[str]`, optional): A list of component names (e.g., "encoder", "transformer") + that will remain trainable. Components not in this list will have their parameters frozen. """ - # finetune_components = [] # load-enc-trans_finetune-head - # finetune_components = ['transformer'] # load-enc-trans_finetune-trans-head - finetune_components = ["representation_network", "encoder"] # load-enc-trans_finetune-encoder-head + # Example configurations for fine-tuning: + # finetune_components = [] # Loads encoder & transformer, fine-tunes only heads. + # finetune_components = ['transformer'] # Loads encoder & transformer, fine-tunes transformer & heads. + finetune_components = ["representation_network", "encoder"] # Loads encoder & transformer, fine-tunes encoder & heads. - # 定义需要排除的参数前缀,即不加载这些参数 + # Define prefixes of parameters to be excluded from loading (typically multi-task heads). exclude_prefixes = [ '_orig_mod.world_model.head_policy_multi_task.', '_orig_mod.world_model.head_value_multi_task.', @@ -1299,29 +2147,28 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any], finetune_components '_orig_mod.world_model.task_emb.' ] - # 定义需要排除的具体参数(如果有特殊情况) + # Define specific parameter keys to be excluded (for special cases like task embeddings). exclude_keys = [ '_orig_mod.world_model.task_emb.weight', - '_orig_mod.world_model.task_emb.bias', # 如果存在则添加 - # 添加其他需要排除的具体参数名 + '_orig_mod.world_model.task_emb.bias', ] def filter_state_dict(state_dict_loader: Dict[str, Any], exclude_prefixes: list, exclude_keys: list = []) -> Dict[str, Any]: """ - 过滤掉需要排除的参数。 + Filters out parameters from a state_dict based on prefixes and specific keys. """ filtered = {} for k, v in state_dict_loader.items(): if any(k.startswith(prefix) for prefix in exclude_prefixes): - print(f"Excluding parameter: {k}") # 调试用,查看哪些参数被排除 + print(f"Excluding parameter: {k}") # For debugging continue if k in exclude_keys: - print(f"Excluding specific parameter: {k}") # 调试用 + print(f"Excluding specific parameter: {k}") # For debugging continue filtered[k] = v return filtered - # 过滤并加载 'model' 部分 + # Filter and load the 'model' state_dict. if 'model' in state_dict: model_state_dict = state_dict['model'] filtered_model_state_dict = filter_state_dict(model_state_dict, exclude_prefixes, exclude_keys) @@ -1333,7 +2180,7 @@ def filter_state_dict(state_dict_loader: Dict[str, Any], exclude_prefixes: list, else: print("No 'model' key found in the state_dict.") - # 过滤并加载 'target_model' 部分 + # Filter and load the 'target_model' state_dict. if 'target_model' in state_dict: target_model_state_dict = state_dict['target_model'] filtered_target_model_state_dict = filter_state_dict(target_model_state_dict, exclude_prefixes, exclude_keys) @@ -1345,41 +2192,42 @@ def filter_state_dict(state_dict_loader: Dict[str, Any], exclude_prefixes: list, else: print("No 'target_model' key found in the state_dict.") - # 对 _learn_model 中的参数进行冻结/解冻的处理 - # 假设模型中参数的名字如果包含 "encoder" 则属于 encoder 模块, - # 包含 "transformer" 则属于 transformer 模块,其它部分可根据需要扩展。 + # Handle freezing/unfreezing of parameters in _learn_model based on finetune_components. + # This assumes a naming convention where component names are present in parameter names. for name, param in self._learn_model.named_parameters(): - # 如果参数属于 encoder 且不在需要微调的组件中,则冻结该参数 + # Freeze the encoder if "encoder" is not in finetune_components. if "encoder" in name and "encoder" not in finetune_components: param.requires_grad = False print(f"Freezing parameter: {name}") + # Freeze the representation network if "representation_network" is not in finetune_components. elif "representation_network" in name and "representation_network" not in finetune_components: param.requires_grad = False print(f"Freezing parameter: {name}") - # 如果参数属于 transformer 且不在需要微调的组件中,则冻结该参数 + # Freeze the transformer if "transformer" is not in finetune_components. elif "transformer" in name and "transformer" not in finetune_components: param.requires_grad = False print(f"Freezing parameter: {name}") else: - # 如果参数属于其他模块,或者包含在 finetune_components 中,则保持默认(或者根据需要调整) - print(f"Parameter remains default: {name}") + # Other parameters remain trainable by default. + print(f"Parameter remains trainable: {name}") - # 注意: - # 如果你的模型中嵌套模块更为复杂,可以基于 module 的属性而不是仅仅依靠参数名称进行判断,比如: + # NOTE: For more complex model structures, it might be better to identify modules by their class + # rather than relying on parameter names. For example: # for module in self._learn_model.modules(): # if isinstance(module, EncoderModule) and "encoder" not in finetune_components: # for param in module.parameters(): # param.requires_grad = False - # # ========== TODO: pretrain-finetue version: only load encoder and transformer-backbone parameters ========== + # ========== NOTE: Another pretrain-finetune version. The main difference from the above is the freezing logic and comments. ========== # def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: # """ # Overview: - # Load the state_dict variable into policy learn mode, excluding multi-task related parameters. + # Loads a state_dict into the policy's learn mode, excluding multi-task related parameters. + # This is intended for fine-tuning a pre-trained model on new tasks. # Arguments: - # - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved previously. + # - state_dict (:obj:`Dict[str, Any]`): The state dictionary from a pre-trained model. # """ - # # 定义需要排除的参数前缀 + # # Define prefixes of parameters to be excluded. # exclude_prefixes = [ # '_orig_mod.world_model.head_policy_multi_task.', # '_orig_mod.world_model.head_value_multi_task.', @@ -1388,29 +2236,28 @@ def filter_state_dict(state_dict_loader: Dict[str, Any], exclude_prefixes: list, # '_orig_mod.world_model.task_emb.' # ] - # # 定义需要排除的具体参数(如果有特殊情况) + # # Define specific parameter keys to be excluded. # exclude_keys = [ # '_orig_mod.world_model.task_emb.weight', - # '_orig_mod.world_model.task_emb.bias', # 如果存在则添加 - # # 添加其他需要排除的具体参数名 + # '_orig_mod.world_model.task_emb.bias', # ] # def filter_state_dict(state_dict_loader: Dict[str, Any], exclude_prefixes: list, exclude_keys: list = []) -> Dict[str, Any]: # """ - # 过滤掉需要排除的参数。 + # Filters out parameters that should not be loaded. # """ # filtered = {} # for k, v in state_dict_loader.items(): # if any(k.startswith(prefix) for prefix in exclude_prefixes): - # print(f"Excluding parameter: {k}") # 调试用,查看哪些参数被排除 + # print(f"Excluding parameter: {k}") # continue # if k in exclude_keys: - # print(f"Excluding specific parameter: {k}") # 调试用 + # print(f"Excluding specific parameter: {k}") # continue # filtered[k] = v # return filtered - # # 过滤并加载 'model' 部分 + # # Filter and load the 'model' part. # if 'model' in state_dict: # model_state_dict = state_dict['model'] # filtered_model_state_dict = filter_state_dict(model_state_dict, exclude_prefixes, exclude_keys) @@ -1422,7 +2269,7 @@ def filter_state_dict(state_dict_loader: Dict[str, Any], exclude_prefixes: list, # else: # print("No 'model' key found in the state_dict.") - # # 过滤并加载 'target_model' 部分 + # # Filter and load the 'target_model' part. # if 'target_model' in state_dict: # target_model_state_dict = state_dict['target_model'] # filtered_target_model_state_dict = filter_state_dict(target_model_state_dict, exclude_prefixes, exclude_keys) @@ -1434,12 +2281,8 @@ def filter_state_dict(state_dict_loader: Dict[str, Any], exclude_prefixes: list, # else: # print("No 'target_model' key found in the state_dict.") - # # 不要加载优化器的 state_dict,因为优化器通常不包含模型参数,加载后性能反而变差 + # # Do not load the optimizer's state_dict when fine-tuning, as it contains state (like momentum) + # # specific to the pre-training task, which can hinder adaptation to new tasks. + # # A fresh optimizer is usually preferred. # # if 'optimizer_world_model' in state_dict: - # # optimizer_state_dict = state_dict['optimizer_world_model'] - # # try: - # # self._optimizer_world_model.load_state_dict(optimizer_state_dict) - # # except Exception as e: - # # print(f"Error loading optimizer state_dict: {e}") - # # else: - # # print("No 'optimizer_world_model' key found in the state_dict.") + # # ... \ No newline at end of file diff --git a/lzero/policy/unizero_multitask_alpha_indep.py b/lzero/policy/unizero_multitask_alpha_indep.py new file mode 100644 index 000000000..db2b4c513 --- /dev/null +++ b/lzero/policy/unizero_multitask_alpha_indep.py @@ -0,0 +1,2000 @@ +import copy +from collections import defaultdict +from typing import List, Dict, Any, Tuple, Union + +import numpy as np +import torch +from ding.model import model_wrap +from ding.utils import POLICY_REGISTRY + +from lzero.entry.utils import initialize_zeros_batch +from lzero.mcts import UniZeroMCTSCtree as MCTSCtree +from lzero.model import ImageTransforms +from lzero.policy import prepare_obs_stack_for_unizero +from lzero.policy import scalar_transform, InverseScalarTransform, phi_transform, \ + DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, prepare_obs +from lzero.policy.unizero import UniZeroPolicy, scale_module_weights_vectorized +from .utils import configure_optimizers_nanogpt +import sys + +# Please replace the path with the actual location of your LibMTL library. +sys.path.append('/path/to/your/LibMTL') + +from LibMTL.weighting.MoCo_unizero import MoCo as GradCorrect +from LibMTL.weighting.moco_fast_mem_eff import FastMoCoMemEff as FastMoCo +from LibMTL.weighting.moco_fast_mem_eff import MoCoCfg + +import torch.distributed as dist + +# ------------------------------------------------------------ +# 1. Add a dedicated process-group for the learner. +# (This function should be called once during the initialization of the main process or the learner.) +# ------------------------------------------------------------ +def build_learner_group(learner_ranks: list[int]) -> dist.ProcessGroup: + """ + Overview: + Builds and returns a new process group containing only the learner ranks. + This is used for methods like GenericMoCo that require collective communication + only among the ranks performing training. + Arguments: + - learner_ranks (:obj:`list[int]`): A list of world ranks that are designated as learners. + These are the ranks that will perform the backward pass. + e.g., if CUDA_VISIBLE_DEVICES=0,1, then learner_ranks=[0,1]. + Returns: + - pg (:obj:`dist.ProcessGroup`): A new process group containing only the learner ranks. + """ + world_pg = dist.group.WORLD + pg = dist.new_group(ranks=learner_ranks, backend='nccl') + if dist.get_rank() in learner_ranks: + torch.cuda.set_device(learner_ranks.index(dist.get_rank())) + return pg + + +def generate_task_loss_dict(multi_task_losses: List[Union[torch.Tensor, float]], task_name_template: str, task_id: int) -> Dict[str, float]: + """ + Overview: + Generates a dictionary for the losses of each task. + Arguments: + - multi_task_losses (:obj:`List[Union[torch.Tensor, float]]`): A list containing the loss for each task. + - task_name_template (:obj:`str`): The template for the task name, e.g., 'obs_loss_task{}'. + - task_id (:obj:`int`): The starting ID of the tasks. + Returns: + - task_loss_dict (:obj:`Dict[str, float]`): A dictionary where keys are formatted task names and values are the corresponding losses. + """ + task_loss_dict = {} + for task_idx, task_loss in enumerate(multi_task_losses): + task_name = task_name_template.format(task_idx + task_id) + try: + # Get the scalar value of the loss if it's a tensor. + task_loss_dict[task_name] = task_loss.item() if hasattr(task_loss, 'item') else task_loss + except Exception as e: + task_loss_dict[task_name] = task_loss + return task_loss_dict + +# # 修改后的函数: +# def generate_task_loss_dict( +# multi_task_losses: List[Union[torch.Tensor, float]], +# task_name_template: str, +# global_task_ids: List[int] +# ) -> Dict[str, float]: +# """ +# Overview: +# Generates a dictionary for the losses of each task using their explicit global IDs. +# Arguments: +# - multi_task_losses (:obj:`List[Union[torch.Tensor, float]]`): A list containing the loss for each task. +# - task_name_template (:obj:`str`): The template for the task name, e.g., 'obs_loss_task{}'. +# - global_task_ids (:obj:`List[int]`): A list of global task IDs corresponding to each loss in multi_task_losses. +# Returns: +# - task_loss_dict (:obj:`Dict[str, float]`): A dictionary where keys are formatted task names and values are the corresponding losses. +# """ +# task_loss_dict = {} +# # 使用 zip 将每个损失与其正确的全局ID配对 +# for task_loss, global_id in zip(multi_task_losses, global_task_ids): +# task_name = task_name_template.format(global_id) +# try: +# task_loss_dict[task_name] = task_loss.item() if hasattr(task_loss, 'item') else task_loss +# except Exception as e: +# task_loss_dict[task_name] = task_loss +# return task_loss_dict + + +class WrappedModel: + """ + Overview: + A wrapper class for the world model to conveniently access its parameters and zero its gradients. + This version wraps the entire world model. + """ + def __init__(self, world_model: torch.nn.Module): + """ + Arguments: + - world_model (:obj:`torch.nn.Module`): The world model instance. + """ + self.world_model = world_model + + def parameters(self) -> iter: + """ + Overview: + Returns an iterator over the parameters of the entire world model. + """ + return self.world_model.parameters() + + def zero_grad(self, set_to_none: bool = False) -> None: + """ + Overview: + Sets the gradients of all world model parameters to zero. + Arguments: + - set_to_none (:obj:`bool`): Whether to set gradients to None instead of zero. + """ + self.world_model.zero_grad(set_to_none=set_to_none) + + +class WrappedModelV2: + """ + Overview: + A wrapper for specific components of the world model. + This version is designed to group parameters that are considered "shared" + across tasks for gradient correction methods like MoCo, excluding the prediction heads. + """ + def __init__(self, tokenizer: torch.nn.Module, transformer: torch.nn.Module, pos_emb: torch.nn.Module, task_emb: torch.nn.Module, act_embedding_table: torch.nn.Module): + """ + Arguments: + - tokenizer (:obj:`torch.nn.Module`): The tokenizer module. + - transformer (:obj:`torch.nn.Module`): The transformer backbone. + - pos_emb (:obj:`torch.nn.Module`): The positional embedding module. + - task_emb (:obj:`torch.nn.Module`): The task embedding module. + - act_embedding_table (:obj:`torch.nn.Module`): The action embedding table. + """ + self.tokenizer = tokenizer + self.transformer = transformer + self.pos_emb = pos_emb + self.task_emb = task_emb + self.act_embedding_table = act_embedding_table + + def parameters(self) -> iter: + """ + Overview: + Returns an iterator over the parameters of the wrapped components (tokenizer, transformer, embeddings). + These are typically the shared parts of the model whose gradients need to be managed for multi-task learning. + """ + return (list(self.tokenizer.parameters()) + + list(self.transformer.parameters()) + + list(self.pos_emb.parameters()) + + # list(self.task_emb.parameters()) + # TODO: Decide whether to include task embeddings in shared parameters. + list(self.act_embedding_table.parameters())) + + def zero_grad(self, set_to_none: bool = False) -> None: + """ + Overview: + Sets the gradients of all wrapped components to zero. + Arguments: + - set_to_none (:obj:`bool`): Whether to set gradients to None instead of zero. + """ + self.tokenizer.zero_grad(set_to_none=set_to_none) + self.transformer.zero_grad(set_to_none=set_to_none) + self.pos_emb.zero_grad(set_to_none=set_to_none) + # self.task_emb.zero_grad(set_to_none=set_to_none) # TODO: Match the decision made in the parameters() method. + self.act_embedding_table.zero_grad(set_to_none=set_to_none) + + +class WrappedModelV3: + """ + Overview: + An alternative wrapper for world model components. + This version excludes the tokenizer from the shared parameters, focusing gradient correction + on the transformer and embedding layers. + """ + def __init__(self, transformer: torch.nn.Module, pos_emb: torch.nn.Module, task_emb: torch.nn.Module, act_embedding_table: torch.nn.Module): + """ + Arguments: + - transformer (:obj:`torch.nn.Module`): The transformer backbone. + - pos_emb (:obj:`torch.nn.Module`): The positional embedding module. + - task_emb (:obj:`torch.nn.Module`): The task embedding module. + - act_embedding_table (:obj:`torch.nn.Module`): The action embedding table. + """ + self.transformer = transformer + self.pos_emb = pos_emb + self.task_emb = task_emb + self.act_embedding_table = act_embedding_table + + def parameters(self) -> iter: + """ + Overview: + Returns an iterator over the parameters of the transformer and various embedding layers. + """ + return (list(self.transformer.parameters()) + + list(self.pos_emb.parameters()) + + list(self.task_emb.parameters()) + + list(self.act_embedding_table.parameters())) + + def zero_grad(self, set_to_none: bool = False) -> None: + """ + Overview: + Sets the gradients of the wrapped components to zero. + Arguments: + - set_to_none (:obj:`bool`): Whether to set gradients to None instead of zero. + """ + self.transformer.zero_grad(set_to_none=set_to_none) + self.pos_emb.zero_grad(set_to_none=set_to_none) + self.task_emb.zero_grad(set_to_none=set_to_none) + self.act_embedding_table.zero_grad(set_to_none=set_to_none) + + +# def configure_optimizer_unizero(model, learning_rate, weight_decay, device_type, betas): +# """ +# 为UniZero模型配置带有差异化学习率的优化器。 +# """ +# # 1. 定义需要特殊处理的参数 +# param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad} + +# # 2. 将参数分为三组:Transformer主干、Tokenizer、Heads +# transformer_params = {pn: p for pn, p in param_dict.items() if 'transformer' in pn} +# tokenizer_params = {pn: p for pn, p in param_dict.items() if 'tokenizer' in pn} + +# # Heads的参数是那些既不属于transformer也不属于tokenizer的 +# head_params = { +# pn: p for pn, p in param_dict.items() +# if 'transformer' not in pn and 'tokenizer' not in pn +# } + +# # 3. 为每组设置不同的优化器参数(特别是学习率) +# # 这里我们仍然使用AdamW,但学习率设置更合理 +# optim_groups = [ +# { +# 'params': list(transformer_params.values()), +# 'lr': learning_rate, # 1e-4 +# # 'lr': learning_rate * 0.2, # 为Transformer主干设置一个较小的学习率,例如 1e-5 +# 'weight_decay': weight_decay +# # 'weight_decay': weight_decay * 5.0 +# }, +# { +# 'params': list(tokenizer_params.values()), +# 'lr': learning_rate, # Tokenizer使用基础学习率,例如 1e-4 +# # 'lr': learning_rate * 0.1, # 为encoder设置一个较小的学习率,例如 1e-5 +# 'weight_decay': weight_decay * 5.0 # <-- 为Encoder设置5倍的权重衰减!这是一个强力正则化 + +# }, +# { +# 'params': list(head_params.values()), +# 'lr': learning_rate, # Heads也使用基础学习率率,例如 1e-4 +# 'weight_decay': 0.0 # 通常Heads的权重不做衰减 +# # 'weight_decay': weight_decay + +# } +# ] + +# print("--- Optimizer Groups ---") +# print(f"Transformer LR: {learning_rate}") +# print(f"Tokenizer/Heads LR: {learning_rate}") + +# optimizer = torch.optim.AdamW(optim_groups, betas=betas) +# return optimizer + +def configure_optimizer_unizero(model, learning_rate, weight_decay, device_type, betas): + """ + 为UniZero模型配置带有差异化学习率的优化器。 + (修正版,确保参数组互斥) + """ + # 1. 创建空的参数列表用于分组 + transformer_params = [] + tokenizer_params = [] + head_params = [] + + # 2. 遍历所有可训练参数,并使用 if/elif/else 结构确保每个参数只被分配到一个组 + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + + if 'transformer' in name: + transformer_params.append(param) + elif 'tokenizer' in name: + tokenizer_params.append(param) + else: + head_params.append(param) + + # 3. 为每组设置不同的优化器参数 + # 这里我们仍然使用AdamW,但学习率设置更合理 + optim_groups = [ + { + 'params': transformer_params, + 'lr': learning_rate, # 1e-4 + 'weight_decay': weight_decay + }, + { + 'params': tokenizer_params, + 'lr': learning_rate, # Tokenizer使用基础学习率,例如 1e-4 + # 'weight_decay': weight_decay * 5.0 # <-- 为Encoder设置5倍的权重衰减!这是一个强力正则化 + 'weight_decay': weight_decay # <-- 为Encoder设置5倍的权重衰减!这是一个强力正则化 + }, + { + 'params': head_params, + 'lr': learning_rate, # Heads也使用基础学习率率,例如 1e-4 + # 'weight_decay': 0.0 # 通常Heads的权重不做衰减 + 'weight_decay': weight_decay + + } + ] + + print("--- Optimizer Groups ---") + # 打印每个组的参数数量以供调试 + print(f"Transformer params: {len(transformer_params)}") + print(f"Tokenizer params: {len(tokenizer_params)}") + print(f"Head params: {len(head_params)}") + print(f"Transformer LR: {learning_rate}") + print(f"Tokenizer/Heads LR: {learning_rate}") + + optimizer = torch.optim.AdamW(optim_groups, betas=betas) + return optimizer + +@POLICY_REGISTRY.register('unizero_multitask') +class UniZeroMTPolicy(UniZeroPolicy): + """ + Overview: + The policy class for multi-task UniZero, an official implementation for the paper "UniZero: Generalized and Efficient Planning + with Scalable Latent World Models". UniZero aims to enhance the planning capabilities of reinforcement learning agents + by addressing the limitations of MuZero-style algorithms, particularly in environments requiring the + capture of long-term dependencies. More details can be found at: https://arxiv.org/abs/2406.10667. + """ + + # The default_config for UniZero multi-task policy. + config = dict( + type='unizero_multitask', + model=dict( + # (str) The model type. For 1-dimensional vector obs, we use mlp model. For the image obs, we use conv model. + model_type='conv', # options={'mlp', 'conv'} + # (bool) If True, the action space of the environment is continuous, otherwise discrete. + continuous_action_space=False, + # (tuple) The obs shape. + observation_shape=(3, 64, 64), + # (bool) Whether to use the self-supervised learning loss. + self_supervised_learning_loss=True, + # (bool) Whether to use discrete support to represent categorical distribution for value/reward/value_prefix. + categorical_distribution=True, + # (int) The image channel in image observation. + image_channel=3, + # (int) The number of frames to stack together. + frame_stack_num=1, + # (int) The number of res blocks in MuZero model. + num_res_blocks=1, + # (int) The number of channels of hidden states in MuZero model. + num_channels=64, + # (int) The scale of supports used in categorical distribution. + # This variable is only effective when ``categorical_distribution=True``. + support_scale=50, + # (bool) whether to learn bias in the last linear layer in value and policy head. + bias=True, + # (bool) whether to use res connection in dynamics. + res_connection_in_dynamics=True, + # (str) The type of normalization in MuZero model. Options are ['BN', 'LN']. Default to 'BN'. + norm_type='LN', # NOTE: LayerNorm is used in the transformer-based world model. + # (bool) Whether to analyze simulation normalization. + analysis_sim_norm=False, + # (int) The save interval of the model. + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=10000, ), ), ), + world_model_cfg=dict( + # (int) The number of tokens per block. + tokens_per_block=2, + # (int) The maximum number of blocks. + max_blocks=10, + # (int) The maximum number of tokens, calculated as tokens per block multiplied by max blocks. + max_tokens=2 * 10, + # (int) The context length, usually calculated as twice the number of some base unit. + context_length=2 * 4, + # (bool) Whether to use GRU gating mechanism. + gru_gating=False, + # (str) The device to be used for computation, e.g., 'cpu' or 'cuda'. + device='cpu', + # (bool) Whether to analyze simulation normalization. + analysis_sim_norm=False, + # (bool) Whether to analyze dormant ratio. + analysis_dormant_ratio=False, + # (int) The shape of the action space. + action_space_size=6, + # (int) The size of the group, related to simulation normalization. + group_size=8, # NOTE: for sim_norm + # (str) The type of attention mechanism used. Options could be ['causal']. + attention='causal', + # (int) The number of layers in the model. + num_layers=2, + # (int) The number of attention heads. + num_heads=8, + # (int) The dimension of the embedding. + embed_dim=768, + # (float) The dropout probability for the embedding layer. + embed_pdrop=0.1, + # (float) The dropout probability for the residual connections. + resid_pdrop=0.1, + # (float) The dropout probability for the attention mechanism. + attn_pdrop=0.1, + # (int) The size of the support set for value and reward heads. + support_size=101, + # (int) The maximum size of the cache. + max_cache_size=5000, + # (int) The number of environments. + env_num=8, + # (float) The weight of the latent reconstruction loss. + latent_recon_loss_weight=0., + # (float) The weight of the perceptual loss. + perceptual_loss_weight=0., + # (float) The weight of the policy entropy. + policy_entropy_weight=1e-4, + # (str) The type of loss for predicting latent variables. Options could be ['group_kl', 'mse']. + predict_latent_loss_type='group_kl', + # (str) The type of observation. Options are ['image', 'vector']. + obs_type='image', + # (float) The discount factor for future rewards. + gamma=1, + # (bool) Whether to analyze dormant ratio, average_weight_magnitude of net, effective_rank of latent. + analysis_dormant_ratio_weight_rank=False, + # (float) The threshold for a dormant neuron. + dormant_threshold=0.01, + + ), + ), + # ****** common ****** + # (bool) whether to use rnd model. + use_rnd_model=False, + # (bool) Whether to use multi-gpu training. + multi_gpu=True, + # (bool) Whether to enable the sampled-based algorithm (e.g. Sampled EfficientZero) + # this variable is used in ``collector``. + sampled_algo=False, + # (bool) Whether to enable the gumbel-based algorithm (e.g. Gumbel Muzero) + gumbel_algo=False, + # (bool) Whether to use C++ MCTS in policy. If False, use Python implementation. + mcts_ctree=True, + # (bool) Whether to use cuda for network. + cuda=True, + # (int) The number of environments used in collecting data. + collector_env_num=8, + # (int) The number of environments used in evaluating policy. + evaluator_env_num=3, + # (str) The type of environment. Options are ['not_board_games', 'board_games']. + env_type='not_board_games', + # (str) The type of action space. Options are ['fixed_action_space', 'varied_action_space']. + action_type='fixed_action_space', + # (str) The type of battle mode. Options are ['play_with_bot_mode', 'self_play_mode']. + battle_mode='play_with_bot_mode', + # (bool) Whether to monitor extra statistics in tensorboard. + monitor_extra_statistics=True, + # (int) The transition number of one ``GameSegment``. + game_segment_length=400, + # (bool) Whether to analyze simulation normalization. + analysis_sim_norm=False, + # (bool) Whether to use the pure policy to collect data. + collect_with_pure_policy=False, + # (int) The evaluation frequency. + eval_freq=int(5e3), + # (str) The sample type. Options are ['episode', 'transition']. + sample_type='transition', + + # ****** observation ****** + # (bool) Whether to transform image to string to save memory. + transform2string=False, + # (bool) Whether to use gray scale image. + gray_scale=False, + # (bool) Whether to use data augmentation. + use_augmentation=False, + # (list) The style of augmentation. + augmentation=['shift', 'intensity'], + + # ******* learn ****** + # (bool) Whether to ignore the done flag in the training data. Typically, this value is set to False. + # However, for some environments with a fixed episode length, to ensure the accuracy of Q-value calculations, + # we should set it to True to avoid the influence of the done flag. + ignore_done=False, + # (int) How many updates(iterations) to train after collector's one collection. + # Bigger "update_per_collect" means bigger off-policy. + # collect data -> update policy-> collect data -> ... + # For different env, we have different episode_length, + # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor. + # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.replay_ratio automatically. + update_per_collect=None, + # (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None. + replay_ratio=0.25, + # (int) Minibatch size for one gradient descent. + batch_size=256, + # (str) Optimizer for training policy network. + optim_type='AdamW', + # (float) Learning rate for training policy network. Initial lr for manually decay schedule. + learning_rate=0.0001, + # (int) Frequency of hard target network update. + target_update_freq=100, + # (int) Frequency of soft target network update. + target_update_theta=0.05, + # (int) Frequency of target network update. + target_update_freq_for_intrinsic_reward=1000, + # (float) Weight decay for training policy network. + weight_decay=1e-4, + # (float) One-order Momentum in optimizer, which stabilizes the training process (gradient direction). + momentum=0.9, + # (float) The maximum constraint value of gradient norm clipping. + grad_clip_value=5, + # (int) The number of episodes in each collecting stage when use muzero_collector. + n_episode=8, + # (int) The number of num_segments in each collecting stage when use muzero_segment_collector. + num_segments=8, + # # (int) the number of simulations in MCTS for renalyze. + num_simulations=50, + # (int) The number of simulations in MCTS for the collect phase. + collect_num_simulations=25, + # (int) The number of simulations in MCTS for the eval phase. + eval_num_simulations=50, + # (float) Discount factor (gamma) for returns. + discount_factor=0.997, + # (int) The number of steps for calculating target q_value. + td_steps=5, + # (int) The number of unroll steps in dynamics network. + num_unroll_steps=10, + # (float) The weight of reward loss. + reward_loss_weight=1, + # (float) The weight of value loss. + value_loss_weight=0.25, + # (float) The weight of policy loss. + policy_loss_weight=1, + # (float) The weight of ssl (self-supervised learning) loss. + ssl_loss_weight=0, + cos_lr_scheduler=False, + piecewise_decay_lr_scheduler=False, + # (bool) Whether to use piecewise constant learning rate decay. + # i.e. lr: 0.2 -> 0.02 -> 0.002 + lr_piecewise_constant_decay=False, + # (int) The number of final training iterations to control lr decay, which is only used for manually decay. + threshold_training_steps_for_final_lr=int(5e4), + # (bool) Whether to use manually decayed temperature. + manual_temperature_decay=False, + # (int) The number of final training iterations to control temperature, which is only used for manually decay. + threshold_training_steps_for_final_temperature=int(1e5), + # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. + # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. + fixed_temperature_value=0.25, + # (bool) Whether to use the true chance in MCTS in some environments with stochastic dynamics, such as 2048. + use_ture_chance_label_in_chance_encoder=False, + + # ****** Priority ****** + # (bool) Whether to use priority when sampling training data from the buffer. + use_priority=False, + # (float) The degree of prioritization to use. A value of 0 means no prioritization, + # while a value of 1 means full prioritization. + priority_prob_alpha=0.6, + # (float) The degree of correction to use. A value of 0 means no correction, + # while a value of 1 means full correction. + priority_prob_beta=0.4, + # (int) The initial Env Steps for training. + train_start_after_envsteps=int(0), + + # ****** UCB ****** + # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of search tree. + root_dirichlet_alpha=0.3, + # (float) The noise weight at the root node of the search tree. + root_noise_weight=0.25, + + # ****** Explore by random collect ****** + # (int) The number of episodes to collect data randomly before training. + random_collect_episode_num=0, + + # ****** Explore by eps greedy ****** + eps=dict( + # (bool) Whether to use eps greedy exploration in collecting data. + eps_greedy_exploration_in_collect=False, + # (str) The type of decaying epsilon. Options are 'linear', 'exp'. + type='linear', + # (float) The start value of eps. + start=1., + # (float) The end value of eps. + end=0.05, + # (int) The decay steps from start to end eps. + decay=int(1e5), + ), + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm's default model setting for demonstration. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): A tuple containing the model name and a list of import paths. + - model_type (:obj:`str`): The model type used in this algorithm, registered in ModelRegistry. + - import_names (:obj:`List[str]`): The list of model class paths used in this algorithm. + .. note:: + Users can define and use customized network models, but they must adhere to the same interface definition + as indicated by the import_names path. For multi-task UniZero, this is ``lzero.model.unizero_model_multitask.UniZeroMTModel``. + """ + # NOTE: This specifies the default multi-task model. + return 'UniZeroMTModel', ['lzero.model.unizero_model_multitask'] + + def _init_learn(self) -> None: + """ + Overview: + Initializes the learn mode. This method is called by ``self.__init__``. + It sets up the learn model, optimizer, target model, and other utilities required for training. + """ + if self._cfg.optim_type == 'SGD': + # --- 改为SGD优化器 --- + self._optimizer_world_model = torch.optim.SGD( + self._model.world_model.parameters(), + lr=self._cfg.learning_rate, # 初始学习率,在配置中设为 0.2 + momentum=self._cfg.momentum, # 在配置中设为 0.9 + weight_decay=self._cfg.weight_decay # 在配置中设为 1e-4 + ) + elif self._cfg.optim_type == 'AdamW': + # NOTE: nanoGPT optimizer + self._optimizer_world_model = configure_optimizers_nanogpt( + model=self._model.world_model, + learning_rate=self._cfg.learning_rate, + weight_decay=self._cfg.weight_decay, + device_type=self._cfg.device, + betas=(0.9, 0.95), + ) + elif self._cfg.optim_type == 'AdamW_mix_lr_wdecay': + self._optimizer_world_model = configure_optimizer_unizero( + model=self._model.world_model, + learning_rate=self._cfg.learning_rate, # 使用一个合理的AdamW基础学习率 + weight_decay=self._cfg.weight_decay, + device_type=self._cfg.device, + betas=(0.9, 0.95), + ) + + if self._cfg.cos_lr_scheduler: + from torch.optim.lr_scheduler import CosineAnnealingLR + # TODO: check the total training steps + # self.lr_scheduler = CosineAnnealingLR(self._optimizer_world_model, 1e5, eta_min=0, last_epoch=-1) + total_iters = self._cfg.get('total_iterations', 500000) # 500k iter + # final_lr = self._cfg.get('final_learning_rate', 0.0) + final_lr = self._cfg.get('final_learning_rate', 1e-6) + + self.lr_scheduler = CosineAnnealingLR( + self._optimizer_world_model, + T_max=total_iters, + eta_min=final_lr + ) + print(f"CosineAnnealingLR enabled: T_max={total_iters}, eta_min={final_lr}") + + + if self._cfg.piecewise_decay_lr_scheduler: + from torch.optim.lr_scheduler import LambdaLR + max_step = self._cfg.threshold_training_steps_for_final_lr + # NOTE: the 1, 0.1, 0.01 is the decay rate, not the lr. + lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa + self.lr_scheduler = LambdaLR(self._optimizer_world_model, lr_lambda=lr_lambda) + + + # Use a deep copy for the target model. + self._target_model = copy.deepcopy(self._model) + # Ensure that the installed torch version is >= 2.0 for torch.compile. + assert int(''.join(filter(str.isdigit, torch.__version__))) >= 200, "We need torch version >= 2.0" + self._model = torch.compile(self._model) + self._target_model = torch.compile(self._target_model) + + # Wrap the target model for soft updates (momentum-based). + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='momentum', + update_kwargs={'theta': self._cfg.target_update_theta} + ) + self._learn_model = self._model + + if self._cfg.use_augmentation: + self.image_transforms = ImageTransforms( + self._cfg.augmentation, + image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) + ) + + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) + self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) + self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) + + self.intermediate_losses = defaultdict(float) + self.l2_norm_before = 0. + self.l2_norm_after = 0. + self.grad_norm_before = 0. + self.grad_norm_after = 0. + + # Create a WrappedModel instance. + # This is used for gradient correction methods where gradients of shared parameters are managed. + # In this setup, all parameters are considered shared and subject to correction. + # wrapped_model = WrappedModel( + # self._learn_model.world_model, + # ) + + self.task_id = self._cfg.task_id + self.task_num_for_current_rank = self._cfg.task_num + + print(f'self._cfg.only_use_moco_stats:{self._cfg.only_use_moco_stats}') + if self._cfg.use_moco or self._cfg.only_use_moco_stats: + # The prediction heads' gradients are not corrected. + self.wrapped_model = WrappedModelV2( + # TODO: This assumes the tokenizer has an encoder attribute which is a list. This might need to be more robust. + self._learn_model.world_model.tokenizer.encoder[0], + self._learn_model.world_model.transformer, + self._learn_model.world_model.pos_emb, + self._learn_model.world_model.task_emb, + self._learn_model.world_model.act_embedding_table, + ) + + # Alternative setup: The head and tokenizer.encoder gradients are not corrected. + # wrapped_model = WrappedModelV3( + # self._learn_model.world_model.transformer, + # self._learn_model.world_model.pos_emb, + # self._learn_model.world_model.task_emb, + # self._learn_model.world_model.act_embedding_table, + # ) + + # Pass the wrapped_model as `shared_module` to the gradient correction method. + # ========= Initialize MoCo/CAGrad parameters ========= + if self._cfg.moco_version=="v0": + # This version is only compatible with single-GPU training. + self.grad_correct = GradCorrect(self.wrapped_model, self._cfg.total_task_num, self._cfg.device, self._cfg.multi_gpu) + self.grad_correct.init_param() + self.grad_correct.rep_grad = False + elif self._cfg.moco_version=="v1": + cfg_moco = MoCoCfg( + beta0=0.9, beta_sigma=0.95, + gamma0=0.1, gamma_sigma=0.95, + rho=0.01, stat_interval=10000) + self.grad_correct = FastMoCo( + shared_module=self.wrapped_model, + world_task_num=self._cfg.total_task_num, # Total number of tasks globally + device=self._cfg.device, + multi_gpu=self._cfg.multi_gpu, + cfg=cfg_moco, + ) + + # Cache for plasticity-related metrics from the previous frame. + self._prev_plasticity_metrics = dict( + dormant_ratio_encoder = 0.0, + dormant_ratio_transformer = 0.0, + dormant_ratio_head = 0.0, + avg_weight_mag_encoder = 0.0, + avg_weight_mag_transformer = 0.0, + avg_weight_mag_head = 0.0, + e_rank_last_linear = 0.0, + e_rank_sim_norm = 0.0, + ) + + # ==================== START: 目标熵正则化初始化 ==================== + # 从配置中读取是否启用自适应alpha,并提供一个默认值 + self.use_adaptive_entropy_weight = self._cfg.get('use_adaptive_entropy_weight', True) + + # 在 _init_learn 中增加配置 + self.target_entropy_start_ratio = self._cfg.get('target_entropy_start_ratio', 0.98) + self.target_entropy_end_ratio = self._cfg.get('target_entropy_end_ratio', 0.7) + self.target_entropy_decay_steps = self._cfg.get('target_entropy_decay_steps', 200000) # 例如,在200k步内完成退火 2M envsteps + + if self.use_adaptive_entropy_weight: + # 1. 设置目标熵。对于离散动作空间,一个常见的启发式设置是动作空间维度的负对数乘以一个系数。 + # 这个系数(例如0.98)可以作为一个超参数。 + action_space_size = self._cfg.model.action_space_size + self.target_entropy = -np.log(1.0 / action_space_size) * 0.98 + + # 2. 初始化一个可学习的 log_alpha 参数。 + # 初始化为0,意味着初始的 alpha = exp(0) = 1.0。 + self.log_alpha = torch.nn.Parameter(torch.zeros(1, device=self._cfg.device), requires_grad=True) + + # 3. 为 log_alpha 创建一个专属的优化器。 + # 使用与主优化器不同的、较小的学习率(例如1e-4)通常更稳定。 + alpha_lr = self._cfg.get('adaptive_entropy_alpha_lr', 1e-4) + self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=alpha_lr) + + print("="*20) + print(">>> 目标熵正则化 (自适应Alpha) 已启用 <<<") + print(f" 目标熵 (Target Entropy): {self.target_entropy:.4f}") + print(f" Alpha 优化器学习率: {alpha_lr:.2e}") + print("="*20) + # ===================== END: 目标熵正则化初始化 ===================== + + self.latent_norm_clip_threshold = self._cfg.get('latent_norm_clip_threshold', 30.0) + # ==================== START: 初始化 Encoder-Clip Annealing 参数 ==================== + self.use_encoder_clip_annealing = self._cfg.get('use_encoder_clip_annealing', False) + if self.use_encoder_clip_annealing: + self.encoder_clip_anneal_type = self._cfg.get('encoder_clip_anneal_type', 'cosine') + self.encoder_clip_start = self._cfg.get('encoder_clip_start_value', 30.0) + self.encoder_clip_end = self._cfg.get('encoder_clip_end_value', 10.0) + self.encoder_clip_anneal_steps = self._cfg.get('encoder_clip_anneal_steps', 200000) + + print("="*20) + print(">>> Encoder-Clip 退火已启用 <<<") + print(f" 类型: {self.encoder_clip_anneal_type}") + print(f" 范围: {self.encoder_clip_start} -> {self.encoder_clip_end}") + print(f" 步数: {self.encoder_clip_anneal_steps}") + print("="*20) + else: + # 如果不启用退火,则使用固定的 clip 阈值 + self.latent_norm_clip_threshold = self._cfg.get('latent_norm_clip_threshold', 30.0) + # ===================== END: 初始化 Encoder-Clip Annealing 参数 ===================== + + # --- NEW: Policy Label Smoothing Parameters --- + self.policy_ls_eps_start = self._cfg.get('policy_ls_eps_start', 0.05) # TODO policy_label_smoothing_eps_start 越大的action space需要越大的eps + self.policy_ls_eps_end = self._cfg.get('policy_label_smoothing_eps_end ', 0.01) # TODO policy_label_smoothing_eps_start + self.policy_ls_eps_decay_steps = self._cfg.get('policy_ls_eps_decay_steps ', 50000) # TODO 50k + print(f"self.policy_ls_eps_start:{self.policy_ls_eps_start}") + + @staticmethod + def _is_zero(x: Union[float, torch.Tensor], eps: float = 1e-8) -> bool: + """ + Overview: + Checks if a scalar or a 0-D tensor can be considered zero within a small tolerance. + Arguments: + - x (:obj:`Union[float, torch.Tensor]`): The input value to check. + - eps (:obj:`float`): The tolerance for checking against zero. + Returns: + - (:obj:`bool`): True if the value is close to zero, False otherwise. + """ + if isinstance(x, torch.Tensor): + return torch.all(torch.abs(x) < eps).item() + return abs(x) < eps + + def _retain_prev_if_zero(self, name: str, + value: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]: + """ + Overview: + If the current `value` is close to zero, returns the cached value from the previous frame. + Otherwise, it updates the cache with the current value and returns it. This is useful for + metrics that are computed intermittently. + Arguments: + - name (:obj:`str`): The name of the metric to cache. + - value (:obj:`Union[float, torch.Tensor]`): The current value of the metric. + Returns: + - (:obj:`Union[float, torch.Tensor]`): The retained or current value. + """ + if self._is_zero(value): + # Directly return the previous value (can be float or tensor). + return self._prev_plasticity_metrics[name] + else: + # Update the cache and return the current value. + self._prev_plasticity_metrics[name] = value + return value + + + #@profile + def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, train_iter=None, ignore_grad=False) -> Dict[str, Union[float, int]]: + """ + Overview: + The forward function for learning in the policy. This is the core of the training process. + Data is sampled from the replay buffer, losses are calculated, and the model is updated via backpropagation. + Arguments: + - data (:obj:`Tuple[torch.Tensor]`): A tuple of data batches, where each element corresponds to a different task. + - task_weights (:obj:`Any`, optional): Optional weights for each task's loss. Not currently used. + - ignore_grad (:obj:`bool`): If True, gradients are zeroed out after computation, effectively skipping the update. + Returns: + - info_dict (:obj:`Dict[str, Union[float, int]]`): A dictionary containing current learning losses and statistics for logging. + """ + self._learn_model.train() + self._target_model.train() + + # Lists to store metrics for each task within the batch. + obs_loss_multi_task = [] + reward_loss_multi_task = [] + policy_loss_multi_task = [] + value_loss_multi_task = [] + latent_recon_loss_multi_task = [] + perceptual_loss_multi_task = [] + orig_policy_loss_multi_task = [] + policy_entropy_multi_task = [] + weighted_total_loss = 0.0 # Initialize to 0.0 to avoid in-place operations. + + latent_state_l2_norms_multi_task = [] + average_target_policy_entropy_multi_task = [] + value_priority_multi_task = [] + value_priority_mean_multi_task = [] + + # Metrics for network plasticity analysis. + dormant_ratio_encoder_multi_task = [] + dormant_ratio_transformer_multi_task = [] + dormant_ratio_head_multi_task = [] + avg_weight_mag_encoder_multi_task = [] + avg_weight_mag_transformer_multi_task = [] + avg_weight_mag_head_multi_task = [] + e_rank_last_linear_multi_task = [] + e_rank_sim_norm_multi_task = [] + + # --- NEW: Calculate current epsilon for policy --- + # if self.policy_ls_eps_start > 0: + # progress = min(1.0, train_iter / self.policy_ls_eps_decay_steps) + # current_policy_label_eps = self.policy_ls_eps_start * (1 - progress) + self.policy_ls_eps_end * progress + # else: + # current_policy_label_eps = 0.0 + current_policy_label_eps = 0.01 + + # 新增一个列表来收集当前批次中所有任务的真实全局ID + global_task_ids_in_batch = [] + alpha_loss = None + + losses_list = [] # Used to store the loss tensor for each task, required by gradient correction methods. + for task_id, data_one_task in enumerate(data): + current_batch, target_batch, task_id = data_one_task # task_id 是真实的全局ID + + # 将真实的全局ID添加到列表中 + global_task_ids_in_batch.append(task_id) + + # TODO: Adapt RoPE for multitask settings (using timestep_batch). + obs_batch_ori, action_batch, target_action_batch, mask_batch, indices, weights, make_time, timestep_batch = current_batch + target_reward, target_value, target_policy = target_batch + + # Prepare observations based on frame stack number. + if self._cfg.model.frame_stack_num == 4: + obs_batch, obs_target_batch = prepare_obs_stack_for_unizero(obs_batch_ori, self._cfg) + else: + obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) + + # Apply augmentations if needed. + if self._cfg.use_augmentation: + obs_batch = self.image_transforms.transform(obs_batch) + if self._cfg.model.self_supervised_learning_loss: + obs_target_batch = self.image_transforms.transform(obs_target_batch) + + # Prepare action batch and convert to a torch tensor. + action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze( + -1).long() # For discrete action space. + data_list = [mask_batch, target_reward.astype('float32'), target_value.astype('float32'), target_policy, + weights] + mask_batch, target_reward, target_value, target_policy, weights = to_torch_float_tensor(data_list, + self._cfg.device) + + cur_batch_size = target_reward.size(0) # Run-time batch size. + + target_reward = target_reward.view(cur_batch_size, -1) + target_value = target_value.view(cur_batch_size, -1) + + # Transform scalar rewards and values to their scaled representations. + transformed_target_reward = scalar_transform(target_reward) + transformed_target_value = scalar_transform(target_value) + + # Convert scaled representations to categorical distributions. + # target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) + # target_value_categorical = phi_transform(self.value_support, transformed_target_value) + + target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward, label_smoothing_eps= self._cfg.label_smoothing_eps) + target_value_categorical = phi_transform(self.value_support, transformed_target_value, label_smoothing_eps=self._cfg.label_smoothing_eps) + + + # Prepare the batch for the transformer-based world model. + batch_for_gpt = {} + if isinstance(self._cfg.model.observation_shape, int) or len(self._cfg.model.observation_shape) == 1: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( + cur_batch_size, -1, self._cfg.model.observation_shape) + elif len(self._cfg.model.observation_shape) == 3: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( + cur_batch_size, -1, *self._cfg.model.observation_shape) + + batch_for_gpt['actions'] = action_batch.squeeze(-1) + batch_for_gpt['rewards'] = target_reward_categorical[:, :-1] + batch_for_gpt['mask_padding'] = mask_batch == 1.0 # 0 means invalid padding data. + batch_for_gpt['mask_padding'] = batch_for_gpt['mask_padding'][:, :-1] + batch_for_gpt['observations'] = batch_for_gpt['observations'][:, :-1] + batch_for_gpt['ends'] = torch.zeros(batch_for_gpt['mask_padding'].shape, dtype=torch.long, + device=self._cfg.device) + batch_for_gpt['target_value'] = target_value_categorical[:, :-1] + batch_for_gpt['target_policy'] = target_policy[:, :-1] + batch_for_gpt['scalar_target_value'] = target_value + + # Extract valid target policy data and compute its entropy. + valid_target_policy = batch_for_gpt['target_policy'][batch_for_gpt['mask_padding']] + target_policy_entropy = -torch.sum(valid_target_policy * torch.log(valid_target_policy + 1e-9), dim=-1) + average_target_policy_entropy = target_policy_entropy.mean().item() + + # Update world model and compute losses. + intermediate_losses = defaultdict(float) + # losses = self._learn_model.world_model.compute_loss( + # batch_for_gpt, self._target_model.world_model.tokenizer, self.value_inverse_scalar_transform_handle, task_id=task_id + # ) + + losses = self._learn_model.world_model.compute_loss( + batch_for_gpt, self._target_model.world_model.tokenizer, self.value_inverse_scalar_transform_handle, current_policy_label_eps=current_policy_label_eps, task_id=task_id + ) + + # ==================== START MODIFICATION 2 ==================== + # Extract the calculated value_priority from the returned losses. + value_priority_tensor = losses.intermediate_losses['value_priority'] + # Convert to numpy array for the replay buffer, adding a small epsilon. + value_priority_np = value_priority_tensor.detach().cpu().numpy() + 1e-6 + # ===================== END MODIFICATION 2 ===================== + + + # TODO: Accumulate the weighted total loss. This assumes the loss from `compute_loss` is already weighted. + weighted_total_loss += losses.loss_total # NOTE:+= + + # TODO: Add assertions to check for NaN or Inf values in the loss if needed for debugging. + # assert not torch.isnan(losses.loss_total).any(), "Loss contains NaN values" + # assert not torch.isinf(losses.loss_total).any(), "Loss contains Inf values" + + # TODO: Append the total loss for this task, used by MoCo. + losses_list.append(losses.loss_total) + + for loss_name, loss_value in losses.intermediate_losses.items(): + intermediate_losses[f"{loss_name}"] = loss_value + + + + obs_loss = intermediate_losses['loss_obs'] + reward_loss = intermediate_losses['loss_rewards'] + policy_loss = intermediate_losses['loss_policy'] + orig_policy_loss = intermediate_losses['orig_policy_loss'] + policy_entropy = intermediate_losses['policy_entropy'] + value_loss = intermediate_losses['loss_value'] + latent_recon_loss = intermediate_losses['latent_recon_loss'] + perceptual_loss = intermediate_losses['perceptual_loss'] + latent_state_l2_norms = intermediate_losses['latent_state_l2_norms'] + + # 从 losses 对象中提取策略熵 + # ==================== START: 目标熵正则化更新逻辑 ==================== + current_alpha = self._cfg.model.world_model_cfg.policy_entropy_weight # 默认使用固定值 + if self.use_adaptive_entropy_weight: + # --- 动态计算目标熵 (这部分逻辑是正确的,予以保留) --- + progress = min(1.0, train_iter / self.target_entropy_decay_steps) + current_ratio = self.target_entropy_start_ratio * (1 - progress) + self.target_entropy_end_ratio * progress + action_space_size = self._cfg.model.action_space_size + # 注意:我们将 target_entropy 定义为正数,更符合直觉 + current_target_entropy = -np.log(1.0 / action_space_size) * current_ratio + + # --- 计算 alpha_loss (已修正符号) --- + # 这是核心修正点:去掉了最前面的负号 + # detach() 仍然是关键,确保 alpha_loss 的梯度只流向 log_alpha + alpha_loss = (self.log_alpha * (policy_entropy.detach() - current_target_entropy)).mean() # NOTE:= + + # # --- 更新 log_alpha --- + self.alpha_optimizer.zero_grad() + alpha_loss.backward() + self.alpha_optimizer.step() + # --- [优化建议] 增加 log_alpha 裁剪作为安全措施 --- + with torch.no_grad(): + # 将 alpha 限制在例如 [1e-4, 10.0] 的范围内 + self.log_alpha.clamp_(np.log(1e-4), np.log(10.0)) + + # --- 使用当前更新后的 alpha (截断梯度流) --- + current_alpha = self.log_alpha.exp().detach() + + # 重新计算加权的策略损失和总损失 + # 注意:这里的 policy_entropy 已经是一个batch的平均值 + weighted_policy_loss = orig_policy_loss - current_alpha * policy_entropy + # 重新构建总损失 (不使用 losses.loss_total) + # 确保这里的权重与 LossWithIntermediateLosses 类中的计算方式一致 + self.obs_loss_weight = 10 + self.value_loss_weight = 0.5 + self.reward_loss_weight = 1. + self.policy_loss_weight = 1. + self.ends_loss_weight = 0. + total_loss = ( + self.reward_loss_weight * reward_loss + + self.value_loss_weight * value_loss + + self.policy_loss_weight * weighted_policy_loss + + self.obs_loss_weight * obs_loss # 假设 ssl_loss_weight 是 obs_loss 的权重 + # ... 如果还有其他损失项,也加进来 ... + ) + weighted_total_loss += (weights * total_loss).mean() # NOTE:+= + # ===================== END: 目标熵正则化更新逻辑 ===================== + + # ============ For value-based priority calculation ============ + # TODO: The following section for calculating value_priority is commented out. + # If re-enabled, ensure it correctly computes L1 loss between predicted and target values + # and handles CPU/Numpy conversion properly. + # original_value = self.value_inverse_scalar_transform_handle(logits_value.reshape(-1, 101)).reshape( + # batch_for_gpt['observations'].shape[0], batch_for_gpt['observations'].shape[1], 1) + # value_priority = torch.nn.L1Loss(reduction='none')(original_value.squeeze(-1)[:,0], target_value[:, 0]) + # value_priority = value_priority.data.cpu().numpy() + 1e-6 + # value_priority = torch.tensor(0., device=self._cfg.device) + # ============ End of value priority section ============ + + # Metrics related to network plasticity. + # Use the helper function to retain the previous value if the current one is zero. + dormant_ratio_encoder = self._retain_prev_if_zero( + 'dormant_ratio_encoder', + intermediate_losses['dormant_ratio_encoder']) + dormant_ratio_transformer = self._retain_prev_if_zero( + 'dormant_ratio_transformer', + intermediate_losses['dormant_ratio_transformer']) + dormant_ratio_head = self._retain_prev_if_zero( + 'dormant_ratio_head', + intermediate_losses['dormant_ratio_head']) + avg_weight_mag_encoder = self._retain_prev_if_zero( + 'avg_weight_mag_encoder', + intermediate_losses['avg_weight_mag_encoder']) + avg_weight_mag_transformer = self._retain_prev_if_zero( + 'avg_weight_mag_transformer', + intermediate_losses['avg_weight_mag_transformer']) + avg_weight_mag_head = self._retain_prev_if_zero( + 'avg_weight_mag_head', + intermediate_losses['avg_weight_mag_head']) + e_rank_last_linear = self._retain_prev_if_zero( + 'e_rank_last_linear', + intermediate_losses['e_rank_last_linear']) + e_rank_sim_norm = self._retain_prev_if_zero( + 'e_rank_sim_norm', + intermediate_losses['e_rank_sim_norm']) + + # Append all metrics for this task to their respective lists. + obs_loss_multi_task.append(obs_loss) + reward_loss_multi_task.append(reward_loss) + policy_loss_multi_task.append(policy_loss) + orig_policy_loss_multi_task.append(orig_policy_loss) + policy_entropy_multi_task.append(policy_entropy) + value_loss_multi_task.append(value_loss) + latent_recon_loss_multi_task.append(latent_recon_loss) + perceptual_loss_multi_task.append(perceptual_loss) + latent_state_l2_norms_multi_task.append(latent_state_l2_norms) + value_priority_multi_task.append(value_priority_tensor) + value_priority_mean_multi_task.append(value_priority_tensor.mean().item()) + + # Append plasticity metrics. + dormant_ratio_encoder_multi_task.append(dormant_ratio_encoder) + dormant_ratio_transformer_multi_task.append(dormant_ratio_transformer) + dormant_ratio_head_multi_task.append(dormant_ratio_head) + avg_weight_mag_encoder_multi_task.append(avg_weight_mag_encoder) + avg_weight_mag_transformer_multi_task.append(avg_weight_mag_transformer) + avg_weight_mag_head_multi_task.append(avg_weight_mag_head) + e_rank_last_linear_multi_task.append(e_rank_last_linear) + e_rank_sim_norm_multi_task.append(e_rank_sim_norm) + + + # Core learn model update step. + self._optimizer_world_model.zero_grad() + + # Assuming losses_list is a list of tensors with gradients, e.g., [loss1, loss2, ...]. + if self._cfg.use_moco: + # Call MoCo's backward method, which handles gradient correction internally. + if self._cfg.moco_version=="v0": + lambd, stats = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params) + elif self._cfg.moco_version=="v1": + lambd, stats = self.grad_correct.backward(losses_list) + + elif self._cfg.only_use_moco_stats: + # Only compute MoCo stats without applying gradient correction. + lambd, stats = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params) + # Each rank performs its own backpropagation. + weighted_total_loss.backward() + else: + # If not using gradient correction, each rank performs standard backpropagation. + lambd = torch.tensor([0. for _ in range(self.task_num_for_current_rank)], device=self._cfg.device) + weighted_total_loss.backward() + + + # ----------------------------------------------------------------- + # 仍然在 torch.no_grad() 环境下执行 + # ================================================================= + with torch.no_grad(): + # 1. Encoder-Clip + # ==================== START: 动态计算当前 Clip 阈值 ==================== + current_clip_value = self.latent_norm_clip_threshold # 默认使用固定值 + if self.use_encoder_clip_annealing: + progress = min(1.0, train_iter / self.encoder_clip_anneal_steps) + + if self.encoder_clip_anneal_type == 'cosine': + # 余弦调度: 从1平滑过渡到0 + cosine_progress = 0.5 * (1.0 + np.cos(np.pi * progress)) + current_clip_value = self.encoder_clip_end + \ + (self.encoder_clip_start - self.encoder_clip_end) * cosine_progress + else: # 默认为线性调度 + current_clip_value = self.encoder_clip_start * (1 - progress) + \ + self.encoder_clip_end * progress + # ===================== END: 动态计算当前 Clip 阈值 ===================== + + # 1. Encoder-Clip (使用动态计算出的 current_clip_value) + if current_clip_value > 0 and 'obs_embeddings' in losses.intermediate_losses: + obs_embeddings = losses.intermediate_losses['obs_embeddings'] + if obs_embeddings is not None: + max_latent_norm = obs_embeddings.norm(p=2, dim=-1).max() + if max_latent_norm > current_clip_value: + scale_factor = current_clip_value / max_latent_norm.item() + # 不再频繁打印,或者可以改为每隔N步打印一次 + if train_iter % 1000 == 0: + print(f"[Encoder-Clip Annealing] Iter {train_iter}: Max latent norm {max_latent_norm.item():.2f} > {current_clip_value:.2f}. Scaling by {scale_factor:.4f}.") + scale_module_weights_vectorized(self._model.world_model.tokenizer.encoder, scale_factor) + + + # For debugging purposes. + # for name, param in self._learn_model.world_model.tokenizer.encoder.named_parameters(): + # print('name, param.mean(), param.std():', name, param.mean(), param.std()) + # if param.requires_grad: + # print(name, param.grad.norm()) + + if self._cfg.analysis_sim_norm: + del self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after + self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after = self._learn_model.encoder_hook.analyze() + self._target_model.encoder_hook.clear_data() + + total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_(self._learn_model.world_model.parameters(), + self._cfg.grad_clip_value) + + if ignore_grad: + # NOTE: For cases where all tasks on a GPU are solved, `train` is still called for DDP synchronization, + # but gradients should be zeroed out to prevent updates. + self._optimizer_world_model.zero_grad() + + if self._cfg.multi_gpu: + # If not using a gradient correction method that handles it, sync gradients manually. + if not self._cfg.use_moco: + self.sync_gradients(self._learn_model) + + self._optimizer_world_model.step() + + if self._cfg.cos_lr_scheduler or self._cfg.piecewise_decay_lr_scheduler: + self.lr_scheduler.step() + + # Core target model update step. + self._target_model.update(self._learn_model.state_dict()) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + current_memory_allocated = torch.cuda.memory_allocated() + max_memory_allocated = torch.cuda.max_memory_allocated() + current_memory_allocated_gb = current_memory_allocated / (1024 ** 3) + max_memory_allocated_gb = max_memory_allocated / (1024 ** 3) + else: + current_memory_allocated_gb = 0. + max_memory_allocated_gb = 0. + + # Build the dictionary of return values for logging. + return_log_dict = { + 'Current_GPU': current_memory_allocated_gb, + 'Max_GPU': max_memory_allocated_gb, + 'collect_mcts_temperature': self._collect_mcts_temperature, + 'collect_epsilon': self._collect_epsilon, + 'cur_lr_world_model': self._optimizer_world_model.param_groups[0]['lr'], + 'weighted_total_loss': weighted_total_loss.item(), + 'total_grad_norm_before_clip_wm': total_grad_norm_before_clip_wm.item(), + } + + # ==================== START: 添加新日志项 ==================== + if self.use_adaptive_entropy_weight: + return_log_dict['adaptive_alpha'] = current_alpha.item() + return_log_dict['adaptive_target_entropy_ratio'] = current_ratio + return_log_dict['alpha_loss'] = alpha_loss.item() + # ==================== START: 添加新日志项 ==================== + + # Generate task-related loss dictionaries and prefix each task-related loss with "noreduce_". + multi_task_loss_dicts = { + **generate_task_loss_dict(obs_loss_multi_task, 'noreduce_obs_loss_task{}', task_id=self.task_id), #global_task_ids=global_task_ids_in_batch), # task_id=self.task_id), + **generate_task_loss_dict(latent_recon_loss_multi_task, 'noreduce_latent_recon_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(perceptual_loss_multi_task, 'noreduce_perceptual_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(latent_state_l2_norms_multi_task, 'noreduce_latent_state_l2_norms_task{}', task_id=self.task_id), + **generate_task_loss_dict(dormant_ratio_head_multi_task, 'noreduce_dormant_ratio_head_task{}', task_id=self.task_id), + + **generate_task_loss_dict(policy_loss_multi_task, 'noreduce_policy_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(orig_policy_loss_multi_task, 'noreduce_orig_policy_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(policy_entropy_multi_task, 'noreduce_policy_entropy_task{}', task_id=self.task_id), + **generate_task_loss_dict(reward_loss_multi_task, 'noreduce_reward_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_loss_multi_task, 'noreduce_value_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(average_target_policy_entropy_multi_task, 'noreduce_target_policy_entropy_task{}', task_id=self.task_id), + **generate_task_loss_dict(lambd, 'noreduce_lambd_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_priority_multi_task, 'noreduce_value_priority_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_priority_mean_multi_task, 'noreduce_value_priority_mean_task{}', task_id=self.task_id), + } + return_log_dict.update(multi_task_loss_dicts) + + + if self._learn_model.world_model.do_analysis: + # Include plasticity metrics if analysis is enabled. + plasticity_loss_dicts = { + **generate_task_loss_dict(dormant_ratio_encoder_multi_task, 'noreduce_dormant_ratio_encoder_task{}', task_id=self.task_id), + **generate_task_loss_dict(dormant_ratio_transformer_multi_task, 'noreduce_dormant_ratio_transformer_task{}', task_id=self.task_id), + **generate_task_loss_dict(dormant_ratio_head_multi_task, 'noreduce_dormant_ratio_head_task{}', task_id=self.task_id), + **generate_task_loss_dict(avg_weight_mag_encoder_multi_task, 'noreduce_avg_weight_mag_encoder_task{}', task_id=self.task_id), + **generate_task_loss_dict(avg_weight_mag_transformer_multi_task, 'noreduce_avg_weight_mag_transformer_task{}', task_id=self.task_id), + **generate_task_loss_dict(avg_weight_mag_head_multi_task, 'noreduce_avg_weight_mag_head_task{}', task_id=self.task_id), + **generate_task_loss_dict(e_rank_last_linear_multi_task, 'noreduce_e_rank_last_linear_task{}', task_id=self.task_id), + **generate_task_loss_dict(e_rank_sim_norm_multi_task, 'noreduce_e_rank_sim_norm_task{}', task_id=self.task_id), + } + # Merge the dictionaries. + return_log_dict.update(plasticity_loss_dicts) + + # Return the final loss dictionary. + return return_log_dict + + def monitor_weights_and_grads(self, model: torch.nn.Module) -> None: + """ + Overview: + A utility function to print the mean and standard deviation of weights and their gradients for each layer in a model. + Useful for debugging training issues like exploding or vanishing gradients. + Arguments: + - model (:obj:`torch.nn.Module`): The model to monitor. + """ + for name, param in model.named_parameters(): + if param.requires_grad: + print(f"Layer: {name} | " + f"Weight mean: {param.data.mean():.4f} | " + f"Weight std: {param.data.std():.4f} | " + f"Grad mean: {param.grad.mean():.4f} | " + f"Grad std: {param.grad.std():.4f}") + + def _init_collect(self) -> None: + """ + Overview: + Initializes the collect mode. This method is called by ``self.__init__``. + It sets up the collect model and MCTS utilities for data collection. + """ + self._collect_model = self._model + + # Create a copy of the configuration for collect MCTS and set a specific number of simulations. + mcts_collect_cfg = copy.deepcopy(self._cfg) + mcts_collect_cfg.num_simulations = self._cfg.collect_num_simulations + + if self._cfg.mcts_ctree: + self._mcts_collect = MCTSCtree(mcts_collect_cfg) + else: + self._mcts_collect = MCTSPtree(mcts_collect_cfg) + + self._collect_mcts_temperature = 1. + self._collect_epsilon = 0.0 + self.collector_env_num = self._cfg.collector_env_num + if self._cfg.model.model_type == 'conv': + self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_action = [-1 for i in range(self.collector_env_num)] + elif self._cfg.model.model_type == 'mlp': + self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape]).to(self._cfg.device) + self.last_batch_action = [-1 for i in range(self.collector_env_num)] + + # TODO: The num_tasks parameter is hardcoded. It should ideally be derived from the config. + def _monitor_vars_learn(self, num_tasks: int = 2) -> List[str]: + """ + Overview: + Registers variables to be monitored during training. These variables will be logged in TensorBoard. + It dynamically creates variable names for each task if `num_tasks` is provided. + Arguments: + - num_tasks (:obj:`int`): The number of tasks being trained on the current rank. + Returns: + - monitored_vars (:obj:`List[str]`): A list of strings, where each string is the name of a variable to be logged. + """ + # Basic monitored variables that do not depend on the number of tasks. + monitored_vars = [ + 'Current_GPU', + 'Max_GPU', + 'collect_epsilon', + 'collect_mcts_temperature', + 'cur_lr_world_model', + 'weighted_total_loss', + 'total_grad_norm_before_clip_wm', + + # 'value_priority', + 'adaptive_alpha', + "adaptive_target_entropy_ratio", + 'alpha_loss', + ] + + + + # Task-specific variables to be monitored. + task_specific_vars = [ + 'noreduce_obs_loss', + 'noreduce_orig_policy_loss', + 'noreduce_policy_loss', + 'noreduce_latent_recon_loss', + 'noreduce_policy_entropy', + 'noreduce_target_policy_entropy', + 'noreduce_reward_loss', + 'noreduce_value_loss', + 'noreduce_perceptual_loss', + 'noreduce_latent_state_l2_norms', + 'noreduce_lambd', + 'noreduce_value_priority_mean', + # Metrics related to network plasticity. + 'noreduce_dormant_ratio_encoder', + 'noreduce_dormant_ratio_transformer', + 'noreduce_dormant_ratio_head', + 'noreduce_avg_weight_mag_encoder', + 'noreduce_avg_weight_mag_transformer', + 'noreduce_avg_weight_mag_head', + 'noreduce_e_rank_last_linear', + 'noreduce_e_rank_sim_norm' + ] + + # Use self.task_num_for_current_rank as the number of tasks for the current rank. + num_tasks = self.task_num_for_current_rank + # If the number of tasks is provided, extend the monitored variables list with task-specific variable names. + if num_tasks is not None: + for var in task_specific_vars: + for task_idx in range(num_tasks): + monitored_vars.append(f'{var}_task{self.task_id+task_idx}') + else: + # If num_tasks is not provided, assume a single task and use the original variable names. + monitored_vars.extend(task_specific_vars) + + return monitored_vars + + #@profile + def _forward_collect( + self, + data: torch.Tensor, + action_mask: list = None, + temperature: float = 1, + to_play: List = [-1], + epsilon: float = 0.25, + ready_env_id: np.array = None, + timestep: List = [0], + task_id: int = None, + ) -> Dict: + """ + Overview: + The forward function for collecting data. It uses the model to perform MCTS search and + selects actions via sampling to encourage exploration. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e., the current observation. + - action_mask (:obj:`list`, optional): A list of action masks for each environment. + - temperature (:obj:`float`, optional): The temperature for MCTS action selection. + - to_play (:obj:`List`, optional): A list of player IDs for each environment. + - epsilon (:obj:`float`, optional): The probability for epsilon-greedy exploration. + - ready_env_id (:obj:`np.array`, optional): An array of IDs for environments that are ready for a new action. + - timestep (:obj:`List`, optional): The current timestep in each environment. + - task_id (:obj:`int`, optional): The ID of the task for the current environments. + Returns: + - output (:obj:`Dict`): A dictionary where keys are environment IDs and values are dictionaries + containing the selected action and other MCTS statistics. + """ + self._collect_model.eval() + + self._collect_mcts_temperature = temperature + self._collect_epsilon = epsilon + active_collect_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + output = {i: None for i in ready_env_id} + + with torch.no_grad(): + network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, task_id=task_id) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + + # ========================== 核心修复 ========================== + # C++ 绑定需要一个 list,即使它在 MuZero 中代表奖励。 + reward_roots = reward_roots.detach().cpu().numpy().tolist() + # =============================================================== + + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] + # The main difference between collect and eval is the addition of Dirichlet noise at the root. + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) + ).astype(np.float32).tolist() for j in range(active_collect_env_num) + ] + if self._cfg.mcts_ctree: + # C++ MCTS tree implementation. + roots = MCTSCtree.roots(active_collect_env_num, legal_actions) + else: + # Python MCTS tree implementation. + roots = MCTSPtree.roots(active_collect_env_num, legal_actions) + + + # # 在本文件开始,通过全局变量来控制是否处于调试状态 + # global DEBUG_ENABLED;DEBUG_ENABLED = True + # import torch.distributed as dist + # if dist.get_rank() == 0 and DEBUG_ENABLED: + # print(f"rank {dist.get_rank()} 进入调试模式,输入interact,可以键入整段的python代码调试。通过设置 DEBUG_ENABLED = False, 可以跳过调试状态") + # import ipdb; ipdb.set_trace() + # # 同步点,防止其它进程早跑 + # dist.barrier() + + roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) + self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play, timestep= timestep, task_id=task_id) + + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() + + batch_action = [] + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + + if self._cfg.eps.eps_greedy_exploration_in_collect: + # Epsilon-greedy collection strategy. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=True + ) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + if np.random.rand() < self._collect_epsilon: + action = np.random.choice(legal_actions[i]) + else: + # Standard collection strategy (sampling from MCTS policy). + # NOTE: `action_index_in_legal_action_set` is the index within the set of legal actions. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=False + ) + # Convert the index back to the action in the full action space. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + + # ============== TODO: This section is for visualization purposes only and should be removed for training. ============== + # It forces deterministic action selection during collection. + # action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + # distributions, temperature=self._collect_mcts_temperature, deterministic=True + # ) + # action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + # ============== End of visualization section. ============== + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + batch_action.append(action) + + self.last_batch_obs = data + self.last_batch_action = batch_action + + # ========= TODO: This logic is currently for the `muzero_segment_collector`. ========= + if active_collect_env_num < self.collector_env_num: + # When one environment in `collect_env` finishes early, the length of `self.last_batch_obs` is reduced. + # The transformer needs the `env_id` to retrieve from the KV cache, which is complex to manage with a dynamic batch size. + # Therefore, we reset `self.last_batch_action` for all environments to -1, forcing the transformer + # to start from scratch and avoid retrieval errors. + print('==========collect_forward============') + print(f'len(self.last_batch_obs) < self.collector_env_num, {active_collect_env_num}<{self.collector_env_num}') + self._reset_collect(reset_init_data=True, task_id=task_id) + if getattr(self._cfg, 'sample_type', '') == 'episode': + print('BUG: sample_type is episode, but len(self.last_batch_obs) < self.collector_env_num') + + return output + + def _init_eval(self) -> None: + """ + Overview: + Initializes the eval mode. This method is called by ``self.__init__``. + It sets up the eval model and MCTS utilities for evaluation. + """ + self._eval_model = self._model + + # Create a copy of the configuration for eval MCTS and set a specific number of simulations. + mcts_eval_cfg = copy.deepcopy(self._cfg) + mcts_eval_cfg.num_simulations = self._cfg.eval_num_simulations + + if self._cfg.mcts_ctree: + self._mcts_eval = MCTSCtree(mcts_eval_cfg) + else: + self._mcts_eval = MCTSPtree(mcts_eval_cfg) + + self.evaluator_env_num = self._cfg.evaluator_env_num + + if self._cfg.model.model_type == 'conv': + self.last_batch_obs = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] + elif self._cfg.model.model_type == 'mlp': + self.last_batch_obs = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape]).to(self._cfg.device) + self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] + + #@profile + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, + ready_env_id: np.array = None, timestep: List = [0], task_id: int = None) -> Dict: + """ + Overview: + The forward function for evaluating the policy. It uses the model to perform MCTS search and + selects actions deterministically (choosing the one with the highest visit count). + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e., the current observation. + - action_mask (:obj:`list`): A list of action masks for each environment. + - to_play (:obj:`int`, optional): The player ID for the current turn. + - ready_env_id (:obj:`np.array`, optional): An array of IDs for environments that are ready for a new action. + - timestep (:obj:`List`, optional): The current timestep in each environment. + - task_id (:obj:`int`, optional): The ID of the task for the current environments. + Returns: + - output (:obj:`Dict`): A dictionary where keys are environment IDs and values are dictionaries + containing the selected action and other MCTS statistics. + """ + self._eval_model.eval() + active_eval_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_eval_env_num) + output = {i: None for i in ready_env_id} + with torch.no_grad(): + network_output = self._eval_model.initial_inference(self.last_batch_obs_eval, self.last_batch_action, data, task_id=task_id) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + # ========================== 核心修复 ========================== + # C++ 绑定需要一个 list,即使它在 MuZero 中代表奖励。 + reward_roots = reward_roots.detach().cpu().numpy().tolist() # TODO============================= + # =============================================================== + + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)] + if self._cfg.mcts_ctree: + # C++ MCTS tree implementation. + roots = MCTSCtree.roots(active_eval_env_num, legal_actions) + else: + # Python MCTS tree implementation. + roots = MCTSPtree.roots(active_eval_env_num, legal_actions) + + # During evaluation, no noise is added to the root policy. + roots.prepare_no_noise(reward_roots, policy_logits, to_play) + self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play, timestep= timestep, task_id=task_id) + + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() + + batch_action = [] + + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + + # NOTE: `deterministic=True` means we select the action with the highest visit count (argmax) + # rather than sampling, which is standard for evaluation. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=1, deterministic=True + ) + # Convert the index back to the action in the full action space. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + batch_action.append(action) + + self.last_batch_obs_eval = data + self.last_batch_action = batch_action + + return output + + #@profile + def _reset_collect(self, env_id: int = None, current_steps: int = 0, reset_init_data: bool = True, task_id: int = None) -> None: + """ + Overview: + Resets the collection process for a specific environment or all environments. + It can clear caches and reset initial data to ensure optimal performance and prevent state leakage. + Arguments: + - env_id (:obj:`int`, optional): The ID of the environment to reset. If None, the reset applies more broadly. Defaults to None. + - current_steps (:obj:`int`, optional): The current step count in the environment, used to trigger periodic cache clearing. Defaults to 0. + - reset_init_data (:obj:`bool`, optional): If True, resets the initial observation and action buffers. Defaults to True. + - task_id (:obj:`int`, optional): The task ID, currently unused in this method. Defaults to None. + """ + if reset_init_data: + self.last_batch_obs = initialize_zeros_batch( + self._cfg.model.observation_shape, + self._cfg.collector_env_num, + self._cfg.device + ) + self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)] + # print('Collector: last_batch_obs and last_batch_action have been reset.') + + # Return immediately if env_id is not a single integer (e.g., None or a list). + # if env_id is None or isinstance(env_id, list): + # return + + # We must handle both single int and list of ints for env_id. + if env_id is not None: + if isinstance(env_id, int): + env_ids_to_reset = [env_id] + else: # Assumes it's a list + env_ids_to_reset = env_id + + # The key condition: `current_steps` is None only on the end-of-episode reset call from the collector. + if current_steps is None: + world_model = self._collect_model.world_model + for eid in env_ids_to_reset: + # Clear the specific environment's initial inference cache. + if eid < len(world_model.past_kv_cache_init_infer_envs): + world_model.past_kv_cache_init_infer_envs[eid].clear() + + print(f'>>> [Collector] Cleared KV cache for env_id: {eid} at episode end.') + + + # Determine the clear interval based on the environment's sample type. + # clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else self._cfg.game_segment_length + + # Clear caches periodically to manage memory. + # if current_steps % clear_interval == 0: + if current_steps is not None and current_steps % clear_interval == 0: + + print(f'clear_interval: {clear_interval}') + + # Clear various KV caches in the collect model's world model. + world_model = self._collect_model.world_model + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + + # Free up unused GPU memory. + torch.cuda.empty_cache() + + print(f'Collector: Caches cleared for collect_model at step {current_steps} for env {env_id}.') + + # TODO: Check if resetting the target model here is correct and necessary. + self._reset_target_model() + + #@profile + def _reset_target_model(self) -> None: + """ + Overview: + Resets the target model by clearing its internal caches. This is crucial for managing memory, + especially when using transformer-based models with KV caching. + """ + # Clear various KV caches in the target model's world model. + world_model = self._target_model.world_model + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + + # Free up unused GPU memory. + torch.cuda.empty_cache() + print('Collector: Target model past_kv_cache cleared.') + + #@profile + def _reset_eval(self, env_id: int = None, current_steps: int = 0, reset_init_data: bool = True, task_id: int = None) -> None: + """ + Overview: + Resets the evaluation process for a specific environment or all environments. + Clears caches and resets initial data to ensure clean evaluation runs. + Arguments: + - env_id (:obj:`int`, optional): The ID of the environment to reset. Defaults to None. + - current_steps (:obj:`int`, optional): The current step count, used for periodic cache clearing. Defaults to 0. + - reset_init_data (:obj:`bool`, optional): If True, resets the initial observation and action buffers. Defaults to True. + - task_id (:obj:`int`, optional): The task ID. Can be used to handle different observation shapes per task. Defaults to None. + """ + if reset_init_data: + self.last_batch_obs_eval = initialize_zeros_batch( + self._cfg.model.observation_shape, + self._cfg.evaluator_env_num, + self._cfg.device + ) + # print(f'Evaluator reset: last_batch_obs_eval shape: {self.last_batch_obs_eval.shape}') + + self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)] + + + # --- BEGIN ROBUST FIX --- + # This logic handles the crucial end-of-episode cache clearing for evaluation. + # The evaluator calls `_policy.reset([env_id])` when an episode is done. + if env_id is not None: + if isinstance(env_id, int): + env_ids_to_reset = [env_id] + else: # Assumes it's a list + env_ids_to_reset = env_id + + # The key condition: `current_steps` is None only on the end-of-episode reset call from the evaluator. + if current_steps is None: + world_model = self._eval_model.world_model + for eid in env_ids_to_reset: + # Clear the specific environment's initial inference cache. + if eid < len(world_model.past_kv_cache_init_infer_envs): + world_model.past_kv_cache_init_infer_envs[eid].clear() + + print(f'>>> [Evaluator] Cleared KV cache for env_id: {eid} at episode end.') + + # The recurrent cache is global. + world_model.past_kv_cache_recurrent_infer.clear() + + if hasattr(world_model, 'keys_values_wm_list'): + world_model.keys_values_wm_list.clear() + + torch.cuda.empty_cache() + return + # --- END ROBUST FIX --- + + # Determine the clear interval. + # clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else self._cfg.game_segment_length + + # Clear caches periodically. + # if current_steps % clear_interval == 0: + if current_steps is not None and current_steps % clear_interval == 0: + + print(f'clear_interval: {clear_interval}') + + # Clear various KV caches in the eval model's world model. + world_model = self._eval_model.world_model + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + + # Free up unused GPU memory. + torch.cuda.empty_cache() + + print(f'Evaluator: Caches cleared for eval_model at step {current_steps} for env {env_id}.') + + + def recompute_pos_emb_diff_and_clear_cache(self) -> None: + """ + Overview: + Clears all KV caches and precomputes positional embedding matrices in the model. + This is typically called when the maximum sequence length changes. + """ + # NOTE: This must be done for both the collect and target models. + for model in [self._collect_model, self._target_model]: + model.world_model.precompute_pos_emb_diff_kv() + model.world_model.clear_caches() + torch.cuda.empty_cache() + + def _state_dict_learn(self) -> Dict[str, Any]: + """ + Overview: + Returns the state dictionary of the learn mode. + This typically includes the model, target model, and optimizer states, + which are necessary for saving and resuming training. + Returns: + - state_dict (:obj:`Dict[str, Any]`): The state dictionary for the current learning progress. + """ + return { + 'model': self._learn_model.state_dict(), + 'target_model': self._target_model.state_dict(), + 'optimizer_world_model': self._optimizer_world_model.state_dict(), + } + + # ========== NOTE: This is the original version which loads all parameters from the state_dict. ========== + # def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + # """ + # Overview: + # Loads the state_dict into the policy's learn mode. + # Arguments: + # - state_dict (:obj:`Dict[str, Any]`): The state dictionary saved from a previous training session. + # """ + # self._learn_model.load_state_dict(state_dict['model']) + # self._target_model.load_state_dict(state_dict['target_model']) + # self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model']) + + # ========== NOTE: This is a pretrain-finetune version that selectively loads parameters and freezes layers. ========== + def _load_state_dict_learn(self, state_dict: Dict[str, Any], finetune_components: List[str] = []) -> None: + """ + Overview: + Loads a state_dict for fine-tuning. It excludes multi-task specific parameters + and can freeze parts of the model (e.g., encoder, transformer) based on `finetune_components`. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): The state dictionary from a pre-trained model. + - finetune_components (:obj:`List[str]`, optional): A list of component names (e.g., "encoder", "transformer") + that will remain trainable. Components not in this list will have their parameters frozen. + """ + # Example configurations for fine-tuning: + # finetune_components = [] # Loads encoder & transformer, fine-tunes only heads. + # finetune_components = ['transformer'] # Loads encoder & transformer, fine-tunes transformer & heads. + finetune_components = ["representation_network", "encoder"] # Loads encoder & transformer, fine-tunes encoder & heads. + + # Define prefixes of parameters to be excluded from loading (typically multi-task heads). + exclude_prefixes = [ + '_orig_mod.world_model.head_policy_multi_task.', + '_orig_mod.world_model.head_value_multi_task.', + '_orig_mod.world_model.head_rewards_multi_task.', + '_orig_mod.world_model.head_observations_multi_task.', + '_orig_mod.world_model.task_emb.' + ] + + # Define specific parameter keys to be excluded (for special cases like task embeddings). + exclude_keys = [ + '_orig_mod.world_model.task_emb.weight', + '_orig_mod.world_model.task_emb.bias', + ] + + def filter_state_dict(state_dict_loader: Dict[str, Any], exclude_prefixes: list, exclude_keys: list = []) -> Dict[str, Any]: + """ + Filters out parameters from a state_dict based on prefixes and specific keys. + """ + filtered = {} + for k, v in state_dict_loader.items(): + if any(k.startswith(prefix) for prefix in exclude_prefixes): + print(f"Excluding parameter: {k}") # For debugging + continue + if k in exclude_keys: + print(f"Excluding specific parameter: {k}") # For debugging + continue + filtered[k] = v + return filtered + + # Filter and load the 'model' state_dict. + if 'model' in state_dict: + model_state_dict = state_dict['model'] + filtered_model_state_dict = filter_state_dict(model_state_dict, exclude_prefixes, exclude_keys) + missing_keys, unexpected_keys = self._learn_model.load_state_dict(filtered_model_state_dict, strict=False) + if missing_keys: + print(f"Missing keys when loading _learn_model: {missing_keys}") + if unexpected_keys: + print(f"Unexpected keys when loading _learn_model: {unexpected_keys}") + else: + print("No 'model' key found in the state_dict.") + + # Filter and load the 'target_model' state_dict. + if 'target_model' in state_dict: + target_model_state_dict = state_dict['target_model'] + filtered_target_model_state_dict = filter_state_dict(target_model_state_dict, exclude_prefixes, exclude_keys) + missing_keys, unexpected_keys = self._target_model.load_state_dict(filtered_target_model_state_dict, strict=False) + if missing_keys: + print(f"Missing keys when loading _target_model: {missing_keys}") + if unexpected_keys: + print(f"Unexpected keys when loading _target_model: {unexpected_keys}") + else: + print("No 'target_model' key found in the state_dict.") + + # Handle freezing/unfreezing of parameters in _learn_model based on finetune_components. + # This assumes a naming convention where component names are present in parameter names. + for name, param in self._learn_model.named_parameters(): + # Freeze the encoder if "encoder" is not in finetune_components. + if "encoder" in name and "encoder" not in finetune_components: + param.requires_grad = False + print(f"Freezing parameter: {name}") + # Freeze the representation network if "representation_network" is not in finetune_components. + elif "representation_network" in name and "representation_network" not in finetune_components: + param.requires_grad = False + print(f"Freezing parameter: {name}") + # Freeze the transformer if "transformer" is not in finetune_components. + elif "transformer" in name and "transformer" not in finetune_components: + param.requires_grad = False + print(f"Freezing parameter: {name}") + else: + # Other parameters remain trainable by default. + print(f"Parameter remains trainable: {name}") + + # NOTE: For more complex model structures, it might be better to identify modules by their class + # rather than relying on parameter names. For example: + # for module in self._learn_model.modules(): + # if isinstance(module, EncoderModule) and "encoder" not in finetune_components: + # for param in module.parameters(): + # param.requires_grad = False + + # ========== NOTE: Another pretrain-finetune version. The main difference from the above is the freezing logic and comments. ========== + # def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + # """ + # Overview: + # Loads a state_dict into the policy's learn mode, excluding multi-task related parameters. + # This is intended for fine-tuning a pre-trained model on new tasks. + # Arguments: + # - state_dict (:obj:`Dict[str, Any]`): The state dictionary from a pre-trained model. + # """ + # # Define prefixes of parameters to be excluded. + # exclude_prefixes = [ + # '_orig_mod.world_model.head_policy_multi_task.', + # '_orig_mod.world_model.head_value_multi_task.', + # '_orig_mod.world_model.head_rewards_multi_task.', + # '_orig_mod.world_model.head_observations_multi_task.', + # '_orig_mod.world_model.task_emb.' + # ] + + # # Define specific parameter keys to be excluded. + # exclude_keys = [ + # '_orig_mod.world_model.task_emb.weight', + # '_orig_mod.world_model.task_emb.bias', + # ] + + # def filter_state_dict(state_dict_loader: Dict[str, Any], exclude_prefixes: list, exclude_keys: list = []) -> Dict[str, Any]: + # """ + # Filters out parameters that should not be loaded. + # """ + # filtered = {} + # for k, v in state_dict_loader.items(): + # if any(k.startswith(prefix) for prefix in exclude_prefixes): + # print(f"Excluding parameter: {k}") + # continue + # if k in exclude_keys: + # print(f"Excluding specific parameter: {k}") + # continue + # filtered[k] = v + # return filtered + + # # Filter and load the 'model' part. + # if 'model' in state_dict: + # model_state_dict = state_dict['model'] + # filtered_model_state_dict = filter_state_dict(model_state_dict, exclude_prefixes, exclude_keys) + # missing_keys, unexpected_keys = self._learn_model.load_state_dict(filtered_model_state_dict, strict=False) + # if missing_keys: + # print(f"Missing keys when loading _learn_model: {missing_keys}") + # if unexpected_keys: + # print(f"Unexpected keys when loading _learn_model: {unexpected_keys}") + # else: + # print("No 'model' key found in the state_dict.") + + # # Filter and load the 'target_model' part. + # if 'target_model' in state_dict: + # target_model_state_dict = state_dict['target_model'] + # filtered_target_model_state_dict = filter_state_dict(target_model_state_dict, exclude_prefixes, exclude_keys) + # missing_keys, unexpected_keys = self._target_model.load_state_dict(filtered_target_model_state_dict, strict=False) + # if missing_keys: + # print(f"Missing keys when loading _target_model: {missing_keys}") + # if unexpected_keys: + # print(f"Unexpected keys when loading _target_model: {unexpected_keys}") + # else: + # print("No 'target_model' key found in the state_dict.") + + # # Do not load the optimizer's state_dict when fine-tuning, as it contains state (like momentum) + # # specific to the pre-training task, which can hinder adaptation to new tasks. + # # A fresh optimizer is usually preferred. + # # if 'optimizer_world_model' in state_dict: + # # ... \ No newline at end of file diff --git a/lzero/policy/utils.py b/lzero/policy/utils.py index 631af8391..1dd85d259 100644 --- a/lzero/policy/utils.py +++ b/lzero/policy/utils.py @@ -10,7 +10,20 @@ from easydict import EasyDict from scipy.stats import entropy from torch.nn import functional as F +import nltk +from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction +def compute_bleu(reference: str, prediction: str) -> float: + """ + Compute sentence-level BLEU-4 score with smoothing and scale it to 0–1. + """ + if reference is None or prediction is None: + return 0.0 + reference_tokens = reference.strip().split() + prediction_tokens = prediction.strip().split() + smoothing = SmoothingFunction().method4 + bleu = sentence_bleu([reference_tokens], prediction_tokens, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smoothing) + return bleu def pad_and_get_lengths(inputs, num_of_sampled_actions): """ @@ -54,7 +67,7 @@ def visualize_avg_softmax(logits): avg_probabilities = torch.mean(probabilities, dim=0) # Convert to numpy for visualization. - avg_probabilities_np = avg_probabilities.detach().numpy() + avg_probabilities_np = avg_probabilities.detach().cpu().numpy() # Create a bar plot. plt.figure(figsize=(10, 8)) @@ -198,29 +211,69 @@ def forward(self, input): return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) -# modified from https://github.com/karpathy/nanoGPT/blob/master/model.py#L263 -def configure_optimizers_nanogpt(model, weight_decay, learning_rate, betas, device_type): - # start with all of the candidate parameters +# The following code is modified from the original implementation at: +# https://github.com/karpathy/nanoGPT/blob/master/model.py#L263 + +def configure_optimizers_nanogpt( + model: nn.Module, + weight_decay: float, + learning_rate: float, + betas: Tuple[float, float], + device_type: str +) -> torch.optim.AdamW: + """ + Overview: + Configures the AdamW optimizer for the nanoGPT model. This function separates model + parameters into two groups: one that will be subject to weight decay and one that will not. + Typically, 2D and higher-dimensional tensors (e.g., weights of linear layers) are decayed, + while 1D tensors (e.g., biases and LayerNorm weights) are not. + + Arguments: + - model (:obj:`nn.Module`): The model for which to configure optimizers. + - weight_decay (:obj:`float`): The weight decay coefficient to apply. + - learning_rate (:obj:`float`): The learning rate for the optimizer. + - betas (:obj:`Tuple[float, float]`): The beta coefficients for the AdamW optimizer (e.g., (0.9, 0.95)). + - device_type (:obj:`str`): The type of device being used, e.g., 'cuda' or 'cpu'. + + Returns: + (:obj:`torch.optim.AdamW`): The configured AdamW optimizer instance. + """ + # Start with all of the candidate parameters from the model. param_dict = {pn: p for pn, p in model.named_parameters()} - # filter out those that do not require grad - param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} - # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. - # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. + + # TODO: The following code is commented out, which is crucial for a balanced pipeline. + # We do not filter out parameters with `requires_grad=False` because their `requires_grad` + # attribute might be set to `True` at a later stage during training. + # param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} + + # Create optimizer parameter groups. Any parameter that is 2D or higher will be weight decayed, + # otherwise no. i.e. all weight tensors in matrix multiplications and embeddings will be decayed, + # while all biases and layernorm weights will not. decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] optim_groups = [ {'params': decay_params, 'weight_decay': weight_decay}, {'params': nodecay_params, 'weight_decay': 0.0} ] + num_decay_params = sum(p.numel() for p in decay_params) num_nodecay_params = sum(p.numel() for p in nodecay_params) print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") - # Create AdamW optimizer and use the fused version if it is available + + # Create the AdamW optimizer. + # Check if a fused version of AdamW is available in the current PyTorch installation. fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters + + # Note: The current logic creates a standard AdamW optimizer on CUDA-enabled systems. + # The 'fused' version is only considered on non-CUDA systems, where it will ultimately not be used + # because `device_type` would not be 'cuda'. if torch.cuda.is_available(): + # On a CUDA-enabled system, create a standard AdamW optimizer. optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas) else: + # On a non-CUDA system, check if the fused optimizer can be used. + # This will be False if device_type is not 'cuda'. use_fused = fused_available and device_type == 'cuda' extra_args = dict(fused=True) if use_fused else dict() optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) @@ -556,7 +609,7 @@ def concat_output_value(output_lst: List) -> np.ndarray: # concat the values of the model output list value_lst = [] for output in output_lst: - value_lst.append(output.value) # TODO:cpu + value_lst.append(output.value) # print(f'value_lst:{value_lst}') # print(f'value_lst[0]:{value_lst[0]}') diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index 0299abf8f..06fa3b580 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -1,6 +1,6 @@ import time from collections import deque, namedtuple -from typing import Optional, Any, List +from typing import Optional, Any, List, Dict, Set import numpy as np import torch @@ -21,41 +21,43 @@ class MuZeroCollector(ISerialCollector): """ Overview: - The Episode Collector for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero, Gumbel MuZero. - It manages the data collection process for training these algorithms using a serial mechanism. + The episode-based collector for MCTS-based reinforcement learning algorithms, + including MuZero, EfficientZero, Sampled EfficientZero, and Gumbel MuZero. + It orchestrates the data collection process in a serial manner, managing interactions + between the policy and the environment to generate game segments for training. Interfaces: - ``__init__``, ``reset``, ``reset_env``, ``reset_policy``, ``_reset_stat``, ``envstep``, ``__del__``, ``_compute_priorities``, - ``pad_and_save_last_trajectory``, ``collect``, ``_output_log``, ``close`` + ``__init__``, ``reset``, ``reset_env``, ``reset_policy``, ``_reset_stat``, ``collect``, + ``_compute_priorities``, ``pad_and_save_last_trajectory``, ``_output_log``, ``close``, ``__del__``. Properties: - ``envstep`` + ``envstep``. """ - # TO be compatible with ISerialCollector + # Default configuration for the collector. To be compatible with ISerialCollector. config = dict() def __init__( self, collect_print_freq: int = 100, - env: BaseEnvManager = None, - policy: namedtuple = None, + env: Optional[BaseEnvManager] = None, + policy: Optional[namedtuple] = None, tb_logger: 'SummaryWriter' = None, # noqa - exp_name: Optional[str] = 'default_experiment', - instance_name: Optional[str] = 'collector', + exp_name: str = 'default_experiment', + instance_name: str = 'collector', policy_config: 'policy_config' = None, # noqa - task_id: int = None, + task_id: Optional[int] = None, ) -> None: """ Overview: - Initialize the MuZeroCollector with the given parameters. + Initializes the MuZeroCollector with the given configuration. Arguments: - - collect_print_freq (:obj:`int`): Frequency (in training steps) at which to print collection information. - - env (:obj:`Optional[BaseEnvManager]`): Instance of the subclass of vectorized environment manager. - - policy (:obj:`Optional[namedtuple]`): namedtuple of the collection mode policy API. - - tb_logger (:obj:`Optional[SummaryWriter]`): TensorBoard logger instance. - - exp_name (:obj:`str`): Name of the experiment, used for logging and saving purposes. - - instance_name (:obj:`str`): Unique identifier for this collector instance. - - policy_config (:obj:`Optional[policy_config]`): Configuration object for the policy. - - task_id (:obj:`int`): Unique identifier for the task. If None, that means we are in the single task mode. + - collect_print_freq (:obj:`int`): The frequency (in training iterations) at which to print collection statistics. + - env (:obj:`Optional[BaseEnvManager]`): An instance of a vectorized environment manager. + - policy (:obj:`Optional[namedtuple]`): A namedtuple containing the policy's forward pass and other methods. + - tb_logger (:obj:`Optional[SummaryWriter]`): A TensorBoard logger instance for logging metrics. + - exp_name (:obj:`str`): The name of the experiment, used for organizing logs. + - instance_name (:obj:`str`): A unique name for this collector instance. + - policy_config (:obj:`'policy_config'`): The configuration object for the policy. + - task_id (:obj:`Optional[int]`): The identifier for the current task in a multi-task setting. If None, operates in single-task mode. """ self.task_id = task_id self._exp_name = exp_name @@ -64,23 +66,26 @@ def __init__( self._timer = EasyTimer() self._end_flag = False + # Get distributed training info self._rank = get_rank() self._world_size = get_world_size() + + # Logger setup: only rank 0 creates the main logger and TensorBoard logger. if self._rank == 0: if tb_logger is not None: self._logger, _ = build_logger( - path='./{}/log/{}'.format(self._exp_name, self._instance_name), + path=f'./{self._exp_name}/log/{self._instance_name}', name=self._instance_name, need_tb=False ) self._tb_logger = tb_logger else: self._logger, self._tb_logger = build_logger( - path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name + path=f'./{self._exp_name}/log/{self._instance_name}', name=self._instance_name ) else: self._logger, _ = build_logger( - path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False + path=f'./{self._exp_name}/log/{self._instance_name}', name=self._instance_name, need_tb=False ) self._tb_logger = None @@ -92,12 +97,11 @@ def __init__( def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: - Reset or replace the environment managed by this collector. - If _env is None, reset the old environment. - If _env is not None, replace the old environment in the collector with the new passed \ - in environment and launch. + Resets or replaces the environment managed by the collector. + If `_env` is None, it resets the existing environment. Otherwise, it replaces the old + environment with the new one and launches it. Arguments: - - env (:obj:`Optional[BaseEnvManager]`): New environment to manage, if provided. + - _env (:obj:`Optional[BaseEnvManager]`): The new environment to be used. If None, resets the current environment. """ if _env is not None: self._env = _env @@ -109,42 +113,39 @@ def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: """ Overview: - Reset or replace the policy used by this collector. - If _policy is None, reset the old policy. - If _policy is not None, replace the old policy in the collector with the new passed in policy. + Resets or replaces the policy used by the collector. + If `_policy` is None, it resets the existing policy. Otherwise, it replaces the old + policy with the new one. Arguments: - - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy + - _policy (:obj:`Optional[namedtuple]`): The new policy to be used. """ - assert hasattr(self, '_env'), "please set env first" + assert hasattr(self, '_env'), "Please set env first before resetting policy." if _policy is not None: self._policy = _policy self._default_n_episode = _policy.get_attribute('cfg').get('n_episode', None) self._logger.debug( - 'Set default n_episode mode(n_episode({}), env_num({}))'.format(self._default_n_episode, self._env_num) + f"Set default n_episode mode(n_episode({self._default_n_episode}), env_num({self._env_num}))" ) self._policy.reset() def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: - Reset the collector with the given policy and/or environment. - If _env is None, reset the old environment. - If _env is not None, replace the old environment in the collector with the new passed \ - in environment and launch. - If _policy is None, reset the old policy. - If _policy is not None, replace the old policy in the collector with the new passed in policy. + Resets the collector, including the environment and policy. Also re-initializes + internal state variables for tracking collection progress. Arguments: - - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy - - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ - env_manager(BaseEnvManager) + - _policy (:obj:`Optional[namedtuple]`): The new policy to use. + - _env (:obj:`Optional[BaseEnvManager]`): The new environment to use. """ if _env is not None: self.reset_env(_env) if _policy is not None: self.reset_policy(_policy) + # Initialize per-environment tracking info self._env_info = {env_id: {'time': 0., 'step': 0} for env_id in range(self._env_num)} + # Reset overall statistics self._episode_info = [] self._total_envstep_count = 0 self._total_episode_count = 0 @@ -152,18 +153,17 @@ def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvMana self._last_train_iter = 0 self._end_flag = False - # A game_segment_pool implementation based on the deque structure. + # A pool to store completed game segments, implemented using a deque. self.game_segment_pool = deque(maxlen=int(1e6)) self.unroll_plus_td_steps = self.policy_config.num_unroll_steps + self.policy_config.td_steps def _reset_stat(self, env_id: int) -> None: """ Overview: - Reset the collector's state. Including reset the traj_buffer, obs_pool, policy_output_pool \ - and env_info. Reset these states according to env_id. You can refer to base_serial_collector\ - to get more messages. + Resets the statistics for a specific environment, identified by `env_id`. + This is typically called when an episode in that environment ends. Arguments: - - env_id (:obj:`int`): the id where we need to reset the collector's state + - env_id (:obj:`int`): The ID of the environment to reset statistics for. """ self._env_info[env_id] = {'time': 0., 'step': 0} @@ -171,17 +171,17 @@ def _reset_stat(self, env_id: int) -> None: def envstep(self) -> int: """ Overview: - Get the total number of environment steps collected. + Returns the total number of environment steps collected since the last reset. Returns: - - envstep (:obj:`int`): Total number of environment steps collected. + - envstep (:obj:`int`): The total environment step count. """ return self._total_envstep_count def close(self) -> None: """ Overview: - Close the collector. If end_flag is False, close the environment, flush the tb_logger \ - and close the tb_logger. + Closes the collector, including the environment and any loggers. + Ensures that all resources are properly released. """ if self._end_flag: return @@ -194,603 +194,454 @@ def close(self) -> None: def __del__(self) -> None: """ Overview: - Execute the close command and close the collector. __del__ is automatically called to \ - destroy the collector instance when the collector finishes its work + Destructor for the collector instance, ensuring that `close` is called + to clean up resources. """ self.close() # ============================================================== - # MCTS+RL related core code + # MCTS+RL Core Collection Logic # ============================================================== - def _compute_priorities(self, i: int, pred_values_lst: List[float], search_values_lst: List[float]) -> np.ndarray: + def _compute_priorities(self, i: int, pred_values_lst: List[float], search_values_lst: List[float]) -> Optional[np.ndarray]: """ Overview: - Compute the priorities for transitions based on prediction and search value discrepancies. + Computes priorities for experience replay based on the discrepancy between + predicted values and MCTS search values. Arguments: - - i (:obj:`int`): Index of the values in the list to compute the priority for. - - pred_values_lst (:obj:`List[float]`): List of predicted values. - - search_values_lst (:obj:`List[float]`): List of search values obtained from MCTS. + - i (:obj:`int`): The index of the environment's data in the lists. + - pred_values_lst (:obj:`List[float]`): A list containing lists of predicted values for each environment. + - search_values_lst (:obj:`List[float]`): A list containing lists of search values from MCTS for each environment. Returns: - - priorities (:obj:`np.ndarray`): Array of computed priorities. + - priorities (:obj:`Optional[np.ndarray]`): An array of priorities for the transitions. Returns None if priority is not used. """ if self.policy_config.use_priority: - # Calculate priorities. The priorities are the L1 losses between the predicted - # values and the search values. We use 'none' as the reduction parameter, which - # means the loss is calculated for each element individually, instead of being summed or averaged. - # A small constant (1e-6) is added to the results to avoid zero priorities. This - # is done because zero priorities could potentially cause issues in some scenarios. + # Calculate priorities as the L1 loss between predicted values and search values. + # 'reduction=none' ensures the loss is calculated for each element individually. pred_values = torch.from_numpy(np.array(pred_values_lst[i])).to(self.policy_config.device).float().view(-1) - search_values = torch.from_numpy(np.array(search_values_lst[i])).to(self.policy_config.device - ).float().view(-1) - priorities = L1Loss(reduction='none' - )(pred_values, - search_values).detach().cpu().numpy() + 1e-6 + search_values = torch.from_numpy(np.array(search_values_lst[i])).to(self.policy_config.device).float().view(-1) + + # A small epsilon is added to avoid zero priorities. + priorities = L1Loss(reduction='none')(pred_values, search_values).detach().cpu().numpy() + 1e-6 else: - # priorities is None -> use the max priority for all newly collected data + # If priority is not used, return None. The replay buffer will use max priority for new data. priorities = None return priorities - def pad_and_save_last_trajectory(self, i: int, last_game_segments: List[GameSegment], - last_game_priorities: List[np.ndarray], - game_segments: List[GameSegment], done: np.ndarray) -> None: + def pad_and_save_last_trajectory( + self, i: int, last_game_segments: List[Optional[GameSegment]], + last_game_priorities: List[Optional[np.ndarray]], + game_segments: List[GameSegment], done: np.ndarray + ) -> None: """ Overview: - Save the game segment to the pool if the current game is finished, padding it if necessary. + Pads the end of the `last_game_segment` with data from the start of the current `game_segment`. + This is necessary to compute target values for the final transitions of a segment. After padding, + the completed segment is stored in the `game_segment_pool`. Arguments: - - i (:obj:`int`): Index of the current game segment. - - last_game_segments (:obj:`List[GameSegment]`): List of the last game segments to be padded and saved. - - last_game_priorities (:obj:`List[np.ndarray]`): List of priorities of the last game segments. - - game_segments (:obj:`List[GameSegment]`): List of the current game segments. - - done (:obj:`np.ndarray`): Array indicating whether each game is done. + - i (:obj:`int`): The index of the environment being processed. + - last_game_segments (:obj:`List[Optional[GameSegment]]`): List of game segments from the previous collection chunk. + - last_game_priorities (:obj:`List[Optional[np.ndarray]]`): List of priorities corresponding to the last game segments. + - game_segments (:obj:`List[GameSegment]`): List of game segments from the current collection chunk. + - done (:obj:`np.ndarray`): Array indicating if the episode has terminated for each environment. Note: - (last_game_segments[i].obs_segment[-4:][j] == game_segments[i].obs_segment[:4][j]).all() is True + An implicit assumption is that the start of the new segment's observation history overlaps with the + end of the last segment's, e.g., `(last_game_segments[i].obs_segment[-4:][j] == game_segments[i].obs_segment[:4][j]).all()` is True. """ - # pad over last segment trajectory - beg_index = self.policy_config.model.frame_stack_num - end_index = beg_index + self.policy_config.num_unroll_steps + self.policy_config.td_steps - - # the start obs is init zero obs, so we take the - # [ : +] obs as the pad obs - # e.g. the start 4 obs is init zero obs, the num_unroll_steps is 5, so we take the [4:9] obs as the pad obs - pad_obs_lst = game_segments[i].obs_segment[beg_index:end_index] - - # NOTE: for unizero - beg_index = 0 - end_index = beg_index + self.policy_config.num_unroll_steps + self.policy_config.td_steps - pad_action_lst = game_segments[i].action_segment[beg_index:end_index] - - # NOTE: for unizero - pad_child_visits_lst = game_segments[i].child_visit_segment[ - :self.policy_config.num_unroll_steps + self.policy_config.td_steps] - - # EfficientZero original repo bug: - # pad_child_visits_lst = game_segments[i].child_visit_segment[beg_index:end_index] - - beg_index = 0 - end_index = beg_index + self.unroll_plus_td_steps - 1 - - pad_reward_lst = game_segments[i].reward_segment[beg_index:end_index] + # --- Prepare padding data from the current game segment --- + # Observations for padding are taken from the start of the new segment. + beg_index_obs = self.policy_config.model.frame_stack_num + end_index_obs = beg_index_obs + self.policy_config.num_unroll_steps + self.policy_config.td_steps + pad_obs_lst = game_segments[i].obs_segment[beg_index_obs:end_index_obs] + + # Actions for padding. + beg_index_ac = 0 + end_index_ac = beg_index_ac + self.policy_config.num_unroll_steps + self.policy_config.td_steps + pad_action_lst = game_segments[i].action_segment[beg_index_ac:end_index_ac] + + # Child visits for padding. + pad_child_visits_lst = game_segments[i].child_visit_segment[:self.policy_config.num_unroll_steps + self.policy_config.td_steps] + + # Rewards for padding. + beg_index_rew = 0 + end_index_rew = beg_index_rew + self.unroll_plus_td_steps - 1 + pad_reward_lst = game_segments[i].reward_segment[beg_index_rew:end_index_rew] + + # Root values for padding. + beg_index_val = 0 + end_index_val = beg_index_val + self.unroll_plus_td_steps + pad_root_values_lst = game_segments[i].root_value_segment[beg_index_val:end_index_val] if self.policy_config.use_ture_chance_label_in_chance_encoder: - chance_lst = game_segments[i].chance_segment[beg_index:end_index] - - beg_index = 0 - end_index = beg_index + self.unroll_plus_td_steps - - pad_root_values_lst = game_segments[i].root_value_segment[beg_index:end_index] - + chance_lst = game_segments[i].chance_segment[beg_index_rew:end_index_rew] + if self.policy_config.gumbel_algo: - pad_improved_policy_prob = game_segments[i].improved_policy_probs[beg_index:end_index] + pad_improved_policy_prob = game_segments[i].improved_policy_probs[beg_index_val:end_index_val] - # pad over and save + # --- Pad the last game segment and save it --- if self.policy_config.gumbel_algo: - last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst, - next_segment_improved_policy=pad_improved_policy_prob) + last_game_segments[i].pad_over( + pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, + pad_child_visits_lst, next_segment_improved_policy=pad_improved_policy_prob + ) else: if self.policy_config.use_ture_chance_label_in_chance_encoder: - last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst, - next_chances=chance_lst) + last_game_segments[i].pad_over( + pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, + pad_child_visits_lst, next_chances=chance_lst + ) else: - last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst) - """ - Note: - game_segment element shape: - obs: game_segment_length + stack + num_unroll_steps, 20+4 +5 - rew: game_segment_length + stack + num_unroll_steps + td_steps -1 20 +5+3-1 - action: game_segment_length + num_unroll_steps + td_steps -> 20 +5+3 - root_values: game_segment_length + num_unroll_steps + td_steps -> 20 +5+3 - child_visits: game_segment_length + num_unroll_steps -> 20 +5 - to_play: game_segment_length -> 20 - action_mask: game_segment_length -> 20 - """ - + last_game_segments[i].pad_over( + pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst + ) + + # Convert the segment's lists to NumPy arrays for efficient storage. last_game_segments[i].game_segment_to_array() - # put the game segment into the pool + # Add the completed game segment and its associated data to the pool. self.game_segment_pool.append((last_game_segments[i], last_game_priorities[i], done[i])) - # reset last game_segments + # Reset the placeholder for the last game segment. last_game_segments[i] = None last_game_priorities[i] = None - return None - - def collect(self, - n_episode: Optional[int] = None, - train_iter: int = 0, - policy_kwargs: Optional[dict] = None, - collect_with_pure_policy: bool = False) -> List[Any]: + def collect( + self, + n_episode: Optional[int] = None, + train_iter: int = 0, + policy_kwargs: Optional[Dict] = None, + collect_with_pure_policy: bool = False + ) -> List[Any]: """ Overview: - Collect `n_episode` episodes of data with policy_kwargs, trained for `train_iter` iterations. + Collects `n_episode` episodes of data. It manages the entire lifecycle of an episode, + from getting actions from the policy, stepping the environment, storing transitions, + and saving completed game segments. Arguments: - - n_episode (:obj:`Optional[int]`): Number of episodes to collect. - - train_iter (:obj:`int`): Number of training iterations completed so far. - - policy_kwargs (:obj:`Optional[dict]`): Additional keyword arguments for the policy. - - collect_with_pure_policy (:obj:`bool`): Whether to collect data using pure policy without MCTS. + - n_episode (:obj:`Optional[int]`): The number of episodes to collect. If None, uses the default from the policy config. + - train_iter (:obj:`int`): The current training iteration, used for logging. + - policy_kwargs (:obj:`Optional[Dict]`): Additional keyword arguments to pass to the policy's forward method, like temperature for exploration. + - collect_with_pure_policy (:obj:`bool`): If True, collects data using a pure policy (e.g., greedy action) without MCTS. Returns: - - return_data (:obj:`List[Any]`): Collected data in the form of a list. + - return_data (:obj:`List[Any]`): A list containing the collected game segments and metadata. """ - # TODO: collect_with_pure_policy as a separate collector + # TODO(author): Consider implementing `collect_with_pure_policy` as a separate, more streamlined collector for clarity and modularity. if n_episode is None: if self._default_n_episode is None: - raise RuntimeError("Please specify collect n_episode") + raise RuntimeError("Please specify `n_episode` for collection.") else: n_episode = self._default_n_episode - assert n_episode >= self._env_num, "Please make sure n_episode >= env_num{}/{}".format(n_episode, self._env_num) + assert n_episode >= self._env_num, f"Please ensure n_episode ({n_episode}) >= env_num ({self._env_num})." + if policy_kwargs is None: policy_kwargs = {} - temperature = policy_kwargs['temperature'] - epsilon = policy_kwargs['epsilon'] + temperature = policy_kwargs.get('temperature', 1.0) + epsilon = policy_kwargs.get('epsilon', 0.0) + # --- Initializations --- collected_episode = 0 - collected_step = 0 env_nums = self._env_num retry_waiting_time = 0.05 - # initializations + # Wait for all environments to be ready and get initial observations. init_obs = self._env.ready_obs while len(init_obs.keys()) != self._env_num: - # To be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to - # len(self._env.ready_obs), especially in tictactoe env. - self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) - self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) + self._logger.warning(f"Waiting for all environments to reset. Ready envs: {list(init_obs.keys())}") time.sleep(retry_waiting_time) - self._logger.info('=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10) - self._logger.info( - 'After sleeping {}s, the current _env_states is {}'.format(retry_waiting_time, self._env._env_states) - ) init_obs = self._env.ready_obs + # Prepare initial state dictionaries from observations. action_mask_dict = {i: to_ndarray(init_obs[i]['action_mask']) for i in range(env_nums)} to_play_dict = {i: to_ndarray(init_obs[i]['to_play']) for i in range(env_nums)} - timestep_dict = {} - for i in range(env_nums): - if 'timestep' not in init_obs[i]: - print(f"Warning: 'timestep' key is missing in init_obs[{i}], assigning value -1") - timestep_dict[i] = to_ndarray(init_obs[i].get('timestep', -1)) - + timestep_dict = {i: to_ndarray(init_obs[i].get('timestep', -1)) for i in range(env_nums)} if self.policy_config.use_ture_chance_label_in_chance_encoder: chance_dict = {i: to_ndarray(init_obs[i]['chance']) for i in range(env_nums)} - game_segments = [ - GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config - ) for _ in range(env_nums) - ] - # stacked observation windows in reset stage for init game_segments - observation_window_stack = [[] for _ in range(env_nums)] + # Initialize game segments and observation stacks for each environment. + game_segments = [GameSegment(self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config) for _ in range(env_nums)] + observation_window_stack = [deque(maxlen=self.policy_config.model.frame_stack_num) for _ in range(env_nums)] for env_id in range(env_nums): - observation_window_stack[env_id] = deque( - [to_ndarray(init_obs[env_id]['observation']) for _ in range(self.policy_config.model.frame_stack_num)], - maxlen=self.policy_config.model.frame_stack_num - ) + for _ in range(self.policy_config.model.frame_stack_num): + observation_window_stack[env_id].append(to_ndarray(init_obs[env_id]['observation'])) game_segments[env_id].reset(observation_window_stack[env_id]) + # State tracking variables for the collection loop. dones = np.array([False for _ in range(env_nums)]) - last_game_segments = [None for _ in range(env_nums)] - last_game_priorities = [None for _ in range(env_nums)] - # for priorities in self-play + last_game_segments: List[Optional[GameSegment]] = [None for _ in range(env_nums)] + last_game_priorities: List[Optional[np.ndarray]] = [None for _ in range(env_nums)] + + # Buffers for priority calculation. search_values_lst = [[] for _ in range(env_nums)] pred_values_lst = [[] for _ in range(env_nums)] if self.policy_config.gumbel_algo: improved_policy_lst = [[] for _ in range(env_nums)] - # some logs - eps_steps_lst, visit_entropies_lst = np.zeros(env_nums), np.zeros(env_nums) + # Logging variables. + eps_steps_lst = np.zeros(env_nums) + visit_entropies_lst = np.zeros(env_nums) if self.policy_config.gumbel_algo: completed_value_lst = np.zeros(env_nums) - self_play_moves = 0. - self_play_episodes = 0. - self_play_moves_max = 0 - self_play_visit_entropy = [] - total_transitions = 0 - ready_env_id = set() + ready_env_id: Set[int] = set() remain_episode = n_episode if collect_with_pure_policy: - temp_visit_list = [0.0 for i in range(self._env.action_space.n)] + # Dummy visit counts for pure policy collection. + temp_visit_list = [0.0 for _ in range(self._env.action_space.n)] + # --- Main Collection Loop --- while True: with self._timer: - # Get current ready env obs. + # Get observations from ready environments. obs = self._env.ready_obs - new_available_env_id = set(obs.keys()).difference(ready_env_id) - ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) + ready_env_id.update(list(new_available_env_id)[:remain_episode]) remain_episode -= min(len(new_available_env_id), remain_episode) - - # NOTE: If waiting for N environments to synchronize, it may result in some environments not being completed (done) by the time of return. - # However, the current muzero_collector does not properly maintain the global self.last_game_segments, leading to some data not being collected. - - stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} - stack_obs = list(stack_obs.values()) - - action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id} - to_play_dict = {env_id: to_play_dict[env_id] for env_id in ready_env_id} - timestep_dict = {env_id: timestep_dict[env_id] for env_id in ready_env_id} + # Prepare policy inputs. + stack_obs_list = [game_segments[env_id].get_obs() for env_id in ready_env_id] action_mask = [action_mask_dict[env_id] for env_id in ready_env_id] to_play = [to_play_dict[env_id] for env_id in ready_env_id] timestep = [timestep_dict[env_id] for env_id in ready_env_id] - if self.policy_config.use_ture_chance_label_in_chance_encoder: - chance_dict = {env_id: chance_dict[env_id] for env_id in ready_env_id} - - stack_obs = to_ndarray(stack_obs) - # return stack_obs shape: [B, S*C, W, H] e.g. [8, 4*1, 96, 96] - stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) - stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device) + stack_obs_array = to_ndarray(stack_obs_list) + stack_obs_tensor = prepare_observation(stack_obs_array, self.policy_config.model.model_type) + stack_obs_tensor = torch.from_numpy(stack_obs_tensor).to(self.policy_config.device) # ============================================================== - # Key policy forward step + # Policy Forward Pass # ============================================================== - # print(f'ready_env_id:{ready_env_id}') - # policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id) - if self.task_id is None: - # single task setting - policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id, timestep=timestep) - else: - # multi-task setting - policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id, timestep=timestep, task_id=self.task_id) - # Extract relevant policy outputs - actions_with_env_id = {k: v['action'] for k, v in policy_output.items()} - value_dict_with_env_id = {k: v['searched_value'] for k, v in policy_output.items()} - pred_value_dict_with_env_id = {k: v['predicted_value'] for k, v in policy_output.items()} - timestep_dict_with_env_id = { - k: v['timestep'] if 'timestep' in v else -1 for k, v in policy_output.items() + policy_input = { + 'x': stack_obs_tensor, + 'action_mask': action_mask, + 'temperature': temperature, + 'to_play': to_play, + 'epsilon': epsilon, + 'ready_env_id': ready_env_id, + 'timestep': timestep } - + if self.task_id is not None: + policy_input['task_id'] = self.task_id + + policy_output = self._policy.forward(**policy_input) + + # --- Unpack policy outputs --- + actions, value_dict, pred_value_dict = {}, {}, {} + distributions_dict, visit_entropy_dict = {}, {} if self.policy_config.sampled_algo: - root_sampled_actions_dict_with_env_id = { - k: v['root_sampled_actions'] for k, v in policy_output.items() - } - - if not collect_with_pure_policy: - distributions_dict_with_env_id = {k: v['visit_count_distributions'] for k, v in - policy_output.items()} - visit_entropy_dict_with_env_id = {k: v['visit_count_distribution_entropy'] for k, v in - policy_output.items()} - - if self.policy_config.gumbel_algo: - improved_policy_dict_with_env_id = {k: v['improved_policy_probs'] for k, v in - policy_output.items()} - completed_value_with_env_id = {k: v['roots_completed_value'] for k, v in policy_output.items()} - - # Initialize dictionaries to store results - actions = {} - value_dict = {} - pred_value_dict = {} - timestep_dict = {} + root_sampled_actions_dict = {} + if self.policy_config.gumbel_algo: + improved_policy_dict, completed_value_dict = {}, {} - if not collect_with_pure_policy: - distributions_dict = {} - visit_entropy_dict = {} - - if self.policy_config.sampled_algo: - root_sampled_actions_dict = {} - - if self.policy_config.gumbel_algo: - improved_policy_dict = {} - completed_value_dict = {} - - # Populate the result dictionaries for env_id in ready_env_id: - actions[env_id] = actions_with_env_id.pop(env_id) - value_dict[env_id] = value_dict_with_env_id.pop(env_id) - pred_value_dict[env_id] = pred_value_dict_with_env_id.pop(env_id) - timestep_dict[env_id] = timestep_dict_with_env_id.pop(env_id) - + output = policy_output[env_id] + actions[env_id] = output['action'] + value_dict[env_id] = output['searched_value'] + pred_value_dict[env_id] = output['predicted_value'] + if not collect_with_pure_policy: - distributions_dict[env_id] = distributions_dict_with_env_id.pop(env_id) - + distributions_dict[env_id] = output['visit_count_distributions'] + visit_entropy_dict[env_id] = output['visit_count_distribution_entropy'] if self.policy_config.sampled_algo: - root_sampled_actions_dict[env_id] = root_sampled_actions_dict_with_env_id.pop(env_id) - - visit_entropy_dict[env_id] = visit_entropy_dict_with_env_id.pop(env_id) - + root_sampled_actions_dict[env_id] = output['root_sampled_actions'] if self.policy_config.gumbel_algo: - improved_policy_dict[env_id] = improved_policy_dict_with_env_id.pop(env_id) - completed_value_dict[env_id] = completed_value_with_env_id.pop(env_id) + improved_policy_dict[env_id] = output['improved_policy_probs'] + completed_value_dict[env_id] = output['roots_completed_value'] # ============================================================== - # Interact with the environment + # Environment Interaction # ============================================================== timesteps = self._env.step(actions) - interaction_duration = self._timer.value / len(timesteps) + interaction_duration = self._timer.value / len(timesteps) if timesteps else 0 for env_id, episode_timestep in timesteps.items(): with self._timer: + # Handle abnormal timesteps by resetting the environment and policy state. if episode_timestep.info.get('abnormal', False): - # If there is an abnormal episode_timestep, reset all the related variables(including this env). - # suppose there is no reset param, reset this env self._env.reset({env_id: None}) self._policy.reset([env_id]) self._reset_stat(env_id) - self._logger.info('Env{} returns a abnormal step, its info is {}'.format(env_id, episode_timestep.info)) + self._logger.info(f"Environment {env_id} returned an abnormal step, info: {episode_timestep.info}") continue + obs, reward, done, info = episode_timestep.obs, episode_timestep.reward, episode_timestep.done, episode_timestep.info + # Store MCTS search statistics. if collect_with_pure_policy: game_segments[env_id].store_search_stats(temp_visit_list, 0) else: if self.policy_config.sampled_algo: - game_segments[env_id].store_search_stats( - distributions_dict[env_id], value_dict[env_id], root_sampled_actions_dict[env_id] - ) + game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id], root_sampled_actions_dict[env_id]) elif self.policy_config.gumbel_algo: - game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id], - improved_policy=improved_policy_dict[env_id]) + game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id], improved_policy=improved_policy_dict[env_id]) else: game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id]) - # append a transition tuple, including a_t, o_{t+1}, r_{t}, action_mask_{t}, to_play_{t} - # in ``game_segments[env_id].init``, we have appended o_{t} in ``self.obs_segment`` + # Append the current transition to the game segment. + append_args = (actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], to_play_dict[env_id]) if self.policy_config.use_ture_chance_label_in_chance_encoder: - game_segments[env_id].append( - actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], - to_play_dict[env_id], chance_dict[env_id], timestep_dict[env_id] - ) - else: - game_segments[env_id].append( - actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], - to_play_dict[env_id], timestep_dict[env_id] - ) + append_args += (chance_dict[env_id],) + append_args += (timestep_dict[env_id],) + game_segments[env_id].append(*append_args) - # NOTE: the position of code snippet is very important. - # the obs['action_mask'] and obs['to_play'] are corresponding to the next action + # Update state dictionaries for the next step. action_mask_dict[env_id] = to_ndarray(obs['action_mask']) to_play_dict[env_id] = to_ndarray(obs['to_play']) timestep_dict[env_id] = to_ndarray(obs.get('timestep', -1)) if self.policy_config.use_ture_chance_label_in_chance_encoder: chance_dict[env_id] = to_ndarray(obs['chance']) - if self.policy_config.ignore_done: - dones[env_id] = False - else: - dones[env_id] = done - + dones[env_id] = done if not self.policy_config.ignore_done else False + + # Update logging and priority data. if not collect_with_pure_policy: visit_entropies_lst[env_id] += visit_entropy_dict[env_id] if self.policy_config.gumbel_algo: completed_value_lst[env_id] += np.mean(np.array(completed_value_dict[env_id])) - + eps_steps_lst[env_id] += 1 - if self._policy.get_attribute('cfg').type in ['unizero', 'sampled_unizero', 'unizero_multitask', 'sampled_unizero_multitask']: - # TODO: only for UniZero now - self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False) # NOTE: reset_init_data=False - - total_transitions += 1 - if self.policy_config.use_priority: pred_values_lst[env_id].append(pred_value_dict[env_id]) search_values_lst[env_id].append(value_dict[env_id]) - if self.policy_config.gumbel_algo and not collect_with_pure_policy: - improved_policy_lst[env_id].append(improved_policy_dict[env_id]) - # append the newest obs + # Update the observation window with the new observation. observation_window_stack[env_id].append(to_ndarray(obs['observation'])) # ============================================================== - # we will save a game segment if it is the end of the game or the next game segment is finished. + # Game Segment Saving Logic # ============================================================== - - # if game segment is full, we will save the last game segment + # If a segment is full, pad and save the previous segment. if game_segments[env_id].is_full(): - # pad over last segment trajectory if last_game_segments[env_id] is not None: - # TODO(pu): return the one game segment - self.pad_and_save_last_trajectory( - env_id, last_game_segments, last_game_priorities, game_segments, dones - ) + self.pad_and_save_last_trajectory(env_id, last_game_segments, last_game_priorities, game_segments, dones) - # calculate priority + # Calculate priorities for the now-completed `last_game_segment`. priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst) - pred_values_lst[env_id] = [] - search_values_lst[env_id] = [] - if self.policy_config.gumbel_algo and not collect_with_pure_policy: - improved_policy_lst[env_id] = [] + pred_values_lst[env_id], search_values_lst[env_id] = [], [] - # the current game_segments become last_game_segment + # The current segment becomes the `last_game_segment`. last_game_segments[env_id] = game_segments[env_id] last_game_priorities[env_id] = priorities - # create new GameSegment - game_segments[env_id] = GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config - ) + # Start a new game segment. + game_segments[env_id] = GameSegment(self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config) game_segments[env_id].reset(observation_window_stack[env_id]) self._env_info[env_id]['step'] += 1 collected_step += 1 self._env_info[env_id]['time'] += self._timer.value + interaction_duration - if episode_timestep.done: - reward = episode_timestep.info['eval_episode_return'] - info = { - 'reward': reward, - 'time': self._env_info[env_id]['time'], - 'step': self._env_info[env_id]['step'], - } + + # --- Episode Termination Handling --- + if done: + collected_episode += 1 + reward = info['eval_episode_return'] + log_info = {'reward': reward, 'time': self._env_info[env_id]['time'], 'step': self._env_info[env_id]['step']} if not collect_with_pure_policy: - info['visit_entropy'] = visit_entropies_lst[env_id] / eps_steps_lst[env_id] + log_info['visit_entropy'] = visit_entropies_lst[env_id] / eps_steps_lst[env_id] if eps_steps_lst[env_id] > 0 else 0 if self.policy_config.gumbel_algo: - info['completed_value'] = completed_value_lst[env_id] / eps_steps_lst[env_id] - - collected_episode += 1 - self._episode_info.append(info) - - # ============================================================== - # if it is the end of the game, we will save the game segment - # ============================================================== + log_info['completed_value'] = completed_value_lst[env_id] / eps_steps_lst[env_id] if eps_steps_lst[env_id] > 0 else 0 + self._episode_info.append(log_info) - # NOTE: put the penultimate game segment in one episode into the trajectory_pool - # pad over 2th last game_segment using the last game_segment + # Pad and save the segment before the final one. if last_game_segments[env_id] is not None: - self.pad_and_save_last_trajectory( - env_id, last_game_segments, last_game_priorities, game_segments, dones - ) - - # store current segment trajectory + self.pad_and_save_last_trajectory(env_id, last_game_segments, last_game_priorities, game_segments, dones) + + # Process and save the final segment of the episode. priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst) - - # NOTE: put the last game segment in one episode into the trajectory_pool game_segments[env_id].game_segment_to_array() - - # assert len(game_segments[env_id]) == len(priorities) - # NOTE: save the last game segment in one episode into the trajectory_pool if it's not null - if len(game_segments[env_id].reward_segment) != 0: + if len(game_segments[env_id].reward_segment) > 0: self.game_segment_pool.append((game_segments[env_id], priorities, dones[env_id])) - # print(game_segments[env_id].reward_segment) - # reset the finished env and init game_segments + # Reset environment-specific states for a new episode. if n_episode > self._env_num: - # Get current ready env obs. + # Re-initialize the state for this env_id. init_obs = self._env.ready_obs - retry_waiting_time = 0.001 - while len(init_obs.keys()) != self._env_num: - # To be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to - # len(self._env.ready_obs), especially in tictactoe env. - self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) - self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) + while env_id not in init_obs: + self._logger.warning(f"Waiting for env {env_id} to reset...") time.sleep(retry_waiting_time) - self._logger.info( - '=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10 - ) - self._logger.info( - 'After sleeping {}s, the current _env_states is {}'.format( - retry_waiting_time, self._env._env_states - ) - ) init_obs = self._env.ready_obs - - new_available_env_id = set(init_obs.keys()).difference(ready_env_id) - ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) - remain_episode -= min(len(new_available_env_id), remain_episode) - + action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) timestep_dict[env_id] = to_ndarray(init_obs[env_id].get('timestep', -1)) - if self.policy_config.use_ture_chance_label_in_chance_encoder: - chance_dict[env_id] = to_ndarray(init_obs[env_id]['chance']) - - game_segments[env_id] = GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config - ) - observation_window_stack[env_id] = deque( - [init_obs[env_id]['observation'] for _ in range(self.policy_config.model.frame_stack_num)], - maxlen=self.policy_config.model.frame_stack_num - ) + chance_dict[env_id] = to_ndarray(init_obs[env_id]['chance']) + + # Reset game segment and observation stack. + game_segments[env_id] = GameSegment(self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config) + observation_window_stack[env_id].clear() + for _ in range(self.policy_config.model.frame_stack_num): + observation_window_stack[env_id].append(init_obs[env_id]['observation']) game_segments[env_id].reset(observation_window_stack[env_id]) last_game_segments[env_id] = None last_game_priorities[env_id] = None - # log - self_play_moves_max = max(self_play_moves_max, eps_steps_lst[env_id]) - if not collect_with_pure_policy: - self_play_visit_entropy.append(visit_entropies_lst[env_id] / eps_steps_lst[env_id]) - self_play_moves += eps_steps_lst[env_id] - self_play_episodes += 1 - - pred_values_lst[env_id] = [] - search_values_lst[env_id] = [] - eps_steps_lst[env_id] = 0 - visit_entropies_lst[env_id] = 0 + # Reset tracking and logging variables. + pred_values_lst[env_id], search_values_lst[env_id] = [], [] + eps_steps_lst[env_id], visit_entropies_lst[env_id] = 0, 0 + if self.policy_config.gumbel_algo: + completed_value_lst[env_id] = 0 - # Env reset is done by env_manager automatically - self._policy.reset([env_id]) # NOTE: reset the policy for the env_id. Default reset_init_data=True. + # Reset policy and collector stats for the finished environment. + self._policy.reset([env_id]) self._reset_stat(env_id) ready_env_id.remove(env_id) + # --- Check for Collection Completion --- if collected_episode >= n_episode: - # [data, meta_data] - return_data = [self.game_segment_pool[i][0] for i in range(len(self.game_segment_pool))], [ - { - 'priorities': self.game_segment_pool[i][1], - 'done': self.game_segment_pool[i][2], + # Prepare data for returning. + return_data = [ + [item[0] for item in self.game_segment_pool], + [{ + 'priorities': item[1], + 'done': item[2], 'unroll_plus_td_steps': self.unroll_plus_td_steps - } for i in range(len(self.game_segment_pool)) + } for item in self.game_segment_pool] ] self.game_segment_pool.clear() break - + + # --- Finalize and Log --- collected_duration = sum([d['time'] for d in self._episode_info]) - # reduce data when enables DDP + # In DDP, aggregate statistics across all processes. if self._world_size > 1: - # Before allreduce - self._logger.info(f"Rank {self._rank} before allreduce: collected_step={collected_step}, collected_episode={collected_episode}") collected_step = allreduce_data(collected_step, 'sum') collected_episode = allreduce_data(collected_episode, 'sum') collected_duration = allreduce_data(collected_duration, 'sum') - # After allreduce - self._logger.info(f"Rank {self._rank} after allreduce: collected_step={collected_step}, collected_episode={collected_episode}") self._total_envstep_count += collected_step self._total_episode_count += collected_episode self._total_duration += collected_duration - # log self._output_log(train_iter) return return_data def _output_log(self, train_iter: int) -> None: """ Overview: - Log the collector's data and output the log information. + Aggregates and logs collection statistics to the console, TensorBoard, and WandB. + This method is only executed by the rank 0 process in a distributed setup. Arguments: - - train_iter (:obj:`int`): Current training iteration number for logging context. + - train_iter (:obj:`int`): The current training iteration number, used as the logging step. """ if self._rank != 0: return + if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0: self._last_train_iter = train_iter episode_count = len(self._episode_info) envstep_count = sum([d['step'] for d in self._episode_info]) duration = sum([d['time'] for d in self._episode_info]) episode_reward = [d['reward'] for d in self._episode_info] - if not self.collect_with_pure_policy: - visit_entropy = [d['visit_entropy'] for d in self._episode_info] - else: - visit_entropy = [0.0] - if self.policy_config.gumbel_algo: - completed_value = [d['completed_value'] for d in self._episode_info] - self._total_duration += duration + info = { 'episode_count': episode_count, 'envstep_count': envstep_count, 'avg_envstep_per_episode': envstep_count / episode_count, - 'avg_envstep_per_sec': envstep_count / duration, - 'avg_episode_per_sec': episode_count / duration, + 'avg_envstep_per_sec': envstep_count / duration if duration > 0 else 0, + 'avg_episode_per_sec': episode_count / duration if duration > 0 else 0, 'collect_time': duration, 'reward_mean': np.mean(episode_reward), 'reward_std': np.std(episode_reward), @@ -799,25 +650,32 @@ def _output_log(self, train_iter: int) -> None: 'total_envstep_count': self._total_envstep_count, 'total_episode_count': self._total_episode_count, 'total_duration': self._total_duration, - 'visit_entropy': np.mean(visit_entropy), } + + if not self.collect_with_pure_policy: + visit_entropy = [d['visit_entropy'] for d in self._episode_info] + info['visit_entropy_mean'] = np.mean(visit_entropy) if self.policy_config.gumbel_algo: - info['completed_value'] = np.mean(completed_value) + completed_value = [d['completed_value'] for d in self._episode_info] + info['completed_value_mean'] = np.mean(completed_value) + self._episode_info.clear() - self._logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()]))) + + # Log to console + self._logger.info("Collector Training Summary:\n{}".format('\n'.join([f' {k}: {v}' for k, v in info.items()]))) + + # Log to TensorBoard and WandB for k, v in info.items(): - if k in ['each_reward']: - continue - if self.task_id is None: - self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) - else: - self._tb_logger.add_scalar('{}_iter_task{}/'.format(self._instance_name, self.task_id) + k, v, train_iter) - if k in ['total_envstep_count']: - continue if self.task_id is None: - self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count) + tb_prefix_iter = f'{self._instance_name}_iter/' + tb_prefix_step = f'{self._instance_name}_step/' else: - self._tb_logger.add_scalar('{}_step_task{}/'.format(self._instance_name, self.task_id) + k, v, self._total_envstep_count) - + tb_prefix_iter = f'{self._instance_name}_iter_task{self.task_id}/' + tb_prefix_step = f'{self._instance_name}_step_task{self.task_id}/' + + self._tb_logger.add_scalar(tb_prefix_iter + k, v, train_iter) + self._tb_logger.add_scalar(tb_prefix_step + k, v, self._total_envstep_count) + if self.policy_config.use_wandb: - wandb.log({'{}_step/'.format(self._instance_name) + k: v for k, v in info.items()}, step=self._total_envstep_count) + wandb_log_data = {tb_prefix_step + k: v for k, v in info.items()} + wandb.log(wandb_log_data, step=self._total_envstep_count) \ No newline at end of file diff --git a/lzero/worker/muzero_evaluator.py b/lzero/worker/muzero_evaluator.py index 468c75aa5..01fabd38c 100644 --- a/lzero/worker/muzero_evaluator.py +++ b/lzero/worker/muzero_evaluator.py @@ -21,92 +21,87 @@ class MuZeroEvaluator(ISerialEvaluator): """ Overview: - The Evaluator class for MCTS+RL algorithms, such as MuZero, EfficientZero, and Sampled EfficientZero. + The Evaluator for MCTS-based reinforcement learning algorithms, such as MuZero, EfficientZero, and Sampled EfficientZero. Interfaces: __init__, reset, reset_policy, reset_env, close, should_eval, eval Properties: env, policy """ + # Default configuration for the MuZeroEvaluator. + config = dict( + # The frequency of evaluation, measured in training iterations. + eval_freq=50, + ) + @classmethod def default_config(cls: type) -> EasyDict: """ Overview: - Retrieve the default configuration for the evaluator by merging evaluator-specific defaults with other - defaults and any user-provided configuration. + Get the default configuration of the MuZeroEvaluator. Returns: - - cfg (:obj:`EasyDict`): The default configuration for the evaluator. + - cfg (:obj:`EasyDict`): An EasyDict object representing the default configuration. """ cfg = EasyDict(copy.deepcopy(cls.config)) cfg.cfg_type = cls.__name__ + 'Dict' return cfg - config = dict( - # Evaluate every "eval_freq" training iterations. - eval_freq=50, - ) - def __init__( self, eval_freq: int = 1000, n_evaluator_episode: int = 3, - stop_value: int = 1e6, - env: BaseEnvManager = None, - policy: namedtuple = None, - tb_logger: 'SummaryWriter' = None, # noqa - exp_name: Optional[str] = 'default_experiment', - instance_name: Optional[str] = 'evaluator', - policy_config: 'policy_config' = None, # noqa - task_id: int = None, + stop_value: float = 1e6, + env: Optional[BaseEnvManager] = None, + policy: Optional[namedtuple] = None, + tb_logger: Optional['SummaryWriter'] = None, + exp_name: str = 'default_experiment', + instance_name: str = 'evaluator', + policy_config: Optional[EasyDict] = None, + task_id: Optional[int] = None, ) -> None: """ Overview: - Initialize the evaluator with configuration settings for various components such as logger helper and timer. + Initialize the MuZeroEvaluator. Arguments: - - eval_freq (:obj:`int`): Evaluation frequency in terms of training steps. - - n_evaluator_episode (:obj:`int`): Number of episodes to evaluate in total. - - stop_value (:obj:`float`): A reward threshold above which the training is considered converged. - - env (:obj:`Optional[BaseEnvManager]`): An optional instance of a subclass of BaseEnvManager. - - policy (:obj:`Optional[namedtuple]`): An optional API namedtuple defining the policy for evaluation. - - tb_logger (:obj:`Optional[SummaryWriter]`): Optional TensorBoard logger instance. - - exp_name (:obj:`str`): Name of the experiment, used to determine output directory. - - instance_name (:obj:`str`): Name of this evaluator instance. - - policy_config (:obj:`Optional[dict]`): Optional configuration for the game policy. - - task_id (:obj:`int`): Unique identifier for the task. If None, that means we are in the single task mode. + - eval_freq (:obj:`int`): The frequency, in training iterations, at which to run evaluation. + - n_evaluator_episode (:obj:`int`): The total number of episodes to run during each evaluation. + - stop_value (:obj:`float`): The reward threshold at which training is considered converged and will stop. + - env (:obj:`Optional[BaseEnvManager]`): An optional environment manager for evaluation. + - policy (:obj:`Optional[namedtuple]`): An optional policy for evaluation. + - tb_logger (:obj:`Optional['SummaryWriter']`): An optional TensorBoard logger. + - exp_name (:obj:`str`): The name of the experiment, used for logging. + - instance_name (:obj:`str`): The name of this evaluator instance. + - policy_config (:obj:`Optional[EasyDict]`): Configuration for the policy. + - task_id (:obj:`Optional[int]`): The unique identifier for the task. If None, it operates in single-task mode. """ - self.stop_event = threading.Event() # Add stop event to handle timeouts + self.stop_event = threading.Event() # Event to signal a stop, e.g., due to a timeout. self.task_id = task_id self._eval_freq = eval_freq self._exp_name = exp_name self._instance_name = instance_name - # Logger (Monitor will be initialized in policy setter) - # Only rank == 0 learner needs monitor and tb_logger, others only need text_logger to display terminal output. + # Initialize logger. Only rank 0 needs a full logger with TensorBoard. if get_rank() == 0: if tb_logger is not None: self._logger, _ = build_logger( - './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False + f'./{self._exp_name}/log/{self._instance_name}', self._instance_name, need_tb=False ) self._tb_logger = tb_logger else: self._logger, self._tb_logger = build_logger( - './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name + f'./{self._exp_name}/log/{self._instance_name}', self._instance_name ) else: - # self._logger, self._tb_logger = None, None # for close elegantly - # ========== TODO: unizero_multitask ddp_v2 ======== + # TODO(username): Refine logger setup for UniZero multitask with DDP v2. if tb_logger is not None: self._logger, _ = build_logger( - './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False + f'./{self._exp_name}/log/{self._instance_name}', self._instance_name, need_tb=False ) self._tb_logger = tb_logger - self._rank = get_rank() - print(f'rank {self._rank}, self.task_id: {self.task_id}') - self.reset(policy, env) self._timer = EasyTimer() @@ -114,21 +109,16 @@ def __init__( self._stop_value = stop_value # ============================================================== - # MCTS+RL related core code + # MCTS+RL related core properties # ============================================================== self.policy_config = policy_config - # def stop(self): - # self.stop_event.set() - def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: - Reset the environment for the evaluator, optionally replacing it with a new environment. - If _env is None, reset the old environment. If _env is not None, replace the old environment - in the evaluator with the new passed in environment and launch. + Reset the environment. If a new environment is provided, it replaces the old one. Arguments: - - _env (:obj:`Optional[BaseEnvManager]`): An optional new environment instance to replace the existing one. + - _env (:obj:`Optional[BaseEnvManager]`): New environment manager to use. If None, resets the existing environment. """ if _env is not None: self._env = _env @@ -140,13 +130,11 @@ def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: """ Overview: - Reset the policy for the evaluator, optionally replacing it with a new policy. - If _policy is None, reset the old policy. - If _policy is not None, replace the old policy in the evaluator with the new passed in policy. + Reset the policy. If a new policy is provided, it replaces the old one. Arguments: - - _policy (:obj:`Optional[namedtuple]`): An optional new policy namedtuple to replace the existing one. + - _policy (:obj:`Optional[namedtuple]`): New policy to use. If None, resets the existing policy. """ - assert hasattr(self, '_env'), "please set env first" + assert hasattr(self, '_env'), "Please set environment first." if _policy is not None: self._policy = _policy self._policy.reset(task_id=self.task_id) @@ -154,15 +142,10 @@ def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: - Reset both the policy and environment for the evaluator, optionally replacing them. - If _env is None, reset the old environment. - If _env is not None, replace the old environment in the evaluator with the new passed in \ - environment and launch. - If _policy is None, reset the old policy. - If _policy is not None, replace the old policy in the evaluator with the new passed in policy. + Reset both the policy and the environment. Arguments: - - _policy (:obj:`Optional[namedtuple]`): An optional new policy namedtuple to replace the existing one. - - _env (:obj:`Optional[BaseEnvManager]`): An optional new environment instance to replace the existing one. + - _policy (:obj:`Optional[namedtuple]`): New policy to use. + - _env (:obj:`Optional[BaseEnvManager]`): New environment manager to use. """ if _env is not None: self.reset_env(_env) @@ -175,32 +158,32 @@ def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvMana def close(self) -> None: """ Overview: - Close the evaluator, the environment, flush and close the TensorBoard logger if applicable. + Close the evaluator, including the environment and the TensorBoard logger. """ if self._end_flag: return self._end_flag = True - self._env.close() + if hasattr(self, '_env'): + self._env.close() if self._tb_logger: self._tb_logger.flush() self._tb_logger.close() - def __del__(self): + def __del__(self) -> None: """ Overview: - Execute the close command and close the evaluator. __del__ is automatically called \ - to destroy the evaluator instance when the evaluator finishes its work + Destructor that ensures `close` is called to clean up resources. """ self.close() def should_eval(self, train_iter: int) -> bool: """ Overview: - Determine whether to initiate evaluation based on the training iteration count and evaluation frequency. + Determine whether it's time to run an evaluation based on the training iteration. Arguments: - - train_iter (:obj:`int`): The current count of training iterations. + - train_iter (:obj:`int`): The current training iteration. Returns: - - (:obj:`bool`): `True` if evaluation should be initiated, otherwise `False`. + - (:obj:`bool`): True if evaluation should be run, otherwise False. """ if train_iter == self._last_eval_iter: return False @@ -211,24 +194,25 @@ def should_eval(self, train_iter: int) -> bool: def eval( self, - save_ckpt_fn: Callable = None, + save_ckpt_fn: Optional[Callable] = None, train_iter: int = -1, envstep: int = -1, n_episode: Optional[int] = None, return_trajectory: bool = False, - ) -> Tuple[bool, float]: + ) -> Tuple[bool, Dict[str, Any]]: """ Overview: - Evaluate the current policy, storing the best policy if it achieves the highest historical reward. + Run a full evaluation process. It will evaluate the current policy, log the results, + and save a checkpoint if a new best performance is achieved. Arguments: - - save_ckpt_fn (:obj:`Optional[Callable]`): Optional function to save a checkpoint when a new best reward is achieved. - - train_iter (:obj:`int`): The current training iteration count. - - envstep (:obj:`int`): The current environment step count. - - n_episode (:obj:`Optional[int]`): Optional number of evaluation episodes; defaults to the evaluator's setting. - - return_trajectory (:obj:`bool`): Return the evaluated trajectory `game_segments` in `episode_info` if True. + - save_ckpt_fn (:obj:`Optional[Callable]`): A function to save a checkpoint. Called when a new best reward is achieved. + - train_iter (:obj:`int`): The current training iteration. + - envstep (:obj:`int`): The current total environment steps. + - n_episode (:obj:`Optional[int]`): The number of episodes to evaluate. Defaults to the value set in `__init__`. + - return_trajectory (:obj:`bool`): Whether to return the collected `game_segments` in the result dictionary. Returns: - - stop_flag (:obj:`bool`): Indicates whether the training can be stopped based on the stop value. - - episode_info (:obj:`Dict[str, Any]`): A dictionary containing information about the evaluation episodes. + - stop_flag (:obj:`bool`): A flag indicating whether the training should stop (e.g., if the stop value is reached). + - episode_info (:obj:`Dict[str, Any]`): A dictionary containing evaluation results, such as rewards and episode lengths. """ if torch.cuda.is_available(): print(f"=========in eval() Rank {get_rank()} ===========") @@ -237,16 +221,14 @@ def eval( torch.cuda.set_device(get_rank()) print(f"set device后的 GPU 设备编号: {get_rank()}") - # the evaluator only works on rank0 + # The evaluator is designed to work on rank 0, but DDP support is being developed. episode_info = None stop_flag = False - # ======== TODO: unizero_multitask ddp_v2 ======== - # if get_rank() == 0: + # TODO(username): Refine evaluation logic for UniZero multitask with DDP v2. if get_rank() >= 0: - if n_episode is None: n_episode = self._default_n_episode - assert n_episode is not None, "please indicate eval n_episode" + assert n_episode is not None, "Please specify the number of evaluation episodes (n_episode)." envstep_count = 0 eval_monitor = VectorEvalMonitor(self._env.env_num, n_episode) env_nums = self._env.env_num @@ -254,21 +236,14 @@ def eval( self._env.reset() self._policy.reset(task_id=self.task_id) - # initializations + # Initializations init_obs = self._env.ready_obs + # Wait for all environments to be ready, especially in subprocess-based environment managers. retry_waiting_time = 0.001 while len(init_obs.keys()) != self._env_num: - # To be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to - # len(self._env.ready_obs), especially in tictactoe env. - self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) - self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) + self._logger.info(f"Waiting for all environments to reset. Current ready envs: {list(init_obs.keys())}") time.sleep(retry_waiting_time) - self._logger.info('=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10) - self._logger.info( - 'After sleeping {}s, the current _env_states is {}'.format(retry_waiting_time, - self._env._env_states) - ) init_obs = self._env.ready_obs action_mask_dict = {i: to_ndarray(init_obs[i]['action_mask']) for i in range(env_nums)} @@ -279,7 +254,7 @@ def eval( if 'timestep' not in init_obs[i]: print(f"Warning: 'timestep' key is missing in init_obs[{i}], assigning value -1") timestep_dict[i] = to_ndarray(init_obs[i].get('timestep', -1)) - + dones = np.array([False for _ in range(env_nums)]) game_segments = [ @@ -300,24 +275,20 @@ def eval( eps_steps_lst = np.zeros(env_nums) with self._timer: while not eval_monitor.is_finished(): - - # Check if stop_event is set (timeout occurred) + # Check if a timeout has occurred. if self.stop_event.is_set(): self._logger.info("[EVALUATOR]: Evaluation aborted due to timeout.") break - # Get current ready env obs. + # Get observations from ready environments. obs = self._env.ready_obs new_available_env_id = set(obs.keys()).difference(ready_env_id) ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) remain_episode -= min(len(new_available_env_id), remain_episode) + # Prepare stacked observations and other inputs for the policy. stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} stack_obs = list(stack_obs.values()) - - action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id} - to_play_dict = {env_id: to_play_dict[env_id] for env_id in ready_env_id} - timestep_dict = {env_id: timestep_dict[env_id] for env_id in ready_env_id} action_mask = [action_mask_dict[env_id] for env_id in ready_env_id] to_play = [to_play_dict[env_id] for env_id in ready_env_id] timestep = [timestep_dict[env_id] for env_id in ready_env_id] @@ -327,42 +298,30 @@ def eval( stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() # ============================================================== - # policy forward + # Policy Forward Pass # ============================================================== - # policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id) if self.task_id is None: - # single task setting - policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id) + # Single-task setting + policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id, timestep=timestep) else: - # multi task setting - policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id, task_id=self.task_id) + # Multi-task setting + policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id, timestep=timestep, task_id=self.task_id) + # Unpack policy outputs. actions_with_env_id = {k: v['action'] for k, v in policy_output.items()} distributions_dict_with_env_id = {k: v['visit_count_distributions'] for k, v in policy_output.items()} if self.policy_config.sampled_algo: - root_sampled_actions_dict_with_env_id = { - k: v['root_sampled_actions'] - for k, v in policy_output.items() - } - + root_sampled_actions_dict_with_env_id = {k: v['root_sampled_actions'] for k, v in policy_output.items()} value_dict_with_env_id = {k: v['searched_value'] for k, v in policy_output.items()} pred_value_dict_with_env_id = {k: v['predicted_value'] for k, v in policy_output.items()} - timestep_dict_with_env_id = { - k: v['timestep'] if 'timestep' in v else -1 for k, v in policy_output.items() - } - visit_entropy_dict_with_env_id = { - k: v['visit_count_distribution_entropy'] - for k, v in policy_output.items() - } - - actions = {} - distributions_dict = {} + timestep_dict_with_env_id = {k: v.get('timestep', -1) for k, v in policy_output.items()} + visit_entropy_dict_with_env_id = {k: v['visit_count_distribution_entropy'] for k, v in policy_output.items()} + + # Remap outputs from policy's internal IDs to environment IDs. + actions, distributions_dict, value_dict, pred_value_dict, timestep_dict, visit_entropy_dict = {}, {}, {}, {}, {}, {} if self.policy_config.sampled_algo: root_sampled_actions_dict = {} - value_dict = {} - pred_value_dict = {} - timestep_dict = {} - visit_entropy_dict = {} + for index, env_id in enumerate(ready_env_id): actions[env_id] = actions_with_env_id.pop(env_id) distributions_dict[env_id] = distributions_dict_with_env_id.pop(env_id) @@ -374,7 +333,7 @@ def eval( visit_entropy_dict[env_id] = visit_entropy_dict_with_env_id.pop(env_id) # ============================================================== - # Interact with env. + # Environment Interaction # ============================================================== timesteps = self._env.step(actions) timesteps = to_tensor(timesteps, dtype=torch.float32) @@ -382,8 +341,8 @@ def eval( obs, reward, done, info = episode_timestep.obs, episode_timestep.reward, episode_timestep.done, episode_timestep.info eps_steps_lst[env_id] += 1 + # This reset logic is specific to UniZero-like models. if self._policy.get_attribute('cfg').type in ['unizero', 'sampled_unizero']: - # only for UniZero now self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False, task_id=self.task_id) game_segments[env_id].append( @@ -391,15 +350,13 @@ def eval( to_play_dict[env_id], timestep_dict[env_id] ) - # NOTE: the position of code snippet is very important. - # the obs['action_mask'] and obs['to_play'] are corresponding to next action + # IMPORTANT: The action_mask and to_play from the new observation correspond to the *next* state. action_mask_dict[env_id] = to_ndarray(obs['action_mask']) to_play_dict[env_id] = to_ndarray(obs['to_play']) timestep_dict[env_id] = to_ndarray(obs.get('timestep', -1)) dones[env_id] = done if episode_timestep.done: - # Env reset is done by env_manager automatically. self._policy.reset([env_id]) reward = episode_timestep.info['eval_episode_return'] saved_info = {'eval_episode_return': episode_timestep.info['eval_episode_return']} @@ -408,41 +365,27 @@ def eval( eval_monitor.update_info(env_id, saved_info) eval_monitor.update_reward(env_id, reward) self._logger.info( - "[EVALUATOR]env {} finish episode, final reward: {}, current episode: {}".format( - env_id, eval_monitor.get_latest_reward(env_id), eval_monitor.get_current_episode() - ) + f"[EVALUATOR] env {env_id} finished episode, final reward: {eval_monitor.get_latest_reward(env_id)}, " + f"current episode count: {eval_monitor.get_current_episode()}" ) - # reset the finished env and init game_segments + # If there are more episodes to run than available environments, reset and reuse this one. if n_episode > self._env_num: - # Get current ready env obs. init_obs = self._env.ready_obs - retry_waiting_time = 0.001 + # Wait for the environment to be ready again. while len(init_obs.keys()) != self._env_num: - # In order to be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to - # len(self._env.ready_obs), especially in tictactoe env. - self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) - self._logger.info( - 'Before sleeping, the _env_states is {}'.format(self._env._env_states) - ) + self._logger.info(f"Waiting for env {env_id} to reset. Current ready envs: {list(init_obs.keys())}") time.sleep(retry_waiting_time) - self._logger.info( - '=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10 - ) - self._logger.info( - 'After sleeping {}s, the current _env_states is {}'.format( - retry_waiting_time, self._env._env_states - ) - ) init_obs = self._env.ready_obs new_available_env_id = set(init_obs.keys()).difference(ready_env_id) ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) remain_episode -= min(len(new_available_env_id), remain_episode) + # Re-initialize state for the new episode. action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) - timestep_dict[env_id] = to_ndarray(init_obs[env_id]['timestep']) + timestep_dict[env_id] = to_ndarray(init_obs[env_id].get('timestep', -1)) game_segments[env_id] = GameSegment( self._env.action_space, @@ -450,76 +393,69 @@ def eval( config=self.policy_config, task_id=self.task_id ) - game_segments[env_id].reset( - [ - init_obs[env_id]['observation'] - for _ in range(self.policy_config.model.frame_stack_num) - ] + [init_obs[env_id]['observation'] for _ in range(self.policy_config.model.frame_stack_num)] ) eps_steps_lst[env_id] = 0 - - # Env reset is done by env_manager automatically. - self._policy.reset([env_id]) # NOTE: reset the policy for the env_id. Default reset_init_data=True. + # NOTE: Reset the policy state for this env_id. `reset_init_data` defaults to True. + self._policy.reset([env_id]) ready_env_id.remove(env_id) envstep_count += 1 + duration = self._timer.value episode_return = eval_monitor.get_episode_return() info = { 'train_iter': train_iter, - 'ckpt_name': 'iteration_{}.pth.tar'.format(train_iter), + 'ckpt_name': f'iteration_{train_iter}.pth.tar', 'episode_count': n_episode, 'envstep_count': envstep_count, - 'avg_envstep_per_episode': envstep_count / n_episode, + 'avg_envstep_per_episode': envstep_count / n_episode if n_episode > 0 else 0, 'evaluate_time': duration, - 'avg_envstep_per_sec': envstep_count / duration, - 'avg_time_per_episode': n_episode / duration, + 'avg_envstep_per_sec': envstep_count / duration if duration > 0 else 0, + 'avg_time_per_episode': n_episode / duration if duration > 0 else 0, 'reward_mean': np.mean(episode_return), 'reward_std': np.std(episode_return), 'reward_max': np.max(episode_return), 'reward_min': np.min(episode_return), - # 'each_reward': episode_return, } episode_info = eval_monitor.get_episode_info() if episode_info is not None: info.update(episode_info) - - print(f'rank {self._rank}, self.task_id: {self.task_id}') + print(f'rank {self._rank}, self.task_id: {self.task_id}') self._logger.info(self._logger.get_tabulate_vars_hor(info)) + + # Log to TensorBoard and WandB. for k, v in info.items(): - if k in ['train_iter', 'ckpt_name', 'each_reward']: - continue - if not np.isscalar(v): + if k in ['train_iter', 'ckpt_name', 'each_reward'] or not np.isscalar(v): continue if self.task_id is None: - self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) - self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep) + self._tb_logger.add_scalar(f'{self._instance_name}_iter/{k}', v, train_iter) + self._tb_logger.add_scalar(f'{self._instance_name}_step/{k}', v, envstep) else: - self._tb_logger.add_scalar('{}_iter_task{}/'.format(self._instance_name, self.task_id) + k, v, - train_iter) - self._tb_logger.add_scalar('{}_step_task{}/'.format(self._instance_name, self.task_id) + k, v, - envstep) + self._tb_logger.add_scalar(f'{self._instance_name}_iter_task{self.task_id}/{k}', v, train_iter) + self._tb_logger.add_scalar(f'{self._instance_name}_step_task{self.task_id}/{k}', v, envstep) if self.policy_config.use_wandb: - wandb.log({'{}_step/'.format(self._instance_name) + k: v}, step=envstep) + wandb.log({f'{self._instance_name}_step/{k}': v}, step=envstep) - episode_return = np.mean(episode_return) - if episode_return > self._max_episode_return: + # Check for new best performance and save checkpoint. + mean_episode_return = np.mean(episode_return) + if mean_episode_return > self._max_episode_return: if save_ckpt_fn: save_ckpt_fn('ckpt_best.pth.tar') - self._max_episode_return = episode_return - stop_flag = episode_return >= self._stop_value and train_iter > 0 + self._max_episode_return = mean_episode_return + + # Check if the stop condition is met. + stop_flag = mean_episode_return >= self._stop_value and train_iter > 0 if stop_flag: self._logger.info( - "[LightZero serial pipeline] " + - "Current episode_return: {} is greater than stop_value: {}".format(episode_return, - self._stop_value) + - ", so your MCTS/RL agent is converged, you can refer to 'log/evaluator/evaluator_logger.txt' for details." + f"[LightZero serial pipeline] Current episode_return: {mean_episode_return} is greater than " + f"stop_value: {self._stop_value}. The agent is considered converged." ) - # ========== TODO: unizero_multitask ddp_v2 ======== + # TODO(username): Finalize DDP synchronization for evaluation results. # if get_world_size() > 1: # objects = [stop_flag, episode_info] # print(f'rank {self._rank}, self.task_id: {self.task_id}') diff --git a/lzero/worker/muzero_segment_collector.py b/lzero/worker/muzero_segment_collector.py index 668a05118..ad7f91bf9 100644 --- a/lzero/worker/muzero_segment_collector.py +++ b/lzero/worker/muzero_segment_collector.py @@ -1,7 +1,7 @@ import logging import time from collections import deque, namedtuple -from typing import Optional, Any, List +from typing import Optional, Any, List, Dict import numpy as np import torch @@ -20,21 +20,20 @@ class MuZeroSegmentCollector(ISerialCollector): """ Overview: - MuZeroSegmentCollector is a data collector for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero, and Gumbel MuZero. - It manages the data collection process for training these algorithms using a serial mechanism. - - The main difference from MuZeroCollector is that MuZeroSegmentCollector returns after collecting a specified number of segments, - whereas MuZeroCollector returns after collecting a complete game. This provides more extensibility and flexibility in data collection. + MuZeroSegmentCollector is a data collector for MCTS+RL algorithms, including MuZero, EfficientZero, + Sampled EfficientZero, and Gumbel MuZero. It manages the data collection process for training these + algorithms using a serial mechanism. + The main difference from MuZeroCollector is that MuZeroSegmentCollector returns after collecting a + specified number of segments, whereas MuZeroCollector returns after collecting a complete game. + This provides more extensibility and flexibility in data collection. Interfaces: - ``__init__``, ``reset``, ``reset_env``, ``reset_policy``, ``_reset_stat``, ``envstep``, ``__del__``, ``_compute_priorities``, - ``pad_and_save_last_trajectory``, ``collect``, ``_output_log``, ``close`` - + ``__init__``, ``reset``, ``reset_env``, ``reset_policy``, ``_reset_stat``, ``collect``, ``close``, ``__del__`` Properties: - ``envstep``: Counter for the current number of environment steps. + - envstep (:obj:`int`): The total number of environment steps collected. """ - # To be compatible with ISerialCollector + # To be compatible with ISerialCollector. config = dict() def __init__( @@ -50,18 +49,18 @@ def __init__( ) -> None: """ Overview: - Initialize the MuZeroCollector with the given parameters. + Initializes the MuZeroSegmentCollector. Arguments: - - collect_print_freq (:obj:`int`): Frequency (in training steps) at which to print collection information. - - env (:obj:`Optional[BaseEnvManager]`): Instance of the subclass of vectorized environment manager. - - policy (:obj:`Optional[namedtuple]`): namedtuple of the collection mode policy API. - - tb_logger (:obj:`Optional[SummaryWriter]`): TensorBoard logger instance. - - exp_name (:obj:`str`): Name of the experiment, used for logging and saving purposes. - - instance_name (:obj:`str`): Unique identifier for this collector instance. - - policy_config (:obj:`Optional[policy_config]`): Configuration object for the policy. + - collect_print_freq (:obj:`int`): The frequency (in training steps) at which to print collection information. + - env (:obj:`Optional[BaseEnvManager]`): An instance of a vectorized environment manager. + - policy (:obj:`Optional[namedtuple]`): A namedtuple containing the collect mode policy API. + - tb_logger (:obj:`Optional[SummaryWriter]`): A TensorBoard logger instance. + - exp_name (:obj:`str`): The name of the experiment, used for logging and saving. + - instance_name (:obj:`str`): A unique identifier for this collector instance. + - policy_config (:obj:`Optional[policy_config]`): The configuration object for the policy. + - task_id (:obj:`int`): The ID of the task, used in multi-task learning settings. """ self.task_id = task_id - self._exp_name = exp_name self._instance_name = instance_name self._collect_print_freq = collect_print_freq @@ -69,31 +68,25 @@ def __init__( self._end_flag = False self._rank = get_rank() - - print(f'rank {self._rank}, self.task_id: {self.task_id}') - - self._world_size = get_world_size() + if self._rank == 0: if tb_logger is not None: self._logger, _ = build_logger( - path='./{}/log/{}'.format(self._exp_name, self._instance_name), - name=self._instance_name, - need_tb=False + path=f'./{self._exp_name}/log/{self._instance_name}', name=self._instance_name, need_tb=False ) self._tb_logger = tb_logger else: self._logger, self._tb_logger = build_logger( - path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name + path=f'./{self._exp_name}/log/{self._instance_name}', name=self._instance_name ) else: self._logger, _ = build_logger( - path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False + path=f'./{self._exp_name}/log/{self._instance_name}', name=self._instance_name, need_tb=False ) - # =========== TODO: for unizero_multitask ddp_v2 ======== + # TODO(author): This part is for UniZero multi-task DDP v2 compatibility. self._tb_logger = tb_logger - self.policy_config = policy_config self.collect_with_pure_policy = self.policy_config.collect_with_pure_policy @@ -102,12 +95,11 @@ def __init__( def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: - Reset or replace the environment managed by this collector. - If _env is None, reset the old environment. - If _env is not None, replace the old environment in the collector with the new passed \ - in environment and launch. + Resets or replaces the environment managed by the collector. + If `_env` is None, it resets the existing environment. Otherwise, it replaces the old + environment with the new one and launches it. Arguments: - - env (:obj:`Optional[BaseEnvManager]`): New environment to manage, if provided. + - _env (:obj:`Optional[BaseEnvManager]`): The new environment to be used. Defaults to None. """ if _env is not None: self._env = _env @@ -119,35 +111,28 @@ def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: """ Overview: - Reset or replace the policy used by this collector. - If _policy is None, reset the old policy. - If _policy is not None, replace the old policy in the collector with the new passed in policy. + Resets or replaces the policy used by the collector. + If `_policy` is None, it resets the existing policy. Otherwise, it replaces the old + policy with the new one. Arguments: - - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy + - _policy (:obj:`Optional[namedtuple]`): The new policy's API in a namedtuple format. Defaults to None. """ - assert hasattr(self, '_env'), "please set env first" + assert hasattr(self, '_env'), "Please set env before resetting policy." if _policy is not None: self._policy = _policy - - self._default_num_segments = _policy.get_attribute('cfg').get('num_segments', None) + self._default_num_segments = self._policy.get_attribute('cfg').get('num_segments', None) self._logger.debug( - 'Set default num_segments mode(num_segments({}), env_num({}))'.format(self._default_num_segments, self._env_num) + f'Set default num_segments mode(num_segments({self._default_num_segments}), env_num({self._env_num}))' ) self._policy.reset(task_id=self.task_id) def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: - Reset the collector with the given policy and/or environment. - If _env is None, reset the old environment. - If _env is not None, replace the old environment in the collector with the new passed \ - in environment and launch. - If _policy is None, reset the old policy. - If _policy is not None, replace the old policy in the collector with the new passed in policy. + Resets the collector state, including the environment and policy. Arguments: - - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy - - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ - env_manager(BaseEnvManager) + - _policy (:obj:`Optional[namedtuple]`): The new policy to use. Defaults to None. + - _env (:obj:`Optional[BaseEnvManager]`): The new environment to use. Defaults to None. """ if _env is not None: self.reset_env(_env) @@ -156,13 +141,12 @@ def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvMana self._env_info = {env_id: {'time': 0., 'step': 0} for env_id in range(self._env_num)} - # Initialize action_mask_dict, to_play_dict, and chance_dict here to ensure they contain values for all env_id + # Initialize dictionaries to store environment-specific states. self.action_mask_dict = {i: None for i in range(self._env_num)} self.to_play_dict = {i: None for i in range(self._env_num)} + self.timestep_dict = {i: None for i in range(self._env_num)} if self.policy_config.use_ture_chance_label_in_chance_encoder: self.chance_dict = {i: None for i in range(self._env_num)} - - self.timestep_dict = {i: None for i in range(self._env_num)} self.dones = np.array([False for _ in range(self._env_num)]) self.last_game_segments = [None for _ in range(self._env_num)] @@ -175,18 +159,16 @@ def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvMana self._last_train_iter = 0 self._end_flag = False - # A game_segment_pool implementation based on the deque structure. + # A deque-based pool for storing game segments. self.game_segment_pool = deque(maxlen=int(1e6)) self.unroll_plus_td_steps = self.policy_config.num_unroll_steps + self.policy_config.td_steps def _reset_stat(self, env_id: int) -> None: """ Overview: - Reset the collector's state. Including reset the traj_buffer, obs_pool, policy_output_pool \ - and env_info. Reset these states according to env_id. You can refer to base_serial_collector\ - to get more messages. + Resets the statistics for a specific environment. Arguments: - - env_id (:obj:`int`): the id where we need to reset the collector's state + - env_id (:obj:`int`): The ID of the environment to reset. """ self._env_info[env_id] = {'time': 0., 'step': 0} @@ -194,17 +176,16 @@ def _reset_stat(self, env_id: int) -> None: def envstep(self) -> int: """ Overview: - Get the total number of environment steps collected. + Returns the total number of environment steps collected. Returns: - - envstep (:obj:`int`): Total number of environment steps collected. + - envstep (:obj:`int`): The total count of environment steps. """ return self._total_envstep_count def close(self) -> None: """ Overview: - Close the collector. If end_flag is False, close the environment, flush the tb_logger \ - and close the tb_logger. + Closes the collector, including the environment and the TensorBoard logger. """ if self._end_flag: return @@ -217,79 +198,63 @@ def close(self) -> None: def __del__(self) -> None: """ Overview: - Execute the close command and close the collector. __del__ is automatically called to \ - destroy the collector instance when the collector finishes its work + Ensures that the `close` method is called when the collector instance is deleted. """ self.close() - # ============================================================== - # MCTS+RL related core code - # ============================================================== - def _compute_priorities(self, i: int, pred_values_lst: List[float], search_values_lst: List[float]) -> np.ndarray: + def _compute_priorities(self, i: int, pred_values_lst: List[float], search_values_lst: List[float]) -> Optional[np.ndarray]: """ Overview: - Compute the priorities for transitions based on prediction and search value discrepancies. + Computes priorities for transitions based on the discrepancy between predicted and search values. Arguments: - - i (:obj:`int`): Index of the values in the list to compute the priority for. - - pred_values_lst (:obj:`List[float]`): List of predicted values. - - search_values_lst (:obj:`List[float]`): List of search values obtained from MCTS. + - i (:obj:`int`): The index of the values list to process. + - pred_values_lst (:obj:`List[float]`): A list containing lists of predicted values. + - search_values_lst (:obj:`List[float]`): A list containing lists of search values from MCTS. Returns: - - priorities (:obj:`np.ndarray`): Array of computed priorities. + - priorities (:obj:`Optional[np.ndarray]`): An array of computed priorities, or None if priority is disabled. """ if self.policy_config.use_priority: - # Calculate priorities. The priorities are the L1 losses between the predicted - # values and the search values. We use 'none' as the reduction parameter, which - # means the loss is calculated for each element individually, instead of being summed or averaged. - # A small constant (1e-6) is added to the results to avoid zero priorities. This - # is done because zero priorities could potentially cause issues in some scenarios. + # Calculate priorities as the L1 loss between predicted and search values. + # The reduction is 'none' to get per-element losses. + # A small epsilon (1e-6) is added to prevent zero priorities. pred_values = torch.from_numpy(np.array(pred_values_lst[i])).to(self.policy_config.device).float().view(-1) - search_values = torch.from_numpy(np.array(search_values_lst[i])).to(self.policy_config.device - ).float().view(-1) - priorities = L1Loss(reduction='none' - )(pred_values, - search_values).detach().cpu().numpy() + 1e-6 + search_values = torch.from_numpy(np.array(search_values_lst[i])).to(self.policy_config.device).float().view(-1) + priorities = L1Loss(reduction='none')(pred_values, search_values).detach().cpu().numpy() + 1e-6 else: - # priorities is None -> use the max priority for all newly collected data + # If not using priority, all new data will use the maximum priority in the replay buffer. priorities = None return priorities - def pad_and_save_last_trajectory(self, i: int, last_game_segments: List[GameSegment], - last_game_priorities: List[np.ndarray], - game_segments: List[GameSegment], done: np.ndarray) -> None: + def pad_and_save_last_trajectory( + self, i: int, last_game_segments: List[GameSegment], last_game_priorities: List[np.ndarray], + game_segments: List[GameSegment], done: np.ndarray + ) -> None: """ Overview: - Save the game segment to the pool if the current game is finished, padding it if necessary. + Pads the last game segment with data from the current segment and saves it to the pool. + This is done when a game ends or a segment becomes full. Arguments: - - i (:obj:`int`): Index of the current game segment. - - last_game_segments (:obj:`List[GameSegment]`): List of the last game segments to be padded and saved. - - last_game_priorities (:obj:`List[np.ndarray]`): List of priorities of the last game segments. - - game_segments (:obj:`List[GameSegment]`): List of the current game segments. - - done (:obj:`np.ndarray`): Array indicating whether each game is done. - Note: - (last_game_segments[i].obs_segment[-4:][j] == game_segments[i].obs_segment[:4][j]).all() is True + - i (:obj:`int`): The index of the current game segment (and environment). + - last_game_segments (:obj:`List[GameSegment]`): The list of previous game segments to be padded. + - last_game_priorities (:obj:`List[np.ndarray]`): The list of priorities for the previous game segments. + - game_segments (:obj:`List[GameSegment]`): The list of current game segments, used for padding data. + - done (:obj:`np.ndarray`): An array indicating whether each game has terminated. """ - # pad over last segment trajectory + # Pad the trajectory of the last segment. beg_index = self.policy_config.model.frame_stack_num end_index = beg_index + self.policy_config.num_unroll_steps + self.policy_config.td_steps - # the start obs is init zero obs, so we take the - # [ : +] obs as the pad obs - # e.g. the start 4 obs is init zero obs, the num_unroll_steps is 5, so we take the [4:9] obs as the pad obs + # The initial observations are zero-padded, so we take observations from + # [ : + ] for padding. pad_obs_lst = game_segments[i].obs_segment[beg_index:end_index] - # NOTE: for unizero + # NOTE: Specific padding logic for UniZero. pad_action_lst = game_segments[i].action_segment[:self.policy_config.num_unroll_steps + self.policy_config.td_steps] - - # NOTE: for unizero pad_child_visits_lst = game_segments[i].child_visit_segment[:self.policy_config.num_unroll_steps + self.policy_config.td_steps] - # EfficientZero original repo bug: - # pad_child_visits_lst = game_segments[i].child_visit_segment[beg_index:end_index] - beg_index = 0 end_index = beg_index + self.unroll_plus_td_steps - 1 - pad_reward_lst = game_segments[i].reward_segment[beg_index:end_index] if self.policy_config.use_ture_chance_label_in_chance_encoder: @@ -297,101 +262,87 @@ def pad_and_save_last_trajectory(self, i: int, last_game_segments: List[GameSegm beg_index = 0 end_index = beg_index + self.unroll_plus_td_steps - pad_root_values_lst = game_segments[i].root_value_segment[beg_index:end_index] if self.policy_config.gumbel_algo: pad_improved_policy_prob = game_segments[i].improved_policy_probs[beg_index:end_index] - # pad over and save + # Pad and finalize the last game segment. if self.policy_config.gumbel_algo: - last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst, - next_segment_improved_policy=pad_improved_policy_prob) + last_game_segments[i].pad_over( + pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst, + next_segment_improved_policy=pad_improved_policy_prob + ) else: if self.policy_config.use_ture_chance_label_in_chance_encoder: - last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst, - next_chances=chance_lst) + last_game_segments[i].pad_over( + pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst, + next_chances=chance_lst + ) else: - last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst) - """ - Note: - game_segment element shape: - obs: game_segment_length + stack + num_unroll_steps, 20+4 +5 - rew: game_segment_length + stack + num_unroll_steps + td_steps -1 20 +5+3-1 - action: game_segment_length + num_unroll_steps + td_steps -> 20 +5+3 - root_values: game_segment_length + num_unroll_steps + td_steps -> 20 +5+3 - child_visits: game_segment_length + num_unroll_steps -> 20 +5 - to_play: game_segment_length -> 20 - action_mask: game_segment_length -> 20 - """ + last_game_segments[i].pad_over( + pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst + ) last_game_segments[i].game_segment_to_array() - # put the game segment into the pool + # Add the completed game segment to the pool. self.game_segment_pool.append((last_game_segments[i], last_game_priorities[i], done[i])) - # reset last game_segments and last game_priorities for the next collection + # Reset placeholders for the next collection cycle. last_game_segments[i] = None last_game_priorities[i] = None - return None - - def collect(self, - num_segments: Optional[int] = None, - train_iter: int = 0, - policy_kwargs: Optional[dict] = None, - collect_with_pure_policy: bool = False) -> List[Any]: + def collect( + self, + num_segments: Optional[int] = None, + train_iter: int = 0, + policy_kwargs: Optional[dict] = None, + collect_with_pure_policy: bool = False + ) -> List[Any]: """ Overview: - Collect `num_segments` segments of data with policy_kwargs, trained for `train_iter` iterations. + Collects a specified number of game segments using the policy. Arguments: - - num_segments (:obj:`Optional[int]`): Number of segments to collect. - - train_iter (:obj:`int`): Number of training iterations completed so far. - - policy_kwargs (:obj:`Optional[dict]`): Additional keyword arguments for the policy. - - collect_with_pure_policy (:obj:`bool`): Whether to collect data using pure policy without MCTS. + - num_segments (:obj:`Optional[int]`): The number of segments to collect. If None, uses the default. + - train_iter (:obj:`int`): The current training iteration, used for logging. + - policy_kwargs (:obj:`Optional[dict]`): Additional arguments for the policy forward pass. + - collect_with_pure_policy (:obj:`bool`): If True, collects data using a pure policy without MCTS. Returns: - - return_data (:obj:`List[Any]`): Collected data in the form of a list. + - return_data (:obj:`List[Any]`): A list containing the collected game segments and their metadata. """ if num_segments is None: if self._default_num_segments is None: - raise RuntimeError("Please specify collect num_segments") + raise RuntimeError("Please specify num_segments for collection.") else: num_segments = self._default_num_segments - assert num_segments == self._env_num, "Please make sure num_segments == env_num{}/{}".format(num_segments, self._env_num) + assert num_segments == self._env_num, f"num_segments({num_segments}) must be equal to env_num({self._env_num})." if policy_kwargs is None: policy_kwargs = {} - temperature = policy_kwargs['temperature'] - epsilon = policy_kwargs['epsilon'] + temperature = policy_kwargs.get('temperature', 1.0) + epsilon = policy_kwargs.get('epsilon', 0.0) + # Initializations collected_episode = 0 collected_step = 0 env_nums = self._env_num - - # initializations init_obs = self._env.ready_obs + # Wait for all environments to be ready, especially in a subprocess setup. retry_waiting_time = 0.05 - while len(init_obs.keys()) != self._env_num: - # To be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to - # len(self._env.ready_obs), especially in tictactoe env. - self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) - self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) + while len(init_obs.keys()) != env_nums: + self._logger.info(f'Waiting for all environments to reset. Ready envs: {list(init_obs.keys())}') time.sleep(retry_waiting_time) - self._logger.info('=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10) - self._logger.info( - 'After sleeping {}s, the current _env_states is {}'.format(retry_waiting_time, self._env._env_states) - ) init_obs = self._env.ready_obs for env_id in range(env_nums): - if env_id in init_obs.keys(): + if env_id in init_obs: self.action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) self.to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) - if 'timestep' not in init_obs[env_id]: - print(f"Warning: 'timestep' key is missing in init_obs[{env_id}], assigning value -1") self.timestep_dict[env_id] = to_ndarray(init_obs[env_id].get('timestep', -1)) - + if 'timestep' not in init_obs[env_id]: + self._logger.warning(f"'timestep' key missing in init_obs[{env_id}], assigning default -1.") if self.policy_config.use_ture_chance_label_in_chance_encoder: self.chance_dict[env_id] = to_ndarray(init_obs[env_id]['chance']) @@ -403,155 +354,91 @@ def collect(self, task_id=self.task_id ) for _ in range(env_nums) ] - # stacked observation windows in reset stage for init game_segments - observation_window_stack = [[] for _ in range(env_nums)] - for env_id in range(env_nums): - observation_window_stack[env_id] = deque( - [to_ndarray(init_obs[env_id]['observation']) for _ in range(self.policy_config.model.frame_stack_num)], - maxlen=self.policy_config.model.frame_stack_num - ) + # Stacked observation windows for initializing game segments. + observation_window_stack = [deque(maxlen=self.policy_config.model.frame_stack_num) for _ in range(env_nums)] + for env_id in range(env_nums): + initial_frames = [to_ndarray(init_obs[env_id]['observation']) for _ in range(self.policy_config.model.frame_stack_num)] + observation_window_stack[env_id].extend(initial_frames) game_segments[env_id].reset(observation_window_stack[env_id]) - # for priorities in self-play + # Lists for storing values for priority calculation. search_values_lst = [[] for _ in range(env_nums)] pred_values_lst = [[] for _ in range(env_nums)] if self.policy_config.gumbel_algo: improved_policy_lst = [[] for _ in range(env_nums)] - # some logs + # Logging variables. eps_steps_lst, visit_entropies_lst = np.zeros(env_nums), np.zeros(env_nums) if self.policy_config.gumbel_algo: completed_value_lst = np.zeros(env_nums) - self_play_moves = 0. - self_play_episodes = 0. - self_play_moves_max = 0 - self_play_visit_entropy = [] - total_transitions = 0 if collect_with_pure_policy: - temp_visit_list = [0.0 for i in range(self._env.action_space.n)] + temp_visit_list = [0.0 for _ in range(self._env.action_space.n)] while True: with self._timer: - # Get current ready env obs. + # Get observations from ready environments. obs = self._env.ready_obs ready_env_id = set(obs.keys()) if len(ready_env_id) < self._env_num: - logging.info(f'muzero_segment_collector: len(ready_env_id) < self._env_num, ready_env_id: {ready_env_id}, self._env_num: {self._env_num}') - - # TODO: For UniZero, during the init-infer process, it is necessary to retrieve the current kv_cache from the kv_cache_dict corresponding to each env_id. - # In theory, this requires waiting for all environments to be ready. However, in practice, - # waiting for all environments to be ready can have a significant negative impact on UniZero's performance, - # whereas the impact on MuZero is relatively small. + self._logger.debug(f'Only {len(ready_env_id)}/{self._env_num} envs are ready.') + + # TODO(author): For UniZero, waiting for all environments to be ready can negatively impact performance. + # This wait loop is currently commented out, but its impact should be considered. # while len(obs.keys()) != self._env_num: - # # To be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to - # # len(self._env.ready_obs), especially in tictactoe env. - # self._logger.info('The current init_obs.keys() is {}'.format(obs.keys())) - # self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) # time.sleep(retry_waiting_time) - # self._logger.info('=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10) - # self._logger.info( - # 'After sleeping {}s, the current _env_states is {}'.format(retry_waiting_time, self._env._env_states) - # ) # obs = self._env.ready_obs # ready_env_id = set(obs.keys()) - stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} - - - stack_obs = list(stack_obs.values()) + # Prepare stacked observations for the policy network. + stack_obs_dict = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} + stack_obs_list = [stack_obs_dict[env_id] for env_id in sorted(list(ready_env_id))] self.action_mask_dict_tmp = {env_id: self.action_mask_dict[env_id] for env_id in ready_env_id} self.to_play_dict_tmp = {env_id: self.to_play_dict[env_id] for env_id in ready_env_id} self.timestep_dict_tmp = {env_id: self.timestep_dict[env_id] for env_id in ready_env_id} - - action_mask = [self.action_mask_dict_tmp[env_id] for env_id in ready_env_id] - to_play = [self.to_play_dict_tmp[env_id] for env_id in ready_env_id] - timestep = [self.timestep_dict_tmp[env_id] for env_id in ready_env_id] + + action_mask = [self.action_mask_dict_tmp[env_id] for env_id in sorted(list(ready_env_id))] + to_play = [self.to_play_dict_tmp[env_id] for env_id in sorted(list(ready_env_id))] + timestep = [self.timestep_dict_tmp[env_id] for env_id in sorted(list(ready_env_id))] if self.policy_config.use_ture_chance_label_in_chance_encoder: self.chance_dict_tmp = {env_id: self.chance_dict[env_id] for env_id in ready_env_id} - stack_obs = to_ndarray(stack_obs) - # return stack_obs shape: [B, S*C, W, H] e.g. [8, 4*1, 96, 96] - stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) - stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device) + stack_obs_array = to_ndarray(stack_obs_list) + stack_obs_tensor = prepare_observation(stack_obs_array, self.policy_config.model.model_type) + stack_obs_tensor = torch.from_numpy(stack_obs_tensor).to(self.policy_config.device) # ============================================================== - # Key policy forward step + # Perform a forward pass with the policy. # ============================================================== - # print(f'ready_env_id:{ready_env_id}') - if self.task_id is None: - # single task setting - policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id) - else: - # multi task setting - policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id, task_id=self.task_id) - + policy_args = (stack_obs_tensor, action_mask, temperature, to_play, epsilon) + policy_kwargs_forward = {'ready_env_id': sorted(list(ready_env_id)), 'timestep': timestep} + if self.task_id is not None: + policy_kwargs_forward['task_id'] = self.task_id + + policy_output = self._policy.forward(*policy_args, **policy_kwargs_forward) - # Extract relevant policy outputs + # Extract policy outputs. actions_with_env_id = {k: v['action'] for k, v in policy_output.items()} value_dict_with_env_id = {k: v['searched_value'] for k, v in policy_output.items()} pred_value_dict_with_env_id = {k: v['predicted_value'] for k, v in policy_output.items()} - timestep_dict_with_env_id = { - k: v['timestep'] if 'timestep' in v else -1 for k, v in policy_output.items() - } - - if self.policy_config.sampled_algo: - root_sampled_actions_dict_with_env_id = { - k: v['root_sampled_actions'] for k, v in policy_output.items() - } - - if not collect_with_pure_policy: - distributions_dict_with_env_id = {k: v['visit_count_distributions'] for k, v in - policy_output.items()} - visit_entropy_dict_with_env_id = {k: v['visit_count_distribution_entropy'] for k, v in - policy_output.items()} - - if self.policy_config.gumbel_algo: - improved_policy_dict_with_env_id = {k: v['improved_policy_probs'] for k, v in - policy_output.items()} - completed_value_with_env_id = {k: v['roots_completed_value'] for k, v in policy_output.items()} - - # Initialize dictionaries to store results - actions = {} - value_dict = {} - pred_value_dict = {} - timestep_dict = {} if not collect_with_pure_policy: - distributions_dict = {} - visit_entropy_dict = {} - + distributions_dict_with_env_id = {k: v['visit_count_distributions'] for k, v in policy_output.items()} + visit_entropy_dict_with_env_id = {k: v['visit_count_distribution_entropy'] for k, v in policy_output.items()} if self.policy_config.sampled_algo: - root_sampled_actions_dict = {} - + root_sampled_actions_dict_with_env_id = {k: v['root_sampled_actions'] for k, v in policy_output.items()} if self.policy_config.gumbel_algo: - improved_policy_dict = {} - completed_value_dict = {} - - # Populate the result dictionaries - for env_id in ready_env_id: - actions[env_id] = actions_with_env_id.pop(env_id) - value_dict[env_id] = value_dict_with_env_id.pop(env_id) - pred_value_dict[env_id] = pred_value_dict_with_env_id.pop(env_id) - timestep_dict[env_id] = timestep_dict_with_env_id.pop(env_id) - - if not collect_with_pure_policy: - distributions_dict[env_id] = distributions_dict_with_env_id.pop(env_id) - - if self.policy_config.sampled_algo: - root_sampled_actions_dict[env_id] = root_sampled_actions_dict_with_env_id.pop(env_id) + improved_policy_dict_with_env_id = {k: v['improved_policy_probs'] for k, v in policy_output.items()} + completed_value_with_env_id = {k: v['roots_completed_value'] for k, v in policy_output.items()} - visit_entropy_dict[env_id] = visit_entropy_dict_with_env_id.pop(env_id) - - if self.policy_config.gumbel_algo: - improved_policy_dict[env_id] = improved_policy_dict_with_env_id.pop(env_id) - completed_value_dict[env_id] = completed_value_with_env_id.pop(env_id) + # Populate the result dictionaries, mapping outputs to original env_ids. + actions: Dict[int, Any] = {env_id: actions_with_env_id.pop(env_id) for env_id in ready_env_id} # ============================================================== - # Interact with the environment + # Step the environments with the chosen actions. # ============================================================== timesteps = self._env.step(actions) @@ -559,104 +446,93 @@ def collect(self, for env_id, episode_timestep in timesteps.items(): with self._timer: + # Handle abnormal timesteps by resetting the environment and policy state. if episode_timestep.info.get('abnormal', False): - # If there is an abnormal episode_timestep, reset all the related variables(including this env). - # suppose there is no reset param, reset this env self._env.reset({env_id: None}) self._policy.reset([env_id]) self._reset_stat(env_id) - self._logger.info('Env{} returns a abnormal step, its info is {}'.format(env_id, episode_timestep.info)) + self._logger.info(f'Env {env_id} had an abnormal step, info: {episode_timestep.info}') continue + obs, reward, done, info = episode_timestep.obs, episode_timestep.reward, episode_timestep.done, episode_timestep.info + # Store search statistics in the game segment. if collect_with_pure_policy: game_segments[env_id].store_search_stats(temp_visit_list, 0) else: if self.policy_config.sampled_algo: game_segments[env_id].store_search_stats( - distributions_dict[env_id], value_dict[env_id], root_sampled_actions_dict[env_id] + distributions_dict_with_env_id[env_id], value_dict_with_env_id[env_id], root_sampled_actions_dict_with_env_id[env_id] ) elif self.policy_config.gumbel_algo: - game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id], - improved_policy=improved_policy_dict[env_id]) + game_segments[env_id].store_search_stats( + distributions_dict_with_env_id[env_id], value_dict_with_env_id[env_id], + improved_policy=improved_policy_dict_with_env_id[env_id] + ) else: - game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id]) + game_segments[env_id].store_search_stats(distributions_dict_with_env_id[env_id], value_dict_with_env_id[env_id]) - # append a transition tuple, including a_t, o_{t+1}, r_{t}, action_mask_{t}, to_play_{t} - # in ``game_segments[env_id].init``, we have appended o_{t} in ``self.obs_segment`` + # Append the new transition to the game segment. + append_kwargs = {'timestep': to_ndarray(obs.get('timestep', -1))} if self.policy_config.use_ture_chance_label_in_chance_encoder: - game_segments[env_id].append( - actions[env_id], to_ndarray(obs['observation']), reward, self.action_mask_dict_tmp[env_id], - self.to_play_dict_tmp[env_id], timestep=to_ndarray(obs['timestep']), chance=self.chance_dict_tmp[env_id] - ) - else: - game_segments[env_id].append( - actions[env_id], to_ndarray(obs['observation']), reward, self.action_mask_dict_tmp[env_id], - self.to_play_dict_tmp[env_id], timestep=to_ndarray(obs['timestep']) - ) - - # NOTE: the position of code snippet is very important. - # the obs['action_mask'] and obs['to_play'] are corresponding to the next action - self.action_mask_dict_tmp[env_id] = to_ndarray(obs['action_mask']) - self.to_play_dict_tmp[env_id] = to_ndarray(obs['to_play']) - # self.timestep_dict_tmp[env_id] = to_ndarray(obs['timestep']) - self.timestep_dict_tmp[env_id] = to_ndarray(obs.get('timestep', -1)) - - + append_kwargs['chance'] = self.chance_dict_tmp[env_id] + + game_segments[env_id].append( + actions[env_id], to_ndarray(obs['observation']), reward, + self.action_mask_dict_tmp[env_id], self.to_play_dict_tmp[env_id], **append_kwargs + ) + + # NOTE: This position is crucial. The action_mask and to_play from the new observation correspond to the *next* state. + self.action_mask_dict[env_id] = to_ndarray(obs['action_mask']) + self.to_play_dict[env_id] = to_ndarray(obs['to_play']) + self.timestep_dict[env_id] = to_ndarray(obs.get('timestep', -1)) if self.policy_config.use_ture_chance_label_in_chance_encoder: - self.chance_dict_tmp[env_id] = to_ndarray(obs['chance']) + self.chance_dict[env_id] = to_ndarray(obs['chance']) - if self.policy_config.ignore_done: - self.dones[env_id] = False - else: - self.dones[env_id] = done + self.dones[env_id] = False if self.policy_config.ignore_done else done if not collect_with_pure_policy: - visit_entropies_lst[env_id] += visit_entropy_dict[env_id] + visit_entropies_lst[env_id] += visit_entropy_dict_with_env_id[env_id] if self.policy_config.gumbel_algo: - completed_value_lst[env_id] += np.mean(np.array(completed_value_dict[env_id])) + completed_value_lst[env_id] += np.mean(np.array(completed_value_with_env_id[env_id])) eps_steps_lst[env_id] += 1 + + # NOTE: Specific reset logic for UniZero. if self._policy.get_attribute('cfg').type in ['unizero', 'sampled_unizero']: - # ============ only for UniZero now ============ self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False) - total_transitions += 1 - if self.policy_config.use_priority: - pred_values_lst[env_id].append(pred_value_dict[env_id]) - search_values_lst[env_id].append(value_dict[env_id]) + pred_values_lst[env_id].append(pred_value_dict_with_env_id[env_id]) + search_values_lst[env_id].append(value_dict_with_env_id[env_id]) if self.policy_config.gumbel_algo and not collect_with_pure_policy: - improved_policy_lst[env_id].append(improved_policy_dict[env_id]) + improved_policy_lst[env_id].append(improved_policy_dict_with_env_id[env_id]) - # append the newest obs + # Append the newest observation to the observation window. observation_window_stack[env_id].append(to_ndarray(obs['observation'])) # ============================================================== - # we will save a game segment if it is the end of the game or the next game segment is finished. + # Save a game segment if it is full or the game has ended. # ============================================================== - - # if game segment is full, we will save the last game segment if game_segments[env_id].is_full(): - # pad over last segment trajectory + # If there's a previous segment, pad and save it. if self.last_game_segments[env_id] is not None: - # TODO(pu): return the one game segment + # TODO(pu): This logic pads and saves one game segment at a time. self.pad_and_save_last_trajectory( env_id, self.last_game_segments, self.last_game_priorities, game_segments, self.dones ) - # calculate priority + # Calculate priorities for the collected transitions. priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst) - pred_values_lst[env_id] = [] - search_values_lst[env_id] = [] + pred_values_lst[env_id], search_values_lst[env_id] = [], [] if self.policy_config.gumbel_algo and not collect_with_pure_policy: improved_policy_lst[env_id] = [] - # the current game_segments become last_game_segment + # The current segment now becomes the 'last' segment for the next padding operation. self.last_game_segments[env_id] = game_segments[env_id] self.last_game_priorities[env_id] = priorities - # create new GameSegment + # Create a new game segment to continue collection. game_segments[env_id] = GameSegment( self._env.action_space, game_segment_length=self.policy_config.game_segment_length, @@ -670,89 +546,75 @@ def collect(self, self._env_info[env_id]['time'] += self._timer.value + interaction_duration if episode_timestep.done: - logging.info(f'========env {env_id} done!========') + self._logger.info(f'======== Environment {env_id} episode finished! ========') self._total_episode_count += 1 - reward = episode_timestep.info['eval_episode_return'] info = { - 'reward': reward, + 'reward': episode_timestep.info['eval_episode_return'], 'time': self._env_info[env_id]['time'], 'step': self._env_info[env_id]['step'], } if not collect_with_pure_policy: - info['visit_entropy'] = visit_entropies_lst[env_id] / eps_steps_lst[env_id] + info['visit_entropy'] = visit_entropies_lst[env_id] / eps_steps_lst[env_id] if eps_steps_lst[env_id] > 0 else 0 if self.policy_config.gumbel_algo: - info['completed_value'] = completed_value_lst[env_id] / eps_steps_lst[env_id] - + info['completed_value'] = completed_value_lst[env_id] / eps_steps_lst[env_id] if eps_steps_lst[env_id] > 0 else 0 collected_episode += 1 self._episode_info.append(info) # ============================================================== - # if it is the end of the game, we will save the game segment + # At the end of a game, save all remaining game segments. # ============================================================== - - # NOTE: put the penultimate game segment in one episode into the trajectory_pool - # pad over 2th last game_segment using the last game_segment + # NOTE: Store the second-to-last game segment of the episode. if self.last_game_segments[env_id] is not None: self.pad_and_save_last_trajectory( env_id, self.last_game_segments, self.last_game_priorities, game_segments, self.dones ) - # store current segment trajectory + # Calculate priorities for the final segment. priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst) - # NOTE: put the last game segment in one episode into the trajectory_pool + # NOTE: Store the final game segment of the episode. game_segments[env_id].game_segment_to_array() - - # assert len(game_segments[env_id]) == len(priorities) - # NOTE: save the last game segment in one episode into the trajectory_pool if it's not null - if len(game_segments[env_id].reward_segment) != 0: + if len(game_segments[env_id].reward_segment) > 0: self.game_segment_pool.append((game_segments[env_id], priorities, self.dones[env_id])) - # log - self_play_moves_max = max(self_play_moves_max, eps_steps_lst[env_id]) - if not collect_with_pure_policy: - self_play_visit_entropy.append(visit_entropies_lst[env_id] / eps_steps_lst[env_id]) - self_play_moves += eps_steps_lst[env_id] - self_play_episodes += 1 + # Reset lists and stats for the new episode. + pred_values_lst[env_id], search_values_lst[env_id] = [], [] + eps_steps_lst[env_id], visit_entropies_lst[env_id] = 0, 0 - pred_values_lst[env_id] = [] - search_values_lst[env_id] = [] - eps_steps_lst[env_id] = 0 - visit_entropies_lst[env_id] = 0 - - # Env reset is done by env_manager automatically - # NOTE: ============ reset the policy for the env_id. Default reset_init_data=True. ================ + # Environment reset is handled by the env_manager automatically. + # NOTE: Reset the policy state for the completed environment. self._policy.reset([env_id], task_id=self.task_id) self._reset_stat(env_id) - ready_env_id.remove(env_id) - # ===== NOTE: if one episode done and not return, we should init its game_segments[env_id] ======= - # create new GameSegment - game_segments[env_id] = GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config, - task_id=self.task_id - ) + # NOTE: If an episode finishes but collection continues, re-initialize its game segment. + game_segments[env_id] = GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config, + task_id=self.task_id + ) game_segments[env_id].reset(observation_window_stack[env_id]) - - # NOTE: must after the for loop to make sure all env_id's data are collected + # Check if the required number of segments has been collected. if len(self.game_segment_pool) >= self._default_num_segments: - logging.info(f'env {env_id} collected {len(self.game_segment_pool)} segments now!') - - # [data, meta_data] - return_data = [self.game_segment_pool[i][0] for i in range(len(self.game_segment_pool))], [ - { - 'priorities': self.game_segment_pool[i][1], - 'done': self.game_segment_pool[i][2], - 'unroll_plus_td_steps': self.unroll_plus_td_steps - } for i in range(len(self.game_segment_pool)) + self._logger.info(f'Collected {len(self.game_segment_pool)} segments, reaching the target of {self._default_num_segments}.') + + # Format data for returning: [game_segments, metadata_list] + return_data = [ + [self.game_segment_pool[i][0] for i in range(len(self.game_segment_pool))], + [ + { + 'priorities': self.game_segment_pool[i][1], + 'done': self.game_segment_pool[i][2], + 'unroll_plus_td_steps': self.unroll_plus_td_steps + } for i in range(len(self.game_segment_pool)) + ] ] self.game_segment_pool.clear() break + collected_duration = sum([d['time'] for d in self._episode_info]) # TODO: for atari multitask new ddp pipeline # reduce data when enables DDP @@ -772,11 +634,11 @@ def collect(self, def _output_log(self, train_iter: int) -> None: """ Overview: - Log the collector's data and output the log information. + Logs collection statistics to the console and TensorBoard. Arguments: - - train_iter (:obj:`int`): Current training iteration number for logging context. + - train_iter (:obj:`int`): The current training iteration for logging context. """ - # TODO: for atari multitask new ddp pipeline + # TODO(author): For multi-task DDP, logging might be restricted to rank 0. # if self._rank != 0: # return if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0: @@ -785,19 +647,18 @@ def _output_log(self, train_iter: int) -> None: envstep_count = sum([d['step'] for d in self._episode_info]) duration = sum([d['time'] for d in self._episode_info]) episode_reward = [d['reward'] for d in self._episode_info] + if not self.collect_with_pure_policy: - visit_entropy = [d['visit_entropy'] for d in self._episode_info] + visit_entropy = [d.get('visit_entropy', 0.0) for d in self._episode_info] else: visit_entropy = [0.0] - if self.policy_config.gumbel_algo: - completed_value = [d['completed_value'] for d in self._episode_info] - self._total_duration += duration + info = { 'episode_count': episode_count, 'envstep_count': envstep_count, 'avg_envstep_per_episode': envstep_count / episode_count, - 'avg_envstep_per_sec': envstep_count / duration, - 'avg_episode_per_sec': episode_count / duration, + 'avg_envstep_per_sec': envstep_count / duration if duration > 0 else 0, + 'avg_episode_per_sec': episode_count / duration if duration > 0 else 0, 'collect_time': duration, 'reward_mean': np.mean(episode_reward), 'reward_std': np.std(episode_reward), @@ -806,25 +667,25 @@ def _output_log(self, train_iter: int) -> None: 'total_envstep_count': self._total_envstep_count, 'total_episode_count': self._total_episode_count, 'total_duration': self._total_duration, - 'visit_entropy': np.mean(visit_entropy), + 'visit_entropy_mean': np.mean(visit_entropy), } if self.policy_config.gumbel_algo: - info['completed_value'] = np.mean(completed_value) + completed_value = [d.get('completed_value', 0.0) for d in self._episode_info] + info['completed_value_mean'] = np.mean(completed_value) + self._episode_info.clear() - print(f'collector output_log: rank {self._rank}, self.task_id: {self.task_id}') - self._logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()]))) + + self._logger.info(f"Collector log (rank {self._rank}, task_id {self.task_id}):\n" + '\n'.join([f'{k}: {v}' for k, v in info.items()])) for k, v in info.items(): if k in ['each_reward']: continue if self.task_id is None: - self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) - else: - self._tb_logger.add_scalar('{}_iter_task{}/'.format(self._instance_name, self.task_id) + k, v, - train_iter) - if k in ['total_envstep_count']: - continue - if self.task_id is None: - self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count) + # Log for single-task setting + self._tb_logger.add_scalar(f'{self._instance_name}_iter/{k}', v, train_iter) + if k not in ['total_envstep_count', 'total_episode_count', 'total_duration']: + self._tb_logger.add_scalar(f'{self._instance_name}_step/{k}', v, self._total_envstep_count) else: - self._tb_logger.add_scalar('{}_step_task{}/'.format(self._instance_name, self.task_id) + k, v, - self._total_envstep_count) + # Log for multi-task setting + self._tb_logger.add_scalar(f'{self._instance_name}_iter_task{self.task_id}/{k}', v, train_iter) + if k not in ['total_envstep_count', 'total_episode_count', 'total_duration']: + self._tb_logger.add_scalar(f'{self._instance_name}_step_task{self.task_id}/{k}', v, self._total_envstep_count) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 18645a2e0..2b56f5a3f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,5 @@ line_profiler xxhash simple_parsing einops -openai \ No newline at end of file +openai +nltk \ No newline at end of file diff --git a/zoo/README.md b/zoo/README.md index 298171748..a1dd94e14 100644 --- a/zoo/README.md +++ b/zoo/README.md @@ -1,26 +1,36 @@ - ## Environment Versatility -- The following is a brief introduction to the environment supported by our zoo: +- The following is a brief introduction to the environments supported by our zoo:
Expand for full list -| No | Environment | Label | Visualization | Doc Links | -|:--:|:---------------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:| -| 1 | [board_games/tictactoe](https://github.com/opendilab/LightZero/tree/main/zoo/board_games/tictactoe) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](https://github.com/opendilab/LightZero/tree/main/zoo/board_games/tictactoe/tictactoe.gif) | [env tutorial](https://en.wikipedia.org/wiki/Tic-tac-toe) | -| 2 | [board_games/gomoku](https://github.com/opendilab/LightZero/tree/main/zoo/board_games/gomoku) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](https://github.com/opendilab/LightZero/tree/main/zoo/board_games/gomoku/gomoku.gif) | [env tutorial](https://en.wikipedia.org/wiki/Gomoku) | -| 3 | [board_games/connect4](https://github.com/opendilab/LightZero/tree/main/zoo/board_games/connect4) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](https://github.com/opendilab/LightZero/tree/main/zoo/board_games/connect4/connect4.gif) | [env tutorial](https://en.wikipedia.org/wiki/Connect4) | -| 4 | [game_2048](https://github.com/opendilab/LightZero/tree/main/zoo/game_2048) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](https://github.com/opendilab/LightZero/tree/main/zoo/game_2048/game_2048.gif) | [env tutorial](https://en.wikipedia.org/wiki/2048) | -| 5 | [chess](https://github.com/opendilab/LightZero/tree/main/zoo/board_games/chess) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](https://github.com/opendilab/LightZero/tree/main/zoo/board_games/chess/chess.gif) | [env tutorial](https://en.wikipedia.org/wiki/Chess) | -| 6 | [go](https://github.com/opendilab/LightZero/tree/main/zoo/board_games/go) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](https://github.com/opendilab/LightZero/tree/main/zoo/board_games/go/go.gif) | [env tutorial](https://en.wikipedia.org/wiki/Go) | -| 7 | [classic_control/cartpole](https://github.com/opendilab/LightZero/tree/main/zoo/classic_control) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/classic_control/cartpole/cartpole.gif) | [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/cartpole.html)
[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/cartpole_zh.html) | -| 8 | [classic_control/pendulum](https://github.com/opendilab/LightZero/tree/main/zoo/classic_control) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](https://github.com/opendilab/DI-engine/blob/main//dizoo/classic_control/pendulum/pendulum.gif) | [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/pendulum.html)
[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/pendulum_zh.html) | -| 9 | [box2d/lunarlander](https://github.com/opendilab/LightZero/tree/main/zoo/box2d) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![continuous](https://img.shields.io/badge/-continous-green) | ![original](https://github.com/opendilab/DI-engine/blob/main//dizoo/box2d/lunarlander/lunarlander.gif) | [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/lunarlander.html)
[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/lunarlander_zh.html) | -| 10 | [box2d/bipedalwalker](https://github.com/opendilab/LightZero/tree/main/zoo/box2d) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](https://github.com/opendilab/DI-engine/blob/main//dizoo/box2d/bipedalwalker/bipedalwalker.gif) | [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/bipedalwalker.html)
[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/bipedalwalker_zh.html) | -| 11 | [atari](https://github.com/opendilab/LightZero/tree/main/zoo/atari) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](https://github.com/opendilab/DI-engine/blob/main/dizoo/atari/atari.gif) | [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/atari.html)
[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/atari_zh.html) | -| 11 | [mujoco](https://github.com/opendilab/LightZero/tree/main/zoo/mujoco) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](https://github.com/opendilab/DI-engine/blob/main/dizoo/mujoco/mujoco.gif) | [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/mujoco.html)
[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/mujoco_zh.html) | -| 12 | [minigrid](https://github.com/opendilab/LightZero/tree/main/zoo/minigrid) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](https://github.com/opendilab/DI-engine/blob/main/dizoo/minigrid/minigrid.gif) | [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/minigrid.html)
[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/minigrid_zh.html) | -| 13 | [memory](https://github.com/opendilab/LightZero/tree/main/zoo/memory) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](https://github.com/opendilab/LightZero/blob/main/zoo/memory/key_to_door.gif)
![original](https://github.com/opendilab/LightZero/blob/main/zoo/memory/visual_match.gif) | [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/memory.html)
[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/memory_zh.html) | +| No | Environment | Label | Visualization | Brief Description | Doc Links | +|:--:|:-----------:|:-----:|:-------------:|:-----------------|:---------| +| 1 | [board_games/tictactoe](https://github.com/opendilab/LightZero/tree/main/zoo/board_games/tictactoe) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | N/A | Classic Tic-Tac-Toe board game with simple rules and fast gameplay. | [Tic-tac-toe Wiki](https://en.wikipedia.org/wiki/Tic-tac-toe) | +| 2 | [board_games/gomoku](https://github.com/opendilab/LightZero/tree/main/zoo/board_games/gomoku) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | N/A | Gomoku (Five in a Row), a strategic board game on a grid. | [Gomoku Wiki](https://en.wikipedia.org/wiki/Gomoku) | +| 3 | [board_games/connect4](https://github.com/opendilab/LightZero/tree/main/zoo/board_games/connect4) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![connect4](https://github.com/opendilab/LightZero/blob/main/zoo/board_games/connect4/connect4.gif) | Connect Four, a two-player connection board game. | [Connect Four Wiki](https://en.wikipedia.org/wiki/Connect_Four) | +| 4 | [board_games/chess](https://github.com/opendilab/LightZero/tree/main/zoo/board_games/chess) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | N/A | Chess, the classic strategy board game. | [Chess Wiki](https://en.wikipedia.org/wiki/Chess) | +| 5 | [board_games/go](https://github.com/opendilab/LightZero/tree/main/zoo/board_games/go) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | N/A | Go, an ancient board game emphasizing territory control. | [Go Wiki](https://en.wikipedia.org/wiki/Go_(game)) | +| 6 | [game_2048](https://github.com/opendilab/LightZero/tree/main/zoo/game_2048) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | N/A | 2048, a single-player sliding block puzzle game. | [2048 Wiki](https://en.wikipedia.org/wiki/2048_(video_game)) | +| 7 | [classic_control/cartpole](https://github.com/opendilab/LightZero/tree/main/zoo/classic_control/cartpole) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | N/A | CartPole, a classic control problem balancing a pole on a cart. | [CartPole Doc](https://di-engine-docs.readthedocs.io/en/latest/13_envs/cartpole.html) | +| 8 | [classic_control/pendulum](https://github.com/opendilab/LightZero/tree/main/zoo/classic_control/pendulum) | ![continuous](https://img.shields.io/badge/-continous-green) | N/A | Pendulum, a continuous control task for swing-up and stabilization. | [Pendulum Doc](https://di-engine-docs.readthedocs.io/en/latest/13_envs/pendulum.html) | +| 9 | [classic_control/mountain_car](https://github.com/opendilab/LightZero/tree/main/zoo/classic_control/mountain_car) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | N/A | MountainCar, a classic control task for reinforcement learning. | [MountainCar Doc](https://www.gymlibrary.dev/environments/classic_control/mountain_car/) | +| 10 | [box2d/lunarlander](https://github.com/opendilab/LightZero/tree/main/zoo/box2d/lunarlander) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![continuous](https://img.shields.io/badge/-continous-green) | N/A | LunarLander, a Box2D-based environment for landing a spacecraft. | [LunarLander Doc](https://di-engine-docs.readthedocs.io/en/latest/13_envs/lunarlander.html) | +| 11 | [box2d/bipedalwalker](https://github.com/opendilab/LightZero/tree/main/zoo/box2d/bipedalwalker) | ![continuous](https://img.shields.io/badge/-continous-green) | N/A | BipedalWalker, a continuous control task for walking robots. | [BipedalWalker Doc](https://di-engine-docs.readthedocs.io/en/latest/13_envs/bipedalwalker.html) | +| 12 | [atari](https://github.com/opendilab/LightZero/tree/main/zoo/atari) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | N/A | Atari 2600 suite, classic video games for RL benchmarks. | [Atari Doc](https://di-engine-docs.readthedocs.io/en/latest/13_envs/atari.html) | +| 13 | [mujoco](https://github.com/opendilab/LightZero/tree/main/zoo/mujoco) | ![continuous](https://img.shields.io/badge/-continous-green) | N/A | MuJoCo, continuous control suite for robotics and locomotion. | [MuJoCo Doc](https://di-engine-docs.readthedocs.io/en/latest/13_envs/mujoco.html) | +| 14 | [minigrid](https://github.com/opendilab/LightZero/tree/main/zoo/minigrid) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![MiniGrid-FourRooms](https://github.com/opendilab/LightZero/blob/main/zoo/minigrid/envs/video/MiniGrid-FourRooms-v0_episode_0.gif) | MiniGrid, a gridworld environment for exploration and planning. | [MiniGrid Doc](https://di-engine-docs.readthedocs.io/en/latest/13_envs/minigrid.html) | +| 15 | [memory](https://github.com/opendilab/LightZero/tree/main/zoo/memory) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![key_to_door](https://github.com/opendilab/LightZero/blob/main/zoo/memory/key_to_door.gif)
![visual_match](https://github.com/opendilab/LightZero/blob/main/zoo/memory/visual_match.gif) | Memory tasks, such as Key-to-Door and Visual-Match, for memory-based RL. | [Memory Doc](https://di-engine-docs.readthedocs.io/en/latest/13_envs/memory.html) | +| 16 | [dmc2gym](https://github.com/opendilab/LightZero/tree/main/zoo/dmc2gym) | ![continuous](https://img.shields.io/badge/-continous-green) | N/A | DeepMind Control Suite via Gym interface, continuous control tasks. | [DMC2Gym Doc](https://github.com/denisyarats/dmc2gym) | +| 17 | [jericho](https://github.com/opendilab/LightZero/tree/main/zoo/jericho) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | N/A | Jericho, a suite for text-based adventure games. | [Jericho Doc](https://github.com/microsoft/jericho) | +| 18 | [pooltool/sum_to_three](https://github.com/opendilab/LightZero/tree/main/zoo/pooltool/sum_to_three) | ![continuous](https://img.shields.io/badge/-continous-green) | N/A | SumToThree, a physics-based pool tool environment. | [SumToThree Doc](https://github.com/opendilab/LightZero/tree/main/zoo/pooltool/sum_to_three) | +| 19 | [crowd_sim](https://github.com/opendilab/LightZero/tree/main/zoo/crowd_sim) | ![continuous](https://img.shields.io/badge/-continous-green) | N/A | CrowdSim, environments for crowd simulation and navigation. | [CrowdSim Doc](https://github.com/opendilab/LightZero/tree/main/zoo/crowd_sim) | +| 20 | [metadrive](https://github.com/opendilab/LightZero/tree/main/zoo/metadrive) | ![continuous](https://img.shields.io/badge/-continous-green) | N/A | MetaDrive, a driving simulator for RL research. | [MetaDrive Doc](https://github.com/metadriverse/metadrive) | +| 21 | [memory_maze](https://github.com/opendilab/LightZero/tree/main/zoo/memory_maze) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | N/A | Memory Maze, a challenging memory-based navigation task. | [Memory Maze Doc](https://github.com/deepmind/maze-solver) | +| 22 | [bsuite](https://github.com/opendilab/LightZero/tree/main/zoo/bsuite) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | N/A | BSuite, a collection of RL environments for benchmarking. | [BSuite Doc](https://github.com/deepmind/bsuite) | + + +
diff --git a/zoo/atari/config/atari_efficientzero_config.py b/zoo/atari/config/atari_efficientzero_config.py index 4134dce32..6d440cbf5 100644 --- a/zoo/atari/config/atari_efficientzero_config.py +++ b/zoo/atari/config/atari_efficientzero_config.py @@ -45,9 +45,8 @@ self_supervised_learning_loss=True, # default is False discrete_action_encoding_type='one_hot', norm_type='BN', - reward_support_size=101, - value_support_size=101, - support_scale=50, + reward_support_range=(-50., 51., 1.), + value_support_range=(-50., 51., 1.), ), cuda=True, env_type='not_board_games', diff --git a/zoo/atari/config/atari_env_action_space_map.py b/zoo/atari/config/atari_env_action_space_map.py index e2090586d..d40d12f41 100644 --- a/zoo/atari/config/atari_env_action_space_map.py +++ b/zoo/atari/config/atari_env_action_space_map.py @@ -27,4 +27,7 @@ 'SeaquestNoFrameskip-v4': 18, 'BoxingNoFrameskip-v4': 18, 'BreakoutNoFrameskip-v4': 4, + 'SpaceInvadersNoFrameskip-v4': 6, + 'BeamRiderNoFrameskip-v4': 9, + 'GravitarNoFrameskip-v4': 18, }) \ No newline at end of file diff --git a/zoo/atari/config/atari_muzero_config.py b/zoo/atari/config/atari_muzero_config.py index 8f79eb63e..7a615dbf1 100644 --- a/zoo/atari/config/atari_muzero_config.py +++ b/zoo/atari/config/atari_muzero_config.py @@ -39,7 +39,7 @@ collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, - manager=dict(shared_memory=True, ), + manager=dict(shared_memory=False, ), # TODO: debug # collect_max_episode_steps=int(50), # eval_max_episode_steps=int(50), diff --git a/zoo/atari/config/atari_muzero_context_config.py b/zoo/atari/config/atari_muzero_context_config.py index e3208d74f..d835ce349 100644 --- a/zoo/atari/config/atari_muzero_context_config.py +++ b/zoo/atari/config/atari_muzero_context_config.py @@ -58,9 +58,8 @@ self_supervised_learning_loss=True, discrete_action_encoding_type='one_hot', norm_type='BN', - reward_support_size=101, - value_support_size=101, - support_scale=50, + reward_support_range=(-50., 51., 1.), + value_support_range=(-50., 51., 1.), context_length_init=context_length_init, use_sim_norm=True, model_type='conv_context', diff --git a/zoo/atari/config/atari_muzero_multitask_segment_config.py b/zoo/atari/config/atari_muzero_multitask_segment_config.py deleted file mode 100644 index ce486a050..000000000 --- a/zoo/atari/config/atari_muzero_multitask_segment_config.py +++ /dev/null @@ -1,260 +0,0 @@ -from easydict import EasyDict - -def create_config( - env_id, - action_space_size, - collector_env_num, - evaluator_env_num, - n_episode, - num_simulations, - reanalyze_ratio, - batch_size, - num_unroll_steps, - infer_context_length, - norm_type, - buffer_reanalyze_freq, - reanalyze_batch_size, - reanalyze_partition, - num_segments -): - - return EasyDict(dict( - env=dict( - stop_value=int(5e5), # Adjusted max_env_step based on user TODO - env_id=env_id, - observation_shape=(4, 96, 96), - frame_stack_num=4, - gray_scale=True, - collector_env_num=collector_env_num, - evaluator_env_num=evaluator_env_num, - n_evaluator_episode=evaluator_env_num, - manager=dict(shared_memory=False, ), - full_action_space=True, - # ===== TODO: only for debug ===== - # collect_max_episode_steps=int(50), - # eval_max_episode_steps=int(50), - ), - policy=dict( - learn=dict( - learner=dict( - hook=dict(save_ckpt_after_iter=200000,), # Adjusted checkpoint frequency - ), - ), - grad_correct_params=dict( - # Placeholder for gradient correction parameters if needed - ), - task_num=len(env_id_list), - model=dict( - device='cuda', - num_res_blocks=2, # NOTE: encoder for 4 game - num_channels=256, - reward_head_channels= 16, - value_head_channels= 16, - policy_head_channels= 16, - fc_reward_layers= [32], - fc_value_layers= [32], - fc_policy_layers= [32], - observation_shape=(4, 96, 96), - frame_stack_num=4, - gray_scale=True, - action_space_size=action_space_size, - norm_type=norm_type, - model_type='conv', - image_channel=1, - downsample=True, - self_supervised_learning_loss=True, - discrete_action_encoding_type='one_hot', - use_sim_norm=True, - use_sim_norm_kl_loss=False, - task_num=len(env_id_list), - ), - cuda=True, - env_type='not_board_games', - # train_start_after_envsteps=2000, - train_start_after_envsteps=0, - game_segment_length=20, # Fixed segment length as per user config - random_collect_episode_num=0, - use_augmentation=True, - use_priority=False, - replay_ratio=0.25, - num_unroll_steps=num_unroll_steps, - # =========== TODO: debug =========== - # update_per_collect=2, # TODO: debug - update_per_collect=80, # Consistent with UniZero config - batch_size=batch_size, - optim_type='SGD', - td_steps=5, - lr_piecewise_constant_decay=True, - manual_temperature_decay=False, - learning_rate=0.2, - target_update_freq=100, - num_segments=num_segments, - num_simulations=num_simulations, - policy_entropy_weight=5e-3, #TODO - ssl_loss_weight=2, - eval_freq=int(5e3), - replay_buffer_size=int(5e5), # Adjusted as per UniZero config - collector_env_num=collector_env_num, - evaluator_env_num=evaluator_env_num, - # ============= The key different params for reanalyze ============= - buffer_reanalyze_freq=buffer_reanalyze_freq, - reanalyze_batch_size=reanalyze_batch_size, - reanalyze_partition=reanalyze_partition, - ), - )) - -def generate_configs( - env_id_list, - action_space_size, - collector_env_num, - evaluator_env_num, - n_episode, - num_simulations, - reanalyze_ratio, - batch_size, - num_unroll_steps, - infer_context_length, - norm_type, - seed, - buffer_reanalyze_freq, - reanalyze_batch_size, - reanalyze_partition, - num_segments -): - configs = [] - exp_name_prefix = ( - f'data_muzero_mt_8games/{len(env_id_list)}games_brf{buffer_reanalyze_freq}/' - f'{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_' - f'{len(env_id_list)}-pred-head_mbs-512_upc80_H{num_unroll_steps}_seed{seed}/' - ) - - for task_id, env_id in enumerate(env_id_list): - config = create_config( - env_id, - action_space_size, - # collector_env_num if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, # TODO: different collector_env_num for Pong and Boxing - # evaluator_env_num if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, - # n_episode if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, - collector_env_num, - evaluator_env_num, - n_episode, - num_simulations, - reanalyze_ratio, - batch_size, - num_unroll_steps, - infer_context_length, - norm_type, - buffer_reanalyze_freq, - reanalyze_batch_size, - reanalyze_partition, - num_segments - ) - config.policy.task_id = task_id - config.exp_name = f"{exp_name_prefix}{env_id.split('NoFrameskip')[0]}_muzero-mt_seed{seed}" - - configs.append([task_id, [config, create_env_manager()]]) - - return configs - -def create_env_manager(): - return EasyDict(dict( - env=dict( - type='atari_lightzero', - import_names=['zoo.atari.envs.atari_lightzero_env'], - ), - env_manager=dict(type='subprocess'), - # env_manager=dict(type='base'), - policy=dict( - type='muzero_multitask', - import_names=['lzero.policy.muzero_multitask'], - ), - )) - -if __name__ == "__main__": - import sys - sys.path.insert(0, "/mnt/afs/niuyazhe/code/LightZero") - import lzero - print("lzero path:", lzero.__file__) - # import sys - # import os - # # 添加项目根目录到 PYTHONPATH - # sys.path.append(os.path.dirname(os.path.abspath(__file__))) - - from lzero.entry import train_muzero_multitask_segment_noddp - import argparse - - parser = argparse.ArgumentParser(description='Train MuZero Multitask on Atari') - parser.add_argument('--seed', type=int, default=0, help='Random seed') - args = parser.parse_args() - - # Define your list of environment IDs - env_id_list = [ - 'PongNoFrameskip-v4', - 'MsPacmanNoFrameskip-v4', - 'SeaquestNoFrameskip-v4', - 'BoxingNoFrameskip-v4', - 'AlienNoFrameskip-v4', - 'ChopperCommandNoFrameskip-v4', - 'HeroNoFrameskip-v4', - 'RoadRunnerNoFrameskip-v4', - ] - # env_id_list = [ - # 'PongNoFrameskip-v4', - # 'MsPacmanNoFrameskip-v4', - # ] - - action_space_size = 18 # Full action space, adjust if different per env - seed = args.seed - collector_env_num = 8 - evaluator_env_num = 3 - num_segments = 8 - n_episode = 8 - num_simulations = 50 - reanalyze_ratio = 0.0 - - max_batch_size = 512 - batch_size = [int(min(64, max_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] - print(f'=========== batch_size: {batch_size} ===========') - - num_unroll_steps = 5 - infer_context_length = 4 - # norm_type = 'LN' - norm_type = 'BN' - - buffer_reanalyze_freq = 1 / 50 # Adjusted as per UniZero config - reanalyze_batch_size = 160 - reanalyze_partition = 0.75 - - num_segments = 8 - - # =========== TODO: debug =========== - # collector_env_num = 2 - # evaluator_env_num = 2 - # num_segments = 2 - # n_episode = 2 - # num_simulations = 5 - # batch_size = [int(min(2, max_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] - - - # Generate configurations - configs = generate_configs( - env_id_list=env_id_list, - action_space_size=action_space_size, - collector_env_num=collector_env_num, - evaluator_env_num=evaluator_env_num, - n_episode=n_episode, - num_simulations=num_simulations, - reanalyze_ratio=reanalyze_ratio, - batch_size=batch_size, - num_unroll_steps=num_unroll_steps, - infer_context_length=infer_context_length, - norm_type=norm_type, - seed=seed, - buffer_reanalyze_freq=buffer_reanalyze_freq, - reanalyze_batch_size=reanalyze_batch_size, - reanalyze_partition=reanalyze_partition, - num_segments=num_segments - ) - - # Start training - train_muzero_multitask_segment_noddp(configs, seed=seed, max_env_step=int(5e5)) \ No newline at end of file diff --git a/zoo/atari/config/atari_muzero_multitask_segment_ddp_config.py b/zoo/atari/config/atari_muzero_multitask_segment_ddp_config.py index 698a3d1ac..7d640e1d7 100644 --- a/zoo/atari/config/atari_muzero_multitask_segment_ddp_config.py +++ b/zoo/atari/config/atari_muzero_multitask_segment_ddp_config.py @@ -1,294 +1,330 @@ +""" +Overview: + Configuration generation script for multi-task MuZero training on Atari environments. + This script defines and generates the necessary configuration files for a distributed training setup. +""" from easydict import EasyDict from copy import deepcopy -from atari_env_action_space_map import atari_env_action_space_map +from typing import List, Union, Dict, Any -def create_config( - env_id, - action_space_size, - collector_env_num, - evaluator_env_num, - n_episode, - num_simulations, - reanalyze_ratio, - batch_size, - num_unroll_steps, - infer_context_length, - norm_type, - buffer_reanalyze_freq, - reanalyze_batch_size, - reanalyze_partition, - num_segments -): +# The 'atari_env_action_space_map' was not used in the original code, so it has been removed. - return EasyDict(dict( - env=dict( - stop_value=int(5e5), # Adjusted max_env_step based on user TODO - env_id=env_id, - observation_shape=(4, 96, 96), - frame_stack_num=4, - gray_scale=True, - collector_env_num=collector_env_num, - evaluator_env_num=evaluator_env_num, - n_evaluator_episode=evaluator_env_num, - manager=dict(shared_memory=False, ), - full_action_space=True, - collect_max_episode_steps=int(5e3), - eval_max_episode_steps=int(5e3), - # ===== only for debug ===== - # collect_max_episode_steps=int(50), - # eval_max_episode_steps=int(50), - ), - policy=dict( - multi_gpu=True, # ======== Very important for ddp ============= - learn=dict( - learner=dict( - hook=dict(save_ckpt_after_iter=200000,), # Adjusted checkpoint frequency - ), - ), - grad_correct_params=dict( - # Placeholder for gradient correction parameters if needed - ), - task_num=len(env_id_list), - model=dict( - device='cuda', - num_res_blocks=2, # NOTE: encoder for 4 game - num_channels=256, - reward_head_channels= 16, - value_head_channels= 16, - policy_head_channels= 16, - fc_reward_layers= [32], - fc_value_layers= [32], - fc_policy_layers= [32], +class AtariMuZeroMultitaskConfig: + """ + Overview: + A class to generate and manage configurations for multi-task MuZero experiments on Atari. + It encapsulates the entire configuration logic, providing a clean and extensible interface. + """ + + def __init__( + self, + env_id_list: List[str], + seed: int, + num_unroll_steps: int, + num_simulations: int, + collector_env_num: int, + evaluator_env_num: int, + max_env_step: int, + batch_size: Union[List[int], int], + norm_type: str, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + exp_path_prefix: str = 'YOUR_EXPERIMENT_PATH_PREFIX/data_muzero_mt_atari', + ) -> None: + """ + Overview: + Initializes the multi-task configuration generator. + Arguments: + - env_id_list (:obj:`List[str]`): A list of Atari environment IDs to be trained on. + - seed (:obj:`int`): The random seed for the experiment. + - num_unroll_steps (:obj:`int`): The number of steps to unroll the model during training. + - num_simulations (:obj:`int`): The number of simulations to run in the MCTS search. + - collector_env_num (:obj:`int`): The number of environments for data collection. + - evaluator_env_num (:obj:`int`): The number of environments for evaluation. + - max_env_step (:obj:`int`): The total number of environment steps to train for. + - batch_size (:obj:`Union[List[int], int]`): The batch size for training. Can be a list for per-task sizes or a single int. + - norm_type (:obj:`str`): The type of normalization to use in the model (e.g., 'BN', 'LN'). + - buffer_reanalyze_freq (:obj:`float`): The frequency at which to reanalyze the replay buffer. + - reanalyze_batch_size (:obj:`int`): The batch size for reanalysis. + - reanalyze_partition (:obj:`float`): The partition ratio for reanalysis. + - num_segments (:obj:`int`): The number of segments for the replay buffer. + - exp_path_prefix (:obj:`str`): A template for the experiment's output path. + """ + self.env_id_list = env_id_list + self.seed = seed + self.num_unroll_steps = num_unroll_steps + self.num_simulations = num_simulations + self.collector_env_num = collector_env_num + self.evaluator_env_num = evaluator_env_num + self.max_env_step = max_env_step + self.batch_size = batch_size + self.norm_type = norm_type + self.buffer_reanalyze_freq = buffer_reanalyze_freq + self.reanalyze_batch_size = reanalyze_batch_size + self.reanalyze_partition = reanalyze_partition + self.num_segments = num_segments + self.exp_path_prefix = exp_path_prefix + + # --- Derived attributes --- + self.num_tasks = len(self.env_id_list) + self.action_space_size = 18 # Default full action space for Atari + + def _create_base_config(self) -> EasyDict: + """ + Overview: + Creates the base configuration dictionary with shared settings for all tasks. + Returns: + - (:obj:`EasyDict`): A dictionary containing the base configuration. + """ + return EasyDict(dict( + env=dict( + stop_value=int(self.max_env_step), observation_shape=(4, 96, 96), frame_stack_num=4, gray_scale=True, - action_space_size=action_space_size, - norm_type=norm_type, - model_type='conv', - image_channel=1, - downsample=True, - self_supervised_learning_loss=True, - discrete_action_encoding_type='one_hot', - use_sim_norm=True, - use_sim_norm_kl_loss=False, - task_num=len(env_id_list), + collector_env_num=self.collector_env_num, + evaluator_env_num=self.evaluator_env_num, + n_evaluator_episode=self.evaluator_env_num, + manager=dict(shared_memory=False), + full_action_space=True, + collect_max_episode_steps=int(5e3), + eval_max_episode_steps=int(5e3), ), - allocated_batch_sizes=False, - cuda=True, - env_type='not_board_games', - train_start_after_envsteps=2000, - # train_start_after_envsteps=0, # TODO: debug - game_segment_length=20, # Fixed segment length as per user config - random_collect_episode_num=0, - use_augmentation=True, - use_priority=False, - replay_ratio=0.25, - num_unroll_steps=num_unroll_steps, - # update_per_collect=2, # TODO: debug - update_per_collect=80, # Consistent with UniZero config - batch_size=batch_size, - optim_type='SGD', - td_steps=5, - lr_piecewise_constant_decay=True, - manual_temperature_decay=False, - learning_rate=0.2, - target_update_freq=100, - num_segments=num_segments, - num_simulations=num_simulations, - policy_entropy_weight=5e-3, #TODO - ssl_loss_weight=2, - eval_freq=int(5e3), - replay_buffer_size=int(5e5), # Adjusted as per UniZero config - collector_env_num=collector_env_num, - evaluator_env_num=evaluator_env_num, - # ============= The key different params for reanalyze ============= - buffer_reanalyze_freq=buffer_reanalyze_freq, - reanalyze_batch_size=reanalyze_batch_size, - reanalyze_partition=reanalyze_partition, - ), - )) - -def generate_configs( - env_id_list, - action_space_size, - collector_env_num, - evaluator_env_num, - n_episode, - num_simulations, - reanalyze_ratio, - batch_size, - num_unroll_steps, - infer_context_length, - norm_type, - seed, - buffer_reanalyze_freq, - reanalyze_batch_size, - reanalyze_partition, - num_segments -): - configs = [] - # TODO: debug name - exp_name_prefix = ( - f'data_lz/data_muzero_mt_atari_20250228/{len(env_id_list)}games_brf{buffer_reanalyze_freq}/' - f'{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_' - f'{len(env_id_list)}-pred-head_mbs-512_upc80_H{num_unroll_steps}_seed{seed}/' - ) + policy=dict( + multi_gpu=True, # Very important for DDP + learn=dict( + learner=dict( + hook=dict(save_ckpt_after_iter=200000), + ), + ), + grad_correct_params=dict(), + task_num=self.num_tasks, + model=dict( + device='cuda', + num_res_blocks=2, + num_channels=256, + reward_head_channels=16, + value_head_channels=16, + policy_head_channels=16, + fc_reward_layers=[32], + fc_value_layers=[32], + fc_policy_layers=[32], + observation_shape=(4, 96, 96), + frame_stack_num=4, + gray_scale=True, + action_space_size=self.action_space_size, + norm_type=self.norm_type, + model_type='conv', + image_channel=1, + downsample=True, + self_supervised_learning_loss=True, + discrete_action_encoding_type='one_hot', + use_sim_norm=True, + use_sim_norm_kl_loss=False, + task_num=self.num_tasks, + ), + allocated_batch_sizes=False, + cuda=True, + env_type='not_board_games', + train_start_after_envsteps=2000, + # train_start_after_envsteps=0, # TODO: debug + game_segment_length=20, + random_collect_episode_num=0, + use_augmentation=True, + use_priority=False, + replay_ratio=0.25, + num_unroll_steps=self.num_unroll_steps, + update_per_collect=80, + optim_type='SGD', + td_steps=5, + lr_piecewise_constant_decay=True, + manual_temperature_decay=False, + learning_rate=0.2, + target_update_freq=100, + num_segments=self.num_segments, + num_simulations=self.num_simulations, + policy_entropy_weight=5e-3, # TODO: Fine-tune this weight. + ssl_loss_weight=2, + eval_freq=int(5e3), + replay_buffer_size=int(5e5), + collector_env_num=self.collector_env_num, + evaluator_env_num=self.evaluator_env_num, + # ============= Reanalyze Parameters ============= + buffer_reanalyze_freq=self.buffer_reanalyze_freq, + reanalyze_batch_size=self.reanalyze_batch_size, + reanalyze_partition=self.reanalyze_partition, + ), + )) - for task_id, env_id in enumerate(env_id_list): - config = create_config( - env_id, - action_space_size, - # collector_env_num if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, # TODO: different collector_env_num for Pong and Boxing - # evaluator_env_num if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, - # n_episode if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, - collector_env_num, - evaluator_env_num, - n_episode, - num_simulations, - reanalyze_ratio, - batch_size, - num_unroll_steps, - infer_context_length, - norm_type, - buffer_reanalyze_freq, - reanalyze_batch_size, - reanalyze_partition, - num_segments + def _get_exp_name(self, env_id: str) -> str: + """ + Overview: + Generates a formatted experiment name for a given task. + Arguments: + - env_id (:obj:`str`): The environment ID for the specific task. + Returns: + - (:obj:`str`): The formatted experiment name. + """ + # TODO: debug name + prefix = ( + f'{self.exp_path_prefix}/{self.num_tasks}games_brf{self.buffer_reanalyze_freq}/' + f'{self.num_tasks}games_brf{self.buffer_reanalyze_freq}_1-encoder-{self.norm_type}-res2-channel256_gsl20_' + f'{self.num_tasks}-pred-head_mbs-512_upc80_H{self.num_unroll_steps}_seed{self.seed}/' ) - config.policy.task_id = task_id - config.exp_name = f"{exp_name_prefix}{env_id.split('NoFrameskip')[0]}_muzero-mt_seed{seed}" + env_name = env_id.split('NoFrameskip')[0] + return f"{prefix}{env_name}_muzero-mt_seed{self.seed}" - configs.append([task_id, [config, create_env_manager()]]) + def generate_configs(self) -> List[List[Union[int, List[Any]]]]: + """ + Overview: + Generates the final list of configurations for all specified tasks, + ready to be used by the training entry point. + Returns: + - (:obj:`List[List[Union[int, List[Any]]]]`): A list where each element corresponds to a task, + containing the task_id and a list with the task's config and env_manager config. + """ + base_config = self._create_base_config() + env_manager_config = self._create_env_manager_config() + + configs = [] + for task_id, env_id in enumerate(self.env_id_list): + task_config = deepcopy(base_config) + + # --- Apply task-specific settings --- + task_config.env.env_id = env_id + task_config.policy.task_id = task_id + + # Handle per-task batch size if provided as a list + if isinstance(self.batch_size, list): + task_config.policy.batch_size = self.batch_size[task_id] + else: + task_config.policy.batch_size = self.batch_size + + task_config.exp_name = self._get_exp_name(env_id) - return configs + configs.append([task_id, [task_config, env_manager_config]]) + + return configs -def create_env_manager(): - return EasyDict(dict( - env=dict( - type='atari_lightzero', - import_names=['zoo.atari.envs.atari_lightzero_env'], - ), - env_manager=dict(type='subprocess'), - policy=dict( - type='muzero_multitask', - import_names=['lzero.policy.muzero_multitask'], - ), - )) - -if __name__ == "__main__": - # import sys - # sys.path.insert(0, "/mnt/afs/niuyazhe/code/LightZero") - # import lzero - # print("lzero path:", lzero.__file__) + @staticmethod + def _create_env_manager_config() -> EasyDict: + """ + Overview: + Creates a static configuration for the environment and policy managers. + Returns: + - (:obj:`EasyDict`): A dictionary containing manager configurations. + """ + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='muzero_multitask', + import_names=['lzero.policy.muzero_multitask'], + ), + )) - # parser = argparse.ArgumentParser(description='Train MuZero Multitask on Atari') - # parser.add_argument('--seed', type=int, default=0, help='Random seed') - # args = parser.parse_args() - # Define your list of environment IDs - env_id_list = [ - 'PongNoFrameskip-v4', - 'MsPacmanNoFrameskip-v4', - 'SeaquestNoFrameskip-v4', - 'BoxingNoFrameskip-v4', - # 'AlienNoFrameskip-v4', - # 'ChopperCommandNoFrameskip-v4', - # 'HeroNoFrameskip-v4', - # 'RoadRunnerNoFrameskip-v4', - ] +if __name__ == "__main__": + # ============================================================== + # Hyperparameters for Multi-Task Training + # ============================================================== + + # --- List of Atari environments for multi-task learning --- env_id_list = [ - 'PongNoFrameskip-v4', - 'MsPacmanNoFrameskip-v4', - 'SeaquestNoFrameskip-v4', - 'BoxingNoFrameskip-v4', - 'AlienNoFrameskip-v4', - 'ChopperCommandNoFrameskip-v4', - 'HeroNoFrameskip-v4', - 'RoadRunnerNoFrameskip-v4', - 'AmidarNoFrameskip-v4', - 'AssaultNoFrameskip-v4', - 'AsterixNoFrameskip-v4', - 'BankHeistNoFrameskip-v4', - 'BattleZoneNoFrameskip-v4', - 'CrazyClimberNoFrameskip-v4', - 'DemonAttackNoFrameskip-v4', - 'FreewayNoFrameskip-v4', - 'FrostbiteNoFrameskip-v4', - 'GopherNoFrameskip-v4', - 'JamesbondNoFrameskip-v4', - 'KangarooNoFrameskip-v4', - 'KrullNoFrameskip-v4', - 'KungFuMasterNoFrameskip-v4', - 'PrivateEyeNoFrameskip-v4', - 'UpNDownNoFrameskip-v4', - 'QbertNoFrameskip-v4', - 'BreakoutNoFrameskip-v4', + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', + 'BoxingNoFrameskip-v4', 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', + 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', 'AmidarNoFrameskip-v4', + 'AssaultNoFrameskip-v4', 'AsterixNoFrameskip-v4', 'BankHeistNoFrameskip-v4', + 'BattleZoneNoFrameskip-v4', 'CrazyClimberNoFrameskip-v4', 'DemonAttackNoFrameskip-v4', + 'FreewayNoFrameskip-v4', 'FrostbiteNoFrameskip-v4', 'GopherNoFrameskip-v4', + 'JamesbondNoFrameskip-v4', 'KangarooNoFrameskip-v4', 'KrullNoFrameskip-v4', + 'KungFuMasterNoFrameskip-v4', 'PrivateEyeNoFrameskip-v4', 'UpNDownNoFrameskip-v4', + 'QbertNoFrameskip-v4', 'BreakoutNoFrameskip-v4', ] - action_space_size = 18 # Full action space, adjust if different per env + # --- Core Experiment Settings --- seed = 0 + max_env_step = int(5e5) + + # --- Training & Model Parameters --- + num_unroll_steps = 5 + num_simulations = 50 + norm_type = 'BN' # 'BN' (Batch Normalization) or 'LN' (Layer Normalization) + # --- Environment & Collector Settings --- collector_env_num = 8 evaluator_env_num = 3 num_segments = 8 - n_episode = 8 - num_simulations = 50 - reanalyze_ratio = 0.0 - max_env_step = 5e5 + # --- Batch Size Configuration --- + # The batch size is dynamically calculated per task to not exceed a maximum total batch size. max_batch_size = 512 - # max_batch_size = 1024 - batch_size = [int(min(64, max_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] - - num_unroll_steps = 5 - infer_context_length = 4 - # norm_type = 'LN' - norm_type = 'BN' + per_task_batch_size = int(min(64, max_batch_size / len(env_id_list))) + batch_size = [per_task_batch_size] * len(env_id_list) - buffer_reanalyze_freq = 1 / 50 # Adjusted as per UniZero config + # --- Reanalyze Buffer Settings --- + buffer_reanalyze_freq = 1 / 50 reanalyze_batch_size = 160 reanalyze_partition = 0.75 - - # =========== TODO: debug =========== + # --- (Optional) Debug Settings --- + # To use debug settings, uncomment the following lines. # collector_env_num = 2 # evaluator_env_num = 2 # num_segments = 2 - # n_episode = 2 # num_simulations = 3 - # batch_size = [int(min(2, max_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + # debug_batch_size = int(min(2, max_batch_size / len(env_id_list))) + # batch_size = [debug_batch_size] * len(env_id_list) + # print("--- RUNNING IN DEBUG MODE ---") + + print(f'=========== Batch size per task: {batch_size[0]} ===========') - print(f'=========== batch_size: {batch_size} ===========') - # Generate configurations - configs = generate_configs( + # ============================================================== + # Configuration Generation and Training Launch + # ============================================================== + + # --- Instantiate and generate configurations --- + experiment_config = AtariMuZeroMultitaskConfig( env_id_list=env_id_list, - action_space_size=action_space_size, + seed=seed, + max_env_step=max_env_step, + num_unroll_steps=num_unroll_steps, + num_simulations=num_simulations, collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, - n_episode=n_episode, - num_simulations=num_simulations, - reanalyze_ratio=reanalyze_ratio, batch_size=batch_size, - num_unroll_steps=num_unroll_steps, - infer_context_length=infer_context_length, norm_type=norm_type, - seed=seed, buffer_reanalyze_freq=buffer_reanalyze_freq, reanalyze_batch_size=reanalyze_batch_size, reanalyze_partition=reanalyze_partition, - num_segments=num_segments + num_segments=num_segments, + # Note: Update this path to your desired location. + exp_path_prefix='YOUR_EXPERIMENT_PATH_PREFIX/data_muzero_mt_atari_20250228' ) + + configs_to_run = experiment_config.generate_configs() + # --- Launch Distributed Training --- """ Overview: This script should be executed with GPUs. - Run the following command to launch the script: + Set the NCCL timeout and launch the script using one of the following commands. + + Command using torch.distributed.launch: export NCCL_TIMEOUT=3600000 - python -m torch.distributed.launch --nproc_per_node=4 --master_port=29501 ./zoo/atari/config/atari_muzero_multitask_segment_8games_ddp_config.py - 或者使用 torchrun: - torchrun --nproc_per_node=4 ./zoo/atari/config/atari_muzero_multitask_segment_8games_ddp_config.py + python -m torch.distributed.launch --nproc_per_node=4 --master_port=29501 ./path/to/this/script.py + + Command using torchrun: + export NCCL_TIMEOUT=3600000 + torchrun --nproc_per_node=4 --master_port=29501 ./path/to/this/script.py """ from lzero.entry import train_muzero_multitask_segment_ddp from ding.utils import DDPContext + with DDPContext(): - train_muzero_multitask_segment_ddp(configs, seed=seed, max_env_step=max_env_step) \ No newline at end of file + train_muzero_multitask_segment_ddp(configs_to_run, seed=seed, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/atari/config/atari_muzero_rnn_fullobs_config.py b/zoo/atari/config/atari_muzero_rnn_fullobs_config.py index 9482fff09..490f347f4 100644 --- a/zoo/atari/config/atari_muzero_rnn_fullobs_config.py +++ b/zoo/atari/config/atari_muzero_rnn_fullobs_config.py @@ -58,9 +58,8 @@ self_supervised_learning_loss=True, # default is False discrete_action_encoding_type='one_hot', norm_type='BN', - reward_support_size=101, - value_support_size=101, - support_scale=50, + reward_support_range=(-50., 51., 1.), + value_support_range=(-50., 51., 1.), context_length=context_length_init, # NOTE use_sim_norm=True, use_sim_norm_kl_loss=False, diff --git a/zoo/atari/config/atari_muzero_segment_config.py b/zoo/atari/config/atari_muzero_segment_config.py index 212798506..03fffa39e 100644 --- a/zoo/atari/config/atari_muzero_segment_config.py +++ b/zoo/atari/config/atari_muzero_segment_config.py @@ -18,11 +18,14 @@ def main(env_id, seed): num_unroll_steps = 5 batch_size = 256 - max_env_step = int(5e5) + # max_env_step = int(5e5) + max_env_step = int(100e6) + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. # buffer_reanalyze_freq = 1/10 - buffer_reanalyze_freq = 1/10000 + buffer_reanalyze_freq = 1/50 + # buffer_reanalyze_freq = 1/10000 # Each reanalyze process will reanalyze sequences ( transitions per sequence) reanalyze_batch_size = 160 # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. diff --git a/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py b/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py new file mode 100644 index 000000000..9c4725f9f --- /dev/null +++ b/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py @@ -0,0 +1,550 @@ +# -*- coding: utf-8 -*- +""" +Overview: + This script contains the configuration generation logic for a multi-task UniZero agent + designed for Atari environments. It sets up experiment parameters, computes batch sizes + for distributed training, and generates the final configuration objects required to + launch the training process. + +Execution Command Example: + To run this script using distributed training with GPUs, use the following command. + Replace with the number of GPUs per node (e.g., 8) and adjust paths and log files as needed. + + cd /path/to/your/project/LightZero + python -m torch.distributed.launch --nproc_per_node= --master_port= \ + /path/to/this/script.py 2>&1 | tee /path/to/your/logs/training.log +""" +import math +from typing import List, Tuple, Dict, Any + +from easydict import EasyDict +from ding.utils import DDPContext +# It is recommended to place entry point imports within the main execution block +# to avoid circular dependencies or premature initializations. +# from lzero.entry import train_unizero_multitask_balance_segment_ddp + + +# ============================================================== +# Configuration Computation and Generation +# ============================================================== + +def compute_batch_config( + env_id_list: List[str], + effective_batch_size: int, + gpus_per_node: int = 8, + max_micro_batch_per_gpu: int = 400 +) -> Tuple[List[int], int]: + """ + Overview: + Computes the micro-batch size for each environment and the number of gradient accumulation steps. + This is designed to balance the load across multiple environments and GPUs while respecting + memory constraints (max_micro_batch_per_gpu). + + Arguments: + - env_id_list (:obj:`List[str]`): A list of environment IDs. + - effective_batch_size (:obj:`int`): The target total batch size after gradient accumulation. + - gpus_per_node (:obj:`int`): The number of GPUs available for training. Defaults to 8. + - max_micro_batch_per_gpu (:obj:`int`): The maximum micro-batch size that can fit on a single GPU. Defaults to 400. + + Returns: + - (:obj:`Tuple[List[int], int]`): A tuple containing: + - A list of micro-batch sizes, one for each environment. + - The number of gradient accumulation steps required. + """ + num_envs = len(env_id_list) + if num_envs == 0: + return [], 1 + + # To avoid division by zero, assume at least one environment is processed per GPU group. + envs_per_gpu_group = max(1, num_envs // gpus_per_node) + + # Calculate the maximum micro-batch size per environment based on GPU memory limits. + max_micro_batch_per_env = int(max_micro_batch_per_gpu / envs_per_gpu_group) + + # Calculate the theoretical batch size per environment if distributed evenly. + theoretical_env_batch = effective_batch_size / num_envs + + if theoretical_env_batch > max_micro_batch_per_env: + # If the theoretical batch size exceeds the per-environment limit, + # cap the micro-batch size at the maximum allowed value. + micro_batch_size = max_micro_batch_per_env + # Calculate gradient accumulation steps needed to reach the effective batch size. + grad_accumulate_steps = math.ceil(theoretical_env_batch / max_micro_batch_per_env) + else: + # If the theoretical batch size is within limits, use it directly. + micro_batch_size = int(theoretical_env_batch) + grad_accumulate_steps = 1 + + # Assign the same computed micro-batch size to all environments. + batch_sizes = [micro_batch_size] * num_envs + + # Logging for debugging purposes. + print(f"Number of environments: {num_envs}") + print(f"Effective total batch size: {effective_batch_size}") + print(f"Theoretical batch size per environment: {theoretical_env_batch:.2f}") + print(f"Micro-batch size per environment: {micro_batch_size}") + print(f"Gradient accumulation steps: {grad_accumulate_steps}") + + return batch_sizes, grad_accumulate_steps + + +def create_config( + env_id: str, + action_space_size: int, + collector_env_num: int, + evaluator_env_num: int, + n_episode: int, + num_simulations: int, + reanalyze_ratio: float, + batch_size: int, + num_unroll_steps: int, + infer_context_length: int, + norm_type: str, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + total_batch_size: int, + target_return: int, + curriculum_stage_num: int, + num_envs: int, +) -> EasyDict: + """ + Overview: + Creates the main configuration dictionary for a single UniZero task. + + Arguments: + - env_id (:obj:`str`): The ID of the environment (e.g., 'PongNoFrameskip-v4'). + - action_space_size (:obj:`int`): The size of the action space. + - collector_env_num (:obj:`int`): Number of environments for data collection. + - evaluator_env_num (:obj:`int`): Number of environments for evaluation. + - n_episode (:obj:`int`): Number of episodes to run for collection. + - num_simulations (:obj:`int`): Number of simulations for MCTS. + - reanalyze_ratio (:obj:`float`): The ratio of reanalyzed data in a batch. + - batch_size (:obj:`int`): The micro-batch size for training. + - num_unroll_steps (:obj:`int`): The number of steps to unroll the model dynamics. + - infer_context_length (:obj:`int`): The context length for inference. + - norm_type (:obj:`str`): The type of normalization layer to use (e.g., 'LN'). + - buffer_reanalyze_freq (:obj:`float`): Frequency of reanalyzing the replay buffer. + - reanalyze_batch_size (:obj:`int`): Batch size for reanalysis. + - reanalyze_partition (:obj:`float`): Partition ratio for reanalysis. + - num_segments (:obj:`int`): Number of segments for game episodes. + - total_batch_size (:obj:`int`): The effective total batch size. + - target_return (:obj:`int`): The target return for the environment. + - curriculum_stage_num (:obj:`int`): The number of stages in curriculum learning. + - num_envs (:obj:`int`): The total number of environments in the multi-task setup. + + Returns: + - (:obj:`EasyDict`): A configuration object for the agent. + """ + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + full_action_space=True, + collect_max_episode_steps=int(5e3), + eval_max_episode_steps=int(5e3), + ), + policy=dict( + multi_gpu=True, # Crucial for DDP + model=dict( + observation_shape=(3, 64, 64), + action_space_size=action_space_size, + norm_type=norm_type, + num_res_blocks=2, + num_channels=256, + continuous_action_space=False, + world_model_cfg=dict( + use_global_pooling=False, + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', + share_head=False, + analysis_dormant_ratio_weight_rank=False, + dormant_threshold=0.025, + continuous_action_space=False, + task_embed_option=None, + use_task_embed=False, + use_shared_projection=False, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + num_layers=4, + num_heads=24, + embed_dim=768, + obs_type='image', + env_num=num_envs, + task_num=num_envs, + encoder_type='vit', + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + moe_in_transformer=False, + multiplication_moe_in_transformer=True, + n_shared_experts=1, + num_experts_per_tok=1, + num_experts_of_moe_in_transformer=8, + moe_use_lora=True, + curriculum_stage_num=curriculum_stage_num, + lora_target_modules=["attn", "feed_forward"], + lora_r=64, + lora_alpha=32, + lora_dropout=0.1, + lora_scale_init=1, + min_stage0_iters=50000, + max_stage_iters=20000, + apply_curriculum_to_encoder=False, + ), + ), + # --- Task and Learning Settings --- + total_task_num=num_envs, + task_num=num_envs, + task_id=0, # This will be overridden for each task. + target_return=target_return, + use_task_exploitation_weight=False, + task_complexity_weight=True, + balance_pipeline=True, + # --- Training Settings --- + cuda=True, + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + batch_size=batch_size, + num_unroll_steps=num_unroll_steps, + update_per_collect=80, + replay_ratio=0.25, + optim_type='AdamW', + cos_lr_scheduler=False, + train_start_after_envsteps=int(0), + # --- Replay Buffer and Reanalysis --- + replay_buffer_size=int(5e5), + num_segments=num_segments, + use_priority=False, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + reanalyze_ratio=reanalyze_ratio, + # --- MCTS Settings --- + num_simulations=num_simulations, + collect_num_simulations=num_simulations, + eval_num_simulations=50, + # --- Collector and Evaluator Settings --- + n_episode=n_episode, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + eval_freq=int(1e4), + # --- Miscellaneous --- + print_task_priority_logs=False, + model_path=None, + game_segment_length=20, + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=200000))), + ), + )) + + +def _generate_experiment_name( + base_path_prefix: str, + num_envs: int, + curriculum_stage_num: int, + buffer_reanalyze_freq: float, + seed: int, + env_id: str +) -> str: + """ + Overview: + Helper function to generate a standardized experiment name. + + Arguments: + - base_path_prefix (:obj:`str`): The prefix for the experiment path, e.g., 'data_unizero_atari_mt_balance_YYYYMMDD'. + - num_envs (:obj:`int`): The total number of environments. + - curriculum_stage_num (:obj:`int`): The number of curriculum stages. + - buffer_reanalyze_freq (:obj:`float`): The buffer reanalyze frequency. + - seed (:obj:`int`): The random seed for the experiment. + - env_id (:obj:`str`): The environment ID for this specific task. + + Returns: + - (:obj:`str`): The generated experiment name. + """ + # Template for the experiment's parent directory. + brf_str = str(buffer_reanalyze_freq).replace('.', '') + parent_dir = ( + f"{base_path_prefix}/atari_{num_envs}games_balance-total-stage{curriculum_stage_num}_" + f"stage-50k-20k_vit-small-ln_trans-nlayer4-moe8_backbone-attn-mlp-lora_no-lora-scale_" + f"brf{brf_str}_not-share-head_seed{seed}/" + ) + + # Clean the environment ID for the final part of the name. + env_name_part = env_id.split('NoFrameskip')[0] + + return f"{parent_dir}{env_name_part}_seed{seed}" + + +def generate_configs( + env_id_list: List[str], + action_space_size: int, + collector_env_num: int, + n_episode: int, + evaluator_env_num: int, + num_simulations: int, + reanalyze_ratio: float, + batch_sizes: List[int], + num_unroll_steps: int, + infer_context_length: int, + norm_type: str, + seed: int, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + total_batch_size: int, + target_return_dict: Dict[str, int], + curriculum_stage_num: int, +) -> List[Tuple[int, List[Any]]]: + """ + Overview: + Generates a list of configuration tuples, one for each task/environment. + + Returns: + - (:obj:`List[Tuple[int, List[Any]]]`): A list where each element is a tuple containing + the task_id and a list with the main config and the environment manager config. + """ + configs = [] + exp_name_base_prefix = 'data_unizero_atari_mt_balance_20250730' # YYYYMMDD format + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id=env_id, + action_space_size=action_space_size, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_episode=n_episode, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_size=batch_sizes[task_id], + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments, + total_batch_size=total_batch_size, + target_return=target_return_dict[env_id], + curriculum_stage_num=curriculum_stage_num, + num_envs=len(env_id_list), + ) + config.policy.task_id = task_id + config.exp_name = _generate_experiment_name( + base_path_prefix=exp_name_base_prefix, + num_envs=len(env_id_list), + curriculum_stage_num=curriculum_stage_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + seed=seed, + env_id=env_id + ) + configs.append([task_id, [config, create_env_manager()]]) + return configs + + +def create_env_manager() -> EasyDict: + """ + Overview: + Creates the environment manager configuration, specifying the types of environment, + policy, and manager to be used. + + Returns: + - (:obj:`EasyDict`): A configuration object for the environment manager. + """ + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + + +def get_atari_target_return_dict(ratio: float = 1.0) -> Dict[str, int]: + """ + Overview: + Calculates the target return for each Atari game based on a predefined score + and a scaling ratio. + + Arguments: + - ratio (:obj:`float`): A scaling factor for the target returns. Defaults to 1.0. + + Returns: + - (:obj:`Dict[str, int]`): A dictionary mapping environment IDs to their calculated target returns. + """ + # Pre-defined target scores for various Atari games. + target_scores = { + 'PongNoFrameskip-v4': 20, + 'MsPacmanNoFrameskip-v4': 6951.6, + 'SeaquestNoFrameskip-v4': 42054.7, + 'BoxingNoFrameskip-v4': 12.1, + 'AlienNoFrameskip-v4': 7127.7, + 'ChopperCommandNoFrameskip-v4': 7387.8, + 'HeroNoFrameskip-v4': 30826.4, + 'RoadRunnerNoFrameskip-v4': 7845.0, + 'AmidarNoFrameskip-v4': 100.5, + 'AssaultNoFrameskip-v4': 742.0, + 'AsterixNoFrameskip-v4': 1503.3, + 'BankHeistNoFrameskip-v4': 753.1, + 'BattleZoneNoFrameskip-v4': 12187.5, + 'CrazyClimberNoFrameskip-v4': 15829.4, + 'DemonAttackNoFrameskip-v4': 1971.0, + 'FreewayNoFrameskip-v4': 29.6, + 'FrostbiteNoFrameskip-v4': 334.7, + 'GopherNoFrameskip-v4': 2412.5, + 'JamesbondNoFrameskip-v4': 302.8, + 'KangarooNoFrameskip-v4': 3035.0, + 'KrullNoFrameskip-v4': 2665.5, + 'KungFuMasterNoFrameskip-v4': 12736.3, + 'PrivateEyeNoFrameskip-v4': 1001.3, + 'UpNDownNoFrameskip-v4': 11693.2, + 'QbertNoFrameskip-v4': 13455.0, + 'BreakoutNoFrameskip-v4': 30.5, + } + return {env: int(round(score * ratio)) for env, score in target_scores.items()} + + +def get_env_id_list(num_games: int) -> List[str]: + """ + Overview: + Returns a list of Atari environment IDs based on the specified number of games. + + Arguments: + - num_games (:obj:`int`): The number of games to include (e.g., 8 or 26). + + Returns: + - (:obj:`List[str]`): A list of environment ID strings. + """ + games_8 = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + ] + games_26 = games_8 + [ + 'AmidarNoFrameskip-v4', 'AssaultNoFrameskip-v4', 'AsterixNoFrameskip-v4', 'BankHeistNoFrameskip-v4', + 'BattleZoneNoFrameskip-v4', 'CrazyClimberNoFrameskip-v4', 'DemonAttackNoFrameskip-v4', + 'FreewayNoFrameskip-v4', + 'FrostbiteNoFrameskip-v4', 'GopherNoFrameskip-v4', 'JamesbondNoFrameskip-v4', 'KangarooNoFrameskip-v4', + 'KrullNoFrameskip-v4', 'KungFuMasterNoFrameskip-v4', 'PrivateEyeNoFrameskip-v4', 'UpNDownNoFrameskip-v4', + 'QbertNoFrameskip-v4', 'BreakoutNoFrameskip-v4', + ] + if num_games == 3: + return ['PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4'] + elif num_games == 8: + return games_8 + elif num_games == 26: + return games_26 + else: + raise ValueError(f"Unsupported number of games: {num_games}. Supported values are 3, 8, 26.") + + +def main(): + """ + Overview: + Main function to configure and launch the multi-task training process. + """ + # ============================================================== + # Primary Hyperparameters + # ============================================================== + # --- Experiment --- + num_games = 8 # Options: 3, 8, 26 + seeds = [0] + max_env_step = int(4e5) + benchmark_name = "atari" + + # --- Curriculum --- + curriculum_stage_num = 5 + + # --- Environment and Agent --- + action_space_size = 18 + num_simulations = 50 + num_unroll_steps = 10 + infer_context_length = 4 + norm_type = 'LN' + + # --- Collector and Evaluator --- + collector_env_num = 8 + evaluator_env_num = 3 + n_episode = 8 + num_segments = 8 + + # --- Reanalysis --- + reanalyze_ratio = 0.0 + buffer_reanalyze_freq = 1 / 50 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + # ============================================================== + # Derived Configurations + # ============================================================== + env_id_list = get_env_id_list(num_games) + target_return_dict = get_atari_target_return_dict(ratio=1.0) + + # --- Batch Size Calculation --- + if num_games == 8: + effective_batch_size = 512 + elif num_games == 26: + effective_batch_size = 512 # For ViT-Base encoder + else: + # Default or other cases + effective_batch_size = 512 + + batch_sizes, grad_acc_steps = compute_batch_config(env_id_list, effective_batch_size) + # Note: `total_batch_size` is passed to the config but `effective_batch_size` is used for calculation. + # This maintains consistency with the original script's logic. + total_batch_size = effective_batch_size + + # ============================================================== + # Launch Training + # ============================================================== + from lzero.entry import train_unizero_multitask_balance_segment_ddp + + for seed in seeds: + configs = generate_configs( + env_id_list=env_id_list, + action_space_size=action_space_size, + collector_env_num=collector_env_num, + n_episode=n_episode, + evaluator_env_num=evaluator_env_num, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_sizes=batch_sizes, + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + seed=seed, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments, + total_batch_size=total_batch_size, + target_return_dict=target_return_dict, + curriculum_stage_num=curriculum_stage_num + ) + + with DDPContext(): + train_unizero_multitask_balance_segment_ddp( + configs, + seed=seed, + max_env_step=max_env_step, + benchmark_name=benchmark_name + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py b/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py index a08064748..33de7eea0 100644 --- a/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py +++ b/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py @@ -1,9 +1,96 @@ from easydict import EasyDict +import math +from typing import List, Tuple, Any, Dict, Union -def create_config(env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, - num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, - norm_type, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, - total_batch_size): +# ------------------------------------------------- +# 1. Refactored compute_batch_config +# ------------------------------------------------- +def compute_batch_config( + env_id_list: List[str], + effective_batch_size: int, + gpu_num: int = 8, + max_micro_batch_one_gpu: int = 400, +) -> Tuple[List[int], int]: + """ + Overview: + Calculate the micro-batch size for each environment and the number of gradient accumulation steps + to approach a target effective batch size across multiple GPUs and environments. + + Arguments: + - env_id_list (:obj:`List[str]`): A list of environment IDs for all tasks. + - effective_batch_size (:obj:`int`): The target global batch size for one backward pass. + - gpu_num (:obj:`int`): The number of GPUs actually used. Defaults to 8. + - max_micro_batch_one_gpu (:obj:`int`): The maximum micro-batch size a single GPU can handle. Defaults to 400. + + Returns: + - batch_sizes (:obj:`List[int]`): A list of micro-batch sizes for each environment. + - grad_acc_steps (:obj:`int`): The number of gradient accumulation steps. + """ + n_env = len(env_id_list) + # Number of environments that each GPU needs to handle simultaneously. + envs_per_gpu = max(1, math.ceil(n_env / gpu_num)) + # Reduce the micro-batch limit if multiple environments share one GPU. + max_micro_batch = max(1, max_micro_batch_one_gpu // envs_per_gpu) + + # First, calculate a candidate micro-batch by distributing the effective batch size evenly. + candidate = max(1, effective_batch_size // n_env) + micro_batch = min(candidate, max_micro_batch) + + # Gradient accumulation steps = ceil(global_batch / (micro_batch * n_env)). + grad_acc_steps = max(1, math.ceil(effective_batch_size / (micro_batch * n_env))) + + # Fine-tune the micro-batch downwards to ensure: + # micro_batch * n_env * grad_acc_steps <= effective_batch_size + # This aims to get as close as possible to the target without exceeding it. + while micro_batch * n_env * grad_acc_steps > effective_batch_size: + micro_batch -= 1 + if micro_batch == 0: # Defensive check, should not happen in theory. + micro_batch = 1 + break + + batch_sizes = [micro_batch] * n_env + + # --- Debug Information --- # + real_total_batch_size = micro_batch * n_env * grad_acc_steps + print( + f"[BatchConfig] Envs={n_env}, TargetTotalBS={effective_batch_size}, " + f"MicroBS={micro_batch}, GradAccSteps={grad_acc_steps}, RealTotalBS={real_total_batch_size}" + ) + + return batch_sizes, grad_acc_steps + +def create_config( + env_id: str, action_space_size: int, collector_env_num: int, evaluator_env_num: int, n_episode: int, + num_simulations: int, reanalyze_ratio: float, batch_size: int, num_unroll_steps: int, + infer_context_length: int, norm_type: str, buffer_reanalyze_freq: float, reanalyze_batch_size: int, + reanalyze_partition: float, num_segments: int, total_batch_size: int, num_layers: int +) -> EasyDict: + """ + Overview: + Creates the main configuration structure for a single training task. + + Arguments: + - env_id (:obj:`str`): The environment ID. + - action_space_size (:obj:`int`): The size of the action space. + - collector_env_num (:obj:`int`): Number of environments for data collection. + - evaluator_env_num (:obj:`int`): Number of environments for evaluation. + - n_episode (:obj:`int`): Number of episodes to run for evaluation. + - num_simulations (:obj:`int`): Number of simulations in MCTS. + - reanalyze_ratio (:obj:`float`): The ratio of reanalyzed samples in a batch. + - batch_size (:obj:`int`): The batch size for training. + - num_unroll_steps (:obj:`int`): The number of steps to unroll the model dynamics. + - infer_context_length (:obj:`int`): The context length for inference. + - norm_type (:obj:`str`): The type of normalization layer to use (e.g., 'LN'). + - buffer_reanalyze_freq (:obj:`float`): Frequency of reanalyzing the replay buffer. + - reanalyze_batch_size (:obj:`int`): Batch size for reanalysis. + - reanalyze_partition (:obj:`float`): Partition ratio for reanalysis. + - num_segments (:obj:`int`): Number of segments for data collection. + - total_batch_size (:obj:`int`): The total effective batch size. + - num_layers (:obj:`int`): Number of layers in the transformer model. + + Returns: + - (:obj:`EasyDict`): A configuration object. + """ return EasyDict(dict( env=dict( stop_value=int(1e6), @@ -15,116 +102,162 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu n_evaluator_episode=evaluator_env_num, manager=dict(shared_memory=False), full_action_space=True, - # collect_max_episode_steps=int(5e3), - # eval_max_episode_steps=int(5e3), - # ===== only for debug ===== - collect_max_episode_steps=int(20), - eval_max_episode_steps=int(20), + collect_max_episode_steps=int(5e3), + eval_max_episode_steps=int(5e3), + + # collect_max_episode_steps=int(50), # debug + # eval_max_episode_steps=int(50), ), policy=dict( - multi_gpu=True, # Very important for ddp + multi_gpu=True, # Essential for DDP (Distributed Data Parallel) only_use_moco_stats=False, - use_moco=False, # ==============TODO============== - # use_moco=True, # ==============TODO============== + use_moco=False, learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=200000))), grad_correct_params=dict( MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, MoCo_rho=0, calpha=0.5, rescale=1, ), + moco_version="v1", total_task_num=len(env_id_list), task_num=len(env_id_list), - task_id=0, + task_id=0, # This will be overridden for each task model=dict( observation_shape=(3, 64, 64), action_space_size=action_space_size, norm_type=norm_type, num_res_blocks=2, num_channels=256, - # num_channels=512, # ==============TODO============== continuous_action_space=False, world_model_cfg=dict( - # use_adaptive_scale=True, - use_adaptive_scale=False, - + num_res_blocks=2, + num_channels=256, + norm_type=norm_type, + use_global_pooling=False, final_norm_option_in_obs_head='LayerNorm', final_norm_option_in_encoder='LayerNorm', - predict_latent_loss_type='mse', # TODO: for latent state layer_norm - - # final_norm_option_in_obs_head='SimNorm', - # final_norm_option_in_encoder='SimNorm', - # predict_latent_loss_type='group_kl', # TODO: only for latent state sim_norm - - share_head=False, # TODO - - analysis_dormant_ratio_weight_rank=False, # TODO - dormant_threshold=0.025, - + predict_latent_loss_type='mse', + share_head=False, + analysis_dormant_ratio_weight_rank=False, + # analysis_dormant_ratio_weight_rank=True, + # analysis_dormant_ratio_interval=5000, continuous_action_space=False, - - task_embed_option=None, # ==============TODO: none ============== - use_task_embed=False, # ==============TODO============== - - # task_embed_option='concat_task_embed', # ==============TODO: none ============== - # use_task_embed=True, # ==============TODO============== - # task_embed_dim=128, - # # task_embed_dim=96, - + task_embed_option=None, + use_task_embed=False, use_shared_projection=False, max_blocks=num_unroll_steps, max_tokens=2 * num_unroll_steps, context_length=2 * infer_context_length, device='cuda', action_space_size=action_space_size, - num_layers=8, - num_heads=24, - - # ===== only for debug ===== - # num_layers=1, - # num_heads=8, - + num_layers=num_layers, + # num_heads=24, + num_heads=8, embed_dim=768, obs_type='image', - env_num=8, + env_num=len(env_id_list), task_num=len(env_id_list), + # game_segment_length=game_segment_length, + game_segment_length=20, # TODO + use_priority=True, + # use_priority=False, # TODO===== + priority_prob_alpha=1, + priority_prob_beta=1, + # encoder_type='vit', + encoder_type='resnet', use_normal_head=True, use_softmoe_head=False, use_moe_head=False, num_experts_in_moe_head=4, moe_in_transformer=False, - multiplication_moe_in_transformer=False, - num_experts_of_moe_in_transformer=4, - # LoRA 参数: - lora_r= 0, - lora_alpha =1, - lora_dropout= 0.0, + multiplication_moe_in_transformer=True, + # multiplication_moe_in_transformer=False, # TODO===== + + n_shared_experts=1, + num_experts_per_tok=1, + num_experts_of_moe_in_transformer=8, + # LoRA parameters + moe_use_lora=False, + lora_r=0, + lora_alpha=1, + lora_dropout=0.0, + + + optim_type='AdamW_mix_lr_wdecay', # only for tsne plot ), ), - use_task_exploitation_weight=False, # TODO - task_complexity_weight=False, # TODO + optim_type='AdamW_mix_lr_wdecay', + weight_decay=1e-2, # TODO: encoder 5*wd, transformer wd, head 0 + learning_rate=0.0001, + + # (bool) 是否启用自适应策略熵权重 (alpha) + use_adaptive_entropy_weight=True, + # use_adaptive_entropy_weight=False, + + # (float) 自适应alpha优化器的学习率 + adaptive_entropy_alpha_lr=1e-4, + target_entropy_start_ratio =0.98, + # target_entropy_end_ratio =0.9, # TODO===== + # target_entropy_end_ratio =0.7, + # target_entropy_decay_steps = 100000, # 例如,在100k次迭代后达到最终值 + + target_entropy_end_ratio =0.5, # for action_space=18 + target_entropy_decay_steps = 100000, # 例如,在150k次迭代 300k envsteps后达到最终值 + # target_entropy_decay_steps = 150000, # 例如,在150k次迭代 300k envsteps后达到最终值 + + # ==================== START: Encoder-Clip Annealing Config ==================== + # (bool) 是否启用 encoder-clip 值的退火。 + use_encoder_clip_annealing=True, + # (str) 退火类型。可选 'linear' 或 'cosine'。 + encoder_clip_anneal_type='cosine', + # (float) 退火的起始 clip 值 (训练初期,较宽松)。 + encoder_clip_start_value=30.0, + # (float) 退火的结束 clip 值 (训练后期,较严格)。 + encoder_clip_end_value=10.0, + # (int) 完成从起始值到结束值的退火所需的训练迭代步数。 + encoder_clip_anneal_steps=100000, # 例如,在100k次迭代后达到最终值 + # encoder_clip_anneal_steps=50000, # 例如,在30k次迭代后达到最终值 + + + # ==================== START: label smooth ==================== + policy_ls_eps_start=0.05, #TODO============= good start in Pong and MsPacman + policy_ls_eps_end=0.01, + policy_ls_eps_decay_steps=50000, # 50k + label_smoothing_eps=0.1, #TODO============= for value + + # ==================== [新增] 范数监控频率 ==================== + # 每隔多少个训练迭代步数,监控一次模型参数的范数。设置为0则禁用。 + monitor_norm_freq=10000, + # monitor_norm_freq=2, # only for debug + + use_task_exploitation_weight=False, + task_complexity_weight=False, total_batch_size=total_batch_size, allocated_batch_sizes=False, - train_start_after_envsteps=int(0), # TODO: ===== only for debug ===== - # train_start_after_envsteps=int(2000), - use_priority=False, + train_start_after_envsteps=int(0), + # use_priority=False, # TODO===== + use_priority=True, + priority_prob_alpha=1, + priority_prob_beta=1, print_task_priority_logs=False, cuda=True, model_path=None, num_unroll_steps=num_unroll_steps, game_segment_length=20, - update_per_collect=2, # TODO: ===== only for debug ===== - # update_per_collect=80, # TODO + update_per_collect=80, # Corresponds to replay_ratio=0.5 for 8 games (20*8*0.5=80) replay_ratio=0.25, batch_size=batch_size, - optim_type='AdamW', - # cos_lr_scheduler=True, + # optim_type='AdamW', cos_lr_scheduler=False, num_segments=num_segments, num_simulations=num_simulations, reanalyze_ratio=reanalyze_ratio, n_episode=n_episode, replay_buffer_size=int(5e5), - eval_freq=int(2e4), + # eval_freq=int(2e4), # Evaluation frequency for 26 games + eval_freq=int(1e4), # Evaluation frequency for 8 games + # eval_freq=int(1e4), # Evaluation frequency for 8 games + # eval_freq=int(2), # ======== TODO: only for debug======== collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, buffer_reanalyze_freq=buffer_reanalyze_freq, @@ -133,27 +266,62 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu ), )) -def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, - num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, - norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, - num_segments, total_batch_size): +def generate_configs( + env_id_list: List[str], action_space_size: int, collector_env_num: int, n_episode: int, + evaluator_env_num: int, num_simulations: int, reanalyze_ratio: float, batch_size: List[int], + num_unroll_steps: int, infer_context_length: int, norm_type: str, seed: int, + buffer_reanalyze_freq: float, reanalyze_batch_size: int, reanalyze_partition: float, + num_segments: int, total_batch_size: int, num_layers: int +) -> List[List[Union[int, List[EasyDict]]]]: + """ + Overview: + Generates a list of configurations for all specified tasks. + + Arguments: + (See arguments for `create_config` function) + - seed (:obj:`int`): The random seed for the experiment. + + Returns: + - (:obj:`List[List[Union[int, List[EasyDict]]]]`): A list where each element contains a task_id + and its corresponding configuration objects. + """ configs = [] - # ===== only for debug ===== - exp_name_prefix = f'data_lz/data_unizero_atari_mt_20250425_debug/atari_{len(env_id_list)}games_tbs1536-encoderchannel256-nlayer8_brf{buffer_reanalyze_freq}_not-share-head_encoder-final-ln_seed{seed}/' + # --- Experiment Name Template --- + # Replace placeholders like [BENCHMARK_TAG] and [MODEL_TAG] to define the experiment name. + # benchmark_tag = "data_unizero_mt_refactor1010_debug" # e.g., unizero_atari_mt_20250612 + benchmark_tag = "data_unizero_mt_refactor1012" # e.g., unizero_atari_mt_20250612 + # model_tag = f"vit-small_moe8_tbs512_tran-nlayer{num_layers}_brf{buffer_reanalyze_freq}_not-share-head" + # model_tag = f"resnet_noprior_noalpha_nomoe_head-inner-ln_adamw-wd1e-2_tbs512_tran-nlayer{num_layers}_brf{buffer_reanalyze_freq}" + + # model_tag = f"vit_prior_alpha-100k-098-07_encoder-100k-30-10_moe8_head-inner-ln_adamw-wd1e-2_tbs512_tran-nlayer{num_layers}_brf{buffer_reanalyze_freq}" + + # model_tag = f"resnet_encoder-100k-30-10-true_label-smooth_prior_alpha-100k-098-07_moe8_head-inner-ln_adamw-wd1e-2-all_tbs512_tran-nlayer{num_layers}_brf{buffer_reanalyze_freq}" + model_tag = f"resnet_tran-nlayer{num_layers}_moe8_encoder-100k-30-10-true_alpha-100k-098-05_prior_adamw-wd1e-2-all_tbs512_brf{buffer_reanalyze_freq}_label-smooth_head-inner-ln" + # model_tag = f"resnet_encoder-100k-30-10-true_label-smooth_prior_alpha-150k-098-05_moe8_head-inner-ln_adamw-wd1e-2-all_tbs512_tran-nlayer{num_layers}_brf{buffer_reanalyze_freq}" + + exp_name_prefix = f'{benchmark_tag}/atari_{len(env_id_list)}games_{model_tag}_seed{seed}/' for task_id, env_id in enumerate(env_id_list): config = create_config( env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, - buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size + buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size, num_layers ) config.policy.task_id = task_id config.exp_name = exp_name_prefix + f"{env_id.split('NoFrameskip')[0]}_seed{seed}" configs.append([task_id, [config, create_env_manager()]]) return configs -def create_env_manager(): +def create_env_manager() -> EasyDict: + """ + Overview: + Creates the environment manager configuration, specifying the types of environment, + policy, and their import paths. + + Returns: + - (:obj:`EasyDict`): A configuration object for the environment manager. + """ return EasyDict(dict( env=dict( type='atari_lightzero', @@ -169,75 +337,110 @@ def create_env_manager(): if __name__ == "__main__": """ Overview: - This script should be executed with GPUs. + This script should be executed with GPUs for distributed training. Run the following command to launch the script: - python -m torch.distributed.launch --nproc_per_node=8 --master_port=29504 ./zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py 2>&1 | tee ./log/uz_mt_atari26_channel256_debug.log - torchrun --nproc_per_node=8 ./zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py - """ + Example launch command: + export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + + export CUDA_VISIBLE_DEVICES=2,3,4,5,6,7 + + export CUDA_VISIBLE_DEVICES=4,5,6,7 + + cd /path/to/your/project/ + python -m torch.distributed.launch --nproc_per_node=6 --master_port=29502 /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py 2>&1 | tee /mnt/nfs/zhangjinouwen/puyuan/LightZero/log/20251012_resnet_nlayer4_alpha-100k-098-05.log + /path/to/this/script.py 2>&1 | tee /path/to/your/log/file.log + """ from lzero.entry import train_unizero_multitask_segment_ddp from ding.utils import DDPContext + import torch.distributed as dist import os - env_id_list = [ - 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', - 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', - ] - - # List of Atari games used for multi-task learning - env_id_list = [ - 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', - 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', - 'AmidarNoFrameskip-v4', 'AssaultNoFrameskip-v4', 'AsterixNoFrameskip-v4', 'BankHeistNoFrameskip-v4', - 'BattleZoneNoFrameskip-v4', 'CrazyClimberNoFrameskip-v4', 'DemonAttackNoFrameskip-v4', 'FreewayNoFrameskip-v4', - 'FrostbiteNoFrameskip-v4', 'GopherNoFrameskip-v4', 'JamesbondNoFrameskip-v4', 'KangarooNoFrameskip-v4', - 'KrullNoFrameskip-v4', 'KungFuMasterNoFrameskip-v4', 'PrivateEyeNoFrameskip-v4', 'UpNDownNoFrameskip-v4', - 'QbertNoFrameskip-v4', 'BreakoutNoFrameskip-v4', - ] - + # --- Main Experiment Settings --- + num_games = 8 # Options: 3, 8, 26 + num_layers = 4 + # num_layers = 2 # debug action_space_size = 18 collector_env_num = 8 num_segments = 8 n_episode = 8 evaluator_env_num = 3 num_simulations = 50 - max_env_step = int(5e5) + # max_env_step = int(4e5) + max_env_step = int(5e6) # TODO reanalyze_ratio = 0.0 - total_batch_size =int(512*3) - batch_size = [int(total_batch_size / len(env_id_list)) for _ in range(len(env_id_list))] + if num_games == 3: + env_id_list = ['PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4'] + elif num_games == 8: + env_id_list = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + ] + elif num_games == 26: + env_id_list = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + 'AmidarNoFrameskip-v4', 'AssaultNoFrameskip-v4', 'AsterixNoFrameskip-v4', 'BankHeistNoFrameskip-v4', + 'BattleZoneNoFrameskip-v4', 'CrazyClimberNoFrameskip-v4', 'DemonAttackNoFrameskip-v4', 'FreewayNoFrameskip-v4', + 'FrostbiteNoFrameskip-v4', 'GopherNoFrameskip-v4', 'JamesbondNoFrameskip-v4', 'KangarooNoFrameskip-v4', + 'KrullNoFrameskip-v4', 'KungFuMasterNoFrameskip-v4', 'PrivateEyeNoFrameskip-v4', 'UpNDownNoFrameskip-v4', + 'QbertNoFrameskip-v4', 'BreakoutNoFrameskip-v4', + ] + else: + raise ValueError(f"Unsupported number of environments: {num_games}") - # total_batch_size = 512 - # batch_size = [int(min(32, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] - + # --- Batch Size Calculation --- + # The effective batch size is adjusted based on the number of games and model size (layers) + # to fit within GPU memory constraints. + if len(env_id_list) == 8: + if num_layers in [2, 4]: + effective_batch_size = 512 + elif num_layers == 8: + effective_batch_size = 512 + elif len(env_id_list) == 26: + effective_batch_size = 512 + elif len(env_id_list) == 18: + effective_batch_size = 1536 + elif len(env_id_list) == 3: + effective_batch_size = 10 # For debugging + else: + raise ValueError(f"Batch size not configured for {len(env_id_list)} environments.") + + batch_sizes, grad_acc_steps = compute_batch_config(env_id_list, effective_batch_size, gpu_num=6) # TODO + total_batch_size = effective_batch_size # Currently for logging purposes + + # --- Model and Training Settings --- num_unroll_steps = 10 infer_context_length = 4 norm_type = 'LN' - # buffer_reanalyze_freq = 1 / 50 - buffer_reanalyze_freq = 1 / 1000000 + buffer_reanalyze_freq = 1 / 100000000 # Effectively disable buffer reanalyze reanalyze_batch_size = 160 reanalyze_partition = 0.75 - # ======== TODO: only for debug ======== - collector_env_num = 2 - num_segments = 2 - n_episode = 2 - evaluator_env_num = 2 - num_simulations = 1 - reanalyze_batch_size = 2 - num_unroll_steps = 5 - infer_context_length = 2 - batch_size = [4, 4, 4, 4, 4, 4, 4, 4] + # ====== only for debug ===== + # num_games = 8 # Options: 3, 8, 26 + # num_layers = 2 # debug + # collector_env_num = 2 + # num_segments = 2 + # evaluator_env_num = 2 + # num_simulations = 5 + # batch_sizes = [num_games] * len(env_id_list) + # buffer_reanalyze_freq = 1/100000000 + # total_batch_size = num_games * len(env_id_list) + # --- Training Loop --- for seed in [0]: - configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, - num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, - norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, - num_segments, total_batch_size) + configs = generate_configs( + env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, + num_simulations, reanalyze_ratio, batch_sizes, num_unroll_steps, infer_context_length, + norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + num_segments, total_batch_size, num_layers + ) with DDPContext(): - # train_unizero_multitask_segment_ddp(configs, seed=seed, max_env_step=max_env_step) - - # ======== TODO: only for debug ======== - train_unizero_multitask_segment_ddp(configs[:8], seed=seed, max_env_step=max_env_step) # train on the first four tasks + train_unizero_multitask_segment_ddp(configs, seed=seed, max_env_step=max_env_step, benchmark_name="atari") + print(f"Seed: {seed} training finished!") + if dist.is_initialized(): + dist.destroy_process_group() \ No newline at end of file diff --git a/zoo/atari/config/atari_unizero_multitask_segment_eval_config.py b/zoo/atari/config/atari_unizero_multitask_segment_eval_config.py index 29de4f112..b7973ff87 100644 --- a/zoo/atari/config/atari_unizero_multitask_segment_eval_config.py +++ b/zoo/atari/config/atari_unizero_multitask_segment_eval_config.py @@ -1,6 +1,79 @@ from easydict import EasyDict +from typing import List, Any, Dict -def create_config(env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size): +# ============================================================== +# Environment and Policy Manager Configuration +# ============================================================== + +def create_env_manager() -> EasyDict: + """ + Overview: + Creates the configuration for the environment and policy managers. + This config specifies the types and import paths for core components + like the environment wrapper and the policy definition. + Returns: + - manager_config (:obj:`EasyDict`): A dictionary containing the types and import names + for the environment and policy managers. + """ + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +# ============================================================== +# Main Configuration Generation +# ============================================================== + +def create_config( + env_id: str, + action_space_size: int, + collector_env_num: int, + evaluator_env_num: int, + n_episode: int, + num_simulations: int, + reanalyze_ratio: float, + batch_size: List[int], + num_unroll_steps: int, + infer_context_length: int, + norm_type: str, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + total_batch_size: int, + env_id_list: List[str], +) -> EasyDict: + """ + Overview: + Creates the main configuration dictionary for a single task in a multi-task setup. + Arguments: + - env_id (:obj:`str`): The ID of the environment for this specific task. + - action_space_size (:obj:`int`): The size of the action space for the model. + - collector_env_num (:obj:`int`): The number of environments for the data collector. + - evaluator_env_num (:obj:`int`): The number of environments for the evaluator. + - n_episode (:obj:`int`): The number of episodes to run for collection. + - num_simulations (:obj:`int`): The number of simulations for the MCTS algorithm. + - reanalyze_ratio (:obj:`float`): The ratio of reanalyzed data in the replay buffer. + - batch_size (:obj:`List[int]`): The batch size for training, specified per task. + - num_unroll_steps (:obj:`int`): The number of steps to unroll the model during training. + - infer_context_length (:obj:`int`): The context length for inference. + - norm_type (:obj:`str`): The type of normalization to use (e.g., 'LN' for LayerNorm). + - buffer_reanalyze_freq (:obj:`float`): The frequency at which to reanalyze the buffer. + - reanalyze_batch_size (:obj:`int`): The batch size for reanalyzing data. + - reanalyze_partition (:obj:`float`): The partition ratio for reanalyzing data. + - num_segments (:obj:`int`): The number of segments for game data. + - total_batch_size (:obj:`int`): The total batch size across all tasks. + - env_id_list (:obj:`List[str]`): The list of all environment IDs in the multi-task setup. + Returns: + - config (:obj:`EasyDict`): The complete configuration for a single training task. + """ return EasyDict(dict( env=dict( stop_value=int(1e6), @@ -23,7 +96,7 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu MoCo_rho=0, calpha=0.5, rescale=1, ), task_num=len(env_id_list), - task_id=0, + task_id=0, # Placeholder, will be set in generate_configs model=dict( observation_shape=(3, 64, 64), action_space_size=action_space_size, @@ -32,7 +105,8 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu num_channels=256, world_model_cfg=dict( env_id_list=env_id_list, - analysis_tsne=True, # TODO + # TODO: Implement and verify the t-SNE analysis functionality. + analysis_tsne=True, max_blocks=num_unroll_steps, max_tokens=2 * num_unroll_steps, context_length=2 * infer_context_length, @@ -40,10 +114,9 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu action_space_size=action_space_size, num_layers=8, # Transformer layers num_heads=8, - # num_heads=24, embed_dim=768, obs_type='image', - env_num=8, + env_num=len(env_id_list), task_num=len(env_id_list), use_normal_head=True, use_softmoe_head=False, @@ -79,9 +152,71 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu ), )) -def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size): + +def _generate_exp_name_prefix( + exp_base_path: str, + num_games: int, + buffer_reanalyze_freq: float, + norm_type: str, + seed: int +) -> str: + """ + Overview: + Generates a standardized prefix for the experiment name based on key hyperparameters. + Arguments: + - exp_base_path (:obj:`str`): The base directory for the experiment logs. + - num_games (:obj:`int`): The number of games in the multi-task setup. + - buffer_reanalyze_freq (:obj:`float`): The frequency of buffer reanalysis. + - norm_type (:obj:`str`): The normalization type used in the model. + - seed (:obj:`int`): The random seed for the experiment. + Returns: + - (:obj:`str`): The generated experiment name prefix. + """ + # NOTE: This name is constructed based on a specific convention to encode hyperparameters. + # It includes details about the model architecture, training parameters, and environment setup. + return ( + f'{exp_base_path}/{num_games}games_brf{buffer_reanalyze_freq}_' + f'1-encoder-{norm_type}-res2-channel256_gsl20_{num_games}-pred-head_' + f'nlayer8-nh24-lsd768_seed{seed}/' + ) + + +def generate_configs( + env_id_list: List[str], + action_space_size: int, + collector_env_num: int, + n_episode: int, + evaluator_env_num: int, + num_simulations: int, + reanalyze_ratio: float, + batch_size: List[int], + num_unroll_steps: int, + infer_context_length: int, + norm_type: str, + seed: int, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + total_batch_size: int, + exp_base_path: str, +) -> List[List[Any]]: + """ + Overview: + Generates a list of configurations for each task in a multi-task training setup. + Each configuration is paired with an environment manager config. + Arguments: + - (All arguments from create_config, plus): + - seed (:obj:`int`): The random seed for the experiment, used for naming. + - exp_base_path (:obj:`str`): The base path for saving experiment results. + Returns: + - configs (:obj:`List[List[Any]]`): A list where each item contains + [task_id, [task_specific_config, env_manager_config]]. + """ configs = [] - exp_name_prefix = f'data_unizero_mt_ddp-8gpu_eval-latent_state_tsne/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_{len(env_id_list)}-pred-head_nlayer8-nh24-lsd768_seed{seed}/' + exp_name_prefix = _generate_exp_name_prefix( + exp_base_path, len(env_id_list), buffer_reanalyze_freq, norm_type, seed + ) for task_id, env_id in enumerate(env_id_list): config = create_config( @@ -89,79 +224,110 @@ def generate_configs(env_id_list, action_space_size, collector_env_num, n_episod n_episode, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, - num_segments, total_batch_size + num_segments, total_batch_size, env_id_list ) + # Assign the specific task ID for this configuration config.policy.task_id = task_id - config.exp_name = exp_name_prefix + f"{env_id.split('NoFrameskip')[0]}_unizero-mt_seed{seed}" + # Set the full experiment name for logging and checkpointing + env_name = env_id.split('NoFrameskip')[0] + config.exp_name = exp_name_prefix + f"{env_name}_unizero-mt_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs -def create_env_manager(): - return EasyDict(dict( - env=dict( - type='atari_lightzero', - import_names=['zoo.atari.envs.atari_lightzero_env'], - ), - env_manager=dict(type='subprocess'), - policy=dict( - type='unizero_multitask', - import_names=['lzero.policy.unizero_multitask'], - ), - )) +# ============================================================== +# Main execution block +# ============================================================== if __name__ == "__main__": """ Overview: - This program is designed to obtain the t-SNE of the latent states in 8games multi-task learning. + This program is designed to obtain the t-SNE of the latent states in multi-task learning + across a set of Atari games (e.g., 8 games). + + This script should be executed with GPUs for Distributed Data Parallel (DDP) training. + Run one of the following commands to launch the script: + + Using `torch.distributed.launch` (deprecated): + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29501 ./path/to/this/script.py - This script should be executed with GPUs. - Run the following command to launch the script: - python -m torch.distributed.launch --nproc_per_node=8 --master_port=29501 ./zoo/atari/config/atari_unizero_multitask_segment_eval_config.py - torchrun --nproc_per_node=8 ./zoo/atari/config/atari_unizero_multitask_segment_eval_config.py + Using `torchrun` (recommended): + torchrun --nproc_per_node=8 ./path/to/this/script.py """ - from lzero.entry import train_unizero_multitask_segment_eval from ding.utils import DDPContext + # --- Basic Environment and Model Setup --- env_id_list = [ 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', ] + action_space_size = 18 # Standard action space size for Atari games - action_space_size = 18 - - for seed in [0]: - collector_env_num = 2 - num_segments = 2 - n_episode = 2 - evaluator_env_num = 2 - num_simulations = 50 - max_env_step = int(4e5) - reanalyze_ratio = 0.0 - total_batch_size = int(4*len(env_id_list)) - batch_size = [4 for _ in range(len(env_id_list))] - num_unroll_steps = 10 - infer_context_length = 4 - norm_type = 'LN' - buffer_reanalyze_freq = 1/50 - reanalyze_batch_size = 160 - reanalyze_partition = 0.75 + # --- Hyperparameter Configuration --- + # Grouping hyperparameters for better readability and management. + main_hyperparams = { + 'seed': 0, + 'collector_env_num': 2, + 'evaluator_env_num': 2, + 'n_episode': 2, + 'num_simulations': 50, + 'max_env_step': int(4e5), + 'reanalyze_ratio': 0.0, + 'num_segments': 2, + 'num_unroll_steps': 10, + 'infer_context_length': 4, + 'norm_type': 'LN', + 'buffer_reanalyze_freq': 1/50, + 'reanalyze_batch_size': 160, + 'reanalyze_partition': 0.75, + 'total_batch_size': int(4 * len(env_id_list)), + 'batch_size_per_task': 4, + # --- Path for experiment logs and pretrained model --- + # NOTE: Please update these paths to your local directory structure. + 'exp_base_path': 'data/unizero_mt_ddp-8gpu_eval-latent_state_tsne', + # Example for an 8-game pretrained model + 'pretrained_model_path': '/path/to/your/pretrained_model.pth.tar', + # Example for a 26-game pretrained model + # 'pretrained_model_path': '/path/to/your/26_game_model.pth.tar', + } + # --- Generate Configurations for each seed --- + # This loop allows running experiments with multiple seeds easily. + for seed in [main_hyperparams['seed']]: + # The batch size is a list, with one entry per task. + batch_size_list = [main_hyperparams['batch_size_per_task']] * len(env_id_list) + # Generate the list of configurations for the trainer configs = generate_configs( - env_id_list, action_space_size, collector_env_num, n_episode, - evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, - num_unroll_steps, infer_context_length, norm_type, seed, - buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, - num_segments, total_batch_size + env_id_list=env_id_list, + action_space_size=action_space_size, + collector_env_num=main_hyperparams['collector_env_num'], + n_episode=main_hyperparams['n_episode'], + evaluator_env_num=main_hyperparams['evaluator_env_num'], + num_simulations=main_hyperparams['num_simulations'], + reanalyze_ratio=main_hyperparams['reanalyze_ratio'], + batch_size=batch_size_list, + num_unroll_steps=main_hyperparams['num_unroll_steps'], + infer_context_length=main_hyperparams['infer_context_length'], + norm_type=main_hyperparams['norm_type'], + seed=seed, + buffer_reanalyze_freq=main_hyperparams['buffer_reanalyze_freq'], + reanalyze_batch_size=main_hyperparams['reanalyze_batch_size'], + reanalyze_partition=main_hyperparams['reanalyze_partition'], + num_segments=main_hyperparams['num_segments'], + total_batch_size=main_hyperparams['total_batch_size'], + exp_base_path=main_hyperparams['exp_base_path'], ) - # Pretrained model paths - # 8games - pretrained_model_path = '/mnt/afs/niuyazhe/code/LightZero/data_unizero_mt_ddp-8gpu_1127/8games_brf0.02_nlayer8-nhead24_seed1/8games_brf0.02_1-encoder-LN-res2-channel256_gsl20_8-pred-head_lsd768-nlayer8-nh24_mbs-512-bs64_upc80_seed1/Pong_unizero-mt_seed1/ckpt/iteration_200000.pth.tar' - # 26games - # pretrained_model_path = '/mnt/afs/niuyazhe/code/LightZero/data_unizero_mt_ddp-8gpu-26game_1127/26games_brf0.02_nlayer8-nhead24_seed0/26games_brf0.02_1-encoder-LN-res2-channel256_gsl20_26-pred-head_lsd768-nlayer8-nh24_mbs-512-bs64_upc80_seed0/Pong_unizero-mt_seed0/ckpt/iteration_150000.pth.tar' - + # --- Launch Training --- + # Use DDPContext to manage the distributed training environment. with DDPContext(): - train_unizero_multitask_segment_eval(configs, seed=seed, model_path=pretrained_model_path, max_env_step=max_env_step) \ No newline at end of file + train_unizero_multitask_segment_eval( + configs, + seed=seed, + model_path=main_hyperparams['pretrained_model_path'], + max_env_step=main_hyperparams['max_env_step'] + ) \ No newline at end of file diff --git a/zoo/atari/config/atari_unizero_multitask_segment_finetune_config.py b/zoo/atari/config/atari_unizero_multitask_segment_finetune_config.py index badcd9585..3581839b2 100644 --- a/zoo/atari/config/atari_unizero_multitask_segment_finetune_config.py +++ b/zoo/atari/config/atari_unizero_multitask_segment_finetune_config.py @@ -1,30 +1,49 @@ from easydict import EasyDict +from typing import List, Tuple, Union, Any, Dict -def create_config(env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size): - return EasyDict(dict( - env=dict( +class UniZeroAtariConfig: + """ + Overview: + Default configuration class for UniZero Atari experiments. + This class centralizes all default parameters, making it easier to manage and extend. + """ + def __init__(self) -> None: + self.exp_name: str = '' + self.env: EasyDict = self._get_default_env_config() + self.policy: EasyDict = self._get_default_policy_config() + + @staticmethod + def _get_default_env_config() -> EasyDict: + """ + Overview: + Returns the default environment configuration. + """ + return EasyDict(dict( stop_value=int(1e6), - env_id=env_id, + env_id='PongNoFrameskip-v4', observation_shape=(3, 64, 64), gray_scale=False, - collector_env_num=collector_env_num, - evaluator_env_num=evaluator_env_num, - n_evaluator_episode=evaluator_env_num, + collector_env_num=8, + evaluator_env_num=3, + n_evaluator_episode=3, manager=dict(shared_memory=False), full_action_space=True, collect_max_episode_steps=int(5e3), eval_max_episode_steps=int(5e3), - # ===== only for debug ===== - # collect_max_episode_steps=int(20), - # eval_max_episode_steps=int(20), - ), - policy=dict( + )) + + @staticmethod + def _get_default_policy_config() -> EasyDict: + """ + Overview: + Returns the default policy configuration. + """ + return EasyDict(dict( multi_gpu=True, - only_use_moco_stats=False, - use_moco=False, # ==============TODO============== - # use_moco=True, # ==============TODO============== + # ==============TODO============== + use_moco=False, learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=50000))), - grad_correct_params=dict( # Gradient correction parameters + grad_correct_params=dict( MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, @@ -33,50 +52,47 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu calpha=0.5, rescale=1, ), - task_num=len(env_id_list), + task_num=1, task_id=0, model=dict( observation_shape=(3, 64, 64), - action_space_size=action_space_size, - norm_type=norm_type, + action_space_size=18, + norm_type='LN', num_res_blocks=2, num_channels=256, world_model_cfg=dict( + # TODO: for latent state layer_norm final_norm_option_in_obs_head='LayerNorm', final_norm_option_in_encoder='LayerNorm', - predict_latent_loss_type='mse', # TODO: for latent state layer_norm - + predict_latent_loss_type='mse', + # TODO: only for latent state sim_norm # final_norm_option_in_obs_head='SimNorm', # final_norm_option_in_encoder='SimNorm', - # predict_latent_loss_type='group_kl', # TODO: only for latent state sim_norm - - share_head=False, # TODO - analysis_dormant_ratio_weight_rank=False, # TODO + # predict_latent_loss_type='group_kl', + share_head=False, # TODO + analysis_dormant_ratio_weight_rank=False, # TODO dormant_threshold=0.025, - continuous_action_space=False, - - task_embed_option=None, # ==============TODO: none ============== - use_task_embed=False, # ==============TODO============== - - # task_embed_option='concat_task_embed', # ==============TODO: none ============== - # use_task_embed=True, # ==============TODO============== + # ==============TODO: none ============== + task_embed_option=None, + use_task_embed=False, + # ==============TODO============== + # task_embed_option='concat_task_embed', + # use_task_embed=True, # task_embed_dim=96, # task_embed_dim=128, - use_shared_projection=False, - - max_blocks=num_unroll_steps, - max_tokens=2 * num_unroll_steps, - context_length=2 * infer_context_length, + max_blocks=10, # num_unroll_steps + max_tokens=20, # 2 * num_unroll_steps + context_length=8, # 2 * infer_context_length device='cuda', - action_space_size=action_space_size, + action_space_size=18, num_layers=8, num_heads=24, embed_dim=768, obs_type='image', env_num=8, - task_num=len(env_id_list), + task_num=1, use_normal_head=True, use_softmoe_head=False, use_moe_head=False, @@ -84,84 +100,205 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu moe_in_transformer=False, multiplication_moe_in_transformer=False, num_experts_of_moe_in_transformer=4, - - # LoRA 参数(启用LoRA) + # LoRA parameters (enable LoRA by setting lora_r > 0) lora_r=0, # lora_r=8, lora_alpha=32, lora_dropout=0.1, - # 默认目标模块:attn和feed_forward + # Default target modules: attn and feed_forward lora_target_modules=["attn", "feed_forward"], - # 调整finetune_components ), ), - use_task_exploitation_weight=False, # TODO - task_complexity_weight=False, # TODO - total_batch_size=total_batch_size, + # TODO + use_task_exploitation_weight=False, + task_complexity_weight=False, + total_batch_size=512, allocated_batch_sizes=False, train_start_after_envsteps=int(0), use_priority=False, print_task_priority_logs=False, cuda=True, model_path=None, - num_unroll_steps=num_unroll_steps, + num_unroll_steps=10, game_segment_length=20, update_per_collect=80, replay_ratio=0.25, - batch_size=batch_size, + batch_size=64, optim_type='AdamW', cos_lr_scheduler=True, - num_segments=num_segments, - num_simulations=num_simulations, - reanalyze_ratio=reanalyze_ratio, - n_episode=n_episode, + num_segments=8, + num_simulations=50, + reanalyze_ratio=0.0, + n_episode=8, replay_buffer_size=int(5e5), eval_freq=int(2e4), - collector_env_num=collector_env_num, - evaluator_env_num=evaluator_env_num, - buffer_reanalyze_freq=buffer_reanalyze_freq, - reanalyze_batch_size=reanalyze_batch_size, - reanalyze_partition=reanalyze_partition, - ), - )) + collector_env_num=8, + evaluator_env_num=3, + buffer_reanalyze_freq=1 / 10000000, + reanalyze_batch_size=160, + reanalyze_partition=0.75, + )) -def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size): - configs = [] - # exp_name_prefix = f'data_lz/data_unizero_atari_mt_finetune_20250308/amidar_load-enc-trans_finetune-head/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_lsd768-nlayer8-nh8_upc80_seed{seed}/' - exp_name_prefix = f'data_lz/data_unizero_atari_mt_finetune_20250308/amidar_load-enc-trans_finetune-head-encoder/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_lsd768-nlayer8-nh8_upc80_seed{seed}/' - # exp_name_prefix = f'data_lz/data_unizero_atari_mt_finetune_20250308/amidar_load-enc-trans_finetune-head-trans/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_lsd768-nlayer8-nh8_upc80_seed{seed}/' - # exp_name_prefix = f'data_lz/data_unizero_atari_mt_finetune_20250308/amidar_load-enc-trans_finetune-head-trans-lora/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_lsd768-nlayer8-nh24_upc80_seed{seed}/' - - # exp_name_prefix = f'data_lz/data_unizero_atari_mt_finetune_20250308/pong_load-enc-trans_finetune-head-trans-lora/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_lsd768-nlayer8-nh24_upc80_seed{seed}/' - # exp_name_prefix = f'data_lz/data_unizero_atari_mt_finetune_20250308/pong_load-enc-trans_finetune-head/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_lsd768-nlayer8-nh24_upc80_seed{seed}/' +def create_config( + env_id: str, + action_space_size: int, + collector_env_num: int, + evaluator_env_num: int, + n_episode: int, + num_simulations: int, + reanalyze_ratio: float, + batch_size: Union[int, List[int]], + num_unroll_steps: int, + infer_context_length: int, + norm_type: str, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + total_batch_size: int, + task_num: int +) -> EasyDict: + """ + Overview: + Creates and customizes a configuration for a specific Atari environment task. + + Arguments: + - env_id (:obj:`str`): The ID of the Atari environment. + - action_space_size (:obj:`int`): The size of the action space. + - collector_env_num (:obj:`int`): Number of environments for collecting data. + - evaluator_env_num (:obj:`int`): Number of environments for evaluation. + - n_episode (:obj:`int`): Number of episodes to run for each collection. + - num_simulations (:obj:`int`): Number of simulations in the MCTS. + - reanalyze_ratio (:obj:`float`): The ratio of reanalyzed samples in the replay buffer. + - batch_size (:obj:`Union[int, List[int]]`): The batch size for training. + - num_unroll_steps (:obj:`int`): The number of steps to unroll the model. + - infer_context_length (:obj:`int`): The context length for inference. + - norm_type (:obj:`str`): The type of normalization to use. + - buffer_reanalyze_freq (:obj:`float`): Frequency of reanalyzing the buffer. + - reanalyze_batch_size (:obj:`int`): Batch size for reanalyzing. + - reanalyze_partition (:obj:`float`): Partition ratio for reanalyzing. + - num_segments (:obj:`int`): Number of segments for each game. + - total_batch_size (:obj:`int`): The total batch size across all tasks. + - task_num (:obj:`int`): The total number of tasks. + + Returns: + - (:obj:`EasyDict`): A fully configured EasyDict object for the experiment. + """ + cfg = UniZeroAtariConfig() + + # == Update Environment Config == + cfg.env.env_id = env_id + cfg.env.collector_env_num = collector_env_num + cfg.env.evaluator_env_num = evaluator_env_num + cfg.env.n_evaluator_episode = evaluator_env_num + + # == Update Policy Config == + policy = cfg.policy + policy.task_num = task_num + policy.action_space_size = action_space_size + policy.n_episode = n_episode + policy.num_simulations = num_simulations + policy.reanalyze_ratio = reanalyze_ratio + policy.batch_size = batch_size + policy.total_batch_size = total_batch_size + policy.num_unroll_steps = num_unroll_steps + policy.collector_env_num = collector_env_num + policy.evaluator_env_num = evaluator_env_num + policy.buffer_reanalyze_freq = buffer_reanalyze_freq + policy.reanalyze_batch_size = reanalyze_batch_size + policy.reanalyze_partition = reanalyze_partition + policy.num_segments = num_segments + + # == Update Model Config == + model = policy.model + model.action_space_size = action_space_size + model.norm_type = norm_type + # == Update World Model Config == + world_model = model.world_model_cfg + world_model.max_blocks = num_unroll_steps + world_model.max_tokens = 2 * num_unroll_steps + world_model.context_length = 2 * infer_context_length + world_model.action_space_size = action_space_size + world_model.task_num = task_num + + return EasyDict(cfg) +def generate_experiment_configs( + env_id_list: List[str], + action_space_size: int, + collector_env_num: int, + n_episode: int, + evaluator_env_num: int, + num_simulations: int, + reanalyze_ratio: float, + batch_size: Union[int, List[int]], + num_unroll_steps: int, + infer_context_length: int, + norm_type: str, + seed: int, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + total_batch_size: int +) -> List[Tuple[int, List[Union[EasyDict, Any]]]]: + """ + Overview: + Generates a list of configurations for multi-task experiments. + + Arguments: + - env_id_list (:obj:`List[str]`): List of environment IDs for the tasks. + - ... (same as create_config): Other experiment parameters. + - seed (:obj:`int`): The random seed for the experiment. + + Returns: + - (:obj:`List[Tuple[int, List[Union[EasyDict, Any]]]]`): A list where each element contains a task_id and its + corresponding configuration and environment manager setup. + """ + configs = [] + task_num = len(env_id_list) + + # --- Experiment Name Prefix --- + # This prefix defines the storage path for experiment data and logs. + # Please replace `` with your actual data storage path. + exp_name_prefix_template = ( + "/data_unizero_atari_mt_finetune_{timestamp}/" + "experiment_name/{task_num}games_brf{brf}_1-encoder-{norm}-res2-channel256_" + "gsl20_lsd768-nlayer8-nh8_upc80_seed{seed}/" + ) + exp_name_prefix = exp_name_prefix_template.format( + timestamp="20250308", + task_num=task_num, + brf=buffer_reanalyze_freq, + norm=norm_type, + seed=seed + ) + for task_id, env_id in enumerate(env_id_list): config = create_config( - env_id, - action_space_size, - collector_env_num, - evaluator_env_num, - n_episode, - num_simulations, - reanalyze_ratio, - batch_size, - num_unroll_steps, - infer_context_length, - norm_type, - buffer_reanalyze_freq, - reanalyze_batch_size, - reanalyze_partition, - num_segments, - total_batch_size + env_id, action_space_size, collector_env_num, evaluator_env_num, + n_episode, num_simulations, reanalyze_ratio, batch_size, + num_unroll_steps, infer_context_length, norm_type, + buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + num_segments, total_batch_size, task_num ) config.policy.task_id = task_id config.exp_name = exp_name_prefix + f"{env_id.split('NoFrameskip')[0]}_unizero-mt_seed{seed}" configs.append([task_id, [config, create_env_manager()]]) return configs -def create_env_manager(): + +def create_env_manager() -> EasyDict: + """ + Overview: + Creates the environment and policy manager configuration. + This specifies the types and import paths for the environment and policy used in the experiment. + + Returns: + - (:obj:`EasyDict`): An EasyDict object containing manager configurations. + """ return EasyDict(dict( env=dict( type='atari_lightzero', @@ -174,63 +311,99 @@ def create_env_manager(): ), )) + if __name__ == "__main__": """ Overview: This script should be executed with GPUs. - Run the following command to launch the script: - python -m torch.distributed.launch --nproc_per_node=1 --master_port=29507 ./zoo/atari/config/atari_unizero_multitask_segment_finetune_config.py - torchrun --nproc_per_node=8 ./zoo/atari/config/atari_unizero_multitask_segment_finetune_config.py + Run one of the following commands to launch the script: + - Using torch.distributed.launch: + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29507 ./path/to/this/script.py + - Using torchrun: + torchrun --nproc_per_node=8 ./path/to/this/script.py """ - from lzero.entry import train_unizero_multitask_segment_ddp from ding.utils import DDPContext - from easydict import EasyDict + import os - # env_id_list = ['PongNoFrameskip-v4'] # Debug setup - env_id_list = ['AmidarNoFrameskip-v4'] # Debug setup + # --- Main Experiment Settings --- + # Use DEBUG mode for fast iteration and debugging. + DEBUG = False + # --- Environment and Task Settings --- + env_id_list = ['AmidarNoFrameskip-v4'] action_space_size = 18 - # NCCL environment setup - import os + # --- Distributed Training Settings --- os.environ["NCCL_TIMEOUT"] = "3600000000" - # for seed in [0, 1, 2]: + # --- Loop over seeds for multiple runs --- for seed in [0]: - collector_env_num = 8 - num_segments = 8 - n_episode = 8 - evaluator_env_num = 3 - num_simulations = 50 - max_env_step = int(4e5) + # --- Core Algorithm Parameters --- + if DEBUG: + # Settings for quick debugging + collector_env_num = 2 + num_segments = 2 + n_episode = 2 + evaluator_env_num = 2 + num_simulations = 2 + total_batch_size = 32 + batch_size = [int(total_batch_size / len(env_id_list))] * len(env_id_list) + reanalyze_batch_size = 4 + max_env_step = int(1e3) + else: + # Standard experiment settings + collector_env_num = 8 + num_segments = 8 + n_episode = 8 + evaluator_env_num = 3 + num_simulations = 50 + total_batch_size = 512 + batch_size = [int(min(64, total_batch_size / len(env_id_list)))] * len(env_id_list) + reanalyze_batch_size = 160 + max_env_step = int(4e5) + # --- Shared Parameters --- reanalyze_ratio = 0.0 - total_batch_size = 512 - batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] - num_unroll_steps = 10 infer_context_length = 4 norm_type = 'LN' - # buffer_reanalyze_freq = 1 / 50 - buffer_reanalyze_freq = 1 / 10000000 - reanalyze_batch_size = 160 + buffer_reanalyze_freq = 1 / 10000000 # Effectively disabled reanalyze_partition = 0.75 - # ======== TODO: only for debug ======== - # collector_env_num = 2 - # num_segments = 2 - # n_episode = 2 - # evaluator_env_num = 2 - # num_simulations = 1 - # reanalyze_batch_size = 2 - # batch_size = [4, 4, 4, 4, 4, 4, 4, 4] - - configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size) - - # pretrained_model_path = '/mnt/afs/niuyazhe/code/LightZero/data_unizero_mt_ddp-8gpu_1127/8games_brf0.02_nlayer8-nhead24_seed1/8games_brf0.02_1-encoder-LN-res2-channel256_gsl20_8-pred-head_lsd768-nlayer8-nh24_mbs-512-bs64_upc80_seed1/Pong_unizero-mt_seed1/ckpt/iteration_200000.pth.tar' - # pretrained_model_path = '/mnt/afs/niuyazhe/code/LightZero/data_unizero_atari_mt_20250217/atari_8games_notaskembed_bs64_brf0.02_seed0_dev-uz-mz-mt-cont/Pong_seed0_250218_124624/ckpt/ckpt_best.pth.tar' + # --- Generate Configurations --- + configs = generate_experiment_configs( + env_id_list=env_id_list, + action_space_size=action_space_size, + collector_env_num=collector_env_num, + n_episode=n_episode, + evaluator_env_num=evaluator_env_num, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_size=batch_size, + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + seed=seed, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments, + total_batch_size=total_batch_size + ) - pretrained_model_path = '/fs-computility/ai-shen/puyuan/code/LightZero/data_lz/data_unizero_atari_mt_20250307/atari_8games_brf0.02_not-share-head_final-ln_seed0/Pong_seed0/ckpt/ckpt_best.pth.tar' + # --- Pretrained Model Path --- + # Please replace `` with the actual path to your model. + pretrained_model_path = ( + "/data_unizero_atari_mt_20250307/" + "atari_8games_brf0.02_not-share-head_final-ln_seed0/Pong_seed0/ckpt/ckpt_best.pth.tar" + ) + + # --- Start Training --- with DDPContext(): - train_unizero_multitask_segment_ddp(configs, seed=seed, model_path=pretrained_model_path, max_env_step=max_env_step) \ No newline at end of file + train_unizero_multitask_segment_ddp( + configs, + seed=seed, + model_path=pretrained_model_path, + max_env_step=max_env_step + ) \ No newline at end of file diff --git a/zoo/atari/config/atari_unizero_segment_config.py b/zoo/atari/config/atari_unizero_segment_config.py index 4518fbe4d..fa115e459 100644 --- a/zoo/atari/config/atari_unizero_segment_config.py +++ b/zoo/atari/config/atari_unizero_segment_config.py @@ -10,22 +10,34 @@ def main(env_id, seed): # ============================================================== collector_env_num = 8 num_segments = 8 + game_segment_length = 20 + # game_segment_length = 400 # TODO + evaluator_env_num = 3 num_simulations = 50 - max_env_step = int(4e5) - batch_size = 64 + # max_env_step = int(4e5) + max_env_step = int(5e6) # TODO + # max_env_step = int(1e6) # TODO pong + + # batch_size = 2 # only for debug + # batch_size = 64 + batch_size = 256 num_layers = 2 - replay_ratio = 0.25 + replay_ratio = 0.1 + # replay_ratio = 0.25 num_unroll_steps = 10 infer_context_length = 4 # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. - buffer_reanalyze_freq = 1/50 + # buffer_reanalyze_freq = 1/50 + buffer_reanalyze_freq = 1/5000000000 + # Each reanalyze process will reanalyze sequences ( transitions per sequence) reanalyze_batch_size = 160 # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. reanalyze_partition = 0.75 + norm_type ="LN" # ====== only for debug ===== # collector_env_num = 2 @@ -34,6 +46,8 @@ def main(env_id, seed): # num_simulations = 5 # batch_size = 5 # buffer_reanalyze_freq = 1/1000000 + # replay_ratio = 1 + # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== @@ -59,15 +73,23 @@ def main(env_id, seed): model=dict( observation_shape=(3, 64, 64), action_space_size=action_space_size, - support_scale=300, + reward_support_range=(-300., 301., 1.), + value_support_range=(-300., 301., 1.), + norm_type=norm_type, + num_res_blocks=1, + num_channels=64, + # num_res_blocks=2, + # num_channels=128, world_model_cfg=dict( - # final_norm_option_in_obs_head='LayerNorm', - # final_norm_option_in_encoder='LayerNorm', - # predict_latent_loss_type='mse', # TODO: only for latent state layer_norm + norm_type=norm_type, + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', # TODO: only for latent state layer_norm + + # final_norm_option_in_obs_head='SimNorm', + # final_norm_option_in_encoder='SimNorm', + # predict_latent_loss_type='group_kl', # TODO: only for latent state sim_norm - final_norm_option_in_obs_head='SimNorm', - final_norm_option_in_encoder='SimNorm', - predict_latent_loss_type='group_kl', # TODO: only for latent state sim_norm # analysis_dormant_ratio_weight_rank=True, # TODO analysis_dormant_ratio_weight_rank=False, # TODO @@ -89,7 +111,11 @@ def main(env_id, seed): obs_type='image', env_num=max(collector_env_num, evaluator_env_num), num_simulations=num_simulations, + game_segment_length=game_segment_length, + # use_priority=False, + use_priority=True, rotary_emb=False, + encoder_type='resnet', use_normal_head=True, use_softmoe_head=False, use_moe_head=False, @@ -101,28 +127,74 @@ def main(env_id, seed): lora_r= 0, lora_alpha =1, lora_dropout= 0.0, + optim_type='AdamW_mix_lr_wdecay', # only for tsne plot + ), ), + optim_type='AdamW_mix_lr_wdecay', + weight_decay=1e-2, # TODO: encoder 5*wd, transformer wd, head 0 + learning_rate=0.0001, + # (str) The path of the pretrained model. If None, the model will be initialized by the default model. model_path=None, + + # (bool) 是否启用自适应策略熵权重 (alpha) + use_adaptive_entropy_weight=True, + # (float) 自适应alpha优化器的学习率 + adaptive_entropy_alpha_lr=1e-4, + # adaptive_entropy_alpha_lr=1e-3, + target_entropy_start_ratio =0.98, + # target_entropy_end_ratio =0.9, + target_entropy_end_ratio =0.7, + target_entropy_decay_steps = 100000, # 例如,在100k次迭代后达到最终值 需要与replay ratio协同调整 + # target_entropy_end_ratio =0.5, # TODO===== + # target_entropy_decay_steps = 400000, # 例如,在100k次迭代后达到最终值 需要与replay ratio协同调整 + + + # ==================== START: Encoder-Clip Annealing Config ==================== + # (bool) 是否启用 encoder-clip 值的退火。 + use_encoder_clip_annealing=True, + # (str) 退火类型。可选 'linear' 或 'cosine'。 + encoder_clip_anneal_type='cosine', + # (float) 退火的起始 clip 值 (训练初期,较宽松)。 + encoder_clip_start_value=30.0, + # (float) 退火的结束 clip 值 (训练后期,较严格)。 + encoder_clip_end_value=10.0, + # (int) 完成从起始值到结束值的退火所需的训练迭代步数。 + # encoder_clip_anneal_steps=400000, # 例如,在400k次迭代后达到最终值 + encoder_clip_anneal_steps=100000, # 例如,在100k次迭代后达到最终值 + + # ==================== START: label smooth ==================== + policy_ls_eps_start=0.05, #TODO============= good start in Pong and MsPacman + policy_ls_eps_end=0.01, + policy_ls_eps_decay_steps=50000, # 50k + label_smoothing_eps=0.1, #TODO============= for value + + # ==================== [新增] 范数监控频率 ==================== + # 每隔多少个训练迭代步数,监控一次模型参数的范数。设置为0则禁用。 + # monitor_norm_freq=10000, + monitor_norm_freq=5000, # TODO + # monitor_norm_freq=2, # only for debug + use_augmentation=False, manual_temperature_decay=False, threshold_training_steps_for_final_temperature=int(2.5e4), - use_priority=False, + # use_priority=False, + use_priority=True, + priority_prob_alpha=1, + priority_prob_beta=1, num_unroll_steps=num_unroll_steps, update_per_collect=None, replay_ratio=replay_ratio, batch_size=batch_size, - optim_type='AdamW', - learning_rate=0.0001, num_simulations=num_simulations, num_segments=num_segments, td_steps=5, - # train_start_after_envsteps=0, # only for debug - train_start_after_envsteps=2000, + train_start_after_envsteps=0, # only for debug + # train_start_after_envsteps=2000, game_segment_length=game_segment_length, grad_clip_value=5, - replay_buffer_size=int(1e6), + replay_buffer_size=int(5e5), eval_freq=int(5e3), collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, @@ -154,7 +226,11 @@ def main(env_id, seed): # ============ use muzero_segment_collector instead of muzero_collector ============= from lzero.entry import train_unizero_segment - main_config.exp_name = f'data_lz/data_unizero/{env_id[:-14]}/{env_id[:-14]}_uz_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}' + main_config.exp_name = f'data_unizero_st_refactor1023/{env_id[:-14]}/{env_id[:-14]}_uz_ch64-res1_targetentropy-alpha-100k-098-07-encoder-clip30-10-100k_adamw-wd1e-2-encoder5-trans1-head0_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}' + + # main_config.exp_name = f'data_unizero_st_refactor1023/{env_id[:-14]}/{env_id[:-14]}_uz_ch64-res1_targetentropy-alpha-100k-098-07-encoder-clip30-10-100k_label-smooth_resnet-encoder_priority_adamw-wd1e-2-encoder5-trans1-head0-true_ln-inner-ln_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}' + + # main_config.exp_name = f'data_unizero_st_refactor1010/{env_id[:-14]}/{env_id[:-14]}_uz_ch128-res2_targetentropy-alpha-100k-098-07-encoder-clip30-10-400k_label-smooth_resnet-encoder_priority_adamw-wd1e-2-encoder1-trans1-head1_ln-inner-ln_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}' train_unizero_segment([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step) @@ -164,4 +240,34 @@ def main(env_id, seed): parser.add_argument('--env', type=str, help='The environment to use', default='PongNoFrameskip-v4') parser.add_argument('--seed', type=int, help='The seed to use', default=0) args = parser.parse_args() + + + + # 测试的atari8中的4个base环境 + # args.env = 'PongNoFrameskip-v4' # 反应型环境 密集奖励 + # args.env = 'MsPacmanNoFrameskip-v4' # 记忆规划型环境 稀疏奖励 + + args.env = 'SeaquestNoFrameskip-v4' # 记忆规划型环境 稀疏奖励 + # args.env = 'HeroNoFrameskip-v4' # 记忆规划型环境 稀疏奖励 + + # args.env = 'AlienNoFrameskip-v4' + + # 下面是atari8以外的2个代表环境 + # args.env = 'QbertNoFrameskip-v4' # 记忆规划型环境 稀疏奖励 + # args.env = 'SpaceInvadersNoFrameskip-v4' # 记忆规划型环境 稀疏奖励 + + # 下面是已经表现不错的 + # args.env = 'BoxingNoFrameskip-v4' # 反应型环境 密集奖励 + # args.env = 'ChopperCommandNoFrameskip-v4' + # args.env = 'RoadRunnerNoFrameskip-v4' + main(args.env, args.seed) + + """ + tmux new -s uz-st-refactor-boxing + + conda activate /mnt/nfs/zhangjinouwen/puyuan/conda_envs/lz + export CUDA_VISIBLE_DEVICES=1 + cd /mnt/nfs/zhangjinouwen/puyuan/LightZero + python /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/config/atari_unizero_segment_config.py 2>&1 | tee /mnt/nfs/zhangjinouwen/puyuan/LightZero/log/202510/20251023_uz_st_seaq.log + """ diff --git a/zoo/atari/config/atari_unizero_segment_ddp_config.py b/zoo/atari/config/atari_unizero_segment_ddp_config.py index 9321603bc..2031f6ddf 100644 --- a/zoo/atari/config/atari_unizero_segment_ddp_config.py +++ b/zoo/atari/config/atari_unizero_segment_ddp_config.py @@ -58,7 +58,8 @@ def main(env_id, seed): model=dict( observation_shape=(3, 96, 96), action_space_size=action_space_size, - support_scale=300, + reward_support_range=(-300., 301., 1.), + value_support_range=(-300., 301., 1.), world_model_cfg=dict( support_size=601, policy_entropy_weight=5e-3, diff --git a/zoo/atari/envs/atari_lightzero_env.py b/zoo/atari/envs/atari_lightzero_env.py index f5e43f6c8..d40f35033 100644 --- a/zoo/atari/envs/atari_lightzero_env.py +++ b/zoo/atari/envs/atari_lightzero_env.py @@ -177,7 +177,8 @@ def step(self, action: int) -> BaseEnvTimestep: self.reward = np.array(reward).astype(np.float32) self._eval_episode_return += self.reward self._timestep += 1 - # logging.info(f'self._timestep: {self._timestep}') + if self._timestep%200==0: + logging.info(f'self._timestep: {self._timestep}') observation = self.observe() if done: logging.info(f'one episode done! total episode length is: {self._timestep}') diff --git a/zoo/board_games/connect4/config/connect4_muzero_bot_mode_config.py b/zoo/board_games/connect4/config/connect4_muzero_bot_mode_config.py index 2d908f431..0a9d34a51 100644 --- a/zoo/board_games/connect4/config/connect4_muzero_bot_mode_config.py +++ b/zoo/board_games/connect4/config/connect4_muzero_bot_mode_config.py @@ -33,9 +33,8 @@ image_channel=3, num_res_blocks=1, num_channels=64, - support_scale=300, - reward_support_size=601, - value_support_size=601, + reward_support_range=(-300., 301., 1.), + value_support_range=(-300., 301., 1.), ), cuda=True, env_type='board_games', diff --git a/zoo/board_games/connect4/config/connect4_muzero_sp_mode_config.py b/zoo/board_games/connect4/config/connect4_muzero_sp_mode_config.py index af8dad8b6..7c286f313 100644 --- a/zoo/board_games/connect4/config/connect4_muzero_sp_mode_config.py +++ b/zoo/board_games/connect4/config/connect4_muzero_sp_mode_config.py @@ -33,9 +33,8 @@ image_channel=3, num_res_blocks=1, num_channels=64, - support_scale=300, - reward_support_size=601, - value_support_size=601, + reward_support_range=(-300., 301., 1.), + value_support_range=(-300., 301., 1.), ), cuda=True, env_type='board_games', diff --git a/zoo/board_games/connect4/config/connect4_rezero_mz_bot_mode_config.py b/zoo/board_games/connect4/config/connect4_rezero_mz_bot_mode_config.py index 98697887b..6fb7cd101 100644 --- a/zoo/board_games/connect4/config/connect4_rezero_mz_bot_mode_config.py +++ b/zoo/board_games/connect4/config/connect4_rezero_mz_bot_mode_config.py @@ -37,9 +37,8 @@ image_channel=3, num_res_blocks=1, num_channels=64, - support_scale=300, - reward_support_size=601, - value_support_size=601, + reward_support_range=(-10., 11., 1.), + value_support_range=(-10., 11., 1.), ), cuda=True, env_type='board_games', diff --git a/zoo/board_games/gomoku/config/gomoku_gumbel_muzero_bot_mode_config.py b/zoo/board_games/gomoku/config/gomoku_gumbel_muzero_bot_mode_config.py index 26fab3a1d..fbb19920e 100644 --- a/zoo/board_games/gomoku/config/gomoku_gumbel_muzero_bot_mode_config.py +++ b/zoo/board_games/gomoku/config/gomoku_gumbel_muzero_bot_mode_config.py @@ -40,9 +40,8 @@ image_channel=3, num_res_blocks=1, num_channels=32, - support_scale=10, - reward_support_size=21, - value_support_size=21, + reward_support_range=(-10., 11., 1.), + value_support_range=(-10., 11., 1.), ), # (str) The path of the pretrained model. If None, the model will be initialized by the default model. model_path=None, diff --git a/zoo/board_games/gomoku/config/gomoku_muzero_bot_mode_config.py b/zoo/board_games/gomoku/config/gomoku_muzero_bot_mode_config.py index 1eab765b1..2e7e48669 100644 --- a/zoo/board_games/gomoku/config/gomoku_muzero_bot_mode_config.py +++ b/zoo/board_games/gomoku/config/gomoku_muzero_bot_mode_config.py @@ -40,9 +40,8 @@ image_channel=3, num_res_blocks=1, num_channels=32, - support_scale=10, - reward_support_size=21, - value_support_size=21, + reward_support_range=(-10., 11., 1.), + value_support_range=(-10., 11., 1.), ), cuda=True, env_type='board_games', diff --git a/zoo/board_games/gomoku/config/gomoku_muzero_sp_mode_config.py b/zoo/board_games/gomoku/config/gomoku_muzero_sp_mode_config.py index efe7e659a..e2c99066f 100644 --- a/zoo/board_games/gomoku/config/gomoku_muzero_sp_mode_config.py +++ b/zoo/board_games/gomoku/config/gomoku_muzero_sp_mode_config.py @@ -38,9 +38,8 @@ image_channel=3, num_res_blocks=1, num_channels=32, - support_scale=10, - reward_support_size=21, - value_support_size=21, + reward_support_range=(-10., 11., 1.), + value_support_range=(-10., 11., 1.), ), # (str) The path of the pretrained model. If None, the model will be initialized by the default model. model_path=None, diff --git a/zoo/board_games/gomoku/config/gomoku_rezero_mz_bot_mode_config.py b/zoo/board_games/gomoku/config/gomoku_rezero_mz_bot_mode_config.py index e51e10ba1..56c35f45a 100644 --- a/zoo/board_games/gomoku/config/gomoku_rezero_mz_bot_mode_config.py +++ b/zoo/board_games/gomoku/config/gomoku_rezero_mz_bot_mode_config.py @@ -40,9 +40,8 @@ image_channel=3, num_res_blocks=1, num_channels=32, - support_scale=10, - reward_support_size=21, - value_support_size=21, + reward_support_range=(-10., 11., 1.), + value_support_range=(-10., 11., 1.), ), cuda=True, env_type='board_games', diff --git a/zoo/board_games/tictactoe/config/tictactoe_efficientzero_bot_mode_config.py b/zoo/board_games/tictactoe/config/tictactoe_efficientzero_bot_mode_config.py index db709271a..f53afbd42 100644 --- a/zoo/board_games/tictactoe/config/tictactoe_efficientzero_bot_mode_config.py +++ b/zoo/board_games/tictactoe/config/tictactoe_efficientzero_bot_mode_config.py @@ -35,9 +35,8 @@ reward_head_hidden_channels=[8], value_head_hidden_channels=[8], policy_head_hidden_channels=[8], - support_scale=10, - reward_support_size=21, - value_support_size=21, + reward_support_range=(-10., 11., 1.), + value_support_range=(-10., 11., 1.), norm_type='BN', downsample=False, discrete_action_encoding_type='one_hot', diff --git a/zoo/board_games/tictactoe/config/tictactoe_efficientzero_sp_mode_config.py b/zoo/board_games/tictactoe/config/tictactoe_efficientzero_sp_mode_config.py index 939ffef2e..84e1c94ae 100644 --- a/zoo/board_games/tictactoe/config/tictactoe_efficientzero_sp_mode_config.py +++ b/zoo/board_games/tictactoe/config/tictactoe_efficientzero_sp_mode_config.py @@ -35,9 +35,8 @@ reward_head_hidden_channels=[8], value_head_hidden_channels=[8], policy_head_hidden_channels=[8], - support_scale=10, - reward_support_size=21, - value_support_size=21, + reward_support_range=(-10., 11., 1.), + value_support_range=(-10., 11., 1.), downsample=False, discrete_action_encoding_type='one_hot', ), diff --git a/zoo/board_games/tictactoe/config/tictactoe_gumbel_muzero_bot_mode_config.py b/zoo/board_games/tictactoe/config/tictactoe_gumbel_muzero_bot_mode_config.py index a353b12e6..ab554adb3 100644 --- a/zoo/board_games/tictactoe/config/tictactoe_gumbel_muzero_bot_mode_config.py +++ b/zoo/board_games/tictactoe/config/tictactoe_gumbel_muzero_bot_mode_config.py @@ -35,9 +35,8 @@ reward_head_hidden_channels=[8], value_head_hidden_channels=[8], policy_head_hidden_channels=[8], - support_scale=10, - reward_support_size=21, - value_support_size=21, + reward_support_range=(-10., 11., 1.), + value_support_range=(-10., 11., 1.), ), # (str) The path of the pretrained model. If None, the model will be initialized by the default model. model_path=None, diff --git a/zoo/board_games/tictactoe/config/tictactoe_muzero_bot_mode_config.py b/zoo/board_games/tictactoe/config/tictactoe_muzero_bot_mode_config.py index 168931360..fbbb51d94 100644 --- a/zoo/board_games/tictactoe/config/tictactoe_muzero_bot_mode_config.py +++ b/zoo/board_games/tictactoe/config/tictactoe_muzero_bot_mode_config.py @@ -35,9 +35,8 @@ reward_head_hidden_channels=[8], value_head_hidden_channels=[8], policy_head_hidden_channels=[8], - support_scale=10, - reward_support_size=21, - value_support_size=21, + reward_support_range=(-10., 11., 1.), + value_support_range=(-10., 11., 1.), norm_type='BN', ), # (str) The path of the pretrained model. If None, the model will be initialized by the default model. diff --git a/zoo/board_games/tictactoe/config/tictactoe_muzero_sp_mode_config.py b/zoo/board_games/tictactoe/config/tictactoe_muzero_sp_mode_config.py index 9f40f0668..a6b1809ab 100644 --- a/zoo/board_games/tictactoe/config/tictactoe_muzero_sp_mode_config.py +++ b/zoo/board_games/tictactoe/config/tictactoe_muzero_sp_mode_config.py @@ -35,9 +35,8 @@ reward_head_hidden_channels=[8], value_head_hidden_channels=[8], policy_head_hidden_channels=[8], - support_scale=10, - reward_support_size=21, - value_support_size=21, + reward_support_range=(-10., 11., 1.), + value_support_range=(-10., 11., 1.), ), # (str) The path of the pretrained model. If None, the model will be initialized by the default model. model_path=None, diff --git a/zoo/box2d/box2d_suz_multitask.py b/zoo/box2d/box2d_suz_multitask.py deleted file mode 100644 index cf87e189d..000000000 --- a/zoo/box2d/box2d_suz_multitask.py +++ /dev/null @@ -1,179 +0,0 @@ -from easydict import EasyDict -from copy import deepcopy -import torch -def create_config(env_id, observation_shapes, action_space_sizes, collector_env_num, evaluator_env_num, n_episode, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type): - return EasyDict(dict( - env=dict( - stop_value=int(1e6), - env_id=env_id, - continuous=True, - manually_discretization=False, - collector_env_num=collector_env_num, - evaluator_env_num=evaluator_env_num, - n_evaluator_episode=evaluator_env_num, - manager=dict(shared_memory=False, ), - ), - policy=dict( - learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=200000,),),), # default is 10000 - grad_correct_params=dict( - # for MoCo - MoCo_beta=0.5, - MoCo_beta_sigma=0.5, - MoCo_gamma=0.1, - MoCo_gamma_sigma=0.5, - MoCo_rho=0, - # for CAGrad - calpha=0.5, - rescale=1, - ), - task_num=len(env_id_list), - task_id=0, - model=dict( - observation_shapes=observation_shapes, - action_space_size=4, - action_space_sizes=action_space_sizes, - continuous_action_space=True, - num_of_sampled_actions=20, - model_type='mlp', - world_model_cfg=dict( - obs_type='vector', - num_unroll_steps=num_unroll_steps, - policy_entropy_loss_weight=1e-4, - continuous_action_space=True, - num_of_sampled_actions=20, - sigma_type='conditioned', - norm_type=norm_type, - bound_type=None, - max_blocks=num_unroll_steps, - max_tokens=2 * num_unroll_steps, - context_length=2 * infer_context_length, - device='cuda' if torch.cuda.is_available() else 'cpu', - action_space_size=action_space_sizes, - env_num=max(collector_env_num, evaluator_env_num), - task_num=len(env_id_list), - use_normal_head=True, - use_softmoe_head=False, - use_moe_head=False, - num_experts_in_moe_head=4, # NOTE - moe_in_transformer=False, # NOTE - multiplication_moe_in_transformer=False, # NOTE - num_experts_of_moe_in_transformer=4, - ), - ), - use_priority=True, - print_task_priority_logs=False, - cuda=True, - model_path=None, - num_unroll_steps=num_unroll_steps, - replay_ratio=0.25, - batch_size=batch_size, - optim_type='AdamW', - learning_rate=1e-4, - num_simulations=num_simulations, - reanalyze_ratio=reanalyze_ratio, - n_episode=n_episode, - eval_freq=int(2e3), - replay_buffer_size=int(1e6), - collector_env_num=collector_env_num, - evaluator_env_num=evaluator_env_num, - ), - )) - -def generate_configs(env_id_list, observation_shapes, action_space_sizes, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed): - configs = [] - exp_name_prefix = f'data_unizero_mt_box2d/{len(env_id_list)}games_cont_action_seed{seed}/' - - for task_id, (env_id, observation_shape, action_space_size) in enumerate(zip(env_id_list, observation_shapes, action_space_sizes)): - config = create_config( - env_id, - observation_shapes, # TODO - action_space_sizes, - collector_env_num, - evaluator_env_num, - n_episode, - num_simulations, - reanalyze_ratio, - batch_size, - num_unroll_steps, - infer_context_length, - norm_type - ) - config.policy.task_id = task_id - config.exp_name = exp_name_prefix + f"{env_id.split('-v')[0]}_unizero_mt_seed{seed}" - - configs.append([task_id, [config, create_env_manager(env_name=env_id)]]) - return configs - -def create_env_manager(): - return EasyDict(dict( - env=dict( - type='box2d', - import_names=['zoo.box2d.lunarlander.envs.lunarlander_env', 'zoo.box2d.bipedalwalker.envs.bipedalwalker_env'], - ), - env_manager=dict(type='subprocess'), - policy=dict( - type='sampled_unizero_multitask', - import_names=['lzero.policy.sampled_unizero_multitask'], - ), - )) - -def create_env_manager(env_name: str): - if env_name == 'LunarLanderContinuous-v2': - return EasyDict(dict( - env=dict( - type='lunarlander', - import_names=[f'zoo.box2d.lunarlander.envs.lunarlander_env'], - ), - env_manager=dict(type='subprocess'), - policy=dict( - type='sampled_unizero_multitask', - import_names=['lzero.policy.sampled_unizero_multitask'], - ), - )) - elif env_name == 'BipedalWalker-v3': - return EasyDict(dict( - env=dict( - type='bipedalwalker', - import_names=[f'zoo.box2d.bipedalwalker.envs.bipedalwalker_env'], - ), - env_manager=dict(type='subprocess'), - policy=dict( - type='sampled_unizero_multitask', - import_names=['lzero.policy.sampled_unizero_multitask'], - ), - )) - -if __name__ == "__main__": - from lzero.entry import train_unizero_multitask - - env_id_list = [ - 'LunarLanderContinuous-v2', - 'BipedalWalker-v3', - ] - - observation_shapes = [ - 8, # LunarLanderContinuous-v2 - 24, # BipedalWalker-v3 - ] - - action_space_sizes = [ - 2, # LunarLanderContinuous-v2 - 4, # BipedalWalker-v3 - ] - - seed = 0 - collector_env_num = 6 - n_episode = 8 - evaluator_env_num = 3 - num_simulations = 50 - max_env_step = int(1e6) - reanalyze_ratio = 0. - max_batch_size = 1000 - batch_size = [int(max_batch_size/len(env_id_list)) for i in range(len(env_id_list))] - num_unroll_steps = 10 - infer_context_length = 4 - norm_type = 'LN' - - configs = generate_configs(env_id_list, observation_shapes, action_space_sizes, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed) - - train_unizero_multitask(configs, seed=seed, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/classic_control/cartpole/config/cartpole_muzero_config.py b/zoo/classic_control/cartpole/config/cartpole_muzero_config.py index d19e61d3e..3387ab602 100644 --- a/zoo/classic_control/cartpole/config/cartpole_muzero_config.py +++ b/zoo/classic_control/cartpole/config/cartpole_muzero_config.py @@ -43,7 +43,6 @@ model_path=None, cuda=True, env_type='not_board_games', - action_type='varied_action_space', game_segment_length=50, update_per_collect=update_per_collect, batch_size=batch_size, diff --git a/zoo/classic_control/cartpole/config/cartpole_unizero_config.py b/zoo/classic_control/cartpole/config/cartpole_unizero_config.py index 9bab25093..7cb8d98d4 100644 --- a/zoo/classic_control/cartpole/config/cartpole_unizero_config.py +++ b/zoo/classic_control/cartpole/config/cartpole_unizero_config.py @@ -15,7 +15,6 @@ # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== - cartpole_unizero_config = dict( exp_name=f'data_unizero/cartpole_unizero_ns{num_simulations}_upc{update_per_collect}-rr{replay_ratio}_H{num_unroll_steps}_bs{batch_size}_seed0', env=dict( @@ -28,6 +27,7 @@ manager=dict(shared_memory=False, ), ), policy=dict( + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000, ), ), ), model=dict( observation_shape=4, action_space_size=2, @@ -36,6 +36,9 @@ norm_type='BN', model_type='mlp', world_model_cfg=dict( + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', max_blocks=10, max_tokens=2 * 10, context_length=2 * 4, diff --git a/zoo/classic_control/mountain_car/entry/visualize_mz_mtcar.ipynb b/zoo/classic_control/mountain_car/entry/visualize_mz_mtcar.ipynb index e24522e94..ae8d89616 100644 --- a/zoo/classic_control/mountain_car/entry/visualize_mz_mtcar.ipynb +++ b/zoo/classic_control/mountain_car/entry/visualize_mz_mtcar.ipynb @@ -49,7 +49,7 @@ "from ding.torch_utils import to_tensor, to_device, to_ndarray\n", "from ding.worker import BaseLearner\n", "from lzero.worker import MuZeroEvaluator\n", - "from lzero.policy import InverseScalarTransform, mz_network_output_unpack\n", + "from lzero.policy import DiscreteSupport, InverseScalarTransform, mz_network_output_unpack\n", "\n", "from zoo.classic_control.mountain_car.config.mtcar_muzero_config import main_config, create_config\n", "# from lzero.entry import eval_muzero\n", @@ -195,9 +195,9 @@ " with torch.no_grad():\n", " network_output = model.initial_inference(state_space)\n", " latent_state, reward, value, policy_logits = mz_network_output_unpack(network_output)\n", + " discrete_support = DiscreteSupport(*policy_cfg.model.support_range, policy_cfg.device)\n", " inverse_scalar_transform_handler = InverseScalarTransform(\n", - " policy_cfg.model.support_scale,\n", - " policy_cfg.device,\n", + " discrete_support,\n", " policy_cfg.model.categorical_distribution)\n", " value_real = inverse_scalar_transform_handler(value)\n", "\n", diff --git a/zoo/dmc2gym/config/dmc2gym_state_smz_config.py b/zoo/dmc2gym/config/dmc2gym_state_smz_config.py index 95456d56e..c99d3960b 100644 --- a/zoo/dmc2gym/config/dmc2gym_state_smz_config.py +++ b/zoo/dmc2gym/config/dmc2gym_state_smz_config.py @@ -30,7 +30,7 @@ # ============================================================== dmc2gym_state_cont_sampled_muzero_config = dict( - exp_name=f'data_smz/dmc2gym_{env_id}_state_cont_sampled_muzero_k{K}_ns{num_simulations}_upc{update_per_collect}-rr{replay_ratio}_rer{reanalyze_ratio}_{norm_type}_seed{seed}', + exp_name=f'/oss/niuyazhe/puyuan/data/data_lz_202505/data_smz/dmc2gym_{env_id}_state_cont_sampled_muzero_k{K}_ns{num_simulations}_upc{update_per_collect}-rr{replay_ratio}_rer{reanalyze_ratio}_{norm_type}_seed{seed}', env=dict( env_id='dmc2gym-v0', domain_name=domain_name, diff --git a/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_balance_config.py b/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_balance_config.py new file mode 100644 index 000000000..ba979b1c6 --- /dev/null +++ b/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_balance_config.py @@ -0,0 +1,531 @@ +# -*- coding: utf-8 -*- +""" +Overview: + This script defines the configuration for a multi-task reinforcement learning experiment + using the UniZero model on DeepMind Control Suite (DMC) environments. + It is designed to be launched with PyTorch's Distributed Data Parallel (DDP) for multi-GPU training. +""" +from __future__ import annotations + +import logging +from typing import Any, Dict, List + +from easydict import EasyDict + +# ============================================================== +# Global setup: Logging +# ============================================================== +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(message)s', + handlers=[ + logging.FileHandler("output.log", encoding="utf-8"), # Log to file + logging.StreamHandler() # Log to console + ] +) + + +def get_base_config(env_id_list: list[str], collector_env_num: int, evaluator_env_num: int, + num_unroll_steps: int, infer_context_length: int, curriculum_stage_num: int) -> EasyDict: + """ + Overview: + Creates the base configuration EasyDict with default settings for the experiment. + These settings are shared across all tasks but can be overridden. + + Arguments: + - env_id_list (:obj:`list[str]`): A list of environment IDs for all tasks. + - collector_env_num (:obj:`int`): The number of environments for data collection. + - evaluator_env_num (:obj:`int`): The number of environments for evaluation. + - num_unroll_steps (:obj:`int`): The number of game steps to unroll in the model. + - infer_context_length (:obj:`int`): The context length for inference. + - curriculum_stage_num (:obj:`int`): The number of stages in the curriculum learning. + + Returns: + - (:obj:`EasyDict`): A dictionary containing the base configuration. + """ + return EasyDict(dict( + # Environment-specific settings + env=dict( + stop_value=int(5e5), + from_pixels=False, + continuous=True, # Assuming all DMC tasks use continuous action spaces + manager=dict(shared_memory=False), + game_segment_length=100, + # TODO(user): For debugging only. Uncomment to use smaller segments and episodes. + # game_segment_length=10, + # collect_max_episode_steps=int(40), + # eval_max_episode_steps=int(40), + ), + # Policy-specific settings + policy=dict( + multi_gpu=True, # TODO(user): Enable multi-GPU for DDP. + # TODO(user): Configure MoCo settings. + only_use_moco_stats=False, + use_moco=False, + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000))), + grad_correct_params=dict( + MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, MoCo_rho=0, + calpha=0.5, rescale=1, + ), + total_task_num=len(env_id_list), + task_num=len(env_id_list), + # Model configuration + model=dict( + continuous_action_space=True, + num_of_sampled_actions=20, + model_type='mlp', + world_model_cfg=dict( + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', # TODO(user): Loss type for latent state with LayerNorm. + + share_head=False, # TODO(user): Whether to share the prediction head across tasks. + use_shared_projection=False, + + # TODO(user): analysis_dormant_ratio needs to be corrected for the DMC encoder. + analysis_dormant_ratio_weight_rank=False, + analysis_dormant_ratio_interval=5000, + # analysis_dormant_ratio_interval=20, # For debugging + + # TODO(user): Configure task embedding options. + task_embed_option=None, + use_task_embed=False, + # task_embed_option='concat_task_embed', + # use_task_embed=True, + # task_embed_dim=128, + + policy_loss_type='kl', + obs_type='vector', + policy_entropy_weight=5e-2, + continuous_action_space=True, + num_of_sampled_actions=20, + sigma_type='conditioned', + fixed_sigma_value=0.5, + bound_type=None, + model_type='mlp', + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, # Each timestep has 2 tokens: obs and action + context_length=2 * infer_context_length, + device='cuda', + + # TODO(user): For debugging only. Use a smaller model. + # num_layers=1, + num_layers=4, + # num_layers=8, + + num_heads=24, + embed_dim=768, + env_num=max(collector_env_num, evaluator_env_num), + task_num=len(env_id_list), + + # Mixture of Experts (MoE) head configuration + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + + # MoE in Transformer configuration + moe_in_transformer=False, + multiplication_moe_in_transformer=True, + n_shared_experts=1, + num_experts_per_tok=1, + num_experts_of_moe_in_transformer=8, + + # LoRA (Low-Rank Adaptation) parameters + # TODO(user): Enable or disable LoRA for MoE layers. + moe_use_lora=True, + lora_target_modules=["attn", "feed_forward"], + lora_r=64, + lora_alpha=1, + lora_dropout=0.0, + lora_scale_init=1, + + # Curriculum learning stage iteration counts + curriculum_stage_num=curriculum_stage_num, + min_stage0_iters=10000, # Corresponds to 400k envsteps, 40k iters + max_stage_iters=5000, + + # TODO(user): For debugging only. Use very short stage iterations. + # min_stage0_iters=2, + # max_stage_iters=5, + ), + ), + # TODO(user): Enable or disable task exploitation weight. + use_task_exploitation_weight=False, + balance_pipeline=True, + # TODO(user): Enable or disable task complexity weight. + task_complexity_weight=True, + allocated_batch_sizes=False, + # TODO(user): Set the number of environment steps to collect before training starts. + train_start_after_envsteps=int(0), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + + # TODO(user): For debugging only. Set a smaller update_per_collect. + # update_per_collect=3, + update_per_collect=200, # e.g., 8 envs * 100 steps/env * 0.25 replay_ratio = 200 + replay_buffer_size=int(1e6), + eval_freq=int(4e3), + grad_clip_value=5, + learning_rate=1e-4, + discount_factor=0.99, + td_steps=5, + piecewise_decay_lr_scheduler=False, + manual_temperature_decay=True, + threshold_training_steps_for_final_temperature=int(2.5e4), + cos_lr_scheduler=True, + ), + )) + + +def create_task_config( + base_config: EasyDict, + env_id: str, + observation_shape_list: list[int], + action_space_size_list: list[int], + target_return_dict: dict[str, int], + collector_env_num: int, + evaluator_env_num: int, + n_episode: int, + num_simulations: int, + reanalyze_ratio: float, + batch_size: int, + num_unroll_steps: int, + norm_type: str, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + total_batch_size: int +) -> EasyDict: + """ + Overview: + Creates a specialized configuration for a single task by updating the base config. + + Arguments: + - base_config (:obj:`EasyDict`): The base configuration dictionary. + - env_id (:obj:`str`): The ID of the environment for this specific task. + - observation_shape_list (:obj:`list[int]`): List of observation shapes for all tasks. + - action_space_size_list (:obj:`list[int]`): List of action space sizes for all tasks. + - target_return_dict (:obj:`dict[str, int]`): A dictionary mapping env_id to its target return. + - collector_env_num (:obj:`int`): The number of collector environments. + - evaluator_env_num (:obj:`int`): The number of evaluator environments. + - n_episode (:obj:`int`): The number of episodes to run for collection. + - num_simulations (:obj:`int`): The number of simulations in MCTS. + - reanalyze_ratio (:obj:`float`): The ratio of reanalyzed data in a batch. + - batch_size (:obj:`int`): The batch size for training this task. + - num_unroll_steps (:obj:`int`): The number of steps to unroll the model. + - norm_type (:obj:`str`): The type of normalization to use (e.g., 'LN'). + - buffer_reanalyze_freq (:obj:`float`): Frequency of buffer reanalysis. + - reanalyze_batch_size (:obj:`int`): Batch size for reanalysis. + - reanalyze_partition (:obj:`float`): Partition ratio for reanalysis. + - num_segments (:obj:`int`): The number of segments in the replay buffer. + - total_batch_size (:obj:`int`): The total batch size across all tasks. + + Returns: + - (:obj:`EasyDict`): The final configuration for the specified task. + """ + domain_name, task_name = env_id.split('-', 1) + frame_skip = 8 if domain_name == "pendulum" else 4 + + config = base_config + + # Update environment settings + config.env.update(dict( + env_id=env_id, + domain_name=domain_name, + task_name=task_name, + observation_shape_list=observation_shape_list, + action_space_size_list=action_space_size_list, + frame_skip=frame_skip, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + )) + + # Update model settings + config.policy.model.update(dict( + observation_shape_list=observation_shape_list, + action_space_size_list=action_space_size_list, + )) + config.policy.model.world_model_cfg.update(dict( + observation_shape_list=observation_shape_list, + action_space_size_list=action_space_size_list, + num_unroll_steps=num_unroll_steps, + norm_type=norm_type, + )) + + # Update policy settings + config.policy.update(dict( + target_return=target_return_dict.get(env_id), + total_batch_size=total_batch_size, + num_unroll_steps=num_unroll_steps, + replay_ratio=reanalyze_ratio, + batch_size=batch_size, + num_segments=num_segments, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + )) + + return config + + +def create_env_manager_config() -> EasyDict: + """ + Overview: + Creates the configuration for the environment manager and policy type. + + Returns: + - (:obj:`EasyDict`): A dictionary with environment manager and policy import settings. + """ + return EasyDict(dict( + env=dict( + type='dmc2gym_lightzero', + import_names=['zoo.dmc2gym.envs.dmc2gym_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='sampled_unizero_multitask', + import_names=['lzero.policy.sampled_unizero_multitask'], + ), + )) + + +def generate_experiment_name(num_tasks: int, curriculum_stage_num: int, buffer_reanalyze_freq: float, seed: int) -> str: + """ + Overview: + Generates a descriptive name for the experiment. + + Arguments: + - num_tasks (:obj:`int`): Number of tasks in the experiment. + - curriculum_stage_num (:obj:`int`): Number of curriculum stages. + - buffer_reanalyze_freq (:obj:`float`): Frequency of buffer reanalysis. + - seed (:obj:`int`): The random seed for the experiment. + + Returns: + - (:obj:`str`): The generated experiment name prefix. + """ + # NOTE: This is a template for the experiment name. + # Users should customize it to reflect their specific experiment settings. + return ( + f'data_suz_dmc_mt_balance_20250625/dmc_{num_tasks}tasks_frameskip4-pen-fs8_balance-stage-total-{curriculum_stage_num}' + f'_stage0-10k-5k_fix-lora-update-stablescale_moe8-uselora_nlayer4_not-share-head' + f'_brf{buffer_reanalyze_freq}_seed{seed}/' + ) + + +def generate_all_task_configs( + env_id_list: list[str], + target_return_dict: dict[str, int], + action_space_size_list: list[int], + observation_shape_list: list[int], + curriculum_stage_num: int, + collector_env_num: int, + n_episode: int, + evaluator_env_num: int, + num_simulations: int, + reanalyze_ratio: float, + batch_size: list[int], + num_unroll_steps: int, + infer_context_length: int, + norm_type: str, + seed: int, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + total_batch_size: int +) -> list[tuple[int, list[EasyDict, EasyDict]]]: + """ + Overview: + Generates a list of configurations, one for each task in the experiment. + + Arguments: + - env_id_list (:obj:`list[str]`): A list of all environment IDs. + - target_return_dict (:obj:`dict[str, int]`): Mapping from env_id to target return. + - action_space_size_list (:obj:`list[int]`): List of action space sizes for all tasks. + - observation_shape_list (:obj:`list[int]`): List of observation shapes for all tasks. + - curriculum_stage_num (:obj:`int`): The number of curriculum stages. + - (other args): Hyperparameters for the experiment. See `create_task_config` for details. + + Returns: + - (:obj:`list`): A list where each element is `[task_id, [task_config, env_manager_config]]`. + """ + configs = [] + exp_name_prefix = generate_experiment_name( + num_tasks=len(env_id_list), + curriculum_stage_num=curriculum_stage_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + seed=seed + ) + + base_config = get_base_config( + env_id_list=env_id_list, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + curriculum_stage_num=curriculum_stage_num + ) + + for task_id, env_id in enumerate(env_id_list): + task_specific_config = create_task_config( + base_config=base_config.clone(), # Use a clone to avoid modifying the base config + env_id=env_id, + action_space_size_list=action_space_size_list, + observation_shape_list=observation_shape_list, + target_return_dict=target_return_dict, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_episode=n_episode, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_size=batch_size[task_id], + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments, + total_batch_size=total_batch_size, + ) + task_specific_config.policy.task_id = task_id + task_specific_config.exp_name = exp_name_prefix + f"{env_id}_seed{seed}" + + env_manager_cfg = create_env_manager_config() + configs.append([task_id, [task_specific_config, env_manager_cfg]]) + + return configs + + +def main(): + """ + Overview: + Main function to set up and launch the multi-task UniZero training experiment. + This script should be executed with GPUs. + + Example launch commands: + 1. Using `torch.distributed.launch`: + cd /LightZero/ + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29501 \\ + ./zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_balance_config.py 2>&1 | tee \\ + ./logs/uz_mt_dmc18_balance_moe8_seed0.log + + 2. Using `torchrun`: + cd /LightZero/ + torchrun --nproc_per_node=8 ./zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_balance_config.py + """ + from lzero.entry import train_unizero_multitask_balance_segment_ddp + from ding.utils import DDPContext + import torch.distributed as dist + from zoo.dmc2gym.config.dmc_state_env_space_map import dmc_state_env_action_space_map, dmc_state_env_obs_space_map + + # ============================================================== + # Experiment-level settings + # ============================================================== + # NOTE: You can switch between different sets of environments by uncommenting them. + # DMC 8-task benchmark + # env_id_list = [ + # 'acrobot-swingup', 'cartpole-balance', 'cartpole-balance_sparse', + # 'cartpole-swingup', 'cartpole-swingup_sparse', 'cheetah-run', + # "ball_in_cup-catch", "finger-spin", + # ] + # target_return_dict = { + # 'acrobot-swingup': 500, 'cartpole-balance': 950, 'cartpole-balance_sparse': 950, + # 'cartpole-swingup': 800, 'cartpole-swingup_sparse': 750, 'cheetah-run': 650, + # "ball_in_cup-catch": 950, "finger-spin": 800, + # } + + # DMC 18-task benchmark + env_id_list = [ + 'acrobot-swingup', 'cartpole-balance', 'cartpole-balance_sparse', 'cartpole-swingup', + 'cartpole-swingup_sparse', 'cheetah-run', "ball_in_cup-catch", "finger-spin", + "finger-turn_easy", "finger-turn_hard", 'hopper-hop', 'hopper-stand', + 'pendulum-swingup', 'reacher-easy', 'reacher-hard', 'walker-run', + 'walker-stand', 'walker-walk', + ] + target_return_dict = { + 'acrobot-swingup': 500, 'cartpole-balance': 900, 'cartpole-balance_sparse': 950, + 'cartpole-swingup': 750, 'cartpole-swingup_sparse': 750, 'cheetah-run': 550, + "ball_in_cup-catch": 950, "finger-spin": 800, "finger-turn_easy": 950, + "finger-turn_hard": 950, 'hopper-hop': 150, 'hopper-stand': 600, + 'pendulum-swingup': 800, 'reacher-easy': 900, 'reacher-hard': 900, + 'walker-run': 500, 'walker-stand': 900, 'walker-walk': 900, + } + + # ============================================================== + # Hyperparameters + # ============================================================== + # NOTE: For debugging, you can use smaller values. + # collector_env_num, num_segments, n_episode = 2, 2, 2 + # evaluator_env_num, num_simulations, total_batch_size = 2, 1, 8 + # batch_size = [3] * len(env_id_list) + # max_env_step = int(1e3) + + # Production settings + curriculum_stage_num = 5 + collector_env_num = 8 + num_segments = 8 + n_episode = 8 + evaluator_env_num = 3 + num_simulations = 50 + max_env_step = int(4e5) + reanalyze_ratio = 0.0 + total_batch_size = 512 + batch_size = [int(min(64, total_batch_size / len(env_id_list)))] * len(env_id_list) + num_unroll_steps = 5 + infer_context_length = 2 + norm_type = 'LN' + buffer_reanalyze_freq = 1 / 100000 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + seed = 0 # You can iterate over multiple seeds if needed + + # Fetch observation and action space info from predefined maps + action_space_size_list = [dmc_state_env_action_space_map[env_id] for env_id in env_id_list] + observation_shape_list = [dmc_state_env_obs_space_map[env_id] for env_id in env_id_list] + + # ============================================================== + # Generate configurations and start training + # ============================================================== + configs = generate_all_task_configs( + env_id_list=env_id_list, + target_return_dict=target_return_dict, + action_space_size_list=action_space_size_list, + observation_shape_list=observation_shape_list, + curriculum_stage_num=curriculum_stage_num, + collector_env_num=collector_env_num, + n_episode=n_episode, + evaluator_env_num=evaluator_env_num, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_size=batch_size, + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + seed=seed, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments, + total_batch_size=total_batch_size, + ) + + with DDPContext(): + # To train only a subset of tasks for debugging, you can slice the configs list. + # e.g., train_unizero_multitask_balance_segment_ddp(configs[:1], ...) + train_unizero_multitask_balance_segment_ddp(configs, seed=seed, max_env_step=max_env_step, benchmark_name="dmc") + dist.destroy_process_group() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py b/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py index 57770d6a3..de2c09fa2 100644 --- a/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py +++ b/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py @@ -1,195 +1,325 @@ from easydict import EasyDict -from typing import List +from typing import List, Any, Dict, Tuple import logging +# Set up logging configuration +# Configure logging to output to both a file and the console. logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(message)s', handlers=[ - logging.FileHandler("output.log", encoding="utf-8"), # 文件日志 - logging.StreamHandler() # 终端日志 + logging.FileHandler("output.log", encoding="utf-8"), # Log to file + logging.StreamHandler() # Log to console ] ) -def create_config(env_id, observation_shape_list, action_space_size_list, collector_env_num, evaluator_env_num, n_episode, - num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, - norm_type, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, - total_batch_size): - domain_name = env_id.split('-')[0] - task_name = env_id.split('-')[1] - return EasyDict(dict( - env=dict( - stop_value=int(5e5), - env_id=env_id, - domain_name=domain_name, - task_name=task_name, + +def create_config( + env_id: str, + env_id_list: List[str], + target_return_dict: Dict[str, int], + observation_shape_list: List[Tuple[int, ...]], + action_space_size_list: List[int], + collector_env_num: int, + evaluator_env_num: int, + n_episode: int, + num_simulations: int, + reanalyze_ratio: float, + batch_size: List[int], + num_unroll_steps: int, + infer_context_length: int, + norm_type: str, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + total_batch_size: int, +) -> EasyDict: + """ + Overview: + Create a configuration EasyDict for a single reinforcement learning task. + + Arguments: + - env_id (:obj:`str`): The ID of the environment, e.g., 'cartpole-swingup'. + - env_id_list (:obj:`List[str]`): A list of all environment IDs for the multi-task setup. + - target_return_dict (:obj:`Dict[str, int]`): A dictionary mapping environment IDs to their target return values. + - observation_shape_list (:obj:`List[Tuple[int, ...]]`): List of observation shapes for all tasks. + - action_space_size_list (:obj:`List[int]`): List of action space sizes for all tasks. + - collector_env_num (:obj:`int`): Number of environments for data collection. + - evaluator_env_num (:obj:`int`): Number of environments for evaluation. + - n_episode (:obj:`int`): Number of episodes to run for collection. + - num_simulations (:obj:`int`): Number of simulations in the MCTS search. + - reanalyze_ratio (:obj:`float`): The ratio of reanalyzed data in a batch. + - batch_size (:obj:`List[int]`): Batch size for training per task. + - num_unroll_steps (:obj:`int`): Number of steps to unroll the model during training. + - infer_context_length (:obj:`int`): The context length for inference. + - norm_type (:obj:`str`): The type of normalization to use (e.g., 'LN'). + - buffer_reanalyze_freq (:obj:`float`): Frequency of reanalyzing the buffer. + - reanalyze_batch_size (:obj:`int`): Batch size for reanalyzing. + - reanalyze_partition (:obj:`float`): Partition ratio for reanalyzing. + - num_segments (:obj:`int`): Number of segments for the replay buffer. + - total_batch_size (:obj:`int`): The total batch size across all tasks. + + Returns: + - (:obj:`EasyDict`): A configuration object for the specified task. + """ + domain_name, task_name = env_id.split('-') + + # Specific frame_skip settings for certain domains. + if domain_name == "pendulum": + frame_skip = 8 + else: + frame_skip = 4 + + # --- Environment Configuration --- + env_cfg = dict( + stop_value=int(5e5), + env_id=env_id, + domain_name=domain_name, + task_name=task_name, + observation_shape_list=observation_shape_list, + action_space_size_list=action_space_size_list, + from_pixels=False, + frame_skip=frame_skip, + continuous=True, # Assuming all DMC tasks use continuous action spaces + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + game_segment_length=100, + # TODO: Settings for debugging purposes. + # game_segment_length=10, + # collect_max_episode_steps=int(40), + # eval_max_episode_steps=int(40), + ) + + # --- World Model Configuration --- + world_model_cfg = dict( + # --- Normalization and Loss --- + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', # TODO: for latent state layer_norm + # final_norm_option_in_obs_head='SimNorm', + # final_norm_option_in_encoder='SimNorm', + # predict_latent_loss_type='group_kl', # TODO: only for latent state sim_norm + + # --- Architecture --- + share_head=False, # TODO + use_shared_projection=False, + obs_type='vector', + model_type='mlp', + continuous_action_space=True, + num_of_sampled_actions=20, + sigma_type='conditioned', + fixed_sigma_value=0.5, + bound_type=None, + norm_type=norm_type, + device='cuda', + + # --- Transformer/MOE Settings --- + num_layers=8, # TODO: 8 for standard, 1 for debug + num_heads=24, + embed_dim=768, + moe_in_transformer=False, + multiplication_moe_in_transformer=True, + num_experts_of_moe_in_transformer=8, + n_shared_experts=1, + num_experts_per_tok=1, + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + + # --- LoRA Parameters --- + moe_use_lora=False, # TODO + curriculum_stage_num=3, + lora_target_modules=["attn", "feed_forward"], + lora_r=0, + lora_alpha=1, + lora_dropout=0.0, + + # --- Multi-task Settings --- + task_embed_option=None, # TODO: 'concat_task_embed' or None + use_task_embed=False, # TODO + # task_embed_dim=128, + task_num=len(env_id_list), + + # --- Analysis --- + analysis_dormant_ratio_weight_rank=False, # TODO + analysis_dormant_ratio_interval=5000, + + # --- Dynamic Properties --- + observation_shape_list=observation_shape_list, + action_space_size_list=action_space_size_list, + num_unroll_steps=num_unroll_steps, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, # Each timestep has 2 tokens: obs and action + context_length=2 * infer_context_length, + env_num=max(collector_env_num, evaluator_env_num), + + # --- Loss Weights --- + policy_loss_type='kl', + policy_entropy_weight=5e-2, + ) + + # --- Policy Configuration --- + policy_cfg = dict( + # --- Hardware & Distribution --- + multi_gpu=True, # TODO: enable multi-GPU for DDP + cuda=True, + + # --- Model --- + model=dict( observation_shape_list=observation_shape_list, action_space_size_list=action_space_size_list, - from_pixels=False, - frame_skip=2, - continuous=True, # Assuming all DMC tasks use continuous action spaces - collector_env_num=collector_env_num, - evaluator_env_num=evaluator_env_num, - n_evaluator_episode=evaluator_env_num, - manager=dict(shared_memory=False), - game_segment_length=100, # As per single-task config - # ===== TODO: only for debug ===== - # game_segment_length=10, # As per single-task config - # collect_max_episode_steps=int(20), - # eval_max_episode_steps=int(20), + continuous_action_space=True, + num_of_sampled_actions=20, + model_type='mlp', + world_model_cfg=world_model_cfg, ), - policy=dict( - multi_gpu=True, # TODO: enable multi-GPU for DDP - only_use_moco_stats=False, - # use_moco=False, # ==============TODO============== - use_moco=True, # ==============TODO============== - learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000))), - grad_correct_params=dict( - # Example gradient correction parameters, adjust as needed - MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, MoCo_rho=0, - calpha=0.5, rescale=1, - ), - total_task_num=len(env_id_list), - task_num=len(env_id_list), - task_id=0, # To be set per task - model=dict( - observation_shape_list=observation_shape_list, - action_space_size_list=action_space_size_list, - continuous_action_space=True, - num_of_sampled_actions=20, - model_type='mlp', - world_model_cfg=dict( - final_norm_option_in_obs_head='LayerNorm', - final_norm_option_in_encoder='LayerNorm', - predict_latent_loss_type='mse', # TODO: for latent state layer_norm - - share_head=False, # TODO - use_shared_projection=False, - # analysis_dormant_ratio_weight_rank=True, # TODO - analysis_dormant_ratio_weight_rank=False, # TODO - dormant_threshold=0.025, - - # task_embed_option=None, # ==============TODO: none ============== - # use_task_embed=False, # ==============TODO============== - - task_embed_option='concat_task_embed', # ==============TODO: none ============== - use_task_embed=True, # ==============TODO============== - task_embed_dim=128, - # task_embed_dim=96, - - observation_shape_list=observation_shape_list, - action_space_size_list=action_space_size_list, - policy_loss_type='kl', - obs_type='vector', - num_unroll_steps=num_unroll_steps, - policy_entropy_weight=5e-2, - continuous_action_space=True, - num_of_sampled_actions=20, - sigma_type='conditioned', - fixed_sigma_value=0.5, - bound_type=None, - model_type='mlp', - norm_type=norm_type, - max_blocks=num_unroll_steps, - max_tokens=2 * num_unroll_steps, # Each timestep has 2 tokens: obs and action - context_length=2 * infer_context_length, - device='cuda', - # num_layers=1, # TODO: debug config - num_layers=8, # TODO - num_heads=24, - embed_dim=768, - env_num=max(collector_env_num, evaluator_env_num), - task_num=len(env_id_list), - use_normal_head=True, - use_softmoe_head=False, - use_moe_head=False, - num_experts_in_moe_head=4, - moe_in_transformer=False, - multiplication_moe_in_transformer=False, - num_experts_of_moe_in_transformer=4, - - # LoRA 参数: - lora_r= 0, - lora_alpha =1, - lora_dropout= 0.0, - ), - ), - use_task_exploitation_weight=False, # TODO - # use_task_exploitation_weight=True, # TODO - - task_complexity_weight=False, # TODO - total_batch_size=total_batch_size, - allocated_batch_sizes=False, - # train_start_after_envsteps=int(2e3), # TODO - train_start_after_envsteps=int(0), - use_priority=False, - print_task_priority_logs=False, - cuda=True, - model_path=None, - num_unroll_steps=num_unroll_steps, - # update_per_collect=2, # TODO: debug config - # update_per_collect=200, # TODO: 8*100*0.25=200 - update_per_collect=80, # TODO: 8*100*0.1=80 - replay_ratio=reanalyze_ratio, - batch_size=batch_size, - optim_type='AdamW', - num_segments=num_segments, - num_simulations=num_simulations, - reanalyze_ratio=reanalyze_ratio, - n_episode=n_episode, - replay_buffer_size=int(1e6), - # eval_freq=int(5e3), - eval_freq=int(4e3), - grad_clip_value=5, - learning_rate=1e-4, - discount_factor=0.99, - td_steps=5, - piecewise_decay_lr_scheduler=False, - manual_temperature_decay=True, - threshold_training_steps_for_final_temperature=int(2.5e4), - cos_lr_scheduler=True, - collector_env_num=collector_env_num, - evaluator_env_num=evaluator_env_num, - buffer_reanalyze_freq=buffer_reanalyze_freq, - reanalyze_batch_size=reanalyze_batch_size, - reanalyze_partition=reanalyze_partition, + + # --- Learning --- + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000))), + optim_type='AdamW', + learning_rate=1e-4, + grad_clip_value=5, + cos_lr_scheduler=True, + piecewise_decay_lr_scheduler=False, + + # --- Training Loop --- + train_start_after_envsteps=int(0), # TODO: 2e3 for standard, 0 for quick debug + update_per_collect=200, + replay_ratio=reanalyze_ratio, + + # --- Batch Sizes --- + batch_size=batch_size, + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + + # --- Replay Buffer --- + replay_buffer_size=int(1e6), + num_segments=num_segments, + use_priority=False, + + # --- Reanalyze --- + reanalyze_ratio=reanalyze_ratio, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + + # --- Algorithm Hyperparameters --- + num_simulations=num_simulations, + num_unroll_steps=num_unroll_steps, + td_steps=5, + discount_factor=0.99, + manual_temperature_decay=True, + threshold_training_steps_for_final_temperature=int(2.5e4), + + # --- MoCo (Momentum Contrast) --- + use_moco=False, # TODO + only_use_moco_stats=False, + grad_correct_params=dict( + MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, MoCo_rho=0, + calpha=0.5, rescale=1, ), + + # --- Multi-task Specific --- + total_task_num=len(env_id_list), + task_num=len(env_id_list), + task_id=0, # To be set per task + target_return=target_return_dict.get(env_id), + use_task_exploitation_weight=False, # TODO + task_complexity_weight=True, # TODO + balance_pipeline=True, + print_task_priority_logs=False, + + # --- Environment Interaction --- + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_episode=n_episode, + eval_freq=int(4e3), + + # --- Checkpointing --- + model_path=None, + ) + + # --- Combine configurations into the final EasyDict object --- + main_config = EasyDict(dict( + env=env_cfg, + policy=policy_cfg, )) + return main_config -def generate_configs(env_id_list: List[str], - collector_env_num: int, - n_episode: int, - evaluator_env_num: int, - num_simulations: int, - reanalyze_ratio: float, - batch_size: List[int], - num_unroll_steps: int, - infer_context_length: int, - norm_type: str, - seed: int, - buffer_reanalyze_freq: float, - reanalyze_batch_size: int, - reanalyze_partition: float, - num_segments: int, - total_batch_size: int): - configs = [] - exp_name_prefix = f'data_lz/data_suz_dmc_mt_20250413_moco/dmc_{len(env_id_list)}tasks_concattaskembed128_nlayer8_not-share-head_final-ln_bs64_brf{buffer_reanalyze_freq}_seed{seed}/' - # exp_name_prefix = f'data_lz/data_suz_dmc_mt_20250413_moco/dmc_{len(env_id_list)}tasks_notaskembed_nlayer8_not-share-head_final-ln_bs64_brf{buffer_reanalyze_freq}_seed{seed}/' +def generate_configs( + env_id_list: List[str], + target_return_dict: Dict[str, int], + collector_env_num: int, + n_episode: int, + evaluator_env_num: int, + num_simulations: int, + reanalyze_ratio: float, + batch_size: List[int], + num_unroll_steps: int, + infer_context_length: int, + norm_type: str, + seed: int, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + total_batch_size: int, + dmc_state_env_action_space_map: Dict[str, int], + dmc_state_env_obs_space_map: Dict[str, Tuple[int, ...]], +) -> List[Tuple[int, List[Any]]]: + """ + Overview: + Generate a list of configurations for all specified multi-task environments. - # exp_name_prefix = f'data_lz/data_suz_dmc_mt_20250409_moco/dmc_{len(env_id_list)}tasks_notaskembed_nlayer8_not-share-head_final-ln_bs64_brf{buffer_reanalyze_freq}_seed{seed}/' - - # exp_name_prefix = f'data_lz/data_suz_dmc_mt_20250325/dmc_{len(env_id_list)}tasks_task-exploitation-weight_notaskembed_nlayer8_not-share-head_final-ln_bs64_brf{buffer_reanalyze_freq}_seed{seed}/' - # exp_name_prefix = f'data_lz/data_suz_dmc_mt_20250311/dmc_{len(env_id_list)}tasks_concattaskembed-128_nlayer8_not-share-head_final-ln_bs64*8_brf{buffer_reanalyze_freq}_seed{seed}/' + Arguments: + - env_id_list (:obj:`List[str]`): A list of all environment IDs for the multi-task setup. + - target_return_dict (:obj:`Dict[str, int]`): A dictionary mapping environment IDs to their target return values. + - collector_env_num (:obj:`int`): Number of environments for data collection. + - n_episode (:obj:`int`): Number of episodes to run for collection. + - evaluator_env_num (:obj:`int`): Number of environments for evaluation. + - num_simulations (:obj:`int`): Number of simulations in the MCTS search. + - reanalyze_ratio (:obj:`float`): The ratio of reanalyzed data in a batch. + - batch_size (:obj:`List[int]`): Batch size for training per task. + - num_unroll_steps (:obj:`int`): Number of steps to unroll the model during training. + - infer_context_length (:obj:`int`): The context length for inference. + - norm_type (:obj:`str`): The type of normalization to use (e.g., 'LN'). + - seed (:obj:`int`): The random seed. + - buffer_reanalyze_freq (:obj:`float`): Frequency of reanalyzing the buffer. + - reanalyze_batch_size (:obj:`int`): Batch size for reanalyzing. + - reanalyze_partition (:obj:`float`): Partition ratio for reanalyzing. + - num_segments (:obj:`int`): Number of segments for the replay buffer. + - total_batch_size (:obj:`int`): The total batch size across all tasks. + - dmc_state_env_action_space_map (:obj:`Dict[str, int]`): Map from env_id to action space size. + - dmc_state_env_obs_space_map (:obj:`Dict[str, Tuple[int, ...]]`): Map from env_id to observation shape. + Returns: + - (:obj:`List[Tuple[int, List[Any]]]`): A list where each element contains the task ID and its corresponding + configuration objects. + """ + configs = [] + + # Define the experiment name prefix. This helps in organizing experiment logs and results. + exp_name_prefix = ( + f'data_suz_dmc_mt_20250601/dmc_{len(env_id_list)}tasks_frameskip4-pendulum-skip8_ln-mse' + f'_nlayer8_trans-moe8_brf{buffer_reanalyze_freq}_seed{seed}/' + ) + + # Get action_space_size and observation_shape for each environment. action_space_size_list = [dmc_state_env_action_space_map[env_id] for env_id in env_id_list] observation_shape_list = [dmc_state_env_obs_space_map[env_id] for env_id in env_id_list] - for task_id, (env_id, obs_shape, act_space) in enumerate(zip(env_id_list, observation_shape_list, action_space_size_list)): + for task_id, env_id in enumerate(env_id_list): config = create_config( env_id=env_id, + env_id_list=env_id_list, + target_return_dict=target_return_dict, action_space_size_list=action_space_size_list, observation_shape_list=observation_shape_list, collector_env_num=collector_env_num, @@ -213,7 +343,15 @@ def generate_configs(env_id_list: List[str], return configs -def create_env_manager(): +def create_env_manager() -> EasyDict: + """ + Overview: + Create the environment and policy manager configuration. This specifies the types + of environment, policy, and their import paths. + + Returns: + - (:obj:`EasyDict`): A configuration object for the environment and policy managers. + """ return EasyDict(dict( env=dict( type='dmc2gym_lightzero', @@ -230,126 +368,113 @@ def create_env_manager(): if __name__ == "__main__": """ Overview: + Main script to configure and launch a multi-task training session for DeepMind Control Suite (DMC) + environments using Distributed Data Parallel (DDP). + + Usage: This script should be executed with GPUs. - Run the following command to launch the script: - python -m torch.distributed.launch --nproc_per_node=8 --master_port=29502 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py 2>&1 | tee ./log/uz_mt_dmc_moco_taskembed_20250409.log - torchrun --nproc_per_node=8 ./zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py - """ + Navigate to the project root directory and run the launch command. + + Example command: + cd + # Using torch.distributed.launch (deprecated) + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29501 \\ + /dmc2gym_state_suz_multitask_ddp_config.py 2>&1 | tee \\ + /uz_mt_dmc18_train.log + # Using torchrun (recommended) + torchrun --nproc_per_node=8 /dmc2gym_state_suz_multitask_ddp_config.py + """ + # --- Import necessary components for training --- + # It's good practice to place imports inside the main guard + # if they are only used for script execution. from lzero.entry import train_unizero_multitask_segment_ddp from ding.utils import DDPContext - import os + import torch.distributed as dist from zoo.dmc2gym.config.dmc_state_env_space_map import dmc_state_env_action_space_map, dmc_state_env_obs_space_map - os.environ["NCCL_TIMEOUT"] = "3600000000" - - - # DMC 8games - # env_id_list = [ - # 'acrobot-swingup', - # 'cartpole-balance', - # 'cartpole-balance_sparse', - # 'cartpole-swingup', - # 'cartpole-swingup_sparse', - # 'cheetah-run', - # "ball_in_cup-catch", - # "finger-spin", - # ] - - # DMC 18games - env_id_list = [ - 'acrobot-swingup', # 0 - 'cartpole-balance', # 1 - 'cartpole-balance_sparse', # 2 - 'cartpole-swingup', # 3 - 'cartpole-swingup_sparse', # 4 bad - 'cheetah-run', # 5 bad - "ball_in_cup-catch", # 6 - "finger-spin", # 7 bad - "finger-turn_easy", # 8 波动 - "finger-turn_hard", # 9 波动 - 'hopper-hop', # 10 bad - 'hopper-stand', # 11 - 'pendulum-swingup', # 12 bad - 'reacher-easy', # 13 - 'reacher-hard', # 14 波动 - 'walker-run', # 15 略差 - 'walker-stand', # 16 - 'walker-walk', # 17 - ] + # --- Experiment constants --- + BENCHMARK_NAME = 'dmc' - # debug - # env_id_list = [ - # 'acrobot-swingup', # 0 - # 'cartpole-balance', # 1 - # 'cartpole-balance_sparse', # 2 - # 'cartpole-swingup', # 3 - # 'cartpole-swingup_sparse', # 4 bad - # 'cheetah-run', # 5 bad - # "ball_in_cup-catch", # 6 - # "finger-spin", # 7 bad - # # "finger-turn_easy", # 8 波动 - # # "finger-turn_hard", # 9 波动 - # ] - - # 获取各环境的 action_space_size 和 observation_shape - action_space_size_list = [dmc_state_env_action_space_map[env_id] for env_id in env_id_list] - observation_shape_list = [dmc_state_env_obs_space_map[env_id] for env_id in env_id_list] + # --- Environment and Task Definitions --- + # Target return values for each DMC task, used for evaluation and potential curriculum. + target_return_dict = { + 'acrobot-swingup': 500, + 'cartpole-balance': 950, + 'cartpole-balance_sparse': 950, + 'cartpole-swingup': 800, + 'cartpole-swingup_sparse': 750, + 'cheetah-run': 650, + "ball_in_cup-catch": 950, + "finger-spin": 800, + "finger-turn_easy": 950, + "finger-turn_hard": 950, + 'hopper-hop': 150, + 'hopper-stand': 600, + 'pendulum-swingup': 800, + 'reacher-easy': 950, + 'reacher-hard': 950, + 'walker-run': 600, + 'walker-stand': 950, + 'walker-walk': 950, + } + # List of DMC environments to be used in the multi-task setup. + env_id_list = list(target_return_dict.keys()) + + # --- Hyperparameters for the training session --- + # Environment and Collector settings collector_env_num = 8 - num_segments = 8 - n_episode = 8 evaluator_env_num = 3 - num_simulations = 50 - max_env_step = int(5e5) + n_episode = 8 + max_env_step = int(4e5) + + # Replay Buffer and Reanalyze settings + num_segments = 8 reanalyze_ratio = 0.0 + buffer_reanalyze_freq = 1 / 100000 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 - # nlayer=8 - total_batch_size = 1024 + # Model and Training settings + total_batch_size = 512 + # Allocate batch size per task, ensuring a minimum of 64 or distributing the total size. batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] - - # nlayer=12 - # total_batch_size = 256 - # batch_size = [int(min(32, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] - num_unroll_steps = 5 infer_context_length = 2 norm_type = 'LN' - buffer_reanalyze_freq = 1 / 100000 - reanalyze_batch_size = 160 - reanalyze_partition = 0.75 - - # ======== TODO: only for debug ======== - # collector_env_num = 2 - # num_segments = 2 - # n_episode = 2 - # evaluator_env_num = 2 - # num_simulations = 1 - # batch_size = [4 for _ in range(len(env_id_list))] - # ======================================= - - seed = 0 # You can iterate over multiple seeds if needed + num_simulations = 50 - configs = generate_configs( - env_id_list=env_id_list, - collector_env_num=collector_env_num, - n_episode=n_episode, - evaluator_env_num=evaluator_env_num, - num_simulations=num_simulations, - reanalyze_ratio=reanalyze_ratio, - batch_size=batch_size, - num_unroll_steps=num_unroll_steps, - infer_context_length=infer_context_length, - norm_type=norm_type, - seed=seed, - buffer_reanalyze_freq=buffer_reanalyze_freq, - reanalyze_batch_size=reanalyze_batch_size, - reanalyze_partition=reanalyze_partition, - num_segments=num_segments, - total_batch_size=total_batch_size, - ) + # --- Main training loop --- + # Iterate over different random seeds for multiple runs. + for seed in [1, 2]: + # Generate the specific configurations for each task for the current run. + configs = generate_configs( + env_id_list=env_id_list, + target_return_dict=target_return_dict, + collector_env_num=collector_env_num, + n_episode=n_episode, + evaluator_env_num=evaluator_env_num, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_size=batch_size, + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + seed=seed, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments, + total_batch_size=total_batch_size, + dmc_state_env_action_space_map=dmc_state_env_action_space_map, + dmc_state_env_obs_space_map=dmc_state_env_obs_space_map, + ) - with DDPContext(): - train_unizero_multitask_segment_ddp(configs, seed=seed, max_env_step=max_env_step) - # 如果只想训练部分任务,可以修改 configs,例如: - # train_unizero_multitask_segment_ddp(configs[:4], seed=seed, max_env_step=max_env_step) \ No newline at end of file + with DDPContext(): + train_unizero_multitask_segment_ddp(configs, seed=seed, max_env_step=max_env_step, + benchmark_name=BENCHMARK_NAME) + # If you only want to train a subset of tasks, you can slice the configs list. + # For example, to train only the first four tasks: + # train_unizero_multitask_segment_ddp(configs[:4], seed=seed, max_env_step=max_env_step, benchmark_name=BENCHMARK_NAME) + dist.destroy_process_group() \ No newline at end of file diff --git a/zoo/game_2048/config/stochastic_muzero_2048_config.py b/zoo/game_2048/config/stochastic_muzero_2048_config.py index a57c10175..0724e981f 100644 --- a/zoo/game_2048/config/stochastic_muzero_2048_config.py +++ b/zoo/game_2048/config/stochastic_muzero_2048_config.py @@ -44,6 +44,8 @@ self_supervised_learning_loss=True, discrete_action_encoding_type='one_hot', norm_type='BN', + reward_support_range=(0., 601., 1.), + value_support_range=(0., 601., 1.), ), # (str) The path of the pretrained model. If None, the model will be initialized by the default model. model_path=None, diff --git a/zoo/jericho/configs/jericho_ppo_config.py b/zoo/jericho/configs/jericho_ppo_config.py index 2c05c8579..e0cf74ea7 100644 --- a/zoo/jericho/configs/jericho_ppo_config.py +++ b/zoo/jericho/configs/jericho_ppo_config.py @@ -6,10 +6,10 @@ env_id = 'detective.z5' # Define environment configurations env_configurations = { - 'detective.z5': (10, 50), - 'omniquest.z5': (10, 100), - 'acorncourt.z5': (10, 50), - 'zork1.z5': (10, 400), + 'detective.z5': (12, 100), + 'omniquest.z5': (25, 100), + 'acorncourt.z5': (45, 50), + 'zork1.z5': (55, 500), } # Set action_space_size and max_steps based on env_id action_space_size, max_steps = env_configurations.get(env_id, (10, 50)) # Default values if env_id not found diff --git a/zoo/jericho/configs/jericho_unizero_config.py b/zoo/jericho/configs/jericho_unizero_config.py index 13155dfd5..8d5ac7fc1 100644 --- a/zoo/jericho/configs/jericho_unizero_config.py +++ b/zoo/jericho/configs/jericho_unizero_config.py @@ -5,7 +5,7 @@ from easydict import EasyDict -def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e5)) -> None: +def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e6)) -> None: """ Main entry point for setting up environment configurations and launching training. @@ -16,40 +16,38 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e Returns: None """ + env_id = 'detective.z5' + + collector_env_num: int = 4 # Number of collector environments + n_episode = int(collector_env_num) + batch_size=64 + # ------------------------------------------------------------------ # Base environment parameters (Note: these values might be adjusted for different env_id) # ------------------------------------------------------------------ # Define environment configurations env_configurations = { - 'detective.z5': (10, 50), - 'omniquest.z5': (10, 100), - 'acorncourt.z5': (10, 50), - 'zork1.z5': (10, 400), + 'detective.z5': (12, 100), + 'omniquest.z5': (25, 100), + 'acorncourt.z5': (45, 50), + 'zork1.z5': (55, 500), } - # env_id = 'detective.z5' - # env_id = 'omniquest.z5' - # env_id = 'acorncourt.z5' - # env_id = 'zork1.z5' - # Set action_space_size and max_steps based on env_id action_space_size, max_steps = env_configurations.get(env_id, (10, 50)) # Default values if env_id not found # ------------------------------------------------------------------ # User frequently modified configurations # ------------------------------------------------------------------ - evaluator_env_num: int = 2 # Number of evaluator environments + evaluator_env_num: int = 3 # Number of evaluator environments num_simulations: int = 50 # Number of simulations # Project training parameters - collector_env_num: int = 4 # Number of collector environments - n_episode: int = 4 # Number of episodes per training batch - batch_size: int = 64 # Batch size in training num_unroll_steps: int = 10 # Number of unroll steps (for rollout sequence expansion) infer_context_length: int = 4 # Inference context length num_layers: int = 2 # Number of layers in the model - replay_ratio: float = 0.25 # Replay ratio for experience replay + replay_ratio: float = 0.1 # Replay ratio for experience replay embed_dim: int = 768 # Embedding dimension # Reanalysis (reanalyze) parameters: @@ -61,12 +59,19 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e reanalyze_partition: float = 0.75 # Model name or path - configurable according to the predefined model paths or names - model_name: str = 'BAAI/bge-base-en-v1.5' + encoder_option = 'legacy' # ['qwen', 'legacy']. Legacy uses the bge encoder + + if encoder_option == 'qwen': + model_name: str = 'Qwen/Qwen3-0.6B' + elif encoder_option == 'legacy': + model_name: str = 'BAAI/bge-base-en-v1.5' + else: + raise ValueError(f"Unsupported encoder option: {encoder_option}") # ------------------------------------------------------------------ # TODO: Debug configuration - override some parameters for debugging purposes # ------------------------------------------------------------------ - # max_env_step = int(5e5) + # max_env_step = int(2e5) # batch_size = 10 # num_simulations = 2 # num_unroll_steps = 5 @@ -74,7 +79,6 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e # max_steps = 10 # num_layers = 1 # replay_ratio = 0.05 - # ------------------------------------------------------------------ # Configuration dictionary for the Jericho Unizero environment and policy # ------------------------------------------------------------------ @@ -94,12 +98,12 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e manager=dict(shared_memory=False), ), policy=dict( - multi_gpu=False, # Important for distributed data parallel (DDP) + multi_gpu=False, use_wandb=False, learn=dict( learner=dict( hook=dict( - save_ckpt_after_iter=1000000, + save_ckpt_after_iter=1000000, # To save memory, set a large value. If intermediate checkpoints are needed, reduce this value. ), ), ), @@ -107,10 +111,14 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e model=dict( observation_shape=512, action_space_size=action_space_size, + encoder_option=encoder_option, encoder_url=model_name, model_type="mlp", continuous_action_space=False, world_model_cfg=dict( + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', policy_entropy_weight=5e-2, continuous_action_space=False, max_blocks=num_unroll_steps, @@ -122,12 +130,30 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e num_layers=num_layers, num_heads=24, embed_dim=embed_dim, - obs_type="text", # TODO: Modify as needed. + obs_type="text", env_num=max(collector_env_num, evaluator_env_num), + + task_embed_option=None, + use_task_embed=False, + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + + moe_in_transformer=False, + multiplication_moe_in_transformer=False, + n_shared_experts=1, + num_experts_per_tok=1, + num_experts_of_moe_in_transformer=8, + lora_r= 0, + lora_alpha =1, + lora_dropout= 0.0, + + decode_loss_mode=None, # Controls where to compute reconstruction loss: after_backbone, before_backbone, or None. + latent_recon_loss_weight=0.1 ), ), - # update_per_collect=None, # Important for DDP - update_per_collect=int(collector_env_num*max_steps*replay_ratio), # Important for DDP + update_per_collect=int(collector_env_num*max_steps*replay_ratio ), # Important for DDP action_type="varied_action_space", model_path=None, num_unroll_steps=num_unroll_steps, @@ -135,18 +161,16 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e replay_ratio=replay_ratio, batch_size=batch_size, learning_rate=0.0001, - cos_lr_scheduler=True, + cos_lr_scheduler=False, fixed_temperature_value=0.25, manual_temperature_decay=False, num_simulations=num_simulations, n_episode=n_episode, train_start_after_envsteps=0, # TODO: Adjust training start trigger if needed. - # train_start_after_envsteps=2000, # TODO: Adjust training start trigger if needed. replay_buffer_size=int(5e5), - eval_freq=int(1e4), + eval_freq=int(3e4), collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, - # Reanalysis key parameters: buffer_reanalyze_freq=buffer_reanalyze_freq, reanalyze_batch_size=reanalyze_batch_size, reanalyze_partition=reanalyze_partition, @@ -164,8 +188,6 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e ), # Use base env manager to avoid bugs present in subprocess env manager. env_manager=dict(type="base"), - # If necessary, switch to subprocess env manager by uncommenting the following line: - # env_manager=dict(type="subprocess"), policy=dict( type="unizero", import_names=["lzero.policy.unizero"], @@ -181,7 +203,7 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e # Construct experiment name containing key parameters main_config.exp_name = ( - f"data_lz/data_unizero_jericho/bge-base-en-v1.5/uz_{env_id[:8]}_ms{max_steps}_ass-{action_space_size}_" + f"data_lz/data_unizero_jericho/bge-base-en-v1.5/{env_id}/uz_gpu_cen{collector_env_num}_rr{replay_ratio}_ftemp025_{env_id[:8]}_ms{max_steps}_ass-{action_space_size}_" f"nlayer{num_layers}_embed{embed_dim}_Htrain{num_unroll_steps}-" f"Hinfer{infer_context_length}_bs{batch_size}_seed{seed}" ) @@ -196,6 +218,13 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e if __name__ == "__main__": + """ + Overview: + This script should be executed with GPUs. + Run the following command to launch the script: + torchrun --nproc_per_node=4 ./zoo/jericho/configs/jericho_unizero_ddp_config.py + """ + parser = argparse.ArgumentParser(description='Process environment configuration and launch training.') parser.add_argument( '--env', @@ -215,10 +244,4 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e os.environ['TOKENIZERS_PARALLELISM'] = 'false' # Start the main process with the provided arguments - main(args.env, args.seed) - - # ====== the following is only for cprofile ====== - # def run(max_env_step: int): - # main(args.env, args.seed, max_env_step=max_env_step) - # import cProfile - # cProfile.run(f"run({10000})", filename="./zoo/jericho/detective_unizero_cprofile_10k_envstep", sort="cumulative") \ No newline at end of file + main(args.env, args.seed) \ No newline at end of file diff --git a/zoo/jericho/configs/jericho_unizero_ddp_config.py b/zoo/jericho/configs/jericho_unizero_ddp_config.py index b407e1040..e6079060d 100644 --- a/zoo/jericho/configs/jericho_unizero_ddp_config.py +++ b/zoo/jericho/configs/jericho_unizero_ddp_config.py @@ -5,7 +5,7 @@ from easydict import EasyDict -def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e5)) -> None: +def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e6)) -> None: """ Main entry point for setting up environment configurations and launching training. @@ -19,38 +19,51 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e gpu_num = 4 collector_env_num: int = 4 # Number of collector environments n_episode = int(collector_env_num*gpu_num) - batch_size = int(64*gpu_num) + + # Model name or path - configurable according to the predefined model paths or names + encoder_option = 'legacy' # ['qwen', 'legacy']. Legacy uses the bge encoder + + if encoder_option == 'qwen': + model_name: str = 'Qwen/Qwen3-0.6B' + batch_size = int(1*gpu_num) + accumulation_steps=64 + elif encoder_option == 'legacy': + model_name: str = 'BAAI/bge-base-en-v1.5' + batch_size = int(64*gpu_num) + accumulation_steps=1 + else: + raise ValueError(f"Unsupported encoder option: {encoder_option}") + + # TODO + # batch_size = batch_size * 2 # ------------------------------------------------------------------ # Base environment parameters (Note: these values might be adjusted for different env_id) # ------------------------------------------------------------------ # Define environment configurations + env_configurations = { - 'detective.z5': (10, 50), - 'omniquest.z5': (10, 100), - 'acorncourt.z5': (10, 50), - 'zork1.z5': (10, 400), + 'detective.z5': (12, 100), + 'omniquest.z5': (25, 100), + 'acorncourt.z5': (45, 50), + 'zork1.z5': (55, 500), } - - # env_id = 'detective.z5' - # env_id = 'omniquest.z5' - # env_id = 'acorncourt.z5' - # env_id = 'zork1.z5' - + env_id = 'detective.z5' # Set action_space_size and max_steps based on env_id action_space_size, max_steps = env_configurations.get(env_id, (10, 50)) # Default values if env_id not found # ------------------------------------------------------------------ # User frequently modified configurations # ------------------------------------------------------------------ - evaluator_env_num: int = 2 # Number of evaluator environments + evaluator_env_num: int = 3 # Number of evaluator environments num_simulations: int = 50 # Number of simulations # Project training parameters num_unroll_steps: int = 10 # Number of unroll steps (for rollout sequence expansion) infer_context_length: int = 4 # Inference context length + num_layers: int = 2 # Number of layers in the model - replay_ratio: float = 0.25 # Replay ratio for experience replay + replay_ratio: float = 0.1 # Replay ratio for experience replay embed_dim: int = 768 # Embedding dimension # Reanalysis (reanalyze) parameters: @@ -61,9 +74,6 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e # reanalyze_partition: Partition ratio from the replay buffer to use during reanalysis reanalyze_partition: float = 0.75 - # Model name or path - configurable according to the predefined model paths or names - model_name: str = 'BAAI/bge-base-en-v1.5' - # ------------------------------------------------------------------ # TODO: Debug configuration - override some parameters for debugging purposes # ------------------------------------------------------------------ @@ -103,14 +113,18 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e ), ), ), - accumulation_steps=1, # TODO: Accumulated gradient steps (currently default) + accumulation_steps=accumulation_steps, # TODO: Accumulated gradient steps (currently default) model=dict( observation_shape=512, action_space_size=action_space_size, encoder_url=model_name, + encoder_option=encoder_option, model_type="mlp", continuous_action_space=False, world_model_cfg=dict( + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', policy_entropy_weight=5e-2, continuous_action_space=False, max_blocks=num_unroll_steps, @@ -124,9 +138,12 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e embed_dim=embed_dim, obs_type="text", # TODO: Modify as needed. env_num=max(collector_env_num, evaluator_env_num), + decode_loss_mode=None, # Controls where to compute reconstruction loss: after_backbone, before_backbone, or None. + latent_recon_loss_weight=0.1 # TODO: decoder loss weight ), ), - update_per_collect=int(collector_env_num*max_steps*replay_ratio), # Important for DDP + # TODO + update_per_collect=int(collector_env_num*max_steps*replay_ratio*accumulation_steps), # Important for DDP action_type="varied_action_space", model_path=None, num_unroll_steps=num_unroll_steps, @@ -134,14 +151,15 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e replay_ratio=replay_ratio, batch_size=batch_size, learning_rate=0.0001, - cos_lr_scheduler=True, + cos_lr_scheduler=False, fixed_temperature_value=0.25, manual_temperature_decay=False, + num_simulations=num_simulations, n_episode=n_episode, train_start_after_envsteps=0, # TODO: Adjust training start trigger if needed. replay_buffer_size=int(5e5), - eval_freq=int(1e4), + eval_freq=int(3e4), collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, # Reanalysis key parameters: @@ -183,7 +201,7 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e main_config = lz_to_ddp_config(main_config) # Construct experiment name containing key parameters main_config.exp_name = ( - f"data_lz/data_unizero_jericho/bge-base-en-v1.5/uz_ddp-{gpu_num}gpu_cen{collector_env_num}_rr{replay_ratio}_ftemp025_{env_id[:8]}_ms{max_steps}_ass-{action_space_size}_" + f"data_lz/data_unizero_jericho/{model_name}/{env_id}/uz_ddp-{gpu_num}gpu_cen{collector_env_num}_rr{replay_ratio}_ftemp025_{env_id[:8]}_ms{max_steps}_ass-{action_space_size}_" f"nlayer{num_layers}_embed{embed_dim}_Htrain{num_unroll_steps}-" f"Hinfer{infer_context_length}_bs{batch_size}_seed{seed}" ) @@ -225,3 +243,4 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e # Start the main process with the provided arguments main(args.env, args.seed) +# \ No newline at end of file diff --git a/zoo/jericho/configs/jericho_unizero_segment_config.py b/zoo/jericho/configs/jericho_unizero_segment_config.py index d5aff1c7b..a44b9cf75 100644 --- a/zoo/jericho/configs/jericho_unizero_segment_config.py +++ b/zoo/jericho/configs/jericho_unizero_segment_config.py @@ -9,10 +9,10 @@ def main(env_id: str = 'detective.z5', seed: int = 0) -> None: # Base configurations # ------------------------------------------------------------------ env_configurations = { - 'detective.z5': (10, 50), - 'omniquest.z5': (10, 100), - 'acorncourt.z5': (10, 50), - 'zork1.z5': (10, 400), + 'detective.z5': (12, 100), + 'omniquest.z5': (25, 100), + 'acorncourt.z5': (45, 50), + 'zork1.z5': (55, 500), } # Set action_space_size and max_steps based on env_id @@ -22,7 +22,16 @@ def main(env_id: str = 'detective.z5', seed: int = 0) -> None: # Frequently changed configurations (user-specified) # ============================================================== # Model name or path - configurable according to the predefined model paths or names - model_name: str = 'BAAI/bge-base-en-v1.5' + encoder_option = 'legacy' # ['qwen', 'legacy']. Legacy uses the bge encoder + + if encoder_option == 'qwen': + model_name: str = 'Qwen/Qwen3-0.6B' + elif encoder_option == 'legacy': + model_name: str = 'BAAI/bge-base-en-v1.5' + else: + raise ValueError(f"Unsupported encoder option: {encoder_option}") + + collector_env_num = 8 game_segment_length = 20 evaluator_env_num = 5 @@ -86,9 +95,13 @@ def main(env_id: str = 'detective.z5', seed: int = 0) -> None: model=dict( observation_shape=512, action_space_size=action_space_size, + encoder_option=encoder_option, encoder_url=model_name, model_type="mlp", world_model_cfg=dict( + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', policy_entropy_weight=5e-3, continuous_action_space=False, max_blocks=num_unroll_steps, @@ -101,6 +114,8 @@ def main(env_id: str = 'detective.z5', seed: int = 0) -> None: embed_dim=embed_dim, obs_type="text", env_num=max(collector_env_num, evaluator_env_num), + decode_loss_mode='None', # Controls where to compute reconstruction loss: after_backbone, before_backbone, or None. + latent_recon_loss_weight=0.1 ), ), action_type="varied_action_space", diff --git a/zoo/jericho/detective_unizero_cprofile_10k_envstep b/zoo/jericho/detective_unizero_cprofile_10k_envstep deleted file mode 100644 index ae3d22ab1..000000000 Binary files a/zoo/jericho/detective_unizero_cprofile_10k_envstep and /dev/null differ diff --git a/zoo/jericho/envs/jericho_env.py b/zoo/jericho/envs/jericho_env.py index 4a3fc243a..f0a4a675a 100644 --- a/zoo/jericho/envs/jericho_env.py +++ b/zoo/jericho/envs/jericho_env.py @@ -45,7 +45,7 @@ class JerichoEnv(BaseEnv): DEFAULT_CONFIG: Dict[str, Any] = { 'max_steps': 400, 'max_action_num': 10, - 'tokenizer_path': "google-bert/bert-base-uncased", + 'tokenizer_path': "BAAI/bge-base-en-v1.5", 'max_seq_len': 512, 'remove_stuck_actions': False, 'add_location_and_inventory': False, @@ -53,7 +53,7 @@ class JerichoEnv(BaseEnv): 'save_replay': False, 'save_replay_path': None, 'env_type': "zork1", - 'collect_policy_mode': "random" + 'collect_policy_mode': "agent" } def __init__(self, cfg: Dict[str, Any]) -> None: @@ -70,12 +70,13 @@ def __init__(self, cfg: Dict[str, Any]) -> None: self.max_steps: int = self.cfg['max_steps'] self.game_path: str = self.cfg['game_path'] + self.env_type: str = self.cfg['env_type'] + self.max_action_num: int = self.cfg['max_action_num'] self.max_seq_len: int = self.cfg['max_seq_len'] self.save_replay: bool = self.cfg['save_replay'] self.save_replay_path: str = self.cfg['save_replay_path'] self.collect_policy_mode: str = self.cfg['collect_policy_mode'] - self.env_type: str = self.cfg['env_type'] # Record the last observation and action for detecting stuck actions. self.last_observation: Optional[str] = None @@ -152,7 +153,9 @@ def prepare_obs(self, obs: str, return_str: bool = False) -> Dict[str, Any]: full_obs: str = f"Location: {player_location}\nInventory: {inventory}{obs}\nValid actions: {available_actions}" else: full_obs = f"{obs}\nValid actions: {available_actions}" - + + full_obs_str = copy.deepcopy(full_obs) + # Tokenize observation if required. if not return_str: tokenized_output = JerichoEnv.tokenizer( @@ -175,11 +178,15 @@ def prepare_obs(self, obs: str, return_str: bool = False) -> Dict[str, Any]: if return_str: if self.for_unizero: return {'observation': full_obs, 'action_mask': action_mask, 'to_play': -1, 'timestep': self._timestep} + else: return {'observation': full_obs, 'action_mask': action_mask} else: if self.for_unizero: - return {'observation': full_obs, 'obs_attn_mask': obs_attn_mask, 'action_mask': action_mask, 'to_play': -1, 'timestep': self._timestep} + if self.save_replay: + return {'observation': full_obs, 'observation_str': full_obs_str,'obs_attn_mask': obs_attn_mask, 'action_mask': action_mask, 'to_play': -1, 'timestep': self._timestep} + else: + return {'observation': full_obs, 'obs_attn_mask': obs_attn_mask, 'action_mask': action_mask, 'to_play': -1, 'timestep': self._timestep} else: return {'observation': full_obs, 'obs_attn_mask': obs_attn_mask, 'action_mask': action_mask} @@ -199,10 +206,10 @@ def reset(self, return_str: bool = False) -> Dict[str, Any]: self._init_flag = True self._action_list = None self.episode_return = 0.0 - self.env_step = 0 self._timestep = 0 self.episode_history = [] - self.walkthrough_actions = self._env.get_walkthrough() + if self.collect_policy_mode == 'expert': + self.walkthrough_actions = self._env.get_walkthrough() if self.remove_stuck_actions: self.last_observation = initial_observation @@ -214,13 +221,15 @@ def reset(self, return_str: bool = False) -> Dict[str, Any]: processed_obs = self.prepare_obs(initial_observation, return_str) - self.episode_history.append({ - 'timestep': 0, - 'obs': processed_obs['observation'], - 'act': None, - 'done': False, - 'info': info - }) + if self.save_replay: + self.episode_history.append({ + 'timestep': 0, + 'obs': processed_obs['observation'] if return_str else processed_obs['observation_str'] , + 'act': None, + 'done': False, + 'info': info + }) + return processed_obs @@ -299,7 +308,6 @@ def step(self, action: Union[int, np.ndarray, str], return_str: bool = False) -> self._timestep += 1 if not self.for_unizero: reward = np.array([float(reward)]) - self.env_step += 1 self.episode_return += reward self._action_list = None @@ -314,13 +322,13 @@ def step(self, action: Union[int, np.ndarray, str], return_str: bool = False) -> processed_obs = self.prepare_obs(observation, return_str) - if self.env_step >= self.max_steps: + if self._timestep >= self.max_steps: done = True if self.save_replay: self.episode_history.append({ 'timestep': self._timestep, - 'obs': processed_obs['observation'], + 'obs': processed_obs['observation'] if return_str else processed_obs['observation_str'], 'act': action_str, 'reward': reward.item() if isinstance(reward, np.ndarray) else reward, 'done': done, @@ -329,7 +337,7 @@ def step(self, action: Union[int, np.ndarray, str], return_str: bool = False) -> if done: print('=' * 20) - print(f'rank {self.rank} one episode done!') + print(f'rank {self.rank} one episode done! episode_return:{self.episode_return}') self.finished = True info['eval_episode_return'] = self.episode_return diff --git a/zoo/mujoco/config/mujoco_sampled_efficientzero_config.py b/zoo/mujoco/config/mujoco_sampled_efficientzero_config.py index 7725d8409..4f7fef3ea 100644 --- a/zoo/mujoco/config/mujoco_sampled_efficientzero_config.py +++ b/zoo/mujoco/config/mujoco_sampled_efficientzero_config.py @@ -1,7 +1,9 @@ from easydict import EasyDict # options={'Hopper-v3', 'HalfCheetah-v3', 'Walker2d-v3', 'Ant-v3', 'Humanoid-v3'} -env_id = 'Hopper-v3' +# env_id = 'Hopper-v3' +env_id = 'Ant-v3' + if env_id == 'Hopper-v3': action_space_size = 3