本文主要介绍反卷积或转置卷积, 并使用 TensorFlow 和 PyTorch 验证其计算方法.
系列文章索引 :
https://blog.csdn.net/oBrightLamp/article/details/85067981
转置卷积 (Transpose Convolution ), 有时候也称为反卷积 (Deconvolution), 是全卷积神经网络 FCN (Fully Convolutional Networks) 的关键组件.
在正常的卷积神经网络中, 卷积用于特征提取, 池化用于缩小尺寸, 这些操作也被成为下采样(down sampling).
将全卷积神经网络用于图片的语义分割任务时, 需要将得到的缩小特征图映射回原输入数据的尺寸.
我们需要某种相当的操作, 将尺寸放大, 将特征图解码至原尺寸.
与卷积操作对应的, 就是转置卷积.
考虑到卷积操作可以变换为等价的 Affine 操作, 往前传播得到损失值 e :
Y = X W T    e = f o r w a r d ( Y ) Y = XW^T\\ \;\\ e = forward(Y) Y=XWTe=forward(Y)
损失 e 对 X 的梯度为 :
d e d X = d e d Y W \frac{de}{dX}=\frac{de}{dY}W dXde=dYdeW
在某种意义上, d e / d Y {de}/{dY} de/dY 可以理解为上游的梯度特征图, d e / d X {de}/{dX} de/dX 可以理解为经过 W W W 矩阵解码后的梯度特征图.
d e / d X {de}/{dX} de/dX 和 X X X 的尺寸相同, d e / d Y {de}/{dY} de/dY 和 Y Y Y 的尺寸相同.
受梯度反向传播的计算方法的启示, 转置卷积使用的是同样的特征图解码过程 :
X ′ = Y W ′ X' = YW' X′=YW′
其计算过程和卷积的反向传播的计算过程是一样的, 区别在于解码矩阵的不同.
import numpy as np
import tensorflow as tf
tf.enable_eager_execution()
tf.set_random_seed(123)
np.set_printoptions(8, suppress=True, linewidth=120)
np.random.seed(123)
conv = tf.layers.Conv2D(
filters=2, kernel_size=3, strides=(2, 2),
data_format="channels_first")
x_tf = tf.constant(np.random.random((1, 3, 7, 5)))
dy_tf = tf.constant(np.random.random((1, 2, 3, 2)))
with tf.GradientTape() as t:
t.watch(x_tf)
y_tf = conv(x_tf)
dy_dx = t.gradient(y_tf, x_tf, dy_tf)
conv_trans = tf.layers.Conv2DTranspose(
kernel_initializer=tf.constant_initializer(
np.array(conv.get_weights()[0])),
filters=3, kernel_size=3, strides=(2, 2),
data_format="channels_first")
trans_out = conv_trans(dy_tf)
print("Conv2D backward dx")
print(dy_dx.numpy())
print()
print("Conv2DTranspose")
print(trans_out.numpy())
"""
Conv2D backward dx
[[[[-0.26440486 0.26833869 -0.1497194 0.25107463 0.22341981]
[-0.11080054 -0.17643584 -0.349382 -0.27533132 -0.3187842 ]
[-0.42400125 0.47303904 -0.62131251 0.41661293 -0.02865214]
[-0.21167667 -0.05395512 -0.39484159 -0.24876052 -0.33796068]
[-0.38350013 0.30566768 -0.57411525 0.364068 -0.21861426]
[-0.14180074 -0.04180731 -0.33559074 -0.01265906 -0.16946705]
[-0.13812166 0.02646263 -0.33804969 0.01576805 -0.1967117 ]]
[[ 0.11763588 0.21558463 0.06459053 0.2714485 -0.1065533 ]
[-0.04558774 0.03794783 0.03351889 0.07328753 -0.08867673]
[ 0.22252216 0.17373468 0.31926596 0.30095783 0.0281614 ]
[-0.22279067 -0.01495894 0.13037753 0.0594301 -0.04183767]
[ 0.27010557 0.03643737 0.31739757 0.10868841 0.101587 ]
[-0.14653154 -0.00827151 -0.06755262 -0.01766256 0.11459037]
[ 0.14272034 -0.06190039 0.2653872 -0.07000048 0.11934594]]
[[-0.02599829 -0.13185001 -0.32674567 -0.08481018 -0.22130249]
[ 0.07681449 0.21004136 0.17081842 0.1788548 0.24557813]
[-0.07103747 -0.16904703 -0.57256872 -0.00453903 -0.2766294 ]
[ 0.13991308 0.34539933 -0.01511208 0.23340967 0.18819516]
[-0.23274997 -0.10187769 -0.42749974 -0.05479129 -0.21957189]
[ 0.09386354 0.23249821 0.03583443 0.23084185 -0.09312374]
[-0.21560568 0.05434721 -0.29221142 0.04344786 -0.08432524]]]]
Conv2DTranspose
[[[[-0.26440486 0.26833869 -0.1497194 0.25107463 0.22341981]
[-0.11080054 -0.17643584 -0.349382 -0.27533132 -0.3187842 ]
[-0.42400125 0.47303904 -0.62131251 0.41661293 -0.02865214]
[-0.21167667 -0.05395512 -0.39484159 -0.24876052 -0.33796068]
[-0.38350013 0.30566768 -0.57411525 0.364068 -0.21861426]
[-0.14180074 -0.04180731 -0.33559074 -0.01265906 -0.16946705]
[-0.13812166 0.02646263 -0.33804969 0.01576805 -0.1967117 ]]
[[ 0.11763588 0.21558463 0.06459053 0.2714485 -0.1065533 ]
[-0.04558774 0.03794783 0.03351889 0.07328753 -0.08867673]
[ 0.22252216 0.17373468 0.31926596 0.30095783 0.0281614 ]
[-0.22279067 -0.01495894 0.13037753 0.0594301 -0.04183767]
[ 0.27010557 0.03643737 0.31739757 0.10868841 0.101587 ]
[-0.14653154 -0.00827151 -0.06755262 -0.01766256 0.11459037]
[ 0.14272034 -0.06190039 0.2653872 -0.07000048 0.11934594]]
[[-0.02599829 -0.13185001 -0.32674567 -0.08481018 -0.22130249]
[ 0.07681449 0.21004136 0.17081842 0.1788548 0.24557813]
[-0.07103747 -0.16904703 -0.57256872 -0.00453903 -0.2766294 ]
[ 0.13991308 0.34539933 -0.01511208 0.23340967 0.18819516]
[-0.23274997 -0.10187769 -0.42749974 -0.05479129 -0.21957189]
[ 0.09386354 0.23249821 0.03583443 0.23084185 -0.09312374]
[-0.21560568 0.05434721 -0.29221142 0.04344786 -0.08432524]]]]
"""
import torch
import numpy as np
np.random.seed(123)
np.set_printoptions(precision=8, suppress=True, linewidth=120)
torch.manual_seed(123)
x_torch = torch.tensor(
np.random.randn(1, 3, 7, 5).astype(np.float32),
requires_grad=True)
dy = torch.tensor(
np.random.randn(1, 2, 3, 2).astype(np.float32),
requires_grad=True).float()
downsample = torch.nn.Conv2d(3, 2, 3, stride=2, bias=False)
upsample = torch.nn.ConvTranspose2d(2, 3, 3, stride=2, bias=False)
upsample.weight = downsample.weight
y = downsample(x_torch)
y.backward(dy)
up_output = upsample(dy)
print("Conv2D backward dx")
print(x_torch.grad.data.numpy())
print()
print("Conv2DTranspose")
print(up_output.data.numpy())
"""
Conv2D backward dx
[[[[ 0.16825794 0.19124855 -0.33899829 -0.11185936 0.11329053]
[ 0.20666172 0.10861827 0.00129423 -0.09700691 -0.03632798]
[ 0.30875325 0.25878775 0.07577793 0.01328135 -0.22357774]
[-0.05278662 0.35138682 -0.02333573 0.02318966 0.14505014]
[ 0.46448994 0.39769357 0.18331857 -0.11662451 0.20657384]
[-0.07072078 0.28886762 -0.45937026 -0.05585691 -0.17530707]
[ 0.24664924 0.29023114 0.1393536 -0.15389174 0.08684592]]
[[-0.1993071 -0.07106322 -0.03068739 0.02783359 0.11258031]
[ 0.01023537 -0.0486452 -0.16939156 0.0054528 0.06320697]
[-0.26690811 0.32631955 -0.3382079 -0.24664044 -0.07278823]
[ 0.31050849 0.19940096 0.02310865 -0.0861482 -0.13465518]
[-0.12364334 0.4635199 -0.24362773 0.22820684 0.19334255]
[ 0.2657707 0.17743842 0.13708432 0.09919833 0.1717405 ]
[ 0.05583961 0.28651747 -0.16637257 -0.18717323 -0.03308342]]
[[ 0.00604475 -0.03519516 0.0296074 0.03315877 0.00219124]
[ 0.21434575 0.14805993 -0.02838825 -0.07707279 -0.01687056]
[ 0.13658504 -0.07158607 -0.2082321 -0.02648996 0.01859709]
[ 0.0191915 -0.0389997 -0.12270755 0.14160722 0.14834535]
[ 0.25714236 -0.19829705 0.00305076 0.08666279 -0.06087164]
[-0.00976412 -0.05168293 -0.52960438 -0.18384263 -0.17310025]
[ 0.15100177 -0.08521508 0.13797237 -0.08828223 -0.06531012]]]]
Conv2DTranspose
[[[[ 0.16825794 0.19124855 -0.33899829 -0.11185936 0.11329053]
[ 0.20666172 0.10861827 0.00129423 -0.09700691 -0.03632798]
[ 0.30875325 0.25878775 0.07577793 0.01328135 -0.22357774]
[-0.05278662 0.35138682 -0.02333573 0.02318966 0.14505014]
[ 0.46448994 0.39769357 0.18331857 -0.11662451 0.20657384]
[-0.07072078 0.28886762 -0.45937026 -0.05585691 -0.17530707]
[ 0.24664924 0.29023114 0.1393536 -0.15389174 0.08684592]]
[[-0.1993071 -0.07106322 -0.03068739 0.02783359 0.11258031]
[ 0.01023537 -0.0486452 -0.16939156 0.0054528 0.06320697]
[-0.26690811 0.32631955 -0.3382079 -0.24664044 -0.07278823]
[ 0.31050849 0.19940096 0.02310865 -0.0861482 -0.13465518]
[-0.12364334 0.4635199 -0.24362773 0.22820684 0.19334255]
[ 0.2657707 0.17743842 0.13708432 0.09919833 0.1717405 ]
[ 0.05583961 0.28651747 -0.16637257 -0.18717323 -0.03308342]]
[[ 0.00604475 -0.03519516 0.0296074 0.03315877 0.00219124]
[ 0.21434575 0.14805993 -0.02838825 -0.07707279 -0.01687056]
[ 0.13658504 -0.07158607 -0.2082321 -0.02648996 0.01859709]
[ 0.0191915 -0.0389997 -0.12270755 0.14160722 0.14834535]
[ 0.25714236 -0.19829705 0.00305076 0.08666279 -0.06087164]
[-0.00976412 -0.05168293 -0.52960438 -0.18384263 -0.17310025]
[ 0.15100177 -0.08521508 0.13797237 -0.08828223 -0.06531012]]]]
"""