TensorflowTTS项目中.h5文件转.tflite脚本

FastSpeech2

import tensorflow as tf

import yaml
import numpy as np
import matplotlib.pyplot as plt

import IPython.display as ipd

from tensorflow_tts.processor import LJSpeechProcessor
from tensorflow_tts.processor.ljspeech import LJSPEECH_SYMBOLS

from tensorflow_tts.configs import FastSpeech2Config
from tensorflow_tts.models import TFFastSpeech2

with open('D:\\TensorFlow\\TensorFlowTTS-master\\TensorFlowTTS\\examples\\fastspeech2\\conf\\fastspeech2.baker.v2.yaml') as f:
    config = yaml.load(f, Loader=yaml.Loader)

config = FastSpeech2Config(**config["fastspeech2_params"])
fastspeech2 = TFFastSpeech2(config=config, enable_tflite_convertible=True, name="fastspeech2")
fastspeech2._build()
fastspeech2.load_weights("D:\\TensorFlow\\TensorFlowTTS-master\\TensorFlowTTS\\content\\fastspeech2-200k.h5")

fastspeech2_concrete_function = fastspeech2.inference_tflite.get_concrete_function()

converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [fastspeech2_concrete_function]
)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
                                       tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()

with open('D:\\TensorFlow\\TensorFlowTTS-master\\TensorFlowTTS\\fastspeech_quant.tflite', 'wb') as f:
  f.write(tflite_model)

print('Model size is %f MBs.' % (len(tflite_model) / 1024 / 1024.0) )

 multiband_melgan

import tensorflow as tf

import yaml
import numpy as np
import matplotlib.pyplot as plt

import IPython.display as ipd

from tensorflow_tts.processor import LJSpeechProcessor
from tensorflow_tts.processor.ljspeech import LJSPEECH_SYMBOLS

from tensorflow_tts.configs import MultiBandMelGANGeneratorConfig
from tensorflow_tts.models import TFMBMelGANGenerator

config_file = "D:\\TensorFlow\\TensorFlowTTS-master\\TensorFlowTTS\\examples\\multiband_melgan\\conf\\multiband_melgan.baker.v1.yaml"
model_path = "D:\\TensorFlow\\TensorFlowTTS-master\\TensorFlowTTS\\content\\mb.melgan-920k.h5"
with open(config_file) as conf:
    config = yaml.load(conf, Loader=yaml.Loader)
config = MultiBandMelGANGeneratorConfig(**config['multiband_melgan_generator_params'])
mblemgan = TFMBMelGANGenerator(config=config, name='MBMelGAN')
mblemgan._build()
# Load weights from file
mblemgan.load_weights(model_path)
fake_mels = tf.random.uniform(shape=[4, 256, 80], dtype=tf.float32)
audios = mblemgan.inference(fake_mels)
mb_melgan_concrete_function = mblemgan.inference_tflite.get_concrete_function()
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[mb_melgan_concrete_function]
)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [tf.lite.OpsSet.SELECT_TF_OPS]

# Convert the model.
tflite_model = converter.convert()
with open('D:\\TensorFlow\\TensorFlowTTS-master\\TensorFlowTTS\\mbmelgan.tflite', 'wb') as t_f:
  t_f.write(tflite_model)
print('Model size is %f MBs.' % (len(tflite_model) / 1024 / 1024.0) )

你可能感兴趣的:(tensorflow,TTS,自然语言处理)