IA-SEG项目中DIAL-Filters(IAPM模块+LGF模块)使用解读

IA-SEG项目源自论文Improving Nighttime Driving-Scene Segmentation via Dual Image-adaptive Learnable Filters,其核心就是在原有的语义分割模型上添加了DIAL-Filters。而,DIAL-Filters由两部分组成,包括一个图像自适应处理模块(IAPM,即IA-YOLO中的CNN-PP+DIF模块)和一个可学习的引导滤波器(LGF)。其项目代码使用pytorch实现,为能在pytorch下实现域自适应的检测算法,故对该项目进行分析。IA-SEG项目为针对夜间环境下的语义分割项目,其包含监督学习和非监督学习部分,这里只讨论其核心部分IAPM模块(CNN-PP与DIF)+LGF模块的使用。在本文的第三章和第四章有相关的代码使用案例。

除DIAL-Filters外,IA-SEG论文还提出了一种非监督学习框架,在博文最后面描述,感兴趣的朋友可以去查阅论文原文,或者看我的IA-SEG论文翻译讲解。

IA-SEG项目地址:https://github.com/wenyyu/IA-Seg#arxiv
IA-SEG项目中DIAL-Filters(IAPM模块+LGF模块)使用解读_第1张图片

1、CNN-PP模块

1.1 基本介绍

CNN-PP模块为DIP模块优化图像提供filter参数,其本质是一个简洁的卷积神经网络,其输入部分为低分辨的原始图,其输出为DIP模块的优化参数。
IA-SEG项目中DIAL-Filters(IAPM模块+LGF模块)使用解读_第2张图片
在IA-SEG中,CNN-PP模块的参数(预测4个filter参数,278K)比在IA-YOLO(预测15个filter参数,165k)要多。注:IA-SEG与IA-YOLO均为同一作者实现

1.2 实现代码

代码地址:https://github.com/wenyyu/IA-Seg/blob/main/network/dip.py
代码全文如下,其中涉及到一个外部对象cfg,该对象为配置文件,包含num_filter_parameters和cfg.filters在CNN_PP中被用到。

#! /usr/bin/env python
# coding=utf-8
import torch
import torch.nn as nn

import numpy as np

from configs.train_config import cfg

import time

def conv_downsample(in_filters, out_filters, normalization=False):
    layers = [nn.Conv2d(in_filters, out_filters, 3, stride=2, padding=1)]
    layers.append(nn.LeakyReLU(0.2))
    if normalization:
        layers.append(nn.InstanceNorm2d(out_filters, affine=True))
    return layers

