Spring Boot 之web Filter --支持排序的使用扩展
为了spring boot支持注解@WebFilter("/*")的web filter组件排序,我们需要对其扩展。
本博客对web filter的排序支持注解@Order(Integer.MAX_VALUE),也支持spring的Ordered接口。 跟踪源码,web fiter的注册主要由ServletContextInitializerBeans实现,但由于sprirng对其 某些方法访问权限的控制,我只能copy一份,修改方法权限和利用反射技术实现。
package com.sdcuike.spring.extend.web;
import java.lang.reflect.Field;
import java.util.AbstractCollection;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.EventListener;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import javax.servlet.Filter;
import javax.servlet.MultipartConfigElement;
import javax.servlet.Servlet;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.aop.scope.ScopedProxyUtils;
import org.springframework.beans.factory.ListableBeanFactory;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.boot.web.servlet.DelegatingFilterProxyRegistrationBean;
import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.boot.web.servlet.RegistrationBean;
import org.springframework.boot.web.servlet.ServletContextInitializer;
import org.springframework.boot.web.servlet.ServletContextInitializerBeans;
import org.springframework.boot.web.servlet.ServletListenerRegistrationBean;
import org.springframework.boot.web.servlet.ServletRegistrationBean;
import org.springframework.core.annotation.AnnotationAwareOrderComparator;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
/**
* copy自{@link ServletContextInitializerBeans},主要解决了反射解决spring的小bug及访问控制权限,以利于自定义扩展
*
* @author sdcuike
*
* Created on 2017.02.13
*
*/
public class ServletContextInitializerBeansModify extends AbstractCollection {
private static final String DISPATCHER_SERVLET_NAME = "dispatcherServlet";
protected static final Log logger = LogFactory
.getLog(ServletContextInitializerBeansModify.class);
/**
* Seen bean instances or bean names.
*/
protected final Set seen = new HashSet();
protected final MultiValueMap, ServletContextInitializer> initializers;
private List sortedList;
public ServletContextInitializerBeansModify(ListableBeanFactory beanFactory) {
this.initializers = new LinkedMultiValueMap, ServletContextInitializer>();
addServletContextInitializerBeans(beanFactory);
addAdaptableBeans(beanFactory);
List sortedInitializers = new ArrayList();
for (Map.Entry, List> entry : this.initializers
.entrySet()) {
AnnotationAwareOrderComparator.sort(entry.getValue());
sortedInitializers.addAll(entry.getValue());
}
this.sortedList = Collections.unmodifiableList(sortedInitializers);
}
private void addServletContextInitializerBeans(ListableBeanFactory beanFactory) {
for (Entry initializerBean : getOrderedBeansOfType(
beanFactory, ServletContextInitializer.class)) {
addServletContextInitializerBean(initializerBean.getKey(),
initializerBean.getValue(), beanFactory);
}
}
private void addServletContextInitializerBean(String beanName,
ServletContextInitializer initializer, ListableBeanFactory beanFactory) {
if (initializer instanceof ServletRegistrationBean) {
Servlet source = getServlet((ServletRegistrationBean) initializer);
addServletContextInitializerBean(Servlet.class, beanName, initializer,
beanFactory, source);
} else if (initializer instanceof FilterRegistrationBean) {
Filter source = ((FilterRegistrationBean) initializer).getFilter();
addServletContextInitializerBean(Filter.class, beanName, initializer,
beanFactory, source);
} else if (initializer instanceof DelegatingFilterProxyRegistrationBean) {
String source = getTargetFilterName((DelegatingFilterProxyRegistrationBean) initializer);
addServletContextInitializerBean(Filter.class, beanName, initializer,
beanFactory, source);
} else if (initializer instanceof ServletListenerRegistrationBean) {
EventListener source = ((ServletListenerRegistrationBean>) initializer)
.getListener();
addServletContextInitializerBean(EventListener.class, beanName, initializer,
beanFactory, source);
} else {
addServletContextInitializerBean(ServletContextInitializer.class, beanName,
initializer, beanFactory, initializer);
}
}
protected void addServletContextInitializerBean(Class> type, String beanName,
ServletContextInitializer initializer, ListableBeanFactory beanFactory,
Object source) {
this.initializers.add(type, initializer);
if (source != null) {
// Mark the underlying source as seen in case it wraps an existing bean
this.seen.add(source);
}
if (ServletContextInitializerBeansModify.logger.isDebugEnabled()) {
String resourceDescription = getResourceDescription(beanName, beanFactory);
int order = getOrder(initializer);
ServletContextInitializerBeansModify.logger.debug("Added existing "
+ type.getSimpleName() + " initializer bean '" + beanName
+ "'; order=" + order + ", resource=" + resourceDescription);
}
}
protected final String getResourceDescription(String beanName,
ListableBeanFactory beanFactory) {
if (beanFactory instanceof BeanDefinitionRegistry) {
BeanDefinitionRegistry registry = (BeanDefinitionRegistry) beanFactory;
return registry.getBeanDefinition(beanName).getResourceDescription();
}
return "unknown";
}
@SuppressWarnings("unchecked")
private void addAdaptableBeans(ListableBeanFactory beanFactory) {
MultipartConfigElement multipartConfig = getMultipartConfig(beanFactory);
addAsRegistrationBean(beanFactory, Servlet.class,
new ServletRegistrationBeanAdapter(multipartConfig));
addAsRegistrationBean(beanFactory, Filter.class,
new FilterRegistrationBeanAdapter());
for (Class> listenerType : ServletListenerRegistrationBean
.getSupportedTypes()) {
addAsRegistrationBean(beanFactory, EventListener.class,
(Class) listenerType,
new ServletListenerRegistrationBeanAdapter());
}
}
private MultipartConfigElement getMultipartConfig(ListableBeanFactory beanFactory) {
List> beans = getOrderedBeansOfType(
beanFactory, MultipartConfigElement.class);
return (beans.isEmpty() ? null : beans.get(0).getValue());
}
private void addAsRegistrationBean(ListableBeanFactory beanFactory, Class type,
RegistrationBeanAdapter adapter) {
addAsRegistrationBean(beanFactory, type, type, adapter);
}
private void addAsRegistrationBean(ListableBeanFactory beanFactory,
Class type, Class beanType, RegistrationBeanAdapter adapter) {
List> beans = getOrderedBeansOfType(beanFactory, beanType,
this.seen);
for (Entry bean : beans) {
if (this.seen.add(bean.getValue())) {
int order = getOrder(bean.getValue());
String beanName = bean.getKey();
// One that we haven't already seen
RegistrationBean registration = adapter.createRegistrationBean(beanName,
bean.getValue(), beans.size());
registration.setName(beanName);
registration.setOrder(order);
this.initializers.add(type, registration);
if (ServletContextInitializerBeansModify.logger.isDebugEnabled()) {
ServletContextInitializerBeansModify.logger.debug(
"Created " + type.getSimpleName() + " initializer for bean '"
+ beanName + "'; order=" + order + ", resource="
+ getResourceDescription(beanName, beanFactory));
}
}
}
}
private final int getOrder(Object value) {
return new AnnotationAwareOrderComparator() {
@Override
public int getOrder(Object obj) {
return super.getOrder(obj);
}
}.getOrder(value);
}
private List> getOrderedBeansOfType(
ListableBeanFactory beanFactory, Class type) {
return getOrderedBeansOfType(beanFactory, type, Collections.emptySet());
}
private List> getOrderedBeansOfType(
ListableBeanFactory beanFactory, Class type, Set> excludes) {
List> beans = new ArrayList>();
Comparator> comparator = new Comparator>() {
@Override
public int compare(Entry o1, Entry o2) {
return AnnotationAwareOrderComparator.INSTANCE.compare(o1.getValue(),
o2.getValue());
}
};
String[] names = beanFactory.getBeanNamesForType(type, true, false);
Map map = new LinkedHashMap();
for (String name : names) {
if (!excludes.contains(name) && !ScopedProxyUtils.isScopedTarget(name)) {
T bean = beanFactory.getBean(name, type);
if (!excludes.contains(bean)) {
map.put(name, bean);
}
}
}
beans.addAll(map.entrySet());
Collections.sort(beans, comparator);
return beans;
}
@Override
public Iterator iterator() {
return this.sortedList.iterator();
}
@Override
public int size() {
return this.sortedList.size();
}
/**
* Adapter to convert a given Bean type into a {@link RegistrationBean} (and hence a {@link ServletContextInitializer}.
*/
protected interface RegistrationBeanAdapter {
RegistrationBean createRegistrationBean(String name, T source,
int totalNumberOfSourceBeans);
}
/**
* {@link RegistrationBeanAdapter} for {@link Servlet} beans.
*/
private static class ServletRegistrationBeanAdapter
implements RegistrationBeanAdapter {
private final MultipartConfigElement multipartConfig;
ServletRegistrationBeanAdapter(MultipartConfigElement multipartConfig) {
this.multipartConfig = multipartConfig;
}
@Override
public RegistrationBean createRegistrationBean(String name, Servlet source,
int totalNumberOfSourceBeans) {
String url = (totalNumberOfSourceBeans == 1 ? "/" : "/" + name + "/");
if (name.equals(DISPATCHER_SERVLET_NAME)) {
url = "/"; // always map the main dispatcherServlet to "/"
}
ServletRegistrationBean bean = new ServletRegistrationBean(source, url);
bean.setMultipartConfig(this.multipartConfig);
return bean;
}
}
/**
* {@link RegistrationBeanAdapter} for {@link Filter} beans.
*/
private static class FilterRegistrationBeanAdapter
implements RegistrationBeanAdapter {
@Override
public RegistrationBean createRegistrationBean(String name, Filter source,
int totalNumberOfSourceBeans) {
return new FilterRegistrationBean(source);
}
}
/**
* {@link RegistrationBeanAdapter} for certain {@link EventListener} beans.
*/
private static class ServletListenerRegistrationBeanAdapter
implements RegistrationBeanAdapter {
@Override
public RegistrationBean createRegistrationBean(String name, EventListener source,
int totalNumberOfSourceBeans) {
return new ServletListenerRegistrationBean(source);
}
}
private Servlet getServlet(ServletRegistrationBean object) {
try {
Field field = object.getClass().getDeclaredField("servlet");
field.setAccessible(true);
return (Servlet) field.get(object);
} catch (NoSuchFieldException | SecurityException | IllegalArgumentException | IllegalAccessException e) {
throw new RuntimeException(e);
}
}
private String getTargetFilterName(DelegatingFilterProxyRegistrationBean object) {
try {
Field field = object.getClass().getDeclaredField("targetBeanName");
field.setAccessible(true);
return (String) field.get(object);
} catch (NoSuchFieldException | SecurityException | IllegalArgumentException | IllegalAccessException e) {
throw new RuntimeException(e);
}
}
}
类的名字没有起好,先这样凑合吧。
利用反射破除sprirng某些方法的访问,有下面两个方法:
private Servlet getServlet(ServletRegistrationBean object) {
try {
Field field = object.getClass().getDeclaredField("servlet");
field.setAccessible(true);
return (Servlet) field.get(object);
} catch (NoSuchFieldException | SecurityException | IllegalArgumentException | IllegalAccessException e) {
throw new RuntimeException(e);
}
}
private String getTargetFilterName(DelegatingFilterProxyRegistrationBean object) {
try {
Field field = object.getClass().getDeclaredField("targetBeanName");
field.setAccessible(true);
return (String) field.get(object);
} catch (NoSuchFieldException | SecurityException | IllegalArgumentException | IllegalAccessException e) {
throw new RuntimeException(e);
}
}
因为启动过程中就会注册,所以反射没加缓存,不影响性能吧。
修改了下面的方法的访问权限,以便子类复写(为了区别是spring的代码和自己扩展的代码,
我写了两个类,便于区分,没写到一个类中):
protected void addServletContextInitializerBean(Class> type, String beanName,
ServletContextInitializer initializer, ListableBeanFactory beanFactory,
Object source)
我们开始对web filter的注册做处理:
package com.sdcuike.spring.extend.web;
import javax.servlet.Filter;
import org.springframework.beans.factory.ListableBeanFactory;
import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.boot.web.servlet.ServletContextInitializer;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.OrderUtils;
/**
* @author sdcuike
*
* Created on 2017.02.13
*
*/
public class ServletContextInitializerBeansExtend extends ServletContextInitializerBeansModify {
public ServletContextInitializerBeansExtend(ListableBeanFactory beanFactory) {
super(beanFactory);
}
@Override
protected void addServletContextInitializerBean(Class> type, String beanName, ServletContextInitializer initializer, ListableBeanFactory beanFactory, Object source) {
Integer order = OrderUtils.getOrder(source.getClass());
if (order == null && source instanceof Ordered) {
order = ((Ordered) source).getOrder();
}
if (Filter.class == type && order != null) {
FilterRegistrationBean filterRegistrationBean = (FilterRegistrationBean) initializer;
filterRegistrationBean.setOrder(order);
}
super.addServletContextInitializerBean(type, beanName, initializer, beanFactory, source);
}
}
主要对继承filter接口的类做处理即注册fiter:排序的规则注解
@Order 优先于接口
Ordered
的值。
为了使用我们扩展的注册类,我们还需要实现一个EmbeddedWebApplicationContext :
package com.sdcuike.spring.extend.web;
import java.util.Collection;
import org.springframework.boot.context.embedded.EmbeddedWebApplicationContext;
import org.springframework.boot.web.servlet.ServletContextInitializer;
/**
* 自定义EmbeddedWebApplicationContext,主要为了支持@WebFilter 支持 @Order 排序
*
* @author sdcuike
*
* Created on 2017.02.13
*
*/
public class EmbeddedWebApplicationContextExtend extends EmbeddedWebApplicationContext {
@Override
protected Collection getServletContextInitializerBeans() {
return new ServletContextInitializerBeansExtend(getBeanFactory());
}
}
启动spring boot应用:
package com.sdcuike.practice;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import com.sdcuike.spring.extend.web.EmbeddedWebApplicationContextExtend;
/**
* @author sdcuike
*
* Created on 2016.12
*
* We generally recommend that you locate your main application class in a root package above other classes.
*
* We recommend that you follow Java’s recommended package naming conventions and use a reversed domain name (for example, com.example.project)
*/
@SpringBootApplication
public class Application {
public static void main(String[] args) {
SpringApplication springApplication = new SpringApplication(Application.class);
springApplication.setApplicationContextClass(EmbeddedWebApplicationContextExtend.class);
springApplication.run(args);
}
}
相关源码:
https://github.com/sdcuike/spring-boot-practice/tree/master/src/main/java/com/sdcuike/spring/extend/web
https://github.com/sdcuike/spring-boot-practice/blob/master/src/main/java/com/sdcuike/practice/Application.java< spring-boot .version>1.5.1.RELEASE spring-boot .version>