自适应池化层快速转换为池化层

为什么要将自适应池化层转化为池化层呢?

onnx中的错误:ONNX export of operator adaptive pooling, since output_size is not constant

因为onnx要求模型的输入参数的固定的,而自适应池化层是根据输入来确定模型参数的。

tensorflow等没有自适应池化层,只要pytorch有

自适应池化Adaptive Pooling与标准的Max/AvgPooling区别在于,自适应池化Adaptive Pooling会根据输入的参数来控制输出output_size,而标准的Max/AvgPooling是通过kernel_size,stride与padding来计算output_size: 
                     output_size = ceil ( (input_size+2∗padding−kernel_size)/stride)+1

通常,在池化层中,padding =0

stride = floor ( (input_size / (output_size) )

kernel_size = input_size − (output_size−1) * stride

padding = 0

1.AdaptiveAvgPool1d

import torch as t
import math
import numpy as np
 
alist = t.randn(2,3,9)
 
inputsz = np.array(alist.shape[2:])
outputsz = np.array([4])
 
stridesz = np.floor(inputsz/outputsz).astype(np.int32)
 
kernelsz = inputsz-(outputsz-1)*stridesz
 
adp = t.nn.AdaptiveAvgPool1d(list(outputsz))
avg = t.nn.AvgPool1d(kernel_size=list(kernelsz),stride=list(stridesz))
adplist = adp(alist)
avglist = avg(alist)
 
print(alist)
print(adplist)
print(avglist)

自适应池化层快速转换为池化层_第1张图片

2. AdaptiveAvgPool2d

import torch as t
import math
import numpy as np
 
alist = t.randn(2,6,7)
 
inputsz = np.array(alist.shape[1:])
outputsz = np.array([2,3])
 
stridesz = np.floor(inputsz/outputsz).astype(np.int32)
 
kernelsz = inputsz-(outputsz-1)*stridesz
 
adp = t.nn.AdaptiveAvgPool2d(list(outputsz))
avg = t.nn.AvgPool2d(kernel_size=list(kernelsz),stride=list(stridesz))
adplist = adp(alist)
avglist = avg(alist)
 
print(alist)
print(adplist)
print(avglist)

自适应池化层快速转换为池化层_第2张图片

 3. AdaptiveAvgPool3d

import torch as t
import math
import numpy as np
alist = t.randn(4,3,2,6,7)
 
inputsz = np.array(alist.shape[2:])
print(inputsz)
outputsz = np.array([inputsz[0],2,3])
 
stridesz = np.floor(inputsz/outputsz).astype(np.int32)
print(stridesz)
 
kernelsz = inputsz-(outputsz-1)*stridesz
 
adp = t.nn.AdaptiveAvgPool3d(list(outputsz))
avg = t.nn.AvgPool3d(kernel_size=list(kernelsz),stride=list(stridesz))
adplist = adp(alist)
avglist = avg(alist)
 
print(alist)
print(adplist)
print(avglist)

你可能感兴趣的:(网络,人工智能,深度学习)