【PyTorch】基于PyTorch的softmax分类Fanshion_MNIST数据集

Fashion-MNIST数据集介绍
Fashion-MNIST是一个替代MNIST手写数字集的图像数据集。 它是由Zalando(一家德国的时尚科技公司)旗下的研究部门提供。其涵盖了来自10种类别的共7万个不同商品的正面图片。Fashion-MNIST的大小、格式和训练集/测试集划分与原始的MNIST完全一致。60000/10000的训练测试数据划分,28x28的灰度图片。你可以直接用它来测试你的机器学习和深度学习算法性能,且不需要改动任何的代码
softmax介绍
用来解决多分类问题,将输出的没一行中的概率最大值所对应的索引表示为这个元素的对应类别。可参考:https://www.cnblogs.com/wangyarui/p/8670769.html

代码及介绍

1.导入需要的包

import torch 
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
from IPython import display
import numpy as np

2. 下载torch中自带的数据集Fashion-MNIST并放入适当的位置

mnist_train = torchvision.datasets.FashionMNIST(root='~/Desktop/OpenCV_demo/Datasets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Desktop/OpenCV_demo/Datasets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor())

3. 定义函数显示数据集中的10个类别

def get_fashion_mnist_labels(labels):
    text_labels = ['t-shirt','trouser','pullover','dress','coat',
                  'sandal','shirt','sneaker','bag','ankle boot']
    return [text_labels[int(i)]for i in labels]

4. 定义显示图像的函数,并显示前十个样本及标签

def use_svg_display():
    """Use svg format to display plot in jupyter"""
    display.set_matplotlib_formats('svg')
#在一行内画出多张图像和对应标签的函数 
def show_fashion_mnist(images, labels):
    use_svg_display()
    # 这里的_表示我们忽略(不使用)的变量
    _, figs = plt.subplots(1, len(images), figsize=(12, 12))
    for f, img, lbl in zip(figs, images, labels):
        f.imshow(img.view((28, 28)).numpy())
        f.set_title(lbl)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)
    plt.show()
#前10个样本图像和文本标签
X, y = [], []
for i in range(10):
    X.append(mnist_train[i][0])
    y.append(mnist_train[i][1])
show_fashion_mnist(X, get_fashion_mnist_labels(y))

输出10个样本及标签:
【PyTorch】基于PyTorch的softmax分类Fanshion_MNIST数据集_第1张图片
5. 设置batch_size并确定训练集和测试集

#batch_size的设置
batch_size = 256
if sys.platform.startswith('win'):
    num_workers = 0
else:
    num_workers = 4
train_iter = torch.utils.data.DataLoader(
    mnist_train,batch_size = batch_size,shuffle = True,num_workers =  num_workers)
test_iter = torch.utils.data.DataLoader(
    mnist_test,batch_size = batch_size,shuffle = True,num_workers =  num_workers)

6. 初始化参数

#输入图片大小为28*28= 784 输出类别为10 因此,设置权重为784*10 偏置为1*10
num_inputs = 784
num_outputs = 10
#初始换权重和参数 
#权重 初始化成 均值为零,方差为0.01的大小为748*10的正太随机数
#偏置初始化为0
w = torch.tensor(np.random.normal(0,0.01,(num_inputs,num_outputs)),dtype = torch.float)
b = torch.zeros(num_outputs,dtype = torch.float)

#需要设置参数梯度
w.requires_grad_(requires_grad = True)
b.requires_grad_(requires_grad = True)

7. 定义softmax运算

#softmax 运算 输出为概率值  先指数运算 然后对行求和 然后求每一个元素的概率
def softmax(X):
    X_exp = X.exp()
    partition = X_exp.sum(dim = 1, keepdim = True)
    return X_exp/partition

8. 定义线性网络

#利用view函数啊将X设置成 未知*num_inputs大小 其中-1表示根据 num_inputs自行判断
#定义的网络模型为线性的网络
def net(X):
    return softmax(torch.mm(X.view((-1,num_inputs)),w)+b)

9. 定义损失函数

#定损失函数  softmax使用交叉熵损失  H= -log y_hat
#gather函数  y.view(-1,1)将y设置为 未知*1元素,其中未知根据具体情况自动生成
#1代表dim=1 表示行 将行元素对应y.view(-1,1)取出
#可以用一句代替 loss = nn.CrossEntropyLoss()
def cross_entropy(y_hat,y):
    return -torch.log(y_hat.gather(1,y.view(-1,1)))

