深度学习笔记(六)——分类网络的训练问题-2

3.4 New model

3.4.1 Save model

Pytorch provide two kinds of method to save model. We recommmend the method which only saves parameters. Because it's more feasible and dont' rely on fixed model.

When saving parameters, we not only save learnable parameters in model, but also learnable parameters in optimizer.

A common PyTorch convention is to save models using either a .pt or .pth file extension.

Read more abount save load from this link.

# show parameters in model

# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

# Print optimizer's state_dict
print("\nOptimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])
Model's state_dict:
hidden1.weight   torch.Size([100, 784])
hidden1.bias     torch.Size([100])
hidden2.weight   torch.Size([100, 100])
hidden2.bias     torch.Size([100])
hidden3.weight   torch.Size([100, 100])
hidden3.bias     torch.Size([100])
classification_layer.weight      torch.Size([10, 100])
classification_layer.bias    torch.Size([10])
hidden1_bn.weight    torch.Size([100])
hidden1_bn.bias      torch.Size([100])
hidden1_bn.running_mean      torch.Size([100])
hidden1_bn.running_var   torch.Size([100])
hidden1_bn.num_batches_tracked   torch.Size([])
hidden2_bn.weight    torch.Size([100])
hidden2_bn.bias      torch.Size([100])
hidden2_bn.running_mean      torch.Size([100])
hidden2_bn.running_var   torch.Size([100])
hidden2_bn.num_batches_tracked   torch.Size([])
hidden3_bn.weight    torch.Size([100])
hidden3_bn.bias      torch.Size([100])
hidden3_bn.running_mean      torch.Size([100])
hidden3_bn.running_var   torch.Size([100])
hidden3_bn.num_batches_tracked   torch.Size([])

Optimizer's state_dict:
state    {}
param_groups     [{'lr': 0.75, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4755824576, 4755820904, 4750998264, 4757925536, 4757922584, 4758702408, 4758703200, 4758702552, 4758702480, 4758702264, 4758703704, 4758702912, 4764186232, 4764188032]}]
# save model

