机器学习加速利器jax,让numpy加速30倍

jax.numpy是CPU、GPU和TPU上的numpy,具有出色的自动差异化功能,可用于高性能机器学习研究

我今天就来试一试到底多快。我在同一台bu带gpu的机器上进行试验

首先我们得安装jax

pip install jax jaxlib

先试一下原生的numpy 

import numpy as np  
import time
x = np.random.random([5000, 5000]).astype(np.float32)
try:
    st=time.time()
    y=np.matmul(x, x)
except Exception:
    print("erro")
print(time.time()-st)
print(y)

运行结果:

[root@node opt]# python np.py
4.424036026000977
[[1236.3004 1240.3048 1211.4501 ... 1225.7804 1237.1368 1235.1566]
 [1235.5778 1246.7327 1208.7142 ... 1238.117  1232.439  1226.5779]
 [1235.0111 1244.4628 1211.5264 ... 1238.5541 1246.9045 1244.6909]
 ...
 [1229.7677 1238.8345 1210.4467 ... 1219.8604 1234.0862 1220.1482]
 [1231.9464 1251.9636 1212.1384 ... 1235.8513 1236.8677 1240.5355]
 [1254.0636 1265.74   1241.6528 ... 1245.015  1259.153  1247.0613]]

再来试一下jax带的numpy

import jax.numpy as np
from jax import random
import time
x = random.uniform(random.PRNGKey(0), [5000, 5000])
st=time.time()
try:
   y=np.matmul(x, x)
except Exception:
    print("erro")
print(time.time()-st)
print(y)

结果:

[root@node opt]# python jax_np.py
/opt/AN/lib/python3.7/site-packages/jax/lib/xla_bridge.py:122: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')
0.013895750045776367
[[1261.0647 1244.5797 1237.2269 ... 1264.7208 1246.0367 1260.5391]
 [1256.1    1239.737  1237.5562 ... 1257.1333 1243.5856 1243.5979]
 [1261.2687 1239.5006 1250.6697 ... 1259.8387 1250.6825 1248.5712]
 ...
 [1265.9805 1230.9077 1244.4961 ... 1264.2374 1241.5995 1244.9274]
 [1262.9971 1253.961  1256.2424 ... 1266.3489 1255.1581 1274.1865]
 [1273.3524 1252.4921 1261.0496 ... 1273.2394 1272.829  1267.7483]]

我们可以看到,没有jax的numpy运行了差不多4.4秒,而带了jax的numpy直接才0.014,速度基本上提升了30倍。也太快了

同样地,jax下面还有jax.scipy等代替原生的scipy

 

你可能感兴趣的:(机器学习)