目录
model.fit的作用
model.fit的示例
model.fit可用于以指定的迭代次数训练模型。可以设置的参数很多,重点理解黄色标注的参数,这些比较常用。
x=None,
y=None,
batch_size=None,
epochs=1,
verbose='auto',
callbacks=None,
validation_split=0.0,
validation_data=None,
shuffle=True,
class_weight=None,
sample_weight=None,
initial_epoch=0,
steps_per_epoch=None,
validation_steps=None,
validation_batch_size=None,
validation_freq=1,
max_queue_size=10,
workers=1,
use_multiprocessing=False
x是输入数据,数据类型可以是:numpy array, tensor, tf.data的dataset,返回(输入,目标)或者(输入,目标,权重)的生成器 。或者由numpy array\tensor组成的列表, 值为numpy array\tensor,键为名称字符串的字典。
y是目标数据(也可称标签),数据类型和x一致。如果输入数据x是dataset或者生成器类型,则不能指定y,因为目标数据可以从x中获取。
batch_size是指每个batch的样本数量。如果没有指定,batch_size为32。如果输入数据x是dataset或者生成器类型,则不能指定batch_size,因为x每次就会生成一组batch数据。
epochs设置训练多少个epoch。一个epoch是指对全部的输入数据都迭代一遍。
verbose可以控制进度条的展示。verbose可选值包括‘auto’,0,1,2。0表示静默,什么都不显示,1表示每个epoch展示1根进度条,2表示每个epoch一行log打印,没有进度条。‘auto’大部分情况下默认值是1。
callbacks可以设置回调函数,内置的回调函数可查tf.keras.callbacks.Callback。可以将训练过程想象成一个生产线,回调函数就是在这个生产线上的一些检测点,可以通过回调函数的结果观测整个生产线内部的信息。比如想看看每个epoch的损失值、准确度,回调函数tf.keras.callbacks.History, tf.keras.callbacks.ProgbarLogger就可以帮助我们查看,这两个回调函数是默认使用的,无需额外传入。
内置的回调函数比较多,常见的有
- EarlyStopping(当检测到某种评价指标不再改进时,停止训练)
- LearningRateScheduler(设置学习率的变化方式)
- TensorBoard(在tensorboard中可视化一些信息)
validation_split是一个0~1之间的浮点数。设置训练数据中多少比例的数据用于做验证数据。当输入数据x为dataset或者生成器时,不支持这项设置。
validation_data设置验证数据,数据类型和输入数据允许的类型一样。数据设置成(输入,目标)这种tuple格式即可。如果同时设置了validation_split和validation_data, 则validation_data才是真正会被用于验证的数据。这部分数据不参与训练,只在每个epoch迭代结束时,用于评测。
shuffle设置是否在每个epoch开始前,随机打乱训练数据的顺序。如果输入数据x是dataset或者生成器,这个参数会被忽略。可以设置成True, False或者batch。设置为batch时,会打乱每个batch中的数据。只在steps_per_epoch为None时,设置成‘batch’才会生效。
class_weight在训练过程中,对不同类别的样本进行加权。意为对某些样本多多留意。
sample_weight可以设置样本的权重,在训练过程中,用于对损失函数中不同样本损失值进行加权。当输入数据是dataset类型时,不支持设置该参数,因为dataset类型可支持(输入,目标,权重)这个权重与这里的sample_weight是同一含义。
initial_epochs初始化epoch次数,可以设置经过多少个epoch后,开始训练。可以保证模型在开始的时候,能够运行。
steps_per_epoch可以设置迭代多少个batch算是迭代完一个epoch。默认情况下,steps_per_epoch是对训练数据进行完整的迭代需要的次数,比如数据总量/batch_size
validation_steps意思与steps_per_epoch差不多,只不过这个是指验证数据。
validataion_batch_size验证数据每个batch包含多少个样本,如果没有指定,则validation_batch_size使用batch_size的值。
validation_freq表示迭代多少个epoch后,做一次验证。可以设置成一个整数,或者列表。设置成整数,如3,表示每迭代3个epoch,做一次验证。设置成列表[1,5,10],表示迭代1个,5个,10个epoch后,各做一次验证。
max_queue_size生成器的最大队列长度,一般不用设置。
workers设置线程数
use_multiprocessing设置是否使用基于进程的线程,一般不用设置。
import tensorflow as tf
# load data
(x_train,y_train),(x_test,y_test) = tf.keras.datasets.fashion_mnist.load_data()
# create model
model = tf.keras.Sequential(
[
tf.keras.layers.Flatten(input_shape=(28,28), name='flatten_1'),
tf.keras.layers.Dense(units=200, activation='relu', name='dense_1'),
tf.keras.layers.Dropout(0.2, name='dropout_1'),
tf.keras.layers.Dense(units=10, name='output')
]
)
# compile
model.compile(
optimizer = 'adam',
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]
)
# train
model.fit(
x_train,
y_train,
batch_size = 50,
epochs = 10,
validation_data = (x_test, y_test)
)
'''
Epoch 1/10
1200/1200 [==============================] - 5s 4ms/step - loss: 3.4873 - sparse_categorical_accuracy: 0.6353 - val_loss: 0.8207 - val_sparse_categorical_accuracy: 0.7339
Epoch 2/10
1200/1200 [==============================] - 3s 3ms/step - loss: 0.8110 - sparse_categorical_accuracy: 0.6995 - val_loss: 0.6741 - val_sparse_categorical_accuracy: 0.7494
Epoch 3/10
1200/1200 [==============================] - 3s 2ms/step - loss: 0.7143 - sparse_categorical_accuracy: 0.7346 - val_loss: 0.5563 - val_sparse_categorical_accuracy: 0.8101
Epoch 4/10
1200/1200 [==============================] - 3s 2ms/step - loss: 0.6897 - sparse_categorical_accuracy: 0.7458 - val_loss: 0.5783 - val_sparse_categorical_accuracy: 0.7941
Epoch 5/10
1200/1200 [==============================] - 3s 2ms/step - loss: 0.6713 - sparse_categorical_accuracy: 0.7511 - val_loss: 0.5178 - val_sparse_categorical_accuracy: 0.8183
Epoch 6/10
1200/1200 [==============================] - 3s 2ms/step - loss: 0.6683 - sparse_categorical_accuracy: 0.7534 - val_loss: 0.6003 - val_sparse_categorical_accuracy: 0.7940
Epoch 7/10
1200/1200 [==============================] - 3s 2ms/step - loss: 0.6576 - sparse_categorical_accuracy: 0.7558 - val_loss: 0.5320 - val_sparse_categorical_accuracy: 0.8178
Epoch 8/10
1200/1200 [==============================] - 3s 2ms/step - loss: 0.6411 - sparse_categorical_accuracy: 0.7682 - val_loss: 0.5684 - val_sparse_categorical_accuracy: 0.7848
Epoch 9/10
1200/1200 [==============================] - 3s 2ms/step - loss: 0.6124 - sparse_categorical_accuracy: 0.7742 - val_loss: 0.5932 - val_sparse_categorical_accuracy: 0.8051
Epoch 10/10
1200/1200 [==============================] - 3s 2ms/step - loss: 0.6264 - sparse_categorical_accuracy: 0.7717 - val_loss: 0.5898 - val_sparse_categorical_accuracy: 0.7700
'''