Skip to content

Commit 7a30b42

Browse files
hannanjgawsjluntamazonaws-bowenccyazhom-awsdevesr-amzn
authored
Sync internal repo to external Mar 29 2024 (#81)
Updates for Neuron SDK 2.18.0 --------- Co-authored-by: Jonathan Lunt <[email protected]> Co-authored-by: Bowen Chen <[email protected]> Co-authored-by: Yuan Zhou <[email protected]> Co-authored-by: Devesh Ratho <[email protected]> Co-authored-by: Amer <[email protected]> Co-authored-by: Mike Zhang <[email protected]> Co-authored-by: Liangfu Chen <[email protected]> Co-authored-by: Nicholas Waldron <[email protected]> Co-authored-by: Shubham Chandak <[email protected]> Co-authored-by: Wojciech Romaszkan <[email protected]> Co-authored-by: Amulya Ballakur <[email protected]> Co-authored-by: Dylan Geva <[email protected]> Co-authored-by: Jeffrey Huynh <[email protected]> Co-authored-by: Shashwat Srijan <[email protected]>
1 parent 8b01968 commit 7a30b42

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+2323
-906
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def get_version():
7171
'accelerate',
7272
'safetensors',
7373
'torch-neuronx',
74-
'transformers',
74+
'transformers>=4.36',
7575
],
7676
python_requires='>=3.7',
7777
package_dir={'': 'src'},

src/transformers_neuronx/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,17 @@
1616

1717
from transformers_neuronx.config import NeuronConfig, QuantizationConfig, ContinuousBatchingConfig
1818
from transformers_neuronx.constants import GQA
19+
from transformers_neuronx.generation_utils import HuggingFaceGenerationModelAdapter
20+
21+
from transformers_neuronx.bloom.model import BloomForSampling
22+
from transformers_neuronx.llama.model import LlamaForSampling
23+
from transformers_neuronx.gpt2.model import GPT2ForSamplingWithContextBroadcasting
24+
from transformers_neuronx.gptneox.model import GPTNeoXForSampling
25+
from transformers_neuronx.gptj.model import GPTJForSampling
26+
from transformers_neuronx.mistral.model import MistralForSampling
27+
from transformers_neuronx.mixtral.model import MixtralForSampling
28+
from transformers_neuronx.opt.model import OPTForSampling
29+
30+
from transformers_neuronx.modeling_auto import NeuronAutoModelForCausalLM
1931

2032
from . import testing

src/transformers_neuronx/base.py

