WS-DAN.Pytorch代码理解

WS-DAN.Pytorch

Code Author:GuYuc

train.py

import os
import time
import logging
import warnings
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import config
from models import WSDAN
from datasets import get_trainval_datasets
from utils import CenterLoss, AverageMeter, TopKAccuracyMetric, ModelCheckpoint, batch_augment

1、GPU设置

# GPU settings
assert torch.cuda.is_available()
os.environ['CUDA_VISIBLE_DEVICES'] = config.GPU
device = torch.device("cuda")
torch.backends.cudnn.benchmark = True

torch.backends.cudnn.benchmark = True:可以让内置的 cuDNN 的 auto-tuner 自动寻找最适合当前配置的高效算法,来达到优化运行效率的问题。

2、Loss function

# General loss functions
cross_entropy_loss = nn.CrossEntropyLoss()
center_loss = CenterLoss()

CrossEntropyLoss交叉熵计算公式
l o s s ( x , c l a s s ) = − l o g ( e x p ( x [ c l a s s ] ) ∑ j e x p ( x [ j ] ) ) loss(x,class) = -log\left ( \frac{exp(x[class])}{\sum_{j}exp(x[j])} \right ) loss(x,class)=log(jexp(x[j])exp(x[class]))

3、评价指标

# loss and metric
loss_container = AverageMeter(name='loss')
raw_metric = TopKAccuracyMetric(topk=(1, 5))
crop_metric = TopKAccuracyMetric(topk=(1, 5))
drop_metric = TopKAccuracyMetric(topk=(1, 5))

AverageMeter快速计算多个类的平均值。TopKAccuracyMetric:选取概率最大的 k 个标签

class AverageMeter(Metric):
    def __init__(self, name='loss'):
        self.name = name
        self.reset()

    def reset(self):
        self.scores = 0.
        self.total_num = 0.

    def __call__(self, batch_score, sample_num=1):
        self.scores += batch_score
        self.total_num += sample_num
        return self.scores / self.total_num
class TopKAccuracyMetric(Metric):
    def __init__(self, topk=(1,)):
        self.name = 'topk_accuracy'
        self.topk = topk
        self.maxk = max(topk)
        self.reset()

    def reset(self):
        self.corrects = np.zeros(len(self.topk))
        self.num_samples = 0.

    def __call__(self, output, target):
        """Computes the precision@k for the specified values of k"""
        self.num_samples += target.size(0)
        _, pred = output.topk(self.maxk, 1, True, True) #.topk(True, True)返回前k个最大的元素
        pred = pred.t() #.()矩阵转置
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        for i, k in enumerate(self.topk):
            correct_k = correct[:k].view(-1).float().sum(0)
            self.corrects[i] += correct_k.item()

        return self.corrects * 100. / self.num_samples

4、def main()

Initialize and Data load 初始化、数据集加载

    ##################################
    # Initialize saving directory
    ##################################
    if not os.path.exists(config.save_dir):
        os.makedirs(config.save_dir)

    ##################################
    # Logging setting
    ##################################
    logging.basicConfig(
        filename=os.path.join(config.save_dir, config.log_name),
        filemode='w',
        format='%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s',
        level=logging.INFO)
    warnings.filterwarnings("ignore")

logging模块是Python内置的标准模块,主要用于输出运行日志,可以设置输出日志的等级、日志保存路径、日志文件回滚等。

日志级别:INFO——确认一切按预期运行

filename: 指定日志文件名

filemode: 和file函数意义相同,指定日志文件的打开模式,‘w’或’a’。

format: 指定输出的格式和内容。%(asctime)s: 打印日志的时间。%(levelname)s: 打印日志级别名称。%(filename)s: 打印当前执行程序名。 %(lineno)d: 打印日志的当前行号。%(message)s: 打印日志信息。

warnings.filterwarnings(“ignore”):利用过滤器来实现忽略告警

    ##################################
    # Load dataset
    ##################################
    train_dataset, validate_dataset = get_trainval_datasets(config.tag,config.image_size)

    train_loader, validate_loader = DataLoader(train_dataset,
                                               batch_size=config.batch_size,
                                               shuffle=True,
                                               num_workers=config.workers,                                                              pin_memory=True), \
                                    DataLoader(validate_dataset,                                                                        batch_size=config.batch_size * 4,
                                               shuffle=False,
                                               num_workers=config.workers,
                                               pin_memory=True)
    num_classes = train_dataset.num_classes

