數據集爲5個不同類別的圖片集,每個圖片集大概有3W張圖片。所以要建立一個train訓練的txt文件和一個val驗證的txt文件,裏面放圖片的路徑,因爲只是練手用,所以不放test驗證。
最終要的結果是從每個文件裏拿出28000個訓練和剩下差不多3000個用來測試。
import os
a=0
while(a<5):
dir = '/home/zyx/data/pic/'+str(a)+'/'
label = a
files = os.listdir(dir)
files.sort()
train = open('/home/zyx/data/train.txt','a')
val = open('/home/zyx/data/val.txt', 'a')
i = 1
for file in files:
if i<29000:
fileType = os.path.split(file)
if fileType[1] == '.txt':
continue
name = str(dir) + file + ' ' + str(int(label)) +'\n'
train.write(name)
i = i+1
print(i)
else:
fileType = os.path.split(file)
if fileType[1] == '.txt':
continue
name = str(dir) +file + ' ' + str(int(label)) +'\n'
val.write(name)
i = i+1
print(i)
val.close()
train.close()
print(a)
a = a + 1
結果
然後就可以開始寫網絡和訓練模型了
因爲我的圖片數據集裏有/home/zyx/data/pic/0/0_original_108475 (2).JPG_6c664301-0796-43f1-ba25-f19aa62537b4.JPG 0比較奇怪的命名,所以要把讀取數據的地方稍微做一些修改
class MyDataset(Dataset):
def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
fh = open(txt, 'r')
imgs = []
for line in fh:
line = line.strip('\n')
line = line.rstrip()
words = line.split()
if len(words)>2:
words[0] = str((words[0]))+' '+str((words[1]))
words[1] = words[2]
print(len(words))
imgs.append((words[0],int(words[1])))
print((words[0],int(words[1])))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.loader = loader
def __getitem__(self, index):
fn, label = self.imgs[index]
img = self.loader(fn)
if self.transform is not None:
img = self.transform(img)
return img,label
def __len__(self):
return len(self.imgs)
基本上再用這個沒啥問題可以直接用了