【tensorflow】Cifar10卷积神经网络实时训练过程

执行环境,可能出现
AttributeError: module 'tensorflow.python.ops.image_ops' has no attribute 'per_image_whitening'
需要将cifar10_input.py 的182行改为per_image_whitening 改为 per_image_standardization

cifar10_input.pycifar10.py 下载地址

#coding=utf-8
import cifar10,cifar10_input
from matplotlib import pyplot as plt
import tensorflow as tf
import numpy as np
import time
import math

max_steps = 4500
batch_size = 128

data_dir = './cifar10_data/cifar-10-batches-bin'
#下载好的数据集所在的文件夹

def variable_with_weight_loss(shape, stddev, wl):
    var = tf.Variable(tf.truncated_normal(shape, stddev=stddev))
    if wl is not None:
        weight_loss = tf.multiply(tf.nn.l2_loss(var), wl, name='weight_loss')
        ## wl 控制l2损失的比重
        tf.add_to_collection('losses', weight_loss) 
        ## 参数的l2损失加入到损失集合中
    return var


def loss(logits, labels):

    labels = tf.cast(labels, tf.int64)
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=logits, labels=labels, name='cross_entropy_per_example')
    # logits为[batch_size,num_classes]
    # labels为[batch_size,]的一维向量,其中每一个元素代表对应样本的类别
    # 先对网络的输出 Logits 进行 Softmax 概率化
    # Cross-Entropy 每个样本的交叉熵(损失)
    cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
    # 一个 batch 内样本求平均损失
    tf.add_to_collection('losses', cross_entropy_mean)
    
    return tf.add_n(tf.get_collection('losses'), name='total_loss') 
    # get 损失集合中所有损失,并相加后返回损失总和
  

# cifar10.maybe_download_and_extract()
# 如果没有下载,则需要将该行注释取消,
# 当然检查到 data_dir 目录下已经下载好的,则自动取消下载

images_train, labels_train = cifar10_input.distorted_inputs(data_dir=data_dir,
                                                            batch_size=batch_size)

images_test, labels_test = cifar10_input.inputs(eval_data=True,
                                                data_dir=data_dir,
                                                batch_size=batch_size)                                                  
#images_train, labels_train = cifar10.distorted_inputs()
#images_test, labels_test = cifar10.inputs(eval_data=True)

image_holder = tf.placeholder(tf.float32, [batch_size, 24, 24, 3])
label_holder = tf.placeholder(tf.int32, [batch_size])
#原始图像是32*32*3,distorted_inputs函数随机裁剪,旋转成24*24*3的尺寸
#inputs用于测试集,只在正中间裁剪成24*24*3的尺寸
#logits = inference(image_holder)

#############################第一层卷积层###################################
weight1 = variable_with_weight_loss(shape=[5, 5, 3, 64], stddev=5e-2, wl=0.0)
kernel1 = tf.nn.conv2d(image_holder, weight1, [1, 1, 1, 1], padding='SAME')
bias1 = tf.Variable(tf.constant(0.0, shape=[64]))
conv1 = tf.nn.relu(tf.nn.bias_add(kernel1, bias1)) #24*24*64
#####################################################################

#############################第二层池化与正则###############################
pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],padding='SAME')
#12*12*64
norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75) 
#局部响应归一化,现在主要使用batch normalization
######################################################################

############################第三层卷积层#######################################
weight2 = variable_with_weight_loss(shape=[5, 5, 64, 64], stddev=5e-2, wl=0.0)
kernel2 = tf.nn.conv2d(norm1, weight2, [1, 1, 1, 1], padding='SAME')
bias2 = tf.Variable(tf.constant(0.1, shape=[64]))
conv2 = tf.nn.relu(tf.nn.bias_add(kernel2, bias2))# 12*12*64
###########################################################################

#############################第四层池化与正则###############################
norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)
pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],padding='SAME')
#6*6*64
###########################################################################

reshape = tf.reshape(pool2, [batch_size, -1])
dim = reshape.get_shape()[1].value # 2304 = 6*6*64

