Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion lzero/model/unizero_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook, \
HFLanguageRepresentationNetwork
from .unizero_world_models.tokenizer import Tokenizer
from .unizero_world_models.world_model import WorldModel
# from .unizero_world_models.world_model import WorldModel
from .unzero_world_models.world_model_mamba2 import WorldModel
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

描述中简单写一下目前的情况吧,加了哪些模块,流程是否跑通,还需要哪些todo等

from ding.utils import ENV_REGISTRY, set_pkg_seed, get_rank, get_world_size


# use ModelRegistry to register the model, for more details about ModelRegistry, please refer to DI-engine's document.
Expand Down
116 changes: 116 additions & 0 deletions lzero/model/unizero_world_models/mamba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# -*- coding: utf-8 -*-
import math
from dataclasses import dataclass, field
from typing import Optional, Tuple, List, Any

import torch
import torch.nn as nn
from torch.nn import functional as F
from ding.torch_utils.network import GRUGatingUnit # Keep if GRU gating is used outside Block
from einops import rearrange
from mamba_ssm import Mamba2
from mamba_ssm.utils.generation import InferenceParams

class Mamba(nn.Module):
"""
Mamba-based model potentially for UniZero architecture.
Replaces the Transformer backbone.

Arguments:
- config (:obj:`MambaConfig`): Configuration for the Mamba model.
"""

def __init__(self, config) -> None:
super().__init__()
self.config = config
self.embed_dim = config.embed_dim
self.drop = nn.Dropout(config.embed_pdrop)
self.blocks = nn.ModuleList()

for i in range(config.num_layers):
mamba_block = Mamba2(
d_model=config.embed_dim,
d_state=128,
d_conv=4,
expand=2,
headdim=64,
ngroups=1,
bias=False,
conv_bias=True,
chunk_size=256,
use_mem_eff_path=True,
layer_idx=i,
)
self.blocks.append(mamba_block)

self.ln_f = nn.LayerNorm(config.embed_dim)

def _get_device(self):
return self.ln_f.weight.device

def _get_dtype(self):
return self.ln_f.weight.dtype

def generate_empty_state(self,
batch_size: int,
max_seq_len: Optional[int] = None,
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
"""
为所有 Mamba 层分配零初始化的状态张量 (conv_state, ssm_state),用于推理。
"""
_device = self._get_device()
_dtype = self._get_dtype()
_max_seq_len = max_tokens if max_seq_len is not None else getattr(self.config, 'max_seq_len', 2048)

all_layer_states = []
for mamba_layer in self.blocks:
conv_state, ssm_state = mamba_layer.allocate_inference_cache(
batch_size=batch_size,
max_seqlen=_max_seq_len,
dtype=_dtype
)
all_layer_states.append((conv_state.to(_device), ssm_state.to(_device)))
return all_layer_states


def forward(self, sequences: torch.Tensor, past_mamba_states: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
seqlen_offset: Optional[int] = 0) -> torch.Tensor:
"""
Forward pass for training or full sequence processing.

Arguments:
- sequences (:obj:`torch.Tensor`): Input tensor of shape (B, L, D) or (B*L, D) if seqlen is provided.
- seqlen (:obj:`Optional[int]`): Sequence length if input is flattened (B*L, D).
- inference_params (:obj:`Optional[Any]`): If provided, indicates potential step-by-step inference mode
(though `step` is preferred for that). Mamba2 forward might use it.

Returns:
- torch.Tensor: Output tensor, same shape principles as input `sequences`.
"""
x = self.drop(sequences)
current_inference_params = None
if past_mamba_states is not None:
batch_size, cur_seq_len, _ = sequences.shape
current_inference_params = InferenceParams(
max_seqlen=cur_seq_len + seqlen_offset,
max_batch_size=batch_size,
seqlen_offset=seqlen_offset
)
for i in range(self.config.num_layers):
current_inference_params.key_value_memory_dict[i] = past_mamba_states[i]

for i, block in enumerate(self.blocks):
x = block(x, inference_params=current_inference_params)

x = self.ln_f(x)

updated_layer_states_list: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None
if current_inference_params is not None:
updated_layer_states_list = []
for i in range(self.config.num_layers):
updated_conv_state, updated_ssm_state = current_inference_params.key_value_memory_dict[i]
updated_layer_states_list.append((updated_conv_state, updated_ssm_state))

return x, updated_layer_states_list


7 changes: 5 additions & 2 deletions lzero/model/unizero_world_models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from einops import rearrange

from .kv_caching import KeysValues
from mamba_ssm import Mamba2


@dataclass
Expand Down Expand Up @@ -239,7 +240,8 @@ def __init__(self, config: TransformerConfig) -> None:

self.ln1 = nn.LayerNorm(config.embed_dim)
self.ln2 = nn.LayerNorm(config.embed_dim)
self.attn = SelfAttention(config)
# self.attn = SelfAttention(config)
self.attn = Mamba2(d_model=config.embed_dim, d_state=64, d_conv=4, expand=2)
self.mlp = nn.Sequential(
nn.Linear(config.embed_dim, 4 * config.embed_dim),
nn.GELU(approximate='tanh'),
Expand All @@ -261,7 +263,8 @@ def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None
Returns:
- torch.Tensor: Output tensor of shape (batch_size, seq_length, embed_dim).
"""
x_attn = self.attn(self.ln1(x), past_keys_values, valid_context_lengths, freqs_cis)
# x_attn = self.attn(self.ln1(x), past_keys_values, valid_context_lengths, freqs_cis)
x_attn = self.attn(self.ln1(x))
if self.gru_gating:
x = self.gate1(x, x_attn)
x = self.gate2(x, self.mlp(self.ln2(x)))
Expand Down
Loading