解决过拟合问题的代码

1、使用更大的数据集

2、数据增强

数据增强包括翻转(水平或垂直)、任意角度旋转、随机缩放、随机裁剪、中心裁剪、添加噪声等。代码放在构建的dataloader中,相当于在训练中读取数据的时候进行随机的数据增强。

这里列出了翻转、旋转和裁剪。因为是显著性检测任务,image和label要一同变换。

from torch.utils.data import Dataset
from skimage import io, transform
import numpy as np
import scipy.io as scio
from sklearn.decomposition import PCA
import random
import torchvision.transforms as tf
from math import *
import cv2


def flip(image,mask): #水平翻转和垂直翻转
    if random.random()>0.5:
        image = np.flip(image, 1)
        mask = np.flip(mask, 1)
    if random.random()<0.5:
        image = np.flip(image, 2)
        mask = np.flip(mask, 2)
    return image, mask


def rotate_image(image, angle, center=None, scale=1.0):
    # grab the dimensions of the image
    (h, w) = image.shape[:2]
    # if the center is None, initialize it as the center of
    if center is None:
        center = (w // 2, h // 2)
    # perform the rotation
    M = cv2.getRotationMatrix2D(center, angle, scale)
    rotated = cv2.warpAffine(image, M, (w, h))
    # return the rotated image
    return rotated


def rotate(image,mask):
    angle = tf.RandomRotation.get_params([-180, 180]) # -180~180随机选一个角度旋转
    for i in range(image.shape[0]):
        image[i] = rotate_image(image[i], angle)
    for i in range(mask.shape[0]):
        mask[i] = rotate_image(mask[i], angle)
    return image, mask


def scale_up(image, mask):
    i_b, i_h, i_w = image.shape
    m_b, m_h, m_w = mask.shape
    sch = random.uniform(1.01,1.5)
    scw = random.uniform(1.01,1.5)
    ih = int(i_h * sch)
    iw = int(i_w * scw)
    image = transform.resize(image, (i_b, ih, iw))
    mask = transform.resize(mask, (m_b, ih, iw))
    rh = random.randrange(0,ih-i_h,1)
    rw = random.randrange(0,iw-i_w,1)
    image = image[:,rh:(rh+i_h),rw:(rw+i_w)]
    mask = mask[:,rh:(rh+i_h),rw:(rw+i_w)]
    return image, mask


def scale_down(image, mask):
    i_b, i_h, i_w = image.shape
    m_b, m_h, m_w = mask.shape
    sch = random.uniform(0.5,0.99)
    scw = random.uniform(0.5,0.99)
    ih = int(i_h * sch)
    iw = int(i_w * scw)
    image = transform.resize(image, (i_b, ih, iw))
    mask = transform.resize(mask, (m_b, ih, iw))
    rh = random.randrange(0,i_h-ih,1)
    rw = random.randrange(0,i_w-iw,1)
    image_ = np.zeros((i_b, i_h, i_w))
    mask_ = np.zeros((i_b, i_h, i_w))
    image_[:,rh:(rh+ih),rw:(rw+iw)] = image
    mask_[:,rh:(rh+ih),rw:(rw+iw)] = mask
    return image_, mask_


def scale(image, mask):
    if random.random()>0.5:
        image, mask = scale_down(image, mask)
    if random.random()<0.5:
        image, mask = scale_up(image, mask)
    return image, mask


def augmentation(img, label):
    img, label = scale(img, label)
    img, label = flip(img, label)
    img, label = rotate(img, label)
    
    return img, label

3、使用batchnorm,增大batchsize

增大batchsize可以使用SyncBatchNorm

在多卡下使用sync可以增大计算batchnrom时的size,以DDP为例

model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],output_device=args.local_rank,find_unused_parameters=True)

4、加入Dropout

dropout和batchsize一起使用会出现问题,eval()远低于train()下效果

需要使用均匀分布dropout,才能和batchnorm搭配使用,详情见我的另外一篇总结Dropout、高斯Dropout、均匀分布Dropout(Uout)_天明月落的博客-CSDN博客

5、加入正则化

正则化分为L1和L2,就是将模型参数量也考虑到优化里面,趋向于最简单的模型

L2正则化最简单,直接在定义网络optimizer的时候对weight_decay参数赋值,例如

optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.1)

值越大考虑的程度越深

L1正则化需要自己写

class Regularization(torch.nn.Module):
    def __init__(self,model,weight_decay,p=1):
        '''
        :param model 模型
        :param weight_decay:正则化参数
        :param p: 当p=2为L2正则化,p=1为L1正则化
        '''
        super(Regularization, self).__init__()
        if weight_decay <= 0:
            print("param weight_decay can not <=0")
            exit(0)
        self.model=model
        self.weight_decay=weight_decay
        self.p=p
        self.weight_list=self.get_weight(model)
        #self.weight_info(self.weight_list)
    def to(self,device):
        '''
        指定运行模式
        :param device: cude or cpu
        :return:
        '''
        self.device=device
        super().to(device)
        return self
    def forward(self, model):
        self.weight_list=self.get_weight(model)#获得最新的权重
        reg_loss = self.regularization_loss(self.weight_list, self.weight_decay, p=self.p)
        return reg_loss
    def get_weight(self,model):
        '''
        获得模型的权重列表
        :param model:
        :return:
        '''
        weight_list = []
        for name, param in model.named_parameters():
            if 'weight' in name:
                weight = (name, param)
                weight_list.append(weight)
        return weight_list
    def regularization_loss(self,weight_list, weight_decay, p=1):
        '''
        计算张量范数
        :param weight_list:
        :param p: 范数计算中的幂指数值,默认求2范数
        :param weight_decay:
        :return:
        '''
        # weight_decay=Variable(torch.FloatTensor([weight_decay]).to(self.device),requires_grad=True)
        # reg_loss=Variable(torch.FloatTensor([0.]).to(self.device),requires_grad=True)
        # weight_decay=torch.FloatTensor([weight_decay]).to(self.device)
        # reg_loss=torch.FloatTensor([0.]).to(self.device)
        reg_loss=0
        for name, w in weight_list:
            l2_reg = torch.norm(w, p=p)
            reg_loss = reg_loss + l2_reg
        reg_loss=weight_decay*reg_loss
        return reg_loss
    def weight_info(self,weight_list):
        '''
        打印权重列表信息
        :param weight_list:
        :return:
        '''
        print("---------------regularization weight---------------")
        for name ,w in weight_list:
            print(name)
        print("---------------------------------------------------")
if args.L1decay>0:
    reg_loss=Regularization(model, args.L1decay, p=1).to(device)
else:
    print("no regularization")
loss = fuse_loss(out, labels_v)
loss = loss + reg_loss(model).item()
loss.backward()

注释:先定义一个正则化类,改变参数可以分别去算L1、L2正则化。然后在计算loss时,把正则化loss和模型loss加起来。正则化loss可以理解为模型参数量的范数和,L1、L2对应两种范数计算方式。在反向传播的时候就会趋向于参数量小的模型达到防治过拟合的目的。

你可能感兴趣的:(深度学习,python,人工智能)