数据库存储过程的单元测试工具

接着昨天写完了数据库数据录入与核对之后,今天写了个数据库存储过程单元测试的小工具,与大家分享一下。

github:https://github.com/xunmeng2002/python/tree/master/dbtest

本模块依赖于昨天完成的check_db模块:csv解析与db数据检查

本模块写了两个文件:test_struct.py 包含测试组件结果、测试用例结构及相关函数,run_utest.py 获取测试用例、配置环境、及执行,使用了一个配置文件 utest.ini

test_struct.py:

# encoding: utf-8
import os
import sys
import traceback
import datetime

sys.path.append("..\source_py")
from common_utils import common_utils
from common_utils import db_struct
from check_db import db_operate
from check_db import csv_parse


class TestSuite:
    def __init__(self, name, dir, cases):
        self.name = name
        self.cases = cases
        self.dir = dir
        self.fail_num = 0
        self.pass_num = 0

    def exec_suite(self, cursor):
        print "exec suite[%s]:" % self.name
        for case in self.cases:
            start = datetime.datetime.now()
            if case.exec_case(cursor):
                self.pass_num += 1
            else:
                self.fail_num += 1
            end = datetime.datetime.now()
            case.elapse = (end - start).seconds


class TestCase:
    def __init__(self, name, dir, admin_db, history_db, init_db, sync_db):
        self.name = name
        self.dir = dir
        self.admin_db = admin_db
        self.history_db = history_db
        self.init_db = init_db
        self.sync_db = sync_db
        self.command = ""
        self.result = True
        self.msg = ''
        self.elapse = 0
        self.read_command()

    def read_command(self):
        command_path = os.path.join(self.dir, "command.txt")
        for line in open(command_path):
            line = common_utils.trim_return(line)
            if line.find("#") > 0:
                line = line[:line.find("#")]
            self.command += line

    def replace_db_name(self, sql):
        sql = sql.replace('admin.', self.admin_db + '.')
        sql = sql.replace('history.', self.history_db + '.')
        sql = sql.replace('init.', self.init_db + '.')
        sql = sql.replace('sync.', self.sync_db + '.')
        sql = sql.replace('', self.admin_db)
        sql = sql.replace('', self.history_db)
        sql = sql.replace('', self.init_db)
        sql = sql.replace('', self.sync_db)
        return sql

    def check_expect(self, cursor, sql):
        if sql.startswith('EXPECT_NOT_EQUAL'):
            values = sql.replace('EXPECT_NOT_EQUAL', '').replace('(', '').replace(')', '').replace(';', '').split(',')
            return self.expect_not_equal(cursor, values[0], values[1])
        else:
            values = sql.replace('EXPECT', '').replace('(', '').replace(')', '').replace(';', '').split(',')
            return self.expect_equal(cursor, values[0], values[1])

    def expect_equal(self, cursor, value1, value2):
        sql1 = "select " + value1
        cursor.execute(sql1)
        v1 = cursor.fetchone()[0]
        sql2 = "select " + value2
        cursor.execute(sql2)
        v2 = cursor.fetchone()[0]
        if v1 != v2:
            self.result = False
            self.msg = "EXPECT(%s,%s):[%s]不等于[%s]" % (value1, value2, str(v1), str(v2))
            return False
        return True

    def expect_not_equal(self, cursor, value1, value2):
        sql1 = "select " + value1
        cursor.execute(sql1)
        v1 = cursor.fetchone()[0]
        sql2 = "select " + value2
        cursor.execute(sql2)
        v2 = cursor.fetchone()[0]
        if v1 == v2:
            self.result = False
            self.msg = "EXPECT_NOT_EQUAL(%s,%s):[%s]等于[%s]" % (value1, value2, str(v1), str(v2))
            return False
        return True

    def exec_case(self, cursor):
        print "\texec case[%s]." % self.name
        try:
            init_dir = os.path.join(self.dir, "init")
            expect_dir = os.path.join(self.dir, "expect")
            db_operate.load_csv_into_db(cursor, init_dir, self.admin_db, self.history_db, self.init_db, self.sync_db)
            sqls = common_utils.trim_return(self.command).split(';')
            for sql in sqls:
                sql = self.replace_db_name(sql)
                if sql == '':
                    continue
                elif sql.startswith("EXPECT"):
                    if self.result is False:
                        return False
                else:
                    try:
                        cursor.execute(sql)
                    except Exception, e:
                        print sql
                        raise e
            db_info = db_struct.DbInfo(self.admin_db, self.history_db, self.init_db, self.sync_db)
            csv_parse.parse_csvs(cursor, expect_dir, db_info)
            self.result, self.msg = db_operate.check_all_result(cursor, db_info)
            return self.result
        except Exception, e:
            self.msg = traceback.format_exc()
            self.result = False
            return self.result

run_utest.py:

# encoding:utf-8
import os
import ConfigParser
import datetime
import sys
import re
import webbrowser as web
import test_struct

sys.path.append("..\source_py")
from check_db import db_operate

reload(sys)
sys.setdefaultencoding("utf-8")
# 参数,全局使用
arg_names = ['help', 'filter', 'suite-filter', 'only-load-data']
args = {}