batch_size调用config.batch_size,shuffle=False不打乱数据顺序,num_workers= 4使用config.workers个子进程。pin_memory=True意味着,生成的Tensor数据最开始是属于内存中的锁页内存,这样将内存的Tensor转义到GPU的显存就会更快一些。

get_trainval_datasets() 调用**_int_.py**,_int_.py中import bird_dataset.py

_int_.py
from .aircraft_dataset import AircraftDataset
from .bird_dataset import BirdDataset
from .car_dataset import CarDataset
from .dog_dataset import DogDataset

def get_trainval_datasets(tag, resize):
    if tag == 'aircraft':
        return AircraftDataset(phase='train', resize=resize),                 AircraftDataset(phase='val', resize=resize)
    elif tag == 'bird':
        return BirdDataset(phase='train', resize=resize), BirdDataset(phase='val', resize=resize)
    elif tag == 'car':
        return CarDataset(phase='train', resize=resize), CarDataset(phase='val', resize=resize)
    elif tag == 'dog':
        return DogDataset(phase='train', resize=resize), DogDataset(phase='val', resize=resize)
    else:
        raise ValueError('Unsupported Tag {}'.format(tag))

根据tag调用数据集。根据’train’和’val’分别加载训练集和验证集,同时resize图片的大小。

bird_dataset.py:
def __init__(self, phase='train', resize=500):
        assert phase in ['train', 'val', 'test']
        self.phase = phase
        self.resize = resize
        self.image_id = []
        self.num_classes = 200

        # get image path from images.txt
        with open(os.path.join(DATAPATH, 'images.txt')) as f:
            for line in f.readlines():
                id, path = line.strip().split(' ')
                image_path[id] = path

        # get image label from image_class_labels.txt
        with open(os.path.join(DATAPATH, 'image_class_labels.txt')) as f:
            for line in f.readlines():
                id, label = line.strip().split(' ')
                image_label[id] = int(label)

        # get train/test image id from train_test_split.txt
        with open(os.path.join(DATAPATH, 'train_test_split.txt')) as f:
            for line in f.readlines():
                image_id, is_training_image = line.strip().split(' ')
                is_training_image = int(is_training_image)

                if self.phase == 'train' and is_training_image:
                    self.image_id.append(image_id)
                if self.phase in ('val', 'test') and not is_training_image:
                    self.image_id.append(image_id)

        # transform
        self.transform = get_transform(self.resize, self.phase)

    def __getitem__(self, item):
        # get image id
        image_id = self.image_id[item]

        # image
        image = Image.open(os.path.join(DATAPATH, 'images', image_path[image_id])).convert('RGB')  # (C, H, W)
        image = self.transform(image)

        # return image and label
        return image, image_label[image_id] - 1  # count begin from zero

    def __len__(self):
        return len(self.image_id)


if __name__ == '__main__':
    ds = BirdDataset('train')
    print(len(ds))
    for i in range(0, 10):
        image, label = ds[i]
        print(image.shape, label)

CUB-200-2011数据集的划分及设置图片大小,通过读取数据集中的.txt文件,读取图片及label,同时划分为trai、val(test)。

Initialize model 模型初始化

Initialize model,首先初始化参数,to(device)可以把tentor复制一份到指定的device,以后计算在GPU计算。调用class WSDAN(nn.Module),class WSDAN 中调用 BAP(nn.Module)

    ##################################
    # Initialize model
    ##################################
    logs = {}
    start_epoch = 0
    net = WSDAN(num_classes=num_classes, M=config.num_attentions, net=config.net, pretrained=True)

    # feature_center: size of (#classes, #attention_maps * #channel_features)
    feature_center = torch.zeros(num_classes, config.num_attentions * net.num_features).to(device)

    if config.ckpt:
        # Load ckpt and get state_dict
        checkpoint = torch.load(config.ckpt)

        # Get epoch and some logs
        logs = checkpoint['logs']
        start_epoch = int(logs['epoch'])

        # Load weights
        state_dict = checkpoint['state_dict']
        net.load_state_dict(state_dict)
        logging.info('Network loaded from {}'.format(config.ckpt))

        # load feature center
        if 'feature_center' in checkpoint:
            feature_center = checkpoint['feature_center'].to(device)
            logging.info('feature_center loaded from {}'.format(config.ckpt))

    logging.info('Network weights save to {}'.format(config.save_dir))
WSDAN(nn.Module)

