Magenta魔改记-4:Melody RNN的数据表示和tfrecord读取

Magenta魔改记-4:Melody RNN的数据表示和tfrecord读取

本文介绍Melody RNN数据表示的具体形式,以及如何读取Melody RNN转换后保存的.tfrecord文件。

Magenta version:1.1.1

数据表示和tfrecord读取

首先,我们以一首最简单的歌曲《小星星》为例。

在一切之前,导入我们需要的库:

import tensorflow as tf
import magenta as mgt
import numpy as np

#加这行是因为jupyter notebook对tf.app.flags.FLAGS有bug
#见https://github.com/tensorflow/tensorflow/issues/17702
tf.app.flags.DEFINE_string('f', '', 'kernel')
WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.

首先,我们先把《小星星》的MIDI文件转换成NoteSequence,使用上一节介绍的方法:

import magenta.scripts.convert_dir_to_note_sequences as cvrt
cvrt.FLAGS.input_dir = r'Dataset\raw\data-representation-example'
cvrt.FLAGS.output_file = r'Dataset\pre\data-representation-example.tfrecord'
cvrt.FLAGS.recursive = True
cvrt.FLAGS.log = 'INFO'
unused_argv = ''
cvrt.main(unused_argv)
INFO:tensorflow:Converting files in 'Dataset\raw\data-representation-example\'.
INFO:tensorflow:0 files converted.
INFO:tensorflow:Converted MIDI file Dataset\raw\data-representation-example\star.mid.

在这之后,我们需要将NoteSequence转换为Melody RNN模型需要的输入形式。

与前一节使用的convert_dir_to_note_sequences.py类似,我们需要给Melody RNN的数据转换程序输入参数。

Melody RNN的数据转换程序为melody_rnn_create_dataset.py。

# 首先,我们需要删除在执行convert_dir_to_note_sequences.py转换中定义过的FLAGS,
# 否则相同名称的FLAGS会相互冲突。
# https://stackoverflow.com/a/51211037/8764874
def del_all_flags(FLAGS):
    flags_dict = FLAGS._flags()
    keys_list = [keys for keys in flags_dict]
    for keys in keys_list:
        FLAGS.__delattr__(keys)


del_all_flags(tf.flags.FLAGS)
import magenta.models.melody_rnn.melody_rnn_create_dataset as create_dataset
print(create_dataset.FLAGS)
magenta.models.melody_rnn.melody_rnn_config_flags:
  --config: Which config to use. Must be one of 'basic', 'lookback', or
    'attention'. Mutually exclusive with `--melody_encoder_decoder`.
  --generator_description: A description of the generator. Overrides the default
    if `--config` is also supplied.
  --generator_id: A unique ID for the generator. Overrides the default if
    `--config` is also supplied.
  --hparams: Comma-separated list of `name=value` pairs. For each pair, the
    value of the hyperparameter named `name` is set to `value`. This mapping is
    merged with the default hyperparameters.
    (default: '')
  --melody_encoder_decoder: Which encoder/decoder to use. Must be one of
    'onehot', 'lookback', or 'key'. Mutually exclusive with `--config`.

magenta.models.melody_rnn.melody_rnn_create_dataset:
  --eval_ratio: Fraction of input to set aside for eval set. Partition is
    randomly selected.
    (default: '0.1')
    (a number)
  --input: TFRecord to read NoteSequence protos from.
  --log: The threshold for what messages will be logged DEBUG, INFO, WARN,
    ERROR, or FATAL.
    (default: 'INFO')
  --output_dir: Directory to write training and eval TFRecord files. The
    TFRecord files are populated with  SequenceExample protos.

absl.flags:
  --flagfile: Insert flag definitions from the given file into the command line.
    (default: '')
  --undefok: comma-separated list of flag names that it is okay to specify on
    the command line even if the program does not define a flag with that name.
    IMPORTANT: flags in this list that have arguments MUST use the --flag=value
    format.
    (default: '')

一个值得注意的问题是,在打印出的FLAG中除了有来自magenta.models.melody_rnn.melody_rnn_create_datasetFLAGS外,还有来自magenta.models.melody_rnn.melody_rnn_config_flags的FLAGS。这是因为在melody_rnn_create_dataset.py中,from magenta.models.melody_rnn import melody_rnn_config_flags一行导入了来自melody_rnn_config_flags.py的FLAGS。对于在jupyter notebook中调试来说,我们可以直接修改create_dataset.FLAGS的子类的值。但是如果你想要直接修改python文件并运行的话,不要忘记同时修改melody_rnn_config_flags.py中的设置参数。

我们开始设置Melody RNN的配置参数。这里我们先使用basic_rnn。我们不需要总去设置所有的参数。有些步骤用不到一些参数,比如在数据转换这步,我们显然就不需要决定网络的规模和LSTM的层数等。同时,有些参数是会根据其他参数的值自动设置。如在generator_descriptiongenerator_id等参数中的描述中,我们可以看到,一旦设置了config,就可以不需要设置这些变量。

