|
22 | 22 | ImageUrl,
|
23 | 23 | ModelRequest,
|
24 | 24 | ModelResponse,
|
| 25 | + PartStartEvent, |
25 | 26 | RetryPromptPart,
|
26 | 27 | SystemPromptPart,
|
27 | 28 | TextPart,
|
@@ -2929,3 +2930,50 @@ def test_deprecated_openai_model(openai_api_key: str):
|
2929 | 2930 |
|
2930 | 2931 | provider = OpenAIProvider(api_key=openai_api_key)
|
2931 | 2932 | OpenAIModel('gpt-4o', provider=provider) # type: ignore[reportDeprecated]
|
| 2933 | + |
| 2934 | + |
| 2935 | +async def test_openai_response_prefix(allow_model_requests: None): |
| 2936 | + """Test that OpenAI models correctly handle response prefix.""" |
| 2937 | + c = completion_message( |
| 2938 | + ChatCompletionMessage(content='Red', role='assistant'), |
| 2939 | + ) |
| 2940 | + mock_client = MockOpenAI.create_mock(c) |
| 2941 | + # Use a model name that supports response prefix (DeepSeek models do) |
| 2942 | + m = OpenAIChatModel('deepseek-chat', provider=OpenAIProvider(openai_client=mock_client)) |
| 2943 | + agent = Agent(m) |
| 2944 | + |
| 2945 | + # Test non-streaming response |
| 2946 | + result = await agent.run('What is the name of color #FF0000', response_prefix="It's name is ") |
| 2947 | + assert result.output == "It's name is Red" |
| 2948 | + |
| 2949 | + # Verify that the response prefix was added to the request |
| 2950 | + kwargs = get_mock_chat_completion_kwargs(mock_client)[0] |
| 2951 | + assert 'messages' in kwargs |
| 2952 | + messages = kwargs['messages'] |
| 2953 | + # Should have user message and assistant message with prefix |
| 2954 | + assert len(messages) == 2 |
| 2955 | + assert messages[0]['role'] == 'user' |
| 2956 | + assert messages[1]['role'] == 'assistant' |
| 2957 | + assert messages[1]['content'] == "It's name is " |
| 2958 | + |
| 2959 | + |
| 2960 | +async def test_openai_response_prefix_stream(allow_model_requests: None): |
| 2961 | + """Test that OpenAI models correctly handle response prefix in streaming.""" |
| 2962 | + stream = [text_chunk('Red'), chunk([])] |
| 2963 | + mock_client = MockOpenAI.create_mock_stream(stream) |
| 2964 | + # Use a model name that supports response prefix (DeepSeek models do) |
| 2965 | + m = OpenAIChatModel('deepseek-chat', provider=OpenAIProvider(openai_client=mock_client)) |
| 2966 | + agent = Agent(m) |
| 2967 | + |
| 2968 | + event_parts: list[Any] = [] |
| 2969 | + async with agent.iter(user_prompt='What is the name of color #FF0000', response_prefix="It's name is ") as agent_run: |
| 2970 | + async for node in agent_run: |
| 2971 | + if Agent.is_model_request_node(node): |
| 2972 | + async with node.stream(agent_run.ctx) as request_stream: |
| 2973 | + async for event in request_stream: |
| 2974 | + event_parts.append(event) |
| 2975 | + |
| 2976 | + # Check that the first text part starts with the prefix |
| 2977 | + text_parts = [p for p in event_parts if isinstance(p, PartStartEvent) and isinstance(p.part, TextPart)] |
| 2978 | + assert len(text_parts) > 0 |
| 2979 | + assert cast(TextPart, text_parts[0].part).content == "It's name is Red" |
0 commit comments