本文使用纯 Python 实现 TensorFlow 和 PyTorch 验证卷积 convolution 函数矩阵化计算及反向传播.
原理和详细解释, 请参考文章 :
卷积 convolution 函数的矩阵化计算方法及其梯度的反向传播
系列文章索引 :
https://blog.csdn.net/oBrightLamp/article/details/85067981
文件目录 : vanilla_nn/convolution_matrix.py
import numpy as np
class Batch2ConvMatrix:
def __init__(self, stride, kernel_h, kernel_w):
self.stride = stride
self.kernel_h = kernel_h
self.kernel_w = kernel_w
self.x = None
self.conv_size = None
def __call__(self, x):
self.x = x
x_nums, x_channels, x_height, x_width = np.shape(self.x)
conv_height = int((x_height - self.kernel_h) / self.stride) + 1
conv_width = int((x_width - self.kernel_w) / self.stride) + 1
scan = np.zeros((x_nums, conv_height, conv_width,
x_channels, self.kernel_h, self.kernel_w))
for n in range(x_nums):
for h in range(conv_height):
for w in range(conv_width):
for c in range(x_channels):
start_h = h * self.stride
start_w = w * self.stride
end_h = start_h + self.kernel_h
end_w = start_w + self.kernel_w
scan[n, h, w, c] = \
x[n, c, start_h:end_h, start_w:end_w]
conv_matrix = scan.reshape(x_nums * conv_height * conv_width, -1)
self.conv_size = [x_nums, x_channels, conv_height, conv_width]
return conv_matrix
def backward(self, dx2m):
dx = np.zeros_like(self.x)
kh = self.kernel_h
kw = self.kernel_w
xn, xc, ch, cw = self.conv_size
dx2m = dx2m.reshape((xn, ch, cw, xc, kh, kw))
for n in range(xn):
for c in range(xc):
for h in range(ch):
for w in range(cw):
start_h = h * self.stride
start_w = w * self.stride
end_h = start_h + self.kernel_h
end_w = start_w + self.kernel_w
dx[n, c][start_h:end_h, start_w:end_w] \
+= dx2m[n, h, w, c]
return dx
class Conv2d:
def __init__(self, stride, weight=None, bias=None):
self.stride = stride
self.weight = weight
self.bias = bias
self.b2m = None
self.x2m = None
self.w2m = None
self.dw = None
self.db = None
def __call__(self, x):
wn, wc, wh, ww = np.shape(self.weight)
if self.b2m is None:
self.b2m = Batch2ConvMatrix(self.stride, wh, ww)
x2m = self.b2m(x)
w2m = self.weight.reshape(wn, -1)
xn, xc, oh, ow = self.b2m.conv_size
out_matrix = np.matmul(x2m, w2m.T) + self.bias
out = out_matrix.reshape((xn, oh, ow, wn))
self.x2m = x2m
self.w2m = w2m
out = out.transpose((0, 3, 1, 2))
return out
def backward(self, d_loss):
on, oc, oh, ow = np.shape(d_loss)
d_loss = d_loss.transpose((0, 2, 3, 1))
d_loss = d_loss.reshape((on * oh * ow, -1))
dw = np.matmul(d_loss.T, self.x2m)
self.dw = dw.reshape(np.shape(self.weight))
self.db = np.sum(d_loss, axis=0)
dx2m = np.matmul(d_loss, self.w2m)
dx = self.b2m.backward(dx2m)
return dx
import numpy as np
import tensorflow as tf
from vanilla_nn.convolution_matrix import Conv2d
tf.enable_eager_execution()
tf.set_random_seed(123)
np.random.seed(123)
np.set_printoptions(6, suppress=True, linewidth=120)
x_numpy = np.random.random((2, 3, 5, 7))
x_tf = tf.constant(x_numpy)
conv_tf = tf.layers.Conv2D(
filters=4, kernel_size=3, strides=(2, 2),
data_format="channels_first")
with tf.GradientTape(persistent=True) as t:
t.watch(x_tf)
y_tf = conv_tf(x_tf)
conv_numpy = Conv2d(
stride=2,
weight=conv_tf.get_weights()[0].transpose(3, 2, 0, 1),
bias=conv_tf.get_weights()[1])
y_numpy = conv_numpy(x_numpy)
dy_numpy = np.random.random(y_numpy.shape)
dy_tf = tf.constant(dy_numpy)
dx_numpy = conv_numpy.backward(dy_numpy)
dy_dx = t.gradient(y_tf, x_tf, dy_tf)
dw_tf, db_tf = t.gradient(y_tf, conv_tf.weights, dy_tf)
print("y_numpy\n", y_numpy[0][0])
print("y_tf\n", y_tf.numpy()[0][0])
print("dx_numpy\n", dx_numpy[0][0])
print("dx_tf\n", dy_dx.numpy()[0][0])
print("dw_numpy\n", conv_numpy.dw[0][0])
print("dw_tf\n", dw_tf.numpy()[0][0])
print("db_numpy\n", conv_numpy.db)
print("db_tf\n", db_tf.numpy())
"""
y_numpy
[[-0.410179 -0.379269 -1.083951]
[-0.340203 -0.57836 -0.620306]]
y_tf
[[-0.410179 -0.379269 -1.083951]
[-0.340203 -0.57836 -0.620306]]
dx_numpy
[[-0.220716 0.140123 -0.246691 0.017347 -0.073237 -0.027538 0.058724]
[-0.067507 -0.137531 0.273061 -0.048839 0.257607 -0.097253 0.445625]
[-0.403482 -0.176619 -0.319394 -0.279691 -0.49433 -0.325839 -0.105012]
[-0.114447 -0.17873 0.223171 -0.076912 0.27387 -0.123164 0.544131]
[-0.089011 -0.26426 0.003028 -0.24742 -0.286531 -0.404775 0.071425]]
dx_tf
[[-0.220716 0.140123 -0.246691 0.017347 -0.073237 -0.027538 0.058724]
[-0.067507 -0.137531 0.273061 -0.048839 0.257607 -0.097253 0.445625]
[-0.403482 -0.176619 -0.319394 -0.279691 -0.49433 -0.325839 -0.105012]
[-0.114447 -0.17873 0.223171 -0.076912 0.27387 -0.123164 0.544131]
[-0.089011 -0.26426 0.003028 -0.24742 -0.286531 -0.404775 0.071425]]
dw_numpy
[[ 3.759475 4.10321 3.357347]
[ 3.621035 3.578632 3.329442]
[ 3.183582 4.388502 2.181801]]
dw_tf
[[ 3.759475 2.767449 3.618431 3.229403]
[ 3.718784 3.304698 3.817697 3.17113 ]
[ 3.056882 1.940038 2.436826 2.322039]]
db_numpy
[ 6.804211 5.354626 6.683215 5.545097]
db_tf
[ 6.804211 5.354626 6.683215 5.545097]
"""
import torch
import numpy as np
from vanilla_nn.convolution_matrix import Conv2d
np.random.seed(123)
torch.manual_seed(123)
np.set_printoptions(6, suppress=True, linewidth=120)
x_numpy = np.random.random((2, 3, 5, 7))
x_torch = torch.tensor(x_numpy, requires_grad=True)
conv_torch = torch.nn.Conv2d(3, 4, (3, 3), stride=2).double()
conv_numpy = Conv2d(
stride=2,
weight=conv_torch.weight.data.numpy(),
bias=conv_torch.bias.data.numpy())
y_numpy = conv_numpy(x_numpy)
y_torch = conv_torch(x_torch)
dy_numpy = np.random.random(y_numpy.shape)
dy_torch = torch.tensor(dy_numpy)
y_torch.backward(dy_torch)
dx_numpy = conv_numpy.backward(dy_numpy)
dx_torch = x_torch.grad.data.numpy()
print("y_numpy\n", y_numpy[0][0])
print("y_torch\n", y_torch.data.numpy()[0][0])
print("dx_numpy\n", dx_numpy[0][0])
print("dx_torch\n", dx_torch[0][0])
print("dw_numpy\n", conv_numpy.dw[0][0])
print("dw_torch\n", conv_torch.weight.grad.data.numpy()[0][0])
print("db_numpy\n", conv_numpy.db)
print("db_torch\n", conv_torch.bias.grad.data.numpy())
"""
y_numpy
[[-0.400592 -0.385043 -0.5706 ]
[-0.373743 -0.203867 -0.217639]]
y_torch
[[-0.400592 -0.385043 -0.5706 ]
[-0.373743 -0.203867 -0.217639]]
dx_numpy
[[ 0.097144 -0.050192 -0.115058 -0.091209 0.023963 0.04729 -0.128907]
[ 0.024621 -0.087676 0.238786 0.010403 0.297381 -0.112517 0.262634]
[ 0.024772 -0.310948 0.129654 -0.152652 0.059663 -0.34918 -0.192246]
[ 0.084056 -0.081453 0.207333 -0.126043 0.409333 -0.012 0.260899]
[-0.026007 -0.175092 -0.040534 -0.303666 0.025274 -0.298108 0.191089]]
dx_torch
[[ 0.097144 -0.050192 -0.115058 -0.091209 0.023963 0.04729 -0.128907]
[ 0.024621 -0.087676 0.238786 0.010403 0.297381 -0.112517 0.262634]
[ 0.024772 -0.310948 0.129654 -0.152652 0.059663 -0.34918 -0.192246]
[ 0.084056 -0.081453 0.207333 -0.126043 0.409333 -0.012 0.260899]
[-0.026007 -0.175092 -0.040534 -0.303666 0.025274 -0.298108 0.191089]]
dw_numpy
[[ 3.569718 4.275913 4.107687]
[ 3.066375 4.63871 3.118461]
[ 3.838584 4.8143 3.607119]]
dw_torch
[[ 3.569718 4.275913 4.107687]
[ 3.066375 4.63871 3.118461]
[ 3.838584 4.8143 3.607119]]
db_numpy
[ 7.27808 6.566618 5.797415 4.544563]
db_torch
[ 7.27808 6.566618 5.797415 4.544563]
y_numpy
[[ 0.038209 -0.104653 -0.37583 ]
[-0.101538 -0.192953 -0.152639]]
y_torch
[[ 0.038209 -0.104653 -0.37583 ]
[-0.101538 -0.192953 -0.152639]]
dx_numpy
[[-0.040876 0.002517 -0.160813 -0.109504 -0.154843 -0.022449 -0.304437]
[ 0.093016 -0.133319 0.289393 -0.048426 0.246752 0.000857 -0.040514]
[-0.076863 -0.00029 -0.240455 0.11569 -0.242307 0.337566 -0.519987]
[ 0.092797 -0.117967 0.250818 -0.041134 0.187269 -0.040959 -0.022162]
[-0.077402 -0.066061 -0.083124 0.102747 0.026502 0.250037 -0.120596]]
dx_torch
[[-0.040876 0.002517 -0.160813 -0.109504 -0.154843 -0.022449 -0.304437]
[ 0.093016 -0.133319 0.289393 -0.048426 0.246752 0.000857 -0.040514]
[-0.076863 -0.00029 -0.240455 0.11569 -0.242307 0.337566 -0.519987]
[ 0.092797 -0.117967 0.250818 -0.041134 0.187269 -0.040959 -0.022162]
[-0.077402 -0.066061 -0.083124 0.102747 0.026502 0.250037 -0.120596]]
dw_numpy
[[ 3.759475 4.10321 3.357347]
[ 3.621035 3.578632 3.329442]
[ 3.183582 4.388502 2.181801]]
dw_torch
[[ 3.759475 4.10321 3.357347]
[ 3.621035 3.578632 3.329442]
[ 3.183582 4.388502 2.181801]]
db_numpy
[ 6.804211 5.354626 6.683215 5.545097]
db_torch
[ 6.804211 5.354626 6.683215 5.545097]
"""