EANet笔记

 

核心网络reset

pa_avg_pool

pa_max_pool

全局平均池化

import torch.nn as nn


class GlobalPool(object):
    def __init__(self, cfg):
        self.pool = nn.AdaptiveAvgPool2d(1) if cfg.max_or_avg == 'avg' else nn.AdaptiveMaxPool2d(1)

    def __call__(self, in_dict):
        feat = self.pool(in_dict['feat'])
        feat = feat.view(feat.size(0), -1)
        out_dict = {'feat_list': [feat]}
        return out_dict

 

import torch.nn.functional as F


def pa_avg_pool(in_dict):
    """Mask weighted avg pooling.
    Args:
        feat: pytorch tensor, with shape [N, C, H, W]
        mask: pytorch tensor, with shape [N, pC, pH, pW]
    Returns:
        feat_list: a list (length = pC) of pytorch tensors with shape [N, C]
        visible: pytorch tensor with shape [N, pC]
    """
    feat = in_di

你可能感兴趣的:(深度学习宝典)