Skip to content

Commit ec609ea

Browse files
authored
Merge pull request #11 from DataArcTech/main
merge main
2 parents 3946d22 + 306fcf3 commit ec609ea

49 files changed

Lines changed: 3887 additions & 1257 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

examples/TCL_rag/config.yaml

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
llm:
2+
name: openai
3+
base_url: "https://api.gptsapi.net/v1"
4+
api_key: "sk-2T06b7c7f9c3870049fbf8fada596b0f8ef908d1e233KLY2"
5+
model: "gpt-4.1-mini"
6+
7+
embedding:
8+
name: huggingface
9+
model_name: "/finance_ML/dataarc_syn_database/model/Qwen/qwen_embedding_0.6B"
10+
model_kwargs:
11+
device: "cuda:0"
12+
13+
14+
15+
store:
16+
name: faiss
17+
folder_path: /data/FinAi_Mapping_Knowledge/chenmingzhen/test_faiss_store
18+
19+
20+
bm25:
21+
name: bm25
22+
k: 10
23+
data_path: /data/FinAi_Mapping_Knowledge/chenmingzhen/tog3_backend/TCL/syn_table_data/data_all_clearn_short_chunk_with_caption_desc.json
24+
25+
retriever:
26+
name: vectorstore
27+
28+
reranker:
29+
name: qwen3
30+
model_name_or_path: "/finance_ML/dataarc_syn_database/model/Qwen/qwen_reranker_0.6B"
31+
device_id: "cuda:0"
32+
33+
dataset:
34+
name: TCL
35+

examples/TCL_rag/rag_flow.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import sys
2+
import os
3+
4+
# 添加 RAG-Factory 目录到 Python 路径
5+
rag_factory_path = os.path.join(os.path.dirname(__file__), "..", "..")
6+
sys.path.insert(0, rag_factory_path)
7+
8+
from rag_factory.llms import LLMRegistry
9+
from rag_factory.Embed import EmbeddingRegistry
10+
from rag_factory.Store import VectorStoreRegistry
11+
from rag_factory.Retrieval import RetrieverRegistry
12+
from rag_factory.rerankers import RerankerRegistry
13+
from rag_factory.Retrieval import Document
14+
from typing import List
15+
import json
16+
17+
18+
class TCL_RAG:
19+
def __init__(
20+
self,
21+
*,
22+
llm_config=None,
23+
embedding_config=None,
24+
vector_store_config=None,
25+
bm25_retriever_config=None,
26+
retriever_config=None,
27+
reranker_config=None,
28+
):
29+
llm_config = llm_config or {}
30+
embedding_config = embedding_config or {}
31+
vector_store_config = vector_store_config or {}
32+
bm25_retriever_config = bm25_retriever_config or {}
33+
retriever_config = retriever_config or {}
34+
reranker_config = reranker_config or {}
35+
self.llm = LLMRegistry.create(**llm_config)
36+
self.embedding = EmbeddingRegistry.create(**embedding_config)
37+
self.vector_store = VectorStoreRegistry.load(**vector_store_config, embedding=self.embedding)
38+
self.bm25_retriever = RetrieverRegistry.create(**bm25_retriever_config)
39+
self.bm25_retriever = self.bm25_retriever.from_documents(documents=self._load_data(bm25_retriever_config["data_path"]), preprocess_func=self.chinese_preprocessing_func, k=bm25_retriever_config["k"])
40+
41+
self.retriever = RetrieverRegistry.create(**retriever_config, vectorstore=self.vector_store)
42+
self.multi_path_retriever = RetrieverRegistry.create("multipath", retrievers=[self.bm25_retriever, self.retriever])
43+
self.reranker = RerankerRegistry.create(**reranker_config)
44+
45+
def invoke(self, query: str, k: int = None):
46+
return self.multi_path_retriever.invoke(query, top_k=k)
47+
48+
def rerank(self, query: str, documents: List[Document], k: int = None, batch_size: int = 8):
49+
return self.reranker.rerank(query, documents, k, batch_size)
50+
51+
def _load_data(self, data_path: str):
52+
with open(data_path, "r", encoding="utf-8") as f:
53+
data = json.load(f)
54+
docs = []
55+
for item in data:
56+
content = item.get("full_content", "")
57+
metadata = {"title": item.get("original_filename", "")}
58+
docs.append(Document(content=content, metadata=metadata))
59+
return docs
60+
61+
def chinese_preprocessing_func(self, text: str) -> str:
62+
import jieba
63+
return " ".join(jieba.cut(text))
64+
65+
66+
def answer(self, query: str, documents: List[Document]):
67+
68+
template = (
69+
"你是一位工业领域的专家。根据以下检索到的材料回答用户问题。"
70+
"如果回答所需信息未在材料中出现,请说明无法找到相关信息。\n\n"
71+
"{context}\n\n"
72+
"用户问题:{question}\n"
73+
"答复:"
74+
)
75+
context = "\n".join([doc.content for doc in documents])
76+
prompt = template.format(question=query, context=context)
77+
messages = [
78+
{"role": "system", "content": "你是一位工业领域的专家。"},
79+
{"role": "user", "content": prompt}
80+
]
81+
return self.llm.chat(messages)
82+
83+
84+
85+

