Skip to content

Commit 0623de2

Browse files
hannanjgawsjluntamazonaws-bowenccyazhom-awsdevesr-amzn
authored
Sync internal repo to external Apr 15 2024 (#85)
* [module] Added better/faster checkpoint support with both sharded/whole checkpoints GitOrigin-RevId: 474757c3e65895084384e2e67d811f0423880fcf * [generation-demo] Add --profile GitOrigin-RevId: 7ba35ca7839df4fd300686db3632eb23feffbd6e * [module] Added the ability to download weights from huggingface hub repositories GitOrigin-RevId: 88661e675a31f2f1784056982717b3b15e04c922 * [automodel] Added NeuronAutoModelForCausalLM class which automatically loads architecture-specific classes GitOrigin-RevId: 18095eec9b978327e74cac2f24f872ca93c876f9 * [util] Avoid truncating tensors when padded size is less than current size. GitOrigin-RevId: 948fa370b543ac66976693e956bc024167060c8f * [window-context] Add window context encoding GitOrigin-RevId: 2def53a09fb5cbd58e1842941be8abd8fca79988 * Add support for post-processing logits for on-device log-softmax GitOrigin-RevId: 416ad2a32cc89606872c62c366b89e51e506f861 * [generation-demo] Add torch-profile; Use random input with --prompt_len GitOrigin-RevId: 62f39cfc1b857882034b307a7169bc3b2a46e292 * fix conflict for bsh gated mlp GitOrigin-RevId: 34815f7a7261b9484e20aaf7e42f55a5de09fc87 * fix conflict for bsh gated mlp fix2 GitOrigin-RevId: 8bee8b2595d88837341faad75630349f33ea55dd * [speculation] Updated speculative generator to correctly insert last draft token into KV cache GitOrigin-RevId: 8a3eedd6ee516c641442cc1fa9c5639a491d097d * [llama] Added support for tied embedding/head weights GitOrigin-RevId: 65dea75643e5bc041bca7bdd8677a6951e3ffccc * [decoder] Added a warmup to all kernels to avoid unexpected initialization latency spikes GitOrigin-RevId: b62d7e7a2df4675e66354a6f56b527d0e332891f * [hlo] add support for bool literals in Python 3.11 GitOrigin-RevId: 07ad81981b19ccaf5c775f18b839f51690f5b2ae * Add support for Mistral-7B-v0.2 for no sliding window GitOrigin-RevId: 205dcf4a5c8ce6e3c7c6c98e604e5ee3df509054 * [pp] call self directlry instead of self.forward to enable forward hook GitOrigin-RevId: 908b7af05e1320a2ddded5734545b147cd7a20ba * [module] Added support for model checkpoints using base model prefix GitOrigin-RevId: 8e8bd0d72de318d1f4cbc9859b44dc4e0d9b0514 * Fused KV cache update for deduplicating index calculation GitOrigin-RevId: 888778dfed05b873b7020b4208ed403edb87158d * Add tags to models, this should pass tags down to Parallel kernel, prefix context kernels with 'context' GitOrigin-RevId: cb9b6f19d2c0e33c441a2c770c8eb8a3c0a60a23 * Add warmup logic when profile gets called on each kernel GitOrigin-RevId: 8225ec846d5bcfb35741e4982b72921ce248a55b * [decoder] Handle corner cases where the KV cache is None. GitOrigin-RevId: ac418340447a5b8b4fc8df0c46eb7e97d50befb3 * [decoder] Added prefix tags to more decoder types. Added less ambiguous tag parameter prefixes GitOrigin-RevId: 15a25fab827f0dc18c4b37d11517f9eb4c5cd875 * Set self.tag in base class. This is used by PipelineParallelProgram GitOrigin-RevId: 96ec44be9af30c88b607d8ddddb2ff5ff907ec5f * Extend TP support of Mixtral-8x7B model from 8 to 16 and 32 and fix accuracy issue GitOrigin-RevId: 2f0bbfb4934c8396327d9579afc8d5284887fe94 * support BSH attention layout for continuous batching GitOrigin-RevId: 9664ff667ce1baa7c7eaacb57ecbe81d75a82629 * [generation_demo] additional flags, minor fixes GitOrigin-RevId: 6ab804d5012a5286ab2d0f612d68e6e24854ea36 * [generation_demo] model from config support, minor fixes GitOrigin-RevId: 38c82a99d0f628fe4b791dcc4d51bf7d0c835303 * Require transformers>=4.36 GitOrigin-RevId: 6f8b1ef2e099d268188ae7ed3b055ea94f7cbf81 * Support on-device embedding for GPT2. Fix multi-layer model support for LLAMA and BLOOM and clean up forward function signatures. GitOrigin-RevId: e7ea681c09e9f712c41668f4cc1aa78104f467e3 * Fixing HSB GroupNorm implementation GitOrigin-RevId: 372b2cca5fae0418e8a4cf346cf87133ac33ddf6 * [compile-cache] Add NEURONX_DUMP_TO_NOTEMP to dump artifacts from neuron cache GitOrigin-RevId: 140a46779a5e42806b901b3896c95251c9260010 * Fix forward call for Mixtral GitOrigin-RevId: abefd80fc726015ab133dd40425d3ba97d1ff2f3 * [Speculation] Use target model to generate leftover tokens GitOrigin-RevId: a654d7c01e43fffe9c3253850a75ea17d04aac7d * add block diagonal causal mask for concatenated multi-prompt encoding GitOrigin-RevId: f806d511d0eac7bf766972bb43eed9308566b5e4 * Revert [Speculation] Use target model to generate leftover tokens GitOrigin-RevId: 76feacb3aa501239359e0c4939dacbdf311ca7e2 * [hlo] Support transposed attention output weight GitOrigin-RevId: 5d1d00c1773d57570248388943c7c567a5dac870 * [Speculation] Use target model to generate leftover tokens GitOrigin-RevId: 2f677787bbcea31e5738982f243183908e612a45 * [compiler] Fixed snapshot steps functionality in executor. Fixed warmup to never snapshot GitOrigin-RevId: 343b2ce9549a9762f698fe2029ed28ad1245eb0f * KV cache placement with slot_mapping GitOrigin-RevId: 18c217fd664e60bd7834702f64001eecfaa0d688 * Update quantization to support contraction dim GitOrigin-RevId: 397ded48a0850d98144d7a517f605d6ed3ac3691 * [mistral/mixtral/opt] Fix on-device embedding support for remaining models GitOrigin-RevId: a52299027aabeffd47f13c9a9c4b4df600e8a4c2 * Adjust profiling to pass through number of NTFF files generated and remove tar file creation GitOrigin-RevId: 689616f8951e0e778aeae0dbdc12d04c267b6cbf * Fuse QKV support for GQA GitOrigin-RevId: 938e3d267b77e4b5d225d214be47c190472472c7 * [Hlo] Change mmadd to dot_add and update it to support N-D tensors GitOrigin-RevId: 5a06e680dbda0ab33262127df0b4458f88d38eed * Replace problem characters in tagged HLO. This directly translates to filenames (NEFF, NTFF) GitOrigin-RevId: fe658484f8c6cebe691ab623c45ee69e3427b5c9 * [Release 2.18] Change version to 1.0 GitOrigin-RevId: 2c948b4669ab83591925595fcbba87319971369d * added ntff_count_limit argument to generation_demo GitOrigin-RevId: 9649a8710ce4fe13cd22891f1e8ed983377c1cf6 * Merge branch 'mainline' of ssh://git.amazon.com:2222/pkg/KaenaTransformers into mainline GitOrigin-RevId: aa1051241c38be805cca604855f4531c35eda83c * Merge branch 'mainline' of ssh://git.amazon.com:2222/pkg/KaenaTransformers into mainline GitOrigin-RevId: 31e2511deaf9b8aa2e4d983b1263377a74ab4cd8 * Merge branch 'mainline' of ssh://git.amazon.com:2222/pkg/KaenaTransformers into mainline GitOrigin-RevId: c5b1d876e387ab51e55c0b1c7a8320ab97a88777 * Fix position_offset param for OPTForSamplingNoEmbeddingHlo GitOrigin-RevId: 2aebac56ac5780fc7f88bc5f24bc7611562ef5bd * initial support for BSH cache layout GitOrigin-RevId: 68440f419057fbbb9befe02a5f95bbead9d4a24a * support BSH cache layout with BSH attention layout GitOrigin-RevId: 75e2807a57e509ef3dc9edae7c44421f911d8961 * [generation_demo] remove dump flag GitOrigin-RevId: 23a774b75e2f8e224e511531703391ccf5ee4b10 * Reenable prompt broadcasting in GPT2 (input batch 1, output batch N) GitOrigin-RevId: ec8772e01b121a2acd5526b9bbe3fa0bb9d334a9 * Fix return ranks for executors while using on-device sampling GitOrigin-RevId: 412113c5b10b0285fe7f3aafe82bae578b463024 * [Release 2.18] Change version to 0.10.x GitOrigin-RevId: fde8f715eca71e3ee1dac392b9e9b3527e6ee0cb * [module] Allow safetensors checkpoint downloads to be explicitly disabled GitOrigin-RevId: fb34b5e8ace114f086728382c0f60358fb687f24 * Override attn_implementation as eager to skip sdpa attn implemenatation GitOrigin-RevId: 2112b39c4a43cadc77ec551b6ed49ce33a7a72f1 * [hlo] Added primitive broadcasting. Added new operators. Added on-device speculative token selection GitOrigin-RevId: cf64e6130141d5d237cdc0277c8b6d3adc39630e * LHS alignment for static batching (vectorize last_token_id) GitOrigin-RevId: 41dc716837dc1416eea4b60c198a39710cc153a7 * fix cache_ids padding for llama CB and batch=1 SD GitOrigin-RevId: 7a28dcbe147e2a4400fe7ae7c8ac5cbad145e185 * Cherry-picks to 2.18 for multi-bucketing/multi-prompt for continuous batching GitOrigin-RevId: 74a3c4cdc4dc8ae7ee6a5a12a9734a084499224e * Fix "Unit Tests - 2 Core" errors in test_neuron_auto_model.py mixtral tests GitOrigin-RevId: 8d5c47068ccbd2e87672243aedd56c46b887a8b9 * fix generation_utils GitOrigin-RevId: fcb5254a8ecafbeef434cff5fce9075993cdd471 --------- Co-authored-by: Jonathan Lunt <[email protected]> Co-authored-by: Bowen Chen <[email protected]> Co-authored-by: Yuan Zhou <[email protected]> Co-authored-by: Devesh Ratho <[email protected]> Co-authored-by: Amer <[email protected]> Co-authored-by: Mike Zhang <[email protected]> Co-authored-by: Liangfu Chen <[email protected]> Co-authored-by: Nicholas Waldron <[email protected]> Co-authored-by: Shubham Chandak <[email protected]> Co-authored-by: Wojciech Romaszkan <[email protected]> Co-authored-by: Amulya Ballakur <[email protected]> Co-authored-by: Dylan Geva <[email protected]> Co-authored-by: Jeffrey Huynh <[email protected]> Co-authored-by: Shashwat Srijan <[email protected]> Co-authored-by: Prithvijit Chakrabarty <[email protected]>
1 parent 7a30b42 commit 0623de2

File tree

15 files changed

+524
-91
lines changed

15 files changed

+524
-91
lines changed

src/transformers_neuronx/base.py

Lines changed: 77 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,17 @@ def to_neuron(self):
7373
self.setup()
7474

7575
# top level api
76-
def enable_speculative_decoder(self,speculation_length:Optional[Union[List[int], int]]):
76+
def enable_speculative_decoder(self, speculation_length: Optional[Union[List[int], int]], batch_sizes: Optional[Union[List[int], int]]=None):
7777
if isinstance(speculation_length, int):
78-
speculation_length=[speculation_length]
78+
speculation_length = [speculation_length]
79+
if batch_sizes is None:
80+
batch_sizes = self.decoder_param_set.batch_size
81+
if isinstance(batch_sizes, int):
82+
batch_sizes = [batch_sizes]
7983
for k in speculation_length:
80-
self.decoder_lm_head_for_speculation[k]=self.decoder_param_set.init_speculative_decoder(unroll=self.unroll, buckets=self.token_buckets, model_obj=self, n_active_tokens=k)
84+
for batch_size in batch_sizes:
85+
self.decoder_lm_head_for_speculation[k, batch_size] = \
86+
self.decoder_param_set.init_speculative_decoder(unroll=self.unroll, buckets=self.token_buckets, model_obj=self, n_active_tokens=k, batch_size=batch_size)
8187

8288
def enable_window_context_decoder(self, window_context_length:Optional[Union[List[int], int]], unroll):
8389
if isinstance(window_context_length, int):
@@ -164,7 +170,7 @@ def context(self, hidden, cache_ids, start_ids, last_token_id, *rest):
164170
Other arguments that are required by the model are contained in `rest`.
165171
"""
166172
context_length = hidden.shape[1]
167-
batch_size, = start_ids.shape
173+
batch_size = start_ids.shape[0]
168174

169175
if self.is_fid:
170176
# Fusion-In-Decoder context encoding
@@ -239,7 +245,7 @@ def context(self, hidden, cache_ids, start_ids, last_token_id, *rest):
239245
return logits, scores
240246
return logits
241247

242-
def _prepare_for_par_ctx_rhs_padding(self, input_ids):
248+
def _prepare_for_par_ctx_rhs_padding(self, input_ids, cache_ids):
243249
"""A helper to do rhs padding on prompt for parallel context encoding model
244250
i.e.
245251
input_ids = [[111, 222, 333]]
@@ -257,9 +263,12 @@ def _prepare_for_par_ctx_rhs_padding(self, input_ids):
257263
batch_size, context_length = input_ids.shape
258264

259265
# if last_token_id not used, simply set to 0
260-
last_token_id = torch.as_tensor(0, dtype=torch.int32)
266+
if self.neuron_config.vectorize_last_token_id:
267+
last_token_id = torch.zeros(batch_size, dtype=torch.int32)
268+
else:
269+
last_token_id = torch.as_tensor(0, dtype=torch.int32)
261270
if context_length == 1:
262-
return input_ids, last_token_id
271+
return input_ids, cache_ids, last_token_id
263272

264273
# TODO: check context_buckets for compatibility with OPT
265274
if hasattr(self, "context_buckets"):
@@ -269,11 +278,38 @@ def _prepare_for_par_ctx_rhs_padding(self, input_ids):
269278

270279
if estimate:
271280
# when context length is larger than estimate, last_token_id=estimate-1
272-
last_token_id = torch.as_tensor(min(context_length - 1, estimate-1), dtype=torch.int32)
281+
if self.neuron_config.vectorize_last_token_id:
282+
last_token_id = cache_ids.max(dim=1).values
283+
else:
284+
last_token_id = torch.as_tensor(min(context_length - 1, estimate-1), dtype=torch.int32)
273285
if context_length < estimate:
274286
input_ids = utils.pad(input_ids, 1, estimate, left=False)
287+
cache_ids = self._pad_cache_ids(cache_ids, batch_size, context_length, estimate)
288+
289+
return input_ids, cache_ids, last_token_id
275290

276-
return input_ids, last_token_id
291+
def _pad_cache_ids(self, cache_ids, batch_size, context_length, estimate):
292+
if self.neuron_config.use_2d_cache_ids:
293+
# TODO: fix cache_ids padding for batch speculative decoding
294+
cache_ids = torch.arange(estimate, dtype=torch.long)
295+
cache_ids = cache_ids.unsqueeze(0).expand(batch_size, estimate)
296+
else:
297+
if cache_ids is None:
298+
cache_ids = torch.arange(estimate, dtype=torch.long)
299+
else:
300+
# Inputs: cache_ids = [16, 17], estimate = 512
301+
#
302+
# Process:
303+
# start_idx = 18, end_idx = 528 (= 512+16)
304+
# padded_elements = [18, 19, ..., 511, 512, 513, ..., 525, 526, 527]
305+
# cache_ids_pad = [16, 17, 18, 19, ..., 511, 512, 513, ..., 525, 526, 527]
306+
# cache_ids = [16, 17, 18, 19, ..., 511, 511, 511, ..., 511, 511, 511]
307+
start_idx = cache_ids[-1].item() + 1
308+
end_idx = estimate + start_idx - context_length
309+
pad_elements = torch.arange(start_idx, end_idx, dtype=torch.long)
310+
cache_ids_pad = torch.concat([cache_ids, pad_elements], dim=0)
311+
cache_ids = torch.minimum(cache_ids_pad, torch.tensor(estimate-1, dtype=torch.long))
312+
return cache_ids
277313

278314
def _prepare_for_continuous_batching(self, input_ids, cache_ids=None, seq_ids=None):
279315
n_seqs, n_active_tokens = input_ids.shape
@@ -288,10 +324,33 @@ def _prepare_for_continuous_batching(self, input_ids, cache_ids=None, seq_ids=No
288324
if n_active_tokens > 1 and cache_ids.flatten()[0].item() == 0:
289325
# context encoding
290326
n_active_seqs, n_active_tokens = input_ids.shape
291-
n_positions = self.context_buckets[-1]
327+
continuous_batching_n_positions = bucket.find(self.context_buckets, n_active_tokens)
292328
assert n_active_seqs == cache_ids.shape[0], f"invalid n_active_seqs ({n_active_seqs} vs {cache_ids.shape[0]})"
293-
assert n_active_tokens <= n_positions, f"invalid input prompt length ({n_active_tokens} <= {n_positions})"
294-
cache_ids_pad = torch.zeros(n_active_seqs, n_positions, dtype=cache_ids.dtype, device='cpu')
329+
assert n_active_tokens <= continuous_batching_n_positions, \
330+
f"invalid input prompt length ({n_active_tokens} <= {continuous_batching_n_positions})"
331+
cache_ids_pad = torch.zeros(n_active_seqs, continuous_batching_n_positions, dtype=cache_ids.dtype, device='cpu')
332+
for seq_id in range(n_active_seqs):
333+
cache_ids_pad[seq_id, :n_active_tokens] = cache_ids[seq_id, :n_active_tokens]
334+
return input_ids, cache_ids_pad, seq_ids
335+
336+
elif n_active_tokens > 1 and cache_ids.flatten()[0].item() > 0:
337+
# speculative forward
338+
n_active_seqs, n_active_tokens = input_ids.shape
339+
speculative_n_positions = bucket.find(self.context_buckets, n_active_tokens)
340+
assert n_active_tokens <= speculative_n_positions, \
341+
f"invalid input prompt length ({n_active_tokens} <= {speculative_n_positions})"
342+
prompt_buckets = list(set([k for k, batch_size in self.decoder_lm_head_for_speculation.keys()]))
343+
speculation_bucket = bucket.find(prompt_buckets, n_active_tokens)
344+
# validate the speculative head was compiled for the given batch size
345+
speculation_batches = [batch_size for (k, batch_size) in self.decoder_lm_head_for_speculation.keys()]
346+
assert n_active_seqs in speculation_batches, \
347+
f"invalid batch size for speculative forward ({n_active_seqs} not in {speculation_batches})"
348+
# make cache ids 2d if needed and pad to match speculation bucket
349+
if len(cache_ids.shape) == 1:
350+
cache_ids = cache_ids.unsqueeze(0)
351+
assert cache_ids.shape[0] == n_active_seqs, \
352+
f"invalid n_active_seqs ({n_active_seqs} vs {cache_ids.shape[0]}) in speculative forward"
353+
cache_ids_pad = torch.zeros(n_active_seqs, speculative_n_positions, dtype=cache_ids.dtype, device='cpu')
295354
for seq_id in range(n_active_seqs):
296355
cache_ids_pad[seq_id, :n_active_tokens] = cache_ids[seq_id, :n_active_tokens]
297356
return input_ids, cache_ids_pad, seq_ids
@@ -311,7 +370,7 @@ def _preprocess(self, input_ids, start_ids=None, cache_ids=None):
311370
input_ids, cache_ids, start_ids = self._prepare_for_continuous_batching(input_ids, cache_ids, start_ids)
312371

313372
# right pad the input_ids if neccessary
314-
input_ids, last_token_id = self._prepare_for_par_ctx_rhs_padding(input_ids)
373+
input_ids, cache_ids, last_token_id = self._prepare_for_par_ctx_rhs_padding(input_ids, cache_ids)
315374

316375
# note: this context_length is after right padded
317376
batch_size, context_length = input_ids.shape
@@ -321,6 +380,8 @@ def _preprocess(self, input_ids, start_ids=None, cache_ids=None):
321380

322381
if cache_ids is None:
323382
cache_ids = torch.arange(context_length, dtype=torch.int32)
383+
if self.neuron_config.use_2d_cache_ids:
384+
cache_ids = cache_ids.unsqueeze(0).expand(batch_size, context_length)
324385

325386
if hasattr(self, "prefixed_length") and self.prefixed_length:
326387
cache_ids += self.prefixed_length
@@ -365,7 +426,7 @@ def _context_dynamic_batching(self, hidden, *args):
365426
"input batch size ({input_batch_size}) not divisible by running batch size ({running_batch_size})"
366427
n_iters = input_batch_size // running_batch_size
367428
all_logits = []
368-
cache_ids, start_ids = args[0], args[1]
429+
cache_ids, start_ids, last_token_id = args[0], args[1], args[2]
369430
for iter_id in range(n_iters):
370431
# Assuming HSB layout
371432
start_idx = iter_id*running_batch_size
@@ -376,9 +437,9 @@ def _context_dynamic_batching(self, hidden, *args):
376437
hidden_per_batch = hidden[:, :, start_idx:end_idx]
377438
cache_ids_per_batch = cache_ids[start_idx:end_idx, :]
378439
start_ids_per_batch = start_ids[start_idx:end_idx]
379-
last_token_id = cache_ids_per_batch.max()
440+
last_token_id_per_batch = last_token_id[start_idx:end_idx]
380441
logits_per_batch = self.context(hidden_per_batch, cache_ids_per_batch,
381-
start_ids_per_batch, last_token_id)
442+
start_ids_per_batch, last_token_id_per_batch)
382443
all_logits.append(logits_per_batch)
383444
logits = torch.cat(all_logits, dim=2)
384445
else:

src/transformers_neuronx/config.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,10 @@ def __init__(self, **kargs):
9393
self.cast_logits_dtype = kargs.pop('cast_logits_dtype', 'float32')
9494
self.fuse_qkv = kargs.pop('fuse_qkv', False)
9595
self.continuous_batching = kargs.pop('continuous_batching', None)
96-
self.use_2d_cache_ids = kargs.pop('use_2d_cache_ids', False)
96+
self.lhs_aligned = kargs.pop('use_2d_cache_ids', False) or kargs.pop('lhs_aligned', False)
9797
if self.continuous_batching:
9898
# Force using 2D cache_ids layout for continuous batching.
99-
self.use_2d_cache_ids = True
99+
self.lhs_aligned = True
100100
self.attention_layout = kargs.pop('attention_layout', constants.LAYOUT_HSB)
101101
self.cache_layout = kargs.pop('cache_layout', constants.LAYOUT_SBH)
102102
self.collectives_layout = kargs.pop('collectives_layout', constants.LAYOUT_HSB)
@@ -122,6 +122,14 @@ def __init__(self, **kargs):
122122

123123
self.layer_partition = {}
124124

125+
@property
126+
def use_2d_cache_ids(self):
127+
return self.lhs_aligned
128+
129+
@property
130+
def vectorize_last_token_id(self):
131+
return self.lhs_aligned
132+
125133
def is_valid_layer(self, layer_id):
126134
if not self.is_pp():
127135
return True

src/transformers_neuronx/decoder.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -215,12 +215,13 @@ def init_token_decoder(self,unroll, buckets, model_obj):
215215
decoder_lm_head.add_embedding_builder(self.hlo_builder.embedding)
216216
return decoder_lm_head
217217

218-
def init_speculative_decoder(self, unroll, buckets, model_obj, n_active_tokens):
219-
decoder_lm_head = DecoderLmHeadForSamplingNoEmbedding(
218+
def init_speculative_decoder(self, unroll, buckets, model_obj, n_active_tokens, batch_size=None):
219+
cls = type(self)
220+
decoder_lm_head = cls(
220221
tp_degree=self.tp_degree,
221222
n_positions_list=buckets,
222223
n_active_tokens=n_active_tokens,
223-
batch_size=self.batch_size,
224+
batch_size=self.batch_size if batch_size is None else batch_size,
224225
attention_head_size=self.attention_head_size,
225226
amp=self.amp,
226227
num_layers=self.num_layers,
@@ -400,9 +401,9 @@ def forward_single(self, *inputs):
400401
etc.
401402
"""
402403
_, cache_ids, start_ids, *_ = inputs
403-
batch_size, = start_ids.shape
404-
# In continuous batching, take largest cache_id and use the power-of-two policy to find the appropriate bucket.
405-
if self.neuron_config and self.neuron_config.continuous_batching:
404+
batch_size = start_ids.shape[0]
405+
# With 2D cache_ids, take largest cache_id and use the power-of-two policy to find the appropriate bucket.
406+
if self.neuron_config and self.neuron_config.use_2d_cache_ids:
406407
bucket_id = 0
407408
batch_size, _ = cache_ids.shape
408409
else:
@@ -416,7 +417,7 @@ def forward_single(self, *inputs):
416417

417418
def forward(self, *inputs):
418419
hidden, cache_ids, start_ids, *_ = inputs
419-
batch_size, = start_ids.shape
420+
batch_size = start_ids.shape[0]
420421
sequence_dim, *_ = self.inputs_sdim
421422
sequence_length = hidden.shape[sequence_dim]
422423
if sequence_length == 1:
@@ -459,8 +460,9 @@ def embed_positions_ids(self, position_ids, start_ids=None, batch_size=None):
459460
batch_size = self.batch_size[0]
460461
if start_ids is None:
461462
return position_ids, torch.zeros([batch_size], dtype=torch.int32)
462-
position_ids = position_ids.unsqueeze(0).repeat(batch_size, 1)
463-
position_ids -= start_ids.unsqueeze(1)
463+
if not self.neuron_config.use_2d_cache_ids:
464+
position_ids = position_ids.unsqueeze(0).repeat(batch_size, 1)
465+
position_ids -= start_ids.unsqueeze(1)
464466
position_ids.masked_fill_(position_ids < 0, 0)
465467
return position_ids, start_ids
466468

src/transformers_neuronx/generation_utils.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ def __init__(self, config, model):
2222
super().__init__(config)
2323
self.model = model
2424
self.config = config
25-
self.cur_len = 0
25+
self.cur_len = torch.zeros(1, dtype=torch.long)
2626

2727
def reset_generation(self):
28-
self.cur_len = 0
28+
self.cur_len = torch.zeros(1, dtype=torch.long)
2929

3030
def forward(self, input_ids, cache_ids, start_ids=None, output_hidden_states=False, output_attentions=False,
3131
attention_mask=None, return_dict=False):
@@ -69,23 +69,34 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_
6969
if attention_mask is not None:
7070
_, start_ids = attention_mask.max(axis=1)
7171

72-
if self.cur_len > 0:
72+
if (self.cur_len > 0).any().item():
7373
input_ids = input_ids[:, -1:]
74-
cache_ids = torch.as_tensor([self.cur_len], dtype=torch.int32)
75-
76-
continuous_batching = self.model.neuron_config.continuous_batching is not None
77-
if continuous_batching:
78-
if self.cur_len > 0:
79-
batch_size = input_ids.shape[0]
80-
cache_ids = torch.as_tensor([self.cur_len]*batch_size, dtype=torch.int32).reshape(batch_size, 1)
81-
start_ids = None
82-
else:
83-
cache_ids = torch.arange(input_ids.shape[-1]) * attention_mask
84-
start_ids = torch.arange(input_ids.shape[0])
8574

86-
# no need to prepare cache_ids for parallel context encoding here as forward will pad input_ids and generate legalized cache_ids
75+
if self.model.neuron_config.use_2d_cache_ids:
76+
# 2D cache_ids
77+
batch_size, context_length = attention_mask.shape
78+
start_ids = torch.arange(input_ids.shape[0])
79+
if (self.cur_len > 0).any().item():
80+
# token generation (aka decoding) with 2D cache_ids
81+
index_map = torch.arange(context_length).unsqueeze(0).expand(batch_size, context_length)
82+
cache_ids = (index_map * attention_mask).max(dim=1).values.unsqueeze(-1)
83+
self.cur_len = cache_ids.squeeze(-1)
84+
else:
85+
# context encoding (aka prefill) with 2D cache_ids
86+
cache_ids = torch.arange(context_length) * attention_mask
87+
self.cur_len = cache_ids.max(dim=1).values
88+
else:
89+
start_ids = None
90+
if (self.cur_len > 0).any().item():
91+
# token generation (aka decoding) with 1D cache_ids
92+
cache_ids = self.cur_len
93+
self.cur_len = cache_ids + 1
94+
else:
95+
# context encoding (aka prefill) with 1D cache_ids
96+
batch_size, context_length = input_ids.shape
97+
cache_ids = torch.arange(context_length)
98+
self.cur_len = torch.tensor([context_length], dtype=torch.long)
8799

88-
self.cur_len += input_ids.shape[-1]
89100
model_inputs = {
90101
"input_ids": input_ids,
91102
"cache_ids": cache_ids,

src/transformers_neuronx/gpt2/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ def forward(self, input_ids, cache_ids=None, start_ids=None):
442442
is_context_encode = context_length > 1
443443
estimate = bucket.find(self.context_buckets, context_length)
444444

445-
inputs, last_token_id = self._prepare_for_par_ctx_rhs_padding(input_ids)
445+
inputs, cache_ids, last_token_id = self._prepare_for_par_ctx_rhs_padding(input_ids, cache_ids)
446446
batch_size, context_length = inputs.shape
447447

448448
model = self.decoder_lm_head
@@ -529,7 +529,7 @@ def speculative_forward(self, input_ids, cache_ids=None, start_ids=None, specula
529529
if speculation_length is None:
530530
model=self.decoder_lm_head
531531
else:
532-
model=self.decoder_lm_head_for_speculation[speculation_length]
532+
model=self.decoder_lm_head_for_speculation[speculation_length, batch_size]
533533

534534
# Compute the window starting index for specific mask patterns
535535
# For other patterns we pass in a default value of 0, it won't be used

0 commit comments

Comments
 (0)