读取keras保存的h5文件,显示各层的权重

 # hdf5的数据结构主要是File - Group - Dataset三级,

# 数据集dataset, 是同一类型数据的多维数组; 组group, 是一种容器结构

# 参考我们的文件系统,不同的文件存放在不同的目录下:

# 目录就是group,描述了数据集DataSet的分类信息,通过group有效的将多种dataset进行管理和划分

# 文件就是dataset,表示具体的数据

测试文件下载  :   blstm_model.h5 和 best_model.weights 
链接: https://pan.baidu.com/s/189lGr5foy4AafwVFGGa3dw 
提取码: gtye
import os
import h5py
import numpy as np

def print_model_h5_wegiths(weight_file_path):
    # weights的tensor保存在Dataset的value中,而每一集都会有attrs保存各网络层的属性

    f = h5py.File(weight_file_path)  # 读取weights h5文件返回File类
    try:
        if len(f.attrs.items()):
            print("{} contains: ".format(f.filename))  # weight_file_path
            print("Root attributes:")
        for key, value in f.attrs.items():
            print("  {}: {}".format(key, value))
            # 输出储存在File类中的attrs信息,一般是各层的名称 : layer_names\ backend \keras_version

        for layer, g in f.items():
            # 读取各层的名称以及包含层信息的Group类
            print("  {}  with Group :  {}".format(layer, g))     # model_weights with Group :  ),
            print("    Attributes:")
            for key, value in g.attrs.items():
                # 输出储存在Group类中的attrs信息,一般是各层的weights和bias及他们的名称
                # eg ;weight_names: [b'attention_2/q_kernel:0' b'attention_2/k_kernel:0' b'attention_2/w_kernel:0']
                print("      {}: {}".format(key, value))
                #
                print("    Dataset:")   # np.array(f.get(key)).shape()
            for name, d in g.items():  # 读取各层储存具体信息的Dataset类
                print('name:   ', name, d)

                if str(f.filename).endswith('.weights'):
                    for k, v in d.items():
                        # 输出储存在Dataset中的层名称和权重,也可以打印dataset的attrs
                        # k , v   embeddings:0 
                        print(' {} with shape : {} or {}  '.format(k, np.array(d.get(k)).shape, np.array(v).shape))
                        print("      {} have weights : {}".format(k, np.array(v)))   # 各层的权重

                if str(f.filename).endswith('.h5'):
                    for k, v in d.items():  # v 等价于  d.get(k)
                        print(k, v)
                        # Adam 

    finally:
        f.close()

print('当前工作路径:', os.getcwd())

model_weights = r'../ckpt/best_model.weights'
print_model_h5_wegiths(model_weights )
print('***'*10)
h5_weight = r'../ckpt/blstm_model.h5'
print_model_h5_wegiths(h5_weight)
 

 

你可能感兴趣的:(读取keras保存的h5文件,显示各层的权重)