diff --git a/spring-boot/src/main/java/org/springframework/boot/SpringApplication.java b/spring-boot/src/main/java/org/springframework/boot/SpringApplication.java index 5eb6e97418b..23feed96924 100644 --- a/spring-boot/src/main/java/org/springframework/boot/SpringApplication.java +++ b/spring-boot/src/main/java/org/springframework/boot/SpringApplication.java @@ -135,7 +135,9 @@ public class SpringApplication { private final Log log = LogFactory.getLog(getClass()); - private Set sources = new LinkedHashSet(); + private Set defaultSources = new LinkedHashSet(); + + private Set additionalSources = new LinkedHashSet(); private Class mainApplicationClass; @@ -161,8 +163,6 @@ public class SpringApplication { private String[] defaultCommandLineArgs; - private boolean sourcesInitialized = false; - /** * Crate a new {@link SpringApplication} instance. The application context will load * beans from the specified sources (see {@link SpringApplication class-level} @@ -193,8 +193,7 @@ public class SpringApplication { private void initialize(Object[] sources) { if (sources != null && sources.length > 0) { - this.sourcesInitialized = true; - this.sources.addAll(Arrays.asList(sources)); + this.additionalSources.addAll(Arrays.asList(sources)); } this.webEnvironment = deduceWebEnvironment(); this.initializers = new ArrayList>(); @@ -249,7 +248,8 @@ public class SpringApplication { // Call all remaining initializers callEnvironmentAwareSpringApplicationInitializers(environment); - Assert.notEmpty(this.sources, "Sources must not be empty"); + Set sources = assembleSources(); + Assert.notEmpty(sources, "Sources must not be empty"); if (this.showBanner) { printBanner(); } @@ -267,12 +267,19 @@ public class SpringApplication { if (this.logStartupInfo) { logStartupInfo(); } - load(context, this.sources.toArray(new Object[this.sources.size()])); + load(context, sources.toArray(new Object[sources.size()])); refresh(context); runCommandLineRunners(context, args); return context; } + private Set assembleSources() { + LinkedHashSet sources = new LinkedHashSet(); + sources.addAll(this.defaultSources); + sources.addAll(this.additionalSources); + return sources; + } + private void callNonEnvironmentAwareSpringApplicationInitializers() { for (ApplicationContextInitializer initializer : this.initializers) { if (initializer instanceof SpringApplicationInitializer @@ -354,7 +361,7 @@ public class SpringApplication { Log applicationLog = getApplicationLog(); new StartupInfoLogger(this.mainApplicationClass).log(applicationLog); if (applicationLog.isDebugEnabled()) { - applicationLog.debug("Sources: " + this.sources); + applicationLog.debug("Sources: " + this.defaultSources); } } @@ -561,13 +568,13 @@ public class SpringApplication { } /** - * Returns a mutable set of the sources that will be used to create an - * ApplicationContext when {@link #run(String...)} is called. + * Returns a mutable set of the sources that will be added to an ApplicationContext + * when {@link #run(String...)} is called. * @return the sources the application sources. * @see #SpringApplication(Object...) */ public Set getSources() { - return this.sources; + return this.defaultSources; } /** @@ -579,11 +586,8 @@ public class SpringApplication { * @see #SpringApplication(Object...) */ public void setSources(Set sources) { - if (this.sourcesInitialized) { - return; - } Assert.notNull(sources, "Sources must not be null"); - this.sources = new LinkedHashSet(sources); + this.defaultSources = new LinkedHashSet(sources); } /** diff --git a/spring-boot/src/test/java/org/springframework/boot/SpringApplicationTests.java b/spring-boot/src/test/java/org/springframework/boot/SpringApplicationTests.java index d638d9743a8..8bb6ba989f2 100644 --- a/spring-boot/src/test/java/org/springframework/boot/SpringApplicationTests.java +++ b/spring-boot/src/test/java/org/springframework/boot/SpringApplicationTests.java @@ -16,6 +16,7 @@ package org.springframework.boot; +import java.util.Set; import java.util.concurrent.atomic.AtomicReference; import org.junit.After; @@ -47,6 +48,7 @@ import org.springframework.core.env.StandardEnvironment; import org.springframework.core.io.DefaultResourceLoader; import org.springframework.core.io.ResourceLoader; import org.springframework.mock.web.MockServletContext; +import org.springframework.test.util.ReflectionTestUtils; import org.springframework.web.context.support.StaticWebApplicationContext; import static org.hamcrest.Matchers.equalTo; @@ -268,7 +270,10 @@ public class SpringApplicationTests { application.setWebEnvironment(false); application.setUseMockLoader(true); application.run(); - assertThat(application.getSources().toArray(), equalTo(sources)); + @SuppressWarnings("unchecked") + Set additionalSources = (Set) ReflectionTestUtils.getField( + application, "additionalSources"); + assertThat(additionalSources.toArray(), equalTo(sources)); } @Test