Python 开源项目records库学习

records库源码学习

records 项目地址:https://github.com/kennethreitz/records

  • 该项目是大神kennethreitz写的一个只有500行代码的库
  • 用来入门学习一个开源项目 个人觉得还是很不错的
  • 项目源码名为records.py位于根目录下
  • 作者使用了pipenv来管理相关依赖
  • 你可以fork这个项目后,使用pipenv install安装相关依赖。前提是你已经安装了pipenv

不BB直接上代码?

# -*- coding: utf-8 -*-

import os
from sys import stdout
from collections import OrderedDict
from contextlib import contextmanager
from inspect import isclass

import tablib
from docopt import docopt
from sqlalchemy import create_engine, exc, inspect, text

DATABASE_URL = os.environ.get('DATABASE_URL')


def isexception(obj):
    """Given an object, return a boolean indicating whether it is an instance
    or subclass of :py:class:`Exception`.
    """
    if isinstance(obj, Exception):  # 判断obj是否是一个Exception对象
        return True
    if isclass(obj) and issubclass(obj, Exception):  # 判断obj是否是一个class 并且 是否是Exception的子类
        return True
    return False


class Record(object):
    """A row, from a query, from a database."""
    __slots__ = ('_keys', '_values')  # 简单来讲__slots__属性控制了该类可以绑定的属性只有_keys和_values
                                      # 单下划线保护变量,不能直接访问
    def __init__(self, keys, values):
        self._keys = keys
        self._values = values

        # Ensure that lengths match properly.
        assert len(self._keys) == len(self._values)

    # keys和values方法就是获取保护变量的值的函数
    def keys(self):
        """Returns the list of column names from the query."""
        return self._keys

    def values(self):
        """Returns the list of values from the query."""
        return self._values

    def __repr__(self):  # repr就是对象的显示,与str相似但不同
        return ''.format(self.export('json')[1:-1])  # 切片1:-1是因为这是一个str类型去掉左右的大括号字符串

    def __getitem__(self, key):  # 该方法使该类的对象可以实现索引的功能,类似于:a_list[index]
        # Support for index-based lookup.
        # 如果传入的是int类型的,返回该int索引对应的values
        if isinstance(key, int):
            return self.values()[key]

        # Support for string-based lookup.
        # 如果传入的是字符串字段名称
        if key in self.keys():  # 如果有这个字符串对应的键
            i = self.keys().index(key)  # 根据键获取其在键的序列中的位置
            if self.keys().count(key) > 1:  # 如果有多于1个该键则raise一个Error
                raise KeyError("Record contains multiple '{}' fields.".format(key))
            return self.values()[i]  # 根据在键序列中的位置到值序列中获取值
        # 没有则抛出一个Error
        raise KeyError("Record contains no '{}' field.".format(key))

    def __getattr__(self, key):  # 作用于属性查找的最后一步,用来兜底;
                                 # 这里使用该方法是为了该类的对象除了通过索引,还可以通过.属性来访问value
        try:
            return self[key]
        except KeyError as e:
            raise AttributeError(e)

    def __dir__(self):  # 当dir()函数被调用时调用
        standard = dir(super(Record, self))
        # Merge standard attrs with generated ones (from column names).
        return sorted(standard + [str(k) for k in self.keys()])  # 获取了父类object类的属性然后与record的keys作为属性添加进来
                                                                 # sorted内置函数提供了对序列对象的排序功能
    def get(self, key, default=None):  # 本人认为这个get方法的实现类似于dict中的get方法,同样是等效于r.get('A')==r['A']
        """Returns the value for a given key, or default."""
        try:
            return self[key]
        except KeyError:
            return default
    # get __getattr__ __getitem__ 3个方法实现了通过索引,属性,get方法 3种方式获取value值
    def as_dict(self, ordered=False):
        """Returns the row as a dictionary, as ordered."""
        items = zip(self.keys(), self.values())  # zip 来将相互对应的key和value对应组合成多个元组 组成的list
        # items=[(k1,v1),(k2,v2),.....]
        return OrderedDict(items) if ordered else dict(items)  # 如果需要顺序字典则生成顺序字典,否则普通字典

    @property
    def dataset(self):  # 创建一个只读属性dataset,该属性返回一个Tablib的Dataset对象
        """A Tablib Dataset containing the row."""
        data = tablib.Dataset()
        data.headers = self.keys()  # 为Dataset对象设置headers属性为record对象的键

        row = _reduce_datetimes(self.values())  # 处理key对应的values值,如果是datetime类型的values将转换为字符串
        data.append(row)  # 为Dataset添加一条记录

        return data

    def export(self, format, **kwargs):
        """Exports the row to the given format."""
        return self.dataset.export(format, **kwargs)  # 这里调用了Dataset对象的export方法导出数据,导出的为str类型


