Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update closed loop server with new interface #9

Merged
merged 3 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 40 additions & 15 deletions nerfstudio/scripts/closed_loop/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,61 +13,86 @@
# limitations under the License.
from __future__ import annotations

import base64
import io
from typing import Literal, Union

import numpy as np
import torch
import tyro
import uvicorn
from fastapi import FastAPI, Response
from fastapi import FastAPI, HTTPException, Response
from PIL import Image
from torch import Tensor

from nerfstudio.scripts.closed_loop.models import ActorTrajectory, RenderInput
from nerfstudio.scripts.closed_loop.models import ActorTrajectory, ImageFormat, RenderInput
from nerfstudio.scripts.closed_loop.server import ClosedLoopServer

app = FastAPI()


@app.get("/alive")
async def alive() -> bool:
def alive() -> bool:
return True


@app.get("/get_actors")
async def get_actors() -> list[ActorTrajectory]:
def get_actors() -> list[ActorTrajectory]:
"""Get actor trajectories."""
actor_trajectories = cl_server.get_actor_trajectories()
actor_trajectories = [ActorTrajectory.from_torch(act_traj) for act_traj in actor_trajectories]
return actor_trajectories


@app.post("/update_actors")
async def update_actors(actor_trajectories: list[ActorTrajectory]) -> None:
def update_actors(actor_trajectories: list[ActorTrajectory]) -> None:
"""Update actor trajectories (keys correspond to actor uuids)."""
torch_actor_trajectories = [act_traj.to_torch() for act_traj in actor_trajectories]
cl_server.update_actor_trajectories(torch_actor_trajectories)


@app.post("/render_image", response_class=Response, responses={200: {"content": {"image/png": {}}}})
async def render_image(data: RenderInput) -> Response:
@app.post(
"/render_image",
response_class=Response,
responses={200: {"content": {"text/plain": {}, "image/png": {}, "image/jpeg": {}}}},
)
def get_image(data: RenderInput) -> Response:
torch_pose = torch.tensor(data.pose, dtype=torch.float32)
render = cl_server.get_image(torch_pose, data.timestamp, data.camera_name)
return Response(content=_torch_to_png(render), media_type="image/png")
if data.image_format == ImageFormat.raw:
return Response(content=_torch_to_bytestr(render), media_type="text/plain")
elif data.image_format == ImageFormat.png:
return Response(content=_torch_to_img(render, "png"), media_type="image/png")
elif data.image_format in (ImageFormat.jpg, ImageFormat.jpeg):
return Response(content=_torch_to_img(render, "jpeg"), media_type="text/jpeg")
else:
raise HTTPException(
status_code=400, detail=f"Invalid image format: {data.image_format}, must be 'raw', 'png', 'jpg', or 'jpeg'"
)


@app.get("/start_time")
async def get_start_time() -> int:
def get_start_time() -> int:
return int(cl_server.min_time * 1e6)


def _torch_to_png(render: Tensor) -> bytes:
"""Convert a torch tensor to a PNG image."""
def _torch_to_bytestr(render: Tensor) -> bytes:
"""Convert a torch tensor to a base64 encoded bytestring."""
buff = io.BytesIO()
img = (render * 255).to(torch.uint8).cpu()
torch.save(img, buff)
return base64.b64encode(buff.getvalue())


def _torch_to_img(render: Tensor, format: Union[Literal["jpeg"], Literal["png"]]) -> bytes:
"""Convert a torch tensor to a PNG or JPG image."""
if format not in ("jpeg", "png"):
raise ValueError(f"Invalid format: {format}")

img = Image.fromarray((render * 255).cpu().numpy().astype(np.uint8))
image_stream = io.BytesIO()
img.save(image_stream, format="PNG")
image_bytes = image_stream.getvalue()
return image_bytes
buff = io.BytesIO()
img.save(buff, format=format.upper())
return buff.getvalue()


if __name__ == "__main__":
Expand Down
10 changes: 10 additions & 0 deletions nerfstudio/scripts/closed_loop/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,21 @@
# limitations under the License.
from __future__ import annotations

from enum import Enum
from typing import List, TypedDict

import torch
from pydantic import BaseModel
from torch import Tensor


class ImageFormat(str, Enum):
raw = "raw" # will return a raw tensor, works good when sending across same machine
png = "png" # more suitable if sent over network
jpg = "jpg" # more suitable if sent over network, pseudo for jpeg
jpeg = "jpeg" # more suitable if sent over network


class TrajectoryDict(TypedDict):
uuid: str
poses: Tensor
Expand Down Expand Up @@ -66,3 +74,5 @@ class RenderInput(BaseModel):
"""Timestamp in microseconds"""
camera_name: str
"""Camera name"""
image_format: ImageFormat = ImageFormat.raw
"""What format to return the image in. Defaults to raw tensor."""
Loading