训练的时候读取本地图片以及类别
tf.keras.preprocessing.image.ImageDataGenerator(
featurewise_center=False, samplewise_center=False,
featurewise_std_normalization=False, samplewise_std_normalization=False,
zca_whitening=False, zca_epsilon=1e-06, rotation_range=0, width_shift_range=0.0,
height_shift_range=0.0, brightness_range=None, shear_range=0.0, zoom_range=0.0,
channel_shift_range=0.0, fill_mode='nearest', cval=0.0, horizontal_flip=False,
vertical_flip=False, rescale=None, preprocessing_function=None,
data_format=None, validation_split=0.0, dtype=None
)
train_generator = ImageDataGenerator()
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
datagen = ImageDataGenerator(
featurewise_center=True,
featurewise_std_normalization=True,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=True)
for e in range(epochs):
print('Epoch', e)
batches = 0
for x_batch, y_batch in datagen.flow(x_train, y_train, batch_size=32):
model.fit(x_batch, y_batch)
使用train_generator.flow_from_directory(
directory=path,# 读取目录
target_size=(h,w),# 目标形状
batch_size=size,# 批数量大小
class_mode='binary', # 目标值格式,One of "categorical", "binary", "sparse",
shuffle=True
data/
train/
dogs/
dog001.jpg
dog002.jpg
...
cats/
cat001.jpg
cat002.jpg
...
validation/
dogs/
dog001.jpg
dog002.jpg
...
cats/
cat001.jpg
cat002.jpg
...
def data_from_sequence(train_data_dir, batch_size, num_classes, input_size):
"""读取本地图片和标签数据,处理成sequence数据类型
:return:
"""
# 1、获取txt文件,打乱一次文件
label_files = [os.path.join(train_data_dir, filename) for filename in os.listdir(train_data_dir) if filename.endswith('.txt')]
print(label_files)
random.shuffle(label_files)
# 2、读取txt文件,解析出
img_paths = []
labels = []
for index, file_path in enumerate(label_files):
with open(file_path, 'r') as f:
line = f.readline()
line_split = line.strip().split(', ')
if len(line_split) != 2:
print('%s 文件中格式错误' % (file_path))
continue
# 获取图片名称和标签,转换格式
img_name = line_split[0]
label = int(line_split[1])
# 图片完整路径拼接,并获取到图片和标签列表中(顺序一一对应)
img_paths.append(os.path.join(train_data_dir, img_name))
labels.append(label)
# 3、进行标签类别处理,以及标签平滑
labels = to_categorical(labels, num_classes)
labels = smooth_labels(labels)
# 4、进行所有数据的分割,训练集和验证集
train_img_paths, validation_img_paths, train_labels, validation_labels = \
train_test_split(img_paths, labels, test_size=0.15, random_state=0)
print('总共样本数: %d, 训练样本数: %d, 验证样本数据: %d' % (
len(img_paths), len(train_img_paths), len(validation_img_paths)))
# 5、sequence序列数据制作
train_sequence = GarbageDataSequence(train_img_paths, train_labels, batch_size,
[input_size, input_size], use_aug=True)
validation_sequence = GarbageDataSequence(validation_img_paths, validation_labels, batch_size,
[input_size, input_size], use_aug=False)
return train_sequence, validation_sequence
import math
import os
import random
import numpy as np
from PIL import Image
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import to_categorical, Sequence
from sklearn.model_selection import train_test_split
import os
import torch
import torch.utils.data as data
from datasets.data_aug import *
import cv2 as cv
def get_train_path(self, list_path, file_path):
image = []
label = []
with open(list_path, "r") as lines:
for line in lines:
img_path = os.path.join(file_path, line[:-3])
image.append(img_path)
label.append(line[-2:-1])
return image, label
# 数据集
class expression_dataset(data.Dataset):
# 自定义的参数
def __init__(self, image, label,transforms=None,debug=False,test=False):
self.paths = image
self.labels = label
self.transforms = transforms
self.debug=debug
self.test=test
# 返回图片个数
def __len__(self):
return len(self.paths)
# 获取每个图片
def __getitem__(self, item):
# path
img_path =self.paths[item]
# read image
img =cv.imread(img_path) #BGR
# RGB
img = cv.cvtColor(img,cv.COLOR_BGR2RGB)
# augmentation
if self.transforms is not None:
img = self.transforms(img)
# read label
label = self.labels[item]
# return
return torch.from_numpy(img).float(), int(label)
# image,label = get_train_path(list_path = '/home/aries/Downloads/datasets/expression/lists/traine.txt',
# file_path='/home/aries/Downloads/datasets/expression/')
#
# train_data = expression_dataset(image,label)
#
# train_dataset = data.DataLoader(train_data, batch_size=2, shuffle=None,num_workers=4)
#
# for i, (img,lbl) in enumerate(train_dataset):
# print(img.shape)
# print(lbl)