pytorch训练自己的数据集

使用pytorch自带的模型,并修改全连接层为自己数据集的类别数。

model = models.resnet50(pretrained=False)
class_num = 62
fc_features = model.fc.in_features
model.fc = nn.Linear(fc_features, class_num)
model = model.to(device)

根据数据集文件夹制作包含文件路径 标签的txt文本文件

import os
data = []
labels = []
filetype = "ppm"
for root,dirs,files in os.walk("./"):
    for f in files:
        if f.split('.')[1] == filetype:
            data.append(root[1:]+"/"+f)
            labels.append(int(root[2:]))

with open("datalist.txt",'w') as fi:
    for i in range(len(data)-1):
        fi.write("{} {}\n".format(data[i],labels[i]))
    fi.write("{} {}".format(data[i+1],labels[i+1]))

得到如图所示结果
pytorch训练自己的数据集_第1张图片
参考https://blog.csdn.net/sinat_42239797/article/details/90641659
制作Dataset类,定义数据读取方法

def default_loader(path):
    return Image.open(path).convert('RGB')


class BTSCDataset(Dataset):
    def __init__(self, dir, transforms=None, loader=default_loader, train=True):
        super(BTSCDataset, self).__init__()
        self.sub_directory = 'Training' if train else 'Testing'
        imgs = []
        with open(dir + "/BelgiumTSC_" + self.sub_directory + "/" + self.sub_directory + "/datalist.txt", "r") as f:
            data = f.readlines()
        for line in data:
            label = int(line.split()[1])
            img = line.split()[0]
            imgs.append((dir + "/BelgiumTSC_" + self.sub_directory + "/" + self.sub_directory + img, label))
        self.imgs = imgs
        self.transforms = transforms
        self.loader = loader

    def __getitem__(self, item):
        fn, label = self.imgs[item]
        img = self.loader(fn)
        if self.transforms is not None:
            img = self.transforms(self)
        return img, fn

    def __len__(self):
        return len(self.imgs)

参考https://www.pythonf.cn/read/156398
设计训练过程

你可能感兴趣的:(深度学习,python,深度学习)