Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/lightning/fabric/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ def auto_device_count() -> int:
def is_available() -> bool:
"""Detect if the hardware is available."""

@staticmethod
@abstractmethod
def name() -> str:
"""The name of the accelerator."""

@classmethod
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
"""Register the accelerator with the registry."""
pass
7 changes: 6 additions & 1 deletion src/lightning/fabric/accelerators/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,16 @@ def is_available() -> bool:
"""CPU is always available for execution."""
return True

@staticmethod
@override
def name() -> str:
return "cpu"

@classmethod
@override
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
accelerator_registry.register(
"cpu",
cls.name(),
cls,
description=cls.__name__,
)
Expand Down
7 changes: 6 additions & 1 deletion src/lightning/fabric/accelerators/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,16 @@ def auto_device_count() -> int:
def is_available() -> bool:
return num_cuda_devices() > 0

@staticmethod
@override
def name() -> str:
return "cuda"

@classmethod
@override
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
accelerator_registry.register(
"cuda",
cls.name(),
cls,
description=cls.__name__,
)
Expand Down
7 changes: 6 additions & 1 deletion src/lightning/fabric/accelerators/mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,16 @@ def is_available() -> bool:
mps_disabled = os.getenv("DISABLE_MPS", "0") == "1"
return not mps_disabled and torch.backends.mps.is_available() and platform.processor() in ("arm", "arm64")

@staticmethod
@override
def name() -> str:
return "mps"

@classmethod
@override
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
accelerator_registry.register(
"mps",
cls.name(),
cls,
description=cls.__name__,
)
Expand Down
11 changes: 10 additions & 1 deletion src/lightning/fabric/accelerators/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,19 @@ def is_available() -> bool:
# when `torch_xla` is imported but not used
return False

@staticmethod
@override
def name() -> str:
return "tpu"

@classmethod
@override
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
accelerator_registry.register("tpu", cls, description=cls.__name__)
accelerator_registry.register(
cls.name(),
cls,
description=cls.__name__,
)


# PJRT support requires this minimum version
Expand Down
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Add MPS accelerator support for mixed precision ([#21209](https://github.com/Lightning-AI/pytorch-lightning/pull/21209))

- Add `name()` function to accelerator interface (([#21325](https://github.com/Lightning-AI/pytorch-lightning/pull/21325)))


### Removed

Expand Down
9 changes: 7 additions & 2 deletions src/lightning/pytorch/accelerators/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from lightning_utilities.core.imports import RequirementCache
from typing_extensions import override

from lightning.fabric.accelerators import _AcceleratorRegistry
from lightning.fabric.accelerators.cpu import _parse_cpu_cores
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
from lightning.fabric.utilities.types import _DEVICE
from lightning.pytorch.accelerators.accelerator import Accelerator
from lightning.pytorch.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -71,11 +71,16 @@ def is_available() -> bool:
"""CPU is always available for execution."""
return True

@staticmethod
@override
def name() -> str:
return "cpu"

@classmethod
@override
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
accelerator_registry.register(
"cpu",
cls.name(),
cls,
description=cls.__name__,
)
Expand Down
9 changes: 7 additions & 2 deletions src/lightning/pytorch/accelerators/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from typing_extensions import override

import lightning.pytorch as pl
from lightning.fabric.accelerators import _AcceleratorRegistry
from lightning.fabric.accelerators.cuda import _check_cuda_matmul_precision, _clear_cuda_memory, num_cuda_devices
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
from lightning.fabric.utilities.device_parser import _parse_gpu_ids
from lightning.fabric.utilities.types import _DEVICE
from lightning.pytorch.accelerators.accelerator import Accelerator
Expand Down Expand Up @@ -104,11 +104,16 @@ def auto_device_count() -> int:
def is_available() -> bool:
return num_cuda_devices() > 0

@staticmethod
@override
def name() -> str:
return "cuda"

@classmethod
@override
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
accelerator_registry.register(
"cuda",
cls.name(),
cls,
description=cls.__name__,
)
Expand Down
9 changes: 7 additions & 2 deletions src/lightning/pytorch/accelerators/mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
import torch
from typing_extensions import override

from lightning.fabric.accelerators import _AcceleratorRegistry
from lightning.fabric.accelerators.mps import MPSAccelerator as _MPSAccelerator
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
from lightning.fabric.utilities.device_parser import _parse_gpu_ids
from lightning.fabric.utilities.types import _DEVICE
from lightning.pytorch.accelerators.accelerator import Accelerator
Expand Down Expand Up @@ -78,11 +78,16 @@ def is_available() -> bool:
"""MPS is only available on a machine with the ARM-based Apple Silicon processors."""
return _MPSAccelerator.is_available()

@staticmethod
@override
def name() -> str:
return "mps"

@classmethod
@override
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
accelerator_registry.register(
"mps",
cls.name(),
cls,
description=cls.__name__,
)
Expand Down
13 changes: 11 additions & 2 deletions src/lightning/pytorch/accelerators/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from typing_extensions import override

from lightning.fabric.accelerators import _AcceleratorRegistry
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
from lightning.fabric.accelerators.xla import XLAAccelerator as FabricXLAAccelerator
from lightning.fabric.utilities.types import _DEVICE
from lightning.pytorch.accelerators.accelerator import Accelerator
Expand Down Expand Up @@ -49,7 +49,16 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
"avg. peak memory (MB)": peak_memory,
}

@staticmethod
@override
def name() -> str:
return "tpu"

@classmethod
@override
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
accelerator_registry.register("tpu", cls, description=cls.__name__)
accelerator_registry.register(
cls.name(),
cls,
description=cls.__name__,
)
8 changes: 8 additions & 0 deletions tests/tests_fabric/accelerators/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def auto_device_count():
def is_available():
return True

@staticmethod
def name():
return "test_accelerator"


def test_accelerator_registry_with_new_accelerator():
accelerator_name = "custom_accelerator"
Expand Down Expand Up @@ -85,6 +89,10 @@ def auto_device_count():
def is_available():
return True

@staticmethod
def name():
return "custom_accelerator"

ACCELERATOR_REGISTRY.register(
accelerator_name, CustomAccelerator, description=accelerator_description, param1="abc", param2=123
)
Expand Down
Loading