diff --git a/src/drd/api/ollama_api.py b/src/drd/api/ollama_api.py
new file mode 100644
index 0000000..b6475ee
--- /dev/null
+++ b/src/drd/api/ollama_api.py
@@ -0,0 +1,53 @@
+import requests
+from typing import Dict, Any, Generator, Optional
+import json
+
+OLLAMA_ENDPOINT = "http://localhost:11434/api"
+
+
+def get_ollama_client():
+ # This is a placeholder function to maintain consistency with other APIs
+ # Ollama doesn't require a client object, but we'll use this to set up any necessary configurations
+ return None
+
+
+def call_ollama_api(model: str, prompt: str, system_prompt: str = "") -> str:
+ data = {
+ "model": model,
+ "prompt": prompt,
+ "system": system_prompt,
+ "stream": False
+ }
+ response = requests.post(f"{OLLAMA_ENDPOINT}/generate", json=data)
+ response.raise_for_status()
+ return response.json()["response"]
+
+
+def stream_ollama_response(model: str, prompt: str, system_prompt: str = "") -> Generator[str, None, None]:
+ data = {
+ "model": model,
+ "prompt": prompt,
+ "system": system_prompt,
+ "stream": True
+ }
+ response = requests.post(
+ f"{OLLAMA_ENDPOINT}/generate", json=data, stream=True)
+ response.raise_for_status()
+
+ for line in response.iter_lines():
+ if line:
+ chunk = json.loads(line)
+ if chunk.get("response"):
+ yield chunk["response"]
+
+
+def call_ollama_api_with_pagination(query: str, model: str, include_context: bool = False, instruction_prompt: Optional[str] = None) -> str:
+ full_response = call_ollama_api(model, query, instruction_prompt or "")
+ return full_response
+
+# Note: Ollama doesn't have built-in support for image input like OpenAI.
+# For vision-related tasks, we'd need to use a different approach or model.
+
+
+def call_ollama_vision_api_with_pagination(query: str, image_path: str, model: str, include_context: bool = False, instruction_prompt: Optional[str] = None) -> str:
+ raise NotImplementedError("Vision API is not supported for Ollama models")
diff --git a/src/drd/api/openai_api.py b/src/drd/api/openai_api.py
index 5f0770e..a9b0436 100644
--- a/src/drd/api/openai_api.py
+++ b/src/drd/api/openai_api.py
@@ -6,6 +6,7 @@
from ..utils.parser import extract_and_parse_xml, parse_dravid_response
import xml.etree.ElementTree as ET
import click
+from .ollama_api import get_ollama_client, call_ollama_api_with_pagination, stream_ollama_response
DEFAULT_MODEL = "gpt-4o-2024-05-13"
MAX_TOKENS = 4000
@@ -33,6 +34,8 @@ def get_client():
api_key = get_env_variable("DRAVID_LLM_API_KEY")
api_base = get_env_variable("DRAVID_LLM_ENDPOINT")
return OpenAI(api_key=api_key, base_url=api_base)
+ elif llm_type == 'ollama':
+ return get_ollama_client()
else:
raise ValueError(f"Unsupported LLM type: {llm_type}")
@@ -41,7 +44,7 @@ def get_model():
llm_type = get_env_variable('DRAVID_LLM', 'openai').lower()
if llm_type == 'azure':
return get_env_variable("AZURE_OPENAI_DEPLOYMENT_NAME")
- elif llm_type == 'custom':
+ elif llm_type == 'custom' or llm_type == 'ollama':
return get_env_variable("DRAVID_LLM_MODEL")
else:
return get_env_variable("OPENAI_MODEL", DEFAULT_MODEL)
@@ -57,8 +60,13 @@ def parse_response(response: str) -> str:
def call_api_with_pagination(query: str, include_context: bool = False, instruction_prompt: Optional[str] = None) -> str:
- client = get_client()
+ llm_type = get_env_variable('DRAVID_LLM', 'openai').lower()
model = get_model()
+
+ if llm_type == 'ollama':
+ return call_ollama_api_with_pagination(query, model, include_context, instruction_prompt)
+
+ client = get_client()
full_response = ""
messages = [
{"role": "system", "content": instruction_prompt or ""},
@@ -83,6 +91,11 @@ def call_api_with_pagination(query: str, include_context: bool = False, instruct
def call_vision_api_with_pagination(query: str, image_path: str, include_context: bool = False, instruction_prompt: Optional[str] = None) -> str:
+ llm_type = get_env_variable('DRAVID_LLM', 'openai').lower()
+ if llm_type == 'ollama':
+ raise NotImplementedError(
+ "Vision API is not supported for Ollama models")
+
client = get_client()
model = get_model()
with open(image_path, "rb") as image_file:
@@ -119,8 +132,14 @@ def call_vision_api_with_pagination(query: str, image_path: str, include_context
def stream_response(query: str, instruction_prompt: Optional[str] = None) -> Generator[str, None, None]:
- client = get_client()
+ llm_type = get_env_variable('DRAVID_LLM', 'openai').lower()
model = get_model()
+
+ if llm_type == 'ollama':
+ yield from stream_ollama_response(model, query, instruction_prompt or "")
+ return
+
+ client = get_client()
messages = [
{"role": "system", "content": instruction_prompt or ""},
{"role": "user", "content": query}
diff --git a/tests/api/test_openai_api.py b/tests/api/test_openai_api.py
index 39514cc..88883fb 100644
--- a/tests/api/test_openai_api.py
+++ b/tests/api/test_openai_api.py
@@ -1,4 +1,5 @@
import unittest
+import requests
from unittest.mock import patch, MagicMock
import os
from openai import OpenAI, AzureOpenAI
@@ -22,7 +23,77 @@ def setUp(self):
self.query = "Test query"
self.image_path = "test_image.jpg"
- # ... (keep the existing tests for get_env_variable, get_client, get_model, and parse_response) ...
+ @patch.dict(os.environ, {"OPENAI_API_KEY": "test_openai_api_key"})
+ def test_get_env_variable_existing(self):
+ self.assertEqual(get_env_variable(
+ "OPENAI_API_KEY"), "test_openai_api_key")
+
+ @patch.dict(os.environ, {}, clear=True)
+ def test_get_env_variable_missing(self):
+ with self.assertRaises(ValueError):
+ get_env_variable("NON_EXISTENT_VAR")
+
+ @patch.dict(os.environ, {"DRAVID_LLM": "openai", "OPENAI_API_KEY": "test_key"})
+ def test_get_client_openai(self):
+ client = get_client()
+ self.assertIsInstance(client, OpenAI)
+ self.assertEqual(client.api_key, "test_key")
+
+ @patch.dict(os.environ, {
+ "DRAVID_LLM": "azure",
+ "AZURE_OPENAI_API_KEY": "test_azure_key",
+ "AZURE_OPENAI_API_VERSION": "2023-05-15",
+ "AZURE_OPENAI_ENDPOINT": "https://test.openai.azure.com"
+ })
+ def test_get_client_azure(self):
+ client = get_client()
+ self.assertIsInstance(client, AzureOpenAI)
+ # Note: We can't directly access these attributes in the new OpenAI client
+ # Instead, we can check if the client was initialized correctly
+ self.assertTrue(isinstance(client, AzureOpenAI))
+
+ @patch.dict(os.environ, {
+ "DRAVID_LLM": "azure",
+ "AZURE_OPENAI_DEPLOYMENT_NAME": "test-deployment"
+ })
+ def test_get_model_azure(self):
+ model = get_model()
+ self.assertEqual(model, "test-deployment")
+
+ @patch.dict(os.environ, {
+ "DRAVID_LLM": "custom",
+ "DRAVID_LLM_API_KEY": "test_custom_key",
+ "DRAVID_LLM_ENDPOINT": "https://custom-llm-endpoint.com"
+ })
+ def test_get_client_custom(self):
+ client = get_client()
+ self.assertIsInstance(client, OpenAI)
+ self.assertEqual(client.api_key, "test_custom_key")
+ self.assertEqual(client.base_url, "https://custom-llm-endpoint.com")
+
+ @patch.dict(os.environ, {"DRAVID_LLM": "openai", "OPENAI_MODEL": "gpt-4"})
+ def test_get_model_openai(self):
+ self.assertEqual(get_model(), "gpt-4")
+
+ @patch.dict(os.environ, {"DRAVID_LLM": "azure", "AZURE_OPENAI_DEPLOYMENT_NAME": "test-deployment"})
+ def test_get_model_azure(self):
+ self.assertEqual(get_model(), "test-deployment")
+
+ @patch.dict(os.environ, {"DRAVID_LLM": "custom", "DRAVID_LLM_MODEL": "llama-3"})
+ def test_get_model_custom(self):
+ self.assertEqual(get_model(), "llama-3")
+
+ def test_parse_response_valid_xml(self):
+ xml_response = "Test content"
+ parsed = parse_response(xml_response)
+ self.assertEqual(parsed, xml_response)
+
+ @patch('drd.api.openai_api.click.echo')
+ def test_parse_response_invalid_xml(self, mock_echo):
+ invalid_xml = "Not XML"
+ parsed = parse_response(invalid_xml)
+ self.assertEqual(parsed, invalid_xml)
+ mock_echo.assert_called_once()
@patch('drd.api.openai_api.get_client')
@patch('drd.api.openai_api.get_model')
@@ -99,6 +170,83 @@ def test_stream_response(self, mock_get_model, mock_get_client):
self.assertEqual(call_args['messages'][1]['content'], self.query)
self.assertTrue(call_args['stream'])
+ @patch.dict(os.environ, {"DRAVID_LLM": "ollama", "DRAVID_LLM_MODEL": "starcoder"})
+ def test_get_client_ollama(self):
+ client = get_client()
+ self.assertIsNone(client) # Ollama doesn't use a client object
+
+ @patch.dict(os.environ, {"DRAVID_LLM": "ollama", "DRAVID_LLM_MODEL": "starcoder"})
+ def test_get_model_ollama(self):
+ model = get_model()
+ self.assertEqual(model, "starcoder")
+
+ @patch('requests.post')
+ @patch.dict(os.environ, {"DRAVID_LLM": "ollama", "DRAVID_LLM_MODEL": "starcoder"})
+ def test_call_api_with_pagination_ollama(self, mock_post):
+ mock_response = MagicMock()
+ mock_response.json.return_value = {
+ "response": "Test Ollama response"}
+ mock_post.return_value = mock_response
+
+ response = call_api_with_pagination(self.query)
+ self.assertEqual(response, "Test Ollama response")
+
+ mock_post.assert_called_once_with(
+ "http://localhost:11434/api/generate",
+ json={
+ "model": "starcoder",
+ "prompt": self.query,
+ "system": "",
+ "stream": False
+ }
+ )
+
+ @patch.dict(os.environ, {"DRAVID_LLM": "ollama", "DRAVID_LLM_MODEL": "starcoder"})
+ def test_call_vision_api_with_pagination_ollama(self):
+ with self.assertRaises(NotImplementedError):
+ call_vision_api_with_pagination(self.query, self.image_path)
+
+ @patch('requests.post')
+ @patch.dict(os.environ, {"DRAVID_LLM": "ollama", "DRAVID_LLM_MODEL": "starcoder"})
+ def test_stream_response_ollama(self, mock_post):
+ mock_response = MagicMock()
+ mock_response.iter_lines.return_value = [
+ b'{"response":"Test"}',
+ b'{"response":" stream"}',
+ b'{"done":true}'
+ ]
+ mock_post.return_value = mock_response
+
+ result = list(stream_response(self.query))
+ self.assertEqual(result, ["Test", " stream"])
+
+ mock_post.assert_called_once_with(
+ "http://localhost:11434/api/generate",
+ json={
+ "model": "starcoder",
+ "prompt": self.query,
+ "system": "",
+ "stream": True
+ },
+ stream=True
+ )
+
+ @patch('requests.post')
+ @patch.dict(os.environ, {"DRAVID_LLM": "ollama", "DRAVID_LLM_MODEL": "starcoder"})
+ def test_call_api_with_pagination_ollama_error(self, mock_post):
+ mock_post.side_effect = requests.RequestException("Ollama API error")
+
+ with self.assertRaises(requests.RequestException):
+ call_api_with_pagination(self.query)
+
+ @patch('requests.post')
+ @patch.dict(os.environ, {"DRAVID_LLM": "ollama", "DRAVID_LLM_MODEL": "starcoder"})
+ def test_stream_response_ollama_error(self, mock_post):
+ mock_post.side_effect = requests.RequestException("Ollama API error")
+
+ with self.assertRaises(requests.RequestException):
+ list(stream_response(self.query))
+
if __name__ == '__main__':
unittest.main()