Skip to content

Commit

Permalink
roll support for ollama
Browse files Browse the repository at this point in the history
  • Loading branch information
vysakh0 committed Jul 19, 2024
1 parent 70ee5eb commit 010ca23
Show file tree
Hide file tree
Showing 3 changed files with 224 additions and 4 deletions.
53 changes: 53 additions & 0 deletions src/drd/api/ollama_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import requests
from typing import Dict, Any, Generator, Optional
import json

OLLAMA_ENDPOINT = "http://localhost:11434/api"


def get_ollama_client():
# This is a placeholder function to maintain consistency with other APIs
# Ollama doesn't require a client object, but we'll use this to set up any necessary configurations
return None


def call_ollama_api(model: str, prompt: str, system_prompt: str = "") -> str:
data = {
"model": model,
"prompt": prompt,
"system": system_prompt,
"stream": False
}
response = requests.post(f"{OLLAMA_ENDPOINT}/generate", json=data)
response.raise_for_status()
return response.json()["response"]


def stream_ollama_response(model: str, prompt: str, system_prompt: str = "") -> Generator[str, None, None]:
data = {
"model": model,
"prompt": prompt,
"system": system_prompt,
"stream": True
}
response = requests.post(
f"{OLLAMA_ENDPOINT}/generate", json=data, stream=True)
response.raise_for_status()

for line in response.iter_lines():
if line:
chunk = json.loads(line)
if chunk.get("response"):
yield chunk["response"]


def call_ollama_api_with_pagination(query: str, model: str, include_context: bool = False, instruction_prompt: Optional[str] = None) -> str:
full_response = call_ollama_api(model, query, instruction_prompt or "")
return full_response

# Note: Ollama doesn't have built-in support for image input like OpenAI.
# For vision-related tasks, we'd need to use a different approach or model.


def call_ollama_vision_api_with_pagination(query: str, image_path: str, model: str, include_context: bool = False, instruction_prompt: Optional[str] = None) -> str:
raise NotImplementedError("Vision API is not supported for Ollama models")
25 changes: 22 additions & 3 deletions src/drd/api/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ..utils.parser import extract_and_parse_xml, parse_dravid_response
import xml.etree.ElementTree as ET
import click
from .ollama_api import get_ollama_client, call_ollama_api_with_pagination, stream_ollama_response

DEFAULT_MODEL = "gpt-4o-2024-05-13"
MAX_TOKENS = 4000
Expand Down Expand Up @@ -33,6 +34,8 @@ def get_client():
api_key = get_env_variable("DRAVID_LLM_API_KEY")
api_base = get_env_variable("DRAVID_LLM_ENDPOINT")
return OpenAI(api_key=api_key, base_url=api_base)
elif llm_type == 'ollama':
return get_ollama_client()
else:
raise ValueError(f"Unsupported LLM type: {llm_type}")

Expand All @@ -41,7 +44,7 @@ def get_model():
llm_type = get_env_variable('DRAVID_LLM', 'openai').lower()
if llm_type == 'azure':
return get_env_variable("AZURE_OPENAI_DEPLOYMENT_NAME")
elif llm_type == 'custom':
elif llm_type == 'custom' or llm_type == 'ollama':
return get_env_variable("DRAVID_LLM_MODEL")
else:
return get_env_variable("OPENAI_MODEL", DEFAULT_MODEL)
Expand All @@ -57,8 +60,13 @@ def parse_response(response: str) -> str:


def call_api_with_pagination(query: str, include_context: bool = False, instruction_prompt: Optional[str] = None) -> str:
client = get_client()
llm_type = get_env_variable('DRAVID_LLM', 'openai').lower()
model = get_model()

if llm_type == 'ollama':
return call_ollama_api_with_pagination(query, model, include_context, instruction_prompt)

