diff --git a/pyproject.toml b/pyproject.toml index 23b95716..d8003525 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,8 @@ dependencies = [ "benchmark-db-writer @ git+https://github.com/AI-Hypercomputer/aotc.git@2ff16e670df20b497ddaf1f86920dbb5dd9f0c8f#subdirectory=src/aotc/benchmark_db_writer", "dacite==1.9.2", "click~=8.1.8", - "google-cloud-storage==2.19.0" + "google-cloud-storage==2.19.0", + "gcsfs" ] [project.optional-dependencies] diff --git a/torchprime/data/dataset.py b/torchprime/data/dataset.py index ed9aacdd..c961784a 100644 --- a/torchprime/data/dataset.py +++ b/torchprime/data/dataset.py @@ -1,11 +1,14 @@ """Utilities for preparing datasets for basic training tasks.""" import json +import logging import fsspec -from datasets import Dataset, DatasetDict, load_dataset +from datasets import Dataset, DatasetDict, load_dataset, load_from_disk from transformers.tokenization_utils import PreTrainedTokenizerBase +logger = logging.getLogger(__name__) + def _load_json_dataset(path: str, split: str) -> Dataset: """Load a dataset from a JSON Lines file. @@ -33,6 +36,7 @@ def _load_hf_dataset( config: str | None, split: str, cache_dir: str | None, + streaming: bool = False, ) -> Dataset: """Download and return a dataset from Hugging Face Hub. @@ -41,12 +45,19 @@ def _load_hf_dataset( config: Optional configuration name. split: Split to load. cache_dir: Directory where the dataset cache should live. + streaming: Whether to stream the dataset. Returns: The loaded ``Dataset`` instance for ``split``. """ - data = load_dataset(name, config, split=split, cache_dir=cache_dir) + data = load_dataset( + name, + config, + split=split, + cache_dir=cache_dir, + streaming=streaming, + ) assert isinstance(data, Dataset | DatasetDict) if isinstance(data, DatasetDict): data = data[split] @@ -59,6 +70,7 @@ def load_hf_or_json_dataset( file_dataset_path: str | None = None, split: str = "train", cache_dir: str | None = None, + streaming: bool = False, ): """Loads a dataset either from Hugging Face Hub or a local/remote JSONL file. @@ -72,12 +84,19 @@ def load_hf_or_json_dataset( file_dataset_path: Optional path to a JSONL file (local or remote). split: Dataset split to load (default is "train"). cache_dir: Optional directory to use for dataset caching (HF only). + streaming: Whether to stream the dataset (HF only). Returns: A HuggingFace ``Dataset`` instance. """ if hf_dataset_name: - data = _load_hf_dataset(hf_dataset_name, hf_dataset_config_name, split, cache_dir) + data = _load_hf_dataset( + hf_dataset_name, + hf_dataset_config_name, + split, + cache_dir, + streaming, + ) elif file_dataset_path: data = _load_json_dataset(file_dataset_path, split) else: @@ -89,6 +108,7 @@ def load_hf_or_json_dataset( def make_train_dataset( + cached_dataset_path: str | None = None, hf_dataset_name: str | None = None, hf_dataset_config_name: str | None = None, file_dataset_path: str | None = None, @@ -97,6 +117,10 @@ def make_train_dataset( *, tokenizer: PreTrainedTokenizerBase, block_size: int, + text_column: str = "text", + streaming: bool = False, + num_proc: int | None = None, + **kwargs, ) -> Dataset: """Loads and tokenizes a dataset, then chunks it into fixed-size blocks for training. @@ -106,6 +130,7 @@ def make_train_dataset( for efficient language modeling, especially on accelerators like TPUs. Args: + cached_dataset_path: Optional path to a pre-processed, cached dataset. hf_dataset_name: Optional Hugging Face dataset name. (e.g., "wikitext"). hf_dataset_config_name: Optional HF dataset config name. (e.g., "wikitext-103-raw-v1"). file_dataset_path: Optional path or ``gs://`` URI to a JSONL dataset. @@ -113,24 +138,38 @@ def make_train_dataset( cache_dir: Optional directory for HF dataset cache. tokenizer: A Hugging Face tokenizer used to tokenize the input text. block_size: The fixed length of each chunked training example. + text_column: The name of the column containing the text to be tokenized. + streaming: Whether to stream the dataset. + num_proc: Number of processes for multiprocessing. + **kwargs: Unused keyword arguments. Returns: A `Dataset` object containing tokenized and block-wise grouped training examples, each with keys `"input_ids"` and `"labels"`. """ + if cached_dataset_path: + logger.info(f"Loading cached dataset from: {cached_dataset_path}") + # `load_from_disk` works seamlessly with local paths and GCS URIs. + data = load_from_disk(cached_dataset_path) + return data + + logger.info("No `cached_dataset_path` provided. Processing dataset on-the-fly...") + data = load_hf_or_json_dataset( hf_dataset_name=hf_dataset_name, hf_dataset_config_name=hf_dataset_config_name, file_dataset_path=file_dataset_path, split=split, cache_dir=cache_dir, + streaming=streaming, ) column_names = list(data.features) data = data.map( - lambda samples: tokenizer(samples["text"]), + lambda samples: tokenizer(samples[text_column]), batched=True, remove_columns=column_names, + num_proc=num_proc, ) def group_texts(examples): @@ -155,5 +194,5 @@ def group_texts(examples): result["labels"] = result["input_ids"].copy() return result - data = data.map(group_texts, batched=True) + data = data.map(group_texts, batched=True, num_proc=num_proc) return data diff --git a/torchprime/launcher/cli.py b/torchprime/launcher/cli.py index a5134337..36916bc4 100644 --- a/torchprime/launcher/cli.py +++ b/torchprime/launcher/cli.py @@ -4,6 +4,7 @@ import getpass import json +import logging import os import re import subprocess @@ -19,6 +20,7 @@ from dataclasses_json import dataclass_json from pathspec import PathSpec from pathspec.patterns import GitWildMatchPattern # type: ignore +from transformers import AutoTokenizer from watchdog.events import FileSystemEventHandler from watchdog.observers import Observer @@ -77,6 +79,78 @@ def cli(ctx, interactive): ctx.obj["interactive"] = interactive +@cli.command() +@click.option("--dataset-name", required=True, help="Name of the Hugging Face dataset.") +@click.option( + "--dataset-config-name", + default=None, + help="Configuration name of the Hugging Face dataset.", +) +@click.option( + "--tokenizer-name", required=True, help="Name of the Hugging Face tokenizer." +) +@click.option( + "--output-path", + required=True, + help="Path to save the processed dataset (local or GCS).", +) +@click.option( + "--block-size", type=int, default=4096, help="Sequence length for packing." +) +@click.option("--split", default="train", help="Dataset split to process.") +@click.option("--text-column", default="text", help="The column containing text data.") +@click.option( + "--cache-dir", default=None, help="Directory to cache the raw dataset downloads." +) +@click.option("--num-workers", type=int, default=50, help="Number of Dataflow workers.") +def preprocess( + dataset_name, + dataset_config_name, + tokenizer_name, + output_path, + block_size, + split, + text_column, + cache_dir, + num_workers, +): + """Preprocesses a dataset and saves it to a specified location.""" + from torchprime.data.dataset import make_train_dataset + + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO) + + logger.info("Starting dataset preprocessing...") + + logger.info(f"Loading tokenizer: {tokenizer_name}") + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + logger.info("Loading and preprocessing raw dataset...") + processed_dataset = make_train_dataset( + hf_dataset_name=dataset_name, + hf_dataset_config_name=dataset_config_name, + split=split, + tokenizer=tokenizer, + block_size=block_size, + text_column=text_column, + streaming=False, + cache_dir=cache_dir, + num_proc=num_workers, + ) + + logger.info("Preprocessing finished. Now saving to disk...") + logger.info(f"Saving processed dataset to: {output_path}") + processed_dataset.save_to_disk(output_path) + logger.info("Preprocessing complete.") + + @cli.command() @click.option("--cluster", required=True, help="Name of the XPK cluster") @click.option("--project", required=True, help="GCP project the cluster belongs to") diff --git a/torchprime/metrics/metrics.py b/torchprime/metrics/metrics.py index fa13ac1b..5ac5b817 100644 --- a/torchprime/metrics/metrics.py +++ b/torchprime/metrics/metrics.py @@ -19,6 +19,9 @@ class Metrics: step_execution_time: timedelta | None """The average time to execute a training step.""" + dataset_load_time: timedelta | None + """The time it took to load and process the dataset.""" + mfu: float | None """Model FLOPs Utilization.""" @@ -60,10 +63,14 @@ def __init__(self): self.mfu = None self.tokens_per_second = None self.num_steps = None + self.dataset_load_time: float | None = None def log_step_execution_time(self, step_execution_time: float): self.step_execution_time = step_execution_time + def log_dataset_load_time(self, dataset_load_time: float): + self.dataset_load_time = dataset_load_time + def log_mfu(self, mfu: float): self.mfu = mfu @@ -80,6 +87,9 @@ def finalize(self) -> Metrics: step_execution_time=timedelta(seconds=self.step_execution_time) if self.step_execution_time else None, + dataset_load_time=timedelta(seconds=self.dataset_load_time) + if self.dataset_load_time is not None + else None, mfu=self.mfu, tokens_per_second=self.tokens_per_second, num_steps=self.num_steps, diff --git a/torchprime/torch_xla_models/configs/default.yaml b/torchprime/torch_xla_models/configs/default.yaml index 0cc84745..52e1138e 100644 --- a/torchprime/torch_xla_models/configs/default.yaml +++ b/torchprime/torch_xla_models/configs/default.yaml @@ -31,6 +31,10 @@ output_dir: outputs # If unspecified, defaults to the current date and time. run_name: null +dataset: + # Default path for preprocessed data, can be overridden in dataset-specific configs + cached_dataset_path: null + # The virtual device mesh shape to use within a TPU slice. This is also called # the "ICI mesh", since devices within a slice enjoy a faster network called # "Inter-Chip Interconnect". diff --git a/torchprime/torch_xla_models/train.py b/torchprime/torch_xla_models/train.py index a8dcabef..826aad22 100644 --- a/torchprime/torch_xla_models/train.py +++ b/torchprime/torch_xla_models/train.py @@ -2,6 +2,7 @@ import logging import sys +from timeit import default_timer as timer import datasets import hydra @@ -69,12 +70,25 @@ def main(config: omegaconf.DictConfig): trainer_cls = torchprime.torch_xla_models.trainer.TRAINERS.get( config.task.name, torchprime.torch_xla_models.trainer.Trainer ) + + load_time_start = timer() data = retry.retry(lambda: dataset_fn(**config.dataset, tokenizer=tokenizer)) + load_time_end = timer() + load_time_seconds = load_time_end - load_time_start dataset_name = getattr(config.dataset, "hf_dataset_name", None) or getattr( config.dataset, "file_dataset_path", "unknown" ) - logger.info("Loaded dataset `%s`, size=%d (packed) samples", dataset_name, len(data)) + num_tokens = len(data) * config.dataset.block_size + tokens_per_second = num_tokens / load_time_seconds + logger.info("--- Dataset Loading Benchmark ---") + logger.info(" Dataset: %s", dataset_name) + logger.info(" Num samples: %d", len(data)) + logger.info(" Total tokens: %d", num_tokens) + logger.info(f" Load time: {load_time_seconds:.2f} seconds") + logger.info(f" Tokens/sec: {tokens_per_second:,.2f}") + logger.info("---------------------------------") + metrics_logger.log_dataset_load_time(load_time_seconds) trainer = trainer_cls( model=model, diff --git a/torchprime/torch_xla_models/trainer/base_trainer.py b/torchprime/torch_xla_models/trainer/base_trainer.py index 65e502eb..ec2f0918 100644 --- a/torchprime/torch_xla_models/trainer/base_trainer.py +++ b/torchprime/torch_xla_models/trainer/base_trainer.py @@ -17,6 +17,7 @@ from pathlib import Path from timeit import default_timer as timer +import numpy as np import torch import torch.nn.utils as nn_utils import torch_xla @@ -49,6 +50,7 @@ setup_sharding_and_mesh, ) from torchprime.torch_xla_models.topology import get_num_slices +from torchprime.utils.data_load_benchmark_logger import DataLoadBenchmarkLogger from torchprime.utils.profiling import ensure_profile_end_step logger = logging.getLogger(__name__) @@ -91,10 +93,17 @@ def __init__( self.device = xm.xla_device() self.global_batch_size = self.config.task.global_batch_size self.train_dataset = train_dataset + self.dataloader_wait_times = [] # Initialize tensorboard metrics writer self._initialize_tensorboard_writer() + self.benchmark_logger = DataLoadBenchmarkLogger( + self.config.output_dir, + "dataloader_benchmark.csv", + fieldnames=["epoch", "step", "wait_time_ms", "compute_time_ms"], + ) + # -- Model transformations -- # Recursively replace `nn.Linear` layers with einsum operations in the model. # Without this patch, an `nn.Linear` module will flatten non-contracting dimensions @@ -205,7 +214,9 @@ def _get_train_dataloader(self) -> pl.MpDeviceLoader: drop_last=True, ) loader = pl.MpDeviceLoader( - dataloader, self.device, input_sharding=self.input_sharding_spec + dataloader, + self.device, + input_sharding=self.input_sharding_spec, ) return loader @@ -235,6 +246,7 @@ def train_loop(self) -> None: epoch = 0 for step in range(max_step): + wait_start_time = timer() try: batch = next(train_iterator) except StopIteration: @@ -243,34 +255,59 @@ def train_loop(self) -> None: train_iterator = iter(train_loader) batch = next(train_iterator) + wait_end_time = timer() + batch_wait_time = wait_end_time - wait_start_time + batch_wait_time_ms = batch_wait_time * 1000 + trace_start_time = timer() loss, grad_norm = self.train_step(batch) trace_end_time = timer() + compute_time_ms = (trace_end_time - trace_start_time) * 1000 + + self.dataloader_wait_times.append(batch_wait_time) + self.benchmark_logger.log_step( + epoch=step / steps_per_epoch, + step=step, + wait_time_ms=batch_wait_time_ms, + compute_time_ms=compute_time_ms, + ) + logger.info( + f"Epoch: {epoch:.4f}, step: {step}, batch loading time: {batch_wait_time_ms:.2f} ms" + ) if step % self.config.logging_steps == 0: def step_closure( - epoch, step, loss, grad_norm, trace_start_time, trace_end_time, lr + fractional_epoch, step, loss, grad_norm, trace_start_time, trace_end_time, lr ): loss = loss.detach().item() grad_norm = grad_norm.detach().item() + compute_time_ms = (trace_end_time - trace_start_time) * 1000 + + # A moving average of wait time over the last logging window. + wait_time_ms = ( + np.mean(self.dataloader_wait_times[-self.config.logging_steps :]) * 1000 + ) + step_time_ms = compute_time_ms + wait_time_ms logger.info( - "Epoch: %.4f, step: %d, loss: %.4f, grad_norm: %.4f, lr: %.2e, trace time: %.2f ms", - step / steps_per_epoch, + "Epoch: %.4f, step: %d, loss: %.4f, grad_norm: %.4f, lr: %.2e, step time: %.2f ms (compute: %.2f, wait: %.2f)", + fractional_epoch, step, loss, grad_norm, lr, - (trace_end_time - trace_start_time) * 1000, + step_time_ms, + compute_time_ms, + wait_time_ms, ) - self._log_to_tensorboard(epoch, step, loss, lr, grad_norm) + self._log_to_tensorboard(fractional_epoch, step, loss, lr, grad_norm) if math.isnan(loss): raise ValueError(f"Loss is NaN at step {step}") xm.add_step_closure( step_closure, args=( - epoch, + step / steps_per_epoch, step, loss, grad_norm, @@ -333,6 +370,9 @@ def finalize_training(self, metrics_logger) -> None: # Print and save metrics metrics = metrics_logger.finalize() logger.info("***** train metrics *****\n%s", metrics) + + # The benchmark logger now handles file operations within each log_step call. + logger.info("Saving data loading time benchmark log...") metrics.save(Path(self.config.output_dir) / "train_metrics.json") # Save the hydra config diff --git a/torchprime/utils/data_load_benchmark_logger.py b/torchprime/utils/data_load_benchmark_logger.py new file mode 100644 index 00000000..658bc07d --- /dev/null +++ b/torchprime/utils/data_load_benchmark_logger.py @@ -0,0 +1,30 @@ +import csv +from pathlib import Path +from typing import Any + + +class DataLoadBenchmarkLogger: + """A simple logger for writing data loading benchmark data to a CSV file.""" + + def __init__(self, output_dir: str, filename: str, fieldnames: list[str]): + """Initializes the logger. + + Args: + output_dir: The directory where the log file will be saved. + filename: The name of the CSV file. + fieldnames: The list of column names for the CSV file. + """ + self.output_path = Path(output_dir) / filename + self.fieldnames = fieldnames + # Ensure the output directory exists. + self.output_path.parent.mkdir(parents=True, exist_ok=True) + + def log_step(self, **kwargs: Any): + """Logs a single step of benchmark data.""" + file_exists = self.output_path.exists() + + with self.output_path.open("a", newline="") as f: + writer = csv.DictWriter(f, fieldnames=self.fieldnames) + if not file_exists or f.tell() == 0: + writer.writeheader() + writer.writerow(kwargs)