使用torch.nn包来构建神经网络。nn包依赖autograd包来定义模型并求导。 一个nn.Module包含各个层和一个forward(input)方法,该方法返回output。
它是一个简单的前馈神经网络,它接受一个输入,然后一层接着一层地传递,最后输出计算的结果。
神经网络的典型训练过程如下:
1.定义包含一些可学习的参数(或者叫权重)神经网络模型;
2.在数据集上迭代;
3.通过神经网络处理输入;
4.计算损失(输出结果和正确值的差值大小);
5.将梯度反向传播回网络的参数;
6.更新网络的参数,主要使用如下简单的更新原则: weight = weight - learning_rate * gradient
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
#复制并使用Net的父类的初始化方法,即先运行nn.Module的初始化函数
super(Net, self).__init__()
# 1 input image channel, 6 output channels, 5x5 square convolution
# kernel
self.conv1 = nn.Conv2d(1, 6, 5)#
self.conv2 = nn.Conv2d(6, 16, 5)
# an affine operation: y = Wx + b
self.fc1 = nn.Linear(16 * 5 * 5, 120)
#16*5*5:一开始的input是32*32,经过两次卷积池化之后的大小就是16(卷积核)*5(宽)*5(高), #然后再把它通过view后。全部输入到第一个全连接层的时候,input大小就是16*5*5。
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
#定义该神经网络的向前传播函数,该函数必须定义,一旦定义成功,向后传播函数也会自动生成(autograd)
def forward(self, x):
# Max pooling over a (2, 2) window
#输入x经过卷积conv1之后,经过激活函数ReLU,使用2x2的窗口进行最大池化Max pooling,然后更新到x。
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
# If the size is a square you can only specify a single number
#输入x经过卷积conv2之后,经过激活函数ReLU,使用2x2的窗口进行最大池化Max pooling,然后更新到x。
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
# view函数将张量x变形成一维的向量形式,总特征数并不改变,为接下来的全连接作准备
x = x.view(-1, self.num_flat_features(x))
# 输入x经过全连接1,再经过ReLU激活函数,然后更新x
x = F.relu(self.fc1(x))
# 输入x经过全连接2,再经过ReLU激活函数,然后更新x
x = F.relu(self.fc2(x))
# 输入x经过全连接3,然后更新x
x = self.fc3(x)
return x
def num_flat_features(self, x):
size = x.size()[1:] # all dimensions except the batch dimension
num_features = 1
for s in size:
num_features *= s
return num_features
net = Net()
print(net)
Net(
(conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(fc1): Linear(in_features=400, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)
)
net.parameters()
返回可被学习的参数(权重)列表和值
params = list(net.parameters())
print(len(params))
print(params[0].size()) # conv1's .weight
10
torch.Size([6, 1, 5, 5])
详细查看网络的参数
print(list(net.named_parameters()))
[('conv1.weight', Parameter containing:
tensor([[[[-0.0864, 0.1581, 0.1053, -0.0674, 0.0957],
[ 0.0101, -0.1520, -0.0849, 0.1507, 0.1818],
[ 0.0768, 0.0109, -0.0293, 0.0672, 0.0463],
[-0.1766, 0.0446, 0.1647, 0.1359, -0.1144],
[ 0.1169, 0.1585, -0.1872, 0.1876, -0.1431]]],
[[[-0.0481, -0.0338, -0.0865, 0.1113, 0.0281],
[-0.0110, -0.1859, 0.0683, -0.1943, 0.1785],
[-0.0686, 0.0450, 0.1983, -0.0679, -0.1770],
[ 0.1880, 0.0102, -0.0149, 0.1218, -0.1315],
[-0.1027, 0.0082, -0.0568, 0.1253, -0.0648]]],
[[[-0.1574, -0.1198, 0.1627, -0.1704, 0.0758],
[ 0.0399, 0.0644, -0.1151, -0.1112, 0.0837],
[-0.0485, 0.0777, 0.0932, 0.0524, 0.0491],
[-0.0652, 0.0403, -0.0345, -0.1912, 0.0434],
[ 0.0401, -0.1806, 0.1038, -0.0035, 0.1026]]],
[[[ 0.0366, -0.1110, -0.1358, -0.0388, -0.0136],
[ 0.0998, 0.0714, 0.0491, -0.1141, -0.1639],
[-0.0606, 0.1912, 0.0887, -0.1987, 0.1686],
[-0.0851, -0.0431, 0.0035, 0.0634, 0.1633],
[-0.1842, -0.1439, -0.1178, 0.1330, 0.1844]]],
[[[ 0.1000, 0.1853, 0.0840, 0.0314, 0.1282],
[ 0.0399, 0.1136, 0.1043, -0.1371, 0.0194],
[-0.0125, -0.0674, 0.1311, 0.0281, -0.0820],
[-0.0661, 0.0128, -0.1913, -0.0224, 0.0260],
[ 0.0750, 0.0270, 0.1146, 0.1069, 0.0046]]],
[[[ 0.0444, 0.1200, -0.1469, 0.1303, -0.0595],
[-0.1066, -0.1726, 0.1489, -0.0327, -0.0576],
[ 0.1739, 0.1475, 0.1767, 0.0594, 0.1139],
[ 0.0412, -0.1624, -0.1075, -0.1895, -0.1300],
[-0.1598, -0.1082, -0.0381, -0.0411, 0.0309]]]], requires_grad=True)), ('conv1.bias', Parameter containing:
tensor([ 0.2000, 0.1894, 0.1787, 0.0921, -0.0161, 0.1778],
requires_grad=True)), ('conv2.weight', Parameter containing:
tensor([[[[-4.6751e-02, 6.2027e-02, 6.8025e-02, 6.2829e-02, 5.9207e-02],
[ 3.4858e-02, -4.6945e-02, 3.6028e-03, 9.2037e-03, -5.0105e-02],
[ 1.9097e-02, -5.0400e-02, 6.0714e-02, -2.4261e-02, 2.4184e-02],
[ 6.4232e-03, 4.2253e-02, 3.8566e-02, 7.8425e-02, -4.7642e-02],
[-1.6932e-02, 7.0821e-02, -7.0269e-02, 4.5952e-02, 4.7312e-02]],
[[ 2.4947e-02, -5.6880e-02, -5.0089e-02, -1.8905e-02, -5.6852e-02],
[-2.1401e-02, 7.6649e-02, 5.6511e-02, 2.9109e-02, -4.3187e-02],
[-3.8747e-02, -5.5686e-02, -6.8277e-02, 4.3523e-02, -3.1327e-02],
[ 4.6242e-02, -6.7128e-02, 2.1462e-02, 2.1463e-02, -5.7259e-02],
[ 8.5649e-03, 4.0587e-02, 7.0886e-02, 3.4475e-02, 4.8482e-02]],
[[ 3.4679e-02, -5.7377e-02, -3.8839e-02, 5.0153e-02, 4.7679e-02],
[-6.6696e-02, -5.2049e-02, 6.1444e-03, -1.0752e-02, -3.8693e-02],
[-6.1365e-02, 6.7706e-02, -8.0698e-02, -2.0662e-02, -6.6051e-02],
[-6.0435e-02, -4.1243e-02, 7.0561e-02, -7.8558e-02, -7.2256e-02],
[-2.9384e-02, -5.9898e-02, 1.4626e-02, 1.2504e-02, -1.6448e-03]],
[[-8.3191e-03, -2.4209e-02, -2.4460e-02, -3.7102e-02, -2.8801e-02],
[ 7.0775e-02, -7.7456e-02, 1.6669e-02, -5.7365e-02, 5.1116e-02],
[-2.2186e-03, 7.1819e-02, 1.4087e-02, 1.2796e-02, -8.0466e-02],
[-2.2954e-02, 1.7675e-02, -6.7700e-02, 4.2731e-02, -4.5965e-02],
[ 1.9493e-02, 2.8682e-02, -3.5744e-03, -6.2132e-02, -1.6093e-02]],
[[-4.6586e-02, -5.1907e-03, -2.0616e-02, -5.0160e-03, -1.0218e-02],
[ 3.0281e-02, 1.0599e-02, -4.4635e-02, -6.2386e-03, 1.2839e-02],
[ 6.3183e-02, -4.7749e-02, -1.5905e-02, 3.7993e-03, 6.9684e-03],
[-4.3931e-02, 1.2768e-02, 8.1540e-02, 4.5033e-02, -4.2173e-02],
[ 5.9427e-02, -4.7729e-02, 4.9988e-02, 5.0745e-02, 6.8716e-02]],
[[ 3.1921e-02, 7.3147e-02, 8.1235e-02, 4.7650e-02, 1.7855e-02],
[ 4.9684e-02, 4.8450e-02, 2.7255e-02, -8.2837e-03, -7.4009e-02],
[ 6.2119e-02, -7.7966e-02, -6.6083e-02, 5.2162e-02, 2.8479e-02],
[ 7.4365e-02, -2.8491e-02, 9.6219e-03, 1.9529e-02, -2.1051e-03],
[ 6.5328e-02, 3.6120e-02, -2.6581e-02, 6.8440e-02, -1.8014e-02]]],
[[[ 2.6748e-02, 5.5181e-02, 2.2528e-02, 6.6341e-03, 5.5057e-02],
[ 3.3438e-02, -3.9520e-02, -3.8108e-02, -4.8951e-03, 4.3503e-02],
[-1.7223e-02, -6.1210e-02, -1.6332e-02, 7.2617e-02, 1.2141e-02],
[ 1.5687e-02, -7.5238e-02, -3.2138e-02, -6.2180e-03, 7.4173e-02],
[ 1.3187e-02, 3.0043e-02, -4.9262e-02, 1.0382e-02, -5.2707e-02]],
[[-6.1639e-02, -1.6759e-03, 9.1899e-03, -2.0154e-02, 4.7649e-02],
[ 4.4110e-02, -4.4745e-02, -6.4256e-02, -5.3946e-02, 1.4076e-02],
[ 2.5333e-02, 6.0833e-02, -3.7473e-03, -3.7682e-02, -7.7613e-03],
[ 1.2864e-02, 6.8993e-02, -2.3292e-02, -3.1719e-02, 1.6382e-02],
[ 7.3083e-02, 7.3863e-02, -2.6046e-02, -4.9677e-02, 1.5290e-02]],
[[ 7.4857e-02, -5.2538e-02, -8.0590e-02, -6.6272e-02, -6.3306e-02],
[-4.3562e-02, -6.7632e-02, -7.2244e-02, 6.6524e-02, 3.8218e-02],
[ 2.9765e-02, -1.5771e-02, -1.2642e-02, -6.1716e-02, -6.1412e-02],
[-1.2431e-02, 7.5718e-02, -4.0329e-02, 2.4516e-02, -7.6501e-02],
[-2.5373e-02, -7.6981e-02, 4.2991e-02, -5.8949e-02, -8.1620e-02]],
[[-2.8247e-04, -4.3295e-02, -7.0514e-04, -6.0201e-02, -9.4066e-03],
[ 2.6099e-02, -6.8910e-02, -5.6376e-02, -7.7094e-02, -3.9688e-02],
[ 2.1004e-02, 3.8341e-02, -4.4436e-03, 4.8700e-02, 1.1442e-02],
[-2.7794e-02, -2.9512e-02, -8.3854e-03, 2.5953e-02, -5.8595e-02],
[-9.8773e-03, 3.8318e-02, -6.4992e-02, -6.4106e-02, 2.1669e-02]],
[[-5.6786e-02, -2.3078e-02, -6.3540e-02, 6.4722e-02, -3.6102e-03],
[ 2.6318e-03, -5.8278e-02, -7.0678e-02, 6.4511e-02, 4.9434e-02],
[-1.7529e-02, -3.6845e-02, -4.1389e-02, -1.7150e-02, -1.7080e-02],
[-5.1267e-02, -5.4333e-02, 3.7604e-02, 2.5120e-02, 4.5091e-03],
[ 6.0331e-02, 1.0586e-02, -4.9640e-03, -4.0446e-02, -5.2631e-02]],
[[ 4.2650e-02, 8.0488e-02, 2.2548e-02, 2.5217e-02, 3.9701e-02],
[ 2.4243e-02, 3.7999e-02, -2.1265e-02, -4.0239e-02, 6.2443e-02],
[-6.5399e-02, -4.2270e-03, 7.1457e-02, 7.6897e-02, 7.7068e-03],
[ 3.0916e-02, -1.0364e-02, 1.8940e-02, -5.4591e-02, 3.6602e-02],
[ 7.2685e-02, -5.0956e-02, -7.2337e-02, 7.0943e-02, 7.7663e-02]]],
[[[ 6.8438e-02, -4.1001e-02, 4.9182e-02, -1.0405e-02, 5.2470e-02],
[ 5.1253e-03, 3.2332e-02, 5.4736e-03, 7.6816e-02, -2.4995e-02],
[-1.5874e-02, -6.0776e-02, 6.4105e-03, 3.5231e-02, 2.4181e-03],
[ 4.9819e-03, 8.2966e-03, 4.6681e-02, 6.0710e-02, 2.3426e-02],
[ 3.0622e-03, -4.2392e-02, 7.9669e-02, -6.0518e-02, -7.7647e-02]],
[[ 5.6804e-02, 3.2624e-02, 2.6365e-02, 1.9988e-02, 5.0879e-02],
[ 5.7976e-03, 5.6419e-02, -6.0317e-02, -6.3418e-02, 7.9154e-02],
[ 5.4380e-02, 5.0265e-02, -3.4530e-02, -3.4289e-02, -5.7359e-02],
[-2.0039e-02, -1.1091e-02, 6.2006e-02, 4.2279e-02, 1.3524e-02],
[-5.1808e-02, 1.6028e-02, -1.3753e-02, -4.4715e-02, 6.4370e-02]],
[[-7.5229e-02, -3.1394e-02, -5.7542e-02, -2.5208e-03, 3.8752e-02],
[-3.8169e-02, -5.5508e-02, 4.4447e-03, 1.3224e-02, -9.8716e-03],
[ 5.8168e-02, -3.0163e-02, -1.7469e-02, -6.4066e-02, 5.9589e-02],
[ 1.6919e-02, 2.7378e-02, 2.6752e-02, -2.9841e-03, 6.6029e-03],
[ 4.5793e-02, -1.0817e-02, 5.7493e-02, -1.9364e-02, 6.9119e-02]],
[[-1.7484e-02, -7.0515e-02, -2.1695e-02, 2.5091e-02, 2.9314e-02],
[-3.3494e-02, 8.0373e-02, -5.5042e-02, 5.6274e-02, -6.3458e-02],
[-7.4188e-02, -7.3826e-02, 6.3572e-03, 7.2682e-02, -4.4463e-02],
[-3.3026e-02, -5.7821e-02, 5.5200e-02, -5.6740e-02, 7.5848e-02],
[-1.8855e-02, 5.5604e-03, 2.7479e-02, 3.2113e-02, -7.1571e-02]],
[[-5.9697e-02, 5.5090e-02, 3.4306e-02, -2.8067e-02, 1.0584e-02],
[ 1.9661e-03, -8.2800e-03, 4.1372e-02, 7.3996e-02, -2.4893e-02],
[ 5.5144e-02, 4.6595e-03, 5.8500e-02, -5.4630e-02, -4.6803e-04],
[-1.2829e-02, -7.8442e-02, -5.3737e-02, 4.0796e-02, 6.9439e-02],
[-7.3417e-02, 6.3357e-02, 2.7502e-02, -3.9882e-02, 6.1540e-03]],
[[ 3.7075e-02, 6.7935e-02, 2.5044e-02, -4.7587e-02, 4.8540e-02],
[ 5.1211e-02, -1.1007e-02, 7.1380e-02, 7.2305e-02, -4.2107e-02],
[ 2.6421e-02, 2.0946e-02, -4.6829e-02, 2.6259e-02, -6.5568e-02],
[-4.4335e-02, -8.0355e-02, -2.2185e-02, 8.1275e-02, -4.5410e-02],
[-7.9133e-02, 7.0272e-03, -5.9176e-02, -7.3026e-02, 6.6735e-02]]],
...,
[[[ 4.7453e-02, 1.6624e-03, -4.9159e-02, -6.3126e-02, 7.6058e-02],
[-7.6433e-02, 6.6995e-02, 7.6168e-02, 2.9466e-02, -4.2630e-02],
[-8.1500e-02, 1.6365e-02, 4.9923e-02, 9.0797e-03, -7.3443e-02],
[-9.7154e-03, -3.5494e-02, 7.9512e-02, -3.8434e-02, -1.7049e-02],
[ 3.3263e-02, -4.0653e-02, 1.1763e-02, 1.1847e-02, 1.3137e-02]],
[[-7.0701e-02, 5.2604e-02, -2.0280e-02, -5.1948e-02, 6.1880e-02],
[ 5.0646e-02, -6.0460e-02, 2.2518e-02, 5.4762e-02, 2.6427e-04],
[ 5.5895e-02, -3.1711e-02, -7.7939e-02, -3.7241e-02, 5.3921e-02],
[ 8.0110e-02, 3.4243e-02, -3.8894e-02, 4.3316e-02, -6.2829e-02],
[-2.0985e-02, 2.0693e-02, -4.0446e-02, 5.4924e-02, -7.0742e-02]],
[[ 5.8691e-02, 5.7382e-02, 6.1324e-03, 7.0543e-02, 8.1322e-02],
[ 1.7258e-02, -8.9442e-03, 5.5113e-02, -1.6722e-03, -3.3769e-02],
[-1.2999e-02, -3.7392e-02, 5.2130e-02, -6.1703e-02, 2.5096e-02],
[ 2.4746e-02, 1.8096e-02, -5.4290e-03, 5.6583e-03, -3.2907e-02],
[ 2.8275e-02, -2.2063e-03, 5.2794e-02, 4.5287e-02, 1.2212e-02]],
[[-6.6913e-02, -1.5204e-02, -2.6589e-02, -3.9329e-02, -1.7749e-02],
[ 5.4253e-02, -4.8495e-02, 1.3793e-02, 6.2208e-02, 3.4392e-02],
[-3.9233e-02, 1.6612e-02, -1.9025e-02, -4.1305e-02, -4.1636e-02],
[ 6.3880e-03, -7.3584e-02, -7.6999e-02, 3.9100e-02, 6.4777e-02],
[ 6.0852e-02, -1.7153e-02, 6.4263e-02, -6.6408e-02, -2.8885e-02]],
[[-5.6493e-02, -8.9312e-03, -6.7856e-02, -2.2723e-03, 6.3610e-02],
[-8.4754e-03, -4.5080e-02, 2.0656e-02, 7.2127e-02, 6.0980e-02],
[-2.8793e-02, 2.5090e-02, 6.1947e-02, -5.9376e-02, -7.8424e-02],
[-7.0324e-02, -2.1602e-03, -3.6420e-02, 3.6624e-02, 6.1857e-02],
[-1.1167e-02, 8.7337e-04, 1.7752e-02, 7.5435e-02, -2.4711e-02]],
[[ 7.4637e-02, 6.9347e-02, 3.9538e-02, 4.4932e-02, 3.9881e-02],
[-2.0344e-02, -5.1092e-02, 7.0027e-03, -8.0631e-02, -1.7589e-02],
[-6.1046e-02, 3.9349e-02, -7.4361e-02, 6.2805e-02, -6.0173e-02],
[ 3.9680e-02, 6.1395e-02, 6.7319e-02, 6.8641e-03, -2.2438e-03],
[ 3.9790e-02, 1.5846e-02, 4.8661e-02, 7.7811e-02, 4.8407e-02]]],
[[[-5.9373e-03, -1.3359e-02, -4.8370e-02, -4.6881e-02, 6.3805e-02],
[ 6.6859e-02, 5.4224e-02, -1.5097e-02, 2.9913e-02, 4.1902e-02],
[-4.7653e-02, -5.5400e-02, -4.5789e-02, -2.2652e-02, 1.4833e-03],
[ 4.5867e-03, 9.4236e-03, 6.7346e-02, -2.9957e-02, 2.4964e-02],
[ 1.4203e-02, 7.3159e-02, -3.3535e-02, -4.0967e-02, -6.9503e-02]],
[[ 1.8814e-02, 2.5195e-02, 2.5055e-02, -6.1476e-02, 2.6273e-02],
[ 8.0260e-02, -1.8977e-03, -1.4657e-02, -3.2918e-02, 7.8980e-02],
[ 2.2730e-02, -5.7422e-02, -5.7983e-02, 7.5397e-02, -5.1492e-02],
[-7.7555e-02, -3.0075e-02, -3.7065e-02, 2.9274e-02, 7.8566e-02],
[ 4.8337e-02, -8.0466e-02, -3.4578e-02, 7.4706e-02, -4.7404e-02]],
[[ 7.3456e-02, 4.5859e-02, 4.8964e-02, 6.4293e-02, 3.2808e-02],
[ 8.0656e-02, 8.7235e-03, 5.7811e-02, -7.8568e-02, -2.5927e-02],
[-7.4949e-02, -8.0697e-02, 2.6019e-03, -4.1251e-02, -2.1189e-02],
[ 3.4552e-02, -6.9636e-02, -2.4231e-02, 7.0092e-03, 1.3193e-03],
[-5.3746e-02, -1.5284e-02, -2.3633e-02, -7.1587e-02, -7.4092e-03]],
[[-1.0503e-02, -4.2436e-02, 6.0257e-02, 3.7005e-02, -5.7754e-03],
[ 5.1638e-02, -2.6558e-02, 1.1289e-02, -6.4803e-02, -3.0699e-02],
[ 4.8014e-02, -5.4398e-02, 5.9667e-02, 7.3184e-03, -5.9711e-02],
[-7.0265e-02, -5.3415e-02, 7.1987e-02, 8.8195e-03, 3.9838e-03],
[ 4.5838e-02, 2.0373e-02, 7.5955e-02, -1.4837e-02, 5.9894e-04]],
[[ 2.7225e-03, 8.1387e-02, 7.1768e-02, 5.3041e-02, -1.8938e-02],
[-4.4462e-03, -7.4515e-03, 7.2474e-02, -8.1187e-02, 6.8177e-02],
[-6.3188e-02, 7.8637e-02, 7.5461e-02, -4.3640e-02, -2.7294e-02],
[ 1.4516e-03, -2.1540e-02, 6.3983e-02, -5.5237e-02, 6.4640e-02],
[ 6.7225e-02, -3.2269e-02, 5.2291e-02, 4.1978e-02, 1.0545e-02]],
[[-3.9995e-02, 1.5795e-02, -1.1020e-02, 3.0129e-02, -7.5448e-02],
[ 4.8956e-02, 7.5318e-02, 4.4730e-03, -4.0962e-02, 2.6889e-02],
[ 6.6497e-02, 1.6028e-02, -7.3739e-02, -3.7472e-02, 7.3378e-03],
[ 6.1110e-02, -7.4697e-02, -3.3396e-02, -5.3662e-02, 1.9632e-02],
[-2.7660e-02, -6.3403e-02, 7.5017e-02, -6.2683e-05, -3.7116e-02]]],
[[[-1.8004e-02, 6.5819e-02, 6.2223e-02, 1.6332e-02, 1.8143e-02],
[-3.6172e-02, -6.9717e-02, -5.4694e-03, 2.3101e-02, 7.9629e-03],
[ 6.3475e-02, 3.4132e-02, -5.2767e-03, -3.8902e-02, -5.6896e-02],
[ 6.7207e-02, -4.2420e-02, 5.5340e-02, 3.7744e-02, 4.0298e-02],
[-6.8564e-02, 2.5046e-02, 6.8845e-03, 4.5561e-02, 6.5920e-02]],
[[ 4.8047e-02, -2.8234e-02, 7.3325e-02, -4.8417e-02, 3.9393e-02],
[ 6.5448e-02, -1.5095e-02, 4.1826e-02, -3.8581e-02, -2.3085e-02],
[-3.8703e-03, 8.0055e-02, 3.9432e-02, -3.5234e-02, -7.4530e-02],
[ 6.6605e-02, 7.1790e-02, -2.6339e-02, -6.3635e-02, 5.8536e-02],
[-3.7311e-02, 2.6147e-02, -6.4885e-02, -6.8364e-02, -8.1779e-03]],
[[ 5.7778e-04, 6.5419e-02, 5.4751e-02, -2.7986e-02, 2.1922e-02],
[-1.1870e-02, 4.0985e-02, 2.4626e-02, -6.6320e-03, -4.1242e-02],
[-7.0432e-02, 2.8016e-02, -5.6965e-02, -1.4918e-02, -4.4836e-02],
[ 6.8992e-02, 2.6582e-02, 2.8458e-02, -2.8932e-02, -6.5144e-02],
[ 3.3061e-02, -9.9823e-03, -6.1666e-02, -5.6478e-02, 1.8484e-02]],
[[-3.0624e-02, -4.3539e-02, 6.4283e-02, -3.4875e-02, 7.4699e-02],
[-2.9344e-04, -1.3068e-03, -2.6016e-02, 4.0942e-02, 2.7596e-02],
[ 3.8677e-02, 1.5856e-03, -6.9073e-02, 6.8789e-02, 6.7985e-02],
[ 1.2172e-02, 6.0105e-02, -2.9876e-02, 1.6004e-02, 2.7459e-02],
[-6.4051e-02, -4.5696e-02, -2.2469e-02, -7.0772e-02, 3.1664e-02]],
[[ 4.2269e-02, 1.6835e-02, -1.6207e-02, 1.9925e-02, -1.3160e-02],
[-3.8214e-02, -2.0697e-02, 1.2636e-02, 5.8541e-02, 6.5787e-02],
[ 5.4092e-02, -4.1293e-02, 7.1718e-03, 3.5447e-03, 5.9314e-02],
[-5.3280e-02, 5.7965e-02, -2.5258e-02, 5.3279e-02, 1.2102e-02],
[-6.1438e-02, -4.7414e-02, -6.4839e-02, -4.9603e-02, 3.7255e-02]],
[[ 2.8404e-04, -3.0073e-02, 1.2103e-02, 2.6228e-02, -4.6524e-02],
[ 6.2227e-02, 7.4788e-02, 7.7427e-02, 6.0245e-02, 4.3996e-02],
[-4.6962e-03, -1.5549e-02, 2.2174e-02, -6.1404e-02, -5.2019e-02],
[ 1.6444e-02, 3.0072e-02, 4.8501e-02, 7.1959e-02, 7.8801e-02],
[-3.8354e-03, -1.2520e-02, -7.9206e-02, 4.9887e-02, -7.5002e-02]]]],
requires_grad=True)), ('conv2.bias', Parameter containing:
tensor([-0.0291, -0.0585, -0.0470, 0.0497, -0.0030, 0.0714, -0.0772, -0.0729,
-0.0146, -0.0744, -0.0542, 0.0765, -0.0288, -0.0186, -0.0451, 0.0693],
requires_grad=True)), ('fc1.weight', Parameter containing:
tensor([[ 0.0127, 0.0351, 0.0101, ..., 0.0015, 0.0404, -0.0455],
[ 0.0154, 0.0499, -0.0289, ..., 0.0344, -0.0245, -0.0473],
[ 0.0072, 0.0226, 0.0097, ..., 0.0363, 0.0141, 0.0354],
...,
[ 0.0225, 0.0071, -0.0407, ..., -0.0101, -0.0304, 0.0344],
[ 0.0367, 0.0057, -0.0408, ..., -0.0393, -0.0016, 0.0063],
[ 0.0365, -0.0315, 0.0336, ..., -0.0208, -0.0220, 0.0388]],
requires_grad=True)), ('fc1.bias', Parameter containing:
tensor([ 0.0048, -0.0018, 0.0441, -0.0093, -0.0455, -0.0172, 0.0060, 0.0095,
-0.0444, -0.0034, -0.0198, -0.0411, 0.0237, -0.0362, -0.0433, 0.0259,
-0.0194, 0.0440, -0.0057, 0.0166, 0.0372, -0.0365, 0.0167, 0.0237,
-0.0316, -0.0307, -0.0303, -0.0447, 0.0386, -0.0020, -0.0163, -0.0206,
-0.0429, 0.0362, -0.0114, 0.0069, -0.0320, 0.0469, 0.0204, 0.0199,
0.0402, 0.0199, 0.0290, -0.0053, -0.0403, -0.0448, -0.0317, -0.0443,
-0.0072, 0.0160, -0.0249, -0.0263, -0.0495, -0.0168, -0.0374, -0.0191,
0.0002, -0.0181, 0.0465, -0.0054, -0.0368, -0.0173, -0.0434, 0.0306,
0.0474, 0.0398, 0.0285, -0.0076, -0.0255, 0.0232, 0.0022, 0.0059,
0.0353, 0.0013, 0.0414, -0.0094, 0.0435, -0.0331, -0.0408, -0.0381,
0.0160, -0.0089, 0.0269, 0.0171, 0.0205, 0.0466, 0.0360, -0.0251,
-0.0087, 0.0258, -0.0364, -0.0393, -0.0077, -0.0196, 0.0485, 0.0066,
-0.0440, -0.0421, -0.0475, 0.0283, 0.0257, -0.0366, -0.0414, -0.0026,
0.0096, 0.0189, 0.0289, -0.0230, -0.0331, -0.0086, 0.0300, 0.0070,
-0.0250, 0.0458, -0.0106, -0.0499, -0.0290, 0.0242, -0.0086, -0.0206],
requires_grad=True)), ('fc2.weight', Parameter containing:
tensor([[-0.0868, -0.0417, 0.0663, ..., 0.0089, -0.0459, -0.0327],
[-0.0881, 0.0408, -0.0878, ..., 0.0712, -0.0304, 0.0808],
[ 0.0405, 0.0224, -0.0019, ..., -0.0402, -0.0458, -0.0392],
...,
[-0.0641, 0.0234, 0.0442, ..., -0.0221, 0.0643, 0.0412],
[ 0.0731, 0.0759, -0.0588, ..., 0.0214, 0.0084, 0.0542],
[-0.0466, 0.0032, 0.0363, ..., -0.0843, 0.0305, -0.0020]],
requires_grad=True)), ('fc2.bias', Parameter containing:
tensor([ 1.8544e-03, 7.7148e-02, -2.2453e-02, -5.1801e-02, 4.5913e-02,
-7.2284e-02, -3.2864e-02, -6.8654e-02, -2.1057e-05, 7.0599e-02,
2.5858e-02, 7.3280e-02, 1.1011e-02, 6.3518e-02, -4.0666e-02,
-4.1031e-02, 6.6883e-02, -1.2299e-02, -6.6735e-02, 4.0905e-03,
-6.5498e-02, -2.6964e-03, 8.2400e-02, -4.4789e-02, 3.2078e-02,
-2.4232e-02, -8.0591e-02, 5.7183e-03, -1.8703e-03, -8.2100e-02,
-6.2772e-02, 7.0709e-02, -4.4223e-02, -8.6373e-03, -8.9721e-02,
-8.7231e-02, 2.1281e-02, 6.5149e-02, -5.8912e-02, -6.0408e-03,
3.9662e-02, -5.1060e-02, -5.0107e-02, 7.2825e-02, 8.5803e-02,
3.2855e-02, -4.8813e-02, -6.2276e-02, 6.0332e-02, -2.1124e-03,
2.0599e-02, 5.5399e-02, 1.8062e-02, 3.8259e-02, -1.0740e-02,
-4.9024e-02, 7.9856e-02, 3.3842e-02, 3.2818e-03, -8.6486e-02,
6.8716e-04, -3.4156e-02, -6.9502e-02, 6.9801e-02, -1.9957e-02,
-8.9092e-02, 8.5563e-02, 6.8908e-02, -1.2464e-03, -5.8290e-02,
5.7215e-03, -1.9210e-02, 5.3976e-02, -8.5581e-02, -5.1105e-02,
-1.1988e-02, -8.9009e-02, 7.1955e-02, -7.3027e-02, -3.2002e-02,
-7.7311e-03, 6.5628e-02, 7.7175e-02, -2.6717e-02],
requires_grad=True)), ('fc3.weight', Parameter containing:
tensor([[-5.6418e-02, 8.7256e-03, 1.4857e-03, -4.4241e-02, 2.6271e-02,
6.2223e-02, 9.5526e-02, 8.1587e-02, -1.0604e-01, -1.4852e-02,
2.0481e-02, 6.4787e-02, -1.7821e-02, -7.5547e-02, 7.8957e-02,
-1.3885e-02, -1.0226e-01, -8.2753e-02, 9.8486e-02, -3.4305e-02,
2.7100e-03, 1.8032e-02, -1.0715e-01, 4.9688e-03, -9.7491e-02,
5.7741e-02, 1.0480e-01, -7.5001e-02, -1.0038e-01, -8.3054e-02,
1.4810e-02, -2.5294e-02, -3.5486e-02, 1.0084e-01, 1.2828e-02,
-2.8502e-03, -4.3740e-02, -7.8102e-03, 2.1588e-02, -7.5598e-02,
-3.3590e-02, -6.8185e-02, 3.3957e-02, 2.1573e-02, -5.8785e-02,
6.0361e-02, -8.8535e-02, 6.7021e-02, -9.3063e-02, 1.9414e-02,
-6.7784e-02, 3.7093e-02, 7.5324e-03, -1.9678e-02, 5.4143e-02,
-8.9511e-02, -9.2640e-02, 6.8033e-02, 2.4691e-02, -4.7642e-02,
5.3445e-02, -6.4451e-02, 4.2707e-02, -4.0519e-02, -9.8571e-02,
6.6185e-02, 5.7230e-02, -3.7836e-02, 9.9665e-02, -2.2986e-02,
-5.7756e-02, -6.8581e-02, -1.1435e-02, -3.2040e-02, 1.0683e-01,
-9.7724e-02, -4.2304e-02, -1.2573e-02, 6.7023e-02, 4.1495e-02,
7.6838e-02, -6.6680e-02, -7.8055e-02, 5.0132e-02],
[ 1.9141e-02, 2.5746e-02, 1.0312e-01, 1.4601e-02, 3.0647e-02,
9.3271e-02, 6.4014e-02, -5.4571e-02, -5.6861e-02, -2.6344e-02,
9.0764e-02, 5.0952e-02, -9.0798e-02, 2.1471e-02, -1.0136e-01,
-8.5665e-02, 2.1246e-02, -2.4271e-02, -6.6701e-02, 7.1168e-02,
6.3234e-02, -9.3446e-03, 2.8220e-03, -1.4337e-02, -2.4543e-02,
1.0582e-01, 2.0000e-02, -1.0267e-01, -3.5270e-03, -5.9467e-02,
-2.3974e-03, 2.7704e-02, -6.6317e-02, 8.0701e-02, -1.0008e-01,
-1.0211e-01, -8.1395e-02, -6.5022e-02, 5.9279e-02, -7.2579e-02,
-9.5273e-02, -6.4547e-02, -8.0597e-02, 8.4433e-03, -4.7848e-02,
9.7750e-02, 9.3517e-03, 7.4619e-02, -3.9448e-02, -8.5778e-02,
-9.5482e-02, -5.5330e-03, -1.0796e-01, 6.7140e-02, -2.9456e-02,
-1.8318e-02, -4.2094e-02, 1.8795e-02, 1.0150e-02, 4.3263e-02,
-7.2860e-02, 7.8749e-02, -5.2190e-03, 1.5920e-02, 5.5891e-02,
-4.6358e-02, 9.4703e-02, 8.4561e-02, 5.2485e-02, 7.6034e-02,
-2.0514e-03, 5.1548e-02, -2.3387e-03, -1.0369e-01, -1.7051e-03,
9.5608e-02, -7.4236e-02, 8.8553e-02, -7.3277e-02, -6.1156e-02,
-4.4343e-02, 1.5902e-02, 8.3211e-02, -4.9491e-02],
[-5.4914e-03, -8.3233e-02, 2.0474e-02, -7.9890e-02, -2.6837e-02,
9.9609e-02, -6.1847e-02, 8.2810e-03, -7.8208e-02, -2.2464e-02,
-8.8908e-02, 8.8578e-02, 7.5675e-02, -1.2950e-02, 9.0193e-02,
-2.1504e-02, 8.8890e-02, 9.8521e-02, -7.5039e-02, 1.0464e-01,
9.4199e-02, -3.8682e-02, 8.7018e-02, 8.5465e-02, -7.7055e-02,
-7.1284e-02, -1.0897e-01, -5.4695e-03, -7.9368e-02, 3.3380e-02,
7.2154e-02, 9.5790e-02, 6.1063e-04, -1.7953e-02, 6.1940e-02,
-7.8037e-02, 9.0474e-02, 9.4975e-02, 4.7951e-02, 2.9325e-02,
2.9166e-04, 5.0427e-02, 4.9875e-02, 1.9344e-02, -1.5598e-02,
-8.8184e-02, 4.9534e-02, 4.1846e-02, -3.4013e-02, -5.9683e-02,
-7.4203e-02, 4.6901e-02, -5.7360e-02, 3.9085e-02, -1.0259e-01,
5.0026e-02, 9.3175e-02, -4.9197e-02, -9.1434e-02, -6.5976e-02,
9.9470e-02, 2.5012e-02, 2.3345e-02, -7.4206e-02, 9.9779e-02,
-9.9292e-02, 9.1207e-02, -1.8750e-02, -8.7222e-02, 4.5009e-02,
6.2097e-02, 1.1889e-02, -3.1094e-02, 6.2701e-02, -3.6896e-02,
7.9025e-03, 3.9122e-02, -6.6788e-02, 5.8422e-02, -1.8931e-02,
4.7233e-02, 7.9713e-03, 6.5758e-02, 2.4974e-02],
[ 9.3710e-02, 6.3587e-03, -1.1848e-02, 4.5289e-02, -2.2177e-02,
-7.1277e-02, 4.7252e-02, 1.6109e-02, -6.7931e-02, -1.7756e-02,
2.0796e-02, 7.4046e-02, -1.0572e-01, -9.9937e-02, 6.4527e-04,
7.9186e-02, -2.9363e-02, 9.3051e-02, 2.8610e-02, 5.4603e-02,
1.7861e-02, 5.2665e-02, 9.7904e-02, -2.5723e-02, -3.2497e-02,
9.0306e-02, -8.9747e-02, 7.0985e-02, 4.8148e-02, 8.0181e-02,
-7.6122e-02, -1.0335e-01, -8.7064e-02, 9.8974e-02, -5.0385e-02,
4.0987e-02, 1.0069e-01, 4.4886e-02, -6.0288e-02, 2.8446e-02,
9.9679e-02, 5.8738e-02, -8.6646e-02, 4.9163e-02, -1.0731e-01,
-8.7259e-02, 8.3456e-02, 6.8349e-02, -8.7784e-02, -8.0034e-02,
4.8366e-02, 4.0478e-02, 5.7153e-03, -5.4092e-02, 5.2010e-03,
8.2633e-02, -1.0290e-01, 5.6955e-02, -9.8502e-02, 4.5809e-03,
4.4042e-02, 9.9767e-03, 5.3250e-02, -2.6768e-02, -7.9477e-02,
-1.0056e-01, -1.5802e-03, -6.9598e-03, -6.3054e-02, 1.0204e-01,
-9.7165e-02, -9.1548e-02, 2.5512e-02, -5.9958e-02, -2.2757e-02,
-7.9634e-03, -1.0523e-01, -6.3195e-02, -1.2414e-02, 1.1663e-02,
-8.9895e-02, -8.0842e-02, 1.0750e-01, -3.0337e-02],
[ 5.6841e-02, 8.6600e-02, 1.7552e-02, -5.7772e-03, -4.4277e-02,
-4.1141e-03, -4.0124e-02, 3.6352e-02, -8.9301e-02, 4.6127e-02,
7.3090e-02, -9.3387e-02, -3.1823e-02, 4.5977e-02, 1.0641e-01,
-2.9397e-02, 4.1653e-02, 5.7302e-03, 4.0278e-02, -4.2121e-02,
-1.0518e-01, 9.9239e-02, -2.4550e-02, 5.3995e-02, -4.9111e-02,
-1.2914e-02, 5.2915e-02, 8.5013e-02, 6.0425e-02, 9.0397e-02,
-7.9633e-02, 3.0904e-02, 8.4106e-02, 7.7185e-02, -1.1237e-02,
-3.8323e-02, 4.4578e-03, -4.1946e-02, -9.6759e-02, -1.2592e-02,
-4.1569e-02, -1.0844e-01, -5.4256e-02, 2.7450e-02, -6.3376e-02,
-1.0629e-01, -2.7586e-02, -4.1033e-02, -1.0556e-01, -1.6890e-03,
-1.0441e-01, -2.9892e-02, 9.5217e-03, 8.4235e-02, -1.0753e-01,
2.1481e-02, 5.6448e-03, -9.0530e-02, 1.0382e-01, -2.1958e-02,
4.4257e-02, 4.6737e-02, -6.9540e-02, -5.2136e-02, 6.0439e-02,
4.6486e-02, -7.2450e-02, -1.1493e-02, 1.7681e-02, 6.1643e-02,
7.3240e-02, 2.1737e-02, 1.2148e-02, -9.1230e-02, 7.5856e-02,
5.2213e-02, -2.0461e-02, -5.2718e-02, -1.0508e-01, 1.0872e-02,
-1.5246e-03, -2.6219e-02, 8.0384e-03, 1.4495e-02],
[-5.3898e-02, -6.2903e-02, 9.3278e-02, 5.2151e-02, 2.4894e-02,
7.4073e-02, 5.3318e-02, -2.8024e-02, -1.0801e-01, 1.0287e-01,
4.7442e-02, 5.0882e-02, -2.4797e-02, -9.7050e-02, 8.4431e-02,
6.6385e-02, 1.0674e-01, -1.0502e-01, -5.5089e-03, -2.1849e-02,
4.1641e-02, 6.3477e-02, -7.0544e-02, -1.0295e-01, -6.6842e-02,
8.4617e-03, 6.7929e-02, 5.4032e-02, -1.3090e-02, -2.8971e-02,
-4.3182e-02, -4.4062e-02, -9.7478e-02, -6.2066e-02, -5.4336e-02,
5.0292e-02, 7.3119e-02, 8.2924e-02, 9.4342e-02, 4.5792e-02,
-4.1454e-02, 6.4998e-02, -5.4642e-02, 1.6673e-02, -3.3706e-02,
-9.6359e-02, 1.7345e-04, -9.2701e-02, 8.3176e-03, 2.0206e-02,
-3.3897e-03, -4.6855e-02, 5.2317e-02, -4.5804e-02, 2.2652e-02,
-1.0233e-01, -3.7416e-03, -7.8552e-02, -1.7273e-02, -2.1090e-02,
-7.0454e-03, 6.4824e-02, 1.0302e-01, -5.9749e-03, 1.0142e-01,
7.1721e-02, -3.2308e-02, 1.0354e-01, -2.1937e-02, 9.5663e-02,
-5.3376e-03, -8.3901e-02, 2.2311e-02, 4.4284e-02, -1.1621e-03,
-1.9146e-02, 9.5266e-02, 4.9740e-02, 8.7377e-02, -6.9876e-02,
7.4907e-02, -1.7911e-02, 7.3294e-02, 7.7561e-02],
[-1.0208e-01, 8.0829e-02, -7.1439e-02, 2.2382e-02, 7.3827e-02,
-5.1018e-02, -6.3343e-02, -2.9540e-02, 9.3982e-02, 4.3037e-02,
8.5199e-02, -5.6971e-02, 2.5385e-02, 7.3944e-02, 3.9802e-02,
7.6630e-02, -1.4518e-02, 5.0097e-02, 4.9435e-02, 6.8677e-02,
-1.9400e-02, 9.0476e-03, 9.9804e-02, -9.8907e-02, 7.7342e-02,
5.8050e-02, -1.0695e-02, 2.4084e-02, 6.4561e-02, -1.5142e-02,
-2.9764e-02, 1.3710e-03, 1.3472e-02, -4.0897e-03, 3.9828e-02,
1.0682e-01, -3.4013e-02, -7.1825e-02, 1.0256e-01, -4.8058e-02,
-7.9164e-02, -3.1616e-02, 1.4847e-02, -5.4025e-02, 1.0359e-01,
-7.2879e-02, -5.5342e-02, 9.1632e-02, 9.1646e-02, -4.5751e-02,
6.7954e-02, 3.2295e-02, 6.9718e-03, 9.7817e-02, 1.0138e-01,
5.1222e-02, -6.8627e-02, -4.2274e-02, -5.4111e-02, 2.0167e-02,
9.3684e-02, -7.3912e-02, -4.8360e-02, 2.9101e-02, 5.3984e-02,
-6.4280e-02, -7.3429e-02, 6.6071e-02, -4.8896e-02, -2.6180e-02,
8.5460e-02, 7.1012e-02, 2.1531e-02, -7.9570e-02, 9.2200e-02,
-9.2980e-02, 7.6317e-02, -4.6281e-02, -1.0742e-01, 3.3758e-02,
-1.0100e-01, -5.6719e-02, -1.4088e-02, 6.6273e-02],
[ 5.3003e-02, -8.5680e-02, -3.2159e-04, -6.6686e-02, 1.0357e-01,
7.7528e-02, -7.7695e-02, -3.7343e-02, -6.4867e-02, -3.4528e-02,
3.8912e-02, -2.9453e-02, 7.2781e-02, 8.9834e-02, -7.9001e-02,
-8.3703e-02, 5.7901e-02, 6.9107e-02, 7.9265e-02, 4.9187e-02,
5.0513e-02, -5.4097e-02, -3.2582e-02, 6.4819e-02, 1.9896e-02,
-7.3084e-02, -8.9256e-02, 4.9840e-02, -5.9509e-02, 5.3884e-02,
-3.8777e-02, 6.0162e-02, 1.0393e-01, -2.5981e-02, 3.4439e-02,
7.1359e-02, -8.5326e-02, -8.6582e-02, 2.3416e-02, -4.4381e-02,
-3.0652e-03, 3.3561e-02, 4.0531e-02, 6.6637e-02, -1.0751e-01,
-2.0587e-02, -5.7200e-02, 8.0107e-02, -9.5006e-02, 8.8297e-02,
5.6524e-02, 5.3064e-02, -5.4679e-02, -6.5577e-02, 2.3586e-02,
2.7163e-02, 3.3865e-03, 9.4055e-02, -4.6757e-02, -1.8562e-02,
-7.2532e-02, 4.8080e-02, -9.0010e-02, -1.0510e-01, -7.9523e-02,
-7.7586e-02, -8.7399e-02, -2.3291e-02, 5.0533e-02, 9.0525e-02,
5.7866e-02, 9.3118e-02, 6.4315e-02, -9.9633e-02, 3.1986e-02,
8.2956e-02, 7.0101e-02, -2.0538e-02, 8.6799e-02, -6.9437e-02,
1.0641e-01, 1.0426e-01, 8.0581e-02, -8.0374e-02],
[ 3.2695e-02, -3.4856e-02, 6.2409e-02, -1.3378e-02, 6.0558e-03,
1.4727e-02, -8.5913e-02, 1.0691e-01, 7.7849e-02, -4.7844e-02,
6.0482e-02, -7.5180e-03, -4.0028e-02, 6.0692e-02, -5.5052e-02,
3.3449e-02, -8.4209e-02, -1.3423e-02, 9.3352e-02, 1.3072e-02,
-5.9021e-04, 7.6759e-02, 8.5805e-02, -2.4546e-02, -9.6939e-02,
-4.4700e-02, 5.6368e-02, -1.1727e-02, 5.2196e-02, 6.8510e-02,
4.8723e-02, -3.9295e-02, -3.2981e-02, -3.3015e-02, -1.3612e-02,
-7.1090e-02, -7.5796e-02, -4.9240e-02, 1.4497e-02, 2.2467e-02,
-4.2246e-03, 7.9742e-02, 9.8798e-02, -1.0770e-01, -1.0561e-01,
8.5984e-02, -9.2008e-02, 5.3294e-02, -5.5027e-02, -4.8945e-02,
1.9929e-02, 8.5859e-02, 6.2954e-02, 2.9138e-03, 4.9239e-02,
2.2325e-02, -9.6379e-02, 1.4460e-02, -3.5029e-02, -4.5814e-02,
-7.4607e-02, 7.0636e-02, 6.4684e-02, -2.8305e-02, 1.7848e-03,
-1.5799e-02, 6.4801e-02, 5.6922e-02, -9.3420e-02, -2.4182e-02,
4.9939e-02, -4.6919e-02, -9.4452e-02, -6.2229e-02, 2.6796e-02,
-1.9359e-02, -1.0534e-01, -1.0308e-01, 6.7256e-02, -1.0804e-01,
-5.2163e-02, -3.9063e-02, -1.1436e-02, -4.9283e-02],
[ 3.7929e-02, -9.2263e-02, 8.3047e-02, -1.4673e-03, 1.0762e-01,
4.5225e-02, -6.9877e-02, -8.5745e-02, 4.3992e-02, 8.6473e-02,
-3.9794e-02, -6.4171e-02, 2.1542e-02, 1.5011e-03, -8.4764e-02,
7.6233e-02, 1.4007e-02, -4.1551e-02, 6.5961e-02, -9.7566e-02,
-9.3359e-02, 3.9779e-02, 3.7884e-03, 2.0744e-02, 1.0474e-01,
4.3060e-02, -2.9415e-02, -1.3947e-02, -1.0270e-01, -1.4549e-02,
-3.3965e-02, 7.2888e-02, -1.5545e-02, -2.1496e-02, -9.5266e-02,
4.0747e-02, -2.9195e-02, -6.9564e-02, 8.0910e-02, -4.2163e-02,
-1.3920e-02, -3.4202e-02, 8.5124e-02, 8.0781e-02, -4.7417e-02,
1.0826e-01, 2.4278e-02, 3.6846e-02, -9.6196e-02, 1.2560e-02,
9.7764e-02, 1.0009e-01, 1.0832e-01, 2.5618e-03, 7.1625e-02,
-7.1904e-02, -6.2205e-02, 5.8912e-02, 2.3969e-02, -9.6306e-02,
2.9396e-02, 3.1466e-02, 6.3138e-02, 3.7381e-02, 1.0424e-01,
-1.0732e-01, -1.0593e-01, 3.2480e-02, -1.0435e-01, -1.3005e-02,
5.2493e-02, -6.7261e-02, -7.5851e-02, -4.7852e-05, 1.2423e-02,
-7.6787e-02, -1.0262e-02, -1.6454e-02, 9.6626e-02, -1.0664e-01,
-8.2998e-02, -2.1841e-02, -5.2301e-02, -1.5127e-02]],
requires_grad=True)), ('fc3.bias', Parameter containing:
tensor([ 0.0554, 0.0078, 0.0239, 0.0951, 0.0102, 0.0927, 0.0383, 0.0847,
0.0910, -0.0069], requires_grad=True))]
测试随机输入32×32。 注:这个网络(LeNet)期望的输入大小是32×32,如果使用MNIST数据集来训练这个网络,请把图片大小重新调整到32×32。
input = torch.randn(1, 1, 32, 32)
out = net(input)
print(out)
tensor([[-0.0384, 0.0331, 0.0275, 0.0812, -0.0870, 0.1044, 0.0813, 0.0467,
0.0679, 0.0307]], grad_fn=<AddmmBackward>)
将所有参数的梯度缓存清零,然后进行随机梯度的的反向传播:
net.zero_grad()
out.backward(torch.randn(1, 10))
Note
torch.nn
只支持小批量输入。整个 torch.nn
包都只支持小批量样本,而不支持单个样本。 例如,nn.Conv2d
接受一个4维的张量, 每一维分别是sSamples * nChannels * Height * Width(样本数*通道数*高*宽)
。 如果你有单个样本,只需使用 input.unsqueeze(0)
来添加其它的维数。
一个损失函数接受一对 (output, target) 作为输入,计算一个值来估计网络的输出和目标值相差多少。
nn包中有很多不同的损失函数。 nn.MSELoss是一个比较简单的损失函数,它计算输出和目标间的均方误差, 例如:
output = net(input)
target = torch.randn(10) # 随机值作为样例
target = target.view(1, -1) # 使target和output的shape相同
criterion = nn.MSELoss()
loss = criterion(output, target)
print(loss)
tensor(0.4847, grad_fn=<MseLossBackward>)
调用loss.backward()获得反向传播的误差。
但是在调用前需要清除已存在的梯度,否则梯度将被累加到已存在的梯度。
现在,我们将调用loss.backward(),并查看conv1层的偏差(bias)项在反向传播前后的梯度。
net.zero_grad() # 清除梯度
print('conv1.bias.grad before backward')
print(net.conv1.bias.grad)
loss.backward()
print('conv1.bias.grad after backward')
print(net.conv1.bias.grad)
conv1.bias.grad before backward
tensor([0., 0., 0., 0., 0., 0.])
conv1.bias.grad after backward
tensor([ 0.0051, 0.0042, 0.0026, 0.0152, -0.0040, -0.0036])
在实践中最简单的权重更新规则是随机梯度下降(SGD):
weight = weight - learning_rate * gradient
我们可以使用简单的Python代码实现这个规则:
learning_rate = 0.01
for f in net.parameters():
f.data.sub_(f.grad.data * learning_rate)
但是当使用神经网络是想要使用各种不同的更新规则时,比如SGD、Nesterov-SGD、Adam、RMSPROP等,PyTorch中构建了一个包torch.optim实现了所有的这些规则。 使用它们非常简单:
import torch.optim as optim
# create your optimizer
optimizer = optim.SGD(net.parameters(), lr=0.01)
# in your training loop:
optimizer.zero_grad() # zero the gradient buffers
output = net(input)
loss = criterion(output, target)
loss.backward()
optimizer.step() # Does the update