VGG代码实现

VGG代码实现

  • 开发环境
  • 准备工作
  • 项目代码结构
  • 猫狗数据集构建程序
  • vgg16模型获取程序
  • VGG推理demo程序
  • VGG猫狗数据集训练程序
  • VGG网络结构代码

开发环境

python–3.7
torch–1.8+cu101
torchsummary
torchvision–0.6.1+cu101
PIL
numpy
opencv-python
pillow

准备工作

VGG预训练模型,下载地址:
model_urls = {
‘vgg11’: ‘https://download.pytorch.org/models/vgg11-bbd30ac9.pth’,
‘vgg13’: ‘https://download.pytorch.org/models/vgg13-c768596a.pth’,
‘vgg16’: ‘https://download.pytorch.org/models/vgg16-397923af.pth’,
‘vgg19’: ‘https://download.pytorch.org/models/vgg19-dcbb9e9d.pth’,
‘vgg11_bn’: ‘https://download.pytorch.org/models/vgg11_bn-6002323d.pth’,
‘vgg13_bn’: ‘https://download.pytorch.org/models/vgg13_bn-abd245e5.pth’,
‘vgg16_bn’: ‘https://download.pytorch.org/models/vgg16_bn-6c64b313.pth’,
‘vgg19_bn’: ‘https://download.pytorch.org/models/vgg19_bn-c79401a0.pth’,
}

猫狗数据集,下载地址:
www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition/data

项目代码结构

VGG代码实现_第1张图片

  • data文件下存储了猫狗大战训练数据集、VGG预训练权重以及用于测试demo的图片和imagenet配置文件。
  • results存储tensorboard的一些绘制结果文件
  • src存储了vgg推理demo文件和vgg用于训练猫狗数据集的训练文件
  • tools存储通用文件:猫狗数据集构建文件和模型获取文件。

猫狗数据集构建程序

import os
import random
from PIL import Image
from torch.utils.data import Dataset

random.seed(1)


class CatDogDataset(Dataset):
    def __init__(self, data_dir, mode="train", split_n=0.9, rng_seed=620, transform=None):
        """
        猫狗数据分类任务的Dataset
        :param data_dir: 数据集所在路径
        :param mode: train/vaild
        :param split_n: 划分数据刻度
        :param rng_seed: 随机种子
        :param transform: torch.transform,数据预处理
        """
        self.mode = mode
        self.data_dir = data_dir
        self.rng_seed = rng_seed
        self.split_n = split_n
        self.data_info = self._get_img_info()  # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
        self.transform = transform

    def __getitem__(self, index):
        path_img, label = self.data_info[index]
        img = Image.open(path_img).convert('RGB')     # 0~255

        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等

        return img, label

    def __len__(self):
        if len(self.data_info) == 0:
            raise Exception("\ndata_dir:{} is a empty dir! Please checkout your path to images!".format(self.data_dir))
        return len(self.data_info)

    def _get_img_info(self):

        img_names = os.listdir(self.data_dir)
        img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))

        random.seed(self.rng_seed)
        random.shuffle(img_names)

        img_labels = [0 if n.startswith('cat') else 1 for n in img_names]

        split_idx = int(len(img_labels) * self.split_n) # 25000* 0.9 = 22500
        # split_idx = int(100 * self.split_n)
        if self.mode == "train":
            img_set = img_names[:split_idx]     # 数据集90%训练
            # img_set = img_names[:22500]     #  hard code 数据集90%训练
            label_set = img_labels[:split_idx]
        elif self.mode == "valid":
            img_set = img_names[split_idx:]
            label_set = img_labels[split_idx:]
        else:
            raise Exception("self.mode 无法识别,仅支持(train, valid)")

        path_img_set = [os.path.join(self.data_dir, n) for n in img_set]
        data_info = [(n, l) for n, l in zip(path_img_set, label_set)]

        return data_info

vgg16模型获取程序

import torch
import torchvision.models as models

def get_vgg16(path_state_dict, device, vis_model=False):
    """
    创建模型,加载参数
    :param path_state_dict: 模型参数
    :param device: cuda or cpu
    :param vis_model: 可视化模型 or not
    :return:
    """
    model = models.vgg16()
    # model = models.vgg16_bn()
    pretrained_state_dict = torch.load(path_state_dict)
    model.load_state_dict(pretrained_state_dict)
    model.eval()

    if vis_model:
        from torchsummary import summary
        summary(model, input_size=(3, 224, 224), device="cpu")

    model.to(device)
    return model

VGG推理demo程序

# 导入包
import os
import time
import json
import torch
import torchvision.transforms as transforms
from PIL import Image
from matplotlib import pyplot as plt
from tools.common_tools import get_vgg16

os.environ['NLS_LANG'] = 'SIMPLIFIED CHINESE_CHINA.UTF8'
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')

