Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community: add new parameter default_headers #28700

Merged
merged 4 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/docs/integrations/chat/oci_data_science.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -156,6 +156,10 @@
" \"temperature\": 0.2,\n",
" \"max_tokens\": 512,\n",
" }, # other model params...\n",
" default_headers={\n",
" \"route\": \"/v1/chat/completions\",\n",
" # other request headers ...\n",
" },\n",
")"
]
},
Expand Down
37 changes: 35 additions & 2 deletions libs/community/langchain_community/chat_models/oci_data_science.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
)

logger = logging.getLogger(__name__)
DEFAULT_INFERENCE_ENDPOINT_CHAT = "/v1/chat/completions"


def _is_pydantic_class(obj: Any) -> bool:
Expand All @@ -56,6 +57,13 @@ def _is_pydantic_class(obj: Any) -> bool:
class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
"""OCI Data Science Model Deployment chat model integration.

Prerequisite
The OCI Model Deployment plugins are installable only on
python version 3.9 and above. If you're working inside the notebook,
try installing the python 3.10 based conda pack and running the
following setup.


Setup:
Install ``oracle-ads`` and ``langchain-openai``.

Expand Down Expand Up @@ -90,6 +98,8 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
Key init args — client params:
auth: dict
ADS auth dictionary for OCI authentication.
default_headers: Optional[Dict]
The headers to be added to the Model Deployment request.

Instantiate:
.. code-block:: python
Expand All @@ -98,14 +108,18 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):

chat = ChatOCIModelDeployment(
endpoint="https://modeldeployment.<region>.oci.customer-oci.com/<ocid>/predict",
model="odsc-llm",
model="odsc-llm", # this is the default model name if deployed with AQUA
streaming=True,
max_retries=3,
model_kwargs={
"max_token": 512,
"temperature": 0.2,
# other model parameters ...
},
default_headers={
"route": "/v1/chat/completions",
# other request headers ...
},
)

Invocation:
Expand Down Expand Up @@ -288,6 +302,25 @@ def _default_params(self) -> Dict[str, Any]:
"stream": self.streaming,
}

def _headers(
self, is_async: Optional[bool] = False, body: Optional[dict] = None
) -> Dict:
"""Construct and return the headers for a request.

Args:
is_async (bool, optional): Indicates if the request is asynchronous.
Defaults to `False`.
body (optional): The request body to be included in the headers if
the request is asynchronous.

Returns:
Dict: A dictionary containing the appropriate headers for the request.
"""
return {
"route": DEFAULT_INFERENCE_ENDPOINT_CHAT,
**super()._headers(is_async=is_async, body=body),
}

def _generate(
self,
messages: List[BaseMessage],
Expand Down Expand Up @@ -701,7 +734,7 @@ def _process_response(self, response_json: dict) -> ChatResult:

for choice in choices:
message = _convert_dict_to_message(choice["message"])
generation_info = dict(finish_reason=choice.get("finish_reason"))
generation_info = {"finish_reason": choice.get("finish_reason")}
if "logprobs" in choice:
generation_info["logprobs"] = choice["logprobs"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from langchain_community.utilities.requests import Requests

logger = logging.getLogger(__name__)
DEFAULT_INFERENCE_ENDPOINT = "/v1/completions"


DEFAULT_TIME_OUT = 300
Expand Down Expand Up @@ -81,6 +82,9 @@ class BaseOCIModelDeployment(Serializable):
max_retries: int = 3
"""Maximum number of retries to make when generating."""

default_headers: Optional[Dict[str, Any]] = None
"""The headers to be added to the Model Deployment request."""

@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Dict:
Expand Down Expand Up @@ -120,12 +124,12 @@ def _headers(
Returns:
Dict: A dictionary containing the appropriate headers for the request.
"""
headers = self.default_headers or {}
if is_async:
signer = self.auth["signer"]
_req = requests.Request("POST", self.endpoint, json=body)
req = _req.prepare()
req = signer(req)
headers = {}
for key, value in req.headers.items():
headers[key] = value

Expand All @@ -135,7 +139,7 @@ def _headers(
)
return headers

