Skip to content

Commit d7fdb71

Browse files
authored
Merge pull request #434 from martindurant/river
Add River nodes and examples
2 parents e6515f4 + 4cdb0d7 commit d7fdb71

File tree

7 files changed

+390
-123
lines changed

7 files changed

+390
-123
lines changed

examples/river_kmeans.ipynb

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"id": "accbccab",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"import random\n",
11+
"\n",
12+
"import pandas as pd\n",
13+
"\n",
14+
"from streamz import Stream\n",
15+
"import hvplot.streamz\n",
16+
"from streamz.river import RiverTrain\n",
17+
"from river import cluster\n",
18+
"import holoviews as hv\n",
19+
"from panel.pane.holoviews import HoloViews\n",
20+
"hv.extension('bokeh')"
21+
]
22+
},
23+
{
24+
"cell_type": "code",
25+
"execution_count": null,
26+
"id": "8a2ef27a",
27+
"metadata": {},
28+
"outputs": [],
29+
"source": [
30+
"model = cluster.KMeans(n_clusters=3, sigma=0.1, mu=0.5)\n",
31+
"centres = [[random.random(), random.random()] for _ in range(3)]\n",
32+
"\n",
33+
"def gen(move_chance=0.05):\n",
34+
" centre = int(random.random() * 3) # 3x faster than random.randint(0, 2)\n",
35+
" if random.random() < move_chance:\n",
36+
" centres[centre][0] += random.random() / 5 - 0.1\n",
37+
" centres[centre][1] += random.random() / 5 - 0.1\n",
38+
" value = {'x': random.random() / 20 + centres[centre][0],\n",
39+
" 'y': random.random() / 20 + centres[centre][1]}\n",
40+
" return value\n",
41+
"\n",
42+
"\n",
43+
"def get_clusters(model):\n",
44+
" # return [{\"x\": xcen, \"y\": ycen}, ...] for each centre\n",
45+
" data = [{'x': v['x'], 'y': v['y']} for k, v in model.centers.items()]\n",
46+
" return pd.DataFrame(data, index=range(3))"
47+
]
48+
},
49+
{
50+
"cell_type": "code",
51+
"execution_count": null,
52+
"id": "e6451048",
53+
"metadata": {},
54+
"outputs": [],
55+
"source": [
56+
"s = Stream.from_periodic(gen, 0.03)\n",
57+
"km = RiverTrain(model, pass_model=True)\n",
58+
"s.map(lambda x: (x,)).connect(km) # learn takes a tuple of (x,[ y[, w]])\n",
59+
"ex = pd.DataFrame({'x': [0.5], 'y': [0.5]})\n",
60+
"ooo = s.map(lambda x: pd.DataFrame([x])).to_dataframe(example=ex)\n",
61+
"out = km.map(get_clusters)\n",
62+
"\n",
63+
"# start things\n",
64+
"s.emit(gen()) # set initial model\n",
65+
"for i, (x, y) in enumerate(centres):\n",
66+
" model.centers[i]['x'] = x\n",
67+
" model.centers[i]['y'] = y\n"
68+
]
69+
},
70+
{
71+
"cell_type": "code",
72+
"execution_count": null,
73+
"id": "1b4de451",
74+
"metadata": {},
75+
"outputs": [],
76+
"source": [
77+
"pout = out.to_dataframe(example=ex)\n",
78+
"pl = (ooo.hvplot.scatter('x', 'y', color=\"blue\", backlog=50) *\n",
79+
" pout.hvplot.scatter('x', 'y', color=\"red\", backlog=3))\n",
80+
"pl.opts(xlim=(-0.2, 1.2), ylim=(-0.2, 1.2), height=600, width=600)\n",
81+
"pl"
82+
]
83+
},
84+
{
85+
"cell_type": "code",
86+
"execution_count": null,
87+
"id": "c24d2363",
88+
"metadata": {},
89+
"outputs": [],
90+
"source": [
91+
"s.start()"
92+
]
93+
},
94+
{
95+
"cell_type": "code",
96+
"execution_count": null,
97+
"id": "18cfd94e",
98+
"metadata": {},
99+
"outputs": [],
100+
"source": [
101+
"s.stop()"
102+
]
103+
},
104+
{
105+
"cell_type": "code",
106+
"execution_count": null,
107+
"id": "4537495c",
108+
"metadata": {},
109+
"outputs": [],
110+
"source": []
111+
}
112+
],
113+
"metadata": {
114+
"kernelspec": {
115+
"display_name": "Python 3",
116+
"language": "python",
117+
"name": "python3"
118+
},
119+
"language_info": {
120+
"codemirror_mode": {
121+
"name": "ipython",
122+
"version": 3
123+
},
124+
"file_extension": ".py",
125+
"mimetype": "text/x-python",
126+
"name": "python",
127+
"nbconvert_exporter": "python",
128+
"pygments_lexer": "ipython3",
129+
"version": "3.8.8"
130+
}
131+
},
132+
"nbformat": 4,
133+
"nbformat_minor": 5
134+
}

