【transformer】【pytorch】DeiT的数据增强

1 main中的相关参数

#函数:def get_args_parser():
parser.add_argument('--input-size', default=224, type=int, help='images input size')

#颜色抖动
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
                        help='Color jitter factor (default: 0.4)')
#rand_augment_transform的参数
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
                    help='Use AutoAugment policy. "v0" or "original". (default: rand-m9-mstd0.5-inc1)'),
#插值方法
parser.add_argument('--train-interpolation', type=str, default='bicubic',
                    help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
#repeated
parser.add_argument('--repeated-aug', action='store_true')
parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')
parser.set_defaults(repeated_aug=True)

#下面的是与随机擦除有关的参数
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
                    help='Random erase prob (default: 0.25)')
parser.add_argument('--remode', type=str, default='pixel',
                    help='Random erase mode (default: "pixel")')
parser.add_argument('--recount', type=int, default=1,
                    help='Random erase count (default: 1)')
parser.add_argument('--resplit', action='store_true', default=False,
                    help='Do not random erase first (clean) augmentation split')

2 datasets中的build_transform函数

没有直接将库中的create_transform返回,也是为了能够对其进行修改。

def build_transform(is_train, args):
    resize_im = args.input_size > 32
    #用于训练
    if is_train:
        # this should always dispatch to transforms_imagenet_train
        transform = create_transform(
            input_size=args.input_size,
            is_training=True,
            color_jitter=args.color_jitter,
            auto_augment=args.aa,
            interpolation=args.train_interpolation,
            re_prob=args.reprob,
            re_mode=args.remode,
            re_count=args.recount,
        )
        if not resize_im:
            # replace RandomResizedCropAndInterpolation with
            # RandomCrop
            transform.transforms[0] = transforms.RandomCrop(#create_transform返回的是一个列表,可以对列表中的函数进行更改
                args.input_size, padding=4)
        return transform
	#测试
    t = []
    if resize_im:
        size = int((256 / 224) * args.input_size)
        t.append(
            transforms.Resize(size, interpolation=3),  # to maintain same ratio w.r.t. 224 images
        )
        t.append(transforms.CenterCrop(args.input_size))

    t.append(transforms.ToTensor())#最后两个不能忘记
    t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
    return transforms.Compose(t)#形成的列表放入Compose中

3 create_transform函数

来源:transforms_factory.py(timm库)
函数是一个

#下面带井号的都是传入的参数
def create_transform(
        input_size,#
        is_training=False,#
        use_prefetcher=False,
        no_aug=False,
        scale=None,
        ratio=None,
        hflip=0.5,
        vflip=0.,
        color_jitter=0.4,#
        auto_augment=None,#arg.aa=rand-m9-mstd0.5-inc1
        interpolation='bilinear',#bicubic
        mean=IMAGENET_DEFAULT_MEAN,
        std=IMAGENET_DEFAULT_STD,
        re_prob=0.,#
        re_mode='const',#
        re_count=1,#
        re_num_splits=0,
        crop_pct=None,
        tf_preprocessing=False,
        separate=False):

 		 ...
  #没用到的就没有写
      transform = transforms_imagenet_train(#使用的是在ImageNet数据集上训练后得到的参数
          img_size,
          scale=scale,
          ratio=ratio,
          hflip=hflip,
          vflip=vflip,
          color_jitter=color_jitter,
          auto_augment=auto_augment,
          interpolation=interpolation,
          use_prefetcher=use_prefetcher,
          mean=mean,
          std=std,
          re_prob=re_prob,
          re_mode=re_mode,
          re_count=re_count,
          re_num_splits=re_num_splits,
          separate=separate)

4 transforms_imagenet_train函数

来源:同上

1)transform返回列表:
* 能选择是否返回三个还是合并的一个,显然seperate是这个作用;
* 第一个函数是RandomResizedCropAndInterpolation,记得刚才在函数中对其进行了替换,transform[0]2)primary_tfl:
* RandomResizedCropAndInterpolation
* RandomHorizontalFlip(可选)
* RandomVerticalFlip(可选)
3)secondary_tfl:
* rand_augment_transform
* ColorJitter
4)final_tfl:
* ToTensor
* Normalize
* RandomErasing

