Skip to content

Commit

Permalink
IMPROVEMENT: support Openai API v1 for Azure OpenAI completions (lang…
Browse files Browse the repository at this point in the history
…chain-ai#13231)

Hi,
this PR adds support for OpenAI API v1 for Azure OpenAI completion API.
@baskaryan @hwchase17

---------

Co-authored-by: Bagatur <[email protected]>
  • Loading branch information
mspronesti and baskaryan authored Nov 14, 2023
1 parent fc886cc commit 344cab0
Show file tree
Hide file tree
Showing 4 changed files with 337 additions and 14 deletions.
2 changes: 1 addition & 1 deletion cookbook/openai_v1_cookbook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@
"\n",
"\n",
"RECOMMENDED CHANGES:\n",
"- When using AzureChatOpenAI, if passing in an Azure endpoint (eg https://example-resource.azure.openai.com/) this should be specified via the `azure_endpoint` parameter or the `AZURE_OPENAI_ENDPOINT`. We're maintaining backwards compatibility for now with specifying this via `openai_api_base`/`base_url` or env var `OPENAI_API_BASE` but this shouldn't be relied upon.\n",
"- When using `AzureChatOpenAI` or `AzureOpenAI`, if passing in an Azure endpoint (eg https://example-resource.azure.openai.com/) this should be specified via the `azure_endpoint` parameter or the `AZURE_OPENAI_ENDPOINT`. We're maintaining backwards compatibility for now with specifying this via `openai_api_base`/`base_url` or env var `OPENAI_API_BASE` but this shouldn't be relied upon.\n",
"- When using Azure chat or embedding models, pass in API keys either via `openai_api_key` parameter or `AZURE_OPENAI_API_KEY` parameter. We're maintaining backwards compatibility for now with specifying this via `OPENAI_API_KEY` but this shouldn't be relied upon."
]
},
Expand Down
3 changes: 3 additions & 0 deletions libs/langchain/langchain/chat_models/azure_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ class AzureChatOpenAI(ChatOpenAI):
openai_api_type: str = ""
"""Legacy, for openai<1.0.0 support."""
validate_base_url: bool = True
"""For backwards compatibility. If legacy val openai_api_base is passed in, try to
infer if it is a base_url or azure_endpoint and update accordingly.
"""

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
Expand Down
164 changes: 151 additions & 13 deletions libs/langchain/langchain/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def _stream_response_to_generation_chunk(
stream_response: Dict[str, Any],
) -> GenerationChunk:
"""Convert a stream response to a generation chunk."""
if not stream_response["choices"]:
return GenerationChunk(text="")
return GenerationChunk(
text=stream_response["choices"][0]["text"],
generation_info=dict(
Expand Down Expand Up @@ -746,21 +748,154 @@ class AzureOpenAI(BaseOpenAI):
openai = AzureOpenAI(model_name="text-davinci-003")
"""

deployment_name: str = ""
"""Deployment name to use."""
azure_endpoint: Union[str, None] = None
"""Your Azure endpoint, including the resource.
Automatically inferred from env var `AZURE_OPENAI_ENDPOINT` if not provided.
Example: `https://example-resource.azure.openai.com/`
"""
deployment_name: Union[str, None] = Field(default=None, alias="azure_deployment")
"""A model deployment.
If given sets the base client URL to include `/deployments/{azure_deployment}`.
Note: this means you won't be able to use non-deployment endpoints.
"""
openai_api_version: str = Field(default="", alias="api_version")
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
openai_api_key: Union[str, None] = Field(default=None, alias="api_key")
"""Automatically inferred from env var `AZURE_OPENAI_API_KEY` if not provided."""
azure_ad_token: Union[str, None] = None
"""Your Azure Active Directory token.
Automatically inferred from env var `AZURE_OPENAI_AD_TOKEN` if not provided.
For more:
https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id.
""" # noqa: E501
azure_ad_token_provider: Union[str, None] = None
"""A function that returns an Azure Active Directory token.
Will be invoked on every request.
"""
openai_api_type: str = ""
openai_api_version: str = ""
"""Legacy, for openai<1.0.0 support."""
validate_base_url: bool = True
"""For backwards compatibility. If legacy val openai_api_base is passed in, try to
infer if it is a base_url or azure_endpoint and update accordingly.
"""

@root_validator()
def validate_azure_settings(cls, values: Dict) -> Dict:
values["openai_api_version"] = get_from_dict_or_env(
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
if values["n"] < 1:
raise ValueError("n must be at least 1.")
if values["streaming"] and values["n"] > 1:
raise ValueError("Cannot stream results when n > 1.")
if values["streaming"] and values["best_of"] > 1:
raise ValueError("Cannot stream results when best_of > 1.")

# Check OPENAI_KEY for backwards compatibility.
# TODO: Remove OPENAI_API_KEY support to avoid possible conflict when using
# other forms of azure credentials.
values["openai_api_key"] = (
values["openai_api_key"]
or os.getenv("AZURE_OPENAI_API_KEY")
or os.getenv("OPENAI_API_KEY")
)

values["azure_endpoint"] = values["azure_endpoint"] or os.getenv(
"AZURE_OPENAI_ENDPOINT"
)
values["azure_ad_token"] = values["azure_ad_token"] or os.getenv(
"AZURE_OPENAI_AD_TOKEN"
)
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
"OPENAI_API_BASE"
)
values["openai_proxy"] = get_from_dict_or_env(
values,
"openai_api_version",
"OPENAI_API_VERSION",
"openai_proxy",
"OPENAI_PROXY",
default="",
)
values["openai_organization"] = (
values["openai_organization"]
or os.getenv("OPENAI_ORG_ID")
or os.getenv("OPENAI_ORGANIZATION")
)
values["openai_api_version"] = values["openai_api_version"] or os.getenv(
"OPENAI_API_VERSION"
)
values["openai_api_type"] = get_from_dict_or_env(
values, "openai_api_type", "OPENAI_API_TYPE", "azure"
values, "openai_api_type", "OPENAI_API_TYPE", default="azure"
)
try:
import openai
except ImportError:
raise ImportError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
if is_openai_v1():
# For backwards compatibility. Before openai v1, no distinction was made
# between azure_endpoint and base_url (openai_api_base).
openai_api_base = values["openai_api_base"]
if openai_api_base and values["validate_base_url"]:
if "/openai" not in openai_api_base:
values["openai_api_base"] = (
values["openai_api_base"].rstrip("/") + "/openai"
)
warnings.warn(
"As of openai>=1.0.0, Azure endpoints should be specified via "
f"the `azure_endpoint` param not `openai_api_base` "
f"(or alias `base_url`). Updating `openai_api_base` from "
f"{openai_api_base} to {values['openai_api_base']}."
)
if values["deployment_name"]:
warnings.warn(
"As of openai>=1.0.0, if `deployment_name` (or alias "
"`azure_deployment`) is specified then "
"`openai_api_base` (or alias `base_url`) should not be. "
"Instead use `deployment_name` (or alias `azure_deployment`) "
"and `azure_endpoint`."
)
if values["deployment_name"] not in values["openai_api_base"]:
warnings.warn(
"As of openai>=1.0.0, if `openai_api_base` "
"(or alias `base_url`) is specified it is expected to be "
"of the form "
"https://example-resource.azure.openai.com/openai/deployments/example-deployment. " # noqa: E501
f"Updating {openai_api_base} to "
f"{values['openai_api_base']}."
)
values["openai_api_base"] += (
"/deployments/" + values["deployment_name"]
)
values["deployment_name"] = None
client_params = {
"api_version": values["openai_api_version"],
"azure_endpoint": values["azure_endpoint"],
"azure_deployment": values["deployment_name"],
"api_key": values["openai_api_key"],
"azure_ad_token": values["azure_ad_token"],
"azure_ad_token_provider": values["azure_ad_token_provider"],
"organization": values["openai_organization"],
"base_url": values["openai_api_base"],
"timeout": values["request_timeout"],
"max_retries": values["max_retries"],
"default_headers": values["default_headers"],
"default_query": values["default_query"],
"http_client": values["http_client"],
}
values["client"] = openai.AzureOpenAI(**client_params).completions
values["async_client"] = openai.AsyncAzureOpenAI(
**client_params
).completions

else:
values["client"] = openai.Completion

return values

@property
Expand All @@ -772,11 +907,14 @@ def _identifying_params(self) -> Mapping[str, Any]:

@property
def _invocation_params(self) -> Dict[str, Any]:
openai_params = {
"engine": self.deployment_name,
"api_type": self.openai_api_type,
"api_version": self.openai_api_version,
}
if is_openai_v1():
openai_params = {"model": self.deployment_name}
else:
openai_params = {
"engine": self.deployment_name,
"api_type": self.openai_api_type,
"api_version": self.openai_api_version,
}
return {**openai_params, **super()._invocation_params}

@property
Expand Down
182 changes: 182 additions & 0 deletions libs/langchain/tests/integration_tests/llms/test_azure_openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
"""Test AzureOpenAI wrapper."""
import os
from typing import Any, Generator

import pytest

from langchain.callbacks.manager import CallbackManager
from langchain.llms import AzureOpenAI
from langchain.schema import (
LLMResult,
)
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler

OPENAI_API_VERSION = os.environ.get("AZURE_OPENAI_API_VERSION", "")
OPENAI_API_BASE = os.environ.get("AZURE_OPENAI_API_BASE", "")
OPENAI_API_KEY = os.environ.get("AZURE_OPENAI_API_KEY", "")
DEPLOYMENT_NAME = os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME", "")


def _get_llm(**kwargs: Any) -> AzureOpenAI:
return AzureOpenAI(
deployment_name=DEPLOYMENT_NAME,
openai_api_version=OPENAI_API_VERSION,
openai_api_base=OPENAI_API_BASE,
openai_api_key=OPENAI_API_KEY,
**kwargs,
)


@pytest.mark.scheduled
@pytest.fixture
def llm() -> AzureOpenAI:
return _get_llm(
max_tokens=10,
)


@pytest.mark.scheduled
def test_openai_call(llm: AzureOpenAI) -> None:
"""Test valid call to openai."""
output = llm("Say something nice:")
assert isinstance(output, str)


@pytest.mark.scheduled
def test_openai_streaming(llm: AzureOpenAI) -> None:
"""Test streaming tokens from AzureOpenAI."""
generator = llm.stream("I'm Pickle Rick")

assert isinstance(generator, Generator)

full_response = ""
for token in generator:
assert isinstance(token, str)
full_response += token
assert full_response


@pytest.mark.scheduled
@pytest.mark.asyncio
async def test_openai_astream(llm: AzureOpenAI) -> None:
"""Test streaming tokens from AzureOpenAI."""
async for token in llm.astream("I'm Pickle Rick"):
assert isinstance(token, str)


@pytest.mark.scheduled
@pytest.mark.asyncio
async def test_openai_abatch(llm: AzureOpenAI) -> None:
"""Test streaming tokens from AzureOpenAI."""
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token, str)


@pytest.mark.asyncio
async def test_openai_abatch_tags(llm: AzureOpenAI) -> None:
"""Test streaming tokens from AzureOpenAI."""
result = await llm.abatch(
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
)
for token in result:
assert isinstance(token, str)


@pytest.mark.scheduled
def test_openai_batch(llm: AzureOpenAI) -> None:
"""Test streaming tokens from AzureOpenAI."""
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token, str)


@pytest.mark.scheduled
@pytest.mark.asyncio
async def test_openai_ainvoke(llm: AzureOpenAI) -> None:
"""Test streaming tokens from AzureOpenAI."""
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
assert isinstance(result, str)


@pytest.mark.scheduled
def test_openai_invoke(llm: AzureOpenAI) -> None:
"""Test streaming tokens from AzureOpenAI."""
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result, str)


@pytest.mark.scheduled
def test_openai_multiple_prompts(llm: AzureOpenAI) -> None:
"""Test completion with multiple prompts."""
output = llm.generate(["I'm Pickle Rick", "I'm Pickle Rick"])
assert isinstance(output, LLMResult)
assert isinstance(output.generations, list)
assert len(output.generations) == 2


def test_openai_streaming_best_of_error() -> None:
"""Test validation for streaming fails if best_of is not 1."""
with pytest.raises(ValueError):
_get_llm(best_of=2, streaming=True)


def test_openai_streaming_n_error() -> None:
"""Test validation for streaming fails if n is not 1."""
with pytest.raises(ValueError):
_get_llm(n=2, streaming=True)


def test_openai_streaming_multiple_prompts_error() -> None:
"""Test validation for streaming fails if multiple prompts are given."""
with pytest.raises(ValueError):
_get_llm(streaming=True).generate(["I'm Pickle Rick", "I'm Pickle Rick"])


@pytest.mark.scheduled
def test_openai_streaming_call() -> None:
"""Test valid call to openai."""
llm = _get_llm(max_tokens=10, streaming=True)
output = llm("Say foo:")
assert isinstance(output, str)


def test_openai_streaming_callback() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
llm = _get_llm(
max_tokens=10,
streaming=True,
temperature=0,
callback_manager=callback_manager,
verbose=True,
)
llm("Write me a sentence with 100 words.")
assert callback_handler.llm_streams == 11


@pytest.mark.scheduled
@pytest.mark.asyncio
async def test_openai_async_generate() -> None:
"""Test async generation."""
llm = _get_llm(max_tokens=10)
output = await llm.agenerate(["Hello, how are you?"])
assert isinstance(output, LLMResult)


@pytest.mark.asyncio
async def test_openai_async_streaming_callback() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
llm = _get_llm(
max_tokens=10,
streaming=True,
temperature=0,
callback_manager=callback_manager,
verbose=True,
)
result = await llm.agenerate(["Write me a sentence with 100 words."])
assert callback_handler.llm_streams == 11
assert isinstance(result, LLMResult)

0 comments on commit 344cab0

Please sign in to comment.