num_classes:分类数目 M=config.num_attentions: attention mps的数量 net=config.net:特征提取网络 pretrained=True 预训练

class WSDAN(nn.Module):
    def __init__(self, num_classes, M=32, net='inception_mixed_6e', pretrained=False):
        super(WSDAN, self).__init__()
        self.num_classes = num_classes
        self.M = M
        self.net = net

        # Network Initialization
        if 'inception' in net:     
            if net == 'inception_mixed_6e':
                self.features = inception_v3(pretrained=pretrained).get_features_mixed_6e()
                self.num_features = 768
            elif net == 'inception_mixed_7c':
                self.features = inception_v3(pretrained=pretrained).get_features_mixed_7c()
                self.num_features = 2048
            else:
                raise ValueError('Unsupported net: %s' % net)
        elif 'vgg' in net:
            self.features = getattr(vgg, net)(pretrained=pretrained).get_features()
            self.num_features = 512
        elif 'resnet' in net:
            self.features = getattr(resnet, net)(pretrained=pretrained).get_features()
            self.num_features = 512 * self.features[-1][-1].expansion
        else:
            raise ValueError('Unsupported net: %s' % net)

        # Attention Maps
        self.attentions = BasicConv2d(self.num_features, self.M, kernel_size=1)

        # Bilinear Attention Pooling
        self.bap = BAP(pool='GAP')

        # Classification Layer
        self.fc = nn.Linear(self.M * self.num_features, self.num_classes, bias=False)

        logging.info('WSDAN: using {} as feature extractor, num_classes: {}, num_attentions: {}'.format(net, self.num_classes, self.M))

#config.py :net = 'inception_mixed_6e'

self.features = inception_v3(pretrained=pretrained).get_features_mixed_6e() 读取并调用inception_v3的预训练模型。使用与训练的mixed_6e层获取特征。

self.attentions = BasicConv2d(self.num_features, self.M, kernel_size=1) BasicConv2d 的效果等效于用一个1*1的卷积核对图像做卷积,同时利用BatchNorm2d 输出batch=self.M的图片。

class BasicConv2d(nn.Module):

    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x, inplace=True)

Conv2d 二维卷积。(输入通道数,输出通道数,有无偏移) 随机取一个数作为卷积子。

BatchNorm2d 在网络中输出的feature map 的尺寸为 B*C*H*W,经过处理后,尺寸变为B*M*H*W。

    def forward(self, x):
        batch_size = x.size(0)

        # Feature Maps, Attention Maps and Feature Matrix
        feature_maps = self.features(x)
        if self.net != 'inception_mixed_7c':
            attention_maps = self.attentions(feature_maps)
        else:
            attention_maps = feature_maps[:, :self.M, ...]
        feature_matrix = self.bap(feature_maps, attention_maps)

        # Classification
        p = self.fc(feature_matrix * 100.)

        # Generate Attention Map
        if self.training:
            # Randomly choose one of attention maps Ak
            attention_map = []
            for i in range(batch_size):
                attention_weights = torch.sqrt(attention_maps[i].sum(dim=(1, 2)).detach() + EPSILON)
                attention_weights = F.normalize(attention_weights, p=1, dim=0)
                k_index = np.random.choice(self.M, 2, p=attention_weights.cpu().numpy())
                attention_map.append(attention_maps[i, k_index, ...])
            attention_map = torch.stack(attention_map)  # (B, 2, H, W) - one for cropping, the other for dropping
        else:
            # Object Localization Am = mean(Ak)
            attention_map = torch.mean(attention_maps, dim=1, keepdim=True)  # (B, 1, H, W)

        # p: (B, self.num_classes)
        # feature_matrix: (B, M * C)
        # attention_map: (B, 2, H, W) in training, (B, 1, H, W) in val/testing
        return p, feature_matrix, attention_map

k_index = np.random.choice(self.M, 2, p=attention_weights.cpu().numpy()) p是概率,通过对图片像素加权的平方得到。 val或test时求均值。

    def load_state_dict(self, state_dict, strict=True):
        model_dict = self.state_dict()
        pretrained_dict = {k: v for k, v in state_dict.items()
                           if k in model_dict and model_dict[k].size() == v.size()}

        if len(pretrained_dict) == len(state_dict):
            logging.info('%s: All params loaded' % type(self).__name__)
        else:
            logging.info('%s: Some params were not loaded:' % type(self).__name__)
            not_loaded_keys = [k for k in state_dict.keys() if k not in pretrained_dict.keys()]
            logging.info(('%s, ' * (len(not_loaded_keys) - 1) + '%s') % tuple(not_loaded_keys))

        model_dict.update(pretrained_dict)
        super(WSDAN, self).load_state_dict(model_dict)
