Skip to content

Enable patchwise training and prediction #135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 169 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
169 commits
Select commit Hold shift + click to select a range
131c434
stach changes
nilsleh Oct 10, 2023
3342b96
draft
nilsleh Oct 12, 2023
b7cf3fa
draft
nilsleh Oct 12, 2023
70f3783
merge main
nilsleh Oct 12, 2023
379e3b2
wrong merge
nilsleh Oct 12, 2023
85cd34b
incorporate some of the feedback
nilsleh Oct 13, 2023
be8fffd
run black
nilsleh Oct 13, 2023
3415377
merge main
nilsleh Nov 6, 2023
39dd15b
merge main
nilsleh Nov 6, 2023
876970e
layout code
nilsleh Apr 12, 2024
d1cb338
change __call__
nilsleh Apr 12, 2024
218f791
revert
nilsleh Apr 12, 2024
37fe771
type annotation
nilsleh Apr 12, 2024
fb20ccc
patch_size sampling test
nilsleh Apr 15, 2024
5bda80b
patchwise test trainer
nilsleh Apr 15, 2024
c276844
gridded window patching
Apr 19, 2024
fde7e02
adding sliding window patching function
Apr 19, 2024
195a923
loader with bboxes
nilsleh Apr 22, 2024
824df24
loader with boxes
nilsleh Apr 22, 2024
e6e1ae8
Altering kwargs to enable for-loop and change sliding function
Apr 22, 2024
e75d022
Merge branch 'patchwise_train' into msjr/patching
nilsleh Apr 22, 2024
bae0855
move logic to call
nilsleh Apr 22, 2024
a090d34
Merge branch 'main' into patchwise_train
nilsleh Apr 22, 2024
797f48e
Merge branch 'patchwise_train' into msjr/patching
nilsleh Apr 22, 2024
5291ec3
Merge pull request #1 from nilsleh/msjr/patching
nilsleh Apr 23, 2024
7b09119
typo
nilsleh Apr 23, 2024
282c2be
notebook with patchwise train
nilsleh Apr 24, 2024
dfa386d
refining stride to avoid error
Apr 24, 2024
8d46653
inference patching
Apr 27, 2024
acbad8b
predict_patches
Apr 28, 2024
3e2994e
patchwise predictions during inference and stitching
May 3, 2024
765849d
fix typo
May 3, 2024
abc9a9d
merge main
nilsleh Jun 11, 2024
7f8ef93
new cropped stitching
Jun 24, 2024
847a47c
clipped patchwise predictions, single date
Jun 26, 2024
f93fc39
correct minor errors/typos
Jun 27, 2024
07256f0
Merge pull request #2 from nilsleh/msjr-test_patching
nilsleh Jun 27, 2024
d8af314
use TODO to be uniform
davidwilby Jun 27, 2024
f3b7f12
use "stride" as in taskloader
davidwilby Jun 28, 2024
5a1766b
resolve unnormalised coordinate names
Jul 10, 2024
84d9944
Handle absent bbox and task as non-iterable
davidwilby Jul 11, 2024
aab6f1e
resolve unnormalised coordinate names
Jul 10, 2024
bda7176
use dict format for isel for variable coordinate names
davidwilby Jul 11, 2024
55bf86f
add basic test for patchwise prediction
davidwilby Jul 16, 2024
323ab46
handle patch_size and stride as floats or tuples in task loader and p…
davidwilby Jul 18, 2024
09befb3
test parameter handling and sizes in patchwise prediction
davidwilby Jul 19, 2024
c36455b
remove resolved TODO
davidwilby Jul 19, 2024
0cf143d
check patch_size and stride values in predict_patch and test
davidwilby Jul 19, 2024
8da48c1
test inference
nilsleh Aug 7, 2024
7cf556e
correct typo
Aug 8, 2024
ca6d001
resolve conflicts
Aug 8, 2024
bc862df
fix stride & patch checking
davidwilby Aug 9, 2024
61dc88e
revert previous commit
nilsleh Aug 9, 2024
f2bd5bb
fix patchwise training tests
davidwilby Aug 9, 2024
601102e
add actual training step to test_sliding_window_training
davidwilby Aug 9, 2024
64773fc
try to make printing work for task objects with bbox attribute
davidwilby Aug 9, 2024
69f0ac6
run black
davidwilby Aug 9, 2024
f5e4a8a
re-add missing code from task loader
davidwilby Aug 12, 2024
8aac7b6
Merge pull request #5 from nilsleh/dw/add_missing_loader_code
nilsleh Aug 12, 2024
c7a6172
Merge pull request #3 from nilsleh/msjr-test_patching
nilsleh Aug 12, 2024
5e29031
Commit to allow patching irrespective of whether x1 and x2 are ascend…
Aug 12, 2024
294cc47
changes to loader.py to ensure all patched tasks run left to right an…
Aug 12, 2024
4e136e3
Commit to make model agnostic to coord direction
Aug 13, 2024
6dd9b3a
Resolve conflicts primarily due to use of unnorm_name and orig_name
Aug 14, 2024
529e8c8
use more informative error message for predict_patch
davidwilby Aug 14, 2024
f37e28c
Merge branch 'patchwise_train' into dw/refactor_predict
davidwilby Aug 15, 2024
0344c2a
fix use of stride_size
davidwilby Aug 15, 2024
840838d
move patchwise parameter test to test_task_loader
davidwilby Aug 16, 2024
ceeb8ca
fix patch_size and stride for sliding window tests
davidwilby Aug 16, 2024
5fc1fe3
remove test as moved to test_task_loader
davidwilby Aug 16, 2024
1f434cc
check input parameters in task loader
davidwilby Aug 16, 2024
96edce8
For patchwise prediction, get patch_size and stride directly from task
davidwilby Aug 16, 2024
a705549
Merge pull request #6 from nilsleh/msjr-test_patching
nilsleh Aug 21, 2024
f5c015b
Merge branch 'patchwise_train' into dw/patch_size_from_task_for_predict
davidwilby Aug 21, 2024
c14d5d1
Merge pull request #7 from nilsleh/dw/patch_size_from_task_for_predict
nilsleh Aug 21, 2024
18f2e5a
raise errors instead of assert
davidwilby Aug 22, 2024
47d0998
use warning for stride > patch size
davidwilby Aug 22, 2024
df0533b
remove comment
davidwilby Aug 22, 2024
f7d57e9
raise error for stride > patch_size in prediction
davidwilby Aug 22, 2024
fed3940
alter paramaters for test
davidwilby Aug 22, 2024
b3a6dab
raise error for more than one date in predict_patch
davidwilby Aug 22, 2024
c8a38f2
black
davidwilby Aug 22, 2024
52a0cb3
Merge branch 'main' into dw/merge_main
davidwilby Aug 22, 2024
2672fec
Merge branch 'patchwise_train' into dw/refactor_predict
davidwilby Aug 22, 2024
e10d645
fix getting and checking of patch_size and stride
davidwilby Aug 22, 2024
f6f843d
fix docstrings and defaults
davidwilby Aug 22, 2024
2e5c6a8
reinstate orig_name patch clip slicing
davidwilby Aug 22, 2024
e4d5567
Merge pull request #8 from nilsleh/dw/merge_main
nilsleh Aug 22, 2024
e57f065
Merge branch 'patchwise_train' into dw/refactor_predict
davidwilby Aug 23, 2024
51d8c05
use hypothesis to expand on patchwise predict testing
davidwilby Aug 23, 2024
79afa12
Merge pull request #4 from nilsleh/dw/refactor_predict
davidwilby Aug 23, 2024
d6500f7
account for warnings/errors in patchwise task loader
davidwilby Aug 23, 2024
7c47357
allow longer test runs
davidwilby Aug 23, 2024
274902a
use patch size which relates to the normalised size
davidwilby Aug 23, 2024
60c7a14
alter docstring to reflect function
davidwilby Aug 23, 2024
cdbd73a
attempt fix for compute_x1x2_direction
davidwilby Aug 23, 2024
f1e3dfe
Merge branch 'dw/compute_coord_direction' into dw/fix_regression
davidwilby Aug 23, 2024
6f0e2e6
Merge pull request #5 from davidwilby/dw/fix_regression
davidwilby Sep 17, 2024
c0cd17e
address montonic and prediction size issues
Sep 17, 2024
d30e687
move patchwise test out of class
davidwilby Oct 11, 2024
7a100ee
Update deepsensor/model/model.py
MartinSJRogers Oct 15, 2024
23733df
Merge pull request #8 from davidwilby/montonic_errors
MartinSJRogers Oct 15, 2024
4a4276d
Move spatial slicing below gapfill sampling
Oct 31, 2024
2f0e2ba
Merge pull request #11 from davidwilby/gapfill_loop
MartinSJRogers Oct 31, 2024
e2488f3
Merge branch 'main' into patchwise_train
davidwilby Oct 31, 2024
54ee611
Merge remote-tracking branch 'origin/patchwise_train' into patchwise_…
davidwilby Oct 31, 2024
9b1f30d
lint patchwise code
davidwilby Oct 31, 2024
812e056
Update patchwise training notebook with additional descriptive text
Nov 1, 2024
3a34ed3
Merge pull request #13 from davidwilby/patchwise_linting
davidwilby Nov 5, 2024
4859f2d
rename notebook; use new tqdm notebook; other small tweaks to text
davidwilby Nov 5, 2024
53f238f
add correct output and prediction plot
davidwilby Nov 5, 2024
9e4254f
Merge branch 'patchwise_train' into update_notebook
davidwilby Nov 5, 2024
3e68ee1
Merge pull request #14 from davidwilby/update_notebook
davidwilby Nov 7, 2024
f7d5422
use python 3.8 compatible type hints
davidwilby Nov 13, 2024
527edff
rename predict_patch to predict_patchwise and fix references
davidwilby Nov 13, 2024
afac690
remove mention of contributing in error message
davidwilby Nov 13, 2024
277cdc3
refactor overlap calculation
davidwilby Nov 25, 2024
88ae024
use smaller test dataset
davidwilby Nov 25, 2024
325de6d
update docstring for predict_patchwise
davidwilby Nov 27, 2024
9d79b34
account for non-gridded data correctly
davidwilby Nov 29, 2024
4c05b77
refactor to reduce duplication; reduce floating point errors
davidwilby Nov 29, 2024
4b570ae
first attempt using merge
Nov 29, 2024
df88bc7
pass kwargs to predict; use data_processor attribute instead of arg
davidwilby Dec 2, 2024
4e028ab
correct typo
davidwilby Dec 2, 2024
657a42a
Replace combine by coords with method to infill blank prediction obje…
Dec 3, 2024
e68c01a
remove the +1 to prevent Nan lines forming
Dec 3, 2024
8b9a8ac
add some comments
davidwilby Dec 3, 2024
d620e88
Merge pull request #16 from davidwilby/refactor_sample_sliding
davidwilby Dec 3, 2024
3c1c1c8
Update deepsensor/model/model.py
MartinSJRogers Dec 3, 2024
e601109
linting
davidwilby Dec 13, 2024
a86ce31
tweak comments
davidwilby Dec 13, 2024
c7a994e
re-enable size checking in test
davidwilby Dec 13, 2024
e5b580b
rename some variables for slightly improved readability; add typehints
davidwilby Dec 20, 2024
91b83ce
Merge pull request #18 from davidwilby/replace_combineByCoords_with_m…
davidwilby Dec 20, 2024
158b6dc
Merge pull request #17 from davidwilby/predict_args
davidwilby Dec 20, 2024
747f7dd
reduce large comment block to easier to follow inline comments
davidwilby Jan 8, 2025
4f5eead
remove unused hypothesis dependency
davidwilby Jan 8, 2025
322766f
remove todo
davidwilby Jan 8, 2025
b4e9ff5
move coord direction calcuation to where it is needed
davidwilby Jan 9, 2025
572d7ec
clean up markup
Jan 9, 2025
da2f68f
Reduce repitiion and place code to determine coordinate extent in one…
Jan 10, 2025
c2f0ffe
Create DeepSensor object straight after stitching
Jan 10, 2025
9a7e743
Slightly amend some mark up text
Jan 10, 2025
e857355
Editted text for get_coordinate_extent_method
Jan 10, 2025
358b884
Edit where time is defined in stitched prediction object
Jan 10, 2025
58e9076
Reduce for loops and keep predictions as deepsensor.prediction objects
Jan 12, 2025
9943e99
Update deepsensor/model/model.py
MartinSJRogers Jan 21, 2025
53ee50f
Merge pull request #19 from davidwilby/simplify_stitching
davidwilby Jan 21, 2025
6cf0a28
Update deepsensor/model/model.py
MartinSJRogers Jan 22, 2025
1f0fb32
Merge pull request #20 from davidwilby/simplify_stitching_retain_pred…
davidwilby Jan 27, 2025
be883dc
lint
davidwilby Jan 27, 2025
b0459e8
use python 3.8 compatible typehint
davidwilby Jan 27, 2025
9765787
correct type hint
davidwilby Jan 27, 2025
2bb4d9b
move stitching to pred module
davidwilby Jan 27, 2025
8f3897f
Edit some markup text in new methods in pred module.
Jan 30, 2025
d9e4d1b
Merge pull request #22 from davidwilby/mr_move_stitching_to_pred
davidwilby Jan 30, 2025
ccd407c
fix linting
davidwilby Feb 12, 2025
c114926
Merge pull request #21 from davidwilby/dw/move_stitching_to_predict
davidwilby Feb 12, 2025
761934d
begin moving patchwise task loading into its own class
davidwilby Apr 2, 2025
ea42f69
update user guide notebook with PatchwiseTaskLoader class references
davidwilby Apr 2, 2025
648ca8a
slightly reduce duplication of TaskLoader.__call__
davidwilby Apr 3, 2025
03e6eec
rename num_samples_per_date arg to num_patch_tasks
davidwilby Apr 3, 2025
620677f
Merge branch 'main' into patchwise_train
davidwilby Apr 7, 2025
c230c35
Merge branch 'patchwise_train' into refactor_classes
davidwilby Apr 7, 2025
7587726
Merge pull request #24 from davidwilby/refactor_classes
davidwilby Apr 7, 2025
ea3987b
re-use tests for patchwise training using paramterized
davidwilby Apr 7, 2025
fc02665
use descriptive test docstring for test_patchwise_prediction
davidwilby Apr 7, 2025
c60d9b5
reduce repetition in patch strategy validation
davidwilby Apr 7, 2025
f967318
store patch_strategy as a property of Task
davidwilby Apr 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
619 changes: 601 additions & 18 deletions deepsensor/data/loader.py

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion deepsensor/data/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ def __init__(self, task_dict: dict) -> None:
@classmethod
def summarise_str(cls, k, v):
"""Return string summaries for the _str__ method."""
if plum.isinstance(v, B.Numeric):
if isinstance(v, float):
return v
elif plum.isinstance(v, B.Numeric):
return v.shape
elif plum.isinstance(v, tuple):
return tuple(vi.shape for vi in v)
Expand All @@ -57,6 +59,8 @@ def summarise_repr(cls, k, v) -> str:
"""
if v is None:
return "None"
elif isinstance(v, float):
return f"{type(v).__name__}"
elif plum.isinstance(v, B.Numeric):
return f"{type(v).__name__}/{v.dtype}/{v.shape}"
if plum.isinstance(v, deepsensor.backend.nps.mask.Masked):
Expand Down
268 changes: 268 additions & 0 deletions deepsensor/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Prediction,
increase_spatial_resolution,
infer_prediction_modality_from_X_t,
stitch_clipped_predictions,
)
from deepsensor.data.task import Task

Expand Down Expand Up @@ -648,6 +649,273 @@ def unnormalise_pred_array(arr, **kwargs):

return pred

def predict_patchwise(
self,
tasks: Union[List[Task], Task],
X_t: Union[
xr.Dataset,
xr.DataArray,
pd.DataFrame,
pd.Series,
pd.Index,
np.ndarray,
],
X_t_mask: Optional[Union[xr.Dataset, xr.DataArray]] = None,
**kwargs,
) -> Prediction:
"""Predict using tasks loaded using a sliding window patching strategy. Uses the `predict` method.

