网上从文件中读取样本和标签的资料很多,但大多讲的不全面,或只讲原理,或只有变为.tfrecords部分,或没有调用的栗子。寄几and男票一起捣鼓了两天,终于有了目前这个完整版的代码,希望对看到的朋友有所帮助。
样本图示如图1,标签文件train_y.csv如图2,这是个2分类问题。
图1
图2
我们的图片存储路径如图3红框所示,标签文件train_y.csv存储路径如图3绿框所示。
我们用ray14_train.py进行train,这个.py文件和train_y.csv不在同一目录下。所以,在标签文件train_y.csv中,我们需要将图片名称这一列变为相对路径,如图4所示,这个新csv我们存为y_train.csv,测试集也这么处理。
图3
图4
import numpy as np
import pandas as pd
import cv2
import csv
from os import path as osp
import os
base_path = os.path.join('images','images224')
train_y_path = os.path.join(base_path,'train_y.csv')
train_y = np.loadtxt(train_y_path, delimiter=",", skiprows=0, usecols=(0,1), dtype=str)
train_y_pd = pd.DataFrame(train_y)
for i in range(train_y.shape[0]):
train_y_pd.iloc[i,0] = os.path.join(base_path,train_y[i,0])
train_y_pd.to_csv(os.path.join(base_path, 'y_train.csv'),header=None,index=None)
def load_file(example_list_file):
lines = np.genfromtxt(example_list_file,delimiter=",",dtype=[('col1', 'S120'), ('col2', 'i8')])
examples = []
labels = []
for example,label in lines:
examples.append(example)
labels.append(label)
#convert to numpy array
return np.asarray(examples),np.asarray(labels),len(lines)
def extract_image(filename,height,width):
# print(filename)
image = cv2.imread(filename)
# image = cv2.resize(image,(height,width))
b,g,r = cv2.split(image)
rgb_image = cv2.merge([r,g,b])
return rgb_image
def trans2tfRecord(train_file,name,output_dir,height,width):
if not os.path.exists(output_dir) or os.path.isfile(output_dir):
os.makedirs(output_dir)
_examples,_labels,examples_num = load_file(train_file)
filename = name + '.tfrecords'
writer = tf.python_io.TFRecordWriter(filename)
for i,[example,label] in enumerate(zip(_examples,_labels)):
# print("NO{}".format(i))
#need to convert the example(bytes) to utf-8
example = example.decode("UTF-8")
image = extract_image(example,height,width)
image_raw = image.tostring()
example = tf.train.Example(features=tf.train.Features(feature={
'image_raw':_bytes_feature(image_raw),
'height':_int64_feature(image.shape[0]),
'width': _int64_feature(32),
'depth': _int64_feature(32),
'label': _int64_feature(label)
}))
writer.write(example.SerializeToString())
writer.close()
return filename
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def read_tfRecord(file_tfRecord,shuffle=False):
# 这个函数需要传入一个文件名,系统会自动将它转为一个文件名队列,这个队列存的是训练或测试过程用到的数据
# tf.train.string_input_producer有两个重要的参数,一个是num_epochs,这个设成默认none就行,none表示无限次
# 它表示将全部样本入队次数,一般程序迭代几次就入队几次。程序运行开始,数据就开始出队,为了保证队列一直不空,
# 我们设为none,使全部样本入队无数次(无限循环)。
# 另外一个就是shuffle,shuffle是指在一个epoch内文件的顺序是否被打乱(但是我测试时发现无论是True还是False,其实都打乱了)。
queue = tf.train.string_input_producer([file_tfRecord], shuffle=shuffle)
reader = tf.TFRecordReader()
_,serialized_example = reader.read(queue)
features = tf.parse_single_example(
serialized_example,
features={
'image_raw': tf.FixedLenFeature([], tf.string),
'height': tf.FixedLenFeature([], tf.int64),
'width':tf.FixedLenFeature([], tf.int64),
'depth': tf.FixedLenFeature([], tf.int64),
'label': tf.FixedLenFeature([], tf.int64)
}
)
image = tf.decode_raw(features['image_raw'],tf.uint8)
#height = tf.cast(features['height'], tf.int64)
#width = tf.cast(features['width'], tf.int64)
image = tf.reshape(image,[224,224,3])
image = tf.cast(image, tf.float32)
image = tf.image.per_image_standardization(image)
label = tf.cast(features['label'], tf.int64)
print(image,label)
return image,label
with tf.Session() as sess:
# 训练过程
base_path = os.path.join('images','images224')
data_train_path = os.path.join(base_path,'y_train.csv')
data_test_path = os.path.join(base_path,'y_test.csv')
# 首次执行程序需要运行一旦生成之后就可以注释掉了:利用csv生成y_train.tfrecords和y_test.tfrecords文件,这俩文件是训练集和测试集的样本与标签,
filename = trans2tfRecord(data_train_path, 'y_train', base_path, 224, 224)
filename2 = trans2tfRecord(data_train_path, 'y_test', base_path, 224, 224)
img_batch, path_batch = read_tfRecord(filename, shuffle=True)
img_batch2, path_batch2 = read_tfRecord(filename2, shuffle=False)
image_batches, label_batches = tf.train.batch([img_batch, path_batch], batch_size=batch, capacity=4096)
image_batches2, label_batches2 = tf.train.batch([img_batch2, path_batch2], batch_size=batch, capacity=4096)
tf.local_variables_initializer().run()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord)
# 定义一个模型
model=ATDA(sess=sess)
model.create_model()
# 训练模型:(image_batches,label_batches)是训练集,(image_batches2,label_batches2)是测试集,
model.fit_ATDA(source_train=image_batches, y_train=label_batches,
target_val=image_batches2, y_val=label_batches2,
# n是训练集总数,my_number是测试集总数,my_catelogy是标签种类,batch是迭代次数
nb_epoch=epochs, n = 86524, my_number = 25596, my_catelogy = 2,batch = 16)
coord.request_stop() # 请求线程结束
coord.join() # 等待线程结束
8.model.fit_ATDA(),这部分是训练模型。
def fit_ATDA(source_train, y_train, target_val, y_val, nb_epoch=30,
n = 86524, my_number = 25596, my_catelogy = 2, batch = 4):
for e in range(nb_epoch):
n_batch = 0
for my_batch_train in range(int(n/batch)):
Xu_batch, Yu_batch = self.sess.run([source_train, y_train])
Xu_batch = transform_batch_images(Xu_batch)
Yu_batch = np_utils.to_categorical(Yu_batch, 2)
# print('train label',Yu_batch)
feed_dict = { self.x: Xu_batch, self.y_: Yu_batch ,self.istrain:True}
cost, Ft_loss = self.sess.run([cost, Ft_loss], feed_dict=feed_dict)
n_batch += 1
#every 1000 minibatch print loss
if n_batch % 1000==0:
print("Epoch %d total_loss %f Ft_loss %f" % (e + 1, cost,Ft_loss))
其中,从文件读取部分代码是:
Xu_batch, Yu_batch = self.sess.run([source_train, y_train])
9.测试的代码就不写了,类似8。
参考资料:
1.https://zhuanlan.zhihu.com/p/27238630
2.https://www.cnblogs.com/wktwj/p/7257526.html