diff --git a/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/rsocket/RSocketProperties.java b/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/rsocket/RSocketProperties.java index 232eb1db427..71e5ce4785f 100644 --- a/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/rsocket/RSocketProperties.java +++ b/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/rsocket/RSocketProperties.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2019 the original author or authors. + * Copyright 2012-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,12 +19,15 @@ package org.springframework.boot.autoconfigure.rsocket; import java.net.InetAddress; import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; import org.springframework.boot.rsocket.server.RSocketServer; +import org.springframework.boot.web.server.Ssl; /** * {@link ConfigurationProperties properties} for RSocket support. * * @author Brian Clozel + * @author Chris Bono * @since 2.2.0 */ @ConfigurationProperties("spring.rsocket") @@ -59,6 +62,9 @@ public class RSocketProperties { */ private String mappingPath; + @NestedConfigurationProperty + private Ssl ssl; + public Integer getPort() { return this.port; } @@ -91,6 +97,14 @@ public class RSocketProperties { this.mappingPath = mappingPath; } + public Ssl getSsl() { + return this.ssl; + } + + public void setSsl(Ssl ssl) { + this.ssl = ssl; + } + } } diff --git a/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/rsocket/RSocketServerAutoConfiguration.java b/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/rsocket/RSocketServerAutoConfiguration.java index 7b13f7beabf..80902e1de3d 100644 --- a/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/rsocket/RSocketServerAutoConfiguration.java +++ b/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/rsocket/RSocketServerAutoConfiguration.java @@ -97,6 +97,7 @@ public class RSocketServerAutoConfiguration { PropertyMapper map = PropertyMapper.get().alwaysApplyingWhenNonNull(); map.from(properties.getServer().getAddress()).to(factory::setAddress); map.from(properties.getServer().getPort()).to(factory::setPort); + map.from(properties.getServer().getSsl()).to(factory::setSsl); factory.setRSocketServerCustomizers(customizers.orderedStream().collect(Collectors.toList())); return factory; } diff --git a/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/rsocket/RSocketServerAutoConfigurationTests.java b/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/rsocket/RSocketServerAutoConfigurationTests.java index 2134011e8a1..ae354082305 100644 --- a/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/rsocket/RSocketServerAutoConfigurationTests.java +++ b/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/rsocket/RSocketServerAutoConfigurationTests.java @@ -91,6 +91,18 @@ class RSocketServerAutoConfigurationTests { }); } + @Test + void shouldUseSslWhenRocketServerSslIsConfigured() { + reactiveWebContextRunner() + .withPropertyValues("spring.rsocket.server.ssl.keyStore=classpath:rsocket/test.jks", + "spring.rsocket.server.ssl.keyPassword=password", "spring.rsocket.server.port=0") + .run((context) -> assertThat(context).hasSingleBean(RSocketServerFactory.class) + .hasSingleBean(RSocketServerBootstrap.class).hasSingleBean(RSocketServerCustomizer.class) + .getBean(RSocketServerFactory.class) + .hasFieldOrPropertyWithValue("ssl.keyStore", "classpath:rsocket/test.jks") + .hasFieldOrPropertyWithValue("ssl.keyPassword", "password")); + } + @Test void shouldUseCustomServerBootstrap() { contextRunner().withUserConfiguration(CustomServerBootstrapConfig.class).run((context) -> assertThat(context) diff --git a/spring-boot-project/spring-boot-autoconfigure/src/test/resources/rsocket/test.jks b/spring-boot-project/spring-boot-autoconfigure/src/test/resources/rsocket/test.jks new file mode 100644 index 00000000000..0fc3e802f75 Binary files /dev/null and b/spring-boot-project/spring-boot-autoconfigure/src/test/resources/rsocket/test.jks differ diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactory.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactory.java index 7668e629932..278d066e416 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactory.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactory.java @@ -37,6 +37,9 @@ import org.springframework.boot.rsocket.server.ConfigurableRSocketServerFactory; import org.springframework.boot.rsocket.server.RSocketServer; import org.springframework.boot.rsocket.server.RSocketServerCustomizer; import org.springframework.boot.rsocket.server.RSocketServerFactory; +import org.springframework.boot.web.embedded.netty.SslServerCustomizer; +import org.springframework.boot.web.server.Ssl; +import org.springframework.boot.web.server.SslStoreProvider; import org.springframework.http.client.reactive.ReactorResourceFactory; import org.springframework.util.Assert; @@ -45,6 +48,7 @@ import org.springframework.util.Assert; * by Netty. * * @author Brian Clozel + * @author Chris Bono * @since 2.2.0 */ public class NettyRSocketServerFactory implements RSocketServerFactory, ConfigurableRSocketServerFactory { @@ -61,6 +65,10 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur private List rSocketServerCustomizers = new ArrayList<>(); + private Ssl ssl; + + private SslStoreProvider sslStoreProvider; + @Override public void setPort(int port) { this.port = port; @@ -76,6 +84,16 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur this.transport = transport; } + @Override + public void setSsl(Ssl ssl) { + this.ssl = ssl; + } + + @Override + public void setSslStoreProvider(SslStoreProvider sslStoreProvider) { + this.sslStoreProvider = sslStoreProvider; + } + /** * Set the {@link ReactorResourceFactory} to get the shared resources from. * @param resourceFactory the server resources @@ -133,21 +151,27 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur } private ServerTransport createWebSocketTransport() { + HttpServer httpServer = HttpServer.create(); if (this.resourceFactory != null) { - HttpServer httpServer = HttpServer.create().runOn(this.resourceFactory.getLoopResources()) - .bindAddress(this::getListenAddress); - return WebsocketServerTransport.create(httpServer); + httpServer = httpServer.runOn(this.resourceFactory.getLoopResources()); } - return WebsocketServerTransport.create(getListenAddress()); + if (this.ssl != null && this.ssl.isEnabled()) { + SslServerCustomizer sslServerCustomizer = new SslServerCustomizer(this.ssl, null, this.sslStoreProvider); + httpServer = sslServerCustomizer.apply(httpServer); + } + return WebsocketServerTransport.create(httpServer.bindAddress(this::getListenAddress)); } private ServerTransport createTcpTransport() { + TcpServer tcpServer = TcpServer.create(); if (this.resourceFactory != null) { - TcpServer tcpServer = TcpServer.create().runOn(this.resourceFactory.getLoopResources()) - .bindAddress(this::getListenAddress); - return TcpServerTransport.create(tcpServer); + tcpServer = tcpServer.runOn(this.resourceFactory.getLoopResources()); } - return TcpServerTransport.create(getListenAddress()); + if (this.ssl != null && this.ssl.isEnabled()) { + TcpSslServerCustomizer sslServerCustomizer = new TcpSslServerCustomizer(this.ssl, this.sslStoreProvider); + tcpServer = sslServerCustomizer.apply(tcpServer); + } + return TcpServerTransport.create(tcpServer.bindAddress(this::getListenAddress)); } private InetSocketAddress getListenAddress() { @@ -157,4 +181,21 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur return new InetSocketAddress(this.port); } + private static final class TcpSslServerCustomizer extends SslServerCustomizer { + + private TcpSslServerCustomizer(Ssl ssl, SslStoreProvider sslStoreProvider) { + super(ssl, null, sslStoreProvider); + } + + private TcpServer apply(TcpServer server) { + try { + return server.secure((contextSpec) -> contextSpec.sslContext(getContextBuilder())); + } + catch (Exception ex) { + throw new IllegalStateException(ex); + } + } + + } + } diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/rsocket/server/ConfigurableRSocketServerFactory.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/rsocket/server/ConfigurableRSocketServerFactory.java index cb974105429..40de9b3f7f0 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/rsocket/server/ConfigurableRSocketServerFactory.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/rsocket/server/ConfigurableRSocketServerFactory.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2019 the original author or authors. + * Copyright 2012-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,9 @@ package org.springframework.boot.rsocket.server; import java.net.InetAddress; +import org.springframework.boot.web.server.Ssl; +import org.springframework.boot.web.server.SslStoreProvider; + /** * A configurable {@link RSocketServerFactory}. * @@ -45,4 +48,16 @@ public interface ConfigurableRSocketServerFactory { */ void setTransport(RSocketServer.Transport transport); + /** + * Sets the SSL configuration that will be applied to the server's default connector. + * @param ssl the SSL configuration + */ + void setSsl(Ssl ssl); + + /** + * Sets a provider that will be used to obtain SSL stores. + * @param sslStoreProvider the SSL store provider + */ + void setSslStoreProvider(SslStoreProvider sslStoreProvider); + } diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactoryTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactoryTests.java index d116d9c6f60..30cc5bfb67f 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactoryTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactoryTests.java @@ -17,15 +17,19 @@ package org.springframework.boot.rsocket.netty; import java.net.InetSocketAddress; -import java.time.Duration; +import java.nio.channels.ClosedChannelException; import java.util.Arrays; import java.util.concurrent.Callable; import io.netty.buffer.PooledByteBufAllocator; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.SslProvider; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; import io.rsocket.ConnectionSetupPayload; import io.rsocket.Payload; import io.rsocket.RSocket; import io.rsocket.SocketAcceptor; +import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.client.WebsocketClientTransport; import io.rsocket.util.DefaultPayload; import org.assertj.core.api.Assertions; @@ -33,9 +37,14 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; import org.mockito.InOrder; import reactor.core.publisher.Mono; +import reactor.netty.http.client.HttpClient; +import reactor.netty.tcp.TcpClient; +import reactor.test.StepVerifier; import org.springframework.boot.rsocket.server.RSocketServer; +import org.springframework.boot.rsocket.server.RSocketServer.Transport; import org.springframework.boot.rsocket.server.RSocketServerCustomizer; +import org.springframework.boot.web.server.Ssl; import org.springframework.core.codec.CharSequenceEncoder; import org.springframework.core.codec.StringDecoder; import org.springframework.core.io.buffer.NettyDataBufferFactory; @@ -45,6 +54,7 @@ import org.springframework.messaging.rsocket.RSocketStrategies; import org.springframework.util.SocketUtils; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.will; import static org.mockito.Mockito.inOrder; @@ -55,6 +65,7 @@ import static org.mockito.Mockito.mock; * * @author Brian Clozel * @author Leo Li + * @author Chris Bono */ class NettyRSocketServerFactoryTests { @@ -62,10 +73,11 @@ class NettyRSocketServerFactoryTests { private RSocketRequester requester; - private static final Duration TIMEOUT = Duration.ofSeconds(3); - @AfterEach void tearDown() { + if (this.requester != null) { + this.requester.rsocketClient().dispose(); + } if (this.server != null) { try { this.server.stop(); @@ -74,9 +86,6 @@ class NettyRSocketServerFactoryTests { // Ignore } } - if (this.requester != null) { - this.requester.rsocketClient().dispose(); - } } private NettyRSocketServerFactory getFactory() { @@ -94,10 +103,8 @@ class NettyRSocketServerFactoryTests { return port; }); this.requester = createRSocketTcpClient(); - String payload = "test payload"; - String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT); assertThat(this.server.address().getPort()).isEqualTo(specificPort); - assertThat(response).isEqualTo(payload); + checkEchoRequest(); } @Test @@ -107,9 +114,7 @@ class NettyRSocketServerFactoryTests { this.server = factory.create(new EchoRequestResponseAcceptor()); this.server.start(); this.requester = createRSocketWebSocketClient(); - String payload = "test payload"; - String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT); - assertThat(response).isEqualTo(payload); + checkEchoRequest(); } @Test @@ -122,9 +127,7 @@ class NettyRSocketServerFactoryTests { this.server = factory.create(new EchoRequestResponseAcceptor()); this.server.start(); this.requester = createRSocketWebSocketClient(); - String payload = "test payload"; - String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT); - assertThat(response).isEqualTo(payload); + checkEchoRequest(); } @Test @@ -144,16 +147,104 @@ class NettyRSocketServerFactoryTests { } } + @Test + void tcpTransportBasicSslFromClassPath() { + testBasicSslWithKeyStore("classpath:test.jks", "password", Transport.TCP); + } + + @Test + void tcpTransportBasicSslFromFileSystem() { + testBasicSslWithKeyStore("src/test/resources/test.jks", "password", Transport.TCP); + } + + @Test + void websocketTransportBasicSslFromClassPath() { + testBasicSslWithKeyStore("classpath:test.jks", "password", Transport.WEBSOCKET); + } + + @Test + void websocketTransportBasicSslFromFileSystem() { + testBasicSslWithKeyStore("src/test/resources/test.jks", "password", Transport.WEBSOCKET); + } + + private void checkEchoRequest() { + String payload = "test payload"; + Mono response = this.requester.route("test").data(payload).retrieveMono(String.class); + StepVerifier.create(response).expectNext(payload).verifyComplete(); + } + + private void testBasicSslWithKeyStore(String keyStore, String keyPassword, Transport transport) { + NettyRSocketServerFactory factory = getFactory(); + factory.setTransport(transport); + Ssl ssl = new Ssl(); + ssl.setKeyStore(keyStore); + ssl.setKeyPassword(keyPassword); + factory.setSsl(ssl); + this.server = factory.create(new EchoRequestResponseAcceptor()); + this.server.start(); + this.requester = (transport == Transport.TCP) ? createSecureRSocketTcpClient() + : createSecureRSocketWebSocketClient(); + checkEchoRequest(); + } + + @Test + void tcpTransportSslRejectsInsecureClient() { + NettyRSocketServerFactory factory = getFactory(); + factory.setTransport(Transport.TCP); + Ssl ssl = new Ssl(); + ssl.setKeyStore("classpath:test.jks"); + ssl.setKeyPassword("password"); + factory.setSsl(ssl); + this.server = factory.create(new EchoRequestResponseAcceptor()); + this.server.start(); + this.requester = createRSocketTcpClient(); + String payload = "test payload"; + Mono responseMono = this.requester.route("test").data(payload).retrieveMono(String.class); + StepVerifier.create(responseMono) + .verifyErrorSatisfies((ex) -> assertThatExceptionOfType(ClosedChannelException.class)); + } + private RSocketRequester createRSocketTcpClient() { - Assertions.assertThat(this.server).isNotNull(); - InetSocketAddress address = this.server.address(); - return createRSocketRequesterBuilder().tcp(address.getHostString(), address.getPort()); + return createRSocketRequesterBuilder().transport(TcpClientTransport.create(createTcpClient())); } private RSocketRequester createRSocketWebSocketClient() { + return createRSocketRequesterBuilder().transport(WebsocketClientTransport.create(createHttpClient(), "/")); + } + + private RSocketRequester createSecureRSocketTcpClient() { + return createRSocketRequesterBuilder().transport(TcpClientTransport.create(createSecureTcpClient())); + } + + private RSocketRequester createSecureRSocketWebSocketClient() { + return createRSocketRequesterBuilder() + .transport(WebsocketClientTransport.create(createSecureHttpClient(), "/")); + } + + private HttpClient createSecureHttpClient() { + HttpClient httpClient = createHttpClient(); + SslContextBuilder builder = SslContextBuilder.forClient().sslProvider(SslProvider.JDK) + .trustManager(InsecureTrustManagerFactory.INSTANCE); + return httpClient.secure((spec) -> spec.sslContext(builder)); + } + + private HttpClient createHttpClient() { Assertions.assertThat(this.server).isNotNull(); InetSocketAddress address = this.server.address(); - return createRSocketRequesterBuilder().transport(WebsocketClientTransport.create(address)); + return HttpClient.create().host(address.getHostName()).port(address.getPort()); + } + + private TcpClient createSecureTcpClient() { + TcpClient tcpClient = createTcpClient(); + SslContextBuilder builder = SslContextBuilder.forClient().sslProvider(SslProvider.JDK) + .trustManager(InsecureTrustManagerFactory.INSTANCE); + return tcpClient.secure((spec) -> spec.sslContext(builder)); + } + + private TcpClient createTcpClient() { + Assertions.assertThat(this.server).isNotNull(); + InetSocketAddress address = this.server.address(); + return TcpClient.create().host(address.getHostName()).port(address.getPort()); } private RSocketRequester.Builder createRSocketRequesterBuilder() {