数据处理三 数据类别平衡(难易样本平衡)

在使用paddlseg进行训练时常常会出现针对于某一类数据预测效果不佳,这些样本通常被称为难样本。在设计loss时可以使用FocalLoss与BootstrappedCrossEntropyLoss作用损失函数进行训练,也可以自行对难样本进行过采样然后再进行训练在训练过程中难样本通常是频率较低的样本,故此可以改变数据的频率尝试将难样本变为普通样本。增强了难样本的频率,则在一个epoch中增加了难样本的迭代次数,使模型在训练过程中学习到更多的难样本特征。

例如,原始数据1000个,难样本100个,易样本900,难易样本比为1:9,每一训练一个epoch,易样本比难样本多学习了8遍,故而使得模型对易样本预测效果更好。当我们将难样本扩展到900个,此时数据总数为1800个(原始总量1000个,有800个重复的难样本数据),难易样本比为1:1,在训练过程中模型则没有了学习的偏好,则可以保证对难样本的预测效果。

1、数据频率统计

按照像素面积将数据分为大中小三类。(我们的数据只有一个连通域,其中连通域面积较大数据数据预测效果较好,连通域面积小的数据预测效果差)。大面积类为易样本,小面积类为难样本,中面积类为普通样本。
通过以下代码对钢卷数据标签进行分类,得出其频率信息。并将不同类别的数据分别拷贝到特定的目录下,以供观察分析(主要是看面积阈值设置的是否合理)。

import cv2,os
import shutil
from PIL import Image
import numpy as np

path=r"D:\实战项目\钢卷塌陷检测\All_train_data\txqy_data\Annotations"
path_list=os.listdir(path)
if not os.path.exists(path+'/small'):
    os.makedirs(path+'/small')
    os.makedirs(path+'/medium')
    os.makedirs(path+'/big')
path_list2=[x for x in path_list if "png" in x]
pl=[0,0,0]
for p in path_list2:
    print(path+'/'+p)
    #img=cv2.imread(path+'/'+p,1) #不支持中文路径
    img=np.array(Image.open(path+'/'+p))
    number=img.sum()

    if number<1200:
        print( "图片 %s 目标像素点个数:%s 属于小面积数据" % (p, number))
        shutil.copy(path+'/'+p,path+'/small')
        #shutil.copy(path+'/'+p.replace('.bmp','.jpg'),'1121/small')
        pl[0]+=1
         
    elif number>8000:
        print( "图片 %s 目标像素点个数:%s 属于大面积数据" % (p, number))
        shutil.copy(path+'/'+p,path+'/big')
        #shutil.copy(path+'/'+p.replace('.bmp','.jpg'),'1121/big')
        pl[2]+=1
        
    else:
        print( "图片 %s 目标像素点个数:%s 属于中等面积数据" % (p, number))
        shutil.copy(path+'/'+p,path+'/medium')
        #shutil.copy(path+'/'+p.replace('.bmp','.jpg'),'1121/medium')
        pl[1]+=1
        
print(pl)

通过下图输出,可以看到小面积数据有89个(难样本),中面积数据385个(一般样本)、大面积数据72个(易样本)

数据处理三 数据类别平衡(难易样本平衡)_第1张图片
代码执行完毕后,会在png图片路径生成big、medium和small三个目录,里面有划分好类别的数据。
数据处理三 数据类别平衡(难易样本平衡)_第2张图片

2、数据频率平衡

以下代码根据频率,设置了变量copy_times=[5,0,0],即将类别0(小面积数据类)对应的png图(标签图)与jpg图(原始图)复制5倍,类别1不进行复制,类别2也不进行复制。
因为在本项目中小面积数据较少,且是预测的难点;而大面积数据预测效果较好,虽然其数量小,但不需要扩展扩展数据主要是使数据频率达到平衡状态,使在训练过程中的小面积数据与中面积数据相等,如果数据平衡后对于小面积数据预测效果还是不好,可以再次进行增强。

path=r"D:\实战项目\钢卷塌陷检测\All_train_data\txqy_data\Annotations"
jpath=path.replace("Annotations","JPEGImages")
path_list=os.listdir(path)
path_list2=[x for x in path_list if "png" in x]

