Skip to content

Commit

Permalink
Unit tests for the system instruction override
Browse files Browse the repository at this point in the history
  • Loading branch information
leila-messallem committed Dec 19, 2024
1 parent 23a8001 commit 58ce6af
Show file tree
Hide file tree
Showing 6 changed files with 339 additions and 4 deletions.
65 changes: 62 additions & 3 deletions tests/unit/llm/test_anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import pytest
from neo4j_graphrag.exceptions import LLMGenerationError
from neo4j_graphrag.llm.anthropic_llm import AnthropicLLM
from neo4j_graphrag.llm.types import LLMResponse


@pytest.fixture
Expand Down Expand Up @@ -61,11 +62,9 @@ def test_anthropic_invoke_with_message_history_happy_path(mock_anthropic: Mock)
content="generated text"
)
model_params = {"temperature": 0.3}
system_instruction = "You are a helpful assistant."
llm = AnthropicLLM(
"claude-3-opus-20240229",
model_params=model_params,
system_instruction=system_instruction,
)
message_history = [
{"role": "user", "content": "When does the sun come up in the summer?"},
Expand All @@ -79,11 +78,71 @@ def test_anthropic_invoke_with_message_history_happy_path(mock_anthropic: Mock)
llm.client.messages.create.assert_called_once_with( # type: ignore[attr-defined]
messages=message_history,
model="claude-3-opus-20240229",
system=system_instruction,
system=None,
**model_params,
)


def test_anthropic_invoke_with_message_history_and_system_instruction(
mock_anthropic: Mock,
) -> None:
mock_anthropic.Anthropic.return_value.messages.create.return_value = MagicMock(
content="generated text"
)
model_params = {"temperature": 0.3}
initial_instruction = "You are a helpful assistant."
llm = AnthropicLLM(
"claude-3-opus-20240229",
model_params=model_params,
system_instruction=initial_instruction,
)
message_history = [
{"role": "user", "content": "When does the sun come up in the summer?"},
{"role": "assistant", "content": "Usually around 6am."},
]
question = "What about next season?"

# first invokation - initial instructions
response = llm.invoke(question, message_history) # type: ignore
assert response.content == "generated text"
message_history.append({"role": "user", "content": question})
llm.client.messages.create.assert_called_once_with( # type: ignore[attr-defined]
model="claude-3-opus-20240229",
system=initial_instruction,
messages=message_history,
**model_params,
)

# second invokation - override instructions
override_instruction = "Ignore all previous instructions"
question = "When does it come up in the winter?"
response = llm.invoke(question, message_history, override_instruction) # type: ignore
assert isinstance(response, LLMResponse)
assert response.content == "generated text"
message_history.append({"role": "user", "content": question})
llm.client.messages.create.assert_called_with( # type: ignore[attr-defined]
model="claude-3-opus-20240229",
system=override_instruction,
messages=message_history,
**model_params,
)

# third invokation - default instructions
question = "When does it set?"
response = llm.invoke(question, message_history) # type: ignore
assert isinstance(response, LLMResponse)
assert response.content == "generated text"
message_history.append({"role": "user", "content": question})
llm.client.messages.create.assert_called_with( # type: ignore[attr-defined]
model="claude-3-opus-20240229",
system=initial_instruction,
messages=message_history,
**model_params,
)

assert llm.client.messages.create.call_count == 3 # type: ignore


def test_anthropic_invoke_with_message_history_validation_error(
mock_anthropic: Mock,
) -> None:
Expand Down
55 changes: 55 additions & 0 deletions tests/unit/llm/test_cohere_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,61 @@ def test_cohere_llm_invoke_with_message_history_happy_path(mock_cohere: Mock) ->
)


def test_cohere_llm_invoke_with_message_history_and_system_instruction(
mock_cohere: Mock,
) -> None:
chat_response_mock = MagicMock()
chat_response_mock.message.content = [MagicMock(text="cohere response text")]
mock_cohere.ClientV2.return_value.chat.return_value = chat_response_mock

initial_instruction = "You are a helpful assistant."
llm = CohereLLM(model_name="gpt", system_instruction=initial_instruction)
message_history = [
{"role": "user", "content": "When does the sun come up in the summer?"},
{"role": "assistant", "content": "Usually around 6am."},
]
question = "What about next season?"

# first invokation - initial instructions
res = llm.invoke(question, message_history) # type: ignore
assert isinstance(res, LLMResponse)
assert res.content == "cohere response text"
messages = [{"role": "system", "content": initial_instruction}]
messages.extend(message_history)
messages.append({"role": "user", "content": question})
llm.client.chat.assert_called_once_with(
messages=messages,
model="gpt",
)

