Skip to content

Commit

Permalink
cr
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan committed Oct 26, 2023
1 parent 5846c8e commit a318ba0
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 5 deletions.
5 changes: 0 additions & 5 deletions .github/workflows/scheduled_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,6 @@ jobs:
FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }}
run: |
make scheduled_tests
if [[ ${{ matrix.python-version }} != "3.8" ]]
then
poetry run pytest tests/integration_tests/llms/test_fireworks.py
poetry run pytest tests/integration_tests/chat_models/test_fireworks.py
fi
- name: Ensure the tests did not create any additional files
shell: bash
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
"""Test ChatFireworks wrapper."""
import sys

import pytest

from langchain.chat_models.fireworks import ChatFireworks
from langchain.schema import ChatGeneration, ChatResult, LLMResult
from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage

if sys.version_info < (3, 9):
pytest.skip("fireworks-ai requires Python > 3.8", allow_module_level=True)


@pytest.fixture
def chat() -> ChatFireworks:
return ChatFireworks(model_kwargs={"temperature": 0, "max_tokens": 512})


@pytest.mark.scheduled
def test_chat_fireworks(chat: ChatFireworks) -> None:
"""Test ChatFireworks wrapper."""
message = HumanMessage(content="What is the weather in Redwood City, CA today")
Expand All @@ -20,12 +25,14 @@ def test_chat_fireworks(chat: ChatFireworks) -> None:
assert isinstance(response.content, str)


@pytest.mark.scheduled
def test_chat_fireworks_model() -> None:
"""Test ChatFireworks wrapper handles model_name."""
chat = ChatFireworks(model="foo")
assert chat.model == "foo"


@pytest.mark.scheduled
def test_chat_fireworks_system_message(chat: ChatFireworks) -> None:
"""Test ChatFireworks wrapper with system message."""
system_message = SystemMessage(content="You are to chat with the user.")
Expand All @@ -35,6 +42,7 @@ def test_chat_fireworks_system_message(chat: ChatFireworks) -> None:
assert isinstance(response.content, str)


@pytest.mark.scheduled
def test_chat_fireworks_generate() -> None:
"""Test ChatFireworks wrapper with generate."""
chat = ChatFireworks(model_kwargs={"n": 2})
Expand All @@ -50,6 +58,7 @@ def test_chat_fireworks_generate() -> None:
assert generation.text == generation.message.content


@pytest.mark.scheduled
def test_chat_fireworks_multiple_completions() -> None:
"""Test ChatFireworks wrapper with multiple completions."""
chat = ChatFireworks(model_kwargs={"n": 5})
Expand All @@ -62,6 +71,7 @@ def test_chat_fireworks_multiple_completions() -> None:
assert isinstance(generation.message.content, str)


@pytest.mark.scheduled
def test_chat_fireworks_llm_output_contains_model_id(chat: ChatFireworks) -> None:
"""Test llm_output contains model_id."""
message = HumanMessage(content="Hello")
Expand All @@ -70,13 +80,15 @@ def test_chat_fireworks_llm_output_contains_model_id(chat: ChatFireworks) -> Non
assert llm_result.llm_output["model"] == chat.model


@pytest.mark.scheduled
def test_fireworks_invoke(chat: ChatFireworks) -> None:
"""Tests chat completion with invoke"""
result = chat.invoke("How is the weather in New York today?", stop=[","])
assert isinstance(result.content, str)
assert result.content[-1] == ","


@pytest.mark.scheduled
@pytest.mark.asyncio
async def test_fireworks_ainvoke(chat: ChatFireworks) -> None:
"""Tests chat completion with invoke"""
Expand All @@ -85,6 +97,7 @@ async def test_fireworks_ainvoke(chat: ChatFireworks) -> None:
assert result.content[-1] == ","


