前言
- 本文旨在将我训练 yolo 模型制作 VOC 数据集时使用的代码工具罗列出来,方便大家使用。
收集图片代码
- 代码来源
【Python 爬虫】收集图片
- 完整代码
"""
Created on 2021/4/19 11:47
Filename : spider_image_baidu.py
Author : Taosy
Zhihu : https://www.zhihu.com/people/1105936347
Github : https://github.com/AFei19911012
Description: Spider - get images from baidu
"""
import random
import time
import requests
import os
import re
def get_images_from_baidu(keyword, page_num, save_dir):
header = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/78.0.3904.108 Safari/537.36'}
url = 'https://image.baidu.com/search/acjson?'
n = 0
for pn in range(0, 30 * page_num, 30):
time.sleep(random.randint(0, 10) / 10.0)
param = {'tn': 'resultjson_com',
'ipn': 'rj',
'ct': 201326592,
'is': '',
'fp': 'result',
'queryWord': keyword,
'cl': 2,
'lm': -1,
'ie': 'utf-8',
'oe': 'utf-8',
'adpicid': '',
'st': -1,
'z': '',
'ic': '',
'hd': '',
'latest': '',
'copyright': '',
'word': keyword,
's': '',
'se': '',
'tab': '',
'width': '',
'height': '',
'face': 0,
'istype': 2,
'qc': '',
'nc': '1',
'fr': '',
'expermode': '',
'force': '',
'cg': '',
'pn': pn,
'rn': '30',
'gsm': '1e',
'1618827096642': ''
}
request = requests.get(url=url, headers=header, params=param)
if request.status_code == 200:
print('Request success.')
request.encoding = 'utf-8'
html = request.text
image_url_list = re.findall('"thumbURL":"(.*?)",', html, re.S)
print(image_url_list)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
for image_url in image_url_list:
image_data = requests.get(url=image_url, headers=header).content
with open(os.path.join(save_dir, f'{n:06d}.jpg'), 'wb') as fp:
fp.write(image_data)
n = n + 1
if __name__ == '__main__':
keyword = '狗'
save_dir = keyword
page_num = 2
get_images_from_baidu(keyword, page_num, save_dir)
print('Get images finished.')
- 使用方式
- 在 main 中,修改 keyword 后为你想搜索的图片的关键词,代码运行后会在 py 文件的同级目录生成存放图片的文件夹
- 在 main 中,修改 page_num 的数值可以更改爬取的图片页数
- 备注:在代码中加入了随机延迟
time.sleep(random.randint(0, 10) / 10.0)
重复图片去除
- 使用上面的代码爬取的图片中会有重复的图片,我使用下面的代码对重复图片进行了删除
- 原理:通过图片哈希值的比较筛选重复图片
- 代码:
import os
import hashlib
def find_duplicate_images(folder):
"""
查找指定文件夹中重复的图片
"""
image_hashes = {}
for root, dirs, files in os.walk(folder):
for file in files:
if file.endswith((".png", ".jpg", ".jpeg", ".bmp", ".gif")):
file_path = os.path.join(root, file)
with open(file_path, "rb") as f:
file_hash = hashlib.md5(f.read()).hexdigest()
if file_hash in image_hashes:
image_hashes[file_hash].append(file_path)
else:
image_hashes[file_hash] = [file_path]
return [image_paths for image_paths in image_hashes.values() if len(image_paths) > 1]
def delete_duplicate_images(folder):
"""
删除指定文件夹中的重复图片
"""
duplicate_images = find_duplicate_images(folder)
for image_paths in duplicate_images:
original_image_path = image_paths[0]
duplicate_image_paths = image_paths[1:]
for duplicate_image_path in duplicate_image_paths:
os.remove(duplicate_image_path)
print(f"Deleted duplicate image: {duplicate_image_path}")
print(f"Original image: {original_image_path} has been kept.")
if __name__ == '__main__':
delete_duplicate_images(r'F:\学习之路\YOLO\VOC_catANDdog\JPEGImages')
- 使用方法:将 main 中 F:\学习之路\YOLO\VOC_catANDdog\JPEGImages 换成你要处理的图片文件夹路径即可
数据集划分
- 代码来源:
目标检测算法—将数据集为划分训练集和验证集
- 完整代码:
import os, random, shutil
def moveimg(fileDir, tarDir):
pathDir = os.listdir(fileDir)
filenumber = len(pathDir)
rate = 0.1
picknumber = int(filenumber * rate)
sample = random.sample(pathDir, picknumber)
print(sample)
for name in sample:
shutil.move(fileDir + name, tarDir + "\\" + name)
return
def movelabel(file_list, file_label_train, file_label_val):
for i in file_list:
if i.endswith('.jpg'):
filename = file_label_train + "\\" + i[:-4] + '.txt'
if os.path.exists(filename):
shutil.move(filename, file_label_val)
print(i + "处理成功!")
if __name__ == '__main__':
fileDir = r"C:\Users\86159\Desktop\hat\JPEGImages" + "\\"
tarDir = r'C:\Users\86159\Desktop\hat\JPEGImages_val'
moveimg(fileDir, tarDir)
file_list = os.listdir(tarDir)
file_label_train = r"C:\Users\86159\Desktop\hat\Annotations_yolo"
file_label_val = r"C:\Users\86159\Desktop\hat\Annotations_val"
movelabel(file_list, file_label_train, file_label_val)
- 说明:修改修改第 7 行中的 rate 的值可以修改划分出来的验证集的比例