手写Spring源码——实现一个简单的spring framework

这篇文章主要带大家实现一个简单的Spring框架,包含单例、多例bean的获取,依赖注入、懒加载等功能。文章内容会持续更新,感兴趣的小伙伴可以持续关注一下。

目录

一、创建Java项目

二、开始实现Spring

1、创建BeanFactory接口

2、创建ApplicationContext接口

3、创建ApplicationContext接口的实现类

4、实现Spring IOC功能

创建配置类

创建自定义注解

@Lazy

@Bean 

@Scope

@Configuration

@Component

@Repository

@Service

@Controller

@ComponentScan

创建bean的定义

修改AnnotationConfigApplicationContext

5、创建单例bean和非单例bean

6、使用Spring获取bean对象

7、未完待续


一、创建Java项目

首先,需要创建一个Java工程,名字就叫spring。

手写Spring源码——实现一个简单的spring framework_第1张图片

手写Spring源码——实现一个简单的spring framework_第2张图片

手写Spring源码——实现一个简单的spring framework_第3张图片

创建完成后,如下图,再依次创建三级包

手写Spring源码——实现一个简单的spring framework_第4张图片

二、开始实现Spring

Spring中最重要也是最基础的类就是Spring容器,Spring容器用于创建管理对象,为了方便实现类型转换功能,给接口设置一个参数化类型(泛型)。

1、创建BeanFactory接口

BeanFactory是spring容器的顶级接口,在该接口中定义三个重载的获取bean的方法。

package com.example.spring;

/**
 * @author heyunlin
 * @version 1.0
 */
public interface BeanFactory {
    
    Object getBean(String beanName);
    
    T getBean(Class type);
    
    T getBean(String beanName, Class type);
}

2、创建ApplicationContext接口

ApplicationContext接口扩展自BeanFactory接口

package com.example.spring;

/**
 * @author heyunlin
 * @version 1.0
 */
public interface ApplicationContext extends BeanFactory {
    
}

3、创建ApplicationContext接口的实现类

创建一个ApplicationContext接口的实现类,实现接口中定义的所有抽象方法。

package com.example.spring;

/**
 * @author heyunlin
 * @version 1.0
 */
public class AnnotationConfigApplicationContext implements ApplicationContext {

    @Override
    public Object getBean(String beanName) {
        return null;
    }

    @Override
    public T getBean(Class type) {
        return null;
    }

    @Override
    public T getBean(String beanName, Class type) {
        return null;
    }

}

4、实现Spring IOC功能

首先,组件扫描需要一个扫描路径,可以通过配置类上的@ComponentScan注解指定,如果不指定,则默认为配置类所在的包。

创建配置类

在当前包下创建一个类,配置包扫描路径。

package com.example.spring;

/**
 * @author heyunlin
 * @version 1.0
 */
@ComponentScan("com.example.spring")
public class SpringConfig {
    
}

创建自定义注解

@Lazy
package com.example.spring;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 * @author heyunlin
 * @version 1.0
 */
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
public @interface Lazy {

    boolean value() default false;
}

@Bean 
package com.example.spring;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 * @author heyunlin
 * @version 1.0
 */
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface Bean {

    String value();
}

@Scope
package com.example.spring;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 * @author heyunlin
 * @version 1.0
 */
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
public @interface Scope {

    String value();
}

@Configuration
package com.example.spring;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 * @author heyunlin
 * @version 1.0
 */
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Component
public @interface Configuration {

    String value();
}

@Component
package com.example.spring;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 * @author heyunlin
 * @version 1.0
 */
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
public @interface Component {

    String value() default "";
}

@Repository
package com.example.spring;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 * @author heyunlin
 * @version 1.0
 */
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Component
public @interface Repository {

    String value();
}

@Service
package com.example.spring;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 * @author heyunlin
 * @version 1.0
 */
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Component
public @interface Service {

    String value();
}

