Skip to content

Commit

Permalink
refactor: default for max_new_tokens to 512 in Hugging Face generators (
Browse files Browse the repository at this point in the history
#7370)

* set default for max_new_tokens to 512 in Hugging Face generators

* add release notes

* fix tests

* remove issues from release note

---------

Co-authored-by: christopherkeibel <[email protected]>
Co-authored-by: Julian Risch <[email protected]>
  • Loading branch information
3 people authored and silvanocerza committed Apr 8, 2024
1 parent f112034 commit 7e59a7f
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 16 deletions.
1 change: 1 addition & 0 deletions haystack/components/generators/chat/hugging_face_tgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __init__(
check_generation_params(generation_kwargs, ["n"])
generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", [])
generation_kwargs["stop_sequences"].extend(stop_words or [])
generation_kwargs.setdefault("max_new_tokens", 512)

self.model = model
self.url = url
Expand Down
1 change: 1 addition & 0 deletions haystack/components/generators/hugging_face_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def __init__(
"Found both the `stop_words` init parameter and the `stopping_criteria` key in `generation_kwargs`. "
"Please specify only one of them."
)
generation_kwargs.setdefault("max_new_tokens", 512)

self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs
self.generation_kwargs = generation_kwargs
Expand Down
1 change: 1 addition & 0 deletions haystack/components/generators/hugging_face_tgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def __init__(
check_generation_params(generation_kwargs, ["n"])
generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", [])
generation_kwargs["stop_sequences"].extend(stop_words or [])
generation_kwargs.setdefault("max_new_tokens", 512)

self.model = model
self.url = url
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
enhancements:
- |
Set max_new_tokens default to 512 in Hugging Face generators.
18 changes: 11 additions & 7 deletions test/components/generators/chat/test_hugging_face_tgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,11 @@ def test_initialize_with_valid_model_and_generation_parameters(
)
generator.warm_up()

assert generator.generation_kwargs == {**generation_kwargs, **{"stop_sequences": ["stop"]}}
assert generator.generation_kwargs == {
**generation_kwargs,
**{"stop_sequences": ["stop"]},
**{"max_new_tokens": 512},
}
assert generator.tokenizer is not None
assert generator.client is not None
assert generator.streaming_callback == streaming_callback
Expand All @@ -90,7 +94,7 @@ def test_to_dict(self, mock_check_valid_model):
# Assert that the init_params dictionary contains the expected keys and values
assert init_params["model"] == "NousResearch/Llama-2-7b-chat-hf"
assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}
assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"]}
assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"], "max_new_tokens": 512}

def test_from_dict(self, mock_check_valid_model):
generator = HuggingFaceTGIChatGenerator(
Expand All @@ -104,7 +108,7 @@ def test_from_dict(self, mock_check_valid_model):

generator_2 = HuggingFaceTGIChatGenerator.from_dict(result)
assert generator_2.model == "NousResearch/Llama-2-7b-chat-hf"
assert generator_2.generation_kwargs == {"n": 5, "stop_sequences": ["stop", "words"]}
assert generator_2.generation_kwargs == {"n": 5, "stop_sequences": ["stop", "words"], "max_new_tokens": 512}
assert generator_2.streaming_callback is streaming_callback_handler

def test_warm_up(self, mock_check_valid_model, mock_auto_tokenizer, mock_list_inference_deployed_models):
Expand Down Expand Up @@ -203,7 +207,7 @@ def test_generate_text_response_with_valid_prompt_and_generation_parameters(
# check kwargs passed to text_generation
# note how n because it is not text generation parameter was not passed to text_generation
_, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "stop_sequences": ["stop"]}
assert kwargs == {"details": True, "stop_sequences": ["stop"], "max_new_tokens": 512}

assert isinstance(response, dict)
assert "replies" in response
Expand Down Expand Up @@ -238,7 +242,7 @@ def test_generate_multiple_text_responses_with_valid_prompt_and_generation_param

# check kwargs passed to text_generation
_, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "stop_sequences": ["stop"]}
assert kwargs == {"details": True, "stop_sequences": ["stop"], "max_new_tokens": 512}

# note how n caused n replies to be generated
assert isinstance(response, dict)
Expand Down Expand Up @@ -266,7 +270,7 @@ def test_generate_text_with_stop_words(
# check kwargs passed to text_generation
# we translate stop_words to stop_sequences
_, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "stop_sequences": ["stop", "words"]}
assert kwargs == {"details": True, "stop_sequences": ["stop", "words"], "max_new_tokens": 512}

# Assert that the response contains the generated replies
assert "replies" in response
Expand Down Expand Up @@ -342,7 +346,7 @@ def mock_iter(self):

# check kwargs passed to text_generation
_, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "stop_sequences": [], "stream": True}
assert kwargs == {"details": True, "stop_sequences": [], "stream": True, "max_new_tokens": 512}

# Assert that the streaming callback was called twice
assert streaming_call_count == 2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_init_default(self, model_info_mock, monkeypatch):
"token": None,
"device": ComponentDevice.resolve_device(None).to_hf(),
}
assert generator.generation_kwargs == {}
assert generator.generation_kwargs == {"max_new_tokens": 512}
assert generator.pipeline is None

