参考资料:
[1]. spatial transformer network 李宏毅教学视频
[2]. 知乎 Spatial Transformer Networks
[3]. 详细解读Spatial Transformer Networks(STN)-一篇文章让你完全理解STN了
[4]. kevinzakka/spatial-transformer-network
代码:
from scipy import ndimage
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import cv2
def gen_grid(o_dims):
height,width = o_dims
x = np.linspace(0, 1.0, width, endpoint=False)
y = np.linspace(0, 1.0, height, endpoint=False)
# x = np.linspace(0, width, width, endpoint=False)
# y = np.linspace(0, height, height, endpoint=False)
x_table,y_table = np.meshgrid(x,y)
ones_table = np.ones(shape=o_dims)
grid = np.concatenate((np.expand_dims(y_table,0), np.expand_dims(x_table,0), np.expand_dims(ones_table,0)))
flatten_grid = np.reshape(grid,(3,-1))
flatten_grid = tf.convert_to_tensor(flatten_grid, dtype='float32')
return flatten_grid
def get_pixel_value(imgs, x, y):
num_batches = x.shape[0]
height = x.shape[1]
width = x.shape[2]
b = tf.range(num_batches)
b = tf.reshape(b, shape=(num_batches, 1, 1))
b = tf.tile(b, [1, height, width])
indices = tf.stack([b,y,x], axis=3)
return tf.gather_nd(imgs, indices)
def test_get_pixel_value():
N = 2
H = 4
W = 5
C = 3
imgs = tf.range(N*C*H*W)
imgs = tf.reshape(imgs,shape=[N, C, H, W])
imgs = tf.transpose(imgs, [0,2,3,1])
# print(imgs.eval()[0,:,:,0])
x = tf.zeros(shape=(N, H, W), dtype='int32')
y = tf.zeros(shape=(N, H, W), dtype='int32')
ret = get_pixel_value(imgs, x, y)
tf.InteractiveSession()
print(ret.eval().shape)
def bilinear_interpolation(imgs, x, y):
num_batches, height, width = imgs.shape[0], imgs.shape[1], imgs.shape[2]
x0 = tf.cast(tf.floor(x), 'int32')
x1 = x0 + 1
y0 = tf.cast(tf.floor(y), 'int32')
y1 = y0 + 1
x0 = tf.clip_by_value(x0, 0, width-1)
x1 = tf.clip_by_value(x1, 0, width-1)
y0 = tf.clip_by_value(y0, 0, height-1)
y1 = tf.clip_by_value(y1, 0, height-1)
a = get_pixel_value(imgs, x0, y1)
b = get_pixel_value(imgs, x1, y1)
c = get_pixel_value(imgs, x0, y0)
d = get_pixel_value(imgs, x1, y0)
x0 = tf.cast(x0, 'float32')
x1 = tf.cast(x1, 'float32')
y0 = tf.cast(y0, 'float32')
y1 = tf.cast(y1, 'float32')
wa = (x1 - x) * (y - y0)
wb = (x - x0) * (y - y0)
wc = (x1 - x) * (y1 - y)
wd = (x - x0) * (y1 - y)
wa = tf.expand_dims(wa, axis=3)
wb = tf.expand_dims(wb, axis=3)
wc = tf.expand_dims(wc, axis=3)
wd = tf.expand_dims(wd, axis=3)
inter_img = a * wa + b * wb + c * wc + d * wd
# 保证图片色彩正常显示
inter_img = tf.clip_by_value(inter_img, 0, 255)
inter_img = tf.cast(inter_img, 'uint8')
return inter_img
def STN(input, thetas, o_shape=None):
if o_shape is None:
o_shape = input.shape[1:3]
num_batches = thetas.shape[0]
# expand_dims是为了后面的相乘
in_shape = input.get_shape()[1:3]
in_shape = tf.expand_dims(in_shape,0)
in_shape = tf.expand_dims(in_shape,2)
flatten_grid = gen_grid(o_shape)
# (B*2*3)*(3,O_H*O_W) = B * 2 * (O_H*O_W)
locations_in_input = tf.matmul(thetas, flatten_grid)
locations_in_input = locations_in_input * tf.cast(in_shape, 'float32')
y, x = tf.split(locations_in_input, 2, axis=1)
x = tf.reshape(x, shape=(num_batches, o_shape[0], o_shape[1]))
y = tf.reshape(y, shape=(num_batches, o_shape[0], o_shape[1]))
output = bilinear_interpolation(input, x, y)
tf.InteractiveSession()
output = output.eval()
return output
if __name__ == '__main__':
img = cv2.imread('./cat.jpg')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
height = img.shape[0]
width = img.shape[1]
imgs = np.concatenate((np.expand_dims(img,0), np.expand_dims(img,0)))
imgs = tf.convert_to_tensor(imgs, dtype='float32')
thetas = [
[[1., 0., .5],
[0., 1., 0.]],
[[1., 0., 0],
[0., 1., 0]],
]
thetas = tf.convert_to_tensor(thetas, dtype='float32')
output = STN(imgs, thetas, (height//2, width//2))
plt.figure()
plt.subplot(131)
plt.imshow(img)
plt.subplot(132)
plt.imshow(output[0])
plt.subplot(133)
plt.imshow(output[1])
plt.show()
# test_get_pixel_value()