examples/TCL_rag/test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from rag_flow import TCL_RAG
2+
import yaml
3+
4+
# 加载配置文件
5+
with open('/data/FinAi_Mapping_Knowledge/chenmingzhen/RAG-Factory/examples/TCL_rag/config.yaml', 'r', encoding='utf-8') as f:
6+
config = yaml.safe_load(f)
7+
8+
llm_config = config['llm']
9+
embedding_config = config['embedding']
10+
reranker_config = config['reranker']
11+
bm25_retriever_config = config['bm25']
12+
retriever_config = config['retriever']
13+
vector_store_config = config['store']
14+
15+
16+
17+
18+
if __name__ == "__main__":
19+
20+
rag = TCL_RAG(llm_config=llm_config,
21+
embedding_config=embedding_config,
22+
reranker_config=reranker_config,
23+
retriever_config=retriever_config,
24+
vector_store_config=vector_store_config,
25+
bm25_retriever_config=bm25_retriever_config)
26+
27+
result = rag.invoke("毛细管设计规范按照什么标准",k=20)
28+
29+
answer = rag.answer("毛细管设计规范按照什么标准",result)
30+
31+
32+
print(answer)

examples/bm25/config.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
retriever:
2+
name: bm25
3+
k: 8

examples/bm25/main.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import sys
2+
import os
3+
4+
rag_factory_path = os.path.join(os.path.dirname(__file__), "..", "..")
5+
sys.path.insert(0, rag_factory_path)
6+
7+
import json
8+
from rag_factory.Retrieval import Document
9+
from rag_factory.Retrieval import RetrieverRegistry
10+
11+
import yaml
12+
13+
14+
def load_data(jsonl_path: str):
15+
with open(jsonl_path, "r", encoding="utf-8") as f:
16+
data = json.load(f)
17+
docs = []
18+
for item in data:
19+
content = item.get("full_content", "")
20+
metadata = {"title": item.get("original_title", "")}
21+
docs.append(Document(content=content, metadata=metadata))
22+
return docs
23+
24+
def chinese_preprocessing_func(text: str) -> str:
25+
import jieba
26+
return " ".join(jieba.cut(text))
27+
28+
if __name__ == "__main__":
29+
docs = load_data("/data/FinAi_Mapping_Knowledge/chenmingzhen/tog3_backend/TCL/syn_table_data/data_all_clearn_short_chunk_with_caption_desc.json")
30+
with open("/data/FinAi_Mapping_Knowledge/chenmingzhen/RAG-Factory/examples/bm25/config.yaml", "r", encoding="utf-8") as f:
31+
config = yaml.safe_load(f)
32+
33+
bm25_retriever = RetrieverRegistry.create(**config["retriever"])
34+
bm25_retriever = bm25_retriever.from_documents(documents=docs, preprocess_func=chinese_preprocessing_func, k=config["retriever"]["k"])
35+
36+
print(bm25_retriever.invoke("什么是TCL?"))
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
store:
2+
name: faiss # 数据库
3+
folder_path: /data/FinAi_Mapping_Knowledge/chenmingzhen/test_faiss_store # 保存路径
4+
5+
6+
embedding:
7+
name: huggingface # 嵌入模型
8+
model_name: "/finance_ML/dataarc_syn_database/model/Qwen/qwen_embedding_0.6B" # 模型路径
9+
model_kwargs:
10+
device: "cuda:1" # 设备
11+
12+
dataset:
13+
name: TCL
14+
data_path: /data/FinAi_Mapping_Knowledge/chenmingzhen/tog3_backend/TCL/syn_table_data/data_all_clearn_short_chunk_with_caption_desc.json
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import sys
2+
import os
3+
4+
# 添加 RAG-Factory 目录到 Python 路径
5+
rag_factory_path = os.path.join(os.path.dirname(__file__), "..", "..")
6+
sys.path.insert(0, rag_factory_path)
7+
8+
from rag_factory.Store import VectorStoreRegistry
9+
from rag_factory.Embed import EmbeddingRegistry
10+
import yaml
11+
from rag_factory.Retrieval import Document
12+
import json
13+
14+
15+
with open("/data/FinAi_Mapping_Knowledge/chenmingzhen/RAG-Factory/examples/faiss_construct/config.yaml", "r", encoding="utf-8") as f:
16+
config = yaml.safe_load(f)
17+
18+
store_config = config["store"]
19+
embedding_config = config["embedding"]
20+
dataset_config = config["dataset"]["data_path"]
21+
embedding = EmbeddingRegistry.create(**embedding_config)
22+
store = VectorStoreRegistry.create(**store_config, embedding=embedding)
23+
24+
25+
if __name__ == "__main__":
26+
27+
# 读取数据
28+
with open(dataset_config, "r", encoding="utf-8") as f:
29+
docs = []
30+
data = json.load(f)
31+
for item in data:
32+
full_content = item.get("full_content", "")
33+
metadata = {
34+
"title": item.get("original_filename"),
35+
}
36+
37+
docs.append(Document(content=full_content, metadata=metadata))
38+
39+
# 创建向量库
40+
vectorstore = store.from_documents(docs, embedding=embedding)
41+
42+
# 保存到本地
43+
vectorstore.save_local(store_config["folder_path"])

