diff --git a/java/dagger/hilt/android/internal/lifecycle/BUILD b/java/dagger/hilt/android/internal/lifecycle/BUILD index f7314d279c9..7afd1dd9dee 100644 --- a/java/dagger/hilt/android/internal/lifecycle/BUILD +++ b/java/dagger/hilt/android/internal/lifecycle/BUILD @@ -38,6 +38,7 @@ android_library( "@maven//:androidx_lifecycle_lifecycle_viewmodel", "@maven//:androidx_lifecycle_lifecycle_viewmodel_savedstate", "@maven//:androidx_savedstate_savedstate", + "@maven//:org_jetbrains_kotlin_kotlin_stdlib", ], ) diff --git a/java/dagger/hilt/android/internal/lifecycle/HiltViewModelAssistedMap.java b/java/dagger/hilt/android/internal/lifecycle/HiltViewModelAssistedMap.java new file mode 100644 index 00000000000..69bb2b11842 --- /dev/null +++ b/java/dagger/hilt/android/internal/lifecycle/HiltViewModelAssistedMap.java @@ -0,0 +1,32 @@ +/* + * Copyright (C) 2023 The Dagger Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dagger.hilt.android.internal.lifecycle; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; +import javax.inject.Qualifier; + +/** + * Internal qualifier for the multibinding map of assisted factories for @AssistedInject-annotated + * ViewModels used by the {@link dagger.hilt.android.lifecycle.HiltViewModelFactory}. + */ +@Qualifier +@Retention(RetentionPolicy.CLASS) +@Target({ElementType.METHOD, ElementType.PARAMETER}) +public @interface HiltViewModelAssistedMap {} diff --git a/java/dagger/hilt/android/internal/lifecycle/HiltViewModelFactory.java b/java/dagger/hilt/android/internal/lifecycle/HiltViewModelFactory.java index 52f31b92027..33b1618031d 100644 --- a/java/dagger/hilt/android/internal/lifecycle/HiltViewModelFactory.java +++ b/java/dagger/hilt/android/internal/lifecycle/HiltViewModelFactory.java @@ -16,12 +16,12 @@ package dagger.hilt.android.internal.lifecycle; +import static androidx.lifecycle.SavedStateHandleSupport.createSavedStateHandle; + import android.app.Activity; import android.os.Bundle; import androidx.annotation.NonNull; import androidx.annotation.Nullable; -import androidx.lifecycle.AbstractSavedStateViewModelFactory; -import androidx.lifecycle.SavedStateHandle; import androidx.lifecycle.ViewModel; import androidx.lifecycle.ViewModelProvider; import androidx.lifecycle.viewmodel.CreationExtras; @@ -37,6 +37,7 @@ import java.util.Map; import java.util.Set; import javax.inject.Provider; +import kotlin.jvm.functions.Function1; /** * View Model Provider Factory for the Hilt Extension. @@ -55,8 +56,16 @@ public final class HiltViewModelFactory implements ViewModelProvider.Factory { public interface ViewModelFactoriesEntryPoint { @HiltViewModelMap Map> getHiltViewModelMap(); + + // From ViewModel class names to user defined @AssistedFactory-annotated implementations. + @HiltViewModelAssistedMap + Map getHiltViewModelAssistedMap(); } + /** Creation extra key for the callbacks that create @AssistedInject-annotated ViewModels. */ + public static final CreationExtras.Key> CREATION_CALLBACK_KEY = + new CreationExtras.Key>() {}; + /** Hilt module for providing the empty multi-binding map of ViewModels. */ @Module @InstallIn(ViewModelComponent.class) @@ -64,11 +73,15 @@ interface ViewModelModule { @Multibinds @HiltViewModelMap Map hiltViewModelMap(); + + @Multibinds + @HiltViewModelAssistedMap + Map hiltViewModelAssistedMap(); } private final Set hiltViewModelKeys; private final ViewModelProvider.Factory delegateFactory; - private final AbstractSavedStateViewModelFactory hiltViewModelFactory; + private final ViewModelProvider.Factory hiltViewModelFactory; public HiltViewModelFactory( @NonNull Set hiltViewModelKeys, @@ -77,31 +90,75 @@ public HiltViewModelFactory( this.hiltViewModelKeys = hiltViewModelKeys; this.delegateFactory = delegateFactory; this.hiltViewModelFactory = - new AbstractSavedStateViewModelFactory() { + new ViewModelProvider.Factory() { @NonNull @Override - @SuppressWarnings("unchecked") - protected T create( - @NonNull String key, @NonNull Class modelClass, @NonNull SavedStateHandle handle) { + public T create( + @NonNull Class modelClass, @NonNull CreationExtras extras) { RetainedLifecycleImpl lifecycle = new RetainedLifecycleImpl(); - ViewModelComponent component = viewModelComponentBuilder - .savedStateHandle(handle) - .viewModelLifecycle(lifecycle) - .build(); + ViewModelComponent component = + viewModelComponentBuilder + .savedStateHandle(createSavedStateHandle(extras)) + .viewModelLifecycle(lifecycle) + .build(); + T viewModel = createViewModel(component, modelClass, extras); + viewModel.addCloseable(lifecycle::dispatchOnCleared); + return viewModel; + } + + private T createViewModel( + @NonNull ViewModelComponent component, + @NonNull Class modelClass, + @NonNull CreationExtras extras) { Provider provider = EntryPoints.get(component, ViewModelFactoriesEntryPoint.class) .getHiltViewModelMap() .get(modelClass.getName()); - if (provider == null) { - throw new IllegalStateException( - "Expected the @HiltViewModel-annotated class '" - + modelClass.getName() - + "' to be available in the multi-binding of " - + "@HiltViewModelMap but none was found."); + Function1 creationCallback = extras.get(CREATION_CALLBACK_KEY); + Object assistedFactory = + EntryPoints.get(component, ViewModelFactoriesEntryPoint.class) + .getHiltViewModelAssistedMap() + .get(modelClass.getName()); + + if (assistedFactory == null) { + if (creationCallback == null) { + if (provider == null) { + throw new IllegalStateException( + "Expected the @HiltViewModel-annotated class " + + modelClass.getName() + + " to be available in the multi-binding of " + + "@HiltViewModelMap" + + " but none was found."); + } else { + return (T) provider.get(); + } + } else { + // Provider could be null or non-null. + throw new IllegalStateException( + "Found creation callback but class " + + modelClass.getName() + + " does not have an assisted factory specified in @HiltViewModel."); + } + } else { + if (provider == null) { + if (creationCallback == null) { + throw new IllegalStateException( + "Found @HiltViewModel-annotated class " + + modelClass.getName() + + " using @AssistedInject but no creation callback" + + " was provided in CreationExtras."); + } else { + return (T) creationCallback.invoke(assistedFactory); + } + } else { + // Creation callback could be null or non-null. + throw new AssertionError( + "Found the @HiltViewModel-annotated class " + + modelClass.getName() + + " in both the multi-bindings of " + + "@HiltViewModelMap and @HiltViewModelAssistedMap."); + } } - ViewModel viewModel = provider.get(); - viewModel.addCloseable(lifecycle::dispatchOnCleared); - return (T) viewModel; } }; } diff --git a/java/dagger/hilt/android/lifecycle/HiltViewModel.java b/java/dagger/hilt/android/lifecycle/HiltViewModel.java index 198ec8abeab..72683b9badc 100644 --- a/java/dagger/hilt/android/lifecycle/HiltViewModel.java +++ b/java/dagger/hilt/android/lifecycle/HiltViewModel.java @@ -65,4 +65,11 @@ @Target(ElementType.TYPE) @Retention(RetentionPolicy.CLASS) @GeneratesRootInput -public @interface HiltViewModel {} +public @interface HiltViewModel { + /** + * Returns a factory class that can be used to create this ViewModel with assisted injection. The + * default value `Object.class` denotes that no factory is specified and the ViewModel is not + * assisted injected. + */ + Class assistedFactory() default Object.class; +} diff --git a/java/dagger/hilt/android/processor/internal/AndroidClassNames.java b/java/dagger/hilt/android/processor/internal/AndroidClassNames.java index 2d954481a61..7f5dec060ca 100644 --- a/java/dagger/hilt/android/processor/internal/AndroidClassNames.java +++ b/java/dagger/hilt/android/processor/internal/AndroidClassNames.java @@ -114,6 +114,10 @@ public final class AndroidClassNames { get("dagger.hilt.android.lifecycle", "HiltViewModel"); public static final ClassName HILT_VIEW_MODEL_MAP_QUALIFIER = get("dagger.hilt.android.internal.lifecycle", "HiltViewModelMap"); + + public static final ClassName HILT_VIEW_MODEL_ASSISTED_FACTORY_MAP_QUALIFIER = + get("dagger.hilt.android.internal.lifecycle", "HiltViewModelAssistedMap"); + public static final ClassName HILT_VIEW_MODEL_KEYS_QUALIFIER = get("dagger.hilt.android.internal.lifecycle", "HiltViewModelMap", "KeySet"); public static final ClassName VIEW_MODEL = get("androidx.lifecycle", "ViewModel"); diff --git a/java/dagger/hilt/android/processor/internal/viewmodel/BUILD b/java/dagger/hilt/android/processor/internal/viewmodel/BUILD index f3686a1c818..e5c3eaf50b8 100644 --- a/java/dagger/hilt/android/processor/internal/viewmodel/BUILD +++ b/java/dagger/hilt/android/processor/internal/viewmodel/BUILD @@ -38,6 +38,7 @@ kt_jvm_library( "//java/dagger/hilt/android/processor/internal:android_classnames", "//java/dagger/hilt/processor/internal:base_processor", "//java/dagger/hilt/processor/internal:classnames", + "//java/dagger/hilt/processor/internal:compiler_options", "//java/dagger/hilt/processor/internal:processor_errors", "//java/dagger/hilt/processor/internal:processors", "//java/dagger/internal/codegen/xprocessing", diff --git a/java/dagger/hilt/android/processor/internal/viewmodel/ViewModelMetadata.kt b/java/dagger/hilt/android/processor/internal/viewmodel/ViewModelMetadata.kt index 49b35a7d93f..920410a0b94 100644 --- a/java/dagger/hilt/android/processor/internal/viewmodel/ViewModelMetadata.kt +++ b/java/dagger/hilt/android/processor/internal/viewmodel/ViewModelMetadata.kt @@ -16,83 +16,181 @@ package dagger.hilt.android.processor.internal.viewmodel +import androidx.room.compiler.codegen.XTypeName +import androidx.room.compiler.codegen.toJavaPoet import androidx.room.compiler.processing.ExperimentalProcessingApi +import androidx.room.compiler.processing.XMethodElement import androidx.room.compiler.processing.XProcessingEnv import androidx.room.compiler.processing.XTypeElement import com.squareup.javapoet.ClassName import dagger.hilt.android.processor.internal.AndroidClassNames import dagger.hilt.processor.internal.ClassNames +import dagger.hilt.processor.internal.HiltCompilerOptions import dagger.hilt.processor.internal.ProcessorErrors import dagger.hilt.processor.internal.Processors import dagger.internal.codegen.xprocessing.XAnnotations +import dagger.internal.codegen.xprocessing.XElements +import dagger.internal.codegen.xprocessing.XTypeElements import dagger.internal.codegen.xprocessing.XTypes /** Data class that represents a Hilt injected ViewModel */ @OptIn(ExperimentalProcessingApi::class) -internal class ViewModelMetadata private constructor(val typeElement: XTypeElement) { - val className = typeElement.className +internal class ViewModelMetadata +private constructor(val viewModelElement: XTypeElement, val assistedFactory: XTypeElement) { + val className = viewModelElement.asClassName().toJavaPoet() + + val assistedFactoryClassName: ClassName = assistedFactory.asClassName().toJavaPoet() val modulesClassName = ClassName.get( - typeElement.packageName, + viewModelElement.packageName, "${className.simpleNames().joinToString("_")}_HiltModules" ) companion object { + + private const val ASSISTED_FACTORY_VALUE = "assistedFactory" + + fun getAssistedFactoryMethods(factory: XTypeElement?): List { + return XTypeElements.getAllNonPrivateInstanceMethods(factory) + .filter { it.isAbstract() } + .filter { !it.isJavaDefault() } + } + internal fun create( processingEnv: XProcessingEnv, - typeElement: XTypeElement, + viewModelElement: XTypeElement, ): ViewModelMetadata? { ProcessorErrors.checkState( - XTypes.isSubtype(typeElement.type, processingEnv.requireType(AndroidClassNames.VIEW_MODEL)), - typeElement, + XTypes.isSubtype( + viewModelElement.type, + processingEnv.requireType(AndroidClassNames.VIEW_MODEL) + ), + viewModelElement, "@HiltViewModel is only supported on types that subclass %s.", AndroidClassNames.VIEW_MODEL ) - typeElement - .getConstructors() - .filter { constructor -> - ProcessorErrors.checkState( - !constructor.hasAnnotation(ClassNames.ASSISTED_INJECT), - constructor, - "ViewModel constructor should be annotated with @Inject instead of @AssistedInject." - ) - constructor.hasAnnotation(ClassNames.INJECT) - } - .let { injectConstructors -> - ProcessorErrors.checkState( - injectConstructors.size == 1, - typeElement, - "@HiltViewModel annotated class should contain exactly one @Inject " + - "annotated constructor." - ) - - injectConstructors.forEach { injectConstructor -> + val isAssistedInjectFeatureEnabled = + HiltCompilerOptions.isAssistedInjectViewModelsEnabled(viewModelElement) + + val assistedFactoryType = + viewModelElement + .requireAnnotation(AndroidClassNames.HILT_VIEW_MODEL) + .getAsType(ASSISTED_FACTORY_VALUE) + val assistedFactory = assistedFactoryType.typeElement!! + + if (assistedFactoryType.asTypeName() != XTypeName.ANY_OBJECT) { + ProcessorErrors.checkState( + isAssistedInjectFeatureEnabled, + viewModelElement, + "Specified assisted factory %s for %s in @HiltViewModel but compiler option 'enableAssistedInjectViewModels' was not enabled.", + assistedFactoryType.asTypeName().toJavaPoet(), + XElements.toStableString(viewModelElement), + ) + + ProcessorErrors.checkState( + assistedFactory.hasAnnotation(ClassNames.ASSISTED_FACTORY), + viewModelElement, + "Class %s is not annotated with @AssistedFactory.", + assistedFactoryType.asTypeName().toJavaPoet() + ) + + val assistedFactoryMethod = getAssistedFactoryMethods(assistedFactory).singleOrNull() + + ProcessorErrors.checkState( + assistedFactoryMethod != null, + assistedFactory, + "Cannot find assisted factory method in %s.", + XElements.toStableString(assistedFactory) + ) + + val assistedFactoryMethodType = assistedFactoryMethod!!.asMemberOf(assistedFactoryType) + + ProcessorErrors.checkState( + assistedFactoryMethodType.returnType.asTypeName() == viewModelElement.asClassName(), + assistedFactoryMethod, + "Class %s must have a factory method that returns a %s. Found %s.", + XElements.toStableString(assistedFactory), + XElements.toStableString(viewModelElement), + XTypes.toStableString(assistedFactoryMethodType.returnType) + ) + } + + val injectConstructors = + viewModelElement.getConstructors().filter { constructor -> + if (isAssistedInjectFeatureEnabled) { + constructor.hasAnnotation(ClassNames.INJECT) || + constructor.hasAnnotation(ClassNames.ASSISTED_INJECT) + } else { ProcessorErrors.checkState( - !injectConstructor.isPrivate(), - injectConstructor, - "@Inject annotated constructors must not be private." + !constructor.hasAnnotation(ClassNames.ASSISTED_INJECT), + constructor, + "ViewModel constructor should be annotated with @Inject instead of @AssistedInject." ) + constructor.hasAnnotation(ClassNames.INJECT) } } + val injectAnnotationsMessage = + if (isAssistedInjectFeatureEnabled) { + "@Inject or @AssistedInject" + } else { + "@Inject" + } + + ProcessorErrors.checkState( + injectConstructors.size == 1, + viewModelElement, + "@HiltViewModel annotated class should contain exactly one %s annotated constructor.", + injectAnnotationsMessage + ) + + val injectConstructor = injectConstructors.single() + + ProcessorErrors.checkState( + !injectConstructor.isPrivate(), + injectConstructor, + "%s annotated constructors must not be private.", + injectAnnotationsMessage + ) + + if (injectConstructor.hasAnnotation(ClassNames.ASSISTED_INJECT)) { + // If "enableAssistedInjectViewModels" is not enabled we'll get error: + // "ViewModel constructor should be annotated with @Inject instead of @AssistedInject." + + ProcessorErrors.checkState( + assistedFactoryType.asTypeName() != XTypeName.ANY_OBJECT, + viewModelElement, + "%s must have a valid assisted factory specified in @HiltViewModel when used with assisted injection. Found %s.", + XElements.toStableString(viewModelElement), + XTypes.toStableString(assistedFactoryType) + ) + } else { + ProcessorErrors.checkState( + assistedFactoryType.asTypeName() == XTypeName.ANY_OBJECT, + injectConstructor, + "Found assisted factory %s in @HiltViewModel but the constructor was annotated with @Inject instead of @AssistedInject.", + XTypes.toStableString(assistedFactoryType), + ) + } + ProcessorErrors.checkState( - !typeElement.isNested() || typeElement.isStatic(), - typeElement, + !viewModelElement.isNested() || viewModelElement.isStatic(), + viewModelElement, "@HiltViewModel may only be used on inner classes if they are static." ) - Processors.getScopeAnnotations(typeElement).let { scopeAnnotations -> + Processors.getScopeAnnotations(viewModelElement).let { scopeAnnotations -> ProcessorErrors.checkState( scopeAnnotations.isEmpty(), - typeElement, + viewModelElement, "@HiltViewModel classes should not be scoped. Found: %s", scopeAnnotations.joinToString { XAnnotations.toStableString(it) } ) } - return ViewModelMetadata(typeElement) + return ViewModelMetadata(viewModelElement, assistedFactory) } } } diff --git a/java/dagger/hilt/android/processor/internal/viewmodel/ViewModelModuleGenerator.kt b/java/dagger/hilt/android/processor/internal/viewmodel/ViewModelModuleGenerator.kt index 29f8c3ce528..e7b3dad2f2f 100644 --- a/java/dagger/hilt/android/processor/internal/viewmodel/ViewModelModuleGenerator.kt +++ b/java/dagger/hilt/android/processor/internal/viewmodel/ViewModelModuleGenerator.kt @@ -16,6 +16,7 @@ package dagger.hilt.android.processor.internal.viewmodel +import androidx.room.compiler.codegen.XTypeName import androidx.room.compiler.processing.ExperimentalProcessingApi import androidx.room.compiler.processing.XProcessingEnv import androidx.room.compiler.processing.addOriginatingElement @@ -23,6 +24,7 @@ import com.squareup.javapoet.AnnotationSpec import com.squareup.javapoet.ClassName import com.squareup.javapoet.JavaFile import com.squareup.javapoet.MethodSpec +import com.squareup.javapoet.TypeName import com.squareup.javapoet.TypeSpec import dagger.hilt.android.processor.internal.AndroidClassNames import dagger.hilt.processor.internal.ClassNames @@ -60,20 +62,20 @@ import javax.lang.model.element.Modifier @OptIn(ExperimentalProcessingApi::class) internal class ViewModelModuleGenerator( private val processingEnv: XProcessingEnv, - private val injectedViewModel: ViewModelMetadata + private val viewModelMetadata: ViewModelMetadata ) { fun generate() { val modulesTypeSpec = - TypeSpec.classBuilder(injectedViewModel.modulesClassName) + TypeSpec.classBuilder(viewModelMetadata.modulesClassName) .apply { - addOriginatingElement(injectedViewModel.typeElement) + addOriginatingElement(viewModelMetadata.viewModelElement) Processors.addGeneratedAnnotation(this, processingEnv, ViewModelProcessor::class.java) addAnnotation( AnnotationSpec.builder(ClassNames.ORIGINATING_ELEMENT) .addMember( "topLevelClass", "$T.class", - injectedViewModel.className.topLevelClassName() + viewModelMetadata.className.topLevelClassName() ) .build() ) @@ -85,7 +87,7 @@ internal class ViewModelModuleGenerator( .build() processingEnv.filer.write( - JavaFile.builder(injectedViewModel.modulesClassName.packageName(), modulesTypeSpec).build() + JavaFile.builder(viewModelMetadata.modulesClassName.packageName(), modulesTypeSpec).build() ) } @@ -96,7 +98,13 @@ internal class ViewModelModuleGenerator( ) .addModifiers(Modifier.ABSTRACT) .addMethod(MethodSpec.constructorBuilder().addModifiers(Modifier.PRIVATE).build()) - .addMethod(getViewModelBindsMethod()) + .addMethod( + if (viewModelMetadata.assistedFactory.asClassName() != XTypeName.ANY_OBJECT) { + getAssistedViewModelBindsMethod() + } else { + getViewModelBindsMethod() + } + ) .build() private fun getViewModelBindsMethod() = @@ -105,13 +113,13 @@ internal class ViewModelModuleGenerator( .addAnnotation(ClassNames.INTO_MAP) .addAnnotation( AnnotationSpec.builder(ClassNames.STRING_KEY) - .addMember("value", S, injectedViewModel.className.reflectionName()) + .addMember("value", S, viewModelMetadata.className.reflectionName()) .build() ) .addAnnotation(AndroidClassNames.HILT_VIEW_MODEL_MAP_QUALIFIER) .addModifiers(Modifier.PUBLIC, Modifier.ABSTRACT) .returns(AndroidClassNames.VIEW_MODEL) - .addParameter(injectedViewModel.className, "vm") + .addParameter(viewModelMetadata.className, "vm") .build() private fun getKeyModuleTypeSpec() = @@ -131,12 +139,40 @@ internal class ViewModelModuleGenerator( .addAnnotation(AndroidClassNames.HILT_VIEW_MODEL_KEYS_QUALIFIER) .addModifiers(Modifier.PUBLIC, Modifier.STATIC) .returns(String::class.java) - .addStatement("return $S", injectedViewModel.className.reflectionName()) + .addStatement("return $S", viewModelMetadata.className.reflectionName()) + .build() + + /** + * Should generate: + * ``` + * @Binds + * @IntoMap + * @StringKey("pkg.FooViewModel") + * @HiltViewModelAssistedMap + * public abstract Object bind(FooViewModelAssistedFactory factory); + * ``` + * + * So that we have a HiltViewModelAssistedMap that maps from fully qualified ViewModel names to + * its assisted factory instance. + */ + private fun getAssistedViewModelBindsMethod() = + MethodSpec.methodBuilder("bind") + .addAnnotation(ClassNames.BINDS) + .addAnnotation(ClassNames.INTO_MAP) + .addAnnotation( + AnnotationSpec.builder(ClassNames.STRING_KEY) + .addMember("value", S, viewModelMetadata.className.reflectionName()) + .build() + ) + .addAnnotation(AndroidClassNames.HILT_VIEW_MODEL_ASSISTED_FACTORY_MAP_QUALIFIER) + .addModifiers(Modifier.PUBLIC, Modifier.ABSTRACT) + .addParameter(viewModelMetadata.assistedFactoryClassName, "factory") + .returns(TypeName.OBJECT) .build() private fun createModuleTypeSpec(className: String, component: ClassName) = TypeSpec.classBuilder(className) - .addOriginatingElement(injectedViewModel.typeElement) + .addOriginatingElement(viewModelMetadata.viewModelElement) .addAnnotation(ClassNames.MODULE) .addAnnotation( AnnotationSpec.builder(ClassNames.INSTALL_IN) diff --git a/java/dagger/hilt/android/processor/internal/viewmodel/ViewModelProcessingStep.kt b/java/dagger/hilt/android/processor/internal/viewmodel/ViewModelProcessingStep.kt index 69f7ac28493..55cb289fcd2 100644 --- a/java/dagger/hilt/android/processor/internal/viewmodel/ViewModelProcessingStep.kt +++ b/java/dagger/hilt/android/processor/internal/viewmodel/ViewModelProcessingStep.kt @@ -28,16 +28,13 @@ import dagger.internal.codegen.xprocessing.XElements @OptIn(ExperimentalProcessingApi::class) /** Annotation processor for @ViewModelInject. */ class ViewModelProcessingStep(env: XProcessingEnv) : BaseProcessingStep(env) { + override fun annotationClassNames() = ImmutableSet.of(AndroidClassNames.HILT_VIEW_MODEL) override fun processEach(annotation: ClassName, element: XElement) { val typeElement = XElements.asTypeElement(element) - ViewModelMetadata.create( - processingEnv(), - typeElement, - ) - ?.let { viewModelMetadata -> - ViewModelModuleGenerator(processingEnv(), viewModelMetadata).generate() - } + ViewModelMetadata.create(processingEnv(), typeElement)?.let { viewModelMetadata -> + ViewModelModuleGenerator(processingEnv(), viewModelMetadata).generate() + } } } diff --git a/java/dagger/hilt/processor/internal/ClassNames.java b/java/dagger/hilt/processor/internal/ClassNames.java index ec2e73ed407..890b404ae53 100644 --- a/java/dagger/hilt/processor/internal/ClassNames.java +++ b/java/dagger/hilt/processor/internal/ClassNames.java @@ -73,6 +73,7 @@ public final class ClassNames { get("dagger.hilt.internal.definecomponent", "DefineComponentClasses"); public static final ClassName ASSISTED_INJECT = get("dagger.assisted", "AssistedInject"); + public static final ClassName ASSISTED_FACTORY = get("dagger.assisted", "AssistedFactory"); public static final ClassName BINDS = get("dagger", "Binds"); public static final ClassName BINDS_OPTIONAL_OF = diff --git a/java/dagger/hilt/processor/internal/HiltCompilerOptions.java b/java/dagger/hilt/processor/internal/HiltCompilerOptions.java index cdb2d6c980c..90575599bc7 100644 --- a/java/dagger/hilt/processor/internal/HiltCompilerOptions.java +++ b/java/dagger/hilt/processor/internal/HiltCompilerOptions.java @@ -21,6 +21,7 @@ import androidx.room.compiler.processing.XProcessingEnv; import androidx.room.compiler.processing.XTypeElement; +import androidx.room.compiler.processing.compat.XConverters; import com.google.common.collect.ImmutableSet; import dagger.hilt.processor.internal.optionvalues.BooleanValue; import dagger.hilt.processor.internal.optionvalues.GradleProjectType; @@ -101,6 +102,13 @@ public static GradleProjectType getGradleProjectType(XProcessingEnv env) { return GRADLE_PROJECT_TYPE.get(env); } + public static boolean isAssistedInjectViewModelsEnabled(XTypeElement viewModelElement) { + boolean enabled = + ENABLE_ASSISTED_INJECT_VIEWMODELS.get(XConverters.getProcessingEnv(viewModelElement)) + == BooleanValue.TRUE; + return enabled; + } + /** Do not use! This is for internal use only. */ private static final EnumOption DISABLE_ANDROID_SUPERCLASS_VALIDATION = new EnumOption<>("android.internal.disableAndroidSuperclassValidation", BooleanValue.FALSE); @@ -124,6 +132,9 @@ public static GradleProjectType getGradleProjectType(XProcessingEnv env) { private static final EnumOption GRADLE_PROJECT_TYPE = new EnumOption<>("android.internal.projectType", GradleProjectType.UNSET); + private static final EnumOption ENABLE_ASSISTED_INJECT_VIEWMODELS = + new EnumOption<>( + "enableAssistedInjectViewModels", BooleanValue.TRUE ); private static final ImmutableSet DEPRECATED_OPTIONS = ImmutableSet.of("dagger.hilt.android.useFragmentGetContextFix"); diff --git a/javatests/dagger/hilt/android/AndroidManifest.xml b/javatests/dagger/hilt/android/AndroidManifest.xml index 3c0fa840d1c..a170fb019f7 100644 --- a/javatests/dagger/hilt/android/AndroidManifest.xml +++ b/javatests/dagger/hilt/android/AndroidManifest.xml @@ -65,6 +65,30 @@ android:name=".ViewModelScopedTest$TestActivity" android:exported="false" tools:ignore="MissingClass"/> + + + + + + scenario = + ActivityScenario.launch(TestConfigChangeActivity.class)) { + scenario.onActivity( + activity -> { + assertThat(activity.vm.one.bar).isNotNull(); + assertThat(activity.vm.one.bar).isSameInstanceAs(activity.vm.two.bar); + assertThat(activity.vm.s).isEqualTo("foo"); + }); + scenario.recreate(); + scenario.onActivity( + activity -> { + // Check that we still get the same ViewModel instance after config change and the + // passed assisted arg has no effect anymore. + assertThat(activity.vm.s).isEqualTo("foo"); + }); + } + } + + @Test + public void testKeyedViewModels() { + try (ActivityScenario scenario = + ActivityScenario.launch(TestKeyedViewModelActivity.class)) { + scenario.onActivity( + activity -> { + assertThat(activity.vm1.s).isEqualTo("foo"); + assertThat(activity.vm2.s).isEqualTo("bar"); + }); + } + } + + @Test + public void testNoCreationCallbacks() { + Exception exception = + assertThrows( + IllegalStateException.class, + () -> ActivityScenario.launch(TestNoCreationCallbacksActivity.class).close()); + assertThat(exception) + .hasMessageThat() + .contains( + "Found @HiltViewModel-annotated class" + + " dagger.hilt.android.ViewModelAssistedTest$MyViewModel" + + " using @AssistedInject but no creation callback was provided" + + " in CreationExtras."); + } + + @Test + public void testNoFactory() { + Exception exception = + assertThrows( + RuntimeException.class, + () -> ActivityScenario.launch(TestNoFactoryActivity.class).close()); + assertThat(exception) + .hasMessageThat() + .contains( + "Found creation callback but class" + + " dagger.hilt.android.ViewModelAssistedTest$MyInjectedViewModel does not have an" + + " assisted factory specified in @HiltViewModel."); + } + + @Test + public void testFragmentArgs() { + try (ActivityScenario scenario = + ActivityScenario.launch(TestFragmentArgsActivity.class)) { + scenario.onActivity( + activity -> { + TestFragment fragment = + (TestFragment) activity.getSupportFragmentManager().findFragmentByTag("tag"); + assertThat(fragment.vm.handle.get("key")).isEqualTo("foobar"); + }); + } + } + + @Test + public void testIncompatibleFactories() { + Exception exception = + assertThrows( + ClassCastException.class, + () -> ActivityScenario.launch(TestIncompatibleFactoriesActivity.class).close()); + assertThat(exception) + .hasMessageThat() + .contains( + "class dagger.hilt.android.ViewModelAssistedTest_MyViewModel_Factory_Impl cannot be" + + " cast to class" + + " dagger.hilt.android.ViewModelAssistedTest$MyViewModel$AnotherFactory"); + } + + @AndroidEntryPoint(FragmentActivity.class) + public static class TestConfigChangeActivity + extends Hilt_ViewModelAssistedTest_TestConfigChangeActivity { + + MyViewModel vm; + + @Override + protected void onCreate(@Nullable Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + if (savedInstanceState == null) { + vm = + new ViewModelProvider( + getViewModelStore(), + getDefaultViewModelProviderFactory(), + getCreationExtrasWithCreationCallback( + this, factory -> ((MyViewModel.Factory) factory).create("foo"))) + .get(MyViewModel.class); + } else { + vm = + new ViewModelProvider( + getViewModelStore(), + getDefaultViewModelProviderFactory(), + getCreationExtrasWithCreationCallback( + this, factory -> ((MyViewModel.Factory) factory).create("bar"))) + .get(MyViewModel.class); + } + } + } + + @AndroidEntryPoint(FragmentActivity.class) + public static class TestKeyedViewModelActivity + extends Hilt_ViewModelAssistedTest_TestKeyedViewModelActivity { + + MyViewModel vm1; + MyViewModel vm2; + + @Override + protected void onCreate(@Nullable Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + vm1 = + new ViewModelProvider( + getViewModelStore(), + getDefaultViewModelProviderFactory(), + getCreationExtrasWithCreationCallback( + this, factory -> ((MyViewModel.Factory) factory).create("foo"))) + .get("a", MyViewModel.class); + + vm2 = + new ViewModelProvider( + getViewModelStore(), + getDefaultViewModelProviderFactory(), + getCreationExtrasWithCreationCallback( + this, factory -> ((MyViewModel.Factory) factory).create("bar"))) + .get("b", MyViewModel.class); + } + } + + @AndroidEntryPoint(FragmentActivity.class) + public static class TestNoCreationCallbacksActivity + extends Hilt_ViewModelAssistedTest_TestNoCreationCallbacksActivity { + + MyViewModel vm; + + @Override + protected void onCreate(@Nullable Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + vm = new ViewModelProvider(this).get(MyViewModel.class); + } + } + + @AndroidEntryPoint(FragmentActivity.class) + public static class TestNoFactoryActivity + extends Hilt_ViewModelAssistedTest_TestNoFactoryActivity { + + MyInjectedViewModel vm; + + @Override + protected void onCreate(@Nullable Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + + vm = + new ViewModelProvider( + getViewModelStore(), + getDefaultViewModelProviderFactory(), + getCreationExtrasWithCreationCallback( + this, factory -> ((MyViewModel.Factory) factory).create("bar"))) + .get(MyInjectedViewModel.class); + } + } + + @AndroidEntryPoint(FragmentActivity.class) + public static class TestFragmentArgsActivity + extends Hilt_ViewModelAssistedTest_TestFragmentArgsActivity { + @Override + protected void onCreate(@Nullable Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + if (savedInstanceState == null) { + Fragment f = + getSupportFragmentManager() + .getFragmentFactory() + .instantiate(TestFragment.class.getClassLoader(), TestFragment.class.getName()); + Bundle b = new Bundle(); + b.putString("key", "foobar"); + f.setArguments(b); + getSupportFragmentManager().beginTransaction().add(0, f, "tag").commitNow(); + } + } + } + + @AndroidEntryPoint(FragmentActivity.class) + public static class TestIncompatibleFactoriesActivity + extends Hilt_ViewModelAssistedTest_TestIncompatibleFactoriesActivity { + + MyViewModel vm; + + @Override + protected void onCreate(@Nullable Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + vm = + new ViewModelProvider( + getViewModelStore(), + getDefaultViewModelProviderFactory(), + getCreationExtrasWithCreationCallback( + this, factory -> ((MyViewModel.AnotherFactory) factory).create("foo"))) + .get(MyViewModel.class); + } + } + + @AndroidEntryPoint(Fragment.class) + public static class TestFragment extends Hilt_ViewModelAssistedTest_TestFragment { + + MyViewModel vm; + + @Override + public void onCreate(@Nullable Bundle bundle) { + super.onCreate(bundle); + vm = + new ViewModelProvider( + getViewModelStore(), + getDefaultViewModelProviderFactory(), + getCreationExtrasWithCreationCallback( + this, factory -> ((MyViewModel.Factory) factory).create("foo"))) + .get(MyViewModel.class); + } + } + + private static CreationExtras getCreationExtrasWithCreationCallback( + HasDefaultViewModelProviderFactory owner, Function1 callback) { + MutableCreationExtras extras = + new MutableCreationExtras(owner.getDefaultViewModelCreationExtras()); + extras.set(HiltViewModelFactory.CREATION_CALLBACK_KEY, callback); + return extras; + } + + @HiltViewModel(assistedFactory = MyViewModel.Factory.class) + static class MyViewModel extends ViewModel { + + final DependsOnBarOne one; + final DependsOnBarTwo two; + final SavedStateHandle handle; + final String s; + boolean cleared = false; + + @AssistedInject + MyViewModel( + DependsOnBarOne one, + DependsOnBarTwo two, + ViewModelLifecycle lifecycle, + SavedStateHandle handle, + @Assisted String s) { + this.one = one; + this.two = two; + this.s = s; + this.handle = handle; + lifecycle.addOnClearedListener(() -> cleared = true); + } + + @AssistedFactory + interface Factory { + MyViewModel create(String s); + } + + @AssistedFactory + interface AnotherFactory { + MyViewModel create(String s); + } + } + + @HiltViewModel + static class MyInjectedViewModel extends ViewModel { + + final DependsOnBarOne one; + final DependsOnBarTwo two; + final SavedStateHandle handle; + boolean cleared = false; + + @Inject + MyInjectedViewModel( + DependsOnBarOne one, + DependsOnBarTwo two, + ViewModelLifecycle lifecycle, + SavedStateHandle handle) { + this.one = one; + this.two = two; + this.handle = handle; + lifecycle.addOnClearedListener(() -> cleared = true); + } + } + + @ViewModelScoped + static class Bar { + @Inject + Bar() {} + } + + static class DependsOnBarOne { + final Bar bar; + + @Inject + DependsOnBarOne(Bar bar) { + this.bar = bar; + } + } + + static class DependsOnBarTwo { + final Bar bar; + + @Inject + DependsOnBarTwo(Bar bar) { + this.bar = bar; + } + } +} diff --git a/javatests/dagger/hilt/android/processor/internal/viewmodel/BUILD b/javatests/dagger/hilt/android/processor/internal/viewmodel/BUILD index 5dfb0aaf6e2..53bc73c8c8e 100644 --- a/javatests/dagger/hilt/android/processor/internal/viewmodel/BUILD +++ b/javatests/dagger/hilt/android/processor/internal/viewmodel/BUILD @@ -15,8 +15,8 @@ # Description: # Tests for internal code for implementing Hilt processors. -load("@io_bazel_rules_kotlin//kotlin:kotlin.bzl", "kt_jvm_library") load("//java/dagger/testing/compile:macros.bzl", "kt_compiler_test") +load("@io_bazel_rules_kotlin//kotlin:kotlin.bzl", "kt_jvm_library") package(default_visibility = ["//:src"]) @@ -33,6 +33,7 @@ kt_compiler_test( "//java/dagger/hilt/android/testing/compile", "//java/dagger/internal/codegen/xprocessing", "//java/dagger/internal/codegen/xprocessing:xprocessing-testing", + "//third_party/java/guava/collect", "//third_party/java/junit", ], ) diff --git a/javatests/dagger/hilt/android/processor/internal/viewmodel/ViewModelProcessorTest.kt b/javatests/dagger/hilt/android/processor/internal/viewmodel/ViewModelProcessorTest.kt index b35aed247c9..133770084f4 100644 --- a/javatests/dagger/hilt/android/processor/internal/viewmodel/ViewModelProcessorTest.kt +++ b/javatests/dagger/hilt/android/processor/internal/viewmodel/ViewModelProcessorTest.kt @@ -18,6 +18,7 @@ package dagger.hilt.android.processor.internal.viewmodel import androidx.room.compiler.processing.ExperimentalProcessingApi import androidx.room.compiler.processing.util.Source +import com.google.common.collect.ImmutableMap import dagger.hilt.android.testing.compile.HiltCompilerTests import org.junit.Test import org.junit.runner.RunWith @@ -84,24 +85,23 @@ class ViewModelProcessorTest { } @Test - fun verifySingleAnnotatedConstructor() { + fun verifyNoAssistedInjectViewModels() { val myViewModel = Source.java( "dagger.hilt.android.test.MyViewModel", """ package dagger.hilt.android.test; + import dagger.assisted.AssistedInject; + import dagger.assisted.Assisted; import androidx.lifecycle.ViewModel; import dagger.hilt.android.lifecycle.HiltViewModel; import javax.inject.Inject; @HiltViewModel class MyViewModel extends ViewModel { - @Inject - MyViewModel() { } - - @Inject - MyViewModel(String s) { } + @AssistedInject + MyViewModel(String s, @Assisted int i) { } } """ .trimIndent() @@ -110,18 +110,67 @@ class ViewModelProcessorTest { HiltCompilerTests.hiltCompiler(myViewModel) .withAdditionalJavacProcessors(ViewModelProcessor()) .withAdditionalKspProcessors(KspViewModelProcessor.Provider()) + .withProcessorOptions(ImmutableMap.of("dagger.hilt.enableAssistedInjectViewModels", "false")) .compile { subject -> subject.compilationDidFail() - subject.hasErrorCount(2) - subject.hasErrorContaining( - "Type dagger.hilt.android.test.MyViewModel may only contain one injected constructor. Found: [@Inject dagger.hilt.android.test.MyViewModel(), @Inject dagger.hilt.android.test.MyViewModel(String)]" - ) + subject.hasErrorCount(1) subject.hasErrorContaining( - "@HiltViewModel annotated class should contain exactly one @Inject annotated constructor." + "ViewModel constructor should be annotated with @Inject instead of @AssistedInject." ) } } + @Test + fun verifySingleAnnotatedConstructor() { + val myViewModel = + Source.java( + "dagger.hilt.android.test.MyViewModel", + """ + package dagger.hilt.android.test; + + import androidx.lifecycle.ViewModel; + import dagger.hilt.android.lifecycle.HiltViewModel; + import javax.inject.Inject; + + @HiltViewModel + class MyViewModel extends ViewModel { + @Inject + MyViewModel() { } + + @Inject + MyViewModel(String s) { } + } + """ + .trimIndent() + ) + + listOf(false, true).forEach { enableAssistedInjectViewModels -> + HiltCompilerTests.hiltCompiler(myViewModel) + .withAdditionalJavacProcessors(ViewModelProcessor()) + .withAdditionalKspProcessors(KspViewModelProcessor.Provider()) + .withProcessorOptions( + ImmutableMap.of( + "dagger.hilt.enableAssistedInjectViewModels", + enableAssistedInjectViewModels.toString() + ) + ) + .compile { subject -> + subject.compilationDidFail() + subject.hasErrorCount(2) + subject.hasErrorContaining( + "Type dagger.hilt.android.test.MyViewModel may only contain one injected constructor. Found: [@Inject dagger.hilt.android.test.MyViewModel(), @Inject dagger.hilt.android.test.MyViewModel(String)]" + ) + subject.hasErrorContaining( + if (enableAssistedInjectViewModels) { + "@HiltViewModel annotated class should contain exactly one @Inject or @AssistedInject annotated constructor." + } else { + "@HiltViewModel annotated class should contain exactly one @Inject annotated constructor." + } + ) + } + } + } + @Test fun verifyNonPrivateConstructor() { val myViewModel = @@ -143,15 +192,29 @@ class ViewModelProcessorTest { .trimIndent() ) - HiltCompilerTests.hiltCompiler(myViewModel) - .withAdditionalJavacProcessors(ViewModelProcessor()) - .withAdditionalKspProcessors(KspViewModelProcessor.Provider()) - .compile { subject -> - subject.compilationDidFail() - subject.hasErrorCount(2) - subject.hasErrorContaining("Dagger does not support injection into private constructors") - subject.hasErrorContaining("@Inject annotated constructors must not be private.") - } + listOf(false, true).forEach { enableAssistedInjectViewModels -> + HiltCompilerTests.hiltCompiler(myViewModel) + .withAdditionalJavacProcessors(ViewModelProcessor()) + .withAdditionalKspProcessors(KspViewModelProcessor.Provider()) + .withProcessorOptions( + ImmutableMap.of( + "dagger.hilt.enableAssistedInjectViewModels", + enableAssistedInjectViewModels.toString() + ) + ) + .compile { subject -> + subject.compilationDidFail() + subject.hasErrorCount(2) + subject.hasErrorContaining("Dagger does not support injection into private constructors") + subject.hasErrorContaining( + if (enableAssistedInjectViewModels) { + "@Inject or @AssistedInject annotated constructors must not be private." + } else { + "@Inject annotated constructors must not be private." + } + ) + } + } } @Test @@ -225,4 +288,340 @@ class ViewModelProcessorTest { ) } } + + @Test + fun verifyAssistedFlagIsEnabled() { + val myViewModel = + Source.java( + "dagger.hilt.android.test.MyViewModel", + """ + package dagger.hilt.android.test; + + import dagger.assisted.Assisted; + import dagger.assisted.AssistedInject; + import androidx.lifecycle.ViewModel; + import dagger.hilt.android.lifecycle.HiltViewModel; + + @HiltViewModel(assistedFactory = MyFactory.class) + class MyViewModel extends ViewModel { + @AssistedInject + MyViewModel(String s, @Assisted int i) { } + } + """ + .trimIndent() + ) + val myFactory = + Source.java( + "dagger.hilt.android.test.MyFactory", + """ + package dagger.hilt.android.test; + import dagger.assisted.AssistedFactory; + @AssistedFactory + interface MyFactory { + MyViewModel create(int i); + } + """ + ) + + HiltCompilerTests.hiltCompiler(myViewModel, myFactory) + .withAdditionalJavacProcessors(ViewModelProcessor()) + .withAdditionalKspProcessors(KspViewModelProcessor.Provider()) + .withProcessorOptions(ImmutableMap.of("dagger.hilt.enableAssistedInjectViewModels", "false")) + .compile { subject -> + subject.compilationDidFail() + subject.hasErrorCount(1) + subject.hasErrorContaining( + "Specified assisted factory dagger.hilt.android.test.MyFactory for dagger.hilt.android.test.MyViewModel in @HiltViewModel but compiler option 'enableAssistedInjectViewModels' was not enabled." + ) + } + } + + @Test + fun verifyAssistedFactoryHasMethod() { + val myViewModel = + Source.java( + "dagger.hilt.android.test.MyViewModel", + """ + package dagger.hilt.android.test; + + import dagger.assisted.Assisted; + import dagger.assisted.AssistedInject; + import androidx.lifecycle.ViewModel; + import dagger.hilt.android.lifecycle.HiltViewModel; + + @HiltViewModel(assistedFactory = MyFactory.class) + class MyViewModel extends ViewModel { + @AssistedInject + MyViewModel(String s, @Assisted int i) { } + } + """ + .trimIndent() + ) + val myFactory = + Source.java( + "dagger.hilt.android.test.MyFactory", + """ + package dagger.hilt.android.test; + import dagger.assisted.AssistedFactory; + @AssistedFactory + interface MyFactory {} + """ + ) + + HiltCompilerTests.hiltCompiler(myViewModel, myFactory) + .withAdditionalJavacProcessors(ViewModelProcessor()) + .withAdditionalKspProcessors(KspViewModelProcessor.Provider()) + .withProcessorOptions(ImmutableMap.of("dagger.hilt.enableAssistedInjectViewModels", "true")) + .compile { subject -> + subject.compilationDidFail() + subject.hasErrorCount(2) + subject.hasErrorContaining( + "The @AssistedFactory-annotated type is missing an abstract, non-default method whose return type matches the assisted injection type." + ) + subject.hasErrorContaining( + "Cannot find assisted factory method in dagger.hilt.android.test.MyFactory." + ) + } + } + + @Test + fun verifyAssistedFactoryHasOnlyOneMethod() { + val myViewModel = + Source.java( + "dagger.hilt.android.test.MyViewModel", + """ + package dagger.hilt.android.test; + + import dagger.assisted.Assisted; + import dagger.assisted.AssistedInject; + import androidx.lifecycle.ViewModel; + import dagger.hilt.android.lifecycle.HiltViewModel; + + @HiltViewModel(assistedFactory = MyFactory.class) + class MyViewModel extends ViewModel { + @AssistedInject + MyViewModel(String s, @Assisted int i) { } + } + """ + .trimIndent() + ) + val myFactory = + Source.java( + "dagger.hilt.android.test.MyFactory", + """ + package dagger.hilt.android.test; + import dagger.assisted.AssistedFactory; + @AssistedFactory + interface MyFactory { + MyViewModel create(int i); + String createString(int i); + Integer createInteger(int i); + } + """ + ) + + HiltCompilerTests.hiltCompiler(myViewModel, myFactory) + .withAdditionalJavacProcessors(ViewModelProcessor()) + .withAdditionalKspProcessors(KspViewModelProcessor.Provider()) + .withProcessorOptions(ImmutableMap.of("dagger.hilt.enableAssistedInjectViewModels", "true")) + .compile { subject -> + subject.compilationDidFail() + subject.hasErrorCount(4) + subject.hasErrorContaining( + "The @AssistedFactory-annotated type should contain a single abstract, non-default method but found multiple: [dagger.hilt.android.test.MyFactory.create(int), dagger.hilt.android.test.MyFactory.createString(int), dagger.hilt.android.test.MyFactory.createInteger(int)]" + ) + subject.hasErrorContaining( + "Invalid return type: java.lang.String. An assisted factory's abstract method must return a type with an @AssistedInject-annotated constructor." + ) + subject.hasErrorContaining( + "Invalid return type: java.lang.Integer. An assisted factory's abstract method must return a type with an @AssistedInject-annotated constructor." + ) + subject.hasErrorContaining( + "Cannot find assisted factory method in dagger.hilt.android.test.MyFactory." + ) + } + } + + @Test + fun verifyAssistedFactoryIsAnnotatedWithAssistedFactory() { + val myViewModel = + Source.java( + "dagger.hilt.android.test.MyViewModel", + """ + package dagger.hilt.android.test; + + import dagger.assisted.Assisted; + import dagger.assisted.AssistedInject; + import androidx.lifecycle.ViewModel; + import dagger.hilt.android.lifecycle.HiltViewModel; + + @HiltViewModel(assistedFactory = Integer.class) + class MyViewModel extends ViewModel { + @AssistedInject + MyViewModel(String s, @Assisted int i) { } + } + """ + .trimIndent() + ) + + HiltCompilerTests.hiltCompiler(myViewModel) + .withAdditionalJavacProcessors(ViewModelProcessor()) + .withAdditionalKspProcessors(KspViewModelProcessor.Provider()) + .withProcessorOptions(ImmutableMap.of("dagger.hilt.enableAssistedInjectViewModels", "true")) + .compile { subject -> + subject.compilationDidFail() + subject.hasErrorCount(1) + subject.hasErrorContaining( + "Class java.lang.Integer is not annotated with @AssistedFactory." + ) + } + } + + @Test + fun verifyFactoryMethodHasCorrectReturnType() { + val myViewModel = + Source.java( + "dagger.hilt.android.test.MyViewModel", + """ + package dagger.hilt.android.test; + + import dagger.assisted.Assisted; + import dagger.assisted.AssistedInject; + import androidx.lifecycle.ViewModel; + import dagger.hilt.android.lifecycle.HiltViewModel; + + @HiltViewModel(assistedFactory = MyFactory.class) + class MyViewModel extends ViewModel { + @AssistedInject + MyViewModel(String s, @Assisted int i) { } + } + """ + .trimIndent() + ) + val myFactory = + Source.java( + "dagger.hilt.android.test.MyFactory", + """ + package dagger.hilt.android.test; + import dagger.assisted.AssistedFactory; + @AssistedFactory + interface MyFactory { + String create(int i); + } + """ + ) + + HiltCompilerTests.hiltCompiler(myViewModel, myFactory) + .withAdditionalJavacProcessors(ViewModelProcessor()) + .withAdditionalKspProcessors(KspViewModelProcessor.Provider()) + .withProcessorOptions(ImmutableMap.of("dagger.hilt.enableAssistedInjectViewModels", "true")) + .compile { subject -> + subject.compilationDidFail() + subject.hasErrorCount(2) + subject.hasErrorContaining( + "Invalid return type: java.lang.String. An assisted factory's abstract method must return a type with an @AssistedInject-annotated constructor." + ) + subject.hasErrorContaining( + "Class dagger.hilt.android.test.MyFactory must have a factory method that returns a dagger.hilt.android.test.MyViewModel. Found java.lang.String." + ) + } + } + + @Test + fun verifyAssistedFactoryIsSpecified() { + val myViewModel = + Source.java( + "dagger.hilt.android.test.MyViewModel", + """ + package dagger.hilt.android.test; + + import dagger.assisted.Assisted; + import dagger.assisted.AssistedInject; + import androidx.lifecycle.ViewModel; + import dagger.hilt.android.lifecycle.HiltViewModel; + + @HiltViewModel + class MyViewModel extends ViewModel { + @AssistedInject + MyViewModel(String s, @Assisted int i) { } + } + """ + .trimIndent() + ) + val myFactory = + Source.java( + "dagger.hilt.android.test.MyFactory", + """ + package dagger.hilt.android.test; + import dagger.assisted.AssistedFactory; + @AssistedFactory + interface MyFactory { + MyViewModel create(int i); + } + """ + ) + + HiltCompilerTests.hiltCompiler(myViewModel, myFactory) + .withAdditionalJavacProcessors(ViewModelProcessor()) + .withAdditionalKspProcessors(KspViewModelProcessor.Provider()) + .withProcessorOptions(ImmutableMap.of("dagger.hilt.enableAssistedInjectViewModels", "true")) + .compile { subject -> + subject.compilationDidFail() + subject.hasErrorCount(1) + subject.hasErrorContaining( + "dagger.hilt.android.test.MyViewModel must have a valid assisted factory specified in @HiltViewModel when used with assisted injection. Found java.lang.Object." + ) + } + } + + @Test + fun verifyConstructorHasRightInjectAnnotation() { + val myViewModel = + Source.java( + "dagger.hilt.android.test.MyViewModel", + """ + package dagger.hilt.android.test; + + import dagger.assisted.Assisted; + import dagger.assisted.AssistedInject; + import androidx.lifecycle.ViewModel; + import dagger.hilt.android.lifecycle.HiltViewModel; + import javax.inject.Inject; + + @HiltViewModel(assistedFactory = MyFactory.class) + class MyViewModel extends ViewModel { + @Inject + MyViewModel(String s, int i) { } + } + """ + .trimIndent() + ) + val myFactory = + Source.java( + "dagger.hilt.android.test.MyFactory", + """ + package dagger.hilt.android.test; + import dagger.assisted.AssistedFactory; + @AssistedFactory + interface MyFactory { + MyViewModel create(int i); + } + """ + ) + + HiltCompilerTests.hiltCompiler(myViewModel, myFactory) + .withAdditionalJavacProcessors(ViewModelProcessor()) + .withAdditionalKspProcessors(KspViewModelProcessor.Provider()) + .withProcessorOptions(ImmutableMap.of("dagger.hilt.enableAssistedInjectViewModels", "true")) + .compile { subject -> + subject.compilationDidFail() + subject.hasErrorCount(2) + subject.hasErrorContaining( + "Invalid return type: dagger.hilt.android.test.MyViewModel. An assisted factory's abstract method must return a type with an @AssistedInject-annotated constructor." + ) + subject.hasErrorContaining( + "Found assisted factory dagger.hilt.android.test.MyFactory in @HiltViewModel but the constructor was annotated with @Inject instead of @AssistedInject." + ) + } + } }