Achieving 99.26% Accuracy on MNIST with CNN in PyTorch by Amit ChejaraAchieving 99.26% Accuracy on MNIST with CNN in PyTorch by Amit Chejara

Achieving 99.26% Accuracy on MNIST with CNN in PyTorch

Amit Chejara

Amit Chejara

How I Hit 99.26% Accuracy on MNIST with a CNN in PyTorch

In this article, we’ll build a Convolutional Neural Network (CNN) from scratch using PyTorch to classify handwritten digits from the famous MNIST dataset. We’ll walk through every step — from loading and preprocessing the data, designing the model architecture, setting up the training loop, and evaluating performance — culminating in an impressive 99.26% test accuracy. Whether you’re new to deep learning or looking to refine your PyTorch skills, this hands-on guide will help you understand the key components of a successful CNN implementation, complete with best practices for optimization, evaluation, and debugging.
Follow along with the Kaggle notebook and practice side by side with this tutorial.

Setting Up the Foundation: Importing Essential Libraries

# Let's import the necessary packages!

import math
import numpy as np
import h5py
import matplotlib.pyplot as plt
from matplotlib.pyplot import imread
import scipy
from PIL import Image
import pandas as pd
import torch
from torch import nn
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torchvision.transforms import ToTensor
from torch.nn.functional import one_hot

%matplotlib inline
np.random.seed(1)
Every machine learning project begins with importing the right tools, and our MNIST classification task is no exception. This code block brings together all the essential Python libraries we’ll need throughout our implementation. We start with fundamental packages like math and numpy for numerical operations, then add visualization power with matplotlib to help us understand our data. The torch imports form the backbone of our implementation - we're using PyTorch's neural network module (nn) for building our CNN, its vision datasets for easy access to MNIST, and transforms for preprocessing our images. Special mentions go to torchvision.transforms.ToTensor for converting images to PyTorch tensors and one_hot for label encoding. We also set a random seed with np.random.seed(1) to ensure reproducibility of our results, and include %matplotlib inline for smooth Jupyter Notebook visualization. This comprehensive import strategy ensures we have all the necessary components for data loading, model building, training, and evaluation readily available as we progress through our implementation.

Loading the MNIST Dataset: Our Handwritten Digit Collection

train_data = MNIST(root="/kaggle/working/mnist-data", train=True, download=True)
test_data = MNIST(root='/kaggle/working/mnist-data', train=False, download=True)

print("Training set size:", len(train_data))
print("Test set size:", len(test_data))
Now that we have our tools ready, it’s time to bring in the star of our project — the MNIST dataset. These two simple lines of code using PyTorch’s built-in MNIST class do the heavy lifting for us: the first loads 60,000 training images (train=True), while the second prepares 10,000 test images (train=False). The root parameter specifies where to store the data in our Kaggle environment, and download=True automatically fetches the dataset if it's not already present.
The beauty of using torchvision.datasets.MNIST is its seamless integration with PyTorch - it handles all the downloading, decompressing, and organizing for us. When we print the dataset sizes, we see the classic MNIST split: 60,000 training samples and 10,000 test samples. This 6:1 ratio is ideal for machine learning, giving us ample data to train our CNN while reserving a substantial set for evaluation. Notice how we're not yet transforming the data - we'll handle preprocessing in the next steps, keeping our pipeline clean and modular.
# Let's check the type of the images in the dataset
type(train_data[0])
output: tuple
# So, the image and the target is stored in a tuple, Let's check the 
# type of the image and the target
type(train_data[0][0]), type(train_data[0][1])
output: (PIL.Image.Image, int)
Before diving into preprocessing, let’s examine how PyTorch structures our MNIST data. When we check type(train_data[0]), we discover it returns a tuple - this reveals an important characteristic of how the dataset is organized. In PyTorch's MNIST implementation, each data sample is stored as a (image, label) pair, where:
The first element is a PIL (Python Imaging Library) Image object containing the 28x28 grayscale pixel data
The second element is an integer representing the digit class (0 through 9)
This tuple structure is fundamental to PyTorch’s dataset handling and will influence how we design our data pipeline. The fact that we’re working with PIL Images initially means we’ll need to convert these to tensors before our CNN can process them — a transformation we’ll handle soon. This quick check serves as a valuable reminder that understanding your data structure is just as important as understanding the algorithms that will process it.
# OK, great let's look at some images in the dataset and their labels

