Pytorch入门(六)使用ResNet-18网络训练自制海贼王数据集

Pytorch入门(六)使用ResNet-18网络训练自制海贼王数据集_第1张图片

文章目录

  • 预测效果
  • 数据集处理
  • 训练-starting
  • 代码+数据集

预测效果

效果图先行。
Pytorch入门(六)使用ResNet-18网络训练自制海贼王数据集_第2张图片
Pytorch入门(六)使用ResNet-18网络训练自制海贼王数据集_第3张图片
Pytorch入门(六)使用ResNet-18网络训练自制海贼王数据集_第4张图片
Pytorch入门(六)使用ResNet-18网络训练自制海贼王数据集_第5张图片

Pytorch入门(六)使用ResNet-18网络训练自制海贼王数据集_第6张图片

数据集处理

只是进行了数据集的切换,针对上篇博客Pytorch入门(五)只是改动了数据集处理方式,使用transforms.Resize()对图像大小进行了修改。图片数量较少,所以在测试集,验证集划分上没有那么严格标准。

transforms.Resize([h, w])
例如transforms.Resize([32,32]),将图片修改为32x32大小的特征图
虽然会改变图片的长宽比,但是本身并没有发生裁切。
仍可以通过resize方法返回原来的形状

加载数据集代码:

# 准备数据集并预处理
transform_train = transforms.Compose([
    transforms.Resize([32,32]),
    transforms.RandomCrop(32, padding=4),  # 先四周填充0,在吧图像随机裁剪成32*32
    transforms.RandomHorizontalFlip(),  # 图像一半的概率翻转,一半的概率不翻转
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),  # R,G,B每层的归一化用到的均值和方差
])

transform_test = transforms.Compose([
    transforms.Resize([32, 32]),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.ImageFolder(root='data/train', transform=transform_train)  # 训练数据集
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True,
                                          num_workers=2)  # 生成一个个batch进行批训练,组成batch的时候顺序打乱取

testset = torchvision.datasets.ImageFolder(root='data/test', transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=True, num_workers=2)

训练-starting

将原来每100个批次打印一次损失改为了每16个批次计算一次损失:

if total_train_step % 16 == 0:print('[训练次数:%d] Loss: %.03f'% (total_train_step, total_train_loss))

训练效果稳步提升。
Pytorch入门(六)使用ResNet-18网络训练自制海贼王数据集_第7张图片

代码+数据集

项目灵感来源:chgl16
原项目采用的是CNN,本篇博客将网络骨架进行了替换。采用了ResNet18残差网络,代码可以看我之前的博客。想要数据集如果打不开github,可以去我的资源里面下载数据集,已经整理上传。大家也可以在网络上爬取自己喜欢的图片然后使用自己喜欢的网络进行分类、预测。
Pytorch入门(六)使用ResNet-18网络训练自制海贼王数据集_第8张图片

你可能感兴趣的:(深度学习,pytorch,深度学习,人工智能)