diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/LazyInitializationBeanFactoryPostProcessor.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/LazyInitializationBeanFactoryPostProcessor.java index 6eb7a509bb0..61469e8bbc5 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/LazyInitializationBeanFactoryPostProcessor.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/LazyInitializationBeanFactoryPostProcessor.java @@ -16,10 +16,12 @@ package org.springframework.boot; +import java.util.ArrayList; import java.util.Collection; import org.springframework.beans.BeansException; import org.springframework.beans.factory.NoSuchBeanDefinitionException; +import org.springframework.beans.factory.SmartInitializingSingleton; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanFactoryPostProcessor; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; @@ -30,6 +32,11 @@ import org.springframework.core.Ordered; * {@link BeanFactoryPostProcessor} to set lazy-init on bean definitions that are not * {@link LazyInitializationExcludeFilter excluded} and have not already had a value * explicitly set. + *

+ * Note that {@link SmartInitializingSingleton SmartInitializingSingletons} are + * automatically excluded from lazy initialization to ensure that their + * {@link SmartInitializingSingleton#afterSingletonsInstantiated() callback method} is + * invoked. * * @author Andy Wilkinson * @author Madhura Bhave @@ -42,9 +49,7 @@ public final class LazyInitializationBeanFactoryPostProcessor implements BeanFac @Override public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException { - // Take care not to force the eager init of factory beans when getting filters - Collection filters = beanFactory - .getBeansOfType(LazyInitializationExcludeFilter.class, false, false).values(); + Collection filters = getFilters(beanFactory); for (String beanName : beanFactory.getBeanDefinitionNames()) { BeanDefinition beanDefinition = beanFactory.getBeanDefinition(beanName); if (beanDefinition instanceof AbstractBeanDefinition) { @@ -53,6 +58,14 @@ public final class LazyInitializationBeanFactoryPostProcessor implements BeanFac } } + private Collection getFilters(ConfigurableListableBeanFactory beanFactory) { + // Take care not to force the eager init of factory beans when getting filters + ArrayList filters = new ArrayList<>( + beanFactory.getBeansOfType(LazyInitializationExcludeFilter.class, false, false).values()); + filters.add(LazyInitializationExcludeFilter.forBeanTypes(SmartInitializingSingleton.class)); + return filters; + } + private void postProcess(ConfigurableListableBeanFactory beanFactory, Collection filters, String beanName, AbstractBeanDefinition beanDefinition) { diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/LazyInitializationBeanFactoryPostProcessorTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/LazyInitializationBeanFactoryPostProcessorTests.java new file mode 100644 index 00000000000..e42544fb6bd --- /dev/null +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/LazyInitializationBeanFactoryPostProcessorTests.java @@ -0,0 +1,90 @@ +/* + * Copyright 2012-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot; + +import java.util.ArrayList; +import java.util.List; + +import org.junit.jupiter.api.Test; + +import org.springframework.beans.factory.SmartInitializingSingleton; +import org.springframework.context.annotation.AnnotationConfigApplicationContext; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link LazyInitializationBeanFactoryPostProcessor}. + * + * @author Andy Wilkinson + */ +class LazyInitializationBeanFactoryPostProcessorTests { + + @Test + void whenLazyInitializationIsEnabledThenNormalBeansAreNotInitializedUntilRequired() { + try (AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext()) { + context.addBeanFactoryPostProcessor(new LazyInitializationBeanFactoryPostProcessor()); + context.register(BeanState.class, ExampleBean.class); + context.refresh(); + BeanState beanState = context.getBean(BeanState.class); + assertThat(beanState.initializedBeans).isEmpty(); + context.getBean(ExampleBean.class); + assertThat(beanState.initializedBeans).containsExactly(ExampleBean.class); + } + } + + @Test + void whenLazyInitializationIsEnabledThenSmartInitializingSingletonsAreInitializedDuringRefresh() { + try (AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext()) { + context.addBeanFactoryPostProcessor(new LazyInitializationBeanFactoryPostProcessor()); + context.register(BeanState.class, ExampleSmartInitializingSingleton.class); + context.refresh(); + BeanState beanState = context.getBean(BeanState.class); + assertThat(beanState.initializedBeans).containsExactly(ExampleSmartInitializingSingleton.class); + assertThat(context.getBean(ExampleSmartInitializingSingleton.class).callbackInvoked).isTrue(); + } + } + + static class ExampleBean { + + ExampleBean(BeanState beanState) { + beanState.initializedBeans.add(getClass()); + } + + } + + static class ExampleSmartInitializingSingleton implements SmartInitializingSingleton { + + private boolean callbackInvoked; + + ExampleSmartInitializingSingleton(BeanState beanState) { + beanState.initializedBeans.add(getClass()); + } + + @Override + public void afterSingletonsInstantiated() { + this.callbackInvoked = true; + } + + } + + static class BeanState { + + private final List> initializedBeans = new ArrayList<>(); + + } + +}