tf.sparse_to_dense()函数理解

今天看代码的时候看到一个陌生的函tf.sparse_to_dense(),看了很多博客感觉都没有解释到点子上,看了函数才有了一点理解,记录如下:

import tensorflow as tf
import numpy
indices = tf.reshape(tf.range(0, 10 ,1), [10, 1])
labels=tf.expand_dims(tf.constant([0,2,3,6,7,9,1,3,5,4]),1)
print(indices)
print(labels)
onehot = tf.sparse_to_dense(
      tf.concat(values=[indices, labels], axis=1),
      [10, 10], 1.0, 0.0)
with tf.Session()  as sess:
      a = sess.run(onehot)
      print(a)

上面的代码很简单,先是产生一个(10,1)的随机labels,然后在产生一个0-9的vector作为labels的排序,然后就调用tf.sparse_to_dense()函数,这个函数的参数如下:
sparse_indices,
output_shape,
sparse_values,
default_value=0,
validate_indices=True,
name=None
其中output_shape为输出tensor的形状,这里我们采用10*10,第一个10表示10个labels,第二个10表示我们有10个分类,因为onehot数据最终要作为网络的标签。
sparse_values、default_value为两个值我们设为1和0
最重要的一个参数就是sparse_indices,我一开始很不是很理解为什么要为labels产生一个0-9的排序?我们看一下代码中的解释:
#If sparse_indices is scalar
dense[i] = (i == sparse_indices ? sparse_values : default_value)
#If sparse_indices is a vector, then for each i
dense[sparse_indices[i]] = sparse_values[i]
#If sparse_indices is an n by d matrix, then for each i in [0, n)
dense[sparse_indices[i][0], …, sparse_indices[i][d-1]] = sparse_values[i]
这个参数可以为scalar、vector以及matrix
结合我们的代码,我们可以发现通过 tf.concat函数我们将我们的labels和0-9对应起来,效果如下图:
tf.sparse_to_dense()函数理解_第1张图片
再根据文档中的解释,dense[0,0]=1,dense[1,2]=1…以此类推,然后dense中剩余的空位置用0填满,效果如下:
tf.sparse_to_dense()函数理解_第2张图片
说到这里一切似乎都很明朗了,以上就是我对这个函数的粗浅理解,当然这个函数的功能还很强大,我只是从产生onehot标签的角度去理解,如有错误欢迎指正。

你可能感兴趣的:(tf.sparse_to_dense()函数理解)