Pytorch基础篇
01-PyTorch新手必看:张量是什么?5 分钟教你快速创建张量!
02-张量运算真简单!PyTorch 数值计算操作完全指南
03-Numpy 还是 PyTorch?张量与 Numpy 的神奇转换技巧
04-揭秘数据处理神器:PyTorch 张量拼接与拆分实用技巧
05-深度学习从索引开始:PyTorch 张量索引与切片最全解析
06-张量形状任意改!PyTorch reshape、transpose 操作超详细教程
07-深入解读 PyTorch 张量运算:6 大核心函数全面解析,代码示例一步到位!
08-自动微分到底有多强?PyTorch 自动求导机制深度解析
Pytorch实战篇
09-从零手写线性回归模型:PyTorch 实现深度学习入门教程
10-PyTorch 框架实现线性回归:从数据预处理到模型训练全流程
11-PyTorch 框架实现逻辑回归:从数据预处理到模型训练全流程
12-PyTorch 框架实现多层感知机(MLP):手写数字分类全流程详解
13-PyTorch 时间序列与信号处理全解析:从预测到生成
14-深度学习必备:PyTorch数据加载与预处理全解析
15-PyTorch实战:手把手教你完成MNIST手写数字识别任务
图像分类是深度学习中最经典的任务之一,而MNIST手写数字识别则是入门的最佳起点。本文将带你使用PyTorch从零构建一个简单的图像分类模型,通过卷积神经网络(CNN)和LeNet模型完成端到端的分类任务。无论你是深度学习新手,还是希望复习CNN基础的开发者,这篇文章都会为你提供清晰的步骤和代码示例。学习目标是通过实战理解CNN的核心结构,掌握PyTorch模型训练流程。
图像分类的目标是让模型识别图片中的内容,而卷积神经网络(CNN)是实现这一目标的利器。本节将从基础概念入手,带你了解图像分类和CNN的核心原理。
图像分类是指将图片分配到特定类别,比如识别手写数字“0-9”。MNIST数据集是一个经典的基准数据集,包含大量手写数字图片,非常适合初学者练习。
MNIST就像深度学习的“Hello World”,简单却能揭示图像分类的本质。
CNN通过卷积操作捕捉图像的空间特征,是图像分类的首选模型。
卷积神经网络(CNN)通过卷积层、池化层和全连接层处理图像。LeNet是CNN的早期代表,适合处理小型图像。
下表对比了传统神经网络与CNN的特点:
特点 | 传统神经网络 | CNN |
---|---|---|
输入类型 | 展平的向量 | 保留二维结构 |
参数量 | 较多 | 较少 |
擅长任务 | 表格数据 | 图像数据 |
LeNet是Yann LeCun提出的经典CNN模型,包含:
它结构简单但足以应对MNIST这样的任务。
本节将通过PyTorch实现LeNet模型,完成MNIST手写数字识别的完整流程。从数据加载到模型训练,每一步都配有代码和解析。
PyTorch提供了便捷的工具加载MNIST数据集,并进行预处理。
使用torchvision.datasets加载数据:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 定义数据变换
transform = transforms.Compose([
transforms.ToTensor(), # 转为张量
transforms.Normalize((0.1307,), (0.3081,)) # 标准化(均值和方差来自MNIST统计)
])
# 加载数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# 创建DataLoader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
关键代码解析:
我们可以用Matplotlib简单可视化数据:
import matplotlib.pyplot as plt
images, labels = next(iter(train_loader))
plt.imshow(images[0].numpy().squeeze(), cmap='gray')
plt.title(f'Label: {labels[0].item()}')
plt.show()
这将显示一张手写数字图片及其标签,确保数据加载正确。
接下来,我们用PyTorch定义LeNet模型。
以下是LeNet的完整实现:
import torch.nn as nn
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=2) # 输入1通道,输出6通道
self.conv2 = nn.Conv2d(6, 16, kernel_size=5) # 输入6通道,输出16通道
self.pool = nn.MaxPool2d(2, 2) # 池化层
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 全连接层
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10) # 输出10类
def forward(self, x):
x = torch.relu(self.conv1(x)) # 卷积 + ReLU
x = self.pool(x) # 池化
x = torch.relu(self.conv2(x)) # 卷积 + ReLU
x = self.pool(x) # 池化
x = x.view(-1, 16 * 5 * 5) # 展平
x = torch.relu(self.fc1(x)) # 全连接 + ReLU
x = torch.relu(self.fc2(x))
x = self.fc3(x) # 输出层
return x
# 实例化模型
model = LeNet()
print(model)
关键点解析:
最后,我们训练模型并测试其性能。
使用交叉熵损失和SGD优化器训练模型:
import torch.optim as optim
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# 训练循环
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
for epoch in range(5): # 训练5个epoch
model.train()
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad() # 清零梯度
outputs = model(images) # 前向传播
loss = criterion(outputs, labels) # 计算损失
loss.backward() # 反向传播
optimizer.step() # 更新参数
print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
评估模型在测试集上的准确率:
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f"Test Accuracy: {100 * correct / total:.2f}%")
输出示例:
Test Accuracy: 98.50%
本文通过PyTorch实现了MNIST手写数字识别,从数据加载、LeNet模型构建到训练与评估,完整呈现了图像分类的流程。你不仅掌握了CNN的基础结构,还学会了如何用PyTorch搭建端到端任务。下一步,可以尝试更复杂的数据集(如CIFAR-10)或更深的模型(如ResNet),进一步提升技能。