Skip to content

Commit

Permalink
avoid un-necessary lambda
Browse files Browse the repository at this point in the history
  • Loading branch information
devxpy committed Feb 27, 2024
1 parent 6ed3cc3 commit 11b3812
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions daras_ai_v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1852,13 +1852,11 @@ def get_example_request_body(
state: dict,
include_all: bool = False,
) -> dict:
preferred_fields = cls.get_example_preferred_fields(state)
return extract_model_fields(
cls.RequestModel,
state,
condition=lambda field_name, field: include_all
or field.required
or field_name in preferred_fields,
include_all=include_all,
preferred_fields=cls.get_example_preferred_fields(state),
)

def get_example_response_body(
Expand All @@ -1878,7 +1876,7 @@ def get_example_response_body(
output = extract_model_fields(
self.ResponseModel,
state,
condition=lambda field_name, field: include_all or field.required,
include_all=include_all,
)
if as_async:
return dict(
Expand Down Expand Up @@ -1945,13 +1943,18 @@ def render_output_caption():
def extract_model_fields(
model: typing.Type[BaseModel],
state: dict,
condition: typing.Callable[[str, "pydantic.ModelField"], bool],
include_all: bool = False,
preferred_fields: list[str] = None,
) -> dict:
"""Only returns required fields unless include_all is set to True."""
return {
field_name: state.get(field_name)
for field_name, field in model.__fields__.items()
if condition(field_name, field)
if (
include_all
or field.required
or (preferred_fields and field_name in preferred_fields)
)
}


Expand Down

0 comments on commit 11b3812

Please sign in to comment.