Python和PyTorch对比实现affine/linear(仿射/线性)变换函数及全连接层的反向传播

摘要

本文使用纯 Python 和 PyTorch 对比实现affine/linear(仿射/线性)变换函数及其反向传播.

相关

原理和详细解释, 请参考文章 :

affine/linear(仿射/线性)变换函数详解及全连接层反向传播的梯度求导

系列文章索引 :
https://blog.csdn.net/oBrightLamp/article/details/85067981

正文

import torch
import numpy as np


class Affine:
    def __init__(self):
        self.x = None
        self.weight = None
        self.bias = None
        self.dx = None
        self.dw = None
        self.db = None

    def __call__(self, x):
        self.x = x
        out = np.dot(self.x, self.weight.T) + self.bias
        return out

    def backward(self, d_loss):
        self.dx = np.dot(d_loss, self.weight)
        self.dw = np.dot(d_loss.T, self.x)
        self.db = np.sum(d_loss, axis=0)
        return self.dx


np.random.seed(123)
np.set_printoptions(precision=6, suppress=True, linewidth=120)

x_numpy = np.random.random((3, 4))
w_numpy = np.random.random((5, 4))
b_numpy = np.random.random((5,))

x_tensor = torch.tensor(x_numpy, requires_grad=True)
w_tensor = torch.tensor(w_numpy, requires_grad=True)
b_tensor = torch.tensor(b_numpy, requires_grad=True)

affine_numpy = Affine()
affine_numpy.weight = w_numpy
affine_numpy.bias = b_numpy

affine_tensor = torch.nn.Linear(3, 7, bias=True)
affine_tensor.weight = torch.nn.Parameter(w_tensor, requires_grad=True)
affine_tensor.bias = torch.nn.Parameter(b_tensor, requires_grad=True)

out_numpy = affine_numpy(x_numpy)
out_tensor = affine_tensor(x_tensor)

d_loss_numpy = np.random.random(out_numpy.shape)
d_loss_tensor = torch.tensor(d_loss_numpy)
out_tensor.backward(d_loss_tensor)

dx_numpy = affine_numpy.backward(d_loss_numpy)
dw_numpy = affine_numpy.dw
db_numpy = affine_numpy.db

dx_tensor = x_tensor.grad
dw_tensor = affine_tensor.weight.grad
db_tensor = affine_tensor.bias.grad

print("--- 对比变换结果 ---")
print(out_numpy)
print(out_tensor.data.numpy())

print("--- 对比 dx ---")
print(dx_numpy)
print(dx_tensor.numpy())

print("--- 对比 dw ---")
print(dw_numpy)
print(dw_tensor.numpy())

print("--- 对比 db ---")
print(db_numpy)
print(db_tensor.numpy())

"""
代码输出
--- 对比变换结果 ---
[[ 1.250556  1.084776  1.611937  1.115749  1.071461]
 [ 1.667441  1.584755  2.370629  1.479834  1.291984]
 [ 1.339822  1.220394  1.758095  1.076918  1.162823]]
[[ 1.250556  1.084776  1.611937  1.115749  1.071461]
 [ 1.667441  1.584755  2.370629  1.479834  1.291984]
 [ 1.339822  1.220394  1.758095  1.076918  1.162823]]
--- 对比 dx ---
[[ 1.367212  0.91971   1.457424  1.660651]
 [ 1.087256  1.213257  1.109499  1.250769]
 [ 1.245717  1.230932  1.232196  1.764028]]
[[ 1.367212  0.91971   1.457424  1.660651]
 [ 1.087256  1.213257  1.109499  1.250769]
 [ 1.245717  1.230932  1.232196  1.764028]]
--- 对比 dw ---
[[ 1.324482  0.776335  0.852071  1.428347]
 [ 1.20587   0.649376  0.799307  1.183345]
 [ 1.267557  0.750463  1.173819  1.316775]
 [ 0.672773  0.331807  0.428579  0.603458]
 [ 0.825466  0.561481  0.783553  0.996982]]
[[ 1.324482  0.776335  0.852071  1.428347]
 [ 1.20587   0.649376  0.799307  1.183345]
 [ 1.267557  0.750463  1.173819  1.316775]
 [ 0.672773  0.331807  0.428579  0.603458]
 [ 0.825466  0.561481  0.783553  0.996982]]
--- 对比 db ---
[ 2.196234  1.878471  1.98104   0.995037  1.424993]
[ 2.196234  1.878471  1.98104   0.995037  1.424993]
"""

你可能感兴趣的:(深度学习编程)