Skip to content

Commit c58795f

Browse files
committed
fix: apply code review comments
1 parent 3e9ccbd commit c58795f

File tree

6 files changed

+49
-23
lines changed

6 files changed

+49
-23
lines changed

pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ async def wrapped_run_workflow(
115115
infer_name: bool = True,
116116
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
117117
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
118+
response_prefix: str | None = None,
118119
**_deprecated_kwargs: Never,
119120
) -> AgentRunResult[Any]:
120121
with self._dbos_overrides():
@@ -131,6 +132,7 @@ async def wrapped_run_workflow(
131132
infer_name=infer_name,
132133
toolsets=toolsets,
133134
event_stream_handler=event_stream_handler,
135+
response_prefix=response_prefix,
134136
**_deprecated_kwargs,
135137
)
136138

@@ -152,6 +154,7 @@ def wrapped_run_sync_workflow(
152154
infer_name: bool = True,
153155
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
154156
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
157+
response_prefix: str | None = None,
155158
**_deprecated_kwargs: Never,
156159
) -> AgentRunResult[Any]:
157160
with self._dbos_overrides():
@@ -168,6 +171,7 @@ def wrapped_run_sync_workflow(
168171
infer_name=infer_name,
169172
toolsets=toolsets,
170173
event_stream_handler=event_stream_handler,
174+
response_prefix=response_prefix,
171175
**_deprecated_kwargs,
172176
)
173177

@@ -240,6 +244,7 @@ async def run(
240244
infer_name: bool = True,
241245
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
242246
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
247+
response_prefix: str | None = None,
243248
) -> AgentRunResult[OutputDataT]: ...
244249

245250
@overload
@@ -258,6 +263,7 @@ async def run(
258263
infer_name: bool = True,
259264
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
260265
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
266+
response_prefix: str | None = None,
261267
) -> AgentRunResult[RunOutputDataT]: ...
262268

