diff --git a/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/mock/mockito/MockitoTestExecutionListener.java b/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/mock/mockito/MockitoTestExecutionListener.java index 975d916b744..be4561d45a5 100644 --- a/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/mock/mockito/MockitoTestExecutionListener.java +++ b/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/mock/mockito/MockitoTestExecutionListener.java @@ -64,9 +64,10 @@ public class MockitoTestExecutionListener extends AbstractTestExecutionListener } @Override - public void beforeTestMethod(TestContext testContext) { + public void beforeTestMethod(TestContext testContext) throws Exception { if (Boolean.TRUE.equals( testContext.getAttribute(DependencyInjectionTestExecutionListener.REINJECT_DEPENDENCIES_ATTRIBUTE))) { + closeMocks(testContext); initMocks(testContext); reinjectFields(testContext); } @@ -77,6 +78,11 @@ public class MockitoTestExecutionListener extends AbstractTestExecutionListener closeMocks(testContext); } + @Override + public void afterTestClass(TestContext testContext) throws Exception { + closeMocks(testContext); + } + private void initMocks(TestContext testContext) { if (hasMockitoAnnotations(testContext)) { testContext.setAttribute(MOCKS_ATTRIBUTE_NAME, MockitoAnnotations.openMocks(testContext.getTestInstance())); diff --git a/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/mock/mockito/MockitoTestExecutionListenerIntegrationTests.java b/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/mock/mockito/MockitoTestExecutionListenerIntegrationTests.java index 7dbe1367005..1357629a616 100644 --- a/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/mock/mockito/MockitoTestExecutionListenerIntegrationTests.java +++ b/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/mock/mockito/MockitoTestExecutionListenerIntegrationTests.java @@ -21,11 +21,13 @@ import java.util.UUID; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.ClassOrderer; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.MethodOrderer; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Order; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestClassOrder; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.TestInstance.Lifecycle; import org.junit.jupiter.api.TestMethodOrder; @@ -36,6 +38,8 @@ import org.mockito.MockedStatic; import org.springframework.boot.test.context.TestConfiguration; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; +import org.springframework.test.annotation.DirtiesContext; +import org.springframework.test.annotation.DirtiesContext.ClassMode; import org.springframework.test.context.junit.jupiter.SpringExtension; import static org.assertj.core.api.Assertions.assertThat; @@ -77,6 +81,122 @@ class MockitoTestExecutionListenerIntegrationTests { } + @Nested + @TestMethodOrder(MethodOrderer.OrderAnnotation.class) + @DirtiesContext(classMode = ClassMode.BEFORE_EACH_TEST_METHOD) + class MockedStaticTestsDirtiesContext { + + private static final UUID uuid = UUID.randomUUID(); + + @Mock + private MockedStatic mockedStatic; + + @Test + @Order(1) + @Disabled + void shouldReturnConstantValueDisabled() { + this.mockedStatic.when(UUID::randomUUID).thenReturn(uuid); + UUID result = UUID.randomUUID(); + assertThat(result).isEqualTo(uuid); + } + + @Test + @Order(2) + void shouldNotFailBecauseOfMockedStaticNotBeingClosed() { + this.mockedStatic.when(UUID::randomUUID).thenReturn(uuid); + UUID result = UUID.randomUUID(); + assertThat(result).isEqualTo(uuid); + } + + @Test + @Order(3) + void shouldNotFailBecauseOfMockedStaticNotBeingClosedWhenMocksAreReinjected() { + this.mockedStatic.when(UUID::randomUUID).thenReturn(uuid); + UUID result = UUID.randomUUID(); + assertThat(result).isEqualTo(uuid); + } + + } + + @Nested + @TestMethodOrder(MethodOrderer.OrderAnnotation.class) + @TestClassOrder(ClassOrderer.OrderAnnotation.class) + class MockedStaticTestsIfClassContainsOnlyDisabledTests { + + @Nested + @Order(1) + class TestClass1 { + + private static final UUID uuid = UUID.randomUUID(); + + @Mock + private MockedStatic mockedStatic; + + @Test + @Order(1) + @Disabled + void disabledTest() { + this.mockedStatic.when(UUID::randomUUID).thenReturn(uuid); + } + + } + + @Nested + @Order(2) + class TestClass2 { + + private static final UUID uuid = UUID.randomUUID(); + + @Mock + private MockedStatic mockedStatic; + + @Test + @Order(1) + void shouldNotFailBecauseMockedStaticHasNotBeenClosed() { + this.mockedStatic.when(UUID::randomUUID).thenReturn(uuid); + UUID result = UUID.randomUUID(); + assertThat(result).isEqualTo(uuid); + } + + } + + } + + @Nested + @TestMethodOrder(MethodOrderer.OrderAnnotation.class) + @TestClassOrder(ClassOrderer.OrderAnnotation.class) + class MockedStaticTestsIfClassContainsNoTests { + + @Nested + @Order(1) + class TestClass1 { + + @Mock + private MockedStatic mockedStatic; + + } + + @Nested + @Order(2) + class TestClass2 { + + private static final UUID uuid = UUID.randomUUID(); + + @Mock + private MockedStatic mockedStatic; + + @Test + @Order(1) + void shouldNotFailBecauseMockedStaticHasNotBeenClosed() { + this.mockedStatic.when(UUID::randomUUID).thenReturn(uuid); + UUID result = UUID.randomUUID(); + assertThat(result).isEqualTo(uuid); + } + + } + + } + @Nested @TestMethodOrder(MethodOrderer.OrderAnnotation.class) class ConfigureMockInBeforeEach {