Skip to content

Commit 3b4bc98

Browse files
authored
Add Float8Tensor (#2463)
Summary: * Added Float8Tensor that's using fbgemm kernels and scaled_mm: * per row activation + per row weight linear calling torch._scaled_mm op (for compatibilty with SM 8.9) * per tensor activation + per tensor weight quant linear calling torch._scaled_mm op (for compatibilty with SM 8.9) * per row activation + per row weight bmm calling torch.ops.fbgemm.f8f8bf16_rowwise_batched kernel (only works for SM 9.0+) can use batched scaled mm from torch when it's supported: pytorch/pytorch#157950 * dynamic quantization kwargs is added to the Float8Tensor directly * Added QuantizeTensorKwargs and QuantizeTensorToFloat8Kwargs to store key word args for Float8Tensor.to_float8 * Updated Float8DynamicActivationFloat8WeightConfig and Float8WeightOnlyConfig to use Float8Tensor Test Plan: python test/dtypes/test_affine_quantized_float.py python test/quantization/quantize_/workflows/float8/test_float8_tensor.py Reviewers: Subscribers: Tasks: Tags:
1 parent 23b0219 commit 3b4bc98

File tree

14 files changed

+1386
-192
lines changed

14 files changed

+1386
-192
lines changed

.github/workflows/1xH100_tests.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,15 @@ jobs:
2525
include:
2626
- name: H100
2727
runs-on: linux.aws.h100
28-
torch-spec: '--pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126'
28+
torch-spec: '--pre torch torchvision torchaudio fbgemm-gpu-genai --index-url https://download.pytorch.org/whl/nightly/cu126'
2929
gpu-arch-type: "cuda"
3030
gpu-arch-version: "12.4"
3131
permissions:
3232
id-token: write
3333
contents: read
3434
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
3535
with:
36-
timeout: 60
36+
timeout: 90
3737
runner: ${{ matrix.runs-on }}
3838
gpu-arch-type: ${{ matrix.gpu-arch-type }}
3939
gpu-arch-version: ${{ matrix.gpu-arch-version }}
@@ -46,8 +46,8 @@ jobs:
4646
pip install uv
4747
pip install ${{ matrix.torch-spec }}
4848
uv pip install -r dev-requirements.txt
49-
uv pip install vllm
5049
pip install .
5150
pytest test/integration --verbose -s
5251
pytest test/dtypes/test_affine_quantized_float.py --verbose -s
52+
python test/quantization/quantize_/workflows/float8/test_float8_tensor.py
5353
./test/float8/test_everything_single_gpu.sh

.github/workflows/1xL4_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ jobs:
4646
pip install uv
4747
pip install ${{ matrix.torch-spec }}
4848
uv pip install -r dev-requirements.txt
49-
uv pip install vllm
5049
pip install .
5150
pytest test/integration --verbose -s
5251
pytest test/dtypes/test_affine_quantized_float.py --verbose -s
5352
./test/float8/test_everything_single_gpu.sh
53+
python test/quantization/quantize_/workflows/float8/test_float8_tensor.py

test/dtypes/test_affine_quantized_float.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -737,6 +737,7 @@ def test_expected_kernels_on_gpu(self, granularity, torch_compile_mode):
737737
Verify that float8 quantization + torch.compile results in the
738738
expected number of kernels in the GPU trace.
739739
"""
740+
torch.compiler.reset()
740741

741742
M, K, N = 128, 256, 512
742743
m = torch.nn.Sequential(

test/dtypes/test_fbgemm_fp8.py

Lines changed: 0 additions & 153 deletions
This file was deleted.

0 commit comments

Comments
 (0)