Jtester+unitils+testng:DAO单元测试文件模板自动生成


         在使用 jtester+unitils+testng 做数据库接口的单元测试框架中, 常常需要编写一些 wiki 及 DAOTest java 文件,  比如:

public class XXXDefaultDAOTest extends BaseRegionDbDAOTestCase {

	@SpringBeanByName
	private XXXDefaultDAO XXXDefaultDAO;
	
	@Test
	@DbFit(when="XXXDefaultDAOTest.initBlank.when.wiki", then="XXXDefaultDAOTest.queryOneRecord.then.wiki")
	public void testInsertXXXDefaultDO() {
		XXXDefaultDO XXXDefaultDO = new XXXDefaultDO();
		XXXDefaultDO.setId(1L);
		XXXDefaultDO.setCidrBlock("192.168.10.10");
		XXXDefaultDO.setIpProtocol("tcp");
		XXXDefaultDO.setPortRange("3000:4000");
		XXXDefaultDO.setPolicy(Policy.POLICY_ACCEPT);
		XXXDefaultDO.setNic(Nic.INTRANET);
		XXXDefaultDO.setPriority(65533L);
		XXXDefaultDO.setType(1L);
		XXXDefaultDO.setIsDeleted(0L);
		XXXDefaultDO.setDescription("test1");
		XXXDefaultDO.setGmtCreate(new Date());
		XXXDefaultDO.setGmtModify(new Date());
		XXXDefaultDAO.insertXXXDefaultDO(XXXDefaultDO);
	}

	@Test
	@DbFit(when="XXXDefaultDAOTest.initRecords.when.wiki")
	public void testCountXXXDefaultDOByExample() {
		XXXDefaultDO XXXDefaultDO = new XXXDefaultDO();
		Assert.assertTrue(XXXDefaultDAO.countXXXDefaultDOByExample(XXXDefaultDO) == 1);
	}

	@Test
	@DbFit(when="XXXDefaultDAOTest.initRecords.when.wiki", then="XXXDefaultDAOTest.testUpdate.then.wiki")
	public void testUpdateXXXDefaultDO() {
		XXXDefaultDO found = XXXDefaultDAO.findXXXDefaultDOByPrimaryKey(6L);
		found.setIpProtocol("udp");
		found.setNic(Nic.INTERNET);
		found.setDescription("desc");
		XXXDefaultDAO.updateXXXDefaultDO(found);
	}

	@Test
	@DbFit(when="XXXDefaultDAOTest.initRecords.when.wiki")
	public void testFindListByExample() {
		String cidrBlock = "10.152.126.83";
		Policy policy = Policy.POLICY_ACCEPT;
		XXXDefaultDO XXXDefault = new XXXDefaultDO();
		XXXDefault.setCidrBlock(cidrBlock);
		XXXDefault.setPolicy(policy);
		List<XXXDefaultDO> list = XXXDefaultDAO.findListByExample(XXXDefault);
		Assert.assertEquals(list.size(), 1);
		for (XXXDefaultDO XXXDefaultDO: list) {
			Assert.assertEquals(XXXDefaultDO.getCidrBlock(), cidrBlock);
			Assert.assertEquals(XXXDefaultDO.getPolicy(), policy);
		}
	}

	@Test
	@DbFit(when="XXXDefaultDAOTest.initRecords.when.wiki")
	public void testFindXXXDefaultDOByPrimaryKey() {
		XXXDefaultDO found = XXXDefaultDAO.findXXXDefaultDOByPrimaryKey(6L);
		Assert.assertEquals(found.getCidrBlock(), "10.152.126.83");
		Assert.assertEquals(found.getIpProtocol(), "all");
		Assert.assertEquals(found.getPortRange(), "");
		Assert.assertEquals(found.getPolicy(), Policy.POLICY_ACCEPT);
		Assert.assertEquals(found.getNic(), Nic.BOTH);
		Assert.assertEquals(found.getPriority().longValue(),1L);
		Assert.assertEquals(found.getType().intValue(), 1);
		Assert.assertEquals(found.getIsDeleted().intValue(), 0);
		Assert.assertEquals(found.getDescription(), "bie dong");
	}

