A deep learning project that implements a Variational Autoencoder (VAE) with a U-Net-like architecture, self-attention, and perceptual loss to colorize grayscale images of birds. This project was developed for the M.S. Machine Learning course to explore generative models and advanced computer vision techniques.
- U-Net VAE: Implements a Variational Autoencoder with an Encoder-Decoder structure featuring skip connections, similar to a U-Net.
- Advanced Architecture: Utilizes modern components including Residual Blocks, Squeeze-and-Excitation (SE) blocks, and Self-Attention layers.
- Hybrid Loss Function: Combines standard MSE, KL Divergence (for the latent space), and VGG19-based Perceptual Loss to generate sharp, high-fidelity images.
- Kaggle Integration: Includes a simple script to download and set up the "Birds 20 Species" dataset directly from Kaggle.
- Modular & Scripted: Refactored from a notebook into a clean, modular Python project with separate scripts for training and evaluation.
- Generative Models: Variational Autoencoders (VAEs)
- Representation Learning: Learning a compressed latent space for complex data.
- Encoder-Decoder Architecture: Using a U-Net structure with skip connections to preserve low-level details.
- Attention Mechanisms:
SelfAttentionandSEBlockto focus on relevant image features. - Transfer Learning: Using a pre-trained VGG19 network to calculate a Perceptual Loss, which better aligns with human visual perception.
The core of this project is a Variational Autoencoder (VAE), which learns a probabilistic mapping from grayscale images to color images. The architecture is a U-Net-style Encoder-Decoder.
-
Encoder (
src/model.py:Encoder):- Takes a 1-channel (grayscale) 160x160 image as input.
- It consists of a stack of
ResidualBlocks (each containingSEBlocks) andMaxPool2dlayers to downsample the image and extract features. - Skip connections are saved at each downsampling stage.
- A
SelfAttentionlayer is applied at the bottleneck to capture global dependencies. - Finally, it outputs the
mu(mean) andlogvar(log-variance) that define the latent space distribution.
-
Decoder (
src/model.py:Decoder):- A sample
zis drawn from the latent space using the reparameterization trick. - This sample is passed through a
SelfAttentionlayer and upsampled usingConvTranspose2dlayers. - At each upsampling stage, the corresponding skip connection from the encoder is concatenated. This allows the decoder to access low-level features (like edges) from the input, which is crucial for generating sharp images.
- The final layer is a
Conv2dfollowed by aSigmoidfunction to output a 3-channel (RGB) 160x160 image.
- A sample
-
Hybrid Loss Function (
src/vae_system.py:VAEColorizer.compute_loss): The model is trained on a hybrid loss function to balance three objectives:-
Pixel-wise Reconstruction (
$L_{MSE}$ ): A standard Mean Squared Error loss between the predicted color image and the actual color image. -
Perceptual Loss (
$L_{perceptual}$ ): This loss (fromsrc/vae_system.py:PerceptualLoss) feeds the predicted and actual images through a pre-trained VGG19 network. It then computes the L1 loss between the features extracted at different layers. This encourages the model to generate images that are perceptually similar to the target, resulting in sharper and more realistic textures. -
Kullback-Leibler (KL) Divergence (
$L_{KLD}$ ): The standard VAE loss term that regularizes the latent space, forcing it to approximate a unit Gaussian distribution. This enables the model's generative capabilities.
The final loss is a weighted sum:
$$L_{total} = L_{Recon} + \beta \cdot L_{KLD}$$ $$L_{Recon} = L_{MSE} + \lambda \cdot L_{perceptual}$$ -
Pixel-wise Reconstruction (
The model was trained for 50 epochs. The plots show its performance and the quality of the latent space.
-
Training & Validation Loss: The model converges steadily, with the validation loss tracking closely with the training loss, indicating no significant overfitting.
-
Colorization Examples: The model successfully learns to colorize the birds. It captures the general color scheme (e.g., blue for jays, yellow/green for others) and applies it plausibly. The perceptual loss helps maintain sharpness, though some color "bleeding" can be observed.
-
Latent Space (t-SNE): A t-SNE plot visualizes the 128-dimensional latent space in 2D. We can observe some emerging clusters. This suggests the encoder is learning to group images with similar structural features (e.g., bird pose, background texture) together in the latent space, which is a key goal of representation learning.
pytorch-vae-image-colorization/
├── .gitignore
├── LICENSE
├── README.md
├── requirements.txt # Project dependencies
├── notebooks/
│ └── colorization_demo.ipynb # Main notebook to run all scripts
├── scripts/
│ ├── setup_dataset.py # Script to download and setup the Kaggle dataset
│ ├── train.py # Main script to train the model
│ └── evaluate.py # Main script to evaluate the model and get plots
├── logs/ # Directory for log files (e.g., training.log)
├── models/ # Directory for saved .pth models
├── outputs/ # Directory for saved plots (loss, t-SNE, etc.)
└── src/
├── __init__.py # Makes 'src' a Python package
├── config.py # Stores all hyperparameters and paths
├── dataset.py # Contains BirdDataset class and DataLoader functions
├── model.py # Contains Encoder, Decoder, and other nn.Modules
├── vae_system.py # Contains the main VAEColorizer system and PerceptualLoss
├── engine.py # Contains the training and evaluation loops
└── utils.py # Contains logging, data loading, and plotting helpers
-
Clone the Repository:
git clone https://github.com/msmrexe/pytorch-vae-image-colorization.git cd pytorch-vae-image-colorization -
Setup Environment & Dataset
Install Dependencies:
pip install -r requirements.txt
Download Kaggle Data:
This project uses the "Birds 20 Species" dataset from Kaggle. The
setup_dataset.pyscript will download it for you. You will be prompted for your Kaggle username and API key.(You can get your API key by creating a
kaggle.jsonfile from your Kaggle Account settings page.)python setup_dataset.py
This will download and extract the data into a
./data/folder. -
Run Training
To train the model, run the
train.pyscript. You can adjust hyperparameters using command-line arguments (e.g.,--epochs 100).# Train with default settings (50 epochs) python train.py # Train for a different number of epochs python train.py --epochs 75
The script will save the final model to
models/vae_colorizer.pthand the loss plot tooutputs/training_loss_plot.png. -
Run Evaluation
After training, run the
evaluate.pyscript to generate the results grid and t-SNE plot using the saved model.python evaluate.py
This will save
colorization_results.pngandlatent_space_tsne.pngto theoutputs/directory. -
Run via Notebook
Alternatively, you can run all steps (install, setup, train, eval) sequentially inside the
Colorization_Demo.ipynbnotebook.
Feel free to connect or reach out if you have any questions!
- Maryam Rezaee
- GitHub: @msmrexe
- Email: ms.maryamrezaee@gmail.com
This project is licensed under the MIT License. See the LICENSE file for full details.