torchvision 是独立于pytorch 之外的图像操作库
具体介绍详见:DrHW的文章
torchvision主要包括一下几个包:1
引文2
"""
inout pipline for custom dataset
"""
from torch.utils.data.dataset import Dataset
class CustomDataset(Dataset):
def __init__(self):
"""
一些初始化过程写在这里
"""
# TODO
# 1. Initialize file paths or a list of file names.
pass
def __getitem__(self, index):
"""
返回数据和标签,可以这样显示调用:
img, label = MyCustomDataset.__getitem__(99)
"""
# TODO
# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
# 2. Preprocess the data (e.g. torchvision.Transform).
# 3. Return a data pair (e.g. image and label).
pass
def __len__(self):
"""
返回所有数据的数量
"""
# You should change 9 to the total size of your dataset.
return 9 # e.g. 9 is size of dataset
from torch.utils.data.dataset import Dataset
from torchvision import transforms
class MyCustomDataset(Dataset):
def __init__(self, ..., transforms=None):
# stuff
...
self.transforms = transforms
def __getitem__(self, index):
# stuff
...
data = # 一些读取的数据
if self.transforms is not None:
data = self.transforms(data)
# 如果 transform 不为 None,则进行 transform 操作
return (img, label)
def __len__(self):
return count
if __name__ == \'__main__\':
# 定义我们的 transforms (1)
transformations = transforms.Compose([transforms.CenterCrop(100), transforms.ToTensor()])
# 创建 dataset
custom_dataset = MyCustomDataset(..., transformations)
from torch.utils.data.dataset import Dataset
from torchvision import transforms
class MyCustomDataset(Dataset):
def __init__(self, ...):
# stuff
...
# (2) 一种方法是单独定义 transform
self.center_crop = transforms.CenterCrop(100)
self.to_tensor = transforms.ToTensor()
# (3) 或者写成下面这样
self.transformations = \
transforms.Compose([transforms.CenterCrop(100),
transforms.ToTensor()])
def __getitem__(self, index):
# stuff
...
data = #一些读取的数据
# 当第二次调用 transform 时,调用的是 __call__()
data = self.center_crop(data) # (2)
data = self.to_tensor(data) # (2)
# 或者写成下面这样
data = self.trasnformations(data) # (3)
# 注意 (2) 和 (3) 中只需要实现一种
return (img, label)
def __len__(self):
return count
if __name__ == \'__main__\':
custom_dataset = MyCustomDataset(...)
另一种情况是 csv 文件中保存了我们需要的图像文件的像素值(比如有些 MNIST 教程就是这样的)。我们需要改动一下 getitem() 函数。
Label | pixel_1 | pixel_2 | … |
---|---|---|---|
1 | 50 | 99 | … |
0 | 21 | 223 | … |
9 | 44 | 112 | … |
class CustomDatasetFromCSV(Dataset):
def __init__(self, csv_path, height, width, transforms=None):
"""
Args:
csv_path (string): csv 文件路径
height (int): 图像高度
width (int): 图像宽度
transform: transform 操作
"""
self.data = pd.read_csv(csv_path)
self.labels = np.asarray(self.data.iloc[:, 0])
self.height = height
self.width = width
self.transforms = transform
def __getitem__(self, index):
single_image_label = self.labels[index]
# 读取所有像素值,并将 1D array ([784]) reshape 成为 2D array ([28,28])
img_as_np = np.asarray(self.data.iloc[index][1:]).reshape(28,28).astype(\'uint8\')
# 把 numpy array 格式的图像转换成灰度 PIL image
img_as_img = Image.fromarray(img_as_np)
img_as_img = img_as_img.convert(\'L\')
# 将图像转换成 tensor
if self.transforms is not None:
img_as_tensor = self.transforms(img_as_img)
# 返回图像及其 label
return (img_as_tensor, single_image_label)
def __len__(self):
return len(self.data.index)
if __name__ == "__main__":
transformations = transforms.Compose([transforms.ToTensor()])
custom_mnist_from_csv = \
CustomDatasetFromCSV(\'../data/mnist_in_csv.csv\', 28, 28, transformations)
PyTorch 中的 Dataloader 只是调用 getitem() 方法并组合成 batch,我们可以这样调用:
...
if __name__ == "__main__":
# 定义 transforms
transformations = transforms.Compose([transforms.ToTensor()])
# 自定义数据集
custom_mnist_from_csv = \
CustomDatasetFromCSV(\'../data/mnist_in_csv.csv\',
28, 28,
transformations)
# 定义 data loader
mn_dataset_loader = torch.utils.data.DataLoader(dataset=custom_mnist_from_csv,
batch_size=10,
shuffle=False)
for images, labels in mn_dataset_loader:
# 将数据传给网络模型
需要注意的是使用多卡训练时,PyTorch dataloader 会将每个 batch 平均分配到各个 GPU。所以如果 batch size 过小,可能发挥不了多卡的效果。
from torch.utils.data.dataset import Dataset
from torchvision import transforms
class MyDateset(Dataset):
def __init__(self, file_folder, is_test=False, transform=None):
self.img_folder_path = '../input/images/Images/'
self.annotation_folder_path = '../input/annotations/Annotation/'
self.file_folder = file_folder
self.transform = transform
#self.transform = transforms.Compose
self.is_test = is_test
def __getitem__(self, idx):
file = self.file_folder[idx]
img_path = self.img_folder_path + file
img = Image.open(img_path).convert('RGB')
if not self.is_test:
annotation_path = self.annotation_folder_path + file.split('.')[0]
with open(annotation_path) as f:
annotation = f.read()
xy = self.get_xy(annotation)
box = torch.FloatTensor(list(xy))
new_box = self.box_resize(box, img)
if self.transform is not None:
img = self.transform(img)
return img, new_box
else:
if self.transform is not None:
img = self.transform(img)
return img
def __len__(self):
return len(self.file_folder)
def get_xy(self, annotation):
xmin = int(re.findall('(?<=)[0-9]+?(?= )', annotation)[0])
xmax = int(re.findall('(?<=)[0-9]+?(?= )', annotation)[0])
ymin = int(re.findall('(?<=)[0-9]+?(?= )', annotation)[0])
ymax = int(re.findall('(?<=)[0-9]+?(?= )', annotation)[0])
return xmin, ymin, xmax, ymax
def show_box(self):
file = random.choice(self.file_folder)
annotation_path = self.annotation_folder_path + file.split('.')[0]
img_box = Image.open(self.img_folder_path + file)
with open(annotation_path) as f:
annotation = f.read()
draw = ImageDraw.Draw(img_box)
xy = self.get_xy(annotation)
print('bbox:', xy)
draw.rectangle(xy=[xy[:2], xy[2:]])
return img_box
def box_resize(self, box, img, dims=(332, 332)):
old_dims = torch.FloatTensor([img.width, img.height, img.width, img.height]).unsqueeze(0)
new_box = box / old_dims
new_dims = torch.FloatTensor([dims[1], dims[0], dims[1], dims[0]]).unsqueeze(0)
new_box = new_box * new_dims
return new_box
class FaceLandmarksDataset(Dataset):
"""Face Landmarks dataset."""
def __init__(self, csv_file, root_dir, transform=None):
"""
Args:
csv_file (string): Path to the csv file with annotations.
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.landmarks_frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.landmarks_frame)
def __getitem__(self, idx):
img_name = os.path.join(self.root_dir,
self.landmarks_frame.iloc[idx, 0])
image = io.imread(img_name)
landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1, 2)
sample = {'image': image, 'landmarks': landmarks}
if self.transform:
sample = self.transform(sample)
return sample
参考文献:
https://www.cnblogs.com/yjphhw/p/9773333.html ↩︎
https://github.com/yunjey/pytorch-tutorial/ ↩︎