如果你的cudatoolkit是9.x版本的,在执行两个很大的batch做matmal的时候,可能会报一个很奇怪的错误:
这个问题困扰了我好几天。从网上查阅了很多资料,才发现是cublas的内部的一个保护机制。当你对两个batch做matmul的时候,如果batch的大小大于172800(大概是这么一个数),就会报错。不太确定cudatoolkit10.x还有没有类似的问题,但是至少cudatoolkit9.x都会遇到这个问题,所以只能想办法把batch改小一点。
注意这里说的batch大小是说矩阵相乘的前面的维度的综合。比如你要做的操作是:
tf.matmul(tf.ones([512, 1024, 4, 2]), tf.ones([512, 1024, 2, 1]))
也会报错的。虽然后面真实相乘的矩阵很小,但是512*1024>172800了,所以会报错。
不信的话,你可以用下面的程序测试一下:
import tensorflow as tf
import numpy as np
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
tf.Session(config=config).close()
def calc():
N = 15 # works for N <= 14
a = 64
b = 16
X = np.random.rand(N, 11520, b, 1).astype(np.float32)
print(X.nbytes*1e-6, "MB")
W = np.random.rand(N, 11520, a, b).astype(np.float32)
print(W.nbytes*1e-6, "MB")
X_ = tf.constant(X, name="X-constant", dtype=tf.float32)
W_ = tf.constant(W, name="W-constant", dtype=tf.float32)
# tf.matmul(W_, X_, name="mymatmul")
return W_ @ X_
tf.reset_default_graph()
a = calc()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
b = sess.run(a)
sess.close()
print(b.shape)