From 22017e82cc7dedb6d3c40f05fd2daf5015b64b3a Mon Sep 17 00:00:00 2001 From: Akshay Krishnan Date: Mon, 19 Jan 2026 12:40:51 -0500 Subject: [PATCH 01/24] store pre ba result --- .../cluster_optimizer_cacher.py | 21 +++++++++- gtsfm/cluster_optimizer/cluster_vggt.py | 41 ++++++++++++++++--- gtsfm/frontend/vggt.py | 25 ++++++----- 3 files changed, 70 insertions(+), 17 deletions(-) diff --git a/gtsfm/cluster_optimizer/cluster_optimizer_cacher.py b/gtsfm/cluster_optimizer/cluster_optimizer_cacher.py index b59e37e00..4edad8d99 100644 --- a/gtsfm/cluster_optimizer/cluster_optimizer_cacher.py +++ b/gtsfm/cluster_optimizer/cluster_optimizer_cacher.py @@ -3,6 +3,7 @@ from __future__ import annotations import hashlib +import os import typing from pathlib import Path from typing import Optional, TYPE_CHECKING @@ -34,11 +35,12 @@ class ClusterOptimizerCacher(ClusterOptimizerBase): """Caches the delayed bundle result produced by a cluster optimizer.""" - def __init__(self, optimizer: ClusterOptimizerBase) -> None: + def __init__(self, optimizer: ClusterOptimizerBase, cache_subdir: Optional[str] = None) -> None: """Initializes the cacher with the actual cluster optimizer object. Args: optimizer: cluster optimizer to use in case of cache miss. + cache_subdir: Optional subdirectory (relative to cache root) for storing cache entries. """ super().__init__( pose_angular_error_thresh=optimizer.pose_angular_error_thresh, @@ -46,6 +48,8 @@ def __init__(self, optimizer: ClusterOptimizerBase) -> None: ) self._optimizer = optimizer self._optimizer_hash = hashlib.sha1(repr(optimizer).encode()).hexdigest() + self._cache_subdir = cache_subdir if cache_subdir is not None else os.getenv("GTSFM_CACHE_SUBDIR") + self._cache_root_path = self._resolve_cache_root(self._cache_subdir) def __repr__(self) -> str: return repr(self._optimizer) @@ -64,20 +68,33 @@ def __getstate__(self) -> dict[str, object]: return { "_optimizer": self._optimizer, "_optimizer_hash": self._optimizer_hash, + "_cache_subdir": self._cache_subdir, } def __setstate__(self, state: dict[str, object]) -> None: """Restore state and keep worker routing consistent.""" self._optimizer = typing.cast(ClusterOptimizerBase, state["_optimizer"]) self._optimizer_hash = typing.cast(str, state["_optimizer_hash"]) + self._cache_subdir = typing.cast(Optional[str], state.get("_cache_subdir")) + self._cache_root_path = self._resolve_cache_root(self._cache_subdir) # Re-initialize the base class to mimic the constructor. super().__init__( pose_angular_error_thresh=self._optimizer.pose_angular_error_thresh, output_worker=self._optimizer._output_worker, ) + @staticmethod + def _resolve_cache_root(cache_subdir: Optional[str]) -> Path: + """Resolve the cache root path, optionally using a subdirectory or absolute override.""" + if not cache_subdir: + return CACHE_ROOT_PATH + subdir_path = Path(cache_subdir) + if subdir_path.is_absolute(): + return subdir_path + return CACHE_ROOT_PATH / subdir_path + def _get_cache_path(self, cache_key: str) -> Path: - return CACHE_ROOT_PATH / "cluster_optimizer" / f"{cache_key}.pbz2" + return self._cache_root_path / "cluster_optimizer" / f"{cache_key}.pbz2" def _hash_one_view_data(self, one_view_data: Optional["OneViewData"]) -> str: """Compute a stable hash for OneViewData contents.""" diff --git a/gtsfm/cluster_optimizer/cluster_vggt.py b/gtsfm/cluster_optimizer/cluster_vggt.py index 9c934d50b..c74ea9e2e 100644 --- a/gtsfm/cluster_optimizer/cluster_vggt.py +++ b/gtsfm/cluster_optimizer/cluster_vggt.py @@ -14,7 +14,7 @@ from gtsfm.cluster_optimizer.cluster_optimizer_base import ClusterComputationGraph, ClusterContext, ClusterOptimizerBase from gtsfm.common.gtsfm_data import GtsfmData from gtsfm.evaluation.metrics import GtsfmMetric, GtsfmMetricsGroup -from gtsfm.frontend.vggt import VggtConfiguration +from gtsfm.frontend.vggt import VggtConfiguration, VggtReconstruction from gtsfm.products.visibility_graph import visibility_graph_keys from gtsfm.ui.gtsfm_process import UiMetadata from gtsfm.utils.logger import get_logger @@ -65,7 +65,7 @@ def _run_vggt_pipeline( model_cache_key: Hashable | None = None, loader_kwargs: dict[str, Any] | None = None, **kwargs, -) -> GtsfmData: +) -> VggtReconstruction: torch.manual_seed(seed) np.random.seed(seed) if torch.cuda.is_available(): @@ -83,18 +83,29 @@ def _run_vggt_pipeline( cached_model = _resolve_vggt_model(model_cache_key, loader_kwargs) if cached_model is not None: kwargs = {**kwargs, "model": cached_model} - return vggt.run_reconstruction_gtsfm_data_only(image_batch, **kwargs) + return vggt.run_reconstruction(image_batch, **kwargs) def _save_reconstruction_as_text( result: GtsfmData, results_path: Path, + *, + subdir: str = "vggt", ) -> None: - target_dir = results_path / "vggt" + target_dir = results_path / subdir target_dir.mkdir(parents=True, exist_ok=True) result.export_as_colmap_text(target_dir) +def _save_pre_ba_reconstruction_as_text( + pre_ba_result: Optional[GtsfmData], + results_path: Path, +) -> None: + if pre_ba_result is None: + return + _save_reconstruction_as_text(pre_ba_result, results_path, subdir="vggt_pre_ba") + + def _aggregate_vggt_metrics(result: GtsfmData) -> GtsfmMetricsGroup: num_cameras = len(result.get_valid_camera_indices()) num_points3d = result.number_tracks() @@ -107,6 +118,16 @@ def _aggregate_vggt_metrics(result: GtsfmData) -> GtsfmMetricsGroup: ) +def _extract_post_ba_result(result: VggtReconstruction) -> GtsfmData: + """Extract the post-BA reconstruction from the VGGT pipeline output.""" + return result.gtsfm_data + + +def _extract_pre_ba_result(result: VggtReconstruction) -> Optional[GtsfmData]: + """Extract the optional pre-BA reconstruction for debugging.""" + return result.pre_ba_data + + class ClusterVGGT(ClusterOptimizerBase): """Cluster optimizer that runs VGGT to generate COLMAP-style reconstructions.""" @@ -137,6 +158,7 @@ def __init__( enable_protection: bool = False, extra_model_kwargs: Optional[dict[str, Any]] = None, run_bundle_adjustment_on_leaf: bool = False, + store_pre_ba_result: bool = False, run_bundle_adjustment_on_parent: bool = True, max_reproj_error: float = 8.0, plot_reprojection_histograms: bool = True, @@ -170,6 +192,7 @@ def __init__( self._use_sparse_attention = use_sparse_attention self._dtype = inference_dtype self._run_bundle_adjustment_on_leaf = run_bundle_adjustment_on_leaf + self._store_pre_ba_result = store_pre_ba_result if fast_dtype is not None: if self._dtype is None: self._dtype = fast_dtype @@ -266,6 +289,7 @@ def create_computation_graph( model_ctor_kwargs=self._model_ctor_kwargs.copy(), use_sparse_attention=self._use_sparse_attention, run_bundle_adjustment_on_leaf=self._run_bundle_adjustment_on_leaf, + store_pre_ba_result=self._store_pre_ba_result, max_reproj_error=self._max_reproj_error, ) @@ -273,7 +297,7 @@ def create_computation_graph( context.loader, global_indices, self._image_load_resolution ) - result_graph = delayed(_run_vggt_pipeline)( + reconstruction_graph = delayed(_run_vggt_pipeline)( image_batch_graph, seed=self._seed, original_coords=original_coords_graph, @@ -285,6 +309,7 @@ def create_computation_graph( loader_kwargs=self._loader_kwargs or None, cluster_label=context.label, ) + result_graph = delayed(_extract_post_ba_result)(reconstruction_graph) metrics_tasks = [delayed(_aggregate_vggt_metrics)(result_graph)] @@ -296,6 +321,12 @@ def create_computation_graph( context.output_paths.results, ) ) + io_tasks.append( + delayed(_save_pre_ba_reconstruction_as_text)( + delayed(_extract_pre_ba_result)(reconstruction_graph), + context.output_paths.results, + ) + ) return ClusterComputationGraph( io_tasks=tuple(io_tasks), diff --git a/gtsfm/frontend/vggt.py b/gtsfm/frontend/vggt.py index bdb4bab9c..efb9bc869 100644 --- a/gtsfm/frontend/vggt.py +++ b/gtsfm/frontend/vggt.py @@ -151,6 +151,7 @@ class VggtConfiguration: model_ctor_kwargs: dict[str, Any] = field(default_factory=dict) use_sparse_attention: bool = False run_bundle_adjustment_on_leaf: bool = False + store_pre_ba_result: bool = False # Tracking-specific parameters: tracking: bool = True @@ -181,7 +182,8 @@ class VggtReconstruction: """Outputs from the VGGT reconstruction helper. Attributes: - gtsfm_data: Sparse scene estimate including cameras and tracks in original image coordinates. + gtsfm_data: Sparse scene estimate (post-BA if enabled) in original image coordinates. + pre_ba_data: Optional sparse scene estimate before bundle adjustment (debug-only). points_3d: Dense point cloud filtered by VGGT confidence scores. points_rgb: Per-point RGB colors aligned with ``points_3d``. tracking_result: Optional dense tracking payload in the square VGGT coordinate frame. @@ -190,6 +192,7 @@ class VggtReconstruction: gtsfm_data: GtsfmData points_3d: np.ndarray points_rgb: np.ndarray + pre_ba_data: GtsfmData | None = None tracking_result: "VGGTTrackingResult | None" = None def visualize_tracks( @@ -382,7 +385,8 @@ def _high_confidence_pointcloud(config: VggtConfiguration, vggt_output: VggtOutp ) depth_conf_np = vggt_output.depth_confidence.to(torch.float32).cpu().numpy() - conf_mask = depth_conf_np >= config.confidence_threshold + conf_threshold = min(config.confidence_threshold, depth_conf_np.mean() - depth_conf_np.std()) + conf_mask = depth_conf_np >= conf_threshold conf_mask = randomly_limit_trues(conf_mask, config.max_num_points) # limit number of points if asked return points_3d[conf_mask], points_rgb[conf_mask] @@ -400,9 +404,6 @@ def _is_point_in_front_of_camera(camera, point_xyz: np.ndarray, *, epsilon: floa return float(z_val) > epsilon - - - def _convert_vggt_outputs_to_gtsfm_data( *, vggt_output: VggtOutput, @@ -413,7 +414,7 @@ def _convert_vggt_outputs_to_gtsfm_data( points_3d: np.ndarray, points_rgb: np.ndarray, tracking_result: VGGTTrackingResult | None = None, -) -> GtsfmData: +) -> tuple[GtsfmData, GtsfmData | None]: """Convert raw VGGT predictions into ``GtsfmData``.""" extrinsic_np = vggt_output.extrinsic.to(torch.float32).cpu().numpy() @@ -503,21 +504,24 @@ def _convert_vggt_outputs_to_gtsfm_data( track.addMeasurement(global_idx, Point2(rescaled_u, rescaled_v)) gtsfm_data.add_track(track) + gtsfm_data_pre_ba: GtsfmData | None = None if config.run_bundle_adjustment_on_leaf: + if config.store_pre_ba_result: + gtsfm_data_pre_ba = gtsfm_data if gtsfm_data.number_tracks() == 0: logger.warning("Skipping bundle adjustment because VGGT produced no valid tracks.") else: try: gtsfm_data, should_run_ba = data_utils.remove_cameras_with_no_tracks(gtsfm_data, "node-level BA") if not should_run_ba: - return gtsfm_data + return gtsfm_data, gtsfm_data_pre_ba optimizer = BundleAdjustmentOptimizer() gtsfm_data_with_ba, _ = optimizer.run_simple_ba(gtsfm_data, verbose=False) - return gtsfm_data_with_ba + return gtsfm_data_with_ba, gtsfm_data_pre_ba except Exception as exc: logger.warning("⚠️ Failed to run bundle adjustment: %s", exc) - return gtsfm_data + return gtsfm_data, gtsfm_data_pre_ba def _offload_vggt_model(model: Optional[VGGT]) -> None: @@ -808,7 +812,7 @@ def run_reconstruction( points_3d, points_rgb = _high_confidence_pointcloud(cfg, vggt_output) - gtsfm_data = _convert_vggt_outputs_to_gtsfm_data( + gtsfm_data, gtsfm_data_pre_ba = _convert_vggt_outputs_to_gtsfm_data( config=cfg, vggt_output=vggt_output, original_coords=original_coords, @@ -825,6 +829,7 @@ def run_reconstruction( return VggtReconstruction( gtsfm_data=gtsfm_data, + pre_ba_data=gtsfm_data_pre_ba, points_3d=points_3d, points_rgb=points_rgb, tracking_result=tracking_result, From bd91d3be759843ca9fbb08a9904f382e9c1fe3c6 Mon Sep 17 00:00:00 2001 From: Akshay Krishnan Date: Tue, 20 Jan 2026 12:30:36 -0500 Subject: [PATCH 02/24] track visualization script --- create_tracks_viz.sh | 17 ++ gtsfm/visualization/visualize_tracks.py | 372 ++++++++++++++++++++++++ 2 files changed, 389 insertions(+) create mode 100644 create_tracks_viz.sh create mode 100644 gtsfm/visualization/visualize_tracks.py diff --git a/create_tracks_viz.sh b/create_tracks_viz.sh new file mode 100644 index 000000000..641a5ef77 --- /dev/null +++ b/create_tracks_viz.sh @@ -0,0 +1,17 @@ +# python gtsfm/visualization/visualize_tracks.py \ +# --result_root /coc/flash5/akrishnan86/gtsfm/outputs/metis_skydio32/results/ \ +# --loader_config colmap \ +# --dataset_dir /coc/flash5/akrishnan86/gtsfm/data/skydio32 + + + +# python gtsfm/visualization/visualize_tracks.py \ +# --result_root /coc/flash5/akrishnan86/gtsfm/outputs/metis_cm2_palace_0_4/results/ \ +# --loader_config olsson \ +# --dataset_dir /coc/flash5/akrishnan86/gtsfm/data/palace_fine_arts + +python gtsfm/visualization/visualize_tracks.py \ + --result_root /coc/flash5/akrishnan86/gtsfm/outputs/metis_vggt_palace/results/ \ + --loader_config olsson \ + --dataset_dir /coc/flash5/akrishnan86/gtsfm/data/palace_fine_arts \ + --line_only \ No newline at end of file diff --git a/gtsfm/visualization/visualize_tracks.py b/gtsfm/visualization/visualize_tracks.py new file mode 100644 index 000000000..cbff0686d --- /dev/null +++ b/gtsfm/visualization/visualize_tracks.py @@ -0,0 +1,372 @@ +"""Visualize reprojection errors for tracks stored in COLMAP text outputs. + +This script reconstructs a GtsfmData object from COLMAP text files, builds a +single loader from a Hydra config, and overlays reprojection error vectors on +each image. Each measurement draws a line between the reprojected point and the +measured 2D keypoint, with an optional dot for the track. + +The script searches `--result_root` recursively for folders containing COLMAP +`cameras.txt`, `images.txt`, and `points3D.txt`, then writes visualizations to +`/tracks_viz/...` mirroring the COLMAP folder structure. +""" + +from __future__ import annotations + +import argparse +import colorsys +import os +from pathlib import Path +from typing import Iterable, List, Set, Tuple + +import cv2 +import hydra +import numpy as np +from hydra.utils import instantiate +from PIL import Image as PILImage +from PIL.Image import Image as PILImageType + +import gtsfm.utils.logger as logger_utils +from gtsfm.common.gtsfm_data import GtsfmData +from gtsfm.loader.loader_base import LoaderBase + +logger = logger_utils.get_logger() + + +def _build_loader( + loader_config: str, + dataset_dir: str, + images_dir: str | None, + max_resolution: int | None, +) -> LoaderBase: + """Instantiate a loader using a Hydra config.""" + overrides: List[str] = [f"dataset_dir={dataset_dir}"] + if images_dir is not None: + overrides.append(f"images_dir={images_dir}") + if max_resolution is not None: + overrides.append(f"max_resolution={max_resolution}") + + config_dir = Path(__file__).resolve().parents[1] / "configs" / "loader" + with hydra.initialize_config_dir(config_dir=str(config_dir), version_base=None): + cfg = hydra.compose(config_name=loader_config, overrides=overrides) + return instantiate(cfg) + + +def _collect_reprojection_pairs( + gtsfm_data: GtsfmData, + camera_idx: int, + allowed_track_indices: Set[int], +) -> List[Tuple[int, np.ndarray, np.ndarray]]: + """Collect (track_idx, measured, reprojected) for a given camera index.""" + camera = gtsfm_data.get_camera(camera_idx) + if camera is None: + return [] + + pairs: List[Tuple[int, np.ndarray, np.ndarray]] = [] + measurements = gtsfm_data.get_measurements_for_camera(camera_idx) + for track_idx, measurement_idx in measurements: + if track_idx not in allowed_track_indices: + continue + track = gtsfm_data.get_track(track_idx) + image_idx, uv_measured = track.measurement(measurement_idx) + assert image_idx == camera_idx, "Measurement image index does not match camera index" + uv_reproj, success = camera.projectSafe(track.point3()) + if not success: + continue + pairs.append((track_idx, np.array(uv_measured, dtype=float), np.array(uv_reproj, dtype=float))) + return pairs + + +def _track_color(track_idx: int) -> Tuple[int, int, int]: + """Assign a consistent, distinguishable RGB color per track index.""" + hue = (track_idx * 0.61803398875) % 1.0 + r, g, b = colorsys.hsv_to_rgb(hue, 0.7, 0.95) + return int(r * 255), int(g * 255), int(b * 255) + + +def _draw_reprojection_overlay( + image_array: np.ndarray, + pairs: Iterable[Tuple[int, np.ndarray, np.ndarray]], + *, + line_color: Tuple[int, int, int], + dot_radius: int, + line_width: int, + draw_measured: bool, + measured_color: Tuple[int, int, int], + scale_u: float, + scale_v: float, + dot_on_measured: bool, + line_only: bool, +) -> PILImageType: + """Draw reprojection overlays on an image using OpenCV.""" + image_rgb = image_array.astype(np.uint8) + image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR) + line_color_bgr = (line_color[2], line_color[1], line_color[0]) + measured_color_bgr = (measured_color[2], measured_color[1], measured_color[0]) + + for track_idx, uv_measured, uv_reproj in pairs: + x_meas = float(uv_measured[0]) * scale_u + y_meas = float(uv_measured[1]) * scale_v + x_rep = float(uv_reproj[0]) * scale_u + y_rep = float(uv_reproj[1]) * scale_v + reproj_color = _track_color(track_idx) + dot_x, dot_y = (x_meas, y_meas) if dot_on_measured else (x_rep, y_rep) + + pt_rep = (int(round(x_rep)), int(round(y_rep))) + pt_meas = (int(round(x_meas)), int(round(y_meas))) + pt_dot = (int(round(dot_x)), int(round(dot_y))) + + cv2.line(image_bgr, pt_rep, pt_meas, line_color_bgr, thickness=line_width, lineType=cv2.LINE_AA) + if not line_only: + reproj_color_bgr = (reproj_color[2], reproj_color[1], reproj_color[0]) + cv2.circle(image_bgr, pt_dot, dot_radius, reproj_color_bgr, thickness=-1, lineType=cv2.LINE_AA) + if draw_measured: + cv2.circle( + image_bgr, + pt_meas, + dot_radius, + measured_color_bgr, + thickness=max(1, line_width), + lineType=cv2.LINE_AA, + ) + + image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) + return PILImage.fromarray(image_rgb) + + +def _resolve_output_name(gtsfm_data: GtsfmData, loader: LoaderBase, image_idx: int) -> str: + """Resolve output filename based on COLMAP image names, with loader fallback.""" + info = gtsfm_data.get_image_info(image_idx) + if info.name: + return Path(info.name).name + filenames = loader.image_filenames() + if 0 <= image_idx < len(filenames): + name = Path(filenames[image_idx]).name + if name: + return name + return f"image_{image_idx:06d}.png" + + +def _build_loader_name_maps(loader: LoaderBase) -> tuple[dict[str, int], dict[str, list[int]]]: + """Build lookup maps from loader filenames to loader indices.""" + filenames = loader.image_filenames() + full_map: dict[str, int] = {} + base_map: dict[str, list[int]] = {} + for idx, name in enumerate(filenames): + full_map[name] = idx + base = Path(name).name + base_map.setdefault(base, []).append(idx) + return full_map, base_map + + +def _resolve_loader_index( + gtsfm_data: GtsfmData, image_idx: int, full_map: dict[str, int], base_map: dict[str, list[int]] +) -> int | None: + """Resolve loader index for a COLMAP image index based on filename.""" + info = gtsfm_data.get_image_info(image_idx) + if info.name: + if info.name in full_map: + return full_map[info.name] + base = Path(info.name).name + if base in base_map: + if len(base_map[base]) > 1: + logger.warning("Multiple loader matches for %s; using first.", base) + return base_map[base][0] + return None + + +def _has_colmap_text_files(directory: str) -> bool: + """Check whether a directory contains COLMAP text outputs.""" + required = {"cameras.txt", "images.txt", "points3D.txt"} + try: + entries = set(os.listdir(directory)) + except FileNotFoundError: + return False + return required.issubset(entries) + + +def _find_colmap_dirs(root_dir: str) -> List[str]: + """Recursively find all subdirectories containing COLMAP text files.""" + matches: List[str] = [] + for dirpath, _, _ in os.walk(root_dir): + if _has_colmap_text_files(dirpath): + matches.append(dirpath) + return matches + + +def _visualize_tracks_for_dir(args: argparse.Namespace, colmap_dir: str, output_dir: str, loader: LoaderBase) -> None: + """Visualize reprojection errors for one COLMAP directory.""" + logger.info("Loading reconstruction from %s", colmap_dir) + try: + gtsfm_data = GtsfmData.read_colmap(colmap_dir) + except Exception as exc: + logger.exception("Skipping %s due to error: %s", colmap_dir, exc) + return + + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + full_map, base_map = _build_loader_name_maps(loader) + + camera_indices = sorted(gtsfm_data.get_valid_camera_indices()) + if args.max_images is not None: + camera_indices = camera_indices[: args.max_images] + + num_tracks = gtsfm_data.number_tracks() + if args.max_pairs is not None and num_tracks > args.max_pairs: + rng = np.random.default_rng(args.random_seed) + sampled = rng.choice(num_tracks, size=args.max_pairs, replace=False) + allowed_track_indices = set(int(idx) for idx in sampled) + else: + allowed_track_indices = set(range(num_tracks)) + + for camera_idx in camera_indices: + loader_idx = _resolve_loader_index(gtsfm_data, camera_idx, full_map, base_map) + if loader_idx is None: + logger.warning("Skipping camera %d with no loader match", camera_idx) + continue + + pairs = _collect_reprojection_pairs(gtsfm_data, camera_idx, allowed_track_indices) + if not pairs: + logger.info("No valid measurements for image %d", camera_idx) + continue + + image = loader.get_image(loader_idx) + resized_h, resized_w = image.height, image.width + info = gtsfm_data.get_image_info(camera_idx) + if info.shape is not None: + orig_h, orig_w = info.shape + else: + orig_h, orig_w = resized_h, resized_w + scale_u = resized_w / orig_w if orig_w > 0 else 1.0 + scale_v = resized_h / orig_h if orig_h > 0 else 1.0 + + overlay = _draw_reprojection_overlay( + image.value_array, + pairs, + line_color=tuple(args.line_color), + dot_radius=args.dot_radius, + line_width=args.line_width, + draw_measured=args.draw_measured, + measured_color=tuple(args.measured_color), + scale_u=scale_u, + scale_v=scale_v, + dot_on_measured=args.dot_on_measured, + line_only=args.line_only, + ) + + output_name = _resolve_output_name(gtsfm_data, loader, camera_idx) + output_file = output_path / output_name + overlay.save(output_file) + logger.info("Saved %s", output_file) + + +def visualize_tracks(args: argparse.Namespace) -> None: + """Visualize reprojection errors across all COLMAP directories under result_root.""" + colmap_dirs = _find_colmap_dirs(args.result_root) + if not colmap_dirs: + logger.warning("No COLMAP text directories found under %s", args.result_root) + return + + logger.info("Instantiating loader config=%s", args.loader_config) + loader = _build_loader( + loader_config=args.loader_config, + dataset_dir=args.dataset_dir, + images_dir=args.images_dir, + max_resolution=args.max_resolution, + ) + + viz_root = Path(args.result_root) / "tracks_viz" + for colmap_dir in colmap_dirs: + rel_path = Path(colmap_dir).relative_to(args.result_root) + output_dir = viz_root / rel_path + _visualize_tracks_for_dir(args, colmap_dir, str(output_dir), loader) + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Overlay reprojection error vectors on images for COLMAP reconstructions." + ) + parser.add_argument( + "--result_root", + type=str, + required=True, + help="Root directory to recursively search for COLMAP text outputs.", + ) + parser.add_argument( + "--loader_config", + type=str, + default="colmap", + help="Loader config name from gtsfm/configs/loader (e.g., colmap, tanks_and_temples).", + ) + parser.add_argument( + "--dataset_dir", + type=str, + required=True, + help="Dataset root for the loader (passed as loader.dataset_dir).", + ) + parser.add_argument( + "--images_dir", + type=str, + default=None, + help="Optional images directory (passed as loader.images_dir).", + ) + parser.add_argument( + "--max_resolution", + type=int, + default=None, + help="Optional max resolution override for loader.", + ) + parser.add_argument( + "--max_images", + type=int, + default=None, + help="Limit the number of images to visualize.", + ) + parser.add_argument( + "--max_pairs", + type=int, + default=None, + help="Limit the number of tracks drawn across all images (randomly sampled).", + ) + parser.add_argument("--dot_radius", type=int, default=2, help="Radius for reprojection dot.") + parser.add_argument("--line_width", type=int, default=1, help="Line width for reprojection error.") + parser.add_argument( + "--line_color", + type=int, + nargs=3, + default=(255, 0, 0), + help="RGB color for reprojection error lines.", + ) + parser.add_argument( + "--random_seed", + type=int, + default=0, + help="Random seed for sampling tracks when max_pairs is set.", + ) + parser.add_argument( + "--draw_measured", + action="store_true", + help="Draw an outline circle at the measured 2D point.", + ) + parser.add_argument( + "--dot_on_measured", + action="store_true", + help="Draw the colored dot on the measured point instead of the reprojection.", + ) + parser.add_argument( + "--line_only", + action="store_true", + help="Draw only the line; use a small line-colored dot at the line head.", + ) + parser.add_argument( + "--measured_color", + type=int, + nargs=3, + default=(0, 255, 0), + help="RGB color for measured point outlines.", + ) + return parser.parse_args() + + +if __name__ == "__main__": + visualize_tracks(_parse_args()) From ac5c784216f848e059ccbc273f06cffc4e3bd866 Mon Sep 17 00:00:00 2001 From: Akshay Krishnan Date: Wed, 21 Jan 2026 16:10:48 -0500 Subject: [PATCH 03/24] trajectory aligner sim3 use --- gtsfm/cluster_merging.py | 72 ++++++++++++++++++++++++++++++--- gtsfm/configs/vggt_megaloc.yaml | 52 ++++++++++++++++++++++++ gtsfm/scene_optimizer.py | 12 +++--- 3 files changed, 125 insertions(+), 11 deletions(-) create mode 100644 gtsfm/configs/vggt_megaloc.yaml diff --git a/gtsfm/cluster_merging.py b/gtsfm/cluster_merging.py index 3a30fa1a7..9cc56c33c 100644 --- a/gtsfm/cluster_merging.py +++ b/gtsfm/cluster_merging.py @@ -7,9 +7,10 @@ from pathlib import Path from typing import TYPE_CHECKING, Optional, Tuple +import gtsam import numpy as np from dask.distributed import Client, Future -from gtsam import Similarity3, Pose3 +from gtsam import Similarity3, Pose3, UnaryMeasurementPose3, TrajectoryAlignerSim3 import gtsfm.utils.logger as logger_utils import gtsfm.common.types as gtsfm_types @@ -35,6 +36,58 @@ _SCENE_LABEL_ATTR = "_gtsfm_cluster_label" +def _create_unary_measurements(scene: GtsfmData) -> list[UnaryMeasurementPose3]: + # TODO(akshay-krishnan): investigate using a scene-dependent noise model + # perhaps * np.exp(-len(scene.get_valid_camera_indices()) / 100.0) + noise_model = gtsam.noiseModel.Diagonal.Sigmas( + np.array([1e-2, 1e-2, 1e-2, 1e-1, 1e-1, 1e-1]) + ) + unary_measurements = [] + for i, camera in scene.get_camera_poses().items(): + if camera is None: + continue + unary_measurement = UnaryMeasurementPose3(i, camera, noise_model) + unary_measurements.append(unary_measurement) + return unary_measurements + + +def merge_scenes_with_sim3_nonlinear(parent_scene: GtsfmData, children_scenes: list[GtsfmData]) -> GtsfmData: + if len(children_scenes) == 0: + return parent_scene + + aTi_measurements = _create_unary_measurements(parent_scene) + parent_camera_ids = set(parent_scene.get_valid_camera_indices()) + valid_child_scenes = [] + + for i, child_scene in enumerate(children_scenes): + child_camera_ids = set(child_scene.get_valid_camera_indices()) + common_camera_ids = parent_camera_ids & child_camera_ids + if len(common_camera_ids) == 0: + logger.warning("Child scene %d has insufficient overlap with parent, skipping", i) + continue + valid_child_scenes.append(child_scene) + + if len(valid_child_scenes) == 0: + return parent_scene + + aTi_measurements = _create_unary_measurements(parent_scene) + bTi_measurements = [_create_unary_measurements(child_scene) for child_scene in valid_child_scenes] + aligner = TrajectoryAlignerSim3(aTi_measurements, bTi_measurements) + result = aligner.solve() + + opt_aTi = {i: result.atPose3(i) for i in parent_scene.get_valid_camera_indices() if i in result.keys()} + + merged = parent_scene + for i, aTi in opt_aTi.items(): + merged.update_camera_pose(i, aTi) + + for i, child_scene in enumerate(valid_child_scenes): + opt_bSa = result.atSimilarity3(gtsam.Symbol("S", i).key()) + opt_aSb = opt_bSa.inverse() + merged = merged.merged_with(child_scene, opt_aSb) # type: ignore + return merged + + @dataclass(frozen=True) class MergedNodeResult: """Results of merging child scenes with parent scenes in the reconstruction tree. @@ -406,6 +459,7 @@ def combine_results( drop_camera_with_no_track: bool = True, drop_child_if_merging_fail: bool = True, store_full_data: bool = False, + use_nonlinear_sim3_alignment: bool = False, ) -> MergedNodeResult: """Run the merging and parent BA pipeline using already-transformed children. @@ -469,10 +523,18 @@ def _finalize_result(result_scene: Optional[GtsfmData]) -> MergedNodeResult: merged = current _log_scene_reprojection_stats(merged, "Current node", plot_histograms=plot_reprojection_histograms) - # Merge all children into the merged scene. - for i, child in enumerate(valid_child_scenes): - merged = _align_and_merge_results(merged, child, drop_if_merging_fails=drop_child_if_merging_fail) - _log_scene_reprojection_stats(merged, f"Merged with child #{i+1}", plot_histograms=plot_reprojection_histograms) + if use_nonlinear_sim3_alignment: + merged = merge_scenes_with_sim3_nonlinear(merged, valid_child_scenes) + _log_scene_reprojection_stats( + merged, "Merged with children (nonlinear alignment)", plot_histograms=plot_reprojection_histograms + ) + else: + # Merge all children into the merged scene. + for i, child in enumerate(valid_child_scenes): + merged = _align_and_merge_results(merged, child, drop_if_merging_fails=drop_child_if_merging_fail) + _log_scene_reprojection_stats( + merged, f"Merged with child #{i+1}", plot_histograms=plot_reprojection_histograms + ) _propagate_scene_metadata(merged, metadata_source) diff --git a/gtsfm/configs/vggt_megaloc.yaml b/gtsfm/configs/vggt_megaloc.yaml new file mode 100644 index 000000000..a7f80269e --- /dev/null +++ b/gtsfm/configs/vggt_megaloc.yaml @@ -0,0 +1,52 @@ +# VGGT cluster-only configuration. + +# @package _global_ +_target_: gtsfm.scene_optimizer.SceneOptimizer + +loader: + _target_: gtsfm.loader.Olsson + dataset_dir: ??? # Required: set to the dataset root on the command line. + images_dir: null + max_resolution: 760 + +image_pairs_generator: + _target_: gtsfm.retriever.image_pairs_generator.ImagePairsGenerator + global_descriptor: + _target_: gtsfm.frontend.cacher.global_descriptor_cacher.GlobalDescriptorCacher + global_descriptor_obj: + _target_: gtsfm.frontend.global_descriptor.MegaLoc + retriever: + _target_: gtsfm.retriever.Similarity + num_matched: 15 + min_score: 0.4 + batch_size: 16 + +graph_partitioner: + _target_: gtsfm.graph_partitioner.Metis + +cluster_optimizer: + _target_: gtsfm.cluster_optimizer.Cacher + optimizer: + _target_: gtsfm.cluster_optimizer.cluster_vggt.ClusterVGGT + weights_path: null + image_load_resolution: 1024 + inference_resolution: 518 + conf_threshold: 5.0 + max_num_points: 100000 + tracking: true + tracking_max_query_pts: 512 + tracking_query_frame_num: 3 + tracking_fine_tracking: false + track_vis_thresh: 0.2 + camera_type: PINHOLE + drop_outlier_after_camera_merging: false + drop_child_if_merging_fail: true + drop_camera_with_no_track: true + seed: 42 + plot_reprojection_histograms: true + run_bundle_adjustment_on_leaf: false + run_bundle_adjustment_on_parent: true + model_cache_key: null + # store_pre_ba_result: true + +use_nonlinear_sim3_alignment: false \ No newline at end of file diff --git a/gtsfm/scene_optimizer.py b/gtsfm/scene_optimizer.py index 917f3c089..4d7356b05 100644 --- a/gtsfm/scene_optimizer.py +++ b/gtsfm/scene_optimizer.py @@ -108,14 +108,13 @@ def __init__( output_root: str = DEFAULT_OUTPUT_ROOT, output_worker: Optional[str] = None, plot_reprojection_histograms: bool = True, + use_nonlinear_sim3_merging: bool = False, ) -> None: self.loader = loader self.image_pairs_generator = image_pairs_generator self.graph_partitioner = graph_partitioner self.cluster_optimizer = cluster_optimizer - self._run_bundle_adjustment_on_parent = getattr( - self.cluster_optimizer, "run_bundle_adjustment_on_parent", True - ) + self._run_bundle_adjustment_on_parent = getattr(self.cluster_optimizer, "run_bundle_adjustment_on_parent", True) self._plot_reprojection_histograms = getattr( self.cluster_optimizer, "plot_reprojection_histograms", plot_reprojection_histograms ) @@ -124,7 +123,7 @@ def __init__( ) self._drop_camera_with_no_track = getattr(self.cluster_optimizer, "drop_camera_with_no_track", True) self._drop_child_if_merging_fail = getattr(self.cluster_optimizer, "drop_child_if_merging_fail", True) - + self._use_nonlinear_sim3_merging = use_nonlinear_sim3_merging self.output_root = Path(output_root) if output_worker is not None: self.cluster_optimizer._output_worker = output_worker @@ -231,7 +230,7 @@ def to_context(path: tuple[int, ...], visibility_graph: VisibilityGraph) -> Clus # Returns handles to various outputs: reconstruction, metrics, io_barrier etc. handles_tree = context_tree.map(self._schedule_single_cluster) - # Get the reconstruction handle and run merging to get a tree of merged result handles. + # Get the reconstruction handle and run merging to get a tree of merged result handles. reconstruction_tree = handles_tree.map(lambda handle: handle.reconstruction) cameras_gt = self.loader.get_gt_cameras() @@ -248,6 +247,7 @@ def merge_fn( drop_camera_with_no_track=self._drop_camera_with_no_track, drop_child_if_merging_fail=self._drop_child_if_merging_fail, store_full_data=False, + use_nonlinear_sim3_alignment=self._use_nonlinear_sim3_merging, ) merged_future_tree = submit_tree_map_with_children(client, reconstruction_tree, merge_fn) @@ -270,7 +270,7 @@ def merge_fn( base_metrics_groups.extend(metrics_groups) base_metrics_groups.append(merged_result.metrics) root_merge_future = merge_future - elif metrics_groups: + else: merged_result = merge_future.result() metrics_groups.append(merged_result.metrics) save_metrics_reports(metrics_groups, str(handle.output_paths.metrics)) From 6aa34930367e6f9687f3738ee5137bebcdc5d822 Mon Sep 17 00:00:00 2001 From: Harneet Singh Khanuja Date: Sun, 25 Jan 2026 19:44:13 -0500 Subject: [PATCH 04/24] Replacing VGGSfM tracker with VGGT tracker module --- gtsfm/cluster_optimizer/cluster_vggt.py | 33 +- gtsfm/configs/vggt.yaml | 12 +- gtsfm/frontend/vggt.py | 418 ++++++++++++------------ gtsfm/loader/loader_base.py | 134 ++++++++ 4 files changed, 356 insertions(+), 241 deletions(-) diff --git a/gtsfm/cluster_optimizer/cluster_vggt.py b/gtsfm/cluster_optimizer/cluster_vggt.py index 9c934d50b..fc8e531f7 100644 --- a/gtsfm/cluster_optimizer/cluster_vggt.py +++ b/gtsfm/cluster_optimizer/cluster_vggt.py @@ -32,13 +32,9 @@ def _resize_to_square_tensor(image: np.ndarray, target_size: int) -> torch.Tenso return (tensor.squeeze(0)) / 255.0 -def _load_vggt_inputs(loader, indices: list[int], target_size: int): +def _load_vggt_inputs(loader, indices: list[int], mode: str): """Load and preprocess a batch of images for VGGT.""" - - def resize_transform(arr: np.ndarray) -> torch.Tensor: - return _resize_to_square_tensor(arr, target_size) - - return loader.load_image_batch_vggt(indices, target_size, resize_transform) + return loader.load_image_batch_vggt_loader(indices, mode=mode) def _resolve_vggt_model(cache_key: Hashable | None, loader_kwargs: dict[str, Any] | None) -> Any | None: @@ -113,15 +109,14 @@ class ClusterVGGT(ClusterOptimizerBase): def __init__( self, weights_path: Optional[str] = None, - image_load_resolution: int = 1024, - inference_resolution: int = 518, conf_threshold: float = 5.0, max_num_points: int = 100000, tracking: bool = False, - tracking_max_query_pts: int = 1000, - tracking_query_frame_num: int = 4, - tracking_fine_tracking: bool = True, - track_vis_thresh: float = 0.2, + tracking_max_query_pts: int = 2048, + tracking_query_frame_num: int = 3, + track_vis_thresh: float = 0.05, + track_conf_thresh: float = 0.2, + keypoint_extractor: str = "aliked+sp+sift", camera_type: str = "PINHOLE", seed: int = 42, scene_dir: Optional[str] = None, @@ -154,15 +149,14 @@ def __init__( run_bundle_adjustment_on_parent=run_bundle_adjustment_on_parent, ) self._weights_path = Path(weights_path) if weights_path is not None else None - self._image_load_resolution = image_load_resolution - self._inference_resolution = inference_resolution self._conf_threshold = conf_threshold self._max_points_for_colmap = max_num_points self._tracking = tracking self._tracking_max_query_pts = tracking_max_query_pts self._tracking_query_frame_num = tracking_query_frame_num - self._tracking_fine_tracking = tracking_fine_tracking self._track_vis_thresh = track_vis_thresh + self._track_conf_thresh = track_conf_thresh + self._keypoint_extractor = keypoint_extractor self._camera_type = camera_type self._max_reproj_error = max_reproj_error self._seed = seed @@ -216,8 +210,6 @@ def _maybe_set_model_kw(key: str, value: Any) -> None: def __repr__(self) -> str: components = [ f"weights_path={self._weights_path}", - f"image_load_resolution={self._image_load_resolution}", - f"inference_resolution={self._inference_resolution}", f"camera_type={self._camera_type}", f"dtype={self._dtype}", f"use_sparse_attention={self._use_sparse_attention}", @@ -253,15 +245,14 @@ def create_computation_graph( image_names = tuple(str(image_filenames[idx]) for idx in keys) config = VggtConfiguration( - vggt_fixed_resolution=self._inference_resolution, - img_load_resolution=self._image_load_resolution, confidence_threshold=self._conf_threshold, max_num_points=self._max_points_for_colmap, tracking=self._tracking, max_query_pts=self._tracking_max_query_pts, query_frame_num=self._tracking_query_frame_num, - fine_tracking=self._tracking_fine_tracking, track_vis_thresh=self._track_vis_thresh, + track_conf_thresh=self._track_conf_thresh, + keypoint_extractor=self._keypoint_extractor, dtype=self._dtype, model_ctor_kwargs=self._model_ctor_kwargs.copy(), use_sparse_attention=self._use_sparse_attention, @@ -270,7 +261,7 @@ def create_computation_graph( ) image_batch_graph, original_coords_graph = delayed(_load_vggt_inputs, nout=2)( - context.loader, global_indices, self._image_load_resolution + context.loader, global_indices, mode="crop" # mode is fixed to "crop" ) result_graph = delayed(_run_vggt_pipeline)( diff --git a/gtsfm/configs/vggt.yaml b/gtsfm/configs/vggt.yaml index 979cca861..01781aa8a 100644 --- a/gtsfm/configs/vggt.yaml +++ b/gtsfm/configs/vggt.yaml @@ -7,7 +7,7 @@ loader: _target_: gtsfm.loader.Olsson dataset_dir: ??? # Required: set to the dataset root on the command line. images_dir: null - max_resolution: 760 + max_resolution: 518 # VGGT recommended max resolution. Non editable. mode is fixed to "crop" image_pairs_generator: _target_: gtsfm.retriever.image_pairs_generator.ImagePairsGenerator @@ -29,15 +29,15 @@ cluster_optimizer: optimizer: _target_: gtsfm.cluster_optimizer.cluster_vggt.ClusterVGGT weights_path: null - image_load_resolution: 1024 - inference_resolution: 518 conf_threshold: 5.0 max_num_points: 100000 tracking: true - tracking_max_query_pts: 512 + tracking_max_query_pts: 2048 tracking_query_frame_num: 3 - tracking_fine_tracking: false - track_vis_thresh: 0.2 + keypoint_extractor: aliked+sp+sift + track_vis_thresh: 0.05 + track_conf_thresh: 0.2 + max_reproj_error: 0 # 0.0 means no filtering based on reproj error camera_type: PINHOLE drop_outlier_after_camera_merging: false drop_child_if_merging_fail: true diff --git a/gtsfm/frontend/vggt.py b/gtsfm/frontend/vggt.py index bdb4bab9c..1124e3ef2 100644 --- a/gtsfm/frontend/vggt.py +++ b/gtsfm/frontend/vggt.py @@ -13,7 +13,6 @@ import numpy as np import torch -import torch.nn.functional as F from gtsam import Point2, Point3 from torch.amp import autocast as amp_autocast # type: ignore @@ -97,14 +96,11 @@ def _import_from_vanilla_vggt(module_suffix: str) -> ModuleType: logger.info("⚡ FastVGGT enabled via thirdparty/FastVGGT.") else: logger.info("📷 Using vanilla VGGT (FastVGGT submodule not detected).") +from vggt.utils.geometry import unproject_depth_map_to_point_map # type: ignore from vggt.utils.helper import randomly_limit_trues # type: ignore from vggt.utils.load_fn import load_and_preprocess_images_square # type: ignore from vggt.utils.pose_enc import pose_encoding_to_extri_intri # type: ignore -from gtsfm.frontend.anysplat import ( - batchify_unproject_depth_map_to_point_map as _anysplat_batchify_unproject, -) # type: ignore - DEFAULT_FIXED_RESOLUTION = 518 _DTYPE_ALIASES: dict[str, torch.dtype] = { @@ -142,8 +138,6 @@ def _resolve_dtype_argument(arg: Optional[Union[str, torch.dtype]]) -> Optional[ class VggtConfiguration: """Configuration for the high-level VGGT reconstruction pipeline.""" - img_load_resolution: int = 1024 - vggt_fixed_resolution: int = DEFAULT_FIXED_RESOLUTION seed: int = 42 confidence_threshold: float = 5.0 max_num_points: int = 100000 @@ -154,11 +148,11 @@ class VggtConfiguration: # Tracking-specific parameters: tracking: bool = True - max_query_pts: int = 1000 - query_frame_num: int = 4 - keypoint_extractor: str = "aliked+sp" - fine_tracking: bool = True - track_vis_thresh: float = 0.2 + max_query_pts: int = 2048 + query_frame_num: int = 3 + track_vis_thresh: float = 0.05 + track_conf_thresh: float = 0.2 + keypoint_extractor: str = "aliked+sp+sift" max_reproj_error: float = 8.0 @@ -168,7 +162,7 @@ class VggtOutput: # TODO(Frank): derive from base class shared with AnySplat (i device: torch.device dtype: torch.dtype - resized_images: torch.Tensor + images: torch.Tensor extrinsic: torch.Tensor intrinsic: torch.Tensor depth_map: torch.Tensor @@ -310,76 +304,10 @@ def load_model( return model -def _rescale_intrinsic_for_original_resolution( - intrinsic: np.ndarray, - reconstruction_resolution: int, - image_width: float, - image_height: float, -) -> np.ndarray: - """Adapt intrinsics estimated on a square crop back to the original image size.""" - resized = intrinsic.copy() - # print('image_width, image_height: ', image_width, image_height) - resize_ratio = max(image_width, image_height) / float(reconstruction_resolution) - resized[:2, :] *= resize_ratio - resized[0, 2] = image_width / 2.0 - resized[1, 2] = image_height / 2.0 - return resized - - -def _convert_measurement_to_original_resolution( - uv: Tuple[float, float], - original_coord: np.ndarray, - inference_resolution: int, - img_load_resolution: int, - *, - measurement_in_load_resolution: bool = False, -) -> Tuple[float, float]: - """Convert VGGT coordinates back to the original image coordinate system. - - Args: - uv: Input measurement in either inference or load resolution space. - original_coord: Metadata describing the crop location within the padded square, expressed at load resolution. - inference_resolution: Resolution of VGGT inference grid. - img_load_resolution: Resolution used when images were padded/resized prior to inference. - measurement_in_load_resolution: Set ``True`` if ``uv`` already lives in the load resolution. - """ - - x_infer, y_infer = uv - x1, y1 = original_coord[0], original_coord[1] - width, height = original_coord[4], original_coord[5] - - # VGGT runs on the ``img_load_resolution`` square; vggt_output down-samples that square to the - # (typically smaller) ``inference_resolution``. Undo that downscale so we can use the crop - # metadata stored in ``original_coord``. - if measurement_in_load_resolution: - x_load = x_infer - y_load = y_infer - else: - scale_back_to_load = float(img_load_resolution) / float(inference_resolution) - x_load = x_infer * scale_back_to_load - y_load = y_infer * scale_back_to_load - - # ``original_coord`` encodes the location of the original, possibly rectangular, image within - # the padded square (in *load* resolution). Remove the padding and scale the remaining pixels - # back to the native resolution. - max_side = float(max(width, height)) - resize_ratio = max_side / float(img_load_resolution) - u = (x_load - x1) * resize_ratio - v = (y_load - y1) * resize_ratio - - max_u = max(width - 0.5, 0.0) - max_v = max(height - 0.5, 0.0) - u = float(np.clip(u, 0.0, max_u)) - v = float(np.clip(v, 0.0, max_v)) - return u, v - - def _high_confidence_pointcloud(config: VggtConfiguration, vggt_output: VggtOutput) -> Tuple[np.ndarray, np.ndarray]: """Convert raw VGGT predictions into point attributes.""" points_3d = vggt_output.dense_points.to(torch.float32).cpu().numpy() - points_rgb = (vggt_output.resized_images.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) * 255).astype( - np.uint8 - ) + points_rgb = (vggt_output.images.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) * 255).astype(np.uint8) depth_conf_np = vggt_output.depth_confidence.to(torch.float32).cpu().numpy() conf_mask = depth_conf_np >= config.confidence_threshold @@ -400,9 +328,6 @@ def _is_point_in_front_of_camera(camera, point_xyz: np.ndarray, *, epsilon: floa return float(z_val) > epsilon - - - def _convert_vggt_outputs_to_gtsfm_data( *, vggt_output: VggtOutput, @@ -427,9 +352,8 @@ def _convert_vggt_outputs_to_gtsfm_data( image_width = float(original_coords_np[local_idx, 4]) image_height = float(original_coords_np[local_idx, 5]) - scaled_intrinsic = _rescale_intrinsic_for_original_resolution( - intrinsic_np[local_idx], config.vggt_fixed_resolution, image_width, image_height - ) + scaled_intrinsic = intrinsic_np[local_idx] + camera = torch_utils.camera_from_matrices(extrinsic_np[local_idx], scaled_intrinsic) gtsfm_data.add_camera(global_idx, camera) # type: ignore[arg-type] gtsfm_data.set_image_info( @@ -445,20 +369,20 @@ def _convert_vggt_outputs_to_gtsfm_data( if tracking_result: - # track masks according to visibility, reprojection error, etc + max_reproj_error = float(config.max_reproj_error) track_mask = tracking_result.visibilities > config.track_vis_thresh - inlier_num = track_mask.sum(0) - - valid_mask = inlier_num >= 2 # a track is invalid if without two inliers if tracking_result.confidences is not None: - valid_mask = np.logical_and(valid_mask, tracking_result.confidences > config.confidence_threshold) - valid_idx = np.nonzero(valid_mask)[0] + track_mask = np.logical_and(track_mask, tracking_result.confidences > config.track_conf_thresh) - max_reproj_error = float(config.max_reproj_error) enforce_reproj_filter = ( tracking_result.points_3d is not None and np.isfinite(max_reproj_error) and max_reproj_error > 0.0 ) + inlier_num = track_mask.sum(0) + min_measurements = 2 + valid_mask = inlier_num >= min_measurements # a track is invalid if without two inliers + valid_idx = np.nonzero(valid_mask)[0] + for valid_id in valid_idx: rgb: np.ndarray if tracking_result.colors is not None and valid_id < tracking_result.colors.shape[0]: @@ -478,29 +402,24 @@ def _convert_vggt_outputs_to_gtsfm_data( camera = gtsfm_data.get_camera(global_idx) if not _is_point_in_front_of_camera(camera, point_xyz): continue - rescaled_u, rescaled_v = _convert_measurement_to_original_resolution( - (float(u), float(v)), - original_coords_np[local_id], - config.vggt_fixed_resolution, - config.img_load_resolution, - measurement_in_load_resolution=True, - ) + float_u = float(u) + float_v = float(v) if enforce_reproj_filter: projected = camera.project(gtsam_point) proj_u = float(projected[0]) proj_v = float(projected[1]) - reproj_err = float(np.hypot(rescaled_u - proj_u, rescaled_v - proj_v)) + reproj_err = float(np.hypot(float_u - proj_u, float_v - proj_v)) max_error_for_track = max(max_error_for_track, reproj_err) - per_track_measurements.append((global_idx, rescaled_u, rescaled_v)) + per_track_measurements.append((global_idx, float_u, float_v)) - if len(per_track_measurements) < 2: + if len(per_track_measurements) < min_measurements: continue if enforce_reproj_filter and max_error_for_track > max_reproj_error: continue track = torch_utils.colored_track_from_point(point_xyz, rgb) - for global_idx, rescaled_u, rescaled_v in per_track_measurements: - track.addMeasurement(global_idx, Point2(rescaled_u, rescaled_v)) + for global_idx, float_u, float_v in per_track_measurements: + track.addMeasurement(global_idx, Point2(float_u, float_v)) gtsfm_data.add_track(track) if config.run_bundle_adjustment_on_leaf: @@ -538,11 +457,7 @@ def run_VGGT( model: Optional[VGGT] = None, weights_path: PathLike | None = None, ) -> VggtOutput: - """Run VGGT on a batch of images and return raw model predictions. - - Set ``return_dense_points`` to ``True`` to additionally compute the full per-pixel - point cloud using the optional AnySplat acceleration path (when available). - """ + """Run VGGT and unproject depth using the geometry helper.""" if images.ndim != 4 or images.shape[1] != 3: raise ValueError("VGGT expects images shaped as (N, 3, H, W).") @@ -570,22 +485,10 @@ def run_VGGT( assert model is not None images = images.to(resolved_device, dtype=resolved_dtype) - res = cfg.vggt_fixed_resolution if cfg else DEFAULT_FIXED_RESOLUTION - resized_images = F.interpolate(images, size=(res, res), mode="bilinear", align_corners=False, antialias=True) - # print('resized_images: ', resized_images.shape) 518, 518 - - # FastVGGT requires the model to know the actual patch grid dimensions used for token merging. - patch_w = max(1, resized_images.shape[-1] // getattr(model.aggregator, "patch_size", 14)) - patch_h = max(1, resized_images.shape[-2] // getattr(model.aggregator, "patch_size", 14)) - if hasattr(model, "update_patch_dimensions"): - try: - model.update_patch_dimensions(patch_w, patch_h) - except Exception as exc: # pragma: no cover - best effort for FastVGGT compatibility - logger.warning("Failed to update VGGT patch dimensions (%dx%d): %s", patch_w, patch_h, exc) # FastVGGT requires the model to know the actual patch grid dimensions used for token merging. - patch_w = max(1, resized_images.shape[-1] // getattr(model.aggregator, "patch_size", 14)) - patch_h = max(1, resized_images.shape[-2] // getattr(model.aggregator, "patch_size", 14)) + patch_w = max(1, images.shape[-1] // getattr(model.aggregator, "patch_size", 14)) + patch_h = max(1, images.shape[-2] // getattr(model.aggregator, "patch_size", 14)) if hasattr(model, "update_patch_dimensions"): try: model.update_patch_dimensions(patch_w, patch_h) @@ -599,28 +502,32 @@ def run_VGGT( with torch.no_grad(): with autocast_ctx: - batched = resized_images.unsqueeze(0) # make into (training) batch of 1 + batched = images.unsqueeze(0) # make into (training) batch of 1 tokens, ps_idx = model.aggregator(batched) # transformer backbone + with torch.cuda.amp.autocast(dtype=torch.float32): pose_enc = model.camera_head(tokens)[-1] extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, batched.shape[-2:]) depth_map, depth_conf = model.depth_head(tokens, batched, ps_idx) - assert _anysplat_batchify_unproject is not None, "Anysplat dependencies not available" - dense_points = _anysplat_batchify_unproject(depth_map, extrinsic, intrinsic) - depth_confidence = depth_conf.squeeze(0) if depth_confidence.ndim == 4 and depth_confidence.shape[-1] == 1: depth_confidence = depth_confidence.squeeze(-1) + depth_map_fp32 = depth_map.squeeze(0).to(dtype=torch.float32) + extrinsic_fp32 = extrinsic.squeeze(0).to(dtype=torch.float32) + intrinsic_fp32 = intrinsic.squeeze(0).to(dtype=torch.float32) + dense_points_np = unproject_depth_map_to_point_map(depth_map_fp32, extrinsic_fp32, intrinsic_fp32) + dense_points = torch.from_numpy(dense_points_np).to(device=resolved_device, dtype=torch.float32) + return VggtOutput( device=resolved_device, dtype=resolved_dtype, - resized_images=resized_images, + images=images, extrinsic=extrinsic.squeeze(0), intrinsic=intrinsic.squeeze(0), depth_map=depth_map.squeeze(0), depth_confidence=depth_confidence, - dense_points=dense_points.squeeze(0), + dense_points=dense_points, ) @@ -646,22 +553,17 @@ class VGGTTrackingResult: colors: Optional[np.ndarray] -def _import_predict_tracks(): - """Return the vendored ``predict_tracks`` helper from the VGGT submodule. - - The tracker lives in ``thirdparty/vggt``. We keep this import behind a small helper so that runtime - errors surface with a clear explanation when the submodule is missing. - """ +def _import_vggsfm_utils(): + """Return the vendored vggsfm utilities module from the VGGT submodule.""" try: - from vggt.dependency.track_predict import predict_tracks as _predict_tracks # type: ignore + from vggt.dependency import vggsfm_utils as _vggsfm_utils # type: ignore except ImportError as exc: # pragma: no cover - exercised only when the submodule is absent - # FastVGGT strips the tracker utilities, so fall back to the vanilla VGGT namespace if possible. if _USING_FASTVGGT: try: - tracker_module = _import_from_vanilla_vggt("dependency.track_predict") - logger.info("Using tracker utilities from the vanilla VGGT submodule.") - return tracker_module.predict_tracks # type: ignore[attr-defined] + tracker_module = _import_from_vanilla_vggt("dependency.vggsfm_utils") + logger.info("Using vggsfm utilities from the vanilla VGGT submodule.") + return tracker_module # type: ignore[return-value] except ImportError as fallback_exc: exc = fallback_exc @@ -672,42 +574,19 @@ def _import_predict_tracks(): if _USING_FASTVGGT: hint += " FastVGGT does not bundle the tracker code, so the vanilla VGGT submodule must remain accessible." raise ImportError(hint) from exc - return _predict_tracks + return _vggsfm_utils -def run_vggt_tracking( - images: torch.Tensor, vggt_output: VggtOutput, *, config: Optional[VggtConfiguration] = None +def _run_vggt_head_tracking( + vggt_output: VggtOutput, + *, + model: VGGT, + config: Optional[VggtConfiguration] = None, ) -> VGGTTrackingResult: - """Generate dense feature tracks using the VGGSfM tracker shipped with VGGT. - - Parameters: - images: Tensor shaped ``(num_frames, 3, H, W)`` at the *square* VGGT load resolution. You can reuse - the ``images`` tensor that you passed into :func:`run_reconstruction`; typically this is the output - from ``load_and_preprocess_images_square`` prior to interpolation. - vggt_output: Output from :func:`run_VGGT`. The ``depth_confidence`` and optional ``dense_points`` tensors - are consumed directly, avoiding redundant transfers or recomputation. - config: Optional :class:`VggtConfiguration`. We reuse the existing configuration container because - it already captures the tracker-specific parameters (``max_query_pts``, ``query_frame_num``, etc.). - tracker_kwargs: Optional dictionary to override individual keyword arguments passed to the underlying - :func:`vggt.dependency.track_predict.predict_tracks` function. This is useful if you want to tweak - settings not exposed via :class:`VggtConfiguration`. - - Returns: - :class:`VGGTTrackingResult` aggregating the numpy arrays emitted by the tracker. The visibility scores can - be thresholded manually, e.g. ``mask = result.visibilities > config.vis_thresh``. The tracks are expressed - in the same *square* coordinate frame as ``images``; remember to rescale them back to the original image - crop using :func:`_convert_measurement_to_original_resolution` if you plan to add them to ``GtsfmData``. - - Example: - >>> vggt_output = run_VGGT(image_batch, model=model, dtype=dtype, return_dense_points=True) - >>> cfg = VggtConfiguration() - >>> tracking = run_vggt_tracking(image_batch, vggt_output, config=cfg) - >>> high_quality = tracking.visibilities > cfg.vis_thresh - >>> first_track_pixels = tracking.tracks[:, 0] - """ + """Generate dense feature tracks using the VGGT track head.""" cfg = config or VggtConfiguration() - predict_tracks = _import_predict_tracks() + vggsfm_utils = _import_vggsfm_utils() device = vggt_output.device if device.type != "cuda": @@ -716,47 +595,156 @@ def run_vggt_tracking( "Re-run the pipeline with CUDA available." ) - dtype = torch.float32 # Tracker stack (LightGlue / DINO) expects fp32 inputs. - - if images.device != device or images.dtype != dtype: - logger.info("Moving VGGT tracking inputs to %s (dtype=%s) for DINO attention.", device, dtype) - images = images.to(device=device, dtype=dtype, non_blocking=True) + images = vggt_output.images + if images.device != device or images.dtype != torch.float32: + images = images.to(device=device, dtype=torch.float32, non_blocking=True) + + frame_num = images.shape[0] + query_frame_indexes = vggsfm_utils.generate_rank_by_dino( + images, + query_frame_num=cfg.query_frame_num, + image_size=518, + model_name="dinov2_vitb14_reg", + device=device, + spatial_similarity=False, + ) + if 0 in query_frame_indexes: + query_frame_indexes.remove(0) + query_frame_indexes = [0, *query_frame_indexes] + + extractors = vggsfm_utils.initialize_feature_extractors( + max_query_num=cfg.max_query_pts, + extractor_method=cfg.keypoint_extractor, + device=device, + ) - conf_tensor = vggt_output.depth_confidence.to(device="cpu", dtype=dtype, non_blocking=True) - points_tensor = vggt_output.dense_points.to(device="cpu", dtype=dtype, non_blocking=True) + dense_points = vggt_output.dense_points + depth_confidence = vggt_output.depth_confidence + height, width = images.shape[-2:] + + pred_tracks = [] + pred_vis_scores = [] + pred_conf_scores = [] + pred_world_points = [] + pred_world_points_conf = [] + pred_colors = [] + + for query_index in query_frame_indexes: + query_image = images[query_index] + query_points = vggsfm_utils.extract_keypoints(query_image, extractors, round_keypoints=True) + if query_points is None or query_points.shape[1] == 0: + continue + + query_points = query_points[:, torch.randperm(query_points.shape[1], device=device)] + if query_points.shape[1] > cfg.max_query_pts: + query_points = query_points[:, : cfg.max_query_pts] + + query_points_round = query_points.squeeze(0).round().long() + query_points_round[:, 0] = query_points_round[:, 0].clamp(0, width - 1) + query_points_round[:, 1] = query_points_round[:, 1].clamp(0, height - 1) + + pred_color = ( + images[query_index][:, query_points_round[:, 1], query_points_round[:, 0]].permute(1, 0).cpu().numpy() + * 255.0 + ).astype(np.uint8) + + pred_point_3d = dense_points[query_index][query_points_round[:, 1], query_points_round[:, 0]] + + pred_conf = None + if depth_confidence is not None: + pred_conf = depth_confidence[query_index][query_points_round[:, 1], query_points_round[:, 0]] + + if query_points.shape[1] == 0: + continue + + reorder_index = vggsfm_utils.calculate_index_mappings(query_index, frame_num, device=device) + reorder_images = vggsfm_utils.switch_tensor_order([images], reorder_index, dim=0)[0] + + with torch.no_grad(): + with amp_autocast("cuda", dtype=vggt_output.dtype): + aggregated_tokens_list, ps_idx = model.aggregator(reorder_images[None]) + if aggregated_tokens_list and aggregated_tokens_list[0].dtype != torch.float32: + aggregated_tokens_list = [tokens.float() for tokens in aggregated_tokens_list] + with amp_autocast("cuda", dtype=torch.float32): + track_list, vis_scores, conf_scores = model.track_head( + aggregated_tokens_list, + reorder_images[None], + ps_idx, + query_points=query_points, + ) - with torch.no_grad(): - tracks, vis_scores, confidences, points_3d, colors = predict_tracks( - images, - conf=conf_tensor, - points_3d=points_tensor, - masks=None, # ignored anyway ! - max_query_pts=cfg.max_query_pts, - query_frame_num=cfg.query_frame_num, - keypoint_extractor=cfg.keypoint_extractor, - fine_tracking=cfg.fine_tracking, + pred_track = track_list[-1] + pred_track = pred_track.squeeze(0) + vis_scores = vis_scores.squeeze(0) + conf_scores = conf_scores.squeeze(0) + reordered = vggsfm_utils.switch_tensor_order([pred_track, vis_scores, conf_scores], reorder_index, dim=0) + pred_track, pred_vis, pred_conf_score = reordered + + pred_tracks.append(pred_track) + pred_vis_scores.append(pred_vis) + if pred_conf_score is not None: + pred_conf_scores.append(pred_conf_score) + pred_world_points.append(pred_point_3d) + if pred_conf is not None: + pred_world_points_conf.append(pred_conf) + pred_colors.append(pred_color) + + if not pred_tracks: + empty_tracks = np.zeros((frame_num, 0, 2), dtype=np.float32) + empty_vis = np.zeros((frame_num, 0), dtype=np.float32) + empty_conf = np.zeros((0,), dtype=np.float32) if depth_confidence is not None else None + empty_points = np.zeros((0, 3), dtype=np.float32) + empty_colors = np.zeros((0, 3), dtype=np.uint8) + return VGGTTrackingResult( + tracks=empty_tracks, + visibilities=empty_vis, + confidences=empty_conf, + points_3d=empty_points, + colors=empty_colors, ) - # print("images: ", images.shape) - # print("conf_tensor: ", conf_tensor.shape) - # print("tracks: ", tracks.shape) - # print("vis_scores: ", vis_scores.shape) - # print("confidences: ", confidences.shape) - # print("points_3d: ", points_3d.shape) - # print("colors: ", colors.shape) - # # images: torch.Size([4, 3, 1024, 1024]) - # # conf_tensor: torch.Size([4, 518, 518]) - # # tracks: (4, 2901, 2) - # # vis_scores: (4, 2901) - # # confidences: (2901,) - # # points_3d: (2901, 3) - # # colors: (2901, 3) + tracks = torch.cat(pred_tracks, dim=1) + vis_scores = torch.cat(pred_vis_scores, dim=1) + confidences = torch.cat(pred_conf_scores, dim=1) if pred_conf_scores else None + points_3d = torch.cat(pred_world_points, dim=0) if pred_world_points else None + points_3d_conf = torch.cat(pred_world_points_conf, dim=0) if pred_world_points_conf else None + colors = np.concatenate(pred_colors, axis=0) if pred_colors else None + + if points_3d_conf is not None and points_3d is not None: + filtered_flag = points_3d_conf > 1.5 + if int(filtered_flag.sum().item()) > cfg.max_query_pts // 2: + tracks = tracks[:, filtered_flag] + vis_scores = vis_scores[:, filtered_flag] + if confidences is not None: + confidences = confidences[:, filtered_flag] + points_3d = points_3d[filtered_flag] + points_3d_conf = points_3d_conf[filtered_flag] + if colors is not None: + colors = colors[filtered_flag.cpu().numpy()] return VGGTTrackingResult( - tracks=tracks, visibilities=vis_scores, confidences=confidences, points_3d=points_3d, colors=colors + tracks=tracks.float().cpu().numpy(), + visibilities=vis_scores.float().cpu().numpy(), + confidences=confidences.float().cpu().numpy() if confidences is not None else None, + points_3d=points_3d.float().cpu().numpy() if points_3d is not None else None, + colors=colors, ) +def run_vggt_tracking( + vggt_output: VggtOutput, + *, + config: Optional[VggtConfiguration] = None, + model: Optional[VGGT] = None, +) -> VGGTTrackingResult: + """Generate dense feature tracks using the configured VGGT tracking backend.""" + + cfg = config or VggtConfiguration() + if model is None: + raise ValueError("VGGT tracking_head requires a loaded VGGT model.") + return _run_vggt_head_tracking(vggt_output, model=model, config=cfg) + + # --- VGGT reconstruction ------------------------------------------------- @@ -773,8 +761,7 @@ def run_reconstruction( """Run VGGT on a batch of images and convert outputs to ``GtsfmData``. Args: - images: Tensor shaped ``(num_frames, 3, H, W)`` at the *square* VGGT load resolution. You can - obtain this tensor by calling ``load_and_preprocess_images_square`` prior to interpolation. + images: Tensor shaped ``(num_frames, 3, H, W)`` at the VGGT load resolution. image_indices: Sequence of global image indices corresponding to the provided ``images`` batch. image_names: Optional sequence of image filenames corresponding to the provided ``images`` batch. original_coords: Tensor shaped ``(num_frames, 6)`` giving the original image crop metadata @@ -794,18 +781,22 @@ def run_reconstruction( torch.cuda.manual_seed(cfg.seed) torch.cuda.manual_seed_all(cfg.seed) + model_for_tracking = None + if cfg.tracking and model_for_tracking is None: + model_for_tracking = model + vggt_output = run_VGGT(images, config=cfg, model=model, weights_path=weights_path) + tracking_result = None + if cfg.tracking: + tracking_result = run_vggt_tracking(vggt_output, config=cfg, model=model_for_tracking) + if cfg.tracking and vggt_output.device.type == "cuda": - if model is not None: - _offload_vggt_model(model) + if model_for_tracking is not None: + _offload_vggt_model(model_for_tracking) else: torch.cuda.empty_cache() - tracking_result = None - if cfg.tracking: - tracking_result = run_vggt_tracking(images, vggt_output, config=cfg) - points_3d, points_rgb = _high_confidence_pointcloud(cfg, vggt_output) gtsfm_data = _convert_vggt_outputs_to_gtsfm_data( @@ -835,8 +826,7 @@ def run_reconstruction_gtsfm_data_only(images: torch.Tensor, **kwargs) -> GtsfmD """Run VGGT on a batch of images and convert outputs to ``GtsfmData``. Args: - images: Tensor shaped ``(num_frames, 3, H, W)`` at the *square* VGGT load resolution. You can - obtain this tensor by calling ``load_and_preprocess_images_square`` prior to interpolation. + images: Tensor shaped ``(num_frames, 3, H, W)`` at the VGGT load resolution. **kwargs: Additional keyword arguments passed to :func:`run_reconstruction`. Returns: diff --git a/gtsfm/loader/loader_base.py b/gtsfm/loader/loader_base.py index f6f4cc096..415ddef14 100644 --- a/gtsfm/loader/loader_base.py +++ b/gtsfm/loader/loader_base.py @@ -12,6 +12,8 @@ from dask.delayed import Delayed, delayed from dask.distributed import Client, Future from gtsam import Cal3_S2, Cal3Bundler, Cal3DS2, Pose3 # type: ignore +from PIL import Image as PILImage +from torchvision import transforms as TF from trimesh import Trimesh import gtsfm.common.types as gtsfm_types @@ -492,6 +494,138 @@ def load_image_batch_vggt( transformed = batch_transform(batch_tensor) if batch_transform else batch_tensor return transformed, original_coords_tensor + def load_image_batch_vggt_loader(self, indices: List[int], mode="crop"): + """ + A quick start function to load and preprocess images for model input. + This assumes the images should have the same shape for easier batching, + but VGGT model can also work well with different shapes. + + Args: + indices: List of image indices to load + mode (str, optional): Preprocessing mode, either "crop" or "pad". + - "crop" (default): Sets width to 518px and center crops height if needed. + - "pad": Preserves all pixels by making the largest dimension 518px + and padding the smaller dimension to reach a square shape. + + Returns: + torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W) + + Raises: + ValueError: If the input list is empty or if mode is invalid + + Notes: + - Images with different dimensions will be padded with white (value=1.0) + - A warning is printed when images have different shapes + - When mode="crop": The function ensures width=518px while maintaining aspect ratio + and height is center-cropped if larger than 518px + - When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio + and the smaller dimension is padded to reach a square shape (518x518) + - Dimensions are adjusted to be divisible by 14 for compatibility with model requirements + """ + # Check for empty list + if len(indices) == 0: + raise ValueError("At least 1 image is required") + + # Validate mode + if mode not in ["crop", "pad"]: + raise ValueError("Mode must be either 'crop' or 'pad'") + + images = [] + shapes = set() + to_tensor = TF.ToTensor() + target_size = 518 + + # First process all images and collect their shapes + for idx in indices: + # Open image + img = self.get_image(idx).value_array + + img = PILImage.fromarray(img) + + width, height = img.size + + if mode == "pad": + # Make the largest dimension 518px while maintaining aspect ratio + if width >= height: + new_width = target_size + new_height = round(height * (new_width / width) / 14) * 14 # Make divisible by 14 + else: + new_height = target_size + new_width = round(width * (new_height / height) / 14) * 14 # Make divisible by 14 + else: # mode == "crop" + # Original behavior: set width to 518px + new_width = target_size + # Calculate height maintaining aspect ratio, divisible by 14 + new_height = round(height * (new_width / width) / 14) * 14 + + # Resize with new dimensions (width, height) + img = img.resize((new_width, new_height), PILImage.Resampling.BICUBIC) + img = to_tensor(img) # Convert to tensor (0, 1) + + # Center crop height if it's larger than 518 (only in crop mode) + if mode == "crop" and new_height > target_size: + start_y = (new_height - target_size) // 2 + img = img[:, start_y : start_y + target_size, :] + + # For pad mode, pad to make a square of target_size x target_size + if mode == "pad": + h_padding = target_size - img.shape[1] + w_padding = target_size - img.shape[2] + + if h_padding > 0 or w_padding > 0: + pad_top = h_padding // 2 + pad_bottom = h_padding - pad_top + pad_left = w_padding // 2 + pad_right = w_padding - pad_left + + # Pad with white (value=1.0) + img = torch.nn.functional.pad( + img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 + ) + + shapes.add((img.shape[1], img.shape[2])) + images.append(img) + + # Check if we have different shapes + # In theory our model can also work well with different shapes + if len(shapes) > 1: + print(f"Warning: Found images with different shapes: {shapes}") + # Find maximum dimensions + max_height = max(shape[0] for shape in shapes) + max_width = max(shape[1] for shape in shapes) + + # Pad images if necessary + padded_images = [] + for img in images: + h_padding = max_height - img.shape[1] + w_padding = max_width - img.shape[2] + + if h_padding > 0 or w_padding > 0: + pad_top = h_padding // 2 + pad_bottom = h_padding - pad_top + pad_left = w_padding // 2 + pad_right = w_padding - pad_left + + img = torch.nn.functional.pad( + img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 + ) + padded_images.append(img) + images = padded_images + + images = torch.stack(images) # concatenate images + + # Ensure correct shape when single image + if len(indices) == 1: + # Verify shape is (1, C, H, W) + if images.dim() == 3: + images = images.unsqueeze(0) + + height, width = images.shape[-2], images.shape[-1] + coords = np.tile([0.0, 0.0, float(width), float(height), float(width), float(height)], (len(indices), 1)) + original_coords_tensor = torch.from_numpy(coords).float() + + return images, original_coords_tensor + def get_all_descriptor_image_batches_as_futures( self, client: Client, From 90b252def2b43a2753fa217768929dadb96fb0aa Mon Sep 17 00:00:00 2001 From: nantonzhang Date: Mon, 26 Jan 2026 23:19:19 -0500 Subject: [PATCH 05/24] Decrease min_score to increase connectivity --- gtsfm/configs/vggt.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsfm/configs/vggt.yaml b/gtsfm/configs/vggt.yaml index 01781aa8a..b881a46da 100644 --- a/gtsfm/configs/vggt.yaml +++ b/gtsfm/configs/vggt.yaml @@ -18,7 +18,7 @@ image_pairs_generator: retriever: _target_: gtsfm.retriever.Similarity num_matched: 5 - min_score: 0.79 + min_score: 0.25 batch_size: 16 graph_partitioner: From a1c3c348346032e47e0e2099d0776ca9fa529009 Mon Sep 17 00:00:00 2001 From: nantonzhang Date: Mon, 26 Jan 2026 23:22:39 -0500 Subject: [PATCH 06/24] Use Metis by default --- gtsfm/configs/vggt.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsfm/configs/vggt.yaml b/gtsfm/configs/vggt.yaml index b881a46da..85e6b5fcf 100644 --- a/gtsfm/configs/vggt.yaml +++ b/gtsfm/configs/vggt.yaml @@ -22,7 +22,7 @@ image_pairs_generator: batch_size: 16 graph_partitioner: - _target_: gtsfm.graph_partitioner.Single + _target_: gtsfm.graph_partitioner.Metis cluster_optimizer: _target_: gtsfm.cluster_optimizer.Cacher From a7d6c90a501b6d6d913e305ec9ad35ed30480d74 Mon Sep 17 00:00:00 2001 From: nantonzhang Date: Mon, 26 Jan 2026 23:26:00 -0500 Subject: [PATCH 07/24] exclude all result folders --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 29a4be5f9..62fe37926 100644 --- a/.gitignore +++ b/.gitignore @@ -155,6 +155,7 @@ data/ # Data dumped by GTSFM directory debug/ plots/ +**/*result*/ results/ result_metrics/ *.html From 562448f4fe1f08704b07f5b4cbb08a3b080b5e61 Mon Sep 17 00:00:00 2001 From: nantonzhang Date: Tue, 27 Jan 2026 09:24:00 -0500 Subject: [PATCH 08/24] change default to no ba --- gtsfm/configs/vggt.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsfm/configs/vggt.yaml b/gtsfm/configs/vggt.yaml index 85e6b5fcf..59d8e23bf 100644 --- a/gtsfm/configs/vggt.yaml +++ b/gtsfm/configs/vggt.yaml @@ -45,5 +45,5 @@ cluster_optimizer: seed: 42 plot_reprojection_histograms: true run_bundle_adjustment_on_leaf: false - run_bundle_adjustment_on_parent: true + run_bundle_adjustment_on_parent: false model_cache_key: null From b9286e4802c7a9d7db5edd0dae70e42b19121532 Mon Sep 17 00:00:00 2001 From: nantonzhang Date: Tue, 27 Jan 2026 10:07:41 -0500 Subject: [PATCH 09/24] Add more metrics to eval code --- gtsfm/evaluation/compare_colmap_outputs.py | 149 ++- .../compare_colmap_outputs_by_cluster.py | 898 ++++++++++++++++++ 2 files changed, 1033 insertions(+), 14 deletions(-) create mode 100644 gtsfm/evaluation/compare_colmap_outputs_by_cluster.py diff --git a/gtsfm/evaluation/compare_colmap_outputs.py b/gtsfm/evaluation/compare_colmap_outputs.py index 6e163d697..66b8016c1 100644 --- a/gtsfm/evaluation/compare_colmap_outputs.py +++ b/gtsfm/evaluation/compare_colmap_outputs.py @@ -1,13 +1,17 @@ """Script to compare two reconstructions in Colmap's output format. -Authors: Ayush Baid +Authors: Ayush Baid, Xinan Zhang """ import argparse +import csv +import json import os -from typing import Dict, Tuple +import textwrap +from typing import Dict, List, Optional, Tuple import numpy as np +import matplotlib.pyplot as plt import pycolmap from gtsam import Point3, Pose3, Rot3, Similarity3 from scipy.spatial.transform import Rotation @@ -53,7 +57,108 @@ def align_with_colmap( return aSb, aligned_dict -def compare_poses(baseline_dirpath: str, eval_dirpath: str, output_dirpath: str) -> None: +def plot_camera_centers( + baseline_wTi_list: List[Pose3], + current_wTi_list: List[Pose3], + output_dirpath: str, + title: Optional[str] = None, +) -> None: + """Save a 3D scatter plot of baseline and current camera centers.""" + baseline_centers = np.stack([pose.translation() for pose in baseline_wTi_list]) + current_centers_list = [pose.translation() for pose in current_wTi_list] + current_centers = np.stack(current_centers_list) if current_centers_list else np.empty((0, 3)) + + fig = plt.figure(figsize=(7, 7)) + ax = fig.add_subplot(111, projection="3d") + if baseline_centers.size: + center = baseline_centers.mean(axis=0) + mean_radius = np.linalg.norm(baseline_centers - center, axis=1).mean() + arrow_len = max(mean_radius * 0.15, 1e-3) + else: + arrow_len = 1.0 + + for pose in baseline_wTi_list: + origin = pose.transformFrom(Point3(0.0, 0.0, 0.0)) + tip = pose.transformFrom(Point3(0.0, 0.0, arrow_len)) + direction = tip - origin + ax.quiver( + origin[0], origin[1], origin[2], + direction[0], direction[1], direction[2], + color="tab:blue", linewidth=0.5, arrow_length_ratio=0.2, alpha=0.6 + ) + for pose in current_wTi_list: + origin = pose.transformFrom(Point3(0.0, 0.0, 0.0)) + tip = pose.transformFrom(Point3(0.0, 0.0, arrow_len)) + direction = tip - origin + ax.quiver( + origin[0], origin[1], origin[2], + direction[0], direction[1], direction[2], + color="tab:orange", linewidth=0.5, arrow_length_ratio=0.2, alpha=0.6 + ) + + ax.scatter( + baseline_centers[:, 0], + baseline_centers[:, 1], + baseline_centers[:, 2], + s=10, + c="tab:blue", + label="baseline", + ) + if current_centers.size: + ax.scatter( + current_centers[:, 0], + current_centers[:, 1], + current_centers[:, 2], + s=10, + c="tab:orange", + label="current", + ) + ax.set_xlabel("x") + ax.set_ylabel("y") + ax.set_zlabel("z") + ax.legend(loc="best") + wrapped = "\n".join(textwrap.wrap(title, width=80)) if title else "" + if wrapped: + fig.suptitle(wrapped, fontsize=9, y=0.98) + fig.tight_layout(rect=[0, 0, 1, 0.92]) + fig.savefig(os.path.join(output_dirpath, "camera_centers.png"), dpi=300) + plt.close(fig) + + +def export_metrics_group_to_csv(metrics_group: GtsfmMetricsGroup, output_path: str) -> None: + """Export a metrics group to a CSV file.""" + rows: List[Dict[str, str]] = [] + for metric in metrics_group.metrics: + if metric.dim == 0: + value = "" if metric.data is None else f"{float(metric.data):.6f}" + rows.append({"metric_name": metric.name, "value": value}) + else: + summary_json = json.dumps(metric.summary, sort_keys=True) + rows.append({"metric_name": metric.name, "value": summary_json}) + + with open(output_path, "w", newline="") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=["metric_name", "value"]) + writer.writeheader() + writer.writerows(rows) + + +def _format_pose_auc(metrics_group: GtsfmMetricsGroup) -> str: + auc_parts = [] + for metric in metrics_group.metrics: + if not metric.name.startswith("pose_auc_@"): + continue + if metric.data is None: + continue + try: + value = float(metric.data) + except (TypeError, ValueError): + continue + suffix = metric.name.replace("pose_auc_", "") + auc_parts.append(f"{suffix}={value:.3f}") + return ", ".join(auc_parts) + + +def compare_poses(baseline_dirpath: str, eval_dirpath: str, output_dirpath: str) -> GtsfmMetricsGroup: """Compare the pose metrics between two reconstructions (Colmap format). Args: @@ -82,25 +187,32 @@ def compare_poses(baseline_dirpath: str, eval_dirpath: str, output_dirpath: str) len(common_fnames), ) + baseline_wTi_list: List[Pose3] = [] + current_wTi_list: List[Optional[Pose3]] = [] + for fname, wTi in baseline_wTi_dict.items(): + baseline_wTi_list.append(wTi) + current_wTi_list.append(current_wTi_dict.get(fname)) + if not args.use_pycolmap_alignment: - aSb = align.sim3_from_Pose3_maps_robust(baseline_wTi_dict, current_wTi_dict) - current_wTi_dict = transform.Pose3_map_with_sim3(aSb, current_wTi_dict) + aSb = align.sim3_from_optional_Pose3s(baseline_wTi_list, current_wTi_list) + current_wTi_list = transform.optional_Pose3s_with_sim3(aSb, current_wTi_list) + current_wTi_dict = {fname: aSb.transformFrom(pose) for fname, pose in current_wTi_dict.items()} - i2Ri1_dict_gt, i2Ui1_dict_gt = metric_utils.get_all_relative_rotations_translations(baseline_wTi_dict) + i2Ri1_dict_gt, i2Ui1_dict_gt = metric_utils.get_all_relative_rotations_translations(baseline_wTi_list) - wRi_aligned, wti_aligned = metric_utils.get_rotations_translations_from_poses(current_wTi_dict) - baseline_wRi, baseline_wti = metric_utils.get_rotations_translations_from_poses(baseline_wTi_dict) + wRi_aligned_list, wti_aligned_list = metric_utils.get_rotations_translations_from_poses(current_wTi_list) + baseline_wRi_list, baseline_wti_list = metric_utils.get_rotations_translations_from_poses(baseline_wTi_list) metrics = [] - metrics.append(metric_utils.compute_rotation_angle_metric(wRi_aligned, baseline_wRi)) - metrics.append(metric_utils.compute_translation_distance_metric(wti_aligned, baseline_wti)) - metrics.append(metric_utils.compute_translation_angle_metric(baseline_wTi_dict, current_wTi_dict)) + metrics.append(metric_utils.compute_rotation_angle_metric(wRi_aligned_list, baseline_wRi_list)) + metrics.append(metric_utils.compute_translation_distance_metric(wti_aligned_list, baseline_wti_list)) + metrics.append(metric_utils.compute_translation_angle_metric(baseline_wTi_list, current_wTi_list)) relative_rotation_error_metric = metric_utils.compute_relative_rotation_angle_metric( - i2Ri1_dict_gt, current_wTi_dict, store_full_data=True + i2Ri1_dict_gt, current_wTi_list, store_full_data=True ) metrics.append(relative_rotation_error_metric) relative_translation_error_metric = metric_utils.compute_relative_translation_angle_metric( - i2Ui1_dict_gt, current_wTi_dict, store_full_data=True + i2Ui1_dict_gt, current_wTi_list, store_full_data=True ) metrics.append(relative_translation_error_metric) @@ -114,7 +226,14 @@ def compare_poses(baseline_dirpath: str, eval_dirpath: str, output_dirpath: str) ba_pose_metrics = GtsfmMetricsGroup(name="ba_pose_error_metrics", metrics=metrics) + auc_text = _format_pose_auc(ba_pose_metrics) + title = eval_dirpath + if auc_text: + title = f"{title}\nPose AUC: {auc_text}" + plot_camera_centers(baseline_wTi_list, list(current_wTi_dict.values()), output_dirpath, title=title) + save_metrics_reports([ba_pose_metrics], metrics_path=output_dirpath) + return ba_pose_metrics if __name__ == "__main__": @@ -139,4 +258,6 @@ def compare_poses(baseline_dirpath: str, eval_dirpath: str, output_dirpath: str) os.makedirs(args.output, exist_ok=True) - compare_poses(args.baseline, args.current, args.output) + ba_pose_metrics = compare_poses(args.baseline, args.current, args.output) + export_metrics_group_to_csv(ba_pose_metrics, os.path.join(args.output, f"{ba_pose_metrics.name}.csv")) + diff --git a/gtsfm/evaluation/compare_colmap_outputs_by_cluster.py b/gtsfm/evaluation/compare_colmap_outputs_by_cluster.py new file mode 100644 index 000000000..869ebbcef --- /dev/null +++ b/gtsfm/evaluation/compare_colmap_outputs_by_cluster.py @@ -0,0 +1,898 @@ +"""Compare COLMAP reconstructions using image-name alignment. + +This script walks a results tree, finds cluster reconstructions under a given subfolder +name (default: "vggt"), and evaluates camera pose quality against a COLMAP baseline. +""" + +from __future__ import annotations + +import argparse +import csv +import json +import os +from pathlib import Path +from typing import Dict, Iterable, List, Optional, Tuple +import textwrap + +import matplotlib.pyplot as plt +import numpy as np +from gtsam import Pose3, Rot3 + +import gtsfm.utils.logger as logger_utils +import gtsfm.utils.metrics as metric_utils +from gtsfm.evaluation.metrics import GtsfmMetric, GtsfmMetricsGroup +from gtsfm.utils import align + +logger = logger_utils.get_logger() + + +def _read_images_txt_with_names(images_txt: Path) -> Dict[str, Pose3]: + """Read poses from COLMAP images.txt keyed by image NAME.""" + if not images_txt.exists(): + raise FileNotFoundError(f"{images_txt} does not exist.") + + with images_txt.open("r") as f: + lines = f.readlines() + + poses_by_name: Dict[str, Pose3] = {} + for line in lines: + if not line.strip() or line.startswith("#"): + continue + parts = line.split() + if len(parts) < 10: + continue + _image_id_str, qw, qx, qy, qz, tx, ty, tz, _camera_id = parts[:9] + img_fname = " ".join(parts[9:]) + iRw = Rot3(float(qw), float(qx), float(qy), float(qz)) + wTi = Pose3(iRw, np.array([tx, ty, tz], dtype=np.float64)).inverse() + if not np.isfinite(wTi.translation()).all(): + logger.warning("Skipping non-finite pose for %s in %s", img_fname, images_txt) + continue + poses_by_name[img_fname] = wTi + return poses_by_name + + +def _read_images_txt_with_names_and_cameras(images_txt: Path) -> Tuple[Dict[str, Pose3], Dict[str, int]]: + """Read poses and camera ids from COLMAP images.txt keyed by image NAME.""" + if not images_txt.exists(): + raise FileNotFoundError(f"{images_txt} does not exist.") + + with images_txt.open("r") as f: + lines = f.readlines() + + poses_by_name: Dict[str, Pose3] = {} + camera_by_name: Dict[str, int] = {} + for line in lines: + if not line.strip() or line.startswith("#"): + continue + parts = line.split() + if len(parts) < 10: + continue + _image_id_str, qw, qx, qy, qz, tx, ty, tz, camera_id_str = parts[:9] + img_fname = " ".join(parts[9:]) + iRw = Rot3(float(qw), float(qx), float(qy), float(qz)) + wTi = Pose3(iRw, np.array([tx, ty, tz], dtype=np.float64)).inverse() + if not np.isfinite(wTi.translation()).all(): + logger.warning("Skipping non-finite pose for %s in %s", img_fname, images_txt) + continue + camera_id = int(camera_id_str) + poses_by_name[img_fname] = wTi + camera_by_name[img_fname] = camera_id + return poses_by_name, camera_by_name + + +def _read_cameras_txt_with_ids(cameras_txt: Path) -> Dict[int, Dict[str, float]]: + """Read camera intrinsics from COLMAP cameras.txt keyed by CAMERA_ID.""" + if not cameras_txt.exists(): + raise FileNotFoundError(f"{cameras_txt} does not exist.") + + with cameras_txt.open("r") as f: + lines = f.readlines() + + cameras_by_id: Dict[int, Dict[str, float]] = {} + for line in lines[3:]: + if line.startswith("#") or not line.strip(): + continue + parts = line.split() + camera_id = int(parts[0]) + model = parts[1] + width = int(parts[2]) + height = int(parts[3]) + params = list(map(float, parts[4:])) + if model == "SIMPLE_PINHOLE": + fx, cx, cy = params + fy = fx + k1 = 0.0 + k2 = 0.0 + elif model == "SIMPLE_RADIAL": + fx, cx, cy, k1 = params + k2 = 0.0 + fy = fx + elif model == "RADIAL": + fx, cx, cy, k1, k2 = params + fy = fx + elif model == "PINHOLE": + fx, fy, cx, cy = params + k1 = 0.0 + k2 = 0.0 + elif model == "OPENCV": + fx, fy, cx, cy, k1, k2, _p1, _p2, *_rest = params + elif model == "OPENCV_FISHEYE": + fx, fy, cx, cy, k1, k2, *_rest = params + else: + logger.warning("Unsupported camera model %s; skipping camera_id=%d", model, camera_id) + continue + cameras_by_id[camera_id] = { + "model": model, + "width": float(width), + "height": float(height), + "fx": fx, + "fy": fy, + "cx": cx, + "cy": cy, + "k1": k1, + "k2": k2, + } + return cameras_by_id + + +def _find_cluster_recon_dirs(root: Path, recon_name: str) -> Iterable[Path]: + """Yield directories that match the recon_name and contain images.txt.""" + for dirpath, dirnames, filenames in os.walk(root): + if os.path.basename(dirpath) != recon_name: + continue + if "images.txt" in filenames: + yield Path(dirpath) + + +def _build_pose_lists( + baseline_poses: Dict[str, Pose3], + current_poses: Dict[str, Pose3], + cluster_label: str, +) -> Tuple[List[str], List[Pose3], List[Pose3]]: + """Align poses by image NAME and return matched pose lists.""" + common_names = sorted(set(baseline_poses.keys()) & set(current_poses.keys())) + if not common_names: + missing_in_baseline = sorted(set(current_poses.keys()) - set(baseline_poses.keys())) + missing_in_current = sorted(set(baseline_poses.keys()) - set(current_poses.keys())) + if missing_in_baseline: + logger.warning( + "No common images for %s; missing in baseline (sample): %s", + cluster_label, + ", ".join(missing_in_baseline[:5]), + ) + if missing_in_current: + logger.warning( + "No common images for %s; missing in current (sample): %s", + cluster_label, + ", ".join(missing_in_current[:5]), + ) + else: + logger.info("Common images for %s: %d", cluster_label, len(common_names)) + baseline_list = [baseline_poses[name] for name in common_names] + current_list = [current_poses[name] for name in common_names] + return common_names, baseline_list, current_list + + +def _compute_pose_metrics(baseline_list: List[Pose3], current_aligned_list: List[Pose3]) -> GtsfmMetricsGroup: + """Compute the same pose metrics as compare_colmap_outputs, without plotting.""" + i2Ri1_dict_gt, i2Ui1_dict_gt = metric_utils.get_all_relative_rotations_translations(baseline_list) + wRi_aligned_list, wti_aligned_list = metric_utils.get_rotations_translations_from_poses(current_aligned_list) + baseline_wRi_list, baseline_wti_list = metric_utils.get_rotations_translations_from_poses(baseline_list) + + metrics = [] + metrics.append(metric_utils.compute_rotation_angle_metric(wRi_aligned_list, baseline_wRi_list)) + metrics.append(metric_utils.compute_translation_distance_metric(wti_aligned_list, baseline_wti_list)) + metrics.append(metric_utils.compute_translation_angle_metric(baseline_list, current_aligned_list)) + relative_rotation_error_metric = metric_utils.compute_relative_rotation_angle_metric( + i2Ri1_dict_gt, current_aligned_list, store_full_data=True + ) + metrics.append(relative_rotation_error_metric) + relative_translation_error_metric = metric_utils.compute_relative_translation_angle_metric( + i2Ui1_dict_gt, current_aligned_list, store_full_data=True + ) + metrics.append(relative_translation_error_metric) + thresholds_deg = (1.0, 2.5, 5.0, 10.0, 20.0) + if relative_rotation_error_metric.data is not None: + rotation_angular_errors = np.asarray(relative_rotation_error_metric.data) + rotation_auc_values = metric_utils.pose_auc(rotation_angular_errors, thresholds_deg) + metrics.extend( + [ + GtsfmMetric(f"rotation_auc_@{threshold}_deg", auc) + for threshold, auc in zip(thresholds_deg, rotation_auc_values) + ] + ) + if relative_translation_error_metric.data is not None: + translation_angular_errors = np.asarray(relative_translation_error_metric.data) + translation_auc_values = metric_utils.pose_auc(translation_angular_errors, thresholds_deg) + metrics.extend( + [ + GtsfmMetric(f"translation_auc_@{threshold}_deg", auc) + for threshold, auc in zip(thresholds_deg, translation_auc_values) + ] + ) + metrics.extend( + metric_utils.compute_pose_auc_metric( + relative_rotation_error_metric.data, relative_translation_error_metric.data, thresholds_deg=thresholds_deg + ) + ) + + return GtsfmMetricsGroup(name="ba_pose_error_metrics", metrics=metrics) + + +def _estimate_sim3_ransac( + baseline_list: List[Pose3], + current_list: List[Pose3], + max_hypotheses: int, + inlier_thresh: float, + rng: np.random.Generator, + cluster_label: str, +) -> align.Similarity3: + """Estimate Sim(3) using simple RANSAC over camera centers with refit on inliers.""" + n_to_align = len(baseline_list) + if n_to_align < 2: + logger.warning("SIM(3) alignment uses at least 2 frames; Skipping for %s", cluster_label) + return align.Similarity3(Rot3(), np.zeros(3), 1.0) + + baseline_centers = np.stack([pose.translation() for pose in baseline_list]) + current_centers = np.stack([pose.translation() for pose in current_list]) + best_inliers: Optional[np.ndarray] = None + best_count = -1 + best_mean_error = float("inf") + best_aSb: Optional[align.Similarity3] = None + + for _ in range(max_hypotheses): + sample_idx = rng.choice(n_to_align, size=2, replace=False) + baseline_sample = {i: baseline_list[idx] for i, idx in enumerate(sample_idx)} + current_sample = {i: current_list[idx] for i, idx in enumerate(sample_idx)} + try: + aSb_candidate = align.sim3_from_Pose3_maps(baseline_sample, current_sample) + except Exception: + continue + transformed = np.stack([aSb_candidate.transformFrom(p) for p in current_centers]) + errors = np.linalg.norm(baseline_centers - transformed, axis=1) + inliers = errors <= inlier_thresh + count = int(np.count_nonzero(inliers)) + mean_error = float(errors[inliers].mean()) if count > 0 else float("inf") + if count > best_count or (count == best_count and mean_error < best_mean_error): + best_count = count + best_mean_error = mean_error + best_inliers = inliers + best_aSb = aSb_candidate + + if best_aSb is None or best_inliers is None: + logger.warning("Robust Sim3 failed; falling back to all-poses alignment for %s", cluster_label) + baseline_dict = {i: pose for i, pose in enumerate(baseline_list)} + current_dict = {i: pose for i, pose in enumerate(current_list)} + return align.sim3_from_Pose3_maps(baseline_dict, current_dict) + + inlier_indices = np.where(best_inliers)[0] + if len(inlier_indices) < 2: + logger.warning( + "Robust Sim3 inliers too few (%d/%d); using best hypothesis for %s", + len(inlier_indices), + n_to_align, + cluster_label, + ) + return best_aSb + + baseline_inliers = {i: baseline_list[idx] for i, idx in enumerate(inlier_indices)} + current_inliers = {i: current_list[idx] for i, idx in enumerate(inlier_indices)} + aSb_refit = align.sim3_from_Pose3_maps(baseline_inliers, current_inliers) + logger.info( + "Robust Sim3 for %s: inliers=%d/%d, thresh=%.3f", + cluster_label, + len(inlier_indices), + n_to_align, + inlier_thresh, + ) + return aSb_refit + + +def _align_poses( + baseline_list: List[Pose3], + current_list: List[Pose3], + use_ransac: bool, + max_hypotheses: int, + inlier_thresh: float, + rng: np.random.Generator, + cluster_label: str, +) -> Tuple[List[Pose3], align.Similarity3]: + """Align current poses to baseline using Sim(3), optionally with RANSAC+refit.""" + baseline_dict = {i: pose for i, pose in enumerate(baseline_list)} + current_dict = {i: pose for i, pose in enumerate(current_list)} + if use_ransac: + aSb = _estimate_sim3_ransac( + baseline_list, current_list, max_hypotheses, inlier_thresh, rng, cluster_label + ) + else: + aSb = align.sim3_from_Pose3_maps(baseline_dict, current_dict) + current_aligned_list = [aSb.transformFrom(pose) for pose in current_list] + return current_aligned_list, aSb + + +def _plot_camera_centers( + baseline_list: List[Pose3], + current_list: List[Pose3], + output_path: Path, + title: str, +) -> None: + """Save a 3D scatter plot of baseline and current camera centers.""" + baseline_centers = np.stack([pose.translation() for pose in baseline_list]) + current_centers_list = [pose.translation() for pose in current_list] + current_centers = np.stack(current_centers_list) if current_centers_list else np.empty((0, 3)) + + fig = plt.figure(figsize=(7, 7)) + ax = fig.add_subplot(111, projection="3d") + if baseline_centers.size: + center = baseline_centers.mean(axis=0) + mean_radius = np.linalg.norm(baseline_centers - center, axis=1).mean() + arrow_len = max(mean_radius * 0.15, 1e-3) + else: + arrow_len = 1.0 + + for pose in baseline_list: + origin = pose.transformFrom(np.array([0.0, 0.0, 0.0])) + tip = pose.transformFrom(np.array([0.0, 0.0, arrow_len])) + direction = tip - origin + ax.quiver( + origin[0], origin[1], origin[2], + direction[0], direction[1], direction[2], + color="tab:blue", linewidth=0.5, arrow_length_ratio=0.2, alpha=0.6 + ) + for pose in current_list: + origin = pose.transformFrom(np.array([0.0, 0.0, 0.0])) + tip = pose.transformFrom(np.array([0.0, 0.0, arrow_len])) + direction = tip - origin + ax.quiver( + origin[0], origin[1], origin[2], + direction[0], direction[1], direction[2], + color="tab:orange", linewidth=0.5, arrow_length_ratio=0.2, alpha=0.6 + ) + + ax.scatter( + baseline_centers[:, 0], + baseline_centers[:, 1], + baseline_centers[:, 2], + s=10, + c="tab:blue", + label="baseline", + ) + if current_centers.size: + ax.scatter( + current_centers[:, 0], + current_centers[:, 1], + current_centers[:, 2], + s=10, + c="tab:orange", + label="current", + ) + ax.set_xlabel("x") + ax.set_ylabel("y") + ax.set_zlabel("z") + ax.legend(loc="best") + if title: + wrapped_lines = [] + for line in title.splitlines(): + wrapped_lines.extend(textwrap.wrap(line, width=80) or [""]) + wrapped = "\n".join(wrapped_lines) + else: + wrapped = "" + if wrapped: + fig.suptitle(wrapped, fontsize=9, y=0.98) + fig.tight_layout(rect=[0, 0, 1, 0.92]) + output_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(str(output_path), dpi=300) + plt.close(fig) + + +def _summarize_pose_errors( + baseline_list: List[Pose3], + current_aligned_list: List[Pose3], + cluster_label: str, +) -> None: + """Log median/mean absolute pose errors after alignment.""" + if not baseline_list or not current_aligned_list: + return + rot_errors_deg: List[float] = [] + trans_errors: List[float] = [] + for baseline_pose, current_pose in zip(baseline_list, current_aligned_list): + rel = baseline_pose.between(current_pose) + rot_vec = Rot3.Logmap(rel.rotation()) + rot_errors_deg.append(float(np.rad2deg(np.linalg.norm(rot_vec)))) + trans_errors.append(float(np.linalg.norm(rel.translation()))) + logger.info( + "Pose errors for %s: rot_deg median=%.3f mean=%.3f; trans median=%.3f mean=%.3f", + cluster_label, + float(np.median(rot_errors_deg)), + float(np.mean(rot_errors_deg)), + float(np.median(trans_errors)), + float(np.mean(trans_errors)), + ) + + +def _plot_pose_auc_boxplot(auc_values_by_label: Dict[str, List[float]], output_path: Path, title: str) -> None: + """Save box plots for AUC metrics across all clusters.""" + preferred_order = ["@1.0_deg", "@2.5_deg", "@5.0_deg", "@10.0_deg", "@20.0_deg"] + labels = [label for label in preferred_order if auc_values_by_label.get(label)] + if not labels: + labels = sorted(auc_values_by_label.keys()) + data = [auc_values_by_label[label] for label in labels if auc_values_by_label.get(label)] + labels = [label for label in labels if auc_values_by_label.get(label)] + if not data: + return + + fig = plt.figure(figsize=(6, 4)) + ax = fig.add_subplot(111) + ax.boxplot(data, vert=True, patch_artist=True) + ax.set_title(title) + ax.set_ylabel("AUC") + ax.set_xticks(range(1, len(labels) + 1)) + ax.set_xticklabels(labels, rotation=30, ha="right") + stats_lines = [] + for label, values in zip(labels, data): + mean_val = float(np.mean(values)) + median_val = float(np.median(values)) + stats_lines.append(f"{label}: mean={mean_val:.3f}, med={median_val:.3f}") + if stats_lines: + ax.text( + 0.02, + 0.98, + "\n".join(stats_lines), + transform=ax.transAxes, + va="top", + ha="left", + fontsize=8, + bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8), + ) + fig.tight_layout() + output_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(str(output_path), dpi=300) + plt.close(fig) + + +def _plot_pose_auc_vs_input_images( + auc_by_label_and_count: Dict[str, List[Tuple[int, float]]], + output_path: Path, +) -> None: + """Plot pose AUC at each threshold vs. input image count across clusters.""" + preferred_order = ["@1.0_deg", "@2.5_deg", "@5.0_deg", "@10.0_deg", "@20.0_deg"] + labels = [label for label in preferred_order if auc_by_label_and_count.get(label)] + if not labels: + labels = sorted(auc_by_label_and_count.keys()) + + fig = plt.figure(figsize=(7, 5)) + ax = fig.add_subplot(111) + for label in labels: + pairs = auc_by_label_and_count.get(label, []) + if not pairs: + continue + pairs_sorted = sorted(pairs, key=lambda pair: pair[0]) + x_vals = [pair[0] for pair in pairs_sorted] + y_vals = [pair[1] for pair in pairs_sorted] + ax.plot(x_vals, y_vals, marker="o", linewidth=1.0, markersize=4, alpha=0.85, label=label) + + ax.set_title("Pose AUC vs input images (all clusters)") + ax.set_xlabel("input images (current count)") + ax.set_ylabel("AUC") + ax.grid(True, linestyle="--", linewidth=0.5, alpha=0.5) + ax.legend(loc="best", fontsize=8) + fig.tight_layout() + output_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(str(output_path), dpi=300) + plt.close(fig) + + +def _plot_intrinsics_deltas_boxplot(deltas: Dict[str, List[float]], output_path: Path, title: str) -> None: + """Save box plots for normalized intrinsics deltas for a cluster.""" + labels = ["delta_fx_norm", "delta_fy_norm", "delta_cx_norm", "delta_cy_norm"] + data = [deltas.get(label, []) for label in labels] + if not any(data): + return + + fig = plt.figure(figsize=(6, 4)) + ax = fig.add_subplot(111) + ax.boxplot(data, vert=True, patch_artist=True) + ax.set_title(title) + ax.set_ylabel("normalized by baseline value") + ax.set_xticks(range(1, len(labels) + 1)) + ax.set_xticklabels(labels, rotation=20, ha="right") + stats_lines = [] + for label, values in zip(labels, data): + if not values: + continue + mean_val = float(np.mean(values)) + median_val = float(np.median(values)) + stats_lines.append(f"{label}: mean={mean_val:.3f}, med={median_val:.3f}") + if stats_lines: + ax.text( + 0.02, + 0.98, + "\n".join(stats_lines), + transform=ax.transAxes, + va="top", + ha="left", + fontsize=8, + bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8), + ) + fig.tight_layout() + output_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(str(output_path), dpi=300) + plt.close(fig) + + +def _plot_fov_deltas_boxplot(deltas: Dict[str, List[float]], output_path: Path, title: str) -> None: + """Save box plots for FOV deltas (degrees) for a cluster.""" + labels = ["delta_fovx_deg", "delta_fovy_deg"] + data = [deltas.get(label, []) for label in labels] + if not any(data): + return + + fig = plt.figure(figsize=(6, 4)) + ax = fig.add_subplot(111) + ax.boxplot(data, vert=True, patch_artist=True) + ax.set_title(title) + ax.set_ylabel("degrees") + ax.set_xticks(range(1, len(labels) + 1)) + ax.set_xticklabels(labels, rotation=20, ha="right") + stats_lines = [] + for label, values in zip(labels, data): + if not values: + continue + mean_val = float(np.mean(values)) + median_val = float(np.median(values)) + stats_lines.append(f"{label}: mean={mean_val:.3f}, med={median_val:.3f}") + if stats_lines: + ax.text( + 0.02, + 0.98, + "\n".join(stats_lines), + transform=ax.transAxes, + va="top", + ha="left", + fontsize=8, + bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8), + ) + fig.tight_layout() + output_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(str(output_path), dpi=300) + plt.close(fig) + + +def _print_metrics(label: str, metrics_group: GtsfmMetricsGroup) -> None: + logger.info("=== %s ===", label) + for metric in metrics_group.metrics: + if metric.dim == 0: + value = "" if metric.data is None else f"{float(metric.data):.6f}" + logger.info("%s: %s", metric.name, value) + else: + logger.info("%s: %s", metric.name, json.dumps(metric.summary, sort_keys=True)) + + +def _format_auc(metrics_group: GtsfmMetricsGroup, prefix: str) -> str: + auc_parts = [] + prefix_token = f"{prefix}_@" + for metric in metrics_group.metrics: + if not metric.name.startswith(prefix_token): + continue + if metric.data is None: + continue + try: + value = float(metric.data) + except (TypeError, ValueError): + continue + suffix = metric.name.replace(f"{prefix}_", "") + auc_parts.append(f"{suffix}={value:.3f}") + return ", ".join(auc_parts) + + +def export_metrics_group_to_csv( + metrics_group: GtsfmMetricsGroup, + cluster_label: str, + baseline_count: int, + current_count: int, + common_count: int, + output_path: Path, + rows: List[Dict[str, str]], +) -> None: + """Append metrics for a cluster into a shared CSV row list.""" + auc_values: List[float] = [] + for metric in metrics_group.metrics: + if metric.dim == 0: + value = "" if metric.data is None else f"{float(metric.data):.6f}" + if metric.name.startswith("pose_auc_@") and metric.data is not None: + try: + auc_values.append(float(metric.data)) + except (TypeError, ValueError): + pass + else: + value = json.dumps(metric.summary, sort_keys=True) + rows.append( + { + "cluster": cluster_label, + "baseline_count": str(baseline_count), + "current_count": str(current_count), + "common_count": str(common_count), + "metric_name": metric.name, + "value": value, + } + ) + if auc_values: + rows.append( + { + "cluster": cluster_label, + "baseline_count": str(baseline_count), + "current_count": str(current_count), + "common_count": str(common_count), + "metric_name": "pose_auc_avg", + "value": f"{float(np.mean(auc_values)):.6f}", + } + ) + + if output_path.exists() and output_path.stat().st_size > 0: + return + + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("w", newline="") as csvfile: + writer = csv.DictWriter( + csvfile, + fieldnames=["cluster", "baseline_count", "current_count", "common_count", "metric_name", "value"], + ) + writer.writeheader() + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--baseline", required=True, help="Path to baseline COLMAP directory.") + parser.add_argument("--root", required=True, help="Root directory to traverse for cluster reconstructions.") + parser.add_argument("--recon_name", default="vggt", help="Subdirectory name for reconstructions.") + parser.add_argument( + "--csv_output", + default=None, + help="Optional path to a single CSV file for all cluster metrics.", + ) + parser.add_argument( + "--fig_output_dir", + default=None, + help="Optional directory to save per-cluster camera_centers.png plots.", + ) + parser.add_argument( + "--robust_sim3", + action="store_true", + default=False, + help="Use simple RANSAC+refit for Sim(3) alignment.", + ) + parser.add_argument( + "--robust_sim3_max_hypotheses", + type=int, + default=200, + help="Max RANSAC hypotheses for robust Sim(3) alignment.", + ) + parser.add_argument( + "--robust_sim3_inlier_thresh", + type=float, + default=0.1, + help="Inlier threshold on camera-center error for robust Sim(3) alignment.", + ) + parser.add_argument( + "--robust_sim3_seed", + type=int, + default=0, + help="Random seed for robust Sim(3) alignment.", + ) + args = parser.parse_args() + + baseline_images = Path(args.baseline) / "images.txt" + baseline_cameras_txt = Path(args.baseline) / "cameras.txt" + baseline_poses, baseline_camera_by_name = _read_images_txt_with_names_and_cameras(baseline_images) + baseline_cameras = _read_cameras_txt_with_ids(baseline_cameras_txt) + fig_output_dir = Path(args.fig_output_dir) if args.fig_output_dir else None + if fig_output_dir is None and args.csv_output: + fig_output_dir = Path(args.csv_output).parent / "cluster_camera_centers" + + root = Path(args.root) + recon_dirs = sorted(_find_cluster_recon_dirs(root, args.recon_name)) + if not recon_dirs: + raise FileNotFoundError(f"No reconstructions named '{args.recon_name}' with images.txt under {root}") + + logger.info("Found %d reconstructions under %s", len(recon_dirs), root) + + csv_rows: List[Dict[str, str]] = [] + all_pose_auc_values: Dict[str, List[float]] = {} + all_pose_auc_by_label_and_count: Dict[str, List[Tuple[int, float]]] = {} + all_rotation_auc_values: Dict[str, List[float]] = {} + all_translation_auc_values: Dict[str, List[float]] = {} + all_intrinsics_deltas: Dict[str, List[float]] = { + "delta_fx_norm": [], + "delta_fy_norm": [], + "delta_cx_norm": [], + "delta_cy_norm": [], + } + all_fov_deltas: Dict[str, List[float]] = { + "delta_fovx_deg": [], + "delta_fovy_deg": [], + } + rng = np.random.default_rng(args.robust_sim3_seed) + for recon_dir in recon_dirs: + current_images = recon_dir / "images.txt" + current_cameras_txt = recon_dir / "cameras.txt" + current_poses, current_camera_by_name = _read_images_txt_with_names_and_cameras(current_images) + try: + current_cameras = _read_cameras_txt_with_ids(current_cameras_txt) + except FileNotFoundError: + logger.warning("Missing cameras.txt for %s; skipping intrinsics comparison.", recon_dir) + current_cameras = {} + common_names, baseline_list, current_list = _build_pose_lists( + baseline_poses, current_poses, cluster_label=str(recon_dir) + ) + baseline_count = len(baseline_poses) + current_count = len(current_poses) + common_count = len(common_names) + if len(common_names) < 2: + logger.warning( + "Skipping %s (baseline=%d, current=%d, common=%d)", + recon_dir, + baseline_count, + current_count, + common_count, + ) + continue + current_aligned_list, _aSb = _align_poses( + baseline_list, + current_list, + use_ransac=args.robust_sim3, + max_hypotheses=args.robust_sim3_max_hypotheses, + inlier_thresh=args.robust_sim3_inlier_thresh, + rng=rng, + cluster_label=str(recon_dir), + ) + metrics_group = _compute_pose_metrics(baseline_list, current_aligned_list) + _summarize_pose_errors(baseline_list, current_aligned_list, str(recon_dir)) + intrinsics_deltas: Dict[str, List[float]] = { + "delta_fx_norm": [], + "delta_fy_norm": [], + "delta_cx_norm": [], + "delta_cy_norm": [], + } + fov_deltas: Dict[str, List[float]] = { + "delta_fovx_deg": [], + "delta_fovy_deg": [], + } + for name in common_names: + base_cam_id = baseline_camera_by_name.get(name) + curr_cam_id = current_camera_by_name.get(name) + if base_cam_id is None or curr_cam_id is None: + continue + base = baseline_cameras.get(base_cam_id) + curr = current_cameras.get(curr_cam_id) + if base is None or curr is None: + continue + base_w, base_h = base["width"], base["height"] + curr_w, curr_h = curr["width"], curr["height"] + if curr_w > 0 and curr_h > 0 and (base_w != curr_w or base_h != curr_h): + sx = base_w / curr_w + sy = base_h / curr_h + curr_fx = curr["fx"] * sx + curr_fy = curr["fy"] * sy + curr_cx = curr["cx"] * sx + curr_cy = curr["cy"] * sy + else: + curr_fx = curr["fx"] + curr_fy = curr["fy"] + curr_cx = curr["cx"] + curr_cy = curr["cy"] + if base_w > 0 and base_h > 0: + if base["fx"] != 0: + intrinsics_deltas["delta_fx_norm"].append(abs(base["fx"] - curr_fx) / abs(base["fx"])) + if base["fy"] != 0: + intrinsics_deltas["delta_fy_norm"].append(abs(base["fy"] - curr_fy) / abs(base["fy"])) + if base["cx"] != 0: + intrinsics_deltas["delta_cx_norm"].append(abs(base["cx"] - curr_cx) / abs(base["cx"])) + if base["cy"] != 0: + intrinsics_deltas["delta_cy_norm"].append(abs(base["cy"] - curr_cy) / abs(base["cy"])) + base_fovx = 2.0 * np.degrees(np.arctan(base_w / (2.0 * base["fx"]))) + base_fovy = 2.0 * np.degrees(np.arctan(base_h / (2.0 * base["fy"]))) + curr_fovx = 2.0 * np.degrees(np.arctan(base_w / (2.0 * curr_fx))) + curr_fovy = 2.0 * np.degrees(np.arctan(base_h / (2.0 * curr_fy))) + fov_deltas["delta_fovx_deg"].append(abs(base_fovx - curr_fovx)) + fov_deltas["delta_fovy_deg"].append(abs(base_fovy - curr_fovy)) + for key, values in intrinsics_deltas.items(): + all_intrinsics_deltas[key].extend(values) + for key, values in fov_deltas.items(): + all_fov_deltas[key].extend(values) + if args.csv_output: + export_metrics_group_to_csv( + metrics_group, + cluster_label=str(recon_dir), + baseline_count=baseline_count, + current_count=current_count, + common_count=common_count, + output_path=Path(args.csv_output), + rows=csv_rows, + ) + else: + _print_metrics(str(recon_dir), metrics_group) + if fig_output_dir is not None: + safe_name = str(recon_dir).replace(os.sep, "__") + plot_path = fig_output_dir / f"{safe_name}_camera_centers.png" + pose_auc_text = _format_auc(metrics_group, "pose_auc") + rotation_auc_text = _format_auc(metrics_group, "rotation_auc") + translation_auc_text = _format_auc(metrics_group, "translation_auc") + title_lines = [f"{recon_dir}"] + if pose_auc_text: + title_lines.append(f"Pose AUC: {pose_auc_text}") + if rotation_auc_text: + title_lines.append(f"Rotation AUC: {rotation_auc_text}") + if translation_auc_text: + title_lines.append(f"Translation AUC: {translation_auc_text}") + title = "\n".join(title_lines) + _plot_camera_centers(baseline_list, current_aligned_list, plot_path, title) + # Intrinsics stats are annotated in the plot; no terminal logging. + for metric in metrics_group.metrics: + if metric.name.startswith("pose_auc_@") and metric.data is not None: + try: + value = float(metric.data) + except (TypeError, ValueError): + continue + label = metric.name.replace("pose_auc_", "") + all_pose_auc_values.setdefault(label, []).append(value) + all_pose_auc_by_label_and_count.setdefault(label, []).append((current_count, value)) + elif metric.name.startswith("rotation_auc_@") and metric.data is not None: + try: + value = float(metric.data) + except (TypeError, ValueError): + continue + label = metric.name.replace("rotation_auc_", "") + all_rotation_auc_values.setdefault(label, []).append(value) + elif metric.name.startswith("translation_auc_@") and metric.data is not None: + try: + value = float(metric.data) + except (TypeError, ValueError): + continue + label = metric.name.replace("translation_auc_", "") + all_translation_auc_values.setdefault(label, []).append(value) + + if args.csv_output and csv_rows: + output_path = Path(args.csv_output) + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("a", newline="") as csvfile: + writer = csv.DictWriter( + csvfile, + fieldnames=["cluster", "baseline_count", "current_count", "common_count", "metric_name", "value"], + ) + writer.writerows(csv_rows) + if fig_output_dir is not None and all_pose_auc_values: + auc_plot_path = fig_output_dir / "pose_auc_boxplot_all_clusters.png" + _plot_pose_auc_boxplot(all_pose_auc_values, auc_plot_path, "Pose AUC by threshold (all clusters)") + if fig_output_dir is not None and all_pose_auc_by_label_and_count: + auc_vs_images_plot_path = fig_output_dir / "pose_auc_vs_input_images.png" + _plot_pose_auc_vs_input_images(all_pose_auc_by_label_and_count, auc_vs_images_plot_path) + if fig_output_dir is not None and all_rotation_auc_values: + rotation_auc_plot_path = fig_output_dir / "rotation_auc_boxplot_all_clusters.png" + _plot_pose_auc_boxplot(all_rotation_auc_values, rotation_auc_plot_path, "Rotation AUC by threshold (all clusters)") + if fig_output_dir is not None and all_translation_auc_values: + translation_auc_plot_path = fig_output_dir / "translation_auc_boxplot_all_clusters.png" + _plot_pose_auc_boxplot( + all_translation_auc_values, + translation_auc_plot_path, + "Translation AUC by threshold (all clusters)", + ) + if fig_output_dir is not None and any(all_intrinsics_deltas.values()): + intrinsics_plot_path = fig_output_dir / "intrinsics_deltas_all_clusters.png" + _plot_intrinsics_deltas_boxplot( + all_intrinsics_deltas, + intrinsics_plot_path, + "Intrinsics Δ (normalized, all clusters)", + ) + if fig_output_dir is not None and any(all_fov_deltas.values()): + fov_plot_path = fig_output_dir / "fov_deltas_all_clusters.png" + _plot_fov_deltas_boxplot( + all_fov_deltas, + fov_plot_path, + "FOV Δ (degrees, all clusters)", + ) + + +if __name__ == "__main__": + main() + From 883d601c4117ca6e0be8354f0ef9bedf0f2ce812 Mon Sep 17 00:00:00 2001 From: Akshay Krishnan Date: Tue, 27 Jan 2026 10:11:38 -0500 Subject: [PATCH 10/24] merging updates --- gtsfm/cluster_merging.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/gtsfm/cluster_merging.py b/gtsfm/cluster_merging.py index 9cc56c33c..6ab654c47 100644 --- a/gtsfm/cluster_merging.py +++ b/gtsfm/cluster_merging.py @@ -39,9 +39,7 @@ def _create_unary_measurements(scene: GtsfmData) -> list[UnaryMeasurementPose3]: # TODO(akshay-krishnan): investigate using a scene-dependent noise model # perhaps * np.exp(-len(scene.get_valid_camera_indices()) / 100.0) - noise_model = gtsam.noiseModel.Diagonal.Sigmas( - np.array([1e-2, 1e-2, 1e-2, 1e-1, 1e-1, 1e-1]) - ) + noise_model = gtsam.noiseModel.Diagonal.Sigmas(np.array([1e-2, 1e-2, 1e-2, 1e-1, 1e-1, 1e-1])) unary_measurements = [] for i, camera in scene.get_camera_poses().items(): if camera is None: @@ -277,9 +275,7 @@ def _get_pose_metrics( aligned_result_data = result_data.align_via_sim3_and_transform(poses_gt) return metrics_utils.compute_ba_pose_metrics( - gt_wTi=poses_gt, - computed_wTi=aligned_result_data.get_camera_poses(), - save_dir=save_dir, + gt_wTi=poses_gt, computed_wTi=aligned_result_data.get_camera_poses(), save_dir=save_dir, store_full_data=True ) @@ -386,6 +382,8 @@ def _drop_outlier_tracks(scene: GtsfmData) -> GtsfmData: Returns: The scene with outlier tracks dropped. """ + if scene.number_tracks() == 0: + return scene track_errors: list[float] = [] tracks = scene.tracks() cameras = scene.cameras() @@ -538,10 +536,12 @@ def _finalize_result(result_scene: Optional[GtsfmData]) -> MergedNodeResult: _propagate_scene_metadata(merged, metadata_source) - if drop_outlier_after_camera_merging and merged is not None and merged.number_tracks() > 0: - merged = _drop_outlier_tracks(merged) + if merged is None: + return _finalize_result(None) if not run_bundle_adjustment_on_parent: + if drop_outlier_after_camera_merging: + merged = _drop_outlier_tracks(merged) return _finalize_result(merged) # Log cameras that have no supporting track measurements before running BA. @@ -560,6 +560,14 @@ def _finalize_result(result_scene: Optional[GtsfmData]) -> MergedNodeResult: "merged result (with ba)", plot_histograms=plot_reprojection_histograms, ) + if drop_outlier_after_camera_merging: + merged_with_ba = _drop_outlier_tracks(merged_with_ba) + _log_scene_reprojection_stats( + merged_with_ba, + "merged result (with ba + outlier filtering)", + plot_histograms=plot_reprojection_histograms, + ) + # TODO: the order here is different from the merging order above, we should fix this. if merged.has_gaussian_splats(): logger.info("🫱🏻‍🫲🏽 Merging Gaussians") From 806606a0edfa46b6653a70b18ed7141d07678b5d Mon Sep 17 00:00:00 2001 From: Akshay Krishnan Date: Tue, 27 Jan 2026 10:13:13 -0500 Subject: [PATCH 11/24] save more metrics for debugging --- .../averaging/translation/averaging_1dsfm.py | 10 +-- gtsfm/bundle/bundle_adjustment.py | 10 ++- gtsfm/cluster_optimizer/cluster_vggt.py | 75 ++++++++++++++++--- gtsfm/common/gtsfm_data.py | 15 +++- gtsfm/evaluation/metrics.py | 2 +- gtsfm/frontend/vggt.py | 14 +++- gtsfm/runner.py | 2 +- gtsfm/utils/metrics.py | 14 +++- 8 files changed, 112 insertions(+), 30 deletions(-) diff --git a/gtsfm/averaging/translation/averaging_1dsfm.py b/gtsfm/averaging/translation/averaging_1dsfm.py index 3a121d5a4..52d101a07 100644 --- a/gtsfm/averaging/translation/averaging_1dsfm.py +++ b/gtsfm/averaging/translation/averaging_1dsfm.py @@ -51,7 +51,7 @@ MAX_PROJECTION_DIRECTIONS = 2000 OUTLIER_WEIGHT_THRESHOLD = 0.125 -NOISE_MODEL_DIMENSION = 3 # chordal distances on Unit3 +NOISE_MODEL_DIMENSION = 2 # chordal distances on Unit3 NOISE_MODEL_SIGMA = 0.01 HUBER_LOSS_K = 1.3 # default value from GTSAM @@ -73,7 +73,7 @@ L = symbol_shorthand.B # for track (landmark) translation variables RelativeDirectionsDict = AnnotatedGraph[Unit3] -DUMMY_NOISE_MODEL = gtsam.noiseModel.Isotropic.Sigma(3, 1e-2) # MFAS does not use this. +DUMMY_NOISE_MODEL = gtsam.noiseModel.Isotropic.Sigma(2, 1e-2) # MFAS does not use this. class TranslationAveraging1DSFM(TranslationAveragingBase): @@ -465,9 +465,9 @@ def __run_averaging( ) noise_model = gtsam.noiseModel.Isotropic.Sigma(NOISE_MODEL_DIMENSION, NOISE_MODEL_SIGMA) - if self._robust_measurement_noise: - huber_loss = gtsam.noiseModel.mEstimator.Huber.Create(HUBER_LOSS_K) - noise_model = gtsam.noiseModel.Robust.Create(huber_loss, noise_model) + # if self._robust_measurement_noise: + # huber_loss = gtsam.noiseModel.mEstimator.Huber.Create(HUBER_LOSS_K) + # noise_model = gtsam.noiseModel.Robust.Create(huber_loss, noise_model) w_i1Ui2_measurements = self._binary_measurements_from_dict(w_i2Ui1_dict, w_i2Ui1_dict_tracks, noise_model) diff --git a/gtsfm/bundle/bundle_adjustment.py b/gtsfm/bundle/bundle_adjustment.py index ae0a5616f..8d3200666 100644 --- a/gtsfm/bundle/bundle_adjustment.py +++ b/gtsfm/bundle/bundle_adjustment.py @@ -255,12 +255,18 @@ def __construct_simple_factor_graph( first_camera = initial_data.get_camera(cameras_to_model[0]) assert first_camera is not None, "First camera in initial data is None" graph.push_back( - PriorFactorPose3( + gtsam.NonlinearEqualityPose3( X(cameras_to_model[0]), first_camera.pose(), - Isotropic.Sigma(CAM_POSE3_DOF, self._cam_pose3_prior_noise_sigma), ) ) + # graph.push_back( + # PriorFactorPose3( + # X(cameras_to_model[0]), + # first_camera.pose(), + # Isotropic.Sigma(CAM_POSE3_DOF, self._cam_pose3_prior_noise_sigma), + # ) + # ) if initial_data.number_tracks() > 0: graph.push_back( diff --git a/gtsfm/cluster_optimizer/cluster_vggt.py b/gtsfm/cluster_optimizer/cluster_vggt.py index c74ea9e2e..2f213541b 100644 --- a/gtsfm/cluster_optimizer/cluster_vggt.py +++ b/gtsfm/cluster_optimizer/cluster_vggt.py @@ -5,6 +5,7 @@ from pathlib import Path from typing import Any, Hashable, Optional, Union +from gtsam import Pose3 import numpy as np import torch import torch.nn.functional as F @@ -12,12 +13,14 @@ import gtsfm.frontend.vggt as vggt from gtsfm.cluster_optimizer.cluster_optimizer_base import ClusterComputationGraph, ClusterContext, ClusterOptimizerBase +import gtsfm.common.types as gtsfm_types from gtsfm.common.gtsfm_data import GtsfmData from gtsfm.evaluation.metrics import GtsfmMetric, GtsfmMetricsGroup from gtsfm.frontend.vggt import VggtConfiguration, VggtReconstruction from gtsfm.products.visibility_graph import visibility_graph_keys from gtsfm.ui.gtsfm_process import UiMetadata from gtsfm.utils.logger import get_logger +import gtsfm.utils.metrics as metrics_utils logger = get_logger() @@ -106,18 +109,57 @@ def _save_pre_ba_reconstruction_as_text( _save_reconstruction_as_text(pre_ba_result, results_path, subdir="vggt_pre_ba") -def _aggregate_vggt_metrics(result: GtsfmData) -> GtsfmMetricsGroup: - num_cameras = len(result.get_valid_camera_indices()) - num_points3d = result.number_tracks() - return GtsfmMetricsGroup( - "vggt_runtime_metrics", - [ - GtsfmMetric("num_cameras", num_cameras), - GtsfmMetric("num_points3d", num_points3d), - ], +def _get_pose_metrics( + result_data: GtsfmData, + cameras_gt: list[Optional[gtsfm_types.CAMERA_TYPE]], + save_dir: Optional[str] = None, +) -> GtsfmMetricsGroup: + """Compute pose metrics for a VGGT result after aligning with ground truth.""" + image_idxs = list(result_data._image_info.keys()) + poses_gt: dict[int, Pose3] = {} + for i in image_idxs: + if i >= len(cameras_gt): + continue + camera = cameras_gt[i] + if camera is not None: + poses_gt[i] = camera.pose() + if len(poses_gt) == 0: + return GtsfmMetricsGroup(name="ba_pose_error_metrics", metrics=[]) + aligned_result_data = result_data.align_via_sim3_and_transform(poses_gt) + computed_wTi: dict[int, Optional[Pose3]] = {i: pose for i, pose in aligned_result_data.get_camera_poses().items()} + return metrics_utils.compute_ba_pose_metrics( + gt_wTi=poses_gt, + computed_wTi=computed_wTi, + save_dir=save_dir, + store_full_data=True, ) +def _aggregate_vggt_metrics( + result: GtsfmData, + cameras_gt: Optional[list[Optional[gtsfm_types.CAMERA_TYPE]]] = None, + pre_ba_result: Optional[GtsfmData] = None, + *, + save_dir: Optional[str] = None, +) -> list[GtsfmMetricsGroup]: + def _build_metrics_group(scene: GtsfmData, name: str) -> GtsfmMetricsGroup: + metrics_group = GtsfmMetricsGroup( + name, + [ + GtsfmMetric("num_cameras", len(scene.get_valid_camera_indices())), + GtsfmMetric("num_points3d", scene.number_tracks()), + ], + ) + if cameras_gt is not None: + metrics_group.extend(_get_pose_metrics(scene, cameras_gt, save_dir=save_dir)) + return metrics_group + + metrics_groups = [_build_metrics_group(result, "cluster_vggt_metrics")] + if pre_ba_result is not None: + metrics_groups.append(_build_metrics_group(pre_ba_result, "cluster_vggt_pre_ba_metrics")) + return metrics_groups + + def _extract_post_ba_result(result: VggtReconstruction) -> GtsfmData: """Extract the post-BA reconstruction from the VGGT pipeline output.""" return result.gtsfm_data @@ -310,8 +352,17 @@ def create_computation_graph( cluster_label=context.label, ) result_graph = delayed(_extract_post_ba_result)(reconstruction_graph) - - metrics_tasks = [delayed(_aggregate_vggt_metrics)(result_graph)] + pre_ba_result_graph = delayed(_extract_pre_ba_result)(reconstruction_graph) + + cameras_gt = [context.one_view_data_dict[idx].camera_gt for idx in range(context.num_images)] + metrics_tasks = [ + delayed(_aggregate_vggt_metrics)( + result_graph, + cameras_gt=cameras_gt, + pre_ba_result=pre_ba_result_graph, + save_dir=str(context.output_paths.metrics), + ) + ] io_tasks: list[Delayed] = [] with self._output_annotation(): @@ -323,7 +374,7 @@ def create_computation_graph( ) io_tasks.append( delayed(_save_pre_ba_reconstruction_as_text)( - delayed(_extract_pre_ba_result)(reconstruction_graph), + pre_ba_result_graph, context.output_paths.results, ) ) diff --git a/gtsfm/common/gtsfm_data.py b/gtsfm/common/gtsfm_data.py index 6b8d48f22..da8c69275 100644 --- a/gtsfm/common/gtsfm_data.py +++ b/gtsfm/common/gtsfm_data.py @@ -540,6 +540,14 @@ def get_tracks(self) -> List[SfmTrack]: """Returns all tracks.""" return self._tracks + def update_camera_pose(self, index: int, pose: Pose3) -> None: + """Updates the pose of a camera at index.""" + if index not in self._cameras: + raise ValueError(f"Camera at index {index} not found") + K = self._cameras[index].calibration() + new_camera = gtsfm_types.get_camera_class_for_calibration(K)(pose, K) + self._cameras[index] = new_camera + def add_camera(self, index: int, camera: gtsfm_types.CAMERA_TYPE) -> None: """Adds camera at index if not already present.""" if camera is None: @@ -838,11 +846,12 @@ def align_via_sim3_and_transform(self, aTi: dict[int, Pose3]) -> "GtsfmData": def get_metrics(self, suffix: str, store_full_data: bool = False) -> List[GtsfmMetric]: """Helper to get bundle adjustment metrics from a GtsfmData object with a suffix for metric names.""" metrics = [] - metrics.append(GtsfmMetric(name="number_cameras", data=len(self.get_valid_camera_indices()))) - metrics.append(GtsfmMetric("number_tracks" + suffix, self.number_tracks())) + metrics.append(GtsfmMetric(name=f"number_images{suffix}", data=self.number_images())) + metrics.append(GtsfmMetric(name=f"number_cameras{suffix}", data=len(self.get_valid_camera_indices()))) + metrics.append(GtsfmMetric(name=f"number_tracks{suffix}", data=self.number_tracks())) metrics.append( GtsfmMetric( - name="3d_track_lengths" + suffix, + name=f"3d_track_lengths{suffix}", data=self.get_track_lengths(), plot_type=GtsfmMetric.PlotType.HISTOGRAM, store_full_data=store_full_data, diff --git a/gtsfm/evaluation/metrics.py b/gtsfm/evaluation/metrics.py index 470ad6bb5..23b4b3f9f 100644 --- a/gtsfm/evaluation/metrics.py +++ b/gtsfm/evaluation/metrics.py @@ -208,7 +208,7 @@ def get_metric_as_dict(self) -> Dict[str, Any]: The metric as a dict representation explained above. """ if self._dim == 0: - return {self._name: self._data.tolist()} + return {self._name: round(self._data.tolist(), 4)} metric_dict = {SUMMARY_KEY: self.summary} if self._data is not None: metric_dict[FULL_DATA_KEY] = self._data.tolist() diff --git a/gtsfm/frontend/vggt.py b/gtsfm/frontend/vggt.py index efb9bc869..df5ef7f05 100644 --- a/gtsfm/frontend/vggt.py +++ b/gtsfm/frontend/vggt.py @@ -451,8 +451,12 @@ def _convert_vggt_outputs_to_gtsfm_data( inlier_num = track_mask.sum(0) valid_mask = inlier_num >= 2 # a track is invalid if without two inliers + confidence_threshold = config.confidence_threshold + confidence_threshold = min( + confidence_threshold, np.mean(tracking_result.confidences) + np.std(tracking_result.confidences) + ) if tracking_result.confidences is not None: - valid_mask = np.logical_and(valid_mask, tracking_result.confidences > config.confidence_threshold) + valid_mask = np.logical_and(valid_mask, tracking_result.confidences > confidence_threshold) valid_idx = np.nonzero(valid_mask)[0] max_reproj_error = float(config.max_reproj_error) @@ -460,6 +464,8 @@ def _convert_vggt_outputs_to_gtsfm_data( tracking_result.points_3d is not None and np.isfinite(max_reproj_error) and max_reproj_error > 0.0 ) + logger.info("num points 3d: %d, num valid idx: %d", tracking_result.points_3d.shape[0], len(valid_idx)) + for valid_id in valid_idx: rgb: np.ndarray if tracking_result.colors is not None and valid_id < tracking_result.colors.shape[0]: @@ -492,12 +498,12 @@ def _convert_vggt_outputs_to_gtsfm_data( proj_v = float(projected[1]) reproj_err = float(np.hypot(rescaled_u - proj_u, rescaled_v - proj_v)) max_error_for_track = max(max_error_for_track, reproj_err) + # if reproj_err > max_reproj_error: + # continue per_track_measurements.append((global_idx, rescaled_u, rescaled_v)) if len(per_track_measurements) < 2: continue - if enforce_reproj_filter and max_error_for_track > max_reproj_error: - continue track = torch_utils.colored_track_from_point(point_xyz, rgb) for global_idx, rescaled_u, rescaled_v in per_track_measurements: @@ -515,7 +521,7 @@ def _convert_vggt_outputs_to_gtsfm_data( gtsfm_data, should_run_ba = data_utils.remove_cameras_with_no_tracks(gtsfm_data, "node-level BA") if not should_run_ba: return gtsfm_data, gtsfm_data_pre_ba - optimizer = BundleAdjustmentOptimizer() + optimizer = BundleAdjustmentOptimizer(robust_measurement_noise=False, calibration_prior_noise_sigma=10) gtsfm_data_with_ba, _ = optimizer.run_simple_ba(gtsfm_data, verbose=False) return gtsfm_data_with_ba, gtsfm_data_pre_ba except Exception as exc: diff --git a/gtsfm/runner.py b/gtsfm/runner.py index fda8540b8..83c33ce04 100644 --- a/gtsfm/runner.py +++ b/gtsfm/runner.py @@ -150,7 +150,7 @@ def construct_argparser(self) -> argparse.ArgumentParser: ) parser.add_argument("--threads_per_worker", type=int, default=1, help="Number of threads per each worker.") parser.add_argument( - "--worker_memory_limit", type=str, default="16GB", help="Memory limit per worker, e.g. `16GB`" + "--worker_memory_limit", type=str, default="32GB", help="Memory limit per worker, e.g. `16GB`" ) parser.add_argument("--dashboard_port", type=str, default=":8787", help="dask dashboard port number") parser.add_argument( diff --git a/gtsfm/utils/metrics.py b/gtsfm/utils/metrics.py index a58294fd4..945a74bfd 100644 --- a/gtsfm/utils/metrics.py +++ b/gtsfm/utils/metrics.py @@ -444,6 +444,7 @@ def compute_ba_pose_metrics( gt_wTi: dict[int, Pose3], computed_wTi: dict[int, Optional[Pose3]], save_dir: Optional[str] = None, + store_full_data: bool = False, ) -> GtsfmMetricsGroup: """Compute pose errors w.r.t. GT for the bundle adjustment result. @@ -453,6 +454,7 @@ def compute_ba_pose_metrics( gt_wTi: Dict of ground truth poses keyed by camera id. computed_wTi: Dict of computed poses keyed by camera id. save_dir: Directory to save the metrics plots. + store_full_data: Whether to store full data. Returns: A group of metrics that describe errors associated with a bundle adjustment result (w.r.t. GT). @@ -477,12 +479,20 @@ def compute_ba_pose_metrics( gt_wTi_opt: dict[int, Optional[Pose3]] = {i: pose for i, pose in gt_wTi.items()} translation_angular_errors = get_relative_translation_angles(i2Ui1_gt_opt, computed_wTi_opt, include_none=True) metrics.append( - GtsfmMetric("relative_translation_angle_error_deg", np.array(translation_angular_errors, dtype=np.float32)) + GtsfmMetric( + "relative_translation_angle_error_deg", + np.array(translation_angular_errors, dtype=np.float32), + store_full_data=store_full_data, + ) ) metrics.append(compute_translation_angle_metric(gt_wTi_opt, computed_wTi_opt)) rotation_angular_errors = get_relative_rotation_angles(i2Ri1_gt_opt, computed_wTi_opt, include_none=True) metrics.append( - GtsfmMetric("relative_rotation_angle_error_deg", np.array(rotation_angular_errors, dtype=np.float32)) + GtsfmMetric( + "relative_rotation_angle_error_deg", + np.array(rotation_angular_errors, dtype=np.float32), + store_full_data=store_full_data, + ) ) metrics.extend(compute_pose_auc_metric(rotation_angular_errors, translation_angular_errors, save_dir=save_dir)) From 8ce3153e0500798fca389e6469ac26aeeb711bd4 Mon Sep 17 00:00:00 2001 From: nantonzhang Date: Tue, 27 Jan 2026 20:48:16 -0500 Subject: [PATCH 12/24] more detailed eval code --- gtsfm/evaluation/compare_colmap_outputs.py | 19 +++++++++---------- .../compare_colmap_outputs_by_cluster.py | 19 ++++++++++--------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/gtsfm/evaluation/compare_colmap_outputs.py b/gtsfm/evaluation/compare_colmap_outputs.py index 66b8016c1..a84fd3145 100644 --- a/gtsfm/evaluation/compare_colmap_outputs.py +++ b/gtsfm/evaluation/compare_colmap_outputs.py @@ -194,25 +194,25 @@ def compare_poses(baseline_dirpath: str, eval_dirpath: str, output_dirpath: str) current_wTi_list.append(current_wTi_dict.get(fname)) if not args.use_pycolmap_alignment: - aSb = align.sim3_from_optional_Pose3s(baseline_wTi_list, current_wTi_list) + aSb = align.sim3_from_Pose3_maps_robust(baseline_wTi_dict, current_wTi_dict) current_wTi_list = transform.optional_Pose3s_with_sim3(aSb, current_wTi_list) current_wTi_dict = {fname: aSb.transformFrom(pose) for fname, pose in current_wTi_dict.items()} - i2Ri1_dict_gt, i2Ui1_dict_gt = metric_utils.get_all_relative_rotations_translations(baseline_wTi_list) + i2Ri1_dict_gt, i2Ui1_dict_gt = metric_utils.get_all_relative_rotations_translations(baseline_wTi_dict) - wRi_aligned_list, wti_aligned_list = metric_utils.get_rotations_translations_from_poses(current_wTi_list) - baseline_wRi_list, baseline_wti_list = metric_utils.get_rotations_translations_from_poses(baseline_wTi_list) + wRi_aligned_dict, wti_aligned_dict = metric_utils.get_rotations_translations_from_poses(current_wTi_dict) + baseline_wRi_dict, baseline_wti_dict = metric_utils.get_rotations_translations_from_poses(baseline_wTi_dict) metrics = [] - metrics.append(metric_utils.compute_rotation_angle_metric(wRi_aligned_list, baseline_wRi_list)) - metrics.append(metric_utils.compute_translation_distance_metric(wti_aligned_list, baseline_wti_list)) - metrics.append(metric_utils.compute_translation_angle_metric(baseline_wTi_list, current_wTi_list)) + metrics.append(metric_utils.compute_rotation_angle_metric(wRi_aligned_dict, baseline_wRi_dict)) + metrics.append(metric_utils.compute_translation_distance_metric(wti_aligned_dict, baseline_wti_dict)) + metrics.append(metric_utils.compute_translation_angle_metric(baseline_wTi_dict, current_wTi_dict)) relative_rotation_error_metric = metric_utils.compute_relative_rotation_angle_metric( - i2Ri1_dict_gt, current_wTi_list, store_full_data=True + i2Ri1_dict_gt, current_wTi_dict, store_full_data=True ) metrics.append(relative_rotation_error_metric) relative_translation_error_metric = metric_utils.compute_relative_translation_angle_metric( - i2Ui1_dict_gt, current_wTi_list, store_full_data=True + i2Ui1_dict_gt, current_wTi_dict, store_full_data=True ) metrics.append(relative_translation_error_metric) @@ -260,4 +260,3 @@ def compare_poses(baseline_dirpath: str, eval_dirpath: str, output_dirpath: str) ba_pose_metrics = compare_poses(args.baseline, args.current, args.output) export_metrics_group_to_csv(ba_pose_metrics, os.path.join(args.output, f"{ba_pose_metrics.name}.csv")) - diff --git a/gtsfm/evaluation/compare_colmap_outputs_by_cluster.py b/gtsfm/evaluation/compare_colmap_outputs_by_cluster.py index 869ebbcef..fd0c2a20d 100644 --- a/gtsfm/evaluation/compare_colmap_outputs_by_cluster.py +++ b/gtsfm/evaluation/compare_colmap_outputs_by_cluster.py @@ -176,20 +176,22 @@ def _build_pose_lists( def _compute_pose_metrics(baseline_list: List[Pose3], current_aligned_list: List[Pose3]) -> GtsfmMetricsGroup: """Compute the same pose metrics as compare_colmap_outputs, without plotting.""" - i2Ri1_dict_gt, i2Ui1_dict_gt = metric_utils.get_all_relative_rotations_translations(baseline_list) - wRi_aligned_list, wti_aligned_list = metric_utils.get_rotations_translations_from_poses(current_aligned_list) - baseline_wRi_list, baseline_wti_list = metric_utils.get_rotations_translations_from_poses(baseline_list) + baseline_dict = {i: pose for i, pose in enumerate(baseline_list)} + current_dict = {i: pose for i, pose in enumerate(current_aligned_list)} + i2Ri1_dict_gt, i2Ui1_dict_gt = metric_utils.get_all_relative_rotations_translations(baseline_dict) + wRi_aligned_dict, wti_aligned_dict = metric_utils.get_rotations_translations_from_poses(current_dict) + baseline_wRi_dict, baseline_wti_dict = metric_utils.get_rotations_translations_from_poses(baseline_dict) metrics = [] - metrics.append(metric_utils.compute_rotation_angle_metric(wRi_aligned_list, baseline_wRi_list)) - metrics.append(metric_utils.compute_translation_distance_metric(wti_aligned_list, baseline_wti_list)) - metrics.append(metric_utils.compute_translation_angle_metric(baseline_list, current_aligned_list)) + metrics.append(metric_utils.compute_rotation_angle_metric(wRi_aligned_dict, baseline_wRi_dict)) + metrics.append(metric_utils.compute_translation_distance_metric(wti_aligned_dict, baseline_wti_dict)) + metrics.append(metric_utils.compute_translation_angle_metric(baseline_dict, current_dict)) relative_rotation_error_metric = metric_utils.compute_relative_rotation_angle_metric( - i2Ri1_dict_gt, current_aligned_list, store_full_data=True + i2Ri1_dict_gt, current_dict, store_full_data=True ) metrics.append(relative_rotation_error_metric) relative_translation_error_metric = metric_utils.compute_relative_translation_angle_metric( - i2Ui1_dict_gt, current_aligned_list, store_full_data=True + i2Ui1_dict_gt, current_dict, store_full_data=True ) metrics.append(relative_translation_error_metric) thresholds_deg = (1.0, 2.5, 5.0, 10.0, 20.0) @@ -895,4 +897,3 @@ def main() -> None: if __name__ == "__main__": main() - From eceb70b27979df2873defc9513cecf7981ee3de7 Mon Sep 17 00:00:00 2001 From: nantonzhang Date: Tue, 27 Jan 2026 21:03:23 -0500 Subject: [PATCH 13/24] optimize cacher --- .../cluster_optimizer_base.py | 2 ++ .../cluster_optimizer_cacher.py | 26 ++++++++++++++++++- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/gtsfm/cluster_optimizer/cluster_optimizer_base.py b/gtsfm/cluster_optimizer/cluster_optimizer_base.py index 3a2b3deb4..bf6496d38 100644 --- a/gtsfm/cluster_optimizer/cluster_optimizer_base.py +++ b/gtsfm/cluster_optimizer/cluster_optimizer_base.py @@ -102,6 +102,7 @@ def __init__( drop_outlier_after_camera_merging: bool = True, plot_reprojection_histograms: bool = True, run_bundle_adjustment_on_parent: bool = True, + run_bundle_adjustment_on_leaf: bool = False, output_worker: None | str = None, ) -> None: self.drop_child_if_merging_fail = drop_child_if_merging_fail @@ -109,6 +110,7 @@ def __init__( self.drop_outlier_after_camera_merging = drop_outlier_after_camera_merging self.plot_reprojection_histograms = plot_reprojection_histograms self.run_bundle_adjustment_on_parent = run_bundle_adjustment_on_parent + self.run_bundle_adjustment_on_leaf = run_bundle_adjustment_on_leaf self._pose_angular_error_thresh = pose_angular_error_thresh self._output_worker = output_worker diff --git a/gtsfm/cluster_optimizer/cluster_optimizer_cacher.py b/gtsfm/cluster_optimizer/cluster_optimizer_cacher.py index 4edad8d99..e320bdc67 100644 --- a/gtsfm/cluster_optimizer/cluster_optimizer_cacher.py +++ b/gtsfm/cluster_optimizer/cluster_optimizer_cacher.py @@ -42,8 +42,13 @@ def __init__(self, optimizer: ClusterOptimizerBase, cache_subdir: Optional[str] optimizer: cluster optimizer to use in case of cache miss. cache_subdir: Optional subdirectory (relative to cache root) for storing cache entries. """ + run_bundle_adjustment_on_leaf = getattr(optimizer, "run_bundle_adjustment_on_leaf", None) + if run_bundle_adjustment_on_leaf is None: + run_bundle_adjustment_on_leaf = getattr(optimizer, "_run_bundle_adjustment_on_leaf", False) super().__init__( pose_angular_error_thresh=optimizer.pose_angular_error_thresh, + run_bundle_adjustment_on_leaf=run_bundle_adjustment_on_leaf, + run_bundle_adjustment_on_parent=getattr(optimizer, "run_bundle_adjustment_on_parent", True), output_worker=optimizer._output_worker, ) self._optimizer = optimizer @@ -78,8 +83,13 @@ def __setstate__(self, state: dict[str, object]) -> None: self._cache_subdir = typing.cast(Optional[str], state.get("_cache_subdir")) self._cache_root_path = self._resolve_cache_root(self._cache_subdir) # Re-initialize the base class to mimic the constructor. + run_bundle_adjustment_on_leaf = getattr(self._optimizer, "run_bundle_adjustment_on_leaf", None) + if run_bundle_adjustment_on_leaf is None: + run_bundle_adjustment_on_leaf = getattr(self._optimizer, "_run_bundle_adjustment_on_leaf", False) super().__init__( pose_angular_error_thresh=self._optimizer.pose_angular_error_thresh, + run_bundle_adjustment_on_leaf=run_bundle_adjustment_on_leaf, + run_bundle_adjustment_on_parent=getattr(self._optimizer, "run_bundle_adjustment_on_parent", True), output_worker=self._optimizer._output_worker, ) @@ -144,11 +154,25 @@ def _save_result_to_cache(self, result: GtsfmData, cache_path: Path) -> GtsfmDat io_utils.write_to_bz2_file(result, cache_path) return result + def _save_cached_result_outputs(self, result: GtsfmData, results_path: Path) -> None: + """Persist cached outputs expected by downstream tooling. + + Currently used to re-export VGGT reconstructions in COLMAP text format. + """ + if "VGGT" not in type(self._optimizer).__name__: + return + target_dir = results_path / "vggt" + target_dir.mkdir(parents=True, exist_ok=True) + result.export_as_colmap_text(target_dir) + def create_computation_graph(self, context: ClusterContext) -> ClusterComputationGraph | None: cached_result = self._load_result_from_cache(context) if cached_result is not None: cached_graph: Delayed = delayed(lambda r: r, pure=False)(cached_result) - return ClusterComputationGraph(io_tasks=tuple(), metric_tasks=tuple(), sfm_result=cached_graph) + io_tasks = ( + delayed(self._save_cached_result_outputs, pure=False)(cached_graph, context.output_paths.results), + ) + return ClusterComputationGraph(io_tasks=io_tasks, metric_tasks=tuple(), sfm_result=cached_graph) computation = self._optimizer.create_computation_graph(context) if computation is None or computation.sfm_result is None: From 86f568a027f6b8088181809133c0eef3ab583130 Mon Sep 17 00:00:00 2001 From: nantonzhang Date: Tue, 27 Jan 2026 21:04:08 -0500 Subject: [PATCH 14/24] skip use_nonlinear_sim3_alignment for now --- gtsfm/configs/vggt.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/gtsfm/configs/vggt.yaml b/gtsfm/configs/vggt.yaml index 59d8e23bf..71fa0d5d0 100644 --- a/gtsfm/configs/vggt.yaml +++ b/gtsfm/configs/vggt.yaml @@ -47,3 +47,5 @@ cluster_optimizer: run_bundle_adjustment_on_leaf: false run_bundle_adjustment_on_parent: false model_cache_key: null + +# use_nonlinear_sim3_alignment: false \ No newline at end of file From 01d9c81ab95d5e1f3d4641c4d434adcbd9df5f1f Mon Sep 17 00:00:00 2001 From: Harneet Singh Khanuja Date: Thu, 5 Feb 2026 17:51:20 -0500 Subject: [PATCH 15/24] resolving comments --- gtsfm/cluster_optimizer/cluster_vggt.py | 10 +++++---- gtsfm/configs/vggt.yaml | 1 - gtsfm/frontend/vggt.py | 29 ++++++++++++------------- gtsfm/loader/loader_base.py | 2 +- 4 files changed, 21 insertions(+), 21 deletions(-) diff --git a/gtsfm/cluster_optimizer/cluster_vggt.py b/gtsfm/cluster_optimizer/cluster_vggt.py index f98479009..8f7b8b7b3 100644 --- a/gtsfm/cluster_optimizer/cluster_vggt.py +++ b/gtsfm/cluster_optimizer/cluster_vggt.py @@ -5,22 +5,22 @@ from pathlib import Path from typing import Any, Hashable, Optional, Union -from gtsam import Pose3 import numpy as np import torch import torch.nn.functional as F from dask.delayed import Delayed, delayed +from gtsam import Pose3 +import gtsfm.common.types as gtsfm_types import gtsfm.frontend.vggt as vggt +import gtsfm.utils.metrics as metrics_utils from gtsfm.cluster_optimizer.cluster_optimizer_base import ClusterComputationGraph, ClusterContext, ClusterOptimizerBase -import gtsfm.common.types as gtsfm_types from gtsfm.common.gtsfm_data import GtsfmData from gtsfm.evaluation.metrics import GtsfmMetric, GtsfmMetricsGroup from gtsfm.frontend.vggt import VggtConfiguration, VggtReconstruction from gtsfm.products.visibility_graph import visibility_graph_keys from gtsfm.ui.gtsfm_process import UiMetadata from gtsfm.utils.logger import get_logger -import gtsfm.utils.metrics as metrics_utils logger = get_logger() @@ -326,8 +326,10 @@ def create_computation_graph( max_reproj_error=self._max_reproj_error, ) + # mode is fixed to "crop", it resizes the width to 518 while maintaining aspect ratio and only if + # height is > 518 then crops image_batch_graph, original_coords_graph = delayed(_load_vggt_inputs, nout=2)( - context.loader, global_indices, mode="crop" # mode is fixed to "crop" + context.loader, global_indices, mode="crop" ) reconstruction_graph = delayed(_run_vggt_pipeline)( diff --git a/gtsfm/configs/vggt.yaml b/gtsfm/configs/vggt.yaml index 71fa0d5d0..a1e93eb06 100644 --- a/gtsfm/configs/vggt.yaml +++ b/gtsfm/configs/vggt.yaml @@ -7,7 +7,6 @@ loader: _target_: gtsfm.loader.Olsson dataset_dir: ??? # Required: set to the dataset root on the command line. images_dir: null - max_resolution: 518 # VGGT recommended max resolution. Non editable. mode is fixed to "crop" image_pairs_generator: _target_: gtsfm.retriever.image_pairs_generator.ImagePairsGenerator diff --git a/gtsfm/frontend/vggt.py b/gtsfm/frontend/vggt.py index ed6b9c6fa..3ea1c1b12 100644 --- a/gtsfm/frontend/vggt.py +++ b/gtsfm/frontend/vggt.py @@ -377,8 +377,9 @@ def _convert_vggt_outputs_to_gtsfm_data( track_mask = tracking_result.visibilities > config.track_vis_thresh inlier_num = track_mask.sum(0) - valid_mask = inlier_num >= 2 # a track is invalid if without two inliers - confidence_threshold = config.confidence_threshold + min_measurements = 2 + valid_mask = inlier_num >= min_measurements # a track is invalid if without two inliers + confidence_threshold = config.track_conf_thresh confidence_threshold = min( confidence_threshold, np.mean(tracking_result.confidences) + np.std(tracking_result.confidences) ) @@ -416,18 +417,16 @@ def _convert_vggt_outputs_to_gtsfm_data( camera = gtsfm_data.get_camera(global_idx) if not _is_point_in_front_of_camera(camera, point_xyz): continue - float_u = float(u) - float_v = float(v) if enforce_reproj_filter: projected = camera.project(gtsam_point) proj_u = float(projected[0]) proj_v = float(projected[1]) - reproj_err = float(np.hypot(float_u - proj_u, float_v - proj_v)) + reproj_err = float(np.hypot(u - proj_u, v - proj_v)) max_error_for_track = max(max_error_for_track, reproj_err) - per_track_measurements.append((global_idx, float_u, float_v)) + per_track_measurements.append((global_idx, u, v)) - # if len(per_track_measurements) < min_measurements: - # continue + if len(per_track_measurements) < min_measurements: + continue track = torch_utils.colored_track_from_point(point_xyz, rgb) for global_idx, float_u, float_v in per_track_measurements: @@ -528,19 +527,19 @@ def run_VGGT( if depth_confidence.ndim == 4 and depth_confidence.shape[-1] == 1: depth_confidence = depth_confidence.squeeze(-1) - depth_map_fp32 = depth_map.squeeze(0).to(dtype=torch.float32) - extrinsic_fp32 = extrinsic.squeeze(0).to(dtype=torch.float32) - intrinsic_fp32 = intrinsic.squeeze(0).to(dtype=torch.float32) - dense_points_np = unproject_depth_map_to_point_map(depth_map_fp32, extrinsic_fp32, intrinsic_fp32) + depth_map = depth_map.squeeze(0).to(dtype=torch.float32) + extrinsic = extrinsic.squeeze(0).to(dtype=torch.float32) + intrinsic = intrinsic.squeeze(0).to(dtype=torch.float32) + dense_points_np = unproject_depth_map_to_point_map(depth_map, extrinsic, intrinsic) dense_points = torch.from_numpy(dense_points_np).to(device=resolved_device, dtype=torch.float32) return VggtOutput( device=resolved_device, dtype=resolved_dtype, images=images, - extrinsic=extrinsic.squeeze(0), - intrinsic=intrinsic.squeeze(0), - depth_map=depth_map.squeeze(0), + extrinsic=extrinsic, + intrinsic=intrinsic, + depth_map=depth_map, depth_confidence=depth_confidence, dense_points=dense_points, ) diff --git a/gtsfm/loader/loader_base.py b/gtsfm/loader/loader_base.py index 415ddef14..13e2cd89a 100644 --- a/gtsfm/loader/loader_base.py +++ b/gtsfm/loader/loader_base.py @@ -589,7 +589,7 @@ def load_image_batch_vggt_loader(self, indices: List[int], mode="crop"): # Check if we have different shapes # In theory our model can also work well with different shapes if len(shapes) > 1: - print(f"Warning: Found images with different shapes: {shapes}") + logger.warning("Found images with different shapes: %s", shapes) # Find maximum dimensions max_height = max(shape[0] for shape in shapes) max_width = max(shape[1] for shape in shapes) From 0943229a8d4249236db55be8c81704b237b46203 Mon Sep 17 00:00:00 2001 From: Harneet Singh Khanuja Date: Thu, 5 Feb 2026 17:55:20 -0500 Subject: [PATCH 16/24] adding 2 parameters in yaml file --- gtsfm/configs/vggt_megaloc.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/gtsfm/configs/vggt_megaloc.yaml b/gtsfm/configs/vggt_megaloc.yaml index a7f80269e..2bff9d9bc 100644 --- a/gtsfm/configs/vggt_megaloc.yaml +++ b/gtsfm/configs/vggt_megaloc.yaml @@ -38,6 +38,8 @@ cluster_optimizer: tracking_query_frame_num: 3 tracking_fine_tracking: false track_vis_thresh: 0.2 + track_conf_thresh: 0.2 + max_reproj_error: 0 # 0.0 means no filtering based on reproj error camera_type: PINHOLE drop_outlier_after_camera_merging: false drop_child_if_merging_fail: true From de2242dc1a38d9a336752156cb83d2389fcf2c5d Mon Sep 17 00:00:00 2001 From: Harneet Singh Khanuja Date: Thu, 5 Feb 2026 18:07:53 -0500 Subject: [PATCH 17/24] Fixing dependencies --- environment_linux.yml | 2 +- environment_linux_cpuonly.yml | 2 +- environment_mac.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/environment_linux.yml b/environment_linux.yml index 46639b053..8852d6f95 100644 --- a/environment_linux.yml +++ b/environment_linux.yml @@ -56,7 +56,7 @@ dependencies: - pydegensac - colour - trimesh[easy] - - gtsam==4.3a0 + - gtsam-develop - pydot # dust3r/mast3r - roma diff --git a/environment_linux_cpuonly.yml b/environment_linux_cpuonly.yml index 03827ba55..5e7b5ff6e 100644 --- a/environment_linux_cpuonly.yml +++ b/environment_linux_cpuonly.yml @@ -56,7 +56,7 @@ dependencies: - pydegensac - colour - trimesh[easy] - - gtsam==4.3a0 + - gtsam-develop - pydot - roma - tqdm diff --git a/environment_mac.yml b/environment_mac.yml index 97303fb95..df85d2352 100644 --- a/environment_mac.yml +++ b/environment_mac.yml @@ -57,7 +57,7 @@ dependencies: - pydegensac - colour - trimesh[easy] - - gtsam==4.3a0 + - gtsam-develop - pydot # dust3r/mast3r - roma From 040b1087ab4641c3f34b16ff1d2de895d5d0a071 Mon Sep 17 00:00:00 2001 From: Harneet Singh Khanuja Date: Thu, 5 Feb 2026 23:32:03 -0500 Subject: [PATCH 18/24] Resolving comments --- create_tracks_viz.sh | 17 --- .../averaging/translation/averaging_1dsfm.py | 6 +- gtsfm/bundle/bundle_adjustment.py | 18 +-- gtsfm/cluster_optimizer/cluster_vggt.py | 2 +- gtsfm/configs/vggt.yaml | 2 - gtsfm/configs/vggt_megaloc.yaml | 4 +- .../compare_colmap_outputs_by_cluster.py | 36 +++-- gtsfm/frontend/vggt.py | 139 +++++++++++++++++- gtsfm/loader/loader_base.py | 134 ----------------- 9 files changed, 172 insertions(+), 186 deletions(-) delete mode 100644 create_tracks_viz.sh diff --git a/create_tracks_viz.sh b/create_tracks_viz.sh deleted file mode 100644 index 641a5ef77..000000000 --- a/create_tracks_viz.sh +++ /dev/null @@ -1,17 +0,0 @@ -# python gtsfm/visualization/visualize_tracks.py \ -# --result_root /coc/flash5/akrishnan86/gtsfm/outputs/metis_skydio32/results/ \ -# --loader_config colmap \ -# --dataset_dir /coc/flash5/akrishnan86/gtsfm/data/skydio32 - - - -# python gtsfm/visualization/visualize_tracks.py \ -# --result_root /coc/flash5/akrishnan86/gtsfm/outputs/metis_cm2_palace_0_4/results/ \ -# --loader_config olsson \ -# --dataset_dir /coc/flash5/akrishnan86/gtsfm/data/palace_fine_arts - -python gtsfm/visualization/visualize_tracks.py \ - --result_root /coc/flash5/akrishnan86/gtsfm/outputs/metis_vggt_palace/results/ \ - --loader_config olsson \ - --dataset_dir /coc/flash5/akrishnan86/gtsfm/data/palace_fine_arts \ - --line_only \ No newline at end of file diff --git a/gtsfm/averaging/translation/averaging_1dsfm.py b/gtsfm/averaging/translation/averaging_1dsfm.py index a8d0b8b38..2c4463176 100644 --- a/gtsfm/averaging/translation/averaging_1dsfm.py +++ b/gtsfm/averaging/translation/averaging_1dsfm.py @@ -465,9 +465,9 @@ def __run_averaging( ) noise_model = gtsam.noiseModel.Isotropic.Sigma(NOISE_MODEL_DIMENSION, NOISE_MODEL_SIGMA) - # if self._robust_measurement_noise: - # huber_loss = gtsam.noiseModel.mEstimator.Huber.Create(HUBER_LOSS_K) - # noise_model = gtsam.noiseModel.Robust.Create(huber_loss, noise_model) + if self._robust_measurement_noise: + huber_loss = gtsam.noiseModel.mEstimator.Huber.Create(HUBER_LOSS_K) + noise_model = gtsam.noiseModel.Robust.Create(huber_loss, noise_model) w_i1Ui2_measurements = self._binary_measurements_from_dict(w_i2Ui1_dict, w_i2Ui1_dict_tracks, noise_model) diff --git a/gtsfm/bundle/bundle_adjustment.py b/gtsfm/bundle/bundle_adjustment.py index 53e46c39e..d897a66d9 100644 --- a/gtsfm/bundle/bundle_adjustment.py +++ b/gtsfm/bundle/bundle_adjustment.py @@ -12,13 +12,7 @@ import gtsam # type: ignore import numpy as np from dask.delayed import Delayed -from gtsam import ( - BetweenFactorPose3, - NonlinearFactorGraph, - PriorFactorPose3, - PriorFactorPoint3, - Values, -) +from gtsam import BetweenFactorPose3, NonlinearFactorGraph, PriorFactorPoint3, PriorFactorPose3, Values from gtsam.noiseModel import Diagonal, Isotropic, Robust, mEstimator # type: ignore from gtsam.symbol_shorthand import K, P, X # type: ignore @@ -248,18 +242,12 @@ def __construct_simple_factor_graph( first_camera = initial_data.get_camera(cameras_to_model[0]) assert first_camera is not None, "First camera in initial data is None" graph.push_back( - gtsam.NonlinearEqualityPose3( + PriorFactorPose3( X(cameras_to_model[0]), first_camera.pose(), + Isotropic.Sigma(CAM_POSE3_DOF, self._cam_pose3_prior_noise_sigma), ) ) - # graph.push_back( - # PriorFactorPose3( - # X(cameras_to_model[0]), - # first_camera.pose(), - # Isotropic.Sigma(CAM_POSE3_DOF, self._cam_pose3_prior_noise_sigma), - # ) - # ) # Add prior factor on the position of the first landmark to fix the scale. if initial_data.number_tracks() > 0: diff --git a/gtsfm/cluster_optimizer/cluster_vggt.py b/gtsfm/cluster_optimizer/cluster_vggt.py index 8f7b8b7b3..aab76bbd9 100644 --- a/gtsfm/cluster_optimizer/cluster_vggt.py +++ b/gtsfm/cluster_optimizer/cluster_vggt.py @@ -37,7 +37,7 @@ def _resize_to_square_tensor(image: np.ndarray, target_size: int) -> torch.Tenso def _load_vggt_inputs(loader, indices: list[int], mode: str): """Load and preprocess a batch of images for VGGT.""" - return loader.load_image_batch_vggt_loader(indices, mode=mode) + return vggt.load_image_batch_vggt_loader(loader, indices, mode=mode) def _resolve_vggt_model(cache_key: Hashable | None, loader_kwargs: dict[str, Any] | None) -> Any | None: diff --git a/gtsfm/configs/vggt.yaml b/gtsfm/configs/vggt.yaml index a1e93eb06..4fc525117 100644 --- a/gtsfm/configs/vggt.yaml +++ b/gtsfm/configs/vggt.yaml @@ -46,5 +46,3 @@ cluster_optimizer: run_bundle_adjustment_on_leaf: false run_bundle_adjustment_on_parent: false model_cache_key: null - -# use_nonlinear_sim3_alignment: false \ No newline at end of file diff --git a/gtsfm/configs/vggt_megaloc.yaml b/gtsfm/configs/vggt_megaloc.yaml index 2bff9d9bc..d4cbe08a7 100644 --- a/gtsfm/configs/vggt_megaloc.yaml +++ b/gtsfm/configs/vggt_megaloc.yaml @@ -49,6 +49,6 @@ cluster_optimizer: run_bundle_adjustment_on_leaf: false run_bundle_adjustment_on_parent: true model_cache_key: null - # store_pre_ba_result: true + store_pre_ba_result: true -use_nonlinear_sim3_alignment: false \ No newline at end of file +use_nonlinear_sim3_alignment: false diff --git a/gtsfm/evaluation/compare_colmap_outputs_by_cluster.py b/gtsfm/evaluation/compare_colmap_outputs_by_cluster.py index fd0c2a20d..0ad746992 100644 --- a/gtsfm/evaluation/compare_colmap_outputs_by_cluster.py +++ b/gtsfm/evaluation/compare_colmap_outputs_by_cluster.py @@ -10,9 +10,9 @@ import csv import json import os +import textwrap from pathlib import Path from typing import Dict, Iterable, List, Optional, Tuple -import textwrap import matplotlib.pyplot as plt import numpy as np @@ -304,9 +304,7 @@ def _align_poses( baseline_dict = {i: pose for i, pose in enumerate(baseline_list)} current_dict = {i: pose for i, pose in enumerate(current_list)} if use_ransac: - aSb = _estimate_sim3_ransac( - baseline_list, current_list, max_hypotheses, inlier_thresh, rng, cluster_label - ) + aSb = _estimate_sim3_ransac(baseline_list, current_list, max_hypotheses, inlier_thresh, rng, cluster_label) else: aSb = align.sim3_from_Pose3_maps(baseline_dict, current_dict) current_aligned_list = [aSb.transformFrom(pose) for pose in current_list] @@ -338,18 +336,32 @@ def _plot_camera_centers( tip = pose.transformFrom(np.array([0.0, 0.0, arrow_len])) direction = tip - origin ax.quiver( - origin[0], origin[1], origin[2], - direction[0], direction[1], direction[2], - color="tab:blue", linewidth=0.5, arrow_length_ratio=0.2, alpha=0.6 + origin[0], + origin[1], + origin[2], + direction[0], + direction[1], + direction[2], + color="tab:blue", + linewidth=0.5, + arrow_length_ratio=0.2, + alpha=0.6, ) for pose in current_list: origin = pose.transformFrom(np.array([0.0, 0.0, 0.0])) tip = pose.transformFrom(np.array([0.0, 0.0, arrow_len])) direction = tip - origin ax.quiver( - origin[0], origin[1], origin[2], - direction[0], direction[1], direction[2], - color="tab:orange", linewidth=0.5, arrow_length_ratio=0.2, alpha=0.6 + origin[0], + origin[1], + origin[2], + direction[0], + direction[1], + direction[2], + color="tab:orange", + linewidth=0.5, + arrow_length_ratio=0.2, + alpha=0.6, ) ax.scatter( @@ -871,7 +883,9 @@ def main() -> None: _plot_pose_auc_vs_input_images(all_pose_auc_by_label_and_count, auc_vs_images_plot_path) if fig_output_dir is not None and all_rotation_auc_values: rotation_auc_plot_path = fig_output_dir / "rotation_auc_boxplot_all_clusters.png" - _plot_pose_auc_boxplot(all_rotation_auc_values, rotation_auc_plot_path, "Rotation AUC by threshold (all clusters)") + _plot_pose_auc_boxplot( + all_rotation_auc_values, rotation_auc_plot_path, "Rotation AUC by threshold (all clusters)" + ) if fig_output_dir is not None and all_translation_auc_values: translation_auc_plot_path = fig_output_dir / "translation_auc_boxplot_all_clusters.png" _plot_pose_auc_boxplot( diff --git a/gtsfm/frontend/vggt.py b/gtsfm/frontend/vggt.py index 3ea1c1b12..5deeb5511 100644 --- a/gtsfm/frontend/vggt.py +++ b/gtsfm/frontend/vggt.py @@ -9,12 +9,14 @@ from importlib.machinery import ModuleSpec from pathlib import Path from types import ModuleType -from typing import Any, Optional, Sequence, Tuple, Union +from typing import Any, List, Optional, Sequence, Tuple, Union import numpy as np import torch from gtsam import Point2, Point3 +from PIL import Image as PILImage from torch.amp import autocast as amp_autocast # type: ignore +from torchvision import transforms as TF from gtsfm.bundle.bundle_adjustment import BundleAdjustmentOptimizer from gtsfm.common.gtsfm_data import GtsfmData @@ -134,6 +136,140 @@ def _resolve_dtype_argument(arg: Optional[Union[str, torch.dtype]]) -> Optional[ raise TypeError(f"Unsupported dtype specifier of type {type(arg)!r}: {arg!r}") +def load_image_batch_vggt_loader(loader, indices: List[int], mode="crop"): + """ + A quick start function to load and preprocess images for model input. + This assumes the images should have the same shape for easier batching, + but VGGT model can also work well with different shapes. + + Args: + loader: Loader instance providing ``get_image``. + indices: List of image indices to load. + mode (str, optional): Preprocessing mode, either "crop" or "pad". + - "crop" (default): Sets width to 518px and center crops height if needed. + - "pad": Preserves all pixels by making the largest dimension 518px + and padding the smaller dimension to reach a square shape. + + Returns: + torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W) + + Raises: + ValueError: If the input list is empty or if mode is invalid + + Notes: + - Images with different dimensions will be padded with white (value=1.0) + - A warning is printed when images have different shapes + - When mode="crop": The function ensures width=518px while maintaining aspect ratio + and height is center-cropped if larger than 518px + - When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio + and the smaller dimension is padded to reach a square shape (518x518) + - Dimensions are adjusted to be divisible by 14 for compatibility with model requirements + """ + # Check for empty list + if len(indices) == 0: + raise ValueError("At least 1 image is required") + + # Validate mode + if mode not in ["crop", "pad"]: + raise ValueError("Mode must be either 'crop' or 'pad'") + + images = [] + shapes = set() + to_tensor = TF.ToTensor() + target_size = 518 + + # First process all images and collect their shapes + for idx in indices: + # Open image + img = loader.get_image(idx).value_array + + img = PILImage.fromarray(img) + + width, height = img.size + + if mode == "pad": + # Make the largest dimension 518px while maintaining aspect ratio + if width >= height: + new_width = target_size + new_height = round(height * (new_width / width) / 14) * 14 # Make divisible by 14 + else: + new_height = target_size + new_width = round(width * (new_height / height) / 14) * 14 # Make divisible by 14 + else: # mode == "crop" + # Original behavior: set width to 518px + new_width = target_size + # Calculate height maintaining aspect ratio, divisible by 14 + new_height = round(height * (new_width / width) / 14) * 14 + + # Resize with new dimensions (width, height) + img = img.resize((new_width, new_height), PILImage.Resampling.BICUBIC) + img = to_tensor(img) # Convert to tensor (0, 1) + + # Center crop height if it's larger than 518 (only in crop mode) + if mode == "crop" and new_height > target_size: + start_y = (new_height - target_size) // 2 + img = img[:, start_y : start_y + target_size, :] + + # For pad mode, pad to make a square of target_size x target_size + if mode == "pad": + h_padding = target_size - img.shape[1] + w_padding = target_size - img.shape[2] + + if h_padding > 0 or w_padding > 0: + pad_top = h_padding // 2 + pad_bottom = h_padding - pad_top + pad_left = w_padding // 2 + pad_right = w_padding - pad_left + + # Pad with white (value=1.0) + img = torch.nn.functional.pad( + img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 + ) + + shapes.add((img.shape[1], img.shape[2])) + images.append(img) + + # Check if we have different shapes + # In theory our model can also work well with different shapes + if len(shapes) > 1: + logger.warning("Found images with different shapes: %s", shapes) + # Find maximum dimensions + max_height = max(shape[0] for shape in shapes) + max_width = max(shape[1] for shape in shapes) + + # Pad images if necessary + padded_images = [] + for img in images: + h_padding = max_height - img.shape[1] + w_padding = max_width - img.shape[2] + + if h_padding > 0 or w_padding > 0: + pad_top = h_padding // 2 + pad_bottom = h_padding - pad_top + pad_left = w_padding // 2 + pad_right = w_padding - pad_left + + img = torch.nn.functional.pad( + img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 + ) + padded_images.append(img) + images = padded_images + + images = torch.stack(images) # concatenate images + + # Ensure correct shape when single image + if len(indices) == 1: + # Verify shape is (1, C, H, W) + if images.dim() == 3: + images = images.unsqueeze(0) + + height, width = images.shape[-2], images.shape[-1] + coords = np.tile([0.0, 0.0, float(width), float(height), float(width), float(height)], (len(indices), 1)) + original_coords_tensor = torch.from_numpy(coords).float() + + return images, original_coords_tensor + + @dataclass class VggtConfiguration: """Configuration for the high-level VGGT reconstruction pipeline.""" @@ -860,6 +996,7 @@ def run_reconstruction_gtsfm_data_only(images: torch.Tensor, **kwargs) -> GtsfmD "VGGT_SUBMODULE_PATH", "LIGHTGLUE_SUBMODULE_PATH", "default_dtype", + "load_image_batch_vggt_loader", "load_and_preprocess_images_square", "resolve_weights_path", "load_model", diff --git a/gtsfm/loader/loader_base.py b/gtsfm/loader/loader_base.py index 13e2cd89a..f6f4cc096 100644 --- a/gtsfm/loader/loader_base.py +++ b/gtsfm/loader/loader_base.py @@ -12,8 +12,6 @@ from dask.delayed import Delayed, delayed from dask.distributed import Client, Future from gtsam import Cal3_S2, Cal3Bundler, Cal3DS2, Pose3 # type: ignore -from PIL import Image as PILImage -from torchvision import transforms as TF from trimesh import Trimesh import gtsfm.common.types as gtsfm_types @@ -494,138 +492,6 @@ def load_image_batch_vggt( transformed = batch_transform(batch_tensor) if batch_transform else batch_tensor return transformed, original_coords_tensor - def load_image_batch_vggt_loader(self, indices: List[int], mode="crop"): - """ - A quick start function to load and preprocess images for model input. - This assumes the images should have the same shape for easier batching, - but VGGT model can also work well with different shapes. - - Args: - indices: List of image indices to load - mode (str, optional): Preprocessing mode, either "crop" or "pad". - - "crop" (default): Sets width to 518px and center crops height if needed. - - "pad": Preserves all pixels by making the largest dimension 518px - and padding the smaller dimension to reach a square shape. - - Returns: - torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W) - - Raises: - ValueError: If the input list is empty or if mode is invalid - - Notes: - - Images with different dimensions will be padded with white (value=1.0) - - A warning is printed when images have different shapes - - When mode="crop": The function ensures width=518px while maintaining aspect ratio - and height is center-cropped if larger than 518px - - When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio - and the smaller dimension is padded to reach a square shape (518x518) - - Dimensions are adjusted to be divisible by 14 for compatibility with model requirements - """ - # Check for empty list - if len(indices) == 0: - raise ValueError("At least 1 image is required") - - # Validate mode - if mode not in ["crop", "pad"]: - raise ValueError("Mode must be either 'crop' or 'pad'") - - images = [] - shapes = set() - to_tensor = TF.ToTensor() - target_size = 518 - - # First process all images and collect their shapes - for idx in indices: - # Open image - img = self.get_image(idx).value_array - - img = PILImage.fromarray(img) - - width, height = img.size - - if mode == "pad": - # Make the largest dimension 518px while maintaining aspect ratio - if width >= height: - new_width = target_size - new_height = round(height * (new_width / width) / 14) * 14 # Make divisible by 14 - else: - new_height = target_size - new_width = round(width * (new_height / height) / 14) * 14 # Make divisible by 14 - else: # mode == "crop" - # Original behavior: set width to 518px - new_width = target_size - # Calculate height maintaining aspect ratio, divisible by 14 - new_height = round(height * (new_width / width) / 14) * 14 - - # Resize with new dimensions (width, height) - img = img.resize((new_width, new_height), PILImage.Resampling.BICUBIC) - img = to_tensor(img) # Convert to tensor (0, 1) - - # Center crop height if it's larger than 518 (only in crop mode) - if mode == "crop" and new_height > target_size: - start_y = (new_height - target_size) // 2 - img = img[:, start_y : start_y + target_size, :] - - # For pad mode, pad to make a square of target_size x target_size - if mode == "pad": - h_padding = target_size - img.shape[1] - w_padding = target_size - img.shape[2] - - if h_padding > 0 or w_padding > 0: - pad_top = h_padding // 2 - pad_bottom = h_padding - pad_top - pad_left = w_padding // 2 - pad_right = w_padding - pad_left - - # Pad with white (value=1.0) - img = torch.nn.functional.pad( - img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 - ) - - shapes.add((img.shape[1], img.shape[2])) - images.append(img) - - # Check if we have different shapes - # In theory our model can also work well with different shapes - if len(shapes) > 1: - logger.warning("Found images with different shapes: %s", shapes) - # Find maximum dimensions - max_height = max(shape[0] for shape in shapes) - max_width = max(shape[1] for shape in shapes) - - # Pad images if necessary - padded_images = [] - for img in images: - h_padding = max_height - img.shape[1] - w_padding = max_width - img.shape[2] - - if h_padding > 0 or w_padding > 0: - pad_top = h_padding // 2 - pad_bottom = h_padding - pad_top - pad_left = w_padding // 2 - pad_right = w_padding - pad_left - - img = torch.nn.functional.pad( - img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 - ) - padded_images.append(img) - images = padded_images - - images = torch.stack(images) # concatenate images - - # Ensure correct shape when single image - if len(indices) == 1: - # Verify shape is (1, C, H, W) - if images.dim() == 3: - images = images.unsqueeze(0) - - height, width = images.shape[-2], images.shape[-1] - coords = np.tile([0.0, 0.0, float(width), float(height), float(width), float(height)], (len(indices), 1)) - original_coords_tensor = torch.from_numpy(coords).float() - - return images, original_coords_tensor - def get_all_descriptor_image_batches_as_futures( self, client: Client, From 6f02af830efa31c19f8dc0870133abba798d6866 Mon Sep 17 00:00:00 2001 From: Harneet Singh Khanuja Date: Fri, 6 Feb 2026 02:06:15 -0500 Subject: [PATCH 19/24] Minor filtering changes to frontend/vggt.py --- gtsfm/frontend/vggt.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/gtsfm/frontend/vggt.py b/gtsfm/frontend/vggt.py index 5deeb5511..369c74bae 100644 --- a/gtsfm/frontend/vggt.py +++ b/gtsfm/frontend/vggt.py @@ -510,28 +510,25 @@ def _convert_vggt_outputs_to_gtsfm_data( if tracking_result: # track masks according to visibility, reprojection error, etc + max_reproj_error = float(config.max_reproj_error) track_mask = tracking_result.visibilities > config.track_vis_thresh - inlier_num = track_mask.sum(0) - min_measurements = 2 - valid_mask = inlier_num >= min_measurements # a track is invalid if without two inliers confidence_threshold = config.track_conf_thresh confidence_threshold = min( confidence_threshold, np.mean(tracking_result.confidences) + np.std(tracking_result.confidences) ) if tracking_result.confidences is not None: - valid_mask = np.logical_and(valid_mask, tracking_result.confidences > confidence_threshold) - valid_idx = np.nonzero(valid_mask)[0] - - max_reproj_error = float(config.max_reproj_error) - track_mask = tracking_result.visibilities > config.track_vis_thresh - if tracking_result.confidences is not None: - track_mask = np.logical_and(track_mask, tracking_result.confidences > config.track_conf_thresh) + track_mask = np.logical_and(track_mask, tracking_result.confidences > confidence_threshold) enforce_reproj_filter = ( tracking_result.points_3d is not None and np.isfinite(max_reproj_error) and max_reproj_error > 0.0 ) + inlier_num = track_mask.sum(0) + min_measurements = 2 + valid_mask = inlier_num >= min_measurements # a track is invalid if without two inliers + valid_idx = np.nonzero(valid_mask)[0] + logger.info("num points 3d: %d, num valid idx: %d", tracking_result.points_3d.shape[0], len(valid_idx)) for valid_id in valid_idx: @@ -563,7 +560,8 @@ def _convert_vggt_outputs_to_gtsfm_data( if len(per_track_measurements) < min_measurements: continue - + if enforce_reproj_filter and max_error_for_track > max_reproj_error: + continue track = torch_utils.colored_track_from_point(point_xyz, rgb) for global_idx, float_u, float_v in per_track_measurements: track.addMeasurement(global_idx, Point2(float_u, float_v)) From c3e99b8cfe886721f93e3a82bc537961e42e8b5e Mon Sep 17 00:00:00 2001 From: Harneet Singh Khanuja Date: Sat, 7 Feb 2026 12:05:28 -0500 Subject: [PATCH 20/24] Pinning GTSAM version --- environment_linux.yml | 2 +- environment_linux_cpuonly.yml | 2 +- environment_mac.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/environment_linux.yml b/environment_linux.yml index 8852d6f95..479688b81 100644 --- a/environment_linux.yml +++ b/environment_linux.yml @@ -56,7 +56,7 @@ dependencies: - pydegensac - colour - trimesh[easy] - - gtsam-develop + - gtsam-develop==4.3a1.dev202602040056 - pydot # dust3r/mast3r - roma diff --git a/environment_linux_cpuonly.yml b/environment_linux_cpuonly.yml index 5e7b5ff6e..00955e1f3 100644 --- a/environment_linux_cpuonly.yml +++ b/environment_linux_cpuonly.yml @@ -56,7 +56,7 @@ dependencies: - pydegensac - colour - trimesh[easy] - - gtsam-develop + - gtsam-develop==4.3a1.dev202602040056 - pydot - roma - tqdm diff --git a/environment_mac.yml b/environment_mac.yml index df85d2352..59eb40997 100644 --- a/environment_mac.yml +++ b/environment_mac.yml @@ -57,7 +57,7 @@ dependencies: - pydegensac - colour - trimesh[easy] - - gtsam-develop + - gtsam-develop==4.3a1.dev202602040056 - pydot # dust3r/mast3r - roma From 5c40d50e0961e0bddc14b330ba713d77cacea15d Mon Sep 17 00:00:00 2001 From: Harneet Singh Khanuja Date: Sat, 7 Feb 2026 12:16:29 -0500 Subject: [PATCH 21/24] Pinning GTSAM version in pyproject --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 227f3ec8e..cb9e386d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,7 +76,7 @@ dependencies = [ "pydegensac", "colour", "trimesh[easy]", - "gtsam==4.3a0", + "gtsam-develop==4.3a1.dev202602040056", "pydot", # Dust3r/Mast3r From a921e568a1ae067efd863c4f59d871d2a70a6b75 Mon Sep 17 00:00:00 2001 From: Harneet Singh Khanuja Date: Sat, 7 Feb 2026 13:40:00 -0500 Subject: [PATCH 22/24] Fixing noise model in tests --- tests/averaging/translation/test_averaging_1dsfm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/averaging/translation/test_averaging_1dsfm.py b/tests/averaging/translation/test_averaging_1dsfm.py index c02413af7..f2670a418 100644 --- a/tests/averaging/translation/test_averaging_1dsfm.py +++ b/tests/averaging/translation/test_averaging_1dsfm.py @@ -48,7 +48,7 @@ def test_binary_measurements_from_dict(self): (0, 2): Unit3(Point3(0, 1, 0)), (1, 2): Unit3(Point3(0, 0, 1)), } - noise_model = gtsam.noiseModel.Isotropic.Sigma(3, 0.1) + noise_model = gtsam.noiseModel.Isotropic.Sigma(2, 0.1) expected_measurement_idxs = set( [(C(i2), C(i1)) for (i1, i2) in w_i2Ui1_dict.keys()] + [(C(i2), L(i1)) for (i1, i2) in w_i2Ui1_dict_tracks.keys()] From 04b53137b831f6389d817abe7c7d51079380811e Mon Sep 17 00:00:00 2001 From: Harneet Singh Khanuja Date: Sat, 7 Feb 2026 14:24:33 -0500 Subject: [PATCH 23/24] Fixing tests --- tests/utils/test_align.py | 6 +++--- tests/utils/test_metric_utils.py | 16 ++++++++-------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/utils/test_align.py b/tests/utils/test_align.py index dff1fb3fd..f2f3dd1de 100644 --- a/tests/utils/test_align.py +++ b/tests/utils/test_align.py @@ -952,12 +952,12 @@ def test_align_gtsfm_data_via_Sim3_to_poses_skydio32(self) -> None: aligned_metrics = aligned_filtered_data.get_metrics(suffix="_filtered") - assert unaligned_metrics[3].name == "reprojection_errors_filtered_px" - assert aligned_metrics[3].name == "reprojection_errors_filtered_px" + assert unaligned_metrics[4].name == "reprojection_errors_filtered_px" + assert aligned_metrics[4].name == "reprojection_errors_filtered_px" # Reprojection error should be unaffected by Sim(3) alignment. for key in ["min", "max", "median", "mean", "stddev"]: - assert np.isclose(unaligned_metrics[3].summary[key], aligned_metrics[3].summary[key]) + assert np.isclose(unaligned_metrics[4].summary[key], aligned_metrics[4].summary[key]) if __name__ == "__main__": diff --git a/tests/utils/test_metric_utils.py b/tests/utils/test_metric_utils.py index 0c3e09967..35c4d43b3 100644 --- a/tests/utils/test_metric_utils.py +++ b/tests/utils/test_metric_utils.py @@ -177,14 +177,14 @@ def test_get_metrics_for_sfmdata_skydio32(self) -> None: ) metrics = aligned_filtered_data.get_metrics(suffix="_filtered") - assert metrics[0].name == "number_cameras" - assert np.isclose(metrics[0]._data, np.array(5.0, dtype=np.float32)) + assert metrics[1].name == "number_cameras" + assert np.isclose(metrics[1]._data, np.array(5.0, dtype=np.float32)) - assert metrics[1].name == "number_tracks_filtered" - assert np.isclose(metrics[1]._data, np.array(7.0, dtype=np.float32)) + assert metrics[2].name == "number_tracks_filtered" + assert np.isclose(metrics[2]._data, np.array(7.0, dtype=np.float32)) - assert metrics[2].name == "3d_track_lengths_filtered" - assert metrics[2].summary == { + assert metrics[3].name == "3d_track_lengths_filtered" + assert metrics[3].summary == { "min": 2, "max": 2, "median": 2.0, @@ -195,8 +195,8 @@ def test_get_metrics_for_sfmdata_skydio32(self) -> None: "invalid": 0, } - assert metrics[3].name == "reprojection_errors_filtered_px" - assert metrics[3].summary == {"min": np.nan, "max": np.nan, "median": np.nan, "mean": np.nan, "stddev": np.nan} + assert metrics[4].name == "reprojection_errors_filtered_px" + assert metrics[4].summary == {"min": np.nan, "max": np.nan, "median": np.nan, "mean": np.nan, "stddev": np.nan} def test_compute_percentage_change_improve() -> None: From 62bdf266a40be3ac0c0b5b962102b4bb24c15150 Mon Sep 17 00:00:00 2001 From: Harneet Singh Khanuja Date: Sat, 7 Feb 2026 14:57:40 -0500 Subject: [PATCH 24/24] Fixing tests --- tests/utils/test_metric_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/test_metric_utils.py b/tests/utils/test_metric_utils.py index 35c4d43b3..aa9c89cab 100644 --- a/tests/utils/test_metric_utils.py +++ b/tests/utils/test_metric_utils.py @@ -177,7 +177,7 @@ def test_get_metrics_for_sfmdata_skydio32(self) -> None: ) metrics = aligned_filtered_data.get_metrics(suffix="_filtered") - assert metrics[1].name == "number_cameras" + assert metrics[1].name == "number_cameras_filtered" assert np.isclose(metrics[1]._data, np.array(5.0, dtype=np.float32)) assert metrics[2].name == "number_tracks_filtered"