@Controller
package com.example.spring;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 * @author heyunlin
 * @version 1.0
 */
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Component
public @interface Controller {

    String value();
}

@ComponentScan
package com.example.spring;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 * @author heyunlin
 * @version 1.0
 */
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
public @interface ComponentScan {
    
    String value();
}

创建bean的定义

package com.example.spring;

/**
 * @author heyunlin
 * @version 1.0
 */
public class BeanDefinition {

    /**
     * bean的类型
     */
    private Class type;

    /**
     * bean的作用域
     */
    private String scope;

    /**
     * 是否懒加载
     */
    private boolean lazy;

    public Class getType() {
        return type;
    }

    public void setType(Class type) {
        this.type = type;
    }

    public String getScope() {
        return scope;
    }

    public void setScope(String scope) {
        this.scope = scope;
    }

    public boolean isLazy() {
        return lazy;
    }

    public void setLazy(boolean lazy) {
        this.lazy = lazy;
    }

}

修改AnnotationConfigApplicationContext

1、添加一个属性clazz,用于保存实例化时传递的配置类对象参数

2、Spring容器中创建用于保存bean的定义的map

3、创建单例对象池,也是一个map,保存单例bean

package com.example.spring;

import java.io.File;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.net.URL;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * @author heyunlin
 * @version 1.0
 */
public class AnnotationConfigApplicationContext implements ApplicationContext {

    private Map beanDefinitionMap = new HashMap<>();

    /**
     * 单例对象池
     */
    private Map singletonObjects = new HashMap<>();

    public final Class clazz;

    public AnnotationConfigApplicationContext(Class clazz) throws ClassNotFoundException {
        this.clazz = clazz;

        // 扫描组件,保存到BeanDefinition中
        scan(clazz);

        // 把组件中非懒加载的单例bean保存到单例池
        for (Map.Entry entry : beanDefinitionMap.entrySet()) {
            String beanName = entry.getKey();
            BeanDefinition beanDefinition = entry.getValue();

            if(isSingleton(beanDefinition.getScope()) && !beanDefinition.isLazy()) {
                Object bean = createBean(beanDefinition);

                singletonObjects.put(beanName, bean);
            }
        }
    }

    /**
     * 创建bean对象
     * @param beanDefinition bean的定义
     * @return Object 创建好的bean对象
     */
    private Object createBean(BeanDefinition beanDefinition) {
        Object bean = null;
        Class beanType = beanDefinition.getType();

        // 获取所有构造方法
        Constructor[] constructors = beanType.getConstructors();

        try {
            /**
             * 推断构造方法
             * 1、没有提供构造方法:调用默认的无参构造
             * 2、提供了构造方法:
             *   - 构造方法个数为1
             *     - 构造方法参数个数为0:无参构造
             *     - 构造方法参数个数不为0:传入多个为空的参数
             *   - 构造方法个数 > 1:推断失败,抛出异常
             */
            // 无参构造方法
            Constructor constructor = beanType.getConstructor();

            bean = constructor.newInstance();
//            if (isEmpty(constructors)) {
//                // 无参构造方法
//                Constructor constructor = beanType.getConstructor();
//
//                bean = constructor.newInstance();
//            } else if (constructors.length == 1) {
//                Constructor constructor = constructors[0];
//                // 得到构造方法参数个数
//                int parameterCount = constructor.getParameterCount();
//
//                if (parameterCount == 0) {
//                    // 无参构造方法
//                    bean = constructor.newInstance();
//                } else {
//                    // 多个参数的构造方法
//                    Object[] array = new Object[parameterCount];
//
//                    bean = constructor.newInstance(array);
//                }
//            } else {
//                throw new IllegalStateException();
//            }
        } catch (InstantiationException | IllegalAccessException | InvocationTargetException | NoSuchMethodException e) {
            e.printStackTrace();
        }

        return bean;
    }

    private boolean isEmpty(Object[] array) {
        return array.length == 0;
    }

