1.导包
import torch
from torch import nn
2.下载Fashion-MNIST数据集,然后将其加载到内存中
from torchvision import transforms
import torchvision
from torch.utils import data
def load_data_fashion_mnist(batch_size, resize=None):
trans = [transforms.ToTensor()]
if resize:
trans.insert(0, transforms.Resize(resize))
trans = transforms.Compose(trans)
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
root="../data", train=False, transform=trans, download=True)
return (data.DataLoader(mnist_train, batch_size, shuffle=True,
num_workers=4),
data.DataLoader(mnist_test, batch_size, shuffle=False,
num_workers=4))
batch_size=256
train_iter,test_iter=load_data_fashion_mnist(batch_size)
3.在线性层前定义了展平层(flatten),来调整网络输入的形状
net=nn.Sequential(nn.Flatten(),
nn.Linear(784,10))
4.调整权重正态分布
def init_weights(m):
if type(m)==nn.Linear:
nn.init.normal_(m.weight,std=0.01)
net.apply(init_weights)
loss=nn.CrossEntropyLoss(reduction='none')
6.定义梯度下降
trainer=torch.optim.SGD(net.parameters(),lr=0.1)
num_epochs=10
7.定义一个Accumulator类求和
class Accumulator:
def __init__(self,n):
self.data=[0.0]*n
def add(self,*args):
self.data=[a+float(b) for a, b in zip(self.data,args)]
def reset(self):
self.data=[0.0]*len(self.data)
def __getitem__(self, item):
return self.data[item]
8.计算准确率
def accuracy(y_hat,y):
if len(y_hat.shape)>1 and y_hat.shape[1]>1:
y_hat=y_hat.argmax(axis=1)
cmp=y_hat.type(y.dtype)==y
return float(cmp.type(y.dtype).sum())
def evaluate_accuracy(net,data_iter):
net.eval()
metric=Accumulator(2)
with torch.no_grad():
for X,y in data_iter:
metric.add(accuracy(net(X),y),y.numel())
return metric[0]/metric[1]
9.定义训练
def train_epoch_ch3(net,train_iter,loss,updater):
net.train()
metric=Accumulator(3)
for X,y in train_iter:
y_hat=net(X)
#print(y_hat)
l=loss(y_hat,y)
updater.zero_grad()
l.mean().backward()
updater.step()
metric.add(float(l.sum()),accuracy(y_hat,y),y.numel())
return metric[0]/metric[2],metric[1]/metric[2]
def train_ch3(net,train_iter,test_iter,loss,num_epochs,updater):
for epoch in range(num_epochs):
train_metrics=train_epoch_ch3(net,train_iter,loss,updater)
test_acc=evaluate_accuracy(net,test_iter)
print(f'epoch {epoch+1}, train_metrics {train_metrics}, test_acc {test_acc}')
train_ch3(net,train_iter,test_iter,loss,num_epochs,trainer)
def get_fashion_mnist_labels(labels):
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
11.预测
def predict_ch3(net,test_iter):
for X,y in test_iter:
break
preds=get_fashion_mnist_labels(net(X).argmax(axis=1))
return preds
predict_ch3(net,test_iter)