Skip to content

Commit c2a2016

Browse files
author
George Ohashi
committed
channel wise fp8 quantization, attention modules
Signed-off-by: George Ohashi <[email protected]>
1 parent 76fc03d commit c2a2016

File tree

3 files changed

+55
-48
lines changed

3 files changed

+55
-48
lines changed

examples/quantization_kv_cache/llama3_fp8_kv_example.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -58,33 +58,35 @@ def process_and_tokenize(example):
5858
strategy: channel
5959
dynamic: false
6060
symmetric: true
61-
targets: ['re:.*q_proj', 're:.*k_proj', 're:.*v_proj']
61+
# targets: ['re:.*q_proj', 're:.*k_proj', 're:.*v_proj']
62+
targets: ['re:.*q_proj',]
63+
6264
"""
63-
recipe = """
64-
quant_stage:
65-
quant_modifiers:
66-
QuantizationModifier:
67-
config_groups:
68-
fp8_attention_q_proj:
69-
output_activations:
70-
num_bits: 8
71-
type: float
72-
strategy: group
73-
group_size: 512
74-
dynamic: false
75-
symmetric: true
76-
targets: ['re:.*q_proj']
77-
fp8_attention_kv_proj:
78-
output_activations:
79-
num_bits: 8
80-
type: float
81-
strategy: group
82-
group_size: 128
83-
dynamic: false
84-
symmetric: true
85-
targets: ['re:.*k_proj', 're:.*v_proj']
65+
# recipe = """
66+
# quant_stage:
67+
# quant_modifiers:
68+
# QuantizationModifier:
69+
# config_groups:
70+
# fp8_attention_q_proj:
71+
# output_activations:
72+
# num_bits: 8
73+
# type: float
74+
# strategy: channel
75+
# # group_size: 512
76+
# dynamic: false
77+
# symmetric: true
78+
# targets: ['re:.*q_proj']
79+
# # fp8_attention_kv_proj:
80+
# # output_activations:
81+
# # num_bits: 8
82+
# # type: float
83+
# # strategy: group
84+
# # group_size: 128
85+
# # dynamic: false
86+
# # symmetric: true
87+
# # targets: ['re:.*k_proj', 're:.*v_proj']
8688

87-
"""
89+
# """
8890

8991
# Apply algorithms.
9092
oneshot(

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,9 @@ def call_observer(module: Module, base_name: str, value: Optional[torch.Tensor]
8181
raise ValueError("Must provide a value to observe if not using weight observer")
8282

8383
observer = getattr(module, f"{base_name}_observer")
84-
updated_scale, updated_zero_point = observer(value, g_idx=g_idx)
84+
updated_scale, updated_zero_point = observer(
85+
value, g_idx=g_idx, base_name=base_name
86+
)
8587

8688
# update scale and zero point
8789
update_parameter_data(module, updated_scale, f"{base_name}_scale")

src/llmcompressor/observers/base.py

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@ def __init__(self, quantization_args: QuantizationArgs):
3131

3232
@torch.no_grad()
3333
def forward(
34-
self, observed: Tensor, g_idx: Optional[Tensor] = None
34+
self,
35+
observed: Tensor,
36+
g_idx: Optional[Tensor] = None,
37+
base_name: Optional[str] = None,
3538
) -> Tuple[FloatTensor, IntTensor]:
3639
"""
3740
maps directly to get_qparams
@@ -40,8 +43,9 @@ def forward(
4043
:param g_idx: optional mapping from column index to group index
4144
:return: tuple of scale and zero point based on last observed value
4245
"""
46+
# breakpoint()
4347
self.record_observed_tokens(observed)
44-
return self.get_qparams(observed=observed, g_idx=g_idx)
48+
return self.get_qparams(observed=observed, g_idx=g_idx, base_name=base_name)
4549

4650
def calculate_qparams(
4751
self,
@@ -66,6 +70,7 @@ def get_qparams(
6670
self,
6771
observed: Optional[Tensor] = None,
6872
g_idx: Optional[Tensor] = None,
73+
base_name: Optional[str] = None,
6974
) -> Tuple[FloatTensor, IntTensor]:
7075
"""
7176
Convenience function to wrap overwritten calculate_qparams
@@ -123,26 +128,24 @@ def get_qparams(
123128
self._zero_point[:, group_index] = zero_point.squeeze(1)
124129

125130
elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
126-
# assume observed is transposed, because its the output, hence use dim 0
127-
# we pass in [1, 8, 2048, 128] for k_states
128-
# normally per channel: (output_dim, 1) and you have as many scales as the output_dim
129-
# we want 8 - num_k_head_scales? or
130-
#breakpoint()
131-
132-
# weight --> get scales along the first dimension (output dim is first dim)
133-
# weight shape (output_dim, input_dim)
134-
# self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0)
135-
# output when applied to the weight: (output_dim, 1)
136-
137-
138-
# for outputs:
139-
self._scale, self._zero_point = self.get_qparams_along_dim(observed, 2)
140-
self._scale = self._scale.squeeze(1)
141-
self._zero_point = self._zero_point.squeeze(1)
142-
# why is the output of self._scale: [1, 1, 1]
143-
144-
145-
131+
if base_name == "output":
132+
# the last dimension is the hidden dimension
133+
# shape of [1,1, num_key_value_heads * head_dim]
134+
scale, zero_point = self.get_qparams_along_dim(
135+
observed, observed.ndim - 1
136+
)
137+
self._scale = (
138+
scale.squeeze()
139+
) # shape of [num_key_value_heads * head_dim]
140+
self._zero_point = (
141+
zero_point.squeeze()
142+
) # shape of [num_key_value_heads * head_dim]
143+
else:
144+
# weight or input
145+
# assume observed is transposed, because its the output, hence use dim 0
146+
self._scale, self._zero_point = self.get_qparams_along_dim(
147+
observed, 0
148+
)
146149

147150
elif self.quantization_args.strategy == QuantizationStrategy.TOKEN:
148151
# use dim 1, assume the obsersed.shape = [batch, token, hidden]

0 commit comments

Comments
 (0)