copy_times=[5,0,0]
def copy_by_ferquence(copy_times):
    if not os.path.exists(path+'/copy'):
        os.makedirs(path+'/copy')
    for p in path_list2:
        print(path+'/'+p)
        jpgname=p.replace('.png','.jpg')
        #img=cv2.imread(path+'/'+p,1) #不支持中文路径
        img=np.array(Image.open(path+'/'+p))
        number=img.sum()
        if number<1200:#小面积数据类
            for i in range(copy_times[0]):
                print(path+'/'+p,path+'/%s_'%i+p)
                shutil.copy(path+'/'+p,path+'/copy/%s_'%i+p)
                shutil.copy(jpath+'/'+jpgname,path+'/copy/%s_'%i+jpgname)
        elif number>8000:#大面积数据类
            for i in range(copy_times[2]):
                shutil.copy(path+'/'+p,path+'/copy/%s_'%i+p)
                shutil.copy(jpath+'/'+jpgname,path+'/copy/%s_'%i+jpgname)
        else:#中等面积数据类
            for i in range(copy_times[1]):
                shutil.copy(path+'/'+p,path+'/copy/%s_'%i+p)
                shutil.copy(jpath+'/'+jpgname,path+'/copy/%s_'%i+jpgname)
#pl=[89, 395, 72]
copy_by_ferquence(copy_times)

通过以上代码增加的数据存储在copy目录下,具体如下所示。自行将生成的jpg文件与png文件拷贝到相应目录即可。
数据处理三 数据类别平衡(难易样本平衡)_第3张图片

3、完整代码

完整代码如下所,要执行两遍,第一遍先执行get_ferquence函数,获取各个类别的频率信息,然后根据频率信息人为设定各个类别的增强次数,即变量 copy_times。最终将生成好的数据再次拷贝到paddleseg的数据目录下,重新生成txt列表进行训练。

import cv2,os
import shutil
from PIL import Image
import numpy as np

path=r"D:\实战项目\钢卷塌陷检测\All_train_data\txqy_data\Annotations"
jpath=path.replace("Annotations","JPEGImages")
path_list=os.listdir(path)
path_list2=[x for x in path_list if "png" in x]

#获取数据的频率信息
def get_ferquence():
    pl=[0,0,0]
    if not os.path.exists(path+'/small'):
        os.makedirs(path+'/small')
        os.makedirs(path+'/medium')
        os.makedirs(path+'/big')
    for p in path_list2:
        print(path+'/'+p)
        #img=cv2.imread(path+'/'+p,1) #不支持中文路径
        img=np.array(Image.open(path+'/'+p))
        number=img.sum()
        if number<1200:
            print( "图片 %s 目标像素点个数:%s 属于小面积数据" % (p, number))
            shutil.copy(path+'/'+p,path+'/small')
            #shutil.copy(path+'/'+p.replace('.bmp','.jpg'),'1121/small')
            pl[0]+=1
        elif number>8000:
            print( "图片 %s 目标像素点个数:%s 属于大面积数据" % (p, number))
            shutil.copy(path+'/'+p,path+'/big')
            #shutil.copy(path+'/'+p.replace('.bmp','.jpg'),'1121/big')
            pl[2]+=1
            
        else:
            print( "图片 %s 目标像素点个数:%s 属于中等面积数据" % (p, number))
            shutil.copy(path+'/'+p,path+'/medium')
            #shutil.copy(path+'/'+p.replace('.bmp','.jpg'),'1121/medium')
            pl[1]+=1
    return pl
#pl=get_ferquence()

copy_times=[5,0,0]
def copy_by_ferquence(copy_times):
    if not os.path.exists(path+'/copy'):
        os.makedirs(path+'/copy')
    for p in path_list2:
        print(path+'/'+p)
        jpgname=p.replace('.png','.jpg')
        #img=cv2.imread(path+'/'+p,1) #不支持中文路径
        img=np.array(Image.open(path+'/'+p))
        number=img.sum()
        if number<1200:
            for i in range(copy_times[0]):
                print(path+'/'+p,path+'/%s_'%i+p)
                shutil.copy(path+'/'+p,path+'/copy/%s_'%i+p)
                shutil.copy(jpath+'/'+jpgname,path+'/copy/%s_'%i+jpgname)
            #shutil.copy(path+'/'+p.replace('.bmp','.jpg'),'1121/small')
        elif number>8000:
            for i in range(copy_times[2]):
                shutil.copy(path+'/'+p,path+'/copy/%s_'%i+p)
                shutil.copy(jpath+'/'+jpgname,path+'/copy/%s_'%i+jpgname)
            #shutil.copy(path+'/'+p.replace('.bmp','.jpg'),'1121/big')
        else:
            for i in range(copy_times[1]):
                shutil.copy(path+'/'+p,path+'/copy/%s_'%i+p)
                shutil.copy(jpath+'/'+jpgname,path+'/copy/%s_'%i+jpgname)
            #shutil.copy(path+'/'+p.replace('.bmp','.jpg'),'1121/medium')
#pl=[89, 395, 72]
copy_by_ferquence(copy_times)


你可能感兴趣的:(数据处理,python,paddle,数据分析)