本文使用纯 Python 和 PyTorch 对比实现门控循环单元GRU及其反向传播.
原理和详细解释, 请参考: :
门控循环单元GRUCell详解及反向传播的梯度求导
文章索引 :
https://blog.csdn.net/oBrightLamp/article/details/85067981
文件目录 : vanilla_nn/grucell.py
import numpy as np
def sigmoid(x):
return 1 / (1 + np.exp(-x))
class GRUCell:
def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh):
self.weight_ih = weight_ih
self.weight_hh = weight_hh
self.bias_ih = bias_ih
self.bias_hh = bias_hh
self.dh_prev = None
self.weight_ih_grad_stack = []
self.weight_hh_grad_stack = []
self.bias_ih_grad_stack = []
self.bias_hh_grad_stack = []
self.x_stack = []
self.dx_list = []
self.h_prev_stack = []
self.h_next_stack = []
self.dh_prev_list = []
self.reset_gate_stack = []
self.update_gate_stack = []
self.cell_gate_stack = []
def __call__(self, x, h_prev):
xw_vector = np.dot(x, self.weight_ih.T) + self.bias_ih
hv_vector = np.dot(h_prev, self.weight_hh.T) + self.bias_hh
h_size = np.shape(h_prev)[1]
reset_gate = sigmoid(xw_vector[:, h_size * 0:h_size * 1]
+ hv_vector[:, h_size * 0:h_size * 1])
update_gate = sigmoid(xw_vector[:, h_size * 1:h_size * 2]
+ hv_vector[:, h_size * 1:h_size * 2])
cell_gate = np.tanh(xw_vector[:, h_size * 2:]
+ hv_vector[:, h_size * 2:] * reset_gate)
h_next = (1 - update_gate) * cell_gate + update_gate * h_prev
self.x_stack.append(x)
self.reset_gate_stack.append(reset_gate)
self.update_gate_stack.append(update_gate)
self.cell_gate_stack.append(cell_gate)
self.h_prev_stack.append(h_prev)
self.h_next_stack.append(h_next)
self.dh_prev = np.zeros_like(h_next)
return h_next
def backward(self, dh_next):
x = self.x_stack.pop()
h_prev = self.h_prev_stack.pop()
reset_gate = self.reset_gate_stack.pop()
update_gate = self.update_gate_stack.pop()
cell_gate = self.cell_gate_stack.pop()
h_size = np.shape(dh_next)[1]
wr = self.weight_ih[h_size * 0:h_size * 1, :]
wu = self.weight_ih[h_size * 1:h_size * 2, :]
wc = self.weight_ih[h_size * 2:, :]
vr = self.weight_hh[h_size * 0:h_size * 1, :]
vu = self.weight_hh[h_size * 1:h_size * 2, :]
vc = self.weight_hh[h_size * 2:, :]
bc = self.bias_hh[h_size * 2:]
dh = dh_next + self.dh_prev
d_update_gate = dh * (h_prev - cell_gate)
d_cell_gate = dh * (1 - update_gate)
d_au = d_update_gate * update_gate * (1 - update_gate)
d_ac = d_cell_gate * (1 - np.square(cell_gate))
d_reset_gate = d_ac * (np.dot(h_prev, vc.T) + bc)
d_ar = d_reset_gate * reset_gate * (1 - reset_gate)
dh_prev = dh * update_gate + np.dot(d_ar, vr)
dh_prev += np.dot(d_au, vu) + np.dot(d_ac * reset_gate, vc)
self.dh_prev_list.insert(0, dh_prev)
self.dh_prev = dh_prev
dx = np.dot(d_ar, wr) + np.dot(d_au, wu) + np.dot(d_ac, wc)
self.dx_list.insert(0, dx)
dw = np.vstack([np.dot(d_ar.T, x),
np.dot(d_au.T, x),
np.dot(d_ac.T, x)])
self.weight_ih_grad_stack.append(dw)
dv = np.vstack([np.dot(d_ar.T, h_prev),
np.dot(d_au.T, h_prev),
np.dot((d_ac * reset_gate).T, h_prev)])
self.weight_hh_grad_stack.append(dv)
self.bias_ih_grad_stack.append(
np.hstack([np.sum(d_ar, axis=0),
np.sum(d_au, axis=0),
np.sum(d_ac, axis=0)]))
self.bias_hh_grad_stack.append(
np.hstack([np.sum(d_ar, axis=0),
np.sum(d_au, axis=0),
np.sum(d_ac * reset_gate, axis=0)]))
return dh_prev
import torch
import numpy as np
from vanilla_nn.grucell import GRUCell
np.random.seed(123)
torch.random.manual_seed(123)
np.set_printoptions(precision=6, suppress=True, linewidth=80)
grucell_torch = torch.nn.GRUCell(3, 4).double()
grucell_numpy = GRUCell(grucell_torch.weight_ih.data.numpy(),
grucell_torch.weight_hh.data.numpy(),
grucell_torch.bias_ih.data.numpy(),
grucell_torch.bias_hh.data.numpy())
x_numpy = np.random.random((2, 3))
x_torch = torch.tensor(x_numpy, requires_grad=True)
h_numpy = np.random.random((2, 4))
h_torch = torch.tensor(h_numpy, requires_grad=True)
next_h_numpy = grucell_numpy(x_numpy, h_numpy)
next_h_torch = grucell_torch(x_torch, h_torch)
dh_numpy = np.random.random((2, 4))
dh_torch = torch.tensor(dh_numpy, requires_grad=True)
dh_numpy = grucell_numpy.backward(dh_numpy)
next_h_torch.backward(dh_torch)
print("--- 代码输出 ---")
print("out_numpy :\n", next_h_numpy)
print("out_torch :\n", next_h_torch.data.numpy())
print("---------")
print("dh_numpy :\n", dh_numpy)
print("dh_torch :\n", h_torch.grad.data.numpy())
print("---------")
print("dx_numpy :\n", np.array(grucell_numpy.dx_list))
print("dx_torch :\n", x_torch.grad.data.numpy())
print("---------")
print("w_ih_numpy :\n", np.sum(grucell_numpy.weight_ih_grad_stack, axis=0))
print("w_ih_torch :\n", grucell_torch.weight_ih.grad.data.numpy())
print("---------")
print("w_hh_numpy :\n", np.sum(grucell_numpy.weight_hh_grad_stack, axis=0))
print("w_hh_torch :\n", grucell_torch.weight_hh.grad.data.numpy())
print("---------")
print("b_ih_numpy :\n", np.sum(grucell_numpy.bias_ih_grad_stack, axis=0))
print("b_ih_torch :\n", grucell_torch.bias_ih.grad.data.numpy())
print("---------")
print("b_hh_numpy :\n", np.sum(grucell_numpy.bias_hh_grad_stack, axis=0))
print("b_hh_torch :\n", grucell_torch.bias_hh.grad.data.numpy())
"""
--- 代码输出 ---
out_numpy :
[[ 0.537654 0.419409 0.334602 0.552652]
[ 0.23298 0.546675 0.322881 0.331436]]
out_torch :
[[ 0.537654 0.419409 0.334602 0.552652]
[ 0.23298 0.546675 0.322881 0.331436]]
---------
dh_numpy :
[[ 0.134685 0.262583 0.150232 0.011796]
[ 0.29401 0.209619 0.38664 0.512479]]
dh_torch :
[[ 0.134685 0.262583 0.150232 0.011796]
[ 0.29401 0.209619 0.38664 0.512479]]
---------
dx_numpy :
[[[ 0.100331 0.206714 -0.235714]
[ 0.172891 0.255504 -0.155857]]]
dx_torch :
[[ 0.100331 0.206714 -0.235714]
[ 0.172891 0.255504 -0.155857]]
---------
w_ih_numpy :
[[ 0.000184 0.002742 0.001377]
[-0.063094 -0.039634 -0.027325]
[-0.009026 -0.012194 -0.007132]
[ 0.015227 0.016667 0.010103]
[ 0.087729 0.052286 0.036599]
[ 0.076196 0.052676 0.035376]
[ 0.027772 0.029146 0.017808]
[-0.084754 -0.101578 -0.060585]
[ 0.247172 0.215229 0.136668]
[ 0.432748 0.299664 0.20116 ]
[ 0.239466 0.250673 0.153232]
[ 0.122349 0.134346 0.081388]]
w_ih_torch :
[[ 0.000184 0.002742 0.001377]
[-0.063094 -0.039634 -0.027325]
[-0.009026 -0.012194 -0.007132]
[ 0.015227 0.016667 0.010103]
[ 0.087729 0.052286 0.036599]
[ 0.076196 0.052676 0.035376]
[ 0.027772 0.029146 0.017808]
[-0.084754 -0.101578 -0.060585]
[ 0.247172 0.215229 0.136668]
[ 0.432748 0.299664 0.20116 ]
[ 0.239466 0.250673 0.153232]
[ 0.122349 0.134346 0.081388]]
---------
w_hh_numpy :
[[-0.002084 0.001193 0.00044 -0.001252]
[-0.076799 -0.06724 -0.045178 -0.028548]
[-0.005254 -0.012093 -0.007229 -0.000766]
[ 0.012294 0.018921 0.011737 0.003278]
[ 0.109266 0.092423 0.062486 0.041131]
[ 0.088518 0.083027 0.055124 0.032029]
[ 0.023524 0.034034 0.02126 0.006616]
[-0.060689 -0.108654 -0.066362 -0.013765]
[ 0.089854 0.108185 0.069158 0.028726]
[ 0.262026 0.239856 0.159908 0.095748]
[ 0.078294 0.109812 0.068845 0.022568]
[ 0.071772 0.108447 0.067409 0.019455]]
w_hh_torch :
[[-0.002084 0.001193 0.00044 -0.001252]
[-0.076799 -0.06724 -0.045178 -0.028548]
[-0.005254 -0.012093 -0.007229 -0.000766]
[ 0.012294 0.018921 0.011737 0.003278]
[ 0.109266 0.092423 0.062486 0.041131]
[ 0.088518 0.083027 0.055124 0.032029]
[ 0.023524 0.034034 0.02126 0.006616]
[-0.060689 -0.108654 -0.066362 -0.013765]
[ 0.089854 0.108185 0.069158 0.028726]
[ 0.262026 0.239856 0.159908 0.095748]
[ 0.078294 0.109812 0.068845 0.022568]
[ 0.071772 0.108447 0.067409 0.019455]]
---------
b_ih_numpy :
[ 0.001392 -0.096389 -0.016547 0.026264 0.13283 0.118438 0.047374 -0.149915
0.402955 0.672871 0.408214 0.211218]
b_ih_torch :
[ 0.001392 -0.096389 -0.016547 0.026264 0.13283 0.118438 0.047374 -0.149915
0.402955 0.672871 0.408214 0.211218]
---------
b_hh_numpy :
[ 0.001392 -0.096389 -0.016547 0.026264 0.13283 0.118438 0.047374 -0.149915
0.151978 0.342735 0.153074 0.15066 ]
b_hh_torch :
[ 0.001392 -0.096389 -0.016547 0.026264 0.13283 0.118438 0.047374 -0.149915
0.151978 0.342735 0.153074 0.15066 ]
"""
import torch
import numpy as np
from vanilla_nn.grucell import GRUCell
np.random.seed(123)
torch.random.manual_seed(123)
np.set_printoptions(precision=6, suppress=True, linewidth=80)
gru_torch = torch.nn.GRU(4, 5, 1).double()
gru_numpy = GRUCell(gru_torch.all_weights[0][0].data.numpy(),
gru_torch.all_weights[0][1].data.numpy(),
gru_torch.all_weights[0][2].data.numpy(),
gru_torch.all_weights[0][3].data.numpy())
x_numpy = np.random.random((3, 3, 4))
x_torch = torch.tensor(x_numpy, requires_grad=True)
h_numpy = np.random.random((1, 3, 5))
h_torch = torch.tensor(h_numpy, requires_grad=True)
dh_numpy = np.random.random((3, 3, 5))
dh_torch = torch.tensor(dh_numpy, requires_grad=True)
out_torch, hn_torch = gru_torch(x_torch, h_torch)
out_torch.backward(dh_torch)
h0_numpy = h_numpy[0]
for i in range(3):
h0_numpy = gru_numpy(x_numpy[i], h0_numpy)
for i in reversed(range(3)):
gru_numpy.backward(dh_numpy[i])
print("--- 代码输出 ---")
print("out_numpy :\n", np.array(gru_numpy.h_next_stack))
print("out_torch :\n", out_torch.data.numpy())
print("--- 代码输出 ---")
print("dx_numpy :\n", np.array(gru_numpy.dx_list))
print("dx_torch :\n", x_torch.grad.data.numpy())
print("--- 代码输出 ---")
print("dw_ih_numpy :\n", np.sum(gru_numpy.weight_ih_grad_stack, 0))
print("dw_ih_torch :\n", gru_torch.all_weights[0][0].grad.data.numpy())
print("--- 代码输出 ---")
print("dw_hh_numpy :\n", np.sum(gru_numpy.weight_hh_grad_stack, 0))
print("dw_hh_torch :\n", gru_torch.all_weights[0][1].grad.data.numpy())
print("--- 代码输出 ---")
print("db_ih_numpy :\n", np.sum(gru_numpy.bias_ih_grad_stack, 0))
print("db_ih_torch :\n", gru_torch.all_weights[0][2].grad.data.numpy())
print("--- 代码输出 ---")
print("db_hh_numpy :\n", np.sum(gru_numpy.bias_hh_grad_stack, 0))
print("db_hh_torch :\n", gru_torch.all_weights[0][3].grad.data.numpy())
"""
--- 代码输出 ---
out_numpy :
[[[ 0.307578 0.38917 0.868305 0.226222 0.190614]
[ 0.073207 0.063558 0.439481 0.389793 -0.115967]
[ 0.342953 0.445099 0.494275 0.18942 -0.11399 ]]
[[ 0.262178 0.124979 0.753652 -0.024015 -0.082316]
[ 0.053777 -0.031063 0.358107 0.072536 -0.212627]
[ 0.208464 0.104906 0.509116 0.026988 -0.162194]]
[[ 0.141425 -0.034831 0.656582 -0.071536 -0.178199]
[-0.001007 -0.050403 0.306474 -0.05571 -0.174536]
[ 0.096458 -0.031054 0.448428 -0.061591 -0.166288]]]
out_torch :
[[[ 0.307578 0.38917 0.868305 0.226222 0.190614]
[ 0.073207 0.063558 0.439481 0.389793 -0.115967]
[ 0.342953 0.445099 0.494275 0.18942 -0.11399 ]]
[[ 0.262178 0.124979 0.753652 -0.024015 -0.082316]
[ 0.053777 -0.031063 0.358107 0.072536 -0.212627]
[ 0.208464 0.104906 0.509116 0.026988 -0.162194]]
[[ 0.141425 -0.034831 0.656582 -0.071536 -0.178199]
[-0.001007 -0.050403 0.306474 -0.05571 -0.174536]
[ 0.096458 -0.031054 0.448428 -0.061591 -0.166288]]]
--- 代码输出 ---
dx_numpy :
[[[-0.099959 0.125734 -0.100467 0.057584]
[-0.175371 0.278181 -0.08639 -0.081519]
[ 0.05688 0.215838 0.006216 -0.02745 ]]
[[-0.030292 0.117966 -0.023817 -0.078948]
[-0.041677 0.158772 0.007034 0.036202]
[-0.028633 0.213837 0.008546 -0.06969 ]]
[[-0.077594 0.189323 -0.042506 -0.126031]
[-0.021397 0.159047 0.027051 -0.110323]
[ 0.014817 0.163531 0.020699 -0.00104 ]]]
dx_torch :
[[[-0.099959 0.125734 -0.100467 0.057584]
[-0.175371 0.278181 -0.08639 -0.081519]
[ 0.05688 0.215838 0.006216 -0.02745 ]]
[[-0.030292 0.117966 -0.023817 -0.078948]
[-0.041677 0.158772 0.007034 0.036202]
[-0.028633 0.213837 0.008546 -0.06969 ]]
[[-0.077594 0.189323 -0.042506 -0.126031]
[-0.021397 0.159047 0.027051 -0.110323]
[ 0.014817 0.163531 0.020699 -0.00104 ]]]
--- 代码输出 ---
dw_ih_numpy :
[[ 0.054313 0.040754 0.046713 0.056132]
[ 0.143938 0.086123 0.133868 0.142727]
[-0.012153 -0.020432 -0.025273 -0.01912 ]
[ 0.250502 0.158463 0.181569 0.266292]
[ 0.089863 0.081379 0.079555 0.086468]
[ 0.380143 0.288339 0.288571 0.369766]
[ 0.421876 0.260856 0.32739 0.447516]
[ 0.095548 0.049706 0.056802 0.111859]
[ 0.518252 0.359941 0.455362 0.608289]
[ 0.466701 0.236542 0.426777 0.473724]
[ 1.355748 1.005534 1.19756 1.463865]
[ 1.438667 0.910484 1.333382 1.454368]
[ 0.813703 0.66155 0.821699 0.913758]
[ 1.679467 1.182209 1.29286 1.748798]
[ 1.473526 1.200475 1.428137 1.480017]]
dw_ih_torch :
[[ 0.054313 0.040754 0.046713 0.056132]
[ 0.143938 0.086123 0.133868 0.142727]
[-0.012153 -0.020432 -0.025273 -0.01912 ]
[ 0.250502 0.158463 0.181569 0.266292]
[ 0.089863 0.081379 0.079555 0.086468]
[ 0.380143 0.288339 0.288571 0.369766]
[ 0.421876 0.260856 0.32739 0.447516]
[ 0.095548 0.049706 0.056802 0.111859]
[ 0.518252 0.359941 0.455362 0.608289]
[ 0.466701 0.236542 0.426777 0.473724]
[ 1.355748 1.005534 1.19756 1.463865]
[ 1.438667 0.910484 1.333382 1.454368]
[ 0.813703 0.66155 0.821699 0.913758]
[ 1.679467 1.182209 1.29286 1.748798]
[ 1.473526 1.200475 1.428137 1.480017]]
--- 代码输出 ---
dw_hh_numpy :
[[ 0.029407 0.040854 0.067315 0.028967 0.006582]
[ 0.065199 0.106626 0.161274 0.110149 0.040565]
[-0.00217 0.002182 -0.009461 -0.011639 0.012655]
[ 0.154114 0.269694 0.300844 0.181655 0.08886 ]
[ 0.041277 0.04673 0.102476 0.038294 -0.006607]
[ 0.22554 0.374378 0.425336 0.237448 0.09482 ]
[ 0.251627 0.479232 0.472401 0.366764 0.17784 ]
[ 0.066873 0.099892 0.162182 0.052397 0.031526]
[ 0.313399 0.564707 0.579177 0.443881 0.153558]
[ 0.212389 0.407104 0.504583 0.422809 0.217823]
[ 0.224363 0.357882 0.464804 0.271478 0.07211 ]
[ 0.242434 0.376752 0.565195 0.340365 0.105577]
[ 0.20171 0.312777 0.443363 0.292979 0.040601]
[ 0.442316 0.728961 0.869799 0.489859 0.181804]
[ 0.333121 0.456453 0.815045 0.469363 0.037041]]
dw_hh_torch :
[[ 0.029407 0.040854 0.067315 0.028967 0.006582]
[ 0.065199 0.106626 0.161274 0.110149 0.040565]
[-0.00217 0.002182 -0.009461 -0.011639 0.012655]
[ 0.154114 0.269694 0.300844 0.181655 0.08886 ]
[ 0.041277 0.04673 0.102476 0.038294 -0.006607]
[ 0.22554 0.374378 0.425336 0.237448 0.09482 ]
[ 0.251627 0.479232 0.472401 0.366764 0.17784 ]
[ 0.066873 0.099892 0.162182 0.052397 0.031526]
[ 0.313399 0.564707 0.579177 0.443881 0.153558]
[ 0.212389 0.407104 0.504583 0.422809 0.217823]
[ 0.224363 0.357882 0.464804 0.271478 0.07211 ]
[ 0.242434 0.376752 0.565195 0.340365 0.105577]
[ 0.20171 0.312777 0.443363 0.292979 0.040601]
[ 0.442316 0.728961 0.869799 0.489859 0.181804]
[ 0.333121 0.456453 0.815045 0.469363 0.037041]]
--- 代码输出 ---
db_ih_numpy :
[ 0.106404 0.257179 -0.039206 0.442932 0.181347 0.660625 0.72062 0.218483
0.946051 0.759723 2.604006 2.604498 1.66155 3.041845 2.817076]
db_ih_torch :
[ 0.106404 0.257179 -0.039206 0.442932 0.181347 0.660625 0.72062 0.218483
0.946051 0.759723 2.604006 2.604498 1.66155 3.041845 2.817076]
--- 代码输出 ---
db_hh_numpy :
[ 0.106404 0.257179 -0.039206 0.442932 0.181347 0.660625 0.72062 0.218483
0.946051 0.759723 0.763548 0.904376 0.79913 1.342985 1.467562]
db_hh_torch :
[ 0.106404 0.257179 -0.039206 0.442932 0.181347 0.660625 0.72062 0.218483
0.946051 0.759723 0.763548 0.904376 0.79913 1.342985 1.467562]
"""