diff --git a/ultravox/model/ultravox_config.py b/ultravox/model/ultravox_config.py index fa77b14..3671250 100644 --- a/ultravox/model/ultravox_config.py +++ b/ultravox/model/ultravox_config.py @@ -63,6 +63,8 @@ class UltravoxConfig(transformers.PretrainedConfig): The LoRA configuration for finetuning the text model. audio_model_lora_config (`LoraConfigSimplified`, *optional*): The LoRA configuration for finetuning the audio model. + audio_latency_block_size (`int`, *optional*, defaults to `None`): + The latency block size for simulating audio streaming. Example: @@ -105,6 +107,7 @@ def __init__( projector_act: str = "swiglu", text_model_lora_config: Optional[LoraConfigSimplified] = None, audio_model_lora_config: Optional[LoraConfigSimplified] = None, + audio_latency_block_size: Optional[int] = None, **kwargs, ): self.ignore_index = ignore_index @@ -147,6 +150,7 @@ def __init__( if isinstance(audio_model_lora_config, dict) else dataclasses.asdict(audio_model_lora_config or LoraConfigSimplified()) ) + self.audio_latency_block_size = audio_latency_block_size self.vocab_size = self.text_config.vocab_size diff --git a/ultravox/model/ultravox_config_test.py b/ultravox/model/ultravox_config_test.py index ebb5cd2..1e92ec4 100644 --- a/ultravox/model/ultravox_config_test.py +++ b/ultravox/model/ultravox_config_test.py @@ -14,9 +14,15 @@ def test_can_load_release(model_id: str): ) config_from_dict = ultravox_config.UltravoxConfig(**orig_config.to_dict()) config_from_diff_dict = ultravox_config.UltravoxConfig(**orig_config.to_diff_dict()) + # To not inadvertently ignore other keys, we explicitly define keys we require to ignore. + keys_to_ignore = ("audio_latency_block_size",) + orig_values = { + **{k: None for k in keys_to_ignore}, + **orig_config.to_dict(), + } - assert config_from_dict.to_dict() == orig_config.to_dict() - assert config_from_diff_dict.to_dict() == orig_config.to_dict() + assert config_from_dict.to_dict() == orig_values + assert config_from_diff_dict.to_dict() == orig_values assert config_from_dict.text_config.to_dict() == orig_config.text_config.to_dict() assert config_from_dict.audio_config.to_dict() == orig_config.audio_config.to_dict() @@ -25,8 +31,8 @@ def test_can_load_release(model_id: str): config_reloaded_diff = ultravox_config.UltravoxConfig( **config_from_dict.to_diff_dict() ) - assert config_reloaded.to_dict() == orig_config.to_dict() - assert config_reloaded_diff.to_dict() == orig_config.to_dict() + assert config_reloaded.to_dict() == orig_values + assert config_reloaded_diff.to_dict() == orig_values def test_no_config_when_id_present(): diff --git a/ultravox/model/ultravox_model.py b/ultravox/model/ultravox_model.py index 1b0a575..ccea3e0 100644 --- a/ultravox/model/ultravox_model.py +++ b/ultravox/model/ultravox_model.py @@ -291,7 +291,7 @@ def _create_audio_tower( config.audio_latency_block_size, dtype=config.torch_dtype ) else: - assert config.audio_latency_block_size not in ( + assert config.audio_latency_block_size in ( None, 0, ), "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'" @@ -305,7 +305,7 @@ def _create_audio_tower( config.audio_latency_block_size, dtype=config.torch_dtype ) else: - assert config.audio_latency_block_size not in ( + assert config.audio_latency_block_size in ( None, 0, ), "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'"