examples/river_kmeans.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import random
2+
3+
import pandas as pd
4+
5+
from streamz import Stream
6+
import hvplot.streamz
7+
from streamz.river import RiverTrain
8+
from river import cluster
9+
import holoviews as hv
10+
from panel.pane.holoviews import HoloViews
11+
hv.extension('bokeh')
12+
13+
model = cluster.KMeans(n_clusters=3, sigma=0.1, mu=0.5)
14+
centres = [[random.random(), random.random()] for _ in range(3)]
15+
count = [0]
16+
17+
def gen(move_chance=0.05):
18+
centre = int(random.random() * 3) # 3x faster than random.randint(0, 2)
19+
if random.random() < move_chance:
20+
centres[centre][0] += random.random() / 5 - 0.1
21+
centres[centre][1] += random.random() / 5 - 0.1
22+
value = {'x': random.random() / 20 + centres[centre][0],
23+
'y': random.random() / 20 + centres[centre][1]}
24+
count[0] += 1
25+
return value
26+
27+
28+
def get_clusters(model):
29+
# return [{"x": xcen, "y": ycen}, ...] for each centre
30+
data = [{'x': v['x'], 'y': v['y']} for k, v in model.centers.items()]
31+
return pd.DataFrame(data, index=range(3))
32+
33+
34+
def main(viz=True):
35+
# setup pipes
36+
cadance = 0.16 if viz else 0.01
37+
s = Stream.from_periodic(gen, cadance)
38+
km = RiverTrain(model, pass_model=True)
39+
s.map(lambda x: (x,)).connect(km) # learn takes a tuple of (x,[ y[, w]])
40+
ex = pd.DataFrame({'x': [0.5], 'y': [0.5]})
41+
ooo = s.map(lambda x: pd.DataFrame([x])).to_dataframe(example=ex)
42+
out = km.map(get_clusters)
43+
44+
# start things
45+
s.emit(gen()) # set initial model
46+
for i, (x, y) in enumerate(centres):
47+
model.centers[i]['x'] = x
48+
model.centers[i]['y'] = y
49+
50+
print("starting")
51+
s.start()
52+
53+
if viz:
54+
# plot
55+
pout = out.to_dataframe(example=ex)
56+
pl = (ooo.hvplot.scatter('x', 'y', color="blue", backlog=50) *
57+
pout.hvplot.scatter('x', 'y', color="red", backlog=3))
58+
pl.opts(xlim=(-0.2, 1.2), ylim=(-0.2, 1.2), height=600, width=600)
59+
pan = HoloViews(pl)
60+
pan.show()
61+
else:
62+
import time
63+
time.sleep(5)
64+
print(count, "events")
65+
print("Current centres", centres)
66+
print("Output centres", [list(c.values()) for c in model.centers.values()])
67+
s.stop()
68+
69+
if __name__ == "__main__":
70+
main(viz=True)

streamz/core.py

Lines changed: 0 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1902,89 +1902,6 @@ def cb(self):
19021902
yield self._emit(x, self.next_metadata)
19031903

19041904

