纯Python和PyTorch对比实现门控循环单元GRU及反向传播

摘要

本文使用纯 Python 和 PyTorch 对比实现门控循环单元GRU及其反向传播.

相关

原理和详细解释, 请参考: :

门控循环单元GRUCell详解及反向传播的梯度求导

文章索引 :
https://blog.csdn.net/oBrightLamp/article/details/85067981

正文

1. GRUCell 类

文件目录 : vanilla_nn/grucell.py

import numpy as np


def sigmoid(x):
    return 1 / (1 + np.exp(-x))


class GRUCell:
    def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh):
        self.weight_ih = weight_ih
        self.weight_hh = weight_hh
        self.bias_ih = bias_ih
        self.bias_hh = bias_hh

        self.dh_prev = None

        self.weight_ih_grad_stack = []
        self.weight_hh_grad_stack = []
        self.bias_ih_grad_stack = []
        self.bias_hh_grad_stack = []

        self.x_stack = []
        self.dx_list = []

        self.h_prev_stack = []
        self.h_next_stack = []
        self.dh_prev_list = []

        self.reset_gate_stack = []
        self.update_gate_stack = []
        self.cell_gate_stack = []

    def __call__(self, x, h_prev):
        xw_vector = np.dot(x, self.weight_ih.T) + self.bias_ih
        hv_vector = np.dot(h_prev, self.weight_hh.T) + self.bias_hh

        h_size = np.shape(h_prev)[1]

        reset_gate = sigmoid(xw_vector[:, h_size * 0:h_size * 1]
                             + hv_vector[:, h_size * 0:h_size * 1])

        update_gate = sigmoid(xw_vector[:, h_size * 1:h_size * 2]
                              + hv_vector[:, h_size * 1:h_size * 2])

        cell_gate = np.tanh(xw_vector[:, h_size * 2:]
                            + hv_vector[:, h_size * 2:] * reset_gate)

        h_next = (1 - update_gate) * cell_gate + update_gate * h_prev

        self.x_stack.append(x)
        self.reset_gate_stack.append(reset_gate)
        self.update_gate_stack.append(update_gate)
        self.cell_gate_stack.append(cell_gate)
        self.h_prev_stack.append(h_prev)
        self.h_next_stack.append(h_next)

        self.dh_prev = np.zeros_like(h_next)

        return h_next

    def backward(self, dh_next):
        x = self.x_stack.pop()
        h_prev = self.h_prev_stack.pop()
        reset_gate = self.reset_gate_stack.pop()
        update_gate = self.update_gate_stack.pop()
        cell_gate = self.cell_gate_stack.pop()

        h_size = np.shape(dh_next)[1]

        wr = self.weight_ih[h_size * 0:h_size * 1, :]
        wu = self.weight_ih[h_size * 1:h_size * 2, :]
        wc = self.weight_ih[h_size * 2:, :]

        vr = self.weight_hh[h_size * 0:h_size * 1, :]
        vu = self.weight_hh[h_size * 1:h_size * 2, :]
        vc = self.weight_hh[h_size * 2:, :]
        bc = self.bias_hh[h_size * 2:]

        dh = dh_next + self.dh_prev

        d_update_gate = dh * (h_prev - cell_gate)
        d_cell_gate = dh * (1 - update_gate)

        d_au = d_update_gate * update_gate * (1 - update_gate)
        d_ac = d_cell_gate * (1 - np.square(cell_gate))

        d_reset_gate = d_ac * (np.dot(h_prev, vc.T) + bc)
        d_ar = d_reset_gate * reset_gate * (1 - reset_gate)

        dh_prev = dh * update_gate + np.dot(d_ar, vr)
        dh_prev += np.dot(d_au, vu) + np.dot(d_ac * reset_gate, vc)
        self.dh_prev_list.insert(0, dh_prev)
        self.dh_prev = dh_prev

        dx = np.dot(d_ar, wr) + np.dot(d_au, wu) + np.dot(d_ac, wc)
        self.dx_list.insert(0, dx)

        dw = np.vstack([np.dot(d_ar.T, x),
                        np.dot(d_au.T, x),
                        np.dot(d_ac.T, x)])
        self.weight_ih_grad_stack.append(dw)

        dv = np.vstack([np.dot(d_ar.T, h_prev),
                        np.dot(d_au.T, h_prev),
                        np.dot((d_ac * reset_gate).T, h_prev)])
        self.weight_hh_grad_stack.append(dv)

        self.bias_ih_grad_stack.append(
            np.hstack([np.sum(d_ar, axis=0),
                       np.sum(d_au, axis=0),
                       np.sum(d_ac, axis=0)]))

        self.bias_hh_grad_stack.append(
            np.hstack([np.sum(d_ar, axis=0),
                       np.sum(d_au, axis=0),
                       np.sum(d_ac * reset_gate, axis=0)]))

        return dh_prev

