Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 31 additions & 20 deletions python/agents/RAG/rag/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os
import uuid

import google.auth
from dotenv import load_dotenv
from google.adk.agents import Agent
from google.adk.tools.retrieval.vertex_ai_rag_retrieval import (
Expand All @@ -28,32 +29,42 @@
from .prompts import return_instructions_root

load_dotenv()

_, project_id = google.auth.default()
os.environ.setdefault("GOOGLE_CLOUD_PROJECT", project_id)
os.environ["GOOGLE_CLOUD_LOCATION"] = "global"
os.environ.setdefault("GOOGLE_GENAI_USE_VERTEXAI", "True")

_ = instrument_adk_with_arize()

# Initialize tools list
tools = []

ask_vertex_retrieval = VertexAiRagRetrieval(
name="retrieve_rag_documentation",
description=(
"Use this tool to retrieve documentation and reference materials for the question from the RAG corpus,"
),
rag_resources=[
rag.RagResource(
# please fill in your own rag corpus
# here is a sample rag corpus for testing purpose
# e.g. projects/123/locations/us-central1/ragCorpora/456
rag_corpus=os.environ.get("RAG_CORPUS")
)
],
similarity_top_k=10,
vector_distance_threshold=0.6,
)
# Only add RAG retrieval tool if RAG_CORPUS is configured
rag_corpus = os.environ.get("RAG_CORPUS")
if rag_corpus:
ask_vertex_retrieval = VertexAiRagRetrieval(
name="retrieve_rag_documentation",
description=(
"Use this tool to retrieve documentation and reference materials for the question from the RAG corpus,"
),
rag_resources=[
rag.RagResource(
# please fill in your own rag corpus
# here is a sample rag corpus for testing purpose
# e.g. projects/123/locations/us-central1/ragCorpora/456
rag_corpus=rag_corpus
)
],
similarity_top_k=10,
vector_distance_threshold=0.6,
)
tools.append(ask_vertex_retrieval)

with using_session(session_id=uuid.uuid4()):
root_agent = Agent(
model="gemini-2.0-flash-001",
model="gemini-3-flash-preview",
name="ask_rag_agent",
instruction=return_instructions_root(),
tools=[
ask_vertex_retrieval,
],
tools=tools,
)
Loading