胶囊网络Spread loss和accuracy_eval的理解2018-06-16

1、损失函数:

胶囊网络Spread loss和accuracy_eval的理解2018-06-16_第1张图片
图片.png
import tensorflow as tf
from include.cifar_10_data import get_data_set
with tf.device('/cpu:0'):
    # 仅仅是为了观察运算过程,分别任意取cifar_10数据的10个标签当成label和activation
    images, labels = get_data_set(name="test")
    gg = tf.split(labels, num_or_size_splits=1000,axis=0)

##cifar的label有10000个,分成1000批,则每组数据有10个
    label =gg[33] #随机选第33组,共10个
    activation = gg[10] #随机选第10组,共10个

    # 开始观察
    sess=tf.Session()
    activations_shape = activation.get_shape().as_list()
    mask_t = tf.equal(label, 1)
    mask_i = tf.equal(label, 0)
    
    print('label:\r\n',sess.run(label))
    print('activation:\r\n',sess.run(activation))
    print('mask_t:\r\n',sess.run(mask_t))
    print('mask_i:\r\n',sess.run(mask_i))

    print('boolean_mask:\r\n', sess.run(tf.boolean_mask(activation, mask_t)))
    tf.boolean_mask(activation, mask_t)
    activations_t = tf.reshape(
      tf.boolean_mask(activation, mask_t), [activations_shape[0], 1]
    )

    activations_i = tf.reshape(
        tf.boolean_mask(activation, mask_i), [activations_shape[0], activations_shape[1] - 1]
    )
    margin =0.5  #取值0.5,后期加到0.9,loss将更加大
    print('activations_i:\r\n', sess.run(activations_i))
    print('activations_t:\r\n', sess.run(activations_t))
    print('activations_t - activations_i:\r\n', sess.run((activations_t - activations_i)))
    print('margin-(activations_t - activations_i):\r\n', sess.run(margin-(activations_t - activations_i)))

    print('tf.nn.relu(margin-(activations_t - activations_i)):\r\n', sess.run(tf.nn.relu(margin-(activations_t - activations_i))))
    print('tf.square:\r\n', sess.run(tf.square(
        tf.nn.relu(
          margin - (activations_t - activations_i)
        )
      )))
    print('tf.reduce_sum:\r\n', sess.run(tf.reduce_sum(
      tf.square(
        tf.nn.relu(
          margin - (activations_t - activations_i)
        )
      )
    )))

得到的运行结果:

