pytorch入门(二)基于PyTorch的图像分类

pytorch入门(二)基于PyTorch的图像分类

  • 鱼猫分类器
    • 任务
    • 数据
    • dataset和data loader
    • 建立训练数据集
    • 建立验证和测试数据集
    • 建立dataloder
    • batch_size
    • 设置神经网络
    • lossfunction
    • optimizing
    • 训练
      • 参数
    • 预测
    • 模型保存
  • 全部代码

pytorch入门(一)简介pytorch

鱼猫分类器

任务

设计一个区分鱼和猫的分类器。

pytorch入门(二)基于PyTorch的图像分类_第1张图片 pytorch入门(二)基于PyTorch的图像分类_第2张图片

数据

ImageNet:用于训练神经网络的标准图像集合,它包含超过1400万张图像和20000个图像类别。
github有相应的下载代码,但我这总是没跑对,所以从其他渠道直接把图片下下来了(提取码:pypt)

dataset和data loader

pytorch通过dataset和data loader建立数据与神经网络的联系。
符合以下类的数据集,可以被喂入神经网络进行训练。

class Dataset(object):
    def __getitem__(self, index):
        raise NotImplementedError
    def __len__(self):
        raise NotImplementedError

len:返回数据长度。
getitem:返回每个batch_size的对应的张量和标签。

建立训练数据集

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from PIL import Image


train_data_path = "./train/"

transforms = transforms.Compose([
    transforms.Resize(64),  # 转变为64*64
    transforms.ToTensor(),  # 将图像转为tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # normalize
                         std=[0.229, 0.224, 0.225])
])
train_data = torchvision.datasets.ImageFolder(root=train_data_path, transform=transforms)

建立验证和测试数据集

val_data_path = "./val/"
val_data = torchvision.datasets.ImageFolder(root=val_data_path,
                                            transform=transforms)
test_data_path = "./test/"
test_data = torchvision.datasets.ImageFolder(root=test_data_path,
                                             transform=transforms)

建立dataloder

batch_size = 64

train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size)
val_data_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size)
test_data_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)

batch_size

理解为每次喂入神经网络的数据数量。更大的batch_size可以让程序更好的学习全局信息,但会占用更大的内存空间。pytorch默认batch_size为1。

设置神经网络

其包含:

  1. 一个输入层:处理输入的张量
  2. 输出层:用于判断类型
  3. 隐藏层
    pytorch入门(二)基于PyTorch的图像分类_第3张图片
# 定义神经网络
class net(nn.Module):
    def __init__(self):
        super(net, self).__init__()
        self.fc1 = nn.Linear(12288, 84)  # 64*64*3
        self.fc2 = nn.Linear(84, 20)
        self.fc3 = nn.Linear(20, 2)
    def forward(self, x):
        x = nn.

你可能感兴趣的:(pytorch,分类,深度学习)