diff --git a/pkgs/community/swarmauri_llm_communityleptonai/README.md b/pkgs/community/swarmauri_llm_communityleptonai/README.md new file mode 100644 index 00000000..cd26902a --- /dev/null +++ b/pkgs/community/swarmauri_llm_communityleptonai/README.md @@ -0,0 +1 @@ +# Swarmauri Example Community Package \ No newline at end of file diff --git a/pkgs/community/swarmauri_llm_communityleptonai/pyproject.toml b/pkgs/community/swarmauri_llm_communityleptonai/pyproject.toml new file mode 100644 index 00000000..c7fd920a --- /dev/null +++ b/pkgs/community/swarmauri_llm_communityleptonai/pyproject.toml @@ -0,0 +1,57 @@ +[tool.poetry] +name = "swarmauri_llm_communityleptonai" +version = "0.6.0.dev1" +description = "Swarmauri Lepton AI Model" +authors = ["Jacob Stewart "] +license = "Apache-2.0" +readme = "README.md" +repository = "http://github.com/swarmauri/swarmauri-sdk" +classifiers = [ + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12" +] + +[tool.poetry.dependencies] +python = ">=3.10,<3.13" + +# Swarmauri +swarmauri_core = { path = "../../core" } +swarmauri_base = { path = "../../base" } + +# Dependencies +leptonai = "^0.22.0" + +[tool.poetry.group.dev.dependencies] +flake8 = "^7.0" +pytest = "^8.0" +pytest-asyncio = ">=0.24.0" +pytest-xdist = "^3.6.1" +pytest-json-report = "^1.5.0" +python-dotenv = "*" +requests = "^2.32.3" + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" + +[tool.pytest.ini_options] +norecursedirs = ["combined", "scripts"] + +markers = [ + "test: standard test", + "unit: Unit tests", + "integration: Integration tests", + "acceptance: Acceptance tests", + "experimental: Experimental tests" +] +log_cli = true +log_cli_level = "INFO" +log_cli_format = "%(asctime)s [%(levelname)s] %(message)s" +log_cli_date_format = "%Y-%m-%d %H:%M:%S" +asyncio_default_fixture_loop_scope = "function" + +[tool.poetry.plugins."swarmauri.llms"] +LeptonAIImgGenModel = "swarmauri_llm_communityleptonai.LeptonAIImgGenModel:LeptonAIImgGenModel" +LeptonAIModel = "swarmauri_llm_communityleptonai.LeptonAIImgGenModel:LeptonAIModel" \ No newline at end of file diff --git a/pkgs/community/swarmauri_llm_communityleptonai/swarmauri_llm_communityleptonai/LeptonAIImgGenModel.py b/pkgs/community/swarmauri_llm_communityleptonai/swarmauri_llm_communityleptonai/LeptonAIImgGenModel.py new file mode 100644 index 00000000..636b8707 --- /dev/null +++ b/pkgs/community/swarmauri_llm_communityleptonai/swarmauri_llm_communityleptonai/LeptonAIImgGenModel.py @@ -0,0 +1,97 @@ +import os +import asyncio +import requests +from io import BytesIO +from PIL import Image +from typing import List, Literal +from pydantic import Field, ConfigDict +from swarmauri_base.llms.LLMBase import LLMBase + + +class LeptonAIImgGenModel(LLMBase): + """ + A model for generating images from text using Lepton AI's SDXL image generation model. + It returns the image as bytes. + Get your API KEY from Lepton AI. + """ + + api_key: str = Field(default_factory=lambda: os.environ.get("LEPTON_API_KEY")) + model_name: str = Field(default="sdxl") + type: Literal["LeptonAIImgGenModel"] = "LeptonAIImgGenModel" + base_url: str = Field(default="https://sdxl.lepton.run") + + model_config = ConfigDict(protected_namespaces=()) + + def __init__(self, **data): + super().__init__(**data) + if self.api_key: + os.environ["LEPTON_API_KEY"] = self.api_key + + def _send_request(self, prompt: str, **kwargs) -> bytes: + """Send a request to Lepton AI's API for image generation.""" + client = requests.Session() + client.headers.update({"Authorization": f"Bearer {self.api_key}"}) + + payload = { + "prompt": prompt, + "height": kwargs.get("height", 1024), + "width": kwargs.get("width", 1024), + "guidance_scale": kwargs.get("guidance_scale", 5), + "high_noise_frac": kwargs.get("high_noise_frac", 0.75), + "seed": kwargs.get("seed", None), + "steps": kwargs.get("steps", 30), + "use_refiner": kwargs.get("use_refiner", False), + } + + response = client.post(f"{self.base_url}/run", json=payload) + response.raise_for_status() + return response.content + + def generate_image(self, prompt: str, **kwargs) -> bytes: + """Generates an image based on the prompt and returns the image as bytes.""" + return self._send_request(prompt, **kwargs) + + async def agenerate_image(self, prompt: str, **kwargs) -> bytes: + """Asynchronously generates an image based on the prompt and returns the image as bytes.""" + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, self.generate_image, prompt, **kwargs) + + def batch(self, prompts: List[str], **kwargs) -> List[bytes]: + """ + Generates images for a batch of prompts. + Returns a list of image bytes. + """ + image_bytes_list = [] + for prompt in prompts: + image_bytes = self.generate_image(prompt=prompt, **kwargs) + image_bytes_list.append(image_bytes) + return image_bytes_list + + async def abatch( + self, prompts: List[str], max_concurrent: int = 5, **kwargs + ) -> List[bytes]: + """ + Asynchronously generates images for a batch of prompts. + Returns a list of image bytes. + """ + semaphore = asyncio.Semaphore(max_concurrent) + + async def process_prompt(prompt): + async with semaphore: + return await self.agenerate_image(prompt=prompt, **kwargs) + + tasks = [process_prompt(prompt) for prompt in prompts] + return await asyncio.gather(*tasks) + + @staticmethod + def save_image(image_bytes: bytes, filename: str): + """Utility method to save the image bytes to a file.""" + with open(filename, "wb") as f: + f.write(image_bytes) + print(f"Image saved as {filename}") + + @staticmethod + def display_image(image_bytes: bytes): + """Utility method to display the image using PIL.""" + image = Image.open(BytesIO(image_bytes)) + image.show() diff --git a/pkgs/community/swarmauri_llm_communityleptonai/swarmauri_llm_communityleptonai/LeptonAIModel.py b/pkgs/community/swarmauri_llm_communityleptonai/swarmauri_llm_communityleptonai/LeptonAIModel.py new file mode 100644 index 00000000..539b3fa7 --- /dev/null +++ b/pkgs/community/swarmauri_llm_communityleptonai/swarmauri_llm_communityleptonai/LeptonAIModel.py @@ -0,0 +1,285 @@ +import json +from openai import OpenAI, AsyncOpenAI +from typing import List, Dict, Literal, Optional, Iterator, AsyncIterator +import asyncio +from pydantic import Field +from swarmauri_core.typing import SubclassUnion +from swarmauri_base.messages.MessageBase import MessageBase +from swarmauri_base.llms.base.LLMBase import LLMBase + +from swarmauri_standard.messages.AgentMessage import AgentMessage, UsageData +from swarmauri.utils.duration_manager import DurationManager + + +class LeptonAIModel(LLMBase): + """ + Provider resources: https://www.lepton.ai/playground + """ + + api_key: str + allowed_models: List[str] = [ + "llama2-13b", + "llama3-1-405b", + "llama3-1-70b", + "llama3-1-8b", + "llama3-70b", + "llama3-8b", + "mixtral-8x7b", + "mistral-7b", + "nous-hermes-llama2-13b", + "openchat-3-5", + "qwen2-72b", + "toppy-m-7b", + "wizardlm-2-7b", + "wizardlm-2-8x22b", + ] + + name: str = "llama3-8b" + type: Literal["LeptonAIModel"] = "LeptonAIModel" + client: OpenAI = Field(default=None, exclude=True) + async_client: AsyncOpenAI = Field(default=None, exclude=True) + + def __init__(self, **data): + super().__init__(**data) + url = f"https://{self.name}.lepton.run/api/v1/" + self.client = OpenAI(base_url=url, api_key=self.api_key) + self.async_client = AsyncOpenAI(base_url=url, api_key=self.api_key) + + def _format_messages( + self, messages: List[SubclassUnion[MessageBase]] + ) -> List[Dict[str, str]]: + message_properties = ["content", "role", "name"] + formatted_messages = [ + message.model_dump(include=message_properties, exclude_none=True) + for message in messages + ] + return formatted_messages + + def _get_system_context(self, messages: List[SubclassUnion[MessageBase]]) -> str: + system_context = None + for message in messages: + if message.role == "system": + system_context = message.content + return system_context + + def _prepare_messages(self, conversation): + formatted_messages = self._format_messages(conversation.history) + system_context = self._get_system_context(conversation.history) + if system_context: + formatted_messages = [ + {"role": "system", "content": system_context}, + formatted_messages[-1], + ] + return formatted_messages + + def _prepare_usage_data( + self, + usage_data, + prompt_time: float = 0.0, + completion_time: float = 0.0, + ): + """ + Prepares and extracts usage data and response timing. + """ + total_time = prompt_time + completion_time + + usage = UsageData( + prompt_tokens=usage_data.get("prompt_tokens", 0), + completion_tokens=usage_data.get("completion_tokens", 0), + total_tokens=usage_data.get("total_tokens", 0), + prompt_time=prompt_time, + completion_time=completion_time, + total_time=total_time, + ) + + return usage + + def predict( + self, + conversation, + temperature: Optional[float] = 0.5, + max_tokens: Optional[int] = 256, + top_p: Optional[float] = 0.8, + stream: Optional[bool] = False, + ): + formatted_messages = self._prepare_messages(conversation) + + with DurationManager() as prompt_timer: + response = self.client.chat.completions.create( + model=self.name, + messages=formatted_messages, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + stream=stream, + ) + + result = json.loads(response.model_dump_json()) + message_content = result["choices"][0]["message"]["content"] + usage_data = result.get("usage", {}) + + usage = self._prepare_usage_data( + usage_data, + prompt_timer.duration, + ) + + conversation.add_message(AgentMessage(content=message_content, usage=usage)) + return conversation + + async def apredict( + self, + conversation, + temperature: Optional[float] = 0.5, + max_tokens: Optional[int] = 256, + top_p: Optional[float] = 0.8, + ): + """Asynchronous version of predict""" + formatted_messages = self._prepare_messages(conversation) + + with DurationManager() as prompt_timer: + response = await self.async_client.chat.completions.create( + model=self.name, + messages=formatted_messages, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + ) + + result = json.loads(response.model_dump_json()) + message_content = result["choices"][0]["message"]["content"] + usage_data = result.get("usage", {}) + + usage = self._prepare_usage_data( + usage_data, + prompt_timer.duration, + ) + + conversation.add_message(AgentMessage(content=message_content, usage=usage)) + return conversation + + def stream( + self, + conversation, + temperature: Optional[float] = 0.5, + max_tokens: Optional[int] = 256, + top_p: Optional[float] = 0.8, + ) -> Iterator[str]: + """Synchronously stream the response token by token""" + formatted_messages = self._prepare_messages(conversation) + + with DurationManager() as prompt_timer: + stream = self.client.chat.completions.create( + model=self.name, + messages=formatted_messages, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + stream=True, + stream_options={"include_usage": True}, + ) + + collected_content = [] + usage_data = {} + + with DurationManager() as completion_timer: + for chunk in stream: + if chunk.choices and chunk.choices[0].delta.content: + content = chunk.choices[0].delta.content + collected_content.append(content) + yield content + + if hasattr(chunk, "usage") and chunk.usage is not None: + usage_data = chunk.usage + + full_content = "".join(collected_content) + usage = self._prepare_usage_data( + usage_data.model_dump(), + prompt_timer.duration, + completion_timer.duration, + ) + + conversation.add_message(AgentMessage(content=full_content, usage=usage)) + + async def astream( + self, + conversation, + temperature: Optional[float] = 0.5, + max_tokens: Optional[int] = 256, + top_p: Optional[float] = 0.8, + ) -> AsyncIterator[str]: + """Asynchronously stream the response token by token""" + formatted_messages = self._prepare_messages(conversation) + + with DurationManager() as prompt_timer: + stream = await self.async_client.chat.completions.create( + model=self.name, + messages=formatted_messages, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + stream=True, + stream_options={"include_usage": True}, + ) + + usage_data = {} + collected_content = [] + + with DurationManager() as completion_timer: + async for chunk in stream: + if chunk.choices and chunk.choices[0].delta.content: + content = chunk.choices[0].delta.content + collected_content.append(content) + yield content + + if hasattr(chunk, "usage") and chunk.usage is not None: + usage_data = chunk.usage + + full_content = "".join(collected_content) + + usage = self._prepare_usage_data( + usage_data.model_dump(), + prompt_timer.duration, + completion_timer.duration, + ) + conversation.add_message(AgentMessage(content=full_content, usage=usage)) + + def batch( + self, + conversations: List, + temperature: Optional[float] = 0.5, + max_tokens: Optional[int] = 256, + top_p: Optional[float] = 0.8, + ) -> List: + """Synchronously process multiple conversations""" + return [ + self.predict( + conv, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + ) + for conv in conversations + ] + + async def abatch( + self, + conversations: List, + temperature: Optional[float] = 0.5, + max_tokens: Optional[int] = 256, + top_p: Optional[float] = 0.8, + max_concurrent: int = 5, + ) -> List: + """Process multiple conversations in parallel with controlled concurrency""" + semaphore = asyncio.Semaphore(max_concurrent) + + async def process_conversation(conv): + async with semaphore: + return await self.apredict( + conv, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + ) + + tasks = [process_conversation(conv) for conv in conversations] + return await asyncio.gather(*tasks) diff --git a/pkgs/community/swarmauri_llm_communityleptonai/swarmauri_llm_communityleptonai/__init__.py b/pkgs/community/swarmauri_llm_communityleptonai/swarmauri_llm_communityleptonai/__init__.py new file mode 100644 index 00000000..80d328b7 --- /dev/null +++ b/pkgs/community/swarmauri_llm_communityleptonai/swarmauri_llm_communityleptonai/__init__.py @@ -0,0 +1,17 @@ +from .LeptonAIImgGenModel import LeptonAIImgGenModel +from .LeptonAIModel import LeptonAIModel + +__version__ = "0.6.0.dev26" +__long_desc__ = """ + +# Swarmauri LeptonAI Plugin + +Components Included: +- LeptonAIModel +- LeptonAIImgGenModel + +Visit us at: https://swarmauri.com +Follow us at: https://github.com/swarmauri +Star us at: https://github.com/swarmauri/swarmauri-sdk + +""" diff --git a/pkgs/community/swarmauri_llm_communityleptonai/tests/unit/LeptonAIImgGenModel_unit_test.py b/pkgs/community/swarmauri_llm_communityleptonai/tests/unit/LeptonAIImgGenModel_unit_test.py new file mode 100644 index 00000000..4d6fd8e6 --- /dev/null +++ b/pkgs/community/swarmauri_llm_communityleptonai/tests/unit/LeptonAIImgGenModel_unit_test.py @@ -0,0 +1,76 @@ +import pytest +import os +from swarmauri_llm_communityleptonai.LeptonAIImgGenModel import LeptonAIImgGenModel +from swarmauri_standard.utils.timeout_wrapper import timeout +from dotenv import load_dotenv + +load_dotenv() + +LEPTON_API_KEY = os.getenv("LEPTON_API_KEY") + + +@pytest.fixture(scope="module") +def lepton_ai_imggen_model(): + if not LEPTON_API_KEY: + pytest.skip("Skipping tests due to missing Lepton API key") + model = LeptonAIImgGenModel(api_key=LEPTON_API_KEY) + return model + + +def test_ubc_type(lepton_ai_imggen_model): + assert lepton_ai_imggen_model.type == "LeptonAIImgGenModel" + + +def test_serialization(lepton_ai_imggen_model): + assert ( + lepton_ai_imggen_model.id + == LeptonAIImgGenModel.model_validate_json( + lepton_ai_imggen_model.model_dump_json() + ).id + ) + + +@timeout(5) +def test_generate_image(lepton_ai_imggen_model): + prompt = "A cute cat playing with a ball of yarn" + image_bytes = lepton_ai_imggen_model.generate_image(prompt=prompt) + assert isinstance(image_bytes, bytes) + assert len(image_bytes) > 0 + + +@timeout(5) +@pytest.mark.asyncio +async def test_agenerate_image(lepton_ai_imggen_model): + prompt = "A serene landscape with mountains and a lake" + image_bytes = await lepton_ai_imggen_model.agenerate_image(prompt=prompt) + assert isinstance(image_bytes, bytes) + assert len(image_bytes) > 0 + + +@timeout(5) +def test_batch(lepton_ai_imggen_model): + prompts = [ + "A futuristic city skyline", + "A tropical beach at sunset", + "A steaming cup of coffee on a wooden table", + ] + result_image_bytes_list = lepton_ai_imggen_model.batch(prompts=prompts) + assert len(result_image_bytes_list) == len(prompts) + for image_bytes in result_image_bytes_list: + assert isinstance(image_bytes, bytes) + assert len(image_bytes) > 0 + + +@timeout(5) +@pytest.mark.asyncio +async def test_abatch(lepton_ai_imggen_model): + prompts = [ + "An abstract painting with vibrant colors", + "A snowy mountain peak", + "A vintage car on a rural road", + ] + result_image_bytes_list = await lepton_ai_imggen_model.abatch(prompts=prompts) + assert len(result_image_bytes_list) == len(prompts) + for image_bytes in result_image_bytes_list: + assert isinstance(image_bytes, bytes) + assert len(image_bytes) > 0 diff --git a/pkgs/community/swarmauri_llm_communityleptonai/tests/unit/LeptonAIModel_unit_test.py b/pkgs/community/swarmauri_llm_communityleptonai/tests/unit/LeptonAIModel_unit_test.py new file mode 100644 index 00000000..8fc82f79 --- /dev/null +++ b/pkgs/community/swarmauri_llm_communityleptonai/tests/unit/LeptonAIModel_unit_test.py @@ -0,0 +1,212 @@ +import logging + +import pytest +import time +import os +import asyncio +from swarmauri_llm_communityleptonai.LeptonAIModel import LeptonAIModel as LLM +from swarmauri_standard.conversations.Conversation import Conversation +from swarmauri_standard.messages.HumanMessage import HumanMessage +from swarmauri_standard.messages.SystemMessage import SystemMessage +from swarmauri_standard.utils.timeout_wrapper import timeout + +from swarmauri_standard.messages.AgentMessage import UsageData +from dotenv import load_dotenv + +load_dotenv() + +API_KEY = os.getenv("LEPTON_API_KEY") + + +@pytest.fixture(scope="module") +def leptonai_model(): + if not API_KEY: + pytest.skip("Skipping due to environment variable not set") + llm = LLM(api_key=API_KEY) + return llm + + +def get_allowed_models(): + if not API_KEY: + return [] + llm = LLM(api_key=API_KEY) + return llm.allowed_models + + +@pytest.mark.unit +def test_ubc_resource(leptonai_model): + assert leptonai_model.resource == "LLM" + + +@pytest.mark.unit +def test_ubc_type(leptonai_model): + assert leptonai_model.type == "LeptonAIModel" + + +@pytest.mark.unit +def test_serialization(leptonai_model): + assert ( + leptonai_model.id + == LLM.model_validate_json(leptonai_model.model_dump_json()).id + ) + + +@pytest.mark.unit +def test_default_name(leptonai_model): + assert leptonai_model.name == "llama3-8b" + + +@timeout(5) +@pytest.mark.parametrize("model_name", get_allowed_models()) +@pytest.mark.unit +def test_no_system_context(leptonai_model, model_name): + model = leptonai_model + model.name = model_name + conversation = Conversation() + + input_data = "Hello" + human_message = HumanMessage(content=input_data) + conversation.add_message(human_message) + time.sleep(1) + + model.predict(conversation=conversation) + prediction = conversation.get_last().content + assert isinstance(prediction, str) + assert isinstance(conversation.get_last().usage, UsageData) + logging.info(conversation.get_last().usage) + + +@timeout(5) +@pytest.mark.parametrize("model_name", get_allowed_models()) +@pytest.mark.unit +def test_preamble_system_context(leptonai_model, model_name): + model = leptonai_model + model.name = model_name + conversation = Conversation() + + system_context = 'You only respond with the following phrase, "Jeff"' + system_message = SystemMessage(content=system_context) + conversation.add_message(system_message) + + input_data = "Hi" + human_message = HumanMessage(content=input_data) + conversation.add_message(human_message) + time.sleep(1) + + model.predict(conversation=conversation) + prediction = conversation.get_last().content + assert isinstance(prediction, str) + assert "Jeff" in prediction + assert isinstance(conversation.get_last().usage, UsageData) + + +@timeout(5) +@pytest.mark.parametrize("model_name", get_allowed_models()) +@pytest.mark.unit +def test_stream(leptonai_model, model_name): + model = leptonai_model + model.name = model_name + conversation = Conversation() + + input_data = "Write a short story about a cat." + human_message = HumanMessage(content=input_data) + conversation.add_message(human_message) + time.sleep(1) + + collected_tokens = [] + for token in model.stream(conversation=conversation): + assert isinstance(token, str) + collected_tokens.append(token) + + full_response = "".join(collected_tokens) + assert len(full_response) > 0 + assert conversation.get_last().content == full_response + assert isinstance(conversation.get_last().usage, UsageData) + + +@timeout(5) +@pytest.mark.parametrize("model_name", get_allowed_models()) +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.unit +async def test_apredict(leptonai_model, model_name): + model = leptonai_model + model.name = model_name + conversation = Conversation() + + input_data = "Hello" + human_message = HumanMessage(content=input_data) + conversation.add_message(human_message) + await asyncio.sleep(1) + + result = await model.apredict(conversation=conversation) + prediction = result.get_last().content + assert isinstance(prediction, str) + assert isinstance(conversation.get_last().usage, UsageData) + logging.info(conversation.get_last().usage) + + +@timeout(5) +@pytest.mark.parametrize("model_name", get_allowed_models()) +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.unit +async def test_astream(leptonai_model, model_name): + model = leptonai_model + model.name = model_name + conversation = Conversation() + + input_data = "Write a short story about a dog." + human_message = HumanMessage(content=input_data) + conversation.add_message(human_message) + await asyncio.sleep(1) + + collected_tokens = [] + async for token in model.astream(conversation=conversation): + assert isinstance(token, str) + collected_tokens.append(token) + + full_response = "".join(collected_tokens) + assert len(full_response) > 0 + assert conversation.get_last().content == full_response + assert isinstance(conversation.get_last().usage, UsageData) + logging.info(conversation.get_last().usage) + + +@timeout(5) +@pytest.mark.parametrize("model_name", get_allowed_models()) +@pytest.mark.unit +def test_batch(leptonai_model, model_name): + model = leptonai_model + model.name = model_name + conversations = [] + for prompt in ["Hello", "Hi there", "Good morning"]: + conv = Conversation() + conv.add_message(HumanMessage(content=prompt)) + conversations.append(conv) + time.sleep(1) + + results = model.batch(conversations=conversations) + assert len(results) == len(conversations) + for result in results: + assert isinstance(result.get_last().content, str) + assert isinstance(result.get_last().usage, UsageData) + + +@timeout(5) +@pytest.mark.parametrize("model_name", get_allowed_models()) +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.unit +async def test_abatch(leptonai_model, model_name): + model = leptonai_model + model.name = model_name + conversations = [] + for prompt in ["Hello", "Hi there", "Good morning"]: + conv = Conversation() + conv.add_message(HumanMessage(content=prompt)) + conversations.append(conv) + await asyncio.sleep(1) + + results = await model.abatch(conversations=conversations) + assert len(results) == len(conversations) + for result in results: + assert isinstance(result.get_last().content, str) + assert isinstance(result.get_last().usage, UsageData) diff --git a/pkgs/standards/swarmauri_standard/swarmauri_standard/image_gens/BlackForestImgGenModel.py b/pkgs/standards/swarmauri_standard/swarmauri_standard/image_gens/BlackForestImgGenModel.py index 7a3c3ffd..7e2f8774 100644 --- a/pkgs/standards/swarmauri_standard/swarmauri_standard/image_gens/BlackForestImgGenModel.py +++ b/pkgs/standards/swarmauri_standard/swarmauri_standard/image_gens/BlackForestImgGenModel.py @@ -8,7 +8,8 @@ from swarmauri_base.image_gens.ImageGenBase import ImageGenBase from swarmauri_core.ComponentBase import ComponentBase -@ComponentBase.register_type(ImageGenBase, 'BlackForestImgGenModel') + +@ComponentBase.register_type(ImageGenBase, "BlackForestImgGenModel") class BlackForestImgGenModel(ImageGenBase): """ A model for generating images using FluxPro's image generation models through the Black Forest API.