@@ -10,7 +10,6 @@ authors = [
10
10
requires-python = " >=3.12,<3.13"
11
11
# TODO: split the plotting dependencies into their own dep groups, they are not required.
12
12
dependencies = [
13
- ' torch==2.6.0' ,
14
13
' numpy~=2.2' ,
15
14
' astropy_healpix~=1.1.2' ,
16
15
' zarr~=2.17' ,
@@ -22,7 +21,6 @@ dependencies = [
22
21
' packaging' ,
23
22
' wheel' ,
24
23
' psutil' ,
25
- " flash-attn; sys_platform == 'linux'" ,
26
24
" polars~=1.25.2" ,
27
25
" omegaconf~=2.3.0" ,
28
26
" dask~=2025.5.1" ,
@@ -32,6 +30,7 @@ dependencies = [
32
30
" weathergen-evaluate" ,
33
31
]
34
32
33
+
35
34
[project .urls ]
36
35
Homepage = " https://www.weathergenerator.eu"
37
36
Documentation = " https://readthedocs.org"
@@ -66,6 +65,25 @@ dev = [
66
65
]
67
66
68
67
68
+ # Torch listed as optional dependencies.
69
+ # uv and python can only filter dependencies by platform, not by capability.
70
+ # Following the recommendations from https://docs.astral.sh/uv/guides/integration/pytorch
71
+ # We need to support:
72
+ # x86_64: cpu (unit tests) + gpu
73
+ # aarch64: gpu
74
+ [project .optional-dependencies ]
75
+
76
+ cpu = [
77
+ ' torch==2.6.0' ,
78
+ ]
79
+
80
+ gpu = [
81
+ ' torch==2.6.0+cu126' ,
82
+ # flash-attn also has a torch dependency.
83
+ " flash-attn" ,
84
+ ]
85
+
86
+
69
87
[tool .black ]
70
88
71
89
# Wide rows
@@ -125,6 +143,8 @@ ignore = [
125
143
line-ending = " lf"
126
144
127
145
146
+
147
+
128
148
[tool .uv ]
129
149
# Most work is done a distributed filesystem, where hardlink is not always possible.
130
150
# Also, trying to resolve some permissions issue, see 44.
@@ -141,14 +161,26 @@ link-mode = "symlink"
141
161
# Also, relatively recent versions are required to support workspaces.
142
162
required-version = " >=0.7.0"
143
163
144
- # Following the recommendations from https://docs.astral.sh/uv/guides/integration/pytorch
145
- # The current setup is:
146
- # linux == GPU + flashattention
147
- # windows == GPU
148
- # macos == CPU
164
+ # The supported environments
165
+ # TODO: add macos and windows (CPU only, for running tests)
166
+ environments = [
167
+ " sys_platform == 'linux' and platform_machine == 'aarch64'" ,
168
+ " sys_platform == 'linux' and platform_machine == 'x86_64'" ,
169
+ # "sys_platform == 'darwin'",
170
+ ]
171
+
172
+ # One can only have cpu or gpu.
173
+ conflicts = [
174
+ [
175
+ { extra = " cpu" },
176
+ { extra = " gpu" },
177
+ ],
178
+ ]
179
+
180
+
149
181
[[tool .uv .index ]]
150
- name = " pytorch-cu124 "
151
- url = " https://download.pytorch.org/whl/cu124 "
182
+ name = " pytorch-cu126 "
183
+ url = " https://download.pytorch.org/whl/cu126 "
152
184
explicit = true
153
185
154
186
@@ -181,14 +213,26 @@ explicit = true
181
213
[tool .uv .sources ]
182
214
weathergen-common = { workspace = true }
183
215
weathergen-evaluate = { workspace = true }
184
- torch = [
185
- { index = " pytorch-cu124" , marker = " sys_platform == 'linux' or sys_platform == 'win32'" },
186
- { index = " pytorch-cpu" , marker = " sys_platform == 'macosx'" },
187
- ]
188
- # This URL was evaluated this way:
189
- # uv run ~/WeatherGenerator-private/hpc/hpc2020/ecmwf/get-flash-atten.sh
216
+
217
+
190
218
flash-attn = [
191
- { url = " https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp312-cp312-linux_x86_64.whl" , marker = " sys_platform == 'linux'" },
219
+ # The build of Cathal O'Brien is not compatible with the libc build on santis.
220
+ # Hardcode the reference to the swiss cluster for the time being.
221
+ # TODO: open issue
222
+ # { url = "https://github.com/cathalobrien/get-flash-attn/releases/download/v0.1-alpha/flash_attn-2.7.4+cu12torch2.6cxx11abiFALSE-cp312-cp312-linux_aarch64.whl", marker = "sys_platform == 'linux' and platform_machine == 'aarch64'" },
223
+ # This version was rebuilt locally on santis and uploaded.
224
+ { url = " https://object-store.os-api.cci1.ecmwf.int/weathergenerator-dev/wheels/flash_attn-2.7.3-cp312-cp312-linux_aarch64.whl" , marker = " sys_platform == 'linux' and platform_machine == 'aarch64'" },
225
+ { url = " https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiTRUE-cp312-cp312-linux_x86_64.whl" , marker = " sys_platform == 'linux' and platform_machine == 'x86_64'" },
226
+ # { index = "pytorch-cpu", marker = "sys_platform == 'darwin'"},
227
+ ]
228
+
229
+
230
+ torch = [
231
+ # Explicit pin for GPU
232
+ { url = " https://download.pytorch.org/whl/cu126/torch-2.6.0%2Bcu126-cp312-cp312-linux_aarch64.whl" , marker = ' sys_platform == "linux" and platform_machine == "aarch64"' , extra =" gpu" },
233
+ { url = " https://download.pytorch.org/whl/cu126/torch-2.6.0%2Bcu126-cp312-cp312-manylinux_2_28_x86_64.whl" , marker = ' sys_platform == "linux" and platform_machine == "x86_64"' , extra =" gpu" },
234
+ # Use the public repo for CPU versions.
235
+ { index = " pytorch-cpu" , marker = " sys_platform == 'linux'" , extra =" cpu" },
192
236
]
193
237
194
238
[tool .pytest .ini_options ]
@@ -203,3 +247,4 @@ members = [
203
247
" packages/evaluate" ,
204
248
" packages/common"
205
249
]
250
+
0 commit comments