diff --git a/streamz/dask.py b/streamz/dask.py index d0c9d4e2..ae11a7a2 100644 --- a/streamz/dask.py +++ b/streamz/dask.py @@ -1,5 +1,7 @@ from __future__ import absolute_import, division, print_function +from functools import wraps +from .core import _truthy from operator import getitem from tornado import gen @@ -10,6 +12,36 @@ from .core import Stream from . import core, sources +from collections import Sequence + + +NULL_COMPUTE = "~~NULL_COMPUTE~~" + + +def return_null(func): + @wraps(func) + def inner(x, *args, **kwargs): + tv = func(x, *args, **kwargs) + if tv: + return x + else: + return NULL_COMPUTE + + return inner + + +def filter_null_wrapper(func): + @wraps(func) + def inner(*args, **kwargs): + if any(a is NULL_COMPUTE for a in args) or any( + v is NULL_COMPUTE for v in kwargs.values() + ): + return NULL_COMPUTE + else: + return func(*args, **kwargs) + + return inner + class DaskStream(Stream): """ A Parallel stream using Dask @@ -46,7 +78,7 @@ def __init__(self, *args, **kwargs): @DaskStream.register_api() class map(DaskStream): def __init__(self, upstream, func, *args, **kwargs): - self.func = func + self.func = filter_null_wrapper(func) self.kwargs = kwargs self.args = args @@ -117,12 +149,20 @@ class gather(core.Stream): buffer scatter """ + @gen.coroutine def update(self, x, who=None): client = default_client() result = yield client.gather(x, asynchronous=True) - result2 = yield self._emit(result) - raise gen.Return(result2) + if ( + not ( + isinstance(result, Sequence) + and any(r is NULL_COMPUTE for r in result) + ) + and result is not NULL_COMPUTE + ): + result2 = yield self._emit(result) + raise gen.Return(result2) @DaskStream.register_api() @@ -140,6 +180,24 @@ def update(self, x, who=None): return self._emit(result) +@DaskStream.register_api() +class filter(DaskStream): + def __init__(self, upstream, predicate, *args, **kwargs): + if predicate is None: + predicate = _truthy + self.predicate = return_null(predicate) + stream_name = kwargs.pop("stream_name", None) + self.kwargs = kwargs + self.args = args + + DaskStream.__init__(self, upstream, stream_name=stream_name) + + def update(self, x, who=None): + client = default_client() + result = client.submit(self.predicate, x, *self.args, **self.kwargs) + return self._emit(result) + + @DaskStream.register_api() class buffer(DaskStream, core.buffer): pass diff --git a/streamz/tests/test_dask.py b/streamz/tests/test_dask.py index d4da4fc4..9be7ff57 100644 --- a/streamz/tests/test_dask.py +++ b/streamz/tests/test_dask.py @@ -131,6 +131,67 @@ def test_buffer(c, s, a, b): assert source.loop == c.loop +@pytest.mark.slow +def test_filter(): + source = Stream(asynchronous=True) + futures = scatter(source).filter(lambda x: x % 2 == 0) + futures_L = futures.sink_to_list() + L = futures.gather().sink_to_list() + + for i in range(5): + yield source.emit(i) + + assert L == [0, 2, 4] + assert all(isinstance(f, Future) for f in futures_L) + + +@pytest.mark.slow +def test_filter_buffer(): + source = Stream(asynchronous=True) + futures = scatter(source).filter(lambda x: x % 2 == 0) + futures_L = futures.sink_to_list() + L = futures.buffer(10).gather().sink_to_list() + + for i in range(5): + yield source.emit(i) + while len(L) < 3: + yield gen.sleep(.01) + + assert L == [0, 2, 4] + assert all(isinstance(f, Future) for f in futures_L) + + +@pytest.mark.slow +def test_filter_map(): + source = Stream(asynchronous=True) + futures = ( + scatter(source).filter(lambda x: x % 2 == 0).map(inc) + ) + futures_L = futures.sink_to_list() + L = futures.gather().sink_to_list() + + for i in range(5): + yield source.emit(i) + + assert L == [1, 3, 5] + assert all(isinstance(f, Future) for f in futures_L) + + +@pytest.mark.slow +def test_filter_starmap(): + source = Stream(asynchronous=True) + futures1 = scatter(source).filter(lambda x: x[1] % 2 == 0) + futures = futures1.starmap(add) + futures_L = futures.sink_to_list() + L = futures.gather().sink_to_list() + + for i in range(5): + yield source.emit((i, i)) + + assert L == [0, 4, 8] + assert all(isinstance(f, Future) for f in futures_L) + + @pytest.mark.slow def test_buffer_sync(loop): # noqa: F811 with cluster() as (s, [a, b]):