PyTorch implementation and training of DeepCaps: Going Deeper with Capsule Networks
DeepCaps_Pytorch is a PyTorch implementation of DeepCaps: Going Deeper with Capsule Networks, based on the paper "DeepCaps: Going Deeper with Capsule Networks" (Rajasegaran et al., 2019). See: https://arxiv.org/abs/1904.09546
This project adapts the DeepCaps architecture for applications such as galaxy morphology classification (using SDSS / Galaxy Zoo data) as part of an academic research investigation.
It demonstrates how Capsule Networks can be extended to deeper architectures and compared against CNN-based baselines.
- 🧩 Modular implementation of DeepCaps architecture in PyTorch
- 📈 Training, evaluation, and prediction scripts
- 🪐 Dataset support for SDSS / Galaxy Zoo images
- ⚙️ Configuration system (
cfg.py) for hyperparameter tuning - 📊 Visualization tools for accuracy and loss curves
- 🧠 Easily extendable to new datasets and domains
- Language: Python 3.10+
- Framework: PyTorch
- Visualization: Matplotlib, Seaborn
- Utilities: NumPy, tqdm, Pillow
├── acc_plot.py # Plot accuracy/loss curves
├── cfg.py # Configuration / hyperparameters
├── helpers.py # Utility functions
├── load_data.py # Dataset loader base
├── load_data_sdss.py # SDSS-specific data loader
├── model.py # DeepCaps architecture
├── plot.py # Visualization utilities
├── predictor.py # Inference and prediction
├── train.py # Model training script
├── requirements.txt
└── README.md
Clone the repository and install dependencies:
git clone https://github.com/Commit2Cosmos/DeepCaps_Pytorch.git
cd DeepCaps_Pytorch
pip install -r requirements.txtIf using a GPU, ensure PyTorch is installed with CUDA support. You can verify this by running:
python -c "import torch; print(torch.cuda.is_available())"To train the DeepCaps model, simply run:
python train.pyYou can modify hyperparameters (epochs, learning rate, dataset paths, etc.) inside cfg.py.
Once trained, you can run inference using:
python predictor.py- Update cfg.py — modify dataset paths, batch size, and training hyperparameters.
- Add a new loader — write a new dataset class in load_data.py following the torch.utils.data.Dataset interface.
- Edit train.py — point the dataset reference to your new loader.
- Run training — DeepCaps will adapt to the new data dimensions and classes.
- Evaluate and visualize — use acc_plot.py and plot.py to generate metrics.