# 图片transform处理程序
def img_transform(img_rgb, transform=None):
    """
    将数据转换为模型读取的形式
    :param img_rgb: PIL Image对象
    :param transform: torchvision.transform
    :return: tensor
    """
    if transform is None:
        raise ValueError("找不到transform!必须有transform对img进行处理")
    img_t = transform(img_rgb)
    return img_t


# 对输入图片预处理程序
def process_img(path_img):
    # hard code
    norm_mean = [0.485, 0.456, 0.406]
    norm_std = [0.229, 0.224, 0.225]

    inference_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(norm_mean, norm_std),
    ])

    # path --> img
    img_rgb = Image.open(path_img).convert('RGB')
    # img --> tensor
    img_tensor = img_transform(img_rgb, inference_transform)
    img_tensor.unsqueeze_(0)  # chw --> bchw
    img_tensor = img_tensor.to(device)
    return img_tensor, img_rgb


# 标签名称加载
def load_class_names(p_clsnames, p_clsnames_cn):
    """
    加载标签名
    :param p_clsnames:
    :param p_clsnames_cn:
    :return:
    """
    with open(p_clsnames, "r") as f:
        class_names = json.load(f)
    with open(p_clsnames_cn, encoding='UTF-8') as f:  # 设置文件对象
        class_names_cn = f.readlines()
    return class_names, class_names_cn


if __name__ == "__main__":
	# config
    path_state_dict = os.path.join(BASE_DIR, "..", "data", "vgg16-397923af.pth")
    # path_img = os.path.join(BASE_DIR, "..", "..", "Data","Golden Retriever from baidu.jpg")
    path_img = os.path.join(BASE_DIR, "..", "..", "Data", "tiger cat.jpg")
    path_classnames = os.path.join(BASE_DIR, "..", "..", "Data", "imagenet1000.json")
    path_classnames_cn = os.path.join(BASE_DIR, "..", "..", "Data", "imagenet_classnames.txt")
    
    # load class names
    cls_n, cls_n_cn = load_class_names(path_classnames, path_classnames_cn)
    # 1/5 load img
    img_tensor, img_rgb = process_img(path_img)
    
    # 2/5 load model
    vgg_model = get_vgg16(path_state_dict, device, True)
    # 3/5 inference  tensor --> vector
    with torch.no_grad():
        time_tic = time.time()
        outputs = vgg_model(img_tensor)
        time_toc = time.time()
     # 4/5 index to class names
    _, pred_int = torch.max(outputs.data, 1)
    _, top5_idx = torch.topk(outputs.data, 5, dim=1)

    pred_idx = int(pred_int.cpu().numpy())
    pred_str, pred_cn = cls_n[pred_idx], cls_n_cn[pred_idx]
    print("img: {} is: {}\n{}".format(os.path.basename(path_img), pred_str, pred_cn))
    print("time consuming:{:.2f}s".format(time_toc - time_tic))

    # 5/5 visualization
    plt.imshow(img_rgb)
    plt.title("predict:{}".format(pred_str))
    top5_num = top5_idx.cpu().numpy().squeeze()
    text_str = [cls_n[t] for t in top5_num]
    for idx in range(len(top5_num)):
        plt.text(5, 15 + idx * 30, "top {}:{}".format(idx + 1, text_str[idx]), bbox=dict(fc='yellow'))
    plt.show()

VGG猫狗数据集训练程序

import os
import numpy as np
import torch.nn as nn
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt
from tools.my_dataset import CatDogDataset
from tools.common_tools import get_vgg16
from datetime import datetime


BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


now_time = datetime.now()
time_str = datetime.strftime(now_time, '%m-%d-%H-%M')
log_dir = os.path.join(BASE_DIR, "..", "results", time_str)
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

