Skip to content

Commit

Permalink
Add AssistedInject to Hilt ViewModel.
Browse files Browse the repository at this point in the history
RELNOTES=Add support for using `@AssistedInject` with `@HiltViewModel`.
PiperOrigin-RevId: 572216305
  • Loading branch information
kuanyingchou authored and Dagger Team committed Oct 10, 2023
1 parent 774bee1 commit 8327177
Show file tree
Hide file tree
Showing 16 changed files with 1,172 additions and 92 deletions.
1 change: 1 addition & 0 deletions java/dagger/hilt/android/internal/lifecycle/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ android_library(
"@maven//:androidx_lifecycle_lifecycle_viewmodel",
"@maven//:androidx_lifecycle_lifecycle_viewmodel_savedstate",
"@maven//:androidx_savedstate_savedstate",
"@maven//:org_jetbrains_kotlin_kotlin_stdlib",
],
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright (C) 2023 The Dagger Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package dagger.hilt.android.internal.lifecycle;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import javax.inject.Qualifier;

/**
* Internal qualifier for the multibinding map of assisted factories for @AssistedInject-annotated
* ViewModels used by the {@link dagger.hilt.android.lifecycle.HiltViewModelFactory}.
*/
@Qualifier
@Retention(RetentionPolicy.CLASS)
@Target({ElementType.METHOD, ElementType.PARAMETER})
public @interface HiltViewModelAssistedMap {}
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@

package dagger.hilt.android.internal.lifecycle;

import static androidx.lifecycle.SavedStateHandleSupport.createSavedStateHandle;

import android.app.Activity;
import android.os.Bundle;
import androidx.annotation.NonNull;
import androidx.annotation.Nullable;
import androidx.lifecycle.AbstractSavedStateViewModelFactory;
import androidx.lifecycle.SavedStateHandle;
import androidx.lifecycle.ViewModel;
import androidx.lifecycle.ViewModelProvider;
import androidx.lifecycle.viewmodel.CreationExtras;
Expand All @@ -37,6 +37,7 @@
import java.util.Map;
import java.util.Set;
import javax.inject.Provider;
import kotlin.jvm.functions.Function1;

/**
* View Model Provider Factory for the Hilt Extension.
Expand All @@ -55,20 +56,32 @@ public final class HiltViewModelFactory implements ViewModelProvider.Factory {
public interface ViewModelFactoriesEntryPoint {
@HiltViewModelMap
Map<String, Provider<ViewModel>> getHiltViewModelMap();

// From ViewModel class names to user defined @AssistedFactory-annotated implementations.
@HiltViewModelAssistedMap
Map<String, Object> getHiltViewModelAssistedMap();
}

/** Creation extra key for the callbacks that create @AssistedInject-annotated ViewModels. */
public static final CreationExtras.Key<Function1<Object, ViewModel>> CREATION_CALLBACK_KEY =
new CreationExtras.Key<Function1<Object, ViewModel>>() {};

/** Hilt module for providing the empty multi-binding map of ViewModels. */
@Module
@InstallIn(ViewModelComponent.class)
interface ViewModelModule {
@Multibinds
@HiltViewModelMap
Map<String, ViewModel> hiltViewModelMap();

@Multibinds
@HiltViewModelAssistedMap
Map<String, Object> hiltViewModelAssistedMap();
}

private final Set<String> hiltViewModelKeys;
private final ViewModelProvider.Factory delegateFactory;
private final AbstractSavedStateViewModelFactory hiltViewModelFactory;
private final ViewModelProvider.Factory hiltViewModelFactory;

