Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions src/google/adk/models/lite_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,19 @@
_NEW_LINE = "\n"
_EXCLUDED_PART_FIELD = {"inline_data": {"data"}}

# Mapping of LiteLLM finish_reason strings to FinishReason enum values
# Note: tool_calls/function_call map to STOP because:
# 1. FinishReason.TOOL_CALL enum does not exist (as of google-genai 0.8.0)
# 2. Tool calls represent normal completion (model stopped to invoke tools)
# 3. Gemini native responses use STOP for tool calls (see lite_llm.py:910)
_FINISH_REASON_MAPPING = {
"length": types.FinishReason.MAX_TOKENS,
"stop": types.FinishReason.STOP,
"tool_calls": types.FinishReason.STOP, # Normal completion with tool invocation
"function_call": types.FinishReason.STOP, # Legacy function call variant
"content_filter": types.FinishReason.SAFETY,
}


class ChatCompletionFileUrlObject(TypedDict, total=False):
file_data: str
Expand Down Expand Up @@ -494,13 +507,23 @@ def _model_response_to_generate_content_response(
"""

message = None
if response.get("choices", None):
message = response["choices"][0].get("message", None)
finish_reason = None
if choices := response.get("choices"):
first_choice = choices[0]
message = first_choice.get("message", None)
finish_reason = first_choice.get("finish_reason", None)

if not message:
raise ValueError("No message in response")

llm_response = _message_to_generate_content_response(message)
if finish_reason:
# Map LiteLLM finish_reason strings to FinishReason enum
# This provides type consistency with Gemini native responses and avoids warnings
finish_reason_str = str(finish_reason).lower()
llm_response.finish_reason = _FINISH_REASON_MAPPING.get(
finish_reason_str, types.FinishReason.OTHER
)
if response.get("usage", None):
llm_response.usage_metadata = types.GenerateContentResponseUsageMetadata(
prompt_token_count=response["usage"].get("prompt_tokens", 0),
Expand Down
7 changes: 6 additions & 1 deletion src/google/adk/models/llm_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from typing import Any
from typing import Optional
from typing import Union

from google.genai import types
from pydantic import alias_generators
Expand Down Expand Up @@ -78,7 +79,11 @@ class LlmResponse(BaseModel):
"""

finish_reason: Optional[types.FinishReason] = None
"""The finish reason of the response."""
"""The finish reason of the response.

Always a types.FinishReason enum. String values from underlying model providers
are mapped to corresponding enum values (with fallback to OTHER for unknown values).
"""

error_code: Optional[str] = None
"""Error code if the response is an error. Code varies by model."""
Expand Down
4 changes: 3 additions & 1 deletion src/google/adk/telemetry/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,11 @@ def trace_call_llm(
llm_response.usage_metadata.candidates_token_count,
)
if llm_response.finish_reason:
# finish_reason is always FinishReason enum
finish_reason_str = llm_response.finish_reason.name.lower()
span.set_attribute(
'gen_ai.response.finish_reasons',
[llm_response.finish_reason.value.lower()],
[finish_reason_str],
)


Expand Down
118 changes: 118 additions & 0 deletions tests/unittests/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1903,3 +1903,121 @@ def test_non_gemini_litellm_no_warning():
# Test with non-Gemini model
LiteLlm(model="openai/gpt-4o")
assert len(w) == 0


@pytest.mark.parametrize(
"finish_reason,response_content,expected_content,has_tool_calls",
[
("length", "Test response", "Test response", False),
("stop", "Complete response", "Complete response", False),
(
"tool_calls",
"",
"",
True,
),
("content_filter", "", "", False),
],
ids=["length", "stop", "tool_calls", "content_filter"],
)
@pytest.mark.asyncio
async def test_finish_reason_propagation(
mock_acompletion,
lite_llm_instance,
finish_reason,
response_content,
expected_content,
has_tool_calls,
):
"""Test that finish_reason is properly propagated from LiteLLM response."""
tool_calls = None
if has_tool_calls:
tool_calls = [
ChatCompletionMessageToolCall(
type="function",
id="test_id",
function=Function(
name="test_function",
arguments='{"arg": "value"}',
),
)
]

mock_response = ModelResponse(
choices=[
Choices(
message=ChatCompletionAssistantMessage(
role="assistant",
content=response_content,
tool_calls=tool_calls,
),
finish_reason=finish_reason,
)
]
)
mock_acompletion.return_value = mock_response

llm_request = LlmRequest(
contents=[
types.Content(
role="user", parts=[types.Part.from_text(text="Test prompt")]
)
],
)

async for response in lite_llm_instance.generate_content_async(llm_request):
assert response.content.role == "model"
# Verify finish_reason is mapped to FinishReason enum
assert isinstance(response.finish_reason, types.FinishReason)
# Verify correct enum mapping using dictionary
expected_mapping = {
"length": types.FinishReason.MAX_TOKENS,
"stop": types.FinishReason.STOP,
"tool_calls": types.FinishReason.STOP,
"content_filter": types.FinishReason.SAFETY,
}
assert response.finish_reason == expected_mapping[finish_reason]
if expected_content:
assert response.content.parts[0].text == expected_content
if has_tool_calls:
assert len(response.content.parts) > 0
assert response.content.parts[-1].function_call.name == "test_function"

mock_acompletion.assert_called_once()



@pytest.mark.asyncio
async def test_finish_reason_unknown_maps_to_other(
mock_acompletion, lite_llm_instance
):
"""Test that unknown finish_reason values map to FinishReason.OTHER."""
mock_response = ModelResponse(
choices=[
Choices(
message=ChatCompletionAssistantMessage(
role="assistant",
content="Test response",
),
finish_reason="unknown_reason_type",
)
]
)
mock_acompletion.return_value = mock_response

llm_request = LlmRequest(
contents=[
types.Content(
role="user", parts=[types.Part.from_text(text="Test prompt")]
)
],
)

async for response in lite_llm_instance.generate_content_async(llm_request):
assert response.content.role == "model"
# Unknown finish_reason should map to OTHER
assert isinstance(response.finish_reason, types.FinishReason)
assert response.finish_reason == types.FinishReason.OTHER

mock_acompletion.assert_called_once()