2.我这里用python3.7的虚拟环境
conda create -n jax python=3.7
conda activate jax
2.下载安装jax和jaxlib等相关依赖包,然后进入虚拟环境安装
cd C:\Users\dz\Downloads\jax-0.2.9_and_jaxlib-0.1.61-cp37-win_amd64
pip install jaxlib-0.1.61-cp37-none-win_amd64.whl
pip install jax==0.2.9
pip install matplotlib
3.写个py文件测试环境是否安装成功
import jax.numpy as jnp
import matplotlib.pyplot as plt
x_jnp=jnp.linspace(0,10,1000)
y_jnp=jnp.sin(x_jnp)*jnp.cos(x_jnp)
print(x_jnp,y_jnp)
plt.plot(x_jnp,y_jnp)
plt.show()
import jax.numpy as jnp
from jax import random
key=random.PRNGKey(0)#随机种子
x=random.normal(key,(10,),dtype=jnp.float32)#生成1维10个数的数组
print(x)
print(type(x))#
import jax.numpy as jnp
from jax import random
import jax
def sum_of_squares(x):
return jnp.sum(x**2)
sum_of_squares_dx=jax.grad(sum_of_squares)#它接受一个用 Python 编写的数值函数,并返回一个新的 Python 函数,该函数计算原始函数的梯度。
x=jnp.asarray([1.0,2.0,3.0,4.0])
print(sum_of_squares(x))#求平方和》30.0
print(sum_of_squares_dx(x))##求平方和函数对每个自变量x的导数》[2. 4. 6. 8.]
2.要找到关于不同参数(或多个)的梯度,您可以设置argnums
import jax.numpy as jnp
import jax
def sum_squared_error(x,y):
return jnp.sum((x-y)**2)
sum_squared_error_dx_dy=jax.value_and_grad(sum_squared_error,argnums=(0,1))
x = jnp.asarray([2.0, 3.0, 4.0, 5.0])
y = jnp.asarray([1.0, 2.0, 3.0, 4.0])
#对x求导:2x-2y;对y求导:2y-2x
print(sum_squared_error_dx_dy(x,y))#[2., 2., 2., 2.];[-2., -2., -2., -2.]
3.需要找到函数的值和梯度用jax.value_and_grad
import jax.numpy as jnp
import jax
def sum_squared_error(x,y):
return jnp.sum((x-y)**2)
sum_squared_error_dx_dy=jax.value_and_grad(sum_squared_error)
x = jnp.asarray([2.0, 3.0, 4.0, 5.0])
y = jnp.asarray([1.0, 2.0, 3.0, 4.0])
#(DeviceArray(4., dtype=float32), DeviceArray([2., 2., 2., 2.], dtype=float32))
print(sum_squared_error_dx_dy(x,y))#jax.value_and_grad(f)(*xs) == (f(*xs), jax.grad(f)(*xs))
4.grad函数内不是一个函数,而是一组元组(中间函数),用has_aux=True
import jax.numpy as jnp
import jax
def sum_squared_error(x,y):
return jnp.sum((x-y)**2),x-y
"""jax.grad is only defined on scalar functions,
and our new function returns a tuple.
But we need to return a tuple to return our intermediate results!
This is where has_aux comes in"""
sum_squared_error_dx_dy=jax.grad(sum_squared_error,has_aux=True)
x = jnp.asarray([2.0, 3.0, 4.0, 5.0])
y = jnp.asarray([1.0, 2.0, 3.0, 4.0])
#(DeviceArray([2., 2., 2., 2.], dtype=float32), DeviceArray([1., 1., 1., 1.], dtype=float32))
print(sum_squared_error_dx_dy(x,y))
5.DeviceArray的修改要按索引,而且是软修改
import jax.numpy as jnp
import numpy as np
#1.修改numpy数组
"""x=np.array([1,2,3])
def in_place_modify(x):
x[0]=4
return None
in_place_modify(x)
print(x)#[4 2 3]"""
#2.修改jnp数组:按索引进行就地修改,旧数组未受影响
y=jnp.array([1,2,3])
def jax_in_place_modify(x):
return x.at[0].set(4)
print(jax_in_place_modify(y))#[4 2 3]
print(y)#[1 2 3]
输入2维数组,第一层3个神经元,第二层1个神经元,输出2维数组的XOR(异或)结果,如下
import random
import itertools
import jax
import jax.numpy as jnp
import numpy as np
learning_rate=1
inputs=jnp.array([[0,0],[0,1],[1,0],[1,1]])
def sigmoid(x):
return 1/(1+jnp.exp(-x))
def net(params,x):
w1,b1,w2,b2=params
hidden=jnp.tanh(jnp.dot(w1,x)+b1)
return sigmoid(jnp.dot(w2,hidden)+b2)#输出0,1分类
def loss(params,x,y):
out=net(params,x)
cross_entropy=-y*jnp.log(out)-(1-y)*jnp.log(1-out)
return cross_entropy
def test_all_inputs(inputs,params):
predictions=[int(net(params,inp)>0.5) for inp in inputs]
for inp,out in zip(inputs,predictions):
print(inp,'->',out)
return (predictions==[np.bitwise_xor(*inp) for inp in inputs])#网络输出结果进行异或运算
#1.jax.grad 接受一个函数并返回一个新函数,该函数计算原始函数的渐变。默认情况下,相对于第一个参数进行渐变;这可以通过 jgn.grad 的 argnums 参数来控制。
loss_grad=jax.grad(loss)
def initial_params():
return [np.random.randn(3,2),np.random.randn(3),np.random.randn(3),np.random.randn()]
params=initial_params()#初始化参数
for n in itertools.count():#迭代
x=inputs[np.random.choice(inputs.shape[0])]#四个数据中随机拿一个数据
y=np.bitwise_xor(*x)#两个值的异或运算
grads=loss_grad(params,x,y)
params=[param-learning_rate*grad for param,grad in zip(params,grads)]#参数更新
if not n%100:
print('Iteration {}'.format(n))#每100次训练测试1次
if test_all_inputs(inputs,params):#如果结果都正确了就结束循环
break
import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax
#1.数据
xs=np.random.normal(size=(100,))
noise=np.random.normal(scale=0.1,size=(100,))
ys=xs*3-1+noise
plt.scatter(xs,ys)
# plt.show()
#2.模型\hat y(x; \theta) = wx + b
def model (theta,x):
w,b=theta
return w*x+b
def loss_fn(theta,x,y):
prediction=model(theta,x)
return jnp.mean((prediction-y)**2)#误差方J(x, y; \theta) = (\hat y - y)^2
#3.参数更新
def update(theta,x,y,lr=0.1):#参数更新\theta_{new} = \theta - 0.1 (\nabla_\theta J) (x, y; \theta)
return theta -lr*jax.grad(loss_fn)(theta,x,y)
theta=jnp.array([1.,1.])
for _ in range(1000):
theta=update(theta,xs,ys)
plt.plot(xs,model(theta,xs))
plt.show()
w,b=theta
print(f"w:{w:<.2f},b:{b:<.2f}") #w:2.99,b:-1.00
pip install tensorflow_datasets -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install tensorflow
import tensorflow as tf
import tensorflow_datasets as tfds
import jax
import jax.numpy as jnp
from jax import jit,grad,random
from jax.experimental import optimizers,stax
num_classes= 10
input_shape=(-1,28*28)
step_size=0.001#学习率
batch_size=128
momentum_mass=0.9
rng=random.PRNGKey(0)
#1.数据
x_train=jnp.load(r"C:\Users\dz\Desktop\JAX\mnist_train_x.npy")
y_train=jnp.load(r"C:\Users\dz\Desktop\JAX\mnist_train_y.npy")
total_train_imgs=len(y_train)
def one_hot_nojit(x,k=10,dtype=jnp.float32):
return jnp.array(x[:,None]==jnp.arange(k),dtype)
y_train=one_hot_nojit(y_train)
ds_train=tf.data.Dataset.from_tensor_slices((x_train,y_train)).shuffle(1024).batch(256).prefetch(tf.data.experimental.AUTOTUNE)
ds_train=tfds.as_numpy(ds_train)
#2.网络
init_random_params,predict=stax.serial(stax.Dense(1024),stax.Relu,stax.Dense(1024),stax.Relu,stax.Dense(10),stax.LogSoftmax)
def pred_check(params,batch):
inputs,targets=batch
predict_result=predict(params,inputs)
predicted_class=jnp.argmax(predict_result,axis=1)
targets=jnp.argmax(targets,axis=1)
return jnp.sum(predicted_class==targets)
def loss(params,batch):
inputs,targets=batch
return jnp.mean(jnp.sum(-targets*predict(params,inputs),axis=1))
opt_init,opt_update,get_params=optimizers.adam(step_size=2e-4)
_,init_params=init_random_params(rng,input_shape)
opt_state=opt_init(init_params)
def update(i,opt_state,batch):
params=get_params(opt_state)
return opt_update(i,grad(loss)(params,batch),opt_state)
#3.训练
for _ in range(17):
itercount=0
for batch_raw in ds_train:
data=batch_raw[0].reshape((-1,28*28))
targets=batch_raw[1].reshape((-1,10))
opt_state=update((itercount),opt_state,(data,targets))
itercount+=1
params=get_params(opt_state)
train_acc=[]
correct_preds=0.0
for batch_raw in ds_train:
data=batch_raw[0].reshape((-1,28*28))
targets=batch_raw[1]
correct_preds+=pred_check(params,(data,targets))
train_acc.append(correct_preds/float(total_train_imgs))
print(f"training set accuracy:{train_acc}")
4个特征值,3分类问题,使用2层感知机进行分类。
from cgitb import reset
from sklearn.datasets import load_iris
import jax.numpy as jnp
from jax import random,grad
import jax
#1.数据
data=load_iris()
iris_data=jnp.float32(data.data)#数据转化为float类型的list
iris_target=jnp.float32(data.target)
iris_data=jax.random.shuffle(random.PRNGKey(17),iris_data)#伪随机打乱数据
iris_target=jax.random.shuffle(random.PRNGKey(17),iris_target)
def one_hot_nojit(x,k=3,dtype=jnp.float32):
return jnp.array(x[:,None]==jnp.arange(k),dtype)
iris_target=one_hot_nojit(iris_target)
#2.网络结构
def Dense(dense_shape=[1,1]):
rng=random.PRNGKey(17)
weight=random.normal(rng,shape=dense_shape)
bias=random.normal(rng,shape=(dense_shape[-1],))
params=[weight,bias]#参数结构
def apply_fun(inputs,params=params):
w,b=params
return jnp.dot(inputs,w)+b#参数与输入数据点乘
return apply_fun
def selu(x,alpha=1.67,lmbda=1.05):
return lmbda*jnp.where(x>0,x,alpha*jnp.exp(x)-alpha)
def softmax(x,axis=-1):
unnormalized=jnp.exp(x)
return unnormalized/unnormalized.sum(axis,keepdims=True)
def cross_entropy(y_true,y_pred):
y_true==jnp.array(y_true)
y_pred=jnp.array(y_pred)
red=-jnp.sum(y_true*jnp.log(y_pred+1e-7),axis=-1)
return red
def mlp(x,params):
a0,b0,a1,b1=params
x=Dense()(x,[a0,b0])
x=jax.nn.selu(x)
x=Dense()(x,[a1,b1])
x=softmax(x,axis=-1)
return x
def loss_mlp(params,x,y):
preds=mlp(x,params)
loss_value=cross_entropy(y,preds)
return jnp.mean(loss_value)
rng=random.PRNGKey(17)
a0=random.normal(rng,shape=(4,5))
b0=random.normal(rng,shape=(5,))
a1=random.normal(rng,shape=(5,3))
b1=random.normal(rng,shape=(3,))
params=[a0,b0,a1,b1]
learning_rate=2.17e-4
#3.训练
for i in range(20000):
loss=loss_mlp(params,iris_data,iris_target)
if i%1000==0:
predict_result=mlp(iris_data,params)
predicted_class=jnp.argmax(predict_result,axis=1)
_iris_target=jnp.argmax(iris_target,axis=1)
accuracy=jnp.sum(predicted_class==_iris_target)/len(_iris_target)
print("i:",i,"loss:",loss,"accuracy:",accuracy)
params_grad=grad(loss_mlp)(params,iris_data,iris_target)
params=[(p-g*learning_rate) for p,g in zip(params,params_grad)]
predict_result=mlp(iris_data,params)
predicted_class=jnp.argmax(predict_result,axis=1)
iris_target=jnp.argmax(iris_target,axis=1)
accuracy=jnp.sum(predicted_class==iris_target)/len(iris_target)
print(accuracy)
[1]https://zhuanlan.zhihu.com/p/56468260
[2]https://jax.readthedocs.io/en/latest/notebooks/quickstart.html
[3]https://github.com/google/jaxa