基于paddlepaddle构建resnet神经网络的蝴蝶分类

一、序言

使用百度飞浆提供的paddle框架实现蝴蝶分类,环境:paddle 2.0.2,opencv 4.5.4.58,pycharm编译器。

目录结构:
基于paddlepaddle构建resnet神经网络的蝴蝶分类_第1张图片

  • Butterfly20里有20个文件夹,分别代表20种蝴蝶种类,每个文件夹内有多个同种类的蝴蝶照片
  • Butterfly20_test里有200张蝴蝶照片用于测试训练好的网络
  • visualdl_log里存放训练好的网络,使用log文件格式
  • species.txt里存放20种类别的名称和序号
  • train_set和validation_set在运行时随机分配

二、准备数据

随机查看一个蝴蝶图片及其类别

data_path= '.\Butterfly20\*\*.jpg'
but_files =glob.glob(data_path) #获取Butterfly20中的所有图片地址

print('图片数据为',len(but_files))

#随机显示一个样品的图片
index=random.choice(but_files)  # 随机获取一个图片
print(index)  # 查看地址

name=index.split('\\')[-2]  # 获取标签,得到的是训练集中随机蝴蝶的类别
img = Image.open(index)  # 打开图片
img = cv2.imread(index)  # 图片处理
print(img.shape)  # 输出图片形状(441,600,3)
img = img[:,:,::-1] # 三通道,-1表示从右往左切片,opencv输入为BGR,故从右往左切片为RGB三通道
print(f'该样本标签为:{name}')
cv2.imshow("ran_img",img)
cv2.waitKey(0)

测试输出为:

基于paddlepaddle构建resnet神经网络的蝴蝶分类_第2张图片

写一个Reader类,其中定义三个函数,分别为初始化、处理图像、计算长度,使用Reader类加载训练集与数据集

### 查看数据类型
data_list = [] #用个列表保存每个样本的读取路径、标签
# 由于属种名称本身是字符串,而输入模型的是数字。需要构造一个字典,把某个数字代表该属种名称。键是属种名称,值是整数。
label_list=[]
with open("E:/Pycharm/workspace/OpenCV/butterfly/species.txt") as f:
    for line in f:
        a,b = line.strip("\n").split(" ") #a为1-20的序号,b为每个种类的name
        label_list.append([b, int(a)-1]) #将20种txt种的类别加入label_list数组种
label_dic = dict(label_list) #dict创建一个字典,字典中有20种蝴蝶类型

butterfly_path = './Butterfly20/'
#若项目目录内已经有train_set与validation_set两个数据集,则删除,之后重新创建这两个数据集
if(os.path.exists('E:/Pycharm/workspace/OpenCV/butterfly/train_set.txt')):  # 判断有误文件
    os.remove('E:/Pycharm/workspace/OpenCV/butterfly/train_set.txt')  # 删除文件
if(os.path.exists('E:/Pycharm/workspace/OpenCV/butterfly/validation_set.txt')):
    os.remove('E:/Pycharm/workspace/OpenCV/butterfly/validation_set.txt')

for i in os.listdir(butterfly_path): #得到Butterfly20里的所有文件夹
    if i not in '.DS_Store': #DB_Store里是20种蝴蝶类型的名字
        for j in os.listdir(os.path.join(butterfly_path, i)): #路径拼接,拼接后为./Butterfly20/20种名字,j从这个路径里提取序号.jpg
            data_list.append(f'{os.path.join(butterfly_path, i, j)}\t{label_dic[i]}\n') #前一个大括号是每个图片具体路径,后一个是其种类的序号

random.shuffle(data_list)  # 乱序
print(data_list[0]) #打印随机选出的第一个图片以及其属于的种类号
data_len = len(data_list)
count = 0

for data in data_list:
    if count <= data_len*0.8:
        with open('E:/Pycharm/workspace/OpenCV/butterfly/train_set.txt', 'a')as f: # 80%写入训练集
            f.write(data)
            count += 1
    else:
        with open('E:/Pycharm/workspace/OpenCV/butterfly/validation_set.txt', 'a')as tf:  # 20%写入验证集
            tf.write(data)
            count += 1

