A modular implementation of the BERT (Bidirectional Encoder Representations from Transformers) architecture.
This repo provides a structured framework for both pretraining and finetuning BERT models with performance-focused components.
The project is organized into clear domains to separate concerns between modeling, data handling, and training:
The model is broken down into atomic components, allowing for independent testing and modification:
attention: Implements Multi-Head Self-Attention, including the scaling and masking logic.encoders: Contains theEncoderblocks and theStackedEncoderwhich orchestrates the depth of the model.bert: Assembly logic forBertBackboneand specific heads forPreTraining(MLM/NSP) andSequenceClassification.embeddings: Handles the summation of word, position, and token-type embeddings.pooler: Extracts the representation of the[CLS]token for downstream tasks.
The data module manages the transition from raw text to structured tensors using a strict schema-driven approach.
- Pretraining Logic: Implements the standard BERT objectives: dynamic Masked Language Modeling (MLM) and Next Sentence Prediction (NSP).
- Core Data Types: Defines the structural interface to ensure tensors remain consistent across training, validation, and inference.
- Dataset Management: Handles the heavy lifting of padding, sequence truncation, and the creation of efficient data streams.
Training logic is abstracted into specialized workflows to handle different stages of the model lifecycle:
- Core Optimization: A centralized engine that manages gradient updates, weight decay, and linear learning rate scheduling.
- Pretraining Workflow: Specialized for large-scale, long-running tasks featuring infinite data streaming and periodic state synchronization.
- Task Adaptation: A workflow focused on supervised learning, featuring evaluation loops and performance-based model versioning.
- Experiment Tracking: Manages the persistence of training metrics, artifacts and global configuration states, integrating visualization via TensorBoard.
The repository includes a Rust-based backend to eliminate the performance bottlenecks typically found in text preprocessing.
- Logic: The heavy computation for WordPiece and BPE (Byte Pair Encoding) algorithms is implemented in Rust. This allows for faster vocabulary training and text encoding than pure Python implementations.
- Python Interface: Native modules are compiled and exposed to the Python environment, providing a seamless experience where Rust speed meets Python's ease of use.
The notes/ directory contains several Notebooks that explain the internal mechanics of different components.
These serve as a technical companion to the source code, documenting the "why" behind specific implementation choices.
Ensure you have the Rust toolchain installed, and the Python package and project manager uv:
uv venv --python 3.13
source .venv/bin/activate
make install
# add src and rust to your python path
export PYTHONPATH=$(PWD)/src:$(PWD)/rustYou can refer to the notes to see how to train a tokenizer, bert model or use any of the modules.
Below is a snippet on how to pretrain a BERT model with all default settings:
from pathlib import Path
from data.pretraining import PretrainingCorpusData, PretrainingDataset
from modules.bert.pretraining import BertForPreTraining
from settings import SETTINGS
from token_encoders.rust.bpe import RustBPETokenizer
from tracker import ExperimentTracker
from trainers.pretraining import PreTrainer
# train a tokenizer
tokenizer = RustBPETokenizer(SETTINGS.tokenizer)
tokenizer.train([Path("data/wikitext-103-raw-v1/tokenizer.txt").read_text()])
tokenizer.save(Path("saved_tokenizers/bpe"))
# load corpus data and prepare pretraining dataset
corpus_data = PretrainingCorpusData.from_file(
Path("data/wikitext-103-raw-v1/pretraining_bert.txt")
)
dataset = PretrainingDataset(
data=corpus_data, tokenizer=tokenizer, loader_settings=SETTINGS.loader
)
# define BERT model
model = BertForPreTraining(SETTINGS.bert)
# define trainer and start training
trainer = PreTrainer(
model=model,
train_dataset=dataset,
settings=SETTINGS.pretrainer,
tracker=ExperimentTracker(SETTINGS.tracker, [SETTINGS.bert, SETTINGS.tokenizer]),
)
trainer.train()