目录
5.4 基于残差网络的手写体数字识别实验
5.4.1 模型构建
5.4.1.1 残差单元
5.4.1.2 残差网络的整体结构
5.4.2 没有残差连接的ResNet18
5.4.2.1 模型训练
5.4.2.2 模型评价
5.4.3 带残差连接的ResNet18
5.4.3.1 模型训练
5.4.3.2 模型评价
5.4.4 与高层API实现版本的对比实验
参考
残差网络(Residual Network,ResNet)是在神经网络模型中给非线性层增加直连边的方式来缓解梯度消失问题,从而使训练深度神经网络变得更加容易。 在残差网络中,最基本的单位为残差单元。
假设为一个或多个神经层,残差单元在的输入和输出之间加上一个直连边。
不同于传统网络结构中让网络去逼近一个目标函数,在残差网络中,将目标函数拆为了两个部分:恒等函数和残差函数
其中为可学习的参数。
一个典型的残差单元如下图所示,由多个级联的卷积层和一个跨层的直连边组成。
一个残差网络通常有很多个残差单元堆叠而成。下面我们来构建一个在计算机视觉中非常典型的残差网络:ResNet18,并重复上一节中的手写体数字识别任务。
构建ResNet18的残差单元,然后在组建完整的网络。
这里,我们实现一个算子ResBlock
来构建残差单元,其中定义了use_residual
参数,用于在后续实验中控制是否使用残差连接。
残差单元包裹的非线性层的输入和输出形状大小应该一致。如果一个卷积层的输入特征图和输出特征图的通道数不一致,则其输出与输入特征图无法直接相加。为了解决上述问题,可以使用1×1大小的卷积将输入特征图的通道数映射为与级联卷积输出特征图的一致通道数。
1×1卷积:与标准卷积完全一样,唯一的特殊点在于卷积核的尺寸是1×1,也就是不去考虑输入数据局部信息之间的关系,而把关注点放在不同通道间。通过使用1×1卷积,可以起到如下作用:
代码如下
导入的库
import json
import gzip
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.transforms import Compose, Resize, Normalize, ToTensor
import random
import torch.utils.data as data
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn.init import constant_, normal_, uniform_
import time
from torchsummary import summary
from thop import profile
import torch.optim as opt
from torchvision.models import resnet18
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1, use_residual=True):
super(ResBlock, self).__init__()
self.stride = stride
self.use_residual = use_residual
# 第一个卷积层,卷积核大小为3×3,可以设置不同输出通道数以及步长
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1, stride=self.stride, bias=False)
# 第二个卷积层,卷积核大小为3×3,不改变输入特征图的形状,步长为1
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False)
# 如果conv2的输出和此残差块的输入数据形状不一致,则use_1x1conv = True
# 当use_1x1conv = True,添加1个1x1的卷积作用在输入数据上,使其形状变成跟conv2一致
if in_channels != out_channels or stride != 1:
self.use_1x1conv = True
else:
self.use_1x1conv = False
# 当残差单元包裹的非线性层输入和输出通道数不一致时,需要用1×1卷积调整通道数后再进行相加运算
if self.use_1x1conv:
self.shortcut = nn.Conv2d(in_channels, out_channels, 1, stride=self.stride, bias=False)
# 每个卷积层后会接一个批量规范化层,批量规范化的内容在7.5.1中会进行详细介绍
self.bn1 = nn.BatchNorm2d(out_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
if self.use_1x1conv:
self.bn3 = nn.BatchNorm2d(out_channels)
def forward(self, inputs):
y = F.relu(self.bn1(self.conv1(inputs)))
y = self.bn2(self.conv2(y))
if self.use_residual:
if self.use_1x1conv: # 如果为真,对inputs进行1×1卷积,将形状调整成跟conv2的输出y一致
shortcut = self.shortcut(inputs)
shortcut = self.bn3(shortcut)
else: # 否则直接将inputs和conv2的输出y相加
shortcut = inputs
y = torch.add(shortcut, y)
out = F.relu(y)
return out
残差网络就是将很多个残差单元串联起来构成的一个非常深的网络。ResNet18 的网络结构如下图所示。
其中为了便于理解,可以将ResNet18网络划分为6个模块:
ResNet18模型的代码实现如下:
定义模块一:
def make_first_module(in_channels):
# 模块一:7*7卷积、批量规范化、汇聚
m1 = nn.Sequential(nn.Conv2d(in_channels, 64, 7, stride=2, padding=3),
nn.BatchNorm2d(64), nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
return m1
定义模块二到五:
#定义模块二到五
def resnet_module(input_channels, out_channels, num_res_blocks, stride=1, use_residual=True):
blk = []
# 根据num_res_blocks,循环生成残差单元
for i in range(num_res_blocks):
if i == 0: # 创建模块中的第一个残差单元
blk.append(ResBlock(input_channels, out_channels,
stride=stride, use_residual=use_residual))
else: # 创建模块中的其他残差单元
blk.append(ResBlock(out_channels, out_channels, use_residual=use_residual))
return blk
定义完整网络:
# 定义完整网络
class Model_ResNet18(nn.Module):
def __init__(self, in_channels=3, num_classes=10, use_residual=True):
super(Model_ResNet18,self).__init__()
m1 = make_first_module(in_channels)
m2, m3, m4, m5 = make_modules(use_residual)
# 封装模块一到模块6
self.net = nn.Sequential(m1, m2, m3, m4, m5,
# 模块六:汇聚层、全连接层
nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(512, num_classes) )
def forward(self, x):
return self.net(x)
这里同样可以使用torchsummaryAPI统计模型的参数量
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0
model = Model_ResNet18(in_channels=1, num_classes=10, use_residual=True).to(device)
summary(model, (1, 32, 32))
运行结果
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 16, 16] 3,200
BatchNorm2d-2 [-1, 64, 16, 16] 128
ReLU-3 [-1, 64, 16, 16] 0
MaxPool2d-4 [-1, 64, 8, 8] 0
Conv2d-5 [-1, 64, 8, 8] 36,864
BatchNorm2d-6 [-1, 64, 8, 8] 128
Conv2d-7 [-1, 64, 8, 8] 36,864
BatchNorm2d-8 [-1, 64, 8, 8] 128
ResBlock-9 [-1, 64, 8, 8] 0
Conv2d-10 [-1, 64, 8, 8] 36,864
BatchNorm2d-11 [-1, 64, 8, 8] 128
Conv2d-12 [-1, 64, 8, 8] 36,864
BatchNorm2d-13 [-1, 64, 8, 8] 128
ResBlock-14 [-1, 64, 8, 8] 0
Conv2d-15 [-1, 128, 4, 4] 73,728
BatchNorm2d-16 [-1, 128, 4, 4] 256
Conv2d-17 [-1, 128, 4, 4] 147,456
BatchNorm2d-18 [-1, 128, 4, 4] 256
Conv2d-19 [-1, 128, 4, 4] 8,192
BatchNorm2d-20 [-1, 128, 4, 4] 256
ResBlock-21 [-1, 128, 4, 4] 0
Conv2d-22 [-1, 128, 4, 4] 147,456
BatchNorm2d-23 [-1, 128, 4, 4] 256
Conv2d-24 [-1, 128, 4, 4] 147,456
BatchNorm2d-25 [-1, 128, 4, 4] 256
ResBlock-26 [-1, 128, 4, 4] 0
Conv2d-27 [-1, 256, 2, 2] 294,912
BatchNorm2d-28 [-1, 256, 2, 2] 512
Conv2d-29 [-1, 256, 2, 2] 589,824
BatchNorm2d-30 [-1, 256, 2, 2] 512
Conv2d-31 [-1, 256, 2, 2] 32,768
BatchNorm2d-32 [-1, 256, 2, 2] 512
ResBlock-33 [-1, 256, 2, 2] 0
Conv2d-34 [-1, 256, 2, 2] 589,824
BatchNorm2d-35 [-1, 256, 2, 2] 512
Conv2d-36 [-1, 256, 2, 2] 589,824
BatchNorm2d-37 [-1, 256, 2, 2] 512
ResBlock-38 [-1, 256, 2, 2] 0
Conv2d-39 [-1, 512, 1, 1] 1,179,648
BatchNorm2d-40 [-1, 512, 1, 1] 1,024
Conv2d-41 [-1, 512, 1, 1] 2,359,296
BatchNorm2d-42 [-1, 512, 1, 1] 1,024
Conv2d-43 [-1, 512, 1, 1] 131,072
BatchNorm2d-44 [-1, 512, 1, 1] 1,024
ResBlock-45 [-1, 512, 1, 1] 0
Conv2d-46 [-1, 512, 1, 1] 2,359,296
BatchNorm2d-47 [-1, 512, 1, 1] 1,024
Conv2d-48 [-1, 512, 1, 1] 2,359,296
BatchNorm2d-49 [-1, 512, 1, 1] 1,024
ResBlock-50 [-1, 512, 1, 1] 0
AdaptiveAvgPool2d-51 [-1, 512, 1, 1] 0
Flatten-52 [-1, 512] 0
Linear-53 [-1, 10] 5,130
================================================================
Total params: 11,175,434
Trainable params: 11,175,434
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 1.05
Params size (MB): 42.63
Estimated Total Size (MB): 43.69
----------------------------------------------------------------
为了验证残差连接对深层卷积神经网络的训练可以起到促进作用,接下来先使用ResNet18(use_residual设置为False)进行手写数字识别实验,再添加残差连接(use_residual设置为True),观察实验对比效果。
为了验证残差连接的效果,先使用没有残差连接的ResNet18进行实验。
使用训练集和验证集进行模型训练,共训练5个epoch。在实验中,保存准确率最高的模型作为最佳模型。代码实现如下:
首先这里的数据集是上个实验中使用到的,以及其中的一些函数在上个实验中都已使用这里把代码已经粘贴过来
# 打印并观察数据集分布情况
train_set, dev_set, test_set = json.load(gzip.open('mnist.json.gz'))
train_images, train_labels = train_set[0][:2000], train_set[1][:2000]
dev_images, dev_labels = dev_set[0][:200], dev_set[1][:200]
test_images, test_labels = test_set[0][:200], test_set[1][:200]
train_set, dev_set, test_set = [train_images, train_labels], [dev_images, dev_labels], [test_images, test_labels]
print('Length of train/dev/test set:{}/{}/{}'.format(len(train_set[0]), len(dev_set[0]), len(test_set[0])))
image, label = train_set[0][0], train_set[1][0]
image, label = np.array(image).astype('float32'), int(label)
# 原始图像数据为长度784的行向量,需要调整为[28,28]大小的图像
image = np.reshape(image, [28, 28])
image = Image.fromarray(image.astype('uint8'), mode='L')
print("The number in the picture is {}".format(label))
plt.figure(figsize=(5, 5))
plt.imshow(image)
plt.savefig('conv-number5.pdf')
plt.show()
# 数据预处理
transforms = Compose([Resize(32), ToTensor(), Normalize(mean=[1], std=[1])])
class MNIST_dataset(data.Dataset):
def __init__(self, dataset, transforms, mode='train'):
self.mode = mode
self.transforms = transforms
self.dataset = dataset
def __getitem__(self, idx):
# 获取图像和标签
image, label = self.dataset[0][idx], self.dataset[1][idx]
image, label = np.array(image).astype('float32'), int(label)
image = np.reshape(image, [28, 28])
image = Image.fromarray(image.astype('uint8'), mode='L')
image = self.transforms(image)
return image, label
def __len__(self):
return len(self.dataset[0])
# 固定随机种子
random.seed(0)
# 加载 mnist 数据集
train_dataset = MNIST_dataset(dataset=train_set, transforms=transforms, mode='train')
test_dataset = MNIST_dataset(dataset=test_set, transforms=transforms, mode='test')
dev_dataset = MNIST_dataset(dataset=dev_set, transforms=transforms, mode='dev')
class RunnerV3(object):
def __init__(self, model, optimizer, loss_fn, metric, **kwargs):
self.model = model
self.optimizer = optimizer
self.loss_fn = loss_fn
self.metric = metric # 只用于计算评价指标
# 记录训练过程中的评价指标变化情况
self.dev_scores = []
# 记录训练过程中的损失函数变化情况
self.train_epoch_losses = [] # 一个epoch记录一次loss
self.train_step_losses = [] # 一个step记录一次loss
self.dev_losses = []
# 记录全局最优指标
self.best_score = 0
def train(self, train_loader, dev_loader=None, **kwargs):
# 将模型切换为训练模式
self.model.train()
# 传入训练轮数,如果没有传入值则默认为0
num_epochs = kwargs.get("num_epochs", 0)
# 传入log打印频率,如果没有传入值则默认为100
log_steps = kwargs.get("log_steps", 100)
# 评价频率
eval_steps = kwargs.get("eval_steps", 0)
# 传入模型保存路径,如果没有传入值则默认为"best_model.pdparams"
save_path = kwargs.get("save_path", "best_model.pdparams")
custom_print_log = kwargs.get("custom_print_log", None)
# 训练总的步数
num_training_steps = num_epochs * len(train_loader)
if eval_steps:
if self.metric is None:
raise RuntimeError('Error: Metric can not be None!')
if dev_loader is None:
raise RuntimeError('Error: dev_loader can not be None!')
# 运行的step数目
global_step = 0
# 进行num_epochs轮训练
for epoch in range(num_epochs):
# 用于统计训练集的损失
total_loss = 0
for step, data in enumerate(train_loader):
X, y = data
# 获取模型预测
logits = self.model(X)
loss = self.loss_fn(logits, y) # 默认求mean
total_loss += loss
# 训练过程中,每个step的loss进行保存
self.train_step_losses.append((global_step, loss.item()))
if log_steps and global_step % log_steps == 0:
print(
f"[Train] epoch: {epoch}/{num_epochs}, step: {global_step}/{num_training_steps}, loss: {loss.item():.5f}")
# 梯度反向传播,计算每个参数的梯度值
loss.backward()
if custom_print_log:
custom_print_log(self)
# 小批量梯度下降进行参数更新
self.optimizer.step()
# 梯度归零
self.optimizer.zero_grad()
# 判断是否需要评价
if eval_steps > 0 and global_step > 0 and \
(global_step % eval_steps == 0 or global_step == (num_training_steps - 1)):
dev_score, dev_loss = self.evaluate(dev_loader, global_step=global_step)
print(f"[Evaluate] dev score: {dev_score:.5f}, dev loss: {dev_loss:.5f}")
# 将模型切换为训练模式
self.model.train()
# 如果当前指标为最优指标,保存该模型
if dev_score > self.best_score:
self.save_model(save_path)
print(
f"[Evaluate] best accuracy performence has been updated: {self.best_score:.5f} --> {dev_score:.5f}")
self.best_score = dev_score
global_step += 1
# 当前epoch 训练loss累计值
trn_loss = (total_loss / len(train_loader)).item()
# epoch粒度的训练loss保存
self.train_epoch_losses.append(trn_loss)
print("[Train] Training done!")
# 模型评估阶段,使用'paddle.no_grad()'控制不计算和存储梯度
@torch.no_grad()
def evaluate(self, dev_loader, **kwargs):
assert self.metric is not None
# 将模型设置为评估模式
self.model.eval()
global_step = kwargs.get("global_step", -1)
# 用于统计训练集的损失
total_loss = 0
# 重置评价
self.metric.reset()
# 遍历验证集每个批次
for batch_id, data in enumerate(dev_loader):
X, y = data
# 计算模型输出
logits = self.model(X)
# 计算损失函数
loss = self.loss_fn(logits, y).item()
# 累积损失
total_loss += loss
# 累积评价
self.metric.update(logits, y)
dev_loss = (total_loss / len(dev_loader))
dev_score = self.metric.accumulate()
# 记录验证集loss
if global_step != -1:
self.dev_losses.append((global_step, dev_loss))
self.dev_scores.append(dev_score)
return dev_score, dev_loss
# 模型评估阶段,使用'paddle.no_grad()'控制不计算和存储梯度
@torch.no_grad()
def predict(self, x, **kwargs):
# 将模型设置为评估模式
self.model.eval()
# 运行模型前向计算,得到预测值
logits = self.model(x)
return logits
def save_model(self, save_path):
torch.save(self.model.state_dict(), save_path)
def load_model(self, model_path):
state_dict = torch.load(model_path)
self.model.load_state_dict(state_dict)
class Accuracy:
def __init__(self, is_logist=True):
"""
输入:
- is_logist: outputs是logist还是激活后的值
"""
# 用于统计正确的样本个数
self.num_correct = 0
# 用于统计样本的总数
self.num_count = 0
self.is_logist = is_logist
def update(self, outputs, labels):
"""
输入:
- outputs: 预测值, shape=[N,class_num]
- labels: 标签值, shape=[N,1]
"""
# 判断是二分类任务还是多分类任务,shape[1]=1时为二分类任务,shape[1]>1时为多分类任务
if outputs.shape[1] == 1: # 二分类
outputs = torch.squeeze(outputs, dim=-1)
if self.is_logist:
# logist判断是否大于0
preds = torch.tensor((outputs >= 0), dtype=torch.float32)
else:
# 如果不是logist,判断每个概率值是否大于0.5,当大于0.5时,类别为1,否则类别为0
preds = torch.tensor((outputs >= 0.5), dtype=torch.float32)
else:
# 多分类时,使用'torch.argmax'计算最大元素索引作为类别
preds = torch.argmax(outputs, dim=1)
# 获取本批数据中预测正确的样本个数
labels = torch.squeeze(labels, dim=-1)
batch_correct = torch.sum(preds == labels).clone().detach().numpy()
batch_count = len(labels)
# 更新num_correct 和 num_count
self.num_correct += batch_correct
self.num_count += batch_count
def accumulate(self):
# 使用累计的数据,计算总的指标
if self.num_count == 0:
return 0
return self.num_correct / self.num_count
def reset(self):
# 重置正确的数目和总数
self.num_correct = 0
self.num_count = 0
def name(self):
return "Accuracy"
def plot(runner, fig_name):
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
train_items = runner.train_step_losses[::30]
train_steps = [x[0] for x in train_items]
train_losses = [x[1] for x in train_items]
plt.plot(train_steps, train_losses, color='#8E004D', label="Train loss")
if runner.dev_losses[0][0] != -1:
dev_steps = [x[0] for x in runner.dev_losses]
dev_losses = [x[1] for x in runner.dev_losses]
plt.plot(dev_steps, dev_losses, color='#E20079', linestyle='--', label="Dev loss")
# 绘制坐标轴和图例
plt.ylabel("loss", fontsize='x-large')
plt.xlabel("step", fontsize='x-large')
plt.legend(loc='upper right', fontsize='x-large')
plt.subplot(1, 2, 2)
# 绘制评价准确率变化曲线
if runner.dev_losses[0][0] != -1:
plt.plot(dev_steps, runner.dev_scores,
color='#E20079', linestyle="--", label="Dev accuracy")
else:
plt.plot(list(range(len(runner.dev_scores))), runner.dev_scores,
color='#E20079', linestyle="--", label="Dev accuracy")
# 绘制坐标轴和图例
plt.ylabel("score", fontsize='x-large')
plt.xlabel("step", fontsize='x-large')
plt.legend(loc='lower right', fontsize='x-large')
plt.savefig(fig_name)
plt.show()
开始训练
# 学习率大小
lr = 0.1
# 批次大小
batch_size = 64
# 加载数据
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
dev_loader = torch.utils.data.DataLoader(dev_dataset, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)
# 定义LeNet网络
# 自定义算子实现的LeNet-5
model = Model_ResNet18(in_channels=1, num_classes=10)
# 飞桨API实现的LeNet-5
# model = Paddle_LeNet(in_channels=1, num_classes=10)
# 定义优化器
optimizer = torch.optim.SGD(lr=lr, params=model.parameters())
# 定义损失函数
loss_fn = F.cross_entropy
# 定义评价指标
metric = Accuracy(is_logist=True)
# 实例化RunnerV3
runner = RunnerV3(model, optimizer, loss_fn, metric)
# 启动训练
log_steps = 15
eval_steps = 15
runner.train(train_loader, dev_loader, num_epochs=5, log_steps=log_steps,
eval_steps=eval_steps, save_path="best_model.pdparams")
# 可视化观察训练集与验证集的Loss变化情况
plot(runner, 'cnn-loss2.pdf')
运行结果
可视化结果
使用测试数据对在训练过程中保存的最佳模型进行评价,观察模型在测试集上的准确率以及损失情况。代码实现如下:
# 加载最优模型
runner.load_model('best_model.pdparams')
# 模型评价
score, loss = runner.evaluate(test_loader)
print("[Test] accuracy/loss: {:.4f}/{:.4f}".format(score, loss))
运行结果
[Test] accuracy/loss: 0.9550/0.2173
再使用带残差连接的ResNet18重复上面的实验。
使用带残差连接的ResNet18重复上面的实验,代码实现如下:
# 学习率大小
lr = 0.01
# 批次大小
batch_size = 64
# 加载数据
train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
dev_loader = data.DataLoader(dev_dataset, batch_size=batch_size)
test_loader = data.DataLoader(test_dataset, batch_size=batch_size)
# 定义网络,通过指定use_residual为True,使用残差结构的深层网络
model = Model_ResNet18(in_channels=1, num_classes=10, use_residual=True)
# 定义优化器
optimizer = opt.SGD(lr=lr, params=model.parameters())
# 实例化RunnerV3
runner = RunnerV3(model, optimizer, loss_fn, metric)
# 启动训练
log_steps = 15
eval_steps = 15
runner.train(train_loader, dev_loader, num_epochs=5, log_steps=log_steps,
eval_steps=eval_steps, save_path="best_model.pdparams")
# 可视化观察训练集与验证集的Loss变化情况
plot(runner, 'cnn-loss3.pdf')
运行结果
可视化结果
使用测试数据对在训练过程中保存的最佳模型进行评价,观察模型在测试集上的准确率以及损失情况。
# 加载最优模型
runner.load_model('best_model.pdparams')
# 模型评价
score, loss = runner.evaluate(test_loader)
print("[Test] accuracy/loss: {:.4f}/{:.4f}".format(score, loss))
运行结果
[Test] accuracy/loss: 0.9580/0.0821
添加了残差连接后,模型收敛曲线更平滑。
从输出结果看,和不使用残差连接的ResNet相比,添加了残差连接后,模型效果有了一定的提升。
飞桨高层 API是对飞桨API的进一步封装与升级,提供了更加简洁易用的API,进一步提升了飞桨的易学易用性。
其中,飞桨高层API封装了以下模块:
飞桨高层 API主要包含在paddle.vision
和paddle.text
目录中。
对于Reset18这种比较经典的图像分类网络,pytorch高层API中都为大家提供了实现好的版本,大家可以不再从头开始实现。
代码如下:
from torchvision.models import resnet18
hapi_model = resnet18(pretrained=True)
# 自定义的resnet18模型
model = Model_ResNet18(in_channels=3, num_classes=1000, use_residual=True)
# 获取网络的权重
params = hapi_model.state_dict()
# 用来保存参数名映射后的网络权重
new_params = {}
# 将参数名进行映射
for key in params:
if 'layer' in key:
if 'downsample.0' in key:
new_params['net.' + key[5:8] + '.shortcut' + key[-7:]] = params[key]
elif 'downsample.1' in key:
new_params['net.' + key[5:8] + '.shorcutt' + key[23:]] = params[key]
else:
new_params['net.' + key[5:]] = params[key]
elif 'conv1.weight' == key:
new_params['net.0.0.weight'] = params[key]
elif 'bn1' in key:
new_params['net.0.1' + key[3:]] = params[key]
elif 'fc' in key:
new_params['net.7' + key[2:]] = params[key]
del new_params["net.2.0.shorcutteight"]
del new_params["net.2.0.shorcuttias"]
del new_params["net.2.0.shorcuttunning_mean"]
del new_params["net.2.0.shorcuttunning_var"]
del new_params["net.2.0.shorcuttum_batches_tracked"]
del new_params["net.3.0.shorcutteight"]
del new_params["net.3.0.shorcuttias"]
del new_params["net.3.0.shorcuttunning_mean"]
del new_params["net.3.0.shorcuttunning_var"]
del new_params["net.3.0.shorcuttum_batches_tracked"]
del new_params["net.4.0.shorcutteight"]
del new_params["net.4.0.shorcuttias"]
del new_params["net.4.0.shorcuttunning_mean"]
del new_params["net.4.0.shorcuttunning_var"]
del new_params["net.4.0.shorcuttum_batches_tracked"]
inputs = np.random.randn(*[3, 3, 32, 32])
inputs = inputs.astype('float32')
x = torch.tensor(inputs)
output = hapi_model(x)
hapi_out = hapi_model(x)
# 计算两个模型输出的差异
diff = output - hapi_out
# 取差异最大的值
max_diff = torch.max(diff)
print(max_diff)
通过结果可以知道,高层API版本的resnet18模型和自定义的resnet18模型输出结果是一致的,也就说明两个模型的实现完全一样。
NNDL 实验5(上) - HBU_DAVID - 博客园 (cnblogs.com)https://www.cnblogs.com/hbuwyg/p/16617671.html
NNDL 实验六 卷积神经网络(4)ResNet18实现MNISThttps://blog.csdn.net/qq_38975453/article/details/127167583?spm=1001.2014.3001.5502