Consider HandlerMethodValidationException in DefaultErrorAttributes

See gh-39865
This commit is contained in:
Yanming Zhou 2024-03-11 11:39:18 +08:00 committed by Andy Wilkinson
parent 4b61ae415b
commit 20e9ff9f3d
4 changed files with 141 additions and 7 deletions

View File

@ -32,6 +32,7 @@ import org.springframework.http.HttpStatus;
import org.springframework.util.StringUtils;
import org.springframework.validation.BindingResult;
import org.springframework.validation.ObjectError;
import org.springframework.validation.method.MethodValidationResult;
import org.springframework.web.bind.annotation.ResponseStatus;
import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.server.ResponseStatusException;
@ -57,6 +58,7 @@ import org.springframework.web.server.ServerWebExchange;
* @author Stephane Nicoll
* @author Michele Mancioppi
* @author Scott Frederick
* @author Yanming Zhou
* @since 2.0.0
* @see ErrorAttributes
*/
@ -113,6 +115,14 @@ public class DefaultErrorAttributes implements ErrorAttributes {
if (error instanceof BindingResult) {
return error.getMessage();
}
if (error instanceof MethodValidationResult methodValidationResult) {
long errorCount = methodValidationResult.getAllErrors()
.stream()
.filter(ObjectError.class::isInstance)
.count();
return "Validation failed for method: %s, with %d %s".formatted(methodValidationResult.getMethod(),
errorCount, (errorCount > 1) ? "errors" : "error");
}
if (error instanceof ResponseStatusException responseStatusException) {
return responseStatusException.getReason();
}
@ -147,6 +157,12 @@ public class DefaultErrorAttributes implements ErrorAttributes {
errorAttributes.put("errors", result.getAllErrors());
}
}
if (error instanceof MethodValidationResult result) {
if (result.hasErrors()) {
errorAttributes.put("errors",
result.getAllErrors().stream().filter(ObjectError.class::isInstance).toList());
}
}
}
@Override

View File

