Skip to content

Commit a87c96f

Browse files
PaliCfacebook-github-bot
authored andcommitted
Remove torch._running_with_deploy() from fbcode
Summary: As per https://fb.workplace.com/groups/pytorch.dev/permalink/1828123831099422 we can now safely remove “torch.is_deploy_running”. This commit does this! Differential Revision: D78525065
1 parent 45d5c4d commit a87c96f

File tree

3 files changed

+181
-178
lines changed

3 files changed

+181
-178
lines changed

torchrec/distributed/comm_ops.py

Lines changed: 178 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,6 @@ def get_gradient_division() -> bool:
5454

5555

5656
def set_use_sync_collectives(val: bool) -> None:
57-
if val and torch._running_with_deploy():
58-
raise RuntimeError(
59-
"TorchRec sync_collectives are not supported in torch.deploy."
60-
)
6157

6258
global USE_SYNC_COLLECTIVES
6359
USE_SYNC_COLLECTIVES = val
@@ -2356,202 +2352,213 @@ def backward(ctx, grad_output: Tensor) -> Tuple[None, None, Tensor]:
23562352
return (None, None, myreq.dummy_tensor)
23572353

23582354

2359-
if not torch._running_with_deploy(): # noqa C901
2360-
# Torch Library op def can not be used in Deploy
2361-
class AllToAllSingle(torch.autograd.Function):
2362-
@staticmethod
2363-
# pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently.
2364-
def forward(
2365-
# pyre-fixme[2]: Parameter must be annotated.
2366-
ctx,
2367-
input: Tensor,
2368-
output_split_sizes: List[int],
2369-
input_split_sizes: List[int],
2370-
group_name: str,
2371-
group_size: int,
2372-
gradient_division: bool,
2373-
) -> Tensor:
2374-
ctx.output_split_sizes = input_split_sizes
2375-
ctx.input_split_sizes = output_split_sizes
2376-
ctx.group_name = group_name
2377-
ctx.group_size = group_size
2378-
ctx.gradient_division = gradient_division
2379-
return torch.distributed._functional_collectives.all_to_all_single(
2380-
input, output_split_sizes, input_split_sizes, group_name
2381-
)
2382-
2383-
@staticmethod
2384-
# pyre-ignore
2385-
def backward(ctx, grad):
2386-
grad = torch.distributed._functional_collectives.all_to_all_single(
2387-
grad,
2388-
ctx.output_split_sizes,
2389-
ctx.input_split_sizes,
2390-
ctx.group_name,
2391-
)
2392-
if ctx.gradient_division:
2393-
grad.div_(ctx.group_size)
2394-
2395-
return grad, None, None, None, None, None
2396-
2397-
# torchrec::reduce_scatter_tensor
2398-
@torch.library.custom_op("torchrec::reduce_scatter_tensor", mutates_args=())
2399-
def reduce_scatter_tensor(
2355+
class AllToAllSingle(torch.autograd.Function):
2356+
@staticmethod
2357+
# pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently.
2358+
def forward(
2359+
# pyre-fixme[2]: Parameter must be annotated.
2360+
ctx,
24002361
input: Tensor,
2401-
reduceOp: str,
2402-
group_size: int,
2362+
output_split_sizes: List[int],
2363+
input_split_sizes: List[int],
24032364
group_name: str,
2404-
gradient_division: bool,
2405-
) -> Tensor:
2406-
out = torch.ops._c10d_functional.reduce_scatter_tensor(
2407-
input,
2408-
reduceOp,
2409-
group_size,
2410-
group_name,
2411-
)
2412-
return torch.ops._c10d_functional.wait_tensor(out)
2413-
2414-
@torch.library.register_fake("torchrec::reduce_scatter_tensor")
2415-
def reduce_scatter_tensor_fake(
2416-
input: Tensor,
2417-
reduceOp: str,
24182365
group_size: int,
2419-
group_name: str,
24202366
gradient_division: bool,
24212367
) -> Tensor:
2422-
return torch.ops._c10d_functional.reduce_scatter_tensor(
2423-
input,
2424-
reduceOp,
2425-
group_size,
2426-
group_name,
2427-
)
2428-
2429-
# pyre-ignore
2430-
def reduce_scatter_tensor_setup_context(ctx, inputs, output) -> None:
2431-
_, _, group_size, group_name, gradient_division = inputs
2432-
ctx.group_size = group_size
2368+
ctx.output_split_sizes = input_split_sizes
2369+
ctx.input_split_sizes = output_split_sizes
24332370
ctx.group_name = group_name
2371+
ctx.group_size = group_size
24342372
ctx.gradient_division = gradient_division
2373+
return torch.distributed._functional_collectives.all_to_all_single(
2374+
input, output_split_sizes, input_split_sizes, group_name
2375+
)
24352376

2377+
@staticmethod
24362378
# pyre-ignore
2437-
def reduce_scatter_tensor_backward(ctx, grad):
2438-
# TODO(ivankobzarev): Support codecs(quantization) on backward
2439-
out = torch.ops._c10d_functional.all_gather_into_tensor(
2379+
def backward(ctx, grad):
2380+
grad = torch.distributed._functional_collectives.all_to_all_single(
24402381
grad,
2441-
ctx.group_size,
2382+
ctx.output_split_sizes,
2383+
ctx.input_split_sizes,
24422384
ctx.group_name,
24432385
)
2444-
grad = torch.ops._c10d_functional.wait_tensor(out)
24452386
if ctx.gradient_division:
24462387
grad.div_(ctx.group_size)
24472388

24482389
return grad, None, None, None, None, None
24492390

2450-
torch.library.register_autograd(
2451-
"torchrec::reduce_scatter_tensor",
2452-
reduce_scatter_tensor_backward,
2453-
setup_context=reduce_scatter_tensor_setup_context,
2391+
2392+
# torchrec::reduce_scatter_tensor
2393+
@torch.library.custom_op("torchrec::reduce_scatter_tensor", mutates_args=())
2394+
def reduce_scatter_tensor(
2395+
input: Tensor,
2396+
reduceOp: str,
2397+
group_size: int,
2398+
group_name: str,
2399+
gradient_division: bool,
2400+
) -> Tensor:
2401+
out = torch.ops._c10d_functional.reduce_scatter_tensor(
2402+
input,
2403+
reduceOp,
2404+
group_size,
2405+
group_name,
24542406
)
2407+
return torch.ops._c10d_functional.wait_tensor(out)
24552408

2456-
# torchrec::all_gather_into_tensor
2457-
@torch.library.custom_op("torchrec::all_gather_into_tensor", mutates_args=())
2458-
def all_gather_into_tensor(
2459-
shard: Tensor,
2460-
gather_dim: int,
2461-
group_size: int,
2462-
group_name: str,
2463-
gradient_division: bool,
2464-
) -> Tensor:
2465-
out = torch.ops._c10d_functional.all_gather_into_tensor(
2466-
shard, group_size, group_name
2467-
)
2468-
return torch.ops._c10d_functional.wait_tensor(out)
24692409

2470-
@torch.library.register_fake("torchrec::all_gather_into_tensor")
2471-
def all_gather_into_tensor_fake(
2472-
shard: Tensor,
2473-
gather_dim: int,
2474-
group_size: int,
2475-
group_name: str,
2476-
gradient_division: bool,
2477-
) -> Tensor:
2478-
return torch.ops._c10d_functional.all_gather_into_tensor(
2479-
shard, group_size, group_name
2480-
)
2410+
@torch.library.register_fake("torchrec::reduce_scatter_tensor")
2411+
def reduce_scatter_tensor_fake(
2412+
input: Tensor,
2413+
reduceOp: str,
2414+
group_size: int,
2415+
group_name: str,
2416+
gradient_division: bool,
2417+
) -> Tensor:
2418+
return torch.ops._c10d_functional.reduce_scatter_tensor(
2419+
input,
2420+
reduceOp,
2421+
group_size,
2422+
group_name,
2423+
)
24812424

2482-
# pyre-ignore
2483-
def all_gather_into_tensor_setup_context(ctx, inputs, output) -> None:
2484-
_, gather_dim, group_size, group_name, gradient_division = inputs
2485-
ctx.group_size = group_size
2486-
ctx.group_name = group_name
2487-
ctx.gradient_division = gradient_division
24882425

2489-
# pyre-ignore
2490-
def all_gather_into_tensor_backward(ctx, grad):
2491-
# TODO(ivankobzarev): Support codecs(quantization) on backward
2492-
out = torch.ops._c10d_functional.reduce_scatter_tensor(
2493-
grad,
2494-
"sum",
2495-
ctx.group_size,
2496-
ctx.group_name,
2497-
)
2498-
grad = torch.ops._c10d_functional.wait_tensor(out)
2499-
if ctx.gradient_division:
2500-
grad.div_(ctx.group_size)
2426+
# pyre-ignore
2427+
def reduce_scatter_tensor_setup_context(ctx, inputs, output) -> None:
2428+
_, _, group_size, group_name, gradient_division = inputs
2429+
ctx.group_size = group_size
2430+
ctx.group_name = group_name
2431+
ctx.gradient_division = gradient_division
25012432

2502-
return grad, None, None, None, None
25032433

2504-
torch.library.register_autograd(
2505-
"torchrec::all_gather_into_tensor",
2506-
all_gather_into_tensor_backward,
2507-
setup_context=all_gather_into_tensor_setup_context,
2434+
# pyre-ignore
2435+
def reduce_scatter_tensor_backward(ctx, grad):
2436+
# TODO(ivankobzarev): Support codecs(quantization) on backward
2437+
out = torch.ops._c10d_functional.all_gather_into_tensor(
2438+
grad,
2439+
ctx.group_size,
2440+
ctx.group_name,
25082441
)
2442+
grad = torch.ops._c10d_functional.wait_tensor(out)
2443+
if ctx.gradient_division:
2444+
grad.div_(ctx.group_size)
25092445

2510-
@torch.library.custom_op("torchrec::_split_1d_cat_2d", mutates_args=())
2511-
def _split_1d_cat_2d_impl(
2512-
t: torch.Tensor, dim0: int, dim1s: List[int]
2513-
) -> torch.Tensor:
2514-
torch._check_is_size(dim0)
2515-
[torch._check_is_size(dim1) for dim1 in dim1s]
2516-
splits: List[torch.Tensor] = t.split([dim0 * dim1 for dim1 in dim1s])
2517-
return torch.cat(
2518-
[s.reshape(dim0, dim1) for s, dim1 in zip(splits, dim1s)],
2519-
dim=1,
2520-
)
2446+
return grad, None, None, None, None, None
25212447

2522-
@torch.library.register_fake("torchrec::_split_1d_cat_2d")
2523-
def _split_1d_cat_2d_impl_abstract(
2524-
t: torch.Tensor, dim0: int, dim1s: List[int]
2525-
) -> torch.Tensor:
2526-
return t.new_empty([dim0, sum(dim1s)])
25272448

2528-
@torch.library.custom_op(
2529-
"torchrec::_split_1d_cat_2d_backward_impl", mutates_args=()
2449+
torch.library.register_autograd(
2450+
"torchrec::reduce_scatter_tensor",
2451+
reduce_scatter_tensor_backward,
2452+
setup_context=reduce_scatter_tensor_setup_context,
2453+
)
2454+
2455+
2456+
# torchrec::all_gather_into_tensor
2457+
@torch.library.custom_op("torchrec::all_gather_into_tensor", mutates_args=())
2458+
def all_gather_into_tensor(
2459+
shard: Tensor,
2460+
gather_dim: int,
2461+
group_size: int,
2462+
group_name: str,
2463+
gradient_division: bool,
2464+
) -> Tensor:
2465+
out = torch.ops._c10d_functional.all_gather_into_tensor(
2466+
shard, group_size, group_name
25302467
)
2531-
def _split_1d_cat_2d_backward_impl(
2532-
grad: torch.Tensor, dim1s: List[int]
2533-
) -> torch.Tensor:
2534-
splits = grad.split(dim1s, dim=1)
2535-
return torch.cat([s.reshape(-1) for s in splits], dim=0)
2536-
2537-
@torch.library.register_fake("torchrec::_split_1d_cat_2d_backward_impl")
2538-
def _split_1d_cat_2d_backward_impl_fake(
2539-
grad: torch.Tensor, dim1s: List[int]
2540-
) -> torch.Tensor:
2541-
return grad.new_empty([grad.numel()])
2468+
return torch.ops._c10d_functional.wait_tensor(out)
25422469

2543-
# pyre-ignore
2544-
def _split_1d_cat_2d_backward(ctx, grad):
2545-
ret = torch.ops.torchrec._split_1d_cat_2d_backward_impl(grad, ctx.dim1s)
2546-
return ret, None, None
25472470

2548-
# pyre-ignore
2549-
def _split_1d_cat_2d_setup_context(ctx, inputs, output):
2550-
(x, dim0, dim1s) = inputs
2551-
ctx.dim1s = dim1s
2552-
2553-
torch.library.register_autograd(
2554-
"torchrec::_split_1d_cat_2d",
2555-
_split_1d_cat_2d_backward,
2556-
setup_context=_split_1d_cat_2d_setup_context,
2471+
@torch.library.register_fake("torchrec::all_gather_into_tensor")
2472+
def all_gather_into_tensor_fake(
2473+
shard: Tensor,
2474+
gather_dim: int,
2475+
group_size: int,
2476+
group_name: str,
2477+
gradient_division: bool,
2478+
) -> Tensor:
2479+
return torch.ops._c10d_functional.all_gather_into_tensor(
2480+
shard, group_size, group_name
2481+
)
2482+
2483+
2484+
# pyre-ignore
2485+
def all_gather_into_tensor_setup_context(ctx, inputs, output) -> None:
2486+
_, gather_dim, group_size, group_name, gradient_division = inputs
2487+
ctx.group_size = group_size
2488+
ctx.group_name = group_name
2489+
ctx.gradient_division = gradient_division
2490+
2491+
2492+
# pyre-ignore
2493+
def all_gather_into_tensor_backward(ctx, grad):
2494+
# TODO(ivankobzarev): Support codecs(quantization) on backward
2495+
out = torch.ops._c10d_functional.reduce_scatter_tensor(
2496+
grad,
2497+
"sum",
2498+
ctx.group_size,
2499+
ctx.group_name,
2500+
)
2501+
grad = torch.ops._c10d_functional.wait_tensor(out)
2502+
if ctx.gradient_division:
2503+
grad.div_(ctx.group_size)
2504+
2505+
return grad, None, None, None, None
2506+
2507+
2508+
torch.library.register_autograd(
2509+
"torchrec::all_gather_into_tensor",
2510+
all_gather_into_tensor_backward,
2511+
setup_context=all_gather_into_tensor_setup_context,
2512+
)
2513+
2514+
2515+
@torch.library.custom_op("torchrec::_split_1d_cat_2d", mutates_args=())
2516+
def _split_1d_cat_2d_impl(t: torch.Tensor, dim0: int, dim1s: List[int]) -> torch.Tensor:
2517+
torch._check_is_size(dim0)
2518+
[torch._check_is_size(dim1) for dim1 in dim1s]
2519+
splits: List[torch.Tensor] = t.split([dim0 * dim1 for dim1 in dim1s])
2520+
return torch.cat(
2521+
[s.reshape(dim0, dim1) for s, dim1 in zip(splits, dim1s)],
2522+
dim=1,
25572523
)
2524+
2525+
2526+
@torch.library.register_fake("torchrec::_split_1d_cat_2d")
2527+
def _split_1d_cat_2d_impl_abstract(
2528+
t: torch.Tensor, dim0: int, dim1s: List[int]
2529+
) -> torch.Tensor:
2530+
return t.new_empty([dim0, sum(dim1s)])
2531+
2532+
2533+
@torch.library.custom_op("torchrec::_split_1d_cat_2d_backward_impl", mutates_args=())
2534+
def _split_1d_cat_2d_backward_impl(
2535+
grad: torch.Tensor, dim1s: List[int]
2536+
) -> torch.Tensor:
2537+
splits = grad.split(dim1s, dim=1)
2538+
return torch.cat([s.reshape(-1) for s in splits], dim=0)
2539+
2540+
2541+
@torch.library.register_fake("torchrec::_split_1d_cat_2d_backward_impl")
2542+
def _split_1d_cat_2d_backward_impl_fake(
2543+
grad: torch.Tensor, dim1s: List[int]
2544+
) -> torch.Tensor:
2545+
return grad.new_empty([grad.numel()])
2546+
2547+
2548+
# pyre-ignore
2549+
def _split_1d_cat_2d_backward(ctx, grad):
2550+
ret = torch.ops.torchrec._split_1d_cat_2d_backward_impl(grad, ctx.dim1s)
2551+
return ret, None, None
2552+
2553+
2554+
# pyre-ignore
2555+
def _split_1d_cat_2d_setup_context(ctx, inputs, output):
2556+
(x, dim0, dim1s) = inputs
2557+
ctx.dim1s = dim1s
2558+
2559+
2560+
torch.library.register_autograd(
2561+
"torchrec::_split_1d_cat_2d",
2562+
_split_1d_cat_2d_backward,
2563+
setup_context=_split_1d_cat_2d_setup_context,
2564+
)

0 commit comments

Comments
 (0)