diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/servlet/AbstractFilterRegistrationBean.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/servlet/AbstractFilterRegistrationBean.java index 4c0593d93f0..7935b4d104f 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/servlet/AbstractFilterRegistrationBean.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/servlet/AbstractFilterRegistrationBean.java @@ -30,7 +30,9 @@ import javax.servlet.FilterRegistration.Dynamic; import javax.servlet.ServletContext; import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; import org.springframework.util.StringUtils; +import org.springframework.web.filter.OncePerRequestFilter; /** * Abstract base {@link ServletContextInitializer} to register {@link Filter}s in a @@ -218,7 +220,14 @@ public abstract class AbstractFilterRegistrationBean extends D super.configure(registration); EnumSet dispatcherTypes = this.dispatcherTypes; if (dispatcherTypes == null) { - dispatcherTypes = EnumSet.of(DispatcherType.REQUEST); + T filter = getFilter(); + if (ClassUtils.isPresent("org.springframework.web.filter.OncePerRequestFilter", + filter.getClass().getClassLoader()) && filter instanceof OncePerRequestFilter) { + dispatcherTypes = EnumSet.allOf(DispatcherType.class); + } + else { + dispatcherTypes = EnumSet.of(DispatcherType.REQUEST); + } } Set servletNames = new LinkedHashSet<>(); for (ServletRegistrationBean servletRegistrationBean : this.servletRegistrationBeans) { diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/servlet/FilterRegistrationBeanTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/servlet/FilterRegistrationBeanTests.java index b1e7eff2b42..1d157be24c8 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/servlet/FilterRegistrationBeanTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/servlet/FilterRegistrationBeanTests.java @@ -16,11 +16,20 @@ package org.springframework.boot.web.servlet; +import java.io.IOException; +import java.util.EnumSet; + +import javax.servlet.DispatcherType; import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; import org.junit.jupiter.api.Test; import org.springframework.boot.web.servlet.mock.MockFilter; +import org.springframework.web.filter.OncePerRequestFilter; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.eq; @@ -35,6 +44,16 @@ class FilterRegistrationBeanTests extends AbstractFilterRegistrationBeanTests { private final MockFilter filter = new MockFilter(); + private final OncePerRequestFilter oncePerRequestFilter = new OncePerRequestFilter() { + + @Override + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, + FilterChain filterChain) throws ServletException, IOException { + filterChain.doFilter(request, response); + } + + }; + @Test void setFilter() throws Exception { FilterRegistrationBean bean = new FilterRegistrationBean<>(); @@ -63,6 +82,15 @@ class FilterRegistrationBeanTests extends AbstractFilterRegistrationBeanTests { .withMessageContaining("ServletRegistrationBeans must not be null"); } + @Test + void startupWithOncePerRequestDefaults() throws Exception { + FilterRegistrationBean bean = new FilterRegistrationBean<>(this.oncePerRequestFilter); + bean.onStartup(this.servletContext); + verify(this.servletContext).addFilter(eq("oncePerRequestFilter"), eq(this.oncePerRequestFilter)); + verify(this.registration).setAsyncSupported(true); + verify(this.registration).addMappingForUrlPatterns(EnumSet.allOf(DispatcherType.class), false, "/*"); + } + @Override protected AbstractFilterRegistrationBean createFilterRegistrationBean( ServletRegistrationBean... servletRegistrationBeans) { diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/servlet/NoSpringWebFilterRegistrationBeanTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/servlet/NoSpringWebFilterRegistrationBeanTests.java new file mode 100644 index 00000000000..bdfba615d2a --- /dev/null +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/servlet/NoSpringWebFilterRegistrationBeanTests.java @@ -0,0 +1,48 @@ +/* + * Copyright 2012-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.web.servlet; + +import javax.servlet.Filter; + +import org.springframework.boot.testsupport.classpath.ClassPathExclusions; +import org.springframework.boot.web.servlet.mock.MockFilter; + +import static org.mockito.ArgumentMatchers.eq; + +/** + * Tests for {@link FilterRegistrationBean} when {@code spring-web} is not on the + * classpath. + * + * @author Andy Wilkinson + */ +@ClassPathExclusions("spring-web-*.jar") +public class NoSpringWebFilterRegistrationBeanTests extends AbstractFilterRegistrationBeanTests { + + private final MockFilter filter = new MockFilter(); + + @Override + protected AbstractFilterRegistrationBean createFilterRegistrationBean( + ServletRegistrationBean... servletRegistrationBeans) { + return new FilterRegistrationBean<>(this.filter, servletRegistrationBeans); + } + + @Override + protected Filter getExpectedFilter() { + return eq(this.filter); + } + +}