Merge pull request #19399 from bono007

* gh-19205:
  Polish "Add SSL support to RSocketServer"
  Add SSL support to RSocketServer

Closes gh-19399
This commit is contained in:
Brian Clozel 2020-09-10 15:26:51 +02:00
commit 53607ea777
7 changed files with 203 additions and 29 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2012-2019 the original author or authors.
* Copyright 2012-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,12 +19,15 @@ package org.springframework.boot.autoconfigure.rsocket;
import java.net.InetAddress;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.NestedConfigurationProperty;
import org.springframework.boot.rsocket.server.RSocketServer;
import org.springframework.boot.web.server.Ssl;
/**
* {@link ConfigurationProperties properties} for RSocket support.
*
* @author Brian Clozel
* @author Chris Bono
* @since 2.2.0
*/
@ConfigurationProperties("spring.rsocket")
@ -59,6 +62,9 @@ public class RSocketProperties {
*/
private String mappingPath;
@NestedConfigurationProperty
private Ssl ssl;
public Integer getPort() {
return this.port;
}
@ -91,6 +97,14 @@ public class RSocketProperties {
this.mappingPath = mappingPath;
}
public Ssl getSsl() {
return this.ssl;
}
public void setSsl(Ssl ssl) {
this.ssl = ssl;
}
}
}

View File

@ -97,6 +97,7 @@ public class RSocketServerAutoConfiguration {
PropertyMapper map = PropertyMapper.get().alwaysApplyingWhenNonNull();
map.from(properties.getServer().getAddress()).to(factory::setAddress);
map.from(properties.getServer().getPort()).to(factory::setPort);
map.from(properties.getServer().getSsl()).to(factory::setSsl);
factory.setRSocketServerCustomizers(customizers.orderedStream().collect(Collectors.toList()));
return factory;
}

View File

