看别人的库是学习写代码的最佳方法。
加了序列的排版真是乱,索性不用序列了。
先看几个最基本的函数
svm_read_problem
def svm_read_problem(data_file_name): """ svm_read_problem(data_file_name) -> [y, x] Read LIBSVM-format data from data_file_name and return labels y and data instances x. """ prob_y = [] prob_x = [] for line in open(data_file_name): line = line.split(None, 1) # In case an instance with all zero features if len(line) == 1: line += [''] label, features = line xi = {} for e in features.split(): ind, val = e.split(":") xi[int(ind)] = float(val) prob_y += [float(label)] prob_x += [xi] return (prob_y, prob_x)
看这个地方
line = line.split(None, 1) # In case an instance with all zero features if len(line) == 1: line += ['']
第一行是将每一行从第一个字符分割,分割出来就是标签和特征了,程序里把标签称为label,把特征称为data instances
,返回的两个prob_y、prob_x分别是标签和特征向量,前者是一个list<float>,后者是一个list<dict>。
svm_problem结构体
class svm_problem(Structure): _names = ["l", "y", "x"] _types = [c_int, POINTER(c_double), POINTER(POINTER(svm_node))] _fields_ = genFields(_names, _types) def __init__(self, y, x, isKernel=None): if len(y) != len(x): raise ValueError("len(y) != len(x)") self.l = l = len(y) max_idx = 0 x_space = self.x_space = [] for i, xi in enumerate(x): tmp_xi, tmp_idx = gen_svm_nodearray(xi,isKernel=isKernel) x_space += [tmp_xi] max_idx = max(max_idx, tmp_idx) self.n = max_idx self.y = (c_double * l)() for i, yi in enumerate(y): self.y[i] = yi self.x = (POINTER(svm_node) * l)() for i, xi in enumerate(self.x_space): self.x[i] = xi
这个类可牵扯到了不少函数:
genFields
def genFields(names, types): return list(zip(names, types))
zip函数是python内置函数,用法看这里,作者写得非常简明扼要,这里简单摘录一小段:
基本用法:
x = [1, 2, 3] y = [4, 5, 6] z = [7, 8, 9] xyz = zip(x, y, z) print xyz # 运行结果:[(1, 4, 7), (2, 5, 8), (3, 6, 9)],即竖着对齐下来
示例2(长度调整):
x = [1, 2, 3] y = [4, 5, 6, 7] xy = zip(x, y)print xy # 运行结果:[(1, 4), (2, 5), (3, 6)]
示例3:
x = [1, 2, 3] y = [4, 5, 6] z = [7, 8, 9] xyz = zip(x, y, z) u = zip(*xyz)print u # 运行的结果是:[(1, 2, 3), (4, 5, 6), (7, 8, 9)]
作者的注释:(*的用法很有趣)
一般认为这是一个unzip的过程,它的运行机制是这样的: 在运行zip(*xyz)之前,xyz的值是:[(1, 4, 7), (2, 5, 8), (3, 6, 9)] 那么,zip(*xyz) 等价于 zip((1, 4, 7), (2, 5, 8), (3, 6, 9)) 所以,运行结果是:[(1, 2, 3), (4, 5, 6), (7, 8, 9)] 注:在函数调用中使用*list/tuple的方式表示将list/tuple分开,作为位置参数传递给对应函数(前提是对应函数支持不定个数的位置参数)
svm_problem中其余的几个函数可以看下面,svm_problem本身是一个结构体,定义了l,y,x,l是一个int,表示样本的个数,y是double型指针,也就是一个c数组。c数组的定义都是这个样子(我不知道这个叫什么,所以称他为c数组):
self.y = (c_double * l)() self.x = (POINTER(svm_node) * l)()
x则是指针的指针,这也很好理解,x实际上是一个序列,序列的每一项又表示一串svm_node,自然就是指针的指针,但是我不知道为什么python里面要用指针,而不直接拿list,应该是为了速度。
综上,svm_problem在定义的时候初始化了三个东西:一个是样本的个数,一个一维c数组,一个二维c数组。
gen_svm_nodearray
svm_node是一个结构体,表示一个特征,有两个成员,分别是特征的index和value,看函数原型:
def gen_svm_nodearray(xi, feature_max=None, isKernel=None): if isinstance(xi, dict): index_range = xi.keys() elif isinstance(xi, (list, tuple)): if not isKernel: xi = [0] + xi # idx should start from 1 index_range = range(len(xi)) else: raise TypeError('xi should be a dictionary, list or tuple') if feature_max: assert(isinstance(feature_max, int)) index_range = filter(lambda j: j <= feature_max, index_range) if not isKernel: index_range = filter(lambda j:xi[j] != 0, index_range) index_range = sorted(index_range) ret = (svm_node * (len(index_range)+1))() ret[-1].index = -1 for idx, j in enumerate(index_range): ret[idx].index = j ret[idx].value = xi[j] max_idx = 0 if index_range: max_idx = index_range[-1] return ret, max_idx
这个是上面的类里调用的一个函数。看几个有趣的地方:
if isinstance(xi, dict): index_range = xi.keys() elif isinstance(xi, (list, tuple)): if not isKernel: xi = [0] + xi # idx should start from 1 index_range = range(len(xi))
这里的xi就是每一行的那个dict,index是特征的序号,value是特征的值。他分别判断了xi是否为dict类型和list/tuple类型,说明这里的输入参数xi不要求一定是dict,也可以是按index顺序的列表或元组,当然svm_read_problem返回的那个prob_x直接就是一个dict了。
后面又用了几个高级函数,其实看高手的库和看正宗的英文原版小说一样,处处都是惊艳的用词。
if feature_max: assert(isinstance(feature_max, int)) index_range = filter(lambda j: j <= feature_max, index_range) if not isKernel: index_range = filter(lambda j:xi[j] != 0, index_range)
assert,断言feature_max是一个int,否则就会报AssertionError的错。
filter是一个用来挑选list(tuple、string)中符合条件的元素的函数,原型在此:
def filter(function_or_none, sequence): # known special case of filter """ filter(function or None, sequence) -> list, tuple, or string Return those items of sequence for which function(item) is true. If function is None, return the items that are true. If sequence is a tuple or string, return the same type, else return a list. """ pass
符合function的元素会被挑选出来,用来过滤字符串确实很方便。lambda用来做简单的判断再好不过了。
后面的那几句话就是申请了一个svm_node数组,然后用enumerate的方法将其中的每一个结构体赋值。enumerate是python建议使用的方法,range(0,len(array))这样的,不建议使用。原因也很简单:python的长处在于一字千金,总是把C里面的那套写法背着,就是与python理念相悖的。
最后,gen_svm_nodearray会返回一个装满svm_node的数组和一个最大的id号。
小结:
好了,这里已经解剖完了两个函数了,看看有哪些值得学习的地方:
enumerate
filter
isinstance
zip
assert
指针