diff --git a/spring-boot/src/main/java/org/springframework/boot/context/web/SpringBootServletInitializer.java b/spring-boot/src/main/java/org/springframework/boot/context/web/SpringBootServletInitializer.java index e5ae60fd204..3e711be5bdb 100644 --- a/spring-boot/src/main/java/org/springframework/boot/context/web/SpringBootServletInitializer.java +++ b/spring-boot/src/main/java/org/springframework/boot/context/web/SpringBootServletInitializer.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2014 the original author or authors. + * Copyright 2012-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -57,6 +57,7 @@ import org.springframework.web.context.WebApplicationContext; * * @author Dave Syer * @author Phillip Webb + * @author Andy Wilkinson * @see #configure(SpringApplicationBuilder) */ public abstract class SpringBootServletInitializer implements WebApplicationInitializer { @@ -84,6 +85,7 @@ public abstract class SpringBootServletInitializer implements WebApplicationInit protected WebApplicationContext createRootApplicationContext( ServletContext servletContext) { SpringApplicationBuilder builder = new SpringApplicationBuilder(); + builder.main(getClass()); ApplicationContext parent = getExistingRootWebApplicationContext(servletContext); if (parent != null) { this.logger.info("Root context already created (using as parent)."); diff --git a/spring-boot/src/test/java/org/springframework/boot/context/web/SpringBootServletInitializerTests.java b/spring-boot/src/test/java/org/springframework/boot/context/web/SpringBootServletInitializerTests.java index 595b2aac5e5..ddd796b6bef 100644 --- a/spring-boot/src/test/java/org/springframework/boot/context/web/SpringBootServletInitializerTests.java +++ b/spring-boot/src/test/java/org/springframework/boot/context/web/SpringBootServletInitializerTests.java @@ -26,6 +26,7 @@ import org.hamcrest.Matcher; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.springframework.beans.DirectFieldAccessor; import org.springframework.boot.SpringApplication; import org.springframework.boot.builder.SpringApplicationBuilder; import org.springframework.context.annotation.Configuration; @@ -33,12 +34,14 @@ import org.springframework.mock.web.MockServletContext; import org.springframework.web.context.WebApplicationContext; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; import static org.junit.Assert.assertThat; /** * Tests for {@link SpringBootServletInitializerTests}. * * @author Phillip Webb + * @author Andy Wilkinson */ public class SpringBootServletInitializerTests { @@ -72,6 +75,17 @@ public class SpringBootServletInitializerTests { equalToSet(Config.class, ErrorPageFilter.class)); } + @SuppressWarnings("rawtypes") + @Test + public void mainClassHasSensibleDefault() throws Exception { + new WithConfigurationAnnotation() + .createRootApplicationContext(this.servletContext); + Class mainApplicationClass = (Class) new DirectFieldAccessor(this.application) + .getPropertyValue("mainApplicationClass"); + assertThat(mainApplicationClass, + is(equalTo((Class) WithConfigurationAnnotation.class))); + } + private Matcher> equalToSet(Object... items) { Set set = new LinkedHashSet(); Collections.addAll(set, items);