Skip to content

Commit

Permalink
Suppressed userWarnings from litellm pydantic issues (#1833)
Browse files Browse the repository at this point in the history
* Suppressed userWarnings from litellm pydantic issues

* change litellm version

* Fix failling ollama tasks
  • Loading branch information
bhancockio authored Dec 31, 2024
1 parent 4469461 commit ba89e43
Show file tree
Hide file tree
Showing 12 changed files with 2,235 additions and 540 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies = [
# Core Dependencies
"pydantic>=2.4.2",
"openai>=1.13.3",
"litellm>=1.44.22",
"litellm>=1.56.4",
"instructor>=1.3.3",

# Text Processing
Expand Down
30 changes: 19 additions & 11 deletions src/crewai/crew.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,11 +726,7 @@ def _execute_tasks(

# Determine which tools to use - task tools take precedence over agent tools
tools_for_task = task.tools or agent_to_use.tools or []
tools_for_task = self._prepare_tools(
agent_to_use,
task,
tools_for_task
)
tools_for_task = self._prepare_tools(agent_to_use, task, tools_for_task)

self._log_task_start(task, agent_to_use.role)

Expand Down Expand Up @@ -797,14 +793,18 @@ def _handle_conditional_task(
return skipped_task_output
return None

def _prepare_tools(self, agent: BaseAgent, task: Task, tools: List[Tool]) -> List[Tool]:
def _prepare_tools(
self, agent: BaseAgent, task: Task, tools: List[Tool]
) -> List[Tool]:
# Add delegation tools if agent allows delegation
if agent.allow_delegation:
if self.process == Process.hierarchical:
if self.manager_agent:
tools = self._update_manager_tools(task, tools)
else:
raise ValueError("Manager agent is required for hierarchical process.")
raise ValueError(
"Manager agent is required for hierarchical process."
)

elif agent and agent.allow_delegation:
tools = self._add_delegation_tools(task, tools)
Expand All @@ -823,7 +823,9 @@ def _get_agent_to_use(self, task: Task) -> Optional[BaseAgent]:
return self.manager_agent
return task.agent

def _merge_tools(self, existing_tools: List[Tool], new_tools: List[Tool]) -> List[Tool]:
def _merge_tools(
self, existing_tools: List[Tool], new_tools: List[Tool]
) -> List[Tool]:
"""Merge new tools into existing tools list, avoiding duplicates by tool name."""
if not new_tools:
return existing_tools
Expand All @@ -839,7 +841,9 @@ def _merge_tools(self, existing_tools: List[Tool], new_tools: List[Tool]) -> Lis

return tools

def _inject_delegation_tools(self, tools: List[Tool], task_agent: BaseAgent, agents: List[BaseAgent]):
def _inject_delegation_tools(
self, tools: List[Tool], task_agent: BaseAgent, agents: List[BaseAgent]
):
delegation_tools = task_agent.get_delegation_tools(agents)
return self._merge_tools(tools, delegation_tools)

Expand All @@ -856,7 +860,9 @@ def _add_delegation_tools(self, task: Task, tools: List[Tool]):
if len(self.agents) > 1 and len(agents_for_delegation) > 0 and task.agent:
if not tools:
tools = []
tools = self._inject_delegation_tools(tools, task.agent, agents_for_delegation)
tools = self._inject_delegation_tools(
tools, task.agent, agents_for_delegation
)
return tools

def _log_task_start(self, task: Task, role: str = "None"):
Expand All @@ -870,7 +876,9 @@ def _update_manager_tools(self, task: Task, tools: List[Tool]):
if task.agent:
tools = self._inject_delegation_tools(tools, task.agent, [task.agent])
else:
tools = self._inject_delegation_tools(tools, self.manager_agent, self.agents)
tools = self._inject_delegation_tools(
tools, self.manager_agent, self.agents
)
return tools

def _get_context(self, task: Task, task_outputs: List[TaskOutput]):
Expand Down
8 changes: 5 additions & 3 deletions src/crewai/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Union

import litellm
from litellm import get_supported_openai_params
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
import litellm
from litellm import get_supported_openai_params

from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException,
Expand Down Expand Up @@ -138,7 +140,7 @@ def __init__(
self.kwargs = kwargs

litellm.drop_params = True
litellm.set_verbose = False

self.set_callbacks(callbacks)
self.set_env_callbacks()

Expand Down
8 changes: 5 additions & 3 deletions src/crewai/utilities/internal_instructor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Any, Optional, Type


Expand Down Expand Up @@ -25,9 +26,10 @@ def set_instructor(self):
if self.agent and not self.llm:
self.llm = self.agent.function_calling_llm or self.agent.llm

# Lazy import
import instructor
from litellm import completion
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
import instructor
from litellm import completion

self._client = instructor.from_litellm(
completion,
Expand Down
20 changes: 12 additions & 8 deletions src/crewai/utilities/token_counter_callback.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

from litellm.integrations.custom_logger import CustomLogger
from litellm.types.utils import Usage

Expand All @@ -12,11 +14,13 @@ def log_success_event(self, kwargs, response_obj, start_time, end_time):
if self.token_cost_process is None:
return

usage: Usage = response_obj["usage"]
self.token_cost_process.sum_successful_requests(1)
self.token_cost_process.sum_prompt_tokens(usage.prompt_tokens)
self.token_cost_process.sum_completion_tokens(usage.completion_tokens)
if usage.prompt_tokens_details:
self.token_cost_process.sum_cached_prompt_tokens(
usage.prompt_tokens_details.cached_tokens
)
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
usage: Usage = response_obj["usage"]
self.token_cost_process.sum_successful_requests(1)
self.token_cost_process.sum_prompt_tokens(usage.prompt_tokens)
self.token_cost_process.sum_completion_tokens(usage.completion_tokens)
if usage.prompt_tokens_details:
self.token_cost_process.sum_cached_prompt_tokens(
usage.prompt_tokens_details.cached_tokens
)
23 changes: 10 additions & 13 deletions tests/agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1445,34 +1445,31 @@ def test_llm_call_with_all_attributes():


@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_with_ollama_gemma():
def test_agent_with_ollama_llama3():
agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
llm=LLM(
model="ollama/gemma2:latest",
base_url="http://localhost:8080",
),
llm=LLM(model="ollama/llama3.2:3b", base_url="http://localhost:11434"),
)

assert isinstance(agent.llm, LLM)
assert agent.llm.model == "ollama/gemma2:latest"
assert agent.llm.base_url == "http://localhost:8080"
assert agent.llm.model == "ollama/llama3.2:3b"
assert agent.llm.base_url == "http://localhost:11434"

task = "Respond in 20 words. Who are you?"
response = agent.llm.call([{"role": "user", "content": task}])

assert response
assert len(response.split()) <= 25 # Allow a little flexibility in word count
assert "Gemma" in response or "AI" in response or "language model" in response
assert "Llama3" in response or "AI" in response or "language model" in response


@pytest.mark.vcr(filter_headers=["authorization"])
def test_llm_call_with_ollama_gemma():
def test_llm_call_with_ollama_llama3():
llm = LLM(
model="ollama/gemma2:latest",
base_url="http://localhost:8080",
model="ollama/llama3.2:3b",
base_url="http://localhost:11434",
temperature=0.7,
max_tokens=30,
)
Expand All @@ -1482,7 +1479,7 @@ def test_llm_call_with_ollama_gemma():

assert response
assert len(response.split()) <= 25 # Allow a little flexibility in word count
assert "Gemma" in response or "AI" in response or "language model" in response
assert "Llama3" in response or "AI" in response or "language model" in response


@pytest.mark.vcr(filter_headers=["authorization"])
Expand Down Expand Up @@ -1578,7 +1575,7 @@ def test_agent_execute_task_with_ollama():
role="test role",
goal="test goal",
backstory="test backstory",
llm=LLM(model="ollama/gemma2:latest", base_url="http://localhost:8080"),
llm=LLM(model="ollama/llama3.2:3b", base_url="http://localhost:11434"),
)

task = Task(
Expand Down
Loading

0 comments on commit ba89e43

Please sign in to comment.