Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement structured output by using outlines #745

Open
wants to merge 85 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
85 commits
Select commit Hold shift + click to select a range
5e2be38
feat: Add SchemaModel backend for unified schema-based chat completion
Appointat Jul 2, 2024
cfa04c2
chore: Remove unused imports and update model configuration check
Appointat Jul 2, 2024
4fcd4dd
feat: Update SchemaModel to use the latest message for chat completion
Appointat Jul 2, 2024
c737a00
chore: Update SchemaModel to use GPT_3_5_TURBO as the default model t…
Appointat Jul 3, 2024
82ee7f9
feat: Update SchemaModel to include model configuration parameters
Appointat Jul 3, 2024
6958962
feat: Add SchemaModel backend for unified schema-based chat completion
Appointat Jul 3, 2024
2fb2a72
feat: Update pyproject.toml to include outlines dependency
Appointat Jul 3, 2024
b4b250f
temporary update code
raywhoelse Jul 6, 2024
0a904de
Create 2.py
raywhoelse Jul 11, 2024
9030095
structured response
raywhoelse Jul 13, 2024
bce2207
fix issues about code format
raywhoelse Jul 13, 2024
6be570c
add huddle structed response with tools
raywhoelse Jul 14, 2024
fd69855
refactor: Update GroqModel to support multiple Groq models and improv…
Appointat Jul 16, 2024
e04fa38
fix
Appointat Jul 16, 2024
1336673
fix
Appointat Jul 16, 2024
181dd36
update
Appointat Jul 16, 2024
07ac0fb
update
Appointat Jul 16, 2024
813c9bc
update
Appointat Jul 16, 2024
c866a74
update
Appointat Jul 16, 2024
0fde9c4
update
Appointat Jul 16, 2024
a276a05
update
Appointat Jul 16, 2024
50fda64
update
Appointat Jul 16, 2024
7876921
update
Appointat Jul 16, 2024
8d373e6
update
Appointat Jul 16, 2024
89aa3de
update
Appointat Jul 16, 2024
960cb39
Merge branch 'master' into output_parse
raywhoelse Jul 17, 2024
755aead
update conflict with new version
raywhoelse Jul 17, 2024
1b357bd
format code
raywhoelse Jul 17, 2024
84e4afc
format code
raywhoelse Jul 17, 2024
5e1bd6e
remove duplicate func
raywhoelse Jul 17, 2024
d3d173c
add structure unit test
raywhoelse Jul 18, 2024
1b930f9
remove return_json_format_response in judge
raywhoelse Jul 18, 2024
4434643
add comments of structure func
raywhoelse Jul 19, 2024
2aa34c5
chore: Add support for Outlines models in ModelFactory
Appointat Jul 19, 2024
34775a6
Merge branch 'output_parse' of https://github.com/camel-ai/camel into…
Appointat Jul 19, 2024
9d2205b
chore: Remove unused code and refactor structure of 2.py
Appointat Jul 19, 2024
c50e1f5
docs: Update link to Outlines transformers documentation
Appointat Jul 19, 2024
2d92e5f
chore: Refactor model initialization in SchemaModel
Appointat Jul 19, 2024
375f94a
chore: Refactor model initialization in SchemaModel
Appointat Jul 19, 2024
45d19be
chore: Refactor model initialization in SchemaModel
Appointat Jul 19, 2024
685bc5f
chore: Refactor model initialization in SchemaModel
Appointat Jul 19, 2024
717ce6a
fix issues about pre commit error
raywhoelse Jul 19, 2024
b8f7f83
Merge branch 'master' into output_parse
raywhoelse Jul 19, 2024
3085b09
fix issues with pre-commits error
raywhoelse Jul 19, 2024
558e7a5
Merge branch 'output_parse' of https://github.com/camel-ai/camel into…
raywhoelse Jul 19, 2024
497c904
fix issues
raywhoelse Jul 19, 2024
bcea621
fix issues about pre-commit
raywhoelse Jul 19, 2024
8aba978
chore: Rename step method to un in EmbodiedAgent
Appointat Jul 20, 2024
e02293b
small enhancement
Wendong-Fan Jul 22, 2024
a3080a7
fix
Wendong-Fan Jul 22, 2024
4a86a7c
update poetry
raywhoelse Jul 25, 2024
115c2b4
test
raywhoelse Jul 25, 2024
06c68fd
111
raywhoelse Jul 25, 2024
5127266
test
raywhoelse Jul 25, 2024
e16be7f
test
raywhoelse Jul 25, 2024
6cc2e68
fix the comments
Appointat Aug 2, 2024
fb42df5
merge master branch
Appointat Aug 2, 2024
d3aacb0
update
Appointat Aug 2, 2024
e452eb4
merge
Appointat Aug 2, 2024
78cca79
update
Appointat Aug 2, 2024
098fd4a
feat: Import outlines module with type hinting
Appointat Aug 3, 2024
3e309cb
chore: add hugging_face_hub_token
Appointat Aug 3, 2024
e2bdd59
feat: verify the structured format
Appointat Aug 3, 2024
477a2da
Refactor model_type assignment to use variable instead of hard-coded …
Appointat Aug 4, 2024
bc6cffa
Refactor model_type assignment to use variable instead of hard-coded …
Appointat Aug 4, 2024
0c9d01b
Refactor model_type assignment to use variable instead of hard-coded …
Appointat Aug 4, 2024
8601f6e
Refactor model_type assignment to use variable instead of hard-coded …
Appointat Aug 4, 2024
83193af
Refactor model_type assignment to use variable instead of hard-coded …
Appointat Aug 4, 2024
b1fa76a
Refactor model_type assignment to use variable instead of hard-coded …
Appointat Aug 4, 2024
a266d80
Refactor model_type assignment to use variable instead of hard-coded …
Appointat Aug 4, 2024
874dcf1
Refactor model_type assignment to use variable instead of hard-coded …
Appointat Aug 4, 2024
1d39cfc
Refactor model_name assignment to use updated model name
Appointat Aug 4, 2024
fed085e
Merge branch 'master' into output_parse
Appointat Aug 4, 2024
fdba081
Refactor constructor for open-source backend
Appointat Aug 16, 2024
91831b4
chore: Update version to 0.1.6.5 in code and documentation files
Appointat Aug 16, 2024
3bcbd87
fix: Remove unnecessary code in SchemaModel constructor
Appointat Aug 16, 2024
a93c640
chore: Update optional dependencies in pyproject.toml
Appointat Aug 17, 2024
cc0cb3c
feat: Update optional dependencies in pyproject.toml
Appointat Aug 17, 2024
2a31d63
Update role name to 'User' in user message
Appointat Aug 18, 2024
26bc1df
fix: fix the type check
Appointat Aug 18, 2024
116f8e0
feat: Update SchemaModel to use JSON response for output schema valid…
Appointat Aug 18, 2024
9a5d484
fix: Update output_schema parameter type hint to use Type[BaseModel]
Appointat Aug 18, 2024
ea4feb4
fix:Update output_schema parameter type hint to use Type[BaseModel
Appointat Aug 18, 2024
8bf1356
fixe: Update _client type hint in SchemaModel constructor
Appointat Aug 18, 2024
f798e59
Merge branch 'master' into output_parse
Appointat Aug 23, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions camel/agents/chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
List,
Optional,
Tuple,
Type,
Union,
)