class CNN_PP(nn.Module):
    def __init__(self, in_channels=3):
        super(CNN_PP, self).__init__()

        self.model = nn.Sequential(
            nn.Upsample(size=(256,256),mode='bilinear'),
            nn.Conv2d(3, 16, 3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.InstanceNorm2d(16, affine=True),
            *conv_downsample(16, 32, normalization=True),
            *conv_downsample(32, 64, normalization=True),
            *conv_downsample(64, 128, normalization=True),
            *conv_downsample(128, 128),
            #*discriminator_block(128, 128, normalization=True),
            nn.Dropout(p=0.5),
            nn.Conv2d(128, cfg.num_filter_parameters, 8, padding=0),
        )

    def forward(self, img_input):
        self.Pr = self.model(img_input)
        self.filtered_image_batch = img_input
        filters = cfg.filters
        filters = [x(img_input, cfg) for x in filters]
        self.filter_parameters = []
        self.filtered_images = []

        for j, filter in enumerate(filters):
            # with tf.variable_scope('filter_%d' % j):
            # print('    creating filter:', j, 'name:', str(filter.__class__), 'abbr.',
            #       filter.get_short_name())
            # print('      filter_features:', self.Pr.shape)

            self.filtered_image_batch, filter_parameter = filter.apply(
                self.filtered_image_batch, self.Pr)
            self.filter_parameters.append(filter_parameter)
            self.filtered_images.append(self.filtered_image_batch)

            # print('      output:', self.filtered_image_batch.shape)
        return self.filtered_image_batch, self.filtered_images, self.Pr, self.filter_parameters



def DIP():
    model = CNN_PP()
    return model

1.3 其他关联代码

DIP模块设计到了cfg对象,其代码地址为:
https://github.com/wenyyu/IA-Seg/blob/main/configs/train_config.py

这里与CNN-PP及CNN-DIP相关的代码如下:


import argparse
from network.filters import *

cfg.filters = [ExposureFilter, GammaFilter, ContrastFilter, UsmFilter]
# cfg.filters = []

cfg.num_filter_parameters = 4
#这里的配置均被用于DIF模块的滤波操作
cfg.exposure_begin_param = 0
cfg.gamma_begin_param = 1
cfg.contrast_begin_param = 2
cfg.usm_begin_param = 3
# Gamma = 1/x ~ x
cfg.curve_steps = 8
cfg.gamma_range = 3
cfg.exposure_range = 3.5
cfg.wb_range = 1.1
cfg.color_curve_range = (0.90, 1.10)
cfg.lab_curve_range = (0.90, 1.10)
cfg.tone_curve_range = (0.5, 2)
cfg.defog_range = (0.1, 1.0)
cfg.usm_range = (0.0, 5)
cfg.cont_range = (0.0, 1.0)

此外,其还关联到DIF的实现代码,后续会描述.
CNN-PP模块作为一个即插即用的头部模块,可以不用添加到模型结构中,在train函数补齐其流程即可。IA-SEG对CNN-PP的使用如下,从中可以看出输入CNNPP的是归一化的图片(但并未进行标准化), 同时CNN-PP的输出也并未与其他图像计算loss,CNN-PP的优化全靠forword流程结束后的loss,这与IA-YOLO中的设计不同

更多使用细节可以查看原作者代码

CNNPP = dip.DIP().to(device)
optimizer.zero_grad()
CNNPP.train()
model= PSPNet(num_classes=args.num_classes, dgf=args.DGF_FLAG).to(device)
model.train()
optimizer = optim.SGD(list(model.parameters())+list(CNNPP.parameters()),
                          lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
                          
for i_iter in range(args.num_steps):
        for sub_i in range(args.iter_size):
         	_, batch = trainloader_iter.__next__()
            images, labels, _, _ = batch
            images = images.to(device)
            labels = labels.long().to(device)
            enhanced_images_pre, ci_map, Pr, filter_parameters = CNNPP(images)
            enhanced_images = enhanced_images_pre
            enhanced_images[i_pre,...] = standard_transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(
                    enhanced_images_pre[i_pre,...])
            pred_c = model(enhanced_images)

2、DIF模块

2.1 基本介绍

DIF模块全程Differentiable Image Filters,其由几个具有可调超参数的可微滤波器组成,包括曝光度、伽玛度、对比度和锐度。在IA-SEG中的DIF代码其实是根据IA-YOLO中的DIP代码修改,将原先的TensorFlow实现修改为PyTorch语法,并注释了一些在IA-SEG中不需要用到的Filter模块(Tone Filter 和 Defog Filter)。

2.2 实现代码

代码地址:https://github.com/wenyyu/IA-Seg/blob/main/network/filters.py
其实现代码如下,这里滤除了一下被注释的代码(即原来用tensorflow实现的Tone Filter 和 Defog Filter等).

这里需要注意的是,所有的可微滤波器均继承自Filter,在构建Filter时的参数net,cfg仅有cfg起到作用见1.3章中的代码注释。rgb2lum, tanh_range, lerp函数被引入,为FIlter对象提供数据操作能力。

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from network.util_filters import rgb2lum, tanh_range, lerp
from network.util_filters import *

import cv2
import math

# device = torch.device("cuda")

class Filter(nn.Module):

  def __init__(self, net, cfg):
    super(Filter, self).__init__()

    self.cfg = cfg
    # self.height, self.width, self.channels = list(map(int, net.get_shape()[1:]))

    # Specified in child classes
    self.num_filter_parameters = None
    self.short_name = None
    self.filter_parameters = None

  def get_short_name(self):
    assert self.short_name
    return self.short_name

  def get_num_filter_parameters(self):
    assert self.num_filter_parameters
    return self.num_filter_parameters

  def get_begin_filter_parameter(self):
    return self.begin_filter_parameter

  def extract_parameters(self, features):
    # output_dim = self.get_num_filter_parameters(
    # ) + self.get_num_mask_parameters()
    # features = ly.fully_connected(
    #     features,
    #     self.cfg.fc1_size,
    #     scope='fc1',
    #     activation_fn=lrelu,
    #     weights_initializer=tf.contrib.layers.xavier_initializer())
    # features = ly.fully_connected(
    #     features,
    #     output_dim,
    #     scope='fc2',
    #     activation_fn=None,
    #     weights_initializer=tf.contrib.layers.xavier_initializer())
    return features[:, self.get_begin_filter_parameter():(self.get_begin_filter_parameter() + self.get_num_filter_parameters())], \
           features[:, self.get_begin_filter_parameter():(self.get_begin_filter_parameter() + self.get_num_filter_parameters())]

  # Should be implemented in child classes
  def filter_param_regressor(self, features):
    assert False

  # Process the whole image, without masking
  # Should be implemented in child classes
  def process(self, img, param, defog, IcA):
    assert False

  def debug_info_batched(self):
    return False

  def no_high_res(self):
    return False

  # Apply the whole filter with masking
  def apply(self,
            img,
            img_features=None,
            defog_A=None,
            IcA=None,
            specified_parameter=None,
            high_res=None):
    assert (img_features is None) ^ (specified_parameter is None)
    if img_features is not None:
      filter_features, mask_parameters = self.extract_parameters(img_features)
      filter_parameters = self.filter_param_regressor(filter_features)
    else:
      assert not self.use_masking()
      filter_parameters = specified_parameter

    if high_res is not None:
      # working on high res...
      pass
    debug_info = {}
    # We only debug the first image of this batch
    if self.debug_info_batched():
      debug_info['filter_parameters'] = filter_parameters
    else:
      debug_info['filter_parameters'] = filter_parameters[0]
    # self.mask_parameters = mask_parameters
    # self.mask = self.get_mask(img, mask_parameters)
    # debug_info['mask'] = self.mask[0]
    #low_res_output = lerp(img, self.process(img, filter_parameters), self.mask)
    low_res_output = self.process(img, filter_parameters, defog_A, IcA)

    if high_res is not None:
      if self.no_high_res():
        high_res_output = high_res
      else:
        self.high_res_mask = self.get_mask(high_res, mask_parameters)
        # high_res_output = lerp(high_res,
        #                        self.process(high_res, filter_parameters, defog, IcA),
        #                        self.high_res_mask)
    else:
      high_res_output = None
    #return low_res_output, high_res_output, debug_info
    return low_res_output, filter_parameters

  def use_masking(self):
    return self.cfg.masking

  def get_num_mask_parameters(self):
    return 6

  # Input: no need for tanh or sigmoid
  # Closer to 1 values are applied by filter more strongly
  # no additional TF variables inside
  def get_mask(self, img, mask_parameters):
    if not self.use_masking():
      print('* Masking Disabled')
      return tf.ones(shape=(1, 1, 1, 1), dtype=tf.float32)
    else:
      print('* Masking Enabled')
    with tf.name_scope(name='mask'):
      # Six parameters for one filter
      filter_input_range = 5
      assert mask_parameters.shape[1] == self.get_num_mask_parameters()
      mask_parameters = tanh_range(
          l=-filter_input_range, r=filter_input_range,
          initial=0)(mask_parameters)
      size = list(map(int, img.shape[1:3]))
      grid = np.zeros(shape=[1] + size + [2], dtype=np.float32)

      shorter_edge = min(size[0], size[1])
      for i in range(size[0]):
        for j in range(size[1]):
          grid[0, i, j,
               0] = (i + (shorter_edge - size[0]) / 2.0) / shorter_edge - 0.5
          grid[0, i, j,
               1] = (j + (shorter_edge - size[1]) / 2.0) / shorter_edge - 0.5
      grid = tf.constant(grid)
      # Ax + By + C * L + D
      inp = grid[:, :, :, 0, None] * mask_parameters[:, None, None, 0, None] + \
            grid[:, :, :, 1, None] * mask_parameters[:, None, None, 1, None] + \
            mask_parameters[:, None, None, 2, None] * (rgb2lum(img) - 0.5) + \
            mask_parameters[:, None, None, 3, None] * 2
      # Sharpness and inversion
      inp *= self.cfg.maximum_sharpness * mask_parameters[:, None, None, 4,
                                                          None] / filter_input_range
      mask = tf.sigmoid(inp)
      # Strength
      mask = mask * (
          mask_parameters[:, None, None, 5, None] / filter_input_range * 0.5 +
          0.5) * (1 - self.cfg.minimum_strength) + self.cfg.minimum_strength
      print('mask', mask.shape)
    return mask

  # def visualize_filter(self, debug_info, canvas):
  #   # Visualize only the filter information
  #   assert False

  def visualize_mask(self, debug_info, res):
    return cv2.resize(
        debug_info['mask'] * np.ones((1, 1, 3), dtype=np.float32),
        dsize=res,
        interpolation=cv2.cv2.INTER_NEAREST)

  def draw_high_res_text(self, text, canvas):
    cv2.putText(
        canvas,
        text, (30, 128),
        cv2.FONT_HERSHEY_SIMPLEX,
        0.8, (0, 0, 0),
        thickness=5)
    return canvas


class ExposureFilter(Filter):

  def __init__(self, net, cfg):
    Filter.__init__(self, net, cfg)
    self.short_name = 'E'
    self.begin_filter_parameter = cfg.exposure_begin_param
    self.num_filter_parameters = 1

  def filter_param_regressor(self, features):#param is in (-self.cfg.exposure_range, self.cfg.exposure_range)
    return tanh_range(
        -self.cfg.exposure_range, self.cfg.exposure_range, initial=0)(features)

  def process(self, img, param, defog, IcA):
    # print('      param:', param)
    # print('      param:', torch.exp(param * np.log(2)))


    # return img * torch.exp(torch.tensor(3.31).cuda() * np.log(2))
    return img * torch.exp(param * np.log(2))


class UsmFilter(Filter):#Usm_param is in [Defog_range]

  def __init__(self, net, cfg):

    Filter.__init__(self, net, cfg)
    self.short_name = 'UF'
    self.begin_filter_parameter = cfg.usm_begin_param
    self.num_filter_parameters = 1

  def filter_param_regressor(self, features):
    return tanh_range(*self.cfg.usm_range)(features)

  def process(self, img, param, defog_A, IcA):


    self.channels = 3
    kernel = [[0.00078633, 0.00655965, 0.01330373, 0.00655965, 0.00078633],
              [0.00655965, 0.05472157, 0.11098164, 0.05472157, 0.00655965],
              [0.01330373, 0.11098164, 0.22508352, 0.11098164, 0.01330373],
              [0.00655965, 0.05472157, 0.11098164, 0.05472157, 0.00655965],
              [0.00078633, 0.00655965, 0.01330373, 0.00655965, 0.00078633]]
    kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)
    kernel = np.repeat(kernel, self.channels, axis=0)

    # print('      param:', param)

    kernel = kernel.to(img.device)
    # self.weight = nn.Parameter(data=kernel, requires_grad=False)
    # self.weight.to(device)

    output = F.conv2d(img, kernel, padding=2, groups=self.channels)


    img_out = (img - output) * param + img
    # img_out = (img - output) * torch.tensor(0.043).cuda() + img

    return img_out

class ContrastFilter(Filter):

  def __init__(self, net, cfg):
    Filter.__init__(self, net, cfg)
    self.short_name = 'Ct'
    self.begin_filter_parameter = cfg.contrast_begin_param

    self.num_filter_parameters = 1

  def filter_param_regressor(self, features):
    # return tf.sigmoid(features)
    # return torch.tanh(features)
    return tanh_range(*self.cfg.cont_range)(features)

  def process(self, img, param, defog, IcA):
    # print('      param.shape:', param.shape)

    # luminance = torch.minimum(torch.maximum(rgb2lum(img), 0.0), 1.0)
    luminance = rgb2lum(img)
    zero = torch.zeros_like(luminance)
    one = torch.ones_like(luminance)

    luminance = torch.where(luminance < 0, zero, luminance)
    luminance = torch.where(luminance > 1, one, luminance)

    contrast_lum = -torch.cos(math.pi * luminance) * 0.5 + 0.5
    contrast_image = img / (luminance + 1e-6) * contrast_lum
    return lerp(img, contrast_image, param)
    # return lerp(img, contrast_image, torch.tensor(0.015).cuda())


class ToneFilter(Filter):

  def __init__(self, net, cfg):
    Filter.__init__(self, net, cfg)
    self.curve_steps = cfg.curve_steps
    self.short_name = 'T'
    self.begin_filter_parameter = cfg.tone_begin_param

    self.num_filter_parameters = cfg.curve_steps

  def filter_param_regressor(self, features):
    # tone_curve = tf.reshape(
    #     features, shape=(-1, 1, self.cfg.curve_steps))[:, None, None, :]
    tone_curve = tanh_range(*self.cfg.tone_curve_range)(features)
    return tone_curve

  def process(self, img, param, defog, IcA):
    # img = tf.minimum(img, 1.0)
    # param = tf.constant([[0.52, 0.53, 0.55, 1.9, 1.8, 1.7, 0.7, 0.6], [0.52, 0.53, 0.55, 1.9, 1.8, 1.7, 0.7, 0.6],
    #                       [0.52, 0.53, 0.55, 1.9, 1.8, 1.7, 0.7, 0.6], [0.52, 0.53, 0.55, 1.9, 1.8, 1.7, 0.7, 0.6],
    #                       [0.52, 0.53, 0.55, 1.9, 1.8, 1.7, 0.7, 0.6], [0.52, 0.53, 0.55, 1.9, 1.8, 1.7, 0.7, 0.6]])
    # param = tf.constant([[0.52, 0.53, 0.55, 1.9, 1.8, 1.7, 0.7, 0.6]])
    # param = tf.reshape(
    #     param, shape=(-1, 1, self.cfg.curve_steps))[:, None, None, :]
    param = torch.unsqueeze(param, 3)
    # print('      param.shape:', param.shape)

    tone_curve = param
    tone_curve_sum = torch.sum(tone_curve, axis=1) + 1e-30
    # print('      tone_curve_sum.shape:', tone_curve_sum.shape)

    total_image = img * 0
    for i in range(self.cfg.curve_steps):
      total_image += torch.clamp(img - 1.0 * i / self.cfg.curve_steps, 0, 1.0 / self.cfg.curve_steps) \
                     * param[:, i, :, :]
    # p_cons = [0.52, 0.53, 0.55, 1.9, 1.8, 1.7, 0.7, 0.6]
    # for i in range(self.cfg.curve_steps):
    #   total_image += tf.clip_by_value(img - 1.0 * i / self.cfg.curve_steps, 0, 1.0 / self.cfg.curve_steps) \
    #                  * p_cons[i]
    total_image *= self.cfg.curve_steps / tone_curve_sum
    img = total_image
    return img


  # def visualize_filter(self, debug_info, canvas):
  #   curve = debug_info['filter_parameters']
  #   height, width = canvas.shape[:2]
  #   values = np.array([0] + list(curve[0][0][0]))
  #   values /= sum(values) + 1e-30
  #   for j in range(0, self.curve_steps):
  #     values[j + 1] += values[j]
  #   for j in range(self.curve_steps):
  #     p1 = tuple(
  #         map(int, (width / self.curve_steps * j, height - 1 -
  #                   values[j] * height)))
  #     p2 = tuple(
  #         map(int, (width / self.curve_steps * (j + 1), height - 1 -
  #                   values[j + 1] * height)))
  #     cv2.line(canvas, p1, p2, (0, 0, 0), thickness=1)


class GammaFilter(Filter):  #gamma_param is in [1/gamma_range, gamma_range]

  def __init__(self, net, cfg):
    Filter.__init__(self, net, cfg)
    self.short_name = 'G'
    self.begin_filter_parameter = cfg.gamma_begin_param
    self.num_filter_parameters = 1

  def filter_param_regressor(self, features):
    log_gamma_range = np.log(self.cfg.gamma_range)
    # return tf.exp(tanh_range(-log_gamma_range, log_gamma_range)(features))
    return torch.exp(tanh_range(-log_gamma_range, log_gamma_range)(features))

  def process(self, img, param, defog_A, IcA):
    # print('      param:', param)

    # param_1 = param.repeat(1, 3)
    zero = torch.zeros_like(img) + 0.00001
    img = torch.where(img <= 0, zero, img)
    # print("GAMMMA", param)
    return torch.pow(img, param)
    # return torch.pow(img, torch.tensor(0.51).cuda())

    # param_1 = tf.tile(param, [1, 3])
    # return tf.pow(tf.maximum(img, 0.0001), param_1[:, None, None, :])
    # return img

2.3 其他关联代码

util_filters为filter提供了一些基础功能函数,如rgb2lum, tanh_range, lerp.
完整代码为: https://github.com/wenyyu/IA-Seg/blob/main/network/util_filters.py
主要代码如下:

import math
import cv2
import torch
import torch.nn as nn

def rgb2lum(image):
  image = 0.27 * image[:, :, :, 0] + 0.67 * image[:, :, :,
                                                  1] + 0.06 * image[:, :, :, 2]
  return image[:, :, :, None]


def tanh01(x):
  # return tf.tanh(x) * 0.5 + 0.5
  return torch.tanh(x) * 0.5 + 0.5



def tanh_range(l, r, initial=None):

  def get_activation(left, right, initial):

    def activation(x):
      if initial is not None:
        bias = math.atanh(2 * (initial - left) / (right - left) - 1)
      else:
        bias = 0
      return tanh01(x + bias) * (right - left) + left

    return activation

  return get_activation(l, r, initial)




def lerp(a, b, l):
  return (1 - l) * a + l * b

3、IPAM模块使用

IPAM模块实则为上文中CNN-PP与IDF模块的组合,这里在拎出来将使用,实则是为了将代码冲IA-SEG项目中剥离出来,单独使用。
在IA-SEG中,实质上已经将DIF模块嵌入到了CNN-PP的模型中,构成了IPAM模块。但是相关函数代码分离在多个py文件中,不便于使用,故此进行整合

3.1 整合代码

安装依赖项:pip install easydict
整合后的代码如下所示,仅需要修改最底部的cfg即可。这里构建了IPAM类,可以通过IPAM类直接进行图像域适应。

import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import time
import math

#-----Filter相关的基础函数------
def rgb2lum(image):
  image = 0.27 * image[:, :, :, 0] + 0.67 * image[:, :, :,
                                                  1] + 0.06 * image[:, :, :, 2]
  return image[:, :, :, None]
def tanh01(x):
  # return tf.tanh(x) * 0.5 + 0.5
  return torch.tanh(x) * 0.5 + 0.5
def tanh_range(l, r, initial=None):
  def get_activation(left, right, initial):
    def activation(x):
      if initial is not None:
        bias = math.atanh(2 * (initial - left) / (right - left) - 1)
      else:
        bias = 0
      return tanh01(x + bias) * (right - left) + left
    return activation
  return get_activation(l, r, initial)
def lerp(a, b, l):
  return (1 - l) * a + l * b
    
#-----Filter的相关实现------
class Filter(nn.Module):

  def __init__(self, net, cfg):
    super(Filter, self).__init__()

    self.cfg = cfg
    self.num_filter_parameters = None
    self.short_name = None
    self.filter_parameters = None

  def get_short_name(self):
    assert self.short_name
    return self.short_name

  def get_num_filter_parameters(self):
    assert self.num_filter_parameters
    return self.num_filter_parameters

  def get_begin_filter_parameter(self):
    return self.begin_filter_parameter

  def extract_parameters(self, features):
    return features[:, self.get_begin_filter_parameter():(self.get_begin_filter_parameter() + self.get_num_filter_parameters())], \
           features[:, self.get_begin_filter_parameter():(self.get_begin_filter_parameter() + self.get_num_filter_parameters())]

  # Should be implemented in child classes
  def filter_param_regressor(self, features):
    assert False

  # Process the whole image, without masking
  # Should be implemented in child classes
  def process(self, img, param, defog, IcA):
    assert False

  def debug_info_batched(self):
    return False

  def no_high_res(self):
    return False

  # Apply the whole filter with masking
  def apply(self,
            img,
            img_features=None,
            defog_A=None,
            IcA=None,
            specified_parameter=None,
            high_res=None):
    assert (img_features is None) ^ (specified_parameter is None)
    if img_features is not None:
      filter_features, mask_parameters = self.extract_parameters(img_features)
      filter_parameters = self.filter_param_regressor(filter_features)
    else:
      assert not self.use_masking()
      filter_parameters = specified_parameter

    if high_res is not None:
      # working on high res...
      pass
    debug_info = {}
    # We only debug the first image of this batch
    if self.debug_info_batched():
      debug_info['filter_parameters'] = filter_parameters
    else:
      debug_info['filter_parameters'] = filter_parameters[0]
    # self.mask_parameters = mask_parameters
    # self.mask = self.get_mask(img, mask_parameters)
    # debug_info['mask'] = self.mask[0]
    #low_res_output = lerp(img, self.process(img, filter_parameters), self.mask)
    low_res_output = self.process(img, filter_parameters, defog_A, IcA)

    if high_res is not None:
      if self.no_high_res():
        high_res_output = high_res
      else:
        self.high_res_mask = self.get_mask(high_res, mask_parameters)
        # high_res_output = lerp(high_res,
        #                        self.process(high_res, filter_parameters, defog, IcA),
        #                        self.high_res_mask)
    else:
      high_res_output = None
    #return low_res_output, high_res_output, debug_info
    return low_res_output, filter_parameters

  def use_masking(self):
    return self.cfg.masking

  def get_num_mask_parameters(self):
    return 6

  # Input: no need for tanh or sigmoid
  # Closer to 1 values are applied by filter more strongly
  # no additional TF variables inside
  def get_mask(self, img, mask_parameters):
    if not self.use_masking():
      print('* Masking Disabled')
      return tf.ones(shape=(1, 1, 1, 1), dtype=tf.float32)
    else:
      print('* Masking Enabled')
    with tf.name_scope(name='mask'):
      # Six parameters for one filter
      filter_input_range = 5
      assert mask_parameters.shape[1] == self.get_num_mask_parameters()
      mask_parameters = tanh_range(
          l=-filter_input_range, r=filter_input_range,
          initial=0)(mask_parameters)
      size = list(map(int, img.shape[1:3]))
      grid = np.zeros(shape=[1] + size + [2], dtype=np.float32)

      shorter_edge = min(size[0], size[1])
      for i in range(size[0]):
        for j in range(size[1]):
          grid[0, i, j,
               0] = (i + (shorter_edge - size[0]) / 2.0) / shorter_edge - 0.5
          grid[0, i, j,
               1] = (j + (shorter_edge - size[1]) / 2.0) / shorter_edge - 0.5
      grid = tf.constant(grid)
      # Ax + By + C * L + D
      inp = grid[:, :, :, 0, None] * mask_parameters[:, None, None, 0, None] + \
            grid[:, :, :, 1, None] * mask_parameters[:, None, None, 1, None] + \
            mask_parameters[:, None, None, 2, None] * (rgb2lum(img) - 0.5) + \
            mask_parameters[:, None, None, 3, None] * 2
      # Sharpness and inversion
      inp *= self.cfg.maximum_sharpness * mask_parameters[:, None, None, 4,
                                                          None] / filter_input_range
      mask = tf.sigmoid(inp)
      # Strength
      mask = mask * (
          mask_parameters[:, None, None, 5, None] / filter_input_range * 0.5 +
          0.5) * (1 - self.cfg.minimum_strength) + self.cfg.minimum_strength
      print('mask', mask.shape)
    return mask

  # def visualize_filter(self, debug_info, canvas):
  #   # Visualize only the filter information
  #   assert False

  def visualize_mask(self, debug_info, res):
    return cv2.resize(
        debug_info['mask'] * np.ones((1, 1, 3), dtype=np.float32),
        dsize=res,
        interpolation=cv2.cv2.INTER_NEAREST)

  def draw_high_res_text(self, text, canvas):
    cv2.putText(
        canvas,
        text, (30, 128),
        cv2.FONT_HERSHEY_SIMPLEX,
        0.8, (0, 0, 0),
        thickness=5)
    return canvas


class ExposureFilter(Filter):

  def __init__(self, net, cfg):
    Filter.__init__(self, net, cfg)
    self.short_name = 'E'
    self.begin_filter_parameter = cfg.exposure_begin_param
    self.num_filter_parameters = 1

  def filter_param_regressor(self, features):#param is in (-self.cfg.exposure_range, self.cfg.exposure_range)
    return tanh_range(
        -self.cfg.exposure_range, self.cfg.exposure_range, initial=0)(features)

  def process(self, img, param, defog, IcA):
    return img * torch.exp(param * np.log(2))


class UsmFilter(Filter):#Usm_param is in [Defog_range]

  def __init__(self, net, cfg):

    Filter.__init__(self, net, cfg)
    self.short_name = 'UF'
    self.begin_filter_parameter = cfg.usm_begin_param
    self.num_filter_parameters = 1

  def filter_param_regressor(self, features):
    return tanh_range(*self.cfg.usm_range)(features)

  def process(self, img, param, defog_A, IcA):


    self.channels = 3
    kernel = [[0.00078633, 0.00655965, 0.01330373, 0.00655965, 0.00078633],
              [0.00655965, 0.05472157, 0.11098164, 0.05472157, 0.00655965],
              [0.01330373, 0.11098164, 0.22508352, 0.11098164, 0.01330373],
              [0.00655965, 0.05472157, 0.11098164, 0.05472157, 0.00655965],
              [0.00078633, 0.00655965, 0.01330373, 0.00655965, 0.00078633]]
    kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)
    kernel = np.repeat(kernel, self.channels, axis=0)

    # print('      param:', param)

    kernel = kernel.to(img.device)
    # self.weight = nn.Parameter(data=kernel, requires_grad=False)
    # self.weight.to(device)

    output = F.conv2d(img, kernel, padding=2, groups=self.channels)


    img_out = (img - output) * param + img
    # img_out = (img - output) * torch.tensor(0.043).cuda() + img

    return img_out