263269
async def run(
@@ -275,6 +281,7 @@ async def run(
275281
infer_name: bool = True,
276282
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
277283
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
284+
response_prefix: str | None = None,
278285
**_deprecated_kwargs: Never,
279286
) -> AgentRunResult[Any]:
280287
"""Run the agent with a user prompt in async mode.
@@ -308,6 +315,7 @@ async def main():
308315
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
309316
toolsets: Optional additional toolsets for this run.
310317
event_stream_handler: Optional event stream handler to use for this run.
318+
response_prefix: Optional response prefix to use for this run.
311319
312320
Returns:
313321
The result of the run.
@@ -325,6 +333,7 @@ async def main():
325333
infer_name=infer_name,
326334
toolsets=toolsets,
327335
event_stream_handler=event_stream_handler,
336+
response_prefix=response_prefix,
328337
**_deprecated_kwargs,
329338
)
330339

@@ -344,6 +353,7 @@ def run_sync(
344353
infer_name: bool = True,
345354
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
346355
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
356+
response_prefix: str | None = None,
347357
) -> AgentRunResult[OutputDataT]: ...
348358

349359
@overload
@@ -362,6 +372,7 @@ def run_sync(
362372
infer_name: bool = True,
363373
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
364374
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
375+
response_prefix: str | None = None,
365376
) -> AgentRunResult[RunOutputDataT]: ...
366377

367378
def run_sync(
@@ -379,6 +390,7 @@ def run_sync(
379390
infer_name: bool = True,
380391
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
381392
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
393+
response_prefix: str | None = None,
382394
**_deprecated_kwargs: Never,
383395
) -> AgentRunResult[Any]:
384396
"""Synchronously run the agent with a user prompt.
@@ -411,6 +423,7 @@ def run_sync(
411423
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
412424
toolsets: Optional additional toolsets for this run.
413425
event_stream_handler: Optional event stream handler to use for this run.
426+
response_prefix: Optional response prefix to use for this run.
414427
415428
Returns:
416429
The result of the run.
@@ -428,6 +441,7 @@ def run_sync(
428441
infer_name=infer_name,
429442
toolsets=toolsets,
430443
event_stream_handler=event_stream_handler,
444+
response_prefix=response_prefix,
431445
**_deprecated_kwargs,
432446
)
433447

@@ -447,6 +461,7 @@ def run_stream(
447461
infer_name: bool = True,
448462
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
449463
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
464+
response_prefix: str | None = None,
450465
) -> AbstractAsyncContextManager[StreamedRunResult[AgentDepsT, OutputDataT]]: ...
451466

452467
@overload
@@ -465,6 +480,7 @@ def run_stream(
465480
infer_name: bool = True,
466481
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
467482
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
483+
response_prefix: str | None = None,
468484
) -> AbstractAsyncContextManager[StreamedRunResult[AgentDepsT, RunOutputDataT]]: ...
469485

470486
@asynccontextmanager
@@ -483,6 +499,7 @@ async def run_stream(
483499
infer_name: bool = True,
484500
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
485501
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
502+
response_prefix: str | None = None,
486503
**_deprecated_kwargs: Never,
487504
) -> AsyncIterator[StreamedRunResult[AgentDepsT, Any]]:
488505
"""Run the agent with a user prompt in async mode, returning a streamed response.
@@ -513,6 +530,7 @@ async def main():
513530
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
514531
toolsets: Optional additional toolsets for this run.
515532
event_stream_handler: Optional event stream handler to use for this run. It will receive all the events up until the final result is found, which you can then read or stream from inside the context manager.
533+
response_prefix: Optional response prefix to use for this run.
516534
517535
Returns:
518536
The result of the run.
@@ -537,6 +555,7 @@ async def main():
537555
infer_name=infer_name,
538556
toolsets=toolsets,
539557
event_stream_handler=event_stream_handler,
558+
response_prefix=response_prefix,
540559
**_deprecated_kwargs,
541560
) as result:
542561
yield result
@@ -556,6 +575,7 @@ def iter(
556575
usage: _usage.RunUsage | None = None,
557576
infer_name: bool = True,
558577
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
578+
response_prefix: str | None = None,
559579
**_deprecated_kwargs: Never,
560580
) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ...
561581

@@ -574,6 +594,7 @@ def iter(
574594
usage: _usage.RunUsage | None = None,
575595
infer_name: bool = True,
576596
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
597+
response_prefix: str | None = None,
577598
**_deprecated_kwargs: Never,
578599
) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ...
579600

@@ -592,6 +613,7 @@ async def iter(
592613
usage: _usage.RunUsage | None = None,
593614
infer_name: bool = True,
594615
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
616+
response_prefix: str | None = None,
595617
**_deprecated_kwargs: Never,
596618
) -> AsyncIterator[AgentRun[AgentDepsT, Any]]:
597619
"""A contextmanager which can be used to iterate over the agent graph's nodes as they are executed.
@@ -627,6 +649,7 @@ async def main():
627649
system_prompts=(),
628650
system_prompt_functions=[],
629651
system_prompt_dynamic_functions={},
652+
response_prefix=None,
630653
),
631654
ModelRequestNode(
632655
request=ModelRequest(
@@ -637,6 +660,7 @@ async def main():
637660
)
638661
]
639662
)
663+
response_prefix=None,
640664
),
641665
CallToolsNode(
642666
model_response=ModelResponse(
@@ -666,6 +690,7 @@ async def main():
666690
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
667691
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
668692
toolsets: Optional additional toolsets for this run.
693+
response_prefix: Optional response prefix to use for this run.
669694
670695
Returns:
671696
The result of the run.
@@ -688,6 +713,7 @@ async def main():
688713
usage=usage,
689714
infer_name=infer_name,
690715
toolsets=toolsets,
716+
response_prefix=response_prefix,
691717
**_deprecated_kwargs,
692718
) as run:
693719
yield run

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,12 @@ class AnthropicStreamedResponse(StreamedResponse):
626626

627627
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
628628
current_block: BetaContentBlock | None = None
629-
first_text_delta = True
629+
630+
# Handle response prefix by emitting it as the first text event
631+
if response_prefix := self.model_request_parameters.response_prefix:
632+
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=response_prefix)
633+
if maybe_event is not None:
634+
yield maybe_event
630635

631636
async for event in self._response:
632637
if isinstance(event, BetaRawMessageStartEvent):
@@ -670,12 +675,9 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
670675

