Swin-Transformer图像分类

文章目录

      • 1. 准备数据集
        • 1.1 数据集存放格式
        • 1.2 config配置文件
      • 2. 训练
        • 2.1 代码中调整了的部分
        • 2.2 训练命令
      • 3. 评估
      • 4. 推理
        • 4.1 推理脚本
        • 4.2 推理命令
        • 4.3 推理结果

源代码地址:Swin-Transformer
本机为Ubuntu系统,为了训练自己的数据集,在原代码的基础上做了一点小调整:

  • 原代码中每个epoch保存一个模型,调整为只保存表现最佳的模型最后一个epoch的模型
  • 原代码训练的ImageNet数据集,数据类别比较多,输出了两个评估指标:Top-1 Acc和Top-5 Acc,但我自己数据集只有3个类别,调整为输出Top-1 Acc和Top-2 Acc(其实Top-2 Acc没啥用,不输出也可以的)
  • 原代码未细化每个类别的Acc,简单补充了下该信息在终端的输出
  • 原代码没有推理脚本,简单补充了一个

1. 准备数据集

1.1 数据集存放格式

── imagenet
├── train
│   ├── class1
│   │   ├── cat0001.jpg
│   │   ├── cat0002.jpg
│   │   └── ...
│   ├── class2
│   │   ├── dog0001.jpg
│   │   ├── dog0002.jpg
│   │   └── ...
│   └── class3
│       ├── bird0001.jpg
│       ├── bird0002.jpg
│       └── ...
└── val
    ├── class1
    ├── class2
    └── class3

1.2 config配置文件

swinv2_base_patch4_window12_192_22k.yaml为例

DATA:
  # 为了配合上方的数据集存放格式,DATASET的value需设置为imagenet
  DATASET: imagenet
  IMG_SIZE: 384
  # NAME_CLASSES是自己增加的,在推理阶段可视化时使用
  NAME_CLASSES: ["cat", "dog", "bird"]
MODEL:
  TYPE: swinv2
  NAME: swinv2_base_patch4_window12_192_22k
  DROP_PATH_RATE: 0.2
  # NUM_CLASSES是增加进来的默认是1000
  NUM_CLASSES: 3
  SWINV2:
    EMBED_DIM: 128
    DEPTHS: [ 2, 2, 18, 2 ]
    NUM_HEADS: [ 4, 8, 16, 32 ]
    WINDOW_SIZE: 12
TRAIN:
  EPOCHS: 90
  WARMUP_EPOCHS: 5
  WEIGHT_DECAY: 0.1
  BASE_LR: 1.25e-4 # 4096 batch-size
  WARMUP_LR: 1.25e-7
  MIN_LR: 1.25e-6

针对上方的调整相应地需要修改config.py文件

_C.DATA = CN()
# 增加NAME_CLASSES字段的默认值
_C.DATA.NAME_CLASSES = []

2. 训练

2.1 代码中调整了的部分

  • main.py
if __name__ == '__main__':
    args, config = parse_option()
	
    # 训练环境为本地单机单卡,手动写入环境变量中一些字段
    os.environ['WORLD_SIZE'] = '1'
    os.environ['RANK'] = '0'
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12345'
    # ...
if config.TRAIN.AUTO_RESUME:
	resume_file = auto_resume_helper(config.OUTPUT, get_best=True)