    private boolean isSingleton(String scope) {
        return "singleton".equals(scope);
    }

    private void scan(Class clazz) throws ClassNotFoundException {
        if (clazz.isAnnotationPresent(ComponentScan.class)) {
            ComponentScan componentScan = clazz.getAnnotation(ComponentScan.class);
            String value = componentScan.value();

            if (!"".equals(value)) {
                String path = value;
                path = path.replace(".", "/");

                URL resource = clazz.getClassLoader().getResource(path);
                File file = new File(resource.getFile());

                loopFor(file);
            }
        }
    }

    private void loopFor(File file) throws ClassNotFoundException {
        if (file.isDirectory()) {
            for (File listFile : file.listFiles()) {
                if (listFile.isDirectory()) {
                    loopFor(listFile);

                    continue;
                }

                toBeanDefinitionMap(listFile);
            }
        } else if (file.isFile()) {
            toBeanDefinitionMap(file);
        }
    }

    private void toBeanDefinitionMap(File file) throws ClassNotFoundException {
        String absolutePath = file.getAbsolutePath();
        absolutePath = absolutePath.substring(absolutePath.indexOf("com"), absolutePath.indexOf(".class"));
        absolutePath = absolutePath.replace("\\", ".");
        Class loadClass = clazz.getClassLoader().loadClass(absolutePath);

        String beanName;

        if (loadClass.isAnnotationPresent(Component.class)) {
            Component component = loadClass.getAnnotation(Component.class);
            beanName = component.value();

            if ("".equals(beanName)) {
                beanName = getBeanName(loadClass);
            }

            boolean lazy = false;
            String scope = "singleton";

            // 类上使用了@Scope注解
            if (loadClass.isAnnotationPresent(Scope.class)) {
                // 获取@Scope注解
                Scope annotation = loadClass.getAnnotation(Scope.class);

                // 单例
                if (isSingleton(annotation.value())) {
                    if (loadClass.isAnnotationPresent(Lazy.class)) {
                        Lazy loadClassAnnotation = loadClass.getAnnotation(Lazy.class);

                        if (loadClassAnnotation.value()) {
                            lazy = true;
                        }
                    }
                } else {
                    // 非单例
                    scope = annotation.value();
                }
            } else {
                // 类上没有使用@Scope注解,默认是单例的
                if (loadClass.isAnnotationPresent(Lazy.class)) {
                    Lazy annotation = loadClass.getAnnotation(Lazy.class);

                    if (annotation.value()) {
                        lazy = true;
                    }
                }
            }

            // 保存bean的定义
            BeanDefinition beanDefinition = new BeanDefinition();

            beanDefinition.setType(loadClass);
            beanDefinition.setLazy(lazy);
            beanDefinition.setScope(scope);

            beanDefinitionMap.put(beanName, beanDefinition);
        }
    }

    /**
     * 根据类对象获取beanName
     * @param loadClass bean的Class对象
     * @return String beanName
     */
    private String getBeanName(Class loadClass) {
        String beanName = loadClass.getSimpleName();

        // 判断是否以双大写字母开头
        String className = beanName.replaceAll("([A-Z])([A-Z])", "$1_$2");

        // 正常的大驼峰命名:bean名称为类名首字母大写
        if (className.indexOf("_") != 1) {
            beanName = beanName.substring(0, 1).toLowerCase().concat(beanName.substring(1));
        }
//        else { // 否则,bean名称为类名
//            beanName = beanName;
//        }

        return beanName;
    }

    @Override
    public Object getBean(String beanName) {
        BeanDefinition beanDefinition = beanDefinitionMap.get(beanName);

        if (beanDefinition == null) {
            throw new NullPointerException();
        }

        return getBean(beanName, beanDefinition);
    }

