-
Notifications
You must be signed in to change notification settings - Fork 137
Description
This whole thing (i.e., calling out.cpu()
) is suboptimal. I think we don't need it for JAX (which returns JAX arrays/ not numpy arrays), because np.asarray
works with it, and I guess it doesn't work for torch tensors.
pytensor/pytensor/link/pytorch/linker.py
Line 16 in 7b13a95
return out.cpu() |
This should only be needed for updated shared variables where we have to convert to a common type as they could be used in multiple functions with distinct backends.
Perhaps we should expand a bit on the TorchLinker to perform the updates itself, and only force conversion when that's the case. This is already supported by Function
.
pytensor/pytensor/compile/function/types.py
Lines 1009 to 1017 in 7b13a95
if getattr(self.vm, "need_update_inputs", True): | |
# Update the inputs that have an update function | |
for input, storage in reversed( | |
list(zip(self.maker.expanded_inputs, input_storage)) | |
): | |
if input.update is not None: | |
storage.data = outputs.pop() | |
else: | |
outputs = outputs[: self.n_returned_outputs] |
Originally posted by @ricardoV94 in #1032 (comment)