This project implements a Multi-Layer Perceptron (MLP) for classifying images from the CIFAR-10 dataset. It includes data augmentation, feature extraction using a pre-trained ResNet18, and MLP training from scratch with backpropagation. The project is part of CS776 Computer Vision Assignment 1.
- Dataset: CIFAR-10 (downloaded and placed in
data/folder). - Feature Extraction: Uses ResNet18 to extract 512-dimensional features.
- Model: Single hidden layer MLP (64 neurons, ReLU activation) with softmax output and categorical cross-entropy loss.
- Augmentation: Applies rotations, flips, brightness adjustments, and Gaussian noise.
- Training: Mini-batch gradient descent; trained on original and augmented data.
- Derivations: Backpropagation equations detailed in
readme.pdf. - Results: ~67-69% test accuracy (original: 67.9%, augmented: 68.9%).
- Python 3.x
- Libraries: NumPy, OpenCV, Matplotlib, PyTorch (for feature extraction), Pickle, Random
No additional installations needed beyond standard pip installs (e.g., pip install numpy opencv-python matplotlib torch torchvision).
The CIFAR-10 dataset consists of: 50,000 training images. 10,000 test images. 10 classes Download the dataset from the official site and place it in the data/ folder.
Run the code via Jupyter Notebook:
- Open
main.ipynb. - Execute cells sequentially:
- Section 1: Image transformations (demo on sample images).
- Section 2: Create and load augmented dataset.
- Section 3: Feature extraction (saves to
data/). - Section 5: Train MLP on original and augmented data (saves weights).
- Section 6: Evaluate on test set using loaded weights.
- To predict: Use
predict(W1, b1, W2, b2, X)function after loading weights. - Paths are relative to
data/; adjustpathvariable if needed.
data/: Dataset batches, augmented data, extracted features.feature_extractor.py: ResNet18 feature extractor.main.ipynb: Main implementation (data loading, augmentation, MLP, training).readme.pdf: Assignment details, code structure, training params, backprop derivations.Weights_original/Weights_augmented: Trained model weights.
For full details, refer to readme.pdf.