fig, axes = plt.subplots(1, 3, figsize=(10, 3))

for ax, i in zip(axes, np.random.randint(0, 60000, 3)):
image, label = train_data[i]
ax.imshow(image, cmap='gray')
ax.set_title(f'Label: {label}')
ax.axis('off')

plt.tight_layout()
plt.show()
Before training our neural network, let’s meet some of the handwritten digits we’ll be working with! This visualization code block creates a clean three-panel display showing random samples from our training set. Here’s what’s happening under the hood:
We create a 1×3 grid of subplots using plt.subplots(), with each image given ample space (10 inches wide × 3 inches tall)
The loop selects three random images (using np.random.randint) and displays them with:
image.show() renders the grayscale pixel data
set_title() shows us the ground truth label
axis('off') removes distracting axes for cleaner visualization
The cmap='gray' parameter ensures we see authentic black-and-white representations, just as our CNN will process them. These samples give us immediate intuition about our task - we can see the variation in handwriting styles, stroke thickness, and digit positioning that our model will need to handle. Notice how the labels match the visible digits (though sometimes the handwriting is surprisingly ambiguous even to human eyes!), establishing our baseline expectation for model performance.

Preparing Our Data: The Transformation Pipeline

# Let's define a transform to convert images to tensors and normalize pixel values to [0, 1]
transform = transforms.Compose([
transforms.ToTensor()
])
train_data = MNIST(root='/kaggle/working/mnist-data', train=True, transform=ToTensor(), download=False)
test_data = MNIST(root='/kaggle/working/mnist-data', train=False, transform=ToTensor(), download=False)
# Now let's have a look at the the data
train_data[0][0]
output:

tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.0706, 0.0706, 0.0706,
0.4941, 0.5333, 0.6863, 0.1020, 0.6510, 1.0000, 0.9686, 0.4980,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.1176, 0.1412, 0.3686, 0.6039, 0.6667, 0.9922, 0.9922, 0.9922,
0.9922, 0.9922, 0.8824, 0.6745, 0.9922, 0.9490, 0.7647, 0.2510,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1922,
0.9333, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922,
0.9922, 0.9843, 0.3647, 0.3216, 0.3216, 0.2196, 0.1529, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0706,
0.8588, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.7765, 0.7137,
0.9686, 0.9451, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.3137, 0.6118, 0.4196, 0.9922, 0.9922, 0.8039, 0.0431, 0.0000,
0.1686, 0.6039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0549, 0.0039, 0.6039, 0.9922, 0.3529, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.5451, 0.9922, 0.7451, 0.0078, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0431, 0.7451, 0.9922, 0.2745, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.1373, 0.9451, 0.8824, 0.6275,
0.4235, 0.0039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3176, 0.9412, 0.9922,
0.9922, 0.4667, 0.0980, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1765, 0.7294,
0.9922, 0.9922, 0.5882, 0.1059, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0627,
0.3647, 0.9882, 0.9922, 0.7333, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.9765, 0.9922, 0.9765, 0.2510, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1804, 0.5098,
0.7176, 0.9922, 0.9922, 0.8118, 0.0078, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.1529, 0.5804, 0.8980, 0.9922,
0.9922, 0.9922, 0.9804, 0.7137, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0941, 0.4471, 0.8667, 0.9922, 0.9922, 0.9922,
0.9922, 0.7882, 0.3059, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0902, 0.2588, 0.8353, 0.9922, 0.9922, 0.9922, 0.9922, 0.7765,
0.3176, 0.0078, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0706, 0.6706,
0.8588, 0.9922, 0.9922, 0.9922, 0.9922, 0.7647, 0.3137, 0.0353,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.2157, 0.6745, 0.8863, 0.9922,
0.9922, 0.9922, 0.9922, 0.9569, 0.5216, 0.0431, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.5333, 0.9922, 0.9922, 0.9922,
0.8314, 0.5294, 0.5176, 0.0627, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000]]])
# Great, this is a tensor, let's check it shape
train_data[0][0].shape
output: torch.Size([1, 28, 28])
Now we reach a critical step in our pipeline — preprocessing the raw images into a format our CNN can digest. We accomplish this through PyTorch’s transformation system:
The Transformer: We create a transform pipeline using transforms.Compose, which currently contains just one operation - ToTensor(). This simple but powerful conversion:
Changes our PIL Images into PyTorch tensors (the fundamental data structure for all PyTorch operations)
Automatically scales pixel values from [0, 255] to the range [0, 1]
Adds a channel dimension (transforming 28×28 images into 1×28×28 tensors)
2. Re-loading with Transformations: We reload our datasets, this time applying the transform during loading. Notice we set download=False since we've already cached the data.
3. Verification: When we inspect train_data[0][0], we now see a tensor instead of a PIL Image. This tensor contains our normalized pixel values ready for neural network processing.
This transformation is deceptively simple but fundamentally important — it bridges the gap between human-interpretable images and mathematical representations our CNN can learn from. In more complex applications, we might add additional transformations like data augmentation, but for MNIST, this basic conversion suffices.

