Article
pytorchimage-classificationcomputer-visiondeep-learningcifar-10cnnpython
Train a CIFAR-10 Image Classifier with PyTorch
A hands-on guide to training a simple Convolutional Neural Network (CNN) for image classification on the CIFAR-10 dataset. Learn the fundamental workflow of loading data, defining a model, training, and evaluating with PyTorch.
beginner30 min5 steps
The play
- Load and Normalize CIFAR-10 DataFirst, import PyTorch and torchvision. Then, define a series of transformations to normalize the image data and use torchvision's built-in datasets to download and load the CIFAR-10 training and test sets into DataLoaders.
- Define the Convolutional Neural Network (CNN)Create a Python class that inherits from `torch.nn.Module` to define your network architecture. This PyTorch Image Classification Script will use a simple CNN with two convolutional layers and three fully-connected layers.
- Define Loss Function and OptimizerFor a classification task, Cross-Entropy Loss is a suitable loss function. For optimization, we'll use Stochastic Gradient Descent (SGD) with momentum to update the model's weights during training.
- Train the NetworkLoop over the training data for a set number of epochs. In each loop, get the inputs and labels, zero the parameter gradients, perform a forward pass, calculate the loss, perform a backward pass to compute gradients, and update the weights.
- Test and Evaluate the ModelEvaluate the trained PyTorch Image Classification Script on the test dataset. Pass the test images through the network, get the predicted class, and compare it to the ground-truth labels to calculate the overall accuracy.
Starter code
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# 1. Load and normalize CIFAR-10
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
batch_size = 4
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=False, num_workers=2)
# 2. Define a CNN
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
# 3. Define a Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# 4. Train the network
print('Starting Training...')
for epoch in range(2): # loop over the dataset for 2 epochs for a quick test
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
running_loss = 0.0
print('Finished Training.')
# 5. Test the network on the test data
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')