save_path = './model.pt'
torch.save(model.state_dict(), save_path)
# load parameters from files
saved_parametes = torch.load(save_path)
print(saved_parametes)
OrderedDict([('hidden1.weight', tensor([[ 0.0061,  0.0296, -0.0111,  ...,  0.0030, -0.0219, -0.0101],
        [-0.0171,  0.0213,  0.0470,  ...,  0.0168, -0.0097, -0.0076],
        [-0.0094,  0.0342,  0.0366,  ...,  0.0347,  0.0201, -0.0014],
        ...,
        [ 0.0357,  0.0599,  0.0044,  ...,  0.0245,  0.0249,  0.0117],
        [ 0.0388, -0.0259,  0.0334,  ...,  0.0303,  0.0065, -0.0191],
        [ 0.0564,  0.0475,  0.0173,  ...,  0.0403,  0.0442,  0.0449]])), ('hidden1.bias', tensor([-0.0168, -0.0027, -0.0294, -0.0164,  0.0031, -0.1126, -0.1200, -0.0309,
         0.0018, -0.0125, -0.0191, -0.0128, -0.0523, -0.0306,  0.0244, -0.0634,
        -0.0119, -0.0476, -0.1635, -0.0615,  0.0005, -0.0329, -0.0547, -0.0155,
        -0.0197, -0.0935, -0.0182, -0.1492,  0.0312, -0.0513, -0.1478, -0.0836,
         0.0351, -0.0060,  0.0264,  0.0090, -0.0292, -0.0760, -0.0030, -0.0301,
        -0.0226, -0.1158, -0.0211, -0.0105, -0.1547, -0.1294, -0.0352, -0.0362,
        -0.0490, -0.0284, -0.0899, -0.0111,  0.0088,  0.0089, -0.1379, -0.0392,
         0.0047, -0.0556, -0.1105, -0.0871, -0.0625, -0.0557, -0.0433, -0.0270,
        -0.0180,  0.0207, -0.0378, -0.0158, -0.1503, -0.0545, -0.0462, -0.0816,
         0.0008, -0.0367, -0.0082, -0.0644, -0.0191, -0.0992, -0.0545, -0.0881,
        -0.1154, -0.0954, -0.0931, -0.0208, -0.1681, -0.0307,  0.0138, -0.0588,
        -0.0424, -0.0218, -0.0310, -0.0141, -0.0217, -0.0678, -0.1139,  0.0142,
        -0.0263, -0.0896, -0.0440, -0.0806])), ('hidden2.weight', tensor([[ 0.0156,  0.0259,  0.0132,  ...,  0.0317,  0.0130, -0.0083],
        [-0.0703,  0.0066,  0.0261,  ..., -0.1618, -0.1010, -0.0783],
        [-0.0013,  0.0448, -0.0532,  ..., -0.0807,  0.0350,  0.0551],
        ...,
        [-0.0748, -0.0055, -0.0958,  ..., -0.0372,  0.0271, -0.1036],
        [ 0.0920,  0.1272,  0.0763,  ..., -0.0787,  0.0597, -0.1064],
        [-0.0779,  0.0371,  0.0344,  ..., -0.0633,  0.0402, -0.0065]])), ('hidden2.bias', tensor([-0.5761,  0.5198,  0.3693, -0.1639, -0.1722, -0.4134,  1.0224,  0.0591,
        -0.1358,  0.0150, -0.1590, -0.2059, -0.0574,  0.3346, -0.1240, -0.0494,
        -0.0782, -0.0758,  0.2674, -0.0309, -0.2096, -0.3061, -0.1266, -0.2250,
        -0.0352, -0.3626, -0.3968, -0.1523, -0.1501,  0.0105, -0.1572,  0.4409,
        -0.0585, -0.1668,  0.0431, -0.3306, -0.2386, -0.4994, -0.0402,  0.2434,
        -0.0695,  0.4839, -0.0635, -0.3354, -0.2052,  0.1460, -0.3221, -0.4942,
        -0.4669, -0.1758, -0.2361,  0.0703, -0.0994, -0.3179, -0.0522, -0.3119,
         0.4844,  1.0562, -0.2837, -0.2965, -0.1459, -0.1997, -0.5648, -0.0028,
        -0.2376, -0.1025, -0.0931, -0.1769,  0.0466, -0.0933, -0.1596, -0.3318,
        -0.2438,  0.0077, -0.1148, -0.0701, -0.2182,  0.0352, -0.1677, -0.2224,
        -0.1809,  0.0568, -0.0896, -0.0801, -0.2565, -0.4778, -0.1549, -0.0518,
        -0.5629, -0.0945,  0.8213, -0.0217, -0.0893, -0.3187, -0.2347,  0.4022,
        -0.3037, -0.0043, -0.0388,  0.0045])), ('hidden3.weight', tensor([[-0.1371, -0.1332,  0.0756,  ..., -0.1936, -0.1040, -0.0236],
        [ 0.0034,  0.0138, -0.0925,  ..., -0.0231, -0.1404, -0.0059],
        [-0.0852,  0.0128, -0.0367,  ...,  0.2121, -0.1505, -0.0288],
        ...,
        [-0.1071,  0.0453, -0.0177,  ..., -0.0548, -0.0398,  0.1109],
        [-0.0492,  0.0867,  0.3073,  ..., -0.0626,  0.1075,  0.2109],
        [-0.1140, -0.0369, -0.0115,  ..., -0.0396, -0.0358, -0.0073]])), ('hidden3.bias', tensor([-8.4660e-02, -3.2333e-02, -6.0429e-02, -1.2267e-01, -1.1553e-01,
        -3.6592e-02, -9.9289e-02,  6.3957e-01, -2.0471e-01, -1.2567e-01,
        -2.4764e-02, -1.0635e-01, -2.6803e-02, -8.6840e-02, -2.4284e-01,
        -1.1553e-01,  1.1392e-03, -1.0988e-01, -9.8350e-02, -2.0178e-02,
        -1.0630e-01, -8.7644e-02, -6.7755e-02, -1.5455e-01, -8.0500e-02,
         2.2053e-01,  5.6742e-02, -1.4824e-01, -3.3071e-02,  3.2688e-02,
         6.2942e-01, -4.6284e-02,  2.1287e-01, -3.4355e-02, -1.2961e-01,
         2.9527e-01,  1.2094e-03, -3.3945e-02, -2.1949e-01, -7.0505e-02,
        -9.2214e-02, -1.1195e-01,  3.7178e-01, -2.5034e-02, -1.8616e-01,
        -1.0701e-01, -6.5656e-02, -6.3755e-02,  8.5521e-01, -1.4393e-01,
        -1.8443e-01,  1.7599e-02,  4.3720e-01, -1.0936e-01,  1.0006e-01,
        -8.8871e-02,  5.2978e-01, -1.1293e-01, -1.1250e-01, -2.5872e-01,
        -2.4333e-01, -7.4563e-02, -1.1477e-01,  9.9877e-02, -1.2331e-01,
        -1.0594e-01, -2.8752e-02, -3.4128e-02, -2.5374e-01, -8.5538e-02,
        -8.6164e-02,  4.9599e-01,  6.3113e-01, -4.9306e-02,  8.3178e-02,
         9.6917e-01,  1.7951e+00, -1.9829e-01, -1.7462e-01,  1.0686e-01,
         2.3232e-02, -1.1916e-01, -1.2637e-01,  1.2163e+00, -5.2430e-02,
        -1.2705e-01, -1.1642e-01, -1.4296e-01, -7.0017e-02,  3.6222e-01,
        -1.9231e-01, -9.3500e-02, -6.6554e-02,  7.4068e-02, -1.1235e-01,
        -1.0035e-01,  2.5663e-01, -6.0805e-02,  1.2717e+00, -8.1130e-02])), ('classification_layer.weight', tensor([[-6.1187e-01, -5.6507e-01, -1.4424e-03, -2.7045e-01,  1.5389e-01,
         -1.4199e-01, -1.3265e-01, -1.2181e-01, -5.0219e-01, -7.1377e-02,
          1.0560e-02, -3.2474e-01, -1.6185e-01, -5.3878e-02,  2.9584e-01,
          1.2464e-03, -1.1910e-01, -1.5456e-01, -4.6994e-01,  7.8584e-02,
         -5.3734e-01,  6.5176e-01, -7.4570e-03,  7.1858e-02, -9.0464e-02,
         -5.4486e-02, -4.3265e-01, -4.6849e-02, -1.6478e-01, -6.6419e-01,
         -1.5395e-01, -7.8686e-02, -7.1704e-02, -3.8201e-02,  1.1336e-03,
          3.0307e-01, -6.6520e-02, -4.9982e-02, -1.5092e-01,  3.2128e-02,
         -3.9149e-01,  3.1262e-02,  3.2770e-02,  1.7711e-02,  1.5304e-01,
         -1.3411e-01,  2.5674e-02, -1.7345e-02,  3.5925e-01, -3.7818e-01,
         -2.2275e-01, -4.9380e-01, -1.3756e-01, -2.8159e-01, -1.1654e-01,
         -2.2355e-02, -5.9519e-01,  1.6007e-02,  1.5933e-01, -1.2804e-01,
         -2.1505e-01, -6.4397e-02, -3.2399e-01, -5.6055e-02, -5.0692e-01,
         -2.1875e-01, -8.4137e-02, -1.7504e-01, -1.1924e-01,  5.5566e-02,
         -3.4110e-01,  8.7355e-03, -5.7918e-03, -6.6834e-02, -1.4117e-01,
         -5.4462e-01,  2.7181e-01,  6.9094e-02, -5.3700e-02, -1.1022e-01,
          1.0807e-02, -1.8002e-01, -2.0719e-02,  1.1164e-01,  2.1247e-02,
         -8.1494e-01, -2.8763e-01, -3.5509e-01,  4.1251e-02,  3.2906e-01,
          1.0091e-01, -1.9347e-01, -1.4978e-01,  8.7678e-02,  2.4100e-02,
         -1.8897e-01, -4.3147e-01,  5.3150e-03, -1.3036e-01,  2.1785e-02],
        [-1.2737e-01, -6.2360e-02, -2.6765e-01, -1.5767e-01, -2.5594e-02,
         -6.6481e-02,  1.7943e-01,  4.4155e-02, -2.4315e-01, -4.1117e-02,
         -1.2292e-01,  2.7662e-02, -2.0391e-01,  2.3087e-01, -7.2216e-02,
          6.3339e-02, -6.5326e-02, -1.2291e-01, -6.6390e-02,  1.8075e-01,
         -1.3624e-01, -3.4863e-01, -4.5377e-02, -2.3763e-01, -1.3728e-01,
         -1.0981e-01,  8.0206e-01, -2.9498e-02,  1.0278e-01,  1.7782e-01,
          3.7626e-02, -4.1234e-02,  3.0991e-02, -1.1380e-01,  5.3500e-02,
          4.3036e-02,  2.1550e-01, -1.3913e-03,  3.5874e-02, -9.0758e-02,
         -6.9365e-03, -3.5689e-02, -1.4543e-01, -4.2391e-02,  1.0947e-01,
         -2.7372e-02, -2.8920e-02,  1.0706e-02, -5.6517e-02, -1.0215e-01,
         -1.1967e-01, -8.3464e-02,  5.2941e-01, -1.6522e-01,  1.5481e-01,
          1.5158e-03, -2.7996e-01,  6.0395e-02, -3.8386e-02, -8.7508e-02,
         -6.1256e-02, -6.3347e-03, -3.6381e-03,  8.2017e-02,  7.5017e-02,
         -3.0227e-02, -1.6147e-01,  5.2752e-02, -1.7114e-01, -1.0651e-01,
          4.2548e-02, -7.4882e-02, -1.2945e-01,  1.9034e-02, -8.8790e-02,
         -1.4525e-01, -1.2573e-01, -2.6728e-02, -2.5224e-02, -3.1653e-01,
          8.1042e-02, -1.5198e-01, -3.8757e-02, -1.7580e-01,  6.2542e-02,
          5.6218e-02, -5.4357e-02, -1.2610e-01,  5.5038e-02, -9.7318e-02,
          1.9284e-01,  1.1065e-01, -5.6410e-02, -1.7474e-02, -1.0433e-01,
          8.9177e-02, -7.3771e-02, -1.5495e-01,  1.1680e-01, -5.6681e-03],
        [ 1.9186e-01,  1.0223e-01,  1.0701e-01,  2.7541e-01, -7.9669e-01,
          2.0961e-01,  9.7709e-02, -7.4805e-01,  1.8823e-01,  2.1411e-01,
          2.1111e-02,  2.3222e-01,  1.8330e-01,  1.1841e-01, -3.1634e-02,
          1.5799e-01,  4.0911e-02,  8.6514e-02,  2.1074e-01, -3.2546e-01,
         -4.7492e-02,  6.6884e-02, -1.3732e-01, -4.2719e-02,  6.7551e-02,
         -2.7026e-02,  1.0300e-01, -3.1153e-02,  6.7054e-02, -6.5785e-02,
         -3.3956e-01, -6.5564e-02, -2.8902e-03, -5.6428e-02, -7.6258e-01,
          1.1608e-01,  4.2748e-01,  6.4623e-02,  8.0076e-02,  4.0094e-02,
          1.4439e-01,  1.7171e-01, -4.2383e-01, -7.9801e-02,  1.9432e-02,
         -3.3688e-02, -3.8898e-02,  3.7044e-02,  2.2053e-01,  1.0799e-04,
         -7.8533e-02,  2.4398e-01, -4.2730e-01,  4.2746e-02,  1.2053e-01,
         -2.5344e-02, -3.3898e-01,  1.2668e-01, -1.0610e-01, -2.4014e-01,
         -8.3949e-03, -6.4227e-02,  2.2441e-02, -1.1778e-01,  6.6599e-02,
          3.1477e-02, -1.9016e-04,  8.3123e-02,  1.5891e-01,  1.7011e-01,
          1.5146e-01, -7.9873e-01, -3.9691e-01, -1.3733e-01,  7.5391e-02,
          2.0980e-01, -3.2066e-01, -1.3985e-02,  7.8271e-03, -6.8302e-03,
          1.7113e-01,  1.0262e-01, -1.6727e-03, -6.6286e-02, -2.0698e-01,
          1.8812e-01, -4.9481e-02, -2.0186e-01,  2.3084e-01, -1.0781e-01,
         -6.5467e-01, -6.3727e-04,  1.1265e-01, -2.4518e-01, -7.0207e-03,
          7.2730e-02,  1.9851e-01, -6.7493e-02,  6.7490e-01, -4.7554e-04],
        [ 8.5253e-02,  3.6790e-01, -2.0190e-01, -1.2884e-02,  2.7314e-01,
          1.5281e-01,  1.3067e-01, -1.1924e-01,  4.6366e-01,  5.2650e-02,
          3.8249e-02,  6.9268e-02,  4.2244e-02,  5.4462e-02,  4.3517e-01,
          1.4550e-01,  1.2783e-02, -3.6348e-01, -4.2320e-02,  1.1229e-01,
          1.2455e-01, -3.9866e-01,  1.1623e-01, -9.5708e-02, -3.7096e-01,
         -3.1879e-01, -1.1747e-01, -3.0447e-01,  1.0429e-01,  2.8921e-01,
         -2.6497e-01,  2.5279e-01,  7.5485e-02, -2.3777e-01,  1.3919e-01,
          4.4179e-01,  7.6398e-02,  1.2101e-01,  1.2285e-01, -4.5768e-02,
          1.6343e-01,  5.6805e-02,  6.3629e-01, -9.1578e-02,  2.1487e-01,
          4.9422e-02,  9.7040e-02,  1.7530e-02,  4.7804e-02,  1.1712e-01,
          1.8626e-01, -6.1781e-03, -3.3298e-01, -2.0199e-01,  2.8425e-01,
          4.5014e-02, -1.7341e-01, -5.8929e-04, -8.4011e-04,  2.6978e-01,
          4.0128e-01,  1.3957e-01,  1.9709e-02,  1.9069e-01,  1.9123e-01,
         -5.0225e-02, -4.7232e-02, -4.8383e-02, -6.6502e-02, -3.8041e-02,
         -2.3345e-02,  2.3051e-01, -4.9389e-02, -4.8033e-02,  1.1691e-01,
          5.1496e-01,  1.0084e-01,  7.1630e-02,  3.2055e-02, -3.3248e-01,
         -2.0284e-02, -1.6052e-01, -1.9679e-01, -4.4816e-02,  4.8449e-02,
          3.4831e-01,  1.8643e-01,  2.9630e-01, -1.4649e-01, -2.7486e-01,
          1.6517e-01,  3.6800e-02,  2.5259e-02, -9.9867e-02,  4.6995e-02,
          1.6073e-01,  1.3008e-01, -4.7025e-02,  3.1125e-01,  6.8535e-02],
        [-6.0539e-02, -4.4903e-03, -5.5134e-02,  2.5655e-01,  1.9688e-01,
          1.5103e-01, -1.1826e-02,  5.3932e-01,  7.4606e-01, -1.1625e-01,
          2.0639e-01,  1.0444e-01, -1.2815e-01, -8.3173e-02,  7.3563e-01,
         -5.1470e-02,  1.3627e-01, -1.0789e-01,  7.4358e-02, -6.6397e-02,
         -1.4015e-03, -1.7172e-01, -1.2843e-02,  4.1225e-01, -1.1703e-01,
          1.9197e-01,  1.6208e-01, -1.0199e-01,  1.2796e-02, -7.4289e-02,
         -1.6336e-01,  1.0056e-01, -1.6830e-02, -4.2748e-02,  2.5940e-01,
         -3.4320e-01, -8.1927e-02, -2.8906e-02, -2.6072e-02,  5.1455e-02,
          7.6014e-02,  7.1832e-02, -6.1156e-01,  3.4221e-02,  1.4976e-01,
         -1.4457e-01, -3.2255e-03, -3.5813e-02, -3.8536e-01,  2.4207e-02,
          5.6061e-02,  5.4010e-02,  2.2706e-01, -2.7755e-02, -2.5191e-01,
          1.4227e-01,  1.6484e-01,  4.1759e-03,  2.1995e-01, -1.0093e-01,
          5.4968e-01, -3.1629e-01,  1.4322e-01,  2.0420e-02,  4.4953e-02,
         -1.2201e-01,  1.6357e-02, -5.3477e-02,  3.4353e-02,  1.7106e-02,
          2.1129e-02,  4.8434e-02,  3.3463e-01, -4.9149e-03, -2.9105e-01,
         -4.1212e-01, -1.6414e-01,  5.6706e-02, -9.4353e-02,  4.0012e-01,
          1.1213e-01, -9.4816e-03, -9.8370e-02, -2.9450e-01, -9.1417e-03,
         -8.6727e-02, -9.5072e-02,  3.8655e-01, -7.0459e-02,  5.6630e-01,
         -3.7482e-02, -8.3377e-02,  4.8785e-02, -2.2670e-01, -2.7037e-02,
         -8.2207e-02,  4.1002e-01,  5.0883e-02, -3.9839e-01, -4.4907e-02],
        [ 1.6589e-01,  9.1893e-02,  1.9530e-01,  9.2547e-02,  1.7666e-01,
         -2.1956e-01,  5.9964e-02,  1.5438e-01,  1.8840e-01, -5.4930e-02,
          2.8582e-02,  4.0812e-02,  1.9048e-01,  2.9546e-02,  2.1266e-01,
          3.4931e-02, -2.8150e-02, -4.4153e-02,  9.9252e-02,  1.8868e-01,
          1.3800e-01,  2.0872e-02,  1.5372e-01,  2.1436e-02,  6.6080e-01,
         -2.0198e-01, -4.5529e-01,  3.0689e-01, -2.0198e-02,  2.4786e-01,
          3.5457e-01,  2.6853e-01, -5.6232e-02,  3.6438e-01,  6.8775e-02,
          1.0726e-01,  1.6393e-01, -5.9914e-03, -4.6087e-03,  2.8990e-02,
         -4.6558e-02,  9.0374e-02,  6.3195e-01, -6.4279e-02,  2.0935e-02,
          9.9869e-02, -4.7908e-02, -4.9447e-02,  1.0565e-01,  1.4623e-01,
         -1.7076e-01,  1.9899e-01, -1.0314e-02, -3.7178e-02, -1.6649e-01,
          5.4754e-02,  2.6698e-01,  8.6397e-02, -4.9514e-02, -2.0020e-01,
          9.0062e-02,  4.2456e-01,  8.0474e-02, -8.2473e-02,  1.6329e-01,
          1.9541e-01, -5.2071e-02,  1.3438e-01,  5.3355e-02, -4.0501e-02,
          7.2137e-02,  9.4610e-02, -1.2058e-01,  1.0757e-02,  3.5982e-01,
         -7.8495e-02, -1.7245e-01, -1.8146e-02, -9.3419e-02,  8.2358e-02,
          1.1147e-01, -5.9163e-02,  3.2430e-01,  5.5334e-02,  1.8202e-01,
          1.1893e-01,  1.5508e-02,  1.8176e-01, -1.0215e-02, -2.3243e-01,
          8.3380e-02,  8.7213e-02,  6.3830e-02,  5.1061e-01,  1.6680e-01,
          1.0405e-01, -1.7493e-01,  1.2890e-01, -2.6674e-01,  1.2023e-01],
        [ 8.3150e-02, -3.8342e-02,  9.8947e-02,  2.5845e-02,  7.0042e-02,
         -3.4700e-01, -3.1025e-02, -4.6168e-01,  5.8525e-02,  4.1622e-02,
          7.0948e-03, -3.8280e-02,  1.4927e-01, -5.3696e-02,  8.4766e-02,
         -9.2226e-02, -3.8204e-02,  2.7261e-01,  1.0362e-01,  5.6582e-01,
          1.5963e-01,  3.7851e-01,  8.8275e-02, -9.4182e-03, -8.8659e-02,
         -1.0709e-01, -2.6623e-01, -1.2749e-01, -8.3864e-02, -2.0092e-01,
          2.4163e-01, -2.1115e-01, -1.4877e-01, -1.3554e-01, -6.6022e-02,
         -3.8177e-01,  2.8522e-01, -1.3721e-01, -9.3008e-02, -7.0277e-03,
         -4.5320e-02,  7.0163e-02, -3.9012e-01, -1.0984e-02, -2.6638e-01,
          2.8035e-02, -6.4254e-02,  4.2502e-02, -1.0570e-01, -4.1955e-02,
         -2.9909e-01,  1.5473e-01, -2.2444e-01,  3.3241e-01, -6.4602e-01,
          2.5705e-04,  1.9962e-01,  1.7638e-02, -1.9582e-01, -3.2925e-01,
          2.2283e-02,  7.8028e-03,  1.5140e-01, -6.0125e-02, -2.0571e-02,
          2.1563e-01,  1.5333e-02,  9.4160e-02, -6.0160e-02,  2.3433e-02,
          1.4381e-01, -1.0833e-01, -2.1113e-01, -4.0786e-02,  5.4556e-01,
         -2.5435e-01,  2.0931e-01, -1.6533e-01,  5.7826e-02,  2.4129e-01,
         -6.5470e-03,  1.0573e-01,  4.2425e-02,  1.5857e-01, -2.6230e-01,
          6.0087e-02,  9.3488e-02,  5.1482e-02, -2.3391e-02,  1.5080e-01,
         -8.1993e-03,  1.0456e-01,  9.1285e-02,  1.5626e-01,  1.1439e-02,
          4.7808e-03, -2.6726e-01,  5.3531e-02, -3.6573e-01, -2.3823e-02],
        [ 1.2500e-02,  3.7133e-01, -1.4034e-01,  3.9198e-02, -3.6731e-02,
          4.5655e-01, -7.4721e-02,  4.2471e-01,  4.9195e-02, -3.0580e-02,
          3.3413e-02, -6.8495e-02, -1.9100e-01,  3.6335e-02, -1.9796e-01,
         -1.0472e-01, -4.4844e-03,  1.5869e-02, -1.5608e-01, -3.7370e-01,
          1.7851e-01,  1.1554e-01, -2.4501e-03,  2.4786e-01, -1.8886e-01,
         -3.4925e-02,  8.6353e-02, -5.1049e-03,  2.1715e-02,  1.6269e-01,
         -1.4769e-01,  6.4491e-02,  5.3724e-02,  2.4163e-01,  1.0518e-02,
         -1.2475e-01,  7.9642e-02,  1.0490e-01, -6.9637e-02,  8.0978e-02,
          1.0364e-01, -7.0763e-02,  1.7290e-02,  6.2098e-02,  7.8224e-02,
          1.6668e-02, -3.3680e-02, -7.9051e-02, -4.8204e-01,  1.4291e-01,
          5.3678e-01, -5.3431e-02,  2.7816e-01,  1.6967e-01,  5.4366e-01,
          6.1173e-02,  1.3306e-01, -8.3422e-02, -3.2110e-02,  8.0666e-01,
          2.1188e-01,  2.7693e-02, -1.1357e-01, -1.6673e-03, -9.6522e-05,
         -9.0332e-02, -1.7775e-03, -3.8032e-02, -1.5212e-01, -1.2126e-01,
         -3.0774e-02, -2.2333e-01,  3.0382e-02,  1.7606e-01, -3.3427e-01,
         -1.1134e-01, -4.7750e-01,  1.5241e-01,  1.9317e-01,  2.0194e-02,
         -1.1334e-01,  1.5329e-01, -7.5817e-02, -2.4383e-01,  3.1356e-01,
         -5.0673e-02,  3.6370e-02, -3.7387e-03, -6.6226e-02, -8.6424e-02,
         -1.3998e-01, -1.2383e-02, -3.8332e-02, -2.1968e-01,  1.0451e-01,
         -9.3379e-02, -1.7200e-01, -1.1332e-01,  3.6202e-01, -3.0144e-02],
        [ 4.6659e-02, -2.4959e-04,  2.3426e-01, -5.0069e-02,  4.5034e-01,
         -2.9681e-01,  9.8321e-02, -1.0126e-01,  1.7274e-01,  9.9580e-02,
          2.6660e-02,  1.3632e-01,  3.6400e-01, -6.2227e-04,  3.3356e-01,
          3.1854e-01,  2.5345e-02, -2.1220e-01,  1.0711e-01, -3.2063e-01,
          8.1584e-02, -1.1157e-01,  5.7534e-02, -7.1183e-02,  2.2101e-01,
          1.1482e-01, -1.6558e-01,  1.8002e-01,  1.1658e-01, -1.7561e-01,
          3.7775e-01, -7.2300e-02,  2.2794e-01,  1.2333e-01,  8.2356e-02,
         -1.9581e-02, -4.1038e-01,  3.5549e-02,  4.5166e-02,  7.3371e-03,
          1.2843e-01, -4.4556e-02, -1.9091e-01, -2.5438e-02,  1.3717e-01,
         -7.4617e-02, -3.7626e-02,  1.4992e-01,  2.9686e-01, -2.1618e-02,
         -2.0024e-02,  1.1072e-01,  1.8701e-01, -1.2697e-01, -1.9988e-01,
          6.1961e-02,  2.3703e-01, -3.5152e-02, -4.0518e-02, -8.4783e-02,
          1.6145e-02, -1.6552e-01, -2.3104e-03,  1.8262e-01,  7.8389e-02,
         -5.8572e-02,  6.0145e-02,  1.4109e-01,  2.5367e-03, -7.1306e-02,
          6.4362e-02,  3.9111e-01,  1.8736e-02,  1.1172e-01,  2.6443e-02,
          6.5434e-01,  8.7661e-01, -7.2438e-02, -2.7331e-02,  6.6680e-02,
         -1.2558e-01, -5.2908e-02,  1.6369e-01,  7.8277e-01, -1.8979e-02,
          1.9800e-01,  2.7563e-01,  1.6517e-01, -4.0421e-02, -9.4904e-02,
          2.2520e-01, -1.0231e-02,  2.7240e-03,  6.0567e-02, -3.0439e-02,
          8.6276e-02, -5.8198e-02,  1.4592e-01,  2.6837e-01,  4.3547e-02],
        [ 8.4125e-02, -5.1851e-02,  1.1598e-01, -2.1932e-01, -1.6354e-01,
          7.6671e-02, -4.1862e-01,  6.3160e-01, -1.0527e+00,  7.4813e-02,
         -3.1027e-01,  1.2660e-02, -1.6032e-01, -9.4683e-02, -1.7580e+00,
         -4.9546e-01,  9.8908e-02,  5.0030e-01, -3.1702e-02, -3.7284e-01,
         -5.0620e-02, -9.6933e-02,  6.6228e-02, -3.1463e-01,  4.5886e-02,
          4.0198e-01,  2.0504e-02, -6.8307e-02,  6.4647e-02,  2.3587e-02,
          2.2201e-02, -9.0347e-02,  2.3810e-02,  1.7218e-01,  1.6680e-01,
          1.7557e-01, -8.1318e-01, -1.2789e-01, -1.0980e-01,  8.1767e-02,
          8.8020e-02, -1.1889e-01,  3.1082e-01,  1.7859e-01, -5.2875e-01,
         -8.5934e-02,  1.4063e-02, -1.4435e-01,  3.1226e-02,  1.3918e-01,
          2.9591e-01,  9.6788e-02, -6.5919e-02,  4.9293e-01,  4.5455e-01,
         -8.7932e-02,  3.0168e-01, -3.2154e-02,  7.4244e-02, -2.0943e-02,
         -8.3931e-01, -1.1242e-01, -5.2607e-02,  9.7797e-02, -7.4917e-02,
          1.2305e-01, -9.9438e-02, -6.5583e-02, -3.5105e-02, -9.1418e-02,
         -2.2342e-02,  6.2180e-02,  6.4697e-01,  4.4723e-02, -2.9072e-01,
          2.8315e-01, -3.1691e-01,  6.4718e-02,  2.0533e-01, -2.2795e-01,
         -9.5358e-02,  4.9034e-02, -3.0270e-02, -2.4483e-01, -4.5891e-02,
          1.2248e-02, -6.0542e-02, -2.5736e-01, -6.9711e-02,  4.1689e-01,
         -4.3357e-02,  3.9361e-02,  2.3762e-03,  2.7235e-02, -1.7730e-01,
         -1.9130e-01,  2.8159e-01, -4.1389e-02, -3.6894e-01, -4.5783e-02]])), ('classification_layer.bias', tensor([-0.0437, -0.9516,  0.6422,  0.0229,  0.5154, -0.4206, -0.7157, -0.3382,
         0.6277,  0.6935])), ('hidden1_bn.weight', tensor([0.4632, 0.0693, 0.4676, 0.0912, 0.8992, 0.0107, 0.2437, 0.3002, 0.5073,
        0.3785, 0.4168, 0.2188, 0.5564, 0.3978, 0.5550, 0.4008, 0.9480, 0.2032,
        0.0950, 0.9562, 0.2036, 0.1049, 0.8202, 0.6890, 0.1459, 0.5184, 0.9886,
        0.0288, 0.3081, 0.5502, 0.3616, 0.2362, 0.5752, 0.7971, 0.6464, 0.6093,
        0.6319, 0.6932, 0.5754, 0.7061, 0.1426, 0.5505, 0.6314, 0.5166, 0.7559,
        0.6663, 0.3720, 0.0903, 0.4769, 0.2049, 0.6687, 0.4565, 0.7206, 0.8735,
        0.6352, 0.6227, 0.4973, 0.2230, 0.2906, 0.7680, 0.3271, 0.6717, 0.9873,
        0.8300, 0.3160, 0.3024, 0.0135, 0.3432, 0.9397, 0.4456, 0.4240, 0.2521,
        0.1084, 0.1101, 0.3857, 0.2515, 0.6182, 0.7026, 0.6060, 0.8159, 0.6365,
        0.8266, 0.8583, 0.7963, 0.3495, 0.1919, 0.7465, 0.2586, 0.7636, 0.6191,
        0.7115, 0.4252, 0.6900, 0.5011, 0.2227, 0.4763, 0.6764, 0.1176, 0.8967,
        0.5297])), ('hidden1_bn.bias', tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.])), ('hidden1_bn.running_mean', tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.])), ('hidden1_bn.running_var', tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])), ('hidden1_bn.num_batches_tracked', tensor(0)), ('hidden2_bn.weight', tensor([0.6475, 0.1476, 0.0940, 0.0261, 0.5767, 0.7540, 0.3665, 0.0262, 0.0355,
        0.0341, 0.4112, 0.9077, 0.4641, 0.0622, 0.9530, 0.4326, 0.0157, 0.4790,
        0.4019, 0.4963, 0.8927, 0.4591, 0.3768, 0.4285, 0.1262, 0.2269, 0.4734,
        0.1281, 0.0630, 0.6728, 0.9172, 0.4068, 0.5742, 0.0570, 0.9664, 0.5743,
        0.4197, 0.6693, 0.5954, 0.7664, 0.1576, 0.5143, 0.3858, 0.2389, 0.1980,
        0.2186, 0.4176, 0.2282, 0.3032, 0.9754, 0.9064, 0.3265, 0.1897, 0.6833,
        0.7502, 0.4992, 0.9084, 0.7501, 0.7682, 0.3088, 0.6656, 0.0010, 0.0890,
        0.2017, 0.2345, 0.2617, 0.5082, 0.8750, 0.8884, 0.8557, 0.7229, 0.0018,
        0.9673, 0.8800, 0.2885, 0.0765, 0.1365, 0.5506, 0.2979, 0.4409, 0.2962,
        0.7135, 0.5460, 0.9984, 0.3038, 0.4950, 0.1830, 0.8730, 0.7314, 0.5932,
        0.6564, 0.2105, 0.9765, 0.2568, 0.7231, 0.3166, 0.0087, 0.1504, 0.8817,
        0.2414])), ('hidden2_bn.bias', tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.])), ('hidden2_bn.running_mean', tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.])), ('hidden2_bn.running_var', tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])), ('hidden2_bn.num_batches_tracked', tensor(0)), ('hidden3_bn.weight', tensor([0.0183, 0.1703, 0.0816, 0.0073, 0.7481, 0.0045, 0.3684, 0.6449, 0.7166,
        0.2513, 0.7362, 0.9478, 0.0579, 0.8907, 0.5410, 0.9047, 0.9532, 0.4001,
        0.8993, 0.2649, 0.4780, 0.8824, 0.5346, 0.5739, 0.8368, 0.6350, 0.3515,
        0.2345, 0.9436, 0.4721, 0.3576, 0.0944, 0.5854, 0.5526, 0.5765, 0.7673,
        0.8020, 0.7514, 0.4501, 0.0259, 0.0312, 0.5814, 0.6849, 0.7483, 0.6331,
        0.8805, 0.2422, 0.1488, 0.3588, 0.2841, 0.4533, 0.7722, 0.6284, 0.2670,
        0.0777, 0.8324, 0.4633, 0.8356, 0.1231, 0.7873, 0.4009, 0.3379, 0.4591,
        0.0550, 0.4897, 0.8159, 0.8478, 0.6804, 0.6224, 0.7077, 0.6013, 0.7264,
        0.9880, 0.2310, 0.6292, 0.1254, 0.8500, 0.7606, 0.5549, 0.7801, 0.0566,
        0.1811, 0.6724, 0.4320, 0.2750, 0.8118, 0.7839, 0.6223, 0.0229, 0.8085,
        0.9893, 0.1615, 0.7277, 0.8736, 0.1750, 0.1782, 0.1602, 0.5801, 0.4103,
        0.8275])), ('hidden3_bn.bias', tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.])), ('hidden3_bn.running_mean', tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.])), ('hidden3_bn.running_var', tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])), ('hidden3_bn.num_batches_tracked', tensor(0))])
