【PyTorch】nn.TransformerEncoder 使用 src_key_padding_mask 时出现nan

问题描述:

        在使用nn.TransformerEncoder时,不使用src_key_padding_mask,编码的输出正常,使用src_key_padding_mask后编码结果变成nan了。

ego_transformer_encoder = nn.TransformerEncoder(ego_encoder_layer, num_layers=6)
ego_transformer_features = ego_transformer_encoder(ego_seq2, src_key_padding_mask=src_padding_mask)

分析解决:

        出现nan的原因来自于src_key_padding_mask,src_key_padding_mask 是一个二值化的tensor,在需要被忽略地方应该是True,在需要保留原值的情况下,是False。检查发现src_key_padding_mask全为True,此时会导致编码后结果全为nan。

        解决方法是更新mask或不使用mask。

你可能感兴趣的:(AI/ML/DL,pytorch,深度学习,人工智能)