最近做项目需要多线程下载数据,就简单学了一下。
先是参考了别人的代码,修改后亲测可用。
原文链接:https://junyiseo.com/python/211.html
# -*- coding: UTF-8 -*-
import threading
from time import sleep,ctime
class myThread (threading.Thread):
def __init__(self, threadID, name, s , e):
threading.Thread.__init__(self)
self.threadID = threadID
self.name = name
self.s = s
self.e = e
def run(self):
print ("Starting " + self.name+ctime())
# 获得锁,成功获得锁定后返回True
# 可选的timeout参数不填时将一直阻塞直到获得锁定
# 否则超时后将返回False
threadLock.acquire()
#线程需要执行的方法
printImg(self.s,self.e)
# 释放锁
threadLock.release()
listImg = [] #创建需要读取的列表,可以自行创建自己的列表
for i in range(179):
listImg.append(i)
# 按照分配的区间,读取列表内容,需要其他功能在这个方法里设置
def printImg(s,e):
for i in range(s,e):
print (i)
totalThread = 3 #需要创建的线程数,可以控制线程的数量
lenList = len(listImg) #列表的总长度
gap = lenList // totalThread #列表分配到每个线程的执行数
threadLock = threading.Lock() #锁
threads = [] #创建线程列表
# 创建新线程和添加线程到列表
for i in range(totalThread):
thread = 'thread%s' % i
if i == 0:
thread = myThread(0, "Thread-%s" % i, 0,gap)
elif totalThread==i+1:
thread = myThread(i, "Thread-%s" % i, i*gap,lenList)
else:
thread = myThread(i, "Thread-%s" % i, i*gap,(i+1)*gap)
threads.append(thread) # 添加线程到列表
# 循环开启线程
for i in range(totalThread):
threads[i].start()
# 等待所有线程完成
for t in threads:
t.join()
print ("Exiting Main Thread")
然后参考这个代码,我写出了我多线程下载数据集的代码,如下:
# -*- coding: UTF-8 -*-
import threading
from time import sleep,ctime
class myThread (threading.Thread):
def __init__(self, threadID, name, s , e, encoding, path):
threading.Thread.__init__(self)
self.threadID = threadID
self.name = name
self.s = s
self.e = e
self.encoding = encoding
self.path = path
self.example = []
def run(self):
print ("Starting " + self.name + " " + ctime())
for line in open(self.path+"/"+"labels.txt",'r',encoding=self.encoding).readlines()[self.s:self.e]:
strlabel = line.split(' ')[1].strip('\n')
if len(strlabel) > max_char_count:
continue
try:
arr, initial_len = resize_image(
os.path.join(self.path, line.split(' ')[0]),
max_image_width
)
except(OSError, NameError):
print('OSError, Path:',os.path.join(self.path, line.split(' ')[0]))
continue
#print(label_to_array(strlabel))
self.example.append(
(
arr,
label_to_array(strlabel)
)
)
print ("Stoping " + self.name + " " + ctime())
def load_data():
print('Loading data')
threadNums = [32, 32, 16, 8, 8]
picNums = [400000, 400000, 150000, 80000, 70000]
encodings = ['UTF-8', 'UTF-8', 'gbk', 'gbk', 'gbk']
paths = [examples_path_1, examples_path_2, examples_path_3, examples_path_4, examples_path_5]
gaps = []
threads = []
examples = []
totalThreadCount = 0
for i in range (len(threadNums)):
gap = picNums[i] // threadNums[i]
gaps.append(gap)
for i in range (len(picNums)):
for j in range(threadNums[i]):
thread = myThread(i*100+j, "Thread-%s-%s" % (i, j), gaps[i]*j, gaps[i]*(j+1), encodings[i], paths[i])
threads.append(thread)
totalThreadCount += 1
print(paths[i], gaps[i]*j, gaps[i]*(j+1))
for i in range(totalThreadCount):
threads[i].start()
for t in threads:
t.join()
for t in range(len(threads)):
threadExample = threads[t].example
examples = examples + threadExample
print (t, threads[t].threadID, len(threads[t].example), len(examples))
print ("Exiting Main Thread")
print(len(examples))
random.shuffle(examples)
return examples, len(examples)