Skip to content

Commit 2e65936

Browse files
ankursharmascopybara-github
authored andcommitted
feat: Add implementation of BaseEvalService that runs evals locally
This change: - Introduces the LocalEvalService Class. - Implements only the "perform_inference" method. Evaluate method will be implemented in the next CL. - Adds required test coverage. PiperOrigin-RevId: 771151722
1 parent 162228d commit 2e65936

File tree

5 files changed

+398
-3
lines changed

5 files changed

+398
-3
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing_extensions import override
18+
19+
from ..agents import BaseAgent
20+
21+
22+
class IdentityAgentCreator:
23+
"""An implementation of the AgentCreator interface that always returns a copy of the root agent."""
24+
25+
def __init__(self, root_agent: BaseAgent):
26+
self._root_agent = root_agent
27+
28+
@override
29+
def get_agent(
30+
self,
31+
) -> BaseAgent:
32+
"""Returns a deep copy of the root agent."""
33+
# TODO: Use Agent.clone() when the PR is merged.
34+
# return self._root_agent.model_copy(deep=True)
35+
return self._root_agent.clone()

src/google/adk/evaluation/base_eval_service.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from abc import ABC
1818
from abc import abstractmethod
19+
from enum import Enum
1920
from typing import AsyncGenerator
2021
from typing import Optional
2122

