详解Pytorch中的torch.nn.MSELoss函数(包括每个参数的分析)

一、函数介绍

Pytorch中MSELoss函数的接口声明如下,具体网址可以点这里。

torch.nn.MSELoss(size_average=None, reduce=None, reduction=‘mean’)

该函数默认用于计算两个输入对应元素差值平方和的均值。具体地,在深度学习中,可以使用该函数用来计算两个特征图的相似性。

二、使用方式
import torch

# input和target分别为MESLoss的两个输入
input = torch.tensor([0.,0.,0.])
target = torch.tensor([1.,2.,3.])

# MSELoss函数的具体使用方法如下所示,其中MSELoss函数的参数均为默认参数。
loss = torch.nn.MSELoss(size_average=None, reduce=None, reduction='mean')
loss = loss(input, target)

print(loss)

# input和target逐元素差值平方和的均值计算如下,可以看到与上述MSELoss函数的返回值相同。
# 证明了MSELoss默认用于计算两个输入逐元素差值平方和的均值。
print(((1-0)*(1-0)+(2-0)*(2-0)+(3-0)*(3-0))/3.)

在这里插入图片描述

三、参数介绍

如果同时给出了reduce、size_average、reduction三个参数,则首先看前两个参数。如果前两个参数均为None,则函数的返回值由reduction参数决定。如果前两个参数不全为None,则函数的返回值由前两个参数决定,在这种情况下,为None的那个参数默认为True。确定了三个参数的取值后,根据下述规则进行计算即可:

  • reduce=True时,若size_average=True,则返回一个batch中所有样本损失的均值,结果为标量。注意,对于MESLoss函数来说,首先对该batch中的所有样本损失进行逐元素均值操作,然后对得到N个值再进行均值操作即得到返回值(假设批大小为N,即该batch中共有N个样本),用官网的话来说,就是The mean operation still operates over all the elements, and divides by N.
  • reduce=True时,若size_average=False,则返回一个batch中所有样本损失的和,结果为标量注意,对于MESLoss函数来说,首先对该batch中的所有样本损失进行逐元素求和操作,然后对得到N个值再进行求和操作即得到返回值(假设批大小为N,即该batch中共有N个样本),用官网的话来说,就是The sum operation still operates over all the elements.
  • reduce=False时,则size_average参数失效,即无论size_average参数为False还是True,效果都是一样的。此时,函数返回的是一个batch中每个样本的损失,结果为向量
  • reduction参数包含了reduce和size_average参数的双重含义。即,当reduction=‘none’时,相当于reduce=False;当reduction=‘sum’时,相当于reduce=True且size_average=False;当reduction=‘mean’时,相当于reduce=True且size_average=True;这也是为什么reduce和size_average参数将在后续版本中被弃用的原因

实际上,大家在使用该函数时完全不用考虑地这么细致。上面之所以分析地这么细致只是想系统地对该函数进行一个分析讲解,用于帮助那些喜欢深究的同学。如果你只是想快速地使用该函数,只需要将前两个参数即reduce和size_average参数置为None,然后对reduction进行传参即可;由于该函数的前两个参数本身就默认为None,因此只需要对reduction进行传参即可,具体使用例子可以参考第四部分。

四、实例讲解

1.当reduction='mean’时,即返回一个batch中所有样本损失的均值。

import torch
import torch.nn.functional as F
input = [[[0.,0.,0.],
          [0.,0.,0.],
          [0.,0.,0.]],

         [[0.,0.,0.],
          [0.,0.,0.],
          [0.,0.,0.]]]
input = torch.tensor(input)

target = [[[1.,2.,3.],
           [4.,5.,6.],
           [7.,8.,9.]],

          [[11.,12.,13.],
           [14.,15.,16.],
           [17.,18.,19.]]]
target = torch.tensor(target)

loss = torch.nn.MSELoss(reduction='mean') # loss = torch.nn.MSELoss()效果相同,因为reduction参数默认为'mean'。
loss = loss(input, target)
print(loss)

# 注意,下式最后除以2是指该函数输入的批大小为2;下式中除以9是指该函数输入的批数据中每个样本的元素个数为9。
mean_result = ((1.*1. + 2.*2. + 3.*3. + 4.*4. + 5.*5. + 6.*6. + 7.*7. + 8.*8. + 9.*9.)/9 + (11.*11. + 12.*12. + 13.*13. + 14.*14. + 15.*15. + 16.*16. + 17.*17. + 18.*18. + 19.*19.)/9) / 2
print(mean_result)

在这里插入图片描述

2.当reduction='sum’时,即返回一个batch中所有样本损失的和。

import torch
import torch.nn.functional as F
input = [[[0.,0.,0.],
          [0.,0.,0.],
          [0.,0.,0.]],

         [[0.,0.,0.],
          [0.,0.,0.],
          [0.,0.,0.]]]
input = torch.tensor(input)

target = [[[1.,2.,3.],
           [4.,5.,6.],
           [7.,8.,9.]],

          [[11.,12.,13.],
           [14.,15.,16.],
           [17.,18.,19.]]]
target = torch.tensor(target)

loss = torch.nn.MSELoss(reduction='sum')
loss = loss(input, target)
print(loss)
sum_result = ((1.*1. + 2.*2. + 3.*3. + 4.*4. + 5.*5. + 6.*6. + 7.*7. + 8.*8. + 9.*9.) + (11.*11. + 12.*12. + 13.*13. + 14.*14. + 15.*15. + 16.*16. + 17.*17. + 18.*18. + 19.*19.))
print(sum_result)

在这里插入图片描述

3.当reduction=‘none’时,即返回的是一个batch中每个样本的损失。

import torch
import torch.nn.functional as F
input = [[[0.,0.,0.],
          [0.,0.,0.],
          [0.,0.,0.]],

         [[0.,0.,0.],
          [0.,0.,0.],
          [0.,0.,0.]]]
input = torch.tensor(input)

target = [[[1.,2.,3.],
           [4.,5.,6.],
           [7.,8.,9.]],

          [[11.,12.,13.],
           [14.,15.,16.],
           [17.,18.,19.]]]
target = torch.tensor(target)

loss = torch.nn.MSELoss(reduction='none')
loss = loss(input, target)
print(loss)

详解Pytorch中的torch.nn.MSELoss函数(包括每个参数的分析)_第1张图片

五、参考链接
  • https://blog.csdn.net/u013548568/article/details/81532605
  • https://blog.csdn.net/RadiantJeral/article/details/86585152

你可能感兴趣的:(深度学习,pytorch,深度学习,人工智能)