#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
View on TensorFlow.org | Run in Google Colab | View source on GitHub | Download notebook |
Welcome to the comprehensive guide for Keras weight pruning.
This page documents various use cases and shows how to use the API for each one. Once you know which APIs you need, find the parameters and the low-level details in the
API docs.
The following use cases are covered:
For configuration of the pruning algorithm, refer to the tfmot.sparsity.keras.prune_low_magnitude
API docs.
For finding the APIs you need and understanding purposes, you can run but skip reading this section.
! pip install -q tensorflow-model-optimization
import tensorflow as tf
import numpy as np
import tensorflow_model_optimization as tfmot
%load_ext tensorboard
import tempfile
input_shape = [20]
x_train = np.random.randn(1, 20).astype(np.float32)
y_train = tf.keras.utils.to_categorical(np.random.randn(1), num_classes=20)
def setup_model():
model = tf.keras.Sequential([
tf.keras.layers.Dense(20, input_shape=input_shape),
tf.keras.layers.Flatten()
])
return model
def setup_pretrained_weights():
model = setup_model()
model.compile(
loss=tf.keras.losses.categorical_crossentropy,
optimizer='adam',
metrics=['accuracy']
)
model.fit(x_train, y_train)
_, pretrained_weights = tempfile.mkstemp('.tf')
model.save_weights(pretrained_weights)
return pretrained_weights
def get_gzipped_model_size(model):
# Returns size of gzipped model, in bytes.
import os
import zipfile
_, keras_file = tempfile.mkstemp('.h5')
model.save(keras_file, include_optimizer=False)
_, zipped_file = tempfile.mkstemp('.zip')
with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
f.write(keras_file)
return os.path.getsize(zipped_file)
setup_model()
pretrained_weights = setup_pretrained_weights()
Train on 1 samples
1/1 [==============================] - 0s 326ms/sample - loss: 16.1181 - accuracy: 0.0000e+00
Tips for better model accuracy:
To make the whole model train with pruning, apply tfmot.sparsity.keras.prune_low_magnitude
to the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended.
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)
model_for_pruning.summary()
WARNING:tensorflow:From d:\python_virtualenv\tf2.1_gpu\lib\site-packages\tensorflow_model_optimization\python\core\sparsity\keras\pruning_wrapper.py:199: Layer.add_variable (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `layer.add_weight` method instead.
Model: "sequential_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
prune_low_magnitude_dense_2 (None, 20) 822
_________________________________________________________________
prune_low_magnitude_flatten_ (None, 20) 1
=================================================================
Total params: 823
Trainable params: 420
Non-trainable params: 403
_________________________________________________________________
Pruning a model can have a negative effect on accuracy. You can selectively prune layers of a model to explore the trade-off between accuracy, speed, and model size.
Tips for better model accuracy:
More:
tfmot.sparsity.keras.prune_low_magnitude
API docs provide details on how to vary the pruning configuration per layer.In the example below, prune only the Dense
layers.
# Create a base model
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
# Helper function uses `prune_low_magnitude` to make only the
# Dense layers train with pruning.
def apply_pruning_to_dense(layer):
if isinstance(layer, tf.keras.layers.Dense):
return tfmot.sparsity.keras.prune_low_magnitude(layer)
return layer
# Use `tf.keras.models.clone_model` to apply `apply_pruning_to_dense`
# to the layers of the model.
model_for_pruning = tf.keras.models.clone_model(
base_model,
clone_function=apply_pruning_to_dense,
)
model_for_pruning.summary()
Model: "sequential_3"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
prune_low_magnitude_dense_3 (None, 20) 822
_________________________________________________________________
flatten_3 (Flatten) (None, 20) 0
=================================================================
Total params: 822
Trainable params: 420
Non-trainable params: 402
_________________________________________________________________
While this example used the type of the layer to decide what to prune, the easiest way to prune a particular layer is to set its name
property, and look for that name in the clone_function
.
print(base_model.layers[0].name)
dense_3
This is not compatible with fine-tuning with pruning, which is why it may be less accurate than the above examples which
support fine-tuning.
While prune_low_magnitude
can be applied while defining the initial model, loading the weights after does not work in the below examples.
Functional example
# Use `prune_low_magnitude` to make the `Dense` layer train with pruning.
i = tf.keras.Input(shape=(20,))
x = tfmot.sparsity.keras.prune_low_magnitude(tf.keras.layers.Dense(10))(i)
o = tf.keras.layers.Flatten()(x)
model_for_pruning = tf.keras.Model(inputs=i, outputs=o)
model_for_pruning.summary()
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 20)] 0
_________________________________________________________________
prune_low_magnitude_dense_4 (None, 10) 412
_________________________________________________________________
flatten_4 (Flatten) (None, 10) 0
=================================================================
Total params: 412
Trainable params: 210
Non-trainable params: 202
_________________________________________________________________
Sequential example
# Use `prune_low_magnitude` to make the `Dense` layer train with pruning.
model_for_pruning = tf.keras.Sequential([
tfmot.sparsity.keras.prune_low_magnitude(tf.keras.layers.Dense(20, input_shape=input_shape)),
tf.keras.layers.Flatten()
])
model_for_pruning.summary()
Model: "sequential_4"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
prune_low_magnitude_dense_5 (None, 20) 822
_________________________________________________________________
flatten_5 (Flatten) (None, 20) 0
=================================================================
Total params: 822
Trainable params: 420
Non-trainable params: 402
_________________________________________________________________
Tips for better model accuracy:
More:
tfmot.sparsity.keras.prune_low_magnitude
API docs provide details on how to vary the pruning configuration per layer.class MyModel(tf.keras.Model):
def __init__(self):
super(MyModel, self).__init__()
self.dense = tf.keras.layers.Dense(10)
self.flatten = tf.keras.layers.Flatten()
self.dense2 = tf.keras.Sequential([tf.keras.layers.Dense(10)])
def call(self, inputs):
x = self.dense(inputs)
x = self.flatten(x)
return self.dense2(x)
model_for_pruning = MyModel()
# It's optional but recommended to load pretrained weights before this for
# model accuracy.
model_for_pruning.dense = tfmot.sparsity.keras.prune_low_magnitude(model_for_pruning.dense)
model_for_pruning.dense2 = tfmot.sparsity.keras.prune_low_magnitude(model_for_pruning.dense2)
model_for_pruning.build(input_shape=(None, 20))
model_for_pruning.summary()
Model: "my_model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
flatten_6 (Flatten) multiple 0
_________________________________________________________________
prune_low_magnitude_dense_6 multiple 412
_________________________________________________________________
sequential_5 (Sequential) multiple 212
=================================================================
Total params: 624
Trainable params: 320
Non-trainable params: 304
_________________________________________________________________
Common mistake: pruning the bias usually harms model accuracy too much.
tfmot.sparsity.keras.PrunableLayer
serves two use cases:
For an example, the API defaults to only pruning the kernel of the
Dense
layer. The example below prunes the bias also.
class MyDenseLayer(tf.keras.layers.Dense, tfmot.sparsity.keras.PrunableLayer):
def get_prunable_weights(self):
# Prune bias also, though that usually harms model accuracy too much.
return [self.kernel, self.bias]
# Use `prune_low_magnitude` to make the `MyDenseLayer` layer train with pruning.
model_for_pruning = tf.keras.Sequential([
tfmot.sparsity.keras.prune_low_magnitude(MyDenseLayer(20, input_shape=input_shape)),
tf.keras.layers.Flatten()
])
model_for_pruning.summary()
Model: "sequential_6"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
prune_low_magnitude_my_dense (None, 20) 843
_________________________________________________________________
flatten_7 (Flatten) (None, 20) 0
=================================================================
Total params: 843
Trainable params: 420
Non-trainable params: 423
_________________________________________________________________
Call the tfmot.sparsity.keras.UpdatePruningStep
callback during training.
To help debug training, use the tfmot.sparsity.keras.PruningSummaries
callback.
# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)
log_dir = tempfile.mkdtemp()
callbacks = [
tfmot.sparsity.keras.UpdatePruningStep(),
# Log sparsity and other metrics in Tensorboard.
tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir)
]
model_for_pruning.compile(
loss=tf.keras.losses.categorical_crossentropy,
optimizer='adam',
metrics=['accuracy']
)
model_for_pruning.fit(
x_train,
y_train,
callbacks=callbacks,
epochs=2,
)
#docs_infra: no_execute
%tensorboard --logdir={log_dir}
Train on 1 samples
Epoch 1/2
1/1 [==============================] - 0s 297ms/sample - loss: 2.9753 - accuracy: 0.0000e+00
Epoch 2/2
1/1 [==============================] - 0s 11ms/sample - loss: 2.8398 - accuracy: 0.0000e+00
ERROR: Timed out waiting for TensorBoard to start. It may still be running as pid 5392.
For non-Colab users, you can see the results of a previous run of this code block on TensorBoard.dev.
Call the tfmot.sparsity.keras.UpdatePruningStep
callback during training.
To help debug training, use the tfmot.sparsity.keras.PruningSummaries
callback.
# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)
# Boilerplate
loss = tf.keras.losses.categorical_crossentropy
optimizer = tf.keras.optimizers.Adam()
log_dir = tempfile.mkdtemp()
unused_arg = -1
epochs = 2
batches = 1 # example is hardcoded so that the number of batches cannot change.
# Non-boilerplate.
model_for_pruning.optimizer = optimizer
step_callback = tfmot.sparsity.keras.UpdatePruningStep()
step_callback.set_model(model_for_pruning)
log_callback = tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir) # Log sparsity and other metrics in Tensorboard.
log_callback.set_model(model_for_pruning)
step_callback.on_train_begin() # run pruning callback
for _ in range(epochs):
log_callback.on_epoch_begin(epoch=unused_arg) # run pruning callback
for _ in range(batches):
step_callback.on_train_batch_begin(batch=unused_arg) # run pruning callback
with tf.GradientTape() as tape:
logits = model_for_pruning(x_train, training=True)
loss_value = loss(y_train, logits)
grads = tape.gradient(loss_value, model_for_pruning.trainable_variables)
optimizer.apply_gradients(zip(grads, model_for_pruning.trainable_variables))
step_callback.on_epoch_end(batch=unused_arg) # run pruning callback
#docs_infra: no_execute
%tensorboard --logdir={log_dir}
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
ERROR: Timed out waiting for TensorBoard to start. It may still be running as pid 6184.
For non-Colab users, you can see the results of a previous run of this code block on TensorBoard.dev.
First, look at the tfmot.sparsity.keras.prune_low_magnitude
API docs
to understand what a pruning schedule is and the math of
each type of pruning schedule.
Tips:
Have a learning rate that’s not too high or too low when the model is pruning. Consider the pruning schedule to be a hyperparameter.
As a quick test, try experimenting with pruning a model to the final sparsity at the begining of training by setting begin_step
to 0 with a tfmot.sparsity.keras.ConstantSparsity
schedule. You might get lucky with good results.
Do not prune very frequently to give the model time to recover. The pruning schedule provides a decent default frequency.
For general ideas to improve model accuracy, look for tips for your use case(s) under “Define model”.
You must preserve the optimizer step during checkpointing. This means while you can use Keras HDF5 models for checkpointing, you cannot use Keras HDF5 weights.
# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)
_, keras_model_file = tempfile.mkstemp('.h5')
# Checkpoint: saving the optimizer is necessary (include_optimizer=True is the default).
model_for_pruning.save(keras_model_file, include_optimizer=True)
The above applies generally. The code below is only needed for the HDF5 model format (not HDF5 weights and other formats).
# Deserialize model.
with tfmot.sparsity.keras.prune_scope():
loaded_model = tf.keras.models.load_model(keras_model_file)
loaded_model.summary()
WARNING:tensorflow:No training configuration found in save file: the model was *not* compiled. Compile it manually.
Model: "sequential_11"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
prune_low_magnitude_dense_12 (None, 20) 822
_________________________________________________________________
prune_low_magnitude_flatten_ (None, 20) 1
=================================================================
Total params: 823
Trainable params: 420
Non-trainable params: 403
_________________________________________________________________
Common mistake: both strip_pruning
and applying a standard compression algorithm (e.g. via gzip) are necessary to see the compression
benefits of pruning.
# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)
# Typically you train the model here.
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
print("final model")
model_for_export.summary()
print("\n")
print("Size of gzipped pruned model without stripping: %.2f bytes" % (get_gzipped_model_size(model_for_pruning)))
print("Size of gzipped pruned model with stripping: %.2f bytes" % (get_gzipped_model_size(model_for_export)))
final model
Model: "sequential_12"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_13 (Dense) (None, 20) 420
_________________________________________________________________
flatten_13 (Flatten) (None, 20) 0
=================================================================
Total params: 420
Trainable params: 420
Non-trainable params: 0
_________________________________________________________________
Size of gzipped pruned model without stripping: 3315.00 bytes
Size of gzipped pruned model with stripping: 2881.00 bytes
Once different backends enable pruning to improve latency, using block sparsity can improve latency for certain hardware.
Increasing the block size will decrease the peak sparsity that’s achievable for a target model accuracy. Despite this, latency can still improve.
For details on what’s supported for block sparsity, see
the tfmot.sparsity.keras.prune_low_magnitude
API docs.
base_model = setup_model()
# For using intrinsics on a CPU with 128-bit registers, together with 8-bit
# quantized weights, a 1x16 block size is nice because the block perfectly
# fits into the register.
pruning_params = {'block_size': [1, 16]}
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model, **pruning_params)
model_for_pruning.summary()
Model: "sequential_13"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
prune_low_magnitude_dense_14 (None, 20) 822
_________________________________________________________________
prune_low_magnitude_flatten_ (None, 20) 1
=================================================================
Total params: 823
Trainable params: 420
Non-trainable params: 403
_________________________________________________________________