【FDA】图像通过傅里叶变换改变光谱风格,实现域自适应

FDA: Fourier Domain Adaptation for Semantic Segmentation, CVPR2020

翻译:CVF2020邻域自适应/语义分割:FDA: Fourier Domain Adaptation for Semantic SegmentationFDA:用于语义分割的傅立叶域自适应算法_傅里叶域适应_HheeFish的博客-CSDN博客

论文:https://openaccess.thecvf.com/content_CVPR_2020/papers/Yang_FDA_Fourier_Domain_Adaptation_for_Semantic_Segmentation_CVPR_2020_paper.pdf

代码:https://github.com/YanchaoYang/FDA/

Domain adaptation via style transfer made easy using Fourier Transform. FDA needs no deep networks for style transfer, and involves no adversarial training. Below is the diagram of the proposed Fourier Domain Adaptation method:

Step 1: Apply FFT to source and target images.

Step 2: Replace the low frequency part of the source amplitude with that from the target.

Step 3: Apply inverse FFT to the modified source spectrum.

【FDA】图像通过傅里叶变换改变光谱风格,实现域自适应_第1张图片 光谱转移:在不改变语义内容的情况下,将源图像映射到目标“样式”。随机采样的目标图像通过将源图像频谱的低频分量与其自身频谱进行交换来提供样式。结果“目标风格的源图像”在感知上显示了更小的域差距,并改进了语义分割的迁移学习 【FDA】图像通过傅里叶变换改变光谱风格,实现域自适应_第2张图片 上图所示的域β大小的影响,其中光谱被交换:增加β将减少域间隙,但会引入伪影(参见放大的插图)。调整β,直到变换图像中的伪影变得明显。

【FDA】图像通过傅里叶变换改变光谱风格,实现域自适应_第3张图片 GTA5→CityScapes任务的消融研究。 使用不同β训练的分割网络保持了相似的性能

 文中设置了0.01,0.05,0.09三个消融实验值,存在差异,但不算大。

关键代码:

# https://github.com/YanchaoYang/FDA/blob/master/utils/__init__.py

import torch
import numpy as np

# LB = 0.01 or 0.05 or 0.09 ,依情况而定,论文建议尽量小于0.15
# source to target, target to target
src_in_trg = FDA_source_to_target( src_img, trg_img, L=LB )            # # src_lbl
trg_in_trg = trg_img

def extract_ampl_phase(fft_im):
    # fft_im: size should be bx3xhxwx2
    fft_amp = fft_im[:,:,:,:,0]**2 + fft_im[:,:,:,:,1]**2
    fft_amp = torch.sqrt(fft_amp)
    fft_pha = torch.atan2( fft_im[:,:,:,:,1], fft_im[:,:,:,:,0] )
    return fft_amp, fft_pha

def low_freq_mutate( amp_src, amp_trg, L=0.1 ):
    _, _, h, w = amp_src.size()
    b = (  np.floor(np.amin((h,w))*L)  ).astype(int)     # get b
    amp_src[:,:,0:b,0:b]     = amp_trg[:,:,0:b,0:b]      # top left
    amp_src[:,:,0:b,w-b:w]   = amp_trg[:,:,0:b,w-b:w]    # top right
    amp_src[:,:,h-b:h,0:b]   = amp_trg[:,:,h-b:h,0:b]    # bottom left
    amp_src[:,:,h-b:h,w-b:w] = amp_trg[:,:,h-b:h,w-b:w]  # bottom right
    return amp_src

def low_freq_mutate_np( amp_src, amp_trg, L=0.1 ):
    a_src = np.fft.fftshift( amp_src, axes=(-2, -1) )
    a_trg = np.fft.fftshift( amp_trg, axes=(-2, -1) )

    _, h, w = a_src.shape
    b = (  np.floor(np.amin((h,w))*L)  ).astype(int)
    c_h = np.floor(h/2.0).astype(int)
    c_w = np.floor(w/2.0).astype(int)

    h1 = c_h-b
    h2 = c_h+b+1
    w1 = c_w-b
    w2 = c_w+b+1

    a_src[:,h1:h2,w1:w2] = a_trg[:,h1:h2,w1:w2]
    a_src = np.fft.ifftshift( a_src, axes=(-2, -1) )
    return a_src

def FDA_source_to_target(src_img, trg_img, L=0.1):
    # exchange magnitude
    # input: src_img, trg_img

    # get fft of both source and target
    fft_src = torch.rfft( src_img.clone(), signal_ndim=2, onesided=False ) 
    fft_trg = torch.rfft( trg_img.clone(), signal_ndim=2, onesided=False )

    # extract amplitude and phase of both ffts
    amp_src, pha_src = extract_ampl_phase( fft_src.clone())
    amp_trg, pha_trg = extract_ampl_phase( fft_trg.clone())

    # replace the low frequency amplitude part of source with that from target
    amp_src_ = low_freq_mutate( amp_src.clone(), amp_trg.clone(), L=L )

    # recompose fft of source
    fft_src_ = torch.zeros( fft_src.size(), dtype=torch.float )
    fft_src_[:,:,:,:,0] = torch.cos(pha_src.clone()) * amp_src_.clone()
    fft_src_[:,:,:,:,1] = torch.sin(pha_src.clone()) * amp_src_.clone()

    # get the recomposed image: source content, target style
    _, _, imgH, imgW = src_img.size()
    src_in_trg = torch.irfft( fft_src_, signal_ndim=2, onesided=False, signal_sizes=[imgH,imgW] )

    return src_in_trg

def FDA_source_to_target_np( src_img, trg_img, L=0.1 ):
    # exchange magnitude
    # input: src_img, trg_img

    src_img_np = src_img #.cpu().numpy()
    trg_img_np = trg_img #.cpu().numpy()

    # get fft of both source and target
    fft_src_np = np.fft.fft2( src_img_np, axes=(-2, -1) )
    fft_trg_np = np.fft.fft2( trg_img_np, axes=(-2, -1) )

    # extract amplitude and phase of both ffts
    amp_src, pha_src = np.abs(fft_src_np), np.angle(fft_src_np)
    amp_trg, pha_trg = np.abs(fft_trg_np), np.angle(fft_trg_np)

    # mutate the amplitude part of source with target
    amp_src_ = low_freq_mutate_np( amp_src, amp_trg, L=L )

    # mutated fft of source
    fft_src_ = amp_src_ * np.exp( 1j * pha_src )

    # get the mutated image
    src_in_trg = np.fft.ifft2( fft_src_, axes=(-2, -1) )
    src_in_trg = np.real(src_in_trg)

    return src_in_trg

你可能感兴趣的:(代码阅读系列,论文笔记,python,深度学习,人工智能)