针对 Smart 2.3-SNAPSHOT 版本
想必看过 Smart 源码的朋友,一定阅读过以下这段蛋疼的代码。这是一个简单的类扫描器,可以从指定包名、指定注解、指定父类或接口来扫描相应的类。这恶心的代码究竟是哪位大仙写的呢?哎,惭愧啊,就是本人。
大家可以从代码中看到我当时的心酸、烦恼与无奈,大量的冗余,严重违反了 Don't Repeat Yourself 原则。当时是想能跑起来就行,但如果今天还不重构的话,那就再也找不到任何借口了。
/**
* 类操作工具类
*
* @author huangyong
* @since 1.0
*/
public class ClassUtil {
...
/**
* 获取指定包名下的所有类
*/
public static List<Class<?>> getClassList(String packageName) {
List<Class<?>> classList = new ArrayList<Class<?>>();
try {
Enumeration<URL> urls = getClassLoader().getResources(packageName.replace(".", "/"));
while (urls.hasMoreElements()) {
URL url = urls.nextElement();
if (url != null) {
String protocol = url.getProtocol();
if (protocol.equals("file")) {
String packagePath = url.getPath();
addClass(classList, packagePath, packageName);
} else if (protocol.equals("jar")) {
JarURLConnection jarURLConnection = (JarURLConnection) url.openConnection();
JarFile jarFile = jarURLConnection.getJarFile();
Enumeration<JarEntry> jarEntries = jarFile.entries();
while (jarEntries.hasMoreElements()) {
JarEntry jarEntry = jarEntries.nextElement();
String jarEntryName = jarEntry.getName();
if (jarEntryName.endsWith(".class")) {
String className = jarEntryName.substring(0, jarEntryName.lastIndexOf(".")).replaceAll("/", ".");
if (className.substring(0, className.lastIndexOf(".")).equals(packageName)) {
classList.add(loadClass(className, false));
}
}
}
}
}
}
} catch (Exception e) {
logger.error("获取类出错!", e);
throw new RuntimeException(e);
}
return classList;
}
/**
* 获取指定包名下指定注解的所有类
*/
public static List<Class<?>> getClassListByAnnotation(String packageName, Class<? extends Annotation> annotationClass) {
List<Class<?>> classList = new ArrayList<Class<?>>();
try {
Enumeration<URL> urls = getClassLoader().getResources(packageName.replace(".", "/"));
while (urls.hasMoreElements()) {
URL url = urls.nextElement();
if (url != null) {
String protocol = url.getProtocol();
if (protocol.equals("file")) {
String packagePath = url.getPath();
addClassByAnnotation(classList, packagePath, packageName, annotationClass);
} else if (protocol.equals("jar")) {
JarURLConnection jarURLConnection = (JarURLConnection) url.openConnection();
JarFile jarFile = jarURLConnection.getJarFile();
Enumeration<JarEntry> jarEntries = jarFile.entries();
while (jarEntries.hasMoreElements()) {
JarEntry jarEntry = jarEntries.nextElement();
String jarEntryName = jarEntry.getName();
if (jarEntryName.endsWith(".class")) {
String className = jarEntryName.substring(0, jarEntryName.lastIndexOf(".")).replaceAll("/", ".");
Class<?> cls = loadClass(className, false);
if (cls.isAnnotationPresent(annotationClass)) {
classList.add(cls);
}
}
}
}
}
}
} catch (Exception e) {
logger.error("获取类出错!", e);
throw new RuntimeException(e);
}
return classList;
}
/**
* 获取指定包名下指定父类的所有类
*/
public static List<Class<?>> getClassListBySuper(String packageName, Class<?> superClass) {
List<Class<?>> classList = new ArrayList<Class<?>>();
try {
Enumeration<URL> urls = getClassLoader().getResources(packageName.replace(".", "/"));
while (urls.hasMoreElements()) {
URL url = urls.nextElement();
if (url != null) {
String protocol = url.getProtocol();
if (protocol.equals("file")) {
String packagePath = url.getPath();
addClassBySuper(classList, packagePath, packageName, superClass);
} else if (protocol.equals("jar")) {
JarURLConnection jarURLConnection = (JarURLConnection) url.openConnection();
JarFile jarFile = jarURLConnection.getJarFile();
Enumeration<JarEntry> jarEntries = jarFile.entries();
while (jarEntries.hasMoreElements()) {
JarEntry jarEntry = jarEntries.nextElement();
String jarEntryName = jarEntry.getName();
if (jarEntryName.endsWith(".class")) {
String className = jarEntryName.substring(0, jarEntryName.lastIndexOf(".")).replaceAll("/", ".");
Class<?> cls = loadClass(className, false);
if (superClass.isAssignableFrom(cls) && !superClass.equals(cls)) {
classList.add(cls);
}
}
}
}
}
}
} catch (Exception e) {
logger.error("获取类出错!", e);
throw new RuntimeException(e);
}
return classList;
}
private static void addClass(List<Class<?>> classList, String packagePath, String packageName) {
try {
File[] files = getClassFiles(packagePath);
if (files != null) {
for (File file : files) {
String fileName = file.getName();
if (file.isFile()) {
String className = getClassName(packageName, fileName);
classList.add(loadClass(className, false));
} else {
String subPackagePath = getSubPackagePath(packagePath, fileName);
String subPackageName = getSubPackageName(packageName, fileName);
addClass(classList, subPackagePath, subPackageName);
}
}
}
} catch (Exception e) {
logger.error("添加类出错!", e);
throw new RuntimeException(e);
}
}
private static void addClassByAnnotation(List<Class<?>> classList, String packagePath, String packageName, Class<? extends Annotation> annotationClass) {
try {
File[] files = getClassFiles(packagePath);
if (files != null) {
for (File file : files) {
String fileName = file.getName();
if (file.isFile()) {
String className = getClassName(packageName, fileName);
Class<?> cls = loadClass(className, false);
if (cls.isAnnotationPresent(annotationClass)) {
classList.add(cls);
}
} else {
String subPackagePath = getSubPackagePath(packagePath, fileName);
String subPackageName = getSubPackageName(packageName, fileName);
addClassByAnnotation(classList, subPackagePath, subPackageName, annotationClass);
}
}
}
} catch (Exception e) {
logger.error("添加类出错!", e);
throw new RuntimeException(e);
}
}
private static void addClassBySuper(List<Class<?>> classList, String packagePath, String packageName, Class<?> superClass) {
try {
File[] files = getClassFiles(packagePath);
if (files != null) {
for (File file : files) {
String fileName = file.getName();
if (file.isFile()) {
String className = getClassName(packageName, fileName);
Class<?> cls = loadClass(className, false);
if (superClass.isAssignableFrom(cls) && !superClass.equals(cls)) {
classList.add(cls);
}
} else {
String subPackagePath = getSubPackagePath(packagePath, fileName);
String subPackageName = getSubPackageName(packageName, fileName);
addClassBySuper(classList, subPackagePath, subPackageName, superClass);
}
}
}
} catch (Exception e) {
logger.error("添加类出错!", e);
throw new RuntimeException(e);
}
}
...
}
我勒个去,这么长,看得我自己都寒心啊!大家不用太纠结与细节了,粗略看看外貌也就清楚了。
下面是就是我的重构步骤,精华就在其中,代码值得细看。
我要做的第一件事情就是,抽象出一个接口,没错,这个接口就是 ClassUtil
类中的那三个重要的方法。
/**
* 类扫描器
*
* @author huangyong
* @since 2.3
*/
public interface ClassScanner {
/**
* 获取指定包名中的所有类
*/
List<Class<?>> getClassList(String packageName);
/**
* 获取指定包名中指定注解的相关类
*/
List<Class<?>> getClassListByAnnotation(String packageName, Class<? extends Annotation> annotationClass);
/**
* 获取指定包名中指定父类或接口的相关类
*/
List<Class<?>> getClassListBySuper(String packageName, Class<?> superClass);
}
上一步应该是最简单的,下面需要做的,就是将那些冗余的代码使用 模板方法模式
进行抽象,让抽象父类去实现共性的东西,让其子类去完成个性的事情。
/**
* 用于获取类的模板类
*
* @author huangyong
* @since 2.3
*/
public abstract class ClassTemplate {
private static final Logger logger = LoggerFactory.getLogger(ClassTemplate.class);
protected final String packageName;
protected ClassTemplate(String packageName) {
this.packageName = packageName;
}
public final List<Class<?>> getClassList() {
List<Class<?>> classList = new ArrayList<Class<?>>();
try {
// 从包名获取 URL 类型的资源
Enumeration<URL> urls = ClassUtil.getClassLoader().getResources(packageName.replace(".", "/"));
// 遍历 URL 资源
while (urls.hasMoreElements()) {
URL url = urls.nextElement();
if (url != null) {
// 获取协议名(分为 file 与 jar)
String protocol = url.getProtocol();
if (protocol.equals("file")) {
// 若在 class 目录中,则执行添加类操作
String packagePath = url.getPath();
addClass(classList, packagePath, packageName);
} else if (protocol.equals("jar")) {
// 若在 jar 包中,则解析 jar 包中的 entry
JarURLConnection jarURLConnection = (JarURLConnection) url.openConnection();
JarFile jarFile = jarURLConnection.getJarFile();
Enumeration<JarEntry> jarEntries = jarFile.entries();
while (jarEntries.hasMoreElements()) {
JarEntry jarEntry = jarEntries.nextElement();
String jarEntryName = jarEntry.getName();
// 判断该 entry 是否为 class
if (jarEntryName.endsWith(".class")) {
// 获取类名
String className = jarEntryName.substring(0, jarEntryName.lastIndexOf(".")).replaceAll("/", ".");
// 执行添加类操作
doAddClass(classList, className);
}
}
}
}
}
} catch (Exception e) {
logger.error("获取类出错!", e);
}
return classList;
}
private void addClass(List<Class<?>> classList, String packagePath, String packageName) {
try {
// 获取包名路径下的 class 文件或目录
File[] files = new File(packagePath).listFiles(new FileFilter() {
@Override
public boolean accept(File file) {
return (file.isFile() && file.getName().endsWith(".class")) || file.isDirectory();
}
});
// 遍历文件或目录
for (File file : files) {
String fileName = file.getName();
// 判断是否为文件或目录
if (file.isFile()) {
// 获取类名
String className = fileName.substring(0, fileName.lastIndexOf("."));
if (StringUtil.isNotEmpty(packageName)) {
className = packageName + "." + className;
}
// 执行添加类操作
doAddClass(classList, className);
} else {
// 获取子包
String subPackagePath = fileName;
if (StringUtil.isNotEmpty(packagePath)) {
subPackagePath = packagePath + "/" + subPackagePath;
}
// 子包名
String subPackageName = fileName;
if (StringUtil.isNotEmpty(packageName)) {
subPackageName = packageName + "." + subPackageName;
}
// 递归调用
addClass(classList, subPackagePath, subPackageName);
}
}
} catch (Exception e) {
logger.error("添加类出错!", e);
}
}
private void doAddClass(List<Class<?>> classList, String className) {
// 加载类
Class<?> cls = ClassUtil.loadClass(className, false);
// 判断是否可以添加类
if (checkAddClass(cls)) {
// 添加类
classList.add(cls);
}
}
/**
* 验证是否允许添加类
*/
public abstract boolean checkAddClass(Class<?> cls);
}
以上 ClassTemplate
模板类中,提供了唯一模板方法 —— checkAddClass
方法,每个子类必须重写该方法。
需要注意的是,这些并没有提供任何支持 注解
以及 父类或接口
相关的特性,因为这些特性都在下面的 ClassTemplate
的子类中。
以下是提供 注解
特性的 AnnotationClassTemplate
:
/**
* 用于获取注解类的模板类
*
* @author huangyong
* @since 2.3
*/
public abstract class AnnotationClassTemplate extends ClassTemplate {
protected final Class<? extends Annotation> annotationClass;
protected AnnotationClassTemplate(String packageName, Class<? extends Annotation> annotationClass) {
super(packageName);
this.annotationClass = annotationClass;
}
}
以下是提供 父类或接口
特性的 SupperClassTemplate
:
/**
* 用于获取子类的模板类
*
* @author huangyong
* @since 2.3
*/
public abstract class SupperClassTemplate extends ClassTemplate {
protected final Class<?> superClass;
protected SupperClassTemplate(String packageName, Class<?> superClass) {
super(packageName);
this.superClass = superClass;
}
}
现在的结构看起来是这样的:
这些基础建设已经全部完成了,剩下的最后一件事情就是为 ClassScanner
提供一个实现类了。
实现 ClassScanner
接口是一件非常轻松的事情,因为借助上面的 ClassTemplate
,及其相关子类 AnnotationClassTemplate
与 SupperClassTemplate
就能搞定。
/**
* 默认类扫描器
*
* @author huangyong
* @since 2.3
*/
public class DefaultClassScanner implements ClassScanner {
@Override
public List<Class<?>> getClassList(String packageName) {
return new ClassTemplate(packageName) {
@Override
public boolean checkAddClass(Class<?> cls) {
String className = cls.getName();
String pkgName = className.substring(0, className.lastIndexOf("."));
return pkgName.startsWith(packageName);
}
}.getClassList();
}
@Override
public List<Class<?>> getClassListByAnnotation(String packageName, Class<? extends Annotation> annotationClass) {
return new AnnotationClassTemplate(packageName, annotationClass) {
@Override
public boolean checkAddClass(Class<?> cls) {
return cls.isAnnotationPresent(annotationClass);
}
}.getClassList();
}
@Override
public List<Class<?>> getClassListBySuper(String packageName, Class<?> superClass) {
return new SupperClassTemplate(packageName, superClass) {
@Override
public boolean checkAddClass(Class<?> cls) {
return superClass.isAssignableFrom(cls) && !superClass.equals(cls);
}
}.getClassList();
}
}
针对不同的情况,构造不同的 ClassTemplate
对象,从而编写相应的 checkAddClass
方法(实际上是提供一个添加类的检测条件)。
现在,所有的类都在这里了,他们的依赖关系如下图所示:
经过以上重构后,在代码中只需面向 ClassScanner
接口即可实现相应的类扫描操作。
欢迎下载 Smart 源码:
http://git.oschina.net/huangyong/smart
欢迎阅读 Smart 博文:
http://my.oschina.net/huangyong/blog/158380