機器學習-異常點檢測

1. 異常檢測算法

無監督異常點檢測算法常見的有三種,分別是:Local Outlier Factor(LOF)-局部異常因子算法、Isolation Forest(iForest)-孤立森林、One class SVM-一類支持向量機。

 

2.模型訓練

import cv2
import os
import numpy as np
from sklearn.ensemble import IsolationForest
from sklearn.externals import joblib
from sklearn.neighbors import LocalOutlierFactor
from sklearn.svm import OneClassSVM


def load_data(img_path):
    x = []
    for file in os.listdir(img_path):
        file_path = os.path.join(img_path, file)
        img = cv2.imread(file_path)
        input = img.reshape(1, -1)
        # input = ing.reshape(1, -1)/255. ## 數據規一化
        x.append(input[0])
    return np.array(x)


if __name__ == '__main__':
    img_path = r'D:\data'
    x = load_data(img_path)
    clf = IsolationForest(contamination=0.001, behaviour='new')
    # clf = LocalOutlierFactor(n_Aneighbors=20,contamination=0.00001)
    # clf = OneClassSVM(kernel='rbf',nu=0,gamma=0.1)  ##使用OneClassSVM需要對數據集規一化
    clf.fit(x)
    joblib.dump(clf, r'D:\MODEL\L_IT_069.model')

3. 模型預測

import os, cv2
import numpy as np
from sklearn.externals import joblib
import shutil
from collections import Counter


def model_predict(model_path, img_path, save_path):
    clf = joblib.load(model_path)
    x = []
    n = 0
    path_list = os.listdir(img_path)
    pre_list = []
    if len(path_list) == 0:
        return 'no img'
    for file in path_list:
        file_path = os.path.join(img_path, file)
        img = cv2.imread(file_path)
        if img == []:
            print(file)
            n += 1
            # shutil.move(file_path,os.path.join(save_path,file))
            continue
        if isinstance(img, int): continue
        pre_list.append(file)
        input = img.reshape(1, -1)
        x.append(input[0])
    x = np.array(x)
    if len(x) == 0:
        return 'no img'
    predict = clf.predict(x)
    for i, result in enumerate(predict):
        if result == -1:
            print(pre_list[i])  # 打印ng圖片路徑
            # shutil.move(os.path.join(img_path,pre_list[i]),os.path.join(save_path,pre_list[i]))
    acc = 1 - (n + Counter(predict)[-1]) / len(path_list)
    print(model_path, ' acc: ', acc)
    # return acc


if __name__ == '__main__':
    model_path = r'D:\model\test.model'
    img_path = r"D:\data\img"
    save_path = r'D:data\save_img'
    model_predict(model_path, img_path, save_path)

 

你可能感兴趣的:(機器學習,机器学习,算法)