diff --git a/src/drd/api/ollama_api.py b/src/drd/api/ollama_api.py new file mode 100644 index 0000000..b6475ee --- /dev/null +++ b/src/drd/api/ollama_api.py @@ -0,0 +1,53 @@ +import requests +from typing import Dict, Any, Generator, Optional +import json + +OLLAMA_ENDPOINT = "http://localhost:11434/api" + + +def get_ollama_client(): + # This is a placeholder function to maintain consistency with other APIs + # Ollama doesn't require a client object, but we'll use this to set up any necessary configurations + return None + + +def call_ollama_api(model: str, prompt: str, system_prompt: str = "") -> str: + data = { + "model": model, + "prompt": prompt, + "system": system_prompt, + "stream": False + } + response = requests.post(f"{OLLAMA_ENDPOINT}/generate", json=data) + response.raise_for_status() + return response.json()["response"] + + +def stream_ollama_response(model: str, prompt: str, system_prompt: str = "") -> Generator[str, None, None]: + data = { + "model": model, + "prompt": prompt, + "system": system_prompt, + "stream": True + } + response = requests.post( + f"{OLLAMA_ENDPOINT}/generate", json=data, stream=True) + response.raise_for_status() + + for line in response.iter_lines(): + if line: + chunk = json.loads(line) + if chunk.get("response"): + yield chunk["response"] + + +def call_ollama_api_with_pagination(query: str, model: str, include_context: bool = False, instruction_prompt: Optional[str] = None) -> str: + full_response = call_ollama_api(model, query, instruction_prompt or "") + return full_response + +# Note: Ollama doesn't have built-in support for image input like OpenAI. +# For vision-related tasks, we'd need to use a different approach or model. + + +def call_ollama_vision_api_with_pagination(query: str, image_path: str, model: str, include_context: bool = False, instruction_prompt: Optional[str] = None) -> str: + raise NotImplementedError("Vision API is not supported for Ollama models") diff --git a/src/drd/api/openai_api.py b/src/drd/api/openai_api.py index 5f0770e..a9b0436 100644 --- a/src/drd/api/openai_api.py +++ b/src/drd/api/openai_api.py @@ -6,6 +6,7 @@ from ..utils.parser import extract_and_parse_xml, parse_dravid_response import xml.etree.ElementTree as ET import click +from .ollama_api import get_ollama_client, call_ollama_api_with_pagination, stream_ollama_response DEFAULT_MODEL = "gpt-4o-2024-05-13" MAX_TOKENS = 4000 @@ -33,6 +34,8 @@ def get_client(): api_key = get_env_variable("DRAVID_LLM_API_KEY") api_base = get_env_variable("DRAVID_LLM_ENDPOINT") return OpenAI(api_key=api_key, base_url=api_base) + elif llm_type == 'ollama': + return get_ollama_client() else: raise ValueError(f"Unsupported LLM type: {llm_type}") @@ -41,7 +44,7 @@ def get_model(): llm_type = get_env_variable('DRAVID_LLM', 'openai').lower() if llm_type == 'azure': return get_env_variable("AZURE_OPENAI_DEPLOYMENT_NAME") - elif llm_type == 'custom': + elif llm_type == 'custom' or llm_type == 'ollama': return get_env_variable("DRAVID_LLM_MODEL") else: return get_env_variable("OPENAI_MODEL", DEFAULT_MODEL) @@ -57,8 +60,13 @@ def parse_response(response: str) -> str: def call_api_with_pagination(query: str, include_context: bool = False, instruction_prompt: Optional[str] = None) -> str: - client = get_client() + llm_type = get_env_variable('DRAVID_LLM', 'openai').lower() model = get_model() + + if llm_type == 'ollama': + return call_ollama_api_with_pagination(query, model, include_context, instruction_prompt) + + client = get_client() full_response = "" messages = [ {"role": "system", "content": instruction_prompt or ""}, @@ -83,6 +91,11 @@ def call_api_with_pagination(query: str, include_context: bool = False, instruct def call_vision_api_with_pagination(query: str, image_path: str, include_context: bool = False, instruction_prompt: Optional[str] = None) -> str: + llm_type = get_env_variable('DRAVID_LLM', 'openai').lower() + if llm_type == 'ollama': + raise NotImplementedError( + "Vision API is not supported for Ollama models") + client = get_client() model = get_model() with open(image_path, "rb") as image_file: @@ -119,8 +132,14 @@ def call_vision_api_with_pagination(query: str, image_path: str, include_context def stream_response(query: str, instruction_prompt: Optional[str] = None) -> Generator[str, None, None]: - client = get_client() + llm_type = get_env_variable('DRAVID_LLM', 'openai').lower() model = get_model() + + if llm_type == 'ollama': + yield from stream_ollama_response(model, query, instruction_prompt or "") + return + + client = get_client() messages = [ {"role": "system", "content": instruction_prompt or ""}, {"role": "user", "content": query} diff --git a/tests/api/test_openai_api.py b/tests/api/test_openai_api.py index 39514cc..88883fb 100644 --- a/tests/api/test_openai_api.py +++ b/tests/api/test_openai_api.py @@ -1,4 +1,5 @@ import unittest +import requests from unittest.mock import patch, MagicMock import os from openai import OpenAI, AzureOpenAI @@ -22,7 +23,77 @@ def setUp(self): self.query = "Test query" self.image_path = "test_image.jpg" - # ... (keep the existing tests for get_env_variable, get_client, get_model, and parse_response) ... + @patch.dict(os.environ, {"OPENAI_API_KEY": "test_openai_api_key"}) + def test_get_env_variable_existing(self): + self.assertEqual(get_env_variable( + "OPENAI_API_KEY"), "test_openai_api_key") + + @patch.dict(os.environ, {}, clear=True) + def test_get_env_variable_missing(self): + with self.assertRaises(ValueError): + get_env_variable("NON_EXISTENT_VAR") + + @patch.dict(os.environ, {"DRAVID_LLM": "openai", "OPENAI_API_KEY": "test_key"}) + def test_get_client_openai(self): + client = get_client() + self.assertIsInstance(client, OpenAI) + self.assertEqual(client.api_key, "test_key") + + @patch.dict(os.environ, { + "DRAVID_LLM": "azure", + "AZURE_OPENAI_API_KEY": "test_azure_key", + "AZURE_OPENAI_API_VERSION": "2023-05-15", + "AZURE_OPENAI_ENDPOINT": "https://test.openai.azure.com" + }) + def test_get_client_azure(self): + client = get_client() + self.assertIsInstance(client, AzureOpenAI) + # Note: We can't directly access these attributes in the new OpenAI client + # Instead, we can check if the client was initialized correctly + self.assertTrue(isinstance(client, AzureOpenAI)) + + @patch.dict(os.environ, { + "DRAVID_LLM": "azure", + "AZURE_OPENAI_DEPLOYMENT_NAME": "test-deployment" + }) + def test_get_model_azure(self): + model = get_model() + self.assertEqual(model, "test-deployment") + + @patch.dict(os.environ, { + "DRAVID_LLM": "custom", + "DRAVID_LLM_API_KEY": "test_custom_key", + "DRAVID_LLM_ENDPOINT": "https://custom-llm-endpoint.com" + }) + def test_get_client_custom(self): + client = get_client() + self.assertIsInstance(client, OpenAI) + self.assertEqual(client.api_key, "test_custom_key") + self.assertEqual(client.base_url, "https://custom-llm-endpoint.com") + + @patch.dict(os.environ, {"DRAVID_LLM": "openai", "OPENAI_MODEL": "gpt-4"}) + def test_get_model_openai(self): + self.assertEqual(get_model(), "gpt-4") + + @patch.dict(os.environ, {"DRAVID_LLM": "azure", "AZURE_OPENAI_DEPLOYMENT_NAME": "test-deployment"}) + def test_get_model_azure(self): + self.assertEqual(get_model(), "test-deployment") + + @patch.dict(os.environ, {"DRAVID_LLM": "custom", "DRAVID_LLM_MODEL": "llama-3"}) + def test_get_model_custom(self): + self.assertEqual(get_model(), "llama-3") + + def test_parse_response_valid_xml(self): + xml_response = "Test content" + parsed = parse_response(xml_response) + self.assertEqual(parsed, xml_response) + + @patch('drd.api.openai_api.click.echo') + def test_parse_response_invalid_xml(self, mock_echo): + invalid_xml = "Not XML" + parsed = parse_response(invalid_xml) + self.assertEqual(parsed, invalid_xml) + mock_echo.assert_called_once() @patch('drd.api.openai_api.get_client') @patch('drd.api.openai_api.get_model') @@ -99,6 +170,83 @@ def test_stream_response(self, mock_get_model, mock_get_client): self.assertEqual(call_args['messages'][1]['content'], self.query) self.assertTrue(call_args['stream']) + @patch.dict(os.environ, {"DRAVID_LLM": "ollama", "DRAVID_LLM_MODEL": "starcoder"}) + def test_get_client_ollama(self): + client = get_client() + self.assertIsNone(client) # Ollama doesn't use a client object + + @patch.dict(os.environ, {"DRAVID_LLM": "ollama", "DRAVID_LLM_MODEL": "starcoder"}) + def test_get_model_ollama(self): + model = get_model() + self.assertEqual(model, "starcoder") + + @patch('requests.post') + @patch.dict(os.environ, {"DRAVID_LLM": "ollama", "DRAVID_LLM_MODEL": "starcoder"}) + def test_call_api_with_pagination_ollama(self, mock_post): + mock_response = MagicMock() + mock_response.json.return_value = { + "response": "Test Ollama response"} + mock_post.return_value = mock_response + + response = call_api_with_pagination(self.query) + self.assertEqual(response, "Test Ollama response") + + mock_post.assert_called_once_with( + "http://localhost:11434/api/generate", + json={ + "model": "starcoder", + "prompt": self.query, + "system": "", + "stream": False + } + ) + + @patch.dict(os.environ, {"DRAVID_LLM": "ollama", "DRAVID_LLM_MODEL": "starcoder"}) + def test_call_vision_api_with_pagination_ollama(self): + with self.assertRaises(NotImplementedError): + call_vision_api_with_pagination(self.query, self.image_path) + + @patch('requests.post') + @patch.dict(os.environ, {"DRAVID_LLM": "ollama", "DRAVID_LLM_MODEL": "starcoder"}) + def test_stream_response_ollama(self, mock_post): + mock_response = MagicMock() + mock_response.iter_lines.return_value = [ + b'{"response":"Test"}', + b'{"response":" stream"}', + b'{"done":true}' + ] + mock_post.return_value = mock_response + + result = list(stream_response(self.query)) + self.assertEqual(result, ["Test", " stream"]) + + mock_post.assert_called_once_with( + "http://localhost:11434/api/generate", + json={ + "model": "starcoder", + "prompt": self.query, + "system": "", + "stream": True + }, + stream=True + ) + + @patch('requests.post') + @patch.dict(os.environ, {"DRAVID_LLM": "ollama", "DRAVID_LLM_MODEL": "starcoder"}) + def test_call_api_with_pagination_ollama_error(self, mock_post): + mock_post.side_effect = requests.RequestException("Ollama API error") + + with self.assertRaises(requests.RequestException): + call_api_with_pagination(self.query) + + @patch('requests.post') + @patch.dict(os.environ, {"DRAVID_LLM": "ollama", "DRAVID_LLM_MODEL": "starcoder"}) + def test_stream_response_ollama_error(self, mock_post): + mock_post.side_effect = requests.RequestException("Ollama API error") + + with self.assertRaises(requests.RequestException): + list(stream_response(self.query)) + if __name__ == '__main__': unittest.main()