仓库:懂不懂都无所谓的img2npz

感觉很鸡肋,也可能是我太菜get不到牛逼之处
当年一点注释都没写,不愧是我
我恨你,半年前的我(……)

"""
Created on Tue Mar 10 19:00:17 2020
@author: ylylhl

Img2Npz
大约有些无用代码的残骸,懒得删了
"""
import matplotlib.pyplot as plt
import numpy as np
import os
import cv2
import random

# 存放路径,默认散装
# 如果不是散装而是分类好的,下面读文件的时候改一下就行
train_path='./trainData/'

def readimg(path):
    img=plt.imread(path)
    if len(img.shape)!=2:
        img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
    ret, img = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    img=img / 255.0
    return img

# 种类名:对应编号
# 最后预测出来的是编号
index = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4,"5": 5,
          "6": 6, "7": 7, "8": 8, "9": 9, "A": 10, "B": 11,
          "C": 12, "D": 13, "E": 14, "F": 15, "H": 16,
          "J": 17,"K": 18, "M": 19, "N": 20, "P": 21,
          "Q": 22, "R": 23, "S": 24, "T": 25, "U": 26,"V": 27,
          "W": 28, "X": 29, "Y": 30}

index_new = dict(zip(index.values(), index.keys()))

test_dataset,test_label=[],[]
train_dataset,train_label=[],[]

for i in os.listdir(train_path):
    img=readimg(train_path+i)
    # 获取种类名
    num=i[-5:-4].upper()
    # 随机分成训练集和验证集,比例随缘
    # 为什么写的是test呢,因为我当时英语不好……写作test读作validation(。
    if random.randint(1,80)<=10:
        test_dataset.append(img)
        test_label.append(index[num])
    else:
        train_dataset.append(img)
        train_label.append(index[num])

train_dataset=np.array(train_dataset)
train_label=np.array(train_label)
test_dataset=np.array(test_dataset)
test_label=np.array(test_label)

np.savez('TrainData.npz',
          train_dataset = train_dataset, 
          train_label = train_label, 
          test_dataset = test_dataset, 
          test_label = test_label) 

读的时候只需要

# 文件名
data = np.load('TrainData.npz')
test_labels=data['test_label']
train_labels=data['train_label']
#(48,60,1):原始图像大小
test_images=data['test_dataset'].reshape(len(data['test_dataset']),48,60,1)
train_images=data['train_dataset'].reshape(len(data['train_dataset']),48,60,1)

你可能感兴趣的:(仓库:懂不懂都无所谓的img2npz)