5)as_params:
* translate_const
* img_mean

def transforms_imagenet_train(
        img_size=224,
        scale=None,
        ratio=None,
        hflip=0.5,
        vflip=0.,
        color_jitter=0.4,
        auto_augment=None,
        interpolation='random',
        use_prefetcher=False,
        mean=IMAGENET_DEFAULT_MEAN,
        std=IMAGENET_DEFAULT_STD,
        re_prob=0.,
        re_mode='const',
        re_count=1,
        re_num_splits=0,
        separate=False,
):
    """
    If separate==True, the transforms are returned as a tuple of 3 separate transforms
    for use in a mixing dataset that passes
     * all data through the first (primary) transform, called the 'clean' data
     * a portion of the data through the secondary transform
     * normalizes and converts the branches above with the third, final transform
    """
    scale = tuple(scale or (0.08, 1.0))  # default imagenet scale range
    ratio = tuple(ratio or (3./4., 4./3.))  # default imagenet ratio range
    primary_tfl = [
        RandomResizedCropAndInterpolation(img_size, scale=scale, ratio=ratio, interpolation=interpolation)]
    if hflip > 0.:
        primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)]
    if vflip > 0.:
        primary_tfl += [transforms.RandomVerticalFlip(p=vflip)]

    secondary_tfl = []
    if auto_augment:#rand-m9-mstd0.5-inc1
        assert isinstance(auto_augment, str)
        if isinstance(img_size, tuple):
            img_size_min = min(img_size)
        else:
            img_size_min = img_size
        aa_params = dict(
            translate_const=int(img_size_min * 0.45),
            img_mean=tuple([min(255, round(255 * x)) for x in mean]),
        )
        if interpolation and interpolation != 'random':#从这里开始,看使用哪个auto_augment
            aa_params['interpolation'] = _pil_interp(interpolation)
        if auto_augment.startswith('rand'):#yes
            secondary_tfl += [rand_augment_transform(auto_augment, aa_params)]
        elif auto_augment.startswith('augmix'):
            aa_params['translate_pct'] = 0.3
            secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)]
        else:
            secondary_tfl += [auto_augment_transform(auto_augment, aa_params)]
    elif color_jitter is not None:
        # color jitter is enabled when not using AA
        if isinstance(color_jitter, (list, tuple)):
            # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
            # or 4 if also augmenting hue
            assert len(color_jitter) in (3, 4)
        else:
            # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
            color_jitter = (float(color_jitter),) * 3
        secondary_tfl += [transforms.ColorJitter(*color_jitter)]

    final_tfl = []
    if use_prefetcher:#False
        # prefetcher and collate will handle tensor conversion and norm
        final_tfl += [ToNumpy()]
    else:
        final_tfl += [
            transforms.ToTensor(),
            transforms.Normalize(
                mean=torch.tensor(mean),
                std=torch.tensor(std))
        ]
        if re_prob > 0.:
            final_tfl.append(
                RandomErasing(re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits, device='cpu'))

    if separate:
        return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl)
    else:
        return transforms.Compose(primary_tfl + secondary_tfl + final_tfl)

5 rand_augment_transform函数

调用:secondary_tfl += [rand_augment_transform(auto_augment, aa_params)]
auto_augment:rand-m9-mstd0.5-inc1
aa_params: translate_const, img_mean

