一:读取电脑中已经下好的数据集
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
from torchvision import transforms
from torch.utils import data
# 获取数据集
batch_size = 256
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(
root="./data", train=True, transform=trans, download=False)
mnist_test = torchvision.datasets.FashionMNIST(
root="./data", train=False, transform=trans, download=False)
def get_dataloader_workers():
"""调用进程来读取数据"""
return 0
train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
num_workers=get_dataloader_workers())
test_iter = data.DataLoader(mnist_test,batch_size,shuffle=True,
num_workers=get_dataloader_workers())
具体路径参考个人电脑
二:实现一个具有单隐藏的多层感知机,它包含256个隐藏单元
# 设计隐藏层
num_inputs,num_outputs,num_hiddens = 784,10,256 # hiddens即为隐藏单元个数
# 隐藏层的输入层
W1 = nn.Parameter(torch.randn(num_inputs,num_hiddens,requires_grad=True)) # randn:以给定的形状创建一个随机数组,数组元素符合标准正态分布 N(0,1) ;nn.Parameter 可加可不加,之前就没有加
b1 = nn.Parameter(torch.zeros(num_hiddens,requires_grad=True)) #偏置的行数等于此层权重的列数
# 隐藏层的输出层 前一个隐藏层的输出
W2 = nn.Parameter(torch.randn(num_hiddens,num_outputs,requires_grad=True)) # 上一层的列数为下一层的行数
b2 = nn.Parameter(torch.randn(num_outputs,requires_grad=True)) # 偏置的行数等于此层的列数
params = [W1,b1,W2,b2]
对于矩阵参数的分析:
权重W = tensor(矩阵行,矩阵列)
行数为输入数,列数为输出数
也就是说,上一层的列数,做下一层的行数
并且,偏置的行数,即偏置的个数应该等于输出的个数,即等于此层的输出数 b=(此层列数)
三:手动实现激活函数ReLU
# 手动实现激活函数
def relu(X):
a = torch.zeros_like(X) # a是和X相同形状的零矩阵 不加like会报错说zero的参数必须是一个整数的元组,而不能是张量
return torch.max(X,a) # 矩阵X与0作比较,返回大的值,即只会保留正数部分
四:编写网络,用自己写的relu做激活函数
# 实现我们自己的模型
def net(X):
X = X.reshape((-1,num_inputs)) # 和隐藏层那里一样,行数作为输入,列数作为输出,这里是整个网络的输入层,输入设为-1,即根据数据集的输入自己定
# num_inputs作为输入层的输出,然后进到隐藏层的输入层
H = relu(X @ W1 + b1) # 矩阵乘法用@符号,简便一点,当然也可以用matmul(X,W1)来实现
return (H @ W2 + b2) # 返回第一层的输出和第二层的权重作乘法,再加上偏置
五:损失函数
loss = nn.CrossEntropyLoss() #用交叉熵做损失函数
六:训练
# 训练
num_epochs,lr = 3,0.1
updater = torch.optim.SGD(params,lr=lr)
d2l.train_ch3(net,train_iter,test_iter,loss,num_epochs,updater)
资料来源:4.2. 多层感知机的从零开始实现 — 动手学深度学习 2.0.0-beta0 documentation (d2l.ai)