Skip to content

Commit

Permalink
support azure, other openai compatible providers
Browse files Browse the repository at this point in the history
  • Loading branch information
vysakh0 committed Jul 19, 2024
1 parent 4938c2b commit c9cbc03
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 42 deletions.
13 changes: 7 additions & 6 deletions src/drd/api/main.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
import click
import os
import click
from .claude_api import call_claude_api_with_pagination, call_claude_vision_api_with_pagination, stream_claude_response
from .openai_api import call_openai_api_with_pagination, call_openai_vision_api_with_pagination, stream_openai_response
from .openai_api import call_api_with_pagination, call_vision_api_with_pagination, stream_response
from ..utils import print_debug, print_info
from ..utils.loader import Loader
from ..utils.pretty_print_stream import pretty_print_xml_stream
from ..utils.parser import parse_dravid_response
import xml.etree.ElementTree as ET

LLM_PROVIDER = os.getenv('DRAVID_LLM', 'openai').lower()


def get_api_functions():
if LLM_PROVIDER == 'claude':
llm_type = os.getenv('DRAVID_LLM', 'claude').lower()
if llm_type == 'claude':
return call_claude_api_with_pagination, call_claude_vision_api_with_pagination, stream_claude_response
elif llm_type in ['openai', 'azure', 'custom']:
return call_api_with_pagination, call_vision_api_with_pagination, stream_response
else:
return call_openai_api_with_pagination, call_openai_vision_api_with_pagination, stream_openai_response
raise ValueError(f"Unsupported LLM type: {llm_type}")


def stream_dravid_api(query, include_context=False, instruction_prompt=None, print_chunk=False):
Expand Down
58 changes: 42 additions & 16 deletions src/drd/api/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,53 @@
import json
import base64
from typing import Dict, Any, Optional, List, Generator
from openai import OpenAI
from openai import OpenAI, AzureOpenAI
from ..utils.parser import extract_and_parse_xml, parse_dravid_response
import xml.etree.ElementTree as ET
import click

MODEL = "gpt-4o-2024-05-13"
DEFAULT_MODEL = "gpt-4o-2024-05-13"
MAX_TOKENS = 4000


def get_api_key() -> str:
api_key = os.getenv('OPENAI_API_KEY')
if not api_key:
raise ValueError("OPENAI_API_KEY not found in environment variables")
return api_key
def get_env_variable(name: str, default: Optional[str] = None) -> str:
value = os.getenv(name, default)
if value is None:
raise ValueError(f"{name} not found in environment variables")
return value


client = OpenAI(api_key=get_api_key())
def get_client():
llm_type = get_env_variable('DRAVID_LLM', 'openai').lower()

if llm_type == 'azure':
return AzureOpenAI(
api_key=get_env_variable("AZURE_OPENAI_API_KEY"),
api_version=get_env_variable("AZURE_OPENAI_API_VERSION"),
azure_endpoint=get_env_variable("AZURE_OPENAI_ENDPOINT")
)
elif llm_type in ['openai', 'custom']:
api_key = get_env_variable(
"OPENAI_API_KEY" if llm_type == 'openai' else "DRAVID_LLM_API_KEY")
api_base = get_env_variable(
"DRAVID_LLM_ENDPOINT", "https://api.openai.com/v1") if llm_type == 'custom' else None
return OpenAI(api_key=api_key, base_url=api_base)
else:
raise ValueError(f"Unsupported LLM type: {llm_type}")


def get_model():
llm_type = get_env_variable('DRAVID_LLM', 'openai').lower()
if llm_type == 'azure':
return get_env_variable("AZURE_OPENAI_DEPLOYMENT_NAME")
elif llm_type == 'custom':
return get_env_variable("DRAVID_LLM_MODEL")
else:
return get_env_variable("OPENAI_MODEL", DEFAULT_MODEL)


client = get_client()
MODEL = get_model()


def parse_response(response: str) -> str:
Expand All @@ -30,9 +60,8 @@ def parse_response(response: str) -> str:
return response


def call_openai_api_with_pagination(query: str, include_context: bool = False, instruction_prompt: Optional[str] = None) -> str:
def call_api_with_pagination(query: str, include_context: bool = False, instruction_prompt: Optional[str] = None) -> str:
full_response = ""

