RCNN网络源码解读(Ⅲ) --- finetune训练过程

目录

0.回顾

1.finetune二分类代码解释(finetune.py)

1.1  load_data(定义获取数据的方法)

1.2  CustomFineTuneDataset类

1.3  custom_batch_sampler类( custom_batch_sampler.py)

1.4 训练train_model


0.回顾

        上篇博客我们通过处理,已经得到了适用于二分类的数据集。在classifer_car目录下。

1.finetune二分类代码解释(finetune.py)

from image_handler import show_images
import numpy as np
 
if __name__ == ' __main__':
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    data_loaders,data_sizes = load_data('./data/classifier_car')
    #加载alexnet神经网洛
    model = models.alexnet(pretraine = True)
 
    print(model)
    data_loader = data_loaders["train"]
 
    print("一次迭代取得所有的正负数据,如果是多个类则取得多类数据集合")
    """
    index: 323 inage_id: 200 target: 1 image.shape: (254,342,3)[xmin,ymin,xnax,ymax]: [80,39,422,293]
    """


    #input是128个框体,targets是128个标注(0/1) 
    inputs,targets = next(data_loader.__iter__())
    print(inputs[0].size(),type(inputs[0]))
    trans = transforms.ToPILImage()
    print(type(trans(inputs[0])))
    print(targets)
    print(inputs.shape)
    titles = ["TRUE" if i.item() else "False" for i in targets[0:60]]
    images = [np.array(trans(i))for i in inputs[0:60]]
    show_images(images,titles=titles,num_cols=12)
 
 
    
    #把alexnet变成二分类模型,在最后一行改为2分类。
    num_features = model.classifier[6].in_features
    model.classifier[6] = nn.Linear(num_features,2)
    
    print("记alexnet变成二分类模型,在最后一行改为2分类",model)
    model = model.to(device)
 
    #代价函数
    criterion = nn.CrossEntroyLoss()
    #优化器
    optimizer = optim.SGD(model.parameters(),lr=1e-3, momentum=0.9)
    #学习率衰减
    lr_scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=7,gamma=0.1)
    
    #开始训练
    best_model = train_model(data_loaders,model,criterion,optimizer,lr_scheduler,device=device
num_epachs=10)
    
    check_dir('./models')
    torch.save(best_model.state_dict(),'models/alexnet_car.pth ')

        我们在开始的时候,先把上篇博客所准备的用于训练的二分类器的数据加载出来。

        随后加载数据。

        由于是个二分类器,还要更改网络结构。

        然后开始训练....

1.1  load_data(定义获取数据的方法)

import os
import copy
import time
import torch
import torch.nn as nn
import torch.optim as optin
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.models as models

from utils.data.custom_finetune_dataset import CustomFinetuneDataset
from utils.data.custom_batch_sampler import CustomBatchSampler
from utils.util import check_dir

def load_data(data_root_dir):
    transform = transforms.Compose([
    transforms.ToPILImage()
    transforns.Resize((227,227)),
    transforms.RandomHorizontalFlLip(),
    transforms.ToTensor(),
    transfonms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
    )]

    data_loaders = {}
    data_sizes={}

    for name in ['train',  'val' ]:
        data_dir = os.path.join(data_root_dir,name)
        data_set = CustomFinetuneDataset(data_dir,transform=transfonm)

        #从所有框体随机取128个数据
        data_sampler = CustomBatchSampler(data_set.get_positive_num(),data_set.get_negative_num(),32,96)
        
        #加载数据
        data_loader = DataLoader(data_set, batch_size=128,sampler=data_sampler,num_workers=8,drop_last=True)
        data_loaders[name] = data_loader
        data_sizes[name] = data_sampler.__ len__()

    return data_loaders,data_sizes

        transform用于针对我们输入的一系列图片做一系列的变化。Compose方法是针对图片进行下面一系列的集合的操作。

        上文传进来的data_root_dir ./data/classifier_car

        我们先处理train数据集的数据,路径data_dir ./data/classifier_car/train,并加上transform的数据变换。

1.2  CustomFineTuneDataset类

import os
import cv2
import numpy as np

from PIL import Image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
 
from utils.util import parse_car_csv
 