2. GRUCell 测试

import torch
import numpy as np
from vanilla_nn.grucell import GRUCell

np.random.seed(123)
torch.random.manual_seed(123)
np.set_printoptions(precision=6, suppress=True, linewidth=80)

grucell_torch = torch.nn.GRUCell(3, 4).double()

grucell_numpy = GRUCell(grucell_torch.weight_ih.data.numpy(),
                        grucell_torch.weight_hh.data.numpy(),
                        grucell_torch.bias_ih.data.numpy(),
                        grucell_torch.bias_hh.data.numpy())

x_numpy = np.random.random((2, 3))
x_torch = torch.tensor(x_numpy, requires_grad=True)

h_numpy = np.random.random((2, 4))
h_torch = torch.tensor(h_numpy, requires_grad=True)

next_h_numpy = grucell_numpy(x_numpy, h_numpy)
next_h_torch = grucell_torch(x_torch, h_torch)

dh_numpy = np.random.random((2, 4))
dh_torch = torch.tensor(dh_numpy, requires_grad=True)

dh_numpy = grucell_numpy.backward(dh_numpy)
next_h_torch.backward(dh_torch)

print("--- 代码输出 ---")
print("out_numpy :\n", next_h_numpy)
print("out_torch :\n", next_h_torch.data.numpy())

print("---------")
print("dh_numpy :\n", dh_numpy)
print("dh_torch :\n", h_torch.grad.data.numpy())

print("---------")
print("dx_numpy :\n", np.array(grucell_numpy.dx_list))
print("dx_torch :\n", x_torch.grad.data.numpy())

print("---------")
print("w_ih_numpy :\n", np.sum(grucell_numpy.weight_ih_grad_stack, axis=0))
print("w_ih_torch :\n", grucell_torch.weight_ih.grad.data.numpy())

print("---------")
print("w_hh_numpy :\n", np.sum(grucell_numpy.weight_hh_grad_stack, axis=0))
print("w_hh_torch :\n", grucell_torch.weight_hh.grad.data.numpy())

print("---------")
print("b_ih_numpy :\n", np.sum(grucell_numpy.bias_ih_grad_stack, axis=0))
print("b_ih_torch :\n", grucell_torch.bias_ih.grad.data.numpy())

print("---------")
print("b_hh_numpy :\n", np.sum(grucell_numpy.bias_hh_grad_stack, axis=0))
print("b_hh_torch :\n", grucell_torch.bias_hh.grad.data.numpy())

