pytorch的LSTM层的batch first参数

DataLoader返回数据时候一般第一维都是batch,pytorch的LSTM层默认输入和输出都是batch在第二维。pytorch的LSTM层的batch first参数_第1张图片
如果按照默认的输入和输出结构,可能需要自己定义DataLoader的collate_fn函数,将batch放在第一维。

我一开始就是费了一些劲,捣鼓了半天。后来发现有batch first这个参数,将其设为True就可以将batch放在第一维。(其实一开始看文档的时候注意到了,但是后来写代码忘记它了,回过头来看的时候简直要气死!!)

还有就是使用这个参数的时候有一点要注意,看官方文档:
在这里插入图片描述
设置batch first为true后,input和output都会变为batch在第一维,但是我们有时候也会用到hn和cn,那它们两个是会变呢还是不变呢?
作为懒星人,先去百度了一下,有一篇博客是这样说的:
pytorch的LSTM层的batch first参数_第2张图片
所以我在写代码时就按照博客所说的来写了,但是报错了。。。。
只能自己上手实验了。


```python
import torch.nn as nn
import torch
import numpy as np

model = nn.LSTM(input_size=6, hidden_size=10, num_layers=1, batch_first=True)
model = model.double()

x = np.random.randn(100, 10, 6)

x = torch.from_numpy(x)
print(x.shape)

y, (hn, cn) = model(x)  # 不提供h0和c0,默认全0
print('y:', y.shape)
print('hn:', hn.shape)
print('cn:', cn.shape)

运行结果:

pytorch的LSTM层的batch first参数_第3张图片

根据运行结果来看,设置batch first为true,只有输入input和输出output的batch会在第一维,hn和cn是不会变的。使用的时候要注意,会很容易弄混。
还有就是,这里并没有提供h0和c0,如果需要提供h0和c0,也需要注意shape。

你可能感兴趣的:(pytorch,pytorch,lstm,batch,深度学习,python)