Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions streamz/dask.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import absolute_import, division, print_function
from functools import wraps

from .core import _truthy
from operator import getitem

from tornado import gen
Expand All @@ -11,6 +13,21 @@
from . import core, sources


NULL_COMPUTE = "~~NULL_COMPUTE~~"


def return_null(func):
@wraps(func)
def inner(x, *args, **kwargs):
tv = func(x, *args, **kwargs)
if tv:
return x
else:
return NULL_COMPUTE

return inner


class DaskStream(Stream):
""" A Parallel stream using Dask

Expand Down Expand Up @@ -140,6 +157,24 @@ def update(self, x, who=None):
return self._emit(result)


@DaskStream.register_api()
class filter(DaskStream):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you also need the modifications to the gather and other nodes as well.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made changes to gather already. I also compared other nodes. There is a slightly difference between the dask.starmap and parallel.starmap. Do I need to change that one?

def __init__(self, upstream, predicate, *args, **kwargs):
if predicate is None:
predicate = _truthy
self.predicate = return_null(predicate)
stream_name = kwargs.pop("stream_name", None)
self.kwargs = kwargs
self.args = args

DaskStream.__init__(self, upstream, stream_name=stream_name)

def update(self, x, who=None):
client = self.default_client()
result = client.submit(self.predicate, x, *self.args, **self.kwargs)
return self._emit(result)


@DaskStream.register_api()
class buffer(DaskStream, core.buffer):
pass
Expand Down
61 changes: 61 additions & 0 deletions streamz/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,67 @@ def test_buffer(c, s, a, b):
assert source.loop == c.loop


@pytest.mark.slow
def test_filter(backend):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to remove the backend arg here, and in the scatter statement. Streamz only has a dask backend at the moment.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I remove the backend for all the four tests I have added?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes please

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already changed the four tests but the build is still failed

source = Stream(asynchronous=True)
futures = scatter(source, backend=backend).filter(lambda x: x % 2 == 0)
futures_L = futures.sink_to_list()
L = futures.gather().sink_to_list()

for i in range(5):
yield source.emit(i)

assert L == [0, 2, 4]
assert all(isinstance(f, Future) for f in futures_L)


@pytest.mark.slow
def test_filter_buffer(backend):
source = Stream(asynchronous=True)
futures = scatter(source, backend=backend).filter(lambda x: x % 2 == 0)
futures_L = futures.sink_to_list()
L = futures.buffer(10).gather().sink_to_list()

for i in range(5):
yield source.emit(i)
while len(L) < 3:
yield gen.sleep(.01)

assert L == [0, 2, 4]
assert all(isinstance(f, Future) for f in futures_L)


@pytest.mark.slow
def test_filter_map(backend):
source = Stream(asynchronous=True)
futures = (
scatter(source, backend=backend).filter(lambda x: x % 2 == 0).map(inc)
)
futures_L = futures.sink_to_list()
L = futures.gather().sink_to_list()

for i in range(5):
yield source.emit(i)

assert L == [1, 3, 5]
assert all(isinstance(f, Future) for f in futures_L)


@pytest.mark.slow
def test_filter_starmap(backend):
source = Stream(asynchronous=True)
futures1 = scatter(source, backend=backend).filter(lambda x: x[1] % 2 == 0)
futures = futures1.starmap(add)
futures_L = futures.sink_to_list()
L = futures.gather().sink_to_list()

for i in range(5):
yield source.emit((i, i))

assert L == [0, 4, 8]
assert all(isinstance(f, Future) for f in futures_L)


@pytest.mark.slow
def test_buffer_sync(loop): # noqa: F811
with cluster() as (s, [a, b]):
Expand Down