Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion torchao/prototype/moe_training/scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import logging
from typing import Optional

import torch
Expand All @@ -18,6 +19,8 @@
_is_column_major,
)

logger: logging.Logger = logging.getLogger(__name__)


def _scaled_grouped_mm(
A: torch.Tensor,
Expand All @@ -36,8 +39,8 @@ def _scaled_grouped_mm(
and in column-major memory layout.
offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor.
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
use_triton_for_per_group_scales (bool): Whether to use custom triton kernels to compute per-group scales. Default is True.
"""
logger.debug("Using scaled_grouped_mm")
return _Float8GroupedMM.apply(
A,
B_t,
Expand Down
45 changes: 29 additions & 16 deletions torchao/prototype/moe_training/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@

import torch
import torch.utils._pytree as pytree
from torch import nn
from torch._prims_common import suggest_memory_format
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import MixedPrecisionPolicy

from torchao.prototype.moe_training import _scaled_grouped_mm

logger: logging.Logger = logging.getLogger(__name__)


_ops_to_preserve_subclass = {
torch.ops.aten.empty_like.default,
torch.ops.aten.new_zeros.default,
Expand Down Expand Up @@ -64,12 +66,14 @@ def __init__(
tensor: torch.Tensor,
dtype: torch.dtype,
):
self._data = tensor
self._data = tensor.to(dtype)
self._dtype = dtype

@classmethod
def __torch_function__(cls, func, types, args, kwargs={}):
logger.info(f"{func.__name__}, args: {args}, kwargs: {kwargs}")
logger.debug(
f"ScaledGroupedMMTensor func: {func.__name__}, args: {args}, kwargs: {kwargs}"
)
# override the grouped mm op to use the differentiable _scaled_grouped_mm
if func.__name__ == cls.grouped_mm_func_name:
# Use torchao scaled grouped mm with dynamic quant for
Expand Down Expand Up @@ -100,17 +104,8 @@ def __torch_dispatch__(cls, func, types, args, kwargs={}):
if func == torch.ops.aten.detach.default:
return ScaledGroupedMMTensor(args[0]._data, args[0]._dtype)

# unwrap args and kwargs
dtype: Optional[torch.dtype] = None

def unwrap(t):
nonlocal dtype
if dtype is None:
dtype = t._dtype
else:
assert t._dtype == dtype
return t._data

# unwrap args/kwargs
unwrap = lambda x: x._data if isinstance(x, ScaledGroupedMMTensor) else x
args, kwargs = pytree.tree_map_only(
ScaledGroupedMMTensor, unwrap, (args, kwargs or {})
)
Expand All @@ -125,7 +120,7 @@ def unwrap(t):
# wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass
return pytree.tree_map_only(
torch.Tensor,
lambda x: ScaledGroupedMMTensor(x, dtype),
lambda x: ScaledGroupedMMTensor(x, x.dtype),
out,
)

Expand All @@ -142,9 +137,20 @@ def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
flatten_spec["_dtype"],
)

def fsdp_pre_all_gather(self, mesh):
# fsdp hooks based on https://github.com/pytorch/pytorch/blob/20e40492b046b9287726d3ec656117e4dc38f0e2/test/distributed/_composable/fsdp/test_fully_shard_extensions.py#L81
def fsdp_pre_all_gather(
self,
mesh: DeviceMesh,
outer_size: torch.Size,
outer_stride: tuple[int, ...],
module: nn.Module,
mp_policy: MixedPrecisionPolicy,
):
all_gather_inputs = (self._data,)
all_gather_metadata = ()
logger.debug(
f"ScaledGroupedMMTensor fsdp_pre_all_gather: self._data.dtype={self._data.dtype}, param_dtype: {mp_policy.param_dtype}"
)
return all_gather_inputs, all_gather_metadata

def fsdp_post_all_gather(
Expand All @@ -156,6 +162,13 @@ def fsdp_post_all_gather(
out: Optional[torch.Tensor] = None,
):
(data,) = all_gather_outputs
logger.debug(
f"ScaledGroupedMMTensor fsdp_post_all_gather: data.dtype={data.dtype}, param_dtype: {param_dtype}"
)

if out is not None:
return

output = ScaledGroupedMMTensor(data, param_dtype)
inner_tensors = (data,)
return output, inner_tensors
Loading