使用AlexNet训练自己的分类数据集,使用VGG结合迁移学习训练自己的分类数据集

系统:windows10
语言:python3.8
框架:pytorch1.10

一、使用AlexNet训练自己的分类数据集

备注:现在有很多迁移学习教程都是使用官方的预训练模型,这里主要是想实现从自己预训练好的模型迁移到另一个自己的数据集上。

1.1数据集

博主这里是准备了两个数据集,一个是花分类(5类,3500张左右),一个是猫狗二分类(2000张),我会先使用Alexnet分别对花分类数据集和猫狗分类数据集进行训练,然后使用花分类的模型进行微调,使用迁移学习的方式训练猫狗分类数据集。

这里以猫狗数据集举例,文件夹下面包含猫和狗两个文件夹,分别存放猫狗图片,
data
├─catdog_picture
├─Cat
└─Dog
执行"split_data.py"脚本自动将数据集划分成训练集train和验证集val,花分类数据集同理,执行后改目录如下:
data
├─catdog_picture
│ ├─Cat
│ └─Dog
├─train
│ ├─Cat
│ └─Dog
└─val
├─Cat
└─Dog

split_data.py

import os
from shutil import copy, rmtree
import random


def mk_file(file_path: str):
    if os.path.exists(file_path):
        # 如果文件夹存在,则先删除原文件夹在重新创建
        rmtree(file_path)
    os.makedirs(file_path)


def main():
    # 保证随机可复现
    random.seed(0)

    # 将数据集中10%的数据划分到验证集中
    split_rate = 0.1

    # 指向你的flower_photos文件夹
    cwd = os.getcwd()
    data_root = os.path.join(cwd, "cat_dog")
    origin_flower_path = os.path.join(data_root, "catdog_picture")
    assert os.path.exists(origin_flower_path), "path '{}' does not exist.".format(origin_flower_path)

    flower_class = [cla for cla in os.listdir(origin_flower_path)
                    if os.path.isdir(os.path.join(origin_flower_path, cla))]

    # 建立保存训练集的文件夹
    train_root = os.path.join(data_root, "train")
    mk_file(train_root)
    for cla in flower_class:
        # 建立每个类别对应的文件夹
        mk_file(os.path.join(train_root, cla))

    # 建立保存验证集的文件夹
    val_root = os.path.join(data_root, "val")
    mk_file(val_root)
    for cla in flower_class:
        # 建立每个类别对应的文件夹
        mk_file(os.path.join(val_root, cla))

    for cla in flower_class:
        cla_path = os.path.join(origin_flower_path, cla)
        images = os.listdir(cla_path)
        num = len(images)
        # 随机采样验证集的索引
        eval_index = random.sample(images, k=int(num*split_rate))
        for index, image in enumerate(images):
            if image in eval_index:
                # 将分配至验证集中的文件复制到相应目录
                image_path = os.path.join(cla_path, image)
                new_path = os.path.join(val_root, cla)
                copy(image_path, new_path)
            else:
                # 将分配至训练集中的文件复制到相应目录
                image_path = os.path.join(cla_path, image)
                new_path = os.path.join(train_root, cla)
                copy(image_path, new_path)
            print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="")  # processing bar
        print()

    print("processing done!")


if __name__ == '__main__':
    main()

1.2 分别训练花分类数据集、猫狗数据集

很多经典网络pytorch都有官方实现,并提供预训练模型,这里我们先自己搭建一下Alexnet网络模型。后面在VGG部分,展示如何直接加载官方模型。

model.py

import torch.nn as nn
import torch


class AlexNet(nn.Module):
    def __init__(self, num_classes=1000, init_weights=False):
        #主干网络
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[48, 55, 55]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27]
            nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]
            nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6]
        )

        #分类头
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(128 * 6 * 6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),                    #Dropout一般不用于卷积层,只用于这里的全连接层,卷积层可用Spatial Dropout
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes),
        )

        #初始化权重,当前版本一般不使用,因为pytorch已经默认使用了凯明初始化
        if init_weights:
            self._initialize_weights()

    #前向传播函数
    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)  #展平操作从索引1开始,也就是channel,w,h,  tensor的顺序是(batch,channel,w,h)
        x = self.classifier(x)
        return x

    #初始化权重函数
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

train.py

这里先训练花分类数据集,得到权重AlexNet.pth,再用这个权重模型对猫狗分类数据集进行迁移学习

#训练30轮,准确率为0.80

import os
import sys
import json

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from tqdm import tqdm

