libsvm for python学习(1)

看别人的库是学习写代码的最佳方法。

加了序列的排版真是乱,索性不用序列了。


先看几个最基本的函数

svm_read_problem

def svm_read_problem(data_file_name):
	"""
	svm_read_problem(data_file_name) -> [y, x]

	Read LIBSVM-format data from data_file_name and return labels y
	and data instances x.
	"""
	prob_y = []
	prob_x = []
	for line in open(data_file_name):
		line = line.split(None, 1)
		# In case an instance with all zero features
		if len(line) == 1: line += ['']
		label, features = line
		xi = {}
		for e in features.split():
			ind, val = e.split(":")
			xi[int(ind)] = float(val)
		prob_y += [float(label)]
		prob_x += [xi]
	return (prob_y, prob_x)

看这个地方

line = line.split(None, 1)
# In case an instance with all zero features
if len(line) == 1: line += ['']

第一行是将每一行从第一个字符分割,分割出来就是标签和特征了,程序里把标签称为label,把特征称为data instances

,返回的两个prob_y、prob_x分别是标签和特征向量,前者是一个list<float>,后者是一个list<dict>。

svm_problem结构体

class svm_problem(Structure):
    _names = ["l", "y", "x"]
	_types = [c_int, POINTER(c_double), POINTER(POINTER(svm_node))]
	_fields_ = genFields(_names, _types)

	def __init__(self, y, x, isKernel=None):
		if len(y) != len(x):
			raise ValueError("len(y) != len(x)")
		self.l = l = len(y)

		max_idx = 0
		x_space = self.x_space = []
		for i, xi in enumerate(x):
			tmp_xi, tmp_idx = gen_svm_nodearray(xi,isKernel=isKernel)
			x_space += [tmp_xi]
			max_idx = max(max_idx, tmp_idx)
		self.n = max_idx

		self.y = (c_double * l)()
		for i, yi in enumerate(y): self.y[i] = yi

		self.x = (POINTER(svm_node) * l)()
		for i, xi in enumerate(self.x_space): self.x[i] = xi

   这个类可牵扯到了不少函数:

genFields

def genFields(names, types):
	return list(zip(names, types))

zip函数是python内置函数,用法看这里,作者写得非常简明扼要,这里简单摘录一小段:

    基本用法:

x = [1, 2, 3]
y = [4, 5, 6]
z = [7, 8, 9]
xyz = zip(x, y, z)
print xyz

# 运行结果:[(1, 4, 7), (2, 5, 8), (3, 6, 9)],即竖着对齐下来

    示例2(长度调整):

x = [1, 2, 3]
y = [4, 5, 6, 7]
xy = zip(x, y)print xy

# 运行结果:[(1, 4), (2, 5), (3, 6)]

    示例3:

x = [1, 2, 3]
y = [4, 5, 6]
z = [7, 8, 9]
xyz = zip(x, y, z)
u = zip(*xyz)print u

# 运行的结果是:[(1, 2, 3), (4, 5, 6), (7, 8, 9)]

 作者的注释:(*的用法很有趣

一般认为这是一个unzip的过程,它的运行机制是这样的:
在运行zip(*xyz)之前,xyz的值是:[(1, 4, 7), (2, 5, 8), (3, 6, 9)]
那么,zip(*xyz) 等价于 zip((1, 4, 7), (2, 5, 8), (3, 6, 9))
所以,运行结果是:[(1, 2, 3), (4, 5, 6), (7, 8, 9)]
注:在函数调用中使用*list/tuple的方式表示将list/tuple分开,作为位置参数传递给对应函数(前提是对应函数支持不定个数的位置参数)

svm_problem中其余的几个函数可以看下面,svm_problem本身是一个结构体,定义了l,y,x,l是一个int,表示样本的个数,y是double型指针,也就是一个c数组。c数组的定义都是这个样子(我不知道这个叫什么,所以称他为c数组):

self.y = (c_double * l)()
self.x = (POINTER(svm_node) * l)()

x则是指针的指针,这也很好理解,x实际上是一个序列,序列的每一项又表示一串svm_node,自然就是指针的指针,但是我不知道为什么python里面要用指针,而不直接拿list,应该是为了速度。

综上,svm_problem在定义的时候初始化了三个东西:一个是样本的个数,一个一维c数组,一个二维c数组。

gen_svm_nodearray

svm_node是一个结构体,表示一个特征,有两个成员,分别是特征的index和value,看函数原型:

def gen_svm_nodearray(xi, feature_max=None, isKernel=None):
	if isinstance(xi, dict):
		index_range = xi.keys()
	elif isinstance(xi, (list, tuple)):
		if not isKernel:
			xi = [0] + xi  # idx should start from 1
		index_range = range(len(xi))
	else:
		raise TypeError('xi should be a dictionary, list or tuple')

	if feature_max:
		assert(isinstance(feature_max, int))
		index_range = filter(lambda j: j <= feature_max, index_range)
	if not isKernel:
		index_range = filter(lambda j:xi[j] != 0, index_range)

	index_range = sorted(index_range)
	ret = (svm_node * (len(index_range)+1))()
	ret[-1].index = -1
	for idx, j in enumerate(index_range):
		ret[idx].index = j
		ret[idx].value = xi[j]
	max_idx = 0
	if index_range:
		max_idx = index_range[-1]
	return ret, max_idx

这个是上面的类里调用的一个函数。看几个有趣的地方:

if isinstance(xi, dict):
    index_range = xi.keys()
elif isinstance(xi, (list, tuple)):
    if not isKernel:
	xi = [0] + xi  # idx should start from 1
	index_range = range(len(xi))

这里的xi就是每一行的那个dict,index是特征的序号,value是特征的值。他分别判断了xi是否为dict类型和list/tuple类型,说明这里的输入参数xi不要求一定是dict,也可以是按index顺序的列表或元组,当然svm_read_problem返回的那个prob_x直接就是一个dict了。

后面又用了几个高级函数,其实看高手的库和看正宗的英文原版小说一样,处处都是惊艳的用词。

if feature_max:		
    assert(isinstance(feature_max, int))
    index_range = filter(lambda j: j <= feature_max, index_range)
if not isKernel:
    index_range = filter(lambda j:xi[j] != 0, index_range)

assert,断言feature_max是一个int,否则就会报AssertionError的错。

filter是一个用来挑选list(tuple、string)中符合条件的元素的函数,原型在此:

def filter(function_or_none, sequence): # known special case of filter
    """
    filter(function or None, sequence) -> list, tuple, or string
    
    Return those items of sequence for which function(item) is true.  If
    function is None, return the items that are true.  If sequence is a tuple
    or string, return the same type, else return a list.
    """
    pass

符合function的元素会被挑选出来,用来过滤字符串确实很方便。lambda用来做简单的判断再好不过了。

后面的那几句话就是申请了一个svm_node数组,然后用enumerate的方法将其中的每一个结构体赋值。enumerate是python建议使用的方法,range(0,len(array))这样的,不建议使用。原因也很简单:python的长处在于一字千金,总是把C里面的那套写法背着,就是与python理念相悖的。

最后,gen_svm_nodearray会返回一个装满svm_node的数组和一个最大的id号。


小结:

好了,这里已经解剖完了两个函数了,看看有哪些值得学习的地方:

  1. enumerate

  2. filter

  3. isinstance

  4. zip

  5. assert

  6. 指针


你可能感兴趣的:(libsvm for python学习(1))