class RecordCollection(object):  # 记录集类
    """A set of excellent Records from a query."""
    def __init__(self, rows):  # 初始化一些属性
        self._rows = rows
        self._all_rows = []
        self.pending = True

    def __repr__(self):  # 显示
        return ''.format(len(self), self.pending)

    def __iter__(self): # 个人简单理解这里就是一种动态读取数据的方式,应该是为了提高性能吧
        """Iterate over all rows, consuming the underlying generator
        only when necessary."""
        i = 0
        while True:
            # Other code may have iterated between yields,
            # so always check the cache.
            if i < len(self):
                yield self[i]
            else:
                # Throws StopIteration when done.
                # Prevent StopIteration bubbling from generator, following https://www.python.org/dev/peps/pep-0479/
                try:
                    yield next(self)
                except StopIteration:
                    return
            i += 1

    def next(self):  # 单纯的调用__next__()方法
        return self.__next__()

    def __next__(self):  # 实现该方法证明该类的对象也是迭代器
        try:
            nextrow = next(self._rows)  # 迭代类对象的_rows属性的下一个元素
            self._all_rows.append(nextrow)  # 添加_rows迭代到的下一个元素并添加到_all_rows属性中去
            return nextrow  # 也就是说每次迭代该类对象就是迭代其_rows属性的元素
        except StopIteration:
            self.pending = False
            raise StopIteration('RecordCollection contains no more rows.')

    def __getitem__(self, key):  # 定制该类的索引用法
        # key就是传入的索引,[]的索引只能是int和slice两种类型
        is_int = isinstance(key, int)

        # Convert RecordCollection[1] into slice.
        if is_int:
            key = slice(key, key + 1)  # 将你的整数索引转化为一个切片索引,该切片索引切出的就是本身
        
        # key有可能直接就是一个slice类型的对象
        while len(self) < key.stop or key.stop is None:  # 判断如果你访问的索引大于目前_all_rows的长度就迭代该对象去                                                         
            try:                         # 获取下一个元素,之所以要or key.stop is None是因为如果传入的就是一个索引就会出现stop为空
                next(self)  # 调用该对象的__next__()方法,往_all_rows属性中添加元素
            except StopIteration:
                break
        # 上面的while执行完_all_rows属性中就会填充至你所需最大元素的量
        rows = self._all_rows[key]  # rows就是你索引需要的元素
        if is_int:  # 个人认为这里判断传入的是整数索引还是切片索引,key不是int就是slice
            return rows[0]
        else:  # 如果是slice对象则rows就是一个切片后的list,返回的结果就是一个新的RecordCollection对象
            return RecordCollection(iter(rows))

    def __len__(self):  # 要想对象能够使用len()函数就需要内部实现__len__()方法该方法的返回值就是len()运算该对象的结果
        return len(self._all_rows)  # len将返回_all_rows属性list的长度

    def export(self, format, **kwargs):  # 将记录集导出指定格式
        """Export the RecordCollection to a given format (courtesy of Tablib)."""
        return self.dataset.export(format, **kwargs)

    @property
    def dataset(self):  # 只读的dataset属性用来获取Dataset对象
        """A Tablib Dataset representation of the RecordCollection."""
        # Create a new Tablib Dataset.
        data = tablib.Dataset()

        # If the RecordCollection is empty, just return the empty set
        # Check number of rows by typecasting to list
        if len(list(self)) == 0:  # 如果该对象中的长度为0则返回空的Dataset
            return data

        # Set the column names as headers on Tablib Dataset.
        first = self[0]  # self[0]是返回__all_rows中的第0条记录
        data.headers = first.keys() # 调用first对象的keys()方法来设置Dataset对象的表头
        for row in self.all():
            row = _reduce_datetimes(row.values())
            data.append(row)

        return data

    def all(self, as_dict=False, as_ordereddict=False):    # 返回所有记录,可设置返回为dict
        """Returns a list of all rows for the RecordCollection. If they haven't
        been fetched yet, consume the iterator and cache the results."""

        # By calling list it calls the __iter__ method
        rows = list(self)  # 遍历self自身,就是获取其_all_rows属性中的元素的list

        if as_dict:
            return [r.as_dict() for r in rows]  # 调用_all_rows中每个元素自身的as_dict()方法
        elif as_ordereddict:
            return [r.as_dict(ordered=True) for r in rows] # 同上

        return rows 

    def as_dict(self, ordered=False):  # 将整个record集合转化为dict
        return self.all(as_dict=not(ordered), as_ordereddict=ordered)

    def first(self, default=None, as_dict=False, as_ordereddict=False):
        """Returns a single record for the RecordCollection, or `default`. If
        `default` is an instance or subclass of Exception, then raise it
        instead of returning it."""

        # Try to get a record, or return/raise default.
        try:
            record = self[0]  # 获取_all_rows中的[0]第一个元素
        except IndexError:
            if isexception(default):
                raise default
            return default

        # Cast and return.
        if as_dict:
            return record.as_dict()
        elif as_ordereddict:
            return record.as_dict(ordered=True)
        else:
            return record

    def one(self, default=None, as_dict=False, as_ordereddict=False):
        """Returns a single record for the RecordCollection, ensuring that it
        is the only record, or returns `default`. If `default` is an instance
        or subclass of Exception, then raise it instead of returning it."""

        # Ensure that we don't have more than one row.
        try:
            self[1]  # 去尝试访问第二个元素
        except IndexError:
            return self.first(default=default, as_dict=as_dict, as_ordereddict=as_ordereddict)  # 出错了就return他的第一个元素
        else: # 尝试访问第二个元素没问题就抛出一个valueError 
            raise ValueError('RecordCollection contained more than one row. '
                             'Expects only one row when using '
                             'RecordCollection.one')

    def scalar(self, default=None):
        """Returns the first column of the first row, or `default`."""
        row = self.one()
        return row[0] if row else default

    # 综合解读RecordCollection这个类,他提供了动态获取row中的元素,你访问到的最大元素位置,其会动态为你获取
    # 使用list(x)会迭代所有元素,此时会自动将row属性中的所有元素添加到_all_rows属性中
    # 使用索引访问 也是动态获取到你所取索引之前的所有元素到_all_rows中
    

