【错误】DistributionStrategy is not supported by tf.keras.models.Model.fit_generator

在tensorflow 1.15环境下,使用MultiWorkerMirroredStrategy分布式方法时,出现错误:

NotImplementedError: `fit_generator` is not supported for models compiled with tf.distribute.Strategy.

代码如下:

import numpy as np
import tensorflow as tf
# strategy = tf.distribute.MirroredStrategy()
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()


def generator(batch_size=100):
    while True:
        data = np.random.random((100, 1))
        yield {'inp': data}, 3 * data


with strategy.scope():
    model = tf.keras.models.Sequential()
    model.add(tf.keras.layers.Dense(1))
    model.compile('Adam', 'mae')
    model.fit_generator(generator(), steps_per_epoch=1000, epochs=10)

Traceback (most recent call last):
  File "", line 1, in 
  File "/snap/pycharm-professional/147/helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "/snap/pycharm-professional/147/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/home/tsygankov/Development/next/bin/tmp1.py", line 17, in 
    model.fit_generator(generator(), steps_per_epoch=1000, epochs=10)
  File "/home/tsygankov/Development/next/.vtf2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py", line 1277, in fit_generator
    raise NotImplementedError('`fit_generator` is not supported for '
NotImplementedError: `fit_generator` is not supported for models compiled with tf.distribute.Strategy.

原因:

在方法在tensorflow 2.1及以上版本才能结合fit_generator用,在1.15版本只能和fit配合用,替代的方式是使用keras的接口 multi_gpu_model

from tensorflow.python.keras.utils.multi_gpu_utils import multi_gpu_model

参考:DistributionStrategy is not supported by tf.keras.models.Model.fit_generator · Issue #31231 · tensorflow/tensorflow (github.com)

你可能感兴趣的:(keras,python,tensorflow)