############################第五层全连接##################################
weight3 = variable_with_weight_loss(shape=[dim, 384], stddev=0.04, wl=0.004)
bias3 = tf.Variable(tf.constant(0.1, shape=[384]))
local3 = tf.nn.relu(tf.matmul(reshape, weight3) + bias3)
#########################################################################


############################第六层全连接####################################
weight4 = variable_with_weight_loss(shape=[384, 192], stddev=0.04, wl=0.004)
bias4 = tf.Variable(tf.constant(0.1, shape=[192]))                                      
local4 = tf.nn.relu(tf.matmul(local3, weight4) + bias4)
############################################################################

#############################第七层全连接###################################
weight5 = variable_with_weight_loss(shape=[192, 10], stddev=1/192.0, wl=0.0)
bias5 = tf.Variable(tf.constant(0.0, shape=[10]))
logits = tf.add(tf.matmul(local4, weight5), bias5)
############################################################################


loss = loss(logits, label_holder) # 
train_op = tf.train.AdamOptimizer(1e-3).minimize(loss)
top_k_op = tf.nn.in_top_k(logits, label_holder, 1)

sess = tf.InteractiveSession()
tf.global_variables_initializer().run()

tf.train.start_queue_runners()

## 实时显示
fig = plt.figure()
ax0 = fig.add_subplot(1,2,1)
ax1 = fig.add_subplot(1,2,2)
plt.ion()
ax0.set_xlim(0,150)
ax0.set_ylim(0,1)

ax1.set_xlim(0,150)
ax1.set_ylim(0,2)

train_loss_all = [] 
test_loss_all =  []
train_precision_all = [] 
test_precision_all =  []

for e in range(max_steps):

    image_batch,label_batch = sess.run([images_train,labels_train])
    predictions, loss_train, _ = sess.run([top_k_op, loss, train_op],feed_dict={image_holder: image_batch, label_holder:label_batch})
    # run train_op 训练模型
    if e % 30 == 0:
        predictions_train = np.sum(predictions) / batch_size
        num_examples = 10000
        num_iter = int(math.ceil(num_examples / batch_size))
        true_count = 0  
        total_sample_count = num_iter * batch_size
        step = 0
        loss_value_list = []
        
        ##测试过程
        while step < num_iter:
    	    image_batch,label_batch = sess.run([images_test,labels_test])
    	    predictions,loss_value = sess.run([top_k_op,loss],feed_dict={image_holder: image_batch,
                                                 label_holder:label_batch})
    	    true_count += np.sum(predictions)
    	    loss_value_list.append(loss_value)
    	    step += 1

        precision_test = true_count / total_sample_count
        loss_test = sum(loss_value_list) / len(loss_value_list)
        print(str(e) + ",train_precision: " + str(predictions_train)[:5] +  ",train_loss: " + str(loss_train)[:5]  + "||" + "test_precision: " +  str(precision_test)[:5] + ",test_loss: " + str(loss_test)[:5])
        
        train_loss_all.append(loss_train)
        test_loss_all.append(loss_test)
        test_precision_all.append(precision_test)
        try:
            ax0.lines.remove(ax0_lines[0],ax0_lines[1])
            ax1.lines.remove(ax1_lines[0],ax1_lines[1])
        except Exception:
            pass 
        ax0_lines = ax0.plot(range(len(train_precision_all)), train_precision_all, c = 'b',label = "train_precision")
        ax0_lines = ax0.plot(range(len(test_precision_all)), test_precision_all, c = 'r',label = "test_precision")
        ax1_lines = ax1.plot(range(len(train_loss_all)), train_loss_all, c = 'b',label = "train_loss")
        ax1_lines = ax1.plot(range(len(test_loss_all)), test_loss_all, c = 'r',label = "test_loss")
        plt.pause(1)
        
    
plt.savefig("./fig.png")      

实时显示的图片:
【tensorflow】Cifar10卷积神经网络实时训练过程_第1张图片
依次是 准确率 和 损失曲线。蓝色是训练集,红色的是测试集。
这里测试集每次测试所有样本,而训练集只取当前batch的结果,所以每次损失值相对震荡。
在实际应用过程中,当训练集样本数量不是很大时候,将测试集和训练集一样全部算一下损失,然后取平均。当训练集样本数量很大时候,我们在训练集中随机选择一部分样本,计算平均损失。

