@@ -31,7 +31,10 @@ def __init__(self, quantization_args: QuantizationArgs):
31
31
32
32
@torch .no_grad ()
33
33
def forward (
34
- self , observed : Tensor , g_idx : Optional [Tensor ] = None
34
+ self ,
35
+ observed : Tensor ,
36
+ g_idx : Optional [Tensor ] = None ,
37
+ base_name : Optional [str ] = None ,
35
38
) -> Tuple [FloatTensor , IntTensor ]:
36
39
"""
37
40
maps directly to get_qparams
@@ -40,8 +43,9 @@ def forward(
40
43
:param g_idx: optional mapping from column index to group index
41
44
:return: tuple of scale and zero point based on last observed value
42
45
"""
46
+ # breakpoint()
43
47
self .record_observed_tokens (observed )
44
- return self .get_qparams (observed = observed , g_idx = g_idx )
48
+ return self .get_qparams (observed = observed , g_idx = g_idx , base_name = base_name )
45
49
46
50
def calculate_qparams (
47
51
self ,
@@ -66,6 +70,7 @@ def get_qparams(
66
70
self ,
67
71
observed : Optional [Tensor ] = None ,
68
72
g_idx : Optional [Tensor ] = None ,
73
+ base_name : Optional [str ] = None ,
69
74
) -> Tuple [FloatTensor , IntTensor ]:
70
75
"""
71
76
Convenience function to wrap overwritten calculate_qparams
@@ -123,26 +128,24 @@ def get_qparams(
123
128
self ._zero_point [:, group_index ] = zero_point .squeeze (1 )
124
129
125
130
elif self .quantization_args .strategy == QuantizationStrategy .CHANNEL :
126
- # assume observed is transposed, because its the output, hence use dim 0
127
- # we pass in [1, 8, 2048, 128] for k_states
128
- # normally per channel: (output_dim, 1) and you have as many scales as the output_dim
129
- # we want 8 - num_k_head_scales? or
130
- #breakpoint()
131
-
132
- # weight --> get scales along the first dimension (output dim is first dim)
133
- # weight shape (output_dim, input_dim)
134
- # self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0)
135
- # output when applied to the weight: (output_dim, 1)
136
-
137
-
138
- # for outputs:
139
- self ._scale , self ._zero_point = self .get_qparams_along_dim (observed , 2 )
140
- self ._scale = self ._scale .squeeze (1 )
141
- self ._zero_point = self ._zero_point .squeeze (1 )
142
- # why is the output of self._scale: [1, 1, 1]
143
-
144
-
145
-
131
+ if base_name == "output" :
132
+ # the last dimension is the hidden dimension
133
+ # shape of [1,1, num_key_value_heads * head_dim]
134
+ scale , zero_point = self .get_qparams_along_dim (
135
+ observed , observed .ndim - 1
136
+ )
137
+ self ._scale = (
138
+ scale .squeeze ()
139
+ ) # shape of [num_key_value_heads * head_dim]
140
+ self ._zero_point = (
141
+ zero_point .squeeze ()
142
+ ) # shape of [num_key_value_heads * head_dim]
143
+ else :
144
+ # weight or input
145
+ # assume observed is transposed, because its the output, hence use dim 0
146
+ self ._scale , self ._zero_point = self .get_qparams_along_dim (
147
+ observed , 0
148
+ )
146
149
147
150
elif self .quantization_args .strategy == QuantizationStrategy .TOKEN :
148
151
# use dim 1, assume the obsersed.shape = [batch, token, hidden]
0 commit comments