为了进一步了解上一篇中的class,搜了github如下示例:
import tensorflow as tf
class MyLayer(tf.keras.layers.Layer):
def __init__(self, output_dim, **kwargs):
self.output_dim = output_dim
super(MyLayer, self).__init__(**kwargs)
def build(self, input_shape):
# Create a trainable weight variable for this layer.
self.kernel = self.add_weight(name='kernel',
shape=(int(input_shape[1]), self.output_dim),
initializer='uniform',
trainable=True)
print("build: shape of input: ", input_shape)
print("build: shape of kernel: ", self.kernel)
super(MyLayer, self).build(input_shape) # Be sure to call this at the end
def call(self, x):
print("call: dot of x & kernel: ", tf.keras.backend.dot(x, self.kernel))
return tf.keras.backend.dot(x, self.kernel)
def compute_output_shape(self, input_shape):
return (input_shape[0], self.output_dim)
keras中除了能用Lambda来定义自己的层,也可用继承的方法来定义。后者看起来逼格就高多了。
暂且不讨论上面说的dense层是否正确,仅作为类的继承的学习。
x=tf.keras.Input(12,dtype=tf.float32)
x_out=MyLayer(16)(x)
build: shape of input: (?, 12)
build: shape of kernel:
call: dot of x & kernel: Tensor("my_layer/MatMul:0", shape=(?, 16), dtype=float32)
这种层的使用方法是针对层的,输入的是层!!!不是tensor,错误示例如下:
>>> y=tf.constant([1,2,3])
>>> y_out=MyLayer(16)(y)
Traceback (most recent call last):
File "", line 1, in
y_out=MyLayer(16)(y)
File "D:\python\lib\site-packages\tensorflow_core\python\keras\engine\base_layer.py", line 824, in __call__
self._maybe_build(inputs)
File "D:\python\lib\site-packages\tensorflow_core\python\keras\engine\base_layer.py", line 2146, in _maybe_build
self.build(input_shapes)
File "D:/python/pycode/tf_keras_layers_Layer_.py", line 13, in build
shape=(int(input_shape[1]), self.output_dim),
File "D:\python\lib\site-packages\tensorflow_core\python\framework\tensor_shape.py", line 870, in __getitem__
return self._dims[key]
IndexError: list index out of range
该类能够计算出输入层tensor的input_shape,call是逻辑层所在,不需要显式调用该方法。
其中call中的关键字参数应为inputs,这样看起来比较正规,可联想一下keras中的其他类层,比如Dense
| build(self, input_shape)
| Creates the variables of the layer (optional, for subclass implementers).
|
| This is a method that implementers of subclasses of `Layer` or `Model`
| can override if they need a state-creation step in-between
| layer instantiation and layer call.
|
| This is typically used to create the weights of `Layer` subclasses.
|
| Arguments:
| input_shape: Instance of `TensorShape`, or list of instances of
| `TensorShape` if the layer expects a list of inputs
| (one instance per input).
|
| call(self, inputs)
| This is where the layer's logic lives.
|
| Arguments:
| inputs: Input tensor, or list/tuple of input tensors.
| **kwargs: Additional keyword arguments.
|
| Returns:
| A tensor or list/tuple of tensors.
build是创建该层的变量,add_weight啥的肯定在这一层定义,需要加self,比如self.kernel,self.bias
call再对这些变量进行操作,return结果就是该类层的最终结果。
考虑上一篇的类,输入的不是层!!!然而还是同样的错误,如下:
x=tf.keras.Input(12,dtype=tf.float32)
inputs=list(map(Func(),x))
Traceback (most recent call last):
File "", line 1, in
list(map(Func(),x))
File "D:\python\lib\site-packages\tensorflow_core\python\framework\ops.py", line 547, in __iter__
self._disallow_iteration()
File "D:\python\lib\site-packages\tensorflow_core\python\framework\ops.py", line 543, in _disallow_iteration
self._disallow_in_graph_mode("iterating over `tf.Tensor`")
File "D:\python\lib\site-packages\tensorflow_core\python\framework\ops.py", line 523, in _disallow_in_graph_mode
" this function with @tf.function.".format(task))
tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.
错误显示说对tf Tensor的迭代在图中不支持!!!
但这个类层是可以用的,当输入的是Tensor时也是可以的,上面的话就打脸了,具体情况具体分析吧
>>> Func()(x)
>>> y=tf.constant([1.,2,3])
>>> y
>>> Func()(y)
既然是恒等的作用,类似于resnet中的恒等映射,那么可以去掉这个类层
For Video Recommendation in Deep learning QQ Group 277356808
For Speech,Image, Video in deep learning QQ Group 868373192
I'm here waiting for you.