Skip to content
Open

Tmp #200

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions PyTorchSimFrontend/mlir/mlir_codegen_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,11 +628,10 @@ def indirect_indexing(self, index_var, size, check=True):
def _index_expr(self, tile_desc, renamed_expression, index, base_vector_index):
# In case of index expr, dimension size should be divisible by tile size
if not self.kernel_group.tile_desc.is_dim_dividable(self.ranges):
new_tile_size = self.kernel_group.tile_desc.adjust_tile_to_divisible(self.ranges)
new_tile_size = self.kernel_group.tile_desc.adjust_tile_to_divisible(self.ranges, self.attempted_tile_sizes)
self.kernel_group.tile_desc.set_tile_size(new_tile_size)
self.reset("recompile")
raise mlir_common.RecompileSignal(f"Index access (tile size {self.kernel_group.tile_desc.get_tile_size()} is not divisible by {self.ranges})")

tile_size = tile_desc.get_tile_size_per_lane()
compute_vec_size = tile_desc.get_compute_vec_size()
strides = tile_desc.get_tile_stride_per_lane()
Expand Down Expand Up @@ -1277,7 +1276,7 @@ def convert_indirect_indexing(self, index :sympy.Expr):

# Note: In case of indirect indexing, dimensions should be divisible by tile size
if not self.kernel_group.tile_desc.is_dim_dividable(self.ranges):
new_tile_size = self.kernel_group.tile_desc.adjust_tile_to_divisible(self.ranges)
new_tile_size = self.kernel_group.tile_desc.adjust_tile_to_divisible(self.ranges, self.attempted_tile_sizes)
self.kernel_group.tile_desc.set_tile_size(new_tile_size)
self.reset("recompile")
raise mlir_common.RecompileSignal(f"Indirect access (tile size {self.kernel_group.tile_desc.get_tile_size()} is not divisible by {self.ranges})")
Expand Down
24 changes: 19 additions & 5 deletions PyTorchSimFrontend/mlir/mlir_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,22 +313,32 @@ def is_dim_dividable(self, dim_sizes: list[int]) -> bool:

return all(d % t == 0 for d, t in zip(dim_sizes_cpy, self._tile_size))

def adjust_tile_to_divisible(self, dim_sizes: list[int]) -> list[int]:
def adjust_tile_to_divisible(self, dim_sizes: list[int], attempted_tile_sizes) -> list[int]:
"""Adjust current tile to be divisible by given dimensions."""
if len(dim_sizes) != len(self._tile_size):
raise ValueError("dim_sizes must match the tile size dimensions")

def _adjust_one(dim_size, tile_size):
dim_sizes_cpy = list(dim_sizes)
axis, stride = self.vmap.vlane_split_axis, self.vmap.vlane_stride
remain = dim_sizes_cpy[axis] % stride
if remain:
dim_sizes_cpy[axis] += stride - remain

def _adjust_one(dim_size, tile_size, is_split_dim, skip_size=[]):
for candidate in range(tile_size, 0, -1):
if dim_size % candidate == 0:
return candidate
if is_split_dim:
remain = candidate % stride
candidate += (stride - remain) if remain else 0
if candidate not in skip_size:
return candidate
return 1

candidate_tile_size = [_adjust_one(d, t) for d, t in zip(dim_sizes, self._tile_size)]
vlane_axis_skip_size = [dim[axis] for dim in attempted_tile_sizes]
candidate_tile_size = [_adjust_one(d, t, i==axis, vlane_axis_skip_size if i == axis else []) for i, (d, t) in enumerate(zip(dim_sizes_cpy, self._tile_size))]
for i in range(len(candidate_tile_size)):
self.tile_constraint[i].must_divide_dim = True

axis, stride = self.vmap.vlane_split_axis, self.vmap.vlane_stride
remain = candidate_tile_size[axis] % stride

if remain:
Expand Down Expand Up @@ -609,6 +619,9 @@ def __init__(self, kernel_group, reason=None):
self.target_buffer_override = contextvars.ContextVar("Handler_compute_override", default=self.compute)
self.target_cse_override = contextvars.ContextVar("Handler_cse_override", default=self.cse)

# Compile tile size manage
self.attempted_tile_sizes = set()

