分六步实现手写spring-mvc框架与DI依赖注入

 

springmvc主要包装了servlet,使每次的新添请求不用到web.xml里面去手动加,主要解决的问题就是把请求路径和开发写的controller层方法绑定起来。

最近看了一波spring-mvc的原理解析视频,大概理解一一下springmvc的大致工作原理和工作流程。为了加深对spring-mvc的理解,自己也手写了一波spring-mvc。

其实核心的东西就几点,把请求的url和方法上的自定义注解值匹配起来,通过反射获取方法的参数,在dispatchservlet里面处理HttpservletRequest,把响应的参数获取出来,然后method.invoke(controller实体,参数)返回一个结果,最后把结果放到HttpServletResponse里面。

首先像开发spring-mvc的一样,写好web.xml文件,controller文件,service文件等......

web.xml(入口)配置自己servlet,这时就不要写spring的DispatchServlet啦。(自信点直接写自己的类,配置好了servlet再新建文件从零写)




  Archetype Created Web Application
  
      spring-mvc
      cn.wzy.servlet.WzyDispatchServlet
      
          contextConfigLocation
          config.properties
      
      1
  
    
        spring-mvc
        *.do
    

这里的配置文件就是相当的简单,主要properties解析比xml解析更方便。

config.properties就如此的简单,相当于只是设置了扫描的包就是cn.wzy下。

package=cn.wzy

controller,service文件(像普通开发时套用spring一样,唯一的差别就是所有都用自己的注解)

controller层:

package cn.wzy.controller;

import cn.wzy.annotation.WzyAutowired;
import cn.wzy.annotation.WzyController;
import cn.wzy.annotation.WzyRequestMapping;
import cn.wzy.annotation.WzyRequestParam;
import cn.wzy.service.UserService;
import com.alibaba.druid.support.json.JSONUtils;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

/**
 * @author wzy 不短不长八字刚好.
 * @since 2018/9/2 14:17
 */
@WzyController
@WzyRequestMapping("/user")
public class UserController {

	@WzyAutowired
	private UserService service;

	@WzyRequestMapping("/hello.do")
	public String hello(@WzyRequestParam("id") Integer id,
	                    @WzyRequestParam("name") String name) throws IOException {
		return service.sayHello(id,name);
	}

	@WzyRequestMapping("/hello2.do")
	public Object hello2(@WzyRequestParam("id") Integer id,
	                    @WzyRequestParam("name") String name) throws IOException {
		Map datas = new HashMap<>();
		datas.put("id",id);
		datas.put("name",name);
		return JSONUtils.toJSONString(datas);
	}

}

service层:

package cn.wzy.service;

/**
 * @author wzy 不短不长八字刚好.
 * @since 2018/9/2 14:21
 */
public interface UserService {

	String sayHello(Integer id,String name);
}

service实现类:

package cn.wzy.service.impl;

import cn.wzy.annotation.WzyService;
import cn.wzy.service.UserService;

/**
 * @author wzy 不短不长八字刚好.
 * @since 2018/9/2 14:21
 */
@WzyService
public class UserServiceImpl implements UserService {
	@Override
	public String sayHello(Integer id, String name) {
		return "hello " + id + " : " + name ;
	}
}

接下来就自己定义注解了:

/**
 * @author wzy 不短不长八字刚好.
 * @since 2018/9/2 14:18
 */
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.FIELD)
public @interface WzyAutowired {

	String value() default "";

}
/**
 * @author wzy 不短不长八字刚好.
 * @since 2018/9/2 14:18
 */
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface WzyController {

	String value() default "";

}
/**
 * @author wzy 不短不长八字刚好.
 * @since 2018/9/2 14:18
 */
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.METHOD,ElementType.TYPE})
public @interface WzyRequestMapping {

	String value() default "";

}
/**
 * @author wzy 不短不长八字刚好.
 * @since 2018/9/2 14:18
 */
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.PARAMETER)
public @interface WzyRequestParam {

	String value() default "";

}
/**
 * @author wzy 不短不长八字刚好.
 * @since 2018/9/2 14:18
 */
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface WzyService {

	String value() default "";

}

上面的所有步骤都是正常使用spring-mvc的时候的操作,下面就要开始自己手动实现spring-mvc的核心文件啦

第一步:我们需要加载配置文件,也就是上面的config.properties(只有一个扫描路径的参数项);

第二步:开始扫描所有的类文件;

第三步:对于每一个类,检查是否需要实例化(检查是否有注解申明);

第四步:对于每一个生成的示例,往有注解申明的字段上注入对象;

第五步:把controller层注解的类解析一下,把请求url和请求方法记录下来;

第六步:访问时解析路径,反射调用方法;

核心代码:具体工程代码github地址:https://github.com/1510460325/springframework

package cn.wzy.servlet;

import cn.wzy.annotation.*;

import javax.servlet.ServletConfig;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URL;
import java.util.*;

/**
 * @author wzy 不短不长八字刚好.
 * @since 2018/9/2 13:58
 */
public class WzyDispatchServlet extends HttpServlet {

	private Properties properties = new Properties();

	/**
	 * 扫描的所有类名
	 */
	private List classNames = new ArrayList<>();

	/**
	 * ioc实例化容器
	 */
	private Map ioc = new HashMap<>();

	/**
	 * url路径
	 */
	private Map handlerMapping = new HashMap<>();

	@Override
	public void init(ServletConfig config) throws ServletException {
		//1、加载配置文件
		loadConfig(config.getInitParameter("contextConfigLocation"));
		//2、扫描文件
		scanner(properties.getProperty("package"));
		//3、初始化示例
		instance();
		//4、自动注入
		autowired();
		//5、初始化handlerMapping
		initHandlerMapping();
	}

