Skip to content

Commit d3d0e40

Browse files
amirafzalifacebook-github-bot
authored andcommitted
add rdma tensor write (#36)
Summary: - Support rdma tensor transfer operations from python I took some liberties with the python facing interface, open to feedback if they should be modified Reviewed By: d4l3k Differential Revision: D85887151
1 parent 6f4233a commit d3d0e40

File tree

4 files changed

+153
-17
lines changed

4 files changed

+153
-17
lines changed

comms/torchcomms/ncclx/TorchCommNCCLXPy.cpp

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ PYBIND11_MODULE(_comms_ncclx, m) {
2929
py::class_<TorchCommNCCLX, std::shared_ptr<TorchCommNCCLX>>(
3030
m, "TorchCommNCCLX");
3131

32+
py::class_<RdmaRemoteBuffer, std::shared_ptr<RdmaRemoteBuffer>>(
33+
m, "RdmaRemoteBuffer");
34+
3235
py::class_<RdmaTransport, std::shared_ptr<RdmaTransport>>(m, "RdmaTransport")
3336
// initialize a new RDMATransport using a custom init fn
3437
.def(py::init([](at::Device device) {
@@ -45,5 +48,37 @@ PYBIND11_MODULE(_comms_ncclx, m) {
4548
std::string peerUrlStr = peerUrl.cast<std::string>();
4649
return static_cast<int>(self.connect(peerUrlStr));
4750
})
48-
.def("connected", &RdmaTransport::connected);
51+
.def("connected", &RdmaTransport::connected)
52+
.def(
53+
"write",
54+
[](RdmaTransport& self,
55+
const RdmaMemory::View& localBuffer,
56+
const RdmaRemoteBuffer& remoteBuffer) {
57+
return static_cast<int>(
58+
self.write(localBuffer, remoteBuffer, false).get());
59+
});
60+
61+
py::class_<RdmaMemory::View, std::shared_ptr<RdmaMemory::View>>(
62+
m, "RdmaMemoryView")
63+
.def("size", &RdmaMemory::View::size);
64+
65+
py::class_<RdmaMemory, std::shared_ptr<RdmaMemory>>(m, "RdmaMemory")
66+
.def(py::init([](const at::Tensor& tensor) {
67+
TORCH_CHECK(
68+
tensor.is_contiguous(),
69+
"RdmaMemory currently requires a contiguous tensor");
70+
// If CPU memory is passed, use device 0 for NIC discovery
71+
const auto device = tensor.get_device() < 0 ? 0 : tensor.get_device();
72+
return std::make_shared<RdmaMemory>(
73+
tensor.data_ptr(), tensor.nbytes(), device);
74+
}))
75+
.def(
76+
"to_view",
77+
[](RdmaMemory& self) {
78+
return self.createView(size_t(0), self.length());
79+
})
80+
.def("to_remote_buffer", [](RdmaMemory& self) {
81+
return RdmaRemoteBuffer{
82+
const_cast<void*>(self.data()), self.remoteKey()};
83+
});
4984
}

comms/torchcomms/ncclx/_comms_ncclx.pyi

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,23 @@ import torch
55

66
class TorchCommNCCLX: ...
77

8+
class RdmaMemoryView:
9+
def size(self) -> int: ...
10+
11+
class RdmaRemoteBuffer: ...
12+
13+
class RdmaMemory:
14+
def __init__(self, tensor: torch.Tensor) -> None: ... # pyre-ignore[11]
15+
def to_view(self) -> RdmaMemoryView: ...
16+
def to_remote_buffer(self) -> RdmaRemoteBuffer: ...
17+
818
class RdmaTransport:
919
def __init__(self, device: torch.device) -> None: ... # pyre-ignore[11]
1020
@staticmethod
1121
def supported() -> bool: ...
1222
def bind(self) -> bytes: ...
1323
def connect(self, peer_url: bytes) -> int: ...
1424
def connected(self) -> bool: ...
25+
def write(
26+
self, local_buffer: RdmaMemoryView, remote_buffer: RdmaRemoteBuffer
27+
) -> int: ...

comms/torchcomms/tests/integration/py/TransportTest.py

Lines changed: 96 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import unittest
77

88
import torch
9-
from torchcomms._comms_ncclx import RdmaTransport
9+
from torchcomms._comms_ncclx import RdmaMemory, RdmaTransport
1010

1111

1212
class TransportTest(unittest.TestCase):
@@ -17,28 +17,30 @@ def setUp(self):
1717
def test_construct(self) -> None:
1818
_ = RdmaTransport(torch.device("cuda:0"))
1919

20-
def test_bind_and_connect(self) -> None:
21-
if torch.cuda.device_count() < 2:
22-
self.skipTest(
23-
f"Test requires at least 2 CUDA devices, found {torch.cuda.device_count()}"
24-
)
20+
def test_rdma_memory_from_tensor(self) -> None:
21+
tensor = torch.arange(1024, dtype=torch.uint8, device="cuda:0")
22+
compare_tensor = torch.zeros_like(tensor, device="cuda:1")
2523

26-
server_device = torch.device("cuda:0")
27-
client_device = torch.device("cuda:1")
24+
tensor_mem = RdmaMemory(tensor)
25+
compare_mem = RdmaMemory(compare_tensor)
2826

29-
server_transport = RdmaTransport(server_device)
30-
client_transport = RdmaTransport(client_device)
27+
tensor_view = tensor_mem.to_view()
28+
compare_view = compare_mem.to_view()
3129

32-
server_url = server_transport.bind()
33-
client_url = client_transport.bind()
30+
self.assertEqual(tensor_view.size(), tensor.nbytes)
31+
self.assertAlmostEqual(tensor_view.size(), compare_view.size())
32+
33+
def bind_and_connect(self, server: RdmaTransport, client: RdmaTransport) -> None:
34+
server_url = server.bind()
35+
client_url = client.bind()
3436

3537
self.assertIsNotNone(server_url)
3638
self.assertIsNotNone(client_url)
3739
self.assertNotEqual(server_url, "")
3840
self.assertNotEqual(client_url, "")
3941

40-
server_result = server_transport.connect(client_url)
41-
client_result = client_transport.connect(server_url)
42+
server_result = server.connect(client_url)
43+
client_result = client.connect(server_url)
4244

4345
self.assertEqual(
4446
server_result, 0, "Server connect should return commSuccess (0)"
@@ -47,8 +49,86 @@ def test_bind_and_connect(self) -> None:
4749
client_result, 0, "Client connect should return commSuccess (0)"
4850
)
4951

50-
self.assertTrue(server_transport.connected())
51-
self.assertTrue(client_transport.connected())
52+
self.assertTrue(server.connected())
53+
self.assertTrue(client.connected())
54+
55+
def test_bind_and_connect(self) -> None:
56+
if torch.cuda.device_count() < 2:
57+
self.skipTest(
58+
f"Test requires at least 2 CUDA devices, found {torch.cuda.device_count()}"
59+
)
60+
61+
server_device = torch.device("cuda:0")
62+
client_device = torch.device("cuda:1")
63+
64+
server_transport = RdmaTransport(server_device)
65+
client_transport = RdmaTransport(client_device)
66+
67+
self.bind_and_connect(server_transport, client_transport)
68+
69+
def run_send_recv(
70+
self,
71+
device1: str,
72+
device2: str,
73+
) -> None:
74+
transport_device_1 = "cuda:0" if device1 == "cpu" else device1
75+
transport_device_2 = "cuda:0" if device2 == "cpu" else device2
76+
transport1 = RdmaTransport(torch.device(transport_device_1))
77+
transport2 = RdmaTransport(torch.device(transport_device_2))
78+
79+
self.bind_and_connect(transport1, transport2)
80+
81+
tensor1 = torch.arange(1024, dtype=torch.uint8, device=device1)
82+
tensor2 = torch.zeros_like(tensor1, device=device2)
83+
84+
self.assertEqual(tensor1.nbytes, tensor2.nbytes)
85+
86+
tensor1_mem = RdmaMemory(tensor1)
87+
tensor2_mem = RdmaMemory(tensor2)
88+
89+
res = transport1.write(tensor1_mem.to_view(), tensor2_mem.to_remote_buffer())
90+
91+
self.assertEqual(res, 0)
92+
self.assertTrue(torch.allclose(tensor1.cpu(), tensor2.cpu()))
93+
94+
del transport1
95+
del transport2
96+
del tensor1_mem
97+
del tensor2_mem
98+
99+
def check_multi_gpu(self) -> None:
100+
if torch.cuda.device_count() < 2:
101+
self.skipTest(
102+
f"Test requires at least 2 CUDA devices, found {torch.cuda.device_count()}"
103+
)
104+
105+
def test_write_gpu_to_gpu(self) -> None:
106+
self.check_multi_gpu()
107+
self.run_send_recv("cuda:0", "cuda:1")
108+
109+
def test_write_gpu_to_gpu_2(self) -> None:
110+
self.check_multi_gpu()
111+
self.run_send_recv("cuda:0", "cuda:0")
112+
113+
def test_write_cpu_to_gpu(self) -> None:
114+
self.check_multi_gpu()
115+
self.run_send_recv("cpu", "cuda:1")
116+
117+
def test_write_cpu_to_gpu_2(self) -> None:
118+
self.check_multi_gpu()
119+
self.run_send_recv("cpu", "cuda:0")
120+
121+
def test_write_gpu_to_cpu(self) -> None:
122+
self.check_multi_gpu()
123+
self.run_send_recv("cuda:1", "cpu")
124+
125+
def test_write_gpu_to_cpu_2(self) -> None:
126+
self.check_multi_gpu()
127+
self.run_send_recv("cuda:0", "cpu")
128+
129+
def test_write_cpu_to_cpu(self) -> None:
130+
self.check_multi_gpu()
131+
self.run_send_recv("cpu", "cpu")
52132

53133

54134
if __name__ == "__main__" and os.environ["TEST_BACKEND"] == "ncclx":

comms/torchcomms/transport/RdmaTransport.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,14 @@ class RdmaMemory : folly::MoveOnly {
9898
return cudaDev_;
9999
}
100100

101+
size_t length() const {
102+
return len_;
103+
}
104+
105+
const void* data() const {
106+
return buf_;
107+
}
108+
101109
/*
102110
* Check if the given buffer and length are contained within this memory
103111
* region.

0 commit comments

Comments
 (0)