From 64bfa3054dcb2aa03aec2a4c88b2f345f4b35d30 Mon Sep 17 00:00:00 2001 From: Nipun Gupta Date: Fri, 18 Jul 2025 14:16:50 -0700 Subject: [PATCH] Remove torch._running_with_deploy() from fbcode, Fix exception handling for torch.ops.load_libraries (#3213) Summary: Rollback Plan: Differential Revision: D78583233 --- torchrec/distributed/comm_ops.py | 353 +++++++++--------- torchrec/distributed/dist_data.py | 4 +- torchrec/distributed/embedding.py | 4 +- torchrec/distributed/embeddingbag.py | 4 +- torchrec/distributed/model_parallel.py | 4 +- torchrec/distributed/quant_embedding.py | 4 +- .../distributed/train_pipeline/tracing.py | 7 +- .../train_pipeline/train_pipelines.py | 3 +- torchrec/modules/itep_modules.py | 2 +- torchrec/quant/embedding_modules.py | 4 +- torchrec/sparse/jagged_tensor.py | 4 +- 11 files changed, 198 insertions(+), 195 deletions(-) diff --git a/torchrec/distributed/comm_ops.py b/torchrec/distributed/comm_ops.py index 4d950c7e9..09418be40 100644 --- a/torchrec/distributed/comm_ops.py +++ b/torchrec/distributed/comm_ops.py @@ -25,8 +25,8 @@ try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") -except OSError: - pass +except (OSError, RuntimeError): + from fbgemm_gpu import sparse_ops # noqa: F401, E402 # OSS @@ -54,10 +54,6 @@ def get_gradient_division() -> bool: def set_use_sync_collectives(val: bool) -> None: - if val and torch._running_with_deploy(): - raise RuntimeError( - "TorchRec sync_collectives are not supported in torch.deploy." - ) global USE_SYNC_COLLECTIVES USE_SYNC_COLLECTIVES = val @@ -2356,202 +2352,213 @@ def backward(ctx, grad_output: Tensor) -> Tuple[None, None, Tensor]: return (None, None, myreq.dummy_tensor) -if not torch._running_with_deploy(): # noqa C901 - # Torch Library op def can not be used in Deploy - class AllToAllSingle(torch.autograd.Function): - @staticmethod - # pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently. - def forward( - # pyre-fixme[2]: Parameter must be annotated. - ctx, - input: Tensor, - output_split_sizes: List[int], - input_split_sizes: List[int], - group_name: str, - group_size: int, - gradient_division: bool, - ) -> Tensor: - ctx.output_split_sizes = input_split_sizes - ctx.input_split_sizes = output_split_sizes - ctx.group_name = group_name - ctx.group_size = group_size - ctx.gradient_division = gradient_division - return torch.distributed._functional_collectives.all_to_all_single( - input, output_split_sizes, input_split_sizes, group_name - ) - - @staticmethod - # pyre-ignore - def backward(ctx, grad): - grad = torch.distributed._functional_collectives.all_to_all_single( - grad, - ctx.output_split_sizes, - ctx.input_split_sizes, - ctx.group_name, - ) - if ctx.gradient_division: - grad.div_(ctx.group_size) - - return grad, None, None, None, None, None - - # torchrec::reduce_scatter_tensor - @torch.library.custom_op("torchrec::reduce_scatter_tensor", mutates_args=()) - def reduce_scatter_tensor( +class AllToAllSingle(torch.autograd.Function): + @staticmethod + # pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently. + def forward( + # pyre-fixme[2]: Parameter must be annotated. + ctx, input: Tensor, - reduceOp: str, - group_size: int, + output_split_sizes: List[int], + input_split_sizes: List[int], group_name: str, - gradient_division: bool, - ) -> Tensor: - out = torch.ops._c10d_functional.reduce_scatter_tensor( - input, - reduceOp, - group_size, - group_name, - ) - return torch.ops._c10d_functional.wait_tensor(out) - - @torch.library.register_fake("torchrec::reduce_scatter_tensor") - def reduce_scatter_tensor_fake( - input: Tensor, - reduceOp: str, group_size: int, - group_name: str, gradient_division: bool, ) -> Tensor: - return torch.ops._c10d_functional.reduce_scatter_tensor( - input, - reduceOp, - group_size, - group_name, - ) - - # pyre-ignore - def reduce_scatter_tensor_setup_context(ctx, inputs, output) -> None: - _, _, group_size, group_name, gradient_division = inputs - ctx.group_size = group_size + ctx.output_split_sizes = input_split_sizes + ctx.input_split_sizes = output_split_sizes ctx.group_name = group_name + ctx.group_size = group_size ctx.gradient_division = gradient_division + return torch.distributed._functional_collectives.all_to_all_single( + input, output_split_sizes, input_split_sizes, group_name + ) + @staticmethod # pyre-ignore - def reduce_scatter_tensor_backward(ctx, grad): - # TODO(ivankobzarev): Support codecs(quantization) on backward - out = torch.ops._c10d_functional.all_gather_into_tensor( + def backward(ctx, grad): + grad = torch.distributed._functional_collectives.all_to_all_single( grad, - ctx.group_size, + ctx.output_split_sizes, + ctx.input_split_sizes, ctx.group_name, ) - grad = torch.ops._c10d_functional.wait_tensor(out) if ctx.gradient_division: grad.div_(ctx.group_size) return grad, None, None, None, None, None - torch.library.register_autograd( - "torchrec::reduce_scatter_tensor", - reduce_scatter_tensor_backward, - setup_context=reduce_scatter_tensor_setup_context, + +# torchrec::reduce_scatter_tensor +@torch.library.custom_op("torchrec::reduce_scatter_tensor", mutates_args=()) +def reduce_scatter_tensor( + input: Tensor, + reduceOp: str, + group_size: int, + group_name: str, + gradient_division: bool, +) -> Tensor: + out = torch.ops._c10d_functional.reduce_scatter_tensor( + input, + reduceOp, + group_size, + group_name, ) + return torch.ops._c10d_functional.wait_tensor(out) - # torchrec::all_gather_into_tensor - @torch.library.custom_op("torchrec::all_gather_into_tensor", mutates_args=()) - def all_gather_into_tensor( - shard: Tensor, - gather_dim: int, - group_size: int, - group_name: str, - gradient_division: bool, - ) -> Tensor: - out = torch.ops._c10d_functional.all_gather_into_tensor( - shard, group_size, group_name - ) - return torch.ops._c10d_functional.wait_tensor(out) - @torch.library.register_fake("torchrec::all_gather_into_tensor") - def all_gather_into_tensor_fake( - shard: Tensor, - gather_dim: int, - group_size: int, - group_name: str, - gradient_division: bool, - ) -> Tensor: - return torch.ops._c10d_functional.all_gather_into_tensor( - shard, group_size, group_name - ) +@torch.library.register_fake("torchrec::reduce_scatter_tensor") +def reduce_scatter_tensor_fake( + input: Tensor, + reduceOp: str, + group_size: int, + group_name: str, + gradient_division: bool, +) -> Tensor: + return torch.ops._c10d_functional.reduce_scatter_tensor( + input, + reduceOp, + group_size, + group_name, + ) - # pyre-ignore - def all_gather_into_tensor_setup_context(ctx, inputs, output) -> None: - _, gather_dim, group_size, group_name, gradient_division = inputs - ctx.group_size = group_size - ctx.group_name = group_name - ctx.gradient_division = gradient_division - # pyre-ignore - def all_gather_into_tensor_backward(ctx, grad): - # TODO(ivankobzarev): Support codecs(quantization) on backward - out = torch.ops._c10d_functional.reduce_scatter_tensor( - grad, - "sum", - ctx.group_size, - ctx.group_name, - ) - grad = torch.ops._c10d_functional.wait_tensor(out) - if ctx.gradient_division: - grad.div_(ctx.group_size) +# pyre-ignore +def reduce_scatter_tensor_setup_context(ctx, inputs, output) -> None: + _, _, group_size, group_name, gradient_division = inputs + ctx.group_size = group_size + ctx.group_name = group_name + ctx.gradient_division = gradient_division - return grad, None, None, None, None - torch.library.register_autograd( - "torchrec::all_gather_into_tensor", - all_gather_into_tensor_backward, - setup_context=all_gather_into_tensor_setup_context, +# pyre-ignore +def reduce_scatter_tensor_backward(ctx, grad): + # TODO(ivankobzarev): Support codecs(quantization) on backward + out = torch.ops._c10d_functional.all_gather_into_tensor( + grad, + ctx.group_size, + ctx.group_name, ) + grad = torch.ops._c10d_functional.wait_tensor(out) + if ctx.gradient_division: + grad.div_(ctx.group_size) - @torch.library.custom_op("torchrec::_split_1d_cat_2d", mutates_args=()) - def _split_1d_cat_2d_impl( - t: torch.Tensor, dim0: int, dim1s: List[int] - ) -> torch.Tensor: - torch._check_is_size(dim0) - [torch._check_is_size(dim1) for dim1 in dim1s] - splits: List[torch.Tensor] = t.split([dim0 * dim1 for dim1 in dim1s]) - return torch.cat( - [s.reshape(dim0, dim1) for s, dim1 in zip(splits, dim1s)], - dim=1, - ) + return grad, None, None, None, None, None + + +torch.library.register_autograd( + "torchrec::reduce_scatter_tensor", + reduce_scatter_tensor_backward, + setup_context=reduce_scatter_tensor_setup_context, +) - @torch.library.register_fake("torchrec::_split_1d_cat_2d") - def _split_1d_cat_2d_impl_abstract( - t: torch.Tensor, dim0: int, dim1s: List[int] - ) -> torch.Tensor: - return t.new_empty([dim0, sum(dim1s)]) - @torch.library.custom_op( - "torchrec::_split_1d_cat_2d_backward_impl", mutates_args=() +# torchrec::all_gather_into_tensor +@torch.library.custom_op("torchrec::all_gather_into_tensor", mutates_args=()) +def all_gather_into_tensor( + shard: Tensor, + gather_dim: int, + group_size: int, + group_name: str, + gradient_division: bool, +) -> Tensor: + out = torch.ops._c10d_functional.all_gather_into_tensor( + shard, group_size, group_name ) - def _split_1d_cat_2d_backward_impl( - grad: torch.Tensor, dim1s: List[int] - ) -> torch.Tensor: - splits = grad.split(dim1s, dim=1) - return torch.cat([s.reshape(-1) for s in splits], dim=0) - - @torch.library.register_fake("torchrec::_split_1d_cat_2d_backward_impl") - def _split_1d_cat_2d_backward_impl_fake( - grad: torch.Tensor, dim1s: List[int] - ) -> torch.Tensor: - return grad.new_empty([grad.numel()]) + return torch.ops._c10d_functional.wait_tensor(out) - # pyre-ignore - def _split_1d_cat_2d_backward(ctx, grad): - ret = torch.ops.torchrec._split_1d_cat_2d_backward_impl(grad, ctx.dim1s) - return ret, None, None - # pyre-ignore - def _split_1d_cat_2d_setup_context(ctx, inputs, output): - (x, dim0, dim1s) = inputs - ctx.dim1s = dim1s - - torch.library.register_autograd( - "torchrec::_split_1d_cat_2d", - _split_1d_cat_2d_backward, - setup_context=_split_1d_cat_2d_setup_context, +@torch.library.register_fake("torchrec::all_gather_into_tensor") +def all_gather_into_tensor_fake( + shard: Tensor, + gather_dim: int, + group_size: int, + group_name: str, + gradient_division: bool, +) -> Tensor: + return torch.ops._c10d_functional.all_gather_into_tensor( + shard, group_size, group_name ) + + +# pyre-ignore +def all_gather_into_tensor_setup_context(ctx, inputs, output) -> None: + _, gather_dim, group_size, group_name, gradient_division = inputs + ctx.group_size = group_size + ctx.group_name = group_name + ctx.gradient_division = gradient_division + + +# pyre-ignore +def all_gather_into_tensor_backward(ctx, grad): + # TODO(ivankobzarev): Support codecs(quantization) on backward + out = torch.ops._c10d_functional.reduce_scatter_tensor( + grad, + "sum", + ctx.group_size, + ctx.group_name, + ) + grad = torch.ops._c10d_functional.wait_tensor(out) + if ctx.gradient_division: + grad.div_(ctx.group_size) + + return grad, None, None, None, None + + +torch.library.register_autograd( + "torchrec::all_gather_into_tensor", + all_gather_into_tensor_backward, + setup_context=all_gather_into_tensor_setup_context, +) + + +@torch.library.custom_op("torchrec::_split_1d_cat_2d", mutates_args=()) +def _split_1d_cat_2d_impl(t: torch.Tensor, dim0: int, dim1s: List[int]) -> torch.Tensor: + torch._check_is_size(dim0) + [torch._check_is_size(dim1) for dim1 in dim1s] + splits: List[torch.Tensor] = t.split([dim0 * dim1 for dim1 in dim1s]) + return torch.cat( + [s.reshape(dim0, dim1) for s, dim1 in zip(splits, dim1s)], + dim=1, + ) + + +@torch.library.register_fake("torchrec::_split_1d_cat_2d") +def _split_1d_cat_2d_impl_abstract( + t: torch.Tensor, dim0: int, dim1s: List[int] +) -> torch.Tensor: + return t.new_empty([dim0, sum(dim1s)]) + + +@torch.library.custom_op("torchrec::_split_1d_cat_2d_backward_impl", mutates_args=()) +def _split_1d_cat_2d_backward_impl( + grad: torch.Tensor, dim1s: List[int] +) -> torch.Tensor: + splits = grad.split(dim1s, dim=1) + return torch.cat([s.reshape(-1) for s in splits], dim=0) + + +@torch.library.register_fake("torchrec::_split_1d_cat_2d_backward_impl") +def _split_1d_cat_2d_backward_impl_fake( + grad: torch.Tensor, dim1s: List[int] +) -> torch.Tensor: + return grad.new_empty([grad.numel()]) + + +# pyre-ignore +def _split_1d_cat_2d_backward(ctx, grad): + ret = torch.ops.torchrec._split_1d_cat_2d_backward_impl(grad, ctx.dim1s) + return ret, None, None + + +# pyre-ignore +def _split_1d_cat_2d_setup_context(ctx, inputs, output): + (x, dim0, dim1s) = inputs + ctx.dim1s = dim1s + + +torch.library.register_autograd( + "torchrec::_split_1d_cat_2d", + _split_1d_cat_2d_backward, + setup_context=_split_1d_cat_2d_setup_context, +) diff --git a/torchrec/distributed/dist_data.py b/torchrec/distributed/dist_data.py index 4c66511ef..7dbe7dcc1 100644 --- a/torchrec/distributed/dist_data.py +++ b/torchrec/distributed/dist_data.py @@ -38,8 +38,8 @@ torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings_cpu" ) -except OSError: - pass +except (OSError, RuntimeError): + from fbgemm_gpu import sparse_ops # noqa: F401, E402 # OSS try: diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index ef6a67098..d20e20cfe 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -105,8 +105,8 @@ try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") -except OSError: - pass +except (OSError, RuntimeError): + from fbgemm_gpu import sparse_ops # noqa: F401, E402 logger: logging.Logger = logging.getLogger(__name__) diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 0b554cd02..2ac76b837 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -113,8 +113,8 @@ torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:index_select_ops") -except OSError: - pass +except (OSError, RuntimeError): + from fbgemm_gpu import sparse_ops # noqa: F401, E402 def _pin_and_move(tensor: torch.Tensor, device: torch.device) -> torch.Tensor: diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py index d09f30781..17b2f5410 100644 --- a/torchrec/distributed/model_parallel.py +++ b/torchrec/distributed/model_parallel.py @@ -58,8 +58,8 @@ try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") -except OSError: - pass +except (OSError, RuntimeError): + from fbgemm_gpu import sparse_ops # noqa: F401, E402 _DDP_STATE_DICT_PREFIX = "module." diff --git a/torchrec/distributed/quant_embedding.py b/torchrec/distributed/quant_embedding.py index 94d7574f6..de219654b 100644 --- a/torchrec/distributed/quant_embedding.py +++ b/torchrec/distributed/quant_embedding.py @@ -100,8 +100,8 @@ try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") -except OSError: - pass +except (OSError, RuntimeError): + from fbgemm_gpu import sparse_ops # noqa: F401, E402 logger: logging.Logger = logging.getLogger(__name__) diff --git a/torchrec/distributed/train_pipeline/tracing.py b/torchrec/distributed/train_pipeline/tracing.py index 946348785..e829e2b69 100644 --- a/torchrec/distributed/train_pipeline/tracing.py +++ b/torchrec/distributed/train_pipeline/tracing.py @@ -13,12 +13,9 @@ import torch -if not torch._running_with_deploy(): - from torch.distributed._composable.fsdp.fully_shard import FSDPModule as FSDP2 -else: - class FSDP2: - pass +class FSDP2: + pass from torch.distributed.fsdp import FullyShardedDataParallel as FSDP diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index 9007f55bd..195cc1115 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -80,8 +80,7 @@ except ImportError: logger.warning("torchrec_use_sync_collectives is not available") -if not torch._running_with_deploy(): - torch.ops.import_module("fbgemm_gpu.sparse_ops") +torch.ops.import_module("fbgemm_gpu.sparse_ops") # Note: doesn't make much sense but better than throwing. diff --git a/torchrec/modules/itep_modules.py b/torchrec/modules/itep_modules.py index 4bce8e37d..8fe7e618a 100644 --- a/torchrec/modules/itep_modules.py +++ b/torchrec/modules/itep_modules.py @@ -29,7 +29,7 @@ torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu:intraining_embedding_pruning_gpu" ) -except OSError: +except (OSError, RuntimeError): pass logger: logging.Logger = logging.getLogger(__name__) diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py index 3e979b34d..d8b37c951 100644 --- a/torchrec/quant/embedding_modules.py +++ b/torchrec/quant/embedding_modules.py @@ -79,8 +79,8 @@ try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") -except OSError: - pass +except (OSError, RuntimeError): + from fbgemm_gpu import sparse_ops # noqa: F401, E402 # OSS try: diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index ebdce6acb..d6e4d0fcd 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -44,8 +44,8 @@ torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_gpu" ) -except OSError: - pass +except (OSError, RuntimeError): + from fbgemm_gpu import sparse_ops # noqa: F401, E402 logger: logging.Logger = logging.getLogger()