用python实现了一个稀疏矩阵。
基本思想是3元组(行坐标、列坐标和值)描述矩阵。
将3元组保存在sqlite3的内存表里。
代码如下:
import sqlite3
class SparseMatrix:
def __init__(self, row_count=2147483647, column_count=2147483647):
self.db = sqlite3.connect(":memory:")
self.db.execute("CREATE TABLE 'matrix' ('x' integer, 'y' integer, 'value' real, primary key('x', 'y'));")
self.row = row_count
self.column = column_count
def __del__(self):
self.db.close()
def __getitem__(self, index):
if isinstance(index, tuple):
if len(index) == 2:
row,column = index
if row >= self.row or column >= self.column:
raise IndexError
cursor = self.db.execute("select value from matrix where x=? and y=?", index)
value = cursor.fetchone()
if value:
return value[0]
else:
return 0.0
else:
raise IndexError
else:
raise TypeError
def __setitem__(self, index, value):
row, column = index
if row >= self.row or column >= self.column:
raise IndexError
self.db.execute("insert or replace into matrix values(?,?,?)", (row, column, value))
def __add__(self, other): #self + other
if isinstance(other, SparseMatrix):
m = self.copy()
for r,c,v in other:
m[r,c] = self[r,c] + v
m.update()
return m
else:
raise TypeError
def __iadd__(self, other): #self += other
if isinstance(other, SparseMatrix):
if self.row != other.row or self.column != other.column:
raise IndexError
for row,column,value in other:
self[row,column] = self[row,column] + value
self.update()
return self
else:
raise TypeError
def __sub__(self, other):
if isinstance(other, SparseMatrix):
if self.row != other.row or self.column != other.column:
raise IndexError
m = self.copy()
for r,c,v in other:
m[r,c] = self[r,c] - v
m.update()
return m
else:
raise TypeError
def __mul__(self, other):
if isinstance(other, SparseMatrix):
m = SparseMatrix(self.row, other.column)
rows = self.getallrows()
columns = other.getallcolumns()
results = []
data = []
col_data = other.getcolumn(0)
for r in rows:
row_data = self.getrow(r)
for c in columns:
col_data = other.getcolumn(c)
pr = row_data.__iter__()
pc = col_data.__iter__()
rdata = pr.next()
cdata = pc.next()
while True:
try:
if rdata[0] == cdata[0]:
results.append(rdata[1] * cdata[1])
rdata = pr.next()
cdata = pc.next()
else:
if rdata[0] > cdata[0]:
cdata = pc.next()
else:
rdata = pr.next()
except StopIteration:
if results:
m[r,0] = sum(results)
results = list()
break
m.update()
return m
elif isinstance(other, int) or isinstance(other, float):
m = SparseMatrix(self.row, self.column)
for r,c,v in self:
m[r,c] = v * other
m.update()
return m
else:
raise TypeError
def __iter__(self):
cursor = self.db.execute("select x,y,value from matrix order by x,y")
for cell in cursor:
yield cell
def __len__(self):
cursor = self.db.execute("select count(*) from matrix")
return cursor.fetchone()[0]
def insert(self, cells):
for row,column,value in cells:
self[row, column] = value
self.update()
def copy(self):
m = SparseMatrix(self.row, self.column)
for r,c,v in self:
m[r,c] = v
return m
def getrow(self, row):
cursor = self.db.execute("select y,value from matrix where x=%d order by y" % row)
return cursor.fetchall()
def getcolumn(self, column):
cursor = self.db.execute("select x,value from matrix where y=%d order by x" % column)
return cursor.fetchall()
def getallrows(self):
rows = self.db.execute("select distinct x from matrix order by x").fetchall()
return zip(*rows)[0]
def getallcolumns(self):
columns = self.db.execute("select distinct y from matrix order by y").fetchall()
return zip(*columns)[0]
def update(self):
self.db.execute("DELETE FROM matrix where value between -0.0000001 and 0.0000001")
self.db.commit()
支持下表访问,支持矩阵的加法,减法和乘法运算,支持遍历。
如:
m1 = SparseMatrix()
m2 = SparseMatrix()
m1[0,0] = 1
m2[1,1] = 2
m = m1 + m2
for row,column,value in m:
print row,column,value
只是实现里功能,速度方面就比较头疼了。
做一个NxN的矩阵与N维向量的乘法用了大于6个小时(N=520000)。
还有一个支持多CPU并行乘法的版本,写的太难看,就不贴出来了。
不过有一个优点是不需要明确的知道矩阵的大小,在某些情况下还是有一些用处吧,我想。