CARAFE: 轻量级通用上采样算子
CARAFE 分为两个主要模块,分别是上采样核预测模块和特征重组模块。假设上采样倍率为 σ σ σ σσ \sigma σσσkencoder=3,kup=5,(性能与计算量的折中)
class CARAFE(nn.Module):
def init(self, inC, outC, kernel_size=3, up_factor=2):
super(CARAFE, self).init()
self.kernel_size = kernel_size
self.up_factor = up_factor
self.down = nn.Conv2d(inC, inC // 4, 1)
self.encoder = nn.Conv2d(inC // 4, self.up_factor 2 * self.kernel_size 2,
self.kernel_size, 1, self.kernel_size // 2)
self.out = nn.Conv2d(inC, outC, 1)
def forward(self, in_tensor):
N, C, H, W = in_tensor.size()
# N,C,H,W -> N,C,delta*H,delta*W
# kernel prediction module
kernel_tensor = self.down(in_tensor) # (N, Cm, H, W)
kernel_tensor = self.encoder(kernel_tensor) # (N, S^2 * Kup^2, H, W)
kernel_tensor = F.pixel_shuffle(kernel_tensor, self.up_factor) # (N, S^2 * Kup^2, H, W)->(N, Kup^2, S*H, S*W)
kernel_tensor = F.softmax(kernel_tensor, dim=1) # (N, Kup^2, S*H, S*W)
kernel_tensor = kernel_tensor.unfold(2, self.up_factor, step=self.up_factor) # (N, Kup^2, H, W*S, S)
kernel_tensor = kernel_tensor.unfold(3, self.up_factor, step=self.up_factor) # (N, Kup^2, H, W, S, S)
kernel_tensor = kernel_tensor.reshape(N, self.kernel_size ** 2, H, W, self.up_factor ** 2) # (N, Kup^2, H, W, S^2)
kernel_tensor = kernel_tensor.permute(0, 2, 3, 1, 4) # (N, H, W, Kup^2, S^2)
# content-aware reassembly module
# tensor.unfold: dim, size, step
in_tensor = F.pad(in_tensor, pad=(self.kernel_size // 2, self.kernel_size // 2,
self.kernel_size // 2, self.kernel_size // 2),
mode='constant', value=0) # (N, C, H+Kup//2+Kup//2, W+Kup//2+Kup//2)
in_tensor = in_tensor.unfold(2, self.kernel_size, step=1) # (N, C, H, W+Kup//2+Kup//2, Kup)
in_tensor = in_tensor.unfold(3, self.kernel_size, step=1) # (N, C, H, W, Kup, Kup)
in_tensor = in_tensor.reshape(N, C, H, W, -1) # (N, C, H, W, Kup^2)
in_tensor = in_tensor.permute(0, 2, 3, 1, 4) # (N, H, W, C, Kup^2)
out_tensor = torch.matmul(in_tensor, kernel_tensor) # (N, H, W, C, S^2)
out_tensor = out_tensor.reshape(N, H, W, -1)
out_tensor = out_tensor.permute(0, 3, 1, 2)
out_tensor = F.pixel_shuffle(out_tensor, self.up_factor)
out_tensor = self.out(out_tensor)
return out_tensor
if name == ‘main’:
data = torch.rand(4, 20, 10, 10)
carafe = CARAFE(20, 10)
print(carafe(data).size())