diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/embedded/jetty/JettyEmbeddedWebAppContext.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/embedded/jetty/JettyEmbeddedWebAppContext.java index 3bc9f30bbf8..baf8e758660 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/embedded/jetty/JettyEmbeddedWebAppContext.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/embedded/jetty/JettyEmbeddedWebAppContext.java @@ -38,7 +38,8 @@ class JettyEmbeddedWebAppContext extends WebAppContext { } void deferredInitialize() throws Exception { - ((JettyEmbeddedServletHandler) getServletHandler()).deferredInitialize(); + JettyEmbeddedServletHandler handler = (JettyEmbeddedServletHandler) getServletHandler(); + getContext().call(handler::deferredInitialize, null); } private static final class JettyEmbeddedServletHandler extends ServletHandler { diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/servlet/server/AbstractServletWebServerFactoryTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/servlet/server/AbstractServletWebServerFactoryTests.java index 521e547ae32..f59f916aad6 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/servlet/server/AbstractServletWebServerFactoryTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/servlet/server/AbstractServletWebServerFactoryTests.java @@ -40,6 +40,7 @@ import java.util.Collection; import java.util.Date; import java.util.EnumSet; import java.util.HashMap; +import java.util.HashSet; import java.util.LinkedHashMap; import java.util.List; import java.util.Locale; @@ -65,6 +66,7 @@ import jakarta.servlet.Filter; import jakarta.servlet.FilterChain; import jakarta.servlet.FilterConfig; import jakarta.servlet.GenericServlet; +import jakarta.servlet.ServletConfig; import jakarta.servlet.ServletContext; import jakarta.servlet.ServletContextEvent; import jakarta.servlet.ServletContextListener; @@ -1382,6 +1384,26 @@ public abstract class AbstractServletWebServerFactoryTests { + " \\(http(/1.1)?\\), [0-9]+ \\(http(/1.1)?\\)( with context path '(/)?')?"); } + @Test + void servletComponentsAreInitializedWithTheSameThreadContextClassLoader() { + AbstractServletWebServerFactory factory = getFactory(); + ThreadContextClassLoaderCapturingServlet servlet = new ThreadContextClassLoaderCapturingServlet(); + ThreadContextClassLoaderCapturingFilter filter = new ThreadContextClassLoaderCapturingFilter(); + ThreadContextClassLoaderCapturingListener listener = new ThreadContextClassLoaderCapturingListener(); + this.webServer = factory.getWebServer((context) -> { + context.addServlet("tcclCapturingServlet", servlet).setLoadOnStartup(0); + context.addFilter("tcclCapturingFilter", filter); + context.addListener(listener); + }); + this.webServer.start(); + assertThat(servlet.contextClassLoader).isNotNull(); + assertThat(filter.contextClassLoader).isNotNull(); + assertThat(listener.contextClassLoader).isNotNull(); + assertThat(new HashSet<>( + Arrays.asList(servlet.contextClassLoader, filter.contextClassLoader, listener.contextClassLoader))) + .hasSize(1); + } + protected Future initiateGetRequest(int port, String path) { return initiateGetRequest(HttpClients.createMinimal(), port, path); } @@ -1455,7 +1477,7 @@ public abstract class AbstractServletWebServerFactoryTests { compression.setExcludedUserAgents(excludedUserAgents); } factory.setCompression(compression); - factory.addInitializers(new ServletRegistrationBean(new HttpServlet() { + factory.addInitializers(new ServletRegistrationBean<>(new HttpServlet() { @Override protected void service(HttpServletRequest req, HttpServletResponse resp) throws IOException { @@ -1833,4 +1855,43 @@ public abstract class AbstractServletWebServerFactoryTests { } + static class ThreadContextClassLoaderCapturingServlet extends HttpServlet { + + private ClassLoader contextClassLoader; + + @Override + public void init(ServletConfig config) throws ServletException { + this.contextClassLoader = Thread.currentThread().getContextClassLoader(); + } + + } + + static class ThreadContextClassLoaderCapturingListener implements ServletContextListener { + + private ClassLoader contextClassLoader; + + @Override + public void contextInitialized(ServletContextEvent sce) { + this.contextClassLoader = Thread.currentThread().getContextClassLoader(); + } + + } + + static class ThreadContextClassLoaderCapturingFilter implements Filter { + + private ClassLoader contextClassLoader; + + @Override + public void init(FilterConfig filterConfig) throws ServletException { + this.contextClassLoader = Thread.currentThread().getContextClassLoader(); + } + + @Override + public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) + throws IOException, ServletException { + chain.doFilter(request, response); + } + + } + }