diff --git a/util/internal/test/common/src/main/java/org/hibernate/search/util/impl/test/extension/MultiRunExtension.java b/util/internal/test/common/src/main/java/org/hibernate/search/util/impl/test/extension/MultiRunExtension.java new file mode 100644 index 00000000000..2847ba8328b --- /dev/null +++ b/util/internal/test/common/src/main/java/org/hibernate/search/util/impl/test/extension/MultiRunExtension.java @@ -0,0 +1,233 @@ +/* + * Hibernate Search, full-text search for your domain model + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later + * See the lgpl.txt file in the root directory or . + */ +package org.hibernate.search.util.impl.test.extension; + +import static java.util.Spliterators.spliteratorUnknownSize; +import static java.util.stream.StreamSupport.stream; +import static org.junit.platform.commons.util.AnnotationUtils.findRepeatableAnnotations; +import static org.junit.platform.commons.util.AnnotationUtils.isAnnotated; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Spliterator; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.AfterEachCallback; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.extension.Extension; +import org.junit.jupiter.api.extension.ExtensionContext; +import org.junit.jupiter.api.extension.ParameterContext; +import org.junit.jupiter.api.extension.ParameterResolutionException; +import org.junit.jupiter.api.extension.ParameterResolver; +import org.junit.jupiter.api.extension.TestTemplateInvocationContext; +import org.junit.jupiter.api.extension.TestTemplateInvocationContextProvider; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.ArgumentsProvider; +import org.junit.jupiter.params.provider.ArgumentsSource; +import org.junit.jupiter.params.support.AnnotationConsumerInitializer; +import org.junit.platform.commons.util.AnnotationUtils; +import org.junit.platform.commons.util.ReflectionUtils; + +import org.opentest4j.AssertionFailedError; + +public final class MultiRunExtension + implements TestTemplateInvocationContextProvider, AfterEachCallback, ParameterResolver { + private List envArguments; + private int envIndex = 0; + private boolean envInitialized = false; + private Method initMethod; + + @Override + public boolean supportsParameter(ParameterContext parameterContext, ExtensionContext extensionContext) + throws ParameterResolutionException { + return true; + } + + @Override + public Object resolveParameter(ParameterContext parameterContext, ExtensionContext extensionContext) + throws ParameterResolutionException { + if ( envArguments == null ) { + envArguments = new ArrayList<>(); + Method testMethod = extensionContext.getRequiredTestMethod(); + + for ( ArgumentsSource source : findRepeatableAnnotations( testMethod, ArgumentsSource.class ) ) { + ArgumentsProvider argumentsProvider = + AnnotationConsumerInitializer.initialize( testMethod, ReflectionUtils.newInstance( source.value() ) ); + try { + envArguments.addAll( + argumentsProvider.provideArguments( extensionContext ) + .map( Arguments::get ) + .collect( Collectors.toList() ) ); + } + catch (Exception e) { + throw new IllegalStateException( "unable to read arguments.", e ); + } + } + } + return envArguments.get( envIndex )[parameterContext.getIndex()]; + } + + @Target({ ElementType.METHOD }) + @Retention(RetentionPolicy.RUNTIME) + @TestTemplate + @ExtendWith(MultiRunExtension.class) + public @interface EnvironmentTest { + String init() default "init"; + } + + @Target({ ElementType.METHOD }) + @Retention(RetentionPolicy.RUNTIME) + public @interface TestToExecute { + } + + @Override + public boolean supportsTestTemplate(ExtensionContext context) { + return isAnnotated( context.getTestMethod(), EnvironmentTest.class ); + } + + @Override + public Stream provideTestTemplateInvocationContexts(ExtensionContext context) { + EnvironmentTest environment = context.getTestMethod() + .flatMap( method -> AnnotationUtils.findAnnotation( method, EnvironmentTest.class ) ) + .orElseThrow( IllegalStateException::new ); + + String init = environment.init(); + Class testClass = context.getTestClass().orElseThrow(); + try { + initMethod = testClass.getDeclaredMethod( init, context.getRequiredTestMethod().getParameterTypes() ); + initMethod.setAccessible( true ); + } + catch (NoSuchMethodException e) { + throw new IllegalStateException( "Cannot locate init method.", e ); + } + + + // find actual "tests" that we'll invoke via reflection: + List testMethods = new ArrayList<>(); + for ( Method method : testClass.getDeclaredMethods() ) { + if ( AnnotationUtils.isAnnotated( method, TestToExecute.class ) ) { + testMethods.add( method ); + } + } + + if ( testMethods.isEmpty() ) { + throw new IllegalStateException( "No tests to execute were found." ); + } + + return stream( spliteratorUnknownSize( + new Iterator() { + + Iterator test = testMethods.iterator(); + + @Override + public boolean hasNext() { + if ( Boolean.TRUE.equals( read( context, StoreKey.STOP_RUNNING, Boolean.class ) ) ) { + return false; + } + if ( test.hasNext() ) { + return true; + } + else { + envIndex++; + envInitialized = false; + } + + if ( envIndex < envArguments.size() ) { + test = testMethods.iterator(); + return test.hasNext(); + } + return false; + } + + @Override + public TestTemplateInvocationContext next() { + if ( envInitialized ) { + Method testMethod = test.next(); + write( context, StoreKey.TEST_TO_RUN, testMethod ); + + return new TestTemplateInvocationContext() { + + @Override + public String getDisplayName(int invocationIndex) { + return "Env #" + envIndex + ": " + testMethod.getName(); + } + + @Override + public List getAdditionalExtensions() { + return TestTemplateInvocationContext.super.getAdditionalExtensions(); + } + }; + } + else { + return new TestTemplateInvocationContext() { + + @Override + public String getDisplayName(int invocationIndex) { + return "Env #" + envIndex + ": Initializing"; + } + + @Override + public List getAdditionalExtensions() { + return TestTemplateInvocationContext.super.getAdditionalExtensions(); + } + }; + } + } + }, Spliterator.NONNULL + ), false ); + } + + @Override + public void afterEach(ExtensionContext extensionContext) throws Exception { + // that's where we actually execute the test or init the env + if ( envInitialized ) { + Method testMethod = read( extensionContext, StoreKey.TEST_TO_RUN, Method.class ); + testMethod.setAccessible( true ); + testMethod.invoke( extensionContext.getRequiredTestInstance() ); + } + else { + try { + initMethod.invoke( extensionContext.getRequiredTestInstance(), envArguments.get( envIndex ) ); + } + catch (Exception e) { + envIndex++; + if ( envIndex >= envArguments.size() ) { + write( extensionContext, StoreKey.STOP_RUNNING, Boolean.TRUE ); + } + throw new AssertionFailedError( "Unable to init the env, stopping further execution", e ); + } + envInitialized = true; + } + } + + private void write(ExtensionContext context, StoreKey key, Object value) { + ExtensionContext.Store store = context.getRoot().getStore( + ExtensionContext.Namespace.create( context.getRequiredTestMethod() ) + ); + store.put( key, value ); + } + + private T read(ExtensionContext context, StoreKey key, Class clazz) { + ExtensionContext.Store store = context.getRoot().getStore( + ExtensionContext.Namespace.create( context.getRequiredTestMethod() ) + ); + return store.get( key, clazz ); + } + + private enum StoreKey { + TEST_TO_RUN, + STOP_RUNNING + } +} diff --git a/util/internal/test/common/src/test/java/org/hibernate/search/util/impl/test/reflect/Experiments2Test.java b/util/internal/test/common/src/test/java/org/hibernate/search/util/impl/test/reflect/Experiments2Test.java new file mode 100644 index 00000000000..bdef4195fe7 --- /dev/null +++ b/util/internal/test/common/src/test/java/org/hibernate/search/util/impl/test/reflect/Experiments2Test.java @@ -0,0 +1,295 @@ +/* + * Hibernate Search, full-text search for your domain model + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later + * See the lgpl.txt file in the root directory or . + */ +package org.hibernate.search.util.impl.test.reflect; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.hibernate.search.util.impl.test.function.ThrowingConsumer; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DynamicContainer; +import org.junit.jupiter.api.DynamicNode; +import org.junit.jupiter.api.DynamicTest; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestFactory; +import org.junit.jupiter.api.extension.BeforeEachCallback; +import org.junit.jupiter.api.extension.DynamicTestInvocationContext; +import org.junit.jupiter.api.extension.ExtensionContext; +import org.junit.jupiter.api.extension.InvocationInterceptor; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.api.io.TempDir; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.platform.commons.util.AnnotationUtils; +import org.junit.platform.commons.util.ReflectionUtils; + +class Experiments2Test { + + public static Stream params() { + return Stream.of( + Arguments.of( "string1", 1, true ), + Arguments.of( "string2", 2, true ), + Arguments.of( "string3", 3, true ), + Arguments.of( "string3", 3, false ) + ); + } + + @TempDir + Path tmpDir; + + @RegisterExtension + MyExtension extensionTrackingState = new MyExtension(); + + @TestFactory + Stream testsInnerClass() throws Exception { + return generateTestSequence( + new TestDescriptor<>( this, TheseAreActualTestsButTheyNeedParameters.class ), + params() + ); + } + + @BeforeEach + void setUp() { + // this will work as "before all" for the "nested test class" that we execute from `@TestFactory` + System.err.println( "before each root" ); + } + + @TestFactory + Stream testsFromStaticClass() throws Exception { + return generateTestSequence( + new TestDescriptor<>( this, TheseAreActualTestsButTheyNeedParametersAndClassIsStatic.class ), + params() + ); + } + + public class TheseAreActualTestsButTheyNeedParameters { + + // inner extensions won't work: + @TempDir + Path otherTmpDir; + + private final String string; + private final int number; + private final boolean bool; + + public TheseAreActualTestsButTheyNeedParameters(String string, int number, boolean bool) { + this.string = string; + this.number = number; + this.bool = bool; + + if ( extensionTrackingState.setState( new Object[] { string, number } ) ) { + System.err.println( "init" ); + System.err.println( "\t" + string ); + System.err.println( "\t" + number ); + System.err.println( "\t" + bool ); + } + } + + @Test + void test1() { + System.err.println( "test1" ); + assertThat( number ).isPositive(); + } + + @Test + void test2() { + System.err.println( "test2" ); + assertThat( string ).startsWith( "string" ); + } + + @Test + void test3() { + assertThat( tmpDir ).exists(); + assertThat( otherTmpDir ).isNull(); + } + } + + public static class TheseAreActualTestsButTheyNeedParametersAndClassIsStatic { + private final String string; + private final int number; + private final boolean bool; + + public TheseAreActualTestsButTheyNeedParametersAndClassIsStatic(String string, int number, boolean bool) { + this.string = string; + this.number = number; + this.bool = bool; + + System.err.println( "init" ); + System.err.println( "\t" + string ); + System.err.println( "\t" + number ); + System.err.println( "\t" + bool ); + } + + @BeforeEach + void setUp() { + System.err.println( "do something important" ); + } + + @Test + void test1() { + System.err.println( "test1" ); + assertThat( number ).isPositive(); + } + + @Test + void test2() { + System.err.println( "test2" ); + assertThat( string ).startsWith( "string" ); + } + } + + + public static class TestDescriptor { + private final Object testInstance; + private final Class klass; + private Constructor constructor; + + private final boolean innerClass; + + protected TestDescriptor(Object testInstance, Class klass) { + this.testInstance = testInstance; + this.klass = klass; + this.innerClass = klass.isMemberClass() && !Modifier.isStatic( klass.getModifiers() ); + } + + T createInstance(Object[] arguments) throws InvocationTargetException, InstantiationException, IllegalAccessException { + if ( constructor == null ) { + List> candidateConstructors = Arrays.stream( klass.getDeclaredConstructors() ) + .filter( c -> c.getParameterCount() == arguments.length + ( innerClass ? 1 : 0 ) ) + .collect( Collectors.toList() ); + if ( candidateConstructors.size() != 1 ) { + throw new IllegalStateException( + "Cannot find a suitable constructor to instantiate the test instance of " + klass + + ". Test class must have a single constructor with arguments matching the list of parameter arguments" ); + } + + constructor = (Constructor) candidateConstructors.get( 0 ); + constructor.setAccessible( true ); + } + + Object[] args; + if ( innerClass ) { + args = new Object[arguments.length + 1]; + args[0] = testInstance; + for ( int i = 0; i < arguments.length; i++ ) { + args[i + 1] = arguments[i]; + } + } + else { + args = arguments; + } + + return constructor.newInstance( args ); + } + + private String groupTestName(Object[] arguments) { + StringBuilder sb = new StringBuilder( "Configuration group:" ); + for ( int i = 0; i < arguments.length; i++ ) { + sb.append( " [%s]" ); + } + return String.format( Locale.ROOT, sb.toString(), arguments ); + } + + } + + public static Stream generateTestSequence( + TestDescriptor testDescriptor, + Stream streamOfArguments + ) { + + final List testMethods = AnnotationUtils.findAnnotatedMethods( testDescriptor.klass, Test.class, + ReflectionUtils.HierarchyTraversalMode.TOP_DOWN ); + if ( testMethods.isEmpty() ) { + throw new IllegalStateException( "No test methods found in " + testDescriptor.klass + + ". Class must have tests annotated with @org.junit.jupiter.api.Test." ); + } + + final List beforeEachMethods = AnnotationUtils.findAnnotatedMethods( testDescriptor.klass, + BeforeEach.class, ReflectionUtils.HierarchyTraversalMode.TOP_DOWN ); + + final List afterEachMethods = AnnotationUtils.findAnnotatedMethods( testDescriptor.klass, + AfterEach.class, ReflectionUtils.HierarchyTraversalMode.BOTTOM_UP ); + + return streamOfArguments + .map( args -> { + final Object[] arguments = args.get(); + + final ThrowingConsumer testInvoker = new ThrowingConsumer<>() { + private Object instance; + + @Override + public void accept(Method testMethod) throws Exception { + if ( instance == null ) { + instance = testDescriptor.createInstance( arguments ); + } + + try { + for ( Method method : beforeEachMethods ) { + method.invoke( instance ); + } + + testMethod.invoke( instance ); + + } + finally { + for ( Method method : afterEachMethods ) { + method.invoke( instance ); + } + } + } + }; + + return DynamicContainer.dynamicContainer( + testDescriptor.groupTestName( arguments ), + testMethods.stream() + .map( method -> DynamicTest.dynamicTest( + method.getName() + "()", + () -> testInvoker.accept( method ) ) + ) + ); + } ); + } + + + private static class MyExtension implements InvocationInterceptor, BeforeEachCallback { + private Object[] state; + + @Override + public void interceptDynamicTest(Invocation invocation, DynamicTestInvocationContext invocationContext, + ExtensionContext extensionContext) + throws Throwable { + System.err.println( "do nothing" ); + InvocationInterceptor.super.interceptDynamicTest( invocation, invocationContext, extensionContext ); + } + + public boolean setState(Object[] state) { + if ( !Arrays.equals( this.state, state ) ) { + this.state = state; + return true; + } + return false; + } + + @Override + public void beforeEach(ExtensionContext extensionContext) throws Exception { + // this will work as "before all" for the "nested test class" that we execute from `@TestFactory` + System.err.println( "before each from extension callback" ); + } + } + +} diff --git a/util/internal/test/common/src/test/java/org/hibernate/search/util/impl/test/reflect/ExperimentsTest.java b/util/internal/test/common/src/test/java/org/hibernate/search/util/impl/test/reflect/ExperimentsTest.java new file mode 100644 index 00000000000..b2c8f8dd9e7 --- /dev/null +++ b/util/internal/test/common/src/test/java/org/hibernate/search/util/impl/test/reflect/ExperimentsTest.java @@ -0,0 +1,58 @@ +/* + * Hibernate Search, full-text search for your domain model + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later + * See the lgpl.txt file in the root directory or . + */ +package org.hibernate.search.util.impl.test.reflect; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.Arrays; +import java.util.List; + +import org.hibernate.search.util.impl.test.extension.MultiRunExtension; + +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +class ExperimentsTest { + + public static List params() { + return Arrays.asList( + Arguments.of( "string1", 1, true ), + Arguments.of( "string2", 2, true ), + Arguments.of( "string3", 3, true ) + ); + } + + @MultiRunExtension.EnvironmentTest(init = "init") + @MethodSource("params") + void env(String string, int number, boolean bool) { + System.err.println( "env" ); + // this one is really ignored, and should be empty. alternatively it can be treated as "@BeforeEach" + // since this method will be executed before both for the init and each @TestToExecute + + // input parameters are here so we can use default argument sources and feed these arguments to the init method when needed. + // note init method *MUST* match the same type/order/amount of parameters. + } + + public void init(String string, int number, boolean bool) { + System.err.println( "init" ); + System.err.println( "\t" + string ); + System.err.println( "\t" + number ); + System.err.println( "\t" + bool ); + } + + @MultiRunExtension.TestToExecute + void test1() { + System.err.println( "test1" ); + assertThat( 1 ).isPositive(); + } + + @MultiRunExtension.TestToExecute + void test2() { + System.err.println( "test2" ); + assertThat( 1 ).isPositive(); + } +}