从代码学习深度学习 - 残差网络(ResNet)PyTorch版

文章目录

  • 前言
  • 一、残差网络(ResNet)介绍
    • 1.1. 背景与动机
    • 1.2. 核心思想:残差学习
      • 残差块(Residual Block)
    • 1.3. ResNet的网络架构
      • ResNet-18架构
      • 不同深度的ResNet
      • 瓶颈块(Bottleneck Block)
    • 1.4. 优势与特点
    • 1.5. 应用场景
  • 二、代码解析与实现
    • 2.1. 数据加载
    • 2.2. 模型定义
    • 2.3. 训练工具函数
    • 2.4. 可视化工具
  • 三、模型训练与结果分析
    • 3.1. 训练模型
    • 3.2. 训练结果
    • 3.3. 结果分析
  • 总结


前言

深度学习近年来在计算机视觉、自然语言处理等领域取得了巨大成功,而残差网络(ResNet)作为一种经典的深度神经网络架构,因其解决了深层网络中的梯度消失问题而广受关注。ResNet通过引入“残差连接”(skip connection),使得网络可以直接学习输入和输出之间的差异,从而允许更深的网络结构。本篇博客将通过PyTorch实现一个ResNet模型,并结合代码和训练结果,带你一步步理解残差网络的原理与应用。

我们将使用Fashion-MNIST数据集,通过PyTorch实现ResNet的训练过程,并展示其训练过程中的损失和准确率变化。代码部分将包括数据加载、模型定义、训练函数以及可视化工具的实现,最后我们会分析训练结果并总结经验。


一、残差网络(ResNet)介绍

1.1. 背景与动机

残差网络(Residual Network,简称ResNet)是由何恺明(Kaiming He)等人于2015年提出的深度学习模型,首次出现在论文《Deep Residual Learning for Image Recognition》中。ResNet的提出解决了深度神经网络中的一个核心问题:随着网络层数的增加,模型的训练变得更加困难,容易出现梯度消失或梯度爆炸的现象,导致性能下降。这种现象被称为“退化问题”(degradation problem),即更深的网络反而可能表现得不如较浅的网络。

为了解决这一问题,ResNet引入了“残差学习”(Residual Learning)的概念,通过“残差连接”(skip connection)使得网络可以直接学习输入和输出之间的差异,从而允许更深的网络结构,同时保持良好的训练性能。ResNet在2015年的ImageNet竞赛中取得了冠军,其创新性和实用性使其成为深度学习领域的一个里程碑。

1.2. 核心思想:残差学习

ResNet的核心思想是“残差学习”。传统深度网络中,每一层网络试图直接学习目标函数 ( H(x) ),其中 ( x ) 是输入。然而,随着网络深度的增加,直接拟合 ( H(x) ) 变得越来越困难。ResNet提出了一种新的思路:与其直接学习 ( H(x) ),不如学习残差函数 ( F(x) = H(x) - x )。这样,目标函数可以表示为:

H ( x ) = F ( x ) + x H(x) = F(x) + x H(x)=F(x)+x

这里的 ( F(x) ) 是残差,( x ) 是通过“跳跃连接”(skip connection)直接传递的输入。残差学习的核心假设是:如果 ( F(x) ) 接近于0(即恒等映射),那么 ( H(x) \approx x ),这比直接学习复杂的 ( H(x) ) 更容易优化。

残差块(Residual Block)

残差学习通过“残差块”实现。一个典型的残差块包含以下结构:

  • 主路径:由若干卷积层、批归一化(BatchNorm)和激活函数(通常是ReLU)组成,负责学习残差 ( F(x) )。
  • 跳跃连接:将输入 ( x ) 直接加到主路径的输出上,形成 ( H(x) = F(x) + x )。
  • 激活:在加法操作后通常再应用一个ReLU激活函数。

残差块的结构如下图所示:

从代码学习深度学习 - 残差网络(ResNet)PyTorch版_第1张图片

通过这种结构,即使某些层没有学到有用的特征(即 F ( x ) ≈ 0 F(x) \approx 0 F(x)0),网络仍然可以通过跳跃连接保留输入信息,从而避免退化问题。

1.3. ResNet的网络架构

ResNet的整体架构由多个残差块堆叠而成。根据网络深度的不同,ResNet有多种变体,如ResNet-18、ResNet-34、ResNet-50、ResNet-101等。以下以ResNet-18为例,介绍其典型结构:

ResNet-18架构

从代码学习深度学习 - 残差网络(ResNet)PyTorch版_第2张图片

  • 输入层:接受输入图像(例如224x224x3的RGB图像)。
  • 初始卷积层:一个7x7卷积核(步幅为2),输出64个通道,后面接批归一化和ReLU激活。
  • 最大池化层:一个3x3最大池化层(步幅为2),用于降采样。
  • 残差层:包含4个阶段(stage),每个阶段由若干残差块组成:
    • Stage 1:2个残差块,输出通道数为64。
    • Stage 2:2个残差块,输出通道数为128。
    • Stage 3:2个残差块,输出通道数为256。
    • Stage 4:2个残差块,输出通道数为512。
  • 全局平均池化:将特征图的空间维度压缩为1x1。
  • 全连接层:输出分类结果(例如ImageNet的1000类)。

