如何将.h5格式的数据读取并保存为图片

import cv2
import h5py
import numpy as np
from scipy.misc import imsave
from skimage import transform


def load_dataset():

    train_dataset = h5py.File('train_happy.h5', "r")
    train_set_x_orig = np.array(train_dataset["train_set_x"][:])  # your train set features
    train_set_y_orig = np.array(train_dataset["train_set_y"][:])  # your train set labels

    test_dataset = h5py.File('test_happy.h5', "r")
    test_set_x_orig = np.array(test_dataset["test_set_x"][:])  # your test set features
    test_set_y_orig = np.array(test_dataset["test_set_y"][:])  # your test set labels

    classes = np.array(test_dataset["list_classes"][:])  # the list of classes

    train_set_y_orig = train_set_y_orig.reshape((1, train_set_y_orig.shape[0]))
    test_set_y_orig = test_set_y_orig.reshape((1, test_set_y_orig.shape[0]))

    return train_set_x_orig, train_set_y_orig, test_set_x_orig, test_set_y_orig, classes


def processing():
    X_train_orig, Y_train_orig, X_test_orig, Y_test_orig, classes = load_dataset()

    # print("X_train_orig shape: " + str(X_train_orig.shape))
    # print("Y_train_orig shape: " + str(Y_train_orig.shape))
    # print("X_test_orig shape: " + str(X_test_orig.shape))
    # print("Y_test_orig shape: " + str(Y_test_orig.shape))

    # print(classes[1])

    m = len(X_train_orig)
    # print(X_train_orig[1].shape)

    Y_train_t = Y_train_orig.T

    # for i in range(8):
    #     plt.subplot(2, 4, i + 1)
    #     plt.imshow(X_train_orig[i])
    #     plt.title(Y_train_t[i])
    #     plt.axis('off')
    #
    # plt.show()

    for i in range(m):
        name = 'images/train/' + str(i) + '-[' + str(np.squeeze(Y_train_t[i])) + '].jpg'
        # name = 'images/train/' + str(i) + '.jpg'
        imsave(name, transform.rescale(X_train_orig[i].reshape(64, 64, 3), 10, mode='constant'))  # (640, 640, 3)


def reading():
    image = cv2.imread('images/train/16-[1].jpg', cv2.IMREAD_UNCHANGED)
    print(image.shape)
    cv2.namedWindow('input_image', cv2.WINDOW_AUTOSIZE)
    cv2.imshow('input_image', transform.rescale(image, 0.5, mode='constant'))
    cv2.waitKey(0)
    cv2.destroyAllWindows()


if __name__ == '__main__':
    #reading()
    processing()

## 数据集说明 ##

        本数据集中包含训练集数据X_train_orig, 训练集标签Y_train_orig, 测试集数据X_test_orig,测试集标签Y_test_orig,类别classes。其中包含600张训练图和150张测试图,每张图片的都被存储为64 * 64 的RGB彩色图像。

         load_dataset()函数负责将h5格式的数据集以一定要求加载出来;processing()函数负责把数据集中的四维矩阵转化为图片存储起来;reading()函数负责将存储的图片借助OpenCV库显示出来。

你可能感兴趣的:(Image)