A project demonstrating how to train a Convolutional Neural Network (CNN) on the MNIST dataset to recognize handwritten digits.
Here're some of the project's best features:
- Convolutional Neural Network that classifies digits (0–9) in 28×28 grayscale images.
- Batch Normalization and Dropout to help stabilize training and reduce overfitting.
- Data Augmentation (random rotations shifts scale and shear) to improve generalization.
- Automatic checkpointing of the best model during training.
- Confusion matrix generation for performance analysis.
- Streamlit canvas for drawing digits and obtaining real-time predictions.
1. Clone this repository or download it as a ZIP:
git clone https://github.com/YourUsername/HandwrittenDigitsAI.git
cd HandwrittenDigitsAI
2. (Optional) Create and activate a virtual environment:
python -m venv venv
venv\Scripts\activate
3. Install the required packages:
pip install -r requirements.txt
-
Confirm that a models/ folder exists (the script will create it if needed).
-
Run:
python train.py
-
The script downloads the MNIST dataset to data/ (if not already present) and begins training.
-
Each epoch prints the loss and accuracy on the test set.
-
The best model is automatically saved to models/ under a name that includes the best accuracy and a timestamp.
-
A confusion matrix is generated and saved to confusion_matrices/.
-
Ensure that a trained model (e.g., mnist_cnn_xxx.pth) is present in the models/ folder.
-
Launch the app:
streamlit run app.py
-
A local webpage will open.
-
Draw a digit on the black canvas using white strokes.
-
The app resizes and normalizes your drawing, then displays the predicted digit alongside a probability distribution.
HandwrittenDigitsAI/
├── CNN.py # Defines the CNN model class
├── train.py # Script to train/evaluate the model and save the best checkpoint
├── app.py # Streamlit application for real-time digit recognition
├── models/ # Folder storing saved .pth model files
├── confusion_matrices/ # Folder for generated confusion matrix images
├── data/ # MNIST dataset is downloaded here
└── README.md # This README file
