Pytorch 构建简单Neural Networks

Neural Networks

使用torch.nn包来构建神经网络。nn包依赖autograd包来定义模型并求导。 一个nn.Module包含各个层和一个forward(input)方法,该方法返回output。
Pytorch 构建简单Neural Networks_第1张图片
它是一个简单的前馈神经网络,它接受一个输入,然后一层接着一层地传递,最后输出计算的结果。

神经网络的典型训练过程如下:

1.定义包含一些可学习的参数(或者叫权重)神经网络模型;
2.在数据集上迭代;
3.通过神经网络处理输入;
4.计算损失(输出结果和正确值的差值大小);
5.将梯度反向传播回网络的参数;
6.更新网络的参数,主要使用如下简单的更新原则: weight = weight - learning_rate * gradient

定义网络

import torch
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):

    def __init__(self):
        #复制并使用Net的父类的初始化方法,即先运行nn.Module的初始化函数
        super(Net, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.conv1 = nn.Conv2d(1, 6, 5)#
        self.conv2 = nn.Conv2d(6, 16, 5)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        #16*5*5:一开始的input是32*32,经过两次卷积池化之后的大小就是16(卷积核)*5(宽)*5(高),  #然后再把它通过view后。全部输入到第一个全连接层的时候,input大小就是16*5*5。
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        
    #定义该神经网络的向前传播函数,该函数必须定义,一旦定义成功,向后传播函数也会自动生成(autograd)
    def forward(self, x):
        # Max pooling over a (2, 2) window
        #输入x经过卷积conv1之后,经过激活函数ReLU,使用2x2的窗口进行最大池化Max pooling,然后更新到x。
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        
        # If the size is a square you can only specify a single number
        #输入x经过卷积conv2之后,经过激活函数ReLU,使用2x2的窗口进行最大池化Max pooling,然后更新到x。
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        
        # view函数将张量x变形成一维的向量形式,总特征数并不改变,为接下来的全连接作准备 
        x = x.view(-1, self.num_flat_features(x))
        
        # 输入x经过全连接1,再经过ReLU激活函数,然后更新x
        x = F.relu(self.fc1(x))
        
        # 输入x经过全连接2,再经过ReLU激活函数,然后更新x
        x = F.relu(self.fc2(x))
        
        # 输入x经过全连接3,然后更新x
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


net = Net()
print(net)
Net(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

net.parameters()返回可被学习的参数(权重)列表和值

params = list(net.parameters())
print(len(params))
print(params[0].size())  # conv1's .weight
10
torch.Size([6, 1, 5, 5])

详细查看网络的参数

print(list(net.named_parameters()))
[('conv1.weight', Parameter containing:
tensor([[[[-0.0864,  0.1581,  0.1053, -0.0674,  0.0957],
          [ 0.0101, -0.1520, -0.0849,  0.1507,  0.1818],
          [ 0.0768,  0.0109, -0.0293,  0.0672,  0.0463],
          [-0.1766,  0.0446,  0.1647,  0.1359, -0.1144],
          [ 0.1169,  0.1585, -0.1872,  0.1876, -0.1431]]],


        [[[-0.0481, -0.0338, -0.0865,  0.1113,  0.0281],
          [-0.0110, -0.1859,  0.0683, -0.1943,  0.1785],
          [-0.0686,  0.0450,  0.1983, -0.0679, -0.1770],
          [ 0.1880,  0.0102, -0.0149,  0.1218, -0.1315],
          [-0.1027,  0.0082, -0.0568,  0.1253, -0.0648]]],


        [[[-0.1574, -0.1198,  0.1627, -0.1704,  0.0758],
          [ 0.0399,  0.0644, -0.1151, -0.1112,  0.0837],
          [-0.0485,  0.0777,  0.0932,  0.0524,  0.0491],
          [-0.0652,  0.0403, -0.0345, -0.1912,  0.0434],
          [ 0.0401, -0.1806,  0.1038, -0.0035,  0.1026]]],


        [[[ 0.0366, -0.1110, -0.1358, -0.0388, -0.0136],
          [ 0.0998,  0.0714,  0.0491, -0.1141, -0.1639],
          [-0.0606,  0.1912,  0.0887, -0.1987,  0.1686],
          [-0.0851, -0.0431,  0.0035,  0.0634,  0.1633],
          [-0.1842, -0.1439, -0.1178,  0.1330,  0.1844]]],


        [[[ 0.1000,  0.1853,  0.0840,  0.0314,  0.1282],
          [ 0.0399,  0.1136,  0.1043, -0.1371,  0.0194],
          [-0.0125, -0.0674,  0.1311,  0.0281, -0.0820],
          [-0.0661,  0.0128, -0.1913, -0.0224,  0.0260],
          [ 0.0750,  0.0270,  0.1146,  0.1069,  0.0046]]],


        [[[ 0.0444,  0.1200, -0.1469,  0.1303, -0.0595],
          [-0.1066, -0.1726,  0.1489, -0.0327, -0.0576],
          [ 0.1739,  0.1475,  0.1767,  0.0594,  0.1139],
          [ 0.0412, -0.1624, -0.1075, -0.1895, -0.1300],
          [-0.1598, -0.1082, -0.0381, -0.0411,  0.0309]]]], requires_grad=True)), ('conv1.bias', Parameter containing:
tensor([ 0.2000,  0.1894,  0.1787,  0.0921, -0.0161,  0.1778],
       requires_grad=True)), ('conv2.weight', Parameter containing:
tensor([[[[-4.6751e-02,  6.2027e-02,  6.8025e-02,  6.2829e-02,  5.9207e-02],
          [ 3.4858e-02, -4.6945e-02,  3.6028e-03,  9.2037e-03, -5.0105e-02],
          [ 1.9097e-02, -5.0400e-02,  6.0714e-02, -2.4261e-02,  2.4184e-02],
          [ 6.4232e-03,  4.2253e-02,  3.8566e-02,  7.8425e-02, -4.7642e-02],
          [-1.6932e-02,  7.0821e-02, -7.0269e-02,  4.5952e-02,  4.7312e-02]],

         [[ 2.4947e-02, -5.6880e-02, -5.0089e-02, -1.8905e-02, -5.6852e-02],
          [-2.1401e-02,  7.6649e-02,  5.6511e-02,  2.9109e-02, -4.3187e-02],
          [-3.8747e-02, -5.5686e-02, -6.8277e-02,  4.3523e-02, -3.1327e-02],
          [ 4.6242e-02, -6.7128e-02,  2.1462e-02,  2.1463e-02, -5.7259e-02],
          [ 8.5649e-03,  4.0587e-02,  7.0886e-02,  3.4475e-02,  4.8482e-02]],

         [[ 3.4679e-02, -5.7377e-02, -3.8839e-02,  5.0153e-02,  4.7679e-02],
          [-6.6696e-02, -5.2049e-02,  6.1444e-03, -1.0752e-02, -3.8693e-02],
          [-6.1365e-02,  6.7706e-02, -8.0698e-02, -2.0662e-02, -6.6051e-02],
          [-6.0435e-02, -4.1243e-02,  7.0561e-02, -7.8558e-02, -7.2256e-02],
          [-2.9384e-02, -5.9898e-02,  1.4626e-02,  1.2504e-02, -1.6448e-03]],

         [[-8.3191e-03, -2.4209e-02, -2.4460e-02, -3.7102e-02, -2.8801e-02],
          [ 7.0775e-02, -7.7456e-02,  1.6669e-02, -5.7365e-02,  5.1116e-02],
          [-2.2186e-03,  7.1819e-02,  1.4087e-02,  1.2796e-02, -8.0466e-02],
          [-2.2954e-02,  1.7675e-02, -6.7700e-02,  4.2731e-02, -4.5965e-02],
          [ 1.9493e-02,  2.8682e-02, -3.5744e-03, -6.2132e-02, -1.6093e-02]],

         [[-4.6586e-02, -5.1907e-03, -2.0616e-02, -5.0160e-03, -1.0218e-02],
          [ 3.0281e-02,  1.0599e-02, -4.4635e-02, -6.2386e-03,  1.2839e-02],
          [ 6.3183e-02, -4.7749e-02, -1.5905e-02,  3.7993e-03,  6.9684e-03],
          [-4.3931e-02,  1.2768e-02,  8.1540e-02,  4.5033e-02, -4.2173e-02],
          [ 5.9427e-02, -4.7729e-02,  4.9988e-02,  5.0745e-02,  6.8716e-02]],

         [[ 3.1921e-02,  7.3147e-02,  8.1235e-02,  4.7650e-02,  1.7855e-02],
          [ 4.9684e-02,  4.8450e-02,  2.7255e-02, -8.2837e-03, -7.4009e-02],
          [ 6.2119e-02, -7.7966e-02, -6.6083e-02,  5.2162e-02,  2.8479e-02],
          [ 7.4365e-02, -2.8491e-02,  9.6219e-03,  1.9529e-02, -2.1051e-03],
          [ 6.5328e-02,  3.6120e-02, -2.6581e-02,  6.8440e-02, -1.8014e-02]]],


        [[[ 2.6748e-02,  5.5181e-02,  2.2528e-02,  6.6341e-03,  5.5057e-02],
          [ 3.3438e-02, -3.9520e-02, -3.8108e-02, -4.8951e-03,  4.3503e-02],
          [-1.7223e-02, -6.1210e-02, -1.6332e-02,  7.2617e-02,  1.2141e-02],
          [ 1.5687e-02, -7.5238e-02, -3.2138e-02, -6.2180e-03,  7.4173e-02],
          [ 1.3187e-02,  3.0043e-02, -4.9262e-02,  1.0382e-02, -5.2707e-02]],

         [[-6.1639e-02, -1.6759e-03,  9.1899e-03, -2.0154e-02,  4.7649e-02],
          [ 4.4110e-02, -4.4745e-02, -6.4256e-02, -5.3946e-02,  1.4076e-02],
          [ 2.5333e-02,  6.0833e-02, -3.7473e-03, -3.7682e-02, -7.7613e-03],
          [ 1.2864e-02,  6.8993e-02, -2.3292e-02, -3.1719e-02,  1.6382e-02],
          [ 7.3083e-02,  7.3863e-02, -2.6046e-02, -4.9677e-02,  1.5290e-02]],

         [[ 7.4857e-02, -5.2538e-02, -8.0590e-02, -6.6272e-02, -6.3306e-02],
          [-4.3562e-02, -6.7632e-02, -7.2244e-02,  6.6524e-02,  3.8218e-02],
          [ 2.9765e-02, -1.5771e-02, -1.2642e-02, -6.1716e-02, -6.1412e-02],
          [-1.2431e-02,  7.5718e-02, -4.0329e-02,  2.4516e-02, -7.6501e-02],
          [-2.5373e-02, -7.6981e-02,  4.2991e-02, -5.8949e-02, -8.1620e-02]],

         [[-2.8247e-04, -4.3295e-02, -7.0514e-04, -6.0201e-02, -9.4066e-03],
          [ 2.6099e-02, -6.8910e-02, -5.6376e-02, -7.7094e-02, -3.9688e-02],
          [ 2.1004e-02,  3.8341e-02, -4.4436e-03,  4.8700e-02,  1.1442e-02],
          [-2.7794e-02, -2.9512e-02, -8.3854e-03,  2.5953e-02, -5.8595e-02],
          [-9.8773e-03,  3.8318e-02, -6.4992e-02, -6.4106e-02,  2.1669e-02]],

         [[-5.6786e-02, -2.3078e-02, -6.3540e-02,  6.4722e-02, -3.6102e-03],
          [ 2.6318e-03, -5.8278e-02, -7.0678e-02,  6.4511e-02,  4.9434e-02],
          [-1.7529e-02, -3.6845e-02, -4.1389e-02, -1.7150e-02, -1.7080e-02],
          [-5.1267e-02, -5.4333e-02,  3.7604e-02,  2.5120e-02,  4.5091e-03],
          [ 6.0331e-02,  1.0586e-02, -4.9640e-03, -4.0446e-02, -5.2631e-02]],

         [[ 4.2650e-02,  8.0488e-02,  2.2548e-02,  2.5217e-02,  3.9701e-02],
          [ 2.4243e-02,  3.7999e-02, -2.1265e-02, -4.0239e-02,  6.2443e-02],
          [-6.5399e-02, -4.2270e-03,  7.1457e-02,  7.6897e-02,  7.7068e-03],
          [ 3.0916e-02, -1.0364e-02,  1.8940e-02, -5.4591e-02,  3.6602e-02],
          [ 7.2685e-02, -5.0956e-02, -7.2337e-02,  7.0943e-02,  7.7663e-02]]],


        [[[ 6.8438e-02, -4.1001e-02,  4.9182e-02, -1.0405e-02,  5.2470e-02],
          [ 5.1253e-03,  3.2332e-02,  5.4736e-03,  7.6816e-02, -2.4995e-02],
          [-1.5874e-02, -6.0776e-02,  6.4105e-03,  3.5231e-02,  2.4181e-03],
          [ 4.9819e-03,  8.2966e-03,  4.6681e-02,  6.0710e-02,  2.3426e-02],
          [ 3.0622e-03, -4.2392e-02,  7.9669e-02, -6.0518e-02, -7.7647e-02]],

         [[ 5.6804e-02,  3.2624e-02,  2.6365e-02,  1.9988e-02,  5.0879e-02],
          [ 5.7976e-03,  5.6419e-02, -6.0317e-02, -6.3418e-02,  7.9154e-02],
          [ 5.4380e-02,  5.0265e-02, -3.4530e-02, -3.4289e-02, -5.7359e-02],
          [-2.0039e-02, -1.1091e-02,  6.2006e-02,  4.2279e-02,  1.3524e-02],
          [-5.1808e-02,  1.6028e-02, -1.3753e-02, -4.4715e-02,  6.4370e-02]],

         [[-7.5229e-02, -3.1394e-02, -5.7542e-02, -2.5208e-03,  3.8752e-02],
          [-3.8169e-02, -5.5508e-02,  4.4447e-03,  1.3224e-02, -9.8716e-03],
          [ 5.8168e-02, -3.0163e-02, -1.7469e-02, -6.4066e-02,  5.9589e-02],
          [ 1.6919e-02,  2.7378e-02,  2.6752e-02, -2.9841e-03,  6.6029e-03],
          [ 4.5793e-02, -1.0817e-02,  5.7493e-02, -1.9364e-02,  6.9119e-02]],

         [[-1.7484e-02, -7.0515e-02, -2.1695e-02,  2.5091e-02,  2.9314e-02],
          [-3.3494e-02,  8.0373e-02, -5.5042e-02,  5.6274e-02, -6.3458e-02],
          [-7.4188e-02, -7.3826e-02,  6.3572e-03,  7.2682e-02, -4.4463e-02],
          [-3.3026e-02, -5.7821e-02,  5.5200e-02, -5.6740e-02,  7.5848e-02],
          [-1.8855e-02,  5.5604e-03,  2.7479e-02,  3.2113e-02, -7.1571e-02]],

         [[-5.9697e-02,  5.5090e-02,  3.4306e-02, -2.8067e-02,  1.0584e-02],
          [ 1.9661e-03, -8.2800e-03,  4.1372e-02,  7.3996e-02, -2.4893e-02],
          [ 5.5144e-02,  4.6595e-03,  5.8500e-02, -5.4630e-02, -4.6803e-04],
          [-1.2829e-02, -7.8442e-02, -5.3737e-02,  4.0796e-02,  6.9439e-02],
          [-7.3417e-02,  6.3357e-02,  2.7502e-02, -3.9882e-02,  6.1540e-03]],

         [[ 3.7075e-02,  6.7935e-02,  2.5044e-02, -4.7587e-02,  4.8540e-02],
          [ 5.1211e-02, -1.1007e-02,  7.1380e-02,  7.2305e-02, -4.2107e-02],
          [ 2.6421e-02,  2.0946e-02, -4.6829e-02,  2.6259e-02, -6.5568e-02],
          [-4.4335e-02, -8.0355e-02, -2.2185e-02,  8.1275e-02, -4.5410e-02],
          [-7.9133e-02,  7.0272e-03, -5.9176e-02, -7.3026e-02,  6.6735e-02]]],


        ...,


        [[[ 4.7453e-02,  1.6624e-03, -4.9159e-02, -6.3126e-02,  7.6058e-02],
          [-7.6433e-02,  6.6995e-02,  7.6168e-02,  2.9466e-02, -4.2630e-02],
          [-8.1500e-02,  1.6365e-02,  4.9923e-02,  9.0797e-03, -7.3443e-02],
          [-9.7154e-03, -3.5494e-02,  7.9512e-02, -3.8434e-02, -1.7049e-02],
          [ 3.3263e-02, -4.0653e-02,  1.1763e-02,  1.1847e-02,  1.3137e-02]],

         [[-7.0701e-02,  5.2604e-02, -2.0280e-02, -5.1948e-02,  6.1880e-02],
          [ 5.0646e-02, -6.0460e-02,  2.2518e-02,  5.4762e-02,  2.6427e-04],
          [ 5.5895e-02, -3.1711e-02, -7.7939e-02, -3.7241e-02,  5.3921e-02],
          [ 8.0110e-02,  3.4243e-02, -3.8894e-02,  4.3316e-02, -6.2829e-02],
          [-2.0985e-02,  2.0693e-02, -4.0446e-02,  5.4924e-02, -7.0742e-02]],

         [[ 5.8691e-02,  5.7382e-02,  6.1324e-03,  7.0543e-02,  8.1322e-02],
          [ 1.7258e-02, -8.9442e-03,  5.5113e-02, -1.6722e-03, -3.3769e-02],
          [-1.2999e-02, -3.7392e-02,  5.2130e-02, -6.1703e-02,  2.5096e-02],
          [ 2.4746e-02,  1.8096e-02, -5.4290e-03,  5.6583e-03, -3.2907e-02],
          [ 2.8275e-02, -2.2063e-03,  5.2794e-02,  4.5287e-02,  1.2212e-02]],

         [[-6.6913e-02, -1.5204e-02, -2.6589e-02, -3.9329e-02, -1.7749e-02],
          [ 5.4253e-02, -4.8495e-02,  1.3793e-02,  6.2208e-02,  3.4392e-02],
          [-3.9233e-02,  1.6612e-02, -1.9025e-02, -4.1305e-02, -4.1636e-02],
          [ 6.3880e-03, -7.3584e-02, -7.6999e-02,  3.9100e-02,  6.4777e-02],
          [ 6.0852e-02, -1.7153e-02,  6.4263e-02, -6.6408e-02, -2.8885e-02]],

         [[-5.6493e-02, -8.9312e-03, -6.7856e-02, -2.2723e-03,  6.3610e-02],
          [-8.4754e-03, -4.5080e-02,  2.0656e-02,  7.2127e-02,  6.0980e-02],
          [-2.8793e-02,  2.5090e-02,  6.1947e-02, -5.9376e-02, -7.8424e-02],
          [-7.0324e-02, -2.1602e-03, -3.6420e-02,  3.6624e-02,  6.1857e-02],
          [-1.1167e-02,  8.7337e-04,  1.7752e-02,  7.5435e-02, -2.4711e-02]],

         [[ 7.4637e-02,  6.9347e-02,  3.9538e-02,  4.4932e-02,  3.9881e-02],
          [-2.0344e-02, -5.1092e-02,  7.0027e-03, -8.0631e-02, -1.7589e-02],
          [-6.1046e-02,  3.9349e-02, -7.4361e-02,  6.2805e-02, -6.0173e-02],
          [ 3.9680e-02,  6.1395e-02,  6.7319e-02,  6.8641e-03, -2.2438e-03],
          [ 3.9790e-02,  1.5846e-02,  4.8661e-02,  7.7811e-02,  4.8407e-02]]],


        [[[-5.9373e-03, -1.3359e-02, -4.8370e-02, -4.6881e-02,  6.3805e-02],
          [ 6.6859e-02,  5.4224e-02, -1.5097e-02,  2.9913e-02,  4.1902e-02],
          [-4.7653e-02, -5.5400e-02, -4.5789e-02, -2.2652e-02,  1.4833e-03],
          [ 4.5867e-03,  9.4236e-03,  6.7346e-02, -2.9957e-02,  2.4964e-02],
          [ 1.4203e-02,  7.3159e-02, -3.3535e-02, -4.0967e-02, -6.9503e-02]],

         [[ 1.8814e-02,  2.5195e-02,  2.5055e-02, -6.1476e-02,  2.6273e-02],
          [ 8.0260e-02, -1.8977e-03, -1.4657e-02, -3.2918e-02,  7.8980e-02],
          [ 2.2730e-02, -5.7422e-02, -5.7983e-02,  7.5397e-02, -5.1492e-02],
          [-7.7555e-02, -3.0075e-02, -3.7065e-02,  2.9274e-02,  7.8566e-02],
          [ 4.8337e-02, -8.0466e-02, -3.4578e-02,  7.4706e-02, -4.7404e-02]],

         [[ 7.3456e-02,  4.5859e-02,  4.8964e-02,  6.4293e-02,  3.2808e-02],
          [ 8.0656e-02,  8.7235e-03,  5.7811e-02, -7.8568e-02, -2.5927e-02],
          [-7.4949e-02, -8.0697e-02,  2.6019e-03, -4.1251e-02, -2.1189e-02],
          [ 3.4552e-02, -6.9636e-02, -2.4231e-02,  7.0092e-03,  1.3193e-03],
          [-5.3746e-02, -1.5284e-02, -2.3633e-02, -7.1587e-02, -7.4092e-03]],

         [[-1.0503e-02, -4.2436e-02,  6.0257e-02,  3.7005e-02, -5.7754e-03],
          [ 5.1638e-02, -2.6558e-02,  1.1289e-02, -6.4803e-02, -3.0699e-02],
          [ 4.8014e-02, -5.4398e-02,  5.9667e-02,  7.3184e-03, -5.9711e-02],
          [-7.0265e-02, -5.3415e-02,  7.1987e-02,  8.8195e-03,  3.9838e-03],
          [ 4.5838e-02,  2.0373e-02,  7.5955e-02, -1.4837e-02,  5.9894e-04]],

         [[ 2.7225e-03,  8.1387e-02,  7.1768e-02,  5.3041e-02, -1.8938e-02],
          [-4.4462e-03, -7.4515e-03,  7.2474e-02, -8.1187e-02,  6.8177e-02],
          [-6.3188e-02,  7.8637e-02,  7.5461e-02, -4.3640e-02, -2.7294e-02],
          [ 1.4516e-03, -2.1540e-02,  6.3983e-02, -5.5237e-02,  6.4640e-02],
          [ 6.7225e-02, -3.2269e-02,  5.2291e-02,  4.1978e-02,  1.0545e-02]],

         [[-3.9995e-02,  1.5795e-02, -1.1020e-02,  3.0129e-02, -7.5448e-02],
          [ 4.8956e-02,  7.5318e-02,  4.4730e-03, -4.0962e-02,  2.6889e-02],
          [ 6.6497e-02,  1.6028e-02, -7.3739e-02, -3.7472e-02,  7.3378e-03],
          [ 6.1110e-02, -7.4697e-02, -3.3396e-02, -5.3662e-02,  1.9632e-02],
          [-2.7660e-02, -6.3403e-02,  7.5017e-02, -6.2683e-05, -3.7116e-02]]],


        [[[-1.8004e-02,  6.5819e-02,  6.2223e-02,  1.6332e-02,  1.8143e-02],
          [-3.6172e-02, -6.9717e-02, -5.4694e-03,  2.3101e-02,  7.9629e-03],
          [ 6.3475e-02,  3.4132e-02, -5.2767e-03, -3.8902e-02, -5.6896e-02],
          [ 6.7207e-02, -4.2420e-02,  5.5340e-02,  3.7744e-02,  4.0298e-02],
          [-6.8564e-02,  2.5046e-02,  6.8845e-03,  4.5561e-02,  6.5920e-02]],

         [[ 4.8047e-02, -2.8234e-02,  7.3325e-02, -4.8417e-02,  3.9393e-02],
          [ 6.5448e-02, -1.5095e-02,  4.1826e-02, -3.8581e-02, -2.3085e-02],
          [-3.8703e-03,  8.0055e-02,  3.9432e-02, -3.5234e-02, -7.4530e-02],
          [ 6.6605e-02,  7.1790e-02, -2.6339e-02, -6.3635e-02,  5.8536e-02],
          [-3.7311e-02,  2.6147e-02, -6.4885e-02, -6.8364e-02, -8.1779e-03]],

         [[ 5.7778e-04,  6.5419e-02,  5.4751e-02, -2.7986e-02,  2.1922e-02],
          [-1.1870e-02,  4.0985e-02,  2.4626e-02, -6.6320e-03, -4.1242e-02],
          [-7.0432e-02,  2.8016e-02, -5.6965e-02, -1.4918e-02, -4.4836e-02],
          [ 6.8992e-02,  2.6582e-02,  2.8458e-02, -2.8932e-02, -6.5144e-02],
          [ 3.3061e-02, -9.9823e-03, -6.1666e-02, -5.6478e-02,  1.8484e-02]],

         [[-3.0624e-02, -4.3539e-02,  6.4283e-02, -3.4875e-02,  7.4699e-02],
          [-2.9344e-04, -1.3068e-03, -2.6016e-02,  4.0942e-02,  2.7596e-02],
          [ 3.8677e-02,  1.5856e-03, -6.9073e-02,  6.8789e-02,  6.7985e-02],
          [ 1.2172e-02,  6.0105e-02, -2.9876e-02,  1.6004e-02,  2.7459e-02],
          [-6.4051e-02, -4.5696e-02, -2.2469e-02, -7.0772e-02,  3.1664e-02]],

         [[ 4.2269e-02,  1.6835e-02, -1.6207e-02,  1.9925e-02, -1.3160e-02],
          [-3.8214e-02, -2.0697e-02,  1.2636e-02,  5.8541e-02,  6.5787e-02],
          [ 5.4092e-02, -4.1293e-02,  7.1718e-03,  3.5447e-03,  5.9314e-02],
          [-5.3280e-02,  5.7965e-02, -2.5258e-02,  5.3279e-02,  1.2102e-02],
          [-6.1438e-02, -4.7414e-02, -6.4839e-02, -4.9603e-02,  3.7255e-02]],

         [[ 2.8404e-04, -3.0073e-02,  1.2103e-02,  2.6228e-02, -4.6524e-02],
          [ 6.2227e-02,  7.4788e-02,  7.7427e-02,  6.0245e-02,  4.3996e-02],
          [-4.6962e-03, -1.5549e-02,  2.2174e-02, -6.1404e-02, -5.2019e-02],
          [ 1.6444e-02,  3.0072e-02,  4.8501e-02,  7.1959e-02,  7.8801e-02],
          [-3.8354e-03, -1.2520e-02, -7.9206e-02,  4.9887e-02, -7.5002e-02]]]],
       requires_grad=True)), ('conv2.bias', Parameter containing:
tensor([-0.0291, -0.0585, -0.0470,  0.0497, -0.0030,  0.0714, -0.0772, -0.0729,
        -0.0146, -0.0744, -0.0542,  0.0765, -0.0288, -0.0186, -0.0451,  0.0693],
       requires_grad=True)), ('fc1.weight', Parameter containing:
tensor([[ 0.0127,  0.0351,  0.0101,  ...,  0.0015,  0.0404, -0.0455],
        [ 0.0154,  0.0499, -0.0289,  ...,  0.0344, -0.0245, -0.0473],
        [ 0.0072,  0.0226,  0.0097,  ...,  0.0363,  0.0141,  0.0354],
        ...,
        [ 0.0225,  0.0071, -0.0407,  ..., -0.0101, -0.0304,  0.0344],
        [ 0.0367,  0.0057, -0.0408,  ..., -0.0393, -0.0016,  0.0063],
        [ 0.0365, -0.0315,  0.0336,  ..., -0.0208, -0.0220,  0.0388]],
       requires_grad=True)), ('fc1.bias', Parameter containing:
tensor([ 0.0048, -0.0018,  0.0441, -0.0093, -0.0455, -0.0172,  0.0060,  0.0095,
        -0.0444, -0.0034, -0.0198, -0.0411,  0.0237, -0.0362, -0.0433,  0.0259,
        -0.0194,  0.0440, -0.0057,  0.0166,  0.0372, -0.0365,  0.0167,  0.0237,
        -0.0316, -0.0307, -0.0303, -0.0447,  0.0386, -0.0020, -0.0163, -0.0206,
        -0.0429,  0.0362, -0.0114,  0.0069, -0.0320,  0.0469,  0.0204,  0.0199,
         0.0402,  0.0199,  0.0290, -0.0053, -0.0403, -0.0448, -0.0317, -0.0443,
        -0.0072,  0.0160, -0.0249, -0.0263, -0.0495, -0.0168, -0.0374, -0.0191,
         0.0002, -0.0181,  0.0465, -0.0054, -0.0368, -0.0173, -0.0434,  0.0306,
         0.0474,  0.0398,  0.0285, -0.0076, -0.0255,  0.0232,  0.0022,  0.0059,
         0.0353,  0.0013,  0.0414, -0.0094,  0.0435, -0.0331, -0.0408, -0.0381,
         0.0160, -0.0089,  0.0269,  0.0171,  0.0205,  0.0466,  0.0360, -0.0251,
        -0.0087,  0.0258, -0.0364, -0.0393, -0.0077, -0.0196,  0.0485,  0.0066,
        -0.0440, -0.0421, -0.0475,  0.0283,  0.0257, -0.0366, -0.0414, -0.0026,
         0.0096,  0.0189,  0.0289, -0.0230, -0.0331, -0.0086,  0.0300,  0.0070,
        -0.0250,  0.0458, -0.0106, -0.0499, -0.0290,  0.0242, -0.0086, -0.0206],
       requires_grad=True)), ('fc2.weight', Parameter containing:
tensor([[-0.0868, -0.0417,  0.0663,  ...,  0.0089, -0.0459, -0.0327],
        [-0.0881,  0.0408, -0.0878,  ...,  0.0712, -0.0304,  0.0808],
        [ 0.0405,  0.0224, -0.0019,  ..., -0.0402, -0.0458, -0.0392],
        ...,
        [-0.0641,  0.0234,  0.0442,  ..., -0.0221,  0.0643,  0.0412],
        [ 0.0731,  0.0759, -0.0588,  ...,  0.0214,  0.0084,  0.0542],
        [-0.0466,  0.0032,  0.0363,  ..., -0.0843,  0.0305, -0.0020]],
       requires_grad=True)), ('fc2.bias', Parameter containing:
tensor([ 1.8544e-03,  7.7148e-02, -2.2453e-02, -5.1801e-02,  4.5913e-02,
        -7.2284e-02, -3.2864e-02, -6.8654e-02, -2.1057e-05,  7.0599e-02,
         2.5858e-02,  7.3280e-02,  1.1011e-02,  6.3518e-02, -4.0666e-02,
        -4.1031e-02,  6.6883e-02, -1.2299e-02, -6.6735e-02,  4.0905e-03,
        -6.5498e-02, -2.6964e-03,  8.2400e-02, -4.4789e-02,  3.2078e-02,
        -2.4232e-02, -8.0591e-02,  5.7183e-03, -1.8703e-03, -8.2100e-02,
        -6.2772e-02,  7.0709e-02, -4.4223e-02, -8.6373e-03, -8.9721e-02,
        -8.7231e-02,  2.1281e-02,  6.5149e-02, -5.8912e-02, -6.0408e-03,
         3.9662e-02, -5.1060e-02, -5.0107e-02,  7.2825e-02,  8.5803e-02,
         3.2855e-02, -4.8813e-02, -6.2276e-02,  6.0332e-02, -2.1124e-03,
         2.0599e-02,  5.5399e-02,  1.8062e-02,  3.8259e-02, -1.0740e-02,
        -4.9024e-02,  7.9856e-02,  3.3842e-02,  3.2818e-03, -8.6486e-02,
         6.8716e-04, -3.4156e-02, -6.9502e-02,  6.9801e-02, -1.9957e-02,
        -8.9092e-02,  8.5563e-02,  6.8908e-02, -1.2464e-03, -5.8290e-02,
         5.7215e-03, -1.9210e-02,  5.3976e-02, -8.5581e-02, -5.1105e-02,
        -1.1988e-02, -8.9009e-02,  7.1955e-02, -7.3027e-02, -3.2002e-02,
        -7.7311e-03,  6.5628e-02,  7.7175e-02, -2.6717e-02],
       requires_grad=True)), ('fc3.weight', Parameter containing:
tensor([[-5.6418e-02,  8.7256e-03,  1.4857e-03, -4.4241e-02,  2.6271e-02,
          6.2223e-02,  9.5526e-02,  8.1587e-02, -1.0604e-01, -1.4852e-02,
          2.0481e-02,  6.4787e-02, -1.7821e-02, -7.5547e-02,  7.8957e-02,
         -1.3885e-02, -1.0226e-01, -8.2753e-02,  9.8486e-02, -3.4305e-02,
          2.7100e-03,  1.8032e-02, -1.0715e-01,  4.9688e-03, -9.7491e-02,
          5.7741e-02,  1.0480e-01, -7.5001e-02, -1.0038e-01, -8.3054e-02,
          1.4810e-02, -2.5294e-02, -3.5486e-02,  1.0084e-01,  1.2828e-02,
         -2.8502e-03, -4.3740e-02, -7.8102e-03,  2.1588e-02, -7.5598e-02,
         -3.3590e-02, -6.8185e-02,  3.3957e-02,  2.1573e-02, -5.8785e-02,
          6.0361e-02, -8.8535e-02,  6.7021e-02, -9.3063e-02,  1.9414e-02,
         -6.7784e-02,  3.7093e-02,  7.5324e-03, -1.9678e-02,  5.4143e-02,
         -8.9511e-02, -9.2640e-02,  6.8033e-02,  2.4691e-02, -4.7642e-02,
          5.3445e-02, -6.4451e-02,  4.2707e-02, -4.0519e-02, -9.8571e-02,
          6.6185e-02,  5.7230e-02, -3.7836e-02,  9.9665e-02, -2.2986e-02,
         -5.7756e-02, -6.8581e-02, -1.1435e-02, -3.2040e-02,  1.0683e-01,
         -9.7724e-02, -4.2304e-02, -1.2573e-02,  6.7023e-02,  4.1495e-02,
          7.6838e-02, -6.6680e-02, -7.8055e-02,  5.0132e-02],
        [ 1.9141e-02,  2.5746e-02,  1.0312e-01,  1.4601e-02,  3.0647e-02,
          9.3271e-02,  6.4014e-02, -5.4571e-02, -5.6861e-02, -2.6344e-02,
          9.0764e-02,  5.0952e-02, -9.0798e-02,  2.1471e-02, -1.0136e-01,
         -8.5665e-02,  2.1246e-02, -2.4271e-02, -6.6701e-02,  7.1168e-02,
          6.3234e-02, -9.3446e-03,  2.8220e-03, -1.4337e-02, -2.4543e-02,
          1.0582e-01,  2.0000e-02, -1.0267e-01, -3.5270e-03, -5.9467e-02,
         -2.3974e-03,  2.7704e-02, -6.6317e-02,  8.0701e-02, -1.0008e-01,
         -1.0211e-01, -8.1395e-02, -6.5022e-02,  5.9279e-02, -7.2579e-02,
         -9.5273e-02, -6.4547e-02, -8.0597e-02,  8.4433e-03, -4.7848e-02,
          9.7750e-02,  9.3517e-03,  7.4619e-02, -3.9448e-02, -8.5778e-02,
         -9.5482e-02, -5.5330e-03, -1.0796e-01,  6.7140e-02, -2.9456e-02,
         -1.8318e-02, -4.2094e-02,  1.8795e-02,  1.0150e-02,  4.3263e-02,
         -7.2860e-02,  7.8749e-02, -5.2190e-03,  1.5920e-02,  5.5891e-02,
         -4.6358e-02,  9.4703e-02,  8.4561e-02,  5.2485e-02,  7.6034e-02,
         -2.0514e-03,  5.1548e-02, -2.3387e-03, -1.0369e-01, -1.7051e-03,
          9.5608e-02, -7.4236e-02,  8.8553e-02, -7.3277e-02, -6.1156e-02,
         -4.4343e-02,  1.5902e-02,  8.3211e-02, -4.9491e-02],
        [-5.4914e-03, -8.3233e-02,  2.0474e-02, -7.9890e-02, -2.6837e-02,
          9.9609e-02, -6.1847e-02,  8.2810e-03, -7.8208e-02, -2.2464e-02,
         -8.8908e-02,  8.8578e-02,  7.5675e-02, -1.2950e-02,  9.0193e-02,
         -2.1504e-02,  8.8890e-02,  9.8521e-02, -7.5039e-02,  1.0464e-01,
          9.4199e-02, -3.8682e-02,  8.7018e-02,  8.5465e-02, -7.7055e-02,
         -7.1284e-02, -1.0897e-01, -5.4695e-03, -7.9368e-02,  3.3380e-02,
          7.2154e-02,  9.5790e-02,  6.1063e-04, -1.7953e-02,  6.1940e-02,
         -7.8037e-02,  9.0474e-02,  9.4975e-02,  4.7951e-02,  2.9325e-02,
          2.9166e-04,  5.0427e-02,  4.9875e-02,  1.9344e-02, -1.5598e-02,
         -8.8184e-02,  4.9534e-02,  4.1846e-02, -3.4013e-02, -5.9683e-02,
         -7.4203e-02,  4.6901e-02, -5.7360e-02,  3.9085e-02, -1.0259e-01,
          5.0026e-02,  9.3175e-02, -4.9197e-02, -9.1434e-02, -6.5976e-02,
          9.9470e-02,  2.5012e-02,  2.3345e-02, -7.4206e-02,  9.9779e-02,
         -9.9292e-02,  9.1207e-02, -1.8750e-02, -8.7222e-02,  4.5009e-02,
          6.2097e-02,  1.1889e-02, -3.1094e-02,  6.2701e-02, -3.6896e-02,
          7.9025e-03,  3.9122e-02, -6.6788e-02,  5.8422e-02, -1.8931e-02,
          4.7233e-02,  7.9713e-03,  6.5758e-02,  2.4974e-02],
        [ 9.3710e-02,  6.3587e-03, -1.1848e-02,  4.5289e-02, -2.2177e-02,
         -7.1277e-02,  4.7252e-02,  1.6109e-02, -6.7931e-02, -1.7756e-02,
          2.0796e-02,  7.4046e-02, -1.0572e-01, -9.9937e-02,  6.4527e-04,
          7.9186e-02, -2.9363e-02,  9.3051e-02,  2.8610e-02,  5.4603e-02,
          1.7861e-02,  5.2665e-02,  9.7904e-02, -2.5723e-02, -3.2497e-02,
          9.0306e-02, -8.9747e-02,  7.0985e-02,  4.8148e-02,  8.0181e-02,
         -7.6122e-02, -1.0335e-01, -8.7064e-02,  9.8974e-02, -5.0385e-02,
          4.0987e-02,  1.0069e-01,  4.4886e-02, -6.0288e-02,  2.8446e-02,
          9.9679e-02,  5.8738e-02, -8.6646e-02,  4.9163e-02, -1.0731e-01,
         -8.7259e-02,  8.3456e-02,  6.8349e-02, -8.7784e-02, -8.0034e-02,
          4.8366e-02,  4.0478e-02,  5.7153e-03, -5.4092e-02,  5.2010e-03,
          8.2633e-02, -1.0290e-01,  5.6955e-02, -9.8502e-02,  4.5809e-03,
          4.4042e-02,  9.9767e-03,  5.3250e-02, -2.6768e-02, -7.9477e-02,
         -1.0056e-01, -1.5802e-03, -6.9598e-03, -6.3054e-02,  1.0204e-01,
         -9.7165e-02, -9.1548e-02,  2.5512e-02, -5.9958e-02, -2.2757e-02,
         -7.9634e-03, -1.0523e-01, -6.3195e-02, -1.2414e-02,  1.1663e-02,
         -8.9895e-02, -8.0842e-02,  1.0750e-01, -3.0337e-02],
        [ 5.6841e-02,  8.6600e-02,  1.7552e-02, -5.7772e-03, -4.4277e-02,
         -4.1141e-03, -4.0124e-02,  3.6352e-02, -8.9301e-02,  4.6127e-02,
          7.3090e-02, -9.3387e-02, -3.1823e-02,  4.5977e-02,  1.0641e-01,
         -2.9397e-02,  4.1653e-02,  5.7302e-03,  4.0278e-02, -4.2121e-02,
         -1.0518e-01,  9.9239e-02, -2.4550e-02,  5.3995e-02, -4.9111e-02,
         -1.2914e-02,  5.2915e-02,  8.5013e-02,  6.0425e-02,  9.0397e-02,
         -7.9633e-02,  3.0904e-02,  8.4106e-02,  7.7185e-02, -1.1237e-02,
         -3.8323e-02,  4.4578e-03, -4.1946e-02, -9.6759e-02, -1.2592e-02,
         -4.1569e-02, -1.0844e-01, -5.4256e-02,  2.7450e-02, -6.3376e-02,
         -1.0629e-01, -2.7586e-02, -4.1033e-02, -1.0556e-01, -1.6890e-03,
         -1.0441e-01, -2.9892e-02,  9.5217e-03,  8.4235e-02, -1.0753e-01,
          2.1481e-02,  5.6448e-03, -9.0530e-02,  1.0382e-01, -2.1958e-02,
          4.4257e-02,  4.6737e-02, -6.9540e-02, -5.2136e-02,  6.0439e-02,
          4.6486e-02, -7.2450e-02, -1.1493e-02,  1.7681e-02,  6.1643e-02,
          7.3240e-02,  2.1737e-02,  1.2148e-02, -9.1230e-02,  7.5856e-02,
          5.2213e-02, -2.0461e-02, -5.2718e-02, -1.0508e-01,  1.0872e-02,
         -1.5246e-03, -2.6219e-02,  8.0384e-03,  1.4495e-02],
        [-5.3898e-02, -6.2903e-02,  9.3278e-02,  5.2151e-02,  2.4894e-02,
          7.4073e-02,  5.3318e-02, -2.8024e-02, -1.0801e-01,  1.0287e-01,
          4.7442e-02,  5.0882e-02, -2.4797e-02, -9.7050e-02,  8.4431e-02,
          6.6385e-02,  1.0674e-01, -1.0502e-01, -5.5089e-03, -2.1849e-02,
          4.1641e-02,  6.3477e-02, -7.0544e-02, -1.0295e-01, -6.6842e-02,
          8.4617e-03,  6.7929e-02,  5.4032e-02, -1.3090e-02, -2.8971e-02,
         -4.3182e-02, -4.4062e-02, -9.7478e-02, -6.2066e-02, -5.4336e-02,
          5.0292e-02,  7.3119e-02,  8.2924e-02,  9.4342e-02,  4.5792e-02,
         -4.1454e-02,  6.4998e-02, -5.4642e-02,  1.6673e-02, -3.3706e-02,
         -9.6359e-02,  1.7345e-04, -9.2701e-02,  8.3176e-03,  2.0206e-02,
         -3.3897e-03, -4.6855e-02,  5.2317e-02, -4.5804e-02,  2.2652e-02,
         -1.0233e-01, -3.7416e-03, -7.8552e-02, -1.7273e-02, -2.1090e-02,
         -7.0454e-03,  6.4824e-02,  1.0302e-01, -5.9749e-03,  1.0142e-01,
          7.1721e-02, -3.2308e-02,  1.0354e-01, -2.1937e-02,  9.5663e-02,
         -5.3376e-03, -8.3901e-02,  2.2311e-02,  4.4284e-02, -1.1621e-03,
         -1.9146e-02,  9.5266e-02,  4.9740e-02,  8.7377e-02, -6.9876e-02,
          7.4907e-02, -1.7911e-02,  7.3294e-02,  7.7561e-02],
        [-1.0208e-01,  8.0829e-02, -7.1439e-02,  2.2382e-02,  7.3827e-02,
         -5.1018e-02, -6.3343e-02, -2.9540e-02,  9.3982e-02,  4.3037e-02,
          8.5199e-02, -5.6971e-02,  2.5385e-02,  7.3944e-02,  3.9802e-02,
          7.6630e-02, -1.4518e-02,  5.0097e-02,  4.9435e-02,  6.8677e-02,
         -1.9400e-02,  9.0476e-03,  9.9804e-02, -9.8907e-02,  7.7342e-02,
          5.8050e-02, -1.0695e-02,  2.4084e-02,  6.4561e-02, -1.5142e-02,
         -2.9764e-02,  1.3710e-03,  1.3472e-02, -4.0897e-03,  3.9828e-02,
          1.0682e-01, -3.4013e-02, -7.1825e-02,  1.0256e-01, -4.8058e-02,
         -7.9164e-02, -3.1616e-02,  1.4847e-02, -5.4025e-02,  1.0359e-01,
         -7.2879e-02, -5.5342e-02,  9.1632e-02,  9.1646e-02, -4.5751e-02,
          6.7954e-02,  3.2295e-02,  6.9718e-03,  9.7817e-02,  1.0138e-01,
          5.1222e-02, -6.8627e-02, -4.2274e-02, -5.4111e-02,  2.0167e-02,
          9.3684e-02, -7.3912e-02, -4.8360e-02,  2.9101e-02,  5.3984e-02,
         -6.4280e-02, -7.3429e-02,  6.6071e-02, -4.8896e-02, -2.6180e-02,
          8.5460e-02,  7.1012e-02,  2.1531e-02, -7.9570e-02,  9.2200e-02,
         -9.2980e-02,  7.6317e-02, -4.6281e-02, -1.0742e-01,  3.3758e-02,
         -1.0100e-01, -5.6719e-02, -1.4088e-02,  6.6273e-02],
        [ 5.3003e-02, -8.5680e-02, -3.2159e-04, -6.6686e-02,  1.0357e-01,
          7.7528e-02, -7.7695e-02, -3.7343e-02, -6.4867e-02, -3.4528e-02,
          3.8912e-02, -2.9453e-02,  7.2781e-02,  8.9834e-02, -7.9001e-02,
         -8.3703e-02,  5.7901e-02,  6.9107e-02,  7.9265e-02,  4.9187e-02,
          5.0513e-02, -5.4097e-02, -3.2582e-02,  6.4819e-02,  1.9896e-02,
         -7.3084e-02, -8.9256e-02,  4.9840e-02, -5.9509e-02,  5.3884e-02,
         -3.8777e-02,  6.0162e-02,  1.0393e-01, -2.5981e-02,  3.4439e-02,
          7.1359e-02, -8.5326e-02, -8.6582e-02,  2.3416e-02, -4.4381e-02,
         -3.0652e-03,  3.3561e-02,  4.0531e-02,  6.6637e-02, -1.0751e-01,
         -2.0587e-02, -5.7200e-02,  8.0107e-02, -9.5006e-02,  8.8297e-02,
          5.6524e-02,  5.3064e-02, -5.4679e-02, -6.5577e-02,  2.3586e-02,
          2.7163e-02,  3.3865e-03,  9.4055e-02, -4.6757e-02, -1.8562e-02,
         -7.2532e-02,  4.8080e-02, -9.0010e-02, -1.0510e-01, -7.9523e-02,
         -7.7586e-02, -8.7399e-02, -2.3291e-02,  5.0533e-02,  9.0525e-02,
          5.7866e-02,  9.3118e-02,  6.4315e-02, -9.9633e-02,  3.1986e-02,
          8.2956e-02,  7.0101e-02, -2.0538e-02,  8.6799e-02, -6.9437e-02,
          1.0641e-01,  1.0426e-01,  8.0581e-02, -8.0374e-02],
        [ 3.2695e-02, -3.4856e-02,  6.2409e-02, -1.3378e-02,  6.0558e-03,
          1.4727e-02, -8.5913e-02,  1.0691e-01,  7.7849e-02, -4.7844e-02,
          6.0482e-02, -7.5180e-03, -4.0028e-02,  6.0692e-02, -5.5052e-02,
          3.3449e-02, -8.4209e-02, -1.3423e-02,  9.3352e-02,  1.3072e-02,
         -5.9021e-04,  7.6759e-02,  8.5805e-02, -2.4546e-02, -9.6939e-02,
         -4.4700e-02,  5.6368e-02, -1.1727e-02,  5.2196e-02,  6.8510e-02,
          4.8723e-02, -3.9295e-02, -3.2981e-02, -3.3015e-02, -1.3612e-02,
         -7.1090e-02, -7.5796e-02, -4.9240e-02,  1.4497e-02,  2.2467e-02,
         -4.2246e-03,  7.9742e-02,  9.8798e-02, -1.0770e-01, -1.0561e-01,
          8.5984e-02, -9.2008e-02,  5.3294e-02, -5.5027e-02, -4.8945e-02,
          1.9929e-02,  8.5859e-02,  6.2954e-02,  2.9138e-03,  4.9239e-02,
          2.2325e-02, -9.6379e-02,  1.4460e-02, -3.5029e-02, -4.5814e-02,
         -7.4607e-02,  7.0636e-02,  6.4684e-02, -2.8305e-02,  1.7848e-03,
         -1.5799e-02,  6.4801e-02,  5.6922e-02, -9.3420e-02, -2.4182e-02,
          4.9939e-02, -4.6919e-02, -9.4452e-02, -6.2229e-02,  2.6796e-02,
         -1.9359e-02, -1.0534e-01, -1.0308e-01,  6.7256e-02, -1.0804e-01,
         -5.2163e-02, -3.9063e-02, -1.1436e-02, -4.9283e-02],
        [ 3.7929e-02, -9.2263e-02,  8.3047e-02, -1.4673e-03,  1.0762e-01,
          4.5225e-02, -6.9877e-02, -8.5745e-02,  4.3992e-02,  8.6473e-02,
         -3.9794e-02, -6.4171e-02,  2.1542e-02,  1.5011e-03, -8.4764e-02,
          7.6233e-02,  1.4007e-02, -4.1551e-02,  6.5961e-02, -9.7566e-02,
         -9.3359e-02,  3.9779e-02,  3.7884e-03,  2.0744e-02,  1.0474e-01,
          4.3060e-02, -2.9415e-02, -1.3947e-02, -1.0270e-01, -1.4549e-02,
         -3.3965e-02,  7.2888e-02, -1.5545e-02, -2.1496e-02, -9.5266e-02,
          4.0747e-02, -2.9195e-02, -6.9564e-02,  8.0910e-02, -4.2163e-02,
         -1.3920e-02, -3.4202e-02,  8.5124e-02,  8.0781e-02, -4.7417e-02,
          1.0826e-01,  2.4278e-02,  3.6846e-02, -9.6196e-02,  1.2560e-02,
          9.7764e-02,  1.0009e-01,  1.0832e-01,  2.5618e-03,  7.1625e-02,
         -7.1904e-02, -6.2205e-02,  5.8912e-02,  2.3969e-02, -9.6306e-02,
          2.9396e-02,  3.1466e-02,  6.3138e-02,  3.7381e-02,  1.0424e-01,
         -1.0732e-01, -1.0593e-01,  3.2480e-02, -1.0435e-01, -1.3005e-02,
          5.2493e-02, -6.7261e-02, -7.5851e-02, -4.7852e-05,  1.2423e-02,
         -7.6787e-02, -1.0262e-02, -1.6454e-02,  9.6626e-02, -1.0664e-01,
         -8.2998e-02, -2.1841e-02, -5.2301e-02, -1.5127e-02]],
       requires_grad=True)), ('fc3.bias', Parameter containing:
tensor([ 0.0554,  0.0078,  0.0239,  0.0951,  0.0102,  0.0927,  0.0383,  0.0847,
         0.0910, -0.0069], requires_grad=True))]

测试随机输入32×32。 注:这个网络(LeNet)期望的输入大小是32×32,如果使用MNIST数据集来训练这个网络,请把图片大小重新调整到32×32。

input = torch.randn(1, 1, 32, 32)
out = net(input)
print(out)
tensor([[-0.0384,  0.0331,  0.0275,  0.0812, -0.0870,  0.1044,  0.0813,  0.0467,
          0.0679,  0.0307]], grad_fn=<AddmmBackward>)

将所有参数的梯度缓存清零,然后进行随机梯度的的反向传播:

net.zero_grad()
out.backward(torch.randn(1, 10))

Note
torch.nn 只支持小批量输入。整个 torch.nn 包都只支持小批量样本,而不支持单个样本。 例如,nn.Conv2d 接受一个4维的张量, 每一维分别是sSamples * nChannels * Height * Width(样本数*通道数*高*宽)。 如果你有单个样本,只需使用 input.unsqueeze(0) 来添加其它的维数。

损失函数

一个损失函数接受一对 (output, target) 作为输入,计算一个值来估计网络的输出和目标值相差多少。
nn包中有很多不同的损失函数。 nn.MSELoss是一个比较简单的损失函数,它计算输出和目标间的均方误差, 例如:

output = net(input)
target = torch.randn(10)  # 随机值作为样例
target = target.view(1, -1)  # 使target和output的shape相同
criterion = nn.MSELoss()

loss = criterion(output, target)
print(loss)
tensor(0.4847, grad_fn=<MseLossBackward>)

反向传播

调用loss.backward()获得反向传播的误差。

但是在调用前需要清除已存在的梯度,否则梯度将被累加到已存在的梯度。

现在,我们将调用loss.backward(),并查看conv1层的偏差(bias)项在反向传播前后的梯度。

net.zero_grad()     # 清除梯度

print('conv1.bias.grad before backward')
print(net.conv1.bias.grad)

loss.backward()

print('conv1.bias.grad after backward')
print(net.conv1.bias.grad)
conv1.bias.grad before backward
tensor([0., 0., 0., 0., 0., 0.])
conv1.bias.grad after backward
tensor([ 0.0051,  0.0042,  0.0026,  0.0152, -0.0040, -0.0036])

更新权重

在实践中最简单的权重更新规则是随机梯度下降(SGD):

weight = weight - learning_rate * gradient
我们可以使用简单的Python代码实现这个规则:

learning_rate = 0.01
for f in net.parameters():
    f.data.sub_(f.grad.data * learning_rate)

但是当使用神经网络是想要使用各种不同的更新规则时,比如SGD、Nesterov-SGD、Adam、RMSPROP等,PyTorch中构建了一个包torch.optim实现了所有的这些规则。 使用它们非常简单:

import torch.optim as optim

# create your optimizer
optimizer = optim.SGD(net.parameters(), lr=0.01)

# in your training loop:
optimizer.zero_grad()   # zero the gradient buffers
output = net(input)
loss = criterion(output, target)
loss.backward()
optimizer.step()    # Does the update

你可能感兴趣的:(Pytorch框架,pytorch,迁移学习,深度学习)