10. 计算准确率

#计算准确率
def accuracy(y_hat,y):
    #argmax(dim=1)所有行上最大值所对应的索引值(类别) 如果和y相等 证明预测正确 
    #mean() 求均值 得到准确率
    return (y_hat.argmax(dim=1)==y).float().mean().item()

11. 在模型上评价数据集的精度

#在模型上评价数据集的准确率  .item()将Tensor转换为number
def evaluate_accuracy(data_iter,net):
    acc_sum,n = 0.0,0
    for X,y in data_iter:
        #计算判断准确的元素
        acc_sum += (net(X).argmax(dim=1)==y).float().sum().item()
        #通过shape 获得y的列元素
        n += y.shape[0]
    return acc_sum/n

12. 定义优化算法

#优化算法  小批量随机梯度下降算法
#可以用一句代替 optimizer = torch.optim.SGD(net.parameters(), lr=0.1)
def sgd(params,lr,batch_size):
    for param in params:
        param.data -= lr*param.grad/batch_size

13. 训练模型

#训练模型
num_epochs,lr= 5,0.1

def train_softmax(net,train_iter,test_iter,loss,num_epochs,batch_size,
                  params=None, lr=None, optimizer=None):
    for epoch in range(num_epochs):
        #损失值、正确数量、总数 初始化
        train_l_sum,train_acc_sum,n = 0.0,0.0,0
        
        for X,y in train_iter:
            y_hat = net(X)
            l = loss(y_hat,y).sum()
            
             # 梯度清零 损失函数和优化函数梯度清零
            if optimizer is not None:
                optimizer.zero_grad()
            elif params is not None and params[0].grad is not None:
                for param in params:
                    param.grad.data.zero_()
                    
            l.backward()
            if optimizer is None:
                sgd(params, lr, batch_size)
            else:
                optimizer.step() 
            
            train_l_sum += l.item()
            train_acc_sum +=(y_hat.argmax(dim=1)==y).sum().item()
            n += y.shape[0]
            
        test_acc = evaluate_accuracy(test_iter,net)
        print('epoch %d, loss %.4f, train acc %.3f,test acc %.3f'
              %(epoch+1,train_l_sum/n,train_acc_sum/n,test_acc))
        
train_softmax(net,train_iter,test_iter,cross_entropy,num_epochs,batch_size,[w,b],lr)

输出结果:
【PyTorch】基于PyTorch的softmax分类Fanshion_MNIST数据集_第2张图片
14. 测试模型


X,y = iter(test_iter).next()

true_labels = get_fashion_mnist_labels(y.numpy())
pred_labels = get_fashion_mnist_labels(net(X).argmax(dim=1).numpy())
titles = [true + '\n' + pred for true, pred in zip(true_labels, pred_labels)]

show_fashion_mnist(X[0:9], titles[0:9])

输出结果:
【PyTorch】基于PyTorch的softmax分类Fanshion_MNIST数据集_第3张图片
完整代码:

import torch 
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
from IPython import display
import numpy as np

mnist_train = torchvision.datasets.FashionMNIST(root='~/Desktop/OpenCV_demo/Datasets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Desktop/OpenCV_demo/Datasets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor())

def get_fashion_mnist_labels(labels):
    text_labels = ['t-shirt','trouser','pullover','dress','coat',
                  'sandal','shirt','sneaker','bag','ankle boot']
    return [text_labels[int(i)]for i in labels]
def use_svg_display():
    """Use svg format to display plot in jupyter"""
    display.set_matplotlib_formats('svg')
#在一行内画出多张图像和对应标签的函数 
def show_fashion_mnist(images, labels):
    use_svg_display()
    # 这里的_表示我们忽略(不使用)的变量
    _, figs = plt.subplots(1, len(images), figsize=(12, 12))
    for f, img, lbl in zip(figs, images, labels):
        f.imshow(img.view((28, 28)).numpy())
        f.set_title(lbl)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)
    plt.show()
#前10个样本图像和文本标签
X, y = [], []
for i in range(10):
    X.append(mnist_train[i][0])
    y.append(mnist_train[i][1])
show_fashion_mnist(X, get_fashion_mnist_labels(y))

