这个项目的灵感是来自一款叫做 JMH 的基准测试框架。
是为了更好的回答实现了同样功能的函数,哪些操作更高效,指导我们更好的了解优化我们的程序。
在这个过程中也很好的学习了关于影响到程序尤其是 java 程序执行的一些因素。
整个阶段,用到的是以下知识:
//只存在运行时期
@Retention(RetentionPolicy.RUNTIME)
//方法的注解,被注解的方法就是测试用例
@Target(ElementType.METHOD)
public @interface Benchmark {
}
/**
* 需要测试的次数和组数
*/
@Target({ElementType.METHOD,ElementType.TYPE})
//只存在运行时期
@Retention(RetentionPolicy.RUNTIME)
public @interface Measurement {
//表示每组测试的次数
int iteratations();
//表示一共有多少组测试
int groups();
}
/**
* 预热的注释
*/
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface WarmUp {
//需要进行预热的次数,默认为0
int iterations() default 0;
}
public class CaseLoader {
/**
* 用于自动加载所有的测试用例类
* @return 返回测试用例类集合List
* @throws IOException
* @throws ClassNotFoundException
* @throws IllegalAccessException
* @throws InstantiationException
*/
public CaseRunner load() throws IOException, ClassNotFoundException, IllegalAccessException, InstantiationException {
//包名称
String packName = "arivan.cases";
List classNameList = loadClassName();
//用于存放所有实现了Case接口的类
List caseList = new ArrayList();
for (String className : classNameList) {
Class> cls = Class.forName(packName + "." + className);
//判断当前类cls是否为Case的实现类
if (hasInterface(cls,Case.class)) {
caseList.add((Case)cls.newInstance());
}
}
return new CaseRunner(caseList);
}
/**
* 用于判断当前类cls是否实现了intf接口
* @param cls 当前类cls
* @param intf 接口intf
* @return 若cls实现了接口intf,返回true,否则返回false
*/
private boolean hasInterface(Class> cls, Class> intf) {
Class>[] intfs = cls.getInterfaces();
for (Class> i : intfs) {
if (i == intf) {
return true;
}
}
return false;
}
/**
* 通过固定类来取得所有存放于测试用例包cases下的 *.class文件名集合
* @return 返回List
* @throws IOException
*/
private List loadClassName() throws IOException {
//用于存放所有类名的List集合
List classNameList = new ArrayList();
//1 根据一个固定类找到一个类加载器
String packg = "arivan/cases";
ClassLoader classLoader = this.getClass().getClassLoader();
//2 根据类加载器找到类文件所在的路径
Enumeration urls = classLoader.getResources(packg);
while (urls.hasMoreElements()) {
URL url = urls.nextElement();
if (!url.getProtocol().equals("file")) {
//暂时不支持非 *.class文件
continue;
}
//进行转码,防止出现乱码问题
String dirName = URLDecoder.decode(url.getPath(),"UTF-8");
File dir = new File(dirName);
if (!dir.isDirectory()) {
//如果当前dir不是目录,则跳过
continue;
}
//3 扫描路径所在的所有 *.class文件,作为类文件
File[] files = dir.listFiles();
if (files == null) {
continue;
}
for (File file : files) {
String fileName = file.getName();
//判断后缀是否为.class
if (!fileName.endsWith(".class")) {
continue;
}
//取得类名
String className = fileName.substring(0,fileName.length() - 6);
classNameList.add(className);
}
}
return classNameList;
}
}
public class CaseRunner {
//默认每组运行次数
private static final int DEFAULT_ITERATIONS = 10;
//默认运行组数
private static final int DEFAULT_GROUPS = 5;
//用于记录所有的待测类实例
private final List caseList;
//用于存放每一个测试用例的详细信息(测试用时最大值、最小值、总和、组数、每组次数、测试的方法名)
public List reports = new ArrayList();
public CaseRunner(List list) {
this.caseList = list;
}
/**
* 开始运行测试方法
*/
public void run() throws InvocationTargetException, IllegalAccessException {
for (Case bCase : caseList) {
int iterations = DEFAULT_ITERATIONS;
int groups = DEFAULT_GROUPS;
//先获取类级别的配置注解
Measurement classMeasurement = bCase.getClass().getAnnotation(Measurement.class);
if (classMeasurement != null) {
iterations = classMeasurement.iteratations();
groups = classMeasurement.groups();
}
//找到类对象中需要测试的方法
//获取对象的所有方法
Method[] methods = bCase.getClass().getMethods();
for (Method method : methods) {
Benchmark benchmark = method.getAnnotation(Benchmark.class);
//判断当前方法是否有Benchmark注解
if (benchmark == null) {
continue;
}
Measurement methodMeasureMent = method.getAnnotation(Measurement.class);
if (methodMeasureMent != null) {
iterations = methodMeasureMent.iteratations();
groups = methodMeasureMent.groups();
}
runCase(bCase, method, iterations, groups);
}
}
}
/**
* 此类用于存放预热方法及其方法所在类的实例对象
*/
private class WarmMethod {
Method method; //预热方法
Case aCase; //预热方法所在类的实例对象
int iteratons; //需要预热的次数
public WarmMethod(Method method, Case aCase , int iteratons) {
this.method = method;
this.aCase = aCase;
this.iteratons = iteratons;
}
}
/**
* 取得所有的预热方法
* @return 返回预热方法
*/
private WarmMethod addWarmUpMethods() {
for (Case bcase : caseList) {
Method[] methods = bcase.getClass().getMethods();
for (Method method : methods) {
WarmUp warmUp = method.getAnnotation(WarmUp.class);
if (warmUp != null) {
return new WarmMethod(method,bcase,warmUp.iterations());
}
}
}
return null;
}
/**
* 实际运行方法,用于运行测试用例
* @param bCase 测试方法所在类的实例对象
* @param method 需要测试的方法
* @param iterations 每组测试的次数
* @param groups 一共需要测试的组数
*/
private void runCase(Case bCase, Method method, int iterations, int groups) throws InvocationTargetException, IllegalAccessException {
WarmMethod warmMethod = addWarmUpMethods();
//记录时间耗费的最大值
long maxTime = Long.MIN_VALUE;
//记录时间耗费的最小值
long minTime = Long.MAX_VALUE;
//记录时间总和
long sum = 0;
//记录平均时间
long averageTime = 0;
System.out.println("------------------"+ method.getName() + "测试开始------------------");
for (int i = 1; i <= groups; i++) {
System.out.println("第" + i + "组测试开始:");
//每组测试用例运行之前,运行预热方法
if (warmMethod != null) {
for (int j = 0; j < warmMethod.iteratons; j++) {
warmMethod.method.invoke(warmMethod.aCase);
}
System.out.println(" 预热完成,开始测试!!!");
}
//预热完成后,执行正式测试
for (int j = 1; j <= iterations; j++) {
Object obj = method.invoke(bCase);
long time = (Long)obj;
maxTime = time > maxTime ? time : maxTime;
minTime = time < minTime ? time : minTime;
sum += time;
}
}
sum = sum - maxTime - minTime;
averageTime = sum / (groups * iterations - 2);
Report report = new Report(maxTime,minTime,sum,averageTime,method.getName(),bCase.getClass().getName(),groups,iterations);
this.reports.add(report);
System.out.println("------------------"+ method.getName() + "测试结束------------------");
System.out.println();
}
/**
* 用于记录当前方法测试使用的最长时间,最短时间,总时间,去掉最大最小时间后的平均时间,测试的组数,每组测试的次数
*/
public class Report {
private long maxTime;
private long minTime;
private long sumTime;
private long averageTime;
private String methodName;
private String caseName;
private int groups;
private int iterations;
public Report(long maxTime, long minTime, long sumTime, long averageTime, String methodName, String caseName,
int groups, int iterations) {
this.maxTime = maxTime;
this.minTime = minTime;
this.sumTime = sumTime;
this.averageTime = averageTime;
this.methodName = methodName;
this.caseName = caseName;
this.groups = groups;
this.iterations = iterations;
}
public long getMaxTime() {
return maxTime;
}
public long getMinTime() {
return minTime;
}
public long getSumTime() {
return sumTime;
}
public long getAverageTime() {
return averageTime;
}
public String getMethodName() {
return methodName;
}
public String getCaseName() {
return caseName;
}
public int getGroups() {
return groups;
}
public int getIterations() {
return iterations;
}
}
}
/**
* 测试快速排序和归并排序
*/
@Measurement(iteratations = 10,groups = 5)
public class SortCase {
/**
* 生成一个随机数组
* @param n 随机数组个数
* @param max 数组元素的最大值
* @return 返回生成的数组
*/
private static int[] createArray(int n, int max) {
int[] arr = new int[n];
Random random = new Random(max);
for (int i = 0; i < n; i++) {
arr[i] = random.nextInt(max);
}
return arr;
}
/**
* 预热方法
*/
@WarmUp(iterations = 1) //此处设置预热1次
public static void warmUp() {
int[] arr = createArray(10000,100000);
}
/**
* 测试系统排序用时
* @return
*/
@Benchmark
public static long SystemSort() {
int[] arr = createArray(100000,100000);
long start = System.nanoTime();
Arrays.sort(arr);
long end = System.nanoTime();
System.out.println(" SystemSort排序用时:" + (end - start) + "纳秒");
return end - start;
}
/**
* 测试快速排序
*/
@Benchmark
public static long quickSort() {
int[] arr = createArray(100000,100000);
long start = System.nanoTime();
int n = arr.length;
if (n <= 1) {
return 0;
}
sort(arr,0,arr.length-1);
long end = System.nanoTime();
System.out.println(" 递归法快速排序用时:" + (end - start) + "纳秒");
return end - start;
}
private static void sort(int[] arr, int l, int r) {
if (l >= r) {
return;
}
int mid = partation(arr,l,r);
sort(arr,l,mid-1);
sort(arr,mid+1,r);
}
private static int partation(int[] arr, int l, int r) {
int randomIndex = (int) (Math.random() * (r - l + 1) + l);
swap(arr,l,randomIndex);
int value = arr[l];
int i = l + 1;
int j = l;
int k = r + 1;
while (i < k) {
if (arr[i] < value) {
swap(arr,i++,++j);
} else if (arr[i] > value) {
swap(arr,i,--k);
} else {
i++;
}
}
swap(arr,l,j);
return j;
}
private static void swap(int[] arr, int i, int j) {
int temp = arr[i];
arr[i] = arr[j];
arr[j] = temp;
}
/**
*归并排序
*/
@Benchmark
public static long mergeSort() {
int[] arr = createArray(100000,100000);
long start = System.nanoTime();
int n = arr.length;
if (n <= 1) {
return 0;
}
partition(arr,0,arr.length-1);
long end = System.nanoTime();
System.out.println(" 归并排序用时:" + (end - start) + "纳秒");
return end - start;
}
private static void partition(int[] arr, int l, int r) {
if (l >= r) {
return;
}
int mid = l + (r - l) / 2;
partition(arr,l,mid);
partition(arr,mid+1,r);
if (arr[mid] > arr[mid+1]) {
merge(arr,l,mid,r);
}
}
private static void merge(int[] arr, int l,int mid, int r) {
int[] array = new int[r-l+1];
int i = l;
int j = mid + 1;
int k = 0;
while (i <= mid && j <= r) {
if (arr[i] <= arr[j]) {
array[k++] = arr[i++];
} else {
array[k++] = arr[j++];
}
}
int start = i;
int end = mid;
if (j <= r) {
start = j;
end = r;
}
while (start <= end) {
array[k++] = arr[start++];
}
for (int m = 0; m < array.length; m++) {
arr[l+m] = array[m];
}
}
}
public class Reoprts {
/**
* 生成给定测试的测试报告
* @param runner 给定测试
*/
public static void report(CaseRunner runner) {
//首先取得记录所有测试用例详细信息的集合
List reports = runner.reports;
//将同一个待测类的所有测试用例放在同一个Map中,便于生成测试报告
Map> map = new HashMap>();
for (CaseRunner.Report report : reports) {
//key代表每一个测试用例所在的类的类名
String key = report.getCaseName();
List value = null;
if (!map.containsKey(key)) {
value = new ArrayList();
value.add(report);
} else {
value = map.get(key);
value.add(report);
}
map.put(key,value);
}
//对不同实例类的待测方法进行分开生成测试报告
for (String key : map.keySet()) {
List list = map.get(key);
//同一个实例类中所有待测方法的测试报告一起生成
System.out.println("-----------------测试报告----------------");
for (CaseRunner.Report report : list) {
System.out.println("测试方法:"+ report.getMethodName());
System.out.println(" 测试分为:" + report.getGroups() + " 组");
System.out.println(" 每组测试:" + report.getIterations() + " 次");
System.out.println(" 使用的总时间:" + report.getSumTime() + " 纳秒");
System.out.println(" 使用最多时间:" + report.getMaxTime() + " 纳秒");
System.out.println(" 使用最少时间:" + report.getMinTime() + " 纳秒");
System.out.println(" 平均使用用时:" + report.getAverageTime() + " 纳秒");
System.out.println();
}
System.out.println("----------------------------------------");
}
}
}
1.对由100000 个 值为0到100000的随机整数组成的数组,分别使用递归法快速排序、归并排序以及系统的Arrays.sort()进行排序,测出这三个方法的性能(此处指排序耗时)差异。每个方式测试五组,每组测试10次。
归并排序第一组测试
Array.sort()第一组测试
递归快排第一组测试
2.两次测试报告
3.结论
经过多次测试,可得出:在100000个随即数据的排序中,系统的Arrays.sort() 是性能最高的,递归快速排序与归并排序的性能相差不大。