Skip to content

Commit

Permalink
Reduce startup overhead (#2763)
Browse files Browse the repository at this point in the history
* Reduce import overhead

* Bump tyro, suppress deprecated config values

* Bump viser

* Fix pyright
  • Loading branch information
brentyi authored Jan 17, 2024
1 parent e00da7d commit 05d3054
Show file tree
Hide file tree
Showing 20 changed files with 120 additions and 93 deletions.
6 changes: 4 additions & 2 deletions nerfstudio/cameras/camera_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import Literal, Optional, Type, Union

import torch
import tyro
from jaxtyping import Float, Int
from torch import Tensor, nn
from typing_extensions import assert_never
Expand Down Expand Up @@ -51,10 +52,11 @@ class CameraOptimizerConfig(InstantiateConfig):
rot_l2_penalty: float = 1e-3
"""L2 penalty on rotation parameters."""

optimizer: Optional[OptimizerConfig] = field(default=None)
# tyro.conf.Suppress prevents us from creating CLI arguments for these fields.
optimizer: tyro.conf.Suppress[Optional[OptimizerConfig]] = field(default=None)
"""Deprecated, now specified inside the optimizers dict"""

scheduler: Optional[SchedulerConfig] = field(default=None)
scheduler: tyro.conf.Suppress[Optional[SchedulerConfig]] = field(default=None)
"""Deprecated, now specified inside the optimizers dict"""

def __post_init__(self):
Expand Down
7 changes: 5 additions & 2 deletions nerfstudio/cameras/cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

import cv2
import torch
import torchvision
from jaxtyping import Float, Int, Shaped
from torch import Tensor
from torch.nn import Parameter
Expand Down Expand Up @@ -959,7 +958,11 @@ def to_json(
image_uint8 = (image * 255).detach().type(torch.uint8)
if max_size is not None:
image_uint8 = image_uint8.permute(2, 0, 1)
image_uint8 = torchvision.transforms.functional.resize(image_uint8, max_size, antialias=None) # type: ignore

# torchvision can be slow to import, so we do it lazily.
import torchvision.transforms.functional as TF

image_uint8 = TF.resize(image_uint8, max_size, antialias=None) # type: ignore
image_uint8 = image_uint8.permute(1, 2, 0)
image_uint8 = image_uint8.cpu().numpy()
data = cv2.imencode(".jpg", image_uint8)[1].tobytes() # type: ignore
Expand Down
5 changes: 4 additions & 1 deletion nerfstudio/data/datamanagers/base_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
)

import torch
import tyro
from torch import nn
from torch.nn import Parameter
from torch.utils.data.distributed import DistributedSampler
Expand Down Expand Up @@ -334,7 +335,9 @@ class VanillaDataManagerConfig(DataManagerConfig):
"""
patch_size: int = 1
"""Size of patch to sample from. If > 1, patch-based sampling will be used."""
camera_optimizer: Optional[CameraOptimizerConfig] = field(default=None)

# tyro.conf.Suppress prevents us from creating CLI arguments for this field.
camera_optimizer: tyro.conf.Suppress[Optional[CameraOptimizerConfig]] = field(default=None)
"""Deprecated, has been moved to the model config."""
pixel_sampler: PixelSamplerConfig = field(default_factory=PixelSamplerConfig)
"""Specifies the pixel sampler used to sample pixels from images."""
Expand Down
3 changes: 2 additions & 1 deletion nerfstudio/data/dataparsers/nerfstudio_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from typing import Literal, Optional, Type

import numpy as np
import open3d as o3d
import torch
from PIL import Image

Expand Down Expand Up @@ -337,6 +336,8 @@ def _generate_dataparser_outputs(self, split="train"):
return dataparser_outputs

def _load_3D_points(self, ply_file_path: Path, transform_matrix: torch.Tensor, scale_factor: float):
import open3d as o3d # Importing open3d is slow, so we only do it if we need it.

pcd = o3d.io.read_point_cloud(str(ply_file_path))

points3D = torch.from_numpy(np.asarray(pcd.points, dtype=np.float32))
Expand Down
4 changes: 3 additions & 1 deletion nerfstudio/data/dataparsers/nuscenes_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import numpy as np
import pyquaternion
import torch
from nuscenes.nuscenes import NuScenes as NuScenesDatabase

from nerfstudio.cameras.cameras import Cameras, CameraType
from nerfstudio.data.dataparsers.base_dataparser import DataParser, DataParserConfig, DataparserOutputs
Expand Down Expand Up @@ -81,6 +80,9 @@ class NuScenes(DataParser):
config: NuScenesDataParserConfig

def _generate_dataparser_outputs(self, split="train"):
# nuscenes is slow to import, so we only do it if we need it.
from nuscenes.nuscenes import NuScenes as NuScenesDatabase

nusc = NuScenesDatabase(
version=self.config.version,
dataroot=str(self.config.data_dir.absolute()),
Expand Down
10 changes: 8 additions & 2 deletions nerfstudio/exporter/exporter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@

import sys
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

import numpy as np
import open3d as o3d
import pymeshlab
import torch
from jaxtyping import Float
Expand All @@ -38,6 +37,11 @@
from nerfstudio.pipelines.base_pipeline import Pipeline, VanillaPipeline
from nerfstudio.utils.rich_utils import CONSOLE, ItersPerSecColumn

if TYPE_CHECKING:
# Importing open3d can take ~1 second, so only do it below if we actually
# need it.
import open3d as o3d


@dataclass
class Mesh:
Expand Down Expand Up @@ -193,6 +197,8 @@ def generate_point_cloud(
rgbs = torch.cat(rgbs, dim=0)
view_directions = torch.cat(view_directions, dim=0).cpu()

import open3d as o3d

pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points.double().cpu().numpy())
pcd.colors = o3d.utility.Vector3dVector(rgbs.double().cpu().numpy())
Expand Down
28 changes: 19 additions & 9 deletions nerfstudio/generative/deepfloyd.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import gc
import sys
from pathlib import Path
from typing import List, Optional, Union

Expand All @@ -24,15 +25,7 @@
from torch import Generator, Tensor, nn
from torch.cuda.amp.grad_scaler import GradScaler

from nerfstudio.generative.utils import CatchMissingPackages

try:
from diffusers import DiffusionPipeline, IFPipeline, IFPipeline as IFOrig
from diffusers.pipelines.deepfloyd_if import IFPipelineOutput, IFPipelineOutput as IFOutputOrig
from transformers import T5EncoderModel

except ImportError:
IFPipeline = IFPipelineOutput = T5EncoderModel = CatchMissingPackages()
from nerfstudio.utils.rich_utils import CONSOLE

IMG_DIM = 64

Expand All @@ -47,6 +40,16 @@ def __init__(self, device: Union[torch.device, str]):
super().__init__()
self.device = device

try:
from diffusers import DiffusionPipeline, IFPipeline
from transformers import T5EncoderModel

except ImportError:
CONSOLE.print("[bold red]Missing Stable Diffusion packages.")
CONSOLE.print(r"Install using [yellow]pip install nerfstudio\[gen][/yellow]")
CONSOLE.print(r"or [yellow]pip install -e .\[gen][/yellow] if installing from source.")
sys.exit(1)

self.text_encoder = T5EncoderModel.from_pretrained(
"DeepFloyd/IF-I-L-v1.0",
subfolder="text_encoder",
Expand Down Expand Up @@ -90,6 +93,8 @@ def delete_text_encoder(self):
gc.collect()
torch.cuda.empty_cache()

from diffusers import DiffusionPipeline, IFPipeline

self.pipe = IFPipeline.from_pretrained(
"DeepFloyd/IF-I-L-v1.0",
text_encoder=None,
Expand Down Expand Up @@ -126,6 +131,8 @@ def get_text_embeds(
prompt = [prompt] if isinstance(prompt, str) else prompt
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt

from diffusers import DiffusionPipeline

assert isinstance(self.pipe, DiffusionPipeline)
with torch.no_grad():
prompt_embeds, negative_embeds = self.pipe.encode_prompt(prompt, negative_prompt=negative_prompt)
Expand Down Expand Up @@ -200,6 +207,9 @@ def prompt_to_image(
The generated image.
"""

from diffusers import DiffusionPipeline, IFPipeline as IFOrig
from diffusers.pipelines.deepfloyd_if import IFPipelineOutput as IFOutputOrig

prompts = [prompts] if isinstance(prompts, str) else prompts
negative_prompts = [negative_prompts] if isinstance(negative_prompts, str) else negative_prompts
assert isinstance(self.pipe, DiffusionPipeline)
Expand Down
22 changes: 13 additions & 9 deletions nerfstudio/generative/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@

# Modified from https://github.com/ashawkey/stable-dreamfusion/blob/main/nerf/sd.py

import sys
from pathlib import Path
from typing import List, Optional, Union

import mediapy
import numpy as np
import torch
import torch.nn.functional as F
Expand All @@ -28,16 +28,8 @@
from torch import Tensor, nn
from torch.cuda.amp.grad_scaler import GradScaler

from nerfstudio.generative.utils import CatchMissingPackages
from nerfstudio.utils.rich_utils import CONSOLE

try:
from diffusers import DiffusionPipeline, PNDMScheduler, StableDiffusionPipeline

except ImportError:
PNDMScheduler = StableDiffusionPipeline = CatchMissingPackages()


IMG_DIM = 512
CONST_SCALE = 0.18215
SD_IDENTIFIERS = {
Expand All @@ -57,6 +49,15 @@ class StableDiffusion(nn.Module):
def __init__(self, device: Union[torch.device, str], num_train_timesteps: int = 1000, version="1-5") -> None:
super().__init__()

try:
from diffusers import DiffusionPipeline, PNDMScheduler, StableDiffusionPipeline

except ImportError:
CONSOLE.print("[bold red]Missing Stable Diffusion packages.")
CONSOLE.print(r"Install using [yellow]pip install nerfstudio\[gen][/yellow]")
CONSOLE.print(r"or [yellow]pip install -e .\[gen][/yellow] if installing from source.")
sys.exit(1)

self.device = device
self.num_train_timesteps = num_train_timesteps

Expand Down Expand Up @@ -319,6 +320,9 @@ def generate_image(
with torch.no_grad():
sd = StableDiffusion(cuda_device)
imgs = sd.prompt_to_img(prompt, negative, steps)

import mediapy # Slow to import, so we do it lazily.

mediapy.write_image(str(save_path), imgs[0])


Expand Down
35 changes: 0 additions & 35 deletions nerfstudio/generative/utils.py

This file was deleted.

7 changes: 4 additions & 3 deletions nerfstudio/models/base_surface_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@
import torch
import torch.nn.functional as F
from torch.nn import Parameter
from torchmetrics.functional import structural_similarity_index_measure
from torchmetrics.image import PeakSignalNoiseRatio
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

from nerfstudio.cameras.rays import RayBundle
from nerfstudio.field_components.encodings import NeRFEncoding
Expand Down Expand Up @@ -156,6 +153,10 @@ def populate_modules(self):
self.depth_loss = ScaleAndShiftInvariantLoss(alpha=0.5, scales=1)

# metrics
from torchmetrics.functional import structural_similarity_index_measure
from torchmetrics.image import PeakSignalNoiseRatio
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

self.psnr = PeakSignalNoiseRatio(data_range=1.0)
self.ssim = structural_similarity_index_measure
self.lpips = LearnedPerceptualImagePatchSimilarity()
Expand Down
22 changes: 17 additions & 5 deletions nerfstudio/models/gaussian_splatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,13 @@

import numpy as np
import torch
import torchvision.transforms.functional as TF
from gsplat._torch_impl import quat_to_rotmat
from gsplat.compute_cumulative_intersects import compute_cumulative_intersects
from gsplat.project_gaussians import ProjectGaussians
from gsplat.rasterize import RasterizeGaussians
from gsplat.sh import SphericalHarmonics, num_sh_bases
from pytorch_msssim import SSIM
from sklearn.neighbors import NearestNeighbors
from torch.nn import Parameter
from torchmetrics.image import PeakSignalNoiseRatio
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

from nerfstudio.cameras.camera_optimizers import CameraOptimizer, CameraOptimizerConfig
from nerfstudio.cameras.cameras import Cameras
Expand Down Expand Up @@ -205,6 +201,9 @@ def populate_modules(self):
self.opacities = torch.nn.Parameter(torch.logit(0.1 * torch.ones(self.num_points, 1)))

# metrics
from torchmetrics.image import PeakSignalNoiseRatio
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

self.psnr = PeakSignalNoiseRatio(data_range=1.0)
self.ssim = SSIM(data_range=1.0, size_average=True, channel=3)
self.lpips = LearnedPerceptualImagePatchSimilarity(normalize=True)
Expand Down Expand Up @@ -248,14 +247,16 @@ def load_state_dict(self, dict, **kwargs): # type: ignore

def k_nearest_sklearn(self, x: torch.Tensor, k: int):
"""
Find k-nearest neighbors using sklearn's NearestNeighbors.
Find k-nearest neighbors using sklearn's NearestNeighbors.
x: The data tensor of shape [num_samples, num_features]
k: The number of neighbors to retrieve
"""
# Convert tensor to numpy array
x_np = x.cpu().numpy()

# Build the nearest neighbors model
from sklearn.neighbors import NearestNeighbors

nn_model = NearestNeighbors(n_neighbors=k + 1, algorithm="auto", metric="euclidean").fit(x_np)

# Find the k-nearest neighbors
Expand Down Expand Up @@ -733,6 +734,10 @@ def get_metrics_dict(self, outputs, batch) -> Dict[str, torch.Tensor]:
d = self._get_downscale_factor()
if d > 1:
newsize = [batch["image"].shape[0] // d, batch["image"].shape[1] // d]

# torchvision can be slow to import, so we do it lazily.
import torchvision.transforms.functional as TF

gt_img = TF.resize(batch["image"].permute(2, 0, 1), newsize, antialias=None).permute(1, 2, 0)
else:
gt_img = batch["image"]
Expand All @@ -756,6 +761,10 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Te
d = self._get_downscale_factor()
if d > 1:
newsize = [batch["image"].shape[0] // d, batch["image"].shape[1] // d]

# torchvision can be slow to import, so we do it lazily.
import torchvision.transforms.functional as TF

gt_img = TF.resize(batch["image"].permute(2, 0, 1), newsize, antialias=None).permute(1, 2, 0)
else:
gt_img = batch["image"]
Expand Down Expand Up @@ -807,6 +816,9 @@ def get_image_metrics_and_images(
"""
d = self._get_downscale_factor()
if d > 1:
# torchvision can be slow to import, so we do it lazily.
import torchvision.transforms.functional as TF

newsize = [batch["image"].shape[0] // d, batch["image"].shape[1] // d]
gt_img = TF.resize(batch["image"].permute(2, 0, 1), newsize, antialias=None).permute(1, 2, 0)
predicted_rgb = TF.resize(outputs["rgb"].permute(2, 0, 1), newsize, antialias=None).permute(1, 2, 0)
Expand Down
Loading

0 comments on commit 05d3054

Please sign in to comment.