在PyTorch中使用自己的数据集

太累了 看了一上午CSDN还是没搞明白

看的下面的up主的讲解  做一下笔记 免得忘记 

在pytorch中自定义dataset读取数据_哔哩哔哩_bilibili

主要内容:如何划分训练集 验证集 数据读取 预处理 

代码在github上 pytorch_classification文件夹下custom_dataset文件夹中,内有main.py my_dataset.py utils.py三个py文件

 先看main.py文件

import os

import torch
from torchvision import transforms

from my_dataset import MyDataSet
from utils import read_split_data, plot_data_loader_image

# http://download.tensorflow.org/example_images/flower_photos.tgz
root = "/home/wz/my_github/data_set/flower_data/flower_photos"  # 数据集所在根目录


def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(root)

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

    train_data_set = MyDataSet(images_path=train_images_path,
                               images_class=train_images_label,
                               transform=data_transform["train"])

    batch_size = 8
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_data_set,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=nw,
                                               collate_fn=train_data_set.collate_fn)

    # plot_data_loader_image(train_loader)

    for step, data in enumerate(train_loader):
        images, labels = data


if __name__ == '__main__':
    main()

1需要更改第11行root为自己数据集的位置 且文件夹下包含的文件夹名字即为他们的标签

在PyTorch中使用自己的数据集_第1张图片

在PyTorch中使用自己的数据集_第2张图片

 2用read_split_data划分训练集和 验证集

main文件中查看read_split_data,跳转到utils第九行

#val_rate划分验证集占所有样本的比例 默认值是0.2
def read_split_data(root: str, val_rate: float = 0.2):
    random.seed(0)  # 保证随机结果可复现 不管在谁的电脑上划分的数据集都一样
    assert os.path.exists(root), "dataset root: {} does not exist.".format(root)

    # 遍历文件夹,一个文件夹对应一个类别
    flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
    # 排序,保证顺序一致
    flower_class.sort()
    # 生成类别名称以及对应的数字索引
    class_indices = dict((k, v) for v, k in enumerate(flower_class))
    json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    train_images_path = []  # 存储训练集的所有图片路径
    train_images_label = []  # 存储训练集图片对应索引信息
    val_images_path = []  # 存储验证集的所有图片路径
    val_images_label = []  # 存储验证集图片对应索引信息
    every_class_num = []  # 存储每个类别的样本总数
    supported = [".jpg", ".JPG", ".png", ".PNG"]  # 支持的文件后缀类型
    # 遍历每个文件夹下的文件

 utils第54行 设置为True即可以将样本的数量可视化,在第71行return语句返回四个值到main函数中,在return语句设置断点 debug一下main函数

运行结果

Connected to pydev debugger (build 212.5284.44)
using cuda device.
3670 images were found in the dataset.
2939 images for training.

在PyTorch中使用自己的数据集_第3张图片

 3main.py中预处理图片

my_dataset.py 中第20行如果使用的不是RGB图像可以自行更改 

 使用PIL来预处理图片(也可以用opencv 但pytorch中用pil的预处理较多)

想让他不要显示裁剪后的图片,在main.py第43行,plot_data_loader_image,点击跳转到utils文件定义他的位置,将91-96行注释掉 

在PyTorch中使用自己的数据集_第4张图片

你可能感兴趣的:(ResNet网络,pytorch,深度学习,人工智能)