Skip to content

Commit

Permalink
optimize training
Browse files Browse the repository at this point in the history
  • Loading branch information
bouchardi committed Apr 10, 2024
1 parent c283535 commit 11120b1
Show file tree
Hide file tree
Showing 11 changed files with 44 additions and 39 deletions.
3 changes: 2 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ RUN conda init
RUN conda env create --file environment.yml

# COLMAP dependencies
RUN DEBIAN_FRONTEND=noninteractive apt-get install -y \
RUN apt update && \
DEBIAN_FRONTEND=noninteractive apt-get install -y \
git \
cmake \
libboost-program-options-dev \
Expand Down
8 changes: 2 additions & 6 deletions gaussian_splatting/dataset/cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,12 @@ def __init__(

self.data_device = torch.device("cuda")

self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
self.original_image = image.clamp(0.0, 1.0)
self.image_width = self.original_image.shape[2]
self.image_height = self.original_image.shape[1]

if gt_alpha_mask is not None:
self.original_image *= gt_alpha_mask.to(self.data_device)
else:
self.original_image *= torch.ones(
(1, self.image_height, self.image_width), device=self.data_device
)
self.original_image *= gt_alpha_mask

self.zfar = 100.0
self.znear = 0.01
Expand Down
1 change: 0 additions & 1 deletion gaussian_splatting/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@


class Dataset:

def __init__(
self,
source_path,
Expand Down
1 change: 0 additions & 1 deletion gaussian_splatting/metrics/lpipPyTorch/modules/lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ class LPIPS(nn.Module):
"""

def __init__(self, net_type: str = "alex", version: str = "0.1"):

assert version in ["0.1"], "v0.1 is only supported now"

super(LPIPS, self).__init__()
Expand Down
1 change: 0 additions & 1 deletion gaussian_splatting/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):


class GaussianModel:

def __init__(self, sh_degree: int = 3):
self.active_sh_degree = 0
self.max_sh_degree = sh_degree
Expand Down
1 change: 0 additions & 1 deletion gaussian_splatting/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ def concatenate_points(self, tensors_dict):
extension_tensor = tensors_dict[group["name"]]
stored_state = self._optimizer.state.get(group["params"][0], None)
if stored_state is not None:

stored_state["exp_avg"] = torch.cat(
(stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0
)
Expand Down
57 changes: 33 additions & 24 deletions gaussian_splatting/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@ def __init__(
resolution=-1,
sh_degree=3,
checkpoint_path=None,
output_path=None,
min_num_cameras=15,
):
self._model_path = self._prepare_model_path()
self.model_path = self._prepare_model_path(output_path)

self.dataset = Dataset(source_path, keep_eval=keep_eval, resolution=resolution)
self.dataset.save_scene_info(self._model_path)
if len(self.dataset.get_train_cameras()) <= min_num_cameras:
raise ValueError("Not enough cameras to reconstruct the scene!")
self.dataset.save_scene_info(self.model_path)

self.gaussian_model = GaussianModel(sh_degree)
self.gaussian_model.initialize(self.dataset)
Expand All @@ -35,18 +39,18 @@ def __init__(

self._checkpoint_path = checkpoint_path

self._debug = False
self._debug = True

self._iterations = 30000
self._testing_iterations = [7000, 30000]
self._saving_iterations = [7000, 30000]
self._iterations = 10000
self._testing_iterations = [20, 250, 1000, 2500, 7000, 30000]
self._saving_iterations = [20, 250, 1000, 2500, 3500, 4500, 5000, 7000, 10000]
self._checkpoint_iterations = []

# Loss function
self._lambda_dssim = 0.2

# Densification and pruning
self._opacity_reset_interval = 3000
self._opacity_reset_interval = 30000
self._min_opacity = 0.005
self._max_screen_size = 20
self._percent_dense = 0.01
Expand All @@ -60,9 +64,11 @@ def __init__(
def run(self):
first_iter = 0
if self._checkpoint_path:
gaussian_model_state_dict, self.optimizer_state_dict, first_iter = (
torch.load(checkpoint_path)
)
(
gaussian_model_state_dict,
self.optimizer_state_dict,
first_iter,
) = torch.load(checkpoint_path)
self.gaussian_model.load_state_dict(gaussian_model_state_dict)
self.optimizer.load_state_dict(optmizer_state_dict)

Expand Down Expand Up @@ -101,6 +107,15 @@ def run(self):
# except Exception:
# import pdb; pdb.set_trace()

if iteration in self._saving_iterations:
print("\n[ITER {}] Saving Gaussians".format(iteration))
point_cloud_path = os.path.join(
self.model_path, "point_cloud/iteration_{}".format(iteration)
)
self.gaussian_model.save_ply(
os.path.join(point_cloud_path, "point_cloud.ply")
)

with torch.no_grad():
# Progress bar
ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
Expand All @@ -111,17 +126,8 @@ def run(self):
progress_bar.close()

# Log and save
if iteration in self._testing_iterations:
self._report(iteration)

if iteration in self._saving_iterations:
print("\n[ITER {}] Saving Gaussians".format(iteration))
point_cloud_path = os.path.join(
self.model_path, "point_cloud/iteration_{}".format(iteration)
)
self.gaussian_model.save_ply(
os.path.join(point_cloud_path, "point_cloud.ply")
)
# if iteration in self._testing_iterations:
# self._report(iteration)

# Densification
if iteration < self._densification_iteration_stop:
Expand Down Expand Up @@ -158,9 +164,12 @@ def run(self):
self.model_path + "/chkpnt" + str(iteration) + ".pth",
)

def _prepare_model_path(self):
def _prepare_model_path(self, output_path):
unique_str = str(uuid.uuid4())
model_path = os.path.join("./output/", unique_str[0:10])
model_path = os.path.join(
"./output/" if output_path is None else f"./output/{output_path}/",
unique_str[0:10],
)

# Set up output folder
print("Output folder: {}".format(model_path))
Expand All @@ -181,7 +190,7 @@ def _report(self, iteration):
],
}

for config_name, cameras in validation_configs:
for config_name, cameras in validation_configs.items():
if not cameras or len(cameras) == 0:
continue

Expand Down
3 changes: 2 additions & 1 deletion scripts/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def main(source_path, resize, use_gpu):
step=" Feature matching",
command=f"colmap sequential_matcher \
--database_path {source_path / 'distorted' / 'database.db'} \
--SiftMatching.use_gpu {use_gpu}",
--SiftMatching.use_gpu {use_gpu} \
--SiftMatching.max_num_matches 10000",
)

_run_command(
Expand Down
2 changes: 1 addition & 1 deletion scripts/extract_video_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def main(video_filename: Path, output_path: Path, k: int = 10):
current_frame += 1
continue

output_filename = output_path / f"{current_frame}.jpg"
output_filename = output_path / f"{str(current_frame).zfill(6)}.jpg"
cv2.imwrite(output_filename.as_posix(), frame)
print(f"Writing {output_filename}.")

Expand Down
1 change: 0 additions & 1 deletion scripts/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def readImages(renders_dir, gt_dir):


def evaluate(model_paths):

full_dict = {}
per_view_dict = {}
full_dict_polytopeonly = {}
Expand Down
5 changes: 4 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
packages=find_packages(), # Automatically discover and include all packages
scripts=[],
install_requires=[
"pyyaml"
"pyyaml",
"black",
"isort",
"importchecker"
]
)

0 comments on commit 11120b1

Please sign in to comment.