diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 85d4dd5fe..2c574b3f3 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -1852,13 +1852,11 @@ def get_example_request_body( state: dict, include_all: bool = False, ) -> dict: - preferred_fields = cls.get_example_preferred_fields(state) return extract_model_fields( cls.RequestModel, state, - condition=lambda field_name, field: include_all - or field.required - or field_name in preferred_fields, + include_all=include_all, + preferred_fields=cls.get_example_preferred_fields(state), ) def get_example_response_body( @@ -1878,7 +1876,7 @@ def get_example_response_body( output = extract_model_fields( self.ResponseModel, state, - condition=lambda field_name, field: include_all or field.required, + include_all=include_all, ) if as_async: return dict( @@ -1945,13 +1943,18 @@ def render_output_caption(): def extract_model_fields( model: typing.Type[BaseModel], state: dict, - condition: typing.Callable[[str, "pydantic.ModelField"], bool], + include_all: bool = False, + preferred_fields: list[str] = None, ) -> dict: """Only returns required fields unless include_all is set to True.""" return { field_name: state.get(field_name) for field_name, field in model.__fields__.items() - if condition(field_name, field) + if ( + include_all + or field.required + or (preferred_fields and field_name in preferred_fields) + ) }