@@ -56,6 +57,19 @@ class InferenceConfig(BaseModel):
5657
charges.""",
5758
)
5859

60+
parallelism: int = Field(
61+
default=4,
62+
description="""Number of parallel inferences to run during an Eval. Few
63+
factors to consider while changing this value:
64+
65+
1) Your available quota with the model. Models tend to enforce per-minute or
66+
per-second SLAs. Using a larger value could result in the eval quickly consuming
67+
the quota.
68+
69+
2) The tools used by the Agent could also have their SLA. Using a larger value
70+
could also overwhelm those tools.""",
71+
)
72+
5973

6074
class InferenceRequest(BaseModel):
6175
"""Represent a request to perform inferences for the eval cases in an eval set."""
@@ -88,6 +102,14 @@ class InferenceRequest(BaseModel):
88102
)
89103

90104

105+
class InferenceStatus(Enum):
106+
"""Status of the inference."""
107+
108+
UNKNOWN = 0
109+
SUCCESS = 1
110+
FAILURE = 2
111+
112+
91113
class InferenceResult(BaseModel):
92114
"""Contains inference results for a single eval case."""
93115

@@ -106,14 +128,25 @@ class InferenceResult(BaseModel):
106128
description="""Id of the eval case for which inferences were generated.""",
107129
)
108130

109-
inferences: list[Invocation] = Field(
110-
description="""Inferences obtained from the Agent for the eval case."""
131+
inferences: Optional[list[Invocation]] = Field(
132+
default=None,
133+
description="""Inferences obtained from the Agent for the eval case.""",
111134
)
112135

113136
session_id: Optional[str] = Field(
114137
description="""Id of the inference session."""
115138
)
116139

140+
status: InferenceStatus = Field(
141+
default=InferenceStatus.UNKNOWN,
142+
description="""Status of the inference.""",
143+
)
144+
145+
error_message: Optional[str] = Field(
146+
default=None,
147+
description="""Error message if the inference failed.""",
148+
)
149+
117150

118151
class EvaluateRequest(BaseModel):
119152
model_config = ConfigDict(

src/google/adk/evaluation/evaluation_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ async def _process_query(
137137
async def _generate_inferences_from_root_agent(
138138
invocations: list[Invocation],
139139
root_agent: Agent,
140-
reset_func: Any,
140+
reset_func: Optional[Any] = None,
141141
initial_session: Optional[SessionInput] = None,
142142
session_id: Optional[str] = None,
143143
session_service: Optional[BaseSessionService] = None,
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import asyncio
18+
import logging
19+
from typing import AsyncGenerator
20+
from typing import Callable
21+
from typing import Optional
22+
import uuid
23+
24+
from typing_extensions import override
25+
26+
from ..agents import BaseAgent
27+
from ..artifacts.base_artifact_service import BaseArtifactService
28+
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
29+
from ..errors.not_found_error import NotFoundError
30+
from ..sessions.base_session_service import BaseSessionService
31+
from ..sessions.in_memory_session_service import InMemorySessionService
32+
from ..utils.feature_decorator import working_in_progress
33+
from .base_eval_service import BaseEvalService
34+
from .base_eval_service import EvaluateRequest
35+
from .base_eval_service import InferenceRequest
36+
from .base_eval_service import InferenceResult
37+
from .base_eval_service import InferenceStatus
38+
from .eval_result import EvalCaseResult
39+
from .eval_set import EvalCase
40+
from .eval_set_results_manager import EvalSetResultsManager
41+
from .eval_sets_manager import EvalSetsManager
42+
from .evaluation_generator import EvaluationGenerator
43+
from .metric_evaluator_registry import DEFAULT_METRIC_EVALUATOR_REGISTRY
44+
from .metric_evaluator_registry import MetricEvaluatorRegistry
45+
46+
logger = logging.getLogger('google_adk.' + __name__)
47+
48+
EVAL_SESSION_ID_PREFIX = '___eval___session___'
49+
50+
51+
def _get_session_id() -> str:
52+
return f'{EVAL_SESSION_ID_PREFIX}{str(uuid.uuid4())}'
53+
54+
55+
@working_in_progress("Incomplete feature, don't use yet")
56+
class LocalEvalService(BaseEvalService):
57+
"""An implementation of BaseEvalService, that runs the evals locally."""
58+
59+
def __init__(
60+
self,
61+
root_agent: BaseAgent,
62+
eval_sets_manager: EvalSetsManager,
63+
metric_evaluator_registry: MetricEvaluatorRegistry = DEFAULT_METRIC_EVALUATOR_REGISTRY,
64+
session_service: BaseSessionService = InMemorySessionService(),
65+
artifact_service: BaseArtifactService = InMemoryArtifactService(),
66+
eval_set_results_manager: Optional[EvalSetResultsManager] = None,
67+
session_id_supplier: Callable[[], str] = _get_session_id,
68+
):
69+
self._root_agent = root_agent
70+
self._eval_sets_manager = eval_sets_manager
71+
self._metric_evaluator_registry = metric_evaluator_registry
72+
self._session_service = session_service
73+
self._artifact_service = artifact_service
74+
self._eval_set_results_manager = eval_set_results_manager
75+
self._session_id_supplier = session_id_supplier
76+
77+
@override
78+
async def perform_inference(
79+
self,
80+
inference_request: InferenceRequest,
81+
) -> AsyncGenerator[InferenceResult, None]:
82+
"""Returns InferenceResult obtained from the Agent as and when they are available.
83+
84+
Args:
85+
inference_request: The request for generating inferences.
86+
"""
87+
# Get the eval set from the storage.
88+
eval_set = self._eval_sets_manager.get_eval_set(
89+
app_name=inference_request.app_name,
90+
eval_set_id=inference_request.eval_set_id,
91+
)
92+
93+
if not eval_set:
94+
raise NotFoundError(
95+
f'Eval set with id {inference_request.eval_set_id} not found for app'
96+
f' {inference_request.app_name}'
97+
)
98+
99+
# Select eval cases for which we need to run inferencing. If the inference
100+
# request specified eval cases, then we use only those.
101+
eval_cases = eval_set.eval_cases
102+
if inference_request.eval_case_ids:
103+
eval_cases = [
104+
eval_case
105+
for eval_case in eval_cases
106+
if eval_case.eval_id in inference_request.eval_case_ids
107+
]
108+
109+
root_agent = self._root_agent.clone()
110+
111+
semaphore = asyncio.Semaphore(
112+
value=inference_request.inference_config.parallelism
113+
)
114+
115+
async def run_inference(eval_case):
116+
async with semaphore:
117+
return await self._perform_inference_sigle_eval_item(
118+
app_name=inference_request.app_name,
119+
eval_set_id=inference_request.eval_set_id,
120+
eval_case=eval_case,
121+
root_agent=root_agent,
122+
)
123+
124+
inference_results = [run_inference(eval_case) for eval_case in eval_cases]
125+
for inference_result in asyncio.as_completed(inference_results):
126+
yield await inference_result
127+
128+
@override
129+
async def evaluate(
130+
self,
131+
evaluate_request: EvaluateRequest,
132+
) -> AsyncGenerator[EvalCaseResult, None]:
133+
"""Returns EvalCaseResult for each item as and when they are available.
134+
135+
Args:
136+
evaluate_request: The request to perform metric evaluations on the
137+
inferences.
138+
"""
139+
raise NotImplementedError()
140+
141+
async def _perform_inference_sigle_eval_item(
142+
self,
143+
app_name: str,
144+
eval_set_id: str,
145+
eval_case: EvalCase,
146+
root_agent: BaseAgent,
147+
) -> InferenceResult:
148+
initial_session = eval_case.session_input
149+
session_id = self._session_id_supplier()
150+
inference_result = InferenceResult(
151+
app_name=app_name,
152+
eval_set_id=eval_set_id,
153+
eval_case_id=eval_case.eval_id,
154+
session_id=session_id,
155+
)
156+
157+
try:
158+
inferences = (
159+
await EvaluationGenerator._generate_inferences_from_root_agent(
160+
invocations=eval_case.conversation,
161+
root_agent=root_agent,
162+
initial_session=initial_session,
163+
session_id=session_id,
164+
session_service=self._session_service,
165+
artifact_service=self._artifact_service,
166+
)
167+
)
168+
169+
inference_result.inferences = inferences
170+
inference_result.status = InferenceStatus.SUCCESS
171+
172+
return inference_result
173+
except Exception as e:
174+
# We intentionally catch the Exception as we don't failures to affect
175+
# other inferences.
176+
logger.error(
177+
'Inference failed for eval case `%s` with error %s',
178+
eval_case.eval_id,
179+
e,
180+
)
181+
inference_result.status = InferenceStatus.FAILURE
182+
inference_result.error_message = str(e)
183+
return inference_result

0 commit comments

Comments
 (0)