Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
bouchardi committed Apr 10, 2024
1 parent 78fd273 commit 6bfe455
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 30 deletions.
33 changes: 19 additions & 14 deletions gaussian_splatting/colmap_free/local_initialization_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from matplotlib import pyplot as plt
from tqdm import tqdm
from transformers import pipeline
from pathlib import Path

from gaussian_splatting.dataset.cameras import Camera
from gaussian_splatting.model import GaussianModel
Expand Down Expand Up @@ -50,6 +51,10 @@ def __init__(self, image, sh_degree: int = 3, iterations: int = 10000):

safe_state(seed=2234)

self._output_path =Path(" artifacts/local/init/")
self._output_path.mkdir(exist_ok=True, parents=True)


def run(self, iterations: int = 3000):
progress_bar = tqdm(range(iterations), desc="Initialization")

Expand All @@ -60,16 +65,6 @@ def run(self, iterations: int = 3000):
self.camera, self.gaussian_model
)

if iteration % 100 == 0:
plt.cla()
plt.plot(losses)
plt.yscale("log")
plt.savefig("artifacts/local/init/losses.png")

torchvision.utils.save_image(
rendered_image, f"artifacts/local/init/rendered_{iteration}.png"
)

gt_image = self.camera.original_image.cuda()
loss = self._photometric_loss(rendered_image, gt_image)
loss.backward()
Expand All @@ -82,6 +77,9 @@ def run(self, iterations: int = 3000):
best_iteration = iteration
losses.append(loss.cpu().item())

if iteration % 100 == 0:
self._save_artifacts(self, losses, rendered_image, iteration)

with torch.no_grad():
# Densification
if iteration < self._densification_iteration_stop:
Expand Down Expand Up @@ -115,10 +113,7 @@ def run(self, iterations: int = 3000):
f"Training done. Best loss = {best_loss:.{5}f} at iteration {best_iteration}."
)

torchvision.utils.save_image(
rendered_image, f"artifacts/local/init/rendered_{iteration}.png"
)
torchvision.utils.save_image(gt_image, f"artifacts/local/init/gt.png")
torchvision.utils.save_image(gt_image, self._output_path / "gt.png")

def _get_orthogonal_camera(self, image):
camera = Camera(
Expand Down Expand Up @@ -180,3 +175,13 @@ def _load_DPT(self):
depth_estimator = pipeline("depth-estimation", model=checkpoint)

return depth_estimator

def _save_artifacts(self, losses, rendered_image, iteration):
plt.cla()
plt.plot(losses)
plt.yscale("log")
plt.savefig(self._output_path / "losses.png")

torchvision.utils.save_image(
rendered_image, self._output_path / f"rendered_{iteration}.png"
)
26 changes: 15 additions & 11 deletions gaussian_splatting/colmap_free/local_transformation_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torchvision
from matplotlib import pyplot as plt
from tqdm import tqdm
from pathlib import Path

from gaussian_splatting.colmap_free.transformation_model import \
AffineTransformationModel
Expand All @@ -23,9 +24,12 @@ def __init__(self, gaussian_model):
)
self._photometric_loss = PhotometricLoss(lambda_dssim=0.2)

self._output_path = Path(" artifacts/local/transfo/")
self._output_path.mkdir(exist_ok=True, parents=True)

safe_state(seed=2234)

def run(self, current_camera, gt_image, iterations: int = 1000):
def run(self, current_camera, gt_image, iterations: int = 1000, run: int = 0):
gt_image = gt_image.to(self.gaussian_model.get_xyz.device)

progress_bar = tqdm(range(iterations), desc="Transformation")
Expand All @@ -50,18 +54,18 @@ def run(self, current_camera, gt_image, iterations: int = 1000):
progress_bar.set_postfix({"Loss": f"{loss:.{5}f}"})
progress_bar.update(1)

if iteration % 10 == 0:
self._save_artifacts(losses, rendered_image, iteration)
if iteration % 10 == 0 or iteration == len(iterations) - 1:
self._save_artifacts(losses, rendered_image, iteration, run)

if best_loss is None or best_loss > loss:
best_loss = loss.cpu().item()
best_iteration = iteration
best_xyz = xyz.detach()
elif best_loss < loss and patience > 10:
self._save_artifacts(losses, rendered_image, iteration)
break
else:
patience += 1
#elif best_loss < loss and patience > 10:
# self._save_artifacts(losses, rendered_image, iteration, run)
# break
#else:
# patience += 1

progress_bar.close()

Expand All @@ -73,12 +77,12 @@ def run(self, current_camera, gt_image, iterations: int = 1000):

return rotation, translation

def _save_artifacts(self, losses, rendered_image, iteration):
def _save_artifacts(self, losses, rendered_image, iteration, run):
plt.cla()
plt.plot(losses)
plt.yscale("log")
plt.savefig("artifacts/local/transfo/losses.png")
plt.savefig(self._output_path / "losses.png")

torchvision.utils.save_image(
rendered_image, f"artifacts/local/transfo/rendered_{iteration}.png"
rendered_image, self._output_path / f"{run}_rendered_{iteration}.png"
)
12 changes: 7 additions & 5 deletions scripts/train_colmap_free.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@

def main():
debug = True
iteration_step_size = 5
initialization_iterations = 50
transformation_iterations = 50
global_iterations = 50
iteration_step_size = 50
initialization_iterations = 250
transformation_iterations = 250
global_iterations = 5

photometric_loss = PhotometricLoss(lambda_dssim=0.2)
dataset = ImageDataset(images_path=Path("data/phil/1/input/"))
Expand Down Expand Up @@ -53,6 +53,7 @@ def main():
current_camera,
next_image,
iterations=transformation_iterations,
run=iteration,
)

# Add new camera to Global3DGS training cameras
Expand All @@ -73,15 +74,16 @@ def main():
next_camera_image, _, _, _ = render(next_camera, current_gaussian_model)
next_gaussian_image, _, _, _ = render(current_camera, next_gaussian_model)
loss = photometric_loss(next_camera_image, next_gaussian_image)
assert loss < 0.01

print(loss)
torchvision.utils.save_image(
next_camera_image, f"artifacts/global/next_camera_{iteration}.png"
)
torchvision.utils.save_image(
next_gaussian_image, f"artifacts/global/next_gaussian_{iteration}.png"
)

break
# global_trainer.add_camera(next_camera)
# global_trainer.run(global_iterations)

Expand Down

0 comments on commit 6bfe455

Please sign in to comment.