	@Test
	@DbFit(when="XXXDefaultDAOTest.initRecords.when.wiki", then="")
	public void testDeleteXXXDefaultDOByPrimaryKey() {
		Integer count = XXXDefaultDAO.deleteXXXDefaultDOByPrimaryKey(6L);
		Assert.assertEquals(count.intValue(), 1);
		
		Integer nodelete = XXXDefaultDAO.deleteXXXDefaultDOByPrimaryKey(6L);
		Assert.assertEquals(nodelete.intValue(), 0);
	}

}

       其中, 数据准备文件在 *.when.wiki 中, 数据验证文件在 *.then.wiki 中, 数据库中只需要保证正确的表结构即可。 每次单元测试都是自动化可重复的。

       XXXDefaultDAOTest.initBlank.when.wiki

|connect|
|clean table|xxx_default|
      XXXDefaultDAOTest.initRecords.when.wiki

|connect|
|clean table|xxx_default|
|clean table|xxx|
|insert|xxx_default|
| id | gmt_create          | gmt_modify          | cidr_block    | ip_protocol | port_range | policy | nic | priority | type | is_deleted | description |
|  6 | 2014-04-08 20:18:04 | 2014-04-08 20:18:04 | 10.152.126.83 | all         |            | accept |   3 |        1 |    1 |          0 | bie dong    |
    XXXDefaultDAOTest.queryOneRecord.then.wiki

|connect|
|query|select cidr_block,  ip_protocol, port_range, policy, nic, priority, type, is_deleted, description from xxx_default |
|cidr_block    | ip_protocol | port_range | policy | nic | priority | type | is_deleted | description |
|192.168.10.10 | tcp         | 3000:4000  | accept |   2 |    65533 |    1 |          0 | test1       |
    XXXDefaultDAOTest.testUpdate.then.wiki

|connect|
|query|select cidr_block,  ip_protocol, port_range, policy, nic, priority, type, is_deleted, description from xxx_default|
| cidr_block    | ip_protocol | port_range | policy | nic | priority | type | is_deleted | description |
| 10.152.126.83 | udp         |            | accept |   1 |        1 |    1 |          0 | desc        |

       显然, 如果每个 DAO 测试类都写这些 WIKI  及 DAO 类(set/get 字段很耗体力), 那会是比较大的工作量。 这时候, 最好能够自动生成这些文件或文件模板, 减少手工的劳动量。

   因此, 我编写了一个 python 程序, 在指定配置下, 可以自动生成相关的测试文件模板文件。

   readcfg.py :  读取DAO测试类信息的配置文件    

from ConfigParser import ConfigParser

config = ConfigParser()
config.read("daotest.conf")

def getAllDAOTestInfo():
	allDAOTest = {}
	secs = config.sections() 
	for sec in secs:
		allDAOTest[sec] = getDAOTestInfo(sec)
	return allDAOTest		

def getDAOTestInfo(daoTestName) :
    daoTestInfo = { 
	    'DaoTestName': config.get(daoTestName, 'DaoTestName') ,
	    'TableName': config.get(daoTestName, 'TableName'), 
	    'FieldArray': config.get(daoTestName, 'FieldArray'),
		'NumTypeFields': config.get(daoTestName, 'NumTypeFields'),
	}
    return daoTestInfo	

    create_daotest_wiki.py : 生成 dao 测试的测试文件模板:    

    

import readcfg
import time
import re

def gene_daotest(daoTestInfo):

	daoTestName = daoTestInfo['DaoTestName']
	tableName = daoTestInfo['TableName']
	fieldArray = re.split('\s*,\s*', daoTestInfo['FieldArray'])
	numTypeFields = set(re.split('\s*,\s*', daoTestInfo['NumTypeFields']))
	
	print ' *** ', daoTestName , ' start...\n'
	
	startTime = time.clock()
	gene_daotest_wiki_really(daoTestName, tableName, fieldArray, numTypeFields)
	gene_daotest_java(daoTestName, tableName, fieldArray, numTypeFields)
	endTime = time.clock()
	
	print ' *** ', daoTestName,  ' finished.\n'
	print 'time cost: ', str((endTime - startTime)*1000) + 'ms.\n'

