Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ async def invoke(self, question: str, **kwargs) -> str:
}
)
# Compile the workflow
app = workflow.compile()
workflow_graph = workflow.compile()
# Run the workflow
__initial_state = {
"question": question,
Expand All @@ -99,5 +99,5 @@ async def invoke(self, question: str, **kwargs) -> str:
"depth": 0,
"answer": ""
}
result = await app.ainvoke(__initial_state) # type: ignore
result = await workflow_graph.ainvoke(__initial_state) # type: ignore
return result["answer"]
75 changes: 56 additions & 19 deletions workers/knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import json

from typing import Optional, cast, Dict, Any, Tuple, List, Union, Type, Literal
from uuid import uuid4
from uuid import uuid4, UUID

from langchain_community.document_loaders import S3FileLoader, AmazonTextractPDFLoader
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import SystemMessage
from langchain_core.messages import SystemMessage, BaseMessage
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate, HumanMessagePromptTemplate
from langchain_core.runnables import RunnableConfig
Expand All @@ -26,6 +27,12 @@
from services import S3Client
from vectorstore import QdrantClientManager

from tenacity import (
retry,
stop_after_attempt,
wait_random_exponential,
)


examples_: list[dict[str, str]] = [
{
Expand Down Expand Up @@ -616,6 +623,27 @@ def _convert_to_graph_document(
BATCH_SIZE = 10 # Number of documents to process in a batch


class CallBackHandler(BaseCallbackHandler):
def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[list[str]] = None, # noqa
metadata: Optional[dict[str, Any]] = None, # noqa
**kwargs: Any,
) -> Any:
"""
Handle the start of a chat model run.
:param serialized: The serialized representation of the chat model.
:param messages: The messages being processed.
:param run_id: The unique identifier for the run.
:param parent_run_id: The parent run identifier, if any.
"""
print(f"{run_id}: {messages}")

class LLMGraph(LLMGraphTransformer):
def __init__(self, llm: BaseLanguageModel, prompt: Optional[ChatPromptTemplate] = None):
"""
Expand All @@ -624,6 +652,7 @@ def __init__(self, llm: BaseLanguageModel, prompt: Optional[ChatPromptTemplate]
"""
super().__init__(llm=llm, allowed_nodes=nodes_, allowed_relationships=relationships_, prompt=prompt)

@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def process_batch(
self, documents: list[Document], config: Optional[RunnableConfig] = None
) -> list[GraphDocument]:
Expand Down Expand Up @@ -875,6 +904,7 @@ def _get_llm(llm_name: Literal['openai', 'bedrock'] = 'bedrock') -> BaseLanguage
region=env.AWS_REGION,
aws_access_key_id=env.AWS_ACCESS_KEY_ID,
aws_secret_access_key=env.AWS_SECRET_ACCESS_KEY,
max_tokens=2048,
)
raise ValueError("Unsupported LLM type. Use 'openai' or 'bedrock'.")

Expand All @@ -896,8 +926,12 @@ def process(self, key: str):
# Convert the documents to graph documents using LLMGraphTransformer
llm = self._get_llm()
# Create the LLMGraphTransformer with the allowed nodes and relationships
llm_graph = LLMGraph(llm)
graph_documents = llm_graph.process_batch(documents)
llm_graph = LLMGraph(llm, prompt=self._create_unstructured_relationships_prompt(
node_labels=nodes_,
rel_types=relationships_,
))
config = RunnableConfig(callbacks=[CallBackHandler()])
graph_documents = llm_graph.process_batch(documents, config)

# Connect to Neo4j and add the graph documents
graph = Neo4jGraph(url=env.NEO4J_URL, username=env.NEO4J_USERNAME, password=env.NEO4J_PASSWORD)
Expand All @@ -919,21 +953,24 @@ def process(self, key: str):
chain_legal_document = EXTRACT_ENTITIES_PROMPT | structured
# Iterate over the texts and extract metadata
for text in texts:
# Extract metadata from the document
metadata_extraction_result: LegalDocumentMetadata = chain_legal_document.invoke( # type: ignore
{"entities": ", ".join(legal_document_metadata_keys_), "text": text})
# Parse the metadata extraction result
metadata: dict = metadata_extraction_result.model_dump(exclude_none=True)
for k in metadata.keys():
if k in ['document_id', 'document_hash', 'source']: continue
if k in list(metadatas.keys()):
if isinstance(metadatas[k], list):
metadatas[k].extend(metadata[k])
continue
if isinstance(metadatas[k], str):
metadatas[k] += "\n" + metadata[k]
else:
metadatas[k] = metadata[k]
try:
# Extract metadata from the document
metadata_extraction_result: LegalDocumentMetadata = chain_legal_document.invoke( # type: ignore
{"entities": ", ".join(legal_document_metadata_keys_), "text": text})
# Parse the metadata extraction result
metadata: dict = metadata_extraction_result.model_dump(exclude_none=True)
for k in metadata.keys():
if k in ['document_id', 'document_hash', 'source']: continue
if k in list(metadatas.keys()):
if isinstance(metadatas[k], list):
metadatas[k].extend(metadata[k])
continue
if isinstance(metadatas[k], str):
metadatas[k] += "\n" + metadata[k]
else:
metadatas[k] = metadata[k]
except Exception as e_:
print(f"Error extracting metadata from text chunk: {e_}")

# Add the document to the vector database
vectorstore = QdrantClientManager()
Expand Down
Loading