From fca86163ec78bc8338021db0ecea5180a0f90fc7 Mon Sep 17 00:00:00 2001 From: Garvit Luhadia Date: Mon, 5 May 2025 18:14:40 -0400 Subject: [PATCH 1/2] initializing loss optimization --- src/data_collator_mask.py | 43 +++++++++++++++++++++++++++++++++ src/trainer.py | 51 +++++++++++++++++++++++++++++++++------ 2 files changed, 87 insertions(+), 7 deletions(-) create mode 100644 src/data_collator_mask.py diff --git a/src/data_collator_mask.py b/src/data_collator_mask.py new file mode 100644 index 0000000..85e072c --- /dev/null +++ b/src/data_collator_mask.py @@ -0,0 +1,43 @@ +# src/data_collator_mask.py +import torch +from typing import List, Dict, Any +from transformers import PreTrainedTokenizerBase + +class DataCollatorForCausalLMMaskRequests: + """ + Pads both input_ids and a precomputed 'labels' array + (with -100 in positions we want to ignore) into batched tensors. + """ + def __init__(self, tokenizer: PreTrainedTokenizerBase, pad_to_multiple_of: int = None): + self.tokenizer = tokenizer + self.pad_to_multiple_of = pad_to_multiple_of + + def __call__(self, features: List[Dict[str,Any]]) -> Dict[str,torch.Tensor]: + # 1) pad input_ids + attention_mask + batch = self.tokenizer.pad( + features, + return_tensors="pt", + padding=True, + pad_to_multiple_of=self.pad_to_multiple_of, + ) + input_ids = batch["input_ids"] + + # 2) build & pad the labels list + labels_list = [torch.tensor(f["labels"], dtype=torch.long) for f in features] + labels_padded = torch.nn.utils.rnn.pad_sequence( + labels_list, + batch_first=True, + padding_value=-100 + ) + + # 3) if our labels are shorter (no. of tokens) than input_ids, pad with -100 + if labels_padded.size() != input_ids.size(): + diff = input_ids.size(1) - labels_padded.size(1) + labels_padded = torch.nn.functional.pad( + labels_padded, + (0, diff), + value=-100 + ) + + batch["labels"] = labels_padded + return batch diff --git a/src/trainer.py b/src/trainer.py index 59e18b5..e5fdcca 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -12,6 +12,8 @@ from transformers import Trainer as TransformersTrainer from transformers import TrainingArguments +from typing import List, Dict, Any +import torch class Trainer: @@ -54,10 +56,17 @@ def __init__(self, cfg: DictConfig) -> None: data_path = Path(cfg.dataset.path) self.train_dataset = self.get_dataset(data_path / "train", "train") self.test_dataset = self.get_dataset(data_path / "test", "test") - collator = DataCollatorForLanguageModeling( - self.tokenizer, - mlm=False, # causal‑LM - pad_to_multiple_of=8, # keeps tensor shapes friendly for fp16 + # collator = DataCollatorForLanguageModeling( + # self.tokenizer, + # mlm=False, # causal‑LM + # pad_to_multiple_of=8, # keeps tensor shapes friendly for fp16 + # ) + # use our custom collator so our 'labels' survive padding + + from src.data_collator_mask import DataCollatorForCausalLMMaskRequests + collator = DataCollatorForCausalLMMaskRequests( + tokenizer=self.tokenizer, + pad_to_multiple_of=8, ) # Checkpointing @@ -103,9 +112,37 @@ def __init__(self, cfg: DictConfig) -> None: def run(self): self.trainer.train() - def tokenize(self, example): - ids = self.tokenizer(example["text"]).input_ids - return {"input_ids": ids, "len": len(ids)} + # def tokenize(self, example): + # ids = self.tokenizer(example["text"]).input_ids + # return {"input_ids": ids, "len": len(ids)} + + def tokenize(self, example: Dict[str,str]) -> Dict[str,Any]: + """ + Split the raw text into lines. For each line: + - if it starts with 'req:', mask all its tokens (label=-100) + - if it starts with 'res:', label tokens normally (so they count in the loss) + """ + raw = example["text"].splitlines(keepends=True) + input_ids: List[int] = [] + labels: List[int] = [] + + for line in raw: + # tokenize *without* special‑tokens so we preserve the exact segments + ids = self.tokenizer(line, add_special_tokens=False).input_ids + input_ids.extend(ids) + + if line.startswith("req:"): + # mask all request tokens + labels.extend([-100] * len(ids)) + else: + # treat as response (or any other line) – compute loss + labels.extend(ids) + + return { + "input_ids": input_ids, + "labels": labels, + "len": len(input_ids), + } def get_dataset(self, path: Path, split: str) -> Dataset: txt_files = [str(p) for p in path.glob("*.txt")] From 3aaf28433c1fd2b57f56bb1960bbc877717ac66a Mon Sep 17 00:00:00 2001 From: Garvit Luhadia Date: Mon, 5 May 2025 18:57:19 -0400 Subject: [PATCH 2/2] simplifying loss masking --- src/data_collator_mask.py | 43 --------------------------------------- src/trainer.py | 39 ++++++++++++++--------------------- 2 files changed, 15 insertions(+), 67 deletions(-) delete mode 100644 src/data_collator_mask.py diff --git a/src/data_collator_mask.py b/src/data_collator_mask.py deleted file mode 100644 index 85e072c..0000000 --- a/src/data_collator_mask.py +++ /dev/null @@ -1,43 +0,0 @@ -# src/data_collator_mask.py -import torch -from typing import List, Dict, Any -from transformers import PreTrainedTokenizerBase - -class DataCollatorForCausalLMMaskRequests: - """ - Pads both input_ids and a precomputed 'labels' array - (with -100 in positions we want to ignore) into batched tensors. - """ - def __init__(self, tokenizer: PreTrainedTokenizerBase, pad_to_multiple_of: int = None): - self.tokenizer = tokenizer - self.pad_to_multiple_of = pad_to_multiple_of - - def __call__(self, features: List[Dict[str,Any]]) -> Dict[str,torch.Tensor]: - # 1) pad input_ids + attention_mask - batch = self.tokenizer.pad( - features, - return_tensors="pt", - padding=True, - pad_to_multiple_of=self.pad_to_multiple_of, - ) - input_ids = batch["input_ids"] - - # 2) build & pad the labels list - labels_list = [torch.tensor(f["labels"], dtype=torch.long) for f in features] - labels_padded = torch.nn.utils.rnn.pad_sequence( - labels_list, - batch_first=True, - padding_value=-100 - ) - - # 3) if our labels are shorter (no. of tokens) than input_ids, pad with -100 - if labels_padded.size() != input_ids.size(): - diff = input_ids.size(1) - labels_padded.size(1) - labels_padded = torch.nn.functional.pad( - labels_padded, - (0, diff), - value=-100 - ) - - batch["labels"] = labels_padded - return batch diff --git a/src/trainer.py b/src/trainer.py index e5fdcca..1d0dbf9 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -1,6 +1,9 @@ import os from pathlib import Path +from typing import List, Dict, Any +import torch + from datasets import Dataset, load_dataset from omegaconf import DictConfig, OmegaConf from transformers import ( @@ -12,8 +15,6 @@ from transformers import Trainer as TransformersTrainer from transformers import TrainingArguments -from typing import List, Dict, Any -import torch class Trainer: @@ -61,11 +62,11 @@ def __init__(self, cfg: DictConfig) -> None: # mlm=False, # causal‑LM # pad_to_multiple_of=8, # keeps tensor shapes friendly for fp16 # ) - # use our custom collator so our 'labels' survive padding - from src.data_collator_mask import DataCollatorForCausalLMMaskRequests - collator = DataCollatorForCausalLMMaskRequests( + # default causal‑LM collator will pad and set labels=input_ids + collator = DataCollatorForLanguageModeling( tokenizer=self.tokenizer, + mlm=False, pad_to_multiple_of=8, ) # Checkpointing @@ -122,28 +123,18 @@ def tokenize(self, example: Dict[str,str]) -> Dict[str,Any]: - if it starts with 'req:', mask all its tokens (label=-100) - if it starts with 'res:', label tokens normally (so they count in the loss) """ - raw = example["text"].splitlines(keepends=True) - input_ids: List[int] = [] - labels: List[int] = [] - - for line in raw: - # tokenize *without* special‑tokens so we preserve the exact segments - ids = self.tokenizer(line, add_special_tokens=False).input_ids - input_ids.extend(ids) - - if line.startswith("req:"): - # mask all request tokens - labels.extend([-100] * len(ids)) - else: - # treat as response (or any other line) – compute loss - labels.extend(ids) - + # Just tokenize the full API call + response; let the collator set labels = input_ids + tok = self.tokenizer( + example["text"], + truncation=True, + max_length=self.context_length, + ) return { - "input_ids": input_ids, - "labels": labels, - "len": len(input_ids), + "input_ids": tok["input_ids"], + "len": len(tok["input_ids"]), } + def get_dataset(self, path: Path, split: str) -> Dataset: txt_files = [str(p) for p in path.glob("*.txt")]