def gene_daotest_wiki_really(daoTestName, tableName, fieldArray, numTypeFields):
	
	'''
		generate the wikies used for DAO test java file
	'''
	
	conn = '|connect|'
	clean_table = '|clean table|' + tableName + '|'
	insert_table = '|insert|' + tableName + '|'
	all_fields = '|' + getfieldsWithSep(fieldArray, 0, '|') + '|'
	query_stmt = '|query|' + 'select ' + getfieldsWithSep(fieldArray, 0, ', ', filterTimeAndIdFieldFunc) + ' from ' + tableName + '|'
	query_fields = '|' + getfieldsWithSep(fieldArray, 0, '|', filterTimeAndIdFieldFunc) + '|'
	all_fields_default_values = '|' + getfieldValuesWithSep(fieldArray, numTypeFields, 0, '|') + '|'
	query_fields_default_values = '|' + getfieldValuesWithSep(fieldArray, numTypeFields, 0, '|', filterTimeAndIdFieldFunc) + '|'
	
	# create DaoTestName.initBlank.when.wiki
	f_initBlank = open(daoTestName+".initBlank.when.wiki", 'w')
	f_initBlank.write('\n'.join([conn, clean_table]));
	f_initBlank.close
	
	# create DaoTestName.initRecords.when.wiki
	f_initRecs = open(daoTestName+".initRecords.when.wiki", 'w')
	f_initRecs.write('\n'.join([conn, clean_table, insert_table, all_fields, all_fields_default_values]))
	f_initRecs.close
	
	# create DaoTestName.queryOneRecord.then.wiki
	f_qor = open(daoTestName+".queryOneRecord.then.wiki", 'w')
	f_qor.write('\n'.join([conn, query_stmt, query_fields, query_fields_default_values]))
	f_qor.close
	
	# create DaoTestName.testUpdate.then.wiki
	f_update = open(daoTestName+".testUpdate.then.wiki", 'w')
	f_update.write('\n'.join([conn, query_stmt, query_fields, query_fields_default_values]))
	f_update.close
	
def gene_daotest_java(daoTestName, tableName, fieldArray, numTypeFields):
	
	f_daotest_java = open(daoTestName+'.java', 'w')
	f_daotest_tmpl = open('TemplateDefaultDAOTest.java')
	content = ''
	for line in f_daotest_tmpl:
		content += line
	daoPrefixIndex = daoTestName.find('DAOTest')
	daoPrefix = daoTestName[0: daoPrefixIndex]
	XXXReplacer = daoPrefix
	YYYReplacer = firstLowerCase(XXXReplacer)
	filteredFieldArray = getFilteredFields(fieldArray, filterTimeFieldFunc)
	contentReplaced = content.replace('XXX', XXXReplacer).replace('YYY', YYYReplacer)  \
							 .replace('$setFields', geneSetFields(filteredFieldArray, numTypeFields, YYYReplacer)) \
							 .replace('$AssertGetValues', geneAssertGetValues(filteredFieldArray, numTypeFields, YYYReplacer))
	f_daotest_java.write(contentReplaced)


def geneAssertGetValues(fieldArray, numTypeFields, YYYReplacer):
	content = ''
	for field in fieldArray:
		quoteStr = '' if field in numTypeFields else '"'
		content += 'Assert.assertEquals(%s.get%s(), %s%s%s);\n%s' %  \
		           (YYYReplacer, transformField(field), quoteStr, getDefaultValueForField(field, numTypeFields), quoteStr, indentTimes(2)) 
	return content
	
