Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion reproject/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def _reproject_dispatcher(
array_path, array_in = _dask_to_numpy_memmap(array_in, local_tmp_dir)
logger.info(f"Numpy memory-mapped array is now at {array_path}")

logger.info(f"Calling {reproject_func.__name__} in non-dask mode")
logger.debug(f"Calling {reproject_func.__name__} in non-dask mode")

try:
return reproject_func(
Expand Down
136 changes: 111 additions & 25 deletions reproject/hips/_dask_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,81 +4,160 @@
import uuid

import numpy as np
from astropy import units as u
from astropy.coordinates import SpectralCoord
from astropy.io import fits
from astropy.utils.data import download_file
from astropy.wcs import WCS
from astropy_healpix import HEALPix, level_to_nside
from dask import array as da

from .high_level import VALID_COORD_SYSTEM
from .utils import is_url, load_properties, map_header, tile_filename
from .utils import (
is_url,
load_properties,
map_header,
skycoord_first,
spectral_coord_to_index,
tile_filename,
)

__all__ = ["hips_as_dask_array"]


class HiPSArray:

def __init__(self, directory_or_url, level=None):
def __init__(self, directory_or_url, level=None, level_depth=None):

self._directory_or_url = directory_or_url

self._is_url = is_url(directory_or_url)

self._properties = load_properties(directory_or_url)

if self._properties["dataproduct_type"] == "image":
self.ndim = 2
elif self._properties["dataproduct_type"] == "spectral-cube":
self.ndim = 3
else:
raise TypeError(f"HiPS type {self._properties['dataproduct_type']} not recognized")

self._tile_width = int(self._properties["hips_tile_width"])
self._order = int(self._properties["hips_order"])
self._order_spatial = int(self._properties["hips_order"])

if level is None:
self._level = self._order
self._level_spatial = self._order_spatial
else:
if level > self._order:
if level > self._order_spatial:
raise ValueError(
f"HiPS dataset at {directory_or_url} does not contain level {level} data"
f"HiPS dataset at {directory_or_url} does not contain spatial level {level} data"
)
elif level < 0:
raise ValueError("level should be positive")
else:
self._level = int(level)
self._level = self._order if level is None else level
self._level_spatial = int(level)

if self.ndim == 3:

# TODO: here need to check consistency, maybe actually don't allow spectral level to be passed in

self._tile_depth = int(self._properties["hips_tile_depth"])
self._order_depth = int(self._properties["hips_order_freq"])

if level_depth is None:
self._level_depth = self._order_depth - (self._order_spatial - self._level_spatial)
else:
if level_depth > self._order_depth:
raise ValueError(
f"HiPS dataset at {directory_or_url} does not contain spectral level {level_depth} data"
)
elif level_depth < 0:
raise ValueError("level_depth should be positive")
else:
self._level_depth = int(level_depth)

self._level = (self._level_spatial, self._level_depth)
self._tile_dims = (self._tile_width, self._tile_depth)

else:

self._level_depth = None
self._level = self._level_spatial
self._tile_dims = self._tile_width

self._tile_format = self._properties["hips_tile_format"]
self._frame_str = self._properties["hips_frame"]
self._frame = VALID_COORD_SYSTEM[self._frame_str]

self._hp = HEALPix(nside=level_to_nside(self._level), frame=self._frame, order="nested")
self._hp = HEALPix(
nside=level_to_nside(self._level_spatial), frame=self._frame, order="nested"
)

self._header = map_header(level=self._level, frame=self._frame, tile_size=self._tile_width)
self._header = map_header(level=self._level, frame=self._frame, tile_dims=self._tile_dims)

self.wcs = WCS(self._header)
self.shape = self.wcs.array_shape

# Determine actual spectral range, because we don't actually want to
# create a dask array with the full possible range of spectral indices
# since this will be huge and unnecessary

if self.ndim == 3:

wav_min = SpectralCoord(float(self._properties["em_min"]), u.m)
wav_max = SpectralCoord(float(self._properties["em_max"]), u.m)

index_min = spectral_coord_to_index(self._level_depth, wav_min)
index_max = spectral_coord_to_index(self._level_depth, wav_max)

if index_min > index_max:
index_min, index_max = index_max, index_min

index_max += 1

index_min *= self._tile_depth
index_max *= self._tile_depth

self.wcs = self.wcs[index_min:index_max]
self.shape = (index_max - index_min,) + self.shape[1:]

# FIX following
self.dtype = float
self.ndim = 2

self.chunksize = (self._tile_width, self._tile_width)
if self.ndim == 2:
self.chunksize = (self._tile_width, self._tile_width)
else:
self.chunksize = (self._tile_depth, self._tile_width, self._tile_width)

self._nan = np.nan * np.ones(self.chunksize, dtype=self.dtype)

self._blank = np.broadcast_to(np.nan, self.shape)

def __getitem__(self, item):

if item[0].start == item[0].stop or item[1].start == item[1].stop:
return self._blank[item]
for idx in range(self.ndim):
if item[idx].start == item[idx].stop:
return self._blank[item]

# We use two points in different parts of the image because in some
# cases using the exact center or corners can cause issues.
# Determine spatial healpix index - we use two points in different
# parts of the image because in some cases using the exact center or
# corners can cause issues.

istart = item[0].start
irange = item[0].stop - item[0].start
istart = item[-2].start
irange = item[-2].stop - item[-2].start
imid = np.array([istart + 0.25 * irange, istart + 0.75 * irange])

jstart = item[1].start
jrange = item[1].stop - item[1].start
jstart = item[-1].start
jrange = item[-1].stop - item[-1].start
jmid = np.array([jstart + 0.25 * jrange, jstart + 0.75 * jrange])

# Convert pixel coordinates to HEALPix indices

coord = self.wcs.pixel_to_world(jmid, imid)
if self.ndim == 2:
coord = self.wcs.pixel_to_world(jmid, imid)
else:
kmid = 0.5 * (item[0].start + item[0].stop)
coord, spectral_coord = skycoord_first(self.wcs.pixel_to_world(jmid, imid, kmid))

if self._frame_str == "equatorial":
lon, lat = coord.ra.deg, coord.dec.deg
Expand All @@ -94,14 +173,21 @@ def __getitem__(self, item):
elif np.any(invalid):
coord = coord[~invalid]

index = self._hp.skycoord_to_healpix(coord)
spatial_index = self._hp.skycoord_to_healpix(coord)

if np.all(index == -1):
if np.all(spatial_index == -1):
return self._nan

index = np.max(index)
spatial_index = np.max(spatial_index)

# Determine spectral index, if needed
if self.ndim == 3:
spectral_index = spectral_coord_to_index(self._level_depth, spectral_coord).max()
index = (spatial_index, spectral_index)
else:
index = spatial_index

return self._get_tile(level=self._level, index=index)
return self._get_tile(level=self._level, index=index).astype(float)

@functools.lru_cache(maxsize=128) # noqa: B019
def _get_tile(self, *, level, index):
Expand Down
22 changes: 9 additions & 13 deletions reproject/hips/high_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
ICRS,
BarycentricTrueEcliptic,
Galactic,
SkyCoord,
SpectralCoord,
)
from astropy.io import fits
from astropy.nddata import block_reduce
Expand All @@ -34,6 +32,7 @@
load_properties,
make_tile_folders,
save_properties,
skycoord_first,
spectral_coord_to_index,
tile_filename,
tile_header,
Expand Down Expand Up @@ -258,14 +257,8 @@ def reproject_to_hips(
cen_skycoord = wcs_in.pixel_to_world(*centers)
cor_skycoord = wcs_in.pixel_to_world(*edges)
else:
for w in wcs_in.pixel_to_world(*centers):
if isinstance(w, SkyCoord):
cen_skycoord = w
for w in wcs_in.pixel_to_world(*edges):
if isinstance(w, SkyCoord):
cor_skycoord = w
if isinstance(w, SpectralCoord):
cor_spectralcoord = w
cen_skycoord, _ = skycoord_first(wcs_in.pixel_to_world(*centers))
cor_skycoord, cor_spectralcoord = skycoord_first(wcs_in.pixel_to_world(*edges))

separations = cor_skycoord.separation(cen_skycoord)

Expand All @@ -280,7 +273,10 @@ def reproject_to_hips(
ran_x = np.random.uniform(-0.5, nx - 0.5, n_ran)
ran_y = np.random.uniform(-0.5, nx - 0.5, n_ran)

ran_world = wcs_in.pixel_to_world(ran_x, ran_y)
if ndim == 2:
ran_world = wcs_in.pixel_to_world(ran_x, ran_y)
elif ndim == 3:
ran_world, _ = skycoord_first(wcs_in.pixel_to_world(ran_x, ran_y, np.zeros(n_ran)))

separations = ran_world[:, None].separation(ran_world[None, :])

Expand Down Expand Up @@ -327,7 +323,7 @@ def reproject_to_hips(

# Determine all the spectral indices at the highest spectral level
spectral_indices_edges = spectral_coord_to_index(level_depth, cor_spectralcoord)
spectral_indices = np.arange(spectral_indices_edges.min(), spectral_indices_edges.max())
spectral_indices = np.arange(spectral_indices_edges.min(), spectral_indices_edges.max() + 1)
indices = [
(int(idx), int(spec_idx)) for (idx, spec_idx) in product(indices, spectral_indices)
]
Expand Down Expand Up @@ -356,6 +352,7 @@ def process(index):
header = tile_header(level=level, index=index, frame=frame, tile_dims=tile_dims)

if isinstance(header, tuple):

array_out1, footprint1 = reproject_function(
(array_in, wcs_in_copy), header[0], **kwargs
)
Expand Down Expand Up @@ -592,7 +589,6 @@ def process(index):
generated_properties["hips_order_freq"] = level_depth
generated_properties["hips_order_min"] = 0
generated_properties["hips_tile_depth"] = tile_depth
generated_properties["hips_tile_depth"] = tile_depth
wav = cor_spectralcoord.to_value(u.m)
generated_properties["em_min"] = wav.min()
generated_properties["em_max"] = wav.max()
Expand Down
69 changes: 68 additions & 1 deletion reproject/hips/tests/test_dask_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@ def setup_method(self):
self.original_wcs = WCS(hdu.header)
self.original_array = hdu.data.size + np.arange(hdu.data.size).reshape(hdu.data.shape)

self.original_array_3d = self.original_array.reshape((1,) + self.original_array.shape)
self.original_array_3d = self.original_array_3d * np.arange(1, 11).reshape((10, 1, 1))
assert self.original_array_3d.shape == (10, 240, 480)

self.original_wcs_3d = self.original_wcs.sub([1, 2, 0])
self.original_wcs_3d.wcs.ctype[2] = "FREQ"
self.original_wcs_3d.wcs.crval[2] = 1e10
self.original_wcs_3d.wcs.cdelt[2] = 1e9
self.original_wcs_3d.wcs.crpix[2] = 1
self.original_wcs_3d._naxis = list(self.original_array_3d.shape[::-1])

@pytest.mark.parametrize("frame", ("galactic", "equatorial"))
@pytest.mark.parametrize("level", (0, 1))
def test_roundtrip(self, tmp_path, frame, level):
Expand All @@ -31,6 +42,7 @@ def test_roundtrip(self, tmp_path, frame, level):
(self.original_array, self.original_wcs),
coord_system_out=frame,
level=1,
level_depth=6,
reproject_function=reproject_interp,
output_directory=output_directory,
tile_size=256,
Expand Down Expand Up @@ -82,8 +94,63 @@ def test_level_validation(self, tmp_path):
dask_array, wcs = hips_as_dask_array(output_directory)
assert dask_array.shape == (320, 320)

with pytest.raises(Exception, match=r"does not contain level 2 data"):
with pytest.raises(Exception, match=r"does not contain spatial level 2 data"):
hips_as_dask_array(output_directory, level=2)

with pytest.raises(Exception, match=r"should be positive"):
hips_as_dask_array(output_directory, level=-1)

@pytest.mark.parametrize("frame", ("galactic", "equatorial"))
@pytest.mark.parametrize("level", (0, 1))
def test_roundtrip_3d(self, tmp_path, frame, level):

output_directory = tmp_path / "roundtrip"

# Note that we always use level=1 to generate, but use a variable level
# to construct the dask array - this is deliberate and ensure that the
# dask array has a proper separation of maximum and current level.
reproject_to_hips(
(self.original_array_3d, self.original_wcs_3d),
coord_system_out=frame,
level=1,
reproject_function=reproject_interp,
output_directory=output_directory,
tile_size=32,
tile_depth=8,
)

# Represent the HiPS as a dask array
dask_array, wcs = hips_as_dask_array(output_directory, level=level)

# FIXME: at this point we should be able to do:
#
# Reproject back to the original WCS
# final_array, footprint = reproject_interp(
# (dask_array, wcs),
# self.original_wcs_3d,
# shape_out=self.original_array_3d.shape,
# )
#
# However this does not work properly due to this issue:
# https://github.com/astropy/astropy/issues/18690
#
# For now, we pick a sub-region of the array to check

subset = (slice(None), slice(50, None), slice(50, None))

final_array, footprint = reproject_interp(
(dask_array, wcs),
self.original_wcs_3d[subset],
shape_out=self.original_array_3d[subset].shape,
)

# NOTE: The two last channels are empty - this is normal and is because
# of the interpolation on the spectral grid

valid = ~np.isnan(final_array)[:8]
assert np.sum(valid) > 450000 # similar to 2D test
np.testing.assert_allclose(
final_array[:8][valid],
self.original_array_3d[subset][:8][valid],
rtol=0.1 if level == 1 else 0.4,
)
Loading
Loading