BAP(nn.Module)

self.bap = BAP(pool='GAP')

class BAP(nn.Module):
    def __init__(self, pool='GAP'):
        super(BAP, self).__init__()
        assert pool in ['GAP', 'GMP']
        if pool == 'GAP':
            self.pool = None
        else:
            self.pool = nn.AdaptiveMaxPool2d(1)

    def forward(self, features, attentions):
        B, C, H, W = features.size()
        _, M, AH, AW = attentions.size()

        # match size
        if AH != H or AW != W:
            attentions = F.upsample_bilinear(attentions, size=(H, W))

        # feature_matrix: (B, M, C) -> (B, M * C)
        if self.pool is None:
            feature_matrix = (torch.einsum('imjk,injk->imn', (attentions, features)) / float(H * W)).view(B, -1)
        else:
            feature_matrix = []
            for i in range(M):
                AiF = self.pool(features * attentions[:, i:i + 1, ...]).view(B, -1)
                feature_matrix.append(AiF)
            feature_matrix = torch.cat(feature_matrix, dim=1)

        # sign-sqrt
        feature_matrix = torch.sign(feature_matrix) * torch.sqrt(torch.abs(feature_matrix) + EPSILON)

        # l2 normalization along dimension M and C
        feature_matrix = F.normalize(feature_matrix, dim=-1)
        return feature_matrix

if AH != H or AW != W:

attentions = F.upsample_bilinear(attentions, size=(H, W)) 如果attention maps和feature高和宽不一致,采用双线性插值调整。

1、pool == None:

feature_matrix = (torch.einsum('imjk,injk->imn', (attentions, features)) / float(H * W)).view(B, -1) 把尺寸为i,m,j,k 的attention和尺寸为i,n,j,k的feature,按长宽j,k做内积。得到B*M*C的三维矩阵。除以H*W后,用函数view变为B行,M*C列的二维矩阵。

做完这些变换后得到的feature_matrix,相当于把每个通道的图片展成一维矩阵,然后M个一维矩阵拼接组成一行。feature_matrix的每一行是一张图片的所有特征。

# @zhong
import torch.nn as nn
import torch

m = nn.AdaptiveAvgPool2d(1)
features = torch.randn(2, 3, 3, 3)
attentions = torch.randn(2, 5, 3, 3)
F = torch.einsum('imjk,injk->imn', (attentions, features))
print('einsum', F)
print('after view', F.view(2, -1))
einsum tensor([[[  0.3450,  -3.1217,  -2.0802],
         [ -4.8347,  -2.9586,  -1.6513],
         [  1.8663,   0.3958,  -1.9813],
         [ -0.8181,  -2.9455,  -2.1229],
         [  2.4022,  -2.3604,   4.9093]],

        [[ -7.2237,   2.0915,   4.7289],
         [-12.7117,   0.9476,   4.7803],
         [ -0.4658,   3.7223,   2.7384],
         [ -3.6767,   0.4836,  -0.5864],
         [ -0.1416,   0.4434,  -1.8065]]])
after view tensor([[  0.3450,  -3.1217,  -2.0802,  -4.8347,  -2.9586,  -1.6513,   1.8663,
           0.3958,  -1.9813,  -0.8181,  -2.9455,  -2.1229,   2.4022,  -2.3604,
           4.9093],
        [ -7.2237,   2.0915,   4.7289, -12.7117,   0.9476,   4.7803,  -0.4658,
           3.7223,   2.7384,  -3.6767,   0.4836,  -0.5864,  -0.1416,   0.4434,
          -1.8065]])

2、pool != None:

AiF = self.pool(features * attentions[:, i:i + 1, ...]).view(B, -1)

features * attentions[:, i:i + 1, ...]首先,把M张尺寸为B*1*H*W的attention map, 依次和B*C*H*W的feature maps相乘。得到M份B*C*H*W 的part Feature maps。

接着,做自适应池化(AdaptiveAvgPool2d(1)),尺寸变为B*C*1*1。然后用函数view(B, -1) 变为大小为B*C的矩阵。 (x.view(batchsize, -1) 中batchsize指转换后有几行,而-1指在不告诉函数有多少列的情况下,根据原tensor数据和batchsize自动分配列数。)

