diff --git a/.gitignore b/.gitignore index 898e032..fd9e8d9 100644 --- a/.gitignore +++ b/.gitignore @@ -6,7 +6,7 @@ __pycache__/ #scripts -*.sh +_*.sh # package unidepth.egg-info diff --git a/README.md b/README.md index e83732f..73df8fa 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,5 @@ [![arXiv](https://img.shields.io/badge/arXiv-2403.18913-blue?logo=arxiv&color=%23B31B1B)](https://arxiv.org/abs/2403.18913) [![ProjectPage](https://img.shields.io/badge/Project_Page-UniDepth-blue)](https://lpiccinelli-eth.github.io/pub/unidepth/) - [![KITTI Benchmark](https://img.shields.io/badge/KITTI%20Benchmark-1st%20(at%20submission%20time)-orange)](https://www.cvlibs.net/datasets/kitti/eval_depth.php?benchmark=depth_prediction) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/unidepth-universal-monocular-metric-depth/monocular-depth-estimation-on-nyu-depth-v2)](https://paperswithcode.com/sota/monocular-depth-estimation-on-nyu-depth-v2?p=unidepth-universal-monocular-metric-depth) @@ -211,6 +210,11 @@ To summarize the main differences are: - ONNX support +## Training + +Please [visit the training README](scripts/README.md) for more information. + + ## Results ### Metric Depth Estimation diff --git a/configs/train_v1_vitl14.json b/configs/train_v1_vitl14.json new file mode 100644 index 0000000..bab57d0 --- /dev/null +++ b/configs/train_v1_vitl14.json @@ -0,0 +1,109 @@ +{ + "generic": { + "seed": 13, + "deterministic": true + }, + "training": { + "n_iters": 300000, + "batch_size": 32, + "validation_interval": 1000, + "nsteps_accumulation_gradient": 1, + "use_checkpoint": false, + "lr": 1e-4, + "lr_final": 1e-6, + "lr_warmup": 1.0, + "cycle_beta": false, + + "wd": 0.1, + "wd_final": 0.1, + "warmup_iters": 75000, + "ld": 1.0, + + "drop_path": 0.0, + "ema": true, + "f16": true, + "clipping": 1.0, + "losses": { + "depth": { + "name": "SILog", + "weight": 1.0, + "output_fn": "sqrt", + "input_fn": "log", + "dims": [-2,-1], + "integrated": 0.15 + }, + "invariance": { + "name": "SelfDistill", + "weight": 0.1, + "output_fn": "sqrt" + }, + "camera": { + "name": "Regression", + "weight": 0.25, + "gamma": 1.0, + "alpha": 1.0, + "fn": "l2", + "output_fn": "sqrt", + "input_fn": "linear" + } + }}, + "data": { + "image_shape": [480, 640], + "normalization": "imagenet", + "pairs": 2, + "num_frames": 1, + "sampling":{ + "Sintel": 1.0, + "ADT": 1.0, + "KITTI": 1.0, + "HM3D": 1.0, + "ScanNet": 1.0 + }, + "train_datasets": [ + "ScanNet" + ], + "val_datasets": [ + "IBims" + ], + "data_root": "datasets", + "crop": "garg", + "augmentations": { + "random_scale": 2.0, + "random_jitter": 0.4, + "jitter_p": 0.8, + "random_blur": 2.0, + "blur_p": 0.2, + "random_gamma": 0.2, + "gamma_p": 0.8, + "grayscale_p": 0.2, + "flip_p": 0.5, + "test_context": 1.0, + "shape_constraints": { + "ratio_bounds": [0.66, 2.0], + "pixels_max": 2600, + "pixels_min": 1200, + "height_min": 15, + "width_min": 15, + "shape_mult": 14, + "sample": true + } + } + }, + "model": { + "name": "UniDepthV1", + "num_heads": 8, + "expansion": 4, + "pixel_decoder": { + "hidden_dim": 512, + "depths": [3, 2, 1], + "dropout": 0.0 + }, + "pixel_encoder": { + "name": "dinov2_vits14", + "norm": true, + "pretrained": "", + "lr": 1e-5, + "frozen_stages": 0 + } + } +} \ No newline at end of file diff --git a/install.sh b/install.sh deleted file mode 100644 index 4a09022..0000000 --- a/install.sh +++ /dev/null @@ -1,11 +0,0 @@ -#!/bin/bash -NAME=${1} -VENV_DIR=${2} - -python -m venv ${VENV_DIR}/${NAME} - -source ${VENV_DIR}/${NAME}/bin/activate - -pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu118 -pip install -e . -pip install xformers==0.0.24 --index-url https://download.pytorch.org/whl/cu118 diff --git a/requirements.txt b/requirements.txt index a436461..b2da657 100644 --- a/requirements.txt +++ b/requirements.txt @@ -55,6 +55,9 @@ pycodestyle pyflakes pyparsing python-dateutil +# pytorch3d is needed only for chamfer distance calculation +# you can compile it from ops/knn and avoid this dependency +pytorch3d @ "git+https://github.com/facebookresearch/pytorch3d.git@stable" pytz PyYAML requests diff --git a/scripts/README.md b/scripts/README.md new file mode 100644 index 0000000..939b416 --- /dev/null +++ b/scripts/README.md @@ -0,0 +1,57 @@ +## Training + +We provide the `train.py` script that allows to load the dataset, initialize and start the training. From the root of the repo: + +```bash +export REPO=`pwd` +export PYTHONPATH=${REPO}:${PYTHONPATH} + +# Adapt all this to your setup +export TMPDIR="/tmp" +export TORCH_HOME=${TMPDIR} +export HUGGINGFACE_HUB_CACHE=${TMPDIR} +export WANDB_HOME=${TMPDIR} +export DATAROOT= + + +export MASTER_PORT=$((( RANDOM % 600 ) + 29400 )) +if [ $NNODES -gt 1 ]; then + export MASTER_PORT=29400 +fi + +# this is the config will be used +export CFG="train_v1_vitl14.json" +``` + +If you are on a machine without SLURM you can run the following: +```bash +# make the following input-dependent for multi-node +export NNODES=1 +export RANK=0 +export MASTER_ADDR=127.0.0.1 +export CUDA_VISIBLE_DEVICES="0" # set yours + +export GPUS=$(echo ${CUDA_VISIBLE_DEVICES} | tr ',' '\n' | wc -l) +echo "Start script with python from: `which python`" +torchrun --rdzv-backend=c10d --nnodes=${NNODES} --nproc_per_node=${GPUS} --rdzv-endpoint ${MASTER_ADDR}:${MASTER_PORT} ${REPO}/scripts/train.py --config-file ${REPO}/configs/${CFG} --distributed +``` + +If you system present SLURM, all the information will be set by the scheduler and you have to run just: +```bash +srun -c ${SLURM_CPUS_PER_TASK} --kill-on-bad-exit=1 python -u ${REPO}/scripts/train.py --config-file ${REPO}/configs/${CFG} --master-port ${MASTER_PORT} --distributed +``` + +The training is available only for V1.
+We have changes in the upcomping month for V2 and its "trainable model code" will be made public then. + + +### Datasets + +We used both image-based and sequence-based dataset. The `ImageDataset` class is actually for legacy only as we moved image-based dataset to be "dummy" single-frame sequences.
+We [provide two example dataset to get familiar to the pipeline and structure, namely iBims-1 and Sintel](https://drive.google.com/drive/folders/1FKsa5-b3EX0ukZq7bxord5fC5OfUiy16?usp=sharing), image- and sequence-based, respectively.
+You can adapt the data loading and processing to your example; however, you will need to keep the same interface for the model to be consisten and train "out-of-the-box" the model.
+ + +### Additional dependencies + +We require chamfer distance for the evaluation, hence we rely on Pytorch3D knn, if you find any issue with Pytorch3D installation, you can compile the knn operation under `ops/knn`: `bash compile.sh` from the directory `$REPO/unidepth/ops/knn`. \ No newline at end of file diff --git a/scripts/train.py b/scripts/train.py new file mode 100644 index 0000000..2344bac --- /dev/null +++ b/scripts/train.py @@ -0,0 +1,555 @@ +import argparse +import json +import os +import random +import uuid +from contextlib import nullcontext +from datetime import datetime as dt +from functools import partial +from time import time +from typing import Any, Dict + +import git +import numpy as np +import psutil +import torch +import torch.nn as nn +import torch.utils.data.distributed +import wandb +from torch import distributed as dist +from torch import optim +from torch.nn.parallel.distributed import DistributedDataParallel +from torch.utils.data import DataLoader, RandomSampler, SequentialSampler +from tqdm import tqdm + +import unidepth.datasets as datasets +from unidepth.datasets import (ConcatDataset, DistributedSamplerNoDuplicate, + collate_fn, get_weights) +from unidepth.models import UniDepthV1, UniDepthV2 +from unidepth.ops.scheduler import CosineScheduler +from unidepth.utils.distributed import (barrier, create_local_process_group, + is_main_process, + local_broadcast_process_authkey, + setup_multi_processes, setup_slurm, + sync_string_across_gpus, + sync_tensor_across_gpus) +from unidepth.utils.ema_torch import (DummyExponentialMovingAverage, + ExponentialMovingAverage) +from unidepth.utils.misc import calculate_mean_values, format_seconds +from unidepth.utils.validation import validate + +EMA_INTERVAL = 10 + + +def aggregate_sync_losses(dict_: dict[str, torch.Tensor], device): + keys = list(dict_.keys()) + values = torch.tensor(list(dict_.values()), device=device) + keys = sync_string_across_gpus(keys, device) + values = sync_tensor_across_gpus(values, dim=0).cpu().tolist() + dict_ = calculate_mean_values(keys, values) + return dict_ + + +def main_worker(config: Dict[str, Any], args: argparse.Namespace): + + current_process = psutil.Process(os.getpid()) + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + seed = config["generic"]["seed"] + + if not args.distributed: + args.rank = 0 + args.local_rank = 0 + args.world_size = 1 + else: + # initializes the distributed backend which will take care of synchronizing nodes/GPUs + setup_multi_processes(config) + is_slurm = "SLURM_PROCID" in os.environ + if is_slurm: + setup_slurm("nccl", port=args.master_port) + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.local_rank = device = int(os.environ["LOCAL_RANK"]) + if not is_slurm: + dist.init_process_group("nccl", rank=args.rank, world_size=args.world_size) + torch.cuda.set_device(device) + create_local_process_group() + local_broadcast_process_authkey() + print( + f"Start running DDP on: {args.rank} (local: {args.local_rank}) with seed {seed + args.rank}." + ) + config["training"]["batch_size"] = int( + config["training"]["batch_size"] / args.world_size + ) + dist.barrier() + + # Fix seed + # Different for every machine to avoid sampling + # the same element across machines + seed = seed + args.rank + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + + batch_size = config["training"]["batch_size"] + if is_main_process(): + print("Config: ", args.config_file) + print( + f"Torch version:{torch.__version__}, cuda:{torch.version.cuda}, cudnn:{torch.backends.cudnn.version()}, threads:{torch.get_num_threads()}" + ) + print("BatchSize per GPU: ", batch_size) + print( + f"Divided into {config['training']['nsteps_accumulation_gradient']} accumulation step" + ) + + ############################## + ########### MODEL ############ + ############################## + # Build model + model = eval(config["model"]["name"])(config).to(device) + model.eval() + print(f"MODEL: {model.__class__.__name__} at {model.device}") + torch.cuda.empty_cache() + if args.distributed: + model = nn.SyncBatchNorm.convert_sync_batchnorm(model) + # Before setting find_unused_parameters=True, be sure no gradient checkpointing is on + # => DDP with double pass (unused params) is incompatible with gradient checkpointing + # My guess: CheckpointFunction does not have a pointer pointing to the original model + # params (?), so DDP Reducer will treat those model parameters as unused, and mark them + # as ready. Later, when the forward-backward is run within CheckpointFunction, those + # parameters will be marked as ready again, which can also trigger this error. + + # Possible workaround (not sure at all): + # with model.no_sync(): + # model(inputs).backward() + # works = [dist.all_reduce(p.grad, async_op=True) for p in model.parameters()] # Manual sync + # ideally one can organize grads to larger buckets to make allreduce more efficient + # for work in works: + # work.wait() + model = DistributedDataParallel( + model, + find_unused_parameters=False, + device_ids=[device], + output_device=device, + ) + + ############################## + ######### OPTIMIZER ########## + ############################## + f16 = config["training"].get("f16", False) + clipping = config["training"].get("clipping", None) + + # Optimize + ddp_model = model.module if args.distributed else model + params = ddp_model.get_params(config) + optimizer = optim.AdamW(params) + + # Load Model: + step = 0 + if config["training"].get("pretrained", None) is not None: + ddp_model.load_pretrained(config["training"]["pretrained"]) + pretrained = torch.load(config["training"]["pretrained"], map_location="cpu") + try: + optimizer.load_state_dict(pretrained["optimizer"]) + except Exception as e: + if is_main_process(): + print("Could not load optimizer state dict:", e) + step = pretrained["step"] + + # EMA + ema_class = ( + ExponentialMovingAverage + if config["training"]["ema"] + else DummyExponentialMovingAverage + ) + ema_handle = ema_class( + model.parameters(), + 1 - (1 - 0.9995) * EMA_INTERVAL, + update_after_step=75000 // EMA_INTERVAL, + tau=20000 // EMA_INTERVAL, + ) + setattr(ema_handle, "num_updates", step // EMA_INTERVAL) + + ############################## + ######### GENERICS ########### + ############################## + resize_method = config["data"].get("resize_method", "hard") + crop = config["data"].get("crop", "garg") + augmentations_db = config["data"].get("augmentations", {}) + image_shape = config["data"]["image_shape"] + nsteps_accumulation_gradient = config["training"]["nsteps_accumulation_gradient"] + batch_size = config["training"]["batch_size"] + + is_shell = int(os.environ.get("SHELL_JOB", 0)) + run_id = sync_string_across_gpus( + [f"{dt.now().strftime('%d-%h_%H-%M')}-{uuid.uuid4()}"], device + )[0] + + if not is_shell and is_main_process(): + repo_folder = os.path.dirname(os.path.realpath(__file__)) + try: + repo = git.Repo(repo_folder) + current_head = repo.head if repo.head.is_detached else repo.active_branch + notes = f"MESSAGE: {current_head.commit.message} HASH:{current_head.commit.hexsha} BRANCH:{current_head.name}" + except Exception as e: + print( + f"problem with {repo_folder}, does it exist? (original excpetion: {e})" + ) + notes = "" + + # restore the original batchsize, not acquired by other calls from now on + if args.distributed: + config["training"]["batch_size"] = ( + config["training"]["batch_size"] * args.world_size + ) + wandb.init( + project="UniDepth", + name=run_id, + config=config, + tags=None, + notes=notes, + dir=os.environ.get("WANDB_HOME", os.environ.get("TMPDIR", "/tmp")), + ) + wandb.watch(model) + + ############################## + ########## DATASET ########### + ############################## + # Datasets loading + train_datasets, val_datasets, dims = {}, {}, 0 + if is_main_process(): + print("Loading training datasets...") + for dataset in config["data"]["train_datasets"]: + assert hasattr(datasets, dataset), f"{dataset} not a custom dataset" + train_dataset: datasets.BaseDataset = getattr(datasets, dataset) + train_datasets[dataset] = train_dataset( + image_shape=image_shape, + split_file=train_dataset.train_split, + test_mode=False, + crop=crop, + augmentations_db=augmentations_db, + normalize=config["data"].get("normalization", "imagenet"), + resize_method=resize_method, + mini=1.0, + num_frames=1, + fps_range=[1, 1], + ) + dim = ( + train_datasets[dataset].dataset._addr.numel() * 8 + + train_datasets[dataset].dataset._lst.numel() + ) / (2**20) + if hasattr(train_datasets[dataset], "sequences"): + dim += ( + train_datasets[dataset].sequences._addr.numel() * 8 + + train_datasets[dataset].sequences._lst.numel() + ) / (2**20) + if is_main_process(): + print(f"{dataset}: {dim:.1f}MB") + dims += dim + + if is_main_process(): + print(f"All training datasets loaded, with total size: {dims:.1f}MB") + + barrier() + + assert batch_size % nsteps_accumulation_gradient == 0 + batch_chunk = batch_size // nsteps_accumulation_gradient + + train_dataset = ConcatDataset( + [t for t in train_datasets.values()], + shape_constraints=config["data"]["augmentations"]["shape_constraints"], + pairs=config["data"]["pairs"], + ) + + if is_main_process(): + print("Loading validation datasets...") + for dataset in config["data"]["val_datasets"]: + val_dataset: datasets.BaseDataset = getattr(datasets, dataset) + val_datasets[dataset] = val_dataset( + image_shape=image_shape, + split_file=val_dataset.test_split, + test_mode=True, + crop=crop, + augmentations_db=augmentations_db, + normalize=config["data"].get("normalization", "imagenet"), + resize_method=resize_method, + num_frames=-1, + mini=1.0, + ) + + # Dataset samplers, create distributed sampler pinned to rank + if args.distributed: + weights, num_samples = get_weights(train_datasets, config["data"]["sampling"]) + train_sampler = torch.utils.data.WeightedRandomSampler( + weights, num_samples, replacement=True + ) + valid_samplers = { + k: DistributedSamplerNoDuplicate( + v, + num_replicas=args.world_size, + rank=args.rank, + shuffle=False, + drop_last=False, + ) + for k, v in val_datasets.items() + } + else: + train_sampler = RandomSampler(train_dataset) + valid_samplers = {k: SequentialSampler(v) for k, v in val_datasets.items()} + + train_sampler = torch.utils.data.BatchSampler( + train_sampler, batch_size=batch_size, drop_last=True + ) + + # DATASET LOADERS + val_batch_size = 1 + num_workers = int(os.environ.get("SLURM_CPUS_PER_TASK", 4)) + train_loader = DataLoader( + train_dataset, + num_workers=num_workers, + sampler=train_sampler, + pin_memory=True, + collate_fn=partial(collate_fn, is_batched=True), + persistent_workers=num_workers > 0, + ) + val_loaders = { + name_dataset: DataLoader( + dataset, + batch_size=val_batch_size, + shuffle=False, + num_workers=num_workers, + sampler=valid_samplers[name_dataset], + pin_memory=True, + drop_last=False, + collate_fn=partial(collate_fn, is_batched=False), + ) + for name_dataset, dataset in val_datasets.items() + } + + # SCHEDULERS + scheduler_wd = CosineScheduler( + optimizer, + key="weight_decay", + init_value=config["training"]["wd"], + base_value=config["training"]["wd"], + final_value=config["training"]["wd_final"], + warmup_iters=0, + total_iters=config["training"]["n_iters"], + step_init=step - 1, + ) + scheduler_lr = CosineScheduler( + optimizer, + key="lr", + init_value=config["training"]["lr"] * config["training"].get("lr_warmup", 1.0), + final_value=config["training"]["lr_final"], + warmup_iters=config["training"]["warmup_iters"], + total_iters=config["training"]["n_iters"], + step_init=step - 1, + ) + scheduler_betas = CosineScheduler( + optimizer, + key="betas", + init_value=0.95 if config["training"].get("cycle_betas", True) else 0.9, + base_value=0.85 if config["training"].get("cycle_betas", True) else 0.9, + final_value=0.95 if config["training"].get("cycle_betas", True) else 0.9, + warmup_iters=config["training"]["warmup_iters"], + total_iters=config["training"]["n_iters"], + step_init=step - 1, + ) + + # Set loss scaler for half precision training + sanity zeroing grads + # dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + dtype = torch.float16 # bfloat has some issue in terms of resolution... + context = torch.autocast(device_type="cuda", dtype=dtype, enabled=f16) + optimizer.zero_grad(set_to_none=True) + + ############################## + ########## TRAINING ########## + ############################## + # Remember that if i-th layer is frozen, this will break gradient checkpointing + # in layer i+1-th. This is because CheckpointFunction treats the i+1-th input as + # without gradient, thus the i+1-th layer does not have grads (?). To solve it, + # just add requires_grad_() to the inputs coming from the frozen layer + ddp_model.train() + start = time() + n_steps = config["training"]["n_iters"] + init_steps = int(step) + track_pbar = is_shell + + if is_main_process(): + print("Is a shell job?", is_shell) + print("Use dtype:", dtype if f16 else torch.float32) + print( + f'Train for {config["training"]["n_iters"]} steps, validate every {config["training"]["validation_interval"]} steps' + ) + print(f"START with {num_workers} workers") + if track_pbar: + pbar = tqdm(total=n_steps - init_steps) + + scaler = torch.cuda.amp.GradScaler(enabled=f16) + track_losses = {} + system_memory = dict(psutil.virtual_memory()._asdict())["available"] / 2**30 + cpid_memory = current_process.memory_info()[0] / 2.0**30 + gpu_mem = (torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 2**30 + + while True: + for j, batches in enumerate(train_loader): + + system_memory = ( + 0.99 * system_memory + + 0.01 * dict(psutil.virtual_memory()._asdict())["available"] / 2**30 + ) + cpid_memory = ( + 0.99 * cpid_memory + 0.01 * current_process.memory_info()[0] / 2.0**30 + ) + gpu_mem = ( + 0.99 * gpu_mem + + 0.01 + * (torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) + / 2**30 + ) + if j % 1000 == 0 and is_main_process(): + print(f"System information at step {j}") + print(f"System-wide RAM available: {system_memory:.2f}GB") + print(f"CPU utilization: {psutil.cpu_percent(interval=None)}%") + print(f"GPU memory utilized: {gpu_mem:.2f}GB") + + batches["data"] = { + k: v.to(model.device, non_blocking=True) + for k, v in batches["data"].items() + } + for idx in range(nsteps_accumulation_gradient): + batch = {} + batch_slice = slice(idx * batch_chunk, (idx + 1) * batch_chunk) + batch["data"] = {k: v[batch_slice] for k, v in batches["data"].items()} + batch["img_metas"] = batches["img_metas"][batch_slice] + + # remove temporal dimension of the dataloder, here is always 1! + batch["data"] = {k: v.squeeze(1) for k, v in batch["data"].items()} + batch["img_metas"] = [ + {k: v[0] for k, v in meta.items() if isinstance(v, list)} + for meta in batch["img_metas"] + ] + + with ( + model.no_sync() + if idx < nsteps_accumulation_gradient - 1 + else nullcontext() + ): + with context: + preds, losses = model(batch["data"], batch["img_metas"]) + loss = sum(losses["opt"].values()) / nsteps_accumulation_gradient + scaler.scale(loss).backward() + + losses_dict = { + k: v.detach() for loss in losses.values() for k, v in loss.items() + } + track_losses.update( + { + k: 0.99 * track_losses.get(k, v) + + 0.01 * torch.nan_to_num(v, nan=1e5, posinf=1e5, neginf=1e5) + for k, v in losses_dict.items() + } + ) + + if clipping is not None: + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), clipping) + + scaler.step(optimizer) + scaler.update() + + scheduler_wd.step() + scheduler_lr.step() + scheduler_betas.step() + optimizer.zero_grad(set_to_none=True) + if step % EMA_INTERVAL == 0: + ema_handle.update() + + if is_main_process() and track_pbar: + pbar.update(1) + + step += 1 + + # LOGGING + if step % 50 == 0: + track_losses = aggregate_sync_losses(track_losses, device=model.device) + if is_main_process(): + try: + wandb.log( + { + **{f"Train/{k}": v for k, v in track_losses.items()}, + **{f"Train/lr": scheduler_lr.get()[-1]}, + **{f"Train/wd": scheduler_wd.get()[-2]}, + **{f"Train/scale_f16": scaler.get_scale()}, + }, + step=step, + ) + except Exception as e: + print("Not logging loss because of:", e) + pass + + if step % 100 == 0 and is_main_process(): + log_loss_dict = {f"Train/{k}": v for k, v in track_losses.items()} + elapsed = int(time() - start) + eta = int(elapsed * (n_steps - step) / max(1, step - init_steps)) + print( + f"Loss at {step}/{n_steps} [{format_seconds(elapsed)}<{format_seconds(eta)}]:" + ) + print(", ".join([f"{k}: {v:.5f}" for k, v in log_loss_dict.items()])) + + # Validation + is_last_step = step >= config["training"]["n_iters"] + is_validation = step % config["training"]["validation_interval"] == 0 + if is_last_step or is_validation: + torch.cuda.empty_cache() + barrier() + if is_main_process(): + print(f"Validation at {step}th step...") + + ddp_model.eval() + start_validation = time() + with torch.no_grad(), ema_handle.average_parameters(): + validate( + model, test_loaders=val_loaders, step=step, context=context + ) + + if is_main_process(): + print(f"Elapsed: {format_seconds(int(time() - start_validation))}") + ddp_model.train() + torch.cuda.empty_cache() + + if step >= config["training"]["n_iters"]: + if is_main_process() and track_pbar: + pbar.close() + wandb.finish(0) + dist.destroy_process_group() + return 0 + + +if __name__ == "__main__": + # dummy folder to avoid Triton conflicts on multi-node training with slurm + if "SLURM_PROCID" in os.environ: + os.environ["TRITON_CACHE_DIR"] = "/tmp" + + parser = argparse.ArgumentParser( + description="Training script", conflict_handler="resolve" + ) + parser.add_argument("--config-file", type=str, required=True) + parser.add_argument("--master-port", type=str) + parser.add_argument("--distributed", action="store_true") + parser.add_argument("--local_rank", type=int, default=0) + + args = parser.parse_args() + with open(args.config_file, "r") as f: + config = json.load(f) + + deterministic = config["generic"].get("deterministic", True) + torch.backends.cudnn.deterministic = deterministic + torch.backends.cudnn.benchmark = not deterministic + # set to false otw f16 + efficient memory is not really efficient... i.e. mem is greater than with + torch.backends.cuda.enable_mem_efficient_sdp(False) + torch.set_num_threads(1) + main_worker(config, args) diff --git a/unidepth/datasets/__init__.py b/unidepth/datasets/__init__.py new file mode 100644 index 0000000..3ec2787 --- /dev/null +++ b/unidepth/datasets/__init__.py @@ -0,0 +1,7 @@ +from .base_dataset import BaseDataset +from .ibims import IBims +from .kitti import KITTI +from .nyuv2 import NYUv2Depth +from .samplers import DistributedSamplerNoDuplicate +from .sintel import Sintel +from .utils import ConcatDataset, collate_fn, get_weights diff --git a/unidepth/datasets/adt.py b/unidepth/datasets/adt.py new file mode 100644 index 0000000..b77741f --- /dev/null +++ b/unidepth/datasets/adt.py @@ -0,0 +1,64 @@ +from typing import Any + +import torch + +from unidepth.datasets.sequence_dataset import SequenceDataset + + +class ADT(SequenceDataset): + min_depth = 0.01 + max_depth = 20.0 + depth_scale = 1000.0 + test_split = "val.txt" + train_split = "train.txt" + sequences_file = "sequences.json" + hdf5_paths = [f"ADT.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["camera_params", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, # if not test_mode else [*decode_fields, "points"], + inplace_fields=inplace_fields, + **kwargs, + ) + + def preprocess(self, results): + self.resizer.ctx = None + for i, seq in enumerate(results["sequence_fields"]): + # Create a mask where the distance from the center is less than H/2 + H, W = results[seq]["image"].shape[-2:] + x = torch.linspace(-W / 2 - 0.5, W / 2 + 0.5, W) + y = torch.linspace(-H / 2 - 0.5, H / 2 + 0.5, H) + xv, yv = torch.meshgrid(x, y, indexing="xy") + distance_from_center = torch.sqrt(xv**2 + yv**2).reshape(1, 1, H, W) + results[seq]["validity_mask"] = distance_from_center < (H / 2) + 20 + return super().preprocess(results) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames + results["synthetic"] = [True] * self.num_frames + results["quality"] = [0] * self.num_frames + return results diff --git a/unidepth/datasets/base_dataset.py b/unidepth/datasets/base_dataset.py new file mode 100644 index 0000000..268fc6d --- /dev/null +++ b/unidepth/datasets/base_dataset.py @@ -0,0 +1,317 @@ +import os +from abc import abstractmethod +from math import ceil, log +from typing import Any, Dict, Tuple + +import numpy as np +import torch +from torch.utils.data import Dataset + +import unidepth.datasets.pipelines as pipelines +from unidepth.utils import (eval_3d, eval_depth, identity, is_main_process, + recursive_index, sync_tensor_across_gpus) +from unidepth.utils.constants import (IMAGENET_DATASET_MEAN, + IMAGENET_DATASET_STD, + OPENAI_DATASET_MEAN, OPENAI_DATASET_STD) + + +class BaseDataset(Dataset): + min_depth = 0.01 + max_depth = 1000.0 + + def __init__( + self, + image_shape: Tuple[int, int], + split_file: str, + test_mode: bool, + benchmark: bool, + normalize: bool, + augmentations_db: Dict[str, Any], + resize_method: str, + mini: float, + **kwargs, + ) -> None: + super().__init__() + assert normalize in [None, "imagenet", "openai"] + + self.split_file = split_file + self.test_mode = test_mode + self.data_root = os.environ["DATAROOT"] + self.image_shape = image_shape + self.resize_method = resize_method + self.mini = mini + self.num_frames = 1 + self.metrics_store = {} + self.metrics_count = {} + + if normalize == "imagenet": + self.normalization_stats = { + "mean": torch.tensor(IMAGENET_DATASET_MEAN), + "std": torch.tensor(IMAGENET_DATASET_STD), + } + elif normalize == "openai": + self.normalization_stats = { + "mean": torch.tensor(OPENAI_DATASET_MEAN), + "std": torch.tensor(OPENAI_DATASET_STD), + } + else: + self.normalization_stats = { + "mean": torch.tensor([0.0, 0.0, 0.0]), + "std": torch.tensor([1.0, 1.0, 1.0]), + } + + for k, v in augmentations_db.items(): + setattr(self, k, v) + if not self.test_mode: + self._augmentation_space() + + self.masker = pipelines.AnnotationMask( + min_value=0.0, + max_value=self.max_depth if test_mode else None, + custom_fn=identity, + ) + self.filler = pipelines.RandomFiller(noise_pad=True) + + shape_mult = self.shape_constraints["shape_mult"] + self.image_shape = [ + ceil(self.image_shape[0] / shape_mult) * shape_mult, + ceil(self.image_shape[1] / shape_mult) * shape_mult, + ] + self.resizer = pipelines.ContextCrop( + image_shape=self.image_shape, + train_ctx_range=(1.0 / self.random_scale, 1.0 * self.random_scale), + test_min_ctx=self.test_context, + keep_original=test_mode, + shape_constraints=self.shape_constraints, + ) + + self.collecter = pipelines.Collect( + keys=["image_fields", "mask_fields", "gt_fields", "camera_fields"] + ) + + def __len__(self): + return len(self.dataset) + + def pack_batch(self, results): + for fields_name in [ + "image_fields", + "gt_fields", + "mask_fields", + "camera_fields", + ]: + fields = results.get(fields_name) + packed = { + field: torch.cat( + [results[seq][field] for seq in results["sequence_fields"]] + ) + for field in fields + } + results.update(packed) + return results + + def unpack_batch(self, results): + for fields_name in [ + "image_fields", + "gt_fields", + "mask_fields", + "camera_fields", + ]: + fields = results.get(fields_name) + unpacked = { + field: { + seq: results[field][idx : idx + 1] + for idx, seq in enumerate(results["sequence_fields"]) + } + for field in fields + } + results.update(unpacked) + return results + + def _augmentation_space(self): + self.augmentations_dict = { + "Flip": pipelines.RandomFlip(prob=self.flip_p), + "Jitter": pipelines.RandomColorJitter( + (-self.random_jitter, self.random_jitter), prob=self.jitter_p + ), + "Gamma": pipelines.RandomGamma( + (-self.random_gamma, self.random_gamma), prob=self.gamma_p + ), + "Blur": pipelines.GaussianBlur( + kernel_size=13, sigma=(0.1, self.random_blur), prob=self.blur_p + ), + "Grayscale": pipelines.RandomGrayscale(prob=self.grayscale_p), + } + + def augment(self, results): + for name, aug in self.augmentations_dict.items(): + results = aug(results) + return results + + def prepare_depth_eval(self, inputs, preds): + new_preds = {} + keyframe_idx = getattr(self, "keyframe_idx", None) + slice_idx = slice( + keyframe_idx, keyframe_idx + 1 if keyframe_idx is not None else None + ) + new_gts = inputs["depth"][slice_idx] + new_masks = inputs["depth_mask"][slice_idx].bool() + for key, val in preds.items(): + if "depth" in key: + new_preds[key] = val[slice_idx] + return new_gts, new_preds, new_masks + + def prepare_points_eval(self, inputs, preds): + new_preds = {} + new_gts = inputs["points"] + new_masks = inputs["depth_mask"].bool() + for key, val in preds.items(): + if "points" in key: + new_preds[key] = val + return new_gts, new_preds, new_masks + + def add_points(self, inputs): + inputs["points"] = inputs.get("camera_original", inputs["camera"]).reconstruct( + inputs["depth"] + ) + return inputs + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def accumulate_metrics( + self, + inputs, + preds, + keyframe_idx=None, + metrics=["depth", "points", "flow_fwd", "pairwise"], + ): + if "depth" in inputs and "points" not in inputs: + inputs = self.add_points(inputs) + + available_metrics = [] + for metric in metrics: + metric_in_gt = any((metric in k for k in inputs.keys())) + metric_in_pred = any((metric in k for k in preds.keys())) + if metric_in_gt and metric_in_pred: + available_metrics.append(metric) + + if keyframe_idx is not None: + inputs = recursive_index(inputs, slice(keyframe_idx, keyframe_idx + 1)) + preds = recursive_index(preds, slice(keyframe_idx, keyframe_idx + 1)) + + if "depth" in available_metrics: + depth_gt, depth_pred, depth_masks = self.prepare_depth_eval(inputs, preds) + self.accumulate_metrics_depth(depth_gt, depth_pred, depth_masks) + + if "points" in available_metrics: + points_gt, points_pred, points_masks = self.prepare_points_eval( + inputs, preds + ) + self.accumulate_metrics_3d(points_gt, points_pred, points_masks) + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def accumulate_metrics_depth(self, gts, preds, masks): + for eval_type, pred in preds.items(): + log_name = eval_type.replace("depth", "").strip("-").strip("_") + if log_name not in self.metrics_store: + self.metrics_store[log_name] = {} + current_count = self.metrics_count.get( + log_name, torch.tensor([], device=gts.device) + ) + new_count = masks.view(gts.shape[0], -1).sum(dim=-1) + self.metrics_count[log_name] = torch.cat([current_count, new_count]) + for k, v in eval_depth(gts, pred, masks, max_depth=self.max_depth).items(): + current_metric = self.metrics_store[log_name].get( + k, torch.tensor([], device=gts.device) + ) + self.metrics_store[log_name][k] = torch.cat([current_metric, v]) + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def accumulate_metrics_3d(self, gts, preds, masks): + thresholds = torch.linspace( + log(self.min_depth), + log(self.max_depth / 20), + steps=100, + device=gts.device, + ).exp() + for eval_type, pred in preds.items(): + log_name = eval_type.replace("points", "").strip("-").strip("_") + if log_name not in self.metrics_store: + self.metrics_store[log_name] = {} + current_count = self.metrics_count.get( + log_name, torch.tensor([], device=gts.device) + ) + new_count = masks.view(gts.shape[0], -1).sum(dim=-1) + self.metrics_count[log_name] = torch.cat([current_count, new_count]) + for k, v in eval_3d(gts, pred, masks, thresholds=thresholds).items(): + current_metric = self.metrics_store[log_name].get( + k, torch.tensor([], device=gts.device) + ) + self.metrics_store[log_name][k] = torch.cat([current_metric, v]) + + def get_evaluation(self, metrics=None): + metric_vals = {} + for eval_type in metrics if metrics is not None else self.metrics_store.keys(): + assert self.metrics_store[eval_type] + cnts = sync_tensor_across_gpus(self.metrics_count[eval_type]) + for name, val in self.metrics_store[eval_type].items(): + # vals_r = (sync_tensor_across_gpus(val) * cnts / cnts.sum()).sum() + vals_r = sync_tensor_across_gpus(val).mean() + metric_vals[f"{eval_type}_{name}".strip("_")] = np.round( + vals_r.cpu().item(), 5 + ) + self.metrics_store[eval_type] = {} + self.metrics_count = {} + return metric_vals + + def log_load_dataset(self): + if is_main_process(): + info = f"Loaded {self.__class__.__name__} with {len(self)} images." + print(info) + + def pre_pipeline(self, results): + results["image_fields"] = results.get("image_fields", set()) + results["gt_fields"] = results.get("gt_fields", set()) + results["mask_fields"] = results.get("mask_fields", set()) + results["sequence_fields"] = results.get("sequence_fields", set()) + results["camera_fields"] = results.get("camera_fields", set()) + results["dataset_name"] = [self.__class__.__name__] * self.num_frames + results["depth_scale"] = [self.depth_scale] * self.num_frames + results["si"] = [False] * self.num_frames + results["synthetic"] = [False] * self.num_frames + results["valid_camera"] = [True] * self.num_frames + results["valid_pose"] = [True] * self.num_frames + return results + + def eval_mask(self, valid_mask): + return valid_mask + + @abstractmethod + def preprocess(self, results): + raise NotImplementedError + + @abstractmethod + def postprocess(self, results): + raise NotImplementedError + + @abstractmethod + def get_mapper(self): + raise NotImplementedError + + @abstractmethod + def get_intrinsics(self, idx, image_name): + raise NotImplementedError + + @abstractmethod + def get_extrinsics(self, idx, image_name): + raise NotImplementedError + + @abstractmethod + def load_dataset(self): + raise NotImplementedError + + @abstractmethod + def get_single_item(self, idx, sample=None, mapper=None): + raise NotImplementedError + + @abstractmethod + def __getitem__(self, idx): + raise NotImplementedError diff --git a/unidepth/datasets/hm3d.py b/unidepth/datasets/hm3d.py new file mode 100644 index 0000000..81a8031 --- /dev/null +++ b/unidepth/datasets/hm3d.py @@ -0,0 +1,49 @@ +from typing import Any + +from unidepth.datasets.sequence_dataset import SequenceDataset + + +class HM3D(SequenceDataset): + min_depth = 0.01 + max_depth = 10.0 + depth_scale = 1000.0 + test_split = "val.txt" + train_split = "full.txt" + sequences_file = "sequences.json" + hdf5_paths = [f"HM3D.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["K", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames + results["quality"] = [2] * self.num_frames + return results diff --git a/unidepth/datasets/ibims.py b/unidepth/datasets/ibims.py new file mode 100644 index 0000000..8050b62 --- /dev/null +++ b/unidepth/datasets/ibims.py @@ -0,0 +1,123 @@ +import json +import os + +import h5py +import numpy as np +import torch + +from unidepth.datasets.image_dataset import ImageDataset +from unidepth.datasets.sequence_dataset import SequenceDataset +from unidepth.datasets.utils import DatasetFromList + + +class IBims(ImageDataset): + min_depth = 0.005 + max_depth = 25.0 + depth_scale = 1000.0 + train_split = "ibims_val.txt" + test_split = "ibims_val.txt" + intrisics_file = "ibims_intrinsics.json" + hdf5_paths = ["ibims.hdf5"] + + def __init__( + self, + image_shape, + split_file, + test_mode, + crop=None, + benchmark=False, + augmentations_db={}, + normalize=True, + resize_method="hard", + mini=1.0, + **kwargs, + ): + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + **kwargs, + ) + self.test_mode = test_mode + + self.crop = crop + self.load_dataset() + + def load_dataset(self): + h5file = h5py.File( + os.path.join(self.data_root, self.hdf5_paths[0]), + "r", + libver="latest", + swmr=True, + ) + txt_file = np.array(h5file[self.split_file]) + txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1 + intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") + intrinsics = json.loads(intrinsics) + h5file.close() + dataset = [] + for line in txt_string.split("\n"): + image_filename, depth_filename = line.strip().split(" ") + intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3] + sample = [image_filename, depth_filename, intrinsics_val] + dataset.append(sample) + + self.dataset = DatasetFromList(dataset) + self.log_load_dataset() + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] + results["quality"] = [1] + return results + + +class IBims_F(SequenceDataset): + min_depth = 0.01 + max_depth = 25.0 + depth_scale = 1000.0 + test_split = "train.txt" + train_split = "train.txt" + sequences_file = "sequences.json" + hdf5_paths = ["IBims-F.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, float], + resize_method: str, + mini: float, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth", "points"], + inplace_fields: list[str] = ["camera_params"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames + results["quality"] = [1] * self.num_frames + return results diff --git a/unidepth/datasets/image_dataset.py b/unidepth/datasets/image_dataset.py new file mode 100644 index 0000000..505ef0d --- /dev/null +++ b/unidepth/datasets/image_dataset.py @@ -0,0 +1,176 @@ +import io +import os +from time import time +from typing import Any, Dict, List, Tuple + +import numpy as np +import tables +import torch +import torchvision +import torchvision.transforms.v2.functional as TF +from PIL import Image + +from unidepth.datasets.base_dataset import BaseDataset +from unidepth.utils import is_main_process +from unidepth.utils.camera import BatchCamera, Pinhole + +""" +Awful class for legacy reasons, we assume only pinhole cameras +And we "fake" sequences by setting sequence_fields to [(0, 0)] and cam2w as eye(4) +""" + + +class ImageDataset(BaseDataset): + def __init__( + self, + image_shape: Tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: Dict[str, Any], + resize_method: str, + mini: float, + benchmark: bool = False, + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + **kwargs, + ) + self.mapper = self.get_mapper() + + def get_single_item(self, idx, sample=None, mapper=None): + sample = self.dataset[idx] if sample is None else sample + mapper = self.mapper if mapper is None else mapper + + results = {} + results = self.pre_pipeline(results) + results["sequence_fields"] = [(0, 0)] + + chunk_idx = ( + int(sample[self.mapper["chunk_idx"]]) if "chunk_idx" in self.mapper else 0 + ) + h5_path = os.path.join(self.data_root, self.hdf5_paths[chunk_idx]) + with tables.File( + h5_path, + mode="r", + libver="latest", + swmr=True, + ) as h5file_chunk: + for key_mapper, idx_mapper in mapper.items(): + if "image" not in key_mapper and "depth" not in key_mapper: + continue + value = sample[idx_mapper] + results[key_mapper] = value + name = key_mapper.replace("_filename", "") + value_root = "/" + value + + if "image" in key_mapper: + results["filename"] = value + file = h5file_chunk.get_node(value_root).read() + image = ( + torchvision.io.decode_image(torch.from_numpy(file)) + .to(torch.uint8) + .squeeze() + ) + results["image_fields"].add(name) + results[f"image_ori_shape"] = image.shape[-2:] + results[name] = image[None, ...] + + # collect camera information for the given image + name = name.replace("image_", "") + results["camera_fields"].update({"camera", "cam2w"}) + K = self.get_intrinsics(idx, value) + if K is None: + K = torch.eye(3) + K[0, 0] = K[1, 1] = 0.7 * self.image_shape[1] + K[0, 2] = 0.5 * self.image_shape[1] + K[1, 2] = 0.5 * self.image_shape[0] + + camera = Pinhole(K=K[None, ...].clone()) + results["camera"] = BatchCamera.from_camera(camera) + results["cam2w"] = self.get_extrinsics(idx, value)[None, ...] + + elif "depth" in key_mapper: + # start = time() + file = h5file_chunk.get_node(value_root).read() + depth = Image.open(io.BytesIO(file)) + depth = TF.pil_to_tensor(depth).squeeze().to(torch.float32) + if depth.ndim == 3: + depth = depth[2] + depth[1] * 255 + depth[0] * 255 * 255 + + results["gt_fields"].add(name) + results[f"depth_ori_shape"] = depth.shape + + depth = ( + depth.view(1, 1, *depth.shape).contiguous() / self.depth_scale + ) + results[name] = depth + + results = self.preprocess(results) + if not self.test_mode: + results = self.augment(results) + results = self.postprocess(results) + return results + + def preprocess(self, results): + self.resizer.ctx = None + results = self.resizer(results) + + num_pts = torch.count_nonzero(results["depth"] > self.min_depth) + if num_pts < 50: + raise IndexError(f"Too few points in depth map ({num_pts})") + + for key in results.get("image_fields", ["image"]): + results[key] = results[key].to(torch.float32) / 255 + return results + + def postprocess(self, results): + # normalize after because color aug requires [0,255]? + for key in results.get("image_fields", ["image"]): + results[key] = TF.normalize(results[key], **self.normalization_stats) + results = self.filler(results) + results = self.unpack_batch(results) + results = self.masker(results) + results = self.collecter(results) + return results + + def __getitem__(self, idx): + try: + if isinstance(idx, (list, tuple)): + results = [self.get_single_item(i) for i in idx] + else: + results = self.get_single_item(idx) + except Exception as e: + print(f"Error loading sequence {idx} for {self.__class__.__name__}: {e}") + idx = np.random.randint(0, len(self.dataset)) + results = self[idx] + return results + + def get_intrinsics(self, idx, image_name): + idx_sample = self.mapper.get("K", 1000) + sample = self.dataset[idx] + if idx_sample >= len(sample): + return None + return sample[idx_sample] + + def get_extrinsics(self, idx, image_name): + idx_sample = self.mapper.get("cam2w", 1000) + sample = self.dataset[idx] + if idx_sample >= len(sample): + return torch.eye(4) + return sample[idx_sample] + + def get_mapper(self): + return { + "image_filename": 0, + "depth_filename": 1, + "K": 2, + } diff --git a/unidepth/datasets/kitti.py b/unidepth/datasets/kitti.py new file mode 100644 index 0000000..98f6afc --- /dev/null +++ b/unidepth/datasets/kitti.py @@ -0,0 +1,150 @@ +import os + +import h5py +import numpy as np +import torch + +from unidepth.datasets.image_dataset import ImageDataset +from unidepth.datasets.pipelines import AnnotationMask, KittiCrop +from unidepth.datasets.utils import DatasetFromList +from unidepth.utils import identity + + +class KITTI(ImageDataset): + CAM_INTRINSIC = { + "2011_09_26": torch.tensor( + [ + [7.215377e02, 0.000000e00, 6.095593e02, 4.485728e01], + [0.000000e00, 7.215377e02, 1.728540e02, 2.163791e-01], + [0.000000e00, 0.000000e00, 1.000000e00, 2.745884e-03], + ] + ), + "2011_09_28": torch.tensor( + [ + [7.070493e02, 0.000000e00, 6.040814e02, 4.575831e01], + [0.000000e00, 7.070493e02, 1.805066e02, -3.454157e-01], + [0.000000e00, 0.000000e00, 1.000000e00, 4.981016e-03], + ] + ), + "2011_09_29": torch.tensor( + [ + [7.183351e02, 0.000000e00, 6.003891e02, 4.450382e01], + [0.000000e00, 7.183351e02, 1.815122e02, -5.951107e-01], + [0.000000e00, 0.000000e00, 1.000000e00, 2.616315e-03], + ] + ), + "2011_09_30": torch.tensor( + [ + [7.070912e02, 0.000000e00, 6.018873e02, 4.688783e01], + [0.000000e00, 7.070912e02, 1.831104e02, 1.178601e-01], + [0.000000e00, 0.000000e00, 1.000000e00, 6.203223e-03], + ] + ), + "2011_10_03": torch.tensor( + [ + [7.188560e02, 0.000000e00, 6.071928e02, 4.538225e01], + [0.000000e00, 7.188560e02, 1.852157e02, -1.130887e-01], + [0.000000e00, 0.000000e00, 1.000000e00, 3.779761e-03], + ] + ), + } + min_depth = 0.05 + max_depth = 80.0 + depth_scale = 256.0 + test_split = "kitti_eigen_test.txt" + train_split = "kitti_eigen_train.txt" + test_split_benchmark = "kitti_test.txt" + hdf5_paths = ["kitti.hdf5"] + + def __init__( + self, + image_shape, + split_file, + test_mode, + crop=None, + benchmark=False, + augmentations_db={}, + normalize=True, + resize_method="hard", + mini=1.0, + **kwargs, + ): + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + **kwargs, + ) + self.test_mode = test_mode + self.crop = crop + self.cropper_base = KittiCrop(crop_size=(352, 1216)) + self.load_dataset() + + def load_dataset(self): + h5file = h5py.File( + os.path.join(self.data_root, self.hdf5_paths[0]), + "r", + libver="latest", + swmr=True, + ) + txt_file = np.array(h5file[self.split_file]) + txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1 + h5file.close() + dataset = [] + for line in txt_string.split("\n"): + image_filename = line.strip().split(" ")[0] + depth_filename = line.strip().split(" ")[1] + if depth_filename == "None": + continue + sample = [ + image_filename, + depth_filename, + ] + dataset.append(sample) + + self.dataset = DatasetFromList(dataset) + self.log_load_dataset() + + def get_intrinsics(self, idx, image_name): + return self.CAM_INTRINSIC[image_name.split("/")[0]][:, :3].clone() + + def preprocess(self, results): + results = self.cropper_base(results) + results = self.resizer(results) + for key in results.get("image_fields", ["image"]): + results[key] = results[key].to(torch.float32) / 255 + return results + + def eval_mask(self, valid_mask, info={}): + """Do grag_crop or eigen_crop for testing""" + mask_height, mask_width = valid_mask.shape[-2:] + eval_mask = torch.zeros_like(valid_mask) + if "garg" in self.crop: + eval_mask[ + ..., + int(0.40810811 * mask_height) : int(0.99189189 * mask_height), + int(0.03594771 * mask_width) : int(0.96405229 * mask_width), + ] = 1 + elif "eigen" in self.crop: + eval_mask[ + ..., + int(0.3324324 * mask_height) : int(0.91351351 * mask_height), + int(0.03594771 * mask_width) : int(0.96405229 * mask_width), + ] = 1 + return torch.logical_and(valid_mask, eval_mask) + + def get_mapper(self): + return { + "image_filename": 0, + "depth_filename": 1, + } + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [False] + return results diff --git a/unidepth/datasets/nyuv2.py b/unidepth/datasets/nyuv2.py new file mode 100644 index 0000000..4ee5c24 --- /dev/null +++ b/unidepth/datasets/nyuv2.py @@ -0,0 +1,94 @@ +import os + +import h5py +import numpy as np +import torch + +from unidepth.datasets.image_dataset import ImageDataset +from unidepth.datasets.utils import DatasetFromList + + +class NYUv2Depth(ImageDataset): + CAM_INTRINSIC = { + "ALL": torch.tensor( + [ + [5.1885790117450188e02, 0, 3.2558244941119034e02], + [0, 5.1946961112127485e02, 2.5373616633400465e02], + [0, 0, 1], + ] + ) + } + min_depth = 0.005 + max_depth = 10.0 + depth_scale = 1000.0 + test_split = "nyu_test.txt" + train_split = "nyu_train.txt" + hdf5_paths = ["nyuv2.hdf5"] + + def __init__( + self, + image_shape, + split_file, + test_mode, + crop=None, + benchmark=False, + augmentations_db={}, + normalize=True, + resize_method="hard", + mini=1.0, + **kwargs, + ): + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + **kwargs, + ) + self.test_mode = test_mode + self.load_dataset() + + def load_dataset(self): + h5file = h5py.File( + os.path.join(self.data_root, self.hdf5_paths[0]), + "r", + libver="latest", + swmr=True, + ) + txt_file = np.array(h5file[self.split_file]) + txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1 + h5file.close() + dataset = [] + for line in txt_string.split("\n"): + image_filename, depth_filename, _ = line.strip().split(" ") + sample = [ + image_filename, + depth_filename, + ] + dataset.append(sample) + + self.dataset = DatasetFromList(dataset) + self.log_load_dataset() + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] + return results + + def get_intrinsics(self, idx, image_name): + return self.CAM_INTRINSIC["ALL"].clone() + + def eval_mask(self, valid_mask, info={}): + border_mask = torch.zeros_like(valid_mask) + border_mask[..., 45:-9, 41:-39] = 1 + return torch.logical_and(valid_mask, border_mask) + + def get_mapper(self): + return { + "image_filename": 0, + "depth_filename": 1, + } diff --git a/unidepth/datasets/pipelines/__init__.py b/unidepth/datasets/pipelines/__init__.py new file mode 100644 index 0000000..d7c09b3 --- /dev/null +++ b/unidepth/datasets/pipelines/__init__.py @@ -0,0 +1,9 @@ +from .formating import AnnotationMask, Collect +from .transforms import (Compose, ContextCrop, Crop, GaussianBlur, KittiCrop, + PanoCrop, PanoRoll, RandomAutoContrast, + RandomBrightness, RandomColor, RandomColorJitter, + RandomContrast, RandomEqualize, RandomFiller, + RandomFlip, RandomGamma, RandomGrayscale, + RandomInvert, RandomMasking, RandomPosterize, + RandomSaturation, RandomSharpness, RandomShear, + RandomSolarize, RandomTranslate, Rotate) diff --git a/unidepth/datasets/pipelines/formating.py b/unidepth/datasets/pipelines/formating.py new file mode 100644 index 0000000..51b000b --- /dev/null +++ b/unidepth/datasets/pipelines/formating.py @@ -0,0 +1,95 @@ +from collections.abc import Sequence + +import numpy as np +import torch + + +class Collect(object): + def __init__( + self, + keys, + meta_keys=( + "filename", + "keyframe_idx", + "sequence_name", + "image_filename", + "depth_filename", + "image_ori_shape", + "camera", + "original_camera", + "sfm", + "image_shape", + "resized_shape", + "scale_factor", + "rotation", + "resize_factor", + "flip", + "flip_direction", + "dataset_name", + "paddings", + "max_value", + "log_mean", + "log_std", + "image_rescale", + "focal_rescale", + "depth_rescale", + ), + ): + self.keys = keys + self.meta_keys = meta_keys + + def __call__(self, results): + data_keys = [key for field in self.keys for key in results.get(field, [])] + data = { + key: { + sequence_key: results[key][sequence_key] + for sequence_key in results["sequence_fields"] + } + for key in data_keys + } + data["img_metas"] = { + key: value for key, value in results.items() if key not in data_keys + } + return data + + def __repr__(self): + return ( + self.__class__.__name__ + f"(keys={self.keys}, meta_keys={self.meta_keys})" + ) + + +class AnnotationMask(object): + def __init__(self, min_value, max_value, custom_fn=lambda x: x): + self.min_value = min_value + self.max_value = max_value + self.custom_fn = custom_fn + + def __call__(self, results): + for key in results.get("gt_fields", []): + if key + "_mask" in results["mask_fields"]: + if "flow" in key: + for sequence_idx in results.get("sequence_fields", []): + boundaries = (results[key][sequence_idx] >= -1) & ( + results[key][sequence_idx] <= 1 + ) + boundaries = boundaries[:, :1] & boundaries[:, 1:] + results[key + "_mask"][sequence_idx] = ( + results[key + "_mask"][sequence_idx] & boundaries + ) + continue + for sequence_idx in results.get("sequence_fields", []): + mask = results[key][sequence_idx] > self.min_value + if self.max_value is not None: + mask = mask & (results[key][sequence_idx] < self.max_value) + mask = self.custom_fn(mask, info=results) + if key + "_mask" not in results: + results[key + "_mask"] = {} + results[key + "_mask"][sequence_idx] = mask.bool() + results["mask_fields"].add(key + "_mask") + return results + + def __repr__(self): + return ( + self.__class__.__name__ + + f"(min_value={self.min_value}, max_value={ self.max_value})" + ) diff --git a/unidepth/datasets/pipelines/transforms.py b/unidepth/datasets/pipelines/transforms.py new file mode 100644 index 0000000..29750c0 --- /dev/null +++ b/unidepth/datasets/pipelines/transforms.py @@ -0,0 +1,1454 @@ +import os +import random +from copy import deepcopy +from math import ceil, exp, log, log2, log10, tanh +from typing import Dict, List, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms.v2.functional as TF + +from unidepth.utils.geometric import downsample + + +class PanoCrop: + def __init__(self, crop_v=0.1): + self.crop_v = crop_v + + def _crop_data(self, results, crop_size): + """Function to randomly crop images, bounding boxes, masks, semantic + segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + crop_size (tuple): Expected absolute size after cropping, (h, w). + allow_negative_crop (bool): Whether to allow a crop that does not + contain any bbox area. Default to False. + + Returns: + dict: Randomly cropped results, 'image_shape' key in result dict is + updated according to crop size. + """ + offset_w, offset_h = crop_size + left, top, right, bottom = offset_w[0], offset_h[0], offset_w[1], offset_h[1] + H, W = results["image"].shape[-2:] + for key in results.get("image_fields", ["image"]): + img = results[key][..., top : H - bottom, left : W - right] + results[key] = img + results["image_shape"] = tuple(img.shape) + + for key in results.get("gt_fields", []): + results[key] = results[key][..., top : H - bottom, left : W - right] + + for key in results.get("mask_fields", []): + results[key] = results[key][..., top : H - bottom, left : W - right] + + results["camera"].crop(left, top, right, bottom) + return results + + def __call__(self, results): + H, W = results["image"].shape[-2:] + crop_w = (0, 0) + crop_h = (int(H * self.crop_v), int(H * self.crop_v)) + results = self._crop_data(results, (crop_w, crop_h)) + return results + + +class PanoRoll: + def __init__(self, roll=[-0.5, 0.5]): + self.roll = roll + + def __call__(self, results): + W = results["image"].shape[-1] + roll = random.randint(int(W * self.roll[0]), int(W * self.roll[1])) + for key in results.get("image_fields", ["image"]): + img = results[key] + img = torch.roll(img, roll, dims=-1) + results[key] = img + for key in results.get("gt_fields", []): + results[key] = torch.roll(results[key], roll, dims=-1) + for key in results.get("mask_fields", []): + results[key] = torch.roll(results[key], roll, dims=-1) + return results + + +class RandomFlip: + """Flip the points & bbox. + + If the input dict contains the key "flip", then the flag will be used, + otherwise it will be randomly decided by a ratio specified in the init + method. + + Args: + flip_ratio_bev_horizontal (float, optional): The flipping probability + in horizontal direction. Defaults to 0.0. + flip_ratio_bev_vertical (float, optional): The flipping probability + in vertical direction. Defaults to 0.0. + """ + + def __init__(self, direction="horizontal", prob=0.5, **kwargs): + self.flip_ratio = prob + valid_directions = ["horizontal", "vertical", "diagonal"] + if isinstance(direction, str): + assert direction in valid_directions + elif isinstance(direction, list): + assert set(direction).issubset(set(valid_directions)) + else: + raise ValueError("direction must be either str or list of str") + self.direction = direction + + def __call__(self, results): + """Call function to flip points, values in the ``bbox3d_fields`` and + also flip 2D image and its annotations. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Flipped results, 'flip', 'flip_direction', + """ + if "flip" not in results: + if isinstance(self.direction, list): + # None means non-flip + direction_list = self.direction + [None] + else: + # None means non-flip + direction_list = [self.direction, None] + + if isinstance(self.flip_ratio, list): + non_flip_ratio = 1 - sum(self.flip_ratio) + flip_ratio_list = self.flip_ratio + [non_flip_ratio] + else: + non_flip_ratio = 1 - self.flip_ratio + # exclude non-flip + single_ratio = self.flip_ratio / (len(direction_list) - 1) + flip_ratio_list = [single_ratio] * (len(direction_list) - 1) + [ + non_flip_ratio + ] + + cur_dir = np.random.choice(direction_list, p=flip_ratio_list) + + results["flip"] = cur_dir is not None + + if "flip_direction" not in results: + results["flip_direction"] = cur_dir + + if results["flip"]: + # flip image + if results["flip_direction"] != "vertical": + for key in results.get("image_fields", ["image"]): + results[key] = TF.hflip(results[key]) + for key in results.get("mask_fields", []): + results[key] = TF.hflip(results[key]) + for key in results.get("gt_fields", []): + results[key] = TF.hflip(results[key]) + if "flow" in key: # flip u direction + results[key][:, 0] = -results[key][:, 0] + + H, W = results["image"].shape[-2:] + results["camera"] = results["camera"].flip( + H=H, W=W, direction="horizontal" + ) + # results["K"][..., 0, 2] = results["image"].shape[-1] - results["K"][..., 0, 2] + # flip: - t_x rotate around y by: pi - angle_y * 2 + flip_transform = torch.tensor( + [[-1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], + dtype=torch.float32, + ).unsqueeze(0) + repeats = (results["cam2w"].shape[0],) + (1,) * ( + results["cam2w"].ndim - 1 + ) + results["cam2w"] = flip_transform.repeat(*repeats) @ results["cam2w"] + + if results["flip_direction"] != "horizontal": + for key in results.get("image_fields", ["image"]): + results[key] = TF.vflip(results[key]) + for key in results.get("mask_fields", []): + results[key] = TF.vflip(results[key]) + for key in results.get("gt_fields", []): + results[key] = TF.vflip(results[key]) + results["K"][..., 1, 2] = ( + results["image"].shape[-2] - results["K"][..., 1, 2] + ) + results["flip"] = [results["flip"]] * len(results["image"]) + return results + + def __repr__(self): + """str: Return a string that describes the module.""" + repr_str = self.__class__.__name__ + repr_str += f" flip_ratio={self.flip_ratio})" + return repr_str + + +class Crop: + def __init__( + self, + crop_size, + crop_type="absolute", + crop_offset=(0, 0), + ): + if crop_type not in [ + "relative_range", + "relative", + "absolute", + "absolute_range", + ]: + raise ValueError(f"Invalid crop_type {crop_type}.") + if crop_type in ["absolute", "absolute_range"]: + assert crop_size[0] > 0 and crop_size[1] > 0 + assert isinstance(crop_size[0], int) and isinstance(crop_size[1], int) + else: + assert 0 < crop_size[0] <= 1 and 0 < crop_size[1] <= 1 + self.crop_size = crop_size + self.crop_type = crop_type + self.offset_h, self.offset_w = ( + crop_offset[: len(crop_offset) // 2], + crop_offset[len(crop_offset) // 2 :], + ) + + def _get_crop_size(self, image_shape): + h, w = image_shape + if self.crop_type == "absolute": + return (min(self.crop_size[0], h), min(self.crop_size[1], w)) + elif self.crop_type == "absolute_range": + assert self.crop_size[0] <= self.crop_size[1] + crop_h = np.random.randint( + min(h, self.crop_size[0]), min(h, self.crop_size[1]) + 1 + ) + crop_w = np.random.randint( + min(w, self.crop_size[0]), min(w, self.crop_size[1]) + 1 + ) + return crop_h, crop_w + elif self.crop_type == "relative": + crop_h, crop_w = self.crop_size + return int(h * crop_h + 0.5), int(w * crop_w + 0.5) + elif self.crop_type == "relative_range": + crop_size = np.asarray(self.crop_size, dtype=np.float32) + crop_h, crop_w = crop_size + np.random.rand(2) * (1 - crop_size) + return int(h * crop_h + 0.5), int(w * crop_w + 0.5) + + def _crop_data(self, results, crop_size): + assert crop_size[0] > 0 and crop_size[1] > 0 + for key in results.get("image_fields", ["image"]): + img = results[key] + img = TF.crop( + img, self.offset_h[0], self.offset_w[0], crop_size[0], crop_size[1] + ) + results[key] = img + results["image_shape"] = tuple(img.shape) + + for key in results.get("gt_fields", []): + gt = results[key] + results[key] = TF.crop( + gt, self.offset_h[0], self.offset_w[0], crop_size[0], crop_size[1] + ) + + # crop semantic seg + for key in results.get("mask_fields", []): + mask = results[key] + results[key] = TF.crop( + mask, self.offset_h[0], self.offset_w[0], crop_size[0], crop_size[1] + ) + + results["K"][..., 0, 2] = results["K"][..., 0, 2] - self.offset_w[0] + results["K"][..., 1, 2] = results["K"][..., 1, 2] - self.offset_h[0] + return results + + def __call__(self, results): + image_shape = results["image"].shape[-2:] + crop_size = self._get_crop_size(image_shape) + results = self._crop_data(results, crop_size) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f"(crop_size={self.crop_size}, " + repr_str += f"crop_type={self.crop_type}, " + return repr_str + + +class KittiCrop: + def __init__(self, crop_size): + self.crop_size = crop_size + + def _crop_data(self, results, crop_size): + """Function to randomly crop images, bounding boxes, masks, semantic + segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + crop_size (tuple): Expected absolute size after cropping, (h, w). + allow_negative_crop (bool): Whether to allow a crop that does not + contain any bbox area. Default to False. + + Returns: + dict: Randomly cropped results, 'image_shape' key in result dict is + updated according to crop size. + """ + assert crop_size[0] > 0 and crop_size[1] > 0 + for key in results.get("image_fields", ["image"]): + img = results[key] + h, w = img.shape[-2:] + offset_h, offset_w = int(h - self.crop_size[0]), int( + (w - self.crop_size[1]) / 2 + ) + + # crop the image + img = TF.crop(img, offset_h, offset_w, crop_size[0], crop_size[1]) + results[key] = img + results["image_shape"] = tuple(img.shape) + + for key in results.get("gt_fields", []): + gt = results[key] + results[key] = TF.crop(gt, offset_h, offset_w, crop_size[0], crop_size[1]) + + # crop semantic seg + for key in results.get("mask_fields", []): + mask = results[key] + results[key] = TF.crop(mask, offset_h, offset_w, crop_size[0], crop_size[1]) + + results["camera"].crop(offset_w, offset_h) + return results + + def __call__(self, results): + """Call function to randomly crop images, bounding boxes, masks, + semantic segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Randomly cropped results, 'image_shape' key in result dict is + updated according to crop size. + """ + results = self._crop_data(results, self.crop_size) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f"(crop_size={self.crop_size}, " + return repr_str + + +class RandomMasking: + def __init__( + self, + mask_ratio, + mask_patch=16, + prob=0.5, + warmup_steps=50000, + sampling="random", + curriculum=False, + ): + self.mask_patch = mask_patch + self.prob = prob + self.mask_ratio = mask_ratio + self.warmup_steps = max(1, warmup_steps) + self.hard_bound = 1 + self.idx = 0 + self.curriculum = curriculum + self.sampling = sampling + self.low_bound = 0.0 + self.up_bound = 0.0 + + def __call__(self, results): + B, _, H, W = results["image"].shape + device = results["image"].device + down_size = H // self.mask_patch, W // self.mask_patch + if np.random.random() > self.prob: # fill with dummy + return self._nop(results, down_size, device) + + validity_mask = results["validity_mask"].float().reshape(B, -1, H, W) + validity_mask = F.interpolate(validity_mask, size=down_size).bool() + validity_mask = validity_mask.reshape(B, 1, *down_size) + is_random = self.is_warmup or results.get("guidance") is None + + if not is_random: + guidance = F.interpolate(results["guidance"], size=(H, W), mode="bilinear") + results["guidance"] = -F.max_pool2d( + -guidance, kernel_size=self.mask_patch, stride=self.mask_patch + ) + + if is_random and self.sampling == "inverse": + sampling = self.inverse_sampling + elif is_random and self.sampling == "random": + sampling = self.random_sampling + else: + sampling = self.guided_sampling + mask_ratio = np.random.uniform(self.low_bound, self.up_bound) + for key in results.get("image_fields", ["image"]): + mask = sampling(results, mask_ratio, down_size, validity_mask, device) + results[key + "_mask"] = mask + return results + + def _nop(self, results, down_size, device): + B = results["image"].shape[0] + for key in results.get("image_fields", ["image"]): + mask_blocks = torch.zeros(size=(B, 1, *down_size), device=device) + results[key + "_mask"] = mask_blocks + return results + + def random_sampling(self, results, mask_ratio, down_size, validity_mask, device): + B = results["image"].shape[0] + prob_blocks = torch.rand(size=(B, 1, *down_size), device=device) + mask_blocks = torch.logical_and(prob_blocks < mask_ratio, validity_mask) + return mask_blocks + + def inverse_sampling(self, results, mask_ratio, down_size, validity_mask, device): + # from PIL import Image + # from unidepth.utils import colorize + def area_sample(depth, fx, fy): + dtype = depth.dtype + B = depth.shape[0] + H, W = down_size + depth = downsample(depth, depth.shape[-2] // H) + depth[depth > 200] = 50 # set sky as if depth 50 meters + pixel_area3d = depth / torch.sqrt(fx * fy) + + # Set invalid as -1 (no div problem) -> then clip to 0.0 + pixel_area3d[depth == 0.0] = -1 + prob_density = (1 / pixel_area3d).clamp(min=0.0).square() + prob_density = prob_density / prob_density.sum( + dim=(-1, -2), keepdim=True + ).clamp(min=1e-5) + # Image.fromarray((prob_density[0] * 255 * 100).clamp(min=0.0, max=255.0).squeeze().cpu().byte().numpy()).save("prob_density.png") + + # Sample locations based on prob_density + prob_density_flat = prob_density.view(B, -1) + + # Get the avgerage valid locations, of those we mask self.mask_ratio + valid_locations = (prob_density_flat > 0).to(dtype).sum(dim=1) + + masks = [] + for i in range(B): + num_samples = int(valid_locations[i] * mask_ratio) + mask = torch.zeros_like(prob_density_flat[i]) + # Sample indices + if num_samples > 0: + sampled_indices_flat = torch.multinomial( + prob_density_flat[i], num_samples, replacement=False + ) + mask.scatter_(0, sampled_indices_flat, 1) + masks.append(mask) + return torch.stack(masks).bool().view(B, 1, H, W) + + def random_sample(validity_mask): + prob_blocks = torch.rand( + size=(validity_mask.shape[0], 1, *down_size), device=device + ) + mask = torch.logical_and(prob_blocks < mask_ratio, validity_mask) + return mask + + fx = results["K"][..., 0, 0].view(-1, 1, 1, 1) / self.mask_patch + fy = results["K"][..., 1, 1].view(-1, 1, 1, 1) / self.mask_patch + + valid = ~results["ssi"] & ~results["si"] & results["valid_camera"] + mask_blocks = torch.zeros_like(validity_mask) + if valid.any(): + out = area_sample(results["depth"][valid], fx[valid], fy[valid]) + mask_blocks[valid] = out + if (~valid).any(): + mask_blocks[~valid] = random_sample(validity_mask[~valid]) + + # mask_blocks_ = (mask_blocks.float() * 255).squeeze(1).byte().cpu().numpy() + # Image.fromarray(mask_blocks_[0]).save("mask1.png") + # Image.fromarray(mask_blocks_[-1]).save("mask2.png") + # dd = results["depth"] + # Image.fromarray(colorize(dd[0].squeeze().cpu().numpy())).save("depth1_p.png") + # Image.fromarray(colorize(dd[-1].squeeze().cpu().numpy())).save("depth2_p.png") + # dd = downsample(dd, dd.shape[-2] // down_size[0]) + # Image.fromarray(colorize(dd[0].squeeze().cpu().numpy())).save("depth1.png") + # Image.fromarray(colorize(dd[-1].squeeze().cpu().numpy())).save("depth2.png") + # raise ValueError + + return mask_blocks + + def guided_sampling(self, results, mask_ratio, down_size, validity_mask, device): + # get the lowest (based on guidance) "mask_ratio" quantile of the patches that are in validity mask + B = results["image"].shape[0] + guidance = results["guidance"] + mask_blocks = torch.zeros(size=(B, 1, *down_size), device=device) + for b in range(B): + low_bound = torch.quantile( + guidance[b][validity_mask[b]], max(0.0, self.hard_bound - mask_ratio) + ) + up_bound = torch.quantile( + guidance[b][validity_mask[b]], min(1.0, self.hard_bound) + ) + mask_blocks[b] = torch.logical_and( + guidance[b] < up_bound, guidance[b] > low_bound + ) + mask_blocks = torch.logical_and(mask_blocks, validity_mask) + return mask_blocks + + def step(self): + self.idx += 1 + # schedule hard from 1.0 to self.mask_ratio + if self.curriculum: + step = max(0, self.idx / self.warmup_steps / 2 - 0.5) + self.hard_bound = 1 - (1 - self.mask_ratio) * tanh(step) + self.up_bound = self.mask_ratio * tanh(step) + self.low_bound = 0.2 * tanh(step) + + @property + def is_warmup(self): + return self.idx < self.warmup_steps + + +class Rotate: + def __init__( + self, angle, center=None, img_fill_val=(123.68, 116.28, 103.53), prob=0.5 + ): + if isinstance(img_fill_val, (float, int)): + img_fill_val = tuple([float(img_fill_val)] * 3) + elif isinstance(img_fill_val, tuple): + assert len(img_fill_val) == 3, ( + "image_fill_val as tuple must " + f"have 3 elements. got {len(img_fill_val)}." + ) + img_fill_val = tuple([float(val) for val in img_fill_val]) + else: + raise ValueError("image_fill_val must be float or tuple with 3 elements.") + assert np.all( + [0 <= val <= 255 for val in img_fill_val] + ), f"all elements of img_fill_val should between range [0,255] got {img_fill_val}." + assert 0 <= prob <= 1.0, f"The probability should be in range [0,1]bgot {prob}." + self.center = center + self.img_fill_val = img_fill_val + self.prob = prob + self.random = not isinstance(angle, (float, int)) + self.angle = angle + + def _rotate(self, results, angle, center=None, fill_val=0.0): + for key in results.get("image_fields", ["image"]): + img = results[key] + img_rotated = TF.rotate( + img, + angle, + center=center, + interpolation=TF.InterpolationMode.NEAREST_EXACT, + fill=self.img_fill_val, + ) + results[key] = img_rotated.to(img.dtype) + results["image_shape"] = results[key].shape + + for key in results.get("mask_fields", []): + results[key] = TF.rotate( + results[key], + angle, + center=center, + interpolation=TF.InterpolationMode.NEAREST_EXACT, + fill=fill_val, + ) + + for key in results.get("gt_fields", []): + results[key] = TF.rotate( + results[key], + angle, + center=center, + interpolation=TF.InterpolationMode.NEAREST_EXACT, + fill=fill_val, + ) + + def __call__(self, results): + if np.random.random() > self.prob: + return results + + angle = ( + (self.angle[1] - self.angle[0]) * np.random.rand() + self.angle[0] + if self.random + else np.random.choice([-1, 1], size=1) * self.angle + ) + self._rotate(results, angle, None, fill_val=0.0) + results["rotation"] = angle + return results + + +class RandomColor: + def __init__(self, level, prob=0.5): + self.random = not isinstance(level, (float, int)) + self.level = level + self.prob = prob + + def _adjust_color_img(self, results, factor=1.0): + for key in results.get("image_fields", ["image"]): + results[key] = TF.adjust_hue(results[key], factor) # .to(img.dtype) + + def __call__(self, results): + if np.random.random() > self.prob: + return results + factor = ( + ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0]) + if self.random + else self.level + ) + self._adjust_color_img(results, factor) + return results + + +class RandomSaturation: + def __init__(self, level, prob=0.5): + self.random = not isinstance(level, (float, int)) + self.level = level + self.prob = prob + + def _adjust_saturation_img(self, results, factor=1.0): + for key in results.get("image_fields", ["image"]): + # NOTE defaultly the image should be BGR format + results[key] = TF.adjust_saturation(results[key], factor) # .to(img.dtype) + + def __call__(self, results): + if np.random.random() > self.prob: + return results + factor = ( + 2 ** ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0]) + if self.random + else 2**self.level + ) + self._adjust_saturation_img(results, factor) + return results + + +class RandomSharpness: + def __init__(self, level, prob=0.5): + self.random = not isinstance(level, (float, int)) + self.level = level + self.prob = prob + + def _adjust_sharpeness_img(self, results, factor=1.0): + for key in results.get("image_fields", ["image"]): + # NOTE defaultly the image should be BGR format + results[key] = TF.adjust_sharpness(results[key], factor) # .to(img.dtype) + + def __call__(self, results): + if np.random.random() > self.prob: + return results + factor = ( + 2 ** ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0]) + if self.random + else 2**self.level + ) + self._adjust_sharpeness_img(results, factor) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f"(level={self.level}, " + repr_str += f"prob={self.prob})" + return repr_str + + +class RandomSolarize: + def __init__(self, level, prob=0.5): + self.random = not isinstance(level, (float, int)) + self.level = level + self.prob = prob + + def _adjust_solarize_img(self, results, factor=255.0): + for key in results.get("image_fields", ["image"]): + results[key] = TF.solarize(results[key], factor) # .to(img.dtype) + + def __call__(self, results): + if np.random.random() > self.prob: + return results + factor = ( + ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0]) + if self.random + else self.level + ) + self._adjust_solarize_img(results, factor) + return results + + +class RandomPosterize: + def __init__(self, level, prob=0.5): + self.random = not isinstance(level, (float, int)) + self.level = level + self.prob = prob + + def _posterize_img(self, results, factor=1.0): + for key in results.get("image_fields", ["image"]): + results[key] = TF.posterize(results[key], int(factor)) # .to(img.dtype) + + def __call__(self, results): + if np.random.random() > self.prob: + return results + factor = ( + ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0]) + if self.random + else self.level + ) + self._posterize_img(results, factor) + return results + + +class RandomEqualize: + def __init__(self, prob=0.5): + assert 0 <= prob <= 1.0, "The probability should be in range [0,1]." + self.prob = prob + + def _imequalize(self, results): + for key in results.get("image_fields", ["image"]): + results[key] = TF.equalize(results[key]) # .to(img.dtype) + + def __call__(self, results): + if np.random.random() > self.prob: + return results + self._imequalize(results) + return results + + +class RandomBrightness: + def __init__(self, level, prob=0.5): + self.random = not isinstance(level, (float, int)) + self.level = level + self.prob = prob + + def _adjust_brightness_img(self, results, factor=1.0): + for key in results.get("image_fields", ["image"]): + results[key] = TF.adjust_brightness(results[key], factor) # .to(img.dtype) + + def __call__(self, results, level=None): + if np.random.random() > self.prob: + return results + factor = ( + 2 ** ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0]) + if self.random + else 2**self.level + ) + self._adjust_brightness_img(results, factor) + return results + + +class RandomContrast: + def __init__(self, level, prob=0.5): + self.random = not isinstance(level, (float, int)) + self.level = level + self.prob = prob + + def _adjust_contrast_img(self, results, factor=1.0): + for key in results.get("image_fields", ["image"]): + results[key] = TF.adjust_contrast(results[key], factor) # .to(img.dtype) + + def __call__(self, results, level=None): + if np.random.random() > self.prob: + return results + factor = ( + 2 ** ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0]) + if self.random + else 2**self.level + ) + self._adjust_contrast_img(results, factor) + return results + + +class RandomGamma: + def __init__(self, level, prob=0.5): + self.random = not isinstance(level, (float, int)) + self.level = level + self.prob = prob + + def __call__(self, results, level=None): + if np.random.random() > self.prob: + return results + factor = (self.level[1] - self.level[0]) * np.random.rand() + self.level[0] + for key in results.get("image_fields", ["image"]): + if "original" not in key: + results[key] = TF.adjust_gamma(results[key], 1 + factor) + return results + + +class RandomInvert: + def __init__(self, prob=0.5): + self.prob = prob + + def __call__(self, results): + if np.random.random() > self.prob: + return results + for key in results.get("image_fields", ["image"]): + if "original" not in key: + results[key] = TF.invert(results[key]) # .to(img.dtype) + return results + + +class RandomAutoContrast: + def __init__(self, prob=0.5): + self.prob = prob + + def _autocontrast_img(self, results): + for key in results.get("image_fields", ["image"]): + img = results[key] + results[key] = TF.autocontrast(img) # .to(img.dtype) + + def __call__(self, results): + if np.random.random() > self.prob: + return results + self._autocontrast_img(results) + return results + + +class RandomShear(object): + def __init__( + self, + level, + prob=0.5, + direction="horizontal", + ): + self.random = not isinstance(level, (float, int)) + self.level = level + self.prob = prob + self.direction = direction + + def _shear_img(self, results, magnitude): + for key in results.get("image_fields", ["image"]): + img_sheared = TF.affine( + results[key], + angle=0.0, + translate=[0.0, 0.0], + scale=1.0, + shear=magnitude, + interpolation=TF.InterpolationMode.BILINEAR, + fill=0.0, + ) + results[key] = img_sheared + + def _shear_masks(self, results, magnitude): + for key in results.get("mask_fields", []): + mask_sheared = TF.affine( + results[key], + angle=0.0, + translate=[0.0, 0.0], + scale=1.0, + shear=magnitude, + interpolation=TF.InterpolationMode.NEAREST_EXACT, + fill=0.0, + ) + results[key] = mask_sheared + + def _shear_gt( + self, + results, + magnitude, + ): + for key in results.get("gt_fields", []): + mask_sheared = TF.affine( + results[key], + angle=0.0, + translate=[0.0, 0.0], + scale=1.0, + shear=magnitude, + interpolation=TF.InterpolationMode.NEAREST_EXACT, + fill=0.0, + ) + results[key] = mask_sheared + + def __call__(self, results): + if np.random.random() > self.prob: + return results + magnitude = ( + ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0]) + if self.random + else np.random.choice([-1, 1], size=1) * self.level + ) + if self.direction == "horizontal": + magnitude = [magnitude, 0.0] + else: + magnitude = [0.0, magnitude] + self._shear_img(results, magnitude) + self._shear_masks(results, magnitude) + self._shear_gt(results, magnitude) + return results + + +class RandomTranslate(object): + def __init__( + self, + range, + prob=0.5, + direction="horizontal", + ): + self.range = range + self.prob = prob + self.direction = direction + + def _translate_img(self, results, magnitude): + """Shear the image. + + Args: + results (dict): Result dict from loading pipeline. + magnitude (int | float): The magnitude used for shear. + direction (str): The direction for shear, either "horizontal" + or "vertical". + interpolation (str): Same as in :func:`mmcv.imshear`. + """ + for key in results.get("image_fields", ["image"]): + img_sheared = TF.affine( + results[key], + angle=0.0, + translate=magnitude, + scale=1.0, + shear=[0.0, 0.0], + interpolation=TF.InterpolationMode.BILINEAR, + fill=(123.68, 116.28, 103.53), + ) + results[key] = img_sheared + + def _translate_mask(self, results, magnitude): + """Shear the masks.""" + for key in results.get("mask_fields", []): + mask_sheared = TF.affine( + results[key], + angle=0.0, + translate=magnitude, + scale=1.0, + shear=[0.0, 0.0], + interpolation=TF.InterpolationMode.NEAREST_EXACT, + fill=0.0, + ) + results[key] = mask_sheared + + def _translate_gt( + self, + results, + magnitude, + ): + """Shear the segmentation maps.""" + for key in results.get("gt_fields", []): + mask_sheared = TF.affine( + results[key], + angle=0.0, + translate=magnitude, + scale=1.0, + shear=[0.0, 0.0], + interpolation=TF.InterpolationMode.NEAREST_EXACT, + fill=0.0, + ) + results[key] = mask_sheared + + def __call__(self, results): + """Call function to shear images, bounding boxes, masks and semantic + segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Sheared results. + """ + if np.random.random() > self.prob: + return results + magnitude = (self.range[1] - self.range[0]) * np.random.rand() + self.range[0] + if self.direction == "horizontal": + magnitude = [magnitude * results["image"].shape[1], 0] + else: + magnitude = [0, magnitude * results["image"].shape[0]] + self._translate_img(results, magnitude) + self._translate_mask(results, magnitude) + self._translate_gt(results, magnitude) + results["K"][..., 0, 2] = results["K"][..., 0, 2] + magnitude[0] + results["K"][..., 1, 2] = results["K"][..., 1, 2] + magnitude[1] + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f"(range={self.range}, " + repr_str += f"prob={self.prob}, " + repr_str += f"direction={self.direction}, " + return repr_str + + +class RandomColorJitter: + def __init__(self, level, prob=0.9): + self.level = level + self.prob = prob + self.list_transform = [ + self._adjust_brightness_img, + # self._adjust_sharpness_img, + self._adjust_contrast_img, + self._adjust_saturation_img, + self._adjust_color_img, + ] + + def _adjust_contrast_img(self, results, factor=1.0): + """Adjust the image contrast.""" + for key in results.get("image_fields", ["image"]): + if "original" not in key: + img = results[key] + results[key] = TF.adjust_contrast(img, factor) + + def _adjust_sharpness_img(self, results, factor=1.0): + """Adjust the image contrast.""" + for key in results.get("image_fields", ["image"]): + if "original" not in key: + img = results[key] + results[key] = TF.adjust_sharpness(img, factor) + + def _adjust_brightness_img(self, results, factor=1.0): + """Adjust the brightness of image.""" + for key in results.get("image_fields", ["image"]): + if "original" not in key: + img = results[key] + results[key] = TF.adjust_brightness(img, factor) + + def _adjust_saturation_img(self, results, factor=1.0): + """Apply Color transformation to image.""" + for key in results.get("image_fields", ["image"]): + if "original" not in key: + img = results[key] + results[key] = TF.adjust_saturation(img, factor / 2.0) + + def _adjust_color_img(self, results, factor=1.0): + """Apply Color transformation to image.""" + for key in results.get("image_fields", ["image"]): + if "original" not in key: + img = results[key] + results[key] = TF.adjust_hue(img, (factor - 1.0) / 4.0) + + def __call__(self, results): + """Call function for color transformation. + Args: + results (dict): Results dict from loading pipeline. + + Returns: + dict: Results after the transformation. + """ + random.shuffle(self.list_transform) + for op in self.list_transform: + if np.random.random() < self.prob: + factor = 1.0 + ( + (self.level[1] - self.level[0]) * np.random.random() + self.level[0] + ) + op(results, factor) + return results + + +class RandomGrayscale: + def __init__(self, prob=0.1, num_output_channels=3): + super().__init__() + self.prob = prob + self.num_output_channels = num_output_channels + + def __call__(self, results): + if np.random.random() > self.prob: + return results + + for key in results.get("image_fields", ["image"]): + if "original" not in key: + results[key] = TF.rgb_to_grayscale( + results[key], num_output_channels=self.num_output_channels + ) + return results + + +def masked_nearest_interpolation(input, mask, target_size): + """ + Resize the depth map using bilinear interpolation, considering only valid pixels within NxN neighbors. + + Args: + depth (torch.Tensor): The depth map tensor of shape (H, W). + mask (torch.Tensor): The mask tensor of shape (H, W) where 1 indicates valid depth and 0 indicates missing depth. + target_size (tuple): The desired output size (target_H, target_W). + + Returns: + torch.Tensor: The resized depth map. + """ + B, C, H, W = input.shape + target_H, target_W = target_size + mask = mask.float() + + # Generate a grid of coordinates in the target space + grid_y, grid_x = torch.meshgrid( + torch.linspace(0, H - 1, target_H), + torch.linspace(0, W - 1, target_W), + indexing="ij", + ) + grid_y = grid_y.to(input.device) + grid_x = grid_x.to(input.device) + + # Calculate the floor and ceil of the grid coordinates to get the bounding box + x0 = torch.floor(grid_x).long().clamp(0, W - 1) + x1 = (x0 + 1).clamp(0, W - 1) + y0 = torch.floor(grid_y).long().clamp(0, H - 1) + y1 = (y0 + 1).clamp(0, H - 1) + + # Gather depth values at the four corners + Ia = input[..., y0, x0] + Ib = input[..., y1, x0] + Ic = input[..., y0, x1] + Id = input[..., y1, x1] + + # Gather corresponding mask values + ma = mask[..., y0, x0] + mb = mask[..., y1, x0] + mc = mask[..., y0, x1] + md = mask[..., y1, x1] + + # Calculate distances to each neighbor + # The distances are calculated from the center (grid_x, grid_y) to each corner + dist_a = (grid_x - x0.float()) ** 2 + (grid_y - y0.float()) ** 2 # Top-left + dist_b = (grid_x - x0.float()) ** 2 + (grid_y - y1.float()) ** 2 # Bottom-left + dist_c = (grid_x - x1.float()) ** 2 + (grid_y - y0.float()) ** 2 # Top-right + dist_d = (grid_x - x1.float()) ** 2 + (grid_y - y1.float()) ** 2 # Bottom-right + + # Stack the neighbors, their masks, and distances + stacked_values = torch.stack( + [Ia, Ib, Ic, Id], dim=-1 + ) # Shape: (B, C, target_H, target_W, 4) + stacked_masks = torch.stack( + [ma, mb, mc, md], dim=-1 + ) # Shape: (B, 1, target_H, target_W, 4) + stacked_distances = torch.stack( + [dist_a, dist_b, dist_c, dist_d], dim=-1 + ) # Shape: (target_H, target_W, 4) + stacked_distances = ( + stacked_distances.unsqueeze(0).unsqueeze(1).repeat(B, 1, 1, 1, 1) + ) # Shape: (B, 1, target_H, target_W, 4) + + # Set distances to infinity for invalid neighbors (so that invalid neighbors are never chosen) + stacked_distances[stacked_masks == 0] = float("inf") + + # Find the index of the nearest valid neighbor (the one with the smallest distance) + nearest_indices = stacked_distances.argmin(dim=-1, keepdim=True)[ + ..., :1 + ] # Shape: (B, 1, target_H, target_W, 1) + + # Select the corresponding depth value using the nearest valid neighbor index + interpolated_depth = torch.gather( + stacked_values, dim=-1, index=nearest_indices.repeat(1, C, 1, 1, 1) + ).squeeze(-1) + + # Set depth to zero where no valid neighbors were found + interpolated_depth = interpolated_depth * stacked_masks.sum(dim=-1).clip( + min=0.0, max=1.0 + ) + + return interpolated_depth + + +class ContextCrop: + def __init__( + self, + image_shape, + keep_original=False, + test_min_ctx=1.0, + train_ctx_range=[0.5, 1.5], + shape_constraints={}, + ): + self.image_shape = image_shape + self.keep_original = keep_original + self.test_min_ctx = test_min_ctx + self.train_ctx_range = train_ctx_range + self.shape_mult = shape_constraints["shape_mult"] + self.sample = shape_constraints["sample"] + self.ratio_bounds = shape_constraints["ratio_bounds"] + self.pixels = ( + shape_constraints["pixels_max"] + shape_constraints["pixels_min"] + ) / 2 + self.ctx = None + + def _transform_img(self, results, shapes): + for key in results.get("image_fields", ["image"]): + img = self.crop(results[key], **shapes) + img = TF.resize( + img, + results["resized_shape"], + interpolation=TF.InterpolationMode.BICUBIC, + antialias=True, + ) + results[key] = img + + def _transform_masks(self, results, shapes): + for key in results.get("mask_fields", []): + mask = self.crop(results[key].float(), **shapes).byte() + mask = masked_nearest_interpolation( + mask, mask > 0, results["resized_shape"] + ) + results[key] = mask + + def _transform_gt(self, results, shapes): + for key in results.get("gt_fields", []): + gt = self.crop(results[key], **shapes) + gt = masked_nearest_interpolation(gt, gt > 0, results["resized_shape"]) + results[key] = gt + + @staticmethod + def crop(img, height, width, top, left) -> torch.Tensor: + h, w = img.shape[-2:] + right = left + width + bottom = top + height + padding_ltrb = [ + max(-left + min(0, right), 0), + max(-top + min(0, bottom), 0), + max(right - max(w, left), 0), + max(bottom - max(h, top), 0), + ] + image_cropped = img[..., max(top, 0) : bottom, max(left, 0) : right] + return TF.pad(image_cropped, padding_ltrb) + + def test_closest_shape(self, input_ratio): + if self.sample: + ratio = min(max(input_ratio, self.ratio_bounds[0]), self.ratio_bounds[1]) + h = round((self.pixels / ratio) ** 0.5) + w = h * ratio + self.image_shape[0] = int(h) * self.shape_mult + self.image_shape[1] = int(w) * self.shape_mult + + def _get_crop_shapes(self, image_shape, ctx=None): + h, w = image_shape + input_ratio = w / h + if self.keep_original: + self.test_closest_shape(input_ratio) + ctx = 1.0 + elif ctx is None: + ctx = float( + torch.empty(1) + .uniform_(self.train_ctx_range[0], self.train_ctx_range[1]) + .item() + ) + output_ratio = self.image_shape[1] / self.image_shape[0] + + if output_ratio <= input_ratio: # out like 4:3 in like kitti + if ( + ctx >= 1 + ): # fully in -> use just max_length with sqrt(ctx), here max is width + new_w = w * ctx**0.5 + # sporge un po in una sola dim + # we know that in_width will stick out before in_height, partial overshoot (sporge) + # new_h > old_h via area -> new_h ** 2 * ratio_new = old_h ** 2 * ratio_old * ctx + elif output_ratio / input_ratio * ctx > 1: + new_w = w * ctx + else: # fully contained -> use area + new_w = w * (ctx * output_ratio / input_ratio) ** 0.5 + new_h = new_w / output_ratio + else: + if ctx >= 1: + new_h = h * ctx**0.5 + elif input_ratio / output_ratio * ctx > 1: + new_h = h * ctx + else: + new_h = h * (ctx * input_ratio / output_ratio) ** 0.5 + new_w = new_h * output_ratio + return (int(ceil(new_h - 0.5)), int(ceil(new_w - 0.5))), ctx + + def __call__(self, results): + h, w = results["image"].shape[-2:] + results["image_ori_shape"] = (h, w) + results["camera_fields"].add("camera_original") + results["camera_original"] = deepcopy(results["camera"]) + + results.get("mask_fields", set()).add("validity_mask") + if "validity_mask" not in results: + results["validity_mask"] = torch.ones( + (results["image"].shape[0], 1, h, w), + dtype=torch.uint8, + device=results["image"].device, + ) + + n_iter = 1 if self.keep_original or not self.sample else 100 + + min_valid_area = 0.5 + max_hfov, max_vfov = results["camera"].max_fov[0] # it is a 1-dim list + ctx = None + for ii in range(n_iter): + + (height, width), ctx = self._get_crop_shapes((h, w), ctx=self.ctx or ctx) + margin_h = h - height + margin_w = w - width + + # keep it centered in y direction + top = margin_h // 2 + left = margin_w // 2 + if not self.keep_original: + left = left + np.random.randint( + -self.shape_mult // 2, self.shape_mult // 2 + 1 + ) + top = top + np.random.randint( + -self.shape_mult // 2, self.shape_mult // 2 + 1 + ) + + right = left + width + bottom = top + height + x_zoom = self.image_shape[0] / height + paddings = [ + max(-left + min(0, right), 0), + max(bottom - max(h, top), 0), + max(right - max(w, left), 0), + max(-top + min(0, bottom), 0), + ] + + valid_area = ( + h + * w + / (h + paddings[1] + paddings[3]) + / (w + paddings[0] + paddings[2]) + ) + new_hfov, new_vfov = results["camera_original"].get_new_fov( + new_shape=(height, width), original_shape=(h, w) + )[0] + # if valid_area >= min_valid_area or getattr(self, "ctx", None) is not None: + # break + if ( + valid_area >= min_valid_area + and new_hfov < max_hfov + and new_vfov < max_vfov + ): + break + ctx = ( + ctx * 0.96 + ) # if not enough valid area, try again with less ctx (more zoom) + + # save ctx for next iteration of sequences? + self.ctx = ctx + + results["resized_shape"] = self.image_shape + results["paddings"] = paddings # left ,top ,right, bottom + results["image_rescale"] = x_zoom + results["scale_factor"] = results.get("scale_factor", 1.0) * x_zoom + results["camera"] = results["camera"].crop( + left, top, right=w - right, bottom=h - bottom + ) + results["camera"] = results["camera"].resize(x_zoom) + + shapes = dict(height=height, width=width, top=top, left=left) + self._transform_img(results, shapes) + if not self.keep_original: + self._transform_gt(results, shapes) + self._transform_masks(results, shapes) + else: + # only validity_mask (rgb's masks follows rgb transform) #FIXME + mask = results["validity_mask"].float() + mask = self.crop(mask, **shapes).byte() + mask = TF.resize( + mask, + results["resized_shape"], + interpolation=TF.InterpolationMode.NEAREST, + ) + results["validity_mask"] = mask + + # keep original images before photo-augment + results["image_original"] = results["image"].clone() + results["image_fields"].add( + *[ + field.replace("image", "image_original") + for field in results["image_fields"] + ] + ) + + # repeat for batch resized shape and paddings + results["paddings"] = [results["paddings"]] * results["image"].shape[0] + results["resized_shape"] = [results["resized_shape"]] * results["image"].shape[ + 0 + ] + return results + + +class RandomFiller: + def __init__(self, *args, **kwargs): + super().__init__() + + def _transform(self, results): + def fill_noise(size, device): + return torch.normal(0, 2.0, size=size, device=device) + + def fill_black(size, device): + return -4 * torch.ones(size, device=device, dtype=torch.float32) + + def fill_white(size, device): + return 4 * torch.ones(size, device=device, dtype=torch.float32) + + def fill_zero(size, device): + return torch.zeros(size, device=device, dtype=torch.float32) + + B, C = results["image"].shape[:2] + mismatch = B // results["validity_mask"].shape[0] + if mismatch: + results["validity_mask"] = results["validity_mask"].repeat( + mismatch, 1, 1, 1 + ) + validity_mask = results["validity_mask"].repeat(1, C, 1, 1).bool() + filler_fn = np.random.choice([fill_noise, fill_black, fill_white, fill_zero]) + for key in results.get("image_fields", ["image"]): + results[key][~validity_mask] = filler_fn( + size=results[key][~validity_mask].shape, device=results[key].device + ) + + def __call__(self, results): + # generate mask for filler + if "validity_mask" not in results: + paddings = results.get("padding_size", [0] * 4) + height, width = results["image"].shape[-2:] + results.get("mask_fields", []).add("validity_mask") + results["validity_mask"] = torch.zeros_like(results["image"][:, :1]) + results["validity_mask"][ + ..., + paddings[1] : height - paddings[3], + paddings[0] : width - paddings[2], + ] = 1.0 + self._transform(results) + return results + + +class GaussianBlur: + def __init__(self, kernel_size, sigma=(0.1, 2.0), prob=0.9): + super().__init__() + self.kernel_size = kernel_size + self.sigma = sigma + self.prob = prob + self.padding = kernel_size // 2 + + def apply(self, x, kernel): + # Pad the input tensor + x = F.pad( + x, (self.padding, self.padding, self.padding, self.padding), mode="reflect" + ) + # Apply the convolution with the Gaussian kernel + return F.conv2d(x, kernel, stride=1, padding=0, groups=x.size(1)) + + def _create_kernel(self, sigma): + # Create a 1D Gaussian kernel + kernel_1d = torch.exp( + -torch.arange(-self.padding, self.padding + 1) ** 2 / (2 * sigma**2) + ) + kernel_1d = kernel_1d / kernel_1d.sum() + + # Expand the kernel to 2D and match size of the input + kernel_2d = kernel_1d.unsqueeze(0) * kernel_1d.unsqueeze(1) + kernel_2d = kernel_2d.view(1, 1, self.kernel_size, self.kernel_size).expand( + 3, 1, -1, -1 + ) + return kernel_2d + + def __call__(self, results): + if np.random.random() > self.prob: + return results + sigma = (self.sigma[1] - self.sigma[0]) * np.random.rand() + self.sigma[0] + kernel = self._create_kernel(sigma) + for key in results.get("image_fields", ["image"]): + if "original" not in key: + results[key] = self.apply(results[key], kernel) + return results + + +class Compose: + def __init__(self, transforms): + self.transforms = deepcopy(transforms) + + def __call__(self, results): + for t in self.transforms: + results = t(results) + return results + + def __setattr__(self, name: str, value) -> None: + super().__setattr__(name, value) + for t in self.transforms: + setattr(t, name, value) + + def __repr__(self): + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += f"\n {t}" + format_string += "\n)" + return format_string diff --git a/unidepth/datasets/samplers.py b/unidepth/datasets/samplers.py new file mode 100644 index 0000000..c999abc --- /dev/null +++ b/unidepth/datasets/samplers.py @@ -0,0 +1,13 @@ +import torch + + +class DistributedSamplerNoDuplicate(torch.utils.data.DistributedSampler): + """A distributed sampler that doesn't add duplicates. Arguments are the same as DistributedSampler""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if not self.drop_last and len(self.dataset) % self.num_replicas != 0: + # some ranks may have less samples, that's fine + if self.rank >= len(self.dataset) % self.num_replicas: + self.num_samples -= 1 + self.total_size = len(self.dataset) diff --git a/unidepth/datasets/scannet.py b/unidepth/datasets/scannet.py new file mode 100644 index 0000000..84cfc29 --- /dev/null +++ b/unidepth/datasets/scannet.py @@ -0,0 +1,73 @@ +import json +import os + +import h5py +import numpy as np +import torch + +from unidepth.datasets.image_dataset import ImageDataset +from unidepth.datasets.utils import DatasetFromList + + +class ScanNet(ImageDataset): + min_depth = 0.005 + max_depth = 10.0 + depth_scale = 1000.0 + test_split = "scannet_test.txt" + train_split = "scannet_train.txt" + intrisics_file = "scannet_intrinsics.json" + hdf5_paths = ["scannet.hdf5"] + + def __init__( + self, + image_shape, + split_file, + test_mode, + benchmark=False, + augmentations_db={}, + normalize=True, + resize_method="hard", + mini=1.0, + **kwargs, + ): + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + **kwargs, + ) + self.test_mode = test_mode + self.load_dataset() + + def load_dataset(self): + h5file = h5py.File( + os.path.join(self.data_root, self.hdf5_paths[0]), + "r", + libver="latest", + swmr=True, + ) + txt_file = np.array(h5file[self.split_file]) + txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1 + intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") + intrinsics = json.loads(intrinsics) + h5file.close() + dataset = [] + for line in txt_string.split("\n"): + image_filename, depth_filename = line.strip().split(" ") + intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3] + sample = [image_filename, depth_filename, intrinsics_val] + dataset.append(sample) + + self.dataset = DatasetFromList(dataset) + self.log_load_dataset() + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] + results["quality"] = [1] + return results diff --git a/unidepth/datasets/sequence_dataset.py b/unidepth/datasets/sequence_dataset.py new file mode 100644 index 0000000..0c58b03 --- /dev/null +++ b/unidepth/datasets/sequence_dataset.py @@ -0,0 +1,323 @@ +import json +import os +from functools import partial +from typing import Any, Dict, Tuple + +import h5py +import numpy as np +import tables +import torch +import torchvision.transforms.v2.functional as TF + +from unidepth.datasets.base_dataset import BaseDataset +from unidepth.datasets.utils import DatasetFromList +from unidepth.datasets.utils_decode import (decode_camera, decode_depth, + decode_flow, decode_K, decode_mask, + decode_numpy, decode_rgb, + decode_tensor) +from unidepth.utils.distributed import is_main_process + + +class SequenceDataset(BaseDataset): + DECODE_FNS = { + "image": partial(decode_rgb, name="image"), + "points": partial(decode_numpy, name="points"), + "K": partial(decode_K, name="camera"), + "camera_params": partial(decode_camera, name="camera"), + "cam2w": partial(decode_tensor, name="cam2w"), + "depth": partial(decode_depth, name="depth"), + "flow_fwd": partial(decode_flow, name="flow_fwd"), + "flow_bwd": partial(decode_flow, name="flow_bwd"), + "flow_fwd_mask": partial(decode_mask, name="flow_fwd_mask"), + "flow_bwd_mask": partial(decode_mask, name="flow_bwd_mask"), + } + default_fps = 5 + + def __init__( + self, + image_shape: Tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: Dict[str, Any], + resize_method: str, + mini: float, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth", "flow_fwd", "flow_fwd_mask"], + inplace_fields: list[str] = ["K", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + **kwargs, + ) + self.num_frames = num_frames + self.original_num_frames = num_frames + self.decode_fields = decode_fields + self.inplace_fields = inplace_fields + self.fps = self.default_fps + self.fps_range = kwargs.get("fps_range", None) + if self.fps_range is not None: + self.fps_range[1] = min(self.default_fps, self.fps_range[1]) + + self.load_dataset() + + def load_dataset(self): + h5file = h5py.File( + os.path.join(self.data_root, self.hdf5_paths[0]), + "r", + libver="latest", + swmr=True, + ) + txt_file = np.array(h5file[self.split_file]) + txt_string = txt_file.tostring().decode("ascii").strip() + sequences = np.array(h5file[self.sequences_file]).tostring().decode("ascii") + sequences = json.loads(sequences) + h5file.close() + dataset = [] + for line in txt_string.split("\n"): + if len(line.strip().split(" ")) == 1: + print(line) + continue + sequence_name, num_samples = line.strip().split(" ") + dataset.append( + { + "sequence_name": sequence_name, + "num_samples": int(num_samples), + "chunk_idx": 0, + } + ) + + # filter dataset based on attr "invalid_sequences" + invalid_sequences = getattr(self, "invalid_sequences", []) + dataset = [ + sample + for sample in dataset + if sample["sequence_name"] not in invalid_sequences + ] + + self.dataset = DatasetFromList(dataset) + self.sequences = DatasetFromList( + [sequences[sample["sequence_name"]] for sample in dataset] + ) + self.log_load_dataset() + + def get_random_idxs(self, num_samples_sequence): + if self.num_frames == 1: + return [np.random.randint(0, num_samples_sequence)], 0 + + max_image_skip = int((num_samples_sequence - 1) / (self.num_frames - 1)) + fps_lower_bound = self.default_fps / max_image_skip + if self.fps_range is not None: + lowest_fps = max(fps_lower_bound, self.fps_range[0]) + spf = 1 / lowest_fps - np.random.random() * ( + 1 / lowest_fps - 1 / self.fps_range[1] + ) # seconds per frame between 0.1 (10fps) and 5 (0.2fps) + self.fps = 1 / spf + n_skip_frames = int(self.default_fps / self.fps) + start = np.random.randint( + 0, max(1, num_samples_sequence - self.num_frames * n_skip_frames) + ) + idxs = list( + range( + start, + min(num_samples_sequence, self.num_frames * n_skip_frames + start), + n_skip_frames, + ) + ) + return idxs, np.random.randint(0, len(idxs)) + + def get_test_idxs(self, num_samples_sequence, keyframe_idx): + if self.num_frames == 1: + return [ + keyframe_idx if keyframe_idx is not None else num_samples_sequence // 2 + ], 0 + + if self.num_frames == -1: + cap_idxs = min(16, num_samples_sequence) # CAP AT 16 HARDCODED TODO!!! + idxs = [int(i * num_samples_sequence / cap_idxs) for i in range(cap_idxs)] + return idxs, None + + max_image_skip = int((num_samples_sequence - 1) / (self.num_frames - 1)) + fps_lower_bound = self.default_fps / max_image_skip + keyframe_idx = ( + keyframe_idx if keyframe_idx is not None else num_samples_sequence // 2 + ) + + # decimate up to keyframe_idx and after + if self.fps_range is not None: + self.fps = max(fps_lower_bound, self.fps_range[0]) + n_skip_frames = int(self.default_fps / self.fps) + start = max( + keyframe_idx % n_skip_frames, + keyframe_idx - self.num_frames // 2 * n_skip_frames, + ) + + # Case when the keyframe is too close to the end of the sequence, take last self.num_frames frames + if num_samples_sequence < self.num_frames * n_skip_frames + start: + num_frames_after_keyframe = ( + num_samples_sequence - 1 - keyframe_idx + ) // n_skip_frames + num_frames_before_keyframe = self.num_frames - 1 - num_frames_after_keyframe + idxs = list( + range( + keyframe_idx - num_frames_before_keyframe * n_skip_frames, + num_samples_sequence, + n_skip_frames, + ) + ) + + # Case when the keyframe is too close to the beginning of the sequence, take first self.num_frames frames + elif keyframe_idx - self.num_frames // 2 * n_skip_frames < 0: + num_frames_before_keyframe = keyframe_idx // n_skip_frames + num_frames_after_keyframe = self.num_frames - 1 - num_frames_before_keyframe + idxs = list( + range( + keyframe_idx - num_frames_before_keyframe * n_skip_frames, + keyframe_idx + num_frames_after_keyframe * n_skip_frames + 1, + n_skip_frames, + ) + ) + + # Case when the keyframe is not too close to the beginning and not too close to the end of the sequence + else: + idxs = list( + range( + start, + min( + num_samples_sequence, + 1 + self.num_frames * n_skip_frames + start, + ), + n_skip_frames, + ) + ) + + return idxs, idxs.index(keyframe_idx) + + def get_single_sequence(self, idx): + self.num_frames = self.original_num_frames + # sequence_name = self.dataset[idx]["sequence_name"] + sample = self.sequences[idx] + chunk_idx = int(sample.get("chunk_idx", 0)) + h5_path = os.path.join(self.data_root, self.hdf5_paths[chunk_idx]) + + num_samples_sequence = len(sample["image"]) + if self.num_frames > 0 and num_samples_sequence < self.num_frames: + raise IndexError(f"Sequence {idx} has less than {self.num_frames} frames") + keyframe_idx = None + + if not self.test_mode: + # idxs, keyframe_idx = self.get_random_idxs(num_samples_sequence) + start = np.random.randint(0, max(1, num_samples_sequence - self.num_frames)) + idxs = list( + range(start, min(num_samples_sequence, self.num_frames + start)) + ) + keyframe_idx = np.random.randint(0, len(idxs)) + else: + idxs, keyframe_idx = self.get_test_idxs( + num_samples_sequence, sample.get("keyframe_idx", None) + ) + + self.num_frames = len(idxs) + results = {} + results = self.pre_pipeline(results) + results["sequence_fields"] = [(i, 0) for i in range(self.num_frames)] + results["keyframe_idx"] = keyframe_idx + with tables.File( + h5_path, + mode="r", + libver="latest", + swmr=True, + ) as h5file_chunk: + + for i, j in enumerate(idxs): + results[(i, 0)] = { + k: v.copy() for k, v in results.items() if "fields" in k + } + for inplace_field in self.inplace_fields: + inplace_field_ = inplace_field.replace("intrinsics", "K").replace( + "extrinsics", "cam2w" + ) + results = self.DECODE_FNS[inplace_field_]( + results, sample[inplace_field][j], idx=i, sample=sample, j=j + ) + + for i, j in enumerate(idxs): + for decode_field in self.decode_fields: + results = self.DECODE_FNS[decode_field]( + results, + h5file_chunk, + sample[decode_field][j], + idx=i, + depth_scale=self.depth_scale, + ) + + results["filename"] = sample["image"][j] + + results = self.preprocess(results) + if not self.test_mode: + results = self.augment(results) + results = self.postprocess(results) + return results + + def preprocess(self, results): + self.resizer.ctx = None + for i, seq in enumerate(results["sequence_fields"]): + results[seq] = self.resizer(results[seq]) + + num_pts = torch.count_nonzero(results[seq]["depth"] > 0) + if num_pts < 50: + raise IndexError(f"Too few points in depth map ({num_pts})") + + for key in results[seq].get("image_fields", ["image"]): + results[seq][key] = results[seq][key].to(torch.float32) / 255 + + # update fields common in sequence + for key in [ + "image_fields", + "gt_fields", + "mask_fields", + "sequence_fields", + "camera_fields", + "paddings", + ]: + if key in results[(0, 0)]: + results[key] = results[(0, 0)][key] + + results = self.pack_batch(results) + return results + + def postprocess(self, results): + # # normalize after because color aug requires [0,255]? + for key in results.get("image_fields", ["image"]): + results[key] = TF.normalize(results[key], **self.normalization_stats) + results = self.filler(results) + results = self.unpack_batch(results) + results = self.masker(results) + results = self.collecter(results) + return results + + def __getitem__(self, idx): + try: + if isinstance(idx, (list, tuple)): + results = [self.get_single_sequence(i) for i in idx] + else: + results = self.get_single_sequence(idx) + except Exception as e: + print(f"Error loading sequence {idx} for {self.__class__.__name__}: {e}") + idx = np.random.randint(0, len(self.dataset)) + results = self[idx] + return results + + def log_load_dataset(self): + if is_main_process(): + info = f"Loaded {self.__class__.__name__} with {sum([len(x['image']) for x in self.sequences])} images in {len(self)} sequences." + print(info) diff --git a/unidepth/datasets/sintel.py b/unidepth/datasets/sintel.py new file mode 100644 index 0000000..a5f1859 --- /dev/null +++ b/unidepth/datasets/sintel.py @@ -0,0 +1,49 @@ +from typing import Any + +from unidepth.datasets.sequence_dataset import SequenceDataset + + +class Sintel(SequenceDataset): + min_depth = 0.001 + max_depth = 1000.0 + depth_scale = 1000.0 + test_split = "training.txt" + train_split = "training.txt" + sequences_file = "sequences.json" + hdf5_paths = ["Sintel.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["K", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames + results["synthetic"] = [True] * self.num_frames + return results diff --git a/unidepth/datasets/utils.py b/unidepth/datasets/utils.py new file mode 100644 index 0000000..6a413c7 --- /dev/null +++ b/unidepth/datasets/utils.py @@ -0,0 +1,230 @@ +import copy +import multiprocessing as mp +import pickle +from collections import defaultdict +from typing import Any, Dict, List + +import numpy as np +import torch +import torch.utils.data + +from unidepth.utils.distributed import (all_gather, get_local_rank, + get_local_size, get_rank, + get_world_size) + + +class ConcatDataset(torch.utils.data.ConcatDataset): + def __init__(self, datasets, shape_constraints: dict[str, list[int]] = {}, pairs=1): + super().__init__(datasets) + + self.sample = shape_constraints["sample"] + self.shape_mult = shape_constraints["shape_mult"] + self.ratio_bounds = shape_constraints["ratio_bounds"] + self.pixels_max = shape_constraints["pixels_max"] + self.pixels_min = shape_constraints["pixels_min"] + self.height_min = shape_constraints["height_min"] + self.width_min = shape_constraints["width_min"] + self.pairs = pairs + + def sample_shape(self): + if not self.sample: + return + # 1: sample image ratio + ratio = np.random.uniform(*self.ratio_bounds) + # 2: sample image height or width, if ratio > 1 or < 1 + if ratio > 1: + height_min = max(self.height_min, np.sqrt(self.pixels_min / ratio)) + height = np.random.uniform(height_min, np.sqrt(self.pixels_max / ratio)) + width = height * ratio + else: + width_min = max(self.width_min, np.sqrt(self.pixels_min * ratio)) + width = np.random.uniform(width_min, np.sqrt(self.pixels_max * ratio)) + height = width / ratio + # 3: get final shape based on the shape_mult + shape = [int(height) * self.shape_mult, int(width) * self.shape_mult] + for dataset in self.datasets: + setattr(dataset, "image_shape", shape) + setattr(dataset.resizer, "image_shape", shape) + + def __getitem__(self, idxs): + self.sample_shape() + return [ + super(ConcatDataset, self).__getitem__(idx) + for idx in idxs + for _ in range(self.pairs) + ] + + +def _paddings(image_shape, network_shape): + cur_h, cur_w = image_shape + h, w = network_shape + pad_top, pad_bottom = (h - cur_h) // 2, h - cur_h - (h - cur_h) // 2 + pad_left, pad_right = (w - cur_w) // 2, w - cur_w - (w - cur_w) // 2 + return pad_left, pad_right, pad_top, pad_bottom + + +def collate_fn(in_data: List[List[Dict[str, Any]]], is_batched: bool = True): + out_data = defaultdict(list) + img_metas = [] + in_data = in_data[0] if is_batched else in_data + + # get max_shape and paddings + shapes = [tensor.shape[-2:] for x in in_data for tensor in x["depth"].values()] + max_shape_tuple = tuple(max(elements) for elements in zip(*shapes)) + paddings = [ + [ + _paddings(tensor.shape[-2:], max_shape_tuple) + for tensor in x["depth"].values() + ] + for x in in_data + ] + + for x in in_data: # here iter over batches + padding = paddings.pop(0) + for k, v in x.items(): + if "img_metas" not in k: + values = list(v.values()) + v = torch.cat(values) + out_data[k].append(v) + else: + v["depth_paddings"] = padding + img_metas.append(v) + + return { + "data": {k: torch.stack(v) for k, v in out_data.items()}, + "img_metas": img_metas, + } + + +def local_scatter(array: list[Any]): + """ + Scatter an array from local leader to all local workers. + The i-th local worker gets array[i]. + + Args: + array: Array with same size of #local workers. + """ + if get_world_size() == 1: + return array[0] + if get_local_rank() == 0: + assert len(array) == get_local_size() + all_gather(array) + else: + all_data = all_gather(None) + array = all_data[get_rank() - get_local_rank()] + return array[get_local_rank()] + + +class DatasetFromList(torch.utils.data.Dataset): # type: ignore + """Wrap a list to a torch Dataset. + + We serialize and wrap big python objects in a torch.Dataset due to a + memory leak when dealing with large python objects using multiple workers. + See: https://github.com/pytorch/pytorch/issues/13246 + """ + + def __init__(self, lst: List[Any], deepcopy: bool = False, serialize: bool = True): + """Creates an instance of the class. + + Args: + lst: a list which contains elements to produce. + deepcopy: whether to deepcopy the element when producing it, s.t. + the result can be modified in place without affecting the source + in the list. + serialize: whether to hold memory using serialized objects. When + enabled, data loader workers can use shared RAM from master + process instead of making a copy. + """ + self._copy = deepcopy + self._serialize = serialize + + def _serialize(data: Any): + buffer = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL) + return torch.frombuffer(buffer, dtype=torch.uint8) + + if self._serialize: + # load only on 0th rank + if get_local_rank() == 0: + _lst = [_serialize(x) for x in lst] + self._addr = torch.cumsum( + torch.tensor([len(x) for x in _lst], dtype=torch.int64), dim=0 + ) + self._lst = torch.concatenate(_lst) + # Move data to shared memory, obtain a handle to send to each local worker. + handles = [None] + [ + bytes(mp.reduction.ForkingPickler.dumps((self._addr, self._lst))) + for _ in range(get_local_size() - 1) + ] + else: + handles = None + + # Each worker receives the handle from local leader (rank 0) + # then materialize the tensor from shared memory + handle = local_scatter(handles) + if get_local_rank() > 0: + self._addr, self._lst = mp.reduction.ForkingPickler.loads(handle) + + else: + self._lst = lst + + def __len__(self) -> int: + """Return len of list.""" + if self._serialize: + return len(self._addr) + return len(self._lst) + + def __getitem__(self, idx: int) -> Any: + """Return item of list at idx.""" + if self._serialize: + start_addr = 0 if idx == 0 else self._addr[idx - 1] + end_addr = self._addr[idx] + bytes_ = memoryview(self._lst[start_addr:end_addr].numpy()) + return pickle.loads(bytes_) + if self._copy: + return copy.deepcopy(self._lst[idx]) + + return self._lst[idx] + + +def get_weights( + train_datasets: dict[str, torch.utils.data.Dataset], sampling: dict[str, float] +) -> torch.Tensor: + from .image_dataset import ImageDataset + from .sequence_dataset import SequenceDataset + + weights = [] + num_samples = 0 + info_weights = {} + for dataset_name, dataset in train_datasets.items(): + assert ( + dataset_name in sampling + ), f"Dataset {dataset_name} not found in {sampling.keys()}" + + if isinstance(dataset, ImageDataset): + # sum of all samples has weight as in sampling s.t. sampling dataset in general is as in sampling + # inside is uniform + weight = sampling[dataset_name] / len(dataset) + weights.append(torch.full((len(dataset),), weight).double()) + num_samples += len(dataset) + + elif isinstance(dataset, SequenceDataset): + # local weight is num_samples, but global must be as in sampling + # hence is num_samples / (sum num_samples / sampling[dataset_name]) + # s.t. sampling anything from the dataset is + # sum(num_samples / (sum num_samples / sampling[dataset_name])) + # -> sampling[dataset_name] + numerator = [int(data["num_samples"]) for data in dataset.dataset] + weights.append( + sampling[dataset_name] + * torch.tensor(numerator).double() + / sum(numerator) + ) + num_samples += sum(numerator) + + else: + weight = sampling[dataset_name] / len(dataset) + weights.append(torch.full((len(dataset),), weight).double()) + + info_weights[dataset_name] = weights[-1][-1] + + return torch.cat(weights), num_samples diff --git a/unidepth/datasets/utils_decode.py b/unidepth/datasets/utils_decode.py new file mode 100644 index 0000000..5e91449 --- /dev/null +++ b/unidepth/datasets/utils_decode.py @@ -0,0 +1,125 @@ +import io + +import cv2 +import numpy as np +import torch +import torchvision +import torchvision.transforms.v2.functional as TF +from PIL import Image + +from unidepth.utils.camera import (EUCM, MEI, BatchCamera, Fisheye624, Pinhole, + Spherical) + + +def decode_depth(results, h5file, value, idx, depth_scale, name="depth", **kwargs): + file = h5file.get_node("/" + value).read() + decoded_data = Image.open(io.BytesIO(file)) + decoded_data = TF.pil_to_tensor(decoded_data).squeeze() + + # decoded_data = cv2.imdecode(np.frombuffer(file, np.uint8), 128 | 1) + # decoded_data = torch.from_numpy(decoded_data).squeeze() + + if decoded_data.ndim == 3: # 24 channel loading + decoded_channels = [ + (decoded_data[0] & 0xFF).to(torch.int32), + (decoded_data[1] & 0xFF).to(torch.int32), + (decoded_data[2] & 0xFF).to(torch.int32), + ] + # Reshape and extract the original depth map + decoded_data = ( + decoded_channels[0] + | (decoded_channels[1] << 8) + | (decoded_channels[2] << 16) + ) + + decoded_data = decoded_data.to(torch.float32) + results.get("gt_fields", set()).add(name) + results[(idx, 0)].get("gt_fields", set()).add(name) + results[f"{name}_ori_shape"] = decoded_data.shape + results[(idx, 0)][name] = ( + decoded_data.view(1, 1, *decoded_data.shape).contiguous() / depth_scale + ) + return results + + +def decode_numpy(results, h5file, value, idx, name="points", **kwargs): + file = h5file.get_node("/" + value).read() + decoded_data = np.load(io.BytesIO(file), allow_pickle=False) + decoded_data = torch.from_numpy(decoded_data).to(torch.float32) + if decoded_data.ndim > 2: + decoded_data = decoded_data.permute(2, 0, 1) + results.get("gt_fields", set()).add(name) + results[(idx, 0)].get("gt_fields", set()).add(name) + results[(idx, 0)][name] = decoded_data.unsqueeze(0) + return results + + +def decode_tensor(results, value, idx, name, **kwargs): + results.get("camera_fields", set()).add(name) + results[(idx, 0)].get("camera_fields", set()).add(name) + results[(idx, 0)][name] = torch.tensor(value).unsqueeze(0) + return results + + +def decode_camera(results, value, idx, name, sample, j, **kwargs): + results.get("camera_fields", set()).add(name) + results[(idx, 0)].get("camera_fields", set()).add(name) + camera = eval(sample["camera_model"][j])(params=torch.tensor(value).unsqueeze(0)) + results[(idx, 0)][name] = BatchCamera.from_camera(camera) + return results + + +def decode_K(results, value, idx, name, **kwargs): + results.get("camera_fields", set()).add(name) + results[(idx, 0)].get("camera_fields", set()).add(name) + camera = Pinhole(K=torch.tensor(value).unsqueeze(0)) + results[(idx, 0)][name] = BatchCamera.from_camera(camera) + return results + + +def decode_mask(results, h5file, value, idx, name, **kwargs): + file = h5file.get_node("/" + value).read() + mask = torchvision.io.decode_image(torch.from_numpy(file)).bool().squeeze() + results.get("mask_fields", set()).add(name) + results[(idx, 0)].get("mask_fields", set()).add(name) + results[f"{name}_ori_shape"] = mask.shape[-2:] + results[(idx, 0)][name] = mask.view(1, 1, *mask.shape).contiguous() + return results + + +def decode_rgb(results, h5file, value, idx, name="image", **kwargs): + file = h5file.get_node("/" + value).read() + image = ( + torchvision.io.decode_image(torch.from_numpy(file)).to(torch.uint8).squeeze() + ) + results.get("image_fields", set()).add(name) + results[(idx, 0)].get("image_fields", set()).add(name) + results[f"{name}_ori_shape"] = image.shape[-2:] + if image.ndim == 2: + image = image.unsqueeze(0).repeat(3, 1, 1) + results[(idx, 0)][name] = image.unsqueeze(0) + return results + + +def decode_flow(results, h5file, value, idx, name, **kwargs): + file = h5file.get_node("/" + value).read() + image = ( + torchvision.io.decode_image(torch.from_numpy(file)).to(torch.uint8).squeeze() + ) + decoded_channels = [ + (image[0] & 0xFF).to(torch.int16), + (image[1] & 0xFF).to(torch.int16), + (image[2] & 0xFF).to(torch.int16), + ] + + # Reshape and extract the original 2-channel flow map + flow = torch.zeros((2, image.shape[1], image.shape[2]), dtype=torch.int16) + flow[0] = (decoded_channels[0] | decoded_channels[1] << 8) & 0xFFF + flow[1] = (decoded_channels[1] >> 4 | decoded_channels[2] << 4) & 0xFFF + + results.get("gt_fields", set()).add(name) + results[(idx, 0)].get("gt_fields", set()).add(name) + results[f"{name}_ori_shape"] = flow.shape[-2:] + flow = flow.unsqueeze(0).contiguous().float() + results[(idx, 0)][name] = (0.5 + flow) / 4095.0 * 2 - 1 + return results diff --git a/unidepth/models/backbones/dinov2.py b/unidepth/models/backbones/dinov2.py index 726846d..d394ebb 100644 --- a/unidepth/models/backbones/dinov2.py +++ b/unidepth/models/backbones/dinov2.py @@ -1,3 +1,4 @@ +import contextlib import logging import math from functools import partial @@ -98,6 +99,7 @@ def get_parameter_groups(model, lr, wd=1e-5, ld=0.9, skip_list=()): } parameter_group_vars[group_name]["params"].append(param) parameter_group_names[group_name]["params"].append(name) + return list(parameter_group_vars.values()), [ v["lr"] for k, v in parameter_group_vars.items() ] @@ -137,6 +139,7 @@ def __init__( interpolate_antialias=False, interpolate_offset=0.0, use_norm=False, + frozen_stages=0, ): """ Args: @@ -166,6 +169,7 @@ def __init__( self.num_features = self.embed_dim = ( embed_dim # num_features for consistency with other models ) + self.frozen_stages = frozen_stages self.embed_dims = [embed_dim] * output_idx[-1] self.num_tokens = 1 self.n_blocks = depth @@ -247,7 +251,7 @@ def f(*args, **kwargs): self.chunked_blocks = False self.blocks = nn.ModuleList(blocks_list) - self.norm = norm_layer(embed_dim) + self.norm = nn.LayerNorm(embed_dim) self.use_norm = use_norm self.head = nn.Identity() self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) @@ -301,7 +305,8 @@ def interpolate_pos_encoding(self, x, w, h): def prepare_tokens_with_masks(self, x, masks=None): B, nc, w, h = x.shape - x = self.patch_embed(x) + with torch.no_grad() if self.frozen_stages > -1 else contextlib.nullcontext(): + x = self.patch_embed(x) if masks is not None: masks = masks.bool().view(B, -1, 1) x = torch.where(masks, self.mask_token.to(x.dtype).unsqueeze(0), x) @@ -322,11 +327,19 @@ def forward(self, x, masks=None): x = self.prepare_tokens_with_masks(x, masks) outputs = [] for i, blk in enumerate(self.blocks): - x = blk(x) + with ( + torch.no_grad() if i < self.frozen_stages else contextlib.nullcontext() + ): + x = blk(x) outputs.append(x) if self.use_norm: - outputs = [self.norm(out) for out in outputs] + with ( + torch.no_grad() + if self.frozen_stages >= len(self.blocks) + else contextlib.nullcontext() + ): + outputs = [self.norm(out) for out in outputs] class_tokens = [out[:, :1] for out in outputs] outputs = [out[:, self.num_register_tokens + 1 :] for out in outputs] outputs = [out.reshape(batch_size, *shapes, -1) for out in outputs] @@ -345,6 +358,21 @@ def freeze(self) -> None: def train(self, mode=True): super().train(mode) + if self.frozen_stages > -1: + for p in self.patch_embed.parameters(): + p.requires_grad = False + + for i, blk in enumerate(self.blocks): + if i < self.frozen_stages: + blk.eval() + for p in blk.parameters(): + p.requires_grad = False + + for p in self.norm.parameters(): + p.requires_grad = self.frozen_stages <= len(self.blocks) and self.use_norm + + self.cls_token.requires_grad = self.frozen_stages < 1 + self.pos_embed.requires_grad = self.frozen_stages < 1 self.mask_token.requires_grad = False self.register_tokens.requires_grad = False @@ -419,10 +447,10 @@ def _make_dinov2_model( use_norm: bool = False, export: bool = False, interpolate_offset: float = 0.0, + frozen_stages: int = 0, **kwargs, ): model_name = _make_dinov2_model_name(arch_name, patch_size) - vit_kwargs = dict( img_size=img_size, patch_size=patch_size, @@ -435,6 +463,7 @@ def _make_dinov2_model( use_norm=use_norm, export=export, interpolate_offset=interpolate_offset, + frozen_stages=frozen_stages, ) vit_kwargs.update(**kwargs) model = eval(arch_name)(**vit_kwargs) @@ -452,4 +481,6 @@ def _make_dinov2_model( state_dict = torch.load(pretrained, map_location="cpu") info = model.load_state_dict(state_dict, strict=False) print(f"loading from {pretrained} with:", info) + else: + print("Not loading pretrained weights for backbone") return model diff --git a/unidepth/models/unidepthv1/decoder.py b/unidepth/models/unidepthv1/decoder.py index 6448d3a..bdfb363 100644 --- a/unidepth/models/unidepthv1/decoder.py +++ b/unidepth/models/unidepthv1/decoder.py @@ -197,10 +197,9 @@ def forward( ) -> torch.Tensor: features = features.unbind(dim=-1) shapes = self.shapes + rays_hr = rays_hr.detach() # camera_embedding - # torch.cuda.synchronize() - # start = time() rays_embedding_16 = F.normalize( flat_interpolate(rays_hr, old=self.original_shapes, new=shapes), dim=-1 ) @@ -219,8 +218,6 @@ def forward( rays_embedding_16 = self.project_rays16(rsh_cart_8(rays_embedding_16)) rays_embedding_8 = self.project_rays8(rsh_cart_8(rays_embedding_8)) rays_embedding_4 = self.project_rays4(rsh_cart_8(rays_embedding_4)) - # torch.cuda.synchronize() - # print(f"camera_embedding took {time() - start} seconds") features_tokens = torch.cat(features, dim=1) features_tokens_pos = pos_embed + level_embed @@ -377,7 +374,10 @@ def forward(self, inputs, image_metas) -> torch.Tensor: max_stack(original_encoder_outputs[i:j]) for i, j in self.slices_encoder_range ] - cls_tokens = [cls_tokens[-i - 1] for i in range(len(self.slices_encoder_range))] + # detach tokens for camera + cls_tokens = [ + cls_tokens[-i - 1].detach() for i in range(len(self.slices_encoder_range)) + ] # get features in b n d format # level shapes, the shape per level, for swin like [[128, 128], [64, 64],...], for vit [[32,32]] -> mult times resolutions @@ -474,7 +474,6 @@ def build(self, config): expansion = config["model"]["expansion"] dropout = config["model"]["pixel_decoder"]["dropout"] depths_encoder = config["model"]["pixel_encoder"]["depths"] - num_steps = config["model"].get("num_steps", 100000) layer_scale = 1.0 self.depth = depth diff --git a/unidepth/models/unidepthv1/unidepthv1.py b/unidepth/models/unidepthv1/unidepthv1.py index e02be9e..cbcbc3d 100644 --- a/unidepth/models/unidepthv1/unidepthv1.py +++ b/unidepth/models/unidepthv1/unidepthv1.py @@ -20,7 +20,7 @@ from unidepth.utils.distributed import is_main_process from unidepth.utils.geometric import (generate_rays, spherical_zbuffer_to_euclidean) -from unidepth.utils.misc import get_params +from unidepth.utils.misc import get_params, match_gt MAP_BACKBONES = {"ViTL14": "vitl14", "ConvNextL": "cnvnxtl"} @@ -108,12 +108,13 @@ def __init__( ): super().__init__() self.build(config) + self.build_losses(config) self.eps = eps def forward(self, inputs, image_metas): rgbs = inputs["image"] - gt_intrinsics = inputs.get("K") - H, W = rgbs.shape[-2:] + B, _, H, W = rgbs.shape + cameras = inputs["camera"] # Encode encoder_outputs, cls_tokens = self.pixel_encoder(rgbs) @@ -125,23 +126,18 @@ def forward(self, inputs, image_metas): inputs["encoder_outputs"] = encoder_outputs inputs["cls_tokens"] = cls_tokens - # Get camera infos, if any - if gt_intrinsics is not None: - rays, angles = generate_rays( - gt_intrinsics, self.image_shape, noisy=self.training - ) - inputs["rays"] = rays - inputs["angles"] = angles - inputs["K"] = gt_intrinsics - self.pixel_decoder.test_fixed_camera = True # use GT camera in fwd + # Get camera rays for supervision, all in unit sphere + inputs["rays"] = rearrange( + cameras.get_rays(shapes=(B, H, W)), "b c h w -> b (h w) c" + ) # Decode - pred_intrinsics, predictions, _ = self.pixel_decoder(inputs, {}) + pred_intrinsics, predictions, depth_features = self.pixel_decoder(inputs, {}) predictions = sum( [ F.interpolate( x.clone(), - size=self.image_shape, + size=(H, W), mode="bilinear", align_corners=False, antialias=True, @@ -151,9 +147,23 @@ def forward(self, inputs, image_metas): ) / len(predictions) # Final 3D points backprojection - pred_angles = generate_rays(pred_intrinsics, (H, W), noisy=False)[-1] + pred_rays, pred_angles = generate_rays(pred_intrinsics, (H, W), noisy=False) + # You may want to use inputs["angles"] if available? pred_angles = rearrange(pred_angles, "b (h w) c -> b c h w", h=H, w=W) + + # reshape to match GT as paddings and shape + if not self.training: + depth_gt = inputs["depth"] + image_paddings = [image_metas[0]["paddings"]] + depth_paddings = [image_metas[0]["depth_paddings"]] + predictions = match_gt( + predictions, depth_gt, padding1=image_paddings, padding2=depth_paddings + ) + pred_angles = match_gt( + pred_angles, depth_gt, padding1=image_paddings, padding2=depth_paddings + ) + points_3d = torch.cat((pred_angles, predictions), dim=1) points_3d = spherical_zbuffer_to_euclidean( points_3d.permute(0, 2, 3, 1) @@ -162,12 +172,65 @@ def forward(self, inputs, image_metas): # Output data, use for loss computation outputs = { "angles": pred_angles, + "rays": pred_rays, "intrinsics": pred_intrinsics, "points": points_3d, "depth": predictions[:, -1:], + "cond_features": depth_features, } self.pixel_decoder.test_fixed_camera = False - return outputs + losses = self.compute_losses(outputs, inputs, image_metas) + + return outputs, losses + + def compute_losses(self, outputs, inputs, image_metas): + losses = {"opt": {}, "stat": {}} + if ( + not self.training + ): # only compute losses during training, avoid issues for mismatch size of pred and GT + return losses + losses_to_be_computed = list(self.losses.keys()) + + # depth loss + si = torch.tensor([x.get("si", False) for x in image_metas], device=self.device) + loss = self.losses["depth"] + depth_losses = loss( + outputs["depth"], + target=inputs["depth"], + mask=inputs["depth_mask"].clone(), + si=si, + ) + losses["opt"][loss.name] = loss.weight * depth_losses.mean() + losses_to_be_computed.remove("depth") + + # camera loss, here we apply to rays for simplicity + # in the original training was on angles + # however, we saw no difference (see supplementary) + loss = self.losses["camera"] + camera_losses = loss(outputs["rays"], target=inputs["rays"]) + losses["opt"][loss.name] = loss.weight * camera_losses.mean() + losses_to_be_computed.remove("camera") + + # invariance loss + flips = torch.tensor( + [x.get("flip", False) for x in image_metas], device=self.device + ) + loss = self.losses["invariance"] + invariance_losses = loss( + outputs["cond_features"], + intrinsics=inputs["camera"].K, + mask=inputs["depth_mask"], + flips=flips, + ) + losses["opt"][loss.name] = loss.weight * invariance_losses.mean() + losses_to_be_computed.remove("invariance") + + # remaining losses, we expect no more losses to be computed + assert ( + not losses_to_be_computed + ), f"Losses {losses_to_be_computed} not computed, revise `compute_loss` method" + + return losses @torch.no_grad() def infer(self, rgbs: torch.Tensor, intrinsics=None, skip_camera=False): @@ -292,7 +355,7 @@ def get_params(self, config): decoder_p, decoder_lr = get_params( self.pixel_decoder, config["training"]["lr"], config["training"]["wd"] ) - return [*encoder_p, *decoder_p], [*encoder_lr, *decoder_lr] + return [*encoder_p, *decoder_p] @property def device(self): @@ -326,3 +389,10 @@ def build(self, config): self.pixel_encoder = pixel_encoder self.pixel_decoder = Decoder(config) self.image_shape = config["data"]["image_shape"] + + def build_losses(self, config): + self.losses = {} + for loss_name, loss_config in config["training"]["losses"].items(): + mod = importlib.import_module("unidepth.ops.losses") + loss_factory = getattr(mod, loss_config["name"]) + self.losses[loss_name] = loss_factory.build(loss_config) diff --git a/unidepth/ops/__init__.py b/unidepth/ops/__init__.py index eb25b2f..e87e761 100644 --- a/unidepth/ops/__init__.py +++ b/unidepth/ops/__init__.py @@ -1,9 +1,3 @@ -from .losses import MSE, SelfCons, SILog -from .scheduler import CosineScheduler - -__all__ = [ - "SILog", - "MSE", - "SelfCons", - "CosineScheduler", -] +from .losses import (ARel, Confidence, Dummy, LocalSSI, Regression, + SelfDistill, SILog, TeacherDistill) +from .scheduler import CosineScheduler, PlainCosineScheduler diff --git a/unidepth/ops/knn/__init__.py b/unidepth/ops/knn/__init__.py new file mode 100644 index 0000000..ba469c8 --- /dev/null +++ b/unidepth/ops/knn/__init__.py @@ -0,0 +1,6 @@ +from .functions.knn import knn_gather, knn_points + +__all__ = [ + "knn_points", + "knn_gather", +] diff --git a/unidepth/ops/knn/compile.sh b/unidepth/ops/knn/compile.sh new file mode 100755 index 0000000..348ef71 --- /dev/null +++ b/unidepth/ops/knn/compile.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash + +export TORCH_CUDA_ARCH_LIST="6.1 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" +# export FORCE_CUDA=1 #if you do not actually have cuda, workaround +python setup.py build install \ No newline at end of file diff --git a/unidepth/ops/knn/functions/__init__.py b/unidepth/ops/knn/functions/__init__.py new file mode 100644 index 0000000..6a54211 --- /dev/null +++ b/unidepth/ops/knn/functions/__init__.py @@ -0,0 +1,6 @@ +from .knn import knn_gather, knn_points + +__all__ = [ + "knn_points", + "knn_gather", +] diff --git a/unidepth/ops/knn/functions/knn.py b/unidepth/ops/knn/functions/knn.py new file mode 100644 index 0000000..42c3634 --- /dev/null +++ b/unidepth/ops/knn/functions/knn.py @@ -0,0 +1,249 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from collections import namedtuple +from typing import Union + +import torch +from KNN import knn_points_backward, knn_points_idx +from torch.autograd import Function +from torch.autograd.function import once_differentiable + +_KNN = namedtuple("KNN", "dists idx knn") + + +class _knn_points(Function): + """ + Torch autograd Function wrapper for KNN C++/CUDA implementations. + """ + + @staticmethod + # pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently. + def forward( + ctx, + p1, + p2, + lengths1, + lengths2, + K, + version, + norm: int = 2, + return_sorted: bool = True, + ): + """ + K-Nearest neighbors on point clouds. + + Args: + p1: Tensor of shape (N, P1, D) giving a batch of N point clouds, each + containing up to P1 points of dimension D. + p2: Tensor of shape (N, P2, D) giving a batch of N point clouds, each + containing up to P2 points of dimension D. + lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the + length of each pointcloud in p1. Or None to indicate that every cloud has + length P1. + lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the + length of each pointcloud in p2. Or None to indicate that every cloud has + length P2. + K: Integer giving the number of nearest neighbors to return. + version: Which KNN implementation to use in the backend. If version=-1, + the correct implementation is selected based on the shapes of the inputs. + norm: (int) indicating the norm. Only supports 1 (for L1) and 2 (for L2). + return_sorted: (bool) whether to return the nearest neighbors sorted in + ascending order of distance. + + Returns: + p1_dists: Tensor of shape (N, P1, K) giving the squared distances to + the nearest neighbors. This is padded with zeros both where a cloud in p2 + has fewer than K points and where a cloud in p1 has fewer than P1 points. + + p1_idx: LongTensor of shape (N, P1, K) giving the indices of the + K nearest neighbors from points in p1 to points in p2. + Concretely, if `p1_idx[n, i, k] = j` then `p2[n, j]` is the k-th nearest + neighbors to `p1[n, i]` in `p2[n]`. This is padded with zeros both where a cloud + in p2 has fewer than K points and where a cloud in p1 has fewer than P1 points. + """ + if not ((norm == 1) or (norm == 2)): + raise ValueError("Support for 1 or 2 norm.") + + idx, dists = knn_points_idx(p1, p2, lengths1, lengths2, norm, K, version) + + # sort KNN in ascending order if K > 1 + if K > 1 and return_sorted: + if lengths2.min() < K: + P1 = p1.shape[1] + mask = lengths2[:, None] <= torch.arange(K, device=dists.device)[None] + # mask has shape [N, K], true where dists irrelevant + mask = mask[:, None].expand(-1, P1, -1) + # mask has shape [N, P1, K], true where dists irrelevant + dists[mask] = float("inf") + dists, sort_idx = dists.sort(dim=2) + dists[mask] = 0 + else: + dists, sort_idx = dists.sort(dim=2) + idx = idx.gather(2, sort_idx) + + ctx.save_for_backward(p1, p2, lengths1, lengths2, idx) + ctx.mark_non_differentiable(idx) + ctx.norm = norm + return dists, idx + + @staticmethod + @once_differentiable + def backward(ctx, grad_dists, grad_idx): + p1, p2, lengths1, lengths2, idx = ctx.saved_tensors + norm = ctx.norm + # TODO(gkioxari) Change cast to floats once we add support for doubles. + if not (grad_dists.dtype == torch.float32): + grad_dists = grad_dists.float() + if not (p1.dtype == torch.float32): + p1 = p1.float() + if not (p2.dtype == torch.float32): + p2 = p2.float() + grad_p1, grad_p2 = knn_points_backward( + p1, p2, lengths1, lengths2, idx, norm, grad_dists + ) + return grad_p1, grad_p2, None, None, None, None, None, None + + +def knn_points( + p1: torch.Tensor, + p2: torch.Tensor, + lengths1: Union[torch.Tensor, None] = None, + lengths2: Union[torch.Tensor, None] = None, + norm: int = 2, + K: int = 1, + version: int = -1, + return_nn: bool = False, + return_sorted: bool = True, +) -> _KNN: + """ + K-Nearest neighbors on point clouds. + + Args: + p1: Tensor of shape (N, P1, D) giving a batch of N point clouds, each + containing up to P1 points of dimension D. + p2: Tensor of shape (N, P2, D) giving a batch of N point clouds, each + containing up to P2 points of dimension D. + lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the + length of each pointcloud in p1. Or None to indicate that every cloud has + length P1. + lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the + length of each pointcloud in p2. Or None to indicate that every cloud has + length P2. + norm: Integer indicating the norm of the distance. Supports only 1 for L1, 2 for L2. + K: Integer giving the number of nearest neighbors to return. + version: Which KNN implementation to use in the backend. If version=-1, + the correct implementation is selected based on the shapes of the inputs. + return_nn: If set to True returns the K nearest neighbors in p2 for each point in p1. + return_sorted: (bool) whether to return the nearest neighbors sorted in + ascending order of distance. + + Returns: + dists: Tensor of shape (N, P1, K) giving the squared distances to + the nearest neighbors. This is padded with zeros both where a cloud in p2 + has fewer than K points and where a cloud in p1 has fewer than P1 points. + + idx: LongTensor of shape (N, P1, K) giving the indices of the + K nearest neighbors from points in p1 to points in p2. + Concretely, if `p1_idx[n, i, k] = j` then `p2[n, j]` is the k-th nearest + neighbors to `p1[n, i]` in `p2[n]`. This is padded with zeros both where a cloud + in p2 has fewer than K points and where a cloud in p1 has fewer than P1 + points. + + nn: Tensor of shape (N, P1, K, D) giving the K nearest neighbors in p2 for + each point in p1. Concretely, `p2_nn[n, i, k]` gives the k-th nearest neighbor + for `p1[n, i]`. Returned if `return_nn` is True. + The nearest neighbors are collected using `knn_gather` + + .. code-block:: + + p2_nn = knn_gather(p2, p1_idx, lengths2) + + which is a helper function that allows indexing any tensor of shape (N, P2, U) with + the indices `p1_idx` returned by `knn_points`. The output is a tensor + of shape (N, P1, K, U). + + """ + if p1.shape[0] != p2.shape[0]: + raise ValueError("pts1 and pts2 must have the same batch dimension.") + if p1.shape[2] != p2.shape[2]: + raise ValueError("pts1 and pts2 must have the same point dimension.") + + p1 = p1.contiguous() + p2 = p2.contiguous() + + P1 = p1.shape[1] + P2 = p2.shape[1] + + if lengths1 is None: + lengths1 = torch.full((p1.shape[0],), P1, dtype=torch.int64, device=p1.device) + if lengths2 is None: + lengths2 = torch.full((p1.shape[0],), P2, dtype=torch.int64, device=p1.device) + + p1_dists, p1_idx = _knn_points.apply( + p1, p2, lengths1, lengths2, K, version, norm, return_sorted + ) + + p2_nn = None + if return_nn: + p2_nn = knn_gather(p2, p1_idx, lengths2) + + return _KNN(dists=p1_dists, idx=p1_idx, knn=p2_nn if return_nn else None) + + +def knn_gather( + x: torch.Tensor, idx: torch.Tensor, lengths: Union[torch.Tensor, None] = None +): + """ + A helper function for knn that allows indexing a tensor x with the indices `idx` + returned by `knn_points`. + + For example, if `dists, idx = knn_points(p, x, lengths_p, lengths, K)` + where p is a tensor of shape (N, L, D) and x a tensor of shape (N, M, D), + then one can compute the K nearest neighbors of p with `p_nn = knn_gather(x, idx, lengths)`. + It can also be applied for any tensor x of shape (N, M, U) where U != D. + + Args: + x: Tensor of shape (N, M, U) containing U-dimensional features to + be gathered. + idx: LongTensor of shape (N, L, K) giving the indices returned by `knn_points`. + lengths: LongTensor of shape (N,) of values in the range [0, M], giving the + length of each example in the batch in x. Or None to indicate that every + example has length M. + Returns: + x_out: Tensor of shape (N, L, K, U) resulting from gathering the elements of x + with idx, s.t. `x_out[n, l, k] = x[n, idx[n, l, k]]`. + If `k > lengths[n]` then `x_out[n, l, k]` is filled with 0.0. + """ + N, M, U = x.shape + _N, L, K = idx.shape + + if N != _N: + raise ValueError("x and idx must have same batch dimension.") + + if lengths is None: + lengths = torch.full((x.shape[0],), M, dtype=torch.int64, device=x.device) + + idx_expanded = idx[:, :, :, None].expand(-1, -1, -1, U) + # idx_expanded has shape [N, L, K, U] + + x_out = x[:, :, None].expand(-1, -1, K, -1).gather(1, idx_expanded) + # p2_nn has shape [N, L, K, U] + + needs_mask = lengths.min() < K + if needs_mask: + # mask has shape [N, K], true where idx is irrelevant because + # there is less number of points in p2 than K + mask = lengths[:, None] <= torch.arange(K, device=x.device)[None] + + # expand mask to shape [N, L, K, U] + mask = mask[:, None].expand(-1, L, -1) + mask = mask[:, :, :, None].expand(-1, -1, -1, U) + x_out[mask] = 0.0 + + return x_out diff --git a/unidepth/ops/knn/setup.py b/unidepth/ops/knn/setup.py new file mode 100644 index 0000000..6e7f7dd --- /dev/null +++ b/unidepth/ops/knn/setup.py @@ -0,0 +1,61 @@ +import glob +import os + +import torch +from setuptools import find_packages, setup +from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension + +requirements = ["torch", "torchvision"] + + +def get_extensions(): + this_dir = os.path.dirname(os.path.abspath(__file__)) + extensions_dir = os.path.join(this_dir, "src") + + main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) + source_cpu = glob.glob(os.path.join(extensions_dir, "*.cpp")) + source_cuda = glob.glob(os.path.join(extensions_dir, "*.cu")) + + sources = main_file + source_cpu + extension = CppExtension + extra_compile_args = {"cxx": ["-O3"]} + define_macros = [] + + if torch.cuda.is_available() and CUDA_HOME is not None: + extension = CUDAExtension + sources += source_cuda + define_macros += [("WITH_CUDA", None)] + extra_compile_args["nvcc"] = [ + "-O3", + ] + else: + raise NotImplementedError("Cuda is not available") + + sources = list(set([os.path.join(extensions_dir, s) for s in sources])) + include_dirs = [extensions_dir] + ext_modules = [ + extension( + "KNN", + sources, + include_dirs=include_dirs, + define_macros=define_macros, + extra_compile_args=extra_compile_args, + ) + ] + + return ext_modules + + +setup( + name="KNN", + version="0.1", + author="Luigi Piccinelli", + ext_modules=get_extensions(), + packages=find_packages( + exclude=( + "configs", + "tests", + ) + ), + cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, +) diff --git a/unidepth/ops/knn/src/knn.cu b/unidepth/ops/knn/src/knn.cu new file mode 100644 index 0000000..ba0732d --- /dev/null +++ b/unidepth/ops/knn/src/knn.cu @@ -0,0 +1,587 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +#include "utils/dispatch.cuh" +#include "utils/mink.cuh" + +// A chunk of work is blocksize-many points of P1. +// The number of potential chunks to do is N*(1+(P1-1)/blocksize) +// call (1+(P1-1)/blocksize) chunks_per_cloud +// These chunks are divided among the gridSize-many blocks. +// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc . +// In chunk i, we work on cloud i/chunks_per_cloud on points starting from +// blocksize*(i%chunks_per_cloud). + +template +__global__ void KNearestNeighborKernelV0( + const scalar_t* __restrict__ points1, + const scalar_t* __restrict__ points2, + const int64_t* __restrict__ lengths1, + const int64_t* __restrict__ lengths2, + scalar_t* __restrict__ dists, + int64_t* __restrict__ idxs, + const size_t N, + const size_t P1, + const size_t P2, + const size_t D, + const size_t K, + const size_t norm) { + // Store both dists and indices for knn in global memory. + const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x); + const int64_t chunks_to_do = N * chunks_per_cloud; + for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) { + const int64_t n = chunk / chunks_per_cloud; + const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud); + int64_t p1 = start_point + threadIdx.x; + if (p1 >= lengths1[n]) + continue; + int offset = n * P1 * K + p1 * K; + int64_t length2 = lengths2[n]; + MinK mink(dists + offset, idxs + offset, K); + for (int p2 = 0; p2 < length2; ++p2) { + // Find the distance between points1[n, p1] and points[n, p2] + scalar_t dist = 0; + for (int d = 0; d < D; ++d) { + scalar_t coord1 = points1[n * P1 * D + p1 * D + d]; + scalar_t coord2 = points2[n * P2 * D + p2 * D + d]; + scalar_t diff = coord1 - coord2; + scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff); + dist += norm_diff; + } + mink.add(dist, p2); + } + } +} + +template +__global__ void KNearestNeighborKernelV1( + const scalar_t* __restrict__ points1, + const scalar_t* __restrict__ points2, + const int64_t* __restrict__ lengths1, + const int64_t* __restrict__ lengths2, + scalar_t* __restrict__ dists, + int64_t* __restrict__ idxs, + const size_t N, + const size_t P1, + const size_t P2, + const size_t K, + const size_t norm) { + // Same idea as the previous version, but hoist D into a template argument + // so we can cache the current point in a thread-local array. We still store + // the current best K dists and indices in global memory, so this should work + // for very large K and fairly large D. + scalar_t cur_point[D]; + const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x); + const int64_t chunks_to_do = N * chunks_per_cloud; + for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) { + const int64_t n = chunk / chunks_per_cloud; + const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud); + int64_t p1 = start_point + threadIdx.x; + if (p1 >= lengths1[n]) + continue; + for (int d = 0; d < D; ++d) { + cur_point[d] = points1[n * P1 * D + p1 * D + d]; + } + int offset = n * P1 * K + p1 * K; + int64_t length2 = lengths2[n]; + MinK mink(dists + offset, idxs + offset, K); + for (int p2 = 0; p2 < length2; ++p2) { + // Find the distance between cur_point and points[n, p2] + scalar_t dist = 0; + for (int d = 0; d < D; ++d) { + scalar_t diff = cur_point[d] - points2[n * P2 * D + p2 * D + d]; + scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff); + dist += norm_diff; + } + mink.add(dist, p2); + } + } +} + +// This is a shim functor to allow us to dispatch using DispatchKernel1D +template +struct KNearestNeighborV1Functor { + static void run( + size_t blocks, + size_t threads, + const scalar_t* __restrict__ points1, + const scalar_t* __restrict__ points2, + const int64_t* __restrict__ lengths1, + const int64_t* __restrict__ lengths2, + scalar_t* __restrict__ dists, + int64_t* __restrict__ idxs, + const size_t N, + const size_t P1, + const size_t P2, + const size_t K, + const size_t norm) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + KNearestNeighborKernelV1<<>>( + points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, K, norm); + } +}; + +template +__global__ void KNearestNeighborKernelV2( + const scalar_t* __restrict__ points1, + const scalar_t* __restrict__ points2, + const int64_t* __restrict__ lengths1, + const int64_t* __restrict__ lengths2, + scalar_t* __restrict__ dists, + int64_t* __restrict__ idxs, + const int64_t N, + const int64_t P1, + const int64_t P2, + const size_t norm) { + // Same general implementation as V2, but also hoist K into a template arg. + scalar_t cur_point[D]; + scalar_t min_dists[K]; + int min_idxs[K]; + const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x); + const int64_t chunks_to_do = N * chunks_per_cloud; + for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) { + const int64_t n = chunk / chunks_per_cloud; + const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud); + int64_t p1 = start_point + threadIdx.x; + if (p1 >= lengths1[n]) + continue; + for (int d = 0; d < D; ++d) { + cur_point[d] = points1[n * P1 * D + p1 * D + d]; + } + int64_t length2 = lengths2[n]; + MinK mink(min_dists, min_idxs, K); + for (int p2 = 0; p2 < length2; ++p2) { + scalar_t dist = 0; + for (int d = 0; d < D; ++d) { + int offset = n * P2 * D + p2 * D + d; + scalar_t diff = cur_point[d] - points2[offset]; + scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff); + dist += norm_diff; + } + mink.add(dist, p2); + } + for (int k = 0; k < mink.size(); ++k) { + idxs[n * P1 * K + p1 * K + k] = min_idxs[k]; + dists[n * P1 * K + p1 * K + k] = min_dists[k]; + } + } +} + +// This is a shim so we can dispatch using DispatchKernel2D +template +struct KNearestNeighborKernelV2Functor { + static void run( + size_t blocks, + size_t threads, + const scalar_t* __restrict__ points1, + const scalar_t* __restrict__ points2, + const int64_t* __restrict__ lengths1, + const int64_t* __restrict__ lengths2, + scalar_t* __restrict__ dists, + int64_t* __restrict__ idxs, + const int64_t N, + const int64_t P1, + const int64_t P2, + const size_t norm) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + KNearestNeighborKernelV2<<>>( + points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, norm); + } +}; + +template +__global__ void KNearestNeighborKernelV3( + const scalar_t* __restrict__ points1, + const scalar_t* __restrict__ points2, + const int64_t* __restrict__ lengths1, + const int64_t* __restrict__ lengths2, + scalar_t* __restrict__ dists, + int64_t* __restrict__ idxs, + const size_t N, + const size_t P1, + const size_t P2, + const size_t norm) { + // Same idea as V2, but use register indexing for thread-local arrays. + // Enabling sorting for this version leads to huge slowdowns; I suspect + // that it forces min_dists into local memory rather than registers. + // As a result this version is always unsorted. + scalar_t cur_point[D]; + scalar_t min_dists[K]; + int min_idxs[K]; + const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x); + const int64_t chunks_to_do = N * chunks_per_cloud; + for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) { + const int64_t n = chunk / chunks_per_cloud; + const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud); + int64_t p1 = start_point + threadIdx.x; + if (p1 >= lengths1[n]) + continue; + for (int d = 0; d < D; ++d) { + cur_point[d] = points1[n * P1 * D + p1 * D + d]; + } + int64_t length2 = lengths2[n]; + RegisterMinK mink(min_dists, min_idxs); + for (int p2 = 0; p2 < length2; ++p2) { + scalar_t dist = 0; + for (int d = 0; d < D; ++d) { + int offset = n * P2 * D + p2 * D + d; + scalar_t diff = cur_point[d] - points2[offset]; + scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff); + dist += norm_diff; + } + mink.add(dist, p2); + } + for (int k = 0; k < mink.size(); ++k) { + idxs[n * P1 * K + p1 * K + k] = min_idxs[k]; + dists[n * P1 * K + p1 * K + k] = min_dists[k]; + } + } +} + +// This is a shim so we can dispatch using DispatchKernel2D +template +struct KNearestNeighborKernelV3Functor { + static void run( + size_t blocks, + size_t threads, + const scalar_t* __restrict__ points1, + const scalar_t* __restrict__ points2, + const int64_t* __restrict__ lengths1, + const int64_t* __restrict__ lengths2, + scalar_t* __restrict__ dists, + int64_t* __restrict__ idxs, + const size_t N, + const size_t P1, + const size_t P2, + const size_t norm) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + KNearestNeighborKernelV3<<>>( + points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, norm); + } +}; + +constexpr int V1_MIN_D = 1; +constexpr int V1_MAX_D = 32; + +constexpr int V2_MIN_D = 1; +constexpr int V2_MAX_D = 8; +constexpr int V2_MIN_K = 1; +constexpr int V2_MAX_K = 32; + +constexpr int V3_MIN_D = 1; +constexpr int V3_MAX_D = 8; +constexpr int V3_MIN_K = 1; +constexpr int V3_MAX_K = 4; + +bool InBounds(const int64_t min, const int64_t x, const int64_t max) { + return min <= x && x <= max; +} + +bool KnnCheckVersion(int version, const int64_t D, const int64_t K) { + if (version == 0) { + return true; + } else if (version == 1) { + return InBounds(V1_MIN_D, D, V1_MAX_D); + } else if (version == 2) { + return InBounds(V2_MIN_D, D, V2_MAX_D) && InBounds(V2_MIN_K, K, V2_MAX_K); + } else if (version == 3) { + return InBounds(V3_MIN_D, D, V3_MAX_D) && InBounds(V3_MIN_K, K, V3_MAX_K); + } + return false; +} + +int ChooseVersion(const int64_t D, const int64_t K) { + for (int version = 3; version >= 1; version--) { + if (KnnCheckVersion(version, D, K)) { + return version; + } + } + return 0; +} + +std::tuple KNearestNeighborIdxCuda( + const at::Tensor& p1, + const at::Tensor& p2, + const at::Tensor& lengths1, + const at::Tensor& lengths2, + const int norm, + const int K, + int version) { + // Check inputs are on the same device + at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2}, + lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4}; + at::CheckedFrom c = "KNearestNeighborIdxCuda"; + at::checkAllSameGPU(c, {p1_t, p2_t, lengths1_t, lengths2_t}); + at::checkAllSameType(c, {p1_t, p2_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(p1.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + const auto N = p1.size(0); + const auto P1 = p1.size(1); + const auto P2 = p2.size(1); + const auto D = p2.size(2); + const int64_t K_64 = K; + + TORCH_CHECK((norm == 1) || (norm == 2), "Norm must be 1 or 2."); + + TORCH_CHECK(p1.size(2) == D, "Point sets must have the same last dimension"); + auto long_dtype = lengths1.options().dtype(at::kLong); + auto idxs = at::zeros({N, P1, K}, long_dtype); + auto dists = at::zeros({N, P1, K}, p1.options()); + + if (idxs.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(idxs, dists); + } + + if (version < 0) { + version = ChooseVersion(D, K); + } else if (!KnnCheckVersion(version, D, K)) { + int new_version = ChooseVersion(D, K); + std::cout << "WARNING: Requested KNN version " << version + << " is not compatible with D = " << D << "; K = " << K + << ". Falling back to version = " << new_version << std::endl; + version = new_version; + } + + // At this point we should have a valid version no matter what data the user + // gave us. But we can check once more to be sure; however this time + // assert fail since failing at this point means we have a bug in our version + // selection or checking code. + AT_ASSERTM(KnnCheckVersion(version, D, K), "Invalid version"); + + const size_t threads = 256; + const size_t blocks = 256; + if (version == 0) { + AT_DISPATCH_FLOATING_TYPES( + p1.scalar_type(), "knn_kernel_cuda", ([&] { + KNearestNeighborKernelV0<<>>( + p1.contiguous().data_ptr(), + p2.contiguous().data_ptr(), + lengths1.contiguous().data_ptr(), + lengths2.contiguous().data_ptr(), + dists.data_ptr(), + idxs.data_ptr(), + N, + P1, + P2, + D, + K, + norm); + })); + } else if (version == 1) { + AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] { + DispatchKernel1D< + KNearestNeighborV1Functor, + scalar_t, + V1_MIN_D, + V1_MAX_D>( + D, + blocks, + threads, + p1.contiguous().data_ptr(), + p2.contiguous().data_ptr(), + lengths1.contiguous().data_ptr(), + lengths2.contiguous().data_ptr(), + dists.data_ptr(), + idxs.data_ptr(), + N, + P1, + P2, + K, + norm); + })); + } else if (version == 2) { + AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] { + DispatchKernel2D< + KNearestNeighborKernelV2Functor, + scalar_t, + V2_MIN_D, + V2_MAX_D, + V2_MIN_K, + V2_MAX_K>( + D, + K_64, + blocks, + threads, + p1.contiguous().data_ptr(), + p2.contiguous().data_ptr(), + lengths1.contiguous().data_ptr(), + lengths2.contiguous().data_ptr(), + dists.data_ptr(), + idxs.data_ptr(), + N, + P1, + P2, + norm); + })); + } else if (version == 3) { + AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] { + DispatchKernel2D< + KNearestNeighborKernelV3Functor, + scalar_t, + V3_MIN_D, + V3_MAX_D, + V3_MIN_K, + V3_MAX_K>( + D, + K_64, + blocks, + threads, + p1.contiguous().data_ptr(), + p2.contiguous().data_ptr(), + lengths1.contiguous().data_ptr(), + lengths2.contiguous().data_ptr(), + dists.data_ptr(), + idxs.data_ptr(), + N, + P1, + P2, + norm); + })); + } + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(idxs, dists); +} + +// ------------------------------------------------------------- // +// Backward Operators // +// ------------------------------------------------------------- // + +// TODO(gkioxari) support all data types once AtomicAdd supports doubles. +// Currently, support is for floats only. +__global__ void KNearestNeighborBackwardKernel( + const float* __restrict__ p1, // (N, P1, D) + const float* __restrict__ p2, // (N, P2, D) + const int64_t* __restrict__ lengths1, // (N,) + const int64_t* __restrict__ lengths2, // (N,) + const int64_t* __restrict__ idxs, // (N, P1, K) + const float* __restrict__ grad_dists, // (N, P1, K) + float* __restrict__ grad_p1, // (N, P1, D) + float* __restrict__ grad_p2, // (N, P2, D) + const size_t N, + const size_t P1, + const size_t P2, + const size_t K, + const size_t D, + const size_t norm) { + const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = gridDim.x * blockDim.x; + + for (size_t i = tid; i < N * P1 * K * D; i += stride) { + const size_t n = i / (P1 * K * D); // batch index + size_t rem = i % (P1 * K * D); + const size_t p1_idx = rem / (K * D); // index of point in p1 + rem = rem % (K * D); + const size_t k = rem / D; // k-th nearest neighbor + const size_t d = rem % D; // d-th dimension in the feature vector + + const size_t num1 = lengths1[n]; // number of valid points in p1 in batch + const size_t num2 = lengths2[n]; // number of valid points in p2 in batch + if ((p1_idx < num1) && (k < num2)) { + const float grad_dist = grad_dists[n * P1 * K + p1_idx * K + k]; + // index of point in p2 corresponding to the k-th nearest neighbor + const int64_t p2_idx = idxs[n * P1 * K + p1_idx * K + k]; + // If the index is the pad value of -1 then ignore it + if (p2_idx == -1) { + continue; + } + float diff = 0.0; + if (norm == 1) { + float sign = + (p1[n * P1 * D + p1_idx * D + d] > p2[n * P2 * D + p2_idx * D + d]) + ? 1.0 + : -1.0; + diff = grad_dist * sign; + } else { // norm is 2 + diff = 2.0 * grad_dist * + (p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]); + } + atomicAdd(grad_p1 + n * P1 * D + p1_idx * D + d, diff); + atomicAdd(grad_p2 + n * P2 * D + p2_idx * D + d, -1.0f * diff); + } + } +} + +std::tuple KNearestNeighborBackwardCuda( + const at::Tensor& p1, + const at::Tensor& p2, + const at::Tensor& lengths1, + const at::Tensor& lengths2, + const at::Tensor& idxs, + int norm, + const at::Tensor& grad_dists) { + // Check inputs are on the same device + at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2}, + lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4}, + idxs_t{idxs, "idxs", 5}, grad_dists_t{grad_dists, "grad_dists", 6}; + at::CheckedFrom c = "KNearestNeighborBackwardCuda"; + at::checkAllSameGPU( + c, {p1_t, p2_t, lengths1_t, lengths2_t, idxs_t, grad_dists_t}); + at::checkAllSameType(c, {p1_t, p2_t, grad_dists_t}); + + // This is nondeterministic because atomicAdd + at::globalContext().alertNotDeterministic("KNearestNeighborBackwardCuda"); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(p1.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + const auto N = p1.size(0); + const auto P1 = p1.size(1); + const auto P2 = p2.size(1); + const auto D = p2.size(2); + const auto K = idxs.size(2); + + TORCH_CHECK(p1.size(2) == D, "Point sets must have the same last dimension"); + TORCH_CHECK(idxs.size(0) == N, "KNN idxs must have the same batch dimension"); + TORCH_CHECK( + idxs.size(1) == P1, "KNN idxs must have the same point dimension as p1"); + TORCH_CHECK(grad_dists.size(0) == N); + TORCH_CHECK(grad_dists.size(1) == P1); + TORCH_CHECK(grad_dists.size(2) == K); + + auto grad_p1 = at::zeros({N, P1, D}, p1.options()); + auto grad_p2 = at::zeros({N, P2, D}, p2.options()); + + if (grad_p1.numel() == 0 || grad_p2.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(grad_p1, grad_p2); + } + + const int blocks = 64; + const int threads = 512; + + KNearestNeighborBackwardKernel<<>>( + p1.contiguous().data_ptr(), + p2.contiguous().data_ptr(), + lengths1.contiguous().data_ptr(), + lengths2.contiguous().data_ptr(), + idxs.contiguous().data_ptr(), + grad_dists.contiguous().data_ptr(), + grad_p1.data_ptr(), + grad_p2.data_ptr(), + N, + P1, + P2, + K, + D, + norm); + + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(grad_p1, grad_p2); +} \ No newline at end of file diff --git a/unidepth/ops/knn/src/knn.h b/unidepth/ops/knn/src/knn.h new file mode 100644 index 0000000..e43a023 --- /dev/null +++ b/unidepth/ops/knn/src/knn.h @@ -0,0 +1,157 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include +#include +#include "utils/pytorch3d_cutils.h" + +// Compute indices of K nearest neighbors in pointcloud p2 to points +// in pointcloud p1. +// +// Args: +// p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each +// containing P1 points of dimension D. +// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each +// containing P2 points of dimension D. +// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud. +// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud. +// norm: int specifying the norm for the distance (1 for L1, 2 for L2) +// K: int giving the number of nearest points to return. +// version: Integer telling which implementation to use. +// +// Returns: +// p1_neighbor_idx: LongTensor of shape (N, P1, K), where +// p1_neighbor_idx[n, i, k] = j means that the kth nearest +// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j]. +// It is padded with zeros so that it can be used easily in a later +// gather() operation. +// +// p1_neighbor_dists: FloatTensor of shape (N, P1, K) containing the squared +// distance from each point p1[n, p, :] to its K neighbors +// p2[n, p1_neighbor_idx[n, p, k], :]. + +// CPU implementation. +std::tuple KNearestNeighborIdxCpu( + const at::Tensor& p1, + const at::Tensor& p2, + const at::Tensor& lengths1, + const at::Tensor& lengths2, + const int norm, + const int K); + +// CUDA implementation +std::tuple KNearestNeighborIdxCuda( + const at::Tensor& p1, + const at::Tensor& p2, + const at::Tensor& lengths1, + const at::Tensor& lengths2, + const int norm, + const int K, + const int version); + +// Implementation which is exposed. +std::tuple KNearestNeighborIdx( + const at::Tensor& p1, + const at::Tensor& p2, + const at::Tensor& lengths1, + const at::Tensor& lengths2, + const int norm, + const int K, + const int version) { + if (p1.is_cuda() || p2.is_cuda()) { +#ifdef WITH_CUDA + CHECK_CUDA(p1); + CHECK_CUDA(p2); + return KNearestNeighborIdxCuda( + p1, p2, lengths1, lengths2, norm, K, version); +#else + AT_ERROR("Not compiled with GPU support."); +#endif + } + return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, norm, K); +} + +// Compute gradients with respect to p1 and p2 +// +// Args: +// p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each +// containing P1 points of dimension D. +// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each +// containing P2 points of dimension D. +// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud. +// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud. +// p1_neighbor_idx: LongTensor of shape (N, P1, K), where +// p1_neighbor_idx[n, i, k] = j means that the kth nearest +// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j]. +// It is padded with zeros so that it can be used easily in a later +// gather() operation. This is computed from the forward pass. +// norm: int specifying the norm for the distance (1 for L1, 2 for L2) +// grad_dists: FLoatTensor of shape (N, P1, K) which contains the input +// gradients. +// +// Returns: +// grad_p1: FloatTensor of shape (N, P1, D) containing the output gradients +// wrt p1. +// grad_p2: FloatTensor of shape (N, P2, D) containing the output gradients +// wrt p2. + +// CPU implementation. +std::tuple KNearestNeighborBackwardCpu( + const at::Tensor& p1, + const at::Tensor& p2, + const at::Tensor& lengths1, + const at::Tensor& lengths2, + const at::Tensor& idxs, + const int norm, + const at::Tensor& grad_dists); + +// CUDA implementation +std::tuple KNearestNeighborBackwardCuda( + const at::Tensor& p1, + const at::Tensor& p2, + const at::Tensor& lengths1, + const at::Tensor& lengths2, + const at::Tensor& idxs, + const int norm, + const at::Tensor& grad_dists); + +// Implementation which is exposed. +std::tuple KNearestNeighborBackward( + const at::Tensor& p1, + const at::Tensor& p2, + const at::Tensor& lengths1, + const at::Tensor& lengths2, + const at::Tensor& idxs, + const int norm, + const at::Tensor& grad_dists) { + if (p1.is_cuda() || p2.is_cuda()) { +#ifdef WITH_CUDA + CHECK_CUDA(p1); + CHECK_CUDA(p2); + return KNearestNeighborBackwardCuda( + p1, p2, lengths1, lengths2, idxs, norm, grad_dists); +#else + AT_ERROR("Not compiled with GPU support."); +#endif + } + return KNearestNeighborBackwardCpu( + p1, p2, lengths1, lengths2, idxs, norm, grad_dists); +} + +// Utility to check whether a KNN version can be used. +// +// Args: +// version: Integer in the range 0 <= version <= 3 indicating one of our +// KNN implementations. +// D: Number of dimensions for the input and query point clouds +// K: Number of neighbors to be found +// +// Returns: +// Whether the indicated KNN version can be used. +bool KnnCheckVersion(int version, const int64_t D, const int64_t K); \ No newline at end of file diff --git a/unidepth/ops/knn/src/knn_cpu.cpp b/unidepth/ops/knn/src/knn_cpu.cpp new file mode 100644 index 0000000..694ab11 --- /dev/null +++ b/unidepth/ops/knn/src/knn_cpu.cpp @@ -0,0 +1,128 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +std::tuple KNearestNeighborIdxCpu( + const at::Tensor& p1, + const at::Tensor& p2, + const at::Tensor& lengths1, + const at::Tensor& lengths2, + const int norm, + const int K) { + const int N = p1.size(0); + const int P1 = p1.size(1); + const int D = p1.size(2); + + auto long_opts = lengths1.options().dtype(torch::kInt64); + torch::Tensor idxs = torch::full({N, P1, K}, 0, long_opts); + torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options()); + + auto p1_a = p1.accessor(); + auto p2_a = p2.accessor(); + auto lengths1_a = lengths1.accessor(); + auto lengths2_a = lengths2.accessor(); + auto idxs_a = idxs.accessor(); + auto dists_a = dists.accessor(); + + for (int n = 0; n < N; ++n) { + const int64_t length1 = lengths1_a[n]; + const int64_t length2 = lengths2_a[n]; + for (int64_t i1 = 0; i1 < length1; ++i1) { + // Use a priority queue to store (distance, index) tuples. + std::priority_queue> q; + for (int64_t i2 = 0; i2 < length2; ++i2) { + float dist = 0; + for (int d = 0; d < D; ++d) { + float diff = p1_a[n][i1][d] - p2_a[n][i2][d]; + if (norm == 1) { + dist += abs(diff); + } else { // norm is 2 (default) + dist += diff * diff; + } + } + int size = static_cast(q.size()); + if (size < K || dist < std::get<0>(q.top())) { + q.emplace(dist, i2); + if (size >= K) { + q.pop(); + } + } + } + while (!q.empty()) { + auto t = q.top(); + q.pop(); + const int k = q.size(); + dists_a[n][i1][k] = std::get<0>(t); + idxs_a[n][i1][k] = std::get<1>(t); + } + } + } + return std::make_tuple(idxs, dists); +} + +// ------------------------------------------------------------- // +// Backward Operators // +// ------------------------------------------------------------- // + +std::tuple KNearestNeighborBackwardCpu( + const at::Tensor& p1, + const at::Tensor& p2, + const at::Tensor& lengths1, + const at::Tensor& lengths2, + const at::Tensor& idxs, + const int norm, + const at::Tensor& grad_dists) { + const int N = p1.size(0); + const int P1 = p1.size(1); + const int D = p1.size(2); + const int P2 = p2.size(1); + const int K = idxs.size(2); + + torch::Tensor grad_p1 = torch::full({N, P1, D}, 0, p1.options()); + torch::Tensor grad_p2 = torch::full({N, P2, D}, 0, p2.options()); + + auto p1_a = p1.accessor(); + auto p2_a = p2.accessor(); + auto lengths1_a = lengths1.accessor(); + auto lengths2_a = lengths2.accessor(); + auto idxs_a = idxs.accessor(); + auto grad_dists_a = grad_dists.accessor(); + auto grad_p1_a = grad_p1.accessor(); + auto grad_p2_a = grad_p2.accessor(); + + for (int n = 0; n < N; ++n) { + const int64_t length1 = lengths1_a[n]; + int64_t length2 = lengths2_a[n]; + length2 = (length2 < K) ? length2 : K; + for (int64_t i1 = 0; i1 < length1; ++i1) { + for (int64_t k = 0; k < length2; ++k) { + const int64_t i2 = idxs_a[n][i1][k]; + // If the index is the pad value of -1 then ignore it + if (i2 == -1) { + continue; + } + for (int64_t d = 0; d < D; ++d) { + float diff = 0.0; + if (norm == 1) { + float sign = (p1_a[n][i1][d] > p2_a[n][i2][d]) ? 1.0 : -1.0; + diff = grad_dists_a[n][i1][k] * sign; + } else { // norm is 2 (default) + diff = 2.0f * grad_dists_a[n][i1][k] * + (p1_a[n][i1][d] - p2_a[n][i2][d]); + } + grad_p1_a[n][i1][d] += diff; + grad_p2_a[n][i2][d] += -1.0f * diff; + } + } + } + } + return std::make_tuple(grad_p1, grad_p2); +} \ No newline at end of file diff --git a/unidepth/ops/knn/src/knn_ext.cpp b/unidepth/ops/knn/src/knn_ext.cpp new file mode 100644 index 0000000..f2fc9b4 --- /dev/null +++ b/unidepth/ops/knn/src/knn_ext.cpp @@ -0,0 +1,10 @@ +#include +#include "knn.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { +#ifdef WITH_CUDA + m.def("knn_check_version", &KnnCheckVersion); +#endif + m.def("knn_points_idx", &KNearestNeighborIdx); + m.def("knn_points_backward", &KNearestNeighborBackward); +} \ No newline at end of file diff --git a/unidepth/ops/knn/src/utils/dispatch.cuh b/unidepth/ops/knn/src/utils/dispatch.cuh new file mode 100644 index 0000000..af197b4 --- /dev/null +++ b/unidepth/ops/knn/src/utils/dispatch.cuh @@ -0,0 +1,357 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// This file provides utilities for dispatching to specialized versions of +// functions. This is especially useful for CUDA kernels, since specializing +// them to particular input sizes can often allow the compiler to unroll loops +// and place arrays into registers, which can give huge performance speedups. +// +// As an example, suppose we have the following function which is specialized +// based on a compile-time int64_t value: +// +// template +// struct SquareOffset { +// static void run(T y) { +// T val = x * x + y; +// std::cout << val << std::endl; +// } +// } +// +// This function takes one compile-time argument x, and one run-time argument y. +// We might want to compile specialized versions of this for x=0, x=1, etc and +// then dispatch to the correct one based on the runtime value of x. +// One simple way to achieve this is with a lookup table: +// +// template +// void DispatchSquareOffset(const int64_t x, T y) { +// if (x == 0) { +// SquareOffset::run(y); +// } else if (x == 1) { +// SquareOffset::run(y); +// } else if (x == 2) { +// SquareOffset::run(y); +// } +// } +// +// This function takes both x and y as run-time arguments, and dispatches to +// different specialized versions of SquareOffset based on the run-time value +// of x. This works, but it's tedious and error-prone. If we want to change the +// set of x values for which we provide compile-time specializations, then we +// will need to do a lot of tedius editing of the dispatch function. Also, if we +// want to provide compile-time specializations for another function other than +// SquareOffset, we will need to duplicate the entire lookup table. +// +// To solve these problems, we can use the DispatchKernel1D function provided by +// this file instead: +// +// template +// void DispatchSquareOffset(const int64_t x, T y) { +// constexpr int64_t xmin = 0; +// constexpr int64_t xmax = 2; +// DispatchKernel1D(x, y); +// } +// +// DispatchKernel1D uses template metaprogramming to compile specialized +// versions of SquareOffset for all values of x with xmin <= x <= xmax, and +// then dispatches to the correct one based on the run-time value of x. If we +// want to change the range of x values for which SquareOffset is specialized +// at compile-time, then all we have to do is change the values of the +// compile-time constants xmin and xmax. +// +// This file also allows us to similarly dispatch functions that depend on two +// compile-time int64_t values, using the DispatchKernel2D function like this: +// +// template +// struct Sum { +// static void run(T z, T w) { +// T val = x + y + z + w; +// std::cout << val << std::endl; +// } +// } +// +// template +// void DispatchSum(const int64_t x, const int64_t y, int z, int w) { +// constexpr int64_t xmin = 1; +// constexpr int64_t xmax = 3; +// constexpr int64_t ymin = 2; +// constexpr int64_t ymax = 5; +// DispatchKernel2D(x, y, z, w); +// } +// +// Like its 1D counterpart, DispatchKernel2D uses template metaprogramming to +// compile specialized versions of sum for all values of (x, y) with +// xmin <= x <= xmax and ymin <= y <= ymax, then dispatches to the correct +// specialized version based on the runtime values of x and y. + +// Define some helper structs in an anonymous namespace. +namespace { + +// 1D dispatch: general case. +// Kernel is the function we want to dispatch to; it should take a typename and +// an int64_t as template args, and it should define a static void function +// run which takes any number of arguments of any type. +// In order to dispatch, we will take an additional template argument curN, +// and increment it via template recursion until it is equal to the run-time +// argument N. +template < + template + class Kernel, + typename T, + int64_t minN, + int64_t maxN, + int64_t curN, + typename... Args> +struct DispatchKernelHelper1D { + static void run(const int64_t N, Args... args) { + if (curN == N) { + // The compile-time value curN is equal to the run-time value N, so we + // can dispatch to the run method of the Kernel. + Kernel::run(args...); + } else if (curN < N) { + // Increment curN via template recursion + DispatchKernelHelper1D::run( + N, args...); + } + // We shouldn't get here -- throw an error? + } +}; + +// 1D dispatch: Specialization when curN == maxN +// We need this base case to avoid infinite template recursion. +template < + template + class Kernel, + typename T, + int64_t minN, + int64_t maxN, + typename... Args> +struct DispatchKernelHelper1D { + static void run(const int64_t N, Args... args) { + if (N == maxN) { + Kernel::run(args...); + } + // We shouldn't get here -- throw an error? + } +}; + +// 2D dispatch, general case. +// This is similar to the 1D case: we take additional template args curN and +// curM, and increment them via template recursion until they are equal to +// the run-time values of N and M, at which point we dispatch to the run +// method of the kernel. +template < + template + class Kernel, + typename T, + int64_t minN, + int64_t maxN, + int64_t curN, + int64_t minM, + int64_t maxM, + int64_t curM, + typename... Args> +struct DispatchKernelHelper2D { + static void run(const int64_t N, const int64_t M, Args... args) { + if (curN == N && curM == M) { + Kernel::run(args...); + } else if (curN < N && curM < M) { + // Increment both curN and curM. This isn't strictly necessary; we could + // just increment one or the other at each step. But this helps to cut + // on the number of recursive calls we make. + DispatchKernelHelper2D< + Kernel, + T, + minN, + maxN, + curN + 1, + minM, + maxM, + curM + 1, + Args...>::run(N, M, args...); + } else if (curN < N) { + // Increment curN only + DispatchKernelHelper2D< + Kernel, + T, + minN, + maxN, + curN + 1, + minM, + maxM, + curM, + Args...>::run(N, M, args...); + } else if (curM < M) { + // Increment curM only + DispatchKernelHelper2D< + Kernel, + T, + minN, + maxN, + curN, + minM, + maxM, + curM + 1, + Args...>::run(N, M, args...); + } + } +}; + +// 2D dispatch, specialization for curN == maxN +template < + template + class Kernel, + typename T, + int64_t minN, + int64_t maxN, + int64_t minM, + int64_t maxM, + int64_t curM, + typename... Args> +struct DispatchKernelHelper2D< + Kernel, + T, + minN, + maxN, + maxN, + minM, + maxM, + curM, + Args...> { + static void run(const int64_t N, const int64_t M, Args... args) { + if (maxN == N && curM == M) { + Kernel::run(args...); + } else if (curM < maxM) { + DispatchKernelHelper2D< + Kernel, + T, + minN, + maxN, + maxN, + minM, + maxM, + curM + 1, + Args...>::run(N, M, args...); + } + // We should not get here -- throw an error? + } +}; + +// 2D dispatch, specialization for curM == maxM +template < + template + class Kernel, + typename T, + int64_t minN, + int64_t maxN, + int64_t curN, + int64_t minM, + int64_t maxM, + typename... Args> +struct DispatchKernelHelper2D< + Kernel, + T, + minN, + maxN, + curN, + minM, + maxM, + maxM, + Args...> { + static void run(const int64_t N, const int64_t M, Args... args) { + if (curN == N && maxM == M) { + Kernel::run(args...); + } else if (curN < maxN) { + DispatchKernelHelper2D< + Kernel, + T, + minN, + maxN, + curN + 1, + minM, + maxM, + maxM, + Args...>::run(N, M, args...); + } + // We should not get here -- throw an error? + } +}; + +// 2D dispatch, specialization for curN == maxN, curM == maxM +template < + template + class Kernel, + typename T, + int64_t minN, + int64_t maxN, + int64_t minM, + int64_t maxM, + typename... Args> +struct DispatchKernelHelper2D< + Kernel, + T, + minN, + maxN, + maxN, + minM, + maxM, + maxM, + Args...> { + static void run(const int64_t N, const int64_t M, Args... args) { + if (maxN == N && maxM == M) { + Kernel::run(args...); + } + // We should not get here -- throw an error? + } +}; + +} // namespace + +// This is the function we expect users to call to dispatch to 1D functions +template < + template + class Kernel, + typename T, + int64_t minN, + int64_t maxN, + typename... Args> +void DispatchKernel1D(const int64_t N, Args... args) { + if (minN <= N && N <= maxN) { + // Kick off the template recursion by calling the Helper with curN = minN + DispatchKernelHelper1D::run( + N, args...); + } + // Maybe throw an error if we tried to dispatch outside the allowed range? +} + +// This is the function we expect users to call to dispatch to 2D functions +template < + template + class Kernel, + typename T, + int64_t minN, + int64_t maxN, + int64_t minM, + int64_t maxM, + typename... Args> +void DispatchKernel2D(const int64_t N, const int64_t M, Args... args) { + if (minN <= N && N <= maxN && minM <= M && M <= maxM) { + // Kick off the template recursion by calling the Helper with curN = minN + // and curM = minM + DispatchKernelHelper2D< + Kernel, + T, + minN, + maxN, + minN, + minM, + maxM, + minM, + Args...>::run(N, M, args...); + } + // Maybe throw an error if we tried to dispatch outside the specified range? +} \ No newline at end of file diff --git a/unidepth/ops/knn/src/utils/index_utils.cuh b/unidepth/ops/knn/src/utils/index_utils.cuh new file mode 100644 index 0000000..d3f7f7a --- /dev/null +++ b/unidepth/ops/knn/src/utils/index_utils.cuh @@ -0,0 +1,224 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// This converts dynamic array lookups into static array lookups, for small +// arrays up to size 32. +// +// Suppose we have a small thread-local array: +// +// float vals[10]; +// +// Ideally we should only index this array using static indices: +// +// for (int i = 0; i < 10; ++i) vals[i] = i * i; +// +// If we do so, then the CUDA compiler may be able to place the array into +// registers, which can have a big performance improvement. However if we +// access the array dynamically, the the compiler may force the array into +// local memory, which has the same latency as global memory. +// +// These functions convert dynamic array access into static array access +// using a brute-force lookup table. It can be used like this: +// +// float vals[10]; +// int idx = 3; +// float val = 3.14f; +// RegisterIndexUtils::set(vals, idx, val); +// float val2 = RegisterIndexUtils::get(vals, idx); +// +// The implementation is based on fbcuda/RegisterUtils.cuh: +// https://github.com/facebook/fbcuda/blob/master/RegisterUtils.cuh +// To avoid depending on the entire library, we just reimplement these two +// functions. The fbcuda implementation is a bit more sophisticated, and uses +// the preprocessor to generate switch statements that go up to N for each +// value of N. We are lazy and just have a giant explicit switch statement. +// +// We might be able to use a template metaprogramming approach similar to +// DispatchKernel1D for this. However DispatchKernel1D is intended to be used +// for dispatching to the correct CUDA kernel on the host, while this is +// is intended to run on the device. I was concerned that a metaprogramming +// approach for this might lead to extra function calls at runtime if the +// compiler fails to optimize them away, which could be very slow on device. +// However I didn't actually benchmark or test this. +template +struct RegisterIndexUtils { + __device__ __forceinline__ static T get(const T arr[N], int idx) { + if (idx < 0 || idx >= N) + return T(); + switch (idx) { + case 0: + return arr[0]; + case 1: + return arr[1]; + case 2: + return arr[2]; + case 3: + return arr[3]; + case 4: + return arr[4]; + case 5: + return arr[5]; + case 6: + return arr[6]; + case 7: + return arr[7]; + case 8: + return arr[8]; + case 9: + return arr[9]; + case 10: + return arr[10]; + case 11: + return arr[11]; + case 12: + return arr[12]; + case 13: + return arr[13]; + case 14: + return arr[14]; + case 15: + return arr[15]; + case 16: + return arr[16]; + case 17: + return arr[17]; + case 18: + return arr[18]; + case 19: + return arr[19]; + case 20: + return arr[20]; + case 21: + return arr[21]; + case 22: + return arr[22]; + case 23: + return arr[23]; + case 24: + return arr[24]; + case 25: + return arr[25]; + case 26: + return arr[26]; + case 27: + return arr[27]; + case 28: + return arr[28]; + case 29: + return arr[29]; + case 30: + return arr[30]; + case 31: + return arr[31]; + }; + return T(); + } + + __device__ __forceinline__ static void set(T arr[N], int idx, T val) { + if (idx < 0 || idx >= N) + return; + switch (idx) { + case 0: + arr[0] = val; + break; + case 1: + arr[1] = val; + break; + case 2: + arr[2] = val; + break; + case 3: + arr[3] = val; + break; + case 4: + arr[4] = val; + break; + case 5: + arr[5] = val; + break; + case 6: + arr[6] = val; + break; + case 7: + arr[7] = val; + break; + case 8: + arr[8] = val; + break; + case 9: + arr[9] = val; + break; + case 10: + arr[10] = val; + break; + case 11: + arr[11] = val; + break; + case 12: + arr[12] = val; + break; + case 13: + arr[13] = val; + break; + case 14: + arr[14] = val; + break; + case 15: + arr[15] = val; + break; + case 16: + arr[16] = val; + break; + case 17: + arr[17] = val; + break; + case 18: + arr[18] = val; + break; + case 19: + arr[19] = val; + break; + case 20: + arr[20] = val; + break; + case 21: + arr[21] = val; + break; + case 22: + arr[22] = val; + break; + case 23: + arr[23] = val; + break; + case 24: + arr[24] = val; + break; + case 25: + arr[25] = val; + break; + case 26: + arr[26] = val; + break; + case 27: + arr[27] = val; + break; + case 28: + arr[28] = val; + break; + case 29: + arr[29] = val; + break; + case 30: + arr[30] = val; + break; + case 31: + arr[31] = val; + break; + } + } +}; \ No newline at end of file diff --git a/unidepth/ops/knn/src/utils/mink.cuh b/unidepth/ops/knn/src/utils/mink.cuh new file mode 100644 index 0000000..7512aab --- /dev/null +++ b/unidepth/ops/knn/src/utils/mink.cuh @@ -0,0 +1,165 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#define MINK_H + +#include "index_utils.cuh" + +// A data structure to keep track of the smallest K keys seen so far as well +// as their associated values, intended to be used in device code. +// This data structure doesn't allocate any memory; keys and values are stored +// in arrays passed to the constructor. +// +// The implementation is generic; it can be used for any key type that supports +// the < operator, and can be used with any value type. +// +// Example usage: +// +// float keys[K]; +// int values[K]; +// MinK mink(keys, values, K); +// for (...) { +// // Produce some key and value from somewhere +// mink.add(key, value); +// } +// mink.sort(); +// +// Now keys and values store the smallest K keys seen so far and the values +// associated to these keys: +// +// for (int k = 0; k < K; ++k) { +// float key_k = keys[k]; +// int value_k = values[k]; +// } +template +class MinK { + public: + // Constructor. + // + // Arguments: + // keys: Array in which to store keys + // values: Array in which to store values + // K: How many values to keep track of + __device__ MinK(key_t* keys, value_t* vals, int K) + : keys(keys), vals(vals), K(K), _size(0) {} + + // Try to add a new key and associated value to the data structure. If the key + // is one of the smallest K seen so far then it will be kept; otherwise it + // it will not be kept. + // + // This takes O(1) operations if the new key is not kept, or if the structure + // currently contains fewer than K elements. Otherwise this takes O(K) time. + // + // Arguments: + // key: The key to add + // val: The value associated to the key + __device__ __forceinline__ void add(const key_t& key, const value_t& val) { + if (_size < K) { + keys[_size] = key; + vals[_size] = val; + if (_size == 0 || key > max_key) { + max_key = key; + max_idx = _size; + } + _size++; + } else if (key < max_key) { + keys[max_idx] = key; + vals[max_idx] = val; + max_key = key; + for (int k = 0; k < K; ++k) { + key_t cur_key = keys[k]; + if (cur_key > max_key) { + max_key = cur_key; + max_idx = k; + } + } + } + } + + // Get the number of items currently stored in the structure. + // This takes O(1) time. + __device__ __forceinline__ int size() { + return _size; + } + + // Sort the items stored in the structure using bubble sort. + // This takes O(K^2) time. + __device__ __forceinline__ void sort() { + for (int i = 0; i < _size - 1; ++i) { + for (int j = 0; j < _size - i - 1; ++j) { + if (keys[j + 1] < keys[j]) { + key_t key = keys[j]; + value_t val = vals[j]; + keys[j] = keys[j + 1]; + vals[j] = vals[j + 1]; + keys[j + 1] = key; + vals[j + 1] = val; + } + } + } + } + + private: + key_t* keys; + value_t* vals; + int K; + int _size; + key_t max_key; + int max_idx; +}; + +// This is a version of MinK that only touches the arrays using static indexing +// via RegisterIndexUtils. If the keys and values are stored in thread-local +// arrays, then this may allow the compiler to place them in registers for +// fast access. +// +// This has the same API as RegisterMinK, but doesn't support sorting. +// We found that sorting via RegisterIndexUtils gave very poor performance, +// and suspect it may have prevented the compiler from placing the arrays +// into registers. +template +class RegisterMinK { + public: + __device__ RegisterMinK(key_t* keys, value_t* vals) + : keys(keys), vals(vals), _size(0) {} + + __device__ __forceinline__ void add(const key_t& key, const value_t& val) { + if (_size < K) { + RegisterIndexUtils::set(keys, _size, key); + RegisterIndexUtils::set(vals, _size, val); + if (_size == 0 || key > max_key) { + max_key = key; + max_idx = _size; + } + _size++; + } else if (key < max_key) { + RegisterIndexUtils::set(keys, max_idx, key); + RegisterIndexUtils::set(vals, max_idx, val); + max_key = key; + for (int k = 0; k < K; ++k) { + key_t cur_key = RegisterIndexUtils::get(keys, k); + if (cur_key > max_key) { + max_key = cur_key; + max_idx = k; + } + } + } + } + + __device__ __forceinline__ int size() { + return _size; + } + + private: + key_t* keys; + value_t* vals; + int _size; + key_t max_key; + int max_idx; +}; \ No newline at end of file diff --git a/unidepth/ops/knn/src/utils/pytorch3d_cutils.h b/unidepth/ops/knn/src/utils/pytorch3d_cutils.h new file mode 100644 index 0000000..c46b5ae --- /dev/null +++ b/unidepth/ops/knn/src/utils/pytorch3d_cutils.h @@ -0,0 +1,17 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor.") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous.") +#define CHECK_CONTIGUOUS_CUDA(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) \ No newline at end of file diff --git a/unidepth/ops/losses.py b/unidepth/ops/losses.py deleted file mode 100644 index c082769..0000000 --- a/unidepth/ops/losses.py +++ /dev/null @@ -1,428 +0,0 @@ -""" -Author: Luigi Piccinelli -Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) -""" - -from typing import Any, Dict, List, Optional, Tuple - -import torch -import torch.nn as nn -import torch.nn.functional as F - -FNS = { - "sqrt": torch.sqrt, - "log": torch.log, - "log1": lambda x: torch.log(x + 1), - "linear": lambda x: x, - "square": torch.square, - "disp": lambda x: 1 / x, -} - - -FNS_INV = { - "sqrt": torch.square, - "log": torch.exp, - "log1": lambda x: torch.exp(x) - 1, - "linear": lambda x: x, - "square": torch.sqrt, - "disp": lambda x: 1 / x, -} - - -def masked_mean_var(data: torch.Tensor, mask: torch.Tensor, dim: List[int]): - if mask is None: - return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True) - mask = mask.float() - mask_sum = torch.sum(mask, dim=dim, keepdim=True) - mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp( - mask_sum, min=1.0 - ) - mask_var = torch.sum( - mask * (data - mask_mean) ** 2, dim=dim, keepdim=True - ) / torch.clamp(mask_sum, min=1.0) - return mask_mean.squeeze(dim), mask_var.squeeze(dim) - - -def masked_mean(data: torch.Tensor, mask: torch.Tensor | None, dim: List[int]): - if mask is None: - return data.mean(dim=dim, keepdim=True) - mask = mask.float() - mask_sum = torch.sum(mask, dim=dim, keepdim=True) - mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp( - mask_sum, min=1.0 - ) - return mask_mean - - -def masked_mae(data: torch.Tensor, mask: torch.Tensor, dim: Tuple[int, ...]): - if mask is None: - return data.abs().mean(dim=dim, keepdim=True) - mask = mask.float() - mask_sum = torch.sum(mask, dim=dim, keepdim=True) - mask_mean = torch.sum(data.abs() * mask, dim=dim, keepdim=True) / torch.clamp( - mask_sum, min=1.0 - ) - return mask_mean - - -def masked_mse(data: torch.Tensor, mask: torch.Tensor, dim: Tuple[int, ...]): - if mask is None: - return (data**2).mean(dim=dim, keepdim=True) - mask = mask.float() - mask_sum = torch.sum(mask, dim=dim, keepdim=True) - mask_mean = torch.sum((data**2) * mask, dim=dim, keepdim=True) / torch.clamp( - mask_sum, min=1.0 - ) - return mask_mean - - -def masked_median(data: torch.Tensor, mask: torch.Tensor, dim: List[int]): - ndim = data.ndim - data = data.flatten(ndim - len(dim)) - mask = mask.flatten(ndim - len(dim)) - mask_median = torch.median(data[mask], dim=-1).values - return mask_median - - -def masked_median_mad(data: torch.Tensor, mask: torch.Tensor): - data = data.flatten() - mask = mask.flatten() - mask_median = torch.median(data[mask]) - n_samples = torch.clamp(torch.sum(mask.float()), min=1.0) - mask_mad = torch.sum((data[mask] - mask_median).abs()) / n_samples - return mask_median, mask_mad - - -def masked_weighted_mean_var( - data: torch.Tensor, mask: torch.Tensor, weights: torch.Tensor, dim: Tuple[int, ...] -): - if mask is None: - return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True) - mask = mask.float() - mask_mean = torch.sum(data * mask * weights, dim=dim, keepdim=True) / torch.sum( - mask * weights, dim=dim, keepdim=True - ).clamp(min=1.0) - # V1**2 - V2, V1: sum w_i, V2: sum w_i**2 - denom = torch.sum(weights * mask, dim=dim, keepdim=True).square() - torch.sum( - (mask * weights).square(), dim=dim, keepdim=True - ) - # correction is V1 / (V1**2 - V2), if w_i=1 => N/(N**2 - N) => 1/(N-1) (unbiased estimator of variance, cvd) - correction_factor = torch.sum(mask * weights, dim=dim, keepdim=True) / denom.clamp( - min=1.0 - ) - mask_var = correction_factor * torch.sum( - weights * mask * (data - mask_mean) ** 2, dim=dim, keepdim=True - ) - return mask_mean, mask_var - - -def masked_mean_var_q(data: torch.Tensor, mask: torch.Tensor, dim: List[int]): - if mask is None: - return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True) - mask = mask.float() - mask_sum = torch.sum(mask, dim=dim, keepdim=True) - mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp( - mask_sum, min=1.0 - ) - mask_var = torch.sum( - mask * (data - mask_mean) ** 2, dim=dim, keepdim=True - ) / torch.clamp(mask_sum, min=1.0) - return mask_mean, mask_var - - -class SILog(nn.Module): - def __init__( - self, - weight: float, - scale_pred_weight: float = 0.15, - output_fn: str = "sqrt", - input_fn: str = "log", - legacy: bool = False, - abs_rel: bool = False, - norm: bool = False, - eps: float = 1e-5, - ): - super().__init__() - assert output_fn in FNS - self.name: str = self.__class__.__name__ - self.weight: float = weight - - self.scale_pred_weight: float = scale_pred_weight - self.dims = (-4, -3, -2, -1) if legacy else (-2, -1) - self.output_fn = FNS[output_fn] - self.input_fn = FNS[input_fn] - self.abs_rel = abs_rel - self.norm = norm - self.eps: float = eps - - @torch.cuda.amp.autocast(enabled=False) - def forward( - self, - input: torch.Tensor, - target: torch.Tensor, - mask: Optional[torch.Tensor] = None, - interpolate: bool = True, - scale_inv: torch.Tensor | None = None, - ss_inv: torch.Tensor | None = None, - **kwargs, - ) -> torch.Tensor: - if interpolate: - input = F.interpolate( - input, target.shape[-2:], mode="bilinear", align_corners=False - ) - if mask is not None: - mask = mask.to(torch.bool) - if ss_inv is not None: - ss_inv = ~ss_inv - - if input.shape[1] > 1: - input_ = torch.cat( - [input[:, :-1], self.input_fn(input[:, -1:].clamp(min=self.eps))], dim=1 - ) - target_ = torch.cat( - [target[:, :-1], self.input_fn(target[:, -1:].clamp(min=self.eps))], - dim=1, - ) - error = torch.norm(input_ - target_, dim=1, keepdim=True) - else: - input_ = self.input_fn(input.clamp(min=self.eps)) - target_ = self.input_fn(target.clamp(min=self.eps)) - error = input_ - target_ - - mean_error, var_error = masked_mean_var(data=error, mask=mask, dim=self.dims) - - # prevoiusly was inverted!! - if self.abs_rel: - scale_error = (input - target).abs()[:, -1:] / target[:, -1:].clip( - min=self.eps - ) - scale_error = masked_mean(data=scale_error, mask=mask, dim=self.dims) - else: - scale_error = mean_error**2 - - if var_error.ndim > 1: - var_error = var_error.sum(dim=1) - scale_error = scale_error.sum(dim=1) - - # if scale inv -> mask scale error, if scale/shift, mask the full loss - if scale_inv is not None: - scale_error = (1 - scale_inv.int()) * scale_error - scale_error = self.scale_pred_weight * scale_error - loss = var_error + scale_error - out_loss = self.output_fn(loss.clamp(min=self.eps)) - out_loss = masked_mean(data=out_loss, mask=ss_inv, dim=(0,)) - return out_loss.mean() - - @classmethod - def build(cls, config: Dict[str, Any]): - obj = cls( - weight=config["weight"], - legacy=config["legacy"], - output_fn=config["output_fn"], - input_fn=config["input_fn"], - norm=config.get("norm", False), - scale_pred_weight=config.get("gamma", 0.15), - abs_rel=config.get("abs_rel", False), - ) - return obj - - -class MSE(nn.Module): - def __init__( - self, - weight: float = 1.0, - input_fn: str = "linear", - output_fn: str = "linear", - ): - super().__init__() - self.name: str = self.__class__.__name__ - self.output_fn = FNS[output_fn] - self.input_fn = FNS[input_fn] - self.weight: float = weight - self.eps = 1e-6 - - @torch.cuda.amp.autocast(enabled=False) - def forward( - self, - input: torch.Tensor, - target: torch.Tensor, - mask: torch.Tensor | None = None, - batch_mask: torch.Tensor | None = None, - **kwargs, - ) -> torch.Tensor: - input = input[..., : target.shape[-1]] # B N C or B H W C - error = self.input_fn(input + self.eps) - self.input_fn(target + self.eps) - abs_error = torch.square(error).sum(dim=-1) - mean_error = masked_mean(data=abs_error, mask=mask, dim=(-1,)).mean(dim=-1) - batched_error = masked_mean( - self.output_fn(mean_error.clamp(self.eps)), batch_mask, dim=(0,) - ) - return batched_error.mean(), mean_error.detach() - - @classmethod - def build(cls, config: Dict[str, Any]): - obj = cls( - weight=config["weight"], - output_fn=config["output_fn"], - input_fn=config["input_fn"], - ) - return obj - - -class SelfCons(nn.Module): - def __init__( - self, - weight: float, - scale_pred_weight: float = 0.15, - output_fn: str = "sqrt", - input_fn: str = "log", - abs_rel: bool = False, - norm: bool = False, - eps: float = 1e-5, - ): - super().__init__() - assert output_fn in FNS - self.name: str = self.__class__.__name__ - self.weight: float = weight - - self.scale_pred_weight: float = scale_pred_weight - self.dims = (-2, -1) - self.output_fn = FNS[output_fn] - self.input_fn = FNS[input_fn] - self.abs_rel = abs_rel - self.norm = norm - self.eps: float = eps - - @torch.cuda.amp.autocast(enabled=False) - def forward( - self, - input: torch.Tensor, - mask: torch.Tensor, - metas: List[Dict[str, torch.Tensor]], - ) -> torch.Tensor: - chunks = input.shape[0] // 2 - device = input.device - mask = F.interpolate(mask.float(), size=input.shape[-2:], mode="nearest") - - rescales = input.shape[-2] / torch.tensor( - [x["resized_shape"][0] for x in metas], device=device - ) - cams = torch.cat([x["K_target"] for x in metas], dim=0).to(device) - flips = torch.tensor([x["flip"] for x in metas], device=device) - - iters = zip( - input.chunk(chunks), - mask.chunk(chunks), - cams.chunk(chunks), - rescales.chunk(chunks), - flips.chunk(chunks), - ) - inputs0, inputs1, masks = [], [], [] - for i, (pair_input, pair_mask, pair_cam, pair_rescale, pair_flip) in enumerate( - iters - ): - mask0, mask1 = pair_mask - input0, input1 = pair_input - cam0, cam1 = pair_cam - rescale0, rescale1 = pair_rescale - flip0, flip1 = pair_flip - - fx_0 = cam0[0, 0] * rescale0 - fx_1 = cam1[0, 0] * rescale1 - cx_0 = (cam0[0, 2] - 0.5) * rescale0 + 0.5 - cx_1 = (cam1[0, 2] - 0.5) * rescale1 + 0.5 - cy_0 = (cam0[1, 2] - 0.5) * rescale0 + 0.5 - cy_1 = (cam1[1, 2] - 0.5) * rescale1 + 0.5 - - # flip image - if flip0 ^ flip1: - input0 = torch.flip(input0, dims=(2,)) - mask0 = torch.flip(mask0, dims=(2,)) - cx_0 = input0.shape[-1] - cx_0 - - # calc zoom - zoom_x = float(fx_1 / fx_0) - - # apply zoom - input0 = F.interpolate( - input0.unsqueeze(0), - scale_factor=zoom_x, - mode="bilinear", - align_corners=True, - ).squeeze(0) - mask0 = F.interpolate( - mask0.unsqueeze(0), scale_factor=zoom_x, mode="nearest" - ).squeeze(0) - - # calc translation - change_left = int(cx_1 - (cx_0 - 0.5) * zoom_x - 0.5) - change_top = int(cy_1 - (cy_0 - 0.5) * zoom_x - 0.5) - change_right = input1.shape[-1] - change_left - input0.shape[-1] - change_bottom = input1.shape[-2] - change_top - input0.shape[-2] - - # apply translation - pad_left = max(0, change_left) - pad_right = max(0, change_right) - pad_top = max(0, change_top) - pad_bottom = max(0, change_bottom) - - crop_left = max(0, -change_left) - crop_right = max(0, -change_right) - crop_top = max(0, -change_top) - crop_bottom = max(0, -change_bottom) - - input0 = F.pad( - input0, - (pad_left, pad_right, pad_top, pad_bottom), - mode="constant", - value=0, - ) - mask0 = F.pad( - mask0, - (pad_left, pad_right, pad_top, pad_bottom), - mode="constant", - value=0, - ) - input0 = input0[ - :, - crop_top : input0.shape[-2] - crop_bottom, - crop_left : input0.shape[-1] - crop_right, - ] - mask0 = mask0[ - :, - crop_top : mask0.shape[-2] - crop_bottom, - crop_left : mask0.shape[-1] - crop_right, - ] - - mask = torch.logical_and(mask0, mask1) - - inputs0.append(input0) - inputs1.append(input1) - masks.append(mask) - - inputs0 = torch.stack(inputs0, dim=0) - inputs1 = torch.stack(inputs1, dim=0) - masks = torch.stack(masks, dim=0) - loss1 = self.loss(inputs0, inputs1.detach(), masks) - loss2 = self.loss(inputs1, inputs0.detach(), masks) - return torch.cat([loss1, loss2], dim=0).mean() - - def loss( - self, - input: torch.Tensor, - target: torch.Tensor, - mask: torch.Tensor, - ) -> torch.Tensor: - loss = masked_mean( - (input - target).square().mean(dim=1), mask=mask, dim=(-2, -1) - ) - return self.output_fn(loss + self.eps) - - @classmethod - def build(cls, config: Dict[str, Any]): - obj = cls( - weight=config["weight"], - output_fn=config["output_fn"], - input_fn=config["input_fn"], - ) - return obj diff --git a/unidepth/ops/losses/__init__.py b/unidepth/ops/losses/__init__.py new file mode 100644 index 0000000..ea7cc74 --- /dev/null +++ b/unidepth/ops/losses/__init__.py @@ -0,0 +1,7 @@ +from .arel import ARel +from .confidence import Confidence +from .distill import SelfDistill, TeacherDistill +from .dummy import Dummy +from .local_ssi import LocalSSI +from .regression import Regression +from .silog import SILog diff --git a/unidepth/ops/losses/arel.py b/unidepth/ops/losses/arel.py new file mode 100644 index 0000000..9d9c0d9 --- /dev/null +++ b/unidepth/ops/losses/arel.py @@ -0,0 +1,46 @@ +import torch +import torch.nn as nn + +from .utils import FNS, masked_mean + + +class ARel(nn.Module): + def __init__( + self, + weight: float, + output_fn: str = "sqrt", + input_fn: str = "linear", + eps: float = 1e-5, + ): + super().__init__() + self.name: str = self.__class__.__name__ + self.weight: float = weight + self.dims = [-2, -1] + self.output_fn = FNS[output_fn] + self.input_fn = FNS[input_fn] + self.eps: float = eps + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def forward( + self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, **kwargs + ) -> torch.Tensor: + mask = mask.bool().clone() + + input = self.input_fn(input.float()) + target = self.input_fn(target.float()) + + error = (input - target).norm(dim=1) / target.norm(dim=1).clip(min=0.05) + mask = mask.squeeze(1) + + error_image = masked_mean(data=error, mask=mask, dim=self.dims).squeeze(1, 2) + error_image = self.output_fn(error_image) + return error_image + + @classmethod + def build(cls, config): + obj = cls( + weight=config["weight"], + output_fn=config["output_fn"], + input_fn=config["input_fn"], + ) + return obj diff --git a/unidepth/ops/losses/confidence.py b/unidepth/ops/losses/confidence.py new file mode 100644 index 0000000..c7b4b74 --- /dev/null +++ b/unidepth/ops/losses/confidence.py @@ -0,0 +1,63 @@ +import torch +import torch.nn as nn + +from .utils import FNS, masked_mean + + +class Confidence(nn.Module): + def __init__( + self, + weight: float, + output_fn: str = "sqrt", + input_fn: str = "linear", + rescale: bool = True, + eps: float = 1e-5, + ): + super(Confidence, self).__init__() + self.name: str = self.__class__.__name__ + self.weight = weight + self.rescale = rescale + self.eps = eps + self.output_fn = FNS[output_fn] + self.input_fn = FNS[input_fn] + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def forward( + self, + input: torch.Tensor, + target_pred: torch.Tensor, + target_gt: torch.Tensor, + mask: torch.Tensor, + ): + B, C = target_gt.shape[:2] + mask = mask.bool() + target_gt = target_gt.float().reshape(B, C, -1) + target_pred = target_pred.float().reshape(B, C, -1) + input = input.float().reshape(B, -1) + mask = mask.reshape(B, -1) + + if self.rescale: + target_pred = torch.stack( + [ + p * torch.median(gt[:, m]) / torch.median(p[:, m]) + for p, gt, m in zip(target_pred, target_gt, mask) + ] + ) + + error = torch.abs( + (self.input_fn(target_pred) - self.input_fn(target_gt)).norm(dim=1) - input + ) + losses = masked_mean(error, dim=[-1], mask=mask).squeeze(dim=-1) + losses = self.output_fn(losses) + + return losses + + @classmethod + def build(cls, config): + obj = cls( + weight=config["weight"], + output_fn=config["output_fn"], + input_fn=config["input_fn"], + rescale=config["rescale"], + ) + return obj diff --git a/unidepth/ops/losses/distill.py b/unidepth/ops/losses/distill.py new file mode 100644 index 0000000..6871a34 --- /dev/null +++ b/unidepth/ops/losses/distill.py @@ -0,0 +1,221 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from .utils import FNS, masked_mean + + +class SelfDistill(nn.Module): + def __init__(self, weight: float, output_fn: str = "sqrt", eps: float = 1e-5): + super().__init__() + self.name: str = self.__class__.__name__ + self.weight: float = weight + self.dims = (-2, -1) + self.output_fn = FNS[output_fn] + self.eps: float = eps + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def forward( + self, + input: torch.Tensor, + intrinsics: torch.Tensor, + mask: torch.Tensor, + flips: torch.Tensor, + downsample_ratio=14, + ) -> torch.Tensor: + chunks = input.shape[0] // 2 + mask = F.interpolate(mask.float(), size=input.shape[-2:], mode="nearest") + + iters = zip( + input.chunk(chunks), + mask.chunk(chunks), + intrinsics.chunk(chunks), + flips.chunk(chunks), + ) + inputs0, inputs1, masks = [], [], [] + for i, (pair_input, pair_mask, pair_cam, pair_flip) in enumerate(iters): + + mask0, mask1 = pair_mask + input0, input1 = pair_input + cam0, cam1 = pair_cam + flip0, flip1 = pair_flip + + fx_0 = cam0[0, 0] / downsample_ratio + fx_1 = cam1[0, 0] / downsample_ratio + cx_0 = cam0[0, 2] / downsample_ratio + cx_1 = cam1[0, 2] / downsample_ratio + cy_0 = cam0[1, 2] / downsample_ratio + cy_1 = cam1[1, 2] / downsample_ratio + + # flip image + if flip0 ^ flip1: + input0 = torch.flip(input0, dims=(2,)) + mask0 = torch.flip(mask0, dims=(2,)) + cx_0 = input0.shape[-1] - cx_0 + + # calc zoom + zoom_x = float(fx_1 / fx_0) + + # apply zoom + input0 = F.interpolate( + input0.unsqueeze(0), scale_factor=zoom_x, mode="bilinear" + ).squeeze(0) + mask0 = F.interpolate( + mask0.unsqueeze(0), scale_factor=zoom_x, mode="nearest" + ).squeeze(0) + + # calc translation + change_left = int(cx_1 - (cx_0 - 0.5) * zoom_x - 0.5) + change_top = int(cy_1 - (cy_0 - 0.5) * zoom_x - 0.5) + change_right = input1.shape[-1] - change_left - input0.shape[-1] + change_bottom = input1.shape[-2] - change_top - input0.shape[-2] + + # apply translation + pad_left = max(0, change_left) + pad_right = max(0, change_right) + pad_top = max(0, change_top) + pad_bottom = max(0, change_bottom) + + crop_left = max(0, -change_left) + crop_right = max(0, -change_right) + crop_top = max(0, -change_top) + crop_bottom = max(0, -change_bottom) + + input0 = F.pad( + input0, + (pad_left, pad_right, pad_top, pad_bottom), + mode="constant", + value=0, + ) + mask0 = F.pad( + mask0, + (pad_left, pad_right, pad_top, pad_bottom), + mode="constant", + value=0, + ) + input0 = input0[ + :, + crop_top : input0.shape[-2] - crop_bottom, + crop_left : input0.shape[-1] - crop_right, + ] + mask0 = mask0[ + :, + crop_top : mask0.shape[-2] - crop_bottom, + crop_left : mask0.shape[-1] - crop_right, + ] + + mask = torch.logical_and(mask0, mask1) + + inputs0.append(input0) + inputs1.append(input1) + masks.append(mask) + + inputs0 = torch.stack(inputs0, dim=0) + inputs1 = torch.stack(inputs1, dim=0) + masks = torch.stack(masks, dim=0) + loss1 = self.loss(inputs0, inputs1.detach(), masks) + loss2 = self.loss(inputs1, inputs0.detach(), masks) + return torch.cat([loss1, loss2], dim=0) + + def loss( + self, + input: torch.Tensor, + target: torch.Tensor, + mask: torch.Tensor, + ) -> torch.Tensor: + loss = masked_mean( + (input - target).square().mean(dim=1), mask=mask, dim=[-2, -1] + ) + return self.output_fn(loss + self.eps) + + @classmethod + def build(cls, config): + obj = cls( + weight=config["weight"], + output_fn=config["output_fn"], + ) + return obj + + +class TeacherDistill(nn.Module): + def __init__( + self, + weight: float, + output_fn: str = "sqrt", + cross: bool = False, + eps: float = 1e-5, + ): + super().__init__() + assert output_fn in FNS + self.name: str = self.__class__.__name__ + self.weight: float = weight + self.dims = (-2, -1) + self.output_fn = FNS[output_fn] + self.eps: float = eps + self.cross = cross + self.threshold = 0.05 + self.head_dim = 64 # hardcoded for vit + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def forward( + self, + student_features: torch.Tensor, + teacher_features: torch.Tensor, + student_tokens: torch.Tensor, + teacher_tokens: torch.Tensor, + mask: torch.Tensor, + # metas: List[Dict[str, torch.Tensor]], + ) -> torch.Tensor: + B = student_features.shape[0] + device = student_features.device + chunks = student_features.shape[0] // 2 + + mask = ( + F.interpolate( + mask.float() + 1e-3, size=student_features.shape[-2:], mode="nearest" + ) + > 0.5 + ) + + # chunk features as self.head_dim + student_features = rearrange( + student_features, "b (n c) h w -> b c h w n", c=self.head_dim + ) + teacher_features = rearrange( + teacher_features, "b (n c) h w -> b c h w n", c=self.head_dim + ) + student_tokens = rearrange( + student_tokens, "b t (n c) -> b t c n", c=self.head_dim + ) + teacher_tokens = rearrange( + teacher_tokens, "b t (n c) -> b t c n", c=self.head_dim + ) + + distance = ( + (student_features - teacher_features) + .square() + .sum(dim=1, keepdim=True) + .sqrt() + .mean(dim=-1) + ) + loss_features = masked_mean(distance, mask=mask, dim=[-2, -1]) + loss_features = self.output_fn(loss_features.clamp(min=self.eps)).squeeze( + 1, 2, 3 + ) + + distance = ( + (student_tokens - teacher_tokens).square().sum(dim=-2).sqrt().mean(dim=-1) + ) + loss_tokens = self.output_fn(distance.clamp(min=self.eps)).squeeze(1) + + return loss_features + 0.01 * loss_tokens + + @classmethod + def build(cls, config): + obj = cls( + weight=config["weight"], + output_fn=config["output_fn"], + cross=config["cross"], + ) + return obj diff --git a/unidepth/ops/losses/dummy.py b/unidepth/ops/losses/dummy.py new file mode 100644 index 0000000..77a4c99 --- /dev/null +++ b/unidepth/ops/losses/dummy.py @@ -0,0 +1,17 @@ +import torch +import torch.nn as nn + + +class Dummy(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + self.name: str = self.__class__.__name__ + self.weight = 1.0 + + def forward(self, dummy: torch.Tensor, *args, **kwargs) -> torch.Tensor: + return torch.tensor([0.0] * dummy.shape[0], device=dummy.device) + + @classmethod + def build(cls, config): + obj = cls() + return obj diff --git a/unidepth/ops/losses/local_ssi.py b/unidepth/ops/losses/local_ssi.py new file mode 100644 index 0000000..e6c40fe --- /dev/null +++ b/unidepth/ops/losses/local_ssi.py @@ -0,0 +1,137 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .utils import FNS, masked_mean, ssi + + +class LocalSSI(nn.Module): + def __init__( + self, + weight: float, + output_fn: str = "sqrt", + patch_size: tuple[int, int] = (32, 32), + min_samples: int = 4, + num_levels: int = 4, + input_fn: str = "linear", + eps: float = 1e-5, + ): + super(LocalSSI, self).__init__() + self.name: str = self.__class__.__name__ + self.weight = weight + self.output_fn = FNS[output_fn] + self.input_fn = FNS[input_fn] + self.min_samples = min_samples + self.eps = eps + patch_logrange = np.linspace( + start=np.log2(min(patch_size)), + stop=np.log2(max(patch_size)), + endpoint=True, + num=num_levels + 1, + ) + self.patch_logrange = [ + (x, y) for x, y in zip(patch_logrange[:-1], patch_logrange[1:]) + ] + self.rescale_fn = ssi + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def forward( + self, + input: torch.Tensor, + target: torch.Tensor, + mask: torch.Tensor, + *args, + **kwargs, + ) -> torch.Tensor: + mask = mask.bool() + input = self.input_fn(input.float()) + target = self.input_fn(target.float()) + B, C, H, W = input.shape + total_errors = [] + + for ii, patch_logrange in enumerate(self.patch_logrange): + + log_kernel = ( + np.random.uniform(*patch_logrange) + if self.training + else np.mean(patch_logrange) + ) + kernel_size = int( + (2**log_kernel) * min(input.shape[-2:]) + ) # always smaller than min_shape + kernel_size = (kernel_size, kernel_size) + stride = (int(kernel_size[0] * 0.9), int(kernel_size[1] * 0.9)) + + # unfold is always exceeding right/bottom, roll image only negative + # to have them back in the unfolding window + max_roll = ( + (W - kernel_size[1]) % stride[1], + (H - kernel_size[0]) % stride[0], + ) + roll_x, roll_y = np.random.randint(-max_roll[0], 1), np.random.randint( + -max_roll[1], 1 + ) + input_fold = torch.roll(input, shifts=(roll_y, roll_x), dims=(2, 3)) + target_fold = torch.roll(target, shifts=(roll_y, roll_x), dims=(2, 3)) + mask_fold = torch.roll(mask.float(), shifts=(roll_y, roll_x), dims=(2, 3)) + + # unfold in patches + input_fold = F.unfold( + input_fold, kernel_size=kernel_size, stride=stride + ).permute( + 0, 2, 1 + ) # B N C*H_p*W_p + target_fold = F.unfold( + target_fold, kernel_size=kernel_size, stride=stride + ).permute(0, 2, 1) + mask_fold = ( + F.unfold(mask_fold, kernel_size=kernel_size, stride=stride) + .bool() + .permute(0, 2, 1) + ) + + # calculate error patchwise, then mean over patch, then over image based if sample size is significant + input_fold, target_fold, _ = self.rescale_fn( + input_fold, target_fold, mask_fold, dim=[-1] + ) + error = (input_fold - target_fold).abs() + + # calculate elements more then 95 percentile and lower than 5percentile of error + valid_patches = mask_fold.sum(dim=-1) >= self.min_samples + error_mean_patch = masked_mean(error, mask_fold, dim=[-1]).squeeze(-1) + error_mean_image = self.output_fn(error_mean_patch.clamp(min=self.eps)) + error_mean_image = masked_mean( + error_mean_image, mask=valid_patches, dim=[-1] + ) + total_errors.append(error_mean_image.squeeze(-1)) + + # global + input_rescale = input.reshape(B, C, -1) + target_rescale = target.reshape(B, C, -1) + mask = mask.reshape(B, 1, -1).clone() + input, target, mask = self.rescale_fn( + input_rescale, target_rescale, mask, dim=[-1] + ) + error = (input - target).abs().squeeze(1) + + mask = mask.squeeze(1) + error_mean_image = masked_mean(error, mask, dim=[-1]).squeeze(-1) + error_mean_image = self.output_fn(error_mean_image.clamp(min=self.eps)) + + total_errors.append(error_mean_image) + + errors = torch.stack(total_errors).mean(dim=0) + return errors + + @classmethod + def build(cls, config): + obj = cls( + weight=config["weight"], + patch_size=config["patch_size"], + output_fn=config["output_fn"], + min_samples=config["min_samples"], + num_levels=config["num_levels"], + input_fn=config["input_fn"], + ) + return obj diff --git a/unidepth/ops/losses/regression.py b/unidepth/ops/losses/regression.py new file mode 100644 index 0000000..d53a45a --- /dev/null +++ b/unidepth/ops/losses/regression.py @@ -0,0 +1,62 @@ +import torch +import torch.nn as nn + +from .utils import FNS, REGRESSION_DICT, masked_mean, masked_quantile + + +class Regression(nn.Module): + def __init__( + self, + weight: float, + input_fn: str, + output_fn: str, + alpha: float, + gamma: float, + fn: str, + dims: list[int] = [-1], + quantile: float = 0.0, + **kwargs, + ): + super().__init__() + self.name = self.__class__.__name__ + self.output_fn = FNS[output_fn] + self.input_fn = FNS[input_fn] + self.weight = weight + self.dims = dims + self.quantile = quantile + self.alpha = alpha + self.gamma = gamma + self.fn = REGRESSION_DICT[fn] + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def forward( + self, + input: torch.Tensor, + target: torch.Tensor, + mask: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + if mask is not None: # usually it is just repeated + mask = mask[:, 0] + + input = self.input_fn(input.float()) + target = self.input_fn(target.float()) + error = self.fn(input - target, gamma=self.gamma, alpha=self.alpha).mean(dim=1) + mean_error = masked_mean(data=error, mask=mask, dim=self.dims).squeeze( + self.dims + ) + mean_error = self.output_fn(mean_error) + return mean_error + + @classmethod + def build(cls, config): + obj = cls( + weight=config["weight"], + output_fn=config["output_fn"], + input_fn=config["input_fn"], + dims=config.get("dims", (-1,)), + alpha=config["alpha"], + gamma=config["gamma"], + fn=config["fn"], + ) + return obj diff --git a/unidepth/ops/losses/silog.py b/unidepth/ops/losses/silog.py new file mode 100644 index 0000000..9a67cc3 --- /dev/null +++ b/unidepth/ops/losses/silog.py @@ -0,0 +1,61 @@ +import torch +import torch.nn as nn + +from .utils import (FNS, REGRESSION_DICT, masked_mean, masked_mean_var, + masked_quantile) + + +class SILog(nn.Module): + def __init__( + self, + weight: float, + input_fn: str = "linear", + output_fn: str = "sqrt", + integrated: float = 0.15, + dims: list[int] = [-3, -2, -1], + eps: float = 1e-5, + ): + super().__init__() + self.name: str = self.__class__.__name__ + self.weight: float = weight + + self.dims = dims + self.input_fn = FNS[input_fn] + self.output_fn = FNS[output_fn] + self.eps: float = eps + self.integrated = integrated + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def forward( + self, + input: torch.Tensor, + target: torch.Tensor, + mask: torch.Tensor, + si: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + mask = mask.bool() + error = self.input_fn(input.float()) - self.input_fn(target.float()) + mean_error, var_error = masked_mean_var( + data=error, mask=mask, dim=self.dims, keepdim=False + ) + if var_error.ndim > 1: + var_error = var_error.mean(dim=-1) + + if self.integrated > 0.0: + scale_error = mean_error**2 + var_error = var_error + self.integrated * scale_error * (1 - si.int()) + + out_loss = self.output_fn(var_error) + return out_loss + + @classmethod + def build(cls, config): + obj = cls( + weight=config["weight"], + dims=config["dims"], + output_fn=config["output_fn"], + input_fn=config["input_fn"], + integrated=config.get("integrated", 0.15), + ) + return obj diff --git a/unidepth/ops/losses/utils.py b/unidepth/ops/losses/utils.py new file mode 100644 index 0000000..32a10cd --- /dev/null +++ b/unidepth/ops/losses/utils.py @@ -0,0 +1,238 @@ +from math import prod +from typing import Any, Dict, List, Optional, Tuple + +import torch + +FNS = { + "sqrt": lambda x: torch.sqrt(x + 1e-4), + "log": lambda x: torch.log(x + 1e-4), + "log1": lambda x: torch.log(x + 1), + "linear": lambda x: x, + "square": torch.square, + "disp": lambda x: 1 / (x + 1e-4), + "disp1": lambda x: 1 / (1 + x), +} + + +FNS_INV = { + "sqrt": torch.square, + "log": torch.exp, + "log1": lambda x: torch.exp(x) - 1, + "linear": lambda x: x, + "square": torch.sqrt, + "disp": lambda x: 1 / x, +} + + +def masked_mean_var( + data: torch.Tensor, mask: torch.Tensor, dim: List[int], keepdim: bool = True +): + if mask is None: + return data.mean(dim=dim, keepdim=keepdim), data.var(dim=dim, keepdim=keepdim) + mask = mask.float() + mask_sum = torch.sum(mask, dim=dim, keepdim=True) + # data = torch.nan_to_num(data, nan=0.0) + mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp( + mask_sum, min=1.0 + ) + mask_var = torch.sum( + mask * (data - mask_mean) ** 2, dim=dim, keepdim=True + ) / torch.clamp(mask_sum, min=1.0) + if not keepdim: + mask_mean, mask_var = mask_mean.squeeze(dim), mask_var.squeeze(dim) + return mask_mean, mask_var + + +def masked_mean(data: torch.Tensor, mask: torch.Tensor | None, dim: List[int]): + if mask is None: + return data.mean(dim=dim, keepdim=True) + mask = mask.float() + mask_sum = torch.sum(mask, dim=dim, keepdim=True) + mask_mean = torch.sum( + torch.nan_to_num(data, nan=0.0) * mask, dim=dim, keepdim=True + ) / mask_sum.clamp(min=1.0) + return mask_mean + + +def masked_quantile( + data: torch.Tensor, mask: torch.Tensor | None, dims: List[int], q: float +): + """ + Compute the quantile of the data only where the mask is 1 along specified dimensions. + + Args: + data (torch.Tensor): The input data tensor. + mask (torch.Tensor): The mask tensor with the same shape as data, containing 1s where data should be considered. + dims (list of int): The dimensions to compute the quantile over. + q (float): The quantile to compute, must be between 0 and 1. + + Returns: + torch.Tensor: The quantile computed over the specified dimensions, ignoring masked values. + """ + masked_data = data * mask if mask is not None else data + + # Get a list of all dimensions + all_dims = list(range(masked_data.dim())) + + # Revert negative dimensions + dims = [d % masked_data.dim() for d in dims] + + # Find the dimensions to keep (not included in the `dims` list) + keep_dims = [d for d in all_dims if d not in dims] + + # Permute dimensions to bring `dims` to the front + permute_order = dims + keep_dims + permuted_data = masked_data.permute(permute_order) + + # Reshape into 2D: (-1, remaining_dims) + collapsed_shape = ( + -1, + prod([permuted_data.size(d) for d in range(len(dims), permuted_data.dim())]), + ) + reshaped_data = permuted_data.reshape(collapsed_shape) + if mask is None: + return torch.quantile(reshaped_data, q, dim=0) + + permuted_mask = mask.permute(permute_order) + reshaped_mask = permuted_mask.reshape(collapsed_shape) + + # Calculate quantile along the first dimension where mask is true + quantiles = [] + for i in range(reshaped_data.shape[1]): + valid_data = reshaped_data[:, i][reshaped_mask[:, i]] + if valid_data.numel() == 0: + # print("Warning: No valid data found for quantile calculation.") + quantiles.append(reshaped_data[:, i].min() * 0.99) + else: + quantiles.append(torch.quantile(valid_data, q, dim=0)) + + # Stack back into a tensor with reduced dimensions + quantiles = torch.stack(quantiles) + quantiles = quantiles.reshape( + [permuted_data.size(d) for d in range(len(dims), permuted_data.dim())] + ) + + return quantiles + + +def masked_median(data: torch.Tensor, mask: torch.Tensor, dim: List[int]): + ndim = data.ndim + data = data.flatten(ndim - len(dim)) + mask = mask.flatten(ndim - len(dim)) + mask_median = torch.median(data[..., mask], dim=-1).values + return mask_median + + +def masked_median_mad(data: torch.Tensor, mask: torch.Tensor, dim: List[int]): + ndim = data.ndim + data = data.flatten(ndim - len(dim)) + mask = mask.flatten(ndim - len(dim)) + mask_median = torch.median(data[mask], dim=-1, keepdim=True).values + mask_mad = masked_mean((data - mask_median).abs(), mask, dim=[-1]) + return mask_median, mask_mad + + +def masked_weighted_mean_var( + data: torch.Tensor, mask: torch.Tensor, weights: torch.Tensor, dim: Tuple[int, ...] +): + if mask is None: + return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True) + mask = mask.float() + mask_mean = torch.sum(data * mask * weights, dim=dim, keepdim=True) / torch.sum( + mask * weights, dim=dim, keepdim=True + ).clamp(min=1.0) + # V1**2 - V2, V1: sum w_i, V2: sum w_i**2 + denom = torch.sum(weights * mask, dim=dim, keepdim=True).square() - torch.sum( + (mask * weights).square(), dim=dim, keepdim=True + ) + # correction is V1 / (V1**2 - V2), if w_i=1 => N/(N**2 - N) => 1/(N-1) (unbiased estimator of variance, cvd) + correction_factor = torch.sum(mask * weights, dim=dim, keepdim=True) / denom.clamp( + min=1.0 + ) + mask_var = correction_factor * torch.sum( + weights * mask * (data - mask_mean) ** 2, dim=dim, keepdim=True + ) + return mask_mean, mask_var + + +def ssi( + input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, dim: list[int] +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # recalculate mask with points in 95% confidence interval + input_detach = input.detach() + input_mean, input_var = masked_mean_var(input_detach, mask=mask, dim=dim) + target_mean, target_var = masked_mean_var(target, mask=mask, dim=dim) + input_std = (input_var).clip(min=1e-6).sqrt() + target_std = (target_var).clip(min=1e-6).sqrt() + stable_points_input = torch.logical_and( + input_detach > input_mean - 1.96 * input_std, + input_detach < input_mean + 1.96 * input_std, + ) + stable_points_target = torch.logical_and( + target > target_mean - 1.96 * target_std, + target < target_mean + 1.96 * target_std, + ) + stable_mask = stable_points_target & stable_points_input & mask + + input_mean, input_var = masked_mean_var(input, mask=stable_mask, dim=dim) + target_mean, target_var = masked_mean_var(target, mask=stable_mask, dim=dim) + target_normalized = (target - target_mean) / FNS["sqrt"](target_var) + input_normalized = (input - input_mean) / FNS["sqrt"](input_var) + return input_normalized, target_normalized, stable_mask + + +def ind2sub(idx, cols): + r = idx // cols + c = idx % cols + return r, c + + +def sub2ind(r, c, cols): + idx = r * cols + c + return idx + + +def l2(input_tensor: torch.Tensor, gamma: float = 1.0, *args, **kwargs) -> torch.Tensor: + return gamma * (input_tensor / gamma) ** 2 + + +def l1(input_tensor: torch.Tensor, gamma: float = 1.0, *args, **kwargs) -> torch.Tensor: + return torch.abs(input_tensor) + + +def charbonnier( + input_tensor: torch.Tensor, gamma: float = 1.0, *args, **kwargs +) -> torch.Tensor: + return torch.sqrt(torch.square(input_tensor) + gamma**2) - gamma + + +def cauchy( + input_tensor: torch.Tensor, gamma: float = 1.0, *args, **kwargs +) -> torch.Tensor: + return gamma * torch.log(torch.square(input_tensor) / gamma + 1) + + +def geman_mcclure( + input_tensor: torch.Tensor, gamma: float = 1.0, *args, **kwargs +) -> torch.Tensor: + return gamma * torch.square(input_tensor) / (torch.square(input_tensor) + gamma) + + +def robust_loss( + input_tensor: torch.Tensor, alpha: float, gamma: float = 1.0, *args, **kwargs +) -> torch.Tensor: + coeff = abs(alpha - 2) / alpha + power = torch.square(input_tensor) / abs(alpha - 2) / (gamma**2) + 1 + return ( + gamma * coeff * (torch.pow(power, alpha / 2) - 1) + ) # mult gamma to keep grad magnitude invariant wrt gamma + + +REGRESSION_DICT = { + "l2": l2, + "l1": l1, + "cauchy": cauchy, + "charbonnier": charbonnier, + "geman_mcclure": geman_mcclure, + "robust_loss": robust_loss, +} diff --git a/unidepth/ops/scheduler.py b/unidepth/ops/scheduler.py index a182ff6..c2d7109 100644 --- a/unidepth/ops/scheduler.py +++ b/unidepth/ops/scheduler.py @@ -1,11 +1,63 @@ -""" -Author: Luigi Piccinelli -Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) -""" +import weakref import numpy as np +class PlainCosineScheduler(object): + def __init__( + self, + klass, + key, + warmup_iters, + total_iters, + overwrite=False, + init_value=None, + base_value=None, + final_value=None, + step_init=-1, + ): + super().__init__() + self.iter = step_init + self.overwrite = overwrite + self.base_value = base_value + self.init_value = init_value if init_value is not None else base_value + self.final_value = final_value + self.total_iters = total_iters + self.warmup_iters = warmup_iters + self.key = key + self.klass = klass + self.schedulers = [self.get_scheduler()] + + def get_scheduler(self): + init_value = self.init_value + base_value = self.base_value + final_value = self.final_value + warmup_iters = self.warmup_iters + total_iters = self.total_iters + + # normalize in 0,1, then apply function (power) and denormalize + normalized_schedule = np.linspace(0, 1, warmup_iters, endpoint=True) + normalized_schedule = np.power(normalized_schedule, 1) + warmup_schedule = (base_value - init_value) * normalized_schedule + init_value + + # main scheduling + iters = np.arange(total_iters - warmup_iters + 1) + schedule = final_value + 0.5 * (base_value - final_value) * ( + 1 + np.cos(np.pi * iters / (len(iters) - 1)) + ) + return np.concatenate((warmup_schedule, schedule)) + + def step(self): + self.iter = self.iter + 1 + vals = self[self.iter] + for i, val in enumerate(vals): + setattr(self.klass, self.key, val) + + def __getitem__(self, it): + it = min(it, self.total_iters) + return [scheduler[it] for scheduler in self.schedulers] + + class CosineScheduler(object): def __init__( self, @@ -44,13 +96,13 @@ def get_schedulers(self, group): # normalize in 0,1, then apply function (power) and denormalize normalized_schedule = np.linspace(0, 1, warmup_iters, endpoint=True) - normalized_schedule = np.power(normalized_schedule, 2) + normalized_schedule = np.power(normalized_schedule, 1) warmup_schedule = (base_value - init_value) * normalized_schedule + init_value # main scheduling - iters = np.arange(total_iters - warmup_iters) + iters = np.arange(total_iters - warmup_iters + 1) schedule = final_value + 0.5 * (base_value - final_value) * ( - 1 + np.cos(np.pi * iters / len(iters)) + 1 + np.cos(np.pi * iters / (len(iters) - 1)) ) return np.concatenate((warmup_schedule, schedule)) @@ -63,7 +115,7 @@ def step(self): group[self.key] = val def __getitem__(self, it): - it = min(it, self.total_iters - 1) + it = min(it, self.total_iters) return [scheduler[it] for scheduler in self.schedulers] def get(self): diff --git a/unidepth/utils/__init__.py b/unidepth/utils/__init__.py index 08d9035..0fa40d4 100644 --- a/unidepth/utils/__init__.py +++ b/unidepth/utils/__init__.py @@ -1,29 +1,12 @@ +from .camera import invert_pinhole +# from .validation import validate +from .coordinate import coords_grid, normalize_coords from .distributed import (barrier, get_dist_info, get_rank, is_main_process, setup_multi_processes, setup_slurm, sync_tensor_across_gpus) -from .evaluation_depth import DICT_METRICS, eval_depth +from .evaluation_depth import (DICT_METRICS, DICT_METRICS_3D, eval_3d, + eval_depth) from .geometric import spherical_zbuffer_to_euclidean, unproject_points -from .misc import format_seconds, get_params, identity, remove_padding +from .misc import (format_seconds, get_params, identity, recursive_index, + remove_padding, to_cpu) from .visualization import colorize, image_grid, log_train_artifacts - -__all__ = [ - "eval_depth", - "DICT_METRICS", - "colorize", - "image_grid", - "log_train_artifacts", - "format_seconds", - "remove_padding", - "get_params", - "identity", - "is_main_process", - "setup_multi_processes", - "setup_slurm", - "sync_tensor_across_gpus", - "barrier", - "get_rank", - "unproject_points", - "spherical_zbuffer_to_euclidean", - "validate", - "get_dist_info", -] diff --git a/unidepth/utils/camera.py b/unidepth/utils/camera.py new file mode 100644 index 0000000..5620078 --- /dev/null +++ b/unidepth/utils/camera.py @@ -0,0 +1,1003 @@ +from copy import deepcopy + +import numpy as np +import torch +import torch.nn.functional as F + +from .coordinate import coords_grid +from .misc import squeeze_list + + +def invert_pinhole(K): + fx = K[..., 0, 0] + fy = K[..., 1, 1] + cx = K[..., 0, 2] + cy = K[..., 1, 2] + K_inv = torch.zeros_like(K) + K_inv[..., 0, 0] = 1.0 / fx + K_inv[..., 1, 1] = 1.0 / fy + K_inv[..., 0, 2] = -cx / fx + K_inv[..., 1, 2] = -cy / fy + K_inv[..., 2, 2] = 1.0 + return K_inv + + +class Camera: + def __init__(self, params=None, K=None): + if params.ndim == 1: + params = params.unsqueeze(0) + + if K is None: + K = ( + torch.eye(3, device=params.device, dtype=params.dtype) + .unsqueeze(0) + .repeat(params.shape[0], 1, 1) + ) + K[..., 0, 0] = params[..., 0] + K[..., 1, 1] = params[..., 1] + K[..., 0, 2] = params[..., 2] + K[..., 1, 2] = params[..., 3] + + self.params = params + self.K = K + self.overlap_mask = None + self.projection_mask = None + + def project(self, xyz): + raise NotImplementedError + + def unproject(self, uv): + raise NotImplementedError + + def get_projection_mask(self): + return self.projection_mask + + def get_overlap_mask(self): + return self.overlap_mask + + def reconstruct(self, depth): + id_coords = coords_grid( + 1, depth.shape[-2], depth.shape[-1], device=depth.device + ) + rays = self.unproject(id_coords) + return ( + rays / rays[:, -1:].clamp(min=1e-4) * depth.clamp(min=1e-4) + ) # assumption z>0!!! + + def resize(self, factor): + self.K[..., :2, :] *= factor + self.params[..., :4] *= factor + return self + + def to(self, device, non_blocking=False): + self.params = self.params.to(device, non_blocking=non_blocking) + self.K = self.K.to(device, non_blocking=non_blocking) + return self + + def get_rays(self, shapes, noisy=False): + b, h, w = shapes + uv = coords_grid(1, h, w, device=self.K.device, noisy=noisy) + rays = self.unproject(uv) + return rays / torch.norm(rays, dim=1, keepdim=True).clamp(min=1e-4) + + def get_pinhole_rays(self, shapes, noisy=False): + b, h, w = shapes + uv = coords_grid(b, h, w, device=self.K.device, homogeneous=True, noisy=noisy) + rays = (invert_pinhole(self.K) @ uv.reshape(b, 3, -1)).reshape(b, 3, h, w) + return rays / torch.norm(rays, dim=1, keepdim=True).clamp(min=1e-4) + + def flip(self, H, W, direction="horizontal"): + new_cx = ( + W - self.params[:, 2] if direction == "horizontal" else self.params[:, 2] + ) + new_cy = H - self.params[:, 3] if direction == "vertical" else self.params[:, 3] + self.params = torch.stack( + [self.params[:, 0], self.params[:, 1], new_cx, new_cy], dim=1 + ) + self.K[..., 0, 2] = new_cx + self.K[..., 1, 2] = new_cy + return self + + def clone(self): + return deepcopy(self) + + def crop(self, left, top, right=None, bottom=None): + self.K[..., 0, 2] -= left + self.K[..., 1, 2] -= top + self.params[..., 2] -= left + self.params[..., 3] -= top + return self + + # helper function to get how fov changes based on new original size and new size + def get_new_fov(self, new_shape, original_shape): + new_hfov = 2 * torch.atan( + self.params[..., 2] / self.params[..., 0] * new_shape[1] / original_shape[1] + ) + new_vfov = 2 * torch.atan( + self.params[..., 3] / self.params[..., 1] * new_shape[0] / original_shape[0] + ) + return new_hfov, new_vfov + + def mask_overlap_projection(self, projected): + B, _, H, W = projected.shape + id_coords = coords_grid(B, H, W, device=projected.device) + + # check for mask where flow would overlap with other part of the image + # eleemtns coming from the border are then masked out + flow = projected - id_coords + gamma = 0.1 + sample_grid = gamma * flow + id_coords # sample along the flow + sample_grid[:, 0] = sample_grid[:, 0] / (W - 1) * 2 - 1 + sample_grid[:, 1] = sample_grid[:, 1] / (H - 1) * 2 - 1 + sampled_flow = F.grid_sample( + flow, + sample_grid.permute(0, 2, 3, 1), + mode="bilinear", + align_corners=False, + padding_mode="border", + ) + mask = ( + (1 - gamma) * torch.norm(flow, dim=1, keepdim=True) + < torch.norm(sampled_flow, dim=1, keepdim=True) + ) | (torch.norm(flow, dim=1, keepdim=True) < 1) + return mask + + def _pad_params(self): + # Ensure params are padded to length 16 + if self.params.shape[1] < 16: + padding = torch.zeros( + 16 - self.params.shape[1], + device=self.params.device, + dtype=self.params.dtype, + ) + padding = padding.unsqueeze(0).repeat(self.params.shape[0], 1) + return torch.cat([self.params, padding], dim=1) + return self.params + + @staticmethod + def flatten_cameras(cameras): # -> list[Camera]: + # Recursively flatten BatchCamera into primitive cameras + flattened_cameras = [] + for camera in cameras: + if isinstance(camera, BatchCamera): + flattened_cameras.extend(BatchCamera.flatten_cameras(camera.cameras)) + else: + flattened_cameras.append(camera) + return flattened_cameras + + @staticmethod + def _stack_or_cat_cameras(cameras, func, **kwargs): + # Generalized method to handle stacking or concatenation + flat_cameras = BatchCamera.flatten_cameras(cameras) + K_matrices = [camera.K for camera in flat_cameras] + padded_params = [camera._pad_params() for camera in flat_cameras] + + stacked_K = func(K_matrices, **kwargs) + stacked_params = func(padded_params, **kwargs) + + # Keep track of the original classes + original_class = [x.__class__.__name__ for x in flat_cameras] + return BatchCamera(stacked_params, stacked_K, original_class, flat_cameras) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + if func is torch.cat: + return Camera._stack_or_cat_cameras(args[0], func, **kwargs) + + if func is torch.stack: + return Camera._stack_or_cat_cameras(args[0], func, **kwargs) + + return super().__torch_function__(func, types, args, kwargs) + + @property + def device(self): + return self.K.device + + # here we assume that cx,cy are more or less H/2 and W/2 + @property + def hfov(self): + return 2 * torch.atan(self.params[..., 2] / self.params[..., 0]) + + @property + def vfov(self): + return 2 * torch.atan(self.params[..., 3] / self.params[..., 1]) + + @property + def max_fov(self): + return 150.0 / 180.0 * np.pi, 150.0 / 180.0 * np.pi + + +class Pinhole(Camera): + def __init__(self, params=None, K=None): + assert params is not None or K is not None + if params is None: + params = torch.stack( + [K[..., 0, 0], K[..., 1, 1], K[..., 0, 2], K[..., 1, 2]], dim=1 + ) + super().__init__(params=params, K=K) + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def project(self, pcd): + b, _, h, w = pcd.shape + pcd_flat = pcd.reshape(b, 3, -1) # [B, 3, H*W] + cam_coords = self.K @ pcd_flat + pcd_proj = cam_coords[:, :2] / cam_coords[:, -1:].clamp(min=0.01) + pcd_proj = pcd_proj.reshape(b, 2, h, w) + invalid = ( + (pcd_proj[:, 0] >= 0) + & (pcd_proj[:, 0] < w) + & (pcd_proj[:, 1] >= 0) + & (pcd_proj[:, 1] < h) + ) + self.projection_mask = (~invalid).unsqueeze(1) + return pcd_proj + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def unproject(self, uv): + b, _, h, w = uv.shape + uv_flat = uv.reshape(b, 2, -1) # [B, 2, H*W] + uv_homogeneous = torch.cat( + [uv_flat, torch.ones(b, 1, h * w, device=uv.device)], dim=1 + ) # [B, 3, H*W] + K_inv = torch.inverse(self.K.float()) + xyz = K_inv @ uv_homogeneous + xyz = xyz / xyz[:, -1:].clip(min=1e-4) + xyz = xyz.reshape(b, 3, h, w) + self.unprojection_mask = xyz[:, -1:] > 1e-4 + return xyz + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def reconstruct(self, depth): + b, _, h, w = depth.shape + uv = coords_grid(b, h, w, device=depth.device) + xyz = self.unproject(uv) * depth.clip(min=0.0) + return xyz + + +class EUCM(Camera): + def __init__(self, params): + super().__init__(params=params, K=None) + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def project(self, xyz): + H, W = xyz.shape[-2:] + fx, fy, cx, cy, alpha, beta = self.params[:6].unbind(dim=1) + x, y, z = xyz.unbind(dim=1) + d = torch.sqrt(beta * (x**2 + y**2) + z**2) + + x = x / (alpha * d + (1 - alpha) * z).clip(min=1e-3) + y = y / (alpha * d + (1 - alpha) * z).clip(min=1e-3) + + Xnorm = fx * x + cx + Ynorm = fy * y + cy + + coords = torch.stack([Xnorm, Ynorm], dim=1) + + invalid = ( + (coords[:, 0] < 0) + | (coords[:, 0] > W) + | (coords[:, 1] < 0) + | (coords[:, 1] > H) + | (z < 0) + ) + self.projection_mask = (~invalid).unsqueeze(1) + + return coords + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def unproject(self, uv): + u, v = uv.unbind(dim=1) + fx, fy, cx, cy, alpha, beta = self.params.unbind(dim=1) + mx = (u - cx) / fx + my = (v - cy) / fy + r_square = mx**2 + my**2 + valid_mask = r_square < torch.where( + alpha < 0.5, 1e6, 1 / (beta * (2 * alpha - 1)) + ) + sqrt_val = 1 - (2 * alpha - 1) * beta * r_square + mz = (1 - beta * (alpha**2) * r_square) / ( + alpha * torch.sqrt(sqrt_val.clip(min=1e-5)) + (1 - alpha) + ) + coeff = 1 / torch.sqrt(mx**2 + my**2 + mz**2 + 1e-5) + + x = coeff * mx + y = coeff * my + z = coeff * mz + self.unprojection_mask = valid_mask & (z > 1e-3) + + xnorm = torch.stack((x, y, z.clamp(1e-3)), dim=1) + return xnorm + + +class Spherical(Camera): + def __init__(self, params): + # Hfov and Vofv are in radians and halved! + super().__init__(params=params, K=None) + + def resize(self, factor): + self.K[..., :2, :] *= factor + self.params[..., :6] *= factor + return self + + def crop(self, left, top, right, bottom): + self.K[..., 0, 2] -= left + self.K[..., 1, 2] -= top + self.params[..., 2] -= left + self.params[..., 3] -= top + W, H = self.params[..., 4], self.params[..., 5] + angle_ratio_W = (W - left - right) / W + angle_ratio_H = (H - top - bottom) / H + + self.params[..., 4] -= left + right + self.params[..., 5] -= top + bottom + + # rescale hfov and vfov + self.params[..., 6] *= angle_ratio_W + self.params[..., 7] *= angle_ratio_H + return self + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def project(self, xyz): + width, height = self.params[..., 4], self.params[..., 5] + hfov, vfov = 2 * self.params[..., 6], 2 * self.params[..., 7] + longitude = torch.atan2(xyz[:, 0], xyz[:, 2]) + latitude = torch.asin(xyz[:, 1] / torch.norm(xyz, dim=1).clamp(min=1e-5)) + + u = longitude / hfov * (width - 1) + (width - 1) / 2 + v = latitude / vfov * (height - 1) + (height - 1) / 2 + + return torch.stack([u, v], dim=1) + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def unproject(self, uv): + u, v = uv.unbind(dim=1) + + width, height = self.params[..., 4], self.params[..., 5] + hfov, vfov = 2 * self.params[..., 6], 2 * self.params[..., 7] + longitude = (u - (width - 1) / 2) / (width - 1) * hfov + latitude = (v - (height - 1) / 2) / (height - 1) * vfov + x = torch.cos(latitude) * torch.sin(longitude) + z = torch.cos(latitude) * torch.cos(longitude) + y = torch.sin(latitude) + unit_sphere = torch.stack([x, y, z], dim=1) + unit_sphere = unit_sphere / torch.norm(unit_sphere, dim=1, keepdim=True).clip( + min=1e-5 + ) + + return unit_sphere + + def reconstruct(self, depth): + id_coords = coords_grid( + 1, depth.shape[-2], depth.shape[-1], device=depth.device + ) + return self.unproject(id_coords) * depth + + def get_new_fov(self, new_shape, original_shape): + new_hfov = 2 * self.params[..., 6] * new_shape[1] / original_shape[1] + new_vfov = 2 * self.params[..., 7] * new_shape[0] / original_shape[0] + return new_hfov, new_vfov + + @property + def hfov(self): + return 2 * self.params[..., 6] + + @property + def vfov(self): + return 2 * self.params[..., 7] + + @property + def max_fov(self): + return 2 * np.pi, 0.9 * np.pi # avoid strong distortion on tops + + +class Fisheye624(Camera): + def __init__(self, params): + super().__init__(params=params, K=None) + self.use_radial = self.params[..., 4:10].abs().sum() > 1e-6 + self.use_tangential = self.params[..., 10:12].abs().sum() > 1e-6 + self.use_thin_prism = self.params[..., 12:].abs().sum() > 1e-6 + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def project(self, xyz): + eps = 1e-9 + B, _, H, W = xyz.shape + N = H * W + xyz = xyz.permute(0, 2, 3, 1).reshape(B, N, 3) + + # Radial correction. + z = xyz[:, :, 2].reshape(B, N, 1) + z = torch.where(torch.abs(z) < eps, eps * torch.sign(z), z) + ab = xyz[:, :, :2] / z + r = torch.norm(ab, dim=-1, p=2, keepdim=True) + th = torch.atan(r) + th_divr = torch.where(r < eps, torch.ones_like(ab), ab / r) + + th_pow = torch.cat( + [torch.pow(th, 3 + i * 2) for i in range(6)], dim=-1 + ) # Create powers of th (th^3, th^5, ...) + distortion_coeffs = self.params[:, 4:10].reshape(B, 1, 6) + th_k = th + torch.sum(th_pow * distortion_coeffs, dim=-1, keepdim=True) + + xr_yr = th_k * th_divr + uv_dist = xr_yr + + # Tangential correction. + p0 = self.params[..., -6].reshape(B, 1) + p1 = self.params[..., -5].reshape(B, 1) + xr = xr_yr[:, :, 0].reshape(B, N) + yr = xr_yr[:, :, 1].reshape(B, N) + xr_yr_sq = torch.square(xr_yr) + xr_sq = xr_yr_sq[:, :, 0].reshape(B, N) + yr_sq = xr_yr_sq[:, :, 1].reshape(B, N) + rd_sq = xr_sq + yr_sq + uv_dist_tu = uv_dist[:, :, 0] + ( + (2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1 + ) + uv_dist_tv = uv_dist[:, :, 1] + ( + (2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0 + ) + uv_dist = torch.stack( + [uv_dist_tu, uv_dist_tv], dim=-1 + ) # Avoids in-place complaint. + + # Thin Prism correction. + s0 = self.params[..., -4].reshape(B, 1) + s1 = self.params[..., -3].reshape(B, 1) + s2 = self.params[..., -2].reshape(B, 1) + s3 = self.params[..., -1].reshape(B, 1) + rd_4 = torch.square(rd_sq) + uv_dist[:, :, 0] = uv_dist[:, :, 0] + (s0 * rd_sq + s1 * rd_4) + uv_dist[:, :, 1] = uv_dist[:, :, 1] + (s2 * rd_sq + s3 * rd_4) + + # Finally, apply standard terms: focal length and camera centers. + if self.params.shape[-1] == 15: + fx_fy = self.params[..., 0].reshape(B, 1, 1) + cx_cy = self.params[..., 1:3].reshape(B, 1, 2) + else: + fx_fy = self.params[..., 0:2].reshape(B, 1, 2) + cx_cy = self.params[..., 2:4].reshape(B, 1, 2) + result = uv_dist * fx_fy + cx_cy + + result = result.reshape(B, H, W, 2).permute(0, 3, 1, 2) + invalid = ( + (result[:, 0] < 0) + | (result[:, 0] > W) + | (result[:, 1] < 0) + | (result[:, 1] > H) + ) + self.projection_mask = (~invalid).unsqueeze(1) + self.overlap_mask = self.mask_overlap_projection(result) + + return result + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def unproject(self, uv, max_iters: int = 10): + eps = 1e-3 + B, _, H, W = uv.shape + N = H * W + uv = uv.permute(0, 2, 3, 1).reshape(B, N, 2) + + if self.params.shape[-1] == 15: + fx_fy = self.params[..., 0].reshape(B, 1, 1) + cx_cy = self.params[..., 1:3].reshape(B, 1, 2) + else: + fx_fy = self.params[..., 0:2].reshape(B, 1, 2) + cx_cy = self.params[..., 2:4].reshape(B, 1, 2) + + uv_dist = (uv - cx_cy) / fx_fy + + # Compute xr_yr using Trust-region method. + xr_yr = uv_dist.clone() + max_iters_tanprism = ( + max_iters if self.use_thin_prism or self.use_tangential else 0 + ) + + for _ in range(max_iters_tanprism): + uv_dist_est = xr_yr.clone() + xr = xr_yr[..., 0].reshape(B, N) + yr = xr_yr[..., 1].reshape(B, N) + xr_yr_sq = torch.square(xr_yr) + xr_sq = xr_yr_sq[..., 0].reshape(B, N) + yr_sq = xr_yr_sq[..., 1].reshape(B, N) + rd_sq = xr_sq + yr_sq + + if self.use_tangential: + # Tangential terms. + p0 = self.params[..., -6].reshape(B, 1) + p1 = self.params[..., -5].reshape(B, 1) + uv_dist_est[..., 0] = uv_dist_est[..., 0] + ( + (2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1 + ) + uv_dist_est[..., 1] = uv_dist_est[..., 1] + ( + (2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0 + ) + + if self.use_thin_prism: + # Thin Prism terms. + s0 = self.params[..., -4].reshape(B, 1) + s1 = self.params[..., -3].reshape(B, 1) + s2 = self.params[..., -2].reshape(B, 1) + s3 = self.params[..., -1].reshape(B, 1) + rd_4 = torch.square(rd_sq) + uv_dist_est[:, :, 0] = uv_dist_est[:, :, 0] + (s0 * rd_sq + s1 * rd_4) + uv_dist_est[:, :, 1] = uv_dist_est[:, :, 1] + (s2 * rd_sq + s3 * rd_4) + + # Compute the derivative of uv_dist w.r.t. xr_yr. + duv_dist_dxr_yr = uv.new_ones(B, N, 2, 2) + + if self.use_tangential: + duv_dist_dxr_yr[..., 0, 0] = 1.0 + 6.0 * xr * p0 + 2.0 * yr * p1 + offdiag = 2.0 * (xr * p1 + yr * p0) + duv_dist_dxr_yr[..., 0, 1] = offdiag + duv_dist_dxr_yr[..., 1, 0] = offdiag + duv_dist_dxr_yr[..., 1, 1] = 1.0 + 6.0 * yr * p1 + 2.0 * xr * p0 + + if self.use_thin_prism: + xr_yr_sq_norm = xr_sq + yr_sq + temp1 = 2.0 * (s0 + 2.0 * s1 * xr_yr_sq_norm) + duv_dist_dxr_yr[..., 0, 0] = duv_dist_dxr_yr[..., 0, 0] + (xr * temp1) + duv_dist_dxr_yr[..., 0, 1] = duv_dist_dxr_yr[..., 0, 1] + (yr * temp1) + temp2 = 2.0 * (s2 + 2.0 * s3 * xr_yr_sq_norm) + duv_dist_dxr_yr[..., 1, 0] = duv_dist_dxr_yr[..., 1, 0] + (xr * temp2) + duv_dist_dxr_yr[..., 1, 1] = duv_dist_dxr_yr[..., 1, 1] + (yr * temp2) + + mat = duv_dist_dxr_yr.reshape(-1, 2, 2) + a = mat[:, 0, 0].reshape(-1, 1, 1) + b = mat[:, 0, 1].reshape(-1, 1, 1) + c = mat[:, 1, 0].reshape(-1, 1, 1) + d = mat[:, 1, 1].reshape(-1, 1, 1) + det = 1.0 / ((a * d) - (b * c)) + top = torch.cat([d, -b], dim=-1) + bot = torch.cat([-c, a], dim=-1) + inv = det * torch.cat([top, bot], dim=-2) + inv = inv.reshape(B, N, 2, 2) + diff = uv_dist - uv_dist_est + a = inv[..., 0, 0] + b = inv[..., 0, 1] + c = inv[..., 1, 0] + d = inv[..., 1, 1] + e = diff[..., 0] + f = diff[..., 1] + step = torch.stack([a * e + b * f, c * e + d * f], dim=-1) + # Newton step. + xr_yr = xr_yr + step + + # Compute theta using Newton's method. + xr_yr_norm = xr_yr.norm(p=2, dim=2).reshape(B, N, 1) + th = xr_yr_norm.clone() + max_iters_radial = max_iters if self.use_radial else 0 + c = ( + torch.tensor([2.0 * i + 3 for i in range(6)], device=self.device) + .reshape(1, 1, 6) + .repeat(B, 1, 1) + ) + radial_params = self.params[..., 4:10].reshape(B, 1, 6) + + # Trust region parameters + delta = torch.full((B, N, 1), 0.1, device=self.device) # Initial trust radius + delta_max = torch.tensor(1.0, device=self.device) # Maximum trust radius + eta = 0.1 # Acceptable reduction threshold + + for i in range(max_iters_radial): + th_sq = th * th + # Compute powers of th^2 up to th^(12) + theta_powers = torch.cat( + [th_sq ** (i + 1) for i in range(6)], dim=-1 + ) # Shape: (B, N, 6) + + # Compute th_radial: radial distortion model applied to th + th_radial = 1.0 + torch.sum( + theta_powers * radial_params, dim=-1, keepdim=True + ) + th_radial = th_radial * th + + # Compute derivative dthd_th + dthd_th = 1.0 + torch.sum( + c * radial_params * theta_powers, dim=-1, keepdim=True + ) + + # Compute residual + residual = th_radial - xr_yr_norm # Shape: (B, N, 1) + residual_norm = torch.norm(residual, dim=2, keepdim=True) + + # Check for convergence + if torch.max(torch.abs(residual)) < eps: + break + + # Avoid division by zero by adding a small epsilon + safe_dthd_th = dthd_th.clone() + zero_derivative_mask = dthd_th.abs() < eps + safe_dthd_th[zero_derivative_mask] = eps + + # Compute Newton's step + step = -residual / safe_dthd_th + + # Compute predicted reduction + predicted_reduction = -(residual * step).sum(dim=2, keepdim=True) + + # Adjust step based on trust region + step_norm = torch.norm(step, dim=2, keepdim=True) + over_trust_mask = step_norm > delta + + # Scale step if it exceeds trust radius + step_scaled = step.clone() + step_scaled[over_trust_mask] = step[over_trust_mask] * ( + delta[over_trust_mask] / step_norm[over_trust_mask] + ) + + # Update theta + th_new = th + step_scaled + + # Compute new residual + th_sq_new = th_new * th_new + theta_powers_new = torch.cat( + [th_sq_new ** (j + 1) for j in range(6)], dim=-1 + ) + th_radial_new = 1.0 + torch.sum( + theta_powers_new * radial_params, dim=-1, keepdim=True + ) + th_radial_new = th_radial_new * th_new + residual_new = th_radial_new - xr_yr_norm + residual_new_norm = torch.norm(residual_new, dim=2, keepdim=True) + + # Compute actual reduction + actual_reduction = residual_norm - residual_new_norm + + # Compute ratio of actual to predicted reduction + rho = actual_reduction / predicted_reduction + rho[(actual_reduction == 0) & (predicted_reduction == 0)] = 1.0 + + # Update trust radius delta + delta_update_mask = rho > 0.5 + delta[delta_update_mask] = torch.min( + 2.0 * delta[delta_update_mask], delta_max + ) + + delta_decrease_mask = rho < 0.2 + delta[delta_decrease_mask] = 0.25 * delta[delta_decrease_mask] + + # Accept or reject the step + accept_step_mask = rho > eta + th = torch.where(accept_step_mask, th_new, th) + + # Compute the ray direction using theta and xr_yr. + close_to_zero = torch.logical_and(th.abs() < eps, xr_yr_norm.abs() < eps) + ray_dir = torch.where(close_to_zero, xr_yr, torch.tan(th) / xr_yr_norm * xr_yr) + + ray = torch.cat([ray_dir, uv.new_ones(B, N, 1)], dim=2) + ray = ray.reshape(B, H, W, 3).permute(0, 3, 1, 2) + + return ray + + +class MEI(Camera): + def __init__(self, params): + super().__init__(params=params, K=None) + # fx fy cx cy k1 k2 p1 p2 xi + self.use_radial = self.params[..., 4:6].abs().sum() > 1e-6 + self.use_tangential = self.params[..., 6:8].abs().sum() > 1e-6 + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def unproject(self, uv, max_iters: int = 20): + eps = 1e-6 + B, _, H, W = uv.shape + N = H * W + uv = uv.permute(0, 2, 3, 1).reshape(B, N, 2) + + k1, k2, p0, p1, xi = self.params[..., 4:9].unbind(dim=1) + fx_fy = self.params[..., 0:2].reshape(B, 1, 2) + cx_cy = self.params[..., 2:4].reshape(B, 1, 2) + + uv_dist = (uv - cx_cy) / fx_fy + + # Compute xr_yr using Newton's method. + xr_yr = uv_dist.clone() # Initial guess. + max_iters_tangential = max_iters if self.use_tangential else 0 + for _ in range(max_iters_tangential): + uv_dist_est = xr_yr.clone() + + # Tangential terms. + xr = xr_yr[..., 0] + yr = xr_yr[..., 1] + xr_yr_sq = xr_yr**2 + xr_sq = xr_yr_sq[..., 0] + yr_sq = xr_yr_sq[..., 1] + rd_sq = xr_sq + yr_sq + uv_dist_est[..., 0] = uv_dist_est[..., 0] + ( + (2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1 + ) + uv_dist_est[..., 1] = uv_dist_est[..., 1] + ( + (2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0 + ) + + # Compute the derivative of uv_dist w.r.t. xr_yr. + duv_dist_dxr_yr = torch.ones((B, N, 2, 2), device=uv.device) + duv_dist_dxr_yr[..., 0, 0] = 1.0 + 6.0 * xr * p0 + 2.0 * yr * p1 + offdiag = 2.0 * (xr * p1 + yr * p0) + duv_dist_dxr_yr[..., 0, 1] = offdiag + duv_dist_dxr_yr[..., 1, 0] = offdiag + duv_dist_dxr_yr[..., 1, 1] = 1.0 + 6.0 * yr * p1 + 2.0 * xr * p0 + + mat = duv_dist_dxr_yr.reshape(-1, 2, 2) + a = mat[:, 0, 0].reshape(-1, 1, 1) + b = mat[:, 0, 1].reshape(-1, 1, 1) + c = mat[:, 1, 0].reshape(-1, 1, 1) + d = mat[:, 1, 1].reshape(-1, 1, 1) + det = 1.0 / ((a * d) - (b * c)) + top = torch.cat([d, -b], dim=-1) + bot = torch.cat([-c, a], dim=-1) + inv = det * torch.cat([top, bot], dim=-2) + inv = inv.reshape(B, N, 2, 2) + + diff = uv_dist - uv_dist_est + a = inv[..., 0, 0] + b = inv[..., 0, 1] + c = inv[..., 1, 0] + d = inv[..., 1, 1] + e = diff[..., 0] + f = diff[..., 1] + step = torch.stack([a * e + b * f, c * e + d * f], dim=-1) + + # Newton step. + xr_yr = xr_yr + step + + # Compute theta using Newton's method. + xr_yr_norm = xr_yr.norm(p=2, dim=2).reshape(B, N, 1) + th = xr_yr_norm.clone() + max_iters_radial = max_iters if self.use_radial else 0 + for _ in range(max_iters_radial): + th_radial = 1.0 + k1 * torch.pow(th, 2) + k2 * torch.pow(th, 4) + dthd_th = 1.0 + 3.0 * k1 * torch.pow(th, 2) + 5.0 * k2 * torch.pow(th, 4) + th_radial = th_radial * th + step = (xr_yr_norm - th_radial) / dthd_th + # handle dthd_th close to 0. + step = torch.where( + torch.abs(dthd_th) > eps, step, torch.sign(step) * eps * 10.0 + ) + th = th + step + + # Compute the ray direction using theta and xr_yr. + close_to_zero = (torch.abs(th) < eps) & (torch.abs(xr_yr_norm) < eps) + ray_dir = torch.where(close_to_zero, xr_yr, th * xr_yr / xr_yr_norm) + + # Compute the 3D projective ray + rho2_u = ( + ray_dir.norm(p=2, dim=2, keepdim=True) ** 2 + ) # B N 1 # x_c * x_c + y_c * y_c + xi = xi.reshape(B, 1, 1) + sqrt_term = torch.sqrt(1.0 + (1.0 - xi * xi) * rho2_u) + P_z = 1.0 - xi * (rho2_u + 1.0) / (xi + sqrt_term) + + # Special case when xi is 1.0 (unit sphere projection ??) + P_z = torch.where(xi == 1.0, (1.0 - rho2_u) / 2.0, P_z) + + ray = torch.cat([ray_dir, P_z], dim=-1) + ray = ray.reshape(B, H, W, 3).permute(0, 3, 1, 2) + + return ray + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def project(self, xyz): + eps = 1e-4 + is_flat = xyz.ndim == 3 + B, N = xyz.shape[:2] + + if not is_flat: + B, _, H, W = xyz.shape + N = H * W + xyz = xyz.permute(0, 2, 3, 1).reshape(B, N, 3) + + k1, k2, p0, p1, xi = self.params[..., 4:].unbind(dim=1) + fx_fy = self.params[..., 0:2].reshape(B, 1, 2) + cx_cy = self.params[..., 2:4].reshape(B, 1, 2) + + norm = xyz.norm(p=2, dim=-1, keepdim=True) + ab = xyz[..., :-1] / (xyz[..., -1:] + xi.reshape(B, 1, 1) * norm) + + # radial correction + r = ab.norm(dim=-1, p=2, keepdim=True) + k1 = self.params[..., 4].reshape(B, 1, 1) + k2 = self.params[..., 5].reshape(B, 1, 1) + # ab / r * th * (1 + k1 * (th ** 2) + k2 * (th**4)) + # but here r = th, no spherical distortion + xr_yr = ab * (1 + k1 * (r**2) + k2 * (r**4)) + + # Tangential correction. + uv_dist = xr_yr + p0 = self.params[:, -3].reshape(B, 1) + p1 = self.params[:, -2].reshape(B, 1) + xr = xr_yr[..., 0].reshape(B, N) + yr = xr_yr[..., 1].reshape(B, N) + xr_yr_sq = torch.square(xr_yr) + xr_sq = xr_yr_sq[:, :, 0].reshape(B, N) + yr_sq = xr_yr_sq[:, :, 1].reshape(B, N) + rd_sq = xr_sq + yr_sq + uv_dist_tu = uv_dist[:, :, 0] + ( + (2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1 + ) + uv_dist_tv = uv_dist[:, :, 1] + ( + (2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0 + ) + uv_dist = torch.stack( + [uv_dist_tu, uv_dist_tv], dim=-1 + ) # Avoids in-place complaint. + + result = uv_dist * fx_fy + cx_cy + + if not is_flat: + result = result.reshape(B, H, W, 2).permute(0, 3, 1, 2) + invalid = ( + (result[:, 0] < 0) + | (result[:, 0] > W) + | (result[:, 1] < 0) + | (result[:, 1] > H) + ) + self.projection_mask = (~invalid).unsqueeze(1) + # creates hole in the middle... ?? + # self.overlap_mask = self.mask_overlap_projection(result) + + return result + + +class BatchCamera(Camera): + def __init__(self, params, K, original_class, cameras): + super().__init__(params, K) + self.original_class = original_class + self.cameras = cameras + + # Delegate these methods to original camera + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def project(self, points_3d): + return torch.cat( + [ + camera.project(points_3d[i : i + 1]) + for i, camera in enumerate(self.cameras) + ] + ) + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def unproject(self, points_2d): + val = torch.cat( + [camera.unproject(points_2d) for i, camera in enumerate(self.cameras)] + ) + return val + + def crop(self, left, top, right=None, bottom=None): + val = torch.cat( + [ + camera.crop(left, top, right, bottom) + for i, camera in enumerate(self.cameras) + ] + ) + return val + + def resize(self, ratio): + val = torch.cat([camera.resize(ratio) for i, camera in enumerate(self.cameras)]) + return val + + def reconstruct(self, depth): + val = torch.cat( + [ + camera.reconstruct(depth[i : i + 1]) + for i, camera in enumerate(self.cameras) + ] + ) + return val + + def get_projection_mask(self): + return torch.cat( + [camera.projection_mask for i, camera in enumerate(self.cameras)] + ) + + def to(self, device, non_blocking=False): + self = super().to(device, non_blocking=non_blocking) + self.cameras = [ + camera.to(device, non_blocking=non_blocking) for camera in self.cameras + ] + return self + + def reshape(self, *shape): + # Reshape the intrinsic matrix (K) and params + # we know that the shape of K is (..., 3, 3) and params is (..., 16) + reshaped_K = self.K.reshape(*shape, 3, 3) + reshaped_params = self.params.reshape(*shape, self.params.shape[-1]) + + self.cameras = np.array(self.cameras, dtype=object).reshape(shape).tolist() + self.original_class = ( + np.array(self.original_class, dtype=object).reshape(shape).tolist() + ) + + # Create a new BatchCamera with reshaped K and params + return BatchCamera( + reshaped_params, reshaped_K, self.original_class, self.cameras + ) + + def get_new_fov(self, new_shape, original_shape): + return [ + camera.get_new_fov(new_shape, original_shape) + for i, camera in enumerate(self.cameras) + ] + + def squeeze(self, dim): + return BatchCamera( + self.params.squeeze(dim), + self.K.squeeze(dim), + squeeze_list(self.original_class, dim=dim), + squeeze_list(self.cameras, dim=dim), + ) + + def __getitem__(self, idx): + if isinstance(idx, int): + return self.cameras[idx] + + elif isinstance(idx, slice): + return BatchCamera( + self.params[idx], + self.K[idx], + self.original_class[idx], + self.cameras[idx], + ) + + raise TypeError(f"Invalid index type: {type(idx)}") + + def __setitem__(self, idx, value): + # If it's an integer index, return a single camera + if isinstance(idx, int): + self.cameras[idx] = value + self.params[idx, :] = 0.0 + self.params[idx, : value.params.shape[1]] = value.params[0] + self.K[idx] = value.K[0] + + self.original_class[idx] = getattr( + value, "original_class", value.__class__.__name__ + ) + + # If it's a slice, return a new BatchCamera with sliced cameras + elif isinstance(idx, slice): + # Update each internal attribute using the slice + self.params[idx] = value.params + self.K[idx] = value.K + self.original_class[idx] = value.original_class + self.cameras[idx] = value.cameras + + def __len__(self): + return len(self.cameras) + + @classmethod + def from_camera(cls, camera): + return cls(camera.params, camera.K, [camera.__class__.__name__], [camera]) + + @property + def is_perspective(self): + return [isinstance(camera, Pinhole) for camera in self.cameras] + + @property + def is_spherical(self): + return [isinstance(camera, Spherical) for camera in self.cameras] + + @property + def is_eucm(self): + return [isinstance(camera, EUCM) for camera in self.cameras] + + @property + def is_fisheye(self): + return [isinstance(camera, Fisheye624) for camera in self.cameras] + + @property + def is_pinhole(self): + return [isinstance(camera, Pinhole) for camera in self.cameras] + + @property + def hfov(self): + return [camera.hfov for camera in self.cameras] + + @property + def vfov(self): + return [camera.vfov for camera in self.cameras] + + @property + def max_fov(self): + return [camera.max_fov for camera in self.cameras] diff --git a/unidepth/utils/chamfer_distance.py b/unidepth/utils/chamfer_distance.py new file mode 100644 index 0000000..08381c6 --- /dev/null +++ b/unidepth/utils/chamfer_distance.py @@ -0,0 +1,157 @@ +from typing import Union + +import torch + +try: + from pytorch3d.ops.knn import knn_points +except ImportError: + print( + "Pytorch3D is not available. Either install it or compile knn under " + "unidepth/ops/knn with `bash compile.sh`" + ) + from unidepth.ops.knn import knn_points + + +def _validate_chamfer_reduction_inputs( + batch_reduction: Union[str, None], point_reduction: str +): + """Check the requested reductions are valid. + + Args: + batch_reduction: Reduction operation to apply for the loss across the + batch, can be one of ["mean", "sum"] or None. + point_reduction: Reduction operation to apply for the loss across the + points, can be one of ["mean", "sum"]. + """ + if batch_reduction is not None and batch_reduction not in ["mean", "sum"]: + raise ValueError('batch_reduction must be one of ["mean", "sum"] or None') + if point_reduction not in ["mean", "sum"]: + raise ValueError('point_reduction must be one of ["mean", "sum"]') + + +def _handle_pointcloud_input( + points: torch.Tensor, + lengths: Union[torch.Tensor, None], + normals: Union[torch.Tensor, None], +): + """ + If points is an instance of Pointclouds, retrieve the padded points tensor + along with the number of points per batch and the padded normals. + Otherwise, return the input points (and normals) with the number of points per cloud + set to the size of the second dimension of `points`. + """ + if points.ndim != 3: + raise ValueError("Expected points to be of shape (N, P, D)") + X = points + if lengths is not None and (lengths.ndim != 1 or lengths.shape[0] != X.shape[0]): + raise ValueError("Expected lengths to be of shape (N,)") + if lengths is None: + lengths = torch.full( + (X.shape[0],), X.shape[1], dtype=torch.int64, device=points.device + ) + if normals is not None and normals.ndim != 3: + raise ValueError("Expected normals to be of shape (N, P, 3") + + return X, lengths, normals + + +class ChamferDistance(torch.nn.Module): + def forward( + self, + x, + y, + x_lengths=None, + y_lengths=None, + x_normals=None, + y_normals=None, + weights=None, + batch_reduction: Union[str, None] = "mean", + point_reduction: str = "mean", + ): + """ + Chamfer distance between two pointclouds x and y. + + Args: + x: FloatTensor of shape (N, P1, D) or a Pointclouds object representing + a batch of point clouds with at most P1 points in each batch element, + batch size N and feature dimension D. + y: FloatTensor of shape (N, P2, D) or a Pointclouds object representing + a batch of point clouds with at most P2 points in each batch element, + batch size N and feature dimension D. + x_lengths: Optional LongTensor of shape (N,) giving the number of points in each + cloud in x. + y_lengths: Optional LongTensor of shape (N,) giving the number of points in each + cloud in x. + x_normals: Optional FloatTensor of shape (N, P1, D). + y_normals: Optional FloatTensor of shape (N, P2, D). + weights: Optional FloatTensor of shape (N,) giving weights for + batch elements for reduction operation. + batch_reduction: Reduction operation to apply for the loss across the + batch, can be one of ["mean", "sum"] or None. + point_reduction: Reduction operation to apply for the loss across the + points, can be one of ["mean", "sum"]. + + Returns: + 2-element tuple containing + + - **loss**: Tensor giving the reduced distance between the pointclouds + in x and the pointclouds in y. + - **loss_normals**: Tensor giving the reduced cosine distance of normals + between pointclouds in x and pointclouds in y. Returns None if + x_normals and y_normals are None. + """ + _validate_chamfer_reduction_inputs(batch_reduction, point_reduction) + + x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals) + y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals) + + return_normals = x_normals is not None and y_normals is not None + + N, P1, D = x.shape + P2 = y.shape[1] + + # Check if inputs are heterogeneous and create a lengths mask. + is_x_heterogeneous = (x_lengths != P1).any() + is_y_heterogeneous = (y_lengths != P2).any() + x_mask = ( + torch.arange(P1, device=x.device)[None] >= x_lengths[:, None] + ) # shape [N, P1] + y_mask = ( + torch.arange(P2, device=y.device)[None] >= y_lengths[:, None] + ) # shape [N, P2] + + if y.shape[0] != N or y.shape[2] != D: + raise ValueError("y does not have the correct shape.") + if weights is not None: + if weights.size(0) != N: + raise ValueError("weights must be of shape (N,).") + if not (weights >= 0).all(): + raise ValueError("weights cannot be negative.") + if weights.sum() == 0.0: + weights = weights.view(N, 1) + if batch_reduction in ["mean", "sum"]: + return ( + (x.sum((1, 2)) * weights).sum() * 0.0, + (x.sum((1, 2)) * weights).sum() * 0.0, + ) + return ( + (x.sum((1, 2)) * weights) * 0.0, + (x.sum((1, 2)) * weights) * 0.0, + ) + + x_nn = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, K=1) + y_nn = knn_points(y, x, lengths1=y_lengths, lengths2=x_lengths, K=1) + + cham_x = x_nn.dists[..., 0] # (N, P1) + cham_y = y_nn.dists[..., 0] # (N, P2) + + if is_x_heterogeneous: + cham_x[x_mask] = 0.0 + if is_y_heterogeneous: + cham_y[y_mask] = 0.0 + + if weights is not None: + cham_x *= weights.view(N, 1) + cham_y *= weights.view(N, 1) + + return cham_x, cham_y, x_nn.idx[..., -1], y_nn.idx[..., -1] diff --git a/unidepth/utils/constants.py b/unidepth/utils/constants.py index fb18fd1..60d1612 100644 --- a/unidepth/utils/constants.py +++ b/unidepth/utils/constants.py @@ -18,5 +18,3 @@ ), dim=0, ) -LOGERR_BINS = torch.linspace(-2, 2, steps=128 + 1) -LINERR_BINS = torch.linspace(-50, 50, steps=256 + 1) diff --git a/unidepth/utils/coordinate.py b/unidepth/utils/coordinate.py new file mode 100644 index 0000000..a091299 --- /dev/null +++ b/unidepth/utils/coordinate.py @@ -0,0 +1,27 @@ +import torch + + +def coords_grid(b, h, w, homogeneous=False, device=None, noisy=False): + pixel_coords_x = torch.linspace(0.5, w - 0.5, w, device=device) + pixel_coords_y = torch.linspace(0.5, h - 0.5, h, device=device) + if noisy: # \pm 0.5px noise + pixel_coords_x += torch.rand_like(pixel_coords_x) - 0.5 + pixel_coords_y += torch.rand_like(pixel_coords_y) - 0.5 + + stacks = [pixel_coords_x.repeat(h, 1), pixel_coords_y.repeat(w, 1).t()] + if homogeneous: + ones = torch.ones_like(stacks[0]) # [H, W] + stacks.append(ones) + grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W] + grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W] + if device is not None: + grid = grid.to(device) + + return grid + + +def normalize_coords(coords, h, w): + c = torch.tensor([(w - 1) / 2.0, (h - 1) / 2.0], device=coords.device).view( + 1, 2, 1, 1 + ) + return (coords - c) / c diff --git a/unidepth/utils/distributed.py b/unidepth/utils/distributed.py index 2b2f16f..2b38e3f 100644 --- a/unidepth/utils/distributed.py +++ b/unidepth/utils/distributed.py @@ -1,9 +1,5 @@ -""" -Author: Luigi Piccinelli -Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) -""" - import os +import pickle import platform import subprocess import warnings @@ -14,6 +10,8 @@ from torch import distributed as dist from torch import multiprocessing as mp +_LOCAL_PROCESS_GROUP = None + def is_dist_avail_and_initialized(): if not dist.is_available(): @@ -29,6 +27,35 @@ def get_rank(): return dist.get_rank() +def get_local_rank() -> int: + """ + Returns: + The rank of the current process within the local (per-machine) process group. + """ + if not is_dist_avail_and_initialized(): + return 0 + assert _LOCAL_PROCESS_GROUP is not None + return dist.get_rank(group=_LOCAL_PROCESS_GROUP) + + +def get_local_size() -> int: + """ + Returns: + The size of the per-machine process group, + i.e. the number of processes per machine. + """ + if not is_dist_avail_and_initialized(): + return 1 + assert _LOCAL_PROCESS_GROUP is not None + return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + def barrier(): if not is_dist_avail_and_initialized(): return @@ -69,44 +96,34 @@ def setup_multi_processes(cfg): mp.set_start_method(mp_start_method, force=True) # disable opencv multithreading to avoid system being overloaded - opencv_num_threads = cfg.get("opencv_num_threads", 0) - cv2.setNumThreads(opencv_num_threads) + # opencv_num_threads = cfg.get('opencv_num_threads', 0) + # cv2.setNumThreads(opencv_num_threads) # setup OMP threads # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa - workers_per_gpu = cfg.get("workers_per_gpu", 4) - - if "OMP_NUM_THREADS" not in os.environ and workers_per_gpu > 1: - omp_num_threads = 1 - warnings.warn( - f"Setting OMP_NUM_THREADS environment variable for each process " - f"to be {omp_num_threads} in default, to avoid your system being " - f"overloaded, please further tune the variable for optimal " - f"performance in your application as needed." - ) - os.environ["OMP_NUM_THREADS"] = str(omp_num_threads) + # workers_per_gpu = cfg.get('workers_per_gpu', 4) + + # if 'OMP_NUM_THREADS' not in os.environ and workers_per_gpu > 1: + # omp_num_threads = 1 + # warnings.warn( + # f'Setting OMP_NUM_THREADS environment variable for each process ' + # f'to be {omp_num_threads} in default, to avoid your system being ' + # f'overloaded, please further tune the variable for optimal ' + # f'performance in your application as needed.') + # os.environ['OMP_NUM_THREADS'] = str(omp_num_threads) # setup MKL threads - if "MKL_NUM_THREADS" not in os.environ and workers_per_gpu > 1: - mkl_num_threads = os.environ.get("OMP_NUM_THREADS", 1) - warnings.warn( - f"Setting MKL_NUM_THREADS environment variable for each process " - f"to be {mkl_num_threads} in default, to avoid your system being " - f"overloaded, please further tune the variable for optimal " - f"performance in your application as needed." - ) - os.environ["MKL_NUM_THREADS"] = str(mkl_num_threads) + # if 'MKL_NUM_THREADS' not in os.environ and workers_per_gpu > 1: + # mkl_num_threads = os.environ.get('OMP_NUM_THREADS', 1) + # warnings.warn( + # f'Setting MKL_NUM_THREADS environment variable for each process ' + # f'to be {mkl_num_threads} in default, to avoid your system being ' + # f'overloaded, please further tune the variable for optimal ' + # f'performance in your application as needed.') + # os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads) def setup_slurm(backend: str, port: str) -> None: - """Initialize slurm distributed training environment. - If argument ``port`` is not specified, then the master port will be system - environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system - environment variable, then a default port ``29500`` will be used. - Args: - backend (str): Backend of torch.distributed. - port (int, optional): Master port. Defaults to None. - """ proc_id = int(os.environ["SLURM_PROCID"]) ntasks = int(os.environ["SLURM_NTASKS"]) node_list = os.environ["SLURM_NODELIST"] @@ -159,13 +176,10 @@ def sync_tensor_across_gpus(t, dim=0, cat=True): return all_ts -import pickle - - def sync_string_across_gpus(keys: list[str], device, dim=0): keys_serialized = pickle.dumps(keys, protocol=pickle.HIGHEST_PROTOCOL) - keys_serialized_tensor = torch.frombuffer(keys_serialized, dtype=torch.uint8).to( - device + keys_serialized_tensor = ( + torch.frombuffer(keys_serialized, dtype=torch.uint8).clone().to(device) ) keys_serialized_tensor = sync_tensor_across_gpus( keys_serialized_tensor, dim=0, cat=False @@ -176,3 +190,55 @@ def sync_string_across_gpus(keys: list[str], device, dim=0): for key in pickle.loads(bytes(keys.cpu().tolist())) ] return keys + + +def create_local_process_group() -> None: + num_workers_per_machine = torch.cuda.device_count() + global _LOCAL_PROCESS_GROUP + assert _LOCAL_PROCESS_GROUP is None + assert get_world_size() % num_workers_per_machine == 0 + num_machines = get_world_size() // num_workers_per_machine + machine_rank = get_rank() // num_workers_per_machine + for i in range(num_machines): + ranks_on_i = list( + range(i * num_workers_per_machine, (i + 1) * num_workers_per_machine) + ) + pg = dist.new_group(ranks_on_i) + if i == machine_rank: + _LOCAL_PROCESS_GROUP = pg + + +def _get_global_gloo_group(): + if dist.get_backend() == "nccl": + return dist.new_group(backend="gloo") + else: + return dist.group.WORLD + + +def all_gather(data, group=None): + if get_world_size() == 1: + return [data] + if group is None: + group = ( + _get_global_gloo_group() + ) # use CPU group by default, to reduce GPU RAM usage. + world_size = dist.get_world_size(group) + if world_size == 1: + return [data] + + output = [None for _ in range(world_size)] + dist.all_gather_object(output, data, group=group) + return output + + +def local_broadcast_process_authkey(): + if get_local_size() == 1: + return + local_rank = get_local_rank() + authkey = bytes(mp.current_process().authkey) + all_keys = all_gather(authkey) + local_leader_key = all_keys[get_rank() - local_rank] + if authkey != local_leader_key: + # print("Process authkey is different from the key of local leader! workers are launched independently ??") + # print("Overwriting local authkey ...") + mp.current_process().authkey = local_leader_key diff --git a/unidepth/utils/evaluation_depth.py b/unidepth/utils/evaluation_depth.py index ab35a93..8345760 100644 --- a/unidepth/utils/evaluation_depth.py +++ b/unidepth/utils/evaluation_depth.py @@ -1,48 +1,37 @@ -""" -Author: Luigi Piccinelli -Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) -""" - -# We prefer not to install PyTorch3D in the package -# Code commented is how 3D metrics are computed - from collections import defaultdict from functools import partial import torch import torch.nn.functional as F -from unidepth.utils.constants import DEPTH_BINS - -# from chamfer_distance import ChamferDistance +from unidepth.utils.chamfer_distance import ChamferDistance +chamfer_cls = ChamferDistance() -# chamfer_cls = ChamferDistance() - -# def chamfer_dist(tensor1, tensor2): -# x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device) -# y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device) -# dist1, dist2, idx1, idx2 = chamfer_cls( -# tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths -# ) -# return (torch.sqrt(dist1) + torch.sqrt(dist2)) / 2 +def chamfer_dist(tensor1, tensor2): + x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device) + y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device) + dist1, dist2, idx1, idx2 = chamfer_cls( + tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths + ) + return (torch.sqrt(dist1) + torch.sqrt(dist2)) / 2 -# def auc(tensor1, tensor2, thresholds): -# x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device) -# y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device) -# dist1, dist2, idx1, idx2 = chamfer_cls( -# tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths -# ) -# # compute precision recall -# precisions = [(dist1 < threshold).sum() / dist1.numel() for threshold in thresholds] -# recalls = [(dist2 < threshold).sum() / dist2.numel() for threshold in thresholds] -# auc_value = torch.trapz( -# torch.tensor(precisions, device=tensor1.device), -# torch.tensor(recalls, device=tensor1.device), -# ) -# return auc_value +def auc(tensor1, tensor2, thresholds): + x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device) + y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device) + dist1, dist2, idx1, idx2 = chamfer_cls( + tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths + ) + # compute precision recall + precisions = [(dist1 < threshold).sum() / dist1.numel() for threshold in thresholds] + recalls = [(dist2 < threshold).sum() / dist2.numel() for threshold in thresholds] + auc_value = torch.trapz( + torch.tensor(precisions, device=tensor1.device), + torch.tensor(recalls, device=tensor1.device), + ) + return auc_value def delta(tensor1, tensor2, exponent): @@ -50,29 +39,30 @@ def delta(tensor1, tensor2, exponent): return (inlier < 1.25**exponent).to(torch.float32).mean() -def ssi(tensor1, tensor2, qtl=0.05): +def tau(tensor1, tensor2, perc): + inlier = torch.maximum((tensor1 / tensor2), (tensor2 / tensor1)) + return (inlier < (1.0 + perc)).to(torch.float32).mean() + + +def ssi(tensor1, tensor2): stability_mat = 1e-9 * torch.eye(2, device=tensor1.device) - error = (tensor1 - tensor2).abs() - mask = error < torch.quantile(error, 1 - qtl) - tensor1_mask = tensor1[mask] - tensor2_mask = tensor2[mask] tensor2_one = torch.stack( - [tensor2_mask.detach(), torch.ones_like(tensor2_mask).detach()], dim=1 + [tensor2.detach(), torch.ones_like(tensor2).detach()], dim=1 ) scale_shift = torch.inverse(tensor2_one.T @ tensor2_one + stability_mat) @ ( - tensor2_one.T @ tensor1_mask.unsqueeze(1) + tensor2_one.T @ tensor1.unsqueeze(1) ) scale, shift = scale_shift.squeeze().chunk(2, dim=0) return tensor2 * scale + shift - # tensor2_one = torch.stack([tensor2.detach(), torch.ones_like(tensor2).detach()], dim=1) - # scale_shift = torch.inverse(tensor2_one.T @ tensor2_one + stability_mat) @ (tensor2_one.T @ tensor1.unsqueeze(1)) - # scale, shift = scale_shift.squeeze().chunk(2, dim=0) - # return tensor2 * scale + shift -def d1_ssi(tensor1, tensor2): - delta_ = delta(tensor1, ssi(tensor1, tensor2), 1.0) - return delta_ +def si(tensor1, tensor2): + return tensor2 * torch.median(tensor1) / torch.median(tensor2) + + +def arel(tensor1, tensor2): + tensor2 = tensor2 * torch.median(tensor1) / torch.median(tensor2) + return (torch.abs(tensor1 - tensor2) / tensor1).mean() def d_auc(tensor1, tensor2): @@ -81,23 +71,23 @@ def d_auc(tensor1, tensor2): return torch.trapz(torch.tensor(deltas, device=tensor1.device), exponents) / 5.0 -# def f1_score(tensor1, tensor2, thresholds): -# x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device) -# y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device) -# dist1, dist2, idx1, idx2 = chamfer_cls( -# tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths -# ) -# # compute precision recall -# precisions = [(dist1 < threshold).sum() / dist1.numel() for threshold in thresholds] -# recalls = [(dist2 < threshold).sum() / dist2.numel() for threshold in thresholds] -# precisions = torch.tensor(precisions, device=tensor1.device) -# recalls = torch.tensor(recalls, device=tensor1.device) -# f1_thresholds = 2 * precisions * recalls / (precisions + recalls) -# f1_thresholds = torch.where( -# torch.isnan(f1_thresholds), torch.zeros_like(f1_thresholds), f1_thresholds -# ) -# f1_value = torch.trapz(f1_thresholds) / len(thresholds) -# return f1_value +def f1_score(tensor1, tensor2, thresholds): + x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device) + y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device) + dist1, dist2, idx1, idx2 = chamfer_cls( + tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths + ) + # compute precision recall + precisions = [(dist1 < threshold).sum() / dist1.numel() for threshold in thresholds] + recalls = [(dist2 < threshold).sum() / dist2.numel() for threshold in thresholds] + precisions = torch.tensor(precisions, device=tensor1.device) + recalls = torch.tensor(recalls, device=tensor1.device) + f1_thresholds = 2 * precisions * recalls / (precisions + recalls) + f1_thresholds = torch.where( + torch.isnan(f1_thresholds), torch.zeros_like(f1_thresholds), f1_thresholds + ) + f1_value = torch.trapz(f1_thresholds) / len(thresholds) + return f1_value DICT_METRICS = { @@ -115,21 +105,21 @@ def d_auc(tensor1, tensor2): "medianlog": lambda gt, pred: 100 * (torch.log(pred) - torch.log(gt)).median().abs(), "d_auc": d_auc, - "d1_ssi": d1_ssi, + "tau": partial(tau, perc=0.03), } -# DICT_METRICS_3D = { -# "chamfer": lambda gt, pred, thresholds: chamfer_dist( -# gt.unsqueeze(0).permute(0, 2, 1), pred.unsqueeze(0).permute(0, 2, 1) -# ), -# "F1": lambda gt, pred, thresholds: f1_score( -# gt.unsqueeze(0).permute(0, 2, 1), -# pred.unsqueeze(0).permute(0, 2, 1), -# thresholds=thresholds, -# ), -# } - +DICT_METRICS_3D = { + "MSE_3d": lambda gt, pred, thresholds: torch.norm(gt - pred, dim=0, p=2), + "chamfer": lambda gt, pred, thresholds: chamfer_dist( + gt.unsqueeze(0).permute(0, 2, 1), pred.unsqueeze(0).permute(0, 2, 1) + ), + "F1": lambda gt, pred, thresholds: f1_score( + gt.unsqueeze(0).permute(0, 2, 1), + pred.unsqueeze(0).permute(0, 2, 1), + thresholds=thresholds, + ), +} DICT_METRICS_D = { "a1": lambda gt, pred: (torch.maximum((gt / pred), (pred / gt)) > 1.25**1.0).to( @@ -146,29 +136,35 @@ def eval_depth( preds = F.interpolate(preds, gts.shape[-2:], mode="bilinear") for i, (gt, pred, mask) in enumerate(zip(gts, preds, masks)): if max_depth is not None: - mask = torch.logical_and(mask, gt <= max_depth) + mask = mask & (gt <= max_depth) for name, fn in DICT_METRICS.items(): + if name in ["tau", "d1", "arel"]: + for rescale_fn in ["ssi", "si"]: + summary_metrics[f"{name}_{rescale_fn}"].append( + fn(gt[mask], eval(rescale_fn)(gt[mask], pred[mask])) + ) summary_metrics[name].append(fn(gt[mask], pred[mask]).mean()) return {name: torch.stack(vals, dim=0) for name, vals in summary_metrics.items()} -# def eval_3d( -# gts: torch.Tensor, preds: torch.Tensor, masks: torch.Tensor, thresholds=None -# ): -# summary_metrics = defaultdict(list) -# w_max = min(gts.shape[-1] // 4, 400) -# gts = F.interpolate( -# gts, (int(w_max * gts.shape[-2] / gts.shape[-1]), w_max), mode="nearest" -# ) -# preds = F.interpolate(preds, gts.shape[-2:], mode="nearest") -# masks = F.interpolate( -# masks.to(torch.float32), gts.shape[-2:], mode="nearest" -# ).bool() -# for i, (gt, pred, mask) in enumerate(zip(gts, preds, masks)): -# if not torch.any(mask): -# continue -# for name, fn in DICT_METRICS_3D.items(): -# summary_metrics[name].append( -# fn(gt[:, mask.squeeze()], pred[:, mask.squeeze()], thresholds).mean() -# ) -# return {name: torch.stack(vals, dim=0) for name, vals in summary_metrics.items()} +def eval_3d( + gts: torch.Tensor, preds: torch.Tensor, masks: torch.Tensor, thresholds=None +): + summary_metrics = defaultdict(list) + ratio = min( + 1.0, (240 * 320 / masks.sum()) ** 0.5 + ) # rescale to avoid OOM during eval, FIXME + h_max, w_max = int(gts.shape[-2] * ratio), int(gts.shape[-1] * ratio) + gts = F.interpolate(gts, size=(h_max, w_max), mode="nearest-exact") + preds = F.interpolate(preds, size=(h_max, w_max), mode="nearest-exact") + masks = F.interpolate( + masks.float(), size=(h_max, w_max), mode="nearest-exact" + ).bool() + for i, (gt, pred, mask) in enumerate(zip(gts, preds, masks)): + if not torch.any(mask): + continue + for name, fn in DICT_METRICS_3D.items(): + summary_metrics[name].append( + fn(gt[:, mask.squeeze()], pred[:, mask.squeeze()], thresholds).mean() + ) + return {name: torch.stack(vals, dim=0) for name, vals in summary_metrics.items()} diff --git a/unidepth/utils/geometric.py b/unidepth/utils/geometric.py index 914b94e..e16fbe3 100644 --- a/unidepth/utils/geometric.py +++ b/unidepth/utils/geometric.py @@ -250,3 +250,53 @@ def flat_interpolate( 0, 2, 1 ) # b (h w) c return flat_tensor_interp.contiguous() + + +@torch.jit.script +def dilate(image, kernel_size: int | tuple[int, int]): + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + device, dtype = image.device, image.dtype + padding = (kernel_size[0] // 2, kernel_size[1] // 2) + kernel = torch.ones((1, 1, *kernel_size), dtype=torch.float32, device=image.device) + dilated_image = F.conv2d(image.float(), kernel, padding=padding, stride=1) + dilated_image = torch.where( + dilated_image > 0, + torch.tensor(1.0, device=device), + torch.tensor(0.0, device=device), + ) + return dilated_image.to(dtype) + + +@torch.jit.script +def erode(image, kernel_size: int | tuple[int, int]): + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + device, dtype = image.device, image.dtype + padding = (kernel_size[0] // 2, kernel_size[1] // 2) + kernel = torch.ones((1, 1, *kernel_size), dtype=torch.float32, device=image.device) + eroded_image = F.conv2d(image.float(), kernel, padding=padding, stride=1) + eroded_image = torch.where( + eroded_image == (kernel_size[0] * kernel_size[1]), + torch.tensor(1.0, device=device), + torch.tensor(0.0, device=device), + ) + return eroded_image.to(dtype) + + +@torch.jit.script +def iou(mask1: torch.Tensor, mask2: torch.Tensor) -> torch.Tensor: + device = mask1.device + + # Ensure the masks are binary (0 or 1) + mask1 = mask1.to(torch.bool) + mask2 = mask2.to(torch.bool) + + # Compute intersection and union + intersection = torch.sum(mask1 & mask2).to(torch.float32) + union = torch.sum(mask1 | mask2).to(torch.float32) + + # Compute IoU + iou = intersection / union.clip(min=1.0) + + return iou diff --git a/unidepth/utils/misc.py b/unidepth/utils/misc.py index 7084f07..cb961c1 100644 --- a/unidepth/utils/misc.py +++ b/unidepth/utils/misc.py @@ -1,5 +1,5 @@ -from collections import defaultdict -from functools import partial, wraps +from functools import wraps +from time import time import numpy as np import torch @@ -9,33 +9,39 @@ from scipy import interpolate -def max_stack(tensors): +@torch.jit.script +def max_stack(tensors: list[torch.Tensor]) -> torch.Tensor: if len(tensors) == 1: return tensors[0] return torch.stack(tensors, dim=-1).max(dim=-1).values -def last_stack(tensors): +def last_stack(tensors: list[torch.Tensor]) -> torch.Tensor: return tensors[-1] -def first_stack(tensors): +def first_stack(tensors: list[torch.Tensor]) -> torch.Tensor: return tensors[0] -def softmax_stack(tensors, temperature=1.0): +@torch.jit.script +def softmax_stack( + tensors: list[torch.Tensor], temperature: float = 1.0 +) -> torch.Tensor: if len(tensors) == 1: return tensors[0] return F.softmax(torch.stack(tensors, dim=-1) / temperature, dim=-1).sum(dim=-1) -def mean_stack(tensors): +@torch.jit.script +def mean_stack(tensors: list[torch.Tensor]) -> torch.Tensor: if len(tensors) == 1: return tensors[0] return torch.stack(tensors, dim=-1).mean(dim=-1) -def sum_stack(tensors): +@torch.jit.script +def sum_stack(tensors: list[torch.Tensor]) -> torch.Tensor: if len(tensors) == 1: return tensors[0] return torch.stack(tensors, dim=-1).sum(dim=-1) @@ -83,9 +89,11 @@ def get_params(module, lr, wd): (name in skip_list) or any((kw in name for kw in skip_keywords)) or len(param.shape) == 1 + or name.endswith(".gamma") + or name.endswith(".beta") + or name.endswith(".bias") ): - # if (name in skip_list) or any((kw in name for kw in skip_keywords)): - # print(name, skip_keywords) + # if (name in skip_list) or any((kw in name for kw in skip_keywords)) or len(param.shape) == 1: no_decay.append(param) else: has_decay.append(param) @@ -96,7 +104,7 @@ def get_params(module, lr, wd): "lr": lr, "weight_decay_init": wd, "weight_decay_base": wd, - "lr_init": lr, + # "lr_init": lr, "lr_base": lr, } group2 = { @@ -106,7 +114,7 @@ def get_params(module, lr, wd): "weight_decay_init": 0.0, "weight_decay_base": 0.0, "weight_decay_final": 0.0, - "lr_init": lr, + # "lr_init": lr, "lr_base": lr, } return [group1, group2], [lr, lr] @@ -346,28 +354,28 @@ def geometric_progression(a, r, n): def add_padding_metas(out, image_metas): device = out.device # left, right, top, bottom - paddings = [img_meta.get("padding_size", [0] * 4) for img_meta in image_metas] + paddings = [img_meta.get("paddings", [0] * 4) for img_meta in image_metas] paddings = torch.stack(paddings).to(device) outs = [F.pad(o, padding, value=0.0) for padding, o in zip(paddings, out)] return torch.stack(outs) +# left, right, top, bottom def remove_padding(out, paddings): - B, C, H, W = out.shape - device = out.device - # left, right, top, bottom - paddings = torch.stack(paddings).to(device) + H, W = out.shape[-2:] outs = [ - o[:, padding[1] : H - padding[3], padding[0] : W - padding[2]] + o[..., padding[2] : H - padding[3], padding[0] : W - padding[1]] for padding, o in zip(paddings, out) ] return torch.stack(outs) def remove_padding_metas(out, image_metas): + B, C, H, W = out.shape + device = out.device # left, right, top, bottom paddings = [ - torch.tensor(img_meta.get("padding_size", [0] * 4)) for img_meta in image_metas + torch.tensor(img_meta.get("paddings", [0] * 4)) for img_meta in image_metas ] return remove_padding(out, paddings) @@ -409,6 +417,15 @@ def remove_leading_dim(infos): return infos +def recursive_index(infos, index): + if isinstance(infos, dict): + return {k: recursive_index(v, index) for k, v in infos.items()} + elif isinstance(infos, torch.Tensor): + return infos[index] + else: + return infos + + def to_cpu(infos): if isinstance(infos, dict): return {k: to_cpu(v) for k, v in infos.items()} @@ -416,3 +433,191 @@ def to_cpu(infos): return infos.detach() else: return infos + + +def masked_mean( + data: torch.Tensor, + mask: torch.Tensor | None = None, + dim: list[int] | None = None, + keepdim: bool = False, +) -> torch.Tensor: + dim = dim if dim is not None else list(range(data.dim())) + if mask is None: + return data.mean(dim=dim, keepdim=keepdim) + mask = mask.float() + mask_sum = torch.sum(mask, dim=dim, keepdim=True) + mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp( + mask_sum, min=1.0 + ) + return mask_mean.squeeze(dim) if not keepdim else mask_mean + + +class ProfileMethod: + def __init__(self, model, func_name, track_statistics=True, verbose=False): + self.model = model + self.func_name = func_name + self.verbose = verbose + self.track_statistics = track_statistics + self.timings = [] + + def __enter__(self): + # Start timing + if self.verbose: + if torch.cuda.is_available(): + torch.cuda.synchronize() + self.start_time = time() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.verbose: + if torch.cuda.is_available(): + torch.cuda.synchronize() + + self.end_time = time() + + elapsed_time = self.end_time - self.start_time + + self.timings.append(elapsed_time) + if self.track_statistics and len(self.timings) > 25: + + # Compute statistics if tracking + timings_array = np.array(self.timings) + mean_time = np.mean(timings_array) + std_time = np.std(timings_array) + quantiles = np.percentile(timings_array, [0, 25, 50, 75, 100]) + print( + f"{self.model.__class__.__name__}.{self.func_name} took {elapsed_time:.4f} seconds" + ) + print(f"Mean Time: {mean_time:.4f} seconds") + print(f"Std Time: {std_time:.4f} seconds") + print( + f"Quantiles: Min={quantiles[0]:.4f}, 25%={quantiles[1]:.4f}, Median={quantiles[2]:.4f}, 75%={quantiles[3]:.4f}, Max={quantiles[4]:.4f}" + ) + + else: + print( + f"{self.model.__class__.__name__}.{self.func_name} took {elapsed_time:.4f} seconds" + ) + + +def profile_method(track_statistics=True, verbose=False): + def decorator(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + with ProfileMethod(self, func.__name__, track_statistics, verbose): + return func(self, *args, **kwargs) + + return wrapper + + return decorator + + +class ProfileFunction: + def __init__(self, func_name, track_statistics=True, verbose=False): + self.func_name = func_name + self.verbose = verbose + self.track_statistics = track_statistics + self.timings = [] + + def __enter__(self): + # Start timing + if self.verbose: + if torch.cuda.is_available(): + torch.cuda.synchronize() + self.start_time = time() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.verbose: + if torch.cuda.is_available(): + torch.cuda.synchronize() + + self.end_time = time() + + elapsed_time = self.end_time - self.start_time + + self.timings.append(elapsed_time) + if self.track_statistics and len(self.timings) > 25: + + # Compute statistics if tracking + timings_array = np.array(self.timings) + mean_time = np.mean(timings_array) + std_time = np.std(timings_array) + quantiles = np.percentile(timings_array, [0, 25, 50, 75, 100]) + print(f"{self.func_name} took {elapsed_time:.4f} seconds") + print(f"Mean Time: {mean_time:.4f} seconds") + print(f"Std Time: {std_time:.4f} seconds") + print( + f"Quantiles: Min={quantiles[0]:.4f}, 25%={quantiles[1]:.4f}, Median={quantiles[2]:.4f}, 75%={quantiles[3]:.4f}, Max={quantiles[4]:.4f}" + ) + + else: + print(f"{self.func_name} took {elapsed_time:.4f} seconds") + + +def profile_function(track_statistics=True, verbose=False): + def decorator(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + with ProfileFunction(func.__name__, track_statistics, verbose): + return func(self, *args, **kwargs) + + return wrapper + + return decorator + + +def squeeze_list(nested_list, dim, current_dim=0): + # If the current dimension is in the list of indices to squeeze + if isinstance(nested_list, list) and len(nested_list) == 1 and current_dim == dim: + return squeeze_list(nested_list[0], dim, current_dim + 1) + elif isinstance(nested_list, list): + return [squeeze_list(item, dim, current_dim + 1) for item in nested_list] + else: + return nested_list + + +def match_gt(tensor1, tensor2, padding1, padding2, mode: str = "bilinear"): + """ + Transform each item in tensor1 batch to match tensor2's dimensions and padding. + + Args: + tensor1 (torch.Tensor): The input tensor to transform, with shape (batch_size, channels, height, width). + tensor2 (torch.Tensor): The target tensor to match, with shape (batch_size, channels, height, width). + padding1 (tuple): Padding applied to tensor1 (pad_left, pad_right, pad_top, pad_bottom). + padding2 (tuple): Desired padding to be applied to match tensor2 (pad_left, pad_right, pad_top, pad_bottom). + + Returns: + torch.Tensor: The batch of transformed tensors matching tensor2's size and padding. + """ + # Get batch size + batch_size = len(tensor1) + src_dtype = tensor1[0].dtype + tgt_dtype = tensor2[0].dtype + + # List to store transformed tensors + transformed_tensors = [] + + for i in range(batch_size): + item1 = tensor1[i] + item2 = tensor2[i] + + h1, w1 = item1.shape[1], item1.shape[2] + pad1_l, pad1_r, pad1_t, pad1_b = ( + padding1[i] if padding1 is not None else (0, 0, 0, 0) + ) + item1_unpadded = item1[:, pad1_t : h1 - pad1_b, pad1_l : w1 - pad1_r] + + h2, w2 = ( + item2.shape[1] - padding2[i][2] - padding2[i][3], + item2.shape[2] - padding2[i][0] - padding2[i][1], + ) + + item1_resized = F.interpolate( + item1_unpadded.unsqueeze(0).to(tgt_dtype), size=(h2, w2), mode=mode + ) + item1_padded = F.pad(item1_resized, tuple(padding2[i])) + transformed_tensors.append(item1_padded) + + transformed_batch = torch.cat(transformed_tensors) + return transformed_batch.to(src_dtype) diff --git a/unidepth/utils/validation.py b/unidepth/utils/validation.py new file mode 100644 index 0000000..511c44f --- /dev/null +++ b/unidepth/utils/validation.py @@ -0,0 +1,89 @@ +import torch +import torch.utils.data.distributed +import wandb +from torch.nn import functional as F + +from unidepth.utils import barrier, is_main_process +from unidepth.utils.misc import remove_padding + + +def original_image(batch, preds=None): + paddings = [ + torch.tensor(pads) + for img_meta in batch["img_metas"] + for pads in img_meta.get("paddings", [[0] * 4]) + ] + paddings = torch.stack(paddings).to(batch["data"]["image"].device)[ + ..., [0, 2, 1, 3] + ] # lrtb + + T, _, H, W = batch["data"]["depth"].shape + batch["data"]["image"] = F.interpolate( + batch["data"]["image"], + (H + paddings[2] + paddings[3], W + paddings[1] + paddings[2]), + mode="bilinear", + align_corners=False, + antialias=True, + ) + batch["data"]["image"] = remove_padding( + batch["data"]["image"], paddings.repeat(T, 1) + ) + + if preds is not None: + for key in ["depth"]: + if key in preds: + preds[key] = F.interpolate( + preds[key], + (H + paddings[2] + paddings[3], W + paddings[1] + paddings[2]), + mode="bilinear", + align_corners=False, + antialias=True, + ) + preds[key] = remove_padding(preds[key], paddings.repeat(T, 1)) + + return batch, preds + + +def log_metrics(metrics_all, step): + for name_ds, metrics in metrics_all.items(): + for metrics_name, metrics_value in metrics.items(): + try: + print(f"Metrics/{name_ds}/{metrics_name} {round(metrics_value, 4)}") + wandb.log( + {f"Metrics/{name_ds}/{metrics_name}": metrics_value}, step=step + ) + except: + pass + + +def validate(model, test_loaders, step, context): + metrics_all = {} + for name_ds, test_loader in test_loaders.items(): + for i, batch in enumerate(test_loader): + with context: + batch["data"] = { + k: v.to(model.device) for k, v in batch["data"].items() + } + # remove temporal dimension of the dataloder, here is always 1! + batch["data"] = {k: v.squeeze(1) for k, v in batch["data"].items()} + batch["img_metas"] = [ + {k: v[0] for k, v in meta.items() if isinstance(v, list)} + for meta in batch["img_metas"] + ] + + preds, losses = model(batch["data"], batch["img_metas"]) + + batch, _ = original_image(batch, preds=None) + test_loader.dataset.accumulate_metrics( + inputs=batch["data"], + preds=preds, + keyframe_idx=batch["img_metas"][0].get("keyframe_idx"), + ) + + barrier() + metrics_all[name_ds] = test_loader.dataset.get_evaluation() + + barrier() + if is_main_process(): + log_metrics(metrics_all=metrics_all, step=step) + return metrics_all