Skip to content

Commit 7af2531

Browse files
kernhandaLisa Ong
andauthored
Fixes call to __call__ method for CallableFunc subclasses (#98)
Co-authored-by: Lisa Ong <[email protected]>
1 parent 9cefcd7 commit 7af2531

File tree

3 files changed

+13
-9
lines changed

3 files changed

+13
-9
lines changed

hatlib/cuda_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,11 +205,11 @@ def cleanup_runtime(self, benchmark: bool, working_dir: str):
205205
cuda.cuCtxDestroy(self.context)
206206

207207
def init_main(self, benchmark: bool, warmup_iters=0, device_id: int = 0, args=[]):
208-
self.func_info.verify(args[0])
208+
self.func_info.verify(args[0] if benchmark else args)
209209
self.device_mem = allocate_cuda_mem(self.func_info.arguments)
210210

211211
if not benchmark:
212-
transfer_mem_host_to_cuda(device_args=self.device_mem, host_args=args[0], arg_infos=self.func_info.arguments)
212+
transfer_mem_host_to_cuda(device_args=self.device_mem, host_args=args, arg_infos=self.func_info.arguments)
213213

214214
self.ptrs = device_args_to_ptr_list(self.device_mem)
215215

hatlib/host_loader.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,24 +112,28 @@ def cleanup_runtime(self, benchmark: bool, working_dir: str):
112112

113113

114114
def init_main(self, benchmark: bool, warmup_iters=0, device_id: int = 0, args=[]):
115-
self.func_info.verify(args[0])
115+
self.func_info.verify(args[0] if benchmark else args)
116116

117117
for _ in range(warmup_iters):
118-
for arg in args:
119-
self.timer_func(*arg, self.timing_arg_val)
118+
if benchmark:
119+
for arg in args:
120+
self.timer_func(*arg, self.timing_arg_val)
121+
else:
122+
self.timer_func(*args, self.timing_arg_val)
120123

121124
def main(self, benchmark: bool, iters=1, batch_size=1, min_time_in_sec=0, args=[]) -> Tuple[float, float]:
122125
batch_timings_ms: List[float] = []
123126
i = 0
124-
i_max = len(args)
127+
i_max = len(args) if benchmark else 1
125128
iterations = 1
126129
min_time_in_ms = min_time_in_sec * 1000
127130

128131
while sum(batch_timings_ms) < min_time_in_ms and len(batch_timings_ms) < batch_size:
129132
self.timing_arg_val.value = np.zeros((1,))
130133

131134
for _ in range(iters):
132-
self.timer_func(*args[i], self.timing_arg_val)
135+
func_args = args[i] if benchmark else args
136+
self.timer_func(*func_args, self.timing_arg_val)
133137
i = iterations % i_max
134138
iterations += 1
135139

hatlib/rocm_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,11 @@ def cleanup_runtime(self, benchmark: bool, working_dir: str):
123123
pass
124124

125125
def init_main(self, benchmark: bool, warmup_iters=0, device_id: int = 0, args=[]):
126-
self.func_info.verify(args[0])
126+
self.func_info.verify(args[0] if benchmark else args)
127127
self.device_mem = allocate_rocm_mem(benchmark, self.func_info.arguments, device_id)
128128

129129
if not benchmark:
130-
transfer_mem_host_to_rocm(device_args=self.device_mem, host_args=args[0], arg_infos=self.func_info.arguments)
130+
transfer_mem_host_to_rocm(device_args=self.device_mem, host_args=args, arg_infos=self.func_info.arguments)
131131

132132
class DataStruct(ctypes.Structure):
133133
_fields_ = [(f"arg{i}", ctypes.c_void_p) for i in range(len(self.func_info.arguments))]

0 commit comments

Comments
 (0)