深度学习入门——mini_batch小批量数据提取

  • 在全部数据中提取出小批量的数据,作为全部数据的近似。

  • 神经网络的学习也就是针对每个mini_batch数据进行学习

# oding:utf-8

import sys, os
sys.path.append(os.pardir)
import numpy as np
from dataset.mnist import load_mnist


(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, one_hot_label=True)

print(x_train.shape) # (6000. 784) 6000个数据,784维
print(t_train.shape) # (6000, 10) 6000个数据,10维

# --------------------------------抽取小批量的数据---------------------------------------------------
# 抽取小批量的数据
train_size = x_train.shape[0]
batch_size = 10 # 抽10个
batch_mask = np.random.choice(train_size, batch_size) # 从6000个数据中随机抽取10个 获得其索引

x_batch = x_train[batch_mask] # 通过索引取出该值
t_batch = t_train[batch_mask] # 通过索引去除该监督值

主要采用 np.random.choice(train_size, batch_size) # 从6000个数据中随机抽取10个 获得其索引

你可能感兴趣的:(深度学习入门)