[bs-61] Allow non-default servlets and filters to be registered

* The RegistrationBean (ServletInitializer) now exposes a registration
target object, and this is used to prevent double registration of those
objects.
* If there is a Servlet with bean id "dispatcherServlet" it is mapped to
"/" (unless already registered as a ServletRegistrationBean).

[Fixes #48645559]
This commit is contained in:
Dave Syer 2013-05-29 07:41:48 +01:00
parent bf30e2de90
commit aabff1a774
5 changed files with 120 additions and 27 deletions

View File

@ -20,6 +20,7 @@ import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map.Entry;
@ -87,6 +88,14 @@ import org.springframework.web.context.support.WebApplicationContextUtils;
*/
public class EmbeddedWebApplicationContext extends GenericWebApplicationContext {
/**
* Constant value for the DispatcherServlet bean name. A Servlet bean with this name
* is deemed to be the "main" servlet and is automatically given a mapping of "/" by
* default. To change the default behaviour you can use a
* {@link ServletRegistrationBean} or a different bean name.
*/
public static final String DISPATCHER_SERVLET_NAME = "dispatcherServlet";
private EmbeddedServletContainer embeddedServletContainer;
private ServletConfig servletConfig;
@ -184,29 +193,45 @@ public class EmbeddedWebApplicationContext extends GenericWebApplicationContext
protected Collection<ServletContextInitializer> getServletContextInitializerBeans() {
Set<ServletContextInitializer> initializers = new LinkedHashSet<ServletContextInitializer>();
Set<Object> targets = new HashSet<Object>();
for (Entry<String, ServletContextInitializer> initializerBean : getOrderedBeansOfType(ServletContextInitializer.class)) {
initializers.add(initializerBean.getValue());
ServletContextInitializer initializer = initializerBean.getValue();
if (initializer instanceof RegistrationBean) {
targets.add(((RegistrationBean) initializer).getRegistrationTarget());
}
if (initializer instanceof ServletRegistrationBean) {
targets.addAll(((ServletRegistrationBean) initializer).getFilters());
}
initializers.add(initializer);
}
if (initializers.isEmpty()) {
List<Entry<String, Servlet>> servletBeans = getOrderedBeansOfType(Servlet.class);
for (Entry<String, Servlet> servletBean : servletBeans) {
String url = (servletBeans.size() == 1 ? "/" : "/"
+ servletBean.getKey().toLowerCase() + "/*");
ServletRegistrationBean registration = new ServletRegistrationBean(
servletBean.getValue(), url);
registration.setName(servletBean.getKey());
initializers.add(registration);
List<Entry<String, Servlet>> servletBeans = getOrderedBeansOfType(Servlet.class);
for (Entry<String, Servlet> servletBean : servletBeans) {
String name = servletBean.getKey();
Servlet servlet = servletBean.getValue();
if (targets.contains(servlet)) {
continue;
}
for (Entry<String, Filter> filterBean : getOrderedBeansOfType(Filter.class)) {
FilterRegistrationBean registration = new FilterRegistrationBean(
filterBean.getValue());
registration.setName(filterBean.getKey());
initializers.add(registration);
String url = (servletBeans.size() == 1 ? "/" : "/" + name + "/*");
if (name.equals(DISPATCHER_SERVLET_NAME)) {
url = "/"; // always map the main dispatcherServlet to "/"
}
ServletRegistrationBean registration = new ServletRegistrationBean(servlet,
url);
registration.setName(name);
initializers.add(registration);
}
for (Entry<String, Filter> filterBean : getOrderedBeansOfType(Filter.class)) {
String name = filterBean.getKey();
Filter filter = filterBean.getValue();
if (targets.contains(filter)) {
continue;
}
FilterRegistrationBean registration = new FilterRegistrationBean(filter);
registration.setName(name);
initializers.add(registration);
}
return initializers;

View File

@ -216,7 +216,12 @@ public class FilterRegistrationBean extends RegistrationBean {
@Override
public void onStartup(ServletContext servletContext) throws ServletException {
Assert.notNull(this.filter, "Filter must not be null");
configure(servletContext.addFilter(getOrDeduceName(this.filter), this.filter));
configure(servletContext.addFilter(getName(), this.filter));
}
@Override
public Object getRegistrationTarget() {
return this.filter;
}
/**

View File

@ -49,6 +49,13 @@ public abstract class RegistrationBean implements ServletContextInitializer {
this.name = name;
}
/**
* @return the name
*/
public String getName() {
return getOrDeduceName(getRegistrationTarget());
}
/**
* Sets if asynchronous operations are support for this registration. If not specified
* defaults to {@code true}.
@ -61,7 +68,7 @@ public abstract class RegistrationBean implements ServletContextInitializer {
* Returns if asynchronous operations are support for this registration.
*/
public boolean isAsyncSupported() {
return asyncSupported;
return this.asyncSupported;
}
/**
@ -92,12 +99,20 @@ public abstract class RegistrationBean implements ServletContextInitializer {
this.initParameters.put(name, value);
}
/**
* The target of the registration (e.g. a Servlet or a Filter) that can be used to
* guess its name if none is supplied explicitly.
*
* @return the target of this registration
*/
public abstract Object getRegistrationTarget();
/**
* Deduces the name for this registration. Will return user specified name or fallback
* to convention based naming.
* @param value the object used for convention based names
*/
protected final String getOrDeduceName(Object value) {
private String getOrDeduceName(Object value) {
return (this.name != null ? this.name : Conventions.getVariableName(value));
}

View File

@ -149,7 +149,12 @@ public class ServletRegistrationBean extends RegistrationBean {
* Returns the servlet name that will be registered.
*/
public String getServletName() {
return getOrDeduceName(this.servlet);
return getName();
}
@Override
public Object getRegistrationTarget() {
return this.servlet;
}
@Override

View File

@ -49,9 +49,9 @@ import static org.junit.Assert.assertThat;
import static org.mockito.BDDMockito.given;
import static org.mockito.Matchers.anyObject;
import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.atMost;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.withSettings;
@ -190,9 +190,32 @@ public class EmbeddedWebApplicationContextTests {
ordered.verify(servletContext).addServlet("servletBean1", servlet1);
ordered.verify(servletContext).addServlet("servletBean2", servlet2);
verify(escf.getRegisteredServlet(0).getRegistration()).addMapping(
"/servletbean1/*");
"/servletBean1/*");
verify(escf.getRegisteredServlet(1).getRegistration()).addMapping(
"/servletbean2/*");
"/servletBean2/*");
}
@Test
public void multipleServletBeansWithMainDispatcher() throws Exception {
addEmbeddedServletContainerFactoryBean();
Servlet servlet1 = mock(Servlet.class,
withSettings().extraInterfaces(Ordered.class));
given(((Ordered) servlet1).getOrder()).willReturn(1);
Servlet servlet2 = mock(Servlet.class,
withSettings().extraInterfaces(Ordered.class));
given(((Ordered) servlet2).getOrder()).willReturn(2);
this.context.registerBeanDefinition("servletBean2", beanDefinition(servlet2));
this.context
.registerBeanDefinition("dispatcherServlet", beanDefinition(servlet1));
this.context.refresh();
MockEmbeddedServletContainerFactory escf = getEmbeddedServletContainerFactory();
ServletContext servletContext = escf.getServletContext();
InOrder ordered = inOrder(servletContext);
ordered.verify(servletContext).addServlet("dispatcherServlet", servlet1);
ordered.verify(servletContext).addServlet("servletBean2", servlet2);
verify(escf.getRegisteredServlet(0).getRegistration()).addMapping("/");
verify(escf.getRegisteredServlet(1).getRegistration()).addMapping(
"/servletBean2/*");
}
@Test
@ -258,7 +281,8 @@ public class EmbeddedWebApplicationContextTests {
}
@Test
public void servletContextInitializerBeansSkipsServletsAndFilters() throws Exception {
public void servletContextInitializerBeansDoesNotSkipServletsAndFilters()
throws Exception {
addEmbeddedServletContainerFactoryBean();
ServletContextInitializer initializer = mock(ServletContextInitializer.class);
Servlet servlet = mock(Servlet.class);
@ -271,8 +295,27 @@ public class EmbeddedWebApplicationContextTests {
ServletContext servletContext = getEmbeddedServletContainerFactory()
.getServletContext();
verify(initializer).onStartup(servletContext);
verify(servletContext, never()).addServlet(anyString(), (Servlet) anyObject());
verify(servletContext, never()).addFilter(anyString(), (Filter) anyObject());
verify(servletContext).addServlet(anyString(), (Servlet) anyObject());
verify(servletContext).addFilter(anyString(), (Filter) anyObject());
}
@Test
public void servletContextInitializerBeansSkipsRegisteredServletsAndFilters()
throws Exception {
addEmbeddedServletContainerFactoryBean();
Servlet servlet = mock(Servlet.class);
Filter filter = mock(Filter.class);
ServletRegistrationBean initializer = new ServletRegistrationBean(servlet, "/foo");
initializer.addFilters(filter);
this.context.registerBeanDefinition("initializerBean",
beanDefinition(initializer));
this.context.registerBeanDefinition("servletBean", beanDefinition(servlet));
this.context.registerBeanDefinition("filterBean", beanDefinition(filter));
this.context.refresh();
ServletContext servletContext = getEmbeddedServletContainerFactory()
.getServletContext();
verify(servletContext, atMost(1)).addServlet(anyString(), (Servlet) anyObject());
verify(servletContext, atMost(1)).addFilter(anyString(), (Filter) anyObject());
}
@Test