零基础入门语义分割-Task3网络模型结构发展

test_mask = pd.read_csv('D:/WorkPlace/Jupyter/test_a_samplesubmit.csv', sep='\t', names=['name', 'mask'])
test_mask['name'] = test_mask['name'].apply(lambda x: 'D:/WorkPlace/Jupyter/test_a/' + x)

for idx, name in enumerate(tqdm_notebook(test_mask['name'].iloc[:])):#得到属性名,数据类型
    image = cv2.imread(name)
    image = trfm(image)
    with torch.no_grad():
        image = image.to(DEVICE)[None]
        score = model(image)['out'][0][0]
        score_sigmoid = score.sigmoid().cpu().numpy()
        score_sigmoid = (score_sigmoid > 0.5).astype(np.uint8)
        score_sigmoid = cv2.resize(score_sigmoid, (512, 512))

        
        # break
    subm.append([name.split('/')[-1], rle_encode(score_sigmoid)])

运行这段代码时遇到了bug,显示如下:

0%
0/2500 [00:00<?, ?it/s]
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-36-eb7b00e9caef> in <module>
      4 for idx, name in enumerate(tqdm_notebook(test_mask['name'].iloc[:])):#得到属性名,数据类型
      5     image = cv2.imread(name)
----> 6     image = trfm(image)
      7     with torch.no_grad():
      8         image = image.to(DEVICE)[None]

E:\Anaconda\envs\py36\lib\site-packages\torchvision\transforms\transforms.py in __call__(self, img)
     68     def __call__(self, img):
     69         for t in self.transforms:
---> 70             img = t(img)
     71         return img
     72 

E:\Anaconda\envs\py36\lib\site-packages\torchvision\transforms\transforms.py in __call__(self, pic)
    134 
    135         """
--> 136         return F.to_pil_image(pic, self.mode)
    137 
    138     def __repr__(self):

E:\Anaconda\envs\py36\lib\site-packages\torchvision\transforms\functional.py in to_pil_image(pic, mode)
    118     """
    119     if not(isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)):
--> 120         raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))
    121 
    122     elif isinstance(pic, torch.Tensor):

TypeError: pic should be Tensor or ndarray. Got <class 'NoneType'>.

原因是我将test_a的部分图片(大概300张)拿来做测试,可能导致csv文件读取到的图片信息无法在test_a中读取到。
解决方法:将全部test_a图片解压后即可。

你可能感兴趣的:(深度学习)