class ContrastFilter(Filter):

  def __init__(self, net, cfg):
    Filter.__init__(self, net, cfg)
    self.short_name = 'Ct'
    self.begin_filter_parameter = cfg.contrast_begin_param

    self.num_filter_parameters = 1

  def filter_param_regressor(self, features):
    return tanh_range(*self.cfg.cont_range)(features)

  def process(self, img, param, defog, IcA):
    # print('      param.shape:', param.shape)

    # luminance = torch.minimum(torch.maximum(rgb2lum(img), 0.0), 1.0)
    luminance = rgb2lum(img)
    zero = torch.zeros_like(luminance)
    one = torch.ones_like(luminance)

    luminance = torch.where(luminance < 0, zero, luminance)
    luminance = torch.where(luminance > 1, one, luminance)

    contrast_lum = -torch.cos(math.pi * luminance) * 0.5 + 0.5
    contrast_image = img / (luminance + 1e-6) * contrast_lum
    return lerp(img, contrast_image, param)
    # return lerp(img, contrast_image, torch.tensor(0.015).cuda())


class ToneFilter(Filter):

  def __init__(self, net, cfg):
    Filter.__init__(self, net, cfg)
    self.curve_steps = cfg.curve_steps
    self.short_name = 'T'
    self.begin_filter_parameter = cfg.tone_begin_param

    self.num_filter_parameters = cfg.curve_steps

  def filter_param_regressor(self, features):
    tone_curve = tanh_range(*self.cfg.tone_curve_range)(features)
    return tone_curve

  def process(self, img, param, defog, IcA):
    param = torch.unsqueeze(param, 3)
    # print('      param.shape:', param.shape)

    tone_curve = param
    tone_curve_sum = torch.sum(tone_curve, axis=1) + 1e-30
    # print('      tone_curve_sum.shape:', tone_curve_sum.shape)

    total_image = img * 0
    for i in range(self.cfg.curve_steps):
      total_image += torch.clamp(img - 1.0 * i / self.cfg.curve_steps, 0, 1.0 / self.cfg.curve_steps) \
                     * param[:, i, :, :]
    total_image *= self.cfg.curve_steps / tone_curve_sum
    img = total_image
    return img