# initailze model by saved parameters
new_model = FeedForwardNeuralNetwork(input_size, hidden_size, output_size)
new_model.load_state_dict(saved_parametes)

3.4.2

Use the evaluate function to predict accuracy and loss of the new_model on the test_loader.

# TODO
new_test_loss, new_test_accuracy = evaluate(test_loader, new_model, loss_fn)
message = 'Average loss: {:.4f}, Accuracy: {:.4f}'.format(new_test_loss, new_test_accuracy)
print(message)
Average loss: 14.7253, Accuracy: 95.2300

4. Training Advanced

4.1 l2_norm

we could minimize the regularization term below by use in SGD optimizer
\begin{equation}
​ L_norm = {\sum_{i=1}{m}{\theta_{i}{2}}}
\end{equation}

4.1.1 l2_norm = 0.01

set l2_norm=0.01, let's train and see

### Hyper parameters
batch_size = 128
n_epochs = 5
learning_rate = 0.01
input_size = 28*28
hidden_size = 100
output_size = 10
l2_norm = 0.01 # use l2 penalty
get_grad = False

# declare a model
model = FeedForwardNeuralNetwork(input_size=input_size, hidden_size=hidden_size, output_size=output_size)
# Cross entropy
loss_fn = torch.nn.CrossEntropyLoss()
# l2_norm can be done in SGD
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=l2_norm) 
train_accs, train_losses, test_losses, test_accs = fit(train_loader, test_loader, model, loss_fn, optimizer, n_epochs, get_grad)
show_curve(train_accs, test_accs, 'Accs')
show_curve(train_losses, test_losses, 'Losses')
Epoch: 1/5. Train set: Average loss: 1.9034, Accuracy: 74.8583
Epoch: 1/5. Validation set: Average loss: 0.9461, Accuracy: 75.4200
Epoch: 2/5. Train set: Average loss: 0.6313, Accuracy: 86.2433
Epoch: 2/5. Validation set: Average loss: 0.4580, Accuracy: 86.5500
Epoch: 3/5. Train set: Average loss: 0.4135, Accuracy: 89.0417
Epoch: 3/5. Validation set: Average loss: 0.3631, Accuracy: 89.3100
Epoch: 4/5. Train set: Average loss: 0.3531, Accuracy: 90.2200
Epoch: 4/5. Validation set: Average loss: 0.3268, Accuracy: 90.4500
Epoch: 5/5. Train set: Average loss: 0.3227, Accuracy: 90.9317
Epoch: 5/5. Validation set: Average loss: 0.3030, Accuracy: 91.1100
image
image