671676
elif isinstance(event, BetaRawContentBlockDeltaEvent):
672677
if isinstance(event.delta, BetaTextDelta):
673-
content = event.delta.text
674-
# Prepend response prefix to the first text delta if provided
675-
if first_text_delta and self.model_request_parameters.response_prefix:
676-
content = self.model_request_parameters.response_prefix + content
677-
first_text_delta = False
678-
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=event.index, content=content)
678+
maybe_event = self._parts_manager.handle_text_delta(
679+
vendor_part_id=event.index, content=event.delta.text
680+
)
679681
if maybe_event is not None: # pragma: no branch
680682
yield maybe_event
681683
elif isinstance(event.delta, BetaThinkingDelta):

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -525,13 +525,9 @@ def _process_response(
525525
]
526526

527527
if choice.message.content is not None:
528-
content = choice.message.content
529-
# Prepend response prefix if provided
530-
if model_request_parameters.response_prefix:
531-
content = model_request_parameters.response_prefix + content
532528
items.extend(
533529
(replace(part, id='content', provider_name=self.system) if isinstance(part, ThinkingPart) else part)
534-
for part in split_content_into_text_and_thinking(content, self.profile.thinking_tags)
530+
for part in split_content_into_text_and_thinking(choice.message.content, self.profile.thinking_tags)
535531
)
536532
if choice.message.tool_calls is not None:
537533
for c in choice.message.tool_calls:
@@ -1243,7 +1239,12 @@ class OpenAIStreamedResponse(StreamedResponse):
12431239
_provider_name: str
12441240

12451241
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
1246-
first_text_delta = True
1242+
# Handle response prefix by emitting it as the first text event
1243+
if response_prefix := self.model_request_parameters.response_prefix:
1244+
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=response_prefix)
1245+
if maybe_event is not None:
1246+
yield maybe_event
1247+
12471248
async for chunk in self._response:
12481249
self._usage += _map_usage(chunk)
12491250

@@ -1266,10 +1267,6 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
12661267
# Handle the text part of the response
12671268
content = choice.delta.content
12681269
if content is not None:
1269-
# Prepend response prefix to the first text delta if provided
1270-
if first_text_delta and self.model_request_parameters.response_prefix:
1271-
content = self.model_request_parameters.response_prefix + content
1272-
first_text_delta = False
12731270
maybe_event = self._parts_manager.handle_text_delta(
12741271
vendor_part_id='content',
12751272
content=content,

pydantic_ai_slim/pydantic_ai/profiles/deepseek.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,7 @@
55

66
def deepseek_model_profile(model_name: str) -> ModelProfile | None:
77
"""Get the model profile for a DeepSeek model."""
8-
return ModelProfile(ignore_streamed_leading_whitespace='r1' in model_name)
8+
return ModelProfile(
9+
ignore_streamed_leading_whitespace='r1' in model_name,
10+
supports_response_prefix=True,
11+
)

pydantic_ai_slim/pydantic_ai/profiles/openai.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,12 @@ def openai_model_profile(model_name: str) -> ModelProfile:
7777
# See https://github.com/pydantic/pydantic-ai/issues/974 for more details.
7878
openai_system_prompt_role = 'user' if model_name.startswith('o1-mini') else None
7979

80-
# Enable response prefix for DeepSeek and OpenRouter models
81-
supports_response_prefix = 'deepseek' in model_name.lower() or 'openrouter' in model_name.lower()
82-
8380
return OpenAIModelProfile(
8481
json_schema_transformer=OpenAIJsonSchemaTransformer,
8582
supports_json_schema_output=True,
8683
supports_json_object_output=True,
8784
openai_unsupported_model_settings=openai_unsupported_model_settings,
8885
openai_system_prompt_role=openai_system_prompt_role,
89-
supports_response_prefix=supports_response_prefix,
9086
openai_chat_supports_web_search=supports_web_search,
9187
)
9288

pydantic_ai_slim/pydantic_ai/providers/openrouter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,9 @@ def model_profile(self, model_name: str) -> ModelProfile | None:
7070

7171
# As OpenRouterProvider is always used with OpenAIChatModel, which used to unconditionally use OpenAIJsonSchemaTransformer,
7272
# we need to maintain that behavior unless json_schema_transformer is set explicitly
73-
return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile)
73+
return OpenAIModelProfile(
74+
json_schema_transformer=OpenAIJsonSchemaTransformer, supports_response_prefix=True
75+
).update(profile)
7476

7577
@overload
7678
def __init__(self) -> None: ...

0 commit comments

Comments
 (0)