torch程序转tf指南

1 网络迁移

基于torch实现的网络需要修改为基于tf的实现, 类似于翻译的工作, 主要是把torch的算子替换为tf等价的替换。
经常会遇到torch中的算子tf中没有, 或者虽然有但功能不等价的情况, 尤其是后一种情况需要格外注意。
还有一个主要区别是torch中通常定义一个类, 在__init__中定义好需要用到的算子,在forward中进行网络的连接;
但是tf通常是函数式编程, 虽然也可以定义类, 但是通常没必要, 因为tf中很多算子都是函数, 定义和计算是一起完成的。
比如在torch中:

class Net():
    def __init__(self):
        self.conv = torch.nn.Conv2d(...)
    
    def forward(self, input):
        out = self.conv(input)
        return out

__init__先定义一个conv操作, 然后在forward中使用。

但是在tf中:

    out = tf.layers.conv2d(input, ...)

定义和计算是在一起的, 因为tf.layers.conv2d本身只是个函数。

2 数据部分

数据需要用tf的方式进行读取。

3 运行部分

tf是静态图模式, 需要先定义计算图, 然后调用sess.run运行计算图才能得到最终结果。
因为是静态图模式, 要想进行调试输出中间结果是比较困难的。 这里提供了一些小技巧,
可以参考:tensorflow调试小技巧

你可能感兴趣的:(训练框架,tensorflow,深度学习,pytorch)