Skip to content

Make it possible to stream structured while using .iter #2078

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

dmontagu
Copy link
Contributor

Addresses the request here to support structured streaming while using .iter.

In particular, the analogous example to the whales streaming stuff would look like:

async with agent.iter(my_prompt) as run:
  async for node in run:
     ...
     # probably only want to do it one these nodes?
     elif Agent.is_model_request_node(node):
        async with node.stream(run.ctx) as request_stream:
          async for message, last in request_stream.stream_responses():
          maybe_whales = await run.validate_structured_output(message, allow_partial=not last)
          if maybe_whales is not None:
              whales = maybe_whales.output
    ...

I have some concerns about the type names etc. but I think overall this seems like a reasonable thing to add.

@@ -95,7 +95,7 @@ async def _validate_response(
match = self._output_schema.find_named_tool(message.parts, output_tool_name)
if match is None:
raise exceptions.UnexpectedModelBehavior( # pragma: no cover
f'Invalid response, unable to find tool: {self._output_schema.tool_names()}'
f'Invalid response, unable to find tool {output_tool_name!r}; expected one of {self._output_schema.tool_names()}'
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this error message and the one below just seemed confusing; I suspect they are hard/impossible to hit in practice but they should still include the relevant information imo

@dmontagu dmontagu force-pushed the dmontagu/stream-structured-with-iter branch from 9b55b99 to 215b39f Compare June 25, 2025 21:53
@@ -2052,6 +2052,69 @@ def usage(self) -> _usage.Usage:
"""Get usage statistics for the run so far, including token usage, model requests, and so on."""
return self._graph_run.state.usage

async def validate_structured_output(
self, response: _messages.ModelResponse, *, tool_name: str | None = None, allow_partial: bool = False
) -> FinalResult[OutputDataT] | None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My biggest concern with this implementation is that it is annotated as returning a FinalResult. That type is exactly what you want to be returning here, I think (well, it would be enough to just use Option, but that's not public, and why not include tool info?). But the name FinalResult is kind of confusing when the purpose of adding this was to support cases where it literally is not the final result, i.e. partial streaming. On the other hand, it's pretty clear what's going on imo, and do we really want to introduce another Option-like result marker type? I'm open to others' opinions.

Comment on lines +2065 to +2068
tool_name: If provided, this should be the name of the tool that will produce the output.
This is only included so that you can skip the tool-lookup step if you already know which tool
produced the output, which may be the case if calling this method in a loop while streaming.
(You can get the `tool_name` from the return value of this method when it is not `None`.)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if we really want to add the complexity associated with allowing the end user to specify the tool, or just always look up the tool here. I'd be okay either way, dropping this should be trivial though if we want to (just delete the argument and the first if-statement referencing it below).

if isinstance(output_schema, _output.ToolOutputSchema):
if tool_name is None:
# infer the tool name from the response parts
for part, _ in output_schema.find_tool(response.parts):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This returns the same tuple[call, output_tool] we get from find_named_tool below, so if we're in this branch, we can skip that extra lookup

Copy link

Docs Preview

commit: 3ebc6a2
Preview URL: https://c77af922-pydantic-ai-previews.pydantic.workers.dev

@@ -2052,6 +2052,69 @@ def usage(self) -> _usage.Usage:
"""Get usage statistics for the run so far, including token usage, model requests, and so on."""
return self._graph_run.state.usage

async def validate_structured_output(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is basically a copy-paste of the other method in _agent_graph, I suggest moving it to a shared util.

@DouweM
Copy link
Contributor

DouweM commented Jul 25, 2025

AgentStream already has a stream_output() method that works this way as of #2134.

"""Information about whales — an example of streamed structured response validation.

This script streams structured responses from GPT-4 about whales, validates the data
and displays it as a dynamic table using Rich as the data is received.

Run with:

    uv run -m pydantic_ai_examples.stream_whales
"""

from typing import Annotated

import logfire
from pydantic import Field
from rich.console import Console
from rich.live import Live
from rich.table import Table
from typing_extensions import NotRequired, TypedDict

from pydantic_ai import Agent
from pydantic_ai.output import NativeOutput

# 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured
logfire.configure(send_to_logfire='if-token-present')
logfire.instrument_pydantic_ai()


class Whale(TypedDict):
    name: str
    length: Annotated[float, Field(description='Average length of an adult whale in meters.')]
    weight: NotRequired[
        Annotated[
            float,
            Field(description='Average weight of an adult whale in kilograms.', ge=50),
        ]
    ]
    ocean: NotRequired[str]
    description: NotRequired[Annotated[str, Field(description='Short Description')]]


agent = Agent('openai:gpt-4o', output_type=NativeOutput(list[Whale]))


async def main():
    console = Console()
    with Live('\n' * 36, console=console) as live:
        console.print('Requesting data...', style='cyan')
        async with agent.iter('Generate me details of 5 species of Whale.') as run:
            async for node in run:
                if Agent.is_model_request_node(node):
                    async with node.stream(run.ctx) as request_stream:
                        console.print('Response:', style='green')

                        async for whales in request_stream.stream_output(debounce_by=0.01):
                            table = Table(
                                title='Species of Whale',
                                caption='Streaming Structured responses from GPT-4',
                                width=120,
                            )
                            table.add_column('ID', justify='right')
                            table.add_column('Name')
                            table.add_column('Avg. Length (m)', justify='right')
                            table.add_column('Avg. Weight (kg)', justify='right')
                            table.add_column('Ocean')
                            table.add_column('Description', justify='right')

                            for wid, whale in enumerate(whales, start=1):
                                table.add_row(
                                    str(wid),
                                    whale['name'],
                                    f'{whale["length"]:0.0f}',
                                    f'{w:0.0f}' if (w := whale.get('weight')) else '…',
                                    whale.get('ocean') or '…',
                                    whale.get('description') or '…',
                                )
                            live.update(table)


if __name__ == '__main__':
    import asyncio

    asyncio.run(main())

@DouweM DouweM closed this Jul 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants