From 0f6562fe71bfa01ddbc8516f1a6bb86c5b037528 Mon Sep 17 00:00:00 2001 From: Dario Date: Tue, 17 Oct 2023 19:22:33 +0000 Subject: [PATCH 1/6] Use Python3.10 --- environment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/environment.yml b/environment.yml index 8107da51..30997bfb 100644 --- a/environment.yml +++ b/environment.yml @@ -5,7 +5,7 @@ channels: - defaults dependencies: - cudatoolkit=11.6 - - python=3.7.13 + - python=3.10 - pip=22.3.1 - pytorch=1.12.1 - torchaudio=0.12.1 From 9df9a7ee96613599231638e489d608b4baf9fb0f Mon Sep 17 00:00:00 2001 From: Dario Date: Tue, 17 Oct 2023 19:22:57 +0000 Subject: [PATCH 2/6] Setup camera is more verbose but clearer --- helpers.py | 49 ++++++++++++++++++++++++++++++++++--------------- 1 file changed, 34 insertions(+), 15 deletions(-) diff --git a/helpers.py b/helpers.py index f11a2aeb..a4fcad0c 100644 --- a/helpers.py +++ b/helpers.py @@ -5,24 +5,43 @@ from diff_gaussian_rasterization import GaussianRasterizationSettings as Camera -def setup_camera(w, h, k, w2c, near=0.01, far=100): - fx, fy, cx, cy = k[0][0], k[1][1], k[0][2], k[1][2] - w2c = torch.tensor(w2c).cuda().float() - cam_center = torch.inverse(w2c)[:3, 3] - w2c = w2c.unsqueeze(0).transpose(1, 2) - opengl_proj = torch.tensor([[2 * fx / w, 0.0, -(w - 2 * cx) / w, 0.0], - [0.0, 2 * fy / h, -(h - 2 * cy) / h, 0.0], - [0.0, 0.0, far / (far - near), -(far * near) / (far - near)], - [0.0, 0.0, 1.0, 0.0]]).cuda().float().unsqueeze(0).transpose(1, 2) - full_proj = w2c.bmm(opengl_proj) +def setup_camera( + width:int, + height:int, + instrinsics:np.ndarray, + world_2_cam:np.ndarray, + near:float=0.01, # Near and far clipping planes for depth in the camera's view frustum. + far:float=100 # meters? + ) -> Camera: + # Focal length, (x, y in pixels) --- optical center (x, y) + fx, fy, cx, cy = instrinsics[0][0], instrinsics[1][1], instrinsics[0][2], instrinsics[1][2] + + world_2_cam_tensor:torch.Tensor = torch.tensor(world_2_cam).cuda().float() + + # position of the camera center in the world coordinates. + cam_center = torch.inverse(world_2_cam_tensor)[:3, 3] + world_2_cam_tensor = world_2_cam_tensor.unsqueeze(0).transpose(1, 2) + + # This matrix is used to map 3D world coordinates to 2D camera coordinates, factoring in the depth. + opengl_proj = torch.tensor([ + [ 2 * fx / width, 0.0, -(width - 2 * cx) / width, 0.0 ], + [ 0.0, 2 * fy / height, -(height - 2 * cy) / height, 0.0 ], + [ 0.0, 0.0, far / (far - near), -(far * near) / (far - near) ], + [ 0.0, 0.0, 1.0, 0.0 ] + ]).cuda().float().unsqueeze(0).transpose(1, 2) + + # This will give a matrix that transforms from world coordinates + # directly to normalized device coordinates (NDC) used in graphics + full_proj = world_2_cam_tensor.bmm(opengl_proj) + cam = Camera( - image_height=h, - image_width=w, - tanfovx=w / (2 * fx), - tanfovy=h / (2 * fy), + image_height=height, + image_width=width, + tanfovx=width / (2 * fx), + tanfovy=height / (2 * fy), bg=torch.tensor([0, 0, 0], dtype=torch.float32, device="cuda"), scale_modifier=1.0, - viewmatrix=w2c, + viewmatrix=world_2_cam_tensor, projmatrix=full_proj, sh_degree=0, campos=cam_center, From d9c92915e43b416ebccd671aa2f8488c89a0df63 Mon Sep 17 00:00:00 2001 From: Dario Date: Tue, 17 Oct 2023 19:23:18 +0000 Subject: [PATCH 3/6] Added comments and variable names tell the story --- visualize.py | 232 +++++++++++++++++++++++++++++++++------------------ 1 file changed, 152 insertions(+), 80 deletions(-) diff --git a/visualize.py b/visualize.py index c6e7e2d0..c4330258 100644 --- a/visualize.py +++ b/visualize.py @@ -2,6 +2,7 @@ import torch import numpy as np import open3d as o3d +from open3d.cuda.pybind.utility import Vector3dVector import time from diff_gaussian_rasterization import GaussianRasterizer as Renderer from helpers import setup_camera, quat_mult @@ -23,28 +24,65 @@ FORCE_LOOP = False # False or True # FORCE_LOOP = True # False or True -w, h = 640, 360 +width:int = 640 +height:int = 360 + near, far = 0.01, 100.0 view_scale = 3.9 fps = 20 traj_frac = 25 # 4% of points traj_length = 15 -def_pix = torch.tensor( - np.stack(np.meshgrid(np.arange(w) + 0.5, np.arange(h) + 0.5, 1), -1).reshape(-1, 3)).cuda().float() -pix_ones = torch.ones(h * w, 1).cuda().float() - - -def init_camera(y_angle=0., center_dist=2.4, cam_height=1.3, f_ratio=0.82): - ry = y_angle * np.pi / 180 - w2c = np.array([[np.cos(ry), 0., -np.sin(ry), 0.], - [0., 1., 0., cam_height], - [np.sin(ry), 0., np.cos(ry), center_dist], - [0., 0., 0., 1.]]) - k = np.array([[f_ratio * w, 0, w / 2], [0, f_ratio * w, h / 2], [0, 0, 1]]) - return w2c, k - -def load_scene_data(seq, exp, seg_as_col=False): +# This tensor represents a set of homogeneous 2D coordinates for each pixel in the image. +# The last dimension (a constant value of 1) makes these coordinates homogeneous, +# which is useful for matrix multiplications in projective geometry. +def_pix = torch.tensor( + np.stack( + np.meshgrid(np.arange(width) + 0.5, + np.arange(height) + 0.5, + 1), -1 + ).reshape(-1, 3) + ).cuda().float() + +pix_ones = torch.ones(height * width, 1).cuda().float() + + +def init_camera( + y_angle:float=0., # degrees + center_dist:float=2.4, # meters? + cam_height:float=1.3, # meters? + f_ratio:float=0.82 + ) -> tuple[np.ndarray, np.ndarray]: + radians_y = y_angle * np.pi / 180. # radians + + world_2_cam = np.array([ + [np.cos(radians_y), 0., -np.sin(radians_y), 0.], + [0., 1., 0., cam_height], + [np.sin(radians_y), 0., np.cos(radians_y), center_dist], + [0., 0., 0., 1.] + ]) + + # Focal lengths + fx = f_ratio * width # Focal length in the x direction (in pixels) + fy = f_ratio * width # Assuming square pixels, so fx = fy. Change if different. + + # Optical center coordinates (typically the center of the image) + cx = width / 2 # x-coordinate of the principal point (optical center) + cy = height / 2 # y-coordinate of the principal point (optical center) + + camera_intrinsics = np.array([ + [fx, 0, cx], + [0, fy, cy], + [0, 0, 1] + ]) + return world_2_cam, camera_intrinsics + + +def load_scene_data( + seq:str, + exp:str, + seg_as_col:bool=False + ) -> tuple[list[dict[str, torch.Tensor]], torch.Tensor]: params = dict(np.load(f"./output/{exp}/{seq}/params.npz")) params = {k: torch.tensor(v).cuda().float() for k, v in params.items()} is_fg = params['seg_colors'][:, 0] > 0.5 @@ -107,71 +145,102 @@ def calculate_rot_vec(scene_data, is_fg): return make_lineset(out_pts, cols, num_lines) -def render(w2c, k, timestep_data): +def render( + world_2_cam:np.ndarray, + instrinsics:np.ndarray, + timestep_data:dict[str, torch.Tensor] + ) -> tuple[torch.Tensor, torch.Tensor]: with torch.no_grad(): - cam = setup_camera(w, h, k, w2c, near, far) - im, _, depth, = Renderer(raster_settings=cam)(**timestep_data) - return im, depth - - -def rgbd2pcd(im, depth, w2c, k, show_depth=False, project_to_cam_w_scale=None): - d_near = 1.5 - d_far = 6 - invk = torch.inverse(torch.tensor(k).cuda().float()) - c2w = torch.inverse(torch.tensor(w2c).cuda().float()) + cam = setup_camera(width, height, instrinsics, world_2_cam, near, far) + image, _, depth, = Renderer(raster_settings=cam)(**timestep_data) + # [3, height, width], [1, height, width] + return image, depth + + +def rgbd_2_pointcloud( + image:torch.Tensor, + depth:torch.Tensor, + world_2_cam:np.ndarray, + instrinsics:np.ndarray, + show_depth:bool = False, + project_to_cam_w_scale:float | None = None + ) -> tuple[Vector3dVector, Vector3dVector]: + depth_near:float = 1.5 + depth_far:float = 6 + + # intrinsic matrix is useful when you want to back-project 2D points from the image plane to 3D rays in camera space + inv_instrinsics = torch.inverse(torch.tensor(instrinsics).cuda().float()) + # this matrix transforms points from the camera frame back to the world frame. + inv_world_2_cam = torch.inverse(torch.tensor(world_2_cam).cuda().float()) + radial_depth = depth[0].reshape(-1) - def_rays = (invk @ def_pix.T).T + # This operation back-projects 'def_pix' 2D coordinates into 3D camera coordinates, + # effectively producing rays (direction vectors) for each pixel that emanate from the camera's optical center + def_rays = (inv_instrinsics @ def_pix.T).T def_radial_rays = def_rays / torch.linalg.norm(def_rays, ord=2, dim=-1)[:, None] - pts_cam = def_radial_rays * radial_depth[:, None] - z_depth = pts_cam[:, 2] + + # Computes the actual 3D coordinates of the points in the scene in the camera's coordinate system using the depth information. + camera_space_points = def_radial_rays * radial_depth[:, None] + depth_values_along_optical_axis = camera_space_points[:, 2] + if project_to_cam_w_scale is not None: - pts_cam = project_to_cam_w_scale * pts_cam / z_depth[:, None] - pts4 = torch.concat((pts_cam, pix_ones), 1) - pts = (c2w @ pts4.T).T[:, :3] - if show_depth: - cols = ((z_depth - d_near) / (d_far - d_near))[:, None].repeat(1, 3) + camera_space_points = project_to_cam_w_scale * camera_space_points / depth_values_along_optical_axis[:, None] + + # Converts the 3D coordinates in camera_space_points into 4D homogeneous coordinates. + camera_space_points_homogeneous = torch.concat((camera_space_points, pix_ones), 1) + # from the camera's coordinate space to the world coordinate space. + world_space_points = (inv_world_2_cam @ camera_space_points_homogeneous.T).T[:, :3] + + if show_depth: # based on their depth values. + colors = ((depth_values_along_optical_axis - depth_near) / (depth_far - depth_near))[:, None].repeat(1, 3) else: - cols = torch.permute(im, (1, 2, 0)).reshape(-1, 3) - pts = o3d.utility.Vector3dVector(pts.contiguous().double().cpu().numpy()) - cols = o3d.utility.Vector3dVector(cols.contiguous().double().cpu().numpy()) - return pts, cols + colors = torch.permute(image, (1, 2, 0)).reshape(-1, 3) + + world_space_points: Vector3dVector = o3d.utility.Vector3dVector(world_space_points.contiguous().double().cpu().numpy()) + colors: Vector3dVector = o3d.utility.Vector3dVector(colors.contiguous().double().cpu().numpy()) + + return world_space_points, colors -def visualize(seq, exp): +def visualize(seq:str, exp:str): scene_data, is_fg = load_scene_data(seq, exp) + vis = o3d.visualization.Visualizer() #type: ignore + vis.create_window(width=int(width * view_scale), height=int(height * view_scale), visible=True) - vis = o3d.visualization.Visualizer() - vis.create_window(width=int(w * view_scale), height=int(h * view_scale), visible=True) + world_2_cam, intrinsics = init_camera(y_angle=60.) + image, depth = render(world_2_cam, intrinsics, scene_data[0]) + init_points, init_colors = rgbd_2_pointcloud(image, depth, world_2_cam, intrinsics, show_depth=(RENDER_MODE == 'depth')) - w2c, k = init_camera() - im, depth = render(w2c, k, scene_data[0]) - init_pts, init_cols = rgbd2pcd(im, depth, w2c, k, show_depth=(RENDER_MODE == 'depth')) - pcd = o3d.geometry.PointCloud() - pcd.points = init_pts - pcd.colors = init_cols - vis.add_geometry(pcd) + # collection of points in 3D space, each potentially having associated properties like color, normals, etc + point_cloud = o3d.geometry.PointCloud() + point_cloud.points = init_points + point_cloud.colors = init_colors + vis.add_geometry(point_cloud) - linesets = None - lines = None + lines: o3d.geometry.LineSet | None = None + linesets:list[o3d.geometry.LineSet] | None = None if ADDITIONAL_LINES is not None: if ADDITIONAL_LINES == 'trajectories': linesets = calculate_trajectories(scene_data, is_fg) elif ADDITIONAL_LINES == 'rotations': linesets = calculate_rot_vec(scene_data, is_fg) + else: + raise ValueError(f"Unsupported value for ADDITIONAL_LINES") lines = o3d.geometry.LineSet() - lines.points = linesets[0].points - lines.colors = linesets[0].colors - lines.lines = linesets[0].lines + lines.points = linesets[0].points #type: ignore + lines.colors = linesets[0].colors #type: ignore + lines.lines = linesets[0].lines #type: ignore vis.add_geometry(lines) - view_k = k * view_scale - view_k[2, 2] = 1 + # adjust the focal length and optical center according to the view_scale. + view_intrinsics = intrinsics * view_scale + view_intrinsics[2, 2] = 1 # don't scale depth view_control = vis.get_view_control() cparams = o3d.camera.PinholeCameraParameters() - cparams.extrinsic = w2c - cparams.intrinsic.intrinsic_matrix = view_k - cparams.intrinsic.height = int(h * view_scale) - cparams.intrinsic.width = int(w * view_scale) + cparams.extrinsic = world_2_cam + cparams.intrinsic.intrinsic_matrix = view_intrinsics + cparams.intrinsic.height = int(height * view_scale) + cparams.intrinsic.width = int(width * view_scale) view_control.convert_from_pinhole_camera_parameters(cparams, allow_arbitrary=True) render_options = vis.get_render_option() @@ -184,39 +253,42 @@ def visualize(seq, exp): passed_time = time.time() - start_time passed_frames = passed_time * fps if ADDITIONAL_LINES == 'trajectories': - t = int(passed_frames % (num_timesteps - traj_length)) + traj_length # Skip t that don't have full traj. + time_step = int(passed_frames % (num_timesteps - traj_length)) + traj_length # Skip t that don't have full traj. else: - t = int(passed_frames % num_timesteps) + time_step = int(passed_frames % num_timesteps) if FORCE_LOOP: num_loops = 1.4 - y_angle = 360*t*num_loops / num_timesteps - w2c, k = init_camera(y_angle) + y_angle = 360 * time_step * num_loops / num_timesteps + world_2_cam, instrinsics = init_camera(y_angle) cam_params = view_control.convert_to_pinhole_camera_parameters() - cam_params.extrinsic = w2c + cam_params.extrinsic = world_2_cam view_control.convert_from_pinhole_camera_parameters(cam_params, allow_arbitrary=True) else: # Interactive control cam_params = view_control.convert_to_pinhole_camera_parameters() - view_k = cam_params.intrinsic.intrinsic_matrix - k = view_k / view_scale - k[2, 2] = 1 - w2c = cam_params.extrinsic + view_instrinsics = cam_params.intrinsic.intrinsic_matrix + instrinsics = view_instrinsics / view_scale + instrinsics[2, 2] = 1 + world_2_cam = cam_params.extrinsic if RENDER_MODE == 'centers': - pts = o3d.utility.Vector3dVector(scene_data[t]['means3D'].contiguous().double().cpu().numpy()) - cols = o3d.utility.Vector3dVector(scene_data[t]['colors_precomp'].contiguous().double().cpu().numpy()) + points = o3d.utility.Vector3dVector(scene_data[time_step]['means3D'].contiguous().double().cpu().numpy()) + colors = o3d.utility.Vector3dVector(scene_data[time_step]['colors_precomp'].contiguous().double().cpu().numpy()) else: - im, depth = render(w2c, k, scene_data[t]) - pts, cols = rgbd2pcd(im, depth, w2c, k, show_depth=(RENDER_MODE == 'depth')) - pcd.points = pts - pcd.colors = cols - vis.update_geometry(pcd) - - if ADDITIONAL_LINES is not None: + image, depth = render(world_2_cam, instrinsics, scene_data[time_step]) + points, colors = rgbd_2_pointcloud(image, depth, world_2_cam, instrinsics, show_depth=(RENDER_MODE == 'depth')) + + point_cloud.points = points + point_cloud.colors = colors + vis.update_geometry(point_cloud) + + if ADDITIONAL_LINES is not None and \ + lines is not None and \ + linesets is not None: if ADDITIONAL_LINES == 'trajectories': - lt = t - traj_length + lt = time_step - traj_length else: - lt = t + lt = time_step lines.points = linesets[lt].points lines.colors = linesets[lt].colors lines.lines = linesets[lt].lines From 062acee63b11e46f8ea177f8cf6e1f0f1cf937c9 Mon Sep 17 00:00:00 2001 From: Dario Date: Tue, 17 Oct 2023 19:54:39 +0000 Subject: [PATCH 4/6] Removed bad comments --- helpers.py | 2 +- visualize.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/helpers.py b/helpers.py index a4fcad0c..028b2b34 100644 --- a/helpers.py +++ b/helpers.py @@ -31,7 +31,7 @@ def setup_camera( ]).cuda().float().unsqueeze(0).transpose(1, 2) # This will give a matrix that transforms from world coordinates - # directly to normalized device coordinates (NDC) used in graphics + # directly to normalized device coordinates (NDC) full_proj = world_2_cam_tensor.bmm(opengl_proj) cam = Camera( diff --git a/visualize.py b/visualize.py index c4330258..c4a8b4d3 100644 --- a/visualize.py +++ b/visualize.py @@ -34,8 +34,7 @@ traj_length = 15 # This tensor represents a set of homogeneous 2D coordinates for each pixel in the image. -# The last dimension (a constant value of 1) makes these coordinates homogeneous, -# which is useful for matrix multiplications in projective geometry. +# The last dimension (a constant value of 1) makes these coordinates homogeneous. def_pix = torch.tensor( np.stack( np.meshgrid(np.arange(width) + 0.5, From c4cdcf82d12cf8f086347c544f804e5b91f6ea01 Mon Sep 17 00:00:00 2001 From: Dario Date: Fri, 20 Oct 2023 08:21:55 +0000 Subject: [PATCH 5/6] Small comments changes and type hints --- helpers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/helpers.py b/helpers.py index 028b2b34..1be96aa9 100644 --- a/helpers.py +++ b/helpers.py @@ -10,8 +10,8 @@ def setup_camera( height:int, instrinsics:np.ndarray, world_2_cam:np.ndarray, - near:float=0.01, # Near and far clipping planes for depth in the camera's view frustum. - far:float=100 # meters? + near:float, # Near and far clipping planes for depth in the camera's view frustum. (in meters?) + far:float ) -> Camera: # Focal length, (x, y in pixels) --- optical center (x, y) fx, fy, cx, cy = instrinsics[0][0], instrinsics[1][1], instrinsics[0][2], instrinsics[1][2] @@ -50,7 +50,7 @@ def setup_camera( return cam -def params2rendervar(params): +def params2rendervar(params:dict) -> dict: rendervar = { 'means3D': params['means3D'], 'colors_precomp': params['rgb_colors'], From 7cb50d27c1c01ff570304d3b9ee887a96b2dc24e Mon Sep 17 00:00:00 2001 From: Dario Date: Fri, 20 Oct 2023 08:22:10 +0000 Subject: [PATCH 6/6] Refactor loss computation --- loss.py | 100 ++++++++++++++++++++++ train.py | 255 ++++++++++++++++++++++++++++--------------------------- 2 files changed, 230 insertions(+), 125 deletions(-) create mode 100644 loss.py diff --git a/loss.py b/loss.py new file mode 100644 index 00000000..0ac5c2d4 --- /dev/null +++ b/loss.py @@ -0,0 +1,100 @@ +import torch + +from helpers import l1_loss_v1 +from helpers import weighted_l2_loss_v1 +from helpers import weighted_l2_loss_v2 +from helpers import l1_loss_v1 +from helpers import l1_loss_v2 +from helpers import quat_mult +from helpers import params2rendervar + +from external import calc_ssim +from external import build_rotation + +from diff_gaussian_rasterization import GaussianRasterizer as Renderer + +L1_LOSS_WEIGHT:float = 0.8 +SSIM_LOSS_WEIGHT:float = 0.2 +LOSS_WEIGTHS = { + 'im': 1.0, + 'seg': 3.0, + 'rigid': 4.0, + 'rot': 4.0, + 'iso': 2.0, + 'floor': 2.0, + 'bg': 20.0, + 'soft_col_cons': 0.01 +} + +def apply_camera_parameters(image: torch.Tensor, params: dict, curr_data: dict) -> torch.Tensor: + curr_id = curr_data['id'] + return torch.exp(params['cam_m'][curr_id])[:, None, None] * image + params['cam_c'][curr_id][:, None, None] + +def compute_loss(rendered: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + l1 = l1_loss_v1(rendered, target) + ssim = 1.0 - calc_ssim(rendered, target) + return L1_LOSS_WEIGHT * l1 + SSIM_LOSS_WEIGHT * ssim + +def compute_rigid_loss(fg_pts, rot, variables): + neighbor_pts = fg_pts[variables["neighbor_indices"]] + curr_offset = neighbor_pts - fg_pts[:, None] + curr_offset_in_prev_coord = (rot.transpose(2, 1)[:, None] @ curr_offset[:, :, :, None]).squeeze(-1) + return weighted_l2_loss_v2(curr_offset_in_prev_coord, variables["prev_offset"], variables["neighbor_weight"]) + +def compute_rot_loss(rel_rot, variables): + return weighted_l2_loss_v2(rel_rot[variables["neighbor_indices"]], rel_rot[:, None], variables["neighbor_weight"]) + +def compute_iso_loss(fg_pts, variables): + neighbor_pts = fg_pts[variables["neighbor_indices"]] + curr_offset = neighbor_pts - fg_pts[:, None] + curr_offset_mag = torch.sqrt((curr_offset ** 2).sum(-1) + 1e-20) + return weighted_l2_loss_v1(curr_offset_mag, variables["neighbor_dist"], variables["neighbor_weight"]) + +def compute_floor_loss(fg_pts): + return torch.clamp(fg_pts[:, 1], min=0).mean() + +def compute_bg_loss(bg_pts, bg_rot, variables): + return l1_loss_v2(bg_pts, variables["init_bg_pts"]) + l1_loss_v2(bg_rot, variables["init_bg_rot"]) + + +def get_loss(params:dict, curr_data:dict, variables:dict, is_initial_timestep:bool): + + losses = {} + + # Image + rendervar = params2rendervar(params) + rendervar['means2D'].retain_grad() + image, radius, _, = Renderer(raster_settings=curr_data['cam'])(**rendervar) + image = apply_camera_parameters(image, params, curr_data) + + # Segmentation + variables['means2D'] = rendervar['means2D'] # Gradient only accum from colour render for densification + segrendervar = params2rendervar(params) + segrendervar['colors_precomp'] = params['seg_colors'] + seg, _, _, = Renderer(raster_settings=curr_data['cam'])(**segrendervar) + + losses['im'] = compute_loss(image, curr_data['im']) + losses['seg'] = compute_loss(seg, curr_data['seg']) + + if not is_initial_timestep: + is_fg = (params['seg_colors'][:, 0] > 0.5).detach() + fg_pts = rendervar['means3D'][is_fg] + fg_rot = rendervar['rotations'][is_fg] + rel_rot = quat_mult(fg_rot, variables["prev_inv_rot_fg"]) + rot = build_rotation(rel_rot) + + losses['rigid'] = compute_rigid_loss(fg_pts, rot, variables) + losses['rot'] = compute_rot_loss(rel_rot, variables) + losses['iso'] = compute_iso_loss(fg_pts, variables) + losses['floor'] = compute_floor_loss(fg_pts) + + bg_pts = rendervar['means3D'][~is_fg] + bg_rot = rendervar['rotations'][~is_fg] + losses['bg'] = compute_bg_loss(bg_pts, bg_rot, variables) + losses['soft_col_cons'] = l1_loss_v2(params['rgb_colors'], variables["prev_col"]) + + loss = sum([LOSS_WEIGTHS[k] * v for k, v in losses.items()]) + seen = radius > 0 + variables['max_2D_radius'][seen] = torch.max(radius[seen], variables['max_2D_radius'][seen]) + variables['seen'] = seen + return loss, variables \ No newline at end of file diff --git a/train.py b/train.py index d4b0ed94..5520c8d6 100644 --- a/train.py +++ b/train.py @@ -1,59 +1,90 @@ -import torch import os import json +import random import copy + +import torch import numpy as np from PIL import Image -from random import randint from tqdm import tqdm + from diff_gaussian_rasterization import GaussianRasterizer as Renderer -from helpers import setup_camera, l1_loss_v1, l1_loss_v2, weighted_l2_loss_v1, weighted_l2_loss_v2, quat_mult, \ - o3d_knn, params2rendervar, params2cpu, save_params -from external import calc_ssim, calc_psnr, build_rotation, densify, update_params_and_optimizer - - -def get_dataset(t, md, seq): - dataset = [] - for c in range(len(md['fn'][t])): - w, h, k, w2c = md['w'], md['h'], md['k'][t][c], md['w2c'][t][c] - cam = setup_camera(w, h, k, w2c, near=1.0, far=100) - fn = md['fn'][t][c] - im = np.array(copy.deepcopy(Image.open(f"./data/{seq}/ims/{fn}"))) - im = torch.tensor(im).float().cuda().permute(2, 0, 1) / 255 - seg = np.array(copy.deepcopy(Image.open(f"./data/{seq}/seg/{fn.replace('.jpg', '.png')}"))).astype(np.float32) - seg = torch.tensor(seg).float().cuda() - seg_col = torch.stack((seg, torch.zeros_like(seg), 1 - seg)) - dataset.append({'cam': cam, 'im': im, 'seg': seg_col, 'id': c}) - return dataset - - -def get_batch(todo_dataset, dataset): - if not todo_dataset: - todo_dataset = dataset.copy() - curr_data = todo_dataset.pop(randint(0, len(todo_dataset) - 1)) - return curr_data - - -def initialize_params(seq, md): - init_pt_cld = np.load(f"./data/{seq}/init_pt_cld.npz")["data"] - seg = init_pt_cld[:, 6] - max_cams = 50 - sq_dist, _ = o3d_knn(init_pt_cld[:, :3], 3) - mean3_sq_dist = sq_dist.mean(-1).clip(min=0.0000001) + +from helpers import setup_camera +from helpers import o3d_knn +from helpers import params2rendervar +from helpers import params2cpu +from helpers import save_params + +from external import calc_psnr +from external import densify +from external import update_params_and_optimizer + +from loss import get_loss + +MAX_CAMS:int = 50 +NUM_NEAREST_NEIGH:int = 3 +SCENE_SIZE_MULT:float = 1.1 + +# Camera +NEAR:float = 1.0 +FAR:float = 100. + +# Training Hyperparams +INITIAL_TIMESTEP_ITERATIONS = 10_000 +TIMESTEP_ITERATIONS = 2_000 + + +def construct_timestep_dataset(timestep:int, metadata:dict, sequence:str) -> list[dict]: + dataset_entries = [] + for camera_id in range(len(metadata['fn'][timestep])): + width, height, intrinsics, extrinsics = metadata['w'], metadata['h'], metadata['k'][timestep][camera_id], metadata['w2c'][timestep][camera_id] + camera = setup_camera(width, height, intrinsics, extrinsics, near=NEAR, far=FAR) + + filename = metadata['fn'][timestep][camera_id] + + image = np.array(copy.deepcopy(Image.open(f"./data/{sequence}/ims/{filename}"))) + image_tensor = torch.tensor(image).float().cuda().permute(2, 0, 1) / 255. + + segmentation = np.array(copy.deepcopy(Image.open(f"./data/{sequence}/seg/{filename.replace('.jpg', '.png')}"))).astype(np.float32) + segmentation_tensor = torch.tensor(segmentation).float().cuda() + segmentation_color = torch.stack((segmentation_tensor, torch.zeros_like(segmentation_tensor), 1 - segmentation_tensor)) + + dataset_entries.append({'cam': camera, 'im': image_tensor, 'seg': segmentation_color, 'id': camera_id}) + + return dataset_entries + + +def initialize_batch_sampler(dataset:list[dict]) -> list[int]: + indices = list(range(len(dataset))) + random.shuffle(indices) + return indices + + +def get_data_point(batch_sampler:list[int], dataset:list[dict]) -> dict: + if len(batch_sampler) < 1: batch_sampler = initialize_batch_sampler(dataset) + return dataset[batch_sampler.pop()] + + +def initialize_params(sequence:str, metadata:dict) -> tuple[dict, dict]: + init_pt_cld:np.ndarray = np.load(f"./data/{sequence}/init_pt_cld.npz")["data"] + segmentation = init_pt_cld[:, 6] + square_distance, _ = o3d_knn(init_pt_cld[:, :3], NUM_NEAREST_NEIGH) + mean_square_distance = square_distance.mean(-1).clip(min=1e-7) params = { 'means3D': init_pt_cld[:, :3], 'rgb_colors': init_pt_cld[:, 3:6], - 'seg_colors': np.stack((seg, np.zeros_like(seg), 1 - seg), -1), - 'unnorm_rotations': np.tile([1, 0, 0, 0], (seg.shape[0], 1)), - 'logit_opacities': np.zeros((seg.shape[0], 1)), - 'log_scales': np.tile(np.log(np.sqrt(mean3_sq_dist))[..., None], (1, 3)), - 'cam_m': np.zeros((max_cams, 3)), - 'cam_c': np.zeros((max_cams, 3)), + 'seg_colors': np.stack((segmentation, np.zeros_like(segmentation), 1 - segmentation), -1), + 'unnorm_rotations': np.tile([1, 0, 0, 0], (segmentation.shape[0], 1)), + 'logit_opacities': np.zeros((segmentation.shape[0], 1)), + 'log_scales': np.tile(np.log(np.sqrt(mean_square_distance))[..., None], (1, 3)), + 'cam_m': np.zeros((MAX_CAMS, 3)), + 'cam_c': np.zeros((MAX_CAMS, 3)), } params = {k: torch.nn.Parameter(torch.tensor(v).cuda().float().contiguous().requires_grad_(True)) for k, v in params.items()} - cam_centers = np.linalg.inv(md['w2c'][0])[:, :3, 3] # Get scene radius - scene_radius = 1.1 * np.max(np.linalg.norm(cam_centers - np.mean(cam_centers, 0)[None], axis=-1)) + cam_centers = np.linalg.inv(metadata['w2c'][0])[:, :3, 3] # Get scene radius + scene_radius = SCENE_SIZE_MULT * np.max(np.linalg.norm(cam_centers - np.mean(cam_centers, 0)[None], axis=-1)) variables = {'max_2D_radius': torch.zeros(params['means3D'].shape[0]).cuda().float(), 'scene_radius': scene_radius, 'means2D_gradient_accum': torch.zeros(params['means3D'].shape[0]).cuda().float(), @@ -61,7 +92,7 @@ def initialize_params(seq, md): return params, variables -def initialize_optimizer(params, variables): +def initialize_optimizer(params:dict, variables:dict): lrs = { 'means3D': 0.00016 * variables['scene_radius'], 'rgb_colors': 0.0025, @@ -76,81 +107,42 @@ def initialize_optimizer(params, variables): return torch.optim.Adam(param_groups, lr=0.0, eps=1e-15) -def get_loss(params, curr_data, variables, is_initial_timestep): - losses = {} - - rendervar = params2rendervar(params) - rendervar['means2D'].retain_grad() - im, radius, _, = Renderer(raster_settings=curr_data['cam'])(**rendervar) - curr_id = curr_data['id'] - im = torch.exp(params['cam_m'][curr_id])[:, None, None] * im + params['cam_c'][curr_id][:, None, None] - losses['im'] = 0.8 * l1_loss_v1(im, curr_data['im']) + 0.2 * (1.0 - calc_ssim(im, curr_data['im'])) - variables['means2D'] = rendervar['means2D'] # Gradient only accum from colour render for densification - - segrendervar = params2rendervar(params) - segrendervar['colors_precomp'] = params['seg_colors'] - seg, _, _, = Renderer(raster_settings=curr_data['cam'])(**segrendervar) - losses['seg'] = 0.8 * l1_loss_v1(seg, curr_data['seg']) + 0.2 * (1.0 - calc_ssim(seg, curr_data['seg'])) - - if not is_initial_timestep: - is_fg = (params['seg_colors'][:, 0] > 0.5).detach() - fg_pts = rendervar['means3D'][is_fg] - fg_rot = rendervar['rotations'][is_fg] - - rel_rot = quat_mult(fg_rot, variables["prev_inv_rot_fg"]) - rot = build_rotation(rel_rot) - neighbor_pts = fg_pts[variables["neighbor_indices"]] - curr_offset = neighbor_pts - fg_pts[:, None] - curr_offset_in_prev_coord = (rot.transpose(2, 1)[:, None] @ curr_offset[:, :, :, None]).squeeze(-1) - losses['rigid'] = weighted_l2_loss_v2(curr_offset_in_prev_coord, variables["prev_offset"], - variables["neighbor_weight"]) - losses['rot'] = weighted_l2_loss_v2(rel_rot[variables["neighbor_indices"]], rel_rot[:, None], - variables["neighbor_weight"]) +def initialize_per_timestep( + params: dict[str, torch.Tensor], + variables: dict[str, torch.Tensor], + optimizer: torch.optim.Optimizer +) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: - curr_offset_mag = torch.sqrt((curr_offset ** 2).sum(-1) + 1e-20) - losses['iso'] = weighted_l2_loss_v1(curr_offset_mag, variables["neighbor_dist"], variables["neighbor_weight"]) + current_points = params['means3D'] + current_rotations_normalized = torch.nn.functional.normalize(params['unnorm_rotations']) - losses['floor'] = torch.clamp(fg_pts[:, 1], min=0).mean() + # Calculate momentum-like updates + new_points = current_points + (current_points - variables["prev_pts"]) + new_rotations = torch.nn.functional.normalize(current_rotations_normalized + (current_rotations_normalized - variables["prev_rot"])) - bg_pts = rendervar['means3D'][~is_fg] - bg_rot = rendervar['rotations'][~is_fg] - losses['bg'] = l1_loss_v2(bg_pts, variables["init_bg_pts"]) + l1_loss_v2(bg_rot, variables["init_bg_rot"]) + # Extract foreground entities' info + foreground_mask = params['seg_colors'][:, 0] > 0.5 + previous_inverse_rotations_foreground = current_rotations_normalized[foreground_mask] + previous_inverse_rotations_foreground[:, 1:] = -1 * previous_inverse_rotations_foreground[:, 1:] + foreground_points = current_points[foreground_mask] + previous_offsets = foreground_points[variables["neighbor_indices"]] - foreground_points[:, None] - losses['soft_col_cons'] = l1_loss_v2(params['rgb_colors'], variables["prev_col"]) - - loss_weights = {'im': 1.0, 'seg': 3.0, 'rigid': 4.0, 'rot': 4.0, 'iso': 2.0, 'floor': 2.0, 'bg': 20.0, - 'soft_col_cons': 0.01} - loss = sum([loss_weights[k] * v for k, v in losses.items()]) - seen = radius > 0 - variables['max_2D_radius'][seen] = torch.max(radius[seen], variables['max_2D_radius'][seen]) - variables['seen'] = seen - return loss, variables - - -def initialize_per_timestep(params, variables, optimizer): - pts = params['means3D'] - rot = torch.nn.functional.normalize(params['unnorm_rotations']) - new_pts = pts + (pts - variables["prev_pts"]) - new_rot = torch.nn.functional.normalize(rot + (rot - variables["prev_rot"])) - - is_fg = params['seg_colors'][:, 0] > 0.5 - prev_inv_rot_fg = rot[is_fg] - prev_inv_rot_fg[:, 1:] = -1 * prev_inv_rot_fg[:, 1:] - fg_pts = pts[is_fg] - prev_offset = fg_pts[variables["neighbor_indices"]] - fg_pts[:, None] - variables['prev_inv_rot_fg'] = prev_inv_rot_fg.detach() - variables['prev_offset'] = prev_offset.detach() + # Update previous values in the variables dictionary + variables['prev_inv_rot_fg'] = previous_inverse_rotations_foreground.detach() + variables['prev_offset'] = previous_offsets.detach() variables["prev_col"] = params['rgb_colors'].detach() - variables["prev_pts"] = pts.detach() - variables["prev_rot"] = rot.detach() + variables["prev_pts"] = current_points.detach() + variables["prev_rot"] = current_rotations_normalized.detach() - new_params = {'means3D': new_pts, 'unnorm_rotations': new_rot} - params = update_params_and_optimizer(new_params, params, optimizer) + # Update the params dictionary + updated_params = {'means3D': new_points, 'unnorm_rotations': new_rotations} + params = update_params_and_optimizer(updated_params, params, optimizer) return params, variables + def initialize_post_first_timestep(params, variables, optimizer, num_knn=20): is_fg = params['seg_colors'][:, 0] > 0.5 init_fg_pts = params['means3D'][is_fg] @@ -184,42 +176,55 @@ def report_progress(params, data, i, progress_bar, every_i=100): progress_bar.update(every_i) -def train(seq, exp): - if os.path.exists(f"./output/{exp}/{seq}"): - print(f"Experiment '{exp}' for sequence '{seq}' already exists. Exiting.") +def train(sequence:str, exp_name:str): + if os.path.exists(f"./output/{exp_name}/{sequence}"): + print(f"Experiment '{exp_name}' for sequence '{sequence}' already exists. Exiting.") return - md = json.load(open(f"./data/{seq}/train_meta.json", 'r')) # metadata - num_timesteps = len(md['fn']) - params, variables = initialize_params(seq, md) + + metadata = json.load(open(f"./data/{sequence}/train_meta.json", 'r')) + num_timesteps = len(metadata['fn']) + + params, variables = initialize_params(sequence, metadata) optimizer = initialize_optimizer(params, variables) output_params = [] - for t in range(num_timesteps): - dataset = get_dataset(t, md, seq) - todo_dataset = [] - is_initial_timestep = (t == 0) + + for timestep in range(num_timesteps): + dataset = construct_timestep_dataset(timestep, metadata, sequence) + batch_sampler = initialize_batch_sampler(dataset) + is_initial_timestep = (timestep == 0) if not is_initial_timestep: + # "momentum-based update" params, variables = initialize_per_timestep(params, variables, optimizer) - num_iter_per_timestep = 10000 if is_initial_timestep else 2000 - progress_bar = tqdm(range(num_iter_per_timestep), desc=f"timestep {t}") + + num_iter_per_timestep = INITIAL_TIMESTEP_ITERATIONS if is_initial_timestep else TIMESTEP_ITERATIONS + progress_bar = tqdm(range(num_iter_per_timestep), desc=f"timestep {timestep}") + for i in range(num_iter_per_timestep): - curr_data = get_batch(todo_dataset, dataset) + curr_data = get_data_point(batch_sampler, dataset) loss, variables = get_loss(params, curr_data, variables, is_initial_timestep) loss.backward() + optimizer.step() + optimizer.zero_grad(set_to_none=True) + with torch.no_grad(): report_progress(params, dataset[0], i, progress_bar) if is_initial_timestep: params, variables = densify(params, variables, optimizer, i) - optimizer.step() - optimizer.zero_grad(set_to_none=True) + progress_bar.close() output_params.append(params2cpu(params, is_initial_timestep)) if is_initial_timestep: variables = initialize_post_first_timestep(params, variables, optimizer) - save_params(output_params, seq, exp) + + save_params(output_params, sequence, exp_name) -if __name__ == "__main__": +def main(): exp_name = "exp1" for sequence in ["basketball", "boxes", "football", "juggle", "softball", "tennis"]: train(sequence, exp_name) torch.cuda.empty_cache() + + +if __name__ == "__main__": + main() \ No newline at end of file