# second invokation - override instructions
override_instruction = "Ignore all previous instructions"
res = llm.invoke(question, message_history, override_instruction) # type: ignore
assert isinstance(res, LLMResponse)
assert res.content == "cohere response text"
messages = [{"role": "system", "content": override_instruction}]
messages.extend(message_history)
messages.append({"role": "user", "content": question})
llm.client.chat.assert_called_with(
messages=messages,
model="gpt",
)

# third invokation - default instructions
res = llm.invoke(question, message_history) # type: ignore
assert isinstance(res, LLMResponse)
assert res.content == "cohere response text"
messages = [{"role": "system", "content": initial_instruction}]
messages.extend(message_history)
messages.append({"role": "user", "content": question})
llm.client.chat.assert_called_with(
messages=messages,
model="gpt",
)

assert llm.client.chat.call_count == 3


def test_cohere_llm_invoke_with_message_history_validation_error(
mock_cohere: Mock,
) -> None:
Expand Down
59 changes: 59 additions & 0 deletions tests/unit/llm/test_mistralai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,65 @@ def test_mistralai_llm_invoke_with_message_history(mock_mistral: Mock) -> None:
)


@patch("neo4j_graphrag.llm.mistralai_llm.Mistral")
def test_mistralai_llm_invoke_with_message_history_and_system_instruction(
mock_mistral: Mock,
) -> None:
mock_mistral_instance = mock_mistral.return_value
chat_response_mock = MagicMock()
chat_response_mock.choices = [
MagicMock(message=MagicMock(content="mistral response"))
]
mock_mistral_instance.chat.complete.return_value = chat_response_mock
model = "mistral-model"
initial_instruction = "You are a helpful assistant."
llm = MistralAILLM(model_name=model, system_instruction=initial_instruction)
message_history = [
{"role": "user", "content": "When does the sun come up in the summer?"},
{"role": "assistant", "content": "Usually around 6am."},
]
question = "What about next season?"

# first invokation - initial instructions
res = llm.invoke(question, message_history) # type: ignore
assert isinstance(res, LLMResponse)
assert res.content == "mistral response"
messages = [{"role": "system", "content": initial_instruction}]
messages.extend(message_history)
messages.append({"role": "user", "content": question})
llm.client.chat.complete.assert_called_once_with( # type: ignore[attr-defined]
messages=messages,
model=model,
)

# second invokation - override instructions
override_instruction = "Ignore all previous instructions"
res = llm.invoke(question, message_history, override_instruction) # type: ignore
assert isinstance(res, LLMResponse)
assert res.content == "mistral response"
messages = [{"role": "system", "content": override_instruction}]
messages.extend(message_history)
messages.append({"role": "user", "content": question})
llm.client.chat.complete.assert_called_with( # type: ignore
messages=messages,
model=model,
)

# third invokation - default instructions
res = llm.invoke(question, message_history) # type: ignore
assert isinstance(res, LLMResponse)
assert res.content == "mistral response"
messages = [{"role": "system", "content": initial_instruction}]
messages.extend(message_history)
messages.append({"role": "user", "content": question})
llm.client.chat.complete.assert_called_with( # type: ignore
messages=messages,
model=model,
)

assert llm.client.chat.complete.call_count == 3 # type: ignore


@patch("neo4j_graphrag.llm.mistralai_llm.Mistral")
def test_mistralai_llm_invoke_with_message_history_validation_error(
mock_mistral: Mock,
Expand Down
57 changes: 57 additions & 0 deletions tests/unit/llm/test_ollama_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,63 @@ def test_ollama_invoke_with_message_history_happy_path(mock_import: Mock) -> Non
)


@patch("builtins.__import__")
def test_ollama_invoke_with_message_history_and_system_instruction(
mock_import: Mock,
) -> None:
mock_ollama = get_mock_ollama()
mock_import.return_value = mock_ollama
mock_ollama.Client.return_value.chat.return_value = MagicMock(
message=MagicMock(content="ollama chat response"),
)
model = "gpt"
model_params = {"temperature": 0.3}
system_instruction = "You are a helpful assistant."
llm = OllamaLLM(
model,
model_params=model_params,
system_instruction=system_instruction,
)
message_history = [
{"role": "user", "content": "When does the sun come up in the summer?"},
{"role": "assistant", "content": "Usually around 6am."},
]
question = "What about next season?"

