本文是《CS231n: Convolutional Neural Networks for Visual Recognition》课程的学习笔记。
课程讲义:http://cs231n.stanford.edu/slides/2017/cs231n_2017_lecture2.pdf
知识点:
图像分类是计算机视觉中最重要的一个任务,简单的说,就是给你一张包含某一特定物体的图像,让你识别出这个图像包含的物体是什么。
比如说,给了一张猫的照片,我们可以非常简单的识别出这是猫,这是因为我们的大脑做了非常多的关于这方面的学习。如果让电脑来识别这张图像,那就完全不一样了,因为电脑和我们看到的东西是完全不一样的,电脑看到的是记录着这张图像的所有的像素。
从计算机的角度,图像其实是一个个像素组成的,那么,即使仍然是一只猫的图像,由于下面情况的出现,会导致像素完全不同:
图像分类的API,可以用如下的代码表示:
def classify_image(image):
# some classify code here (magic).
return class_label
由于前面说到的在猫的识别上存在很多的问题,所以这里很难去定义一套算法来做识别。
为了解决这类问题,John Canny在1986年就通过检测特征点来提取猫的边缘的方式来完成猫的识别,这种方式对猫的识别可能有用,但是如果换了一种物体,则需要重新进行计算和检测了。
那现在,我们可以通过数据驱动的方法(Data-Driven Approach)来完成相关的任务:
def train(images, labels):
# Machine Learning
return model
def predict(model, test_images):
# Use model to predict labels
return test_labels
最邻近法在训练数据过程中只是单纯的把数据记录下来,在预测阶段则是将数据与训练数据进行对比来找出最相似的结果。
这里有一个比较好的训练数据叫做CIFAR10,包含10个分类,总共有5万张训练图像、1万张测试图像。
在比较图像时,我们一般会用到两种距离:
下面以L1距离为例。
L1 Distance为 d 1 ( I 1 − I 2 ) = ∑ ( I 1 p − I 2 p ) d1(I_1 - I_2) = \sum(I_1^p - I_2^p) d1(I1−I2)=∑(I1p−I2p)
即按照对应的像素来求差值的绝对值,最后所有的差值求和。
前面说过,训练数据阶段是读取并存储数据阶段,代码为:
def train(self, X, y):
self.train_X = X
self.train_y = y
预测阶段则是将数据与训练数据进行对比来找出最相似的结果,这里就需要用到L1 Distance了:
for i in xrange(num_test):
distances = np.sum(np.abs(self.train_X - X[i, :]), axis=1)
min_index = np.argmin(distances)
y_predict[i] = self.train_y[min_indix]
从最邻近算法可以看出,在训练阶段其时间复杂度为O(1),但是在预测阶段其时间复杂度为O(N),这种情况使得最邻近算法在现实生产中不可用,因为我们需要保证在用户能接触到的预测阶段的时间复杂度较低(运算较快)。
在后面要学到的卷积神经网络中,时间复杂度与最邻近法恰好相反,其是在训练阶段时间复杂度很大而在预测阶段很小。
除了从最邻近的点来做判断,也可以指定K参数,即从最邻近的K个点来做参考,也称为K邻近算法。
具体可以通过该Demo来查看http://vision.stanford.edu/teaching/cs231n-demos/knn/
从上面的例子中可以看到随着K和距离方式这两个参数的选取不同,最终得到的结果也不同,而这两个参数也称为「超参数」,即人为设定的参数,这些参数会随着问题的不同而选用不同的值,并且需要不断的尝试。
在设定最邻近值函数的超参数时,有几种设置训练数据的方式:
总结一下,现实中基本不会使用K最邻近算法来做图像分类,主要是三个原因:
关于第2、3点,我目前还不懂,还需要再做一些研究。
线性分类(Linear Classifiers)是神经网络的基础,
线性分类是一种参数化的方法,输入图像x,然后应用参数(权重)w,经过一个函数计算,最后得到不同分类的得分(下图中是分为10类)。
以CIFAR-10为例,线性分类主要是将输入数据的所有特征做一个平均,最终得到一个统一的模板来尝试识别不同的分类。所以得到的分类器不会太好。❓(这部分的理解也不好)
在通过线性分类得到不同分类的得分后,我们如何来判断这个W参数的效果呢?这就涉及到损失函数和优化了,后面章节再来探讨。