Skip to content

Commit

Permalink
Added minibatch support to Test-Time Optimization methods (#16)
Browse files Browse the repository at this point in the history
* Renamed existing TTO to be whole batch

* Migrated shared resources

* Moved resources around

* Fixed loader import

* Added sequence minibatcher

* Added minibatch tests

* More cleanups. loss_fn currently broken

* Converted to using model based loss fn

* Checkpoint down past the minute

* Added missing arg

* Added minibatcher type

* Moved models to separate folders

* First cut at working NSFP with new wrapper

* Fixed FastNSF and NSFP tests.

* Fixed constant baseline

* Fixed liu 2024

* Deleted cruft

* Added minibatch optim loop as simple extension

* First cut at minibatched gigachad

* Added trainable Gigachad

* Cleaned up naming, passed args to Liu 2024 subcomponents properly
  • Loading branch information
kylevedder authored May 3, 2024
1 parent b007026 commit cb570aa
Show file tree
Hide file tree
Showing 96 changed files with 1,512 additions and 1,724 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ checkpoint_eval_launch_dir/
screenshots*/
eval_results/
launch_files/
tests/*/config/
tests/*/config*/
15 changes: 9 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@ Currently, the Zoo supports the following datasets:

The Zoo supports the following methods:

- [FastFlow3D](https://arxiv.org/abs/2103.01306) / [FastFlow3D XL](https://vedder.io/zeroflow)
- [ZeroFlow and ZeroFlow XL](https://vedder.io/zeroflow)
- [Neural Scene Flow Prior](https://arxiv.org/abs/2111.01253)
- [Fast NSF](https://arxiv.org/abs/2304.09121)
- [Liu et al. 2024](https://arxiv.org/abs/2403.16116)
- [DeFlow](https://arxiv.org/abs/2401.16122)
- Feed-forward
- [FastFlow3D](https://arxiv.org/abs/2103.01306) / [FastFlow3D XL](https://vedder.io/zeroflow)
- [ZeroFlow and ZeroFlow XL](https://vedder.io/zeroflow)
- [DeFlow](https://arxiv.org/abs/2401.16122)
- Test time optimization
- [Neural Scene Flow Prior (NSFP)](https://arxiv.org/abs/2111.01253)
- [Fast NSF](https://arxiv.org/abs/2304.09121)
- [Liu et al. 2024](https://arxiv.org/abs/2403.16116)


If you use this codebase, please cite the following paper:

Expand Down
63 changes: 34 additions & 29 deletions configs/deflow/argo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,44 +10,49 @@

SEQUENCE_LENGTH = 2

model = dict(name="DeFlow",
args=dict(VOXEL_SIZE={{_base_.VOXEL_SIZE}},
PSEUDO_IMAGE_DIMS={{_base_.PSEUDO_IMAGE_DIMS}},
POINT_CLOUD_RANGE={{_base_.POINT_CLOUD_RANGE}}))

loss_fn = dict(name="FastFlow3DBucketedLoaderLoss", args=dict())
model = dict(
name="DeFlow",
args=dict(
VOXEL_SIZE={{_base_.VOXEL_SIZE}},
PSEUDO_IMAGE_DIMS={{_base_.PSEUDO_IMAGE_DIMS}},
POINT_CLOUD_RANGE={{_base_.POINT_CLOUD_RANGE}},
),
)

######## TEST DATASET ########

test_dataset_root = "/efs/argoverse2/val/"
save_output_folder = "/efs/argoverse2/val_deflow_flow/"

test_dataset = dict(name="BucketedSceneFlowDataset",
args=dict(dataset_name="Argoverse2CausalSceneFlow",
root_dir=test_dataset_root,
with_ground=False,
with_rgb=False,
eval_type="bucketed_epe",
eval_args=dict()))
test_dataset = dict(
name="BucketedSceneFlowDataset",
args=dict(
dataset_name="Argoverse2CausalSceneFlow",
root_dir=test_dataset_root,
with_ground=False,
with_rgb=False,
eval_type="bucketed_epe",
eval_args=dict(),
),
)

test_dataloader = dict(
args=dict(batch_size=1, num_workers=8, shuffle=False, pin_memory=True))
test_dataloader = dict(args=dict(batch_size=1, num_workers=8, shuffle=False, pin_memory=True))

######## TRAIN DATASET ########

train_sequence_dir = "/efs/argoverse2/train/"

train_dataset = dict(name="BucketedSceneFlowDataset",
args=dict(dataset_name="Argoverse2CausalSceneFlow",
root_dir=train_sequence_dir,
with_ground=False,
use_gt_flow=True,
with_rgb=False,
eval_type="bucketed_epe",
eval_args=dict()))

train_dataloader = dict(
args=dict(batch_size=16, num_workers=16, shuffle=True, pin_memory=False))



train_dataset = dict(
name="BucketedSceneFlowDataset",
args=dict(
dataset_name="Argoverse2CausalSceneFlow",
root_dir=train_sequence_dir,
with_ground=False,
use_gt_flow=True,
with_rgb=False,
eval_type="bucketed_epe",
eval_args=dict(),
),
)

train_dataloader = dict(args=dict(batch_size=16, num_workers=16, shuffle=True, pin_memory=False))
2 changes: 0 additions & 2 deletions configs/fastflow3d/argo/bucketed_nsfp_distillation_1x.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
),
)

loss_fn = dict(name="FastFlow3DBucketedLoaderLoss", args=dict())

######## TEST DATASET ########

test_dataset_root = "/efs/argoverse2/val/"
Expand Down
6 changes: 0 additions & 6 deletions configs/gigachad_nsf/argo/causal/test.py

This file was deleted.

3 changes: 0 additions & 3 deletions configs/gigachad_nsf/argo/causal/test_cached.py

This file was deleted.

3 changes: 0 additions & 3 deletions configs/gigachad_nsf/argo/causal/train_cached.py

This file was deleted.

6 changes: 0 additions & 6 deletions configs/gigachad_nsf/argo/causal/val.py

This file was deleted.

3 changes: 0 additions & 3 deletions configs/gigachad_nsf/argo/causal/val_cached.py

This file was deleted.

6 changes: 0 additions & 6 deletions configs/gigachad_nsf/argo/causal/val_cached_debug.py

This file was deleted.

6 changes: 0 additions & 6 deletions configs/gigachad_nsf/argo/causal/val_debug.py

This file was deleted.

6 changes: 0 additions & 6 deletions configs/gigachad_nsf/argo/noncausal/test.py

This file was deleted.

3 changes: 0 additions & 3 deletions configs/gigachad_nsf/argo/noncausal/test_cached.py

This file was deleted.

29 changes: 0 additions & 29 deletions configs/gigachad_nsf/argo/noncausal/train.py

This file was deleted.

3 changes: 0 additions & 3 deletions configs/gigachad_nsf/argo/noncausal/train_cached.py

This file was deleted.

6 changes: 0 additions & 6 deletions configs/gigachad_nsf/argo/noncausal/val.py

This file was deleted.

3 changes: 0 additions & 3 deletions configs/gigachad_nsf/argo/noncausal/val_cached.py

This file was deleted.

6 changes: 0 additions & 6 deletions configs/gigachad_nsf/argo/noncausal/val_cached_debug.py

This file was deleted.

18 changes: 0 additions & 18 deletions configs/gigachad_nsf/argo/noncausal/val_debug.py

This file was deleted.

38 changes: 38 additions & 0 deletions configs/gigachad_nsf/argo/noncausal_minibatched/val_debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
has_labels = False

SEQUENCE_LENGTH = 150

test_dataset_root = "/efs/argoverse2_small/val/"
save_output_folder = "/efs/argoverse2_small/val_gigachad_nsf_flow_feather/"


model = dict(
name="GigachadNSFOptimizationLoop",
args=dict(save_flow_every=1, minibatch_size=10, speed_threshold=30.0 / 10.0),
)

epochs = 5000
learning_rate = 0.008
save_every = 500
validate_every = 500

train_dataset = dict(
name="BucketedSceneFlowDataset",
args=dict(
dataset_name="Argoverse2NonCausalSceneFlow",
root_dir=test_dataset_root,
with_ground=False,
with_rgb=False,
eval_type="bucketed_epe",
eval_args=dict(),
subsequence_length=SEQUENCE_LENGTH,
split=dict(split_idx=0, num_splits=2),
),
)


train_dataloader = dict(args=dict(batch_size=1, num_workers=0, shuffle=False, pin_memory=True))


test_dataset = train_dataset.copy()
test_dataloader = train_dataloader.copy()
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
is_trainable = False
has_labels = False

test_dataset_root = "/efs/argoverse2/train/"
save_output_folder = "/efs/argoverse2/train_gigachad_nsf_flow_feather/"
test_dataset_root = "/efs/argoverse2_small/val/"
save_output_folder = "/efs/argoverse2_small/val_nsfp_rewritten_flow_debug/"

SEQUENCE_LENGTH = 5

model = dict(name="GigaChadNSFModel", args=dict())
SEQUENCE_LENGTH = 2

epochs = 20
learning_rate = 2e-6
save_every = 500
validate_every = 500
model = dict(
name="WholeBatchOptimizationLoop",
args=dict(model_class="WholeBatchNSFPCycleConsistency", save_flow_every=10),
)

test_dataset = dict(
name="BucketedSceneFlowDataset",
Expand All @@ -22,7 +21,6 @@
with_rgb=False,
eval_type="bucketed_epe",
eval_args=dict(),
subsequence_length=SEQUENCE_LENGTH,
),
)

Expand Down
47 changes: 9 additions & 38 deletions core_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,13 @@
from .loaders import (
load_pickle,
save_pickle,
load_json,
save_json,
load_csv,
save_csv,
run_cmd,
load_npz,
save_npz,
load_npy,
save_npy,
load_txt,
save_txt,
save_feather,
load_feather,
save_by_extension,
load_by_extension,
)

from .model_wrapper import ModelWrapper
from .tb_logging import setup_tb_logger
from .checkpointing import get_checkpoint_path, setup_model
from .dataloading import make_dataloader


__all__ = [
"load_pickle",
"save_pickle",
"load_json",
"save_json",
"load_csv",
"save_csv",
"load_npz",
"save_npz",
"load_npy",
"save_npy",
"run_cmd",
"load_txt",
"save_txt",
"load_feather",
"save_feather",
"save_by_extension",
"load_by_extension",
"ModelWrapper",
]
"setup_tb_logger",
"get_checkpoint_path",
"setup_model",
"make_dataloader",
]
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import datetime
from pathlib import Path
import os
from dataloaders import EvalWrapper
from core_utils import ModelWrapper
from typing import Optional


def get_rank() -> int:
Expand All @@ -16,7 +19,7 @@ def get_rank() -> int:


def get_checkpoint_path(cfg: Config) -> Path:
checkpoint_dir_name = datetime.datetime.now().strftime("%Y_%m_%d-%I_%M_%S_%p")
checkpoint_dir_name = datetime.datetime.now().strftime("%Y_%m_%d-%I_%M_%S_%p_%f")
cfg_filename = Path(cfg.filename)
config_name = cfg_filename.stem
parent_name = cfg_filename.parent.name
Expand All @@ -29,3 +32,14 @@ def get_checkpoint_path(cfg: Config) -> Path:
# Since we're not rank 0, we shoulds grab the most recent directory instead of creating a new one.
checkpoint_path = sorted(parent_path.glob("*"))[-1]
return checkpoint_path


def setup_model(cfg: Config, evaluator: EvalWrapper, checkpoint: Optional[Path]):
if (hasattr(cfg, "is_trainable") and not cfg.is_trainable) or checkpoint is None:
model = ModelWrapper(cfg, evaluator=evaluator)
else:
assert checkpoint is not None, "Must provide checkpoint for validation"
checkpoint = Path(checkpoint)
assert checkpoint.exists(), f"Checkpoint file {checkpoint} does not exist"
model = ModelWrapper.load_from_checkpoint(checkpoint, cfg=cfg, evaluator=evaluator)
return model
Loading

0 comments on commit cb570aa

Please sign in to comment.