Skip to content

Commit

Permalink
fix: code format
Browse files Browse the repository at this point in the history
  • Loading branch information
jasonhp committed Nov 8, 2024
1 parent 667dfac commit 959212a
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 13 deletions.
6 changes: 3 additions & 3 deletions libs/community/langchain_community/chat_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@
from langchain_community.chat_models.naver import (
ChatClovaX,
)
from langchain_community.chat_models.novita import (
ChatNovita,
)
from langchain_community.chat_models.oci_data_science import (
ChatOCIModelDeployment,
ChatOCIModelDeploymentTGI,
Expand Down Expand Up @@ -190,9 +193,6 @@
from langchain_community.chat_models.zhipuai import (
ChatZhipuAI,
)
from langchain_community.chat_models.novita import (
ChatNovita,
)

__all__ = [
"AzureChatOpenAI",
Expand Down
21 changes: 13 additions & 8 deletions libs/community/langchain_community/chat_models/novita.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,37 @@
"""Wrapper around Novita chat models."""

from typing import Dict
from pydantic import Field, SecretStr

from langchain_core.utils import (
__init__,
convert_to_secret_str,
get_from_dict_or_env,
pre_init,
)
from pydantic import Field, SecretStr

from langchain_community.chat_models import ChatOpenAI

NOVITA_API_BASE = "https://api.novita.ai/v3/openai"


class ChatNovita(ChatOpenAI): # type: ignore[misc]
"""Novita AI LLM.
To use, you should have the ``openai`` python package installed, and the
environment variable ``NOVITA_API_KEY`` set with your API key.
Example:
.. code-block:: python
from langchain_community.chat_models import ChatNovita
chat = ChatNovita(model="gryphe/mythomax-l2-13b")
"""

novita_api_key: SecretStr = Field(default=None, alias="api_key")
model_name: str = Field(default="gryphe/mythomax-l2-13b", alias="model")

@pre_init
@__init__
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the environment is set up correctly."""
values["novita_api_key"] = convert_to_secret_str(
Expand All @@ -55,6 +58,8 @@ def validate_environment(cls, values: Dict) -> Dict:
if not values.get("client"):
values["client"] = openai.OpenAI(**client_params).chat.completions
if not values.get("async_client"):
values["async_client"] = openai.AsyncOpenAI(**client_params).chat.completions
values["async_client"] = openai.AsyncOpenAI(
**client_params
).chat.completions

return values
6 changes: 4 additions & 2 deletions libs/community/tests/unit_tests/chat_models/test_novita.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,19 @@

from langchain_community.chat_models import ChatNovita


@pytest.mark.requires("openai")
def test__missing_novita_api_key() -> None:
with pytest.raises(ValidationError) as e:
ChatNovita()
assert "Did not find novita_api_key" in str(e)


@pytest.mark.requires("openai")
def test__all_fields_provided() -> None:
chat = ChatNovita(
api_key="787ee3fd-ff97-4aac-936e-1b09cf74a559",
api_key="your_api_key",
model="gryphe/mythomax-l2-13b",
)
assert chat.novita_api_key.get_secret_value() == "787ee3fd-ff97-4aac-936e-1b09cf74a559"
assert chat.novita_api_key.get_secret_value() == "your_api_key"
assert chat.model_name == "gryphe/mythomax-l2-13b"

0 comments on commit 959212a

Please sign in to comment.