# 自定义数据读取器
class Reader(Dataset):
    def __init__(self, mode='train_set'):
        """
        初始化函数
        """
        self.data = []
        with open(f'{mode}_set.txt') as f: #train_set或validation_set
            for line in f.readlines():
                info = line.strip().split('\t') #strip函数去掉首部等于参数值的字符,无参数表示删掉换行符
                if len(info) > 0:
                    self.data.append([info[0].strip(), info[1].strip()])

    def __getitem__(self, index): #将图片转换为(224,224)像素大小
        """
        读取图片,对图片进行归一化处理,返回图片和 标签
        """
        image_file, label = self.data[index]  # 获取数据
        img = Image.open(image_file)  # 读取图片
        img = img.convert('RGB')
        img = img.resize((224, 224), Image.ANTIALIAS)  # 图片大小样式归一化
        img = np.array(img).astype('float32')  # 转换成数组类型浮点型32位
        img = img.transpose((2, 0, 1))  # 读出来的图像是rgb,rgb,rbg..., 转置为 rrr...,ggg...,bbb...
        img = img / 255.0  # 数据缩放到0-1的范围
        return img, np.array(label, dtype='int64')

    def __len__(self):
        """
        获取样本总数
        """
        return len(self.data)

#调用Reader类,其中三个函数都会走
# 训练的数据提供器
train_dataset = Reader(mode='train')
# 测试的数据提供器
eval_dataset = Reader(mode='validation')

# 查看训练和测试数据的大小
print('train大小:', train_dataset.__len__())
print('eval大小:', eval_dataset.__len__())

# 随机查看图片数据、大小及标签
for data, label in eval_dataset:
    print(data)
    print(np.array(data).shape) #(3,224,224)
    print(label)
    break #只循环一次即可

三、构建网络

使用paddle框架构造神经网络,选用resnet152网络用于图像分类,最后分为20类

import paddle.nn.functional as F
#定义模型
class MyNet(paddle.nn.Layer):
    def __init__(self):
        super(MyNet,self).__init__()
        self.layer=paddle.vision.models.resnet152(pretrained=True) #152层的resnet模型,预训练模型只需要设定模型参数pretained=True
        self.dropout=paddle.nn.Dropout(p=0.5) #Dropout值设为0.5,
        self.fc1 = paddle.nn.Linear(1000, 512) #fc为全连接层,与模型训练后为1000个输出,要最后分20类
        self.fc2 = paddle.nn.Linear(512, 20) #两个全连接层实现1000-20
    #网络的前向计算过程
    def forward(self,x):
        x=self.layer(x) #resnet152模型
        x=self.dropout(x) #值为0.5的Dropout
        x=self.fc1(x) #第一个全连接层
        x=F.relu(x) #使用relu函数激活
        x=self.fc2(x) #第二个全连接层得到20个分类特征
        return x

resnet网络结构如下:

基于paddlepaddle构建resnet神经网络的蝴蝶分类_第3张图片
基于paddlepaddle构建resnet神经网络的蝴蝶分类_第4张图片
基于paddlepaddle构建resnet神经网络的蝴蝶分类_第5张图片
基于paddlepaddle构建resnet神经网络的蝴蝶分类_第6张图片
基于paddlepaddle构建resnet神经网络的蝴蝶分类_第7张图片
基于paddlepaddle构建resnet神经网络的蝴蝶分类_第8张图片
基于paddlepaddle构建resnet神经网络的蝴蝶分类_第9张图片
基于paddlepaddle构建resnet神经网络的蝴蝶分类_第10张图片
基于paddlepaddle构建resnet神经网络的蝴蝶分类_第11张图片
基于paddlepaddle构建resnet神经网络的蝴蝶分类_第12张图片
基于paddlepaddle构建resnet神经网络的蝴蝶分类_第13张图片
基于paddlepaddle构建resnet神经网络的蝴蝶分类_第14张图片
基于paddlepaddle构建resnet神经网络的蝴蝶分类_第15张图片
基于paddlepaddle构建resnet神经网络的蝴蝶分类_第16张图片
基于paddlepaddle构建resnet神经网络的蝴蝶分类_第17张图片
基于paddlepaddle构建resnet神经网络的蝴蝶分类_第18张图片
基于paddlepaddle构建resnet神经网络的蝴蝶分类_第19张图片
基于paddlepaddle构建resnet神经网络的蝴蝶分类_第20张图片
基于paddlepaddle构建resnet神经网络的蝴蝶分类_第21张图片

