SiamRPN++测试过程

import argparse
import collections
import datetime
import imp
import os
import pickle
import time
import lmdb
import ipdb
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from tensorboardX import SummaryWriter
from torch.autograd import Variable
from torch.nn.modules.module import Module
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from dataloader.dataset import ImagnetVIDDataset
from network.SiamRPN import *
from utils.AverageMeter import AverageMeter
from utils.Logger import Logger
from utils.loss import rpn_cross_entropy_balance, rpn_smoothL1

net = SiamRPN()
params = torch.load('your_own_trained_weight.pth')
net.load_state_dict(params['network'])

a = torch.Tensor(1,3,127,127)

torch.nn.init.normal_(a,mean=0,std=1)

b = torch.Tensor(1,3,255,255)

torch.nn.init.normal_(a,mean=0,std=1)

aa,bb = net(a,b)

print(aa.shape)

print(bb.shape)

aa输出为cls分类的tensor(1,10,25,25)

bb输出为reg回归的tensor(1,20,25,25)

demo改好了:

SiamRPN++测试过程_第1张图片

你可能感兴趣的:(SiamRPN++测试过程)