Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for o3 models & update litellm #130

Merged
merged 10 commits into from
Feb 4, 2025
48 changes: 30 additions & 18 deletions dynamiq/nodes/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,14 @@ def _get_response_format_and_tools(

return response_format, tools

def update_completion_params(self, params: dict[str, Any]) -> dict[str, Any]:
"""
This method can be overridden by subclasses to update or modify the
parameters passed to the completion method.
By default, it does not modify the params.
"""
return params

def execute(
self,
input_data: BaseLLMInputSchema,
Expand Down Expand Up @@ -304,9 +312,9 @@ def execute(
self.run_on_node_execute_run(callbacks=config.callbacks, prompt_messages=messages, **kwargs)

# Use initialized client if it possible
params = self.connection.conn_params
params = self.connection.conn_params.copy()
if self.client and not isinstance(self.connection, HttpApiKey):
params = {"client": self.client}
params.update({"client": self.client})

current_inference_mode = inference_mode or self.inference_mode
current_schema = schema or self.schema_
Expand All @@ -315,23 +323,27 @@ def execute(
)
tools = tools or base_tools

response = self._completion(
model=self.model,
messages=messages,
stream=self.streaming.enabled,
temperature=self.temperature,
max_tokens=self.max_tokens,
tools=tools,
tool_choice=self.tool_choice,
stop=self.stop,
top_p=self.top_p,
seed=self.seed,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
response_format=response_format,
drop_params=True,
common_params: dict[str, Any] = {
"model": self.model,
"messages": messages,
"stream": self.streaming.enabled,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"tools": tools,
"tool_choice": self.tool_choice,
"stop": self.stop,
"top_p": self.top_p,
"seed": self.seed,
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
"response_format": response_format,
"drop_params": True,
**params,
)
}

common_params = self.update_completion_params(common_params)

response = self._completion(**common_params)

handle_completion = (
self._handle_streaming_completion_response if self.streaming.enabled else self._handle_completion_response
Expand Down
24 changes: 24 additions & 0 deletions dynamiq/nodes/llms/openai.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from functools import cached_property
from typing import Any, ClassVar

from dynamiq.connections import OpenAI as OpenAIConnection
from dynamiq.nodes.llms.base import BaseLLM

Expand All @@ -11,6 +14,7 @@ class OpenAI(BaseLLM):
connection (OpenAIConnection | None): The connection to use for the OpenAI LLM.
"""
connection: OpenAIConnection | None = None
O_SERIES_MODEL_PREFIXES: ClassVar[tuple[str, ...]] = ("o1", "o3")

def __init__(self, **kwargs):
"""Initialize the OpenAI LLM node.
Expand All @@ -21,3 +25,23 @@ def __init__(self, **kwargs):
if kwargs.get("client") is None and kwargs.get("connection") is None:
kwargs["connection"] = OpenAIConnection()
super().__init__(**kwargs)

@cached_property
def is_o_series_model(self) -> bool:
olbychos marked this conversation as resolved.
Show resolved Hide resolved
"""
Determine if the model belongs to the O-series (e.g. o1 or o3)
by checking if the model starts with any of the O-series prefixes.
"""
model_lower = self.model.lower()
return any(model_lower.startswith(prefix) for prefix in self.O_SERIES_MODEL_PREFIXES)

def update_completion_params(self, params: dict[str, Any]) -> dict[str, Any]:
"""
Override the base method to update the completion parameters for OpenAI.
For O-series models, use "max_completion_tokens" instead of "max_tokens".
"""
new_params = params.copy()
if self.is_o_series_model:
new_params["max_completion_tokens"] = self.max_tokens
new_params.pop("max_tokens", None)
return new_params
84 changes: 84 additions & 0 deletions examples/orchestrators/adaptive_article_o3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from dynamiq.connections import E2B as E2BConnection
from dynamiq.connections import Exa, ZenRows
from dynamiq.nodes.agents.orchestrators.adaptive import AdaptiveOrchestrator
from dynamiq.nodes.agents.orchestrators.adaptive_manager import AdaptiveAgentManager
from dynamiq.nodes.agents.react import ReActAgent
from dynamiq.nodes.agents.simple import SimpleAgent
from dynamiq.nodes.tools.e2b_sandbox import E2BInterpreterTool
from dynamiq.nodes.tools.exa_search import ExaTool
from dynamiq.nodes.tools.zenrows import ZenRowsTool
from dynamiq.nodes.types import InferenceMode
from examples.llm_setup import setup_llm

INPUT_TASK = (
"Let's find data on optimizing "
"SEO campaigns in 2025, analyze it, "
"and provide predictions with calculations "
"on how to improve and implement these strategies."
)


if __name__ == "__main__":
python_tool = E2BInterpreterTool(
name="Code Executor",
connection=E2BConnection(),
)

zenrows_tool = ZenRowsTool(
connection=ZenRows(),
name="Web Scraper",
)

exa_tool = ExaTool(
connection=Exa(),
name="Search Engine",
)

llm = setup_llm(model_provider="gpt", model_name="o3-mini", max_tokens=100000)

agent_coding = ReActAgent(
name="Coding Agent",
llm=llm,
tools=[python_tool],
max_loops=13,
inference_mode=InferenceMode.XML,
)

agent_web = ReActAgent(
name="Web Agent",
llm=llm,
tools=[zenrows_tool, exa_tool],
max_loops=13,
inference_mode=InferenceMode.XML,
)

agent_reflection = SimpleAgent(
name="Reflection Agent (Reviewer, Critic)",
llm=llm,
role=(
"Analyze and review the accuracy of any results, "
"including tasks, code, or data. "
"Offer feedback and suggestions for improvement."
),
)

agent_manager = AdaptiveAgentManager(
llm=llm,
)

orchestrator = AdaptiveOrchestrator(
name="Adaptive Orchestrator",
agents=[agent_coding, agent_web, agent_reflection],
manager=agent_manager,
)

result = orchestrator.run(
input_data={
"input": INPUT_TASK,
},
config=None,
)

output_content = result.output.get("content")
print("RESULT")
print(output_content)
Loading
Loading