messages = [
{"role": "system", "content": instruction_prompt or ""},
{"role": "user", "content": query}
Expand All @@ -44,7 +73,6 @@ def call_openai_api_with_pagination(query: str, include_context: bool = False, i
messages=messages,
max_tokens=MAX_TOKENS
)

full_response += response.choices[0].message.content

if response.choices[0].finish_reason != 'length':
Expand All @@ -56,12 +84,11 @@ def call_openai_api_with_pagination(query: str, include_context: bool = False, i
return parse_response(full_response)


def call_openai_vision_api_with_pagination(query: str, image_path: str, include_context: bool = False, instruction_prompt: Optional[str] = None) -> str:
def call_vision_api_with_pagination(query: str, image_path: str, include_context: bool = False, instruction_prompt: Optional[str] = None) -> str:
with open(image_path, "rb") as image_file:
image_data = base64.b64encode(image_file.read()).decode('utf-8')

full_response = ""

messages = [
{"role": "system", "content": instruction_prompt or ""},
{
Expand All @@ -80,7 +107,6 @@ def call_openai_vision_api_with_pagination(query: str, image_path: str, include_
messages=messages,
max_tokens=MAX_TOKENS
)

full_response += response.choices[0].message.content

if response.choices[0].finish_reason != 'length':
Expand All @@ -92,7 +118,7 @@ def call_openai_vision_api_with_pagination(query: str, image_path: str, include_
return parse_response(full_response)


def stream_openai_response(query: str, instruction_prompt: Optional[str] = None) -> Generator[str, None, None]:
def stream_response(query: str, instruction_prompt: Optional[str] = None) -> Generator[str, None, None]:
messages = [
{"role": "system", "content": instruction_prompt or ""},
{"role": "user", "content": query}
Expand All @@ -106,5 +132,5 @@ def stream_openai_response(query: str, instruction_prompt: Optional[str] = None)
)

for chunk in response:
if chunk.choices[0].delta.content is not None:
if chunk.choices and chunk.choices[0].delta.content is not None:
yield chunk.choices[0].delta.content
94 changes: 74 additions & 20 deletions tests/api/test_openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,88 @@
import os
import xml.etree.ElementTree as ET
from io import BytesIO
from openai import OpenAI, AzureOpenAI

from drd.api.openai_api import (
get_api_key,
get_env_variable,
get_client,
get_model,
parse_response,
call_openai_api_with_pagination,
call_openai_vision_api_with_pagination,
stream_openai_response,
call_api_with_pagination,
call_vision_api_with_pagination,
stream_response,
)

MODEL = "gpt-4o-2024-05-13"
DEFAULT_MODEL = "gpt-4o-2024-05-13"


class TestOpenAIApiUtils(unittest.TestCase):

def setUp(self):
self.api_key = "test_openai_api_key"
self.api_key = "test_api_key"
self.query = "Test query"
self.image_path = "test_image.jpg"

@patch.dict(os.environ, {"OPENAI_API_KEY": "test_openai_api_key"})
def test_get_api_key(self):
self.assertEqual(get_api_key(), "test_openai_api_key")
def test_get_env_variable_existing(self):
self.assertEqual(get_env_variable(
"OPENAI_API_KEY"), "test_openai_api_key")

@patch.dict(os.environ, {}, clear=True)
def test_get_api_key_missing(self):
def test_get_env_variable_missing(self):
with self.assertRaises(ValueError):
get_api_key()
get_env_variable("NON_EXISTENT_VAR")

@patch.dict(os.environ, {"DRAVID_LLM": "openai", "OPENAI_API_KEY": "test_key"})
def test_get_client_openai(self):
client = get_client()
self.assertIsInstance(client, OpenAI)
# Note: In newer OpenAI client versions, api_key is not directly accessible
self.assertEqual(client.api_key, "test_key")

@patch.dict(os.environ, {
"DRAVID_LLM": "azure",
"AZURE_OPENAI_API_KEY": "test_azure_key",
"AZURE_OPENAI_API_VERSION": "2023-05-15",
"AZURE_OPENAI_ENDPOINT": "https://test.openai.azure.com"
})
def test_get_client_azure(self):
client = get_client()
self.assertIsInstance(client, AzureOpenAI)
# Note: We can't directly access these attributes in the new OpenAI client
# Instead, we can check if the client was initialized correctly
self.assertTrue(isinstance(client, AzureOpenAI))

@patch.dict(os.environ, {
"DRAVID_LLM": "azure",
"AZURE_OPENAI_DEPLOYMENT_NAME": "test-deployment"
})
def test_get_model_azure(self):
model = get_model()
self.assertEqual(model, "test-deployment")

@patch.dict(os.environ, {
"DRAVID_LLM": "custom",
"DRAVID_LLM_API_KEY": "test_custom_key",
"DRAVID_LLM_ENDPOINT": "https://custom-llm-endpoint.com"
})
def test_get_client_custom(self):
client = get_client()
self.assertIsInstance(client, OpenAI)
self.assertEqual(client.api_key, "test_custom_key")
self.assertEqual(client.base_url, "https://custom-llm-endpoint.com")

@patch.dict(os.environ, {"DRAVID_LLM": "openai", "OPENAI_MODEL": "gpt-4"})
def test_get_model_openai(self):
self.assertEqual(get_model(), "gpt-4")

@patch.dict(os.environ, {"DRAVID_LLM": "azure", "AZURE_OPENAI_DEPLOYMENT_NAME": "test-deployment"})
def test_get_model_azure(self):
self.assertEqual(get_model(), "test-deployment")

@patch.dict(os.environ, {"DRAVID_LLM": "custom", "DRAVID_LLM_MODEL": "llama-3"})
def test_get_model_custom(self):
self.assertEqual(get_model(), "llama-3")

def test_parse_response_valid_xml(self):
xml_response = "<response><content>Test content</content></response>"
Expand All @@ -44,35 +99,34 @@ def test_parse_response_invalid_xml(self, mock_echo):
mock_echo.assert_called_once()

@patch('drd.api.openai_api.client.chat.completions.create')
def test_call_openai_api_with_pagination(self, mock_create):
def test_call_api_with_pagination(self, mock_create):
mock_response = MagicMock()
mock_response.choices[0].message.content = "<response>Test response</response>"
mock_response.choices[0].finish_reason = 'stop'
mock_create.return_value = mock_response

response = call_openai_api_with_pagination(self.query)
response = call_api_with_pagination(self.query)
self.assertEqual(response, "<response>Test response</response>")

mock_create.assert_called_once()
call_args = mock_create.call_args[1]
self.assertEqual(call_args['model'], MODEL)
self.assertEqual(call_args['model'], DEFAULT_MODEL)
self.assertEqual(call_args['messages'][1]['content'], self.query)

@patch('drd.api.openai_api.client.chat.completions.create')
@patch('builtins.open', new_callable=unittest.mock.mock_open, read_data=b'test image data')
def test_call_openai_vision_api_with_pagination(self, mock_open, mock_create):
def test_call_vision_api_with_pagination(self, mock_open, mock_create):
mock_response = MagicMock()
mock_response.choices[0].message.content = "<response>Test vision response</response>"
mock_response.choices[0].finish_reason = 'stop'
mock_create.return_value = mock_response

response = call_openai_vision_api_with_pagination(
self.query, self.image_path)
response = call_vision_api_with_pagination(self.query, self.image_path)
self.assertEqual(response, "<response>Test vision response</response>")

mock_create.assert_called_once()
call_args = mock_create.call_args[1]
self.assertEqual(call_args['model'], MODEL)
self.assertEqual(call_args['model'], DEFAULT_MODEL)
self.assertEqual(call_args['messages'][1]
['content'][0]['type'], 'text')
self.assertEqual(call_args['messages'][1]
Expand All @@ -83,7 +137,7 @@ def test_call_openai_vision_api_with_pagination(self, mock_open, mock_create):
['image_url']['url'].startswith('data:image/jpeg;base64,'))

@patch('drd.api.openai_api.client.chat.completions.create')
def test_stream_openai_response(self, mock_create):
def test_stream_response(self, mock_create):
mock_response = MagicMock()
mock_response.__iter__.return_value = [
MagicMock(choices=[MagicMock(delta=MagicMock(content="Test"))]),
Expand All @@ -92,11 +146,11 @@ def test_stream_openai_response(self, mock_create):
]
mock_create.return_value = mock_response

result = list(stream_openai_response(self.query))
result = list(stream_response(self.query))
self.assertEqual(result, ["Test", " stream"])

mock_create.assert_called_once()
call_args = mock_create.call_args[1]
self.assertEqual(call_args['model'], MODEL)
self.assertEqual(call_args['model'], DEFAULT_MODEL)
self.assertEqual(call_args['messages'][1]['content'], self.query)
self.assertTrue(call_args['stream'])

0 comments on commit c9cbc03

Please sign in to comment.