Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion config/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,15 @@ class Environment(BaseSettings):
NEO4J_USERNAME: Optional[str] = Field(default=os.getenv("NEO4J_USERNAME"))
NEO4J_PASSWORD: Optional[str] = Field(default=os.getenv("NEO4J_PASSWORD"))
# Bedrock
## Generative Model
BEDROCK_MODEL_ID: Optional[str] = Field(
default=os.getenv("BEDROCK_MODEL_ID", "us.anthropic.claude-3-5-sonnet-20240620-v1:0"))
## Embedding Model
BEDROCK_EMBEDDING_MODEL_ID: Optional[str] = Field(
default=os.getenv("BEDROCK_EMBEDDING_MODEL_ID", "amazon.titan-embed-text-v2:0"))
# MongoDB
MONGO_URI: Optional[str] = Field(default=os.getenv("MONGO_URI", "mongodb://root:pwd@127.0.0.1:27017?authSource=admin"))
MONGO_DB_NAME: Optional[str] = Field(default=os.getenv("MONGO_DB_NAME", "db0"))
MONGO_DB_NAME: Optional[str] = Field(default=os.getenv("MONGO_DB_NAME", "db0"))
# QDrant
QDRANT_URL: Optional[str] = Field(default=os.getenv("QDRANT_URL", "http://localhost:6333"))
QDRANT_API_KEY: Optional[str] = Field(default=os.getenv("QDRANT_API_KEY"))
5 changes: 4 additions & 1 deletion core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from langgraph.graph import StateGraph