Structuring Our Data for Effective Training

# Let's seperate the cross validation set aside

from torch.utils.data import random_split


cv_size = int(0.5 * len(test_data)) # 50% for CV
test_size = len(test_data) - cv_size # Remaining 50% for Test

cv_data, test_data = random_split(test_data, [cv_size, test_size])
# let's check the type of the sets
type(train_data)
output: torchvision.datasets.mnist.MNIST
def transform_shape(data_set):

data_examples = torch.zeros(size=(len(data_set), 28, 28))
targets = torch.zeros(size=(len(data_set), 1))

for i,instance in enumerate(data_set):
data_examples[i] = instance[0].reshape(28,28)
targets[i] = instance[1]

targets = one_hot(targets.long(), 10).squeeze()

return data_examples, targets
new_train_x, new_train_y = transform_shape(train_data)
new_test_x, new_test_y = transform_shape(test_data)
new_cv_x, new_cv_y = transform_shape(cv_data)
# Let's check thier shape
new_train_x.shape, new_train_y.shape, new_test_x.shape, new_test_y.shape, new_cv_x.shape, new_cv_y.shape
output:

(torch.Size([60000, 28, 28]),
torch.Size([60000, 10]),
torch.Size([5000, 28, 28]),
torch.Size([5000, 10]),
torch.Size([5000, 28, 28]),
torch.Size([5000, 10]))
new_train_y[0]
output: tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0])
With our data loaded and transformed, we now implement two crucial preparation steps: creating a validation set and restructuring our data for efficient training.
1. Creating a Validation Split: We wisely divide our original test set into equal parts for cross-validation (cv_data) and final testing (test_data) using PyTorch’s random_split. This gives us:
5,000 samples for validation (to tune hyperparameters)
5,000 samples for final testing (to evaluate model performance) Preserving half the original test set for final evaluation ensures we don’t overfit to our validation metrics.
2. Data Restructuring with transform_shape(): This custom function performs several important transformations:
Reshapes image tensors from (1,28,28) to (28,28) while maintaining pixel values
Converts labels into one-hot encoded vectors using PyTorch’s one_hot function
Returns separate tensors for features (data_examples) and targets The output shapes reveal our prepared datasets:
Training: 60,000 images (28×28) with 10-class one-hot labels
Validation/Test: 5,000 images each with corresponding labels
Why This Matters:
The validation set helps monitor model performance during training
One-hot encoding enables effective multi-class classification
Proper tensor shapes ensure compatibility with our upcoming CNN architecture
Separating features and labels simplifies batch generation

Building Our Handwritten Digit Classifier: A PyTorch CNN Blueprint

def my_beloved_model():
"""
Implements the forward propagation for the multiclass classification model:
ZEROPAD2D -> CONV2D -> BATCHNORM -> RELU -> MaxPool2D -> FLATTEN -> LazyLinear -> Softmax

Arguments:
None

Returns:
model -- PyTorch Sequential container
"""
model = nn.Sequential(
nn.ZeroPad2d(2),
nn.Conv2d(1, 16, 5, 1),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Flatten(),
nn.LazyLinear(out_features=10),
nn.Softmax(dim=1)
)
return model
# Let's use torchsummary package to get the summary of the model we created above!
!pip install torchsummary

from torchsummary import summary
summary(my_beloved_model(), (1, 28, 28))
output:

