From 9bdedd0d5c0efb48af3b25b6cb5a808feb4cde5a Mon Sep 17 00:00:00 2001 From: AlexandrByzov Date: Tue, 3 Jun 2025 18:20:55 +0200 Subject: [PATCH 1/6] bugfix: add support for global_ordinal, local_ordinal, world_size in xla --- src/lightning/fabric/plugins/environments/xla.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/lightning/fabric/plugins/environments/xla.py b/src/lightning/fabric/plugins/environments/xla.py index a227d2322b9a3..b8350872f22d9 100644 --- a/src/lightning/fabric/plugins/environments/xla.py +++ b/src/lightning/fabric/plugins/environments/xla.py @@ -66,6 +66,11 @@ def world_size(self) -> int: The output is cached for performance. """ + if _XLA_GREATER_EQUAL_2_1: + from torch_xla import runtime as xr + + return xr.world_size() + import torch_xla.core.xla_model as xm return xm.xrt_world_size() @@ -82,6 +87,11 @@ def global_rank(self) -> int: The output is cached for performance. """ + if _XLA_GREATER_EQUAL_2_1: + from torch_xla import runtime as xr + + return xr.global_ordinal() + import torch_xla.core.xla_model as xm return xm.get_ordinal() @@ -98,6 +108,11 @@ def local_rank(self) -> int: The output is cached for performance. """ + if _XLA_GREATER_EQUAL_2_1: + from torch_xla import runtime as xr + + return xr.local_ordinal() + import torch_xla.core.xla_model as xm return xm.get_local_ordinal() From 4e8e86c4d2dd510cc1acf955b587cac616c3058d Mon Sep 17 00:00:00 2001 From: AlexandrByzov Date: Tue, 3 Jun 2025 18:29:30 +0200 Subject: [PATCH 2/6] docs: update changelog --- src/lightning/fabric/CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index f67eec28deeeb..0b5a653b68835 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -16,6 +16,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - +### Fixed + +- Fix XLA strategy to add support for for global_ordinal, local_ordinal, world_size which came instead of deprecated methods ([#20852](https://github.com/Lightning-AI/pytorch-lightning/issues/20852)) --- From ee41f8214c956f65c3d6c42aec66519f6b099eb5 Mon Sep 17 00:00:00 2001 From: AlexandrByzov Date: Sun, 15 Jun 2025 16:46:56 +0200 Subject: [PATCH 3/6] feat: add tests for world_size, global_ordinal, local_ordinal --- .../plugins/environments/test_xla.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/tests_fabric/plugins/environments/test_xla.py b/tests/tests_fabric/plugins/environments/test_xla.py index 7e33d5db87dd4..adbc49665a6f0 100644 --- a/tests/tests_fabric/plugins/environments/test_xla.py +++ b/tests/tests_fabric/plugins/environments/test_xla.py @@ -97,3 +97,34 @@ def test_detect(monkeypatch): monkeypatch.setattr(lightning.fabric.accelerators.xla.XLAAccelerator, "is_available", lambda: True) assert XLAEnvironment.detect() + + +@mock.patch.dict(os.environ, {}, clear=True) +def test_attributes_from_xla_greater_21_used(xla_available, monkeypatch): + """Test XLA environment attributes when using XLA runtime >= 2.1.""" + monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_GREATER_EQUAL_2_1", True) + monkeypatch.setattr(lightning.fabric.plugins.environments.xla, "_XLA_GREATER_EQUAL_2_1", True) + + env = XLAEnvironment() + + with ( + mock.patch("torch_xla.runtime.world_size", return_value=4), + mock.patch("torch_xla.runtime.global_ordinal", return_value=2), + mock.patch("torch_xla.runtime.local_ordinal", return_value=1), + ): + env.world_size.cache_clear() + env.global_rank.cache_clear() + env.local_rank.cache_clear() + + assert env.world_size() == 4 + assert env.global_rank() == 2 + assert env.local_rank() == 1 + + env.set_world_size(100) + assert env.world_size() == 4 + + env.set_global_rank(100) + assert env.global_rank() == 2 + + env.set_local_rank(100) + assert env.local_rank() == 1 From 13cf1d02c56d0eea91e9715ebc081b362381a747 Mon Sep 17 00:00:00 2001 From: AlexandrByzov Date: Sun, 15 Jun 2025 16:51:55 +0200 Subject: [PATCH 4/6] fix: remove set local rank --- tests/tests_fabric/plugins/environments/test_xla.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/tests_fabric/plugins/environments/test_xla.py b/tests/tests_fabric/plugins/environments/test_xla.py index adbc49665a6f0..b782db0cccf0b 100644 --- a/tests/tests_fabric/plugins/environments/test_xla.py +++ b/tests/tests_fabric/plugins/environments/test_xla.py @@ -125,6 +125,3 @@ def test_attributes_from_xla_greater_21_used(xla_available, monkeypatch): env.set_global_rank(100) assert env.global_rank() == 2 - - env.set_local_rank(100) - assert env.local_rank() == 1 From 85bbdce9d2ad2125c457ba5e9d89800a9497af06 Mon Sep 17 00:00:00 2001 From: Alex Byzov Date: Fri, 20 Jun 2025 11:58:44 +0200 Subject: [PATCH 5/6] Update tests/tests_fabric/plugins/environments/test_xla.py Co-authored-by: Bhimraj Yadav --- tests/tests_fabric/plugins/environments/test_xla.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tests_fabric/plugins/environments/test_xla.py b/tests/tests_fabric/plugins/environments/test_xla.py index b782db0cccf0b..4e624c7a2306c 100644 --- a/tests/tests_fabric/plugins/environments/test_xla.py +++ b/tests/tests_fabric/plugins/environments/test_xla.py @@ -100,10 +100,10 @@ def test_detect(monkeypatch): @mock.patch.dict(os.environ, {}, clear=True) +@mock.patch("lightning.fabric.accelerators.xla._XLA_GREATER_EQUAL_2_1", True) +@mock.patch("(lightning.fabric.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", True) def test_attributes_from_xla_greater_21_used(xla_available, monkeypatch): """Test XLA environment attributes when using XLA runtime >= 2.1.""" - monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_GREATER_EQUAL_2_1", True) - monkeypatch.setattr(lightning.fabric.plugins.environments.xla, "_XLA_GREATER_EQUAL_2_1", True) env = XLAEnvironment() From 140ac0b00475d0df81c28cf079c7e60012e17296 Mon Sep 17 00:00:00 2001 From: Alex Byzov Date: Fri, 20 Jun 2025 12:22:29 +0200 Subject: [PATCH 6/6] Update tests/tests_fabric/plugins/environments/test_xla.py Co-authored-by: Bhimraj Yadav --- tests/tests_fabric/plugins/environments/test_xla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_fabric/plugins/environments/test_xla.py b/tests/tests_fabric/plugins/environments/test_xla.py index 4e624c7a2306c..f6a24792a4316 100644 --- a/tests/tests_fabric/plugins/environments/test_xla.py +++ b/tests/tests_fabric/plugins/environments/test_xla.py @@ -101,7 +101,7 @@ def test_detect(monkeypatch): @mock.patch.dict(os.environ, {}, clear=True) @mock.patch("lightning.fabric.accelerators.xla._XLA_GREATER_EQUAL_2_1", True) -@mock.patch("(lightning.fabric.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", True) +@mock.patch("lightning.fabric.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", True) def test_attributes_from_xla_greater_21_used(xla_available, monkeypatch): """Test XLA environment attributes when using XLA runtime >= 2.1."""