Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Live portrait runner #199

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
5f99980
update to live-portrait workflow.
JJassonn69 Sep 12, 2024
89e140f
update to make base container lean
JJassonn69 Sep 12, 2024
e4f0f9e
reverting the debug patch
JJassonn69 Sep 12, 2024
1ebd2f5
go api bindings
JJassonn69 Sep 13, 2024
d886ad3
Merge branch 'main' into live-portrait-runner
JJassonn69 Sep 20, 2024
edefccf
Update docker.go
JJassonn69 Sep 20, 2024
d365a85
Update requirements.txt
JJassonn69 Sep 20, 2024
665480a
Update requirements.txt to install custom wheel for running liveportrait
JJassonn69 Sep 20, 2024
9ce666f
Update requirements.txt
JJassonn69 Sep 20, 2024
d9dd2e6
upgrades to requirements
JJassonn69 Sep 20, 2024
514fbc3
changes to rebase
JJassonn69 Sep 20, 2024
576a741
something was out of place
JJassonn69 Sep 20, 2024
31fa793
update for some issues fixed in test production
JJassonn69 Sep 25, 2024
9b50bc3
update from testing final
JJassonn69 Sep 27, 2024
52b8a8f
update to isolate the build context from base image.
JJassonn69 Oct 11, 2024
298ee70
final push changing the openapi specs and requirements.
JJassonn69 Oct 11, 2024
baa8b9d
update to docker file to include the mapping for sperate pipeline image
JJassonn69 Oct 11, 2024
6496990
Merge branch 'main' into live-portrait-runner
JJassonn69 Oct 11, 2024
1a3f21f
update to optimize the frames animations
JJassonn69 Oct 12, 2024
ace91c6
openapi bindings and go api generation
JJassonn69 Oct 12, 2024
caae90b
Update Dockerfile
JJassonn69 Oct 23, 2024
6479eca
Merge branch 'main' into live-portrait-runner
JJassonn69 Oct 23, 2024
e9e29a9
Update multipart.go
JJassonn69 Oct 23, 2024
9117c02
Update live_portrait.py
JJassonn69 Oct 23, 2024
315032e
Update live_portrait.py
JJassonn69 Oct 23, 2024
7e53d2b
Update live_portrait.py
JJassonn69 Oct 23, 2024
061a387
Merge branch 'main' into live-portrait-runner
JJassonn69 Oct 26, 2024
b8be128
Merge branch 'main' into pr/199
JJassonn69 Nov 4, 2024
67cd90c
Merge branch 'main' into pr/199
JJassonn69 Nov 13, 2024
b77cd3d
Merge branch 'main' into pr/199
JJassonn69 Nov 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions runner/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ def load_pipeline(pipeline: str, model_id: str) -> any:
from app.pipelines.upscale import UpscalePipeline

return UpscalePipeline(model_id)
case "live-portrait":
from app.pipelines.live_portrait import Inference

return Inference()
case "segment-anything-2":
from app.pipelines.segment_anything_2 import SegmentAnything2Pipeline

Expand Down Expand Up @@ -111,6 +115,10 @@ def load_route(pipeline: str) -> any:
from app.routes import upscale

return upscale.router
case "live-portrait":
from app.routes import live_portrait

return live_portrait.router
case "segment-anything-2":
from app.routes import segment_anything_2

Expand Down
36 changes: 36 additions & 0 deletions runner/app/pipelines/live_portrait.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# coding: utf-8
from liveportrait.core.config.argument_config import ArgumentConfig
from liveportrait.core.config.inference_config import InferenceConfig
from liveportrait.core.config.crop_config import CropConfig
from liveportrait.core.live_portrait_pipeline import LivePortraitPipeline
from app.pipelines.base import Pipeline
import logging

logger = logging.getLogger(__name__)

class Inference(Pipeline):
def __init__(self):
self.args = ArgumentConfig()

def __call__(self, source_image=None, driving_info=None) -> any:
"""Run the live portrait inference pipeline"""
self.args.source_image = source_image
self.args.driving_info = driving_info

# Partial initialization of target_class fields with kwargs
def _partial_fields(target_class, kwargs):
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})

# Specify configs for inference
inference_cfg = _partial_fields(InferenceConfig, self.args.__dict__)
crop_cfg = _partial_fields(CropConfig, self.args.__dict__)

# Initialize the live portrait pipeline
live_portrait_pipeline = LivePortraitPipeline(
inference_cfg=inference_cfg,
crop_cfg=crop_cfg
)

# Run the pipeline
wfp = live_portrait_pipeline.execute(self.args)
return wfp
140 changes: 140 additions & 0 deletions runner/app/routes/live_portrait.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import logging
import os
import multiprocessing
import cv2
from typing import Annotated, Dict, Tuple, Union

from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.routes.utils import HTTPError, VideoResponse, http_error, image_to_data_url, handle_pipeline_exception
from fastapi import APIRouter, Depends, File, Form, UploadFile, status
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from PIL import Image, ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True

router = APIRouter()

logger = logging.getLogger(__name__)

# Pipeline specific error handling configuration.
PIPELINE_ERROR_CONFIG: Dict[str, Tuple[Union[str, None], int]] = {
# Specific error types.
"faceNotDetected": (
"No face detected in either driving video or source image.",
status.HTTP_500_INTERNAL_SERVER_ERROR,
)
}


RESPONSES = {
status.HTTP_200_OK: {
"content": {
"application/json": {
"schema": {
"x-speakeasy-name-override": "data",
}
}
},
},
status.HTTP_400_BAD_REQUEST: {"model": HTTPError},
status.HTTP_401_UNAUTHORIZED: {"model": HTTPError},
status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError},
}

