基于python通过SVM+SURF实现花卉图像分类

数据集为牛津大学库里的17类花卉图像提取码c4s4,该程序的思路是参考手势识别的项目所修改。

1. 提取所有花卉图像的SIFT特征

opencv里有直接调用sift特征提取的函数,下列操作是将所有类别图像文件夹遍历,批量提取sift特征,并将特征量化到一个文本文件中方便后续操作。

path = './' + 'feature' + '/' #保存特征的路径
path_img = './' + 'image' + '/' #数据集路径

def calcSURFFeature(img):
	gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
	sift = cv2.xfeatures2d.SURF_create(200) #特征点数量限制自行设置 
	kps, des = sift.detectAndCompute(gray, None)
	return des
	
if __name__ == "__main__":
    for i in range(1, 18): #读取17类图像的文件夹
        for j in range(1, 21): #读取每类文件夹的图像数量
            #文件夹路径+命名形式
            #数据集内图像名称未按顺序命名,建议先改名后提取特征
            roi = cv2.imread(path_img + str(i) + '_' + str(j) + '.jpg') 
            descirptor_in_use = fd.calcSURFFeature(roi)
            fd_name = path + str(i) + '_' + str(j) + '.txt' #形成特征
            with open(fd_name, 'w', encoding='utf-8') as f:
                for k in range(1, len(descirptor_in_use)):
                    x_record = descirptor_in_use[k]
                    f.write(str(x_record))
                    f.write(' ')
                f.write('\n')
            print(i, '_', j, '完成')

2. 将文本文件中特征形式统一

由于上一步得出的文本中的surf特征都是矩阵形式,为了后续方便分类,将文本中不必要的字符都删除掉。

path = './' + 'feature' + '/'
xmls = glob.glob(path+'*.txt')

for one_xml in xmls:
    print(one_xml)
    f = open(one_xml, 'r+', encoding='utf-8')
    all_the_lines = f.readlines()
    f.seek(0)
    f.truncate()
    for line in all_the_lines:
        line = line.replace('[', '')
        line = line.replace(']', '')
        line = line.replace(' ', '')
        line = line.replace('  ', '')
        line = line.replace('   ', '')
        line = line.replace('\n', '')
        line = line.replace('.', ' ')
        f.write(line)
    f.close()

3. 用SVM分类

先将文本文件中数据归一化

#路径设置
path = './' + 'feature' + '/'
model_path = "./model/"
test_path = "./test_feature/"

test_accuracy = []

def txtToVector(filename, N):
	returnVec = np.zeros((1, N))
	fr = open(filename)
	lineStr = fr.readline()
	lineStr = lineStr.split(' ')
	for i in range(N):
		returnVec[0, i] = int(lineStr[i])
	return returnVec

再到SVM分类器的训练过程,SVM参数设定很重要,参数设置要求越高时间越长,网上有很多SVM参数设定的文章。训练过程中有很多数据用不了,可以直接删除。数据能否使用,与后面设定的N值有关,可以进行debug形式查看哪些数据不能用,但过程比较繁琐。

def tran_SVM(N):
	svc = SVC()
	parameters = {'kernel': ('linear', 'rbf'),
				  'C': [1, 3, 5, 7, 9],
				  'gamma': [0.0001, 0.001, 0.1, 1, 10, 100]}  
	hwLabels = []  # 存放类别标签
	trainingFileList = listdir(path)
	m = len(trainingFileList)
	trainingMat = np.zeros((m, N))
	for i in range(m):
		fileNameStr = trainingFileList[i]
		classNumber = int(fileNameStr.split('_')[0])
		hwLabels.append(classNumber)
		trainingMat[i, :] = txtToVector(path + fileNameStr, N)  
	print("完成")
	clf = GridSearchCV(svc, parameters, cv=5, n_jobs=4) #此处参数设定也很重要
	clf.fit(trainingMat, hwLabels)
	best_model = clf.best_estimator_
	print("SVM模型保存中...")
	save_path = model_path + "svm_efd_" + "train_model.m"
	joblib.dump(best_model, save_path)  # 保存最好的模型

def test_SVM(clf, N):
	testFileList = listdir(test_path)
	errorCount = 0  # 记录错误个数
	mTest = len(testFileList)
	for i in range(mTest):
		fileNameStr = testFileList[i]
		classNum = int(fileNameStr.split('_')[0])
		vectorTest = txtToVector(test_path + fileNameStr, N)
		valTest = clf.predict(vectorTest)
		if valTest != classNum:
			errorCount += 1
	print("总共错了%d个数据\n错误率为%f%%" % (errorCount, errorCount / mTest * 100))

if __name__ == "__main__": #此处N值设定影响训练结果
	tran_SVM(50000)
	clf = joblib.load(model_path + "svm_efd_" + "train_model.m")
	test_SVM(clf,50000)

4. 总结

关于N值的设定我调了很久,N值越大结果的精确度越高,但也对数据有很多限制,因此也删除了许多数据,我用AMD训练时间花了一个多小时。大家可以尝试其他特征提取后再做SVM分类,我用SURF特征提取精确度并不高,不知道有哪些位置可以改正以提高识别率,欢迎交流。

你可能感兴趣的:(python,机器学习,分类算法)