4.1.2 Problem 5

Consider the influence of regular items in loss proportion. L2_norm = 1 was used to train the model.

Hints: because jupyter has context on variables, the model and the optimizer needs to be restated. The model and optimizer can be redefined using the following code. Note that the default initialization is used here.

# TODO

### Hyper parameters
batch_size = 128
n_epochs = 5
learning_rate = 0.01
input_size = 28*28
hidden_size = 100
output_size = 10
l2_norm = 1 # use l2 penalty
get_grad = False

# declare a model
model = FeedForwardNeuralNetwork(input_size=input_size, hidden_size=hidden_size, output_size=output_size)
# Cross entropy
loss_fn = torch.nn.CrossEntropyLoss()
# l2_norm can be done in SGD
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=l2_norm) 
# TODO
# Train
train_accs, train_losses, test_losses, test_accs = fit(train_loader, test_loader, model, loss_fn, optimizer, n_epochs, get_grad)
show_curve(train_accs, test_accs, 'Accs')
show_curve(train_losses, test_losses, 'Losses')
Epoch: 1/5. Train set: Average loss: 2.3071, Accuracy: 11.2367
Epoch: 1/5. Validation set: Average loss: 2.3024, Accuracy: 11.3500
Epoch: 2/5. Train set: Average loss: 2.3073, Accuracy: 11.2367
Epoch: 2/5. Validation set: Average loss: 2.3024, Accuracy: 11.3500
Epoch: 3/5. Train set: Average loss: 2.3073, Accuracy: 11.2367
Epoch: 3/5. Validation set: Average loss: 2.3024, Accuracy: 11.3500
Epoch: 4/5. Train set: Average loss: 2.3073, Accuracy: 11.2367
Epoch: 4/5. Validation set: Average loss: 2.3024, Accuracy: 11.3500
Epoch: 5/5. Train set: Average loss: 2.3073, Accuracy: 11.2367
Epoch: 5/5. Validation set: Average loss: 2.3024, Accuracy: 11.3500
image
image

