Knn 模型的训练和 加载

训练:

def main():

    (train, train_labels), (test, test_labels) = XX_load_data()
    train = np.array(train / 255., dtype=np.float32)
    test = np.array(test / 255., dtype=np.float32)
    knn = cv2.ml.KNearest_create()

    print(train.shape, train_labels.shape)
    print(test.shape, test_labels.shape)

    knn.train(train, cv2.ml.ROW_SAMPLE, train_labels)
    # knn.save("XXX")

其中  XX_load_data  为自己定义的导入 数据 公式

预测:

 knn = cv2.ml.KNearest_load(r"XXX")
    (train, train_labels), (test, test_labels) = Cap_load_data()
    print(test.shape, test_labels.shape)
    test = np.array(test / 255., dtype=np.float32)
    train = np.array(train / 255., dtype=np.float32)

    # src = test[0].reshape(0, 1)

    # print(test[0])
    ret, result, neighbours, dist = knn.findNearest(test, k=3)

最重要的就是这个模型加载了:

knn = cv2.ml.KNearest_load(r"E:\AOItest\AOIpyOtherMethods\Knn\Cap_knn.xml")
之前是这样写的:

knn = cv2.ml.KNearest_create()
knn.load(r"XXX")

这样报错:

test_samples.type() == CV_32F && test_samples.cols == samples.cols in function 'cv::ml::BruteForceImpl::findNearest'

还有需要注意的是 预测输入的 test应为 2维

你可能感兴趣的:(#,机器学习,Knn,模型加载,分类)