Skip to content

Commit

Permalink
Merge pull request #1084 from MichaelDecent/comm_pkg20
Browse files Browse the repository at this point in the history
Add Swarmauri LeptonAI Model Community Packages
  • Loading branch information
cobycloud authored Jan 15, 2025
2 parents eb80c2b + f6a4588 commit 3593a9e
Show file tree
Hide file tree
Showing 8 changed files with 747 additions and 1 deletion.
1 change: 1 addition & 0 deletions pkgs/community/swarmauri_llm_communityleptonai/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Swarmauri Example Community Package
57 changes: 57 additions & 0 deletions pkgs/community/swarmauri_llm_communityleptonai/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
[tool.poetry]
name = "swarmauri_llm_communityleptonai"
version = "0.6.0.dev1"
description = "Swarmauri Lepton AI Model"
authors = ["Jacob Stewart <[email protected]>"]
license = "Apache-2.0"
readme = "README.md"
repository = "http://github.com/swarmauri/swarmauri-sdk"
classifiers = [
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12"
]

[tool.poetry.dependencies]
python = ">=3.10,<3.13"

# Swarmauri
swarmauri_core = { path = "../../core" }
swarmauri_base = { path = "../../base" }

# Dependencies
leptonai = "^0.22.0"

[tool.poetry.group.dev.dependencies]
flake8 = "^7.0"
pytest = "^8.0"
pytest-asyncio = ">=0.24.0"
pytest-xdist = "^3.6.1"
pytest-json-report = "^1.5.0"
python-dotenv = "*"
requests = "^2.32.3"

[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

[tool.pytest.ini_options]
norecursedirs = ["combined", "scripts"]

markers = [
"test: standard test",
"unit: Unit tests",
"integration: Integration tests",
"acceptance: Acceptance tests",
"experimental: Experimental tests"
]
log_cli = true
log_cli_level = "INFO"
log_cli_format = "%(asctime)s [%(levelname)s] %(message)s"
log_cli_date_format = "%Y-%m-%d %H:%M:%S"
asyncio_default_fixture_loop_scope = "function"

[tool.poetry.plugins."swarmauri.llms"]
LeptonAIImgGenModel = "swarmauri_llm_communityleptonai.LeptonAIImgGenModel:LeptonAIImgGenModel"
LeptonAIModel = "swarmauri_llm_communityleptonai.LeptonAIImgGenModel:LeptonAIModel"
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import os
import asyncio
import requests
from io import BytesIO
from PIL import Image
from typing import List, Literal
from pydantic import Field, ConfigDict
from swarmauri_base.llms.LLMBase import LLMBase


class LeptonAIImgGenModel(LLMBase):
"""
A model for generating images from text using Lepton AI's SDXL image generation model.
It returns the image as bytes.
Get your API KEY from Lepton AI.
"""

api_key: str = Field(default_factory=lambda: os.environ.get("LEPTON_API_KEY"))
model_name: str = Field(default="sdxl")
type: Literal["LeptonAIImgGenModel"] = "LeptonAIImgGenModel"
base_url: str = Field(default="https://sdxl.lepton.run")

model_config = ConfigDict(protected_namespaces=())

def __init__(self, **data):
super().__init__(**data)
if self.api_key:
os.environ["LEPTON_API_KEY"] = self.api_key

def _send_request(self, prompt: str, **kwargs) -> bytes:
"""Send a request to Lepton AI's API for image generation."""
client = requests.Session()
client.headers.update({"Authorization": f"Bearer {self.api_key}"})

payload = {
"prompt": prompt,
"height": kwargs.get("height", 1024),
"width": kwargs.get("width", 1024),
"guidance_scale": kwargs.get("guidance_scale", 5),
"high_noise_frac": kwargs.get("high_noise_frac", 0.75),
"seed": kwargs.get("seed", None),
"steps": kwargs.get("steps", 30),
"use_refiner": kwargs.get("use_refiner", False),
}

response = client.post(f"{self.base_url}/run", json=payload)
response.raise_for_status()
return response.content

def generate_image(self, prompt: str, **kwargs) -> bytes:
"""Generates an image based on the prompt and returns the image as bytes."""
return self._send_request(prompt, **kwargs)

async def agenerate_image(self, prompt: str, **kwargs) -> bytes:
"""Asynchronously generates an image based on the prompt and returns the image as bytes."""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self.generate_image, prompt, **kwargs)

def batch(self, prompts: List[str], **kwargs) -> List[bytes]:
"""
Generates images for a batch of prompts.
Returns a list of image bytes.
"""
image_bytes_list = []
for prompt in prompts:
image_bytes = self.generate_image(prompt=prompt, **kwargs)
image_bytes_list.append(image_bytes)
return image_bytes_list

async def abatch(
self, prompts: List[str], max_concurrent: int = 5, **kwargs
) -> List[bytes]:
"""
Asynchronously generates images for a batch of prompts.
Returns a list of image bytes.
"""
semaphore = asyncio.Semaphore(max_concurrent)

async def process_prompt(prompt):
async with semaphore:
return await self.agenerate_image(prompt=prompt, **kwargs)

tasks = [process_prompt(prompt) for prompt in prompts]
return await asyncio.gather(*tasks)

@staticmethod
def save_image(image_bytes: bytes, filename: str):
"""Utility method to save the image bytes to a file."""
with open(filename, "wb") as f:
f.write(image_bytes)
print(f"Image saved as {filename}")

@staticmethod
def display_image(image_bytes: bytes):
"""Utility method to display the image using PIL."""
image = Image.open(BytesIO(image_bytes))
image.show()
Loading

0 comments on commit 3593a9e

Please sign in to comment.