|
1 | 1 | import json
|
2 | 2 | import time
|
3 | 3 | from os import getenv
|
4 |
| -from typing import Optional |
| 4 | +from typing import Any, Dict, Optional |
5 | 5 | from uuid import uuid4
|
6 | 6 |
|
7 | 7 | from agno.agent import Agent
|
8 |
| -from agno.media import ImageArtifact, VideoArtifact |
| 8 | +from agno.media import AudioArtifact, ImageArtifact, VideoArtifact |
9 | 9 | from agno.models.response import FileType
|
10 | 10 | from agno.tools import Toolkit
|
11 | 11 | from agno.utils.log import logger
|
12 | 12 |
|
13 | 13 | try:
|
14 | 14 | import requests
|
| 15 | + from requests.exceptions import RequestException |
15 | 16 | except ImportError:
|
16 | 17 | raise ImportError("`requests` not installed. Please install using `pip install requests`")
|
17 | 18 |
|
| 19 | +MODELS_LAB_URLS = { |
| 20 | + "MP4": "https://modelslab.com/api/v6/video/text2video", |
| 21 | + "MP3": "https://modelslab.com/api/v6/voice/music_gen", |
| 22 | + "GIF": "https://modelslab.com/api/v6/video/text2video", |
| 23 | +} |
| 24 | + |
| 25 | +MODELS_LAB_FETCH_URLS = { |
| 26 | + "MP4": "https://modelslab.com/api/v6/video/fetch", |
| 27 | + "MP3": "https://modelslab.com/api/v6/voice/fetch", |
| 28 | + "GIF": "https://modelslab.com/api/v6/video/fetch", |
| 29 | +} |
| 30 | + |
18 | 31 |
|
19 | 32 | class ModelsLabTools(Toolkit):
|
20 | 33 | def __init__(
|
21 | 34 | self,
|
22 | 35 | api_key: Optional[str] = None,
|
23 |
| - url: str = "https://modelslab.com/api/v6/video/text2video", |
24 |
| - fetch_url: str = "https://modelslab.com/api/v6/video/fetch", |
25 |
| - # Whether to wait for the video to be ready |
26 | 36 | wait_for_completion: bool = False,
|
27 |
| - # Time to add to the ETA to account for the time it takes to fetch the video |
28 | 37 | add_to_eta: int = 15,
|
29 |
| - # Maximum time to wait for the video to be ready |
30 | 38 | max_wait_time: int = 60,
|
31 | 39 | file_type: FileType = FileType.MP4,
|
32 | 40 | ):
|
33 | 41 | super().__init__(name="models_labs")
|
34 | 42 |
|
35 |
| - self.url = url |
36 |
| - self.fetch_url = fetch_url |
| 43 | + file_type_str = file_type.value.upper() |
| 44 | + self.url = MODELS_LAB_URLS[file_type_str] |
| 45 | + self.fetch_url = MODELS_LAB_FETCH_URLS[file_type_str] |
37 | 46 | self.wait_for_completion = wait_for_completion
|
38 | 47 | self.add_to_eta = add_to_eta
|
39 | 48 | self.max_wait_time = max_wait_time
|
40 | 49 | self.file_type = file_type
|
41 | 50 | self.api_key = api_key or getenv("MODELS_LAB_API_KEY")
|
| 51 | + |
42 | 52 | if not self.api_key:
|
43 | 53 | logger.error("MODELS_LAB_API_KEY not set. Please set the MODELS_LAB_API_KEY environment variable.")
|
44 | 54 |
|
45 | 55 | self.register(self.generate_media)
|
46 | 56 |
|
47 |
| - def generate_media(self, agent: Agent, prompt: str) -> str: |
48 |
| - """Use this function to generate a video or image given a prompt. |
49 |
| -
|
50 |
| - Args: |
51 |
| - prompt (str): A text description of the desired video. |
| 57 | + def _create_payload(self, prompt: str) -> Dict[str, Any]: |
| 58 | + """Create payload based on file type.""" |
| 59 | + base_payload: Dict[str, Any] = { |
| 60 | + "key": self.api_key, |
| 61 | + "prompt": prompt, |
| 62 | + "webhook": None, |
| 63 | + "track_id": None, |
| 64 | + } |
| 65 | + |
| 66 | + if self.file_type in [FileType.MP4, FileType.GIF]: |
| 67 | + video_template = { |
| 68 | + "height": 512, |
| 69 | + "width": 512, |
| 70 | + "num_frames": 25, |
| 71 | + "negative_prompt": "low quality", |
| 72 | + "model_id": "cogvideox", |
| 73 | + "instant_response": False, |
| 74 | + "output_type": self.file_type.value, |
| 75 | + } |
| 76 | + base_payload |= video_template # Use |= instead of update() |
| 77 | + else: |
| 78 | + audio_template = { |
| 79 | + "base64": False, |
| 80 | + "temp": False, |
| 81 | + } |
| 82 | + base_payload |= audio_template # Use |= instead of update() |
| 83 | + |
| 84 | + return base_payload |
| 85 | + |
| 86 | + def _add_media_artifact(self, agent: Agent, media_id: str, media_url: str, eta: Optional[str] = None) -> None: |
| 87 | + """Add appropriate media artifact based on file type.""" |
| 88 | + if self.file_type == FileType.MP4: |
| 89 | + agent.add_video(VideoArtifact(id=str(media_id), url=media_url, eta=str(eta))) |
| 90 | + elif self.file_type == FileType.GIF: |
| 91 | + agent.add_image(ImageArtifact(id=str(media_id), url=media_url)) |
| 92 | + elif self.file_type == FileType.MP3: |
| 93 | + agent.add_audio(AudioArtifact(id=str(media_id), url=media_url)) |
| 94 | + |
| 95 | + def _wait_for_media(self, media_id: str, eta: int) -> bool: |
| 96 | + """Wait for media generation to complete.""" |
| 97 | + time_to_wait = min(eta + self.add_to_eta, self.max_wait_time) |
| 98 | + logger.info(f"Waiting for {time_to_wait} seconds for {self.file_type.value} to be ready") |
| 99 | + |
| 100 | + for seconds_waited in range(time_to_wait): |
| 101 | + try: |
| 102 | + fetch_response = requests.post( |
| 103 | + f"{self.fetch_url}/{media_id}", |
| 104 | + json={"key": self.api_key}, |
| 105 | + headers={"Content-Type": "application/json"}, |
| 106 | + ) |
| 107 | + fetch_result = fetch_response.json() |
| 108 | + |
| 109 | + if fetch_result.get("status") == "success": |
| 110 | + return True |
| 111 | + |
| 112 | + time.sleep(1) |
| 113 | + |
| 114 | + except RequestException as e: |
| 115 | + logger.warning(f"Error during fetch attempt {seconds_waited}: {e}") |
| 116 | + |
| 117 | + return False |
52 | 118 |
|
53 |
| - Returns: |
54 |
| - str: A message indicating if the video has been generated successfully or an error message. |
55 |
| - """ |
| 119 | + def generate_media(self, agent: Agent, prompt: str) -> str: |
| 120 | + """Generate media (video, image, or audio) given a prompt.""" |
56 | 121 | if not self.api_key:
|
57 | 122 | return "Please set the MODELS_LAB_API_KEY"
|
58 | 123 |
|
59 | 124 | try:
|
60 |
| - payload = json.dumps( |
61 |
| - { |
62 |
| - "key": self.api_key, |
63 |
| - "prompt": prompt, |
64 |
| - "height": 512, |
65 |
| - "width": 512, |
66 |
| - "num_frames": 25, |
67 |
| - "webhook": None, |
68 |
| - "output_type": self.file_type.value, |
69 |
| - "track_id": None, |
70 |
| - "negative_prompt": "low quality", |
71 |
| - "model_id": "cogvideox", |
72 |
| - "instant_response": False, |
73 |
| - } |
74 |
| - ) |
75 |
| - |
| 125 | + payload = json.dumps(self._create_payload(prompt)) |
76 | 126 | headers = {"Content-Type": "application/json"}
|
77 |
| - logger.debug(f"Generating video for prompt: {prompt}") |
78 |
| - response = requests.request("POST", self.url, data=payload, headers=headers) |
| 127 | + |
| 128 | + logger.debug(f"Generating {self.file_type.value} for prompt: {prompt}") |
| 129 | + response = requests.post(self.url, data=payload, headers=headers) |
79 | 130 | response.raise_for_status()
|
80 | 131 |
|
81 | 132 | result = response.json()
|
| 133 | + |
82 | 134 | if "error" in result:
|
83 |
| - logger.error(f"Failed to generate video: {result['error']}") |
| 135 | + error_msg = f"Failed to generate {self.file_type.value}: {result['error']}" |
| 136 | + logger.error(error_msg) |
84 | 137 | return f"Error: {result['error']}"
|
85 | 138 |
|
86 | 139 | eta = result["eta"]
|
87 | 140 | url_links = result["future_links"]
|
88 |
| - logger.info(f"Media will be ready in {eta} seconds") |
89 |
| - logger.info(f"Media URLs: {url_links}") |
| 141 | + media_id = str(uuid4()) |
90 | 142 |
|
91 |
| - video_id = str(uuid4()) |
92 |
| - |
93 |
| - logger.debug(f"Result: {result}") |
94 | 143 | for media_url in url_links:
|
95 |
| - if self.file_type == FileType.MP4: |
96 |
| - agent.add_video(VideoArtifact(id=str(video_id), url=media_url, eta=str(eta))) |
97 |
| - elif self.file_type == FileType.GIF: |
98 |
| - agent.add_image(ImageArtifact(id=str(video_id), url=media_url)) |
| 144 | + self._add_media_artifact(agent, media_id, media_url, str(eta)) |
99 | 145 |
|
100 | 146 | if self.wait_for_completion and isinstance(eta, int):
|
101 |
| - video_ready = False |
102 |
| - seconds_waited = 0 |
103 |
| - time_to_wait = min(eta + self.add_to_eta, self.max_wait_time) |
104 |
| - logger.info(f"Waiting for {time_to_wait} seconds for video to be ready") |
105 |
| - while not video_ready and seconds_waited < time_to_wait: |
106 |
| - time.sleep(1) |
107 |
| - seconds_waited += 1 |
108 |
| - # Fetch the video from the ModelsLabs API |
109 |
| - fetch_payload = json.dumps({"key": self.api_key}) |
110 |
| - fetch_headers = {"Content-Type": "application/json"} |
111 |
| - logger.debug(f"Fetching video from {self.fetch_url}/{video_id}") |
112 |
| - fetch_response = requests.request( |
113 |
| - "POST", f"{self.fetch_url}/{video_id}", data=fetch_payload, headers=fetch_headers |
114 |
| - ) |
115 |
| - fetch_result = fetch_response.json() |
116 |
| - logger.debug(f"Fetch result: {fetch_result}") |
117 |
| - if fetch_result.get("status") == "success": |
118 |
| - video_ready = True |
119 |
| - break |
120 |
| - |
121 |
| - return f"Video has been generated successfully and will be ready in {eta} seconds" |
| 147 | + if self._wait_for_media(media_id, eta): |
| 148 | + logger.info("Media generation completed successfully") |
| 149 | + else: |
| 150 | + logger.warning("Media generation timed out") |
| 151 | + |
| 152 | + return f"{self.file_type.value.capitalize()} has been generated successfully and will be ready in {eta} seconds" |
| 153 | + |
| 154 | + except RequestException as e: |
| 155 | + error_msg = f"Network error while generating {self.file_type.value}: {e}" |
| 156 | + logger.error(error_msg) |
| 157 | + return f"Error: {error_msg}" |
122 | 158 | except Exception as e:
|
123 |
| - logger.error(f"Failed to generate video: {e}") |
124 |
| - return f"Error: {e}" |
| 159 | + error_msg = f"Unexpected error while generating {self.file_type.value}: {e}" |
| 160 | + logger.error(error_msg) |
| 161 | + return f"Error: {error_msg}" |
0 commit comments