蒙特卡洛dropout

链接是这个https://blog.csdn.net/weixin_26731327/article/details/109070481?utm_medium=distribute.pc_relevant.none-task-blog-title-2&spm=1001.2101.3001.4242
自我总结先写前面,我认为蒙特卡洛dropout首先肯定是测试的时候也开着dropout,然后就是测试n次的测试集,然后求n次的输出概率的平均值,得到不确定性,以此再取沿轴的最大值;
普通的softmax在测试时候的dropout是关闭着的。就是直接一个测试就一个输出,然后直接取沿轴最大值当作输出、

以下就是常规的分类代码,测试时,dropout在普通情况下是关闭的,但是在蒙特卡咯情况下是开启的

(X_train, y_train), (X_test, y_test) = keras.datasets.mnist.load_data()
 
 
model = keras.models.Sequential()
model.add(keras.layers.Flatten(input_shape=(28, 28)))
model.add(keras.layers.Dropout(0.25))
model.add(keras.layers.Dense(300, activation="relu"))
model.add(keras.layers.Dropout(0.25))
model.add(keras.layers.Dense(300, activation="relu"))
model.add(keras.layers.Dropout(0.25))
model.add(keras.layers.Dense(10, activation="softmax"))
 
 
optimizer = keras.optimizers.Nadam(lr=0.001)
model.compile(loss="sparse_categorical_crossentropy", 
              optimizer=optimizer, metrics=["accuracy"])
model.fit(X_train, y_train, epochs=50)
model.evaluate(X_test, y_test)

模型准确性的计算:

可以生成任意数量的预测,就是说可以预测任意多次

def predict_proba(X, model, num_samples):
    preds = [model(X, training=True) for _ in range(num_samples)]
    return np.stack(preds).mean(axis=0)
     
def predict_class(X, model, num_samples):
    proba_preds = predict_proba(X, model, num_samples)
    return np.argmax(proba_preds, axis=1)
y_pred = predict_class(X_test, model, 100)
acc = np.mean(y_pred == y_test)

由以上代码可以看出,先弄num_samples次预测,然后取平均值,然后再沿着某一轴取最大值,即可得到比原来好的预测效果。

预测不确定性

y_pred_proba = predict_proba(X_test, model, 100)
 
 
softmax_output = np.round(model.predict(X_test[1:2]), 3)
mc_pred_proba = np.round(y_pred_proba[1], 3)
print(softmax_output, mc_pred_proba)
softmax_output: [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
mc_pred_proba: [0. 0. 0.989 0.008 0.001 0. 0. 0.001 0.001 0. ]

softmax_output: [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.] 
mc_pred_proba: [0. 0. 0.989 0.008 0.001 0. 0. 0.001 0.001 0. ] [0. 0. 0.989 0.008 0.001 0. 0. 0.001 0.001 0. ]

你可能感兴趣的:(tensorflow杂记,python,深度学习,神经网络)