Skip to content

Commit 87f6127

Browse files
tjhunterclessig
andauthored
[1022] Getting WG to work on santis (#1023)
* working pytorch * changes * Fix for code to work on Alps-Santis * changes * cleanups * changes * reverting change * having issues with the latest branch on santis * changes * changes * changes * override with cpu * working for cpu * flash-attn moved to gpu * remove contstraint * simplifying * trying * working on atos * changes * macos * chanegs * cleanups * actions * actions * actions * actions * changes --------- Co-authored-by: Christian Lessig <[email protected]>
1 parent 349703d commit 87f6127

File tree

8 files changed

+811
-458
lines changed

8 files changed

+811
-458
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ jobs:
3131
- name: Type checker (pyrefly, experimental)
3232
# Do not attempt to install the default dependencies, this is much faster.
3333
# Run temporarily on a sub directory before the main restyling.
34-
run: ./scripts/actions.sh type-check-experimental || echo "::warning::typing issues found"
34+
run: ./scripts/actions.sh type-check || echo "::warning::typing issues found"
3535
pr:
3636
name: PR checks
3737
runs-on: ubuntu-latest

config/default_config.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ ae_global_num_blocks: 8
2828
ae_global_num_heads: 32
2929
ae_global_dropout_rate: 0.1
3030
ae_global_with_qk_lnorm: True
31-
ae_global_att_dense_rate: 0.2
31+
# TODO: switching to < 1 triggers triton-related issues.
32+
# See https://github.com/ecmwf/WeatherGenerator/issues/1050
33+
ae_global_att_dense_rate: 1.0
3234
ae_global_block_factor: 64
3335
ae_global_mlp_hidden_factor: 2
3436

integration_tests/streams/era5_small.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
ERA5 :
1111
type : anemoi
12-
filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2022-6h-v6.zarr']
12+
filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr']
1313
loss_weight : 1.
1414
source_exclude : ['w_', 'skt', 'sp', 'tcw', 'cp', 'tp']
1515
target_exclude : ['w_', 'skt', 'sp', 'tcw', 'cp', 'tp']

pyproject.toml

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ authors = [
1010
requires-python = ">=3.12,<3.13"
1111
# TODO: split the plotting dependencies into their own dep groups, they are not required.
1212
dependencies = [
13-
'torch==2.6.0',
1413
'numpy~=2.2',
1514
'astropy_healpix~=1.1.2',
1615
'zarr~=2.17',
@@ -22,7 +21,6 @@ dependencies = [
2221
'packaging',
2322
'wheel',
2423
'psutil',
25-
"flash-attn; sys_platform == 'linux'",
2624
"polars~=1.25.2",
2725
"omegaconf~=2.3.0",
2826
"dask~=2025.5.1",
@@ -32,6 +30,7 @@ dependencies = [
3230
"weathergen-evaluate",
3331
]
3432

33+
3534
[project.urls]
3635
Homepage = "https://www.weathergenerator.eu"
3736
Documentation = "https://readthedocs.org"
@@ -66,6 +65,25 @@ dev = [
6665
]
6766

6867

68+
# Torch listed as optional dependencies.
69+
# uv and python can only filter dependencies by platform, not by capability.
70+
# Following the recommendations from https://docs.astral.sh/uv/guides/integration/pytorch
71+
# We need to support:
72+
# x86_64: cpu (unit tests) + gpu
73+
# aarch64: gpu
74+
[project.optional-dependencies]
75+
76+
cpu = [
77+
'torch==2.6.0',
78+
]
79+
80+
gpu = [
81+
'torch==2.6.0+cu126',
82+
# flash-attn also has a torch dependency.
83+
"flash-attn",
84+
]
85+
86+
6987
[tool.black]
7088

7189
# Wide rows
@@ -125,6 +143,8 @@ ignore = [
125143
line-ending = "lf"
126144

127145

146+
147+
128148
[tool.uv]
129149
# Most work is done a distributed filesystem, where hardlink is not always possible.
130150
# Also, trying to resolve some permissions issue, see 44.
@@ -141,14 +161,26 @@ link-mode = "symlink"
141161
# Also, relatively recent versions are required to support workspaces.
142162
required-version = ">=0.7.0"
143163

144-
# Following the recommendations from https://docs.astral.sh/uv/guides/integration/pytorch
145-
# The current setup is:
146-
# linux == GPU + flashattention
147-
# windows == GPU
148-
# macos == CPU
164+
# The supported environments
165+
# TODO: add macos and windows (CPU only, for running tests)
166+
environments = [
167+
"sys_platform == 'linux' and platform_machine == 'aarch64'",
168+
"sys_platform == 'linux' and platform_machine == 'x86_64'",
169+
# "sys_platform == 'darwin'",
170+
]
171+
172+
# One can only have cpu or gpu.
173+
conflicts = [
174+
[
175+
{ extra = "cpu" },
176+
{ extra = "gpu" },
177+
],
178+
]
179+
180+
149181
[[tool.uv.index]]
150-
name = "pytorch-cu124"
151-
url = "https://download.pytorch.org/whl/cu124"
182+
name = "pytorch-cu126"
183+
url = "https://download.pytorch.org/whl/cu126"
152184
explicit = true
153185

154186

@@ -181,14 +213,26 @@ explicit = true
181213
[tool.uv.sources]
182214
weathergen-common = { workspace = true }
183215
weathergen-evaluate = { workspace = true }
184-
torch = [
185-
{ index = "pytorch-cu124", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
186-
{ index = "pytorch-cpu", marker = "sys_platform == 'macosx'"},
187-
]
188-
# This URL was evaluated this way:
189-
# uv run ~/WeatherGenerator-private/hpc/hpc2020/ecmwf/get-flash-atten.sh
216+
217+
190218
flash-attn = [
191-
{ url = "https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp312-cp312-linux_x86_64.whl", marker = "sys_platform == 'linux'" },
219+
# The build of Cathal O'Brien is not compatible with the libc build on santis.
220+
# Hardcode the reference to the swiss cluster for the time being.
221+
# TODO: open issue
222+
# { url = "https://github.com/cathalobrien/get-flash-attn/releases/download/v0.1-alpha/flash_attn-2.7.4+cu12torch2.6cxx11abiFALSE-cp312-cp312-linux_aarch64.whl", marker = "sys_platform == 'linux' and platform_machine == 'aarch64'" },
223+
# This version was rebuilt locally on santis and uploaded.
224+
{ url = "https://object-store.os-api.cci1.ecmwf.int/weathergenerator-dev/wheels/flash_attn-2.7.3-cp312-cp312-linux_aarch64.whl", marker = "sys_platform == 'linux' and platform_machine == 'aarch64'" },
225+
{ url = "https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiTRUE-cp312-cp312-linux_x86_64.whl", marker = "sys_platform == 'linux' and platform_machine == 'x86_64'" },
226+
# { index = "pytorch-cpu", marker = "sys_platform == 'darwin'"},
227+
]
228+
229+
230+
torch = [
231+
# Explicit pin for GPU
232+
{ url = "https://download.pytorch.org/whl/cu126/torch-2.6.0%2Bcu126-cp312-cp312-linux_aarch64.whl", marker = 'sys_platform == "linux" and platform_machine == "aarch64"', extra="gpu" },
233+
{ url = "https://download.pytorch.org/whl/cu126/torch-2.6.0%2Bcu126-cp312-cp312-manylinux_2_28_x86_64.whl", marker = 'sys_platform == "linux" and platform_machine == "x86_64"', extra="gpu" },
234+
# Use the public repo for CPU versions.
235+
{ index = "pytorch-cpu", marker = "sys_platform == 'linux'", extra="cpu"},
192236
]
193237

194238
[tool.pytest.ini_options]
@@ -203,3 +247,4 @@ members = [
203247
"packages/evaluate",
204248
"packages/common"
205249
]
250+

scripts/actions.sh

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ case "$1" in
77
sync)
88
(
99
cd "$SCRIPT_DIR" || exit 1
10-
uv sync --all-packages
10+
uv sync --all-packages --extra gpu
1111
)
1212
;;
1313
lint)
@@ -34,7 +34,7 @@ case "$1" in
3434
src/ scripts/ packages/
3535
)
3636
;;
37-
type-check-experimental)
37+
type-check)
3838
(
3939
cd "$SCRIPT_DIR/packages/common" || exit 1
4040
uv run --all-packages pyrefly check
@@ -47,13 +47,15 @@ case "$1" in
4747
unit-test)
4848
(
4949
cd "$SCRIPT_DIR" || exit 1
50-
uv run pytest src/
50+
uv sync --extra cpu
51+
uv run --extra cpu pytest src/
5152
)
5253
;;
5354
integration-test)
5455
(
5556
cd "$SCRIPT_DIR" || exit 1
56-
srun uv run --offline pytest ./integration_tests/small1_test.py --verbose
57+
uv sync --offline --all-packages --extra gpu
58+
uv run --offline pytest ./integration_tests/small1_test.py --verbose -s
5759
)
5860
;;
5961
create-links)
@@ -110,7 +112,7 @@ case "$1" in
110112
)
111113
;;
112114
*)
113-
echo "Usage: $0 {sync|lint|lint-check|unit-test|integration-test|create-links|create-jupyter-kernel|jupytext-sync}"
115+
echo "Usage: $0 {sync|lint|lint-check|type-check|unit-test|integration-test|create-links|create-jupyter-kernel|jupytext-sync}"
114116
exit 1
115117
;;
116118
esac

