Don't call runners in parent ApplicationContext

Update `SpringApplication` so that `ApplicationRunner` and
`CommandLineRunner` beans are not considered from the parent
`ApplicationContext`.

The restores the behavior that applied before commit 7d6532cac4
whilst still retaining the correct run order.

Fixes gh-38647
This commit is contained in:
Phillip Webb 2023-12-15 08:31:29 -08:00
parent e63be1bf73
commit 13fb450563
2 changed files with 111 additions and 27 deletions

View File

@ -17,12 +17,15 @@
package org.springframework.boot;
import java.lang.StackWalker.StackFrame;
import java.lang.reflect.Method;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
@ -38,14 +41,17 @@ import org.apache.commons.logging.LogFactory;
import org.springframework.aot.AotDetector;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
import org.springframework.beans.factory.config.ConfigurableBeanFactory;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.groovy.GroovyBeanDefinitionReader;
import org.springframework.beans.factory.support.AbstractAutowireCapableBeanFactory;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.BeanNameGenerator;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.beans.factory.xml.XmlBeanDefinitionReader;
import org.springframework.boot.Banner.Mode;
import org.springframework.boot.context.properties.bind.Bindable;
@ -68,6 +74,8 @@ import org.springframework.context.aot.AotApplicationContextInitializer;
import org.springframework.context.support.AbstractApplicationContext;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.core.GenericTypeResolver;
import org.springframework.core.OrderComparator;
import org.springframework.core.OrderComparator.OrderSourceProvider;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.AnnotationAwareOrderComparator;
import org.springframework.core.annotation.Order;
@ -746,33 +754,40 @@ public class SpringApplication {
protected void afterRefresh(ConfigurableApplicationContext context, ApplicationArguments args) {
}
private void callRunners(ApplicationContext context, ApplicationArguments args) {
context.getBeanProvider(Runner.class).orderedStream().forEach((runner) -> {
if (runner instanceof ApplicationRunner applicationRunner) {
callRunner(applicationRunner, args);
}
if (runner instanceof CommandLineRunner commandLineRunner) {
callRunner(commandLineRunner, args);
}
});
private void callRunners(ConfigurableApplicationContext context, ApplicationArguments args) {
ConfigurableListableBeanFactory beanFactory = context.getBeanFactory();
String[] beanNames = beanFactory.getBeanNamesForType(Runner.class);
Map<Runner, String> instancesToBeanNames = new IdentityHashMap<>();
for (String beanName : beanNames) {
instancesToBeanNames.put(beanFactory.getBean(beanName, Runner.class), beanName);
}
Comparator<Object> comparator = getOrderComparator(beanFactory)
.withSourceProvider(new FactoryAwareOrderSourceProvider(beanFactory, instancesToBeanNames));
instancesToBeanNames.keySet().stream().sorted(comparator).forEach((runner) -> callRunner(runner, args));
}
private void callRunner(ApplicationRunner runner, ApplicationArguments args) {
try {
(runner).run(args);
private OrderComparator getOrderComparator(ConfigurableListableBeanFactory beanFactory) {
Comparator<?> dependencyComparator = (beanFactory instanceof DefaultListableBeanFactory defaultListableBeanFactory)
? defaultListableBeanFactory.getDependencyComparator() : null;
return (dependencyComparator instanceof OrderComparator orderComparator) ? orderComparator
: AnnotationAwareOrderComparator.INSTANCE;
}
private void callRunner(Runner runner, ApplicationArguments args) {
if (runner instanceof ApplicationRunner) {
callRunner(ApplicationRunner.class, runner, (applicationRunner) -> applicationRunner.run(args));
}
catch (Exception ex) {
throw new IllegalStateException("Failed to execute ApplicationRunner", ex);
if (runner instanceof CommandLineRunner) {
callRunner(CommandLineRunner.class, runner,
(commandLineRunner) -> commandLineRunner.run(args.getSourceArgs()));
}
}
private void callRunner(CommandLineRunner runner, ApplicationArguments args) {
try {
(runner).run(args.getSourceArgs());
}
catch (Exception ex) {
throw new IllegalStateException("Failed to execute CommandLineRunner", ex);
}
@SuppressWarnings("unchecked")
private <R extends Runner> void callRunner(Class<R> type, Runner runner, ThrowingConsumer<R> call) {
call.throwing(
(message, ex) -> new IllegalStateException("Failed to execute " + ClassUtils.getShortName(type), ex))
.accept((R) runner);
}
private void handleRunFailure(ConfigurableApplicationContext context, Throwable exception,
@ -1598,4 +1613,41 @@ public class SpringApplication {
}
/**
* {@link OrderSourceProvider} used to obtain factory method and target type order
* sources. Based on internal {@link DefaultListableBeanFactory} code.
*/
private class FactoryAwareOrderSourceProvider implements OrderSourceProvider {
private final ConfigurableBeanFactory beanFactory;
private final Map<?, String> instancesToBeanNames;
FactoryAwareOrderSourceProvider(ConfigurableBeanFactory beanFactory, Map<?, String> instancesToBeanNames) {
this.beanFactory = beanFactory;
this.instancesToBeanNames = instancesToBeanNames;
}
@Override
public Object getOrderSource(Object obj) {
String beanName = this.instancesToBeanNames.get(obj);
return (beanName != null) ? getOrderSource(beanName, obj.getClass()) : null;
}
private Object getOrderSource(String beanName, Class<?> instanceType) {
try {
RootBeanDefinition beanDefinition = (RootBeanDefinition) this.beanFactory
.getMergedBeanDefinition(beanName);
Method factoryMethod = beanDefinition.getResolvedFactoryMethod();
Class<?> targetType = beanDefinition.getTargetType();
targetType = (targetType != instanceType) ? targetType : null;
return Stream.of(factoryMethod, targetType).filter(Objects::nonNull).toArray();
}
catch (NoSuchBeanDefinitionException ex) {
return null;
}
}
}
}

View File

@ -60,6 +60,7 @@ import org.springframework.boot.availability.AvailabilityChangeEvent;
import org.springframework.boot.availability.AvailabilityState;
import org.springframework.boot.availability.LivenessState;
import org.springframework.boot.availability.ReadinessState;
import org.springframework.boot.builder.ParentContextApplicationContextInitializer;
import org.springframework.boot.builder.SpringApplicationBuilder;
import org.springframework.boot.context.event.ApplicationContextInitializedEvent;
import org.springframework.boot.context.event.ApplicationEnvironmentPreparedEvent;
@ -630,6 +631,19 @@ class SpringApplicationTests {
assertThat(this.context).has(runTestRunnerBean("runnerC"));
}
@Test
void runCommandLineRunnersAndApplicationRunnersWithParentContext() {
SpringApplication application = new SpringApplication(CommandLineRunConfig.class);
application.setWebApplicationType(WebApplicationType.NONE);
application.addInitializers(new ParentContextApplicationContextInitializer(
new AnnotationConfigApplicationContext(CommandLineRunParentConfig.class)));
this.context = application.run("arg");
assertThat(this.context).has(runTestRunnerBean("runnerA"));
assertThat(this.context).has(runTestRunnerBean("runnerB"));
assertThat(this.context).has(runTestRunnerBean("runnerC"));
assertThat(this.context).doesNotHave(runTestRunnerBean("runnerP"));
}
@Test
void runCommandLineRunnersAndApplicationRunnersUsingOrderOnBeanDefinitions() {
SpringApplication application = new SpringApplication(BeanDefinitionOrderRunnerConfig.class);
@ -1432,7 +1446,7 @@ class SpringApplicationTests {
};
}
private Condition<ConfigurableApplicationContext> runTestRunnerBean(final String name) {
private Condition<ConfigurableApplicationContext> runTestRunnerBean(String name) {
return new Condition<>("run testrunner bean") {
@Override
@ -1642,17 +1656,27 @@ class SpringApplicationTests {
@Bean
TestCommandLineRunner runnerC() {
return new TestCommandLineRunner(Ordered.LOWEST_PRECEDENCE, "runnerB", "runnerA");
return new TestCommandLineRunner("runnerC", Ordered.LOWEST_PRECEDENCE, "runnerB", "runnerA");
}
@Bean
TestApplicationRunner runnerB() {
return new TestApplicationRunner(Ordered.LOWEST_PRECEDENCE - 1, "runnerA");
return new TestApplicationRunner("runnerB", Ordered.LOWEST_PRECEDENCE - 1, "runnerA");
}
@Bean
TestCommandLineRunner runnerA() {
return new TestCommandLineRunner(Ordered.HIGHEST_PRECEDENCE);
return new TestCommandLineRunner("runnerA", Ordered.HIGHEST_PRECEDENCE);
}
}
@Configuration(proxyBeanMethods = false)
static class CommandLineRunParentConfig {
@Bean
TestCommandLineRunner runnerP() {
return new TestCommandLineRunner("runnerP", Ordered.LOWEST_PRECEDENCE);
}
}
@ -1861,12 +1885,16 @@ class SpringApplicationTests {
static class TestCommandLineRunner extends AbstractTestRunner implements CommandLineRunner {
TestCommandLineRunner(int order, String... expectedBefore) {
private final String name;
TestCommandLineRunner(String name, int order, String... expectedBefore) {
super(order, expectedBefore);
this.name = name;
}
@Override
public void run(String... args) {
System.out.println(">>> " + this.name);
markAsRan();
}
@ -1874,12 +1902,16 @@ class SpringApplicationTests {
static class TestApplicationRunner extends AbstractTestRunner implements ApplicationRunner {
TestApplicationRunner(int order, String... expectedBefore) {
private final String name;
TestApplicationRunner(String name, int order, String... expectedBefore) {
super(order, expectedBefore);
this.name = name;
}
@Override
public void run(ApplicationArguments args) {
System.out.println(">>> " + this.name);
markAsRan();
}