深度学习近年来在计算机视觉、自然语言处理等领域取得了巨大成功,而残差网络(ResNet)作为一种经典的深度神经网络架构,因其解决了深层网络中的梯度消失问题而广受关注。ResNet通过引入“残差连接”(skip connection),使得网络可以直接学习输入和输出之间的差异,从而允许更深的网络结构。本篇博客将通过PyTorch实现一个ResNet模型,并结合代码和训练结果,带你一步步理解残差网络的原理与应用。
我们将使用Fashion-MNIST数据集,通过PyTorch实现ResNet的训练过程,并展示其训练过程中的损失和准确率变化。代码部分将包括数据加载、模型定义、训练函数以及可视化工具的实现,最后我们会分析训练结果并总结经验。
残差网络(Residual Network,简称ResNet)是由何恺明(Kaiming He)等人于2015年提出的深度学习模型,首次出现在论文《Deep Residual Learning for Image Recognition》中。ResNet的提出解决了深度神经网络中的一个核心问题:随着网络层数的增加,模型的训练变得更加困难,容易出现梯度消失或梯度爆炸的现象,导致性能下降。这种现象被称为“退化问题”(degradation problem),即更深的网络反而可能表现得不如较浅的网络。
为了解决这一问题,ResNet引入了“残差学习”(Residual Learning)的概念,通过“残差连接”(skip connection)使得网络可以直接学习输入和输出之间的差异,从而允许更深的网络结构,同时保持良好的训练性能。ResNet在2015年的ImageNet竞赛中取得了冠军,其创新性和实用性使其成为深度学习领域的一个里程碑。
ResNet的核心思想是“残差学习”。传统深度网络中,每一层网络试图直接学习目标函数 ( H(x) ),其中 ( x ) 是输入。然而,随着网络深度的增加,直接拟合 ( H(x) ) 变得越来越困难。ResNet提出了一种新的思路:与其直接学习 ( H(x) ),不如学习残差函数 ( F(x) = H(x) - x )。这样,目标函数可以表示为:
H ( x ) = F ( x ) + x H(x) = F(x) + x H(x)=F(x)+x
这里的 ( F(x) ) 是残差,( x ) 是通过“跳跃连接”(skip connection)直接传递的输入。残差学习的核心假设是:如果 ( F(x) ) 接近于0(即恒等映射),那么 ( H(x) \approx x ),这比直接学习复杂的 ( H(x) ) 更容易优化。
残差学习通过“残差块”实现。一个典型的残差块包含以下结构:
残差块的结构如下图所示:
通过这种结构,即使某些层没有学到有用的特征(即 F ( x ) ≈ 0 F(x) \approx 0 F(x)≈0),网络仍然可以通过跳跃连接保留输入信息,从而避免退化问题。
ResNet的整体架构由多个残差块堆叠而成。根据网络深度的不同,ResNet有多种变体,如ResNet-18、ResNet-34、ResNet-50、ResNet-101等。以下以ResNet-18为例,介绍其典型结构:
在更深的ResNet(如ResNet-50及以上)中,使用了瓶颈块来提高效率。瓶颈块的结构如下:
瓶颈块通过降维和升维操作,显著减少了计算量,同时保持了网络的表达能力。
ResNet的提出带来了以下几个关键优势:
ResNet广泛应用于计算机视觉任务,包括但不限于:
此外,ResNet的思想也被扩展到其他领域,例如自然语言处理(NLP)中的Transformer模型(通过类似残差连接的结构)。
我们首先需要加载Fashion-MNIST数据集,这是一个包含10类服装图像的经典数据集,适合用于测试分类模型。以下是数据加载的代码:
import torchvision
import torchvision.transforms as transforms
from torch.utils import data
import multiprocessing
def get_dataloader_workers():
"""使用电脑支持的最大进程数来读取数据"""
return multiprocessing.cpu_count()
def load_data_fashion_mnist(batch_size, resize=None):
"""
下载Fashion-MNIST数据集,然后将其加载到内存中。
参数:
batch_size (int): 每个数据批次的大小。
resize (int, 可选): 图像的目标尺寸。如果为 None,则不调整大小。
返回:
tuple: 包含训练 DataLoader 和测试 DataLoader 的元组。
"""
# 定义变换管道
trans = [transforms.ToTensor()]
if resize:
trans.insert(0, transforms.Resize(resize))
trans = transforms.Compose(trans)
# 加载 Fashion-MNIST 训练和测试数据集
mnist_train = torchvision.datasets.FashionMNIST(
root="./data",
train=True,
transform=trans,
download=True
)
mnist_test = torchvision.datasets.FashionMNIST(
root="./data",
train=False,
transform=trans,
download=True
)
# 返回 DataLoader 对象
return (
data.DataLoader(
mnist_train,
batch_size,
shuffle=True,
num_workers=get_dataloader_workers()
),
data.DataLoader(
mnist_test,
batch_size,
shuffle=False,
num_workers=get_dataloader_workers()
)
)
在上述代码中,我们定义了load_data_fashion_mnist
函数,用于加载Fashion-MNIST数据集。batch_size
参数控制每个批次的数据量,resize
参数允许调整图像大小(这里我们将图像调整为96x96)。transforms.ToTensor()
将图像转换为PyTorch张量格式,并将像素值归一化到[0, 1]。
在实际调用时,我们设置batch_size=256
,并将图像调整为96x96:
import utils_for_data
batch_size = 256
train_iter, test_iter = utils_for_data.load_data_fashion_mnist(batch_size, resize=96)
在本节中,我们将使用PyTorch定义一个残差网络(ResNet)模型,包含残差块(Residual Block)和整体网络结构。
import torch
import torch.nn as nn
import torch.nn.functional as F
# 定义残差块类,继承自 nn.Module
class Residual(nn.Module):
"""
残差块(Re