Skip to content
Open
271 changes: 202 additions & 69 deletions src/google/adk/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
from pathlib import Path
from typing import Optional
from typing import Union
import asyncio
import sys
import threading
from watchdog.events import FileSystemEventHandler
from watchdog.observers import Observer

import click
from google.genai import types
Expand All @@ -42,6 +47,21 @@
from .utils.service_factory import create_session_service_from_options


class DevModeChangeHandler(FileSystemEventHandler):
"""Watchdog event handler to trigger agent reload upon file changes."""
WATCHED_EXTENSIONS = ('.py', '.yaml')

def __init__(self, loop: asyncio.AbstractEventLoop, reload_event: asyncio.Event):
super().__init__()
self.loop = loop
self.reload_event = reload_event

def on_any_event(self, event):
if event.is_directory:
return
if event.src_path.endswith(self.WATCHED_EXTENSIONS) or getattr(event, 'dest_path', '').endswith(self.WATCHED_EXTENSIONS):
self.loop.call_soon_threadsafe(self.reload_event.set)

class InputFile(BaseModel):
state: dict[str, object]
queries: list[str]
Expand Down Expand Up @@ -98,6 +118,10 @@ async def run_interactively(
session_service: BaseSessionService,
credential_service: BaseCredentialService,
memory_service: Optional[BaseMemoryService] = None,
dev: bool = False,
reload_event: Optional[asyncio.Event] = None,
agent_loader: Optional[AgentLoader] = None,
agent_folder_name: Optional[str] = None,
) -> None:
app = (
root_agent_or_app
Expand All @@ -111,11 +135,88 @@ async def run_interactively(
memory_service=memory_service,
credential_service=credential_service,
)

if dev:
loop = asyncio.get_running_loop()
input_queue = asyncio.Queue()
_EOF_SENTINEL = object()

def _prompt_user(new_line: bool = False):
prompt = '\n[user]: ' if new_line else '[user]: '
sys.stdout.write(prompt)
sys.stdout.flush()

async def _handle_reload():
nonlocal runner
click.secho('\nChanges detected, reloading agent...', fg='yellow')
if not (agent_loader and agent_folder_name):
return
try:
agent_loader.remove_agent_from_cache(agent_folder_name)
new_agent_or_app = agent_loader.load_agent(agent_folder_name)
reloaded_app = (
new_agent_or_app
if isinstance(new_agent_or_app, App)
else App(name=session.app_name, root_agent=new_agent_or_app)
)
new_runner = Runner(
app=reloaded_app,
artifact_service=artifact_service,
session_service=session_service,
memory_service=memory_service,
credential_service=credential_service,
)
await runner.close()
runner = new_runner
except Exception as e:
click.secho(f'Error reloading agent: {e}', fg='red')

def _read_input():
try:
while True:
line = sys.stdin.readline()
if not line: break
loop.call_soon_threadsafe(input_queue.put_nowait, line)
except Exception:
import traceback
print("[ERROR] Exception in stdin reader thread:", file=sys.stderr)
traceback.print_exc(file=sys.stderr)
finally:
if not loop.is_closed():
loop.call_soon_threadsafe(input_queue.put_nowait, _EOF_SENTINEL)

threading.Thread(target=_read_input, daemon=True).start()
_prompt_user()

while True:
query = input('[user]: ')
if not dev or reload_event is None:
query = input('[user]: ')
else:
input_task = asyncio.create_task(input_queue.get())
reload_task = asyncio.create_task(reload_event.wait())
done, pending = await asyncio.wait(
[input_task, reload_task], return_when=asyncio.FIRST_COMPLETED
)

if reload_task in done:
input_task.cancel()
reload_event.clear()

await _handle_reload()

_prompt_user(new_line=True)
continue
else:
reload_task.cancel()
query = input_task.result()
if query is _EOF_SENTINEL:
break

if not query or not query.strip():
if dev:
_prompt_user()
continue
if query == 'exit':
if query.strip() == 'exit':
break
async with Aclosing(
runner.run_async(
Expand All @@ -130,6 +231,10 @@ async def run_interactively(
if event.content and event.content.parts:
if text := ''.join(part.text or '' for part in event.content.parts):
click.echo(f'[{event.author}]: {text}')

if dev:
_prompt_user(new_line=True)

await runner.close()


Expand All @@ -141,6 +246,7 @@ async def run_cli(
saved_session_file: Optional[str] = None,
save_session: bool,
session_id: Optional[str] = None,
dev: bool = False,
session_service_uri: Optional[str] = None,
artifact_service_uri: Optional[str] = None,
memory_service_uri: Optional[str] = None,
Expand Down Expand Up @@ -203,6 +309,20 @@ async def run_cli(

credential_service = InMemoryCredentialService()

observer = None
reload_event = None
if dev:
loop = asyncio.get_running_loop()
reload_event = asyncio.Event()
event_handler = DevModeChangeHandler(loop, reload_event)
observer = Observer()
if not agent_root.is_dir():
raise RuntimeError(f"Agent root directory not found or is not a directory: {agent_root}")
watch_path = str(agent_root)
observer.schedule(event_handler, path=watch_path, recursive=True)
observer.start()
click.secho(f"Auto-reload enabled - watching for file changes in {agent_folder_name}...", fg="green")

# Helper function for printing events
def _print_event(event) -> None:
content = event.content
Expand All @@ -214,70 +334,83 @@ def _print_event(event) -> None:
author = event.author or 'system'
click.echo(f'[{author}]: {"".join(text_parts)}')

if input_file:
session = await run_input_file(
app_name=session_app_name,
user_id=user_id,
agent_or_app=agent_or_app,
artifact_service=artifact_service,
session_service=session_service,
memory_service=memory_service,
credential_service=credential_service,
input_path=input_file,
)
elif saved_session_file:
# Load the saved session from file
with open(saved_session_file, 'r', encoding='utf-8') as f:
loaded_session = Session.model_validate_json(f.read())

# Create a new session in the service, copying state from the file
session = await session_service.create_session(
app_name=session_app_name,
user_id=user_id,
state=loaded_session.state if loaded_session else None,
)

# Append events from the file to the new session and display them
if loaded_session:
for event in loaded_session.events:
await session_service.append_event(session, event)
_print_event(event)

await run_interactively(
agent_or_app,
artifact_service,
session,
session_service,
credential_service,
memory_service=memory_service,
)
else:
session = await session_service.create_session(
app_name=session_app_name, user_id=user_id
)
click.echo(f'Running agent {agent_or_app.name}, type exit to exit.')
await run_interactively(
agent_or_app,
artifact_service,
session,
session_service,
credential_service,
memory_service=memory_service,
)

if save_session:
session_id = session_id or input('Session ID to save: ')
session_path = agent_root / f'{session_id}.session.json'

# Fetch the session again to get all the details.
session = await session_service.get_session(
app_name=session.app_name,
user_id=session.user_id,
session_id=session.id,
)
session_path.write_text(
session.model_dump_json(indent=2, exclude_none=True, by_alias=True),
encoding='utf-8',
)

print('Session saved to', session_path)
try:
if input_file:
session = await run_input_file(
app_name=session_app_name,
user_id=user_id,
agent_or_app=agent_or_app,
artifact_service=artifact_service,
session_service=session_service,
memory_service=memory_service,
credential_service=credential_service,
input_path=input_file,
)
elif saved_session_file:
# Load the saved session from file
with open(saved_session_file, 'r', encoding='utf-8') as f:
loaded_session = Session.model_validate_json(f.read())

# Create a new session in the service, copying state from the file
session = await session_service.create_session(
app_name=session_app_name,
user_id=user_id,
state=loaded_session.state if loaded_session else None,
)

# Append events from the file to the new session and display them
if loaded_session:
for event in loaded_session.events:
await session_service.append_event(session, event)
_print_event(event)

await run_interactively(
agent_or_app,
artifact_service,
session,
session_service,
credential_service,
memory_service=memory_service,
dev=dev,
reload_event=reload_event,
agent_loader=agent_loader,
agent_folder_name=agent_folder_name,
)
else:
session = await session_service.create_session(
app_name=session_app_name, user_id=user_id
)
click.echo(f'Running agent {agent_or_app.name}, type exit to exit.')
await run_interactively(
agent_or_app,
artifact_service,
session,
session_service,
credential_service,
memory_service=memory_service,
dev=dev,
reload_event=reload_event,
agent_loader=agent_loader,
agent_folder_name=agent_folder_name,
)

if save_session:
session_id = session_id or input('Session ID to save: ')
session_path = agent_root / f'{session_id}.session.json'

# Fetch the session again to get all the details.
session = await session_service.get_session(
app_name=session.app_name,
user_id=session.user_id,
session_id=session.id,
)
session_path.write_text(
session.model_dump_json(indent=2, exclude_none=True, by_alias=True),
encoding='utf-8',
)

print('Session saved to', session_path)
finally:
if observer:
observer.stop()
observer.join()
8 changes: 8 additions & 0 deletions src/google/adk/cli/cli_tools_click.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,12 @@ def wrapper(*args, **kwargs):
),
callback=validate_exclusive,
)
@click.option(
"--dev",
is_flag=True,
default=False,
help="Optional. Enable development mode with automatic agent reloading on file changes.",
)
@click.argument(
"agent",
type=click.Path(
Expand All @@ -630,6 +636,7 @@ def cli_run(
session_id: Optional[str],
replay: Optional[str],
resume: Optional[str],
dev: bool = False,
session_service_uri: Optional[str] = None,
artifact_service_uri: Optional[str] = None,
memory_service_uri: Optional[str] = None,
Expand All @@ -656,6 +663,7 @@ def cli_run(
saved_session_file=resume,
save_session=save_session,
session_id=session_id,
dev=dev,
session_service_uri=session_service_uri,
artifact_service_uri=artifact_service_uri,
memory_service_uri=memory_service_uri,
Expand Down
Loading