tensorflow2.x多gpu训练

以前经常用

import tensorflow as tf

strategy = tf.distribute.MirroredStrategy()
print('Number of devices: %d' % strategy.num_replicas_in_sync)

with strategy.scope():
	# model = 
	# model.compile()

最近用这个会报错:INTERNAL: NCCL: unhandled system error. Set NCCL_DEBUG=WARN for detail.。不知道为什么,所以改用是:

import tensorflow as tf

strategy = tf.distribute.MultiWorkerMirroredStrategy()
print('Number of devices: %d' % strategy.num_replicas_in_sync)
with strategy.scope():
	# model = 
	# model.compile()	

特意记录一下。

你可能感兴趣的:(tensorflow2,tensorflow,人工智能,python)