TensorFlow和PyTorch对比理解卷积和反向卷积或转置卷积(Transpose Convolution)

摘要

本文主要介绍反卷积或转置卷积, 并使用 TensorFlow 和 PyTorch 验证其计算方法.

相关

系列文章索引 :
https://blog.csdn.net/oBrightLamp/article/details/85067981

正文

1. 应用场景

转置卷积 (Transpose Convolution ), 有时候也称为反卷积 (Deconvolution), 是全卷积神经网络 FCN (Fully Convolutional Networks) 的关键组件.

在正常的卷积神经网络中, 卷积用于特征提取, 池化用于缩小尺寸, 这些操作也被成为下采样(down sampling).

将全卷积神经网络用于图片的语义分割任务时, 需要将得到的缩小特征图映射回原输入数据的尺寸.

我们需要某种相当的操作, 将尺寸放大, 将特征图解码至原尺寸.

与卷积操作对应的, 就是转置卷积.

2. 理论原理

考虑到卷积操作可以变换为等价的 Affine 操作, 往前传播得到损失值 e :

Y = X W T    e = f o r w a r d ( Y ) Y = XW^T\\ \;\\ e = forward(Y) Y=XWTe=forward(Y)

损失 e 对 X 的梯度为 :
d e d X = d e d Y W \frac{de}{dX}=\frac{de}{dY}W dXde=dYdeW

在某种意义上, d e / d Y {de}/{dY} de/dY 可以理解为上游的梯度特征图, d e / d X {de}/{dX} de/dX 可以理解为经过 W W W 矩阵解码后的梯度特征图.

d e / d X {de}/{dX} de/dX X X X 的尺寸相同, d e / d Y {de}/{dY} de/dY Y Y Y 的尺寸相同.

受梯度反向传播的计算方法的启示, 转置卷积使用的是同样的特征图解码过程 :
X ′ = Y W ′ X' = YW' X=YW

其计算过程和卷积的反向传播的计算过程是一样的, 区别在于解码矩阵的不同.

3. TensorFlow 验证

import numpy as np
import tensorflow as tf

tf.enable_eager_execution()
tf.set_random_seed(123)

np.set_printoptions(8, suppress=True, linewidth=120)
np.random.seed(123)

conv = tf.layers.Conv2D(
    filters=2, kernel_size=3, strides=(2, 2),
    data_format="channels_first")

x_tf = tf.constant(np.random.random((1, 3, 7, 5)))
dy_tf = tf.constant(np.random.random((1, 2, 3, 2)))

with tf.GradientTape() as t:
    t.watch(x_tf)
    y_tf = conv(x_tf)

dy_dx = t.gradient(y_tf, x_tf, dy_tf)

conv_trans = tf.layers.Conv2DTranspose(
    kernel_initializer=tf.constant_initializer(
        np.array(conv.get_weights()[0])),
    filters=3, kernel_size=3, strides=(2, 2),
    data_format="channels_first")

trans_out = conv_trans(dy_tf)

print("Conv2D backward dx")
print(dy_dx.numpy())
print()
print("Conv2DTranspose")
print(trans_out.numpy())

