对于图片数据加载,使用ImageFolder读取训练集和测试集的文件夹,其中训练集文件夹下分别有多个子文件夹,文件夹的名字即为该类型图片的标签(如图)
获得图片数据train_dataset后,放入DataLoader迭代器,其中有三个自定义参数:batch_size 、 shuffle 、 num_workers,分别表示批次大小、数据是否打乱和使用的线程数。
from torchvision import transforms, datasets
train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
transform=data_transform["train"])
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,
num_workers=nw)
validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=batch_size, shuffle=False,
num_workers=nw)
from torch.utils.data import Dataset
Dataset有三个主要的方法:__init__ (self ,)、__getitem__(self , idx) 、__len__(self)
1、通过__len__(self):函数获得数据集的大小,从而在放入DataLoader后能根据batch_size划分数据;
2、__getitem__(self , idx):其中idx为每个batch_size的下标,从而获得每条数据;
目前有关于电影信息用csv格式保存和电影海报图片,图片的文件名为对应电影的id,如图:
若想同时读取以上文本数据和图片数据,就得复写Dataset类:
class MovieDataset(Dataset):
# data_dir:csv表的路径 ;root_dir:图片的路径
def __init__(self, data_dir, root_dir , transform=None):
# self.data = pd.read_csv(csv_file) # csv总文件读取
# 数据的位置
self.features, self.targets_values, self.ratings = pickle.load(open(data_dir, mode='rb'))
self.uid, self.user_gender, self.user_age, self.user_job = self.features.take(0, 1), self.features.take(2, 1), self.features.take(3,1), self.features.take(4, 1)
self.movie_id, self.movie_categories, self.movie_titles, self.intro ,self.targets = self.features.take(1,1) , self.features.take(6,1), self.features.take(5,1) , self.features.take(7,1) , self.targets_values
self.root_dir = root_dir # 图片路径
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 图像规范化
def __len__(self):
return len(self.movie_id)
def __getitem__(self, idx):
img_name = os.path.join(self.root_dir, str(self.movie_id[idx])+'.jpg') # 根据电影id获得对应图片位置
image = Image.open(img_name)
image = image.convert('RGB')
if self.transform:
image = self.transform(image)
user_data = self.get_user_data() # 获得用户数据
movie_data = self.get_movie_data() # 获得电影数据
# 使用列表存储输入数据x和真实结果y
input=[]
input.append(image)
# (user_id , user_gender , user_age , user_job)
input.append(user_data[idx])
# (movie_id , categories , title , intro , intro_lengths)
input.append(movie_data[idx])
# (rating)
input.append(self.targets[idx])
return input
这里明显可以看出__getitem__(self , idx)中的idx为csv表中的索引,可以根据索引将不同类型的数据放入DataLoader里。
for step, X in enumerate(train_bar):
batch_x,batch_y = X[:-1] , X[-1] #获得输入特征和真实结果
复写Dataset可以根据自己的需求将不同的数据放入DataLoader里,复写的类中,需同时有__init__ (self ,)、__getitem__(self , idx) 、__len__(self)三个方法。