from model import AlexNet

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    # 数据预处理
    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

    data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path  返回到上上级目录
    image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path   #在上上级目录中,添加上这个路径
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)        #断言,如果不存在这个路径,那么就报错

    #创建数据集(打包数据集)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    #创建验证集(打包验证集)
    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)


    flower_list = train_dataset.class_to_idx   #class_to_idx就是获取train_dataset下每个文件夹的名称,并按字典返回
    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}

    cla_dict = dict((val, key) for key, val in flower_list.items())     #将folwer的键值对反过来

    # write dict into json file,方便预测的时候调用
    json_str = json.dumps(cla_dict, indent=4)
    with open('flower_class.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 32
    nw = 0  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    #加载训练集
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)


    #加载验证集
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=4, shuffle=True,
                                                  num_workers=nw)

    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))


    net = AlexNet(num_classes=5, init_weights=True)   #实例化网络

    #加载之前训练的模型参数,接续训练,如果要重新训练,直接屏蔽以下代码即可
    #weights_path = "./AlexNet.pth"
    #assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    #net.load_state_dict(torch.load(weights_path))

    net.to(device)  # 放到GPU上


    loss_function = nn.CrossEntropyLoss()                  #损失函数
    # pata = list(net.parameters())
    optimizer = optim.Adam(net.parameters(), lr=0.0002)    #优化器,学习率lr是超参数

    epochs = 10
    save_path = './AlexNet.pth'
    best_acc = 0.0
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()                                        #有些步骤在train过程使用,test过程不用,比如dropout,BN,所以使用net.train和net.eval区分
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)    #进度条,(iterable=可迭代对象,file: 输出指向位置, 默认是终端, 一般不需要设置)
        for step, data in enumerate(train_bar):
            images, labels = data                           # 从data中读取数据,也就是从train_loder中取数据
            optimizer.zero_grad()                           #梯度清零
            outputs = net(images.to(device))                #将数据传入网络
            loss = loss_function(outputs, labels.to(device))  #计算损失
            loss.backward()                                  #反向传播
            optimizer.step()                                 #更新权重

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)
        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]                       #预测概率最大的种类
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item() #如果预测正确的话,acc+=1

        val_accurate = acc / val_num                                           #acc/测试集总数,得到正确率
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))
        print()

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')


if __name__ == '__main__':
    main()

接下来训练猫狗分类数据集,首先修改数据集路径,然后需要在上面的train.py做如下修改,即需要先加载预训练模型,即花分类模型,然后修改最后的全连接层,将输出通道数改为我们的分类个数

    net = AlexNet(num_classes=5,init_weights=True)    #这里设置的num_classes必须与预训练权重模型一致
    #微调,使用花朵分类训练好的权重作为预训练权重
    weights_path = "./AlexNet.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    net.load_state_dict(torch.load(weights_path),strict=False)      #加载预训练权重

    #修改最后的全连接层,将输出通道数改为我们的分类个数
    net.classifier= nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(128 * 6 * 6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, 2),
        )

    net.to(device)

这样我们完成了从花分类数据集到猫狗分类数据集的迁移学习。

predict.py

import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import AlexNet


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    #图片预处理函数
    data_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    # load image
    img_path = r"E:\01_school_study\11_deep_learning_for_image_processing_master\data_set\flower_data\val\tulips\38287568_627de6ca20.jpg"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)

    plt.imshow(img)
    # [N, C, H, W]
    img = data_transform(img)   #图片预处理

    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)  #扩充一个batch维度,因为这里读入的图片只有C,W,H三个维度——》[N, C, H, W]

    # read class_indict
    json_path = './flower_class.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    with open(json_path, "r") as f:
        class_indict = json.load(f)

    # create model
    model = AlexNet(num_classes=5).to(device)

    # load model weights
    weights_path = "./AlexNet.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    #加载权重文件
    model.load_state_dict(torch.load(weights_path))

    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()

    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    plt.title(print_res)
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                                  predict[i].numpy()))
    plt.show()


if __name__ == '__main__':
    main()

二、使用VGG结合迁移学习训练自己的分类数据集

上面使用alexnet完成了从自己的一个数据集到另一个数据集的迁移学习,接下来使用vgg实现从官方加载预训练模型进行迁移学习。

vggnet.py

因为我们将加载pytorch官方实现的模型,其实这里是用不上这个的,还是写在这里

import torch.nn as nn
import torch

# official pretrain weights
model_urls = {
    'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
    'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'
}


