From b4810b8b91d0f212741636f7ff887122e9355152 Mon Sep 17 00:00:00 2001 From: cbono Date: Wed, 18 Dec 2019 22:11:49 -0600 Subject: [PATCH 1/2] Add SSL support to RSocketServer See gh-19399 --- .../rsocket/RSocketProperties.java | 14 +++ .../RSocketServerAutoConfiguration.java | 1 + .../RSocketServerAutoConfigurationTests.java | 12 ++ .../src/test/resources/rsocket/test.jks | Bin 0 -> 1276 bytes .../netty/NettyRSocketServerFactory.java | 70 +++++++++++- .../ConfigurableRSocketServerFactory.java | 15 +++ .../netty/NettyRSocketServerFactoryTests.java | 108 ++++++++++++++++-- 7 files changed, 205 insertions(+), 15 deletions(-) create mode 100644 spring-boot-project/spring-boot-autoconfigure/src/test/resources/rsocket/test.jks diff --git a/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/rsocket/RSocketProperties.java b/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/rsocket/RSocketProperties.java index 232eb1db427..cf18235260e 100644 --- a/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/rsocket/RSocketProperties.java +++ b/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/rsocket/RSocketProperties.java @@ -19,12 +19,15 @@ package org.springframework.boot.autoconfigure.rsocket; import java.net.InetAddress; import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; import org.springframework.boot.rsocket.server.RSocketServer; +import org.springframework.boot.web.server.Ssl; /** * {@link ConfigurationProperties properties} for RSocket support. * * @author Brian Clozel + * @author Chris Bono * @since 2.2.0 */ @ConfigurationProperties("spring.rsocket") @@ -59,6 +62,9 @@ public class RSocketProperties { */ private String mappingPath; + @NestedConfigurationProperty + private Ssl ssl; + public Integer getPort() { return this.port; } @@ -91,6 +97,14 @@ public class RSocketProperties { this.mappingPath = mappingPath; } + public Ssl getSsl() { + return this.ssl; + } + + public void setSsl(Ssl ssl) { + this.ssl = ssl; + } + } } diff --git a/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/rsocket/RSocketServerAutoConfiguration.java b/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/rsocket/RSocketServerAutoConfiguration.java index 7b13f7beabf..80902e1de3d 100644 --- a/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/rsocket/RSocketServerAutoConfiguration.java +++ b/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/rsocket/RSocketServerAutoConfiguration.java @@ -97,6 +97,7 @@ public class RSocketServerAutoConfiguration { PropertyMapper map = PropertyMapper.get().alwaysApplyingWhenNonNull(); map.from(properties.getServer().getAddress()).to(factory::setAddress); map.from(properties.getServer().getPort()).to(factory::setPort); + map.from(properties.getServer().getSsl()).to(factory::setSsl); factory.setRSocketServerCustomizers(customizers.orderedStream().collect(Collectors.toList())); return factory; } diff --git a/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/rsocket/RSocketServerAutoConfigurationTests.java b/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/rsocket/RSocketServerAutoConfigurationTests.java index 2134011e8a1..ae354082305 100644 --- a/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/rsocket/RSocketServerAutoConfigurationTests.java +++ b/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/rsocket/RSocketServerAutoConfigurationTests.java @@ -91,6 +91,18 @@ class RSocketServerAutoConfigurationTests { }); } + @Test + void shouldUseSslWhenRocketServerSslIsConfigured() { + reactiveWebContextRunner() + .withPropertyValues("spring.rsocket.server.ssl.keyStore=classpath:rsocket/test.jks", + "spring.rsocket.server.ssl.keyPassword=password", "spring.rsocket.server.port=0") + .run((context) -> assertThat(context).hasSingleBean(RSocketServerFactory.class) + .hasSingleBean(RSocketServerBootstrap.class).hasSingleBean(RSocketServerCustomizer.class) + .getBean(RSocketServerFactory.class) + .hasFieldOrPropertyWithValue("ssl.keyStore", "classpath:rsocket/test.jks") + .hasFieldOrPropertyWithValue("ssl.keyPassword", "password")); + } + @Test void shouldUseCustomServerBootstrap() { contextRunner().withUserConfiguration(CustomServerBootstrapConfig.class).run((context) -> assertThat(context) diff --git a/spring-boot-project/spring-boot-autoconfigure/src/test/resources/rsocket/test.jks b/spring-boot-project/spring-boot-autoconfigure/src/test/resources/rsocket/test.jks new file mode 100644 index 0000000000000000000000000000000000000000..0fc3e802f75461dd074facb9611d350db4d5960f GIT binary patch literal 1276 zcmezO_TO6u1_mZ5W@O+hNi8nXP0YzmEM{O}OjTL=XFE`?-k{cikBv*4jgf^>i%F1? zk(GfZ`?F{4vBFug6<%MmmXJ+)mEgQoslXV6!5_H^yana6V1?kw1w zbE0Bi&GHlLHzfosgjrwL)qKcc5I;l0LF?W2lpQg%-cHp$l(#o)?Jkath1@gQN{hG8 zig@wKv#0R7vd_QC=JG%%Ffy=4=$RT=0v*d`(8R=M(8RcU0W%XL6BCP-)w&Y~JZv0V zZ64=rS(uqv84M~6g$xAPm_u3EggJBalM{0?@{3DgVjNh+*s+LlVG-lTBF2m)W*{fd zYiMC$VQ64zW@K(?5e4L0B5?=MWswHLZ0z7LVq$~_7BeF|vl9agPmO-znfkD()@R+bGrT$(AXmN24#zv*XRX z#$UNu(Lmln78u;Jd@N!tBKmU@J0!OJc3G%!N>OO@P1n+F-CorAVRmOQaA8six!iWP z)M3lXpnJ*TI=kIlH(Yxia-ls?xvctEx&P5B6()tKm`>%bo~@fX9{j%TtMU1G!|pw& zZ6BRjIqQ^`bIxR@OmMno&8^H%tpq36Esh&T(+MJ_la+#pV>+3scS% rSocketServerCustomizers = new ArrayList<>(); + private Ssl ssl; + + private SslStoreProvider sslStoreProvider; + @Override public void setPort(int port) { this.port = port; @@ -76,6 +84,16 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur this.transport = transport; } + @Override + public void setSsl(Ssl ssl) { + this.ssl = ssl; + } + + @Override + public void setSslStoreProvider(SslStoreProvider sslStoreProvider) { + this.sslStoreProvider = sslStoreProvider; + } + /** * Set the {@link ReactorResourceFactory} to get the shared resources from. * @param resourceFactory the server resources @@ -133,21 +151,41 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur } private ServerTransport createWebSocketTransport() { + HttpServer httpServer; if (this.resourceFactory != null) { - HttpServer httpServer = HttpServer.create().runOn(this.resourceFactory.getLoopResources()) + httpServer = HttpServer.create().runOn(this.resourceFactory.getLoopResources()) .bindAddress(this::getListenAddress); - return WebsocketServerTransport.create(httpServer); } - return WebsocketServerTransport.create(getListenAddress()); + else { + InetSocketAddress listenAddress = this.getListenAddress(); + httpServer = HttpServer.create().host(listenAddress.getHostName()).port(listenAddress.getPort()); + } + + if (this.ssl != null && this.ssl.isEnabled()) { + SslServerCustomizer sslServerCustomizer = new SslServerCustomizer(this.ssl, null, this.sslStoreProvider); + httpServer = sslServerCustomizer.apply(httpServer); + } + + return WebsocketServerTransport.create(httpServer); } private ServerTransport createTcpTransport() { + TcpServer tcpServer; if (this.resourceFactory != null) { - TcpServer tcpServer = TcpServer.create().runOn(this.resourceFactory.getLoopResources()) + tcpServer = TcpServer.create().runOn(this.resourceFactory.getLoopResources()) .bindAddress(this::getListenAddress); - return TcpServerTransport.create(tcpServer); } - return TcpServerTransport.create(getListenAddress()); + else { + InetSocketAddress listenAddress = this.getListenAddress(); + tcpServer = TcpServer.create().host(listenAddress.getHostName()).port(listenAddress.getPort()); + } + + if (this.ssl != null && this.ssl.isEnabled()) { + TcpSslServerCustomizer sslServerCustomizer = new TcpSslServerCustomizer(this.ssl, this.sslStoreProvider); + tcpServer = sslServerCustomizer.apply(tcpServer); + } + + return TcpServerTransport.create(tcpServer); } private InetSocketAddress getListenAddress() { @@ -157,4 +195,24 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur return new InetSocketAddress(this.port); } + private static final class TcpSslServerCustomizer extends SslServerCustomizer { + + private TcpSslServerCustomizer(Ssl ssl, SslStoreProvider sslStoreProvider) { + super(ssl, null, sslStoreProvider); + } + + // This does not override the apply in parent - currently just leveraging the + // parent for its "getContextBuilder()" method. This should be refactored when + // we add the concept of http/tcp customizers for RSocket. + private TcpServer apply(TcpServer server) { + try { + return server.secure((contextSpec) -> contextSpec.sslContext(getContextBuilder())); + } + catch (Exception ex) { + throw new IllegalStateException(ex); + } + } + + } + } diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/rsocket/server/ConfigurableRSocketServerFactory.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/rsocket/server/ConfigurableRSocketServerFactory.java index cb974105429..afbf549ba2d 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/rsocket/server/ConfigurableRSocketServerFactory.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/rsocket/server/ConfigurableRSocketServerFactory.java @@ -18,6 +18,9 @@ package org.springframework.boot.rsocket.server; import java.net.InetAddress; +import org.springframework.boot.web.server.Ssl; +import org.springframework.boot.web.server.SslStoreProvider; + /** * A configurable {@link RSocketServerFactory}. * @@ -45,4 +48,16 @@ public interface ConfigurableRSocketServerFactory { */ void setTransport(RSocketServer.Transport transport); + /** + * Sets the SSL configuration that will be applied to the server's default connector. + * @param ssl the SSL configuration + */ + void setSsl(Ssl ssl); + + /** + * Sets a provider that will be used to obtain SSL stores. + * @param sslStoreProvider the SSL store provider + */ + void setSslStoreProvider(SslStoreProvider sslStoreProvider); + } diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactoryTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactoryTests.java index d116d9c6f60..d329c6dbd47 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactoryTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactoryTests.java @@ -17,15 +17,20 @@ package org.springframework.boot.rsocket.netty; import java.net.InetSocketAddress; +import java.nio.channels.ClosedChannelException; import java.time.Duration; import java.util.Arrays; import java.util.concurrent.Callable; import io.netty.buffer.PooledByteBufAllocator; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.SslProvider; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; import io.rsocket.ConnectionSetupPayload; import io.rsocket.Payload; import io.rsocket.RSocket; import io.rsocket.SocketAcceptor; +import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.client.WebsocketClientTransport; import io.rsocket.util.DefaultPayload; import org.assertj.core.api.Assertions; @@ -33,9 +38,13 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; import org.mockito.InOrder; import reactor.core.publisher.Mono; +import reactor.netty.tcp.TcpClient; +import reactor.test.StepVerifier; import org.springframework.boot.rsocket.server.RSocketServer; import org.springframework.boot.rsocket.server.RSocketServerCustomizer; +import org.springframework.boot.rsocket.server.RSocketServer.Transport; +import org.springframework.boot.web.server.Ssl; import org.springframework.core.codec.CharSequenceEncoder; import org.springframework.core.codec.StringDecoder; import org.springframework.core.io.buffer.NettyDataBufferFactory; @@ -45,6 +54,8 @@ import org.springframework.messaging.rsocket.RSocketStrategies; import org.springframework.util.SocketUtils; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.will; import static org.mockito.Mockito.inOrder; @@ -55,6 +66,7 @@ import static org.mockito.Mockito.mock; * * @author Brian Clozel * @author Leo Li + * @author Chris Bono */ class NettyRSocketServerFactoryTests { @@ -93,7 +105,7 @@ class NettyRSocketServerFactoryTests { this.server.start(); return port; }); - this.requester = createRSocketTcpClient(); + this.requester = createRSocketTcpClient(false); String payload = "test payload"; String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT); assertThat(this.server.address().getPort()).isEqualTo(specificPort); @@ -106,7 +118,7 @@ class NettyRSocketServerFactoryTests { factory.setTransport(RSocketServer.Transport.WEBSOCKET); this.server = factory.create(new EchoRequestResponseAcceptor()); this.server.start(); - this.requester = createRSocketWebSocketClient(); + this.requester = createRSocketWebSocketClient(false); String payload = "test payload"; String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT); assertThat(response).isEqualTo(payload); @@ -121,7 +133,7 @@ class NettyRSocketServerFactoryTests { factory.setResourceFactory(resourceFactory); this.server = factory.create(new EchoRequestResponseAcceptor()); this.server.start(); - this.requester = createRSocketWebSocketClient(); + this.requester = createRSocketWebSocketClient(false); String payload = "test payload"; String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT); assertThat(response).isEqualTo(payload); @@ -144,16 +156,94 @@ class NettyRSocketServerFactoryTests { } } - private RSocketRequester createRSocketTcpClient() { - Assertions.assertThat(this.server).isNotNull(); - InetSocketAddress address = this.server.address(); - return createRSocketRequesterBuilder().tcp(address.getHostString(), address.getPort()); + @Test + void tcpTransportBasicSslFromClassPath() { + testBasicSslWithKeyStore("classpath:test.jks", "password", Transport.TCP); } - private RSocketRequester createRSocketWebSocketClient() { + @Test + void tcpTransportBasicSslFromFileSystem() { + testBasicSslWithKeyStore("src/test/resources/test.jks", "password", Transport.TCP); + } + + @Test + void websocketTransportBasicSslFromClassPath() { + testBasicSslWithKeyStore("classpath:test.jks", "password", Transport.WEBSOCKET); + } + + @Test + void websocketTransportBasicSslFromFileSystem() { + testBasicSslWithKeyStore("src/test/resources/test.jks", "password", Transport.WEBSOCKET); + } + + private void testBasicSslWithKeyStore(String keyStore, String keyPassword, Transport transport) { + NettyRSocketServerFactory factory = getFactory(); + factory.setTransport(transport); + Ssl ssl = new Ssl(); + ssl.setKeyStore(keyStore); + ssl.setKeyPassword(keyPassword); + factory.setSsl(ssl); + this.server = factory.create(new EchoRequestResponseAcceptor()); + this.server.start(); + this.requester = (transport == Transport.TCP) ? createRSocketTcpClient(true) + : createRSocketWebSocketClient(true); + String payload = "test payload"; + Mono responseMono = this.requester.route("test").data(payload).retrieveMono(String.class); + StepVerifier.create(responseMono).expectNext(payload).verifyComplete(); + } + + @Test + void tcpTransportSslRejectsInsecureClient() { + NettyRSocketServerFactory factory = getFactory(); + factory.setTransport(Transport.TCP); + Ssl ssl = new Ssl(); + ssl.setKeyStore("classpath:test.jks"); + ssl.setKeyPassword("password"); + factory.setSsl(ssl); + this.server = factory.create(new EchoRequestResponseAcceptor()); + this.server.start(); + this.requester = createRSocketTcpClient(false); + String payload = "test payload"; + Mono responseMono = this.requester.route("test").data(payload).retrieveMono(String.class); + StepVerifier.create(responseMono) + .verifyErrorSatisfies((ex) -> assertThatExceptionOfType(ClosedChannelException.class)); + } + + @Test + void websocketTransportSslRejectsInsecureClient() { + NettyRSocketServerFactory factory = getFactory(); + factory.setTransport(Transport.WEBSOCKET); + Ssl ssl = new Ssl(); + ssl.setKeyStore("classpath:test.jks"); + ssl.setKeyPassword("password"); + factory.setSsl(ssl); + this.server = factory.create(new EchoRequestResponseAcceptor()); + this.server.start(); + // For WebSocket, the SSL failure results in a hang on the initial connect call + assertThatThrownBy(() -> createRSocketWebSocketClient(false)).isInstanceOf(IllegalStateException.class) + .hasStackTraceContaining("Timeout on blocking read"); + } + + private RSocketRequester createRSocketTcpClient(boolean ssl) { + TcpClient tcpClient = createTcpClient(ssl); + return createRSocketRequesterBuilder().connect(TcpClientTransport.create(tcpClient)).block(TIMEOUT); + } + + private RSocketRequester createRSocketWebSocketClient(boolean ssl) { + TcpClient tcpClient = createTcpClient(ssl); + return createRSocketRequesterBuilder().connect(WebsocketClientTransport.create(tcpClient)).block(TIMEOUT); + } + + private TcpClient createTcpClient(boolean ssl) { Assertions.assertThat(this.server).isNotNull(); InetSocketAddress address = this.server.address(); - return createRSocketRequesterBuilder().transport(WebsocketClientTransport.create(address)); + TcpClient tcpClient = TcpClient.create().host(address.getHostName()).port(address.getPort()); + if (ssl) { + SslContextBuilder builder = SslContextBuilder.forClient().sslProvider(SslProvider.JDK) + .trustManager(InsecureTrustManagerFactory.INSTANCE); + tcpClient = tcpClient.secure((spec) -> spec.sslContext(builder)); + } + return tcpClient; } private RSocketRequester.Builder createRSocketRequesterBuilder() { From 0715750eb30acb6935f218ac104d44772a4a71f3 Mon Sep 17 00:00:00 2001 From: Brian Clozel Date: Wed, 18 Dec 2019 22:11:49 -0600 Subject: [PATCH 2/2] Polish "Add SSL support to RSocketServer" See gh-19399 --- .../rsocket/RSocketProperties.java | 2 +- .../netty/NettyRSocketServerFactory.java | 29 +---- .../ConfigurableRSocketServerFactory.java | 2 +- .../netty/NettyRSocketServerFactoryTests.java | 107 +++++++++--------- 4 files changed, 62 insertions(+), 78 deletions(-) diff --git a/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/rsocket/RSocketProperties.java b/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/rsocket/RSocketProperties.java index cf18235260e..71e5ce4785f 100644 --- a/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/rsocket/RSocketProperties.java +++ b/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/rsocket/RSocketProperties.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2019 the original author or authors. + * Copyright 2012-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactory.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactory.java index 5859fd50fb3..278d066e416 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactory.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactory.java @@ -151,41 +151,27 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur } private ServerTransport createWebSocketTransport() { - HttpServer httpServer; + HttpServer httpServer = HttpServer.create(); if (this.resourceFactory != null) { - httpServer = HttpServer.create().runOn(this.resourceFactory.getLoopResources()) - .bindAddress(this::getListenAddress); + httpServer = httpServer.runOn(this.resourceFactory.getLoopResources()); } - else { - InetSocketAddress listenAddress = this.getListenAddress(); - httpServer = HttpServer.create().host(listenAddress.getHostName()).port(listenAddress.getPort()); - } - if (this.ssl != null && this.ssl.isEnabled()) { SslServerCustomizer sslServerCustomizer = new SslServerCustomizer(this.ssl, null, this.sslStoreProvider); httpServer = sslServerCustomizer.apply(httpServer); } - - return WebsocketServerTransport.create(httpServer); + return WebsocketServerTransport.create(httpServer.bindAddress(this::getListenAddress)); } private ServerTransport createTcpTransport() { - TcpServer tcpServer; + TcpServer tcpServer = TcpServer.create(); if (this.resourceFactory != null) { - tcpServer = TcpServer.create().runOn(this.resourceFactory.getLoopResources()) - .bindAddress(this::getListenAddress); + tcpServer = tcpServer.runOn(this.resourceFactory.getLoopResources()); } - else { - InetSocketAddress listenAddress = this.getListenAddress(); - tcpServer = TcpServer.create().host(listenAddress.getHostName()).port(listenAddress.getPort()); - } - if (this.ssl != null && this.ssl.isEnabled()) { TcpSslServerCustomizer sslServerCustomizer = new TcpSslServerCustomizer(this.ssl, this.sslStoreProvider); tcpServer = sslServerCustomizer.apply(tcpServer); } - - return TcpServerTransport.create(tcpServer); + return TcpServerTransport.create(tcpServer.bindAddress(this::getListenAddress)); } private InetSocketAddress getListenAddress() { @@ -201,9 +187,6 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur super(ssl, null, sslStoreProvider); } - // This does not override the apply in parent - currently just leveraging the - // parent for its "getContextBuilder()" method. This should be refactored when - // we add the concept of http/tcp customizers for RSocket. private TcpServer apply(TcpServer server) { try { return server.secure((contextSpec) -> contextSpec.sslContext(getContextBuilder())); diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/rsocket/server/ConfigurableRSocketServerFactory.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/rsocket/server/ConfigurableRSocketServerFactory.java index afbf549ba2d..40de9b3f7f0 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/rsocket/server/ConfigurableRSocketServerFactory.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/rsocket/server/ConfigurableRSocketServerFactory.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2019 the original author or authors. + * Copyright 2012-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactoryTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactoryTests.java index d329c6dbd47..30cc5bfb67f 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactoryTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactoryTests.java @@ -18,7 +18,6 @@ package org.springframework.boot.rsocket.netty; import java.net.InetSocketAddress; import java.nio.channels.ClosedChannelException; -import java.time.Duration; import java.util.Arrays; import java.util.concurrent.Callable; @@ -38,12 +37,13 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; import org.mockito.InOrder; import reactor.core.publisher.Mono; +import reactor.netty.http.client.HttpClient; import reactor.netty.tcp.TcpClient; import reactor.test.StepVerifier; import org.springframework.boot.rsocket.server.RSocketServer; -import org.springframework.boot.rsocket.server.RSocketServerCustomizer; import org.springframework.boot.rsocket.server.RSocketServer.Transport; +import org.springframework.boot.rsocket.server.RSocketServerCustomizer; import org.springframework.boot.web.server.Ssl; import org.springframework.core.codec.CharSequenceEncoder; import org.springframework.core.codec.StringDecoder; @@ -55,7 +55,6 @@ import org.springframework.util.SocketUtils; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.will; import static org.mockito.Mockito.inOrder; @@ -74,10 +73,11 @@ class NettyRSocketServerFactoryTests { private RSocketRequester requester; - private static final Duration TIMEOUT = Duration.ofSeconds(3); - @AfterEach void tearDown() { + if (this.requester != null) { + this.requester.rsocketClient().dispose(); + } if (this.server != null) { try { this.server.stop(); @@ -86,9 +86,6 @@ class NettyRSocketServerFactoryTests { // Ignore } } - if (this.requester != null) { - this.requester.rsocketClient().dispose(); - } } private NettyRSocketServerFactory getFactory() { @@ -105,11 +102,9 @@ class NettyRSocketServerFactoryTests { this.server.start(); return port; }); - this.requester = createRSocketTcpClient(false); - String payload = "test payload"; - String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT); + this.requester = createRSocketTcpClient(); assertThat(this.server.address().getPort()).isEqualTo(specificPort); - assertThat(response).isEqualTo(payload); + checkEchoRequest(); } @Test @@ -118,10 +113,8 @@ class NettyRSocketServerFactoryTests { factory.setTransport(RSocketServer.Transport.WEBSOCKET); this.server = factory.create(new EchoRequestResponseAcceptor()); this.server.start(); - this.requester = createRSocketWebSocketClient(false); - String payload = "test payload"; - String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT); - assertThat(response).isEqualTo(payload); + this.requester = createRSocketWebSocketClient(); + checkEchoRequest(); } @Test @@ -133,10 +126,8 @@ class NettyRSocketServerFactoryTests { factory.setResourceFactory(resourceFactory); this.server = factory.create(new EchoRequestResponseAcceptor()); this.server.start(); - this.requester = createRSocketWebSocketClient(false); - String payload = "test payload"; - String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT); - assertThat(response).isEqualTo(payload); + this.requester = createRSocketWebSocketClient(); + checkEchoRequest(); } @Test @@ -176,6 +167,12 @@ class NettyRSocketServerFactoryTests { testBasicSslWithKeyStore("src/test/resources/test.jks", "password", Transport.WEBSOCKET); } + private void checkEchoRequest() { + String payload = "test payload"; + Mono response = this.requester.route("test").data(payload).retrieveMono(String.class); + StepVerifier.create(response).expectNext(payload).verifyComplete(); + } + private void testBasicSslWithKeyStore(String keyStore, String keyPassword, Transport transport) { NettyRSocketServerFactory factory = getFactory(); factory.setTransport(transport); @@ -185,11 +182,9 @@ class NettyRSocketServerFactoryTests { factory.setSsl(ssl); this.server = factory.create(new EchoRequestResponseAcceptor()); this.server.start(); - this.requester = (transport == Transport.TCP) ? createRSocketTcpClient(true) - : createRSocketWebSocketClient(true); - String payload = "test payload"; - Mono responseMono = this.requester.route("test").data(payload).retrieveMono(String.class); - StepVerifier.create(responseMono).expectNext(payload).verifyComplete(); + this.requester = (transport == Transport.TCP) ? createSecureRSocketTcpClient() + : createSecureRSocketWebSocketClient(); + checkEchoRequest(); } @Test @@ -202,48 +197,54 @@ class NettyRSocketServerFactoryTests { factory.setSsl(ssl); this.server = factory.create(new EchoRequestResponseAcceptor()); this.server.start(); - this.requester = createRSocketTcpClient(false); + this.requester = createRSocketTcpClient(); String payload = "test payload"; Mono responseMono = this.requester.route("test").data(payload).retrieveMono(String.class); StepVerifier.create(responseMono) .verifyErrorSatisfies((ex) -> assertThatExceptionOfType(ClosedChannelException.class)); } - @Test - void websocketTransportSslRejectsInsecureClient() { - NettyRSocketServerFactory factory = getFactory(); - factory.setTransport(Transport.WEBSOCKET); - Ssl ssl = new Ssl(); - ssl.setKeyStore("classpath:test.jks"); - ssl.setKeyPassword("password"); - factory.setSsl(ssl); - this.server = factory.create(new EchoRequestResponseAcceptor()); - this.server.start(); - // For WebSocket, the SSL failure results in a hang on the initial connect call - assertThatThrownBy(() -> createRSocketWebSocketClient(false)).isInstanceOf(IllegalStateException.class) - .hasStackTraceContaining("Timeout on blocking read"); + private RSocketRequester createRSocketTcpClient() { + return createRSocketRequesterBuilder().transport(TcpClientTransport.create(createTcpClient())); } - private RSocketRequester createRSocketTcpClient(boolean ssl) { - TcpClient tcpClient = createTcpClient(ssl); - return createRSocketRequesterBuilder().connect(TcpClientTransport.create(tcpClient)).block(TIMEOUT); + private RSocketRequester createRSocketWebSocketClient() { + return createRSocketRequesterBuilder().transport(WebsocketClientTransport.create(createHttpClient(), "/")); } - private RSocketRequester createRSocketWebSocketClient(boolean ssl) { - TcpClient tcpClient = createTcpClient(ssl); - return createRSocketRequesterBuilder().connect(WebsocketClientTransport.create(tcpClient)).block(TIMEOUT); + private RSocketRequester createSecureRSocketTcpClient() { + return createRSocketRequesterBuilder().transport(TcpClientTransport.create(createSecureTcpClient())); } - private TcpClient createTcpClient(boolean ssl) { + private RSocketRequester createSecureRSocketWebSocketClient() { + return createRSocketRequesterBuilder() + .transport(WebsocketClientTransport.create(createSecureHttpClient(), "/")); + } + + private HttpClient createSecureHttpClient() { + HttpClient httpClient = createHttpClient(); + SslContextBuilder builder = SslContextBuilder.forClient().sslProvider(SslProvider.JDK) + .trustManager(InsecureTrustManagerFactory.INSTANCE); + return httpClient.secure((spec) -> spec.sslContext(builder)); + } + + private HttpClient createHttpClient() { Assertions.assertThat(this.server).isNotNull(); InetSocketAddress address = this.server.address(); - TcpClient tcpClient = TcpClient.create().host(address.getHostName()).port(address.getPort()); - if (ssl) { - SslContextBuilder builder = SslContextBuilder.forClient().sslProvider(SslProvider.JDK) - .trustManager(InsecureTrustManagerFactory.INSTANCE); - tcpClient = tcpClient.secure((spec) -> spec.sslContext(builder)); - } - return tcpClient; + return HttpClient.create().host(address.getHostName()).port(address.getPort()); + } + + private TcpClient createSecureTcpClient() { + TcpClient tcpClient = createTcpClient(); + SslContextBuilder builder = SslContextBuilder.forClient().sslProvider(SslProvider.JDK) + .trustManager(InsecureTrustManagerFactory.INSTANCE); + return tcpClient.secure((spec) -> spec.sslContext(builder)); + } + + private TcpClient createTcpClient() { + Assertions.assertThat(this.server).isNotNull(); + InetSocketAddress address = this.server.address(); + return TcpClient.create().host(address.getHostName()).port(address.getPort()); } private RSocketRequester.Builder createRSocketRequesterBuilder() {