diff --git a/aana/core/models/base.py b/aana/core/models/base.py index d129c53a..efb295e7 100644 --- a/aana/core/models/base.py +++ b/aana/core/models/base.py @@ -58,7 +58,7 @@ def merged_options(default_options: OptionType, options: OptionType) -> OptionTy if type(default_options) != type(options): raise ValueError("Option type mismatch.") # noqa: TRY003 default_options_dict = default_options.model_dump() - for k, v in options.model_dump().items(): + for k, v in options.model_dump(exclude_unset=True).items(): if v is not None: default_options_dict[k] = v return options.__class__.model_validate(default_options_dict) diff --git a/aana/deployments/hf_text_generation_deployment.py b/aana/deployments/hf_text_generation_deployment.py index 0d4be92b..a46c5d8e 100644 --- a/aana/deployments/hf_text_generation_deployment.py +++ b/aana/deployments/hf_text_generation_deployment.py @@ -68,8 +68,11 @@ async def generate_stream( prompt = str(prompt) if sampling_params is None: - sampling_params = SamplingParams() - sampling_params = merged_options(self.default_sampling_params, sampling_params) + sampling_params = self.default_sampling_params + else: + sampling_params = merged_options( + self.default_sampling_params, sampling_params + ) prompt_input = self.tokenizer( prompt, return_tensors="pt", add_special_tokens=False diff --git a/aana/deployments/idefics_2_deployment.py b/aana/deployments/idefics_2_deployment.py index 6536bf88..09069e84 100644 --- a/aana/deployments/idefics_2_deployment.py +++ b/aana/deployments/idefics_2_deployment.py @@ -105,8 +105,11 @@ async def chat_stream( transformers.set_seed(42) if sampling_params is None: - sampling_params = SamplingParams() - sampling_params = merged_options(self.default_sampling_params, sampling_params) + sampling_params = self.default_sampling_params + else: + sampling_params = merged_options( + self.default_sampling_params, sampling_params + ) messages, images = dialog.to_objects() text = self.processor.apply_chat_template(messages, add_generation_prompt=True) @@ -191,8 +194,11 @@ async def chat_batch( transformers.set_seed(42) if sampling_params is None: - sampling_params = SamplingParams() - sampling_params = merged_options(self.default_sampling_params, sampling_params) + sampling_params = self.default_sampling_params + else: + sampling_params = merged_options( + self.default_sampling_params, sampling_params + ) text_batch = [] image_batch = [] diff --git a/aana/deployments/vllm_deployment.py b/aana/deployments/vllm_deployment.py index 708849b5..50386cc6 100644 --- a/aana/deployments/vllm_deployment.py +++ b/aana/deployments/vllm_deployment.py @@ -227,8 +227,11 @@ async def generate_stream( # noqa: C901 prompt_token_ids = prompt if sampling_params is None: - sampling_params = SamplingParams() - sampling_params = merged_options(self.default_sampling_params, sampling_params) + sampling_params = self.default_sampling_params + else: + sampling_params = merged_options( + self.default_sampling_params, sampling_params + ) json_schema = sampling_params.json_schema regex_string = sampling_params.regex_string diff --git a/aana/tests/files/expected/image_text_generation/phi-3.5-vision-instruct_vllm_deployment_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json b/aana/tests/files/expected/image_text_generation/phi-3.5-vision-instruct_vllm_deployment_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json index c44d4688..ba93432f 100644 --- a/aana/tests/files/expected/image_text_generation/phi-3.5-vision-instruct_vllm_deployment_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json +++ b/aana/tests/files/expected/image_text_generation/phi-3.5-vision-instruct_vllm_deployment_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json @@ -1 +1 @@ -" The painting is done by Vincent van Gogh." \ No newline at end of file +" Vincent van Gogh" \ No newline at end of file diff --git a/aana/tests/units/test_merge_options.py b/aana/tests/units/test_merge_options.py index f95ddd09..6f813bfe 100644 --- a/aana/tests/units/test_merge_options.py +++ b/aana/tests/units/test_merge_options.py @@ -12,6 +12,7 @@ class MyOptions(BaseModel): field1: str field2: int | None = None field3: bool + field4: str = "default" def test_merged_options_same_type(): @@ -46,3 +47,17 @@ class AnotherOptions(BaseModel): with pytest.raises(ValueError): merged_options(default, to_merge) + + +def test_merged_options_unset(): + """Test merged_options with unset fields.""" + default = MyOptions(field1="default1", field2=2, field3=True, field4="new_default") + to_merge = MyOptions(field1="merge1", field3=False) # field4 is not set + merged = merged_options(default, to_merge) + + assert merged.field1 == "merge1" + assert merged.field2 == 2 + assert merged.field3 == False + assert ( + merged.field4 == "new_default" + ) # Should retain value from default_options as it's not set in options