error: (-215:Assertion failed) samples.cols == var_count && samples.type() == CV_32F in function ‘解决

背景(可直接跳过,主体代码都不用看,直接看下面的原因)

问题是出在我作为一个实训项目组长的时候的。组内任务是编写一个Django前端框架、cv摄像头、knn及svm机器学习算法三个模块整合在一起的项目,其中一个组员​​​​​​​在编写svm的预测部分时报了如下错误

我们的代码如下:

import cv2
import numpy as np

SZ = 20
bin_n = 16

affine_flags = cv2.WARP_INVERSE_MAP | cv2.INTER_LINEAR


# 使用图片的二阶矩对图片进行抗扭斜处理
def deskew(img):
    m = cv2.moments(img)  # 获取图片的矩
    if abs(m['mu02']) < 1e-2:
        return img.copy()
    skew = m['mu11'] / m['mu02']
    M = np.float32([[1, skew, -0.5 * SZ * skew], [0, 1, 0]])
    img = cv2.warpAffine(img, M, (SZ, SZ), flags=affine_flags)
    return img


# 计算图像的hog描述符
def get_hog():
    hog = cv2.HOGDescriptor((20,20), (8, 8), (4, 4), (8, 8), 9, 1, -1, 0, 0.2, 1, 64, True)
    print("hog descriptor size: {}".format(hog.getDescriptorSize()))
    return hog


# 获取图片
img = cv2.imread('3.jpg', 0)

# 读取模型
model = cv2.ml.SVM_load('svm_data.dat')

# 对图像进行处理
bin_norm = deskew(img)
hog = get_hog()
sample = hog.compute(bin_norm)

# 预测结果
digit = model.predict(sample)
print(digit)

 完整报错如下:

Traceback (most recent call last):
  File "C:\*反正是桌面*\zhouge\testv1tosee.py", line 40, in 
    digit = model.predict(sample)
cv2.error: OpenCV(4.4.0) C:\Users\appveyor\AppData\Local\Temp\1\pip-req-build-95hbg2jt\opencv\modules\ml\src\svm.cpp:2013: error: (-215:Assertion failed) samples.cols == var_count && samples.type() == CV_32F in function 'cv::ml::SVMImpl::predict'

报错解决方法

***原因***

报错的原因是hog的形状不对

至于hog的形状应该是怎么样的,应该和之前的训练有关我不是具体负责这块的,只介绍一下(看了好多其他文章都没提到的)解决问题的思路

解决

把get_hog函数换成下面的就可以了:

def hog(img):
    gx = cv2.Sobel(img, cv2.CV_32F, 1, 0)
    gy = cv2.Sobel(img, cv2.CV_32F, 0, 1)
    mag, ang = cv2.cartToPolar(gx, gy)
    bins = np.int32(bin_n * ang / (2 * np.pi))
    bin_cells = bins[:10, :10], bins[10:, :10], bins[:10, 10:], bins[10:, 10:]
    mag_cells = mag[:10, :10], mag[10:, :10], mag[:10, 10:], mag[10:, 10:]
    hists = [np.bincount(b.ravel(), m.ravel(), bin_n) for b, m in zip(bin_cells, mag_cells)]
    hist = np.hstack(hists)
    return hist

完整代码如下:

import cv2
import numpy as np

SZ = 20
bin_n = 16

affine_flags = cv2.WARP_INVERSE_MAP | cv2.INTER_LINEAR


# 使用图片的二阶矩对图片进行抗扭斜处理
def deskew(img):
    m = cv2.moments(img)  # 获取图片的矩
    if abs(m['mu02']) < 1e-2:
        return img.copy()
    skew = m['mu11'] / m['mu02']
    M = np.float32([[1, skew, -0.5 * SZ * skew], [0, 1, 0]])
    img = cv2.warpAffine(img, M, (SZ, SZ), flags=affine_flags)
    return img

'''  原来的生成hog的函数
# 计算图像的hog描述符
def get_hog():
    hog = cv2.HOGDescriptor((20,20), (8, 8), (4, 4), (8, 8), 9, 1, -1, 0, 0.2, 1, 64, True)
    # print("hog descriptor size: {}".format(hog.getDescriptorSize()))
    return hog
'''