feature_matrix.append(AiF) 循环结束后feature_matrix的大小为M个B*C的矩阵。

feature_matrix = torch.cat(feature_matrix, dim=1) 把feature_mareix 重新排列为B行,C*M列的矩阵。

# @zhong
import torch.nn as nn
import torch

m = nn.AdaptiveAvgPool2d(1)
features = torch.randn(2, 3, 3, 3)
attentions = torch.randn(2, 5, 3, 3)
pool = nn.AdaptiveMaxPool2d(1)

feature_matrix = []
for i in range(5):
    AiF = pool(features * attentions[:, i:i + 1, ...]).view(2, -1)
    feature_matrix.append(AiF)
    print('AiF:\n', AiF)
feature_matrix = torch.cat(feature_matrix, dim=1)
print('feature_matrix:\n', feature_matrix)
AiF:
 tensor([[0.2573, 1.4332, 5.2349],
        [0.3983, 1.6445, 1.8403]])
AiF:
 tensor([[1.1568, 1.7546, 2.5374],
        [3.7578, 0.8067, 1.0593]])
AiF:
 tensor([[1.1870, 2.7297, 1.0319],
        [1.4527, 0.9329, 0.6579]])
AiF:
 tensor([[0.1999, 0.0816, 2.8860],
        [1.6576, 0.8097, 1.1730]])
AiF:
 tensor([[1.5874, 6.1390, 4.3299],
        [1.1993, 2.7865, 0.5152]])
feature_matrix:
 tensor([[0.2573, 1.4332, 5.2349, 1.1568, 1.7546, 2.5374, 1.1870, 2.7297, 1.0319,
         0.1999, 0.0816, 2.8860, 1.5874, 6.1390, 4.3299],
        [0.3983, 1.6445, 1.8403, 3.7578, 0.8067, 1.0593, 1.4527, 0.9329, 0.6579,
         1.6576, 0.8097, 1.1730, 1.1993, 2.7865, 0.5152]])

feature_matrix = torch.sign(feature_matrix) * torch.sqrt(torch.abs(feature_matrix) + EPSILON) 对上面的结果做处理。EPSILON = 1e-12 使根号下不为零。

这个处理是为了得到最显著的特征。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-W641XNqi-1578628118026)(C:\Users\linksure\AppData\Roaming\Typora\typora-user-images\image-20191209194933601.png)]

feature_matrix = F.normalize(feature_matrix, dim=-1) 归一化。

Use cuda 显卡使用

    net.to(device)
    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)

把net加载到device中计算,如果CUDA中device的数量大于1,则使用并行计算。

Optimizer, LR Scheduler 优化器

    learning_rate = logs['lr'] if 'lr' in logs else config.learning_rate
    optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-5)

    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.9, patience=2)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.9)

torch.optim.SGD()采用随机梯度下降优化。lr学习速率。momentum冲量。 更新量:(当本次梯度下降- dx * lr的方向与上次更新量v的方向相同时,上次的更新量能够对本次的搜索起到一个正向加速的作用。当本次梯度下降- dx * lr的方向与上次更新量v的方向相反时,上次的更新量能够对本次的搜索起到一个减速的作用。)
v = − d x ∗ l r + v ∗ m o m e m t u m v=−dx∗lr+v∗momemtum v=dxlr+vmomemtum
weight_decay=1e-5 E(w)= E(w)+λ2w2 使用L2正则化防止过拟合。

torch.optim.lr_scheduler.StepLR() 调整学习速率。每个参数组的学习速率为:
l r ∗ λ n , n = e p o c h s t e p _ s i z e lr*\lambda^n,n = \frac{epoch}{step\_size} lrλn,n=step_sizeepoch
step_size(整数类型): 调整学习率的步长,每过step_size次,更新一次学习率。gamma(float 类型):学习率下降的乘数因子。

ModelCheckpoint 模型监控

    callback_monitor = 'val_{}'.format(raw_metric.name)
    callback = ModelCheckpoint(savepath=os.path.join(config.save_dir, config.model_name),
                               monitor=callback_monitor,
                               mode='max')
    if callback_monitor in logs:
        callback.set_best_score(logs[callback_monitor])
    else:
        callback.reset()

监控并保存最好的模型。调用 class ModelCheckpoint(Callback)