We can see that if the l2 penalty is too big, the accuracy can be significantly affected.

4.2 dropout

During training, randomly zeroes some of the elements of the input tensor with probability p using samples from a Bernoulli distribution.

Each channel will be zeroed out independently on every forward call.

Hints: because jupyter has context on variables, the model and the optimizer needs to be restated. The model and optimizer can be redefined using the following code. Note that the default initialization is used here.

### Hyper parameters
batch_size = 128
n_epochs = 5
learning_rate = 0.01
input_size = 28*28
hidden_size = 100
output_size = 10
l2_norm = 0 # without using l2 penalty
get_grad = False

# declare a model
model = FeedForwardNeuralNetwork(input_size=input_size, hidden_size=hidden_size, output_size=output_size)
# Cross entropy
loss_fn = torch.nn.CrossEntropyLoss()
# l2_norm can be done in SGD
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=l2_norm) 
# Set dropout to True and probability = 0.5
model.set_use_dropout(True)
train_accs, train_losses, test_losses, test_accs = fit(train_loader, test_loader, model, loss_fn, optimizer, n_epochs, get_grad)
show_curve(train_accs, test_accs, 'Accs')
show_curve(train_losses, test_losses, 'Losses')
Epoch: 1/5. Train set: Average loss: 0.3335, Accuracy: 92.6233
Epoch: 1/5. Validation set: Average loss: 0.2438, Accuracy: 92.5300
Epoch: 2/5. Train set: Average loss: 0.3065, Accuracy: 93.3100
Epoch: 2/5. Validation set: Average loss: 0.2221, Accuracy: 93.1600
Epoch: 3/5. Train set: Average loss: 0.2794, Accuracy: 93.8617
Epoch: 3/5. Validation set: Average loss: 0.2036, Accuracy: 93.6500
Epoch: 4/5. Train set: Average loss: 0.2576, Accuracy: 94.3500
Epoch: 4/5. Validation set: Average loss: 0.1894, Accuracy: 94.1400
Epoch: 5/5. Train set: Average loss: 0.2373, Accuracy: 94.7400
Epoch: 5/5. Validation set: Average loss: 0.1768, Accuracy: 94.5100
image
image

