Pytorch-猫狗分类实战(上)

Pytorch物体识别分类实战案例(上)

前面更新了很多关于pytorch中的函数和模块,但理论学得再好不如实战来得更快,不停地敲敲敲敲代码,你终会变强!

这期的物体识别案例是个人纯手写,主要是用于熟悉pytorch搭建一个完整深度学习项目的所有过程,顺带熟悉pytorch的代码结构,其中包括数据预处理、数据集的加载、bath的读入等等,当我整个代码全部写完调试清楚,对于pytorch的代码构成有了一个大体的框架认识,看起github上别人优秀的代码会省下很多力。

话不多说,接下来看看整个项目的详细内容。

1.项目的文件构成及含义

Pytorch-猫狗分类实战(上)_第1张图片CAT_DOG:这是整个项目的文件夹
data:所有的图片数据都放在该文件夹里
model:用于存放训练好的模型
config.py:项目中使用到的超参数的设置,包括learning_rate,batch_size,epoch等等
data_augumentation.py:数据增强部分代码
data_lodar.py:数据加载部分代码
data_process.py:数据预处理部分代码,从图片变成可读入的文本
loss_function.py:损失函数代码
test.py:最终测试代码
train.py:训练代码
vgg16_net.py:搭建vgg16网络代码
train.txt:用于训练的txt文本,包含图片路径和label
arialuni.ttf:字体文件

本项目中data文件的构成如下:
在这里插入图片描述其中猫的所有图片放在Cat文件夹下,狗的所有图片放在Dog文件夹下。

该数据是kaggle竞赛里的猫和狗的图片,网上都有资源可下载,这里也放一个传送门:提取码:35zi

用于训练的txt文件train.txt的组成如下:
Pytorch-猫狗分类实战(上)_第2张图片其中每一行包含图像路径和label,中间用空格隔开。

这是分类任务,针对检测任务,格式和这一样,只不过label中包含了boundingbox的四个坐标和所属类别一共5个值,基本格式如下:
E:/pycharm/pycharmprojects/pytorch/CAT_DOG/data/Cat/0.jpg x1min,y1min,x1max,y1max,cls1 x2min,y2min,x2max,y2max,cls2…

这可以算是一个项目文件的基本构成模板,以后有新的任务都可以按这个模板去加一些东西进对应的文件里。

2. data_process.py 数据预处理部分

这部分代码就是生成训练可读的txt文本,然后从文本加载图片和label,不同的数据具体处理流程不太一样,但大体相同,可自行阅读体会,每行都写了注释,不明白可留言询问。

import os

root = './data/'	#定义数据所在的根文件夹
dir = os.listdir(root)	#遍历数据文件夹data下的所有类别文件夹
classes = ['cat', 'dog']	#定义类别数

def read_picture(dirs):	#该函数用于读取图片路径
    cat = []	#定义猫这一类的list,用于存放所有猫的路径
    dog = []	#定义狗这一类的list,用于存放所有狗的路径

    for d in dirs:	#循环data文件下的所有类别文件夹
        path = os.path.join(root, d)	#拼接路径'./data/Cat'
        pictures = os.listdir(path)		#遍历该类别文件夹下的所有图片
        #print(len(pictures))
        #print(pictures[0].split('.'))
        if d == 'Cat':	#如果子文件夹是Cat
            for img in range(len(pictures)):	#遍历Cat文件夹中的所有图片
                cat_name = pictures[img].split('.')		#将图片名称分为图片和‘jpg’后缀
                cat_imgname = cat_name[0]	#取出图片名
                cat.append(cat_imgname)		#将图片名放到list中
        else:	#如果子文件是Dog
            for img in range(len(pictures)):	#遍历Dog文件夹中的所有图片
                dog_name = pictures[img].split('.')		#将图片名称分为图片和‘jpg’后缀
                dog_imgname = dog_name[0]	#取出图片名
                dog.append(dog_imgname)		#将图片名放到list中
    return cat, dog

def write_txt():	#该函数用于将图片路径和label写入txt文件中
    cat, dog = read_picture(dir)	#先生成cat和dog的图片名
    txt = open('train.txt', 'w')	#打开要写入的txt文件
    for i in range(len(cat)):	#遍历所有猫的图片
        cat_id = classes.index('cat')	#生成对应的labenl,编码后的label,猫的label为0
        txt.write('E:/pycharm/pycharmprojects/pytorch/CAT_DOG/data/Cat/%s.jpg %d' % (cat[i], cat_id))	#写入txt
        txt.write('\n')		#一行结束换行

    # txt = open('train.txt', 'a')
    for j in range(len(dog)):	#遍历所有狗的图片
        dog_id = classes.index('dog')	#生成对应的labenl,编码后的label,狗的label为1
        txt.write('E:/pycharm/pycharmprojects/pytorch/CAT_DOG/data/Dog/%s.jpg %d' % (dog[j], dog_id))	#写入txt
        txt.write('\n')		#一行结束换行

    print("Finished!")

if __name__ == '__main__':
    write_txt()

3.config.py 超参数的定义

这个py文件用于存放训练中使用到的所有超参数,当然这些超参数不是说一定要在这定义,你也可以在各个py文件中单独定义超参数,这样单独使用一个py文件定义超参数的好处在于便于修改,比如你需要更改学习率或者batch又或是epoch,不用再去每个文件中寻找参数的位置,直接在这个文件中更改即可,具体怎么定义看你们自己的喜好了。

