@@ -54,10 +54,6 @@ def get_gradient_division() -> bool:
54
54
55
55
56
56
def set_use_sync_collectives (val : bool ) -> None :
57
- if val and torch ._running_with_deploy ():
58
- raise RuntimeError (
59
- "TorchRec sync_collectives are not supported in torch.deploy."
60
- )
61
57
62
58
global USE_SYNC_COLLECTIVES
63
59
USE_SYNC_COLLECTIVES = val
@@ -2356,202 +2352,213 @@ def backward(ctx, grad_output: Tensor) -> Tuple[None, None, Tensor]:
2356
2352
return (None , None , myreq .dummy_tensor )
2357
2353
2358
2354
2359
- if not torch ._running_with_deploy (): # noqa C901
2360
- # Torch Library op def can not be used in Deploy
2361
- class AllToAllSingle (torch .autograd .Function ):
2362
- @staticmethod
2363
- # pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently.
2364
- def forward (
2365
- # pyre-fixme[2]: Parameter must be annotated.
2366
- ctx ,
2367
- input : Tensor ,
2368
- output_split_sizes : List [int ],
2369
- input_split_sizes : List [int ],
2370
- group_name : str ,
2371
- group_size : int ,
2372
- gradient_division : bool ,
2373
- ) -> Tensor :
2374
- ctx .output_split_sizes = input_split_sizes
2375
- ctx .input_split_sizes = output_split_sizes
2376
- ctx .group_name = group_name
2377
- ctx .group_size = group_size
2378
- ctx .gradient_division = gradient_division
2379
- return torch .distributed ._functional_collectives .all_to_all_single (
2380
- input , output_split_sizes , input_split_sizes , group_name
2381
- )
2382
-
2383
- @staticmethod
2384
- # pyre-ignore
2385
- def backward (ctx , grad ):
2386
- grad = torch .distributed ._functional_collectives .all_to_all_single (
2387
- grad ,
2388
- ctx .output_split_sizes ,
2389
- ctx .input_split_sizes ,
2390
- ctx .group_name ,
2391
- )
2392
- if ctx .gradient_division :
2393
- grad .div_ (ctx .group_size )
2394
-
2395
- return grad , None , None , None , None , None
2396
-
2397
- # torchrec::reduce_scatter_tensor
2398
- @torch .library .custom_op ("torchrec::reduce_scatter_tensor" , mutates_args = ())
2399
- def reduce_scatter_tensor (
2355
+ class AllToAllSingle (torch .autograd .Function ):
2356
+ @staticmethod
2357
+ # pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently.
2358
+ def forward (
2359
+ # pyre-fixme[2]: Parameter must be annotated.
2360
+ ctx ,
2400
2361
input : Tensor ,
2401
- reduceOp : str ,
2402
- group_size : int ,
2362
+ output_split_sizes : List [ int ] ,
2363
+ input_split_sizes : List [ int ] ,
2403
2364
group_name : str ,
2404
- gradient_division : bool ,
2405
- ) -> Tensor :
2406
- out = torch .ops ._c10d_functional .reduce_scatter_tensor (
2407
- input ,
2408
- reduceOp ,
2409
- group_size ,
2410
- group_name ,
2411
- )
2412
- return torch .ops ._c10d_functional .wait_tensor (out )
2413
-
2414
- @torch .library .register_fake ("torchrec::reduce_scatter_tensor" )
2415
- def reduce_scatter_tensor_fake (
2416
- input : Tensor ,
2417
- reduceOp : str ,
2418
2365
group_size : int ,
2419
- group_name : str ,
2420
2366
gradient_division : bool ,
2421
2367
) -> Tensor :
2422
- return torch .ops ._c10d_functional .reduce_scatter_tensor (
2423
- input ,
2424
- reduceOp ,
2425
- group_size ,
2426
- group_name ,
2427
- )
2428
-
2429
- # pyre-ignore
2430
- def reduce_scatter_tensor_setup_context (ctx , inputs , output ) -> None :
2431
- _ , _ , group_size , group_name , gradient_division = inputs
2432
- ctx .group_size = group_size
2368
+ ctx .output_split_sizes = input_split_sizes
2369
+ ctx .input_split_sizes = output_split_sizes
2433
2370
ctx .group_name = group_name
2371
+ ctx .group_size = group_size
2434
2372
ctx .gradient_division = gradient_division
2373
+ return torch .distributed ._functional_collectives .all_to_all_single (
2374
+ input , output_split_sizes , input_split_sizes , group_name
2375
+ )
2435
2376
2377
+ @staticmethod
2436
2378
# pyre-ignore
2437
- def reduce_scatter_tensor_backward (ctx , grad ):
2438
- # TODO(ivankobzarev): Support codecs(quantization) on backward
2439
- out = torch .ops ._c10d_functional .all_gather_into_tensor (
2379
+ def backward (ctx , grad ):
2380
+ grad = torch .distributed ._functional_collectives .all_to_all_single (
2440
2381
grad ,
2441
- ctx .group_size ,
2382
+ ctx .output_split_sizes ,
2383
+ ctx .input_split_sizes ,
2442
2384
ctx .group_name ,
2443
2385
)
2444
- grad = torch .ops ._c10d_functional .wait_tensor (out )
2445
2386
if ctx .gradient_division :
2446
2387
grad .div_ (ctx .group_size )
2447
2388
2448
2389
return grad , None , None , None , None , None
2449
2390
2450
- torch .library .register_autograd (
2451
- "torchrec::reduce_scatter_tensor" ,
2452
- reduce_scatter_tensor_backward ,
2453
- setup_context = reduce_scatter_tensor_setup_context ,
2391
+
2392
+ # torchrec::reduce_scatter_tensor
2393
+ @torch .library .custom_op ("torchrec::reduce_scatter_tensor" , mutates_args = ())
2394
+ def reduce_scatter_tensor (
2395
+ input : Tensor ,
2396
+ reduceOp : str ,
2397
+ group_size : int ,
2398
+ group_name : str ,
2399
+ gradient_division : bool ,
2400
+ ) -> Tensor :
2401
+ out = torch .ops ._c10d_functional .reduce_scatter_tensor (
2402
+ input ,
2403
+ reduceOp ,
2404
+ group_size ,
2405
+ group_name ,
2454
2406
)
2407
+ return torch .ops ._c10d_functional .wait_tensor (out )
2455
2408
2456
- # torchrec::all_gather_into_tensor
2457
- @torch .library .custom_op ("torchrec::all_gather_into_tensor" , mutates_args = ())
2458
- def all_gather_into_tensor (
2459
- shard : Tensor ,
2460
- gather_dim : int ,
2461
- group_size : int ,
2462
- group_name : str ,
2463
- gradient_division : bool ,
2464
- ) -> Tensor :
2465
- out = torch .ops ._c10d_functional .all_gather_into_tensor (
2466
- shard , group_size , group_name
2467
- )
2468
- return torch .ops ._c10d_functional .wait_tensor (out )
2469
2409
2470
- @torch .library .register_fake ("torchrec::all_gather_into_tensor" )
2471
- def all_gather_into_tensor_fake (
2472
- shard : Tensor ,
2473
- gather_dim : int ,
2474
- group_size : int ,
2475
- group_name : str ,
2476
- gradient_division : bool ,
2477
- ) -> Tensor :
2478
- return torch .ops ._c10d_functional .all_gather_into_tensor (
2479
- shard , group_size , group_name
2480
- )
2410
+ @torch .library .register_fake ("torchrec::reduce_scatter_tensor" )
2411
+ def reduce_scatter_tensor_fake (
2412
+ input : Tensor ,
2413
+ reduceOp : str ,
2414
+ group_size : int ,
2415
+ group_name : str ,
2416
+ gradient_division : bool ,
2417
+ ) -> Tensor :
2418
+ return torch .ops ._c10d_functional .reduce_scatter_tensor (
2419
+ input ,
2420
+ reduceOp ,
2421
+ group_size ,
2422
+ group_name ,
2423
+ )
2481
2424
2482
- # pyre-ignore
2483
- def all_gather_into_tensor_setup_context (ctx , inputs , output ) -> None :
2484
- _ , gather_dim , group_size , group_name , gradient_division = inputs
2485
- ctx .group_size = group_size
2486
- ctx .group_name = group_name
2487
- ctx .gradient_division = gradient_division
2488
2425
2489
- # pyre-ignore
2490
- def all_gather_into_tensor_backward (ctx , grad ):
2491
- # TODO(ivankobzarev): Support codecs(quantization) on backward
2492
- out = torch .ops ._c10d_functional .reduce_scatter_tensor (
2493
- grad ,
2494
- "sum" ,
2495
- ctx .group_size ,
2496
- ctx .group_name ,
2497
- )
2498
- grad = torch .ops ._c10d_functional .wait_tensor (out )
2499
- if ctx .gradient_division :
2500
- grad .div_ (ctx .group_size )
2426
+ # pyre-ignore
2427
+ def reduce_scatter_tensor_setup_context (ctx , inputs , output ) -> None :
2428
+ _ , _ , group_size , group_name , gradient_division = inputs
2429
+ ctx .group_size = group_size
2430
+ ctx .group_name = group_name
2431
+ ctx .gradient_division = gradient_division
2501
2432
2502
- return grad , None , None , None , None
2503
2433
2504
- torch .library .register_autograd (
2505
- "torchrec::all_gather_into_tensor" ,
2506
- all_gather_into_tensor_backward ,
2507
- setup_context = all_gather_into_tensor_setup_context ,
2434
+ # pyre-ignore
2435
+ def reduce_scatter_tensor_backward (ctx , grad ):
2436
+ # TODO(ivankobzarev): Support codecs(quantization) on backward
2437
+ out = torch .ops ._c10d_functional .all_gather_into_tensor (
2438
+ grad ,
2439
+ ctx .group_size ,
2440
+ ctx .group_name ,
2508
2441
)
2442
+ grad = torch .ops ._c10d_functional .wait_tensor (out )
2443
+ if ctx .gradient_division :
2444
+ grad .div_ (ctx .group_size )
2509
2445
2510
- @torch .library .custom_op ("torchrec::_split_1d_cat_2d" , mutates_args = ())
2511
- def _split_1d_cat_2d_impl (
2512
- t : torch .Tensor , dim0 : int , dim1s : List [int ]
2513
- ) -> torch .Tensor :
2514
- torch ._check_is_size (dim0 )
2515
- [torch ._check_is_size (dim1 ) for dim1 in dim1s ]
2516
- splits : List [torch .Tensor ] = t .split ([dim0 * dim1 for dim1 in dim1s ])
2517
- return torch .cat (
2518
- [s .reshape (dim0 , dim1 ) for s , dim1 in zip (splits , dim1s )],
2519
- dim = 1 ,
2520
- )
2446
+ return grad , None , None , None , None , None
2521
2447
2522
- @torch .library .register_fake ("torchrec::_split_1d_cat_2d" )
2523
- def _split_1d_cat_2d_impl_abstract (
2524
- t : torch .Tensor , dim0 : int , dim1s : List [int ]
2525
- ) -> torch .Tensor :
2526
- return t .new_empty ([dim0 , sum (dim1s )])
2527
2448
2528
- @torch .library .custom_op (
2529
- "torchrec::_split_1d_cat_2d_backward_impl" , mutates_args = ()
2449
+ torch .library .register_autograd (
2450
+ "torchrec::reduce_scatter_tensor" ,
2451
+ reduce_scatter_tensor_backward ,
2452
+ setup_context = reduce_scatter_tensor_setup_context ,
2453
+ )
2454
+
2455
+
2456
+ # torchrec::all_gather_into_tensor
2457
+ @torch .library .custom_op ("torchrec::all_gather_into_tensor" , mutates_args = ())
2458
+ def all_gather_into_tensor (
2459
+ shard : Tensor ,
2460
+ gather_dim : int ,
2461
+ group_size : int ,
2462
+ group_name : str ,
2463
+ gradient_division : bool ,
2464
+ ) -> Tensor :
2465
+ out = torch .ops ._c10d_functional .all_gather_into_tensor (
2466
+ shard , group_size , group_name
2530
2467
)
2531
- def _split_1d_cat_2d_backward_impl (
2532
- grad : torch .Tensor , dim1s : List [int ]
2533
- ) -> torch .Tensor :
2534
- splits = grad .split (dim1s , dim = 1 )
2535
- return torch .cat ([s .reshape (- 1 ) for s in splits ], dim = 0 )
2536
-
2537
- @torch .library .register_fake ("torchrec::_split_1d_cat_2d_backward_impl" )
2538
- def _split_1d_cat_2d_backward_impl_fake (
2539
- grad : torch .Tensor , dim1s : List [int ]
2540
- ) -> torch .Tensor :
2541
- return grad .new_empty ([grad .numel ()])
2468
+ return torch .ops ._c10d_functional .wait_tensor (out )
2542
2469
2543
- # pyre-ignore
2544
- def _split_1d_cat_2d_backward (ctx , grad ):
2545
- ret = torch .ops .torchrec ._split_1d_cat_2d_backward_impl (grad , ctx .dim1s )
2546
- return ret , None , None
2547
2470
2548
- # pyre-ignore
2549
- def _split_1d_cat_2d_setup_context (ctx , inputs , output ):
2550
- (x , dim0 , dim1s ) = inputs
2551
- ctx .dim1s = dim1s
2552
-
2553
- torch .library .register_autograd (
2554
- "torchrec::_split_1d_cat_2d" ,
2555
- _split_1d_cat_2d_backward ,
2556
- setup_context = _split_1d_cat_2d_setup_context ,
2471
+ @torch .library .register_fake ("torchrec::all_gather_into_tensor" )
2472
+ def all_gather_into_tensor_fake (
2473
+ shard : Tensor ,
2474
+ gather_dim : int ,
2475
+ group_size : int ,
2476
+ group_name : str ,
2477
+ gradient_division : bool ,
2478
+ ) -> Tensor :
2479
+ return torch .ops ._c10d_functional .all_gather_into_tensor (
2480
+ shard , group_size , group_name
2481
+ )
2482
+
2483
+
2484
+ # pyre-ignore
2485
+ def all_gather_into_tensor_setup_context (ctx , inputs , output ) -> None :
2486
+ _ , gather_dim , group_size , group_name , gradient_division = inputs
2487
+ ctx .group_size = group_size
2488
+ ctx .group_name = group_name
2489
+ ctx .gradient_division = gradient_division
2490
+
2491
+
2492
+ # pyre-ignore
2493
+ def all_gather_into_tensor_backward (ctx , grad ):
2494
+ # TODO(ivankobzarev): Support codecs(quantization) on backward
2495
+ out = torch .ops ._c10d_functional .reduce_scatter_tensor (
2496
+ grad ,
2497
+ "sum" ,
2498
+ ctx .group_size ,
2499
+ ctx .group_name ,
2500
+ )
2501
+ grad = torch .ops ._c10d_functional .wait_tensor (out )
2502
+ if ctx .gradient_division :
2503
+ grad .div_ (ctx .group_size )
2504
+
2505
+ return grad , None , None , None , None
2506
+
2507
+
2508
+ torch .library .register_autograd (
2509
+ "torchrec::all_gather_into_tensor" ,
2510
+ all_gather_into_tensor_backward ,
2511
+ setup_context = all_gather_into_tensor_setup_context ,
2512
+ )
2513
+
2514
+
2515
+ @torch .library .custom_op ("torchrec::_split_1d_cat_2d" , mutates_args = ())
2516
+ def _split_1d_cat_2d_impl (t : torch .Tensor , dim0 : int , dim1s : List [int ]) -> torch .Tensor :
2517
+ torch ._check_is_size (dim0 )
2518
+ [torch ._check_is_size (dim1 ) for dim1 in dim1s ]
2519
+ splits : List [torch .Tensor ] = t .split ([dim0 * dim1 for dim1 in dim1s ])
2520
+ return torch .cat (
2521
+ [s .reshape (dim0 , dim1 ) for s , dim1 in zip (splits , dim1s )],
2522
+ dim = 1 ,
2557
2523
)
2524
+
2525
+
2526
+ @torch .library .register_fake ("torchrec::_split_1d_cat_2d" )
2527
+ def _split_1d_cat_2d_impl_abstract (
2528
+ t : torch .Tensor , dim0 : int , dim1s : List [int ]
2529
+ ) -> torch .Tensor :
2530
+ return t .new_empty ([dim0 , sum (dim1s )])
2531
+
2532
+
2533
+ @torch .library .custom_op ("torchrec::_split_1d_cat_2d_backward_impl" , mutates_args = ())
2534
+ def _split_1d_cat_2d_backward_impl (
2535
+ grad : torch .Tensor , dim1s : List [int ]
2536
+ ) -> torch .Tensor :
2537
+ splits = grad .split (dim1s , dim = 1 )
2538
+ return torch .cat ([s .reshape (- 1 ) for s in splits ], dim = 0 )
2539
+
2540
+
2541
+ @torch .library .register_fake ("torchrec::_split_1d_cat_2d_backward_impl" )
2542
+ def _split_1d_cat_2d_backward_impl_fake (
2543
+ grad : torch .Tensor , dim1s : List [int ]
2544
+ ) -> torch .Tensor :
2545
+ return grad .new_empty ([grad .numel ()])
2546
+
2547
+
2548
+ # pyre-ignore
2549
+ def _split_1d_cat_2d_backward (ctx , grad ):
2550
+ ret = torch .ops .torchrec ._split_1d_cat_2d_backward_impl (grad , ctx .dim1s )
2551
+ return ret , None , None
2552
+
2553
+
2554
+ # pyre-ignore
2555
+ def _split_1d_cat_2d_setup_context (ctx , inputs , output ):
2556
+ (x , dim0 , dim1s ) = inputs
2557
+ ctx .dim1s = dim1s
2558
+
2559
+
2560
+ torch .library .register_autograd (
2561
+ "torchrec::_split_1d_cat_2d" ,
2562
+ _split_1d_cat_2d_backward ,
2563
+ setup_context = _split_1d_cat_2d_setup_context ,
2564
+ )
0 commit comments