class CustomFineTuneDataset(Dataset):
    def __init__(self, root_dir,transform=None):
        #samples是图片名称
        #读取关于训练/测试集里面所有图片的内容
        samples = parse_car_csv(root_dir)
 
        #获取所有图片
        jpeg_images = [cv2.imread(os.path.join(root_dir,'JPEGImages',sample_name.zfill(6)+ ".jpg'))for sample_name in samples]

        #获取所有正例的框体
        #将所有正例图片的标注下标为_1.csv的文件拼接成一个列表
        positive_annotations = [os.path.join(root_dir,'Annotations', sample_name.zfill(6) +'_1.csv))for sample_name in samples]

        #获取所有负例框体
        negative_annotations = [os.path.join(root_dir,'Annotations', sample_name.zfill(6) +'_0.csv))for sample_name in samples]

        #边界框大小
        positive_sizes = list()
        negative_sizes = list()

        #边界框坐标
        positive_nects = list()
        negative_rects = list()
        
        sample_num = 1

        
        for annotation_path in positive_annotations:
            #这里的rect可能是多个框的集合
            rects = np.loadtxt(annotation_path,dtype=np.int,delimiter=' ')
            # print("训练集样本得真实框".format(sample_num,nects))
            sample_num += 1
            #存在文件为空或者文件中仅有单行数据
            if len(rects.shape) == 1:
                #是否为单行
                if rects.shape[0] ==1:
                    positive_rects.append(rects)
                    positive_sizes.append(1)
                else:
                    positive_sizes.append(0)
            else:
                positive_rects.extend(rects)
                positive_sizes.append(len(rects))

        #最后来说,positive_rects中positive_sizes前存放的是正例框体的集合
        print("训练集正向框体个数书正向框体汇总数".format(len(positive_rects),len(positive_sizes)))
        
        for annotation_path in negative_annotations:
            #这里的rect可能是多个框的集合
            rects = np.loadtxt(annotation_path,dtype=np.int,delimiter=' ')
            # print("训练集样本得真实框".format(sample_num,nects))
            sample_num += 1
            #存在文件为空或者文件中仅有单行数据
            if len(rects.shape) == 1:
                #是否为单行
                if rects.shape[0] ==1:
                    negative_rects.append(rects)
                    negative_sizes.append(1)
                else:
                    negative_sizes.append(0)
            else:
                negative_rects.extend(rects)
                negative_sizes.append(len(rects))

        #正向框体的后面就是反例框体!!!!!
        print("训练集正向框体个数书正向框体汇总数".format(len(negative_rects),len(negative_sizes)))
 
    
    #定义变换
    self.transform = transform

    #所有图像3742张
    self.jpeg_images = jpeg_images

    #正向框体汇总数量
    self.positive_sizes = positive_sizes

    #负向框体汇总数量
    self.negative_sizes = negative_sizes
    
    #正向框体列表
    self.positive_rects = positive_rects

    #负向框体列表
    self.negative_rects = negative_rects

    #正向框体总数
    self.total_positive_num = int(np.sun(positive_sizes))

    #负向框体总数
    self.total_negative_num = int(np.sun(negative_sizes))


    def __getitem__(self,index: int):
        """
        训练集正向框体个数621   正向框体汇总总数374
        训练练集负向框体个数357451   负向框体汇总总数374
        验证集正向松体个数617   正向框体汇总总数335
        验证集负向根体个数312808  负向框体汇总总数335
        """

        #定位下标所属图像
        image_id = len(self.jpeg_images) - 1
        # print(len(self.positive_sizes)) # 374
        
        #index 小于正例框体的数量
        if index < self.total_positive_num:
            #正样本
            target = 1
            #取得其中正样本的框体(621中的一个)
            xmin, ymin, xmax,ymax = self.positive_rects[index]
            #寻找所属图像
            for i in range(len(self.positive_sizes) - 1):
                if np.sum(self.positive_sizes[:i])<= index < np.sum(self.positive_sizes[:(i + 1)]):
                    image_id = i
                    break
            #截图
            image = self.jpeg_images[image_id][ymin :ymax, xmin:xmax]
        else:
            #负样本
            target = 0
            idx = index - self.total_positive_num
            xmin, ymin, xmax, ymax = self.negative_rects[idx]
            #寻找所属图像
            for i in range(len(self.negative_sizes) - 1):
                if np.sum(self.negative_sizes[:i])<= index < np.sum(self.negative_sizes[:(i + 1)]):
                    image_id = i
                    break
            image = self.jpeq_images[image_id][ymin:ymax, xmin:xmax]
        return image,target

    #返回总框体数目    
    def __len__(self) -> int:
        retrnn self.total_positive_num + self.total_negative_num

    #返回正例框体数目        
    def get_positive_num(self) -> int:
        return self.total_positive_num

    #返回负例框体数目   
    def get_negative_num(self) -> int:
        return self.total_negative_num

        这里__getitem__不是很好理解,我们举个例子:

RCNN网络源码解读(Ⅲ) --- finetune训练过程_第1张图片

          框体index范围 0-(621+357451-1)

         如果索引小于正向索引总数(621),则在图片中找到索引的图片截取那块到那块

搜寻方法。

         我们举一个小一点的例子:

        九个框体属于五张图。

        这里我们for循环idx就是1-9(0-8)

        positive_size = 【3,2,1,1,2】对应五张大图,每个图的框的数量为3 2 1 1 2

当i=0时候      self.positivesize(正向框体总数621)[0:0] 

当i=1时候      self.positivesize(正向框体总数621)[0:1] 

        这样我们就得到了小框体index所属于的大框体的索引image_id,最后我们返回image = self.jepg_image(大小为374)的截图。

        同理,getitem最终我们函数其实是返回对应索引框体所在的框体图像和它的所属类别(0反例1正例)

        我们写个函数测试一下这段代码:

def test(idx):
    root_dir = '../../data/classifier_car/train'
    train_data_set = customFinetuneDataset(root_dir)

    print('positive num: %d' % train_data_set.get_positive_num()
    print('negative num: %d' % train_data_set.get_negative_num()

    print('total num: %d' % train_data_set.__len__())

    image,target = train_data_set__.getitems__(idx)
    print('target: %d' % target)
    
    image = Image.fromarray(image)
    print(image)
    print(type(image))

    cv2.imshow('image',image)
    cV2.waitKey(0)

1.3  custom_batch_sampler类( custom_batch_sampler.py)

"""
(data_set.get_positive_num(),data_set.get_negative_num(),32,96)
正例框体总数、负例框体总数、32、96

"""

class customBatchsampler(Sampler):

    def __init__(self,num_positive,num_negative,batch_positive, batch_negative) -> None:
        """
        2分类数据集
        每次批量处理,其中batch_positive个正样本,batch_negative个负样本
        @param num_positive:正样本数目
        @param num_negative:负样本数目
        @param batch_positive:单次正样本数
        @param batch_negative:单次负样本数  
        """

        self.num_positive = num_positive
        self.num_negative = num_negative
        self.batch_positive = batch_positive
        self.batch_negative = batch_negative

        length = num_positive + num_negative
        #建立索引
        self.idx_list = list(range(length))

        self.batch = batch_negative + batch_positive
        self.num_iter = length // self.batch

    def __iter__(self):
        sampler_list = list()
I       for i in range(self.num_iter):
            """
            在self.idx_list的正向数据中取得32个数据
            在反面数据中获取随机96个数据作为测试数据集合 
            """

            #从 索引 0 : 正例的索引中(即全是正例的框体中) 选取batch_positive=32个样本
            tmp = np.concatenate(
                (random.sample(self.idx_list[:self.num_positive],self.batch_positive),
                random.sample(self.idx_list[self.num_positive:],self.batch_negative))
            )
    
            #打乱这128个框体顺序
            random.shuffle(tmp)
            sampler_list.extend(tmp)
        #返回迭代器
        return iter(sampler_list)


    #迭代次数 * 128
    def __len__(self) ->int:
        return self.num_iter * self.batch

    #迭代次数
    def get_num_batch(self) -> int:
        return self.num_iter

        一个小测试:

RCNN网络源码解读(Ⅲ) --- finetune训练过程_第2张图片

1.4 训练train_model

def train_model(data loaders, model,criterion,optimizer,lr_scheduler,num_epochs=25,device=Mone):

    since = time.time()
    best_model_weights = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}',format(epoch,num_epochs - 1))
        print(' -' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train','val']:
            if phase == 'train ':
                model.train()  # Set model to training mode
            else:
                model.eval()   #Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            #Iterate over data.
            for inputs,labels in data_loaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                #forward
                #track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _,preds = torch.max(outputs,1)
                    loss = criterion(outputs,labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                
                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
        
            if phase == 'train ':
                lr_scheduler.step()

            epoch_loss = running_loss / data_sizes[phase]
            epoch_acc = running_corrects.double() / data_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))
    
            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_weights = copy.deepcopy(model.state_dict())

        print()
    time_elapsed = time.time() - since
    print( 'Training complete in {:.0f}m {:.0f}s '.format(
        time_elapsed // 60,time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))


你可能感兴趣的:(RCNN源码解读,计算机视觉与深度学习,深度学习,人工智能,计算机视觉,cnn)