def set_ranges(self, lengths, reduction_lengths):
if self.call_ranges:
assert self.call_ranges == tuple(lengths) + tuple(
Expand Down Expand Up @@ -761,6 +774,7 @@ def codegen_nodes(self, nodes, kernel_name):
# Set node range info
vars, reduction_vars = self.set_ranges(group, reduction_group)
tile_desc = self.compute_tile_size(nodes, vars, reduction_vars)
self.attempted_tile_sizes.add(tuple(tile_desc.get_tile_size()))
self.compute_body_loop.size = tile_desc.get_numel_per_lane()
self.compute_body_loop.step = tile_desc.get_compute_vec_size()
try:
Expand Down
30 changes: 30 additions & 0 deletions configs/systolic_ws_128x128_c2_simple_noc_tpuv3_timing.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
num_cores: 2
core_freq_mhz: 940
core_stats_print_period_cycles: 10000
num_systolic_array_per_core: 2

vpu_num_lanes: 128
vpu_spad_size_kb_per_lane: 128
vpu_vector_length_bits: 256

dram_type: ramulator2
dram_freq_mhz: 940
dram_channels: 32
dram_req_size_byte: 32
dram_num_burst_length: 2
dram_stats_print_period_cycles: 10000
ramulator_config_path: ../configs/ramulator2_configs/HBM2_TPUv3.yaml

icnt_type: simple
icnt_latency_cycles: 10
icnt_freq_mhz: 940
icnt_injection_ports_per_core: 16

pytorchsim_functional_mode: 0
pytorchsim_timing_mode: 1

codegen_mapping_strategy: heuristic
codegen_external_mapping_file: ''
codegen_autotune_max_retry: 10
codegen_autotune_template_topk: 4
codegen_compiler_optimization: all
28 changes: 28 additions & 0 deletions tests/OPT/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
{
"_name_or_path": "opt-350m",
"activation_dropout": 0.0,
"activation_function": "relu",
"architectures": [
"OPTForCausalLM"
],
"attention_dropout": 0.0,
"bos_token_id": 2,
"do_layer_norm_before": false,
"dropout": 0.1,
"eos_token_id": 2,
"ffn_dim": 4096,
"hidden_size": 1024,
"init_std": 0.02,
"layerdrop": 0.0,
"max_position_embeddings": 2048,
"model_type": "opt",
"num_attention_heads": 16,
"num_hidden_layers": 24,
"pad_token_id": 1,
"prefix": "</s>",
"torch_dtype": "float16",
"transformers_version": "4.20.0.dev0",
"use_cache": true,
"vocab_size": 50272,
"word_embed_proj_dim": 512
}
254 changes: 254 additions & 0 deletions tests/OPT/experiment_cpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
import torch
import torch.nn as nn

from typing import Optional

class LLM_Config:
def __init__(self,
embed_dim,
hidden_size,
num_heads,
ffn_dim,
vocab_size,
word_embed_proj_dim,
pad_token_id,
max_position_embeddings,
enable_bias,
layer_norm_elementwise_affine,
do_layer_norm_before):
self.embed_dim = embed_dim
self.hidden_size = hidden_size
self.num_heads = num_heads
self.ffn_dim = ffn_dim
self.vocab_size = vocab_size
self.word_embed_proj_dim = word_embed_proj_dim
self.pad_token_id = pad_token_id
self.max_position_embeddings = max_position_embeddings
self.enable_bias = enable_bias
self.layer_norm_elementwise_affine = layer_norm_elementwise_affine
self.do_layer_norm_before = do_layer_norm_before



class OPTLearnedPositionalEmbedding(nn.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size.
"""

def __init__(self, num_embeddings: int, embedding_dim: int):
# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim)

def forward(
self,
attention_mask: torch.LongTensor,
past_key_values_length: int = 0,
position_ids: Optional[torch.LongTensor] = None,
):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""

if position_ids is None:
position_ids = torch.cumsum(attention_mask, dim=1)
position_ids = (position_ids * attention_mask - 1).long()
# cut positions if `past_key_values_length` is > 0
position_ids = position_ids[:, past_key_values_length:]

return super().forward(position_ids + self.offset)

class my_opt_decoder(nn.Module):
def __init__(self, config: LLM_Config, current_seq_len):
super(my_opt_decoder, self).__init__()
self.config = config

self.head_dim = self.config.embed_dim // self.config.num_heads
self.scaling = self.head_dim**-0.5

# Embedding layers
self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.word_embed_proj_dim, self.config.pad_token_id)
self.embed_positions = OPTLearnedPositionalEmbedding(self.config.max_position_embeddings, config.hidden_size)
self.project_in = nn.Linear(self.config.word_embed_proj_dim, self.config.hidden_size, bias=False)

# KV Cache
# self.past_k = torch.randn(bsz, num_heads, current_seq_len, self.head_dim)
# self.past_v = torch.randn(bsz, num_heads, current_seq_len, self.head_dim)
self.register_buffer(
"past_k",
torch.randn(bsz, num_heads, current_seq_len, self.head_dim)
)
self.register_buffer(
"past_v",
torch.randn(bsz, num_heads, current_seq_len, self.head_dim)
)


# QKV layers
self.k_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias)
self.v_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias)
self.q_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias)
self.o_proj = nn.Linear(self.config.embed_dim, self.config.embed_dim, bias=self.config.enable_bias)

self.self_attn_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)

# FC layers
self.activation_fn = nn.ReLU()
self.fc1 = nn.Linear(self.config.embed_dim, config.ffn_dim, bias=config.enable_bias)
self.fc2 = nn.Linear(config.ffn_dim, self.config.embed_dim, bias=config.enable_bias)
self.final_layer_norm = nn.LayerNorm(self.config.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)

# LM head
self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False)
self.lm_head_linear = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)

def embed(self, input_ids):
# input_ids: (bsz, seq_len)
inputs_embeds = self.embed_tokens(input_ids)
inputs_embeds = self.project_in(inputs_embeds)
bsz, seq_len, _ = inputs_embeds.size()
attention_mask = (input_ids != self.config.pad_token_id).long()
position_embeds = self.embed_positions(attention_mask=attention_mask)
hidden_states = inputs_embeds + position_embeds
return hidden_states

# qkv + rms
def qkv(self, hidden_states):
self.residual = hidden_states
self.bsz, self.tgt_len, _ = hidden_states.size()

if self.config.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)

query_states = self.q_proj(hidden_states) * self.scaling
query_states = query_states.view(self.bsz, -1, self.config.num_heads, self.head_dim).transpose(1, 2)

key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
key_states = key_states.view(self.bsz, -1, self.config.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(self.bsz, -1, self.config.num_heads, self.head_dim).transpose(1, 2)

return query_states, key_states, value_states

# QK^T + SV
def attn(self, query, key, value, attention_mask, scaling, dropout, **kwargs):
# KV cache update
self.past_k = torch.cat([self.past_k, key], dim=2)
self.past_v = torch.cat([self.past_v, value], dim=2)

key = self.past_k
value = self.past_v

attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=False)

attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()

return attn_output


# out-proj + rms
def out_proj(self, attn_output):
attn_output = attn_output.reshape(self.bsz, self.tgt_len, -1).contiguous()
attn_output = self.o_proj(attn_output)

attn_output = nn.functional.dropout(attn_output, p=0.0, training=False)
attn_output = self.residual + attn_output

# 350m applies layer norm AFTER attention
if not self.config.do_layer_norm_before:
attn_output = self.self_attn_layer_norm(attn_output)

return attn_output

# MLP + rms
def ffn(self, hidden_states):
hidden_states_shape = hidden_states.shape
hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
residual = hidden_states

# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.config.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)

hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)

hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=0.0, training=False)

hidden_states = (residual + hidden_states).view(hidden_states_shape)

# 350m applies layer norm AFTER attention
if not self.config.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)

outputs = (hidden_states,)

return outputs

def lm_head(self, outputs, logits_to_keep):
hidden_states = outputs[0]
hidden_states = self.project_out(hidden_states)
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head_linear(hidden_states[:, slice_indices, :]).contiguous()

return logits

def forward(self, x):
hidden = self.embed(x)
print(f"after embed hidden shape: {hidden.shape}")
q, k, v = self.qkv(hidden)
print(f"q shape: {q.shape}, k shape: {k.shape}, v shape: {v.shape}")
attn_output = self.attn(q, k, v, None, self.scaling, 0.0)
print(f"attn_output shape: {attn_output.shape}")
attn_output = self.out_proj(attn_output)
print(f"after out_proj attn_output shape: {attn_output.shape}")
outputs = self.ffn(attn_output)
print(f"after ffn outputs[0] shape: {outputs[0].shape}")
logits = self.lm_head(outputs, 1)
print(f"lm_head logits shape: {logits.shape}")
return logits



if __name__ == "__main__":

embed_dim = 1024
hidden_size = 1024
num_heads = 16
ffn_dim = 4096
vocab_size = 50272
word_embed_proj_dim = 512
pad_token_id = 1
max_position_embeddings = 2048
enable_bias = True
layer_norm_elementwise_affine = True
do_layer_norm_before = False

bsz = 128
seq_len = 1128

config = LLM_Config(embed_dim = embed_dim,
hidden_size = hidden_size,
num_heads = num_heads,
ffn_dim = ffn_dim,
vocab_size = vocab_size,
word_embed_proj_dim = word_embed_proj_dim,
pad_token_id = pad_token_id,
max_position_embeddings = max_position_embeddings,
enable_bias = enable_bias,
layer_norm_elementwise_affine = layer_norm_elementwise_affine,
do_layer_norm_before = do_layer_norm_before)
decoder = my_opt_decoder(config, seq_len)
decoder.eval()

input = torch.randint(0, vocab_size, (bsz, 1)) # (bsz, seq_len)

with torch.no_grad():
decoder(input)
Loading
Loading