[pytorch]FixMatch代码详解(超详细)二

无标签数据数据增强方式 randaugment.py

目录

无标签数据数据增强方式 randaugment.py


根据索引返回对应的img和target,用transform参数控制强弱变。

class TransformFixMatch(object):
    def __init__(self, mean, std):
        self.weak = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=32,
                                  padding=int(32*0.125),
                                  padding_mode='reflect'])
        self.strong = trandforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=32,
                                  padding=int(32*0.125),       
                                  padding_mode='reflect'),
            RandAugmentMC(n=2, m=10)])
        self.normalize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)])

    def __call__(self, x):
        weak = self.weak(x)
        strong = self.strong(x)
        return self.normalize(weak), self.normalize(strong)

对无标记样本进行扩增(Augment),扩增分为强扩增和弱扩增,弱扩增使用标准的水平翻转和裁剪;强扩增使用RandAugment和CTAugment两种算法(这里用了RandAugment)。

这段代码定义了一个名为TransformFixMatch的类,它是一个数据转换器,主要用于数据增强,以在训练神经网络时提高性能。这个类有一个构造函数 __init__ ,接受两个参数 mean std,分别表示图像的均值和标准差。在类初始化时,创建了两个数据增强的变换,分别为 weak strong

  • weak 变换使用了随机水平翻转和随机裁剪,其中随机裁剪使用了反射填充(padding_mode='reflect')来避免黑边出现。这个变换用于生成弱变换的数据,即相对于原始图像来说变化比较小的图像。
  • strong 变换同样使用了随机水平翻转和随机裁剪,但它还使用了一个称为RandAugmentMC的数据增强技术,这个技术会随机地应用一些变换,例如旋转、裁剪、缩放等,来进一步增强数据。这个变换用于生成强变换的数据,即相对于原始图像来说变化比较大的图像。

类还有一个方法 __call__ ,接受一个输入图像 x,将其传递给 weak strong 变换,生成相应的弱变换和强变换。然后,这两个变换会被传递给 normalize 变换,它会将图像转换为张量,并将其标准化为均值为 mean,标准差为 std 的值,最后返回这两个标准化后的张量,即弱变换和强变换。

打开dataset中的randaugment.py文件

logger = logging.getLogger(__name__)
PARAMETER_MAX = 10

这段代码定义了一个名为loggerlogging对象,用于记录程序运行时的信息,可以通过该对象将信息打印到控制台或者写入日志文件中。

logging.getLogger(__name__) 返回一个与当前模块同名的logging对象,即logger对象,方便在其他模块中引用此对象。

PARAMETER_MAX 是一个常量,表示模型参数的最大值,一般用于进行模型参数的初始化或者限制参数范围等。这个常量的值为10,可以根据具体问题进行调整。

def AutoContrast(img, **kwarg):
    return PIL.ImageOps.autocontrast(img)

这段代码定义了一个名为AutoContrast的函数,它使用PIL库中的ImageOps.autocontrast函数来对输入的图像进行自动对比度增强,并将增强后的图像作为函数的返回值。

这个函数接受两个参数:img 表示输入的图像,可以是PIL图像对象或者Numpy数组等图像数据类型;**kwarg 表示可变关键字参数,允许传递额外的参数给 ImageOps.autocontrast 函数,以调整增强的程度等。

在函数内部,它直接调用了PIL.ImageOps.autocontrast函数,并将输入的图像作为参数传递给它。该函数会自动计算图像的直方图,并将图像的像素值重新映射到 0-255 的范围内,以增强图像的对比度。最后,函数将增强后的图像作为返回值返回。