@ -20,6 +20,7 @@ import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.Date;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import jakarta.servlet.RequestDispatcher;
@ -29,6 +30,7 @@ import jakarta.servlet.http.HttpServletResponse;
import org.springframework.boot.web.error.ErrorAttributeOptions;
import org.springframework.boot.web.error.ErrorAttributeOptions.Include;
import org.springframework.context.MessageSourceResolvable;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order;
import org.springframework.http.HttpStatus;
@ -36,6 +38,7 @@ import org.springframework.util.ObjectUtils;
import org.springframework.util.StringUtils;
import org.springframework.validation.BindingResult;
import org.springframework.validation.ObjectError;
import org.springframework.validation.method.MethodValidationResult;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.WebRequest;
import org.springframework.web.servlet.HandlerExceptionResolver;
@ -61,6 +64,7 @@ import org.springframework.web.servlet.ModelAndView;
* @author Stephane Nicoll
* @author Vedran Pavic
* @author Scott Frederick
* @author Yanming Zhou
* @since 2.0.0
* @see ErrorAttributes
*/
@ -145,13 +149,20 @@ public class DefaultErrorAttributes implements ErrorAttributes, HandlerException
}
private void addErrorMessage(Map<String, Object> errorAttributes, WebRequest webRequest, Throwable error) {
BindingResult result = extractBindingResult(error);
if (result == null) {
addExceptionErrorMessage(errorAttributes, webRequest, error);
MethodValidationResult methodValidationResult = extractMethodValidationResult(error);
if (methodValidationResult != null) {
addMethodValidationResultErrorMessage(errorAttributes, methodValidationResult);
}
else {
addBindingResultErrorMessage(errorAttributes, result);
BindingResult bindingResult = extractBindingResult(error);
if (bindingResult != null) {
addBindingResultErrorMessage(errorAttributes, bindingResult);
}
else {
addExceptionErrorMessage(errorAttributes, webRequest, error);
}
}
}
private void addExceptionErrorMessage(Map<String, Object> errorAttributes, WebRequest webRequest, Throwable error) {
@ -189,6 +200,17 @@ public class DefaultErrorAttributes implements ErrorAttributes, HandlerException
errorAttributes.put("errors", result.getAllErrors());
}
private void addMethodValidationResultErrorMessage(Map<String, Object> errorAttributes,
MethodValidationResult result) {
List<? extends MessageSourceResolvable> errors = result.getAllErrors()
.stream()
.filter(ObjectError.class::isInstance)
.toList();
errorAttributes.put("message",
"Validation failed for method='" + result.getMethod() + "'. " + "Error count: " + errors.size());
errorAttributes.put("errors", errors);
}
private BindingResult extractBindingResult(Throwable error) {
if (error instanceof BindingResult bindingResult) {
return bindingResult;
@ -196,6 +218,13 @@ public class DefaultErrorAttributes implements ErrorAttributes, HandlerException
return null;
}
private MethodValidationResult extractMethodValidationResult(Throwable error) {
if (error instanceof MethodValidationResult methodValidationResult) {
return methodValidationResult;
}
return null;
}
private void addStackTrace(Map<String, Object> errorAttributes, Throwable error) {
StringWriter stackTrace = new StringWriter();
error.printStackTrace(new PrintWriter(stackTrace));

View File

@ -35,8 +35,11 @@ import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.validation.BindingResult;
import org.springframework.validation.MapBindingResult;
import org.springframework.validation.ObjectError;
import org.springframework.validation.method.MethodValidationResult;
import org.springframework.validation.method.ParameterValidationResult;
import org.springframework.web.bind.annotation.ResponseStatus;
import org.springframework.web.bind.support.WebExchangeBindException;
import org.springframework.web.method.annotation.HandlerMethodValidationException;
import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.server.ResponseStatusException;
import org.springframework.web.server.ServerWebExchange;
@ -50,6 +53,7 @@ import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
* @author Brian Clozel
* @author Stephane Nicoll
* @author Scott Frederick
* @author Yanming Zhou
*/
class DefaultErrorAttributesTests {
@ -246,6 +250,45 @@ class DefaultErrorAttributesTests {
assertThat(attributes).containsEntry("errors", bindingResult.getAllErrors());
}
@Test
void extractMethodValidationResultErrors() throws Exception {
Object target = "test";
Method method = String.class.getMethod("substring", int.class);
MethodParameter parameter = new MethodParameter(method, 0);
MethodValidationResult methodValidationResult = new MethodValidationResult() {
@Override
public Object getTarget() {
return target;
}
@Override
public Method getMethod() {
return method;
}
@Override
public boolean isForReturnValue() {
return false;
}
@Override
public List<ParameterValidationResult> getAllValidationResults() {
return List.of(new ParameterValidationResult(parameter, -1,
List.of(new ObjectError("beginIndex", "beginIndex is negative")), null, null, null));
}
};
HandlerMethodValidationException ex = new HandlerMethodValidationException(methodValidationResult);
MockServerHttpRequest request = MockServerHttpRequest.get("/test").build();
Map<String, Object> attributes = this.errorAttributes.getErrorAttributes(buildServerRequest(request, ex),
ErrorAttributeOptions.of(Include.MESSAGE, Include.BINDING_ERRORS));
assertThat(attributes.get("message")).asString()
.isEqualTo(
"Validation failed for method: public java.lang.String java.lang.String.substring(int), with 1 error");
assertThat(attributes).containsEntry("errors",
methodValidationResult.getAllErrors().stream().filter(ObjectError.class::isInstance).toList());
}
@Test
void extractBindingResultErrorsExcludeMessageAndErrors() throws Exception {
Method method = getClass().getDeclaredMethod("method", String.class);

View File

@ -19,13 +19,16 @@ package org.springframework.boot.web.servlet.error;
import java.lang.reflect.Method;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
import jakarta.servlet.ServletException;
import org.junit.jupiter.api.Test;
import org.springframework.boot.web.error.ErrorAttributeOptions;
import org.springframework.boot.web.error.ErrorAttributeOptions.Include;
import org.springframework.context.MessageSourceResolvable;
import org.springframework.core.MethodParameter;
import org.springframework.http.HttpStatus;
import org.springframework.mock.web.MockHttpServletRequest;
@ -34,9 +37,12 @@ import org.springframework.validation.BindException;
import org.springframework.validation.BindingResult;
import org.springframework.validation.MapBindingResult;
import org.springframework.validation.ObjectError;
import org.springframework.validation.method.MethodValidationResult;
import org.springframework.validation.method.ParameterValidationResult;
import org.springframework.web.bind.MethodArgumentNotValidException;
import org.springframework.web.context.request.ServletWebRequest;
import org.springframework.web.context.request.WebRequest;
import org.springframework.web.method.annotation.HandlerMethodValidationException;
import org.springframework.web.servlet.ModelAndView;
import static org.assertj.core.api.Assertions.assertThat;
@ -47,6 +53,7 @@ import static org.assertj.core.api.Assertions.assertThat;
* @author Phillip Webb
* @author Vedran Pavic
* @author Scott Frederick
* @author Yanming Zhou
*/
class DefaultErrorAttributesTests {
@ -201,18 +208,57 @@ class DefaultErrorAttributesTests {
testBindingResult(bindingResult, ex, ErrorAttributeOptions.of(Include.MESSAGE, Include.BINDING_ERRORS));
}
@Test
void withHandlerMethodValidationExceptionBindingErrors() {
Object target = "test";
Method method = ReflectionUtils.findMethod(String.class, "substring", int.class);
MethodParameter parameter = new MethodParameter(method, 0);
MethodValidationResult methodValidationResult = new MethodValidationResult() {
@Override
public Object getTarget() {
return target;
}
@Override
public Method getMethod() {
return method;
}
@Override
public boolean isForReturnValue() {
return false;
}
@Override
public List<ParameterValidationResult> getAllValidationResults() {
return List.of(new ParameterValidationResult(parameter, -1,
List.of(new ObjectError("beginIndex", "beginIndex is negative")), null, null, null));
}
};
HandlerMethodValidationException ex = new HandlerMethodValidationException(methodValidationResult);
testErrorsSupplier(methodValidationResult::getAllErrors,
"Validation failed for method='public java.lang.String java.lang.String.substring(int)'. Error count: 1",
ex, ErrorAttributeOptions.of(Include.MESSAGE, Include.BINDING_ERRORS));
}
private void testBindingResult(BindingResult bindingResult, Exception ex, ErrorAttributeOptions options) {
testErrorsSupplier(bindingResult::getAllErrors, "Validation failed for object='objectName'. Error count: 1", ex,
options);
}
private void testErrorsSupplier(Supplier<List<? extends MessageSourceResolvable>> errorsSupplier,
String expectedMessage, Exception ex, ErrorAttributeOptions options) {
this.request.setAttribute("jakarta.servlet.error.exception", ex);
Map<String, Object> attributes = this.errorAttributes.getErrorAttributes(this.webRequest, options);
if (options.isIncluded(Include.MESSAGE)) {
assertThat(attributes).containsEntry("message",
"Validation failed for object='objectName'. Error count: 1");
assertThat(attributes).containsEntry("message", expectedMessage);
}
else {
assertThat(attributes).doesNotContainKey("message");
}
if (options.isIncluded(Include.BINDING_ERRORS)) {
assertThat(attributes).containsEntry("errors", bindingResult.getAllErrors());
assertThat(attributes).containsEntry("errors", errorsSupplier.get());
}
else {
assertThat(attributes).doesNotContainKey("errors");