Skip to content

Commit 705d57b

Browse files
Music generator (#2103)
## Description Generates music from a prompt.
1 parent 4867004 commit 705d57b

File tree

3 files changed

+133
-68
lines changed

3 files changed

+133
-68
lines changed

cookbook/playground/multimodal_agents.py

+27
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,32 @@
5757
),
5858
)
5959

60+
ml_music_agent = Agent(
61+
name="ModelsLab Music Agent",
62+
agent_id="ml_music_agent",
63+
model=OpenAIChat(id="gpt-4o"),
64+
tools=[ModelsLabTools(wait_for_completion=True, file_type=FileType.MP3)],
65+
description="You are an AI agent that can generate music using the ModelsLabs API.",
66+
instructions=[
67+
"When generating music, use the `generate_media` tool with detailed prompts that specify:",
68+
"- The genre and style of music (e.g., classical, jazz, electronic)",
69+
"- The instruments and sounds to include",
70+
"- The tempo, mood and emotional qualities",
71+
"- The structure (intro, verses, chorus, bridge, etc.)",
72+
"Create rich, descriptive prompts that capture the desired musical elements.",
73+
"Focus on generating high-quality, complete instrumental pieces.",
74+
"Keep responses simple and only confirm when music is generated successfully.",
75+
"Do not include any file names, URLs or technical details in responses.",
76+
],
77+
markdown=True,
78+
debug_mode=True,
79+
add_history_to_messages=True,
80+
add_datetime_to_instructions=True,
81+
storage=SqliteAgentStorage(
82+
table_name="ml_music_agent", db_file=image_agent_storage_file
83+
),
84+
)
85+
6086
ml_video_agent = Agent(
6187
name="ModelsLab Video Agent",
6288
agent_id="ml_video_agent",
@@ -147,6 +173,7 @@
147173
agents=[
148174
image_agent,
149175
ml_gif_agent,
176+
ml_music_agent,
150177
ml_video_agent,
151178
fal_agent,
152179
gif_agent,

libs/agno/agno/models/response.py

+1
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,4 @@ class ModelResponse:
3838
class FileType(str, Enum):
3939
MP4 = "mp4"
4040
GIF = "gif"
41+
MP3 = "mp3"

libs/agno/agno/tools/models_labs.py

+105-68
Original file line numberDiff line numberDiff line change
@@ -1,124 +1,161 @@
11
import json
22
import time
33
from os import getenv
4-
from typing import Optional
4+
from typing import Any, Dict, Optional
55
from uuid import uuid4
66

77
from agno.agent import Agent
8-
from agno.media import ImageArtifact, VideoArtifact
8+
from agno.media import AudioArtifact, ImageArtifact, VideoArtifact
99
from agno.models.response import FileType
1010
from agno.tools import Toolkit
1111
from agno.utils.log import logger
1212

1313
try:
1414
import requests
15+
from requests.exceptions import RequestException
1516
except ImportError:
1617
raise ImportError("`requests` not installed. Please install using `pip install requests`")
1718

19+
MODELS_LAB_URLS = {
20+
"MP4": "https://modelslab.com/api/v6/video/text2video",
21+
"MP3": "https://modelslab.com/api/v6/voice/music_gen",
22+
"GIF": "https://modelslab.com/api/v6/video/text2video",
23+
}
24+
25+
MODELS_LAB_FETCH_URLS = {
26+
"MP4": "https://modelslab.com/api/v6/video/fetch",
27+
"MP3": "https://modelslab.com/api/v6/voice/fetch",
28+
"GIF": "https://modelslab.com/api/v6/video/fetch",
29+
}
30+
1831

1932
class ModelsLabTools(Toolkit):
2033
def __init__(
2134
self,
2235
api_key: Optional[str] = None,
23-
url: str = "https://modelslab.com/api/v6/video/text2video",
24-
fetch_url: str = "https://modelslab.com/api/v6/video/fetch",
25-
# Whether to wait for the video to be ready
2636
wait_for_completion: bool = False,
27-
# Time to add to the ETA to account for the time it takes to fetch the video
2837
add_to_eta: int = 15,
29-
# Maximum time to wait for the video to be ready
3038
max_wait_time: int = 60,
3139
file_type: FileType = FileType.MP4,
3240
):
3341
super().__init__(name="models_labs")
3442

35-
self.url = url
36-
self.fetch_url = fetch_url
43+
file_type_str = file_type.value.upper()
44+
self.url = MODELS_LAB_URLS[file_type_str]
45+
self.fetch_url = MODELS_LAB_FETCH_URLS[file_type_str]
3746
self.wait_for_completion = wait_for_completion
3847
self.add_to_eta = add_to_eta
3948
self.max_wait_time = max_wait_time
4049
self.file_type = file_type
4150
self.api_key = api_key or getenv("MODELS_LAB_API_KEY")
51+
4252
if not self.api_key:
4353
logger.error("MODELS_LAB_API_KEY not set. Please set the MODELS_LAB_API_KEY environment variable.")
4454

4555
self.register(self.generate_media)
4656

47-
def generate_media(self, agent: Agent, prompt: str) -> str:
48-
"""Use this function to generate a video or image given a prompt.
49-
50-
Args:
51-
prompt (str): A text description of the desired video.
57+
def _create_payload(self, prompt: str) -> Dict[str, Any]:
58+
"""Create payload based on file type."""
59+
base_payload: Dict[str, Any] = {
60+
"key": self.api_key,
61+
"prompt": prompt,
62+
"webhook": None,
63+
"track_id": None,
64+
}
65+
66+
if self.file_type in [FileType.MP4, FileType.GIF]:
67+
video_template = {
68+
"height": 512,
69+
"width": 512,
70+
"num_frames": 25,
71+
"negative_prompt": "low quality",
72+
"model_id": "cogvideox",
73+
"instant_response": False,
74+
"output_type": self.file_type.value,
75+
}
76+
base_payload |= video_template # Use |= instead of update()
77+
else:
78+
audio_template = {
79+
"base64": False,
80+
"temp": False,
81+
}
82+
base_payload |= audio_template # Use |= instead of update()
83+
84+
return base_payload
85+
86+
def _add_media_artifact(self, agent: Agent, media_id: str, media_url: str, eta: Optional[str] = None) -> None:
87+
"""Add appropriate media artifact based on file type."""
88+
if self.file_type == FileType.MP4:
89+
agent.add_video(VideoArtifact(id=str(media_id), url=media_url, eta=str(eta)))
90+
elif self.file_type == FileType.GIF:
91+
agent.add_image(ImageArtifact(id=str(media_id), url=media_url))
92+
elif self.file_type == FileType.MP3:
93+
agent.add_audio(AudioArtifact(id=str(media_id), url=media_url))
94+
95+
def _wait_for_media(self, media_id: str, eta: int) -> bool:
96+
"""Wait for media generation to complete."""
97+
time_to_wait = min(eta + self.add_to_eta, self.max_wait_time)
98+
logger.info(f"Waiting for {time_to_wait} seconds for {self.file_type.value} to be ready")
99+
100+
for seconds_waited in range(time_to_wait):
101+
try:
102+
fetch_response = requests.post(
103+
f"{self.fetch_url}/{media_id}",
104+
json={"key": self.api_key},
105+
headers={"Content-Type": "application/json"},
106+
)
107+
fetch_result = fetch_response.json()
108+
109+
if fetch_result.get("status") == "success":
110+
return True
111+
112+
time.sleep(1)
113+
114+
except RequestException as e:
115+
logger.warning(f"Error during fetch attempt {seconds_waited}: {e}")
116+
117+
return False
52118

53-
Returns:
54-
str: A message indicating if the video has been generated successfully or an error message.
55-
"""
119+
def generate_media(self, agent: Agent, prompt: str) -> str:
120+
"""Generate media (video, image, or audio) given a prompt."""
56121
if not self.api_key:
57122
return "Please set the MODELS_LAB_API_KEY"
58123

59124
try:
60-
payload = json.dumps(
61-
{
62-
"key": self.api_key,
63-
"prompt": prompt,
64-
"height": 512,
65-
"width": 512,
66-
"num_frames": 25,
67-
"webhook": None,
68-
"output_type": self.file_type.value,
69-
"track_id": None,
70-
"negative_prompt": "low quality",
71-
"model_id": "cogvideox",
72-
"instant_response": False,
73-
}
74-
)
75-
125+
payload = json.dumps(self._create_payload(prompt))
76126
headers = {"Content-Type": "application/json"}
77-
logger.debug(f"Generating video for prompt: {prompt}")
78-
response = requests.request("POST", self.url, data=payload, headers=headers)
127+
128+
logger.debug(f"Generating {self.file_type.value} for prompt: {prompt}")
129+
response = requests.post(self.url, data=payload, headers=headers)
79130
response.raise_for_status()
80131

81132
result = response.json()
133+
82134
if "error" in result:
83-
logger.error(f"Failed to generate video: {result['error']}")
135+
error_msg = f"Failed to generate {self.file_type.value}: {result['error']}"
136+
logger.error(error_msg)
84137
return f"Error: {result['error']}"
85138

86139
eta = result["eta"]
87140
url_links = result["future_links"]
88-
logger.info(f"Media will be ready in {eta} seconds")
89-
logger.info(f"Media URLs: {url_links}")
141+
media_id = str(uuid4())
90142

91-
video_id = str(uuid4())
92-
93-
logger.debug(f"Result: {result}")
94143
for media_url in url_links:
95-
if self.file_type == FileType.MP4:
96-
agent.add_video(VideoArtifact(id=str(video_id), url=media_url, eta=str(eta)))
97-
elif self.file_type == FileType.GIF:
98-
agent.add_image(ImageArtifact(id=str(video_id), url=media_url))
144+
self._add_media_artifact(agent, media_id, media_url, str(eta))
99145

100146
if self.wait_for_completion and isinstance(eta, int):
101-
video_ready = False
102-
seconds_waited = 0
103-
time_to_wait = min(eta + self.add_to_eta, self.max_wait_time)
104-
logger.info(f"Waiting for {time_to_wait} seconds for video to be ready")
105-
while not video_ready and seconds_waited < time_to_wait:
106-
time.sleep(1)
107-
seconds_waited += 1
108-
# Fetch the video from the ModelsLabs API
109-
fetch_payload = json.dumps({"key": self.api_key})
110-
fetch_headers = {"Content-Type": "application/json"}
111-
logger.debug(f"Fetching video from {self.fetch_url}/{video_id}")
112-
fetch_response = requests.request(
113-
"POST", f"{self.fetch_url}/{video_id}", data=fetch_payload, headers=fetch_headers
114-
)
115-
fetch_result = fetch_response.json()
116-
logger.debug(f"Fetch result: {fetch_result}")
117-
if fetch_result.get("status") == "success":
118-
video_ready = True
119-
break
120-
121-
return f"Video has been generated successfully and will be ready in {eta} seconds"
147+
if self._wait_for_media(media_id, eta):
148+
logger.info("Media generation completed successfully")
149+
else:
150+
logger.warning("Media generation timed out")
151+
152+
return f"{self.file_type.value.capitalize()} has been generated successfully and will be ready in {eta} seconds"
153+
154+
except RequestException as e:
155+
error_msg = f"Network error while generating {self.file_type.value}: {e}"
156+
logger.error(error_msg)
157+
return f"Error: {error_msg}"
122158
except Exception as e:
123-
logger.error(f"Failed to generate video: {e}")
124-
return f"Error: {e}"
159+
error_msg = f"Unexpected error while generating {self.file_type.value}: {e}"
160+
logger.error(error_msg)
161+
return f"Error: {error_msg}"

0 commit comments

Comments
 (0)