不同深度的ResNet

  • ResNet-18:包含18层(16个卷积层 + 1个初始卷积层 + 1个全连接层)。
  • ResNet-34:包含34层,结构类似ResNet-18,但每个阶段的残差块数量更多。
  • ResNet-50/101/152:引入了“瓶颈块”(Bottleneck Block),使用1x1卷积降维和升维,以减少计算量,适合更深的网络。

瓶颈块(Bottleneck Block)

在更深的ResNet(如ResNet-50及以上)中,使用了瓶颈块来提高效率。瓶颈块的结构如下:

  • 1x1卷积:降维,减少通道数。
  • 3x3卷积:进行特征提取。
  • 1x1卷积:升维,恢复通道数。
  • 跳跃连接:将输入加到输出上。

瓶颈块通过降维和升维操作,显著减少了计算量,同时保持了网络的表达能力。

1.4. 优势与特点

ResNet的提出带来了以下几个关键优势:

  1. 解决退化问题:通过残差连接,ResNet允许网络深度大幅增加(例如ResNet-152),而不会导致性能下降。
  2. 易于优化:残差学习使得网络更容易训练,即使是数百层的网络也能通过梯度下降有效优化。
  3. 泛化能力强:ResNet在多种任务(如图像分类、目标检测、语义分割)中表现出色,成为许多后续模型的基础。
  4. 模块化设计:残差块的模块化设计使得ResNet易于扩展和修改,适应不同的任务需求。

1.5. 应用场景

ResNet广泛应用于计算机视觉任务,包括但不限于:

  • 图像分类:如ImageNet数据集上的分类任务。
  • 目标检测:作为骨干网络(backbone)用于Faster R-CNN、YOLO等模型。
  • 语义分割:在DeepLab、U-Net等模型中使用ResNet提取特征。
  • 迁移学习:预训练的ResNet模型常用于迁移学习,适配小规模数据集。

此外,ResNet的思想也被扩展到其他领域,例如自然语言处理(NLP)中的Transformer模型(通过类似残差连接的结构)。

二、代码解析与实现

2.1. 数据加载

我们首先需要加载Fashion-MNIST数据集,这是一个包含10类服装图像的经典数据集,适合用于测试分类模型。以下是数据加载的代码:

import torchvision
import torchvision.transforms as transforms
from torch.utils import data
import multiprocessing

def get_dataloader_workers():
    """使用电脑支持的最大进程数来读取数据"""
    return multiprocessing.cpu_count()

def load_data_fashion_mnist(batch_size, resize=None):
    """
    下载Fashion-MNIST数据集,然后将其加载到内存中。
    
    参数:
        batch_size (int): 每个数据批次的大小。
        resize (int, 可选): 图像的目标尺寸。如果为 None,则不调整大小。
    
    返回:
        tuple: 包含训练 DataLoader 和测试 DataLoader 的元组。
    """
    # 定义变换管道
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    
    # 加载 Fashion-MNIST 训练和测试数据集
    mnist_train = torchvision.datasets.FashionMNIST(
        root="./data",
        train=True,
        transform=trans,
        download=True
    )
    mnist_test = torchvision.datasets.FashionMNIST(
        root="./data",
        train=False,
        transform=trans,
        download=True
    )
    
    # 返回 DataLoader 对象
    return (
        data.DataLoader(
            mnist_train,
            batch_size,
            shuffle=True,
            num_workers=get_dataloader_workers()
        ),
        data.DataLoader(
            mnist_test,
            batch_size,
            shuffle=False,
            num_workers=get_dataloader_workers()
        )
    )

在上述代码中,我们定义了load_data_fashion_mnist函数,用于加载Fashion-MNIST数据集。batch_size参数控制每个批次的数据量,resize参数允许调整图像大小(这里我们将图像调整为96x96)。transforms.ToTensor()将图像转换为PyTorch张量格式,并将像素值归一化到[0, 1]。

在实际调用时,我们设置batch_size=256,并将图像调整为96x96:

import utils_for_data
batch_size = 256
train_iter, test_iter = utils_for_data.load_data_fashion_mnist(batch_size, resize=96)

2.2. 模型定义

在本节中,我们将使用PyTorch定义一个残差网络(ResNet)模型,包含残差块(Residual Block)和整体网络结构。

import torch
import torch.nn as nn
import torch.nn.functional as F

# 定义残差块类,继承自 nn.Module
class Residual(nn.Module):
    """
    残差块(Re

你可能感兴趣的:(深度学习-pytorch版,深度学习,pytorch)