66import unittest
77
88import torch
9- from torchcomms ._comms_ncclx import RdmaTransport
9+ from torchcomms ._comms_ncclx import RdmaMemory , RdmaTransport
1010
1111
1212class TransportTest (unittest .TestCase ):
@@ -17,28 +17,30 @@ def setUp(self):
1717 def test_construct (self ) -> None :
1818 _ = RdmaTransport (torch .device ("cuda:0" ))
1919
20- def test_bind_and_connect (self ) -> None :
21- if torch .cuda .device_count () < 2 :
22- self .skipTest (
23- f"Test requires at least 2 CUDA devices, found { torch .cuda .device_count ()} "
24- )
20+ def test_rdma_memory_from_tensor (self ) -> None :
21+ tensor = torch .arange (1024 , dtype = torch .uint8 , device = "cuda:0" )
22+ compare_tensor = torch .zeros_like (tensor , device = "cuda:1" )
2523
26- server_device = torch . device ( "cuda:0" )
27- client_device = torch . device ( "cuda:1" )
24+ tensor_mem = RdmaMemory ( tensor )
25+ compare_mem = RdmaMemory ( compare_tensor )
2826
29- server_transport = RdmaTransport ( server_device )
30- client_transport = RdmaTransport ( client_device )
27+ tensor_view = tensor_mem . to_view ( )
28+ compare_view = compare_mem . to_view ( )
3129
32- server_url = server_transport .bind ()
33- client_url = client_transport .bind ()
30+ self .assertEqual (tensor_view .size (), tensor .nbytes )
31+ self .assertAlmostEqual (tensor_view .size (), compare_view .size ())
32+
33+ def bind_and_connect (self , server : RdmaTransport , client : RdmaTransport ) -> None :
34+ server_url = server .bind ()
35+ client_url = client .bind ()
3436
3537 self .assertIsNotNone (server_url )
3638 self .assertIsNotNone (client_url )
3739 self .assertNotEqual (server_url , "" )
3840 self .assertNotEqual (client_url , "" )
3941
40- server_result = server_transport .connect (client_url )
41- client_result = client_transport .connect (server_url )
42+ server_result = server .connect (client_url )
43+ client_result = client .connect (server_url )
4244
4345 self .assertEqual (
4446 server_result , 0 , "Server connect should return commSuccess (0)"
@@ -47,8 +49,86 @@ def test_bind_and_connect(self) -> None:
4749 client_result , 0 , "Client connect should return commSuccess (0)"
4850 )
4951
50- self .assertTrue (server_transport .connected ())
51- self .assertTrue (client_transport .connected ())
52+ self .assertTrue (server .connected ())
53+ self .assertTrue (client .connected ())
54+
55+ def test_bind_and_connect (self ) -> None :
56+ if torch .cuda .device_count () < 2 :
57+ self .skipTest (
58+ f"Test requires at least 2 CUDA devices, found { torch .cuda .device_count ()} "
59+ )
60+
61+ server_device = torch .device ("cuda:0" )
62+ client_device = torch .device ("cuda:1" )
63+
64+ server_transport = RdmaTransport (server_device )
65+ client_transport = RdmaTransport (client_device )
66+
67+ self .bind_and_connect (server_transport , client_transport )
68+
69+ def run_send_recv (
70+ self ,
71+ device1 : str ,
72+ device2 : str ,
73+ ) -> None :
74+ transport_device_1 = "cuda:0" if device1 == "cpu" else device1
75+ transport_device_2 = "cuda:0" if device2 == "cpu" else device2
76+ transport1 = RdmaTransport (torch .device (transport_device_1 ))
77+ transport2 = RdmaTransport (torch .device (transport_device_2 ))
78+
79+ self .bind_and_connect (transport1 , transport2 )
80+
81+ tensor1 = torch .arange (1024 , dtype = torch .uint8 , device = device1 )
82+ tensor2 = torch .zeros_like (tensor1 , device = device2 )
83+
84+ self .assertEqual (tensor1 .nbytes , tensor2 .nbytes )
85+
86+ tensor1_mem = RdmaMemory (tensor1 )
87+ tensor2_mem = RdmaMemory (tensor2 )
88+
89+ res = transport1 .write (tensor1_mem .to_view (), tensor2_mem .to_remote_buffer ())
90+
91+ self .assertEqual (res , 0 )
92+ self .assertTrue (torch .allclose (tensor1 .cpu (), tensor2 .cpu ()))
93+
94+ del transport1
95+ del transport2
96+ del tensor1_mem
97+ del tensor2_mem
98+
99+ def check_multi_gpu (self ) -> None :
100+ if torch .cuda .device_count () < 2 :
101+ self .skipTest (
102+ f"Test requires at least 2 CUDA devices, found { torch .cuda .device_count ()} "
103+ )
104+
105+ def test_write_gpu_to_gpu (self ) -> None :
106+ self .check_multi_gpu ()
107+ self .run_send_recv ("cuda:0" , "cuda:1" )
108+
109+ def test_write_gpu_to_gpu_2 (self ) -> None :
110+ self .check_multi_gpu ()
111+ self .run_send_recv ("cuda:0" , "cuda:0" )
112+
113+ def test_write_cpu_to_gpu (self ) -> None :
114+ self .check_multi_gpu ()
115+ self .run_send_recv ("cpu" , "cuda:1" )
116+
117+ def test_write_cpu_to_gpu_2 (self ) -> None :
118+ self .check_multi_gpu ()
119+ self .run_send_recv ("cpu" , "cuda:0" )
120+
121+ def test_write_gpu_to_cpu (self ) -> None :
122+ self .check_multi_gpu ()
123+ self .run_send_recv ("cuda:1" , "cpu" )
124+
125+ def test_write_gpu_to_cpu_2 (self ) -> None :
126+ self .check_multi_gpu ()
127+ self .run_send_recv ("cuda:0" , "cpu" )
128+
129+ def test_write_cpu_to_cpu (self ) -> None :
130+ self .check_multi_gpu ()
131+ self .run_send_recv ("cpu" , "cpu" )
52132
53133
54134if __name__ == "__main__" and os .environ ["TEST_BACKEND" ] == "ncclx" :
0 commit comments