Skip to content

Commit

Permalink
Fix a corner case where oneof customizations are ignored and simplify…
Browse files Browse the repository at this point in the history
… the interfaces.

In particular, after this change, when a oneof field nullness is customized, we skip recursion analysis for protos (similar to other individual field customizations).

PiperOrigin-RevId: 713280456
  • Loading branch information
hadi88 authored and copybara-github committed Jan 13, 2025
1 parent c99c121 commit 5d96311
Showing 1 changed file with 26 additions and 19 deletions.
45 changes: 26 additions & 19 deletions fuzztest/internal/domains/protobuf_domain_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ class ProtobufDomainUntypedImpl
corpus_type Init(absl::BitGenRef prng) {
if (auto seed = this->MaybeGetRandomSeed(prng)) return *seed;
FUZZTEST_INTERNAL_CHECK(
!customized_fields_.empty() || !IsNonTerminatingRecursive(),
!IsCustomizedRecursivelyOnly() || !IsNonTerminatingRecursive(),
"Cannot set recursive fields by default.");
const auto* descriptor = prototype_.Get()->GetDescriptor();
corpus_type val;
Expand All @@ -468,10 +468,10 @@ class ProtobufDomainUntypedImpl
if (!oneof_to_field.contains(oneof->index())) {
oneof_to_field[oneof->index()] = SelectAFieldIndexInOneof(
oneof, prng,
/*non_recursive_only=*/customized_fields_.empty());
/*non_recursive_only=*/IsCustomizedRecursivelyOnly());
}
if (oneof_to_field[oneof->index()] != field->index()) continue;
} else if (!IsRequired(field) && customized_fields_.empty() &&
} else if (!IsRequired(field) && IsCustomizedRecursivelyOnly() &&
IsFieldRecursive(field)) {
// We avoid initializing non-required recursive fields by default (if
// they are not explicitly customized). Otherwise, the initialization
Expand Down Expand Up @@ -1662,37 +1662,36 @@ class ProtobufDomainUntypedImpl

bool IsNonTerminatingRecursive() {
absl::flat_hash_set<decltype(prototype_.Get()->GetDescriptor())> parents;
return IsProtoRecursive(prototype_.Get()->GetDescriptor(), parents, policy_,
return IsProtoRecursive(prototype_.Get()->GetDescriptor(), parents,
/*consider_non_terminating_recursions=*/true);
}

bool IsFieldRecursive(const FieldDescriptor* field) {
if (!field->message_type()) return false;
absl::flat_hash_set<decltype(field->message_type())> parents;
return IsProtoRecursive(field->message_type(), parents, policy_,
return IsProtoRecursive(field->message_type(), parents,
/*consider_non_terminating_recursions=*/false);
}

bool IsOneofRecursive(const OneofDescriptor* oneof,
absl::flat_hash_set<const Descriptor*>& parents,
const ProtoPolicy<Message>& policy,
bool consider_non_terminating_recursions) const {
bool is_oneof_recursive = false;
for (int i = 0; i < oneof->field_count(); ++i) {
const auto* field = oneof->field(i);
const auto field_policy = policy.GetOptionalPolicy(field);
const auto field_policy = policy_.GetOptionalPolicy(field);
if (field_policy == OptionalPolicy::kAlwaysNull) continue;
const auto* child = field->message_type();
if (consider_non_terminating_recursions) {
is_oneof_recursive =
field_policy != OptionalPolicy::kWithNull && child &&
IsProtoRecursive(child, parents, policy,
IsProtoRecursive(child, parents,
consider_non_terminating_recursions);
if (!is_oneof_recursive) {
return false;
}
} else {
if (child && IsProtoRecursive(child, parents, policy,
if (child && IsProtoRecursive(child, parents,
consider_non_terminating_recursions)) {
return true;
}
Expand All @@ -1704,13 +1703,12 @@ class ProtobufDomainUntypedImpl
template <typename Descriptor>
bool IsProtoRecursive(const Descriptor* descriptor,
absl::flat_hash_set<const Descriptor*>& parents,
const ProtoPolicy<Message>& policy,
bool consider_non_terminating_recursions) const {
if (parents.contains(descriptor)) return true;
parents.insert(descriptor);
for (int i = 0; i < descriptor->oneof_decl_count(); ++i) {
const auto* oneof = descriptor->oneof_decl(i);
if (IsOneofRecursive(oneof, parents, policy,
if (IsOneofRecursive(oneof, parents,
consider_non_terminating_recursions)) {
parents.erase(descriptor);
return true;
Expand All @@ -1728,23 +1726,23 @@ class ProtobufDomainUntypedImpl
if (consider_non_terminating_recursions) {
const bool should_be_set =
IsRequired(field) ||
(field->is_optional() &&
policy.GetOptionalPolicy(field) == OptionalPolicy::kWithoutNull) ||
(field->is_optional() && policy_.GetOptionalPolicy(field) ==
OptionalPolicy::kWithoutNull) ||
(field->is_repeated() &&
policy.GetMinRepeatedFieldSize(field).has_value() &&
*policy.GetMinRepeatedFieldSize(field) > 0);
policy_.GetMinRepeatedFieldSize(field).has_value() &&
*policy_.GetMinRepeatedFieldSize(field) > 0);
if (!should_be_set) continue;
} else {
const bool can_be_set =
IsRequired(field) ||
(field->is_optional() &&
policy.GetOptionalPolicy(field) != OptionalPolicy::kAlwaysNull) ||
policy_.GetOptionalPolicy(field) != OptionalPolicy::kAlwaysNull) ||
(field->is_repeated() &&
(!policy.GetMaxRepeatedFieldSize(field).has_value() ||
*policy.GetMaxRepeatedFieldSize(field) > 0));
(!policy_.GetMaxRepeatedFieldSize(field).has_value() ||
*policy_.GetMaxRepeatedFieldSize(field) > 0));
if (!can_be_set) continue;
}
if (IsProtoRecursive(child, parents, policy,
if (IsProtoRecursive(child, parents,
consider_non_terminating_recursions)) {
parents.erase(descriptor);
return true;
Expand All @@ -1767,6 +1765,15 @@ class ProtobufDomainUntypedImpl
field->containing_type()->map_value() == field;
}

// Returns true if all domain customizations are defined recursively (through
// policy) and no individual field is customized at the top level. This check
// would be useful in recursion analysis. In particular, recursion analysis
// is only meaningful when all customizations are also recursive.
bool IsCustomizedRecursivelyOnly() {
return customized_fields_.empty() && always_set_oneofs_.empty() &&
uncustomizable_oneofs_.empty() && unset_oneof_fields_.empty();
}

PrototypePtr<Message> prototype_;
bool use_lazy_initialization_;

Expand Down

0 comments on commit 5d96311

Please sign in to comment.