diff --git a/haystack/components/generators/openai.py b/haystack/components/generators/openai.py index 83f1a02d96..48180fc321 100644 --- a/haystack/components/generators/openai.py +++ b/haystack/components/generators/openai.py @@ -170,6 +170,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenAIGenerator": def run( self, prompt: str, + system_prompt: Optional[str] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, generation_kwargs: Optional[Dict[str, Any]] = None, ): @@ -178,6 +179,9 @@ def run( :param prompt: The string prompt to use for text generation. + :param system_prompt: + The system prompt to use for text generation. If this run time system prompt is omitted, the system + prompt, if defined at initialisation time, is used. :param streaming_callback: A callback function that is called when a new token is received from the stream. :param generation_kwargs: @@ -189,7 +193,9 @@ def run( for each response. """ message = ChatMessage.from_user(prompt) - if self.system_prompt: + if system_prompt is not None: + messages = [ChatMessage.from_system(system_prompt), message] + elif self.system_prompt: messages = [ChatMessage.from_system(self.system_prompt), message] else: messages = [message] @@ -237,7 +243,8 @@ def run( "meta": [message.meta for message in completions], } - def _connect_chunks(self, chunk: Any, chunks: List[StreamingChunk]) -> ChatMessage: + @staticmethod + def _connect_chunks(chunk: Any, chunks: List[StreamingChunk]) -> ChatMessage: """ Connects the streaming chunks into a single ChatMessage. """ @@ -252,7 +259,8 @@ def _connect_chunks(self, chunk: Any, chunks: List[StreamingChunk]) -> ChatMessa ) return complete_response - def _build_message(self, completion: Any, choice: Any) -> ChatMessage: + @staticmethod + def _build_message(completion: Any, choice: Any) -> ChatMessage: """ Converts the response from the OpenAI API to a ChatMessage. @@ -276,7 +284,8 @@ def _build_message(self, completion: Any, choice: Any) -> ChatMessage: ) return chat_message - def _build_chunk(self, chunk: Any) -> StreamingChunk: + @staticmethod + def _build_chunk(chunk: Any) -> StreamingChunk: """ Converts the response from the OpenAI API to a StreamingChunk. @@ -293,7 +302,8 @@ def _build_chunk(self, chunk: Any) -> StreamingChunk: chunk_message.meta.update({"model": chunk.model, "index": choice.index, "finish_reason": choice.finish_reason}) return chunk_message - def _check_finish_reason(self, message: ChatMessage) -> None: + @staticmethod + def _check_finish_reason(message: ChatMessage) -> None: """ Check the `finish_reason` returned with the OpenAI completions. diff --git a/test/components/generators/test_openai.py b/test/components/generators/test_openai.py index a164fa5e0e..047299b199 100644 --- a/test/components/generators/test_openai.py +++ b/test/components/generators/test_openai.py @@ -49,23 +49,6 @@ def test_init_with_parameters(self, monkeypatch): assert component.client.timeout == 40.0 assert component.client.max_retries == 1 - def test_init_with_parameters(self, monkeypatch): - monkeypatch.setenv("OPENAI_TIMEOUT", "100") - monkeypatch.setenv("OPENAI_MAX_RETRIES", "10") - component = OpenAIGenerator( - api_key=Secret.from_token("test-api-key"), - model="gpt-4o-mini", - streaming_callback=print_streaming_chunk, - api_base_url="test-base-url", - generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, - ) - assert component.client.api_key == "test-api-key" - assert component.model == "gpt-4o-mini" - assert component.streaming_callback is print_streaming_chunk - assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} - assert component.client.timeout == 100.0 - assert component.client.max_retries == 10 - def test_to_dict_default(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") component = OpenAIGenerator() @@ -331,3 +314,22 @@ def __call__(self, chunk: StreamingChunk) -> None: assert callback.counter > 1 assert "Paris" in callback.responses + + @pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY", None), + reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", + ) + @pytest.mark.integration + def test_run_with_system_prompt(self): + generator = OpenAIGenerator( + model="gpt-4o-mini", + system_prompt="You answer in Portuguese, regardless of the language on which a question is asked", + ) + result = generator.run("Can you explain the Pitagoras therom?") + assert "teorema" in result["replies"][0] + + result = generator.run( + "Can you explain the Pitagoras therom?", + system_prompt="You answer in German, regardless of the language on which a question is asked.", + ) + assert "Pythagoras" in result["replies"][0]