client = get_client()
full_response = ""
messages = [
{"role": "system", "content": instruction_prompt or ""},
Expand All @@ -83,6 +91,11 @@ def call_api_with_pagination(query: str, include_context: bool = False, instruct


def call_vision_api_with_pagination(query: str, image_path: str, include_context: bool = False, instruction_prompt: Optional[str] = None) -> str:
llm_type = get_env_variable('DRAVID_LLM', 'openai').lower()
if llm_type == 'ollama':
raise NotImplementedError(
"Vision API is not supported for Ollama models")

client = get_client()
model = get_model()
with open(image_path, "rb") as image_file:
Expand Down Expand Up @@ -119,8 +132,14 @@ def call_vision_api_with_pagination(query: str, image_path: str, include_context


def stream_response(query: str, instruction_prompt: Optional[str] = None) -> Generator[str, None, None]:
client = get_client()
llm_type = get_env_variable('DRAVID_LLM', 'openai').lower()
model = get_model()

if llm_type == 'ollama':
yield from stream_ollama_response(model, query, instruction_prompt or "")
return

client = get_client()
messages = [
{"role": "system", "content": instruction_prompt or ""},
{"role": "user", "content": query}
Expand Down
150 changes: 149 additions & 1 deletion tests/api/test_openai_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
import requests
from unittest.mock import patch, MagicMock
import os
from openai import OpenAI, AzureOpenAI
Expand All @@ -22,7 +23,77 @@ def setUp(self):
self.query = "Test query"
self.image_path = "test_image.jpg"

# ... (keep the existing tests for get_env_variable, get_client, get_model, and parse_response) ...
@patch.dict(os.environ, {"OPENAI_API_KEY": "test_openai_api_key"})
def test_get_env_variable_existing(self):
self.assertEqual(get_env_variable(
"OPENAI_API_KEY"), "test_openai_api_key")

@patch.dict(os.environ, {}, clear=True)
def test_get_env_variable_missing(self):
with self.assertRaises(ValueError):
get_env_variable("NON_EXISTENT_VAR")

@patch.dict(os.environ, {"DRAVID_LLM": "openai", "OPENAI_API_KEY": "test_key"})
def test_get_client_openai(self):
client = get_client()
self.assertIsInstance(client, OpenAI)
self.assertEqual(client.api_key, "test_key")

@patch.dict(os.environ, {
"DRAVID_LLM": "azure",
"AZURE_OPENAI_API_KEY": "test_azure_key",
"AZURE_OPENAI_API_VERSION": "2023-05-15",
"AZURE_OPENAI_ENDPOINT": "https://test.openai.azure.com"
})
def test_get_client_azure(self):
client = get_client()
self.assertIsInstance(client, AzureOpenAI)
# Note: We can't directly access these attributes in the new OpenAI client
# Instead, we can check if the client was initialized correctly
self.assertTrue(isinstance(client, AzureOpenAI))

@patch.dict(os.environ, {
"DRAVID_LLM": "azure",
"AZURE_OPENAI_DEPLOYMENT_NAME": "test-deployment"
})
def test_get_model_azure(self):
model = get_model()
self.assertEqual(model, "test-deployment")

@patch.dict(os.environ, {
"DRAVID_LLM": "custom",
"DRAVID_LLM_API_KEY": "test_custom_key",
"DRAVID_LLM_ENDPOINT": "https://custom-llm-endpoint.com"
})
def test_get_client_custom(self):
client = get_client()
self.assertIsInstance(client, OpenAI)
self.assertEqual(client.api_key, "test_custom_key")
self.assertEqual(client.base_url, "https://custom-llm-endpoint.com")

@patch.dict(os.environ, {"DRAVID_LLM": "openai", "OPENAI_MODEL": "gpt-4"})
def test_get_model_openai(self):
self.assertEqual(get_model(), "gpt-4")

@patch.dict(os.environ, {"DRAVID_LLM": "azure", "AZURE_OPENAI_DEPLOYMENT_NAME": "test-deployment"})
def test_get_model_azure(self):
self.assertEqual(get_model(), "test-deployment")

@patch.dict(os.environ, {"DRAVID_LLM": "custom", "DRAVID_LLM_MODEL": "llama-3"})
def test_get_model_custom(self):
self.assertEqual(get_model(), "llama-3")

def test_parse_response_valid_xml(self):
xml_response = "<response><content>Test content</content></response>"
parsed = parse_response(xml_response)
self.assertEqual(parsed, xml_response)

@patch('drd.api.openai_api.click.echo')
def test_parse_response_invalid_xml(self, mock_echo):
invalid_xml = "Not XML"
parsed = parse_response(invalid_xml)
self.assertEqual(parsed, invalid_xml)
mock_echo.assert_called_once()

@patch('drd.api.openai_api.get_client')
@patch('drd.api.openai_api.get_model')
Expand Down Expand Up @@ -99,6 +170,83 @@ def test_stream_response(self, mock_get_model, mock_get_client):
self.assertEqual(call_args['messages'][1]['content'], self.query)
self.assertTrue(call_args['stream'])

@patch.dict(os.environ, {"DRAVID_LLM": "ollama", "DRAVID_LLM_MODEL": "starcoder"})
def test_get_client_ollama(self):
client = get_client()
self.assertIsNone(client) # Ollama doesn't use a client object

@patch.dict(os.environ, {"DRAVID_LLM": "ollama", "DRAVID_LLM_MODEL": "starcoder"})
def test_get_model_ollama(self):
model = get_model()
self.assertEqual(model, "starcoder")

@patch('requests.post')
@patch.dict(os.environ, {"DRAVID_LLM": "ollama", "DRAVID_LLM_MODEL": "starcoder"})
def test_call_api_with_pagination_ollama(self, mock_post):
mock_response = MagicMock()
mock_response.json.return_value = {
"response": "<response>Test Ollama response</response>"}
mock_post.return_value = mock_response

response = call_api_with_pagination(self.query)
self.assertEqual(response, "<response>Test Ollama response</response>")

mock_post.assert_called_once_with(
"http://localhost:11434/api/generate",
json={
"model": "starcoder",
"prompt": self.query,
"system": "",
"stream": False
}
)

@patch.dict(os.environ, {"DRAVID_LLM": "ollama", "DRAVID_LLM_MODEL": "starcoder"})
def test_call_vision_api_with_pagination_ollama(self):
with self.assertRaises(NotImplementedError):
call_vision_api_with_pagination(self.query, self.image_path)

@patch('requests.post')
@patch.dict(os.environ, {"DRAVID_LLM": "ollama", "DRAVID_LLM_MODEL": "starcoder"})
def test_stream_response_ollama(self, mock_post):
mock_response = MagicMock()
mock_response.iter_lines.return_value = [
b'{"response":"Test"}',
b'{"response":" stream"}',
b'{"done":true}'
]
mock_post.return_value = mock_response

result = list(stream_response(self.query))
self.assertEqual(result, ["Test", " stream"])

mock_post.assert_called_once_with(
"http://localhost:11434/api/generate",
json={
"model": "starcoder",
"prompt": self.query,
"system": "",
"stream": True
},
stream=True
)

@patch('requests.post')
@patch.dict(os.environ, {"DRAVID_LLM": "ollama", "DRAVID_LLM_MODEL": "starcoder"})
def test_call_api_with_pagination_ollama_error(self, mock_post):
mock_post.side_effect = requests.RequestException("Ollama API error")

with self.assertRaises(requests.RequestException):
call_api_with_pagination(self.query)

@patch('requests.post')
@patch.dict(os.environ, {"DRAVID_LLM": "ollama", "DRAVID_LLM_MODEL": "starcoder"})
def test_stream_response_ollama_error(self, mock_post):
mock_post.side_effect = requests.RequestException("Ollama API error")

with self.assertRaises(requests.RequestException):
list(stream_response(self.query))


if __name__ == '__main__':
unittest.main()

0 comments on commit 010ca23

Please sign in to comment.