Add support for deriving a DataSourceBuilder from a wrapped DataSource

Closes gh-31605
This commit is contained in:
Andy Wilkinson 2022-07-08 17:23:11 +01:00
parent d58f33f1ce
commit fa43e1f378
2 changed files with 104 additions and 1 deletions

View File

@ -236,7 +236,35 @@ public final class DataSourceBuilder<T extends DataSource> {
throw new IllegalStateException("Unable to unwrap embedded database", ex);
}
}
return new DataSourceBuilder<>(dataSource);
try {
while (dataSource.isWrapperFor(DataSource.class)) {
DataSource unwrapped = dataSource.unwrap(DataSource.class);
if (unwrapped == dataSource) {
break;
}
dataSource = unwrapped;
}
}
catch (SQLException ex) {
// Try to continue with the existing, potentially still wrapped, DataSource
}
return new DataSourceBuilder<>(unwrap(dataSource));
}
private static DataSource unwrap(DataSource dataSource) {
try {
while (dataSource.isWrapperFor(DataSource.class)) {
DataSource unwrapped = dataSource.unwrap(DataSource.class);
if (unwrapped == dataSource) {
return unwrapped;
}
dataSource = unwrapped;
}
}
catch (SQLException ex) {
// Try to continue with the existing, potentially still wrapped, DataSource
}
return dataSource;
}
/**

View File

@ -18,11 +18,14 @@ package org.springframework.boot.jdbc;
import java.io.Closeable;
import java.io.IOException;
import java.io.PrintWriter;
import java.net.URL;
import java.net.URLClassLoader;
import java.sql.Connection;
import java.sql.SQLException;
import java.sql.SQLFeatureNotSupportedException;
import java.util.Arrays;
import java.util.logging.Logger;
import javax.sql.DataSource;
@ -343,6 +346,19 @@ class DataSourceBuilderTests {
assertThat(built.getUrl()).startsWith("jdbc:hsqldb:mem");
}
@Test
void buildWhenDerivedFromWrappedDataSource() {
HikariDataSource dataSource = new HikariDataSource();
dataSource.setUsername("test");
dataSource.setPassword("secret");
dataSource.setJdbcUrl("jdbc:h2:test");
DataSourceBuilder<?> builder = DataSourceBuilder.derivedFrom(wrap(wrap(dataSource)));
HikariDataSource built = (HikariDataSource) builder.username("test2").password("secret2").build();
assertThat(built.getUsername()).isEqualTo("test2");
assertThat(built.getPassword()).isEqualTo("secret2");
assertThat(built.getJdbcUrl()).isEqualTo("jdbc:h2:test");
}
@Test // gh-26644
void buildWhenDerivedFromExistingDatabaseWithTypeChange() {
HikariDataSource dataSource = new HikariDataSource();
@ -384,6 +400,65 @@ class DataSourceBuilderTests {
assertThat(testSource.getPassword()).isEqualTo("secret");
}
private DataSource wrap(DataSource target) {
return new DataSourceWrapper(target);
}
private static final class DataSourceWrapper implements DataSource {
private final DataSource delegate;
private DataSourceWrapper(DataSource delegate) {
this.delegate = delegate;
}
@Override
public Logger getParentLogger() throws SQLFeatureNotSupportedException {
return this.delegate.getParentLogger();
}
@Override
public <T> T unwrap(Class<T> iface) throws SQLException {
return this.delegate.unwrap(iface);
}
@Override
public boolean isWrapperFor(Class<?> iface) throws SQLException {
return this.delegate.isWrapperFor(iface);
}
@Override
public Connection getConnection() throws SQLException {
return this.delegate.getConnection();
}
@Override
public Connection getConnection(String username, String password) throws SQLException {
return this.delegate.getConnection(username, password);
}
@Override
public PrintWriter getLogWriter() throws SQLException {
return this.delegate.getLogWriter();
}
@Override
public void setLogWriter(PrintWriter out) throws SQLException {
this.delegate.setLogWriter(out);
}
@Override
public void setLoginTimeout(int seconds) throws SQLException {
this.delegate.setLoginTimeout(seconds);
}
@Override
public int getLoginTimeout() throws SQLException {
return this.delegate.getLoginTimeout();
}
}
final class HidePackagesClassLoader extends URLClassLoader {
private final String[] hiddenPackages;