将tf模型转换为pytorch模型

PRNet将tf转换为pytorch

tf模型可以很方便的转化为pytorch模型吗?可以。
我们需要做哪些步骤呢?

实现方法

我们以PRNet为例,实现tf2torch

  1. 加载tf模型
	from predictor import resfcn256   # import from prnet_dir, maybe using importlib is a better idea
    tf_network_def = resfcn256(256, 256)
    # tensorflow network forward
    net_input = tf.placeholder(
        tf.float32, shape=[None, 256, 256, 3])
    tf_model = tf_network_def(net_input, is_training=False)
    tf_config = tf.ConfigProto(
        gpu_options=tf.GPUOptions(allow_growth=False))
    tf_config = tf.ConfigProto(
        device_count={"GPU":0}
    )
    sess = tf.Session(config=tf_config)

    saver = tf.train.Saver(tf_network_def.vars)
    saver.restore(
        sess, os.path.join(args.prnet_dir, 'Data', 'net-data', '256_256_resfcn256_weight'))
	graph = sess.graph
	#print([node.name for node in graph.as_graph_def().node])
  1. 类似tf的网络结构,定义torch的网络结构,并且模型中添加tf_map用于连接tf节点和torch节点的参数命名的不同。记为PRNet_full(此部分略去)
  2. 将tf模型中的参数迁移到torch中:
	torch_dict = OrderedDict()

    for node in graph.as_graph_def().node:
        if node.name in torch_model.tf_map:
            torch_name = torch_model.tf_map[node.name]
            data = graph.get_operation_by_name(node.name).outputs[0]
            data_np = sess.run(data)
            if len(data_np.shape) > 1:
                # weight layouts  |   tensorflow   |     pytorch     |  transpose   |
                # conv2d_transpose (H, W, out, in) -> (in, out, H, W)  (3, 2, 0, 1)
                # conv2d           (H, W, in, out) -> (out, in, H, W)  (3, 2, 0, 1)
                torch_dict[torch_name] = torch.tensor(np.transpose(data_np, (3, 2, 0, 1)).astype(np.float32))
            else:
                torch_dict[torch_name] = torch.tensor(data_np.astype(np.float32))
        else:
            if node.name.find('save') == -1:
                pass
                # print('not in {}'.format(node.name))
    torch.save(torch_dict, 'from_tf.pth')
  1. [可选] 将PRNet_full中那些用于连接的tf_map删除,重新建立一个干净的PRNet,load刚刚得到torch权重。
  2. [可选] 验证对于同样的输入,是否tf和torch得到的结果大致相同。

训练

也可以比较tf和torch训练每一步中是否loss相近。

完结撒花

参考
https://github.com/liguohao96/pytorch-prnet

你可能感兴趣的:(深度学习框架)