def get_suites(root_dir, admin_db, history_db, init_db, sync_db):
    test_suites = []
    for name in os.listdir(root_dir):
        if name == 'report':
            continue
        suite_dir = os.path.join(root_dir, name)
        if not os.path.isdir(suite_dir):
            continue
        if "suite-filter" in args and re.search(args["suite-filter"], name) is None:
            continue
        suite = test_struct.TestSuite(name, suite_dir, get_cases(suite_dir, admin_db, history_db, init_db, sync_db))
        if len(suite.cases) > 0:
            test_suites.append(suite)
    return test_suites


def get_cases(suite_dir, admin_db, history_db, init_db, sync_db):
    cases = []
    for name in os.listdir(suite_dir):
        case_dir = os.path.join(suite_dir, name)
        if not os.path.isdir(case_dir):
            continue
        if "filter" in args and re.search(args["filter"], name) is None:
                continue
        cases.append(test_struct.TestCase(name, case_dir, admin_db, history_db, init_db, sync_db))
    return cases


def print_summary(suites):
    print 'test result:'
    for suite in suites:
        for case in suite.cases:
            if case.result:
                print 'pass  :[%s].%s' % (suite.name, case.name)
            else:
                print '  fail:[%s].%s' % (suite.name, case.name)


def parse_args():
    for i in range(1, len(sys.argv)):
        kv = sys.argv[i].split("=")
        if not kv[0].startswith('--') or kv[0][2:] not in arg_names:
            print 'error:%s is invalid argument' % kv[0]
            exit(-1)
        if len(kv) < 2:
            args[kv[0][2:]] = None
        else:
            args[kv[0][2:]] = kv[1]


def report(suites, start_time, end_time):
    html_name = os.path.join("./report", "report%s.html" % start_time.strftime("%Y%m%d_%H%M%S"))
    html = open(html_name, "wb")

    html.write("																\n")
    html.write("                                                      				\n")
    html.write("                                                            				\n")
    html.write("                                                      				\n")
    html.write("	                                  				\n")
    html.write("	测试结果                                 				\n")
    html.write("	                                                				\n")
    html.write("                                                     				\n")

    html.write(" 										\n")
    html.write("	开始时间:%s\n" % start_time.strftime("%Y-%m-%d %H:%M:%S"))
    html.write("	结束时间:%s\n" % end_time.strftime("%Y-%m-%d %H:%M:%S"))
    html.write("	执行时间:%d 秒\n" % (end_time - start_time).seconds)

    case_pass_num = 0
    case_fail_num = 0
    suite_pass_num = 0
    suite_fail_num = 0
    for suite in suites:
        case_pass_num += suite.pass_num
        case_fail_num += suite.fail_num
        if suite.fail_num > 0:
            suite_fail_num += 1
        else:
            suite_pass_num += 1

    html.write("	总共执行测试套件数量:%d,通过数量:%d,失败数量:%d\n" % (
        suite_pass_num + suite_fail_num, suite_pass_num, suite_fail_num))
    html.write("	总共执行测试用例数量:%d,通过数量:%d,失败数量:%d\n" % (
        case_pass_num + case_fail_num, case_pass_num, case_fail_num))

    html.write("		                           				\n")
    html.write("																			\n")
    html.write("			                                                                \n")
    html.write("				                                               \n")
    html.write("				                                               \n")
    html.write("				                                                   \n")
    html.write("				                                                   \n")
    html.write("			                                                               \n")
    html.write("		                                                                \n")
    html.write("		                                                                 \n")
    for suite in suites:
        first = True
        for case in suite.cases:
            html.write("			\n")
            if first:
                first = False
                html.write("				\n" %
                           (len(suite.cases), suite.name, suite.pass_num + suite.fail_num, suite.fail_num))
            html.write("				\n" % case.name)
            if case.result:
                html.write("				\n")
            else:
                html.write("				\n" % case.msg)

            html.write("				\n" % case.elapse)
            html.write("			\n")

    html.write("		                                                                \n")
    html.write("	
测试套件测试用例结果耗时
%s(%d,%d)%spassfailed:\n%s%d秒
\n") html.write(" \n") html.write(" \n") html.close() web.open_new_tab(html_name) def show_help(): print 'args:' print ' --filter=xxxxx [optional] filt test case' print ' --suite-filter=xxxxx [optional] filt test suite' print ' --only-load-data [optional] only load init data' print ' --help [optional] show help info' def main(): parse_args() if 'help' in args: show_help() exit(0) cfg = ConfigParser.ConfigParser() cfg.read("utest.ini") db_user = cfg.get('db', 'user') db_password = cfg.get('db', 'password') db_host = cfg.get('db', 'host') db_port = cfg.get('db', 'port') db_database = cfg.get('db', 'curr_db') conn = db_operate.connect_db(db_user, db_password, db_host, db_port, db_database) cursor = conn.cursor() admin_db = cfg.get("db", "admin_db") history_db = cfg.get("db", "history_db") init_db = cfg.get("db", "init_db") sync_db = cfg.get("db", "sync_db") start_time = datetime.datetime.now() root_dir = cfg.get('path', 'root_dir') suites = get_suites(root_dir, admin_db, history_db, init_db, sync_db) for suite in suites: suite.exec_suite(cursor) end_time = datetime.datetime.now() report(suites, start_time, end_time) print_summary(suites) if __name__ == '__main__': main()

utest.ini:

[db]
user=test
password=Test@1234
host=192.168.6.125
port=3306
admin_db=test_admin
history_db=test_history
init_db=test_init
sync_db=test_sync
curr_db=test_history

[path]
root_dir=./settlement_test_suite

 

你可能感兴趣的:(Python,SQL,UnitTest,紫云的程序人生)