环境:pycharm
框架:tensorflow、keras
首先,我们都知道keras官网中有给出一个如何在keras使用tfrecord格式的mnist数据的实例(https://keras.io/examples/mnist_tfrecord/)。但是上面的代码并没有给出具体如何读取tfrecord文件,仅用一行read_data_sets()带过。其实,如果数据本身就是较小的图片格式,并不需要写进tfrecord中,那么就可以参考这位朋友的方法:
https://baijiahao.baidu.com/s?id=1628460932421002169&wfr=spider&for=pc
笔者的数据是来自于弥散核磁共振成像预处理后生成的FA文件,是一个三维的以.nii.gz格式存储的脑图像,大小是145 * 170 * 145 * 1,如果用上面的方法,在读取数据中都将花费大量的时间。于是,笔者稍微对此修改了一下。
读写tfrecord部分,如果是其他数据,可忽略下面两段代码。
以下代码是在python中读取FA数据,并将其写入tfrecord。
savePath = '/home/wenjingxi/MRI/tfrecord'
dataPathList = glob('/media//MyFA/*dti_FA.nii.gz')
#文件缩小比例,可以选择不缩小,缩小原因是数据太大难以训练。
zoom_rate = 0.4
if not os.path.isdir(savePath):
os.makedirs(savePath)
#读取量表
def read_csv(filePath):
csv_file = csv.reader(open(filePath))
l = {}
for r in csv_file:
l[r[0]] = list(map(int, r[6:12]))
return l
t1 = time.time()
print('zoom rate: {}'.format(zoom_rate))
random.shuffle(dataPathList)
dataPathList_1 = dataPathList[0:len(dataPathList)]
labels = read_csv('label.csv')
for i in range(len(dataPathList_1)):
savePath_t = os.path.join(savePath, 'dataset_{}.tfrecord'.format(i))
writer = tf.python_io.TFRecordWriter(savePath_t)
p_fa = dataPathList_1[i]
data_fa, affine_fa = load_nifti(p_fa)
data = data_fa
print('data shape:{}'.format(data.shape))
data = nd.interpolation.zoom(data, zoom_rate, prefilter=False)
print("data shape:{}".format(data.shape))
m = re.search('[0-9]{6}', p_fa)
seq = m.group()
print('seq:' + seq)
print(labels[seq])
img_raw = data.tobytes()
print('cut last 4 num:', labels[seq])
example = tf.train.Example(features=tf.train.Features(feature={"label": tf.train.Feature(
int64_list=tf.train.Int64List(value=labels[seq])),'img_raw': tf.train.Feature(
bytes_list=tf.train.BytesList(value=[img_raw]))
}))
n = writer.write(example.SerializeToString())
print('No {} finish time cost:{} min'.format(i, (time.time() - t1) // 60))
writer.close()
以下为读取tfrecord部分,如果是其他数据,可以采用其他简便的方式,
def _parse_function_60(example_proto):
features = {"label": tf.FixedLenFeature([], tf.int64),"img_raw": tf.FixedLenFeature([], tf.string)}
parsed_features = tf.parse_single_example(example_proto, features)
img = tf.decode_raw(parsed_features['img_raw'], tf.float32)
img = tf.reshape(img, [58, 70, 58, 1])
img = tf.cast(img, tf.float32)
print(parsed_features['label'])
print('img shape~~~~~~~~~~~~~~~~:{}'.format(img.get_shape()))
label = tf.cast(parsed_features['label'], tf.int64)
print(label)
label=tf.reshape(label,[1,2])
return img, label
def load_data(sess,filename,batch_size,zoom_rate,shuffle_buffer=None):
dataset = tf.data.TFRecordDataset(filename)
if zoom_rate==60:
_parse_function=_parse_function_60
dataset = dataset.map(_parse_function)
dataset = dataset.repeat()
dataset = dataset.batch(batch_size)
iterator = dataset.make_initializable_iterator()
next_batch = iterator.get_next()
sess.run(iterator.initializer)
return next_batch
def load_data_with_val(sess,batch_size,zoom_rate,shuffle_buffer=None,cross=0,brain_area=None,modal='flirt'):
data_dir = '/media/tfrecords/3_flirt/0.4zoom_rate'
file_list = []
for i in range(3):
p = os.path.join(data_dir, 'dataset_{}.tfrecord'.format(i))
file_list.append(p)
print(file_list)
val_file=file_list[1]
test_file=file_list[0]
train_files=[file for file in file_list if file!=val_file and file!=test_file]
next_batch_t=load_data(sess, filename=train_files, batch_size=batch_size, zoom_rate=zoom_rate, shuffle_buffer=shuffle_buffer)
next_batch_v = load_data(sess, filename=val_file, batch_size=batch_size, zoom_rate=zoom_rate, shuffle_buffer=shuffle_buffer)
next_batch_test = load_data(sess, filename=test_file, batch_size=batch_size, zoom_rate=zoom_rate, shuffle_buffer=shuffle_buffer)
return next_batch_t,next_batch_v,next_batch_test
从这里开始的sess和读取tfrecord的sess是同一个,只不过为了批量读取,将dataset的tensor放到My_Costom_Generator中run了。
模型训练,将training_set和val_set两个tensor传给train()函数,用fit_generator()函数从重载函数My_Custom_Generator()中获取数据,要记得把session传过去。
def train(sess, training_set, val_set):
my_training_batch_generator = My_Custom_Generator(sess, training_set, self.batch_size, 855)
my_validation_batch_generator = My_Custom_Generator(sess, val_set, self.batch_size, 100)
model.fit_generator(generator=my_training_batch_generator,
steps_per_epoch=15,
epochs=10,
verbose=1,
validation_data=my_validation_batch_generator,
validation_steps = 3)
重载My_Custom_Generator(),用sess.run()将上面的training_set/val_set读出来。
class My_Custom_Generator(keras.utils.Sequence):
def __init__(self, sess, data_set, batch_size, dataset_size):
self.sess = sess
self.data_set = data_set
self.batch_size = batch_size
self.dataset_size = dataset_size
def __len__(self):
return (np.ceil(self.dataset_size / float(self.batch_size))).astype(np.int)
def __getitem__(self, idx):
data, label= self.sess.run(self.data_set)
label = np.squeeze(label)
return data, label
本文采用的是3D数据,在缩小到0.4倍后再放入模型的,模型部分与本文无关,读者可根据自己的需求自行编写。