Skip to content

Commit cfbbdd2

Browse files
committed
add starcoder2
1 parent c8d6bdc commit cfbbdd2

File tree

6 files changed

+660
-0
lines changed

6 files changed

+660
-0
lines changed

src/transformers_neuronx/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from transformers_neuronx.mistral.model import MistralForSampling
2929
from transformers_neuronx.mixtral.model import MixtralForSampling
3030
from transformers_neuronx.opt.model import OPTForSampling
31+
from transformers_neuronx.starcoder2.model import Starcoder2ForSampling
3132

3233
from transformers_neuronx.modeling_auto import NeuronAutoModelForCausalLM
3334

src/transformers_neuronx/modeling_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"mistral": transformers_neuronx.MistralForSampling,
1313
"mixtral": transformers_neuronx.MixtralForSampling,
1414
"opt": transformers_neuronx.OPTForSampling,
15+
"starcoder2": transformers_neuronx.Starcoder2ForSampling,
1516
}
1617

1718

@@ -24,6 +25,7 @@
2425
transformers.MistralConfig: "mistral",
2526
transformers.MixtralConfig: "mixtral",
2627
transformers.OPTConfig: "opt",
28+
transformers.Starcoder2Config: "starcoder2",
2729
}
2830

2931

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from transformers_neuronx import utils
2+
3+
4+
class Starcoder2Config:
5+
def __init__(
6+
self,
7+
config,
8+
n_positions,
9+
batch_size,
10+
amp,
11+
tp_degree,
12+
**kwargs
13+
):
14+
# Extract configs used for building HLO
15+
self.intermediate_size = config.intermediate_size
16+
self.hidden_size = config.hidden_size
17+
18+
self.attention_head_size = config.hidden_size // config.num_attention_heads
19+
self.num_attention_heads = config.num_attention_heads
20+
self.num_key_value_heads = config.num_key_value_heads if hasattr(config,
21+
"num_key_value_heads") else config.num_attention_heads
22+
self.num_hidden_layers = config.num_hidden_layers
23+
self.vocab_size = config.vocab_size
24+
self.hidden_act = config.hidden_act
25+
self.bos_token_id = config.bos_token_id
26+
self.eos_token_id = config.eos_token_id
27+
self.max_position_embeddings = config.max_position_embeddings
28+
self.rms_norm_eps = config.norm_epsilon
29+
self.rotary_percentage = getattr(config, "rotary_percentage", 1)
30+
self.rope_theta = getattr(config, "rope_theta", 10000)
31+
self.position_interpolation_factor = getattr(config, "position_interpolation_factor", None)
32+
self.use_bias = getattr(config, "use_bias", True)
33+
utils.maybe_override_attributes(self, kwargs)
34+
35+
# Add required Neuron configs
36+
self.n_positions = n_positions
37+
self.batch_size = batch_size
38+
self.amp = amp
39+
self.tp_degree = tp_degree
Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
from typing import Optional
2+
3+
from transformers_neuronx import constants
4+
from transformers_neuronx import hlo
5+
from transformers_neuronx import utils
6+
from transformers_neuronx.config import NeuronConfig
7+
from transformers_neuronx.constants import LAYOUT_HSB
8+
from transformers_neuronx.hlo import mlp
9+
from transformers_neuronx.layers import transformer, rotary, attention, attention_utils, flash_decoding
10+
from transformers_neuronx.starcoder2.config import Starcoder2Config
11+
12+
13+
class Starcoder2ForSamplingNoEmbeddingHlo:
14+
15+
def __init__(self,
16+
config: Starcoder2Config,
17+
neuron_config: Optional[NeuronConfig] = None
18+
):
19+
self.config = config
20+
self.neuron_config = neuron_config
21+
self.n_positions = None
22+
23+
@property
24+
def shard_over_batch(self):
25+
# Property access allows fallback configuration to be enabled after construction
26+
return (
27+
self.neuron_config is not None
28+
and self.neuron_config.group_query_attention == constants.GQA.SHARD_OVER_BATCH
29+
)
30+
31+
def inputs(self, scribe, dtype, n_active_tokens, batch_size):
32+
tensors, dims = transformer.inputs(
33+
scribe, dtype, batch_size, n_active_tokens, self.config.hidden_size, self.neuron_config)
34+
35+
return tensors, dims
36+
37+
def embedding(self, input_ids, cache_ids, start_ids, last_token_id, embed_weight):
38+
dtype = getattr(input_ids.scribe, self.config.amp)
39+
hidden = hlo.embedding(embed_weight, input_ids, tp_degree=self.config.tp_degree, dtype=dtype)
40+
if self.config.hidden_size % self.config.tp_degree != 0:
41+
hidden = hlo.slice_along(hidden, dim=-1, limit=self.config.hidden_size, start=0)
42+
if self.neuron_config.attention_layout == LAYOUT_HSB:
43+
hidden = hlo.transpose210(hidden)
44+
return hidden
45+
46+
def pre_layer(self, hidden, cache_ids, start_ids, last_token_id, *weights):
47+
head_dim = self.config.attention_head_size
48+
pos_embed = rotary.hlo_rotary_embedding(
49+
hidden.dtype, int(head_dim * self.config.rotary_percentage), cache_ids,
50+
base=self.config.rope_theta,
51+
interpolation_factor=self.config.position_interpolation_factor
52+
)
53+
mask, active_mask = hlo.attention_mask(cache_ids, start_ids, self.n_positions)
54+
core_id = None
55+
if self.neuron_config.shard_over_sequence:
56+
core_id, *rst = weights
57+
n_kv_heads = self.config.num_key_value_heads if self.config.num_attention_heads else self.config.num_attention_heads
58+
cores_per_kv_head = self.config.tp_degree // n_kv_heads
59+
self.cores_per_kv_head = cores_per_kv_head if cores_per_kv_head > 1 else self.config.tp_degree
60+
cache_ids, mask, active_mask = flash_decoding.convert_attn_mask_and_cache_id(cache_ids,
61+
core_id, self.n_positions,
62+
cores_per_kv_head=self.cores_per_kv_head)
63+
64+
return hidden, last_token_id, pos_embed, cache_ids, start_ids, mask, active_mask, core_id
65+
66+
def layer(
67+
self, hidden, last_token_id, pos_embed, cache_ids, start_ids, mask, active_mask, core_id,
68+
attn_k_cache, attn_v_cache,
69+
pre_attn_ln_weight, pre_attn_ln_bias,
70+
attn_q_weight, attn_q_scales, attn_q_bias,
71+
attn_k_weight, attn_k_scales, attn_k_bias,
72+
attn_v_weight, attn_v_scales, attn_v_bias,
73+
attn_out_weight, attn_out_scales, attn_out_bias,
74+
post_attn_ln_weight, post_attn_ln_bias,
75+
pre_mlp_ln_weight, pre_mlp_ln_bias,
76+
mlp_in_weight, mlp_in_scales, mlp_in_bias,
77+
mlp_out_weight, mlp_out_scales, mlp_out_bias,
78+
post_mlp_ln_weight, post_mlp_ln_bias,
79+
):
80+
# eps = self.config.rms_norm_eps
81+
# is_bsh = self.neuron_config and self.neuron_config.attention_layout == LAYOUT_BSH
82+
ln_hidden = hlo.layer_norm(hidden, pre_attn_ln_weight, pre_attn_ln_bias)
83+
84+
attn_output, out_attn_k_cache, out_attn_v_cache = self.attention(
85+
ln_hidden, cache_ids, start_ids, pos_embed, mask, active_mask, core_id,
86+
attn_k_cache, attn_v_cache,
87+
attn_q_weight, attn_q_scales, attn_q_bias,
88+
attn_k_weight, attn_k_scales, attn_k_bias,
89+
attn_v_weight, attn_v_scales, attn_v_bias,
90+
attn_out_weight, attn_out_scales, attn_out_bias
91+
)
92+
hidden = hlo.add(attn_output, hidden)
93+
94+
norm_hidden = hlo.layer_norm(hidden, pre_mlp_ln_weight, pre_mlp_ln_bias)
95+
mlp_hidden = mlp(
96+
norm_hidden,
97+
mlp_in_weight, mlp_in_bias, mlp_out_weight, mlp_out_bias,
98+
activation_function='gelu_new', # 'gelu_pytorch_tanh',
99+
tp_degree=self.config.tp_degree,
100+
neuron_config=self.neuron_config
101+
)
102+
res_hidden = hlo.add(mlp_hidden, hidden)
103+
return res_hidden, out_attn_k_cache, out_attn_v_cache
104+
105+
def ln_lm_head(self, hidden, last_token_id, rms_weight, unused_bias, lm_head_weight, lm_head_bias,
106+
return_all_outputs=True):
107+
logits = transformer.rms_lm_head(self.config.tp_degree, hidden, last_token_id, rms_weight, lm_head_weight,
108+
lm_head_bias, return_all_outputs, eps=self.config.rms_norm_eps,
109+
neuron_config=self.neuron_config)
110+
return logits
111+
112+
def attention(
113+
self,
114+
hidden, cache_ids, start_ids, pos_embed, mask, active_mask, core_id,
115+
cached_keys, cached_values,
116+
q_weight, q_scales, q_bias,
117+
k_weight, k_scales, k_bias,
118+
v_weight, v_scales, v_bias,
119+
out_weight, out_scales, out_bias,
120+
):
121+
d_head = self.config.attention_head_size
122+
tp_degree = self.config.tp_degree
123+
124+
# Compute the expected number of KV heads (Used in case fused QKV is used)
125+
n_kv_heads_tp = None
126+
if self.config.num_key_value_heads is not None:
127+
n_head = self.config.num_attention_heads
128+
n_kv_head = self.config.num_key_value_heads
129+
_, n_kv_head_padded = utils.get_qkv_padding(n_head, n_kv_head, tp_degree, self.neuron_config)
130+
n_kv_heads_tp = n_kv_head_padded // tp_degree
131+
132+
# Q = (hidden @ wQ) + bQ
133+
# K = (hidden @ wK) + bK
134+
# V = (hidden @ wV) + bV
135+
query, key, value = attention.query_key_value(
136+
hidden,
137+
q_weight, q_scales, q_bias,
138+
k_weight, k_scales, k_bias,
139+
v_weight, v_scales, v_bias,
140+
d_head,
141+
neuron_config=self.neuron_config,
142+
tp_degree=tp_degree, # TODO: include tp_degree into neuron_config
143+
shard_over_batch=self.shard_over_batch,
144+
n_kv_heads_tp=n_kv_heads_tp,
145+
)
146+
147+
# Q = Rotate(Q)
148+
# K = Rotate(K)
149+
query, key = rotary.rotate_half(query, key, pos_embed, self.config.rotary_percentage,
150+
tp_degree=tp_degree, shard_over_batch=self.shard_over_batch)
151+
152+
# Q = Q / sqrt(d_head)
153+
query = attention.scale(query, d_head)
154+
155+
# In BSH cache layout, the output of QKV linear projection is still kept as SBH for all QKV.
156+
bsh_cache_layout = False
157+
batch_dim = 1
158+
if self.neuron_config is not None:
159+
bsh_cache_layout = self.neuron_config.cache_layout == constants.LAYOUT_BSH
160+
if bsh_cache_layout:
161+
query, key, value = attention_utils.transpose_qkv(query, key, value)
162+
batch_dim = 0
163+
164+
# Single Token Generation ("Prefetch"-style) ans speculative forward
165+
if active_mask is not None:
166+
167+
n_active_tokens = key.sizes[1] if bsh_cache_layout else key.sizes[0]
168+
if n_active_tokens > 1 and self.neuron_config and self.neuron_config.continuous_batching:
169+
# For speculative forward + continuous batching, slice out samples in the batch size
170+
# corresponding to the batch size of the speculative head
171+
slice_sizes = [1] * len(cached_keys.sizes)
172+
if cached_keys.sizes[batch_dim] == 1:
173+
# Use hlo.select for batch size 1 as index select is prohibitively slow
174+
# TODO: revert to hlo.index_select once its faster P126527643
175+
cached_keys_s = hlo.select(cached_keys, batch_dim, hlo.reshape(start_ids, slice_sizes),
176+
keepdim=True)
177+
cached_values_s = hlo.select(cached_values, batch_dim, hlo.reshape(start_ids, slice_sizes),
178+
keepdim=True)
179+
else:
180+
cached_keys_s = hlo.index_select(cached_keys, batch_dim, start_ids)
181+
cached_values_s = hlo.index_select(cached_values, batch_dim, start_ids)
182+
else:
183+
cached_keys_s = cached_keys
184+
cached_values_s = cached_values
185+
# Communication 1: all-gather query from cores
186+
if (n_active_tokens != self.n_positions) and self.neuron_config.shard_over_sequence:
187+
query = flash_decoding.gather_query_group(query, self.cores_per_kv_head,
188+
self.config.num_attention_heads,
189+
tp_degree)
190+
191+
# Sp = Q @ Kp
192+
prior_scores = attention.score(query, cached_keys_s, n_kv_heads=self.config.num_key_value_heads,
193+
tp_degree=tp_degree, neuron_config=self.neuron_config)
194+
prior_scores = attention.mask(prior_scores, mask, tp_degree=tp_degree,
195+
shard_over_batch=self.shard_over_batch)
196+
197+
# Sa = Q @ Ka
198+
active_score = attention.score(query, key, n_kv_heads=self.config.num_key_value_heads,
199+
tp_degree=tp_degree, neuron_config=self.neuron_config)
200+
active_score = attention.mask(active_score, active_mask, tp_degree=tp_degree,
201+
shard_over_batch=self.shard_over_batch)
202+
203+
# C = softmax(Sa, Sp) @ (Va, Vp)
204+
if self.neuron_config.shard_over_sequence:
205+
dtype = query.dtype
206+
context = flash_decoding.context(prior_scores, active_score, cached_values_s, value, core_id, mask,
207+
active_mask,
208+
n_kv_heads=self.config.num_key_value_heads,
209+
n_heads=self.config.num_attention_heads, dtype=dtype,
210+
tp_degree=tp_degree, neuron_config=self.neuron_config,
211+
shard_over_batch=self.shard_over_batch)
212+
cache_ids, value, key = flash_decoding.select_values_within_bound(cache_ids, value, key,
213+
self.cores_per_kv_head, core_id,
214+
dim=0)
215+
216+
else:
217+
context = attention.context(prior_scores, active_score, cached_values_s, value,
218+
n_kv_heads=self.config.num_key_value_heads, tp_degree=tp_degree,
219+
neuron_config=self.neuron_config)
220+
221+
# KCache[I], VCache[I] = K, V
222+
updated_keys, updated_values = attention.fused_kv_update_cache(cached_keys, cached_values, cache_ids,
223+
key, value, start_ids,
224+
neuron_config=self.neuron_config)
225+
226+
# Multi-Token Context Encoding
227+
else:
228+
_, batch_size, _, _ = query.sizes
229+
if self.neuron_config.lhs_aligned or batch_size == 1:
230+
context = attention.flash_attention(query, key, value)
231+
else:
232+
# do not use flash attention for lhs padded (right aligned) batch > 1 case
233+
# because it does not correctly take mask into account
234+
context = None
235+
236+
if context is None:
237+
# S = Q @ K
238+
239+
score = attention.score(query, key, n_kv_heads=self.config.num_key_value_heads,
240+
tp_degree=tp_degree, neuron_config=self.neuron_config)
241+
score = attention.mask(score, mask, tp_degree=tp_degree, shard_over_batch=self.shard_over_batch)
242+
context = attention.context_combined(score, value, n_kv_heads=self.config.num_key_value_heads,
243+
tp_degree=tp_degree, neuron_config=self.neuron_config)
244+
245+
if self.neuron_config.shard_over_sequence:
246+
cache_ids, value, key = flash_decoding.select_values_within_bound(cache_ids,
247+
value,
248+
key,
249+
self.cores_per_kv_head,
250+
core_id, dim=0)
251+
# KCache, VCache = K, V
252+
if cached_keys.sizes == key.sizes:
253+
updated_keys, updated_values = key, value
254+
else:
255+
updated_keys, updated_values = attention.fused_kv_update_cache(cached_keys, cached_values, cache_ids,
256+
key, value, start_ids,
257+
neuron_config=self.neuron_config)
258+
259+
# O = (C @ wO) + bO
260+
output = attention.output(context, out_weight, out_scales, out_bias, tp_degree, self.neuron_config)
261+
return output, updated_keys, updated_values

0 commit comments

Comments
 (0)