两小时手写springmvc框架

这篇文章是学习咕泡学院tom老师手写一个spring框架视频而来,代码基本复制原文,稍作修改。可以帮助我们理解springmvc实现的大致原理。

1、构建maven工程,只需要导入javax.servlet-api的依赖。另外配置,直接通过tomcat插件来启动项目。

  
    
      javax.servlet
      javax.servlet-api
      3.0.1
      provided
    
  
  
    springmvc
    
        
              org.eclipse.jetty
              jetty-maven-plugin
              9.4.7.v20170914
              
                  
                      /${project.build.finalName}
                  
                  CTRL+C
                  8999
                  10
                  
                      src/main/webapp/WEB-INF/web.xml
                  
              
        
    
  

项目结构:

两小时手写springmvc框架_第1张图片

2、编写主要的代码并配置web.xml

XXAutowired.java

package com.xxx.springmvc.annotation;

import java.lang.annotation.*;
@Target({ElementType.FIELD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface XXAutowired {
	String value() default "";
}

XXController.java

package com.xxx.springmvc.annotation;

import java.lang.annotation.*;

@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface XXController {
	String value() default "";
}

XXRequestMapping.java

package com.xxx.springmvc.annotation;

import java.lang.annotation.*;

@Target({ElementType.TYPE,ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface XXRequestMapping {
	String value() default "";
}

XXRequestParam.java

package com.xxx.springmvc.annotation;
import java.lang.annotation.*;
@Target({ElementType.PARAMETER})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface XXRequestParam {
	String value() default "";
}

XXService.java

package com.xxx.springmvc.annotation;

import java.lang.annotation.*;

@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface XXService {
	String value() default "";
}

实体类User.java

package com.xxx.springmvc.entity;

public class User {
	private String id;
	private String username;
	private String password;
	...此处省略get set方法
	@Override
	public String toString() {
		return "{\"id\":" + id + ", \"username\":\"" + username + "\", \"password\":\""
				+ password + "\"}";
	}
	public User() {}
	public User(String id, String username, String password) {
		this.id = id;
		this.username = username;
		this.password = password;
	}
	
}

UserService.java接口文件

package com.xxx.springmvc.web.service;

import java.util.List;

import com.xxx.springmvc.entity.User;

public interface UserService {
	String get(String name);
	List list();
}

UserServiceImpl.java

package com.xxx.springmvc.web.service.impl;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;

import com.xxx.springmvc.annotation.XXService;
import com.xxx.springmvc.entity.User;
import com.xxx.springmvc.web.service.UserService;
@XXService("userService")
public class UserServiceImpl implements UserService {
	
	private static Map users = new HashMap();
	
	static{
		users.put("aa", new User("1","aaa","123456"));
		users.put("bb", new User("2","bbb","123456"));
		users.put("cc", new User("3","ccc","123456"));
		users.put("dd", new User("4","ddd","123456"));
		users.put("ee", new User("5","eee","123456"));
	}

	@Override
	public String get(String name) {
		User user = users.get(name);
		if(user==null){
			user = users.get("aa");
		}
		return user.toString();
	}

	@Override
	public List list() {
		List list = new ArrayList();
		for(Entry entry : users.entrySet()){
			list.add(entry.getValue());
		}
		return list;
	}

}

UserController.java

package com.xxx.springmvc.web.controller;
import java.io.IOException;
import java.util.List;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import com.xxx.springmvc.annotation.XXAutowired;
import com.xxx.springmvc.annotation.XXController;
import com.xxx.springmvc.annotation.XXRequestMapping;
import com.xxx.springmvc.annotation.XXRequestParam;
import com.xxx.springmvc.entity.User;
@XXController
@XXRequestMapping("/user")
public class UserController {
	
	@XXAutowired
	private com.xxx.springmvc.web.service.UserService userService;
	
	@XXRequestMapping("/index")
	public String index(HttpServletRequest request,HttpServletResponse response,
          @XXRequestParam("name")String name) throws IOException{
		String res = userService.get(name);
		System.out.println(name+"=>"+res);
		response.setContentType("application/json;charset=UTF-8");
		response.getWriter().write(res);
		return "index";
	}
	
	@XXRequestMapping("/list")
	public String list(HttpServletRequest request,HttpServletResponse response)
       throws IOException{
		List users = userService.list();
		response.setContentType("application/json;charset=UTF-8");
		response.getWriter().write(users.toString());
		return "list";
	}
}

核心类:XXDispatcherServlet.java

package com.xxx.springmvc.servlet;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.net.URL;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Properties;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import javax.servlet.ServletConfig;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import com.xxx.springmvc.annotation.XXAutowired;
import com.xxx.springmvc.annotation.XXController;
import com.xxx.springmvc.annotation.XXRequestMapping;
import com.xxx.springmvc.annotation.XXRequestParam;
import com.xxx.springmvc.annotation.XXService;

public class XXDispatcherServlet extends HttpServlet{
	private Properties contextConfig = new Properties();
	private List classNames = new ArrayList();
	private Map ioc = new HashMap();
	private List handlerMapping = new ArrayList();

	private static final long serialVersionUID = -4943120355864715254L;

	
	@Override
	public void init(ServletConfig config) throws ServletException {
		//load config
		doLoadConfig(config.getInitParameter("contextConfigLocation"));
		//scan relative class
		doScanner(contextConfig.getProperty("scanPackage"));
		//init ioc container put relative class to it
		doInstance();
		//inject dependence
		doAutoWired();
		//init handlerMapping
		initHandlerMapping();
	}

	private void initHandlerMapping() {
		if(ioc.isEmpty())return;
		for(Map.Entry entry:ioc.entrySet()){
			Class clazz = entry.getValue().getClass();
			if(!clazz.isAnnotationPresent(XXController.class)){continue;}
			String baseUrl = "";
			if(clazz.isAnnotationPresent(XXRequestMapping.class)){
				XXRequestMapping requestMapping = clazz.getAnnotation(XXRequestMapping.class);
				baseUrl = requestMapping.value();
			}
			Method[] methods = clazz.getMethods();
			for(Method method:methods){
				if(!method.isAnnotationPresent(XXRequestMapping.class)){continue;}
				XXRequestMapping requestMapping = method.getAnnotation(XXRequestMapping.class);
				String url = (baseUrl+requestMapping.value()).replaceAll("/+", "/");
				Pattern pattern = Pattern.compile(url);
				handlerMapping.add(new Handler(pattern, entry.getValue(), method));
				System.out.println("mapped:"+url+"=>"+method);
			}
		}
	}

	private void doAutoWired() {
		if(ioc.isEmpty())return;
		for(Map.Entry entry:ioc.entrySet()){
			//依赖注入->给加了XXAutowired注解的字段赋值
			Field[] fields = entry.getValue().getClass().getDeclaredFields();
			for(Field field:fields){
				if(!field.isAnnotationPresent(XXAutowired.class)){continue;}
				XXAutowired autowired = field.getAnnotation(XXAutowired.class);
				String beanName = autowired.value();
				if("".equals(beanName)){
					beanName = field.getType().getName();
				}
				field.setAccessible(true);
				try {
					field.set(entry.getValue(), ioc.get(beanName));
				} catch (IllegalAccessException e) {
					e.printStackTrace();
					continue;
				}
			}
		}
	}

	private void doInstance() {
		if(classNames.isEmpty())return;
		try {		
			for(String className:classNames){
				Class clazz = Class.forName(className);
				if(clazz.isAnnotationPresent(XXController.class)){
					String beanName = lowerFirstCase(clazz.getSimpleName());
					ioc.put(beanName, clazz.newInstance());
				}else if(clazz.isAnnotationPresent(XXService.class)){
					
					XXService service = clazz.getAnnotation(XXService.class);
					String beanName = service.value();
					if("".equals(beanName)){
						beanName = lowerFirstCase(clazz.getSimpleName());
					}
					Object instance = clazz.newInstance();
					ioc.put(beanName, instance);
					Class[] interfaces = clazz.getInterfaces();
					for(Class i:interfaces){
						ioc.put(i.getName(), instance);
					}
				}else{
					continue;
				}
			}
		} catch (Exception e) {
			e.printStackTrace();
		}
	}

	private void doScanner(String packageName) {
		URL resource = 
this.getClass().getClassLoader().getResource("/"+packageName.replaceAll("\\.", "/"));
	    File classDir = new File(resource.getFile());
	    for(File classFile:classDir.listFiles()){
	    	if(classFile.isDirectory()){
	    		doScanner(packageName+"."+classFile.getName());
	    	}else{
	    		String className = (packageName+"."+classFile.getName()).replace(".class", "");
	    		classNames.add(className);
	    	}
	    }
	}

	private void doLoadConfig(String location) {
		InputStream input = this.getClass().getClassLoader().getResourceAsStream(location);
		try {
			contextConfig.load(input);
		} catch (IOException e) {
			e.printStackTrace();
		}finally{
			if(input!=null){
				try {
					input.close();
				} catch (IOException e) {
					e.printStackTrace();
				}
			}
		}
	}

	@Override
	protected void doGet(HttpServletRequest req, HttpServletResponse res)
			throws ServletException, IOException {
		this.doPost(req, res);
	}

	@Override
	protected void doPost(HttpServletRequest req, HttpServletResponse res)
			throws ServletException, IOException {
		doDispatcher(req, res);
	}
	
	public void doDispatcher(HttpServletRequest req,HttpServletResponse res){
		try {
			Handler handler = getHandler(req);
			if(handler==null){
				res.getWriter().write("404 not found.");
				return;
			}
			Class[] paramTypes = handler.method.getParameterTypes();
			Object[] paramValues = new Object[paramTypes.length];
			Map params = req.getParameterMap();
			for(Entry param:params.entrySet()){
				String value = Arrays.toString(param.getValue()).replaceAll("\\[|\\]", "");
				if(!handler.paramIndexMapping.containsKey(param.getKey())){continue;}
				int index = handler.paramIndexMapping.get(param.getKey());
				paramValues[index] = convert(paramTypes[index],value);
			}
			int reqIndex = handler.paramIndexMapping.get(HttpServletRequest.class.getName());
			paramValues[reqIndex] = req;
			int resIndex = handler.paramIndexMapping.get(HttpServletResponse.class.getName());
			paramValues[resIndex] = res;
			handler.method.invoke(handler.controller, paramValues);
		} catch (Exception e) {
			e.printStackTrace();
		}
		String url = req.getRequestURI();
		String contextPath = req.getContextPath();
		url = url.replace(contextPath, "").replaceAll("/+", "/");
		
	}
	private Object convert(Class type, String value) {
		if(Integer.class == type){
			return Integer.valueOf(value);
		}
		return value;
	}

	private String lowerFirstCase(String str){
		char[] chars = str.toCharArray();
		chars[0] += 32;
		return String.valueOf(chars);
	}
	private Handler getHandler(HttpServletRequest req){
		if(handlerMapping.isEmpty()){return null;}
		String url = req.getRequestURI();
		String contextPath = req.getContextPath();
		url = url.replace(contextPath, "").replaceAll("/+", "/");
		for(Handler handler:handlerMapping){
			Matcher matcher = handler.pattern.matcher(url);
			if(!matcher.matches()){continue;}
			return handler;
		}
		return null;
	}
	private class Handler{
		protected Object controller;
		protected Method method;
		protected Pattern pattern;
		protected Map paramIndexMapping;
		protected Handler(Pattern pattern,Object controller,Method method){
			this.pattern = pattern;
			this.controller = controller;
			this.method = method;
			paramIndexMapping = new HashMap();
			putParamIndexMapping(method);
		}
		private void putParamIndexMapping(Method method) {
			Annotation[][] pa = method.getParameterAnnotations();
			for(int i=0;i[] paramTypes = method.getParameterTypes();
			for(int i=0;i type = paramTypes[i];
				if(type == HttpServletRequest.class || type == HttpServletResponse.class){
					paramIndexMapping.put(type.getName(), i);
				}
			}
		}
	}
}

配置文件config.properties

scanPackage=com.xxx.springmvc.web

配置文件web.xml


      springmvc
      com.xxx.springmvc.servlet.XXDispatcherServlet
      
           contextConfigLocation
           config.properties
      
      1
  
  
      springmvc
      /
  

3、运行:

控制台打印日志:

两小时手写springmvc框架_第2张图片

访问 http://localhost:8080/springmvc/user/index?name=bb:

两小时手写springmvc框架_第3张图片

访问 http://localhost:8080/springmvc/user/list:

两小时手写springmvc框架_第4张图片

访问 http://localhost:8080/springmvc/user/detail 出现404:

你可能感兴趣的:(java)