Skip to content

Commit c8d6bdc

Browse files
hannanjgawsjluntamazonaws-bowenccyazhom-awsdevesr-amzn
authored
Sync internal repo to external June 28 2024 (#90)
* [module] Added better/faster checkpoint support with both sharded/whole checkpoints GitOrigin-RevId: 474757c3e65895084384e2e67d811f0423880fcf * [generation-demo] Add --profile GitOrigin-RevId: 7ba35ca7839df4fd300686db3632eb23feffbd6e * [module] Added the ability to download weights from huggingface hub repositories GitOrigin-RevId: 88661e675a31f2f1784056982717b3b15e04c922 * [automodel] Added NeuronAutoModelForCausalLM class which automatically loads architecture-specific classes GitOrigin-RevId: 18095eec9b978327e74cac2f24f872ca93c876f9 * [util] Avoid truncating tensors when padded size is less than current size. GitOrigin-RevId: 948fa370b543ac66976693e956bc024167060c8f * [window-context] Add window context encoding GitOrigin-RevId: 2def53a09fb5cbd58e1842941be8abd8fca79988 * Add support for post-processing logits for on-device log-softmax GitOrigin-RevId: 416ad2a32cc89606872c62c366b89e51e506f861 * [generation-demo] Add torch-profile; Use random input with --prompt_len GitOrigin-RevId: 62f39cfc1b857882034b307a7169bc3b2a46e292 * fix conflict for bsh gated mlp GitOrigin-RevId: 34815f7a7261b9484e20aaf7e42f55a5de09fc87 * fix conflict for bsh gated mlp fix2 GitOrigin-RevId: 8bee8b2595d88837341faad75630349f33ea55dd * [speculation] Updated speculative generator to correctly insert last draft token into KV cache GitOrigin-RevId: 8a3eedd6ee516c641442cc1fa9c5639a491d097d * [llama] Added support for tied embedding/head weights GitOrigin-RevId: 65dea75643e5bc041bca7bdd8677a6951e3ffccc * [decoder] Added a warmup to all kernels to avoid unexpected initialization latency spikes GitOrigin-RevId: b62d7e7a2df4675e66354a6f56b527d0e332891f * [hlo] add support for bool literals in Python 3.11 GitOrigin-RevId: 07ad81981b19ccaf5c775f18b839f51690f5b2ae * Add support for Mistral-7B-v0.2 for no sliding window GitOrigin-RevId: 205dcf4a5c8ce6e3c7c6c98e604e5ee3df509054 * [pp] call self directlry instead of self.forward to enable forward hook GitOrigin-RevId: 908b7af05e1320a2ddded5734545b147cd7a20ba * [module] Added support for model checkpoints using base model prefix GitOrigin-RevId: 8e8bd0d72de318d1f4cbc9859b44dc4e0d9b0514 * Fused KV cache update for deduplicating index calculation GitOrigin-RevId: 888778dfed05b873b7020b4208ed403edb87158d * Add tags to models, this should pass tags down to Parallel kernel, prefix context kernels with 'context' GitOrigin-RevId: cb9b6f19d2c0e33c441a2c770c8eb8a3c0a60a23 * Add warmup logic when profile gets called on each kernel GitOrigin-RevId: 8225ec846d5bcfb35741e4982b72921ce248a55b * [decoder] Handle corner cases where the KV cache is None. GitOrigin-RevId: ac418340447a5b8b4fc8df0c46eb7e97d50befb3 * [decoder] Added prefix tags to more decoder types. Added less ambiguous tag parameter prefixes GitOrigin-RevId: 15a25fab827f0dc18c4b37d11517f9eb4c5cd875 * Set self.tag in base class. This is used by PipelineParallelProgram GitOrigin-RevId: 96ec44be9af30c88b607d8ddddb2ff5ff907ec5f * Extend TP support of Mixtral-8x7B model from 8 to 16 and 32 and fix accuracy issue GitOrigin-RevId: 2f0bbfb4934c8396327d9579afc8d5284887fe94 * support BSH attention layout for continuous batching GitOrigin-RevId: 9664ff667ce1baa7c7eaacb57ecbe81d75a82629 * [generation_demo] additional flags, minor fixes GitOrigin-RevId: 6ab804d5012a5286ab2d0f612d68e6e24854ea36 * [generation_demo] model from config support, minor fixes GitOrigin-RevId: 38c82a99d0f628fe4b791dcc4d51bf7d0c835303 * Require transformers>=4.36 GitOrigin-RevId: 6f8b1ef2e099d268188ae7ed3b055ea94f7cbf81 * Support on-device embedding for GPT2. Fix multi-layer model support for LLAMA and BLOOM and clean up forward function signatures. GitOrigin-RevId: e7ea681c09e9f712c41668f4cc1aa78104f467e3 * Fixing HSB GroupNorm implementation GitOrigin-RevId: 372b2cca5fae0418e8a4cf346cf87133ac33ddf6 * [compile-cache] Add NEURONX_DUMP_TO_NOTEMP to dump artifacts from neuron cache GitOrigin-RevId: 140a46779a5e42806b901b3896c95251c9260010 * Fix forward call for Mixtral GitOrigin-RevId: abefd80fc726015ab133dd40425d3ba97d1ff2f3 * [Speculation] Use target model to generate leftover tokens GitOrigin-RevId: a654d7c01e43fffe9c3253850a75ea17d04aac7d * add block diagonal causal mask for concatenated multi-prompt encoding GitOrigin-RevId: f806d511d0eac7bf766972bb43eed9308566b5e4 * Revert [Speculation] Use target model to generate leftover tokens GitOrigin-RevId: 76feacb3aa501239359e0c4939dacbdf311ca7e2 * [hlo] Support transposed attention output weight GitOrigin-RevId: 5d1d00c1773d57570248388943c7c567a5dac870 * [Speculation] Use target model to generate leftover tokens GitOrigin-RevId: 2f677787bbcea31e5738982f243183908e612a45 * [compiler] Fixed snapshot steps functionality in executor. Fixed warmup to never snapshot GitOrigin-RevId: 343b2ce9549a9762f698fe2029ed28ad1245eb0f * KV cache placement with slot_mapping GitOrigin-RevId: 18c217fd664e60bd7834702f64001eecfaa0d688 * Update quantization to support contraction dim GitOrigin-RevId: 397ded48a0850d98144d7a517f605d6ed3ac3691 * [mistral/mixtral/opt] Fix on-device embedding support for remaining models GitOrigin-RevId: a52299027aabeffd47f13c9a9c4b4df600e8a4c2 * Adjust profiling to pass through number of NTFF files generated and remove tar file creation GitOrigin-RevId: 689616f8951e0e778aeae0dbdc12d04c267b6cbf * Fuse QKV support for GQA GitOrigin-RevId: 938e3d267b77e4b5d225d214be47c190472472c7 * [Hlo] Change mmadd to dot_add and update it to support N-D tensors GitOrigin-RevId: 5a06e680dbda0ab33262127df0b4458f88d38eed * Replace problem characters in tagged HLO. This directly translates to filenames (NEFF, NTFF) GitOrigin-RevId: fe658484f8c6cebe691ab623c45ee69e3427b5c9 * [Release 2.18] Change version to 1.0 GitOrigin-RevId: 2c948b4669ab83591925595fcbba87319971369d * added ntff_count_limit argument to generation_demo GitOrigin-RevId: 9649a8710ce4fe13cd22891f1e8ed983377c1cf6 * Merge branch 'mainline' of ssh://git.amazon.com:2222/pkg/KaenaTransformers into mainline GitOrigin-RevId: aa1051241c38be805cca604855f4531c35eda83c * Merge branch 'mainline' of ssh://git.amazon.com:2222/pkg/KaenaTransformers into mainline GitOrigin-RevId: 31e2511deaf9b8aa2e4d983b1263377a74ab4cd8 * Merge branch 'mainline' of ssh://git.amazon.com:2222/pkg/KaenaTransformers into mainline GitOrigin-RevId: c5b1d876e387ab51e55c0b1c7a8320ab97a88777 * Fix position_offset param for OPTForSamplingNoEmbeddingHlo GitOrigin-RevId: 2aebac56ac5780fc7f88bc5f24bc7611562ef5bd * initial support for BSH cache layout GitOrigin-RevId: 68440f419057fbbb9befe02a5f95bbead9d4a24a * support BSH cache layout with BSH attention layout GitOrigin-RevId: 75e2807a57e509ef3dc9edae7c44421f911d8961 * [generation_demo] remove dump flag GitOrigin-RevId: 23a774b75e2f8e224e511531703391ccf5ee4b10 * Reenable prompt broadcasting in GPT2 (input batch 1, output batch N) GitOrigin-RevId: ec8772e01b121a2acd5526b9bbe3fa0bb9d334a9 * Fix return ranks for executors while using on-device sampling GitOrigin-RevId: 412113c5b10b0285fe7f3aafe82bae578b463024 * [Release 2.18] Change version to 0.10.x GitOrigin-RevId: fde8f715eca71e3ee1dac392b9e9b3527e6ee0cb * [module] Allow safetensors checkpoint downloads to be explicitly disabled GitOrigin-RevId: fb34b5e8ace114f086728382c0f60358fb687f24 * Override attn_implementation as eager to skip sdpa attn implemenatation GitOrigin-RevId: 2112b39c4a43cadc77ec551b6ed49ce33a7a72f1 * [hlo] Added primitive broadcasting. Added new operators. Added on-device speculative token selection GitOrigin-RevId: cf64e6130141d5d237cdc0277c8b6d3adc39630e * LHS alignment for static batching (vectorize last_token_id) GitOrigin-RevId: 41dc716837dc1416eea4b60c198a39710cc153a7 * fix cache_ids padding for llama CB and batch=1 SD GitOrigin-RevId: 7a28dcbe147e2a4400fe7ae7c8ac5cbad145e185 * Cherry-picks to 2.18 for multi-bucketing/multi-prompt for continuous batching GitOrigin-RevId: 74a3c4cdc4dc8ae7ee6a5a12a9734a084499224e * Fix "Unit Tests - 2 Core" errors in test_neuron_auto_model.py mixtral tests GitOrigin-RevId: 8d5c47068ccbd2e87672243aedd56c46b887a8b9 * fix generation_utils GitOrigin-RevId: fcb5254a8ecafbeef434cff5fce9075993cdd471 * support BSH cache layout with BSH attention layout GitOrigin-RevId: 5868197a4ed26eec2346b23acbc1c4ccf764f826 * Add the ability to return arbitrary tensors for debugging GitOrigin-RevId: fdfc3c19a864a19a578b109e95c003f23b98a93c * [generation_demo] remove dump flag GitOrigin-RevId: 4da57bd6b09f606f5d1f0d290965ef7ff39bde2c * fix pp due to the change of HloNeuronProgram GitOrigin-RevId: 19ea3848b7ba66bd5a216f25052001b4d150c67e * Changes to make simple_sample work with speculative decoding GitOrigin-RevId: 20164ebedb6ef274032fea570bfa52a385e75b40 * Add QKV weight tiling for fused QKV GitOrigin-RevId: da6a9ede322d529215879b96356778269460cc3f * Override attn_implementation as eager to skip sdpa attn implemenatation GitOrigin-RevId: 1705e85c4cec9c665acd74b37fc5a54b98e9b405 * Revert "fix pp due to the change of HloNeuronProgram" This reverts commit 19ea3848b7ba66bd5a216f25052001b4d150c67e. GitOrigin-RevId: fdbea022a128256e661bddc23a6ad709570a8f96 * Revert "Add the ability to return arbitrary tensors for debugging" This reverts commit fdfc3c19a864a19a578b109e95c003f23b98a93c. GitOrigin-RevId: 584a49ec3ac28754bfac83a203e66cbf9f7d333d * Merge branch 'mainline' of ssh://git.amazon.com:2222/pkg/KaenaTransformers into mainline GitOrigin-RevId: d0d79eba5c29edbe672ac58ee250bcfa20bc62ad * Merge branch 'mainline' of ssh://git.amazon.com:2222/pkg/KaenaTransformers into mainline GitOrigin-RevId: 148afab8f91d523c7436af46cd2d556546d96692 * redo fix for DecoderProgramMultiLayer constructor API change GitOrigin-RevId: 008f9c9568fdf60b2390dde2458944532bd12d5a * Merge branch 'mainline' of ssh://git.amazon.com:2222/pkg/KaenaTransformers into mainline GitOrigin-RevId: 8f4d0c97e4fdbe33c3458cb13ac1a74661218b4b * [llama] support qk norm GitOrigin-RevId: 9f395a3cecd822e83fe56653afe99b4930a22dd6 * [executor] Added the ability to support concatenated outputs using new runtime integration GitOrigin-RevId: 6c43d2ebd2be15db4830f751708c1a9184c7c6ea * [decoder] Decoder init_xxx_decoder methods clone a copy of the current class to allow derived classes to build themselves GitOrigin-RevId: cc6a5caa9e45652c5e230e0310889fa99fbfdbc7 * Fix return ranks for executors while using on-device sampling GitOrigin-RevId: 3268b18eae910ad9a4d2e1aafd4a55dacac8ff13 * [module] Allow ties between more than 2 parameters GitOrigin-RevId: 5de3ebb808ef09f2258fda0e575155592d9da658 * [gqa] Reduce GQA replication amount under certain conditions GitOrigin-RevId: a0772a0d8cb6d755e01acb3d0e7d1b84de3eccfb * [Release 2.18] Change version to 0.10.x GitOrigin-RevId: 66dd5f983669e3c45d90680289987f4fb44a92ce * [module] Allow safetensors checkpoint downloads to be explicitly disabled GitOrigin-RevId: 7a496d2d13bcaed3481b59e1c332332c8dc9404a * Expose GenerationConfig as a top level construct GitOrigin-RevId: 7ab8740ba4efd253b4b99905616b8d2d8892d607 * [decoder] Updated HLO builder 'inputs' function to only handle parameters. Moved mask creation to 'pre_layer' function GitOrigin-RevId: a6ff3f126b2fe1acaa8309fa63f6dd84fddbc6f4 * [decoder] Factored HLO graph build into more discrete functions GitOrigin-RevId: 755d790cfa127a7d4d4481f289bba8e9bf11146a * [sampling] Removed sampling from architecture-specific classes GitOrigin-RevId: 217f6a98ceb2cdba419f0523b3ab2bbba389e5ac * [llama/speculation] 2d kv-cache id update for speculative_forward GitOrigin-RevId: d25ac42ef25540598d04e6fa95afae643d35d1a7 * [hlo] Added primitive broadcasting. Added new operators. Added on-device speculative token selection GitOrigin-RevId: b4b239d015e30f9b35ad65a57eaceef672502214 * [hlo] Added tensor parallel softmax implementation GitOrigin-RevId: 42746ebd9c869cbec78e9b8b10ed7a4a38665680 * [generation] Removed sequence length restriction for on-device sampling GitOrigin-RevId: 49e6d01840a5d90689c1a323f232005c8c686e1f * [hlo] Simplified speculative token selection API GitOrigin-RevId: be99bb4f54f3499f27410842bf32eb75dd6bc729 * [hlo] Fixed speculative token selector batched broadcast issue GitOrigin-RevId: 5fc9a2d82d9a6e2ccf6d9986aa0359b24f982a72 * [base] Make batch sizes for enable_speculative_decoder optional GitOrigin-RevId: b78200a717c44820b7b32f4d465dc74cf55772b6 * [decoder] Added a method to be able to retrieve all decoder parameters GitOrigin-RevId: 9954297014f8cab87931ffc376e4c28dbe16552b * [program] Added a ParallelProgram model executor with a simplified/general interface GitOrigin-RevId: 337066258c251392c01be101b52c6c463107ea41 * LHS alignment for static batching (vectorize last_token_id) GitOrigin-RevId: 8a678f4b979dc59d40e92bb74de3cacde9859d16 * [hlo] legalize broadcast and batch norm lowerings GitOrigin-RevId: 761573b84de85088321a65f098b66c7a258cdb9a * Revert "[llama] support qk norm" This reverts commit 9f395a3cecd822e83fe56653afe99b4930a22dd6. GitOrigin-RevId: 9dddbad3a8d87c4e3db92d5a597b7421ff0a9de0 * [hlo] Allow speculative token selection to pad with a specified token GitOrigin-RevId: dc4f9669d5e4b142043a4515ebab53aa6d217f33 * [snapshot] Add tag as iteration suffix. Fixed Executor snapshot iteration GitOrigin-RevId: b221a1f6d444916ed0a1f1d4f74a63fe437eb873 * [GQA] Reduce replication amount under certain conditions GitOrigin-RevId: 6eb3504e1e0c74f184f60cd1988b5f016e257967 * [debugger] Re-enable debugger. GitOrigin-RevId: ba251cd21ef4fd9a9f54d3060c7339e565fbf493 * [hlo] Added sort_with_indices & argsort GitOrigin-RevId: 59e37888ac4b2206ee4a3d00bac18a2bbb421c75 * [program] Added bucketed program and selector classes GitOrigin-RevId: 2af53bd710dead8690fb28fbad9683b06dbaee26 * [hlo] Update argmax to use built in hlo ops; update reshape to handle scalar input GitOrigin-RevId: 0c30fbec405d7bb713b9508e818297d7a9b3de27 * [generation] Added Top-P on-device sampling GitOrigin-RevId: 961d6d91496cbe574a3653d91ab6b43a0142219d * [hlo][topk] use built in hlo functions GitOrigin-RevId: 4e494643c36ec8949e32f086f4af5fccfd31f667 * Fix overwriting of cache_ids causing accuracy issues GitOrigin-RevId: e1ad8b5da91b3a7455060aa8dbf582c113a9142b * [hlo] Add support for custom activations in MLP GitOrigin-RevId: 7fe42367042aa68783ae3dd7d703693cfe9766c5 * Adapt recent NKI changes into NKI interface and add NKI flash attention kernel. SIM: https://sim.amazon.com/issues/NAPP-1895 GitOrigin-RevId: d95005cd4809e46b6774c5bc76af21dacf22259b * remove unused line sim: https://sim.amazon.com/issues/NAPP-1895 GitOrigin-RevId: 9375aa04bf55f6d5dac4ba7f51ab8f331f93b741 * change bf16 dtype to nl.bfloat16 instead of |V2 sim: https://sim.amazon.com/issues/NAPP-1895 GitOrigin-RevId: 73c6a276eee8fa496cc7e0f247dccee1fdb57997 * change bf16 dtype to nl.bfloat16 instead of |V2 sim: https://sim.amazon.com/issues/NAPP-1895 GitOrigin-RevId: 61b5b20dc8efd0fa14bd0254021e0a806ca48dd6 * removed unnecessary flash attention kernel sim: https://sim.amazon.com/issues/NAPP-1895 GitOrigin-RevId: fc0cb2da6119b46cf35c082d3dd777e224f9a10d * adding the wrapper sim: https://sim.amazon.com/issues/NAPP-1895 GitOrigin-RevId: e2139891727481d9684cdb4eb5ec048662633848 * Added flash attention kernel wrapper sim: https://sim.amazon.com/issues/NAPP-1895 GitOrigin-RevId: b084bcb04a6b912e2bc179f1347dfe1d609bbe44 * remove verification file sim: https://sim.amazon.com/issues/NAPP-1895 GitOrigin-RevId: 9bd2759aa5ac12090f4477678d1f2eaedccaa7e9 * [hlo] Add error handling for non-callable activation functions GitOrigin-RevId: 46fb6d7ff93dd49ccf6d5d952767f9237366154a * [hlo] Updated speculative token selection softmax to use sharded compute GitOrigin-RevId: c74bd613988ea9373d1500461ed4c393c724d726 * [Generation] Add quantization and flash attention args GitOrigin-RevId: ee29661a5a0ed504e1b24057eb87bbba53b9c68c * [hlo] Add ability to return token selection mask from speculative_token_selection GitOrigin-RevId: 3882aad7066bc954d08e610469428fcccde1829a * Fix CB logit postprocessing for full batches GitOrigin-RevId: 676a12cc8ea0056a3fade4636bb50aefccf0eec6 * [base] Fixed off-by-one error in window context encoding and make unroll factor optional GitOrigin-RevId: 24850985ce6cae43972c3092ab2dbc3764b62b49 * add missing debug_tensors param GitOrigin-RevId: 585647b7f8a4f47952eeb625b0c4c1692ce3c51c * [continuous batching] Add multi-bucketing support to continuous batching GitOrigin-RevId: ebc57784b3d9400923d11551a1c01b95260d0c94 * Revert "Fix CB logit postprocessing for full batches" This reverts commit 676a12cc8ea0056a3fade4636bb50aefccf0eec6. GitOrigin-RevId: b9a20bb2053e5b9067cad3310ece9849526acf49 * Updated BSH/HSB attention layouts to use unified logic GitOrigin-RevId: 03bdaf947c48ebf4b216eb7183d6cc388e010f10 * flash attention integration GitOrigin-RevId: fe46b47151825c9c8783adde96e2330aab9a9783 * [decoder/utils] Factored out QKV padding calculation to common utility GitOrigin-RevId: e3da5c6bfc2c7a2adbb26abedf129ff92ac7d14b * [config] Added type annotations to NeuronConfig, added deprecation warnings for old arguments, and exposed top-level configurations GitOrigin-RevId: 2a64c500d43994bb5a8dbeb5f4d7a9b20b74e38b * [attention] Fixed B & S transpose issue in BSH attention output GitOrigin-RevId: 57c27b49d91b965c26436f69e8ad2c88fc14a099 * [sos] refactor flash decoding related code Author: karthick gopalswamy <[email protected]> Date: 2024-04-05T01:22:10.000Z GitOrigin-RevId: 3630b8ac60247ff01d820b47b94bb3b33db2704a * [Do Not Merge] Flash decoding chanegs GitOrigin-RevId: 2f5a475e7e9d91b96ea3c9d7af94617aa8b143c3 * [Do Not Merge] Flash Decoding mask and cache modification GitOrigin-RevId: c5af182cda9aa7e012e19a36527e4de6df2241cf * [sos] sequence parallel attention GitOrigin-RevId: 93d1e82027d5896398043d97a42a99cbe5497e2e * Revert "[Do Not Merge] Flash Decoding mask and cache modification" This reverts commit 3b59fbda15971bd690e357d412f91f121f2b9097. GitOrigin-RevId: e9adb2dfe13738de009d81b35124cbe635ff9768 * Revert "[Do Not Merge] Flash decoding chanegs" This reverts commit 8b0973a6d7540b0715d640365b02628572f3ef4d. GitOrigin-RevId: c1ffbc49f80a54b2e8382aca6270cdd416fa6ec6 * [sos] refactor flash decoding related code GitOrigin-RevId: 853274b5619bbf267b44231c15f0b2e958d11d16 * [hlo] Add standalone definition for top-p GitOrigin-RevId: 7c68c568ee3866a245f587aecd3bc91c11c18622 * [llama] Fixed accuracy issue with speculation GitOrigin-RevId: 93214daabb17d4422dc85a72a95c429ab59a40c0 * [hlo] Add masked top-k implementation GitOrigin-RevId: bb8e127fd115eda0032ead1cd3dc657337c8cda1 * [rotary] Updated functions to use hlo.py operator definitions GitOrigin-RevId: a82e7a3c6cc2132ac529df7fca42d014c5da1f3a * [generation] Enable dynamic static shape sampling in generation.py GitOrigin-RevId: e0098078d6c8ab4f6c6b7e869e15f9ab0571350c * add reduce-scatter to hlo GitOrigin-RevId: 04dc756760c9d4e1952ef05f3b79c735390dcdad * Fix breaking change in flash attention kernel interface GitOrigin-RevId: 745f460b0866c4fca84e860f5107011e9ff5d2d5 * Add streamer support during on device sampling GitOrigin-RevId: 64a68d3cb951de326df37221f7490ec8ba370192 * have GPT2ForSamplingWithContextBroadcasting sample accept a streamer in mainline GitOrigin-RevId: 56582d463a5d91a14399e04406bd673d45c7916c * fix continuous batching support for llama and mistral GitOrigin-RevId: e490fc0bf033948f57ac3f1bbdeddf3070b5bf8c * bloom, gptj, gptneox, opt accept streamer in sample GitOrigin-RevId: f9a4eeb4b78ff7ea684f056b8b5229ffa0bccdba * commit: c9b19fbb411f9ea0efa9eebd3f62477122752f97 Author: Joshua Hannan <[email protected]> Date: 2024-04-17T02:51:14.000Z [generation] Add the ability to do dynamic on device generation GitOrigin-RevId: 95070848da3c80b459443d743af070a05ba573b5 * [generation][Mixtral/OPT] Update Mixtral and OPT to support on device generation GitOrigin-RevId: 8ea893cd895c39340045df707f3475f4f138e15e * [generation] Handle non-fp32 inputs for temperature GitOrigin-RevId: 7a02696de81f03f43af7ea1044c97b368eccd4fb * [compiler] Updated shape checking to be performed on provided inputs GitOrigin-RevId: 6bb6d357418a7119bea628c8b20100c0ba1ac7af * [config] Added equality comparison function to GenerationConfig GitOrigin-RevId: fe26c7603e5b4704239f2ed3ddc67ca25a4382d0 * [hlo] Update top-p to not perform reverse ops GitOrigin-RevId: f4783bfe7f54a60c17495d3e8bd843c1755076c9 * [hlo] Added fast implementation of cumsum operation GitOrigin-RevId: 9de129e3cb99dd8666b999e4164c3d86a47b6e02 * [hlo] Updated fast cumulative sum to always cast inputs GitOrigin-RevId: 77abc934e56ba22dc4d4861b3d1caac8ab11d0c1 * Optimize QKV padding for certain GQA models GitOrigin-RevId: a5527ab6dac38ab8fa217462f8b9b669360055b6 * [hlo] Added fast cumulative sum floating point check GitOrigin-RevId: 3b7a87101d5af41770971383726e6fc48fc323d6 * [generation] Transpose logits prior to generation GitOrigin-RevId: 806dfa933cf5ae88a6017e9ab2ef613cd3740f5e * [hlo] Added decomposed speculation functions GitOrigin-RevId: 752dd8d1fa5afaf215afce18ee032ad2f79f374b * [generation] Add the ability to specify a global top-k GitOrigin-RevId: fe9606d00bb58b3265096cb03870f32a451eb7ac * [llama/speculation] Bug fixes: speculative_forward with continuous batching GitOrigin-RevId: 702779e17d910afc3c150b3b5d329afb2357b348 * Fix mistral sliding window bug GitOrigin-RevId: e4bdc2b6027442b14bb954af8c93d1e1710a4d47 * [sos] flash decoding changes for llama model GitOrigin-RevId: a17189cf52ebba193f047a282f742c214d171559 * Optimize logit ordering for continuous batching GitOrigin-RevId: 5198b8a747bd0370937bf317e6925582bff9be84 * Removing old kernels due to API change SIM: https://sim.amazon.com/issues/NAPP-2811 GitOrigin-RevId: 4ae166fcb44c48ccfa4230a7b10271f522198052 * Fix pipeline by adding streamer to GPT2ForSampling GitOrigin-RevId: 0edc3242b7cef42600028e637f9ce3f8576cb581 * Fix mistral bug for multiple sampling runs GitOrigin-RevId: a70267456db853ad594e6c8038b3df65d75e0ee1 * Enabling custom RMSNorm across the board GitOrigin-RevId: f8bb611ce3350a41457c11457b0c0f2ca3141108 * Revert "Enabling custom RMSNorm across the board" This reverts commit f8bb611ce3350a41457c11457b0c0f2ca3141108. GitOrigin-RevId: 34296a5be7bb73f0c3a8c993e2047cb27135d2ea * [Speculation] Set cache_ids if None GitOrigin-RevId: dc6422538c712b559f707c428e0d40a3b383e5ed * flash attention layout shape fix GitOrigin-RevId: 077993869b7fa69fa58c21d1181e00728d484098 * BIR flash attention kernel + GQA broadcasting GitOrigin-RevId: 99cd0ed62100b84b2246bb76e5f19e9265f5dedc * [debug] Enable returning all logits during context encoding GitOrigin-RevId: 0555a82a0d908bfb8f35bff965ed802b1f3e85ed * remove start_ids=None from generate() GitOrigin-RevId: 0cca9dee080a3576c9e1c2d56391fde0a5fcbe47 * Remove internal code name GitOrigin-RevId: d33b0d642bcdc752486495b735c3c78141d9dad2 * switch back to gelu_new_legacy GitOrigin-RevId: a17892ccab44fc81655f587c7d1e42f5adaad97e * Use collectives_layout=BSH as default and add deprecation warning GitOrigin-RevId: 6ea9f56919adcca9a87937b72981ede68c63a6ae * Rufus 2.19 Cherry picks GitOrigin-RevId: c829df72d1f3ea77fdd484481bd39bf464cfe229 * Revert "Use collectives_layout=BSH as default and add deprecation warning" This reverts commit 33b62dc57a682ed64fa58ab7df9d9b90fef6d163. GitOrigin-RevId: bbd4c5aeb7a39146788917a3f706706b21f93ce5 * Do not use flash attention for batch > 1 left padded GitOrigin-RevId: 2855c3be1efe3d223c11beac1395447e711166ba * Cherry-pick on-device embedding changes: [generation] Fix on-device generation without filtering GitOrigin-RevId: d2dfddb3fe6c38b2b25817990d9360d0c056d8c0 * Enabling custom RMSNorm across the board GitOrigin-RevId: 13059d1660d367712346b3ba0ab682c7826a2878 --------- Co-authored-by: Jonathan Lunt <[email protected]> Co-authored-by: Bowen Chen <[email protected]> Co-authored-by: Yuan Zhou <[email protected]> Co-authored-by: Devesh Ratho <[email protected]> Co-authored-by: Amer <[email protected]> Co-authored-by: Mike Zhang <[email protected]> Co-authored-by: Liangfu Chen <[email protected]> Co-authored-by: Nicholas Waldron <[email protected]> Co-authored-by: Shubham Chandak <[email protected]> Co-authored-by: Wojciech Romaszkan <[email protected]> Co-authored-by: Amulya Ballakur <[email protected]> Co-authored-by: Dylan Geva <[email protected]> Co-authored-by: Jeffrey Huynh <[email protected]> Co-authored-by: Shashwat Srijan <[email protected]> Co-authored-by: Prithvijit Chakrabarty <[email protected]> Co-authored-by: Haichen Li <[email protected]> Co-authored-by: Patrick Lange <[email protected]> Co-authored-by: Hesam Ilati <[email protected]> Co-authored-by: Tyler Osterberg <[email protected]> Co-authored-by: karthick gopalswamy <[email protected]> Co-authored-by: Faqin Zhong <[email protected]> Co-authored-by: Yishan McNabb <[email protected]> Co-authored-by: Akhil Raj Azhikodan <[email protected]> Co-authored-by: yichi <[email protected]>
1 parent 0623de2 commit c8d6bdc

39 files changed

+2881
-1780
lines changed

src/transformers_neuronx/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
# ==============================================================================
1515
from transformers_neuronx.version import __version__
1616

17-
from transformers_neuronx.config import NeuronConfig, QuantizationConfig, ContinuousBatchingConfig
18-
from transformers_neuronx.constants import GQA
17+
18+
from transformers_neuronx.constants import GQA, Layout
19+
from transformers_neuronx.sparse_attn_utils import SparseAttnConfig
20+
from transformers_neuronx.config import NeuronConfig, QuantizationConfig, ContinuousBatchingConfig, GenerationConfig
1921
from transformers_neuronx.generation_utils import HuggingFaceGenerationModelAdapter
2022

2123
from transformers_neuronx.bloom.model import BloomForSampling

src/transformers_neuronx/activations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import math
1616

1717
def gelu_new(hidden):
18-
return hidden.dtype[hidden.sizes].CustomCall(hidden, custom_call_target="AwsNeuronGelu")
18+
return hidden.dtype[hidden.sizes].CustomCall(hidden, custom_call_target="AwsNeuronGeluApprxTanh")
1919

2020
def gelu_new_legacy(hidden):
2121
dtype = hidden.dtype

src/transformers_neuronx/base.py

Lines changed: 64 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from transformers_neuronx import module
2525
from transformers_neuronx.compiler import ParallelKernel
2626
from transformers_neuronx.constants import LAYOUT_BSH
27+
from transformers_neuronx.config import GenerationConfig
2728
from concurrent.futures import ProcessPoolExecutor
2829

2930

@@ -85,10 +86,12 @@ def enable_speculative_decoder(self, speculation_length: Optional[Union[List[int
8586
self.decoder_lm_head_for_speculation[k, batch_size] = \
8687
self.decoder_param_set.init_speculative_decoder(unroll=self.unroll, buckets=self.token_buckets, model_obj=self, n_active_tokens=k, batch_size=batch_size)
8788

88-
def enable_window_context_decoder(self, window_context_length:Optional[Union[List[int], int]], unroll):
89+
def enable_window_context_decoder(self, window_context_length:Optional[Union[List[int], int]], unroll: Optional[int] = None):
8990
if isinstance(window_context_length, int):
9091
window_context_length=[window_context_length]
9192
self.window_context_buckets = bucket.context_sizes(window_context_length, self.token_buckets)
93+
if unroll is None:
94+
unroll = self.decoder_param_set.num_layers
9295
for k in self.window_context_buckets:
9396
self.decoder_lm_head_for_window_context[k]=self.decoder_param_set.init_window_context_decoder(unroll=unroll, buckets=self.token_buckets, model_obj=self, n_active_tokens=k)
9497

@@ -172,6 +175,8 @@ def context(self, hidden, cache_ids, start_ids, last_token_id, *rest):
172175
context_length = hidden.shape[1]
173176
batch_size = start_ids.shape[0]
174177

178+
all_logits = [] # Collect all logits if neuron_config.output_all_logits is True
179+
175180
if self.is_fid:
176181
# Fusion-In-Decoder context encoding
177182
fused_context_length = hidden.shape[1]
@@ -181,7 +186,6 @@ def context(self, hidden, cache_ids, start_ids, last_token_id, *rest):
181186

182187
estimate = bucket.find(self.context_buckets, context_length)
183188

184-
185189
if estimate is not None:
186190
hidden_context = hidden
187191
cache_context = cache_ids
@@ -208,11 +212,11 @@ def context(self, hidden, cache_ids, start_ids, last_token_id, *rest):
208212
logits, scores = model(hidden_context, cache_context, start_ids, last_token_id, *rest)
209213
else:
210214
logits = model(hidden_context, cache_context, start_ids, last_token_id, *rest)
211-
212-
215+
if self.neuron_config.output_all_logits:
216+
all_logits.append(logits[:, :last_token_id + 1, :])
213217

214218
# process the leftovers context
215-
while current < context_length - 1:
219+
while current < context_length:
216220
# find the optimal "window"
217221
estimate = None
218222
if hasattr(self, "window_context_buckets"):
@@ -225,18 +229,25 @@ def context(self, hidden, cache_ids, start_ids, last_token_id, *rest):
225229
cache_ids = torch.as_tensor([i], dtype=torch.int32)
226230
hidden_slice = hidden[:, i:i+1].contiguous()
227231
logits = self.decoder_lm_head(hidden_slice, cache_ids, start_ids, last_token_id, *rest)
232+
if self.neuron_config.output_all_logits:
233+
all_logits.append(logits)
228234
break
229235

230236
hidden_slice = hidden[:, current:current+estimate].contiguous()
231237
cache_ids = torch.as_tensor([i for i in range(current, current+estimate)], dtype=torch.int32)
232-
last_token_id = torch.as_tensor(estimate - 1)
238+
last_token_id = torch.as_tensor([estimate - 1])
233239
if self.neuron_config.log_softmax_scores:
234240
logits, scores = self.decoder_lm_head_for_window_context[estimate](hidden_slice, cache_ids, start_ids, last_token_id, *rest)
235241
else:
236242
logits = self.decoder_lm_head_for_window_context[estimate](hidden_slice, cache_ids, start_ids, last_token_id, *rest)
243+
if self.neuron_config.output_all_logits:
244+
all_logits.append(logits)
237245

238246
current += estimate
239247

248+
if all_logits:
249+
logits = torch.cat(all_logits, dim=1)
250+
240251
if self.is_fid:
241252
logits[:] = float('-inf')
242253
logits[self.bos_token_id] = 1.0
@@ -266,12 +277,16 @@ def _prepare_for_par_ctx_rhs_padding(self, input_ids, cache_ids):
266277
if self.neuron_config.vectorize_last_token_id:
267278
last_token_id = torch.zeros(batch_size, dtype=torch.int32)
268279
else:
269-
last_token_id = torch.as_tensor(0, dtype=torch.int32)
280+
last_token_id = torch.as_tensor([0], dtype=torch.int32)
270281
if context_length == 1:
271282
return input_ids, cache_ids, last_token_id
272283

273284
# TODO: check context_buckets for compatibility with OPT
274-
if hasattr(self, "context_buckets"):
285+
if cache_ids is not None and cache_ids.flatten()[0].item() > 0:
286+
# speculative forward: n_active_tokens > 1 and cache_ids start from position > 0
287+
speculation_buckets = list(set([k for k, batch_size in self.decoder_lm_head_for_speculation.keys()]))
288+
estimate = bucket.find(speculation_buckets, context_length)
289+
elif hasattr(self, "context_buckets"):
275290
estimate = bucket.find(self.context_buckets, context_length)
276291
else:
277292
estimate = self.context_length_estimate
@@ -281,7 +296,7 @@ def _prepare_for_par_ctx_rhs_padding(self, input_ids, cache_ids):
281296
if self.neuron_config.vectorize_last_token_id:
282297
last_token_id = cache_ids.max(dim=1).values
283298
else:
284-
last_token_id = torch.as_tensor(min(context_length - 1, estimate-1), dtype=torch.int32)
299+
last_token_id = torch.as_tensor([min(context_length - 1, estimate-1)], dtype=torch.int32)
285300
if context_length < estimate:
286301
input_ids = utils.pad(input_ids, 1, estimate, left=False)
287302
cache_ids = self._pad_cache_ids(cache_ids, batch_size, context_length, estimate)
@@ -291,11 +306,15 @@ def _prepare_for_par_ctx_rhs_padding(self, input_ids, cache_ids):
291306
def _pad_cache_ids(self, cache_ids, batch_size, context_length, estimate):
292307
if self.neuron_config.use_2d_cache_ids:
293308
# TODO: fix cache_ids padding for batch speculative decoding
294-
cache_ids = torch.arange(estimate, dtype=torch.long)
309+
# for now, use cache_ids without change for speculative_forward
310+
is_speculative_forward = cache_ids.flatten()[0].item() > 0
311+
if is_speculative_forward:
312+
return cache_ids
313+
cache_ids = torch.arange(estimate, dtype=torch.int32)
295314
cache_ids = cache_ids.unsqueeze(0).expand(batch_size, estimate)
296315
else:
297316
if cache_ids is None:
298-
cache_ids = torch.arange(estimate, dtype=torch.long)
317+
cache_ids = torch.arange(estimate, dtype=torch.int32)
299318
else:
300319
# Inputs: cache_ids = [16, 17], estimate = 512
301320
#
@@ -306,9 +325,9 @@ def _pad_cache_ids(self, cache_ids, batch_size, context_length, estimate):
306325
# cache_ids = [16, 17, 18, 19, ..., 511, 511, 511, ..., 511, 511, 511]
307326
start_idx = cache_ids[-1].item() + 1
308327
end_idx = estimate + start_idx - context_length
309-
pad_elements = torch.arange(start_idx, end_idx, dtype=torch.long)
328+
pad_elements = torch.arange(start_idx, end_idx, dtype=torch.int32)
310329
cache_ids_pad = torch.concat([cache_ids, pad_elements], dim=0)
311-
cache_ids = torch.minimum(cache_ids_pad, torch.tensor(estimate-1, dtype=torch.long))
330+
cache_ids = torch.minimum(cache_ids_pad, torch.tensor(estimate-1, dtype=torch.int32))
312331
return cache_ids
313332

314333
def _prepare_for_continuous_batching(self, input_ids, cache_ids=None, seq_ids=None):
@@ -350,20 +369,24 @@ def _prepare_for_continuous_batching(self, input_ids, cache_ids=None, seq_ids=No
350369
cache_ids = cache_ids.unsqueeze(0)
351370
assert cache_ids.shape[0] == n_active_seqs, \
352371
f"invalid n_active_seqs ({n_active_seqs} vs {cache_ids.shape[0]}) in speculative forward"
353-
cache_ids_pad = torch.zeros(n_active_seqs, speculative_n_positions, dtype=cache_ids.dtype, device='cpu')
372+
# pad cache IDs with max(n_positions) - 1
373+
# unlike context encoding, padding with 0
374+
# during speculative_forward will contaminate kv-cache history
375+
cache_ids_pad = torch.full((n_active_seqs, speculation_bucket), max(self.context_buckets) - 1, dtype=cache_ids.dtype, device="cpu")
354376
for seq_id in range(n_active_seqs):
355377
cache_ids_pad[seq_id, :n_active_tokens] = cache_ids[seq_id, :n_active_tokens]
356378
return input_ids, cache_ids_pad, seq_ids
357379

358380
# token generation
359-
full_input_ids = torch.zeros(batch_size, 1, dtype=input_ids.dtype, device="cpu")
360-
full_cache_ids = torch.zeros(batch_size, 1, dtype=cache_ids.dtype, device="cpu")
381+
full_input_ids = torch.zeros(batch_size, 1, dtype=torch.int32)
382+
full_cache_ids = torch.zeros(batch_size, 1, dtype=torch.int32)
383+
full_seq_ids = torch.arange(batch_size, dtype=torch.int32)
361384
for idx, seq_id in enumerate(seq_ids.flatten()):
362385
seq_id = seq_id.item()
363386
full_input_ids[seq_id, :] = input_ids[idx, :]
364387
full_cache_ids[seq_id, :] = cache_ids[idx, :]
365388

366-
return full_input_ids, full_cache_ids, seq_ids
389+
return full_input_ids, full_cache_ids, full_seq_ids
367390

368391
def _preprocess(self, input_ids, start_ids=None, cache_ids=None):
369392
# enable dynamic batch size feature for continuous batching
@@ -389,7 +412,8 @@ def _preprocess(self, input_ids, start_ids=None, cache_ids=None):
389412
return input_ids, cache_ids, start_ids, last_token_id
390413

391414
def _postprocess(self, logits, start_ids=None):
392-
if start_ids is None:
415+
416+
if start_ids is None or (self.neuron_config.output_all_logits and logits.shape[1] > 1):
393417
return logits
394418

395419
running_batch_size, n_embed = logits.shape
@@ -400,24 +424,25 @@ def _postprocess(self, logits, start_ids=None):
400424
return logits
401425

402426
# token generation (aka decoding)
403-
seq_ids = start_ids.flatten().tolist()
404-
assert input_batch_size == len(seq_ids), f"expected seq_ids to be {input_batch_size} in length, but seq_ids={seq_ids}"
405-
new_logits = torch.zeros(input_batch_size, n_embed, dtype=logits.dtype, device=logits.device)
406-
for idx, seq_id in enumerate(seq_ids):
407-
new_logits[idx, :] = logits[seq_id, :]
427+
seq_ids = start_ids.flatten()
428+
if torch.equal(seq_ids, torch.arange(input_batch_size)):
429+
logits = logits[:input_batch_size]
430+
else:
431+
logits = logits[seq_ids]
408432

409-
return new_logits
433+
return logits
410434

411435
def _cast_logits(self, logits):
412436
# Cast logits to float32 or the dtype specified in the neuron config
413437
logits_dtype = torch.float32
414438
if self.neuron_config:
415-
logits_dtype = getattr(torch, self.neuron_config.cast_logits_dtype)
439+
if self.neuron_config.cast_logits_dtype is not None:
440+
logits_dtype = getattr(torch, self.neuron_config.cast_logits_dtype)
416441
return logits.to(logits_dtype)
417442

418443
def _context_dynamic_batching(self, hidden, *args):
419444
is_bsh = self.neuron_config and self.neuron_config.attention_layout == LAYOUT_BSH
420-
input_batch_size = hidden.shape[0] if is_bsh else hidden.shape[2]
445+
input_batch_size = hidden.shape[0] if is_bsh or self.neuron_config.on_device_embedding else hidden.shape[2]
421446
assert hasattr(self, "context_batch_sizes"), f"{type(self)} doesn't support dynamic batching."
422447

423448
running_batch_size = self.context_batch_sizes[-1]
@@ -428,20 +453,19 @@ def _context_dynamic_batching(self, hidden, *args):
428453
all_logits = []
429454
cache_ids, start_ids, last_token_id = args[0], args[1], args[2]
430455
for iter_id in range(n_iters):
431-
# Assuming HSB layout
432456
start_idx = iter_id*running_batch_size
433457
end_idx = (iter_id+1)*running_batch_size
434-
if is_bsh:
435-
hidden_per_batch = hidden[start_idx:end_idx, :, :]
458+
if is_bsh or self.neuron_config.on_device_embedding:
459+
hidden_per_batch = hidden[start_idx:end_idx, ...]
436460
else:
437-
hidden_per_batch = hidden[:, :, start_idx:end_idx]
461+
hidden_per_batch = hidden[..., start_idx:end_idx]
438462
cache_ids_per_batch = cache_ids[start_idx:end_idx, :]
439463
start_ids_per_batch = start_ids[start_idx:end_idx]
440464
last_token_id_per_batch = last_token_id[start_idx:end_idx]
441465
logits_per_batch = self.context(hidden_per_batch, cache_ids_per_batch,
442466
start_ids_per_batch, last_token_id_per_batch)
443467
all_logits.append(logits_per_batch)
444-
logits = torch.cat(all_logits, dim=2)
468+
logits = torch.cat(all_logits, dim=-1)
445469
else:
446470
assert input_batch_size == running_batch_size, \
447471
"input batch size ({input_batch_size}) not equal to running batch size ({running_batch_size})"
@@ -464,8 +488,11 @@ def _forward(self, hidden, *args):
464488
return logits
465489

466490
logits = self._cast_logits(logits)
467-
logits = logits[:self.config.vocab_size, -1, :]
468-
logits = logits.transpose(0, 1)
491+
if self.neuron_config.output_all_logits and context_length > 1:
492+
logits = logits.permute(2, 1, 0)
493+
else:
494+
logits = logits[:self.config.vocab_size, -1, :]
495+
logits = logits.transpose(0, 1)
469496
return logits
470497

471498

@@ -506,6 +533,10 @@ def profile(self, profile_dir, ntff_count_limit):
506533
if isinstance(kernel, ParallelKernel):
507534
kernel.profile(profile_dir, ntff_count_limit)
508535

536+
def update_generation_config(self, generation_config: GenerationConfig):
537+
self.decoder_lm_head.update_generation_config(generation_config)
538+
539+
509540
# Base class for all "Serializable Objects"
510541
class NeuronBaseSerializer:
511542

src/transformers_neuronx/bloom/hlo.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,23 @@
1414
# ==============================================================================
1515
from transformers_neuronx import hlo
1616
from transformers_neuronx.constants import LAYOUT_BSH
17-
from transformers_neuronx.layers import transformer, alibi, generation
17+
from transformers_neuronx.layers import transformer, alibi, attention
1818
from transformers_neuronx.bloom.config import BloomConfig
1919

2020
class BloomForSamplingNoEmbeddingHlo:
2121

2222
def __init__(self, config: BloomConfig, neuron_config=None):
2323
self.config = config
2424
self.neuron_config = neuron_config
25+
self.n_positions = None
2526

26-
def inputs(self, scribe, dtype, n_positions, n_active_tokens, batch_size):
27-
hidden, cache_ids, start_ids, last_token_id, dims = transformer.inputs(
27+
def inputs(self, scribe, dtype, n_active_tokens, batch_size):
28+
tensors, dims = transformer.inputs(
2829
scribe, dtype, batch_size, n_active_tokens, self.config.hidden_size, self.neuron_config
2930
)
30-
mask, active_mask = hlo.attention_mask(cache_ids, start_ids, n_positions)
31-
return (hidden, last_token_id, cache_ids, mask, active_mask), dims
31+
return tensors, dims
3232

33-
def embedding(self, input_ids, last_token_id, cache_ids, mask, active_mask, slopes, word_embeddings, ln_weight, ln_bias):
33+
def embedding(self, input_ids, cache_ids, start_ids, last_token_id, slopes, word_embeddings, ln_weight, ln_bias):
3434
dtype = getattr(input_ids.scribe, self.config.amp)
3535
hidden = hlo.embedding(word_embeddings, input_ids, tp_degree=self.config.tp_degree, dtype=dtype)
3636
if self.config.hidden_size % self.config.tp_degree != 0:
@@ -41,8 +41,9 @@ def embedding(self, input_ids, last_token_id, cache_ids, mask, active_mask, slop
4141
return hlo.layer_norm_bsh(hidden, ln_weight, ln_bias) if is_bsh \
4242
else hlo.layer_norm(hidden, ln_weight, ln_bias)
4343

44-
def pre_layer(self, hidden, last_token_id, cache_ids, mask, active_mask, *pre_layer_weights):
44+
def pre_layer(self, hidden, cache_ids, start_ids, last_token_id, *pre_layer_weights):
4545
slopes, *rest = pre_layer_weights
46+
mask, active_mask = hlo.attention_mask(cache_ids, start_ids, self.n_positions)
4647
prior_alibi, active_alibi = alibi.alibi(slopes, mask, active_mask)
4748
return hidden, last_token_id, cache_ids, mask, active_mask, prior_alibi, active_alibi
4849

@@ -87,14 +88,9 @@ def layer(self, hidden, last_token_id, cache_ids, mask, active_mask, prior_alibi
8788
hidden = dtype[hidden.sizes].Add(mlp_hidden, hidden)
8889
return hidden, out_attn_k_cache, out_attn_v_cache
8990

90-
def ln_lm_head(self, hidden, last_token_id, ln_f_weight, ln_f_bias, lm_head_weight, lm_head_bias, logits_indices, return_all_outputs=True):
91+
def ln_lm_head(self, hidden, last_token_id, ln_f_weight, ln_f_bias, lm_head_weight, lm_head_bias, return_all_outputs=True):
9192
logits = transformer.ln_lm_head(self.config.tp_degree, hidden, last_token_id, ln_f_weight, ln_f_bias, lm_head_weight,
9293
lm_head_bias, return_all_outputs, neuron_config=self.neuron_config)
93-
if self.neuron_config.on_device_generation is not None:
94-
return generation.generate(logits, logits_indices,
95-
config=self.neuron_config.on_device_generation,
96-
tp_degree=self.config.tp_degree,
97-
eos_token_id=self.config.eos_token_id)
9894
return logits
9995

10096
def attention(self,
@@ -111,12 +107,6 @@ def attention(self,
111107
dtype = hidden.dtype
112108
d_head = self.config.hidden_size // self.config.n_head
113109

114-
is_bsh = neuron_config and neuron_config.attention_layout == LAYOUT_BSH
115-
if is_bsh:
116-
import transformers_neuronx.layers.attention as attention
117-
else:
118-
import transformers_neuronx.layers.attention_hsb as attention
119-
120110
# Q = (hidden @ wQ) + bQ
121111
# K = (hidden @ wK) + bK
122112
# V = (hidden @ wV) + bV

0 commit comments

Comments
 (0)