参考
网络基本构建与训练方法,常用函数解析
torch.nn.functional模块
nn.Module模块
%matplotlib inline
from pathlib import Path
import requests
DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"
PATH.mkdir(parents=True, exist_ok=True)
URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"
if not (PATH / FILENAME).exists():
content = requests.get(URL + FILENAME).content
(PATH / FILENAME).open("wb").write(content)
import pickle
import gzip
with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")
import matplotlib.pyplot as plt
import numpy as np
plt.imshow(x_train[0].reshape((28, 28)), cmap="gray")
print(x_train.shape) # 784 --> 28*28*1 , H*W*C
(50000, 784)
from IPython.display import Image
Image(filename = 'path')
import torch
# 将数据转换为tensor进行后续训练
x_train, y_train, x_valid, y_valid = map(torch.tensor, (x_train, y_train, x_valid, y_valid))
n, c = x_train.shape
x_train, x_train.shape, y_train.min(), y_train.max()
print(x_train, y_train)
print(x_train.shape)
print(y_train.min(), y_train.max())
tensor([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]) tensor([5, 0, 4, ..., 8, 4, 8])
torch.Size([50000, 784])
tensor(0) tensor(9)
import torch.nn.functional as F
loss_func = F.cross_entropy
def model(xb):
return xb.mm(weights) + bias
bs = 64
xb = x_train[:bs] # batch
yb = y_train[:bs]
weights = torch.randn([784, 10], dtype = torch.float, requires_grad = True)
bias = torch.zeros(10, requires_grad = True)
print(loss_func(model(xb),yb))
tensor(15.3532, grad_fn=<NllLossBackward>)
import torch.nn as nn
class Mnist_NN(nn.Module):
def __init__(self):
super(Mnist_NN, self).__init__()
self.hidden1 = nn.Linear(784, 128)
self.hidden2 = nn.Linear(128, 256)
self.out = nn.Linear(256, 10)
def forward(self, x):
x = F.relu(self.hidden1(x))
x = F.relu(self.hidden2(x))
return self.out(x)
net = Mnist_NN()
print(net)
Mnist_NN(
(hidden1): Linear(in_features=784, out_features=128, bias=True)
(hidden2): Linear(in_features=128, out_features=256, bias=True)
(out): Linear(in_features=256, out_features=10, bias=True)
)
for name, parameter in net.named_parameters():
print(name, parameter, parameter.size())
hidden1.weight Parameter containing:
tensor([[ 0.0254, -0.0086, -0.0126, ..., 0.0136, 0.0186, 0.0039],
[-0.0018, -0.0231, -0.0341, ..., 0.0008, -0.0026, -0.0084],
[-0.0281, -0.0042, -0.0078, ..., -0.0317, 0.0121, -0.0097],
...,
[ 0.0093, 0.0152, -0.0289, ..., 0.0039, -0.0213, 0.0186],
[-0.0256, -0.0167, -0.0294, ..., 0.0328, 0.0083, -0.0089],
[ 0.0250, -0.0224, 0.0345, ..., 0.0154, -0.0297, 0.0179]],
requires_grad=True) torch.Size([128, 784])
hidden1.bias Parameter containing:
tensor([-6.5502e-03, -5.5020e-03, -3.2130e-02, -3.3698e-02, -3.7882e-03,
4.3483e-03, -1.1440e-02, -2.5736e-02, -1.3230e-02, 1.8212e-02,
1.4735e-02, -2.8692e-02, -1.9248e-02, -1.6134e-02, 2.6542e-02,
-2.3506e-02, 1.9662e-02, -3.4623e-02, -2.3149e-02, -2.8246e-02,
1.9046e-02, 9.0913e-04, -1.9846e-02, 1.7235e-02, 1.1492e-02,
-1.6770e-02, -1.4115e-02, -9.3709e-03, -2.3016e-02, -5.6963e-05,
-5.3043e-03, 1.3980e-02, 1.4813e-02, -2.2603e-02, -1.3933e-02,
5.0620e-03, 1.3708e-02, -3.3913e-02, -1.4046e-02, -2.1119e-02,
3.1775e-02, 3.2746e-02, -8.8494e-03, -9.8448e-03, 2.2414e-02,
-1.2001e-02, -1.2603e-02, -3.0907e-02, 3.3045e-02, -1.1636e-02,
-9.3357e-03, 3.1828e-02, 5.6664e-04, -2.9685e-02, 1.4453e-02,
-3.1200e-02, -2.1885e-02, 1.6815e-02, 4.2410e-03, -3.5402e-02,
3.0180e-02, 3.3961e-02, -1.1829e-02, 8.7620e-03, -2.7341e-02,
-3.4919e-02, 3.3052e-02, 2.0017e-02, -5.5200e-03, -2.6593e-02,
-2.0800e-02, -3.4543e-02, -6.6311e-03, -3.1868e-03, -3.6205e-03,
-2.4685e-02, 3.1324e-02, -3.1354e-03, -2.0892e-02, -1.8177e-02,
-2.6288e-02, 1.6119e-02, 1.7610e-02, 1.5385e-02, 1.3507e-02,
-1.1772e-02, 1.5924e-02, 3.3665e-02, -1.1112e-02, 7.4026e-03,
1.9150e-02, -2.1051e-02, 2.3446e-03, -8.3934e-03, -6.7710e-03,
2.2514e-03, 1.6836e-02, 1.0848e-02, -3.0801e-02, 2.1783e-03,
-3.1347e-03, -2.9991e-02, -7.4200e-03, -3.0242e-02, 3.2102e-02,
-5.3846e-03, 1.8788e-02, 1.9407e-02, 1.7531e-02, -2.6160e-02,
-2.8468e-02, 2.4840e-03, -1.1963e-02, -3.3299e-02, -2.0726e-02,
3.6885e-03, -2.7398e-02, 2.3647e-02, -2.7768e-02, -7.7109e-04,
-2.8763e-02, 3.3596e-02, -1.6966e-02, -6.5043e-03, 1.0516e-02,
-1.1469e-02, 1.5004e-02, 3.1973e-02], requires_grad=True) torch.Size([128])
hidden2.weight Parameter containing:
tensor([[-0.0347, 0.0167, -0.0324, ..., -0.0582, 0.0603, 0.0834],
[ 0.0720, -0.0335, -0.0762, ..., 0.0224, -0.0405, 0.0167],
[ 0.0356, 0.0801, -0.0104, ..., 0.0505, 0.0144, -0.0683],
...,
[-0.0317, 0.0338, -0.0141, ..., -0.0784, 0.0405, 0.0139],
[ 0.0221, 0.0389, 0.0688, ..., 0.0255, 0.0720, 0.0392],
[-0.0772, 0.0241, 0.0877, ..., -0.0487, 0.0440, -0.0370]],
requires_grad=True) torch.Size([256, 128])
hidden2.bias Parameter containing:
tensor([-4.4112e-02, -2.9887e-02, -2.8265e-02, -3.3286e-02, 7.2502e-02,
8.0232e-02, 1.8865e-03, 9.7098e-03, 3.1223e-02, -2.2631e-02,
-2.1214e-02, 3.1150e-02, 6.1466e-02, 7.0035e-02, -4.5615e-02,
-5.8772e-02, 6.2894e-02, 4.9750e-02, 5.7451e-02, 2.2339e-02,
2.6763e-02, 8.7752e-02, -5.7866e-02, -4.4658e-02, 1.2571e-02,
5.1465e-02, 3.1973e-02, 1.6999e-02, 1.0167e-02, -7.5020e-02,
5.2646e-02, 1.4925e-03, -2.8097e-02, 5.1080e-02, -4.3607e-02,
-1.0488e-02, -4.1806e-02, -2.6486e-02, -2.9917e-02, -1.4247e-02,
-8.1358e-02, -4.8308e-02, 1.0190e-02, -8.7466e-02, -3.2915e-02,
-2.2256e-02, -2.7506e-02, 5.0168e-02, -3.7357e-02, -6.6164e-02,
-4.8030e-02, 7.9535e-02, -7.8770e-02, 5.7002e-02, -8.2957e-02,
7.9409e-02, 6.7863e-02, 5.5029e-02, -1.3148e-02, 6.7071e-02,
6.9344e-02, 8.2403e-02, 7.7624e-04, 2.6836e-03, -1.0046e-02,
9.3734e-03, -6.5549e-02, -3.2896e-02, 2.6147e-02, -1.7230e-02,
8.6758e-02, -1.4608e-02, -8.7553e-02, -7.8356e-02, 7.3508e-02,
-6.6287e-03, 3.2445e-03, 4.2771e-02, 1.6711e-02, -9.7339e-03,
1.8356e-02, 8.2133e-02, 5.8495e-02, 2.1896e-02, -1.9515e-02,
4.5635e-02, 4.9553e-02, 6.5750e-02, 1.4359e-03, -5.5184e-02,
-2.4005e-02, 4.1542e-03, -5.3008e-02, 4.3737e-02, -8.1441e-02,
-6.3477e-03, -3.2300e-02, -2.7014e-02, 7.8015e-02, 3.1171e-02,
-7.1218e-02, -4.8229e-02, 6.4831e-02, 3.1426e-03, 3.5155e-02,
7.0181e-02, 7.6853e-02, 7.5148e-02, 6.6496e-02, -3.5201e-02,
6.1620e-02, -8.0127e-02, -6.0731e-02, -8.1545e-02, 5.2329e-02,
3.4222e-02, -3.3294e-02, 1.9992e-02, 8.7442e-02, -8.2247e-02,
-6.2127e-02, 9.7789e-03, 5.2960e-02, 1.2544e-02, 3.8694e-02,
-8.9400e-04, 8.6839e-02, -1.3664e-04, -6.2805e-02, -3.1600e-02,
-1.5621e-02, -7.9170e-02, -6.3270e-02, -3.4443e-02, 6.0702e-02,
4.0426e-02, -4.4049e-02, 7.5615e-02, 3.2658e-02, 7.6719e-05,
-1.6568e-02, -6.0935e-02, -7.8334e-02, -5.0843e-02, 3.5265e-02,
8.2918e-02, -1.1703e-02, 4.2792e-02, -2.1638e-02, 8.1203e-02,
7.2390e-02, 5.1223e-02, -2.9995e-02, 5.9169e-02, -4.4215e-02,
-5.9534e-02, 8.7776e-02, -6.9242e-02, -4.1378e-02, -5.9089e-03,
-2.5731e-02, 5.4990e-03, -1.4946e-02, 4.4629e-02, -5.5693e-02,
5.4679e-02, 4.8887e-02, -3.2876e-02, 5.8299e-02, -3.6360e-03,
8.2118e-02, 5.4893e-02, 8.5792e-02, -1.4165e-02, 6.2316e-02,
3.0707e-02, -2.0505e-03, 8.3190e-02, -4.6698e-03, -4.1325e-02,
1.2834e-02, -3.4875e-02, -3.5632e-02, -7.9326e-02, 8.3959e-02,
2.5515e-02, -6.2535e-02, 8.2259e-03, 1.6056e-02, 2.5417e-03,
9.4396e-03, -7.8076e-02, -3.2177e-02, -5.5346e-02, -4.9452e-02,
-7.5331e-02, 8.1705e-02, -8.7080e-02, 1.2553e-02, 7.7827e-02,
-3.2539e-02, -1.3828e-02, -8.4033e-02, -7.8187e-02, 1.8038e-02,
-6.0549e-02, 9.6873e-03, 8.8103e-02, 6.0122e-02, -4.4501e-02,
2.7200e-02, -2.9700e-02, 7.2213e-02, -7.7957e-02, -5.8321e-02,
4.1538e-03, 3.3013e-02, 5.8974e-02, -2.5956e-02, -7.1387e-02,
-6.4368e-02, 7.6615e-02, -1.6766e-02, 8.0093e-02, 6.8445e-02,
8.3135e-02, -7.3480e-02, 3.9200e-02, -1.9807e-02, 5.4121e-02,
6.0608e-02, -1.9772e-02, -7.2825e-03, 5.8620e-02, -8.8234e-02,
-3.2589e-03, -5.0743e-03, 1.0887e-02, -7.5440e-03, 7.1621e-02,
-4.9204e-02, 3.3991e-02, 6.5276e-02, -6.6759e-02, 4.1546e-02,
2.3619e-02, -2.9659e-02, -7.3076e-02, -8.2609e-02, -1.0342e-02,
8.2861e-02, 1.0027e-02, -2.5046e-02, 3.9917e-02, -2.7212e-02,
-2.1461e-02], requires_grad=True) torch.Size([256])
out.weight Parameter containing:
tensor([[ 0.0095, 0.0141, -0.0027, ..., 0.0513, 0.0086, 0.0472],
[ 0.0109, 0.0500, -0.0060, ..., -0.0260, 0.0032, -0.0464],
[-0.0083, -0.0227, -0.0512, ..., 0.0108, -0.0503, 0.0396],
...,
[ 0.0596, 0.0050, -0.0049, ..., 0.0406, 0.0231, 0.0435],
[ 0.0562, 0.0624, 0.0522, ..., 0.0311, -0.0034, 0.0553],
[-0.0130, -0.0108, -0.0064, ..., 0.0567, -0.0260, 0.0179]],
requires_grad=True) torch.Size([10, 256])
out.bias Parameter containing:
tensor([-0.0549, 0.0200, -0.0585, -0.0196, 0.0322, 0.0346, -0.0047, 0.0268,
0.0280, -0.0420], requires_grad=True) torch.Size([10])
from torch.utils.data import TensorDataset, DataLoader
bs = 64
train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(dataset = train_ds, batch_size = bs, shuffle = True)
valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(dataset = valid_ds, batch_size = bs * 2)
def get_dataloader(train_ds, valid_ds, batch_size):
return (
DataLoader(dataset = train_ds, batch_size = batch_size, shuffle = True),
DataLoader(dataset = valid_ds, batch_size = batch_size * 2)
)
import numpy as np
def loss_batch(model, loss_func, xb, yb, opt=None):
# 前向计算
loss = loss_func(model(xb), yb)
if opt is not None:
# 梯度置0
opt.zero_grad()
# 反向传播
loss.backward()
# 更新参数
opt.step()
return loss.item(), len(xb)
def fit(steps, model, loss_func, opt, train_dl, valid_dl):
for step in range(steps):
model.train()
for xb, yb in train_dl:
loss_batch(model, loss_func, xb, yb, opt)
model.eval()
with torch.no_grad():
losses, nums = zip(*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl])
val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
print('当前step:'+str(step), '验证集损失:'+str(val_loss))
from torch.optim import SGD
def get_model():
model = Mnist_NN()
return model, SGD(model.parameters(), lr = 0.001)
train_dl, valid_dl = get_dataloader(train_ds, valid_ds, bs)
model, opt = get_model()
fit(25, model, loss_func, opt, train_dl, valid_dl)
当前step:0 验证集损失:2.275792053222656
当前step:1 验证集损失:2.2407693840026854
当前step:2 验证集损失:2.187168152999878
当前step:3 验证集损失:2.0995729919433592
当前step:4 验证集损失:1.957993221282959
当前step:5 验证集损失:1.7496695947647094
当前step:6 验证集损失:1.4938231435775757
当前step:7 验证集损失:1.246651389503479
当前step:8 验证集损失:1.0482983011245728
当前step:9 验证集损失:0.9003914960861206
当前step:10 验证集损失:0.7902145584106446
当前step:11 验证集损失:0.7077002679824829
当前step:12 验证集损失:0.6448616947174072
当前step:13 验证集损失:0.5954384657859803
当前step:14 验证集损失:0.5561129764556885
当前step:15 验证集损失:0.524388509464264
当前step:16 验证集损失:0.49813378930091856
当前step:17 验证集损失:0.4757344177246094
当前step:18 验证集损失:0.4564040184020996
当前step:19 验证集损失:0.4398327290534973
当前step:20 验证集损失:0.4252701699256897
当前step:21 验证集损失:0.4126383014202118
当前step:22 验证集损失:0.40154306089878083
当前step:23 验证集损失:0.39144011669158935
当前step:24 验证集损失:0.3828064109802246