四、训练网络

用构建好的resnet152网络进行训练

model = paddle.Model(MyNet())
model.summary((1, 3, 224, 224)) #输出各层参数

input_define = paddle.static.InputSpec(shape=[-1,3,224,224], dtype="float32", name="img")
label_define = paddle.static.InputSpec(shape=[-1,1], dtype="int64", name="label")

#实例化网络对象并定义优化器等训练逻辑
model = MyNet()
model = paddle.Model(model,inputs=input_define,labels=label_define) #用Paddle.Model()对模型进行封装
optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters())
#上述优化器中的学习率(learning_rate)参数很重要。要是训练过程中得到的准确率呈震荡状态,忽大忽小,可以试试进一步把学习率调低。

model.prepare(optimizer=optimizer, #指定优化器
              loss=paddle.nn.CrossEntropyLoss(), #指定损失函数
              metrics=paddle.metric.Accuracy()) #指定评估方法

callback = paddle.callbacks.VisualDL(log_dir='./visualdl_log')

model.fit(train_data=train_dataset,     #训练数据集
          eval_data=eval_dataset,         #测试数据集
          batch_size=64,                  #一个批次的样本数量
          epochs=100,                      #迭代轮次
          save_dir="./visualdl_log", #把模型参数、优化器参数保存至自定义的文件夹
          save_freq=20,                    #设定每隔多少个epoch保存模型参数及优化器参数
          log_freq=100,                     #打印日志的频率
          verbose=1,                        # 日志展示模式
          shuffle=True,                     # 是否打乱数据集顺序
          callbacks=callback                # 回调函数使用
        )

result = model.evaluate(eval_dataset, verbose=1)
print(result)

model.save('E:/Pycharm/workspace/OpenCV/butterfly/butterfly_model')  # 保存模型

五、预测图片

随机使用一张图片,通过训练好的网络进行预测蝴蝶的种类,该蝴蝶属于第15类

def load_image(file): #加载测试图片并处理图片
    # 打开图片
    im = Image.open(file)
    # 将图片调整为跟训练数据一样的大小
    im = im.convert('RGB')
    im = im.resize((224, 224), Image.ANTIALIAS)
    # 建立图片矩阵 类型为float32
    im = np.array(im).astype(np.float32)
    # 矩阵转置
    im = im.transpose((2, 0, 1))
    # 将像素值从[0-255]转换为[0-1]
    im = im / 255.0
    # print(im)
    im = np.expand_dims(im, axis=0)
    # 保持和之前输入image维度一致
    print('im_shape的维度:', im.shape)
    return im

from PIL import Image
# site = 255  # 读取图片位置
model_state_dict = paddle.load('E:/Pycharm/workspace/OpenCV/butterfly/butterfly_model.pdparams')  # 读取模型
model = MyNet()  # 实例化模型
model.set_state_dict(model_state_dict) #浅拷贝,读取模型
model.eval() #不进行BN与dropout,使用所有全职计算

img = load_image(index)

print(paddle.to_tensor(img).shape)
# print(paddle.reshape(paddle.to_tensor(img), (1, 3, 224, 224)))
ceshi = model(paddle.reshape(paddle.to_tensor(img), (1, 3, 224, 224)))  # 测试
print('预测的结果为:', list(label_dic.keys())[np.argmax(ceshi.numpy())])  # 获取值
with open("./work/result.txt", "w") as f:
    for r in result:
        f.write("{}\n".format(r))
Image.open(index)  # 显示图片

预测结果:
基于paddlepaddle构建resnet神经网络的蝴蝶分类_第22张图片

你可能感兴趣的:(机器学习,paddlepaddle,神经网络,分类)