Commit 0623de2
Sync internal repo to external Apr 15 2024 (#85)
* [module] Added better/faster checkpoint support with both sharded/whole checkpoints
GitOrigin-RevId: 474757c3e65895084384e2e67d811f0423880fcf
* [generation-demo] Add --profile
GitOrigin-RevId: 7ba35ca7839df4fd300686db3632eb23feffbd6e
* [module] Added the ability to download weights from huggingface hub repositories
GitOrigin-RevId: 88661e675a31f2f1784056982717b3b15e04c922
* [automodel] Added NeuronAutoModelForCausalLM class which automatically loads architecture-specific classes
GitOrigin-RevId: 18095eec9b978327e74cac2f24f872ca93c876f9
* [util] Avoid truncating tensors when padded size is less than current size.
GitOrigin-RevId: 948fa370b543ac66976693e956bc024167060c8f
* [window-context] Add window context encoding
GitOrigin-RevId: 2def53a09fb5cbd58e1842941be8abd8fca79988
* Add support for post-processing logits for on-device log-softmax
GitOrigin-RevId: 416ad2a32cc89606872c62c366b89e51e506f861
* [generation-demo] Add torch-profile; Use random input with --prompt_len
GitOrigin-RevId: 62f39cfc1b857882034b307a7169bc3b2a46e292
* fix conflict for bsh gated mlp
GitOrigin-RevId: 34815f7a7261b9484e20aaf7e42f55a5de09fc87
* fix conflict for bsh gated mlp fix2
GitOrigin-RevId: 8bee8b2595d88837341faad75630349f33ea55dd
* [speculation] Updated speculative generator to correctly insert last draft token into KV cache
GitOrigin-RevId: 8a3eedd6ee516c641442cc1fa9c5639a491d097d
* [llama] Added support for tied embedding/head weights
GitOrigin-RevId: 65dea75643e5bc041bca7bdd8677a6951e3ffccc
* [decoder] Added a warmup to all kernels to avoid unexpected initialization latency spikes
GitOrigin-RevId: b62d7e7a2df4675e66354a6f56b527d0e332891f
* [hlo] add support for bool literals in Python 3.11
GitOrigin-RevId: 07ad81981b19ccaf5c775f18b839f51690f5b2ae
* Add support for Mistral-7B-v0.2 for no sliding window
GitOrigin-RevId: 205dcf4a5c8ce6e3c7c6c98e604e5ee3df509054
* [pp] call self directlry instead of self.forward to enable forward hook
GitOrigin-RevId: 908b7af05e1320a2ddded5734545b147cd7a20ba
* [module] Added support for model checkpoints using base model prefix
GitOrigin-RevId: 8e8bd0d72de318d1f4cbc9859b44dc4e0d9b0514
* Fused KV cache update for deduplicating index calculation
GitOrigin-RevId: 888778dfed05b873b7020b4208ed403edb87158d
* Add tags to models, this should pass tags down to Parallel kernel, prefix context kernels with 'context'
GitOrigin-RevId: cb9b6f19d2c0e33c441a2c770c8eb8a3c0a60a23
* Add warmup logic when profile gets called on each kernel
GitOrigin-RevId: 8225ec846d5bcfb35741e4982b72921ce248a55b
* [decoder] Handle corner cases where the KV cache is None.
GitOrigin-RevId: ac418340447a5b8b4fc8df0c46eb7e97d50befb3
* [decoder] Added prefix tags to more decoder types. Added less ambiguous tag parameter prefixes
GitOrigin-RevId: 15a25fab827f0dc18c4b37d11517f9eb4c5cd875
* Set self.tag in base class. This is used by PipelineParallelProgram
GitOrigin-RevId: 96ec44be9af30c88b607d8ddddb2ff5ff907ec5f
* Extend TP support of Mixtral-8x7B model from 8 to 16 and 32 and fix accuracy issue
GitOrigin-RevId: 2f0bbfb4934c8396327d9579afc8d5284887fe94
* support BSH attention layout for continuous batching
GitOrigin-RevId: 9664ff667ce1baa7c7eaacb57ecbe81d75a82629
* [generation_demo] additional flags, minor fixes
GitOrigin-RevId: 6ab804d5012a5286ab2d0f612d68e6e24854ea36
* [generation_demo] model from config support, minor fixes
GitOrigin-RevId: 38c82a99d0f628fe4b791dcc4d51bf7d0c835303
* Require transformers>=4.36
GitOrigin-RevId: 6f8b1ef2e099d268188ae7ed3b055ea94f7cbf81
* Support on-device embedding for GPT2. Fix multi-layer model support for
LLAMA and BLOOM and clean up forward function signatures.
GitOrigin-RevId: e7ea681c09e9f712c41668f4cc1aa78104f467e3
* Fixing HSB GroupNorm implementation
GitOrigin-RevId: 372b2cca5fae0418e8a4cf346cf87133ac33ddf6
* [compile-cache] Add NEURONX_DUMP_TO_NOTEMP to dump artifacts from neuron cache
GitOrigin-RevId: 140a46779a5e42806b901b3896c95251c9260010
* Fix forward call for Mixtral
GitOrigin-RevId: abefd80fc726015ab133dd40425d3ba97d1ff2f3
* [Speculation] Use target model to generate leftover tokens
GitOrigin-RevId: a654d7c01e43fffe9c3253850a75ea17d04aac7d
* add block diagonal causal mask for concatenated multi-prompt encoding
GitOrigin-RevId: f806d511d0eac7bf766972bb43eed9308566b5e4
* Revert [Speculation] Use target model to generate leftover tokens
GitOrigin-RevId: 76feacb3aa501239359e0c4939dacbdf311ca7e2
* [hlo] Support transposed attention output weight
GitOrigin-RevId: 5d1d00c1773d57570248388943c7c567a5dac870
* [Speculation] Use target model to generate leftover tokens
GitOrigin-RevId: 2f677787bbcea31e5738982f243183908e612a45
* [compiler] Fixed snapshot steps functionality in executor. Fixed warmup to never snapshot
GitOrigin-RevId: 343b2ce9549a9762f698fe2029ed28ad1245eb0f
* KV cache placement with slot_mapping
GitOrigin-RevId: 18c217fd664e60bd7834702f64001eecfaa0d688
* Update quantization to support contraction dim
GitOrigin-RevId: 397ded48a0850d98144d7a517f605d6ed3ac3691
* [mistral/mixtral/opt] Fix on-device embedding support for remaining models
GitOrigin-RevId: a52299027aabeffd47f13c9a9c4b4df600e8a4c2
* Adjust profiling to pass through number of NTFF files generated and remove tar file creation
GitOrigin-RevId: 689616f8951e0e778aeae0dbdc12d04c267b6cbf
* Fuse QKV support for GQA
GitOrigin-RevId: 938e3d267b77e4b5d225d214be47c190472472c7
* [Hlo] Change mmadd to dot_add and update it to support N-D tensors
GitOrigin-RevId: 5a06e680dbda0ab33262127df0b4458f88d38eed
* Replace problem characters in tagged HLO. This directly translates to filenames (NEFF, NTFF)
GitOrigin-RevId: fe658484f8c6cebe691ab623c45ee69e3427b5c9
* [Release 2.18] Change version to 1.0
GitOrigin-RevId: 2c948b4669ab83591925595fcbba87319971369d
* added ntff_count_limit argument to generation_demo
GitOrigin-RevId: 9649a8710ce4fe13cd22891f1e8ed983377c1cf6
* Merge branch 'mainline' of ssh://git.amazon.com:2222/pkg/KaenaTransformers into mainline
GitOrigin-RevId: aa1051241c38be805cca604855f4531c35eda83c
* Merge branch 'mainline' of ssh://git.amazon.com:2222/pkg/KaenaTransformers into mainline
GitOrigin-RevId: 31e2511deaf9b8aa2e4d983b1263377a74ab4cd8
* Merge branch 'mainline' of ssh://git.amazon.com:2222/pkg/KaenaTransformers into mainline
GitOrigin-RevId: c5b1d876e387ab51e55c0b1c7a8320ab97a88777
* Fix position_offset param for OPTForSamplingNoEmbeddingHlo
GitOrigin-RevId: 2aebac56ac5780fc7f88bc5f24bc7611562ef5bd
* initial support for BSH cache layout
GitOrigin-RevId: 68440f419057fbbb9befe02a5f95bbead9d4a24a
* support BSH cache layout with BSH attention layout
GitOrigin-RevId: 75e2807a57e509ef3dc9edae7c44421f911d8961
* [generation_demo] remove dump flag
GitOrigin-RevId: 23a774b75e2f8e224e511531703391ccf5ee4b10
* Reenable prompt broadcasting in GPT2 (input batch 1, output batch N)
GitOrigin-RevId: ec8772e01b121a2acd5526b9bbe3fa0bb9d334a9
* Fix return ranks for executors while using on-device sampling
GitOrigin-RevId: 412113c5b10b0285fe7f3aafe82bae578b463024
* [Release 2.18] Change version to 0.10.x
GitOrigin-RevId: fde8f715eca71e3ee1dac392b9e9b3527e6ee0cb
* [module] Allow safetensors checkpoint downloads to be explicitly disabled
GitOrigin-RevId: fb34b5e8ace114f086728382c0f60358fb687f24
* Override attn_implementation as eager to skip sdpa attn implemenatation
GitOrigin-RevId: 2112b39c4a43cadc77ec551b6ed49ce33a7a72f1
* [hlo] Added primitive broadcasting. Added new operators. Added on-device speculative token selection
GitOrigin-RevId: cf64e6130141d5d237cdc0277c8b6d3adc39630e
* LHS alignment for static batching (vectorize last_token_id)
GitOrigin-RevId: 41dc716837dc1416eea4b60c198a39710cc153a7
* fix cache_ids padding for llama CB and batch=1 SD
GitOrigin-RevId: 7a28dcbe147e2a4400fe7ae7c8ac5cbad145e185
* Cherry-picks to 2.18 for multi-bucketing/multi-prompt for continuous batching
GitOrigin-RevId: 74a3c4cdc4dc8ae7ee6a5a12a9734a084499224e
* Fix "Unit Tests - 2 Core" errors in test_neuron_auto_model.py mixtral tests
GitOrigin-RevId: 8d5c47068ccbd2e87672243aedd56c46b887a8b9
* fix generation_utils
GitOrigin-RevId: fcb5254a8ecafbeef434cff5fce9075993cdd471
---------
Co-authored-by: Jonathan Lunt <[email protected]>
Co-authored-by: Bowen Chen <[email protected]>
Co-authored-by: Yuan Zhou <[email protected]>
Co-authored-by: Devesh Ratho <[email protected]>
Co-authored-by: Amer <[email protected]>
Co-authored-by: Mike Zhang <[email protected]>
Co-authored-by: Liangfu Chen <[email protected]>
Co-authored-by: Nicholas Waldron <[email protected]>
Co-authored-by: Shubham Chandak <[email protected]>
Co-authored-by: Wojciech Romaszkan <[email protected]>
Co-authored-by: Amulya Ballakur <[email protected]>
Co-authored-by: Dylan Geva <[email protected]>
Co-authored-by: Jeffrey Huynh <[email protected]>
Co-authored-by: Shashwat Srijan <[email protected]>
Co-authored-by: Prithvijit Chakrabarty <[email protected]>1 parent 7a30b42 commit 0623de2
File tree
15 files changed
+524
-91
lines changed- src/transformers_neuronx
- gpt2
- layers
- llama
- mistral
15 files changed
+524
-91
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
73 | 73 | | |
74 | 74 | | |
75 | 75 | | |
76 | | - | |
| 76 | + | |
77 | 77 | | |
78 | | - | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
79 | 83 | | |
80 | | - | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
81 | 87 | | |
82 | 88 | | |
83 | 89 | | |
| |||
164 | 170 | | |
165 | 171 | | |
166 | 172 | | |
167 | | - | |
| 173 | + | |
168 | 174 | | |
169 | 175 | | |
170 | 176 | | |
| |||
239 | 245 | | |
240 | 246 | | |
241 | 247 | | |
242 | | - | |
| 248 | + | |
243 | 249 | | |
244 | 250 | | |
245 | 251 | | |
| |||
257 | 263 | | |
258 | 264 | | |
259 | 265 | | |
260 | | - | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
261 | 270 | | |
262 | | - | |
| 271 | + | |
263 | 272 | | |
264 | 273 | | |
265 | 274 | | |
| |||
269 | 278 | | |
270 | 279 | | |
271 | 280 | | |
272 | | - | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
| 284 | + | |
273 | 285 | | |
274 | 286 | | |
| 287 | + | |
| 288 | + | |
| 289 | + | |
275 | 290 | | |
276 | | - | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
| 298 | + | |
| 299 | + | |
| 300 | + | |
| 301 | + | |
| 302 | + | |
| 303 | + | |
| 304 | + | |
| 305 | + | |
| 306 | + | |
| 307 | + | |
| 308 | + | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
277 | 313 | | |
278 | 314 | | |
279 | 315 | | |
| |||
288 | 324 | | |
289 | 325 | | |
290 | 326 | | |
291 | | - | |
| 327 | + | |
292 | 328 | | |
293 | | - | |
294 | | - | |
| 329 | + | |
| 330 | + | |
| 331 | + | |
| 332 | + | |
| 333 | + | |
| 334 | + | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
| 343 | + | |
| 344 | + | |
| 345 | + | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
295 | 354 | | |
296 | 355 | | |
297 | 356 | | |
| |||
311 | 370 | | |
312 | 371 | | |
313 | 372 | | |
314 | | - | |
| 373 | + | |
315 | 374 | | |
316 | 375 | | |
317 | 376 | | |
| |||
321 | 380 | | |
322 | 381 | | |
323 | 382 | | |
| 383 | + | |
| 384 | + | |
324 | 385 | | |
325 | 386 | | |
326 | 387 | | |
| |||
365 | 426 | | |
366 | 427 | | |
367 | 428 | | |
368 | | - | |
| 429 | + | |
369 | 430 | | |
370 | 431 | | |
371 | 432 | | |
| |||
376 | 437 | | |
377 | 438 | | |
378 | 439 | | |
379 | | - | |
| 440 | + | |
380 | 441 | | |
381 | | - | |
| 442 | + | |
382 | 443 | | |
383 | 444 | | |
384 | 445 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
93 | 93 | | |
94 | 94 | | |
95 | 95 | | |
96 | | - | |
| 96 | + | |
97 | 97 | | |
98 | 98 | | |
99 | | - | |
| 99 | + | |
100 | 100 | | |
101 | 101 | | |
102 | 102 | | |
| |||
122 | 122 | | |
123 | 123 | | |
124 | 124 | | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
125 | 133 | | |
126 | 134 | | |
127 | 135 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
215 | 215 | | |
216 | 216 | | |
217 | 217 | | |
218 | | - | |
219 | | - | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
220 | 221 | | |
221 | 222 | | |
222 | 223 | | |
223 | | - | |
| 224 | + | |
224 | 225 | | |
225 | 226 | | |
226 | 227 | | |
| |||
400 | 401 | | |
401 | 402 | | |
402 | 403 | | |
403 | | - | |
404 | | - | |
405 | | - | |
| 404 | + | |
| 405 | + | |
| 406 | + | |
406 | 407 | | |
407 | 408 | | |
408 | 409 | | |
| |||
416 | 417 | | |
417 | 418 | | |
418 | 419 | | |
419 | | - | |
| 420 | + | |
420 | 421 | | |
421 | 422 | | |
422 | 423 | | |
| |||
459 | 460 | | |
460 | 461 | | |
461 | 462 | | |
462 | | - | |
463 | | - | |
| 463 | + | |
| 464 | + | |
| 465 | + | |
464 | 466 | | |
465 | 467 | | |
466 | 468 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
22 | 22 | | |
23 | 23 | | |
24 | 24 | | |
25 | | - | |
| 25 | + | |
26 | 26 | | |
27 | 27 | | |
28 | | - | |
| 28 | + | |
29 | 29 | | |
30 | 30 | | |
31 | 31 | | |
| |||
69 | 69 | | |
70 | 70 | | |
71 | 71 | | |
72 | | - | |
| 72 | + | |
73 | 73 | | |
74 | | - | |
75 | | - | |
76 | | - | |
77 | | - | |
78 | | - | |
79 | | - | |
80 | | - | |
81 | | - | |
82 | | - | |
83 | | - | |
84 | | - | |
85 | 74 | | |
86 | | - | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
87 | 99 | | |
88 | | - | |
89 | 100 | | |
90 | 101 | | |
91 | 102 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
442 | 442 | | |
443 | 443 | | |
444 | 444 | | |
445 | | - | |
| 445 | + | |
446 | 446 | | |
447 | 447 | | |
448 | 448 | | |
| |||
529 | 529 | | |
530 | 530 | | |
531 | 531 | | |
532 | | - | |
| 532 | + | |
533 | 533 | | |
534 | 534 | | |
535 | 535 | | |
| |||
0 commit comments