diff --git a/nerfstudio/scripts/closed_loop/main.py b/nerfstudio/scripts/closed_loop/main.py index b4f3c590..00006266 100644 --- a/nerfstudio/scripts/closed_loop/main.py +++ b/nerfstudio/scripts/closed_loop/main.py @@ -13,29 +13,31 @@ # limitations under the License. from __future__ import annotations +import base64 import io +from typing import Literal, Union import numpy as np import torch import tyro import uvicorn -from fastapi import FastAPI, Response +from fastapi import FastAPI, HTTPException, Response from PIL import Image from torch import Tensor -from nerfstudio.scripts.closed_loop.models import ActorTrajectory, RenderInput +from nerfstudio.scripts.closed_loop.models import ActorTrajectory, ImageFormat, RenderInput from nerfstudio.scripts.closed_loop.server import ClosedLoopServer app = FastAPI() @app.get("/alive") -async def alive() -> bool: +def alive() -> bool: return True @app.get("/get_actors") -async def get_actors() -> list[ActorTrajectory]: +def get_actors() -> list[ActorTrajectory]: """Get actor trajectories.""" actor_trajectories = cl_server.get_actor_trajectories() actor_trajectories = [ActorTrajectory.from_torch(act_traj) for act_traj in actor_trajectories] @@ -43,31 +45,54 @@ async def get_actors() -> list[ActorTrajectory]: @app.post("/update_actors") -async def update_actors(actor_trajectories: list[ActorTrajectory]) -> None: +def update_actors(actor_trajectories: list[ActorTrajectory]) -> None: """Update actor trajectories (keys correspond to actor uuids).""" torch_actor_trajectories = [act_traj.to_torch() for act_traj in actor_trajectories] cl_server.update_actor_trajectories(torch_actor_trajectories) -@app.post("/render_image", response_class=Response, responses={200: {"content": {"image/png": {}}}}) -async def render_image(data: RenderInput) -> Response: +@app.post( + "/render_image", + response_class=Response, + responses={200: {"content": {"text/plain": {}, "image/png": {}, "image/jpeg": {}}}}, +) +def get_image(data: RenderInput) -> Response: torch_pose = torch.tensor(data.pose, dtype=torch.float32) render = cl_server.get_image(torch_pose, data.timestamp, data.camera_name) - return Response(content=_torch_to_png(render), media_type="image/png") + if data.image_format == ImageFormat.raw: + return Response(content=_torch_to_bytestr(render), media_type="text/plain") + elif data.image_format == ImageFormat.png: + return Response(content=_torch_to_img(render, "png"), media_type="image/png") + elif data.image_format in (ImageFormat.jpg, ImageFormat.jpeg): + return Response(content=_torch_to_img(render, "jpeg"), media_type="text/jpeg") + else: + raise HTTPException( + status_code=400, detail=f"Invalid image format: {data.image_format}, must be 'raw', 'png', 'jpg', or 'jpeg'" + ) @app.get("/start_time") -async def get_start_time() -> int: +def get_start_time() -> int: return int(cl_server.min_time * 1e6) -def _torch_to_png(render: Tensor) -> bytes: - """Convert a torch tensor to a PNG image.""" +def _torch_to_bytestr(render: Tensor) -> bytes: + """Convert a torch tensor to a base64 encoded bytestring.""" + buff = io.BytesIO() + img = (render * 255).to(torch.uint8).cpu() + torch.save(img, buff) + return base64.b64encode(buff.getvalue()) + + +def _torch_to_img(render: Tensor, format: Union[Literal["jpeg"], Literal["png"]]) -> bytes: + """Convert a torch tensor to a PNG or JPG image.""" + if format not in ("jpeg", "png"): + raise ValueError(f"Invalid format: {format}") + img = Image.fromarray((render * 255).cpu().numpy().astype(np.uint8)) - image_stream = io.BytesIO() - img.save(image_stream, format="PNG") - image_bytes = image_stream.getvalue() - return image_bytes + buff = io.BytesIO() + img.save(buff, format=format.upper()) + return buff.getvalue() if __name__ == "__main__": diff --git a/nerfstudio/scripts/closed_loop/models.py b/nerfstudio/scripts/closed_loop/models.py index 50297a0b..16490c10 100644 --- a/nerfstudio/scripts/closed_loop/models.py +++ b/nerfstudio/scripts/closed_loop/models.py @@ -13,6 +13,7 @@ # limitations under the License. from __future__ import annotations +from enum import Enum from typing import List, TypedDict import torch @@ -20,6 +21,13 @@ from torch import Tensor +class ImageFormat(str, Enum): + raw = "raw" # will return a raw tensor, works good when sending across same machine + png = "png" # more suitable if sent over network + jpg = "jpg" # more suitable if sent over network, pseudo for jpeg + jpeg = "jpeg" # more suitable if sent over network + + class TrajectoryDict(TypedDict): uuid: str poses: Tensor @@ -66,3 +74,5 @@ class RenderInput(BaseModel): """Timestamp in microseconds""" camera_name: str """Camera name""" + image_format: ImageFormat = ImageFormat.raw + """What format to return the image in. Defaults to raw tensor."""