废话不多说,直接上代码
import argparse
import datetime
import json
import random
import time
from pathlib import Path
from PIL import Image
from datasets.KINS import make_coco_transforms
import numpy as np
import torch
from torch.utils.data import DataLoader, DistributedSampler
import datasets
import util.misc as utils
from datasets import build_dataset, get_coco_api_from_dataset
from engine import evaluate, train_one_epoch
from models import build_model
# TODO:修改自己数据集的类别
CLASSES = [
'pedestrain', 'cyclist', 'person-sitting', 'van', 'car','tram','truck','misc','bicyccle','N/A'
]
COLORS = [
[0.000,0.447,0.741], [0.850,0.325,0.098], [0.929, 0.694, 0.125], [0.494, 0.184, 0.556],
[0.466, 0.647, 0.188], [0.301, 0.785, 0.933]
]
def get_args_parser():
parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
parser.add_argument('--lr', default=1e-4, type=float)
parser.add_argument('--lr_backbone', default=1e-5, type=float)
parser.add_argument('--batch_size', default=1, type=int)
parser.add_argument('--weight_decay', default=1e-4, type=float)
parser.add_argument('--epochs', default=300, type=int)
parser.add_argument('--lr_drop', default=200, type=int)
parser.add_argument('--clip_max_norm', default=0.1, type=float,
help='gradient clipping max norm')
# Model parameters
parser.add_argument('--frozen_weights', type=str, default=None,
help="Path to the pretrained model. If set, only the mask head will be trained")
# * Backbone
parser.add_argument('--backbone', default='resnet50', type=str,
help="Name of the convolutional backbone to use")
parser.add_argument('--dilation', action='store_true',
help="If true, we replace stride with dilation in the last convolutional block (DC5)")
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
help="Type of positional embedding to use on top of the image features")
# * Transformer
parser.add_argument('--enc_layers', default=6, type=int,
help="Number of encoding layers in the transformer")
parser.add_argument('--dec_layers', default=6, type=int,
help="Number of decoding layers in the transformer")
parser.add_argument('--dim_feedforward', default=2048, type=int,
help="Intermediate size of the feedforward layers in the transformer blocks")
parser.add_argument('--hidden_dim', default=256, type=int,
help="Size of the embeddings (dimension of the transformer)")
parser.add_argument('--dropout', default=0.1, type=float,
help="Dropout applied in the transformer")
parser.add_argument('--nheads', default=8, type=int,
help="Number of attention heads inside the transformer's attentions")
parser.add_argument('--num_queries', default=100, type=int,
help="Number of query slots")
parser.add_argument('--pre_norm', action='store_true')
parser.add_argument('--transformer', type=str, default='normalize', help='the modal image patch transformer manner')
# * Segmentation
parser.add_argument('--masks', action='store_true',
help="Train segmentation head if the flag is provided")
# Loss
parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false',
help="Disables auxiliary decoding losses (loss at each layer)")
# * Matcher
parser.add_argument('--set_cost_class', default=1, type=float,
help="Class coefficient in the matching cost")
parser.add_argument('--set_cost_bbox', default=5, type=float,
help="L1 box coefficient in the matching cost")
parser.add_argument('--set_cost_giou', default=2, type=float,
help="giou box coefficient in the matching cost")
# * Loss coefficients
parser.add_argument('--mask_loss_coef', default=1, type=float)
parser.add_argument('--dice_loss_coef', default=1, type=float)
parser.add_argument('--bbox_loss_coef', default=5, type=float)
parser.add_argument('--giou_loss_coef', default=2, type=float)
parser.add_argument('--eos_coef', default=0.1, type=float,
help="Relative classification weight of the no-object class")
# dataset parameters
parser.add_argument('--dataset_file', default='KINS')
parser.add_argument('--coco_path', type=str, default='data_object_image_2_KINS')
parser.add_argument('--coco_panoptic_path', type=str)
parser.add_argument('--remove_difficult', action='store_true')
parser.add_argument('--output_dir', default='./checkpoints',
help='path where to save, empty for no saving')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=42, type=int)
# parser.add_argument('--resume', default='https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth', help='resume from checkpoint')
parser.add_argument('--resume', default='checkpoints/checkpoint.pth',
help='resume from checkpoint')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--eval', default=False, type=bool, help='whether to eval')
parser.add_argument('--num_workers', default=1, type=int)
# distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
return parser
def main(args):
utils.init_distributed_mode(args)
print("git:\n {}\n".format(utils.get_sha()))
if args.frozen_weights is not None:
assert args.masks, "Frozen training is meant for segmentation only"
print(args)
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
model, criterion, postprocessors = build_model(args)
model.to(device)
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)
param_dicts = [
{"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad]},
{
"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad],
"lr": args.lr_backbone,
},
]
# load pth
if args.resume:
if args.resume.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, model_dir='checkpoints', map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.resume, map_location='cpu')
print("load ckeckpoint from {}".format(args.resume))
model_without_ddp.load_state_dict(checkpoint['model'])
# TODO: 修改自己的图片路径
img_path = '/home/Fzh/detr-main/data_object_image_2_KINS/training/image_2/000015.png'
img = Image.open(img_path)
img_size = img.size
transformer = make_coco_transforms('test')
img_t, _ = transformer(img, None)
img_t = img_t.unsqueeze(0).to(device)
model.eval()
output = model(img_t)
print("detect over!")
probs = output['pred_logits'].softmax(-1)[0, :, :-1]
keep = probs.max(-1).values > 0.7
bbox_scaled = rescaled_boxex(output['pred_boxes'][0, keep],img_size)
pro = probs[keep]
print(pro, bbox_scaled)
plot_results(img, pro, bbox_scaled)
def plot_results(pii_img, prob, boxes):
import matplotlib.pyplot as plt
plt.figure(figsize=(16, 10))
plt.imshow(pii_img)
ax = plt.gca()
for p,(xmin,ymin,xmax,ymax), c in zip(prob,boxes.tolist(),COLORS*100):
ax.add_patch(plt.Rectangle((xmin,ymin), xmax-xmin, ymax-ymin,
fill=False, color=c, linewidth=3))
cl=p.argmax()
text = f'{CLASSES[cl]}:{p[cl]:0.2f}'
ax.text(xmin,ymin,text,fontsize=15,bbox=dict(facecolor='yellow',alpha=0.5))
plt.axis('off')
plt.show()
def box_cxcy2xyxy(x):
x_c,y_c, w, h = x.unbind(1)
b = [(x_c-0.5*w),(y_c-0.5*h),(x_c+0.5*w),(y_c+0.5*h)]
return torch.stack(b, dim=1)
def rescaled_boxex(out_box, size):
img_w, img_h = size
b =box_cxcy2xyxy(out_box).cpu()
b = b*torch.tensor([img_w, img_h, img_w, img_h],dtype=torch.float32)
return b
if __name__ == '__main__':
parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
args = parser.parse_args()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)
需要改的地方,修改自己的模型参数位置,修改查看的图片的路径,最开始类别的情况。