NUM_CLASSES = 2		#类别数量

CLASS_NAMES = ['cat', 'dog']	#类别名称,这个名称得特别注意,和在data_process.py中的list要一致,
								#因为cat的类别编码成了0,cat的类别只能在list的第一个位置,
								#不然最终会label对不上。

LEARNING_RATE = 0.0001	#学习率

EPOCHS = 5000	#总的epoch数量,即整个数据集会训练的次数

BATCH_SIZE = 2	#batch大小,一次加载两张图片

INPUT_SIZE = 224	#输入网络的图片的大小

3.data_lodar.py 数据加载部分代码

这个py文件中定义了数据怎么包装成torch中Dataset模块的样子,然后可被torch中DataLoadar模块调用,这两个模块在上一篇文章中有详细介绍,放个传送门:pytorch中的数据加载:Dataset与DataLoader

这部分代码也可以当成一个模板,后续再需要处理不同的数据时,只需要重写__init__、 与__len__、 __getitem__三个方法中的代码即可,大家可以细细品味琢磨以下,欢迎大家在评论区询问不懂的地方。

from torch.utils.data.dataset import Dataset    #通过继承torch.utils.data.dataset创建自定义数据集
import cv2

class Data_loader(Dataset):	#继承torch中的Dateset类
    def __init__(self, train_lines, transform):	#初始化参数
        super(Data_loader, self).__init__()	#继承Data_loader
        self.train_lines = train_lines	#读取的txt文件中的数据
        self.transform = transform	#定义数据增强方式

    def __len__(self):	#数据长度
        return len(self.train_lines)	#返回数据长度

    def __getitem__(self, idx):	#定义如何去数据,idx为数据的编号
        img_path = self.train_lines[idx].split(' ')[0]	#从train_lines将图片路径取出来
        label = self.train_lines[idx].split(' ')[1]	#从train_lines将对应图片label取出来
        img = cv2.imread(img_path)	#读取图片矩阵
        sample = {'image':img, 'label':label}	#将图片矩阵和对应的label值使用字典的方式存入一个样本sample中

        if self.transform:	#如果transform有值,则进行数据对应的数据增强
            sample = self.transform(sample)	#进行数据增强后的样本sample,仍然是一个字典形式
        return sample

4.data_augumentation.py 数据增强部分

这个py文件中主要定义了各种数据增强的方法。
数据增强在深度学习中是不可缺少的一部分,可以解决数据量不足的缺陷。

这个py文件也相当于一个模板,本项目只定义了随即缩放和随机裁剪两种数据增强方法的类,以后还有其他的数据增强方法,只需要在该文件中继续定义类,在类下创建方法即可。

import numpy as np
import torch
import config as cfg
import cv2

class Rescale(object):	#定义随机缩放的类 
    def __init__(self, output_size):	#初始化参数
        self.output_size = output_size	#随机缩放输出图片的大小
    def __call__(self,sample):	#定义随机缩放类的具体方法
        img= sample['image']	#取出sample字典中的图像矩阵
        label = sample['label']	#取出sample字典中的图像对应的label
        h, w = img.shape[:2]     #取出图片的高和宽,cv2读取的图像格式为(高,宽,channel)
   
        if h>w:
            new_h, new_w = self.output_size * h / w, self.output_size  # 保证了高宽比不变,短边直接
          	  														  #变成输出图像的大小,对长边进行缩放
        else:
            new_h, new_w = self.output_size, self.output_size * w / h
        new_img = cv2.resize(img, (int(new_w), int(new_h)))     #cv2的resize方法,输入是(宽,高)
        return {'image':new_img, 'label':label}	#返回缩放后的字典

class RandomCrop(object):	#定义随机裁剪的类
    def __init__(self ):	#初始化参数
        self.output_size = cfg.INPUT_SIZE	#随机裁剪后的图像直接会输入进网络层中进行前向传播,
        									#因此输出大小即为config.py中定义的输入图片大小
    def __call__(self, sample):	#定义随机裁剪类的具体方法
        img = sample['image']	#取出sample字典中的图像矩阵
        label = sample['label']	#取出sample字典中的图像对应的label
        h,w = img.shape[:2]     #取出图片的高和宽,cv2读取的图像格式为(高,宽,channel)
        new_h ,new_w = self.output_size, self.output_size	#定义裁剪后图像的大小,即为网络输入大小

        crop_x = np.random.randint(0, w - new_w)	#随机生成裁剪图像左上角的x坐标
        crop_y = np.random.randint(0, h - new_h)	#随机生成裁剪图像左上角的y坐标

        new_img = img[crop_y:crop_y+new_h, crop_x:crop_x+new_w]	#裁剪图像
        return {'image':new_img, 'label':label}

class ToTensor(object):	#定义将数据转换为torch需要的tensor的类
    def __call__(self, sample):	#定义方法
        image, label = sample['image'], sample['label']	#取出sample中的image矩阵和对应的label
        image = image.transpose((2, 0, 1))	#将图像的channel维度放第一位,torch的要求,
        									#cv2读取的图像为(h,w,C),需要转换为(C,h,w)
        
        label = int(label)	#将label转换成int型数据
        label = np.array(label)	#转换为numpy array的形式
        return {
            'image':torch.from_numpy(image),	#将image从array形式转换为tensor
            'label':torch.from_numpy(label)		#将label从array形式转换为tensor
        }

你可能感兴趣的:(pytorch)