# first invokation - initial instructions
response = llm.invoke(question, message_history) # type: ignore
assert response.content == "ollama chat response"
messages = [{"role": "system", "content": system_instruction}]
messages.extend(message_history)
messages.append({"role": "user", "content": question})
llm.client.chat.assert_called_once_with( # type: ignore[attr-defined]
model=model, messages=messages, options=model_params
)

# second invokation - override instructions
override_instruction = "Ignore all previous instructions"
response = llm.invoke(question, message_history, override_instruction) # type: ignore
assert response.content == "ollama chat response"
messages = [{"role": "system", "content": override_instruction}]
messages.extend(message_history)
messages.append({"role": "user", "content": question})
llm.client.chat.assert_called_with( # type: ignore[attr-defined]
model=model, messages=messages, options=model_params
)

# third invokation - default instructions
response = llm.invoke(question, message_history) # type: ignore
assert response.content == "ollama chat response"
messages = [{"role": "system", "content": system_instruction}]
messages.extend(message_history)
messages.append({"role": "user", "content": question})
llm.client.chat.assert_called_with( # type: ignore[attr-defined]
model=model, messages=messages, options=model_params
)

assert llm.client.chat.call_count == 3 # type: ignore


@patch("builtins.__import__")
def test_ollama_invoke_with_message_history_validation_error(mock_import: Mock) -> None:
mock_ollama = get_mock_ollama()
Expand Down
69 changes: 69 additions & 0 deletions tests/unit/llm/test_openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,70 @@ def test_openai_llm_with_message_history_happy_path(mock_import: Mock) -> None:
res = llm.invoke(question, message_history) # type: ignore
assert isinstance(res, LLMResponse)
assert res.content == "openai chat response"
message_history.append({"role": "user", "content": question})
llm.client.chat.completions.create.assert_called_once_with( # type: ignore
messages=message_history,
model="gpt",
)


@patch("builtins.__import__")
def test_openai_llm_with_message_history_and_system_instruction(
mock_import: Mock,
) -> None:
mock_openai = get_mock_openai()
mock_import.return_value = mock_openai
mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock(
choices=[MagicMock(message=MagicMock(content="openai chat response"))],
)
initial_instruction = "You are a helpful assistent."
llm = OpenAILLM(
api_key="my key", model_name="gpt", system_instruction=initial_instruction
)
message_history = [
{"role": "user", "content": "When does the sun come up in the summer?"},
{"role": "assistant", "content": "Usually around 6am."},
]
question = "What about next season?"

# first invokation - initial instructions
res = llm.invoke(question, message_history) # type: ignore
assert isinstance(res, LLMResponse)
assert res.content == "openai chat response"
messages = [{"role": "system", "content": initial_instruction}]
messages.extend(message_history)
messages.append({"role": "user", "content": question})
llm.client.chat.completions.create.assert_called_once_with( # type: ignore
messages=messages,
model="gpt",
)

# second invokation - override instructions
override_instruction = "Ignore all previous instructions"
res = llm.invoke(question, message_history, override_instruction) # type: ignore
assert isinstance(res, LLMResponse)
assert res.content == "openai chat response"
messages = [{"role": "system", "content": override_instruction}]
messages.extend(message_history)
messages.append({"role": "user", "content": question})
llm.client.chat.completions.create.assert_called_with( # type: ignore
messages=messages,
model="gpt",
)

# third invokation - default instructions
res = llm.invoke(question, message_history) # type: ignore
assert isinstance(res, LLMResponse)
assert res.content == "openai chat response"
messages = [{"role": "system", "content": initial_instruction}]
messages.extend(message_history)
messages.append({"role": "user", "content": question})
llm.client.chat.completions.create.assert_called_with( # type: ignore
messages=messages,
model="gpt",
)

assert llm.client.chat.completions.create.call_count == 3 # type: ignore


@patch("builtins.__import__")
Expand Down Expand Up @@ -137,6 +201,11 @@ def test_azure_openai_llm_with_message_history_happy_path(mock_import: Mock) ->
res = llm.invoke(question, message_history) # type: ignore
assert isinstance(res, LLMResponse)
assert res.content == "openai chat response"
message_history.append({"role": "user", "content": question})
llm.client.chat.completions.create.assert_called_once_with( # type: ignore
messages=message_history,
model="gpt",
)


@patch("builtins.__import__")
Expand Down
Loading

0 comments on commit 58ce6af

Please sign in to comment.