# -*- coding: utf-8 -*-
'''
@Time : 2020/5/20 22:28
@Author : HHNa
@FileName: dataset.py
@Software: PyCharm
'''
import os, sys, glob, shutil, json
import cv2
from PIL import Image
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms
class SVHNDataset(Dataset):
def __init__(self, img_path, img_label, transform=None):
self.img_path = img_path
self.img_label = img_label
if transform is not None:
self.transform = transform
else:
self.transform = None
def __getitem__(self, index):
img = Image.open(self.img_path[index]).convert('RGB')
if self.transform is not None:
img = self.transform(img)
# 原始SVHN中类别10为数字0
lbl = np.array(self.img_label[index], dtype=np.int)
lbl = list(lbl) + (5 - len(lbl)) * [10]
return img, torch.from_numpy(np.array(lbl[:5]))
def __len__(self):
return len(self.img_path)
train_path = glob.glob('./data/train/mchar_train/*.png')
train_path.sort()
train_json = json.load(open('./data/train/mchar_train.json'))
train_label = [train_json[x]['label'] for x in train_json]
train_loader = torch.utils.data.DataLoader(
SVHNDataset(train_path, train_label,
transforms.Compose([
transforms.Resize((64, 128)),
transforms.ColorJitter(0.3, 0.3, 0.2),
transforms.RandomRotation(5),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])),
batch_size=10, # 每批样本个数
shuffle=False, # 是否打乱顺序
num_workers=10, # 读取的线程个数
)
for data in train_loader:
break
Traceback (most recent call last):
File "G:/my_github_project/skills/CV_DataWhale_517/dataset.py", line 19, in
import torchvision as transforms
File "C:\Users\hhn\Anaconda3\envs\pytorch_1.2_gpu\lib\site-packages\torchvision\__init__.py", line 1, in
from torchvision import models
File "C:\Users\hhn\Anaconda3\envs\pytorch_1.2_gpu\lib\site-packages\torchvision\models\__init__.py", line 11, in
from . import detection
File "C:\Users\hhn\Anaconda3\envs\pytorch_1.2_gpu\lib\site-packages\torchvision\models\detection\__init__.py", line 1, in
from .faster_rcnn import *
File "C:\Users\hhn\Anaconda3\envs\pytorch_1.2_gpu\lib\site-packages\torchvision\models\detection\faster_rcnn.py", line 7, in
from torchvision.ops import misc as misc_nn_ops
File "C:\Users\hhn\Anaconda3\envs\pytorch_1.2_gpu\lib\site-packages\torchvision\ops\__init__.py", line 1, in
from .boxes import nms, box_iou
File "C:\Users\hhn\Anaconda3\envs\pytorch_1.2_gpu\lib\site-packages\torchvision\ops\boxes.py", line 2, in
from torchvision import _C
ImportError: DLL load failed: 找不到指定的模块。
Process finished with exit code 1
尝试解决办法:
https://blog.csdn.net/zhenyu_an/article/details/103940020/
但是这个方式下载非常慢。
后来,我发现可能是torchvision和pytorch版本没对上的原因哈。
所以我卸载torchvision=0.3(pip uninstall torchvision)
pip install torchvision-0.4.0-cp36-cp36m-win_amd64.whl
完美解决!