# 新的计算hog的函数
def hog(img):
    gx = cv2.Sobel(img, cv2.CV_32F, 1, 0)
    gy = cv2.Sobel(img, cv2.CV_32F, 0, 1)
    mag, ang = cv2.cartToPolar(gx, gy)
    bins = np.int32(bin_n * ang / (2 * np.pi))
    bin_cells = bins[:10, :10], bins[10:, :10], bins[:10, 10:], bins[10:, 10:]
    mag_cells = mag[:10, :10], mag[10:, :10], mag[:10, 10:], mag[10:, 10:]
    hists = [np.bincount(b.ravel(), m.ravel(), bin_n) for b, m in zip(bin_cells, mag_cells)]
    hist = np.hstack(hists)
    return hist


# 获取图片
img = cv2.imread('2.jpg', 0)
print(img.shape)

# 读取模型
model=cv2.ml.SVM_load('svm_data.dat')

# 对图像进行处理
bin_norm=deskew(img)
sample=hog(deskew(img))
trainData = np.float32(sample).reshape(-1, 64)  # 注意,这里又reshape了一下

# 预测结果
digit=model.predict(trainData)
print(int(digit[-1][0]))

后记

我们组是分工协作,遇到这个问题一开始我还没留意,以为还是像其他问题一样,顶多过半天负责这个的人就解决了。

没想到这个问题卡了1、2天,最后还差点卡了我们组上交的进度。

不过真正上手解决,大概也就1个白天的时间。

作为程序员最大的收获就是学会了看源代码。倒不是说看源码就能直接解决,看源码最重要的是给你解决问题的决心。像这个问题,我用的是python,一开始看源码,cv2的库不好弄,折腾了半天,折腾完看到源码后都是这样的:

def SVM_load(filepath): # real signature unknown; restored from __doc__
    """
    SVM_load(filepath) -> retval
    .   @brief Loads and creates a serialized svm from a file
    .        *
    .        * Use SVM::save to serialize and store an SVM to disk.
    .        * Load the SVM from this file again, by calling this function with the path to the file.
    .        *
    .        * @param filepath path to serialized svm
    """
    pass

代码里面写个pass,这代表这个代码具体是用C或C++写的。

最后我在github上找到了源码。放一段:

    ...
    float predict( InputArray _samples, OutputArray _results, int flags ) const CV_OVERRIDE
    {
        float result = 0;
        Mat samples = _samples.getMat(), results;
        int nsamples = samples.rows;
        bool returnDFVal = (flags & RAW_OUTPUT) != 0;

        CV_Assert( samples.cols == var_count && samples.type() == CV_32F );

        if( _results.needed() )
        {
            _results.create( nsamples, 1, samples.type() );
            results = _results.getMat();
        }
        else
        {
            CV_Assert( nsamples == 1 );
            results = Mat(1, 1, CV_32F, &result);
        }
    ...

这段就定位到了我错误出现的地方。

看到源码,就要完全读懂吗?不是的。其实在出问题的这里,你只要简单了解一下CV_Assert是cv库自置的报错的一个函数,括号内bool值为false就报错;以及Mat是类似numpy的array,可以存图之类的,你这里的工作其实就已经做到了做完了。

真看完源代码啊?你可以去自己研究库了。

但在这看问题解决方法的,难道不都是连自己代码问题都没搞明白才来求助于网络的人吗?

着手之处,还得是解决手头上的问题。

那么,看完这段、理解完报错出现的地方,我们知道了什么?是确认了报错就是因为传入的参数形状不对,是让你有底气判断错误的方向特别是,网上其他的报错解决贴没有提到这样的解决方法,需要自己作出对问题的抉择时,这样的信心就尤为重要。

确定方向后,解决就很容易上手了。我就把工作交给组员,不到一上午他就做出来了。之前他被这个问题卡着,一方面是因为网上没有现成的方案,另一方面也是自己没有解决问题的方向。而看源码能够弥补前者,为后者创造条件

其他的收获就是作为组长的了。不过,前面这些写的也够了,就不多说了。

另外本人在这。

你可能感兴趣的:(opencv,python,svm,源代码管理,c++)