使用Flax、Jax、TF经历记录

近期复现别人代码时需要使用flax、jax、tensorflow,在安装过程中遇到了一些bug,记录在此。

Tips:

  • 使用import jax.numpy as jnp; a = jnp.zeros([2,3]); print(a.device()) 可以确认Jax使用的是GPU。
  • 使用 tf.config.list_physical_devices('GPU') 可以确认 TensorFlow 使用的是 GPU。
  • 默认情况下,TensorFlow 会映射进程可见的所有 GPU(取决于 CUDA_VISIBLE_DEVICES)的几乎全部内存。这是为了减少内存碎片,更有效地利用设备上相对宝贵的 GPU 内存资源。可以将环境变量 TF_FORCE_GPU_ALLOW_GROWTH 设置为 true来开启内存增长。(参考https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth)
  • 当第一个 JAX 操作运行时,JAX 将预分配 90% 当前可用的 GPU 内存。可以设置XLA_PYTHON_CLIENT_PREALLOCATE=false来禁用预分配,或者设置XLA_PYTHON_CLIENT_MEM_FRACTION=.XX来定义比例。(参考GPU memory allocation — JAX documentation)
  • Jax需要提前安装cuda和cudnn,可以通过canda安装,例如conda install cudatoolkit cudnn -c conda-forge; conda install cuda-nvcc -c nvidia

Problem 1:

使用tensorflow-datasets读取数据集可能会出现如下bug:

W tensorflow/core/platform/cloud/google_auth_provider.cc:178] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "Not found: Could not locate the credentials file.". Retrieving token from GCE failed with "Aborted: All 10 retry attempts failed. The last failure: Unavailable: Error executing an HTTP request: libcurl code 6 meaning 'Couldn't resolve host name', error details: Couldn't resolve host 'metadata'".

 这是由于tensorflow-datasets的版本过高导致的,可以考虑降级2.1.0

Problem 2:

显存不足,可以对XLA或TF的设置进行修改,主要为限制单卡占用比例以及允许动态分配

config.gpu_options.per_process_gpu_memory_fraction = 0.5 # 程序最多只能占用指定gpu50%的显存 config.gpu_options.allow_growth = True #程序按需申请内存

Problem 3:

当使用conda安装cuda、cudnn时jax可能会报如下错

Unimplemented: DNN library is not found. 

这是因为cudnn库没有成功被jax链接到,需要设置环境变量 LD_LIBRARY。(参考https://github.com/google/jax/issues/4920)

export LD_LIBRARY_PATH=$CONDA_PREFIX/lib/

使用Flax、Jax、TF经历记录_第1张图片

如果需要设置每个conda环境单独的LD_LIBRARY环境变量,可以参考如下方法

使用Flax、Jax、TF经历记录_第2张图片

 Problem 4:

jax0.3.15在使用GPU时也可能报如下错误:(参考Jax - ALCF User Guides)

jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: no kernel image is available for execution on the device

可以设置以下环境变量来解决

export XLA_FLAGS="--xla_gpu_force_compilation_parallelism=1"

Problem 5:

可能出现如下错误,问题可能是 JAX 试图预分配太多内存(参考https://github.com/google/jax/discussions/6332)

RuntimeError: Unknown: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(3296): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd.handle(), input_data.opaque(), filter_nd.handle(), filter_data.opaque(), conv.handle(), ToConvForwardAlgo(algorithm_desc), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd.handle(), output_data.opaque())'

可以设置以下环境变量来解决

export XLA_PYTHON_CLIENT_MEM_FRACTION=.7

你可能感兴趣的:(ubuntu,tensorflow,python,深度学习,机器学习,ubuntu)