from server import SocketManager
from vectorstore import QdrantClientManager
from .base import LLMBedRockBase, GraphState
from .graph import GraphAgent
from .manager import ChatManager
Expand Down Expand Up @@ -35,8 +36,10 @@ def __init__(
username=env.NEO4J_USERNAME,
password=env.NEO4J_PASSWORD,
)
self._vectorstore = QdrantClientManager()
self._agent = GraphAgent(
graph=self._graph,
vectorstore=self._vectorstore,
llm=self._llm,
chat_manager=self._chat_manager,
sio=sio,
Expand Down Expand Up @@ -81,7 +84,7 @@ async def invoke(self, question: str, **kwargs) -> str:
self.route_status,
{
"search_graph": "SearchGraph",
"search_vector": "SearchGraph", # TODO: Implement vector search
"search_vector": "SearchVector",
"answer_final": "Answer",
}
)
Expand Down
38 changes: 32 additions & 6 deletions core/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from schemas import AgentGraphSubquery, AgentGraphRoute, AgentGraphStart
from server import SocketManager
from vectorstore import QdrantClientManager
from .base import GraphNodesBase
from .manager import ChatManager
from .prompt import (
Expand All @@ -24,12 +25,14 @@ class GraphAgent(GraphNodesBase):
def __init__(
self,
graph: Neo4jGraph,
vectorstore: QdrantClientManager,
llm: ChatBedrock,
chat_manager: ChatManager,
sio: Optional[SocketManager] = None
):
self._chat_manager = chat_manager
self._graph = graph
self._vectorstore = vectorstore
self._llm = llm
self._sio = sio

Expand Down Expand Up @@ -60,10 +63,34 @@ async def search_vector(self, state: dict) -> dict:
:param state:
:return:
"""
# TODO: Implement the vector search logic
print("--SEARCHING VECTOR--")
depth: int = state["depth"]
depth += 1
information_text = 'Buscando vetores no banco de dados...'
await self._emit("agent_updated", {"status": information_text})
documents: list[dict] = state["documents"]
for q in state["subqueries"]:
print(f"Processing query: {q}")
information_text += f"\n- Consultando: {q}"
try:
result = await self._vectorstore.asearch(
query=q,
k=10,
filters=None, # TODO: Implement filters if needed
)
documents.extend([
{
"query": q,
"content": doc.page_content,
"source": doc.metadata['source']
}
for doc in result
])
information_text += " **OK**"
except Exception as e:
print(f"Error during vector search: {e}")
information_text += f"\n- Erro ao buscar vetores: {e}"
await self._emit("agent_updated", {"status": information_text})
return {"documents": state["documents"], "depth": depth}

async def search_graph(self, state: dict) -> dict:
Expand All @@ -76,15 +103,14 @@ async def search_graph(self, state: dict) -> dict:
information_text = 'Buscando relacionamentos em grafos...'
await self._emit("agent_updated", {"status": information_text})
documents: list[dict] = state["documents"]
subqueries: list[dict] = state["subqueries"]
depth: int = state["depth"]
cypher_prompt_copy = CYPHER_PROMPT.model_copy()
cypher_prompt_copy.template = cypher_prompt_copy.template.replace(
"{{chat_history}}", await self._chat_manager.get_history_as_string())
# Search the graph using the LLM
for query in subqueries:
print(f"Processing query: {query}")
information_text += f"\n- Consultando: {query}"
for q in state["subqueries"]:
print(f"Processing query: {q}")
information_text += f"\n- Consultando: {q}"
try:
cypher_chain = GraphCypherQAChain.from_llm(
self._llm,
Expand All @@ -96,7 +122,7 @@ async def search_graph(self, state: dict) -> dict:
allow_dangerous_requests=True,
validate_cypher=True,
)
document = await cypher_chain.ainvoke(query)
document = await cypher_chain.ainvoke(q)
documents.append(document)
information_text += " **OK**"
except CypherSyntaxError as e:
Expand Down
17 changes: 17 additions & 0 deletions core/prompt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from langchain.prompts import PromptTemplate


ROUTING_CONSTANTS = {
"search_graph": "Consultando os relacionamentos do grafo",
"search_vector": "Consultando o contexto semântico",
Expand Down Expand Up @@ -119,8 +120,24 @@
template="""Você é um assistente jurídico responsável por fornecer respostas claras, objetivas e fundamentadas a perguntas legais.
Utilize exclusivamente as informações do contexto fornecido para elaborar sua resposta.
Caso o contexto esteja vazio ou não contenha detalhes suficientes, informe educadamente que não há informações suficientes para responder à pergunta.
Caso houver as fontes utilizadas, informe-as ao final da resposta.

Informações disponíveis:
{context}""",
input_variables=["context"],
)

EXTRACT_ENTITIES_PROMPT = PromptTemplate(
template="""You are a legal extraction assistant specialized in the Brazilian legal domain.
Your task is to extract structured legal information from text in order to build a Brazilian legal knowledge graph.
Identify legal entities strictly following the user prompt.

You must produce output in JSON format, containing a single JSON object with the keys:
If something is missing, leave it empty or null. Do not guess or hallucinate. Extract precisely.
{entities}. Use only the explicit information in the text.

Extract the following legal entities from the provided text:
{text}
""",
input_variables=["entities", "text"],
)
13 changes: 12 additions & 1 deletion docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ services:
volumes:
- neo4j-data:/data

qdrant:
image: qdrant/qdrant:latest
ports:
- "6333:6333" # HTTP API
- "6334:6334" # gRPC API
volumes:
- qdrant-data:/qdrant/storage

mongo:
image: mongo:latest
environment:
Expand All @@ -29,6 +37,7 @@ services:
- .env
environment:
NEO4J_URL: bolt://neo4j:7687
QDRANT_URL: http://qdrant:6333
command: celery -A workers.tasks worker --loglevel=INFO
volumes:
- ./:/app
Expand All @@ -41,6 +50,7 @@ services:
- .env
environment:
NEO4J_URL: bolt://neo4j:7687
QDRANT_URL: http://qdrant:6333
MONGO_URI: mongodb://root:pwd@mongo:27017/?authSource=admin
ports:
- "8000:8000"
Expand All @@ -50,4 +60,5 @@ services:

volumes:
neo4j-data:
mongodb-data:
mongodb-data:
qdrant-data:
29 changes: 27 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import os
import tempfile
from uuid import uuid4

from services import S3Client
from botocore.exceptions import ClientError

from fastapi import FastAPI, Depends, Request
from starlette.staticfiles import StaticFiles
Expand All @@ -8,8 +13,10 @@
KnowledgeUploadSchema, KnowledgeUpdateResponse,
AgentGraphRAGRequest, AgentGraphRAGResponse,
)

from core import AgentGraphRAGBedRock, ChatManager
from server import SocketManager
from workers import aupload_knowledge_base

app = FastAPI(
title="Chat GraphRAG API",
Expand Down Expand Up @@ -104,20 +111,38 @@ async def update_knowledge(upload: KnowledgeUploadSchema = Depends(KnowledgeUplo
:param upload: The uploaded files containing the knowledge base document.
:return: A confirmation message.
"""
from workers import aupload_knowledge_base
if len(upload.files) == 0:
return KnowledgeUpdateResponse(
success=False,
message="No files uploaded."
)
job_ids: list[str] = []
s3_client = S3Client()
for file in upload.files:
if not file.filename.endswith(('.pdf', '.txt', '.md')):
return KnowledgeUpdateResponse(
success=False,
message="Unsupported file type. Only PDF, TXT, and MD files are allowed."
)
job_id = await aupload_knowledge_base(key=file.filename)
ext = file.filename.split('.')[-1]
key = f"knowledge/{file.filename}-{str(uuid4())}.{ext}"
try:
temp = tempfile.NamedTemporaryFile(delete=False)
with open(temp.name, 'wb') as f:
f.write(file.file.read())
s3_client.upload_file(temp.name, key)
os.remove(temp.name)
except ClientError as e:
return KnowledgeUpdateResponse(
success=False,
message=f"Failed to upload file to S3: {e}"
)
except FileNotFoundError as e:
return KnowledgeUpdateResponse(
success=False,
message=f"File not found: {e}"
)
job_id = await aupload_knowledge_base(key=key)
job_ids.append(job_id)
return KnowledgeUpdateResponse(
success=True,
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ requires-python = ">=3.12"
dependencies = [
"amazon-textract-caller>=0.2.4",
"amazon-textract-textractor>=1.9.2",
"boto3>=1.39.3",
"celery[sqs]>=5.5.3",
"fastapi[standard]>=0.115.14",
"langchain>=0.3.26",
Expand All @@ -15,6 +16,7 @@ dependencies = [
"langchain-mongodb>=0.6.2",
"langchain-neo4j>=0.4.0",
"langchain-openai>=0.3.27",
"langchain-qdrant>=0.2.0",
"langgraph>=0.5.1",
"loguru>=0.7.3",
"neo4j>=5.28.1",
Expand Down
3 changes: 2 additions & 1 deletion schemas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
AgentGraphRAGResponse,
AgentGraphRAGRequest,
)
from .agent_schema import AgentGraphSubquery, AgentGraphRoute, AgentGraphStart
from .agent_schema import AgentGraphSubquery, AgentGraphRoute, AgentGraphStart
from .document_schema import LegalDocumentMetadata
48 changes: 48 additions & 0 deletions schemas/document_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import Optional

from pydantic import BaseModel, Field


class LegalDocumentMetadata(BaseModel):
# Document Identification
title: Optional[str] = Field(None, description="Title of the legal document")
type: Optional[str] = Field(None, description="Type of the legal document")
case_number: Optional[str] = Field(None, description="Official case number")
document_number: Optional[str] = Field(None, description="Internal document number")
creation_date: Optional[str] = Field(None, description="Date when the document was created")
filing_date: Optional[str] = Field(None, description="Date when the document was filed")
signature_date: Optional[str] = Field(None, description="Date when the document was signed")
version: Optional[str] = Field(None, description="Version or revision of the document")
place_of_issue: Optional[str] = Field(None, description="Place where the document was issued")

# Parties Involved
plaintiffs: Optional[list[str]] = Field(None, description="List of plaintiffs")
defendants: Optional[list[str]] = Field(None, description="List of defendants")
lawyers: Optional[list[str]] = Field(None, description="List of lawyers involved")
bar_number: Optional[list[str]] = Field(None, description="List of bar registration numbers")
legal_representatives: Optional[list[str]] = Field(None, description="Other legal representatives")
judge_or_rapporteur: Optional[str] = Field(None, description="Name of judge or rapporteur")
third_parties: Optional[list[str]] = Field(None, description="Interested third parties")

# Procedural Data
court: Optional[str] = Field(None, description="Court where the case is processed")
jurisdiction: Optional[str] = Field(None, description="Jurisdiction")
district: Optional[str] = Field(None, description="Judicial district")
adjudicating_body: Optional[str] = Field(None, description="Adjudicating body or chamber")
case_class: Optional[str] = Field(None, description="Class of the legal action")
nature_of_action: Optional[str] = Field(None, description="Nature of the action")
main_subject: Optional[str] = Field(None, description="Main subject of the action")
secondary_subjects: Optional[list[str]] = Field(None, description="Other related subjects")
case_progress: Optional[str] = Field(None, description="Current case progress stage")
case_stage: Optional[str] = Field(None, description="Current stage in process")

# Legal Information
legal_basis: Optional[list[str]] = Field(None, description="Articles, laws or norms cited")
jurisprudence: Optional[list[str]] = Field(None, description="Precedents or case law cited")
legal_thesis: Optional[str] = Field(None, description="Legal thesis argued")
claims: Optional[list[str]] = Field(None, description="Claims requested")
legal_reasoning: Optional[str] = Field(None, description="Legal reasoning or justification")
provisions: Optional[list[str]] = Field(None, description="Provisions applied")
decision: Optional[str] = Field(None, description="Decision content")
case_value: Optional[str] = Field(None, description="Value attributed to the case")
attorney_fees: Optional[str] = Field(None, description="Agreed or court-appointed attorney fees")
1 change: 1 addition & 0 deletions services/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .s3_client import S3Client
44 changes: 44 additions & 0 deletions services/s3_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import boto3
from config import env


class S3Client:
def __init__(self, bucket_name: str = env.S3_BUCKET_NAME):
self._bucket_name = bucket_name
self._session = boto3.Session(
aws_access_key_id=env.AWS_ACCESS_KEY_ID,
aws_secret_access_key=env.AWS_SECRET_ACCESS_KEY,
region_name=env.AWS_REGION,
)
self._client = self._session.client('s3')

def upload_file(self, file_path: str, object_name: str) -> None:
"""
Upload a file to the S3 bucket.
:param file_path: The path to the file to upload.
:param object_name: The name of the object in the S3 bucket.
"""
try:
self._client.upload_file(file_path, self._bucket_name, object_name)
print(f"File {file_path} uploaded to {self._bucket_name}/{object_name}.")
except Exception as e:
print(f"Error uploading file: {e}")

def delete_object(self, file_path: str) -> None:
"""
Delete a file from the S3 bucket.
:param file_path: The path to the file to delete.
"""
try:
self._client.delete_object(Bucket=self._bucket_name, Key=file_path)
print(f"File {file_path} deleted from {self._bucket_name}.")
except Exception as e:
print(f"Error deleting file: {e}")

@property
def bucket_name(self) -> str:
"""
Get the name of the S3 bucket.
:return: The name of the S3 bucket.
"""
return self._bucket_name
Loading
Loading