代码举例
import torch.nn as nn
class LSTMClassifier(nn.Module):
"""
This is the simple RNN model we will be using to perform Sentiment Analysis.
"""
def __init__(self, embedding_dim, hidden_dim, vocab_size):
"""
Initialize the model by settingg up the various layers.
"""
super(LSTMClassifier, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
self.lstm = nn.LSTM(embedding_dim, hidden_dim)
self.dense = nn.Linear(in_features=hidden_dim, out_features=1)
self.sig = nn.Sigmoid()
self.word_dict = None
def forward(self, x):
"""
Perform a forward pass of our model on some input.
"""
x = x.t()
lengths = x[0,:]
reviews = x[1:,:]
embeds = self.embedding(reviews)
lstm_out, _ = self.lstm(embeds)
out = self.dense(lstm_out)
out = out[lengths - 1, range(len(lengths))]
return self.sig(out.squeeze())
def train(model, train_loader, epochs, optimizer, loss_fn, device):
for epoch in range(1, epochs + 1):
model.train()
total_loss = 0
for batch in train_loader:
batch_X, batch_y = batch
print("batch_X.shape=",batch_X.shape)
print("batch_y.shape=",batch_y.shape)
batch_X = batch_X.to(device)
batch_y = batch_y.to(device)
# TODO: Complete this train method to train the model provided.
optimizer.zero_grad()
# get predictions from model
y_pred = model(batch_X)
# perform backprop
loss = criterion(y_pred, batch_y)
loss.backward()
optimizer.step()
total_loss += loss.data.item()
print("Epoch: {}, BCELoss: {}".format(epoch, total_loss / len(train_loader)))
该类的使用LSTMClassifier对象model 实际调用的是forward (),原因是nn.Module()中的__call__方法中调用了forward方法
#第三方调用model代码场景如下
import torch.optim as optim
from train.model import LSTMClassifier
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LSTMClassifier(32, 100, 5000).to(device)
optimizer = optim.Adam(model.parameters(),lr=0.001)
loss_fn = torch.nn.BCELoss()
train(model, train_sample_dl, 5, optimizer, loss_fn, device)
#上面类的父类nn.Module中的__call__方法中调用了forward方法,如下
def __call__(self, *input, **kwargs):
for hook in self._forward_pre_hooks.values():
result = hook(self, input)
if result is not None:
if not isinstance(result, tuple):
result = (result,)
input = result
if torch._C._get_tracing_state():
result = self._slow_forward(*input, **kwargs)
else:
result = self.forward(*input, **kwargs)
for hook in self._forward_hooks.values():
hook_result = hook(self, input, result)
if hook_result is not None:
result = hook_result
if len(self._backward_hooks) > 0:
var = result
while not isinstance(var, torch.Tensor):
if isinstance(var, dict):
var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
else:
var = var[0]
grad_fn = var.grad_fn
if grad_fn is not None:
for hook in self._backward_hooks.values():
wrapper = functools.partial(hook, self)
functools.update_wrapper(wrapper, hook)
grad_fn.register_hook(wrapper)
return result
理解该现象的关键
__call__
_forward_pre_hooks
参考
https://www.cnblogs.com/SBJBA/p/11355412.html
nn.Module方法逐个分析
https://blog.csdn.net/lishuiwang/article/details/104505675/
https://www.jb51.net/article/184033.htm
nn.Module源码
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py