目录
- Outline
- Where
- where(tensor)
- where(cond,A,B)
- scatter_nd
- 一维
- 二维
- meshgrid
- Points
- numpy实现
- tensorflow2实现
Outline
where
scatter_nd
meshgrid
Where
where(tensor)
- where获得以下表格中True的位置
1 | 2 | 3 |
---|---|---|
True | False | False |
False | True | False |
False | False | True |
import tensorflow as tf
a = tf.random.normal([3, 3])
a
mask = a > 0
mask
# 为True元素的值
tf.boolean_mask(a, mask)
# 为True元素,即>0的元素的索引
indices = tf.where(mask)
indices
# 取回>0的值
tf.gather_nd(a, indices)
where(cond,A,B)
mask
A = tf.ones([3, 3])
B = tf.zeros([3, 3])
# True的元素会从A中选值,False的元素会从B中选值
tf.where(mask, A, B)
scatter_nd
- tf.scatter_nd(
- indices,
- updates,
- shape)
一维
indices = tf.constant([[4], [3], [1], [7]])
updates = tf.constant([9, 10, 11, 12])
shape = tf.constant([8])
# 把updates按照indices的索引放在底板shape上
tf.scatter_nd(indices, updates, shape)
二维
indices = tf.constant([[0], [2]])
updates = tf.constant([
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
])
updates.shape
TensorShape([2, 4, 4])
shape = tf.constant([4, 4, 4])
tf.scatter_nd(indices, updates, shape)
meshgrid
- [-2,-2]
- [-1,-2]
- [0,-2]
- [-2,-2]
- [-1,-1]
- ...
- [2,2]
Points
- [y,x,w]
- [5,5,2]
- [N,2]
numpy实现
import numpy as np
points = []
for y in np.linspace(-2, 2, 5):
for x in np.linspace(-2, 2, 5):
points.append([x, y])
np.array(points)
array([[-2., -2.],
[-1., -2.],
[ 0., -2.],
[ 1., -2.],
[ 2., -2.],
[-2., -1.],
[-1., -1.],
[ 0., -1.],
[ 1., -1.],
[ 2., -1.],
[-2., 0.],
[-1., 0.],
[ 0., 0.],
[ 1., 0.],
[ 2., 0.],
[-2., 1.],
[-1., 1.],
[ 0., 1.],
[ 1., 1.],
[ 2., 1.],
[-2., 2.],
[-1., 2.],
[ 0., 2.],
[ 1., 2.],
[ 2., 2.]])
tensorflow2实现
y = tf.linspace(-2., 2, 5)
y
x = tf.linspace(-2., 2, 5)
x
points_x, points_y = tf.meshgrid(x, y)
points_x.shape
TensorShape([5, 5])
points_x
points_y
points = tf.stack([points_x, points_y], axis=2)
points