4.3 batch_normalization

Batch normalization is a technique for improving the performance and stability of artificial neural networks

\begin{equation}
​ y=\frac{x-E[x]}{\sqrt{Var[x]+\epsilon}} * \gamma + \beta,
\end{equation}

and are learnable parameters

Hints: because jupyter has context on variables, the model and the optimizer needs to be restated. The model and optimizer can be redefined using the following code. Note that the default initialization is used here.

### Hyper parameters
batch_size = 128
n_epochs = 5
learning_rate = 0.01
input_size = 28*28
hidden_size = 100
output_size = 10
l2_norm = 0 # without using l2 penalty
get_grad = False

# declare a model
model = FeedForwardNeuralNetwork(input_size=input_size, hidden_size=hidden_size, output_size=output_size)
# Cross entropy
loss_fn = torch.nn.CrossEntropyLoss()
# l2_norm can be done in SGD
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=l2_norm) 
model.set_use_bn(True)
model.use_bn
True
train_accs, train_losses, test_losses, test_accs = fit(train_loader, test_loader, model, loss_fn, optimizer, n_epochs, get_grad)
show_curve(train_accs, test_accs, 'Accs')
show_curve(train_losses, test_losses, 'Losses')
Epoch: 1/5. Train set: Average loss: 1.0761, Accuracy: 91.1733
Epoch: 1/5. Validation set: Average loss: 0.4680, Accuracy: 91.1000
Epoch: 2/5. Train set: Average loss: 0.3410, Accuracy: 94.5100
Epoch: 2/5. Validation set: Average loss: 0.2490, Accuracy: 94.1800
Epoch: 3/5. Train set: Average loss: 0.2136, Accuracy: 95.9850
Epoch: 3/5. Validation set: Average loss: 0.1795, Accuracy: 95.5600
Epoch: 4/5. Train set: Average loss: 0.1589, Accuracy: 96.8617
Epoch: 4/5. Validation set: Average loss: 0.1459, Accuracy: 96.3400
Epoch: 5/5. Train set: Average loss: 0.1268, Accuracy: 97.4000
Epoch: 5/5. Validation set: Average loss: 0.1269, Accuracy: 96.6400
image
image

4.4 data augmentation

data augmentation can be more complicated to gain a better generalization on test dataset

# only add random horizontal flip
train_transform_1 = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(), # Convert a PIL Image or numpy.ndarray to tensor.
    # Normalize a tensor image with mean and standard deviation
    transforms.Normalize((0.1307,), (0.3081,))
])

# only add random crop
train_transform_2 = transforms.Compose([
    transforms.RandomCrop(size=[28,28], padding=4),
    transforms.ToTensor(), # Convert a PIL Image or numpy.ndarray to tensor.
    # Normalize a tensor image with mean and standard deviation
    transforms.Normalize((0.1307,), (0.3081,))
])

# add random horizontal flip and random crop
train_transform_3 = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(size=[28,28], padding=4),
    transforms.ToTensor(), # Convert a PIL Image or numpy.ndarray to tensor.
    # Normalize a tensor image with mean and standard deviation
    transforms.Normalize((0.1307,), (0.3081,))
])
# reload train_loader using trans

train_dataset_1 = torchvision.datasets.MNIST(root='./data', 
                            train=True, 
                            transform=train_transform_1,
                            download=False)

train_loader_1 = torch.utils.data.DataLoader(dataset=train_dataset_1, 
                                           batch_size=batch_size, 
                                           shuffle=True)
print(train_dataset_1)
Dataset MNIST
    Number of datapoints: 60000
    Split: train
    Root Location: ./data
    Transforms (if any): Compose(
                             RandomHorizontalFlip(p=0.5)
                             ToTensor()
                             Normalize(mean=(0.1307,), std=(0.3081,))
                         )
    Target Transforms (if any): None
### Hyper parameters
batch_size = 128
n_epochs = 5
learning_rate = 0.01
input_size = 28*28
hidden_size = 100
output_size = 10
l2_norm = 0 # without using l2 penalty
get_grad = False

# declare a model
model = FeedForwardNeuralNetwork(input_size=input_size, hidden_size=hidden_size, output_size=output_size)
# Cross entropy
loss_fn = torch.nn.CrossEntropyLoss()
# l2_norm can be done in SGD
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=l2_norm) 
train_accs, train_losses, test_losses, test_accs = fit(train_loader_1, test_loader, model, loss_fn, optimizer, n_epochs, get_grad)
show_curve(train_accs, test_accs, 'Accs')
show_curve(train_losses, test_losses, 'Losses')
Epoch: 1/5. Train set: Average loss: 2.0015, Accuracy: 66.7167
Epoch: 1/5. Validation set: Average loss: 1.2088, Accuracy: 67.6700
Epoch: 2/5. Train set: Average loss: 0.8502, Accuracy: 78.9600
Epoch: 2/5. Validation set: Average loss: 0.6482, Accuracy: 79.7700
Epoch: 3/5. Train set: Average loss: 0.6221, Accuracy: 82.1050
Epoch: 3/5. Validation set: Average loss: 0.5469, Accuracy: 82.7900
Epoch: 4/5. Train set: Average loss: 0.5425, Accuracy: 83.7417
Epoch: 4/5. Validation set: Average loss: 0.4863, Accuracy: 84.2700
Epoch: 5/5. Train set: Average loss: 0.4813, Accuracy: 85.9383
Epoch: 5/5. Validation set: Average loss: 0.4333, Accuracy: 86.1800
image
image

4.5 Problem 6

Use train_transform_2 and train_transform_3 provided, reload train_loader and train with fit.

Hints: because jupyter has context for variables, the model, the optimizer, needs to be re-declared. Note that the default initialization is used here.

# TODO

# reload train_loader using train_transform_2

train_dataset_2 = torchvision.datasets.MNIST(root='./data', 
                            train=True, 
                            transform=train_transform_2,
                            download=False)

train_loader_2 = torch.utils.data.DataLoader(dataset=train_dataset_2, 
                                           batch_size=batch_size, 
                                           shuffle=True)

train_accs, train_losses, test_losses, test_accs = fit(train_loader_2, test_loader, model, loss_fn, optimizer, n_epochs, get_grad)
show_curve(train_accs, test_accs, 'Accs')
show_curve(train_losses, test_losses, 'Losses')
Epoch: 1/5. Train set: Average loss: 1.3406, Accuracy: 62.0983
Epoch: 1/5. Validation set: Average loss: 0.9176, Accuracy: 74.7300
Epoch: 2/5. Train set: Average loss: 1.0130, Accuracy: 72.4767
Epoch: 2/5. Validation set: Average loss: 0.7144, Accuracy: 79.6100
Epoch: 3/5. Train set: Average loss: 0.7818, Accuracy: 78.9767
Epoch: 3/5. Validation set: Average loss: 0.5295, Accuracy: 84.8800
Epoch: 4/5. Train set: Average loss: 0.6261, Accuracy: 82.3433
Epoch: 4/5. Validation set: Average loss: 0.4338, Accuracy: 87.4800
Epoch: 5/5. Train set: Average loss: 0.5252, Accuracy: 85.5233
Epoch: 5/5. Validation set: Average loss: 0.3735, Accuracy: 89.0300
image
image
# TODO

# reload train_loader using train_transform_3

train_dataset_3 = torchvision.datasets.MNIST(root='./data', 
                            train=True, 
                            transform=train_transform_3,
                            download=False)

train_loader_3 = torch.utils.data.DataLoader(dataset=train_dataset_3, 
                                           batch_size=batch_size, 
                                           shuffle=True)

