#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
View on TensorFlow.org | Run in Google Colab | View source on GitHub | Download notebook |
Welcome to an end-to-end example for quantization aware training.
For an introduction to what quantization aware training is and to determine if you should use it (including what’s supported), see the overview page.
To quickly find the APIs you need for your use case (beyond fully-quantizing a model with 8-bits), see the
comprehensive guide.
In this tutorial, you will:
tf.keras
model for MNIST from scratch.! pip uninstall -y tensorflow
! pip install -q tf-nightly
! pip install -q tensorflow-model-optimization
WARNING: Skipping tensorflow as it is not installed.
ERROR: tensorflow-gpu 2.1.0 has requirement gast==0.2.2, but you'll have gast 0.3.3 which is incompatible.
ERROR: tensorflow-gpu 2.1.0 has requirement tensorboard<2.2.0,>=2.1.0, but you'll have tensorboard 2.2.1 which is incompatible.
WARNING: Failed to write executable - trying to use .deleteme logic
import tempfile
import os
import tensorflow as tf
from tensorflow import keras
# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0
# Define the model architecture.
model = keras.Sequential([
keras.layers.InputLayer(input_shape=(28, 28)),
keras.layers.Reshape(target_shape=(28, 28, 1)),
keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation=tf.nn.relu),
keras.layers.MaxPooling2D(pool_size=(2, 2)),
keras.layers.Flatten(),
keras.layers.Dense(10)
])
# Train the digit classification model
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(
train_images,
train_labels,
epochs=1,
validation_split=0.1,
)
1688/1688 [==============================] - 5s 3ms/step - loss: 0.2840 - accuracy: 0.9205 - val_loss: 0.1255 - val_accuracy: 0.9635
You will apply quantization aware training to the whole model and see this in the model summary. All layers are now prefixed by “quant”.
Note that the resulting model is quantization aware but not quantized (e.g. the weights are float32 instead of int8). The sections after show how to create a quantized model from the quantization aware one.
In the comprehensive guide, you can see how to quantize some layers for model accuracy improvements.
import tensorflow_model_optimization as tfmot
quantize_model = tfmot.quantization.keras.quantize_model
# q_aware stands for for quantization aware.
q_aware_model = quantize_model(model)
# `quantize_model` requires a recompile.
q_aware_model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
q_aware_model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
quantize_layer (QuantizeLaye (None, 28, 28) 3
_________________________________________________________________
quant_reshape (QuantizeWrapp (None, 28, 28, 1) 1
_________________________________________________________________
quant_conv2d (QuantizeWrappe (None, 26, 26, 12) 147
_________________________________________________________________
quant_max_pooling2d (Quantiz (None, 13, 13, 12) 1
_________________________________________________________________
quant_flatten (QuantizeWrapp (None, 2028) 1
_________________________________________________________________
quant_dense (QuantizeWrapper (None, 10) 20295
=================================================================
Total params: 20,448
Trainable params: 20,410
Non-trainable params: 38
_________________________________________________________________
To demonstrate fine tuning after training the model for just an epoch, fine tune with quantization aware training on a subset of the training data.
train_images_subset = train_images[0:1000] # out of 60000
train_labels_subset = train_labels[0:1000]
q_aware_model.fit(train_images_subset, train_labels_subset,
batch_size=500, epochs=1, validation_split=0.1)
2/2 [==============================] - 0s 68ms/step - loss: 0.1414 - accuracy: 0.9644 - val_loss: 0.1628 - val_accuracy: 0.9700
For this example, there is minimal to no loss in test accuracy after quantization aware training, compared to the baseline.
_, baseline_model_accuracy = model.evaluate(
test_images, test_labels, verbose=0)
_, q_aware_model_accuracy = q_aware_model.evaluate(
test_images, test_labels, verbose=0)
print('Baseline test accuracy:', baseline_model_accuracy)
print('Quant test accuracy:', q_aware_model_accuracy)
Baseline test accuracy: 0.9617000222206116
Quant test accuracy: 0.9639999866485596
After this, you have an actually quantized model with int8 weights and uint8 activations.
converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_tflite_model = converter.convert()
WARNING:tensorflow:From d:\python_virtualenv\tf2.1_gpu\lib\site-packages\tensorflow\python\keras\backend.py:465: set_learning_phase (from tensorflow.python.keras.backend) is deprecated and will be removed after 2020-10-11.
Instructions for updating:
Simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.
WARNING:tensorflow:From d:\python_virtualenv\tf2.1_gpu\lib\site-packages\tensorflow\python\training\tracking\tracking.py:105: Network.state_updates (from tensorflow.python.keras.engine.network) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: C:\Users\ADMINI~1\AppData\Local\Temp\tmpcwmu8x7i\assets
Define a helper function to evaluate the TF Lite model on the test dataset.
import numpy as np
def evaluate_model(interpreter):
input_index = interpreter.get_input_details()[0]["index"]
output_index = interpreter.get_output_details()[0]["index"]
# Run predictions on every image in the "test" dataset.
prediction_digits = []
for i, test_image in enumerate(test_images):
if i % 1000 == 0:
print('Evaluated on {n} results so far.'.format(n=i))
# Pre-processing: add batch dimension and convert to float32 to match with
# the model's input data format.
test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
interpreter.set_tensor(input_index, test_image)
# Run inference.
interpreter.invoke()
# Post-processing: remove batch dimension and find the digit with highest
# probability.
output = interpreter.tensor(output_index)
digit = np.argmax(output()[0])
prediction_digits.append(digit)
print('\n')
# Compare prediction results with ground truth labels to calculate accuracy.
prediction_digits = np.array(prediction_digits)
accuracy = (prediction_digits == test_labels).mean()
return accuracy
You evaluate the quantized model and see that the accuracy from TensorFlow persists to the TFLite backend.
interpreter = tf.lite.Interpreter(model_content=quantized_tflite_model)
interpreter.allocate_tensors()
test_accuracy = evaluate_model(interpreter)
print('Quant TFLite test_accuracy:', test_accuracy)
print('Quant TF test accuracy:', q_aware_model_accuracy)
Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.
Quant TFLite test_accuracy: 0.9597
Quant TF test accuracy: 0.9639999866485596
You create a float TFLite model and then see that the quantized TFLite model
is 4x smaller.
# Create float TFLite model.
float_converter = tf.lite.TFLiteConverter.from_keras_model(model)
float_tflite_model = float_converter.convert()
# Measure sizes of models.
_, float_file = tempfile.mkstemp('.tflite')
_, quant_file = tempfile.mkstemp('.tflite')
with open(quant_file, 'wb') as f:
f.write(quantized_tflite_model)
with open(float_file, 'wb') as f:
f.write(float_tflite_model)
print("Float model in Mb:", os.path.getsize(float_file) / float(2**20))
print("Quantized model in Mb:", os.path.getsize(quant_file) / float(2**20))
INFO:tensorflow:Assets written to: C:\Users\ADMINI~1\AppData\Local\Temp\tmp9bs1h66a\assets
INFO:tensorflow:Assets written to: C:\Users\ADMINI~1\AppData\Local\Temp\tmp9bs1h66a\assets
Float model in Mb: 0.08053970336914062
Quantized model in Mb: 0.02347564697265625
In this tutorial, you saw how to create quantization aware models with the TensorFlow Model Optimization Toolkit API and then quantized models for the TFLite backend.
You saw a 4x model size compression benefit for a model for MNIST, with minimal accuracy
difference. To see the latency benefits on mobile, try out the TFLite examples in the TFLite app repository.
We encourage you to try this new capability, which can be particularly important for deployment in resource-constrained environments.