Skip to content

Commit

Permalink
Merge pull request #1131 from JohnSnowLabs/feature/add-support-for-ch…
Browse files Browse the repository at this point in the history
…at-and-instruct-model-types

Feature/add support for chat and instruct model types
  • Loading branch information
chakravarthik27 authored Oct 21, 2024
2 parents 3db1a94 + 572bbe4 commit 0d7589d
Show file tree
Hide file tree
Showing 9 changed files with 203 additions and 48 deletions.
13 changes: 8 additions & 5 deletions langtest/langtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

from pkg_resources import resource_filename

from langtest.types import DatasetConfig, ModelConfig

from .tasks import TaskManager
from .augmentation import AugmentRobustness, TemplaticAugment
from .datahandler.datasource import DataFactory
Expand Down Expand Up @@ -90,8 +92,8 @@ class Harness:
def __init__(
self,
task: Union[str, dict],
model: Optional[Union[list, dict]] = None,
data: Optional[Union[list, dict]] = None,
model: Optional[Union[List[ModelConfig], ModelConfig]] = None,
data: Optional[Union[List[DatasetConfig], DatasetConfig]] = None,
config: Optional[Union[str, dict]] = None,
benchmarking: dict = None,
):
Expand Down Expand Up @@ -156,11 +158,12 @@ def __init__(
raise ValueError(Errors.E003())

if isinstance(model, dict):
hub, model = model["hub"], model["model"]
hub, model, model_type = model["hub"], model["model"], model.get("type")
self.hub = hub
self._actual_model = model
else:
hub = None
model_type = None

# loading task

Expand Down Expand Up @@ -215,14 +218,14 @@ def __init__(
hub = i["hub"]

model_dict[model] = self.task.model(
model, hub, **self._config.get("model_parameters", {})
model, hub, model_type, **self._config.get("model_parameters", {})
)

self.model = model_dict

else:
self.model = self.task.model(
model, hub, **self._config.get("model_parameters", {})
model, hub, model_type, **self._config.get("model_parameters", {})
)
# end model selection
formatted_config = json.dumps(self._config, indent=1)
Expand Down
32 changes: 24 additions & 8 deletions langtest/modelhandler/llm_modelhandler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import inspect

from typing import Any, List, Union
import langchain.llms as lc
import langchain.chat_models as chat_models
from langchain.chains import LLMChain
from langchain_core.prompts import PromptTemplate
from langchain_core.exceptions import OutputParserException
Expand Down Expand Up @@ -71,7 +73,9 @@ def load_model(cls, hub: str, path: str, *args, **kwargs) -> "PretrainedModelFor
ValueError: If the model is not found online or locally.
ConfigError: If there is an error in the model configuration.
"""
exclude_args = ["task", "device", "stream"]
exclude_args = ["task", "device", "stream", "model_type", "chat_template"]

model_type = kwargs.get("model_type", None)

filtered_kwargs = kwargs.copy()

Expand All @@ -93,24 +97,36 @@ def load_model(cls, hub: str, path: str, *args, **kwargs) -> "PretrainedModelFor
"gpt-3.5-turbo-1106",
"gpt-4o-2024-05-13",
"gpt-4o",
):
"o1-preview",
"o1-mini",
) and hub in ["openai", "azure-openai"]:
if hub == "openai":
from langchain_openai.chat_models import ChatOpenAI

model = ChatOpenAI(model=path, *args, **filtered_kwargs)
elif hub == "azure-openai":
from langchain.chat_models.azure_openai import AzureChatOpenAI
from langchain_openai.chat_models import AzureChatOpenAI

model = AzureChatOpenAI(model=path, *args, **filtered_kwargs)

return cls(hub, model, *args, **filtered_kwargs)
elif hub == "ollama":
from langchain.chat_models.ollama import ChatOllama
# elif hub == "ollama":
# from langchain.chat_models.ollama import ChatOllama

model = ChatOllama(model=path, *args, **filtered_kwargs)
return cls(hub, model, *args, **filtered_kwargs)
# model = ChatOllama(model=path, *args, **filtered_kwargs)
# return cls(hub, model, *args, **filtered_kwargs)
else:
model = getattr(lc, LANGCHAIN_HUBS[hub])
from .utils import CHAT_MODEL_CLASSES

if model_type and hub in CHAT_MODEL_CLASSES:
hub_module = getattr(chat_models, hub)
model = getattr(hub_module, CHAT_MODEL_CLASSES[hub])
elif model_type in [None, "instruct"]:
model = getattr(lc, LANGCHAIN_HUBS[hub])
else:
raise ValueError(
f"{hub} hub is not supported for the given model type"
)
default_args = inspect.getfullargspec(model).kwonlyargs
if "model" in default_args:
cls.model = model(model=path, *args, **filtered_kwargs)
Expand Down
35 changes: 29 additions & 6 deletions langtest/modelhandler/transformers_modelhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,7 @@ def __init__(self, model, **kwargs):
)

self.model = model
self.model_type = kwargs.get("model_type", None)

@classmethod
def _try_initialize_model(cls, path, device, tasks, **kwargs):
Expand Down Expand Up @@ -729,6 +730,9 @@ def load_model(cls, path: str, **kwargs):
- PretrainedModelForQA: An instance of the PretrainedModelForQA class.
"""
try:
# set the model_type from kwargs
model_type = kwargs.get("model_type", None)

# Setup and pop specific kwargs
new_tokens_key = "max_new_tokens"

Expand Down Expand Up @@ -758,7 +762,7 @@ def load_model(cls, path: str, **kwargs):
)
else:
model = HuggingFacePipeline(pipeline=path)
return cls(model)
return cls(model, model_type=model_type)
else:
if isinstance(path, str):
model = cls._try_initialize_model(
Expand All @@ -773,7 +777,7 @@ def load_model(cls, path: str, **kwargs):
else:
model = HuggingFacePipeline(pipeline=path)

return cls(model)
return cls(model, model_type=model_type)

except Exception as e:
raise ValueError(Errors.E090(error_message=e))
Expand All @@ -792,10 +796,29 @@ def predict(self, text: Union[str, dict], prompt: dict, **kwargs) -> str:
- str: The generated prediction.
"""
try:
prompt_template = SimplePromptTemplate(**prompt)
text = prompt_template.format(**text)
output = self.model._generate([text])
return output[0]
if self.model_type == "chat":
from langtest.prompts import PromptManager

prompt_manager = PromptManager()
examples = prompt_manager.get_prompt(hub="transformers")

if examples:
prompt["template"] = "".join(
f"{k.title()}:\n{{{k}}}\n" for k in text.keys()
)
prompt_template = SimplePromptTemplate(**prompt)
text = prompt_template.format(**text)
messages = [*examples, {"role": "user", "content": text}]
else:
messages = [{"role": "user", "content": text}]
output = self.model._generate([messages])
return output[0].strip()

else:
prompt_template = SimplePromptTemplate(**prompt)
text = prompt_template.format(**text)
output = self.model._generate([text])
return output[0]
except Exception as e:
raise ValueError(Errors.E089(error_message=e))

Expand Down
47 changes: 47 additions & 0 deletions langtest/modelhandler/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# This file contains the model classes that are used in the model handler.
# from langchain

CHAT_MODEL_CLASSES = {
"anthropic": "ChatAnthropic",
"anyscale": "ChatAnyscale",
"azure_openai": "AzureChatOpenAI",
"baichuan": "ChatBaichuan",
"baidu_qianfan_endpoint": "QianfanChatEndpoint",
"bedrock": "BedrockChat",
"cohere": "ChatCohere",
"databricks": "ChatDatabricks",
"deepinfra": "ChatDeepInfra",
"ernie": "ErnieBotChat",
"everlyai": "ChatEverlyAI",
"fake": "FakeListChatModel",
"fireworks": "ChatFireworks",
"gigachat": "GigaChat",
"google_palm": "ChatGooglePalm",
"gpt_router": "GPTRouter",
"huggingface": "ChatHuggingFace",
"human": "HumanInputChatModel",
"hunyuan": "ChatHunyuan",
"javelin_ai_gateway": "ChatJavelinAIGateway",
"jinachat": "JinaChat",
"kinetica": "ChatKinetica",
"konko": "ChatKonko",
"litellm": "ChatLiteLLM",
"litellm_router": "ChatLiteLLMRouter",
"llama_edge": "LlamaEdgeChatService",
"maritalk": "ChatMaritalk",
"minimax": "MiniMaxChat",
"mlflow": "ChatMlflow",
"mlflow_ai_gateway": "ChatMLflowAIGateway",
"ollama": "ChatOllama",
"openai": "ChatOpenAI",
"pai_eas_endpoint": "PaiEasChatEndpoint",
"perplexity": "ChatPerplexity",
"promptlayer_openai": "PromptLayerChatOpenAI",
"sparkllm": "ChatSparkLLM",
"tongyi": "ChatTongyi",
"vertexai": "ChatVertexAI",
"volcengine_maas": "VolcEngineMaasChat",
"yandex": "ChatYandexGPT",
"yuan2": "ChatYuan2",
"zhipuai": "ChatZhipuAI",
}
17 changes: 14 additions & 3 deletions langtest/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,17 @@ def get_template(self):

temp = []
order_less = []
for field in self.__dict__:

sorted_fields = sorted(
self.__dict__.keys(), key=lambda x: self.__field_order.index(x.lower())
)

for field in sorted_fields:
if field in self.__field_order:
temp.append(f"{field.title()}: {{{field}}}")
else:
order_less.append(f"{field.title()}: {{{field}}}")

if order_less:
temp.extend(order_less)
return "\n" + "\n".join(temp)
Expand Down Expand Up @@ -169,7 +175,7 @@ def prompt_style(self):
return final_prompt

def get_prompt(self, hub=None):
if hub == "lm-studio":
if hub in ("lm-studio", "transformers"):
return self.lm_studio_prompt()
return self.prompt_style()

Expand All @@ -194,7 +200,12 @@ def lm_studio_prompt(self):

# assistant role
temp_ai["role"] = "assistant"
temp_ai["content"] = example.ai.get_template.format(**example.ai.get_example)
temp_ai["content"] = (
example.ai.get_template.format(**example.ai.get_example)
.replace("Answer:", "")
.strip()
+ "\n\n"
)

messages.append(temp_user)
messages.append(temp_ai)
Expand Down
15 changes: 13 additions & 2 deletions langtest/tasks/task.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ast
import re
from abc import ABC, abstractmethod
from typing import Union
from typing import Literal, Union
from langtest.modelhandler import ModelAPI, LANGCHAIN_HUBS, INSTALLED_HUBS
from langtest.errors import Errors, ColumnNameError

Expand All @@ -23,7 +23,14 @@ def create_sample(cls, *args, **kwargs) -> samples.Sample:
pass

@classmethod
def load_model(cls, model_path: str, model_hub: str, *args, **kwargs):
def load_model(
cls,
model_path: str,
model_hub: str,
model_type: Literal["chat", "completion"] = None,
*args,
**kwargs,
):
"""Load the model."""

models = ModelAPI.model_registry
Expand Down Expand Up @@ -54,6 +61,10 @@ def load_model(cls, model_path: str, model_hub: str, *args, **kwargs):
if "server_prompt" in kwargs:
cls.server_prompt = kwargs.get("server_prompt")
kwargs.pop("server_prompt")

if model_type:
kwargs["model_type"] = model_type

try:
if model_hub in LANGCHAIN_HUBS:
# LLM models
Expand Down
37 changes: 37 additions & 0 deletions langtest/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import Literal, TypedDict, Union, List


class ModelConfig(TypedDict):
"""
ModelConfig is a TypedDict that defines the configuration for a model.
Attributes:
model (str): The name of the model.
type (Literal['chat', 'completion']): The type of the model, either 'chat' or 'completion'.
hub (str): The hub where the model is located.
"""

model: str
type: Literal["chat", "completion"]
hub: str


class DatasetConfig(TypedDict):
"""
DatasetConfig is a TypedDict that defines the configuration for a dataset.
Attributes:
data_source (str): The source of the data, e.g., a file path.
split (str): The data split, e.g., 'train', 'test', or 'validation'.
subset (str): A specific subset of the data, if applicable.
feature_column (Union[str, List[str]]): The column(s) representing the features in the dataset.
target_column (Union[str, List[str]]): The column(s) representing the target variable(s) in the dataset.
source (str): The original source of the dataset ex: huggingface.
"""

data_source: str
split: str
subset: str
feature_column: Union[str, List[str]]
target_column: Union[str, List[str]]
source: str
4 changes: 2 additions & 2 deletions langtest/utils/config_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import Any, Dict, List
from pkg_resources import resource_filename


Expand All @@ -12,7 +12,7 @@
"default": resource_filename("langtest", "data/config/QA_summarization_config.yml"),
}

DEFAULTS_CONFIG = {
DEFAULTS_CONFIG: Dict[str, Any] = {
"question-answering": LLM_DEFAULTS_CONFIG,
"summarization": LLM_DEFAULTS_CONFIG,
"ideology": resource_filename("langtest", "data/config/political_config.yml"),
Expand Down
Loading

0 comments on commit 0d7589d

Please sign in to comment.