STN(Spatial Transformer Networks)网络学习(附代码)

参考资料:
[1]. spatial transformer network 李宏毅教学视频
[2]. 知乎 Spatial Transformer Networks
[3]. 详细解读Spatial Transformer Networks(STN)-一篇文章让你完全理解STN了
[4]. kevinzakka/spatial-transformer-network

代码:

from scipy import ndimage
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import cv2


def gen_grid(o_dims):
    height,width = o_dims

    x = np.linspace(0, 1.0, width, endpoint=False)
    y = np.linspace(0, 1.0, height, endpoint=False)
    # x = np.linspace(0, width, width, endpoint=False)
    # y = np.linspace(0, height, height, endpoint=False)

    x_table,y_table = np.meshgrid(x,y)
    ones_table = np.ones(shape=o_dims)

    grid = np.concatenate((np.expand_dims(y_table,0), np.expand_dims(x_table,0), np.expand_dims(ones_table,0)))
    flatten_grid = np.reshape(grid,(3,-1))

    flatten_grid = tf.convert_to_tensor(flatten_grid, dtype='float32')
    return flatten_grid


def get_pixel_value(imgs, x, y):
    num_batches = x.shape[0]
    height = x.shape[1]
    width = x.shape[2]

    b = tf.range(num_batches)
    b = tf.reshape(b, shape=(num_batches, 1, 1))
    b = tf.tile(b, [1, height, width])

    indices = tf.stack([b,y,x], axis=3)

    return tf.gather_nd(imgs, indices)


def test_get_pixel_value():
    N = 2
    H = 4
    W = 5
    C = 3
    imgs = tf.range(N*C*H*W)
    imgs = tf.reshape(imgs,shape=[N, C, H, W])
    imgs = tf.transpose(imgs, [0,2,3,1])


    # print(imgs.eval()[0,:,:,0])

    x = tf.zeros(shape=(N, H, W), dtype='int32')
    y = tf.zeros(shape=(N, H, W), dtype='int32')

    ret = get_pixel_value(imgs, x, y)

    tf.InteractiveSession()
    print(ret.eval().shape)


def bilinear_interpolation(imgs, x, y):
    num_batches, height, width = imgs.shape[0], imgs.shape[1], imgs.shape[2]

    x0 = tf.cast(tf.floor(x), 'int32')
    x1 = x0 + 1
    y0 = tf.cast(tf.floor(y), 'int32')
    y1 = y0 + 1

    x0 = tf.clip_by_value(x0, 0, width-1)
    x1 = tf.clip_by_value(x1, 0, width-1)
    y0 = tf.clip_by_value(y0, 0, height-1)
    y1 = tf.clip_by_value(y1, 0, height-1)

    a = get_pixel_value(imgs, x0, y1)
    b = get_pixel_value(imgs, x1, y1)
    c = get_pixel_value(imgs, x0, y0)
    d = get_pixel_value(imgs, x1, y0)

    x0 = tf.cast(x0, 'float32')
    x1 = tf.cast(x1, 'float32')
    y0 = tf.cast(y0, 'float32')
    y1 = tf.cast(y1, 'float32')

    wa = (x1 - x) * (y - y0)
    wb = (x - x0) * (y - y0)
    wc = (x1 - x) * (y1 - y)
    wd = (x - x0) * (y1 - y)

    wa = tf.expand_dims(wa, axis=3)
    wb = tf.expand_dims(wb, axis=3)
    wc = tf.expand_dims(wc, axis=3)
    wd = tf.expand_dims(wd, axis=3)

    inter_img = a * wa + b * wb + c * wc + d * wd

    # 保证图片色彩正常显示
    inter_img = tf.clip_by_value(inter_img, 0, 255)
    inter_img = tf.cast(inter_img, 'uint8')

    return inter_img


def STN(input, thetas, o_shape=None):
    if o_shape is None:
        o_shape = input.shape[1:3]

    num_batches = thetas.shape[0]

    # expand_dims是为了后面的相乘
    in_shape = input.get_shape()[1:3]
    in_shape = tf.expand_dims(in_shape,0)
    in_shape = tf.expand_dims(in_shape,2)

    flatten_grid = gen_grid(o_shape)

    # (B*2*3)*(3,O_H*O_W) = B * 2 * (O_H*O_W)
    locations_in_input = tf.matmul(thetas, flatten_grid)
    locations_in_input = locations_in_input * tf.cast(in_shape, 'float32')
    y, x = tf.split(locations_in_input, 2, axis=1)
    x = tf.reshape(x, shape=(num_batches, o_shape[0], o_shape[1]))
    y = tf.reshape(y, shape=(num_batches, o_shape[0], o_shape[1]))

    output = bilinear_interpolation(input, x, y)

    tf.InteractiveSession()
    output = output.eval()

    return output


if __name__ == '__main__':
    img = cv2.imread('./cat.jpg')
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    height = img.shape[0]
    width = img.shape[1]

    imgs = np.concatenate((np.expand_dims(img,0), np.expand_dims(img,0)))
    imgs = tf.convert_to_tensor(imgs, dtype='float32')

    thetas = [
        [[1., 0., .5],
         [0., 1., 0.]],
        [[1., 0., 0],
         [0., 1., 0]],
    ]
    thetas = tf.convert_to_tensor(thetas, dtype='float32')

    output = STN(imgs, thetas, (height//2, width//2))

    plt.figure()
    plt.subplot(131)
    plt.imshow(img)
    plt.subplot(132)
    plt.imshow(output[0])
    plt.subplot(133)
    plt.imshow(output[1])

    plt.show()

    # test_get_pixel_value()


STN(Spatial Transformer Networks)网络学习(附代码)_第1张图片
image.png

你可能感兴趣的:(STN(Spatial Transformer Networks)网络学习(附代码))