近期复现别人代码时需要使用flax、jax、tensorflow,在安装过程中遇到了一些bug,记录在此。
Tips:
tf.config.list_physical_devices('GPU')
可以确认 TensorFlow 使用的是 GPU。TF_FORCE_GPU_ALLOW_GROWTH
设置为 true来
开启内存增长。(参考https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth)Problem 1:
使用tensorflow-datasets读取数据集可能会出现如下bug:
W tensorflow/core/platform/cloud/google_auth_provider.cc:178] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "Not found: Could not locate the credentials file.". Retrieving token from GCE failed with "Aborted: All 10 retry attempts failed. The last failure: Unavailable: Error executing an HTTP request: libcurl code 6 meaning 'Couldn't resolve host name', error details: Couldn't resolve host 'metadata'".
这是由于tensorflow-datasets的版本过高导致的,可以考虑降级2.1.0
Problem 2:
显存不足,可以对XLA或TF的设置进行修改,主要为限制单卡占用比例以及允许动态分配
config.gpu_options.per_process_gpu_memory_fraction = 0.5 # 程序最多只能占用指定gpu50%的显存 config.gpu_options.allow_growth = True #程序按需申请内存
Problem 3:
当使用conda安装cuda、cudnn时jax可能会报如下错
Unimplemented: DNN library is not found.
这是因为cudnn库没有成功被jax链接到,需要设置环境变量 LD_LIBRARY。(参考https://github.com/google/jax/issues/4920)
export LD_LIBRARY_PATH=$CONDA_PREFIX/lib/
如果需要设置每个conda环境单独的LD_LIBRARY环境变量,可以参考如下方法
Problem 4:
jax0.3.15在使用GPU时也可能报如下错误:(参考Jax - ALCF User Guides)
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: no kernel image is available for execution on the device
可以设置以下环境变量来解决
export XLA_FLAGS="--xla_gpu_force_compilation_parallelism=1"
Problem 5:
可能出现如下错误,问题可能是 JAX 试图预分配太多内存(参考https://github.com/google/jax/discussions/6332)
RuntimeError: Unknown: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(3296): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd.handle(), input_data.opaque(), filter_nd.handle(), filter_data.opaque(), conv.handle(), ToConvForwardAlgo(algorithm_desc), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd.handle(), output_data.opaque())'
可以设置以下环境变量来解决
export XLA_PYTHON_CLIENT_MEM_FRACTION=.7