Skip to content

Commit 4b5ec22

Browse files
committed
add token experiment
1 parent f3829ad commit 4b5ec22

File tree

8 files changed

+1717
-113
lines changed

8 files changed

+1717
-113
lines changed

bergson/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ class Build:
3636

3737
def execute(self):
3838
"""Build the gradient dataset."""
39-
if not self.cfg.save_index and not self.cfg.save_processor:
39+
if not self.cfg.save_index and not self.cfg.save_processor and not self.cfg.create_custom_query:
4040
raise ValueError(
41-
"At least one of save_index or save_processor must be True"
41+
"At least one of save_index or save_processor or create_custom_query must be True"
4242
)
4343

4444
build_gradient_dataset(self.cfg)

bergson/attributor.py

Lines changed: 154 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -43,36 +43,52 @@ def __init__(
4343
dtype: torch.dtype = torch.float32,
4444
unit_norm: bool = False,
4545
faiss_cfg: FaissConfig | None = None,
46+
processor: GradientProcessor | None = None,
47+
unstructured: bool = False,
4648
):
4749
self.device = device
4850
self.dtype = dtype
4951
self.unit_norm = unit_norm
5052
self.faiss_index = None
5153

5254
# Load the gradient processor
53-
self.processor = GradientProcessor(projection_dim=16)
54-
# self.processor = GradientProcessor.load(index_path, map_location=device)
55+
self.processor = processor or GradientProcessor.load(
56+
index_path, map_location=device
57+
)
5558

5659
# Load the gradient index
5760
if faiss_cfg:
5861
self.faiss_index = FaissIndex(index_path, faiss_cfg, device, unit_norm)
5962
self.N = self.faiss_index.ntotal
6063
else:
6164
mmap = load_gradients(index_path)
62-
63-
# Copy gradients into device memory
64-
self.grads = {
65-
name: torch.tensor(mmap[name], device=device, dtype=dtype)
66-
for name in mmap.dtype.names
67-
}
6865
self.N = mmap[mmap.dtype.names[0]].shape[0]
6966

70-
if unit_norm:
71-
norm = torch.cat([grad for grad in self.grads.values()], dim=1).norm(
72-
dim=1, keepdim=True
73-
)
74-
for name in self.grads:
75-
self.grads[name] /= norm
67+
# Copy gradients into device memory
68+
if unstructured:
69+
from numpy.lib.recfunctions import structured_to_unstructured
70+
import numpy as np
71+
mmap = structured_to_unstructured(mmap).astype(np.float16)
72+
print("Number of elements:", mmap.shape[0] * mmap.shape[1])
73+
print(mmap.dtype)
74+
print(f"RAM required assuming float32: {mmap.shape[0] * mmap.shape[1] * 4 / 1024**3} GB")
75+
self.grads = torch.from_numpy(mmap)
76+
77+
if unit_norm:
78+
norm = self.grads.norm(dim=1, keepdim=True) + torch.finfo(dtype).eps
79+
self.grads /= norm
80+
else:
81+
self.grads = {
82+
name: torch.tensor(mmap[name], device=device, dtype=dtype)
83+
for name in mmap.dtype.names
84+
}
85+
86+
if unit_norm:
87+
norm = torch.cat([grad for grad in self.grads.values()], dim=1).norm(
88+
dim=1, keepdim=True
89+
)
90+
for name in self.grads:
91+
self.grads[name] /= norm
7692