class ModelCheckpoint(Callback):
    def __init__(self, savepath, monitor='val_topk_accuracy', mode='max'):
        self.savepath = savepath
        self.monitor = monitor
        self.mode = mode
        self.reset()
        super(ModelCheckpoint, self).__init__()

    def reset(self):
        if self.mode == 'max':
            self.best_score = float('-inf')
        else:
            self.best_score = float('inf')

    def set_best_score(self, score):
        if isinstance(score, np.ndarray):
            self.best_score = score[0]
        else:
            self.best_score = score

    def on_epoch_begin(self):
        pass

    def on_epoch_end(self, logs, net, **kwargs):
        current_score = logs[self.monitor]
        if isinstance(current_score, np.ndarray):
            current_score = current_score[0]

        if (self.mode == 'max' and current_score > self.best_score) or \
            (self.mode == 'min' and current_score < self.best_score):
            self.best_score = current_score

            if isinstance(net, torch.nn.DataParallel):
                state_dict = net.module.state_dict()
            else:
                state_dict = net.state_dict()

            for key in state_dict.keys():
                state_dict[key] = state_dict[key].cpu()

            if 'feature_center' in kwargs:
                feature_center = kwargs['feature_center']
                feature_center = feature_center.cpu()

                torch.save({
                    'logs': logs,
                    'state_dict': state_dict,
                    'feature_center': feature_center}, self.savepath)
            else:
                torch.save({
                    'logs': logs,
                    'state_dict': state_dict}, self.savepath)

Training 训练

    logging.info('Start training: Total epochs: {}, Batch size: {}, Training size: {}, Validation size: {}'.
                 format(config.epochs, config.batch_size, len(train_dataset), len(validate_dataset)))
    logging.info('')

    for epoch in range(start_epoch, config.epochs):
        callback.on_epoch_begin()

        logs['epoch'] = epoch + 1
        logs['lr'] = optimizer.param_groups[0]['lr']

        logging.info('Epoch {:03d}, Learning Rate {:g}'.format(epoch + 1, optimizer.param_groups[0]['lr']))

        pbar = tqdm(total=len(train_loader), unit=' batches')
        pbar.set_description('Epoch {}/{}'.format(epoch + 1, config.epochs))

        train(logs=logs,
              data_loader=train_loader,
              net=net,
              feature_center=feature_center,
              optimizer=optimizer,
              pbar=pbar)
        validate(logs=logs,
                 data_loader=validate_loader,
                 net=net,
                 pbar=pbar)

        if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
            scheduler.step(logs['val_loss'])
        else:
            scheduler.step()

        callback.on_epoch_end(logs, net, feature_center=feature_center)
        pbar.close()

logging 输出运行日志。

从起始epoch 到设定的最后一个epoch,callback.on_epoch_begin()是第一个epoch的话就pass。

pbar = tqdm(total=len(train_loader), unit=' batches')加载进度条提示。

调用def traindef validate

isinstance() 函数:函数来判断一个对象是否是一个已知的类型(考虑继承关系)。

scheduler.step()对学习速率进行调整。

def train
def train(**kwargs):
    # Retrieve training configuration
    logs = kwargs['logs']
    data_loader = kwargs['data_loader']
    net = kwargs['net']
    feature_center = kwargs['feature_center']
    optimizer = kwargs['optimizer']
    pbar = kwargs['pbar']

**kwargs: **会以键/值对的形式解包一个字典,使其成为独立的关键字参数。

    # metrics initialization
    loss_container.reset()
    raw_metric.reset()
    crop_metric.reset()
    drop_metric.reset()

    # begin training
    start_time = time.time()
    net.train()
    for i, (X, y) in enumerate(data_loader):
        optimizer.zero_grad()

        # obtain data for training
        X = X.to(device)
        y = y.to(device)

参数调用reset()函数,回到最初设置(全零矩阵)。

def reset(self):
    self.corrects = np.zeros(len(self.topk))
    self.num_samples = 0.

time.time() 返回当前时间。

enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标.

optimizer.zero_grad()将梯度初始化为零。

        ##################################
        # Raw Image
        ##################################
        # raw images forward
        y_pred_raw, feature_matrix, attention_map = net(X)

        # Update Feature Center
        feature_center_batch = F.normalize(feature_center[y], dim=-1)
        feature_center[y] += config.beta * (feature_matrix.detach() - feature_center_batch)

更新feature_center。初始的feature_center是一个全零矩阵。feature_center =zeros(num_classes, config.num_attentions * net.num_features)

