《Python数据分析与挖掘实战》第6章代码问题 Sequential‘ object has no attribute ‘predict_classes‘

问题描述

在第6章中,使用predict_classes会造成报错,经过查询,高版本的tensorflow中已经不存在predict_classes,因此在查询之后决定修改为以下内容:

# 原来的代码
predict_result = net.predict_classes(train[:, :3]).reshape(len(train))  
#修改后的代码
predict_result = np.argmax(net.predict(train[:, :3]), axis=1)

修改之后再次运行代码,发现能够正常运行,但是绘制的混淆矩阵与书中内容相差较大
《Python数据分析与挖掘实战》第6章代码问题 Sequential‘ object has no attribute ‘predict_classes‘_第1张图片
经过检查后,打印输出predict_result,发现结果如下图所示:
《Python数据分析与挖掘实战》第6章代码问题 Sequential‘ object has no attribute ‘predict_classes‘_第2张图片
可以发现np.argmax()输出结果全部为0,经过查询后再次对代码进行修改

# 第一次修改后的代码
predict_result = np.argmax(net.predict(train[:, :3]), axis=1)
# 再次修改后的代码
predict_result = (net.predict(train[:, :3]) > 0.5).astype("int32")

最后再次运行代码,发现能够正常输出,结果正常。
绘制的混淆矩阵如下图所示,

《Python数据分析与挖掘实战》第6章代码问题 Sequential‘ object has no attribute ‘predict_classes‘_第3张图片

你可能感兴趣的:(笔记,python,数据分析,开发语言)