diff --git a/spring-boot-project/spring-boot-devtools/src/test/java/org/springframework/boot/devtools/tunnel/client/TunnelClientTests.java b/spring-boot-project/spring-boot-devtools/src/test/java/org/springframework/boot/devtools/tunnel/client/TunnelClientTests.java index 060a798e081..da6eb662b1d 100644 --- a/spring-boot-project/spring-boot-devtools/src/test/java/org/springframework/boot/devtools/tunnel/client/TunnelClientTests.java +++ b/spring-boot-project/spring-boot-devtools/src/test/java/org/springframework/boot/devtools/tunnel/client/TunnelClientTests.java @@ -25,15 +25,13 @@ import java.nio.channels.Channels; import java.nio.channels.SocketChannel; import java.nio.channels.WritableByteChannel; import java.time.Duration; +import java.util.concurrent.atomic.AtomicInteger; import org.awaitility.Awaitility; import org.junit.jupiter.api.Test; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; /** * Tests for {@link TunnelClient}. @@ -98,16 +96,34 @@ class TunnelClientTests { @Test void addListener() throws Exception { TunnelClient client = new TunnelClient(0, this.tunnelConnection); - TunnelClientListener listener = mock(TunnelClientListener.class); + MockTunnelClientListener listener = new MockTunnelClientListener(); client.addListener(listener); int port = client.start(); SocketChannel channel = SocketChannel.open(new InetSocketAddress(port)); - Thread.sleep(200); - channel.close(); + Awaitility.await().atMost(Duration.ofSeconds(30)).until(listener.onOpen::get, (open) -> open == 1); + assertThat(listener.onClose).hasValue(0); client.getServerThread().stopAcceptingConnections(); + channel.close(); + Awaitility.await().atMost(Duration.ofSeconds(30)).until(listener.onClose::get, (close) -> close == 1); client.getServerThread().join(2000); - verify(listener).onOpen(any(SocketChannel.class)); - verify(listener).onClose(any(SocketChannel.class)); + } + + static class MockTunnelClientListener implements TunnelClientListener { + + private final AtomicInteger onOpen = new AtomicInteger(); + + private final AtomicInteger onClose = new AtomicInteger(); + + @Override + public void onOpen(SocketChannel socket) { + this.onOpen.incrementAndGet(); + } + + @Override + public void onClose(SocketChannel socket) { + this.onClose.incrementAndGet(); + } + } static class MockTunnelConnection implements TunnelConnection {