接下来,我们进行设置:

create_dataset.FLAGS.config = 'basic_rnn'
create_dataset.FLAGS.input = 'Dataset\pre\data-representation-example.tfrecord'
create_dataset.FLAGS.output_dir = 'Dataset\melody_rnn\data-representation-example'
create_dataset.FLAGS.eval_ratio = 0.1
unused_argv = ''
create_dataset.main(unused_argv)
WARNING:tensorflow:From D:\Anaconda3\envs\magenta\lib\site-packages\magenta\pipelines\pipeline.py:310: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`
WARNING:tensorflow:Key signatures ignored by TranspositionPipeline.
INFO:tensorflow:

Completed.

INFO:tensorflow:Processed 1 inputs total. Produced 1 outputs.
INFO:tensorflow:DAGPipeline_MelodyExtractor_training_melodies_discarded_too_few_pitches: 0
INFO:tensorflow:DAGPipeline_MelodyExtractor_training_melodies_discarded_too_long: 0
INFO:tensorflow:DAGPipeline_MelodyExtractor_training_melodies_discarded_too_short: 0
INFO:tensorflow:DAGPipeline_MelodyExtractor_training_melodies_truncated: 0
INFO:tensorflow:DAGPipeline_MelodyExtractor_training_melody_lengths_in_bars:
  [10,20): 1
INFO:tensorflow:DAGPipeline_MelodyExtractor_training_polyphonic_tracks_discarded: 0
INFO:tensorflow:DAGPipeline_RandomPartition_training_melodies_count: 1
INFO:tensorflow:DAGPipeline_TranspositionPipeline_training_skipped_due_to_range_exceeded: 0
INFO:tensorflow:DAGPipeline_TranspositionPipeline_training_transpositions_generated: 1

根据打印出的信息我们可以发现,Melody RNN的数据清洗工作做得还是相当全面的。在数据转换时,Melody RNN应该是将音乐统一转调,并自动分割,过滤过长或过短的片段。同时,在Dataset\pre\data-representation-example路径中会生成两个文件training_melodies.tfrecordeval_melodies.tfrecord。默认设置下(参数中eval_ratio项),验证集为数据集的10%,由于我们现在只有一条数据,因此验证集为空。

下面我们来看一下如何读取.tfrecord文件,以及在.tfrecord文件中,数据是如何存储的:

读取tfrecord数据需要建立数据队列。在这里,我们还需要提前定义解析数据的维数。下面定义特征数据的代码可以在sequence_example_lib.py中的make_sequence_example找到,这里给出的是旧版的代码,但一样可以读取。

sequence_example_file_paths=[r'Dataset\melody_rnn\data-representation-example\training_melodies.tfrecord']
file_queue = tf.train.string_input_producer(sequence_example_file_paths)
reader = tf.TFRecordReader()
read_queue, serialized_example = reader.read(file_queue)

sequence_features = {
    'inputs': tf.FixedLenSequenceFeature(shape=[38],
                                         dtype=tf.float32),
    'labels': tf.FixedLenSequenceFeature(shape=[],
                                         dtype=tf.int64)}

single_queue, sequence = tf.parse_single_sequence_example(
    serialized_example, sequence_features=sequence_features)

WARNING:tensorflow:From :1: string_input_producer (from tensorflow.python.training.input) is deprecated and will be removed in a future version.
Instructions for updating:
Queue-based input pipelines have been replaced by `tf.data`. Use `tf.data.Dataset.from_tensor_slices(string_tensor).shuffle(tf.shape(input_tensor, out_type=tf.int64)[0]).repeat(num_epochs)`. If `shuffle=False`, omit the `.shuffle(...)`.
WARNING:tensorflow:From D:\Anaconda3\envs\magenta\lib\site-packages\tensorflow\python\training\input.py:278: input_producer (from tensorflow.python.training.input) is deprecated and will be removed in a future version.
Instructions for updating:
Queue-based input pipelines have been replaced by `tf.data`. Use `tf.data.Dataset.from_tensor_slices(input_tensor).shuffle(tf.shape(input_tensor, out_type=tf.int64)[0]).repeat(num_epochs)`. If `shuffle=False`, omit the `.shuffle(...)`.
WARNING:tensorflow:From D:\Anaconda3\envs\magenta\lib\site-packages\tensorflow\python\training\input.py:190: limit_epochs (from tensorflow.python.training.input) is deprecated and will be removed in a future version.
Instructions for updating:
Queue-based input pipelines have been replaced by `tf.data`. Use `tf.data.Dataset.from_tensors(tensor).repeat(num_epochs)`.
WARNING:tensorflow:From D:\Anaconda3\envs\magenta\lib\site-packages\tensorflow\python\training\input.py:199: QueueRunner.__init__ (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
WARNING:tensorflow:From D:\Anaconda3\envs\magenta\lib\site-packages\tensorflow\python\training\input.py:199: add_queue_runner (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
WARNING:tensorflow:From D:\Anaconda3\envs\magenta\lib\site-packages\tensorflow\python\training\input.py:202: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.cast instead.
WARNING:tensorflow:From :2: TFRecordReader.__init__ (from tensorflow.python.ops.io_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Queue-based input pipelines have been replaced by `tf.data`. Use `tf.data.TFRecordDataset`.

关于读取tfrecord更详细的内容,你可以查看这里或这里。tfrecord似乎在新版本中已经逐渐弃用,如果你想从零开始搭建,请了解tf.data。

因为文件读取常常与多线程结合,因此这里读取时我们需要用到coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess, coord)等多线程模块。sequence在这里相当于一个迭代器,每运行一次返回下一条数据。这里我们只取出一条数据。

sess=tf.InteractiveSession()
data_number=1
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess, coord)
try:
    for i in range (data_number):
        seq=sess.run(sequence)
except tf.errors.OutOfRangeError:
    print("done")
finally:
    coord.request_stop()
coord.join(threads)
WARNING:tensorflow:From :3: start_queue_runners (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
print(seq['inputs'].shape)
print(seq['labels'].shape)
(191, 38)
(191,)
print(seq['inputs'])
[[0. 0. 0. ... 0. 0. 0.]
 [1. 0. 0. ... 0. 0. 0.]
 [1. 0. 0. ... 0. 0. 0.]
 ...
 [1. 0. 0. ... 0. 0. 0.]
 [1. 0. 0. ... 0. 0. 0.]
 [1. 0. 0. ... 0. 0. 0.]]
print(np.argmax(seq['inputs'],axis=-1))
[14  0  0  0 14  0  0  0 21  0  0  0 21  0  0  0 23  0  0  0 23  0  0  0
 21  0  0  0  0  0  0  0 19  0  0  0 19  0  0  0 18  0  0  0 18  0  0  0
 16  0  0  0 16  0  0  0 14  0  0  0  0  0  0  0 21  0  0  0 21  0  0  0
 19  0  0  0 19  0  0  0 18  0  0  0 18  0  0  0 16  0  0  0  0  0  0  0
 21  0  0  0 21  0  0  0 19  0  0  0 19  0  0  0 18  0  0  0 18  0  0  0
 16  0  0  0  0  0  0  0 14  0  0  0 14  0  0  0 21  0  0  0 21  0  0  0
 23  0  0  0 23  0  0  0 21  0  0  0  0  0  0  0 19  0  0  0 19  0  0  0
 18  0  0  0 18  0  0  0 16  0  0  0 16  0  0  0 14  0  0  0  0  0  0]
print(seq['labels'])
[ 0  0  0 14  0  0  0 21  0  0  0 21  0  0  0 23  0  0  0 23  0  0  0 21
  0  0  0  0  0  0  0 19  0  0  0 19  0  0  0 18  0  0  0 18  0  0  0 16
  0  0  0 16  0  0  0 14  0  0  0  0  0  0  0 21  0  0  0 21  0  0  0 19
  0  0  0 19  0  0  0 18  0  0  0 18  0  0  0 16  0  0  0  0  0  0  0 21
  0  0  0 21  0  0  0 19  0  0  0 19  0  0  0 18  0  0  0 18  0  0  0 16
  0  0  0  0  0  0  0 14  0  0  0 14  0  0  0 21  0  0  0 21  0  0  0 23
  0  0  0 23  0  0  0 21  0  0  0  0  0  0  0 19  0  0  0 19  0  0  0 18
  0  0  0 18  0  0  0 16  0  0  0 16  0  0  0 14  0  0  0  0  0  0  0]

我们可以看到,在inputs中,数据以one-hot vector的方式储存,而在labels中,数据则直接以数字的方式储存。对比两个序列我们可以发现,labels序列是inputs序列左移一个时间点。因此,相当于输入 X 1 X_1 X1 X t X_t Xt,输出 X 2 X_2 X2 X t + 1 X_{t+1} Xt+1

同时,由于我们的乐曲中不含有休止符,因此没有出现“1”代表的note-off。由于basic_rnn只表示[C3,C5](MIDI音高为[48,72])的音高范围,还含有两个特殊标记(note-offno-event),因此第一个音C4(MIDI音高为60)则对应着60-48+2=14。之后的数字则以此类推。同时,第一个音“14”的后面有3个“0”可以推断出在这里一个时间点代表着1/4个四分音符的长度。

你可能感兴趣的:(Magenta魔改记-4:Melody RNN的数据表示和tfrecord读取)