【PyTorch】max、mean、Max/AvgPool

一、torch.max

torch.max(input, dim, keepdim=False)
参数:
\qquad input: 输入 是一个tensor
\qquad dim: max函数索引的维度
\qquad keepdim: 保持输出的维度
     \qquad\qquad\quad\;\; 当keepdim=False时,输出比输入少一个维度(就是指定的dim求范数的维度)。
     \qquad\qquad\quad\;\; 当keepdim=True时,输出与输入维度相同,仅仅是输出在求范数的维度上元素个数变为1。
当max函数中有维数参数的时候,它的返回值为两个,一个为最大值,另一个为最大值的索引.

1.1、当input为二维时

1、不输入dim, 默认是求所有维度所有元素的最大值

import torch

input = torch.rand((2, 2))
print("input=", input)
max = torch.max(input)
print("max=", max)

【PyTorch】max、mean、Max/AvgPool_第1张图片

2、输入dim=0, 默认是求每一列的最大值

import torch

input = torch.rand((2, 2))
print("input=", input)
max, _ = torch.max(input, dim=0)
print("max=", max)

【PyTorch】max、mean、Max/AvgPool_第2张图片
3、输入dim=1, 默认是求每一行的最大值

import torch

input = torch.rand((2, 2))
print("input=", input)
max, _ = torch.max(input, dim=1)
print("max=", max)

【PyTorch】max、mean、Max/AvgPool_第3张图片
4、测试 keepdim=True

import torch

input = torch.rand((2, 2))
print("input=", input)
max, _ = torch.max(input, dim=1, keepdim=True)
print("max=", max, "max_shape=", max.size())

【PyTorch】max、mean、Max/AvgPool_第4张图片

1.2、当input为三维时

1、不输入dim, 默认是求所有维度所有元素的最大值

import torch

input = torch.rand((2, 2, 2))
print(input)
mean = torch.max(input)
print(mean)

【PyTorch】max、mean、Max/AvgPool_第5张图片
2、输入dim=0, 默认是求每一个位置的元素在所有维度上的最大值

import torch

input = torch.rand((2, 3, 3))
print(input)
mean, _ = torch.max(input, dim=0)
print(mean)

【PyTorch】max、mean、Max/AvgPool_第6张图片

3、输入dim=1, 默认是求每个channel上每一列的最大值

import torch

input = torch.rand((2, 3, 3))
print(input)
mean, _ = torch.max(input, dim=1)
print(mean)

【PyTorch】max、mean、Max/AvgPool_第7张图片
4、输入dim=2, 默认是求每个channel上每一行的最大值

import torch

input = torch.rand((2, 3, 3))
print(input)
mean, _ = torch.max(input, dim=2)
print(mean)

【PyTorch】max、mean、Max/AvgPool_第8张图片
5、测试 keepdim=True

import torch

input = torch.rand((2, 3, 3))
print(input)
mean, _ = torch.max(input, dim=0, keepdim=True)
print(mean)
print(mean.shape)

【PyTorch】max、mean、Max/AvgPool_第9张图片

二、nn.AdaptiveMaxPool2d(1)

求出每个channel的最大值

import torch
from torch import nn

input = torch.rand((2, 3, 3))
print(input)

max_pool = nn.AdaptiveMaxPool2d(1)
max = max_pool(input)
print(max)

【PyTorch】max、mean、Max/AvgPool_第10张图片

三、torch.mean

同上torch.max

四、nn.AdaptiveAvgPool2d(1)

同上nn.AdaptiveMaxPool2d(1)

你可能感兴趣的:(PyTorch,torch.max,torch.mean,MaxPool2d,AvgPool2d)