显卡驱动不会对下面的安装产生影响
首先在CUDA官网上下载所需要的CUDA版本:https://developer.nvidia.com/cuda-downloads
按照提示安装CUDA,记录安装路径,默认路径为:C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
在cuDNN官网下载CUDA版本所对应的cuDNN版本:https://developer.nvidia.com/cudnn (需要登陆,可以用微信登陆)
下载好之后直接解压文件
讲下列文件复制到CUDA相应的位置即可
Copy \cuda\bin\cudnn64_7.dll to C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v9.0\bin.
Copy \cuda\ include\cudnn.h to C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v9.0\include.
Copy \cuda\lib\x64\cudnn.lib to C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v9.0\lib\x64.
直接在windows命令行窗口中用pip安装对应版本的tensorflow-gpu
tensorflow官网安装 pip install tensorflow-gpu1.13.1 (速度比较慢)
清华镜像安装 pip install -i https://pypi.mirrors.ustc.edu.cn/simple/ tensorflow-gpu1.13.1 (速度比较快)
import tensorflow as tf
# Creates a graph.
with tf.device('/cpu:0'):
a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3], name='a')
b = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[3, 2], name='b')
c = tf.matmul(a, b)
# Creates a session with log_device_placement set to True.
sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))
# Runs the op.
print(sess.run(c))
import tensorflow as tf
# Creates a graph.
with tf.device('/device:GPU:2'):
a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3], name='a')
b = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[3, 2], name='b')
c = tf.matmul(a, b)
# Creates a session with allow_soft_placement and log_device_placement set
# to True.
sess = tf.Session(config=tf.ConfigProto(
allow_soft_placement=True, log_device_placement=True))
# Runs the op.
print(sess.run(c))
import tensorflow as tf
# Creates a graph.
c = []
for d in ['/device:GPU:0', '/device:GPU:1']:
with tf.device(d):
a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3])
b = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[3, 2])
c.append(tf.matmul(a, b))
with tf.device('/cpu:0'):
sum = tf.add_n(c)
# Creates a session with log_device_placement set to True.
sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))
# Runs the op.
print(sess.run(sum))