def geneSetFields(fieldArray, numTypeFields, YYYReplacer):
	content = ''
	for field in fieldArray:
		quoteStr = '' if field in numTypeFields else '"'
		content += '%s.set%s(%s%s%s);\n%s' % \
		          (YYYReplacer, transformField(field), quoteStr, getDefaultValueForField(field, numTypeFields), quoteStr, indentTimes(2))
	return content
	
def transformField(field):
	'''
	   convert field with UnderLine form to Camel Form
	   eg.  cidr_block ==> CidrBlock		
	'''
	
	parts = field.split('_')
	content = ''
	for part in parts: 
		content += firstSuperCase(part)
	return content

def indentTimes(num):
	indent = '';
	while num > 0 :
		indent += '\t'
		num -= 1
	return indent
	
def firstLowerCase(input):
	'''
		the first letter lowered. eg. NcDAOTest ==> ncDAOTest
	'''	
	return input[0].lower() + input[1:]
	
def firstSuperCase(input):
	'''
		the first letter uppered. eg. ncDAOTest ==> NcDAOTest
	'''	
	return input[0].upper() + input[1:]	
	
def nopFunc(field):
	return True	

def currTime():
	return time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))		
	
def getDefaultValueForField(field, numTypeFields):
	'''
	    get the default value of field and return a value
		if you want to set more proper value , do it here.
	'''
	if field == 'id' or field.find('_id') != -1:
		return '1'
	elif field.find('gmt_') != -1:
		return currTime()
	elif field.find('ip') != -1 or field.find('addr') != -1:
		return '172.16.0.1'
	elif field.find('cidr') != -1:
		return '172.16.0.0/22'
	elif field in numTypeFields:
		return '0'
	else :
		return 'test-'+field
		
def getfieldsWithSep(fieldArray, index=0, sep='|', filterFunc=nopFunc):
	if index < 0 or index > len(fieldArray):
		raise Exception('index '  + index + ' invalid: must be in [0,' + len(fieldArray) + ']')
	fieldFilteredArray = getFilteredFields(fieldArray, filterFunc)
	return sep.join(fieldFilteredArray[index:])	
	
def getfieldValuesWithSep(fieldArray, numTypeFields, index=0, sep='|', filterFunc=nopFunc):
	if index < 0 or index > len(fieldArray):
		raise Exception('index '  + index + ' invalid: must be in [0,' + len(fieldArray) + ']')
	fieldFilteredArray = getFilteredFields(fieldArray, filterFunc)
	fieldDefaultValues = []
	for field in fieldFilteredArray:
		fieldDefaultValues.append(getDefaultValueForField(field, numTypeFields)) 
	return sep.join(fieldDefaultValues)
	
def filterTimeAndIdFieldFunc(field):
	return field.find('gmt_') == -1 and field != 'id'

def filterTimeFieldFunc(field):
	return field.find('gmt_') == -1 
	
def getFilteredFields(fieldArray, filterFunc):
	return filter(filterFunc, fieldArray)
	
if __name__ == '__main__':
	allDAOTest = readcfg.getAllDAOTestInfo()
	for daoTestName, daoTestInfo in allDAOTest.iteritems():
		gene_daotest(daoTestInfo)

     daotest.conf:  配置文件

[VmDAOTest]
DaoTestName=VmDAOTest
TableName=vm
FieldArray=id,gmt_create,gmt_modify,vm_name,cores,mem,disk,status,nc_id,is_deleted
NumTypeFields=id,cores,mem,disk,status,nc_id,is_deleted

[NcDAOTest]
DaoTestName=NcDAOTest
TableName=nc
FieldArray=id,gmt_create,gmt_modify,hostname,ip,avail_cpu, avail_mem, avail_disk
NumTypeFields=id,avail_cpu, avail_mem, avail_disk

     DAO java 文件模板:

     

package xxx.dao.regiondb.impl;

import java.util.Date;
import java.util.List;

import org.jtester.unitils.dbfit.DbFit;
import org.testng.Assert;
import org.testng.annotations.Test;
import org.unitils.spring.annotation.SpringBeanByName;