train_accs, train_losses, test_losses, test_accs = fit(train_loader_3, test_loader, model, loss_fn, optimizer, n_epochs, get_grad)
show_curve(train_accs, test_accs, 'Accs')
show_curve(train_losses, test_losses, 'Losses')
Epoch: 1/5. Train set: Average loss: 0.7662, Accuracy: 78.1667
Epoch: 1/5. Validation set: Average loss: 0.4631, Accuracy: 86.4500
Epoch: 2/5. Train set: Average loss: 0.6339, Accuracy: 80.5283
Epoch: 2/5. Validation set: Average loss: 0.4665, Accuracy: 86.0200
Epoch: 3/5. Train set: Average loss: 0.5718, Accuracy: 81.7750
Epoch: 3/5. Validation set: Average loss: 0.4170, Accuracy: 86.5700
Epoch: 4/5. Train set: Average loss: 0.5321, Accuracy: 83.5950
Epoch: 4/5. Validation set: Average loss: 0.3840, Accuracy: 87.6100
Epoch: 5/5. Train set: Average loss: 0.5019, Accuracy: 83.8067
Epoch: 5/5. Validation set: Average loss: 0.3902, Accuracy: 87.7000
image
image

5. Visualization of training and validation phase

We could use tensorboard to visualize our training and test phase.
You could find example here

6. Gradient explosion and vanishing

We have embedded code which shows grad for hidden2 and hidden3 layer. By observing their grad changes, we can
see whether gradient is normal or not.

For plot grad changes, you need to set get_grad=True in fit function

### Hyper parameters
batch_size = 128
n_epochs = 15
learning_rate = 0.01
input_size = 28*28
hidden_size = 100
output_size = 10
l2_norm = 0 # use l2 penalty
get_grad = True

# declare a model
model = FeedForwardNeuralNetwork(input_size=input_size, hidden_size=hidden_size, output_size=output_size)
# Cross entropy
loss_fn = torch.nn.CrossEntropyLoss()
# l2_norm can be done in SGD
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=l2_norm) 
fit(train_loader, test_loader, model, loss_fn, optimizer, n_epochs, get_grad)
Epoch: 1/15. Train set: Average loss: 1.8883, Accuracy: 77.2633
Epoch: 1/15. Validation set: Average loss: 0.8983, Accuracy: 77.9100
Epoch: 2/15. Train set: Average loss: 0.5687, Accuracy: 87.7217
Epoch: 2/15. Validation set: Average loss: 0.4038, Accuracy: 88.0700
Epoch: 3/15. Train set: Average loss: 0.3675, Accuracy: 89.9283
Epoch: 3/15. Validation set: Average loss: 0.3260, Accuracy: 90.1600
Epoch: 4/15. Train set: Average loss: 0.3123, Accuracy: 91.1600
Epoch: 4/15. Validation set: Average loss: 0.2863, Accuracy: 91.4200
Epoch: 5/15. Train set: Average loss: 0.2793, Accuracy: 92.1150
Epoch: 5/15. Validation set: Average loss: 0.2593, Accuracy: 92.2500
Epoch: 6/15. Train set: Average loss: 0.2543, Accuracy: 92.8367
Epoch: 6/15. Validation set: Average loss: 0.2384, Accuracy: 92.8200
Epoch: 7/15. Train set: Average loss: 0.2336, Accuracy: 93.4067
Epoch: 7/15. Validation set: Average loss: 0.2208, Accuracy: 93.4100
Epoch: 8/15. Train set: Average loss: 0.2155, Accuracy: 93.9067
Epoch: 8/15. Validation set: Average loss: 0.2052, Accuracy: 93.8500
Epoch: 9/15. Train set: Average loss: 0.1995, Accuracy: 94.3783
Epoch: 9/15. Validation set: Average loss: 0.1911, Accuracy: 94.1600
Epoch: 10/15. Train set: Average loss: 0.1854, Accuracy: 94.7917
Epoch: 10/15. Validation set: Average loss: 0.1789, Accuracy: 94.5500
Epoch: 11/15. Train set: Average loss: 0.1727, Accuracy: 95.1583
Epoch: 11/15. Validation set: Average loss: 0.1682, Accuracy: 94.8800
Epoch: 12/15. Train set: Average loss: 0.1615, Accuracy: 95.4683
Epoch: 12/15. Validation set: Average loss: 0.1588, Accuracy: 95.1600
Epoch: 13/15. Train set: Average loss: 0.1516, Accuracy: 95.7700
Epoch: 13/15. Validation set: Average loss: 0.1507, Accuracy: 95.3900
Epoch: 14/15. Train set: Average loss: 0.1427, Accuracy: 96.0317
Epoch: 14/15. Validation set: Average loss: 0.1437, Accuracy: 95.6500
Epoch: 15/15. Train set: Average loss: 0.1348, Accuracy: 96.2417
Epoch: 15/15. Validation set: Average loss: 0.1376, Accuracy: 95.8400





([77.26333333333334,
  87.72166666666666,
  89.92833333333333,
  91.16,
  92.115,
  92.83666666666667,
  93.40666666666667,
  93.90666666666667,
  94.37833333333333,
  94.79166666666667,
  95.15833333333333,
  95.46833333333333,
  95.77,
  96.03166666666667,
  96.24166666666666],
 [1.8883255884433403,
  0.5687443313117211,
  0.36754155533117616,
  0.31234517640983445,
  0.27934257469625556,
  0.25430761317475736,
  0.23359582908292356,
  0.21554398813690895,
  0.1995451689307761,
  0.1853731685023532,
  0.17268824516835377,
  0.16149521451921034,
  0.1515944946843844,
  0.142730517917846,
  0.13476479675971034],
 [0.8983050381081014,
  0.40381407219020626,
  0.32599611438905135,
  0.2863018473586704,
  0.25928632353868664,
  0.23837185495450527,
  0.22084368661611894,
  0.20515649761014346,
  0.19110500274956982,
  0.17893974940422214,
  0.16822792386895494,
  0.15882641767870775,
  0.15071836245965353,
  0.14373108235341084,
  0.1375972312651103],
 [77.91,
  88.07,
  90.16,
  91.42,
  92.25,
  92.82,
  93.41,
  93.85,
  94.16,
  94.55,
  94.88,
  95.16,
  95.39,
  95.65,
  95.84])
image

6.1.1 Gradient Vanishing

Set learning=e-10

### Hyper parameters
batch_size = 128
n_epochs = 15
learning_rate = 1e-20
input_size = 28*28
hidden_size = 100
output_size = 10
l2_norm = 0 # use l2 penalty
get_grad = True

# declare a model
model = FeedForwardNeuralNetwork(input_size=input_size, hidden_size=hidden_size, output_size=output_size)
# Cross entropy
loss_fn = torch.nn.CrossEntropyLoss()
# l2_norm can be done in SGD
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=l2_norm) 
fit(train_loader, test_loader, model, loss_fn, optimizer, n_epochs, get_grad=get_grad)
Epoch: 1/15. Train set: Average loss: 2.3074, Accuracy: 14.6833
Epoch: 1/15. Validation set: Average loss: 2.3011, Accuracy: 15.2900
Epoch: 2/15. Train set: Average loss: 2.3074, Accuracy: 14.6833
Epoch: 2/15. Validation set: Average loss: 2.3011, Accuracy: 15.2900
Epoch: 3/15. Train set: Average loss: 2.3074, Accuracy: 14.6833
Epoch: 3/15. Validation set: Average loss: 2.3011, Accuracy: 15.2900
Epoch: 4/15. Train set: Average loss: 2.3074, Accuracy: 14.6833
Epoch: 4/15. Validation set: Average loss: 2.3011, Accuracy: 15.2900
Epoch: 5/15. Train set: Average loss: 2.3074, Accuracy: 14.6833
Epoch: 5/15. Validation set: Average loss: 2.3011, Accuracy: 15.2900
Epoch: 6/15. Train set: Average loss: 2.3074, Accuracy: 14.6833
Epoch: 6/15. Validation set: Average loss: 2.3011, Accuracy: 15.2900
Epoch: 7/15. Train set: Average loss: 2.3074, Accuracy: 14.6833
Epoch: 7/15. Validation set: Average loss: 2.3011, Accuracy: 15.2900
Epoch: 8/15. Train set: Average loss: 2.3074, Accuracy: 14.6833
Epoch: 8/15. Validation set: Average loss: 2.3011, Accuracy: 15.2900
Epoch: 9/15. Train set: Average loss: 2.3074, Accuracy: 14.6833
Epoch: 9/15. Validation set: Average loss: 2.3011, Accuracy: 15.2900
Epoch: 10/15. Train set: Average loss: 2.3074, Accuracy: 14.6833
Epoch: 10/15. Validation set: Average loss: 2.3011, Accuracy: 15.2900
Epoch: 11/15. Train set: Average loss: 2.3074, Accuracy: 14.6833
Epoch: 11/15. Validation set: Average loss: 2.3011, Accuracy: 15.2900
Epoch: 12/15. Train set: Average loss: 2.3074, Accuracy: 14.6833
Epoch: 12/15. Validation set: Average loss: 2.3011, Accuracy: 15.2900
Epoch: 13/15. Train set: Average loss: 2.3074, Accuracy: 14.6833
Epoch: 13/15. Validation set: Average loss: 2.3011, Accuracy: 15.2900
Epoch: 14/15. Train set: Average loss: 2.3074, Accuracy: 14.6833
Epoch: 14/15. Validation set: Average loss: 2.3011, Accuracy: 15.2900
Epoch: 15/15. Train set: Average loss: 2.3074, Accuracy: 14.6833
Epoch: 15/15. Validation set: Average loss: 2.3011, Accuracy: 15.2900





