pytorch(一)——用python自動生成train,val文件

任務——分類

數據集爲5個不同類別的圖片集,每個圖片集大概有3W張圖片。所以要建立一個train訓練的txt文件和一個val驗證的txt文件,裏面放圖片的路徑,因爲只是練手用,所以不放test驗證。

pytorch(一)——用python自動生成train,val文件_第1张图片
最終要的結果是從每個文件裏拿出28000個訓練和剩下差不多3000個用來測試。

import os
a=0
while(a<5):

    dir = '/home/zyx/data/pic/'+str(a)+'/'
    label = a

    files = os.listdir(dir)
    files.sort()
    train = open('/home/zyx/data/train.txt','a')
    val = open('/home/zyx/data/val.txt', 'a')
    i = 1
    for file in files:
        if i<29000:
            fileType = os.path.split(file)
            if fileType[1] == '.txt':
                continue
            name =  str(dir) +  file + ' ' + str(int(label)) +'\n'
            train.write(name)
            i = i+1
            print(i)
        else:
            fileType = os.path.split(file)
            if fileType[1] == '.txt':
                continue
            name = str(dir) +file + ' ' + str(int(label)) +'\n'
            val.write(name)
            i = i+1
            print(i)


    val.close()
    train.close()
    print(a)
    a = a + 1

結果

然後就可以開始寫網絡和訓練模型了
因爲我的圖片數據集裏有/home/zyx/data/pic/0/0_original_108475 (2).JPG_6c664301-0796-43f1-ba25-f19aa62537b4.JPG 0比較奇怪的命名,所以要把讀取數據的地方稍微做一些修改

class MyDataset(Dataset):
    def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
        fh = open(txt, 'r')
        imgs = []
        for line in fh:
            line = line.strip('\n')
            line = line.rstrip()
            words = line.split()
            if len(words)>2:
                words[0] = str((words[0]))+' '+str((words[1]))
                words[1] = words[2]
            print(len(words))

            imgs.append((words[0],int(words[1])))
            print((words[0],int(words[1])))
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = self.loader(fn)
        if self.transform is not None:
            img = self.transform(img)
        return img,label

    def __len__(self):
        return len(self.imgs)

基本上再用這個沒啥問題可以直接用了

你可能感兴趣的:(PyTorch,Python)