diff --git a/caikit_nlp/toolkit/text_generation/model_run_utils.py b/caikit_nlp/toolkit/text_generation/model_run_utils.py index ac33130b..49cee8d7 100644 --- a/caikit_nlp/toolkit/text_generation/model_run_utils.py +++ b/caikit_nlp/toolkit/text_generation/model_run_utils.py @@ -167,7 +167,7 @@ def generate_text_func( GENERATE_FUNCTION_ARGS ) - error.type_check("", str, eos_token=eos_token) + error.type_check("", str, allow_none=True, eos_token=eos_token) error.type_check("", str, text=text) error.type_check( diff --git a/caikit_nlp/toolkit/text_generation/tgis_utils.py b/caikit_nlp/toolkit/text_generation/tgis_utils.py index 16bf99b9..26ede868 100644 --- a/caikit_nlp/toolkit/text_generation/tgis_utils.py +++ b/caikit_nlp/toolkit/text_generation/tgis_utils.py @@ -73,7 +73,7 @@ def validate_inf_params( ) error.type_check("", str, text=text) error.type_check("", bool, preserve_input_text=preserve_input_text) - error.type_check("", str, eos_token=eos_token) + error.type_check("", str, allow_none=True, eos_token=eos_token) error.type_check( "", int, allow_none=True, max_new_tokens=max_new_tokens )