Skip to content

Commit

Permalink
Allow known types to be specified for ProtoReflection. (#208)
Browse files Browse the repository at this point in the history
  • Loading branch information
SanjayVas authored Aug 3, 2023
1 parent 8f5eefc commit d6e938e
Showing 1 changed file with 42 additions and 23 deletions.
65 changes: 42 additions & 23 deletions src/main/kotlin/org/wfanet/measurement/common/ProtoReflection.kt
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,24 @@ import com.google.protobuf.fileDescriptorSet
/** Utility object for protobuf reflection. */
object ProtoReflection {
/**
* Map of file name to [Descriptors.FileDescriptor] for
* [Descriptors.FileDescriptor]s of
* [well-known types](https://developers.google.com/protocol-buffers/docs/reference/google.protobuf)
* .
*/
val WELL_KNOWN_TYPES: Map<String, Descriptors.FileDescriptor> =
val WELL_KNOWN_TYPES: List<Descriptors.FileDescriptor> =
listOf(
TypeProto.getDescriptor(),
DescriptorProtos.getDescriptor(),
WrappersProto.getDescriptor(),
AnyProto.getDescriptor(),
ApiProto.getDescriptor(),
DurationProto.getDescriptor(),
EmptyProto.getDescriptor(),
StructProto.getDescriptor(),
TimestampProto.getDescriptor(),
)
.associateBy { it.name }
TypeProto.getDescriptor(),
DescriptorProtos.getDescriptor(),
WrappersProto.getDescriptor(),
AnyProto.getDescriptor(),
ApiProto.getDescriptor(),
DurationProto.getDescriptor(),
EmptyProto.getDescriptor(),
StructProto.getDescriptor(),
TimestampProto.getDescriptor(),
)

private val WELL_KNOWN_TYPES_BY_NAME = WELL_KNOWN_TYPES.associateBy { it.name }

/**
* Builds a [DescriptorProtos.FileDescriptorSet] from [descriptor], including direct and
Expand All @@ -56,16 +57,18 @@ object ProtoReflection {
* [Descriptors.FileDescriptor]s of [WELL_KNOWN_TYPES] are excluded from the output.
*/
fun buildFileDescriptorSet(
descriptor: Descriptors.Descriptor
descriptor: Descriptors.Descriptor,
knownTypes: Iterable<Descriptors.FileDescriptor> = WELL_KNOWN_TYPES
): DescriptorProtos.FileDescriptorSet {
val fileDescriptors = mutableSetOf<Descriptors.FileDescriptor>()
val rootFileDescriptor: Descriptors.FileDescriptor = descriptor.file
fileDescriptors.addDeps(rootFileDescriptor)
fileDescriptors.add(rootFileDescriptor)

val knownTypesByName: Map<String, Descriptors.FileDescriptor> = knownTypes.byName()
return fileDescriptorSet {
for (fileDescriptor in fileDescriptors) {
if (WELL_KNOWN_TYPES.containsKey(fileDescriptor.name)) {
if (knownTypesByName.containsKey(fileDescriptor.name)) {
continue
}
this.file += fileDescriptor.toProto()
Expand Down Expand Up @@ -94,16 +97,22 @@ object ProtoReflection {

/** Builds [Descriptors.Descriptor]s from [fileDescriptorSets]. */
fun buildDescriptors(
fileDescriptorSets: Iterable<DescriptorProtos.FileDescriptorSet>
fileDescriptorSets: Iterable<DescriptorProtos.FileDescriptorSet>,
knownTypes: Iterable<Descriptors.FileDescriptor> = WELL_KNOWN_TYPES
): List<Descriptors.Descriptor> {
val fileDescriptors =
FileDescriptorMapBuilder(fileDescriptorSets.flatMap { it.fileList }.associateBy { it.name })
val knownTypesByName: Map<String, Descriptors.FileDescriptor> = knownTypes.byName()
val fileDescriptorsByName: Map<String, Descriptors.FileDescriptor> =
FileDescriptorMapBuilder(
fileDescriptorSets.flatMap { it.fileList }.associateBy { it.name },
knownTypesByName
)
.build()
return fileDescriptors.values.flatMap { it.messageTypes }
return fileDescriptorsByName.values.flatMap { it.messageTypes }
}

private class FileDescriptorMapBuilder(
private val fileDescriptorProtos: Map<String, DescriptorProtos.FileDescriptorProto>
private val fileDescriptorProtos: Map<String, DescriptorProtos.FileDescriptorProto>,
private val knownTypesByName: Map<String, Descriptors.FileDescriptor>
) {
/** Builds a [Map] of file name to [Descriptors.FileDescriptor]. */
fun build(): Map<String, Descriptors.FileDescriptor> {
Expand Down Expand Up @@ -140,9 +149,9 @@ object ProtoReflection {
if (containsKey(depName)) {
continue
}
val wellKnownType = WELL_KNOWN_TYPES[depName]
if (wellKnownType != null) {
put(depName, wellKnownType)
val knownType: Descriptors.FileDescriptor? = knownTypesByName[depName]
if (knownType != null) {
put(depName, knownType)
continue
}

Expand All @@ -158,4 +167,14 @@ object ProtoReflection {
}
}
}

private fun Iterable<Descriptors.FileDescriptor>.byName():
Map<String, Descriptors.FileDescriptor> {

if (this === WELL_KNOWN_TYPES) {
// Optimize for common case.
return WELL_KNOWN_TYPES_BY_NAME
}
return associateBy { it.name }
}
}

0 comments on commit d6e938e

Please sign in to comment.