detach()返回一个新的Variable,从当前计算图中分离下来的,但是仍指向原变量的存放位置。作差之后,不断叠加变化量到feature_matrix。

        ##################################
        # Attention Cropping
        ##################################
        with torch.no_grad():
            crop_images = batch_augment(X, attention_map[:, :1, :, :], mode='crop', theta=(0.4, 0.6), padding_ratio=0.1)

        # crop images forward
        y_pred_crop, _, _ = net(crop_images)

torch.no_grad()不需要梯度。

调用函数def batch_augument()

def batch_augument()
def batch_augment(images, attention_map, mode='crop', theta=0.5, padding_ratio=0.1):
    batches, _, imgH, imgW = images.size()

    if mode == 'crop':
        crop_images = []
        for batch_index in range(batches):
            atten_map = attention_map[batch_index:batch_index + 1]
            if isinstance(theta, tuple):
                theta_c = random.uniform(*theta) * atten_map.max() 
            else:
                theta_c = theta * atten_map.max()

            crop_mask = F.upsample_bilinear(atten_map, size=(imgH, imgW)) >= theta_c
            nonzero_indices = torch.nonzero(crop_mask[0, 0, ...])
            height_min = max(int(nonzero_indices[:, 0].min().item() - padding_ratio * imgH), 0)
            height_max = min(int(nonzero_indices[:, 0].max().item() + padding_ratio * imgH), imgH)
            width_min = max(int(nonzero_indices[:, 1].min().item() - padding_ratio * imgW), 0)
            width_max = min(int(nonzero_indices[:, 1].max().item() + padding_ratio * imgW), imgW)

            crop_images.append(
                F.upsample_bilinear(images[batch_index:batch_index + 1, :, height_min:height_max, width_min:width_max],
                                    size=(imgH, imgW)))
        crop_images = torch.cat(crop_images, dim=0)
        return crop_images

    elif mode == 'drop':
        drop_masks = []
        for batch_index in range(batches):
            atten_map = attention_map[batch_index:batch_index + 1]
            if isinstance(theta, tuple):
                theta_d = random.uniform(*theta) * atten_map.max()
            else:
                theta_d = theta * atten_map.max()

            drop_masks.append(F.upsample_bilinear(atten_map, size=(imgH, imgW)) < theta_d)
        drop_masks = torch.cat(drop_masks, dim=0)
        drop_images = images * drop_masks.float()
        return drop_images

crop模式:

uniform()方法将随机生成下一个实数。theta为数组时,每次从theta给出的范围中随机生成一个实数。

阈值为theta * atten_map.max()

crop_mask = F.upsample_bilinear(atten_map, size=(imgH, imgW)) >= theta_c 对atten_map(每次循环,从attention maps 取1张图片作为atten_map)做双线性上采样,大于阈值的部分置为true(1),小于阈值的部分置为false(0)。

torch.nonzero(crop_mask[0, 0, ...]) crop_mask的第一个通道,非零值的位置。结果的第一列是图片的行号,第二列是列号。

然后找到第一列的最大值和最小值(向外扩张padding_ratio * imgH)的差值作为高,第二列的最大值与最小值(向外扩张padding_ratio * imgW)的差值作为宽。

crop_images 通过对images(X[:, m:n],即取所有数据的第m到n-1列数据,含左不含右)按上式获得的边界截取得到。

drop模式:

drop_masks.append(F.upsample_bilinear(atten_map, size=(imgH, imgW)) < theta_d) 做双线性上采样,小于阈值置为0,大于置为1。

drop_masks = torch.cat(drop_masks, dim=0)把drop_mask按列合并,变成B*(M-1)*H*W。

drop_mask与image相乘,得到dropping后的图片。

        ##################################
        # Attention Dropping
        ##################################
        with torch.no_grad():
            drop_images = batch_augment(X, attention_map[:, 1:, :, :], mode='drop', theta=(0.2, 0.5))

        # drop images forward
        y_pred_drop, _, _ = net(drop_images)

crop中输入的attention map是attention_map[:, :1, :, :] attention maps的第一张图片。

drop中输入的attention map是attention_map[:, 1:, :, :] attention maps:除去第一张图片外的所有图片。

