forked from graphdeco-inria/gaussian-splatting
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
118 additions
and
108 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,72 +1,85 @@ | ||
import copy | ||
from pathlib import Path | ||
|
||
import copy | ||
import numpy as np | ||
import torchvision | ||
import torch | ||
from tqdm import tqdm | ||
|
||
from gaussian_splatting.dataset.cameras import Camera | ||
from gaussian_splatting.render import render | ||
from gaussian_splatting.dataset.image_dataset import ImageDataset | ||
from gaussian_splatting.global_trainer import GlobalTrainer | ||
from gaussian_splatting.local_initialization_trainer import \ | ||
LocalInitializationTrainer | ||
from gaussian_splatting.local_transformation_trainer import \ | ||
LocalTransformationTrainer | ||
from gaussian_splatting.render import render | ||
from gaussian_splatting.global_trainer import GlobalTrainer | ||
from gaussian_splatting.utils.general import PILtoTorch | ||
from gaussian_splatting.utils.loss import PhotometricLoss | ||
|
||
|
||
def main(): | ||
debug = True | ||
iteration_step_size = 5 | ||
initialization_iterations = 50 | ||
transformation_iterations = 50 | ||
|
||
photometric_loss = PhotometricLoss(lambda_dssim=0.2) | ||
dataset = ImageDataset(images_path=Path("data/phil/1/input/")) | ||
first_image = dataset.get_frame(0) | ||
|
||
local_initialization_trainer = LocalInitializationTrainer( | ||
first_image, iterations=500 | ||
) | ||
local_initialization_trainer.run() | ||
# Initialize Local3DGS gaussians | ||
current_image = dataset.get_frame(0) | ||
local_initialization_trainer = LocalInitializationTrainer(current_image) | ||
local_initialization_trainer.run(iterations=initialization_iterations) | ||
|
||
next_gaussian_model = local_initialization_trainer.gaussian_model | ||
local_transformation_trainer = LocalTransformationTrainer(next_gaussian_model) | ||
#global_trainer = GlobalTrainer( | ||
# gaussian_model=gaussian_model, | ||
# cameras=[camera], | ||
# iterations=100 | ||
#) | ||
|
||
current_camera = local_initialization_trainer.camera | ||
for iteration in tqdm(range(iteration_step_size, len(dataset), iteration_step_size)): | ||
# Keep a copy of current gaussians to compare | ||
with torch.no_grad(): | ||
current_gaussian_model = copy.deepcopy(next_gaussian_model) | ||
|
||
step = 5 | ||
gaussian_model = local_initialization_trainer.gaussian_model | ||
camera = local_initialization_trainer.camera | ||
global_trainer = GlobalTrainer( | ||
gaussian_model=gaussian_model, cameras=[camera], iterations=100 | ||
) | ||
for iteration in range(step, len(dataset), step): | ||
next_image = dataset.get_frame(iteration) | ||
local_transformation_trainer = LocalTransformationTrainer( | ||
# Find transformation from current to next camera poses | ||
next_image = PILtoTorch(dataset.get_frame(iteration)) | ||
rotation, translation = local_transformation_trainer.run( | ||
current_camera, | ||
next_image, | ||
camera=camera, | ||
gaussian_model=copy.deepcopy(gaussian_model), | ||
iterations=250, | ||
iterations=transformation_iterations, | ||
) | ||
local_transformation_trainer.run() | ||
rotation, translation = local_transformation_trainer.get_affine_transformation() | ||
|
||
next_image = PILtoTorch(next_image) | ||
# Add new camera to Global3DGS training cameras | ||
next_camera = Camera( | ||
R=np.matmul(camera.R, rotation), | ||
T=camera.T + translation, | ||
FoVx=camera.FoVx, | ||
FoVy=camera.FoVy, | ||
R=np.matmul(current_camera.R, rotation), | ||
T=current_camera.T + translation, | ||
FoVx=current_camera.FoVx, | ||
FoVy=current_camera.FoVy, | ||
image=next_image, | ||
gt_alpha_mask=None, | ||
image_name="patate", | ||
colmap_id=iteration, | ||
uid=iteration, | ||
) | ||
global_trainer.add_camera(next_camera) | ||
global_trainer.run() | ||
|
||
rendered_image, viewspace_point_tensor, visibility_filter, radii = render( | ||
next_camera, | ||
local_initialization_trainer.gaussian_model, | ||
) | ||
torchvision.utils.save_image( | ||
rendered_image, f"artifacts/global/rendered_{iteration}.png" | ||
) | ||
torchvision.utils.save_image(next_image, f"artifacts/global/gt_{iteration}.png") | ||
if debug: | ||
# Save artifact | ||
next_camera_image, _, _, _ = render(next_camera, current_gaussian_model) | ||
next_gaussian_image, _, _, _ = render(current_camera, next_gaussian_model) | ||
loss = photometric_loss(next_camera_image, next_gaussian_image) | ||
assert loss < 0.01 | ||
|
||
print(f">>> Iteration {iteration} / {len(dataset) // step}") | ||
torchvision.utils.save_image(next_camera_image, f"artifacts/global/next_camera_{iteration}.png") | ||
torchvision.utils.save_image(next_gaussian_image, f"artifacts/global/next_gaussian_{iteration}.png") | ||
#global_trainer.add_camera(next_camera) | ||
#global_trainer.run() | ||
|
||
current_camera = next_camera | ||
|
||
if __name__ == "__main__": | ||
main() |