public HiltViewModelFactory(
@NonNull Set<String> hiltViewModelKeys,
Expand All @@ -77,31 +90,75 @@ public HiltViewModelFactory(
this.hiltViewModelKeys = hiltViewModelKeys;
this.delegateFactory = delegateFactory;
this.hiltViewModelFactory =
new AbstractSavedStateViewModelFactory() {
new ViewModelProvider.Factory() {
@NonNull
@Override
@SuppressWarnings("unchecked")
protected <T extends ViewModel> T create(
@NonNull String key, @NonNull Class<T> modelClass, @NonNull SavedStateHandle handle) {
public <T extends ViewModel> T create(
@NonNull Class<T> modelClass, @NonNull CreationExtras extras) {
RetainedLifecycleImpl lifecycle = new RetainedLifecycleImpl();
ViewModelComponent component = viewModelComponentBuilder
.savedStateHandle(handle)
.viewModelLifecycle(lifecycle)
.build();
ViewModelComponent component =
viewModelComponentBuilder
.savedStateHandle(createSavedStateHandle(extras))
.viewModelLifecycle(lifecycle)
.build();
T viewModel = createViewModel(component, modelClass, extras);
viewModel.addCloseable(lifecycle::dispatchOnCleared);
return viewModel;
}

private <T extends ViewModel> T createViewModel(
@NonNull ViewModelComponent component,
@NonNull Class<T> modelClass,
@NonNull CreationExtras extras) {
Provider<? extends ViewModel> provider =
EntryPoints.get(component, ViewModelFactoriesEntryPoint.class)
.getHiltViewModelMap()
.get(modelClass.getName());
if (provider == null) {
throw new IllegalStateException(
"Expected the @HiltViewModel-annotated class '"
+ modelClass.getName()
+ "' to be available in the multi-binding of "
+ "@HiltViewModelMap but none was found.");
Function1<Object, ViewModel> creationCallback = extras.get(CREATION_CALLBACK_KEY);
Object assistedFactory =
EntryPoints.get(component, ViewModelFactoriesEntryPoint.class)
.getHiltViewModelAssistedMap()
.get(modelClass.getName());

if (assistedFactory == null) {
if (creationCallback == null) {
if (provider == null) {
throw new IllegalStateException(
"Expected the @HiltViewModel-annotated class "
+ modelClass.getName()
+ " to be available in the multi-binding of "
+ "@HiltViewModelMap"
+ " but none was found.");
} else {
return (T) provider.get();
}
} else {
// Provider could be null or non-null.
throw new IllegalStateException(
"Found creation callback but class "
+ modelClass.getName()
+ " does not have an assisted factory specified in @HiltViewModel.");
}
} else {
if (provider == null) {
if (creationCallback == null) {
throw new IllegalStateException(
"Found @HiltViewModel-annotated class "
+ modelClass.getName()
+ " using @AssistedInject but no creation callback"
+ " was provided in CreationExtras.");
} else {
return (T) creationCallback.invoke(assistedFactory);
}
} else {
// Creation callback could be null or non-null.
throw new AssertionError(
"Found the @HiltViewModel-annotated class "
+ modelClass.getName()
+ " in both the multi-bindings of "
+ "@HiltViewModelMap and @HiltViewModelAssistedMap.");
}
}
ViewModel viewModel = provider.get();
viewModel.addCloseable(lifecycle::dispatchOnCleared);
return (T) viewModel;
}
};
}
Expand Down
9 changes: 8 additions & 1 deletion java/dagger/hilt/android/lifecycle/HiltViewModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,11 @@
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.CLASS)
@GeneratesRootInput
public @interface HiltViewModel {}
public @interface HiltViewModel {
/**
* Returns a factory class that can be used to create this ViewModel with assisted injection. The
* default value `Object.class` denotes that no factory is specified and the ViewModel is not
* assisted injected.
*/
Class<?> assistedFactory() default Object.class;
}
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ public final class AndroidClassNames {
get("dagger.hilt.android.lifecycle", "HiltViewModel");
public static final ClassName HILT_VIEW_MODEL_MAP_QUALIFIER =
get("dagger.hilt.android.internal.lifecycle", "HiltViewModelMap");

public static final ClassName HILT_VIEW_MODEL_ASSISTED_FACTORY_MAP_QUALIFIER =
get("dagger.hilt.android.internal.lifecycle", "HiltViewModelAssistedMap");

public static final ClassName HILT_VIEW_MODEL_KEYS_QUALIFIER =
get("dagger.hilt.android.internal.lifecycle", "HiltViewModelMap", "KeySet");
public static final ClassName VIEW_MODEL = get("androidx.lifecycle", "ViewModel");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ kt_jvm_library(
"//java/dagger/hilt/android/processor/internal:android_classnames",
"//java/dagger/hilt/processor/internal:base_processor",
"//java/dagger/hilt/processor/internal:classnames",
"//java/dagger/hilt/processor/internal:compiler_options",
"//java/dagger/hilt/processor/internal:processor_errors",
"//java/dagger/hilt/processor/internal:processors",
"//java/dagger/internal/codegen/xprocessing",
Expand Down
Loading

0 comments on commit 8327177

Please sign in to comment.