.. versionadded:: 0.4.3
:py:func:`predict_patchwise()` method.

Args:
tasks (List[Task] | Task):
List of tasks containing context data. Tasks for patchwise prediction must be generated by a PatchwiseTaskLoader using the "sliding" patching strategy.
X_t (:class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` | :class:`numpy:numpy.ndarray`):
Target locations to predict at. Can be an xarray object
containing on-grid locations or a pandas object containing off-grid locations.
X_t_mask: :class:`xarray.Dataset` | :class:`xarray.DataArray`, optional
2D mask to apply to gridded ``X_t`` (zero/False will be NaNs). Will be interpolated
to the same grid as ``X_t`` and patched in the same way. Default None (no mask).
**kwargs:
Keyword arguments as per ``predict``.

Returns:
:class:`~.model.pred.Prediction`):
A `dict`-like object mapping from target variable IDs to xarray or pandas objects
containing model predictions.
- If ``X_t`` is a pandas object, returns pandas objects
containing off-grid predictions.
- If ``X_t`` is an xarray object, returns xarray object
containing on-grid predictions.
- If ``n_samples`` == 0, returns only mean and std predictions.
- If ``n_samples`` > 0, returns mean, std and samples
predictions.

Raises:
AttributeError
If ``tasks`` are not generated using the "sliding" patching strategy of PatchwiseTaskLoader.
Errors
See `~.model.model.DeepSensorModel.predict`
"""
# Get coordinate names of original unnormalised dataset.
orig_x1_name = self.data_processor.x1_name
orig_x2_name = self.data_processor.x2_name

def get_patches_per_row(preds) -> int:
"""Calculate number of patches per row.
Required to stitch patches back together.

Args:
preds (List[class:`~.model.pred.Prediction`]):
A list of `dict`-like objects containing patchwise predictions.

Returns:
patches_per_row: int
Number of patches per row.
"""
patches_per_row = 0
vars = list(preds[0][0].data_vars)
var = vars[0]
x1_val = preds[0][0][var].coords[orig_x1_name].min()

for pred in preds:
if pred[0][var].coords[orig_x1_name].min() == x1_val:
patches_per_row = patches_per_row + 1

return patches_per_row

def get_patch_overlap(
overlap_norm, data_processor, X_t_ds, x1_ascend, x2_ascend
) -> int:
"""Calculate overlap between adjacent patches in pixels.

Parameters
----------
overlap_norm : tuple[float].
Normalised size of overlap in x1/x2.

data_processor (:class:`~.data.processor.DataProcessor`):
Used for unnormalising the coordinates of the bounding boxes of patches.

X_t_ds (:class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` | :class:`numpy:numpy.ndarray`):
Data array containing target locations to predict at.