label:
[[ 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
[ 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
[ 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
[ 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]]
activation:
[[ 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
[ 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
[ 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]]
mask_t:
[[False True False False False False False False False False]
[False False False False False False False False False True]
[False False False True False False False False False False]
[False False False False False False True False False False]
[False False False False False False True False False False]
[False False False False False False False False False True]
[False False False True False False False False False False]
[False False False False False False False False True False]
[ True False False False False False False False False False]
[False False False False False False False True False False]]
mask_i:
[[ True False True True True True True True True True]
[ True True True True True True True True True False]
[ True True True False True True True True True True]
[ True True True True True True False True True True]
[ True True True True True True False True True True]
[ True True True True True True True True True False]
[ True True True False True True True True True True]
[ True True True True True True True True False True]
[False True True True True True True True True True]
[ True True True True True True True False True True]]
boolean_mask:
[ 0. 0. 0. 0. 0. 0. 1. 0. 0. 1.]
activations_i:
[[ 0. 0. 0. 1. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 1. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 1. 0. 0. 0.]
[ 0. 0. 0. 1. 0. 0. 0. 0. 0.]
[ 0. 1. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 1. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 1. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 1. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
activations_t:
[[ 0.]
[ 0.]
[ 0.]
[ 0.]
[ 0.]
[ 0.]
[ 1.]
[ 0.]
[ 0.]
[ 1.]]
activations_t - activations_i:
[[ 0. 0. 0. -1. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. -1. 0. 0. 0.]
[ 0. 0. 0. 0. 0. -1. 0. 0. 0.]
[ 0. 0. 0. -1. 0. 0. 0. 0. 0.]
[ 0. -1. 0. 0. 0. 0. 0. 0. 0.]
[ 0. -1. 0. 0. 0. 0. 0. 0. 0.]
[ 1. 1. 1. 1. 1. 1. 1. 1. 1.]
[ 0. 0. 0. 0. 0. 0. -1. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. -1. 0.]
[ 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
margin-(activations_t - activations_i):
[[ 0.5 0.5 0.5 1.5 0.5 0.5 0.5 0.5 0.5]
[ 0.5 0.5 0.5 0.5 0.5 1.5 0.5 0.5 0.5]
[ 0.5 0.5 0.5 0.5 0.5 1.5 0.5 0.5 0.5]
[ 0.5 0.5 0.5 1.5 0.5 0.5 0.5 0.5 0.5]
[ 0.5 1.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5]
[ 0.5 1.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5]
[-0.5 -0.5 -0.5 -0.5 -0.5 -0.5 -0.5 -0.5 -0.5]
[ 0.5 0.5 0.5 0.5 0.5 0.5 1.5 0.5 0.5]
[ 0.5 0.5 0.5 0.5 0.5 0.5 0.5 1.5 0.5]
[-0.5 -0.5 -0.5 -0.5 -0.5 -0.5 -0.5 -0.5 -0.5]]
tf.nn.relu(margin-(activations_t - activations_i)):
[[ 0.5 0.5 0.5 1.5 0.5 0.5 0.5 0.5 0.5]
[ 0.5 0.5 0.5 0.5 0.5 1.5 0.5 0.5 0.5]
[ 0.5 0.5 0.5 0.5 0.5 1.5 0.5 0.5 0.5]
[ 0.5 0.5 0.5 1.5 0.5 0.5 0.5 0.5 0.5]
[ 0.5 1.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5]
[ 0.5 1.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. ]
[ 0.5 0.5 0.5 0.5 0.5 0.5 1.5 0.5 0.5]
[ 0.5 0.5 0.5 0.5 0.5 0.5 0.5 1.5 0.5]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. ]]
tf.square:
[[ 0.25 0.25 0.25 2.25 0.25 0.25 0.25 0.25 0.25]
[ 0.25 0.25 0.25 0.25 0.25 2.25 0.25 0.25 0.25]
[ 0.25 0.25 0.25 0.25 0.25 2.25 0.25 0.25 0.25]
[ 0.25 0.25 0.25 2.25 0.25 0.25 0.25 0.25 0.25]
[ 0.25 2.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25]
[ 0.25 2.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. ]
[ 0.25 0.25 0.25 0.25 0.25 0.25 2.25 0.25 0.25]
[ 0.25 0.25 0.25 0.25 0.25 0.25 0.25 2.25 0.25]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. ]]
tf.reduce_sum:
34.0
显然,有两个正确,每错一个则loss加4.25分

2、评估函数:accuracy_eval


def accuracy_eval(sess,logits,labels):

    predictions=sess.run(logits)
    print('predictions:\r\n',predictions)
    print('tf.argmax(predictions, 1):\r\n',sess.run( tf.argmax(predictions, 1)))

    print('labels:\r\n',sess.run(labels))
    print('tf.argmax(labels, 1):\r\n', sess.run(tf.argmax(labels, 1)))

    tf.equal(tf.argmax(predictions, 1), tf.argmax(labels, 1))

    true_or_false = tf.equal(tf.argmax(predictions, 1), tf.argmax(labels, 1))
    print('true_or_false:\r\n', sess.run(true_or_false))

    accuracy =tf.reduce_mean(tf.cast(true_or_false,tf.float32))
    print('accuracy:\r\n', sess.run(accuracy))

    return accuracy

输出结果:

predictions:
[[ 0.43357164 0.50969642 0.48628038 0.41604722 0.381396 0.37202153
0.40028509 0.52360743 0.33245403 0.59107369]
[ 0.56551605 0.54600668 0.38575536 0.28260168 0.38372964 0.32026786
0.423181 0.37484613 0.64745504 0.53001797]
[ 0.44975427 0.58631718 0.38979253 0.39016432 0.39017558 0.37531838
0.47394231 0.46395907 0.52134717 0.40106338]
[ 0.43395406 0.45158568 0.50714886 0.52360237 0.42100799 0.42440724
0.49690422 0.42216748 0.30997989 0.45479706]
[ 0.56510764 0.5565328 0.38620496 0.25996378 0.36804253 0.32081208
0.40408149 0.39643031 0.64207578 0.56647772]
[ 0.63262081 0.47189313 0.38610214 0.24103768 0.39650175 0.33255672
0.42739603 0.33855924 0.70086569 0.53474861]
[ 0.42177251 0.36542833 0.54296505 0.50520897 0.48850119 0.49256417
0.4874 0.49314392 0.30180061 0.35400492]
[ 0.46720546 0.59052932 0.38354144 0.35084662 0.39522159 0.36131698
0.45329645 0.46744886 0.55032516 0.42503142]
[ 0.50269079 0.33317381 0.57097989 0.46942753 0.53243518 0.43186834
0.56617087 0.3485783 0.38901323 0.311407 ]
[ 0.46945664 0.31061405 0.60182887 0.46262699 0.53712744 0.47861028
0.49903092 0.44854346 0.31332538 0.3362765 ]]
tf.argmax(predictions, 1):
[9 8 1 3 8 8 2 1 2 2]
labels:
[[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
[ 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
[ 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]]
tf.argmax(labels, 1):
[9 8 0 3 8 8 7 7 4 6]
true_or_false:
[ True True False True True True False False False False]
accuracy:
0.5
accuracy:
0.5

你可能感兴趣的:(胶囊网络Spread loss和accuracy_eval的理解2018-06-16)