Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 37 additions & 36 deletions models/retriever/faiss_filter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
import time
import hashlib
from collections import defaultdict
from itertools import combinations
from typing import Dict, List, Set, Tuple
Expand All @@ -26,10 +27,10 @@ def __init__(self, dataset, graph: nx.MultiDiGraph, model_name: str = "all-MiniL
self.cache_dir = cache_dir
os.makedirs(cache_dir, exist_ok=True)
self.dataset = dataset

# Create dataset-specific cache directory
dataset_cache_dir = f"{self.cache_dir}/{self.dataset}"
os.makedirs(dataset_cache_dir, exist_ok=True)
# 使用 ASCII-safe 目录名避免 Windows + FAISS 处理中文路径失败
safe_suffix = hashlib.md5(str(dataset).encode("utf-8")).hexdigest()[:8]
self.dataset_cache_dir = os.path.join(self.cache_dir, safe_suffix)
os.makedirs(self.dataset_cache_dir, exist_ok=True)

self.triple_index = None
self.comm_index = None
Expand Down Expand Up @@ -551,7 +552,7 @@ def clear_embedding_cache(self, max_cache_size: int = 10000):

def save_embedding_cache(self):
"""Save embedding cache to disk using numpy format to avoid pickle issues"""
cache_path = f"{self.cache_dir}/{self.dataset}/node_embedding_cache.pt"
cache_path = f"{self.dataset_cache_dir}/node_embedding_cache.pt"
try:
if not self.node_embedding_cache:
return False
Expand Down Expand Up @@ -601,7 +602,7 @@ def save_embedding_cache(self):

def load_embedding_cache(self):
"""从磁盘加载嵌入缓存"""
cache_path = f"{self.cache_dir}/{self.dataset}/node_embedding_cache.pt"
cache_path = f"{self.dataset_cache_dir}/node_embedding_cache.pt"
if os.path.exists(cache_path):
try:
file_size = os.path.getsize(cache_path)
Expand Down Expand Up @@ -787,14 +788,14 @@ def _precompute_node_embeddings(self, batch_size: int = 100, force_recompute: bo
def build_indices(self):
"""Build FAISS Index only if they don't already exist and are consistent with current graph"""
# Check if all indices and embedding files already exist
node_path = f"{self.cache_dir}/{self.dataset}/node.index"
relation_path = f"{self.cache_dir}/{self.dataset}/relation.index"
triple_path = f"{self.cache_dir}/{self.dataset}/triple.index"
comm_path = f"{self.cache_dir}/{self.dataset}/comm.index"
node_embed_path = f"{self.cache_dir}/{self.dataset}/node_embeddings.pt"
relation_embed_path = f"{self.cache_dir}/{self.dataset}/relation_embeddings.pt"
node_map_path = f"{self.cache_dir}/{self.dataset}/node_map.json"
dim_transform_path = f"{self.cache_dir}/{self.dataset}/dim_transform.pt"
node_path = f"{self.dataset_cache_dir}/node.index"
relation_path = f"{self.dataset_cache_dir}/relation.index"
triple_path = f"{self.dataset_cache_dir}/triple.index"
comm_path = f"{self.dataset_cache_dir}/comm.index"
node_embed_path = f"{self.dataset_cache_dir}/node_embeddings.pt"
relation_embed_path = f"{self.dataset_cache_dir}/relation_embeddings.pt"
node_map_path = f"{self.dataset_cache_dir}/node_map.json"
dim_transform_path = f"{self.dataset_cache_dir}/dim_transform.pt"

all_exist = (os.path.exists(node_path) and
os.path.exists(relation_path) and
Expand Down Expand Up @@ -883,7 +884,7 @@ def build_indices(self):

def _save_dim_transform(self):
"""Save dimension transform state to disk"""
dim_transform_path = f"{self.cache_dir}/{self.dataset}/dim_transform.pt"
dim_transform_path = f"{self.dataset_cache_dir}/dim_transform.pt"
try:
save_data = {
'model_dim': self.model_dim,
Expand All @@ -901,7 +902,7 @@ def _save_dim_transform(self):

def _load_dim_transform(self):
"""Load dimension transform state from disk"""
dim_transform_path = f"{self.cache_dir}/{self.dataset}/dim_transform.pt"
dim_transform_path = f"{self.dataset_cache_dir}/dim_transform.pt"
if not os.path.exists(dim_transform_path):
return False

Expand Down Expand Up @@ -956,7 +957,7 @@ def _build_node_index(self):
# Store embeddings on CPU to save GPU memory
self.node_embeddings = embeddings.cpu()
# Save as .pt for consistency across the codebase
torch.save(self.node_embeddings, f"{self.cache_dir}/{self.dataset}/node_embeddings.pt")
torch.save(self.node_embeddings, f"{self.dataset_cache_dir}/node_embeddings.pt")

# Build FAISS index
embeddings_np = embeddings.cpu().numpy()
Expand All @@ -965,9 +966,9 @@ def _build_node_index(self):
faiss.normalize_L2(embeddings_np)
index.add(embeddings_np)

faiss.write_index(index, f"{self.cache_dir}/{self.dataset}/node.index")
faiss.write_index(index, f"{self.dataset_cache_dir}/node.index")
self.node_map = {str(i): n for i, n in enumerate(nodes)}
with open(f"{self.cache_dir}/{self.dataset}/node_map.json", 'w') as f:
with open(f"{self.dataset_cache_dir}/node_map.json", 'w') as f:
json.dump(self.node_map, f)

self.node_index = index
Expand All @@ -983,7 +984,7 @@ def _build_relation_index(self):
# Store embeddings on CPU
self.relation_embeddings = embeddings.cpu()
# Save as .pt for consistency across the codebase
torch.save(self.relation_embeddings, f"{self.cache_dir}/{self.dataset}/relation_embeddings.pt")
torch.save(self.relation_embeddings, f"{self.dataset_cache_dir}/relation_embeddings.pt")

# Build FAISS index
embeddings_np = embeddings.cpu().numpy()
Expand All @@ -992,9 +993,9 @@ def _build_relation_index(self):
faiss.normalize_L2(embeddings_np)
index.add(embeddings_np)

faiss.write_index(index, f"{self.cache_dir}/{self.dataset}/relation.index")
faiss.write_index(index, f"{self.dataset_cache_dir}/relation.index")
self.relation_map = {str(i): r for i, r in enumerate(relations)}
with open(f"{self.cache_dir}/{self.dataset}/relation_map.json", 'w') as f:
with open(f"{self.dataset_cache_dir}/relation_map.json", 'w') as f:
json.dump(self.relation_map, f)

self.relation_index = index
Expand All @@ -1014,8 +1015,8 @@ def _build_triple_index(self):
faiss.normalize_L2(embeddings)
index.add(embeddings)

faiss.write_index(index, f"{self.cache_dir}/{self.dataset}/triple.index")
with open(f"{self.cache_dir}/{self.dataset}/triple_map.json", 'w') as f:
faiss.write_index(index, f"{self.dataset_cache_dir}/triple.index")
with open(f"{self.dataset_cache_dir}/triple_map.json", 'w') as f:
json.dump({i: n for i, n in enumerate(triples)}, f)

self.triple_index = index
Expand Down Expand Up @@ -1050,21 +1051,21 @@ def _build_community_index(self):
faiss.normalize_L2(embeddings)
index.add(embeddings)

faiss.write_index(index, f"{self.cache_dir}/{self.dataset}/comm.index")
with open(f"{self.cache_dir}/{self.dataset}/comm_map.json", 'w') as f:
faiss.write_index(index, f"{self.dataset_cache_dir}/comm.index")
with open(f"{self.dataset_cache_dir}/comm_map.json", 'w') as f:
json.dump({i: n for i, n in enumerate(valid_communities)}, f)

self.comm_index = index
self.comm_map = {str(i): n for i, n in enumerate(valid_communities)}

def _load_indices(self):
logger.info("Starting _load_indices...")
triple_path = f"{self.cache_dir}/{self.dataset}/triple.index"
comm_path = f"{self.cache_dir}/{self.dataset}/comm.index"
node_path = f"{self.cache_dir}/{self.dataset}/node.index"
relation_path = f"{self.cache_dir}/{self.dataset}/relation.index"
node_embed_path = f"{self.cache_dir}/{self.dataset}/node_embeddings.pt"
relation_embed_path = f"{self.cache_dir}/{self.dataset}/relation_embeddings.pt"
triple_path = f"{self.dataset_cache_dir}/triple.index"
comm_path = f"{self.dataset_cache_dir}/comm.index"
node_path = f"{self.dataset_cache_dir}/node.index"
relation_path = f"{self.dataset_cache_dir}/relation.index"
node_embed_path = f"{self.dataset_cache_dir}/node_embeddings.pt"
relation_embed_path = f"{self.dataset_cache_dir}/relation_embeddings.pt"

logger.debug(f"Checking cache files...")
logger.debug(f"node_path exists: {os.path.exists(node_path)}")
Expand All @@ -1077,22 +1078,22 @@ def _load_indices(self):
if os.path.exists(node_path):
logger.debug("Loading node index...")
self.node_index = faiss.read_index(node_path)
with open(f"{self.cache_dir}/{self.dataset}/node_map.json", 'r') as f:
with open(f"{self.dataset_cache_dir}/node_map.json", 'r') as f:
self.node_map = json.load(f)

if os.path.exists(relation_path):
self.relation_index = faiss.read_index(relation_path)
with open(f"{self.cache_dir}/{self.dataset}/relation_map.json", 'r') as f:
with open(f"{self.dataset_cache_dir}/relation_map.json", 'r') as f:
self.relation_map = json.load(f)

if os.path.exists(triple_path):
self.triple_index = faiss.read_index(triple_path)
with open(f"{self.cache_dir}/{self.dataset}/triple_map.json", 'r') as f:
with open(f"{self.dataset_cache_dir}/triple_map.json", 'r') as f:
self.triple_map = json.load(f)

if os.path.exists(comm_path):
self.comm_index = faiss.read_index(comm_path)
with open(f"{self.cache_dir}/{self.dataset}/comm_map.json", 'r') as f:
with open(f"{self.dataset_cache_dir}/comm_map.json", 'r') as f:
self.comm_map = json.load(f)

if os.path.exists(node_embed_path):
Expand Down