caffe数据层数据增广

caffe数据层相关学习以及训练在线数据增广

caffe数据层是将已经生成好的LMDB文件中的label和数据读入到Datum数据结构体中,然后将数据转化到Blob中,进而进行数据传递,才能进行数据训练。目前使用的data_layer是经过了别人的改版,能够读入多个LMDB,并且在source_filelists中除了写入图像数据,还会加入每个数据样本所对应的key,从而实现Key—Value一一对应的结果。

本次希望实现的主要功能是希望在训练过程中,能够实时的随机改变图像数据的饱和度、亮度和对比度,进而能够达到数据增广的目的,增加数据的多样性,使得训练得到的模型的泛化性更好。而如果要达到这样的目的,就需要对data_transformer.cpp进行相应的代码更改。除此以外,因为希望在训练过程中,只对负样本进行数据增广,而不改变正样本。所以还需要对data_layer层进行相应更改。因为最初生成的key中已经记录了样本的正负属性,所以可以在对key进行相应处理知道属性是“正”or“负”。

提前在数据结构体中增加了bool is_augmentated变量。如果在data_layer中判断样本为负,则将is_augmentated值设为true,则在data_transformer.cpp中进行相应的数据增广。

在data_transformer.cpp和data_transformer.hpp中首先添加调整图像饱和度、亮度和对比度的函数。几个函数都传入的是Blob数据体的数据头,因为已经Blob的地址都是连续的,这样方便直接进行地址操作。但是要注意,和cv::Mat的数据存储格式HWC有所区别,Blob的存储格式是CHW,因此在使用指针进行地址操作的时候,需要进行相应的转变。

饱和度调整函数代码如下:

[cpp]  view plain  copy
  1. template <typename Dtype> \\注意次数需要使用对应的模板  
  2.     void DataTransformer::Saturation( \\饱和度调整  
  3.         Dtype* img, \\ 第一个坑,尽管blob是float型的,但是因为使用了模板,所以需要将所有和blob挂钩的变量类型都要更改为Dtype,这是一个模板类型的变量。  
  4.         const int img_size,  
  5.         const float alpha_rand  \\ 增广参数的范围,该参数是在网络中设置的,所以需要相应的更改caffe.proto文件,在TransformationParameter中增加相关参数   ) {  
  6.         float alpha = 0.0f;  
  7.         Dtype tempPixel = 0.0f;  
  8.         GenerateRandomFloat_range(alpha_rand, &alpha); \\根据caffe自带的随机数生成器Rand(),得到-alpha_rand~alpha_rand的随机实数  
  9.         // BGR to Gray scale image: R -> 0.299, G -> 0.587, B -> 0.114  
  10.         for (int h = 0; h < img_size; ++h) {  
  11.             for (int w = 0; w < img_size; ++w) {  
  12.                 Dtype gray_color = img[0 * img_size*img_size + h*img_size + w] * 0.114f + img[1 * img_size*img_size + h*img_size + w] * 0.587f +  
  13.                     img[2 * img_size*img_size + h*img_size + w] * 0.299f;   
  14.                  // 因为是已经是直接对Blob进行地址操作,所以可以直接使用指针头进行相应的位置移  
  15.                 for (int c = 0; c < 3; ++c) { t  
  16.                         empPixel = img[c * img_size*img_size + h*img_size + w] * alpha + gray_color * (1.0f - alpha);img[c * img_size*img_size + h*img_size + w] = CheckAugmentatedValue(&tempPixel);   
  17.                  //CheckAugmentatedValue()是特意增加的边界检测,防止修改后的数值超过[0,255]的图像范围  
  18.                 }  
  19.             }  
  20.         }  
  21.     }  
调整亮度和对比度的代码和饱和度的类似,不再增加相关注释

[cpp]  view plain  copy
  1. // assume HWC order and color channels BGR  
  2.     template <typename Dtype>  
  3.     void DataTransformer::Brightness(  
  4.         Dtype* img,  
  5.         const int img_size,  
  6.         const float alpha_rand  
  7.         ) {  
  8.         float alpha = 0.0f;  
  9.         Dtype tempPixel = 0.0f;  
  10.         GenerateRandomFloat_range(alpha_rand, &alpha);  
  11.         LOG(INFO) << "Brightness Alpha = " << alpha;  
  12.         int p = 0;  
  13.         for (int h = 0; h < img_size; ++h) {  
  14.             for (int w = 0; w < img_size; ++w) {  
  15.                 for (int c = 0; c < 3; ++c) {  
  16.                     tempPixel = img[p] * alpha;  
  17.                     img[p] = CheckAugmentatedValue(&tempPixel);  
  18.                     p++;  
  19.                 }  
  20.             }  
  21.         }  
  22.     }  



[cpp]  view plain  copy
  1. template <typename Dtype>  
  2.     void DataTransformer::Contrast(  
  3.         Dtype* img,  
  4.         const int img_size,  
  5.         const float alpha_rand  
  6.         ) {  
  7.         float gray_mean = 0;  
  8.         Dtype tempPixel = 0;  
  9.         for (int h = 0; h < img_size; ++h) {  
  10.             for (int w = 0; w < img_size; ++w) {  
  11.                 // BGR to Gray scale image: R -> 0.299, G -> 0.587, B -> 0.114  
  12.                 gray_mean += img[0 * img_size*img_size + h*img_size + w] * 0.114f + img[1 * img_size*img_size + h*img_size + w] * 0.587f +  
  13.                     img[2 * img_size*img_size + h*img_size + w] * 0.299f;  
  14.             }  
  15.         }  
  16.         gray_mean /= (img_size * img_size);  
  17.   
  18.         float alpha = 0.0f;  
  19.         GenerateRandomFloat_range(alpha_rand, &alpha);  
  20.         LOG(INFO) << "Contrast Alpha = " << alpha;  
  21.         int p = 0;  
  22.         for (int h = 0; h < img_size; ++h) {  
  23.             for (int w = 0; w < img_size; ++w) {  
  24.                 for (int c = 0; c < 3; ++c) {  
  25.                     tempPixel = img[p] * alpha + gray_mean * (1.0f - alpha);  
  26.                     img[p] = CheckAugmentatedValue(&tempPixel);  
  27.                     p++;  
  28.                 }  
  29.             }  
  30.         }  
  31.     }  

你可能感兴趣的:(caffe数据层数据增广)