训练内存轻量化
最近又在训练模型(炼丹),以前老抱怨,区区2万 samples 也好意思叫大数据,近期的任务似乎听到了我这个抱怨,纷纷都是什么“1700万个句子”,“4000个文档”的数据,对服务器内存一次次的进行着冲击。
虽说我之前已经写过一个CIR(CorpusIterationReader)类实现的文章用来解决类似问题(哎?我那篇文章哪去了,被吃了么……emmmm,以后再重发一次吧。)但是那个类也只能让 pivot 以 “文件指针+instance指针” 的方式进行顺序存取,不是很好处理 “shuffle后随机存取” 的情况,再者,“每个文件中包括多个samples” 的设计在多进程中容易产生冲突。
经 cyx 学长提醒,可以考虑每个 sample 单独作为一个文件。(我觉得吧,这个也会有个小问题,就是这个文件夹里千万别一不小心按一下 ls -all
不然要等好半天了哈哈哈)于是基于学长 OneFileDB 的设计,重构并实现了一些这种处理方案的工具类及工具函数,便于我基于 PyTorch 的模型得以正常训练。
跨版本编码兼容
实现过程中,由于 python2 的老项目和 python3 的新项目都需要使用,于是编码也是一个大难题,参考 Bert 里的 convert_to_unicode,研究了 ujson
和 json.JSONEncoder
是如何将不同编码处理成 unicode 并存为 json 格式,把相关的实现也放了进去。
对于一个 json 文件而言,通常是一个 list,里面包含多个dict的形式存储的 samples
对于模型而言,我们需要的是,在 sample 的数量足够多时,还要能够较快地通过下标(或者key)来获取到对应的 sample 喂给模型。
# JSON EXAMPLE
j = [{'info': {'sid': 'test1'},
'words': [{'id': 'w0', 'word': u'电'.encode('utf-8')},
{'id': 'w1', 'word': u'话'.encode('utf-8')},
{'id': 'w2', 'word': '[unused10]'},
{'id': 'w3', 'word': '0'},
{'id': 'w4', 'word': '2'},
{'id': 'w5', 'word': '1'},
{'id': 'w6', 'word': '-'},
{'id': 'w7', 'word': '3'},
{'id': 'w27', 'word': '0'}],
'entities': [], 'relations': []},
{'info': {'sid': 'test2'},
'words': [{'id': 'w0', 'word': u'地'.encode('utf-8')},
{'id': 'w1', 'word': u'址'.encode('utf-8')}],
'entities': [], 'relations': []}]
我们有三种方式来进行存储:
- OneFileDB,即单文件存储,和我们平时直接读一个文件进来没有两样
- FolderDB,文件夹存储,文件夹中的每一个文件是一个 sample
- CFolderDB,加密文件夹存储,是 FolderDB 的继承类,不同点在于 sample 是加密压缩的
# 特别的,我们可以将一个json文件读入为 OneFileDB 后,
# 通过成员函数 `transfer_to_folderdb(path=)` 生成一个 FolderDB
db_of = Database('./test.json') # OneFileDB
db_f1 = db_of.transfer_to_folderdb('./test') # FolderDB
db_f2 = Database('./test') # FolderDB
db_cf = Database('./test.cfolder') # CFolderDB
这几种 DB 的使用,也是通常的写入,下标读取,遍历,获得 samples 长度等。
而对于 Folder 类的 DB 来说,还有额外的 append 函数,方便其增加新的 samples。
db.write(samples=j)
db_f.append(samples=j)
for idx, item in enumerate(db):
print(idx, item)
print(db.__len__())
print(db[1])
# coding: utf-8
# ==========================================================================
# Copyright (C) 2016-2020 All rights reserved.
#
# filename : training_dbs_new.py
# origin : cyx / caoyixuan
# author : chendian / [email protected]
# date : 2020-07-21
# desc : An alternative to the original database class (multi-json).
# can be called as a dict or a list.
# ==========================================================================
class Database(object):
"""
A unified wrapper for OneFileDB, FolderDB
"""
def __init__(self, path, samples=None, n_samples=None, read_only=True, load_now=False):
if samples is not None:
db = OneFileDB(path, samples, n_samples=n_samples)
else:
mode = self.determine_mode(path)
logging.info('database mode: {}'.format(mode))
if mode == 'all_samples_one_file':
db = OneFileDB(path, samples=None, n_samples=n_samples,
read_only=read_only, load_now=load_now)
elif mode == 'one_sample_per_file':
db = FolderDB(path, n_samples=n_samples,
read_only=read_only, load_now=load_now)
elif mode == 'cfolder':
db = CFolderDB(path, n_samples=n_samples,
read_only=read_only, load_now=load_now)
else:
raise ValueError("Unknown mode: {}".format(mode))
self.db = db
self.sids = db.sids
@staticmethod
def determine_mode(label_path):
if label_path.endswith('.json'):
mode = 'all_samples_one_file'
elif label_path.endswith('.cfolder') or label_path.endswith('.cfolder/'):
mode = 'cfolder'
else: # directory path without postfix
mode = 'one_sample_per_file'
return mode
def write(self, samples):
return self.db.write(samples)
def get_by_sid(self, sid):
return self.db.get_by_sid(sid)
def __getitem__(self, item):
if isinstance(item, slice):
return self.sl(item)
return self.db[item]
def sl(self, key):
start, stop, step = key.indices(len(self))
for i in range(start, stop, step):
yield self.db[i]
def __len__(self):
return self.db.__len__()
def __iter__(self):
return self.db.__iter__()
def next(self):
return self.db.next()
@property
def all_samples(self):
return self.db.all_samples
if __name__ == "__main__":
sd = Database('./test')
class TrainDBBase(object):
"""
An immutable dataset once write.
"""
def write(self, samples):
"""save samples"""
raise NotImplementedError()
def get_by_sid(self, sid):
"""get sample by sid"""
raise NotImplementedError()
def __getitem__(self, item):
""" get sample by index in dataset"""
raise NotImplementedError()
def __len__(self):
"""return the number of samples in this dataset"""
raise NotImplementedError()
def __iter__(self):
self.n = 0
return self
def next(self):
if self.n == self.__len__():
raise StopIteration
n = self.n
self.n += 1
return self[n]
def __next__(self):
return self.next()
@property
def all_samples(self):
"""return all samples in this dataset"""
return [self[i] for i in range(len(self))]
class FolderDB(TrainDBBase):
"""
一个sample写到一个文件里,一个DB就是一个文件夹,只能按照文件名进行索引
NEW: 也可以按下标遍历
"""
def __init__(self, folder, n_samples=None, read_only=True, load_now=False):
self.folder = folder
self.compress = False
self.n_samples = n_samples
self.sids = None
if load_now:
self.load_register()
def write(self, samples):
write_one_sample_per_file(samples, self.folder)
def append(self, samples):
append_write_one_sample_per_file(samples, self.folder)
def get_by_sid(self, sid):
file_path = path_join(self.folder, sid)
sample = json.load(open(file_path))
return sample
def __getitem__(self, index):
self.load_register()
sid = self.sids[index]
return self.get_by_sid(sid)
def __len__(self):
self.load_register()
return len(self.sids)
def load_register(self):
if self.sids is not None:
return
sids = load_register(self.folder)
if self.n_samples:
sids = sids[: self.n_samples]
self.sids = sids
assert len(self.sids) == len(set(self.sids)), 'exist duplicated sids'
class CFolderDB(FolderDB):
"""A json-encrypted FolderDB"""
def write(self, samples):
write_one_sample_per_file(samples, self.folder, compress=True)
def get_by_sid(self, sid):
file_path = path_join(self.folder, sid)
sample = json_load(path=file_path, mode='r', decrypt=True)
# sample = json.loads(zlib.decompress(open(file_path, 'rb').read()).decode('utf-8'))
return sample
class OneFileDB(TrainDBBase):
""" Single file as a DB"""
def __init__(self, file_path, samples=None, n_samples=None, read_only=True, load_now=False):
self.file_path = file_path
self.sids = None
self.samples = None
self.compress = False
self.sid_to_sample = None
self.n_samples = n_samples
if samples is not None:
self.set_samples(samples)
else:
if load_now:
self.load()
def write(self, samples):
json_dump(
obj_=samples, path=self.file_path,
mode='w', encrypt=self.compress)
def get_by_sid(self, sid):
self.load()
return self.sid_to_sample[sid]
def load(self):
if self.samples is not None:
return
samples = json_load(
path=self.file_path, mode='r',
decrypt=self.compress)
self.set_samples(samples)
def set_samples(self, samples):
# make a minor database for testing.
if self.n_samples:
samples = samples[: self.n_samples]
self.samples = samples
self.sids = [s['info']['sid'] for s in self.samples]
self.sid_to_sample = {s['info']['sid']: s for s in self.samples}
def transfer_to_folderdb(self, path):
write_one_sample_per_file(
answers=self.samples,
folder=path,
compress=self.compress)
return Database(path=path)
def __getitem__(self, item):
self.load()
return self.samples[item]
def __len__(self):
self.load()
return len(self.samples)
这种任务,最麻烦的就是 Python2 和 Python3 之间的兼容性,兼容性最麻烦的又体现在编码上,Python2的
unicode
编码即Python3的str
编码,Python2的str
编码即Python3的bytes
编码,于是
from __future__ import unicode_literals
from six import PY2, PY3
import logging
import os
import zlib
import numpy as np
from io import open
JSON_MODULE = None
try:
# if you have ujson, it will be faster
# but the calling method is different.
import ujson as json
JSON_MODULE = 'ujson'
except ImportError:
import json
JSON_MODULE = 'json'
class JsonBytesEncoder(json.JSONEncoder):
# json.dumps
def default(self, obj):
# if isinstance(obj, np.ndarray):
# return obj.tolist() # for further support.
if isinstance(obj, bytes):
return convert_to_unicode(obj)
# return str(obj, encoding='utf-8')
return json.JSONEncoder.default(self, obj)
def json_dumps(obj_, encrypt=False):
if JSON_MODULE == 'json':
_json_str = json.dumps(
obj_, cls=JsonBytesEncoder)
elif JSON_MODULE == 'ujson':
if int(json.__version__[0]) < 2:
# standard ujson-1.35 for python2.7
_json_str = json.dumps(obj_)
else: # standard ujson-3.0.0 for python3.6
_json_str = json.dumps(
obj_, reject_bytes=False)
else:
_json_str = json.dumps(obj_)
if encrypt:
return zlib_encrypt(_json_str)
return _json_str
def json_dump(obj_, path=None, mode='w', stream=None, encrypt=False):
# the same as json.dump(zlib_encrypt(obj_), open(path, 'w'))
# use 'w', not 'wb' in python3 for
# TypeError: a bytes-like object is required, not 'str'
if encrypt: # the zlib.compress transfers data into bytes
mode = 'wb'
if stream is not None:
# stream contains path and mode
stream.write(json_dumps(obj_, encrypt))
else:
with open(path, mode) as f:
f.write(json_dumps(obj_, encrypt))
def json_loads(str_, decrypt=False):
if decrypt:
str_ = zlib_decrypt(str_)
# all kinds of json have the same loads()
data = json.loads(str_)
return data
def json_load(path, mode='r', decrypt=False):
# the same as json.load(open(path, mode))
if decrypt: # the zlib.compress transfers data into bytes
mode = 'rb'
with open(path, mode) as f:
obj_ = json_loads(f.read(), decrypt)
return obj_
def zlib_encrypt(data):
# return an encrypted string
if isinstance(data, (list, dict, tuple)):
j_str = json_dumps(data) # data-structure to json-str
else: # to unicode (py2-unicode or py3-str)
j_str = convert_to_unicode(data)
# zlib only allow bytes-like inputs
return zlib.compress(convert_to_bytes(j_str))
def zlib_decrypt(str_):
# return a json_str in unicode
b_str = zlib.decompress(str_)
return convert_to_unicode(b_str)
def path_join(*args):
return ''.join(convert_to_unicode(each) for each in args)
def write_data(stream, text, encoding='unicode'):
# once write **text** into a file, need to know
# the basestring for py2 and py3 are different
if encoding in ['unicode', 'u']:
stream.write(convert_to_unicode(text))
elif encoding in ['bytes', 'utf-8', 'b']:
stream.write(convert_to_bytes(text))
else: # others
stream.write(text)
def convert_to_unicode(text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if isinstance(text, (int, float)):
text = '{}'.format(text)
if PY3:
if isinstance(text, str): # py3-str is unicode
return text
elif isinstance(text, bytes): # py3-bytes is py2-str
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif PY2:
if isinstance(text, str): # py2-str is py3-bytes
return text.decode("utf-8", "ignore")
elif isinstance(text, unicode): # py2-unicode is py3-str
return text
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def convert_to_bytes(text):
if PY2 and isinstance(text, str):
return text
elif PY3 and isinstance(text, bytes):
return text
u_text = convert_to_unicode(text)
return u_text.encode('utf-8')
def recursive_encoding_unification(cur_node):
from collections import OrderedDict
reu = recursive_encoding_unification
if isinstance(cur_node, (list, tuple)):
return type(cur_node)(
[reu(item) for item in cur_node])
elif isinstance(cur_node, (dict, OrderedDict)):
return type(cur_node)(
[(reu(k), reu(v)) for (k, v) in cur_node.items()])
elif isinstance(cur_node, (int, float)):
return cur_node
elif cur_node is None:
return None
else: # str, bytes, unicode
# only convert leaf-nodes
return convert_to_unicode(cur_node)
def json_unicode(json_dict):
return recursive_encoding_unification(json_dict)
def write_one_sample_per_file(answers, folder, compress=False):
register = ['{}'.format(s['info']['sid']) for i, s in enumerate(answers)]
if not os.path.exists(folder):
os.mkdir(folder)
with open(path_join(folder, 'register'), 'w') as fw:
# 'w' for py2 and py3 is different
write_data(fw, '\n'.join(register))
for s in answers:
file_path = path_join(folder, s['info']['sid'])
json_dump(s, path=file_path, encrypt=compress)
def append_write_one_sample_per_file(answers, folder, compress=False):
assert os.path.isdir(folder), 'folder should exist if you want to append to existing dataset'
sids = load_register(folder)
conflict_sids = set(sids).intersection([s['info']['sid'] for s in answers])
assert not conflict_sids, 'some sids already exist: {}'.format(list(conflict_sids)[:10])
new_register = ['{}'.format(s['info']['sid']) for i, s in enumerate(answers)]
# saving bytes is faster, but here is 'append' without 'b'
# remain storing as source text
with open(path_join(folder, 'register'), 'a') as fw:
# 'w' for py2 and py3 is different
write_data(fw, '\n')
write_data(fw, '\n'.join(new_register))
f = 0
for s in answers:
try:
sid_str = convert_to_unicode(s['info']['sid'])
json_dump(obj_=s, path=path_join(folder, sid_str), encrypt=False)
except OverflowError:
logging.warn('{} save error'.format(s['info']['sid']))
f += 1
if f > 30:
break
def load_register(folder, n_samples=None):
sids = []
# loading bytes is faster (append with 'ab+', loading with 'rb')
with open(path_join(folder, 'register'), 'r') as fr:
if n_samples is None:
# faster list-construction
sids = [line.strip().split(',')[-1] for line in fr]
else: # custom n_samples is usually small,
for line in fr: # list-appending will be faster.
sid = line.strip().split(',')[-1]
sids.append(sid)
if n_samples is not None:
if len(sids) >= n_samples:
break
return sids
def random_ints(n):
"""return n random ints that are distinct"""
assert n < 10 ** 9, 'Too many distinct numbers asked.'
row_randoms = np.random.randint(0, np.iinfo(np.int64).max, 2 * n)
uniques = np.unique(row_randoms)
while len(uniques) < n:
r = np.random.randint(0, np.iinfo(np.int64).max, 2 * n)
uniques = np.unique(np.stack([uniques, r]))
return uniques[:n]