Pytorch训练CRNN输出为NaN

记录一个非常隐蔽pytorch训练audio相关的Bug。

问题描述

用spectrogram作为input feature,训练几轮之后输出和loss中就有NaN,导致神经网络不收敛。经排查每步的输出,feature中没有NaN,但是网络会输出NaN,loss为NaN。

问题复现

网络输入是torchaudio.transforms.Spectrogram(…,power=1)
model.backward()

torchaudio.transforms.Spectrogram默认的power=2,用默认参数不会有这个问题。

问题分析

在pytorch audio的issue中发现类似问题,torchaudio.transforms.Spectrogram在power=1,也就是用magnitude spectrogram作为输入时,反向传递会出现这个问题。

Spectrogram的计算如下:

stft = torch.stft(wav, n_fft, hop_length, win_length, ...)
norm_sq = stft.pow(2.).sum(-1)
result = norm_sq.sqrt()

如果wav为0,则norm_sq为0,取根号后的result也为0,但是result在计算梯度的时候就是NaN, 因此会在反向传递时出问题。

解决方案

eps = 1e-14  # Add eps to ensure .sqrt is not 0

stft = torch.stft(wav, n_fft, hop_length, win_length, ...)
norm_sq = stft.pow(2.).sum(-1)
result = (norm_sq + eps).sqrt()

给要取根号的值加个很小的数,使得取根号后的result不为0.

解决!

参考:https://github.com/pytorch/audio/issues/993

你可能感兴趣的:(八阿哥图鉴,pytorch,深度学习,人工智能,音频)