问题描述:
使用TensorFlow Model Optimization 0.6.0对自己的网络进行剪枝的时候,遇到官方暂时不支持的层
ValueError: Please initialize Prune
with a supported layer. Layers should either be supported by the PruneRegistry (built-in keras layers) or should be a PrunableLayer
instance, or should has a customer defined get_prunable_weights
method. You passed:
情况一:不支持的层是Keras官方原生的层
参考博主专业混水
prune_registry.py包含了tensorflow官方目前支持剪枝的层,以及对应要剪枝的权重(默认是kernel,像Relu和Batch_norm这种会返回一个空列表,没有可以剪枝的参数)
找到prune_registry.py中的字典_LAYERS_WEIGHTS_MAP手动修改
运行成功,无报错,但这前提是keras官方的层。
情况二:不支持的层是自定义层(自定义层中是原生Keras层的组合)
首先,先贴出我个人的自定义层,是一个DenseNet里面的Transition Layer,但是是一维的版本。
class TransitionBlock(tf.keras.layers.Layer,tfmot.sparsity.keras.PrunableLayer):
def __init__(self, num_channels,name, **kwargs):
super(TransitionBlock, self).__init__(**kwargs)
self.batch_norm = tf.keras.layers.BatchNormalization(name = "{}_BatchNorm".format(name))
self.relu = tf.keras.layers.ReLU(name = "{}_ReLU".format(name))
self.conv = tf.keras.layers.Conv1D(num_channels, kernel_size=1 ,kernel_initializer='he_normal', name = "{}_Conv1D".format(name))
self.max_pool = tf.keras.layers.MaxPooling1D(name = "{}_MaxPool".format(name))
def get_config(self):
config = super().get_config().copy()
config.update({
"batch_norm": self.batch_norm,
"relu": self.relu,
"conv": self.conv,
"max_pool": self.max_pool,
})
return config
def get_prunable_weights(self):
return [self.conv.kernel] #这里返回的是list,参考官方例子
def call(self, x):
x = self.batch_norm(x)
x = self.relu(x)
x_concat = self.conv(x)
x = self.max_pool(x_concat)
return x_concat,x #x_concat是在upsample的时候用来concat的
这里有两点需要注意:
1.在自定义层的时候,同时继承两个类tf.keras.layers.Layer,tfmot.sparsity.keras.PrunableLayer,参考tensorflow官方的例子修剪自定义Keras层或修改某些部分以进行修剪
2.get_prunable_weights()方法
一般情况下,如果是对tf.keras.layers.Dense或者tf.keras.layers.Conv1D进行剪枝的时候,默认是对其kernel进行权重剪枝。但是在自定义层的时候,会识别不出来,出现上面ValueError的情况,需要自己重写get_prunable_weights。
对于本文中的情况,在self.batch_norm、self.relu、self.conv和self.maxpool中,只有self.conv是有权重可以剪枝的。
参考官方例子中,get_prunable_weights()返回的是一个list
tfmot.sparsity.keras.PrunableLayer 的源代码
通过查看tfmot.sparsity.keras.PrunableLayer 的源代码
class PrunableLayer(object):
"""Abstract Base Class for making your own keras layer prunable.
Custom keras layers which want to add pruning should implement this class.
"""
@abc.abstractmethod
def get_prunable_weights(self):
"""Returns list of prunable weight tensors.
All the weight tensors which the layer wants to be pruned during
training must be returned by this method.
Returns: List of weight tensors/kernels in the keras layer which must be
pruned during training.
"""
raise NotImplementedError('Must be implemented in subclasses.')
发现继承tfmot.sparsity.keras.PrunableLayer 时其实是继承了一个抽象方法,
对应地,在我自定义的Transition layer里面重写
def get_prunable_weights(self):
return [self.conv.kernel]
即可实现。
情况三:不属于tf.keras.layer里面的层,如TFOpLambda
利用tf.keras.models.clone_model 只转换其中一些layer
def apply_pruning_to(layer):
accepted_layers = [
tf.keras.layers.Conv1D,
tf.keras.layers.MaxPooling1D,
tf.keras.layers.BatchNormalization,
tf.keras.layers.ReLU,
tf.keras.layers.GlobalAveragePooling1D,#通道注意力
tf.keras.layers.GlobalMaxPooling1D,#通道注意力
tf.keras.layers.Reshape,#通道注意力
tf.keras.layers.Add,#通道注意力
tf.keras.layers.Activation,#通道注意力
tf.keras.layers.Multiply,
tf.keras.layers.Concatenate,
tf.keras.layers.Conv1DTranspose,
#ConvBlock,
DenseBlock,
TransitionBlock,
]
for accepted in accepted_layers:
if isinstance(layer, accepted):
return tfmot.sparsity.keras.prune_low_magnitude(layer)
return layer
model_for_pruning = tf.keras.models.clone_model(
model,
clone_function=apply_pruning_to,
)
参考1Yamnet clustering AttributeError: Exception encountered when calling layer “tf.operators.add” (type TFOpLambda). #972
参考2Prune some layers (Sequential and Functional)
下面是各种层对应的要剪枝的参数
class PruneRegistry(object):
"""Registry responsible for built-in keras layers."""
# The keys represent built-in keras layers and the values represent the
# the variables within the layers which hold the kernel weights. This
# allows the wrapper to access and modify the weights.
_LAYERS_WEIGHTS_MAP = {
layers.ELU: [],
layers.LeakyReLU: [],
layers.ReLU: [],
layers.Softmax: [],
layers.ThresholdedReLU: [],
layers.Conv1D: ['kernel'],
layers.Conv2D: ['kernel'],
layers.Conv2DTranspose: ['kernel'],
layers.Conv3D: ['kernel'],
layers.Conv3DTranspose: ['kernel'],
layers.Cropping1D: [],
layers.Cropping2D: [],
layers.Cropping3D: [],
layers.DepthwiseConv2D: [],
layers.SeparableConv1D: ['pointwise_kernel'],
layers.SeparableConv2D: ['pointwise_kernel'],
layers.UpSampling1D: [],
layers.UpSampling2D: [],
layers.UpSampling3D: [],
layers.ZeroPadding1D: [],
layers.ZeroPadding2D: [],
layers.ZeroPadding3D: [],
layers.Activation: [],
layers.ActivityRegularization: [],
layers.Dense: ['kernel'],
layers.Dropout: [],
layers.Flatten: [],
layers.Lambda: [],
layers.Masking: [],
layers.Permute: [],
layers.RepeatVector: [],
layers.Reshape: [],
layers.SpatialDropout1D: [],
layers.SpatialDropout2D: [],
layers.SpatialDropout3D: [],
layers.Embedding: ['embeddings'],
layers.LocallyConnected1D: ['kernel'],
layers.LocallyConnected2D: ['kernel'],
layers.Add: [],
layers.Average: [],
layers.Concatenate: [],
layers.Dot: [],
layers.Maximum: [],
layers.Minimum: [],
layers.Multiply: [],
layers.Subtract: [],
layers.AlphaDropout: [],
layers.GaussianDropout: [],
layers.GaussianNoise: [],
layers.BatchNormalization: [],
layers.LayerNormalization: [],
layers.AveragePooling1D: [],
layers.AveragePooling2D: [],
layers.AveragePooling3D: [],
layers.GlobalAveragePooling1D: [],
layers.GlobalAveragePooling2D: [],
layers.GlobalAveragePooling3D: [],
layers.GlobalMaxPooling1D: [],
layers.GlobalMaxPooling2D: [],
layers.GlobalMaxPooling3D: [],
layers.MaxPooling1D: [],
layers.MaxPooling2D: [],
layers.MaxPooling3D: [],
layers.MultiHeadAttention: [
'_query_dense.kernel', '_key_dense.kernel', '_value_dense.kernel',
'_output_dense.kernel'
],
layers.experimental.SyncBatchNormalization: [],
layers.experimental.preprocessing.Rescaling.__class__: [],
TensorFlowOpLayer: [],
layers_compat_v1.BatchNormalization: [],
}