主体是翻译的Keras Tuner的说明:https://keras-team.github.io/keras-
tuner/documentation/tuners/
github地址:https://github.com/keras-team/keras-tuner
Keras Tuner 是一个分布式超参数优化框架,能够在定义的超参数空间里寻找最优参数配置。内置有贝叶斯优化、Hyperband 和随机搜索算法等算法。
不过原文只是举栗子,程序不能运行,改了不少,主要有以下几点:
显卡不行HyperXception、HyperResNet两个模型跑了一个晚上看着遥遥无期就停了,把cifar10的数据量从50000减到10000好像还是需要好久,只是学习下Keras Tuner就不跑了,感兴趣的同学可以试试,把结果分享一下。
Keras Tuner是用于Keras调参的分布式超参数优化框架,尤其是对于基于TensorFlow
2.0的tf.keras
。Keras Tuner 可以轻松定义搜索空间,并利用内置算法找到较佳超参数的值,内置有贝叶斯优化、Hyperband和随机搜索算法。其全部文档和教程见Keras Tuner website.
依赖:
安装命令:
pip install -U keras-tuner
使用源码安装:
git clone https://github.com/keras-team/keras-tuner.git
cd keras-tuner
pip install .
这里展示了如何使用随机搜索为单层深度神经网络寻找最优超参。首先,定义一个模型。其输入一个可以采样超参的hp
引用,如hp.Int('units', min_value=32, max_value=512, step=32)
(特定范围内的整数)。该函数返回一个编译好的模型。
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
# 获取MNIST 数据集.
(x_train, y_train), (x_val, y_val) = mnist.load_data()
x_train = np.expand_dims(x_train.astype('float32') / 255, -1)
x_val = np.expand_dims(x_val.astype('float32') / 255, -1)
# y_train = to_categorical(y_train, 10)
# y_val = to_categorical(y_val, 10)
from kerastuner.tuners import RandomSearch
# 构建模型,传入hp参数,使用其定义需要优化的参数范围,构成参数空间
def build_model(hp):
model = keras.Sequential()
model.add(layers.Input(shape=(28, 28, 1)))
model.add(layers.Flatten())
model.add(layers.Dense(units=hp.Int('units',
min_value=32,
max_value=512,
step=32),
activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
model.compile(
optimizer=keras.optimizers.Adam(
hp.Choice('learning_rate',
values=[1e-2, 1e-3, 1e-4])),
# loss='categorical_crossentropy',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model
下一步,举例说明。需要设置模型构建函数,优化目标的名称(最大化还是最小化由内建度量得出),测试的总试验次数(max_trials
),每次试验模型构建训练次数(executions_per_trial
)。目前的优化器有RandomSearch
和Hyperband
。
注意: 每次试验多次运行的目的是减少结果方差从而可以更精确的获取模型表现。如果想更快的得到结果,可以设置executions_per_trial=1
每个模型配置只训练一轮。
# 选用随机搜索
tuner = RandomSearch(
build_model,
objective='val_accuracy', #优化目标为精度'val_accuracy'(最小化目标)
max_trials=5, #总共试验5次,选五个参数配置
executions_per_trial=3, #每次试验训练模型三次
directory='my_dir',
project_name='helloworld')
可以通过如下代码打印搜索空间综述:
tuner.search_space_summary()
Search space summary
|-Default search space size: 2
units (Int)
|-conditions: []
|-default: None
|-max_value: 512
|-min_value: 32
|-sampling: None
|-step: 32
learning_rate (Choice)
|-conditions: []
|-default: 0.01
|-ordered: True
|-values: [0.01, 0.001, 0.0001]
然后开始搜索最佳的参数配置,search
的调用方式与model.fit()
相似。
tuner.search(x_train, y_train,
epochs=5,
validation_data=(x_val, y_val))
输出,删了好多,看不过来。
Train on 60000 samples, validate on 10000 samples
Epoch 1/5
60000/60000 [==============================] - ETA: 12:56 - loss: 2.2418 - accuracy: 0.156 - ETA: 57s - loss: 1.9715 - accuracy: 0.4083 - ETA: 34s - loss: 1.6704 - accuracy: 0.552 - ETA: 25s - loss: 1.4573 - accuracy: 0.626 - ETA: 20s - loss: 1.2937 - accuracy: 0.661 - ETA: 17s - loss: 1.1664 - accuracy: 0.696 - ETA: 16s - loss: 1.0756 - accuracy: 0.721 - ETA: 14s - loss: 0.9931 - accuracy: 0.743 - ETA: 13s - loss: 0.9200 - accuracy: 0.762 - ETA: 12s - loss: 0.8660 - accuracy: 0.776 - ETA: 12s - loss: 0.8175 - accuracy: 0.787 - ETA: 11s - loss: 0.7853 - accuracy: 0.793 - ETA: 11s - loss: 0.7477 - accuracy: 0.803 - ETA: 10s - loss: 0.7196 - accuracy: 0.808 - ETA: 10s - loss: 0.6958 - accuracy: 0.816 - ETA: 9s - loss: 0.6752 - accuracy: 0.820 - ETA: 9s - loss: 0.6564 - accuracy: 0.82 - ETA: 9s - loss: 0.6421 - accuracy: 0.82 - ETA: 9s - loss: 0.6252 - accuracy: 0.83 - ETA: 8s - loss: 0.6121 - accuracy: 0.83 - ETA: 8s - loss: 0.5969 - accuracy: 0.84 - ETA: 8s - loss: 0.5861 - accuracy: 0.84 - ETA: 8s - loss: 0.5765 - accuracy: 0.84 - ETA: 8s - loss: 0.5628 - accuracy: 0.84 - ETA: 7s - loss: 0.5490 - accuracy: 0.85 - ETA: 7s - loss: 0.5420 - accuracy: 0.85 - ETA: 7s - loss: 0.5327 - accuracy: 0.85 - ETA: 7s - loss: 0.5229 - accuracy: 0.85 - ETA: 7s - loss: 0.5149 - accuracy: 0.86 - ETA: 7s - loss: 0.5073 - accuracy: 0.86 - ETA: 7s - loss: 0.5009 - accuracy: 0.86 - ETA: 7s - loss: 0.4937 - accuracy: 0.86 - ETA: 6s - loss: 0.4892 - accuracy: 0.86 - ETA: 6s - loss: 0.4858 - accuracy: 0.86 - ETA: 6s - loss: 0.4785 - accuracy: 0.87 - ETA: 6s - loss: 0.4731 - accuracy: 0.87 - ETA: 6s - loss: 0.4657 - accuracy: 0.87 - ETA: 6s - loss: 0.4592 - accuracy: 0.87 - ETA: 6s - loss: 0.4547 - accuracy: 0.87 - ETA: 6s - loss: 0.4499 - accuracy: 0.87 - ETA: 6s - loss: 0.4456 - accuracy: 0.87 - ETA: 6s - loss: 0.4409 - accuracy: 0.88 - ETA: 5s - loss: 0.4364 - accuracy: 0.88 - ETA: 5s - loss: 0.4312 - accuracy: 0.88 - ETA: 5s - loss: 0.4271 - accuracy: 0.88 - ETA: 5s - loss: 0.4230 - accuracy: 0.88 - ETA: 5s - loss: 0.4190 - accuracy: 0.88 - ETA: 5s - loss: 0.4152 - accuracy: 0.88 - ETA: 5s - loss: 0.4118 - accuracy: 0.88 - ETA: 5s - loss: 0.4085 - accuracy: 0.88 - ETA: 5s - loss: 0.4046 - accuracy: 0.88 - ETA: 5s - loss: 0.4010 - accuracy: 0.89 - ETA: 5s - loss: 0.3975 - accuracy: 0.89 - ETA: 5s - loss: 0.3947 - accuracy: 0.89 - ETA: 5s - loss: 0.3920 - accuracy: 0.89 - ETA: 4s - loss: 0.3898 - accuracy: 0.89 - ETA: 4s - loss: 0.3858 - accuracy: 0.89 - ETA: 4s - loss: 0.3824 - accuracy: 0.89 - ETA: 4s - loss: 0.3794 - accuracy: 0.89 - ETA: 4s - loss: 0.3758 - accuracy: 0.89 - ETA: 4s - loss: 0.3746 - accuracy: 0.89 - ETA: 4s - loss: 0.3721 - accuracy: 0.89 - ETA: 4s - loss: 0.3700 - accuracy: 0.89 - ETA: 4s - loss: 0.3695 - accuracy: 0.89 - ETA: 4s - loss: 0.3671 - accuracy: 0.89 - ETA: 4s - loss: 0.3650 - accuracy: 0.89 - ETA: 4s - loss: 0.3628 - accuracy: 0.90 - ETA: 4s - loss: 0.3600 - accuracy: 0.90 - ETA: 4s - loss: 0.3574 - accuracy: 0.90 - ETA: 4s - loss: 0.3555 - accuracy: 0.90 - ETA: 3s - loss: 0.3543 - accuracy: 0.90 - ETA: 3s - loss: 0.3529 - accuracy: 0.90 - ETA: 3s - loss: 0.3522 - accuracy: 0.90 - ETA: 3s - loss: 0.3505 - accuracy: 0.90 - ETA: 3s - loss: 0.3498 - accuracy: 0.90 - ETA: 3s - loss: 0.3483 - accuracy: 0.90 - ETA: 3s - loss: 0.3463 - accuracy: 0.90 - ETA: 3s - loss: 0.3444 - accuracy: 0.90 - ETA: 3s - loss: 0.3417 - accuracy: 0.90 - ETA: 3s - loss: 0.3398 - accuracy: 0.90 - ETA: 3s - loss: 0.3386 - accuracy: 0.90 - ETA: 3s - loss: 0.3374 - accuracy: 0.90 - ETA: 3s - loss: 0.3356 - accuracy: 0.90 - ETA: 3s - loss: 0.3337 - accuracy: 0.90 - ETA: 3s - loss: 0.3320 - accuracy: 0.90 - ETA: 3s - loss: 0.3310 - accuracy: 0.90 - ETA: 3s - loss: 0.3296 - accuracy: 0.90 - ETA: 2s - loss: 0.3284 - accuracy: 0.90 - ETA: 2s - loss: 0.3273 - accuracy: 0.90 - ETA: 2s - loss: 0.3257 - accuracy: 0.90 - ETA: 2s - loss: 0.3247 - accuracy: 0.91 - ETA: 2s - loss: 0.3239 - accuracy: 0.91 - ETA: 2s - loss: 0.3227 - accuracy: 0.91 - ETA: 2s - loss: 0.3211 - accuracy: 0.91 - ETA: 2s - loss: 0.3196 - accuracy: 0.91 - ETA: 2s - loss: 0.3175 - accuracy: 0.91 - ETA: 2s - loss: 0.3166 - accuracy: 0.91 - ETA: 2s - loss: 0.3159 - accuracy: 0.91 - ETA: 2s - loss: 0.3148 - accuracy: 0.91 - ETA: 2s - loss: 0.3134 - accuracy: 0.91 - ETA: 2s - loss: 0.3124 - accuracy: 0.91 - ETA: 2s - loss: 0.3110 - accuracy: 0.91 - ETA: 2s - loss: 0.3102 - accuracy: 0.91 - ETA: 2s - loss: 0.3095 - accuracy: 0.91 - ETA: 1s - loss: 0.3084 - accuracy: 0.91 - ETA: 1s - loss: 0.3076 - accuracy: 0.91 - ETA: 1s - loss: 0.3064 - accuracy: 0.91 - ETA: 1s - loss: 0.3059 - accuracy: 0.91 - ETA: 1s - loss: 0.3049 - accuracy: 0.91 - ETA: 1s - loss: 0.3036 - accuracy: 0.91 - ETA: 1s - loss: 0.3031 - accuracy: 0.91 - ETA: 1s - loss: 0.3015 - accuracy: 0.91 - ETA: 1s - loss: 0.3011 - accuracy: 0.91 - ETA: 1s - loss: 0.3003 - accuracy: 0.91 - ETA: 1s - loss: 0.2991 - accuracy: 0.91 - ETA: 1s - loss: 0.2986 - accuracy: 0.91 - ETA: 1s - loss: 0.2975 - accuracy: 0.91 - ETA: 1s - loss: 0.2965 - accuracy: 0.91 - ETA: 1s - loss: 0.2951 - accuracy: 0.91 - ETA: 1s - loss: 0.2941 - accuracy: 0.91 - ETA: 1s - loss: 0.2932 - accuracy: 0.91 - ETA: 1s - loss: 0.2919 - accuracy: 0.91 - ETA: 1s - loss: 0.2912 - accuracy: 0.91 - ETA: 0s - loss: 0.2896 - accuracy: 0.91 - ETA: 0s - loss: 0.2886 - accuracy: 0.91 - ETA: 0s - loss: 0.2873 - accuracy: 0.92 - ETA: 0s - loss: 0.2870 - accuracy: 0.92 - ETA: 0s - loss: 0.2861 - accuracy: 0.92 - ETA: 0s - loss: 0.2851 - accuracy: 0.92 - ETA: 0s - loss: 0.2843 - accuracy: 0.92 - ETA: 0s - loss: 0.2831 - accuracy: 0.92 - ETA: 0s - loss: 0.2820 - accuracy: 0.92 - ETA: 0s - loss: 0.2816 - accuracy: 0.92 - ETA: 0s - loss: 0.2805 - accuracy: 0.92 - ETA: 0s - loss: 0.2798 - accuracy: 0.92 - ETA: 0s - loss: 0.2793 - accuracy: 0.92 - ETA: 0s - loss: 0.2790 - accuracy: 0.92 - ETA: 0s - loss: 0.2783 - accuracy: 0.92 - ETA: 0s - loss: 0.2771 - accuracy: 0.92 - ETA: 0s - loss: 0.2760 - accuracy: 0.92 - ETA: 0s - loss: 0.2754 - accuracy: 0.92 - 8s 141us/sample - loss: 0.2749 - accuracy: 0.9234 - val_loss: 0.1597 - val_accuracy: 0.9536
Epoch 2/5
60000/60000 [==============================] - ETA: 6s - loss: 0.3736 - accuracy: 0.90 - ETA: 6s - loss: 0.1718 - accuracy: 0.94 - ETA: 6s - loss: 0.1503 - accuracy: 0.95 - ETA: 6s - loss: 0.1499 - accuracy: 0.95 - ETA: 6s - loss: 0.1469 - accuracy: 0.95 - ETA: 7s - loss: 0.1446 - accuracy: 0.95 - ETA: 7s - loss: 0.1578 - accuracy: 0.95 - ETA: 7s - loss: 0.1543 - accuracy: 0.95 - ETA: 7s - loss: 0.1538 - accuracy: 0.95 - ETA: 6s - loss: 0.1476 - accuracy: 0.96 - ETA: 6s - loss: 0.1456 - accuracy: 0.96 - ETA: 6s - loss: 0.1555 - accuracy: 0.95 - ETA: 6s - loss: 0.1558 - accuracy: 0.95 - ETA: 6s - loss: 0.1536 - accuracy: 0.95 - ETA: 6s - loss: 0.1543 - accuracy: 0.95 - ETA: 6s - loss: 0.1536 - accuracy: 0.95 - ETA: 6s - loss: 0.1556 - accuracy: 0.95 - ETA: 6s - loss: 0.1552 - accuracy: 0.95 - ETA: 6s - loss: 0.1519 - accuracy: 0.95 - ETA: 6s - loss: 0.1522 - accuracy: 0.95 - ETA: 6s - loss: 0.1533 - accuracy: 0.95 - ETA: 6s - loss: 0.1514 - accuracy: 0.95 - ETA: 6s - loss: 0.1511 - accuracy: 0.95 - ETA: 6s - loss: 0.1513 - accuracy: 0.95 - ETA: 6s - loss: 0.1493 - accuracy: 0.95 - ETA: 6s - loss: 0.1485 - accuracy: 0.95 - ETA: 5s - loss: 0.1488 - accuracy: 0.95 - ETA: 5s - loss: 0.1485 - accuracy: 0.95 - ETA: 5s - loss: 0.1499 - accuracy: 0.95 - ETA: 5s - loss: 0.1502 - accuracy: 0.95 - ETA: 5s - loss: 0.1501 - accuracy: 0.95 - ETA: 5s - loss: 0.1488 - accuracy: 0.95 - ETA: 5s - loss: 0.1500 - accuracy: 0.95 - ETA: 5s - loss: 0.1498 - accuracy: 0.95 - ETA: 5s - loss: 0.1487 - accuracy: 0.95 - ETA: 5s - loss: 0.1478 - accuracy: 0.95 - ETA: 5s - loss: 0.1461 - accuracy: 0.95 - ETA: 5s - loss: 0.1459 - accuracy: 0.95 - ETA: 5s - loss: 0.1465 - accuracy: 0.95 - ETA: 5s - loss: 0.1462 - accuracy: 0.95 - ETA: 5s - loss: 0.1457 - accuracy: 0.95 - ETA: 5s - loss: 0.1451 - accuracy: 0.95 - ETA: 5s - loss: 0.1461 - accuracy: 0.95 - ETA: 5s - loss: 0.1454 - accuracy: 0.95 - ETA: 5s - loss: 0.1450 - accuracy: 0.95 - ETA: 4s - loss: 0.1445 - accuracy: 0.95 - ETA: 4s - loss: 0.1449 - accuracy: 0.95 - ETA: 4s - loss: 0.1448 - accuracy: 0.95 - ETA: 4s - loss: 0.1451 - accuracy: 0.95 - ETA: 4s - loss: 0.1449 - accuracy: 0.95 - ETA: 4s - loss: 0.1445 - accuracy: 0.95 - ETA: 4s - loss: 0.1450 - accuracy: 0.95 - ETA: 4s - loss: 0.1441 - accuracy: 0.95 - ETA: 4s - loss: 0.1441 - accuracy: 0.95 - ETA: 4s - loss: 0.1434 - accuracy: 0.95 - ETA: 4s - loss: 0.1429 - accuracy: 0.95 - ETA: 4s - loss: 0.1426 - accuracy: 0.95 - ETA: 4s - loss: 0.1429 - accuracy: 0.95 - ETA: 4s - loss: 0.1426 - accuracy: 0.95 - ETA: 4s - loss: 0.1415 - accuracy: 0.95 - ETA: 4s - loss: 0.1410 - accuracy: 0.95 - ETA: 4s - loss: 0.1404 - accuracy: 0.95 - ETA: 4s - loss: 0.1410 - accuracy: 0.95 - ETA: 4s - loss: 0.1411 - accuracy: 0.95 - ETA: 3s - loss: 0.1412 - accuracy: 0.95 - ETA: 3s - loss: 0.1421 - accuracy: 0.95 - ETA: 3s - loss: 0.1429 - accuracy: 0.95 - ETA: 3s - loss: 0.1428 - accuracy: 0.95 - ETA: 3s - loss: 0.1417 - accuracy: 0.95 - ETA: 3s - loss: 0.1415 - accuracy: 0.95 - ETA: 3s - loss: 0.1414 - accuracy: 0.95 - ETA: 3s - loss: 0.1407 - accuracy: 0.95 - ETA: 3s - loss: 0.1398 - accuracy: 0.95 - ETA: 3s - loss: 0.1401 - accuracy: 0.95 - ETA: 3s - loss: 0.1399 - accuracy: 0.95 - ETA: 3s - loss: 0.1396 - accuracy: 0.95 - ETA: 3s - loss: 0.1394 - accuracy: 0.95 - ETA: 3s - loss: 0.1391 - accuracy: 0.95 - ETA: 3s - loss: 0.1386 - accuracy: 0.95 - ETA: 3s - loss: 0.1383 - accuracy: 0.96 - ETA: 3s - loss: 0.1380 - accuracy: 0.96 - ETA: 3s - loss: 0.1376 - accuracy: 0.96 - ETA: 2s - loss: 0.1376 - accuracy: 0.96 - ETA: 2s - loss: 0.1375 - accuracy: 0.96 - ETA: 2s - loss: 0.1371 - accuracy: 0.96 - ETA: 2s - loss: 0.1375 - accuracy: 0.96 - ETA: 2s - loss: 0.1372 - accuracy: 0.96 - ETA: 2s - loss: 0.1373 - accuracy: 0.96 - ETA: 2s - loss: 0.1369 - accuracy: 0.96 - ETA: 2s - loss: 0.1367 - accuracy: 0.96 - ETA: 2s - loss: 0.1369 - accuracy: 0.96 - ETA: 2s - loss: 0.1373 - accuracy: 0.96 - ETA: 2s - loss: 0.1376 - accuracy: 0.96 - ETA: 2s - loss: 0.1385 - accuracy: 0.95 - ETA: 2s - loss: 0.1381 - accuracy: 0.95 - ETA: 2s - loss: 0.1381 - accuracy: 0.95 - ETA: 2s - loss: 0.1385 - accuracy: 0.95 - ETA: 2s - loss: 0.1381 - accuracy: 0.95 - ETA: 2s - loss: 0.1378 - accuracy: 0.95 - ETA: 2s - loss: 0.1380 - accuracy: 0.95 - ETA: 2s - loss: 0.1382 - accuracy: 0.95 - ETA: 2s - loss: 0.1379 - accuracy: 0.95 - ETA: 1s - loss: 0.1380 - accuracy: 0.95 - ETA: 1s - loss: 0.1378 - accuracy: 0.95 - ETA: 1s - loss: 0.1376 - accuracy: 0.95 - ETA: 1s - loss: 0.1375 - accuracy: 0.95 - ETA: 1s - loss: 0.1370 - accuracy: 0.95 - ETA: 1s - loss: 0.1367 - accuracy: 0.96 - ETA: 1s - loss: 0.1367 - accuracy: 0.96 - ETA: 1s - loss: 0.1364 - accuracy: 0.96 - ETA: 1s - loss: 0.1361 - accuracy: 0.96 - ETA: 1s - loss: 0.1359 - accuracy: 0.96 - ETA: 1s - loss: 0.1354 - accuracy: 0.96 - ETA: 1s - loss: 0.1351 - accuracy: 0.96 - ETA: 1s - loss: 0.1348 - accuracy: 0.96 - ETA: 1s - loss: 0.1345 - accuracy: 0.96 - ETA: 1s - loss: 0.1347 - accuracy: 0.96 - ETA: 1s - loss: 0.1350 - accuracy: 0.96 - ETA: 1s - loss: 0.1345 - accuracy: 0.96 - ETA: 1s - loss: 0.1340 - accuracy: 0.96 - ETA: 1s - loss: 0.1336 - accuracy: 0.96 - ETA: 0s - loss: 0.1337 - accuracy: 0.96 - ETA: 0s - loss: 0.1336 - accuracy: 0.96 - ETA: 0s - loss: 0.1334 - accuracy: 0.96 - ETA: 0s - loss: 0.1334 - accuracy: 0.96 - ETA: 0s - loss: 0.1335 - accuracy: 0.96 - ETA: 0s - loss: 0.1337 - accuracy: 0.96 - ETA: 0s - loss: 0.1337 - accuracy: 0.96 - ETA: 0s - loss: 0.1336 - accuracy: 0.96 - ETA: 0s - loss: 0.1337 - accuracy: 0.96 - ETA: 0s - loss: 0.1334 - accuracy: 0.96 - ETA: 0s - loss: 0.1329 - accuracy: 0.96 - ETA: 0s - loss: 0.1326 - accuracy: 0.96 - ETA: 0s - loss: 0.1323 - accuracy: 0.96 - ETA: 0s - loss: 0.1323 - accuracy: 0.96 - ETA: 0s - loss: 0.1323 - accuracy: 0.96 - ETA: 0s - loss: 0.1324 - accuracy: 0.96 - ETA: 0s - loss: 0.1326 - accuracy: 0.96 - ETA: 0s - loss: 0.1326 - accuracy: 0.96 - ETA: 0s - loss: 0.1324 - accuracy: 0.96 - 8s 131us/sample - loss: 0.1322 - accuracy: 0.9613 - val_loss: 0.1164 - val_accuracy: 0.9655
Epoch 3/5
38400/60000 [==================>...........] - ETA: 6s - loss: 0.0618 - accuracy: 0.96 - ETA: 6s - loss: 0.0994 - accuracy: 0.95 - ETA: 6s - loss: 0.1115 - accuracy: 0.95 - ETA: 6s - loss: 0.0972 - accuracy: 0.96 - ETA: 6s - loss: 0.0947 - accuracy: 0.96 - ETA: 6s - loss: 0.0944 - accuracy: 0.97 - ETA: 6s - loss: 0.0890 - accuracy: 0.97 - ETA: 6s - loss: 0.0913 - accuracy: 0.97 - ETA: 6s - loss: 0.0905 - accuracy: 0.97 - ETA: 6s - loss: 0.0897 - accuracy: 0.97 - ETA: 6s - loss: 0.0874 - accuracy: 0.97 - ETA: 6s - loss: 0.0874 - accuracy: 0.97 - ETA: 6s - loss: 0.0864 - accuracy: 0.97 - ETA: 6s - loss: 0.0863 - accuracy: 0.97 - ETA: 6s - loss: 0.0874 - accuracy: 0.97 - ETA: 6s - loss: 0.0904 - accuracy: 0.97 - ETA: 6s - loss: 0.0879 - accuracy: 0.97 - ETA: 6s - loss: 0.0899 - accuracy: 0.97 - ETA: 6s - loss: 0.0897 - accuracy: 0.97 - ETA: 6s - loss: 0.0905 - accuracy: 0.97 - ETA: 6s - loss: 0.0924 - accuracy: 0.97 - ETA: 6s - loss: 0.0922 - accuracy: 0.97 - ETA: 6s - loss: 0.0901 - accuracy: 0.97 - ETA: 5s - loss: 0.0895 - accuracy: 0.97 - ETA: 5s - loss: 0.0911 - accuracy: 0.97 - ETA: 5s - loss: 0.0914 - accuracy: 0.97 - ETA: 5s - loss: 0.0907 - accuracy: 0.97 - ETA: 5s - loss: 0.0904 - accuracy: 0.97 - ETA: 5s - loss: 0.0919 - accuracy: 0.97 - ETA: 5s - loss: 0.0904 - accuracy: 0.97 - ETA: 5s - loss: 0.0901 - accuracy: 0.97 - ETA: 5s - loss: 0.0907 - accuracy: 0.97 - ETA: 5s - loss: 0.0909 - accuracy: 0.97 - ETA: 5s - loss: 0.0919 - accuracy: 0.97 - ETA: 5s - loss: 0.0917 - accuracy: 0.97 - ETA: 5s - loss: 0.0923 - accuracy: 0.97 - ETA: 5s - loss: 0.0935 - accuracy: 0.97 - ETA: 5s - loss: 0.0935 - accuracy: 0.97 - ETA: 5s - loss: 0.0933 - accuracy: 0.97 - ETA: 5s - loss: 0.0929 - accuracy: 0.97 - ETA: 5s - loss: 0.0921 - accuracy: 0.97 - ETA: 5s - loss: 0.0926 - accuracy: 0.97 - ETA: 5s - loss: 0.0931 - accuracy: 0.97 - ETA: 5s - loss: 0.0925 - accuracy: 0.97 - ETA: 4s - loss: 0.0925 - accuracy: 0.97 - ETA: 4s - loss: 0.0928 - accuracy: 0.97 - ETA: 4s - loss: 0.0932 - accuracy: 0.97 - ETA: 4s - loss: 0.0926 - accuracy: 0.97 - ETA: 4s - loss: 0.0925 - accuracy: 0.97 - ETA: 4s - loss: 0.0928 - accuracy: 0.97 - ETA: 4s - loss: 0.0929 - accuracy: 0.97 - ETA: 4s - loss: 0.0927 - accuracy: 0.97 - ETA: 4s - loss: 0.0936 - accuracy: 0.97 - ETA: 4s - loss: 0.0937 - accuracy: 0.97 - ETA: 4s - loss: 0.0936 - accuracy: 0.97 - ETA: 4s - loss: 0.0932 - accuracy: 0.97 - ETA: 4s - loss: 0.0937 - accuracy: 0.97 - ETA: 4s - loss: 0.0934 - accuracy: 0.97 - ETA: 4s - loss: 0.0934 - accuracy: 0.97 - ETA: 4s - loss: 0.0934 - accuracy: 0.97 - ETA: 4s - loss: 0.0933 - accuracy: 0.97 - ETA: 4s - loss: 0.0928 - accuracy: 0.97 - ETA: 4s - loss: 0.0928 - accuracy: 0.97 - ETA: 3s - loss: 0.0930 - accuracy: 0.97 - ETA: 3s - loss: 0.0929 - accuracy: 0.97 - ETA: 3s - loss: 0.0926 - accuracy: 0.97 - ETA: 3s - loss: 0.0928 - accuracy: 0.97 - ETA: 3s - loss: 0.0933 - accuracy: 0.97 - ETA: 3s - loss: 0.0930 - accuracy: 0.97 - ETA: 3s - loss: 0.0928 - accuracy: 0.97 - ETA: 3s - loss: 0.0922 - accuracy: 0.97 - ETA: 3s - loss: 0.0920 - accuracy: 0.97 - ETA: 3s - loss: 0.0922 - accuracy: 0.97 - ETA: 3s - loss: 0.0924 - accuracy: 0.97 - ETA: 3s - loss: 0.0920 - accuracy: 0.97 - ETA: 3s - loss: 0.0919 - accuracy: 0.97 - ETA: 3s - loss: 0.0913 - accuracy: 0.97 - ETA: 3s - loss: 0.0920 - accuracy: 0.97 - ETA: 3s - loss: 0.0917 - accuracy: 0.97 - ETA: 3s - loss: 0.0915 - accuracy: 0.97 - ETA: 3s - loss: 0.0915 - accuracy: 0.97 - ETA: 3s - loss: 0.0912 - accuracy: 0.97 - ETA: 3s - loss: 0.0920 - accuracy: 0.97 - ETA: 2s - loss: 0.0921 - accuracy: 0.97 - ETA: 2s - loss: 0.0920 - accuracy: 0.97 - ETA: 2s - loss: 0.0920 - accuracy: 0.97 - ETA: 2s - loss: 0.0921 - accuracy: 0.97 - ETA: 2s - loss: 0.0921 - accuracy: 0.97 - ETA: 2s - loss: 0.0921 - accuracy: 0.97 - ETA: 2s - loss: 0.0922 - accuracy: 0.97 - ETA: 2s - loss: 0.0922 - accuracy: 0.9720WARNING:tensorflow:Can save best model only with val_accuracy available, skipping.
搜索过程具体如下:通过调用模型构建函数,使用hp
跟踪的超参空间(搜索空间)中的参数配置,多次构建模型。优化器逐渐探索超参空间,记录每种配置的评估结果。
当搜索结束时,你可以得到最佳的模型。
# 返回最佳的两个模型
models = tuner.get_best_models(num_models=2)
也可以打印结果综述。
tuner.results_summary()
Results summary
|-Results in my_dir/helloworld
|-Showing 10 best trials
|-Objective(name='val_accuracy', direction='max')
Trial summary
|-Trial ID: 71fc41aef4fc34c049c2f3b22a74252f
|-Score: 0.9792666435241699
|-Best step: 0
Hyperparameters:
|-learning_rate: 0.001
|-units: 256
Trial summary
|-Trial ID: b184c03f4c418071edd3b5afa390f952
|-Score: 0.9789333343505859
|-Best step: 0
Hyperparameters:
|-learning_rate: 0.001
|-units: 224
Trial summary
|-Trial ID: 4111f77d4d668a6a593030c902074bec
|-Score: 0.9765666127204895
|-Best step: 0
Hyperparameters:
|-learning_rate: 0.001
|-units: 128
Trial summary
|-Trial ID: 37e520f0bf10bbb4f3a32806700dce93
|-Score: 0.9545333385467529
|-Best step: 0
Hyperparameters:
|-learning_rate: 0.0001
|-units: 160
Trial summary
|-Trial ID: 9ec337e0d6d8ff7e7c53fdf9931557a5
|-Score: 0.9280333518981934
|-Best step: 0
Hyperparameters:
|-learning_rate: 0.0001
|-units: 32
可以在本例中的my_dir/helloworld
,即directory/project_name
模型保存文件夹下查看详细的日志、检查点信息。
搜索空间可以设置条件超参。下面使用for
循环创建一组可优化的层,每一层都包含可优化的units
参数。这可以被推广到任何级别的相关参数,也可以递归。
注意所有的参数名称必须是唯一的(这里,对于第i个循环,内部参数命名为'units_'+str(i)
)。
# 构建模型
def build_model(hp):
model = keras.Sequential()
# 循环
for i in range(hp.Int('num_layers', 2, 20)):
# 循环中优化参数命名
model.add(layers.Dense(units=hp.Int('units_' + str(i),
min_value=32,
max_value=512,
step=32),
activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
model.compile(
optimizer=keras.optimizers.Adam(
hp.Choice('learning_rate', [1e-2, 1e-3, 1e-4])),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model
可以使用超参模型子类代替模型构建函数。
这会使超参模型的分享和重用变得简单。HyperModel
子类只需要实现一个build(self, hp)
方法。
from kerastuner import HyperModel
class MyHyperModel(HyperModel):
def __init__(self, num_classes):
self.num_classes = num_classes
def build(self, hp):
model = keras.Sequential()
model.add(layers.Input(shape=(28, 28, 1)))
model.add(layers.Flatten())
model.add(layers.Dense(units=hp.Int('units',
min_value=32,
max_value=512,
step=32),
activation='relu'))
model.add(layers.Dense(self.num_classes, activation='softmax'))
model.compile(
optimizer=keras.optimizers.Adam(
hp.Choice('learning_rate',
values=[1e-2, 1e-3, 1e-4])),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model
hypermodel = MyHyperModel(num_classes=10)
tuner = RandomSearch(
hypermodel,
objective='val_accuracy',
max_trials=10,
directory='my_dir',
project_name='helloworld1')
tuner.search(x_train, y_train,
epochs=5,
validation_data=(x_val, y_val))
Keras Tuner包含了预定义的优化应用:HyperResNet 和HyperXception。这是可以用于机器视觉的随时可用的超参模型。他们使用loss="categorical_crossentropy"
和metrics=["accuracy"]
进行预编译。
# 读取数据
from tensorflow.keras.datasets import cifar10
NUM_CLASSES = 10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
# 数据太多跑的太慢,减少数据
x_train = x_train[:10000]
x_test = x_test[:2000]
y_train = to_categorical(y_train, NUM_CLASSES)[:10000]
y_test = to_categorical(y_test, NUM_CLASSES)[:2000]
调用预定义的模型
from kerastuner.applications import HyperResNet
from kerastuner.tuners import Hyperband
hypermodel = HyperResNet(input_shape=(32, 32, 3), classes=10)
tuner = Hyperband(
hypermodel,
objective='val_accuracy',
max_epochs=5,
directory='my_dir',
project_name='cifar10_resnet')
tuner.search(x_train, y_train,
validation_data=(x_test, y_test))
可以轻易的限定搜索空间去优化部分参数。如果已经有了超参模型,只想优化其部分参数(如学习率),可以通过传递hyperparameters
参数给优化器构造器,也就是tune_new_entries=False
来限定没有在hyperparameters
中列出参数不参与优化。对于这些参数使用其默认值。
from kerastuner import HyperParameters
from kerastuner.applications import HyperXception
from kerastuner.tuners import Hyperband
hypermodel = HyperXception(input_shape=(32, 32, 3), classes=10)
hp = HyperParameters()
# 这将根据设定好的选项对`learning_rate` 参数进行优化
hp.Choice('learning_rate', values=[1e-2, 1e-3, 1e-4])
tuner = Hyperband(
hypermodel,
hyperparameters=hp,
# `tune_new_entries=False` 禁止没有列出的参数被优化
tune_new_entries=False,
objective='val_accuracy',
max_epochs=5,
directory='my_dir',
project_name='cifar10_xception')
tuner.search(x_train, y_train,
validation_data=(x_test, y_test))
想了解还有哪些可用参数?请阅读代码
参数默认值当在模型构建函数或者超参模型的build
方法中注册一个超参时,可用设定其默认值:
hp.Int('units',
min_value=32,
max_value=512,
step=32,
default=128)
如果不设默认值,超参的默认值被也会被默认设置(liruInt
,其默认值为最小值min_value
)。
如果你想相反的操作——优化除了一个参数(如学习率)之外,超参模型中所有可用的参数?
传递一个包含一个(或者几个)Fixed
项的超参hyperparameters
,并设定tune_new_entries=True
。
hypermodel = HyperXception(input_shape=(32, 32, 3), classes=10)
hp = HyperParameters()
hp.Fixed('learning_rate', value=1e-4)
tuner = Hyperband(
hypermodel,
hyperparameters=hp,
tune_new_entries=True,
objective='val_accuracy',
max_epochs=5,
directory='my_dir',
project_name='cifar10_xception1')
tuner.search(x_train, y_train,
validation_data=(x_test, y_test))
如果你有了一个想要优化现有优化器、损失或度量的超参模型,你同样可以将这些参数传递给优化器构造器如下所示:
hypermodel = HyperXception(input_shape=(32, 32, 3), classes=10)
tuner = Hyperband(
hypermodel,
optimizer=keras.optimizers.Adam(1e-3),
loss='mse',
metrics=[keras.metrics.Precision(name='precision'),
keras.metrics.Recall(name='recall')],
objective='val_Precision',
max_epochs=5,
directory='my_dir',
project_name='cifar10_xception2')
tuner.search(x_train, y_train,
validation_data=(x_test, y_test))