Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ollama support to llm_core #349

Open
wants to merge 13 commits into
base: develop
Choose a base branch
from
9 changes: 9 additions & 0 deletions llm_core/llm_core/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

DefaultModelConfig: Type[ModelConfig]
MiniModelConfig: ModelConfig
OllamaModelConfig: ModelConfig
default_model_name = os.environ.get("LLM_DEFAULT_MODEL")
evaluation_model_name = os.environ.get("LLM_EVALUATION_MODEL")

Expand All @@ -24,6 +25,14 @@
except AttributeError:
pass

try:
import llm_core.models.ollama as ollama_config #type: ignore
types.append(ollama_config.OllamaModelConfig)
OllamaModelConfig = ollama_config.OllamaModelConfig(model_name="llama3.1:70b",max_tokens=1000, temperature=0,top_p=1,presence_penalty=0,frequency_penalty=0)
# DefaultModelConfig = ollama_config.OllamaModelConfig
except AttributeError:
pass

if not types:
raise EnvironmentError(
"No model configurations available, please set up at least one provider in the environment variables.")
Expand Down
115 changes: 115 additions & 0 deletions llm_core/llm_core/models/ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import requests
from enum import Enum
from llm_core.models.model_config import ModelConfig # type: ignore
from pydantic import validator, Field, PositiveInt
from langchain.base_language import BaseLanguageModel
import os
from langchain_community.chat_models import ChatOllama # type: ignore
from athena.logger import logger
from requests.exceptions import RequestException, Timeout

if os.environ.get('GPU_USER') and os.environ.get('GPU_PASSWORD') and os.environ.get('OLLAMA_ENDPOINT') is not None:
try:
if(os.environ["GPU_USER"] and os.environ["GPU_PASSWORD"]):
auth_header= {
'Authorization': requests.auth._basic_auth_str(os.environ["GPU_USER"],os.environ["GPU_PASSWORD"]) # type: ignore
}


def get_ollama_models():
url = os.environ["OLLAMA_ENDPOINT"] + "/api/tags"
response = requests.get(url, auth=(os.environ["GPU_USER"], os.environ["GPU_PASSWORD"]))
data = response.json()
model_list = [model['name'] for model in data['models']]
return model_list

ollama_models = get_ollama_models()
available_models = {}

if([os.environ["OLLAMA_ENDPOINT"]]):
available_models = {
name : ChatOllama(
name = name,
model = name,
base_url = os.environ["OLLAMA_ENDPOINT"],
headers = auth_header,
format = "json"
) for name in ollama_models
}

default_model_name = "llama3.1:70b"
LlamaModel = Enum('LlamaModel', {name: name for name in available_models}) # type: ignore
class OllamaModelConfig(ModelConfig):
"""Ollama LLM configuration."""
logger.info("Available ollama models: %s", ", ".join(available_models.keys()))

model_name: LlamaModel = Field(default=default_model_name, # type: ignore
description="The name of the model to use.")

fromat : str = Field(default = "json" , description="The format to respond with")

max_tokens: PositiveInt = Field(1000, description="")

temperature: float = Field(default=0.0, ge=0, le=2, description="")

top_p: float = Field(default=1, ge=0, le=1, description="")

headers : dict = Field(default= auth_header, description="headers for authentication")

presence_penalty: float = Field(default=0, ge=-2, le=2, description="")

frequency_penalty: float = Field(default=0, ge=-2, le=2, description="")

base_url : str = Field(default="https://gpu-artemis.ase.cit.tum.de/ollama", description=" Base Url where ollama is hosted")
@validator('max_tokens')
def max_tokens_must_be_positive(cls, v):
"""
Validate that max_tokens is a positive integer.
"""
if v <= 0:
raise ValueError('max_tokens must be a positive integer')
return v

def get_model(self) -> BaseLanguageModel:
print("Getting Model: ", self.model_name.value)
"""Get the model from the configuration.

Returns:
BaseLanguageModel: The model.
"""

model = available_models[self.model_name.value]
kwargs = model.__dict__
secrets = {secret: getattr(model, secret) for secret in model.lc_secrets.keys()}
kwargs.update(secrets)

model_kwargs = kwargs.get("model_kwargs", {})
for attr, value in self.dict().items():
if attr == "model_name":
# Skip model_name
continue
if hasattr(model, attr):
# If the model has the attribute, add it to kwargs
kwargs[attr] = value
else:
# Otherwise, add it to model_kwargs (necessary for chat models)
model_kwargs[attr] = value
kwargs["model_kwargs"] = model_kwargs

