diff --git a/spring-bootstrap/src/main/java/org/springframework/bootstrap/context/annotation/AbstractOnBeanCondition.java b/spring-bootstrap/src/main/java/org/springframework/bootstrap/context/annotation/AbstractOnBeanCondition.java index 9c3731f45e0..7c827372740 100644 --- a/spring-bootstrap/src/main/java/org/springframework/bootstrap/context/annotation/AbstractOnBeanCondition.java +++ b/spring-bootstrap/src/main/java/org/springframework/bootstrap/context/annotation/AbstractOnBeanCondition.java @@ -16,18 +16,24 @@ package org.springframework.bootstrap.context.annotation; +import java.lang.reflect.Method; import java.util.ArrayList; import java.util.List; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.beans.factory.BeanFactoryUtils; -import org.springframework.context.annotation.Condition; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.ConditionContext; +import org.springframework.context.annotation.ConfigurationCondition; import org.springframework.core.type.AnnotatedTypeMetadata; +import org.springframework.core.type.MethodMetadata; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import org.springframework.util.MultiValueMap; +import org.springframework.util.ReflectionUtils; +import org.springframework.util.ReflectionUtils.MethodCallback; /** * Base for {@link OnBeanCondition} and {@link OnMissingBeanCondition}. @@ -35,65 +41,95 @@ import org.springframework.util.MultiValueMap; * @author Phillip Webb * @author Dave Syer */ -abstract class AbstractOnBeanCondition implements Condition { +abstract class AbstractOnBeanCondition implements ConfigurationCondition { protected Log logger = LogFactory.getLog(getClass()); - private List beanClasses; - - private List beanNames; - protected abstract Class annotationClass(); - protected List getBeanClasses() { - return this.beanClasses; - } - - protected List getBeanNames() { - return this.beanNames; + @Override + public ConfigurationPhase getConfigurationPhase() { + return ConfigurationPhase.REGISTER_BEAN; } @Override public boolean matches(ConditionContext context, AnnotatedTypeMetadata metadata) { + MultiValueMap attributes = metadata.getAllAnnotationAttributes( + annotationClass().getName(), true); + final List beanClasses = collect(attributes, "value"); + final List beanNames = collect(attributes, "name"); + + if (beanClasses.size() == 0) { + if (metadata instanceof MethodMetadata + && metadata.isAnnotated(Bean.class.getName())) { + try { + final MethodMetadata methodMetadata = (MethodMetadata) metadata; + // We should be safe to load at this point since we are in the + // REGISTER_BEAN phase + Class configClass = ClassUtils.forName( + methodMetadata.getDeclaringClassName(), + context.getClassLoader()); + ReflectionUtils.doWithMethods(configClass, new MethodCallback() { + @Override + public void doWith(Method method) + throws IllegalArgumentException, IllegalAccessException { + if (methodMetadata.getMethodName().equals(method.getName())) { + beanClasses.add(method.getReturnType().getName()); + } + } + }); + } catch (Exception e) { + } + } + } + + Assert.isTrue(beanClasses.size() > 0 || beanNames.size() > 0, + "@" + ClassUtils.getShortName(annotationClass()) + + " annotations must specify at least one bean"); + + return matches(context, metadata, beanClasses, beanNames); + } + + protected boolean matches(ConditionContext context, AnnotatedTypeMetadata metadata, + List beanClasses, List beanNames) throws LinkageError { String checking = ConditionLogUtils.getPrefix(this.logger, metadata); - MultiValueMap attributes = metadata.getAllAnnotationAttributes( - annotationClass().getName(), true); - this.beanClasses = collect(attributes, "value"); - this.beanNames = collect(attributes, "name"); - Assert.isTrue(this.beanClasses.size() > 0 || this.beanNames.size() > 0, "@" - + ClassUtils.getShortName(annotationClass()) - + " annotations must specify at least one bean"); + Boolean considerHierarchy = (Boolean) metadata.getAnnotationAttributes( + annotationClass().getName()).get("considerHierarchy"); + considerHierarchy = (considerHierarchy == null ? false : considerHierarchy); List beanClassesFound = new ArrayList(); List beanNamesFound = new ArrayList(); - for (String beanClass : this.beanClasses) { + for (String beanClass : beanClasses) { try { // eagerInit set to false to prevent early instantiation (some // factory beans will not be able to determine their object type at this // stage, so those are not eligible for matching this condition) - String[] beans = BeanFactoryUtils.beanNamesForTypeIncludingAncestors( - context.getBeanFactory(), - ClassUtils.forName(beanClass, context.getClassLoader()), false, - false); + ConfigurableListableBeanFactory beanFactory = context.getBeanFactory(); + Class type = ClassUtils.forName(beanClass, context.getClassLoader()); + String[] beans = (considerHierarchy ? BeanFactoryUtils + .beanNamesForTypeIncludingAncestors(beanFactory, type, false, + false) : beanFactory.getBeanNamesForType(type, false, + false)); if (beans.length != 0) { beanClassesFound.add(beanClass); } } catch (ClassNotFoundException ex) { } } - for (String beanName : this.beanNames) { - if (context.getBeanFactory().containsBeanDefinition(beanName)) { + for (String beanName : beanNames) { + if (considerHierarchy ? context.getBeanFactory().containsBean(beanName) + : context.getBeanFactory().containsLocalBean(beanName)) { beanNamesFound.add(beanName); } } boolean result = evaluate(beanClassesFound, beanNamesFound); if (this.logger.isDebugEnabled()) { - logFoundResults(checking, "class", this.beanClasses, beanClassesFound); - logFoundResults(checking, "name", this.beanNames, beanClassesFound); + logFoundResults(checking, "class", beanClasses, beanClassesFound); + logFoundResults(checking, "name", beanNames, beanClassesFound); this.logger.debug(checking + "Match result is: " + result); } return result; diff --git a/spring-bootstrap/src/main/java/org/springframework/bootstrap/context/annotation/ConditionalOnBean.java b/spring-bootstrap/src/main/java/org/springframework/bootstrap/context/annotation/ConditionalOnBean.java index 7066408180a..1a1f1ee2817 100644 --- a/spring-bootstrap/src/main/java/org/springframework/bootstrap/context/annotation/ConditionalOnBean.java +++ b/spring-bootstrap/src/main/java/org/springframework/bootstrap/context/annotation/ConditionalOnBean.java @@ -52,4 +52,9 @@ public @interface ConditionalOnBean { */ String[] name() default {}; + /** + * If the application context hierarchy (parent contexts) should be considered. + */ + boolean considerHierarchy() default true; + } diff --git a/spring-bootstrap/src/main/java/org/springframework/bootstrap/context/annotation/ConditionalOnMissingBean.java b/spring-bootstrap/src/main/java/org/springframework/bootstrap/context/annotation/ConditionalOnMissingBean.java index 96a9567dc68..295a46efa76 100644 --- a/spring-bootstrap/src/main/java/org/springframework/bootstrap/context/annotation/ConditionalOnMissingBean.java +++ b/spring-bootstrap/src/main/java/org/springframework/bootstrap/context/annotation/ConditionalOnMissingBean.java @@ -52,4 +52,9 @@ public @interface ConditionalOnMissingBean { */ String[] name() default {}; + /** + * If the application context hierarchy (parent contexts) should be considered. + */ + boolean considerHierarchy() default true; + } diff --git a/spring-bootstrap/src/test/java/org/springframework/bootstrap/context/annotation/OnMissingBeanConditionTests.java b/spring-bootstrap/src/test/java/org/springframework/bootstrap/context/annotation/OnMissingBeanConditionTests.java index 7e13a4e7b6f..5bd94c287ad 100644 --- a/spring-bootstrap/src/test/java/org/springframework/bootstrap/context/annotation/OnMissingBeanConditionTests.java +++ b/spring-bootstrap/src/test/java/org/springframework/bootstrap/context/annotation/OnMissingBeanConditionTests.java @@ -21,13 +21,19 @@ import org.springframework.context.annotation.AnnotationConfigApplicationContext import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import static org.hamcrest.Matchers.equalTo; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; /** + * Tests for {@link OnMissingBeanCondition}. + * * @author Dave Syer + * @author Phillip Webb */ +@SuppressWarnings("resource") public class OnMissingBeanConditionTests { private AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); @@ -49,6 +55,35 @@ public class OnMissingBeanConditionTests { assertEquals("foo", this.context.getBean("foo")); } + @Test + public void hierarchyConsidered() throws Exception { + this.context.register(FooConfiguration.class); + this.context.refresh(); + AnnotationConfigApplicationContext childContext = new AnnotationConfigApplicationContext(); + childContext.setParent(this.context); + childContext.register(HierarchyConsidered.class); + childContext.refresh(); + assertFalse(childContext.containsLocalBean("bar")); + } + + @Test + public void hierarchyNotConsidered() throws Exception { + this.context.register(FooConfiguration.class); + this.context.refresh(); + AnnotationConfigApplicationContext childContext = new AnnotationConfigApplicationContext(); + childContext.setParent(this.context); + childContext.register(HierarchyNotConsidered.class); + childContext.refresh(); + assertTrue(childContext.containsLocalBean("bar")); + } + + @Test + public void impliedOnBeanMethod() throws Exception { + this.context.register(ExampleBeanConfiguration.class, ImpliedOnBeanMethod.class); + this.context.refresh(); + assertThat(this.context.getBeansOfType(ExampleBean.class).size(), equalTo(1)); + } + @Configuration @ConditionalOnMissingBean(name = "foo") protected static class OnBeanNameConfiguration { @@ -66,4 +101,43 @@ public class OnMissingBeanConditionTests { } } + @Configuration + @ConditionalOnMissingBean(name = "foo") + protected static class HierarchyConsidered { + @Bean + public String bar() { + return "bar"; + } + } + + @Configuration + @ConditionalOnMissingBean(name = "foo", considerHierarchy = false) + protected static class HierarchyNotConsidered { + @Bean + public String bar() { + return "bar"; + } + } + + @Configuration + protected static class ExampleBeanConfiguration { + @Bean + public ExampleBean exampleBean() { + return new ExampleBean(); + } + } + + @Configuration + protected static class ImpliedOnBeanMethod { + + @Bean + @ConditionalOnMissingBean + public ExampleBean exampleBean2() { + return new ExampleBean(); + } + + } + + public static class ExampleBean { + } }