Skip to content

Commit cdffa3c

Browse files
committed
fixed bug in dask+cupy args and added unit test
1 parent a536d2a commit cdffa3c

File tree

3 files changed

+19
-8
lines changed

3 files changed

+19
-8
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ In the GIS world, rasters are used for representing continuous phenomena (e.g. e
209209
| Name | NumPy xr.DataArray | Dask xr.DataArray | CuPy GPU xr.DataArray | Dask GPU xr.DataArray |
210210
|:----------:|:----------------------:|:--------------------:|:-------------------:|:------:|
211211
| [Aspect](xrspatial/aspect.py) | ✅️ | ✅️ | ✅️ | ✅️ |
212-
| [Curvature](xrspatial/curvature.py) | ✅️ | | | ⚠️ |
212+
| [Curvature](xrspatial/curvature.py) | ✅️ |⚠️✅️ |⚠️✅️ |️✅|
213213
| [Hillshade](xrspatial/hillshade.py) | ✅️ | ✅️ | | |
214214
| [Slope](xrspatial/slope.py) | ✅️ | ✅️ | ✅️ | ⚠️✅️ |
215215
| [Terrain Generation](xrspatial/terrain.py) | ✅️ | ✅️ | ✅️ | |

xrspatial/curvature.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@ class cupy(object):
1515
from numba import cuda
1616

1717
# local modules
18-
from xrspatial.utils import (ArrayTypeFunctionMapping, cuda_args, get_dataarray_resolution, ngjit,
19-
not_implemented_func)
18+
from xrspatial.utils import (ArrayTypeFunctionMapping, cuda_args, get_dataarray_resolution, ngjit)
2019

2120

2221
@ngjit
@@ -84,10 +83,14 @@ def _run_cupy(data: cupy.ndarray,
8483

8584
return out
8685

86+
8787
def _run_dask_cupy(data: da.Array,
88-
cellsize: Union[int, float]) -> da.Array:
88+
cellsize: Union[int, float]) -> da.Array:
8989
data = data.astype(cupy.float32)
90-
_func = partial(_cpu, cellsize=cellsize)
90+
cellsize_arr = cupy.array([float(cellsize)], dtype='f4')
91+
92+
_func = partial(_run_cupy, cellsize=cellsize_arr)
93+
9194
out = data.map_overlap(_func,
9295
depth=(1, 1),
9396
boundary=cupy.nan,

xrspatial/tests/test_curvature.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from xrspatial import curvature
55
from xrspatial.tests.general_checks import (assert_numpy_equals_cupy,
6+
assert_numpy_equals_dask_cupy,
67
assert_numpy_equals_dask_numpy, create_test_raster,
78
cuda_and_cupy_available, general_output_checks)
89

@@ -87,9 +88,6 @@ def test_numpy_equals_cupy_random_data(random_data):
8788
numpy_agg = create_test_raster(random_data, backend='numpy')
8889
cupy_agg = create_test_raster(random_data, backend='cupy')
8990
assert_numpy_equals_cupy(numpy_agg, cupy_agg, curvature)
90-
# NOTE: Dask + GPU code paths don't currently work because of
91-
# dask casting cupy arrays to numpy arrays during
92-
# https://github.com/dask/dask/issues/4842
9391

9492

9593
@pytest.mark.parametrize("size", [(2, 4), (10, 15)])
@@ -99,3 +97,13 @@ def test_numpy_equals_dask_random_data(random_data):
9997
numpy_agg = create_test_raster(random_data, backend='numpy')
10098
dask_agg = create_test_raster(random_data, backend='dask')
10199
assert_numpy_equals_dask_numpy(numpy_agg, dask_agg, curvature)
100+
101+
102+
@cuda_and_cupy_available
103+
@pytest.mark.parametrize("size", [(2, 4), (10, 15)])
104+
@pytest.mark.parametrize(
105+
"dtype", [np.int32, np.int64, np.uint32, np.uint64, np.float32, np.float64])
106+
def test_numpy_equals_dask_cupy_random_data(random_data):
107+
numpy_agg = create_test_raster(random_data, backend='numpy')
108+
dask_cupy_agg = create_test_raster(random_data, backend='dask+cupy')
109+
assert_numpy_equals_dask_cupy(numpy_agg, dask_cupy_agg, curvature, atol=1e-6, rtol=1e-6)

0 commit comments

Comments
 (0)