src/weathergen/train/trainer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def run(self, cf, devices, run_id_contd=None, epoch_contd=None):
158158
self.init(cf, devices)
159159
cf = self.cf
160160

161+
# TODO: do not define new members outside of the init!!
161162
self.device_type = torch.accelerator.current_accelerator()
162163
self.device = torch.device(f"{self.device_type}:{cf.local_rank}")
163164

@@ -676,6 +677,9 @@ def validate(self, epoch):
676677
self.dataset_val.advance()
677678

678679
def batch_to_device(self, batch):
680+
# TODO: do not define new members outside of the init!!
681+
self.device_type = torch.accelerator.current_accelerator()
682+
self.device = torch.device(f"{self.device_type}:{self.cf.local_rank}")
679683
# forecast_steps is dropped here from the batch
680684
return (
681685
[[d.to_device(self.device) for d in db] for db in batch[0]],

tests/test_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
DUMMY_STREAM_CONF = {
2525
"ERA5": {
2626
"type": "anemoi",
27-
"filenames": ["aifs-ea-an-oper-0001-mars-o96-1979-2022-6h-v6.zarr"],
27+
"filenames": ["aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr"],
2828
"source": ["u_", "v_", "10u", "10v"],
2929
"target": ["10u", "10v"],
3030
"loss_weight": 1.0,

0 commit comments

Comments
 (0)