格式为(batch, frames, h, w, c)的视频输入需要使用生成器,自定义网络的输入,但是重复batch*frames次数据的读取操作,如果默认循环读取方式,读取速度100frames/s,严重影响训练进程。
电脑CPU核心数充沛,需要使用多进程方案,每一个batch分配一个进程,并行读取数据
。
百度搜网上基本是这段程序
import os
import time
from PIL import Image
from multiprocessing import Pool
def get_file_path(path):
img_paths = []
dirs = os.listdir(path)
for file_dir in dirs:
file_path = os.path.join(path, file_dir)
img_names = os.listdir(file_path)
for img_name in img_names:
img_path = os.path.join(file_path, img_name)
img_paths.append(img_path)
return img_paths
def resize_image(file_name):
try:
img = Image.open(file_name)
new_img = img.resize((250, 250), Image.ANTIALIAS)
new_img.save(file_name)
except:
print(file_name)
if __name__ == '__main__':
start = time.time()
path = r'C:\Users\Alvin_Fang\Downloads\identities_0'
img_paths = get_file_path(path)
pool = Pool(6) #CPU数量
pool.map(resize_image, img_paths)
pool.close()
pool.join()
end = time.time()
print(end - start)
但十分bug的一点:
注意:在Windows上要想使用进程模块,就必须把有关进程的代码写在当前.py文件的if __name __== ‘__main __’ :语句的下面
,才能正常使用Windows下的进程模块,Unix/Linux下则不需要。
https://blog.csdn.net/Alvin_FZW/article/details/82886004
Python官方的解释参考:
避免共享状态、使用 Join 避免僵尸进程、避免杀死进程
https://docs.python.org/zh-cn/3/library/multiprocessing.html
https://docs.python.org/zh-cn/3/library/multiprocessing.html#multiprocessing-programming
因此无法适用于我的需求,放弃此方案。
百度搜网上基本是这段程序
import concurrent.futures
def load_and_resize(image_filename):
img = cv2.imread(image_filename)
# img = cv2.resize(img, (600, 600))
start_time1 = time.time()
with concurrent.futures.ProcessPoolExecutor() as executor: ## 默认为1
image_files = glob.glob(image_path + '\\*.jpg')
'''
executor.map() 将你想要运行的函数和列表作为输入,列表中的每个元素
都是我们函数的单个输入,由于我们有6个核,我们将同时处理该列表中的6个项目
'''
executor.map(load_and_resize, image_files)
print('多核并行加速后运行 time:', round(time.time() - start_time1, 2), " 秒")
但是同multiprocessing.pool相同需要放在if __name __== ‘__main __’ :语句的下面
The main module must be importable by worker subprocesses. This means that ProcessPoolExecutor will not work in the interactive interpreter.
官网介绍:https://docs.python.org/3/library/concurrent.futures.html
唯一找到没有main()限制的可行方案
def load_data_parallel(b):
mean_image = read_image(join(data_dir, 'mean_frame.png'))
##————————###
with futures.ThreadPoolExecutor(max_workers=8) as excutor:
batch_list = list(range(batchsize))
excutor.map(load_data_parallel, batch_list)
官网介绍:https://docs.python.org/3/library/concurrent.futures.html