"""
Conv2D backward dx
[[[[-0.26440486  0.26833869 -0.1497194   0.25107463  0.22341981]
   [-0.11080054 -0.17643584 -0.349382   -0.27533132 -0.3187842 ]
   [-0.42400125  0.47303904 -0.62131251  0.41661293 -0.02865214]
   [-0.21167667 -0.05395512 -0.39484159 -0.24876052 -0.33796068]
   [-0.38350013  0.30566768 -0.57411525  0.364068   -0.21861426]
   [-0.14180074 -0.04180731 -0.33559074 -0.01265906 -0.16946705]
   [-0.13812166  0.02646263 -0.33804969  0.01576805 -0.1967117 ]]

  [[ 0.11763588  0.21558463  0.06459053  0.2714485  -0.1065533 ]
   [-0.04558774  0.03794783  0.03351889  0.07328753 -0.08867673]
   [ 0.22252216  0.17373468  0.31926596  0.30095783  0.0281614 ]
   [-0.22279067 -0.01495894  0.13037753  0.0594301  -0.04183767]
   [ 0.27010557  0.03643737  0.31739757  0.10868841  0.101587  ]
   [-0.14653154 -0.00827151 -0.06755262 -0.01766256  0.11459037]
   [ 0.14272034 -0.06190039  0.2653872  -0.07000048  0.11934594]]

  [[-0.02599829 -0.13185001 -0.32674567 -0.08481018 -0.22130249]
   [ 0.07681449  0.21004136  0.17081842  0.1788548   0.24557813]
   [-0.07103747 -0.16904703 -0.57256872 -0.00453903 -0.2766294 ]
   [ 0.13991308  0.34539933 -0.01511208  0.23340967  0.18819516]
   [-0.23274997 -0.10187769 -0.42749974 -0.05479129 -0.21957189]
   [ 0.09386354  0.23249821  0.03583443  0.23084185 -0.09312374]
   [-0.21560568  0.05434721 -0.29221142  0.04344786 -0.08432524]]]]

Conv2DTranspose
[[[[-0.26440486  0.26833869 -0.1497194   0.25107463  0.22341981]
   [-0.11080054 -0.17643584 -0.349382   -0.27533132 -0.3187842 ]
   [-0.42400125  0.47303904 -0.62131251  0.41661293 -0.02865214]
   [-0.21167667 -0.05395512 -0.39484159 -0.24876052 -0.33796068]
   [-0.38350013  0.30566768 -0.57411525  0.364068   -0.21861426]
   [-0.14180074 -0.04180731 -0.33559074 -0.01265906 -0.16946705]
   [-0.13812166  0.02646263 -0.33804969  0.01576805 -0.1967117 ]]

  [[ 0.11763588  0.21558463  0.06459053  0.2714485  -0.1065533 ]
   [-0.04558774  0.03794783  0.03351889  0.07328753 -0.08867673]
   [ 0.22252216  0.17373468  0.31926596  0.30095783  0.0281614 ]
   [-0.22279067 -0.01495894  0.13037753  0.0594301  -0.04183767]
   [ 0.27010557  0.03643737  0.31739757  0.10868841  0.101587  ]
   [-0.14653154 -0.00827151 -0.06755262 -0.01766256  0.11459037]
   [ 0.14272034 -0.06190039  0.2653872  -0.07000048  0.11934594]]

  [[-0.02599829 -0.13185001 -0.32674567 -0.08481018 -0.22130249]
   [ 0.07681449  0.21004136  0.17081842  0.1788548   0.24557813]
   [-0.07103747 -0.16904703 -0.57256872 -0.00453903 -0.2766294 ]
   [ 0.13991308  0.34539933 -0.01511208  0.23340967  0.18819516]
   [-0.23274997 -0.10187769 -0.42749974 -0.05479129 -0.21957189]
   [ 0.09386354  0.23249821  0.03583443  0.23084185 -0.09312374]
   [-0.21560568  0.05434721 -0.29221142  0.04344786 -0.08432524]]]]
"""

4. PyTorch 验证

import torch
import numpy as np

np.random.seed(123)
np.set_printoptions(precision=8, suppress=True, linewidth=120)

torch.manual_seed(123)

x_torch = torch.tensor(
    np.random.randn(1, 3, 7, 5).astype(np.float32),
    requires_grad=True)

dy = torch.tensor(
    np.random.randn(1, 2, 3, 2).astype(np.float32),
    requires_grad=True).float()

downsample = torch.nn.Conv2d(3, 2, 3, stride=2, bias=False)
upsample = torch.nn.ConvTranspose2d(2, 3, 3, stride=2, bias=False)

upsample.weight = downsample.weight

y = downsample(x_torch)
y.backward(dy)

up_output = upsample(dy)

print("Conv2D backward dx")
print(x_torch.grad.data.numpy())
print()
print("Conv2DTranspose")
print(up_output.data.numpy())

