-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathcache.py
More file actions
238 lines (209 loc) · 8.65 KB
/
cache.py
File metadata and controls
238 lines (209 loc) · 8.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
"""
缓存管理器 - 基于文件 digest + LRU 淘汰
支持分块下载的断点续传缓存
"""
import os
import hashlib
import sqlite3
import shutil
from pathlib import Path
from typing import Optional, List, Dict
from datetime import datetime, timedelta
class CacheManager:
def __init__(self, cache_dir: str, max_size_gb: float = 100):
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.max_bytes = int(max_size_gb * 1024**3)
self.db_path = self.cache_dir / "meta.db"
self._init_db()
def _init_db(self):
"""初始化 SQLite 元数据表"""
conn = sqlite3.connect(self.db_path)
c = conn.cursor()
# 文件缓存表
c.execute('''
CREATE TABLE IF NOT EXISTS files (
digest TEXT PRIMARY KEY,
filepath TEXT NOT NULL,
size INTEGER NOT NULL,
accessed TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
created TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')
c.execute('CREATE INDEX IF NOT EXISTS idx_accessed ON files(accessed)')
# 分块下载缓存表
c.execute('''
CREATE TABLE IF NOT EXISTS chunks (
id INTEGER PRIMARY KEY AUTOINCREMENT,
url TEXT NOT NULL,
total_size INTEGER NOT NULL,
chunk_start INTEGER NOT NULL,
chunk_end INTEGER NOT NULL,
downloaded INTEGER DEFAULT 0,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
UNIQUE(url, chunk_start, chunk_end)
)
''')
c.execute('CREATE INDEX IF NOT EXISTS idx_chunks_url ON chunks(url)')
c.execute('CREATE INDEX IF NOT EXISTS idx_chunks_updated ON chunks(updated_at)')
conn.commit()
conn.close()
def _get_digest(self, url: str, content_type: str = "") -> str:
"""生成缓存键:url 的 sha256(content_type 不参与,避免影响缓存命中)"""
return hashlib.sha256(url.encode()).hexdigest()
def get(self, url: str, content_type: str = "") -> Optional[str]:
"""获取缓存文件路径,不存在返回 None"""
digest = self._get_digest(url, content_type)
conn = sqlite3.connect(self.db_path)
c = conn.cursor()
c.execute('SELECT filepath, size FROM files WHERE digest=?', (digest,))
row = c.fetchone()
conn.close()
if row and os.path.exists(row[0]):
# 更新访问时间
conn = sqlite3.connect(self.db_path)
conn.execute('UPDATE files SET accessed=CURRENT_TIMESTAMP WHERE digest=?', (digest,))
conn.commit()
conn.close()
return row[0]
return None
def put(self, url: str, filepath: str, content_type: str = "") -> str:
"""存入缓存,返回 digest"""
size = os.path.getsize(filepath)
digest = self._get_digest(url, content_type)
# 目标路径
target = self.cache_dir / digest
if not os.path.exists(target):
try:
os.link(filepath, target) # 硬链接,节省空间
except OSError:
# 跨文件系统时硬链接失败,使用复制
shutil.copy2(filepath, target)
conn = sqlite3.connect(self.db_path)
c = conn.cursor()
c.execute('''
INSERT OR REPLACE INTO files (digest, filepath, size, accessed, created)
VALUES (?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
''', (digest, str(target), size))
conn.commit()
conn.close()
# 检查并执行 LRU 淘汰
self._evict_if_needed()
return digest
def _evict_if_needed(self):
"""LRU 淘汰策略"""
conn = sqlite3.connect(self.db_path)
c = conn.cursor()
# 计算当前总大小
c.execute('SELECT SUM(size) FROM files')
total = c.fetchone()[0] or 0
while total > self.max_bytes:
# 找出最久未访问的文件
c.execute('''
SELECT digest, filepath, size FROM files
ORDER BY accessed ASC LIMIT 1
''')
row = c.fetchone()
if not row:
break
digest, filepath, size = row
# 删除文件
if os.path.exists(filepath):
os.remove(filepath)
c.execute('DELETE FROM files WHERE digest=?', (digest,))
total -= size
conn.commit()
conn.close()
def get_stats(self) -> dict:
"""返回缓存统计"""
conn = sqlite3.connect(self.db_path)
c = conn.cursor()
c.execute('SELECT COUNT(*), SUM(size), MIN(created), MAX(accessed) FROM files')
count, size, first, last = c.fetchone()
conn.close()
return {
"count": count or 0,
"size_bytes": size or 0,
"size_gb": (size or 0) / 1024**3,
"first_cached": first,
"last_accessed": last
}
# ========== 分块下载缓存管理 ==========
def get_downloaded_chunks(self, url: str, total_size: int, chunk_ttl_hours: int = 48) -> List[Dict]:
"""
获取已下载的分块列表
只返回在有效期内且 total_size 匹配的 chunk
"""
conn = sqlite3.connect(self.db_path)
c = conn.cursor()
# 计算过期时间
expire_time = datetime.now() - timedelta(hours=chunk_ttl_hours)
c.execute('''
SELECT chunk_start, chunk_end, downloaded, updated_at FROM chunks
WHERE url = ? AND total_size = ? AND updated_at > ? AND downloaded = 1
ORDER BY chunk_start
''', (url, total_size, expire_time.isoformat()))
rows = c.fetchall()
conn.close()
chunks = []
for row in rows:
chunks.append({
'start': row[0],
'end': row[1],
'downloaded': row[2] == 1,
'updated_at': row[3]
})
return chunks
def mark_chunks_downloaded(self, url: str, total_size: int, chunks: List[tuple]):
"""批量标记分块已下载完成
Args:
chunks: 分块列表,每个元素为 (chunk_start, chunk_end) 元组
"""
if not chunks:
return
conn = sqlite3.connect(self.db_path)
c = conn.cursor()
# 使用事务批量写入
c.executemany('''
INSERT INTO chunks (url, total_size, chunk_start, chunk_end, downloaded, updated_at)
VALUES (?, ?, ?, ?, 1, CURRENT_TIMESTAMP)
ON CONFLICT(url, chunk_start, chunk_end) DO UPDATE SET
downloaded = 1,
updated_at = CURRENT_TIMESTAMP
''', [(url, total_size, start, end) for start, end in chunks])
conn.commit()
conn.close()
def mark_chunk_downloaded(self, url: str, total_size: int, chunk_start: int, chunk_end: int):
"""标记单个分块已下载完成(兼容旧接口)"""
self.mark_chunks_downloaded(url, total_size, [(chunk_start, chunk_end)])
def mark_chunk_pending(self, url: str, total_size: int, chunk_start: int, chunk_end: int):
"""标记分块为待下载状态(用于恢复下载时)"""
conn = sqlite3.connect(self.db_path)
c = conn.cursor()
c.execute('''
INSERT INTO chunks (url, total_size, chunk_start, chunk_end, downloaded, updated_at)
VALUES (?, ?, ?, ?, 0, CURRENT_TIMESTAMP)
ON CONFLICT(url, chunk_start, chunk_end) DO UPDATE SET
downloaded = 0,
updated_at = CURRENT_TIMESTAMP
''', (url, total_size, chunk_start, chunk_end))
conn.commit()
conn.close()
def clear_chunks_for_url(self, url: str):
"""清除指定 URL 的所有分块记录"""
conn = sqlite3.connect(self.db_path)
c = conn.cursor()
c.execute('DELETE FROM chunks WHERE url = ?', (url,))
conn.commit()
conn.close()
def cleanup_expired_chunks(self, chunk_ttl_hours: int = 48):
"""清理过期的分块记录"""
conn = sqlite3.connect(self.db_path)
c = conn.cursor()
expire_time = datetime.now() - timedelta(hours=chunk_ttl_hours)
c.execute('DELETE FROM chunks WHERE updated_at < ?', (expire_time.isoformat(),))
deleted = c.rowcount
conn.commit()
conn.close()
return deleted