"""
--- 代码输出 ---
out_numpy :
 [[ 0.537654  0.419409  0.334602  0.552652]
 [ 0.23298   0.546675  0.322881  0.331436]]
out_torch :
 [[ 0.537654  0.419409  0.334602  0.552652]
 [ 0.23298   0.546675  0.322881  0.331436]]
---------
dh_numpy :
 [[ 0.134685  0.262583  0.150232  0.011796]
 [ 0.29401   0.209619  0.38664   0.512479]]
dh_torch :
 [[ 0.134685  0.262583  0.150232  0.011796]
 [ 0.29401   0.209619  0.38664   0.512479]]
---------
dx_numpy :
 [[[ 0.100331  0.206714 -0.235714]
  [ 0.172891  0.255504 -0.155857]]]
dx_torch :
 [[ 0.100331  0.206714 -0.235714]
 [ 0.172891  0.255504 -0.155857]]
---------
w_ih_numpy :
 [[ 0.000184  0.002742  0.001377]
 [-0.063094 -0.039634 -0.027325]
 [-0.009026 -0.012194 -0.007132]
 [ 0.015227  0.016667  0.010103]
 [ 0.087729  0.052286  0.036599]
 [ 0.076196  0.052676  0.035376]
 [ 0.027772  0.029146  0.017808]
 [-0.084754 -0.101578 -0.060585]
 [ 0.247172  0.215229  0.136668]
 [ 0.432748  0.299664  0.20116 ]
 [ 0.239466  0.250673  0.153232]
 [ 0.122349  0.134346  0.081388]]
w_ih_torch :
 [[ 0.000184  0.002742  0.001377]
 [-0.063094 -0.039634 -0.027325]
 [-0.009026 -0.012194 -0.007132]
 [ 0.015227  0.016667  0.010103]
 [ 0.087729  0.052286  0.036599]
 [ 0.076196  0.052676  0.035376]
 [ 0.027772  0.029146  0.017808]
 [-0.084754 -0.101578 -0.060585]
 [ 0.247172  0.215229  0.136668]
 [ 0.432748  0.299664  0.20116 ]
 [ 0.239466  0.250673  0.153232]
 [ 0.122349  0.134346  0.081388]]
---------
w_hh_numpy :
 [[-0.002084  0.001193  0.00044  -0.001252]
 [-0.076799 -0.06724  -0.045178 -0.028548]
 [-0.005254 -0.012093 -0.007229 -0.000766]
 [ 0.012294  0.018921  0.011737  0.003278]
 [ 0.109266  0.092423  0.062486  0.041131]
 [ 0.088518  0.083027  0.055124  0.032029]
 [ 0.023524  0.034034  0.02126   0.006616]
 [-0.060689 -0.108654 -0.066362 -0.013765]
 [ 0.089854  0.108185  0.069158  0.028726]
 [ 0.262026  0.239856  0.159908  0.095748]
 [ 0.078294  0.109812  0.068845  0.022568]
 [ 0.071772  0.108447  0.067409  0.019455]]
w_hh_torch :
 [[-0.002084  0.001193  0.00044  -0.001252]
 [-0.076799 -0.06724  -0.045178 -0.028548]
 [-0.005254 -0.012093 -0.007229 -0.000766]
 [ 0.012294  0.018921  0.011737  0.003278]
 [ 0.109266  0.092423  0.062486  0.041131]
 [ 0.088518  0.083027  0.055124  0.032029]
 [ 0.023524  0.034034  0.02126   0.006616]
 [-0.060689 -0.108654 -0.066362 -0.013765]
 [ 0.089854  0.108185  0.069158  0.028726]
 [ 0.262026  0.239856  0.159908  0.095748]
 [ 0.078294  0.109812  0.068845  0.022568]
 [ 0.071772  0.108447  0.067409  0.019455]]
---------
b_ih_numpy :
 [ 0.001392 -0.096389 -0.016547  0.026264  0.13283   0.118438  0.047374 -0.149915
  0.402955  0.672871  0.408214  0.211218]
b_ih_torch :
 [ 0.001392 -0.096389 -0.016547  0.026264  0.13283   0.118438  0.047374 -0.149915
  0.402955  0.672871  0.408214  0.211218]
---------
b_hh_numpy :
 [ 0.001392 -0.096389 -0.016547  0.026264  0.13283   0.118438  0.047374 -0.149915
  0.151978  0.342735  0.153074  0.15066 ]
b_hh_torch :
 [ 0.001392 -0.096389 -0.016547  0.026264  0.13283   0.118438  0.047374 -0.149915
  0.151978  0.342735  0.153074  0.15066 ]
"""

3. GRU 测试

import torch
import numpy as np
from vanilla_nn.grucell import GRUCell

np.random.seed(123)
torch.random.manual_seed(123)
np.set_printoptions(precision=6, suppress=True, linewidth=80)

gru_torch = torch.nn.GRU(4, 5, 1).double()

gru_numpy = GRUCell(gru_torch.all_weights[0][0].data.numpy(),
                    gru_torch.all_weights[0][1].data.numpy(),
                    gru_torch.all_weights[0][2].data.numpy(),
                    gru_torch.all_weights[0][3].data.numpy())

x_numpy = np.random.random((3, 3, 4))
x_torch = torch.tensor(x_numpy, requires_grad=True)

h_numpy = np.random.random((1, 3, 5))
h_torch = torch.tensor(h_numpy, requires_grad=True)

dh_numpy = np.random.random((3, 3, 5))
dh_torch = torch.tensor(dh_numpy, requires_grad=True)

out_torch, hn_torch = gru_torch(x_torch, h_torch)

out_torch.backward(dh_torch)

h0_numpy = h_numpy[0]
for i in range(3):
    h0_numpy = gru_numpy(x_numpy[i], h0_numpy)

for i in reversed(range(3)):
    gru_numpy.backward(dh_numpy[i])

print("--- 代码输出 ---")
print("out_numpy :\n", np.array(gru_numpy.h_next_stack))
print("out_torch :\n", out_torch.data.numpy())

print("--- 代码输出 ---")
print("dx_numpy :\n", np.array(gru_numpy.dx_list))
print("dx_torch :\n", x_torch.grad.data.numpy())

