CNN数据增强(1)

数据增强(Data Augmentation

(本人水平有限,请广大读者批评指正!!!!)  

深度学习通常需要大量的数据作为支撑,看到那些公开的数据集,少的也有几十万张,但是在现实中,我们能拥有的数据集网络没有那么到。但是数据量少,往往会造成过拟合等问题,因此需要一些“奇巧淫技”来增强数据,正好本人在看斯坦福的CS231N课程中的这方面介绍,因此做个总结。

结合课程和网上查看的资料,将Data Augmentation总结如下:

1、水平/竖直翻转。

2、随机crop

3、颜色改变。

4、仿射/旋转变换

5、随机改变大小

6、加噪声

7、·······

        下面对上述方法中部分进行具体介绍:

1、Keras

Keras是以tensorflowtheano作为后端的一个极易上手的框架,本人比较懒,所以研究生阶段用的最多的也就是Keras。在Keras中专门有一个图像数据增加的工具ImageDataGenerator。它能满足数据增强的大部分需求。

直接上代码:

from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img

Datagen = ImageDataGenerator(rotation_range=40,
			     shear_range=0.2,
			     zoom_range=0.2,
			     horizontal_flip=True,
			     vertical_flip = True
			     fill_mode='nearest')
#还有其他一些参数,具体请看:https://keras.io/preprocessing/image/ ,如去均值,标准化,ZCA白化,旋转,
#偏移,翻转,缩放等

img = load_img('../data/hand1.jpg')#获取一个PIL图像
x_img = img_to_array(img)
x_img = x_img.reshape((1,)+ x_img.shape)

i = 0
for img_batch in Datagen.flow(x_img,
			      batch_size=1,
			      save_to_dir='../data/pre_Data/'
			      save_prefix='hand',
			      save_format='jpeg'):
	i +=1
	if i > 20:
		break

2caffe

    现在在用caffe,想研究一下caffe中的数据增强,只发现mirror、scale、crop三种,查看了一些资料,需要自己添加一些数据增强的代码,因此最近一直在研究caffe源码(一个c++很烂的人研究源码,也是蛋疼),。

后续补充。

CNN数据增强(1)_第1张图片


3、提供一个链接:https://github.com/aleju/imgaug

   看上去效果很好:

CNN数据增强(1)_第2张图片 

 

4、PCA Jittering的实现:

PCA Jittering最早是由Alex在他2012年赢得的ImageNet竞赛的那篇NIPS中提出的,首先按照RGB三个颜色通道计算均值和标准差,对网络的输入数据进行规范化,随后我们在整个训练集上计算了协方差矩阵,进行特征分解,得到特征向量和特征值,用来做PCA Jittering

本文根据:https://www.zhihu.com/question/35339639中提供的PCA Jittering的代码做了下实验。代码如下

# -*- coding: utf-8 -*-
"""
Created on Wed May 10 10:00:53 2017

@author: xx
"""

import numpy as np
import os
from PIL import Image, ImageOps
import argparse
import random
from scipy import misc

def PCA_Jittering(path):
    img_list = os.listdir(path)
    img_num = len(img_list)
    
    for i in range(img_num):
        img_path = os.path.join(path, img_list[i])
        img = Image.open(img_path)
        
        img = np.asanyarray(img, dtype = 'float32')
        
        img = img / 255.0
        img_size = img.size / 3
        img1 = img.reshape(img_size, 3)
        img1 = np.transpose(img1)
        img_cov = np.cov([img1[0], img1[1], img1[2]])
        lamda, p = np.linalg.eig(img_cov)
        
        p = np.transpose(p)
        
        alpha1 = random.normalvariate(0,3)
        alpha2 = random.normalvariate(0,3)
        alpha3 = random.normalvariate(0,3)
        
        v = np.transpose((alpha1*lamda[0], alpha2*lamda[1], alpha3*lamda[2]))    
        add_num = np.dot(p,v)
        
        img2 = np.array([img[:,:,0]+add_num[0], img[:,:,1]+add_num[1], img[:,:,2]+add_num[2]])
        
        img2 = np.swapaxes(img2,0,2)
        img2 = np.swapaxes(img2,0,1)
        save_name = 'pre'+str(i)+'.png'
        save_path = os.path.join(path, save_name)
        misc.imsave(save_path,img2)
        
        #plt.imshow(img2)
        #plt.show()

效果如下:


CNN数据增强(1)_第3张图片 CNN数据增强(1)_第4张图片

你可能感兴趣的:(CNN)