|
49 | 49 | from pydantic import ValidationError
|
50 | 50 | from starlette.types import Lifespan
|
51 | 51 | from typing_extensions import override
|
| 52 | +from watchdog.events import FileSystemEventHandler |
| 53 | +from watchdog.observers import Observer |
52 | 54 |
|
53 | 55 | from ..agents import RunConfig
|
54 | 56 | from ..agents.live_request_queue import LiveRequest
|
|
87 | 89 | logger = logging.getLogger("google_adk." + __name__)
|
88 | 90 |
|
89 | 91 | _EVAL_SET_FILE_EXTENSION = ".evalset.json"
|
| 92 | +_app_name = "" |
| 93 | +_runners_to_clean = set() |
| 94 | + |
| 95 | + |
| 96 | +class AgentChangeEventHandler(FileSystemEventHandler): |
| 97 | + |
| 98 | + def __init__(self, agent_loader: AgentLoader): |
| 99 | + self.agent_loader = agent_loader |
| 100 | + |
| 101 | + def on_modified(self, event): |
| 102 | + if not (event.src_path.endswith(".py") or event.src_path.endswith(".yaml")): |
| 103 | + return |
| 104 | + logger.info("Change detected in agents directory: %s", event.src_path) |
| 105 | + self.agent_loader.remove_agent_from_cache(_app_name) |
| 106 | + _runners_to_clean.add(_app_name) |
90 | 107 |
|
91 | 108 |
|
92 | 109 | class ApiServerSpanExporter(export.SpanExporter):
|
@@ -205,6 +222,7 @@ def get_fast_api_app(
|
205 | 222 | host: str = "127.0.0.1",
|
206 | 223 | port: int = 8000,
|
207 | 224 | trace_to_cloud: bool = False,
|
| 225 | + reload_agents: bool = False, |
208 | 226 | lifespan: Optional[Lifespan[FastAPI]] = None,
|
209 | 227 | ) -> FastAPI:
|
210 | 228 | # InMemory tracing dict.
|
@@ -235,14 +253,16 @@ def get_fast_api_app(
|
235 | 253 |
|
236 | 254 | @asynccontextmanager
|
237 | 255 | async def internal_lifespan(app: FastAPI):
|
238 |
| - |
239 | 256 | try:
|
240 | 257 | if lifespan:
|
241 | 258 | async with lifespan(app) as lifespan_context:
|
242 | 259 | yield lifespan_context
|
243 | 260 | else:
|
244 | 261 | yield
|
245 | 262 | finally:
|
| 263 | + if reload_agents: |
| 264 | + observer.stop() |
| 265 | + observer.join() |
246 | 266 | # Create tasks for all runner closures to run concurrently
|
247 | 267 | await cleanup.close_runners(list(runner_dict.values()))
|
248 | 268 |
|
@@ -336,6 +356,13 @@ async def internal_lifespan(app: FastAPI):
|
336 | 356 | # initialize Agent Loader
|
337 | 357 | agent_loader = AgentLoader(agents_dir)
|
338 | 358 |
|
| 359 | + # Set up a file system watcher to detect changes in the agents directory. |
| 360 | + observer = Observer() |
| 361 | + if reload_agents: |
| 362 | + event_handler = AgentChangeEventHandler(agent_loader) |
| 363 | + observer.schedule(event_handler, agents_dir, recursive=True) |
| 364 | + observer.start() |
| 365 | + |
339 | 366 | @app.get("/list-apps")
|
340 | 367 | def list_apps() -> list[str]:
|
341 | 368 | base_path = Path.cwd() / agents_dir
|
@@ -390,6 +417,9 @@ async def get_session(
|
390 | 417 | )
|
391 | 418 | if not session:
|
392 | 419 | raise HTTPException(status_code=404, detail="Session not found")
|
| 420 | + |
| 421 | + global _app_name |
| 422 | + _app_name = app_name |
393 | 423 | return session
|
394 | 424 |
|
395 | 425 | @app.get(
|
@@ -947,6 +977,11 @@ async def process_messages():
|
947 | 977 |
|
948 | 978 | async def _get_runner_async(app_name: str) -> Runner:
|
949 | 979 | """Returns the runner for the given app."""
|
| 980 | + if app_name in _runners_to_clean: |
| 981 | + _runners_to_clean.remove(app_name) |
| 982 | + runner = runner_dict.pop(app_name, None) |
| 983 | + await cleanup.close_runners(list([runner])) |
| 984 | + |
950 | 985 | envs.load_dotenv_for_agent(os.path.basename(app_name), agents_dir)
|
951 | 986 | if app_name in runner_dict:
|
952 | 987 | return runner_dict[app_name]
|
|
0 commit comments