本文介绍Melody RNN数据表示的具体形式,以及如何读取Melody RNN转换后保存的.tfrecord文件。
首先,我们以一首最简单的歌曲《小星星》为例。
在一切之前,导入我们需要的库:
import tensorflow as tf
import magenta as mgt
import numpy as np
#加这行是因为jupyter notebook对tf.app.flags.FLAGS有bug
#见https://github.com/tensorflow/tensorflow/issues/17702
tf.app.flags.DEFINE_string('f', '', 'kernel')
WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
* https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
* https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.
首先,我们先把《小星星》的MIDI文件转换成NoteSequence
,使用上一节介绍的方法:
import magenta.scripts.convert_dir_to_note_sequences as cvrt
cvrt.FLAGS.input_dir = r'Dataset\raw\data-representation-example'
cvrt.FLAGS.output_file = r'Dataset\pre\data-representation-example.tfrecord'
cvrt.FLAGS.recursive = True
cvrt.FLAGS.log = 'INFO'
unused_argv = ''
cvrt.main(unused_argv)
INFO:tensorflow:Converting files in 'Dataset\raw\data-representation-example\'.
INFO:tensorflow:0 files converted.
INFO:tensorflow:Converted MIDI file Dataset\raw\data-representation-example\star.mid.
在这之后,我们需要将NoteSequence转换为Melody RNN模型需要的输入形式。
与前一节使用的convert_dir_to_note_sequences.py
类似,我们需要给Melody RNN的数据转换程序输入参数。
Melody RNN的数据转换程序为melody_rnn_create_dataset.py。
# 首先,我们需要删除在执行convert_dir_to_note_sequences.py转换中定义过的FLAGS,
# 否则相同名称的FLAGS会相互冲突。
# https://stackoverflow.com/a/51211037/8764874
def del_all_flags(FLAGS):
flags_dict = FLAGS._flags()
keys_list = [keys for keys in flags_dict]
for keys in keys_list:
FLAGS.__delattr__(keys)
del_all_flags(tf.flags.FLAGS)
import magenta.models.melody_rnn.melody_rnn_create_dataset as create_dataset
print(create_dataset.FLAGS)
magenta.models.melody_rnn.melody_rnn_config_flags:
--config: Which config to use. Must be one of 'basic', 'lookback', or
'attention'. Mutually exclusive with `--melody_encoder_decoder`.
--generator_description: A description of the generator. Overrides the default
if `--config` is also supplied.
--generator_id: A unique ID for the generator. Overrides the default if
`--config` is also supplied.
--hparams: Comma-separated list of `name=value` pairs. For each pair, the
value of the hyperparameter named `name` is set to `value`. This mapping is
merged with the default hyperparameters.
(default: '')
--melody_encoder_decoder: Which encoder/decoder to use. Must be one of
'onehot', 'lookback', or 'key'. Mutually exclusive with `--config`.
magenta.models.melody_rnn.melody_rnn_create_dataset:
--eval_ratio: Fraction of input to set aside for eval set. Partition is
randomly selected.
(default: '0.1')
(a number)
--input: TFRecord to read NoteSequence protos from.
--log: The threshold for what messages will be logged DEBUG, INFO, WARN,
ERROR, or FATAL.
(default: 'INFO')
--output_dir: Directory to write training and eval TFRecord files. The
TFRecord files are populated with SequenceExample protos.
absl.flags:
--flagfile: Insert flag definitions from the given file into the command line.
(default: '')
--undefok: comma-separated list of flag names that it is okay to specify on
the command line even if the program does not define a flag with that name.
IMPORTANT: flags in this list that have arguments MUST use the --flag=value
format.
(default: '')
一个值得注意的问题是,在打印出的FLAG中除了有来自magenta.models.melody_rnn.melody_rnn_create_dataset
的FLAGS
外,还有来自magenta.models.melody_rnn.melody_rnn_config_flags
的FLAGS。这是因为在melody_rnn_create_dataset.py
中,from magenta.models.melody_rnn import melody_rnn_config_flags
一行导入了来自melody_rnn_config_flags.py
的FLAGS。对于在jupyter notebook中调试来说,我们可以直接修改create_dataset.FLAGS
的子类的值。但是如果你想要直接修改python文件并运行的话,不要忘记同时修改melody_rnn_config_flags.py
中的设置参数。
我们开始设置Melody RNN的配置参数。这里我们先使用basic_rnn
。我们不需要总去设置所有的参数。有些步骤用不到一些参数,比如在数据转换这步,我们显然就不需要决定网络的规模和LSTM的层数等。同时,有些参数是会根据其他参数的值自动设置。如在generator_description
、generator_id
等参数中的描述中,我们可以看到,一旦设置了config
,就可以不需要设置这些变量。
接下来,我们进行设置:
create_dataset.FLAGS.config = 'basic_rnn'
create_dataset.FLAGS.input = 'Dataset\pre\data-representation-example.tfrecord'
create_dataset.FLAGS.output_dir = 'Dataset\melody_rnn\data-representation-example'
create_dataset.FLAGS.eval_ratio = 0.1
unused_argv = ''
create_dataset.main(unused_argv)
WARNING:tensorflow:From D:\Anaconda3\envs\magenta\lib\site-packages\magenta\pipelines\pipeline.py:310: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and:
`tf.data.TFRecordDataset(path)`
WARNING:tensorflow:Key signatures ignored by TranspositionPipeline.
INFO:tensorflow:
Completed.
INFO:tensorflow:Processed 1 inputs total. Produced 1 outputs.
INFO:tensorflow:DAGPipeline_MelodyExtractor_training_melodies_discarded_too_few_pitches: 0
INFO:tensorflow:DAGPipeline_MelodyExtractor_training_melodies_discarded_too_long: 0
INFO:tensorflow:DAGPipeline_MelodyExtractor_training_melodies_discarded_too_short: 0
INFO:tensorflow:DAGPipeline_MelodyExtractor_training_melodies_truncated: 0
INFO:tensorflow:DAGPipeline_MelodyExtractor_training_melody_lengths_in_bars:
[10,20): 1
INFO:tensorflow:DAGPipeline_MelodyExtractor_training_polyphonic_tracks_discarded: 0
INFO:tensorflow:DAGPipeline_RandomPartition_training_melodies_count: 1
INFO:tensorflow:DAGPipeline_TranspositionPipeline_training_skipped_due_to_range_exceeded: 0
INFO:tensorflow:DAGPipeline_TranspositionPipeline_training_transpositions_generated: 1
根据打印出的信息我们可以发现,Melody RNN的数据清洗工作做得还是相当全面的。在数据转换时,Melody RNN应该是将音乐统一转调,并自动分割,过滤过长或过短的片段。同时,在Dataset\pre\data-representation-example
路径中会生成两个文件training_melodies.tfrecord
和eval_melodies.tfrecord
。默认设置下(参数中eval_ratio
项),验证集为数据集的10%,由于我们现在只有一条数据,因此验证集为空。
下面我们来看一下如何读取.tfrecord
文件,以及在.tfrecord
文件中,数据是如何存储的:
读取tfrecord数据需要建立数据队列。在这里,我们还需要提前定义解析数据的维数。下面定义特征数据的代码可以在sequence_example_lib.py中的make_sequence_example
找到,这里给出的是旧版的代码,但一样可以读取。
sequence_example_file_paths=[r'Dataset\melody_rnn\data-representation-example\training_melodies.tfrecord']
file_queue = tf.train.string_input_producer(sequence_example_file_paths)
reader = tf.TFRecordReader()
read_queue, serialized_example = reader.read(file_queue)
sequence_features = {
'inputs': tf.FixedLenSequenceFeature(shape=[38],
dtype=tf.float32),
'labels': tf.FixedLenSequenceFeature(shape=[],
dtype=tf.int64)}
single_queue, sequence = tf.parse_single_sequence_example(
serialized_example, sequence_features=sequence_features)
WARNING:tensorflow:From :1: string_input_producer (from tensorflow.python.training.input) is deprecated and will be removed in a future version.
Instructions for updating:
Queue-based input pipelines have been replaced by `tf.data`. Use `tf.data.Dataset.from_tensor_slices(string_tensor).shuffle(tf.shape(input_tensor, out_type=tf.int64)[0]).repeat(num_epochs)`. If `shuffle=False`, omit the `.shuffle(...)`.
WARNING:tensorflow:From D:\Anaconda3\envs\magenta\lib\site-packages\tensorflow\python\training\input.py:278: input_producer (from tensorflow.python.training.input) is deprecated and will be removed in a future version.
Instructions for updating:
Queue-based input pipelines have been replaced by `tf.data`. Use `tf.data.Dataset.from_tensor_slices(input_tensor).shuffle(tf.shape(input_tensor, out_type=tf.int64)[0]).repeat(num_epochs)`. If `shuffle=False`, omit the `.shuffle(...)`.
WARNING:tensorflow:From D:\Anaconda3\envs\magenta\lib\site-packages\tensorflow\python\training\input.py:190: limit_epochs (from tensorflow.python.training.input) is deprecated and will be removed in a future version.
Instructions for updating:
Queue-based input pipelines have been replaced by `tf.data`. Use `tf.data.Dataset.from_tensors(tensor).repeat(num_epochs)`.
WARNING:tensorflow:From D:\Anaconda3\envs\magenta\lib\site-packages\tensorflow\python\training\input.py:199: QueueRunner.__init__ (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
WARNING:tensorflow:From D:\Anaconda3\envs\magenta\lib\site-packages\tensorflow\python\training\input.py:199: add_queue_runner (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
WARNING:tensorflow:From D:\Anaconda3\envs\magenta\lib\site-packages\tensorflow\python\training\input.py:202: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.cast instead.
WARNING:tensorflow:From :2: TFRecordReader.__init__ (from tensorflow.python.ops.io_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Queue-based input pipelines have been replaced by `tf.data`. Use `tf.data.TFRecordDataset`.
关于读取tfrecord更详细的内容,你可以查看这里或这里。tfrecord似乎在新版本中已经逐渐弃用,如果你想从零开始搭建,请了解tf.data。
因为文件读取常常与多线程结合,因此这里读取时我们需要用到coord = tf.train.Coordinator()
与threads = tf.train.start_queue_runners(sess, coord)
等多线程模块。sequence
在这里相当于一个迭代器,每运行一次返回下一条数据。这里我们只取出一条数据。
sess=tf.InteractiveSession()
data_number=1
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess, coord)
try:
for i in range (data_number):
seq=sess.run(sequence)
except tf.errors.OutOfRangeError:
print("done")
finally:
coord.request_stop()
coord.join(threads)
WARNING:tensorflow:From :3: start_queue_runners (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
print(seq['inputs'].shape)
print(seq['labels'].shape)
(191, 38)
(191,)
print(seq['inputs'])
[[0. 0. 0. ... 0. 0. 0.]
[1. 0. 0. ... 0. 0. 0.]
[1. 0. 0. ... 0. 0. 0.]
...
[1. 0. 0. ... 0. 0. 0.]
[1. 0. 0. ... 0. 0. 0.]
[1. 0. 0. ... 0. 0. 0.]]
print(np.argmax(seq['inputs'],axis=-1))
[14 0 0 0 14 0 0 0 21 0 0 0 21 0 0 0 23 0 0 0 23 0 0 0
21 0 0 0 0 0 0 0 19 0 0 0 19 0 0 0 18 0 0 0 18 0 0 0
16 0 0 0 16 0 0 0 14 0 0 0 0 0 0 0 21 0 0 0 21 0 0 0
19 0 0 0 19 0 0 0 18 0 0 0 18 0 0 0 16 0 0 0 0 0 0 0
21 0 0 0 21 0 0 0 19 0 0 0 19 0 0 0 18 0 0 0 18 0 0 0
16 0 0 0 0 0 0 0 14 0 0 0 14 0 0 0 21 0 0 0 21 0 0 0
23 0 0 0 23 0 0 0 21 0 0 0 0 0 0 0 19 0 0 0 19 0 0 0
18 0 0 0 18 0 0 0 16 0 0 0 16 0 0 0 14 0 0 0 0 0 0]
print(seq['labels'])
[ 0 0 0 14 0 0 0 21 0 0 0 21 0 0 0 23 0 0 0 23 0 0 0 21
0 0 0 0 0 0 0 19 0 0 0 19 0 0 0 18 0 0 0 18 0 0 0 16
0 0 0 16 0 0 0 14 0 0 0 0 0 0 0 21 0 0 0 21 0 0 0 19
0 0 0 19 0 0 0 18 0 0 0 18 0 0 0 16 0 0 0 0 0 0 0 21
0 0 0 21 0 0 0 19 0 0 0 19 0 0 0 18 0 0 0 18 0 0 0 16
0 0 0 0 0 0 0 14 0 0 0 14 0 0 0 21 0 0 0 21 0 0 0 23
0 0 0 23 0 0 0 21 0 0 0 0 0 0 0 19 0 0 0 19 0 0 0 18
0 0 0 18 0 0 0 16 0 0 0 16 0 0 0 14 0 0 0 0 0 0 0]
我们可以看到,在inputs
中,数据以one-hot vector的方式储存,而在labels
中,数据则直接以数字的方式储存。对比两个序列我们可以发现,labels
序列是inputs
序列左移一个时间点。因此,相当于输入 X 1 X_1 X1到 X t X_t Xt,输出 X 2 X_2 X2到 X t + 1 X_{t+1} Xt+1。
同时,由于我们的乐曲中不含有休止符,因此没有出现“1”代表的note-off
。由于basic_rnn
只表示[C3,C5](MIDI音高为[48,72])的音高范围,还含有两个特殊标记(note-off
与no-event
),因此第一个音C4(MIDI音高为60)则对应着60-48+2=14。之后的数字则以此类推。同时,第一个音“14”的后面有3个“0”可以推断出在这里一个时间点代表着1/4个四分音符的长度。