简介:我相信大部分人在最早接触AI的时候所学习到的就是图像分类,比如非常经典的MINIST手写数字识别,可以说类似于C语言的“Hello World”那样对于我们的启发意义。接下来图像分类这一系列的文章我将系统总结图像分类的整体流程和具体操作整理下来,方便大家理清有关图像分类的知识点。主要要感谢B站up主@同济子豪兄 的视频教学和代码实例,希望大家多去B站一键三连!
建议参考其他博主的方法利用Anaconda平台安装所需要的环境,或者推荐利用这个云GPU平台,这样可以减少安装环境可能遇到的各种麻烦,从而专心学习专业知识。平台链接:Featurize
在构建图像分类数据集时要注意删除无关图片,这个操作需要人工去完成。
图像分类数据集的质量直接关系模型训练的质量,而一个好的数据集应该具备多样性,代表性和一致性。
分割训练集和测试集是要注意数据的分布要保持一致,不然就会出现OOD(Out-Of-Distribution)问题。
melon17 瓜果数据集:
代码:
# 下载压缩包
# 如报错 Unable to establish SSL connection. 重新运行本代码块即可
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/melon17/melon17_full.zip
# 解压
!unzip melon17_full.zip >> /dev/null
b. fruit81 水果图像分类数据集:
代码:
# 下载压缩包
# 如报错 Unable to establish SSL connection. 重新运行本代码块即可
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/fruit81/fruit81_full.zip
# 解压
!unzip fruit81_full.zip >> /dev/null
导入工具包
import os
import numpy as np
import pandas as pd
import cv2
from tqdm import tqdm
import matplotlib.pyplot as plt
%matplotlib inline
b. 指定数据集路径
# 指定数据集路径
dataset_path = 'fruit81_full'
os.chdir(dataset_path)
os.listdir()
df = pd.DataFrame()
for fruit in tqdm(os.listdir()): # 遍历每个类别
os.chdir(fruit)
for file in os.listdir(): # 遍历每张图像
try:
img = cv2.imread(file)
df = df.append({'类别':fruit, '文件名':file, '图像宽':img.shape[1], '图像高':img.shape[0]}, ignore_index=True)
except:
print(os.path.join(fruit, file), '读取错误')
os.chdir('../')
os.chdir('../')
c. 可视化图像尺寸分布
from scipy.stats import gaussian_kde
from matplotlib.colors import LogNorm
x = df['图像宽']
y = df['图像高']
xy = np.vstack([x,y])
z = gaussian_kde(xy)(xy)
# Sort the points by density, so that the densest points are plotted last
idx = z.argsort()
x, y, z = x[idx], y[idx], z[idx]
plt.figure(figsize=(10,10))
# plt.figure(figsize=(12,12))
plt.scatter(x, y, c=z, s=5, cmap='Spectral_r')
# plt.colorbar()
# plt.xticks([])
# plt.yticks([])
plt.tick_params(labelsize=15)
xy_max = max(max(df['图像宽']), max(df['图像高']))
plt.xlim(xmin=0, xmax=xy_max)
plt.ylim(ymin=0, ymax=xy_max)
plt.ylabel('height', fontsize=25)
plt.xlabel('width', fontsize=25)
plt.savefig('图像尺寸分布.pdf', dpi=120, bbox_inches='tight')
plt.show()
导入工具包
import os
import shutil
import random
import pandas as pd
b. 获得所有类别名称
# 指定数据集路径
dataset_path = 'fruit81_full'
dataset_name = dataset_path.split('_')[0]
print('数据集', dataset_name)
classes = os.listdir(dataset_path)
print(classes)
c. 创建训练集文件夹和测试集文件夹
# 创建 train 文件夹
os.mkdir(os.path.join(dataset_path, 'train'))
# 创建 test 文件夹
os.mkdir(os.path.join(dataset_path, 'val'))
# 在 train 和 test 文件夹中创建各类别子文件夹
for fruit in classes:
os.mkdir(os.path.join(dataset_path, 'train', fruit))
os.mkdir(os.path.join(dataset_path, 'val', fruit))
d. 划分训练集、测试集,移动文件
test_frac = 0.2 # 测试集比例
random.seed(123) # 随机数种子,便于复现
df = pd.DataFrame()
print('{:^18} {:^18} {:^18}'.format('类别', '训练集数据个数', '测试集数据个数'))
for fruit in classes: # 遍历每个类别
# 读取该类别的所有图像文件名
old_dir = os.path.join(dataset_path, fruit)
images_filename = os.listdir(old_dir)
random.shuffle(images_filename) # 随机打乱
# 划分训练集和测试集
testset_numer = int(len(images_filename) * test_frac) # 测试集图像个数
testset_images = images_filename[:testset_numer] # 获取拟移动至 test 目录的测试集图像文件名
trainset_images = images_filename[testset_numer:] # 获取拟移动至 train 目录的训练集图像文件名
# 移动图像至 test 目录
for image in testset_images:
old_img_path = os.path.join(dataset_path, fruit, image) # 获取原始文件路径
new_test_path = os.path.join(dataset_path, 'val', fruit, image) # 获取 test 目录的新文件路径
shutil.move(old_img_path, new_test_path) # 移动文件
# 移动图像至 train 目录
for image in trainset_images:
old_img_path = os.path.join(dataset_path, fruit, image) # 获取原始文件路径
new_train_path = os.path.join(dataset_path, 'train', fruit, image) # 获取 train 目录的新文件路径
shutil.move(old_img_path, new_train_path) # 移动文件
# 删除旧文件夹
assert len(os.listdir(old_dir)) == 0 # 确保旧文件夹中的所有图像都被移动走
shutil.rmtree(old_dir) # 删除文件夹
# 工整地输出每一类别的数据个数
print('{:^18} {:^18} {:^18}'.format(fruit, len(trainset_images), len(testset_images)))
# 保存到表格中
df = df.append({'class':fruit, 'trainset':len(trainset_images), 'testset':len(testset_images)}, ignore_index=True)
# 重命名数据集文件夹
shutil.move(dataset_path, dataset_name+'_split')
# 数据集各类别数量统计表格,导出为 csv 文件
df['total'] = df['trainset'] + df['testset']
df.to_csv('数据量统计.csv', index=False)
e. 查看文件目录结构
!sudo snap install tree
!tree fruit81_split -L 2
本文所有代码均来自B站up主@ 同济子豪兄 的github中的内容,github链接为: GitHub - TommyZihao/Train_Custom_Dataset: 标注自己的数据集,训练、评估、测试、部署自己的人工智能算法