def rand_augment_transform(config_str, hparams):                                                                             
    """                                                                                                                      
    Create a RandAugment transform                                                                                           
                                                                                                                             
    :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by      
    dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining      
    sections, not order sepecific determine                                                                                  
        'm' - integer magnitude of rand augment                                                                              
        'n' - integer num layers (number of transform ops selected per image)                                                
        'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)                          
        'mstd' -  float std deviation of magnitude noise applied                                                             
        'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) #为1表示使用严重程度随幅度增加的增强                   
    Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5                         
    'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2                        
                                                                                                                             
    :param hparams: Other hparams (kwargs) for the RandAugmentation scheme                                                   
                                                                                                                             
    :return: A PyTorch compatible Transform                                                                                  
    """ 
    #auto_augment:rand-m9-mstd0.5-inc1
	#aa_params: translate_const, img_mean                                                                                                                     
	
    magnitude = _MAX_LEVEL  # default to _MAX_LEVEL for magnitude (currently 10)                                             
    num_layers = 2  # default to 2 ops per image                                                                             
    weight_idx = None  # default to no probability weights for op choice                                                     
    transforms = _RAND_TRANSFORMS                                                                                            
    config = config_str.split('-')  #m9, mstd0.5, inc1                                                                                         
    assert config[0] == 'rand'                                                                                               
    config = config[1:] # mstd0.5, inc1                                                                                                    
    for c in config:                                                                                       
        cs = re.split(r'(\d.*)', c)                                                                                          
        if len(cs) < 2:                                                                                                      
            continue                                                                                                         
        key, val = cs[:2]                                                                                                    
        if key == 'mstd':                                                                                                    
            # noise param injected via hparams for now                                                                       
            hparams.setdefault('magnitude_std', float(val))                                                                  
        elif key == 'inc':                                                                                                   
            if bool(val):                                                                                                    
                transforms = _RAND_INCREASING_TRANSFORMS                                                                     
        elif key == 'm':                                                                                                     
            magnitude = int(val)                                                                                             
        elif key == 'n':                                                                                                     
            num_layers = int(val)                                                                                            
        elif key == 'w':                                                                                                     
            weight_idx = int(val)                                                                                            
        else:                                                                                                                
            assert False, 'Unknown RandAugment config section'                                                               
    ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms)   #所有的增强方法                                
    choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)        #如果w没有,那么就是None ,否则就使用_RAND_CHOICE_WEIGHTS_的参数                  
    return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)     #随机选取操作,ra_ops是传入的方法str列表,num_layers是增强方法的数目,choice_weight是对应的权重参数


_RAND_INCREASING_TRANSFORMS = [                                       
    'AutoContrast',                                                   
    'Equalize',                                                       
    'Invert',                                                         
    'Rotate',                                                         
    'PosterizeIncreasing',                                            
    'SolarizeIncreasing',                                             
    'SolarizeAdd',                                                    
    'ColorIncreasing',                                                
    'ContrastIncreasing',                                             
    'BrightnessIncreasing',                                           
    'SharpnessIncreasing',                                            
    'ShearX',                                                         
    'ShearY',                                                         
    'TranslateXRel',                                                  
    'TranslateYRel',                                                  
    #'Cutout'  # NOTE I've implement this as random erasing separately       

# These experimental weights are based loosely on the relative improvements mentioned in paper.   
# They may not result in increased performance, but could likely be tuned to so.                  
_RAND_CHOICE_WEIGHTS_0 = {                                                                        
    'Rotate': 0.3,                                                                                
    'ShearX': 0.2,                                                                                
    'ShearY': 0.2,                                                                                
    'TranslateXRel': 0.1,                                                                         
    'TranslateYRel': 0.1,                                                                         
    'Color': .025,                                                                                
    'Sharpness': 0.025,                                                                           
    'AutoContrast': 0.025,                                                                        
    'Solarize': .005,                                                                             
    'SolarizeAdd': .005,                                                                          
    'Contrast': .005,                                                                             
    'Brightness': .005,                                                                           
    'Equalize': .005,                                                                             
    'Posterize': 0,                                                                               
    'Invert': 0,                                                                                  
}      

                                                                                                                                   

6 RandAugment类

class RandAugment:                                                                                
    def __init__(self, ops, num_layers=2, choice_weights=None):                                   
        self.ops = ops                                                                            
        self.num_layers = num_layers                                                              
        self.choice_weights = choice_weights                                                      
                                                                                                  
    def __call__(self, img):                                                                      
        # no replacement when using weighted choice                                               
        ops = np.random.choice(                                                                   
            self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights)
        for op in ops:                                                                            
            img = op(img)                                                                         
        return img                                                                                

你可能感兴趣的:(transformer_CV)