Skip to content

Commit

Permalink
tests: add pytest flags (#38)
Browse files Browse the repository at this point in the history
* tests: add pytest flags
---------

Co-authored-by: Leonid Kuligin <[email protected]>
  • Loading branch information
svidiella and lkuligin authored Feb 28, 2024
1 parent 68ccc51 commit 75d0d8c
Show file tree
Hide file tree
Showing 11 changed files with 121 additions and 15 deletions.
2 changes: 1 addition & 1 deletion libs/vertexai/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ TEST_FILE ?= tests/unit_tests/
integration_test integration_tests: TEST_FILE = tests/integration_tests/

test tests integration_test integration_tests:
poetry run pytest $(TEST_FILE)
poetry run pytest --release $(TEST_FILE)


######################
Expand Down
64 changes: 64 additions & 0 deletions libs/vertexai/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
Tests configuration to be executed before tests execution.
"""

from typing import List

import pytest

_RELEASE_FLAG = "release"
_GPU_FLAG = "gpu"
_LONG_FLAG = "long"
_EXTENDED_FLAG = "extended"

_PYTEST_FLAGS = [_RELEASE_FLAG, _GPU_FLAG, _LONG_FLAG, _EXTENDED_FLAG]


def pytest_addoption(parser: pytest.Parser) -> None:
"""
Add flags accepted by pytest CLI.
Args:
parser: The pytest parser object.
Returns:
"""
for flag in _PYTEST_FLAGS:
parser.addoption(
f"--{flag}", action="store_true", default=False, help=f"run {flag} tests"
)


def pytest_configure(config: pytest.Config) -> None:
"""
Add pytest custom configuration.
Args:
config: The pytest config object.
Returns:
"""
for flag in _PYTEST_FLAGS:
config.addinivalue_line(
"markers", f"{flag}: mark test to run as {flag} only test"
)


def pytest_collection_modifyitems(
config: pytest.Config, items: List[pytest.Item]
) -> None:
"""
Skip tests with a marker from our list that were not explicitly invoked.
Args:
config: The pytest config object.
items: The list of tests to be executed.
Returns:
"""
for item in items:
keywords = list(set(item.keywords).intersection(_PYTEST_FLAGS))
if keywords and not any((config.getoption(f"--{kw}") for kw in keywords)):
skip = pytest.mark.skip(reason=f"need --{keywords[0]} option to run")
item.add_marker(skip)
5 changes: 5 additions & 0 deletions libs/vertexai/tests/integration_tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from langchain_google_vertexai.llms import VertexAI


@pytest.mark.release
@pytest.mark.parametrize(
"model_name",
["gemini-pro", "text-bison@001", "code-bison@001"],
Expand All @@ -25,6 +26,7 @@ def test_llm_invoke(model_name: str) -> None:
assert vb.completion_tokens > completion_tokens


@pytest.mark.release
@pytest.mark.parametrize(
"model_name",
["gemini-pro", "chat-bison@001", "codechat-bison@001"],
Expand All @@ -45,6 +47,7 @@ def test_chat_call(model_name: str) -> None:
assert vb.completion_tokens > completion_tokens


@pytest.mark.release
@pytest.mark.parametrize(
"model_name",
["gemini-pro", "text-bison@001", "code-bison@001"],
Expand All @@ -64,6 +67,7 @@ def test_invoke_config(model_name: str) -> None:
assert vb.completion_tokens > completion_tokens


@pytest.mark.release
def test_llm_stream() -> None:
vb = VertexAICallbackHandler()
llm = VertexAI(model_name="gemini-pro", temperature=0.0, callbacks=[vb])
Expand All @@ -74,6 +78,7 @@ def test_llm_stream() -> None:
assert vb.completion_tokens > 0


@pytest.mark.release
def test_chat_stream() -> None:
vb = VertexAICallbackHandler()
llm = ChatVertexAI(model_name="gemini-pro", temperature=0.0, callbacks=[vb])
Expand Down
16 changes: 16 additions & 0 deletions libs/vertexai/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
model_names_to_test = [None, "codechat-bison", "chat-bison", "gemini-pro"]


@pytest.mark.release
@pytest.mark.parametrize("model_name", model_names_to_test)
def test_initialization(model_name: Optional[str]) -> None:
"""Test chat model initialization."""
Expand All @@ -32,6 +33,7 @@ def test_initialization(model_name: Optional[str]) -> None:
assert model.model_name == model.client._model_name.split("/")[-1]


@pytest.mark.release
@pytest.mark.parametrize("model_name", model_names_to_test)
def test_vertexai_single_call(model_name: Optional[str]) -> None:
if model_name:
Expand All @@ -44,6 +46,7 @@ def test_vertexai_single_call(model_name: Optional[str]) -> None:
assert isinstance(response.content, str)


@pytest.mark.release
# mark xfail because Vertex API randomly doesn't respect
# the n/candidate_count parameter
@pytest.mark.xfail
Expand All @@ -56,6 +59,7 @@ def test_candidates() -> None:
assert len(response.generations[0]) == 2


@pytest.mark.release
@pytest.mark.parametrize("model_name", ["chat-bison@001", "gemini-pro"])
async def test_vertexai_agenerate(model_name: str) -> None:
model = ChatVertexAI(temperature=0, model_name=model_name)
Expand All @@ -76,6 +80,7 @@ async def test_vertexai_agenerate(model_name: str) -> None:
assert int(usage_metadata["candidates_token_count"]) > 0


@pytest.mark.release
@pytest.mark.parametrize("model_name", ["chat-bison@001", "gemini-pro"])
def test_vertexai_stream(model_name: str) -> None:
model = ChatVertexAI(temperature=0, model_name=model_name)
Expand All @@ -86,6 +91,7 @@ def test_vertexai_stream(model_name: str) -> None:
assert isinstance(chunk, AIMessageChunk)


@pytest.mark.release
async def test_vertexai_astream() -> None:
model = ChatVertexAI(temperature=0, model_name="gemini-pro")
message = HumanMessage(content="Hello")
Expand All @@ -94,6 +100,7 @@ async def test_vertexai_astream() -> None:
assert isinstance(chunk, AIMessageChunk)


@pytest.mark.release
def test_vertexai_single_call_with_context() -> None:
model = ChatVertexAI()
raw_context = (
Expand All @@ -110,6 +117,7 @@ def test_vertexai_single_call_with_context() -> None:
assert isinstance(response.content, str)


@pytest.mark.release
def test_multimodal() -> None:
llm = ChatVertexAI(model_name="gemini-pro-vision")
gcs_url = (
Expand All @@ -129,6 +137,7 @@ def test_multimodal() -> None:
assert isinstance(output.content, str)


@pytest.mark.release
@pytest.mark.xfail(reason="problem on vertex side")
def test_multimodal_history() -> None:
llm = ChatVertexAI(model_name="gemini-pro-vision")
Expand Down Expand Up @@ -158,6 +167,7 @@ def test_multimodal_history() -> None:
assert isinstance(response.content, str)


@pytest.mark.release
def test_vertexai_single_call_with_examples() -> None:
model = ChatVertexAI()
raw_context = "My name is Peter. You are my personal assistant."
Expand All @@ -172,6 +182,7 @@ def test_vertexai_single_call_with_examples() -> None:
assert isinstance(response.content, str)


@pytest.mark.release
@pytest.mark.parametrize("model_name", model_names_to_test)
def test_vertexai_single_call_with_history(model_name: Optional[str]) -> None:
if model_name:
Expand All @@ -188,6 +199,7 @@ def test_vertexai_single_call_with_history(model_name: Optional[str]) -> None:
assert isinstance(response.content, str)


@pytest.mark.release
def test_vertexai_single_call_fails_no_message() -> None:
chat = ChatVertexAI()
with pytest.raises(ValueError) as exc_info:
Expand All @@ -198,6 +210,7 @@ def test_vertexai_single_call_fails_no_message() -> None:
)


@pytest.mark.release
@pytest.mark.parametrize("model_name", ["gemini-pro"])
def test_chat_vertexai_gemini_system_message_error(model_name: str) -> None:
model = ChatVertexAI(model_name=model_name)
Expand All @@ -211,6 +224,7 @@ def test_chat_vertexai_gemini_system_message_error(model_name: str) -> None:
model([system_message, message1, message2, message3])


@pytest.mark.release
@pytest.mark.parametrize("model_name", model_names_to_test)
def test_chat_vertexai_system_message(model_name: Optional[str]) -> None:
if model_name:
Expand All @@ -231,6 +245,7 @@ def test_chat_vertexai_system_message(model_name: Optional[str]) -> None:
assert isinstance(response.content, str)


@pytest.mark.release
@pytest.mark.parametrize("model_name", model_names_to_test)
def test_get_num_tokens_from_messages(model_name: str) -> None:
if model_name:
Expand All @@ -243,6 +258,7 @@ def test_get_num_tokens_from_messages(model_name: str) -> None:
assert token == 3


@pytest.mark.release
def test_chat_vertexai_gemini_function_calling() -> None:
class MyModel(BaseModel):
name: str
Expand Down
6 changes: 6 additions & 0 deletions libs/vertexai/tests/integration_tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
from langchain_google_vertexai.embeddings import VertexAIEmbeddings


@pytest.mark.release
def test_initialization() -> None:
"""Test embedding model initialization."""
VertexAIEmbeddings()


@pytest.mark.release
def test_langchain_google_vertexai_embedding_documents() -> None:
documents = ["foo bar"]
model = VertexAIEmbeddings()
Expand All @@ -23,13 +25,15 @@ def test_langchain_google_vertexai_embedding_documents() -> None:
assert model.model_name == "textembedding-gecko@001"


@pytest.mark.release
def test_langchain_google_vertexai_embedding_query() -> None:
document = "foo bar"
model = VertexAIEmbeddings()
output = model.embed_query(document)
assert len(output) == 768


@pytest.mark.release
def test_langchain_google_vertexai_large_batches() -> None:
documents = ["foo bar" for _ in range(0, 251)]
model_uscentral1 = VertexAIEmbeddings(location="us-central1")
Expand All @@ -40,6 +44,7 @@ def test_langchain_google_vertexai_large_batches() -> None:
assert model_asianortheast1.instance["batch_size"] < 50


@pytest.mark.release
def test_langchain_google_vertexai_paginated_texts() -> None:
documents = [
"foo bar",
Expand All @@ -58,6 +63,7 @@ def test_langchain_google_vertexai_paginated_texts() -> None:
assert model.model_name == model.client._model_id


@pytest.mark.release
def test_warning(caplog: pytest.LogCaptureFixture) -> None:
_ = VertexAIEmbeddings()
assert len(caplog.records) == 1
Expand Down
8 changes: 4 additions & 4 deletions libs/vertexai/tests/integration_tests/test_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)


@pytest.mark.skip("CI testing not set up")
@pytest.mark.extended
def test_gemma_model_garden() -> None:
"""In order to run this test, you should provide endpoint names.
Expand All @@ -36,7 +36,7 @@ def test_gemma_model_garden() -> None:
assert llm._llm_type == "gemma_vertexai_model_garden"


@pytest.mark.skip("CI testing not set up")
@pytest.mark.extended
def test_gemma_chat_model_garden() -> None:
"""In order to run this test, you should provide endpoint names.
Expand Down Expand Up @@ -67,7 +67,7 @@ def test_gemma_chat_model_garden() -> None:
assert len(output.content) > 2


@pytest.mark.skip("CI testing not set up")
@pytest.mark.gpu
def test_gemma_kaggle() -> None:
llm = GemmaLocalKaggle(model_name="gemma_2b_en")
output = llm.invoke("What is the meaning of life?")
Expand All @@ -76,7 +76,7 @@ def test_gemma_kaggle() -> None:
assert len(output) > 2


@pytest.mark.skip("CI testing not set up")
@pytest.mark.gpu
def test_gemma_chat_kaggle() -> None:
llm = GemmaChatLocalKaggle(model_name="gemma_2b_en")
text_question1, text_answer1 = "How much is 2+2?", "4"
Expand Down
Loading

0 comments on commit 75d0d8c

Please sign in to comment.