	/**
	 * 初始化请求绑定
	 */
	private void initHandlerMapping() {
		for (Object value: ioc.values()) {
			Class clazz = value.getClass();
			if (!clazz.isAnnotationPresent(WzyController.class)) {
				continue;
			}
			String baseUrl = "";
			if (clazz.isAnnotationPresent(WzyRequestMapping.class)) {
				WzyRequestMapping url = clazz.getAnnotation(WzyRequestMapping.class);
				if (!url.value().equals("")) {
					baseUrl += url.value();
				}
			}

			Method[] methods = clazz.getDeclaredMethods();
			for (Method method: methods) {
				if (!method.isAnnotationPresent(WzyRequestMapping.class)) {
					continue;
				}
				WzyRequestMapping url = method.getAnnotation(WzyRequestMapping.class);
				Handler handler = new Handler();
				handler.controller = value;
				handler.method = method;
				handler.url = baseUrl + url.value();
				handlerMapping.put(handler.url,handler);
			}
		}
	}

	/**
	 * 依赖注入
	 */
	private void autowired() {
		if (ioc.isEmpty()) {
			return;
		}
		for (Map.Entry entry: ioc.entrySet()) {
			Field[] fields = entry.getValue().getClass().getDeclaredFields();
			for (Field field: fields) {
				if (!field.isAnnotationPresent(WzyAutowired.class)) {
					continue;
				}

				WzyAutowired autowired = field.getAnnotation(WzyAutowired.class);
				String beanName = autowired.value();
				if (beanName.equals("")) {
					beanName = lowerFirst(field.getType().getName());
				}
				field.setAccessible(true);
				try {
					field.set(entry.getValue(),ioc.get(beanName));
				} catch (IllegalAccessException e) {
					e.printStackTrace();
					System.out.println("===Autowired " + field.getName() + " failed=====");
				}
			}
		}
	}

	/**
	 * 实例化
	 */
	private void instance() {
		if (classNames.isEmpty())
			return;
		try {
		for (String className: classNames) {
				Class clazz = Class.forName(className);
				if (clazz.isAnnotationPresent(WzyController.class)) {
					String beanName = lowerFirst(clazz.getSimpleName());
					ioc.put(beanName,clazz.newInstance());
				} else if (clazz.isAnnotationPresent(WzyService.class)) {
					WzyService service = clazz.getAnnotation(WzyService.class);
					String beanName = service.value();
					if (beanName.equals("")) {
						beanName = lowerFirst(clazz.getSimpleName());
					}
					Object instance = clazz.newInstance();
					ioc.put(beanName,instance);

					Class[] interfaces = clazz.getInterfaces();
					for (Class c: interfaces) {
						ioc.put(lowerFirst(c.getName()),instance);
					}
				}
			}
		}catch (ClassNotFoundException e) {
			e.printStackTrace();
		} catch (IllegalAccessException e) {
			e.printStackTrace();
		} catch (InstantiationException e) {
			e.printStackTrace();
		}
	}

	private String lowerFirst(String className) {
		char[] ch = className.toCharArray();
		ch[0] += 32;
		return new String(ch);
	}

	/**
	 * 递归扫描包下所有的类,存到list里面待下阶段生成实例
	 * @param packageName
	 */
	private void scanner(String packageName) {
		URL url = this.getClass().getClassLoader().getResource("/" + packageName.replace(".","/"));
		File classDir = new File(url.getFile());
		for (File file: classDir.listFiles()) {
			if (file.isDirectory()) {
				scanner(packageName + "." + file.getName());
			} else {
				classNames.add(packageName + "." + file.getName().replace(".class",""));
			}
		}
	}

	/**
	 * 加载配置文件
	 * @param location
	 */
	private void loadConfig(String location) {
		InputStream is = this.getClass().getClassLoader().getResourceAsStream(location);
		try {
			properties.load(is);
		} catch (IOException e) {
			e.printStackTrace();
		} finally {
			try {
				if (is != null) {
					is.close();
				}
			} catch (IOException e) {
				e.printStackTrace();
			}
		}
	}

	/**
	 * handlerMapping处理类
	 */
	private class Handler{
		protected Object controller;
		protected Method method;
		protected String url;

		public Handler() {
		}
	}

	@Override
	protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
		this.doPost(req, resp);
	}

	@Override
	protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
		doDispatcher(req,resp);
	}

	protected void doDispatcher(HttpServletRequest req, HttpServletResponse resp) {
		//获取请求url
		String url = req.getRequestURI();
		url = url.replace(req.getContextPath(),"");
		//映射到具体实例方法
		Handler handler = handlerMapping.get(url);
		Method targetMethod = handler.method;
		//绑定请求参数到具体方法
		Class[] types = targetMethod.getParameterTypes();
		Object[] params = new Object[types.length];
		int index = 0;
		Annotation[][] annotations = targetMethod.getParameterAnnotations();
		for (Annotation[] an1: annotations) {
			for (Annotation an: an1) {
				if (an instanceof WzyRequestParam) {
					String paramName = ((WzyRequestParam)an).value();
					Class type = types[index];
					if (type == Integer.class)
						params[index] = Integer.parseInt(req.getParameter(paramName));
					else
						params[index] = req.getParameter(paramName);
				}
				index++;
			}
		}
		try {
			Object result = targetMethod.invoke(handler.controller,params);
			resp.setContentType("application/json");
			resp.setCharacterEncoding("utf-8");
			resp.getWriter().write(result.toString());
		} catch (IllegalAccessException e) {
			e.printStackTrace();
		} catch (InvocationTargetException e) {
			e.printStackTrace();
		} catch (IOException e) {
			e.printStackTrace();
		}
	}
}

 

你可能感兴趣的:(spring,spring-mvc)