Polish "Hide loader classes from Tomcat's ServletContext resource paths"

See gh-17538
This commit is contained in:
Andy Wilkinson 2019-06-20 10:54:28 +01:00
parent a81325bbbc
commit 591250f75e
4 changed files with 113 additions and 17 deletions

View File

@ -754,7 +754,8 @@ public class TomcatServletWebServerFactory extends AbstractServletWebServerFacto
@Override @Override
public Set<String> listWebAppPaths(String path) { public Set<String> listWebAppPaths(String path) {
return this.delegate.listWebAppPaths(path).stream() return this.delegate.listWebAppPaths(path).stream()
.filter((p) -> !p.startsWith("org/springframework/boot/loader")).collect(Collectors.toSet()); .filter((webAppPath) -> !webAppPath.startsWith("/org/springframework/boot"))
.collect(Collectors.toSet());
} }
@Override @Override

View File

@ -18,6 +18,8 @@ package com.example;
import java.io.IOException; import java.io.IOException;
import java.net.URL; import java.net.URL;
import java.util.LinkedHashSet;
import java.util.Set;
import javax.servlet.ServletException; import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServlet;
@ -38,31 +40,62 @@ import org.springframework.context.annotation.Bean;
@SpringBootApplication @SpringBootApplication
public class ResourceHandlingApplication { public class ResourceHandlingApplication {
@Bean
public ServletRegistrationBean<?> resourceServletRegistration() {
ServletRegistrationBean<?> registration = new ServletRegistrationBean<HttpServlet>(new GetResourceServlet());
registration.addUrlMappings("/servletContext");
return registration;
}
@Bean
public ServletRegistrationBean<?> resourcePathsServletRegistration() {
ServletRegistrationBean<?> registration = new ServletRegistrationBean<HttpServlet>(
new GetResourcePathsServlet());
registration.addUrlMappings("/resourcePaths");
return registration;
}
public static void main(String[] args) { public static void main(String[] args) {
new SpringApplicationBuilder(ResourceHandlingApplication.class).properties("server.port:0") new SpringApplicationBuilder(ResourceHandlingApplication.class).properties("server.port:0")
.listeners(new WebServerPortFileWriter("target/server.port")).run(args); .listeners(new WebServerPortFileWriter("target/server.port")).run(args);
} }
@Bean private static final class GetResourcePathsServlet extends HttpServlet {
public ServletRegistrationBean<?> resourceServletRegistration() {
ServletRegistrationBean<?> registration = new ServletRegistrationBean<HttpServlet>(new HttpServlet() {
@Override @Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
throws ServletException, IOException { collectResourcePaths("/").forEach(resp.getWriter()::println);
URL resource = getServletContext().getResource(req.getQueryString()); resp.getWriter().flush();
if (resource == null) { }
resp.sendError(404);
} private Set<String> collectResourcePaths(String path) {
else { Set<String> allResourcePaths = new LinkedHashSet<>();
resp.getWriter().println(resource); Set<String> pathsForPath = getServletContext().getResourcePaths(path);
resp.getWriter().flush(); if (pathsForPath != null) {
for (String resourcePath : pathsForPath) {
allResourcePaths.add(resourcePath);
allResourcePaths.addAll(collectResourcePaths(resourcePath));
} }
} }
return allResourcePaths;
}
}
private static final class GetResourceServlet extends HttpServlet {
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
URL resource = getServletContext().getResource(req.getQueryString());
if (resource == null) {
resp.sendError(404);
}
else {
resp.getWriter().println(resource);
resp.getWriter().flush();
}
}
});
registration.addUrlMappings("/servletContext");
return registration;
} }
} }

View File

@ -16,7 +16,13 @@
package org.springframework.boot.context.embedded; package org.springframework.boot.context.embedded;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
@ -67,4 +73,29 @@ public class EmbeddedServletContainerWarDevelopmentIntegrationTests
assertThat(entity.getStatusCode()).isEqualTo(HttpStatus.OK); assertThat(entity.getStatusCode()).isEqualTo(HttpStatus.OK);
} }
@Test
public void loaderClassesAreNotAvailableViaResourcePaths() {
ResponseEntity<String> entity = this.rest.getForEntity("/resourcePaths", String.class);
assertThat(entity.getStatusCode()).isEqualTo(HttpStatus.OK);
assertThat(readLines(entity.getBody()))
.noneMatch((resourcePath) -> resourcePath.startsWith("/org/springframework/boot/loader"));
}
private List<String> readLines(String input) {
if (input == null) {
return Collections.emptyList();
}
List<String> lines = new ArrayList<>();
try (BufferedReader reader = new BufferedReader(new StringReader(input))) {
String line;
while ((line = reader.readLine()) != null) {
lines.add(line);
}
return lines;
}
catch (IOException ex) {
throw new RuntimeException("Failed to read lines from input '" + input + "'");
}
}
} }

View File

@ -16,7 +16,13 @@
package org.springframework.boot.context.embedded; package org.springframework.boot.context.embedded;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
@ -90,4 +96,29 @@ public class EmbeddedServletContainerWarPackagingIntegrationTests
assertThat(entity.getStatusCode()).isEqualTo(HttpStatus.NOT_FOUND); assertThat(entity.getStatusCode()).isEqualTo(HttpStatus.NOT_FOUND);
} }
@Test
public void loaderClassesAreNotAvailableViaResourcePaths() {
ResponseEntity<String> entity = this.rest.getForEntity("/resourcePaths", String.class);
assertThat(entity.getStatusCode()).isEqualTo(HttpStatus.OK);
assertThat(readLines(entity.getBody()))
.noneMatch((resourcePath) -> resourcePath.startsWith("/org/springframework/boot/loader"));
}
private List<String> readLines(String input) {
if (input == null) {
return Collections.emptyList();
}
List<String> lines = new ArrayList<>();
try (BufferedReader reader = new BufferedReader(new StringReader(input))) {
String line;
while ((line = reader.readLine()) != null) {
lines.add(line);
}
return lines;
}
catch (IOException ex) {
throw new RuntimeException("Failed to read lines from input '" + input + "'");
}
}
} }