Python 实现 TensorFlow 和 PyTorch 验证卷积 convolution 函数矩阵化计算及反向传播

摘要

本文使用纯 Python 实现 TensorFlow 和 PyTorch 验证卷积 convolution 函数矩阵化计算及反向传播.

相关

原理和详细解释, 请参考文章 :

卷积 convolution 函数的矩阵化计算方法及其梯度的反向传播

系列文章索引 :
https://blog.csdn.net/oBrightLamp/article/details/85067981

正文

1. Batch2ConvMatrix类和Conv2d类

文件目录 : vanilla_nn/convolution_matrix.py

import numpy as np


class Batch2ConvMatrix:
    def __init__(self, stride, kernel_h, kernel_w):
        self.stride = stride
        self.kernel_h = kernel_h
        self.kernel_w = kernel_w

        self.x = None
        self.conv_size = None

    def __call__(self, x):
        self.x = x
        x_nums, x_channels, x_height, x_width = np.shape(self.x)

        conv_height = int((x_height - self.kernel_h) / self.stride) + 1
        conv_width = int((x_width - self.kernel_w) / self.stride) + 1

        scan = np.zeros((x_nums, conv_height, conv_width,
                         x_channels, self.kernel_h, self.kernel_w))

        for n in range(x_nums):
            for h in range(conv_height):
                for w in range(conv_width):
                    for c in range(x_channels):
                        start_h = h * self.stride
                        start_w = w * self.stride
                        end_h = start_h + self.kernel_h
                        end_w = start_w + self.kernel_w

                        scan[n, h, w, c] = \
                            x[n, c, start_h:end_h, start_w:end_w]

        conv_matrix = scan.reshape(x_nums * conv_height * conv_width, -1)
        self.conv_size = [x_nums, x_channels, conv_height, conv_width]
        return conv_matrix

    def backward(self, dx2m):
        dx = np.zeros_like(self.x)
        kh = self.kernel_h
        kw = self.kernel_w
        xn, xc, ch, cw = self.conv_size

        dx2m = dx2m.reshape((xn, ch, cw, xc, kh, kw))

        for n in range(xn):
            for c in range(xc):
                for h in range(ch):
                    for w in range(cw):
                        start_h = h * self.stride
                        start_w = w * self.stride
                        end_h = start_h + self.kernel_h
                        end_w = start_w + self.kernel_w

                        dx[n, c][start_h:end_h, start_w:end_w] \
                            += dx2m[n, h, w, c]

        return dx


class Conv2d:
    def __init__(self, stride, weight=None, bias=None):
        self.stride = stride
        self.weight = weight
        self.bias = bias

        self.b2m = None
        self.x2m = None
        self.w2m = None

        self.dw = None
        self.db = None

    def __call__(self, x):
        wn, wc, wh, ww = np.shape(self.weight)

        if self.b2m is None:
            self.b2m = Batch2ConvMatrix(self.stride, wh, ww)

        x2m = self.b2m(x)
        w2m = self.weight.reshape(wn, -1)
        xn, xc, oh, ow = self.b2m.conv_size

        out_matrix = np.matmul(x2m, w2m.T) + self.bias

        out = out_matrix.reshape((xn, oh, ow, wn))

        self.x2m = x2m
        self.w2m = w2m

        out = out.transpose((0, 3, 1, 2))
        return out

    def backward(self, d_loss):
        on, oc, oh, ow = np.shape(d_loss)

        d_loss = d_loss.transpose((0, 2, 3, 1))
        d_loss = d_loss.reshape((on * oh * ow, -1))

        dw = np.matmul(d_loss.T, self.x2m)
        self.dw = dw.reshape(np.shape(self.weight))
        self.db = np.sum(d_loss, axis=0)

        dx2m = np.matmul(d_loss, self.w2m)
        dx = self.b2m.backward(dx2m)
        return dx

2. TensorFlow验证

import numpy as np
import tensorflow as tf
from vanilla_nn.convolution_matrix import Conv2d

tf.enable_eager_execution()
tf.set_random_seed(123)

np.random.seed(123)
np.set_printoptions(6, suppress=True, linewidth=120)

x_numpy = np.random.random((2, 3, 5, 7))
x_tf = tf.constant(x_numpy)

