Skip to content

Commit

Permalink
Rename files and functions
Browse files Browse the repository at this point in the history
  • Loading branch information
bouchardi committed Mar 19, 2024
1 parent 4f3f351 commit 2bfcee3
Show file tree
Hide file tree
Showing 15 changed files with 65 additions and 74 deletions.
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
import torch
from torch import nn

from gaussian_splatting.utils.graphics import (getProjectionMatrix,
getWorld2View2)
from gaussian_splatting.utils.graphics import (get_projection_matrix,
get_world_2_view)


class Camera(nn.Module):
Expand Down Expand Up @@ -62,10 +62,10 @@ def __init__(
self.scale = scale

self.world_view_transform = (
torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()
torch.tensor(get_world_2_view(R, T, trans, scale)).transpose(0, 1).cuda()
)
self.projection_matrix = (
getProjectionMatrix(
get_projection_matrix(
znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy
)
.transpose(0, 1)
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
import os
import random

from gaussian_splatting.scene.dataset_readers import readColmapSceneInfo
from gaussian_splatting.utils.camera import (camera_to_JSON,
cameraList_from_camInfos)
from gaussian_splatting.dataset.dataset_readers import read_colmap_scene_info
from gaussian_splatting.utils.camera import camera_to_json, load_cameras


class Dataset:
Expand All @@ -35,7 +34,7 @@ def __init__(
self.test_cameras = {}

if os.path.exists(os.path.join(source_path, "sparse")):
scene_info = readColmapSceneInfo(source_path, keep_eval)
scene_info = read_colmap_scene_info(source_path, keep_eval)
else:
assert False, "Could not recognize scene type!"

Expand All @@ -45,11 +44,11 @@ def __init__(

for resolution_scale in resolution_scales:
print("Loading Training Cameras")
self.train_cameras[resolution_scale] = cameraList_from_camInfos(
self.train_cameras[resolution_scale] = load_cameras(
scene_info.train_cameras, resolution_scale, resolution
)
print("Loading Test Cameras")
self.test_cameras[resolution_scale] = cameraList_from_camInfos(
self.test_cameras[resolution_scale] = load_cameras(
scene_info.test_cameras, resolution_scale, resolution
)

Expand All @@ -67,12 +66,12 @@ def save_scene_info(self, model_path):
if self.scene_info.train_cameras:
camlist.extend(self.scene_info.train_cameras)
for id, cam in enumerate(camlist):
json_cams.append(camera_to_JSON(id, cam))
json_cams.append(camera_to_json(id, cam))
with open(os.path.join(model_path, "cameras.json"), "w") as file:
json.dump(json_cams, file)

def getTrainCameras(self, scale=1.0):
def get_train_cameras(self, scale=1.0):
return self.train_cameras[scale]

def getTestCameras(self, scale=1.0):
def get_test_cameras(self, scale=1.0):
return self.test_cameras[scale]
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@
from PIL import Image
from plyfile import PlyData, PlyElement

from gaussian_splatting.scene.colmap_loader import (qvec2rotmat,
read_extrinsics_binary,
read_extrinsics_text,
read_intrinsics_binary,
read_intrinsics_text,
read_points3D_binary,
read_points3D_text)
from gaussian_splatting.scene.gaussian_model import BasicPointCloud
from gaussian_splatting.utils.graphics import focal2fov, getWorld2View2
from gaussian_splatting.dataset.colmap_loader import (qvec2rotmat,
read_extrinsics_binary,
read_extrinsics_text,
read_intrinsics_binary,
read_intrinsics_text,
read_points3D_binary,
read_points3D_text)
from gaussian_splatting.utils.graphics import (BasicPointCloud, focal2fov,
get_world_2_view)


class CameraInfo(NamedTuple):
Expand All @@ -49,7 +49,7 @@ class SceneInfo(NamedTuple):
ply_path: str


def getNerfppNorm(cam_info):
def get_nerfpp_norm(cam_info):
def get_center_and_diag(cam_centers):
cam_centers = np.hstack(cam_centers)
avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)
Expand All @@ -61,7 +61,7 @@ def get_center_and_diag(cam_centers):
cam_centers = []

for cam in cam_info:
W2C = getWorld2View2(cam.R, cam.T)
W2C = get_world_2_view(cam.R, cam.T)
C2W = np.linalg.inv(W2C)
cam_centers.append(C2W[:3, 3:4])

Expand All @@ -73,7 +73,7 @@ def get_center_and_diag(cam_centers):
return {"translate": translate, "radius": radius}


def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
def read_colmap_cameras(cam_extrinsics, cam_intrinsics, images_folder):
cam_infos = []
for idx, key in enumerate(cam_extrinsics):
sys.stdout.write("\r")
Expand Down Expand Up @@ -125,7 +125,7 @@ def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
return cam_infos


def fetchPly(path):
def fetch_ply(path):
plydata = PlyData.read(path)
vertices = plydata["vertex"]
positions = np.vstack([vertices["x"], vertices["y"], vertices["z"]]).T
Expand All @@ -134,7 +134,7 @@ def fetchPly(path):
return BasicPointCloud(points=positions, colors=colors, normals=normals)


def storePly(path, xyz, rgb):
def store_ply(path, xyz, rgb):
# Define the dtype for the structured array
dtype = [
("x", "f4"),
Expand All @@ -160,7 +160,7 @@ def storePly(path, xyz, rgb):
ply_data.write(path)


def readColmapSceneInfo(path, keep_eval, llffhold=8):
def read_colmap_scene_info(path, keep_eval, llffhold=8):
try:
cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin")
cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin")
Expand All @@ -172,7 +172,7 @@ def readColmapSceneInfo(path, keep_eval, llffhold=8):
cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file)
cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file)

cam_infos_unsorted = readColmapCameras(
cam_infos_unsorted = read_colmap_cameras(
cam_extrinsics=cam_extrinsics,
cam_intrinsics=cam_intrinsics,
images_folder=os.path.join(path, "images"),
Expand All @@ -186,7 +186,7 @@ def readColmapSceneInfo(path, keep_eval, llffhold=8):
train_cam_infos = cam_infos
test_cam_infos = []

nerf_normalization = getNerfppNorm(train_cam_infos)
nerf_normalization = get_nerfpp_norm(train_cam_infos)

ply_path = os.path.join(path, "sparse/0/points3D.ply")
bin_path = os.path.join(path, "sparse/0/points3D.bin")
Expand All @@ -199,9 +199,9 @@ def readColmapSceneInfo(path, keep_eval, llffhold=8):
xyz, rgb, _ = read_points3D_binary(bin_path)
except:
xyz, rgb, _ = read_points3D_text(txt_path)
storePly(ply_path, xyz, rgb)
store_ply(ply_path, xyz, rgb)
try:
pcd = fetchPly(ply_path)
pcd = fetch_ply(ply_path)
except:
pcd = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from gaussian_splatting.utils.general import (build_rotation,
build_scaling_rotation,
inverse_sigmoid, strip_symmetric)
from gaussian_splatting.utils.graphics import BasicPointCloud
from gaussian_splatting.utils.sh import RGB2SH
from gaussian_splatting.utils.system import mkdir_p

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from diff_gaussian_rasterization import (GaussianRasterizationSettings,
GaussianRasterizer)

from gaussian_splatting.scene.gaussian_model import GaussianModel
from gaussian_splatting.model import GaussianModel


def render(
Expand Down
15 changes: 7 additions & 8 deletions gaussian_splatting/training.py → gaussian_splatting/trainer.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import os
import uuid
from argparse import Namespace
from random import randint

import torch
from tqdm import tqdm

from gaussian_splatting.gaussian_renderer import render
from gaussian_splatting.dataset.dataset import Dataset
from gaussian_splatting.model import GaussianModel
from gaussian_splatting.optimizer import Optimizer
from gaussian_splatting.scene import Dataset
from gaussian_splatting.scene.gaussian_model import GaussianModel
from gaussian_splatting.render import render
from gaussian_splatting.utils.general import safe_state
from gaussian_splatting.utils.image import psnr
from gaussian_splatting.utils.loss import l1_loss, ssim
Expand Down Expand Up @@ -82,7 +81,7 @@ def run(self):

# Pick a random camera
if not cameras:
cameras = self.dataset.getTrainCameras().copy()
cameras = self.dataset.get_train_cameras().copy()
camera = cameras.pop(randint(0, len(cameras) - 1))

# Render image
Expand Down Expand Up @@ -173,10 +172,10 @@ def _report(self, iteration):
# Report test and samples of training set
torch.cuda.empty_cache()
validation_configs = {
"test": self.dataset.getTestCameras(),
"test": self.dataset.get_test_cameras(),
"train": [
self.dataset.getTrainCameras()[
idx % len(self.dataset.getTrainCameras())
self.dataset.get_train_cameras()[
idx % len(self.dataset.get_train_cameras())
]
for idx in range(5, 30, 5)
],
Expand Down
24 changes: 13 additions & 11 deletions gaussian_splatting/utils/camera.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import numpy as np

from gaussian_splatting.scene.cameras import Camera
from gaussian_splatting.dataset.cameras import Camera
from gaussian_splatting.utils.general import PILtoTorch
from gaussian_splatting.utils.graphics import fov2focal

Expand Down Expand Up @@ -65,16 +65,16 @@ def load_camera(resolution, cam_id, cam_info, resolution_scale):
)


def cameraList_from_camInfos(cam_infos, resolution_scale, resolution):
camera_list = []
def load_cameras(cameras_infos, resolution_scale, resolution):
cameras = [
load_camera(resolution, cam_id, cam_info, resolution_scale)
for cam_id, cam_info in enumerate(cameras_infos)
]

for cam_id, c in enumerate(cam_infos):
camera_list.append(load_camera(resolution, cam_id, c, resolution_scale))
return cameras

return camera_list


def camera_to_JSON(id, camera: Camera):
def camera_to_json(_id, camera: Camera):
Rt = np.zeros((4, 4))
Rt[:3, :3] = camera.R.transpose()
Rt[:3, 3] = camera.T
Expand All @@ -84,8 +84,9 @@ def camera_to_JSON(id, camera: Camera):
pos = W2C[:3, 3]
rot = W2C[:3, :3]
serializable_array_2d = [x.tolist() for x in rot]
camera_entry = {
"id": id,

json_camera = {
"id": _id,
"img_name": camera.image_name,
"width": camera.width,
"height": camera.height,
Expand All @@ -94,4 +95,5 @@ def camera_to_JSON(id, camera: Camera):
"fy": fov2focal(camera.FovY, camera.height),
"fx": fov2focal(camera.FovX, camera.width),
}
return camera_entry

return json_camera
2 changes: 0 additions & 2 deletions gaussian_splatting/utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
#

import random
import sys
from datetime import datetime

import numpy as np
import torch
Expand Down
12 changes: 2 additions & 10 deletions gaussian_splatting/utils/graphics.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,7 @@ def geom_transform_points(points, transf_matrix):
return (points_out[..., :3] / denom).squeeze(dim=0)


def getWorld2View(R, t):
Rt = np.zeros((4, 4))
Rt[:3, :3] = R.transpose()
Rt[:3, 3] = t
Rt[3, 3] = 1.0
return np.float32(Rt)


def getWorld2View2(R, t, translate=np.array([0.0, 0.0, 0.0]), scale=1.0):
def get_world_2_view(R, t, translate=np.array([0.0, 0.0, 0.0]), scale=1.0):
Rt = np.zeros((4, 4))
Rt[:3, :3] = R.transpose()
Rt[:3, 3] = t
Expand All @@ -54,7 +46,7 @@ def getWorld2View2(R, t, translate=np.array([0.0, 0.0, 0.0]), scale=1.0):
return np.float32(Rt)


def getProjectionMatrix(znear, zfar, fovX, fovY):
def get_projection_matrix(znear, zfar, fovX, fovY):
tanHalfFovY = math.tan((fovY / 2))
tanHalfFovX = math.tan((fovX / 2))

Expand Down
3 changes: 2 additions & 1 deletion gaussian_splatting/utils/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def mkdir_p(folder_path):
raise


def searchForMaxIteration(folder):
def search_for_max_iteration(folder):
saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)]

return max(saved_iters)
13 changes: 7 additions & 6 deletions scripts/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
import torchvision
from tqdm import tqdm

from gaussian_splatting.gaussian_renderer import GaussianModel, render
from gaussian_splatting.scene import Dataset
from gaussian_splatting.dataset.dataset import Dataset
from gaussian_splatting.model import GaussianModel
from gaussian_splatting.render import render
from gaussian_splatting.utils.general import safe_state
from gaussian_splatting.utils.system import searchForMaxIteration
from gaussian_splatting.utils.system import search_for_max_iteration


def render_set(model_path, name, iteration, views, gaussian_model):
Expand Down Expand Up @@ -58,7 +59,7 @@ def render_sets(
gaussian_model = GaussianModel(dataset.sh_degree)

if iteration == -1:
iteration = searchForMaxIteration(
iteration = search_for_max_iteration(
os.path.join(self.model_path, "point_cloud")
)
print(f"Loading trained model at iteration {iteration}.")
Expand All @@ -76,15 +77,15 @@ def render_sets(
render_set(
model_path,
"train",
dataset.getTrainCameras(),
dataset.get_train_cameras(),
gaussian_model,
)

if not skip_test:
render_set(
model_path,
"test",
dataset.getTestCameras(),
dataset.get_test_cameras(),
gaussian_model,
)

Expand Down
2 changes: 1 addition & 1 deletion scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import sys
from argparse import ArgumentParser

from gaussian_splatting.training import Trainer
from gaussian_splatting.trainer import Trainer

if __name__ == "__main__":
# Set up command line argument parser
Expand Down
2 changes: 1 addition & 1 deletion scripts/train_modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
timeout=10800,
)
def f():
from gaussian_splatting.training import Trainer
from gaussian_splatting.trainer import Trainer

trainer = Trainer(source_path="/workspace/data/phil_open/5")
trainer.run()
Expand Down

0 comments on commit 2bfcee3

Please sign in to comment.