Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 30 additions & 29 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,44 +25,45 @@ classifiers = [ # List of https://pypi.org/classifiers/
]
dependencies = [
# go/keep-sorted start
"PyYAML>=6.0.2, <7.0.0", # For APIHubToolset.
"absolufy-imports>=0.3.1, <1.0.0", # For Agent Engine deployment.
"anyio>=4.9.0, <5.0.0;python_version>='3.10'", # For MCP Session Manager
"authlib>=1.5.1, <2.0.0", # For RestAPI Tool
"click>=8.1.8, <9.0.0", # For CLI tools
"fastapi>=0.115.0, <1.0.0", # FastAPI framework
"google-api-python-client>=2.157.0, <3.0.0", # Google API client discovery
"google-cloud-aiplatform[agent_engines]>=1.112.0, <2.0.0",# For VertexAI integrations, e.g. example store.
"google-cloud-bigtable>=2.32.0", # For Bigtable database
"google-cloud-discoveryengine>=0.13.12, <0.14.0", # For Discovery Engine Search Tool
"google-cloud-secret-manager>=2.22.0, <3.0.0", # Fetching secrets in RestAPI Tool
"google-cloud-spanner>=3.56.0, <4.0.0", # For Spanner database
"google-cloud-speech>=2.30.0, <3.0.0", # For Audio Transcription
"google-cloud-storage>=2.18.0, <3.0.0", # For GCS Artifact service
"google-genai>=1.41.0, <2.0.0", # Google GenAI SDK
"graphviz>=0.20.2, <1.0.0", # Graphviz for graph rendering
"mcp>=1.8.0, <2.0.0;python_version>='3.10'", # For MCP Toolset
"opentelemetry-api>=1.37.0, <=1.37.0", # OpenTelemetry - limit upper version for sdk and api to not risk breaking changes from unstable _logs package.
"PyYAML>=6.0.2, <7.0.0", # For APIHubToolset.
"absolufy-imports>=0.3.1, <1.0.0", # For Agent Engine deployment.
"anyio>=4.9.0, <5.0.0;python_version>='3.10'", # For MCP Session Manager
"authlib>=1.5.1, <2.0.0", # For RestAPI Tool
"click>=8.1.8, <9.0.0", # For CLI tools
"fastapi>=0.115.0, <1.0.0", # FastAPI framework
"google-api-python-client>=2.157.0, <3.0.0", # Google API client discovery
"google-cloud-aiplatform[agent_engines]>=1.112.0, <2.0.0", # For VertexAI integrations, e.g. example store.
"google-cloud-bigtable>=2.32.0", # For Bigtable database
"google-cloud-discoveryengine>=0.13.12, <0.14.0", # For Discovery Engine Search Tool
"google-cloud-secret-manager>=2.22.0, <3.0.0", # Fetching secrets in RestAPI Tool
"google-cloud-spanner>=3.56.0, <4.0.0", # For Spanner database
"google-cloud-speech>=2.30.0, <3.0.0", # For Audio Transcription
"google-cloud-storage>=2.18.0, <3.0.0", # For GCS Artifact service
"google-genai>=1.41.0, <2.0.0", # Google GenAI SDK
"graphviz>=0.20.2, <1.0.0", # Graphviz for graph rendering
"mcp>=1.8.0, <2.0.0;python_version>='3.10'", # For MCP Toolset
"opentelemetry-api>=1.37.0, <=1.37.0", # OpenTelemetry - limit upper version for sdk and api to not risk breaking changes from unstable _logs package.
"opentelemetry-exporter-gcp-logging>=1.9.0a0, <2.0.0",
"opentelemetry-exporter-gcp-monitoring>=1.9.0a0, <2.0.0",
"opentelemetry-exporter-gcp-trace>=1.9.0, <2.0.0",
"opentelemetry-exporter-otlp-proto-http>=1.36.0",
"opentelemetry-resourcedetector-gcp>=1.9.0a0, <2.0.0",
"opentelemetry-sdk>=1.37.0, <=1.37.0",
"pydantic>=2.0, <3.0.0", # For data validation/models
"python-dateutil>=2.9.0.post0, <3.0.0", # For Vertext AI Session Service
"python-dotenv>=1.0.0, <2.0.0", # To manage environment variables
"pydantic>=2.0, <3.0.0", # For data validation/models
"python-dateutil>=2.9.0.post0, <3.0.0", # For Vertext AI Session Service
"python-dotenv>=1.0.0, <2.0.0", # To manage environment variables
"requests>=2.32.4, <3.0.0",
"sqlalchemy-spanner>=1.14.0", # Spanner database session service
"sqlalchemy>=2.0, <3.0.0", # SQL database ORM
"starlette>=0.46.2, <1.0.0", # For FastAPI CLI
"tenacity>=8.0.0, <9.0.0", # For Retry management
"sqlalchemy-spanner>=1.14.0", # Spanner database session service
"sqlalchemy>=2.0, <3.0.0", # SQL database ORM
"starlette>=0.46.2, <1.0.0", # For FastAPI CLI
"tenacity>=8.0.0, <9.0.0", # For Retry management
"typing-extensions>=4.5, <5",
"tzlocal>=5.3, <6.0", # Time zone utilities
"uvicorn>=0.34.0, <1.0.0", # ASGI server for FastAPI
"watchdog>=6.0.0, <7.0.0", # For file change detection and hot reload
"websockets>=15.0.1, <16.0.0", # For BaseLlmFlow
"tzlocal>=5.3, <6.0", # Time zone utilities
"uvicorn>=0.34.0, <1.0.0", # ASGI server for FastAPI
"watchdog>=6.0.0, <7.0.0", # For file change detection and hot reload
"websockets>=15.0.1, <16.0.0", # For BaseLlmFlow
# go/keep-sorted end
"typesense>=1.1.1",
]
dynamic = ["version"]

