自定义的log模块

"""

See README.md for a description of the logging API.

OFF state corresponds to having Logger.CURRENT == Logger.DEFAULT
ON state is otherwise

"""

from collections import OrderedDict
import os
import sys
import shutil
import os.path as osp
import json
# from baselines.common.console_util import fmt_row
from enum import Enum, unique


@unique
class type(Enum):
    stdout=0
    log = 1
    json = 2
    csv = 3

DEBUG = 10
INFO = 20
WARN = 30
ERROR = 40

DISABLED = 50


class OutputFormat(object):
    def writekvs(self, kvs, width):
        """
        Write key-value pairs
        """
        pass

    def writeseq(self, args, width):
        """
        Write a sequence of other data (e.g. a logging message)
        """
        pass

    def close(self):
        return

class CsvOuputFormat(OutputFormat):
    def __init__(self, file):
        self.file = file

    def writeseq(self, args, width):
        if isinstance(args,str):
            args = [args]
        for item in args:
            self.file.write(str(item))
            self.file.write('\t')
        self.file.write('\n')
        self.file.flush()

    def close(self):
        return

class HumanOutputFormat(OutputFormat):
    def __init__(self, file):
        self.file = file

    def writekvs(self, kvs, width_max):
        # Create strings for printing
        key2str = OrderedDict()
        for (key, val) in kvs.items():
            valstr = '%-8.3g' % (val,) if hasattr(val, '__float__') else val
            key2str[self._truncate(key,width_max)] = self._truncate(valstr, width_max)

        # Find max widths
        keywidth = max(map(len, key2str.keys()))
        valwidth = max(map(len, key2str.values()))

        # Write out the data
        dashes = '-' * (keywidth + valwidth + 7)
        lines = [dashes]
        for (key, val) in key2str.items():
            lines.append('| %s%s | %s%s |' % (
                key,
                ' ' * (keywidth - len(key)),
                val,
                ' ' * (valwidth - len(val)),
            ))
        lines.append(dashes)
        self.file.write('\n'.join(lines) + '\n')

        # Flush the output to the file
        self.file.flush()

    def _truncate(self, s, width_max):
        return s[:width_max-3] + '...' if len(s) > width_max else s

    def writeseq(self, args, width):
        if isinstance(args,str):
            args = [args]
        if isinstance(args,list) or isinstance(args, tuple):
            args = fmt_row( width=width, row=args)
        self.file.write(args)
        self.file.write('\n')
        self.file.flush()


class JSONOutputFormat(OutputFormat):
    def __init__(self, file):
        self.file = file

    def writekvs(self, kvs, width):
        for k, v in kvs.items():
            if hasattr(v, 'dtype'):
                v = v.tolist()
                kvs[k] = float(v)
        self.file.write(json.dumps(kvs) + '\n')
        self.file.flush()


def make_output_format(format, ev_dir=None, prefix='', overwrite=True):
    if ev_dir is not None:
        os.makedirs(ev_dir, exist_ok=True)
    if prefix != '':
        prefix += '_'
    mode = 'wt' if overwrite else 'at'
    if format == type.stdout:
        return HumanOutputFormat(sys.stdout)
    elif format == type.log:
        log_file = open(osp.join(ev_dir, prefix+'log.txt'), mode)
        return HumanOutputFormat(log_file)
    elif format == type.json:
        json_file = open(osp.join(ev_dir, prefix+'progress.json'), mode)
        return JSONOutputFormat(json_file)
    elif format == type.csv:
        csv_file = open(osp.join(ev_dir, prefix+'log.csv'), mode)
        return CsvOuputFormat(csv_file)
    else:
        raise ValueError('Unknown format specified: %s' % (format,))

class Logger(object):
    DEFAULT = None
    def __init__(self, output_formats, dir=None, name='', width_log=30, width_kv=30, overwrite=True):
        self.name2val = OrderedDict()  # values this iteration
        self.level = INFO
        self.dir = dir
        self.width_log = width_log
        self.width_kv = width_kv
        self.output_formats = [make_output_format(f, dir, name, overwrite=overwrite) for f in output_formats]

    def log(self, args, level=INFO, width=None):
        if width is None:
            width = self.width_log
        if self.level <= level:
            self._log(args, width)

    def logkv(self, key, val):
        self.name2val[key] = val

    def dumpkvs(self, width=None):
        if len(self.name2val) == 0:
            return
        if width is not None:
            self.width_kv = width
        for fmt in self.output_formats:
            fmt.writekvs(self.name2val, self.width_kv)
        self.name2val.clear()

    def _log(self, args, width):
        for fmt in self.output_formats:
            fmt.writeseq(args, width=width)

    def set_level(self, level):
        self.level = level

    def close(self):
        for fmt in self.output_formats:
            fmt.close()

Logger.DEFAULT = Logger(output_formats=[type.stdout])
def log(args,level=INFO, width=None):
    return Logger.DEFAULT.log(args,width=width, level=level)

def setlevel(level):
    return Logger.DEFAULT.set_level(level)

def logkv(key, val):
    return Logger.DEFAULT.logkv(key, val)

def dumpkvs(width=None):
    return Logger.DEFAULT.dumpkvs(width)


import time
class LogTime():
    def __init__(self, name, path_logger):
        self.time = time.time()
        self.ind = 0
        self.dict = {}
        self.interval_showtitle = 10#np.clip( args.interval_iter_save, 10, 100  )
        self.logger = Logger(dir=path_logger, output_formats=[type.csv], name=name,
                             overwrite=False, width_kv=20, width_log=20)

    def __call__(self, name):
        self.dict[ name ] = time.time() - self.time
        #self.dict[ name+'_time' ] = time.strftime('%m/%d|%H:%M:%S', time.localtime())
        self.time = time.time()

    def complete(self):
        self.dict['time_end'] = time.strftime('%m/%d|%H:%M:%S', time.localtime())
        if self.ind%self.interval_showtitle==0:
            self.logger.log( list( self.dict.keys() ) )
        self.logger.log( list(self.dict.values() ) )
        self.ind += 1


def _demo():

    dir = "log保存路径"
    l = Logger(dir=dir,output_formats=[type.stdout,type.csv],name='aa') #log文件的名字为aa
    log.log( ('Iter','Loss') ) # 迭代轮数及损失
    


if __name__ == "__main__":
    _demo()

你可能感兴趣的:(笔记杂)