Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce startup overhead #2763

Merged
merged 6 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions nerfstudio/cameras/camera_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import Literal, Optional, Type, Union

import torch
import tyro
from jaxtyping import Float, Int
from torch import Tensor, nn
from typing_extensions import assert_never
Expand Down Expand Up @@ -51,10 +52,11 @@ class CameraOptimizerConfig(InstantiateConfig):
rot_l2_penalty: float = 1e-3
"""L2 penalty on rotation parameters."""

optimizer: Optional[OptimizerConfig] = field(default=None)
# tyro.conf.Suppress prevents us from creating CLI arguments for these fields.
optimizer: tyro.conf.Suppress[Optional[OptimizerConfig]] = field(default=None)
"""Deprecated, now specified inside the optimizers dict"""

scheduler: Optional[SchedulerConfig] = field(default=None)
scheduler: tyro.conf.Suppress[Optional[SchedulerConfig]] = field(default=None)
"""Deprecated, now specified inside the optimizers dict"""

def __post_init__(self):
Expand Down
7 changes: 5 additions & 2 deletions nerfstudio/cameras/cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

import cv2
import torch
import torchvision
from jaxtyping import Float, Int, Shaped
from torch import Tensor
from torch.nn import Parameter
Expand Down Expand Up @@ -959,7 +958,11 @@ def to_json(
image_uint8 = (image * 255).detach().type(torch.uint8)
if max_size is not None:
image_uint8 = image_uint8.permute(2, 0, 1)
image_uint8 = torchvision.transforms.functional.resize(image_uint8, max_size, antialias=None) # type: ignore

# torchvision can be slow to import, so we do it lazily.
import torchvision.transforms.functional as TF

image_uint8 = TF.resize(image_uint8, max_size, antialias=None) # type: ignore
image_uint8 = image_uint8.permute(1, 2, 0)
image_uint8 = image_uint8.cpu().numpy()
data = cv2.imencode(".jpg", image_uint8)[1].tobytes() # type: ignore
Expand Down
5 changes: 4 additions & 1 deletion nerfstudio/data/datamanagers/base_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
)

import torch
import tyro
from torch import nn
from torch.nn import Parameter
from torch.utils.data.distributed import DistributedSampler
Expand Down Expand Up @@ -334,7 +335,9 @@ class VanillaDataManagerConfig(DataManagerConfig):
"""
patch_size: int = 1
"""Size of patch to sample from. If > 1, patch-based sampling will be used."""
camera_optimizer: Optional[CameraOptimizerConfig] = field(default=None)

# tyro.conf.Suppress prevents us from creating CLI arguments for this field.
camera_optimizer: tyro.conf.Suppress[Optional[CameraOptimizerConfig]] = field(default=None)
"""Deprecated, has been moved to the model config."""
pixel_sampler: PixelSamplerConfig = field(default_factory=PixelSamplerConfig)
"""Specifies the pixel sampler used to sample pixels from images."""
Expand Down
3 changes: 2 additions & 1 deletion nerfstudio/data/dataparsers/nerfstudio_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from typing import Literal, Optional, Type

import numpy as np
import open3d as o3d
import torch
from PIL import Image

Expand Down Expand Up @@ -337,6 +336,8 @@ def _generate_dataparser_outputs(self, split="train"):
return dataparser_outputs

def _load_3D_points(self, ply_file_path: Path, transform_matrix: torch.Tensor, scale_factor: float):
import open3d as o3d # Importing open3d is slow, so we only do it if we need it.

pcd = o3d.io.read_point_cloud(str(ply_file_path))

points3D = torch.from_numpy(np.asarray(pcd.points, dtype=np.float32))
Expand Down
4 changes: 3 additions & 1 deletion nerfstudio/data/dataparsers/nuscenes_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import numpy as np
import pyquaternion
import torch
from nuscenes.nuscenes import NuScenes as NuScenesDatabase

from nerfstudio.cameras.cameras import Cameras, CameraType
from nerfstudio.data.dataparsers.base_dataparser import DataParser, DataParserConfig, DataparserOutputs
Expand Down Expand Up @@ -81,6 +80,9 @@ class NuScenes(DataParser):
config: NuScenesDataParserConfig

def _generate_dataparser_outputs(self, split="train"):
# nuscenes is slow to import, so we only do it if we need it.
from nuscenes.nuscenes import NuScenes as NuScenesDatabase

nusc = NuScenesDatabase(
version=self.config.version,
dataroot=str(self.config.data_dir.absolute()),
Expand Down
10 changes: 8 additions & 2 deletions nerfstudio/exporter/exporter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@

import sys
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

import numpy as np
import open3d as o3d
import pymeshlab
import torch
from jaxtyping import Float
Expand All @@ -38,6 +37,11 @@
from nerfstudio.pipelines.base_pipeline import Pipeline, VanillaPipeline
from nerfstudio.utils.rich_utils import CONSOLE, ItersPerSecColumn

if TYPE_CHECKING:
# Importing open3d can take ~1 second, so only do it below if we actually
# need it.
import open3d as o3d


@dataclass
class Mesh:
Expand Down Expand Up @@ -193,6 +197,8 @@ def generate_point_cloud(
rgbs = torch.cat(rgbs, dim=0)
view_directions = torch.cat(view_directions, dim=0).cpu()

import open3d as o3d

pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points.double().cpu().numpy())
pcd.colors = o3d.utility.Vector3dVector(rgbs.double().cpu().numpy())
Expand Down
28 changes: 19 additions & 9 deletions nerfstudio/generative/deepfloyd.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import gc
import sys
from pathlib import Path
from typing import List, Optional, Union

Expand All @@ -24,15 +25,7 @@
from torch import Generator, Tensor, nn
from torch.cuda.amp.grad_scaler import GradScaler

from nerfstudio.generative.utils import CatchMissingPackages

try:
from diffusers import DiffusionPipeline, IFPipeline, IFPipeline as IFOrig
from diffusers.pipelines.deepfloyd_if import IFPipelineOutput, IFPipelineOutput as IFOutputOrig
from transformers import T5EncoderModel

except ImportError:
IFPipeline = IFPipelineOutput = T5EncoderModel = CatchMissingPackages()
from nerfstudio.utils.rich_utils import CONSOLE

IMG_DIM = 64

Expand All @@ -47,6 +40,16 @@ def __init__(self, device: Union[torch.device, str]):
super().__init__()
self.device = device

try:
from diffusers import DiffusionPipeline, IFPipeline
from transformers import T5EncoderModel

except ImportError:
CONSOLE.print("[bold red]Missing Stable Diffusion packages.")
CONSOLE.print(r"Install using [yellow]pip install nerfstudio\[gen][/yellow]")
CONSOLE.print(r"or [yellow]pip install -e .\[gen][/yellow] if installing from source.")
sys.exit(1)

self.text_encoder = T5EncoderModel.from_pretrained(
"DeepFloyd/IF-I-L-v1.0",
subfolder="text_encoder",
Expand Down Expand Up @@ -90,6 +93,8 @@ def delete_text_encoder(self):
gc.collect()
torch.cuda.empty_cache()

from diffusers import DiffusionPipeline, IFPipeline

self.pipe = IFPipeline.from_pretrained(
"DeepFloyd/IF-I-L-v1.0",
text_encoder=None,
Expand Down Expand Up @@ -126,6 +131,8 @@ def get_text_embeds(
prompt = [prompt] if isinstance(prompt, str) else prompt
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt

from diffusers import DiffusionPipeline

assert isinstance(self.pipe, DiffusionPipeline)
with torch.no_grad():
prompt_embeds, negative_embeds = self.pipe.encode_prompt(prompt, negative_prompt=negative_prompt)
Expand Down Expand Up @@ -200,6 +207,9 @@ def prompt_to_image(
The generated image.
"""

