records库源码学习
records 项目地址:https://github.com/kennethreitz/records
- 该项目是大神kennethreitz写的一个只有500行代码的库
- 用来入门学习一个开源项目 个人觉得还是很不错的
- 项目源码名为records.py位于根目录下
- 作者使用了pipenv来管理相关依赖
- 你可以fork这个项目后,使用
pipenv install
安装相关依赖。前提是你已经安装了pipenv
不BB直接上代码?
import os
from sys import stdout
from collections import OrderedDict
from contextlib import contextmanager
from inspect import isclass
import tablib
from docopt import docopt
from sqlalchemy import create_engine, exc, inspect, text
DATABASE_URL = os.environ.get('DATABASE_URL')
def isexception(obj):
"""Given an object, return a boolean indicating whether it is an instance
or subclass of :py:class:`Exception`.
"""
if isinstance(obj, Exception):
return True
if isclass(obj) and issubclass(obj, Exception):
return True
return False
class Record(object):
"""A row, from a query, from a database."""
__slots__ = ('_keys', '_values')
def __init__(self, keys, values):
self._keys = keys
self._values = values
assert len(self._keys) == len(self._values)
def keys(self):
"""Returns the list of column names from the query."""
return self._keys
def values(self):
"""Returns the list of values from the query."""
return self._values
def __repr__(self):
return ''.format(self.export('json')[1:-1])
def __getitem__(self, key):
if isinstance(key, int):
return self.values()[key]
if key in self.keys():
i = self.keys().index(key)
if self.keys().count(key) > 1:
raise KeyError("Record contains multiple '{}' fields.".format(key))
return self.values()[i]
raise KeyError("Record contains no '{}' field.".format(key))
def __getattr__(self, key):
try:
return self[key]
except KeyError as e:
raise AttributeError(e)
def __dir__(self):
standard = dir(super(Record, self))
return sorted(standard + [str(k) for k in self.keys()])
def get(self, key, default=None):
"""Returns the value for a given key, or default."""
try:
return self[key]
except KeyError:
return default
def as_dict(self, ordered=False):
"""Returns the row as a dictionary, as ordered."""
items = zip(self.keys(), self.values())
return OrderedDict(items) if ordered else dict(items)
@property
def dataset(self):
"""A Tablib Dataset containing the row."""
data = tablib.Dataset()
data.headers = self.keys()
row = _reduce_datetimes(self.values())
data.append(row)
return data
def export(self, format, **kwargs):
"""Exports the row to the given format."""
return self.dataset.export(format, **kwargs)
class RecordCollection(object):
"""A set of excellent Records from a query."""
def __init__(self, rows):
self._rows = rows
self._all_rows = []
self.pending = True
def __repr__(self):
return ''.format(len(self), self.pending)
def __iter__(self):
"""Iterate over all rows, consuming the underlying generator
only when necessary."""
i = 0
while True:
if i < len(self):
yield self[i]
else:
try:
yield next(self)
except StopIteration:
return
i += 1
def next(self):
return self.__next__()
def __next__(self):
try:
nextrow = next(self._rows)
self._all_rows.append(nextrow)
return nextrow
except StopIteration:
self.pending = False
raise StopIteration('RecordCollection contains no more rows.')
def __getitem__(self, key):
is_int = isinstance(key, int)
if is_int:
key = slice(key, key + 1)
while len(self) < key.stop or key.stop is None:
try:
next(self)
except StopIteration:
break
rows = self._all_rows[key]
if is_int:
return rows[0]
else:
return RecordCollection(iter(rows))
def __len__(self):
return len(self._all_rows)
def export(self, format, **kwargs):
"""Export the RecordCollection to a given format (courtesy of Tablib)."""
return self.dataset.export(format, **kwargs)
@property
def dataset(self):
"""A Tablib Dataset representation of the RecordCollection."""
data = tablib.Dataset()
if len(list(self)) == 0:
return data
first = self[0]
data.headers = first.keys()
for row in self.all():
row = _reduce_datetimes(row.values())
data.append(row)
return data
def all(self, as_dict=False, as_ordereddict=False):
"""Returns a list of all rows for the RecordCollection. If they haven't
been fetched yet, consume the iterator and cache the results."""
rows = list(self)
if as_dict:
return [r.as_dict() for r in rows]
elif as_ordereddict:
return [r.as_dict(ordered=True) for r in rows]
return rows
def as_dict(self, ordered=False):
return self.all(as_dict=not(ordered), as_ordereddict=ordered)
def first(self, default=None, as_dict=False, as_ordereddict=False):
"""Returns a single record for the RecordCollection, or `default`. If
`default` is an instance or subclass of Exception, then raise it
instead of returning it."""
try:
record = self[0]
except IndexError:
if isexception(default):
raise default
return default
if as_dict:
return record.as_dict()
elif as_ordereddict:
return record.as_dict(ordered=True)
else:
return record
def one(self, default=None, as_dict=False, as_ordereddict=False):
"""Returns a single record for the RecordCollection, ensuring that it
is the only record, or returns `default`. If `default` is an instance
or subclass of Exception, then raise it instead of returning it."""
try:
self[1]
except IndexError:
return self.first(default=default, as_dict=as_dict, as_ordereddict=as_ordereddict)
else:
raise ValueError('RecordCollection contained more than one row. '
'Expects only one row when using '
'RecordCollection.one')
def scalar(self, default=None):
"""Returns the first column of the first row, or `default`."""
row = self.one()
return row[0] if row else default
class Database(object):
"""A Database. Encapsulates a url and an SQLAlchemy engine with a pool of
connections.
"""
def __init__(self, db_url=None, **kwargs):
self.db_url = db_url or DATABASE_URL
if not self.db_url:
raise ValueError('You must provide a db_url.')
self._engine = create_engine(self.db_url, **kwargs)
self.open = True
def close(self):
"""Closes the Database."""
self._engine.dispose()
self.open = False
def __enter__(self):
return self
def __exit__(self, exc, val, traceback):
self.close()
def __repr__(self):
return ''.format(self.open)
def get_table_names(self, internal=False):
"""Returns a list of table names for the connected database."""
return inspect(self._engine).get_table_names()
def get_connection(self):
"""Get a connection to this Database. Connections are retrieved from a
pool.
"""
if not self.open:
raise exc.ResourceClosedError('Database closed.')
return Connection(self._engine.connect())
def query(self, query, fetchall=False, **params):
"""Executes the given SQL query against the Database. Parameters can,
optionally, be provided. Returns a RecordCollection, which can be
iterated over to get result rows as dictionaries.
"""
with self.get_connection() as conn:
return conn.query(query, fetchall, **params)
def bulk_query(self, query, *multiparams):
"""Bulk insert or update."""
with self.get_connection() as conn:
conn.bulk_query(query, *multiparams)
def query_file(self, path, fetchall=False, **params):
"""Like Database.query, but takes a filename to load a query from."""
with self.get_connection() as conn:
return conn.query_file(path, fetchall, **params)
def bulk_query_file(self, path, *multiparams):
"""Like Database.bulk_query, but takes a filename to load a query from."""
with self.get_connection() as conn:
conn.bulk_query_file(path, *multiparams)
@contextmanager
def transaction(self):
"""A context manager for executing a transaction on this Database."""
conn = self.get_connection()
tx = conn.transaction()
try:
yield conn
tx.commit()
except:
tx.rollback()
finally:
conn.close()
class Connection(object):
"""A Database connection."""
def __init__(self, connection):
self._conn = connection
self.open = not connection.closed
def close(self):
self._conn.close()
self.open = False
def __enter__(self):
return self
def __exit__(self, exc, val, traceback):
self.close()
def __repr__(self):
return ''.format(self.open)
def query(self, query, fetchall=False, **params):
"""Executes the given SQL query against the connected Database.
Parameters can, optionally, be provided. Returns a RecordCollection,
which can be iterated over to get result rows as dictionaries.
"""
cursor = self._conn.execute(text(query), **params)
row_gen = (Record(cursor.keys(), row) for row in cursor)
results = RecordCollection(row_gen)
if fetchall:
results.all()
return results
def bulk_query(self, query, *multiparams):
"""Bulk insert or update."""
self._conn.execute(text(query), *multiparams)
def query_file(self, path, fetchall=False, **params):
"""Like Connection.query, but takes a filename to load a query from."""
if not os.path.exists(path):
raise IOError("File '{}' not found!".format(path))
if os.path.isdir(path):
raise IOError("'{}' is a directory!".format(path))
with open(path) as f:
query = f.read()
return self.query(query=query, fetchall=fetchall, **params)
def bulk_query_file(self, path, *multiparams):
"""Like Connection.bulk_query, but takes a filename to load a query
from.
"""
if not os.path.exists(path):
raise IOError("File '{}'' not found!".format(path))
if os.path.isdir(path):
raise IOError("'{}' is a directory!".format(path))
with open(path) as f:
query = f.read()
self._conn.execute(text(query), *multiparams)
def transaction(self):
"""Returns a transaction object. Call ``commit`` or ``rollback``
on the returned object as appropriate."""
return self._conn.begin()
def _reduce_datetimes(row):
"""Receives a row, converts datetimes to strings."""
row = list(row)
for i in range(len(row)):
if hasattr(row[i], 'isoformat'):
row[i] = row[i].isoformat()
return tuple(row)
def cli():
supported_formats = 'csv tsv json yaml html xls xlsx dbf latex ods'.split()
formats_lst=", ".join(supported_formats)
cli_docs ="""Records: SQL for Humans™
A Kenneth Reitz project.
Usage:
records [] [...] [--url=]
records (-h | --help)
Options:
-h --help Show this screen.
--url= The database URL to use. Defaults to $DATABASE_URL.
Supported Formats:
%(formats_lst)s
Note: xls, xlsx, dbf, and ods formats are binary, and should only be
used with redirected output e.g. '$ records sql xls > sql.xls'.
Query Parameters:
Query parameters can be specified in key=value format, and injected
into your query in :key format e.g.:
$ records 'select * from repos where language ~= :lang' lang=python
Notes:
- While you may specify a database connection string with --url, records
will automatically default to the value of $DATABASE_URL, if available.
- Query is intended to be the path of a SQL file, however a query string
can be provided instead. Use this feature discernfully; it's dangerous.
- Records is intended for report-style exports of database queries, and
has not yet been optimized for extremely large data dumps.
""" % dict(formats_lst=formats_lst)
arguments = docopt(cli_docs)
query = arguments['']
params = arguments['']
format = arguments.get('')
if format and "=" in format:
del arguments['']
arguments[''].append(format)
format = None
if format and format not in supported_formats:
print('%s format not supported.' % format)
print('Supported formats are %s.' % formats_lst)
exit(62)
try:
params = dict([i.split('=') for i in params])
except ValueError:
print('Parameters must be given in key=value format.')
exit(64)
try:
db = Database(arguments['--url'])
if os.path.isfile(query):
rows = db.query_file(query, **params)
elif len(query.split()) > 2:
rows = db.query(query, **params)
else:
print('The given query could not be found.')
exit(66)
if format:
content = rows.export(format)
if isinstance(content, bytes):
print_bytes(content)
else:
print(content)
else:
print(rows.dataset)
except ImportError as impexc:
print(impexc.msg)
print("Used database or format require a package, which is missing.")
print("Try to install missing packages.")
exit(60)
def print_bytes(content):
try:
stdout.buffer.write(content)
except AttributeError:
stdout.write(content)
if __name__ == '__main__':
cli()