7793
def search(
7894
self,
@@ -124,9 +140,113 @@ def search(
124140

125141
return torch.topk(scores, k)
126142

143+
def score(
144+
self,
145+
queries: dict[str, Tensor],
146+
# modules: list[str] | None = None,
147+
batch_size: int = 1024,
148+
onload_device: str = "cuda",
149+
):
150+
"""
151+
Search for the `k` nearest examples in the index based on the query or queries.
152+
Onload shards into VRAM and search.
153+
154+
Args:
155+
queries: The query tensor of shape [..., d].
156+
k: The number of nearest examples to return for each query.
157+
module: The name of the module to search for. If `None`,
158+
all modules will be searched.
159+
160+
Returns:
161+
A namedtuple containing the top `k` indices and inner products for each
162+
query. Both have shape [..., k].
163+
"""
164+
assert not self.faiss_index, "FAISS index does not implement onloaded search."
165+
166+
q = {name: item.to(self.device, self.dtype) for name, item in queries.items()}
167+
168+
if self.unit_norm:
169+
norm = torch.cat(list(q.values()), dim=1).norm(dim=1, keepdim=True)
170+
for name in q:
171+
q[name] /= norm + 1e-8
172+
173+
# modules = modules or list(q.keys())
174+
k = self.N
175+
176+
modules = list(q.keys())
177+
178+
scores = torch.zeros(k, len(q), device=self.device, dtype=self.dtype)
179+
180+
q_tensor = torch.cat([q[name] for name in modules], dim=1).to(onload_device)
181+
for i in range(0, self.N, batch_size):
182+
batch = self.grads[i : i + batch_size].to(onload_device, self.dtype)
183+
batch_scores = batch @ q_tensor.mT
184+
scores[i : i + batch_size] = batch_scores.to(self.device)
185+
186+
return scores
187+
188+
@contextmanager
189+
def trace_score(self,
190+
module: nn.Module,
191+
k: int | None,
192+
*,
193+
precondition: bool = False,
194+
target_modules: set[str] | None = None,
195+
):
196+
197+
mod_grads = defaultdict(list)
198+
result = {}
199+
200+
def callback(name: str, g: Tensor, indices: list[int]):
201+
# Precondition the gradient using Cholesky solve
202+
if precondition:
203+
eigval, eigvec = self.processor.preconditioners_eigen[name]
204+
# assert not eigval.isnan().any().item() and not eigvec.isnan().any().item()
205+
206+
eigval_clamped = torch.clamp(eigval.to(torch.float64), min=0.0)
207+
# assert not eigval_clamped.isnan().any().item(), "eigval_clamped is nan"
208+
eigval_inverse_sqrt = 1.0 / (
209+
(eigval_clamped).sqrt() + torch.finfo(torch.float64).eps
210+
)
211+
212+
P = (
213+
eigvec.to(eigval_inverse_sqrt.dtype)
214+
* eigval_inverse_sqrt
215+
@ eigvec.mT.to(eigval_inverse_sqrt.dtype)
216+
)
217+
g = g.flatten(1).type_as(P)
218+
assert not P.isnan().any().item(), "P is nan"
219+
assert not g.isnan().any().item(), "g is nan"
220+
g = g @ P
221+
else:
222+
g = g.flatten(1)
223+
224+
# Store the gradient for later use
225+
mod_grads[name].append(g.to(self.device, self.dtype, non_blocking=True))
226+
227+
with GradientCollector(module, callback, self.processor, target_modules):
228+
yield result
229+
230+
if not mod_grads:
231+
raise ValueError("No grads collected. Did you forget to call backward?")
232+
233+
queries = {name: torch.cat(g, dim=1) for name, g in mod_grads.items()}
234+
235+
if any(q.isnan().any() for q in queries.values()):
236+
raise ValueError("NaN found in queries.")
237+
238+
result['scores'] = self.score(queries)
239+
240+
127241
@contextmanager
128242
def trace(
129-
self, module: nn.Module, k: int | None, *, precondition: bool = False, target_modules: set[str] | None = None
243+
self,
244+
module: nn.Module,
245+
k: int | None,
246+
*,
247+
precondition: bool = False,
248+
target_modules: set[str] | None = None,
249+
score: bool = False,
130250
) -> Generator[TraceResult, None, None]:
131251
"""
132252
Context manager to trace the gradients of a module and return the
@@ -139,9 +259,26 @@ def callback(name: str, g: Tensor, indices: list[int]):
139259
# Precondition the gradient using Cholesky solve
140260
if precondition:
141261
eigval, eigvec = self.processor.preconditioners_eigen[name]
142-
eigval_inverse_sqrt = 1.0 / (eigval).sqrt()
143-
P = eigvec * eigval_inverse_sqrt @ eigvec.mT
262+
# assert not eigval.isnan().any().item() and not eigvec.isnan().any().item()
263+
264+
eigval_clamped = torch.clamp(eigval.to(torch.float64), min=0.0)
265+
# assert not eigval_clamped.isnan().any().item(), "eigval_clamped is nan"
266+
eigval_inverse_sqrt = 1.0 / (
267+
(eigval_clamped).sqrt() + torch.finfo(torch.float64).eps
268+
) #
269+
# assert not eigval_inverse_sqrt.isnan().any().item()
270+
271+
# assert not eigval_inverse_sqrt.isnan().any().item(), "eigval_inverse_sqrt is nan after dtype conversion"
272+
# eigval_inverse_sqrt = eigval_inverse_sqrt.to(eigval.dtype)
273+
# P = eigvec * eigval_inverse_sqrt @ eigvec.mT
274+
P = (
275+
eigvec.to(eigval_inverse_sqrt.dtype)
276+
* eigval_inverse_sqrt
277+
@ eigvec.mT.to(eigval_inverse_sqrt.dtype)
278+
)
144279
g = g.flatten(1).type_as(P)
280+
assert not P.isnan().any().item(), "P is nan"
281+
assert not g.isnan().any().item(), "g is nan"
145282
g = g @ P
146283
else:
147284
g = g.flatten(1)

bergson/build.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def worker(
168168
save_index=cfg.save_index,
169169
save_processor=cfg.save_processor,
170170
drop_columns=cfg.drop_columns,
171-
create_custom_query=cfg.in_memory_index,
171+
create_custom_query=cfg.create_custom_query,
172172
module_wise=cfg.module_wise,
173173
token_batch_size=cfg.token_batch_size,
174174
)
@@ -197,7 +197,7 @@ def flush():
197197
# Save a processor state checkpoint after each shard
198198
save_processor=cfg.save_processor,
199199
drop_columns=cfg.drop_columns,
200-
create_custom_query=cfg.in_memory_index,
200+
create_custom_query=cfg.create_custom_query,
201201
module_wise=cfg.module_wise,
202202
token_batch_size=cfg.token_batch_size,
203203
)
@@ -240,15 +240,19 @@ def build_gradient_dataset(cfg: IndexConfig):
240240
tokenizer.model_max_length = min(tokenizer.model_max_length, cfg.token_batch_size)
241241

242242
# Do all the data loading and preprocessing on the main process
243-
ds = load_data_string(cfg.data.dataset, cfg.data.split, streaming=cfg.streaming)
243+
if cfg.data.subset:
244+
ds = load_data_string(cfg.data.dataset, cfg.data.split, cfg.data.subset, streaming=cfg.streaming)
245+
else:
246+
ds = load_data_string(cfg.data.dataset, cfg.data.split, streaming=cfg.streaming)
244247

245248
remove_columns = ds.column_names if cfg.drop_columns else None
246-
ds = ds.map(
247-
tokenize,
248-
batched=True,
249-
fn_kwargs=dict(args=cfg.data, tokenizer=tokenizer),
250-
remove_columns=remove_columns,
251-
)
249+
if not cfg.skip_tokenization:
250+
ds = ds.map(
251+
tokenize,
252+
batched=True,
253+
fn_kwargs=dict(args=cfg.data, tokenizer=tokenizer),
254+
remove_columns=remove_columns,
255+
)
252256
if cfg.data.reward_column:
253257
assert isinstance(ds, Dataset), "Dataset required for advantage estimation"
254258
ds = ds.add_column(

0 commit comments

Comments
 (0)