Skip to content

Commit

Permalink
centralizing provider initialization, adding all supported providers,…
Browse files Browse the repository at this point in the history
… adding additional untested agent modules, migrated to functions for provider init.
  • Loading branch information
Your Name committed Jun 17, 2024
1 parent acc55b8 commit 9013986
Showing 1 changed file with 44 additions and 37 deletions.
81 changes: 44 additions & 37 deletions app/modules/llm_provider.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import os
from srt_core.config import Config
from srt_core.utils.logger import Logger
from llama_cpp_agent.providers import (
LlamaCppServerProvider,
VLLMServerProvider,
TGIServerProvider,
LlamaCppPythonProvider,
GroqProvider
)
from llama_cpp import Llama

class LLMProvider:
def __init__(self):
Expand All @@ -18,6 +25,8 @@ def load_llm_settings(self, llm_name):
"agent_provider": llm_config["agent_provider"],
"server_name": llm_config["server"],
"max_tokens": llm_config["max_tokens"],
"model": llm_config.get("model", ""),
"api_key": llm_config.get("api_key", "")
}

def set_llm_attributes(self, llm_prefix, llm_settings):
Expand All @@ -38,49 +47,47 @@ def load_provider_settings(self):
self.set_llm_attributes('summary', self.summary_llm_settings)
self.set_llm_attributes('chat', self.chat_llm_settings)

if "llama_cpp_server" in self.default_llm_settings["agent_provider"]:
from llama_cpp_agent.providers import LlamaCppServerProvider
self.default_provider = LlamaCppServerProvider(self.default_llm_settings["url"])
self.summary_provider = LlamaCppServerProvider(self.summary_llm_settings["url"])
self.chat_provider = LlamaCppServerProvider(self.chat_llm_settings["url"])
elif "llama_cpp_python" in self.default_llm_settings["agent_provider"]:
from llama_cpp import Llama
from llama_cpp_agent.providers import LlamaCppPythonProvider
python_cpp_llm = Llama(
model_path=f"models/{self.default_llm_settings['filename']}",
self.default_provider = self._initialize_provider(self.default_llm_settings)
self.summary_provider = self._initialize_provider(self.summary_llm_settings)
self.chat_provider = self._initialize_provider(self.chat_llm_settings)

def _initialize_provider(self, llm_settings):
self.logger.info(f"Initializing provider with settings: {llm_settings}")
provider_type = llm_settings["agent_provider"]

if provider_type == "vllm_server":
return VLLMServerProvider(
llm_settings["url"],
llm_settings["huggingface"],
llm_settings["huggingface"],
self.config.openai_compatible_api_key,
)
elif provider_type == "llama_cpp_server":
return LlamaCppServerProvider(llm_settings["url"])
elif provider_type == "tgi_server":
return TGIServerProvider(server_address=llm_settings["url"])
elif provider_type == "llama_cpp_python":
llama_model = Llama(
model_path=f"models/{llm_settings['filename']}",
flash_attn=True,
n_threads=40,
n_gpu_layers=81,
n_batch=1024,
n_ctx=self.default_llm_settings["max_tokens"],
)
self.default_provider = LlamaCppPythonProvider(python_cpp_llm)
self.summary_provider = LlamaCppPythonProvider(python_cpp_llm)
self.chat_provider = LlamaCppPythonProvider(python_cpp_llm)
elif "tgi_server" in self.default_llm_settings["agent_provider"]:
from llama_cpp_agent.providers import TGIServerProvider
self.default_provider = TGIServerProvider(server_address=self.default_llm_settings["url"])
self.summary_provider = TGIServerProvider(server_address=self.summary_llm_settings["url"])
self.chat_provider = TGIServerProvider(server_address=self.chat_llm_settings["url"])
elif "vllm_server" in self.default_llm_settings["agent_provider"]:
from llama_cpp_agent.providers import VLLMServerProvider
self.default_provider = VLLMServerProvider(
base_url=self.default_llm_settings["url"],
model=self.default_llm_settings["huggingface"],
huggingface_model=self.default_llm_settings["huggingface"],
)
self.summary_provider = VLLMServerProvider(
base_url=self.summary_llm_settings["url"],
model=self.summary_llm_settings["huggingface"],
huggingface_model=self.summary_llm_settings["huggingface"],
n_ctx=llm_settings["max_tokens"],
)
self.chat_provider = VLLMServerProvider(
base_url=self.chat_llm_settings["url"],
model=self.chat_llm_settings["huggingface"],
huggingface_model=self.chat_llm_settings["huggingface"],
return LlamaCppPythonProvider(llama_model)
elif provider_type == "groq":
return GroqProvider(
base_url=llm_settings["url"],
model=llm_settings["model"],
huggingface_model=llm_settings["huggingface"],
api_key=llm_settings["api_key"]
)
elif provider_type == "llama_cpp_python_server":
return LlamaCppServerProvider(llm_settings["url"], llama_cpp_python_server=True)
else:
self.logger.error(f"Unsupported llama-cpp-agent provider: {self.default_llm_settings['agent_provider']}, {self.summary_llm_settings['agent_provider']}, {self.chat_llm_settings['agent_provider']}")
self.logger.error(f"Unsupported provider: {provider_type}")
raise ValueError(f"Unsupported provider: {provider_type}")

# Example usage
if __name__ == "__main__":
Expand Down

0 comments on commit 9013986

Please sign in to comment.