x1_ascend : str:
Boolean defining whether the x1 coords ascend (increase) from top to bottom, default = True.

x2_ascend : str:
Boolean defining whether the x2 coords ascend (increase) from left to right, default = True.

Returns:
-------
patch_overlap : tuple (int)
Unnormalised size of overlap between adjacent patches.
"""
# Todo- check if there is simplier and more robust way to convert overlap into pixels.
# Place x1/x2 overlap values in Xarray to pass into unnormalise()
overlap_list = [0, overlap_norm[0], 0, overlap_norm[1]]
x1 = xr.DataArray([overlap_list[0], overlap_list[1]], dims="x1", name="x1")
x2 = xr.DataArray([overlap_list[2], overlap_list[3]], dims="x2", name="x2")
overlap_norm_xr = xr.Dataset(coords={"x1": x1, "x2": x2})

# Unnormalise coordinates of bounding boxes
overlap_unnorm_xr = data_processor.unnormalise(overlap_norm_xr)

unnorm_overlap_x1 = overlap_unnorm_xr.coords[orig_x1_name].values[1]
unnorm_overlap_x2 = overlap_unnorm_xr.coords[orig_x2_name].values[1]

def overlap_index(
coords: np.ndarray, ascend: bool, unnorm_overlap: float
) -> int:
"""Find size of overlap in a single coordinate direction, in units of pixels.

