diff --git a/datasource-decorator-spring-boot-autoconfigure/build.gradle.kts b/datasource-decorator-spring-boot-autoconfigure/build.gradle.kts index c318db6..c9a20e0 100644 --- a/datasource-decorator-spring-boot-autoconfigure/build.gradle.kts +++ b/datasource-decorator-spring-boot-autoconfigure/build.gradle.kts @@ -42,6 +42,7 @@ dependencies { testImplementation("org.apache.commons:commons-dbcp2:2.9.0") testImplementation("org.apache.tomcat:tomcat-jdbc:10.1.5") testImplementation("com.zaxxer:HikariCP:5.0.1") + testImplementation("org.flywaydb:flyway-core:9.5.1") } tasks { diff --git a/datasource-decorator-spring-boot-autoconfigure/src/main/java/com/github/gavlyukovskiy/boot/jdbc/decorator/DecoratedDataSource.java b/datasource-decorator-spring-boot-autoconfigure/src/main/java/com/github/gavlyukovskiy/boot/jdbc/decorator/DecoratedDataSource.java index e98413a..7c185be 100644 --- a/datasource-decorator-spring-boot-autoconfigure/src/main/java/com/github/gavlyukovskiy/boot/jdbc/decorator/DecoratedDataSource.java +++ b/datasource-decorator-spring-boot-autoconfigure/src/main/java/com/github/gavlyukovskiy/boot/jdbc/decorator/DecoratedDataSource.java @@ -20,6 +20,7 @@ import javax.sql.DataSource; +import java.sql.SQLException; import java.util.List; import java.util.stream.Collectors; @@ -83,6 +84,24 @@ public List getDecoratingChain() { return decoratingChain; } + @Override + @SuppressWarnings("unchecked") + public T unwrap(Class iface) throws SQLException { + // Spring Boot unwrapping simply passes 'unwrap(DataSource.class)' expecting real datasource to be returned + // if the real datasource type matches - return real datasource + if (iface.isInstance(getRealDataSource())) { + return (T) getRealDataSource(); + } + // As some decorators don't consider their types during unwrapping + // if their type is specifically requested, we can return the decorator itself + for (DataSourceDecorationStage dataSourceDecorationStage : decoratingChain) { + if (iface.isInstance(dataSourceDecorationStage.getDataSource())) { + return (T) dataSourceDecorationStage.getDataSource(); + } + } + return super.unwrap(iface); + } + @Override public String toString() { return decoratingChain.stream() diff --git a/datasource-decorator-spring-boot-autoconfigure/src/test/java/com/github/gavlyukovskiy/boot/jdbc/decorator/DataSourceDecoratorAutoConfigurationTests.java b/datasource-decorator-spring-boot-autoconfigure/src/test/java/com/github/gavlyukovskiy/boot/jdbc/decorator/DataSourceDecoratorAutoConfigurationTests.java index 77ec445..74a90ce 100644 --- a/datasource-decorator-spring-boot-autoconfigure/src/test/java/com/github/gavlyukovskiy/boot/jdbc/decorator/DataSourceDecoratorAutoConfigurationTests.java +++ b/datasource-decorator-spring-boot-autoconfigure/src/test/java/com/github/gavlyukovskiy/boot/jdbc/decorator/DataSourceDecoratorAutoConfigurationTests.java @@ -27,6 +27,7 @@ import org.springframework.beans.DirectFieldAccessor; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.context.PropertyPlaceholderAutoConfiguration; +import org.springframework.boot.autoconfigure.flyway.FlywayAutoConfiguration; import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; @@ -254,6 +255,35 @@ void testRoutingDataSourceIsNotDecorated() { }); } + @Test + void testUnwrapsRealDataSourceForFlyway() { + ApplicationContextRunner contextRunner = this.contextRunner + .withConfiguration(AutoConfigurations.of(FlywayAutoConfiguration.class)) + .withPropertyValues("spring.flyway.user=sa"); + + contextRunner.run(context -> { + DataSource dataSource = context.getBean(DataSource.class); + assertThat(dataSource).isInstanceOf(DecoratedDataSource.class); + assertThat(dataSource.unwrap(HikariDataSource.class)).isInstanceOf(HikariDataSource.class); + assertThat(dataSource.unwrap(DataSource.class)).isInstanceOf(HikariDataSource.class); + }); + } + + @Test + void testUnwrapsProxyDataSourcesFromChain() { + ApplicationContextRunner contextRunner = this.contextRunner + .withConfiguration(AutoConfigurations.of(FlywayAutoConfiguration.class)) + .withPropertyValues("spring.flyway.user=sa"); + + contextRunner.run(context -> { + DataSource dataSource = context.getBean(DataSource.class); + assertThat(dataSource).isInstanceOf(DecoratedDataSource.class); + assertThat(dataSource.unwrap(ProxyDataSource.class)).isInstanceOf(ProxyDataSource.class); + assertThat(dataSource.unwrap(P6DataSource.class)).isInstanceOf(P6DataSource.class); + assertThat(dataSource.unwrap(FlexyPoolDataSource.class)).isInstanceOf(FlexyPoolDataSource.class); + }); + } + private AbstractListAssert, Object, ObjectAssert> assertThatDataSourceDecoratingChain(DataSource dataSource) { return assertThat(((DecoratedDataSource) dataSource).getDecoratingChain()).extracting("dataSource").extracting("class"); }