Skip to content

Commit

Permalink
Added EulerFlow to readme
Browse files Browse the repository at this point in the history
  • Loading branch information
kylevedder committed Nov 13, 2024
1 parent 15ffa1c commit b021ad5
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 47 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ The Zoo supports the following methods:
- [Neural Scene Flow Prior (NSFP)](https://arxiv.org/abs/2111.01253)
- [Fast NSF](https://arxiv.org/abs/2304.09121)
- [Liu et al. 2024](https://arxiv.org/abs/2403.16116)
- [EulerFlow][https://vedder.io/eulerflow]


If you use this codebase, please cite the following paper:
Expand Down
4 changes: 2 additions & 2 deletions models/components/neural_reps/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .nsfp_raw_mlp import NSFPRawMLP, ActivationFn
from .eulerflow_raw_mlp import (
EulerFlowFlowMLP,
EulerFlowMLP,
EulerFlowOccFlowMLP,
QueryDirection,
ModelFlowResult,
Expand All @@ -14,7 +14,7 @@
__all__ = [
"NSFPRawMLP",
"ActivationFn",
"EulerFlowFlowMLP",
"EulerFlowMLP",
"EulerFlowOccFlowMLP",
"Liu2024FusionRawMLP",
"fNT",
Expand Down
2 changes: 1 addition & 1 deletion models/components/neural_reps/eulerflow_raw_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __len__(self):
return super().__len__() + self.n_freq


class EulerFlowFlowMLP(NSFPRawMLP):
class EulerFlowMLP(NSFPRawMLP):

def __init__(
self,
Expand Down
88 changes: 44 additions & 44 deletions models/mini_batch_optimization/eulerflow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pytorch_lightning.loggers.logger import Logger
from .mini_batch_optim_loop import MiniBatchOptimizationLoop, MinibatchedSceneFlowInputSequence
from models.components.neural_reps import (
EulerFlowFlowMLP,
EulerFlowMLP,
QueryDirection,
ModelFlowResult,
ActivationFn,
Expand Down Expand Up @@ -33,20 +33,20 @@
import numpy as np


class PointCloudLossType(enum.Enum):
class PointCloudLoss(enum.Enum):
TRUNCATED_CHAMFER_FORWARD = "truncated_chamfer"
TRUNCATED_CHAMFER_FORWARD_BACKWARD = "truncated_chamfer_forward_backward"
TRUNCATED_KD_TREE_FORWARD = "truncated_kd_tree_forward"
TRUNCATED_KD_TREE_FORWARD_BACKWARD = "truncated_kd_tree_forward_backward"


class PointCloudTargetType(enum.Enum):
class PointCloudTarget(enum.Enum):
LIDAR = "lidar"
LIDAR_CAMERA = "lidar_camera"


@dataclass
class EulerFlowPreprocessedInput:
class EulerFlowTorchedInput:
full_global_pcs: list[torch.Tensor]
full_global_pcs_mask: list[torch.Tensor]
full_global_auxillary_pcs: list[torch.Tensor | None]
Expand Down Expand Up @@ -109,17 +109,17 @@ def __init__(
self,
full_input_sequence: TorchFullFrameInputSequence,
speed_threshold: float,
pc_target_type: PointCloudTargetType | str,
pc_loss_type: PointCloudLossType | str,
pc_target_type: PointCloudTarget | str,
pc_loss_type: PointCloudLoss | str,
enable_k_step_loss: bool = True,
enable_cycle_consistency: bool = True,
model: torch.nn.Module = EulerFlowFlowMLP(),
model: torch.nn.Module = EulerFlowMLP(),
) -> None:
super().__init__(full_input_sequence)
self.model = model
self.speed_threshold = speed_threshold
self.pc_target_type = PointCloudTargetType(pc_target_type)
self.pc_loss_type = PointCloudLossType(pc_loss_type)
self.pc_target_type = PointCloudTarget(pc_target_type)
self.pc_loss_type = PointCloudLoss(pc_loss_type)

self._prep_neural_prior(self.model)
self.enable_k_step_loss = enable_k_step_loss
Expand All @@ -132,9 +132,9 @@ def _prep_kdtrees(self) -> list[KDTreeWrapper]:
kd_trees = []
for idx in tqdm.tqdm(range(len(full_rep)), desc="Building KD Trees"):
match self.pc_target_type:
case PointCloudTargetType.LIDAR:
case PointCloudTarget.LIDAR:
target_pc = full_rep.get_global_lidar_pc(idx, with_grad=False)
case PointCloudTargetType.LIDAR_CAMERA:
case PointCloudTarget.LIDAR_CAMERA:
target_pc = full_rep.get_global_lidar_auxillary_pc(idx, with_grad=False)
case _:
raise ValueError(f"Unknown point cloud target type: {self.pc_target_type}")
Expand All @@ -159,7 +159,7 @@ def init_weights(m):

def _preprocess(
self, input_sequence: TorchFullFrameInputSequence
) -> EulerFlowPreprocessedInput:
) -> EulerFlowTorchedInput:

full_global_pcs: list[torch.Tensor] = []
full_global_pcs_mask: list[torch.Tensor] = []
Expand Down Expand Up @@ -191,7 +191,7 @@ def _preprocess(
sequence_idxes = list(range(len(input_sequence)))
sequence_total_length = len(input_sequence)

return EulerFlowPreprocessedInput(
return EulerFlowTorchedInput(
full_global_pcs=full_global_pcs,
full_global_pcs_mask=full_global_pcs_mask,
full_global_auxillary_pcs=full_global_camera_pcs,
Expand All @@ -203,7 +203,7 @@ def _preprocess(
def _is_occupied_cost(self, model_res: ModelFlowResult) -> list[BaseCostProblem]:
return []

def _cycle_consistency(self, rep: EulerFlowPreprocessedInput) -> BaseCostProblem:
def _cycle_consistency(self, rep: EulerFlowTorchedInput) -> BaseCostProblem:
cost_problems: list[BaseCostProblem] = []
for idx in range(len(rep) - 1):
pc = rep.get_global_lidar_pc(idx)
Expand All @@ -230,20 +230,20 @@ def _cycle_consistency(self, rep: EulerFlowPreprocessedInput) -> BaseCostProblem
cost_problems.extend(self._is_occupied_cost(model_res_reverse))
return AdditiveCosts(cost_problems)

def _get_kd_tree(self, rep: EulerFlowPreprocessedInput, rep_idx: int) -> KDTreeWrapper:
def _get_kd_tree(self, rep: EulerFlowTorchedInput, rep_idx: int) -> KDTreeWrapper:
global_idx = rep.sequence_idxes[rep_idx]
if self.kd_trees is None:
self.kd_trees = self._prep_kdtrees()
return self.kd_trees[global_idx]

def _process_k_step_subk(
self,
rep: EulerFlowPreprocessedInput,
rep: EulerFlowTorchedInput,
anchor_pc: torch.Tensor,
anchor_idx: int,
subk: int,
query_direction: QueryDirection,
loss_type: PointCloudLossType,
loss_type: PointCloudLoss,
speed_limit: float | None,
) -> tuple[BaseCostProblem, torch.Tensor]:
sequence_idx = rep.sequence_idxes[anchor_idx]
Expand All @@ -259,30 +259,30 @@ def _process_k_step_subk(

def _get_target_pc() -> torch.Tensor:
match self.pc_target_type:
case PointCloudTargetType.LIDAR:
case PointCloudTarget.LIDAR:
target_pc = rep.get_global_lidar_pc(anchor_idx + index_offset)
case PointCloudTargetType.LIDAR_CAMERA:
case PointCloudTarget.LIDAR_CAMERA:
target_pc = rep.get_global_lidar_auxillary_pc(anchor_idx + index_offset)
return target_pc

match loss_type:
case PointCloudLossType.TRUNCATED_CHAMFER_FORWARD:
case PointCloudLoss.TRUNCATED_CHAMFER_FORWARD:
problem: BaseCostProblem = TruncatedChamferLossProblem(
warped_pc=anchor_pc,
target_pc=_get_target_pc(),
distance_type=ChamferDistanceType.FORWARD_ONLY,
)
case PointCloudLossType.TRUNCATED_CHAMFER_FORWARD_BACKWARD:
case PointCloudLoss.TRUNCATED_CHAMFER_FORWARD_BACKWARD:
problem = TruncatedChamferLossProblem(
warped_pc=anchor_pc,
target_pc=_get_target_pc(),
distance_type=ChamferDistanceType.BOTH_DIRECTION,
)
case PointCloudLossType.TRUNCATED_KD_TREE_FORWARD:
case PointCloudLoss.TRUNCATED_KD_TREE_FORWARD:
problem = TruncatedForwardKDTreeLossProblem(
warped_pc=anchor_pc, kdtree=self._get_kd_tree(rep, anchor_idx + index_offset)
)
case PointCloudLossType.TRUNCATED_KD_TREE_FORWARD_BACKWARD:
case PointCloudLoss.TRUNCATED_KD_TREE_FORWARD_BACKWARD:
problem = TruncatedForwardBackwardKDTreeLossProblem(
warped_pc=anchor_pc,
target_pc=_get_target_pc(),
Expand All @@ -307,10 +307,10 @@ def _get_target_pc() -> torch.Tensor:

def _k_step_cost(
self,
rep: EulerFlowPreprocessedInput,
rep: EulerFlowTorchedInput,
k: int,
query_direction: QueryDirection,
loss_type: PointCloudLossType,
loss_type: PointCloudLoss,
speed_limit: float | None = None,
) -> BaseCostProblem:
assert k >= 1, f"Expected k >= 1, but got {k}"
Expand Down Expand Up @@ -365,7 +365,7 @@ def optim_forward_single(

def _compute_ego_flow(
self,
rep: EulerFlowPreprocessedInput,
rep: EulerFlowTorchedInput,
query_idx: int,
direction: QueryDirection = QueryDirection.FORWARD,
) -> tuple[torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -525,17 +525,17 @@ class EulerFlowOptimizationLoop(MiniBatchOptimizationLoop):
def __init__(
self,
speed_threshold: float,
pc_target_type: PointCloudTargetType | str,
pc_target_type: PointCloudTarget | str,
pc_loss_type: (
PointCloudLossType | str
) = PointCloudLossType.TRUNCATED_KD_TREE_FORWARD_BACKWARD,
PointCloudLoss | str
) = PointCloudLoss.TRUNCATED_KD_TREE_FORWARD_BACKWARD,
*args,
**kwargs,
):
super().__init__(model_class=self._model_class(), *args, **kwargs)
self.speed_threshold = speed_threshold
self.pc_target_type = PointCloudTargetType(pc_target_type)
self.pc_loss_type = PointCloudLossType(pc_loss_type)
self.pc_target_type = PointCloudTarget(pc_target_type)
self.pc_loss_type = PointCloudLoss(pc_loss_type)

def _model_class(self) -> type[BaseOptimizationModel]:
return EulerFlowModel
Expand Down Expand Up @@ -574,7 +574,7 @@ def _model_constructor_args(
self, full_input_sequence: TorchFullFrameInputSequence
) -> dict[str, any]:
return super()._model_constructor_args(full_input_sequence) | dict(
model=EulerFlowFlowMLP(act_fn=ActivationFn.SINC)
model=EulerFlowMLP(act_fn=ActivationFn.SINC)
)


Expand All @@ -584,7 +584,7 @@ def _model_constructor_args(
self, full_input_sequence: TorchFullFrameInputSequence
) -> dict[str, any]:
return super()._model_constructor_args(full_input_sequence) | dict(
model=EulerFlowFlowMLP(act_fn=ActivationFn.GAUSSIAN)
model=EulerFlowMLP(act_fn=ActivationFn.GAUSSIAN)
)


Expand All @@ -594,7 +594,7 @@ def _model_constructor_args(
self, full_input_sequence: TorchFullFrameInputSequence
) -> dict[str, any]:
return super()._model_constructor_args(full_input_sequence) | dict(
model=EulerFlowFlowMLP(encoder=FourierTemporalEmbedding())
model=EulerFlowMLP(encoder=FourierTemporalEmbedding())
)


Expand All @@ -604,7 +604,7 @@ def _model_constructor_args(
self, full_input_sequence: TorchFullFrameInputSequence
) -> dict[str, any]:
return super()._model_constructor_args(full_input_sequence) | dict(
model=EulerFlowFlowMLP(num_layers=22)
model=EulerFlowMLP(num_layers=22)
)


Expand All @@ -614,7 +614,7 @@ def _model_constructor_args(
self, full_input_sequence: TorchFullFrameInputSequence
) -> dict[str, any]:
return super()._model_constructor_args(full_input_sequence) | dict(
model=EulerFlowFlowMLP(num_layers=20)
model=EulerFlowMLP(num_layers=20)
)


Expand All @@ -624,7 +624,7 @@ def _model_constructor_args(
self, full_input_sequence: TorchFullFrameInputSequence
) -> dict[str, any]:
return super()._model_constructor_args(full_input_sequence) | dict(
model=EulerFlowFlowMLP(num_layers=18)
model=EulerFlowMLP(num_layers=18)
)


Expand All @@ -634,7 +634,7 @@ def _model_constructor_args(
self, full_input_sequence: TorchFullFrameInputSequence
) -> dict[str, any]:
return super()._model_constructor_args(full_input_sequence) | dict(
model=EulerFlowFlowMLP(num_layers=16)
model=EulerFlowMLP(num_layers=16)
)


Expand All @@ -644,7 +644,7 @@ def _model_constructor_args(
self, full_input_sequence: TorchFullFrameInputSequence
) -> dict[str, any]:
return super()._model_constructor_args(full_input_sequence) | dict(
model=EulerFlowFlowMLP(num_layers=14)
model=EulerFlowMLP(num_layers=14)
)


Expand All @@ -654,7 +654,7 @@ def _model_constructor_args(
self, full_input_sequence: TorchFullFrameInputSequence
) -> dict[str, any]:
return super()._model_constructor_args(full_input_sequence) | dict(
model=EulerFlowFlowMLP(num_layers=12)
model=EulerFlowMLP(num_layers=12)
)


Expand All @@ -664,7 +664,7 @@ def _model_constructor_args(
self, full_input_sequence: TorchFullFrameInputSequence
) -> dict[str, any]:
return super()._model_constructor_args(full_input_sequence) | dict(
model=EulerFlowFlowMLP(num_layers=10)
model=EulerFlowMLP(num_layers=10)
)


Expand All @@ -674,7 +674,7 @@ def _model_constructor_args(
self, full_input_sequence: TorchFullFrameInputSequence
) -> dict[str, any]:
return super()._model_constructor_args(full_input_sequence) | dict(
model=EulerFlowFlowMLP(num_layers=6)
model=EulerFlowMLP(num_layers=6)
)


Expand All @@ -684,7 +684,7 @@ def _model_constructor_args(
self, full_input_sequence: TorchFullFrameInputSequence
) -> dict[str, any]:
return super()._model_constructor_args(full_input_sequence) | dict(
model=EulerFlowFlowMLP(num_layers=4)
model=EulerFlowMLP(num_layers=4)
)


Expand All @@ -694,5 +694,5 @@ def _model_constructor_args(
self, full_input_sequence: TorchFullFrameInputSequence
) -> dict[str, any]:
return super()._model_constructor_args(full_input_sequence) | dict(
model=EulerFlowFlowMLP(num_layers=2)
model=EulerFlowMLP(num_layers=2)
)

0 comments on commit b021ad5

Please sign in to comment.