rag_factory/Embed/Embedding_Base.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
from dataclasses import dataclass
33
import asyncio
44
from concurrent.futures import ThreadPoolExecutor
5+
from typing import List
56

67
class Embeddings(ABC):
78
"""嵌入接口"""
89

910
@abstractmethod
10-
def embed_documents(self, texts: list[str]) -> list[list[float]]:
11+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
1112
"""Embed search docs.
1213
1314
Args:
@@ -19,7 +20,7 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]:
1920
pass
2021

2122
@abstractmethod
22-
def embed_query(self, text: str) -> list[float]:
23+
def embed_query(self, text: str) -> List[float]:
2324
"""Embed query text.
2425
2526
Args:
@@ -30,7 +31,7 @@ def embed_query(self, text: str) -> list[float]:
3031
"""
3132
pass
3233

33-
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
34+
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
3435
"""Asynchronous Embed search docs.
3536
3637
Args:
@@ -43,7 +44,7 @@ async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
4344
ThreadPoolExecutor(), self.embed_documents, texts
4445
)
4546

46-
async def aembed_query(self, text: str) -> list[float]:
47+
async def aembed_query(self, text: str) -> List[float]:
4748
"""Asynchronous Embed query text.
4849
4950
Args:

rag_factory/Embed/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .Embedding_Base import Embeddings
22
from .Embedding_Huggingface import HuggingFaceEmbeddings
3+
from .registry import EmbeddingRegistry
34

4-
__all__ = ["Embeddings", "HuggingFaceEmbeddings"]
5+
__all__ = ["Embeddings", "HuggingFaceEmbeddings", "EmbeddingRegistry"]

rag_factory/Embed/registry.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from typing import Dict, Type, Any, Optional, List
2+
import logging
3+
from .Embedding_Huggingface import HuggingFaceEmbeddings
4+
from .Embedding_Base import Embeddings
5+
6+
class EmbeddingRegistry:
7+
"""嵌入模型注册器,用于管理和创建不同类型的嵌入模型"""
8+
_embeddings: Dict[str, Type[Embeddings]] = {}
9+
10+
@classmethod
11+
def register(cls, name: str, embedding_class: Type[Embeddings]):
12+
"""注册嵌入模型类
13+
14+
Args:
15+
name: 模型名称
16+
embedding_class: 嵌入模型类
17+
"""
18+
cls._embeddings[name] = embedding_class
19+
20+
@classmethod
21+
def create(cls, name: str, **kwargs) -> Embeddings:
22+
"""获取嵌入模型实例
23+
24+
Args:
25+
name: 模型名称
26+
**kwargs: 模型初始化参数
27+
28+
Returns:
29+
嵌入模型实例
30+
31+
Raises:
32+
ValueError: 当模型名称不存在时
33+
"""
34+
if name not in cls._embeddings:
35+
available_embeddings = list(cls._embeddings.keys())
36+
raise ValueError(f"嵌入模型 '{name}' 未注册。可用的模型: {available_embeddings}")
37+
38+
embedding_class = cls._embeddings[name]
39+
return embedding_class(**kwargs)
40+
41+
@classmethod
42+
def list_embeddings(cls) -> List[str]:
43+
"""列出所有已注册的嵌入模型名称
44+
45+
Returns:
46+
已注册的模型名称列表
47+
"""
48+
return list(cls._embeddings.keys())
49+
50+
@classmethod
51+
def is_registered(cls, name: str) -> bool:
52+
"""检查模型是否已注册
53+
54+
Args:
55+
name: 模型名称
56+
57+
Returns:
58+
如果已注册返回True,否则返回False
59+
"""
60+
return name in cls._embeddings
61+
62+
@classmethod
63+
def unregister(cls, name: str) -> bool:
64+
"""取消注册模型
65+
66+
Args:
67+
name: 模型名称
68+
69+
Returns:
70+
成功取消注册返回True,模型不存在返回False
71+
"""
72+
if name in cls._embeddings:
73+
del cls._embeddings[name]
74+
return True
75+
return False
76+
77+
78+
# 注册默认的嵌入模型
79+
EmbeddingRegistry.register("huggingface", HuggingFaceEmbeddings)

0 commit comments

Comments
 (0)