pytorch里,view函数使用需要注意的点

背景概要:整整俩天终于找到问题的原因了,竟然是这个小小view函数导致的!我之前在代码里使用view函数重塑数据的形状,但是训练的时候只要batchsize大于1,网络就不收敛。batchsize设为1虽然损失会降,但是震荡太剧烈了。仔细检查代码,终于发现问题所在了。


我的代码如下:

x=self.embedding1(x_convs1.view(b*T,c*h))
x = x.view(T, b, -1)

如果不考虑数据间的顺序,这样做没什么问题,但是如果数据间的顺序是很重要的话,那么我上述的view函数使用就是错误的,因为打乱了数据顺序。
怎么理解?举个列子,文本识别任务,b代表batchsize,T代表LSTM输出个数
x经过self.embedding1后的形状为[48,68],b为2,T为24,实际上x的形状是这48个数据按顺序一行一行堆起来的,像这样(A)

pytorch里,view函数使用需要注意的点_第1张图片

如果view成[24,2,68],实际上就成了这样的形状 (B)

pytorch里,view函数使用需要注意的点_第2张图片

 现在取第一个样本LSTM的第一个输出,A[0]=B[0][0](都是序号1的数据),好,再取第一个样本LSTM的第二个输出,A[1](序号2的数据)!=B[1][0](序号3的数据),问题显现,这就是为什么我的batchsize设置为1,损失能正常下降,而设置为其他,网络却不收敛。
正解为:

x=self.embedding1(x_convs1.view(b*T,c*h))
x = x.view(b, T, -1)

结论:使用view的时候,一定要考虑是否打乱了数据间的顺序。或者说,你想要的顺序和实际你得到的顺序是否一致!

你可能感兴趣的:(人工智能,pytorch,view函数)