@@ -43,36 +43,52 @@ def __init__(
4343 dtype : torch .dtype = torch .float32 ,
4444 unit_norm : bool = False ,
4545 faiss_cfg : FaissConfig | None = None ,
46+ processor : GradientProcessor | None = None ,
47+ unstructured : bool = False ,
4648 ):
4749 self .device = device
4850 self .dtype = dtype
4951 self .unit_norm = unit_norm
5052 self .faiss_index = None
5153
5254 # Load the gradient processor
53- self .processor = GradientProcessor (projection_dim = 16 )
54- # self.processor = GradientProcessor.load(index_path, map_location=device)
55+ self .processor = processor or GradientProcessor .load (
56+ index_path , map_location = device
57+ )
5558
5659 # Load the gradient index
5760 if faiss_cfg :
5861 self .faiss_index = FaissIndex (index_path , faiss_cfg , device , unit_norm )
5962 self .N = self .faiss_index .ntotal
6063 else :
6164 mmap = load_gradients (index_path )
62-
63- # Copy gradients into device memory
64- self .grads = {
65- name : torch .tensor (mmap [name ], device = device , dtype = dtype )
66- for name in mmap .dtype .names
67- }
6865 self .N = mmap [mmap .dtype .names [0 ]].shape [0 ]
6966
70- if unit_norm :
71- norm = torch .cat ([grad for grad in self .grads .values ()], dim = 1 ).norm (
72- dim = 1 , keepdim = True
73- )
74- for name in self .grads :
75- self .grads [name ] /= norm
67+ # Copy gradients into device memory
68+ if unstructured :
69+ from numpy .lib .recfunctions import structured_to_unstructured
70+ import numpy as np
71+ mmap = structured_to_unstructured (mmap ).astype (np .float16 )
72+ print ("Number of elements:" , mmap .shape [0 ] * mmap .shape [1 ])
73+ print (mmap .dtype )
74+ print (f"RAM required assuming float32: { mmap .shape [0 ] * mmap .shape [1 ] * 4 / 1024 ** 3 } GB" )
75+ self .grads = torch .from_numpy (mmap )
76+
77+ if unit_norm :
78+ norm = self .grads .norm (dim = 1 , keepdim = True ) + torch .finfo (dtype ).eps
79+ self .grads /= norm
80+ else :
81+ self .grads = {
82+ name : torch .tensor (mmap [name ], device = device , dtype = dtype )
83+ for name in mmap .dtype .names
84+ }
85+
86+ if unit_norm :
87+ norm = torch .cat ([grad for grad in self .grads .values ()], dim = 1 ).norm (
88+ dim = 1 , keepdim = True
89+ )
90+ for name in self .grads :
91+ self .grads [name ] /= norm
7692
7793 def search (
7894 self ,
@@ -124,9 +140,113 @@ def search(
124140
125141 return torch .topk (scores , k )
126142
143+ def score (
144+ self ,
145+ queries : dict [str , Tensor ],
146+ # modules: list[str] | None = None,
147+ batch_size : int = 1024 ,
148+ onload_device : str = "cuda" ,
149+ ):
150+ """
151+ Search for the `k` nearest examples in the index based on the query or queries.
152+ Onload shards into VRAM and search.
153+
154+ Args:
155+ queries: The query tensor of shape [..., d].
156+ k: The number of nearest examples to return for each query.
157+ module: The name of the module to search for. If `None`,
158+ all modules will be searched.
159+
160+ Returns:
161+ A namedtuple containing the top `k` indices and inner products for each
162+ query. Both have shape [..., k].
163+ """
164+ assert not self .faiss_index , "FAISS index does not implement onloaded search."
165+
166+ q = {name : item .to (self .device , self .dtype ) for name , item in queries .items ()}
167+
168+ if self .unit_norm :
169+ norm = torch .cat (list (q .values ()), dim = 1 ).norm (dim = 1 , keepdim = True )
170+ for name in q :
171+ q [name ] /= norm + 1e-8
172+
173+ # modules = modules or list(q.keys())
174+ k = self .N
175+
176+ modules = list (q .keys ())
177+
178+ scores = torch .zeros (k , len (q ), device = self .device , dtype = self .dtype )
179+
180+ q_tensor = torch .cat ([q [name ] for name in modules ], dim = 1 ).to (onload_device )
181+ for i in range (0 , self .N , batch_size ):
182+ batch = self .grads [i : i + batch_size ].to (onload_device , self .dtype )
183+ batch_scores = batch @ q_tensor .mT
184+ scores [i : i + batch_size ] = batch_scores .to (self .device )
185+
186+ return scores
187+
188+ @contextmanager
189+ def trace_score (self ,
190+ module : nn .Module ,
191+ k : int | None ,
192+ * ,
193+ precondition : bool = False ,
194+ target_modules : set [str ] | None = None ,
195+ ):
196+
197+ mod_grads = defaultdict (list )
198+ result = {}
199+
200+ def callback (name : str , g : Tensor , indices : list [int ]):
201+ # Precondition the gradient using Cholesky solve
202+ if precondition :
203+ eigval , eigvec = self .processor .preconditioners_eigen [name ]
204+ # assert not eigval.isnan().any().item() and not eigvec.isnan().any().item()
205+
206+ eigval_clamped = torch .clamp (eigval .to (torch .float64 ), min = 0.0 )
207+ # assert not eigval_clamped.isnan().any().item(), "eigval_clamped is nan"
208+ eigval_inverse_sqrt = 1.0 / (
209+ (eigval_clamped ).sqrt () + torch .finfo (torch .float64 ).eps
210+ )
211+
212+ P = (
213+ eigvec .to (eigval_inverse_sqrt .dtype )
214+ * eigval_inverse_sqrt
215+ @ eigvec .mT .to (eigval_inverse_sqrt .dtype )
216+ )
217+ g = g .flatten (1 ).type_as (P )
218+ assert not P .isnan ().any ().item (), "P is nan"
219+ assert not g .isnan ().any ().item (), "g is nan"
220+ g = g @ P
221+ else :
222+ g = g .flatten (1 )
223+
224+ # Store the gradient for later use
225+ mod_grads [name ].append (g .to (self .device , self .dtype , non_blocking = True ))
226+
227+ with GradientCollector (module , callback , self .processor , target_modules ):
228+ yield result
229+
230+ if not mod_grads :
231+ raise ValueError ("No grads collected. Did you forget to call backward?" )
232+
233+ queries = {name : torch .cat (g , dim = 1 ) for name , g in mod_grads .items ()}
234+
235+ if any (q .isnan ().any () for q in queries .values ()):
236+ raise ValueError ("NaN found in queries." )
237+
238+ result ['scores' ] = self .score (queries )
239+
240+
127241 @contextmanager
128242 def trace (
129- self , module : nn .Module , k : int | None , * , precondition : bool = False , target_modules : set [str ] | None = None
243+ self ,
244+ module : nn .Module ,
245+ k : int | None ,
246+ * ,
247+ precondition : bool = False ,
248+ target_modules : set [str ] | None = None ,
249+ score : bool = False ,
130250 ) -> Generator [TraceResult , None , None ]:
131251 """
132252 Context manager to trace the gradients of a module and return the
@@ -139,9 +259,26 @@ def callback(name: str, g: Tensor, indices: list[int]):
139259 # Precondition the gradient using Cholesky solve
140260 if precondition :
141261 eigval , eigvec = self .processor .preconditioners_eigen [name ]
142- eigval_inverse_sqrt = 1.0 / (eigval ).sqrt ()
143- P = eigvec * eigval_inverse_sqrt @ eigvec .mT
262+ # assert not eigval.isnan().any().item() and not eigvec.isnan().any().item()
263+
264+ eigval_clamped = torch .clamp (eigval .to (torch .float64 ), min = 0.0 )
265+ # assert not eigval_clamped.isnan().any().item(), "eigval_clamped is nan"
266+ eigval_inverse_sqrt = 1.0 / (
267+ (eigval_clamped ).sqrt () + torch .finfo (torch .float64 ).eps
268+ ) #
269+ # assert not eigval_inverse_sqrt.isnan().any().item()
270+
271+ # assert not eigval_inverse_sqrt.isnan().any().item(), "eigval_inverse_sqrt is nan after dtype conversion"
272+ # eigval_inverse_sqrt = eigval_inverse_sqrt.to(eigval.dtype)
273+ # P = eigvec * eigval_inverse_sqrt @ eigvec.mT
274+ P = (
275+ eigvec .to (eigval_inverse_sqrt .dtype )
276+ * eigval_inverse_sqrt
277+ @ eigvec .mT .to (eigval_inverse_sqrt .dtype )
278+ )
144279 g = g .flatten (1 ).type_as (P )
280+ assert not P .isnan ().any ().item (), "P is nan"
281+ assert not g .isnan ().any ().item (), "g is nan"
145282 g = g @ P
146283 else :
147284 g = g .flatten (1 )
0 commit comments