动手学深度学习--课堂笔记图片分类数据集

softmax是一个非线性函数,但softmax回归是一个线性模型(linear model):是不是线性的是由决策面是否是线性函数决定的,不是由拟合的数据分布决定的。softmax只是对数据分布做了非线性的处理,但它的决策函数形式还是Xw+b的线性形式。

Fashion-MNIST数据集:包含70000张灰度图像,其中包含60,000个示例的训练集和10,000个示例的测试集,每个示例都是一个28x28灰度图像。主要分为:T恤(T-shirt)、裤子(Trouser)、套头衫(Pullover)、连衣裙(Dress)、外套(Coat)、凉鞋(Sandal)、衬衫(Shirt)、运动鞋(Sneaker)、包(Bag)、靴子(Ankle boot)

1.导入包

%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l

d2l.use_svg_display()

use_svg_display()函数指定matplotlib软件包输出svg图表以获得更清晰的图像

2.读取数据集

通过框架中的内置函数将Fashion_MNIST数据集下载并读取到内存中

trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans,download=True)

len(mnist_train), len(mnist_test)#训练集与测试集中样本的数量
mnist_train[0][0].shape#训练集中第一个图片

transforms.ToTensor():将图片转化为Tensor.

mnist_train是训练集,mnist_test是测试集,两者是torch.utils.data.Dataset的子类

root="../data", train=True, transform=trans, download=True:将Fashion-MNIST的训练集(train=True)从网上下载(download=True)到(root="../data")上级目录的data中,并确保得到是tensor而不是图片(transform=trans)

输出结果:

Out[3]:表示训练集有60000张图片,测试集有10000张图片。

Out[4]:因为通过transforms.ToTensor()的转换,变成了尺寸为(CxHxW),数据类型为torch.float32,位于[0.0, 1.0] 的Tensor,输出结果[1,28,28]的‘1’表示的是第一维的通道数为1,所以是灰度图像,后面两维中的‘28‘表示图像的高和宽。

3.定义两个可视化的数据集函数

def get_fashion_mnist_labels(labels):
    text_labels=['t-shirt','trouser','pullover','dress','cost','sandal','shirt','sneaker','bag','ankle boot']
    return [text_labels[int(i)] for i in labels]

def show_images(imgs, num_rows, num_cols,titles=None,scale=1.5):
    figsize=(num_cols*scale,num_rows*scale)
    _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i,(ax,img) in enumerate(zip(axes,imgs)):
        if torch.is_tensor(img):#是否为张量
            ax.imshow(img.numpy())#图片张量
        else:
            ax.imshow(img)#PIL图片
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
      return axes
X,y=next(iter(data.DataLoader(mnist_train,batch_size=18)))#data.DataLoader(mnist_train,batch_size=18):在mist_train数据集中加载数据,每批次要装载18个样品,最后将这些数据封装为Tensor.
show_images(X.reshape(18,28,28),2,9,titles=get_fashion_mnist_labels(y))

1)plt.subplots():plt.subplots()是matplotlib中绘制子图的一种方法。在matplotlib中整个图像为一个Figure对象,在Figure对象中可以包含一个或多个Axes对象,每个Axes(ax)对象都是一个拥有自己坐标系统的绘图区域。plt.subplots()直接在函数内部设置子图纸信息,返回两个变量,一个是Figure实例fig,另一个是AxesSubplot实例ax。fig代表整个图像,ax代表坐标轴和子图。d2l.plt.subplots(num_rows, num_cols, figsize=figsize)中,第一个参数代表子图的行数,第二个参数代表该行图像的列数,第三个参数代表每行的第几个图像。

2)axes.flatten():flatten()是numpy.ndarray.flatten的一个函数,即返回一个一维数组。axes.flatten()表示把axes数组降到一维,默认为按行的方向降。

3)enumerate():获取可迭代对象的每个元素的索引值及该元素值,进行拆包,多用于for循环

4)zip():用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表

5)axe.get_xaxis().set_visible():是设置坐标轴显示与否,包括了刻度与标签,如果设置为False则表示不显示,True为显示。

6)next()与iter():两者要一起使用。iter()函数将Iterable转换为Iterator;对获取到的迭代器(Iterator)不断使用next()函数来获取下一条数据

#flatten()实例
from numpy import *
a = arange(1, 7).reshape(3, 2)
print(a)
print(a.flatten())

输出结果:

[[1 2]
 [3 4]
 [5 6]]
[1 2 3 4 5 6]
#zip()实例
b = [4, 5, 6]
c = [7, 8]
print(list(zip(b, c)))#元素个数与最短的列表一致

 输出结果:

[(4, 7), (5, 8)]
#enumerate()实例
for i, value in enumerate(['a', 'b', 'c', 'd']):
    print(i, value)

 输出结果:

0 a
1 b
2 c
3 d
#axe.get_xaxis().set_visible()实例
import matplotlib.pyplot as plt

fig=plt.figure(figsize=(5,5),dpi=100)   #创建画布
axe=plt.subplot(1,1,1)    #创建子图
axe.set_title('test')   #设置子图标题
fig.savefig('test.png',dpi=100) #保存图片
axe.get_xaxis().set_visible(True)
plt.show()  #展示图片

输出结果:

动手学深度学习--课堂笔记图片分类数据集_第1张图片

#next()&&iter()实例
# 首先获得Iteration对象
it = iter([1, 2, 3, 4, 5])
# 循环
while True:
    try:
        # 获得下一个值
        x = next(it)
        print(x)
    except StopIteration:
        # 遇到StopIteration就退出循环
        break

输出结果:

1
2
3
4
5

 4.读取小批量数据

batch_size = 256

def get_dataloader_workers():  #@save
    """使用4个进程来读取数据"""
    return 4

train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
                             num_workers=get_dataloader_workers())

#读取训练数据所需的时间
timer = d2l.Timer()
for X, y in train_iter:
    continue
f'{timer.stop():.2f} sec'

5.整合所有的函数

定义load_data_fashion_mnist函数,用于获取和读取Fashion-MNIST数据集。

def load_data_fashion_mnist(batch_size, resize=None):  #@save
    """下载Fashion-MNIST数据集,然后将其加载到内存中"""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))#修改图片大小
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(
        root="../data", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="../data", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=get_dataloader_workers()))

 transforms.Compose():串联多个图片变换的操作

train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:
    print(X.shape, X.dtype, y.shape, y.dtype)
    break

你可能感兴趣的:(深度学习,人工智能,python)