diff --git a/fuzztest/internal/domains/protobuf_domain_impl.h b/fuzztest/internal/domains/protobuf_domain_impl.h index 279bf2dd..621b11db 100644 --- a/fuzztest/internal/domains/protobuf_domain_impl.h +++ b/fuzztest/internal/domains/protobuf_domain_impl.h @@ -456,7 +456,7 @@ class ProtobufDomainUntypedImpl corpus_type Init(absl::BitGenRef prng) { if (auto seed = this->MaybeGetRandomSeed(prng)) return *seed; FUZZTEST_INTERNAL_CHECK( - !customized_fields_.empty() || !IsNonTerminatingRecursive(), + !IsCustomizedRecursivelyOnly() || !IsNonTerminatingRecursive(), "Cannot set recursive fields by default."); const auto* descriptor = prototype_.Get()->GetDescriptor(); corpus_type val; @@ -468,10 +468,10 @@ class ProtobufDomainUntypedImpl if (!oneof_to_field.contains(oneof->index())) { oneof_to_field[oneof->index()] = SelectAFieldIndexInOneof( oneof, prng, - /*non_recursive_only=*/customized_fields_.empty()); + /*non_recursive_only=*/IsCustomizedRecursivelyOnly()); } if (oneof_to_field[oneof->index()] != field->index()) continue; - } else if (!IsRequired(field) && customized_fields_.empty() && + } else if (!IsRequired(field) && IsCustomizedRecursivelyOnly() && IsFieldRecursive(field)) { // We avoid initializing non-required recursive fields by default (if // they are not explicitly customized). Otherwise, the initialization @@ -1662,37 +1662,36 @@ class ProtobufDomainUntypedImpl bool IsNonTerminatingRecursive() { absl::flat_hash_setGetDescriptor())> parents; - return IsProtoRecursive(prototype_.Get()->GetDescriptor(), parents, policy_, + return IsProtoRecursive(prototype_.Get()->GetDescriptor(), parents, /*consider_non_terminating_recursions=*/true); } bool IsFieldRecursive(const FieldDescriptor* field) { if (!field->message_type()) return false; absl::flat_hash_setmessage_type())> parents; - return IsProtoRecursive(field->message_type(), parents, policy_, + return IsProtoRecursive(field->message_type(), parents, /*consider_non_terminating_recursions=*/false); } bool IsOneofRecursive(const OneofDescriptor* oneof, absl::flat_hash_set& parents, - const ProtoPolicy& policy, bool consider_non_terminating_recursions) const { bool is_oneof_recursive = false; for (int i = 0; i < oneof->field_count(); ++i) { const auto* field = oneof->field(i); - const auto field_policy = policy.GetOptionalPolicy(field); + const auto field_policy = policy_.GetOptionalPolicy(field); if (field_policy == OptionalPolicy::kAlwaysNull) continue; const auto* child = field->message_type(); if (consider_non_terminating_recursions) { is_oneof_recursive = field_policy != OptionalPolicy::kWithNull && child && - IsProtoRecursive(child, parents, policy, + IsProtoRecursive(child, parents, consider_non_terminating_recursions); if (!is_oneof_recursive) { return false; } } else { - if (child && IsProtoRecursive(child, parents, policy, + if (child && IsProtoRecursive(child, parents, consider_non_terminating_recursions)) { return true; } @@ -1704,13 +1703,12 @@ class ProtobufDomainUntypedImpl template bool IsProtoRecursive(const Descriptor* descriptor, absl::flat_hash_set& parents, - const ProtoPolicy& policy, bool consider_non_terminating_recursions) const { if (parents.contains(descriptor)) return true; parents.insert(descriptor); for (int i = 0; i < descriptor->oneof_decl_count(); ++i) { const auto* oneof = descriptor->oneof_decl(i); - if (IsOneofRecursive(oneof, parents, policy, + if (IsOneofRecursive(oneof, parents, consider_non_terminating_recursions)) { parents.erase(descriptor); return true; @@ -1728,23 +1726,23 @@ class ProtobufDomainUntypedImpl if (consider_non_terminating_recursions) { const bool should_be_set = IsRequired(field) || - (field->is_optional() && - policy.GetOptionalPolicy(field) == OptionalPolicy::kWithoutNull) || + (field->is_optional() && policy_.GetOptionalPolicy(field) == + OptionalPolicy::kWithoutNull) || (field->is_repeated() && - policy.GetMinRepeatedFieldSize(field).has_value() && - *policy.GetMinRepeatedFieldSize(field) > 0); + policy_.GetMinRepeatedFieldSize(field).has_value() && + *policy_.GetMinRepeatedFieldSize(field) > 0); if (!should_be_set) continue; } else { const bool can_be_set = IsRequired(field) || (field->is_optional() && - policy.GetOptionalPolicy(field) != OptionalPolicy::kAlwaysNull) || + policy_.GetOptionalPolicy(field) != OptionalPolicy::kAlwaysNull) || (field->is_repeated() && - (!policy.GetMaxRepeatedFieldSize(field).has_value() || - *policy.GetMaxRepeatedFieldSize(field) > 0)); + (!policy_.GetMaxRepeatedFieldSize(field).has_value() || + *policy_.GetMaxRepeatedFieldSize(field) > 0)); if (!can_be_set) continue; } - if (IsProtoRecursive(child, parents, policy, + if (IsProtoRecursive(child, parents, consider_non_terminating_recursions)) { parents.erase(descriptor); return true; @@ -1767,6 +1765,15 @@ class ProtobufDomainUntypedImpl field->containing_type()->map_value() == field; } + // Returns true if all domain customizations are defined recursively (through + // policy) and no individual field is customized at the top level. This check + // would be useful in recursion analysis. In particular, recursion analysis + // is only meaningful when all customizations are also recursive. + bool IsCustomizedRecursivelyOnly() { + return customized_fields_.empty() && always_set_oneofs_.empty() && + uncustomizable_oneofs_.empty() && unset_oneof_fields_.empty(); + } + PrototypePtr prototype_; bool use_lazy_initialization_;