import xxx.BaseRegionDbDAOTestCase;
import xxx.constant.group.Nic;
import xxx.constant.group.Policy;
import xxx.dao.regiondb.XXXDAO;
import xxx.model.db.XXXDO;

public class XXXDAOTest extends BaseRegionDbDAOTestCase {

	@SpringBeanByName
	private XXXDAO YYYDAO;
	
	@Test
	@DbFit(when="XXXDAOTest.initBlank.when.wiki", then="XXXDAOTest.queryOneRecord.then.wiki")
	public void testInsertXXXDO() {
		XXXDO YYY = new XXXDO();
		$setFields
		YYY.setGmtCreate(new Date());
		YYY.setGmtModify(new Date());
		YYYDAO.insertXXXDO(YYY);
	}

	@Test
	@DbFit(when="XXXDAOTest.initRecords.when.wiki")
	public void testCountXXXDOByExample() {
		XXXDO YYY = new XXXDO();
		Assert.assertTrue(YYYDAO.countXXXDOByExample(YYY).intValue() == 1);
	}

	@Test
	@DbFit(when="XXXDAOTest.initRecords.when.wiki", then="XXXDAOTest.testUpdate.then.wiki")
	public void testUpdateXXXDO() {
		XXXDO YYY = YYYDAO.findXXXDOByPrimaryKey();
		$setFields
		YYYDAO.updateXXXDO(YYY);
	}

	@Test
	@DbFit(when="XXXDAOTest.initRecords.when.wiki")
	public void testFindListByExample() {
		XXXDO YYY = new XXXDO();
		$setFields
		List<XXXDO> list = YYYDAO.findListByExample(YYY);
		Assert.assertEquals(list.size(), 1);
		for (XXXDO YYYDO: list) {
			$AssertGetValues
		}
	}

	@Test
	@DbFit(when="XXXDAOTest.initRecords.when.wiki")
	public void testFindXXXDOByPrimaryKey() {
		XXXDO YYY = YYYDAO.findXXXDOByPrimaryKey(1L);
		$AssertGetValues
	}

	@Test
	@DbFit(when="XXXDAOTest.initRecords.when.wiki")
	public void testDeleteXXXDOByPrimaryKey() {
		Integer count = YYYDAO.deleteXXXDOByPrimaryKey(1L);
		Assert.assertEquals(count.intValue(), 1);
		
		Integer nodelete = YYYDAO.deleteXXXDOByPrimaryKey(1L);
		Assert.assertEquals(nodelete.intValue(), 0);
	}

}

     运行: $ python create_daotest_wiki.py

     生成以下文件: 

     Jtester+unitils+testng:DAO单元测试文件模板自动生成

      其中: 
      VmDAOTest.initBlank.when.wiki
  

|connect|
|clean table|vm|

     VmDAOTest.initRecords.when.wiki    

|connect|
|clean table|vm|
|insert|vm|
|id|gmt_create|gmt_modify|vm_name|cores|mem|disk|status|nc_id|is_deleted|
|1|2014-05-22 12:51:38|2014-05-22 12:51:38|test-vm_name|0|0|0|0|1|0|

     VmDAOTest.queryOneRecord.when.wiki / VmDAOTest.testUpdate.when.wiki       

|connect|
|query|select vm_name, cores, mem, disk, status, nc_id, is_deleted from vm|
|vm_name|cores|mem|disk|status|nc_id|is_deleted|
|test-vm_name|0|0|0|0|1|0|

      生成的DAOTEST Java 文件: 

package xxx.dao.regiondb.impl;

import java.util.Date;
import java.util.List;

import org.jtester.unitils.dbfit.DbFit;
import org.testng.Assert;
import org.testng.annotations.Test;
import org.unitils.spring.annotation.SpringBeanByName;

import xxx.BaseRegionDbDAOTestCase;
import xxx.constant.group.Nic;
import xxx.constant.group.Policy;
import xxx.dao.regiondb.VmDAO;
import xxx.model.db.VmDO;

public class VmDAOTest extends BaseRegionDbDAOTestCase {

