-
Notifications
You must be signed in to change notification settings - Fork 174
feature(xjy): add mamba2 as a unizero backbone option #338
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
xiongjyu
wants to merge
2
commits into
opendilab:main
Choose a base branch
from
xiongjyu:dev-unizero-mamba2
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,116 @@ | ||
| # -*- coding: utf-8 -*- | ||
| import math | ||
| from dataclasses import dataclass, field | ||
| from typing import Optional, Tuple, List, Any | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| from torch.nn import functional as F | ||
| from ding.torch_utils.network import GRUGatingUnit # Keep if GRU gating is used outside Block | ||
| from einops import rearrange | ||
| from mamba_ssm import Mamba2 | ||
| from mamba_ssm.utils.generation import InferenceParams | ||
|
|
||
| class Mamba(nn.Module): | ||
| """ | ||
| Mamba-based model potentially for UniZero architecture. | ||
| Replaces the Transformer backbone. | ||
|
|
||
| Arguments: | ||
| - config (:obj:`MambaConfig`): Configuration for the Mamba model. | ||
| """ | ||
|
|
||
| def __init__(self, config) -> None: | ||
| super().__init__() | ||
| self.config = config | ||
| self.embed_dim = config.embed_dim | ||
| self.drop = nn.Dropout(config.embed_pdrop) | ||
| self.blocks = nn.ModuleList() | ||
|
|
||
| for i in range(config.num_layers): | ||
| mamba_block = Mamba2( | ||
| d_model=config.embed_dim, | ||
| d_state=128, | ||
| d_conv=4, | ||
| expand=2, | ||
| headdim=64, | ||
| ngroups=1, | ||
| bias=False, | ||
| conv_bias=True, | ||
| chunk_size=256, | ||
| use_mem_eff_path=True, | ||
| layer_idx=i, | ||
| ) | ||
| self.blocks.append(mamba_block) | ||
|
|
||
| self.ln_f = nn.LayerNorm(config.embed_dim) | ||
|
|
||
| def _get_device(self): | ||
| return self.ln_f.weight.device | ||
|
|
||
| def _get_dtype(self): | ||
| return self.ln_f.weight.dtype | ||
|
|
||
| def generate_empty_state(self, | ||
| batch_size: int, | ||
| max_seq_len: Optional[int] = None, | ||
| ) -> List[Tuple[torch.Tensor, torch.Tensor]]: | ||
| """ | ||
| 为所有 Mamba 层分配零初始化的状态张量 (conv_state, ssm_state),用于推理。 | ||
| """ | ||
| _device = self._get_device() | ||
| _dtype = self._get_dtype() | ||
| _max_seq_len = max_tokens if max_seq_len is not None else getattr(self.config, 'max_seq_len', 2048) | ||
|
|
||
| all_layer_states = [] | ||
| for mamba_layer in self.blocks: | ||
| conv_state, ssm_state = mamba_layer.allocate_inference_cache( | ||
| batch_size=batch_size, | ||
| max_seqlen=_max_seq_len, | ||
| dtype=_dtype | ||
| ) | ||
| all_layer_states.append((conv_state.to(_device), ssm_state.to(_device))) | ||
| return all_layer_states | ||
|
|
||
|
|
||
| def forward(self, sequences: torch.Tensor, past_mamba_states: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, | ||
| seqlen_offset: Optional[int] = 0) -> torch.Tensor: | ||
| """ | ||
| Forward pass for training or full sequence processing. | ||
|
|
||
| Arguments: | ||
| - sequences (:obj:`torch.Tensor`): Input tensor of shape (B, L, D) or (B*L, D) if seqlen is provided. | ||
| - seqlen (:obj:`Optional[int]`): Sequence length if input is flattened (B*L, D). | ||
| - inference_params (:obj:`Optional[Any]`): If provided, indicates potential step-by-step inference mode | ||
| (though `step` is preferred for that). Mamba2 forward might use it. | ||
|
|
||
| Returns: | ||
| - torch.Tensor: Output tensor, same shape principles as input `sequences`. | ||
| """ | ||
| x = self.drop(sequences) | ||
| current_inference_params = None | ||
| if past_mamba_states is not None: | ||
| batch_size, cur_seq_len, _ = sequences.shape | ||
| current_inference_params = InferenceParams( | ||
| max_seqlen=cur_seq_len + seqlen_offset, | ||
| max_batch_size=batch_size, | ||
| seqlen_offset=seqlen_offset | ||
| ) | ||
| for i in range(self.config.num_layers): | ||
| current_inference_params.key_value_memory_dict[i] = past_mamba_states[i] | ||
|
|
||
| for i, block in enumerate(self.blocks): | ||
| x = block(x, inference_params=current_inference_params) | ||
|
|
||
| x = self.ln_f(x) | ||
|
|
||
| updated_layer_states_list: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None | ||
| if current_inference_params is not None: | ||
| updated_layer_states_list = [] | ||
| for i in range(self.config.num_layers): | ||
| updated_conv_state, updated_ssm_state = current_inference_params.key_value_memory_dict[i] | ||
| updated_layer_states_list.append((updated_conv_state, updated_ssm_state)) | ||
|
|
||
| return x, updated_layer_states_list | ||
|
|
||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
描述中简单写一下目前的情况吧,加了哪些模块,流程是否跑通,还需要哪些todo等