class GammaFilter(Filter):  #gamma_param is in [1/gamma_range, gamma_range]

  def __init__(self, net, cfg):
    Filter.__init__(self, net, cfg)
    self.short_name = 'G'
    self.begin_filter_parameter = cfg.gamma_begin_param
    self.num_filter_parameters = 1

  def filter_param_regressor(self, features):
    log_gamma_range = np.log(self.cfg.gamma_range)
    # return tf.exp(tanh_range(-log_gamma_range, log_gamma_range)(features))
    return torch.exp(tanh_range(-log_gamma_range, log_gamma_range)(features))

  def process(self, img, param, defog_A, IcA):
    # print('      param:', param)

    # param_1 = param.repeat(1, 3)
    zero = torch.zeros_like(img) + 0.00001
    img = torch.where(img <= 0, zero, img)
    # print("GAMMMA", param)
    return torch.pow(img, param)
      
#----------Filter模块的参数------------
from easydict import EasyDict as edict
cfg=edict()
cfg.num_filter_parameters = 4
#这里的配置均被用于DIF模块的滤波操作
cfg.exposure_begin_param = 0
cfg.gamma_begin_param = 1
cfg.contrast_begin_param = 2
cfg.usm_begin_param = 3
# Gamma = 1/x ~ x
cfg.curve_steps = 8
cfg.gamma_range = 3
cfg.exposure_range = 3.5
cfg.wb_range = 1.1
cfg.color_curve_range = (0.90, 1.10)
cfg.lab_curve_range = (0.90, 1.10)
cfg.tone_curve_range = (0.5, 2)
cfg.defog_range = (0.1, 1.0)
cfg.usm_range = (0.0, 5)
cfg.cont_range = (0.0, 1.0)

