源代码地址:Swin-Transformer
本机为Ubuntu系统,为了训练自己的数据集,在原代码的基础上做了一点小调整:
── imagenet
├── train
│ ├── class1
│ │ ├── cat0001.jpg
│ │ ├── cat0002.jpg
│ │ └── ...
│ ├── class2
│ │ ├── dog0001.jpg
│ │ ├── dog0002.jpg
│ │ └── ...
│ └── class3
│ ├── bird0001.jpg
│ ├── bird0002.jpg
│ └── ...
└── val
├── class1
├── class2
└── class3
以swinv2_base_patch4_window12_192_22k.yaml
为例
DATA:
# 为了配合上方的数据集存放格式,DATASET的value需设置为imagenet
DATASET: imagenet
IMG_SIZE: 384
# NAME_CLASSES是自己增加的,在推理阶段可视化时使用
NAME_CLASSES: ["cat", "dog", "bird"]
MODEL:
TYPE: swinv2
NAME: swinv2_base_patch4_window12_192_22k
DROP_PATH_RATE: 0.2
# NUM_CLASSES是增加进来的默认是1000
NUM_CLASSES: 3
SWINV2:
EMBED_DIM: 128
DEPTHS: [ 2, 2, 18, 2 ]
NUM_HEADS: [ 4, 8, 16, 32 ]
WINDOW_SIZE: 12
TRAIN:
EPOCHS: 90
WARMUP_EPOCHS: 5
WEIGHT_DECAY: 0.1
BASE_LR: 1.25e-4 # 4096 batch-size
WARMUP_LR: 1.25e-7
MIN_LR: 1.25e-6
针对上方的调整相应地需要修改config.py
文件
_C.DATA = CN()
# 增加NAME_CLASSES字段的默认值
_C.DATA.NAME_CLASSES = []
main.py
if __name__ == '__main__':
args, config = parse_option()
# 训练环境为本地单机单卡,手动写入环境变量中一些字段
os.environ['WORLD_SIZE'] = '1'
os.environ['RANK'] = '0'
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12345'
# ...
if config.TRAIN.AUTO_RESUME:
resume_file = auto_resume_helper(config.OUTPUT, get_best=True)
# 原代码中计算acc时输出的是top-1 acc和top-5 acc,但我自己的数据集只有3个类别
# 所以调整为输出top-1 acc和top-2 acc
# 增加了每个类别的acc的输出
def validate(config, data_loader, model):
criterion = torch.nn.CrossEntropyLoss()
model.eval()
batch_time = AverageMeter()
loss_meter = AverageMeter()
acc1_meter = AverageMeter()
acc2_meter = AverageMeter()
cla_num_meter = np.zeros(config.MODEL.NUM_CLASSES)
pre_num_meter = np.zeros(config.MODEL.NUM_CLASSES)
end = time.time()
for idx, (images, target) in enumerate(data_loader):
images = images.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
# compute output
with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):
output = model(images)
# measure accuracy and record loss
loss = criterion(output, target)
acc1, acc2 = accuracy(output, target, topk=(1, 2))
cla_num, pre_num = cla_accuracy(output, target, config.MODEL.NUM_CLASSES)
cla_num_meter += cla_num
pre_num_meter += pre_num
acc1 = reduce_tensor(acc1)
acc2 = reduce_tensor(acc2)
loss = reduce_tensor(loss)
loss_meter.update(loss.item(), target.size(0))
acc1_meter.update(acc1.item(), target.size(0))
acc2_meter.update(acc2.item(), target.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if idx % config.PRINT_FREQ == 0:
memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
logger.info(
f'Test: [{idx}/{len(data_loader)}]\t'
f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
f'Acc@2 {acc2_meter.val:.3f} ({acc2_meter.avg:.3f})\t'
f'Mem {memory_used:.0f}MB')
logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@2 {acc2_meter.avg:.3f}')
ans = ''
acc_each_class = [pre_num_meter[i] / cla_num_meter[i] for i in range(config.MODEL.NUM_CLASSES)]
for i in range(config.MODEL.NUM_CLASSES):
ans += f'Acc of {config.DATA.NAME_CLASSES[i]}: {acc_each_class[i]}\t'
logger.info(ans)
return acc1_meter.avg, acc2_meter.avg, loss_meter.avg
def cla_accuracy(output, target, num_class):
# 计算每个类别的实际数目和识别正确数目
_, pred = output.topk(1, 1, True, True)
pred = pred.t()[0]
sam_nums = np.zeros(num_class)
pre_cor_nums = np.zeros(num_class)
for i in range(len(target)):
sam_nums[int(target[i])] += 1
if int(target[i]) == int(pred[i]):
pre_cor_nums[int(target[i])] += 1
return sam_nums, pre_cor_nums
# 原代码每个epoch保存一个模型,调整为只保存best_ckpt.pth和last_epoch_ckpt.pth
for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
data_loader_train.sampler.set_epoch(epoch)
train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler,
loss_scaler)
acc1, acc2, loss = validate(config, data_loader_val, model)
if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)):
if acc1 > max_accuracy:
ckpt_name = "best_ckpt"
else:
ckpt_name = "last_epoch_ckpt"
save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, loss_scaler,
logger, ckpt_name)
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
max_accuracy = max(max_accuracy, acc1)
logger.info(f'Max accuracy: {max_accuracy:.2f}%')
data/build.py
def build_loader(config):
config.defrost()
# 原代码为dataset_train, config.MODEL.NUM_CLASSES =
# 我们在config文件中已经指明了数据集类别数
dataset_train, _ = build_dataset(is_train=True, config=config)
utils.py
# 修改代码resume时调用的是best_ckpt.pth
def auto_resume_helper(output_dir, get_best=False):
checkpoints = os.listdir(output_dir)
checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')]
print(f"All checkpoints founded in {output_dir}: {checkpoints}")
# 原本的代码是采用时间最近的模型,调整为读取best_ckpt.pth
if len(checkpoints) > 0 and not get_best:
latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime)
print(f"The latest checkpoint founded: {latest_checkpoint}")
resume_file = latest_checkpoint
elif get_best and "best_ckpt.pth" in checkpoints:
print(f"The best checkpoint founded: {os.path.join(output_dir, 'best_ckpt.pth')}")
resume_file = os.path.join(output_dir, 'best_ckpt.pth')
else:
resume_file = None
return resume_file
def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, loss_scaler, logger, ckpt_name):
save_state = {'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'max_accuracy': max_accuracy,
'scaler': loss_scaler.state_dict(),
'epoch': epoch,
'config': config}
save_path = os.path.join(config.OUTPUT, f'{ckpt_name}.pth')
logger.info(f"{save_path} saving......")
torch.save(save_state, save_path)
logger.info(f"{save_path} saved !!!")
python main.py --cfg configs/swinv2/swinv2_base_patch4_window12_192_22k.yaml --batch-size 4 --data-path imagenet --pretrained swinv2_base_patch4_window12_192_22k.pth --local_rank 0
python main.py --eval --cfg configs/swinv2/swinv2_base_patch4_window12_192_22k.yaml --resume output/swinv2_base_patch4_window12_192_22k/default/best_ckpt.pth --data-path imagenet --local_rank 0
原作者没有提供inference代码,根据evaluate流程写一个简单的推理脚本。
import os
import argparse
from torch.autograd import Variable
import cv2
import torch
from torchvision import transforms
from config import get_config
from models import build_model
from PIL import Image
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
try:
from torchvision.transforms import InterpolationMode
def _pil_interp(method):
if method == 'bicubic':
return InterpolationMode.BICUBIC
elif method == 'lanczos':
return InterpolationMode.LANCZOS
elif method == 'hamming':
return InterpolationMode.HAMMING
else:
# default bilinear, do we want to allow nearest?
return InterpolationMode.BILINEAR
import timm.data.transforms as timm_transforms
timm_transforms._pil_interp = _pil_interp
except:
from timm.data.transforms import _pil_interp
def parse_option():
parser = argparse.ArgumentParser('Swin Transformer inference script', add_help=False)
parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
parser.add_argument(
"--opts",
help="Modify config options by adding 'KEY VALUE' pairs. ",
default=None,
nargs='+',
)
# easy config modification
parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
parser.add_argument('--data-path', type=str, help='path to dataset')
parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
help='no: no cache, '
'full: cache all data, '
'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
parser.add_argument('--pretrained',
help='pretrained weight from checkpoint, could be imagenet22k pretrained weight')
parser.add_argument('--resume', help='resume from checkpoint')
parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
parser.add_argument('--use-checkpoint', action='store_true',
help="whether to use gradient checkpointing to save memory")
parser.add_argument('--disable_amp', action='store_true', help='Disable pytorch amp')
parser.add_argument('--amp-opt-level', type=str, choices=['O0', 'O1', 'O2'],
help='mixed precision opt level, if O0, no amp is used (deprecated!)')
parser.add_argument('--output', default='output', type=str, metavar='PATH',
help='root of output folder, the full path is )
parser.add_argument('--tag', help='tag of experiment')
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
parser.add_argument('--throughput', action='store_true', help='Test throughput only')
# distributed training
parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
# for acceleration
parser.add_argument('--fused_window_process', action='store_true',
help='Fused window shift & window partition, similar for reversed part.')
parser.add_argument('--fused_layernorm', action='store_true', help='Use fused layernorm.')
## overwrite optimizer in config (*.yaml) if specified, e.g., fused_adam/fused_lamb
parser.add_argument('--optim', type=str,
help='overwrite optimizer if provided, can be adamw/sgd/fused_adam/fused_lamb.')
args, unparsed = parser.parse_known_args()
config = get_config(args)
return args, config
if __name__ == '__main__':
args, config = parse_option()
transform_test = transforms.Compose(
[transforms.Resize(
(config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
interpolation=_pil_interp(config.DATA.INTERPOLATION)),
transforms.ToTensor(),
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
]
)
classes = config.DATA.NAME_CLASSES
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = build_model(config)
checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu')
model.load_state_dict(checkpoint['model'], strict=False)
model.eval()
model.to(DEVICE)
path = config.DATA.DATA_PATH
testList = os.listdir(path)
for file in testList:
img = Image.open(os.path.join(path + file))
img = transform_test(img)
img.unsqueeze_(0)
img = Variable(img).to(DEVICE)
out = model(img)
_,pred = torch.max(out.data, 1)
ori_img = cv2.imread(os.path.join(path + file))
text = 'ImageName:{}, predict:{}'.format(file, classes[pred.data.item()])
font = cv2.FONT_HERSHEY_SIMPLEX
txt_size = cv2.getTextSize(text, font, 0.7, 1)[0]
x0 = int(ori_img.shape[1] / 2.0)
cv2.putText(ori_img, text, (x0 - int(txt_size[0] / 2.0), int(0 + txt_size[1])), font, 0.7, (0, 0, 255), thickness=1)
cv2.imshow(os.path.join(path, file), ori_img)
cv2.waitKey(0)
cv2.destroyAllWindows()
python inference.py --cfg configs/swinv2/swinv2_base_patch4_window12_192_22k.yaml --data-path images/ --pretrained output/swinv2_base_patch4_window12_192_22k/default/best_ckpt.pth --local_rank 0