TensorFlow 2.0 model的predict()方法详解以及自定义predict实现预测数据和真实数据配对输出

TensorFlow模型训练过程中fit()可以直接设置validation_data为test数据集来测试模型的性能。但是通常我们要输出模型的预测值,用来绘制图形等等操作。接下来详细介绍tensorflow的模型预测方法。

# 模型类方法
predict(
    x,
    batch_size=None,
    verbose=0,
    steps=None,
    callbacks=None,
    max_queue_size=10,
    workers=1,
    use_multiprocessing=False
)

参数:

  • x:样本自变量数据X。支持多种类型可以是:

    • A Numpy array (or array-like), or a list of arrays. 注意如果模型是多输入的,那么X就是多个array的list,且list长度和输入的个数相等。输入array的前后顺序要一致。

    例如有一个model的信息如下:
    model = keras.Model( inputs=[image_input, timeseries_input], outputs=[score_output, class_output] )
    那么预测时也应该这么写:
    model.predict([image_input, timeseries_input])

    • A TensorFlow tensor, or a list of tensors.
    • A tf.data.Dataset. A generator or keras.utils.Sequence instance.

    Dataset通常是即包括了X也包括了Y。不用自己分离。tf框架可以自动的处理。

  • batch_size:也是默认32. 当X是Dataset时,不需要指定。

  • verbose:提示信息模式。0 or 1. 0代表沉默模式,1代表提示模式(可以输出一些过程信息到控制台)。

  • steps:这个参数规定了预测的步数,决定model预测batch_size*steps数据。一般指定为None,None时忽略。X是 Dataset时, predict() 将对X的所有数据进行预测。

    这个使用的场景,我认为是当数据很大,也不想拆分数据集的时候,就指定steps,进行有限的预测。

  • callbacks:预测过程中的回调函数设置

在实际使用的时候,如果输入x是DataSet,(Dataset在使用前是必须batch的)可以自动shuffle,发现每一次运行predict(),x的数据也是会重新打乱。所以每一次predict() 输出预测数据的顺序是不一样的。

predict()的x数据类型最好是跟fit()传入的数据类型是一致。

predict()也是基于批处理式的。如果predict()时输入数据的shape[0]或者是batch_size和fit()的batch_size不一致,就会报错。原因predict()也会调用model的__call__()
predict()方法适用于对大量数据输出其预测值,但它不适合用于迭代数据和一次处理少量输入的循环中。对于适合少量数据,直接使用model的__call__()更快:model(x),或model(x, training=False)

predict()与model的__call__()的关系:

y = model.predict(x)y = model(x)都可以运行model返回输出y。predict()对数据进行批量循环(实际上,可以通过predict(x, batch_size=64)指定批量大小),它提取输出的NumPy值。它大致相当于这个:

def predict(x):
    y_batches = []
    for x_batch in get_batches(x):
        y_batch = model(x).numpy()
        y_batches.append(y_batch)
    return np.concatenate(y_batches)

这意味着predict()调用可以扩展到非常大的数组数据的预测而不占用大量的内容。model(x)是在内存中运行的,对大量数据不太友好。另外,这个predict是不可微分的,如果在GradientTape管理内调用它,则无法检索其梯度。当需要检索模型调用的梯度时,应该使用model(x)。

由于tensorflow使用Dataset的数据类型进行模型的训练更加方便,所以输入x一般都是Dataset。Dataset数据最大的问题在于,其batch化是迭代器模式的。每次迭代完元素再次迭代,数据的顺序会重新shuffle,所以会造成对同样的数据的两次predict()得到的预测值的顺序完全不一致。又由于predict() 不会自己进行真实数据和预测数据的匹配。这对于我们获得预测值和其对应的真实值是有困难的。

因为predict()调用一次算是完成了一次Dataset的迭代。而我们如果再去获得Dataset里的真实数据,是一次新的迭代。数据前后顺序不一致

在这种情况下,需要使用model的__call__()方法来自定义model的predict(),实现模型的预测值和真实值的配对输出,参考如下:

def predict(model,db_test,batchz=None,verbose=True):
    '''
    自定义的model的预测方法,主要是实现配对返回预测值和真实值
    :param model:
    :param db_test:
    :param batchz:当db已经进行了batch,这里就不需要赋值
    :param verbose:
    :return:
    '''
    y_pre = np.array([])
    y_tru = np.array([])
    for elem in db_test.as_numpy_iterator():
    	# 注意,这里的model要非训练模式
        batch_y_pre=model(elem[0],training=False).numpy().flatten()
        batch_y_tru=elem[1].flatten()
        y_pre = np.insert(y_pre, len(y_pre), batch_y_pre)
        y_tru = np.insert(y_tru, len(y_tru), batch_y_tru)
    if verbose:
        for pre,tru in zip(y_pre,y_tru):
            print(f"{model.name}的预测值{pre}---->真实值{tru}")
    return y_pre,y_tru

网络模型model,其结构不方便透露,调用自定义的predict()的效果如下:

y_pre,y_tru=predict(model,db_test)
my_rnn的预测值3.8517441749572754---->真实值3.7889999999999997
my_rnn的预测值4.515721797943115---->真实值4.473
my_rnn的预测值4.1115193367004395---->真实值3.963
my_rnn的预测值4.863704204559326---->真实值4.811
my_rnn的预测值3.9370434284210205---->真实值4.306
my_rnn的预测值5.3716230392456055---->真实值5.442
my_rnn的预测值3.5435032844543457---->真实值3.3819999999999997
my_rnn的预测值3.988416910171509---->真实值3.734
my_rnn的预测值5.273722171783447---->真实值5.211
my_rnn的预测值3.800370693206787---->真实值3.3569999999999998
my_rnn的预测值3.845928192138672---->真实值3.7680000000000002
my_rnn的预测值4.800698757171631---->真实值4.5169999999999995
my_rnn的预测值3.680176258087158---->真实值3.6439999999999997
my_rnn的预测值4.042698383331299---->真实值4.492
y_pre
array([3.85174417, 4.5157218 , 4.11151934, 4.8637042 , 3.93704343,
        5.37162304, 3.54350328, 3.98841691, 5.27372217, 3.80037069,
        3.84592819, 4.80069876, 3.68017626, 4.04269838])
y_tru
 array([3.789, 4.473, 3.963, 4.811, 4.306, 5.442, 3.382, 3.734, 5.211,
        3.357, 3.768, 4.517, 3.644, 4.492])

你可能感兴趣的:(#,Tensorflow,python,tensorflow,神经网络)