    @Override
    public T getBean(Class type) {
        if (type == null) {
            throw new IllegalStateException("bean类型不能为空!");
        }

        // 保存指定类型的bean的个数
        AtomicInteger count = new AtomicInteger();
        // 保存同一类型的bean
        Map objectMap = new HashMap<>();

        for (Map.Entry entry : beanDefinitionMap.entrySet()) {
            String beanName = entry.getKey();
            BeanDefinition beanDefinition = entry.getValue();
            Class beanType = beanDefinition.getType();

            if (beanType.equals(type)) {
                count.addAndGet(1);
                objectMap.put(beanName, beanDefinition);
            }
        }

        if (count.get() == 0 || count.get() > 1) {
            throw new IllegalStateException();
        }

        return (T) getBean((String) objectMap.keySet().toArray()[0], (BeanDefinition) objectMap.values().toArray()[0]);
    }

    @Override
    public T getBean(String beanName, Class type) {
        if (type == null) {
            throw new IllegalStateException("bean类型不能为空!");
        }

        BeanDefinition beanDefinition = beanDefinitionMap.get(beanName);

        if (type.equals(beanDefinition.getType())) {
            return (T) getBean(beanName, beanDefinition);
        }

        throw new IllegalStateException();
    }

    /**
     * 统一获取bean的方法
     * @param beanName
     * @param beanDefinition
     * @return
     */
    private Object getBean(String beanName, BeanDefinition beanDefinition) {
        String scope = beanDefinition.getScope();

        if (isSingleton(scope)) {
            Object object = singletonObjects.get(beanName);

            // 懒加载的单例bean
            if (object == null) {
                Object bean = createBean(beanDefinition);

                singletonObjects.put(beanName, bean);
            }

            return singletonObjects.get(beanName);
        }

        return createBean(beanDefinition);
    }

}

5、创建单例bean和非单例bean

创建一个UserService的单例bean

package com.example.spring;

/**
 * @author heyunlin
 * @version 1.0
 */
@Component
public class UserService {

}

创建一个UserMapper的非单例bean

package com.example.spring;

/**
 * @author heyunlin
 * @version 1.0
 */
@Component
@Scope("prototype")
public class UserMapper {

}

6、使用Spring获取bean对象

package com.example.spring;

/**
 * @author heyunlin
 * @version 1.0
 */
public class SpringExample {

    public static void main(String[] args) throws ClassNotFoundException {
        AnnotationConfigApplicationContext applicationContext = new AnnotationConfigApplicationContext(SpringConfig.class);


        Object userService1 = applicationContext.getBean(UserService.class);
        Object userService2 = applicationContext.getBean(UserService.class);
        Object userService3 = applicationContext.getBean("userService");
        Object userService4 = applicationContext.getBean("userService");
        Object userService5 = applicationContext.getBean("userService", UserService.class);
        Object userService6 = applicationContext.getBean("userService", UserService.class);

        System.out.println(userService1);
        System.out.println(userService2);
        System.out.println(userService3);
        System.out.println(userService4);
        System.out.println(userService5);
        System.out.println(userService6);

        System.out.println("----------------------------------------------------");

        Object userMapper1 = applicationContext.getBean(UserMapper.class);
        Object userMapper2 = applicationContext.getBean(UserMapper.class);
        Object userMapper3 = applicationContext.getBean("userMapper");
        Object userMapper4 = applicationContext.getBean("userMapper");
        Object userMapper5 = applicationContext.getBean("userMapper", UserMapper.class);
        Object userMapper6 = applicationContext.getBean("userMapper", UserMapper.class);

        System.out.println(userMapper1);
        System.out.println(userMapper2);
        System.out.println(userMapper3);
        System.out.println(userMapper4);
        System.out.println(userMapper5);
        System.out.println(userMapper6);
    }

}

通过上面三种方法获取到的bean是同一个

7、未完待续

文章和代码持续更新中,敬请期待~

手写Spring Framework源码icon-default.png?t=N6B9https://gitee.com/he-yunlin/spring.git

你可能感兴趣的:(spring,数据库,mysql)