tf.keras.layers.Layer自定义层

为了进一步了解上一篇中的class,搜了github如下示例:

import tensorflow as tf

class MyLayer(tf.keras.layers.Layer):

    def __init__(self, output_dim, **kwargs):
        self.output_dim = output_dim
        super(MyLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        # Create a trainable weight variable for this layer.
        self.kernel = self.add_weight(name='kernel', 
                                      shape=(int(input_shape[1]), self.output_dim),
                                      initializer='uniform',
                                      trainable=True)
        print("build: shape of input: ", input_shape)
        print("build: shape of kernel: ", self.kernel)
        super(MyLayer, self).build(input_shape)  # Be sure to call this at the end

    def call(self, x):
        print("call: dot of x & kernel: ", tf.keras.backend.dot(x, self.kernel))
        return tf.keras.backend.dot(x, self.kernel)

    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.output_dim)

keras中除了能用Lambda来定义自己的层,也可用继承的方法来定义。后者看起来逼格就高多了。

暂且不讨论上面说的dense层是否正确,仅作为类的继承的学习。

x=tf.keras.Input(12,dtype=tf.float32)
x_out=MyLayer(16)(x)

build: shape of input:  (?, 12)
build: shape of kernel:  
call: dot of x & kernel:  Tensor("my_layer/MatMul:0", shape=(?, 16), dtype=float32)

这种层的使用方法是针对层的,输入的是层!!!不是tensor,错误示例如下:

>>> y=tf.constant([1,2,3])
>>> y_out=MyLayer(16)(y)
Traceback (most recent call last):
  File "", line 1, in 
    y_out=MyLayer(16)(y)
  File "D:\python\lib\site-packages\tensorflow_core\python\keras\engine\base_layer.py", line 824, in __call__
    self._maybe_build(inputs)
  File "D:\python\lib\site-packages\tensorflow_core\python\keras\engine\base_layer.py", line 2146, in _maybe_build
    self.build(input_shapes)
  File "D:/python/pycode/tf_keras_layers_Layer_.py", line 13, in build
    shape=(int(input_shape[1]), self.output_dim),
  File "D:\python\lib\site-packages\tensorflow_core\python\framework\tensor_shape.py", line 870, in __getitem__
    return self._dims[key]
IndexError: list index out of range

该类能够计算出输入层tensor的input_shape,call是逻辑层所在,不需要显式调用该方法。

其中call中的关键字参数应为inputs,这样看起来比较正规,可联想一下keras中的其他类层,比如Dense

 |  build(self, input_shape)
 |      Creates the variables of the layer (optional, for subclass implementers).
 |      
 |      This is a method that implementers of subclasses of `Layer` or `Model`
 |      can override if they need a state-creation step in-between
 |      layer instantiation and layer call.
 |      
 |      This is typically used to create the weights of `Layer` subclasses.
 |      
 |      Arguments:
 |        input_shape: Instance of `TensorShape`, or list of instances of
 |          `TensorShape` if the layer expects a list of inputs
 |          (one instance per input).
 |  
 |  call(self, inputs)
 |      This is where the layer's logic lives.
 |      
 |      Arguments:
 |          inputs: Input tensor, or list/tuple of input tensors.
 |          **kwargs: Additional keyword arguments.
 |      
 |      Returns:
 |          A tensor or list/tuple of tensors.

build是创建该层的变量,add_weight啥的肯定在这一层定义,需要加self,比如self.kernel,self.bias

call再对这些变量进行操作,return结果就是该类层的最终结果。

考虑上一篇的类,输入的不是层!!!然而还是同样的错误,如下:

x=tf.keras.Input(12,dtype=tf.float32)
inputs=list(map(Func(),x))

Traceback (most recent call last):
  File "", line 1, in 
    list(map(Func(),x))
  File "D:\python\lib\site-packages\tensorflow_core\python\framework\ops.py", line 547, in __iter__
    self._disallow_iteration()
  File "D:\python\lib\site-packages\tensorflow_core\python\framework\ops.py", line 543, in _disallow_iteration
    self._disallow_in_graph_mode("iterating over `tf.Tensor`")
  File "D:\python\lib\site-packages\tensorflow_core\python\framework\ops.py", line 523, in _disallow_in_graph_mode
    " this function with @tf.function.".format(task))
tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.

错误显示说对tf Tensor的迭代在图中不支持!!!

但这个类层是可以用的,当输入的是Tensor时也是可以的,上面的话就打脸了,具体情况具体分析吧

>>> Func()(x)


>>> y=tf.constant([1.,2,3])
>>> y

>>> Func()(y)

既然是恒等的作用,类似于resnet中的恒等映射,那么可以去掉这个类层

 

For Video Recommendation in Deep learning QQ Group 277356808

For Speech,Image, Video in deep learning QQ Group 868373192

I'm here waiting for you.
 

 

 

你可能感兴趣的:(python)