diff --git a/extend_distributed.py b/extend_distributed.py index 1f2c8a53..c6b2b9c9 100644 --- a/extend_distributed.py +++ b/extend_distributed.py @@ -164,7 +164,7 @@ def init_distributed(rank=-1, local_rank=-1, size=-1, use_gpu=False, backend="") print("Running on %d ranks using %s backend" % (my_size, backend)) if hasattr(dist, "all_to_all_single"): try: - t = torch.zeros([4]) + t = torch.zeros([1024]) if use_gpu: t = t.cuda() dist.all_to_all_single(t, t)