使用百度飞浆提供的paddle框架实现蝴蝶分类,环境:paddle 2.0.2,opencv 4.5.4.58,pycharm编译器。
随机查看一个蝴蝶图片及其类别
data_path= '.\Butterfly20\*\*.jpg'
but_files =glob.glob(data_path) #获取Butterfly20中的所有图片地址
print('图片数据为',len(but_files))
#随机显示一个样品的图片
index=random.choice(but_files) # 随机获取一个图片
print(index) # 查看地址
name=index.split('\\')[-2] # 获取标签,得到的是训练集中随机蝴蝶的类别
img = Image.open(index) # 打开图片
img = cv2.imread(index) # 图片处理
print(img.shape) # 输出图片形状(441,600,3)
img = img[:,:,::-1] # 三通道,-1表示从右往左切片,opencv输入为BGR,故从右往左切片为RGB三通道
print(f'该样本标签为:{name}')
cv2.imshow("ran_img",img)
cv2.waitKey(0)
测试输出为:
写一个Reader类,其中定义三个函数,分别为初始化、处理图像、计算长度,使用Reader类加载训练集与数据集
### 查看数据类型
data_list = [] #用个列表保存每个样本的读取路径、标签
# 由于属种名称本身是字符串,而输入模型的是数字。需要构造一个字典,把某个数字代表该属种名称。键是属种名称,值是整数。
label_list=[]
with open("E:/Pycharm/workspace/OpenCV/butterfly/species.txt") as f:
for line in f:
a,b = line.strip("\n").split(" ") #a为1-20的序号,b为每个种类的name
label_list.append([b, int(a)-1]) #将20种txt种的类别加入label_list数组种
label_dic = dict(label_list) #dict创建一个字典,字典中有20种蝴蝶类型
butterfly_path = './Butterfly20/'
#若项目目录内已经有train_set与validation_set两个数据集,则删除,之后重新创建这两个数据集
if(os.path.exists('E:/Pycharm/workspace/OpenCV/butterfly/train_set.txt')): # 判断有误文件
os.remove('E:/Pycharm/workspace/OpenCV/butterfly/train_set.txt') # 删除文件
if(os.path.exists('E:/Pycharm/workspace/OpenCV/butterfly/validation_set.txt')):
os.remove('E:/Pycharm/workspace/OpenCV/butterfly/validation_set.txt')
for i in os.listdir(butterfly_path): #得到Butterfly20里的所有文件夹
if i not in '.DS_Store': #DB_Store里是20种蝴蝶类型的名字
for j in os.listdir(os.path.join(butterfly_path, i)): #路径拼接,拼接后为./Butterfly20/20种名字,j从这个路径里提取序号.jpg
data_list.append(f'{os.path.join(butterfly_path, i, j)}\t{label_dic[i]}\n') #前一个大括号是每个图片具体路径,后一个是其种类的序号
random.shuffle(data_list) # 乱序
print(data_list[0]) #打印随机选出的第一个图片以及其属于的种类号
data_len = len(data_list)
count = 0
for data in data_list:
if count <= data_len*0.8:
with open('E:/Pycharm/workspace/OpenCV/butterfly/train_set.txt', 'a')as f: # 80%写入训练集
f.write(data)
count += 1
else:
with open('E:/Pycharm/workspace/OpenCV/butterfly/validation_set.txt', 'a')as tf: # 20%写入验证集
tf.write(data)
count += 1
# 自定义数据读取器
class Reader(Dataset):
def __init__(self, mode='train_set'):
"""
初始化函数
"""
self.data = []
with open(f'{mode}_set.txt') as f: #train_set或validation_set
for line in f.readlines():
info = line.strip().split('\t') #strip函数去掉首部等于参数值的字符,无参数表示删掉换行符
if len(info) > 0:
self.data.append([info[0].strip(), info[1].strip()])
def __getitem__(self, index): #将图片转换为(224,224)像素大小
"""
读取图片,对图片进行归一化处理,返回图片和 标签
"""
image_file, label = self.data[index] # 获取数据
img = Image.open(image_file) # 读取图片
img = img.convert('RGB')
img = img.resize((224, 224), Image.ANTIALIAS) # 图片大小样式归一化
img = np.array(img).astype('float32') # 转换成数组类型浮点型32位
img = img.transpose((2, 0, 1)) # 读出来的图像是rgb,rgb,rbg..., 转置为 rrr...,ggg...,bbb...
img = img / 255.0 # 数据缩放到0-1的范围
return img, np.array(label, dtype='int64')
def __len__(self):
"""
获取样本总数
"""
return len(self.data)
#调用Reader类,其中三个函数都会走
# 训练的数据提供器
train_dataset = Reader(mode='train')
# 测试的数据提供器
eval_dataset = Reader(mode='validation')
# 查看训练和测试数据的大小
print('train大小:', train_dataset.__len__())
print('eval大小:', eval_dataset.__len__())
# 随机查看图片数据、大小及标签
for data, label in eval_dataset:
print(data)
print(np.array(data).shape) #(3,224,224)
print(label)
break #只循环一次即可
使用paddle框架构造神经网络,选用resnet152网络用于图像分类,最后分为20类
import paddle.nn.functional as F
#定义模型
class MyNet(paddle.nn.Layer):
def __init__(self):
super(MyNet,self).__init__()
self.layer=paddle.vision.models.resnet152(pretrained=True) #152层的resnet模型,预训练模型只需要设定模型参数pretained=True
self.dropout=paddle.nn.Dropout(p=0.5) #Dropout值设为0.5,
self.fc1 = paddle.nn.Linear(1000, 512) #fc为全连接层,与模型训练后为1000个输出,要最后分20类
self.fc2 = paddle.nn.Linear(512, 20) #两个全连接层实现1000-20
#网络的前向计算过程
def forward(self,x):
x=self.layer(x) #resnet152模型
x=self.dropout(x) #值为0.5的Dropout
x=self.fc1(x) #第一个全连接层
x=F.relu(x) #使用relu函数激活
x=self.fc2(x) #第二个全连接层得到20个分类特征
return x
resnet网络结构如下:
用构建好的resnet152网络进行训练
model = paddle.Model(MyNet())
model.summary((1, 3, 224, 224)) #输出各层参数
input_define = paddle.static.InputSpec(shape=[-1,3,224,224], dtype="float32", name="img")
label_define = paddle.static.InputSpec(shape=[-1,1], dtype="int64", name="label")
#实例化网络对象并定义优化器等训练逻辑
model = MyNet()
model = paddle.Model(model,inputs=input_define,labels=label_define) #用Paddle.Model()对模型进行封装
optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters())
#上述优化器中的学习率(learning_rate)参数很重要。要是训练过程中得到的准确率呈震荡状态,忽大忽小,可以试试进一步把学习率调低。
model.prepare(optimizer=optimizer, #指定优化器
loss=paddle.nn.CrossEntropyLoss(), #指定损失函数
metrics=paddle.metric.Accuracy()) #指定评估方法
callback = paddle.callbacks.VisualDL(log_dir='./visualdl_log')
model.fit(train_data=train_dataset, #训练数据集
eval_data=eval_dataset, #测试数据集
batch_size=64, #一个批次的样本数量
epochs=100, #迭代轮次
save_dir="./visualdl_log", #把模型参数、优化器参数保存至自定义的文件夹
save_freq=20, #设定每隔多少个epoch保存模型参数及优化器参数
log_freq=100, #打印日志的频率
verbose=1, # 日志展示模式
shuffle=True, # 是否打乱数据集顺序
callbacks=callback # 回调函数使用
)
result = model.evaluate(eval_dataset, verbose=1)
print(result)
model.save('E:/Pycharm/workspace/OpenCV/butterfly/butterfly_model') # 保存模型
随机使用一张图片,通过训练好的网络进行预测蝴蝶的种类,该蝴蝶属于第15类
def load_image(file): #加载测试图片并处理图片
# 打开图片
im = Image.open(file)
# 将图片调整为跟训练数据一样的大小
im = im.convert('RGB')
im = im.resize((224, 224), Image.ANTIALIAS)
# 建立图片矩阵 类型为float32
im = np.array(im).astype(np.float32)
# 矩阵转置
im = im.transpose((2, 0, 1))
# 将像素值从[0-255]转换为[0-1]
im = im / 255.0
# print(im)
im = np.expand_dims(im, axis=0)
# 保持和之前输入image维度一致
print('im_shape的维度:', im.shape)
return im
from PIL import Image
# site = 255 # 读取图片位置
model_state_dict = paddle.load('E:/Pycharm/workspace/OpenCV/butterfly/butterfly_model.pdparams') # 读取模型
model = MyNet() # 实例化模型
model.set_state_dict(model_state_dict) #浅拷贝,读取模型
model.eval() #不进行BN与dropout,使用所有全职计算
img = load_image(index)
print(paddle.to_tensor(img).shape)
# print(paddle.reshape(paddle.to_tensor(img), (1, 3, 224, 224)))
ceshi = model(paddle.reshape(paddle.to_tensor(img), (1, 3, 224, 224))) # 测试
print('预测的结果为:', list(label_dic.keys())[np.argmax(ceshi.numpy())]) # 获取值
with open("./work/result.txt", "w") as f:
for r in result:
f.write("{}\n".format(r))
Image.open(index) # 显示图片