Commit c8d6bdc
Sync internal repo to external June 28 2024 (#90)
* [module] Added better/faster checkpoint support with both sharded/whole checkpoints
GitOrigin-RevId: 474757c3e65895084384e2e67d811f0423880fcf
* [generation-demo] Add --profile
GitOrigin-RevId: 7ba35ca7839df4fd300686db3632eb23feffbd6e
* [module] Added the ability to download weights from huggingface hub repositories
GitOrigin-RevId: 88661e675a31f2f1784056982717b3b15e04c922
* [automodel] Added NeuronAutoModelForCausalLM class which automatically loads architecture-specific classes
GitOrigin-RevId: 18095eec9b978327e74cac2f24f872ca93c876f9
* [util] Avoid truncating tensors when padded size is less than current size.
GitOrigin-RevId: 948fa370b543ac66976693e956bc024167060c8f
* [window-context] Add window context encoding
GitOrigin-RevId: 2def53a09fb5cbd58e1842941be8abd8fca79988
* Add support for post-processing logits for on-device log-softmax
GitOrigin-RevId: 416ad2a32cc89606872c62c366b89e51e506f861
* [generation-demo] Add torch-profile; Use random input with --prompt_len
GitOrigin-RevId: 62f39cfc1b857882034b307a7169bc3b2a46e292
* fix conflict for bsh gated mlp
GitOrigin-RevId: 34815f7a7261b9484e20aaf7e42f55a5de09fc87
* fix conflict for bsh gated mlp fix2
GitOrigin-RevId: 8bee8b2595d88837341faad75630349f33ea55dd
* [speculation] Updated speculative generator to correctly insert last draft token into KV cache
GitOrigin-RevId: 8a3eedd6ee516c641442cc1fa9c5639a491d097d
* [llama] Added support for tied embedding/head weights
GitOrigin-RevId: 65dea75643e5bc041bca7bdd8677a6951e3ffccc
* [decoder] Added a warmup to all kernels to avoid unexpected initialization latency spikes
GitOrigin-RevId: b62d7e7a2df4675e66354a6f56b527d0e332891f
* [hlo] add support for bool literals in Python 3.11
GitOrigin-RevId: 07ad81981b19ccaf5c775f18b839f51690f5b2ae
* Add support for Mistral-7B-v0.2 for no sliding window
GitOrigin-RevId: 205dcf4a5c8ce6e3c7c6c98e604e5ee3df509054
* [pp] call self directlry instead of self.forward to enable forward hook
GitOrigin-RevId: 908b7af05e1320a2ddded5734545b147cd7a20ba
* [module] Added support for model checkpoints using base model prefix
GitOrigin-RevId: 8e8bd0d72de318d1f4cbc9859b44dc4e0d9b0514
* Fused KV cache update for deduplicating index calculation
GitOrigin-RevId: 888778dfed05b873b7020b4208ed403edb87158d
* Add tags to models, this should pass tags down to Parallel kernel, prefix context kernels with 'context'
GitOrigin-RevId: cb9b6f19d2c0e33c441a2c770c8eb8a3c0a60a23
* Add warmup logic when profile gets called on each kernel
GitOrigin-RevId: 8225ec846d5bcfb35741e4982b72921ce248a55b
* [decoder] Handle corner cases where the KV cache is None.
GitOrigin-RevId: ac418340447a5b8b4fc8df0c46eb7e97d50befb3
* [decoder] Added prefix tags to more decoder types. Added less ambiguous tag parameter prefixes
GitOrigin-RevId: 15a25fab827f0dc18c4b37d11517f9eb4c5cd875
* Set self.tag in base class. This is used by PipelineParallelProgram
GitOrigin-RevId: 96ec44be9af30c88b607d8ddddb2ff5ff907ec5f
* Extend TP support of Mixtral-8x7B model from 8 to 16 and 32 and fix accuracy issue
GitOrigin-RevId: 2f0bbfb4934c8396327d9579afc8d5284887fe94
* support BSH attention layout for continuous batching
GitOrigin-RevId: 9664ff667ce1baa7c7eaacb57ecbe81d75a82629
* [generation_demo] additional flags, minor fixes
GitOrigin-RevId: 6ab804d5012a5286ab2d0f612d68e6e24854ea36
* [generation_demo] model from config support, minor fixes
GitOrigin-RevId: 38c82a99d0f628fe4b791dcc4d51bf7d0c835303
* Require transformers>=4.36
GitOrigin-RevId: 6f8b1ef2e099d268188ae7ed3b055ea94f7cbf81
* Support on-device embedding for GPT2. Fix multi-layer model support for
LLAMA and BLOOM and clean up forward function signatures.
GitOrigin-RevId: e7ea681c09e9f712c41668f4cc1aa78104f467e3
* Fixing HSB GroupNorm implementation
GitOrigin-RevId: 372b2cca5fae0418e8a4cf346cf87133ac33ddf6
* [compile-cache] Add NEURONX_DUMP_TO_NOTEMP to dump artifacts from neuron cache
GitOrigin-RevId: 140a46779a5e42806b901b3896c95251c9260010
* Fix forward call for Mixtral
GitOrigin-RevId: abefd80fc726015ab133dd40425d3ba97d1ff2f3
* [Speculation] Use target model to generate leftover tokens
GitOrigin-RevId: a654d7c01e43fffe9c3253850a75ea17d04aac7d
* add block diagonal causal mask for concatenated multi-prompt encoding
GitOrigin-RevId: f806d511d0eac7bf766972bb43eed9308566b5e4
* Revert [Speculation] Use target model to generate leftover tokens
GitOrigin-RevId: 76feacb3aa501239359e0c4939dacbdf311ca7e2
* [hlo] Support transposed attention output weight
GitOrigin-RevId: 5d1d00c1773d57570248388943c7c567a5dac870
* [Speculation] Use target model to generate leftover tokens
GitOrigin-RevId: 2f677787bbcea31e5738982f243183908e612a45
* [compiler] Fixed snapshot steps functionality in executor. Fixed warmup to never snapshot
GitOrigin-RevId: 343b2ce9549a9762f698fe2029ed28ad1245eb0f
* KV cache placement with slot_mapping
GitOrigin-RevId: 18c217fd664e60bd7834702f64001eecfaa0d688
* Update quantization to support contraction dim
GitOrigin-RevId: 397ded48a0850d98144d7a517f605d6ed3ac3691
* [mistral/mixtral/opt] Fix on-device embedding support for remaining models
GitOrigin-RevId: a52299027aabeffd47f13c9a9c4b4df600e8a4c2
* Adjust profiling to pass through number of NTFF files generated and remove tar file creation
GitOrigin-RevId: 689616f8951e0e778aeae0dbdc12d04c267b6cbf
* Fuse QKV support for GQA
GitOrigin-RevId: 938e3d267b77e4b5d225d214be47c190472472c7
* [Hlo] Change mmadd to dot_add and update it to support N-D tensors
GitOrigin-RevId: 5a06e680dbda0ab33262127df0b4458f88d38eed
* Replace problem characters in tagged HLO. This directly translates to filenames (NEFF, NTFF)
GitOrigin-RevId: fe658484f8c6cebe691ab623c45ee69e3427b5c9
* [Release 2.18] Change version to 1.0
GitOrigin-RevId: 2c948b4669ab83591925595fcbba87319971369d
* added ntff_count_limit argument to generation_demo
GitOrigin-RevId: 9649a8710ce4fe13cd22891f1e8ed983377c1cf6
* Merge branch 'mainline' of ssh://git.amazon.com:2222/pkg/KaenaTransformers into mainline
GitOrigin-RevId: aa1051241c38be805cca604855f4531c35eda83c
* Merge branch 'mainline' of ssh://git.amazon.com:2222/pkg/KaenaTransformers into mainline
GitOrigin-RevId: 31e2511deaf9b8aa2e4d983b1263377a74ab4cd8
* Merge branch 'mainline' of ssh://git.amazon.com:2222/pkg/KaenaTransformers into mainline
GitOrigin-RevId: c5b1d876e387ab51e55c0b1c7a8320ab97a88777
* Fix position_offset param for OPTForSamplingNoEmbeddingHlo
GitOrigin-RevId: 2aebac56ac5780fc7f88bc5f24bc7611562ef5bd
* initial support for BSH cache layout
GitOrigin-RevId: 68440f419057fbbb9befe02a5f95bbead9d4a24a
* support BSH cache layout with BSH attention layout
GitOrigin-RevId: 75e2807a57e509ef3dc9edae7c44421f911d8961
* [generation_demo] remove dump flag
GitOrigin-RevId: 23a774b75e2f8e224e511531703391ccf5ee4b10
* Reenable prompt broadcasting in GPT2 (input batch 1, output batch N)
GitOrigin-RevId: ec8772e01b121a2acd5526b9bbe3fa0bb9d334a9
* Fix return ranks for executors while using on-device sampling
GitOrigin-RevId: 412113c5b10b0285fe7f3aafe82bae578b463024
* [Release 2.18] Change version to 0.10.x
GitOrigin-RevId: fde8f715eca71e3ee1dac392b9e9b3527e6ee0cb
* [module] Allow safetensors checkpoint downloads to be explicitly disabled
GitOrigin-RevId: fb34b5e8ace114f086728382c0f60358fb687f24
* Override attn_implementation as eager to skip sdpa attn implemenatation
GitOrigin-RevId: 2112b39c4a43cadc77ec551b6ed49ce33a7a72f1
* [hlo] Added primitive broadcasting. Added new operators. Added on-device speculative token selection
GitOrigin-RevId: cf64e6130141d5d237cdc0277c8b6d3adc39630e
* LHS alignment for static batching (vectorize last_token_id)
GitOrigin-RevId: 41dc716837dc1416eea4b60c198a39710cc153a7
* fix cache_ids padding for llama CB and batch=1 SD
GitOrigin-RevId: 7a28dcbe147e2a4400fe7ae7c8ac5cbad145e185
* Cherry-picks to 2.18 for multi-bucketing/multi-prompt for continuous batching
GitOrigin-RevId: 74a3c4cdc4dc8ae7ee6a5a12a9734a084499224e
* Fix "Unit Tests - 2 Core" errors in test_neuron_auto_model.py mixtral tests
GitOrigin-RevId: 8d5c47068ccbd2e87672243aedd56c46b887a8b9
* fix generation_utils
GitOrigin-RevId: fcb5254a8ecafbeef434cff5fce9075993cdd471
* support BSH cache layout with BSH attention layout
GitOrigin-RevId: 5868197a4ed26eec2346b23acbc1c4ccf764f826
* Add the ability to return arbitrary tensors for debugging
GitOrigin-RevId: fdfc3c19a864a19a578b109e95c003f23b98a93c
* [generation_demo] remove dump flag
GitOrigin-RevId: 4da57bd6b09f606f5d1f0d290965ef7ff39bde2c
* fix pp due to the change of HloNeuronProgram
GitOrigin-RevId: 19ea3848b7ba66bd5a216f25052001b4d150c67e
* Changes to make simple_sample work with speculative decoding
GitOrigin-RevId: 20164ebedb6ef274032fea570bfa52a385e75b40
* Add QKV weight tiling for fused QKV
GitOrigin-RevId: da6a9ede322d529215879b96356778269460cc3f
* Override attn_implementation as eager to skip sdpa attn implemenatation
GitOrigin-RevId: 1705e85c4cec9c665acd74b37fc5a54b98e9b405
* Revert "fix pp due to the change of HloNeuronProgram"
This reverts commit 19ea3848b7ba66bd5a216f25052001b4d150c67e.
GitOrigin-RevId: fdbea022a128256e661bddc23a6ad709570a8f96
* Revert "Add the ability to return arbitrary tensors for debugging"
This reverts commit fdfc3c19a864a19a578b109e95c003f23b98a93c.
GitOrigin-RevId: 584a49ec3ac28754bfac83a203e66cbf9f7d333d
* Merge branch 'mainline' of ssh://git.amazon.com:2222/pkg/KaenaTransformers into mainline
GitOrigin-RevId: d0d79eba5c29edbe672ac58ee250bcfa20bc62ad
* Merge branch 'mainline' of ssh://git.amazon.com:2222/pkg/KaenaTransformers into mainline
GitOrigin-RevId: 148afab8f91d523c7436af46cd2d556546d96692
* redo fix for DecoderProgramMultiLayer constructor API change
GitOrigin-RevId: 008f9c9568fdf60b2390dde2458944532bd12d5a
* Merge branch 'mainline' of ssh://git.amazon.com:2222/pkg/KaenaTransformers into mainline
GitOrigin-RevId: 8f4d0c97e4fdbe33c3458cb13ac1a74661218b4b
* [llama] support qk norm
GitOrigin-RevId: 9f395a3cecd822e83fe56653afe99b4930a22dd6
* [executor] Added the ability to support concatenated outputs using new runtime integration
GitOrigin-RevId: 6c43d2ebd2be15db4830f751708c1a9184c7c6ea
* [decoder] Decoder init_xxx_decoder methods clone a copy of the current class to allow derived classes to build themselves
GitOrigin-RevId: cc6a5caa9e45652c5e230e0310889fa99fbfdbc7
* Fix return ranks for executors while using on-device sampling
GitOrigin-RevId: 3268b18eae910ad9a4d2e1aafd4a55dacac8ff13
* [module] Allow ties between more than 2 parameters
GitOrigin-RevId: 5de3ebb808ef09f2258fda0e575155592d9da658
* [gqa] Reduce GQA replication amount under certain conditions
GitOrigin-RevId: a0772a0d8cb6d755e01acb3d0e7d1b84de3eccfb
* [Release 2.18] Change version to 0.10.x
GitOrigin-RevId: 66dd5f983669e3c45d90680289987f4fb44a92ce
* [module] Allow safetensors checkpoint downloads to be explicitly disabled
GitOrigin-RevId: 7a496d2d13bcaed3481b59e1c332332c8dc9404a
* Expose GenerationConfig as a top level construct
GitOrigin-RevId: 7ab8740ba4efd253b4b99905616b8d2d8892d607
* [decoder] Updated HLO builder 'inputs' function to only handle parameters. Moved mask creation to 'pre_layer' function
GitOrigin-RevId: a6ff3f126b2fe1acaa8309fa63f6dd84fddbc6f4
* [decoder] Factored HLO graph build into more discrete functions
GitOrigin-RevId: 755d790cfa127a7d4d4481f289bba8e9bf11146a
* [sampling] Removed sampling from architecture-specific classes
GitOrigin-RevId: 217f6a98ceb2cdba419f0523b3ab2bbba389e5ac
* [llama/speculation] 2d kv-cache id update for speculative_forward
GitOrigin-RevId: d25ac42ef25540598d04e6fa95afae643d35d1a7
* [hlo] Added primitive broadcasting. Added new operators. Added on-device speculative token selection
GitOrigin-RevId: b4b239d015e30f9b35ad65a57eaceef672502214
* [hlo] Added tensor parallel softmax implementation
GitOrigin-RevId: 42746ebd9c869cbec78e9b8b10ed7a4a38665680
* [generation] Removed sequence length restriction for on-device sampling
GitOrigin-RevId: 49e6d01840a5d90689c1a323f232005c8c686e1f
* [hlo] Simplified speculative token selection API
GitOrigin-RevId: be99bb4f54f3499f27410842bf32eb75dd6bc729
* [hlo] Fixed speculative token selector batched broadcast issue
GitOrigin-RevId: 5fc9a2d82d9a6e2ccf6d9986aa0359b24f982a72
* [base] Make batch sizes for enable_speculative_decoder optional
GitOrigin-RevId: b78200a717c44820b7b32f4d465dc74cf55772b6
* [decoder] Added a method to be able to retrieve all decoder parameters
GitOrigin-RevId: 9954297014f8cab87931ffc376e4c28dbe16552b
* [program] Added a ParallelProgram model executor with a simplified/general interface
GitOrigin-RevId: 337066258c251392c01be101b52c6c463107ea41
* LHS alignment for static batching (vectorize last_token_id)
GitOrigin-RevId: 8a678f4b979dc59d40e92bb74de3cacde9859d16
* [hlo] legalize broadcast and batch norm lowerings
GitOrigin-RevId: 761573b84de85088321a65f098b66c7a258cdb9a
* Revert "[llama] support qk norm"
This reverts commit 9f395a3cecd822e83fe56653afe99b4930a22dd6.
GitOrigin-RevId: 9dddbad3a8d87c4e3db92d5a597b7421ff0a9de0
* [hlo] Allow speculative token selection to pad with a specified token
GitOrigin-RevId: dc4f9669d5e4b142043a4515ebab53aa6d217f33
* [snapshot] Add tag as iteration suffix. Fixed Executor snapshot iteration
GitOrigin-RevId: b221a1f6d444916ed0a1f1d4f74a63fe437eb873
* [GQA] Reduce replication amount under certain conditions
GitOrigin-RevId: 6eb3504e1e0c74f184f60cd1988b5f016e257967
* [debugger] Re-enable debugger.
GitOrigin-RevId: ba251cd21ef4fd9a9f54d3060c7339e565fbf493
* [hlo] Added sort_with_indices & argsort
GitOrigin-RevId: 59e37888ac4b2206ee4a3d00bac18a2bbb421c75
* [program] Added bucketed program and selector classes
GitOrigin-RevId: 2af53bd710dead8690fb28fbad9683b06dbaee26
* [hlo] Update argmax to use built in hlo ops; update reshape to handle scalar input
GitOrigin-RevId: 0c30fbec405d7bb713b9508e818297d7a9b3de27
* [generation] Added Top-P on-device sampling
GitOrigin-RevId: 961d6d91496cbe574a3653d91ab6b43a0142219d
* [hlo][topk] use built in hlo functions
GitOrigin-RevId: 4e494643c36ec8949e32f086f4af5fccfd31f667
* Fix overwriting of cache_ids causing accuracy issues
GitOrigin-RevId: e1ad8b5da91b3a7455060aa8dbf582c113a9142b
* [hlo] Add support for custom activations in MLP
GitOrigin-RevId: 7fe42367042aa68783ae3dd7d703693cfe9766c5
* Adapt recent NKI changes into NKI interface and add NKI flash attention
kernel.
SIM: https://sim.amazon.com/issues/NAPP-1895
GitOrigin-RevId: d95005cd4809e46b6774c5bc76af21dacf22259b
* remove unused line
sim: https://sim.amazon.com/issues/NAPP-1895
GitOrigin-RevId: 9375aa04bf55f6d5dac4ba7f51ab8f331f93b741
* change bf16 dtype to nl.bfloat16 instead of |V2
sim: https://sim.amazon.com/issues/NAPP-1895
GitOrigin-RevId: 73c6a276eee8fa496cc7e0f247dccee1fdb57997
* change bf16 dtype to nl.bfloat16 instead of |V2
sim: https://sim.amazon.com/issues/NAPP-1895
GitOrigin-RevId: 61b5b20dc8efd0fa14bd0254021e0a806ca48dd6
* removed unnecessary flash attention kernel
sim: https://sim.amazon.com/issues/NAPP-1895
GitOrigin-RevId: fc0cb2da6119b46cf35c082d3dd777e224f9a10d
* adding the wrapper
sim: https://sim.amazon.com/issues/NAPP-1895
GitOrigin-RevId: e2139891727481d9684cdb4eb5ec048662633848
* Added flash attention kernel wrapper
sim: https://sim.amazon.com/issues/NAPP-1895
GitOrigin-RevId: b084bcb04a6b912e2bc179f1347dfe1d609bbe44
* remove verification file
sim: https://sim.amazon.com/issues/NAPP-1895
GitOrigin-RevId: 9bd2759aa5ac12090f4477678d1f2eaedccaa7e9
* [hlo] Add error handling for non-callable activation functions
GitOrigin-RevId: 46fb6d7ff93dd49ccf6d5d952767f9237366154a
* [hlo] Updated speculative token selection softmax to use sharded compute
GitOrigin-RevId: c74bd613988ea9373d1500461ed4c393c724d726
* [Generation] Add quantization and flash attention args
GitOrigin-RevId: ee29661a5a0ed504e1b24057eb87bbba53b9c68c
* [hlo] Add ability to return token selection mask from speculative_token_selection
GitOrigin-RevId: 3882aad7066bc954d08e610469428fcccde1829a
* Fix CB logit postprocessing for full batches
GitOrigin-RevId: 676a12cc8ea0056a3fade4636bb50aefccf0eec6
* [base] Fixed off-by-one error in window context encoding and make unroll factor optional
GitOrigin-RevId: 24850985ce6cae43972c3092ab2dbc3764b62b49
* add missing debug_tensors param
GitOrigin-RevId: 585647b7f8a4f47952eeb625b0c4c1692ce3c51c
* [continuous batching] Add multi-bucketing support to continuous batching
GitOrigin-RevId: ebc57784b3d9400923d11551a1c01b95260d0c94
* Revert "Fix CB logit postprocessing for full batches"
This reverts commit 676a12cc8ea0056a3fade4636bb50aefccf0eec6.
GitOrigin-RevId: b9a20bb2053e5b9067cad3310ece9849526acf49
* Updated BSH/HSB attention layouts to use unified logic
GitOrigin-RevId: 03bdaf947c48ebf4b216eb7183d6cc388e010f10
* flash attention integration
GitOrigin-RevId: fe46b47151825c9c8783adde96e2330aab9a9783
* [decoder/utils] Factored out QKV padding calculation to common utility
GitOrigin-RevId: e3da5c6bfc2c7a2adbb26abedf129ff92ac7d14b
* [config] Added type annotations to NeuronConfig, added deprecation warnings for old arguments, and exposed top-level configurations
GitOrigin-RevId: 2a64c500d43994bb5a8dbeb5f4d7a9b20b74e38b
* [attention] Fixed B & S transpose issue in BSH attention output
GitOrigin-RevId: 57c27b49d91b965c26436f69e8ad2c88fc14a099
* [sos] refactor flash decoding related code
Author: karthick gopalswamy <[email protected]>
Date: 2024-04-05T01:22:10.000Z
GitOrigin-RevId: 3630b8ac60247ff01d820b47b94bb3b33db2704a
* [Do Not Merge] Flash decoding chanegs
GitOrigin-RevId: 2f5a475e7e9d91b96ea3c9d7af94617aa8b143c3
* [Do Not Merge] Flash Decoding mask and cache modification
GitOrigin-RevId: c5af182cda9aa7e012e19a36527e4de6df2241cf
* [sos] sequence parallel attention
GitOrigin-RevId: 93d1e82027d5896398043d97a42a99cbe5497e2e
* Revert "[Do Not Merge] Flash Decoding mask and cache modification"
This reverts commit 3b59fbda15971bd690e357d412f91f121f2b9097.
GitOrigin-RevId: e9adb2dfe13738de009d81b35124cbe635ff9768
* Revert "[Do Not Merge] Flash decoding chanegs"
This reverts commit 8b0973a6d7540b0715d640365b02628572f3ef4d.
GitOrigin-RevId: c1ffbc49f80a54b2e8382aca6270cdd416fa6ec6
* [sos] refactor flash decoding related code
GitOrigin-RevId: 853274b5619bbf267b44231c15f0b2e958d11d16
* [hlo] Add standalone definition for top-p
GitOrigin-RevId: 7c68c568ee3866a245f587aecd3bc91c11c18622
* [llama] Fixed accuracy issue with speculation
GitOrigin-RevId: 93214daabb17d4422dc85a72a95c429ab59a40c0
* [hlo] Add masked top-k implementation
GitOrigin-RevId: bb8e127fd115eda0032ead1cd3dc657337c8cda1
* [rotary] Updated functions to use hlo.py operator definitions
GitOrigin-RevId: a82e7a3c6cc2132ac529df7fca42d014c5da1f3a
* [generation] Enable dynamic static shape sampling in generation.py
GitOrigin-RevId: e0098078d6c8ab4f6c6b7e869e15f9ab0571350c
* add reduce-scatter to hlo
GitOrigin-RevId: 04dc756760c9d4e1952ef05f3b79c735390dcdad
* Fix breaking change in flash attention kernel interface
GitOrigin-RevId: 745f460b0866c4fca84e860f5107011e9ff5d2d5
* Add streamer support during on device sampling
GitOrigin-RevId: 64a68d3cb951de326df37221f7490ec8ba370192
* have GPT2ForSamplingWithContextBroadcasting sample accept a streamer in mainline
GitOrigin-RevId: 56582d463a5d91a14399e04406bd673d45c7916c
* fix continuous batching support for llama and mistral
GitOrigin-RevId: e490fc0bf033948f57ac3f1bbdeddf3070b5bf8c
* bloom, gptj, gptneox, opt accept streamer in sample
GitOrigin-RevId: f9a4eeb4b78ff7ea684f056b8b5229ffa0bccdba
* commit: c9b19fbb411f9ea0efa9eebd3f62477122752f97
Author: Joshua Hannan <[email protected]>
Date: 2024-04-17T02:51:14.000Z
[generation] Add the ability to do dynamic on device generation
GitOrigin-RevId: 95070848da3c80b459443d743af070a05ba573b5
* [generation][Mixtral/OPT] Update Mixtral and OPT to support on device generation
GitOrigin-RevId: 8ea893cd895c39340045df707f3475f4f138e15e
* [generation] Handle non-fp32 inputs for temperature
GitOrigin-RevId: 7a02696de81f03f43af7ea1044c97b368eccd4fb
* [compiler] Updated shape checking to be performed on provided inputs
GitOrigin-RevId: 6bb6d357418a7119bea628c8b20100c0ba1ac7af
* [config] Added equality comparison function to GenerationConfig
GitOrigin-RevId: fe26c7603e5b4704239f2ed3ddc67ca25a4382d0
* [hlo] Update top-p to not perform reverse ops
GitOrigin-RevId: f4783bfe7f54a60c17495d3e8bd843c1755076c9
* [hlo] Added fast implementation of cumsum operation
GitOrigin-RevId: 9de129e3cb99dd8666b999e4164c3d86a47b6e02
* [hlo] Updated fast cumulative sum to always cast inputs
GitOrigin-RevId: 77abc934e56ba22dc4d4861b3d1caac8ab11d0c1
* Optimize QKV padding for certain GQA models
GitOrigin-RevId: a5527ab6dac38ab8fa217462f8b9b669360055b6
* [hlo] Added fast cumulative sum floating point check
GitOrigin-RevId: 3b7a87101d5af41770971383726e6fc48fc323d6
* [generation] Transpose logits prior to generation
GitOrigin-RevId: 806dfa933cf5ae88a6017e9ab2ef613cd3740f5e
* [hlo] Added decomposed speculation functions
GitOrigin-RevId: 752dd8d1fa5afaf215afce18ee032ad2f79f374b
* [generation] Add the ability to specify a global top-k
GitOrigin-RevId: fe9606d00bb58b3265096cb03870f32a451eb7ac
* [llama/speculation] Bug fixes: speculative_forward with continuous batching
GitOrigin-RevId: 702779e17d910afc3c150b3b5d329afb2357b348
* Fix mistral sliding window bug
GitOrigin-RevId: e4bdc2b6027442b14bb954af8c93d1e1710a4d47
* [sos] flash decoding changes for llama model
GitOrigin-RevId: a17189cf52ebba193f047a282f742c214d171559
* Optimize logit ordering for continuous batching
GitOrigin-RevId: 5198b8a747bd0370937bf317e6925582bff9be84
* Removing old kernels due to API change
SIM: https://sim.amazon.com/issues/NAPP-2811
GitOrigin-RevId: 4ae166fcb44c48ccfa4230a7b10271f522198052
* Fix pipeline by adding streamer to GPT2ForSampling
GitOrigin-RevId: 0edc3242b7cef42600028e637f9ce3f8576cb581
* Fix mistral bug for multiple sampling runs
GitOrigin-RevId: a70267456db853ad594e6c8038b3df65d75e0ee1
* Enabling custom RMSNorm across the board
GitOrigin-RevId: f8bb611ce3350a41457c11457b0c0f2ca3141108
* Revert "Enabling custom RMSNorm across the board"
This reverts commit f8bb611ce3350a41457c11457b0c0f2ca3141108.
GitOrigin-RevId: 34296a5be7bb73f0c3a8c993e2047cb27135d2ea
* [Speculation] Set cache_ids if None
GitOrigin-RevId: dc6422538c712b559f707c428e0d40a3b383e5ed
* flash attention layout shape fix
GitOrigin-RevId: 077993869b7fa69fa58c21d1181e00728d484098
* BIR flash attention kernel + GQA broadcasting
GitOrigin-RevId: 99cd0ed62100b84b2246bb76e5f19e9265f5dedc
* [debug] Enable returning all logits during context encoding
GitOrigin-RevId: 0555a82a0d908bfb8f35bff965ed802b1f3e85ed
* remove start_ids=None from generate()
GitOrigin-RevId: 0cca9dee080a3576c9e1c2d56391fde0a5fcbe47
* Remove internal code name
GitOrigin-RevId: d33b0d642bcdc752486495b735c3c78141d9dad2
* switch back to gelu_new_legacy
GitOrigin-RevId: a17892ccab44fc81655f587c7d1e42f5adaad97e
* Use collectives_layout=BSH as default and add deprecation warning
GitOrigin-RevId: 6ea9f56919adcca9a87937b72981ede68c63a6ae
* Rufus 2.19 Cherry picks
GitOrigin-RevId: c829df72d1f3ea77fdd484481bd39bf464cfe229
* Revert "Use collectives_layout=BSH as default and add deprecation warning"
This reverts commit 33b62dc57a682ed64fa58ab7df9d9b90fef6d163.
GitOrigin-RevId: bbd4c5aeb7a39146788917a3f706706b21f93ce5
* Do not use flash attention for batch > 1 left padded
GitOrigin-RevId: 2855c3be1efe3d223c11beac1395447e711166ba
* Cherry-pick on-device embedding changes:
[generation] Fix on-device generation without filtering
GitOrigin-RevId: d2dfddb3fe6c38b2b25817990d9360d0c056d8c0
* Enabling custom RMSNorm across the board
GitOrigin-RevId: 13059d1660d367712346b3ba0ab682c7826a2878
---------
Co-authored-by: Jonathan Lunt <[email protected]>
Co-authored-by: Bowen Chen <[email protected]>
Co-authored-by: Yuan Zhou <[email protected]>
Co-authored-by: Devesh Ratho <[email protected]>
Co-authored-by: Amer <[email protected]>
Co-authored-by: Mike Zhang <[email protected]>
Co-authored-by: Liangfu Chen <[email protected]>
Co-authored-by: Nicholas Waldron <[email protected]>
Co-authored-by: Shubham Chandak <[email protected]>
Co-authored-by: Wojciech Romaszkan <[email protected]>
Co-authored-by: Amulya Ballakur <[email protected]>
Co-authored-by: Dylan Geva <[email protected]>
Co-authored-by: Jeffrey Huynh <[email protected]>
Co-authored-by: Shashwat Srijan <[email protected]>
Co-authored-by: Prithvijit Chakrabarty <[email protected]>
Co-authored-by: Haichen Li <[email protected]>
Co-authored-by: Patrick Lange <[email protected]>
Co-authored-by: Hesam Ilati <[email protected]>
Co-authored-by: Tyler Osterberg <[email protected]>
Co-authored-by: karthick gopalswamy <[email protected]>
Co-authored-by: Faqin Zhong <[email protected]>
Co-authored-by: Yishan McNabb <[email protected]>
Co-authored-by: Akhil Raj Azhikodan <[email protected]>
Co-authored-by: yichi <[email protected]>1 parent 0623de2 commit c8d6bdc
File tree
39 files changed
+2881
-1780
lines changed- src/transformers_neuronx
- bloom
- gpt2
- gptj
- gptneox
- layers
- llama
- mistral
- mixtral
- nki
- kernels
- opt
39 files changed
+2881
-1780
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
14 | 14 | | |
15 | 15 | | |
16 | 16 | | |
17 | | - | |
18 | | - | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
19 | 21 | | |
20 | 22 | | |
21 | 23 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
15 | 15 | | |
16 | 16 | | |
17 | 17 | | |
18 | | - | |
| 18 | + | |
19 | 19 | | |
20 | 20 | | |
21 | 21 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
24 | 24 | | |
25 | 25 | | |
26 | 26 | | |
| 27 | + | |
27 | 28 | | |
28 | 29 | | |
29 | 30 | | |
| |||
85 | 86 | | |
86 | 87 | | |
87 | 88 | | |
88 | | - | |
| 89 | + | |
89 | 90 | | |
90 | 91 | | |
91 | 92 | | |
| 93 | + | |
| 94 | + | |
92 | 95 | | |
93 | 96 | | |
94 | 97 | | |
| |||
172 | 175 | | |
173 | 176 | | |
174 | 177 | | |
| 178 | + | |
| 179 | + | |
175 | 180 | | |
176 | 181 | | |
177 | 182 | | |
| |||
181 | 186 | | |
182 | 187 | | |
183 | 188 | | |
184 | | - | |
185 | 189 | | |
186 | 190 | | |
187 | 191 | | |
| |||
208 | 212 | | |
209 | 213 | | |
210 | 214 | | |
211 | | - | |
212 | | - | |
| 215 | + | |
| 216 | + | |
213 | 217 | | |
214 | 218 | | |
215 | | - | |
| 219 | + | |
216 | 220 | | |
217 | 221 | | |
218 | 222 | | |
| |||
225 | 229 | | |
226 | 230 | | |
227 | 231 | | |
| 232 | + | |
| 233 | + | |
228 | 234 | | |
229 | 235 | | |
230 | 236 | | |
231 | 237 | | |
232 | | - | |
| 238 | + | |
233 | 239 | | |
234 | 240 | | |
235 | 241 | | |
236 | 242 | | |
| 243 | + | |
| 244 | + | |
237 | 245 | | |
238 | 246 | | |
239 | 247 | | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
240 | 251 | | |
241 | 252 | | |
242 | 253 | | |
| |||
266 | 277 | | |
267 | 278 | | |
268 | 279 | | |
269 | | - | |
| 280 | + | |
270 | 281 | | |
271 | 282 | | |
272 | 283 | | |
273 | 284 | | |
274 | | - | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
| 288 | + | |
| 289 | + | |
275 | 290 | | |
276 | 291 | | |
277 | 292 | | |
| |||
281 | 296 | | |
282 | 297 | | |
283 | 298 | | |
284 | | - | |
| 299 | + | |
285 | 300 | | |
286 | 301 | | |
287 | 302 | | |
| |||
291 | 306 | | |
292 | 307 | | |
293 | 308 | | |
294 | | - | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
| 313 | + | |
295 | 314 | | |
296 | 315 | | |
297 | 316 | | |
298 | | - | |
| 317 | + | |
299 | 318 | | |
300 | 319 | | |
301 | 320 | | |
| |||
306 | 325 | | |
307 | 326 | | |
308 | 327 | | |
309 | | - | |
| 328 | + | |
310 | 329 | | |
311 | | - | |
| 330 | + | |
312 | 331 | | |
313 | 332 | | |
314 | 333 | | |
| |||
350 | 369 | | |
351 | 370 | | |
352 | 371 | | |
353 | | - | |
| 372 | + | |
| 373 | + | |
| 374 | + | |
| 375 | + | |
354 | 376 | | |
355 | 377 | | |
356 | 378 | | |
357 | 379 | | |
358 | 380 | | |
359 | | - | |
360 | | - | |
| 381 | + | |
| 382 | + | |
| 383 | + | |
361 | 384 | | |
362 | 385 | | |
363 | 386 | | |
364 | 387 | | |
365 | 388 | | |
366 | | - | |
| 389 | + | |
367 | 390 | | |
368 | 391 | | |
369 | 392 | | |
| |||
389 | 412 | | |
390 | 413 | | |
391 | 414 | | |
392 | | - | |
| 415 | + | |
| 416 | + | |
393 | 417 | | |
394 | 418 | | |
395 | 419 | | |
| |||
400 | 424 | | |
401 | 425 | | |
402 | 426 | | |
403 | | - | |
404 | | - | |
405 | | - | |
406 | | - | |
407 | | - | |
| 427 | + | |
| 428 | + | |
| 429 | + | |
| 430 | + | |
| 431 | + | |
408 | 432 | | |
409 | | - | |
| 433 | + | |
410 | 434 | | |
411 | 435 | | |
412 | 436 | | |
413 | 437 | | |
414 | 438 | | |
415 | | - | |
| 439 | + | |
| 440 | + | |
416 | 441 | | |
417 | 442 | | |
418 | 443 | | |
419 | 444 | | |
420 | | - | |
| 445 | + | |
421 | 446 | | |
422 | 447 | | |
423 | 448 | | |
| |||
428 | 453 | | |
429 | 454 | | |
430 | 455 | | |
431 | | - | |
432 | 456 | | |
433 | 457 | | |
434 | | - | |
435 | | - | |
| 458 | + | |
| 459 | + | |
436 | 460 | | |
437 | | - | |
| 461 | + | |
438 | 462 | | |
439 | 463 | | |
440 | 464 | | |
441 | 465 | | |
442 | 466 | | |
443 | 467 | | |
444 | | - | |
| 468 | + | |
445 | 469 | | |
446 | 470 | | |
447 | 471 | | |
| |||
464 | 488 | | |
465 | 489 | | |
466 | 490 | | |
467 | | - | |
468 | | - | |
| 491 | + | |
| 492 | + | |
| 493 | + | |
| 494 | + | |
| 495 | + | |
469 | 496 | | |
470 | 497 | | |
471 | 498 | | |
| |||
506 | 533 | | |
507 | 534 | | |
508 | 535 | | |
| 536 | + | |
| 537 | + | |
| 538 | + | |
| 539 | + | |
509 | 540 | | |
510 | 541 | | |
511 | 542 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
14 | 14 | | |
15 | 15 | | |
16 | 16 | | |
17 | | - | |
| 17 | + | |
18 | 18 | | |
19 | 19 | | |
20 | 20 | | |
21 | 21 | | |
22 | 22 | | |
23 | 23 | | |
24 | 24 | | |
| 25 | + | |
25 | 26 | | |
26 | | - | |
27 | | - | |
| 27 | + | |
| 28 | + | |
28 | 29 | | |
29 | 30 | | |
30 | | - | |
31 | | - | |
| 31 | + | |
32 | 32 | | |
33 | | - | |
| 33 | + | |
34 | 34 | | |
35 | 35 | | |
36 | 36 | | |
| |||
41 | 41 | | |
42 | 42 | | |
43 | 43 | | |
44 | | - | |
| 44 | + | |
45 | 45 | | |
| 46 | + | |
46 | 47 | | |
47 | 48 | | |
48 | 49 | | |
| |||
87 | 88 | | |
88 | 89 | | |
89 | 90 | | |
90 | | - | |
| 91 | + | |
91 | 92 | | |
92 | 93 | | |
93 | | - | |
94 | | - | |
95 | | - | |
96 | | - | |
97 | | - | |
98 | 94 | | |
99 | 95 | | |
100 | 96 | | |
| |||
111 | 107 | | |
112 | 108 | | |
113 | 109 | | |
114 | | - | |
115 | | - | |
116 | | - | |
117 | | - | |
118 | | - | |
119 | | - | |
120 | 110 | | |
121 | 111 | | |
122 | 112 | | |
| |||
0 commit comments