使用百度的飞桨框架,总结一些普遍的规律和此框架的简单使用方法
①数据处理:读取数据和预处理操作
②模型设计:网络结构(假设)
③训练配置:优化器(寻解算法)和计算资源配置
④训练过程:循环调用训练过程,前向计算+损失函数(优化目标)+后向传播
⑤保存模型:将训练好的模型保存
①加载飞桨和相关类库
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Linear
import numpy as np
import os
from PIL import Image
②使用飞桨框架提供的Mnist数据集处理函数
③paddle.dataset.mnist.train()
④常见的学术数据集均有现成处理函数(查API可见)
查阅API的方法
a.搜索:在飞桨官网https://aistudio.baidu.com进行查阅
b.分类浏览:在飞桨的API功能分类中寻找
步骤
①声明实例
②加载参数
③灌入数据
④打印结果
注意点
①图片数据归一化
②正确设置路径
③模型“校验”状态
####分析数据集结构,并拆分训练集和测试集
处理数据的五大操作
完整处理流程和异步读取数据
异步读取VS同步读取
同步读取:IO和网络计算串行,速度慢
异步读取:IO和计算通过一个"异步队列"交互,IO把数据不停放入队列,网络计算不停的从队列取 数据,二者同时进行。
PyReader
飞桨提供的异步数据读取器,只需要修改两行代码
创建一个DataLoader对象用于加载Python生成器产生的数据,数据会由Python线程预先读取,并异步送入设定了容量上限的队列中。
#定义DataLoader对象用于加载Python生成器产生的数据
data_loader = fuild.io.DataLoader.from_generator(capacity=5,return_list=True)
#设置数据生成器
data_loader.set_batch_generator(train_loader,places=place)
题目要求:查询API文档,写一个cifar-10数据集的数据读取器,并执行乱序,分批次读取,打印第一个batch数据的shape、类型信息。
#加载飞桨和相关数据处理的库
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Linear
import numpy as np
import os
import random
#设置数据读取器,读取cifar-10数据训练集
trainset = paddle.dataset.cifar.train10()
#包装数据读取器,每次读取的数据数量设置为batch_size=5
train_reader = paddle.batch(trainset, batch_size=5)
#以迭代的形式读取数据
for batch_id,data in enumerate(train_reader()):
#获取图像数据,并转为float32类型
img_data = np.array([x[0] for x in data]).astype('float32')
#获取图像标签数据,并转为float32类型
label_data = np.array([x[1] for x in data]).astype('float32')
#打印数据形状
print("图像数据形状和对应数据为:", img_data.shape, img_data[0])
print("图像标签形状和对应数据为:", label_data.shape, label_data[0])
break
print("\n打印第一个batch的第一个图像,对应标签数字为{}".format(label_data[0]))
#打乱顺序
imgs_length = len(img_data)
#定义数据集每个数据的序号,根据序号读取数据
index_list = list(range(imgs_length))
random.shuffle(index_list)
imgs_list = []
for i in index_list:
img = np.array(img_data[i]+1)*127.5
img = np.reshape(img,[3,32,32]).astype(np.uint8)
img = np.transpose(img,(1,2,0))
imgs_list.append(img)
#显示第一个batch的第一个图像
import matplotlib.pyplot as plt
print(img.shape)
print(img.dtype.name)
plt.figure("This is the first picture of cifar10")
plt.imshow(img)
plt.axis('on')
plt.title('image')
plt.show()
(result.png )]
far10")
plt.imshow(img)
plt.axis('on')
plt.title('image')
plt.show()