	@SpringBeanByName
	private VmDAO vmDAO;
	
	@Test
	@DbFit(when="VmDAOTest.initBlank.when.wiki", then="VmDAOTest.queryOneRecord.then.wiki")
	public void testInsertVmDO() {
		VmDO vm = new VmDO();
		vm.setId(1);
		vm.setVmName("test-vm_name");
		vm.setCores(0);
		vm.setMem(0);
		vm.setDisk(0);
		vm.setStatus(0);
		vm.setNcId(1);
		vm.setIsDeleted(0);
		
		vm.setGmtCreate(new Date());
		vm.setGmtModify(new Date());
		vmDAO.insertVmDO(vm);
	}

	@Test
	@DbFit(when="VmDAOTest.initRecords.when.wiki")
	public void testCountVmDOByExample() {
		VmDO vm = new VmDO();
		Assert.assertTrue(vmDAO.countVmDOByExample(vm).intValue() == 1);
	}

	@Test
	@DbFit(when="VmDAOTest.initRecords.when.wiki", then="VmDAOTest.testUpdate.then.wiki")
	public void testUpdateVmDO() {
		VmDO vm = vmDAO.findVmDOByPrimaryKey();
		vm.setId(1);
		vm.setVmName("test-vm_name");
		vm.setCores(0);
		vm.setMem(0);
		vm.setDisk(0);
		vm.setStatus(0);
		vm.setNcId(1);
		vm.setIsDeleted(0);
		
		vmDAO.updateVmDO(vm);
	}

	@Test
	@DbFit(when="VmDAOTest.initRecords.when.wiki")
	public void testFindListByExample() {
		VmDO vm = new VmDO();
		vm.setId(1);
		vm.setVmName("test-vm_name");
		vm.setCores(0);
		vm.setMem(0);
		vm.setDisk(0);
		vm.setStatus(0);
		vm.setNcId(1);
		vm.setIsDeleted(0);
		
		List<VmDO> list = vmDAO.findListByExample(vm);
		Assert.assertEquals(list.size(), 1);
		for (VmDO vmDO: list) {
			Assert.assertEquals(vm.getId(), 1)
		Assert.assertEquals(vm.getVmName(), "test-vm_name")
		Assert.assertEquals(vm.getCores(), 0)
		Assert.assertEquals(vm.getMem(), 0)
		Assert.assertEquals(vm.getDisk(), 0)
		Assert.assertEquals(vm.getStatus(), 0)
		Assert.assertEquals(vm.getNcId(), 1)
		Assert.assertEquals(vm.getIsDeleted(), 0)
		
		}
	}

	@Test
	@DbFit(when="VmDAOTest.initRecords.when.wiki")
	public void testFindVmDOByPrimaryKey() {
		VmDO vm = vmDAO.findVmDOByPrimaryKey(1L);
		Assert.assertEquals(vm.getId(), 1)
		Assert.assertEquals(vm.getVmName(), "test-vm_name")
		Assert.assertEquals(vm.getCores(), 0)
		Assert.assertEquals(vm.getMem(), 0)
		Assert.assertEquals(vm.getDisk(), 0)
		Assert.assertEquals(vm.getStatus(), 0)
		Assert.assertEquals(vm.getNcId(), 1)
		Assert.assertEquals(vm.getIsDeleted(), 0)
		
	}

	@Test
	@DbFit(when="VmDAOTest.initRecords.when.wiki")
	public void testDeleteVmDOByPrimaryKey() {
		Integer count = vmDAO.deleteVmDOByPrimaryKey(1L);
		Assert.assertEquals(count.intValue(), 1);
		
		Integer nodelete = vmDAO.deleteVmDOByPrimaryKey(1L);
		Assert.assertEquals(nodelete.intValue(), 0);
	}

}

     结语:

     只要是手工劳动, 尽可能自动化。而要做到自动化, 第一是规范标准化, 第二是要发现一些规律性的模式。 

 

你可能感兴趣的:(TestNG)