【PyTorch图像语义分割】4. 使用训练好的模型测试

使用训练好的模型测试新图片

  • 1. 图像的加载
  • 2. 用网络forward测试

1. 图像的加载

测试图像的加载仍然是通过继承torch.ultis.data.Dataset加载。在加载训练图像的时候用的是

class UAVDataSet(torch.utils.data.Dataset):

需要可以返回数据、标签。但是在测试新图像的时候没有标签,故只需要返回数据就行,代码如下:

class UAVTestSet(torch.utils.data.Dataset):
    def __init__(self, root, list_path): 
        super(UAVDataSet,self).__init__()
        self.root = root
        self.list_path = list_path
        self.img_ids = [i_id.strip() for i_id in open(list_path)]
        self.files = []
        
        for name in self.img_ids:
            img_file = os.path.join(self.root, "UAVtest/%s.JPG" % name)
            self.files.append({
                "img": img_file,
                "name": name
            })
   
    def __len__(self):
        return len(self.files) 
 
    def __getitem__(self, index):
        
        datafiles = self.files[index]
 
        '''load the datas'''
        name = datafiles["name"]
        image = Image.open(datafiles["img"]).convert('RGB')
        size_origin = image.size # W * H
 
        '''convert PIL Image to numpy array'''
        I = np.asarray(image,np.float32) 
        I = I.transpose((2,0,1))#transpose the  H*W*C to C*H*W
        #print(I.shape,L.shape)
        return I.copy(), np.array(size_origin), name

但要是进行test的话,应该仍然还是需要返回标签的。

  • 改进:这两个类写成一个类?

返回了读取的数据后,加载成Tensor的形式仍然用torch.utils.data.DataLoader(),代码:

TEST_DIRECTORY = './'
TEST_LIST_PATH = './UAVtest.txt'
Batch_size = 1
dst = UAVTestSet(TEST_DIRECTORY,TEST_LIST_PATH)
testloader = torch.utils.data.DataLoader(dst, batch_size = Batch_size)

2. 用网络forward测试

for i, testdata in enumerate(testloader,0):
    inputs, _size, _name = testdata
    outputs = M(inputs) # 输出是 batchsize * C(类别数) * H * W Tensor
    pred = torch.max(outputs,1) # 在1轴(类别)上取max作为输出
    pred = pred[1].detach().numpy() # pred是need_grad()的,用.detach()方法后再.numpy()
                                    # 或者在测试时候设置不需梯度?
    pred = np.squeeze(pred)         # 此时输出1*19*25的ndarray,squeeze成二维
    pred = pred*32                  # 7类,对应为 c * 32 的灰度值,C为 0-6
    
    # 转换成可以imwrite()为.jpg的三通道图像
    pic = np.zeros([19,25,3]) 
    pic[:,:,0] = pred    
    pic[:,:,1] = pred 
    pic[:,:,2] = pred 
    # 最近邻插值恢复到输入图像大小
    pic1 = cv2.resize(pic, (800,600),interpolation=cv2.INTER_NEAREST)
    cv2.imwrite('%d.jpg'%i, pic1)

输出:
在这里插入图片描述

你可能感兴趣的:(PyTorch,图像语义分割)