【pytorch机器学习 数据增强】

【python】数据增强

数据增强作用:
1.增加训练的数据量,提高模型的泛化能力
2.增加噪声数据,提升模型的鲁棒性
3.一定程度上能解决过拟合问题,样本过少容易出现过拟合
4.解决样本不平衡问题,例如某个类别过少,数据增强可以增强这个类别的数量。

在学习借鉴别人函数以及脚本的基础上改编完成。
函数代码源于https://blog.csdn.net/weixin_43149427/article/details/95034118
使用方法:
1.编辑器pycharm打开运行
2.存为py文件,作为脚本双击运行
功能:
在mydata文件夹下创建data_agu文件夹,且内部创建各类子文件夹,读取data文件夹下各子文件夹内图片进行数据增强,完成后存至data_agu文件夹内部子文件夹中。
扩展:
函数已经定义,根据函数进行创建下部分案例即可

import os
from shutil import copy, rmtree
import random
from PIL import Image
from PIL import ImageEnhance

import cv2
import numpy as np

def mk_file(file_path: str):
    if os.path.exists(file_path):
        # 如果文件夹存在,则先删除原文件夹在重新创建
        rmtree(file_path)
    os.makedirs(file_path)

def flip(root_path,img_name):   #水平翻转图像
    img = Image.open(os.path.join(root_path, img_name))
    filp_img = img.transpose(Image.Transpose.FLIP_LEFT_RIGHT)
    # filp_img.save(os.path.join(root_path,img_name.split('.')[0] + '_flip.jpg'))
    return filp_img

def flip1(root_path,img_name):   #竖直翻转图像
    img = Image.open(os.path.join(root_path, img_name))
    filp_img = img.transpose(Image.Transpose.FLIP_TOP_BOTTOM)
    # filp_img.save(os.path.join(root_path,img_name.split('.')[0] + '_flip.jpg'))
    return filp_img

def rotation20(root_path, img_name):
    img = Image.open(os.path.join(root_path, img_name))
    rotation_img = img.rotate(20) #旋转角度
    # rotation_img.save(os.path.join(root_path,img_name.split('.')[0] + '_rotation.jpg'))
    return rotation_img

def rotation_90(root_path, img_name):
    img = Image.open(os.path.join(root_path, img_name))
    rotation_img = img.rotate(-90) #旋转角度
    # rotation_img.save(os.path.join(root_path,img_name.split('.')[0] + '_rotation.jpg'))
    return rotation_img

def randomColor(root_path, img_name): #随机颜色
    image = Image.open(os.path.join(root_path, img_name))
    random_factor = np.random.randint(0, 31) / 10.  # 随机因子
    color_image = ImageEnhance.Color(image).enhance(random_factor)  # 调整图像的饱和度
    random_factor = np.random.randint(10, 21) / 10.  # 随机因子
    brightness_image = ImageEnhance.Brightness(color_image).enhance(random_factor)  # 调整图像的亮度
    random_factor = np.random.randint(10, 21) / 10.  # 随机因子
    contrast_image = ImageEnhance.Contrast(brightness_image).enhance(random_factor)  # 调整图像对比度
    random_factor = np.random.randint(0, 31) / 10.  # 随机因子
    return ImageEnhance.Sharpness(contrast_image).enhance(random_factor)  # 调整图像锐度


def contrastEnhancement(root_path, img_name):  # 对比度增强
    image = Image.open(os.path.join(root_path, img_name))
    enh_con = ImageEnhance.Contrast(image)
    contrast = 1.5
    image_contrasted = enh_con.enhance(contrast)
    return image_contrasted

def brightnessEnhancement(root_path,img_name):#亮度增强
    image = Image.open(os.path.join(root_path, img_name))
    enh_bri = ImageEnhance.Brightness(image)
    brightness = 1.5
    image_brightened = enh_bri.enhance(brightness)
    return image_brightened

def colorEnhancement(root_path,img_name):#颜色增强
    image = Image.open(os.path.join(root_path, img_name))
    enh_col = ImageEnhance.Color(image)
    color = 1.5
    image_colored = enh_col.enhance(color)
    return image_colored

def main():
    # 保证随机可复现
    

    # 指向你解压后的photos文件夹
    cwd = os.getcwd()
    data_root = os.path.join(cwd, "mydata")
    origin_data_path = os.path.join(data_root, "data")
    assert os.path.exists(origin_data_path), "path '{}' does not exist.".format(origin_data_path)
 

    data_class = [cla for cla in os.listdir(origin_data_path)
                    if os.path.isdir(os.path.join(origin_data_path, cla))]

    # 建立保存的文件夹
    train_root = os.path.join(data_root, "data2")
    mk_file(train_root)
    for cla in data_class:
        # 建立每个类别对应的文件夹
        mk_file(os.path.join(train_root, cla))

    #save_path = os.path.join(train_root, cla)
    save_root = os.path.join(data_root, "data_agu")
    mk_file(save_root)
    for cla in data_class:
        # 建立每个类别对应的文件夹
        mk_file(os.path.join(save_root, cla))
        
    for cla in data_class:
        cla_path = os.path.join(origin_data_path, cla)
        #images = os.listdir(cla_path)
       
        for  name in os.listdir(cla_path):
            
            saveName=name[:-4]+"id.jpg" 
            saveName1= name[:-4]+"be.jpg"
            saveName2= name[:-4]+"fl.jpg"
            saveName3= name[:-4]+"ro.jpg"
            saveName4 = name[:-4] + "cr.jpg"
                                    #image = Image.open(os.path.join(imageDir, name))
             #变换
            orimage = Image.open(os.path.join(cla_path,name))
            saveImage1=contrastEnhancement(cla_path,name)#对比度增强
            saveImage2=flip(cla_path,name)#翻转
            saveImage3=rotation20(cla_path,name)#旋转
            saveImage4=randomColor(cla_path,name)#随机颜色
             #格式转换
            orimage=orimage.convert('RGB')#原始图
            saveImage1 = saveImage1.convert('RGB')#明亮
            saveImage2 = saveImage2.convert('RGB')
            saveImage3= saveImage3.convert('RGB')
            saveImage4= saveImage4.convert('RGB')
                                
             #保存
            orimage.save(os.path.join(save_root ,cla,saveName))
            saveImage1.save(os.path.join(save_root ,cla,saveName1))
            saveImage2.save(os.path.join(save_root ,cla,saveName2))
            saveImage3.save(os.path.join(save_root ,cla,saveName3))
            saveImage4.save(os.path.join(save_root ,cla,saveName4))
            #print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="")  # processing bar
        print()

    print("processing done!")


if __name__ == '__main__':
    main()

你可能感兴趣的:(pytorch,机器学习,python)