class VGG(nn.Module):
    def __init__(self, features, num_classes=1000, init_weights=False):
        super(VGG, self).__init__()

        #主干网络,作为函数的一个参数,由人为传入
        self.features = features

        #分类头
        self.classifier = nn.Sequential(
            nn.Linear(512*7*7, 4096),
            nn.ReLU(True),
            nn.Dropout(p=0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(p=0.5),
            nn.Linear(4096, num_classes)
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        #input: N x 3 x 224 x 224
        x = self.features(x)                 # N x 512 x 7 x 7
        x = torch.flatten(x, start_dim=1)    #从channel这个维度开始展平   # N x 512*7*7
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():               #遍历每一层
            if isinstance(m, nn.Conv2d):       #如果是卷积层
                # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:         #如果有偏置
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):     #如果是全连接层
                nn.init.xavier_uniform_(m.weight)
                # nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

#创建主干网络
def make_features(cfg: list):
    layers = []         #存放我们创建的每一层结构
    in_channels = 3     #输入通道为3
    for v in cfg:
        if v == "M":    #如果是池化层
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:           #卷积层
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)    #(输入深度,输出深度,卷积核尺寸,padding)
            layers += [conv2d, nn.ReLU(True)]
            in_channels = v                                                 #更新输入深度
    return nn.Sequential(*layers)                                           # *,通过非关键字参数形式输入到Sequential,就可以生成网络结构


#对应readme表中VGG不同配置的参数,卷积核个数,‘M’池化
cfgs = {
    'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


def vgg(model_name="vgg16", **kwargs):
    assert model_name in cfgs, "Warning: model number {} not in cfgs dict!".format(model_name)
    cfg = cfgs[model_name]

    model = VGG(make_features(cfg), **kwargs)      #make_features(cfg)就是在创建主干网络,并作为参数,传入VGG
    return model


# vgg_model=vgg(model_name='vgg16')

train.py

这里是加载自己写的模型,重头开始训练

#不采用预训练权重,直接从头开始训练10轮,准确率为0.816

import os
import sys
import json

import torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdm

from model import vgg


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

    data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
    image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 32
    nw = 0  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=nw)
    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))

    # test_data_iter = iter(validate_loader)
    # test_image, test_label = test_data_iter.next()

    model_name = "vgg16"
    net = vgg(model_name=model_name, num_classes=5, init_weights=True)
    net.to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0001)

    epochs = 10
    best_acc = 0.0
    save_path = './{}Net.pth'.format(model_name)
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')


if __name__ == '__main__':
    main()

train_.py

这里是直接导入预训好的vgg16网络,微调

#直接导入预训好的vgg16网络,微调
#采用预训练权重训练10轮,不改变数据标准化处理,准确率达到0.93
#采用预训练权重训练10轮,改变数据标准化处理与预训练模型一致,准确率达到0.89

import os
import sys
import json

import torch
import torch.nn as nn
from torchvision import transforms, datasets,models
import torch.optim as optim
from tqdm import tqdm

from model import vgg


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.485,0.456,0.406), (0.229,0.224,0.225))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.485,0.456,0.406), (0.229,0.224,0.225))])}

    data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
    image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 32
    nw = 0  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=nw)
    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))

    # test_data_iter = iter(validate_loader)
    # test_image, test_label = test_data_iter.next()

    # 导入与与训练好的VGG16网络
    # vgg16 = models.vgg16(pretrained=True)   #会自动下载vgg16的预训练权重到C:\Users\Administrator/.cache\torch\hub\checkpoints\vgg16-397923af.pth

    #如果已经下载好了,可以使用以下方法加载
    vgg16=models.vgg16(pretrained=False)

    #如果不想使用官方的预训练模型,只想使用它的vgg网络,直接屏蔽一下代码即可
    weights_path='vgg16_pre.pth'
    vgg16.load_state_dict(torch.load(weights_path),strict=False)
    # 将vgg16的特征提取层参数冻结,不对其进行更新
    for param in vgg16.features.parameters():
        param.requires_grad_(False)
    net = vgg16

    #修改最后的全连接层,这里的classifier可以ctrl键点击models.vgg16查看
    net.classifier = nn.Sequential(
        nn.Linear(512 * 7 * 7, 4096),
        nn.ReLU(True),
        nn.Dropout(),
        nn.Linear(4096, 4096),
        nn.ReLU(True),
        nn.Dropout(),
        nn.Linear(4096, 5),
    )

    net.to(device)


    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0001)

    epochs = 10
    best_acc = 0.0
    save_path = './VGG16Net.pth'
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')

if __name__ == '__main__':
    main()

到这里我们就实现了从官方加载模型并进行迁移学习,可以只加载其初始化的模型重头开始训练,也可以带预训练权重一起加载,进行迁移学习。

predict.py

import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import vgg


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    # load image
    img_path = "../tulip.jpg"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)
    plt.imshow(img)
    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

    # read class_indict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    with open(json_path, "r") as f:
        class_indict = json.load(f)
    
    # create model
    model = vgg(model_name="vgg16", num_classes=5).to(device)
    # load model weights
    weights_path = "./vgg16Net.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    model.load_state_dict(torch.load(weights_path, map_location=device))

    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()

    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    plt.title(print_res)
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                                  predict[i].numpy()))
    plt.show()


if __name__ == '__main__':
    main()

你可能感兴趣的:(分类,迁移学习,python)