Skip to content

Commit b0480d2

Browse files
switch to F.mse_loss()
Signed-off-by: Brian Dellabetta <[email protected]>
1 parent 6f951b7 commit b0480d2

File tree

1 file changed

+4
-15
lines changed
  • src/llmcompressor/modifiers/awq

1 file changed

+4
-15
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55
from compressed_tensors.quantization import disable_quantization
6+
import torch.nn.functional as F
67
from compressed_tensors.utils import (
78
align_modules,
89
get_execution_device,
@@ -579,9 +580,9 @@ def _compute_best_scale(
579580
x_mean = x_mean.view(-1).to(device)
580581
w_mean = w_mean.view(-1).to(device)
581582

582-
for ratio in range(n_grid):
583+
for grid_idx in range(n_grid):
583584
# create new scales
584-
ratio = ratio / n_grid
585+
ratio = grid_idx / n_grid
585586

586587
# NOTE: s^-1 * x is fused here, according to paper
587588
if self.duo_scaling:
@@ -616,7 +617,7 @@ def _compute_best_scale(
616617
int_w_outputs = self._run_samples(parent_module)
617618

618619
# compute mean squared error (L2 norm)
619-
loss = _compute_loss(fp16_output, int_w_output)
620+
loss = F.mse_loss(int_w_output, fp16_output).item()
620621

621622
history.append(loss)
622623
if loss < best_error:
@@ -650,18 +651,6 @@ def _assert_all_activations_consumed(self):
650651
raise RuntimeError("Some cached activations were not used")
651652

652653

653-
@torch.no_grad()
654-
@torch.compile()
655-
def _compute_loss(
656-
fp16_output: torch.Tensor,
657-
int_w_output: torch.Tensor,
658-
) -> torch.Tensor:
659-
"""
660-
Compute MSE loss over the flattened output of all batches
661-
"""
662-
return (fp16_output - int_w_output).view(-1).float().pow(2).mean()
663-
664-
665654
@torch.compile()
666655
def _pseudo_quantize_tensor(
667656
w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1

0 commit comments

Comments
 (0)