Expand Down
10 changes: 10 additions & 0 deletions src/google/adk/sessions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,13 @@
'DatabaseSessionService require sqlalchemy>=2.0, please ensure it is'
' installed correctly.'
)

try:
from .typesense_session_service import TypesenseSessionService

__all__.append('TypesenseSessionService')
except ImportError:
logger.debug(
'TypesenseSessionService requires typesense>=1.1.1, please ensure it is'
' installed correctly.'
)
51 changes: 51 additions & 0 deletions src/google/adk/sessions/_session_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@
"""Utility functions for session service."""
from __future__ import annotations

import copy
from typing import Any
from typing import Optional

from google.genai import types

from .state import State


def decode_content(
content: Optional[dict[str, Any]],
Expand All @@ -36,3 +39,51 @@ def decode_grounding_metadata(
if not grounding_metadata:
return None
return types.GroundingMetadata.model_validate(grounding_metadata)


def extract_state_delta(
state: dict[str, Any],
) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
"""Extracts state deltas for app, user, and session scopes.

Args:
state: The state dictionary containing mixed scopes.

Returns:
A tuple of (app_state_delta, user_state_delta, session_state_delta).
"""
app_state_delta = {}
user_state_delta = {}
session_state_delta = {}
if state:
for key in state.keys():
if key.startswith(State.APP_PREFIX):
app_state_delta[key.removeprefix(State.APP_PREFIX)] = state[key]
elif key.startswith(State.USER_PREFIX):
user_state_delta[key.removeprefix(State.USER_PREFIX)] = state[key]
elif not key.startswith(State.TEMP_PREFIX):
session_state_delta[key] = state[key]
return app_state_delta, user_state_delta, session_state_delta


def merge_state(
app_state: dict[str, Any],
user_state: dict[str, Any],
session_state: dict[str, Any],
) -> dict[str, Any]:
"""Merges app, user, and session states into a single state dictionary.

Args:
app_state: The app-level state.
user_state: The user-level state.
session_state: The session-level state.

Returns:
A merged state dictionary with appropriate prefixes.
"""
merged_state = copy.deepcopy(session_state)
for key in app_state.keys():
merged_state[State.APP_PREFIX + key] = app_state[key]
for key in user_state.keys():
merged_state[State.USER_PREFIX + key] = user_state[key]
return merged_state
37 changes: 7 additions & 30 deletions src/google/adk/sessions/database_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@

from . import _session_util
from ..events.event import Event
from ._session_util import extract_state_delta
from ._session_util import merge_state
from .base_session_service import BaseSessionService
from .base_session_service import GetSessionConfig
from .base_session_service import ListSessionsResponse
Expand Down Expand Up @@ -463,7 +465,7 @@ async def create_session(
sql_session.add(storage_user_state)

# Extract state deltas
app_state_delta, user_state_delta, session_state = _extract_state_delta(
app_state_delta, user_state_delta, session_state = extract_state_delta(
state
)

Expand All @@ -490,7 +492,7 @@ async def create_session(
sql_session.refresh(storage_session)

# Merge states for response
merged_state = _merge_state(app_state, user_state, session_state)
merged_state = merge_state(app_state, user_state, session_state)
session = storage_session.to_session(state=merged_state)
return session

Expand Down Expand Up @@ -545,7 +547,7 @@ async def get_session(
session_state = storage_session.state

# Merge states
merged_state = _merge_state(app_state, user_state, session_state)
merged_state = merge_state(app_state, user_state, session_state)

# Convert storage session to session
events = [e.to_event() for e in reversed(storage_events)]
Expand Down Expand Up @@ -576,7 +578,7 @@ async def list_sessions(
sessions = []
for storage_session in results:
session_state = storage_session.state
merged_state = _merge_state(app_state, user_state, session_state)
merged_state = merge_state(app_state, user_state, session_state)

sessions.append(storage_session.to_session(state=merged_state))
return ListSessionsResponse(sessions=sessions)
Expand Down Expand Up @@ -636,7 +638,7 @@ async def append_event(self, session: Session, event: Event) -> Event:
if event.actions:
if event.actions.state_delta:
app_state_delta, user_state_delta, session_state_delta = (
_extract_state_delta(event.actions.state_delta)
extract_state_delta(event.actions.state_delta)
)

# Merge state and update storage
Expand All @@ -661,28 +663,3 @@ async def append_event(self, session: Session, event: Event) -> Event:
# Also update the in-memory session
await super().append_event(session=session, event=event)
return event


def _extract_state_delta(state: dict[str, Any]):
app_state_delta = {}
user_state_delta = {}
session_state_delta = {}
if state:
for key in state.keys():
if key.startswith(State.APP_PREFIX):
app_state_delta[key.removeprefix(State.APP_PREFIX)] = state[key]
elif key.startswith(State.USER_PREFIX):
user_state_delta[key.removeprefix(State.USER_PREFIX)] = state[key]
elif not key.startswith(State.TEMP_PREFIX):
session_state_delta[key] = state[key]
return app_state_delta, user_state_delta, session_state_delta


def _merge_state(app_state, user_state, session_state):
# Merge states for response
merged_state = copy.deepcopy(session_state)
for key in app_state.keys():
merged_state[State.APP_PREFIX + key] = app_state[key]
for key in user_state.keys():
merged_state[State.USER_PREFIX + key] = user_state[key]
return merged_state
Loading