Skip to content

Commit 65f0109

Browse files
ankursharmascopybara-github
authored andcommitted
feat: Add implementation of BaseEvalService that runs evals locally
This change: - Introduces the LocalEvalService Class. - Implements only the "perform_inference" method. Evaluate method will be implemented in the next CL. - Adds required test coverage. PiperOrigin-RevId: 771151722
1 parent 9bd539e commit 65f0109

File tree

6 files changed

+493
-3
lines changed

6 files changed

+493
-3
lines changed
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from abc import ABC
18+
from abc import abstractmethod
19+
20+
from typing_extensions import override
21+
22+
from ..agents import BaseAgent
23+
24+
25+
class AgentCreator(ABC):
26+
"""Creates an Agent for the purposes of Eval."""
27+
28+
@abstractmethod
29+
def get_agent(
30+
self,
31+
) -> BaseAgent:
32+
"""Returns an instance of an Agent to be used for Eval purposes."""
33+
34+
35+
class IdentityAgentCreator(AgentCreator):
36+
"""An implementation of the AgentCreator interface that always returns a copy of the root agent."""
37+
38+
def __init__(self, root_agent: BaseAgent):
39+
self._root_agent = root_agent
40+
41+
@override
42+
def get_agent(
43+
self,
44+
) -> BaseAgent:
45+
"""Returns a deep copy of the root agent."""
46+
# TODO: Use Agent.clone() when the PR is merged.
47+
return self._root_agent.model_copy(deep=True)

src/google/adk/evaluation/base_eval_service.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from abc import ABC
1818
from abc import abstractmethod
19+
from enum import Enum
1920
from typing import AsyncGenerator
2021
from typing import Optional
2122

