1
+ # SPDX-License-Identifier: MIT
2
+ # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
3
+
4
+ import ctypes
5
+ import numpy as np
6
+ import sys
7
+
8
+ rt_path = "libcudart.so"
9
+ cuda_runtime = ctypes .cdll .LoadLibrary (rt_path )
10
+
11
+
12
+ def cuda_try (err ):
13
+ if err != 0 :
14
+ cuda_runtime .cudaGetErrorString .restype = ctypes .c_char_p
15
+ error_string = cuda_runtime .cudaGetErrorString (ctypes .c_int (err )).decode ("utf-8" )
16
+ raise RuntimeError (f"cuda error code { err } : { error_string } " )
17
+
18
+
19
+ class cudaIpcMemHandle_t (ctypes .Structure ):
20
+ _fields_ = [("internal" , ctypes .c_byte * 128 )]
21
+
22
+
23
+ def open_ipc_handle (ipc_handle_data , rank ):
24
+ ptr = ctypes .c_void_p ()
25
+ cudaIpcMemLazyEnablePeerAccess = ctypes .c_uint (1 )
26
+ cuda_runtime .cudaIpcOpenMemHandle .argtypes = [
27
+ ctypes .POINTER (ctypes .c_void_p ),
28
+ cudaIpcMemHandle_t ,
29
+ ctypes .c_uint ,
30
+ ]
31
+ if isinstance (ipc_handle_data , np .ndarray ):
32
+ if ipc_handle_data .dtype != np .uint8 or ipc_handle_data .size != 128 :
33
+ raise ValueError ("ipc_handle_data must be a 128-element uint8 numpy array" )
34
+ ipc_handle_bytes = ipc_handle_data .tobytes ()
35
+ ipc_handle_data = (ctypes .c_char * 128 ).from_buffer_copy (ipc_handle_bytes )
36
+ else :
37
+ raise TypeError ("ipc_handle_data must be a numpy.ndarray of dtype uint8 with 128 elements" )
38
+
39
+ raw_memory = ctypes .create_string_buffer (128 )
40
+ ctypes .memset (raw_memory , 0x00 , 128 )
41
+ ipc_handle_struct = cudaIpcMemHandle_t .from_buffer (raw_memory )
42
+ ipc_handle_data_bytes = bytes (ipc_handle_data )
43
+ ctypes .memmove (raw_memory , ipc_handle_data_bytes , 128 )
44
+
45
+ cuda_try (
46
+ cuda_runtime .cudaIpcOpenMemHandle (
47
+ ctypes .byref (ptr ),
48
+ ipc_handle_struct ,
49
+ cudaIpcMemLazyEnablePeerAccess ,
50
+ )
51
+ )
52
+
53
+ return ptr .value
54
+
55
+
56
+ def get_ipc_handle (ptr , rank ):
57
+ ipc_handle = cudaIpcMemHandle_t ()
58
+ cuda_try (cuda_runtime .cudaIpcGetMemHandle (ctypes .byref (ipc_handle ), ptr ))
59
+ return ipc_handle
60
+
61
+
62
+ def count_devices ():
63
+ device_count = ctypes .c_int ()
64
+ cuda_try (cuda_runtime .cudaGetDeviceCount (ctypes .byref (device_count )))
65
+ return device_count .value
66
+
67
+
68
+ def set_device (gpu_id ):
69
+ cuda_try (cuda_runtime .cudaSetDevice (gpu_id ))
70
+
71
+
72
+ def get_device_id ():
73
+ device_id = ctypes .c_int ()
74
+ cuda_try (cuda_runtime .cudaGetDevice (ctypes .byref (device_id )))
75
+ return device_id .value
76
+
77
+
78
+ def get_cu_count (device_id = None ):
79
+ if device_id is None :
80
+ device_id = get_device_id ()
81
+
82
+ cudaDeviceAttributeMultiprocessorCount = 16
83
+ cu_count = ctypes .c_int ()
84
+
85
+ cuda_try (
86
+ cuda_runtime .cudaDeviceGetAttribute (ctypes .byref (cu_count ), cudaDeviceAttributeMultiprocessorCount , device_id )
87
+ )
88
+
89
+ return cu_count .value
90
+
91
+
92
+ # Starting ROCm 6.5
93
+ # def get_xcc_count(device_id=None):
94
+ # if device_id is None:
95
+ # device_id = get_device()
96
+
97
+ # cudaDeviceAttributeNumberOfXccs = ??
98
+ # xcc_count = ctypes.c_int()
99
+
100
+ # cuda_try(cuda_runtime.cudaDeviceGetAttribute(
101
+ # ctypes.byref(xcc_count),
102
+ # cudaDeviceAttributeNumberOfXccs,
103
+ # device_id
104
+ # ))
105
+
106
+ # return xcc_count
107
+
108
+
109
+ def get_wall_clock_rate (device_id ):
110
+ cudaDevAttrMemoryClockRate = 36
111
+ wall_clock_rate = ctypes .c_int ()
112
+ status = cuda_runtime .cudaDeviceGetAttribute (ctypes .byref (wall_clock_rate ), cudaDevAttrMemoryClockRate , device_id )
113
+ cuda_try (status )
114
+ return wall_clock_rate .value
115
+
116
+
117
+ def malloc_fine_grained (size ):
118
+ return cuda_malloc (size )
119
+
120
+
121
+ def cuda_malloc (size ):
122
+ ptr = ctypes .c_void_p ()
123
+ cuda_try (cuda_runtime .cudaMalloc (ctypes .byref (ptr ), size ))
124
+ return ptr
125
+
126
+
127
+ def cuda_free (ptr ):
128
+ cuda_try (cuda_runtime .cudaFree (ptr ))
0 commit comments