----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
ZeroPad2d-1 [-1, 1, 32, 32] 0
Conv2d-2 [-1, 16, 28, 28] 416
BatchNorm2d-3 [-1, 16, 28, 28] 32
ReLU-4 [-1, 16, 28, 28] 0
MaxPool2d-5 [-1, 16, 14, 14] 0
Flatten-6 [-1, 3136] 0
Linear-7 [-1, 10] 31,370
Softmax-8 [-1, 10] 0
================================================================
Total params: 31,818
Trainable params: 31,818
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.34
Params size (MB): 0.12
Estimated Total Size (MB): 0.47
----------------------------------------------------------------
Now we reach the heart of our project — constructing the convolutional neural network that will learn to recognize handwritten digits. Our architecture follows a carefully designed sequence of layers, each serving a specific purpose in feature extraction and classification.
We will be using PyTorch’s Sequential container to build our model—but what exactly is it? Think of Sequential as a simple and organized way to stack layers in a neural network, one after another, like building blocks. Instead of manually defining how data flows between layers, Sequential automatically connects them in the order you specify. This makes it perfect for straightforward architectures where the output of one layer directly feeds into the next. For example, in our CNN, we’ll stack layers like convolution, batch normalization, and activation functions in sequence—just like a pipeline. It’s beginner-friendly, reduces boilerplate code, and keeps the model definition clean and readable. Now, let’s see how we use it to construct our digit classifier!
To learn more about the sequential container in PyTorch, visit here:
The Architecture Breakdown:
Zero Padding (nn.ZeroPad2d): Adds 2 pixels of padding around our 28×28 images, transforming them to 32×32. This preserves edge features during convolution and helps maintain spatial dimensions.
Convolutional Layer (nn.Conv2d): The workhorse of our CNN, using 16 filters of size 5×5 with stride 1. This layer will learn to detect basic visual patterns like edges and curves.
Batch Normalization (nn.BatchNorm2d): Stabilizes training by normalizing the outputs from our convolutional layer, leading to faster convergence and better performance.
ReLU Activation: Introduces non-linearity, allowing our network to learn complex patterns. The simple max(0,x) operation brings our features to life!
Max Pooling (nn.MaxPool2d): Downsamples our feature maps by taking maximum values over 2×2 windows, reducing spatial dimensions while preserving important features.
Flatten Layer: Prepares our 3D feature maps for the dense layer by converting them to a 1D vector (3136 elements in this case).
Lazy Linear Layer: A clever PyTorch feature that automatically calculates input dimensions. This fully-connected layer produces our final 10-class outputs.
Softmax Activation: Converts logits to probabilities, giving us interpretable confidence scores for each digit class.
Why This Design Works: The model summary reveals a compact yet powerful architecture with:
Only 31,818 trainable parameters (efficient for MNIST)
Progressive dimensionality reduction (32×32 → 28×28 → 14×14 → 10)
Balanced feature extraction and classification capabilities
Automatic input dimension handling with LazyLinear
This architecture exemplifies good CNN design principles while remaining simple enough to train quickly on MNIST. The torchsummary output gives us valuable insights into how our data transforms through each layer, helping debug and understand our network’s information flow.

Testing Our Untrained Model: A Crucial Sanity Check

# Make device agnostic code
device = "cuda" if torch.cuda.is_available() else "cpu"
device

output: 'cpu'
# Now we've got a model, let's see what happens when we pass some data through it.

# Instantiate the model
model = my_beloved_model()

# Ensure the model is on the correct device
model = model.to(device)

# Make predictions
untrained_preds = model(new_cv_x[:5].unsqueeze(dim = 1).to(device))
print(f"Length of predictions: {len(untrained_preds)}, Shape: {untrained_preds.shape}")
print(f"Length of test samples: {len(new_cv_y[:5])}, Shape: {new_cv_y[:5].shape}")
print(f"\nFirst 5 predictions:\n{torch.round(untrained_preds[:5])}")
print(f"\nFirst 5 test labels:\n{new_cv_y[:5]}")
output:

Length of predictions: 5, Shape: torch.Size([5, 10])
Length of test samples: 5, Shape: torch.Size([5, 10])

First 5 predictions:
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], grad_fn=<RoundBackward0>)