@ -91,6 +91,18 @@ class RSocketServerAutoConfigurationTests {
});
}
@Test
void shouldUseSslWhenRocketServerSslIsConfigured() {
reactiveWebContextRunner()
.withPropertyValues("spring.rsocket.server.ssl.keyStore=classpath:rsocket/test.jks",
"spring.rsocket.server.ssl.keyPassword=password", "spring.rsocket.server.port=0")
.run((context) -> assertThat(context).hasSingleBean(RSocketServerFactory.class)
.hasSingleBean(RSocketServerBootstrap.class).hasSingleBean(RSocketServerCustomizer.class)
.getBean(RSocketServerFactory.class)
.hasFieldOrPropertyWithValue("ssl.keyStore", "classpath:rsocket/test.jks")
.hasFieldOrPropertyWithValue("ssl.keyPassword", "password"));
}
@Test
void shouldUseCustomServerBootstrap() {
contextRunner().withUserConfiguration(CustomServerBootstrapConfig.class).run((context) -> assertThat(context)

View File

@ -37,6 +37,9 @@ import org.springframework.boot.rsocket.server.ConfigurableRSocketServerFactory;
import org.springframework.boot.rsocket.server.RSocketServer;
import org.springframework.boot.rsocket.server.RSocketServerCustomizer;
import org.springframework.boot.rsocket.server.RSocketServerFactory;
import org.springframework.boot.web.embedded.netty.SslServerCustomizer;
import org.springframework.boot.web.server.Ssl;
import org.springframework.boot.web.server.SslStoreProvider;
import org.springframework.http.client.reactive.ReactorResourceFactory;
import org.springframework.util.Assert;
@ -45,6 +48,7 @@ import org.springframework.util.Assert;
* by Netty.
*
* @author Brian Clozel
* @author Chris Bono
* @since 2.2.0
*/
public class NettyRSocketServerFactory implements RSocketServerFactory, ConfigurableRSocketServerFactory {
@ -61,6 +65,10 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur
private List<RSocketServerCustomizer> rSocketServerCustomizers = new ArrayList<>();
private Ssl ssl;
private SslStoreProvider sslStoreProvider;
@Override
public void setPort(int port) {
this.port = port;
@ -76,6 +84,16 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur
this.transport = transport;
}
@Override
public void setSsl(Ssl ssl) {
this.ssl = ssl;
}
@Override
public void setSslStoreProvider(SslStoreProvider sslStoreProvider) {
this.sslStoreProvider = sslStoreProvider;
}
/**
* Set the {@link ReactorResourceFactory} to get the shared resources from.
* @param resourceFactory the server resources
@ -133,21 +151,27 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur
}
private ServerTransport<CloseableChannel> createWebSocketTransport() {
HttpServer httpServer = HttpServer.create();
if (this.resourceFactory != null) {
HttpServer httpServer = HttpServer.create().runOn(this.resourceFactory.getLoopResources())
.bindAddress(this::getListenAddress);
return WebsocketServerTransport.create(httpServer);
httpServer = httpServer.runOn(this.resourceFactory.getLoopResources());
}
return WebsocketServerTransport.create(getListenAddress());
if (this.ssl != null && this.ssl.isEnabled()) {
SslServerCustomizer sslServerCustomizer = new SslServerCustomizer(this.ssl, null, this.sslStoreProvider);
httpServer = sslServerCustomizer.apply(httpServer);
}
return WebsocketServerTransport.create(httpServer.bindAddress(this::getListenAddress));
}
private ServerTransport<CloseableChannel> createTcpTransport() {
TcpServer tcpServer = TcpServer.create();
if (this.resourceFactory != null) {
TcpServer tcpServer = TcpServer.create().runOn(this.resourceFactory.getLoopResources())
.bindAddress(this::getListenAddress);
return TcpServerTransport.create(tcpServer);
tcpServer = tcpServer.runOn(this.resourceFactory.getLoopResources());
}
return TcpServerTransport.create(getListenAddress());
if (this.ssl != null && this.ssl.isEnabled()) {
TcpSslServerCustomizer sslServerCustomizer = new TcpSslServerCustomizer(this.ssl, this.sslStoreProvider);
tcpServer = sslServerCustomizer.apply(tcpServer);
}
return TcpServerTransport.create(tcpServer.bindAddress(this::getListenAddress));
}
private InetSocketAddress getListenAddress() {
@ -157,4 +181,21 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur
return new InetSocketAddress(this.port);
}
private static final class TcpSslServerCustomizer extends SslServerCustomizer {
private TcpSslServerCustomizer(Ssl ssl, SslStoreProvider sslStoreProvider) {
super(ssl, null, sslStoreProvider);
}
private TcpServer apply(TcpServer server) {
try {
return server.secure((contextSpec) -> contextSpec.sslContext(getContextBuilder()));
}
catch (Exception ex) {
throw new IllegalStateException(ex);
}
}
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2012-2019 the original author or authors.
* Copyright 2012-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -18,6 +18,9 @@ package org.springframework.boot.rsocket.server;
import java.net.InetAddress;
import org.springframework.boot.web.server.Ssl;
import org.springframework.boot.web.server.SslStoreProvider;
/**
* A configurable {@link RSocketServerFactory}.
*
@ -45,4 +48,16 @@ public interface ConfigurableRSocketServerFactory {
*/
void setTransport(RSocketServer.Transport transport);
/**
* Sets the SSL configuration that will be applied to the server's default connector.
* @param ssl the SSL configuration
*/
void setSsl(Ssl ssl);
/**
* Sets a provider that will be used to obtain SSL stores.
* @param sslStoreProvider the SSL store provider
*/
void setSslStoreProvider(SslStoreProvider sslStoreProvider);
}

View File

@ -17,15 +17,19 @@
package org.springframework.boot.rsocket.netty;
import java.net.InetSocketAddress;
import java.time.Duration;
import java.nio.channels.ClosedChannelException;
import java.util.Arrays;
import java.util.concurrent.Callable;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslProvider;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.rsocket.ConnectionSetupPayload;
import io.rsocket.Payload;
import io.rsocket.RSocket;
import io.rsocket.SocketAcceptor;
import io.rsocket.transport.netty.client.TcpClientTransport;
import io.rsocket.transport.netty.client.WebsocketClientTransport;
import io.rsocket.util.DefaultPayload;
import org.assertj.core.api.Assertions;
@ -33,9 +37,14 @@ import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.mockito.InOrder;
import reactor.core.publisher.Mono;
import reactor.netty.http.client.HttpClient;
import reactor.netty.tcp.TcpClient;
import reactor.test.StepVerifier;
import org.springframework.boot.rsocket.server.RSocketServer;
import org.springframework.boot.rsocket.server.RSocketServer.Transport;
import org.springframework.boot.rsocket.server.RSocketServerCustomizer;
import org.springframework.boot.web.server.Ssl;
import org.springframework.core.codec.CharSequenceEncoder;
import org.springframework.core.codec.StringDecoder;
import org.springframework.core.io.buffer.NettyDataBufferFactory;
@ -45,6 +54,7 @@ import org.springframework.messaging.rsocket.RSocketStrategies;
import org.springframework.util.SocketUtils;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.will;
import static org.mockito.Mockito.inOrder;
@ -55,6 +65,7 @@ import static org.mockito.Mockito.mock;
*
* @author Brian Clozel
* @author Leo Li
* @author Chris Bono
*/
class NettyRSocketServerFactoryTests {
@ -62,10 +73,11 @@ class NettyRSocketServerFactoryTests {
private RSocketRequester requester;
private static final Duration TIMEOUT = Duration.ofSeconds(3);
@AfterEach
void tearDown() {
if (this.requester != null) {
this.requester.rsocketClient().dispose();
}
if (this.server != null) {
try {
this.server.stop();
@ -74,9 +86,6 @@ class NettyRSocketServerFactoryTests {
// Ignore
}
}
if (this.requester != null) {
this.requester.rsocketClient().dispose();
}
}
private NettyRSocketServerFactory getFactory() {
@ -94,10 +103,8 @@ class NettyRSocketServerFactoryTests {
return port;
});
this.requester = createRSocketTcpClient();
String payload = "test payload";
String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT);
assertThat(this.server.address().getPort()).isEqualTo(specificPort);
assertThat(response).isEqualTo(payload);
checkEchoRequest();
}
@Test
@ -107,9 +114,7 @@ class NettyRSocketServerFactoryTests {
this.server = factory.create(new EchoRequestResponseAcceptor());
this.server.start();
this.requester = createRSocketWebSocketClient();
String payload = "test payload";
String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT);
assertThat(response).isEqualTo(payload);
checkEchoRequest();
}
@Test
@ -122,9 +127,7 @@ class NettyRSocketServerFactoryTests {
this.server = factory.create(new EchoRequestResponseAcceptor());
this.server.start();
this.requester = createRSocketWebSocketClient();
String payload = "test payload";
String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT);
assertThat(response).isEqualTo(payload);
checkEchoRequest();
}
@Test
@ -144,16 +147,104 @@ class NettyRSocketServerFactoryTests {
}
}
@Test
void tcpTransportBasicSslFromClassPath() {
testBasicSslWithKeyStore("classpath:test.jks", "password", Transport.TCP);
}
@Test
void tcpTransportBasicSslFromFileSystem() {
testBasicSslWithKeyStore("src/test/resources/test.jks", "password", Transport.TCP);
}
@Test
void websocketTransportBasicSslFromClassPath() {
testBasicSslWithKeyStore("classpath:test.jks", "password", Transport.WEBSOCKET);
}
@Test
void websocketTransportBasicSslFromFileSystem() {
testBasicSslWithKeyStore("src/test/resources/test.jks", "password", Transport.WEBSOCKET);
}
private void checkEchoRequest() {
String payload = "test payload";
Mono<String> response = this.requester.route("test").data(payload).retrieveMono(String.class);
StepVerifier.create(response).expectNext(payload).verifyComplete();
}
private void testBasicSslWithKeyStore(String keyStore, String keyPassword, Transport transport) {
NettyRSocketServerFactory factory = getFactory();
factory.setTransport(transport);
Ssl ssl = new Ssl();
ssl.setKeyStore(keyStore);
ssl.setKeyPassword(keyPassword);
factory.setSsl(ssl);
this.server = factory.create(new EchoRequestResponseAcceptor());
this.server.start();
this.requester = (transport == Transport.TCP) ? createSecureRSocketTcpClient()
: createSecureRSocketWebSocketClient();
checkEchoRequest();
}
@Test
void tcpTransportSslRejectsInsecureClient() {
NettyRSocketServerFactory factory = getFactory();
factory.setTransport(Transport.TCP);
Ssl ssl = new Ssl();
ssl.setKeyStore("classpath:test.jks");
ssl.setKeyPassword("password");
factory.setSsl(ssl);
this.server = factory.create(new EchoRequestResponseAcceptor());
this.server.start();
this.requester = createRSocketTcpClient();
String payload = "test payload";
Mono<String> responseMono = this.requester.route("test").data(payload).retrieveMono(String.class);
StepVerifier.create(responseMono)
.verifyErrorSatisfies((ex) -> assertThatExceptionOfType(ClosedChannelException.class));
}
private RSocketRequester createRSocketTcpClient() {
Assertions.assertThat(this.server).isNotNull();
InetSocketAddress address = this.server.address();
return createRSocketRequesterBuilder().tcp(address.getHostString(), address.getPort());
return createRSocketRequesterBuilder().transport(TcpClientTransport.create(createTcpClient()));
}
private RSocketRequester createRSocketWebSocketClient() {
return createRSocketRequesterBuilder().transport(WebsocketClientTransport.create(createHttpClient(), "/"));
}
private RSocketRequester createSecureRSocketTcpClient() {
return createRSocketRequesterBuilder().transport(TcpClientTransport.create(createSecureTcpClient()));
}
private RSocketRequester createSecureRSocketWebSocketClient() {
return createRSocketRequesterBuilder()
.transport(WebsocketClientTransport.create(createSecureHttpClient(), "/"));
}
private HttpClient createSecureHttpClient() {
HttpClient httpClient = createHttpClient();
SslContextBuilder builder = SslContextBuilder.forClient().sslProvider(SslProvider.JDK)
.trustManager(InsecureTrustManagerFactory.INSTANCE);
return httpClient.secure((spec) -> spec.sslContext(builder));
}
private HttpClient createHttpClient() {
Assertions.assertThat(this.server).isNotNull();
InetSocketAddress address = this.server.address();
return createRSocketRequesterBuilder().transport(WebsocketClientTransport.create(address));
return HttpClient.create().host(address.getHostName()).port(address.getPort());
}
private TcpClient createSecureTcpClient() {
TcpClient tcpClient = createTcpClient();
SslContextBuilder builder = SslContextBuilder.forClient().sslProvider(SslProvider.JDK)
.trustManager(InsecureTrustManagerFactory.INSTANCE);
return tcpClient.secure((spec) -> spec.sslContext(builder));
}
private TcpClient createTcpClient() {
Assertions.assertThat(this.server).isNotNull();
InetSocketAddress address = this.server.address();
return TcpClient.create().host(address.getHostName()).port(address.getPort());
}
private RSocketRequester.Builder createRSocketRequesterBuilder() {