- Python >=3.11
- uv
- Download M5 Dataset following the instructions in Reamde
- Run ./dataset/M5/extract.sh
- Run uv syncto download the required dependencies
- Run uv run pretrainm5.pyfor pre-training (atleast for 10 ep)
- Run rye run trainm5.pyfor training (atleast for 100 epochs)
- Requires atleast 70 GB VRAM with Mixed precision
- Toggle SCALE_PREC = Falseintrainm5.pyto use FP16 to run on GPUs of less than 40 GB VRAM