Skip to content

Commit 1b7725d

Browse files
pwwpchecopybara-github
authored andcommitted
feat: Implement PluginService for registering and executing plugins
PluginService takes the registration of plugins, and provide the wrapper utilities to execute all plugins. PiperOrigin-RevId: 769834609
1 parent 3901fad commit 1b7725d

File tree

2 files changed

+500
-0
lines changed

2 files changed

+500
-0
lines changed
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from unittest.mock import Mock
18+
19+
from google.adk.agents.base_agent import BaseAgent
20+
from google.adk.agents.callback_context import CallbackContext
21+
from google.adk.agents.invocation_context import InvocationContext
22+
from google.adk.events.event import Event
23+
from google.adk.models.llm_request import LlmRequest
24+
from google.adk.models.llm_response import LlmResponse
25+
from google.adk.plugins.base_plugin import BasePlugin
26+
from google.adk.tools.base_tool import BaseTool
27+
from google.adk.tools.tool_context import ToolContext
28+
from google.genai import types
29+
import pytest
30+
31+
32+
class TestablePlugin(BasePlugin):
33+
__test__ = False
34+
"""A concrete implementation of BasePlugin for testing purposes."""
35+
pass
36+
37+
38+
class FullOverridePlugin(BasePlugin):
39+
__test__ = False
40+
41+
"""A plugin that overrides every single callback method for testing."""
42+
43+
def __init__(self, name: str = "full_override"):
44+
super().__init__(name)
45+
46+
async def on_user_message_callback(self, **kwargs) -> str:
47+
return "overridden_on_user_message"
48+
49+
async def before_run_callback(self, **kwargs) -> str:
50+
return "overridden_before_run"
51+
52+
async def after_run_callback(self, **kwargs) -> str:
53+
return "overridden_after_run"
54+
55+
async def on_event_callback(self, **kwargs) -> str:
56+
return "overridden_on_event"
57+
58+
async def before_agent_callback(self, **kwargs) -> str:
59+
return "overridden_before_agent"
60+
61+
async def after_agent_callback(self, **kwargs) -> str:
62+
return "overridden_after_agent"
63+
64+
async def before_tool_callback(self, **kwargs) -> str:
65+
return "overridden_before_tool"
66+
67+
async def after_tool_callback(self, **kwargs) -> str:
68+
return "overridden_after_tool"
69+
70+
async def before_model_callback(self, **kwargs) -> str:
71+
return "overridden_before_model"
72+
73+
async def after_model_callback(self, **kwargs) -> str:
74+
return "overridden_after_model"
75+
76+
77+
def test_base_plugin_initialization():
78+
"""Tests that a plugin is initialized with the correct name."""
79+
plugin_name = "my_test_plugin"
80+
plugin = TestablePlugin(name=plugin_name)
81+
assert plugin.name == plugin_name
82+
83+
84+
@pytest.mark.asyncio
85+
async def test_base_plugin_default_callbacks_return_none():
86+
"""Tests that the default (non-overridden) callbacks in BasePlugin exist
87+
88+
and return None as expected.
89+
"""
90+
plugin = TestablePlugin(name="default_plugin")
91+
92+
# Mocking all necessary context objects
93+
mock_context = Mock()
94+
mock_user_message = Mock()
95+
96+
# The default implementations should do nothing and return None.
97+
assert (
98+
await plugin.on_user_message_callback(
99+
user_message=mock_user_message,
100+
invocation_context=mock_context,
101+
)
102+
is None
103+
)
104+
assert (
105+
await plugin.before_run_callback(invocation_context=mock_context) is None
106+
)
107+
assert (
108+
await plugin.after_run_callback(invocation_context=mock_context) is None
109+
)
110+
assert (
111+
await plugin.on_event_callback(
112+
invocation_context=mock_context, event=mock_context
113+
)
114+
is None
115+
)
116+
assert (
117+
await plugin.before_agent_callback(
118+
agent=mock_context, callback_context=mock_context
119+
)
120+
is None
121+
)
122+
assert (
123+
await plugin.after_agent_callback(
124+
agent=mock_context, callback_context=mock_context
125+
)
126+
is None
127+
)
128+
assert (
129+
await plugin.before_tool_callback(
130+
tool=mock_context, tool_args={}, tool_context=mock_context
131+
)
132+
is None
133+
)
134+
assert (
135+
await plugin.after_tool_callback(
136+
tool=mock_context, tool_args={}, tool_context=mock_context, result={}
137+
)
138+
is None
139+
)
140+
assert (
141+
await plugin.before_model_callback(
142+
callback_context=mock_context, llm_request=mock_context
143+
)
144+
is None
145+
)
146+
assert (
147+
await plugin.after_model_callback(
148+
callback_context=mock_context, llm_response=mock_context
149+
)
150+
is None
151+
)
152+
153+
154+
@pytest.mark.asyncio
155+
async def test_base_plugin_all_callbacks_can_be_overridden():
156+
"""Verifies that a user can create a subclass of BasePlugin and that all
157+
158+
overridden methods are correctly called.
159+
"""
160+
plugin = FullOverridePlugin()
161+
162+
# Create mock objects for all required arguments. We don't need real
163+
# objects, just placeholders to satisfy the method signatures.
164+
mock_user_message = Mock(spec=types.Content)
165+
mock_invocation_context = Mock(spec=InvocationContext)
166+
mock_callback_context = Mock(spec=CallbackContext)
167+
mock_agent = Mock(spec=BaseAgent)
168+
mock_tool = Mock(spec=BaseTool)
169+
mock_tool_context = Mock(spec=ToolContext)
170+
mock_llm_request = Mock(spec=LlmRequest)
171+
mock_llm_response = Mock(spec=LlmResponse)
172+
mock_event = Mock(spec=Event)
173+
174+
# Call each method and assert it returns the unique string from the override.
175+
# This proves that the subclass's method was executed.
176+
assert (
177+
await plugin.on_user_message_callback(
178+
user_message=mock_user_message,
179+
invocation_context=mock_invocation_context,
180+
)
181+
== "overridden_on_user_message"
182+
)
183+
assert (
184+
await plugin.before_run_callback(
185+
invocation_context=mock_invocation_context
186+
)
187+
== "overridden_before_run"
188+
)
189+
assert (
190+
await plugin.after_run_callback(
191+
invocation_context=mock_invocation_context
192+
)
193+
== "overridden_after_run"
194+
)
195+
assert (
196+
await plugin.on_event_callback(
197+
invocation_context=mock_invocation_context, event=mock_event
198+
)
199+
== "overridden_on_event"
200+
)
201+
assert (
202+
await plugin.before_agent_callback(
203+
agent=mock_agent, callback_context=mock_callback_context
204+
)
205+
== "overridden_before_agent"
206+
)
207+
assert (
208+
await plugin.after_agent_callback(
209+
agent=mock_agent, callback_context=mock_callback_context
210+
)
211+
== "overridden_after_agent"
212+
)
213+
assert (
214+
await plugin.before_model_callback(
215+
callback_context=mock_callback_context, llm_request=mock_llm_request
216+
)
217+
== "overridden_before_model"
218+
)
219+
assert (
220+
await plugin.after_model_callback(
221+
callback_context=mock_callback_context, llm_response=mock_llm_response
222+
)
223+
== "overridden_after_model"
224+
)
225+
assert (
226+
await plugin.before_tool_callback(
227+
tool=mock_tool, tool_args={}, tool_context=mock_tool_context
228+
)
229+
== "overridden_before_tool"
230+
)
231+
assert (
232+
await plugin.after_tool_callback(
233+
tool=mock_tool,
234+
tool_args={},
235+
tool_context=mock_tool_context,
236+
result={},
237+
)
238+
== "overridden_after_tool"
239+
)

0 commit comments

Comments
 (0)