Skip to content

Commit

Permalink
Mask API Key for OpenAI based ChatModels (OpenAI,AzureOpenAi,Konko)
Browse files Browse the repository at this point in the history
  • Loading branch information
onesolpark committed Oct 30, 2023
1 parent b213850 commit 57d1e94
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 26 deletions.
18 changes: 10 additions & 8 deletions libs/langchain/langchain/chat_models/azure_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from __future__ import annotations

import logging
from typing import Any, Dict, Mapping
from typing import Any, Dict, Mapping, Optional

from langchain.chat_models.openai import ChatOpenAI
from langchain.pydantic_v1 import root_validator
from langchain.pydantic_v1 import SecretStr, root_validator
from langchain.schema import ChatResult
from langchain.utils import get_from_dict_or_env
from langchain.utils import convert_to_secret_str, get_from_dict_or_env

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -56,17 +56,19 @@ class AzureChatOpenAI(ChatOpenAI):
openai_api_type: str = ""
openai_api_base: str = ""
openai_api_version: str = ""
openai_api_key: str = ""
openai_api_key: Optional[SecretStr] = None
openai_organization: str = ""
openai_proxy: str = ""

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["openai_api_key"] = get_from_dict_or_env(
values,
"openai_api_key",
"OPENAI_API_KEY",
values["openai_api_key"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"openai_api_key",
"OPENAI_API_KEY",
)
)
values["openai_api_base"] = get_from_dict_or_env(
values,
Expand Down
24 changes: 12 additions & 12 deletions libs/langchain/langchain/chat_models/konko.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
)
from langchain.chat_models.base import _generate_from_stream
from langchain.chat_models.openai import ChatOpenAI, _convert_delta_to_message_chunk
from langchain.pydantic_v1 import Field, root_validator
from langchain.pydantic_v1 import Field, SecretStr, root_validator
from langchain.schema import ChatGeneration, ChatResult
from langchain.schema.messages import AIMessageChunk, BaseMessage
from langchain.schema.output import ChatGenerationChunk
from langchain.utils import get_from_dict_or_env
from langchain.utils import convert_to_secret_str, get_from_dict_or_env