#----------DIF模块------------
class DIF(nn.Module):
    def __init__(self, Filters):
        super(DIF, self).__init__()
        self.Filters=Filters
    def forward(self, img_input,Pr):
        self.filtered_image_batch = img_input
        filters = [x(img_input, cfg) for x in self.Filters]
        self.filter_parameters = []
        self.filtered_images = []
        for j, filter in enumerate(filters):
            self.filtered_image_batch, filter_parameter = filter.apply(
                self.filtered_image_batch, Pr)
            self.filter_parameters.append(filter_parameter)
            self.filtered_images.append(self.filtered_image_batch)
        return self.filtered_image_batch, self.filtered_images, Pr, self.filter_parameters    
#----------IPAM模块------------
def conv_downsample(in_filters, out_filters, normalization=False):
    layers = [nn.Conv2d(in_filters, out_filters, 3, stride=2, padding=1)]
    layers.append(nn.LeakyReLU(0.2))
    if normalization:
        layers.append(nn.InstanceNorm2d(out_filters, affine=True))
    return layers
class IPAM(nn.Module):
    def __init__(self):
        super(IPAM, self).__init__()
        
        self.CNN_PP = nn.Sequential(
            nn.Upsample(size=(256,256),mode='bilinear'),
            nn.Conv2d(3, 16, 3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.InstanceNorm2d(16, affine=True),
            *conv_downsample(16, 32, normalization=True),
            *conv_downsample(32, 64, normalization=True),
            *conv_downsample(64, 128, normalization=True),
            *conv_downsample(128, 128),
            #*discriminator_block(128, 128, normalization=True),
            nn.Dropout(p=0.5),
            nn.Conv2d(128, cfg.num_filter_parameters, 8, padding=0),
        )
        Filters=[ExposureFilter, GammaFilter, ContrastFilter, UsmFilter]
        self.dif=DIF(Filters)

    def forward(self, img_input):
        self.Pr = self.CNN_PP(img_input)
        out = self.dif(img_input,self.Pr)
        return out

3.2 使用代码

使用代码如下

model = IPAM()
print(model)
x=torch.rand((1,3,256,256))
filtered_image_batch,filtered_images,Pr,filter_parameters=model(x)

代码输出如下,其中filtered_image_batch是优化后的图像,filtered_images是一个长度为4的list,其包含了4个图像增强过程的图像,Pr为DNN-PP的输出,filter_parameters为实际上的DIF参数

IPAM(
  (CNN_PP): Sequential(
    (0): Upsample(size=(256, 256), mode='bilinear')
    (1): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (2): LeakyReLU(negative_slope=0.2)
    (3): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (4): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): LeakyReLU(negative_slope=0.2)
    (6): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (7): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (8): LeakyReLU(negative_slope=0.2)
    (9): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (10): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (11): LeakyReLU(negative_slope=0.2)
    (12): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (13): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (14): LeakyReLU(negative_slope=0.2)
    (15): Dropout(p=0.5, inplace=False)
    (16): Conv2d(128, 4, kernel_size=(8, 8), stride=(1, 1))
  )
  (dif): DIF()
)

