diff --git a/gaussian_splatting/dataset/__init__.py b/gaussian_splatting/dataset/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/gaussian_splatting/scene/cameras.py b/gaussian_splatting/dataset/cameras.py similarity index 90% rename from gaussian_splatting/scene/cameras.py rename to gaussian_splatting/dataset/cameras.py index ae85edde4..f2c34563d 100644 --- a/gaussian_splatting/scene/cameras.py +++ b/gaussian_splatting/dataset/cameras.py @@ -13,8 +13,8 @@ import torch from torch import nn -from gaussian_splatting.utils.graphics import (getProjectionMatrix, - getWorld2View2) +from gaussian_splatting.utils.graphics import (get_projection_matrix, + get_world_2_view) class Camera(nn.Module): @@ -62,10 +62,10 @@ def __init__( self.scale = scale self.world_view_transform = ( - torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda() + torch.tensor(get_world_2_view(R, T, trans, scale)).transpose(0, 1).cuda() ) self.projection_matrix = ( - getProjectionMatrix( + get_projection_matrix( znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy ) .transpose(0, 1) diff --git a/gaussian_splatting/scene/colmap_loader.py b/gaussian_splatting/dataset/colmap_loader.py similarity index 100% rename from gaussian_splatting/scene/colmap_loader.py rename to gaussian_splatting/dataset/colmap_loader.py diff --git a/gaussian_splatting/scene/__init__.py b/gaussian_splatting/dataset/dataset.py similarity index 77% rename from gaussian_splatting/scene/__init__.py rename to gaussian_splatting/dataset/dataset.py index e4fb47ee4..4e110cd17 100644 --- a/gaussian_splatting/scene/__init__.py +++ b/gaussian_splatting/dataset/dataset.py @@ -13,9 +13,8 @@ import os import random -from gaussian_splatting.scene.dataset_readers import readColmapSceneInfo -from gaussian_splatting.utils.camera import (camera_to_JSON, - cameraList_from_camInfos) +from gaussian_splatting.dataset.dataset_readers import read_colmap_scene_info +from gaussian_splatting.utils.camera import camera_to_json, load_cameras class Dataset: @@ -35,7 +34,7 @@ def __init__( self.test_cameras = {} if os.path.exists(os.path.join(source_path, "sparse")): - scene_info = readColmapSceneInfo(source_path, keep_eval) + scene_info = read_colmap_scene_info(source_path, keep_eval) else: assert False, "Could not recognize scene type!" @@ -45,11 +44,11 @@ def __init__( for resolution_scale in resolution_scales: print("Loading Training Cameras") - self.train_cameras[resolution_scale] = cameraList_from_camInfos( + self.train_cameras[resolution_scale] = load_cameras( scene_info.train_cameras, resolution_scale, resolution ) print("Loading Test Cameras") - self.test_cameras[resolution_scale] = cameraList_from_camInfos( + self.test_cameras[resolution_scale] = load_cameras( scene_info.test_cameras, resolution_scale, resolution ) @@ -67,12 +66,12 @@ def save_scene_info(self, model_path): if self.scene_info.train_cameras: camlist.extend(self.scene_info.train_cameras) for id, cam in enumerate(camlist): - json_cams.append(camera_to_JSON(id, cam)) + json_cams.append(camera_to_json(id, cam)) with open(os.path.join(model_path, "cameras.json"), "w") as file: json.dump(json_cams, file) - def getTrainCameras(self, scale=1.0): + def get_train_cameras(self, scale=1.0): return self.train_cameras[scale] - def getTestCameras(self, scale=1.0): + def get_test_cameras(self, scale=1.0): return self.test_cameras[scale] diff --git a/gaussian_splatting/scene/dataset_readers.py b/gaussian_splatting/dataset/dataset_readers.py similarity index 84% rename from gaussian_splatting/scene/dataset_readers.py rename to gaussian_splatting/dataset/dataset_readers.py index dd8fb1ad7..48fcc25bc 100644 --- a/gaussian_splatting/scene/dataset_readers.py +++ b/gaussian_splatting/dataset/dataset_readers.py @@ -17,15 +17,15 @@ from PIL import Image from plyfile import PlyData, PlyElement -from gaussian_splatting.scene.colmap_loader import (qvec2rotmat, - read_extrinsics_binary, - read_extrinsics_text, - read_intrinsics_binary, - read_intrinsics_text, - read_points3D_binary, - read_points3D_text) -from gaussian_splatting.scene.gaussian_model import BasicPointCloud -from gaussian_splatting.utils.graphics import focal2fov, getWorld2View2 +from gaussian_splatting.dataset.colmap_loader import (qvec2rotmat, + read_extrinsics_binary, + read_extrinsics_text, + read_intrinsics_binary, + read_intrinsics_text, + read_points3D_binary, + read_points3D_text) +from gaussian_splatting.utils.graphics import (BasicPointCloud, focal2fov, + get_world_2_view) class CameraInfo(NamedTuple): @@ -49,7 +49,7 @@ class SceneInfo(NamedTuple): ply_path: str -def getNerfppNorm(cam_info): +def get_nerfpp_norm(cam_info): def get_center_and_diag(cam_centers): cam_centers = np.hstack(cam_centers) avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True) @@ -61,7 +61,7 @@ def get_center_and_diag(cam_centers): cam_centers = [] for cam in cam_info: - W2C = getWorld2View2(cam.R, cam.T) + W2C = get_world_2_view(cam.R, cam.T) C2W = np.linalg.inv(W2C) cam_centers.append(C2W[:3, 3:4]) @@ -73,7 +73,7 @@ def get_center_and_diag(cam_centers): return {"translate": translate, "radius": radius} -def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder): +def read_colmap_cameras(cam_extrinsics, cam_intrinsics, images_folder): cam_infos = [] for idx, key in enumerate(cam_extrinsics): sys.stdout.write("\r") @@ -125,7 +125,7 @@ def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder): return cam_infos -def fetchPly(path): +def fetch_ply(path): plydata = PlyData.read(path) vertices = plydata["vertex"] positions = np.vstack([vertices["x"], vertices["y"], vertices["z"]]).T @@ -134,7 +134,7 @@ def fetchPly(path): return BasicPointCloud(points=positions, colors=colors, normals=normals) -def storePly(path, xyz, rgb): +def store_ply(path, xyz, rgb): # Define the dtype for the structured array dtype = [ ("x", "f4"), @@ -160,7 +160,7 @@ def storePly(path, xyz, rgb): ply_data.write(path) -def readColmapSceneInfo(path, keep_eval, llffhold=8): +def read_colmap_scene_info(path, keep_eval, llffhold=8): try: cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin") cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin") @@ -172,7 +172,7 @@ def readColmapSceneInfo(path, keep_eval, llffhold=8): cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file) cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file) - cam_infos_unsorted = readColmapCameras( + cam_infos_unsorted = read_colmap_cameras( cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, "images"), @@ -186,7 +186,7 @@ def readColmapSceneInfo(path, keep_eval, llffhold=8): train_cam_infos = cam_infos test_cam_infos = [] - nerf_normalization = getNerfppNorm(train_cam_infos) + nerf_normalization = get_nerfpp_norm(train_cam_infos) ply_path = os.path.join(path, "sparse/0/points3D.ply") bin_path = os.path.join(path, "sparse/0/points3D.bin") @@ -199,9 +199,9 @@ def readColmapSceneInfo(path, keep_eval, llffhold=8): xyz, rgb, _ = read_points3D_binary(bin_path) except: xyz, rgb, _ = read_points3D_text(txt_path) - storePly(ply_path, xyz, rgb) + store_ply(ply_path, xyz, rgb) try: - pcd = fetchPly(ply_path) + pcd = fetch_ply(ply_path) except: pcd = None diff --git a/gaussian_splatting/scene/gaussian_model.py b/gaussian_splatting/model.py similarity index 99% rename from gaussian_splatting/scene/gaussian_model.py rename to gaussian_splatting/model.py index 77e81b3c2..a23e3fbfc 100644 --- a/gaussian_splatting/scene/gaussian_model.py +++ b/gaussian_splatting/model.py @@ -20,7 +20,6 @@ from gaussian_splatting.utils.general import (build_rotation, build_scaling_rotation, inverse_sigmoid, strip_symmetric) -from gaussian_splatting.utils.graphics import BasicPointCloud from gaussian_splatting.utils.sh import RGB2SH from gaussian_splatting.utils.system import mkdir_p diff --git a/gaussian_splatting/gaussian_renderer/__init__.py b/gaussian_splatting/render.py similarity index 97% rename from gaussian_splatting/gaussian_renderer/__init__.py rename to gaussian_splatting/render.py index e6de931a0..cc4c5df6b 100644 --- a/gaussian_splatting/gaussian_renderer/__init__.py +++ b/gaussian_splatting/render.py @@ -15,7 +15,7 @@ from diff_gaussian_rasterization import (GaussianRasterizationSettings, GaussianRasterizer) -from gaussian_splatting.scene.gaussian_model import GaussianModel +from gaussian_splatting.model import GaussianModel def render( diff --git a/gaussian_splatting/training.py b/gaussian_splatting/trainer.py similarity index 95% rename from gaussian_splatting/training.py rename to gaussian_splatting/trainer.py index 3dfe2cec7..12a2decdc 100644 --- a/gaussian_splatting/training.py +++ b/gaussian_splatting/trainer.py @@ -1,15 +1,14 @@ import os import uuid -from argparse import Namespace from random import randint import torch from tqdm import tqdm -from gaussian_splatting.gaussian_renderer import render +from gaussian_splatting.dataset.dataset import Dataset +from gaussian_splatting.model import GaussianModel from gaussian_splatting.optimizer import Optimizer -from gaussian_splatting.scene import Dataset -from gaussian_splatting.scene.gaussian_model import GaussianModel +from gaussian_splatting.render import render from gaussian_splatting.utils.general import safe_state from gaussian_splatting.utils.image import psnr from gaussian_splatting.utils.loss import l1_loss, ssim @@ -82,7 +81,7 @@ def run(self): # Pick a random camera if not cameras: - cameras = self.dataset.getTrainCameras().copy() + cameras = self.dataset.get_train_cameras().copy() camera = cameras.pop(randint(0, len(cameras) - 1)) # Render image @@ -173,10 +172,10 @@ def _report(self, iteration): # Report test and samples of training set torch.cuda.empty_cache() validation_configs = { - "test": self.dataset.getTestCameras(), + "test": self.dataset.get_test_cameras(), "train": [ - self.dataset.getTrainCameras()[ - idx % len(self.dataset.getTrainCameras()) + self.dataset.get_train_cameras()[ + idx % len(self.dataset.get_train_cameras()) ] for idx in range(5, 30, 5) ], diff --git a/gaussian_splatting/utils/camera.py b/gaussian_splatting/utils/camera.py index cfffa2db4..1a150c14a 100644 --- a/gaussian_splatting/utils/camera.py +++ b/gaussian_splatting/utils/camera.py @@ -11,7 +11,7 @@ import numpy as np -from gaussian_splatting.scene.cameras import Camera +from gaussian_splatting.dataset.cameras import Camera from gaussian_splatting.utils.general import PILtoTorch from gaussian_splatting.utils.graphics import fov2focal @@ -65,16 +65,16 @@ def load_camera(resolution, cam_id, cam_info, resolution_scale): ) -def cameraList_from_camInfos(cam_infos, resolution_scale, resolution): - camera_list = [] +def load_cameras(cameras_infos, resolution_scale, resolution): + cameras = [ + load_camera(resolution, cam_id, cam_info, resolution_scale) + for cam_id, cam_info in enumerate(cameras_infos) + ] - for cam_id, c in enumerate(cam_infos): - camera_list.append(load_camera(resolution, cam_id, c, resolution_scale)) + return cameras - return camera_list - -def camera_to_JSON(id, camera: Camera): +def camera_to_json(_id, camera: Camera): Rt = np.zeros((4, 4)) Rt[:3, :3] = camera.R.transpose() Rt[:3, 3] = camera.T @@ -84,8 +84,9 @@ def camera_to_JSON(id, camera: Camera): pos = W2C[:3, 3] rot = W2C[:3, :3] serializable_array_2d = [x.tolist() for x in rot] - camera_entry = { - "id": id, + + json_camera = { + "id": _id, "img_name": camera.image_name, "width": camera.width, "height": camera.height, @@ -94,4 +95,5 @@ def camera_to_JSON(id, camera: Camera): "fy": fov2focal(camera.FovY, camera.height), "fx": fov2focal(camera.FovX, camera.width), } - return camera_entry + + return json_camera diff --git a/gaussian_splatting/utils/general.py b/gaussian_splatting/utils/general.py index f1181ba90..c55064af1 100644 --- a/gaussian_splatting/utils/general.py +++ b/gaussian_splatting/utils/general.py @@ -10,8 +10,6 @@ # import random -import sys -from datetime import datetime import numpy as np import torch diff --git a/gaussian_splatting/utils/graphics.py b/gaussian_splatting/utils/graphics.py index 40f742967..5dfd5c2c4 100644 --- a/gaussian_splatting/utils/graphics.py +++ b/gaussian_splatting/utils/graphics.py @@ -32,15 +32,7 @@ def geom_transform_points(points, transf_matrix): return (points_out[..., :3] / denom).squeeze(dim=0) -def getWorld2View(R, t): - Rt = np.zeros((4, 4)) - Rt[:3, :3] = R.transpose() - Rt[:3, 3] = t - Rt[3, 3] = 1.0 - return np.float32(Rt) - - -def getWorld2View2(R, t, translate=np.array([0.0, 0.0, 0.0]), scale=1.0): +def get_world_2_view(R, t, translate=np.array([0.0, 0.0, 0.0]), scale=1.0): Rt = np.zeros((4, 4)) Rt[:3, :3] = R.transpose() Rt[:3, 3] = t @@ -54,7 +46,7 @@ def getWorld2View2(R, t, translate=np.array([0.0, 0.0, 0.0]), scale=1.0): return np.float32(Rt) -def getProjectionMatrix(znear, zfar, fovX, fovY): +def get_projection_matrix(znear, zfar, fovX, fovY): tanHalfFovY = math.tan((fovY / 2)) tanHalfFovX = math.tan((fovX / 2)) diff --git a/gaussian_splatting/utils/system.py b/gaussian_splatting/utils/system.py index a89fac10a..04e7ece5c 100644 --- a/gaussian_splatting/utils/system.py +++ b/gaussian_splatting/utils/system.py @@ -25,6 +25,7 @@ def mkdir_p(folder_path): raise -def searchForMaxIteration(folder): +def search_for_max_iteration(folder): saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] + return max(saved_iters) diff --git a/scripts/render.py b/scripts/render.py index 71a96a5bc..96a38ae9e 100644 --- a/scripts/render.py +++ b/scripts/render.py @@ -17,10 +17,11 @@ import torchvision from tqdm import tqdm -from gaussian_splatting.gaussian_renderer import GaussianModel, render -from gaussian_splatting.scene import Dataset +from gaussian_splatting.dataset.dataset import Dataset +from gaussian_splatting.model import GaussianModel +from gaussian_splatting.render import render from gaussian_splatting.utils.general import safe_state -from gaussian_splatting.utils.system import searchForMaxIteration +from gaussian_splatting.utils.system import search_for_max_iteration def render_set(model_path, name, iteration, views, gaussian_model): @@ -58,7 +59,7 @@ def render_sets( gaussian_model = GaussianModel(dataset.sh_degree) if iteration == -1: - iteration = searchForMaxIteration( + iteration = search_for_max_iteration( os.path.join(self.model_path, "point_cloud") ) print(f"Loading trained model at iteration {iteration}.") @@ -76,7 +77,7 @@ def render_sets( render_set( model_path, "train", - dataset.getTrainCameras(), + dataset.get_train_cameras(), gaussian_model, ) @@ -84,7 +85,7 @@ def render_sets( render_set( model_path, "test", - dataset.getTestCameras(), + dataset.get_test_cameras(), gaussian_model, ) diff --git a/scripts/train.py b/scripts/train.py index 19b5d16c6..fe6a5612f 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -11,7 +11,7 @@ import sys from argparse import ArgumentParser -from gaussian_splatting.training import Trainer +from gaussian_splatting.trainer import Trainer if __name__ == "__main__": # Set up command line argument parser diff --git a/scripts/train_modal.py b/scripts/train_modal.py index 3531dd972..007b9a1e8 100644 --- a/scripts/train_modal.py +++ b/scripts/train_modal.py @@ -46,7 +46,7 @@ timeout=10800, ) def f(): - from gaussian_splatting.training import Trainer + from gaussian_splatting.trainer import Trainer trainer = Trainer(source_path="/workspace/data/phil_open/5") trainer.run()