DEFAULT_API_BASE = "https://api.konko.ai/v1"
DEFAULT_MODEL = "meta-llama/Llama-2-13b-chat-hf"
Expand Down Expand Up @@ -67,8 +67,8 @@ def is_lc_serializable(cls) -> bool:
"""What sampling temperature to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
openai_api_key: Optional[str] = None
konko_api_key: Optional[str] = None
openai_api_key: Optional[SecretStr] = None
konko_api_key: Optional[SecretStr] = None
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
"""Timeout for requests to Konko completion API."""
max_retries: int = 6
Expand All @@ -83,8 +83,8 @@ def is_lc_serializable(cls) -> bool:
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["konko_api_key"] = get_from_dict_or_env(
values, "konko_api_key", "KONKO_API_KEY"
values["konko_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "konko_api_key", "KONKO_API_KEY")
)
try:
import konko
Expand Down Expand Up @@ -123,23 +123,23 @@ def _default_params(self) -> Dict[str, Any]:

@staticmethod
def get_available_models(
konko_api_key: Optional[str] = None,
openai_api_key: Optional[str] = None,
konko_api_key: Optional[SecretStr] = None,
openai_api_key: Optional[SecretStr] = None,
konko_api_base: str = DEFAULT_API_BASE,
) -> Set[str]:
"""Get available models from Konko API."""

# Try to retrieve the OpenAI API key if it's not passed as an argument
if not openai_api_key:
try:
openai_api_key = os.environ["OPENAI_API_KEY"]
openai_api_key = convert_to_secret_str(os.environ["OPENAI_API_KEY"])
except KeyError:
pass # It's okay if it's not set, we just won't use it

# Try to retrieve the Konko API key if it's not passed as an argument
if not konko_api_key:
try:
konko_api_key = os.environ["KONKO_API_KEY"]
konko_api_key = convert_to_secret_str(os.environ["KONKO_API_KEY"])
except KeyError:
raise ValueError(
"Konko API key must be passed as keyword argument or "
Expand All @@ -149,11 +149,11 @@ def get_available_models(
models_url = f"{konko_api_base}/models"

headers = {
"Authorization": f"Bearer {konko_api_key}",
"Authorization": f"Bearer {konko_api_key.get_secret_value()}",
}

if openai_api_key:
headers["X-OpenAI-Api-Key"] = openai_api_key
headers["X-OpenAI-Api-Key"] = openai_api_key.get_secret_value()

models_response = requests.get(models_url, headers=headers)

Expand Down
18 changes: 12 additions & 6 deletions libs/langchain/langchain/chat_models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Tuple,
Type,
Union,
cast,
)

from langchain.adapters.openai import convert_dict_to_message, convert_message_to_dict
Expand All @@ -29,7 +30,7 @@
_generate_from_stream,
)
from langchain.llms.base import create_base_retry_decorator
from langchain.pydantic_v1 import Field, root_validator
from langchain.pydantic_v1 import Field, SecretStr, root_validator
from langchain.schema import ChatGeneration, ChatResult
from langchain.schema.messages import (
AIMessageChunk,
Expand All @@ -41,7 +42,11 @@
SystemMessageChunk,
)
from langchain.schema.output import ChatGenerationChunk
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
from langchain.utils import (
convert_to_secret_str,
get_from_dict_or_env,
get_pydantic_field_names,
)

if TYPE_CHECKING:
import tiktoken
Expand Down Expand Up @@ -168,7 +173,7 @@ def is_lc_serializable(cls) -> bool:
"""What sampling temperature to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
openai_api_key: Optional[str] = None
openai_api_key: Optional[SecretStr] = None
"""Base URL path for API requests,
leave blank if not using a proxy or service emulator."""
openai_api_base: Optional[str] = None
Expand Down Expand Up @@ -230,8 +235,8 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["openai_api_key"] = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY"
values["openai_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "openai_api_key", "OPENAI_API_KEY")
)
values["openai_organization"] = get_from_dict_or_env(
values,
Expand Down Expand Up @@ -445,8 +450,9 @@ def _identifying_params(self) -> Dict[str, Any]:
@property
def _client_params(self) -> Dict[str, Any]:
"""Get the parameters used for the openai client."""
self.openai_api_key = cast(SecretStr, self.openai_api_key)
openai_creds: Dict[str, Any] = {
"api_key": self.openai_api_key,
"api_key": self.openai_api_key.get_secret_value(),
"api_base": self.openai_api_base,
"organization": self.openai_organization,
"model": self.model_name,
Expand Down
48 changes: 48 additions & 0 deletions libs/langchain/tests/unit_tests/chat_models/test_azureopenai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from unittest import mock

import pytest
from pytest import MonkeyPatch

from langchain.chat_models.azure_openai import AzureChatOpenAI
from langchain.pydantic_v1 import SecretStr


@mock.patch.dict(
Expand Down Expand Up @@ -55,3 +57,49 @@ def test_model_name_set_on_chat_result_when_present_in_response(
chat_result.llm_output is not None
and chat_result.llm_output["model_name"] == model_name
)


@mock.patch.dict(
os.environ,
{
"OPENAI_API_BASE": "https://oai.azure.com/",
"OPENAI_API_VERSION": "2023-05-01",
},
)
@pytest.mark.requires("openai")
def test_api_key_is_secret_string_and_matches_input() -> None:
llm = AzureChatOpenAI(openai_api_key="secret-api-key")
assert isinstance(llm.openai_api_key, SecretStr)
assert llm.openai_api_key.get_secret_value() == "secret-api-key"


@mock.patch.dict(
os.environ,
{
"OPENAI_API_BASE": "https://oai.azure.com/",
"OPENAI_API_VERSION": "2023-05-01",
},
)
@pytest.mark.requires("openai")
def test_api_key_masked_when_passed_via_constructor() -> None:
llm = AzureChatOpenAI(openai_api_key="secret-api-key")
assert str(llm.openai_api_key) == "**********"
assert "secret-api-key" not in repr(llm.openai_api_key)
assert "secret-api-key" not in repr(llm)


@mock.patch.dict(
os.environ,
{
"OPENAI_API_BASE": "https://oai.azure.com/",
"OPENAI_API_VERSION": "2023-05-01",
},
)
@pytest.mark.requires("openai")
def test_api_key_masked_when_passed_via_env() -> None:
with MonkeyPatch.context() as mp:
mp.setenv("OPENAI_API_KEY", "secret-api-key")
llm = AzureChatOpenAI()
assert str(llm.openai_api_key) == "**********"
assert "secret-api-key" not in repr(llm.openai_api_key)
assert "secret-api-key" not in repr(llm)
39 changes: 39 additions & 0 deletions libs/langchain/tests/unit_tests/chat_models/test_konko.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest
from pytest import MonkeyPatch

from langchain.chat_models.konko import ChatKonko
from langchain.pydantic_v1 import SecretStr


@pytest.mark.requires("konko")
def test_api_key_is_secret_string_and_matches_input() -> None:
llm = ChatKonko(
openai_api_key="secret-openai-api-key", konko_api_key="secret-konko-api-key"
)
assert isinstance(llm.openai_api_key, SecretStr)
assert isinstance(llm.konko_api_key, SecretStr)
assert llm.openai_api_key.get_secret_value() == "secret-openai-api-key"
assert llm.konko_api_key.get_secret_value() == "secret-konko-api-key"


@pytest.mark.requires("konko")
def test_api_key_masked_when_passed_via_constructor() -> None:
llm = ChatKonko(
openai_api_key="secret-openai-api-key", konko_api_key="secret-konko-api-key"
)
assert str(llm.openai_api_key) == "**********"
assert str(llm.konko_api_key) == "**********"
assert "secret-openai-api-key" not in repr(llm.openai_api_key)
assert "secret-konko-api-key" not in repr(llm.konko_api_key)
assert "secret-openai-api-key" not in repr(llm)
assert "secret-konko-api-key" not in repr(llm)


@pytest.mark.requires("konko")
def test_api_key_masked_when_passed_via_env() -> None:
with MonkeyPatch.context() as mp:
mp.setenv("KONKO_API_KEY", "secret-konko-api-key")
llm = ChatKonko()
assert str(llm.konko_api_key) == "**********"
assert "secret-konko-api-key" not in repr(llm.konko_api_key)
assert "secret-konko-api-key" not in repr(llm)
47 changes: 47 additions & 0 deletions libs/langchain/tests/unit_tests/chat_models/test_openai.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
"""Test OpenAI Chat API wrapper."""
import json
import os
from typing import Any
from unittest import mock
from unittest.mock import MagicMock, patch

import pytest
from pytest import MonkeyPatch

from langchain.adapters.openai import convert_dict_to_message
from langchain.chat_models.openai import ChatOpenAI
from langchain.pydantic_v1 import SecretStr
from langchain.schema.messages import (
AIMessage,
FunctionMessage,
Expand All @@ -15,6 +19,12 @@
)


@mock.patch.dict(
os.environ,
{
"OPENAI_API_KEY": "test",
},
)
@pytest.mark.requires("openai")
def test_openai_model_param() -> None:
llm = ChatOpenAI(model="foo")
Expand Down Expand Up @@ -79,6 +89,12 @@ def mock_completion() -> dict:
}


@mock.patch.dict(
os.environ,
{
"OPENAI_API_KEY": "test",
},
)
@pytest.mark.requires("openai")
def test_openai_predict(mock_completion: dict) -> None:
llm = ChatOpenAI()
Expand All @@ -101,6 +117,12 @@ def mock_create(*args: Any, **kwargs: Any) -> Any:
assert completed


@mock.patch.dict(
os.environ,
{
"OPENAI_API_KEY": "test",
},
)
@pytest.mark.requires("openai")
async def test_openai_apredict(mock_completion: dict) -> None:
llm = ChatOpenAI()
Expand All @@ -121,3 +143,28 @@ def mock_create(*args: Any, **kwargs: Any) -> Any:
res = llm.predict("bar")
assert res == "Bar Baz"
assert completed


@pytest.mark.requires("openai")
def test_api_key_is_secret_string_and_matches_input() -> None:
llm = ChatOpenAI(openai_api_key="secret-api-key")
assert isinstance(llm.openai_api_key, SecretStr)
assert llm.openai_api_key.get_secret_value() == "secret-api-key"


@pytest.mark.requires("openai")
def test_api_key_masked_when_passed_via_constructor() -> None:
llm = ChatOpenAI(openai_api_key="secret-api-key")
assert str(llm.openai_api_key) == "**********"
assert "secret-api-key" not in repr(llm.openai_api_key)
assert "secret-api-key" not in repr(llm)


@pytest.mark.requires("openai")
def test_api_key_masked_when_passed_via_env() -> None:
with MonkeyPatch.context() as mp:
mp.setenv("OPENAI_API_KEY", "secret-api-key")
llm = ChatOpenAI()
assert str(llm.openai_api_key) == "**********"
assert "secret-api-key" not in repr(llm.openai_api_key)
assert "secret-api-key" not in repr(llm)

0 comments on commit 57d1e94

Please sign in to comment.