Haiku 是一个基于 JAX 的深度学习库,旨在提供简洁、灵活且易于使用的 API,以构建和训练神经网络模型。
import haiku as hk
import jax
import jax.numpy as jnp
### 1. 定义简单的二层神经网络
class SimpleNN(hk.Module):
def __init__(self, hidden_size, output_size):
super().__init__()
self.hidden_size = hidden_size
self.output_size = output_size
def __call__(self, x):
x = hk.Linear(self.hidden_size)(x)
x = jax.nn.relu(x)
x = hk.Linear(self.output_size)(x)
out = jax.nn.sigmoid(x)
return out
### 2. 创建模块实例
# hk.transform将普通的Python函数转换为可训练的Haiku模块。
# 转换后可以进行参数初始化、模块应用等操作。
model = hk.transform(lambda x: SimpleNN(64, 10)(x))
#print(type(model))
#print(model)
### 3. 模块参数初始化
# jax.random.PRNGKey用于伪随机数生成。
# 使用伪随机数生成器(PRNG)可以确保在相同的初始状态下获得相同的随机数序列,从而保持实验的可重复性。
rng = jax.random.PRNGKey(42)
# print(rng)
## 获取初始化的参数,参数的形状需要输入数据的形状以及模型的结构
input_data = jnp.ones((1, 128))
params = model.init(rng, input_data)
## 查看随机初始化的参数,rng保证每次初始化出相同的参数
#print("Initialized Parameters:", params)
#print(params)
print(params['simple_nn/linear']['w'].shape)
print(params['simple_nn/linear_1']['w'].shape)
### 4.模型预测
# apply方法接受模块参数和输入数据,并返回模块的输出数据
# 在模型训练时,apply方法是对整个模块进行前向传播的操作
output_data = model.apply(params, rng, input_data)
print("Output Data:", output_data)