class Database(object):
    """A Database. Encapsulates a url and an SQLAlchemy engine with a pool of
    connections.
    """

    def __init__(self, db_url=None, **kwargs):
        # If no db_url was provided, fallback to $DATABASE_URL.
        self.db_url = db_url or DATABASE_URL  # DATABASE_URL是从环境变量中提取的值,赋值操作中的or表示or前为true则赋值or前的

        if not self.db_url:
            raise ValueError('You must provide a db_url.')  # 如果既没有传参 环境变量也没用找到则抛出错误

        # Create an engine.
        self._engine = create_engine(self.db_url, **kwargs)  # create_engine是调用SQLAlchemy中的方法
        self.open = True

    def close(self):  # 调用engine对象的dispose()方法
        """Closes the Database."""
        self._engine.dispose()
        self.open = False  # 综上可看出,这个open属性应该是表示是否连接打开了数据库

    def __enter__(self):  # 紧跟with后面的语句被求值后,返回对象的__enter__()方法被调用,这个方法的返回值将被赋值给as后面的变量
        return self

    def __exit__(self, exc, val, traceback):  #  当with后面的代码块全部被执行完之后,将调用前面返回对象的__exit__()方法。
        self.close()

    def __repr__(self):  # Database对象的显示
        return ''.format(self.open)

    def get_table_names(self, internal=False):
        """Returns a list of table names for the connected database."""

        # Setup SQLAlchemy for Database inspection.
        return inspect(self._engine).get_table_names()  # 都是在调用SQLAlchemy的一些函数,不深究,知道其返回什么就o
                                                        # 将返回所连接数据库中的所有数据表名的一个list

    def get_connection(self):
        """Get a connection to this Database. Connections are retrieved from a
        pool.
        """
        if not self.open:
            raise exc.ResourceClosedError('Database closed.')  # 如果数据库未打开则抛出一个错误

        return Connection(self._engine.connect())  # 返回一个Connection对象,并未该对象传入Engine对象的connect()方法的返回值

    def query(self, query, fetchall=False, **params):  # 实际就是调用Connection对象的query方法
        """Executes the given SQL query against the Database. Parameters can,
        optionally, be provided. Returns a RecordCollection, which can be
        iterated over to get result rows as dictionaries.
        """
        with self.get_connection() as conn:  # conn就是Connection对象
            return conn.query(query, fetchall, **params)  # 调用Connection的query方法,跳到下面看这个方法

    def bulk_query(self, query, *multiparams):  # 实际是调用Connection对象的bulk_query方法
        """Bulk insert or update."""

        with self.get_connection() as conn:
            conn.bulk_query(query, *multiparams)

    def query_file(self, path, fetchall=False, **params):  # 同上
        """Like Database.query, but takes a filename to load a query from."""

        with self.get_connection() as conn:
            return conn.query_file(path, fetchall, **params)

    def bulk_query_file(self, path, *multiparams):  # 同上
        """Like Database.bulk_query, but takes a filename to load a query from."""

        with self.get_connection() as conn:
            conn.bulk_query_file(path, *multiparams)

    @contextmanager  # 实现了与__enter__与__exit__一样的上下文管理功能,只不过这是一个装饰器,是用来针对一个方法的
    def transaction(self):  # 事务管理
        """A context manager for executing a transaction on this Database."""
        # @contextmanager装饰的方法,当你with这个方法时会顺序执行到yield语句,yield语句返回的对象给as后的变量
        # 然后执行with下的语句块,正常执行完毕,又回到yield后的语句顺序执行
        conn = self.get_connection()
        tx = conn.transaction()  # 调用Connection对象的transaction方法
        try:
            yield conn
            tx.commit()  # 调用SQLAlchemy的事务相关的方法,提交
        except:
            tx.rollback()  # 回滚
        finally:
            conn.close()  # 关闭


