Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 185 additions & 23 deletions torchrec/distributed/comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@

# pyre-strict

from abc import ABC, abstractmethod
from collections.abc import Sequence
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any, List, Optional, Tuple, TypeVar
from typing import Any, List, Optional, Tuple, TypeVar, Union

import torch
import torch.distributed as dist
import torch.distributed._functional_collectives

from torch import Tensor
from torch.autograd import Function
from torch.autograd.profiler import record_function
Expand Down Expand Up @@ -80,6 +81,45 @@ def torchrec_use_sync_collectives():
"""


class Comm(ABC):
"""
Interface for communication primitives.
A primitive primarily needs to handle 3 tasks:
1. **Memory allocation strategy**:
- Associate each call to a temporary buffer (flexible, simple)
- Reuse a persistent buffer (efficient, complex)
2. **Memory location**:
- NCCL memory pool
- Regular CUDA caching allocator
3. **Communication execution**:
- Actual communication primitive implementation
See `All2AllSingle` for a concrete example.
"""

@abstractmethod
def allocate(
self,
size: Sequence[Union[int, torch.SymInt]],
*,
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor:
"""
Allocate memory for communication buffers.
Args:
size (Sequence[Union[int, torch.SymInt]]): size of the tensor buffer
dtype (torch.dtype): dtype of the tensor buffer
device (torch.device): which device to allocate the tensor onto
Returns:
torch.Tensor: Allocated tensor buffer
Example:
```python
tensor = comm.allocate([1024, 512], dtype=torch.float32, device="cuda:0")
```
"""
...


class Request(Awaitable[W]):
"""
Defines a collective operation request for a process group on a tensor.
Expand Down Expand Up @@ -334,6 +374,83 @@ def _get_split_lengths_by_len(
return (my_len, splits)


class All2AllSingle(Comm):
"""
Interface for all-to-all single tensor communication.
This primitive performs an all-to-all scatter/gather operation where each
rank sends and receives data from all other ranks in a single operation.
"""

@abstractmethod
def __call__(
self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
output_split_sizes: Optional[list[int]] = None,
input_split_sizes: Optional[list[int]] = None,
group: Optional[dist.ProcessGroup] = None,
async_op: bool = False,
) -> Optional[dist.Work]:
"""
Execute all-to-all single operation.
Args:
output_tensor: Pre-allocated output tensor
input_tensor: Input tensor to scatter
output_split_sizes: Sizes for splitting the output (if None, split evenly)
input_split_sizes: Sizes for splitting the input (if None, split evenly)
group: Process group (if None, uses default group)
async_op: Whether to perform asynchronously
Returns:
Optional work handle if async_op=True, None otherwise
"""
...


class DefaultAllocMixin:
"""
Mixin providing default tensor allocation using PyTorch's standard allocator.
"""

def allocate(
self,
size: Sequence[Union[int, torch.SymInt]],
*,
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor:
"""Allocate tensor using torch.empty with standard CUDA caching allocator."""
return torch.empty(size, dtype=dtype, device=device)


class DefaultAll2AllSingle(DefaultAllocMixin, All2AllSingle):
"""
Default implementation of all-to-all single using PyTorch's distributed primitives.

This implementation uses:
- Standard CUDA caching allocator for memory allocation
- PyTorch's dist.all_to_all_single for communication
"""

def __call__(
self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
output_split_sizes: Optional[list[int]] = None,
input_split_sizes: Optional[list[int]] = None,
group: Optional[dist.ProcessGroup] = None,
async_op: bool = False,
) -> Optional[dist.Work]:
"""Execute all-to-all using PyTorch's native implementation."""
return dist.all_to_all_single(
output_tensor,
input_tensor,
output_split_sizes,
input_split_sizes,
group=group,
async_op=async_op,
)


def alltoall_pooled(
a2a_pooled_embs_tensor: Tensor,
batch_size_per_rank: List[int],
Expand All @@ -342,12 +459,14 @@ def alltoall_pooled(
cumsum_dim_sum_per_rank_tensor: Optional[Tensor] = None,
group: Optional[dist.ProcessGroup] = None,
codecs: Optional[QuantizedCommCodecs] = None,
all_to_all_single_comm: Optional[All2AllSingle] = None,
) -> Awaitable[Tensor]:
"""
Performs AlltoAll operation for a single pooled embedding tensor. Each process
splits the input pooled embeddings tensor based on the world size, and then scatters
the split list to all processes in the group. Then concatenates the received tensors
from all processes in the group and returns a single output tensor.
Performs AlltoAll operation for a single pooled embedding tensor.

Each process splits the input pooled embeddings tensor based on the world size,
and then scatters the split list to all processes in the group. Then concatenates
the received tensors from all processes in the group and returns a single output tensor.

Args:
a2a_pooled_embs_tensor (Tensor): input pooled embeddings. Must be pooled
Expand All @@ -367,7 +486,19 @@ def alltoall_pooled(
codecs (Optional[QuantizedCommCodecs]): quantized communication codecs.

Returns:
Awaitable[Tensor]: async work handle (`Awaitable`), which can be `wait()` later to get the resulting tensor.
Awaitable[Tensor]: Async work handle which can be `wait()`-ed later to
get the resulting tensor.

Example:
```python
# Using default implementation
result = alltoall_pooled(embeddings, batch_sizes, dim_sums)
output = result.wait()
# Using custom implementation
custom_comm = MyCustomAll2All()
result = alltoall_pooled(embeddings, batch_sizes, dim_sums,
all2all_single_comm=custom_comm)
```

.. warning::
`alltoall_pooled` is experimental and subject to change.
Expand All @@ -391,7 +522,15 @@ def alltoall_pooled(
return NoWait(all2all_pooled_sync(group, a2ai, a2a_pooled_embs_tensor))

myreq = Request(group, device=a2a_pooled_embs_tensor.device)
All2All_Pooled_Req.apply(group, myreq, a2ai, a2a_pooled_embs_tensor)
if all_to_all_single_comm is None:
all_to_all_single_comm = DefaultAll2AllSingle()
All2All_Pooled_Req.apply(
group,
myreq,
a2ai,
a2a_pooled_embs_tensor,
all_to_all_single_comm,
)
return myreq


Expand Down Expand Up @@ -476,6 +615,7 @@ def variable_batch_alltoall_pooled(
emb_dim_per_rank_per_feature: List[List[int]],
group: Optional[dist.ProcessGroup] = None,
codecs: Optional[QuantizedCommCodecs] = None,
all_to_all_single_comm: Optional[All2AllSingle] = None,
) -> Awaitable[Tensor]:

if group is None:
Expand All @@ -497,7 +637,11 @@ def variable_batch_alltoall_pooled(
)

myreq = Request(group, device=a2a_pooled_embs_tensor.device)
Variable_Batch_All2All_Pooled_Req.apply(group, myreq, a2ai, a2a_pooled_embs_tensor)
if all_to_all_single_comm is None:
all_to_all_single_comm = DefaultAll2AllSingle()
Variable_Batch_All2All_Pooled_Req.apply(
group, myreq, a2ai, a2a_pooled_embs_tensor, all_to_all_single_comm
)
return myreq


Expand Down Expand Up @@ -1138,6 +1282,7 @@ def forward(
myreq: Request[Tensor],
a2ai: All2AllPooledInfo,
input_embeddings: Tensor,
all_to_all_single_comm: All2AllSingle,
) -> Tensor:
my_rank = dist.get_rank(pg)
(B_global, D_local_sum) = input_embeddings.shape
Expand Down Expand Up @@ -1191,16 +1336,24 @@ def forward(
input_split_sizes = [D_local_sum * B_rank for B_rank in batch_size_per_rank]
qcomm_ctx = None

sharded_output_embeddings = torch.empty(
sum(output_split_sizes),
sharded_output_embeddings = all_to_all_single_comm.allocate(
(sum(output_split_sizes),),
dtype=sharded_input_embeddings.dtype,
device=sharded_input_embeddings.device,
)

sharded_input_embeddings_registered = all_to_all_single_comm.allocate(
sharded_input_embeddings.shape,
dtype=sharded_input_embeddings.dtype,
device=sharded_input_embeddings.device,
)

sharded_input_embeddings_registered.copy_(sharded_input_embeddings)

with record_function("## All2All_Pooled_fwd ##"):
req = dist.all_to_all_single(
output=sharded_output_embeddings,
input=sharded_input_embeddings,
req = all_to_all_single_comm(
output_tensor=sharded_output_embeddings,
input_tensor=sharded_input_embeddings_registered,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=pg,
Expand All @@ -1218,7 +1371,7 @@ def forward(

@staticmethod
# pyre-fixme[2]: Parameter must be annotated.
def backward(ctx, *unused) -> Tuple[None, None, None, Tensor]:
def backward(ctx, *unused) -> Tuple[None, None, None, Tensor, None]:
pg = ctx.pg
my_rank = dist.get_rank(pg)
myreq = ctx.myreq
Expand All @@ -1241,7 +1394,7 @@ def backward(ctx, *unused) -> Tuple[None, None, None, Tensor]:
grad_input.div_(dist.get_world_size(ctx.pg))
myreq.tensor = None
myreq.dummy_tensor = None
return (None, None, None, grad_input)
return (None, None, None, grad_input, None)


class All2All_Pooled_Wait(Function):
Expand Down Expand Up @@ -1385,6 +1538,7 @@ def forward(
myreq: Request[Tensor],
a2ai: VariableBatchAll2AllPooledInfo,
input_embeddings: Tensor,
all_to_all_single_comm: All2AllSingle,
) -> Tensor:
my_rank = dist.get_rank(pg)

Expand Down Expand Up @@ -1438,16 +1592,24 @@ def forward(
for split in input_split_sizes
]

sharded_output_embeddings = torch.empty(
sum(output_split_sizes),
sharded_output_embeddings = all_to_all_single_comm.allocate(
(sum(output_split_sizes),),
dtype=sharded_input_embeddings.dtype,
device=sharded_input_embeddings.device,
)

sharded_input_embeddings_registered = all_to_all_single_comm.allocate(
sharded_input_embeddings.shape,
dtype=sharded_input_embeddings.dtype,
device=sharded_input_embeddings.device,
)

sharded_input_embeddings_registered.copy_(sharded_input_embeddings)

with record_function("## Variable_Batch_All2All_Pooled_fwd ##"):
req = dist.all_to_all_single(
output=sharded_output_embeddings,
input=sharded_input_embeddings,
req = all_to_all_single_comm(
output_tensor=sharded_output_embeddings,
input_tensor=sharded_input_embeddings_registered,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=pg,
Expand All @@ -1466,7 +1628,7 @@ def forward(
@staticmethod
# pyre-fixme[2]: Parameter must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def backward(ctx, *unused) -> Tuple[None, None, None, Tensor]:
def backward(ctx, *unused) -> Tuple[None, None, None, Tensor, None]:
myreq = ctx.myreq
a2ai = myreq.a2ai
assert myreq.req is not None
Expand All @@ -1486,7 +1648,7 @@ def backward(ctx, *unused) -> Tuple[None, None, None, Tensor]:
grad_input.div_(dist.get_world_size(ctx.pg))
myreq.tensor = None
myreq.dummy_tensor = None
return (None, None, None, grad_input)
return (None, None, None, grad_input, None)


class Variable_Batch_All2All_Pooled_Wait(Function):
Expand Down