Add native-image support for RestTemplateBuilder

Closes gh-31888
This commit is contained in:
Moritz Halbritter 2022-08-01 14:48:58 +02:00
parent a3d4431d2e
commit ed1f6ad543
3 changed files with 113 additions and 19 deletions

View File

@ -16,13 +16,15 @@
package org.springframework.boot.web.client;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Supplier;
import org.springframework.beans.BeanUtils;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.TypeHint.Builder;
import org.springframework.aot.hint.TypeReference;
import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.http.client.OkHttp3ClientHttpRequestFactory;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.util.ClassUtils;
@ -31,30 +33,38 @@ import org.springframework.util.ClassUtils;
* based on the available implementations on the classpath.
*
* @author Stephane Nicoll
* @author Moritz Halbritter
* @since 2.1.0
*/
public class ClientHttpRequestFactorySupplier implements Supplier<ClientHttpRequestFactory> {
private static final Map<String, String> REQUEST_FACTORY_CANDIDATES;
private static final boolean APACHE_HTTP_CLIENT_PRESENT = ClassUtils.isPresent("org.apache.http.client.HttpClient",
null);
static {
Map<String, String> candidates = new LinkedHashMap<>();
candidates.put("org.apache.http.client.HttpClient",
"org.springframework.http.client.HttpComponentsClientHttpRequestFactory");
candidates.put("okhttp3.OkHttpClient", "org.springframework.http.client.OkHttp3ClientHttpRequestFactory");
REQUEST_FACTORY_CANDIDATES = Collections.unmodifiableMap(candidates);
}
private static final boolean OKHTTP_CLIENT_PRESENT = ClassUtils.isPresent("okhttp3.OkHttpClient", null);
@Override
public ClientHttpRequestFactory get() {
for (Map.Entry<String, String> candidate : REQUEST_FACTORY_CANDIDATES.entrySet()) {
ClassLoader classLoader = getClass().getClassLoader();
if (ClassUtils.isPresent(candidate.getKey(), classLoader)) {
Class<?> factoryClass = ClassUtils.resolveClassName(candidate.getValue(), classLoader);
return (ClientHttpRequestFactory) BeanUtils.instantiateClass(factoryClass);
}
if (APACHE_HTTP_CLIENT_PRESENT) {
return new HttpComponentsClientHttpRequestFactory();
}
if (OKHTTP_CLIENT_PRESENT) {
return new OkHttp3ClientHttpRequestFactory();
}
return new SimpleClientHttpRequestFactory();
}
static class ClientHttpRequestFactorySupplierRuntimeHints {
static void registerHints(RuntimeHints hints, ClassLoader classLoader, Consumer<Builder> callback) {
hints.reflection().registerType(HttpComponentsClientHttpRequestFactory.class, (hint) -> callback
.accept(hint.onReachableType(TypeReference.of("org.apache.http.client.HttpClient"))));
hints.reflection().registerType(OkHttp3ClientHttpRequestFactory.class,
(hint) -> callback.accept(hint.onReachableType(TypeReference.of("okhttp3.OkHttpClient"))));
hints.reflection().registerType(SimpleClientHttpRequestFactory.class, (hint) -> callback
.accept(hint.onReachableType(TypeReference.of(SimpleClientHttpRequestFactory.class))));
}
}
}

View File

@ -29,13 +29,19 @@ import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Supplier;
import reactor.netty.http.client.HttpClientRequest;
import org.springframework.aot.hint.ExecutableMode;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.RuntimeHintsRegistrar;
import org.springframework.aot.hint.TypeReference;
import org.springframework.beans.BeanUtils;
import org.springframework.context.annotation.ImportRuntimeHints;
import org.springframework.http.client.AbstractClientHttpRequestFactoryWrapper;
import org.springframework.http.client.ClientHttpRequest;
import org.springframework.http.client.ClientHttpRequestFactory;
@ -56,7 +62,7 @@ import org.springframework.web.util.UriTemplateHandler;
* converters}, {@link #errorHandler(ResponseErrorHandler) error handlers} and
* {@link #uriTemplateHandler(UriTemplateHandler) UriTemplateHandlers}.
* <p>
* By default the built {@link RestTemplate} will attempt to use the most suitable
* By default, the built {@link RestTemplate} will attempt to use the most suitable
* {@link ClientHttpRequestFactory}, call {@link #detectRequestFactory(boolean)
* detectRequestFactory(false)} if you prefer to keep the default. In a typical
* auto-configured Spring Boot application this builder is available as a bean and can be
@ -71,6 +77,7 @@ import org.springframework.web.util.UriTemplateHandler;
* @author Ilya Lukyanovich
* @since 1.4.0
*/
@ImportRuntimeHints(RestTemplateBuilder.RestTemplateBuilderRuntimeHints.class)
public class RestTemplateBuilder {
private final RequestFactoryCustomizer requestFactoryCustomizer;
@ -789,4 +796,23 @@ public class RestTemplateBuilder {
}
static class RestTemplateBuilderRuntimeHints implements RuntimeHintsRegistrar {
@Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
hints.reflection().registerField(Objects.requireNonNull(
ReflectionUtils.findField(AbstractClientHttpRequestFactoryWrapper.class, "requestFactory")));
ClientHttpRequestFactorySupplier.ClientHttpRequestFactorySupplierRuntimeHints.registerHints(hints,
classLoader, (hint) -> {
hint.withMethod("setConnectTimeout", List.of(TypeReference.of(int.class)),
(method) -> method.withMode(ExecutableMode.INVOKE));
hint.withMethod("setReadTimeout", List.of(TypeReference.of(int.class)),
(method) -> method.withMode(ExecutableMode.INVOKE));
hint.withMethod("setBufferRequestBody", List.of(TypeReference.of(boolean.class)),
(method) -> method.withMode(ExecutableMode.INVOKE));
});
}
}
}

View File

@ -33,9 +33,14 @@ import org.mockito.InOrder;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.predicate.ReflectionHintsPredicates;
import org.springframework.aot.hint.predicate.RuntimeHintsPredicates;
import org.springframework.boot.web.client.RestTemplateBuilder.RestTemplateBuilderRuntimeHints;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.client.AbstractClientHttpRequestFactoryWrapper;
import org.springframework.http.client.BufferingClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpRequest;
import org.springframework.http.client.ClientHttpRequestFactory;
@ -50,6 +55,7 @@ import org.springframework.http.converter.ResourceHttpMessageConverter;
import org.springframework.http.converter.StringHttpMessageConverter;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.test.web.client.MockRestServiceServer;
import org.springframework.util.ReflectionUtils;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriTemplateHandler;
@ -585,6 +591,58 @@ class RestTemplateBuilderTests {
assertThat(template.getRequestFactory()).isInstanceOf(BufferingClientHttpRequestFactory.class);
}
@Test
void shouldRegisterHints() {
RuntimeHints hints = new RuntimeHints();
new RestTemplateBuilderRuntimeHints().registerHints(hints, getClass().getClassLoader());
ReflectionHintsPredicates reflection = RuntimeHintsPredicates.reflection();
assertThat(reflection
.onField(ReflectionUtils.findField(AbstractClientHttpRequestFactoryWrapper.class, "requestFactory")))
.accepts(hints);
}
@Test
void shouldRegisterHttpComponentHints() {
RuntimeHints hints = new RuntimeHints();
new RestTemplateBuilderRuntimeHints().registerHints(hints, getClass().getClassLoader());
ReflectionHintsPredicates reflection = RuntimeHintsPredicates.reflection();
assertThat(reflection.onMethod(ReflectionUtils.findMethod(HttpComponentsClientHttpRequestFactory.class,
"setConnectTimeout", int.class))).accepts(hints);
assertThat(reflection.onMethod(
ReflectionUtils.findMethod(HttpComponentsClientHttpRequestFactory.class, "setReadTimeout", int.class)))
.accepts(hints);
assertThat(reflection.onMethod(ReflectionUtils.findMethod(HttpComponentsClientHttpRequestFactory.class,
"setBufferRequestBody", boolean.class))).accepts(hints);
}
@Test
void shouldRegisterOkHttpHints() {
RuntimeHints hints = new RuntimeHints();
new RestTemplateBuilderRuntimeHints().registerHints(hints, getClass().getClassLoader());
ReflectionHintsPredicates reflection = RuntimeHintsPredicates.reflection();
assertThat(reflection.onMethod(
ReflectionUtils.findMethod(OkHttp3ClientHttpRequestFactory.class, "setConnectTimeout", int.class)))
.accepts(hints);
assertThat(reflection.onMethod(
ReflectionUtils.findMethod(OkHttp3ClientHttpRequestFactory.class, "setReadTimeout", int.class)))
.accepts(hints);
}
@Test
void shouldRegisterSimpleHttpHints() {
RuntimeHints hints = new RuntimeHints();
new RestTemplateBuilderRuntimeHints().registerHints(hints, getClass().getClassLoader());
ReflectionHintsPredicates reflection = RuntimeHintsPredicates.reflection();
assertThat(reflection.onMethod(
ReflectionUtils.findMethod(SimpleClientHttpRequestFactory.class, "setConnectTimeout", int.class)))
.accepts(hints);
assertThat(reflection.onMethod(
ReflectionUtils.findMethod(SimpleClientHttpRequestFactory.class, "setReadTimeout", int.class)))
.accepts(hints);
assertThat(reflection.onMethod(ReflectionUtils.findMethod(SimpleClientHttpRequestFactory.class,
"setBufferRequestBody", boolean.class))).accepts(hints);
}
private ClientHttpRequest createRequest(RestTemplate template) {
return ReflectionTestUtils.invokeMethod(template, "createRequest", URI.create("http://localhost"),
HttpMethod.GET);