Skip to content

Commit dc30111

Browse files
committed
More fixes
1 parent c1ce25e commit dc30111

File tree

2 files changed

+42
-18
lines changed

2 files changed

+42
-18
lines changed

reproject/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
@delayed(pure=True)
1818
def as_delayed_memmap_path(array, tmp_dir):
19+
tmp_dir = tempfile.mkdtemp() # FIXME
1920
if isinstance(array, da.core.Array):
2021
array_path, _ = _dask_to_numpy_memmap(array, tmp_dir)
2122
else:

reproject/mosaicking/coadd.py

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
# Licensed under a 3-clause BSD style license - see LICENSE.rst
22

3-
# Notes on dask re-write
4-
#
5-
3+
import os
4+
import uuid
5+
import tempfile
66
from math import ceil
77
from itertools import product
88

9+
import dask
910
import dask.array as da
1011
import numpy as np
1112
from astropy.wcs import WCS
@@ -174,7 +175,6 @@ def reproject_and_coadd(
174175
shape_out = tuple(
175176
[ceil(shape_out[i] / block_size[i]) * block_size[i] for i in range(len(shape_out))]
176177
)
177-
print(shape_out_original, shape_out)
178178

179179
if output_array is not None and output_array.shape != shape_out:
180180
raise ValueError(
@@ -334,6 +334,7 @@ def reproject_and_coadd(
334334
hdu_in=hdu_in,
335335
return_footprint=False,
336336
return_type="dask",
337+
parallel=parallel,
337338
block_size=block_size,
338339
**kwargs,
339340
)
@@ -434,22 +435,44 @@ def reproject_and_coadd(
434435
"combine_function={combine_function} not yet implemented when block_size is set"
435436
)
436437

437-
print([slice(0, shape_out_original[i]) for i in range(len(shape_out_original))])
438-
439438
result = result[
440439
tuple([slice(0, shape_out_original[i]) for i in range(len(shape_out_original))])
441440
]
442441

443-
if return_type == "numpy":
444-
if output_array is None:
445-
return result.compute(scheduler="synchronous"), None
446-
else:
447-
da.store(
448-
result,
449-
output_array,
450-
compute=True,
451-
scheduler="synchronous",
452-
)
453-
return output_array, None
454-
else:
442+
if return_type == "dask":
455443
return result, None
444+
445+
with tempfile.TemporaryDirectory() as tmp_dir:
446+
if parallel:
447+
# As discussed in https://github.com/dask/dask/issues/9556, da.store
448+
# will not work well in multiprocessing mode when the destination is a
449+
# Numpy array. Instead, in this case we save the dask array to a zarr
450+
# array on disk which can be done in parallel, and re-load it as a dask
451+
# array. We can then use da.store in the next step using the
452+
# 'synchronous' scheduler since that is I/O limited so does not need
453+
# to be done in parallel.
454+
455+
if isinstance(parallel, int):
456+
if parallel > 0:
457+
workers = {"num_workers": parallel}
458+
else:
459+
raise ValueError(
460+
"The number of processors to use must be strictly positive"
461+
)
462+
else:
463+
workers = {}
464+
465+
zarr_path = os.path.join(tmp_dir, f"{uuid.uuid4()}.zarr")
466+
467+
with dask.config.set(scheduler="processes", **workers):
468+
result.to_zarr(zarr_path)
469+
result = da.from_zarr(zarr_path)
470+
471+
da.store(
472+
result,
473+
output_array,
474+
compute=True,
475+
scheduler="synchronous",
476+
)
477+
478+
return output_array, None

0 commit comments

Comments
 (0)