解析flatbuffer格式的tflite文件,转成可读的python dict格式,并可描述模型完整推理流程。
tf.lite.Interpreter可以读tflite模型,但是其python接口没有描述模型结构(op node节点间的连接关系)
比如,interpreter.get_tensor_details()获取的信息,如下
[{'name': 'input_13',
'index': 0,
'shape': array([ 1, 1, 240, 1], dtype=int32),
'shape_signature': array([ -1, 1, 240, 1], dtype=int32),
'dtype': numpy.float32,
'quantization': (0.0, 0),
'quantization_parameters': {'scales': array([], dtype=float32),
'zero_points': array([], dtype=int32),
'quantized_dimension': 0},
'sparsity_parameters': {}},
{'name': 'model_12/dense_144/BiasAdd/ReadVariableOp/resource',
'index': 1,
'shape': array([2], dtype=int32),
'shape_signature': array([2], dtype=int32),
'dtype': numpy.float32,
'quantization': (0.0, 0),
'quantization_parameters': {'scales': array([], dtype=float32),
'zero_points': array([], dtype=int32),
'quantized_dimension': 0},
'sparsity_parameters': {}},
{'name': 'model_12/dense_145/BiasAdd/ReadVariableOp/resource',
....
按本文方式,可以直接获取节点的op参数、输入、输出序号,如下
subg
{'inputs': [0],
'name': [109, 97, 105, 110],
'operators': [{'builtin_options': {'dilation_h_factor': 1,
'dilation_w_factor': 1,
'fused_activation_function': 1,
'padding': 0,
'stride_h': 1,
'stride_w': 1},
'builtin_options_type': 1,
'custom_options': None,
'custom_options_format': 0,
'inputs': [0, 52, 68], //输入节点
'intermediates': None,
'mutating_variable_inputs': None,
'opcode_index': 0,
'outputs': [82]}, //输出节点
{'builtin_options': {'depth_multiplier': 1,
...
tensors
[{'buffer': 1,
'is_variable': False,
'name': [105, 110, 112, 117, 116, 95, 49, 51],
'quantization': {'details': None,
'details_type': 0,
'max': None,
'min': None,
'quantized_dimension': 0,
'scale': None,
'zero_point': None},
'shape': [1, 1, 240, 1],
'shape_signature': [-1, 1, 240, 1],
'sparsity': None,
'type': 0},
{'buffer': 2,
...
#/tensorflow/lite/tools/visualize.py
import re
from tensorflow.lite.python import schema_py_generated as schema_fb
def BuiltinCodeToName(code):
"""Converts a builtin op code enum to a readable name."""
for name, value in schema_fb.BuiltinOperator.__dict__.items():
if value == code:
return name
return None
def CamelCaseToSnakeCase(camel_case_input):
"""Converts an identifier in CamelCase to snake_case."""
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_input)
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
def FlatbufferToDict(fb, preserve_as_numpy):
if isinstance(fb, int) or isinstance(fb, float) or isinstance(fb, str):
return fb
elif hasattr(fb, "__dict__"):
result = {}
for attribute_name in dir(fb):
attribute = fb.__getattribute__(attribute_name)
if not callable(attribute) and attribute_name[0] != "_":
snake_name = CamelCaseToSnakeCase(attribute_name)
preserve = True if attribute_name == "buffers" else preserve_as_numpy
result[snake_name] = FlatbufferToDict(attribute, preserve)
return result
elif isinstance(fb, np.ndarray):
return fb if preserve_as_numpy else fb.tolist()
elif hasattr(fb, "__len__"):
return [FlatbufferToDict(entry, preserve_as_numpy) for entry in fb]
else:
return fb
def CreateDictFromFlatbuffer(buffer_data):
model_obj = schema_fb.Model.GetRootAsModel(buffer_data, 0)
model = schema_fb.ModelT.InitFromObj(model_obj)
return FlatbufferToDict(model, preserve_as_numpy=False)
转换
# Read the model.
with open('xxx.tflite', 'rb') as f:
model_buffer = f.read()
#后面获取到tensor id后,通过interpreter.get_tensor即可拿到tensor值
interpreter = tf.lite.Interpreter(model_content=model_buffer)
interpreter.allocate_tensors()
data = CreateDictFromFlatbuffer(model_buffer)
op_codes = data['operator_codes'] #支持/注册的op
subg = data['subgraphs'][0] #模型结构描述,具体的op构成
tensors = subg['tensors'] #tensor描述, 主要有layer参数、权重
for layer in subg['operators']:
#layer name
op_idx = layer['opcode_index']
op_code = op_codes[op_idx]['builtin_code']
layer_name = BuiltinCodeToName(op_code)
#layer param
layer_param = layer['builtin_options']
#layer input/output idx
input_tensor_idx = layer['inputs']
output_tensor_idx = layer['outputs']
#input
input_idx = input_tensor_idx[0]
#filter weight
weight_idx = input_tensor_idx[1]
weight = interpreter.get_tensor(weight_idx) #用interpreter获取具体的权重数值
filters = tensors[weight_idx]['shape'][0] #卷积核尺寸
#filter bias
bias_idx = input_tensor_idx[2]
上述方法在取tensor数值时,用了interpreter.get_tensor(idx)的方式。 实际上tensor数值也可以从data['buffers']中获取,只不过data['buffers']将tensor解析成uint8_t了。
tensors[52]
{'buffer': 53,
'is_variable': False,
'name': [109,
...
}
interpreter.get_tensor(52) //32位浮点
array([[[[ 0.89609855],
[-0.76255393],
[ 0.2671022 ]]],
...
data['buffers'][53]['data'] //8位
array([183, 102, 101, 63, 188, 54, ...
可以自己验证一个数试试
uint8_t a[] = {183, 102,101,63};
printf("%f\n", *(float*)a); // = 0.896099
1.tensor_idx : 每个operator里标注的input、output索引(同netron里显示的location)
buffer = interpreter.get_tensor(tensor_idx) //得到数值
tensors = subg['tensors']
tensor = tensors[tensor_idx] //得到一个Tensor对象
2.buffer_idx:Tensor对象中表述数据位置的索引
buffer_idx = tensor['buffer']
buffer = data['buffers'][buffer_idx] //得到数值