diff --git a/reproject/array_utils.py b/reproject/array_utils.py index 4754baf45..e3e6161e9 100644 --- a/reproject/array_utils.py +++ b/reproject/array_utils.py @@ -107,6 +107,51 @@ def memory_efficient_access(array, chunk): return array[chunk] +def map_coordinates_vindex( + image, + coords, + order=0, + output=None, + cval=np.nan, + mode="constant", +): + + if order != 0: + raise ValueError("Only order==0 is supported") + + if mode != "constant": + raise ValueError("Only mode=='constant' is supported") + + original_shape = image.shape + + # As in map_coordinates + coords = coords.copy() + for i in range(coords.shape[0]): + coords[i][(coords[i] < 0) & (coords[i] >= -0.5)] = 0 + coords[i][(coords[i] < original_shape[i] - 0.5) & (coords[i] >= original_shape[i] - 1)] = ( + original_shape[i] - 1 + ) + + keep = np.ones(coords.shape[1], dtype=bool) + + for i in range(coords.shape[0]): + keep[(coords[i] < 0) | (coords[i] > original_shape[i] - 1)] = False + + if output is None: + output = np.repeat(cval, coords.shape[1]) + else: + output[...] = cval + + coords_sub = [] + for i in range(coords.shape[0]): + coords_sub.append(np.round(coords[i][keep]).astype(int)) + coords_sub = tuple(coords_sub) + + output[keep] = image.vindex[coords_sub] + + return output + + def map_coordinates( image, coords, max_chunk_size=None, output=None, optimize_memory=False, **kwargs ): diff --git a/reproject/common.py b/reproject/common.py index b14f5b7d7..b5bb6f20d 100644 --- a/reproject/common.py +++ b/reproject/common.py @@ -66,6 +66,7 @@ def _reproject_dispatcher( parallel=True, reproject_func_kwargs=None, return_type=None, + dask_method="memmap", ): """ Main function that handles either calling the core algorithms directly or @@ -116,6 +117,7 @@ def _reproject_dispatcher( If this is set to 'pil_image', a PIL ``Image`` object is returned. The 'pil_image' option can only be used if the input was RGB images or if the input data has shape (3, ny, nx) and contains values between 0 and 255. + dask_method : {'memmap', 'vindex', 'nothing'} """ logger = logging.getLogger(__name__) @@ -162,7 +164,7 @@ def _reproject_dispatcher( "been specified" ) - if isinstance(array_in, da.core.Array): + if isinstance(array_in, da.core.Array) and dask_method == "memmap": logger.info("Computing input dask array to Numpy memory-mapped array") array_path, array_in = _dask_to_numpy_memmap(array_in, local_tmp_dir) logger.info(f"Numpy memory-mapped array is now at {array_path}") @@ -178,6 +180,7 @@ def _reproject_dispatcher( array_out=array_out, return_footprint=return_footprint, output_footprint=output_footprint, + dask_method=dask_method, **reproject_func_kwargs, ) if return_type == "pil_image": @@ -210,7 +213,9 @@ def _reproject_dispatcher( "shape": array_in.shape, "offset": array_in.offset, } - elif isinstance(array_in, da.core.Array) or return_type == "dask": + elif ( + isinstance(array_in, da.core.Array) and dask_method == "memmap" + ) or return_type == "dask": if return_type == "dask": # We should use a temporary directory that will persist beyond # the call to the reproject function. @@ -218,6 +223,9 @@ def _reproject_dispatcher( else: tmp_dir = local_tmp_dir array_in_or_path = as_delayed_memmap_path(_ArrayContainer(array_in), tmp_dir) + elif isinstance(array_in, da.core.Array) and dask_method != "memmap": + dask_arrays = {'array': array_in} + array_in_or_path = 'from-dict' else: # Here we could set array_in_or_path to array_in_path if it has # been set previously, but in synchronous and threaded mode it is @@ -228,6 +236,7 @@ def _reproject_dispatcher( def reproject_single_block(a, array_or_path, block_info=None): + if ( a.ndim == 0 or block_info is None @@ -236,6 +245,12 @@ def reproject_single_block(a, array_or_path, block_info=None): ): return np.array([a, a]) + print(array_or_path, type(array_or_path)) + + if isinstance(array_or_path, str) and array_or_path == 'from-dict': + array_or_path = dask_arrays['array'] + + # The WCS class from astropy is not thread-safe, see e.g. # https://github.com/astropy/astropy/issues/16244 # https://github.com/astropy/astropy/issues/16245 @@ -276,6 +291,7 @@ def reproject_single_block(a, array_or_path, block_info=None): wcs_out_sub, shape_out=shape_out, array_out=np.zeros(shape_out), + dask_method=dask_method, **reproject_func_kwargs, ) @@ -307,6 +323,11 @@ def reproject_single_block(a, array_or_path, block_info=None): logger.info("Setting up output dask array with map_blocks") + print("DOING MAP BLOCKS") + + print(array_out_dask) + print(array_out_dask.chunksize) + print(type(array_in_or_path)) result = da.map_blocks( reproject_single_block, array_out_dask, @@ -320,6 +341,8 @@ def reproject_single_block(a, array_or_path, block_info=None): array_in = None array_in_or_path = None + print("FINALLY HERE") + # Truncate extra elements result = result[tuple([slice(None)] + [slice(s) for s in shape_out])] @@ -366,6 +389,8 @@ def reproject_single_block(a, array_or_path, block_info=None): logger.info("Copying output zarr array into output Numpy arrays") + print('HERE, ABOUT TO STORE') + if return_footprint: da.store( [result[0], result[1]], diff --git a/reproject/interpolation/core.py b/reproject/interpolation/core.py index d3021992c..380ce3af6 100644 --- a/reproject/interpolation/core.py +++ b/reproject/interpolation/core.py @@ -4,7 +4,7 @@ from astropy.wcs import WCS from astropy.wcs.utils import pixel_to_pixel -from ..array_utils import map_coordinates +from ..array_utils import map_coordinates, map_coordinates_vindex from ..wcs_utils import has_celestial, pixel_to_pixel_with_roundtrip @@ -55,6 +55,7 @@ def _reproject_full( return_footprint=True, roundtrip_coords=True, output_footprint=None, + dask_method=None, ): """ Reproject n-dimensional data to a new projection using interpolation. @@ -117,16 +118,27 @@ def _reproject_full( # Loop over the broadcasted dimensions in our array, reusing the same # computed transformation each time for i in range(len(array)): + print('dask_method', dask_method) # Interpolate array on to the pixels coordinates in pixel_in - map_coordinates( - array[i], - pixel_in, - order=order, - cval=np.nan, - mode="constant", - output=array_out_loopable[i].ravel(), - max_chunk_size=256 * 1024**2, - ) + if dask_method == "vindex": + map_coordinates_vindex( + array[i], + pixel_in, + order=order, + cval=np.nan, + mode="constant", + output=array_out_loopable[i].ravel(), + ) + else: + map_coordinates( + array[i], + pixel_in, + order=order, + cval=np.nan, + mode="constant", + output=array_out_loopable[i].ravel(), + max_chunk_size=256 * 1024**2, + ) # n.b. We write the reprojected data into array_out_loopable, but array_out # also contains this data and has the user's desired output shape. diff --git a/reproject/interpolation/high_level.py b/reproject/interpolation/high_level.py index 276e36620..1e9b82082 100644 --- a/reproject/interpolation/high_level.py +++ b/reproject/interpolation/high_level.py @@ -27,6 +27,7 @@ def reproject_interp( block_size=None, parallel=False, return_type=None, + dask_method="memmap", ): """ Reproject data to a new projection using interpolation (this is typically @@ -149,4 +150,5 @@ def reproject_interp( roundtrip_coords=roundtrip_coords, ), return_type=return_type, + dask_method=dask_method, ) diff --git a/reproject/mosaicking/coadd.py b/reproject/mosaicking/coadd.py index 9bd3c5b34..913a106e1 100644 --- a/reproject/mosaicking/coadd.py +++ b/reproject/mosaicking/coadd.py @@ -270,6 +270,8 @@ def reproject_and_coadd( slice_out = tuple([slice(imin, imax) for (imin, imax) in bounds]) + print(slice_out) + if isinstance(wcs_out, WCS): wcs_out_indiv = wcs_out[slice_out] else: diff --git a/reproject/tests/test_array_utils.py b/reproject/tests/test_array_utils.py index c424eda4d..e9a77ab22 100644 --- a/reproject/tests/test_array_utils.py +++ b/reproject/tests/test_array_utils.py @@ -1,8 +1,9 @@ import numpy as np +from dask import array as da from numpy.testing import assert_allclose from scipy.ndimage import map_coordinates as scipy_map_coordinates -from reproject.array_utils import map_coordinates +from reproject.array_utils import map_coordinates, map_coordinates_vindex def test_custom_map_coordinates(): @@ -37,3 +38,37 @@ def test_custom_map_coordinates(): ) assert_allclose(result, expected) + + +def test_custom_map_coordinates_vindex(): + np.random.seed(1249) + + data = np.random.random((3, 4)) + + coords = np.random.uniform(0, 3, (2, 10000)) + + expected = scipy_map_coordinates( + np.pad(data, 1, mode="edge"), + coords + 1, + order=0, + cval=np.nan, + mode="constant", + ) + + reset = np.zeros(coords.shape[1], dtype=bool) + + for i in range(coords.shape[0]): + reset |= coords[i] < -0.5 + reset |= coords[i] > data.shape[i] - 0.5 + + expected[reset] = np.nan + + result = map_coordinates_vindex( + da.from_array(data), + coords, + order=0, + cval=np.nan, + mode="constant", + ) + + assert_allclose(result, expected) diff --git a/reproject/utils.py b/reproject/utils.py index 8701982f5..c5100b2e4 100644 --- a/reproject/utils.py +++ b/reproject/utils.py @@ -83,6 +83,8 @@ def hdu_to_numpy_memmap(hdu): array backed by a memmapped buffer as returned by astropy. """ + print(type(hdu)) + if ( hdu.header.get("BSCALE", 1) != 1 or hdu.header.get("BZERO", 0) != 0 @@ -106,6 +108,8 @@ def parse_input_data(input_data, hdu_in=None, source_hdul=None): Parse input data to return a Numpy array and WCS object. """ + print(type(input_data)) + if isinstance(input_data, str | Path): if is_png(input_data) or is_jpeg(input_data): data = np.array(Image.open(input_data)).transpose(2, 0, 1)[:, ::-1] @@ -124,8 +128,10 @@ def parse_input_data(input_data, hdu_in=None, source_hdul=None): else: hdu_in = 0 return parse_input_data(input_data[hdu_in], source_hdul=input_data) - elif isinstance(input_data, PrimaryHDU | ImageHDU | CompImageHDU): + elif isinstance(input_data, PrimaryHDU | ImageHDU) and not isinstance(input_data, CompImageHDU): return (hdu_to_numpy_memmap(input_data), WCS(input_data.header, fobj=source_hdul)) + elif isinstance(input_data, CompImageHDU): + return (input_data.data, WCS(input_data.header, fobj=source_hdul)) elif isinstance(input_data, tuple) and isinstance(input_data[0], np.ndarray | da.core.Array): if isinstance(input_data[1], Header): return input_data[0], WCS(input_data[1])