print("--- 代码输出 ---")
print("dw_ih_numpy :\n", np.sum(gru_numpy.weight_ih_grad_stack, 0))
print("dw_ih_torch :\n", gru_torch.all_weights[0][0].grad.data.numpy())

print("--- 代码输出 ---")
print("dw_hh_numpy :\n", np.sum(gru_numpy.weight_hh_grad_stack, 0))
print("dw_hh_torch :\n", gru_torch.all_weights[0][1].grad.data.numpy())

print("--- 代码输出 ---")
print("db_ih_numpy :\n", np.sum(gru_numpy.bias_ih_grad_stack, 0))
print("db_ih_torch :\n", gru_torch.all_weights[0][2].grad.data.numpy())

print("--- 代码输出 ---")
print("db_hh_numpy :\n", np.sum(gru_numpy.bias_hh_grad_stack, 0))
print("db_hh_torch :\n", gru_torch.all_weights[0][3].grad.data.numpy())

"""
--- 代码输出 ---
out_numpy :
 [[[ 0.307578  0.38917   0.868305  0.226222  0.190614]
  [ 0.073207  0.063558  0.439481  0.389793 -0.115967]
  [ 0.342953  0.445099  0.494275  0.18942  -0.11399 ]]

 [[ 0.262178  0.124979  0.753652 -0.024015 -0.082316]
  [ 0.053777 -0.031063  0.358107  0.072536 -0.212627]
  [ 0.208464  0.104906  0.509116  0.026988 -0.162194]]

 [[ 0.141425 -0.034831  0.656582 -0.071536 -0.178199]
  [-0.001007 -0.050403  0.306474 -0.05571  -0.174536]
  [ 0.096458 -0.031054  0.448428 -0.061591 -0.166288]]]
out_torch :
 [[[ 0.307578  0.38917   0.868305  0.226222  0.190614]
  [ 0.073207  0.063558  0.439481  0.389793 -0.115967]
  [ 0.342953  0.445099  0.494275  0.18942  -0.11399 ]]

 [[ 0.262178  0.124979  0.753652 -0.024015 -0.082316]
  [ 0.053777 -0.031063  0.358107  0.072536 -0.212627]
  [ 0.208464  0.104906  0.509116  0.026988 -0.162194]]

 [[ 0.141425 -0.034831  0.656582 -0.071536 -0.178199]
  [-0.001007 -0.050403  0.306474 -0.05571  -0.174536]
  [ 0.096458 -0.031054  0.448428 -0.061591 -0.166288]]]
--- 代码输出 ---
dx_numpy :
 [[[-0.099959  0.125734 -0.100467  0.057584]
  [-0.175371  0.278181 -0.08639  -0.081519]
  [ 0.05688   0.215838  0.006216 -0.02745 ]]

 [[-0.030292  0.117966 -0.023817 -0.078948]
  [-0.041677  0.158772  0.007034  0.036202]
  [-0.028633  0.213837  0.008546 -0.06969 ]]

 [[-0.077594  0.189323 -0.042506 -0.126031]
  [-0.021397  0.159047  0.027051 -0.110323]
  [ 0.014817  0.163531  0.020699 -0.00104 ]]]
dx_torch :
 [[[-0.099959  0.125734 -0.100467  0.057584]
  [-0.175371  0.278181 -0.08639  -0.081519]
  [ 0.05688   0.215838  0.006216 -0.02745 ]]

 [[-0.030292  0.117966 -0.023817 -0.078948]
  [-0.041677  0.158772  0.007034  0.036202]
  [-0.028633  0.213837  0.008546 -0.06969 ]]

 [[-0.077594  0.189323 -0.042506 -0.126031]
  [-0.021397  0.159047  0.027051 -0.110323]
  [ 0.014817  0.163531  0.020699 -0.00104 ]]]
--- 代码输出 ---
dw_ih_numpy :
 [[ 0.054313  0.040754  0.046713  0.056132]
 [ 0.143938  0.086123  0.133868  0.142727]
 [-0.012153 -0.020432 -0.025273 -0.01912 ]
 [ 0.250502  0.158463  0.181569  0.266292]
 [ 0.089863  0.081379  0.079555  0.086468]
 [ 0.380143  0.288339  0.288571  0.369766]
 [ 0.421876  0.260856  0.32739   0.447516]
 [ 0.095548  0.049706  0.056802  0.111859]
 [ 0.518252  0.359941  0.455362  0.608289]
 [ 0.466701  0.236542  0.426777  0.473724]
 [ 1.355748  1.005534  1.19756   1.463865]
 [ 1.438667  0.910484  1.333382  1.454368]
 [ 0.813703  0.66155   0.821699  0.913758]
 [ 1.679467  1.182209  1.29286   1.748798]
 [ 1.473526  1.200475  1.428137  1.480017]]
dw_ih_torch :
 [[ 0.054313  0.040754  0.046713  0.056132]
 [ 0.143938  0.086123  0.133868  0.142727]
 [-0.012153 -0.020432 -0.025273 -0.01912 ]
 [ 0.250502  0.158463  0.181569  0.266292]
 [ 0.089863  0.081379  0.079555  0.086468]
 [ 0.380143  0.288339  0.288571  0.369766]
 [ 0.421876  0.260856  0.32739   0.447516]
 [ 0.095548  0.049706  0.056802  0.111859]
 [ 0.518252  0.359941  0.455362  0.608289]
 [ 0.466701  0.236542  0.426777  0.473724]
 [ 1.355748  1.005534  1.19756   1.463865]
 [ 1.438667  0.910484  1.333382  1.454368]
 [ 0.813703  0.66155   0.821699  0.913758]
 [ 1.679467  1.182209  1.29286   1.748798]
 [ 1.473526  1.200475  1.428137  1.480017]]
--- 代码输出 ---
dw_hh_numpy :
 [[ 0.029407  0.040854  0.067315  0.028967  0.006582]
 [ 0.065199  0.106626  0.161274  0.110149  0.040565]
 [-0.00217   0.002182 -0.009461 -0.011639  0.012655]
 [ 0.154114  0.269694  0.300844  0.181655  0.08886 ]
 [ 0.041277  0.04673   0.102476  0.038294 -0.006607]
 [ 0.22554   0.374378  0.425336  0.237448  0.09482 ]
 [ 0.251627  0.479232  0.472401  0.366764  0.17784 ]
 [ 0.066873  0.099892  0.162182  0.052397  0.031526]
 [ 0.313399  0.564707  0.579177  0.443881  0.153558]
 [ 0.212389  0.407104  0.504583  0.422809  0.217823]
 [ 0.224363  0.357882  0.464804  0.271478  0.07211 ]
 [ 0.242434  0.376752  0.565195  0.340365  0.105577]
 [ 0.20171   0.312777  0.443363  0.292979  0.040601]
 [ 0.442316  0.728961  0.869799  0.489859  0.181804]
 [ 0.333121  0.456453  0.815045  0.469363  0.037041]]
dw_hh_torch :
 [[ 0.029407  0.040854  0.067315  0.028967  0.006582]
 [ 0.065199  0.106626  0.161274  0.110149  0.040565]
 [-0.00217   0.002182 -0.009461 -0.011639  0.012655]
 [ 0.154114  0.269694  0.300844  0.181655  0.08886 ]
 [ 0.041277  0.04673   0.102476  0.038294 -0.006607]
 [ 0.22554   0.374378  0.425336  0.237448  0.09482 ]
 [ 0.251627  0.479232  0.472401  0.366764  0.17784 ]
 [ 0.066873  0.099892  0.162182  0.052397  0.031526]
 [ 0.313399  0.564707  0.579177  0.443881  0.153558]
 [ 0.212389  0.407104  0.504583  0.422809  0.217823]
 [ 0.224363  0.357882  0.464804  0.271478  0.07211 ]
 [ 0.242434  0.376752  0.565195  0.340365  0.105577]
 [ 0.20171   0.312777  0.443363  0.292979  0.040601]
 [ 0.442316  0.728961  0.869799  0.489859  0.181804]
 [ 0.333121  0.456453  0.815045  0.469363  0.037041]]
--- 代码输出 ---
db_ih_numpy :
 [ 0.106404  0.257179 -0.039206  0.442932  0.181347  0.660625  0.72062   0.218483
  0.946051  0.759723  2.604006  2.604498  1.66155   3.041845  2.817076]
db_ih_torch :
 [ 0.106404  0.257179 -0.039206  0.442932  0.181347  0.660625  0.72062   0.218483
  0.946051  0.759723  2.604006  2.604498  1.66155   3.041845  2.817076]
--- 代码输出 ---
db_hh_numpy :
 [ 0.106404  0.257179 -0.039206  0.442932  0.181347  0.660625  0.72062   0.218483
  0.946051  0.759723  0.763548  0.904376  0.79913   1.342985  1.467562]
db_hh_torch :
 [ 0.106404  0.257179 -0.039206  0.442932  0.181347  0.660625  0.72062   0.218483
  0.946051  0.759723  0.763548  0.904376  0.79913   1.342985  1.467562]
"""

你可能感兴趣的:(深度学习编程)