if __name__ == "__main__":
	# config
	data_dir = os.path.join(BASE_DIR, "..", "..", "Data", "train")
    path_state_dict = os.path.join(BASE_DIR, "..", "data", "vgg16-397923af.pth")
    num_classes = 2

    MAX_EPOCH = 3
    BATCH_SIZE = 32
    LR = 0.001

    log_interval = 2
    val_interval = 1
    classes = 2
    start_epoch = -1
    lr_decay_step = 1

	# ============================ step 1/5 数据 ============================
    norm_mean = [0.485, 0.456, 0.406]
    norm_std = [0.229, 0.224, 0.225]

    train_transform = transforms.Compose([
        transforms.Resize((256)),
        transforms.CenterCrop(256),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize(norm_mean, norm_std),
    ])
    normalizes = transforms.Normalize(norm_mean, norm_std)
    valid_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.TenCrop(224, vertical_flip=False),
        transforms.Lambda(lambda crops: torch.stack([normalizes(transforms.ToTensor()(crop)) for crop in crops])),
    ])

	# 构建dataset实列
    train_data = CatDogDataset(data_dir=data_dir, mode="train", transform=train_transform)
    valid_data = CatDogDataset(data_dir=data_dir, mode="valid", transform=valid_transform)

    # 构建dataloder
    train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
    valid_loader = DataLoader(dataset=valid_data, batch_size=4)

	# ============================ step 2/5 模型 ============================
    vgg16_model = get_vgg16(path_state_dict, device, False)
    num_ftrs = vgg16_model.classifier._modules["6"].in_features
    vgg16_model.classifier._modules["6"] = nn.Linear(num_ftrs, num_classes)

    vgg16_model.to(device)
    
	# ============================ step 3/5 损失函数 ============================
    criterion = nn.CrossEntropyLoss()
    # ============================ step 4/5 优化器 ============================
    # 冻结卷积层
    flag = 0
    # flag = 1
    if flag:
        fc_params_id = list(map(id, vgg16_model.classifier.parameters()))  # 返回的是parameters的 内存地址
        base_params = filter(lambda p: id(p) not in fc_params_id, vgg16_model.parameters())
        optimizer = optim.SGD([
            {'params': base_params, 'lr': LR * 0.1},  # 0
            {'params': vgg16_model.classifier.parameters(), 'lr': LR}], momentum=0.9)
    else:
        optimizer = optim.SGD(vgg16_model.parameters(), lr=LR, momentum=0.9)  # 选择优化器

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_step, gamma=0.1)
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(patience=5)
    # ============================ step 5/5 训练 ============================
    train_curve = list()
    valid_curve = list()

    for epoch in range(start_epoch + 1, MAX_EPOCH):

        loss_mean = 0.
        correct = 0.
        total = 0.

        vgg16_model.train()
        for i, data in enumerate(train_loader):

            # forward
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = vgg16_model(inputs)

            # backward
            optimizer.zero_grad()
            loss = criterion(outputs, labels)
            loss.backward()

            # update weights
            optimizer.step()

            # 统计分类情况
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).squeeze().cpu().sum().numpy()

            # 打印训练信息
            loss_mean += loss.item()
            train_curve.append(loss.item())
            if (i+1) % log_interval == 0:
                loss_mean = loss_mean / log_interval
                print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%} lr:{}".format(
                    epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total, scheduler.get_last_lr()))
                loss_mean = 0.

        scheduler.step()  # 更新学习率

        # validate the model
        if (epoch+1) % val_interval == 0:

            correct_val = 0.
            total_val = 0.
            loss_val = 0.
            vgg16_model.eval()
            with torch.no_grad():
                for j, data in enumerate(valid_loader):
                    inputs, labels = data
                    inputs, labels = inputs.to(device), labels.to(device)

                    bs, ncrops, c, h, w = inputs.size()
                    outputs = vgg16_model(inputs.view(-1, c, h, w))
                    outputs_avg = outputs.view(bs, ncrops, -1).mean(1)

                    loss = criterion(outputs_avg, labels)

                    _, predicted = torch.max(outputs_avg.data, 1)
                    total_val += labels.size(0)
                    correct_val += (predicted == labels).squeeze().cpu().sum().numpy()

                    loss_val += loss.item()

                loss_val_mean = loss_val/len(valid_loader)
                valid_curve.append(loss_val_mean)
                print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                    epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val_mean, correct_val / total_val))
            vgg16_model.train()

    train_x = range(len(train_curve))
    train_y = train_curve

    train_iters = len(train_loader)
    valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
    valid_y = valid_curve

    plt.plot(train_x, train_y, label='Train')
    plt.plot(valid_x, valid_y, label='Valid')

    plt.legend(loc='upper right')
    plt.ylabel('loss value')
    plt.xlabel('Iteration')
    plt.savefig(os.path.join(log_dir, "loss_curve.png"))
    plt.show()
	

VGG网络结构代码

pytorch官网上有VGG11、VGG11_bn、VGG13、VGG13_bn、VGG16、VGG16_bn、VGG19和VGG19_bn。
VGG代码实现_第2张图片

import torch
import torch.nn as nn

# VGG网络类
class VGG(nn.Module):

    def __init__(self, features, num_classes=1000, init_weights=True):
        super(VGG, self).__init__()
        self.features = features
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
# 卷积层配置参数,M表示最大池化
cfgs = {
    'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

# 卷积层配置函数
def make_layers(cfg, batch_norm=False):
    layers = []
    in_channels = 3
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)

# VGG模型配置函数
def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs):
    if pretrained:
        kwargs['init_weights'] = False
    model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model

# VGG16模型为例
def vgg16(pretrained=False, progress=True, **kwargs):
	return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs)

你可能感兴趣的:(CV-图像处理,pytorch,深度学习,python)