@@ -56,6 +57,19 @@ class InferenceConfig(BaseModel):
5657
charges.""",
5758
)
5859

60+
max_inference_parallelism: int = Field(
61+
default=4,
62+
description="""Number of parallel inferences to run during an Eval. Few
63+
factors to consider while changing this value:
64+
65+
1) Your available quota with the model. Models tend to enforce per-minute or
66+
per-second SLAs. Using a larger value could result in the eval quickly consuming
67+
the quota.
68+
69+
2) The tools used by the Agent could also have their SLA. Using a larger value
70+
could also overwhelm those tools.""",
71+
)
72+
5973

6074
class InferenceRequest(BaseModel):
6175
"""Represent a request to perform inferences for the eval cases in an eval set."""
@@ -88,6 +102,14 @@ class InferenceRequest(BaseModel):
88102
)
89103

90104

105+
class InferenceStatus(Enum):
106+
"""Status of the inference."""
107+
108+
UNKNOWN = 0
109+
SUCCESS = 1
110+
FAILURE = 2
111+
112+
91113
class InferenceResult(BaseModel):
92114
"""Contains inference results for a single eval case."""
93115

@@ -106,14 +128,25 @@ class InferenceResult(BaseModel):
106128
description="""Id of the eval case for which inferences were generated.""",
107129
)
108130

109-
inferences: list[Invocation] = Field(
110-
description="""Inferences obtained from the Agent for the eval case."""
131+
inferences: Optional[list[Invocation]] = Field(
132+
default=None,
133+
description="""Inferences obtained from the Agent for the eval case.""",
111134
)
112135

113136
session_id: Optional[str] = Field(
114137
description="""Id of the inference session."""
115138
)
116139

140+
status: InferenceStatus = Field(
141+
default=InferenceStatus.UNKNOWN,
142+
description="""Status of the inference.""",
143+
)
144+
145+
error_message: Optional[str] = Field(
146+
default=None,
147+
description="""Error message if the inference failed.""",
148+
)
149+
117150

118151
class EvaluateRequest(BaseModel):
119152
model_config = ConfigDict(

src/google/adk/evaluation/evaluation_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ async def _process_query(
137137
async def _generate_inferences_from_root_agent(
138138
invocations: list[Invocation],
139139
root_agent: Agent,
140-
reset_func: Any,
140+
reset_func: Optional[Any] = None,
141141
initial_session: Optional[SessionInput] = None,
142142
session_id: Optional[str] = None,
143143
session_service: Optional[BaseSessionService] = None,
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import asyncio
18+
import logging
19+
from typing import AsyncGenerator
20+
from typing import Callable
21+
from typing import Optional
22+
import uuid
23+
24+
from typing_extensions import override
25+
26+
from ..agents import Agent
27+
from ..artifacts.base_artifact_service import BaseArtifactService
28+
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
29+
from ..errors.not_found_error import NotFoundError
30+
from ..sessions.base_session_service import BaseSessionService
31+
from ..sessions.in_memory_session_service import InMemorySessionService
32+
from ..utils.feature_decorator import working_in_progress
33+
from .agent_creator import AgentCreator
34+
from .base_eval_service import BaseEvalService
35+
from .base_eval_service import EvaluateRequest
36+
from .base_eval_service import InferenceRequest
37+
from .base_eval_service import InferenceResult
38+
from .base_eval_service import InferenceStatus
39+
from .eval_result import EvalCaseResult
40+
from .eval_set import EvalCase
41+
from .eval_set_results_manager import EvalSetResultsManager
42+
from .eval_sets_manager import EvalSetsManager
43+
from .evaluation_generator import EvaluationGenerator
44+
from .metric_evaluator_registry import DEFAULT_METRIC_EVALUATOR_REGISTRY
45+
from .metric_evaluator_registry import MetricEvaluatorRegistry
46+
47+
48+
logger = logging.getLogger('google_adk.' + __name__)
49+
50+
EVAL_SESSION_ID_PREFIX = '___eval___session___'
51+
52+
53+
def _get_session_id() -> str:
54+
return f'{EVAL_SESSION_ID_PREFIX}{str(uuid.uuid4())}'
55+
56+
57+
@working_in_progress("Incomplete feature, don't use yet")
58+
class LocalEvalService(BaseEvalService):
59+
"""An implementation of BaseEvalService, that runs the evals locally."""
60+
61+
def __init__(
62+
self,
63+
agent_creator: AgentCreator,
64+
eval_sets_manager: EvalSetsManager,
65+
metric_evaluator_registry: MetricEvaluatorRegistry = DEFAULT_METRIC_EVALUATOR_REGISTRY,
66+
session_service: BaseSessionService = InMemorySessionService(),
67+
artifact_service: BaseArtifactService = InMemoryArtifactService(),
68+
eval_set_results_manager: Optional[EvalSetResultsManager] = None,
69+
session_id_supplier: Callable[[], str] = _get_session_id,
70+
):
71+
self._agent_creator = agent_creator
72+
self._eval_sets_manager = eval_sets_manager
73+
self._metric_evaluator_registry = metric_evaluator_registry
74+
self._session_service = session_service
75+
self._artifact_service = artifact_service
76+
self._eval_set_results_manager = eval_set_results_manager
77+
self._session_id_supplier = session_id_supplier
78+
79+
@override
80+
async def perform_inference(
81+
self,
82+
inference_request: InferenceRequest,
83+
) -> AsyncGenerator[InferenceResult, None]:
84+
"""Returns InferenceResult obtained from the Agent as and when they are available.
85+
86+
Args:
87+
inference_request: The request for generating inferences.
88+
"""
89+
# Get the eval set from the storage.
90+
eval_set = self._eval_sets_manager.get_eval_set(
91+
app_name=inference_request.app_name,
92+
eval_set_id=inference_request.eval_set_id,
93+
)
94+
95+
if not eval_set:
96+
raise NotFoundError(
97+
f'Eval set with id {inference_request.eval_set_id} not found for app'
98+
f' {inference_request.app_name}'
99+
)
100+
101+
# Select eval cases for which we need to run inferencing. If the inference
102+
# request specified eval cases, then we use only those.
103+
eval_cases = eval_set.eval_cases
104+
if inference_request.eval_case_ids:
105+
eval_cases = [
106+
eval_case
107+
for eval_case in eval_cases
108+
if eval_case.eval_id in inference_request.eval_case_ids
109+
]
110+
111+
root_agent = self._agent_creator.get_agent()
112+
113+
semaphore = asyncio.Semaphore(
114+
value=inference_request.inference_config.max_inference_parallelism
115+
)
116+
117+
async def run_inference(eval_case):
118+
async with semaphore:
119+
return await self._perform_inference_sigle_eval_item(
120+
app_name=inference_request.app_name,
121+
eval_set_id=inference_request.eval_set_id,
122+
eval_case=eval_case,
123+
root_agent=root_agent,
124+
)
125+
126+
inference_results = [run_inference(eval_case) for eval_case in eval_cases]
127+
for inference_result in asyncio.as_completed(inference_results):
128+
yield await inference_result
129+
130+
@override
131+
async def evaluate(
132+
self,
133+
evaluate_request: EvaluateRequest,
134+
) -> AsyncGenerator[EvalCaseResult, None]:
135+
"""Returns EvalCaseResult for each item as and when they are available.
136+
137+
Args:
138+
evaluate_request: The request to perform metric evaluations on the
139+
inferences.
140+
"""
141+
raise NotImplementedError()
142+
143+
async def _perform_inference_sigle_eval_item(
144+
self,
145+
app_name: str,
146+
eval_set_id: str,
147+
eval_case: EvalCase,
148+
root_agent: Agent,
149+
) -> InferenceResult:
150+
initial_session = eval_case.session_input
151+
session_id = self._session_id_supplier()
152+
inference_result = InferenceResult(
153+
app_name=app_name,
154+
eval_set_id=eval_set_id,
155+
eval_case_id=eval_case.eval_id,
156+
session_id=session_id,
157+
)
158+
159+
try:
160+
inferences = (
161+
await EvaluationGenerator._generate_inferences_from_root_agent(
162+
invocations=eval_case.conversation,
163+
root_agent=root_agent,
164+
initial_session=initial_session,
165+
session_id=session_id,
166+
session_service=self._session_service,
167+
artifact_service=self._artifact_service,
168+
)
169+
)
170+
171+
inference_result.inferences = inferences
172+
inference_result.status = InferenceStatus.SUCCESS
173+
174+
return inference_result
175+
except Exception as e:
176+
# We intentionally catch the Exception as we don't failures to affect
177+
# other inferences.
178+
logger.error(
179+
'Inference failed for eval case `%s` with error %s',
180+
eval_case.eval_id,
181+
e,
182+
)
183+
inference_result.status = InferenceStatus.FAILURE
184+
inference_result.error_message = str(e)
185+
return inference_result
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from google.adk.agents import Agent
16+
from google.adk.evaluation.agent_creator import IdentityAgentCreator
17+
from google.adk.tools.tool_context import ToolContext
18+
from google.genai import types
19+
20+
21+
def _method_1(arg1: int, tool_context: ToolContext) -> (int, ToolContext):
22+
return (arg1, tool_context)
23+
24+
25+
async def _method_2(arg1: list[int]) -> list[int]:
26+
return arg1
27+
28+
29+
_TEST_SUB_AGENT = Agent(
30+
model="gemini-2.0-flash",
31+
name="test_sub_agent",
32+
description="test sub-agent description",
33+
instruction="test sub-agent instructions",
34+
tools=[
35+
_method_1,
36+
_method_2,
37+
],
38+
)
39+
40+
_TEST_AGENT_1 = Agent(
41+
model="gemini-2.0-flash",
42+
name="test_agent_1",
43+
description="test agent description",
44+
instruction="test agent instructions",
45+
tools=[
46+
_method_1,
47+
_method_2,
48+
],
49+
sub_agents=[_TEST_SUB_AGENT],
50+
generate_content_config=types.GenerateContentConfig(
51+
safety_settings=[
52+
types.SafetySetting( # avoid false alarm about rolling dice.
53+
category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
54+
threshold=types.HarmBlockThreshold.OFF,
55+
),
56+
]
57+
),
58+
)
59+
60+
61+
def test_identity_agent_creator():
62+
creator = IdentityAgentCreator(root_agent=_TEST_AGENT_1)
63+
64+
agent1 = creator.get_agent()
65+
agent2 = creator.get_agent()
66+
67+
assert isinstance(agent1, Agent)
68+
assert isinstance(agent2, Agent)
69+
70+
assert agent1 is not _TEST_AGENT_1 # Ensure it's a copy
71+
assert agent2 is not _TEST_AGENT_1 # Ensure it's a copy
72+
assert agent1 is not agent2 # Ensure different copies are returned
73+
74+
assert agent1.sub_agents[0] is not _TEST_SUB_AGENT # Ensure it's a copy
75+
assert agent2.sub_agents[0] is not _TEST_SUB_AGENT # Ensure it's a copy
76+
assert (
77+
agent1.sub_agents[0] is not agent2.sub_agents[0]
78+
) # Ensure different copies are returned

0 commit comments

Comments
 (0)