终端输出:

0,train_precision: 0.125,train_loss: 4.678||test_precision: 0.103,test_loss: 4.584
30,train_precision: 0.265,train_loss: 2.833||test_precision: 0.280,test_loss: 2.690
60,train_precision: 0.273,train_loss: 2.203||test_precision: 0.319,test_loss: 2.135
90,train_precision: 0.312,train_loss: 2.031||test_precision: 0.393,test_loss: 1.881
120,train_precision: 0.304,train_loss: 1.994||test_precision: 0.374,test_loss: 1.824
150,train_precision: 0.335,train_loss: 1.968||test_precision: 0.396,test_loss: 1.787
180,train_precision: 0.476,train_loss: 1.577||test_precision: 0.438,test_loss: 1.645
210,train_precision: 0.382,train_loss: 1.681||test_precision: 0.470,test_loss: 1.585
240,train_precision: 0.445,train_loss: 1.737||test_precision: 0.445,test_loss: 1.604
270,train_precision: 0.406,train_loss: 1.688||test_precision: 0.460,test_loss: 1.593
300,train_precision: 0.421,train_loss: 1.575||test_precision: 0.508,test_loss: 1.478
330,train_precision: 0.484,train_loss: 1.530||test_precision: 0.494,test_loss: 1.502
360,train_precision: 0.507,train_loss: 1.585||test_precision: 0.516,test_loss: 1.470
390,train_precision: 0.476,train_loss: 1.531||test_precision: 0.542,test_loss: 1.399
420,train_precision: 0.453,train_loss: 1.632||test_precision: 0.542,test_loss: 1.383
450,train_precision: 0.546,train_loss: 1.356||test_precision: 0.561,test_loss: 1.339
480,train_precision: 0.421,train_loss: 1.611||test_precision: 0.536,test_loss: 1.378
510,train_precision: 0.507,train_loss: 1.351||test_precision: 0.567,test_loss: 1.323
540,train_precision: 0.515,train_loss: 1.505||test_precision: 0.561,test_loss: 1.323
570,train_precision: 0.515,train_loss: 1.504||test_precision: 0.584,test_loss: 1.301
600,train_precision: 0.453,train_loss: 1.473||test_precision: 0.545,test_loss: 1.351
630,train_precision: 0.578,train_loss: 1.363||test_precision: 0.584,test_loss: 1.286
660,train_precision: 0.507,train_loss: 1.456||test_precision: 0.551,test_loss: 1.355
690,train_precision: 0.562,train_loss: 1.269||test_precision: 0.609,test_loss: 1.226
720,train_precision: 0.601,train_loss: 1.167||test_precision: 0.597,test_loss: 1.246
750,train_precision: 0.609,train_loss: 1.341||test_precision: 0.614,test_loss: 1.196
780,train_precision: 0.570,train_loss: 1.298||test_precision: 0.612,test_loss: 1.235
810,train_precision: 0.507,train_loss: 1.393||test_precision: 0.597,test_loss: 1.247
840,train_precision: 0.546,train_loss: 1.283||test_precision: 0.633,test_loss: 1.173
870,train_precision: 0.625,train_loss: 1.159||test_precision: 0.631,test_loss: 1.164
900,train_precision: 0.546,train_loss: 1.338||test_precision: 0.624,test_loss: 1.180
930,train_precision: 0.617,train_loss: 1.184||test_precision: 0.633,test_loss: 1.181
960,train_precision: 0.601,train_loss: 1.130||test_precision: 0.651,test_loss: 1.132
990,train_precision: 0.632,train_loss: 1.181||test_precision: 0.639,test_loss: 1.139
1020,train_precision: 0.562,train_loss: 1.242||test_precision: 0.633,test_loss: 1.158
1050,train_precision: 0.671,train_loss: 1.249||test_precision: 0.649,test_loss: 1.115
1080,train_precision: 0.632,train_loss: 1.162||test_precision: 0.623,test_loss: 1.195
1110,train_precision: 0.562,train_loss: 1.280||test_precision: 0.633,test_loss: 1.155
1140,train_precision: 0.695,train_loss: 1.183||test_precision: 0.659,test_loss: 1.106
1170,train_precision: 0.531,train_loss: 1.448||test_precision: 0.634,test_loss: 1.190
1200,train_precision: 0.648,train_loss: 1.114||test_precision: 0.644,test_loss: 1.126
1230,train_precision: 0.648,train_loss: 1.183||test_precision: 0.657,test_loss: 1.104
1260,train_precision: 0.585,train_loss: 1.225||test_precision: 0.668,test_loss: 1.094
1290,train_precision: 0.562,train_loss: 1.250||test_precision: 0.662,test_loss: 1.097
1320,train_precision: 0.585,train_loss: 1.284||test_precision: 0.656,test_loss: 1.114
1350,train_precision: 0.656,train_loss: 1.158||test_precision: 0.676,test_loss: 1.069
1380,train_precision: 0.734,train_loss: 0.980||test_precision: 0.655,test_loss: 1.100
1410,train_precision: 0.617,train_loss: 1.299||test_precision: 0.648,test_loss: 1.113
1440,train_precision: 0.546,train_loss: 1.308||test_precision: 0.653,test_loss: 1.098
1470,train_precision: 0.625,train_loss: 1.097||test_precision: 0.661,test_loss: 1.088
1500,train_precision: 0.601,train_loss: 1.261||test_precision: 0.662,test_loss: 1.083
1530,train_precision: 0.703,train_loss: 1.064||test_precision: 0.651,test_loss: 1.122
1560,train_precision: 0.601,train_loss: 1.201||test_precision: 0.679,test_loss: 1.039
1590,train_precision: 0.585,train_loss: 1.263||test_precision: 0.675,test_loss: 1.066
1620,train_precision: 0.710,train_loss: 0.965||test_precision: 0.684,test_loss: 1.032
1650,train_precision: 0.679,train_loss: 1.052||test_precision: 0.690,test_loss: 1.038
1680,train_precision: 0.679,train_loss: 1.174||test_precision: 0.665,test_loss: 1.083
1710,train_precision: 0.539,train_loss: 1.401||test_precision: 0.671,test_loss: 1.064
1740,train_precision: 0.585,train_loss: 1.295||test_precision: 0.684,test_loss: 1.030
1770,train_precision: 0.640,train_loss: 1.150||test_precision: 0.672,test_loss: 1.049
1800,train_precision: 0.703,train_loss: 1.042||test_precision: 0.678,test_loss: 1.054
1830,train_precision: 0.617,train_loss: 1.293||test_precision: 0.657,test_loss: 1.085
1860,train_precision: 0.679,train_loss: 0.986||test_precision: 0.687,test_loss: 1.023
1890,train_precision: 0.710,train_loss: 1.037||test_precision: 0.684,test_loss: 1.022
1920,train_precision: 0.664,train_loss: 1.048||test_precision: 0.682,test_loss: 1.031
1950,train_precision: 0.656,train_loss: 0.983||test_precision: 0.690,test_loss: 1.020
1980,train_precision: 0.632,train_loss: 1.106||test_precision: 0.693,test_loss: 1.006
2010,train_precision: 0.617,train_loss: 1.110||test_precision: 0.684,test_loss: 1.035
2040,train_precision: 0.609,train_loss: 1.160||test_precision: 0.679,test_loss: 1.032
2070,train_precision: 0.609,train_loss: 1.160||test_precision: 0.681,test_loss: 1.031
2100,train_precision: 0.656,train_loss: 1.055||test_precision: 0.666,test_loss: 1.054
2130,train_precision: 0.648,train_loss: 1.114||test_precision: 0.686,test_loss: 1.027
2160,train_precision: 0.718,train_loss: 1.033||test_precision: 0.699,test_loss: 1.004
2190,train_precision: 0.648,train_loss: 1.079||test_precision: 0.694,test_loss: 1.017
2220,train_precision: 0.656,train_loss: 1.195||test_precision: 0.688,test_loss: 1.021
2250,train_precision: 0.726,train_loss: 1.008||test_precision: 0.688,test_loss: 1.018
2280,train_precision: 0.609,train_loss: 1.126||test_precision: 0.699,test_loss: 0.986
2310,train_precision: 0.625,train_loss: 1.177||test_precision: 0.695,test_loss: 0.987
2340,train_precision: 0.718,train_loss: 0.972||test_precision: 0.706,test_loss: 0.965
2370,train_precision: 0.656,train_loss: 1.059||test_precision: 0.690,test_loss: 1.018
2400,train_precision: 0.664,train_loss: 1.103||test_precision: 0.693,test_loss: 1.005
2430,train_precision: 0.601,train_loss: 1.297||test_precision: 0.704,test_loss: 0.974
2460,train_precision: 0.679,train_loss: 0.966||test_precision: 0.701,test_loss: 0.977
2490,train_precision: 0.671,train_loss: 1.135||test_precision: 0.690,test_loss: 0.999
2520,train_precision: 0.687,train_loss: 1.054||test_precision: 0.683,test_loss: 1.019
2550,train_precision: 0.625,train_loss: 1.187||test_precision: 0.713,test_loss: 0.969
2580,train_precision: 0.703,train_loss: 0.999||test_precision: 0.710,test_loss: 0.949
2610,train_precision: 0.656,train_loss: 1.196||test_precision: 0.699,test_loss: 0.987
2640,train_precision: 0.687,train_loss: 0.954||test_precision: 0.707,test_loss: 0.956
2670,train_precision: 0.664,train_loss: 1.092||test_precision: 0.699,test_loss: 0.964
2700,train_precision: 0.664,train_loss: 1.003||test_precision: 0.713,test_loss: 0.934
2730,train_precision: 0.679,train_loss: 0.945||test_precision: 0.707,test_loss: 0.954
2760,train_precision: 0.664,train_loss: 1.086||test_precision: 0.699,test_loss: 0.988
2790,train_precision: 0.796,train_loss: 0.892||test_precision: 0.724,test_loss: 0.922
2820,train_precision: 0.695,train_loss: 0.996||test_precision: 0.695,test_loss: 1.016
2850,train_precision: 0.703,train_loss: 0.937||test_precision: 0.704,test_loss: 0.974
2880,train_precision: 0.679,train_loss: 0.950||test_precision: 0.712,test_loss: 0.952
2910,train_precision: 0.734,train_loss: 0.897||test_precision: 0.715,test_loss: 0.934
2940,train_precision: 0.625,train_loss: 1.174||test_precision: 0.712,test_loss: 0.934
2970,train_precision: 0.734,train_loss: 0.929||test_precision: 0.697,test_loss: 0.978
3000,train_precision: 0.625,train_loss: 1.195||test_precision: 0.715,test_loss: 0.943
3030,train_precision: 0.625,train_loss: 1.081||test_precision: 0.716,test_loss: 0.934
3060,train_precision: 0.718,train_loss: 0.970||test_precision: 0.713,test_loss: 0.932
3090,train_precision: 0.648,train_loss: 1.111||test_precision: 0.718,test_loss: 0.948
3120,train_precision: 0.710,train_loss: 0.997||test_precision: 0.716,test_loss: 0.934
3150,train_precision: 0.726,train_loss: 0.972||test_precision: 0.696,test_loss: 0.999
3180,train_precision: 0.640,train_loss: 1.202||test_precision: 0.710,test_loss: 0.951
3210,train_precision: 0.640,train_loss: 1.100||test_precision: 0.712,test_loss: 0.954
3240,train_precision: 0.648,train_loss: 1.066||test_precision: 0.721,test_loss: 0.922
3270,train_precision: 0.679,train_loss: 1.181||test_precision: 0.729,test_loss: 0.906
3300,train_precision: 0.710,train_loss: 0.966||test_precision: 0.685,test_loss: 1.028
3330,train_precision: 0.625,train_loss: 1.166||test_precision: 0.710,test_loss: 0.964
3360,train_precision: 0.742,train_loss: 0.957||test_precision: 0.714,test_loss: 0.927
3390,train_precision: 0.648,train_loss: 1.099||test_precision: 0.712,test_loss: 0.956
3420,train_precision: 0.765,train_loss: 0.903||test_precision: 0.721,test_loss: 0.930
3450,train_precision: 0.710,train_loss: 0.989||test_precision: 0.733,test_loss: 0.899
3480,train_precision: 0.710,train_loss: 1.057||test_precision: 0.712,test_loss: 0.930
3510,train_precision: 0.679,train_loss: 1.106||test_precision: 0.719,test_loss: 0.926
3540,train_precision: 0.703,train_loss: 1.031||test_precision: 0.729,test_loss: 0.895
3570,train_precision: 0.687,train_loss: 1.005||test_precision: 0.710,test_loss: 0.947
3600,train_precision: 0.601,train_loss: 1.099||test_precision: 0.714,test_loss: 0.930
3630,train_precision: 0.773,train_loss: 0.803||test_precision: 0.725,test_loss: 0.907
3660,train_precision: 0.656,train_loss: 0.996||test_precision: 0.726,test_loss: 0.903
3690,train_precision: 0.679,train_loss: 0.980||test_precision: 0.718,test_loss: 0.933
3720,train_precision: 0.593,train_loss: 1.234||test_precision: 0.725,test_loss: 0.912
3750,train_precision: 0.718,train_loss: 0.974||test_precision: 0.716,test_loss: 0.914
3780,train_precision: 0.765,train_loss: 0.903||test_precision: 0.730,test_loss: 0.894
3810,train_precision: 0.671,train_loss: 0.941||test_precision: 0.724,test_loss: 0.898
3840,train_precision: 0.640,train_loss: 1.219||test_precision: 0.715,test_loss: 0.911
3870,train_precision: 0.726,train_loss: 0.885||test_precision: 0.717,test_loss: 0.910
3900,train_precision: 0.75,train_loss: 0.879||test_precision: 0.733,test_loss: 0.876
3930,train_precision: 0.671,train_loss: 0.987||test_precision: 0.722,test_loss: 0.916
3960,train_precision: 0.710,train_loss: 1.028||test_precision: 0.726,test_loss: 0.911
3990,train_precision: 0.742,train_loss: 0.987||test_precision: 0.740,test_loss: 0.883
4020,train_precision: 0.632,train_loss: 1.177||test_precision: 0.722,test_loss: 0.912
4050,train_precision: 0.703,train_loss: 0.973||test_precision: 0.728,test_loss: 0.890
4080,train_precision: 0.718,train_loss: 1.012||test_precision: 0.701,test_loss: 0.951
4110,train_precision: 0.703,train_loss: 0.911||test_precision: 0.733,test_loss: 0.900
4140,train_precision: 0.671,train_loss: 1.006||test_precision: 0.727,test_loss: 0.889
4170,train_precision: 0.742,train_loss: 1.010||test_precision: 0.719,test_loss: 0.919
4200,train_precision: 0.632,train_loss: 1.026||test_precision: 0.738,test_loss: 0.863
4230,train_precision: 0.687,train_loss: 0.851||test_precision: 0.693,test_loss: 0.973
4260,train_precision: 0.695,train_loss: 1.070||test_precision: 0.732,test_loss: 0.893
4290,train_precision: 0.648,train_loss: 1.024||test_precision: 0.714,test_loss: 0.923
4320,train_precision: 0.625,train_loss: 1.160||test_precision: 0.727,test_loss: 0.889
4350,train_precision: 0.726,train_loss: 0.920||test_precision: 0.732,test_loss: 0.873
4380,train_precision: 0.765,train_loss: 0.831||test_precision: 0.705,test_loss: 0.954
4410,train_precision: 0.640,train_loss: 1.178||test_precision: 0.730,test_loss: 0.901
4440,train_precision: 0.765,train_loss: 0.762||test_precision: 0.741,test_loss: 0.859
4470,train_precision: 0.664,train_loss: 1.128||test_precision: 0.730,test_loss: 0.886

最好的结果是准确率0.741,结合训练集来看当前训练过程或者网络架构欠拟合。

你可能感兴趣的:(深度学习)