tensorflow识别cifar10数据集

源代码链接:https://www.kaggle.com/skooch/cifar-10-in-tensorflow

读取pickle的cifar10数据集:

def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

返回一个字典,其中包含数据集、标签等信息,可以通过访问不同的key获取信息

一个data_batch信息为10000行,32*32*3列的矩阵,可通过data_batch.reshape([3,32,32]).transpose(1,2,0)转为常规格式。

global_step = tf.Variable(0, trainable=False)

learning_rate = tf.train.exponential_decay(starting_rate,                 # start at 0.003
                                               global_step, 
                                               steps_per_epoch * epochs_per_decay,       # 100 epochs
                                               decay_factor,                   # 0.5 decrease
                                               staircase=staircase) 
    
     conv1 = tf.layers.conv2d(
            X,                           # Input data
            filters=64,                  # 64 filters
            kernel_size=(5, 5),          # Kernel size: 5x5
            strides=(1, 1),              # Stride: 1
            padding='SAME',              # "same" padding
            activation=None,             # None
            kernel_initializer=tf.truncated_normal_initializer(stddev=5e-2, seed=10),
            kernel_regularizer=tf.contrib.layers.l2_regularizer(scale=lamC),
            name='conv1'                 
        )
# try batch normalization
        bn1 = tf.layers.batch_normalization(
            conv1,
            axis=-1,
            momentum=0.99,
            epsilon=epsilon,
            center=True,
            scale=True,
            beta_initializer=tf.zeros_initializer(),
            gamma_initializer=tf.ones_initializer(),
            moving_mean_initializer=tf.zeros_initializer(),
            moving_variance_initializer=tf.ones_initializer(),
            training=training,
            name='bn1'
        )

        #apply relu
        conv1_bn_relu = tf.nn.relu(bn1, name='relu1')

        conv1_bn_relu = tf.layers.dropout(conv1_bn_relu, rate=0.1, seed=9, training=training)

你可能感兴趣的:(tensorflow识别cifar10数据集)