Skip to content

Commit

Permalink
global reconstruction with no pose prior
Browse files Browse the repository at this point in the history
  • Loading branch information
bouchardi committed Apr 11, 2024
1 parent 0df8051 commit b48029b
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 24 deletions.
7 changes: 4 additions & 3 deletions gaussian_splatting/colmap_free/global_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@


class GlobalTrainer(Trainer):
def __init__(self, gaussian_model):
self._model_path = self._prepare_model_path()
def __init__(self, gaussian_model, output_path = None):
self._model_path = self._prepare_model_path(output_path)

self.gaussian_model = gaussian_model
self.cameras = []
Expand Down Expand Up @@ -81,4 +81,5 @@ def run(self, iterations: int = 1000):
viewspace_point_tensor, visibility_filter, radii
)
self._densify_and_prune(True)
self._reset_opacity()
#self._reset_opacity()

Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, gaussian_model):
self.transformation_model.to(gaussian_model.get_xyz.device)

self.optimizer = torch.optim.Adam(
self.transformation_model.parameters(), lr=0.0005
self.transformation_model.parameters(), lr=0.0001
)
self._photometric_loss = PhotometricLoss(lambda_dssim=0.2)

Expand Down Expand Up @@ -57,8 +57,8 @@ def run(self, current_camera, gt_image, iterations: int = 1000, run: int = 0):
progress_bar.set_postfix({"Loss": f"{loss:.{5}f}"})
progress_bar.update(1)

if iteration % 50 == 0 or iteration == iterations - 1:
self._save_artifacts(losses, rendered_image, iteration, run)
#if iteration % 50 == 0 or iteration == iterations - 1:
# self._save_artifacts(losses, rendered_image, iteration, run)

if best_loss is None or best_loss > loss:
best_loss = loss.cpu().item()
Expand Down
39 changes: 21 additions & 18 deletions scripts/train_colmap_free.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import torchvision

# from gaussian_splatting.colmap_free.global_trainer import GlobalTrainer
from gaussian_splatting.colmap_free.global_trainer import GlobalTrainer
from gaussian_splatting.colmap_free.local_initialization_trainer import \
LocalInitializationTrainer
from gaussian_splatting.colmap_free.local_transformation_trainer import \
Expand All @@ -18,41 +18,43 @@

def main():
debug = True
iteration_step_size = 50
global_iterations = 5
step_size = 10
initialization_iterations = 1000
transformation_iterations = 250
global_iterations = 50

photometric_loss = PhotometricLoss(lambda_dssim=0.2)
dataset = ImageDataset(images_path=Path("data/phil/1/input/"))

global_trainer = None
for iteration in range(0, len(dataset), iteration_step_size):
print(f">>> Current: {iteration} / Next: {iteration + iteration_step_size}")
for i in range(len(dataset) // step_size):
print(f">>> Current: {i * step_size} / Next: {(i + 1) * step_size}")

current_image = dataset.get_frame(iteration)
next_image = dataset.get_frame(iteration + iteration_step_size)
current_image = dataset.get_frame(i * step_size)
next_image = dataset.get_frame((i + 1) * step_size)

if iteration == 0:
if i == 0:
current_camera = _get_orthogonal_camera(current_image)
else:
current_camera = next_camera

current_camera, current_gaussian_model = _initialize_local_gaussian_model(
current_image, current_camera, run=iteration
current_image, current_camera, iterations=initialization_iterations, run=i
)

# if iteration == 0:
# global_gaussian_model = copy.deepcopy(current_gaussian_model)
# global_trainer = GlobalTrainer(global_gaussian_model)
if i == 0:
global_gaussian_model = copy.deepcopy(current_gaussian_model)
global_trainer = GlobalTrainer(global_gaussian_model)

if debug:
current_gaussian_model_copy = copy.deepcopy(current_gaussian_model)

next_camera, next_gaussian_model = _transform_local_gaussian_model(
next_image, current_camera, current_gaussian_model, run=iteration
next_image, current_camera, current_gaussian_model,
iterations=transformation_iterations, run=i
)

# global_trainer.add_camera(next_camera)
# global_trainer.run(global_iterations)
global_trainer.add_camera(next_camera)

if debug:
save_artifacts(
Expand All @@ -62,11 +64,12 @@ def main():
next_camera,
next_gaussian_model,
next_image,
iteration,
i,
)
if iteration >= 50:
break

global_trainer.run((i + 1) * global_iterations)

#global_trainer.run(global_iterations)

def _initialize_local_gaussian_model(
image, camera, iterations: int = 250, run: int = 0
Expand Down

0 comments on commit b48029b

Please sign in to comment.