Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions projects/mock_transformers/dist_infer_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
# limitations under the License.

import init_env # noqa
from typing import List, Optional, Tuple, Union
import oneflow as flow
import oneflow as torch
import oneflow.nn as nn
from omegaconf import DictConfig
from oneflow.utils.global_view import global_mode
from transformers import AutoModelForCausalLM, AutoTokenizer
Expand All @@ -37,6 +40,66 @@ def __init__(self, *args, **kwargs):
self.q_proj = Linear(embed_dim, embed_dim, bias=bias, parallel="col", dtype=flow.float16)
self.out_proj = Linear(embed_dim, embed_dim, bias=bias, parallel="row", dtype=flow.float16)

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""

fallback = key_value_states is not None or output_attentions or not self.is_decoder
if fallback:
return super().forward(
hidden_states,
key_value_states,
past_key_value,
attention_mask,
layer_head_mask,
output_attentions,
)
bsz, tgt_len, _ = hidden_states.size()

query_states, key_states, value_states = flow._C.grouped_matmul_bias(
[hidden_states, hidden_states, hidden_states],
[self.q_proj.weight, self.k_proj.weight, self.v_proj.weight],
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias],
)
past_key, past_value = (past_key_value[0], past_key_value[1]) if past_key_value is not None else (None, None)
key_states, value_states = flow._C.fused_attention_concat_past_key_value(
past_key=past_key,
past_key_layout="BHMK",
past_value=past_value,
past_value_layout="BHMK",
key=key_states,
key_layout="BM(HK)",
value=value_states,
value_layout="BM(HK)",
key_head_size=self.head_dim,
)

past_key_value = (key_states, value_states)

attn_output = flow._C.fused_multi_head_attention_inference_v2(
query=query_states,
query_layout="BM(HK)",
query_head_size=self.head_dim,
key=key_states,
key_layout="BHMK",
value=value_states,
value_layout="BHMK",
attn_mask_type="causal_from_bottom_right",
)
attn_output = self.out_proj(attn_output)

return attn_output, None, past_key_value


modeling_opt.OPTAttention = LiBaiOPTAttention

Expand Down