"""
Conv2D backward dx
[[[[ 0.16825794  0.19124855 -0.33899829 -0.11185936  0.11329053]
   [ 0.20666172  0.10861827  0.00129423 -0.09700691 -0.03632798]
   [ 0.30875325  0.25878775  0.07577793  0.01328135 -0.22357774]
   [-0.05278662  0.35138682 -0.02333573  0.02318966  0.14505014]
   [ 0.46448994  0.39769357  0.18331857 -0.11662451  0.20657384]
   [-0.07072078  0.28886762 -0.45937026 -0.05585691 -0.17530707]
   [ 0.24664924  0.29023114  0.1393536  -0.15389174  0.08684592]]

  [[-0.1993071  -0.07106322 -0.03068739  0.02783359  0.11258031]
   [ 0.01023537 -0.0486452  -0.16939156  0.0054528   0.06320697]
   [-0.26690811  0.32631955 -0.3382079  -0.24664044 -0.07278823]
   [ 0.31050849  0.19940096  0.02310865 -0.0861482  -0.13465518]
   [-0.12364334  0.4635199  -0.24362773  0.22820684  0.19334255]
   [ 0.2657707   0.17743842  0.13708432  0.09919833  0.1717405 ]
   [ 0.05583961  0.28651747 -0.16637257 -0.18717323 -0.03308342]]

  [[ 0.00604475 -0.03519516  0.0296074   0.03315877  0.00219124]
   [ 0.21434575  0.14805993 -0.02838825 -0.07707279 -0.01687056]
   [ 0.13658504 -0.07158607 -0.2082321  -0.02648996  0.01859709]
   [ 0.0191915  -0.0389997  -0.12270755  0.14160722  0.14834535]
   [ 0.25714236 -0.19829705  0.00305076  0.08666279 -0.06087164]
   [-0.00976412 -0.05168293 -0.52960438 -0.18384263 -0.17310025]
   [ 0.15100177 -0.08521508  0.13797237 -0.08828223 -0.06531012]]]]

Conv2DTranspose
[[[[ 0.16825794  0.19124855 -0.33899829 -0.11185936  0.11329053]
   [ 0.20666172  0.10861827  0.00129423 -0.09700691 -0.03632798]
   [ 0.30875325  0.25878775  0.07577793  0.01328135 -0.22357774]
   [-0.05278662  0.35138682 -0.02333573  0.02318966  0.14505014]
   [ 0.46448994  0.39769357  0.18331857 -0.11662451  0.20657384]
   [-0.07072078  0.28886762 -0.45937026 -0.05585691 -0.17530707]
   [ 0.24664924  0.29023114  0.1393536  -0.15389174  0.08684592]]

  [[-0.1993071  -0.07106322 -0.03068739  0.02783359  0.11258031]
   [ 0.01023537 -0.0486452  -0.16939156  0.0054528   0.06320697]
   [-0.26690811  0.32631955 -0.3382079  -0.24664044 -0.07278823]
   [ 0.31050849  0.19940096  0.02310865 -0.0861482  -0.13465518]
   [-0.12364334  0.4635199  -0.24362773  0.22820684  0.19334255]
   [ 0.2657707   0.17743842  0.13708432  0.09919833  0.1717405 ]
   [ 0.05583961  0.28651747 -0.16637257 -0.18717323 -0.03308342]]

  [[ 0.00604475 -0.03519516  0.0296074   0.03315877  0.00219124]
   [ 0.21434575  0.14805993 -0.02838825 -0.07707279 -0.01687056]
   [ 0.13658504 -0.07158607 -0.2082321  -0.02648996  0.01859709]
   [ 0.0191915  -0.0389997  -0.12270755  0.14160722  0.14834535]
   [ 0.25714236 -0.19829705  0.00305076  0.08666279 -0.06087164]
   [-0.00976412 -0.05168293 -0.52960438 -0.18384263 -0.17310025]
   [ 0.15100177 -0.08521508  0.13797237 -0.08828223 -0.06531012]]]]
"""

你可能感兴趣的:(深度学习基础)