-
Notifications
You must be signed in to change notification settings - Fork 52
Description
I find myself wanting to programmatically find out what the "highest precision float type" is that a particular library supports on a particular device. Concretely pytorch and the MPS device (their name for the GPU in a Apple M1 (and M2?)). On the MPS device they don't support float64
which is how I ended up wanting something to let me find out what the highest precision available float type is.
import torch
import array_api_compat
x = torch.tensor([1,2,3], device="mps", dtype=torch.float32)
xp = array_api_compat.get_namespace(x)
# side quest: is there a better way to get the torch namespace?
x = xp.asarray([1,2,3], device="mps", dtype=torch.float32)
# Maybe `can_cast` is the right tool?
xp.can_cast(xp.float32, xp.float64) # -> True
xp.can_cast(x, xp.float64) # -> True
Presumably the two calls to can_cast
return True
because in general PyTorch supports float64 and the implementation of can_cast
does not inspect the device of x
? So at least for now/how it is currently implemented I think can_cast
is not the right tool for finding out if float64
exists and using float32
if not. Or making my own highest_precision_float()
.