1905-
@Stream.register_api()
1906-
class to_kafka(Stream):
1907-
""" Writes data in the stream to Kafka
1908-
1909-
This stream accepts a string or bytes object. Call ``flush`` to ensure all
1910-
messages are pushed. Responses from Kafka are pushed downstream.
1911-
1912-
Parameters
1913-
----------
1914-
topic : string
1915-
The topic which to write
1916-
producer_config : dict
1917-
Settings to set up the stream, see
1918-
https://docs.confluent.io/current/clients/confluent-kafka-python/#configuration
1919-
https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md
1920-
Examples:
1921-
bootstrap.servers: Connection string (host:port) to Kafka
1922-
1923-
Examples
1924-
--------
1925-
>>> from streamz import Stream
1926-
>>> ARGS = {'bootstrap.servers': 'localhost:9092'}
1927-
>>> source = Stream()
1928-
>>> kafka = source.map(lambda x: str(x)).to_kafka('test', ARGS)
1929-
<to_kafka>
1930-
>>> for i in range(10):
1931-
... source.emit(i)
1932-
>>> kafka.flush()
1933-
"""
1934-
def __init__(self, upstream, topic, producer_config, **kwargs):
1935-
import confluent_kafka as ck
1936-
1937-
self.topic = topic
1938-
self.producer = ck.Producer(producer_config)
1939-
1940-
kwargs["ensure_io_loop"] = True
1941-
Stream.__init__(self, upstream, **kwargs)
1942-
self.stopped = False
1943-
self.polltime = 0.2
1944-
self.loop.add_callback(self.poll)
1945-
self.futures = []
1946-
1947-
@gen.coroutine
1948-
def poll(self):
1949-
while not self.stopped:
1950-
# executes callbacks for any delivered data, in this thread
1951-
# if no messages were sent, nothing happens
1952-
self.producer.poll(0)
1953-
yield gen.sleep(self.polltime)
1954-
1955-
def update(self, x, who=None, metadata=None):
1956-
future = gen.Future()
1957-
self.futures.append(future)
1958-
1959-
@gen.coroutine
1960-
def _():
1961-
while True:
1962-
try:
1963-
# this runs asynchronously, in C-K's thread
1964-
self.producer.produce(self.topic, x, callback=self.cb)
1965-
return
1966-
except BufferError:
1967-
yield gen.sleep(self.polltime)
1968-
except Exception as e:
1969-
future.set_exception(e)
1970-
return
1971-
1972-
self.loop.add_callback(_)
1973-
return future
1974-
1975-
@gen.coroutine
1976-
def cb(self, err, msg):
1977-
future = self.futures.pop(0)
1978-
if msg is not None and msg.value() is not None:
1979-
future.set_result(None)
1980-
yield self._emit(msg.value())
1981-
else:
1982-
future.set_exception(err or msg.error())
1983-
1984-
def flush(self, timeout=-1):
1985-
self.producer.flush(timeout)
1986-
1987-
19881905
def sync(loop, func, *args, **kwargs):
19891906
"""
19901907
Run coroutine in loop running in separate thread.

streamz/river.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from . import Stream
2+
3+
4+
# TODO: most river classes support batches, e.g., learn_many, more efficiently
5+
6+
7+
class RiverTransform(Stream):
8+
"""Pass data through one or more River transforms"""
9+
10+
def __init__(self, model, **kwargs):
11+
super().__init__(**kwargs)
12+
self.model = model
13+
14+
def update(self, x, who=None, metadata=None):
15+
out = self.model.transform_one(*x)
16+
self.emit(out)
17+
18+
19+
class RiverTrain(Stream):
20+
21+
def __init__(self, model, metric=None, pass_model=False, **kwargs):
22+
"""
23+
24+
If metric and pass_model are both defaults, this is effectively
25+
a sink.
26+
27+
:param model: river model or pipeline
28+
:param metric: river metric
29+
If given, it is emitted on every sample
30+
:param pass_model: bool
31+
If True, the (updated) model if emitted for each sample
32+
"""
33+
super().__init__(**kwargs)
34+
self.model = model
35+
if pass_model and metric is not None:
36+
raise TypeError
37+
self.pass_model = pass_model
38+
self.metric = metric
39+
40+
def update(self, x, who=None, metadata=None):
41+
"""
42+
:param x: tuple
43+
(x, [y[, w]) floats for single sample. Include
44+
"""
45+
self.model.learn_one(*x)
46+
if self.metric:
47+
yp = self.model.predict_one(x[0])
48+
weights = x[2] if len(x) > 1 else 1.0
49+
self.emit(self.metric.update(x[1], yp, weights).get(), metadata=metadata)
50+
if self.pass_model:
51+
self.emit(self.model, metadata=metadata)
52+
53+
54+
class RiverPredict(Stream):
55+
56+
def __init__(self, model, **kwargs):
57+
super().__init__(**kwargs)
58+
self.model = model
59+
60+
def update(self, x, who=None, metadata=None):
61+
out = self.model.predict_one(x)
62+
self.emit(out, metadata=metadata)

0 commit comments

Comments
 (0)