First 5 test labels:
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0]])
Before diving into training, we perform an essential verification step to ensure our model architecture behaves as expected. This “dry run” helps catch potential issues early and confirms our data pipeline is properly connected to our model.
Device-Agnostic Setup: We first implement best practices by making our code device-agnostic. The line device = "cuda" if torch.cuda.is_available() else "cpu" automatically selects GPU acceleration if available, falling back to CPU otherwise. This makes our code more portable across different hardware setups.
The Verification Process:
We instantiate our model and move it to the selected device
Pass a small batch of 5 validation images through the untrained network:
Note the unsqueeze(dim=1) operation - this adds the required channel dimension (from 28×28 to 1×28×28)
The model outputs 10-class probabilities for each sample
Interpreting the Output: The comparison between predictions and actual labels reveals:
The untrained model produces uniform zeros (after rounding) — exactly what we expect from random initialization
The output shapes match perfectly ([5, 10] for both predictions and labels)
Each prediction contains 10 values corresponding to class probabilities
The actual labels show the correct one-hot encoded representations
Why This Matters:
Confirms our tensor shapes flow correctly through the network
Verifies our device placement works as intended
Demonstrates the model’s initial random state before learning
Helps debug dimension mismatches early (a common pain point in deep learning)
Validates our preprocessing and data loading pipeline
This simple but crucial step acts as an architectural smoke test, ensuring all components are properly connected before we invest time in training. Seeing those uniform zeros might look disappointing now, but it sets the stage for the exciting transformation we’ll witness during training!

Configuring Our Model’s Learning Process 🚀

# We will use the crossentropyloss since we are performing softmax regression
# And we will use the adams optimizer.

loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params = model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer)
! pip install torchmetrics
from torchmetrics.classification import Accuracy, Precision, Recall
num_classes = 10  # 10-classes

train_accuracy = Accuracy(num_classes=num_classes, task="multiclass")
train_precision = Precision(num_classes=num_classes, task="multiclass")
train_recall = Recall(num_classes=num_classes, task="multiclass")

cv_accuracy = Accuracy(num_classes=num_classes, task="multiclass")
cv_precision = Precision(num_classes=num_classes, task="multiclass")
cv_recall = Recall(num_classes=num_classes, task="multiclass")
# Let's only use 5000 examples for training and convert the examples into a dataset object so that we can create batches.
from torch.utils.data import TensorDataset, DataLoader

new_train_x, new_train_y = new_train_x[:5000].to(device), (new_train_y[:5000].type(torch.float)).to(device)
new_cv_x, new_cv_y = new_cv_x[: 900].to(device), (new_cv_y[: 900].type(torch.float)).to(device)

dataset = TensorDataset(new_train_x, new_train_y)

# Define batch size
batch_size = 32

# Create DataLoader
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
With our model architecture verified, we now establish the crucial components that will guide its learning:
1. Loss Function & Optimizer Selection
CrossEntropyLoss: The perfect choice for our multi-class classification task, combining softmax activation and negative log likelihood loss in a numerically stable way. This will measure how far our predictions are from the true digit labels.
Adam Optimizer: Our model's "guide", combining the benefits of AdaGrad and RMSProp with:
Learning rate of 0.01 for balanced updates
Automatic momentum adaptation
Per-parameter learning rates
LinearLR Scheduler: Gradually decreases the learning rate during training for more stable convergence
2. Performance Metrics We install torchmetrics to track three key indicators:
Accuracy: Overall correctness of predictions
Precision: Measure of prediction reliability
Recall: Ability to find all relevant cases Each metric is configured for our 10-class problem, tracking both training and validation performance separately.
3. Efficient Data Handling We optimize our training process by:
Working with a subset (5,000 training, 900 validation samples) for faster iteration
Creating a TensorDataset for clean data management
Implementing a DataLoader with:
Batch size of 32 for stable gradient updates
Shuffling enabled to prevent order bias
Ensuring all tensors are on the correct device (CPU/GPU)
Why This Setup Works:
CrossEntropyLoss is mathematically ideal for classification
Adam optimizer automatically adapts to problem characteristics
Batch training provides computational efficiency
Comprehensive metrics give us multiple performance perspectives
The reduced dataset size allows for quick experimentation
This configuration balances theoretical soundness with practical efficiency, giving our model everything it needs to learn effectively while providing us with clear insights into its progress.

Training Our CNN: From Random Guessing to 99.2% Accuracy 🎯

def train_model(my_model, epoch_nums, optimizer, loss, data_loader, new_cv_x, new_cv_y):