class Connection(object):  # 该类的_conn属性对应 SQLAlchemy的Connection对象
    """A Database connection."""

    def __init__(self, connection):  # 根据上面Database类的get_connection()方法猜测这个connection参数是传入SqlAlchemy
        self._conn = connection      # 模块中的Connection对象
        self.open = not connection.closed

    def close(self):
        self._conn.close()  # SQLAlchemy的Connection对象的close方法调用
        self.open = False  # 表示这个Connection已经关闭

    def __enter__(self):
        return self

    def __exit__(self, exc, val, traceback):
        self.close()  # 关闭 SQLAlchemy的Connection的连接

    def __repr__(self):
        return ''.format(self.open)

    def query(self, query, fetchall=False, **params):
        """Executes the given SQL query against the connected Database.
        Parameters can, optionally, be provided. Returns a RecordCollection,
        which can be iterated over to get result rows as dictionaries.
        """

        # Execute the given query.
        cursor = self._conn.execute(text(query), **params) # TODO: PARAMS GO HERE
        # 个人 研究了以下上面这个execute的意义就是使用text()函数可以使你的sql语句中可以动态设置一些参数,这个参数的传入通
        # 过后面的params这个参数传到execute函数中去
        # Row-by-row Record generator.
        # cursor是一个迭代器,该迭代器中是查询到的每一条记录
        row_gen = (Record(cursor.keys(), row) for row in cursor)  # 将查询到的记录的表头跟每个内容传个一个Record对象
                                                                  # 这里使用()生成了一个包含多个Record对象的迭代器对象
        # Convert psycopg2 results to RecordCollection.
        results = RecordCollection(row_gen)

        # Fetch all results if desired.
        if fetchall:
            results.all()

        return results

    def bulk_query(self, query, *multiparams):
        """Bulk insert or update."""

        self._conn.execute(text(query), *multiparams)  # 同样调用了Connection对象的execute方法

    def query_file(self, path, fetchall=False, **params):  # 读取一个文件来执行sql语句
        """Like Connection.query, but takes a filename to load a query from."""

        # If path doesn't exists
        if not os.path.exists(path):
            raise IOError("File '{}' not found!".format(path))  # 判断文件是否存在,不在就抛出Error

        # If it's a directory
        if os.path.isdir(path):
            raise IOError("'{}' is a directory!".format(path))  # 判断路径如果是文件夹就抛出Error

        # Read the given .sql file into memory.
        with open(path) as f:
            query = f.read()

        # Defer processing to self.query method.
        return self.query(query=query, fetchall=fetchall, **params)  # 将读取到的文件传给query参数

    def bulk_query_file(self, path, *multiparams):
        """Like Connection.bulk_query, but takes a filename to load a query
        from.
        """

         # If path doesn't exists
        if not os.path.exists(path):
            raise IOError("File '{}'' not found!".format(path))

        # If it's a directory
        if os.path.isdir(path):
            raise IOError("'{}' is a directory!".format(path))

        # Read the given .sql file into memory.
        with open(path) as f:
            query = f.read()

        self._conn.execute(text(query), *multiparams)

    def transaction(self):
        """Returns a transaction object. Call ``commit`` or ``rollback``
        on the returned object as appropriate."""

        return self._conn.begin()