(@zhong 想法:同一个batch里的32张attention maps描述的是同一个特征。crop只要得到一个特征的放大图,只需要找到min和max像素的位置,所以一张特征图足够。drop不同,它是把所有小于阈值的像素置0,操作是像素级,用数量多的attention maps可以涵盖尽量大且准确的特征区。)

        # loss
        batch_loss = cross_entropy_loss(y_pred_raw, y) / 3. + \
                     cross_entropy_loss(y_pred_crop, y) / 3. + \
                     cross_entropy_loss(y_pred_drop, y) / 3. + \
                     center_loss(feature_matrix, feature_center_batch)

        # backward
        batch_loss.backward()
        optimizer.step()

        # metrics: loss and top-1,5 error
        with torch.no_grad():
            epoch_loss = loss_container(batch_loss.item())
            epoch_raw_acc = raw_metric(y_pred_raw, y)
            epoch_crop_acc = crop_metric(y_pred_crop, y)
            epoch_drop_acc = drop_metric(y_pred_drop, y)

batch_loss采用交叉熵验证。考虑四个预测部分,结果相加。

backward反向传播,optimizer.step() 根据网络反向传播的梯度来更新网络参数。

metrics :epoch_loss采用AverageMeter,计算多个类的平均值。epoch_raw,crop,drop采用TopKAccuracyMetric(选取概率最大的 k 个标签)计算平均值。

        # end of this batch
        batch_info = 'Loss {:.4f}, Raw Acc ({:.2f}, {:.2f}), Crop Acc ({:.2f}, {:.2f}), Drop Acc ({:.2f}, {:.2f})'.format(
            epoch_loss, epoch_raw_acc[0], epoch_raw_acc[1],
            epoch_crop_acc[0], epoch_crop_acc[1], epoch_drop_acc[0], epoch_drop_acc[1])
        pbar.update()
        pbar.set_postfix_str(batch_info)

    # end of this epoch
    logs['train_{}'.format(loss_container.name)] = epoch_loss
    logs['train_raw_{}'.format(raw_metric.name)] = epoch_raw_acc
    logs['train_crop_{}'.format(crop_metric.name)] = epoch_crop_acc
    logs['train_drop_{}'.format(drop_metric.name)] = epoch_drop_acc
    logs['train_info'] = batch_info
    end_time = time.time()

    # write log for this epoch
    logging.info('Train: {}, Time {:3.2f}'.format(batch_info, end_time - start_time))

end of this batch:更新进度条。

end of this epoch:把acc写进日志

def validate
    # Retrieve training configuration
    logs = kwargs['logs']
    data_loader = kwargs['data_loader']
    net = kwargs['net']
    pbar = kwargs['pbar']

    # metrics initialization
    loss_container.reset()
    raw_metric.reset()

加载日志。loss_container,raw_metric归零。

    # begin validation
    start_time = time.time()
    net.eval()
    with torch.no_grad():
        for i, (X, y) in enumerate(data_loader):
            # obtain data
            X = X.to(device)
            y = y.to(device)

            ##################################
            # Raw Image
            ##################################
            y_pred_raw, _, attention_map = net(X)

加载数据,获取预测值和attention_map。

            ##################################
            # Object Localization and Refinement
            ##################################
            crop_images = batch_augment(X, attention_map, mode='crop', theta=0.1, padding_ratio=0.05)
            y_pred_crop, _, _ = net(crop_images)

            ##################################
            # Final prediction
            ##################################
            y_pred = (y_pred_raw + y_pred_crop) / 2.

            # loss
            batch_loss = cross_entropy_loss(y_pred, y)
            epoch_loss = loss_container(batch_loss.item())

            # metrics: top-1,5 error
            epoch_acc = raw_metric(y_pred, y)

Object Localization and Refinement:通过crop图片来优化预测。

Final prediction:最后的预测值等于raw image的预测值和crop image的预测值。

计算loss 和 acc。

# end of validation
    logs['val_{}'.format(loss_container.name)] = epoch_loss
    logs['val_{}'.format(raw_metric.name)] = epoch_acc
    end_time = time.time()

    batch_info = 'Val Loss {:.4f}, Val Acc ({:.2f}, {:.2f})'.format(epoch_loss, epoch_acc[0], epoch_acc[1])
    pbar.set_postfix_str('{}, {}'.format(logs['train_info'], batch_info))

    # write log for this epoch
    logging.info('Valid: {}, Time {:3.2f}'.format(batch_info, end_time - start_time))
    logging.info('')


if __name__ == '__main__':
    main()

你可能感兴趣的:(文献阅读)