记录一下学习深度学习的一些。本篇简述如何在 Windows 上训练一个模型来识别猫狗。
所使用的环境:
猫狗大战的数据集找不到官网,这里使用 Kaggle 的数据集,也提供百度网盘的下载地址。
下载完数据后,解压,可以看到训练数据集的图片都以 cat.xxx.jpg
和 dog.xxx.jpg
命名,可以从名字中获取标签,而测试数据集的标签无法提取,emmm,不嫌累的可以直接一张张标注。
再查看一些其他信息。
可以看到有 12,500 只猫和 12,500 只狗,图像的高宽大约为500,有部分特别小,有两个特别大,查看一下
可以看到其中有部分脏数据,这对模型的训练会造成很大的困扰,但是 25,000 张图像要清洗一遍太费精力了,待解决。
数据集中的图像有大有小,训练模型要求输入是统一大小的,直接将图像缩放至统一大小对人类判别猫狗影响不大,因此将所有的图像全部缩放至 256,再随机裁剪至 224,进行数据增强。
transform = {
'train': transforms.Compose([
transforms.Resize((256,256)),
transforms.RandomCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize((256,256)),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
}
datasets = {
'train': DogsCatsSet(train_set, transform['train']),
'val': DogsCatsSet(val_set, transform['val'])
}
data_sizes = {
'train': len(datasets['train']),
'val': len(datasets['val'])
}
dataloaders = {
'train': DataLoader(datasets['train'], batch_size=batch_size, shuffle=True, pin_memory=True),
'val': DataLoader(datasets['val'], batch_size=batch_size, shuffle=False, pin_memory=False)
}
用 ResNet 来训练猫狗分类器,使用 ImageNet 的预训练权重,进行微调
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision.models import resnet50
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_ft = resnet50(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 2)
model_tf = model_ft.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_ft.parameters(), lr=1e-3)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
model_tf = train(model_ft, dataloaders, dataset_sizes, criterion, optimizer, exp_lr_scheduler, device, 20)
本次测试由于测试集没有真实标签,所以就无办法直接验证测试集的准确率了,测试部分图像并将输入图像展示出来
test_transform = transforms.Compose([
transforms.Resize((256,256)),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
test_set = DogsCatsSet(test_list, test_transform)
test_dataloader = DataLoader(test_set, batch_size=batch_size, shuffle=True, pin_memory=True)
visualize_preds(model_ft, device, test_dataloader, 64)
可以看到测试的大部分结果都是正确的,但对于其他的测试样本或者真实的数据表现如何,就要实践才知道了。
save_path = 'dogs_vs_cats.pt'
torch.save(model_ft.state_dict(), save_path)