# 原代码中计算acc时输出的是top-1 acc和top-5 acc,但我自己的数据集只有3个类别
# 所以调整为输出top-1 acc和top-2 acc
# 增加了每个类别的acc的输出
def validate(config, data_loader, model):
    criterion = torch.nn.CrossEntropyLoss()
    model.eval()

    batch_time = AverageMeter()
    loss_meter = AverageMeter()
    acc1_meter = AverageMeter()
    acc2_meter = AverageMeter()
    cla_num_meter = np.zeros(config.MODEL.NUM_CLASSES)
    pre_num_meter = np.zeros(config.MODEL.NUM_CLASSES)

    end = time.time()
    for idx, (images, target) in enumerate(data_loader):
        images = images.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        # compute output
        with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):
            output = model(images)

        # measure accuracy and record loss
        loss = criterion(output, target)
        acc1, acc2 = accuracy(output, target, topk=(1, 2))
        cla_num, pre_num = cla_accuracy(output, target, config.MODEL.NUM_CLASSES)
        cla_num_meter += cla_num
        pre_num_meter += pre_num

        acc1 = reduce_tensor(acc1)
        acc2 = reduce_tensor(acc2)
        loss = reduce_tensor(loss)

        loss_meter.update(loss.item(), target.size(0))
        acc1_meter.update(acc1.item(), target.size(0))
        acc2_meter.update(acc2.item(), target.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if idx % config.PRINT_FREQ == 0:
            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
            logger.info(
                f'Test: [{idx}/{len(data_loader)}]\t'
                f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
                f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
                f'Acc@2 {acc2_meter.val:.3f} ({acc2_meter.avg:.3f})\t'
                f'Mem {memory_used:.0f}MB')
    logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@2 {acc2_meter.avg:.3f}')
    ans = ''
    acc_each_class = [pre_num_meter[i] / cla_num_meter[i] for i in range(config.MODEL.NUM_CLASSES)]
    for i in range(config.MODEL.NUM_CLASSES):
        ans += f'Acc of {config.DATA.NAME_CLASSES[i]}: {acc_each_class[i]}\t'
    logger.info(ans)
    return acc1_meter.avg, acc2_meter.avg, loss_meter.avg

def cla_accuracy(output, target, num_class):
    # 计算每个类别的实际数目和识别正确数目
    _, pred = output.topk(1, 1, True, True)
    pred = pred.t()[0]
    sam_nums = np.zeros(num_class)
    pre_cor_nums = np.zeros(num_class)
    for i in range(len(target)):
        sam_nums[int(target[i])] += 1
        if int(target[i]) == int(pred[i]):
            pre_cor_nums[int(target[i])] += 1
    return sam_nums, pre_cor_nums
# 原代码每个epoch保存一个模型,调整为只保存best_ckpt.pth和last_epoch_ckpt.pth
for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
    data_loader_train.sampler.set_epoch(epoch)

    train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler,
                    loss_scaler)

    acc1, acc2, loss = validate(config, data_loader_val, model)

    if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)):
        if acc1 > max_accuracy:
            ckpt_name = "best_ckpt"
        else:
            ckpt_name = "last_epoch_ckpt"
        save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, loss_scaler,
                        logger, ckpt_name)
    
    logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
    max_accuracy = max(max_accuracy, acc1)
    logger.info(f'Max accuracy: {max_accuracy:.2f}%')
  • data/build.py
def build_loader(config):
    config.defrost()
    # 原代码为dataset_train, config.MODEL.NUM_CLASSES = 
    # 我们在config文件中已经指明了数据集类别数
    dataset_train, _ = build_dataset(is_train=True, config=config)
  • utils.py
# 修改代码resume时调用的是best_ckpt.pth
def auto_resume_helper(output_dir, get_best=False):
    checkpoints = os.listdir(output_dir)
    checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')]
    print(f"All checkpoints founded in {output_dir}: {checkpoints}")
    #  原本的代码是采用时间最近的模型,调整为读取best_ckpt.pth
    if len(checkpoints) > 0 and not get_best:
        latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime)
        print(f"The latest checkpoint founded: {latest_checkpoint}")
        resume_file = latest_checkpoint
    elif get_best and "best_ckpt.pth" in checkpoints:
        print(f"The best checkpoint founded: {os.path.join(output_dir, 'best_ckpt.pth')}")
        resume_file = os.path.join(output_dir, 'best_ckpt.pth')
    else:
        resume_file = None
    return resume_file
def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, loss_scaler, logger, ckpt_name):
    
    save_state = {'model': model.state_dict(),
                  'optimizer': optimizer.state_dict(),
                  'lr_scheduler': lr_scheduler.state_dict(),
                  'max_accuracy': max_accuracy,
                  'scaler': loss_scaler.state_dict(),
                  'epoch': epoch,
                  'config': config}

    save_path = os.path.join(config.OUTPUT, f'{ckpt_name}.pth')
    logger.info(f"{save_path} saving......")
    torch.save(save_state, save_path)
    logger.info(f"{save_path} saved !!!")

2.2 训练命令

python main.py --cfg configs/swinv2/swinv2_base_patch4_window12_192_22k.yaml --batch-size 4 --data-path imagenet --pretrained swinv2_base_patch4_window12_192_22k.pth --local_rank 0

3. 评估

python main.py --eval --cfg configs/swinv2/swinv2_base_patch4_window12_192_22k.yaml --resume output/swinv2_base_patch4_window12_192_22k/default/best_ckpt.pth --data-path imagenet --local_rank 0

评估阶段的终端输出:
在这里插入图片描述

4. 推理

4.1 推理脚本

原作者没有提供inference代码,根据evaluate流程写一个简单的推理脚本。