([14.683333333333334,
  14.683333333333334,
  14.683333333333334,
  14.683333333333334,
  14.683333333333334,
  14.683333333333334,
  14.683333333333334,
  14.683333333333334,
  14.683333333333334,
  14.683333333333334,
  14.683333333333334,
  14.683333333333334,
  14.683333333333334,
  14.683333333333334,
  14.683333333333334],
 [2.3074400210991883,
  2.3074400210991883,
  2.3074400210991883,
  2.3074400210991883,
  2.3074400210991883,
  2.3074400210991883,
  2.3074400210991883,
  2.3074400210991883,
  2.3074400210991883,
  2.3074400210991883,
  2.3074400210991883,
  2.3074400210991883,
  2.3074400210991883,
  2.3074400210991883,
  2.3074400210991883],
 [2.3010528570489037,
  2.3010528570489037,
  2.3010528570489037,
  2.3010528570489037,
  2.3010528570489037,
  2.3010528570489037,
  2.3010528570489037,
  2.3010528570489037,
  2.3010528570489037,
  2.3010528570489037,
  2.3010528570489037,
  2.3010528570489037,
  2.3010528570489037,
  2.3010528570489037,
  2.3010528570489037],
 [15.29,
  15.29,
  15.29,
  15.29,
  15.29,
  15.29,
  15.29,
  15.29,
  15.29,
  15.29,
  15.29,
  15.29,
  15.29,
  15.29,
  15.29])
image

6.1.2 Gradient Explosion

6.1.2.1 learning rate

set learning rate = 10

### Hyper parameters
batch_size = 128
n_epochs = 15
learning_rate = 1.0168
input_size = 28*28
hidden_size = 100
output_size = 10
l2_norm = 0 # not to use l2 penalty
get_grad = True

# declare a model
model = FeedForwardNeuralNetwork(input_size=input_size, hidden_size=hidden_size, output_size=output_size)
# Cross entropy
loss_fn = torch.nn.CrossEntropyLoss()
# l2_norm can be done in SGD
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=l2_norm) 
fit(train_loader, test_loader, model, loss_fn, optimizer, n_epochs, get_grad=True)
Epoch: 1/15. Train set: Average loss: 2.0630, Accuracy: 26.7583
Epoch: 1/15. Validation set: Average loss: 2.1282, Accuracy: 26.7700
Epoch: 2/15. Train set: Average loss: 2.2670, Accuracy: 10.0900
Epoch: 2/15. Validation set: Average loss: 2.2986, Accuracy: 9.7600
Epoch: 3/15. Train set: Average loss: 2.1061, Accuracy: 18.4283
Epoch: 3/15. Validation set: Average loss: 2.0783, Accuracy: 17.8500
Epoch: 4/15. Train set: Average loss: 2.0247, Accuracy: 19.7433
Epoch: 4/15. Validation set: Average loss: 1.9792, Accuracy: 19.1600
Epoch: 5/15. Train set: Average loss: 1.8996, Accuracy: 27.6817
Epoch: 5/15. Validation set: Average loss: 1.7469, Accuracy: 27.8600
Epoch: 6/15. Train set: Average loss: 1.9673, Accuracy: 19.7900
Epoch: 6/15. Validation set: Average loss: 1.8792, Accuracy: 19.4400
Epoch: 7/15. Train set: Average loss: 1.9726, Accuracy: 19.0433
Epoch: 7/15. Validation set: Average loss: 1.9119, Accuracy: 18.3700
Epoch: 8/15. Train set: Average loss: 1.8971, Accuracy: 19.3833
Epoch: 8/15. Validation set: Average loss: 2.0936, Accuracy: 19.1200
Epoch: 9/15. Train set: Average loss: 2.0608, Accuracy: 21.0750
Epoch: 9/15. Validation set: Average loss: 1.9886, Accuracy: 21.0400


/Users/nino/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:80: RuntimeWarning: overflow encountered in square
/Users/nino/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:81: RuntimeWarning: overflow encountered in square


Epoch: 10/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 10/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 11/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 11/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 12/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 12/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 13/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 13/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 14/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 14/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 15/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 15/15. Validation set: Average loss: nan, Accuracy: 9.8000





([26.758333333333333,
  10.09,
  18.428333333333335,
  19.743333333333332,
  27.68166666666667,
  19.79,
  19.043333333333333,
  19.383333333333333,
  21.075,
  9.871666666666666,
  9.871666666666666,
  9.871666666666666,
  9.871666666666666,
  9.871666666666666,
  9.871666666666666],
 [2.0630336391110706,
  2.267040777664918,
  2.1061492269365196,
  2.024679694165531,
  1.8995909976144123,
  1.9672510224020379,
  1.9726365409855149,
  1.8970981736977894,
  2.060765859153536,
  nan,
  nan,
  nan,
  nan,
  nan,
  nan],
 [2.1281878012645095,
  2.2986260969427565,
  2.0783027516135686,
  1.9791795317130754,
  1.7469420357595515,
  1.8791846338706681,
  1.9119218029553378,
  2.093622092959247,
  1.988625618475902,
  nan,
  nan,
  nan,
  nan,
  nan,
  nan],
 [26.77,
  9.76,
  17.85,
  19.16,
  27.86,
  19.44,
  18.37,
  19.12,
  21.04,
  9.8,
  9.8,
  9.8,
  9.8,
  9.8,
  9.8])
image

6.1.2.2 normalization for input data

6.1.2.3 unsuitable weight initialization

### Hyper parameters
batch_size = 128
n_epochs = 15
learning_rate = 1
input_size = 28*28
hidden_size = 100
output_size = 10
l2_norm = 0 # not to use l2 penalty
get_grad = True

# declare a model
model = FeedForwardNeuralNetwork(input_size=input_size, hidden_size=hidden_size, output_size=output_size)
# Cross entropy
loss_fn = torch.nn.CrossEntropyLoss()
# l2_norm can be done in SGD
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=l2_norm) 
# reset parameters as 10
def wrong_weight_bias_reset(model):
    """Using normalization with mean=0, std=1 to initialize model's parameter
    """
    for m in model.modules():
        if isinstance(m, nn.Linear):
            # initialize linear layer with mean and std
            mean, std = 0, 1 
            
            # Initialization method
            torch.nn.init.normal_(m.weight, mean, std)
            torch.nn.init.normal_(m.bias, mean, std)
wrong_weight_bias_reset(model)
show_weight_bias(model)
/Users/nino/anaconda3/lib/python3.7/site-packages/matplotlib/figure.py:2299: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.
  warnings.warn("This figure includes Axes that are not compatible "
image
fit(train_loader, test_loader, model, loss_fn, optimizer, n_epochs, get_grad=True)
/Users/nino/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:80: RuntimeWarning: overflow encountered in square
/Users/nino/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:81: RuntimeWarning: overflow encountered in square


Epoch: 1/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 1/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 2/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 2/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 3/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 3/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 4/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 4/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 5/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 5/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 6/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 6/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 7/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 7/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 8/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 8/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 9/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 9/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 10/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 10/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 11/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 11/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 12/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 12/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 13/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 13/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 14/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 14/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 15/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 15/15. Validation set: Average loss: nan, Accuracy: 9.8000





([9.871666666666666,
  9.871666666666666,
  9.871666666666666,
  9.871666666666666,
  9.871666666666666,
  9.871666666666666,
  9.871666666666666,
  9.871666666666666,
  9.871666666666666,
  9.871666666666666,
  9.871666666666666,
  9.871666666666666,
  9.871666666666666,
  9.871666666666666,
  9.871666666666666],
 [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
 [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
 [9.8, 9.8, 9.8, 9.8, 9.8, 9.8, 9.8, 9.8, 9.8, 9.8, 9.8, 9.8, 9.8, 9.8, 9.8])
image

References

  1. Training a Classifier
  2. Save Model and Load Model
  3. Visualize your training phase
  4. Exploding and Vanishing Gradients
  5. Gradient disappearance and gradient explosion in neural network training
  6. tensorboardX

你可能感兴趣的:(深度学习笔记(六)——分类网络的训练问题-2)