Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 89 additions & 29 deletions src/weathergen/datasets/data_reader_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,82 @@ def normalize_coords(self, coords: NDArray[DType]) -> NDArray[DType]:

return coords

def _normalize(
self,
data: NDArray[DType],
idx: list[int],
mean: dict[int, float],
stdev: dict[int, float],
name: str,
) -> NDArray[DType]:
"""
Helper function to normalize data

Parameters
----------
data :
data to be normalized
idx :
indices of channels to be normalized
mean :
mean values for channels
stdev :
standard deviation values for channels
name :
name of the data (for error messages)

Returns
-------
Normalized data
"""
# assert data.shape[-1] == len(idx), f"incorrect number of {name} channels"
if data.shape[-1] != len(idx):
raise ValueError(
f"incorrect number of {name} channels: expected {len(idx)}, got {data.shape[-1]}"
)
for i, ch in enumerate(idx):
data[..., i] = (data[..., i] - mean[ch]) / stdev[ch]

return data

def _denormalize(
self,
data: NDArray[DType],
idx: list[int],
mean: dict[int, float],
stdev: dict[int, float],
name: str,
) -> NDArray[DType]:
"""
Helper function to denormalize data

Parameters
----------
data :
data to be denormalized
idx :
indices of channels to be denormalized
mean :
mean values for channels
stdev :
standard deviation values for channels
name :
name of the data (for error messages)

Returns
-------
Denormalized data
"""
# assert data.shape[-1] == len(idx), f"incorrect number of {name} channels"
if data.shape[-1] != len(idx):
raise ValueError(
f"incorrect number of {name} channels: expected {len(idx)}, got {data.shape[-1]}"
)
for i, ch in enumerate(idx):
data[..., i] = (data[..., i] * stdev[ch]) + mean[ch]

return data

def normalize_geoinfos(self, geoinfos: NDArray[DType]) -> NDArray[DType]:
"""
Normalize geoinfos
Expand All @@ -501,75 +577,59 @@ def normalize_source_channels(self, source: NDArray[DType]) -> NDArray[DType]:

Parameters
----------
data :
source :
data to be normalized

Returns
-------
Normalized data
Normalized source data
"""
assert source.shape[-1] == len(self.source_idx), "incorrect number of source channels"
for i, ch in enumerate(self.source_idx):
source[..., i] = (source[..., i] - self.mean[ch]) / self.stdev[ch]

return source
return self._normalize(source, self.source_idx, self.mean, self.stdev, "source")

def normalize_target_channels(self, target: NDArray[DType]) -> NDArray[DType]:
"""
Normalize target channels

Parameters
----------
data :
target :
data to be normalized

Returns
-------
Normalized data
Normalized target data
"""
assert target.shape[-1] == len(self.target_idx), "incorrect number of target channels"
for i, ch in enumerate(self.target_idx):
target[..., i] = (target[..., i] - self.mean[ch]) / self.stdev[ch]

return target
return self._normalize(target, self.target_idx, self.mean, self.stdev, "target")

def denormalize_source_channels(self, source: NDArray[DType]) -> NDArray[DType]:
"""
Denormalize source channels

Parameters
----------
data :
source :
data to be denormalized

Returns
-------
Denormalized data
Denormalized source data
"""
assert source.shape[-1] == len(self.source_idx), "incorrect number of source channels"
for i, ch in enumerate(self.source_idx):
source[..., i] = (source[..., i] * self.stdev[ch]) + self.mean[ch]
return self._denormalize(source, self.source_idx, self.mean, self.stdev, "source")

return source

def denormalize_target_channels(self, data: NDArray[DType]) -> NDArray[DType]:
def denormalize_target_channels(self, target: NDArray[DType]) -> NDArray[DType]:
"""
Denormalize target channels

Parameters
----------
data :
target :
data to be denormalized (target or pred)

Returns
-------
Denormalized data
Denormalized target data
"""
assert data.shape[-1] == len(self.target_idx), "incorrect number of target channels"
for i, ch in enumerate(self.target_idx):
data[..., i] = (data[..., i] * self.stdev[ch]) + self.mean[ch]

return data
return self._denormalize(target, self.target_idx, self.mean, self.stdev, "target")


class DataReaderTimestep(DataReaderBase):
Expand Down