将人脸分割成几部分,送入并行的网络结构。出现的问题是:
使用torchvision.datasets.ImageLoader函数加载数据集后,当使用torch.utils.data.DataLoader进行shuffle后,这几部分的图像的Label无法一一对应,即无法再并行网络中对各部分的图像特征进行融合。
以下是我的解决方案,直接给代码,有时间在来详细解释。
folder.py
# -*- coding: utf-8 -*-
import os
import torch
import torchvision
from PIL import Image
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import random
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm']
def make_dataset(dir, class_to_idx):
images = []
dir = os.path.expanduser(dir)
for target in sorted(os.listdir(dir)):
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue
for root, _, fnames in sorted(os.walk(d)):
for fname in sorted(fnames):
if is_image_file(fname):
path = os.path.join(root, fname)
item = (path, class_to_idx[target])
images.append(item)
random.shuffle(images) # 在这里进行洗牌,从而代替DataLoader中的shuffle
return images
def is_image_file(filename):
"""Checks if a file is an image.
Args:
filename (string): path to a file
Returns:
bool: True if the filename ends with a known image extension
"""
filename_lower = filename.lower()
return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS)
def find_classes(dir):
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
def default_loader(path):
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return accimage_loader(path)
else:
return pil_loader(path)
def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
with Image.open(f) as img:
return img.convert('RGB')
def accimage_loader(path):
import accimage
try:
return accimage.Image(path)
except IOError:
# Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path)
# 重写ImageFolder类
class MyImageFolder(datasets.ImageFolder):
def __init__(self, imgs, root, transform=None, target_transform=None, loader=default_loader):
super(MyImageFolder, self).__init__(root, transform, target_transform, loader)
self.imgs = imgs
def imshow(inp, title=None):
"""Imshow for Tensor."""
inp = inp.numpy().transpose((1, 2, 0))
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(10) # pause a bit so that plots are updated
data_tranform = transforms.Compose([
transforms.ToTensor()
])
if __name__ == '__main__':
data_dir = r'E:\Experiment\fer2013\datasets\test'
class_names = ['nature', 'anger', 'disgust', 'fear', 'happy', 'sadness', 'surprise']
classes, class_to_idx = find_classes(data_dir)
print(classes)
print(class_to_idx)
imgs = make_dataset(data_dir, class_to_idx)
print(len(imgs))
sets = MyImageFolder(imgs, data_dir, transform=data_tranform)
print(len(sets))
loader = torch.utils.data.DataLoader(sets, batch_size=4, shuffle=False, num_workers=4)
inputs, targets = next(iter(loader))
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[class_names[x] for x in targets])
mydataloader.py
# -*- coding: utf-8 -*-
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import datasets, transforms
from folder import MyImageFolder, make_dataset, find_classes
batch_size = 4
ck_folders = r'E:\FERS\CK+\CK_FACEPARTS_DB1'
folders = ['folder_0', 'folder_1', 'folder_2', 'folder_3', 'folder_4']
global_sets = [os.path.join(x, 'global') for x in folders]
lefteye_sets = [os.path.join(x, 'lefteye') for x in folders]
righteye_sets = [os.path.join(x, 'righteye') for x in folders]
mouth_sets = [os.path.join(x, 'mouth') for x in folders]
data_transform = {
'global': transforms.Compose([
transforms.Grayscale(),
transforms.Resize((64, 64)),
transforms.ToTensor(),
]),
'eye': transforms.Compose([
transforms.Grayscale(),
transforms.Resize((64, 64)),
transforms.ToTensor(),
]),
'mouth': transforms.Compose([
transforms.Grayscale(),
transforms.Resize((32, 64)),
transforms.ToTensor(),
])
}
classes, class_to_idx = find_classes(os.path.join(ck_folders, global_sets[0]))
g_imgs = [make_dataset(os.path.join(ck_folders, global_sets[x]), class_to_idx) for x in range(5)]
lefteye_imgs = [[(item[0].replace('global', 'lefteye'), item[1]) for item in imgs] for imgs in g_imgs]
righteye_imgs = [[(item[0].replace('global', 'righteye'), item[1]) for item in imgs] for imgs in g_imgs]
moutn_imgs = [[(item[0].replace('global', 'mouth'), item[1]) for item in imgs] for imgs in g_imgs]
# global face
ck_global_datasets = [MyImageFolder(g_imgs[x], os.path.join(ck_folders, global_sets[x]),
transform=data_transform['global']) for x in range(5)]
ck_global_dataset_0 = torch.utils.data.ConcatDataset([ck_global_datasets[i] for i in range(5) if i != 0]) # testset 0
ck_global_dataset_1 = torch.utils.data.ConcatDataset([ck_global_datasets[i] for i in range(5) if i != 1]) # testset 1
ck_global_dataset_2 = torch.utils.data.ConcatDataset([ck_global_datasets[i] for i in range(5) if i != 2]) # testset 2
ck_global_dataset_3 = torch.utils.data.ConcatDataset([ck_global_datasets[i] for i in range(5) if i != 3]) # testset 3
ck_global_dataset_4 = torch.utils.data.ConcatDataset([ck_global_datasets[i] for i in range(5) if i != 4]) # testset 4
ck_global_trainloader = torch.utils.data.DataLoader(ck_global_dataset_0, batch_size=batch_size, shuffle=False, num_workers=4)
ck_global_testloader = torch.utils.data.DataLoader(ck_global_datasets[0], batch_size=batch_size, shuffle=False, num_workers=4)
# left eye
ck_lefteye_datasets = [MyImageFolder(lefteye_imgs[x], os.path.join(ck_folders, lefteye_sets[x]),
transform=data_transform['eye']) for x in range(5)]
ck_lefteye_dataset_0 = torch.utils.data.ConcatDataset([ck_lefteye_datasets[i] for i in range(5) if i != 0]) # testset 0
ck_lefteye_dataset_1 = torch.utils.data.ConcatDataset([ck_lefteye_datasets[i] for i in range(5) if i != 1]) # testset 1
ck_lefteye_dataset_2 = torch.utils.data.ConcatDataset([ck_lefteye_datasets[i] for i in range(5) if i != 2]) # testset 2
ck_lefteye_dataset_3 = torch.utils.data.ConcatDataset([ck_lefteye_datasets[i] for i in range(5) if i != 3]) # testset 3
ck_lefteye_dataset_4 = torch.utils.data.ConcatDataset([ck_lefteye_datasets[i] for i in range(5) if i != 4]) # testset 4
ck_lefteye_trainloader = torch.utils.data.DataLoader(ck_lefteye_dataset_0, batch_size=batch_size, shuffle=False, num_workers=4)
ck_lefteye_testloader = torch.utils.data.DataLoader(ck_lefteye_datasets[0], batch_size=batch_size, shuffle=False, num_workers=4)
# right eye
ck_righteye_datasets = [MyImageFolder(righteye_imgs[x], os.path.join(ck_folders, righteye_sets[x]),
transform=data_transform['eye']) for x in range(5)]
ck_righteye_dataset_0 = torch.utils.data.ConcatDataset([ck_righteye_datasets[i] for i in range(5) if i != 0]) # testset 0
ck_righteye_dataset_1 = torch.utils.data.ConcatDataset([ck_righteye_datasets[i] for i in range(5) if i != 1]) # testset 1
ck_righteye_dataset_2 = torch.utils.data.ConcatDataset([ck_righteye_datasets[i] for i in range(5) if i != 2]) # testset 2
ck_righteye_dataset_3 = torch.utils.data.ConcatDataset([ck_righteye_datasets[i] for i in range(5) if i != 3]) # testset 3
ck_righteye_dataset_4 = torch.utils.data.ConcatDataset([ck_righteye_datasets[i] for i in range(5) if i != 4]) # testset 4
ck_righteye_trainloader = torch.utils.data.DataLoader(ck_righteye_dataset_0, batch_size=batch_size, shuffle=False, num_workers=4)
ck_righteye_testloader = torch.utils.data.DataLoader(ck_righteye_datasets[0], batch_size=batch_size, shuffle=False, num_workers=4)
# mouth
ck_mouth_datasets = [MyImageFolder(moutn_imgs[x], os.path.join(ck_folders, mouth_sets[x]),
transform=data_transform['mouth']) for x in range(5)]
ck_mouth_dataset_0 = torch.utils.data.ConcatDataset([ck_mouth_datasets[i] for i in range(5) if i != 0]) # testset 0
ck_mouth_dataset_1 = torch.utils.data.ConcatDataset([ck_mouth_datasets[i] for i in range(5) if i != 1]) # testset 1
ck_mouth_dataset_2 = torch.utils.data.ConcatDataset([ck_mouth_datasets[i] for i in range(5) if i != 2]) # testset 2
ck_mouth_dataset_3 = torch.utils.data.ConcatDataset([ck_mouth_datasets[i] for i in range(5) if i != 3]) # testset 3
ck_mouth_dataset_4 = torch.utils.data.ConcatDataset([ck_mouth_datasets[i] for i in range(5) if i != 4]) # testset 4
ck_mouth_trainloader = torch.utils.data.DataLoader(ck_mouth_dataset_0, batch_size=batch_size, shuffle=False, num_workers=4)
ck_mouth_testloader = torch.utils.data.DataLoader(ck_mouth_datasets[0], batch_size=batch_size, shuffle=False, num_workers=4)
def imshow(inp, title=None):
"""Imshow for Tensor."""
inp = inp.numpy().transpose((1, 2, 0))
# mean = np.array([0.485, 0.456, 0.406])
# std = np.array([0.229, 0.224, 0.225])
# inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(2) # pause a bit so that plots are updated
if __name__ == '__main__':
# for imgs in g_imgs:
# print(imgs[:3])
# print('-' * 10)
# for imgs in lefteye_imgs:
# print(imgs[:3])
# print('-' * 10)
# for imgs in righteye_imgs:
# print(imgs[:3])
# print('-' * 10)
# for imgs in moutn_imgs:
# print(imgs[:3])
# (0=neutral, 1=anger, 2=contempt, 3=disgust, 4=fear, 5=happy, 6=sadness, 7=surprise)
# class_names = ['nature', 'anger', 'contempt', 'disgust', 'fear', 'happy', 'sadness', 'surprise']
class_names = ['nature', 'anger', 'disgust', 'fear', 'happy', 'sadness', 'surprise']
# global face
print('train size: {}, train batch nums:{}'.format(len(ck_global_dataset_0), len(ck_global_trainloader)))
print('test size: {}, test batch nums:{}'.format(len(ck_global_datasets[0]), len(ck_global_testloader)))
inputs, classes = next(iter(ck_global_trainloader))
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[class_names[x] for x in classes])
# left eye
print('train size: {}, train batch nums:{}'.format(len(ck_lefteye_dataset_0), len(ck_lefteye_trainloader)))
print('test size: {}, test batch nums:{}'.format(len(ck_lefteye_datasets[0]), len(ck_lefteye_testloader)))
inputs, classes = next(iter(ck_lefteye_trainloader))
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[class_names[x] for x in classes])
# left eye
print('train size: {}, train batch nums:{}'.format(len(ck_righteye_dataset_0), len(ck_righteye_trainloader)))
print('test size: {}, test batch nums:{}'.format(len(ck_righteye_datasets[0]), len(ck_righteye_testloader)))
inputs, classes = next(iter(ck_righteye_trainloader))
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[class_names[x] for x in classes])
# mouth
print('train size: {}, train batch nums:{}'.format(len(ck_mouth_dataset_0), len(ck_mouth_trainloader)))
print('test size: {}, test batch nums:{}'.format(len(ck_mouth_datasets[0]), len(ck_mouth_testloader)))
inputs, classes = next(iter(ck_mouth_trainloader))
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[class_names[x] for x in classes])