diff --git a/hloc/extract_features.py b/hloc/extract_features.py index f7fd6990..5594d567 100644 --- a/hloc/extract_features.py +++ b/hloc/extract_features.py @@ -125,6 +125,17 @@ "resize_max": 1024, }, }, + "xfeat": { + "output": "feats-xfeat-n5000-r1600", + "model": { + "name": "xfeat", + "max_keypoints": 5000, + }, + "preprocessing": { + "grayscale": False, + "resize_max": 1600, + }, + }, # Global descriptors "dir": { "output": "global-feats-dir", diff --git a/hloc/extractors/xfeat.py b/hloc/extractors/xfeat.py new file mode 100644 index 00000000..5dc230f2 --- /dev/null +++ b/hloc/extractors/xfeat.py @@ -0,0 +1,33 @@ +import torch + +from hloc import logger + +from ..utils.base_model import BaseModel + + +class XFeat(BaseModel): + default_conf = { + "keypoint_threshold": 0.005, + "max_keypoints": -1, + } + required_inputs = ["image"] + + def _init(self, conf): + self.net = torch.hub.load( + "verlab/accelerated_features", + "XFeat", + pretrained=True, + top_k=self.conf["max_keypoints"], + ) + logger.info("Load XFeat(sparse) model done.") + + def _forward(self, data): + pred = self.net.detectAndCompute( + data["image"], top_k=self.conf["max_keypoints"] + )[0] + pred = { + "keypoints": pred["keypoints"][None], + "scores": pred["scores"][None], + "descriptors": pred["descriptors"].T[None], + } + return pred diff --git a/hloc/match_features.py b/hloc/match_features.py index 679e81e9..ad818a0c 100644 --- a/hloc/match_features.py +++ b/hloc/match_features.py @@ -42,6 +42,13 @@ "features": "aliked", }, }, + "xfeat+lighterglue": { + "output": "matches-xfeat-lighterglue", + "model": { + "name": "lighterglue", + "features": "xfeat", + }, + }, "superglue": { "output": "matches-superglue", "model": { diff --git a/hloc/matchers/lighterglue.py b/hloc/matchers/lighterglue.py new file mode 100644 index 00000000..c5752d92 --- /dev/null +++ b/hloc/matchers/lighterglue.py @@ -0,0 +1,58 @@ +import torch +from lightglue import LightGlue as LightGlue_ + +from ..utils.base_model import BaseModel + + +class LighterGlue(BaseModel): + default_conf_xfeat = { + "name": "lighterglue", # just for interfacing + "input_dim": 64, # input descriptor dimension (autoselected from weights) + "descriptor_dim": 96, + "add_scale_ori": False, + "add_laf": False, # for KeyNetAffNetHardNet + "scale_coef": 1.0, # to compensate for the SIFT scale bigger than KeyNet + "n_layers": 6, + "num_heads": 1, + "flash": True, # enable FlashAttention if available. + "mp": False, # enable mixed precision + "depth_confidence": -1, # early stopping, disable with -1 + "width_confidence": 0.95, # point pruning, disable with -1 + "filter_threshold": 0.1, # match threshold + "weights": None, + } + required_inputs = [ + "image0", + "keypoints0", + "descriptors0", + "image1", + "keypoints1", + "descriptors1", + ] + + def _init(self, conf): + LightGlue_.default_conf = self.default_conf_xfeat + self.net = LightGlue_(None, **conf) + url = "https://github.com/verlab/accelerated_features/raw/main/weights/xfeat-lighterglue.pt" # noqa: E501 + state_dict = torch.hub.load_state_dict_from_url(url) + + # rename old state dict entries + for i in range(self.net.conf.n_layers): + pattern = f"self_attn.{i}", f"transformers.{i}.self_attn" + state_dict = {k.replace(*pattern): v for k, v in state_dict.items()} + pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn" + state_dict = {k.replace(*pattern): v for k, v in state_dict.items()} + state_dict = {k.replace("matcher.", ""): v for k, v in state_dict.items()} + + self.net.load_state_dict(state_dict, strict=False) + + def _forward(self, data): + data["descriptors0"] = data["descriptors0"].transpose(-1, -2) + data["descriptors1"] = data["descriptors1"].transpose(-1, -2) + + return self.net( + { + "image0": {k[:-1]: v for k, v in data.items() if k[-1] == "0"}, + "image1": {k[:-1]: v for k, v in data.items() if k[-1] == "1"}, + } + ) diff --git a/hloc/pairs_from_retrieval.py b/hloc/pairs_from_retrieval.py index 32336801..4ccfbc10 100644 --- a/hloc/pairs_from_retrieval.py +++ b/hloc/pairs_from_retrieval.py @@ -81,6 +81,7 @@ def main( db_list=None, db_model=None, db_descriptors=None, + match_mask=None, ): logger.info("Extracting image pairs from a retrieval database.") @@ -108,8 +109,15 @@ def main( query_desc = get_descriptors(query_names, descriptors) sim = torch.einsum("id,jd->ij", query_desc.to(device), db_desc.to(device)) - # Avoid self-matching - self = np.array(query_names)[:, None] == np.array(db_names)[None] + if match_mask is None: + # Avoid self-matching + self = np.array(query_names)[:, None] == np.array(db_names)[None] + else: + assert match_mask.shape == ( + len(query_names), + len(db_names), + ), "mask shape must match size of query and database images!" + self = match_mask pairs = pairs_from_score_matrix(sim, self, num_matched, min_score=0) pairs = [(query_names[i], db_names[j]) for i, j in pairs] diff --git a/hloc/pairs_from_sequential.py b/hloc/pairs_from_sequential.py new file mode 100644 index 00000000..2bb9811d --- /dev/null +++ b/hloc/pairs_from_sequential.py @@ -0,0 +1,150 @@ +import argparse +import collections.abc as collections +import os +from pathlib import Path +from typing import List, Optional, Union + +import numpy as np + +from hloc import logger, pairs_from_retrieval +from hloc.utils.io import list_h5_names +from hloc.utils.parsers import parse_image_lists, parse_retrieval + + +def main( + output: Path, + image_list: Optional[Union[Path, List[str]]] = None, + features: Optional[Path] = None, + window_size: Optional[int] = 10, + quadratic_overlap: bool = True, + use_loop_closure: bool = False, + retrieval_path: Optional[Union[Path, str]] = None, + retrieval_interval: Optional[int] = 2, + num_loc: Optional[int] = 5, +) -> None: + """ + Generate pairs of images based on sequential matching and optional loop closure. + + Args: + output (Path): The output file path where the pairs will be saved. + image_list (Optional[Union[Path, List[str]]]): + A path to a file containing a list of images or a list of image names. + features (Optional[Path]): + A path to a feature file containing image features. + window_size (Optional[int]): + The size of the window for sequential matching. Default is 10. + quadratic_overlap (bool): + Whether to use quadratic overlap in sequential matching. Default is True. + use_loop_closure (bool): + Whether to use loop closure for additional matching. Default is False. + retrieval_path (Optional[Union[Path, str]]): + The path to the retrieval file for loop closure. + retrieval_interval (Optional[int]): + The interval for selecting query images for loop closure. Default is 2. + num_loc (Optional[int]): + The number of top retrieval matches to consider for loop closure. + Default is 5. + + Raises: + ValueError: If neither image_list nor features are provided, + or if image_list is of an unknown type. + + Returns: + None + """ + if image_list is not None: + if isinstance(image_list, (str, Path)): + print(image_list) + names_q = parse_image_lists(image_list) + elif isinstance(image_list, collections.Iterable): + names_q = list(image_list) + else: + raise ValueError(f"Unknown type for image list: {image_list}") + elif features is not None: + names_q = list_h5_names(features) + else: + raise ValueError("Provide either a list of images or a feature file.") + + pairs = [] + N = len(names_q) + + for i in range(N - 1): + for j in range(i + 1, min(i + window_size + 1, N)): + pairs.append((names_q[i], names_q[j])) + + if quadratic_overlap: + q = 2 ** (j - i) + if q > window_size and i + q < N: + pairs.append((names_q[i], names_q[i + q])) + + if use_loop_closure: + retrieval_pairs_tmp: Path = output.parent / "retrieval-pairs-tmp.txt" + + # match mask describes for each image, which images NOT to include in retrevial + # match search I.e., no reason to get retrieval matches for matches + # already included from sequential matching + + query_list = names_q[::retrieval_interval] + M = len(query_list) + match_mask = np.zeros((M, N), dtype=bool) + + for i in range(M): + for k in range(window_size + 1): + if i * retrieval_interval - k >= 0 and i * retrieval_interval - k < N: + match_mask[i][i * retrieval_interval - k] = 1 + if i * retrieval_interval + k >= 0 and i * retrieval_interval + k < N: + match_mask[i][i * retrieval_interval + k] = 1 + + if quadratic_overlap: + if ( + i * retrieval_interval - 2**k >= 0 + and i * retrieval_interval - 2**k < N + ): + match_mask[i][i * retrieval_interval - 2**k] = 1 + if ( + i * retrieval_interval + 2**k >= 0 + and i * retrieval_interval + 2**k < N + ): + match_mask[i][i * retrieval_interval + 2**k] = 1 + + pairs_from_retrieval.main( + retrieval_path, + retrieval_pairs_tmp, + num_matched=num_loc, + match_mask=match_mask, + db_list=names_q, + query_list=query_list, + ) + + retrieval = parse_retrieval(retrieval_pairs_tmp) + + for key, val in retrieval.items(): + for match in val: + pairs.append((key, match)) + + os.unlink(retrieval_pairs_tmp) + + logger.info(f"Found {len(pairs)} pairs.") + with open(output, "w") as f: + f.write("\n".join(" ".join([i, j]) for i, j in pairs)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=""" + Create a list of image pairs basedon the sequence of images on alphabetic order + """ + ) + parser.add_argument("--output", required=True, type=Path) + parser.add_argument("--image_list", type=Path) + parser.add_argument("--features", type=Path) + parser.add_argument( + "--overlap", type=int, default=10, help="Number of overlapping image pairs" + ) + parser.add_argument( + "--quadratic_overlap", + action="store_true", + help="Whether to match images against their quadratic neighbors.", + ) + args = parser.parse_args() + main(**args.__dict__)