Skip to content

Commit

Permalink
Allow overwriting features or matches
Browse files Browse the repository at this point in the history
  • Loading branch information
sarlinpe committed Jan 4, 2022
1 parent aed4d00 commit 81ab784
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 10 deletions.
14 changes: 11 additions & 3 deletions hloc/extract_features.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import torch
from pathlib import Path
from typing import Dict, List, Union, Optional
import h5py
from types import SimpleNamespace
import cv2
Expand Down Expand Up @@ -201,8 +202,13 @@ def __len__(self):


@torch.no_grad()
def main(conf, image_dir, export_dir=None, as_half=False,
image_list=None, feature_path=None):
def main(conf: Dict,
image_dir: Path,
export_dir: Optional[Path] = None,
as_half: bool = False,
image_list: Optional[Union[Path, List[str]]] = None,
feature_path: Optional[Path] = None,
overwrite: bool = False) -> Path:
logger.info('Extracting local features with configuration:'
f'\n{pprint.pformat(conf)}')

Expand All @@ -213,7 +219,7 @@ def main(conf, image_dir, export_dir=None, as_half=False,
feature_path = Path(export_dir, conf['output']+'.h5')
feature_path.parent.mkdir(exist_ok=True, parents=True)
skip_names = set(list_h5_names(feature_path)
if feature_path.exists() else ())
if feature_path.exists() and not overwrite else ())
if set(loader.dataset.names).issubset(set(skip_names)):
logger.info('Skipping the extraction.')
return feature_path
Expand Down Expand Up @@ -244,6 +250,8 @@ def main(conf, image_dir, export_dir=None, as_half=False,

with h5py.File(str(feature_path), 'a') as fd:
try:
if name in fd:
del fd[name]
grp = fd.create_group(name)
for k, v in pred.items():
grp.create_dataset(k, data=v)
Expand Down
24 changes: 17 additions & 7 deletions hloc/match_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,12 @@
}


def main(conf: Dict, pairs: Path, features: Union[Path, str],
export_dir: Optional[Path] = None, matches: Optional[Path] = None,
features_ref: Optional[Path] = None):
def main(conf: Dict,
pairs: Path, features: Union[Path, str],
export_dir: Optional[Path] = None,
matches: Optional[Path] = None,
features_ref: Optional[Path] = None,
overwrite: bool = False) -> Path:

if isinstance(features, Path) or Path(features).exists():
features_q = features
Expand All @@ -79,14 +82,18 @@ def main(conf: Dict, pairs: Path, features: Union[Path, str],
else:
features_ref = [features_ref]

match_from_paths(conf, pairs, matches, features_q, features_ref)
match_from_paths(conf, pairs, matches, features_q, features_ref, overwrite)

return matches


@torch.no_grad()
def match_from_paths(conf: Dict, pairs_path: Path, match_path: Path,
feature_path_q: Path, feature_paths_refs: Path):
def match_from_paths(conf: Dict,
pairs_path: Path,
match_path: Path,
feature_path_q: Path,
feature_paths_refs: Path,
overwrite: bool = False) -> Path:
logger.info('Matching local features with configuration:'
f'\n{pprint.pformat(conf)}')

Expand All @@ -107,7 +114,8 @@ def match_from_paths(conf: Dict, pairs_path: Path, match_path: Path,
model = Model(conf['model']).eval().to(device)

match_path.parent.mkdir(exist_ok=True, parents=True)
skip_pairs = set(list_h5_names(match_path) if match_path.exists() else ())
skip_pairs = set(list_h5_names(match_path)
if match_path.exists() and not overwrite else ())

for (name0, name1) in tqdm(pairs, smoothing=.1):
pair = names_to_pair(name0, name1)
Expand All @@ -131,6 +139,8 @@ def match_from_paths(conf: Dict, pairs_path: Path, match_path: Path,

pred = model(data)
with h5py.File(str(match_path), 'a') as fd:
if pair in fd:
del fd[pair]
grp = fd.create_group(pair)
matches = pred['matches0'][0].cpu().short().numpy()
grp.create_dataset('matches0', data=matches)
Expand Down

0 comments on commit 81ab784

Please sign in to comment.