DETR测试自己数据的方法(facebook)

废话不多说,直接上代码

import argparse
import datetime
import json
import random
import time
from pathlib import Path
from PIL import Image
from datasets.KINS import make_coco_transforms
import numpy as np
import torch
from torch.utils.data import DataLoader, DistributedSampler

import datasets
import util.misc as utils
from datasets import build_dataset, get_coco_api_from_dataset
from engine import evaluate, train_one_epoch
from models import build_model

# TODO:修改自己数据集的类别
CLASSES = [
    'pedestrain', 'cyclist', 'person-sitting', 'van', 'car','tram','truck','misc','bicyccle','N/A'
]
COLORS = [
    [0.000,0.447,0.741], [0.850,0.325,0.098], [0.929, 0.694, 0.125], [0.494, 0.184, 0.556],
    [0.466, 0.647, 0.188], [0.301, 0.785, 0.933]
]

def get_args_parser():
    parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
    parser.add_argument('--lr', default=1e-4, type=float)
    parser.add_argument('--lr_backbone', default=1e-5, type=float)
    parser.add_argument('--batch_size', default=1, type=int)
    parser.add_argument('--weight_decay', default=1e-4, type=float)
    parser.add_argument('--epochs', default=300, type=int)
    parser.add_argument('--lr_drop', default=200, type=int)
    parser.add_argument('--clip_max_norm', default=0.1, type=float,
                        help='gradient clipping max norm')

    # Model parameters
    parser.add_argument('--frozen_weights', type=str, default=None,
                        help="Path to the pretrained model. If set, only the mask head will be trained")
    # * Backbone
    parser.add_argument('--backbone', default='resnet50', type=str,
                        help="Name of the convolutional backbone to use")
    parser.add_argument('--dilation', action='store_true',
                        help="If true, we replace stride with dilation in the last convolutional block (DC5)")
    parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
                        help="Type of positional embedding to use on top of the image features")

    # * Transformer
    parser.add_argument('--enc_layers', default=6, type=int,
                        help="Number of encoding layers in the transformer")
    parser.add_argument('--dec_layers', default=6, type=int,
                        help="Number of decoding layers in the transformer")
    parser.add_argument('--dim_feedforward', default=2048, type=int,
                        help="Intermediate size of the feedforward layers in the transformer blocks")
    parser.add_argument('--hidden_dim', default=256, type=int,
                        help="Size of the embeddings (dimension of the transformer)")
    parser.add_argument('--dropout', default=0.1, type=float,
                        help="Dropout applied in the transformer")
    parser.add_argument('--nheads', default=8, type=int,
                        help="Number of attention heads inside the transformer's attentions")
    parser.add_argument('--num_queries', default=100, type=int,
                        help="Number of query slots")
    parser.add_argument('--pre_norm', action='store_true')
    parser.add_argument('--transformer', type=str, default='normalize', help='the modal image patch transformer manner')
    # * Segmentation
    parser.add_argument('--masks', action='store_true',
                        help="Train segmentation head if the flag is provided")

    # Loss
    parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false',
                        help="Disables auxiliary decoding losses (loss at each layer)")
    # * Matcher
    parser.add_argument('--set_cost_class', default=1, type=float,
                        help="Class coefficient in the matching cost")
    parser.add_argument('--set_cost_bbox', default=5, type=float,
                        help="L1 box coefficient in the matching cost")
    parser.add_argument('--set_cost_giou', default=2, type=float,
                        help="giou box coefficient in the matching cost")
    # * Loss coefficients
    parser.add_argument('--mask_loss_coef', default=1, type=float)
    parser.add_argument('--dice_loss_coef', default=1, type=float)
    parser.add_argument('--bbox_loss_coef', default=5, type=float)
    parser.add_argument('--giou_loss_coef', default=2, type=float)
    parser.add_argument('--eos_coef', default=0.1, type=float,
                        help="Relative classification weight of the no-object class")

    # dataset parameters
    parser.add_argument('--dataset_file', default='KINS')
    parser.add_argument('--coco_path', type=str, default='data_object_image_2_KINS')
    parser.add_argument('--coco_panoptic_path', type=str)
    parser.add_argument('--remove_difficult', action='store_true')

    parser.add_argument('--output_dir', default='./checkpoints',
                        help='path where to save, empty for no saving')
    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=42, type=int)
    # parser.add_argument('--resume', default='https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth', help='resume from checkpoint')
    parser.add_argument('--resume', default='checkpoints/checkpoint.pth',
                        help='resume from checkpoint')

    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument('--eval', default=False, type=bool, help='whether to eval')
    parser.add_argument('--num_workers', default=1, type=int)

    # distributed training parameters
    parser.add_argument('--world_size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
    return parser

def main(args):
    utils.init_distributed_mode(args)
    print("git:\n  {}\n".format(utils.get_sha()))

    if args.frozen_weights is not None:
        assert args.masks, "Frozen training is meant for segmentation only"
    print(args)

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    model, criterion, postprocessors = build_model(args)
    model.to(device)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('number of params:', n_parameters)

    param_dicts = [
        {"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad]},
        {
            "params": [p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad],
            "lr": args.lr_backbone,
        },
    ]


    # load pth
    if args.resume:
        if args.resume.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(
                args.resume, model_dir='checkpoints', map_location='cpu', check_hash=True)
        else:
            checkpoint = torch.load(args.resume, map_location='cpu')
            print("load ckeckpoint from {}".format(args.resume))
        model_without_ddp.load_state_dict(checkpoint['model'])

	# TODO: 修改自己的图片路径
    img_path = '/home/Fzh/detr-main/data_object_image_2_KINS/training/image_2/000015.png'
    img = Image.open(img_path)
    img_size = img.size
    transformer = make_coco_transforms('test')
    img_t, _ = transformer(img, None)
    img_t = img_t.unsqueeze(0).to(device)
    model.eval()
    output = model(img_t)
    print("detect over!")
    probs = output['pred_logits'].softmax(-1)[0, :, :-1]
    keep = probs.max(-1).values > 0.7
    bbox_scaled = rescaled_boxex(output['pred_boxes'][0, keep],img_size)
    pro = probs[keep]

    print(pro, bbox_scaled)
    plot_results(img, pro, bbox_scaled)

def plot_results(pii_img, prob, boxes):
    import matplotlib.pyplot as plt
    plt.figure(figsize=(16, 10))
    plt.imshow(pii_img)
    ax = plt.gca()
    for p,(xmin,ymin,xmax,ymax), c in zip(prob,boxes.tolist(),COLORS*100):
        ax.add_patch(plt.Rectangle((xmin,ymin), xmax-xmin, ymax-ymin,
                                   fill=False, color=c, linewidth=3))
        cl=p.argmax()
        text = f'{CLASSES[cl]}:{p[cl]:0.2f}'
        ax.text(xmin,ymin,text,fontsize=15,bbox=dict(facecolor='yellow',alpha=0.5))
    plt.axis('off')
    plt.show()


def box_cxcy2xyxy(x):
    x_c,y_c, w, h = x.unbind(1)
    b = [(x_c-0.5*w),(y_c-0.5*h),(x_c+0.5*w),(y_c+0.5*h)]
    return torch.stack(b, dim=1)

def rescaled_boxex(out_box, size):
    img_w, img_h = size
    b =box_cxcy2xyxy(out_box).cpu()
    b = b*torch.tensor([img_w, img_h, img_w, img_h],dtype=torch.float32)
    return b


if __name__ == '__main__':
    parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
    args = parser.parse_args()
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    main(args)

需要改的地方,修改自己的模型参数位置,修改查看的图片的路径,最开始类别的情况。

你可能感兴趣的:(计算机视觉)