|
| 1 | +import sys |
| 2 | +import os |
| 3 | + |
| 4 | +# 添加 RAG-Factory 目录到 Python 路径 |
| 5 | +rag_factory_path = os.path.join(os.path.dirname(__file__), "..", "..") |
| 6 | +sys.path.insert(0, rag_factory_path) |
| 7 | + |
| 8 | +from rag_factory.llms import LLMRegistry |
| 9 | +from rag_factory.Embed import EmbeddingRegistry |
| 10 | +from rag_factory.Store import VectorStoreRegistry |
| 11 | +from rag_factory.Retrieval import RetrieverRegistry |
| 12 | +from rag_factory.rerankers import RerankerRegistry |
| 13 | +from rag_factory.Retrieval import Document |
| 14 | +from typing import List |
| 15 | +import json |
| 16 | + |
| 17 | + |
| 18 | +class TCL_RAG: |
| 19 | + def __init__( |
| 20 | + self, |
| 21 | + *, |
| 22 | + llm_config=None, |
| 23 | + embedding_config=None, |
| 24 | + vector_store_config=None, |
| 25 | + bm25_retriever_config=None, |
| 26 | + retriever_config=None, |
| 27 | + reranker_config=None, |
| 28 | + ): |
| 29 | + llm_config = llm_config or {} |
| 30 | + embedding_config = embedding_config or {} |
| 31 | + vector_store_config = vector_store_config or {} |
| 32 | + bm25_retriever_config = bm25_retriever_config or {} |
| 33 | + retriever_config = retriever_config or {} |
| 34 | + reranker_config = reranker_config or {} |
| 35 | + self.llm = LLMRegistry.create(**llm_config) |
| 36 | + self.embedding = EmbeddingRegistry.create(**embedding_config) |
| 37 | + self.vector_store = VectorStoreRegistry.load(**vector_store_config, embedding=self.embedding) |
| 38 | + self.bm25_retriever = RetrieverRegistry.create(**bm25_retriever_config) |
| 39 | + self.bm25_retriever = self.bm25_retriever.from_documents(documents=self._load_data(bm25_retriever_config["data_path"]), preprocess_func=self.chinese_preprocessing_func, k=bm25_retriever_config["k"]) |
| 40 | + |
| 41 | + self.retriever = RetrieverRegistry.create(**retriever_config, vectorstore=self.vector_store) |
| 42 | + self.multi_path_retriever = RetrieverRegistry.create("multipath", retrievers=[self.bm25_retriever, self.retriever]) |
| 43 | + self.reranker = RerankerRegistry.create(**reranker_config) |
| 44 | + |
| 45 | + def invoke(self, query: str, k: int = None): |
| 46 | + return self.multi_path_retriever.invoke(query, top_k=k) |
| 47 | + |
| 48 | + def rerank(self, query: str, documents: List[Document], k: int = None, batch_size: int = 8): |
| 49 | + return self.reranker.rerank(query, documents, k, batch_size) |
| 50 | + |
| 51 | + def _load_data(self, data_path: str): |
| 52 | + with open(data_path, "r", encoding="utf-8") as f: |
| 53 | + data = json.load(f) |
| 54 | + docs = [] |
| 55 | + for item in data: |
| 56 | + content = item.get("full_content", "") |
| 57 | + metadata = {"title": item.get("original_filename", "")} |
| 58 | + docs.append(Document(content=content, metadata=metadata)) |
| 59 | + return docs |
| 60 | + |
| 61 | + def chinese_preprocessing_func(self, text: str) -> str: |
| 62 | + import jieba |
| 63 | + return " ".join(jieba.cut(text)) |
| 64 | + |
| 65 | + |
| 66 | + def answer(self, query: str, documents: List[Document]): |
| 67 | + |
| 68 | + template = ( |
| 69 | + "你是一位工业领域的专家。根据以下检索到的材料回答用户问题。" |
| 70 | + "如果回答所需信息未在材料中出现,请说明无法找到相关信息。\n\n" |
| 71 | + "{context}\n\n" |
| 72 | + "用户问题:{question}\n" |
| 73 | + "答复:" |
| 74 | + ) |
| 75 | + context = "\n".join([doc.content for doc in documents]) |
| 76 | + prompt = template.format(question=query, context=context) |
| 77 | + messages = [ |
| 78 | + {"role": "system", "content": "你是一位工业领域的专家。"}, |
| 79 | + {"role": "user", "content": prompt} |
| 80 | + ] |
| 81 | + return self.llm.chat(messages) |
| 82 | + |
| 83 | + |
| 84 | + |
| 85 | + |
0 commit comments