def process_frame(frame):
# Convert the frame from BGR (OpenCV format) to RGB and then to a PIL Image
rgb_frame = frame[:, :, ::-1] # Convert BGR to RGB using slicing
pil_image = Image.fromarray(rgb_frame, 'RGB')
return {
"url": image_to_data_url(pil_image), # Use the PIL Image here
"seed": 0, # LivePortrait doesn't use seeds
"nsfw": False, # LivePortrait doesn't perform NSFW checks
}

@router.post(
"/live-portrait",
response_model=VideoResponse,
responses=RESPONSES,
description="Generate a video from a provided source image and driving video.",
)
@router.post(
"/live-portrait/",
response_model=VideoResponse,
responses=RESPONSES,
include_in_schema=False,
)
async def live_portrait(
source_image: Annotated[
UploadFile,
File(description="Uploaded source image to animate."),
],
driving_video: Annotated[
UploadFile,
File(description="Uploaded driving video to guide the animation."),
],
model_id: Annotated[
str, Form(description="No model id needed as leave empty.")
] = "",
pipeline: Pipeline = Depends(get_pipeline),
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
):
auth_token = os.environ.get("AUTH_TOKEN")
if auth_token:
if not token or token.credentials != auth_token:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
headers={"WWW-Authenticate": "Bearer"},
content=http_error("Invalid bearer token"),
)
try:
# Save the driving video to a temporary file
temp_video_path = "temp_driving_video.mp4"
with open(temp_video_path, "wb") as buffer:
buffer.write(await driving_video.read())

temp_image_path = "temp_source_image.jpg"
with open(temp_image_path, "wb") as buffer:
buffer.write(await source_image.read())

result_video_path = pipeline(
source_image=temp_image_path,
driving_info=temp_video_path
)

output_frames = []

cap = cv2.VideoCapture(result_video_path)

with multiprocessing.Pool(processes=multiprocessing.cpu_count()) as pool:
futures = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break

futures.append(pool.apply_async(process_frame, (frame,)))

# Collect all processed frames directly
output_frames = [future.get() for future in futures]
cap.release()

except Exception as e:
logger.error(f"LivePortraitPipeline error: {e}")
return handle_pipeline_exception(
e,
default_error_message="Live-portrait pipeline error.",
custom_error_config=PIPELINE_ERROR_CONFIG,
)
finally:
# Clean up the temporary files
if os.path.exists(temp_video_path):
os.remove(temp_video_path)
if os.path.exists(result_video_path):
os.remove(result_video_path)
if os.path.exists(temp_image_path):
os.remove(temp_image_path)

# Return frames wrapped in an outer list, adhering to the required schema
return {"frames": [output_frames]}
5 changes: 5 additions & 0 deletions runner/docker/Dockerfile.live-portrait
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
FROM livepeer/ai-runner:base

RUN pip install --no-cache-dir https://github.com/JJassonn69/liveportrait/releases/download/liveportrait-livepeer/liveportrait-0.2.0-py3-none-any.whl

CMD ["uvicorn", "app.main:app", "--log-config", "app/cfg/uvicorn_logging_config.json", "--host", "0.0.0.0", "--port", "8000"]
70 changes: 68 additions & 2 deletions runner/gateway.openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,50 @@ paths:
security:
- HTTPBearer: []
x-speakeasy-name-override: audioToText
/live-portrait:
post:
summary: Live Portrait
description: Generate a video from a provided source image and driving video.
operationId: live_portrait_live_portrait_post
requestBody:
content:
multipart/form-data:
schema:
$ref: '#/components/schemas/Body_live_portrait_live_portrait_post'
required: true
responses:
'200':
description: Successful Response
content:
application/json:
schema:
$ref: '#/components/schemas/VideoResponse'
'400':
description: Bad Request
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPError'
'401':
description: Unauthorized
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPError'
'500':
description: Internal Server Error
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPError'
'422':
description: Validation Error
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
security:
- HTTPBearer: []
/segment-anything-2:
post:
tags:
Expand Down Expand Up @@ -841,6 +885,29 @@ components:
- image
- model_id
title: Body_genUpscale
Body_live_portrait_live_portrait_post:
properties:
source_image:
type: string
format: binary
title: Source Image
description: Uploaded source image to animate.
driving_video:
type: string
format: binary
title: Driving Video
description: Uploaded driving video to guide the animation.
model_id:
type: string
title: Model Id
description: No model id needed as leave empty.
default: ''
type: object
required:
- source_image
- driving_video
- model_id
title: Body_live_portrait_live_portrait_post
Chunk:
properties:
timestamp:
Expand All @@ -861,8 +928,7 @@ components:
HTTPError:
properties:
detail:
allOf:
- $ref: '#/components/schemas/APIError'
$ref: '#/components/schemas/APIError'
description: Detailed error information.
type: object
required:
Expand Down
4 changes: 4 additions & 0 deletions runner/gen_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
text_to_image,
text_to_speech,
upscale,
live_portrait,
llm,
image_to_text,
)

logging.basicConfig(
Expand Down Expand Up @@ -109,6 +112,7 @@ def write_openapi(fname: str, entrypoint: str = "runner"):
app.include_router(image_to_video.router)
app.include_router(upscale.router)
app.include_router(audio_to_text.router)
app.include_router(live_portrait.router)
app.include_router(segment_anything_2.router)
app.include_router(llm.router)
app.include_router(image_to_text.router)
Expand Down
Loading
Loading