This commit is contained in:
Phillip Webb 2023-04-07 10:36:18 -04:00
parent 1849b82334
commit 2951cc7594
3 changed files with 61 additions and 88 deletions

View File

@ -19,6 +19,7 @@ package org.springframework.boot.autoconfigure.service.connection;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Stream;
import org.springframework.core.ResolvableType;
import org.springframework.core.annotation.AnnotationAwareOrderComparator;
@ -35,76 +36,62 @@ import org.springframework.core.style.ToStringCreator;
*/
public class ConnectionDetailsFactories {
private List<FactoryDetails> registeredFactories = new ArrayList<>();
private List<Registration<?, ?>> registrations = new ArrayList<>();
public ConnectionDetailsFactories() {
this(SpringFactoriesLoader.forDefaultResourceLocation(ConnectionDetailsFactory.class.getClassLoader()));
}
@SuppressWarnings("rawtypes")
@SuppressWarnings({ "rawtypes", "unchecked" })
ConnectionDetailsFactories(SpringFactoriesLoader loader) {
List<ConnectionDetailsFactory> factories = loader.load(ConnectionDetailsFactory.class);
factories.stream().map(this::factoryDetails).filter(Objects::nonNull).forEach(this::register);
Stream<Registration<?, ?>> registrations = factories.stream().map(Registration::get);
registrations.filter(Objects::nonNull).forEach(this.registrations::add);
}
@SuppressWarnings("unchecked")
private FactoryDetails factoryDetails(ConnectionDetailsFactory<?, ?> factory) {
ResolvableType connectionDetailsFactory = findConnectionDetailsFactory(
ResolvableType.forClass(factory.getClass()));
if (connectionDetailsFactory != null) {
ResolvableType input = connectionDetailsFactory.getGeneric(0);
ResolvableType output = connectionDetailsFactory.getGeneric(1);
return new FactoryDetails(input.getRawClass(), (Class<? extends ConnectionDetails>) output.getRawClass(),
factory);
}
return null;
}
private ResolvableType findConnectionDetailsFactory(ResolvableType type) {
try {
ResolvableType[] interfaces = type.getInterfaces();
for (ResolvableType iface : interfaces) {
if (iface.getRawClass().equals(ConnectionDetailsFactory.class)) {
return iface;
}
}
}
catch (TypeNotPresentException ex) {
// A type referenced by the factory is not present. Skip it.
}
ResolvableType superType = type.getSuperType();
return ResolvableType.NONE.equals(superType) ? null : findConnectionDetailsFactory(superType);
}
private void register(FactoryDetails details) {
this.registeredFactories.add(details);
public <S> ConnectionDetails getConnectionDetails(S source) {
return getConnectionDetailsFactory(source).getConnectionDetails(source);
}
@SuppressWarnings("unchecked")
public <S> ConnectionDetailsFactory<S, ConnectionDetails> getConnectionDetailsFactory(S source) {
Class<S> input = (Class<S>) source.getClass();
List<ConnectionDetailsFactory<S, ConnectionDetails>> matchingFactories = new ArrayList<>();
for (FactoryDetails factoryDetails : this.registeredFactories) {
if (factoryDetails.input.isAssignableFrom(input)) {
matchingFactories.add((ConnectionDetailsFactory<S, ConnectionDetails>) factoryDetails.factory);
Class<S> sourceType = (Class<S>) source.getClass();
List<ConnectionDetailsFactory<S, ConnectionDetails>> result = new ArrayList<>();
for (Registration<?, ?> candidate : this.registrations) {
if (candidate.sourceType().isAssignableFrom(sourceType)) {
result.add((ConnectionDetailsFactory<S, ConnectionDetails>) candidate.factory());
}
}
if (matchingFactories.isEmpty()) {
if (result.isEmpty()) {
throw new ConnectionDetailsFactoryNotFoundException(source);
}
else {
if (matchingFactories.size() == 1) {
return matchingFactories.get(0);
AnnotationAwareOrderComparator.sort(result);
return (result.size() != 1) ? new CompositeConnectionDetailsFactory<>(result) : result.get(0);
}
/**
* A {@link ConnectionDetailsFactory} registration.
*/
private record Registration<S, D extends ConnectionDetails>(Class<S> sourceType, Class<D> connectionDetailsType,
ConnectionDetailsFactory<S, D> factory) {
@SuppressWarnings("unchecked")
private static <S, D extends ConnectionDetails> Registration<S, D> get(ConnectionDetailsFactory<S, D> factory) {
ResolvableType type = ResolvableType.forClass(ConnectionDetailsFactory.class, factory.getClass());
if (!type.hasUnresolvableGenerics()) {
Class<?>[] generics = type.resolveGenerics();
return new Registration<>((Class<S>) generics[0], (Class<D>) generics[1], factory);
}
AnnotationAwareOrderComparator.sort(matchingFactories);
return new CompositeConnectionDetailsFactory<>(matchingFactories);
return null;
}
}
private record FactoryDetails(Class<?> input, Class<? extends ConnectionDetails> output,
ConnectionDetailsFactory<?, ?> factory) {
}
/**
* Composite {@link ConnectionDetailsFactory} implementation.
*
* @param <S> the source type
*/
static class CompositeConnectionDetailsFactory<S> implements ConnectionDetailsFactory<S, ConnectionDetails> {
private final List<ConnectionDetailsFactory<S, ConnectionDetails>> delegates;
@ -114,15 +101,16 @@ public class ConnectionDetailsFactories {
}
@Override
@SuppressWarnings("unchecked")
public ConnectionDetails getConnectionDetails(Object source) {
for (ConnectionDetailsFactory<S, ConnectionDetails> delegate : this.delegates) {
ConnectionDetails connectionDetails = delegate.getConnectionDetails((S) source);
if (connectionDetails != null) {
return connectionDetails;
}
}
return null;
public ConnectionDetails getConnectionDetails(S source) {
return this.delegates.stream()
.map((delegate) -> delegate.getConnectionDetails(source))
.filter(Objects::nonNull)
.findFirst()
.orElse(null);
}
List<ConnectionDetailsFactory<S, ConnectionDetails>> getDelegates() {
return this.delegates;
}
@Override
@ -130,10 +118,6 @@ public class ConnectionDetailsFactories {
return new ToStringCreator(this).append("delegates", this.delegates).toString();
}
List<ConnectionDetailsFactory<S, ConnectionDetails>> getDelegates() {
return this.delegates;
}
}
}

View File

@ -51,19 +51,8 @@ class RedisContainerConnectionDetailsFactory
private RedisContainerConnectionDetails(
ContainerConnectionSource<RedisServiceConnection, RedisConnectionDetails, GenericContainer<?>> source) {
super(source);
this.standalone = new Standalone() {
@Override
public String getHost() {
return source.getContainer().getHost();
}
@Override
public int getPort() {
return source.getContainer().getFirstMappedPort();
}
};
this.standalone = Standalone.of(source.getContainer().getHost(),
source.getContainer().getFirstMappedPort());
}
@Override

View File

@ -17,6 +17,7 @@
package org.springframework.boot.test.autoconfigure.service.connection;
import java.util.List;
import java.util.function.Supplier;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
@ -24,7 +25,6 @@ import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.boot.autoconfigure.service.connection.ConnectionDetails;
import org.springframework.boot.autoconfigure.service.connection.ConnectionDetailsFactories;
import org.springframework.boot.autoconfigure.service.connection.ConnectionDetailsFactory;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.test.context.ContextCustomizer;
import org.springframework.test.context.MergedContextConfiguration;
@ -61,21 +61,21 @@ class ServiceConnectionContextCustomizer implements ContextCustomizer {
private void registerServiceConnection(BeanDefinitionRegistry registry, ContainerConnectionSource<?, ?, ?> source) {
ConnectionDetails connectionDetails = getConnectionDetails(source);
String beanName = source.getBeanName();
registry.registerBeanDefinition(beanName, createBeanDefinition(connectionDetails));
}
private <S> ConnectionDetails getConnectionDetails(S source) {
ConnectionDetailsFactory<S, ConnectionDetails> factory = this.factories.getConnectionDetailsFactory(source);
ConnectionDetails connectionDetails = factory.getConnectionDetails(source);
Assert.state(connectionDetails != null,
() -> "No connection details created by %s".formatted(factory.getClass().getName()));
return connectionDetails;
register(connectionDetails, registry, source.getBeanName());
}
@SuppressWarnings("unchecked")
private <T> BeanDefinition createBeanDefinition(T instance) {
return new RootBeanDefinition((Class<T>) instance.getClass(), () -> instance);
private <T> void register(ConnectionDetails connectionDetails, BeanDefinitionRegistry registry, String beanName) {
Class<T> beanType = (Class<T>) connectionDetails.getClass();
Supplier<T> beanSupplier = () -> (T) connectionDetails;
BeanDefinition beanDefinition = new RootBeanDefinition(beanType, beanSupplier);
registry.registerBeanDefinition(beanName, beanDefinition);
}
private <S> ConnectionDetails getConnectionDetails(S source) {
ConnectionDetails connectionDetails = this.factories.getConnectionDetails(source);
Assert.state(connectionDetails != null, () -> "No connection details created for %s".formatted(source));
return connectionDetails;
}
List<ContainerConnectionSource<?, ?, ?>> getSources() {