Lines changed: 64 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
import os
1717
import torch
1818
import logging
19-
from typing import Optional, Union, List
2019
import hashlib
20+
import warnings
21+
from typing import Optional, Union, List
2122
from transformers_neuronx import bucket
2223
from transformers_neuronx import utils
2324
from transformers_neuronx import module
@@ -39,7 +40,7 @@ def save(self, directory):
3940
def load(self, directory):
4041
assert self.serialization_enabled(), 'serialization is not enabled for this model'
4142
self._compiled_artifacts_directory = directory
42-
43+
4344
# top level api
4445
def compile(self, parallel_degree=None):
4546
kernels = self._get_all_kernels()
@@ -78,6 +79,13 @@ def enable_speculative_decoder(self,speculation_length:Optional[Union[List[int],
7879
for k in speculation_length:
7980
self.decoder_lm_head_for_speculation[k]=self.decoder_param_set.init_speculative_decoder(unroll=self.unroll, buckets=self.token_buckets, model_obj=self, n_active_tokens=k)
8081

82+
def enable_window_context_decoder(self, window_context_length:Optional[Union[List[int], int]], unroll):
83+
if isinstance(window_context_length, int):
84+
window_context_length=[window_context_length]
85+
self.window_context_buckets = bucket.context_sizes(window_context_length, self.token_buckets)
86+
for k in self.window_context_buckets:
87+
self.decoder_lm_head_for_window_context[k]=self.decoder_param_set.init_window_context_decoder(unroll=unroll, buckets=self.token_buckets, model_obj=self, n_active_tokens=k)
88+
8189
def is_compiled(self):
8290
# First check if the kernels have neffs already
8391
try:
@@ -141,7 +149,7 @@ def register_for_serialization(self, nbs_obj):
141149
def reset(self):
142150
self.decoder_lm_head.reset()
143151

144-
def context(self, hidden, cache_ids, start_ids, last_token_id, *rest, neuron_config=None):
152+
def context(self, hidden, cache_ids, start_ids, last_token_id, *rest):
145153
"""A helper to process context (prompt)
146154
1) if there is available context encoding model (infered from self.context_buckets)
147155
- when context_length >= estimate, slice the context up to estimate,
@@ -190,17 +198,45 @@ def context(self, hidden, cache_ids, start_ids, last_token_id, *rest, neuron_con
190198

191199
if current == estimate:
192200
model = self.decoder_lm_head_for_context[estimate, batch_size]
193-
logits = model(hidden_context, cache_context, start_ids, last_token_id, *rest, neuron_config=neuron_config)
201+
if self.neuron_config.log_softmax_scores:
202+
logits, scores = model(hidden_context, cache_context, start_ids, last_token_id, *rest)
203+
else:
204+
logits = model(hidden_context, cache_context, start_ids, last_token_id, *rest)
205+
206+
207+
208+
# process the leftovers context
209+
while current < context_length - 1:
210+
# find the optimal "window"
211+
estimate = None
212+
if hasattr(self, "window_context_buckets"):
213+
estimate = bucket.find(self.window_context_buckets, context_length - current)
214+
215+
# when the leftovers is smaller than estimate, fall back to single token generation
216+
# TODO: can we pad?
217+
if estimate is None or context_length - current < estimate:
218+
for i in range(current, context_length):
219+
cache_ids = torch.as_tensor([i], dtype=torch.int32)
220+
hidden_slice = hidden[:, i:i+1].contiguous()
221+
logits = self.decoder_lm_head(hidden_slice, cache_ids, start_ids, last_token_id, *rest)
222+
break
223+
224+
hidden_slice = hidden[:, current:current+estimate].contiguous()
225+
cache_ids = torch.as_tensor([i for i in range(current, current+estimate)], dtype=torch.int32)
226+
last_token_id = torch.as_tensor(estimate - 1)
227+
if self.neuron_config.log_softmax_scores:
228+
logits, scores = self.decoder_lm_head_for_window_context[estimate](hidden_slice, cache_ids, start_ids, last_token_id, *rest)
229+
else:
230+
logits = self.decoder_lm_head_for_window_context[estimate](hidden_slice, cache_ids, start_ids, last_token_id, *rest)
194231

195-
for i in range(current, context_length):
196-
cache_ids = torch.as_tensor([i], dtype=torch.int32)
197-
hidden_slice = hidden[:, i:i+1].contiguous()
198-
logits = self.decoder_lm_head(hidden_slice, cache_ids, start_ids, last_token_id, *rest, neuron_config=neuron_config)
232+
current += estimate
199233

200234
if self.is_fid:
201235
logits[:] = float('-inf')
202236
logits[self.bos_token_id] = 1.0
203237

238+
if self.neuron_config.log_softmax_scores:
239+
return logits, scores
204240
return logits
205241

206242
def _prepare_for_par_ctx_rhs_padding(self, input_ids):
@@ -318,9 +354,9 @@ def _cast_logits(self, logits):
318354
logits_dtype = getattr(torch, self.neuron_config.cast_logits_dtype)
319355
return logits.to(logits_dtype)
320356

321-
def _context_dynamic_batching(self, hidden, *args, neuron_config=None):
322-
# Taking HSB layout
323-
_, context_length, input_batch_size = hidden.shape
357+
def _context_dynamic_batching(self, hidden, *args):
358+
is_bsh = self.neuron_config and self.neuron_config.attention_layout == LAYOUT_BSH
359+
input_batch_size = hidden.shape[0] if is_bsh else hidden.shape[2]
324360
assert hasattr(self, "context_batch_sizes"), f"{type(self)} doesn't support dynamic batching."
325361

326362
running_batch_size = self.context_batch_sizes[-1]
@@ -334,34 +370,37 @@ def _context_dynamic_batching(self, hidden, *args, neuron_config=None):
334370
# Assuming HSB layout
335371
start_idx = iter_id*running_batch_size
336372
end_idx = (iter_id+1)*running_batch_size
337-
hidden_per_batch = hidden[:, :, start_idx:end_idx]
373+
if is_bsh:
374+
hidden_per_batch = hidden[start_idx:end_idx, :, :]
375+
else:
376+
hidden_per_batch = hidden[:, :, start_idx:end_idx]
338377
cache_ids_per_batch = cache_ids[start_idx:end_idx, :]
339378
start_ids_per_batch = start_ids[start_idx:end_idx]
340379
last_token_id = cache_ids_per_batch.max()
341380
logits_per_batch = self.context(hidden_per_batch, cache_ids_per_batch,
342-
start_ids_per_batch, last_token_id, neuron_config=neuron_config)
381+
start_ids_per_batch, last_token_id)
343382
all_logits.append(logits_per_batch)
344383
logits = torch.cat(all_logits, dim=2)
345384
else:
346385
assert input_batch_size == running_batch_size, \
347386
"input batch size ({input_batch_size}) not equal to running batch size ({running_batch_size})"
348-
logits = self.context(hidden, *args, neuron_config=neuron_config)
387+
logits = self.context(hidden, *args)
349388
return logits
350389

351-
def _forward(self, hidden, *args, neuron_config=None):
352-
# Taking HSB layout
390+
def _forward(self, hidden, *args):
353391
_, context_length, *_ = hidden.shape
354-
if not self.neuron_config.on_device_embedding:
355-
hidden = hidden.transpose(0, -1).contiguous()
356392

357393
if context_length > 1:
358394
continuous_batching = self.neuron_config and self.neuron_config.continuous_batching
359395
if continuous_batching:
360-
logits = self._context_dynamic_batching(hidden, *args, neuron_config=neuron_config)
396+
logits = self._context_dynamic_batching(hidden, *args)
361397
else:
362-
logits = self.context(hidden, *args, neuron_config=neuron_config)
398+
logits = self.context(hidden, *args)
363399
else:
364-
logits = self.decoder_lm_head(hidden, *args, neuron_config=neuron_config)
400+
logits = self.decoder_lm_head(hidden, *args)
401+
402+
if self.neuron_config.on_device_generation:
403+
return logits
365404

366405
logits = self._cast_logits(logits)
367406
logits = logits[:self.config.vocab_size, -1, :]
@@ -378,7 +417,7 @@ def pp_forward(self, *args, **kwargs):
378417
if self.neuron_config.rank_id == 0:
379418
broad_cast_objects = [args, kwargs]
380419
dist.broadcast_object_list(broad_cast_objects, src=0, device=torch.device("cpu"))
381-
res = self.forward(*args, **kwargs)
420+
res = self(*args, **kwargs)
382421
return res
383422
else:
384423
# if non-host, fall back to a for loop
@@ -394,17 +433,17 @@ def pp_forward(self, *args, **kwargs):
394433
# it is now naturally handled in forward call
395434
dist.broadcast_object_list(broad_cast_objects, src=0, device=torch.device("cpu"))
396435
args, kwargs = broad_cast_objects
397-
self.forward(*args, **kwargs)
436+
self(*args, **kwargs)
398437

399438
def serialization_enabled(self):
400439
return getattr(self, 'nbs_objs', None) is not None
401440

402-
def profile(self, profile_dir):
441+
def profile(self, profile_dir, ntff_count_limit):
403442
kernels = self._get_all_kernels()
404443

405444
for kernel in kernels:
406445
if isinstance(kernel, ParallelKernel):
407-
kernel.profile(profile_dir)
446+
kernel.profile(profile_dir, ntff_count_limit)
408447

409448
# Base class for all "Serializable Objects"
410449
class NeuronBaseSerializer:

src/transformers_neuronx/bloom/hlo.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# ==============================================================================
1515
from transformers_neuronx import hlo
1616
from transformers_neuronx.constants import LAYOUT_BSH
17-
from transformers_neuronx.layers import transformer, alibi
17+
from transformers_neuronx.layers import transformer, alibi, generation
1818
from transformers_neuronx.bloom.config import BloomConfig
1919

2020
class BloomForSamplingNoEmbeddingHlo:
@@ -30,7 +30,7 @@ def inputs(self, scribe, dtype, n_positions, n_active_tokens, batch_size):
3030
mask, active_mask = hlo.attention_mask(cache_ids, start_ids, n_positions)
3131
return (hidden, last_token_id, cache_ids, mask, active_mask), dims
3232

33-
def embedding(self, input_ids, word_embeddings, ln_weight, ln_bias):
33+
def embedding(self, input_ids, last_token_id, cache_ids, mask, active_mask, slopes, word_embeddings, ln_weight, ln_bias):
3434
dtype = getattr(input_ids.scribe, self.config.amp)
3535
hidden = hlo.embedding(word_embeddings, input_ids, tp_degree=self.config.tp_degree, dtype=dtype)
3636
if self.config.hidden_size % self.config.tp_degree != 0:
@@ -44,10 +44,6 @@ def embedding(self, input_ids, word_embeddings, ln_weight, ln_bias):
4444
def pre_layer(self, hidden, last_token_id, cache_ids, mask, active_mask, *pre_layer_weights):
4545
slopes, *rest = pre_layer_weights
4646
prior_alibi, active_alibi = alibi.alibi(slopes, mask, active_mask)
47-
48-
if self.neuron_config.on_device_embedding:
49-
hidden = self.embedding(hidden, *rest)
50-
5147
return hidden, last_token_id, cache_ids, mask, active_mask, prior_alibi, active_alibi
5248

5349
def layer(self, hidden, last_token_id, cache_ids, mask, active_mask, prior_alibi, active_alibi, attn_k_cache, attn_v_cache,
@@ -91,8 +87,15 @@ def layer(self, hidden, last_token_id, cache_ids, mask, active_mask, prior_alibi
9187
hidden = dtype[hidden.sizes].Add(mlp_hidden, hidden)
9288
return hidden, out_attn_k_cache, out_attn_v_cache
9389

94-
def ln_lm_head(self, hidden, last_token_id, ln_f_weight, ln_f_bias, lm_head_weight, lm_head_bias, return_all_outputs=True):
95-
return transformer.ln_lm_head(self.config.tp_degree, hidden, last_token_id, ln_f_weight, ln_f_bias, lm_head_weight, lm_head_bias, return_all_outputs, neuron_config=self.neuron_config)
90+
def ln_lm_head(self, hidden, last_token_id, ln_f_weight, ln_f_bias, lm_head_weight, lm_head_bias, logits_indices, return_all_outputs=True):
91+
logits = transformer.ln_lm_head(self.config.tp_degree, hidden, last_token_id, ln_f_weight, ln_f_bias, lm_head_weight,
92+
lm_head_bias, return_all_outputs, neuron_config=self.neuron_config)
93+
if self.neuron_config.on_device_generation is not None:
94+
return generation.generate(logits, logits_indices,
95+
config=self.neuron_config.on_device_generation,
96+
tp_degree=self.config.tp_degree,
97+
eos_token_id=self.config.eos_token_id)
98+
return logits
9699

97100
def attention(self,
98101
hidden, cache_ids, mask, active_mask, prior_alibi, active_alibi,

0 commit comments

Comments
 (0)