def test_init_custom_token(self):
Expand Down Expand Up @@ -125,7 +125,7 @@ def test_init_set_return_full_text(self):
"""
generator = HuggingFaceLocalGenerator(task="text-generation")

assert generator.generation_kwargs == {"return_full_text": False}
assert generator.generation_kwargs == {"max_new_tokens": 512, "return_full_text": False}

def test_init_fails_with_both_stopwords_and_stoppingcriteria(self):
with pytest.raises(
Expand Down
18 changes: 11 additions & 7 deletions test/components/generators/test_hugging_face_tgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,11 @@ def test_initialize_with_valid_model_and_generation_parameters(self, mock_check_
)

assert generator.model == model
assert generator.generation_kwargs == {**generation_kwargs, **{"stop_sequences": ["stop"]}}
assert generator.generation_kwargs == {
**generation_kwargs,
**{"stop_sequences": ["stop"]},
**{"max_new_tokens": 512},
}
assert generator.tokenizer is None
assert generator.client is not None
assert generator.streaming_callback == streaming_callback
Expand All @@ -84,7 +88,7 @@ def test_to_dict(self, mock_check_valid_model):
# Assert that the init_params dictionary contains the expected keys and values
assert init_params["model"] == "mistralai/Mistral-7B-v0.1"
assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}
assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"]}
assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"], "max_new_tokens": 512}

def test_from_dict(self, mock_check_valid_model):
generator = HuggingFaceTGIGenerator(
Expand All @@ -99,7 +103,7 @@ def test_from_dict(self, mock_check_valid_model):
# now deserialize, call from_dict
generator_2 = HuggingFaceTGIGenerator.from_dict(result)
assert generator_2.model == "mistralai/Mistral-7B-v0.1"
assert generator_2.generation_kwargs == {"n": 5, "stop_sequences": ["stop", "words"]}
assert generator_2.generation_kwargs == {"n": 5, "stop_sequences": ["stop", "words"], "max_new_tokens": 512}
assert generator_2.streaming_callback is streaming_callback_handler

def test_initialize_with_invalid_url(self, mock_check_valid_model):
Expand Down Expand Up @@ -135,7 +139,7 @@ def test_generate_text_response_with_valid_prompt_and_generation_parameters(
# check kwargs passed to text_generation
# note how n was not passed to text_generation
_, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "stop_sequences": ["stop"]}
assert kwargs == {"details": True, "stop_sequences": ["stop"], "max_new_tokens": 512}

assert isinstance(response, dict)
assert "replies" in response
Expand Down Expand Up @@ -168,7 +172,7 @@ def test_generate_multiple_text_responses_with_valid_prompt_and_generation_param
# check kwargs passed to text_generation
# note how n was not passed to text_generation
_, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "stop_sequences": ["stop"]}
assert kwargs == {"details": True, "stop_sequences": ["stop"], "max_new_tokens": 512}

assert isinstance(response, dict)
assert "replies" in response
Expand Down Expand Up @@ -208,7 +212,7 @@ def test_generate_text_with_stop_words(

# check kwargs passed to text_generation
_, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "stop_sequences": ["stop", "words"]}
assert kwargs == {"details": True, "stop_sequences": ["stop", "words"], "max_new_tokens": 512}

# Assert that the response contains the generated replies
assert "replies" in response
Expand Down Expand Up @@ -284,7 +288,7 @@ def mock_iter(self):

# check kwargs passed to text_generation
_, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "stop_sequences": [], "stream": True}
assert kwargs == {"details": True, "stop_sequences": [], "stream": True, "max_new_tokens": 512}

# Assert that the streaming callback was called twice
assert streaming_call_count == 2
Expand Down

0 comments on commit 7e59a7f

Please sign in to comment.