diff --git a/src/suql/prompt_continuation.py b/src/suql/prompt_continuation.py index 2fa5031..58d5736 100644 --- a/src/suql/prompt_continuation.py +++ b/src/suql/prompt_continuation.py @@ -87,6 +87,7 @@ def _generate( no_line_break_start = "" no_line_break_length = 0 kwargs = { + "model": engine, "messages": [ {"role": "system", "content": filled_prompt + no_line_break_start} ], @@ -97,17 +98,6 @@ def _generate( "presence_penalty": presence_penalty, "stop": stop_tokens, } - engine_model_map = { - "gpt-4": "gpt-4", - "gpt-35-turbo": "gpt-3.5-turbo-1106", - "gpt-3.5-turbo": "gpt-3.5-turbo-1106", - "gpt-4-turbo": "gpt-4-1106-preview", - } - kwargs.update( - { - "model": engine_model_map[engine] if engine in engine_model_map else engine - } - ) generation_output = chat_completion_with_backoff(**kwargs) generation_output = no_line_break_start + generation_output