# Set the number of epochs
epochs = epoch_nums

# Build training and evaluation loop
for epoch in range(epochs):
### Training
my_model.train()

epoch_train_loss = 0
epoch_cv_loss = 0

for batch in data_loader:
inputs, targets = batch
inputs, targets = inputs.to(device), targets.to(device)

# 1. Forward pass (model outputs raw logits)
y_logits = my_model(inputs.unsqueeze(dim = 1)).squeeze() # squeeze to remove extra `1` dimensions, this won't work unless model and data are on same device
y_pred = torch.round(y_logits) # turn logits -> pred probs -> pred labls

# 2. Calculate loss/accuracy

train_loss = loss(y_logits, # Using nn.BCEWithLogitsLoss works with raw logits
targets)
epoch_train_loss = train_loss

# 3. Update metrics
train_accuracy.update(y_pred, targets)
train_precision.update(y_pred, targets)
train_recall.update(y_pred, targets)

# 4. Optimizer zero grad
optimizer.zero_grad()

# 5. Loss backwards
train_loss.backward()

# 6. Optimizer step
optimizer.step()

### Testing
my_model.eval()
with torch.inference_mode():
# 1. Forward pass
cv_logits = my_model(new_cv_x.unsqueeze(dim = 1)).squeeze()
cv_pred = torch.round(cv_logits)
# 2. Caculate loss/accuracy
cv_loss = loss(cv_logits,
new_cv_y)
epoch_cv_loss = cv_loss
cv_accuracy.update(cv_pred, new_cv_y)
cv_precision.update(cv_pred, new_cv_y)
cv_recall.update(cv_pred, new_cv_y)

# Compute metrics at the end of the epoch
epoch_train_accuracy = train_accuracy.compute()
epoch_train_precision = train_precision.compute()
epoch_train_recall = train_recall.compute()

epoch_cv_accuracy = cv_accuracy.compute()
epoch_cv_precision = cv_precision.compute()
epoch_cv_recall = cv_recall.compute()


# Print out what's happening every epoch

print(f"Epoch: {epoch} | Train Loss: {epoch_train_loss:.4f} | CV loss: {epoch_cv_loss:.4f} \nTrain Accuracy: {epoch_train_accuracy.item():.4f}, Train Precision: {epoch_train_precision.item():.4f}, Train Recall: {epoch_train_recall.item():.4f}\nCV Accuracy: {epoch_cv_accuracy:.4f}, CV Precision: {epoch_cv_precision:.4f}, CV Recall: {epoch_cv_recall:.4f}\n\n")

# Reset metrics for the next epoch
train_accuracy.reset()
train_precision.reset()
train_recall.reset()

cv_accuracy.reset()
cv_precision.reset()
cv_recall.reset()
train_model(model, 10, optimizer, loss, data_loader, new_cv_x, new_cv_y)
output:

Epoch: 0 | Train Loss: 1.4652 | CV loss: 1.6278
Train Accuracy: 0.9577, Train Precision: 0.9577, Train Recall: 0.9577
CV Accuracy: 0.9591, CV Precision: 0.9591, CV Recall: 0.9591


Epoch: 1 | Train Loss: 1.4615 | CV loss: 1.5263
Train Accuracy: 0.9823, Train Precision: 0.9823, Train Recall: 0.9823
CV Accuracy: 0.9811, CV Precision: 0.9811, CV Recall: 0.9811


Epoch: 2 | Train Loss: 1.6137 | CV loss: 1.5075
Train Accuracy: 0.9911, Train Precision: 0.9911, Train Recall: 0.9911
CV Accuracy: 0.9901, CV Precision: 0.9901, CV Recall: 0.9901


Epoch: 3 | Train Loss: 1.4621 | CV loss: 1.5180
Train Accuracy: 0.9934, Train Precision: 0.9934, Train Recall: 0.9934
CV Accuracy: 0.9906, CV Precision: 0.9906, CV Recall: 0.9906


Epoch: 4 | Train Loss: 1.4618 | CV loss: 1.4967
Train Accuracy: 0.9937, Train Precision: 0.9937, Train Recall: 0.9937
CV Accuracy: 0.9912, CV Precision: 0.9912, CV Recall: 0.9912