import os
import argparse
from torch.autograd import Variable
import cv2

import torch
from torchvision import transforms

from config import get_config
from models import build_model
from PIL import Image

from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

try:
    from torchvision.transforms import InterpolationMode


    def _pil_interp(method):
        if method == 'bicubic':
            return InterpolationMode.BICUBIC
        elif method == 'lanczos':
            return InterpolationMode.LANCZOS
        elif method == 'hamming':
            return InterpolationMode.HAMMING
        else:
            # default bilinear, do we want to allow nearest?
            return InterpolationMode.BILINEAR


    import timm.data.transforms as timm_transforms

    timm_transforms._pil_interp = _pil_interp
except:
    from timm.data.transforms import _pil_interp


def parse_option():
    parser = argparse.ArgumentParser('Swin Transformer inference script', add_help=False)
    parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
    parser.add_argument(
        "--opts",
        help="Modify config options by adding 'KEY VALUE' pairs. ",
        default=None,
        nargs='+',
    )

    # easy config modification
    parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
    parser.add_argument('--data-path', type=str, help='path to dataset')
    parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
    parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
                        help='no: no cache, '
                             'full: cache all data, '
                             'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
    parser.add_argument('--pretrained',
                        help='pretrained weight from checkpoint, could be imagenet22k pretrained weight')
    parser.add_argument('--resume', help='resume from checkpoint')
    parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
    parser.add_argument('--use-checkpoint', action='store_true',
                        help="whether to use gradient checkpointing to save memory")
    parser.add_argument('--disable_amp', action='store_true', help='Disable pytorch amp')
    parser.add_argument('--amp-opt-level', type=str, choices=['O0', 'O1', 'O2'],
                        help='mixed precision opt level, if O0, no amp is used (deprecated!)')
    parser.add_argument('--output', default='output', type=str, metavar='PATH',
                        help='root of output folder, the full path is // (default: output)')
    parser.add_argument('--tag', help='tag of experiment')
    parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
    parser.add_argument('--throughput', action='store_true', help='Test throughput only')

    # distributed training
    parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')

    # for acceleration
    parser.add_argument('--fused_window_process', action='store_true',
                        help='Fused window shift & window partition, similar for reversed part.')
    parser.add_argument('--fused_layernorm', action='store_true', help='Use fused layernorm.')
    ## overwrite optimizer in config (*.yaml) if specified, e.g., fused_adam/fused_lamb
    parser.add_argument('--optim', type=str,
                        help='overwrite optimizer if provided, can be adamw/sgd/fused_adam/fused_lamb.')

    args, unparsed = parser.parse_known_args()

    config = get_config(args)

    return args, config

if __name__ == '__main__':
    args, config = parse_option()
    
    transform_test = transforms.Compose(
        [transforms.Resize(
            (config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), 
            interpolation=_pil_interp(config.DATA.INTERPOLATION)),
            transforms.ToTensor(),
            transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
            ]
    )
    classes = config.DATA.NAME_CLASSES
    DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = build_model(config)
    checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu')
    model.load_state_dict(checkpoint['model'], strict=False)
    model.eval()
    model.to(DEVICE)
    path = config.DATA.DATA_PATH
    testList = os.listdir(path)
    for file in testList:
        img = Image.open(os.path.join(path + file))
        img = transform_test(img)
        img.unsqueeze_(0)
        img = Variable(img).to(DEVICE)
        out = model(img)
        _,pred = torch.max(out.data, 1)
        ori_img = cv2.imread(os.path.join(path + file))
        text = 'ImageName:{}, predict:{}'.format(file, classes[pred.data.item()])
        font = cv2.FONT_HERSHEY_SIMPLEX
        txt_size = cv2.getTextSize(text, font, 0.7, 1)[0]
        x0 = int(ori_img.shape[1] / 2.0)
        cv2.putText(ori_img, text, (x0 - int(txt_size[0] / 2.0), int(0 + txt_size[1])), font, 0.7, (0, 0, 255), thickness=1)
        cv2.imshow(os.path.join(path, file), ori_img)
        cv2.waitKey(0)
        cv2.destroyAllWindows()

4.2 推理命令

python inference.py --cfg configs/swinv2/swinv2_base_patch4_window12_192_22k.yaml --data-path images/ --pretrained output/swinv2_base_patch4_window12_192_22k/default/best_ckpt.pth --local_rank 0

4.3 推理结果

Swin-Transformer图像分类_第1张图片

你可能感兴趣的:(transformer,分类,深度学习)