def Brightness(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    return PIL.ImageEnhance.Brightness(img).enhance(v)

这段代码定义了一个名为Brightness的函数,用于对输入的图像进行亮度增强,并返回增强后的图像。该函数使用PIL库中的ImageEnhance.Brightness类实现亮度增强,其增强程度由参数 v 控制。

这个函数接受四个参数:img 表示输入的图像,可以是PIL图像对象或者Numpy数组等图像数据类型;v 表示亮度增强的程度,它是一个介于01之间的浮点数,实际的增强程度由 _float_parameter 函数生成max_v 表示亮度增强程度的上限,它是一个浮点数,用于控制亮度增强的幅度,如果 v 大于 max_v,则取 max_v 作为增强程度;bias 表示偏置量,它是一个浮点数,用于调整亮度增强的基准值。

在函数内部,首先使用 _float_parameter 函数生成亮度增强程度 v,然后将其与偏置量 bias 相加得到最终的亮度增强程度。接着,使用 PIL.ImageEnhance.Brightness 类创建一个增强器对象,并将输入的图像作为参数传递给它。最后,调用增强器对象的 enhance 方法,传递亮度增强程度作为参数,并将增强后的图像作为返回值返回。

def Color(img, v,max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    return PIL.ImageEnhance.Color(img).enhance(v)

这段代码定义了一个名为Color的函数,用于对输入的图像进行颜色增强,并返回增强后的图像。该函数使用PIL库中的ImageEnhance.Color类实现颜色增强,其增强程度由参数 v 控制。

这个函数接受四个参数:img 表示输入的图像,可以是PIL图像对象或者Numpy数组等图像数据类型;v 表示颜色增强的程度,它是一个介于01之间的浮点数,实际的增强程度由 _float_parameter 函数生成;max_v 表示颜色增强程度的上限,它是一个浮点数,用于控制颜色增强的幅度,如果 v 大于 max_v,则取 max_v 作为增强程度;bias 表示偏置量,它是一个浮点数,用于调整颜色增强的基准值。

在函数内部,首先使用 _float_parameter 函数生成颜色增强程度 v,然后将其与偏置量 bias 相加得到最终的颜色增强程度。接着,使用 PIL.ImageEnhance.Color 类创建一个增强器对象,并将输入的图像作为参数传递给它。最后,调用增强器对象的 enhance 方法,传递颜色增强程度作为参数,并将增强后的图像作为返回值返回。

def Contrast(img, v, max_v, bias = 0):
    v = _float_parameter(v, max_v) + bias
    return PIL.ImageEnhance.Contrast(img).enhance(v)

这段代码定义了一个名为Contrast的函数,用于对输入的图像进行对比度增强,并返回增强后的图像。该函数使用PIL库中的ImageEnhance.Contrast类实现对比度增强,其增强程度由参数 v 控制。

这个函数接受四个参数:img 表示输入的图像,可以是PIL图像对象或者Numpy数组等图像数据类型;v 表示对比度增强的程度,它是一个介于01之间的浮点数,实际的增强程度由 _float_parameter 函数生成;max_v 表示对比度增强程度的上限,它是一个浮点数,用于控制对比度增强的幅度,如果 v 大于 max_v,则取 max_v 作为增强程度;bias 表示偏置量,它是一个浮点数,用于调整对比度增强的基准值。

在函数内部,首先使用 _float_parameter 函数生成对比度增强程度 v,然后将其与偏置量 bias 相加得到最终的对比度增强程度。接着,使用 PIL.ImageEnhance.Contrast 类创建一个增强器对象,并将输入的图像作为参数传递给它。最后,调用增强器对象的 enhance 方法,传递对比度增强程度作为参数,并将增强后的图像作为返回值返回。

def Cutout(img, v, max_v, bias = 0):
    if v == 0:
        return img
    v = _float_parameter(v, max_v) + bias
    v = int(v * min(img.size))
    return CutoutAbs(img, v)

这段代码定义了一个名为Cutout的函数,用于对输入的图像进行随机擦除,并返回擦除后的图像。该函数会首先计算擦除区域的大小,然后在图像中随机选取一个位置,并将选中位置周围的像素点替换成固定的颜色值(例如黑色或者0值),从而实现擦除操作。

这个函数接受四个参数:img 表示输入的图像,可以是PIL图像对象或者Numpy数组等图像数据类型;v 表示擦除区域大小的比例,它是一个介于01之间的浮点数,实际的擦除区域大小由 _float_parameter 函数生成max_v 表示擦除区域大小比例的上限,它是一个浮点数,用于控制擦除区域的大小比例,如果 v 大于 max_v,则取 max_v 作为擦除区域大小比例;bias 表示偏置量,它是一个浮点数,用于调整擦除区域大小比例的基准值。

在函数内部,首先判断参数 v 是否等于0,如果是则直接返回输入的图像。否则,使用 _float_parameter 函数生成擦除区域大小比例 v,然后将其与偏置量 bias 相加得到最终的擦除区域大小比例。接着,计算擦除区域的大小,这里使用了 min(img.size) 函数,表示选取图像宽度和高度中的最小值作为擦除区域的大小,从而保证擦除区域不会超出图像边界。最后,调用 CutoutAbs 函数对输入图像进行随机擦除操作,并将擦除后的图像作为返回值返回。

def CutoutAbs(img, v, **kwarg):
    w, h = img.size
    x0 = np.random.uniform(0, w)
    y0 = np.random.uniform(0, h)
    x0 = int(max(0, x0 - v / 2.))
    y0 = int(max(0, y0 - v / 2.))
    x1 = int(min(w, x0 + v))
    y1 = int(min(h, y0 + v))
    xy = (x0, y0, x1, y1)
    #grey
    color = (127, 127, 127)
    img = img.copy()
    PIL.ImageDraw.Draw(img).rectangle(xy, color)
    return img

这段代码定义了一个名为 CutoutAbs 的函数,用于对输入的图像进行绝对位置的随机擦除,并返回擦除后的图像。该函数会首先随机选取一个擦除区域的左上角坐标(x0,y0),然后根据输入的擦除大小,计算出擦除区域的右下角坐标(x1,y1),并将该区域内的像素值替换成固定的颜色值(例如黑色或者0值),从而实现擦除操作。

这个函数接受三个参数:img 表示输入的图像,可以是PIL图像对象或者Numpy数组等图像数据类型;v 表示擦除区域的大小,它是一个整数,表示擦除区域的边长或者宽高等绝对尺寸;**kwarg 表示可选的关键字参数,这里没有使用到。

在函数内部,首先获取输入图像的宽度 w 和高度 h,然后使用 np.random.uniform 函数在图像中随机选取一个擦除区域的左上角坐标 x0 y0。接着,根据输入的擦除大小 v,计算出擦除区域的右下角坐标 x1 y1,并使用 max min 函数对坐标值进行调整,以确保擦除区域不会超出图像边界。接下来,构造一个由左上角坐标和右下角坐标组成的元组 xy,表示擦除区域的位置和大小。然后,使用固定的颜色值(这里使用 (127, 127, 127) 表示灰色)对擦除区域内的像素进行替换,并返回擦除后的图像。注意,在对输入图像进行擦除操作时,需要使用 img.copy() 复制一份原始图像,以避免修改原始图像数据。

def Equalize(img, **kwarg):
    return PIL.ImageOps.equalize(img)

Equalize 函数是一种图像变换,使用 PIL.ImageOps.equalize 函数对输入图像执行直方图均衡化操作。直方图均衡化是一种常用的图像增强方法,通过增加图像对比度来使得图像更加清晰。该函数会调整图像中每个像素的亮度值,使得输出图像中每个像素值的出现概率更加均衡,从而提高图像的质量。

def Identity(img, **kwarg):
    return img

Identity 函数是一种图像变换,其作用是将输入图像直接返回,即不做任何变换。这种函数通常用于在数据增强过程中,为了使得数据增强过程更加灵活,增加一些随机性。当随机变换的概率为零时,就使用 Identity 函数来保持原始图像不变。

def Invert(img, **kwarg):
    return PIL.ImageOps.invert(img)

Invert 函数是一种图像变换,使用 PIL.ImageOps.invert 函数对输入图像进行反转(即图像中每个像素的值取反)。该变换通常用于增强图像的对比度和细节,使得原本暗淡的部分更加突出,从而更加清晰地呈现图像的特征。

def Posterize(img, v, max_v, bias = 0):
    v = _int_parameter_(v, max_v) + bias
    return PIL.ImageOps.posterize(img, v)  

Posterize 函数是一种图像变换,使用 PIL.ImageOps.posterize 函数对输入图像进行色调分离操作,将图像中每个像素的色调分为 $2^v$ 个等级(其中 $v$ 是一个整数参数)。该变换可以使得图像变得更加简单,同时可以降低图像中的噪声和细节,从而使得图像更加清晰和易于处理。

def Rotate(img, v, max_v, bias = 0):
    v = _int_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    return img.rotate(v) 

Rotate 函数是一种图像变换,使用 PIL 库中的 rotate 函数对输入图像进行旋转操作。该函数会将图像按照给定的角度 $v$ 进行顺时针或逆时针旋转(取决于一个 $50%$ 的概率),从而产生新的图像。旋转操作可以使得图像的内容重新排布,从而提供一种不同于原始图像的视角,也可以用于对图像进行增强和修复,从而使得图像更加鲜明和生动。

def Sharpness(img, v, max_v, bias = 0):
    v = _float_parameter(v, max_v) + bias
    return PIL.ImageEnhance.Sharpness(img).enhance(v) 

Sharpness 函数是一种图像增强操作,使用 PIL 库中的 ImageEnhance 模块中的 Sharpness 函数对输入图像进行锐化操作。该函数会增强图像的边缘和细节,使得图像更加清晰和锐利。具体来说,该函数会对图像进行一定的卷积操作,从而增强图像的高频分量。函数中的参数 $v$ 控制了锐化的程度,取值范围为 $[0, \text{max}_v]$,其中 $\text{max}_v$ 是参数的最大值。参数 $bias$ 是一个偏移量,用于调整参数 $v$ 的基准值。

def ShearX(img, v, max_v, bias = 0):
    v = _float_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 

实现对图像进行ShearX变换,ShearX是图像变换中的一种,用于在x方向上对图像进行剪切,使图像在x方向上发生倾斜变化。该函数通过对输入图像进行仿射变换实现ShearX。其中,v是变换强度,max_v是变换强度的上限,bias是变换偏移。函数先从变换强度v和变换强度上限max_v中随机选择一个值,并加上变换偏移bias,得到最终的变换强度。然后,函数根据随机数决定变换方向(正方向或负方向),并利用PIL库提供的transform函数,将输入图像进行仿射变换,实现ShearX变换。

def ShearY(img, v, max_v, bias = 0):
    v = _float_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 

这段代码实现了对图像进行沿y轴方向的错切变换。函数名为ShearY,参数包括img表示输入的图像,v表示错切变换的强度,max_v表示错切变换的最大强度,bias表示对错切变换强度的偏置。函数首先将v加上bias,然后以一定概率随机地将v变成它的相反数(这个操作是为了增强数据的多样性),接着通过调用图像对象的transform方法,将图像进行错切变换,变换矩阵为(1, 0, 0, v, 1, 0),表示只对y坐标进行错切变换,变换强度为v最后返回变换后的图像。

def Solarize(img, v, max_v, bias=0):
    v = _int_parameter(v, max_v) + bias
    return PIL.ImageOps.solarize(img, 256 - v)

这个函数实现的是对图片进行反相处理(solarization),即将像素值小于阈值v的像素值取反,大于等于阈值v的像素值保持不变。阈值v在0到max_v之间随机选取,并加上一个偏差bias。

def SolarizeAdd(img, v, max_v, bias=0, threshold = 128):
    v = _int_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    img_np = np.array(img).astype(np.int)
    img_np = img_np + v
    img_np = np.clip(img_np, 0, 255)
    img_np = img_np.astype(np.uint8)
    img = Image.fromarray(img_np)
    return PIL.ImageOps.solarize(img, threshold)

SolarizeAdd函数的功能是对输入的图像进行太阳化加操作。具体来说,该函数会将输入图像转换为NumPy数组,对其所有像素值加上一个随机数v,然后将数组中的值裁剪到0和255之间。接下来,将处理后的数组重新转换为图像,并对其进行太阳化操作,将所有像素值低于给定阈值threshold的像素值翻转。最后返回处理后的图像。

def TranslateX(img, v, max_v, bias = 0):
    v = _float_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    v = int(v * img.size[0])
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 

对图像进行水平方向的平移。详细描述:将输入的图像在水平方向上移动一个随机的距离。如果随机数小于0.5,则向左移动,否则向右移动。移动的距离为输入参数v图像宽度(img.size[0])的乘积。如果bias不为0,则移动的距离还要加上bias与图像宽度的乘积,返回一个经过平移后的图像。

def TranslateY(img, v, max_v, bias = 0):
    v = _float_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    v = int(v * img.size[1])
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 

该函数的功能是垂直方向对图像进行平移操作,可根据传入的参数v、max_v和bias进行平移量的控制。函数中使用了PIL库的transform函数和AFFINE变换,通过对图像的变换实现平移效果。其中,v的值通过随机数实现翻转效果。

def _float_parameter(v, max_v):
    return float(v * max_v / PARAMETER_MAX)

该函数的功能是将一个参数 v 从 [0, PARAMETER_MAX] 的整数范围映射到 [0, max_v] 的浮点数范围。其中 PARAMETER_MAX 是一个预先定义的常量,表示参数的最大值。

def _int_parameter(v, max_v):
    return int(v * max_v / PARAMETER_MAX)

这个函数的作用是将0到PARAMETER_MAX之间的整数值v转换为0到max_v之间的整数值。具体来说,它首先将v除以PARAMETER_MAX,然后将结果乘以max_v并取整数。这个函数通常用于从超参数空间中采样一个值,并将其转换为在具体数据增强函数中使用的合适值。

def fixmatch_augment_pool():
    #FixMatch paper
    augs = [(AutoContrast, None, None),
            (Brightness, 0.9, 0.05),
            (Color, 0.9, 0.05),
            (Contrast, 0.9, 0.05),
            (Equalize, None, None),
            (Identity, None, None),
            (Posterize, 4, 4),
            (Rotate, 30, 0),
            (Sharpness, 0.9, 0.05),
            (ShearX, 0.3, 0),
            (ShearY, 0.3, 0),
            (Solarize, 256, 0),
            (TranslateX, 0.3, 0),
            (TranslateY, 0.3, 0)]
    return augs

这段代码定义了一个数据增强池 fixmatch_augment_pool,它包含了一系列的图像增强操作,这些操作可以用于对输入的图像进行随机变换,增加模型对不同样本的泛化能力。具体来说,它包含以下图像增强操作:

  • AutoContrast:自动对比度增强。
  • Brightness:调整亮度。
  • Color:调整饱和度。
  • Contrast:调整对比度。
  • Equalize:直方图均衡化。
  • Identity:不做任何变换。
  • Posterize:将图像量化为指定的位数。
  • Rotate:旋转图像。
  • Sharpness:调整图像锐度。
  • ShearX:在水平方向上剪切图像。
  • ShearY:在垂直方向上剪切图像。
  • Solarize:将图像进行反相处理。
  • TranslateX:在水平方向上平移图像。
  • TranslateY:在垂直方向上平移图像。

其中,每个操作的具体实现细节可以参考前面的函数定义。

def my_augment_pool():
    #Test
    augs = [(AutoContrast, None, None),
            (Brightness, 1.8, 0.1),
            (Color, 1.8, 0.1),
            (Contrast, 1.8, 0.1),
            (Cutout, 0.2, 0),
            (Equalize, None, None),
            (Identity, None, None),
            (Invert, None, None),
            (Posterize, 4, 4),
            (Rotate, 30, 0),
            (Sharpness, 1.8, 0.1),
            (ShearX, 0.3, 0),
            (ShearY, 0.3, 0),
            (Solarize, 256, 0),
            (SolarizeAdd, 110, 0),
            (TranslateX, 0.45, 0),
            (TranslateY, 0.45, 0)]
    return augs
  • Cutout:随机挖空
  • Invert:反转颜色
  • Posterize:压缩图像位数
  • SolarizeAdd:反转像素值并添加增量

其中,大多数操作的参数都包括一个最大值和一个偏移值,用于在一定范围内随机生成操作参数。返回的列表中,每个元素都是一个包含三个元素的元组:增强操作的函数名称,该函数所需的参数(如果有),以及一个与函数名称对应的字典,该字典包含了每个函数特有的默认参数和参数范围。

Class RandAugmentPC(object):
    def __init__(self, n, m):
        assert n >= 1
        assert 1 <= m <= 10
        self.n = n
        self.m = m
        self.augment_pool = my_augment_pool()

    def __call__(self, img):
        ops = random.choices(self.augment_pool, k = self.n)
        for ops, max_v, bias in ops:
            prob = np.random.uniform(0.2, 0.8)
            if random.random() + prob >= 1:
                img = op(img, v = self.m, max_v = max_v, bias = bias)
        img = CutoutAbs(img, int(32 * 0.5))
        return img

这个类是一个随机数据增强器,通过在给定的一组增强操作中随机选择一些来随机改变图像。类的初始化函数接受两个参数n和m,其中n是选择要应用的增强操作的数量,m是每个增强操作可以改变的强度大小。然后,增强池被设置为自定义的my_augment_pool函数,该函数返回一组增强操作。__call__方法实现了增强,首先随机选择n个增强操作,然后对于每个增强操作,以概率0.2到0.8之间随机选择一个浮点数prob,如果随机数+prob大于等于1,则应用该操作。最后,CutoutAbs函数被应用于图像,该函数随机生成一个区域并将其像素值替换为0。函数返回增强后的图像。

Class RandAugmentMC(object):
    def __init__(self, n, m):
        assert n >= 1
        assert 1 <= m <= 10
        self.n = n
        self.m = m
        self.augment_pool = fixmatch_augment_pool()

    def __call__(self, img):
        ops = random.choices(self.augment_pool, k = self.n)
        for ops, max_v, bias in ops:
            v = np.random.randint(1, self.m)
            if random.random() < 0.5:
                img = op(img, v = v, max_v = max_v, bias = bias)
        img = CutoutAbs(img, int(32 * 0.5))
        return img

这段代码实现了一个数据增强类RandAugmentMC,用于在训练深度神经网络时对图像数据进行增强处理。其中,n表示从fixmatch_augment_pool中选择的数据增强方法的数量,m表示数据增强方法的强度。该类的__call__方法接收一个图像作为输入,随机选择n个增强方法,并根据给定的强度参数m对图像进行数据增强,然后对增强后的图像进行Cutout处理,并返回增强后的图像。

RandAugmentPCRandAugmentMC是图像增强的类,用于对图像进行多种数据增强操作,以增加模型的鲁棒性和泛化性能。RandAugmentPC通过随机选择一组增强算子,并根据指定的参数在图像上应用,同时使用CutoutAbs函数进行遮挡,使得图像增强后更具有鲁棒性。RandAugmentMC使用了FixMatch论文中提出的增强算子组合,并根据随机生成的参数在图像上应用,同样使用CutoutAbs函数进行遮挡。两者的主要区别在于增强算子的不同和参数生成方式的不同,因此也会对增强效果产生影响。

你可能感兴趣的:(pytorch,深度学习,python)