Expand Down Expand Up @@ -319,7 +320,7 @@ def record_message(self, message: BaseMessage) -> None:
def step(
self,
input_message: BaseMessage,
output_schema: Optional[BaseModel] = None,
output_schema: Optional[Type[BaseModel]] = None,
) -> ChatAgentResponse:
r"""Performs a single step in the chat session by generating a response
to the input message.
Expand All @@ -330,10 +331,10 @@ def step(
either `user` or `assistant` but it will be set to `user`
anyway since for the self agent any incoming message is
external.
output_schema (Optional[BaseModel]): An optional pydantic model
that includes value types and field descriptions used to
generate a structured response by LLM. This schema helps
in defining the expected output format.
output_schema (Optional[Type[BaseModel]]): An optional pydantic
model that includes value types and field descriptions used to
generate a structured response by LLM. This schema helps in
defining the expected output format.

Returns:
ChatAgentResponse: A struct containing the output messages,
Expand Down Expand Up @@ -450,7 +451,7 @@ def step(
async def step_async(
self,
input_message: BaseMessage,
output_schema: Optional[BaseModel] = None,
output_schema: Optional[Type[BaseModel]] = None,
) -> ChatAgentResponse:
r"""Performs a single step in the chat session by generating a response
to the input message. This agent step can call async function calls.
Expand All @@ -461,10 +462,10 @@ async def step_async(
either `user` or `assistant` but it will be set to `user`
anyway since for the self agent any incoming message is
external.
output_schema (Optional[BaseModel]): An optional pydantic model
that includes value types and field descriptions used to
generate a structured response by LLM. This schema helps
in defining the expected output format.
output_schema (Optional[Type[BaseModel]]): An optional pydantic
model that includes value types and field descriptions used to
generate a structured response by LLM. This schema helps in
defining the expected output format.

Returns:
ChatAgentResponse: A struct containing the output messages,
Expand Down Expand Up @@ -614,13 +615,13 @@ def _add_tools_for_func_call(
# result message
return tool_calls, func_assistant_msg, func_result_msg

def _add_output_schema_to_tool_list(self, output_schema: BaseModel):
def _add_output_schema_to_tool_list(self, output_schema: Type[BaseModel]):
r"""Handles the structured output response for OpenAI.
This method processes the given output schema and integrates the
resulting function into the tools for the OpenAI model configuration.
Args:
output_schema (BaseModel): The schema representing the expected
output structure.
output_schema (Type[BaseModel]): The schema representing the
expected output structure.
"""
from camel.toolkits import OpenAIFunction

Expand Down
3 changes: 3 additions & 0 deletions camel/configs/openai_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ class OpenSourceConfig(BaseConfig):
which will be used as the API base of OpenAI API.
api_params (ChatGPTConfig): An instance of :obj:ChatGPTConfig to
contain the arguments to be passed to OpenAI API.
model_kwargs (dict, optional): Additional keyword arguments to pass
to the model constructor. (default: :obj:`{}`)
"""

# Maybe the param needs to be renamed.
Expand All @@ -133,3 +135,4 @@ class OpenSourceConfig(BaseConfig):
model_path: str
server_url: str
api_params: ChatGPTConfig = Field(default_factory=ChatGPTConfig)
model_kwargs: Optional[dict] = Field(default_factory=dict)
2 changes: 2 additions & 0 deletions camel/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .openai_audio_models import OpenAIAudioModels
from .openai_compatibility_model import OpenAICompatibilityModel
from .openai_model import OpenAIModel
from .schema_model import SchemaModel
from .stub_model import StubModel
from .vllm_model import VLLMModel
from .zhipuai_model import ZhipuAIModel
Expand All @@ -46,5 +47,6 @@
'OllamaModel',
'VLLMModel',
'GeminiModel',
'SchemaModel',
'OpenAICompatibilityModel',
]
6 changes: 6 additions & 0 deletions camel/models/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from camel.models.open_source_model import OpenSourceModel
from camel.models.openai_compatibility_model import OpenAICompatibilityModel
from camel.models.openai_model import OpenAIModel
from camel.models.schema_model import SchemaModel
from camel.models.stub_model import StubModel
from camel.models.vllm_model import VLLMModel
from camel.models.zhipuai_model import ZhipuAIModel
Expand Down Expand Up @@ -108,6 +109,11 @@ def create(
model_class = VLLMModel
elif model_platform.is_litellm:
model_class = LiteLLMModel
elif model_platform.is_outlines:
model_class = SchemaModel
return model_class(
model_platform, model_type, model_config_dict, url
)
elif model_platform.is_openai_compatibility_model:
model_class = OpenAICompatibilityModel
else:
Expand Down
232 changes: 232 additions & 0 deletions camel/models/schema_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
import json
from typing import (
Any,
Dict,
List,
Optional,
Type,
TypeVar,
Union,
overload,
)

from openai import Stream
from pydantic import BaseModel, ValidationError

from camel.messages import OpenAIMessage
from camel.models import BaseModelBackend
from camel.types import (
ChatCompletion,
ChatCompletionChunk,
ChatCompletionMessage,
Choice,
ModelPlatformType,
ModelType,
)
from camel.utils import (
BaseTokenCounter,
OpenAITokenCounter,
)

T = TypeVar('T', bound=BaseModel)


class SchemaModel(BaseModelBackend):
r"""Shema model in a unified BaseModelBackend interface, which aims to
generate the formatted response."""

def __init__(
self,
model_platform: ModelPlatformType,
model_type: str,
model_config_dict: Dict[str, Any],
url: Optional[str] = None,
) -> None:
r"""Constructor for open-source backend.

Args:
model_platform (ModelPlatformType): Platform from which the model
originates, including transformers, llama_cpp, and vllm.
model_type (str): Model for which a backend is created, for
example, "mistralai/Mistral-7B-v0.3".
model_config_dict (Dict[str, Any]): A dictionary that will
be fed into openai.ChatCompletion.create().
url (Optional[str]): The url to the OpenAI service.
"""
from outlines import models # type: ignore[import]
Appointat marked this conversation as resolved.
Show resolved Hide resolved

Appointat marked this conversation as resolved.
Show resolved Hide resolved
self.model_platform = model_platform
self.model_name = model_type
self.model_config_dict = model_config_dict
self._client: Union[models.Transformers, models.LlamaCpp, models.VLLM]
self._url = url

# Since Outlines suports multiple model types, it is necessary to
# read the documentation to learn about the model kwargs:
# https://outlines-dev.github.io/outlines/reference/models/transformers
if self.model_platform == ModelPlatformType.OUTLINES_TRANSFORMERS:
model_kwargs = self.model_config_dict.get("model_kwargs", {})
device = self.model_config_dict.get("device", None)
tokenizer_kwargs = self.model_config_dict.get(
"tokenizer_kwargs", {}
)

self._client = models.transformers(
model_name=self.model_name,
device=device,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
)
elif self.model_platform == ModelPlatformType.OUTLINES_LLAMACPP:
repo_id = self.model_config_dict.get(
"repo_id", "TheBloke/phi-2-GGUF"
)
filename = self.model_config_dict.get(
"filename", "phi-2.Q4_K_M.gguf"
)
download_dir = self.model_config_dict.get("download_dir", None)
model_kwargs = self.model_config_dict.get("model_kwargs", {})

from llama_cpp import llama_tokenizer # type: ignore[import]
Appointat marked this conversation as resolved.
Show resolved Hide resolved

# Initialize the tokenizer
tokenizer = llama_tokenizer.LlamaHFTokenizer.from_pretrained(
repo_id
) # type: ignore[attr-defined]

self._client = models.llamacpp( # type: ignore[attr-defined]
repo_id=repo_id,
filename=filename,
download_dir=download_dir,
tokenizer=tokenizer,
**model_kwargs,
)
elif self.model_platform == ModelPlatformType.OUTLINES_VLLM:
model_kwargs = self.model_config_dict.get("model_kwargs", {})

self._client = models.vllm(
model_name=self.model_name,
**model_kwargs,
)
else:
raise ValueError(
f"Unsupported model by Outlines: {self.model_name}"
)

self._token_counter: Optional[BaseTokenCounter] = None

@property
def token_counter(self) -> BaseTokenCounter:
r"""Initialize the token counter for the model backend.

Returns:
BaseTokenCounter: The token counter following the model's
tokenization style.
"""
if not self._token_counter:
# The default model type is GPT_3_5_TURBO, since the self-hosted
# models are not supported in the token counter.
self._token_counter = OpenAITokenCounter(ModelType.GPT_3_5_TURBO)
return self._token_counter

@overload
def run(
self,
messages: List[OpenAIMessage],
) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]: ...

@overload
def run(
self,
messages: List[OpenAIMessage],
output_schema: Type[T],
) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]: ...

def run(
self,
messages: List[OpenAIMessage],
output_schema: Optional[Type[T]] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this is T not BaseModel, because our schema of output parse is baseModel type

) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
if output_schema is None:
raise NotImplementedError(
"run without output_schema is not implemented"
)

from outlines import generate # type: ignore[import]

generator = generate.json(self._client, output_schema)
Appointat marked this conversation as resolved.
Show resolved Hide resolved

if not messages:
raise ValueError("The messages list should not be empty.")
message = messages[-1]
message_str = (
f"{message.get('role', '')}: {message.get('content', '')}"
)

parsed_response = generator(message_str)
json_response = json.dumps(str(parsed_response))

# Verify the structured format
try:
_ = output_schema(**json.loads(json_response))
except ValidationError as e:
raise ValueError(
f"Generated response does not match the output schema: {e}"
)

import time

response = ChatCompletion(
id=f"chatcmpl-{time.time()}",
created=int(time.time()),
model=self.model_name,
object="chat.completion",
choices=[
Choice(
index=0,
message=ChatCompletionMessage(
role="assistant",
content=json_response,
),
finish_reason="stop",
),
],
)

return response

def check_model_config(self):
r"""Check whether the model configuration contains the required
arguments for the schema-based model.

Raises:
Warning: If the model configuration dictionary does not contain
the required arguments for the schema-based model, the warnings
are raised.
"""
# Check the model_name, WarningError if not found
if "model_name" not in self.model_config_dict:
raise Warning("The model_name is set to the default value.")

@property
def stream(self) -> bool:
r"""Returns whether the model is in stream mode,
which sends partial results each time.

Returns:
bool: Whether the model is in stream mode.
"""
return self.model_config_dict.get('stream', False)
12 changes: 12 additions & 0 deletions camel/types/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,9 @@ class ModelPlatformType(Enum):
ZHIPU = "zhipuai"
DEFAULT = "default"
GEMINI = "gemini"
OUTLINES_TRANSFORMERS = "outlines-transformers"
OUTLINES_LLAMACPP = "outlines-llamacpp"
OUTLINES_VLLM = "outlines-vllm"
VLLM = "vllm"
MISTRAL = "mistral"
OPENAI_COMPATIBILITY_MODEL = "openai-compatibility-model"
Expand Down Expand Up @@ -507,6 +510,15 @@ def is_gemini(self) -> bool:
r"""Returns whether this platform is Gemini."""
return self is ModelPlatformType.GEMINI

@property
def is_outlines(self) -> bool:
r"""Returns whether this platform is Outlines."""
return self in {
ModelPlatformType.OUTLINES_TRANSFORMERS,
ModelPlatformType.OUTLINES_LLAMACPP,
ModelPlatformType.OUTLINES_VLLM,
}


class AudioModelType(Enum):
TTS_1 = "tts-1"
Expand Down
Loading
Loading