conv_tf = tf.layers.Conv2D(
    filters=4, kernel_size=3, strides=(2, 2),
    data_format="channels_first")

with tf.GradientTape(persistent=True) as t:
    t.watch(x_tf)
    y_tf = conv_tf(x_tf)

conv_numpy = Conv2d(
    stride=2,
    weight=conv_tf.get_weights()[0].transpose(3, 2, 0, 1),
    bias=conv_tf.get_weights()[1])

y_numpy = conv_numpy(x_numpy)

dy_numpy = np.random.random(y_numpy.shape)
dy_tf = tf.constant(dy_numpy)

dx_numpy = conv_numpy.backward(dy_numpy)

dy_dx = t.gradient(y_tf, x_tf, dy_tf)
dw_tf, db_tf = t.gradient(y_tf, conv_tf.weights, dy_tf)

print("y_numpy\n", y_numpy[0][0])
print("y_tf\n", y_tf.numpy()[0][0])

print("dx_numpy\n", dx_numpy[0][0])
print("dx_tf\n", dy_dx.numpy()[0][0])

print("dw_numpy\n", conv_numpy.dw[0][0])
print("dw_tf\n", dw_tf.numpy()[0][0])

print("db_numpy\n", conv_numpy.db)
print("db_tf\n", db_tf.numpy())

"""
y_numpy
 [[-0.410179 -0.379269 -1.083951]
 [-0.340203 -0.57836  -0.620306]]
y_tf
 [[-0.410179 -0.379269 -1.083951]
 [-0.340203 -0.57836  -0.620306]]
dx_numpy
 [[-0.220716  0.140123 -0.246691  0.017347 -0.073237 -0.027538  0.058724]
 [-0.067507 -0.137531  0.273061 -0.048839  0.257607 -0.097253  0.445625]
 [-0.403482 -0.176619 -0.319394 -0.279691 -0.49433  -0.325839 -0.105012]
 [-0.114447 -0.17873   0.223171 -0.076912  0.27387  -0.123164  0.544131]
 [-0.089011 -0.26426   0.003028 -0.24742  -0.286531 -0.404775  0.071425]]
dx_tf
 [[-0.220716  0.140123 -0.246691  0.017347 -0.073237 -0.027538  0.058724]
 [-0.067507 -0.137531  0.273061 -0.048839  0.257607 -0.097253  0.445625]
 [-0.403482 -0.176619 -0.319394 -0.279691 -0.49433  -0.325839 -0.105012]
 [-0.114447 -0.17873   0.223171 -0.076912  0.27387  -0.123164  0.544131]
 [-0.089011 -0.26426   0.003028 -0.24742  -0.286531 -0.404775  0.071425]]
dw_numpy
 [[ 3.759475  4.10321   3.357347]
 [ 3.621035  3.578632  3.329442]
 [ 3.183582  4.388502  2.181801]]
dw_tf
 [[ 3.759475  2.767449  3.618431  3.229403]
 [ 3.718784  3.304698  3.817697  3.17113 ]
 [ 3.056882  1.940038  2.436826  2.322039]]
db_numpy
 [ 6.804211  5.354626  6.683215  5.545097]
db_tf
 [ 6.804211  5.354626  6.683215  5.545097]
 """

3. PyTorch 验证

import torch
import numpy as np
from vanilla_nn.convolution_matrix import Conv2d

np.random.seed(123)
torch.manual_seed(123)
np.set_printoptions(6, suppress=True, linewidth=120)

x_numpy = np.random.random((2, 3, 5, 7))
x_torch = torch.tensor(x_numpy, requires_grad=True)

conv_torch = torch.nn.Conv2d(3, 4, (3, 3), stride=2).double()
conv_numpy = Conv2d(
    stride=2,
    weight=conv_torch.weight.data.numpy(),
    bias=conv_torch.bias.data.numpy())

y_numpy = conv_numpy(x_numpy)
y_torch = conv_torch(x_torch)

dy_numpy = np.random.random(y_numpy.shape)
dy_torch = torch.tensor(dy_numpy)

y_torch.backward(dy_torch)

dx_numpy = conv_numpy.backward(dy_numpy)
dx_torch = x_torch.grad.data.numpy()

print("y_numpy\n", y_numpy[0][0])
print("y_torch\n", y_torch.data.numpy()[0][0])

