From 1d5b503ff647b15e1296546c7a2c33d60254cc0b Mon Sep 17 00:00:00 2001 From: prodesk98 Date: Sun, 20 Jul 2025 19:43:54 -0300 Subject: [PATCH] Enhance agent and knowledge modules: rename workflow variable, add callback handler, and implement error handling for metadata extraction --- core/agent.py | 4 +-- workers/knowledge.py | 75 +++++++++++++++++++++++++++++++++----------- 2 files changed, 58 insertions(+), 21 deletions(-) diff --git a/core/agent.py b/core/agent.py index 8073dbb..98ad053 100644 --- a/core/agent.py +++ b/core/agent.py @@ -89,7 +89,7 @@ async def invoke(self, question: str, **kwargs) -> str: } ) # Compile the workflow - app = workflow.compile() + workflow_graph = workflow.compile() # Run the workflow __initial_state = { "question": question, @@ -99,5 +99,5 @@ async def invoke(self, question: str, **kwargs) -> str: "depth": 0, "answer": "" } - result = await app.ainvoke(__initial_state) # type: ignore + result = await workflow_graph.ainvoke(__initial_state) # type: ignore return result["answer"] diff --git a/workers/knowledge.py b/workers/knowledge.py index 046372e..b675b36 100644 --- a/workers/knowledge.py +++ b/workers/knowledge.py @@ -2,12 +2,13 @@ import json from typing import Optional, cast, Dict, Any, Tuple, List, Union, Type, Literal -from uuid import uuid4 +from uuid import uuid4, UUID from langchain_community.document_loaders import S3FileLoader, AmazonTextractPDFLoader from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship +from langchain_core.callbacks import BaseCallbackHandler from langchain_core.language_models import BaseLanguageModel -from langchain_core.messages import SystemMessage +from langchain_core.messages import SystemMessage, BaseMessage from langchain_core.output_parsers import JsonOutputParser from langchain_core.prompts import ChatPromptTemplate, PromptTemplate, HumanMessagePromptTemplate from langchain_core.runnables import RunnableConfig @@ -26,6 +27,12 @@ from services import S3Client from vectorstore import QdrantClientManager +from tenacity import ( + retry, + stop_after_attempt, + wait_random_exponential, +) + examples_: list[dict[str, str]] = [ { @@ -616,6 +623,27 @@ def _convert_to_graph_document( BATCH_SIZE = 10 # Number of documents to process in a batch +class CallBackHandler(BaseCallbackHandler): + def on_chat_model_start( + self, + serialized: dict[str, Any], + messages: list[list[BaseMessage]], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[list[str]] = None, # noqa + metadata: Optional[dict[str, Any]] = None, # noqa + **kwargs: Any, + ) -> Any: + """ + Handle the start of a chat model run. + :param serialized: The serialized representation of the chat model. + :param messages: The messages being processed. + :param run_id: The unique identifier for the run. + :param parent_run_id: The parent run identifier, if any. + """ + print(f"{run_id}: {messages}") + class LLMGraph(LLMGraphTransformer): def __init__(self, llm: BaseLanguageModel, prompt: Optional[ChatPromptTemplate] = None): """ @@ -624,6 +652,7 @@ def __init__(self, llm: BaseLanguageModel, prompt: Optional[ChatPromptTemplate] """ super().__init__(llm=llm, allowed_nodes=nodes_, allowed_relationships=relationships_, prompt=prompt) + @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) def process_batch( self, documents: list[Document], config: Optional[RunnableConfig] = None ) -> list[GraphDocument]: @@ -875,6 +904,7 @@ def _get_llm(llm_name: Literal['openai', 'bedrock'] = 'bedrock') -> BaseLanguage region=env.AWS_REGION, aws_access_key_id=env.AWS_ACCESS_KEY_ID, aws_secret_access_key=env.AWS_SECRET_ACCESS_KEY, + max_tokens=2048, ) raise ValueError("Unsupported LLM type. Use 'openai' or 'bedrock'.") @@ -896,8 +926,12 @@ def process(self, key: str): # Convert the documents to graph documents using LLMGraphTransformer llm = self._get_llm() # Create the LLMGraphTransformer with the allowed nodes and relationships - llm_graph = LLMGraph(llm) - graph_documents = llm_graph.process_batch(documents) + llm_graph = LLMGraph(llm, prompt=self._create_unstructured_relationships_prompt( + node_labels=nodes_, + rel_types=relationships_, + )) + config = RunnableConfig(callbacks=[CallBackHandler()]) + graph_documents = llm_graph.process_batch(documents, config) # Connect to Neo4j and add the graph documents graph = Neo4jGraph(url=env.NEO4J_URL, username=env.NEO4J_USERNAME, password=env.NEO4J_PASSWORD) @@ -919,21 +953,24 @@ def process(self, key: str): chain_legal_document = EXTRACT_ENTITIES_PROMPT | structured # Iterate over the texts and extract metadata for text in texts: - # Extract metadata from the document - metadata_extraction_result: LegalDocumentMetadata = chain_legal_document.invoke( # type: ignore - {"entities": ", ".join(legal_document_metadata_keys_), "text": text}) - # Parse the metadata extraction result - metadata: dict = metadata_extraction_result.model_dump(exclude_none=True) - for k in metadata.keys(): - if k in ['document_id', 'document_hash', 'source']: continue - if k in list(metadatas.keys()): - if isinstance(metadatas[k], list): - metadatas[k].extend(metadata[k]) - continue - if isinstance(metadatas[k], str): - metadatas[k] += "\n" + metadata[k] - else: - metadatas[k] = metadata[k] + try: + # Extract metadata from the document + metadata_extraction_result: LegalDocumentMetadata = chain_legal_document.invoke( # type: ignore + {"entities": ", ".join(legal_document_metadata_keys_), "text": text}) + # Parse the metadata extraction result + metadata: dict = metadata_extraction_result.model_dump(exclude_none=True) + for k in metadata.keys(): + if k in ['document_id', 'document_hash', 'source']: continue + if k in list(metadatas.keys()): + if isinstance(metadatas[k], list): + metadatas[k].extend(metadata[k]) + continue + if isinstance(metadatas[k], str): + metadatas[k] += "\n" + metadata[k] + else: + metadatas[k] = metadata[k] + except Exception as e_: + print(f"Error extracting metadata from text chunk: {e_}") # Add the document to the vector database vectorstore = QdrantClientManager()