3.3 使用说明

这里由于IPAM的参数未经过训练,故生成的图像随机性比较强。
其中,ImgUilt的代码在: python工具方法 28 中。需要注意的是IPAM模块输入的图像时需要进行归一化的,这里可以通过检验IA-SEG作者dataset源码

import cv2,torch
from ImgUilt import *
import numpy as np
p=r'D:\YOLO_seq\helmet_yolo\images\train\000092.jpg'
im_tensor,img=read_img_as_tensor(p)

model = IPAM().cuda()
im_tensor=im_tensor/255
filtered_image_batch,filtered_images,Pr,filter_parameters=model(im_tensor)

new_img=tensor2img(filtered_image_batch.detach()*255)
myimshows([img,new_img])

执行效果如下所示,可见img在进过IPAM处理后,得到了随机增强,下图效果表明了局部的边缘增强效果。
IA-SEG项目中DIAL-Filters(IAPM模块+LGF模块)使用解读_第3张图片
按照IA-SEG作者的用法,IPAM模块的参数优化不需要额外loss,仅需将其与正常模型的forword流程相连接即可。具体训练代码如下,CNNPP的输出仅与model的输入有关,与任何loss不存在直接关联。

            enhanced_images_pre, ci_map, Pr, filter_parameters = CNNPP(images)
            enhanced_images = enhanced_images_pre


            for i_pre in range(enhanced_images_pre.shape[0]):
                enhanced_images[i_pre,...] = standard_transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(
                    enhanced_images_pre[i_pre,...])

            if args.model == 'RefineNet' or args.model.startswith('deeplabv3'):
                pred_c = model(enhanced_images)
            else:
                _, pred_c = model(enhanced_images)
            pred_c = interp(pred_c)
            loss_seg = seg_loss(pred_c, labels)


            loss = loss_seg #+ loss_seg_dark_dynamic + loss_seg_mix #+ loss_seg_dark_dynamic #+ loss_enhance
            loss_s = loss / args.iter_size
            loss_s.backward(retain_graph=True)
            loss_seg_value += loss_seg.item() / args.iter_size