@pytest.mark.scheduled
def test_fireworks_batch(chat: ChatFireworks) -> None:
"""Test batch tokens from ChatFireworks."""
result = chat.batch(
Expand All @@ -104,6 +117,7 @@ def test_fireworks_batch(chat: ChatFireworks) -> None:
assert token.content[-1] == ","


@pytest.mark.scheduled
@pytest.mark.asyncio
async def test_fireworks_abatch(chat: ChatFireworks) -> None:
"""Test batch tokens from ChatFireworks."""
Expand All @@ -124,13 +138,15 @@ async def test_fireworks_abatch(chat: ChatFireworks) -> None:
assert token.content[-1] == ","


@pytest.mark.scheduled
def test_fireworks_streaming(chat: ChatFireworks) -> None:
"""Test streaming tokens from Fireworks."""

for token in chat.stream("I'm Pickle Rick"):
assert isinstance(token.content, str)


@pytest.mark.scheduled
def test_fireworks_streaming_stop_words(chat: ChatFireworks) -> None:
"""Test streaming tokens with stop words."""

Expand All @@ -141,6 +157,7 @@ def test_fireworks_streaming_stop_words(chat: ChatFireworks) -> None:
assert last_token[-1] == ","


@pytest.mark.scheduled
@pytest.mark.asyncio
async def test_chat_fireworks_agenerate() -> None:
"""Test ChatFireworks wrapper with generate."""
Expand All @@ -157,6 +174,7 @@ async def test_chat_fireworks_agenerate() -> None:
assert generation.text == generation.message.content


@pytest.mark.scheduled
@pytest.mark.asyncio
async def test_fireworks_astream(chat: ChatFireworks) -> None:
"""Test streaming tokens from Fireworks."""
Expand Down
17 changes: 17 additions & 0 deletions libs/langchain/tests/integration_tests/llms/test_fireworks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test Fireworks AI API Wrapper."""
import sys
from typing import Generator

import pytest
Expand All @@ -12,18 +13,23 @@
)
from langchain.schema import LLMResult

if sys.version_info < (3, 9):
pytest.skip("fireworks-ai requires Python > 3.8", allow_module_level=True)


@pytest.fixture
def llm() -> Fireworks:
return Fireworks(model_kwargs={"temperature": 0, "max_tokens": 512})


@pytest.mark.scheduled
def test_fireworks_call(llm: Fireworks) -> None:
"""Test valid call to fireworks."""
output = llm("How is the weather in New York today?")
assert isinstance(output, str)


@pytest.mark.scheduled
def test_fireworks_in_chain() -> None:
"""Tests fireworks AI in a Langchain chain"""
human_message_prompt = HumanMessagePromptTemplate(
Expand All @@ -39,19 +45,22 @@ def test_fireworks_in_chain() -> None:
assert isinstance(output, str)


@pytest.mark.scheduled
def test_fireworks_model_param() -> None:
"""Tests model parameters for Fireworks"""
llm = Fireworks(model="foo")
assert llm.model == "foo"


@pytest.mark.scheduled
def test_fireworks_invoke(llm: Fireworks) -> None:
"""Tests completion with invoke"""
output = llm.invoke("How is the weather in New York today?", stop=[","])
assert isinstance(output, str)
assert output[-1] == ","


@pytest.mark.scheduled
@pytest.mark.asyncio
async def test_fireworks_ainvoke(llm: Fireworks) -> None:
"""Tests completion with invoke"""
Expand All @@ -60,6 +69,7 @@ async def test_fireworks_ainvoke(llm: Fireworks) -> None:
assert output[-1] == ","


@pytest.mark.scheduled
def test_fireworks_batch(llm: Fireworks) -> None:
"""Tests completion with invoke"""
llm = Fireworks()
Expand All @@ -78,6 +88,7 @@ def test_fireworks_batch(llm: Fireworks) -> None:
assert token[-1] == ","


@pytest.mark.scheduled
@pytest.mark.asyncio
async def test_fireworks_abatch(llm: Fireworks) -> None:
"""Tests completion with invoke"""
Expand All @@ -96,6 +107,7 @@ async def test_fireworks_abatch(llm: Fireworks) -> None:
assert token[-1] == ","


@pytest.mark.scheduled
def test_fireworks_multiple_prompts(
llm: Fireworks,
) -> None:
Expand All @@ -106,6 +118,7 @@ def test_fireworks_multiple_prompts(
assert len(output.generations) == 2


@pytest.mark.scheduled
def test_fireworks_streaming(llm: Fireworks) -> None:
"""Test stream completion."""
generator = llm.stream("Who's the best quarterback in the NFL?")
Expand All @@ -115,6 +128,7 @@ def test_fireworks_streaming(llm: Fireworks) -> None:
assert isinstance(token, str)


@pytest.mark.scheduled
def test_fireworks_streaming_stop_words(llm: Fireworks) -> None:
"""Test stream completion with stop words."""
generator = llm.stream("Who's the best quarterback in the NFL?", stop=[","])
Expand All @@ -127,6 +141,7 @@ def test_fireworks_streaming_stop_words(llm: Fireworks) -> None:
assert last_token[-1] == ","


@pytest.mark.scheduled
@pytest.mark.asyncio
async def test_fireworks_streaming_async(llm: Fireworks) -> None:
"""Test stream completion."""
Expand All @@ -140,13 +155,15 @@ async def test_fireworks_streaming_async(llm: Fireworks) -> None:
assert last_token[-1] == ","


@pytest.mark.scheduled
@pytest.mark.asyncio
async def test_fireworks_async_agenerate(llm: Fireworks) -> None:
"""Test async."""
output = await llm.agenerate(["What is the best city to live in California?"])
assert isinstance(output, LLMResult)


@pytest.mark.scheduled
@pytest.mark.asyncio
async def test_fireworks_multiple_prompts_async_agenerate(llm: Fireworks) -> None:
output = await llm.agenerate(
Expand Down

0 comments on commit a318ba0

Please sign in to comment.