print("dx_numpy\n", dx_numpy[0][0])
print("dx_torch\n", dx_torch[0][0])

print("dw_numpy\n", conv_numpy.dw[0][0])
print("dw_torch\n", conv_torch.weight.grad.data.numpy()[0][0])

print("db_numpy\n", conv_numpy.db)
print("db_torch\n", conv_torch.bias.grad.data.numpy())

"""
y_numpy
 [[-0.400592 -0.385043 -0.5706  ]
 [-0.373743 -0.203867 -0.217639]]
y_torch
 [[-0.400592 -0.385043 -0.5706  ]
 [-0.373743 -0.203867 -0.217639]]
dx_numpy
 [[ 0.097144 -0.050192 -0.115058 -0.091209  0.023963  0.04729  -0.128907]
 [ 0.024621 -0.087676  0.238786  0.010403  0.297381 -0.112517  0.262634]
 [ 0.024772 -0.310948  0.129654 -0.152652  0.059663 -0.34918  -0.192246]
 [ 0.084056 -0.081453  0.207333 -0.126043  0.409333 -0.012     0.260899]
 [-0.026007 -0.175092 -0.040534 -0.303666  0.025274 -0.298108  0.191089]]
dx_torch
 [[ 0.097144 -0.050192 -0.115058 -0.091209  0.023963  0.04729  -0.128907]
 [ 0.024621 -0.087676  0.238786  0.010403  0.297381 -0.112517  0.262634]
 [ 0.024772 -0.310948  0.129654 -0.152652  0.059663 -0.34918  -0.192246]
 [ 0.084056 -0.081453  0.207333 -0.126043  0.409333 -0.012     0.260899]
 [-0.026007 -0.175092 -0.040534 -0.303666  0.025274 -0.298108  0.191089]]
dw_numpy
 [[ 3.569718  4.275913  4.107687]
 [ 3.066375  4.63871   3.118461]
 [ 3.838584  4.8143    3.607119]]
dw_torch
 [[ 3.569718  4.275913  4.107687]
 [ 3.066375  4.63871   3.118461]
 [ 3.838584  4.8143    3.607119]]
db_numpy
 [ 7.27808   6.566618  5.797415  4.544563]
db_torch
 [ 7.27808   6.566618  5.797415  4.544563]
y_numpy
 [[ 0.038209 -0.104653 -0.37583 ]
 [-0.101538 -0.192953 -0.152639]]
y_torch
 [[ 0.038209 -0.104653 -0.37583 ]
 [-0.101538 -0.192953 -0.152639]]
dx_numpy
 [[-0.040876  0.002517 -0.160813 -0.109504 -0.154843 -0.022449 -0.304437]
 [ 0.093016 -0.133319  0.289393 -0.048426  0.246752  0.000857 -0.040514]
 [-0.076863 -0.00029  -0.240455  0.11569  -0.242307  0.337566 -0.519987]
 [ 0.092797 -0.117967  0.250818 -0.041134  0.187269 -0.040959 -0.022162]
 [-0.077402 -0.066061 -0.083124  0.102747  0.026502  0.250037 -0.120596]]
dx_torch
 [[-0.040876  0.002517 -0.160813 -0.109504 -0.154843 -0.022449 -0.304437]
 [ 0.093016 -0.133319  0.289393 -0.048426  0.246752  0.000857 -0.040514]
 [-0.076863 -0.00029  -0.240455  0.11569  -0.242307  0.337566 -0.519987]
 [ 0.092797 -0.117967  0.250818 -0.041134  0.187269 -0.040959 -0.022162]
 [-0.077402 -0.066061 -0.083124  0.102747  0.026502  0.250037 -0.120596]]
dw_numpy
 [[ 3.759475  4.10321   3.357347]
 [ 3.621035  3.578632  3.329442]
 [ 3.183582  4.388502  2.181801]]
dw_torch
 [[ 3.759475  4.10321   3.357347]
 [ 3.621035  3.578632  3.329442]
 [ 3.183582  4.388502  2.181801]]
db_numpy
 [ 6.804211  5.354626  6.683215  5.545097]
db_torch
 [ 6.804211  5.354626  6.683215  5.545097]
 """

你可能感兴趣的:(深度学习编程)