同时在使用中,也可以参考IA-YOLO中的用法,将加噪声后的图像传给IPAM,将原始清晰图像与IPAM优化后的图像计算loss

4、LGF模块

4.1 模块简介

引导滤波器是一种边缘保持和梯度保持的图像操作,它利用引导图像中的对象边界来检测对象的显著性。它能够抑制目标外的显著性,提高下行检测或分割性能。从效果上看其就是对输出的feature map的微调。LGF模块的伪代码如下所示,其中fmean表示一个窗口半径为r的平均滤波器。相关性(corr)、方差(var)和协方差(cov)的缩写代表了这些变量的原始含义。其更多详细说明可以查看相关论文。

IA-SEG项目中DIAL-Filters(IAPM模块+LGF模块)使用解读_第4张图片

4.2 实现代码

代码地址:https://github.com/wenyyu/IA-Seg/blob/d6393cc87e5ca95ab3b27dee4ec31293256ab9a4/network/guided_filter.py
代码原文如下,可见guided_filter没有依赖任何外部函数。其中有GuidedFilter和FastGuidedFilter两个类,在IA-SEG中并没有使用FastGuidedFilter(当输入其中的三个参数lr_x, lr_y, hr_x,lr_x与hr_x相同时,其与GuidedFilter效果一模一样)。

