Skip to content

Commit

Permalink
Factor out OneOfSchema from OneOfStringSchema and OneOfIntSchema (#131)
Browse files Browse the repository at this point in the history
* factor out OneOfSchema from OneOfStringSchema and OneOfIntSchema

* move some assignments post_init for OneOf*Schema

* change OneOfIntSchema's jsonschema and openapi fragment 'const' type from 'string' to 'integer'
  • Loading branch information
mfleader authored May 2, 2024
1 parent b6be00d commit 5ad27d8
Show file tree
Hide file tree
Showing 2 changed files with 338 additions and 207 deletions.
335 changes: 128 additions & 207 deletions src/arcaflow_plugin_sdk/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2578,7 +2578,114 @@ def _to_openapi_fragment(


@dataclass
class OneOfStringSchema(_JSONSchemaGenerator, _OpenAPIGenerator):
class OneOfSchema(_JSONSchemaGenerator, _OpenAPIGenerator):
types: typing.Union[
Dict[str, typing.Annotated[_OBJECT_LIKE, discriminator("type_id")]],
Dict[int, typing.Annotated[_OBJECT_LIKE, discriminator("type_id")]],
]
discriminator_inlined: typing.Annotated[
bool,
_name("Discriminator field inlined"),
_description(
"Whether or not the discriminator is inlined in the underlying"
" objects' schema"
),
]
oneof_type: typing.Annotated[str, _name("One Of Type Schema Name")] = None
discriminator_type: typing.Annotated[str, _name("Discriminator Type")] = (
None
)
discriminator_field_name: typing.Annotated[
str,
_name("Discriminator field name"),
_description(
"Name of the field used to discriminate between possible values."
),
] = "_type"

def _insert_discriminator(
self,
discriminated_object: typing.Dict[str, typing.Any],
discriminator_val: str,
) -> typing.Dict[str, typing.Any]:
"""Add a discriminator field as a property of a member type.
This function adds a member type's discriminator field as a property
with a constant value equal to its discriminated value. The
discriminator field is moved to the zeroth index of the list of
required fields in a data packet.
:param discriminated_object: A Python dict which represents the
relevant fragment of the scope's JSON definition.
:param discriminator_val: The value that represents the given object in
its discriminated union.
"""
if self.discriminator_inlined:
# update the object's schema to show the only valid value
# for this object's discriminator
discriminated_object["properties"][
self.discriminator_field_name
] = {
"type": self.discriminator_type,
"const": discriminator_val,
}
# discriminator field is already present in the required
# list when the discriminator is inlined
discriminated_object["required"].remove(
self.discriminator_field_name
)
# discriminator must have the first position
discriminated_object["required"].insert(
0, self.discriminator_field_name
)

def _to_jsonschema_fragment(
self, scope: typing.ForwardRef("ScopeSchema"), defs: _JSONSchemaDefs
) -> any:
one_of = []
for k, v in self.types.items():
# noinspection PyProtectedMember
_ = scope.objects[v.id]._to_jsonschema_fragment(scope, defs)
self._insert_discriminator(defs.defs[v.id], str(k))
if v.display is not None:
if v.display.name is not None:
defs.defs[v.id]["title"] = v.display.name
if v.display.description is not None:
defs.defs[v.id]["description"] = v.display.description
name = v.id + self.oneof_type + str(k)
defs.defs[name] = defs.defs[v.id]
one_of.append({"$ref": "#/$defs/" + name})
return {"oneOf": one_of}

def _to_openapi_fragment(
self, scope: typing.ForwardRef("ScopeSchema"), defs: _OpenAPIComponents
) -> any:
one_of = []
discriminator_mapping = {}
for k, v in self.types.items():
# noinspection PyProtectedMember
_ = scope.objects[v.id]._to_openapi_fragment(scope, defs)
name = v.id + self.oneof_type + str(k)
discriminator_mapping[k] = "#/components/schemas/" + name
self._insert_discriminator(defs.defs[v.id], str(k))
if v.display is not None:
if v.display.name is not None:
defs.defs[v.id]["title"] = v.display.name
if v.display.description is not None:
defs.defs[v.id]["description"] = v.display.description
defs.components[name] = defs.defs[v.id]
one_of.append({"$ref": "#/components/schemas/" + name})
return {
"oneOf": one_of,
"discriminator": {
"propertyName": self.discriminator_field_name,
"mapping": discriminator_mapping,
},
}


@dataclass
class OneOfStringSchema(OneOfSchema):
"""This class holds the definition of variable types with a string
discriminator. This type acts as a split for a case where multiple possible
object types can be present in a field. This type requires that there be a
Expand Down Expand Up @@ -2701,112 +2808,14 @@ class OneOfStringSchema(_JSONSchemaGenerator, _OpenAPIGenerator):
""" # noqa: E501

types: Dict[str, typing.Annotated[_OBJECT_LIKE, discriminator("type_id")]]
discriminator_inlined: typing.Annotated[
bool,
_name("Discriminator field inlined"),
_description(
"True if the discriminator is a field in each schema of the"
" underlying objects"
),
]
discriminator_field_name: typing.Annotated[
str,
_name("Discriminator field name"),
_description(
"Name of the field whose value is used to discriminate between"
" possible subobject types. If this field is present in any of the"
" subobjects it must have a type of string."
),
] = "_type"

def _insert_discriminator(
self,
discriminated_object: typing.Dict[str, typing.Any],
discriminator_val: str,
) -> typing.Dict[str, typing.Any]:
"""Add a discriminator field as a property of a member type.
This function adds a member type's discriminator field as a property
with a constant value equal to its discriminated value. The
discriminator field is moved to the zeroth index of the list of
required fields in a data packet.
:param discriminated_object: A Python dict which represents the
relevant fragment of the scope's JSON definition.
:param discriminator_val: The value that represents the given object in
its discriminated union.
"""
if self.discriminator_inlined:
# update the object's schema to show the only valid value
# for this object's discriminator
discriminated_object["properties"][
self.discriminator_field_name
] = {
"type": "string",
"const": discriminator_val,
}
# discriminator field is already present in the required
# list when the discriminator is inlined
discriminated_object["required"].remove(
self.discriminator_field_name
)
# discriminator must have the first position
discriminated_object["required"].insert(
0, self.discriminator_field_name
)

def _to_jsonschema_fragment(
self, scope: typing.ForwardRef("ScopeSchema"), defs: _JSONSchemaDefs
) -> any:
one_of = []
for k, v in self.types.items():
# noinspection PyProtectedMember
scope.objects[v.id]._to_jsonschema_fragment(scope, defs)

self._insert_discriminator(defs.defs[v.id], k)

if v.display is not None:
if v.display.name is not None:
defs.defs[v.id]["title"] = v.display.name
if v.display.description is not None:
defs.defs[v.id]["description"] = v.display.description

name = v.id + "_discriminated_string_" + _id_typeize(k)
defs.defs[name] = defs.defs[v.id]
one_of.append({"$ref": "#/$defs/" + name})
return {"oneOf": one_of}

def _to_openapi_fragment(
self, scope: typing.ForwardRef("ScopeSchema"), defs: _OpenAPIComponents
) -> any:
one_of = []
discriminator_mapping = {}
for k, v in self.types.items():
# noinspection PyProtectedMember
scope.objects[v.id]._to_openapi_fragment(scope, defs)

name = v.id + "_discriminated_string_" + _id_typeize(k)
discriminator_mapping[k] = "#/components/schemas/" + name
self._insert_discriminator(defs.defs[v.id], k)
if v.display is not None:
if v.display.name is not None:
defs.defs[v.id]["title"] = v.display.name
if v.display.description is not None:
defs.defs[v.id]["description"] = v.display.description

defs.components[name] = defs.defs[v.id]
one_of.append({"$ref": "#/components/schemas/" + name})
return {
"oneOf": one_of,
"discriminator": {
"propertyName": self.discriminator_field_name,
"mapping": discriminator_mapping,
},
}
def __post_init__(self):
self.oneof_type = "_discriminated_string_"
self.discriminator_type = "string"


@dataclass
class OneOfIntSchema(_JSONSchemaGenerator, _OpenAPIGenerator):
class OneOfIntSchema(OneOfSchema):
"""This class holds the definition of variable types with an integer
discriminator. This type acts as a split for a case where multiple possible
object types can be present in a field. This type requires that there be a
Expand Down Expand Up @@ -2912,106 +2921,10 @@ class OneOfIntSchema(_JSONSchemaGenerator, _OpenAPIGenerator):
""" # noqa: E501

types: Dict[int, typing.Annotated[_OBJECT_LIKE, discriminator("type_id")]]
discriminator_inlined: typing.Annotated[
bool,
_name("Discriminator field inlined"),
_description(
"Whether or not the discriminator is inlined in the underlying"
" objects' schema"
),
]
discriminator_field_name: typing.Annotated[
str,
_name("Discriminator field name"),
_description(
"Name of the field used to discriminate between possible values."
" If this field ispresent on any of the component objects it must"
" also be an int."
),
] = "_type"

def _insert_discriminator(
self,
discriminated_object: typing.Dict[str, typing.Any],
discriminator_val: str,
) -> typing.Dict[str, typing.Any]:
"""Add a discriminator field as a property of a member type.
This function adds a member type's discriminator field as a property
with a constant value equal to its discriminated value. The
discriminator field is moved to the zeroth index of the list of
required fields in a data packet.
:param discriminated_object: A Python dict which represents the
relevant fragment of the scope's JSON definition.
:param discriminator_val: The value that represents the given object in
its discriminated union.
"""
if self.discriminator_inlined:
# update the object's schema to show the only valid value
# for this object's discriminator
discriminated_object["properties"][
self.discriminator_field_name
] = {
"type": "string",
"const": discriminator_val,
}
# discriminator field is already present in the required
# list when the discriminator is inlined
discriminated_object["required"].remove(
self.discriminator_field_name
)
# discriminator must have the first position
discriminated_object["required"].insert(
0, self.discriminator_field_name
)

def _to_jsonschema_fragment(
self, scope: typing.ForwardRef("ScopeSchema"), defs: _JSONSchemaDefs
) -> any:
one_of = []
for k, v in self.types.items():
# noinspection PyProtectedMember
scope.objects[v.id]._to_jsonschema_fragment(scope, defs)

self._insert_discriminator(defs.defs[v.id], str(k))
if v.display is not None:
if v.display.name is not None:
defs.defs[v.id]["title"] = v.display.name
if v.display.description is not None:
defs.defs[v.id]["description"] = v.display.description
name = v.id + "_discriminated_int_" + str(k)
defs.defs[name] = defs.defs[v.id]
one_of.append({"$ref": "#/$defs/" + name})
return {"oneOf": one_of}

def _to_openapi_fragment(
self, scope: typing.ForwardRef("ScopeSchema"), defs: _OpenAPIComponents
) -> any:
one_of = []
discriminator_mapping = {}
for k, v in self.types.items():
# noinspection PyProtectedMember
scope.objects[v.id]._to_openapi_fragment(scope, defs)
name = v.id + "_discriminated_int_" + str(k)
discriminator_mapping[k] = "#/components/schemas/" + name

self._insert_discriminator(defs.defs[v.id], str(k))
if v.display is not None:
if v.display.name is not None:
defs.defs[v.id]["title"] = v.display.name
if v.display.description is not None:
defs.defs[v.id]["description"] = v.display.description

defs.components[name] = defs.defs[v.id]
one_of.append({"$ref": "#/components/schemas/" + name})
return {
"oneOf": one_of,
"discriminator": {
"propertyName": self.discriminator_field_name,
"mapping": discriminator_mapping,
},
}
def __post_init__(self):
self.oneof_type = "_discriminated_int_"
self.discriminator_type = "integer"


@dataclass
Expand Down Expand Up @@ -5621,7 +5534,10 @@ def __init__(
):
# noinspection PyArgumentList
OneOfStringSchema.__init__(
self, types, discriminator_inlined, discriminator_field_name
self,
types=types,
discriminator_inlined=discriminator_inlined,
discriminator_field_name=discriminator_field_name,
)
_OneOfType.__init__(
self,
Expand Down Expand Up @@ -5659,7 +5575,10 @@ def __init__(
):
# noinspection PyArgumentList
OneOfIntSchema.__init__(
self, types, discriminator_inlined, discriminator_field_name
self,
types=types,
discriminator_inlined=discriminator_inlined,
discriminator_field_name=discriminator_field_name,
)
_OneOfType.__init__(
self,
Expand Down Expand Up @@ -7019,12 +6938,14 @@ def _resolve_union(
types[discriminator_value] = f.type
if discriminator_type is str:
return OneOfStringType(
types,
scope,
types=types,
scope=scope,
discriminator_inlined=False,
)
else:
return OneOfIntType(types, scope, discriminator_inlined=False)
return OneOfIntType(
types=types, scope=scope, discriminator_inlined=False
)

@classmethod
def _resolve_pattern(cls, t, type_hints: type, path, scope: ScopeType):
Expand Down
Loading

0 comments on commit 5ad27d8

Please sign in to comment.