一,分割任务
"""
将数据集随机分成训练集、测试集
传入参数:
ratio = 0.7 # 训练样本比例
path = "/home/pi/20190701_0705" # 数据路径
new_path = "/home/pi/20190701_0705_new2" # 保存路径
使用方法:
temp = Generate_Train_and_Test(path, new_path, ratio)
temp.splict_data()
"""
import random
import os
import cv2
def makeDir(path):
try:
if not os.path.exists(path):
if not os.path.isfile(path):
# os.mkdir(path)
os.makedirs(path)
return 0
else:
return 1
except Exception as e:
print(str(e))
return -2
class Generate_Train_and_Test:
def __init__(self, path, new_path, ratio):
if not os.path.exists(new_path):
makeDir(new_path)
self.path = path
self.new_path = new_path
self.ratio = ratio
self.train_sample_path = os.path.join(new_path, "train")
self.test_sample_path = os.path.join(new_path, "test")
makeDir(self.train_sample_path)
makeDir(self.test_sample_path)
def splict_data(self):
class_names = os.listdir(self.path) # 类别:bg and ng10
for name in class_names:
print("process class name=%s" % name)
tmp_class_name = os.path.join(self.path, name)
save_train_class_name = os.path.join(self.train_sample_path, name)
save_test_class_name = os.path.join(self.test_sample_path, name)
makeDir(save_train_class_name)
makeDir(save_test_class_name)
if os.path.isdir(tmp_class_name):
image_names = os.listdir(tmp_class_name) # 其中一个类别的所有图像
total = len(image_names)
# 1, 打乱当前类中所有图像
random.shuffle(image_names)
# 2, 从当前类中,取前面的图像作为train data
train_temp = int(self.ratio * total) # 打乱后,取前面作为train_data
for i in range(0, train_temp):
print(i, image_names[i])
temp_img_name = os.path.join(tmp_class_name, image_names[i])
train_image = cv2.imread(temp_img_name)
save_train_img_name = os.path.join(save_train_class_name, image_names[i])
cv2.imwrite(save_train_img_name, train_image)
# 3, 从当前类中,取后面的图像作为test data
for i in range(train_temp, total):
print(i, image_names[i])
test_img_name = os.path.join(tmp_class_name, image_names[i])
test_image = cv2.imread(test_img_name)
save_test_img_name = os.path.join(save_test_class_name, image_names[i])
cv2.imwrite(save_test_img_name, test_image)
二, 分割和分类通用
import random
import os
import cv2
import shutil
import numpy as np
import csv
# 把包含原图和mask图的文件夹, 分成文件夹bg, ng1, ng2, ng3...
def split_cls(img_dir, class_name_list, color_list): # img_dir文件是通过labelme函数得到的, 其包含原图和mask图
img_list = os.listdir(img_dir)
for i in range(len(img_list)): # 遍历所有图像
img_full_path = os.path.join(img_dir, img_list[i])
if img_full_path.endswith('_mask.png') is True: # 找到有mask的对应的原图img, 后面再通过对应的mask颜色判断类别
img_path = img_full_path[:-9] + '.png'
mask_path = img_full_path
mask = cv2.imread(mask_path)
flag_label = False
for j in range(len(color_list) - 1): # 遍历颜色list, 判断当前的mask是属于哪个类别. 下面的大于0说明当前的mask属于当前的颜色
if np.sum(cv2.inRange(mask, tuple(color_list[j]), tuple(color_list[j]))) > 0: # [a, b]之间变成255, 其他变成0
save_mask_path = os.path.join(img_dir, class_name_list[j] + '/' + img_list[i]) # 颜色索引j和类别索引相同
save_img_path = os.path.join(img_dir, class_name_list[j] + '/' + img_list[i][:-9] + '.png')
save_img_cls(mask_path, save_mask_path)
save_img_cls(img_path, save_img_path)
flag_label = True
break # mask找到了对应的类别, 则跳出当前循环
if not flag_label: # bg
save_mask_path = os.path.join(img_dir, class_name_list[-1] + '/' + img_list[i])
save_img_path = os.path.join(img_dir, class_name_list[-1] + '/' + img_list[i][:-9] + '.png')
save_img_cls(mask_path, save_mask_path)
save_img_cls(img_path, save_img_path)
os.remove(mask_path) # 删除原来位置的图片
os.remove(img_path)
def get_color_list(csv_path):
"""
Retrieve the class names and label values for the selected dataset.
Must be in CSV format!
# Arguments
csv_path: The file path of the class dictionairy
# Returns
Two lists: one for the class names and the other for the label values
"""
filename, file_extension = os.path.splitext(csv_path)
if not file_extension == ".csv":
return ValueError("File is not a CSV!")
class_names = []
label_values = []
with open(csv_path, 'r') as csvfile:
file_reader = csv.reader(csvfile, delimiter=',')
for row in file_reader:
class_names.append(row[0])
label_values.append([int(row[1]), int(row[2]), int(row[3])])
return class_names, label_values
def save_img_cls(srcfile, dstfile):
if not os.path.isfile(srcfile):
print("%s not exist!" % (srcfile))
else:
fpath, fname = os.path.split(dstfile) # 分离文件名和路径
if not os.path.exists(fpath):
os.makedirs(fpath) # 创建路径
shutil.copyfile(srcfile, dstfile) # 复制文件
print("copy %s -> %s" % (srcfile, dstfile))
def makeDir(path):
try:
if not os.path.isdir(path):
if not os.path.isfile(path):
# os.mkdir(path)
os.makedirs(path)
return 0
else:
return 1
except Exception as e:
print(str(e))
return -2
class Generate_Train_Val_Test:
def __init__(self, path, ratio):
self.path = path
self.new_path = path + '../' + 'refined_data'
self.ratio = ratio
self.train_sample_path = os.path.join(self.new_path, "train")
self.val_sample_path = os.path.join(self.new_path, "val")
self.test_sample_path = os.path.join(self.new_path, "test")
print('train data: %.2f; val data: %.2f; test data: %.2f' % (self.ratio[0], self.ratio[1], self.ratio[2]))
makeDir(self.train_sample_path)
makeDir(self.test_sample_path)
if self.ratio[1] != 0:
makeDir(self.val_sample_path)
def get_train_val_test_seg(self):
class_names = os.listdir(self.path) # 类别:bg and ng10
for name in class_names:
if not name.endswith(('jpg', 'png', 'jpeg', 'bmp', 'csv', 'csv~')):
print("process class name=%s" % name)
class_full_path = os.path.join(self.path, name)
# 分别新建当前类的train, (val), test文件夹
save_train_class_path = os.path.join(self.train_sample_path, name)
save_test_class_path = os.path.join(self.test_sample_path, name)
makeDir(save_train_class_path)
makeDir(save_test_class_path)
if self.ratio[1] != 0:
save_val_class_path = os.path.join(self.val_sample_path, name)
makeDir(save_val_class_path)
if os.path.isdir(class_full_path):
image_names = os.listdir(class_full_path) # 其中一个类别的所有图像
image_names = [f for f in image_names if not f.endswith('_mask.png')]
total = len(image_names)
# 1, 打乱当前类中所有图像
random.shuffle(image_names)
# 2, 从当前类中,取前面的图像作为train data
train_end_index = int(self.ratio[0] * total) # 打乱后,取前面作为train_data
for i in range(0, train_end_index):
# print(i, image_names[i])
img_full_path = os.path.join(class_full_path, image_names[i])
train_image = cv2.imread(img_full_path)
label_full_path = os.path.join(class_full_path, image_names[i][:-4] + '_mask.png')
train_label = cv2.imread(label_full_path)
save_train_img_name = os.path.join(save_train_class_path, image_names[i]) # 原图和对应的mask都保存放到save_train_class_path
cv2.imwrite(save_train_img_name, train_image)
save_train_label_name = os.path.join(save_train_class_path, image_names[i][:-4] + '_mask.png')
cv2.imwrite(save_train_label_name, train_label)
# 3, 从当前类中,取中间的图像作为val data, 如果不需要val, 则start == end
val_start_index = train_end_index
val_end_index = int((self.ratio[0] + self.ratio[1]) * total)
for i in range(val_start_index, val_end_index):
# print(i, image_names[i])
img_full_path = os.path.join(class_full_path, image_names[i])
val_image = cv2.imread(img_full_path)
label_full_path = os.path.join(class_full_path, image_names[i][:-4] + '_mask.png')
val_label = cv2.imread(label_full_path)
save_val_img_path = os.path.join(save_val_class_path,
image_names[i]) # 原图和对应的mask都保存放到save_val_class_path
cv2.imwrite(save_val_img_path, val_image)
save_val_label_name = os.path.join(save_val_class_path,
image_names[i][:-4] + '_mask.png')
cv2.imwrite(save_val_label_name, val_label)
# 4, 从当前类中,取后面的图像作为test data
test_start_index = val_end_index
test_end_index = total
for i in range(test_start_index, test_end_index):
# print(i, image_names[i])
img_full_path = os.path.join(class_full_path, image_names[i])
test_image = cv2.imread(img_full_path)
label_full_path = os.path.join(class_full_path, image_names[i][:-4] + '_mask.png')
test_label = cv2.imread(label_full_path)
save_test_img_path = os.path.join(save_test_class_path,
image_names[i]) # 原图和对应的mask都保存放到save_val_class_path
cv2.imwrite(save_test_img_path, test_image)
save_test_label_name = os.path.join(save_test_class_path,
image_names[i][:-4] + '_mask.png')
cv2.imwrite(save_test_label_name, test_label)
def get_train_val_test_cls(self):
class_names = os.listdir(self.path) # 类别:bg and ng10
for name in class_names: # 对每个类别, 分别建立train data 和 test data
if not name.endswith(('jpg', 'png', 'jpeg', 'bmp', 'csv', 'csv~')):
print("process class name=%s" % name)
class_full_path = os.path.join(self.path, name)
# 分别新建当前类的train, (val), test文件夹
save_class_train_path = os.path.join(self.train_sample_path, name)
save_class_test_path = os.path.join(self.test_sample_path, name)
makeDir(save_class_train_path)
makeDir(save_class_test_path)
if self.ratio[1] != 0:
save_class_val_path = os.path.join(self.val_sample_path, name)
makeDir(save_class_val_path)
if os.path.isdir(class_full_path):
image_list = os.listdir(class_full_path) # 其中一个类别的所有图像
total = len(image_list)
# 1, 打乱当前类中所有图像
random.shuffle(image_list)
# 2, 从当前类中,取前面的图像作为train data
train_end_index = int(self.ratio[0] * total)
for i in range(0, train_end_index): # 打乱后,取前面作为train_data
print(i, image_list[i])
img_full_path = os.path.join(class_full_path, image_list[i])
train_image = cv2.imread(img_full_path)
save_train_img_name = os.path.join(save_class_train_path, image_list[i])
cv2.imwrite(save_train_img_name, train_image)
# 3, 从当前类中,取中间的图像作为val data, 如果不需要val, 则start == end
val_start_index = train_end_index
val_end_index = int((self.ratio[0] + self.ratio[1]) * total)
for i in range(val_start_index, val_end_index):
print(i, image_list[i])
img_full_path = os.path.join(class_full_path, image_list[i])
val_image = cv2.imread(img_full_path)
save_val_img_name = os.path.join(save_class_val_path, image_list[i])
cv2.imwrite(save_val_img_name, val_image)
# 3, 从当前类中,取后面的图像作为test data
test_start_index = val_end_index
test_end_index = total
for i in range(test_start_index, test_end_index):
print(i, image_list[i])
img_full_path = os.path.join(class_full_path, image_list[i])
test_image = cv2.imread(img_full_path)
save_test_img_name = os.path.join(save_class_test_path, image_list[i])
cv2.imwrite(save_test_img_name, test_image)
if __name__ == '__main__':
# 1, 分割
img_dir = '/home/pi/下载/temp/DF1_Opt1_20190712_0712/' # 只有图像
color_path = '/home/pi/下载/temp/color.csv'
name_list, color_list = get_color_list(color_path) # ng6: 255,255,255 bg: 0,0,0
split_cls(img_dir, name_list, color_list)
# ratio = [0.6, 0.2, 0.2] # 训练样本比例
ratio = [0.7, 0, 0.3] # 不需要验证集
generator = Generate_Train_Val_Test(img_dir, ratio)
generator.get_train_val_test_seg()
# 2, 分类
# img_dir = '/home/pi/下载/temp/DF6-Opt1-cropped/' # 包含bg, ng1, ng2, ...
#
# # ratio = [0.6, 0.2, 0.2] # 训练样本比例
# ratio = [0.7, 0, 0.3] # 不需要验证集
# generator = Generate_Train_Val_Test(img_dir, ratio)
# generator.get_train_val_test_cls() # 不同于分割