|
1 | 1 | # Licensed under a 3-clause BSD style license - see LICENSE.rst
|
2 | 2 |
|
3 |
| -# Notes on dask re-write |
4 |
| -# |
5 |
| - |
| 3 | +import os |
| 4 | +import uuid |
| 5 | +import tempfile |
6 | 6 | from math import ceil
|
7 | 7 | from itertools import product
|
8 | 8 |
|
| 9 | +import dask |
9 | 10 | import dask.array as da
|
10 | 11 | import numpy as np
|
11 | 12 | from astropy.wcs import WCS
|
@@ -174,7 +175,6 @@ def reproject_and_coadd(
|
174 | 175 | shape_out = tuple(
|
175 | 176 | [ceil(shape_out[i] / block_size[i]) * block_size[i] for i in range(len(shape_out))]
|
176 | 177 | )
|
177 |
| - print(shape_out_original, shape_out) |
178 | 178 |
|
179 | 179 | if output_array is not None and output_array.shape != shape_out:
|
180 | 180 | raise ValueError(
|
@@ -334,6 +334,7 @@ def reproject_and_coadd(
|
334 | 334 | hdu_in=hdu_in,
|
335 | 335 | return_footprint=False,
|
336 | 336 | return_type="dask",
|
| 337 | + parallel=parallel, |
337 | 338 | block_size=block_size,
|
338 | 339 | **kwargs,
|
339 | 340 | )
|
@@ -434,22 +435,44 @@ def reproject_and_coadd(
|
434 | 435 | "combine_function={combine_function} not yet implemented when block_size is set"
|
435 | 436 | )
|
436 | 437 |
|
437 |
| - print([slice(0, shape_out_original[i]) for i in range(len(shape_out_original))]) |
438 |
| - |
439 | 438 | result = result[
|
440 | 439 | tuple([slice(0, shape_out_original[i]) for i in range(len(shape_out_original))])
|
441 | 440 | ]
|
442 | 441 |
|
443 |
| - if return_type == "numpy": |
444 |
| - if output_array is None: |
445 |
| - return result.compute(scheduler="synchronous"), None |
446 |
| - else: |
447 |
| - da.store( |
448 |
| - result, |
449 |
| - output_array, |
450 |
| - compute=True, |
451 |
| - scheduler="synchronous", |
452 |
| - ) |
453 |
| - return output_array, None |
454 |
| - else: |
| 442 | + if return_type == "dask": |
455 | 443 | return result, None
|
| 444 | + |
| 445 | + with tempfile.TemporaryDirectory() as tmp_dir: |
| 446 | + if parallel: |
| 447 | + # As discussed in https://github.com/dask/dask/issues/9556, da.store |
| 448 | + # will not work well in multiprocessing mode when the destination is a |
| 449 | + # Numpy array. Instead, in this case we save the dask array to a zarr |
| 450 | + # array on disk which can be done in parallel, and re-load it as a dask |
| 451 | + # array. We can then use da.store in the next step using the |
| 452 | + # 'synchronous' scheduler since that is I/O limited so does not need |
| 453 | + # to be done in parallel. |
| 454 | + |
| 455 | + if isinstance(parallel, int): |
| 456 | + if parallel > 0: |
| 457 | + workers = {"num_workers": parallel} |
| 458 | + else: |
| 459 | + raise ValueError( |
| 460 | + "The number of processors to use must be strictly positive" |
| 461 | + ) |
| 462 | + else: |
| 463 | + workers = {} |
| 464 | + |
| 465 | + zarr_path = os.path.join(tmp_dir, f"{uuid.uuid4()}.zarr") |
| 466 | + |
| 467 | + with dask.config.set(scheduler="processes", **workers): |
| 468 | + result.to_zarr(zarr_path) |
| 469 | + result = da.from_zarr(zarr_path) |
| 470 | + |
| 471 | + da.store( |
| 472 | + result, |
| 473 | + output_array, |
| 474 | + compute=True, |
| 475 | + scheduler="synchronous", |
| 476 | + ) |
| 477 | + |
| 478 | + return output_array, None |
0 commit comments