Additionally check for null on registring Servlets and Filters

See gh-482
This commit is contained in:
Dave Syer 2014-03-13 13:12:44 +00:00
parent 3d43771136
commit 7f8316708a
3 changed files with 29 additions and 3 deletions

View File

@ -230,7 +230,14 @@ public class FilterRegistrationBean extends RegistrationBean {
@Override
public void onStartup(ServletContext servletContext) throws ServletException {
Assert.notNull(this.filter, "Filter must not be null");
configure(servletContext.addFilter(getOrDeduceName(this.filter), this.filter));
String name = getOrDeduceName(this.filter);
FilterRegistration.Dynamic added = servletContext.addFilter(name, this.filter);
if (added == null) {
logger.info("Filter " + name
+ " was not registered (possibly already registered?)");
return;
}
configure(added);
}
/**

View File

@ -26,6 +26,7 @@ import javax.servlet.Servlet;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.ServletRegistration;
import javax.servlet.ServletRegistration.Dynamic;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@ -156,8 +157,15 @@ public class ServletRegistrationBean extends RegistrationBean {
@Override
public void onStartup(ServletContext servletContext) throws ServletException {
Assert.notNull(this.servlet, "Servlet must not be null");
logger.info("Mapping servlet: '" + getServletName() + "' to " + this.urlMappings);
configure(servletContext.addServlet(getServletName(), this.servlet));
String name = getServletName();
logger.info("Mapping servlet: '" + name + "' to " + this.urlMappings);
Dynamic added = servletContext.addServlet(name, this.servlet);
if (added == null) {
logger.info("Servlet " + name
+ " was not registered (possibly already registered?)");
return;
}
configure(added);
}
/**

View File

@ -38,6 +38,7 @@ import org.mockito.MockitoAnnotations;
import static org.mockito.BDDMockito.given;
import static org.mockito.Matchers.anyObject;
import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
/**
@ -79,6 +80,16 @@ public class ServletRegistrationBeanTests {
verify(this.registration).addMapping("/*");
}
@Test
public void startupWithDoubleRegistration() throws Exception {
ServletRegistrationBean bean = new ServletRegistrationBean(this.servlet);
given(this.servletContext.addServlet(anyString(), (Servlet) anyObject()))
.willReturn(null);
bean.onStartup(this.servletContext);
verify(this.servletContext).addServlet("mockServlet", this.servlet);
verify(this.registration, times(0)).setAsyncSupported(true);
}
@Test
public void startupWithSpecifiedValues() throws Exception {
ServletRegistrationBean bean = new ServletRegistrationBean();