深度学习之简单分类

深度学习之简单分类

简单二元分类

制造数据

from sklearn.model_selection import train_test_split
from sklearn import datasets
import matplotlib.pyplot as plt
from tensorflow import keras

X,y = datasets.make_blobs(n_samples=1000,random_state=8,centers=2)

plt.scatter(X[:,0],X[:,1],c=y)
plt.show()

构建模型并训练

model = keras.models.Sequential([
    keras.layers.Dense(32,input_shape=X.shape[1:]),
    keras.layers.Dense(1,activation=keras.activations.sigmoid)]
)
model.summary()

model.compile(loss = keras.losses.binary_crossentropy,optimizer = keras.optimizers.RMSprop(learning_rate=0.1),metrics = [keras.metrics.Accuracy()])
model.fit(X,y,validation_split=0.25,epochs = 20)

查看测试数据和预测数据

print(y[0:10])
y_pre = model.predict(X[0:10])
print(y_pre)


[0 1 1 0 0 1 0 1 1 1]

[[0. 1. 1. 0. 0. 1. 0. 1. 1. 1.]]

多分类

制造数据

from sklearn.model_selection import train_test_split
from sklearn import datasets
import matplotlib.pyplot as plt
from tensorflow import keras

X,y = datasets.make_blobs(n_samples=1000,random_state=8,centers=3)

plt.scatter(X[:,0],X[:,1],c=y)
plt.show()

构建模型并训练

model = keras.models.Sequential([
    keras.layers.Dense(32,input_shape=X.shape[1:], activation='relu'),
    keras.layers.Dense(3,activation=keras.activations.softmax)]
)
model.summary()

model.compile(loss = keras.losses.sparse_categorical_crossentropy,
              optimizer = keras.optimizers.Adam(),
              metrics=['accuracy'])
model.fit(X,y,validation_split=0.25,epochs = 20)

查看数据

print(y[0:10])

y_pre = model.predict(X[0:10])
import numpy as np
print(np.reshape(y_pre,[10,3]))


[1 2 1 1 1 2 1 2 1 1]

[[4.50088549e-03 9.95355964e-01 1.43211320e-04]
 [3.52771860e-03 2.17666663e-03 9.94295657e-01]
 [5.39137749e-04 9.99391794e-01 6.91057867e-05]
 [3.10646836e-03 9.93669093e-01 3.22450022e-03]
 [1.59081508e-04 9.99381661e-01 4.59307077e-04]
 [3.76076205e-04 2.09796475e-03 9.97525990e-01]
 [1.03477845e-02 9.88485038e-01 1.16714649e-03]
 [8.82121618e-04 1.39025709e-04 9.98978853e-01]
 [2.29390264e-02 9.75332916e-01 1.72808918e-03]
 [4.69710241e-04 9.99316335e-01 2.13949577e-04]]

你可能感兴趣的:(tensorflow)