diff --git a/src/trainer.py b/src/trainer.py index 59e18b5..1d0dbf9 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -1,6 +1,9 @@ import os from pathlib import Path +from typing import List, Dict, Any +import torch + from datasets import Dataset, load_dataset from omegaconf import DictConfig, OmegaConf from transformers import ( @@ -54,10 +57,17 @@ def __init__(self, cfg: DictConfig) -> None: data_path = Path(cfg.dataset.path) self.train_dataset = self.get_dataset(data_path / "train", "train") self.test_dataset = self.get_dataset(data_path / "test", "test") + # collator = DataCollatorForLanguageModeling( + # self.tokenizer, + # mlm=False, # causal‑LM + # pad_to_multiple_of=8, # keeps tensor shapes friendly for fp16 + # ) + + # default causal‑LM collator will pad and set labels=input_ids collator = DataCollatorForLanguageModeling( - self.tokenizer, - mlm=False, # causal‑LM - pad_to_multiple_of=8, # keeps tensor shapes friendly for fp16 + tokenizer=self.tokenizer, + mlm=False, + pad_to_multiple_of=8, ) # Checkpointing @@ -103,9 +113,27 @@ def __init__(self, cfg: DictConfig) -> None: def run(self): self.trainer.train() - def tokenize(self, example): - ids = self.tokenizer(example["text"]).input_ids - return {"input_ids": ids, "len": len(ids)} + # def tokenize(self, example): + # ids = self.tokenizer(example["text"]).input_ids + # return {"input_ids": ids, "len": len(ids)} + + def tokenize(self, example: Dict[str,str]) -> Dict[str,Any]: + """ + Split the raw text into lines. For each line: + - if it starts with 'req:', mask all its tokens (label=-100) + - if it starts with 'res:', label tokens normally (so they count in the loss) + """ + # Just tokenize the full API call + response; let the collator set labels = input_ids + tok = self.tokenizer( + example["text"], + truncation=True, + max_length=self.context_length, + ) + return { + "input_ids": tok["input_ids"], + "len": len(tok["input_ids"]), + } + def get_dataset(self, path: Path, split: str) -> Dataset: txt_files = [str(p) for p in path.glob("*.txt")]