diff --git a/src/main/kotlin/org/wfanet/measurement/common/ProtoReflection.kt b/src/main/kotlin/org/wfanet/measurement/common/ProtoReflection.kt index bfb7ebe09..30020ce47 100644 --- a/src/main/kotlin/org/wfanet/measurement/common/ProtoReflection.kt +++ b/src/main/kotlin/org/wfanet/measurement/common/ProtoReflection.kt @@ -31,23 +31,24 @@ import com.google.protobuf.fileDescriptorSet /** Utility object for protobuf reflection. */ object ProtoReflection { /** - * Map of file name to [Descriptors.FileDescriptor] for + * [Descriptors.FileDescriptor]s of * [well-known types](https://developers.google.com/protocol-buffers/docs/reference/google.protobuf) * . */ - val WELL_KNOWN_TYPES: Map = + val WELL_KNOWN_TYPES: List = listOf( - TypeProto.getDescriptor(), - DescriptorProtos.getDescriptor(), - WrappersProto.getDescriptor(), - AnyProto.getDescriptor(), - ApiProto.getDescriptor(), - DurationProto.getDescriptor(), - EmptyProto.getDescriptor(), - StructProto.getDescriptor(), - TimestampProto.getDescriptor(), - ) - .associateBy { it.name } + TypeProto.getDescriptor(), + DescriptorProtos.getDescriptor(), + WrappersProto.getDescriptor(), + AnyProto.getDescriptor(), + ApiProto.getDescriptor(), + DurationProto.getDescriptor(), + EmptyProto.getDescriptor(), + StructProto.getDescriptor(), + TimestampProto.getDescriptor(), + ) + + private val WELL_KNOWN_TYPES_BY_NAME = WELL_KNOWN_TYPES.associateBy { it.name } /** * Builds a [DescriptorProtos.FileDescriptorSet] from [descriptor], including direct and @@ -56,16 +57,18 @@ object ProtoReflection { * [Descriptors.FileDescriptor]s of [WELL_KNOWN_TYPES] are excluded from the output. */ fun buildFileDescriptorSet( - descriptor: Descriptors.Descriptor + descriptor: Descriptors.Descriptor, + knownTypes: Iterable = WELL_KNOWN_TYPES ): DescriptorProtos.FileDescriptorSet { val fileDescriptors = mutableSetOf() val rootFileDescriptor: Descriptors.FileDescriptor = descriptor.file fileDescriptors.addDeps(rootFileDescriptor) fileDescriptors.add(rootFileDescriptor) + val knownTypesByName: Map = knownTypes.byName() return fileDescriptorSet { for (fileDescriptor in fileDescriptors) { - if (WELL_KNOWN_TYPES.containsKey(fileDescriptor.name)) { + if (knownTypesByName.containsKey(fileDescriptor.name)) { continue } this.file += fileDescriptor.toProto() @@ -94,16 +97,22 @@ object ProtoReflection { /** Builds [Descriptors.Descriptor]s from [fileDescriptorSets]. */ fun buildDescriptors( - fileDescriptorSets: Iterable + fileDescriptorSets: Iterable, + knownTypes: Iterable = WELL_KNOWN_TYPES ): List { - val fileDescriptors = - FileDescriptorMapBuilder(fileDescriptorSets.flatMap { it.fileList }.associateBy { it.name }) + val knownTypesByName: Map = knownTypes.byName() + val fileDescriptorsByName: Map = + FileDescriptorMapBuilder( + fileDescriptorSets.flatMap { it.fileList }.associateBy { it.name }, + knownTypesByName + ) .build() - return fileDescriptors.values.flatMap { it.messageTypes } + return fileDescriptorsByName.values.flatMap { it.messageTypes } } private class FileDescriptorMapBuilder( - private val fileDescriptorProtos: Map + private val fileDescriptorProtos: Map, + private val knownTypesByName: Map ) { /** Builds a [Map] of file name to [Descriptors.FileDescriptor]. */ fun build(): Map { @@ -140,9 +149,9 @@ object ProtoReflection { if (containsKey(depName)) { continue } - val wellKnownType = WELL_KNOWN_TYPES[depName] - if (wellKnownType != null) { - put(depName, wellKnownType) + val knownType: Descriptors.FileDescriptor? = knownTypesByName[depName] + if (knownType != null) { + put(depName, knownType) continue } @@ -158,4 +167,14 @@ object ProtoReflection { } } } + + private fun Iterable.byName(): + Map { + + if (this === WELL_KNOWN_TYPES) { + // Optimize for common case. + return WELL_KNOWN_TYPES_BY_NAME + } + return associateBy { it.name } + } }