Skip to content

Commit 88ec228

Browse files
committed
add starcoder2
1 parent c8d6bdc commit 88ec228

File tree

7 files changed

+622
-4
lines changed

7 files changed

+622
-4
lines changed

src/transformers_neuronx/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from transformers_neuronx.mistral.model import MistralForSampling
2929
from transformers_neuronx.mixtral.model import MixtralForSampling
3030
from transformers_neuronx.opt.model import OPTForSampling
31+
from transformers_neuronx.starcoder2.model import Starcoder2ForSampling
3132

3233
from transformers_neuronx.modeling_auto import NeuronAutoModelForCausalLM
3334

src/transformers_neuronx/modeling_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"mistral": transformers_neuronx.MistralForSampling,
1313
"mixtral": transformers_neuronx.MixtralForSampling,
1414
"opt": transformers_neuronx.OPTForSampling,
15+
"starcoder2": transformers_neuronx.Starcoder2ForSampling,
1516
}
1617

1718

@@ -24,6 +25,7 @@
2425
transformers.MistralConfig: "mistral",
2526
transformers.MixtralConfig: "mixtral",
2627
transformers.OPTConfig: "opt",
28+
transformers.Starcoder2Config: "starcoder2",
2729
}
2830

2931

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from transformers_neuronx import utils
2+
3+
4+
class Starcoder2Config:
5+
def __init__(
6+
self,
7+
config,
8+
n_positions,
9+
batch_size,
10+
amp,
11+
tp_degree,
12+
**kwargs
13+
):
14+
# Extract configs used for building HLO
15+
self.intermediate_size = config.intermediate_size
16+
self.hidden_size = config.hidden_size
17+
18+
self.attention_head_size = config.hidden_size // config.num_attention_heads
19+
self.num_attention_heads = config.num_attention_heads
20+
self.num_key_value_heads = config.num_key_value_heads if hasattr(config,
21+
"num_key_value_heads") else config.num_attention_heads
22+
self.num_hidden_layers = config.num_hidden_layers
23+
self.vocab_size = config.vocab_size
24+
self.hidden_act = config.hidden_act
25+
self.bos_token_id = config.bos_token_id
26+
self.eos_token_id = config.eos_token_id
27+
self.max_position_embeddings = config.max_position_embeddings
28+
self.rms_norm_eps = config.norm_epsilon
29+
self.rotary_percentage = getattr(config, "rotary_percentage", 1)
30+
self.rope_theta = getattr(config, "rope_theta", 10000)
31+
self.position_interpolation_factor = getattr(config, "position_interpolation_factor", None)
32+
self.use_bias = getattr(config, "use_bias", True)
33+
utils.maybe_override_attributes(self, kwargs)
34+
35+
# Add required Neuron configs
36+
self.n_positions = n_positions
37+
self.batch_size = batch_size
38+
self.amp = amp
39+
self.tp_degree = tp_degree
Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
from typing import Optional
2+
3+
from transformers_neuronx import constants
4+
from transformers_neuronx import hlo
5+
from transformers_neuronx import utils
6+
from transformers_neuronx.config import NeuronConfig
7+
from transformers_neuronx.constants import LAYOUT_HSB, LAYOUT_BSH
8+
from transformers_neuronx.hlo import mlp, mlp_bsh
9+
from transformers_neuronx.layers import transformer, rotary, attention, attention_utils, flash_decoding
10+
from transformers_neuronx.starcoder2.config import Starcoder2Config
11+
12+
13+
class Starcoder2ForSamplingNoEmbeddingHlo:
14+
15+
def __init__(self,
16+
config: Starcoder2Config,
17+
neuron_config: Optional[NeuronConfig] = None
18+
):
19+
self.config = config
20+
self.neuron_config = neuron_config
21+
self.n_positions = None
22+
23+
@property
24+
def shard_over_batch(self):
25+
# Property access allows fallback configuration to be enabled after construction
26+
return (
27+
self.neuron_config is not None
28+
and self.neuron_config.group_query_attention == constants.GQA.SHARD_OVER_BATCH
29+
)
30+
31+
def inputs(self, scribe, dtype, n_active_tokens, batch_size):
32+
tensors, dims = transformer.inputs(
33+
scribe, dtype, batch_size, n_active_tokens, self.config.hidden_size, self.neuron_config)
34+
35+
return tensors, dims
36+
37+
def embedding(self, input_ids, cache_ids, start_ids, last_token_id, embed_weight):
38+
dtype = getattr(input_ids.scribe, self.config.amp)
39+
hidden = hlo.embedding(embed_weight, input_ids, tp_degree=self.config.tp_degree, dtype=dtype)
40+
if self.config.hidden_size % self.config.tp_degree != 0:
41+
hidden = hlo.slice_along(hidden, dim=-1, limit=self.config.hidden_size, start=0)
42+
if self.neuron_config.attention_layout == LAYOUT_HSB:
43+
hidden = hlo.transpose210(hidden)
44+
return hidden
45+
46+
def pre_layer(self, hidden, cache_ids, start_ids, last_token_id, *weights):
47+
head_dim = self.config.attention_head_size
48+
pos_embed = rotary.hlo_rotary_embedding(
49+
hidden.dtype, int(head_dim * self.config.rotary_percentage), cache_ids,
50+
base=self.config.rope_theta,
51+
interpolation_factor=self.config.position_interpolation_factor
52+
)
53+
mask, active_mask = hlo.attention_mask(cache_ids, start_ids, self.n_positions)
54+
core_id = None
55+
if self.neuron_config.shard_over_sequence:
56+
core_id, *rst = weights
57+
n_kv_heads = self.config.num_key_value_heads if self.config.num_attention_heads else self.config.num_attention_heads
58+
cores_per_kv_head = self.config.tp_degree // n_kv_heads
59+
self.cores_per_kv_head = cores_per_kv_head if cores_per_kv_head > 1 else self.config.tp_degree
60+
cache_ids, mask, active_mask = flash_decoding.convert_attn_mask_and_cache_id(cache_ids,
61+
core_id, self.n_positions,
62+
cores_per_kv_head=self.cores_per_kv_head)
63+
64+
return hidden, last_token_id, pos_embed, cache_ids, start_ids, mask, active_mask, core_id
65+
66+
def layer(
67+
self, hidden, last_token_id, pos_embed, cache_ids, start_ids, mask, active_mask, core_id,
68+
attn_k_cache, attn_v_cache,
69+
pre_attn_ln_weight, pre_attn_ln_bias,
70+
attn_q_weight, attn_q_scales, attn_q_bias,
71+
attn_k_weight, attn_k_scales, attn_k_bias,
72+
attn_v_weight, attn_v_scales, attn_v_bias,
73+
attn_out_weight, attn_out_scales, attn_out_bias,
74+
post_attn_ln_weight, post_attn_ln_bias,
75+
pre_mlp_ln_weight, pre_mlp_ln_bias,
76+
mlp_in_weight, mlp_in_scales, mlp_in_bias,
77+
mlp_out_weight, mlp_out_scales, mlp_out_bias,
78+
post_mlp_ln_weight, post_mlp_ln_bias,
79+
):
80+
is_bsh = self.neuron_config and self.neuron_config.attention_layout == LAYOUT_BSH
81+
layer_norm_ = hlo.layer_norm_bsh if is_bsh else hlo.layer_norm
82+
ln_hidden = layer_norm_(hidden, pre_attn_ln_weight, pre_attn_ln_bias)
83+
84+
attn_output, out_attn_k_cache, out_attn_v_cache = self.attention(
85+
ln_hidden, cache_ids, start_ids, pos_embed, mask, active_mask, core_id,
86+
attn_k_cache, attn_v_cache,
87+
attn_q_weight, attn_q_scales, attn_q_bias,
88+
attn_k_weight, attn_k_scales, attn_k_bias,
89+
attn_v_weight, attn_v_scales, attn_v_bias,
90+
attn_out_weight, attn_out_scales, attn_out_bias
91+
)
92+
hidden = hlo.add(attn_output, hidden)
93+
94+
norm_hidden = layer_norm_(hidden, pre_mlp_ln_weight, pre_mlp_ln_bias)
95+
mlp_ = mlp_bsh if is_bsh else mlp
96+
mlp_hidden = mlp_(
97+
norm_hidden,
98+
mlp_in_weight, mlp_in_bias, mlp_out_weight, mlp_out_bias,
99+
in_scales=mlp_in_scales,
100+
out_scales=mlp_out_scales,
101+
activation_function='gelu_new', # 'gelu_pytorch_tanh',
102+
tp_degree=self.config.tp_degree,
103+
neuron_config=self.neuron_config,
104+
)
105+
res_hidden = hlo.add(mlp_hidden, hidden)
106+
return res_hidden, out_attn_k_cache, out_attn_v_cache
107+
108+
def ln_lm_head(self, hidden, last_token_id, rms_weight, unused_bias, lm_head_weight, lm_head_bias,
109+
return_all_outputs=True):
110+
logits = transformer.rms_lm_head(self.config.tp_degree, hidden, last_token_id, rms_weight, lm_head_weight,
111+
lm_head_bias, return_all_outputs, eps=self.config.rms_norm_eps,
112+
neuron_config=self.neuron_config)
113+
return logits
114+
115+
def attention(
116+
self,
117+
hidden, cache_ids, start_ids, pos_embed, mask, active_mask, core_id,
118+
cached_keys, cached_values,
119+
q_weight, q_scales, q_bias,
120+
k_weight, k_scales, k_bias,
121+
v_weight, v_scales, v_bias,
122+
out_weight, out_scales, out_bias,
123+
):
124+
d_head = self.config.attention_head_size
125+
tp_degree = self.config.tp_degree
126+
127+
# Compute the expected number of KV heads (Used in case fused QKV is used)
128+
n_kv_heads_tp = None
129+
if self.config.num_key_value_heads is not None:
130+
n_head = self.config.num_attention_heads
131+
n_kv_head = self.config.num_key_value_heads
132+
_, n_kv_head_padded = utils.get_qkv_padding(n_head, n_kv_head, tp_degree, self.neuron_config)
133+
n_kv_heads_tp = n_kv_head_padded // tp_degree
134+
135+
# Q = (hidden @ wQ) + bQ
136+
# K = (hidden @ wK) + bK
137+
# V = (hidden @ wV) + bV
138+
query, key, value = attention.query_key_value(
139+
hidden,
140+
q_weight, q_scales, q_bias,
141+
k_weight, k_scales, k_bias,
142+
v_weight, v_scales, v_bias,
143+
d_head,
144+
neuron_config=self.neuron_config,
145+
tp_degree=tp_degree, # TODO: include tp_degree into neuron_config
146+
shard_over_batch=self.shard_over_batch,
147+
n_kv_heads_tp=n_kv_heads_tp,
148+
)
149+
150+
# Q = Rotate(Q)
151+
# K = Rotate(K)
152+
query, key = rotary.rotate_half(query, key, pos_embed, self.config.rotary_percentage,
153+
tp_degree=tp_degree, shard_over_batch=self.shard_over_batch)
154+
155+
# Q = Q / sqrt(d_head)
156+
query = attention.scale(query, d_head)
157+
158+
# In BSH cache layout, the output of QKV linear projection is still kept as SBH for all QKV.
159+
bsh_cache_layout = False
160+
batch_dim = 1
161+
if self.neuron_config is not None:
162+
bsh_cache_layout = self.neuron_config.cache_layout == constants.LAYOUT_BSH
163+
if bsh_cache_layout:
164+
query, key, value = attention_utils.transpose_qkv(query, key, value)
165+
batch_dim = 0
166+
167+
# Single Token Generation ("Prefetch"-style) ans speculative forward
168+
if active_mask is not None:
169+
170+
n_active_tokens = key.sizes[1] if bsh_cache_layout else key.sizes[0]
171+
if n_active_tokens > 1 and self.neuron_config and self.neuron_config.continuous_batching:
172+
# For speculative forward + continuous batching, slice out samples in the batch size
173+
# corresponding to the batch size of the speculative head
174+
slice_sizes = [1] * len(cached_keys.sizes)
175+
if cached_keys.sizes[batch_dim] == 1:
176+
# Use hlo.select for batch size 1 as index select is prohibitively slow
177+
# TODO: revert to hlo.index_select once its faster P126527643
178+
cached_keys_s = hlo.select(cached_keys, batch_dim, hlo.reshape(start_ids, slice_sizes),
179+
keepdim=True)
180+
cached_values_s = hlo.select(cached_values, batch_dim, hlo.reshape(start_ids, slice_sizes),
181+
keepdim=True)
182+
else:
183+
cached_keys_s = hlo.index_select(cached_keys, batch_dim, start_ids)
184+
cached_values_s = hlo.index_select(cached_values, batch_dim, start_ids)
185+
else:
186+
cached_keys_s = cached_keys
187+
cached_values_s = cached_values
188+
# Communication 1: all-gather query from cores
189+
if (n_active_tokens != self.n_positions) and self.neuron_config.shard_over_sequence:
190+
query = flash_decoding.gather_query_group(query, self.cores_per_kv_head,
191+
self.config.num_attention_heads,
192+
tp_degree)
193+
194+
# Sp = Q @ Kp
195+
prior_scores = attention.score(query, cached_keys_s, n_kv_heads=self.config.num_key_value_heads,
196+
tp_degree=tp_degree, neuron_config=self.neuron_config)
197+
prior_scores = attention.mask(prior_scores, mask, tp_degree=tp_degree,
198+
shard_over_batch=self.shard_over_batch)
199+
200+
# Sa = Q @ Ka
201+
active_score = attention.score(query, key, n_kv_heads=self.config.num_key_value_heads,
202+
tp_degree=tp_degree, neuron_config=self.neuron_config)
203+
active_score = attention.mask(active_score, active_mask, tp_degree=tp_degree,
204+
shard_over_batch=self.shard_over_batch)
205+
206+
# C = softmax(Sa, Sp) @ (Va, Vp)
207+
if self.neuron_config.shard_over_sequence:
208+
dtype = query.dtype
209+
context = flash_decoding.context(prior_scores, active_score, cached_values_s, value, core_id, mask,
210+
active_mask,
211+
n_kv_heads=self.config.num_key_value_heads,
212+
n_heads=self.config.num_attention_heads, dtype=dtype,
213+
tp_degree=tp_degree, neuron_config=self.neuron_config,
214+
shard_over_batch=self.shard_over_batch)
215+
cache_ids, value, key = flash_decoding.select_values_within_bound(cache_ids, value, key,
216+
self.cores_per_kv_head, core_id,
217+
dim=0)
218+
219+
else:
220+
context = attention.context(prior_scores, active_score, cached_values_s, value,
221+
n_kv_heads=self.config.num_key_value_heads, tp_degree=tp_degree,
222+
neuron_config=self.neuron_config)
223+
224+
# KCache[I], VCache[I] = K, V
225+
updated_keys, updated_values = attention.fused_kv_update_cache(cached_keys, cached_values, cache_ids,
226+
key, value, start_ids,
227+
neuron_config=self.neuron_config)
228+
229+
# Multi-Token Context Encoding
230+
else:
231+
_, batch_size, _, _ = query.sizes
232+
if self.neuron_config.lhs_aligned or batch_size == 1:
233+
context = attention.flash_attention(query, key, value)
234+
else:
235+
# do not use flash attention for lhs padded (right aligned) batch > 1 case
236+
# because it does not correctly take mask into account
237+
context = None
238+
239+
if context is None:
240+
# S = Q @ K
241+
242+
score = attention.score(query, key, n_kv_heads=self.config.num_key_value_heads,
243+
tp_degree=tp_degree, neuron_config=self.neuron_config)
244+
score = attention.mask(score, mask, tp_degree=tp_degree, shard_over_batch=self.shard_over_batch)
245+
context = attention.context_combined(score, value, n_kv_heads=self.config.num_key_value_heads,
246+
tp_degree=tp_degree, neuron_config=self.neuron_config)
247+
248+
if self.neuron_config.shard_over_sequence:
249+
cache_ids, value, key = flash_decoding.select_values_within_bound(cache_ids,
250+
value,
251+
key,
252+
self.cores_per_kv_head,
253+
core_id, dim=0)
254+
# KCache, VCache = K, V
255+
if cached_keys.sizes == key.sizes:
256+
updated_keys, updated_values = key, value
257+
else:
258+
updated_keys, updated_values = attention.fused_kv_update_cache(cached_keys, cached_values, cache_ids,
259+
key, value, start_ids,
260+
neuron_config=self.neuron_config)
261+
262+
# O = (C @ wO) + bO
263+
output = attention.output(context, out_weight, out_scales, out_bias, tp_degree, self.neuron_config)
264+
return output, updated_keys, updated_values

0 commit comments

Comments
 (0)