Parameters
----------
coords : np.ndarray

ascend : bool
Boolean defining whether coords ascend (increase) from top to bottom or left to right.

unnorm_overlap : float
The patch overlap in unnormalised coordinates.

Returns:
-------
int : The number of pixels in the overlap.
"""
pixel_coords_overlap_diffs = np.abs(coords - unnorm_overlap)
if ascend:
trim_size = np.argmin(pixel_coords_overlap_diffs) / 2
trim_size_rounded = int(
np.floor(trim_size)
) # Always round down trim slide as stitching method can handle slight overlaps
return trim_size_rounded

else:
overlap_pixel_size = np.argmin(pixel_coords_overlap_diffs)
overlap_pixel_size_rounded = np.ceil(overlap_pixel_size)
trim_size = (
(coords.size - int(overlap_pixel_size_rounded)) / 2
) # this extra step is so we get the overlap with respect to the largest value (i.e. is the number of pixels = 360, coords.size = 360)
trim_size_rounded = int(np.floor(trim_size))
return trim_size_rounded

return (
overlap_index(
X_t_ds.coords[orig_x1_name].values, x1_ascend, unnorm_overlap_x1
),
overlap_index(
X_t_ds.coords[orig_x2_name].values, x2_ascend, unnorm_overlap_x2
),
)

task_patch_strategy = tasks[0]["patch_strategy"]
if task_patch_strategy is not "sliding":
raise AttributeError(
f"For patchwise prediction, only tasks generated using a patch_strategy of 'sliding' are valid. \
This task appears to have been created using the {task_patch_strategy} strategy."
)

# load patch_size and stride from task
patch_size = tasks[0]["patch_size"]
stride = tasks[0]["stride"]

# sanitise patch_size and stride arguments
if isinstance(patch_size, float) and patch_size is not None:
patch_size = (patch_size, patch_size)

if isinstance(stride, float) and stride is not None:
stride = (stride, stride)

if stride[0] > patch_size[0] or stride[1] > patch_size[1]:
raise ValueError(
f"stride must be smaller than patch_size in the corresponding dimensions for patchwise prediction. Got: patch_size: {patch_size}, stride: {stride}"
)

# patchwise prediction does not yet support more than a single date
num_task_dates = len(set([t["time"] for t in tasks]))
if num_task_dates > 1:
raise NotImplementedError(
f"Patchwise prediction does not yet support more than a single date at a time, got {num_task_dates}."
)

# tasks should be iterable, if only one is provided, make it a list
if type(tasks) is Task:
tasks = [tasks]

# Perform patchwise predictions
preds = []
for task in tasks:
bbox = task["bbox"]

# Unnormalise coordinates of bounding box of patch
x1 = xr.DataArray([bbox[0], bbox[1]], dims="x1", name="x1")
x2 = xr.DataArray([bbox[2], bbox[3]], dims="x2", name="x2")
bbox_norm = xr.Dataset(coords={"x1": x1, "x2": x2})
bbox_unnorm = self.data_processor.unnormalise(bbox_norm)
unnorm_bbox_x1 = (
bbox_unnorm[orig_x1_name].values.min(),
bbox_unnorm[orig_x1_name].values.max(),
)
unnorm_bbox_x2 = (
bbox_unnorm[orig_x2_name].values.min(),
bbox_unnorm[orig_x2_name].values.max(),
)

# Determine X_t for patch, however, cannot assume min/max ordering of slice coordinates
# Check the order of coordinates in X_t, sometimes they are increasing or decreasing in order.
x1_coords = X_t.coords[orig_x1_name].values
x2_coords = X_t.coords[orig_x2_name].values

if x1_coords[0] < x1_coords[-1]:
x1_slice = slice(unnorm_bbox_x1[0], unnorm_bbox_x1[1])
x1_ascending = True
else:
x1_slice = slice(unnorm_bbox_x1[1], unnorm_bbox_x1[0])
x1_ascending = False

if x2_coords[0] < x2_coords[-1]:
x2_slice = slice(unnorm_bbox_x2[0], unnorm_bbox_x2[1])
x2_ascending = True
else:
x2_slice = slice(unnorm_bbox_x2[1], unnorm_bbox_x2[0])
x2_ascending = False

# Determine X_t for patch with correct slice direction
task_X_t = X_t.sel(**{orig_x1_name: x1_slice, orig_x2_name: x2_slice})
task_X_t_mask = (
X_t_mask.sel(**{orig_x1_name: x1_slice, orig_x2_name: x2_slice})
if X_t_mask
else None
)

# Patchwise prediction
pred = self.predict(task, task_X_t, task_X_t_mask, **kwargs)
# Append patchwise DeepSensor prediction object to list
preds.append(pred)

overlap_norm = tuple(
patch - stride for patch, stride in zip(patch_size, stride)
)
patch_overlap_unnorm = get_patch_overlap(
overlap_norm, self.data_processor, X_t, x1_ascending, x2_ascending
)

patches_per_row = get_patches_per_row(preds)
prediction = stitch_clipped_predictions(
preds,
patch_overlap_unnorm,
patches_per_row,
X_t,
orig_x1_name,
orig_x2_name,
x1_ascending,
x2_ascending,
)

return prediction


def add_valid_time_coord_to_pred_and_move_time_dims(pred: Prediction) -> Prediction:
"""Add a valid time coordinate "time" to a Prediction object based on the
Expand Down
Loading
Loading