From 9749016926730c384d3d5c2e33b5f6c0fedc304d Mon Sep 17 00:00:00 2001 From: xu <2993985375@qq.com> Date: Tue, 16 Dec 2025 16:42:43 +0800 Subject: [PATCH] Fix FAISS cache path for non-ASCII dataset names on Windows --- models/retriever/faiss_filter.py | 73 ++++++++++++++++---------------- 1 file changed, 37 insertions(+), 36 deletions(-) diff --git a/models/retriever/faiss_filter.py b/models/retriever/faiss_filter.py index b0497b67..c919ee88 100644 --- a/models/retriever/faiss_filter.py +++ b/models/retriever/faiss_filter.py @@ -1,6 +1,7 @@ import json import os import time +import hashlib from collections import defaultdict from itertools import combinations from typing import Dict, List, Set, Tuple @@ -26,10 +27,10 @@ def __init__(self, dataset, graph: nx.MultiDiGraph, model_name: str = "all-MiniL self.cache_dir = cache_dir os.makedirs(cache_dir, exist_ok=True) self.dataset = dataset - - # Create dataset-specific cache directory - dataset_cache_dir = f"{self.cache_dir}/{self.dataset}" - os.makedirs(dataset_cache_dir, exist_ok=True) + # 使用 ASCII-safe 目录名避免 Windows + FAISS 处理中文路径失败 + safe_suffix = hashlib.md5(str(dataset).encode("utf-8")).hexdigest()[:8] + self.dataset_cache_dir = os.path.join(self.cache_dir, safe_suffix) + os.makedirs(self.dataset_cache_dir, exist_ok=True) self.triple_index = None self.comm_index = None @@ -551,7 +552,7 @@ def clear_embedding_cache(self, max_cache_size: int = 10000): def save_embedding_cache(self): """Save embedding cache to disk using numpy format to avoid pickle issues""" - cache_path = f"{self.cache_dir}/{self.dataset}/node_embedding_cache.pt" + cache_path = f"{self.dataset_cache_dir}/node_embedding_cache.pt" try: if not self.node_embedding_cache: return False @@ -601,7 +602,7 @@ def save_embedding_cache(self): def load_embedding_cache(self): """从磁盘加载嵌入缓存""" - cache_path = f"{self.cache_dir}/{self.dataset}/node_embedding_cache.pt" + cache_path = f"{self.dataset_cache_dir}/node_embedding_cache.pt" if os.path.exists(cache_path): try: file_size = os.path.getsize(cache_path) @@ -787,14 +788,14 @@ def _precompute_node_embeddings(self, batch_size: int = 100, force_recompute: bo def build_indices(self): """Build FAISS Index only if they don't already exist and are consistent with current graph""" # Check if all indices and embedding files already exist - node_path = f"{self.cache_dir}/{self.dataset}/node.index" - relation_path = f"{self.cache_dir}/{self.dataset}/relation.index" - triple_path = f"{self.cache_dir}/{self.dataset}/triple.index" - comm_path = f"{self.cache_dir}/{self.dataset}/comm.index" - node_embed_path = f"{self.cache_dir}/{self.dataset}/node_embeddings.pt" - relation_embed_path = f"{self.cache_dir}/{self.dataset}/relation_embeddings.pt" - node_map_path = f"{self.cache_dir}/{self.dataset}/node_map.json" - dim_transform_path = f"{self.cache_dir}/{self.dataset}/dim_transform.pt" + node_path = f"{self.dataset_cache_dir}/node.index" + relation_path = f"{self.dataset_cache_dir}/relation.index" + triple_path = f"{self.dataset_cache_dir}/triple.index" + comm_path = f"{self.dataset_cache_dir}/comm.index" + node_embed_path = f"{self.dataset_cache_dir}/node_embeddings.pt" + relation_embed_path = f"{self.dataset_cache_dir}/relation_embeddings.pt" + node_map_path = f"{self.dataset_cache_dir}/node_map.json" + dim_transform_path = f"{self.dataset_cache_dir}/dim_transform.pt" all_exist = (os.path.exists(node_path) and os.path.exists(relation_path) and @@ -883,7 +884,7 @@ def build_indices(self): def _save_dim_transform(self): """Save dimension transform state to disk""" - dim_transform_path = f"{self.cache_dir}/{self.dataset}/dim_transform.pt" + dim_transform_path = f"{self.dataset_cache_dir}/dim_transform.pt" try: save_data = { 'model_dim': self.model_dim, @@ -901,7 +902,7 @@ def _save_dim_transform(self): def _load_dim_transform(self): """Load dimension transform state from disk""" - dim_transform_path = f"{self.cache_dir}/{self.dataset}/dim_transform.pt" + dim_transform_path = f"{self.dataset_cache_dir}/dim_transform.pt" if not os.path.exists(dim_transform_path): return False @@ -956,7 +957,7 @@ def _build_node_index(self): # Store embeddings on CPU to save GPU memory self.node_embeddings = embeddings.cpu() # Save as .pt for consistency across the codebase - torch.save(self.node_embeddings, f"{self.cache_dir}/{self.dataset}/node_embeddings.pt") + torch.save(self.node_embeddings, f"{self.dataset_cache_dir}/node_embeddings.pt") # Build FAISS index embeddings_np = embeddings.cpu().numpy() @@ -965,9 +966,9 @@ def _build_node_index(self): faiss.normalize_L2(embeddings_np) index.add(embeddings_np) - faiss.write_index(index, f"{self.cache_dir}/{self.dataset}/node.index") + faiss.write_index(index, f"{self.dataset_cache_dir}/node.index") self.node_map = {str(i): n for i, n in enumerate(nodes)} - with open(f"{self.cache_dir}/{self.dataset}/node_map.json", 'w') as f: + with open(f"{self.dataset_cache_dir}/node_map.json", 'w') as f: json.dump(self.node_map, f) self.node_index = index @@ -983,7 +984,7 @@ def _build_relation_index(self): # Store embeddings on CPU self.relation_embeddings = embeddings.cpu() # Save as .pt for consistency across the codebase - torch.save(self.relation_embeddings, f"{self.cache_dir}/{self.dataset}/relation_embeddings.pt") + torch.save(self.relation_embeddings, f"{self.dataset_cache_dir}/relation_embeddings.pt") # Build FAISS index embeddings_np = embeddings.cpu().numpy() @@ -992,9 +993,9 @@ def _build_relation_index(self): faiss.normalize_L2(embeddings_np) index.add(embeddings_np) - faiss.write_index(index, f"{self.cache_dir}/{self.dataset}/relation.index") + faiss.write_index(index, f"{self.dataset_cache_dir}/relation.index") self.relation_map = {str(i): r for i, r in enumerate(relations)} - with open(f"{self.cache_dir}/{self.dataset}/relation_map.json", 'w') as f: + with open(f"{self.dataset_cache_dir}/relation_map.json", 'w') as f: json.dump(self.relation_map, f) self.relation_index = index @@ -1014,8 +1015,8 @@ def _build_triple_index(self): faiss.normalize_L2(embeddings) index.add(embeddings) - faiss.write_index(index, f"{self.cache_dir}/{self.dataset}/triple.index") - with open(f"{self.cache_dir}/{self.dataset}/triple_map.json", 'w') as f: + faiss.write_index(index, f"{self.dataset_cache_dir}/triple.index") + with open(f"{self.dataset_cache_dir}/triple_map.json", 'w') as f: json.dump({i: n for i, n in enumerate(triples)}, f) self.triple_index = index @@ -1050,8 +1051,8 @@ def _build_community_index(self): faiss.normalize_L2(embeddings) index.add(embeddings) - faiss.write_index(index, f"{self.cache_dir}/{self.dataset}/comm.index") - with open(f"{self.cache_dir}/{self.dataset}/comm_map.json", 'w') as f: + faiss.write_index(index, f"{self.dataset_cache_dir}/comm.index") + with open(f"{self.dataset_cache_dir}/comm_map.json", 'w') as f: json.dump({i: n for i, n in enumerate(valid_communities)}, f) self.comm_index = index @@ -1059,12 +1060,12 @@ def _build_community_index(self): def _load_indices(self): logger.info("Starting _load_indices...") - triple_path = f"{self.cache_dir}/{self.dataset}/triple.index" - comm_path = f"{self.cache_dir}/{self.dataset}/comm.index" - node_path = f"{self.cache_dir}/{self.dataset}/node.index" - relation_path = f"{self.cache_dir}/{self.dataset}/relation.index" - node_embed_path = f"{self.cache_dir}/{self.dataset}/node_embeddings.pt" - relation_embed_path = f"{self.cache_dir}/{self.dataset}/relation_embeddings.pt" + triple_path = f"{self.dataset_cache_dir}/triple.index" + comm_path = f"{self.dataset_cache_dir}/comm.index" + node_path = f"{self.dataset_cache_dir}/node.index" + relation_path = f"{self.dataset_cache_dir}/relation.index" + node_embed_path = f"{self.dataset_cache_dir}/node_embeddings.pt" + relation_embed_path = f"{self.dataset_cache_dir}/relation_embeddings.pt" logger.debug(f"Checking cache files...") logger.debug(f"node_path exists: {os.path.exists(node_path)}") @@ -1077,22 +1078,22 @@ def _load_indices(self): if os.path.exists(node_path): logger.debug("Loading node index...") self.node_index = faiss.read_index(node_path) - with open(f"{self.cache_dir}/{self.dataset}/node_map.json", 'r') as f: + with open(f"{self.dataset_cache_dir}/node_map.json", 'r') as f: self.node_map = json.load(f) if os.path.exists(relation_path): self.relation_index = faiss.read_index(relation_path) - with open(f"{self.cache_dir}/{self.dataset}/relation_map.json", 'r') as f: + with open(f"{self.dataset_cache_dir}/relation_map.json", 'r') as f: self.relation_map = json.load(f) if os.path.exists(triple_path): self.triple_index = faiss.read_index(triple_path) - with open(f"{self.cache_dir}/{self.dataset}/triple_map.json", 'r') as f: + with open(f"{self.dataset_cache_dir}/triple_map.json", 'r') as f: self.triple_map = json.load(f) if os.path.exists(comm_path): self.comm_index = faiss.read_index(comm_path) - with open(f"{self.cache_dir}/{self.dataset}/comm_map.json", 'r') as f: + with open(f"{self.dataset_cache_dir}/comm_map.json", 'r') as f: self.comm_map = json.load(f) if os.path.exists(node_embed_path):