diff --git a/dspy/streaming/streamify.py b/dspy/streaming/streamify.py index d8ec56eb2f..e3058a026a 100644 --- a/dspy/streaming/streamify.py +++ b/dspy/streaming/streamify.py @@ -166,52 +166,68 @@ async def use_streaming(): if not any(isinstance(c, StatusStreamingCallback) for c in callbacks): callbacks.append(status_streaming_callback) - async def generator(args, kwargs, stream: MemoryObjectSendStream): - with settings.context(send_stream=stream, callbacks=callbacks, stream_listeners=stream_listeners): - prediction = await program(*args, **kwargs) + async def generator(args, kwargs, stream: MemoryObjectSendStream, parent_overrides): + from dspy.dsp.utils.settings import thread_local_overrides - await stream.send(prediction) + original_overrides = thread_local_overrides.get() + token = thread_local_overrides.set({**original_overrides, **parent_overrides}) + try: + with settings.context(send_stream=stream, callbacks=callbacks, stream_listeners=stream_listeners): + prediction = await program(*args, **kwargs) + finally: + thread_local_overrides.reset(token) - async def async_streamer(*args, **kwargs): - send_stream, receive_stream = create_memory_object_stream(16) - async with create_task_group() as tg, send_stream, receive_stream: - tg.start_soon(generator, args, kwargs, send_stream) + await stream.send(prediction) - async for value in receive_stream: - if isinstance(value, ModelResponseStream): - if len(predict_id_to_listener) == 0: - # No listeners are configured, yield the chunk directly for backwards compatibility. + def async_streamer(*args, **kwargs): + from dspy.dsp.utils.settings import thread_local_overrides + + # capture the parent overrides to pass to the generator + # this is to keep the contextvars even after the context block exits + parent_overrides = thread_local_overrides.get().copy() + + async def _async_streamer_impl(): + send_stream, receive_stream = create_memory_object_stream(16) + async with create_task_group() as tg, send_stream, receive_stream: + tg.start_soon(generator, args, kwargs, send_stream, parent_overrides) + + async for value in receive_stream: + if isinstance(value, ModelResponseStream): + if len(predict_id_to_listener) == 0: + # No listeners are configured, yield the chunk directly for backwards compatibility. + yield value + else: + # We are receiving a chunk from the LM's response stream, delegate it to the listeners to + # determine if we should yield a value to the user. + for listener in predict_id_to_listener[value.predict_id]: + # In some special cases such as Citation API, it is possible that multiple listeners + # return values at the same time due to the chunk buffer of the listener. + if output := listener.receive(value): + yield output + elif isinstance(value, StatusMessage): yield value + elif isinstance(value, Prediction): + # Flush remaining buffered tokens before yielding the Prediction instance + for listener in stream_listeners: + if final_chunk := listener.finalize(): + yield final_chunk + + if include_final_prediction_in_output_stream: + yield value + elif ( + len(stream_listeners) == 0 + or any(listener.cache_hit for listener in stream_listeners) + or not any(listener.stream_start for listener in stream_listeners) + ): + yield value + return else: - # We are receiving a chunk from the LM's response stream, delegate it to the listeners to - # determine if we should yield a value to the user. - for listener in predict_id_to_listener[value.predict_id]: - # In some special cases such as Citation API, it is possible that multiple listeners - # return values at the same time due to the chunk buffer of the listener. - if output := listener.receive(value): - yield output - elif isinstance(value, StatusMessage): - yield value - elif isinstance(value, Prediction): - # Flush remaining buffered tokens before yielding the Prediction instance - for listener in stream_listeners: - if final_chunk := listener.finalize(): - yield final_chunk - - if include_final_prediction_in_output_stream: + # This wildcard case allows for customized streaming behavior. + # It is useful when a users have a custom LM which returns stream chunks in a custom format. + # We let those chunks pass through to the user to handle them as needed. yield value - elif ( - len(stream_listeners) == 0 - or any(listener.cache_hit for listener in stream_listeners) - or not any(listener.stream_start for listener in stream_listeners) - ): - yield value - return - else: - # This wildcard case allows for customized streaming behavior. - # It is useful when a users have a custom LM which returns stream chunks in a custom format. - # We let those chunks pass through to the user to handle them as needed. - yield value + + return _async_streamer_impl() if async_streaming: return async_streamer diff --git a/tests/streaming/test_streaming.py b/tests/streaming/test_streaming.py index 89a84250ff..dbdccb88ed 100644 --- a/tests/streaming/test_streaming.py +++ b/tests/streaming/test_streaming.py @@ -10,9 +10,15 @@ import dspy from dspy.adapters.types import Type +from dspy.dsp.utils.settings import thread_local_overrides from dspy.experimental import Citations, Document from dspy.streaming import StatusMessage, StatusMessageProvider, streaming_response +try: + from exceptiongroup import BaseExceptionGroup +except ImportError: + BaseExceptionGroup = BaseException + @pytest.mark.anyio async def test_streamify_yields_expected_response_chunks(litellm_test_server): @@ -1169,3 +1175,40 @@ def test_stream_listener_could_form_end_identifier_xml_adapter(): # Should return False for text that cannot form the pattern assert listener._could_form_end_identifier("hello world", "XMLAdapter") is False assert listener._could_form_end_identifier("some text", "XMLAdapter") is False + + +@pytest.mark.asyncio +async def test_streamify_context_propagation(): + """ + Test that dspy.context() properly propagates LM settings to streamify + even when the context exits before async iteration begins. + """ + predict = dspy.Predict("question->answer") + test_lm = dspy.LM("openai/gpt-4o-mini", cache=False) + + lm_was_set = [] + + async def mock_stream(**kwargs): + current_lm = thread_local_overrides.get().get("lm") or dspy.settings.lm + lm_was_set.append(current_lm is not None and current_lm == test_lm) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="[["))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ##"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" answer"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ##"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ]]\n\n"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="test"))]) + + stream_predict = dspy.streamify( + predict, + stream_listeners=[dspy.streaming.StreamListener(signature_field_name="answer")], + ) + + with mock.patch("litellm.acompletion", side_effect=mock_stream): + with dspy.context(lm=test_lm): + output_stream = stream_predict(question="test question") + + async for _ in output_stream: + break + + assert len(lm_was_set) > 0 + assert lm_was_set[0]