MNIST 是个单标签(single label,multi-class)数据集,图片尺寸都是 28 × 28 28\times28 28×28,可以将 4 幅图拼在一起,组成一幅 56 × 56 56\times56 56×56 的图像,标签也对应加在一起,就可以组成一个简易的多标签(multi-label)数据集。示例:
import os
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
#%matplotlib inline
MNIST_P = "E:/iTom/dataset/MNIST"
mnist = input_data.read_data_sets(MNIST_P, one_hot=True)
print(mnist.train.num_examples, # 55000
mnist.validation.num_examples, # 5000
mnist.test.num_examples) # 10000
x_train = mnist.train.images
x_val = mnist.validation.images
x_test = mnist.test.images
print(x_train.shape, x_val.shape, x_test.shape)
x_train = x_train.reshape(-1, 28, 28)
x_val = x_val.reshape(-1, 28, 28)
x_test = x_test.reshape(-1, 28, 28)
print(x_train.shape, x_val.shape, x_test.shape)
print(np.max(x_train), x_train.min())
y_train = mnist.train.labels
y_val = mnist.validation.labels
y_test = mnist.test.labels
print(y_train.shape, y_val.shape, y_test.shape)
def integrate(x4, y4):
"""4 幅拼成 1 幅
x4: [4, 28, 28] images
y4: [4, 10] labels
"""
up = np.concatenate([x4[0], x4[1]], 1)
down = np.concatenate([x4[2], x4[3]], 1)
x4 = np.concatenate([up, down], 0)
# plt.imshow(x4, cmap="Greys")
# plt.show()
y4 = (y4.sum(0) > 0).astype(y4.dtype)
# print(y4)
return x4, y4
def make(images, labels):
"""对数据集批量操作"""
X, Y = [], []
for i in range(0, labels.shape[0], 4):
img = images[i: i + 4]
lab = labels[i: i + 4]
_x, _y = integrate(img, lab)
X.append(_x[np.newaxis, :])
Y.append(_y[np.newaxis, :])
X = np.vstack(X)
Y = np.vstack(Y)
return X, Y
# 保存
X_test, Y_test = make(x_test, y_test)
print(X_test.shape, Y_test.shape)
np.save(os.path.join(MNIST_P, "x_test.npy"), X_test)
np.save(os.path.join(MNIST_P, "y_test.npy"), Y_test)
X_train, Y_train = make(x_train, y_train)
print(X_train.shape, Y_train.shape)
np.save(os.path.join(MNIST_P, "x_train.npy"), X_train)
np.save(os.path.join(MNIST_P, "y_train.npy"), Y_train)
X_val, Y_val = make(x_val, y_val)
print(X_val.shape, Y_val.shape)
np.save(os.path.join(MNIST_P, "x_val.npy"), X_val)
np.save(os.path.join(MNIST_P, "y_val.npy"), Y_val)