Skip to content

Commit

Permalink
Fix API deserialization if a list field is missing from the JSON
Browse files Browse the repository at this point in the history
  • Loading branch information
Acly committed Jan 17, 2025
1 parent ccd19e7 commit 38c30b1
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
8 changes: 6 additions & 2 deletions ai_diffusion/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import Field, dataclass, field, is_dataclass, fields
from dataclasses import Field, dataclass, field, is_dataclass, fields, MISSING
from copy import copy
from enum import Enum
from types import GenericAlias, UnionType
Expand Down Expand Up @@ -293,8 +293,12 @@ def _object(self, type, input: dict):
return type(*values)

def _field(self, field: Field, value):
if value is None:
if value is None and field.default is not MISSING:
return field.default
elif value is None and field.default_factory is not MISSING:
return field.default_factory()
elif value is None:
return None
field_type = field.type
if isinstance(field_type, UnionType):
field_type = get_args(field_type)[0]
Expand Down
22 changes: 21 additions & 1 deletion tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
ExtentInput,
ImageInput,
ConditioningInput,
RegionInput,
)
from ai_diffusion.image import Extent, Image, ImageFileFormat
from ai_diffusion.image import Extent, Image, ImageFileFormat, Bounds
from ai_diffusion.resources import ControlMode
from ai_diffusion.util import ensure

Expand Down Expand Up @@ -58,3 +59,22 @@ def test_serialize():
assert _ensure_cmp(result_control[1].image) == _ensure_cmp(input_control[1].image)
assert result_control[2].image is None
assert result == input


def test_deserialize_list_default():
input = WorkflowInput(WorkflowKind.generate)
input.images = ImageInput(ExtentInput(Extent(1, 1), Extent(2, 2), Extent(3, 3), Extent(4, 4)))
input.conditioning = ConditioningInput(
"prompt",
regions=[
RegionInput(
Image.create(Extent(2, 2), Qt.GlobalColor.red), Bounds(0, 0, 2, 2), "positive", []
)
],
)

data = input.to_dict()
del data["conditioning"]["regions"][0]["loras"]
result = WorkflowInput.from_dict(data)
assert result.conditioning is not None
assert len(result.conditioning.regions[0].loras) == 0

0 comments on commit 38c30b1

Please sign in to comment.