以下代码的亮点在于实现了可微的方框滤波

import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable



def diff_x(input, r):
    assert input.dim() == 4

    left   = input[:, :,         r:2 * r + 1]
    middle = input[:, :, 2 * r + 1:         ] - input[:, :,           :-2 * r - 1]
    right  = input[:, :,        -1:         ] - input[:, :, -2 * r - 1:    -r - 1]

    output = torch.cat([left, middle, right], dim=2)

    return output

def diff_y(input, r):
    assert input.dim() == 4

    left   = input[:, :, :,         r:2 * r + 1]
    middle = input[:, :, :, 2 * r + 1:         ] - input[:, :, :,           :-2 * r - 1]
    right  = input[:, :, :,        -1:         ] - input[:, :, :, -2 * r - 1:    -r - 1]

    output = torch.cat([left, middle, right], dim=3)

    return output

class BoxFilter(nn.Module):
    def __init__(self, r):
        super(BoxFilter, self).__init__()

        self.r = r

    def forward(self, x):
        assert x.dim() == 4

        return diff_y(diff_x(x.cumsum(dim=2), self.r).cumsum(dim=3), self.r)


class FastGuidedFilter(nn.Module):
    def __init__(self, r, eps=1e-8):
        super(FastGuidedFilter, self).__init__()

        self.r = r
        self.eps = eps
        self.boxfilter = BoxFilter(r)


    def forward(self, lr_x, lr_y, hr_x):
        n_lrx, c_lrx, h_lrx, w_lrx = lr_x.size()
        n_lry, c_lry, h_lry, w_lry = lr_y.size()
        n_hrx, c_hrx, h_hrx, w_hrx = hr_x.size()

        assert n_lrx == n_lry and n_lry == n_hrx
        assert c_lrx == c_hrx and (c_lrx == 1 or c_lrx == c_lry)
        assert h_lrx == h_lry and w_lrx == w_lry
        assert h_lrx > 2*self.r+1 and w_lrx > 2*self.r+1

        ## N
        N = self.boxfilter(Variable(lr_x.data.new().resize_((1, 1, h_lrx, w_lrx)).fill_(1.0)))

        ## mean_x
        mean_x = self.boxfilter(lr_x) / N
        ## mean_y
        mean_y = self.boxfilter(lr_y) / N
        ## cov_xy
        cov_xy = self.boxfilter(lr_x * lr_y) / N - mean_x * mean_y
        ## var_x
        var_x = self.boxfilter(lr_x * lr_x) / N - mean_x * mean_x

        ## A
        A = cov_xy / (var_x + self.eps)
        ## b
        b = mean_y - A * mean_x

        ## mean_A; mean_b
        mean_A = F.interpolate(A, (h_hrx, w_hrx), mode='bilinear', align_corners=True)
        mean_b = F.interpolate(b, (h_hrx, w_hrx), mode='bilinear', align_corners=True)

        return mean_A*hr_x+mean_b


class GuidedFilter(nn.Module):
    def __init__(self, r, eps=1e-8):
        super(GuidedFilter, self).__init__()

        self.r = r
        self.eps = eps
        self.boxfilter = BoxFilter(r)


    def forward(self, x, y):
        n_x, c_x, h_x, w_x = x.size()
        n_y, c_y, h_y, w_y = y.size()

        assert n_x == n_y
        assert c_x == 1 or c_x == c_y
        assert h_x == h_y and w_x == w_y
        assert h_x > 2 * self.r + 1 and w_x > 2 * self.r + 1

        # N
        N = self.boxfilter(Variable(x.data.new().resize_((1, 1, h_x, w_x)).fill_(1.0)))

        # mean_x
        mean_x = self.boxfilter(x) / N
        # mean_y
        mean_y = self.boxfilter(y) / N
        # cov_xy
        cov_xy = self.boxfilter(x * y) / N - mean_x * mean_y
        # var_x
        var_x = self.boxfilter(x * x) / N - mean_x * mean_x

        # A
        A = cov_xy / (var_x + self.eps)
        # b
        b = mean_y - A * mean_x

        # mean_A; mean_b
        mean_A = self.boxfilter(A) / N
        mean_b = self.boxfilter(b) / N

        return mean_A * x + mean_b

以上代码可以保存为guided_filter.py

4.3 使用代码

暂时没有语义分割项目开展需求,故仅分析IA-SEG项目中的用法。
GuideFilter需要两个输入(边缘图和原始图),故需要额外的网络结构获取边缘图。
以下代码即是将普通的语义分割模型封装成一个包含LGF的模型,模型返回x1和x2,x1为正常语义分割的预测结果,x2为LGF优化后的结果。

class LGFModel(nn.Module):
    def __init__(self,  dgf, dgf_r, dgf_eps):
        self.inplanes = 64
        super(LGFModel, self).__init__()
        self.model=SegModel()
        if self.dgf:
            self.guided_map_conv1 = nn.Conv2d(3, 64, 1)
            self.guided_map_relu1 = nn.ReLU(inplace=True)
            self.guided_map_conv2 = nn.Conv2d(64, num_classes, 1)
            self.guided_filter = GuidedFilter(dgf_r, dgf_eps)

    def forward(self, x1):
        im = x1
        x1 = self.model(x1)
        if self.dgf:
            g = self.guided_map_relu1(self.guided_map_conv1(im))
            g = self.guided_map_conv2(g)
            x2 = F.interpolate(x1, im.size()[2:], mode='bilinear', align_corners=True)
            x2 = self.guided_filter(g, x2)
        return x1, x2

使用LGFModel,通常只需要对x2计算loss,可以不对x1的计算loss进行反向传播。如若模型收敛速度较慢,可以对x1计算loss进行反向传播。

你可能感兴趣的:(开源项目分析,图像处理,深度学习,域自适应)