from diffusers import DiffusionPipeline, IFPipeline as IFOrig
from diffusers.pipelines.deepfloyd_if import IFPipelineOutput as IFOutputOrig

prompts = [prompts] if isinstance(prompts, str) else prompts
negative_prompts = [negative_prompts] if isinstance(negative_prompts, str) else negative_prompts
assert isinstance(self.pipe, DiffusionPipeline)
Expand Down
22 changes: 13 additions & 9 deletions nerfstudio/generative/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@

# Modified from https://github.com/ashawkey/stable-dreamfusion/blob/main/nerf/sd.py

import sys
from pathlib import Path
from typing import List, Optional, Union

import mediapy
import numpy as np
import torch
import torch.nn.functional as F
Expand All @@ -28,16 +28,8 @@
from torch import Tensor, nn
from torch.cuda.amp.grad_scaler import GradScaler

from nerfstudio.generative.utils import CatchMissingPackages
from nerfstudio.utils.rich_utils import CONSOLE

try:
from diffusers import DiffusionPipeline, PNDMScheduler, StableDiffusionPipeline

except ImportError:
PNDMScheduler = StableDiffusionPipeline = CatchMissingPackages()


IMG_DIM = 512
CONST_SCALE = 0.18215
SD_IDENTIFIERS = {
Expand All @@ -57,6 +49,15 @@ class StableDiffusion(nn.Module):
def __init__(self, device: Union[torch.device, str], num_train_timesteps: int = 1000, version="1-5") -> None:
super().__init__()

try:
from diffusers import DiffusionPipeline, PNDMScheduler, StableDiffusionPipeline

except ImportError:
CONSOLE.print("[bold red]Missing Stable Diffusion packages.")
CONSOLE.print(r"Install using [yellow]pip install nerfstudio\[gen][/yellow]")
CONSOLE.print(r"or [yellow]pip install -e .\[gen][/yellow] if installing from source.")
sys.exit(1)

self.device = device
self.num_train_timesteps = num_train_timesteps

Expand Down Expand Up @@ -319,6 +320,9 @@ def generate_image(
with torch.no_grad():
sd = StableDiffusion(cuda_device)
imgs = sd.prompt_to_img(prompt, negative, steps)

import mediapy # Slow to import, so we do it lazily.

mediapy.write_image(str(save_path), imgs[0])


Expand Down
35 changes: 0 additions & 35 deletions nerfstudio/generative/utils.py

This file was deleted.

7 changes: 4 additions & 3 deletions nerfstudio/models/base_surface_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@
import torch
import torch.nn.functional as F
from torch.nn import Parameter
from torchmetrics.functional import structural_similarity_index_measure
from torchmetrics.image import PeakSignalNoiseRatio
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

from nerfstudio.cameras.rays import RayBundle
from nerfstudio.field_components.encodings import NeRFEncoding
Expand Down Expand Up @@ -156,6 +153,10 @@ def populate_modules(self):
self.depth_loss = ScaleAndShiftInvariantLoss(alpha=0.5, scales=1)

# metrics
from torchmetrics.functional import structural_similarity_index_measure
from torchmetrics.image import PeakSignalNoiseRatio
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

self.psnr = PeakSignalNoiseRatio(data_range=1.0)
self.ssim = structural_similarity_index_measure
self.lpips = LearnedPerceptualImagePatchSimilarity()
Expand Down
22 changes: 17 additions & 5 deletions nerfstudio/models/gaussian_splatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,13 @@

import numpy as np
import torch
import torchvision.transforms.functional as TF
from gsplat._torch_impl import quat_to_rotmat
from gsplat.compute_cumulative_intersects import compute_cumulative_intersects
from gsplat.project_gaussians import ProjectGaussians
from gsplat.rasterize import RasterizeGaussians
from gsplat.sh import SphericalHarmonics, num_sh_bases
from pytorch_msssim import SSIM
from sklearn.neighbors import NearestNeighbors
from torch.nn import Parameter
from torchmetrics.image import PeakSignalNoiseRatio
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

from nerfstudio.cameras.camera_optimizers import CameraOptimizer, CameraOptimizerConfig
from nerfstudio.cameras.cameras import Cameras
Expand Down Expand Up @@ -205,6 +201,9 @@ def populate_modules(self):
self.opacities = torch.nn.Parameter(torch.logit(0.1 * torch.ones(self.num_points, 1)))

# metrics
from torchmetrics.image import PeakSignalNoiseRatio
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

self.psnr = PeakSignalNoiseRatio(data_range=1.0)
self.ssim = SSIM(data_range=1.0, size_average=True, channel=3)
self.lpips = LearnedPerceptualImagePatchSimilarity(normalize=True)
Expand Down Expand Up @@ -248,14 +247,16 @@ def load_state_dict(self, dict, **kwargs): # type: ignore

def k_nearest_sklearn(self, x: torch.Tensor, k: int):
"""
Find k-nearest neighbors using sklearn's NearestNeighbors.
Find k-nearest neighbors using sklearn's NearestNeighbors.
x: The data tensor of shape [num_samples, num_features]
k: The number of neighbors to retrieve
"""
# Convert tensor to numpy array
x_np = x.cpu().numpy()

# Build the nearest neighbors model
from sklearn.neighbors import NearestNeighbors

nn_model = NearestNeighbors(n_neighbors=k + 1, algorithm="auto", metric="euclidean").fit(x_np)

# Find the k-nearest neighbors
Expand Down Expand Up @@ -733,6 +734,10 @@ def get_metrics_dict(self, outputs, batch) -> Dict[str, torch.Tensor]:
d = self._get_downscale_factor()
if d > 1:
newsize = [batch["image"].shape[0] // d, batch["image"].shape[1] // d]

# torchvision can be slow to import, so we do it lazily.
import torchvision.transforms.functional as TF

gt_img = TF.resize(batch["image"].permute(2, 0, 1), newsize, antialias=None).permute(1, 2, 0)
else:
gt_img = batch["image"]
Expand All @@ -756,6 +761,10 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Te
d = self._get_downscale_factor()
if d > 1:
newsize = [batch["image"].shape[0] // d, batch["image"].shape[1] // d]

# torchvision can be slow to import, so we do it lazily.
import torchvision.transforms.functional as TF

gt_img = TF.resize(batch["image"].permute(2, 0, 1), newsize, antialias=None).permute(1, 2, 0)
else:
gt_img = batch["image"]
Expand Down Expand Up @@ -807,6 +816,9 @@ def get_image_metrics_and_images(
"""
d = self._get_downscale_factor()
if d > 1:
# torchvision can be slow to import, so we do it lazily.
import torchvision.transforms.functional as TF

newsize = [batch["image"].shape[0] // d, batch["image"].shape[1] // d]
gt_img = TF.resize(batch["image"].permute(2, 0, 1), newsize, antialias=None).permute(1, 2, 0)
predicted_rgb = TF.resize(outputs["rgb"].permute(2, 0, 1), newsize, antialias=None).permute(1, 2, 0)
Expand Down
Loading
Loading