return (
headers.update(
{
"Content-Type": DEFAULT_CONTENT_TYPE_JSON,
"enable-streaming": "true",
Expand All @@ -147,6 +151,8 @@ def _headers(
}
)

return headers

def completion_with_retry(
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
) -> Any:
Expand Down Expand Up @@ -383,6 +389,10 @@ class OCIModelDeploymentLLM(BaseLLM, BaseOCIModelDeployment):
model="odsc-llm",
streaming=True,
model_kwargs={"frequency_penalty": 1.0},
headers={
"route": "/v1/completions",
# other request headers ...
}
)
llm.invoke("tell me a joke.")

Expand Down Expand Up @@ -426,7 +436,7 @@ def _construct_json_body(self, prompt: str, param:dict) -> dict:
temperature: float = 0.2
"""A non-negative float that tunes the degree of randomness in generation."""

k: int = -1
k: int = 50
"""Number of most likely tokens to consider at each step."""

p: float = 0.75
Expand Down Expand Up @@ -472,6 +482,25 @@ def _identifying_params(self) -> Dict[str, Any]:
**self._default_params,
}

def _headers(
self, is_async: Optional[bool] = False, body: Optional[dict] = None
) -> Dict:
"""Construct and return the headers for a request.

Args:
is_async (bool, optional): Indicates if the request is asynchronous.
Defaults to `False`.
body (optional): The request body to be included in the headers if
the request is asynchronous.

Returns:
Dict: A dictionary containing the appropriate headers for the request.
"""
return {
"route": DEFAULT_INFERENCE_ENDPOINT,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be included as a header?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @efriis, thank you for your comment! In the current implementation of OCI Model Deployment, the route parameter needs to be passed through the header. For text and chat models, this parameter will be hardcoded by default. However, users can modify the value of the route parameter if needed by utilizing the default_header attribute.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good

**super()._headers(is_async=is_async, body=body),
}

def _generate(
self,
prompts: List[str],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
CONST_ENDPOINT = "https://oci.endpoint/ocid/predict"
CONST_PROMPT = "This is a prompt."
CONST_COMPLETION = "This is a completion."
CONST_COMPLETION_ROUTE = "/v1/chat/completions"
CONST_COMPLETION_RESPONSE = {
"id": "chat-123456789",
"object": "chat.completion",
Expand Down Expand Up @@ -120,6 +121,7 @@ def mocked_requests_post(url: str, **kwargs: Any) -> MockResponse:
def test_invoke_vllm(*args: Any) -> None:
"""Tests invoking vLLM endpoint."""
llm = ChatOCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
output = llm.invoke(CONST_PROMPT)
assert isinstance(output, AIMessage)
assert output.content == CONST_COMPLETION
Expand All @@ -132,6 +134,7 @@ def test_invoke_vllm(*args: Any) -> None:
def test_invoke_tgi(*args: Any) -> None:
"""Tests invoking TGI endpoint using OpenAI Spec."""
llm = ChatOCIModelDeploymentTGI(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
output = llm.invoke(CONST_PROMPT)
assert isinstance(output, AIMessage)
assert output.content == CONST_COMPLETION
Expand All @@ -146,6 +149,7 @@ def test_stream_vllm(*args: Any) -> None:
llm = ChatOCIModelDeploymentVLLM(
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
)
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
output = None
count = 0
for chunk in llm.stream(CONST_PROMPT):
Expand Down Expand Up @@ -184,6 +188,7 @@ async def test_stream_async(*args: Any) -> None:
llm = ChatOCIModelDeploymentVLLM(
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
)
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
with mock.patch.object(
llm,
"_aiter_sse",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
CONST_ENDPOINT = "https://oci.endpoint/ocid/predict"
CONST_PROMPT = "This is a prompt."
CONST_COMPLETION = "This is a completion."
CONST_COMPLETION_ROUTE = "/v1/completions"
CONST_COMPLETION_RESPONSE = {
"choices": [
{
Expand Down Expand Up @@ -114,6 +115,7 @@ async def mocked_async_streaming_response(
def test_invoke_vllm(*args: Any) -> None:
"""Tests invoking vLLM endpoint."""
llm = OCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
output = llm.invoke(CONST_PROMPT)
assert output == CONST_COMPLETION

Expand All @@ -126,6 +128,7 @@ def test_stream_tgi(*args: Any) -> None:
llm = OCIModelDeploymentTGI(
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
)
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
output = ""
count = 0
for chunk in llm.stream(CONST_PROMPT):
Expand All @@ -143,6 +146,7 @@ def test_generate_tgi(*args: Any) -> None:
llm = OCIModelDeploymentTGI(
endpoint=CONST_ENDPOINT, api="/generate", model=CONST_MODEL_NAME
)
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
output = llm.invoke(CONST_PROMPT)
assert output == CONST_COMPLETION

Expand All @@ -161,6 +165,7 @@ async def test_stream_async(*args: Any) -> None:
llm = OCIModelDeploymentTGI(
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
)
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
with mock.patch.object(
llm,
"_aiter_sse",
Expand Down
Loading