From faad8d939fee800aeb03a74d10fe069fb8c160c8 Mon Sep 17 00:00:00 2001 From: ZeguanXiao Date: Sat, 27 Sep 2025 12:54:40 +0800 Subject: [PATCH 1/6] fix: random sampling in ForgetRetainDataset --- src/data/unlearn.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/data/unlearn.py b/src/data/unlearn.py index 0cb0bada..1f6db300 100644 --- a/src/data/unlearn.py +++ b/src/data/unlearn.py @@ -33,14 +33,18 @@ def __len__(self): def __getitem__(self, idx): item = {} + g = torch.Generator() + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + seed = int(torch.empty((), dtype=torch.int64).random_().item() + rank) + g.manual_seed(seed) if self.anchor == "forget": item["forget"] = self.forget[idx] if self.retain: - retain_idx = torch.randint(0, len(self.retain), (1,)).item() + retain_idx = torch.randint(0, len(self.retain), (1,), generator=g).item() item["retain"] = self.retain[retain_idx] elif self.anchor == "retain": item["retain"] = self.retain[idx] if self.forget: - forget_idx = torch.randint(0, len(self.forget), (1,)).item() + forget_idx = torch.randint(0, len(self.forget), (1,), generator=g).item() item["forget"] = self.forget[forget_idx] return item From c079574411df2102b4d89bfd10a9919d2b87ec88 Mon Sep 17 00:00:00 2001 From: ZeguanXiao Date: Wed, 1 Oct 2025 13:05:27 +0800 Subject: [PATCH 2/6] feat: add seed parameter for reproducibility in ForgetRetainDataset --- src/data/__init__.py | 4 ++-- src/data/unlearn.py | 8 +++++--- src/train.py | 2 +- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/data/__init__.py b/src/data/__init__.py index c24b0b03..93c092e5 100644 --- a/src/data/__init__.py +++ b/src/data/__init__.py @@ -46,7 +46,7 @@ def get_datasets(dataset_cfgs: Union[Dict, DictConfig], **kwargs): return dataset -def get_data(data_cfg: DictConfig, mode="train", **kwargs): +def get_data(data_cfg: DictConfig, mode="train", seed=0, **kwargs): data = {} data_cfg = dict(data_cfg) anchor = data_cfg.pop("anchor", "forget") @@ -56,7 +56,7 @@ def get_data(data_cfg: DictConfig, mode="train", **kwargs): return data elif mode == "unlearn": unlearn_splits = {k: v for k, v in data.items() if k not in ("eval", "test")} - unlearn_dataset = ForgetRetainDataset(**unlearn_splits, anchor=anchor) + unlearn_dataset = ForgetRetainDataset(**unlearn_splits, anchor=anchor, seed=seed) data["train"] = unlearn_dataset for split in unlearn_splits: data.pop(split) diff --git a/src/data/unlearn.py b/src/data/unlearn.py index 1f6db300..9fd6e7f5 100644 --- a/src/data/unlearn.py +++ b/src/data/unlearn.py @@ -4,17 +4,19 @@ class ForgetRetainDataset(Dataset): # https://github.com/OPTML-Group/SOUL/blob/main/src/dataset/Base.py - def __init__(self, forget, retain, anchor="forget"): + def __init__(self, forget, retain, anchor="forget", seed=0): """Wraps the forget retain dataset into unlearning dataset. Args: forget (Dataset): Forget Dataset retain (Dataset): Retain Dataset anchor (str, optional): Specifies which dataset to anchor while randomly sampling from the other dataset. Defaults to 'forget'. + seed (int, optional): Random seed for reproducibility. Defaults to 0. """ self.forget = forget self.retain = retain self.anchor = anchor + self.seed = seed def __len__(self): """Ensures the sampled dataset matches the anchor dataset's length.""" @@ -35,8 +37,8 @@ def __getitem__(self, idx): item = {} g = torch.Generator() rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 - seed = int(torch.empty((), dtype=torch.int64).random_().item() + rank) - g.manual_seed(seed) + rank_seed = self.seed + rank + g.manual_seed(rank_seed) if self.anchor == "forget": item["forget"] = self.forget[idx] if self.retain: diff --git a/src/train.py b/src/train.py index a2f81c8d..4e6a0224 100644 --- a/src/train.py +++ b/src/train.py @@ -23,7 +23,7 @@ def main(cfg: DictConfig): # Load Dataset data_cfg = cfg.data data = get_data( - data_cfg, mode=mode, tokenizer=tokenizer, template_args=template_args + data_cfg, mode=mode, tokenizer=tokenizer, template_args=template_args, seed=cfg.trainer.args.seed ) # Load collator From 60e099a02ad8cbcb275b200ba86ca5dd96b94e44 Mon Sep 17 00:00:00 2001 From: ZeguanXiao Date: Thu, 9 Oct 2025 00:10:08 +0800 Subject: [PATCH 3/6] refactor: fix lint --- src/data/__init__.py | 4 +++- src/data/unlearn.py | 8 ++++++-- src/train.py | 6 +++++- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/data/__init__.py b/src/data/__init__.py index 93c092e5..aa39bcbb 100644 --- a/src/data/__init__.py +++ b/src/data/__init__.py @@ -56,7 +56,9 @@ def get_data(data_cfg: DictConfig, mode="train", seed=0, **kwargs): return data elif mode == "unlearn": unlearn_splits = {k: v for k, v in data.items() if k not in ("eval", "test")} - unlearn_dataset = ForgetRetainDataset(**unlearn_splits, anchor=anchor, seed=seed) + unlearn_dataset = ForgetRetainDataset( + **unlearn_splits, anchor=anchor, seed=seed + ) data["train"] = unlearn_dataset for split in unlearn_splits: data.pop(split) diff --git a/src/data/unlearn.py b/src/data/unlearn.py index 9fd6e7f5..bbf745a1 100644 --- a/src/data/unlearn.py +++ b/src/data/unlearn.py @@ -42,11 +42,15 @@ def __getitem__(self, idx): if self.anchor == "forget": item["forget"] = self.forget[idx] if self.retain: - retain_idx = torch.randint(0, len(self.retain), (1,), generator=g).item() + retain_idx = torch.randint( + 0, len(self.retain), (1,), generator=g + ).item() item["retain"] = self.retain[retain_idx] elif self.anchor == "retain": item["retain"] = self.retain[idx] if self.forget: - forget_idx = torch.randint(0, len(self.forget), (1,), generator=g).item() + forget_idx = torch.randint( + 0, len(self.forget), (1,), generator=g + ).item() item["forget"] = self.forget[forget_idx] return item diff --git a/src/train.py b/src/train.py index 4e6a0224..5e8f6db5 100644 --- a/src/train.py +++ b/src/train.py @@ -23,7 +23,11 @@ def main(cfg: DictConfig): # Load Dataset data_cfg = cfg.data data = get_data( - data_cfg, mode=mode, tokenizer=tokenizer, template_args=template_args, seed=cfg.trainer.args.seed + data_cfg, + mode=mode, + tokenizer=tokenizer, + template_args=template_args, + seed=cfg.trainer.args.seed, ) # Load collator From 7a8b5fda95bea8b24636fcf7fe3adb782c921071 Mon Sep 17 00:00:00 2001 From: ZeguanXiao Date: Thu, 9 Oct 2025 00:20:30 +0800 Subject: [PATCH 4/6] fix: ensure unique random seed per item in ForgetRetainDataset --- src/data/unlearn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/data/unlearn.py b/src/data/unlearn.py index bbf745a1..dff81be7 100644 --- a/src/data/unlearn.py +++ b/src/data/unlearn.py @@ -37,7 +37,7 @@ def __getitem__(self, idx): item = {} g = torch.Generator() rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 - rank_seed = self.seed + rank + rank_seed = self.seed + rank + idx g.manual_seed(rank_seed) if self.anchor == "forget": item["forget"] = self.forget[idx] From 19a6141d8baf4b1d25c3a888dd876a764d85def7 Mon Sep 17 00:00:00 2001 From: ZeguanXiao Date: Sun, 19 Oct 2025 10:42:15 +0800 Subject: [PATCH 5/6] fix: remove seed arg from data pipeline and make rank-aware seeding in seed_everything() --- src/data/__init__.py | 8 +++----- src/data/unlearn.py | 20 ++++++-------------- src/train.py | 1 + src/trainer/utils.py | 11 +++++++---- 4 files changed, 17 insertions(+), 23 deletions(-) diff --git a/src/data/__init__.py b/src/data/__init__.py index aa39bcbb..590cafae 100644 --- a/src/data/__init__.py +++ b/src/data/__init__.py @@ -46,7 +46,7 @@ def get_datasets(dataset_cfgs: Union[Dict, DictConfig], **kwargs): return dataset -def get_data(data_cfg: DictConfig, mode="train", seed=0, **kwargs): +def get_data(data_cfg: DictConfig, mode="train", **kwargs): data = {} data_cfg = dict(data_cfg) anchor = data_cfg.pop("anchor", "forget") @@ -56,9 +56,7 @@ def get_data(data_cfg: DictConfig, mode="train", seed=0, **kwargs): return data elif mode == "unlearn": unlearn_splits = {k: v for k, v in data.items() if k not in ("eval", "test")} - unlearn_dataset = ForgetRetainDataset( - **unlearn_splits, anchor=anchor, seed=seed - ) + unlearn_dataset = ForgetRetainDataset(**unlearn_splits, anchor=anchor) data["train"] = unlearn_dataset for split in unlearn_splits: data.pop(split) @@ -104,4 +102,4 @@ def get_collators(collator_cfgs, **kwargs): _register_data(ForgetRetainDataset) # Register collators -_register_collator(DataCollatorForSupervisedDataset) +_register_collator(DataCollatorForSupervisedDataset) \ No newline at end of file diff --git a/src/data/unlearn.py b/src/data/unlearn.py index dff81be7..190ed682 100644 --- a/src/data/unlearn.py +++ b/src/data/unlearn.py @@ -4,19 +4,17 @@ class ForgetRetainDataset(Dataset): # https://github.com/OPTML-Group/SOUL/blob/main/src/dataset/Base.py - def __init__(self, forget, retain, anchor="forget", seed=0): + def __init__(self, forget, retain, anchor="forget"): """Wraps the forget retain dataset into unlearning dataset. Args: forget (Dataset): Forget Dataset retain (Dataset): Retain Dataset anchor (str, optional): Specifies which dataset to anchor while randomly sampling from the other dataset. Defaults to 'forget'. - seed (int, optional): Random seed for reproducibility. Defaults to 0. """ self.forget = forget self.retain = retain self.anchor = anchor - self.seed = seed def __len__(self): """Ensures the sampled dataset matches the anchor dataset's length.""" @@ -35,22 +33,16 @@ def __len__(self): def __getitem__(self, idx): item = {} - g = torch.Generator() - rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 - rank_seed = self.seed + rank + idx - g.manual_seed(rank_seed) if self.anchor == "forget": item["forget"] = self.forget[idx] if self.retain: - retain_idx = torch.randint( - 0, len(self.retain), (1,), generator=g - ).item() + retain_idx = torch.randint(0, len(self.retain), (1,)).item() + print(retain_idx) + exit() item["retain"] = self.retain[retain_idx] elif self.anchor == "retain": item["retain"] = self.retain[idx] if self.forget: - forget_idx = torch.randint( - 0, len(self.forget), (1,), generator=g - ).item() + forget_idx = torch.randint(0, len(self.forget), (1,)).item() item["forget"] = self.forget[forget_idx] - return item + return item \ No newline at end of file diff --git a/src/train.py b/src/train.py index 5e8f6db5..05d74230 100644 --- a/src/train.py +++ b/src/train.py @@ -59,6 +59,7 @@ def main(cfg: DictConfig): evaluators=evaluators, template_args=template_args, ) + seed_everything(trainer_args.seed) if trainer_args.do_train: trainer.train() diff --git a/src/trainer/utils.py b/src/trainer/utils.py index af883e8a..bb04a08a 100644 --- a/src/trainer/utils.py +++ b/src/trainer/utils.py @@ -6,10 +6,13 @@ def seed_everything(seed=42): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) + # Each process gets a different seed to ensure different samples of unanchored data + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + rank_seed = seed + rank + random.seed(rank_seed) + np.random.seed(rank_seed) + torch.manual_seed(rank_seed) + torch.cuda.manual_seed_all(rank_seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False From 80ef3c7e8a8d558b86c019ee81e46df9b7b5ea8b Mon Sep 17 00:00:00 2001 From: ZeguanXiao Date: Sun, 19 Oct 2025 14:24:48 +0800 Subject: [PATCH 6/6] fix: use rank-specific seeding for ForgetRetainDataset --- src/data/unlearn.py | 15 +++++++++++---- src/train.py | 10 +++++++++- src/trainer/utils.py | 11 ++++------- 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/src/data/unlearn.py b/src/data/unlearn.py index 190ed682..c58a0596 100644 --- a/src/data/unlearn.py +++ b/src/data/unlearn.py @@ -15,6 +15,15 @@ def __init__(self, forget, retain, anchor="forget"): self.forget = forget self.retain = retain self.anchor = anchor + self.generator = torch.Generator() + + def set_rank_seed(self, seed: int): + """Set the rank-specific seed for this dataset. + + This should be called after trainer initialization to ensure each rank + uses a unique seed for different unanchored data. + """ + self.generator.manual_seed(seed) def __len__(self): """Ensures the sampled dataset matches the anchor dataset's length.""" @@ -36,13 +45,11 @@ def __getitem__(self, idx): if self.anchor == "forget": item["forget"] = self.forget[idx] if self.retain: - retain_idx = torch.randint(0, len(self.retain), (1,)).item() - print(retain_idx) - exit() + retain_idx = torch.randint(0, len(self.retain), (1,), generator=self.generator).item() item["retain"] = self.retain[retain_idx] elif self.anchor == "retain": item["retain"] = self.retain[idx] if self.forget: - forget_idx = torch.randint(0, len(self.forget), (1,)).item() + forget_idx = torch.randint(0, len(self.forget), (1,), generator=self.generator).item() item["forget"] = self.forget[forget_idx] return item \ No newline at end of file diff --git a/src/train.py b/src/train.py index 05d74230..6eb02717 100644 --- a/src/train.py +++ b/src/train.py @@ -1,6 +1,8 @@ +import torch import hydra from omegaconf import DictConfig from data import get_data, get_collators +from data.unlearn import ForgetRetainDataset from model import get_model from trainer import load_trainer from evals import get_evaluators @@ -59,7 +61,13 @@ def main(cfg: DictConfig): evaluators=evaluators, template_args=template_args, ) - seed_everything(trainer_args.seed) + + # Set rank-specific seed for ForgetRetainDataset after trainer initialization + train_dataset = data.get("train", None) + if isinstance(train_dataset, ForgetRetainDataset): + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + rank_seed = cfg.trainer.args.seed + rank + train_dataset.set_rank_seed(rank_seed) if trainer_args.do_train: trainer.train() diff --git a/src/trainer/utils.py b/src/trainer/utils.py index bb04a08a..af883e8a 100644 --- a/src/trainer/utils.py +++ b/src/trainer/utils.py @@ -6,13 +6,10 @@ def seed_everything(seed=42): - # Each process gets a different seed to ensure different samples of unanchored data - rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 - rank_seed = seed + rank - random.seed(rank_seed) - np.random.seed(rank_seed) - torch.manual_seed(rank_seed) - torch.cuda.manual_seed_all(rank_seed) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False