Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Swarmauri LeptonAI Model Community Packages #1084

Merged
merged 2 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pkgs/community/swarmauri_llm_communityleptonai/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Swarmauri Example Community Package
57 changes: 57 additions & 0 deletions pkgs/community/swarmauri_llm_communityleptonai/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
[tool.poetry]
name = "swarmauri_llm_communityleptonai"
version = "0.6.0.dev1"
description = "Swarmauri Lepton AI Model"
authors = ["Jacob Stewart <[email protected]>"]
license = "Apache-2.0"
readme = "README.md"
repository = "http://github.com/swarmauri/swarmauri-sdk"
classifiers = [
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12"
]

[tool.poetry.dependencies]
python = ">=3.10,<3.13"

# Swarmauri
swarmauri_core = { path = "../../core" }
swarmauri_base = { path = "../../base" }

# Dependencies
leptonai = "^0.22.0"

[tool.poetry.group.dev.dependencies]
flake8 = "^7.0"
pytest = "^8.0"
pytest-asyncio = ">=0.24.0"
pytest-xdist = "^3.6.1"
pytest-json-report = "^1.5.0"
python-dotenv = "*"
requests = "^2.32.3"

[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

[tool.pytest.ini_options]
norecursedirs = ["combined", "scripts"]

markers = [
"test: standard test",
"unit: Unit tests",
"integration: Integration tests",
"acceptance: Acceptance tests",
"experimental: Experimental tests"
]
log_cli = true
log_cli_level = "INFO"
log_cli_format = "%(asctime)s [%(levelname)s] %(message)s"
log_cli_date_format = "%Y-%m-%d %H:%M:%S"
asyncio_default_fixture_loop_scope = "function"

[tool.poetry.plugins."swarmauri.llms"]
LeptonAIImgGenModel = "swarmauri_llm_communityleptonai.LeptonAIImgGenModel:LeptonAIImgGenModel"
LeptonAIModel = "swarmauri_llm_communityleptonai.LeptonAIImgGenModel:LeptonAIModel"
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import os
import asyncio
import requests
from io import BytesIO
from PIL import Image
from typing import List, Literal
from pydantic import Field, ConfigDict
from swarmauri_base.llms.LLMBase import LLMBase


class LeptonAIImgGenModel(LLMBase):
"""
A model for generating images from text using Lepton AI's SDXL image generation model.
It returns the image as bytes.
Get your API KEY from Lepton AI.
"""

api_key: str = Field(default_factory=lambda: os.environ.get("LEPTON_API_KEY"))
model_name: str = Field(default="sdxl")
type: Literal["LeptonAIImgGenModel"] = "LeptonAIImgGenModel"
base_url: str = Field(default="https://sdxl.lepton.run")

model_config = ConfigDict(protected_namespaces=())

def __init__(self, **data):
super().__init__(**data)
if self.api_key:
os.environ["LEPTON_API_KEY"] = self.api_key

def _send_request(self, prompt: str, **kwargs) -> bytes:
"""Send a request to Lepton AI's API for image generation."""
client = requests.Session()
client.headers.update({"Authorization": f"Bearer {self.api_key}"})

payload = {
"prompt": prompt,
"height": kwargs.get("height", 1024),
"width": kwargs.get("width", 1024),
"guidance_scale": kwargs.get("guidance_scale", 5),
"high_noise_frac": kwargs.get("high_noise_frac", 0.75),
"seed": kwargs.get("seed", None),
"steps": kwargs.get("steps", 30),
"use_refiner": kwargs.get("use_refiner", False),
}

response = client.post(f"{self.base_url}/run", json=payload)
response.raise_for_status()
return response.content

def generate_image(self, prompt: str, **kwargs) -> bytes:
"""Generates an image based on the prompt and returns the image as bytes."""
return self._send_request(prompt, **kwargs)

async def agenerate_image(self, prompt: str, **kwargs) -> bytes:
"""Asynchronously generates an image based on the prompt and returns the image as bytes."""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self.generate_image, prompt, **kwargs)

def batch(self, prompts: List[str], **kwargs) -> List[bytes]:
"""
Generates images for a batch of prompts.
Returns a list of image bytes.
"""
image_bytes_list = []
for prompt in prompts:
image_bytes = self.generate_image(prompt=prompt, **kwargs)
image_bytes_list.append(image_bytes)
return image_bytes_list

async def abatch(
self, prompts: List[str], max_concurrent: int = 5, **kwargs
) -> List[bytes]:
"""
Asynchronously generates images for a batch of prompts.
Returns a list of image bytes.
"""
semaphore = asyncio.Semaphore(max_concurrent)

async def process_prompt(prompt):
async with semaphore:
return await self.agenerate_image(prompt=prompt, **kwargs)

tasks = [process_prompt(prompt) for prompt in prompts]
return await asyncio.gather(*tasks)

@staticmethod
def save_image(image_bytes: bytes, filename: str):
"""Utility method to save the image bytes to a file."""
with open(filename, "wb") as f:
f.write(image_bytes)
print(f"Image saved as {filename}")

@staticmethod
def display_image(image_bytes: bytes):
"""Utility method to display the image using PIL."""
image = Image.open(BytesIO(image_bytes))
image.show()
Loading
Loading