Epoch: 5 | Train Loss: 1.4697 | CV loss: 1.5020
Train Accuracy: 0.9957, Train Precision: 0.9957, Train Recall: 0.9957
CV Accuracy: 0.9929, CV Precision: 0.9929, CV Recall: 0.9929


Epoch: 6 | Train Loss: 1.4612 | CV loss: 1.4953
Train Accuracy: 0.9955, Train Precision: 0.9955, Train Recall: 0.9955
CV Accuracy: 0.9922, CV Precision: 0.9922, CV Recall: 0.9922


Epoch: 7 | Train Loss: 1.4633 | CV loss: 1.5041
Train Accuracy: 0.9959, Train Precision: 0.9959, Train Recall: 0.9959
CV Accuracy: 0.9932, CV Precision: 0.9932, CV Recall: 0.9932


Epoch: 8 | Train Loss: 1.4612 | CV loss: 1.4981
Train Accuracy: 0.9968, Train Precision: 0.9968, Train Recall: 0.9968
CV Accuracy: 0.9929, CV Precision: 0.9929, CV Recall: 0.9929


Epoch: 9 | Train Loss: 1.4612 | CV loss: 1.4978
Train Accuracy: 0.9967, Train Precision: 0.9967, Train Recall: 0.9967
CV Accuracy: 0.9932, CV Precision: 0.9932, CV Recall: 0.9932
# great!! We got 99.24% CV Accuracy, now let's check the final test set Accuracy!

test_accuracy = Accuracy(num_classes=num_classes, task="multiclass")

test_logits = model(new_test_x.unsqueeze(dim = 1)).squeeze()
test_pred = torch.round(test_logits)
test_accuracy.update(test_pred, new_test_y)

final_accuracy = test_accuracy.compute()
final_accuracy.item()*100
output: 99.26000237464905
Our model’s journey from complete ignorance to near-perfect digit recognition is nothing short of remarkable. Let’s break down the training process and its outstanding results:
The Training Loop Explained:
Epoch Management: We train for 10 complete passes through our dataset, with each epoch showing progressively better metrics.
Batch Processing: For each batch of 32 images:
We unsqueeze the input to add the channel dimension
Compute logits (raw predictions) through a forward pass
Calculate loss using CrossEntropy (combining softmax and negative log likelihood)
Backpropagate errors and update weights via Adam optimizer
3. Validation Phase: After each training epoch, we evaluate on unseen validation data with torch.inference_mode() for better performance
4. Metric Tracking: Comprehensive metrics (accuracy, precision, recall) are computed and reset each epoch
Key Implementation Details:
Proper model mode management (train() vs eval())
Gradient handling with zero_grad() and backward()
Device-agnostic tensor operations
Clean metric computation and resetting
The Results Speak for Themselves: Our training metrics show a rapid convergence to near-perfect performance:
Starting accuracy: 95.77% (already good!)
By epoch 2: 99.01% validation accuracy
Final validation accuracy: 99.32%
Test set accuracy: 99.26% (evaluated on completely held-out data)
Why This Matters:
Demonstrates the power of CNNs for image recognition
Shows proper training practices lead to excellent results
The small gap between train/validation/test scores indicates no overfitting
All three metrics (accuracy, precision, recall) align perfectly, showing balanced performance
Final Evaluation: When we finally unleash our model on the untouched test set — data it has never seen during training or validation — it achieves an outstanding 99.26% accuracy, proving its ability to generalize to new handwritten digits.
This performance puts our model among the top tier of MNIST classifiers, all with a relatively simple architecture and minimal training time. The success validates our entire pipeline from data loading through model architecture to training configuration.
At the heart of Machine Learning and Deep Learning lies a strong foundation in Linear Algebra. If you’re enjoying this series and want to take your understanding to the next level, why not dive into a structured learning path? Coursera offers top-tier Machine Learning and Deep Learning courses designed by experts, helping you bridge the gap between theory and real-world applications. Whether you’re an aspiring data scientist, AI enthusiast, or just love math, these courses provide hands-on projects and industry-relevant knowledge to accelerate your journey. Check them out here and start mastering ML & DL today!
Start here!🚀or scan the QR code below.
Follow on LinkedIn and X
Like this project

Posted Apr 18, 2025

Built a custom CNN architecture in PyTorch, achieving 99.26% accuracy on the MNIST dataset.