Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 56 additions & 40 deletions dspy/streaming/streamify.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,52 +166,68 @@ async def use_streaming():
if not any(isinstance(c, StatusStreamingCallback) for c in callbacks):
callbacks.append(status_streaming_callback)

async def generator(args, kwargs, stream: MemoryObjectSendStream):
with settings.context(send_stream=stream, callbacks=callbacks, stream_listeners=stream_listeners):
prediction = await program(*args, **kwargs)
async def generator(args, kwargs, stream: MemoryObjectSendStream, parent_overrides):
from dspy.dsp.utils.settings import thread_local_overrides

await stream.send(prediction)
original_overrides = thread_local_overrides.get()
token = thread_local_overrides.set({**original_overrides, **parent_overrides})
try:
with settings.context(send_stream=stream, callbacks=callbacks, stream_listeners=stream_listeners):
prediction = await program(*args, **kwargs)
finally:
thread_local_overrides.reset(token)

async def async_streamer(*args, **kwargs):
send_stream, receive_stream = create_memory_object_stream(16)
async with create_task_group() as tg, send_stream, receive_stream:
tg.start_soon(generator, args, kwargs, send_stream)
await stream.send(prediction)

async for value in receive_stream:
if isinstance(value, ModelResponseStream):
if len(predict_id_to_listener) == 0:
# No listeners are configured, yield the chunk directly for backwards compatibility.
def async_streamer(*args, **kwargs):
from dspy.dsp.utils.settings import thread_local_overrides

# capture the parent overrides to pass to the generator
# this is to keep the contextvars even after the context block exits
parent_overrides = thread_local_overrides.get().copy()

async def _async_streamer_impl():
send_stream, receive_stream = create_memory_object_stream(16)
async with create_task_group() as tg, send_stream, receive_stream:
tg.start_soon(generator, args, kwargs, send_stream, parent_overrides)

async for value in receive_stream:
if isinstance(value, ModelResponseStream):
if len(predict_id_to_listener) == 0:
# No listeners are configured, yield the chunk directly for backwards compatibility.
yield value
else:
# We are receiving a chunk from the LM's response stream, delegate it to the listeners to
# determine if we should yield a value to the user.
for listener in predict_id_to_listener[value.predict_id]:
# In some special cases such as Citation API, it is possible that multiple listeners
# return values at the same time due to the chunk buffer of the listener.
if output := listener.receive(value):
yield output
elif isinstance(value, StatusMessage):
yield value
elif isinstance(value, Prediction):
# Flush remaining buffered tokens before yielding the Prediction instance
for listener in stream_listeners:
if final_chunk := listener.finalize():
yield final_chunk

if include_final_prediction_in_output_stream:
yield value
elif (
len(stream_listeners) == 0
or any(listener.cache_hit for listener in stream_listeners)
or not any(listener.stream_start for listener in stream_listeners)
):
yield value
return
else:
# We are receiving a chunk from the LM's response stream, delegate it to the listeners to
# determine if we should yield a value to the user.
for listener in predict_id_to_listener[value.predict_id]:
# In some special cases such as Citation API, it is possible that multiple listeners
# return values at the same time due to the chunk buffer of the listener.
if output := listener.receive(value):
yield output
elif isinstance(value, StatusMessage):
yield value
elif isinstance(value, Prediction):
# Flush remaining buffered tokens before yielding the Prediction instance
for listener in stream_listeners:
if final_chunk := listener.finalize():
yield final_chunk

if include_final_prediction_in_output_stream:
# This wildcard case allows for customized streaming behavior.
# It is useful when a users have a custom LM which returns stream chunks in a custom format.
# We let those chunks pass through to the user to handle them as needed.
yield value
elif (
len(stream_listeners) == 0
or any(listener.cache_hit for listener in stream_listeners)
or not any(listener.stream_start for listener in stream_listeners)
):
yield value
return
else:
# This wildcard case allows for customized streaming behavior.
# It is useful when a users have a custom LM which returns stream chunks in a custom format.
# We let those chunks pass through to the user to handle them as needed.
yield value

return _async_streamer_impl()

if async_streaming:
return async_streamer
Expand Down
43 changes: 43 additions & 0 deletions tests/streaming/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,15 @@

import dspy
from dspy.adapters.types import Type
from dspy.dsp.utils.settings import thread_local_overrides
from dspy.experimental import Citations, Document
from dspy.streaming import StatusMessage, StatusMessageProvider, streaming_response

try:
from exceptiongroup import BaseExceptionGroup
except ImportError:
BaseExceptionGroup = BaseException


@pytest.mark.anyio
async def test_streamify_yields_expected_response_chunks(litellm_test_server):
Expand Down Expand Up @@ -1169,3 +1175,40 @@ def test_stream_listener_could_form_end_identifier_xml_adapter():
# Should return False for text that cannot form the pattern
assert listener._could_form_end_identifier("hello world", "XMLAdapter") is False
assert listener._could_form_end_identifier("some text", "XMLAdapter") is False


@pytest.mark.asyncio
async def test_streamify_context_propagation():
"""
Test that dspy.context() properly propagates LM settings to streamify
even when the context exits before async iteration begins.
"""
predict = dspy.Predict("question->answer")
test_lm = dspy.LM("openai/gpt-4o-mini", cache=False)

lm_was_set = []

async def mock_stream(**kwargs):
current_lm = thread_local_overrides.get().get("lm") or dspy.settings.lm
lm_was_set.append(current_lm is not None and current_lm == test_lm)
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="[["))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ##"))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" answer"))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ##"))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ]]\n\n"))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="test"))])

stream_predict = dspy.streamify(
predict,
stream_listeners=[dspy.streaming.StreamListener(signature_field_name="answer")],
)

with mock.patch("litellm.acompletion", side_effect=mock_stream):
with dspy.context(lm=test_lm):
output_stream = stream_predict(question="test question")

async for _ in output_stream:
break

assert len(lm_was_set) > 0
assert lm_was_set[0]