深度学习中,有时我们需要对数据集进行预处理,这样能够更好的读取数据。
一、png格式生成.npy格式
import numpy as np
import os
from PIL import Image
dir="C:/Users/Administrator/Desktop/trainA"
def getFileArr(dir):
result_arr=[]
label_list=[]
map={}
map_file_result={}
map_file_label={}
map_new={}
count_label=0
count=0
file_list=os.listdir(dir)
for file in file_list:
file_path=os.path.join(dir,file)
label=file.split(".")[0].split("_")[0]
map[file]=label
if label not in label_list:
label_list.append(label)
map_new[label]=count_label
count_label=count_label+1
img=Image.open(file_path)
result=np.array([])
r,g,b=img.split()
r_arr=np.array(r).reshape(4096)
g_arr=np.array(g).reshape(4096)
b_arr=np.array(b).reshape(4096)
img_arr=np.concatenate((r_arr,g_arr,b_arr))
result=np.concatenate((result,img_arr))
result=result.reshape((64,64,3))
result=result/255.0
map_file_result[file]=result
result_arr.append(result)
count=count+1
for file in file_list:
map_file_label[file]=map_new[map[file]]
#map[file]=map_new[map[file]]
ret_arr=[]
for file in file_list:
each_list=[]
label_one_zero=np.zeros(count_label)
result=map_file_result[file]
label=map_file_label[file]
label_one_zero[label]=1.0
#print(label_one_zero)
each_list.append(result)
each_list.append(label_one_zero)
ret_arr.append(each_list)
os.makedirs("C:/Users/Administrator/Desktop/npy")
np.save('C:/Users/Administrator/Desktop/npy/test_data.npy', ret_arr)
return ret_arr
if __name__=="__main__":
ret_arr=getFileArr(dir)
二、.npy格式生成png格式
import numpy as np
from PIL import Image
import os
dir="C:/Users/Administrator/Desktop/npy/"#npy文件路径
dest_dir="C:/Users/Administrator/Desktop/train/"
def npy2jpg(dir,dest_dir):
if os.path.exists(dir)==False:
os.makedirs(dir)
if os.path.exists(dest_dir)==False:
os.makedirs(dest_dir)
file=dir+'test_data.npy'
con_arr=np.load(file)
count=0
for con in con_arr:
arr=con[0]
label=con[1]
print(np.argmax(label))
arr=arr*255
#arr=np.transpose(arr,(2,1,0))
arr=np.reshape(arr,(3,64,64))
r=Image.fromarray(arr[0]).convert("L")
g=Image.fromarray(arr[1]).convert("L")
b=Image.fromarray(arr[2]).convert("L")
img=Image.merge("RGB",(r,g,b))
label_index=np.argmax(label)
img.save(dest_dir+str(label_index)+"_"+str(count)+".png")
count=count+1
if __name__=="__main__":
npy2jpg(dir,dest_dir)
三、注意
根据自己的数据集需要改尺寸和维度以及改路径。
---------------------
作者:蹦跶的小羊羔
来源:CSDN
原文:https://blog.csdn.net/yql_617540298/article/details/82747789
版权声明:本文为博主原创文章,转载请附上博文链接!