haiku自定义线性模块

import haiku as hk
from typing import Union, Sequence
import jax.numpy as jnp
import jax
import numbers
import numpy as np
import pickle
import copy

# Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
TRUNCATED_NORMAL_STDDEV_FACTOR = np.asarray(.87962566103423978,
                                            dtype=np.float32)


def get_initializer_scale(initializer_name, input_shape):
  """Get Initializer for weights and scale to multiply activations by."""

  if initializer_name == 'zeros':
    w_init = hk.initializers.Constant(0.0)
  else:
    # fan-in scaling
    scale = 1.
    for channel_dim in input_shape:
      # 除以每个维度的值
      scale /= channel_dim
    if initializer_name == 'relu':
      scale *= 2

    noise_scale = scale

    stddev = np.sqrt(noise_scale)
    # Adjust stddev for truncation.
    stddev = stddev / TRUNCATED_NORMAL_STDDEV_FACTOR
    # 截断正态分布的随机数 (mean - 2 * stddev, mean + 2 * stddev)
    w_init = hk.initializers.TruncatedNormal(mean=0.0, stddev=stddev)
    
    # hk.initializers 模块包含了一系列用于初始化模型参数的初始化器。
    # 常见的 hk.initializers 模块中的初始化器:
    # 1. hk.initializers.Constant(value)
    # 2. hk.initializers.RandomNormal(stddev=1.0)
    # 3. hk.initializers.TruncatedNormal(stddev=1.0)
    # 4. hk.initializers.VarianceScaling(scale=1.0, mode='fan_in', distribution='truncated_normal')
    # 5. hk.initializers.Orthogonal(gain=1.0)
    # 6. hk.initializers.IdentityGain()
  return w_init


class Linear(hk.Module):
  """Protein folding specific Linear module.

  This differs from the standard Haiku Linear in a few ways:
    * It supports inputs and outputs of arbitrary rank
    * Initializers are specified by strings
  """

  def __init__(self,
               num_output: Union[int, Sequence[int]],
               initializer: str = 'linear',
               num_input_dims: int = 1,
               use_bias: bool = True,
               bias_init: float = 0.,
               precision = None,
               name: str = 'linear'):
    """Constructs Linear Module.

    Args:
      num_output: Number of output channels. Can be tuple when outputting
          multiple dimensions.
      initializer: What initializer to use, should be one of {'linear', 'relu',
        'zeros'}
      num_input_dims: Number of dimensions from the end to project.
      use_bias: Whether to include trainable bias
      bias_init: Value used to initialize bias.
      precision: What precision to use for matrix multiplication, defaults
        to None.
      name: Name of module, used for name scopes.
    """
    super().__init__(name=name)
    if isinstance(num_output, numbers.Integral):
      self.output_shape = (num_output,)
    else:
      self.output_shape = tuple(num_output)
    self.initializer = initializer
    self.use_bias = use_bias
    self.bias_init = bias_init
    self.num_input_dims = num_input_dims
    self.num_output_dims = len(self.output_shape)
    self.precision = precision

  def __call__(self, inputs):
    """Connects Module.

    Args:
      inputs: Tensor with at least num_input_dims dimensions.

    Returns:
      output of shape [...] + num_output.
    """

    num_input_dims = self.num_input_dims

    if self.num_input_dims > 0:
      in_shape = inputs.shape[-self.num_input_dims:]
    else:
      in_shape = ()
    # 注意初始化weights的数据分布,这样初始化的优点。
    weight_init = get_initializer_scale(self.initializer, in_shape)
    
    in_letters = 'abcde'[:self.num_input_dims]
    out_letters = 'hijkl'[:self.num_output_dims]
    
    # weights维度是输入数据维度和输出数据维度的合并
    weight_shape = in_shape + self.output_shape
    
    # hk.get_parameter:从参数字典中获取参数:四个参数,依次为:
    # 1. 参数的名称(字符串),用于唯一标识该参数。
    # 2. shape 参数指定了参数的形状。
    # 3. dtype 参数指定了参数的数据类型。
    # 4. 初始化器,用于设置参数的初始值。可以使用 hk.initializers 模块中的各种初始化器。
    weights = hk.get_parameter('weights', weight_shape, inputs.dtype, weight_init)

    equation = f'...{in_letters}, {in_letters}{out_letters}->...{out_letters}'
    # equation 字符串类型 ...abc, abch->...h
    
    # jnp.einsum 和np.dot都可以用于执行矩阵乘法
    # jnp.einsum 是一个强大的工具,允许你以字符串形式指定矩阵乘法的具体计算规则
    # jnp.einsum 则提供了更大的灵活性,允许你执行更复杂的张量操作
    output = jnp.einsum(equation, inputs, weights, precision=self.precision)

    if self.use_bias:
      bias = hk.get_parameter('bias', self.output_shape, inputs.dtype,
                              hk.initializers.Constant(self.bias_init))
      output += bias

    return output


with open("Human_HBB_tensor_dict_ensembled.pkl",'rb') as f:
  Human_HBB_tensor_dict = pickle.load(f)

#for k, v in Human_HBB_tensor_dict.items():
#    print(v.shape)

#Human_HBB_tensor_dict.keys()

batch = copy.deepcopy(Human_HBB_tensor_dict)

print(batch['aatype'].shape) 
print(batch['msa_feat'].shape)
#print(batch['msa_feat'])

msa_channel = 16 # 为了演示,设置大了,计算速度慢

input_data = batch['msa_feat'].numpy()
print(type(input_data))

# 转换为Haiku模块
model = hk.transform(lambda x: Linear(msa_channel, 
                                      name='preprocess_msa',
                                      num_input_dims = 1)(input_data))

print(model)

rng = jax.random.PRNGKey(42)
# print(rng)

## 获取初始化的参数,参数的形状需要输入数据的形状以及模型的结构
params = model.init(rng, input_data)
# print(params)
print("params weights shape:") 
print(params['preprocess_msa']['weights'].shape)
print("params weights bias:")
print(params['preprocess_msa']['bias'].shape)

output_data = model.apply(params, rng, input_data)
print("input_data shape:", input_data.shape) 
print("Output Data shape:", output_data.shape)
#print("Output Data:", output_data)

参考:
haiku Initializer

jnp.einsum

你可能感兴趣的:(生物信息学,神经网络)