allowed_fields = set(self.__fields__.keys())
filtered_kwargs = {k: v for k, v in kwargs.items() if k in allowed_fields}
filtered_kwargs["headers"] = auth_header
filtered_kwargs["model"]= self.model_name.value

# Initialize a copy of the model using the filtered kwargs
model = model.__class__(**filtered_kwargs)

return model


class Config:
title = 'Ollama'
except Timeout:
print("Connection timed out. Skipping server connection step.")
except RequestException as e:
print(f"Failed to connect to the server: {e}. Skipping this step.")
22 changes: 21 additions & 1 deletion llm_core/llm_core/utils/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,24 @@ def get_chat_prompt_with_formatting_instructions(
system_message_prompt.prompt.partial_variables = {"format_instructions": output_parser.get_format_instructions()}
system_message_prompt.prompt.input_variables.remove("format_instructions")
human_message_prompt = HumanMessagePromptTemplate.from_template(human_message + "\n\nJSON response following the provided schema:")
return ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])
return ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])

def get_simple_chat_prompt(
system_message: str,
human_message: str,
) -> ChatPromptTemplate:

sys = """
You are an AI Tutor. You are tasked with grading a student submission based on this problem statement and grading instructions. You must not excede the maximum amount of points. Take time to think, which points on the grading instructions are relevant for the students submission.
Further more, if a feedback is specific to a sentence in the student submission, that specify this as well on your feedback. Also specify, when possible, which grading instruction you are refering to.
# Problem statement
{problem_statement}

# Grading instructions
{grading_instructions}
Max points: {max_points}
"""

system_message_prompt = SystemMessagePromptTemplate.from_template(system_message)
human_message_prompt = HumanMessagePromptTemplate.from_template(human_message)
return ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])
39 changes: 36 additions & 3 deletions llm_core/llm_core/utils/predict_and_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,35 @@
from langchain_core.pydantic_v1 import BaseModel, ValidationError
from langchain_core.runnables import RunnableSequence
from athena import get_experiment_environment

from langchain_community.chat_models import ChatOllama # type: ignore
from langchain.output_parsers import PydanticOutputParser
T = TypeVar("T", bound=BaseModel)

def isOllama(model: BaseLanguageModel) -> bool:
return isinstance(model, ChatOllama)

async def predict_plain_text(
model: BaseLanguageModel,
chat_prompt: ChatPromptTemplate,
prompt_input: dict,
tags: Optional[List[str]]) -> Optional[str]:
"""Predict plain text using the provided model and prompt.

Args:
model (BaseLanguageModel): The model to predict with.
chat_prompt (ChatPromptTemplate): The prompt template to use.
prompt_input (dict): Input parameters to use for the prompt.
tags (Optional[List[str]]): List of tags to tag the prediction with.

Returns:
Optional[str]: Prediction as a string, or None if it failed.
"""
try:
chain = chat_prompt | model
return await chain.ainvoke(prompt_input, config={"tags": tags})
except:
raise ValueError("Llm prediction failed.")

