Skip to content

feat: Add implementation of BaseEvalService that runs evals locally #1756

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions src/google/adk/evaluation/agent_creator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing_extensions import override

from ..agents import BaseAgent


class IdentityAgentCreator:
"""An implementation of the AgentCreator interface that always returns a copy of the root agent."""

def __init__(self, root_agent: BaseAgent):
self._root_agent = root_agent

@override
def get_agent(
self,
) -> BaseAgent:
"""Returns a deep copy of the root agent."""
# TODO: Use Agent.clone() when the PR is merged.
# return self._root_agent.model_copy(deep=True)
return self._root_agent.clone()
37 changes: 35 additions & 2 deletions src/google/adk/evaluation/base_eval_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from abc import ABC
from abc import abstractmethod
from enum import Enum
from typing import AsyncGenerator
from typing import Optional

Expand Down Expand Up @@ -56,6 +57,19 @@ class InferenceConfig(BaseModel):
charges.""",
)

parallelism: int = Field(
default=4,
description="""Number of parallel inferences to run during an Eval. Few
factors to consider while changing this value:

1) Your available quota with the model. Models tend to enforce per-minute or
per-second SLAs. Using a larger value could result in the eval quickly consuming
the quota.

2) The tools used by the Agent could also have their SLA. Using a larger value
could also overwhelm those tools.""",
)


class InferenceRequest(BaseModel):
"""Represent a request to perform inferences for the eval cases in an eval set."""
Expand Down Expand Up @@ -88,6 +102,14 @@ class InferenceRequest(BaseModel):
)


class InferenceStatus(Enum):
"""Status of the inference."""

UNKNOWN = 0
SUCCESS = 1
FAILURE = 2


class InferenceResult(BaseModel):
"""Contains inference results for a single eval case."""

Expand All @@ -106,14 +128,25 @@ class InferenceResult(BaseModel):
description="""Id of the eval case for which inferences were generated.""",
)

inferences: list[Invocation] = Field(
description="""Inferences obtained from the Agent for the eval case."""
inferences: Optional[list[Invocation]] = Field(
default=None,
description="""Inferences obtained from the Agent for the eval case.""",
)

session_id: Optional[str] = Field(
description="""Id of the inference session."""
)

status: InferenceStatus = Field(
default=InferenceStatus.UNKNOWN,
description="""Status of the inference.""",
)

error_message: Optional[str] = Field(
default=None,
description="""Error message if the inference failed.""",
)


class EvaluateRequest(BaseModel):
model_config = ConfigDict(
Expand Down
2 changes: 1 addition & 1 deletion src/google/adk/evaluation/evaluation_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ async def _process_query(
async def _generate_inferences_from_root_agent(
invocations: list[Invocation],
root_agent: Agent,
reset_func: Any,
reset_func: Optional[Any] = None,
initial_session: Optional[SessionInput] = None,
session_id: Optional[str] = None,
session_service: Optional[BaseSessionService] = None,
Expand Down
183 changes: 183 additions & 0 deletions src/google/adk/evaluation/local_eval_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import asyncio
import logging
from typing import AsyncGenerator
from typing import Callable
from typing import Optional
import uuid

from typing_extensions import override

from ..agents import BaseAgent
from ..artifacts.base_artifact_service import BaseArtifactService
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
from ..errors.not_found_error import NotFoundError
from ..sessions.base_session_service import BaseSessionService
from ..sessions.in_memory_session_service import InMemorySessionService
from ..utils.feature_decorator import working_in_progress
from .base_eval_service import BaseEvalService
from .base_eval_service import EvaluateRequest
from .base_eval_service import InferenceRequest
from .base_eval_service import InferenceResult
from .base_eval_service import InferenceStatus
from .eval_result import EvalCaseResult
from .eval_set import EvalCase
from .eval_set_results_manager import EvalSetResultsManager
from .eval_sets_manager import EvalSetsManager
from .evaluation_generator import EvaluationGenerator
from .metric_evaluator_registry import DEFAULT_METRIC_EVALUATOR_REGISTRY
from .metric_evaluator_registry import MetricEvaluatorRegistry

logger = logging.getLogger('google_adk.' + __name__)

EVAL_SESSION_ID_PREFIX = '___eval___session___'


def _get_session_id() -> str:
return f'{EVAL_SESSION_ID_PREFIX}{str(uuid.uuid4())}'


@working_in_progress("Incomplete feature, don't use yet")
class LocalEvalService(BaseEvalService):
"""An implementation of BaseEvalService, that runs the evals locally."""

def __init__(
self,
root_agent: BaseAgent,
eval_sets_manager: EvalSetsManager,
metric_evaluator_registry: MetricEvaluatorRegistry = DEFAULT_METRIC_EVALUATOR_REGISTRY,
session_service: BaseSessionService = InMemorySessionService(),
artifact_service: BaseArtifactService = InMemoryArtifactService(),
eval_set_results_manager: Optional[EvalSetResultsManager] = None,
session_id_supplier: Callable[[], str] = _get_session_id,
):
self._root_agent = root_agent
self._eval_sets_manager = eval_sets_manager
self._metric_evaluator_registry = metric_evaluator_registry
self._session_service = session_service
self._artifact_service = artifact_service
self._eval_set_results_manager = eval_set_results_manager
self._session_id_supplier = session_id_supplier

@override
async def perform_inference(
self,
inference_request: InferenceRequest,
) -> AsyncGenerator[InferenceResult, None]:
"""Returns InferenceResult obtained from the Agent as and when they are available.