#batch_size的设置
batch_size = 256
if sys.platform.startswith('win'):
    num_workers = 0
else:
    num_workers = 4
train_iter = torch.utils.data.DataLoader(
    mnist_train,batch_size = batch_size,shuffle = True,num_workers =  num_workers)
test_iter = torch.utils.data.DataLoader(
    mnist_test,batch_size = batch_size,shuffle = True,num_workers =  num_workers)

#输入图片大小为28*28= 784 输出类别为10 因此,设置权重为784*10 偏置为1*10
num_inputs = 784
num_outputs = 10
#初始换权重和参数 
#权重 初始化成 均值为零,方差为0.01的大小为748*10的正太随机数
#偏置初始化为0
w = torch.tensor(np.random.normal(0,0.01,(num_inputs,num_outputs)),dtype = torch.float)
b = torch.zeros(num_outputs,dtype = torch.float)

#需要设置参数梯度
w.requires_grad_(requires_grad = True)
b.requires_grad_(requires_grad = True)

#softmax 运算 输出为概率值  先指数运算 然后对行求和 然后求每一个元素的概率
def softmax(X):
    X_exp = X.exp()
    partition = X_exp.sum(dim = 1, keepdim = True)
    return X_exp/partition
#利用view函数啊将X设置成 未知*num_inputs大小 其中-1表示根据 num_inputs自行判断
#定义的网络模型为线性的网络
def net(X):
    return softmax(torch.mm(X.view((-1,num_inputs)),w)+b)
#定损失函数  softmax使用交叉熵损失  H= -log y_hat
#gather函数  y.view(-1,1)将y设置为 未知*1元素,其中未知根据具体情况自动生成
#1代表dim=1 表示行 将行元素对应y.view(-1,1)取出
def cross_entropy(y_hat,y):
    return -torch.log(y_hat.gather(1,y.view(-1,1)))
def accuracy(y_hat,y):
    #argmax(dim=1)所有行上最大值所对应的索引值(类别) 如果和y相等 证明预测正确 
    #mean() 求均值 得到准确率
    return (y_hat.argmax(dim=1)==y).float().mean().item()
def evaluate_accuracy(data_iter,net):
    acc_sum,n = 0.0,0
    for X,y in data_iter:
        #计算判断准确的元素
        acc_sum += (net(X).argmax(dim=1)==y).float().sum().item()
        #通过shape 获得y的列元素
        n += y.shape[0]
    return acc_sum/n
def sgd(params,lr,batch_size):
    for param in params:
        param.data -= lr*param.grad/batch_size
#训练
num_epochs,lr= 5,0.1
def train_softmax(net,train_iter,test_iter,loss,num_epochs,batch_size,
                  params=None, lr=None, optimizer=None):
    for epoch in range(num_epochs):
        #损失值、正确数量、总数 初始化
        train_l_sum,train_acc_sum,n = 0.0,0.0,0
        
        for X,y in train_iter:
            y_hat = net(X)
            l = loss(y_hat,y).sum()
            
             # 梯度清零 损失函数和优化函数梯度清零
            if optimizer is not None:
                optimizer.zero_grad()
            elif params is not None and params[0].grad is not None:
                for param in params:
                    param.grad.data.zero_()
                    
            l.backward()
            if optimizer is None:
                sgd(params, lr, batch_size)
            else:
                optimizer.step() 
            
            train_l_sum += l.item()
            train_acc_sum +=(y_hat.argmax(dim=1)==y).sum().item()
            n += y.shape[0]
            
        test_acc = evaluate_accuracy(test_iter,net)
        print('epoch %d, loss %.4f, train acc %.3f,test acc %.3f'
              %(epoch+1,train_l_sum/n,train_acc_sum/n,test_acc))
        
train_softmax(net,train_iter,test_iter,cross_entropy,num_epochs,batch_size,[w,b],lr)
#测试
X,y = iter(test_iter).next()
true_labels = get_fashion_mnist_labels(y.numpy())
pred_labels = get_fashion_mnist_labels(net(X).argmax(dim=1).numpy())
titles = [true + '\n' + pred for true, pred in zip(true_labels, pred_labels)]

show_fashion_mnist(X[0:9], titles[0:9])

你可能感兴趣的:(【PyTorch】基于PyTorch的softmax分类Fanshion_MNIST数据集)