用sqlite3实现稀疏矩阵

用python实现了一个稀疏矩阵。
基本思想是3元组(行坐标、列坐标和值)描述矩阵。
将3元组保存在sqlite3的内存表里。

代码如下:
import sqlite3
class SparseMatrix:
    def __init__(self, row_count=2147483647, column_count=2147483647):
        self.db = sqlite3.connect(":memory:")
        self.db.execute("CREATE TABLE 'matrix' ('x' integer, 'y' integer, 'value' real, primary key('x', 'y'));")
        self.row = row_count
        self.column = column_count
    
    def __del__(self):
        self.db.close()
    
    def __getitem__(self, index):
        if isinstance(index, tuple):
            if len(index) == 2:
                row,column = index
                if row >= self.row or column >= self.column:
                    raise IndexError
                cursor = self.db.execute("select value from matrix where x=? and y=?", index)
                value = cursor.fetchone()
                if value:
                    return value[0]
                else:
                    return 0.0
            else:
                raise IndexError
        else:
            raise TypeError
    
    def __setitem__(self, index, value):
        row, column = index
        if row >= self.row or column >= self.column:
            raise IndexError
        self.db.execute("insert or replace into matrix values(?,?,?)", (row, column, value))
    
    def __add__(self, other):  #self + other
        if isinstance(other, SparseMatrix):
            m = self.copy()
            for r,c,v in other:
                m[r,c] = self[r,c] + v
            m.update()
            return m
        else:
            raise TypeError
    
    def __iadd__(self, other): #self += other
        if isinstance(other, SparseMatrix):
            if self.row != other.row or self.column != other.column:
                raise IndexError
            for row,column,value in other:
                self[row,column] = self[row,column] + value
            self.update()
            return self
        else:
            raise TypeError
    
    def __sub__(self, other):
        if isinstance(other, SparseMatrix):
            if self.row != other.row or self.column != other.column:
                raise IndexError
            m = self.copy()
            for r,c,v in other:
                m[r,c] = self[r,c] - v
            m.update()
            return m
        else:
            raise TypeError
    
    def __mul__(self, other):
        if isinstance(other, SparseMatrix):
            m = SparseMatrix(self.row, other.column)
            rows = self.getallrows()
            columns = other.getallcolumns()
            results = []
            data = []
            col_data = other.getcolumn(0)
            for r in rows:
                row_data = self.getrow(r)
                for c in columns:
                    col_data = other.getcolumn(c)
                    pr = row_data.__iter__()
                    pc = col_data.__iter__()
                    rdata = pr.next()
                    cdata = pc.next()
                    while True:
                        try:
                            if rdata[0] == cdata[0]:
                                results.append(rdata[1] * cdata[1])
                                rdata = pr.next()
                                cdata = pc.next()
                            else:
                                if rdata[0] > cdata[0]:
                                    cdata = pc.next()
                                else:
                                    rdata = pr.next()
                        except StopIteration:
                            if results:
                                m[r,0] = sum(results)
                                results = list()
                            break
            m.update()
            return m
        elif isinstance(other, int) or isinstance(other, float):
            m = SparseMatrix(self.row, self.column)
            for r,c,v in self:
                m[r,c] = v * other
            m.update()
            return m
        else:
            raise TypeError
    
    def __iter__(self):
        cursor = self.db.execute("select x,y,value from matrix order by x,y")
        for cell in cursor:
            yield cell
    
    def __len__(self):
        cursor = self.db.execute("select count(*) from matrix")
        return cursor.fetchone()[0]
        
    def insert(self, cells):
        for row,column,value in cells:
            self[row, column] = value
        self.update()
    
    def copy(self):
        m = SparseMatrix(self.row, self.column)
        for r,c,v in self:
            m[r,c] = v
        return m
        
    def getrow(self, row):
        cursor = self.db.execute("select y,value from matrix where x=%d order by y" % row)
        return cursor.fetchall()
    
    def getcolumn(self, column):
        cursor = self.db.execute("select x,value from matrix where y=%d order by x" % column)
        return cursor.fetchall()
    
    def getallrows(self):
        rows = self.db.execute("select distinct x from matrix order by x").fetchall()
        return zip(*rows)[0]
    
    def getallcolumns(self):
        columns = self.db.execute("select distinct y from matrix order by y").fetchall()
        return zip(*columns)[0]
    
    def update(self):
        self.db.execute("DELETE FROM matrix where value between -0.0000001 and 0.0000001")
        self.db.commit()
        

支持下表访问,支持矩阵的加法,减法和乘法运算,支持遍历。
如:
m1 = SparseMatrix()
m2 = SparseMatrix()
m1[0,0] = 1
m2[1,1] = 2
m = m1 + m2
for row,column,value in m:
    print row,column,value


只是实现里功能,速度方面就比较头疼了。
做一个NxN的矩阵与N维向量的乘法用了大于6个小时(N=520000)。
还有一个支持多CPU并行乘法的版本,写的太难看,就不贴出来了。

不过有一个优点是不需要明确的知道矩阵的大小,在某些情况下还是有一些用处吧,我想。

你可能感兴趣的:(C++,c,python,sqlite,C#)