#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
View on TensorFlow.org | Run in Google Colab | View source on GitHub | Download notebook |
This tutorial demonstrates how to classify a highly imbalanced dataset in which the number of examples in one class greatly outnumbers the examples in another. You will work with the Credit Card Fraud Detection dataset hosted on Kaggle. The aim is to detect a mere 492 fraudulent transactions from 284,807 transactions in total. You will use Keras to define the model and class weights to help the model learn from the imbalanced data. .
This tutorial contains complete code to:
import tensorflow as tf
from tensorflow import keras
import os
import tempfile
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import sklearn
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
mpl.rcParams['figure.figsize'] = (12, 10)
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
Pandas is a Python library with many helpful utilities for loading and working with structured data and can be used to download CSVs into a dataframe.
Note: This dataset has been collected and analysed during a research collaboration of Worldline and the Machine Learning Group of ULB (Université Libre de Bruxelles) on big data mining and fraud detection. More details on current and past projects on related topics are available here and the page of the DefeatFraud project
file = tf.keras.utils
raw_df = pd.read_csv('https://storage.googleapis.com/download.tensorflow.org/data/creditcard.csv')
raw_df.head()
Time | V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | ... | V21 | V22 | V23 | V24 | V25 | V26 | V27 | V28 | Amount | Class | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.0 | -1.359807 | -0.072781 | 2.536347 | 1.378155 | -0.338321 | 0.462388 | 0.239599 | 0.098698 | 0.363787 | ... | -0.018307 | 0.277838 | -0.110474 | 0.066928 | 0.128539 | -0.189115 | 0.133558 | -0.021053 | 149.62 | 0 |
1 | 0.0 | 1.191857 | 0.266151 | 0.166480 | 0.448154 | 0.060018 | -0.082361 | -0.078803 | 0.085102 | -0.255425 | ... | -0.225775 | -0.638672 | 0.101288 | -0.339846 | 0.167170 | 0.125895 | -0.008983 | 0.014724 | 2.69 | 0 |
2 | 1.0 | -1.358354 | -1.340163 | 1.773209 | 0.379780 | -0.503198 | 1.800499 | 0.791461 | 0.247676 | -1.514654 | ... | 0.247998 | 0.771679 | 0.909412 | -0.689281 | -0.327642 | -0.139097 | -0.055353 | -0.059752 | 378.66 | 0 |
3 | 1.0 | -0.966272 | -0.185226 | 1.792993 | -0.863291 | -0.010309 | 1.247203 | 0.237609 | 0.377436 | -1.387024 | ... | -0.108300 | 0.005274 | -0.190321 | -1.175575 | 0.647376 | -0.221929 | 0.062723 | 0.061458 | 123.50 | 0 |
4 | 2.0 | -1.158233 | 0.877737 | 1.548718 | 0.403034 | -0.407193 | 0.095921 | 0.592941 | -0.270533 | 0.817739 | ... | -0.009431 | 0.798278 | -0.137458 | 0.141267 | -0.206010 | 0.502292 | 0.219422 | 0.215153 | 69.99 | 0 |
5 rows × 31 columns
raw_df[['Time', 'V1', 'V2', 'V3', 'V4', 'V5', 'V26', 'V27', 'V28', 'Amount', 'Class']].describe()
Time | V1 | V2 | V3 | V4 | V5 | V26 | V27 | V28 | Amount | Class | |
---|---|---|---|---|---|---|---|---|---|---|---|
count | 284807.000000 | 2.848070e+05 | 2.848070e+05 | 2.848070e+05 | 2.848070e+05 | 2.848070e+05 | 2.848070e+05 | 2.848070e+05 | 2.848070e+05 | 284807.000000 | 284807.000000 |
mean | 94813.859575 | 1.165980e-15 | 3.416908e-16 | -1.373150e-15 | 2.086869e-15 | 9.604066e-16 | 1.687098e-15 | -3.666453e-16 | -1.220404e-16 | 88.349619 | 0.001727 |
std | 47488.145955 | 1.958696e+00 | 1.651309e+00 | 1.516255e+00 | 1.415869e+00 | 1.380247e+00 | 4.822270e-01 | 4.036325e-01 | 3.300833e-01 | 250.120109 | 0.041527 |
min | 0.000000 | -5.640751e+01 | -7.271573e+01 | -4.832559e+01 | -5.683171e+00 | -1.137433e+02 | -2.604551e+00 | -2.256568e+01 | -1.543008e+01 | 0.000000 | 0.000000 |
25% | 54201.500000 | -9.203734e-01 | -5.985499e-01 | -8.903648e-01 | -8.486401e-01 | -6.915971e-01 | -3.269839e-01 | -7.083953e-02 | -5.295979e-02 | 5.600000 | 0.000000 |
50% | 84692.000000 | 1.810880e-02 | 6.548556e-02 | 1.798463e-01 | -1.984653e-02 | -5.433583e-02 | -5.213911e-02 | 1.342146e-03 | 1.124383e-02 | 22.000000 | 0.000000 |
75% | 139320.500000 | 1.315642e+00 | 8.037239e-01 | 1.027196e+00 | 7.433413e-01 | 6.119264e-01 | 2.409522e-01 | 9.104512e-02 | 7.827995e-02 | 77.165000 | 0.000000 |
max | 172792.000000 | 2.454930e+00 | 2.205773e+01 | 9.382558e+00 | 1.687534e+01 | 3.480167e+01 | 3.517346e+00 | 3.161220e+01 | 3.384781e+01 | 25691.160000 | 1.000000 |
Let’s look at the dataset imbalance:
neg, pos = np.bincount(raw_df['Class'])
total = neg + pos
print('Examples:\n Total: {}\n Positive: {} ({:.2f}% of total)\n'.format(
total, pos, 100 * pos / total))
Examples:
Total: 284807
Positive: 492 (0.17% of total)
This shows the small fraction of positive samples.
The raw data has a few issues. First the Time
and Amount
columns are too variable to use directly. Drop the Time
column (since it’s not clear what it means) and take the log of the Amount
column to reduce its range.
cleaned_df = raw_df.copy()
# You don't want the `Time` column.
cleaned_df.pop('Time')
# The `Amount` column covers a huge range. Convert to log-space.
eps=0.001 # 0 => 0.1¢
cleaned_df['Log Ammount'] = np.log(cleaned_df.pop('Amount')+eps)
Split the dataset into train, validation, and test sets. The validation set is used during the model fitting to evaluate the loss and any metrics, however the model is not fit with this data. The test set is completely unused during the training phase and is only used at the end to evaluate how well the model generalizes to new data. This is especially important with imbalanced datasets where overfitting is a significant concern from the lack of training data.
# Use a utility from sklearn to split and shuffle our dataset.
train_df, test_df = train_test_split(cleaned_df, test_size=0.2)
train_df, val_df = train_test_split(train_df, test_size=0.2)
# Form np arrays of labels and features.
train_labels = np.array(train_df.pop('Class'))
bool_train_labels = train_labels != 0
val_labels = np.array(val_df.pop('Class'))
test_labels = np.array(test_df.pop('Class'))
train_features = np.array(train_df)
val_features = np.array(val_df)
test_features = np.array(test_df)
Normalize the input features using the sklearn StandardScaler.
This will set the mean to 0 and standard deviation to 1.
Note: The StandardScaler
is only fit using the train_features
to be sure the model is not peeking at the validation or test sets.
scaler = StandardScaler()
train_features = scaler.fit_transform(train_features)
val_features = scaler.transform(val_features)
test_features = scaler.transform(test_features)
train_features = np.clip(train_features, -5, 5)
val_features = np.clip(val_features, -5, 5)
test_features = np.clip(test_features, -5, 5)
print('Training labels shape:', train_labels.shape)
print('Validation labels shape:', val_labels.shape)
print('Test labels shape:', test_labels.shape)
print('Training features shape:', train_features.shape)
print('Validation features shape:', val_features.shape)
print('Test features shape:', test_features.shape)
Training labels shape: (182276,)
Validation labels shape: (45569,)
Test labels shape: (56962,)
Training features shape: (182276, 29)
Validation features shape: (45569, 29)
Test features shape: (56962, 29)
Caution: If you want to deploy a model, it’s critical that you preserve the preprocessing calculations. The easiest way to implement them as layers, and attach them to your model before export.
Next compare the distributions of the positive and negative examples over a few features. Good questions to ask yourself at this point are:
+/- 2
range.pos_df = pd.DataFrame(train_features[ bool_train_labels], columns = train_df.columns)
neg_df = pd.DataFrame(train_features[~bool_train_labels], columns = train_df.columns)
sns.jointplot(pos_df['V5'], pos_df['V6'],
kind='hex', xlim = (-5,5), ylim = (-5,5))
plt.suptitle("Positive distribution")
sns.jointplot(neg_df['V5'], neg_df['V6'],
kind='hex', xlim = (-5,5), ylim = (-5,5))
_ = plt.suptitle("Negative distribution")
Define a function that creates a simple neural network with a densly connected hidden layer, a dropout layer to reduce overfitting, and an output sigmoid layer that returns the probability of a transaction being fraudulent:
METRICS = [
keras.metrics.TruePositives(name='tp'),
keras.metrics.FalsePositives(name='fp'),
keras.metrics.TrueNegatives(name='tn'),
keras.metrics.FalseNegatives(name='fn'),
keras.metrics.BinaryAccuracy(name='accuracy'),
keras.metrics.Precision(name='precision'),
keras.metrics.Recall(name='recall'),
keras.metrics.AUC(name='auc'),
]
def make_model(metrics = METRICS, output_bias=None):
if output_bias is not None:
output_bias = tf.keras.initializers.Constant(output_bias)
model = keras.Sequential([
keras.layers.Dense(
16, activation='relu',
input_shape=(train_features.shape[-1],)),
keras.layers.Dropout(0.5),
keras.layers.Dense(1, activation='sigmoid',
bias_initializer=output_bias),
])
model.compile(
optimizer=keras.optimizers.Adam(lr=1e-3),
loss=keras.losses.BinaryCrossentropy(),
metrics=metrics)
return model
Notice that there are a few metrics defined above that can be computed by the model that will be helpful when evaluating the performance.
true samples total samples \frac{\text{true samples}}{\text{total samples}} total samplestrue samples
true positives true positives + false positives \frac{\text{true positives}}{\text{true positives + false positives}} true positives + false positivestrue positives
true positives true positives + false negatives \frac{\text{true positives}}{\text{true positives + false negatives}} true positives + false negativestrue positives
Note: Accuracy is not a helpful metric for this task. You can 99.8%+ accuracy on this task by predicting False all the time.
Read more:
Now create and train your model using the function that was defined earlier. Notice that the model is fit using a larger than default batch size of 2048, this is important to ensure that each batch has a decent chance of containing a few positive samples. If the batch size was too small, they would likely have no fraudulent transactions to learn from.
Note: this model will not handle the class imbalance well. You will improve it later in this tutorial.
EPOCHS = 100
BATCH_SIZE = 2048
early_stopping = tf.keras.callbacks.EarlyStopping(
monitor='val_auc',
verbose=1,
patience=10,
mode='max',
restore_best_weights=True)
model = make_model()
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 16) 480
_________________________________________________________________
dropout (Dropout) (None, 16) 0
_________________________________________________________________
dense_1 (Dense) (None, 1) 17
=================================================================
Total params: 497
Trainable params: 497
Non-trainable params: 0
_________________________________________________________________
Test run the model:
model.predict(train_features[:10])
array([[0.9230855 ],
[0.3435619 ],
[0.7657857 ],
[0.78466564],
[0.44818196],
[0.33391467],
[0.34766468],
[0.9538829 ],
[0.5203672 ],
[0.67297024]], dtype=float32)
These are initial guesses are not great. You know the dataset is imbalanced. Set the output layer’s bias to reflect that (See: A Recipe for Training Neural Networks: “init well”). This can help with initial convergence.
With the default bias initialization the loss should be about math.log(2) = 0.69314
results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print("Loss: {:0.4f}".format(results[0]))
Loss: 1.1084
The correct bias to set can be derived from:
p 0 = p o s / ( p o s + n e g ) = 1 / ( 1 + e − b 0 ) p_0 = pos/(pos + neg) = 1/(1+e^{-b_0}) p0=pos/(pos+neg)=1/(1+e−b0)
b 0 = − l o g e ( 1 / p 0 − 1 ) b_0 = -log_e(1/p_0 - 1) b0=−loge(1/p0−1)
b 0 = l o g e ( p o s / n e g ) b_0 = log_e(pos/neg) b0=loge(pos/neg)
initial_bias = np.log([pos/neg])
initial_bias
array([-6.35935934])
Set that as the initial bias, and the model will give much more reasonable initial guesses.
It should be near: pos/total = 0.0018
model = make_model(output_bias = initial_bias)
model.predict(train_features[:10])
array([[0.00191665],
[0.00075228],
[0.00076816],
[0.00162692],
[0.00106873],
[0.00045646],
[0.00063379],
[0.00207557],
[0.00037811],
[0.00023752]], dtype=float32)
With this initialization the initial loss should be approximately:
− p 0 l o g ( p 0 ) − ( 1 − p 0 ) l o g ( 1 − p 0 ) = 0.01317 -p_0log(p_0)-(1-p_0)log(1-p_0) = 0.01317 −p0log(p0)−(1−p0)log(1−p0)=0.01317
results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print("Loss: {:0.4f}".format(results[0]))
Loss: 0.0139
This initial loss is about 50 times less than if would have been with naive initilization.
This way the model doesn’t need to spend the first few epochs just learning that positive examples are unlikely. This also makes it easier to read plots of the loss during training.
To make the various training runs more comparable, keep this initial model’s weights in a checkpoint file, and load them into each model before training.
initial_weights = os.path.join(tempfile.mkdtemp(),'initial_weights')
model.save_weights(initial_weights)
Before moving on, confirm quick that the careful bias initialization actually helped.
Train the model for 20 epochs, with and without this careful initialization, and compare the losses:
model = make_model()
model.load_weights(initial_weights)
model.layers[-1].bias.assign([0.0])
zero_bias_history = model.fit(
train_features,
train_labels,
batch_size=BATCH_SIZE,
epochs=20,
validation_data=(val_features, val_labels),
verbose=0)
model = make_model()
model.load_weights(initial_weights)
careful_bias_history = model.fit(
train_features,
train_labels,
batch_size=BATCH_SIZE,
epochs=20,
validation_data=(val_features, val_labels),
verbose=0)
def plot_loss(history, label, n):
# Use a log scale to show the wide range of values.
plt.semilogy(history.epoch, history.history['loss'],
color=colors[n], label='Train '+label)
plt.semilogy(history.epoch, history.history['val_loss'],
color=colors[n], label='Val '+label,
linestyle="--")
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plot_loss(zero_bias_history, "Zero Bias", 0)
plot_loss(careful_bias_history, "Careful Bias", 1)
The above figure makes it clear: In terms of validation loss, on this problem, this careful initialization gives a clear advantage.
model = make_model()
model.load_weights(initial_weights)
baseline_history = model.fit(
train_features,
train_labels,
batch_size=BATCH_SIZE,
epochs=EPOCHS,
callbacks = [early_stopping],
validation_data=(val_features, val_labels))
Train on 182276 samples, validate on 45569 samples
Epoch 1/100
182276/182276 [==============================] - 2s 13us/sample - loss: 0.0125 - tp: 57.0000 - fp: 145.0000 - tn: 181801.0000 - fn: 273.0000 - accuracy: 0.9977 - precision: 0.2822 - recall: 0.1727 - auc: 0.7707 - val_loss: 0.0056 - val_tp: 26.0000 - val_fp: 6.0000 - val_tn: 45486.0000 - val_fn: 51.0000 - val_accuracy: 0.9987 - val_precision: 0.8125 - val_recall: 0.3377 - val_auc: 0.9020
Epoch 2/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.0075 - tp: 125.0000 - fp: 51.0000 - tn: 181895.0000 - fn: 205.0000 - accuracy: 0.9986 - precision: 0.7102 - recall: 0.3788 - auc: 0.8794 - val_loss: 0.0044 - val_tp: 44.0000 - val_fp: 7.0000 - val_tn: 45485.0000 - val_fn: 33.0000 - val_accuracy: 0.9991 - val_precision: 0.8627 - val_recall: 0.5714 - val_auc: 0.9088
Epoch 3/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.0066 - tp: 161.0000 - fp: 41.0000 - tn: 181905.0000 - fn: 169.0000 - accuracy: 0.9988 - precision: 0.7970 - recall: 0.4879 - auc: 0.8883 - val_loss: 0.0041 - val_tp: 45.0000 - val_fp: 7.0000 - val_tn: 45485.0000 - val_fn: 32.0000 - val_accuracy: 0.9991 - val_precision: 0.8654 - val_recall: 0.5844 - val_auc: 0.9154
Epoch 4/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.0060 - tp: 172.0000 - fp: 36.0000 - tn: 181910.0000 - fn: 158.0000 - accuracy: 0.9989 - precision: 0.8269 - recall: 0.5212 - auc: 0.8897 - val_loss: 0.0039 - val_tp: 48.0000 - val_fp: 7.0000 - val_tn: 45485.0000 - val_fn: 29.0000 - val_accuracy: 0.9992 - val_precision: 0.8727 - val_recall: 0.6234 - val_auc: 0.9348
Epoch 5/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.0055 - tp: 179.0000 - fp: 29.0000 - tn: 181917.0000 - fn: 151.0000 - accuracy: 0.9990 - precision: 0.8606 - recall: 0.5424 - auc: 0.9025 - val_loss: 0.0037 - val_tp: 51.0000 - val_fp: 7.0000 - val_tn: 45485.0000 - val_fn: 26.0000 - val_accuracy: 0.9993 - val_precision: 0.8793 - val_recall: 0.6623 - val_auc: 0.9348
Epoch 6/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.0053 - tp: 181.0000 - fp: 34.0000 - tn: 181912.0000 - fn: 149.0000 - accuracy: 0.9990 - precision: 0.8419 - recall: 0.5485 - auc: 0.9074 - val_loss: 0.0036 - val_tp: 51.0000 - val_fp: 7.0000 - val_tn: 45485.0000 - val_fn: 26.0000 - val_accuracy: 0.9993 - val_precision: 0.8793 - val_recall: 0.6623 - val_auc: 0.9348
Epoch 7/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.0052 - tp: 188.0000 - fp: 31.0000 - tn: 181915.0000 - fn: 142.0000 - accuracy: 0.9991 - precision: 0.8584 - recall: 0.5697 - auc: 0.9137 - val_loss: 0.0035 - val_tp: 54.0000 - val_fp: 8.0000 - val_tn: 45484.0000 - val_fn: 23.0000 - val_accuracy: 0.9993 - val_precision: 0.8710 - val_recall: 0.7013 - val_auc: 0.9348
Epoch 8/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.0048 - tp: 200.0000 - fp: 35.0000 - tn: 181911.0000 - fn: 130.0000 - accuracy: 0.9991 - precision: 0.8511 - recall: 0.6061 - auc: 0.9153 - val_loss: 0.0034 - val_tp: 55.0000 - val_fp: 8.0000 - val_tn: 45484.0000 - val_fn: 22.0000 - val_accuracy: 0.9993 - val_precision: 0.8730 - val_recall: 0.7143 - val_auc: 0.9348
Epoch 9/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.0048 - tp: 209.0000 - fp: 35.0000 - tn: 181911.0000 - fn: 121.0000 - accuracy: 0.9991 - precision: 0.8566 - recall: 0.6333 - auc: 0.9171 - val_loss: 0.0034 - val_tp: 45.0000 - val_fp: 6.0000 - val_tn: 45486.0000 - val_fn: 32.0000 - val_accuracy: 0.9992 - val_precision: 0.8824 - val_recall: 0.5844 - val_auc: 0.9349
Epoch 10/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.0047 - tp: 189.0000 - fp: 26.0000 - tn: 181920.0000 - fn: 141.0000 - accuracy: 0.9991 - precision: 0.8791 - recall: 0.5727 - auc: 0.9049 - val_loss: 0.0033 - val_tp: 56.0000 - val_fp: 8.0000 - val_tn: 45484.0000 - val_fn: 21.0000 - val_accuracy: 0.9994 - val_precision: 0.8750 - val_recall: 0.7273 - val_auc: 0.9348
Epoch 11/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.0046 - tp: 199.0000 - fp: 33.0000 - tn: 181913.0000 - fn: 131.0000 - accuracy: 0.9991 - precision: 0.8578 - recall: 0.6030 - auc: 0.9276 - val_loss: 0.0033 - val_tp: 52.0000 - val_fp: 7.0000 - val_tn: 45485.0000 - val_fn: 25.0000 - val_accuracy: 0.9993 - val_precision: 0.8814 - val_recall: 0.6753 - val_auc: 0.9348
Epoch 12/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.0046 - tp: 203.0000 - fp: 30.0000 - tn: 181916.0000 - fn: 127.0000 - accuracy: 0.9991 - precision: 0.8712 - recall: 0.6152 - auc: 0.9231 - val_loss: 0.0032 - val_tp: 55.0000 - val_fp: 8.0000 - val_tn: 45484.0000 - val_fn: 22.0000 - val_accuracy: 0.9993 - val_precision: 0.8730 - val_recall: 0.7143 - val_auc: 0.9348
Epoch 13/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.0044 - tp: 199.0000 - fp: 29.0000 - tn: 181917.0000 - fn: 131.0000 - accuracy: 0.9991 - precision: 0.8728 - recall: 0.6030 - auc: 0.9171 - val_loss: 0.0032 - val_tp: 58.0000 - val_fp: 9.0000 - val_tn: 45483.0000 - val_fn: 19.0000 - val_accuracy: 0.9994 - val_precision: 0.8657 - val_recall: 0.7532 - val_auc: 0.9348
Epoch 14/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.0044 - tp: 204.0000 - fp: 36.0000 - tn: 181910.0000 - fn: 126.0000 - accuracy: 0.9991 - precision: 0.8500 - recall: 0.6182 - auc: 0.9186 - val_loss: 0.0032 - val_tp: 60.0000 - val_fp: 9.0000 - val_tn: 45483.0000 - val_fn: 17.0000 - val_accuracy: 0.9994 - val_precision: 0.8696 - val_recall: 0.7792 - val_auc: 0.9348
Epoch 15/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.0045 - tp: 197.0000 - fp: 34.0000 - tn: 181912.0000 - fn: 133.0000 - accuracy: 0.9991 - precision: 0.8528 - recall: 0.5970 - auc: 0.9126 - val_loss: 0.0032 - val_tp: 59.0000 - val_fp: 9.0000 - val_tn: 45483.0000 - val_fn: 18.0000 - val_accuracy: 0.9994 - val_precision: 0.8676 - val_recall: 0.7662 - val_auc: 0.9348
Epoch 16/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.0044 - tp: 199.0000 - fp: 36.0000 - tn: 181910.0000 - fn: 131.0000 - accuracy: 0.9991 - precision: 0.8468 - recall: 0.6030 - auc: 0.9218 - val_loss: 0.0031 - val_tp: 61.0000 - val_fp: 9.0000 - val_tn: 45483.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8714 - val_recall: 0.7922 - val_auc: 0.9348
Epoch 17/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.0042 - tp: 210.0000 - fp: 33.0000 - tn: 181913.0000 - fn: 120.0000 - accuracy: 0.9992 - precision: 0.8642 - recall: 0.6364 - auc: 0.9127 - val_loss: 0.0031 - val_tp: 60.0000 - val_fp: 9.0000 - val_tn: 45483.0000 - val_fn: 17.0000 - val_accuracy: 0.9994 - val_precision: 0.8696 - val_recall: 0.7792 - val_auc: 0.9348
Epoch 18/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.0041 - tp: 203.0000 - fp: 27.0000 - tn: 181919.0000 - fn: 127.0000 - accuracy: 0.9992 - precision: 0.8826 - recall: 0.6152 - auc: 0.9279 - val_loss: 0.0031 - val_tp: 61.0000 - val_fp: 9.0000 - val_tn: 45483.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.8714 - val_recall: 0.7922 - val_auc: 0.9348
Epoch 19/100
161792/182276 [=========================>....] - ETA: 0s - loss: 0.0043 - tp: 183.0000 - fp: 31.0000 - tn: 161475.0000 - fn: 103.0000 - accuracy: 0.9992 - precision: 0.8551 - recall: 0.6399 - auc: 0.9238Restoring model weights from the end of the best epoch.
182276/182276 [==============================] - 1s 3us/sample - loss: 0.0043 - tp: 213.0000 - fp: 34.0000 - tn: 181912.0000 - fn: 117.0000 - accuracy: 0.9992 - precision: 0.8623 - recall: 0.6455 - auc: 0.9248 - val_loss: 0.0032 - val_tp: 53.0000 - val_fp: 7.0000 - val_tn: 45485.0000 - val_fn: 24.0000 - val_accuracy: 0.9993 - val_precision: 0.8833 - val_recall: 0.6883 - val_auc: 0.9348
Epoch 00019: early stopping
In this section, you will produce plots of your model’s accuracy and loss on the training and validation set. These are useful to check for overfitting, which you can learn more about in this tutorial.
Additionally, you can produce these plots for any of the metrics you created above. False negatives are included as an example.
def plot_metrics(history):
metrics = ['loss', 'auc', 'precision', 'recall']
for n, metric in enumerate(metrics):
name = metric.replace("_"," ").capitalize()
plt.subplot(2,2,n+1)
plt.plot(history.epoch, history.history[metric], color=colors[0], label='Train')
plt.plot(history.epoch, history.history['val_'+metric],
color=colors[0], linestyle="--", label='Val')
plt.xlabel('Epoch')
plt.ylabel(name)
if metric == 'loss':
plt.ylim([0, plt.ylim()[1]])
elif metric == 'auc':
plt.ylim([0.8,1])
else:
plt.ylim([0,1])
plt.legend()
plot_metrics(baseline_history)
Note: That the validation curve generally performs better than the training curve. This is mainly caused by the fact that the dropout layer is not active when evaluating the model.
You can use a confusion matrix to summarize the actual vs. predicted labels where the X axis is the predicted label and the Y axis is the actual label.
train_predictions_baseline = model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_baseline = model.predict(test_features, batch_size=BATCH_SIZE)
def plot_cm(labels, predictions, p=0.5):
cm = confusion_matrix(labels, predictions > p)
plt.figure(figsize=(5,5))
sns.heatmap(cm, annot=True, fmt="d")
plt.title('Confusion matrix @{:.2f}'.format(p))
plt.ylabel('Actual label')
plt.xlabel('Predicted label')
print('Legitimate Transactions Detected (True Negatives): ', cm[0][0])
print('Legitimate Transactions Incorrectly Detected (False Positives): ', cm[0][1])
print('Fraudulent Transactions Missed (False Negatives): ', cm[1][0])
print('Fraudulent Transactions Detected (True Positives): ', cm[1][1])
print('Total Fraudulent Transactions: ', np.sum(cm[1]))
Evaluate your model on the test dataset and display the results for the metrics you created above.
baseline_results = model.evaluate(test_features, test_labels,
batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(model.metrics_names, baseline_results):
print(name, ': ', value)
print()
plot_cm(test_labels, test_predictions_baseline)
loss : 0.002795267753867907
tp : 62.0
fp : 9.0
tn : 56868.0
fn : 23.0
accuracy : 0.9994382
precision : 0.87323946
recall : 0.7294118
auc : 0.92923415
Legitimate Transactions Detected (True Negatives): 56868
Legitimate Transactions Incorrectly Detected (False Positives): 9
Fraudulent Transactions Missed (False Negatives): 23
Fraudulent Transactions Detected (True Positives): 62
Total Fraudulent Transactions: 85
If the model had predicted everything perfectly, this would be a diagonal matrix where values off the main diagonal, indicating incorrect predictions, would be zero. In this case the matrix shows that you have relatively few false positives, meaning that there were relatively few legitimate transactions that were incorrectly flagged. However, you would likely want to have even fewer false negatives despite the cost of increasing the number of false positives. This trade off may be preferable because false negatives would allow fraudulent transactions to go through, whereas false positives may cause an email to be sent to a customer to ask them to verify their card activity.
Now plot the ROC. This plot is useful because it shows, at a glance, the range of performance the model can reach just by tuning the output threshold.
def plot_roc(name, labels, predictions, **kwargs):
fp, tp, _ = sklearn.metrics.roc_curve(labels, predictions)
plt.plot(100*fp, 100*tp, label=name, linewidth=2, **kwargs)
plt.xlabel('False positives [%]')
plt.ylabel('True positives [%]')
plt.xlim([-0.5,20])
plt.ylim([80,100.5])
plt.grid(True)
ax = plt.gca()
ax.set_aspect('equal')
plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')
plt.legend(loc='lower right')
It looks like the precision is relatively high, but the recall and the area under the ROC curve (AUC) aren’t as high as you might like. Classifiers often face challenges when trying to maximize both precision and recall, which is especially true when working with imbalanced datasets. It is important to consider the costs of different types of errors in the context of the problem you care about. In this example, a false negative (a fraudulent transaction is missed) may have a financial cost, while a false positive (a transaction is incorrectly flagged as fraudulent) may decrease user happiness.
The goal is to identify fradulent transactions, but you don’t have very many of those positive samples to work with, so you would want to have the classifier heavily weight the few examples that are available. You can do this by passing Keras weights for each class through a parameter. These will cause the model to “pay more attention” to examples from an under-represented class.
# Scaling by total/2 helps keep the loss to a similar magnitude.
# The sum of the weights of all examples stays the same.
weight_for_0 = (1 / neg)*(total)/2.0
weight_for_1 = (1 / pos)*(total)/2.0
class_weight = {0: weight_for_0, 1: weight_for_1}
print('Weight for class 0: {:.2f}'.format(weight_for_0))
print('Weight for class 1: {:.2f}'.format(weight_for_1))
Weight for class 0: 0.50
Weight for class 1: 289.44
Now try re-training and evaluating the model with class weights to see how that affects the predictions.
Note: Using class_weights
changes the range of the loss. This may affect the stability of the training depending on the optimizer. Optimizers whose step size is dependent on the magnitude of the gradient, like optimizers.SGD
, may fail. The optimizer used here, optimizers.Adam
, is unaffected by the scaling change. Also note that because of the weighting, the total losses are not comparable between the two models.
weighted_model = make_model()
weighted_model.load_weights(initial_weights)
weighted_history = weighted_model.fit(
train_features,
train_labels,
batch_size=BATCH_SIZE,
epochs=EPOCHS,
callbacks = [early_stopping],
validation_data=(val_features, val_labels),
# The class weights go here
class_weight=class_weight)
WARNING:tensorflow:sample_weight modes were coerced from
...
to
['...']
WARNING:tensorflow:sample_weight modes were coerced from
...
to
['...']
Train on 182276 samples, validate on 45569 samples
Epoch 1/100
182276/182276 [==============================] - 3s 15us/sample - loss: 1.7348 - tp: 97.0000 - fp: 420.0000 - tn: 181526.0000 - fn: 233.0000 - accuracy: 0.9964 - precision: 0.1876 - recall: 0.2939 - auc: 0.8278 - val_loss: 0.6994 - val_tp: 46.0000 - val_fp: 42.0000 - val_tn: 45450.0000 - val_fn: 31.0000 - val_accuracy: 0.9984 - val_precision: 0.5227 - val_recall: 0.5974 - val_auc: 0.9318
Epoch 2/100
182276/182276 [==============================] - 1s 4us/sample - loss: 0.8274 - tp: 218.0000 - fp: 729.0000 - tn: 181217.0000 - fn: 112.0000 - accuracy: 0.9954 - precision: 0.2302 - recall: 0.6606 - auc: 0.8966 - val_loss: 0.4231 - val_tp: 61.0000 - val_fp: 60.0000 - val_tn: 45432.0000 - val_fn: 16.0000 - val_accuracy: 0.9983 - val_precision: 0.5041 - val_recall: 0.7922 - val_auc: 0.9535
Epoch 3/100
182276/182276 [==============================] - 1s 4us/sample - loss: 0.6246 - tp: 237.0000 - fp: 1159.0000 - tn: 180787.0000 - fn: 93.0000 - accuracy: 0.9931 - precision: 0.1698 - recall: 0.7182 - auc: 0.9212 - val_loss: 0.3129 - val_tp: 64.0000 - val_fp: 100.0000 - val_tn: 45392.0000 - val_fn: 13.0000 - val_accuracy: 0.9975 - val_precision: 0.3902 - val_recall: 0.8312 - val_auc: 0.9660
Epoch 4/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.4203 - tp: 258.0000 - fp: 1846.0000 - tn: 180100.0000 - fn: 72.0000 - accuracy: 0.9895 - precision: 0.1226 - recall: 0.7818 - auc: 0.9566 - val_loss: 0.2543 - val_tp: 67.0000 - val_fp: 216.0000 - val_tn: 45276.0000 - val_fn: 10.0000 - val_accuracy: 0.9950 - val_precision: 0.2367 - val_recall: 0.8701 - val_auc: 0.9741
Epoch 5/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.3672 - tp: 268.0000 - fp: 2442.0000 - tn: 179504.0000 - fn: 62.0000 - accuracy: 0.9863 - precision: 0.0989 - recall: 0.8121 - auc: 0.9548 - val_loss: 0.2238 - val_tp: 68.0000 - val_fp: 421.0000 - val_tn: 45071.0000 - val_fn: 9.0000 - val_accuracy: 0.9906 - val_precision: 0.1391 - val_recall: 0.8831 - val_auc: 0.9757
Epoch 6/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.3624 - tp: 277.0000 - fp: 3336.0000 - tn: 178610.0000 - fn: 53.0000 - accuracy: 0.9814 - precision: 0.0767 - recall: 0.8394 - auc: 0.9513 - val_loss: 0.2056 - val_tp: 68.0000 - val_fp: 528.0000 - val_tn: 44964.0000 - val_fn: 9.0000 - val_accuracy: 0.9882 - val_precision: 0.1141 - val_recall: 0.8831 - val_auc: 0.9800
Epoch 7/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2907 - tp: 282.0000 - fp: 3990.0000 - tn: 177956.0000 - fn: 48.0000 - accuracy: 0.9778 - precision: 0.0660 - recall: 0.8545 - auc: 0.9673 - val_loss: 0.1882 - val_tp: 69.0000 - val_fp: 614.0000 - val_tn: 44878.0000 - val_fn: 8.0000 - val_accuracy: 0.9864 - val_precision: 0.1010 - val_recall: 0.8961 - val_auc: 0.9810
Epoch 8/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.3237 - tp: 281.0000 - fp: 4760.0000 - tn: 177186.0000 - fn: 49.0000 - accuracy: 0.9736 - precision: 0.0557 - recall: 0.8515 - auc: 0.9584 - val_loss: 0.1769 - val_tp: 69.0000 - val_fp: 702.0000 - val_tn: 44790.0000 - val_fn: 8.0000 - val_accuracy: 0.9844 - val_precision: 0.0895 - val_recall: 0.8961 - val_auc: 0.9828
Epoch 9/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.3026 - tp: 286.0000 - fp: 5234.0000 - tn: 176712.0000 - fn: 44.0000 - accuracy: 0.9710 - precision: 0.0518 - recall: 0.8667 - auc: 0.9572 - val_loss: 0.1708 - val_tp: 69.0000 - val_fp: 757.0000 - val_tn: 44735.0000 - val_fn: 8.0000 - val_accuracy: 0.9832 - val_precision: 0.0835 - val_recall: 0.8961 - val_auc: 0.9827
Epoch 10/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2378 - tp: 296.0000 - fp: 5624.0000 - tn: 176322.0000 - fn: 34.0000 - accuracy: 0.9690 - precision: 0.0500 - recall: 0.8970 - auc: 0.9671 - val_loss: 0.1646 - val_tp: 69.0000 - val_fp: 794.0000 - val_tn: 44698.0000 - val_fn: 8.0000 - val_accuracy: 0.9824 - val_precision: 0.0800 - val_recall: 0.8961 - val_auc: 0.9855
Epoch 11/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2690 - tp: 289.0000 - fp: 5909.0000 - tn: 176037.0000 - fn: 41.0000 - accuracy: 0.9674 - precision: 0.0466 - recall: 0.8758 - auc: 0.9649 - val_loss: 0.1585 - val_tp: 71.0000 - val_fp: 841.0000 - val_tn: 44651.0000 - val_fn: 6.0000 - val_accuracy: 0.9814 - val_precision: 0.0779 - val_recall: 0.9221 - val_auc: 0.9865
Epoch 12/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2430 - tp: 291.0000 - fp: 6291.0000 - tn: 175655.0000 - fn: 39.0000 - accuracy: 0.9653 - precision: 0.0442 - recall: 0.8818 - auc: 0.9722 - val_loss: 0.1539 - val_tp: 71.0000 - val_fp: 892.0000 - val_tn: 44600.0000 - val_fn: 6.0000 - val_accuracy: 0.9803 - val_precision: 0.0737 - val_recall: 0.9221 - val_auc: 0.9871
Epoch 13/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2201 - tp: 295.0000 - fp: 6526.0000 - tn: 175420.0000 - fn: 35.0000 - accuracy: 0.9640 - precision: 0.0432 - recall: 0.8939 - auc: 0.9761 - val_loss: 0.1513 - val_tp: 72.0000 - val_fp: 927.0000 - val_tn: 44565.0000 - val_fn: 5.0000 - val_accuracy: 0.9795 - val_precision: 0.0721 - val_recall: 0.9351 - val_auc: 0.9869
Epoch 14/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2891 - tp: 289.0000 - fp: 6851.0000 - tn: 175095.0000 - fn: 41.0000 - accuracy: 0.9622 - precision: 0.0405 - recall: 0.8758 - auc: 0.9576 - val_loss: 0.1494 - val_tp: 72.0000 - val_fp: 941.0000 - val_tn: 44551.0000 - val_fn: 5.0000 - val_accuracy: 0.9792 - val_precision: 0.0711 - val_recall: 0.9351 - val_auc: 0.9871
Epoch 15/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2286 - tp: 294.0000 - fp: 6655.0000 - tn: 175291.0000 - fn: 36.0000 - accuracy: 0.9633 - precision: 0.0423 - recall: 0.8909 - auc: 0.9718 - val_loss: 0.1467 - val_tp: 72.0000 - val_fp: 920.0000 - val_tn: 44572.0000 - val_fn: 5.0000 - val_accuracy: 0.9797 - val_precision: 0.0726 - val_recall: 0.9351 - val_auc: 0.9874
Epoch 16/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2780 - tp: 292.0000 - fp: 6693.0000 - tn: 175253.0000 - fn: 38.0000 - accuracy: 0.9631 - precision: 0.0418 - recall: 0.8848 - auc: 0.9603 - val_loss: 0.1450 - val_tp: 72.0000 - val_fp: 933.0000 - val_tn: 44559.0000 - val_fn: 5.0000 - val_accuracy: 0.9794 - val_precision: 0.0716 - val_recall: 0.9351 - val_auc: 0.9886
Epoch 17/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2387 - tp: 295.0000 - fp: 6554.0000 - tn: 175392.0000 - fn: 35.0000 - accuracy: 0.9639 - precision: 0.0431 - recall: 0.8939 - auc: 0.9693 - val_loss: 0.1408 - val_tp: 72.0000 - val_fp: 910.0000 - val_tn: 44582.0000 - val_fn: 5.0000 - val_accuracy: 0.9799 - val_precision: 0.0733 - val_recall: 0.9351 - val_auc: 0.9895
Epoch 18/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2403 - tp: 289.0000 - fp: 6299.0000 - tn: 175647.0000 - fn: 41.0000 - accuracy: 0.9652 - precision: 0.0439 - recall: 0.8758 - auc: 0.9716 - val_loss: 0.1373 - val_tp: 72.0000 - val_fp: 885.0000 - val_tn: 44607.0000 - val_fn: 5.0000 - val_accuracy: 0.9805 - val_precision: 0.0752 - val_recall: 0.9351 - val_auc: 0.9897
Epoch 19/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2364 - tp: 297.0000 - fp: 6218.0000 - tn: 175728.0000 - fn: 33.0000 - accuracy: 0.9657 - precision: 0.0456 - recall: 0.9000 - auc: 0.9649 - val_loss: 0.1360 - val_tp: 72.0000 - val_fp: 839.0000 - val_tn: 44653.0000 - val_fn: 5.0000 - val_accuracy: 0.9815 - val_precision: 0.0790 - val_recall: 0.9351 - val_auc: 0.9901
Epoch 20/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2281 - tp: 296.0000 - fp: 6241.0000 - tn: 175705.0000 - fn: 34.0000 - accuracy: 0.9656 - precision: 0.0453 - recall: 0.8970 - auc: 0.9700 - val_loss: 0.1347 - val_tp: 72.0000 - val_fp: 839.0000 - val_tn: 44653.0000 - val_fn: 5.0000 - val_accuracy: 0.9815 - val_precision: 0.0790 - val_recall: 0.9351 - val_auc: 0.9900
Epoch 21/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2269 - tp: 298.0000 - fp: 6413.0000 - tn: 175533.0000 - fn: 32.0000 - accuracy: 0.9646 - precision: 0.0444 - recall: 0.9030 - auc: 0.9733 - val_loss: 0.1337 - val_tp: 72.0000 - val_fp: 855.0000 - val_tn: 44637.0000 - val_fn: 5.0000 - val_accuracy: 0.9811 - val_precision: 0.0777 - val_recall: 0.9351 - val_auc: 0.9902
Epoch 22/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2394 - tp: 294.0000 - fp: 6593.0000 - tn: 175353.0000 - fn: 36.0000 - accuracy: 0.9636 - precision: 0.0427 - recall: 0.8909 - auc: 0.9694 - val_loss: 0.1310 - val_tp: 72.0000 - val_fp: 849.0000 - val_tn: 44643.0000 - val_fn: 5.0000 - val_accuracy: 0.9813 - val_precision: 0.0782 - val_recall: 0.9351 - val_auc: 0.9905
Epoch 23/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2429 - tp: 294.0000 - fp: 6548.0000 - tn: 175398.0000 - fn: 36.0000 - accuracy: 0.9639 - precision: 0.0430 - recall: 0.8909 - auc: 0.9669 - val_loss: 0.1295 - val_tp: 72.0000 - val_fp: 872.0000 - val_tn: 44620.0000 - val_fn: 5.0000 - val_accuracy: 0.9808 - val_precision: 0.0763 - val_recall: 0.9351 - val_auc: 0.9905
Epoch 24/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2247 - tp: 293.0000 - fp: 6638.0000 - tn: 175308.0000 - fn: 37.0000 - accuracy: 0.9634 - precision: 0.0423 - recall: 0.8879 - auc: 0.9738 - val_loss: 0.1278 - val_tp: 72.0000 - val_fp: 889.0000 - val_tn: 44603.0000 - val_fn: 5.0000 - val_accuracy: 0.9804 - val_precision: 0.0749 - val_recall: 0.9351 - val_auc: 0.9905
Epoch 25/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2297 - tp: 298.0000 - fp: 6542.0000 - tn: 175404.0000 - fn: 32.0000 - accuracy: 0.9639 - precision: 0.0436 - recall: 0.9030 - auc: 0.9664 - val_loss: 0.1262 - val_tp: 72.0000 - val_fp: 862.0000 - val_tn: 44630.0000 - val_fn: 5.0000 - val_accuracy: 0.9810 - val_precision: 0.0771 - val_recall: 0.9351 - val_auc: 0.9908
Epoch 26/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2083 - tp: 297.0000 - fp: 6787.0000 - tn: 175159.0000 - fn: 33.0000 - accuracy: 0.9626 - precision: 0.0419 - recall: 0.9000 - auc: 0.9767 - val_loss: 0.1243 - val_tp: 72.0000 - val_fp: 891.0000 - val_tn: 44601.0000 - val_fn: 5.0000 - val_accuracy: 0.9803 - val_precision: 0.0748 - val_recall: 0.9351 - val_auc: 0.9916
Epoch 27/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2168 - tp: 298.0000 - fp: 6596.0000 - tn: 175350.0000 - fn: 32.0000 - accuracy: 0.9636 - precision: 0.0432 - recall: 0.9030 - auc: 0.9722 - val_loss: 0.1239 - val_tp: 72.0000 - val_fp: 856.0000 - val_tn: 44636.0000 - val_fn: 5.0000 - val_accuracy: 0.9811 - val_precision: 0.0776 - val_recall: 0.9351 - val_auc: 0.9909
Epoch 28/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2329 - tp: 295.0000 - fp: 6381.0000 - tn: 175565.0000 - fn: 35.0000 - accuracy: 0.9648 - precision: 0.0442 - recall: 0.8939 - auc: 0.9714 - val_loss: 0.1232 - val_tp: 72.0000 - val_fp: 852.0000 - val_tn: 44640.0000 - val_fn: 5.0000 - val_accuracy: 0.9812 - val_precision: 0.0779 - val_recall: 0.9351 - val_auc: 0.9910
Epoch 29/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2318 - tp: 291.0000 - fp: 6544.0000 - tn: 175402.0000 - fn: 39.0000 - accuracy: 0.9639 - precision: 0.0426 - recall: 0.8818 - auc: 0.9729 - val_loss: 0.1224 - val_tp: 72.0000 - val_fp: 883.0000 - val_tn: 44609.0000 - val_fn: 5.0000 - val_accuracy: 0.9805 - val_precision: 0.0754 - val_recall: 0.9351 - val_auc: 0.9909
Epoch 30/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2296 - tp: 296.0000 - fp: 6691.0000 - tn: 175255.0000 - fn: 34.0000 - accuracy: 0.9631 - precision: 0.0424 - recall: 0.8970 - auc: 0.9701 - val_loss: 0.1229 - val_tp: 72.0000 - val_fp: 876.0000 - val_tn: 44616.0000 - val_fn: 5.0000 - val_accuracy: 0.9807 - val_precision: 0.0759 - val_recall: 0.9351 - val_auc: 0.9908
Epoch 31/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.1995 - tp: 300.0000 - fp: 6367.0000 - tn: 175579.0000 - fn: 30.0000 - accuracy: 0.9649 - precision: 0.0450 - recall: 0.9091 - auc: 0.9754 - val_loss: 0.1237 - val_tp: 72.0000 - val_fp: 856.0000 - val_tn: 44636.0000 - val_fn: 5.0000 - val_accuracy: 0.9811 - val_precision: 0.0776 - val_recall: 0.9351 - val_auc: 0.9908
Epoch 32/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.1990 - tp: 295.0000 - fp: 5891.0000 - tn: 176055.0000 - fn: 35.0000 - accuracy: 0.9675 - precision: 0.0477 - recall: 0.8939 - auc: 0.9764 - val_loss: 0.1228 - val_tp: 71.0000 - val_fp: 797.0000 - val_tn: 44695.0000 - val_fn: 6.0000 - val_accuracy: 0.9824 - val_precision: 0.0818 - val_recall: 0.9221 - val_auc: 0.9911
Epoch 33/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2254 - tp: 295.0000 - fp: 6106.0000 - tn: 175840.0000 - fn: 35.0000 - accuracy: 0.9663 - precision: 0.0461 - recall: 0.8939 - auc: 0.9678 - val_loss: 0.1223 - val_tp: 71.0000 - val_fp: 811.0000 - val_tn: 44681.0000 - val_fn: 6.0000 - val_accuracy: 0.9821 - val_precision: 0.0805 - val_recall: 0.9221 - val_auc: 0.9910
Epoch 34/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2104 - tp: 298.0000 - fp: 6251.0000 - tn: 175695.0000 - fn: 32.0000 - accuracy: 0.9655 - precision: 0.0455 - recall: 0.9030 - auc: 0.9727 - val_loss: 0.1202 - val_tp: 71.0000 - val_fp: 780.0000 - val_tn: 44712.0000 - val_fn: 6.0000 - val_accuracy: 0.9828 - val_precision: 0.0834 - val_recall: 0.9221 - val_auc: 0.9913
Epoch 35/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2226 - tp: 298.0000 - fp: 6012.0000 - tn: 175934.0000 - fn: 32.0000 - accuracy: 0.9668 - precision: 0.0472 - recall: 0.9030 - auc: 0.9696 - val_loss: 0.1198 - val_tp: 71.0000 - val_fp: 766.0000 - val_tn: 44726.0000 - val_fn: 6.0000 - val_accuracy: 0.9831 - val_precision: 0.0848 - val_recall: 0.9221 - val_auc: 0.9913
Epoch 36/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2286 - tp: 288.0000 - fp: 6152.0000 - tn: 175794.0000 - fn: 42.0000 - accuracy: 0.9660 - precision: 0.0447 - recall: 0.8727 - auc: 0.9706 - val_loss: 0.1183 - val_tp: 72.0000 - val_fp: 833.0000 - val_tn: 44659.0000 - val_fn: 5.0000 - val_accuracy: 0.9816 - val_precision: 0.0796 - val_recall: 0.9351 - val_auc: 0.9919
Epoch 37/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.1987 - tp: 296.0000 - fp: 6312.0000 - tn: 175634.0000 - fn: 34.0000 - accuracy: 0.9652 - precision: 0.0448 - recall: 0.8970 - auc: 0.9788 - val_loss: 0.1187 - val_tp: 72.0000 - val_fp: 848.0000 - val_tn: 44644.0000 - val_fn: 5.0000 - val_accuracy: 0.9813 - val_precision: 0.0783 - val_recall: 0.9351 - val_auc: 0.9918
Epoch 38/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.1642 - tp: 303.0000 - fp: 6336.0000 - tn: 175610.0000 - fn: 27.0000 - accuracy: 0.9651 - precision: 0.0456 - recall: 0.9182 - auc: 0.9837 - val_loss: 0.1192 - val_tp: 72.0000 - val_fp: 795.0000 - val_tn: 44697.0000 - val_fn: 5.0000 - val_accuracy: 0.9824 - val_precision: 0.0830 - val_recall: 0.9351 - val_auc: 0.9912
Epoch 39/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.1929 - tp: 298.0000 - fp: 6066.0000 - tn: 175880.0000 - fn: 32.0000 - accuracy: 0.9665 - precision: 0.0468 - recall: 0.9030 - auc: 0.9793 - val_loss: 0.1186 - val_tp: 72.0000 - val_fp: 837.0000 - val_tn: 44655.0000 - val_fn: 5.0000 - val_accuracy: 0.9815 - val_precision: 0.0792 - val_recall: 0.9351 - val_auc: 0.9911
Epoch 40/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2026 - tp: 300.0000 - fp: 6033.0000 - tn: 175913.0000 - fn: 30.0000 - accuracy: 0.9667 - precision: 0.0474 - recall: 0.9091 - auc: 0.9755 - val_loss: 0.1178 - val_tp: 72.0000 - val_fp: 798.0000 - val_tn: 44694.0000 - val_fn: 5.0000 - val_accuracy: 0.9824 - val_precision: 0.0828 - val_recall: 0.9351 - val_auc: 0.9912
Epoch 41/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.1792 - tp: 302.0000 - fp: 6117.0000 - tn: 175829.0000 - fn: 28.0000 - accuracy: 0.9663 - precision: 0.0470 - recall: 0.9152 - auc: 0.9808 - val_loss: 0.1178 - val_tp: 72.0000 - val_fp: 819.0000 - val_tn: 44673.0000 - val_fn: 5.0000 - val_accuracy: 0.9819 - val_precision: 0.0808 - val_recall: 0.9351 - val_auc: 0.9912
Epoch 42/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2007 - tp: 299.0000 - fp: 6146.0000 - tn: 175800.0000 - fn: 31.0000 - accuracy: 0.9661 - precision: 0.0464 - recall: 0.9061 - auc: 0.9759 - val_loss: 0.1176 - val_tp: 72.0000 - val_fp: 822.0000 - val_tn: 44670.0000 - val_fn: 5.0000 - val_accuracy: 0.9819 - val_precision: 0.0805 - val_recall: 0.9351 - val_auc: 0.9912
Epoch 43/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.1904 - tp: 296.0000 - fp: 6116.0000 - tn: 175830.0000 - fn: 34.0000 - accuracy: 0.9663 - precision: 0.0462 - recall: 0.8970 - auc: 0.9795 - val_loss: 0.1173 - val_tp: 72.0000 - val_fp: 838.0000 - val_tn: 44654.0000 - val_fn: 5.0000 - val_accuracy: 0.9815 - val_precision: 0.0791 - val_recall: 0.9351 - val_auc: 0.9918
Epoch 44/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2120 - tp: 295.0000 - fp: 6079.0000 - tn: 175867.0000 - fn: 35.0000 - accuracy: 0.9665 - precision: 0.0463 - recall: 0.8939 - auc: 0.9732 - val_loss: 0.1168 - val_tp: 72.0000 - val_fp: 848.0000 - val_tn: 44644.0000 - val_fn: 5.0000 - val_accuracy: 0.9813 - val_precision: 0.0783 - val_recall: 0.9351 - val_auc: 0.9920
Epoch 45/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.1706 - tp: 299.0000 - fp: 6089.0000 - tn: 175857.0000 - fn: 31.0000 - accuracy: 0.9664 - precision: 0.0468 - recall: 0.9061 - auc: 0.9834 - val_loss: 0.1151 - val_tp: 72.0000 - val_fp: 835.0000 - val_tn: 44657.0000 - val_fn: 5.0000 - val_accuracy: 0.9816 - val_precision: 0.0794 - val_recall: 0.9351 - val_auc: 0.9922
Epoch 46/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.1970 - tp: 297.0000 - fp: 6023.0000 - tn: 175923.0000 - fn: 33.0000 - accuracy: 0.9668 - precision: 0.0470 - recall: 0.9000 - auc: 0.9751 - val_loss: 0.1160 - val_tp: 72.0000 - val_fp: 774.0000 - val_tn: 44718.0000 - val_fn: 5.0000 - val_accuracy: 0.9829 - val_precision: 0.0851 - val_recall: 0.9351 - val_auc: 0.9922
Epoch 47/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.1958 - tp: 301.0000 - fp: 5909.0000 - tn: 176037.0000 - fn: 29.0000 - accuracy: 0.9674 - precision: 0.0485 - recall: 0.9121 - auc: 0.9739 - val_loss: 0.1172 - val_tp: 72.0000 - val_fp: 785.0000 - val_tn: 44707.0000 - val_fn: 5.0000 - val_accuracy: 0.9827 - val_precision: 0.0840 - val_recall: 0.9351 - val_auc: 0.9920
Epoch 48/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.1994 - tp: 295.0000 - fp: 6009.0000 - tn: 175937.0000 - fn: 35.0000 - accuracy: 0.9668 - precision: 0.0468 - recall: 0.8939 - auc: 0.9773 - val_loss: 0.1155 - val_tp: 72.0000 - val_fp: 844.0000 - val_tn: 44648.0000 - val_fn: 5.0000 - val_accuracy: 0.9814 - val_precision: 0.0786 - val_recall: 0.9351 - val_auc: 0.9921
Epoch 49/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.1714 - tp: 305.0000 - fp: 6433.0000 - tn: 175513.0000 - fn: 25.0000 - accuracy: 0.9646 - precision: 0.0453 - recall: 0.9242 - auc: 0.9819 - val_loss: 0.1152 - val_tp: 72.0000 - val_fp: 819.0000 - val_tn: 44673.0000 - val_fn: 5.0000 - val_accuracy: 0.9819 - val_precision: 0.0808 - val_recall: 0.9351 - val_auc: 0.9922
Epoch 50/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2130 - tp: 298.0000 - fp: 5848.0000 - tn: 176098.0000 - fn: 32.0000 - accuracy: 0.9677 - precision: 0.0485 - recall: 0.9030 - auc: 0.9720 - val_loss: 0.1159 - val_tp: 72.0000 - val_fp: 799.0000 - val_tn: 44693.0000 - val_fn: 5.0000 - val_accuracy: 0.9824 - val_precision: 0.0827 - val_recall: 0.9351 - val_auc: 0.9922
Epoch 51/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.1886 - tp: 297.0000 - fp: 5934.0000 - tn: 176012.0000 - fn: 33.0000 - accuracy: 0.9673 - precision: 0.0477 - recall: 0.9000 - auc: 0.9767 - val_loss: 0.1145 - val_tp: 72.0000 - val_fp: 788.0000 - val_tn: 44704.0000 - val_fn: 5.0000 - val_accuracy: 0.9826 - val_precision: 0.0837 - val_recall: 0.9351 - val_auc: 0.9922
Epoch 52/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.1742 - tp: 299.0000 - fp: 5854.0000 - tn: 176092.0000 - fn: 31.0000 - accuracy: 0.9677 - precision: 0.0486 - recall: 0.9061 - auc: 0.9840 - val_loss: 0.1152 - val_tp: 72.0000 - val_fp: 771.0000 - val_tn: 44721.0000 - val_fn: 5.0000 - val_accuracy: 0.9830 - val_precision: 0.0854 - val_recall: 0.9351 - val_auc: 0.9920
Epoch 53/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.1871 - tp: 294.0000 - fp: 5902.0000 - tn: 176044.0000 - fn: 36.0000 - accuracy: 0.9674 - precision: 0.0474 - recall: 0.8909 - auc: 0.9821 - val_loss: 0.1146 - val_tp: 72.0000 - val_fp: 790.0000 - val_tn: 44702.0000 - val_fn: 5.0000 - val_accuracy: 0.9826 - val_precision: 0.0835 - val_recall: 0.9351 - val_auc: 0.9919
Epoch 54/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.1867 - tp: 299.0000 - fp: 6031.0000 - tn: 175915.0000 - fn: 31.0000 - accuracy: 0.9667 - precision: 0.0472 - recall: 0.9061 - auc: 0.9807 - val_loss: 0.1146 - val_tp: 72.0000 - val_fp: 811.0000 - val_tn: 44681.0000 - val_fn: 5.0000 - val_accuracy: 0.9821 - val_precision: 0.0815 - val_recall: 0.9351 - val_auc: 0.9920
Epoch 55/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.2040 - tp: 298.0000 - fp: 5879.0000 - tn: 176067.0000 - fn: 32.0000 - accuracy: 0.9676 - precision: 0.0482 - recall: 0.9030 - auc: 0.9754 - val_loss: 0.1128 - val_tp: 72.0000 - val_fp: 752.0000 - val_tn: 44740.0000 - val_fn: 5.0000 - val_accuracy: 0.9834 - val_precision: 0.0874 - val_recall: 0.9351 - val_auc: 0.9924
Epoch 56/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.1847 - tp: 301.0000 - fp: 5600.0000 - tn: 176346.0000 - fn: 29.0000 - accuracy: 0.9691 - precision: 0.0510 - recall: 0.9121 - auc: 0.9804 - val_loss: 0.1145 - val_tp: 72.0000 - val_fp: 743.0000 - val_tn: 44749.0000 - val_fn: 5.0000 - val_accuracy: 0.9836 - val_precision: 0.0883 - val_recall: 0.9351 - val_auc: 0.9923
Epoch 57/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.1704 - tp: 302.0000 - fp: 5460.0000 - tn: 176486.0000 - fn: 28.0000 - accuracy: 0.9699 - precision: 0.0524 - recall: 0.9152 - auc: 0.9821 - val_loss: 0.1157 - val_tp: 72.0000 - val_fp: 712.0000 - val_tn: 44780.0000 - val_fn: 5.0000 - val_accuracy: 0.9843 - val_precision: 0.0918 - val_recall: 0.9351 - val_auc: 0.9922
Epoch 58/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.1868 - tp: 296.0000 - fp: 5714.0000 - tn: 176232.0000 - fn: 34.0000 - accuracy: 0.9685 - precision: 0.0493 - recall: 0.8970 - auc: 0.9811 - val_loss: 0.1162 - val_tp: 72.0000 - val_fp: 749.0000 - val_tn: 44743.0000 - val_fn: 5.0000 - val_accuracy: 0.9835 - val_precision: 0.0877 - val_recall: 0.9351 - val_auc: 0.9919
Epoch 59/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.1586 - tp: 302.0000 - fp: 5572.0000 - tn: 176374.0000 - fn: 28.0000 - accuracy: 0.9693 - precision: 0.0514 - recall: 0.9152 - auc: 0.9868 - val_loss: 0.1172 - val_tp: 72.0000 - val_fp: 699.0000 - val_tn: 44793.0000 - val_fn: 5.0000 - val_accuracy: 0.9846 - val_precision: 0.0934 - val_recall: 0.9351 - val_auc: 0.9919
Epoch 60/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.1461 - tp: 307.0000 - fp: 5554.0000 - tn: 176392.0000 - fn: 23.0000 - accuracy: 0.9694 - precision: 0.0524 - recall: 0.9303 - auc: 0.9868 - val_loss: 0.1177 - val_tp: 72.0000 - val_fp: 690.0000 - val_tn: 44802.0000 - val_fn: 5.0000 - val_accuracy: 0.9847 - val_precision: 0.0945 - val_recall: 0.9351 - val_auc: 0.9917
Epoch 61/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.1838 - tp: 301.0000 - fp: 5501.0000 - tn: 176445.0000 - fn: 29.0000 - accuracy: 0.9697 - precision: 0.0519 - recall: 0.9121 - auc: 0.9797 - val_loss: 0.1179 - val_tp: 72.0000 - val_fp: 668.0000 - val_tn: 44824.0000 - val_fn: 5.0000 - val_accuracy: 0.9852 - val_precision: 0.0973 - val_recall: 0.9351 - val_auc: 0.9918
Epoch 62/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.1784 - tp: 298.0000 - fp: 5279.0000 - tn: 176667.0000 - fn: 32.0000 - accuracy: 0.9709 - precision: 0.0534 - recall: 0.9030 - auc: 0.9821 - val_loss: 0.1180 - val_tp: 72.0000 - val_fp: 693.0000 - val_tn: 44799.0000 - val_fn: 5.0000 - val_accuracy: 0.9847 - val_precision: 0.0941 - val_recall: 0.9351 - val_auc: 0.9918
Epoch 63/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.1648 - tp: 303.0000 - fp: 5722.0000 - tn: 176224.0000 - fn: 27.0000 - accuracy: 0.9685 - precision: 0.0503 - recall: 0.9182 - auc: 0.9847 - val_loss: 0.1186 - val_tp: 71.0000 - val_fp: 707.0000 - val_tn: 44785.0000 - val_fn: 6.0000 - val_accuracy: 0.9844 - val_precision: 0.0913 - val_recall: 0.9221 - val_auc: 0.9916
Epoch 64/100
182276/182276 [==============================] - 1s 3us/sample - loss: 0.1691 - tp: 298.0000 - fp: 5686.0000 - tn: 176260.0000 - fn: 32.0000 - accuracy: 0.9686 - precision: 0.0498 - recall: 0.9030 - auc: 0.9834 - val_loss: 0.1192 - val_tp: 71.0000 - val_fp: 699.0000 - val_tn: 44793.0000 - val_fn: 6.0000 - val_accuracy: 0.9845 - val_precision: 0.0922 - val_recall: 0.9221 - val_auc: 0.9914
Epoch 65/100
182272/182276 [============================>.] - ETA: 0s - loss: 0.1721 - tp: 299.0000 - fp: 5598.0000 - tn: 176344.0000 - fn: 31.0000 - accuracy: 0.9691 - precision: 0.0507 - recall: 0.9061 - auc: 0.9822Restoring model weights from the end of the best epoch.
182276/182276 [==============================] - 1s 3us/sample - loss: 0.1721 - tp: 299.0000 - fp: 5598.0000 - tn: 176348.0000 - fn: 31.0000 - accuracy: 0.9691 - precision: 0.0507 - recall: 0.9061 - auc: 0.9822 - val_loss: 0.1191 - val_tp: 71.0000 - val_fp: 651.0000 - val_tn: 44841.0000 - val_fn: 6.0000 - val_accuracy: 0.9856 - val_precision: 0.0983 - val_recall: 0.9221 - val_auc: 0.9917
Epoch 00065: early stopping
plot_metrics(weighted_history)
train_predictions_weighted = weighted_model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_weighted = weighted_model.predict(test_features, batch_size=BATCH_SIZE)
weighted_results = weighted_model.evaluate(test_features, test_labels,
batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(weighted_model.metrics_names, weighted_results):
print(name, ': ', value)
print()
plot_cm(test_labels, test_predictions_weighted)
loss : 0.07203275692189394
tp : 77.0
fp : 892.0
tn : 55985.0
fn : 8.0
accuracy : 0.9842
precision : 0.07946336
recall : 0.90588236
auc : 0.98777395
Legitimate Transactions Detected (True Negatives): 55985
Legitimate Transactions Incorrectly Detected (False Positives): 892
Fraudulent Transactions Missed (False Negatives): 8
Fraudulent Transactions Detected (True Positives): 77
Total Fraudulent Transactions: 85
Here you can see that with class weights the accuracy and precision are lower because there are more false positives, but conversely the recall and AUC are higher because the model also found more true positives. Despite having lower accuracy, this model has higher recall (and identifies more fraudulent transactions). Of course, there is a cost to both types of error (you wouldn’t want to bug users by flagging too many legitimate transactions as fraudulent, either). Carefully consider the trade offs between these different types of errors for your application.
plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')
plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')
plt.legend(loc='lower right')
A related approach would be to resample the dataset by oversampling the minority class.
pos_features = train_features[bool_train_labels]
neg_features = train_features[~bool_train_labels]
pos_labels = train_labels[bool_train_labels]
neg_labels = train_labels[~bool_train_labels]
You can balance the dataset manually by choosing the right number of random
indices from the positive examples:
ids = np.arange(len(pos_features))
choices = np.random.choice(ids, len(neg_features))
res_pos_features = pos_features[choices]
res_pos_labels = pos_labels[choices]
res_pos_features.shape
(181946, 29)
resampled_features = np.concatenate([res_pos_features, neg_features], axis=0)
resampled_labels = np.concatenate([res_pos_labels, neg_labels], axis=0)
order = np.arange(len(resampled_labels))
np.random.shuffle(order)
resampled_features = resampled_features[order]
resampled_labels = resampled_labels[order]
resampled_features.shape
(363892, 29)
tf.data
If you’re using tf.data
the easiest way to produce balanced examples is to start with a positive
and a negative
dataset, and merge them. See the tf.data guide for more examples.
BUFFER_SIZE = 100000
def make_ds(features, labels):
ds = tf.data.Dataset.from_tensor_slices((features, labels))#.cache()
ds = ds.shuffle(BUFFER_SIZE).repeat()
return ds
pos_ds = make_ds(pos_features, pos_labels)
neg_ds = make_ds(neg_features, neg_labels)
Each dataset provides (feature, label)
pairs:
for features, label in pos_ds.take(1):
print("Features:\n", features.numpy())
print()
print("Label: ", label.numpy())
Features:
[-5. 5. -5. 4.5746928 -5. -3.67004432
-5. 5. -3.28896941 -5. 3.97921294 -5.
1.29803459 -5. -0.13897438 -5. -5. -5.
0.80165403 2.20401439 2.46854325 -2.99536368 -2.19642465 0.28573075
4.04395904 -0.43493746 3.14386429 1.12349325 0.82290934]
Label: 1
Merge the two together using experimental.sample_from_datasets
:
resampled_ds = tf.data.experimental.sample_from_datasets([pos_ds, neg_ds], weights=[0.5, 0.5])
resampled_ds = resampled_ds.batch(BATCH_SIZE).prefetch(2)
for features, label in resampled_ds.take(1):
print(label.numpy().mean())
0.509765625
To use this dataset, you’ll need the number of steps per epoch.
The definition of “epoch” in this case is less clear. Say it’s the number of batches required to see each negative example once:
resampled_steps_per_epoch = np.ceil(2.0*neg/BATCH_SIZE)
resampled_steps_per_epoch
278.0
Now try training the model with the resampled data set instead of using class weights to see how these methods compare.
Note: Because the data was balanced by replicating the positive examples, the total dataset size is larger, and each epoch runs for more training steps.
resampled_model = make_model()
resampled_model.load_weights(initial_weights)
# Reset the bias to zero, since this dataset is balanced.
output_layer = resampled_model.layers[-1]
output_layer.bias.assign([0])
val_ds = tf.data.Dataset.from_tensor_slices((val_features, val_labels)).cache()
val_ds = val_ds.batch(BATCH_SIZE).prefetch(2)
resampled_history = resampled_model.fit(
resampled_ds,
epochs=EPOCHS,
steps_per_epoch=resampled_steps_per_epoch,
callbacks = [early_stopping],
validation_data=val_ds)
Train for 278.0 steps, validate for 23 steps
Epoch 1/100
278/278 [==============================] - 9s 33ms/step - loss: 0.3692 - tp: 246860.0000 - fp: 55512.0000 - tn: 229068.0000 - fn: 37904.0000 - accuracy: 0.8359 - precision: 0.8164 - recall: 0.8669 - auc: 0.9167 - val_loss: 0.1837 - val_tp: 71.0000 - val_fp: 1107.0000 - val_tn: 44385.0000 - val_fn: 6.0000 - val_accuracy: 0.9756 - val_precision: 0.0603 - val_recall: 0.9221 - val_auc: 0.9832
Epoch 2/100
278/278 [==============================] - 7s 25ms/step - loss: 0.1902 - tp: 260564.0000 - fp: 15781.0000 - tn: 268988.0000 - fn: 24011.0000 - accuracy: 0.9301 - precision: 0.9429 - recall: 0.9156 - auc: 0.9763 - val_loss: 0.1088 - val_tp: 70.0000 - val_fp: 877.0000 - val_tn: 44615.0000 - val_fn: 7.0000 - val_accuracy: 0.9806 - val_precision: 0.0739 - val_recall: 0.9091 - val_auc: 0.9858
Epoch 3/100
278/278 [==============================] - 7s 25ms/step - loss: 0.1554 - tp: 262593.0000 - fp: 11114.0000 - tn: 274225.0000 - fn: 21412.0000 - accuracy: 0.9429 - precision: 0.9594 - recall: 0.9246 - auc: 0.9843 - val_loss: 0.0864 - val_tp: 72.0000 - val_fp: 833.0000 - val_tn: 44659.0000 - val_fn: 5.0000 - val_accuracy: 0.9816 - val_precision: 0.0796 - val_recall: 0.9351 - val_auc: 0.9861
Epoch 4/100
278/278 [==============================] - 7s 25ms/step - loss: 0.1378 - tp: 264847.0000 - fp: 9681.0000 - tn: 274469.0000 - fn: 20347.0000 - accuracy: 0.9473 - precision: 0.9647 - recall: 0.9287 - auc: 0.9880 - val_loss: 0.0748 - val_tp: 70.0000 - val_fp: 736.0000 - val_tn: 44756.0000 - val_fn: 7.0000 - val_accuracy: 0.9837 - val_precision: 0.0868 - val_recall: 0.9091 - val_auc: 0.9848
Epoch 5/100
278/278 [==============================] - 7s 25ms/step - loss: 0.1262 - tp: 265413.0000 - fp: 8862.0000 - tn: 276029.0000 - fn: 19040.0000 - accuracy: 0.9510 - precision: 0.9677 - recall: 0.9331 - auc: 0.9902 - val_loss: 0.0686 - val_tp: 71.0000 - val_fp: 717.0000 - val_tn: 44775.0000 - val_fn: 6.0000 - val_accuracy: 0.9841 - val_precision: 0.0901 - val_recall: 0.9221 - val_auc: 0.9852
Epoch 6/100
278/278 [==============================] - 7s 26ms/step - loss: 0.1189 - tp: 266903.0000 - fp: 8422.0000 - tn: 275713.0000 - fn: 18306.0000 - accuracy: 0.9531 - precision: 0.9694 - recall: 0.9358 - auc: 0.9915 - val_loss: 0.0637 - val_tp: 71.0000 - val_fp: 665.0000 - val_tn: 44827.0000 - val_fn: 6.0000 - val_accuracy: 0.9853 - val_precision: 0.0965 - val_recall: 0.9221 - val_auc: 0.9859
Epoch 7/100
278/278 [==============================] - 7s 24ms/step - loss: 0.1112 - tp: 267946.0000 - fp: 8074.0000 - tn: 276335.0000 - fn: 16989.0000 - accuracy: 0.9560 - precision: 0.9707 - recall: 0.9404 - auc: 0.9928 - val_loss: 0.0606 - val_tp: 71.0000 - val_fp: 701.0000 - val_tn: 44791.0000 - val_fn: 6.0000 - val_accuracy: 0.9845 - val_precision: 0.0920 - val_recall: 0.9221 - val_auc: 0.9814
Epoch 8/100
278/278 [==============================] - 7s 25ms/step - loss: 0.1042 - tp: 269747.0000 - fp: 7725.0000 - tn: 275851.0000 - fn: 16021.0000 - accuracy: 0.9583 - precision: 0.9722 - recall: 0.9439 - auc: 0.9939 - val_loss: 0.0544 - val_tp: 71.0000 - val_fp: 625.0000 - val_tn: 44867.0000 - val_fn: 6.0000 - val_accuracy: 0.9862 - val_precision: 0.1020 - val_recall: 0.9221 - val_auc: 0.9767
Epoch 9/100
278/278 [==============================] - 7s 25ms/step - loss: 0.0994 - tp: 270028.0000 - fp: 7464.0000 - tn: 276896.0000 - fn: 14956.0000 - accuracy: 0.9606 - precision: 0.9731 - recall: 0.9475 - auc: 0.9945 - val_loss: 0.0506 - val_tp: 71.0000 - val_fp: 610.0000 - val_tn: 44882.0000 - val_fn: 6.0000 - val_accuracy: 0.9865 - val_precision: 0.1043 - val_recall: 0.9221 - val_auc: 0.9770
Epoch 10/100
278/278 [==============================] - 7s 25ms/step - loss: 0.0944 - tp: 270369.0000 - fp: 7217.0000 - tn: 277663.0000 - fn: 14095.0000 - accuracy: 0.9626 - precision: 0.9740 - recall: 0.9505 - auc: 0.9951 - val_loss: 0.0448 - val_tp: 71.0000 - val_fp: 530.0000 - val_tn: 44962.0000 - val_fn: 6.0000 - val_accuracy: 0.9882 - val_precision: 0.1181 - val_recall: 0.9221 - val_auc: 0.9763
Epoch 11/100
278/278 [==============================] - 7s 25ms/step - loss: 0.0901 - tp: 271376.0000 - fp: 6868.0000 - tn: 277402.0000 - fn: 13698.0000 - accuracy: 0.9639 - precision: 0.9753 - recall: 0.9519 - auc: 0.9957 - val_loss: 0.0425 - val_tp: 71.0000 - val_fp: 508.0000 - val_tn: 44984.0000 - val_fn: 6.0000 - val_accuracy: 0.9887 - val_precision: 0.1226 - val_recall: 0.9221 - val_auc: 0.9769
Epoch 12/100
278/278 [==============================] - 7s 24ms/step - loss: 0.0855 - tp: 272011.0000 - fp: 6648.0000 - tn: 277864.0000 - fn: 12821.0000 - accuracy: 0.9658 - precision: 0.9761 - recall: 0.9550 - auc: 0.9961 - val_loss: 0.0398 - val_tp: 70.0000 - val_fp: 481.0000 - val_tn: 45011.0000 - val_fn: 7.0000 - val_accuracy: 0.9893 - val_precision: 0.1270 - val_recall: 0.9091 - val_auc: 0.9721
Epoch 13/100
277/278 [============================>.] - ETA: 0s - loss: 0.0818 - tp: 271398.0000 - fp: 6566.0000 - tn: 277527.0000 - fn: 11805.0000 - accuracy: 0.9676 - precision: 0.9764 - recall: 0.9583 - auc: 0.9964Restoring model weights from the end of the best epoch.
278/278 [==============================] - 7s 25ms/step - loss: 0.0819 - tp: 272380.0000 - fp: 6597.0000 - tn: 278513.0000 - fn: 11854.0000 - accuracy: 0.9676 - precision: 0.9764 - recall: 0.9583 - auc: 0.9964 - val_loss: 0.0366 - val_tp: 70.0000 - val_fp: 425.0000 - val_tn: 45067.0000 - val_fn: 7.0000 - val_accuracy: 0.9905 - val_precision: 0.1414 - val_recall: 0.9091 - val_auc: 0.9724
Epoch 00013: early stopping
If the training process were considering the whole dataset on each gradient update, this oversampling would be basically identical to the class weighting.
But when training the model batch-wise, as you did here, the oversampled data provides a smoother gradient signal: Instead of each positive example being shown in one batch with a large weight, they’re shown in many different batches each time with a small weight.
This smoother gradient signal makes it easier to train the model.
Note that the distributions of metrics will be different here, because the training data has a totally different distribution from the validation and test data.
plot_metrics(resampled_history )
Because training is easier on the balanced data, the above training procedure may overfit quickly.
So break up the epochs to give the callbacks.EarlyStopping
finer control over when to stop training.
resampled_model = make_model()
resampled_model.load_weights(initial_weights)
# Reset the bias to zero, since this dataset is balanced.
output_layer = resampled_model.layers[-1]
output_layer.bias.assign([0])
resampled_history = resampled_model.fit(
resampled_ds,
# These are not real epochs
steps_per_epoch = 20,
epochs=10*EPOCHS,
callbacks = [early_stopping],
validation_data=(val_ds))
Train for 20 steps, validate for 23 steps
Epoch 1/1000
20/20 [==============================] - 3s 130ms/step - loss: 0.8130 - tp: 13175.0000 - fp: 8205.0000 - tn: 12183.0000 - fn: 7397.0000 - accuracy: 0.6191 - precision: 0.6162 - recall: 0.6404 - auc: 0.6651 - val_loss: 0.6155 - val_tp: 68.0000 - val_fp: 13826.0000 - val_tn: 31666.0000 - val_fn: 9.0000 - val_accuracy: 0.6964 - val_precision: 0.0049 - val_recall: 0.8831 - val_auc: 0.9084
Epoch 2/1000
20/20 [==============================] - 0s 20ms/step - loss: 0.5626 - tp: 16527.0000 - fp: 7660.0000 - tn: 12786.0000 - fn: 3987.0000 - accuracy: 0.7156 - precision: 0.6833 - recall: 0.8056 - auc: 0.8254 - val_loss: 0.5651 - val_tp: 71.0000 - val_fp: 10902.0000 - val_tn: 34590.0000 - val_fn: 6.0000 - val_accuracy: 0.7606 - val_precision: 0.0065 - val_recall: 0.9221 - val_auc: 0.9433
Epoch 3/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.4767 - tp: 17520.0000 - fp: 6712.0000 - tn: 13750.0000 - fn: 2978.0000 - accuracy: 0.7634 - precision: 0.7230 - recall: 0.8547 - auc: 0.8763 - val_loss: 0.4996 - val_tp: 72.0000 - val_fp: 7455.0000 - val_tn: 38037.0000 - val_fn: 5.0000 - val_accuracy: 0.8363 - val_precision: 0.0096 - val_recall: 0.9351 - val_auc: 0.9529
Epoch 4/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.4173 - tp: 17844.0000 - fp: 5599.0000 - tn: 14835.0000 - fn: 2682.0000 - accuracy: 0.7978 - precision: 0.7612 - recall: 0.8693 - auc: 0.9019 - val_loss: 0.4401 - val_tp: 72.0000 - val_fp: 4936.0000 - val_tn: 40556.0000 - val_fn: 5.0000 - val_accuracy: 0.8916 - val_precision: 0.0144 - val_recall: 0.9351 - val_auc: 0.9585
Epoch 5/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.3830 - tp: 17926.0000 - fp: 4826.0000 - tn: 15719.0000 - fn: 2489.0000 - accuracy: 0.8214 - precision: 0.7879 - recall: 0.8781 - auc: 0.9149 - val_loss: 0.3892 - val_tp: 72.0000 - val_fp: 3306.0000 - val_tn: 42186.0000 - val_fn: 5.0000 - val_accuracy: 0.9273 - val_precision: 0.0213 - val_recall: 0.9351 - val_auc: 0.9624
Epoch 6/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.3457 - tp: 18192.0000 - fp: 3931.0000 - tn: 16466.0000 - fn: 2371.0000 - accuracy: 0.8461 - precision: 0.8223 - recall: 0.8847 - auc: 0.9288 - val_loss: 0.3478 - val_tp: 71.0000 - val_fp: 2381.0000 - val_tn: 43111.0000 - val_fn: 6.0000 - val_accuracy: 0.9476 - val_precision: 0.0290 - val_recall: 0.9221 - val_auc: 0.9663
Epoch 7/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.3188 - tp: 18236.0000 - fp: 3353.0000 - tn: 17139.0000 - fn: 2232.0000 - accuracy: 0.8636 - precision: 0.8447 - recall: 0.8910 - auc: 0.9380 - val_loss: 0.3144 - val_tp: 71.0000 - val_fp: 1974.0000 - val_tn: 43518.0000 - val_fn: 6.0000 - val_accuracy: 0.9565 - val_precision: 0.0347 - val_recall: 0.9221 - val_auc: 0.9701
Epoch 8/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.3010 - tp: 18374.0000 - fp: 2860.0000 - tn: 17533.0000 - fn: 2193.0000 - accuracy: 0.8766 - precision: 0.8653 - recall: 0.8934 - auc: 0.9436 - val_loss: 0.2856 - val_tp: 71.0000 - val_fp: 1710.0000 - val_tn: 43782.0000 - val_fn: 6.0000 - val_accuracy: 0.9623 - val_precision: 0.0399 - val_recall: 0.9221 - val_auc: 0.9727
Epoch 9/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.2814 - tp: 18325.0000 - fp: 2459.0000 - tn: 17977.0000 - fn: 2199.0000 - accuracy: 0.8863 - precision: 0.8817 - recall: 0.8929 - auc: 0.9502 - val_loss: 0.2625 - val_tp: 71.0000 - val_fp: 1538.0000 - val_tn: 43954.0000 - val_fn: 6.0000 - val_accuracy: 0.9661 - val_precision: 0.0441 - val_recall: 0.9221 - val_auc: 0.9753
Epoch 10/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.2682 - tp: 18439.0000 - fp: 2290.0000 - tn: 18188.0000 - fn: 2043.0000 - accuracy: 0.8942 - precision: 0.8895 - recall: 0.9003 - auc: 0.9550 - val_loss: 0.2412 - val_tp: 71.0000 - val_fp: 1396.0000 - val_tn: 44096.0000 - val_fn: 6.0000 - val_accuracy: 0.9692 - val_precision: 0.0484 - val_recall: 0.9221 - val_auc: 0.9772
Epoch 11/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.2528 - tp: 18608.0000 - fp: 2036.0000 - tn: 18351.0000 - fn: 1965.0000 - accuracy: 0.9023 - precision: 0.9014 - recall: 0.9045 - auc: 0.9594 - val_loss: 0.2227 - val_tp: 71.0000 - val_fp: 1299.0000 - val_tn: 44193.0000 - val_fn: 6.0000 - val_accuracy: 0.9714 - val_precision: 0.0518 - val_recall: 0.9221 - val_auc: 0.9792
Epoch 12/1000
20/20 [==============================] - 1s 26ms/step - loss: 0.2432 - tp: 18338.0000 - fp: 1886.0000 - tn: 18803.0000 - fn: 1933.0000 - accuracy: 0.9068 - precision: 0.9067 - recall: 0.9046 - auc: 0.9628 - val_loss: 0.2063 - val_tp: 71.0000 - val_fp: 1213.0000 - val_tn: 44279.0000 - val_fn: 6.0000 - val_accuracy: 0.9732 - val_precision: 0.0553 - val_recall: 0.9221 - val_auc: 0.9809
Epoch 13/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.2352 - tp: 18561.0000 - fp: 1723.0000 - tn: 18762.0000 - fn: 1914.0000 - accuracy: 0.9112 - precision: 0.9151 - recall: 0.9065 - auc: 0.9646 - val_loss: 0.1917 - val_tp: 71.0000 - val_fp: 1128.0000 - val_tn: 44364.0000 - val_fn: 6.0000 - val_accuracy: 0.9751 - val_precision: 0.0592 - val_recall: 0.9221 - val_auc: 0.9818
Epoch 14/1000
20/20 [==============================] - 0s 23ms/step - loss: 0.2288 - tp: 18532.0000 - fp: 1604.0000 - tn: 18930.0000 - fn: 1894.0000 - accuracy: 0.9146 - precision: 0.9203 - recall: 0.9073 - auc: 0.9657 - val_loss: 0.1801 - val_tp: 71.0000 - val_fp: 1073.0000 - val_tn: 44419.0000 - val_fn: 6.0000 - val_accuracy: 0.9763 - val_precision: 0.0621 - val_recall: 0.9221 - val_auc: 0.9827
Epoch 15/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.2191 - tp: 18532.0000 - fp: 1497.0000 - tn: 19043.0000 - fn: 1888.0000 - accuracy: 0.9174 - precision: 0.9253 - recall: 0.9075 - auc: 0.9687 - val_loss: 0.1706 - val_tp: 71.0000 - val_fp: 1032.0000 - val_tn: 44460.0000 - val_fn: 6.0000 - val_accuracy: 0.9772 - val_precision: 0.0644 - val_recall: 0.9221 - val_auc: 0.9831
Epoch 16/1000
20/20 [==============================] - 0s 23ms/step - loss: 0.2119 - tp: 18621.0000 - fp: 1401.0000 - tn: 19130.0000 - fn: 1808.0000 - accuracy: 0.9217 - precision: 0.9300 - recall: 0.9115 - auc: 0.9707 - val_loss: 0.1622 - val_tp: 71.0000 - val_fp: 992.0000 - val_tn: 44500.0000 - val_fn: 6.0000 - val_accuracy: 0.9781 - val_precision: 0.0668 - val_recall: 0.9221 - val_auc: 0.9838
Epoch 17/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.2063 - tp: 18544.0000 - fp: 1279.0000 - tn: 19258.0000 - fn: 1879.0000 - accuracy: 0.9229 - precision: 0.9355 - recall: 0.9080 - auc: 0.9716 - val_loss: 0.1550 - val_tp: 71.0000 - val_fp: 985.0000 - val_tn: 44507.0000 - val_fn: 6.0000 - val_accuracy: 0.9783 - val_precision: 0.0672 - val_recall: 0.9221 - val_auc: 0.9842
Epoch 18/1000
20/20 [==============================] - 1s 30ms/step - loss: 0.2030 - tp: 18601.0000 - fp: 1273.0000 - tn: 19308.0000 - fn: 1778.0000 - accuracy: 0.9255 - precision: 0.9359 - recall: 0.9128 - auc: 0.9731 - val_loss: 0.1482 - val_tp: 71.0000 - val_fp: 963.0000 - val_tn: 44529.0000 - val_fn: 6.0000 - val_accuracy: 0.9787 - val_precision: 0.0687 - val_recall: 0.9221 - val_auc: 0.9844
Epoch 19/1000
20/20 [==============================] - 1s 26ms/step - loss: 0.1971 - tp: 18773.0000 - fp: 1172.0000 - tn: 19233.0000 - fn: 1782.0000 - accuracy: 0.9279 - precision: 0.9412 - recall: 0.9133 - auc: 0.9743 - val_loss: 0.1425 - val_tp: 70.0000 - val_fp: 966.0000 - val_tn: 44526.0000 - val_fn: 7.0000 - val_accuracy: 0.9786 - val_precision: 0.0676 - val_recall: 0.9091 - val_auc: 0.9844
Epoch 20/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.1935 - tp: 18785.0000 - fp: 1154.0000 - tn: 19220.0000 - fn: 1801.0000 - accuracy: 0.9279 - precision: 0.9421 - recall: 0.9125 - auc: 0.9753 - val_loss: 0.1370 - val_tp: 70.0000 - val_fp: 942.0000 - val_tn: 44550.0000 - val_fn: 7.0000 - val_accuracy: 0.9792 - val_precision: 0.0692 - val_recall: 0.9091 - val_auc: 0.9844
Epoch 21/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.1900 - tp: 18756.0000 - fp: 1121.0000 - tn: 19355.0000 - fn: 1728.0000 - accuracy: 0.9304 - precision: 0.9436 - recall: 0.9156 - auc: 0.9762 - val_loss: 0.1320 - val_tp: 71.0000 - val_fp: 925.0000 - val_tn: 44567.0000 - val_fn: 6.0000 - val_accuracy: 0.9796 - val_precision: 0.0713 - val_recall: 0.9221 - val_auc: 0.9848
Epoch 22/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.1864 - tp: 18733.0000 - fp: 1026.0000 - tn: 19464.0000 - fn: 1737.0000 - accuracy: 0.9325 - precision: 0.9481 - recall: 0.9151 - auc: 0.9767 - val_loss: 0.1277 - val_tp: 71.0000 - val_fp: 916.0000 - val_tn: 44576.0000 - val_fn: 6.0000 - val_accuracy: 0.9798 - val_precision: 0.0719 - val_recall: 0.9221 - val_auc: 0.9848
Epoch 23/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.1824 - tp: 18721.0000 - fp: 1059.0000 - tn: 19494.0000 - fn: 1686.0000 - accuracy: 0.9330 - precision: 0.9465 - recall: 0.9174 - auc: 0.9782 - val_loss: 0.1239 - val_tp: 72.0000 - val_fp: 910.0000 - val_tn: 44582.0000 - val_fn: 5.0000 - val_accuracy: 0.9799 - val_precision: 0.0733 - val_recall: 0.9351 - val_auc: 0.9854
Epoch 24/1000
20/20 [==============================] - 0s 23ms/step - loss: 0.1779 - tp: 18756.0000 - fp: 1002.0000 - tn: 19486.0000 - fn: 1716.0000 - accuracy: 0.9336 - precision: 0.9493 - recall: 0.9162 - auc: 0.9790 - val_loss: 0.1211 - val_tp: 72.0000 - val_fp: 923.0000 - val_tn: 44569.0000 - val_fn: 5.0000 - val_accuracy: 0.9796 - val_precision: 0.0724 - val_recall: 0.9351 - val_auc: 0.9852
Epoch 25/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.1794 - tp: 18787.0000 - fp: 1046.0000 - tn: 19492.0000 - fn: 1635.0000 - accuracy: 0.9345 - precision: 0.9473 - recall: 0.9199 - auc: 0.9792 - val_loss: 0.1176 - val_tp: 71.0000 - val_fp: 919.0000 - val_tn: 44573.0000 - val_fn: 6.0000 - val_accuracy: 0.9797 - val_precision: 0.0717 - val_recall: 0.9221 - val_auc: 0.9856
Epoch 26/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.1734 - tp: 18754.0000 - fp: 962.0000 - tn: 19581.0000 - fn: 1663.0000 - accuracy: 0.9359 - precision: 0.9512 - recall: 0.9185 - auc: 0.9803 - val_loss: 0.1142 - val_tp: 71.0000 - val_fp: 904.0000 - val_tn: 44588.0000 - val_fn: 6.0000 - val_accuracy: 0.9800 - val_precision: 0.0728 - val_recall: 0.9221 - val_auc: 0.9859
Epoch 27/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.1703 - tp: 18944.0000 - fp: 929.0000 - tn: 19458.0000 - fn: 1629.0000 - accuracy: 0.9375 - precision: 0.9533 - recall: 0.9208 - auc: 0.9809 - val_loss: 0.1118 - val_tp: 71.0000 - val_fp: 898.0000 - val_tn: 44594.0000 - val_fn: 6.0000 - val_accuracy: 0.9802 - val_precision: 0.0733 - val_recall: 0.9221 - val_auc: 0.9857
Epoch 28/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.1677 - tp: 18925.0000 - fp: 903.0000 - tn: 19538.0000 - fn: 1594.0000 - accuracy: 0.9390 - precision: 0.9545 - recall: 0.9223 - auc: 0.9817 - val_loss: 0.1094 - val_tp: 71.0000 - val_fp: 899.0000 - val_tn: 44593.0000 - val_fn: 6.0000 - val_accuracy: 0.9801 - val_precision: 0.0732 - val_recall: 0.9221 - val_auc: 0.9858
Epoch 29/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.1628 - tp: 18892.0000 - fp: 873.0000 - tn: 19621.0000 - fn: 1574.0000 - accuracy: 0.9403 - precision: 0.9558 - recall: 0.9231 - auc: 0.9829 - val_loss: 0.1071 - val_tp: 71.0000 - val_fp: 897.0000 - val_tn: 44595.0000 - val_fn: 6.0000 - val_accuracy: 0.9802 - val_precision: 0.0733 - val_recall: 0.9221 - val_auc: 0.9861
Epoch 30/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.1652 - tp: 18780.0000 - fp: 922.0000 - tn: 19703.0000 - fn: 1555.0000 - accuracy: 0.9395 - precision: 0.9532 - recall: 0.9235 - auc: 0.9821 - val_loss: 0.1048 - val_tp: 71.0000 - val_fp: 891.0000 - val_tn: 44601.0000 - val_fn: 6.0000 - val_accuracy: 0.9803 - val_precision: 0.0738 - val_recall: 0.9221 - val_auc: 0.9857
Epoch 31/1000
20/20 [==============================] - 1s 26ms/step - loss: 0.1611 - tp: 18794.0000 - fp: 850.0000 - tn: 19757.0000 - fn: 1559.0000 - accuracy: 0.9412 - precision: 0.9567 - recall: 0.9234 - auc: 0.9831 - val_loss: 0.1025 - val_tp: 71.0000 - val_fp: 886.0000 - val_tn: 44606.0000 - val_fn: 6.0000 - val_accuracy: 0.9804 - val_precision: 0.0742 - val_recall: 0.9221 - val_auc: 0.9858
Epoch 32/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.1607 - tp: 18852.0000 - fp: 881.0000 - tn: 19649.0000 - fn: 1578.0000 - accuracy: 0.9400 - precision: 0.9554 - recall: 0.9228 - auc: 0.9832 - val_loss: 0.1003 - val_tp: 71.0000 - val_fp: 878.0000 - val_tn: 44614.0000 - val_fn: 6.0000 - val_accuracy: 0.9806 - val_precision: 0.0748 - val_recall: 0.9221 - val_auc: 0.9861
Epoch 33/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.1586 - tp: 18903.0000 - fp: 869.0000 - tn: 19654.0000 - fn: 1534.0000 - accuracy: 0.9413 - precision: 0.9560 - recall: 0.9249 - auc: 0.9836 - val_loss: 0.0982 - val_tp: 71.0000 - val_fp: 872.0000 - val_tn: 44620.0000 - val_fn: 6.0000 - val_accuracy: 0.9807 - val_precision: 0.0753 - val_recall: 0.9221 - val_auc: 0.9863
Epoch 34/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.1581 - tp: 18776.0000 - fp: 837.0000 - tn: 19794.0000 - fn: 1553.0000 - accuracy: 0.9417 - precision: 0.9573 - recall: 0.9236 - auc: 0.9839 - val_loss: 0.0953 - val_tp: 71.0000 - val_fp: 851.0000 - val_tn: 44641.0000 - val_fn: 6.0000 - val_accuracy: 0.9812 - val_precision: 0.0770 - val_recall: 0.9221 - val_auc: 0.9859
Epoch 35/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.1539 - tp: 18821.0000 - fp: 812.0000 - tn: 19799.0000 - fn: 1528.0000 - accuracy: 0.9429 - precision: 0.9586 - recall: 0.9249 - auc: 0.9845 - val_loss: 0.0938 - val_tp: 71.0000 - val_fp: 844.0000 - val_tn: 44648.0000 - val_fn: 6.0000 - val_accuracy: 0.9813 - val_precision: 0.0776 - val_recall: 0.9221 - val_auc: 0.9860
Epoch 36/1000
20/20 [==============================] - 1s 31ms/step - loss: 0.1537 - tp: 18820.0000 - fp: 804.0000 - tn: 19802.0000 - fn: 1534.0000 - accuracy: 0.9429 - precision: 0.9590 - recall: 0.9246 - auc: 0.9847 - val_loss: 0.0926 - val_tp: 71.0000 - val_fp: 847.0000 - val_tn: 44645.0000 - val_fn: 6.0000 - val_accuracy: 0.9813 - val_precision: 0.0773 - val_recall: 0.9221 - val_auc: 0.9863
Epoch 37/1000
20/20 [==============================] - 1s 26ms/step - loss: 0.1518 - tp: 18999.0000 - fp: 781.0000 - tn: 19671.0000 - fn: 1509.0000 - accuracy: 0.9441 - precision: 0.9605 - recall: 0.9264 - auc: 0.9850 - val_loss: 0.0912 - val_tp: 71.0000 - val_fp: 837.0000 - val_tn: 44655.0000 - val_fn: 6.0000 - val_accuracy: 0.9815 - val_precision: 0.0782 - val_recall: 0.9221 - val_auc: 0.9864
Epoch 38/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.1514 - tp: 18991.0000 - fp: 764.0000 - tn: 19655.0000 - fn: 1550.0000 - accuracy: 0.9435 - precision: 0.9613 - recall: 0.9245 - auc: 0.9852 - val_loss: 0.0900 - val_tp: 71.0000 - val_fp: 834.0000 - val_tn: 44658.0000 - val_fn: 6.0000 - val_accuracy: 0.9816 - val_precision: 0.0785 - val_recall: 0.9221 - val_auc: 0.9859
Epoch 39/1000
20/20 [==============================] - 1s 28ms/step - loss: 0.1499 - tp: 18879.0000 - fp: 760.0000 - tn: 19789.0000 - fn: 1532.0000 - accuracy: 0.9440 - precision: 0.9613 - recall: 0.9249 - auc: 0.9856 - val_loss: 0.0890 - val_tp: 71.0000 - val_fp: 835.0000 - val_tn: 44657.0000 - val_fn: 6.0000 - val_accuracy: 0.9815 - val_precision: 0.0784 - val_recall: 0.9221 - val_auc: 0.9861
Epoch 40/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.1495 - tp: 19031.0000 - fp: 748.0000 - tn: 19636.0000 - fn: 1545.0000 - accuracy: 0.9440 - precision: 0.9622 - recall: 0.9249 - auc: 0.9856 - val_loss: 0.0883 - val_tp: 71.0000 - val_fp: 841.0000 - val_tn: 44651.0000 - val_fn: 6.0000 - val_accuracy: 0.9814 - val_precision: 0.0779 - val_recall: 0.9221 - val_auc: 0.9862
Epoch 41/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.1470 - tp: 19029.0000 - fp: 716.0000 - tn: 19696.0000 - fn: 1519.0000 - accuracy: 0.9454 - precision: 0.9637 - recall: 0.9261 - auc: 0.9861 - val_loss: 0.0878 - val_tp: 71.0000 - val_fp: 848.0000 - val_tn: 44644.0000 - val_fn: 6.0000 - val_accuracy: 0.9813 - val_precision: 0.0773 - val_recall: 0.9221 - val_auc: 0.9863
Epoch 42/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.1433 - tp: 19007.0000 - fp: 724.0000 - tn: 19745.0000 - fn: 1484.0000 - accuracy: 0.9461 - precision: 0.9633 - recall: 0.9276 - auc: 0.9868 - val_loss: 0.0868 - val_tp: 71.0000 - val_fp: 846.0000 - val_tn: 44646.0000 - val_fn: 6.0000 - val_accuracy: 0.9813 - val_precision: 0.0774 - val_recall: 0.9221 - val_auc: 0.9862
Epoch 43/1000
20/20 [==============================] - 0s 23ms/step - loss: 0.1469 - tp: 19050.0000 - fp: 748.0000 - tn: 19668.0000 - fn: 1494.0000 - accuracy: 0.9453 - precision: 0.9622 - recall: 0.9273 - auc: 0.9862 - val_loss: 0.0853 - val_tp: 71.0000 - val_fp: 824.0000 - val_tn: 44668.0000 - val_fn: 6.0000 - val_accuracy: 0.9818 - val_precision: 0.0793 - val_recall: 0.9221 - val_auc: 0.9862
Epoch 44/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.1438 - tp: 18976.0000 - fp: 749.0000 - tn: 19781.0000 - fn: 1454.0000 - accuracy: 0.9462 - precision: 0.9620 - recall: 0.9288 - auc: 0.9867 - val_loss: 0.0842 - val_tp: 71.0000 - val_fp: 818.0000 - val_tn: 44674.0000 - val_fn: 6.0000 - val_accuracy: 0.9819 - val_precision: 0.0799 - val_recall: 0.9221 - val_auc: 0.9852
Epoch 45/1000
20/20 [==============================] - 1s 32ms/step - loss: 0.1401 - tp: 19026.0000 - fp: 703.0000 - tn: 19756.0000 - fn: 1475.0000 - accuracy: 0.9468 - precision: 0.9644 - recall: 0.9281 - auc: 0.9875 - val_loss: 0.0836 - val_tp: 71.0000 - val_fp: 816.0000 - val_tn: 44676.0000 - val_fn: 6.0000 - val_accuracy: 0.9820 - val_precision: 0.0800 - val_recall: 0.9221 - val_auc: 0.9853
Epoch 46/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.1416 - tp: 19128.0000 - fp: 736.0000 - tn: 19625.0000 - fn: 1471.0000 - accuracy: 0.9461 - precision: 0.9629 - recall: 0.9286 - auc: 0.9873 - val_loss: 0.0827 - val_tp: 71.0000 - val_fp: 816.0000 - val_tn: 44676.0000 - val_fn: 6.0000 - val_accuracy: 0.9820 - val_precision: 0.0800 - val_recall: 0.9221 - val_auc: 0.9853
Epoch 47/1000
18/20 [==========================>...] - ETA: 0s - loss: 0.1357 - tp: 17173.0000 - fp: 618.0000 - tn: 17810.0000 - fn: 1263.0000 - accuracy: 0.9490 - precision: 0.9653 - recall: 0.9315 - auc: 0.9886Restoring model weights from the end of the best epoch.
20/20 [==============================] - 1s 27ms/step - loss: 0.1354 - tp: 19048.0000 - fp: 688.0000 - tn: 19817.0000 - fn: 1407.0000 - accuracy: 0.9489 - precision: 0.9651 - recall: 0.9312 - auc: 0.9885 - val_loss: 0.0811 - val_tp: 71.0000 - val_fp: 797.0000 - val_tn: 44695.0000 - val_fn: 6.0000 - val_accuracy: 0.9824 - val_precision: 0.0818 - val_recall: 0.9221 - val_auc: 0.9855
Epoch 00047: early stopping
plot_metrics(resampled_history)
train_predictions_resampled = resampled_model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_resampled = resampled_model.predict(test_features, batch_size=BATCH_SIZE)
resampled_results = resampled_model.evaluate(test_features, test_labels,
batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(resampled_model.metrics_names, resampled_results):
print(name, ': ', value)
print()
plot_cm(test_labels, test_predictions_resampled)
loss : 0.09073417115702664
tp : 77.0
fp : 1036.0
tn : 55841.0
fn : 8.0
accuracy : 0.981672
precision : 0.06918239
recall : 0.90588236
auc : 0.9753755
Legitimate Transactions Detected (True Negatives): 55841
Legitimate Transactions Incorrectly Detected (False Positives): 1036
Fraudulent Transactions Missed (False Negatives): 8
Fraudulent Transactions Detected (True Positives): 77
Total Fraudulent Transactions: 85
plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')
plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')
plot_roc("Train Resampled", train_labels, train_predictions_resampled, color=colors[2])
plot_roc("Test Resampled", test_labels, test_predictions_resampled, color=colors[2], linestyle='--')
plt.legend(loc='lower right')
Imbalanced data classification is an inherantly difficult task since there are so few samples to learn from. You should always start with the data first and do your best to collect as many samples as possible and give substantial thought to what features may be relevant so the model can get the most out of your minority class. At some point your model may struggle to improve and yield the results you want, so it is important to keep in mind the context of your problem and the trade offs between different types of errors.