Mnist数据集

Mnist数据集

  • 1. mnist数据集下载
  • 2.

1. mnist数据集下载

百度云链接,提取码:0wdr

2.

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/mnist/", one_hot = True)

# Load data
x_train = mnist.train.images
y_train = mnist.train.labels
x_test = mnist.test.images
y_test = mnist.test.labels

print("x_train: ", x_train.shape)
print("y_train: ", y_train.shape)
print("x_test: ", x_test.shape)
print("y_test: ", y_test.shape)

def plot_mnist(data, classes):
    
    for i in range(10):
        idxs = (classes == i)
        
        # get 10 images for class i
        images = data[idxs][0:10]
            
        for j in range(5):   
            plt.subplot(5, 10, i + j*10 + 1)
            plt.imshow(images[j].reshape(28, 28), cmap='gray')
            # print a title only once for each class
            if j == 0:
                plt.title(i)
            plt.axis('off')
    plt.show()

classes = np.argmax(y_train, 1)
plot_mnist(x_train, classes)

运行结果:

Extracting /mnist/train-images-idx3-ubyte.gz
Extracting /mnist/train-labels-idx1-ubyte.gz
Extracting /mnist/t10k-images-idx3-ubyte.gz
Extracting /mnist/t10k-labels-idx1-ubyte.gz
x_train: (55000, 784)
y_train: (55000, 10)
x_test: (10000, 784)
y_test: (10000, 10)
Mnist数据集_第1张图片

你可能感兴趣的:(Python)