diff --git a/src/google/adk/cli/cli.py b/src/google/adk/cli/cli.py index 1d49f50d79..e46d435bc8 100644 --- a/src/google/adk/cli/cli.py +++ b/src/google/adk/cli/cli.py @@ -18,6 +18,11 @@ from pathlib import Path from typing import Optional from typing import Union +import asyncio +import sys +import threading +from watchdog.events import FileSystemEventHandler +from watchdog.observers import Observer import click from google.genai import types @@ -42,6 +47,21 @@ from .utils.service_factory import create_session_service_from_options +class DevModeChangeHandler(FileSystemEventHandler): + """Watchdog event handler to trigger agent reload upon file changes.""" + WATCHED_EXTENSIONS = ('.py', '.yaml') + + def __init__(self, loop: asyncio.AbstractEventLoop, reload_event: asyncio.Event): + super().__init__() + self.loop = loop + self.reload_event = reload_event + + def on_any_event(self, event): + if event.is_directory: + return + if event.src_path.endswith(self.WATCHED_EXTENSIONS) or getattr(event, 'dest_path', '').endswith(self.WATCHED_EXTENSIONS): + self.loop.call_soon_threadsafe(self.reload_event.set) + class InputFile(BaseModel): state: dict[str, object] queries: list[str] @@ -98,6 +118,10 @@ async def run_interactively( session_service: BaseSessionService, credential_service: BaseCredentialService, memory_service: Optional[BaseMemoryService] = None, + dev: bool = False, + reload_event: Optional[asyncio.Event] = None, + agent_loader: Optional[AgentLoader] = None, + agent_folder_name: Optional[str] = None, ) -> None: app = ( root_agent_or_app @@ -111,11 +135,88 @@ async def run_interactively( memory_service=memory_service, credential_service=credential_service, ) + + if dev: + loop = asyncio.get_running_loop() + input_queue = asyncio.Queue() + _EOF_SENTINEL = object() + + def _prompt_user(new_line: bool = False): + prompt = '\n[user]: ' if new_line else '[user]: ' + sys.stdout.write(prompt) + sys.stdout.flush() + + async def _handle_reload(): + nonlocal runner + click.secho('\nChanges detected, reloading agent...', fg='yellow') + if not (agent_loader and agent_folder_name): + return + try: + agent_loader.remove_agent_from_cache(agent_folder_name) + new_agent_or_app = agent_loader.load_agent(agent_folder_name) + reloaded_app = ( + new_agent_or_app + if isinstance(new_agent_or_app, App) + else App(name=session.app_name, root_agent=new_agent_or_app) + ) + new_runner = Runner( + app=reloaded_app, + artifact_service=artifact_service, + session_service=session_service, + memory_service=memory_service, + credential_service=credential_service, + ) + await runner.close() + runner = new_runner + except Exception as e: + click.secho(f'Error reloading agent: {e}', fg='red') + + def _read_input(): + try: + while True: + line = sys.stdin.readline() + if not line: break + loop.call_soon_threadsafe(input_queue.put_nowait, line) + except Exception: + import traceback + print("[ERROR] Exception in stdin reader thread:", file=sys.stderr) + traceback.print_exc(file=sys.stderr) + finally: + if not loop.is_closed(): + loop.call_soon_threadsafe(input_queue.put_nowait, _EOF_SENTINEL) + + threading.Thread(target=_read_input, daemon=True).start() + _prompt_user() + while True: - query = input('[user]: ') + if not dev or reload_event is None: + query = input('[user]: ') + else: + input_task = asyncio.create_task(input_queue.get()) + reload_task = asyncio.create_task(reload_event.wait()) + done, pending = await asyncio.wait( + [input_task, reload_task], return_when=asyncio.FIRST_COMPLETED + ) + + if reload_task in done: + input_task.cancel() + reload_event.clear() + + await _handle_reload() + + _prompt_user(new_line=True) + continue + else: + reload_task.cancel() + query = input_task.result() + if query is _EOF_SENTINEL: + break + if not query or not query.strip(): + if dev: + _prompt_user() continue - if query == 'exit': + if query.strip() == 'exit': break async with Aclosing( runner.run_async( @@ -130,6 +231,10 @@ async def run_interactively( if event.content and event.content.parts: if text := ''.join(part.text or '' for part in event.content.parts): click.echo(f'[{event.author}]: {text}') + + if dev: + _prompt_user(new_line=True) + await runner.close() @@ -141,6 +246,7 @@ async def run_cli( saved_session_file: Optional[str] = None, save_session: bool, session_id: Optional[str] = None, + dev: bool = False, session_service_uri: Optional[str] = None, artifact_service_uri: Optional[str] = None, memory_service_uri: Optional[str] = None, @@ -203,6 +309,20 @@ async def run_cli( credential_service = InMemoryCredentialService() + observer = None + reload_event = None + if dev: + loop = asyncio.get_running_loop() + reload_event = asyncio.Event() + event_handler = DevModeChangeHandler(loop, reload_event) + observer = Observer() + if not agent_root.is_dir(): + raise RuntimeError(f"Agent root directory not found or is not a directory: {agent_root}") + watch_path = str(agent_root) + observer.schedule(event_handler, path=watch_path, recursive=True) + observer.start() + click.secho(f"Auto-reload enabled - watching for file changes in {agent_folder_name}...", fg="green") + # Helper function for printing events def _print_event(event) -> None: content = event.content @@ -214,70 +334,83 @@ def _print_event(event) -> None: author = event.author or 'system' click.echo(f'[{author}]: {"".join(text_parts)}') - if input_file: - session = await run_input_file( - app_name=session_app_name, - user_id=user_id, - agent_or_app=agent_or_app, - artifact_service=artifact_service, - session_service=session_service, - memory_service=memory_service, - credential_service=credential_service, - input_path=input_file, - ) - elif saved_session_file: - # Load the saved session from file - with open(saved_session_file, 'r', encoding='utf-8') as f: - loaded_session = Session.model_validate_json(f.read()) - - # Create a new session in the service, copying state from the file - session = await session_service.create_session( - app_name=session_app_name, - user_id=user_id, - state=loaded_session.state if loaded_session else None, - ) - - # Append events from the file to the new session and display them - if loaded_session: - for event in loaded_session.events: - await session_service.append_event(session, event) - _print_event(event) - - await run_interactively( - agent_or_app, - artifact_service, - session, - session_service, - credential_service, - memory_service=memory_service, - ) - else: - session = await session_service.create_session( - app_name=session_app_name, user_id=user_id - ) - click.echo(f'Running agent {agent_or_app.name}, type exit to exit.') - await run_interactively( - agent_or_app, - artifact_service, - session, - session_service, - credential_service, - memory_service=memory_service, - ) - - if save_session: - session_id = session_id or input('Session ID to save: ') - session_path = agent_root / f'{session_id}.session.json' - - # Fetch the session again to get all the details. - session = await session_service.get_session( - app_name=session.app_name, - user_id=session.user_id, - session_id=session.id, - ) - session_path.write_text( - session.model_dump_json(indent=2, exclude_none=True, by_alias=True), - encoding='utf-8', - ) - - print('Session saved to', session_path) + try: + if input_file: + session = await run_input_file( + app_name=session_app_name, + user_id=user_id, + agent_or_app=agent_or_app, + artifact_service=artifact_service, + session_service=session_service, + memory_service=memory_service, + credential_service=credential_service, + input_path=input_file, + ) + elif saved_session_file: + # Load the saved session from file + with open(saved_session_file, 'r', encoding='utf-8') as f: + loaded_session = Session.model_validate_json(f.read()) + + # Create a new session in the service, copying state from the file + session = await session_service.create_session( + app_name=session_app_name, + user_id=user_id, + state=loaded_session.state if loaded_session else None, + ) + + # Append events from the file to the new session and display them + if loaded_session: + for event in loaded_session.events: + await session_service.append_event(session, event) + _print_event(event) + + await run_interactively( + agent_or_app, + artifact_service, + session, + session_service, + credential_service, + memory_service=memory_service, + dev=dev, + reload_event=reload_event, + agent_loader=agent_loader, + agent_folder_name=agent_folder_name, + ) + else: + session = await session_service.create_session( + app_name=session_app_name, user_id=user_id + ) + click.echo(f'Running agent {agent_or_app.name}, type exit to exit.') + await run_interactively( + agent_or_app, + artifact_service, + session, + session_service, + credential_service, + memory_service=memory_service, + dev=dev, + reload_event=reload_event, + agent_loader=agent_loader, + agent_folder_name=agent_folder_name, + ) + + if save_session: + session_id = session_id or input('Session ID to save: ') + session_path = agent_root / f'{session_id}.session.json' + + # Fetch the session again to get all the details. + session = await session_service.get_session( + app_name=session.app_name, + user_id=session.user_id, + session_id=session.id, + ) + session_path.write_text( + session.model_dump_json(indent=2, exclude_none=True, by_alias=True), + encoding='utf-8', + ) + + print('Session saved to', session_path) + finally: + if observer: + observer.stop() + observer.join() diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index b817d4b43a..60b5a6dafe 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -618,6 +618,12 @@ def wrapper(*args, **kwargs): ), callback=validate_exclusive, ) +@click.option( + "--dev", + is_flag=True, + default=False, + help="Optional. Enable development mode with automatic agent reloading on file changes.", +) @click.argument( "agent", type=click.Path( @@ -630,6 +636,7 @@ def cli_run( session_id: Optional[str], replay: Optional[str], resume: Optional[str], + dev: bool = False, session_service_uri: Optional[str] = None, artifact_service_uri: Optional[str] = None, memory_service_uri: Optional[str] = None, @@ -656,6 +663,7 @@ def cli_run( saved_session_file=resume, save_session=save_session, session_id=session_id, + dev=dev, session_service_uri=session_service_uri, artifact_service_uri=artifact_service_uri, memory_service_uri=memory_service_uri, diff --git a/tests/unittests/cli/utils/test_cli.py b/tests/unittests/cli/utils/test_cli.py index f7df1bf17f..893c674251 100644 --- a/tests/unittests/cli/utils/test_cli.py +++ b/tests/unittests/cli/utils/test_cli.py @@ -16,9 +16,12 @@ from __future__ import annotations +import asyncio import json from pathlib import Path +import sys from textwrap import dedent +import time import types from typing import Any from typing import Dict @@ -519,3 +522,66 @@ async def test_run_interactively_whitespace_and_exit( # verify: assistant echoed once with 'echo:hello' assert any("echo:hello" in m for m in echoed) + + +@pytest.mark.asyncio +async def test_run_interactively_dev_reload( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """run_interactively should reload the agent when reload_event is set.""" + + session_service = InMemorySessionService() + sess = await session_service.create_session(app_name="dummy", user_id="u") + artifact_service = InMemoryArtifactService() + credential_service = InMemoryCredentialService() + root_agent = BaseAgent(name="root") + + reload_event = asyncio.Event() + sys_stdin_readline_calls = [] + + main_loop = asyncio.get_running_loop() + + def mock_readline(): + sys_stdin_readline_calls.append(True) + if len(sys_stdin_readline_calls) == 1: + # Return a normal query first + return "hello\n" + elif len(sys_stdin_readline_calls) == 2: + # Sleep a bit to allow the loop to run, then trigger the reload + time.sleep(0.1) + # In tests, we need to set the event thread-safely + main_loop.call_soon_threadsafe(reload_event.set) + time.sleep(0.1) + return "exit\n" + return "exit\n" + + monkeypatch.setattr(sys.stdin, "readline", mock_readline) + + echoed: list[str] = [] + monkeypatch.setattr(click, "echo", lambda msg, **kw: echoed.append(msg)) + monkeypatch.setattr(click, "secho", lambda msg, **kw: echoed.append(msg)) + + class DummyAgentLoader: + removed = False + reloaded = False + def remove_agent_from_cache(self, name): + self.removed = True + + def load_agent(self, name): + self.reloaded = True + return BaseAgent(name="reloaded_root") + + loader = DummyAgentLoader() + + await cli.run_interactively( + root_agent, artifact_service, sess, session_service, credential_service, + dev=True, reload_event=reload_event, agent_loader=loader, agent_folder_name="dummy_folder" + ) + + # Check that the agent handled the first message + assert any("echo:hello" in m for m in echoed) + # Check that the reload message was printed + assert any("reloading agent..." in m for m in echoed) + # Check that the loader cache was manipulated + assert loader.removed is True + assert loader.reloaded is True