Skip to content

Commit b2d49bd

Browse files
committed
Align Int4Tensor implementation details with the design of Float8Tensor
Summary: Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N] Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops) * Added VERSION 2 for Int4WeightOnlyConfig * Migrated op implementation and tests from #2387 Test Plan: python test/quantization/quantize_/workflows/int4/test_int4_tensor.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2687, branch: jerryzh168/stack/16
1 parent 89b3fac commit b2d49bd

File tree

7 files changed

+664
-366
lines changed

7 files changed

+664
-366
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 5 additions & 218 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,11 @@
1010
from typing import Tuple
1111

1212
import torch
13-
import torch.nn as nn
14-
import torch.nn.functional as F
1513
from torch.testing._internal import common_utils
1614
from torch.testing._internal.common_utils import (
17-
TestCase,
1815
run_tests,
1916
)
2017

21-
from torchao.prototype.moe_quant.utils import MoEQuantConfig
2218
from torchao.quantization import (
2319
Float8DynamicActivationFloat8WeightConfig,
2420
Float8WeightOnlyConfig,
@@ -28,6 +24,7 @@
2824
)
2925
from torchao.quantization.quantize_.common import KernelPreference
3026
from torchao.quantization.utils import compute_error
27+
from torchao.testing.utils import TorchAOIntegrationTestCase
3128
from torchao.utils import (
3229
TORCH_VERSION_AT_LEAST_2_8,
3330
_is_fbgemm_genai_gpu_available,
@@ -39,66 +36,6 @@
3936
torch._dynamo.config.cache_size_limit = 128
4037

4138

42-
class Experts(nn.Module):
43-
def __init__(
44-
self,
45-
num_local_experts: int,
46-
dim: int,
47-
hidden_dim: int,
48-
dtype: torch.dtype,
49-
device: torch.device,
50-
) -> None:
51-
super().__init__()
52-
53-
self.num_local_experts = num_local_experts
54-
self.dim = dim
55-
56-
self.w1: nn.Parameter = nn.Parameter(
57-
torch.randn(
58-
num_local_experts,
59-
dim,
60-
hidden_dim,
61-
dtype=dtype,
62-
device=device,
63-
)
64-
)
65-
66-
self.w2: nn.Parameter = nn.Parameter(
67-
torch.randn(
68-
num_local_experts,
69-
hidden_dim,
70-
dim,
71-
dtype=dtype,
72-
device=device,
73-
)
74-
)
75-
76-
self.w3: nn.Parameter = nn.Parameter(
77-
torch.randn(
78-
num_local_experts,
79-
dim,
80-
hidden_dim,
81-
dtype=dtype,
82-
device=device,
83-
)
84-
)
85-
86-
def forward(
87-
self,
88-
routed_in_egD: torch.Tensor, # noqa: N803
89-
) -> torch.Tensor:
90-
e = self.num_local_experts
91-
D = self.dim
92-
93-
x_egD = routed_in_egD.view(e, -1, D)
94-
95-
middle_out_egF = F.silu(torch.bmm(x_egD, self.w1)) * torch.bmm(x_egD, self.w3)
96-
out_egD = torch.bmm(middle_out_egF, self.w2)
97-
out_egD = out_egD.view(-1, D)
98-
99-
return out_egD
100-
101-
10239
class ToyLinearModel(torch.nn.Module):
10340
def __init__(self, in_features, out_features):
10441
super().__init__()
@@ -115,7 +52,7 @@ def forward(self, x):
11552
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
11653
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
11754
@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
118-
class TestFloat8Tensor(TestCase):
55+
class TestFloat8Tensor(TorchAOIntegrationTestCase):
11956
def setUp(self):
12057
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
12158

@@ -338,45 +275,8 @@ def test_slice_preserves_aliasing(self, granularity):
338275

339276
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
340277
def test_slice_and_copy_similar_to_vllm(self, granularity):
341-
# making sure https://github.com/vllm-project/vllm/blob/90bd2ab6e3eb7e83d3f40d99fc23e6e43834743a/vllm/model_executor/layers/linear.py#L483-L495 works properly
342-
# the test is similar to the linked code, but with some hardcoded arguments
343-
# and does not use tensor parallelism
344-
345-
dtype = torch.bfloat16
346-
device = "cuda"
347278
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
348-
l = torch.nn.Linear(1024, 1024, device="cuda", dtype=dtype)
349-
quantize_(l, config)
350-
351-
# high level, we do a narrow for both param.data and the loaded_weights
352-
# and do inplace copy_ to copy from the loaded_weights into param.data
353-
354-
# simulate loaded_weight
355-
dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
356-
# making the weight different
357-
dummy_l.weight = torch.nn.Parameter(
358-
dummy_l.weight + 2 * torch.randn(1024, 1024, device=device, dtype=dtype),
359-
requires_grad=False,
360-
)
361-
quantize_(dummy_l, config)
362-
363-
output_dim = 0
364-
shard_size = 512
365-
for tp_rank in [0, 1]:
366-
start_idx = tp_rank * shard_size
367-
param = l.weight
368-
param_data = param.data
369-
param_data = param_data.narrow(output_dim, start_idx, shard_size)
370-
orig_value = param_data.qdata[0][0].item()
371-
loaded_weight = dummy_l.weight
372-
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
373-
374-
# making sure param.data.qdata[0][0] is not the same as loaded_weight.qdata[0][0]
375-
assert orig_value != loaded_weight.qdata[0][0]
376-
param_data.copy_(loaded_weight)
377-
# making sure param.data is updated to loaded_weight
378-
assert param_data.qdata[0][0] == loaded_weight.qdata[0][0]
379-
assert param_data.scale[0] == loaded_weight.scale[0]
279+
self._test_slice_and_copy_similar_to_vllm(config)
380280

381281
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
382282
def test_bmm(self):
@@ -492,122 +392,9 @@ def test_cat(self, granularity, sizes):
492392

493393
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
494394
def test_moe_weight_reshape_ops(self):
495-
"""This is testing the op call sequence in saving and loading quantization
496-
checkpoints in llama-models for llama4
497-
(https://github.com/meta-llama/llama-models/tree/main/models/llama4)
498-
"""
499-
# only per row quantization is supported for bmm
500395
granularity = PerRow()
501-
dtype = torch.bfloat16
502-
device = "cuda"
503-
504-
bmm_config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
505-
moe_config = MoEQuantConfig(bmm_config)
506-
507-
batch_size = 4
508-
num_experts = 2
509-
input_dim = 64
510-
dim = 128
511-
hidden_dim = 256
512-
513-
moe1 = Experts(num_experts, dim, hidden_dim, dtype, device)
514-
moe2 = Experts(num_experts, dim, hidden_dim, dtype, device)
515-
moe_combined = Experts(num_experts, dim, 2 * hidden_dim, dtype, device)
516-
input = torch.randn(batch_size, input_dim, dim, dtype=dtype, device=device)
517-
518-
moes = [moe1, moe2]
519-
520-
for moe in moes:
521-
moe(input)
522-
523-
def filter_fn(module, fqn):
524-
return isinstance(module, Experts)
525-
526-
# need to transpose before quantizing
527-
moe.w1 = torch.nn.Parameter(
528-
moe.w1.transpose(1, 2).contiguous(), requires_grad=False
529-
)
530-
moe.w2 = torch.nn.Parameter(
531-
moe.w2.transpose(1, 2).contiguous(), requires_grad=False
532-
)
533-
moe.w3 = torch.nn.Parameter(
534-
moe.w3.transpose(1, 2).contiguous(), requires_grad=False
535-
)
536-
537-
quantize_(moe, moe_config, filter_fn=filter_fn)
538-
539-
# make sure it runs
540-
before = moe(input)
541-
542-
# transposing for resharding support since only 2D resharding is supported
543-
new_last_dim = moe.w1.shape[-2]
544-
moe.w1 = torch.nn.Parameter(
545-
moe.w1.transpose(1, 2).reshape(-1, new_last_dim), requires_grad=False
546-
)
547-
new_last_dim = moe.w2.shape[-2]
548-
moe.w2 = torch.nn.Parameter(
549-
moe.w2.transpose(1, 2).reshape(-1, new_last_dim), requires_grad=False
550-
)
551-
new_last_dim = moe.w3.shape[-2]
552-
moe.w3 = torch.nn.Parameter(
553-
moe.w3.transpose(1, 2).reshape(-1, new_last_dim), requires_grad=False
554-
)
555-
556-
moe.w1 = torch.nn.Parameter(
557-
moe.w1.unflatten(0, (num_experts, -1)).squeeze(dim=0),
558-
requires_grad=False,
559-
)
560-
moe.w2 = torch.nn.Parameter(
561-
moe.w2.unflatten(0, (num_experts, -1)).squeeze(dim=0),
562-
requires_grad=False,
563-
)
564-
moe.w3 = torch.nn.Parameter(
565-
moe.w3.unflatten(0, (num_experts, -1)).squeeze(dim=0),
566-
requires_grad=False,
567-
)
568-
569-
# transpose again to recover the original weights
570-
moe.w1 = torch.nn.Parameter(moe.w1.transpose(1, 2), requires_grad=False)
571-
moe.w2 = torch.nn.Parameter(moe.w2.transpose(1, 2), requires_grad=False)
572-
moe.w3 = torch.nn.Parameter(moe.w3.transpose(1, 2), requires_grad=False)
573-
574-
# make sure it runs
575-
after = moe(input)
576-
577-
self.assertEqual(before, after)
578-
579-
state_dicts = [moe1.state_dict(), moe2.state_dict()]
580-
# align the scale parameter so they can be concatenated
581-
for key in ["w1", "w2", "w3"]:
582-
weights = [st[key] for st in state_dicts]
583-
for i in range(1, len(weights)):
584-
weights[i].scale = weights[0].scale
585-
586-
def process_key(key: str) -> torch.Tensor:
587-
tensors = [s[key] for s in state_dicts]
588-
# Note: we have a hacky implementation for cat in user codebase
589-
# since it is not implemented correctly before
590-
if key == "w2":
591-
return torch.cat(tensors, dim=-1)
592-
else:
593-
return torch.cat(tensors, dim=-2)
594-
595-
new_state_dict = {}
596-
for key in ["w1", "w2", "w3"]:
597-
new_state_dict[key] = process_key(key)
598-
599-
moe_combined.w1 = torch.nn.Parameter(
600-
moe_combined.w1.transpose(1, 2), requires_grad=False
601-
)
602-
moe_combined.w2 = torch.nn.Parameter(
603-
moe_combined.w2.transpose(1, 2), requires_grad=False
604-
)
605-
moe_combined.w3 = torch.nn.Parameter(
606-
moe_combined.w3.transpose(1, 2), requires_grad=False
607-
)
608-
moe_combined.load_state_dict(new_state_dict, assign=True)
609-
# make sure it runs
610-
moe_combined(input)
396+
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
397+
self._test_moe_weight_reshape_ops(config)
611398

612399

613400
common_utils.instantiate_parametrized_tests(TestFloat8Tensor)

0 commit comments

Comments
 (0)