# 全局的方法
def _reduce_datetimes(row):  # 转换datetimes为strings
    """Receives a row, converts datetimes to strings."""

    row = list(row)
    # 这里你可能会考虑为何不for in 这个row 而是利用其索引来操作,其实是因为如果不这样做无法应用你的修改到原始list中的内容
    for i in range(len(row)):
        if hasattr(row[i], 'isoformat'):  # 判断数据是否含有isoformat这个属性或者方法
            row[i] = row[i].isoformat()  # 将datetime类转换为ISO格式的时间字符串
    return tuple(row)
# cli-command line interface 命令行界面
def cli():
    supported_formats = 'csv tsv json yaml html xls xlsx dbf latex ods'.split()  # 将该字符串以空格分割成一个list
    formats_lst=", ".join(supported_formats) # 将上述list以逗号分隔开形成一个字符串
    cli_docs ="""Records: SQL for Humans™
A Kenneth Reitz project.

Usage:
  records  [] [...] [--url=]
  records (-h | --help)

Options:
  -h --help     Show this screen.
  --url=   The database URL to use. Defaults to $DATABASE_URL.

Supported Formats:
   %(formats_lst)s

   Note: xls, xlsx, dbf, and ods formats are binary, and should only be
         used with redirected output e.g. '$ records sql xls > sql.xls'.

Query Parameters:
    Query parameters can be specified in key=value format, and injected
    into your query in :key format e.g.:

    $ records 'select * from repos where language ~= :lang' lang=python

Notes:
  - While you may specify a database connection string with --url, records
    will automatically default to the value of $DATABASE_URL, if available.
  - Query is intended to be the path of a SQL file, however a query string
    can be provided instead. Use this feature discernfully; it's dangerous.
  - Records is intended for report-style exports of database queries, and
    has not yet been optimized for extremely large data dumps.
    """ % dict(formats_lst=formats_lst)

    # Parse the command-line arguments.
    arguments = docopt(cli_docs)

    query = arguments['']
    params = arguments['']
    format = arguments.get('')
    if format and "=" in format:
        del arguments['']
        arguments[''].append(format)
        format = None
    if format and format not in supported_formats:
        print('%s format not supported.' % format)
        print('Supported formats are %s.' % formats_lst)
        exit(62)

    # Can't send an empty list if params aren't expected.
    try:
        params = dict([i.split('=') for i in params])
    except ValueError:
        print('Parameters must be given in key=value format.')
        exit(64)

    # Be ready to fail on missing packages
    try:
        # Create the Database.
        db = Database(arguments['--url'])

        # Execute the query, if it is a found file.
        if os.path.isfile(query):
            rows = db.query_file(query, **params)

        # Execute the query, if it appears to be a query string.
        elif len(query.split()) > 2:
            rows = db.query(query, **params)

        # Otherwise, say the file wasn't found.
        else:
            print('The given query could not be found.')
            exit(66)

        # Print results in desired format.
        if format:
            content = rows.export(format)
            if isinstance(content, bytes):
                print_bytes(content)
            else:
                print(content)
        else:
            print(rows.dataset)
    except ImportError as impexc:
        print(impexc.msg)
        print("Used database or format require a package, which is missing.")
        print("Try to install missing packages.")
        exit(60)


def print_bytes(content):
    try:
        stdout.buffer.write(content)
    except AttributeError:
        stdout.write(content)


# Run the CLI when executed directly.
if __name__ == '__main__':
    cli()

你可能感兴趣的:(python相关)