Skip to content

Commit b4fc1a2

Browse files
authored
[1081][Evaluation] Use parent ruff rules (#1177)
* use ruff settings from parent * fix code checks * check fixes 2nd round * reformat to line length
1 parent 01500a3 commit b4fc1a2

File tree

12 files changed

+231
-182
lines changed

12 files changed

+231
-182
lines changed

packages/evaluate/pyproject.toml

Lines changed: 1 addition & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -28,58 +28,7 @@ export = "weathergen.evaluate.export_inference:export"
2828

2929
# The linting configuration
3030
[tool.ruff]
31-
32-
# Wide rows
33-
line-length = 100
34-
35-
[tool.ruff.lint]
36-
# All disabled until the code is formatted.
37-
select = [
38-
# pycodestyle
39-
"E",
40-
# Pyflakes
41-
"F",
42-
# pyupgrade
43-
"UP",
44-
# flake8-bugbear
45-
"B",
46-
# flake8-simplify
47-
"SIM",
48-
# isort
49-
"I",
50-
# Banned imports
51-
"TID",
52-
# Naming conventions
53-
"N",
54-
# print
55-
"T201"
56-
]
57-
58-
# These rules are sensible and should be enabled at a later stage.
59-
ignore = [
60-
# "B006",
61-
"B011",
62-
"UP008",
63-
"SIM117",
64-
"SIM118",
65-
"SIM102",
66-
"SIM401",
67-
"E501", # to be removed
68-
"E721",
69-
# To ignore, not relevant for us
70-
"SIM108", # in case additional norm layer supports are added in future
71-
"N817", # we use heavy acronyms, e.g., allowing 'import LongModuleName as LMN' (LMN is accepted)
72-
"E731", # overly restrictive and less readable code
73-
"N812", # prevents us following the convention for importing torch.nn.functional as F
74-
]
75-
76-
[tool.ruff.lint.flake8-tidy-imports.banned-api]
77-
"numpy.ndarray".msg = "Do not use 'ndarray' to describe a numpy array type, it is a function. Use numpy.typing.NDArray or numpy.typing.NDArray[np.float32] for example"
78-
79-
[tool.ruff.format]
80-
# Use Unix `\n` line endings for all files
81-
line-ending = "lf"
82-
31+
extend = "../../pyproject.toml"
8332

8433
[tool.pyrefly]
8534
project-includes = ["src/"]

packages/evaluate/src/weathergen/evaluate/clim_utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ def match_climatology_time(target_datetime: pd.Timestamp, clim_data: xr.Dataset)
5252
# To Do: leap years and other edge cases
5353
if len(matching_indices) == 0:
5454
_logger.warning(
55-
f"No matching climatology time found for {target_datetime} (DOY: {target_doy}, Hour: {target_hour})"
55+
f"No matching climatology time found for {target_datetime} (DOY: {target_doy}, "
56+
f"Hour: {target_hour})"
5657
f"Please check that climatology data and stream input data filenames match."
5758
)
5859
return None
@@ -156,8 +157,8 @@ def align_clim_data(
156157
if np.any(unmatched_mask):
157158
n_unmatched = np.sum(unmatched_mask)
158159
raise ValueError(
159-
f"Found {n_unmatched} target coordinates with no matching climatology coordinates. "
160-
f"This will cause incorrect ACC calculations. "
160+
f"Found {n_unmatched} target coordinates with no matching climatology "
161+
f"coordinates. This will cause incorrect ACC calculations. "
161162
f"Check coordinate alignment between target and climatology data."
162163
)
163164
# Cache the computed indices and target coords
@@ -175,8 +176,10 @@ def align_clim_data(
175176
except (ValueError, IndexError) as e:
176177
raise ValueError(
177178
f"Failed to align climatology data with target data for ACC calculation. "
178-
f"This error typically occurs when the number of points per sample varies between samples. "
179-
f"ACC metric is currently only supported for forecasting data with constant points per sample. "
179+
f"This error typically occurs when the number of points per sample varies "
180+
f"between samples. "
181+
f"ACC metric is currently only supported for forecasting data with constant "
182+
f"points per sample. "
180183
f"Please ensure all samples have the same spatial coverage and grid points. "
181184
f"Original error: {e}"
182185
) from e

packages/evaluate/src/weathergen/evaluate/derived_channels.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@ def __init__(
2121
Initializes the DeriveChannels class with necessary configurations for channel derivation.
2222
2323
Args:
24-
available_channels (np.array): an array of all available channel names in the datasets (target or pred).
24+
available_channels (np.array): an array of all available channel names
25+
in the datasets (target or pred).
2526
channels (list): A list of channels of interest to be evaluated and/or plotted.
26-
stream_cfg (dict): A dictionary containing the stream configuration settings for evaluation and plottings.
27+
stream_cfg (dict): A dictionary containing the stream configuration settings for
28+
evaluation and plottings.
2729
2830
Returns:
2931
None
@@ -147,6 +149,7 @@ def get_derived_channels(
147149
)
148150
else:
149151
_logger.debug(
150-
f"Calculation of {tag} is skipped because it is included in the available channels..."
152+
f"Calculation of {tag} is skipped because it is included "
153+
"in the available channels..."
151154
)
152155
return data_tars, data_preds, self.channels

packages/evaluate/src/weathergen/evaluate/export_inference.py

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
# weathergen-common = { path = "../../../../../packages/common" }
1111
# weathergen = { path = "../../../../../" }
1212
# ///
13-
## Example USAGE: uv run export --run-id grwnhykd --stream ERA5 --output-dir /p/home/jusers/owens1/jureca/WeatherGen/test_output1 --format netcdf --type prediction target --fsteps 1 --samples 1
13+
## Example USAGE: uv run export --run-id grwnhykd --stream ERA5 \
14+
## --output-dir /p/home/jusers/owens1/jureca/WeatherGen/test_output1 \
15+
## --format netcdf --type prediction target --fsteps 1 --samples 1
1416
import argparse
1517
import logging
1618
import re
@@ -66,9 +68,11 @@ def find_pl(all_variables: list) -> tuple[dict[str, list[str]], list[int]]:
6668
"""
6769
Find all the pressure levels for each variable using regex and returns a dictionary
6870
mapping variable names to their corresponding pressure levels.
71+
6972
Parameters
7073
----------
7174
all_variables : list of variable names with pressure levels (e.g.,'q_500','t_2m').
75+
7276
Returns
7377
-------
7478
A tuple containing:
@@ -333,14 +337,17 @@ def output_filename(
333337
forecast_ref_time: np.datetime64,
334338
) -> Path:
335339
"""
336-
Generate output filename based on prefix (should refer to type e.g. pred/targ), run_id, sample index, output directory, format and forecast_ref_time.
340+
Generate output filename based on prefix (should refer to type e.g. pred/targ), run_id, sample
341+
index, output directory, format and forecast_ref_time.
342+
337343
Parameters
338344
----------
339345
prefix : Prefix for file name (e.g., 'pred' or 'targ').
340346
run_id :Run ID to include in the filename.
341347
output_dir : Directory to save the output file.
342348
output_format : Output file format (currently only 'netcdf' supported).
343349
forecast_ref_time : Forecast reference time to include in the filename.
350+
344351
Returns
345352
-------
346353
Full path to the output file.
@@ -358,9 +365,11 @@ def output_filename(
358365
def get_data_worker(args: tuple) -> xr.DataArray:
359366
"""
360367
Worker function to retrieve data for a single sample and forecast step.
368+
361369
Parameters
362370
----------
363371
args : Tuple containing (sample, fstep, run_id, stream, type).
372+
364373
Returns
365374
-------
366375
xarray DataArray for the specified sample and forecast step.
@@ -397,18 +406,30 @@ def get_data(
397406
398407
Parameters
399408
----------
400-
run_id : Run ID to identify the Zarr store.
401-
samples : Sample to process
402-
stream : Stream name to retrieve data for (e.g., 'ERA5').
403-
type : Type of data to retrieve ('target' or 'prediction').
404-
fsteps : List of forecast steps to retrieve. If None, retrieves all available forecast steps.
405-
channels :List of channels to retrieve. If None, retrieves all available channels.
406-
n_processes : Number of parallel processes to use for data retrieval.
407-
ecpoch : Epoch number to identify the Zarr store.
408-
rank : Rank number to identify the Zarr store.
409-
output_dir : Directory to save the NetCDF files.
410-
output_format : Output file format (currently only 'netcdf' supported).
411-
config : Loaded config for cf_parser function.
409+
run_id : str
410+
Run ID to identify the Zarr store.
411+
samples : list
412+
Sample to process
413+
stream : str
414+
Stream name to retrieve data for (e.g., 'ERA5').
415+
dtype : str
416+
Type of data to retrieve ('target' or 'prediction').
417+
fsteps : list
418+
List of forecast steps to retrieve. If None, retrieves all available forecast steps.
419+
channels : list
420+
List of channels to retrieve. If None, retrieves all available channels.
421+
n_processes : list
422+
Number of parallel processes to use for data retrieval.
423+
ecpoch : int
424+
Epoch number to identify the Zarr store.
425+
rank : int
426+
Rank number to identify the Zarr store.
427+
output_dir : str
428+
Directory to save the NetCDF files.
429+
output_format : str
430+
Output file format (currently only 'netcdf' supported).
431+
config : OmegaConf
432+
Loaded config for cf_parser function.
412433
"""
413434
if dtype not in ["target", "prediction"]:
414435
raise ValueError(f"Invalid type: {dtype}. Must be 'target' or 'prediction'.")
@@ -451,7 +472,8 @@ def get_data(
451472
f"{list(set(channels) - set(existing_channels))}. Skipping them."
452473
)
453474
result = result.sel(channel=existing_channels)
454-
# reshape result - use adaptive function to handle both regular and Gaussian grids
475+
# reshape result - use adaptive function to handle both regular and Gaussian
476+
# grids
455477
result = reshape_dataset_adaptive(result)
456478
da_fs.append(result)
457479

@@ -484,12 +506,14 @@ def save_sample_to_netcdf(
484506
) -> None:
485507
"""
486508
Uses list of pred/target xarray DataArrays to save one sample to a NetCDF file.
509+
487510
Parameters
488511
----------
489512
type_str : str
490513
Type of data ('pred' or 'targ') to include in the filename.
491514
dict_sample_all_steps : dict
492-
Dictionary where keys is sample index and values is a list of xarray DataArrays for all the forecast steps
515+
Dictionary where keys is sample index and values is a list of xarray DataArrays
516+
for all the forecast steps
493517
fstep_hours : np.timedelta64
494518
Time difference between forecast steps (e.g., 6 hours).
495519
run_id : str
@@ -595,7 +619,8 @@ def parse_args(args: list) -> argparse.Namespace:
595619
type=int,
596620
nargs="+",
597621
default=None,
598-
help="List of forecast steps to retrieve (e.g. 1 2 3). If not provided, retrieves all available forecast steps.",
622+
help="List of forecast steps to retrieve (e.g. 1 2 3). "
623+
"If not provided, retrieves all available forecast steps.",
599624
)
600625

601626
parser.add_argument(
@@ -611,7 +636,8 @@ def parse_args(args: list) -> argparse.Namespace:
611636
type=str,
612637
nargs="+",
613638
default=None,
614-
help="List of channels to retrieve (e.g., 'q_500 t_2m'). If not provided, retrieves all available channels.",
639+
help="List of channels to retrieve (e.g., 'q_500 t_2m'). "
640+
"If not provided, retrieves all available channels.",
615641
)
616642

617643
parser.add_argument(

0 commit comments

Comments
 (0)