async def predict_and_parse(
model: BaseLanguageModel,
chat_prompt: ChatPromptTemplate,
Expand Down Expand Up @@ -37,7 +63,14 @@ async def predict_and_parse(
if experiment.run_id is not None:
tags.append(f"run-{experiment.run_id}")


if isOllama(model):
try:
outputParser = PydanticOutputParser(pydantic_object = pydantic_object)
chain = chat_prompt | model | outputParser
return await chain.ainvoke(prompt_input, config={"tags": tags})
except ValidationError as e:
raise ValueError(f"Could not parse output: {e}") from e

if (use_function_calling):
structured_output_llm = model.with_structured_output(pydantic_object)
chain = chat_prompt | structured_output_llm
Expand All @@ -63,4 +96,4 @@ async def predict_and_parse(
return await chain.ainvoke(prompt_input, config={"tags": tags})
except ValidationError as e:
raise ValueError(f"Could not parse output: {e}") from e

4 changes: 4 additions & 0 deletions modules/text/module_text_llm/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,7 @@ OPENAI_API_VERSION="2024-06-01" # change base if needed
# LANGCHAIN_ENDPOINT="https://api.smith.langchain.com"
# LANGCHAIN_API_KEY="XXX"
# LANGCHAIN_PROJECT="XXX"

GPU_USER=
GPU_PASSWORD=
OLLAMA_ENDPOINT= #'https://gpu-artemis.ase.cit.tum.de/ollama'
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
class ApproachType(str, Enum):
basic = "BasicApproach"
chain_of_thought = "ChainOfThought"
ollama_cot = "OllamaChainOfThought"

class ApproachConfig(BaseModel, ABC):
max_input_tokens: int = Field(default=3000, description="Maximum number of tokens in the input prompt.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
from athena.text import Exercise, Submission, Feedback
from module_text_llm.basic_approach import BasicApproachConfig
from module_text_llm.chain_of_thought_approach import ChainOfThoughtConfig
from module_text_llm.ollama_chain_of_thought_approach import OllamaChainOfThoughtConfig
from module_text_llm.approach_config import ApproachConfig

from module_text_llm.basic_approach.generate_suggestions import generate_suggestions as generate_suggestions_basic
from module_text_llm.chain_of_thought_approach.generate_suggestions import generate_suggestions as generate_cot_suggestions

from module_text_llm.ollama_chain_of_thought_approach.generate_suggestions import generate_suggestions as generate_cot_ollana_suggestions
async def generate_suggestions(exercise: Exercise, submission: Submission, config: ApproachConfig, debug: bool) -> List[Feedback]:
if isinstance(config, BasicApproachConfig):
return await generate_suggestions_basic(exercise, submission, config, debug)
if isinstance(config, ChainOfThoughtConfig):
return await generate_cot_suggestions(exercise, submission, config, debug)
if isinstance(config, OllamaChainOfThoughtConfig):
return await generate_cot_ollana_suggestions(exercise, submission, config, debug)
raise ValueError("Unsupported config type provided.")
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
num_tokens_from_prompt,
)
from llm_core.utils.predict_and_parse import predict_and_parse

from module_text_llm.config import BasicApproachConfig
from module_text_llm.helpers.utils import add_sentence_numbers, get_index_range_from_line_range, format_grading_instructions
from module_text_llm.basic_approach.prompt_generate_suggestions import AssessmentModel
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from pydantic import BaseModel, Field
from typing import Literal
from llm_core.models import ModelConfigType, DefaultModelConfig

from module_text_llm.approach_config import ApproachConfig
from module_text_llm.chain_of_thought_approach.prompt_generate_feedback import CoTGenerateSuggestionsPrompt
from module_text_llm.chain_of_thought_approach.prompt_thinking import ThinkingPrompt

class ChainOfThoughtConfig(ApproachConfig):
type: Literal['chain_of_thought'] = 'chain_of_thought'
model: ModelConfigType = Field(default=DefaultModelConfig) # type: ignore
thikning_prompt: ThinkingPrompt = Field(default=ThinkingPrompt())
generate_suggestions_prompt: CoTGenerateSuggestionsPrompt = Field(default=CoTGenerateSuggestionsPrompt())

Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,5 @@ class FeedbackModel(BaseModel):
class AssessmentModel(BaseModel):
"""Collection of feedbacks making up an assessment"""

feedbacks: List[FeedbackModel] = Field(description="Assessment feedbacks")
feedbacks: List[FeedbackModel] = Field(description="Assessment feedbacks")

3 changes: 2 additions & 1 deletion modules/text/module_text_llm/module_text_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

from module_text_llm.chain_of_thought_approach import ChainOfThoughtConfig
from module_text_llm.basic_approach import BasicApproachConfig
from module_text_llm.ollama_chain_of_thought_approach import OllamaChainOfThoughtConfig

ApproachConfigUnion = Union[BasicApproachConfig, ChainOfThoughtConfig]
ApproachConfigUnion = Union[BasicApproachConfig, ChainOfThoughtConfig, OllamaChainOfThoughtConfig]

@config_schema_provider
class Configuration(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from pydantic import BaseModel, Field
from typing import Literal
from llm_core.models import ModelConfigType
try:
from llm_core.models import OllamaModelConfig
except ImportError as e:
print(f"Warning: Failed to import models. {e}")
OllamaModelConfig = None
from module_text_llm.approach_config import ApproachConfig
from module_text_llm.ollama_chain_of_thought_approach.prompt_generate_feedback import CoTGenerateSuggestionsPrompt
from module_text_llm.ollama_chain_of_thought_approach.prompt_thinking import ThinkingPrompt

class OllamaChainOfThoughtConfig(ApproachConfig):
type: Literal['ollama_chain_of_thought'] = 'ollama_chain_of_thought'
model: ModelConfigType = Field(default=OllamaModelConfig) # type: ignore
thikning_prompt: ThinkingPrompt = Field(default=ThinkingPrompt())
generate_suggestions_prompt: CoTGenerateSuggestionsPrompt = Field(default=CoTGenerateSuggestionsPrompt())

Loading
Loading