Skip to content

Commit

Permalink
[Refactor] Consolidate logic for wrapping/unwrapping map values in fr…
Browse files Browse the repository at this point in the history
…amework types.

This CL consolidates the logic from a few different classes into a single class.

This CL also simplifies some sorting logic in `MapMultibindingValidator` by relying on the natural sorting order of `RequestKind`.

RELNOTES=N/A
PiperOrigin-RevId: 679237839
  • Loading branch information
bcorso authored and Dagger Team committed Sep 26, 2024
1 parent 2097a04 commit ce0dfe6
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 145 deletions.
46 changes: 28 additions & 18 deletions java/dagger/internal/codegen/base/MapType.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,22 @@

import androidx.room.compiler.processing.XType;
import com.google.auto.value.AutoValue;
import com.google.common.collect.ImmutableSet;
import com.squareup.javapoet.ClassName;
import com.squareup.javapoet.TypeName;
import dagger.internal.codegen.javapoet.TypeNames;
import dagger.internal.codegen.model.Key;
import dagger.internal.codegen.model.RequestKind;
import dagger.internal.codegen.xprocessing.XTypes;

/** Information about a {@link java.util.Map} type. */
@AutoValue
public abstract class MapType {
// TODO(b/28555349): support PROVIDER_OF_LAZY here too
/** The valid framework request kinds allowed on a multibinding map value. */
public static final ImmutableSet<RequestKind> VALID_FRAMEWORK_REQUEST_KINDS =
ImmutableSet.of(RequestKind.PROVIDER, RequestKind.PRODUCER, RequestKind.PRODUCED);

private XType type;

/** The map type itself. */
Expand Down Expand Up @@ -74,35 +81,38 @@ public boolean valuesAreTypeOf(ClassName className) {

/** Returns {@code true} if the raw type of {@link #valueType()} is a framework type. */
public boolean valuesAreFrameworkType() {
return FrameworkTypes.isFrameworkType(valueType());
return valueRequestKind() != RequestKind.INSTANCE;
}

/**
* {@code V} if {@link #valueType()} is a framework type like {@code Provider<V>} or {@code
* Producer<V>}.
* Returns the map's {@link #valueType()} without any wrapping framework type, if one exists.
*
* <p>In particular, this method returns {@code V} for all of the following map types:
* {@code Map<K,V>}, {@code Map<K,Provider<V>>}, {@code Map<K,Producer<V>>}, and
* {@code Map<K,Produced<V>>}.
*
* @throws IllegalStateException if {@link #isRawType()} is true or {@link #valueType()} is not a
* framework type
* <p>Note that we don't consider {@code Lazy} a framework type for this particular case, so this
* method will return {@code Lazy<V>} for {@code Map<K,Lazy<V>>}.
*
* @throws IllegalStateException if {@link #isRawType()} is true.
*/
public XType unwrappedFrameworkValueType() {
checkState(valuesAreFrameworkType(), "called unwrappedFrameworkValueType() on %s", type());
return uncheckedUnwrappedValueType();
return valuesAreFrameworkType() ? unwrapType(valueType()) : valueType();
}

/**
* {@code V} if {@link #valueType()} is a {@code WrappingClass<V>}.
* Returns the {@link RequestKind} of the {@link #valueType()}.
*
* @throws IllegalStateException if {@link #isRawType()} is true or {@link #valueType()} is not a
* {@code WrappingClass<V>}
* @throws IllegalArgumentException if {@link #isRawType()} is true.
*/
// TODO(b/202033221): Consider using stricter input type, e.g. FrameworkType.
public XType unwrappedValueType(ClassName wrappingClass) {
checkState(valuesAreTypeOf(wrappingClass), "expected values to be %s: %s", wrappingClass, this);
return uncheckedUnwrappedValueType();
}

private XType uncheckedUnwrappedValueType() {
return unwrapType(valueType());
public RequestKind valueRequestKind() {
checkArgument(!isRawType());
for (RequestKind frameworkRequestKind : VALID_FRAMEWORK_REQUEST_KINDS) {
if (valuesAreTypeOf(RequestKinds.frameworkClassName(frameworkRequestKind))) {
return frameworkRequestKind;
}
}
return RequestKind.INSTANCE;
}

/** {@code true} if {@code type} is a {@link java.util.Map} type. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,11 @@
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.Iterables.getOnlyElement;
import static dagger.internal.codegen.base.RequestKinds.extractKeyType;
import static dagger.internal.codegen.base.RequestKinds.frameworkClassName;
import static dagger.internal.codegen.base.RequestKinds.getRequestKind;
import static dagger.internal.codegen.binding.AssistedInjectionAnnotations.isAssistedParameter;
import static dagger.internal.codegen.model.RequestKind.FUTURE;
import static dagger.internal.codegen.model.RequestKind.INSTANCE;
import static dagger.internal.codegen.model.RequestKind.MEMBERS_INJECTION;
import static dagger.internal.codegen.model.RequestKind.PRODUCER;
import static dagger.internal.codegen.model.RequestKind.PROVIDER;
import static dagger.internal.codegen.xprocessing.XTypes.isTypeOf;
import static dagger.internal.codegen.xprocessing.XTypes.unwrapType;
Expand Down Expand Up @@ -104,21 +102,11 @@ private DependencyRequest forMultibindingContribution(
.build();
}

// TODO(b/28555349): support PROVIDER_OF_LAZY here too
private static final ImmutableSet<RequestKind> WRAPPING_MAP_VALUE_FRAMEWORK_TYPES =
ImmutableSet.of(PROVIDER, PRODUCER);

private RequestKind multibindingContributionRequestKind(
Key multibindingKey, ContributionBinding multibindingContribution) {
switch (multibindingContribution.contributionType()) {
case MAP:
MapType mapType = MapType.from(multibindingKey);
for (RequestKind kind : WRAPPING_MAP_VALUE_FRAMEWORK_TYPES) {
if (mapType.valuesAreTypeOf(frameworkClassName(kind))) {
return kind;
}
}
// fall through
return MapType.from(multibindingKey).valueRequestKind();
case SET:
case SET_VALUES:
return INSTANCE;
Expand Down
88 changes: 24 additions & 64 deletions java/dagger/internal/codegen/binding/KeyFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import static dagger.internal.codegen.javapoet.TypeNames.isFutureType;
import static dagger.internal.codegen.xprocessing.XTypes.isDeclared;
import static dagger.internal.codegen.xprocessing.XTypes.unwrapType;
import static java.util.Arrays.asList;

import androidx.room.compiler.processing.XAnnotation;
import androidx.room.compiler.processing.XMethodElement;
Expand All @@ -39,7 +38,6 @@
import dagger.Binds;
import dagger.BindsOptionalOf;
import dagger.internal.codegen.base.ContributionType;
import dagger.internal.codegen.base.FrameworkTypes;
import dagger.internal.codegen.base.MapType;
import dagger.internal.codegen.base.OptionalType;
import dagger.internal.codegen.base.RequestKinds;
Expand Down Expand Up @@ -286,82 +284,44 @@ public Key forProductionComponentMonitor() {
public Key unwrapMapValueType(Key key) {
if (MapType.isMap(key)) {
MapType mapType = MapType.from(key);
if (!mapType.isRawType()) {
for (ClassName frameworkClass :
asList(TypeNames.PROVIDER, TypeNames.PRODUCER, TypeNames.PRODUCED)) {
if (mapType.valuesAreTypeOf(frameworkClass)) {
return key.withType(
DaggerType.from(
mapOf(mapType.keyType(), mapType.unwrappedValueType(frameworkClass))));
}
}
if (!mapType.isRawType() && mapType.valuesAreFrameworkType()) {
return key.withType(
DaggerType.from(mapOf(mapType.keyType(), mapType.unwrappedFrameworkValueType())));
}
}
return key;
}

/** Converts a {@link Key} of type {@code Map<K, V>} to {@code Map<K, Provider<V>>}. */
private Key wrapMapValue(Key key, ClassName newWrappingClassName) {
checkArgument(FrameworkTypes.isFrameworkType(processingEnv.requireType(newWrappingClassName)));
return wrapMapKey(key, newWrappingClassName).get();
}

/**
* If {@code key}'s type is {@code Map<K, CurrentWrappingClass<Bar>>}, returns a key with type
* {@code Map<K, NewWrappingClass<Bar>>} with the same qualifier. Otherwise returns {@link
* Optional#empty()}.
*
* <p>Returns {@link Optional#empty()} if {@code newWrappingClass} is not in the classpath.
*
* @throws IllegalArgumentException if {@code newWrappingClass} is the same as {@code
* currentWrappingClass}
*/
public Optional<Key> rewrapMapKey(
Key possibleMapKey, ClassName currentWrappingClassName, ClassName newWrappingClassName) {
checkArgument(!currentWrappingClassName.equals(newWrappingClassName));
if (MapType.isMap(possibleMapKey)) {
MapType mapType = MapType.from(possibleMapKey);
if (!mapType.isRawType() && mapType.valuesAreTypeOf(currentWrappingClassName)) {
XTypeElement wrappingElement = processingEnv.findTypeElement(newWrappingClassName);
if (wrappingElement == null) {
// This target might not be compiled with Producers, so wrappingClass might not have an
// associated element.
return Optional.empty();
}
XType wrappedValueType =
processingEnv.getDeclaredType(
wrappingElement, mapType.unwrappedValueType(currentWrappingClassName));
return Optional.of(
possibleMapKey.withType(DaggerType.from(mapOf(mapType.keyType(), wrappedValueType))));
}
}
return Optional.empty();
}

/**
* If {@code key}'s type is {@code Map<K, Foo>} and {@code Foo} is not {@code WrappingClass
* <Bar>}, returns a key with type {@code Map<K, WrappingClass<Foo>>} with the same qualifier.
* Otherwise returns {@link Optional#empty()}.
* Returns a key with the type {@code Map<K, FrameworkType<V>>} if the given key has a type of
* {@code Map<K, V>}. Otherwise, returns the unaltered key.
*
* <p>Returns {@link Optional#empty()} if {@code WrappingClass} is not in the classpath.
* @throws IllegalArgumentException if the {@code frameworkClassName} is not a valid framework
* type for multibinding maps.
* @throws IllegalStateException if the {@code key} is already wrapped in a (different) framework
* type.
*/
private Optional<Key> wrapMapKey(Key possibleMapKey, ClassName wrappingClassName) {
if (MapType.isMap(possibleMapKey)) {
MapType mapType = MapType.from(possibleMapKey);
if (!mapType.isRawType() && !mapType.valuesAreTypeOf(wrappingClassName)) {
XTypeElement wrappingElement = processingEnv.findTypeElement(wrappingClassName);
if (wrappingElement == null) {
private Key wrapMapValue(Key key, ClassName frameworkClassName) {
checkArgument(
MapType.VALID_FRAMEWORK_REQUEST_KINDS.stream()
.map(RequestKinds::frameworkClassName)
.anyMatch(frameworkClassName::equals));
if (MapType.isMap(key)) {
MapType mapType = MapType.from(key);
if (!mapType.isRawType() && !mapType.valuesAreTypeOf(frameworkClassName)) {
checkState(!mapType.valuesAreFrameworkType());
XTypeElement frameworkTypeElement = processingEnv.findTypeElement(frameworkClassName);
if (frameworkTypeElement == null) {
// This target might not be compiled with Producers, so wrappingClass might not have an
// associated element.
return Optional.empty();
return key;
}
XType wrappedValueType =
processingEnv.getDeclaredType(wrappingElement, mapType.valueType());
return Optional.of(
possibleMapKey.withType(DaggerType.from(mapOf(mapType.keyType(), wrappedValueType))));
processingEnv.getDeclaredType(frameworkTypeElement, mapType.valueType());
return key.withType(DaggerType.from(mapOf(mapType.keyType(), wrappedValueType)));
}
}
return Optional.empty();
return key;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@
package dagger.internal.codegen.bindinggraphvalidation;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.Multimaps.filterKeys;
import static dagger.internal.codegen.base.Formatter.INDENT;
import static dagger.internal.codegen.extension.DaggerStreams.toImmutableSet;
import static dagger.internal.codegen.extension.DaggerStreams.toImmutableSetMultimap;
import static dagger.internal.codegen.model.BindingKind.MULTIBOUND_MAP;
import static dagger.internal.codegen.xprocessing.XAnnotations.getClassName;
import static javax.tools.Diagnostic.Kind.ERROR;
Expand All @@ -29,21 +27,21 @@
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSetMultimap;
import com.google.common.collect.Multimaps;
import com.google.common.collect.SetMultimap;
import com.squareup.javapoet.ClassName;
import dagger.internal.codegen.base.MapType;
import dagger.internal.codegen.binding.BindingNode;
import dagger.internal.codegen.binding.ContributionBinding;
import dagger.internal.codegen.binding.Declaration;
import dagger.internal.codegen.binding.DeclarationFormatter;
import dagger.internal.codegen.binding.KeyFactory;
import dagger.internal.codegen.javapoet.TypeNames;
import dagger.internal.codegen.model.Binding;
import dagger.internal.codegen.model.BindingGraph;
import dagger.internal.codegen.model.DiagnosticReporter;
import dagger.internal.codegen.model.Key;
import dagger.internal.codegen.validation.ValidationBindingGraphPlugin;
import dagger.internal.codegen.xprocessing.XAnnotations;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Set;
import javax.inject.Inject;

Expand Down Expand Up @@ -92,41 +90,19 @@ public void visitGraph(BindingGraph bindingGraph, DiagnosticReporter diagnosticR
* </ol>
*/
private ImmutableSet<Binding> mapMultibindings(BindingGraph bindingGraph) {
ImmutableSetMultimap<Key, Binding> mapMultibindings =
bindingGraph.bindings().stream()
.filter(node -> node.kind().equals(MULTIBOUND_MAP))
.collect(toImmutableSetMultimap(Binding::key, node -> node));

// Mutlbindings for Map<K, V>
SetMultimap<Key, Binding> plainValueMapMultibindings =
filterKeys(mapMultibindings, key -> !MapType.from(key).valuesAreFrameworkType());

// Multibindings for Map<K, Provider<V>> where Map<K, V> isn't in plainValueMapMultibindings
SetMultimap<Key, Binding> providerValueMapMultibindings =
filterKeys(
mapMultibindings,
key ->
MapType.from(key).valuesAreTypeOf(TypeNames.PROVIDER)
&& !plainValueMapMultibindings.containsKey(keyFactory.unwrapMapValueType(key)));

// Multibindings for Map<K, Producer<V>> where Map<K, V> isn't in plainValueMapMultibindings and
// Map<K, Provider<V>> isn't in providerValueMapMultibindings
SetMultimap<Key, Binding> producerValueMapMultibindings =
filterKeys(
mapMultibindings,
key ->
MapType.from(key).valuesAreTypeOf(TypeNames.PRODUCER)
&& !plainValueMapMultibindings.containsKey(keyFactory.unwrapMapValueType(key))
&& !providerValueMapMultibindings.containsKey(
keyFactory
.rewrapMapKey(key, TypeNames.PRODUCER, TypeNames.PROVIDER)
.get()));

return new ImmutableSet.Builder<Binding>()
.addAll(plainValueMapMultibindings.values())
.addAll(providerValueMapMultibindings.values())
.addAll(producerValueMapMultibindings.values())
.build();
Set<Key> visitedKeys = new HashSet<>();
return bindingGraph.bindings().stream()
.filter(binding -> binding.kind().equals(MULTIBOUND_MAP))
// Sort by the order of the value in the RequestKind:
// (Map<K, V>, then Map<K, Provider<V>>, then Map<K, Producer<V>>).
.sorted(Comparator.comparing(binding -> MapType.from(binding.key()).valueRequestKind()))
// Only take the first binding (post sorting) per unwrapped key.
.filter(binding -> visitedKeys.add(unwrappedKey(binding)))
.collect(toImmutableSet());
}

private Key unwrappedKey(Binding binding) {
return keyFactory.unwrapMapValueType(binding.key());
}

private ImmutableSet<ContributionBinding> mapBindingContributions(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import static com.google.common.base.Preconditions.checkNotNull;
import static dagger.internal.codegen.binding.MapKeys.getMapKeyExpression;
import static dagger.internal.codegen.binding.SourceFiles.mapFactoryClassName;
import static dagger.internal.codegen.extension.DaggerCollectors.toOptional;

import androidx.room.compiler.processing.XProcessingEnv;
import com.squareup.javapoet.ClassName;
Expand All @@ -35,7 +34,6 @@
import dagger.internal.codegen.binding.MultiboundMapBinding;
import dagger.internal.codegen.javapoet.TypeNames;
import dagger.internal.codegen.model.DependencyRequest;
import java.util.stream.Stream;

/** A factory creation expression for a multibound map. */
final class MapFactoryCreationExpression extends MultibindingFactoryCreationExpression {
Expand Down Expand Up @@ -71,15 +69,7 @@ public CodeBlock creationExpression() {
TypeName valueTypeName = TypeName.OBJECT;
if (!useRawType()) {
MapType mapType = MapType.from(binding.key());
// TODO(ronshapiro): either inline this into mapFactoryClassName, or add a
// mapType.unwrappedValueType() method that doesn't require a framework type
valueTypeName =
Stream.of(TypeNames.PROVIDER, TypeNames.PRODUCER, TypeNames.PRODUCED)
.filter(mapType::valuesAreTypeOf)
.map(mapType::unwrappedValueType)
.collect(toOptional())
.orElseGet(mapType::valueType)
.getTypeName();
valueTypeName = mapType.unwrappedFrameworkValueType().getTypeName();
builder.add(
"<$T, $T>",
useLazyClassKey ? TypeNames.STRING : mapType.keyType().getTypeName(),
Expand Down

0 comments on commit ce0dfe6

Please sign in to comment.