From 93ac306f30fe0622f8950ba121f02c094a6439a5 Mon Sep 17 00:00:00 2001 From: Matt Leader Date: Wed, 3 Apr 2024 16:36:15 -0400 Subject: [PATCH] Oneof refactor 2 (#126) * move classes up another level * simplify test fixture class names * replace some string literals of discriminators and object ids with a variable or call to a class name --- src/arcaflow_plugin_sdk/test_schema.py | 407 +++++++++++++------------ 1 file changed, 219 insertions(+), 188 deletions(-) diff --git a/src/arcaflow_plugin_sdk/test_schema.py b/src/arcaflow_plugin_sdk/test_schema.py index 08eb351..5fd41bf 100644 --- a/src/arcaflow_plugin_sdk/test_schema.py +++ b/src/arcaflow_plugin_sdk/test_schema.py @@ -15,6 +15,53 @@ SchemaBuildException, ) +# default discriminator field name used by the OneOfType +# when no discriminator field name is declared +default_discriminator = "_type" + +# The string "type_" is the discriminator identifier that will +# be embedded in StrInline. It must match the OneOfType's +# discriminator field name. +discriminator_field_name = "type_" + + +@dataclasses.dataclass +class Basic: + msg: str + + +@dataclasses.dataclass +class Basic2: + msg2: str + + +@dataclasses.dataclass +class Basic3: + b: int + + +@dataclasses.dataclass +class InlineStr: + type_: str + a: str + + +@dataclasses.dataclass +class InlineInt: + type_: int + msg: str + + +@dataclasses.dataclass +class InlineInt2: + type_: int + msg2: str + + +@dataclasses.dataclass +class BasicUnion: + union_basic: typing.Union[Basic, Basic2] + class Color(enum.Enum): GREEN = "green" @@ -448,47 +495,26 @@ class TestSubclass(TestParent): class OneOfTest(unittest.TestCase): - @dataclasses.dataclass - class OneOfData1: - a: str - - @dataclasses.dataclass - class OneOfData2: - b: int - - @dataclasses.dataclass - class OneOfDataEmbedded1: - type_: str - a: str - - # default discriminator field name - discriminator_default = "_type" - - # The string "type_" is the discriminator identifier that will - # be embedded in OneOfDataEmbedded1. It must match the OneOfType's - # discriminator field name. - discriminator_field_name = "type_" - def setUp(self): self.obj_b = schema.ObjectType( - OneOfTest.OneOfData2, {"b": PropertyType(schema.IntType())} + Basic3, {"b": PropertyType(schema.IntType())} ) self.scope_basic = schema.ScopeType( { "a": schema.ObjectType( - OneOfTest.OneOfData1, - {"a": PropertyType(schema.StringType())}, + Basic, + {"msg": PropertyType(schema.StringType())}, ), "b": self.obj_b, }, "a", ) - self.scope_embedded = schema.ScopeType( + self.scope_mixed_type = schema.ScopeType( { "a": schema.ObjectType( - OneOfTest.OneOfDataEmbedded1, + InlineStr, { - self.discriminator_field_name: PropertyType( + discriminator_field_name: PropertyType( schema.StringType(), ), "a": PropertyType(schema.StringType()), @@ -499,18 +525,6 @@ def setUp(self): "a", ) - def test_unserialize_error_discriminator_type(self): - s_type = schema.OneOfStringType( - { - "a": schema.RefType("a", self.scope_basic), - "b": schema.RefType("b", self.scope_basic), - }, - scope=self.scope_basic, - ) - with self.assertRaises(ConstraintException): - s_type.unserialize({ - self.discriminator_default: "1", 1: "Hello world!"}) - def test_unserialize(self): s_type = schema.OneOfStringType( { @@ -526,31 +540,35 @@ def test_unserialize(self): with self.assertRaises(ConstraintException): s_type.unserialize({"b": 42}) - # invalid type for 'data' argument + # Invalid type, string, for discriminator value + # that requires an integer + with self.assertRaises(ConstraintException): + s_type.unserialize({default_discriminator: "k", 1: "Hello world!"}) + + # Invalid type for 'data' argument with self.assertRaises(ConstraintException): s_type.unserialize([]) # Mismatching key value with self.assertRaises(ConstraintException): - s_type.unserialize({ - self.discriminator_default: "a", "b": "Hello world!"}) + s_type.unserialize( + {default_discriminator: "a", "b": "Hello world!"} + ) # Invalid key value with self.assertRaises(ConstraintException): - s_type.unserialize({ - self.discriminator_default: 1, "a": "Hello world!"}) + s_type.unserialize({default_discriminator: 1, "a": "Hello world!"}) # Invalid discriminator with self.assertRaises(ConstraintException): - s_type.unserialize({ - self.discriminator_default: 1, "b": "Hello world!"}) + s_type.unserialize({default_discriminator: 1, "b": "Hello world!"}) - unserialized_data: OneOfTest.OneOfData1 = s_type.unserialize( - {self.discriminator_default: "a", "a": "Hello world!"} + unserialized_data: Basic = s_type.unserialize( + {default_discriminator: "a", "msg": "Hello world!"} ) - self.assertIsInstance(unserialized_data, OneOfTest.OneOfData1) - self.assertEqual(unserialized_data.a, "Hello world!") - unserialized_data2: OneOfTest.OneOfData2 = s_type.unserialize( - {self.discriminator_default: "b", "b": 42} + self.assertIsInstance(unserialized_data, Basic) + self.assertEqual(unserialized_data.msg, "Hello world!") + unserialized_data2: Basic3 = s_type.unserialize( + {default_discriminator: "b", "b": 42} ) - self.assertIsInstance(unserialized_data2, OneOfTest.OneOfData2) + self.assertIsInstance(unserialized_data2, Basic3) self.assertEqual(unserialized_data2.b, 42) s_type_int = schema.OneOfIntType( @@ -560,107 +578,108 @@ def test_unserialize(self): }, scope=self.scope_basic, ) - unserialized_data3: OneOfTest.OneOfData1 = s_type_int.unserialize( - {self.discriminator_default: 1, "a": "Hello world!"} + unserialized_data3: Basic = s_type_int.unserialize( + {default_discriminator: 1, "msg": "Hello world!"} ) - self.assertIsInstance(unserialized_data3, OneOfTest.OneOfData1) - self.assertEqual(unserialized_data3.a, "Hello world!") + self.assertIsInstance(unserialized_data3, Basic) + self.assertEqual(unserialized_data3.msg, "Hello world!") def test_unserialize_embedded(self): s = schema.OneOfStringType( { - "a": schema.RefType("a", self.scope_embedded), - "b": schema.RefType("b", self.scope_embedded), + "a": schema.RefType("a", self.scope_mixed_type), + "b": schema.RefType("b", self.scope_mixed_type), }, - scope=self.scope_embedded, - discriminator_field_name=self.discriminator_field_name, + scope=self.scope_mixed_type, + discriminator_field_name=discriminator_field_name, ) - unserialized_data: OneOfTest.OneOfDataEmbedded1 = s.unserialize( - {self.discriminator_field_name: "a", "a": "Hello world!"} + unserialized_data: InlineStr = s.unserialize( + {discriminator_field_name: "a", "a": "Hello world!"} ) - self.assertIsInstance(unserialized_data, OneOfTest.OneOfDataEmbedded1) + self.assertIsInstance(unserialized_data, InlineStr) self.assertEqual( - getattr(unserialized_data, self.discriminator_field_name), "a") + getattr(unserialized_data, discriminator_field_name), "a" + ) self.assertEqual(unserialized_data.a, "Hello world!") - unserialized_data2: OneOfTest.OneOfData2 = s.unserialize( - {self.discriminator_field_name: "b", "b": 42} + unserialized_data2: Basic3 = s.unserialize( + {discriminator_field_name: "b", "b": 42} ) - self.assertIsInstance(unserialized_data2, OneOfTest.OneOfData2) + self.assertIsInstance(unserialized_data2, Basic3) self.assertEqual(unserialized_data2.b, 42) def test_validation(self): - s = schema.OneOfStringType[OneOfTest.OneOfDataEmbedded1]( + s = schema.OneOfStringType[InlineStr]( { - "a": schema.RefType("a", self.scope_embedded), - "b": schema.RefType("b", self.scope_embedded), + "a": schema.RefType("a", self.scope_mixed_type), + "b": schema.RefType("b", self.scope_mixed_type), }, - scope=self.scope_embedded, - discriminator_field_name=self.discriminator_field_name, + scope=self.scope_mixed_type, + discriminator_field_name=discriminator_field_name, ) with self.assertRaises(ConstraintException): # noinspection PyTypeChecker - s.validate(OneOfTest.OneOfDataEmbedded1(None, "Hello world!")) + s.validate(InlineStr(None, "Hello world!")) with self.assertRaises(ConstraintException): - s.validate(OneOfTest.OneOfDataEmbedded1("b", "Hello world!")) + s.validate(InlineStr("b", "Hello world!")) with self.assertRaises(ConstraintException): # noinspection PyTypeChecker - s.validate(OneOfTest.OneOfData1("Hello world!")) + s.validate(Basic("Hello world!")) - s.validate(OneOfTest.OneOfDataEmbedded1("a", "Hello world!")) + s.validate(InlineStr("a", "Hello world!")) def test_serialize(self): s = schema.OneOfStringType( { - "a": schema.RefType("a", self.scope_embedded), - "b": schema.RefType("b", self.scope_embedded), + "a": schema.RefType("a", self.scope_mixed_type), + "b": schema.RefType("b", self.scope_mixed_type), }, - scope=self.scope_embedded, - discriminator_field_name=self.discriminator_field_name, + scope=self.scope_mixed_type, + discriminator_field_name=discriminator_field_name, ) self.assertEqual( - s.serialize(OneOfTest.OneOfDataEmbedded1("a", "Hello world!")), - {self.discriminator_field_name: "a", "a": "Hello world!"}, + s.serialize(InlineStr("a", "Hello world!")), + {discriminator_field_name: "a", "a": "Hello world!"}, ) self.assertEqual( - s.serialize(OneOfTest.OneOfData2(42)), - {self.discriminator_field_name: "b", "b": 42}, + s.serialize(Basic3(42)), + {discriminator_field_name: "b", "b": 42}, ) with self.assertRaises(ConstraintException): # noinspection PyTypeChecker - s.serialize(OneOfTest.OneOfData1("Hello world!")) + s.serialize(Basic("Hello world!")) with self.assertRaises(ConstraintException): - s.serialize(OneOfTest.OneOfDataEmbedded1("b", "Hello world!")) + s.serialize(InlineStr("b", "Hello world!")) def test_object(self): scope = schema.ScopeType({}, "") s = schema.OneOfStringType( { "a": schema.ObjectType( - OneOfTest.OneOfDataEmbedded1, + InlineStr, { - self.discriminator_field_name: PropertyType( + discriminator_field_name: PropertyType( schema.StringType(), ), "a": PropertyType(schema.StringType()), }, ), "b": schema.ObjectType( - OneOfTest.OneOfData2, {"b": PropertyType(schema.IntType())} + Basic3, {"b": PropertyType(schema.IntType())} ), }, scope=scope, - discriminator_field_name=self.discriminator_field_name, + discriminator_field_name=discriminator_field_name, ) unserialized_data = s.unserialize( - {self.discriminator_field_name: "b", "b": 42} + {discriminator_field_name: "b", "b": 42} ) - self.assertIsInstance(unserialized_data, OneOfTest.OneOfData2) + self.assertIsInstance(unserialized_data, Basic3) class SerializationTest(unittest.TestCase): @@ -1064,75 +1083,74 @@ class TestData: ) def test_union(self): - @dataclasses.dataclass - class A: - a: str - - @dataclasses.dataclass - class B: - b: str - - @dataclasses.dataclass - class TestData: - a: typing.Union[A, B] - - scope = schema.build_object_schema(TestData) - self.assertEqual("TestData", scope.root) - self.assertIsInstance(scope.objects["TestData"], schema.ObjectType) - self.assertIsInstance(scope.objects["A"], schema.ObjectType) - self.assertIsInstance(scope.objects["B"], schema.ObjectType) + scope = schema.build_object_schema(BasicUnion) + self.assertEqual(BasicUnion.__name__, scope.root) + self.assertIsInstance( + scope.objects[BasicUnion.__name__], schema.ObjectType + ) + self.assertIsInstance(scope.objects[Basic.__name__], schema.ObjectType) + self.assertIsInstance( + scope.objects[Basic2.__name__], schema.ObjectType + ) self.assertIsInstance( - scope.objects["TestData"].properties["a"].type, + scope.objects[BasicUnion.__name__].properties["union_basic"].type, schema.OneOfStringType, ) one_of_type: schema.OneOfStringType = ( - scope.objects["TestData"].properties["a"].type + scope.objects[BasicUnion.__name__].properties["union_basic"].type + ) + self.assertEqual( + one_of_type.discriminator_field_name, default_discriminator + ) + self.assertIsInstance( + one_of_type.types[Basic.__name__], schema.RefType ) - self.assertEqual(one_of_type.discriminator_field_name, "_type") - self.assertIsInstance(one_of_type.types["A"], schema.RefType) - self.assertEqual(one_of_type.types["A"].id, "A") - self.assertIsInstance(one_of_type.types["B"], schema.RefType) - self.assertEqual(one_of_type.types["B"].id, "B") + self.assertEqual(one_of_type.types[Basic.__name__].id, "Basic") + self.assertIsInstance( + one_of_type.types[Basic2.__name__], schema.RefType + ) + self.assertEqual(one_of_type.types[Basic2.__name__].id, "Basic2") def test_union_custom_discriminator(self): - @dataclasses.dataclass - class A: - discriminator: int - a: str - - @dataclasses.dataclass - class B: - discriminator: int - b: str - @dataclasses.dataclass class TestData: - a: typing.Annotated[ + union: typing.Annotated[ typing.Union[ - typing.Annotated[A, schema.discriminator_value(1)], - typing.Annotated[B, schema.discriminator_value(2)], + typing.Annotated[InlineInt, schema.discriminator_value(1)], + typing.Annotated[ + InlineInt2, schema.discriminator_value(2) + ], ], - schema.discriminator("discriminator"), + schema.discriminator(discriminator_field_name), ] scope = schema.build_object_schema(TestData) - self.assertEqual("TestData", scope.root) - self.assertIsInstance(scope.objects["TestData"], schema.ObjectType) - self.assertIsInstance(scope.objects["A"], schema.ObjectType) - self.assertIsInstance(scope.objects["B"], schema.ObjectType) + self.assertEqual(TestData.__name__, scope.root) + self.assertIsInstance( + scope.objects[TestData.__name__], schema.ObjectType + ) + self.assertIsInstance( + scope.objects[InlineInt.__name__], schema.ObjectType + ) + self.assertIsInstance( + scope.objects[InlineInt2.__name__], schema.ObjectType + ) self.assertIsInstance( - scope.objects["TestData"].properties["a"].type, schema.OneOfIntType + scope.objects[TestData.__name__].properties["union"].type, + schema.OneOfIntType, ) one_of_type: schema.OneOfIntType = ( - scope.objects["TestData"].properties["a"].type + scope.objects[TestData.__name__].properties["union"].type + ) + self.assertEqual( + one_of_type.discriminator_field_name, discriminator_field_name ) - self.assertEqual(one_of_type.discriminator_field_name, "discriminator") self.assertIsInstance(one_of_type.types[1], schema.RefType) - self.assertEqual(one_of_type.types[1].id, "A") + self.assertEqual(one_of_type.types[1].id, InlineInt.__name__) self.assertIsInstance(one_of_type.types[2], schema.RefType) - self.assertEqual(one_of_type.types[2].id, "B") + self.assertEqual(one_of_type.types[2].id, InlineInt2.__name__) def test_optional(self): @dataclasses.dataclass @@ -1240,6 +1258,7 @@ class TestData: class JSONSchemaTest(unittest.TestCase): + def _execute_test_cases(self, test_cases): for name in test_cases.keys(): defs = schema._JSONSchemaDefs() @@ -1451,43 +1470,31 @@ class TestData: self.assertEqual(expected, result) def test_one_of(self): - @dataclasses.dataclass - class A: - a: str - - @dataclasses.dataclass - class B: - b: str - - @dataclasses.dataclass - class TestData: - a: typing.Union[A, B] - scope = schema.ScopeType( {}, - "TestData", + BasicUnion.__name__, ) scope.objects = { - "TestData": schema.ObjectType( - TestData, + str(BasicUnion.__name__): schema.ObjectType( + BasicUnion, { - "a": schema.PropertyType( + "union_basic": schema.PropertyType( schema.OneOfStringType( { - "a": schema.RefType("A", scope), - "b": schema.RefType("B", scope), + "a": schema.RefType(Basic.__name__, scope), + "b": schema.RefType(Basic2.__name__, scope), }, scope, - "_type", + default_discriminator, ) ) }, ), - "A": schema.ObjectType( - A, {"a": schema.PropertyType(schema.StringType())} + Basic.__name__: schema.ObjectType( + Basic, {"msg": schema.PropertyType(schema.StringType())} ), - "B": schema.ObjectType( - B, {"b": schema.PropertyType(schema.StringType())} + Basic2.__name__: schema.ObjectType( + Basic2, {"msg2": schema.PropertyType(schema.StringType())} ), } defs = schema._JSONSchemaDefs() @@ -1495,79 +1502,103 @@ class TestData: self.assertEqual( { "$defs": { - "TestData": { + BasicUnion.__name__: { "type": "object", "properties": { - "a": { + "union_basic": { "oneOf": [ { "$ref": ( - "#/$defs/A_discriminated_string_a" + f"#/$defs/{Basic.__name__}" + "_discriminated_string_" + "a" ) }, { "$ref": ( - "#/$defs/B_discriminated_string_b" + f"#/$defs/{Basic2.__name__}" + "_discriminated_string_" + "b" ) }, ] } }, - "required": ["a"], + "required": ["union_basic"], "additionalProperties": False, "dependentRequired": {}, }, - "A": { + Basic.__name__: { "type": "object", "properties": { - "a": {"type": "string"}, - "_type": {"type": "string", "const": "a"}, + "msg": {"type": "string"}, + default_discriminator: { + "type": "string", + "const": "a", + }, }, - "required": ["_type", "a"], + "required": [default_discriminator, "msg"], "additionalProperties": False, "dependentRequired": {}, }, - "A_discriminated_string_a": { + f"{Basic.__name__}_discriminated_string_a": { "type": "object", "properties": { - "a": {"type": "string"}, - "_type": {"type": "string", "const": "a"}, + "msg": {"type": "string"}, + default_discriminator: { + "type": "string", + "const": "a", + }, }, - "required": ["_type", "a"], + "required": [default_discriminator, "msg"], "additionalProperties": False, "dependentRequired": {}, }, - "B": { + Basic2.__name__: { "type": "object", "properties": { - "b": {"type": "string"}, - "_type": {"type": "string", "const": "b"}, + "msg2": {"type": "string"}, + default_discriminator: { + "type": "string", + "const": "b", + }, }, - "required": ["_type", "b"], + "required": [default_discriminator, "msg2"], "additionalProperties": False, "dependentRequired": {}, }, - "B_discriminated_string_b": { + f"{Basic2.__name__}_discriminated_string_b": { "type": "object", "properties": { - "b": {"type": "string"}, - "_type": {"type": "string", "const": "b"}, + "msg2": {"type": "string"}, + default_discriminator: { + "type": "string", + "const": "b", + }, }, - "required": ["_type", "b"], + "required": [default_discriminator, "msg2"], "additionalProperties": False, "dependentRequired": {}, }, }, "type": "object", "properties": { - "a": { + "union_basic": { "oneOf": [ - {"$ref": "#/$defs/A_discriminated_string_a"}, - {"$ref": "#/$defs/B_discriminated_string_b"}, + { + "$ref": f"#/$defs/{Basic.__name__}" + f"_discriminated_string_" + f"a" + }, + { + "$ref": f"#/$defs/{Basic2.__name__}" + f"_discriminated_string_" + f"b" + }, ] } }, - "required": ["a"], + "required": ["union_basic"], "additionalProperties": False, "dependentRequired": {}, },