在 PyTorch 中,写模型就像搭积木:你需要灵活、清晰地把各个“模块”拼成一个能跑的网络。而本章就是你的“模型搭建说明书”,从最基础的 nn.Module
原理,到多输入输出、模型调试保存技巧一网打尽。
你将收获:
nn.Module
到底帮我们做了什么?nn.Module
这几乎是所有 PyTorch 教材、工程项目的主流写法:
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.layer1 = nn.Linear(input_dim, hidden_dim)
self.act = nn.ReLU()
self.layer2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
x = self.act(self.layer1(x))
x = self.layer2(x)
return x
清晰、灵活、可调试。推荐指数:
nn.Sequential
model = nn.Sequential(
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 10)
)
结构简单写得快,但不支持分支、跳连、多个输入输出。
适合原型开发或结构简单的 MLP。推荐指数:
import torch.nn.functi