-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathconditioning_cache.py
More file actions
373 lines (309 loc) · 13.7 KB
/
conditioning_cache.py
File metadata and controls
373 lines (309 loc) · 13.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
"""
Conditioning Cache System for SamplerGridTester
Saves and loads encoded CLIP conditioning to avoid re-encoding prompts
"""
import os
import json
import torch
import hashlib
import base64
import numpy as np
from typing import Dict, Tuple, Optional, Any
class ConditioningCache:
"""
Manages persistent caching of encoded conditioning tensors.
The cache stores conditioning tensors as base64-encoded numpy arrays in JSON,
keyed by a hash of the prompt text + CLIP model state.
"""
def __init__(self, cache_dir: str, clip_hash: str = "unknown", enable_disk_cache: bool = True):
"""
Initialize the conditioning cache.
Args:
cache_dir: Directory to store cache files (e.g., benchmarks/session_name/)
clip_hash: Hash of the CLIP model to ensure compatibility
enable_disk_cache: Whether to save/load cache to/from disk
"""
self.cache_dir = cache_dir
self.clip_hash = clip_hash
self.cache_file = os.path.join(cache_dir, "conditioning_cache.json")
self.enable_disk_cache = enable_disk_cache
# Track current LoRA configuration for cache key generation
self.current_lora_config = ""
# Statistics - MUST be initialized BEFORE _load_from_disk()
self.stats = {
"hits": 0,
"misses": 0,
"saves": 0,
"loads": 0
}
# In-memory cache for current session
self.memory_cache = {
"positive": {},
"negative": {}
}
# Load existing cache from disk (uses self.stats) - only if enabled
if self.enable_disk_cache:
self.disk_cache = self._load_from_disk()
else:
self.disk_cache = {
"version": "1.0",
"clip_hash": self.clip_hash,
"entries": {}
}
print(f"[CondCache] ℹ️ Disk cache disabled - using memory-only cache for this session")
def _get_clip_hash(self, clip_model) -> str:
"""
Generate a hash of the CLIP model state to ensure cache validity.
Args:
clip_model: The CLIP model object
Returns:
MD5 hash string of the model
"""
try:
# Try to hash the model's state dict
if hasattr(clip_model, 'state_dict'):
state_dict = clip_model.state_dict()
# Just hash the keys and shapes, not full tensors (too slow)
model_signature = str([(k, tuple(v.shape)) for k, v in list(state_dict.items())[:10]])
elif hasattr(clip_model, 'cond_stage_model'):
# Try alternative attribute
model_signature = str(type(clip_model.cond_stage_model))
else:
model_signature = str(type(clip_model))
return hashlib.md5(model_signature.encode()).hexdigest()[:16]
except Exception as e:
print(f"[CondCache] Warning: Could not hash CLIP model: {e}")
return "unknown"
def _prompt_key(self, prompt_text: str) -> str:
"""
Generate a cache key from prompt text + LoRA configuration.
Args:
prompt_text: The prompt string
Returns:
MD5 hash of the prompt + CLIP hash + LoRA config
"""
# Include LoRA configuration in the cache key to avoid incorrect cache hits
# when using the same prompt with different LoRAs or strengths
combined = f"{self.clip_hash}:{self.current_lora_config}:{prompt_text}"
return hashlib.md5(combined.encode()).hexdigest()
def set_lora_config(self, lora_config: str):
"""
Set the current LoRA configuration for cache key generation.
This must be called whenever the LoRA configuration changes to ensure
that cached conditioning is only reused when the CLIP model is in the
same state (same LoRAs with same strengths).
Args:
lora_config: String describing the LoRA configuration, e.g.:
"None" or "lora1.safetensors:0.8:0.6,lora2.safetensors:1.0:1.0"
"""
self.current_lora_config = lora_config
def _conditioning_to_dict(self, conditioning) -> Dict[str, Any]:
"""
Convert conditioning tensor to JSON-serializable format.
Args:
conditioning: ComfyUI conditioning format [[tensor, {pooled_output: tensor}]]
Returns:
Dictionary with base64-encoded tensors
"""
try:
cond_tensor = conditioning[0][0] # Main conditioning tensor
pooled_tensor = conditioning[0][1].get("pooled_output")
# Convert to numpy and then to base64
cond_np = cond_tensor.cpu().numpy()
cond_b64 = base64.b64encode(cond_np.tobytes()).decode('utf-8')
cond_shape = list(cond_np.shape)
cond_dtype = str(cond_np.dtype)
result = {
"cond": {
"data": cond_b64,
"shape": cond_shape,
"dtype": cond_dtype
}
}
if pooled_tensor is not None:
pooled_np = pooled_tensor.cpu().numpy()
pooled_b64 = base64.b64encode(pooled_np.tobytes()).decode('utf-8')
pooled_shape = list(pooled_np.shape)
pooled_dtype = str(pooled_np.dtype)
result["pooled"] = {
"data": pooled_b64,
"shape": pooled_shape,
"dtype": pooled_dtype
}
return result
except Exception as e:
print(f"[CondCache] Error serializing conditioning: {e}")
return None
def _dict_to_conditioning(self, data: Dict[str, Any]):
"""
Convert JSON dict back to conditioning tensor.
Args:
data: Dictionary with base64-encoded tensors
Returns:
ComfyUI conditioning format [[tensor, {pooled_output: tensor}]]
"""
try:
# Reconstruct main conditioning tensor
cond_b64 = data["cond"]["data"]
cond_shape = tuple(data["cond"]["shape"])
cond_dtype = np.dtype(data["cond"]["dtype"])
cond_bytes = base64.b64decode(cond_b64)
cond_np = np.frombuffer(cond_bytes, dtype=cond_dtype).reshape(cond_shape)
cond_tensor = torch.from_numpy(cond_np)
# Reconstruct pooled output if present
pooled_tensor = None
if "pooled" in data:
pooled_b64 = data["pooled"]["data"]
pooled_shape = tuple(data["pooled"]["shape"])
pooled_dtype = np.dtype(data["pooled"]["dtype"])
pooled_bytes = base64.b64decode(pooled_b64)
pooled_np = np.frombuffer(pooled_bytes, dtype=pooled_dtype).reshape(pooled_shape)
pooled_tensor = torch.from_numpy(pooled_np)
# Return in ComfyUI format
return [[cond_tensor, {"pooled_output": pooled_tensor}]]
except Exception as e:
print(f"[CondCache] Error deserializing conditioning: {e}")
return None
def _load_from_disk(self) -> Dict[str, Any]:
"""
Load cache from disk.
Returns:
Dictionary containing cached conditioning data
"""
if not os.path.exists(self.cache_file):
return {
"version": "1.0",
"clip_hash": self.clip_hash,
"entries": {}
}
try:
with open(self.cache_file, "r") as f:
data = json.load(f)
# Validate CLIP hash matches
if data.get("clip_hash") != self.clip_hash:
print(f"[CondCache] ⚠️ CLIP model changed (cache: {data.get('clip_hash')}, current: {self.clip_hash})")
print(f"[CondCache] 🔄 Invalidating old cache...")
return {
"version": "1.0",
"clip_hash": self.clip_hash,
"entries": {}
}
num_entries = len(data.get("entries", {}))
print(f"[CondCache] 📂 Loaded {num_entries} cached conditioning entries")
self.stats["loads"] = num_entries
# Debug: Show a sample of what's in the cache
if num_entries > 0:
sample_keys = list(data.get("entries", {}).keys())[:3]
print(f"[CondCache] 🔍 Sample cache keys: {sample_keys}")
return data
except Exception as e:
print(f"[CondCache] ⚠️ Error loading cache: {e}")
import traceback
traceback.print_exc()
return {
"version": "1.0",
"clip_hash": self.clip_hash,
"entries": {}
}
def _save_to_disk(self):
"""Save cache to disk."""
try:
os.makedirs(self.cache_dir, exist_ok=True)
with open(self.cache_file, "w") as f:
json.dump(self.disk_cache, f, indent=2)
num_entries = len(self.disk_cache.get("entries", {}))
print(f"[CondCache] 💾 Saved {num_entries} conditioning entries to cache")
except Exception as e:
print(f"[CondCache] ⚠️ Error saving cache: {e}")
def get(self, prompt_text: str, prompt_type: str = "positive", debug: bool = False) -> Optional[Any]:
"""
Get cached conditioning for a prompt.
Args:
prompt_text: The prompt string
prompt_type: "positive" or "negative"
debug: Enable debug logging
Returns:
Conditioning tensor or None if not cached
"""
# Check memory cache first (uses full text as key)
if prompt_text in self.memory_cache[prompt_type]:
self.stats["hits"] += 1
if debug:
print(f"[CondCache] ✅ Memory cache hit for: {prompt_text[:50]}...")
return self.memory_cache[prompt_type][prompt_text]
# Check disk cache (uses hash as key) - only if enabled
if not self.enable_disk_cache:
self.stats["misses"] += 1
if debug:
print(f"[CondCache] ❌ Cache miss (disk cache disabled)")
return None
key = self._prompt_key(prompt_text)
if debug:
print(f"[CondCache] 🔍 Looking for key: {key}")
print(f"[CondCache] 🔍 Prompt: {prompt_text[:80]}...")
if key in self.disk_cache.get("entries", {}):
entry = self.disk_cache["entries"][key]
conditioning = self._dict_to_conditioning(entry["conditioning"])
if conditioning is not None:
# Store in memory cache for fast access (using full text)
self.memory_cache[prompt_type][prompt_text] = conditioning
self.stats["hits"] += 1
if debug:
print(f"[CondCache] ✅ Disk cache hit!")
return conditioning
self.stats["misses"] += 1
if debug:
print(f"[CondCache] ❌ Cache miss")
return None
def set(self, prompt_text: str, conditioning, prompt_type: str = "positive"):
"""
Cache conditioning for a prompt.
Args:
prompt_text: The prompt string
conditioning: The conditioning tensor
prompt_type: "positive" or "negative"
"""
# Store in memory cache (always enabled)
self.memory_cache[prompt_type][prompt_text] = conditioning
# Store in disk cache (only if enabled)
if not self.enable_disk_cache:
return
key = self._prompt_key(prompt_text)
serialized = self._conditioning_to_dict(conditioning)
if serialized is not None:
if "entries" not in self.disk_cache:
self.disk_cache["entries"] = {}
self.disk_cache["entries"][key] = {
"prompt": prompt_text[:200], # Truncate for readability
"type": prompt_type,
"conditioning": serialized
}
self.stats["saves"] += 1
def save(self):
"""Save the cache to disk (only if disk cache is enabled)."""
if self.enable_disk_cache:
self._save_to_disk()
else:
print(f"[CondCache] ℹ️ Skipping disk save (disk cache disabled)")
def get_stats(self) -> Dict[str, int]:
"""Get cache statistics."""
hit_rate = 0
if (self.stats["hits"] + self.stats["misses"]) > 0:
hit_rate = (self.stats["hits"] / (self.stats["hits"] + self.stats["misses"])) * 100
return {
**self.stats,
"hit_rate": round(hit_rate, 1),
"total_cached": len(self.disk_cache.get("entries", {}))
}
def print_stats(self):
"""Print cache statistics."""
stats = self.get_stats()
print(f"\n{'='*80}")
print(f"[CondCache] 📊 CONDITIONING CACHE STATISTICS")
print(f"{'='*80}")
print(f" Hits: {stats['hits']}")
print(f" Misses: {stats['misses']}")
print(f" Hit Rate: {stats['hit_rate']}%")
print(f" Saved: {stats['saves']}")
print(f" Total Cached: {stats['total_cached']}")
print(f"{'='*80}\n")