Skip to content

Commit

Permalink
LLM client implementation (#905)
Browse files Browse the repository at this point in the history
* 1. implemented rest client for gpt with retries
2. retrieving random response in case n > 1 for gpt3
3. unit and integration tests

* 1. implemented rest client for gpt with retries
2. retrieving random response in case n > 1 for gpt3
3. unit and integration tests
  • Loading branch information
udit-pandey authored May 4, 2023
1 parent 5f3ab0b commit c75471c
Show file tree
Hide file tree
Showing 15 changed files with 689 additions and 368 deletions.
6 changes: 5 additions & 1 deletion kairon/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,11 @@ def check(cls, values):
from kairon.shared.utils import Utility

if not values.get('hyperparameters'):
values['hyperparameters'] = Utility.get_llm_hyperparameters()
values['hyperparameters'] = {}

for key, value in Utility.get_llm_hyperparameters().items():
if key not in values['hyperparameters']:
values['hyperparameters'][key] = value
return values


Expand Down
7 changes: 7 additions & 0 deletions kairon/shared/actions/data_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,14 @@ class KaironFaqAction(Auditlog):
hyperparameters = DictField(default=Utility.get_llm_hyperparameters)
llm_prompts = ListField(EmbeddedDocumentField(LlmPrompt), required=True)

def clean(self):
for key, value in Utility.get_llm_hyperparameters().items():
if key not in self.hyperparameters:
self.hyperparameters.update({key: value})

def validate(self, clean=True):
if clean:
self.clean()
if self.num_bot_responses > 5:
raise ValidationError("num_bot_responses should not be greater than 5")
if not 0.3 <= self.similarity_threshold <= 1:
Expand Down
5 changes: 5 additions & 0 deletions kairon/shared/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,8 @@ class ElementTypes(str, Enum):

class WhatsappBSPTypes(str, Enum):
bsp_360dialog = "360dialog"


class GPT3ResourceTypes(str, Enum):
embeddings = "embeddings"
chat_completion = "chat/completions"
Empty file.
85 changes: 85 additions & 0 deletions kairon/shared/llm/clients/gpt3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import json
import random
from abc import ABC
from json import JSONDecodeError
from typing import Text
from loguru import logger
from openai.api_requestor import parse_stream
from kairon.exceptions import AppException
from kairon.shared.constants import GPT3ResourceTypes
from kairon.shared.utils import Utility


class LLMResources(ABC):

def invoke(self, resource: Text, engine: Text, **kwargs):
raise NotImplementedError("Provider not implemented")


class GPT3Resources(LLMResources):
resource_url = "https://api.openai.com/v1/"

def __init__(self, api_key: Text):
self.api_key = api_key

def invoke(self, resource: Text, model: Text, **kwargs):
http_url = f"{self.resource_url}{resource}"
headers = {"Authorization": f"Bearer {self.api_key}"}
request_body = kwargs.copy()
request_body.update({"model": model})
resp = Utility.execute_http_request(
"POST", http_url, request_body, headers, max_retries=3, backoff_factor=0.2, return_json=False
)
if resp.status_code != 200:
try:
resp = resp.json()
logger.debug(f"GPT response error: {resp}")
raise AppException(f"{resp['error'].get('message')}. Request id: {resp['error'].get('id')}")
except JSONDecodeError:
raise AppException(f"Received non 200 status code: {resp.text}")

return self.__parse_response(resource, resp, **kwargs)

def __parse_response(self, resource: Text, response, **kwargs):
parsers = {
GPT3ResourceTypes.embeddings.value: self._parse_embeddings_response,
GPT3ResourceTypes.chat_completion.value: self.__parse_completion_response
}
return parsers[resource](response, **kwargs)

def _parse_embeddings_response(self, response, **hyperparameters):
raw_response = response.json()
formatted_response = raw_response["data"][0]["embedding"]
return formatted_response, raw_response

def __parse_completion_response(self, response, **kwargs):
if kwargs.get("stream"):
formatted_response, raw_response = self._parse_streaming_response(response, kwargs.get("n", 1))
else:
formatted_response, raw_response = self._parse_api_response(response)
return formatted_response, raw_response

def _parse_streaming_response(self, response, num_choices):
line = None
formatted_response = ''
raw_response = []
msg_choice = random.randint(0, num_choices - 1)
try:
for line in parse_stream(response.iter_lines()):
line = json.loads(line)
if line["choices"][0].get("index") == msg_choice and line["choices"][0]['delta'].get('content'):
formatted_response = f"{formatted_response}{line['choices'][0]['delta']['content']}"
raw_response.append(line)
except (JSONDecodeError, UnicodeDecodeError) as e:
logger.exception(e)
raise AppException(f"Received HTTP code {response.status_code} in streaming response from openai: {line}")
except Exception as e:
logger.exception(e)
raise AppException(f"Failed to parse response: {line}")
return formatted_response, raw_response

def _parse_api_response(self, response):
raw_response = response.json()
msg_choice = random.choice(raw_response['choices'])
formatted_response = msg_choice['message']['content']
return formatted_response, raw_response
34 changes: 12 additions & 22 deletions kairon/shared/llm/gpt3.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from kairon.shared.admin.constants import BotSecretType
from kairon.shared.admin.processor import Sysadmin
from kairon.shared.constants import GPT3ResourceTypes
from kairon.shared.data.constant import DEFAULT_SYSTEM_PROMPT, DEFAULT_CONTEXT_PROMPT
from kairon.shared.llm.base import LLMBase
from typing import Text, Dict, List, Union

from kairon.shared.llm.clients.gpt3 import GPT3Resources
from kairon.shared.utils import Utility
import openai
from kairon.shared.data.data_objects import BotContent
Expand All @@ -25,6 +28,7 @@ def __init__(self, bot: Text):
self.cached_resp_suffix = "_cached_response_embd"
self.vector_config = {'size': 1536, 'distance': 'Cosine'}
self.api_key = Sysadmin.get_bot_secret(bot, BotSecretType.gpt_key.value, raise_err=True)
self.client = GPT3Resources(self.api_key)
self.__logs = []

def train(self, *args, **kwargs) -> Dict:
Expand Down Expand Up @@ -75,12 +79,8 @@ def predict(self, query: Text, *args, **kwargs) -> Dict:
return response

def __get_embedding(self, text: Text) -> List[float]:
result = openai.Embedding.create(
api_key=self.api_key,
model="text-embedding-ada-002",
input=text
)
return result.to_dict_recursive()["data"][0]["embedding"]
result, _ = self.client.invoke(GPT3ResourceTypes.embeddings.value, model="text-embedding-ada-002", input=text)
return result

def __get_answer(self, query, system_prompt: Text, context: Text, **kwargs):
query_prompt = kwargs.get('query_prompt')
Expand All @@ -97,32 +97,22 @@ def __get_answer(self, query, system_prompt: Text, context: Text, **kwargs):
messages.extend(previous_bot_responses)
messages.append({"role": "user", "content": f"{context} \n Q: {query}\n A:"})

completion = openai.ChatCompletion.create(
api_key=self.api_key,
messages=messages,
**hyperparameters
)

response, raw_response = Utility.format_llm_response(completion, hyperparameters.get('stream', False))
completion, raw_response = self.client.invoke(GPT3ResourceTypes.chat_completion.value, messages=messages, **hyperparameters)
self.__logs.append({'messages': messages, 'raw_completion_response': raw_response,
'type': 'answer_query', 'hyperparameters': hyperparameters})
return response
return completion

def __rephrase_query(self, query, system_prompt: Text, query_prompt: Text, **kwargs):
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"{query_prompt}\n\n Q: {query}\n A:"}
]
hyperparameters = kwargs.get('hyperparameters', Utility.get_llm_hyperparameters())
completion = openai.ChatCompletion.create(
api_key=self.api_key,
messages=messages,
**hyperparameters
)
response, raw_response = Utility.format_llm_response(completion, hyperparameters.get('stream', False))