Args:
inference_request: The request for generating inferences.
"""
# Get the eval set from the storage.
eval_set = self._eval_sets_manager.get_eval_set(
app_name=inference_request.app_name,
eval_set_id=inference_request.eval_set_id,
)

if not eval_set:
raise NotFoundError(
f'Eval set with id {inference_request.eval_set_id} not found for app'
f' {inference_request.app_name}'
)

# Select eval cases for which we need to run inferencing. If the inference
# request specified eval cases, then we use only those.
eval_cases = eval_set.eval_cases
if inference_request.eval_case_ids:
eval_cases = [
eval_case
for eval_case in eval_cases
if eval_case.eval_id in inference_request.eval_case_ids
]

root_agent = self._root_agent.clone()

semaphore = asyncio.Semaphore(
value=inference_request.inference_config.parallelism
)

async def run_inference(eval_case):
async with semaphore:
return await self._perform_inference_sigle_eval_item(
app_name=inference_request.app_name,
eval_set_id=inference_request.eval_set_id,
eval_case=eval_case,
root_agent=root_agent,
)

inference_results = [run_inference(eval_case) for eval_case in eval_cases]
for inference_result in asyncio.as_completed(inference_results):
yield await inference_result

@override
async def evaluate(
self,
evaluate_request: EvaluateRequest,
) -> AsyncGenerator[EvalCaseResult, None]:
"""Returns EvalCaseResult for each item as and when they are available.

Args:
evaluate_request: The request to perform metric evaluations on the
inferences.
"""
raise NotImplementedError()

async def _perform_inference_sigle_eval_item(
self,
app_name: str,
eval_set_id: str,
eval_case: EvalCase,
root_agent: BaseAgent,
) -> InferenceResult:
initial_session = eval_case.session_input
session_id = self._session_id_supplier()
inference_result = InferenceResult(
app_name=app_name,
eval_set_id=eval_set_id,
eval_case_id=eval_case.eval_id,
session_id=session_id,
)

try:
inferences = (
await EvaluationGenerator._generate_inferences_from_root_agent(
invocations=eval_case.conversation,
root_agent=root_agent,
initial_session=initial_session,
session_id=session_id,
session_service=self._session_service,
artifact_service=self._artifact_service,
)
)

inference_result.inferences = inferences
inference_result.status = InferenceStatus.SUCCESS

return inference_result
except Exception as e:
# We intentionally catch the Exception as we don't failures to affect
# other inferences.
logger.error(
'Inference failed for eval case `%s` with error %s',
eval_case.eval_id,
e,
)
inference_result.status = InferenceStatus.FAILURE
inference_result.error_message = str(e)
return inference_result
Loading