|
3 | 3 |
|
4 | 4 | import torch
|
5 | 5 | from compressed_tensors.quantization import disable_quantization
|
| 6 | +import torch.nn.functional as F |
6 | 7 | from compressed_tensors.utils import (
|
7 | 8 | align_modules,
|
8 | 9 | get_execution_device,
|
@@ -579,9 +580,9 @@ def _compute_best_scale(
|
579 | 580 | x_mean = x_mean.view(-1).to(device)
|
580 | 581 | w_mean = w_mean.view(-1).to(device)
|
581 | 582 |
|
582 |
| - for ratio in range(n_grid): |
| 583 | + for grid_idx in range(n_grid): |
583 | 584 | # create new scales
|
584 |
| - ratio = ratio / n_grid |
| 585 | + ratio = grid_idx / n_grid |
585 | 586 |
|
586 | 587 | # NOTE: s^-1 * x is fused here, according to paper
|
587 | 588 | if self.duo_scaling:
|
@@ -616,7 +617,7 @@ def _compute_best_scale(
|
616 | 617 | int_w_outputs = self._run_samples(parent_module)
|
617 | 618 |
|
618 | 619 | # compute mean squared error (L2 norm)
|
619 |
| - loss = _compute_loss(fp16_output, int_w_output) |
| 620 | + loss = F.mse_loss(int_w_output, fp16_output).item() |
620 | 621 |
|
621 | 622 | history.append(loss)
|
622 | 623 | if loss < best_error:
|
@@ -650,18 +651,6 @@ def _assert_all_activations_consumed(self):
|
650 | 651 | raise RuntimeError("Some cached activations were not used")
|
651 | 652 |
|
652 | 653 |
|
653 |
| -@torch.no_grad() |
654 |
| -@torch.compile() |
655 |
| -def _compute_loss( |
656 |
| - fp16_output: torch.Tensor, |
657 |
| - int_w_output: torch.Tensor, |
658 |
| -) -> torch.Tensor: |
659 |
| - """ |
660 |
| - Compute MSE loss over the flattened output of all batches |
661 |
| - """ |
662 |
| - return (fp16_output - int_w_output).view(-1).float().pow(2).mean() |
663 |
| - |
664 |
| - |
665 | 654 | @torch.compile()
|
666 | 655 | def _pseudo_quantize_tensor(
|
667 | 656 | w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1
|
|
0 commit comments