completion, raw_response = self.client.invoke(GPT3ResourceTypes.chat_completion.value, messages=messages, **hyperparameters)
self.__logs.append({'messages': messages, 'raw_completion_response': raw_response,
'type': 'rephrase_query', 'hyperparameters': hyperparameters})
return response
'type': 'rephrase_query', 'hyperparameters': hyperparameters})
return completion

def __create_collection__(self, collection_name: Text):
Utility.execute_http_request(http_url=urljoin(self.db_url, f"/collections/{collection_name}"),
Expand Down
33 changes: 16 additions & 17 deletions kairon/shared/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

import pandas as pd
import requests
from requests.adapters import HTTPAdapter, Retry

import yaml
from botocore.exceptions import ClientError
from fastapi import File, UploadFile
Expand Down Expand Up @@ -1288,8 +1290,20 @@ def execute_http_request(
validate_status: To validate status_code in response. False by default.
expected_status_code: 200 by default
err_msg: error message to be raised in case expected status code not received
max_retries: Number times we want to retry in case of failure, defaults to 0.
status_forcelist: status codes for which we want to force retries
backoff_factor: A backoff factor to apply between attempts after the second try. Defaults to 0.
For example, if the backoff_factor is 0.1, then Retry.sleep() will sleep for
[0.0s, 0.2s, 0.4s, 0.8s, …] between retries. No backoff will ever be longer than backoff_max.
:return: dict/response object
"""
session = requests.Session()
max_retries = kwargs.get("max_retries", 0)
status_forcelist = kwargs.get("status_forcelist", [104, 502, 503, 504])
backoff_factor = kwargs.get("backoff_factor", 0)
retries = Retry(total=max_retries, backoff_factor=backoff_factor, status_forcelist=status_forcelist, read=False)
session.mount('https://', HTTPAdapter(max_retries=retries))
session.mount('http://', HTTPAdapter(max_retries=retries))
if not headers:
headers = {}

Expand All @@ -1302,14 +1316,14 @@ def execute_http_request(
request_method.upper(), http_url, params=request_body, headers=headers, timeout=kwargs.get('timeout')
)
elif request_method.lower() in ['post', 'put']:
response = requests.request(
response = session.request(
request_method.upper(), http_url, json=request_body, headers=headers, timeout=kwargs.get('timeout')
)
else:
raise AppException("Invalid request method!")
logger.debug("raw response: " + str(response.text))
logger.debug("status " + str(response.status_code))
except requests.exceptions.ConnectTimeout:
except (requests.exceptions.ConnectTimeout, requests.exceptions.ConnectionError):
_, _, host, _, _, _, _ = parse_url(http_url)
raise AppException(f"Failed to connect to service: {host}")
except Exception as e:
Expand Down Expand Up @@ -1637,21 +1651,6 @@ def get_llm_hyperparameters():
return hyperparameters
raise AppException("Could not find any hyperparameters for configured LLM.")

@staticmethod
def format_llm_response(response, is_streamed: bool = False):
formatted_response = ''
if is_streamed:
raw_response = []
for chunk in response:
for delta in chunk['choices']:
if delta['delta'].get('content'):
formatted_response = f"{formatted_response}{delta['delta'].get('content')}"
raw_response.append(chunk.to_dict_recursive())
else:
formatted_response = ' '.join([choice['message']['content'] for choice in response.to_dict_recursive()['choices']])
raw_response = response.to_dict_recursive()
return formatted_response, raw_response

@staticmethod
def create_uuid_from_string(val: str):
hex_string = hashlib.md5(val.encode("UTF-8")).hexdigest()
Expand Down
Loading

0 comments on commit c75471c

Please sign in to comment.