|
| 1 | +from typing import Optional |
| 2 | + |
| 3 | +from transformers_neuronx import constants |
| 4 | +from transformers_neuronx import hlo |
| 5 | +from transformers_neuronx import utils |
| 6 | +from transformers_neuronx.config import NeuronConfig |
| 7 | +from transformers_neuronx.constants import LAYOUT_HSB |
| 8 | +from transformers_neuronx.hlo import mlp |
| 9 | +from transformers_neuronx.layers import transformer, rotary, attention, attention_utils, flash_decoding |
| 10 | +from transformers_neuronx.starcoder2.config import Starcoder2Config |
| 11 | + |
| 12 | + |
| 13 | +class Starcoder2ForSamplingNoEmbeddingHlo: |
| 14 | + |
| 15 | + def __init__(self, |
| 16 | + config: Starcoder2Config, |
| 17 | + neuron_config: Optional[NeuronConfig] = None |
| 18 | + ): |
| 19 | + self.config = config |
| 20 | + self.neuron_config = neuron_config |
| 21 | + self.n_positions = None |
| 22 | + |
| 23 | + @property |
| 24 | + def shard_over_batch(self): |
| 25 | + # Property access allows fallback configuration to be enabled after construction |
| 26 | + return ( |
| 27 | + self.neuron_config is not None |
| 28 | + and self.neuron_config.group_query_attention == constants.GQA.SHARD_OVER_BATCH |
| 29 | + ) |
| 30 | + |
| 31 | + def inputs(self, scribe, dtype, n_active_tokens, batch_size): |
| 32 | + tensors, dims = transformer.inputs( |
| 33 | + scribe, dtype, batch_size, n_active_tokens, self.config.hidden_size, self.neuron_config) |
| 34 | + |
| 35 | + return tensors, dims |
| 36 | + |
| 37 | + def embedding(self, input_ids, cache_ids, start_ids, last_token_id, embed_weight): |
| 38 | + dtype = getattr(input_ids.scribe, self.config.amp) |
| 39 | + hidden = hlo.embedding(embed_weight, input_ids, tp_degree=self.config.tp_degree, dtype=dtype) |
| 40 | + if self.config.hidden_size % self.config.tp_degree != 0: |
| 41 | + hidden = hlo.slice_along(hidden, dim=-1, limit=self.config.hidden_size, start=0) |
| 42 | + if self.neuron_config.attention_layout == LAYOUT_HSB: |
| 43 | + hidden = hlo.transpose210(hidden) |
| 44 | + return hidden |
| 45 | + |
| 46 | + def pre_layer(self, hidden, cache_ids, start_ids, last_token_id, *weights): |
| 47 | + head_dim = self.config.attention_head_size |
| 48 | + pos_embed = rotary.hlo_rotary_embedding( |
| 49 | + hidden.dtype, int(head_dim * self.config.rotary_percentage), cache_ids, |
| 50 | + base=self.config.rope_theta, |
| 51 | + interpolation_factor=self.config.position_interpolation_factor |
| 52 | + ) |
| 53 | + mask, active_mask = hlo.attention_mask(cache_ids, start_ids, self.n_positions) |
| 54 | + core_id = None |
| 55 | + if self.neuron_config.shard_over_sequence: |
| 56 | + core_id, *rst = weights |
| 57 | + n_kv_heads = self.config.num_key_value_heads if self.config.num_attention_heads else self.config.num_attention_heads |
| 58 | + cores_per_kv_head = self.config.tp_degree // n_kv_heads |
| 59 | + self.cores_per_kv_head = cores_per_kv_head if cores_per_kv_head > 1 else self.config.tp_degree |
| 60 | + cache_ids, mask, active_mask = flash_decoding.convert_attn_mask_and_cache_id(cache_ids, |
| 61 | + core_id, self.n_positions, |
| 62 | + cores_per_kv_head=self.cores_per_kv_head) |
| 63 | + |
| 64 | + return hidden, last_token_id, pos_embed, cache_ids, start_ids, mask, active_mask, core_id |
| 65 | + |
| 66 | + def layer( |
| 67 | + self, hidden, last_token_id, pos_embed, cache_ids, start_ids, mask, active_mask, core_id, |
| 68 | + attn_k_cache, attn_v_cache, |
| 69 | + pre_attn_ln_weight, pre_attn_ln_bias, |
| 70 | + attn_q_weight, attn_q_scales, attn_q_bias, |
| 71 | + attn_k_weight, attn_k_scales, attn_k_bias, |
| 72 | + attn_v_weight, attn_v_scales, attn_v_bias, |
| 73 | + attn_out_weight, attn_out_scales, attn_out_bias, |
| 74 | + post_attn_ln_weight, post_attn_ln_bias, |
| 75 | + pre_mlp_ln_weight, pre_mlp_ln_bias, |
| 76 | + mlp_in_weight, mlp_in_scales, mlp_in_bias, |
| 77 | + mlp_out_weight, mlp_out_scales, mlp_out_bias, |
| 78 | + post_mlp_ln_weight, post_mlp_ln_bias, |
| 79 | + ): |
| 80 | + # eps = self.config.rms_norm_eps |
| 81 | + # is_bsh = self.neuron_config and self.neuron_config.attention_layout == LAYOUT_BSH |
| 82 | + ln_hidden = hlo.layer_norm(hidden, pre_attn_ln_weight, pre_attn_ln_bias) |
| 83 | + |
| 84 | + attn_output, out_attn_k_cache, out_attn_v_cache = self.attention( |
| 85 | + ln_hidden, cache_ids, start_ids, pos_embed, mask, active_mask, core_id, |
| 86 | + attn_k_cache, attn_v_cache, |
| 87 | + attn_q_weight, attn_q_scales, attn_q_bias, |
| 88 | + attn_k_weight, attn_k_scales, attn_k_bias, |
| 89 | + attn_v_weight, attn_v_scales, attn_v_bias, |
| 90 | + attn_out_weight, attn_out_scales, attn_out_bias |
| 91 | + ) |
| 92 | + hidden = hlo.add(attn_output, hidden) |
| 93 | + |
| 94 | + norm_hidden = hlo.layer_norm(hidden, pre_mlp_ln_weight, pre_mlp_ln_bias) |
| 95 | + mlp_hidden = mlp( |
| 96 | + norm_hidden, |
| 97 | + mlp_in_weight, mlp_in_bias, mlp_out_weight, mlp_out_bias, |
| 98 | + activation_function='gelu_new', # 'gelu_pytorch_tanh', |
| 99 | + tp_degree=self.config.tp_degree, |
| 100 | + neuron_config=self.neuron_config |
| 101 | + ) |
| 102 | + res_hidden = hlo.add(mlp_hidden, hidden) |
| 103 | + return res_hidden, out_attn_k_cache, out_attn_v_cache |
| 104 | + |
| 105 | + def ln_lm_head(self, hidden, last_token_id, rms_weight, unused_bias, lm_head_weight, lm_head_bias, |
| 106 | + return_all_outputs=True): |
| 107 | + logits = transformer.rms_lm_head(self.config.tp_degree, hidden, last_token_id, rms_weight, lm_head_weight, |
| 108 | + lm_head_bias, return_all_outputs, eps=self.config.rms_norm_eps, |
| 109 | + neuron_config=self.neuron_config) |
| 110 | + return logits |
| 111 | + |
| 112 | + def attention( |
| 113 | + self, |
| 114 | + hidden, cache_ids, start_ids, pos_embed, mask, active_mask, core_id, |
| 115 | + cached_keys, cached_values, |
| 116 | + q_weight, q_scales, q_bias, |
| 117 | + k_weight, k_scales, k_bias, |
| 118 | + v_weight, v_scales, v_bias, |
| 119 | + out_weight, out_scales, out_bias, |
| 120 | + ): |
| 121 | + d_head = self.config.attention_head_size |
| 122 | + tp_degree = self.config.tp_degree |
| 123 | + |
| 124 | + # Compute the expected number of KV heads (Used in case fused QKV is used) |
| 125 | + n_kv_heads_tp = None |
| 126 | + if self.config.num_key_value_heads is not None: |
| 127 | + n_head = self.config.num_attention_heads |
| 128 | + n_kv_head = self.config.num_key_value_heads |
| 129 | + _, n_kv_head_padded = utils.get_qkv_padding(n_head, n_kv_head, tp_degree, self.neuron_config) |
| 130 | + n_kv_heads_tp = n_kv_head_padded // tp_degree |
| 131 | + |
| 132 | + # Q = (hidden @ wQ) + bQ |
| 133 | + # K = (hidden @ wK) + bK |
| 134 | + # V = (hidden @ wV) + bV |
| 135 | + query, key, value = attention.query_key_value( |
| 136 | + hidden, |
| 137 | + q_weight, q_scales, q_bias, |
| 138 | + k_weight, k_scales, k_bias, |
| 139 | + v_weight, v_scales, v_bias, |
| 140 | + d_head, |
| 141 | + neuron_config=self.neuron_config, |
| 142 | + tp_degree=tp_degree, # TODO: include tp_degree into neuron_config |
| 143 | + shard_over_batch=self.shard_over_batch, |
| 144 | + n_kv_heads_tp=n_kv_heads_tp, |
| 145 | + ) |
| 146 | + |
| 147 | + # Q = Rotate(Q) |
| 148 | + # K = Rotate(K) |
| 149 | + query, key = rotary.rotate_half(query, key, pos_embed, self.config.rotary_percentage, |
| 150 | + tp_degree=tp_degree, shard_over_batch=self.shard_over_batch) |
| 151 | + |
| 152 | + # Q = Q / sqrt(d_head) |
| 153 | + query = attention.scale(query, d_head) |
| 154 | + |
| 155 | + # In BSH cache layout, the output of QKV linear projection is still kept as SBH for all QKV. |
| 156 | + bsh_cache_layout = False |
| 157 | + batch_dim = 1 |
| 158 | + if self.neuron_config is not None: |
| 159 | + bsh_cache_layout = self.neuron_config.cache_layout == constants.LAYOUT_BSH |
| 160 | + if bsh_cache_layout: |
| 161 | + query, key, value = attention_utils.transpose_qkv(query, key, value) |
| 162 | + batch_dim = 0 |
| 163 | + |
| 164 | + # Single Token Generation ("Prefetch"-style) ans speculative forward |
| 165 | + if active_mask is not None: |
| 166 | + |
| 167 | + n_active_tokens = key.sizes[1] if bsh_cache_layout else key.sizes[0] |
| 168 | + if n_active_tokens > 1 and self.neuron_config and self.neuron_config.continuous_batching: |
| 169 | + # For speculative forward + continuous batching, slice out samples in the batch size |
| 170 | + # corresponding to the batch size of the speculative head |
| 171 | + slice_sizes = [1] * len(cached_keys.sizes) |
| 172 | + if cached_keys.sizes[batch_dim] == 1: |
| 173 | + # Use hlo.select for batch size 1 as index select is prohibitively slow |
| 174 | + # TODO: revert to hlo.index_select once its faster P126527643 |
| 175 | + cached_keys_s = hlo.select(cached_keys, batch_dim, hlo.reshape(start_ids, slice_sizes), |
| 176 | + keepdim=True) |
| 177 | + cached_values_s = hlo.select(cached_values, batch_dim, hlo.reshape(start_ids, slice_sizes), |
| 178 | + keepdim=True) |
| 179 | + else: |
| 180 | + cached_keys_s = hlo.index_select(cached_keys, batch_dim, start_ids) |
| 181 | + cached_values_s = hlo.index_select(cached_values, batch_dim, start_ids) |
| 182 | + else: |
| 183 | + cached_keys_s = cached_keys |
| 184 | + cached_values_s = cached_values |
| 185 | + # Communication 1: all-gather query from cores |
| 186 | + if (n_active_tokens != self.n_positions) and self.neuron_config.shard_over_sequence: |
| 187 | + query = flash_decoding.gather_query_group(query, self.cores_per_kv_head, |
| 188 | + self.config.num_attention_heads, |
| 189 | + tp_degree) |
| 190 | + |
| 191 | + # Sp = Q @ Kp |
| 192 | + prior_scores = attention.score(query, cached_keys_s, n_kv_heads=self.config.num_key_value_heads, |
| 193 | + tp_degree=tp_degree, neuron_config=self.neuron_config) |
| 194 | + prior_scores = attention.mask(prior_scores, mask, tp_degree=tp_degree, |
| 195 | + shard_over_batch=self.shard_over_batch) |
| 196 | + |
| 197 | + # Sa = Q @ Ka |
| 198 | + active_score = attention.score(query, key, n_kv_heads=self.config.num_key_value_heads, |
| 199 | + tp_degree=tp_degree, neuron_config=self.neuron_config) |
| 200 | + active_score = attention.mask(active_score, active_mask, tp_degree=tp_degree, |
| 201 | + shard_over_batch=self.shard_over_batch) |
| 202 | + |
| 203 | + # C = softmax(Sa, Sp) @ (Va, Vp) |
| 204 | + if self.neuron_config.shard_over_sequence: |
| 205 | + dtype = query.dtype |
| 206 | + context = flash_decoding.context(prior_scores, active_score, cached_values_s, value, core_id, mask, |
| 207 | + active_mask, |
| 208 | + n_kv_heads=self.config.num_key_value_heads, |
| 209 | + n_heads=self.config.num_attention_heads, dtype=dtype, |
| 210 | + tp_degree=tp_degree, neuron_config=self.neuron_config, |
| 211 | + shard_over_batch=self.shard_over_batch) |
| 212 | + cache_ids, value, key = flash_decoding.select_values_within_bound(cache_ids, value, key, |
| 213 | + self.cores_per_kv_head, core_id, |
| 214 | + dim=0) |
| 215 | + |
| 216 | + else: |
| 217 | + context = attention.context(prior_scores, active_score, cached_values_s, value, |
| 218 | + n_kv_heads=self.config.num_key_value_heads, tp_degree=tp_degree, |
| 219 | + neuron_config=self.neuron_config) |
| 220 | + |
| 221 | + # KCache[I], VCache[I] = K, V |
| 222 | + updated_keys, updated_values = attention.fused_kv_update_cache(cached_keys, cached_values, cache_ids, |
| 223 | + key, value, start_ids, |
| 224 | + neuron_config=self.neuron_config) |
| 225 | + |
| 226 | + # Multi-Token Context Encoding |
| 227 | + else: |
| 228 | + _, batch_size, _, _ = query.sizes |
| 229 | + if self.neuron_config.lhs_aligned or batch_size == 1: |
| 230 | + context = attention.flash_attention(query, key, value) |
| 231 | + else: |
| 232 | + # do not use flash attention for lhs padded (right aligned) batch > 1 case |
| 233 | + # because it does not correctly take mask into account |
| 234 | + context = None |
| 235 | + |
| 236 | + if context is None: |
| 237 | + # S = Q @ K |
| 238 | + |
| 239 | + score = attention.score(query, key, n_kv_heads=self.config.num_key_value_heads, |
| 240 | + tp_degree=tp_degree, neuron_config=self.neuron_config) |
| 241 | + score = attention.mask(score, mask, tp_degree=tp_degree, shard_over_batch=self.shard_over_batch) |
| 242 | + context = attention.context_combined(score, value, n_kv_heads=self.config.num_key_value_heads, |
| 243 | + tp_degree=tp_degree, neuron_config=self.neuron_config) |
| 244 | + |
| 245 | + if self.neuron_config.shard_over_sequence: |
| 246 | + cache_ids, value, key = flash_decoding.select_values_within_bound(cache_ids, |
| 247 | + value, |
| 248 | + key, |
| 249 | + self.cores_per_kv_head, |
| 250 | + core_id, dim=0) |
| 251 | + # KCache, VCache = K, V |
| 252 | + if cached_keys.sizes == key.sizes: |
| 253 | + updated_keys, updated_values = key, value |
| 254 | + else: |
| 255 | + updated_keys, updated_values = attention.fused_kv_update_cache(cached_keys, cached_values, cache_ids, |
| 256 | + key, value, start_ids, |
| 257 | + neuron_config=self.neuron_config) |
| 258 | + |
| 259 | + # O = (C @ wO) + bO |
| 260 | + output = attention.output(context, out_weight, out_scales, out_bias, tp_degree, self.neuron_config) |
| 261 | + return output, updated_keys, updated_values |
0 commit comments