hk.LayerNorm 模块介绍

hk.LayerNorm 是 Haiku 库中用于实现 Layer Normalization(层归一化)的模块。Layer Normalization 是一种神经网络归一化的技术,旨在提高神经网络的训练稳定性和泛化性。

主要参数:

  • axis(默认为-1): 沿着哪个轴进行归一化。通常选择最后一个轴,对输入的特征进行归一化。

  • create_scale(默认为True): 是否创建可学习的缩放参数。如果为 True,则会创建一个可学习的缩放参数,用于调整归一化后的值的幅度。

  • create_offset(默认为True): 是否创建可学习的偏置参数。如果为 True,则会创建一个可学习的偏置参数,用于调整归一化后的值的偏移。

  • epsilon(默认为1e-5): 一个小的正数,用于防止除以零的情况。

import haiku as hk
import jax
import jax.numpy as jnp
import pickle

### 自定义LayerNorm模块
class LayerNorm(hk.LayerNorm):
  """LayerNorm module.

  Equivalent to hk.LayerNorm but with different parameter shapes: they are
  always vectors rather than possibly higher-rank tensors. This makes it easier
  to change the layout whilst keep the model weight-compatible.
  """

  def __init__(self,
               axis,
               create_scale: bool,
               create_offset: bool,
               eps: float = 1e-5,
               scale_init=None,
               offset_init=None,
               use_fast_variance: bool = False,
               name=None,
               param_axis=None):
    super().__init__(
        axis=axis,
        create_scale=False,
        create_offset=False,
        eps=eps,
        scale_init=None,
        offset_init=None,
        use_fast_variance=use_fast_variance,
        name=name,
        param_axis=param_axis)
    self._temp_create_scale = create_scale
    self._temp_create_offset = create_offset
     
    #self.scale_init = hk.initializers.Constant(1)
    #self.offset_init = hk.initializers.Constant(0)
    
  def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
    is_bf16 = (x.dtype == jnp.bfloat16)
    if is_bf16:
      x = x.astype(jnp.float32)

    param_axis = self.param_axis[0] if self.param_axis else -1
    param_shape = (x.shape[param_axis],)

    param_broadcast_shape = [1] * x.ndim
    param_broadcast_shape[param_axis] = x.shape[param_axis]
    scale = None
    offset = None
    
    # scale,offset张量的形状必须可扩展到输入数据的形状。
    # 没有显式指定 self.scale_init,self.offset_init参数,
    # 则默认使用 Haiku 库中的默认初始化方法。同 def __init__()中注释的显式指定
    if self._temp_create_scale:
      scale = hk.get_parameter(
          'scale', param_shape, x.dtype, init=self.scale_init)
      scale = scale.reshape(param_broadcast_shape)

    if self._temp_create_offset:
      offset = hk.get_parameter(
          'offset', param_shape, x.dtype, init=self.offset_init)
      offset = offset.reshape(param_broadcast_shape)

    out = super().__call__(x, scale=scale, offset=offset)

    if is_bf16:
      out = out.astype(jnp.bfloat16)

    return out


with open("Human_HBB_tensor_dict_ensembled.pkl",'rb') as f:
  Human_HBB_tensor_dict = pickle.load(f)

input_data = jnp.array(Human_HBB_tensor_dict['msa_feat'])
print(input_data.shape)

# 转换为Haiku模块
# LayerNorm层,在数据最后一个维度/轴(特征)做归一化,并创建可学习的缩放参数和偏置参数
model = hk.transform(lambda x: LayerNorm(axis=[-1], 
                                         create_scale=True,
                                         create_offset=True,
                                         name='msa_feat_norm')(x))
 
print(model)
                                     
## 获取初始化的参数,参数的形状需要输入数据的形状以及模型的结构
rng = jax.random.PRNGKey(42)
params = model.init(rng, input_data)
print(params) 
print("params scale shape:") 
#print(params['msa_feat_norm']['scale'].shape)
#print("params offset bias:")
#print(params['msa_feat_norm']['offset'].shape)
 
output_data = model.apply(params, rng, input_data)
print("input_data shape:", input_data.shape) 
print("Output Data shape:", output_data.shape)
#print("原始数据:", input_data)
print("经过LayerNorm后:", output_data)

### 使用原始的hk.LayerNorm模块
model2 = hk.transform(lambda x: hk.LayerNorm(axis=[-1], 
                                            create_scale=True,
                                            create_offset=True,
                                            name='msa_feat_norm')(x))
 
print(model2)
                                     
params2 = model2.init(rng, input_data)
print(params2) 
print("params2 scale shape:") 
print(params2['msa_feat_norm']['scale'].shape)
print("params2 offset bias:")
print(params2['msa_feat_norm']['offset'].shape)
 
output_data2 = model2.apply(params2, rng, input_data)
print("input_data shape:", input_data.shape) 
print("Output Data shape:", output_data2.shape)
#print("原始数据:", input_data)
print("经过LayerNorm后:", output_data2)

参考:

https://dm-haiku.readthedocs.io/en/latest/api.html?highlight=layernorm#layernorm

你可能感兴趣的:(生物信息学,人工智能)