Skip to content

Discovering supported float types per namespace and device #678

@betatim

Description

@betatim

I find myself wanting to programmatically find out what the "highest precision float type" is that a particular library supports on a particular device. Concretely pytorch and the MPS device (their name for the GPU in a Apple M1 (and M2?)). On the MPS device they don't support float64 which is how I ended up wanting something to let me find out what the highest precision available float type is.

import torch
import array_api_compat

x = torch.tensor([1,2,3], device="mps", dtype=torch.float32)
xp = array_api_compat.get_namespace(x)
# side quest: is there a better way to get the torch namespace?

x = xp.asarray([1,2,3], device="mps", dtype=torch.float32)

# Maybe `can_cast` is the right tool?
xp.can_cast(xp.float32, xp.float64)  # -> True
xp.can_cast(x, xp.float64)  # -> True

Presumably the two calls to can_cast return True because in general PyTorch supports float64 and the implementation of can_cast does not inspect the device of x? So at least for now/how it is currently implemented I think can_cast is not the right tool for finding out if float64 exists and using float32 if not. Or making my own highest_precision_float().

Metadata

Metadata

Assignees

No one assigned

    Labels

    DuplicateThis issue or pull request already exists.

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions