diff --git a/spring-boot-actuator/src/main/java/org/springframework/boot/actuate/autoconfigure/MetricsFilter.java b/spring-boot-actuator/src/main/java/org/springframework/boot/actuate/autoconfigure/MetricsFilter.java index a68f68f0fe3..b045e5da65c 100644 --- a/spring-boot-actuator/src/main/java/org/springframework/boot/actuate/autoconfigure/MetricsFilter.java +++ b/spring-boot-actuator/src/main/java/org/springframework/boot/actuate/autoconfigure/MetricsFilter.java @@ -42,6 +42,9 @@ import org.springframework.web.util.UrlPathHelper; @Order(Ordered.HIGHEST_PRECEDENCE) final class MetricsFilter extends OncePerRequestFilter { + private static final String ATTRIBUTE_STOP_WATCH = MetricsFilter.class.getName() + + ".StopWatch"; + private static final int UNDEFINED_HTTP_STATUS = 999; private static final String UNKNOWN_PATH_SUFFIX = "/unmapped"; @@ -57,12 +60,16 @@ final class MetricsFilter extends OncePerRequestFilter { this.gaugeService = gaugeService; } + @Override + protected boolean shouldNotFilterAsyncDispatch() { + return false; + } + @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain chain) throws ServletException, IOException { - StopWatch stopWatch = new StopWatch(); - stopWatch.start(); + StopWatch stopWatch = createStopWatchIfNecessary(request); String path = new UrlPathHelper().getPathWithinApplication(request); int status = HttpStatus.INTERNAL_SERVER_ERROR.value(); try { @@ -70,11 +77,24 @@ final class MetricsFilter extends OncePerRequestFilter { status = getStatus(response); } finally { - stopWatch.stop(); - recordMetrics(request, path, status, stopWatch.getTotalTimeMillis()); + if (!request.isAsyncStarted()) { + stopWatch.stop(); + request.removeAttribute(ATTRIBUTE_STOP_WATCH); + recordMetrics(request, path, status, stopWatch.getTotalTimeMillis()); + } } } + private StopWatch createStopWatchIfNecessary(HttpServletRequest request) { + StopWatch stopWatch = (StopWatch) request.getAttribute(ATTRIBUTE_STOP_WATCH); + if (stopWatch == null) { + stopWatch = new StopWatch(); + stopWatch.start(); + request.setAttribute(ATTRIBUTE_STOP_WATCH, stopWatch); + } + return stopWatch; + } + private int getStatus(HttpServletResponse response) { try { return response.getStatus(); diff --git a/spring-boot-actuator/src/test/java/org/springframework/boot/actuate/autoconfigure/MetricFilterAutoConfigurationTests.java b/spring-boot-actuator/src/test/java/org/springframework/boot/actuate/autoconfigure/MetricFilterAutoConfigurationTests.java index bfd8b2baded..041881131f4 100644 --- a/spring-boot-actuator/src/test/java/org/springframework/boot/actuate/autoconfigure/MetricFilterAutoConfigurationTests.java +++ b/spring-boot-actuator/src/test/java/org/springframework/boot/actuate/autoconfigure/MetricFilterAutoConfigurationTests.java @@ -17,6 +17,7 @@ package org.springframework.boot.actuate.autoconfigure; import java.io.IOException; +import java.util.concurrent.CountDownLatch; import javax.servlet.Filter; import javax.servlet.FilterChain; @@ -34,21 +35,28 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.core.annotation.Order; import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseEntity; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.stereotype.Component; import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.MvcResult; import org.springframework.test.web.servlet.setup.MockMvcBuilders; import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.ResponseBody; import org.springframework.web.bind.annotation.ResponseStatus; import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.context.request.async.DeferredResult; import org.springframework.web.filter.OncePerRequestFilter; import org.springframework.web.util.NestedServletException; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; import static org.junit.Assert.assertThat; +import static org.junit.Assert.fail; import static org.mockito.BDDMockito.willAnswer; import static org.mockito.BDDMockito.willThrow; import static org.mockito.Matchers.anyDouble; @@ -57,7 +65,10 @@ import static org.mockito.Matchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.asyncDispatch; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.request; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; /** @@ -202,6 +213,56 @@ public class MetricFilterAutoConfigurationTests { context.close(); } + @Test + public void correctlyRecordsMetricsForDeferredResultResponse() throws Exception { + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext( + Config.class, MetricFilterAutoConfiguration.class); + MetricsFilter filter = context.getBean(MetricsFilter.class); + CountDownLatch latch = new CountDownLatch(1); + MockMvc mvc = MockMvcBuilders + .standaloneSetup(new MetricFilterTestController(latch)).addFilter(filter) + .build(); + String attributeName = MetricsFilter.class.getName() + ".StopWatch"; + MvcResult result = mvc.perform(post("/create")).andExpect(status().isOk()) + .andExpect(request().asyncStarted()) + .andExpect(request().attribute(attributeName, is(notNullValue()))) + .andReturn(); + latch.countDown(); + mvc.perform(asyncDispatch(result)).andExpect(status().isCreated()) + .andExpect(request().attribute(attributeName, is(nullValue()))); + verify(context.getBean(CounterService.class)).increment("status.201.create"); + context.close(); + } + + @Test + public void correctlyRecordsMetricsForFailedDeferredResultResponse() throws Exception { + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext( + Config.class, MetricFilterAutoConfiguration.class); + MetricsFilter filter = context.getBean(MetricsFilter.class); + CountDownLatch latch = new CountDownLatch(1); + MockMvc mvc = MockMvcBuilders + .standaloneSetup(new MetricFilterTestController(latch)).addFilter(filter) + .build(); + String attributeName = MetricsFilter.class.getName() + ".StopWatch"; + MvcResult result = mvc.perform(post("/createFailure")).andExpect(status().isOk()) + .andExpect(request().asyncStarted()) + .andExpect(request().attribute(attributeName, is(notNullValue()))) + .andReturn(); + latch.countDown(); + try { + mvc.perform(asyncDispatch(result)); + fail(); + } + catch (Exception ex) { + assertThat(result.getRequest().getAttribute(attributeName), is(nullValue())); + verify(context.getBean(CounterService.class)).increment( + "status.500.createFailure"); + } + finally { + context.close(); + } + } + @Configuration public static class Config { @@ -220,6 +281,16 @@ public class MetricFilterAutoConfigurationTests { @RestController class MetricFilterTestController { + private final CountDownLatch latch; + + MetricFilterTestController() { + this(null); + } + + MetricFilterTestController(CountDownLatch latch) { + this.latch = latch; + } + @RequestMapping("templateVarTest/{someVariable}") public String testTemplateVariableResolution(@PathVariable String someVariable) { return someVariable; @@ -237,6 +308,43 @@ public class MetricFilterAutoConfigurationTests { public String testException() { throw new RuntimeException(); } + + @RequestMapping("create") + public DeferredResult> create() { + final DeferredResult> result = new DeferredResult>(); + new Thread(new Runnable() { + @Override + public void run() { + try { + MetricFilterTestController.this.latch.await(); + result.setResult(new ResponseEntity("Done", + HttpStatus.CREATED)); + } + catch (InterruptedException ex) { + } + } + }).start(); + return result; + } + + @RequestMapping("createFailure") + public DeferredResult> createFailure() { + final DeferredResult> result = new DeferredResult>(); + new Thread(new Runnable() { + @Override + public void run() { + try { + MetricFilterTestController.this.latch.await(); + result.setErrorResult(new Exception("It failed")); + } + catch (InterruptedException ex) { + + } + } + }).start(); + return result; + } + } @Component