翻译:CVF2020邻域自适应/语义分割:FDA: Fourier Domain Adaptation for Semantic SegmentationFDA:用于语义分割的傅立叶域自适应算法_傅里叶域适应_HheeFish的博客-CSDN博客
论文:https://openaccess.thecvf.com/content_CVPR_2020/papers/Yang_FDA_Fourier_Domain_Adaptation_for_Semantic_Segmentation_CVPR_2020_paper.pdf
代码:https://github.com/YanchaoYang/FDA/
Domain adaptation via style transfer made easy using Fourier Transform. FDA needs no deep networks for style transfer, and involves no adversarial training. Below is the diagram of the proposed Fourier Domain Adaptation method:
Step 1: Apply FFT to source and target images.
Step 2: Replace the low frequency part of the source amplitude with that from the target.
Step 3: Apply inverse FFT to the modified source spectrum.
光谱转移:在不改变语义内容的情况下,将源图像映射到目标“样式”。随机采样的目标图像通过将源图像频谱的低频分量与其自身频谱进行交换来提供样式。结果“目标风格的源图像”在感知上显示了更小的域差距,并改进了语义分割的迁移学习 上图所示的域β大小的影响,其中光谱被交换:增加β将减少域间隙,但会引入伪影(参见放大的插图)。调整β,直到变换图像中的伪影变得明显。 GTA5→CityScapes任务的消融研究。 使用不同β训练的分割网络保持了相似的性能文中设置了0.01,0.05,0.09三个消融实验值,存在差异,但不算大。
关键代码:
# https://github.com/YanchaoYang/FDA/blob/master/utils/__init__.py
import torch
import numpy as np
# LB = 0.01 or 0.05 or 0.09 ,依情况而定,论文建议尽量小于0.15
# source to target, target to target
src_in_trg = FDA_source_to_target( src_img, trg_img, L=LB ) # # src_lbl
trg_in_trg = trg_img
def extract_ampl_phase(fft_im):
# fft_im: size should be bx3xhxwx2
fft_amp = fft_im[:,:,:,:,0]**2 + fft_im[:,:,:,:,1]**2
fft_amp = torch.sqrt(fft_amp)
fft_pha = torch.atan2( fft_im[:,:,:,:,1], fft_im[:,:,:,:,0] )
return fft_amp, fft_pha
def low_freq_mutate( amp_src, amp_trg, L=0.1 ):
_, _, h, w = amp_src.size()
b = ( np.floor(np.amin((h,w))*L) ).astype(int) # get b
amp_src[:,:,0:b,0:b] = amp_trg[:,:,0:b,0:b] # top left
amp_src[:,:,0:b,w-b:w] = amp_trg[:,:,0:b,w-b:w] # top right
amp_src[:,:,h-b:h,0:b] = amp_trg[:,:,h-b:h,0:b] # bottom left
amp_src[:,:,h-b:h,w-b:w] = amp_trg[:,:,h-b:h,w-b:w] # bottom right
return amp_src
def low_freq_mutate_np( amp_src, amp_trg, L=0.1 ):
a_src = np.fft.fftshift( amp_src, axes=(-2, -1) )
a_trg = np.fft.fftshift( amp_trg, axes=(-2, -1) )
_, h, w = a_src.shape
b = ( np.floor(np.amin((h,w))*L) ).astype(int)
c_h = np.floor(h/2.0).astype(int)
c_w = np.floor(w/2.0).astype(int)
h1 = c_h-b
h2 = c_h+b+1
w1 = c_w-b
w2 = c_w+b+1
a_src[:,h1:h2,w1:w2] = a_trg[:,h1:h2,w1:w2]
a_src = np.fft.ifftshift( a_src, axes=(-2, -1) )
return a_src
def FDA_source_to_target(src_img, trg_img, L=0.1):
# exchange magnitude
# input: src_img, trg_img
# get fft of both source and target
fft_src = torch.rfft( src_img.clone(), signal_ndim=2, onesided=False )
fft_trg = torch.rfft( trg_img.clone(), signal_ndim=2, onesided=False )
# extract amplitude and phase of both ffts
amp_src, pha_src = extract_ampl_phase( fft_src.clone())
amp_trg, pha_trg = extract_ampl_phase( fft_trg.clone())
# replace the low frequency amplitude part of source with that from target
amp_src_ = low_freq_mutate( amp_src.clone(), amp_trg.clone(), L=L )
# recompose fft of source
fft_src_ = torch.zeros( fft_src.size(), dtype=torch.float )
fft_src_[:,:,:,:,0] = torch.cos(pha_src.clone()) * amp_src_.clone()
fft_src_[:,:,:,:,1] = torch.sin(pha_src.clone()) * amp_src_.clone()
# get the recomposed image: source content, target style
_, _, imgH, imgW = src_img.size()
src_in_trg = torch.irfft( fft_src_, signal_ndim=2, onesided=False, signal_sizes=[imgH,imgW] )
return src_in_trg
def FDA_source_to_target_np( src_img, trg_img, L=0.1 ):
# exchange magnitude
# input: src_img, trg_img
src_img_np = src_img #.cpu().numpy()
trg_img_np = trg_img #.cpu().numpy()
# get fft of both source and target
fft_src_np = np.fft.fft2( src_img_np, axes=(-2, -1) )
fft_trg_np = np.fft.fft2( trg_img_np, axes=(-2, -1) )
# extract amplitude and phase of both ffts
amp_src, pha_src = np.abs(fft_src_np), np.angle(fft_src_np)
amp_trg, pha_trg = np.abs(fft_trg_np), np.angle(fft_trg_np)
# mutate the amplitude part of source with target
amp_src_ = low_freq_mutate_np( amp_src, amp_trg, L=L )
# mutated fft of source
fft_src_ = amp_src_ * np.exp( 1j * pha_src )
# get the mutated image
src_in_trg = np.fft.ifft2( fft_src_, axes=(-2, -1) )
src_in_trg = np.real(src_in_trg)
return src_in_trg