diff --git a/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/web/client/TestRestTemplate.java b/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/web/client/TestRestTemplate.java index df42dcad9cf..6632387509b 100644 --- a/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/web/client/TestRestTemplate.java +++ b/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/web/client/TestRestTemplate.java @@ -18,16 +18,22 @@ package org.springframework.boot.test.web.client; import java.io.IOException; import java.net.URI; +import java.security.KeyManagementException; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.time.Duration; import java.util.Arrays; import java.util.HashSet; import java.util.Map; import java.util.Set; +import java.util.concurrent.TimeUnit; import javax.net.ssl.SSLContext; import org.apache.hc.client5.http.classic.HttpClient; import org.apache.hc.client5.http.config.RequestConfig; import org.apache.hc.client5.http.cookie.StandardCookieSpec; +import org.apache.hc.client5.http.impl.classic.HttpClientBuilder; import org.apache.hc.client5.http.impl.classic.HttpClients; import org.apache.hc.client5.http.impl.io.PoolingHttpClientConnectionManager; import org.apache.hc.client5.http.impl.io.PoolingHttpClientConnectionManagerBuilder; @@ -35,10 +41,12 @@ import org.apache.hc.client5.http.protocol.HttpClientContext; import org.apache.hc.client5.http.ssl.SSLConnectionSocketFactory; import org.apache.hc.client5.http.ssl.SSLConnectionSocketFactoryBuilder; import org.apache.hc.client5.http.ssl.TrustSelfSignedStrategy; +import org.apache.hc.core5.http.io.SocketConfig; import org.apache.hc.core5.http.protocol.HttpContext; import org.apache.hc.core5.http.ssl.TLS; import org.apache.hc.core5.ssl.SSLContextBuilder; +import org.springframework.boot.web.client.ClientHttpRequestFactorySettings; import org.springframework.boot.web.client.RestTemplateBuilder; import org.springframework.boot.web.client.RootUriTemplateHandler; import org.springframework.core.ParameterizedTypeReference; @@ -138,8 +146,8 @@ public class TestRestTemplate { if (httpClientOptions != null) { ClientHttpRequestFactory requestFactory = builder.buildRequestFactory(); if (requestFactory instanceof HttpComponentsClientHttpRequestFactory) { - builder = builder - .requestFactory(() -> new CustomHttpComponentsClientHttpRequestFactory(httpClientOptions)); + builder = builder.requestFactory( + (settings) -> new CustomHttpComponentsClientHttpRequestFactory(httpClientOptions, settings)); } } if (username != null || password != null) { @@ -1000,43 +1008,71 @@ public class TestRestTemplate { private final boolean enableRedirects; - public CustomHttpComponentsClientHttpRequestFactory(HttpClientOption[] httpClientOptions) { + public CustomHttpComponentsClientHttpRequestFactory(HttpClientOption[] httpClientOptions, + ClientHttpRequestFactorySettings settings) { Set options = new HashSet<>(Arrays.asList(httpClientOptions)); this.cookieSpec = (options.contains(HttpClientOption.ENABLE_COOKIES) ? StandardCookieSpec.STRICT : StandardCookieSpec.IGNORE); this.enableRedirects = options.contains(HttpClientOption.ENABLE_REDIRECTS); - if (options.contains(HttpClientOption.SSL)) { - setHttpClient(createSslHttpClient()); + boolean ssl = options.contains(HttpClientOption.SSL); + if (settings.readTimeout() != null || ssl) { + setHttpClient(createHttpClient(settings.readTimeout(), ssl)); + } + if (settings.connectTimeout() != null) { + setConnectTimeout((int) settings.connectTimeout().toMillis()); + } + if (settings.bufferRequestBody() != null) { + setBufferRequestBody(settings.bufferRequestBody()); } } - private HttpClient createSslHttpClient() { + private HttpClient createHttpClient(Duration readTimeout, boolean ssl) { try { - SSLContext sslContext = new SSLContextBuilder().loadTrustMaterial(null, new TrustSelfSignedStrategy()) - .build(); - SSLConnectionSocketFactory socketFactory = SSLConnectionSocketFactoryBuilder.create() - .setSslContext(sslContext).setTlsVersions(TLS.V_1_3, TLS.V_1_2).build(); - PoolingHttpClientConnectionManager connectionManager = PoolingHttpClientConnectionManagerBuilder - .create().setSSLSocketFactory(socketFactory).build(); - - return HttpClients.custom().setConnectionManager(connectionManager) - .setDefaultRequestConfig(getRequestConfig()).build(); + HttpClientBuilder builder = HttpClients.custom(); + builder.setConnectionManager(createConnectionManager(readTimeout, ssl)); + builder.setDefaultRequestConfig(createRequestConfig()); + return builder.build(); } catch (Exception ex) { - throw new IllegalStateException("Unable to create SSL HttpClient", ex); + throw new IllegalStateException("Unable to create customized HttpClient", ex); } } + private PoolingHttpClientConnectionManager createConnectionManager(Duration readTimeout, boolean ssl) + throws NoSuchAlgorithmException, KeyManagementException, KeyStoreException { + PoolingHttpClientConnectionManagerBuilder builder = PoolingHttpClientConnectionManagerBuilder.create(); + if (ssl) { + builder.setSSLSocketFactory(createSocketFactory()); + } + if (readTimeout != null) { + SocketConfig socketConfig = SocketConfig.custom() + .setSoTimeout((int) readTimeout.toMillis(), TimeUnit.MILLISECONDS).build(); + builder.setDefaultSocketConfig(socketConfig); + } + return builder.build(); + } + + private SSLConnectionSocketFactory createSocketFactory() + throws NoSuchAlgorithmException, KeyManagementException, KeyStoreException { + SSLContext sslContext = new SSLContextBuilder().loadTrustMaterial(null, new TrustSelfSignedStrategy()) + .build(); + return SSLConnectionSocketFactoryBuilder.create().setSslContext(sslContext) + .setTlsVersions(TLS.V_1_3, TLS.V_1_2).build(); + } + @Override protected HttpContext createHttpContext(HttpMethod httpMethod, URI uri) { HttpClientContext context = HttpClientContext.create(); - context.setRequestConfig(getRequestConfig()); + context.setRequestConfig(createRequestConfig()); return context; } - protected RequestConfig getRequestConfig() { - return RequestConfig.custom().setCookieSpec(this.cookieSpec).setAuthenticationEnabled(false) - .setRedirectsEnabled(this.enableRedirects).build(); + protected RequestConfig createRequestConfig() { + RequestConfig.Builder builder = RequestConfig.custom(); + builder.setCookieSpec(this.cookieSpec); + builder.setAuthenticationEnabled(false); + builder.setRedirectsEnabled(this.enableRedirects); + return builder.build(); } } diff --git a/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/web/client/TestRestTemplateTests.java b/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/web/client/TestRestTemplateTests.java index 094418926ce..7313dde323f 100644 --- a/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/web/client/TestRestTemplateTests.java +++ b/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/web/client/TestRestTemplateTests.java @@ -134,7 +134,7 @@ class TestRestTemplateTests { TestRestTemplate template = new TestRestTemplate(HttpClientOption.ENABLE_REDIRECTS); CustomHttpComponentsClientHttpRequestFactory factory = (CustomHttpComponentsClientHttpRequestFactory) template .getRestTemplate().getRequestFactory(); - RequestConfig config = factory.getRequestConfig(); + RequestConfig config = factory.createRequestConfig(); assertThat(config.isRedirectsEnabled()).isTrue(); } diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/ClientHttpRequestFactories.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/ClientHttpRequestFactories.java new file mode 100644 index 00000000000..758b27ea8f6 --- /dev/null +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/ClientHttpRequestFactories.java @@ -0,0 +1,264 @@ +/* + * Copyright 2012-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.web.client; + +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.time.Duration; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + +import org.apache.hc.client5.http.classic.HttpClient; +import org.apache.hc.client5.http.impl.classic.HttpClientBuilder; +import org.apache.hc.client5.http.impl.io.PoolingHttpClientConnectionManager; +import org.apache.hc.client5.http.impl.io.PoolingHttpClientConnectionManagerBuilder; +import org.apache.hc.core5.http.io.SocketConfig; + +import org.springframework.boot.context.properties.PropertyMapper; +import org.springframework.http.client.AbstractClientHttpRequestFactoryWrapper; +import org.springframework.http.client.ClientHttpRequestFactory; +import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; +import org.springframework.http.client.OkHttp3ClientHttpRequestFactory; +import org.springframework.http.client.SimpleClientHttpRequestFactory; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.util.ReflectionUtils; + +/** + * Utility class that can be used to create {@link ClientHttpRequestFactory} instances + * configured using given {@link ClientHttpRequestFactorySettings}. + * + * @author Andy Wilkinson + * @author Phillip Webb + * @since 3.0.0 + */ +public final class ClientHttpRequestFactories { + + static final String APACHE_HTTP_CLIENT_CLASS = "org.apache.hc.client5.http.impl.classic.HttpClients"; + + private static final boolean APACHE_HTTP_CLIENT_PRESENT = ClassUtils.isPresent(APACHE_HTTP_CLIENT_CLASS, null); + + static final String OKHTTP_CLIENT_CLASS = "okhttp3.OkHttpClient"; + + private static final boolean OKHTTP_CLIENT_PRESENT = ClassUtils.isPresent(OKHTTP_CLIENT_CLASS, null); + + private ClientHttpRequestFactories() { + } + + /** + * Return a new {@link ClientHttpRequestFactory} instance using the most appropriate + * implementation. + * @param settings the settings to apply + * @return a new {@link ClientHttpRequestFactory} + */ + public static ClientHttpRequestFactory get(ClientHttpRequestFactorySettings settings) { + Assert.notNull(settings, "Settings must not be null"); + if (APACHE_HTTP_CLIENT_PRESENT) { + return HttpComponents.get(settings); + } + if (OKHTTP_CLIENT_PRESENT) { + return OkHttp.get(settings); + } + return Simple.get(settings); + } + + /** + * Return a new {@link ClientHttpRequestFactory} of the given type, applying + * {@link ClientHttpRequestFactorySettings} using reflection if necessary. + * @param the {@link ClientHttpRequestFactory} type + * @param requestFactoryType the {@link ClientHttpRequestFactory} type + * @param settings the settings to apply + * @return a new {@link ClientHttpRequestFactory} instance + */ + @SuppressWarnings("unchecked") + public static T get(Class requestFactoryType, + ClientHttpRequestFactorySettings settings) { + Assert.notNull(settings, "Settings must not be null"); + if (requestFactoryType == ClientHttpRequestFactory.class) { + return (T) get(settings); + } + if (requestFactoryType == HttpComponentsClientHttpRequestFactory.class) { + return (T) HttpComponents.get(settings); + } + if (requestFactoryType == OkHttp3ClientHttpRequestFactory.class) { + return (T) OkHttp.get(settings); + } + if (requestFactoryType == SimpleClientHttpRequestFactory.class) { + return (T) Simple.get(settings); + } + return get(() -> createRequestFactory(requestFactoryType), settings); + } + + /** + * Return a new {@link ClientHttpRequestFactory} from the given supplier, applying + * {@link ClientHttpRequestFactorySettings} using reflection. + * @param the {@link ClientHttpRequestFactory} type + * @param requestFactorySupplier the {@link ClientHttpRequestFactory} supplier + * @param settings the settings to apply + * @return a new {@link ClientHttpRequestFactory} instance + */ + public static T get(Supplier requestFactorySupplier, + ClientHttpRequestFactorySettings settings) { + return Reflective.get(requestFactorySupplier, settings); + } + + private static T createRequestFactory(Class requestFactory) { + try { + Constructor constructor = requestFactory.getDeclaredConstructor(); + constructor.setAccessible(true); + return constructor.newInstance(); + } + catch (Exception ex) { + throw new IllegalStateException(ex); + } + } + + /** + * Support for {@link HttpComponentsClientHttpRequestFactory}. + */ + static class HttpComponents { + + static HttpComponentsClientHttpRequestFactory get(ClientHttpRequestFactorySettings settings) { + HttpComponentsClientHttpRequestFactory requestFactory = createRequestFactory(settings.readTimeout()); + PropertyMapper map = PropertyMapper.get().alwaysApplyingWhenNonNull(); + map.from(settings::connectTimeout).asInt(Duration::toMillis).to(requestFactory::setConnectTimeout); + map.from(settings::bufferRequestBody).to(requestFactory::setBufferRequestBody); + return requestFactory; + } + + private static HttpComponentsClientHttpRequestFactory createRequestFactory(Duration readTimeout) { + return (readTimeout != null) ? new HttpComponentsClientHttpRequestFactory(createHttpClient(readTimeout)) + : new HttpComponentsClientHttpRequestFactory(); + } + + private static HttpClient createHttpClient(Duration readTimeout) { + SocketConfig socketConfig = SocketConfig.custom() + .setSoTimeout((int) readTimeout.toMillis(), TimeUnit.MILLISECONDS).build(); + PoolingHttpClientConnectionManager connectionManager = PoolingHttpClientConnectionManagerBuilder.create() + .setDefaultSocketConfig(socketConfig).build(); + return HttpClientBuilder.create().setConnectionManager(connectionManager).build(); + } + + } + + /** + * Support for {@link OkHttp3ClientHttpRequestFactory}. + */ + static class OkHttp { + + static OkHttp3ClientHttpRequestFactory get(ClientHttpRequestFactorySettings settings) { + Assert.state(settings.bufferRequestBody() == null, + () -> "OkHttp3ClientHttpRequestFactory does not support request body buffering"); + OkHttp3ClientHttpRequestFactory requestFactory = new OkHttp3ClientHttpRequestFactory(); + PropertyMapper map = PropertyMapper.get().alwaysApplyingWhenNonNull(); + map.from(settings::connectTimeout).asInt(Duration::toMillis).to(requestFactory::setConnectTimeout); + map.from(settings::readTimeout).asInt(Duration::toMillis).to(requestFactory::setReadTimeout); + return requestFactory; + } + + } + + /** + * Support for {@link SimpleClientHttpRequestFactory}. + */ + static class Simple { + + static SimpleClientHttpRequestFactory get(ClientHttpRequestFactorySettings settings) { + SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory(); + PropertyMapper map = PropertyMapper.get().alwaysApplyingWhenNonNull(); + map.from(settings::readTimeout).asInt(Duration::toMillis).to(requestFactory::setReadTimeout); + map.from(settings::connectTimeout).asInt(Duration::toMillis).to(requestFactory::setConnectTimeout); + map.from(settings::bufferRequestBody).to(requestFactory::setBufferRequestBody); + return requestFactory; + } + + } + + /** + * Support for reflective configuration of an unknown {@link ClientHttpRequestFactory} + * implementation. + */ + static class Reflective { + + static T get(Supplier requestFactorySupplier, + ClientHttpRequestFactorySettings settings) { + T requestFactory = requestFactorySupplier.get(); + configure(requestFactory, settings); + return requestFactory; + } + + private static void configure(ClientHttpRequestFactory requestFactory, + ClientHttpRequestFactorySettings settings) { + ClientHttpRequestFactory unwrapped = unwrapRequestFactoryIfNecessary(requestFactory); + PropertyMapper map = PropertyMapper.get().alwaysApplyingWhenNonNull(); + map.from(settings::connectTimeout).to((connectTimeout) -> setConnectTimeout(unwrapped, connectTimeout)); + map.from(settings::readTimeout).to((readTimeout) -> setReadTimeout(unwrapped, readTimeout)); + map.from(settings::bufferRequestBody) + .to((bufferRequestBody) -> setBufferRequestBody(unwrapped, bufferRequestBody)); + } + + private static ClientHttpRequestFactory unwrapRequestFactoryIfNecessary( + ClientHttpRequestFactory requestFactory) { + if (!(requestFactory instanceof AbstractClientHttpRequestFactoryWrapper)) { + return requestFactory; + } + Field field = ReflectionUtils.findField(AbstractClientHttpRequestFactoryWrapper.class, "requestFactory"); + ReflectionUtils.makeAccessible(field); + ClientHttpRequestFactory unwrappedRequestFactory = requestFactory; + while (unwrappedRequestFactory instanceof AbstractClientHttpRequestFactoryWrapper) { + unwrappedRequestFactory = (ClientHttpRequestFactory) ReflectionUtils.getField(field, + unwrappedRequestFactory); + } + return unwrappedRequestFactory; + } + + private static void setConnectTimeout(ClientHttpRequestFactory factory, Duration connectTimeout) { + Method method = findMethod(factory, "setConnectTimeout", int.class); + int timeout = Math.toIntExact(connectTimeout.toMillis()); + invoke(factory, method, timeout); + } + + private static void setReadTimeout(ClientHttpRequestFactory factory, Duration readTimeout) { + Method method = findMethod(factory, "setReadTimeout", int.class); + int timeout = Math.toIntExact(readTimeout.toMillis()); + invoke(factory, method, timeout); + } + + private static void setBufferRequestBody(ClientHttpRequestFactory factory, boolean bufferRequestBody) { + Method method = findMethod(factory, "setBufferRequestBody", boolean.class); + invoke(factory, method, bufferRequestBody); + } + + private static Method findMethod(ClientHttpRequestFactory requestFactory, String methodName, + Class... parameters) { + Method method = ReflectionUtils.findMethod(requestFactory.getClass(), methodName, parameters); + Assert.state(method != null, () -> "Request factory %s does not have a suitable %s method" + .formatted(requestFactory.getClass().getName(), methodName)); + Assert.state(!method.isAnnotationPresent(Deprecated.class), + () -> "Request factory %s has the %s method marked as deprecated" + .formatted(requestFactory.getClass().getName(), methodName)); + return method; + } + + private static void invoke(ClientHttpRequestFactory requestFactory, Method method, Object... parameters) { + ReflectionUtils.invokeMethod(method, requestFactory, parameters); + } + + } + +} diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/ClientHttpRequestFactoriesRuntimeHints.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/ClientHttpRequestFactoriesRuntimeHints.java new file mode 100644 index 00000000000..a986c20304a --- /dev/null +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/ClientHttpRequestFactoriesRuntimeHints.java @@ -0,0 +1,92 @@ +/* + * Copyright 2012-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.web.client; + +import java.lang.reflect.Field; +import java.net.HttpURLConnection; +import java.util.function.Consumer; + +import org.springframework.aot.hint.ExecutableMode; +import org.springframework.aot.hint.ReflectionHints; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.RuntimeHintsRegistrar; +import org.springframework.aot.hint.TypeHint; +import org.springframework.aot.hint.TypeReference; +import org.springframework.http.client.AbstractClientHttpRequestFactoryWrapper; +import org.springframework.http.client.ClientHttpRequestFactory; +import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; +import org.springframework.http.client.OkHttp3ClientHttpRequestFactory; +import org.springframework.http.client.SimpleClientHttpRequestFactory; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.util.ReflectionUtils; + +/** + * {@link RuntimeHintsRegistrar} for {@link ClientHttpRequestFactories}. + * + * @author Andy Wilkinson + * @author Phillip Webb + */ +class ClientHttpRequestFactoriesRuntimeHints implements RuntimeHintsRegistrar { + + @Override + public void registerHints(RuntimeHints hints, ClassLoader classLoader) { + if (ClassUtils.isPresent("org.springframework.http.client.ClientHttpRequestFactory", classLoader)) { + registerHints(hints.reflection(), classLoader); + } + } + + private void registerHints(ReflectionHints hints, ClassLoader classLoader) { + hints.registerField(findField(AbstractClientHttpRequestFactoryWrapper.class, "requestFactory")); + if (ClassUtils.isPresent(ClientHttpRequestFactories.APACHE_HTTP_CLIENT_CLASS, classLoader)) { + registerReflectionHints(hints, HttpComponentsClientHttpRequestFactory.class, this::onReachableHttpClient); + } + if (ClassUtils.isPresent(ClientHttpRequestFactories.OKHTTP_CLIENT_CLASS, classLoader)) { + registerReflectionHints(hints, OkHttp3ClientHttpRequestFactory.class, this::onReachableOkHttpClient); + } + registerReflectionHints(hints, SimpleClientHttpRequestFactory.class, this::onReachableHttpUrlConnection); + } + + private void onReachableHttpUrlConnection(TypeHint.Builder typeHint) { + typeHint.onReachableType(HttpURLConnection.class); + } + + private void onReachableHttpClient(TypeHint.Builder typeHint) { + typeHint.onReachableType(TypeReference.of(ClientHttpRequestFactories.APACHE_HTTP_CLIENT_CLASS)); + } + + private void onReachableOkHttpClient(TypeHint.Builder typeHint) { + typeHint.onReachableType(TypeReference.of(ClientHttpRequestFactories.OKHTTP_CLIENT_CLASS)); + } + + private void registerReflectionHints(ReflectionHints hints, + Class requestFactoryType, Consumer hintCustomizer) { + hints.registerType(requestFactoryType, (typeHint) -> { + typeHint.withMethod("setConnectTimeout", TypeReference.listOf(int.class), ExecutableMode.INVOKE); + typeHint.withMethod("setReadTimeout", TypeReference.listOf(int.class), ExecutableMode.INVOKE); + typeHint.withMethod("setBufferRequestBody", TypeReference.listOf(boolean.class), ExecutableMode.INVOKE); + hintCustomizer.accept(typeHint); + }); + } + + private Field findField(Class type, String name) { + Field field = ReflectionUtils.findField(type, name); + Assert.state(field != null, () -> "Unable to find field '%s' on %s".formatted(type.getName(), name)); + return field; + } + +} diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/ClientHttpRequestFactorySettings.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/ClientHttpRequestFactorySettings.java new file mode 100644 index 00000000000..89b26a276c4 --- /dev/null +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/ClientHttpRequestFactorySettings.java @@ -0,0 +1,74 @@ +/* + * Copyright 2012-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.web.client; + +import java.time.Duration; + +import org.springframework.http.client.ClientHttpRequestFactory; + +/** + * Settings that can be applied when creating a {@link ClientHttpRequestFactory}. + * @param connectTimeout the connect timeout + * @param readTimeout the read timeout + * @param bufferRequestBody if request body buffering is used + * @author Andy Wilkinson + * @author Phillip Webb + * @since 3.0.0 + * @see ClientHttpRequestFactories + */ +public record ClientHttpRequestFactorySettings(Duration connectTimeout, Duration readTimeout, + Boolean bufferRequestBody) { + + /** + * Use defaults for the {@link ClientHttpRequestFactory} which can differ depending on + * the implementation. + */ + public static final ClientHttpRequestFactorySettings DEFAULTS = new ClientHttpRequestFactorySettings(null, null, + null); + + /** + * Return a new {@link ClientHttpRequestFactorySettings} instance with an updated + * connect timeout setting . + * @param connectTimeout the new connect timeout setting + * @return a new {@link ClientHttpRequestFactorySettings} instance + */ + public ClientHttpRequestFactorySettings withConnectTimeout(Duration connectTimeout) { + return new ClientHttpRequestFactorySettings(connectTimeout, this.readTimeout, this.bufferRequestBody); + } + + /** + * Return a new {@link ClientHttpRequestFactorySettings} instance with an updated read + * timeout setting. + * @param readTimeout the new read timeout setting + * @return a new {@link ClientHttpRequestFactorySettings} instance + */ + + public ClientHttpRequestFactorySettings withReadTimeout(Duration readTimeout) { + return new ClientHttpRequestFactorySettings(this.connectTimeout, readTimeout, this.bufferRequestBody); + } + + /** + * Return a new {@link ClientHttpRequestFactorySettings} instance with an updated + * buffer request body setting. + * @param bufferRequestBody the new buffer request body setting + * @return a new {@link ClientHttpRequestFactorySettings} instance + */ + public ClientHttpRequestFactorySettings withBufferRequestBody(Boolean bufferRequestBody) { + return new ClientHttpRequestFactorySettings(this.connectTimeout, this.readTimeout, bufferRequestBody); + } + +} diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/ClientHttpRequestFactorySupplier.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/ClientHttpRequestFactorySupplier.java index ee6f801869f..7fce4e31cf0 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/ClientHttpRequestFactorySupplier.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/ClientHttpRequestFactorySupplier.java @@ -16,17 +16,9 @@ package org.springframework.boot.web.client; -import java.util.function.Consumer; import java.util.function.Supplier; -import org.springframework.aot.hint.RuntimeHints; -import org.springframework.aot.hint.TypeHint.Builder; -import org.springframework.aot.hint.TypeReference; import org.springframework.http.client.ClientHttpRequestFactory; -import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; -import org.springframework.http.client.OkHttp3ClientHttpRequestFactory; -import org.springframework.http.client.SimpleClientHttpRequestFactory; -import org.springframework.util.ClassUtils; /** * A supplier for {@link ClientHttpRequestFactory} that detects the preferred candidate @@ -35,43 +27,15 @@ import org.springframework.util.ClassUtils; * @author Stephane Nicoll * @author Moritz Halbritter * @since 2.1.0 + * @deprecated since 3.0.0 for removal in 3.2.0 in favor of + * {@link ClientHttpRequestFactories} */ +@Deprecated(since = "3.0.0", forRemoval = true) public class ClientHttpRequestFactorySupplier implements Supplier { - private static final String APACHE_HTTP_CLIENT_CLASS = "org.apache.hc.client5.http.impl.classic.HttpClients"; - - private static final boolean APACHE_HTTP_CLIENT_PRESENT = ClassUtils.isPresent(APACHE_HTTP_CLIENT_CLASS, null); - - private static final String OKHTTP_CLIENT_CLASS = "okhttp3.OkHttpClient"; - - private static final boolean OKHTTP_CLIENT_PRESENT = ClassUtils.isPresent(OKHTTP_CLIENT_CLASS, null); - @Override public ClientHttpRequestFactory get() { - if (APACHE_HTTP_CLIENT_PRESENT) { - return new HttpComponentsClientHttpRequestFactory(); - } - if (OKHTTP_CLIENT_PRESENT) { - return new OkHttp3ClientHttpRequestFactory(); - } - return new SimpleClientHttpRequestFactory(); - } - - static class ClientHttpRequestFactorySupplierRuntimeHints { - - static void registerHints(RuntimeHints hints, ClassLoader classLoader, Consumer callback) { - if (ClassUtils.isPresent(APACHE_HTTP_CLIENT_CLASS, classLoader)) { - hints.reflection().registerType(HttpComponentsClientHttpRequestFactory.class, (typeHint) -> callback - .accept(typeHint.onReachableType(TypeReference.of(APACHE_HTTP_CLIENT_CLASS)))); - } - if (ClassUtils.isPresent(OKHTTP_CLIENT_CLASS, classLoader)) { - hints.reflection().registerType(OkHttp3ClientHttpRequestFactory.class, - (typeHint) -> callback.accept(typeHint.onReachableType(TypeReference.of(OKHTTP_CLIENT_CLASS)))); - } - hints.reflection().registerType(SimpleClientHttpRequestFactory.class, (typeHint) -> callback - .accept(typeHint.onReachableType(TypeReference.of(SimpleClientHttpRequestFactory.class)))); - } - + return ClientHttpRequestFactories.get(ClientHttpRequestFactorySettings.DEFAULTS); } } diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java index 7f8d0883123..083226a580b 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java @@ -16,9 +16,6 @@ package org.springframework.boot.web.client; -import java.lang.reflect.Constructor; -import java.lang.reflect.Field; -import java.lang.reflect.Method; import java.nio.charset.Charset; import java.time.Duration; import java.util.ArrayList; @@ -29,20 +26,13 @@ import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.Set; -import java.util.function.Consumer; +import java.util.function.Function; import java.util.function.Supplier; import reactor.netty.http.client.HttpClientRequest; -import org.springframework.aot.hint.ExecutableMode; -import org.springframework.aot.hint.RuntimeHints; -import org.springframework.aot.hint.RuntimeHintsRegistrar; -import org.springframework.aot.hint.TypeReference; import org.springframework.beans.BeanUtils; -import org.springframework.context.annotation.ImportRuntimeHints; -import org.springframework.http.client.AbstractClientHttpRequestFactoryWrapper; import org.springframework.http.client.ClientHttpRequest; import org.springframework.http.client.ClientHttpRequestFactory; import org.springframework.http.client.ClientHttpRequestInterceptor; @@ -51,7 +41,6 @@ import org.springframework.http.client.SimpleClientHttpRequestFactory; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; -import org.springframework.util.ReflectionUtils; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestTemplate; import org.springframework.web.util.UriTemplateHandler; @@ -77,10 +66,9 @@ import org.springframework.web.util.UriTemplateHandler; * @author Ilya Lukyanovich * @since 1.4.0 */ -@ImportRuntimeHints(RestTemplateBuilder.RestTemplateBuilderRuntimeHints.class) public class RestTemplateBuilder { - private final RequestFactoryCustomizer requestFactoryCustomizer; + private final ClientHttpRequestFactorySettings requestFactorySettings; private final boolean detectRequestFactory; @@ -90,7 +78,7 @@ public class RestTemplateBuilder { private final Set interceptors; - private final Supplier requestFactory; + private final Function requestFactory; private final UriTemplateHandler uriTemplateHandler; @@ -111,7 +99,7 @@ public class RestTemplateBuilder { */ public RestTemplateBuilder(RestTemplateCustomizer... customizers) { Assert.notNull(customizers, "Customizers must not be null"); - this.requestFactoryCustomizer = new RequestFactoryCustomizer(); + this.requestFactorySettings = ClientHttpRequestFactorySettings.DEFAULTS; this.detectRequestFactory = true; this.rootUri = null; this.messageConverters = null; @@ -125,18 +113,19 @@ public class RestTemplateBuilder { this.requestCustomizers = Collections.emptySet(); } - private RestTemplateBuilder(RequestFactoryCustomizer requestFactoryCustomizer, boolean detectRequestFactory, + private RestTemplateBuilder(ClientHttpRequestFactorySettings requestFactorySettings, boolean detectRequestFactory, String rootUri, Set> messageConverters, - Set interceptors, Supplier requestFactorySupplier, + Set interceptors, + Function requestFactory, UriTemplateHandler uriTemplateHandler, ResponseErrorHandler errorHandler, BasicAuthentication basicAuthentication, Map> defaultHeaders, Set customizers, Set> requestCustomizers) { - this.requestFactoryCustomizer = requestFactoryCustomizer; + this.requestFactorySettings = requestFactorySettings; this.detectRequestFactory = detectRequestFactory; this.rootUri = rootUri; this.messageConverters = messageConverters; this.interceptors = interceptors; - this.requestFactory = requestFactorySupplier; + this.requestFactory = requestFactory; this.uriTemplateHandler = uriTemplateHandler; this.errorHandler = errorHandler; this.basicAuthentication = basicAuthentication; @@ -153,7 +142,7 @@ public class RestTemplateBuilder { * @return a new builder instance */ public RestTemplateBuilder detectRequestFactory(boolean detectRequestFactory) { - return new RestTemplateBuilder(this.requestFactoryCustomizer, detectRequestFactory, this.rootUri, + return new RestTemplateBuilder(this.requestFactorySettings, detectRequestFactory, this.rootUri, this.messageConverters, this.interceptors, this.requestFactory, this.uriTemplateHandler, this.errorHandler, this.basicAuthentication, this.defaultHeaders, this.customizers, this.requestCustomizers); @@ -169,7 +158,7 @@ public class RestTemplateBuilder { * @return a new builder instance */ public RestTemplateBuilder rootUri(String rootUri) { - return new RestTemplateBuilder(this.requestFactoryCustomizer, this.detectRequestFactory, rootUri, + return new RestTemplateBuilder(this.requestFactorySettings, this.detectRequestFactory, rootUri, this.messageConverters, this.interceptors, this.requestFactory, this.uriTemplateHandler, this.errorHandler, this.basicAuthentication, this.defaultHeaders, this.customizers, this.requestCustomizers); @@ -200,7 +189,7 @@ public class RestTemplateBuilder { */ public RestTemplateBuilder messageConverters(Collection> messageConverters) { Assert.notNull(messageConverters, "MessageConverters must not be null"); - return new RestTemplateBuilder(this.requestFactoryCustomizer, this.detectRequestFactory, this.rootUri, + return new RestTemplateBuilder(this.requestFactorySettings, this.detectRequestFactory, this.rootUri, copiedSetOf(messageConverters), this.interceptors, this.requestFactory, this.uriTemplateHandler, this.errorHandler, this.basicAuthentication, this.defaultHeaders, this.customizers, this.requestCustomizers); @@ -230,7 +219,7 @@ public class RestTemplateBuilder { public RestTemplateBuilder additionalMessageConverters( Collection> messageConverters) { Assert.notNull(messageConverters, "MessageConverters must not be null"); - return new RestTemplateBuilder(this.requestFactoryCustomizer, this.detectRequestFactory, this.rootUri, + return new RestTemplateBuilder(this.requestFactorySettings, this.detectRequestFactory, this.rootUri, append(this.messageConverters, messageConverters), this.interceptors, this.requestFactory, this.uriTemplateHandler, this.errorHandler, this.basicAuthentication, this.defaultHeaders, this.customizers, this.requestCustomizers); @@ -244,7 +233,7 @@ public class RestTemplateBuilder { * @see #messageConverters(HttpMessageConverter...) */ public RestTemplateBuilder defaultMessageConverters() { - return new RestTemplateBuilder(this.requestFactoryCustomizer, this.detectRequestFactory, this.rootUri, + return new RestTemplateBuilder(this.requestFactorySettings, this.detectRequestFactory, this.rootUri, copiedSetOf(new RestTemplate().getMessageConverters()), this.interceptors, this.requestFactory, this.uriTemplateHandler, this.errorHandler, this.basicAuthentication, this.defaultHeaders, this.customizers, this.requestCustomizers); @@ -275,7 +264,7 @@ public class RestTemplateBuilder { */ public RestTemplateBuilder interceptors(Collection interceptors) { Assert.notNull(interceptors, "interceptors must not be null"); - return new RestTemplateBuilder(this.requestFactoryCustomizer, this.detectRequestFactory, this.rootUri, + return new RestTemplateBuilder(this.requestFactorySettings, this.detectRequestFactory, this.rootUri, this.messageConverters, copiedSetOf(interceptors), this.requestFactory, this.uriTemplateHandler, this.errorHandler, this.basicAuthentication, this.defaultHeaders, this.customizers, this.requestCustomizers); @@ -304,7 +293,7 @@ public class RestTemplateBuilder { */ public RestTemplateBuilder additionalInterceptors(Collection interceptors) { Assert.notNull(interceptors, "interceptors must not be null"); - return new RestTemplateBuilder(this.requestFactoryCustomizer, this.detectRequestFactory, this.rootUri, + return new RestTemplateBuilder(this.requestFactorySettings, this.detectRequestFactory, this.rootUri, this.messageConverters, append(this.interceptors, interceptors), this.requestFactory, this.uriTemplateHandler, this.errorHandler, this.basicAuthentication, this.defaultHeaders, this.customizers, this.requestCustomizers); @@ -313,37 +302,41 @@ public class RestTemplateBuilder { /** * Set the {@link ClientHttpRequestFactory} class that should be used with the * {@link RestTemplate}. - * @param requestFactory the request factory to use + * @param requestFactoryType the request factory type to use * @return a new builder instance */ - public RestTemplateBuilder requestFactory(Class requestFactory) { - Assert.notNull(requestFactory, "RequestFactory must not be null"); - return requestFactory(() -> createRequestFactory(requestFactory)); - } - - private ClientHttpRequestFactory createRequestFactory(Class requestFactory) { - try { - Constructor constructor = requestFactory.getDeclaredConstructor(); - constructor.setAccessible(true); - return (ClientHttpRequestFactory) constructor.newInstance(); - } - catch (Exception ex) { - throw new IllegalStateException(ex); - } + public RestTemplateBuilder requestFactory(Class requestFactoryType) { + Assert.notNull(requestFactoryType, "RequestFactoryType must not be null"); + return requestFactory((settings) -> ClientHttpRequestFactories.get(requestFactoryType, settings)); } /** * Set the {@code Supplier} of {@link ClientHttpRequestFactory} that should be called * each time we {@link #build()} a new {@link RestTemplate} instance. - * @param requestFactory the supplier for the request factory + * @param requestFactorySupplier the supplier for the request factory * @return a new builder instance * @since 2.0.0 */ - public RestTemplateBuilder requestFactory(Supplier requestFactory) { - Assert.notNull(requestFactory, "RequestFactory Supplier must not be null"); - return new RestTemplateBuilder(this.requestFactoryCustomizer, this.detectRequestFactory, this.rootUri, - this.messageConverters, this.interceptors, requestFactory, this.uriTemplateHandler, this.errorHandler, - this.basicAuthentication, this.defaultHeaders, this.customizers, this.requestCustomizers); + public RestTemplateBuilder requestFactory(Supplier requestFactorySupplier) { + Assert.notNull(requestFactorySupplier, "RequestFactorySupplier must not be null"); + return requestFactory((settings) -> ClientHttpRequestFactories.get(requestFactorySupplier, settings)); + } + + /** + * Set the {@link ClientHttpRequestFactorySupplier} that should be called each time we + * {@link #build()} a new {@link RestTemplate} instance. + * @param requestFactoryFunction the settings to request factory function + * @return a new builder instance + * @since 3.0.0 + * @see ClientHttpRequestFactories + */ + public RestTemplateBuilder requestFactory( + Function requestFactoryFunction) { + Assert.notNull(requestFactoryFunction, "RequestFactoryFunction must not be null"); + return new RestTemplateBuilder(this.requestFactorySettings, this.detectRequestFactory, this.rootUri, + this.messageConverters, this.interceptors, requestFactoryFunction, this.uriTemplateHandler, + this.errorHandler, this.basicAuthentication, this.defaultHeaders, this.customizers, + this.requestCustomizers); } /** @@ -354,7 +347,7 @@ public class RestTemplateBuilder { */ public RestTemplateBuilder uriTemplateHandler(UriTemplateHandler uriTemplateHandler) { Assert.notNull(uriTemplateHandler, "UriTemplateHandler must not be null"); - return new RestTemplateBuilder(this.requestFactoryCustomizer, this.detectRequestFactory, this.rootUri, + return new RestTemplateBuilder(this.requestFactorySettings, this.detectRequestFactory, this.rootUri, this.messageConverters, this.interceptors, this.requestFactory, uriTemplateHandler, this.errorHandler, this.basicAuthentication, this.defaultHeaders, this.customizers, this.requestCustomizers); } @@ -367,7 +360,7 @@ public class RestTemplateBuilder { */ public RestTemplateBuilder errorHandler(ResponseErrorHandler errorHandler) { Assert.notNull(errorHandler, "ErrorHandler must not be null"); - return new RestTemplateBuilder(this.requestFactoryCustomizer, this.detectRequestFactory, this.rootUri, + return new RestTemplateBuilder(this.requestFactorySettings, this.detectRequestFactory, this.rootUri, this.messageConverters, this.interceptors, this.requestFactory, this.uriTemplateHandler, errorHandler, this.basicAuthentication, this.defaultHeaders, this.customizers, this.requestCustomizers); } @@ -395,7 +388,7 @@ public class RestTemplateBuilder { * @since 2.2.0 */ public RestTemplateBuilder basicAuthentication(String username, String password, Charset charset) { - return new RestTemplateBuilder(this.requestFactoryCustomizer, this.detectRequestFactory, this.rootUri, + return new RestTemplateBuilder(this.requestFactorySettings, this.detectRequestFactory, this.rootUri, this.messageConverters, this.interceptors, this.requestFactory, this.uriTemplateHandler, this.errorHandler, new BasicAuthentication(username, password, charset), this.defaultHeaders, this.customizers, this.requestCustomizers); @@ -412,7 +405,7 @@ public class RestTemplateBuilder { public RestTemplateBuilder defaultHeader(String name, String... values) { Assert.notNull(name, "Name must not be null"); Assert.notNull(values, "Values must not be null"); - return new RestTemplateBuilder(this.requestFactoryCustomizer, this.detectRequestFactory, this.rootUri, + return new RestTemplateBuilder(this.requestFactorySettings, this.detectRequestFactory, this.rootUri, this.messageConverters, this.interceptors, this.requestFactory, this.uriTemplateHandler, this.errorHandler, this.basicAuthentication, append(this.defaultHeaders, name, values), this.customizers, this.requestCustomizers); @@ -425,7 +418,7 @@ public class RestTemplateBuilder { * @since 2.1.0 */ public RestTemplateBuilder setConnectTimeout(Duration connectTimeout) { - return new RestTemplateBuilder(this.requestFactoryCustomizer.connectTimeout(connectTimeout), + return new RestTemplateBuilder(this.requestFactorySettings.withConnectTimeout(connectTimeout), this.detectRequestFactory, this.rootUri, this.messageConverters, this.interceptors, this.requestFactory, this.uriTemplateHandler, this.errorHandler, this.basicAuthentication, this.defaultHeaders, this.customizers, this.requestCustomizers); @@ -438,7 +431,7 @@ public class RestTemplateBuilder { * @since 2.1.0 */ public RestTemplateBuilder setReadTimeout(Duration readTimeout) { - return new RestTemplateBuilder(this.requestFactoryCustomizer.readTimeout(readTimeout), + return new RestTemplateBuilder(this.requestFactorySettings.withReadTimeout(readTimeout), this.detectRequestFactory, this.rootUri, this.messageConverters, this.interceptors, this.requestFactory, this.uriTemplateHandler, this.errorHandler, this.basicAuthentication, this.defaultHeaders, this.customizers, this.requestCustomizers); @@ -454,7 +447,7 @@ public class RestTemplateBuilder { * @see HttpComponentsClientHttpRequestFactory#setBufferRequestBody(boolean) */ public RestTemplateBuilder setBufferRequestBody(boolean bufferRequestBody) { - return new RestTemplateBuilder(this.requestFactoryCustomizer.bufferRequestBody(bufferRequestBody), + return new RestTemplateBuilder(this.requestFactorySettings.withBufferRequestBody(bufferRequestBody), this.detectRequestFactory, this.rootUri, this.messageConverters, this.interceptors, this.requestFactory, this.uriTemplateHandler, this.errorHandler, this.basicAuthentication, this.defaultHeaders, this.customizers, this.requestCustomizers); @@ -485,7 +478,7 @@ public class RestTemplateBuilder { */ public RestTemplateBuilder customizers(Collection customizers) { Assert.notNull(customizers, "Customizers must not be null"); - return new RestTemplateBuilder(this.requestFactoryCustomizer, this.detectRequestFactory, this.rootUri, + return new RestTemplateBuilder(this.requestFactorySettings, this.detectRequestFactory, this.rootUri, this.messageConverters, this.interceptors, this.requestFactory, this.uriTemplateHandler, this.errorHandler, this.basicAuthentication, this.defaultHeaders, copiedSetOf(customizers), this.requestCustomizers); @@ -514,7 +507,7 @@ public class RestTemplateBuilder { */ public RestTemplateBuilder additionalCustomizers(Collection customizers) { Assert.notNull(customizers, "RestTemplateCustomizers must not be null"); - return new RestTemplateBuilder(this.requestFactoryCustomizer, this.detectRequestFactory, this.rootUri, + return new RestTemplateBuilder(this.requestFactorySettings, this.detectRequestFactory, this.rootUri, this.messageConverters, this.interceptors, this.requestFactory, this.uriTemplateHandler, this.errorHandler, this.basicAuthentication, this.defaultHeaders, append(this.customizers, customizers), this.requestCustomizers); @@ -548,7 +541,7 @@ public class RestTemplateBuilder { public RestTemplateBuilder requestCustomizers( Collection> requestCustomizers) { Assert.notNull(requestCustomizers, "RequestCustomizers must not be null"); - return new RestTemplateBuilder(this.requestFactoryCustomizer, this.detectRequestFactory, this.rootUri, + return new RestTemplateBuilder(this.requestFactorySettings, this.detectRequestFactory, this.rootUri, this.messageConverters, this.interceptors, this.requestFactory, this.uriTemplateHandler, this.errorHandler, this.basicAuthentication, this.defaultHeaders, this.customizers, copiedSetOf(requestCustomizers)); @@ -580,7 +573,7 @@ public class RestTemplateBuilder { public RestTemplateBuilder additionalRequestCustomizers( Collection> requestCustomizers) { Assert.notNull(requestCustomizers, "RequestCustomizers must not be null"); - return new RestTemplateBuilder(this.requestFactoryCustomizer, this.detectRequestFactory, this.rootUri, + return new RestTemplateBuilder(this.requestFactorySettings, this.detectRequestFactory, this.rootUri, this.messageConverters, this.interceptors, this.requestFactory, this.uriTemplateHandler, this.errorHandler, this.basicAuthentication, this.defaultHeaders, this.customizers, append(this.requestCustomizers, requestCustomizers)); @@ -651,19 +644,13 @@ public class RestTemplateBuilder { * @since 2.2.0 */ public ClientHttpRequestFactory buildRequestFactory() { - ClientHttpRequestFactory requestFactory = null; if (this.requestFactory != null) { - requestFactory = this.requestFactory.get(); + return this.requestFactory.apply(this.requestFactorySettings); } - else if (this.detectRequestFactory) { - requestFactory = new ClientHttpRequestFactorySupplier().get(); + if (this.detectRequestFactory) { + return ClientHttpRequestFactories.get(this.requestFactorySettings); } - if (requestFactory != null) { - if (this.requestFactoryCustomizer != null) { - this.requestFactoryCustomizer.accept(requestFactory); - } - } - return requestFactory; + return null; } private void addClientHttpRequestInitializer(RestTemplate restTemplate) { @@ -703,118 +690,4 @@ public class RestTemplateBuilder { return Collections.unmodifiableMap(result); } - /** - * Internal customizer used to apply {@link ClientHttpRequestFactory} settings. - */ - private static class RequestFactoryCustomizer implements Consumer { - - private final Duration connectTimeout; - - private final Duration readTimeout; - - private final Boolean bufferRequestBody; - - RequestFactoryCustomizer() { - this(null, null, null); - } - - private RequestFactoryCustomizer(Duration connectTimeout, Duration readTimeout, Boolean bufferRequestBody) { - this.connectTimeout = connectTimeout; - this.readTimeout = readTimeout; - this.bufferRequestBody = bufferRequestBody; - } - - RequestFactoryCustomizer connectTimeout(Duration connectTimeout) { - return new RequestFactoryCustomizer(connectTimeout, this.readTimeout, this.bufferRequestBody); - } - - RequestFactoryCustomizer readTimeout(Duration readTimeout) { - return new RequestFactoryCustomizer(this.connectTimeout, readTimeout, this.bufferRequestBody); - } - - RequestFactoryCustomizer bufferRequestBody(boolean bufferRequestBody) { - return new RequestFactoryCustomizer(this.connectTimeout, this.readTimeout, bufferRequestBody); - } - - @Override - public void accept(ClientHttpRequestFactory requestFactory) { - ClientHttpRequestFactory unwrappedRequestFactory = unwrapRequestFactoryIfNecessary(requestFactory); - if (this.connectTimeout != null) { - setConnectTimeout(unwrappedRequestFactory); - } - if (this.readTimeout != null) { - setReadTimeout(unwrappedRequestFactory); - } - if (this.bufferRequestBody != null) { - setBufferRequestBody(unwrappedRequestFactory); - } - } - - private ClientHttpRequestFactory unwrapRequestFactoryIfNecessary(ClientHttpRequestFactory requestFactory) { - if (!(requestFactory instanceof AbstractClientHttpRequestFactoryWrapper)) { - return requestFactory; - } - Field field = ReflectionUtils.findField(AbstractClientHttpRequestFactoryWrapper.class, "requestFactory"); - ReflectionUtils.makeAccessible(field); - ClientHttpRequestFactory unwrappedRequestFactory = requestFactory; - while (unwrappedRequestFactory instanceof AbstractClientHttpRequestFactoryWrapper) { - unwrappedRequestFactory = (ClientHttpRequestFactory) ReflectionUtils.getField(field, - unwrappedRequestFactory); - } - return unwrappedRequestFactory; - } - - private void setConnectTimeout(ClientHttpRequestFactory factory) { - Method method = findMethod(factory, "setConnectTimeout", int.class); - int timeout = Math.toIntExact(this.connectTimeout.toMillis()); - invoke(factory, method, timeout); - } - - private void setReadTimeout(ClientHttpRequestFactory factory) { - Method method = findMethod(factory, "setReadTimeout", int.class); - int timeout = Math.toIntExact(this.readTimeout.toMillis()); - invoke(factory, method, timeout); - } - - private void setBufferRequestBody(ClientHttpRequestFactory factory) { - Method method = findMethod(factory, "setBufferRequestBody", boolean.class); - invoke(factory, method, this.bufferRequestBody); - } - - private Method findMethod(ClientHttpRequestFactory requestFactory, String methodName, Class... parameters) { - Method method = ReflectionUtils.findMethod(requestFactory.getClass(), methodName, parameters); - if (method == null) { - throw new IllegalStateException("Request factory " + requestFactory.getClass() - + " does not have a suitable " + methodName + " method"); - } - else if (method.isAnnotationPresent(Deprecated.class)) { - throw new IllegalStateException("Request factory " + requestFactory.getClass() + " has the " - + methodName + " method marked as deprecated"); - } - return method; - } - - private void invoke(ClientHttpRequestFactory requestFactory, Method method, Object... parameters) { - ReflectionUtils.invokeMethod(method, requestFactory, parameters); - } - - } - - static class RestTemplateBuilderRuntimeHints implements RuntimeHintsRegistrar { - - @Override - public void registerHints(RuntimeHints hints, ClassLoader classLoader) { - hints.reflection().registerField(Objects.requireNonNull( - ReflectionUtils.findField(AbstractClientHttpRequestFactoryWrapper.class, "requestFactory"))); - ClientHttpRequestFactorySupplier.ClientHttpRequestFactorySupplierRuntimeHints.registerHints(hints, - classLoader, (hint) -> { - hint.withMethod("setConnectTimeout", TypeReference.listOf(int.class), ExecutableMode.INVOKE); - hint.withMethod("setReadTimeout", TypeReference.listOf(int.class), ExecutableMode.INVOKE); - hint.withMethod("setBufferRequestBody", TypeReference.listOf(boolean.class), - ExecutableMode.INVOKE); - }); - } - - } - } diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/webservices/client/HttpWebServiceMessageSenderBuilder.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/webservices/client/HttpWebServiceMessageSenderBuilder.java index de05c6603c8..f823566492e 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/webservices/client/HttpWebServiceMessageSenderBuilder.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/webservices/client/HttpWebServiceMessageSenderBuilder.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2019 the original author or authors. + * Copyright 2012-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,14 +16,14 @@ package org.springframework.boot.webservices.client; -import java.lang.reflect.Method; import java.time.Duration; +import java.util.function.Function; import java.util.function.Supplier; -import org.springframework.boot.web.client.ClientHttpRequestFactorySupplier; +import org.springframework.boot.web.client.ClientHttpRequestFactories; +import org.springframework.boot.web.client.ClientHttpRequestFactorySettings; import org.springframework.http.client.ClientHttpRequestFactory; import org.springframework.util.Assert; -import org.springframework.util.ReflectionUtils; import org.springframework.ws.transport.WebServiceMessageSender; import org.springframework.ws.transport.http.ClientHttpRequestMessageSender; @@ -40,7 +40,7 @@ public class HttpWebServiceMessageSenderBuilder { private Duration readTimeout; - private Supplier requestFactorySupplier; + private Function requestFactory; /** * Set the connection timeout. @@ -70,50 +70,39 @@ public class HttpWebServiceMessageSenderBuilder { */ public HttpWebServiceMessageSenderBuilder requestFactory( Supplier requestFactorySupplier) { - Assert.notNull(requestFactorySupplier, "RequestFactory Supplier must not be null"); - this.requestFactorySupplier = requestFactorySupplier; + Assert.notNull(requestFactorySupplier, "RequestFactorySupplier must not be null"); + this.requestFactory = (settings) -> ClientHttpRequestFactories.get(requestFactorySupplier, settings); return this; } - public WebServiceMessageSender build() { - ClientHttpRequestFactory requestFactory = (this.requestFactorySupplier != null) - ? this.requestFactorySupplier.get() : new ClientHttpRequestFactorySupplier().get(); - if (this.connectTimeout != null) { - new TimeoutRequestFactoryCustomizer(this.connectTimeout, "setConnectTimeout").customize(requestFactory); - } - if (this.readTimeout != null) { - new TimeoutRequestFactoryCustomizer(this.readTimeout, "setReadTimeout").customize(requestFactory); - } - return new ClientHttpRequestMessageSender(requestFactory); + /** + * Set the {@code Function} of {@link ClientHttpRequestFactorySettings} to + * {@link ClientHttpRequestFactory} that should be called to create the HTTP-based + * {@link WebServiceMessageSender}. + * @param requestFactoryFunction the function for the request factory + * @return a new builder instance + * @since 3.0.0 + */ + public HttpWebServiceMessageSenderBuilder requestFactory( + Function requestFactoryFunction) { + Assert.notNull(requestFactoryFunction, "RequestFactoryFunction must not be null"); + this.requestFactory = requestFactoryFunction; + return this; } /** - * {@link ClientHttpRequestFactory} customizer to call a "set timeout" method. + * Build the {@link WebServiceMessageSender} instance. + * @return the {@link WebServiceMessageSender} instance */ - private static class TimeoutRequestFactoryCustomizer { - - private final Duration timeout; - - private final String methodName; - - TimeoutRequestFactoryCustomizer(Duration timeout, String methodName) { - this.timeout = timeout; - this.methodName = methodName; - } - - void customize(ClientHttpRequestFactory factory) { - ReflectionUtils.invokeMethod(findMethod(factory), factory, Math.toIntExact(this.timeout.toMillis())); - } - - private Method findMethod(ClientHttpRequestFactory factory) { - Method method = ReflectionUtils.findMethod(factory.getClass(), this.methodName, int.class); - if (method != null) { - return method; - } - throw new IllegalStateException( - "Request factory " + factory.getClass() + " does not have a " + this.methodName + "(int) method"); - } + public WebServiceMessageSender build() { + return new ClientHttpRequestMessageSender(getRequestFactory()); + } + private ClientHttpRequestFactory getRequestFactory() { + ClientHttpRequestFactorySettings settings = new ClientHttpRequestFactorySettings(this.connectTimeout, + this.readTimeout, null); + return (this.requestFactory != null) ? this.requestFactory.apply(settings) + : ClientHttpRequestFactories.get(settings); } } diff --git a/spring-boot-project/spring-boot/src/main/resources/META-INF/spring/aot.factories b/spring-boot-project/spring-boot/src/main/resources/META-INF/spring/aot.factories index e1c9f247002..f219219f582 100644 --- a/spring-boot-project/spring-boot/src/main/resources/META-INF/spring/aot.factories +++ b/spring-boot-project/spring-boot/src/main/resources/META-INF/spring/aot.factories @@ -7,6 +7,7 @@ org.springframework.boot.env.PropertySourceRuntimeHints,\ org.springframework.boot.json.JacksonRuntimeHints,\ org.springframework.boot.logging.java.JavaLoggingSystemRuntimeHints,\ org.springframework.boot.logging.logback.LogbackRuntimeHints,\ +org.springframework.boot.web.client.ClientHttpRequestFactoriesRuntimeHints,\ org.springframework.boot.web.embedded.undertow.UndertowWebServer.UndertowWebServerRuntimeHints,\ org.springframework.boot.web.server.MimeMappings.MimeMappingsRuntimeHints diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/context/properties/ConfigurationPropertiesBeanFactoryInitializationAotProcessorTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/context/properties/ConfigurationPropertiesBeanFactoryInitializationAotProcessorTests.java index b0e8ad3b7de..ca78514a13f 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/context/properties/ConfigurationPropertiesBeanFactoryInitializationAotProcessorTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/context/properties/ConfigurationPropertiesBeanFactoryInitializationAotProcessorTests.java @@ -40,6 +40,7 @@ import org.springframework.beans.factory.aot.BeanFactoryInitializationCode; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.boot.context.properties.bind.BindableRuntimeHintsRegistrar; import org.springframework.boot.context.properties.bind.ConstructorBinding; import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextAware; @@ -50,7 +51,8 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Mockito.mock; /** - * Tests for {@link ConfigurationPropertiesBeanFactoryInitializationAotProcessor}. + * Tests for {@link ConfigurationPropertiesBeanFactoryInitializationAotProcessor} and + * {@link BindableRuntimeHintsRegistrar}. * * @author Stephane Nicoll * @author Moritz Halbritter diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/AbstractClientHttpRequestFactoriesTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/AbstractClientHttpRequestFactoriesTests.java new file mode 100644 index 00000000000..14b2b806179 --- /dev/null +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/AbstractClientHttpRequestFactoriesTests.java @@ -0,0 +1,83 @@ +/* + * Copyright 2012-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.web.client; + +import java.time.Duration; + +import org.junit.jupiter.api.Test; + +import org.springframework.http.client.ClientHttpRequestFactory; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Base classes for testing of {@link ClientHttpRequestFactories} with different HTTP + * clients on the classpath. + * + * @param the {@link ClientHttpRequestFactory} to be produced + * @author Andy Wilkinson + */ +abstract class AbstractClientHttpRequestFactoriesTests { + + private final Class requestFactoryType; + + protected AbstractClientHttpRequestFactoriesTests(Class requestFactoryType) { + this.requestFactoryType = requestFactoryType; + } + + @Test + void getReturnsRequestFactoryOfExpectedType() { + ClientHttpRequestFactory requestFactory = ClientHttpRequestFactories + .get(ClientHttpRequestFactorySettings.DEFAULTS); + assertThat(requestFactory).isInstanceOf(this.requestFactoryType); + } + + @Test + void getOfGeneralTypeReturnsRequestFactoryOfExpectedType() { + ClientHttpRequestFactory requestFactory = ClientHttpRequestFactories.get(ClientHttpRequestFactory.class, + ClientHttpRequestFactorySettings.DEFAULTS); + assertThat(requestFactory).isInstanceOf(this.requestFactoryType); + } + + @Test + void getOfSpecificTypeReturnsRequestFactoryOfExpectedType() { + ClientHttpRequestFactory requestFactory = ClientHttpRequestFactories.get(this.requestFactoryType, + ClientHttpRequestFactorySettings.DEFAULTS); + assertThat(requestFactory).isInstanceOf(this.requestFactoryType); + } + + @Test + @SuppressWarnings("unchecked") + void getReturnsRequestFactoryWithConfiguredConnectTimeout() { + ClientHttpRequestFactory requestFactory = ClientHttpRequestFactories + .get(ClientHttpRequestFactorySettings.DEFAULTS.withConnectTimeout(Duration.ofSeconds(60))); + assertThat(connectTimeout((T) requestFactory)).isEqualTo(Duration.ofSeconds(60).toMillis()); + } + + @Test + @SuppressWarnings("unchecked") + void getReturnsRequestFactoryWithConfiguredReadTimeout() { + ClientHttpRequestFactory requestFactory = ClientHttpRequestFactories + .get(ClientHttpRequestFactorySettings.DEFAULTS.withReadTimeout(Duration.ofSeconds(120))); + assertThat(readTimeout((T) requestFactory)).isEqualTo(Duration.ofSeconds(120).toMillis()); + } + + protected abstract long connectTimeout(T requestFactory); + + protected abstract long readTimeout(T requestFactory); + +} diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/AbstractRestTemplateBuilderRequestFactoryConfigurationTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/AbstractRestTemplateBuilderRequestFactoryConfigurationTests.java new file mode 100644 index 00000000000..695be67bb5b --- /dev/null +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/AbstractRestTemplateBuilderRequestFactoryConfigurationTests.java @@ -0,0 +1,80 @@ +/* + * Copyright 2012-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.web.client; + +import java.time.Duration; + +import org.junit.jupiter.api.Test; + +import org.springframework.http.client.ClientHttpRequestFactory; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Base class for tests that verify the configuration of the + * {@link ClientHttpRequestFactory} used by {@link RestTemplateBuilder}. + * + * @param the request factory type under test + * @author Andy Wilkinson + */ +abstract class AbstractRestTemplateBuilderRequestFactoryConfigurationTests { + + private final Class factoryType; + + private final RestTemplateBuilder builder = new RestTemplateBuilder(); + + protected AbstractRestTemplateBuilderRequestFactoryConfigurationTests(Class factoryType) { + this.factoryType = factoryType; + } + + @Test + @SuppressWarnings("unchecked") + void connectTimeoutCanBeConfiguredOnFactory() { + ClientHttpRequestFactory requestFactory = this.builder.requestFactory(this.factoryType) + .setConnectTimeout(Duration.ofMillis(1234)).build().getRequestFactory(); + assertThat(connectTimeout((T) requestFactory)).isEqualTo(1234); + } + + @Test + @SuppressWarnings("unchecked") + void readTimeoutCanBeConfiguredOnFactory() { + ClientHttpRequestFactory requestFactory = this.builder.requestFactory(this.factoryType) + .setReadTimeout(Duration.ofMillis(1234)).build().getRequestFactory(); + assertThat(readTimeout((T) requestFactory)).isEqualTo(1234); + } + + @Test + @SuppressWarnings("unchecked") + void connectTimeoutCanBeConfiguredOnDetectedFactory() { + ClientHttpRequestFactory requestFactory = this.builder.setConnectTimeout(Duration.ofMillis(1234)).build() + .getRequestFactory(); + assertThat(connectTimeout((T) requestFactory)).isEqualTo(1234); + } + + @Test + @SuppressWarnings("unchecked") + void readTimeoutCanBeConfiguredOnDetectedFactory() { + ClientHttpRequestFactory requestFactory = this.builder.setReadTimeout(Duration.ofMillis(1234)).build() + .getRequestFactory(); + assertThat(readTimeout((T) requestFactory)).isEqualTo(1234); + } + + protected abstract long connectTimeout(T requestFactory); + + protected abstract long readTimeout(T requestFactory); + +} diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/ClientHttpRequestFactoriesHttpComponentsTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/ClientHttpRequestFactoriesHttpComponentsTests.java new file mode 100644 index 00000000000..afae5323ac4 --- /dev/null +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/ClientHttpRequestFactoriesHttpComponentsTests.java @@ -0,0 +1,52 @@ +/* + * Copyright 2012-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.web.client; + +import org.apache.hc.client5.http.classic.HttpClient; +import org.apache.hc.core5.http.io.SocketConfig; + +import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; +import org.springframework.test.util.ReflectionTestUtils; + +/** + * Tests for {@link ClientHttpRequestFactories} when Apache Http Components is the + * predominant HTTP client. + * + * @author Andy Wilkinson + */ +class ClientHttpRequestFactoriesHttpComponentsTests + extends AbstractClientHttpRequestFactoriesTests { + + ClientHttpRequestFactoriesHttpComponentsTests() { + super(HttpComponentsClientHttpRequestFactory.class); + } + + @Override + protected long connectTimeout(HttpComponentsClientHttpRequestFactory requestFactory) { + return (int) ReflectionTestUtils.getField(requestFactory, "connectTimeout"); + } + + @Override + protected long readTimeout(HttpComponentsClientHttpRequestFactory requestFactory) { + HttpClient httpClient = requestFactory.getHttpClient(); + Object connectionManager = ReflectionTestUtils.getField(httpClient, "connManager"); + SocketConfig socketConfig = (SocketConfig) ReflectionTestUtils.getField(connectionManager, + "defaultSocketConfig"); + return socketConfig.getSoTimeout().toMilliseconds(); + } + +} diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/ClientHttpRequestFactoriesOkHttp3Tests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/ClientHttpRequestFactoriesOkHttp3Tests.java new file mode 100644 index 00000000000..2b92780771c --- /dev/null +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/ClientHttpRequestFactoriesOkHttp3Tests.java @@ -0,0 +1,69 @@ +/* + * Copyright 2012-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.web.client; + +import java.io.File; + +import okhttp3.OkHttpClient; +import org.junit.jupiter.api.Test; + +import org.springframework.boot.testsupport.classpath.ClassPathExclusions; +import org.springframework.boot.testsupport.classpath.ClassPathOverrides; +import org.springframework.http.client.OkHttp3ClientHttpRequestFactory; +import org.springframework.test.util.ReflectionTestUtils; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; + +/** + * Tests for {@link ClientHttpRequestFactories} when OkHttp 3 is the predominant HTTP + * client. + * + * @author Andy Wilkinson + */ +@ClassPathOverrides("com.squareup.okhttp3:okhttp:3.14.9") +@ClassPathExclusions("httpclient5-*.jar") +class ClientHttpRequestFactoriesOkHttp3Tests + extends AbstractClientHttpRequestFactoriesTests { + + ClientHttpRequestFactoriesOkHttp3Tests() { + super(OkHttp3ClientHttpRequestFactory.class); + } + + @Test + void okHttp3IsBeingUsed() { + assertThat(new File(OkHttpClient.class.getProtectionDomain().getCodeSource().getLocation().getFile()).getName()) + .startsWith("okhttp-3."); + } + + @Test + void getFailsWhenBufferRequestBodyIsEnabled() { + assertThatIllegalStateException().isThrownBy(() -> ClientHttpRequestFactories + .get(ClientHttpRequestFactorySettings.DEFAULTS.withBufferRequestBody(true))); + } + + @Override + protected long connectTimeout(OkHttp3ClientHttpRequestFactory requestFactory) { + return ((OkHttpClient) ReflectionTestUtils.getField(requestFactory, "client")).connectTimeoutMillis(); + } + + @Override + protected long readTimeout(OkHttp3ClientHttpRequestFactory requestFactory) { + return ((OkHttpClient) ReflectionTestUtils.getField(requestFactory, "client")).readTimeoutMillis(); + } + +} diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/ClientHttpRequestFactoriesOkHttp4Tests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/ClientHttpRequestFactoriesOkHttp4Tests.java new file mode 100644 index 00000000000..826a86d9c43 --- /dev/null +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/ClientHttpRequestFactoriesOkHttp4Tests.java @@ -0,0 +1,67 @@ +/* + * Copyright 2012-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.web.client; + +import java.io.File; + +import okhttp3.OkHttpClient; +import org.junit.jupiter.api.Test; + +import org.springframework.boot.testsupport.classpath.ClassPathExclusions; +import org.springframework.http.client.OkHttp3ClientHttpRequestFactory; +import org.springframework.test.util.ReflectionTestUtils; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; + +/** + * Tests for {@link ClientHttpRequestFactories} when OkHttp 4 is the predominant HTTP + * client. + * + * @author Andy Wilkinson + */ +@ClassPathExclusions("httpclient5-*.jar") +class ClientHttpRequestFactoriesOkHttp4Tests + extends AbstractClientHttpRequestFactoriesTests { + + ClientHttpRequestFactoriesOkHttp4Tests() { + super(OkHttp3ClientHttpRequestFactory.class); + } + + @Test + void okHttp4IsBeingUsed() { + assertThat(new File(OkHttpClient.class.getProtectionDomain().getCodeSource().getLocation().getFile()).getName()) + .startsWith("okhttp-4."); + } + + @Test + void getFailsWhenBufferRequestBodyIsEnabled() { + assertThatIllegalStateException().isThrownBy(() -> ClientHttpRequestFactories + .get(ClientHttpRequestFactorySettings.DEFAULTS.withBufferRequestBody(true))); + } + + @Override + protected long connectTimeout(OkHttp3ClientHttpRequestFactory requestFactory) { + return ((OkHttpClient) ReflectionTestUtils.getField(requestFactory, "client")).connectTimeoutMillis(); + } + + @Override + protected long readTimeout(OkHttp3ClientHttpRequestFactory requestFactory) { + return ((OkHttpClient) ReflectionTestUtils.getField(requestFactory, "client")).readTimeoutMillis(); + } + +} diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/ClientHttpRequestFactoriesRuntimeHintsTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/ClientHttpRequestFactoriesRuntimeHintsTests.java new file mode 100644 index 00000000000..886c8c6d3a1 --- /dev/null +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/ClientHttpRequestFactoriesRuntimeHintsTests.java @@ -0,0 +1,91 @@ +/* + * Copyright 2012-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.web.client; + +import org.junit.jupiter.api.Test; + +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.predicate.ReflectionHintsPredicates; +import org.springframework.aot.hint.predicate.RuntimeHintsPredicates; +import org.springframework.http.client.AbstractClientHttpRequestFactoryWrapper; +import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; +import org.springframework.http.client.OkHttp3ClientHttpRequestFactory; +import org.springframework.http.client.SimpleClientHttpRequestFactory; +import org.springframework.util.ReflectionUtils; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link ClientHttpRequestFactoriesRuntimeHints}. + * + * @author Andy Wilkinson + */ +public class ClientHttpRequestFactoriesRuntimeHintsTests { + + @Test + void shouldRegisterHints() { + RuntimeHints hints = new RuntimeHints(); + new ClientHttpRequestFactoriesRuntimeHints().registerHints(hints, getClass().getClassLoader()); + ReflectionHintsPredicates reflection = RuntimeHintsPredicates.reflection(); + assertThat(reflection + .onField(ReflectionUtils.findField(AbstractClientHttpRequestFactoryWrapper.class, "requestFactory"))) + .accepts(hints); + } + + @Test + void shouldRegisterHttpComponentHints() { + RuntimeHints hints = new RuntimeHints(); + new ClientHttpRequestFactoriesRuntimeHints().registerHints(hints, getClass().getClassLoader()); + ReflectionHintsPredicates reflection = RuntimeHintsPredicates.reflection(); + assertThat(reflection.onMethod(ReflectionUtils.findMethod(HttpComponentsClientHttpRequestFactory.class, + "setConnectTimeout", int.class))).accepts(hints); + assertThat(reflection.onMethod( + ReflectionUtils.findMethod(HttpComponentsClientHttpRequestFactory.class, "setReadTimeout", int.class))) + .accepts(hints); + assertThat(reflection.onMethod(ReflectionUtils.findMethod(HttpComponentsClientHttpRequestFactory.class, + "setBufferRequestBody", boolean.class))).accepts(hints); + } + + @Test + void shouldRegisterOkHttpHints() { + RuntimeHints hints = new RuntimeHints(); + new ClientHttpRequestFactoriesRuntimeHints().registerHints(hints, getClass().getClassLoader()); + ReflectionHintsPredicates reflection = RuntimeHintsPredicates.reflection(); + assertThat(reflection.onMethod( + ReflectionUtils.findMethod(OkHttp3ClientHttpRequestFactory.class, "setConnectTimeout", int.class))) + .accepts(hints); + assertThat(reflection.onMethod( + ReflectionUtils.findMethod(OkHttp3ClientHttpRequestFactory.class, "setReadTimeout", int.class))) + .accepts(hints); + } + + @Test + void shouldRegisterSimpleHttpHints() { + RuntimeHints hints = new RuntimeHints(); + new ClientHttpRequestFactoriesRuntimeHints().registerHints(hints, getClass().getClassLoader()); + ReflectionHintsPredicates reflection = RuntimeHintsPredicates.reflection(); + assertThat(reflection.onMethod( + ReflectionUtils.findMethod(SimpleClientHttpRequestFactory.class, "setConnectTimeout", int.class))) + .accepts(hints); + assertThat(reflection.onMethod( + ReflectionUtils.findMethod(SimpleClientHttpRequestFactory.class, "setReadTimeout", int.class))) + .accepts(hints); + assertThat(reflection.onMethod(ReflectionUtils.findMethod(SimpleClientHttpRequestFactory.class, + "setBufferRequestBody", boolean.class))).accepts(hints); + } + +} diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/ClientHttpRequestFactoriesSimpleTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/ClientHttpRequestFactoriesSimpleTests.java new file mode 100644 index 00000000000..a9e75aa6496 --- /dev/null +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/ClientHttpRequestFactoriesSimpleTests.java @@ -0,0 +1,47 @@ +/* + * Copyright 2012-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.web.client; + +import org.springframework.boot.testsupport.classpath.ClassPathExclusions; +import org.springframework.http.client.SimpleClientHttpRequestFactory; +import org.springframework.test.util.ReflectionTestUtils; + +/** + * Tests for {@link ClientHttpRequestFactories} when the simple JDK-based client is the + * predominant HTTP client. + * + * @author Andy Wilkinson + */ +@ClassPathExclusions({ "httpclient5-*.jar", "okhttp-*.jar" }) +class ClientHttpRequestFactoriesSimpleTests + extends AbstractClientHttpRequestFactoriesTests { + + ClientHttpRequestFactoriesSimpleTests() { + super(SimpleClientHttpRequestFactory.class); + } + + @Override + protected long connectTimeout(SimpleClientHttpRequestFactory requestFactory) { + return (int) ReflectionTestUtils.getField(requestFactory, "connectTimeout"); + } + + @Override + protected long readTimeout(SimpleClientHttpRequestFactory requestFactory) { + return (int) ReflectionTestUtils.getField(requestFactory, "readTimeout"); + } + +} diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/ClientHttpRequestFactoriesTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/ClientHttpRequestFactoriesTests.java new file mode 100644 index 00000000000..4bff39170bb --- /dev/null +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/ClientHttpRequestFactoriesTests.java @@ -0,0 +1,247 @@ +/* + * Copyright 2012-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.web.client; + +import java.io.IOException; +import java.net.URI; +import java.time.Duration; + +import org.junit.jupiter.api.Test; + +import org.springframework.http.HttpMethod; +import org.springframework.http.client.BufferingClientHttpRequestFactory; +import org.springframework.http.client.ClientHttpRequest; +import org.springframework.http.client.ClientHttpRequestFactory; +import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; +import org.springframework.http.client.OkHttp3ClientHttpRequestFactory; +import org.springframework.http.client.SimpleClientHttpRequestFactory; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; + +/** + * Tests for {@link ClientHttpRequestFactories}. + * + * @author Andy Wilkinson + */ +class ClientHttpRequestFactoriesTests { + + @Test + void getReturnsRequestFactoryOfExpectedType() { + ClientHttpRequestFactory requestFactory = ClientHttpRequestFactories + .get(ClientHttpRequestFactorySettings.DEFAULTS); + assertThat(requestFactory).isInstanceOf(HttpComponentsClientHttpRequestFactory.class); + } + + @Test + void getOfGeneralTypeReturnsRequestFactoryOfExpectedType() { + ClientHttpRequestFactory requestFactory = ClientHttpRequestFactories.get(ClientHttpRequestFactory.class, + ClientHttpRequestFactorySettings.DEFAULTS); + assertThat(requestFactory).isInstanceOf(HttpComponentsClientHttpRequestFactory.class); + } + + @Test + void getOfSimpleFactoryReturnsSimpleFactory() { + ClientHttpRequestFactory requestFactory = ClientHttpRequestFactories.get(SimpleClientHttpRequestFactory.class, + ClientHttpRequestFactorySettings.DEFAULTS); + assertThat(requestFactory).isInstanceOf(SimpleClientHttpRequestFactory.class); + } + + @Test + void getOfHttpComponentsFactoryReturnsHttpComponentsFactory() { + ClientHttpRequestFactory requestFactory = ClientHttpRequestFactories + .get(HttpComponentsClientHttpRequestFactory.class, ClientHttpRequestFactorySettings.DEFAULTS); + assertThat(requestFactory).isInstanceOf(HttpComponentsClientHttpRequestFactory.class); + } + + @Test + void getOfOkHttpFactoryReturnsOkHttpFactory() { + ClientHttpRequestFactory requestFactory = ClientHttpRequestFactories.get(OkHttp3ClientHttpRequestFactory.class, + ClientHttpRequestFactorySettings.DEFAULTS); + assertThat(requestFactory).isInstanceOf(OkHttp3ClientHttpRequestFactory.class); + } + + @Test + void getOfUnknownTypeCreatesFactory() { + ClientHttpRequestFactory requestFactory = ClientHttpRequestFactories.get(TestClientHttpRequestFactory.class, + ClientHttpRequestFactorySettings.DEFAULTS); + assertThat(requestFactory).isInstanceOf(TestClientHttpRequestFactory.class); + } + + @Test + void getOfUnknownTypeWithConnectTimeoutCreatesFactoryAndConfiguresConnectTimeout() { + ClientHttpRequestFactory requestFactory = ClientHttpRequestFactories.get(TestClientHttpRequestFactory.class, + ClientHttpRequestFactorySettings.DEFAULTS.withConnectTimeout(Duration.ofSeconds(60))); + assertThat(requestFactory).isInstanceOf(TestClientHttpRequestFactory.class); + assertThat(((TestClientHttpRequestFactory) requestFactory).connectTimeout) + .isEqualTo(Duration.ofSeconds(60).toMillis()); + } + + @Test + void getOfUnknownTypeWithReadTimeoutCreatesFactoryAndConfiguresReadTimeout() { + ClientHttpRequestFactory requestFactory = ClientHttpRequestFactories.get(TestClientHttpRequestFactory.class, + ClientHttpRequestFactorySettings.DEFAULTS.withReadTimeout(Duration.ofSeconds(90))); + assertThat(requestFactory).isInstanceOf(TestClientHttpRequestFactory.class); + assertThat(((TestClientHttpRequestFactory) requestFactory).readTimeout) + .isEqualTo(Duration.ofSeconds(90).toMillis()); + } + + @Test + void getOfUnknownTypeWithBodyBufferingCreatesFactoryAndConfiguresBodyBuffering() { + ClientHttpRequestFactory requestFactory = ClientHttpRequestFactories.get(TestClientHttpRequestFactory.class, + ClientHttpRequestFactorySettings.DEFAULTS.withBufferRequestBody(true)); + assertThat(requestFactory).isInstanceOf(TestClientHttpRequestFactory.class); + assertThat(((TestClientHttpRequestFactory) requestFactory).bufferRequestBody).isTrue(); + } + + @Test + void getOfUnconfigurableTypeWithConnectTimeoutThrows() { + assertThatIllegalStateException() + .isThrownBy(() -> ClientHttpRequestFactories.get(UnconfigurableClientHttpRequestFactory.class, + ClientHttpRequestFactorySettings.DEFAULTS.withConnectTimeout(Duration.ofSeconds(60)))) + .withMessageContaining("suitable setConnectTimeout method"); + } + + @Test + void getOfUnconfigurableTypeWithReadTimeoutThrows() { + assertThatIllegalStateException() + .isThrownBy(() -> ClientHttpRequestFactories.get(UnconfigurableClientHttpRequestFactory.class, + ClientHttpRequestFactorySettings.DEFAULTS.withReadTimeout(Duration.ofSeconds(60)))) + .withMessageContaining("suitable setReadTimeout method"); + } + + @Test + void getOfUnconfigurableTypeWithBodyBufferingThrows() { + assertThatIllegalStateException() + .isThrownBy(() -> ClientHttpRequestFactories.get(UnconfigurableClientHttpRequestFactory.class, + ClientHttpRequestFactorySettings.DEFAULTS.withBufferRequestBody(true))) + .withMessageContaining("suitable setBufferRequestBody method"); + } + + @Test + void getOfTypeWithDeprecatedConnectTimeoutThrowsWithConnectTimeout() { + assertThatIllegalStateException() + .isThrownBy(() -> ClientHttpRequestFactories.get(DeprecatedMethodsClientHttpRequestFactory.class, + ClientHttpRequestFactorySettings.DEFAULTS.withConnectTimeout(Duration.ofSeconds(60)))) + .withMessageContaining("setConnectTimeout method marked as deprecated"); + } + + @Test + void getOfTypeWithDeprecatedReadTimeoutThrowsWithReadTimeout() { + assertThatIllegalStateException() + .isThrownBy(() -> ClientHttpRequestFactories.get(DeprecatedMethodsClientHttpRequestFactory.class, + ClientHttpRequestFactorySettings.DEFAULTS.withReadTimeout(Duration.ofSeconds(60)))) + .withMessageContaining("setReadTimeout method marked as deprecated"); + } + + @Test + void getOfTypeWithDeprecatedBufferRequestBodyThrowsWithBufferRequestBody() { + assertThatIllegalStateException() + .isThrownBy(() -> ClientHttpRequestFactories.get(DeprecatedMethodsClientHttpRequestFactory.class, + ClientHttpRequestFactorySettings.DEFAULTS.withBufferRequestBody(false))) + .withMessageContaining("setBufferRequestBody method marked as deprecated"); + } + + @Test + void connectTimeoutCanBeConfiguredOnAWrappedRequestFactory() { + SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory(); + BufferingClientHttpRequestFactory result = ClientHttpRequestFactories.get( + () -> new BufferingClientHttpRequestFactory(requestFactory), + ClientHttpRequestFactorySettings.DEFAULTS.withConnectTimeout(Duration.ofMillis(1234))); + assertThat(result).extracting("requestFactory").isSameAs(requestFactory); + assertThat(requestFactory).hasFieldOrPropertyWithValue("connectTimeout", 1234); + } + + @Test + void readTimeoutCanBeConfiguredOnAWrappedRequestFactory() { + SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory(); + BufferingClientHttpRequestFactory result = ClientHttpRequestFactories.get( + () -> new BufferingClientHttpRequestFactory(requestFactory), + ClientHttpRequestFactorySettings.DEFAULTS.withReadTimeout(Duration.ofMillis(1234))); + assertThat(result).extracting("requestFactory").isSameAs(requestFactory); + assertThat(requestFactory).hasFieldOrPropertyWithValue("readTimeout", 1234); + } + + @Test + void bufferRequestBodyCanBeConfiguredOnAWrappedRequestFactory() { + SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory(); + assertThat(requestFactory).hasFieldOrPropertyWithValue("bufferRequestBody", true); + BufferingClientHttpRequestFactory result = ClientHttpRequestFactories.get( + () -> new BufferingClientHttpRequestFactory(requestFactory), + ClientHttpRequestFactorySettings.DEFAULTS.withBufferRequestBody(false)); + assertThat(result).extracting("requestFactory").isSameAs(requestFactory); + assertThat(requestFactory).hasFieldOrPropertyWithValue("bufferRequestBody", false); + } + + public static class TestClientHttpRequestFactory implements ClientHttpRequestFactory { + + private int connectTimeout; + + private int readTimeout; + + private boolean bufferRequestBody; + + @Override + public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IOException { + throw new UnsupportedOperationException(); + } + + public void setConnectTimeout(int timeout) { + this.connectTimeout = timeout; + } + + public void setReadTimeout(int timeout) { + this.readTimeout = timeout; + } + + public void setBufferRequestBody(boolean bufferRequestBody) { + this.bufferRequestBody = bufferRequestBody; + } + + } + + public static class UnconfigurableClientHttpRequestFactory implements ClientHttpRequestFactory { + + @Override + public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IOException { + throw new UnsupportedOperationException(); + } + + } + + public static class DeprecatedMethodsClientHttpRequestFactory implements ClientHttpRequestFactory { + + @Override + public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IOException { + throw new UnsupportedOperationException(); + } + + @Deprecated(since = "3.0.0", forRemoval = false) + public void setConnectTimeout(int timeout) { + } + + @Deprecated(since = "3.0.0", forRemoval = false) + public void setReadTimeout(int timeout) { + } + + @Deprecated(since = "3.0.0", forRemoval = false) + public void setBufferRequestBody(boolean bufferRequestBody) { + } + + } + +} diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/ClientHttpRequestFactorySettingsTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/ClientHttpRequestFactorySettingsTests.java new file mode 100644 index 00000000000..e143d266173 --- /dev/null +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/ClientHttpRequestFactorySettingsTests.java @@ -0,0 +1,69 @@ +/* + * Copyright 2012-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.web.client; + +import java.time.Duration; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link ClientHttpRequestFactorySettings}. + * + * @author Phillip Webb + */ +class ClientHttpRequestFactorySettingsTests { + + private static final Duration ONE_SECOND = Duration.ofSeconds(1); + + @Test + void defaultsHasNullValues() { + ClientHttpRequestFactorySettings settings = ClientHttpRequestFactorySettings.DEFAULTS; + assertThat(settings.connectTimeout()).isNull(); + assertThat(settings.readTimeout()).isNull(); + assertThat(settings.bufferRequestBody()).isNull(); + } + + @Test + void withConnectTimeoutReturnsInstanceWithUpdatedConnectionTimeout() { + ClientHttpRequestFactorySettings settings = ClientHttpRequestFactorySettings.DEFAULTS + .withConnectTimeout(ONE_SECOND); + assertThat(settings.connectTimeout()).isEqualTo(ONE_SECOND); + assertThat(settings.readTimeout()).isNull(); + assertThat(settings.bufferRequestBody()).isNull(); + } + + @Test + void withReadTimeoutReturnsInstanceWithUpdatedReadTimeout() { + ClientHttpRequestFactorySettings settings = ClientHttpRequestFactorySettings.DEFAULTS + .withReadTimeout(ONE_SECOND); + assertThat(settings.connectTimeout()).isNull(); + assertThat(settings.readTimeout()).isEqualTo(ONE_SECOND); + assertThat(settings.bufferRequestBody()).isNull(); + } + + @Test + void withBufferRequestBodyReturnsInstanceWithUpdatedBufferRequestBody() { + ClientHttpRequestFactorySettings settings = ClientHttpRequestFactorySettings.DEFAULTS + .withBufferRequestBody(true); + assertThat(settings.connectTimeout()).isNull(); + assertThat(settings.readTimeout()).isNull(); + assertThat(settings.bufferRequestBody()).isTrue(); + } + +} diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java index 2f7e12cff5c..47f29318e26 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java @@ -18,28 +18,21 @@ package org.springframework.boot.web.client; import java.net.URI; import java.nio.charset.StandardCharsets; -import java.time.Duration; import java.util.Arrays; import java.util.Collections; import java.util.Set; +import java.util.function.Function; import java.util.function.Supplier; -import okhttp3.OkHttpClient; -import org.assertj.core.api.InstanceOfAssertFactories; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.InOrder; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; -import org.springframework.aot.hint.RuntimeHints; -import org.springframework.aot.hint.predicate.ReflectionHintsPredicates; -import org.springframework.aot.hint.predicate.RuntimeHintsPredicates; -import org.springframework.boot.web.client.RestTemplateBuilder.RestTemplateBuilderRuntimeHints; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; -import org.springframework.http.client.AbstractClientHttpRequestFactoryWrapper; import org.springframework.http.client.BufferingClientHttpRequestFactory; import org.springframework.http.client.ClientHttpRequest; import org.springframework.http.client.ClientHttpRequestFactory; @@ -47,22 +40,18 @@ import org.springframework.http.client.ClientHttpRequestInitializer; import org.springframework.http.client.ClientHttpRequestInterceptor; import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; import org.springframework.http.client.InterceptingClientHttpRequestFactory; -import org.springframework.http.client.OkHttp3ClientHttpRequestFactory; import org.springframework.http.client.SimpleClientHttpRequestFactory; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.ResourceHttpMessageConverter; import org.springframework.http.converter.StringHttpMessageConverter; import org.springframework.test.util.ReflectionTestUtils; import org.springframework.test.web.client.MockRestServiceServer; -import org.springframework.util.ReflectionUtils; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestTemplate; import org.springframework.web.util.UriTemplateHandler; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; -import static org.assertj.core.api.Assertions.assertThatIllegalStateException; -import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.entry; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.then; @@ -255,7 +244,7 @@ class RestTemplateBuilderTests { void requestFactoryClassWhenFactoryIsNullShouldThrowException() { assertThatIllegalArgumentException() .isThrownBy(() -> this.builder.requestFactory((Class) null)) - .withMessageContaining("RequestFactory must not be null"); + .withMessageContaining("RequestFactoryType must not be null"); } @Test @@ -274,7 +263,15 @@ class RestTemplateBuilderTests { void requestFactoryWhenSupplierIsNullShouldThrowException() { assertThatIllegalArgumentException() .isThrownBy(() -> this.builder.requestFactory((Supplier) null)) - .withMessageContaining("RequestFactory Supplier must not be null"); + .withMessageContaining("RequestFactorySupplier must not be null"); + } + + @Test + void requestFactoryWhenFunctionIsNullShouldThrowException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.builder + .requestFactory((Function) null)) + .withMessageContaining("RequestFactoryFunction must not be null"); } @Test @@ -459,128 +456,6 @@ class RestTemplateBuilderTests { assertThat(template.getRequestFactory()).isInstanceOf(HttpComponentsClientHttpRequestFactory.class); } - @Test - void connectTimeoutCanBeNullToUseDefault() { - ClientHttpRequestFactory requestFactory = this.builder.requestFactory(SimpleClientHttpRequestFactory.class) - .setConnectTimeout(null).build().getRequestFactory(); - assertThat(requestFactory).hasFieldOrPropertyWithValue("connectTimeout", -1); - } - - @Test - void readTimeoutCanBeNullToUseDefault() { - ClientHttpRequestFactory requestFactory = this.builder.requestFactory(SimpleClientHttpRequestFactory.class) - .setReadTimeout(null).build().getRequestFactory(); - assertThat(requestFactory).hasFieldOrPropertyWithValue("readTimeout", -1); - } - - @Test - void connectTimeoutCanBeConfiguredOnHttpComponentsRequestFactory() { - ClientHttpRequestFactory requestFactory = this.builder - .requestFactory(HttpComponentsClientHttpRequestFactory.class).setConnectTimeout(Duration.ofMillis(1234)) - .build().getRequestFactory(); - assertThat(((int) ReflectionTestUtils.getField(requestFactory, "connectTimeout"))).isEqualTo(1234); - } - - @Test - void readTimeoutConfigurationFailsOnHttpComponentsRequestFactory() { - assertThatThrownBy(() -> this.builder.requestFactory(HttpComponentsClientHttpRequestFactory.class) - .setReadTimeout(Duration.ofMillis(1234)).build()).isInstanceOf(IllegalStateException.class) - .hasMessageContaining("setReadTimeout method marked as deprecated"); - } - - @Test - void bufferRequestBodyCanBeConfiguredOnHttpComponentsRequestFactory() { - ClientHttpRequestFactory requestFactory = this.builder - .requestFactory(HttpComponentsClientHttpRequestFactory.class).setBufferRequestBody(false).build() - .getRequestFactory(); - assertThat(requestFactory).hasFieldOrPropertyWithValue("bufferRequestBody", false); - requestFactory = this.builder.requestFactory(HttpComponentsClientHttpRequestFactory.class) - .setBufferRequestBody(true).build().getRequestFactory(); - assertThat(requestFactory).hasFieldOrPropertyWithValue("bufferRequestBody", true); - requestFactory = this.builder.requestFactory(HttpComponentsClientHttpRequestFactory.class).build() - .getRequestFactory(); - assertThat(requestFactory).hasFieldOrPropertyWithValue("bufferRequestBody", true); - } - - @Test - void connectTimeoutCanBeConfiguredOnSimpleRequestFactory() { - ClientHttpRequestFactory requestFactory = this.builder.requestFactory(SimpleClientHttpRequestFactory.class) - .setConnectTimeout(Duration.ofMillis(1234)).build().getRequestFactory(); - assertThat(requestFactory).hasFieldOrPropertyWithValue("connectTimeout", 1234); - } - - @Test - void readTimeoutCanBeConfiguredOnSimpleRequestFactory() { - ClientHttpRequestFactory requestFactory = this.builder.requestFactory(SimpleClientHttpRequestFactory.class) - .setReadTimeout(Duration.ofMillis(1234)).build().getRequestFactory(); - assertThat(requestFactory).hasFieldOrPropertyWithValue("readTimeout", 1234); - } - - @Test - void bufferRequestBodyCanBeConfiguredOnSimpleRequestFactory() { - ClientHttpRequestFactory requestFactory = this.builder.requestFactory(SimpleClientHttpRequestFactory.class) - .setBufferRequestBody(false).build().getRequestFactory(); - assertThat(requestFactory).hasFieldOrPropertyWithValue("bufferRequestBody", false); - requestFactory = this.builder.requestFactory(SimpleClientHttpRequestFactory.class).setBufferRequestBody(true) - .build().getRequestFactory(); - assertThat(requestFactory).hasFieldOrPropertyWithValue("bufferRequestBody", true); - requestFactory = this.builder.requestFactory(SimpleClientHttpRequestFactory.class).build().getRequestFactory(); - assertThat(requestFactory).hasFieldOrPropertyWithValue("bufferRequestBody", true); - } - - @Test - void connectTimeoutCanBeConfiguredOnOkHttpRequestFactory() { - ClientHttpRequestFactory requestFactory = this.builder.requestFactory(OkHttp3ClientHttpRequestFactory.class) - .setConnectTimeout(Duration.ofMillis(1234)).build().getRequestFactory(); - assertThat(requestFactory).extracting("client", InstanceOfAssertFactories.type(OkHttpClient.class)) - .extracting(OkHttpClient::connectTimeoutMillis).isEqualTo(1234); - } - - @Test - void readTimeoutCanBeConfiguredOnOkHttp3RequestFactory() { - ClientHttpRequestFactory requestFactory = this.builder.requestFactory(OkHttp3ClientHttpRequestFactory.class) - .setReadTimeout(Duration.ofMillis(1234)).build().getRequestFactory(); - assertThat(requestFactory).extracting("client", InstanceOfAssertFactories.type(OkHttpClient.class)) - .extracting(OkHttpClient::readTimeoutMillis).isEqualTo(1234); - } - - @Test - void bufferRequestBodyCanNotBeConfiguredOnOkHttp3RequestFactory() { - assertThatIllegalStateException() - .isThrownBy(() -> this.builder.requestFactory(OkHttp3ClientHttpRequestFactory.class) - .setBufferRequestBody(false).build().getRequestFactory()) - .withMessageContaining(OkHttp3ClientHttpRequestFactory.class.getName()); - } - - @Test - void connectTimeoutCanBeConfiguredOnAWrappedRequestFactory() { - SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory(); - this.builder.requestFactory(() -> new BufferingClientHttpRequestFactory(requestFactory)) - .setConnectTimeout(Duration.ofMillis(1234)).build(); - assertThat(requestFactory).hasFieldOrPropertyWithValue("connectTimeout", 1234); - } - - @Test - void readTimeoutCanBeConfiguredOnAWrappedRequestFactory() { - SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory(); - this.builder.requestFactory(() -> new BufferingClientHttpRequestFactory(requestFactory)) - .setReadTimeout(Duration.ofMillis(1234)).build(); - assertThat(requestFactory).hasFieldOrPropertyWithValue("readTimeout", 1234); - } - - @Test - void bufferRequestBodyCanBeConfiguredOnAWrappedRequestFactory() { - SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory(); - this.builder.requestFactory(() -> new BufferingClientHttpRequestFactory(requestFactory)) - .setBufferRequestBody(false).build(); - assertThat(requestFactory).hasFieldOrPropertyWithValue("bufferRequestBody", false); - this.builder.requestFactory(() -> new BufferingClientHttpRequestFactory(requestFactory)) - .setBufferRequestBody(true).build(); - assertThat(requestFactory).hasFieldOrPropertyWithValue("bufferRequestBody", true); - this.builder.requestFactory(() -> new BufferingClientHttpRequestFactory(requestFactory)).build(); - assertThat(requestFactory).hasFieldOrPropertyWithValue("bufferRequestBody", true); - } - @Test void unwrappingDoesNotAffectRequestFactoryThatIsSetOnTheBuiltTemplate() { SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory(); @@ -589,58 +464,6 @@ class RestTemplateBuilderTests { assertThat(template.getRequestFactory()).isInstanceOf(BufferingClientHttpRequestFactory.class); } - @Test - void shouldRegisterHints() { - RuntimeHints hints = new RuntimeHints(); - new RestTemplateBuilderRuntimeHints().registerHints(hints, getClass().getClassLoader()); - ReflectionHintsPredicates reflection = RuntimeHintsPredicates.reflection(); - assertThat(reflection - .onField(ReflectionUtils.findField(AbstractClientHttpRequestFactoryWrapper.class, "requestFactory"))) - .accepts(hints); - } - - @Test - void shouldRegisterHttpComponentHints() { - RuntimeHints hints = new RuntimeHints(); - new RestTemplateBuilderRuntimeHints().registerHints(hints, getClass().getClassLoader()); - ReflectionHintsPredicates reflection = RuntimeHintsPredicates.reflection(); - assertThat(reflection.onMethod(ReflectionUtils.findMethod(HttpComponentsClientHttpRequestFactory.class, - "setConnectTimeout", int.class))).accepts(hints); - assertThat(reflection.onMethod( - ReflectionUtils.findMethod(HttpComponentsClientHttpRequestFactory.class, "setReadTimeout", int.class))) - .accepts(hints); - assertThat(reflection.onMethod(ReflectionUtils.findMethod(HttpComponentsClientHttpRequestFactory.class, - "setBufferRequestBody", boolean.class))).accepts(hints); - } - - @Test - void shouldRegisterOkHttpHints() { - RuntimeHints hints = new RuntimeHints(); - new RestTemplateBuilderRuntimeHints().registerHints(hints, getClass().getClassLoader()); - ReflectionHintsPredicates reflection = RuntimeHintsPredicates.reflection(); - assertThat(reflection.onMethod( - ReflectionUtils.findMethod(OkHttp3ClientHttpRequestFactory.class, "setConnectTimeout", int.class))) - .accepts(hints); - assertThat(reflection.onMethod( - ReflectionUtils.findMethod(OkHttp3ClientHttpRequestFactory.class, "setReadTimeout", int.class))) - .accepts(hints); - } - - @Test - void shouldRegisterSimpleHttpHints() { - RuntimeHints hints = new RuntimeHints(); - new RestTemplateBuilderRuntimeHints().registerHints(hints, getClass().getClassLoader()); - ReflectionHintsPredicates reflection = RuntimeHintsPredicates.reflection(); - assertThat(reflection.onMethod( - ReflectionUtils.findMethod(SimpleClientHttpRequestFactory.class, "setConnectTimeout", int.class))) - .accepts(hints); - assertThat(reflection.onMethod( - ReflectionUtils.findMethod(SimpleClientHttpRequestFactory.class, "setReadTimeout", int.class))) - .accepts(hints); - assertThat(reflection.onMethod(ReflectionUtils.findMethod(SimpleClientHttpRequestFactory.class, - "setBufferRequestBody", boolean.class))).accepts(hints); - } - private ClientHttpRequest createRequest(RestTemplate template) { return ReflectionTestUtils.invokeMethod(template, "createRequest", URI.create("http://localhost"), HttpMethod.GET); @@ -654,4 +477,8 @@ class RestTemplateBuilderTests { } + static class TestHttpComponentsClientHttpRequestFactory extends HttpComponentsClientHttpRequestFactory { + + } + } diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTestsOkHttp3Tests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTestsOkHttp3Tests.java deleted file mode 100644 index 24b6c20e193..00000000000 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTestsOkHttp3Tests.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright 2012-2022 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.boot.web.client; - -import java.time.Duration; - -import okhttp3.OkHttpClient; -import org.assertj.core.api.InstanceOfAssertFactories; -import org.junit.jupiter.api.Test; - -import org.springframework.boot.testsupport.classpath.ClassPathOverrides; -import org.springframework.http.client.ClientHttpRequestFactory; -import org.springframework.http.client.OkHttp3ClientHttpRequestFactory; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatIllegalStateException; - -/** - * Tests for {@link RestTemplateBuilder} with OkHttp 3.x. - * - * @author Andy Wilkinson - */ -@ClassPathOverrides("com.squareup.okhttp3:okhttp:3.14.9") -class RestTemplateBuilderTestsOkHttp3Tests { - - private RestTemplateBuilder builder = new RestTemplateBuilder(); - - @Test - void connectTimeoutCanBeConfiguredOnOkHttpRequestFactory() { - ClientHttpRequestFactory requestFactory = this.builder.requestFactory(OkHttp3ClientHttpRequestFactory.class) - .setConnectTimeout(Duration.ofMillis(1234)).build().getRequestFactory(); - assertThat(requestFactory).extracting("client", InstanceOfAssertFactories.type(OkHttpClient.class)) - .extracting(OkHttpClient::connectTimeoutMillis).isEqualTo(1234); - } - - @Test - void readTimeoutCanBeConfiguredOnOkHttpRequestFactory() { - ClientHttpRequestFactory requestFactory = this.builder.requestFactory(OkHttp3ClientHttpRequestFactory.class) - .setReadTimeout(Duration.ofMillis(1234)).build().getRequestFactory(); - assertThat(requestFactory).extracting("client", InstanceOfAssertFactories.type(OkHttpClient.class)) - .extracting(OkHttpClient::readTimeoutMillis).isEqualTo(1234); - } - - @Test - void bufferRequestBodyCanNotBeConfiguredOnOkHttpRequestFactory() { - assertThatIllegalStateException() - .isThrownBy(() -> this.builder.requestFactory(OkHttp3ClientHttpRequestFactory.class) - .setBufferRequestBody(false).build().getRequestFactory()) - .withMessageContaining(OkHttp3ClientHttpRequestFactory.class.getName()); - } - -} diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/webservices/client/HttpWebServiceMessageSenderBuilderTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/webservices/client/HttpWebServiceMessageSenderBuilderTests.java index f4de70a23c0..67380ed4806 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/webservices/client/HttpWebServiceMessageSenderBuilderTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/webservices/client/HttpWebServiceMessageSenderBuilderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2019 the original author or authors. + * Copyright 2012-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -57,8 +57,8 @@ class HttpWebServiceMessageSenderBuilderTests { @Test void buildUsesHttpComponentsByDefault() { - ClientHttpRequestMessageSender messageSender = build( - new HttpWebServiceMessageSenderBuilder().setConnectTimeout(Duration.ofSeconds(5))); + ClientHttpRequestMessageSender messageSender = build(new HttpWebServiceMessageSenderBuilder() + .setConnectTimeout(Duration.ofSeconds(5)).setReadTimeout(Duration.ofSeconds(5))); ClientHttpRequestFactory requestFactory = messageSender.getRequestFactory(); assertThat(requestFactory).isInstanceOf(HttpComponentsClientHttpRequestFactory.class); } diff --git a/spring-boot-system-tests/spring-boot-deployment-tests/src/systemTest/java/org/springframework/boot/deployment/AbstractDeploymentTests.java b/spring-boot-system-tests/spring-boot-deployment-tests/src/systemTest/java/org/springframework/boot/deployment/AbstractDeploymentTests.java index b02a6d26b42..3dcedcd68c9 100644 --- a/spring-boot-system-tests/spring-boot-deployment-tests/src/systemTest/java/org/springframework/boot/deployment/AbstractDeploymentTests.java +++ b/spring-boot-system-tests/spring-boot-deployment-tests/src/systemTest/java/org/springframework/boot/deployment/AbstractDeploymentTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2021 the original author or authors. + * Copyright 2012-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License.