只是进行了数据集的切换,针对上篇博客Pytorch入门(五)只是改动了数据集处理方式,使用transforms.Resize()对图像大小进行了修改。图片数量较少,所以在测试集,验证集划分上没有那么严格标准。
transforms.Resize([h, w])
例如transforms.Resize([32,32]),将图片修改为32x32大小的特征图
虽然会改变图片的长宽比,但是本身并没有发生裁切。
仍可以通过resize方法返回原来的形状
加载数据集代码:
# 准备数据集并预处理
transform_train = transforms.Compose([
transforms.Resize([32,32]),
transforms.RandomCrop(32, padding=4), # 先四周填充0,在吧图像随机裁剪成32*32
transforms.RandomHorizontalFlip(), # 图像一半的概率翻转,一半的概率不翻转
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # R,G,B每层的归一化用到的均值和方差
])
transform_test = transforms.Compose([
transforms.Resize([32, 32]),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.ImageFolder(root='data/train', transform=transform_train) # 训练数据集
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True,
num_workers=2) # 生成一个个batch进行批训练,组成batch的时候顺序打乱取
testset = torchvision.datasets.ImageFolder(root='data/test', transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=True, num_workers=2)
将原来每100个批次打印一次损失改为了每16个批次计算一次损失:
if total_train_step % 16 == 0:print('[训练次数:%d] Loss: %.03f'% (total_train_step, total_train_loss))
项目灵感来源:chgl16
原项目采用的是CNN,本篇博客将网络骨架进行了替换。采用了ResNet18残差网络,代码可以看我之前的博客。想要数据集如果打不开github,可以去我的资源里面下载数据集,已经整理上传。大家也可以在网络上爬取自己喜欢的图片然后使用自己喜欢的网络进行分类、预测。