- Python >=3.11
- uv
- Download M5 Dataset following the instructions in Reamde
- Run
./dataset/M5/extract.sh - Run
uv syncto download the required dependencies - Run
uv run pretrainm5.pyfor pre-training (atleast for 10 ep) - Run
rye run trainm5.pyfor training (atleast for 100 epochs)
- Requires atleast 70 GB VRAM with Mixed precision
- Toggle
SCALE_PREC = Falseintrainm5.pyto use FP16 to run on GPUs of less than 40 GB VRAM