diff --git a/.github/copilot-instruction.md b/.github/copilot-instructions.md similarity index 100% rename from .github/copilot-instruction.md rename to .github/copilot-instructions.md diff --git a/.gitmodules b/.gitmodules index 4981a746f..16f6d8752 100644 --- a/.gitmodules +++ b/.gitmodules @@ -13,3 +13,6 @@ [submodule "thirdparty/FastVGGT"] path = thirdparty/FastVGGT url = https://github.com/mystorm16/FastVGGT.git +[submodule "thirdparty/Pi3"] + path = thirdparty/Pi3 + url = https://github.com/yyfz/Pi3.git diff --git a/gtsfm/evaluation/compare_colmap_outputs.py b/gtsfm/evaluation/compare_colmap_outputs.py index a84fd3145..0868af0e6 100644 --- a/gtsfm/evaluation/compare_colmap_outputs.py +++ b/gtsfm/evaluation/compare_colmap_outputs.py @@ -20,12 +20,26 @@ import gtsfm.utils.logger as logger_utils import gtsfm.utils.metrics as metric_utils from gtsfm.cluster_optimizer import save_metrics_reports -from gtsfm.evaluation.metrics import GtsfmMetricsGroup +from gtsfm.evaluation.metrics import GtsfmMetric, GtsfmMetricsGroup from gtsfm.utils import align, transform logger = logger_utils.get_logger() +def _is_auc_metric_name(metric_name: str) -> bool: + return metric_name.startswith("pose_auc_@") + + +def _convert_scalar_auc_metrics_to_percent(metrics: List[GtsfmMetric]) -> List[GtsfmMetric]: + converted_metrics: List[GtsfmMetric] = [] + for metric in metrics: + if metric.dim == 0 and _is_auc_metric_name(metric.name) and metric.data is not None: + converted_metrics.append(GtsfmMetric(metric.name, float(metric.data) * 100.0)) + else: + converted_metrics.append(metric) + return converted_metrics + + def load_poses(colmap_dirpath: str) -> Dict[str, Pose3]: """Returns mapping from image filename to associated camera pose.""" wTi_list, img_fnames, _, _, _, _ = io_utils.read_scene_data_from_colmap_format(colmap_dirpath) @@ -145,7 +159,7 @@ def export_metrics_group_to_csv(metrics_group: GtsfmMetricsGroup, output_path: s def _format_pose_auc(metrics_group: GtsfmMetricsGroup) -> str: auc_parts = [] for metric in metrics_group.metrics: - if not metric.name.startswith("pose_auc_@"): + if not _is_auc_metric_name(metric.name): continue if metric.data is None: continue @@ -154,7 +168,7 @@ def _format_pose_auc(metrics_group: GtsfmMetricsGroup) -> str: except (TypeError, ValueError): continue suffix = metric.name.replace("pose_auc_", "") - auc_parts.append(f"{suffix}={value:.3f}") + auc_parts.append(f"{suffix}={value:.2f}%") return ", ".join(auc_parts) @@ -218,11 +232,11 @@ def compare_poses(baseline_dirpath: str, eval_dirpath: str, output_dirpath: str) rotation_angular_errors = relative_rotation_error_metric.data translation_angular_errors = relative_translation_error_metric.data - metrics.extend( - metric_utils.compute_pose_auc_metric( - rotation_angular_errors, translation_angular_errors, save_dir=output_dirpath - ) + pose_auc_metrics = metric_utils.compute_pose_auc_metric( + rotation_angular_errors, translation_angular_errors, save_dir=output_dirpath ) + pose_auc_metrics = _convert_scalar_auc_metrics_to_percent(pose_auc_metrics) + metrics.extend(pose_auc_metrics) ba_pose_metrics = GtsfmMetricsGroup(name="ba_pose_error_metrics", metrics=metrics) diff --git a/gtsfm/evaluation/compare_colmap_outputs_by_cluster.py b/gtsfm/evaluation/compare_colmap_outputs_by_cluster.py index e9993c383..e1f6a799f 100644 --- a/gtsfm/evaluation/compare_colmap_outputs_by_cluster.py +++ b/gtsfm/evaluation/compare_colmap_outputs_by_cluster.py @@ -26,6 +26,32 @@ logger = logger_utils.get_logger() +def _is_auc_metric_name(metric_name: str) -> bool: + return ( + metric_name.startswith("pose_auc_@") + or metric_name.startswith("rotation_auc_@") + or metric_name.startswith("translation_auc_@") + ) + + +def _convert_scalar_auc_metrics_to_percent(metrics: List[GtsfmMetric]) -> List[GtsfmMetric]: + converted_metrics: List[GtsfmMetric] = [] + for metric in metrics: + if metric.dim == 0 and _is_auc_metric_name(metric.name) and metric.data is not None: + converted_metrics.append(GtsfmMetric(metric.name, float(metric.data) * 100.0)) + else: + converted_metrics.append(metric) + return converted_metrics + + +def _build_short_cluster_plot_filename(recon_dir: Path, root: Path) -> str: + """Build filename as '___camera_centers.png'.""" + del root # not needed after simplifying naming policy + cluster_name = recon_dir.parent.name or "cluster" + recon_name = recon_dir.name or "recon" + return f"{cluster_name}__{recon_name}_camera_centers.png" + + def _read_images_txt_with_names(images_txt: Path) -> Dict[str, Pose3]: """Read poses from COLMAP images.txt keyed by image NAME.""" if not images_txt.exists(): @@ -200,7 +226,7 @@ def _compute_pose_metrics(baseline_list: List[Pose3], current_aligned_list: List rotation_auc_values = metric_utils.pose_auc(rotation_angular_errors, thresholds_deg) metrics.extend( [ - GtsfmMetric(f"rotation_auc_@{threshold}_deg", auc) + GtsfmMetric(f"rotation_auc_@{threshold}_deg", float(auc) * 100.0) for threshold, auc in zip(thresholds_deg, rotation_auc_values) ] ) @@ -209,15 +235,15 @@ def _compute_pose_metrics(baseline_list: List[Pose3], current_aligned_list: List translation_auc_values = metric_utils.pose_auc(translation_angular_errors, thresholds_deg) metrics.extend( [ - GtsfmMetric(f"translation_auc_@{threshold}_deg", auc) + GtsfmMetric(f"translation_auc_@{threshold}_deg", float(auc) * 100.0) for threshold, auc in zip(thresholds_deg, translation_auc_values) ] ) - metrics.extend( - metric_utils.compute_pose_auc_metric( - relative_rotation_error_metric.data, relative_translation_error_metric.data, thresholds_deg=thresholds_deg - ) + pose_auc_metrics = metric_utils.compute_pose_auc_metric( + relative_rotation_error_metric.data, relative_translation_error_metric.data, thresholds_deg=thresholds_deg ) + pose_auc_metrics = _convert_scalar_auc_metrics_to_percent(pose_auc_metrics) + metrics.extend(pose_auc_metrics) return GtsfmMetricsGroup(name="ba_pose_error_metrics", metrics=metrics) @@ -440,14 +466,14 @@ def _plot_pose_auc_boxplot(auc_values_by_label: Dict[str, List[float]], output_p ax = fig.add_subplot(111) ax.boxplot(data, vert=True, patch_artist=True) ax.set_title(title) - ax.set_ylabel("AUC") + ax.set_ylabel("AUC (%)") ax.set_xticks(range(1, len(labels) + 1)) ax.set_xticklabels(labels, rotation=30, ha="right") stats_lines = [] for label, values in zip(labels, data): mean_val = float(np.mean(values)) median_val = float(np.median(values)) - stats_lines.append(f"{label}: mean={mean_val:.3f}, med={median_val:.3f}") + stats_lines.append(f"{label}: mean={mean_val:.2f}%, med={median_val:.2f}%") if stats_lines: ax.text( 0.02, @@ -488,7 +514,7 @@ def _plot_pose_auc_vs_input_images( ax.set_title("Pose AUC vs input images (all clusters)") ax.set_xlabel("input images (current count)") - ax.set_ylabel("AUC") + ax.set_ylabel("AUC (%)") ax.grid(True, linestyle="--", linewidth=0.5, alpha=0.5) ax.legend(loc="best", fontsize=8) fig.tight_layout() @@ -596,7 +622,7 @@ def _format_auc(metrics_group: GtsfmMetricsGroup, prefix: str) -> str: except (TypeError, ValueError): continue suffix = metric.name.replace(f"{prefix}_", "") - auc_parts.append(f"{suffix}={value:.3f}") + auc_parts.append(f"{suffix}={value:.2f}%") return ", ".join(auc_parts) @@ -832,8 +858,7 @@ def main() -> None: else: _print_metrics(str(recon_dir), metrics_group) if fig_output_dir is not None: - safe_name = str(recon_dir).replace(os.sep, "__") - plot_path = fig_output_dir / f"{safe_name}_camera_centers.png" + plot_path = fig_output_dir / _build_short_cluster_plot_filename(recon_dir, root) pose_auc_text = _format_auc(metrics_group, "pose_auc") rotation_auc_text = _format_auc(metrics_group, "rotation_auc") translation_auc_text = _format_auc(metrics_group, "translation_auc") diff --git a/gtsfm/utils/pycolmap_utils.py b/gtsfm/utils/pycolmap_utils.py index 806b12248..1c5dbff60 100644 --- a/gtsfm/utils/pycolmap_utils.py +++ b/gtsfm/utils/pycolmap_utils.py @@ -61,6 +61,10 @@ def colmap_camera_to_gtsam_calibration(camera: ColmapCamera) -> CALIBRATION_TYPE # See https://github.com/colmap/colmap/blob/1f6812e333a1e4b2ef56aa74e2c3873e4e3a40cd/src/colmap/sensor/models.h#L273 # noqa: E501 fx, fy, cx, cy, k1, k2, p1, p2 = camera.params[:8] return gtsam.Cal3DS2(fx, fy, 0.0, cx, cy, k1, k2, p1, p2) + elif camera_model_name == "SIMPLE_PINHOLE": + # See https://github.com/colmap/colmap/blob/1f6812e333a1e4b2ef56aa74e2c3873e4e3a40cd/src/colmap/sensor/models.h#L196 # noqa: E501 + f, cx, cy = camera.params + return gtsam.Cal3_S2(f, f, 0.0, cx, cy) elif camera_model_name == "PINHOLE": # See https://github.com/colmap/colmap/blob/1f6812e333a1e4b2ef56aa74e2c3873e4e3a40cd/src/colmap/sensor/models.h#L196 # noqa: E501 fx, fy, cx, cy = camera.params diff --git a/pipeline/1-partition/partition_metis_megaloc.py b/pipeline/1-partition/partition_metis_megaloc.py new file mode 100644 index 000000000..15e98bc2c --- /dev/null +++ b/pipeline/1-partition/partition_metis_megaloc.py @@ -0,0 +1,190 @@ +"""Run MegaLoc retrieval + METIS partitioning and persist the visibility graph + cluster tree. + +This script mirrors the image_pairs_generator and graph_partitioner configuration in +`gtsfm/configs/vggt.yaml`. It loads images, generates a visibility graph, partitions +the graph with METIS, and saves both the graph and tree under the chosen output root. +""" + +from __future__ import annotations + +import argparse +import pickle +import time +from pathlib import Path + +import hydra +from dask.distributed import Client, LocalCluster +from hydra.utils import instantiate + +import gtsfm.utils.logger as logger_utils +from gtsfm.common.outputs import OutputPaths, prepare_output_paths +from gtsfm.graph_partitioner.graph_partitioner_base import GraphPartitionerBase +from gtsfm.graph_partitioner.single_partitioner import SinglePartitioner +from gtsfm.loader.loader_base import LoaderBase +from gtsfm.products.visibility_graph import VisibilityGraph +from gtsfm.retriever.image_pairs_generator import ImagePairsGenerator + +logger = logger_utils.get_logger() + + +def _build_components( + config_name: str, + dataset_dir: str, + images_dir: str | None, + max_resolution: int | None, +) -> tuple[LoaderBase, ImagePairsGenerator, GraphPartitionerBase]: + overrides: list[str] = [f"loader.dataset_dir={dataset_dir}"] + if images_dir is not None: + overrides.append(f"loader.images_dir={images_dir}") + if max_resolution is not None: + overrides.append(f"loader.max_resolution={max_resolution}") + + with hydra.initialize_config_module(config_module="gtsfm.configs", version_base=None): + cfg = hydra.compose(config_name=config_name, overrides=overrides) + + loader: LoaderBase = instantiate(cfg.loader) + image_pairs_generator: ImagePairsGenerator = instantiate(cfg.image_pairs_generator) + graph_partitioner: GraphPartitionerBase = instantiate(cfg.graph_partitioner) + return loader, image_pairs_generator, graph_partitioner + + +def _run_retriever( + client: Client, loader: LoaderBase, image_pairs_generator: ImagePairsGenerator, output_paths: OutputPaths +) -> VisibilityGraph: + start_time = time.time() + batch_size = image_pairs_generator._batch_size + transforms = image_pairs_generator.get_preprocessing_transforms() + image_batch_futures = loader.get_all_descriptor_image_batches_as_futures(client, batch_size, *transforms) + image_fnames = loader.image_filenames() + + logger.info("🔥 Running image pair retrieval...") + visibility_graph = image_pairs_generator.run( + client=client, + image_batch_futures=image_batch_futures, + image_fnames=image_fnames, + plots_output_dir=output_paths.plots, + ) + + try: + image_pairs_generator._retriever.save_diagnostics( + image_fnames=image_fnames, + pairs=visibility_graph, + plots_output_dir=output_paths.plots, + ) + except Exception as exc: # pragma: no cover - diagnostic path best-effort + logger.warning("Failed to persist retriever diagnostics: %s", exc) + + logger.info("🚀 Image pair retrieval took %.2f min.", (time.time() - start_time) / 60.0) + return visibility_graph + + +def _save_visibility_graph(graph: VisibilityGraph, output_paths: OutputPaths) -> None: + try: + with open(output_paths.results / "visibility_graph.pkl", "wb") as f: + pickle.dump(graph, f) + except Exception as exc: + logger.warning("Failed to serialize visibility graph: %s", exc) + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run MegaLoc+METIS partitioning and save outputs.") + parser.add_argument( + "--dataset_dir", + type=str, + required=True, + help="Dataset root containing images/ (Olsson-style loader default).", + ) + parser.add_argument( + "--images_dir", + type=str, + default=None, + help="Optional path to images directory (overrides loader default).", + ) + parser.add_argument( + "--output_root", + type=str, + default=str(Path.cwd()), + help="Root directory to store results (will create output_root/results).", + ) + parser.add_argument( + "--config_name", + type=str, + default="vggt_megaloc_phototourism", + help="Config in gtsfm/configs to load (default: vggt).", + ) + parser.add_argument( + "--max_resolution", + type=int, + default=None, + help="Override loader max resolution (if unset, uses config default).", + ) + parser.add_argument( + "--num_workers", + type=int, + default=1, + help="Number of local Dask workers.", + ) + parser.add_argument( + "--threads_per_worker", + type=int, + default=1, + help="Threads per Dask worker.", + ) + parser.add_argument( + "--worker_memory_limit", + type=str, + default="32GB", + help="Memory limit per worker, e.g. 16GB.", + ) + parser.add_argument( + "--dashboard_address", + type=str, + default=":8787", + help="Dask dashboard address, set to empty string to disable.", + ) + parser.add_argument( + "--single_cluster", + action="store_true", + help="Skip METIS and output a single cluster containing all retrieved image pairs.", + ) + return parser.parse_args() + + +def main() -> None: + args = _parse_args() + + output_root = Path(args.output_root) + output_paths = prepare_output_paths(output_root, None) + + loader, image_pairs_generator, graph_partitioner = _build_components( + config_name=args.config_name, + dataset_dir=args.dataset_dir, + images_dir=args.images_dir, + max_resolution=args.max_resolution, + ) + + logger.info("🌟 Starting Dask local cluster...") + cluster = LocalCluster( + n_workers=args.num_workers, + threads_per_worker=args.threads_per_worker, + memory_limit=args.worker_memory_limit, + dashboard_address=args.dashboard_address, + ) + + with Client(cluster) as client: + visibility_graph = _run_retriever(client, loader, image_pairs_generator, output_paths) + + if args.single_cluster: + logger.info("🔥 Skipping METIS; creating a single cluster with all retrieved pairs...") + cluster_tree = SinglePartitioner().run(visibility_graph) + else: + logger.info("🔥 Running METIS partitioning...") + cluster_tree = graph_partitioner.run(visibility_graph) + graph_partitioner.log_partition_details(cluster_tree, output_paths) + _save_visibility_graph(visibility_graph, output_paths) + + logger.info("✅ Saved visibility_graph.pkl and cluster_tree.pkl under %s", output_paths.results) + + +if __name__ == "__main__": + main() diff --git a/pipeline/2-reconstruction/Pi3/run_on_cluster.py b/pipeline/2-reconstruction/Pi3/run_on_cluster.py new file mode 100644 index 000000000..6047d9222 --- /dev/null +++ b/pipeline/2-reconstruction/Pi3/run_on_cluster.py @@ -0,0 +1,341 @@ +"""Run Pi3 reconstruction per cluster using a saved cluster_tree.pkl. + +Writes COLMAP text outputs under: + /results/.../ +matching the cluster tree directory structure. +""" + +from __future__ import annotations + +import argparse +import math +import os +import pickle +import sys +from pathlib import Path +from typing import Iterable, Optional, Sequence + +import hydra +import numpy as np +import torch +from hydra.utils import instantiate +from PIL import Image +from torchvision import transforms + +import thirdparty.colmap.scripts.python.read_write_model as colmap_io +from gtsfm.common.outputs import prepare_output_paths +from gtsfm.products.visibility_graph import visibility_graph_keys +from gtsfm.utils.tree import PreOrderIter, Tree +_PI3_PROJECT_ROOT = Path("/nethome/xzhang979/nvme/gtsfm/thirdparty/Pi3") +if not _PI3_PROJECT_ROOT.is_dir(): + _PI3_PROJECT_ROOT = Path(__file__).resolve().parents[3] / "thirdparty" / "Pi3" +if str(_PI3_PROJECT_ROOT) not in sys.path: + sys.path.append(str(_PI3_PROJECT_ROOT)) + +from pi3.utils.geometry import depth_edge +from pi3.models.pi3 import Pi3 + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run Pi3 on clusters from a saved ClusterTree.") + parser.add_argument("--cluster_tree_path", type=str, required=True, help="Path to cluster_tree.pkl") + parser.add_argument("--dataset_dir", type=str, required=True, help="Dataset root (used for loader).") + parser.add_argument("--images_root", type=str, default=None, help="Root directory for images.") + parser.add_argument( + "--output_root", + type=str, + required=True, + help="Base output directory (results will be in /results/...).", + ) + parser.add_argument("--config_name", type=str, default="vggt", help="Config in gtsfm/configs for loader.") + parser.add_argument("--max_resolution", type=int, default=None, help="Optional loader max resolution override.") + parser.add_argument("--model_name", type=str, default="pi3", help="Per-cluster output model folder name.") + parser.add_argument("--min_images", type=int, default=2, help="Skip clusters with fewer images.") + parser.add_argument("--run_leaf", action="store_true", default=True, help="Run on leaf clusters.") + parser.add_argument("--run_parent", action="store_true", default=True, help="Run on non-leaf clusters.") + parser.add_argument("--run_root", action="store_true", default=True, help="Run on root cluster.") + parser.add_argument( + "--no_skip_existing", + action="store_false", + dest="skip_existing", + default=True, + help="Recompute even if output already exists.", + ) + parser.add_argument("--ckpt", type=str, default=None, help="Optional checkpoint path.") + parser.add_argument("--device", type=str, default="cuda", help="Device: cuda or cpu.") + parser.add_argument("--focal_length", type=float, default=None, help="Optional SIMPLE_PINHOLE focal length.") + parser.add_argument( + "--pixel_limit", + type=int, + default=255000, + help="Target max image pixels after resizing (same behavior as Pi3 examples).", + ) + return parser.parse_args() + + +def _build_loader(config_name: str, dataset_dir: str, images_dir: str | None, max_resolution: int | None): + overrides: list[str] = [f"loader.dataset_dir={dataset_dir}"] + if images_dir is not None: + overrides.append(f"loader.images_dir={images_dir}") + if max_resolution is not None: + overrides.append(f"loader.max_resolution={max_resolution}") + + with hydra.initialize_config_module(config_module="gtsfm.configs", version_base=None): + cfg = hydra.compose(config_name=config_name, overrides=overrides) + return instantiate(cfg.loader) + + +def _load_cluster_tree(cluster_tree_path: str): + with open(cluster_tree_path, "rb") as f: + return pickle.load(f) + + +def _resolve_image_paths(image_names: Sequence[str], images_root: str | None) -> list[str]: + resolved_paths: list[str] = [] + for name in image_names: + if os.path.isabs(name): + resolved_paths.append(name) + else: + if images_root is None: + raise ValueError("images_root is required when image filenames are relative.") + resolved_paths.append(os.path.join(images_root, name)) + return resolved_paths + + +def _iter_clusters_with_paths(cluster_tree) -> Iterable[tuple[tuple[int, ...], Sequence[tuple[int, int]], bool]]: + path_tree: Tree[tuple[tuple[int, ...], Sequence[tuple[int, int]]]] = cluster_tree.map_with_path( + lambda path, visibility_graph: (path, visibility_graph) + ) + for node in PreOrderIter(path_tree): + path, visibility_graph = node.value + yield path, visibility_graph, node.is_leaf() + + +def _should_run_cluster(path: tuple[int, ...], is_leaf: bool, args: argparse.Namespace) -> bool: + if path == () and not args.run_root: + return False + if is_leaf and args.run_leaf: + return True + if (not is_leaf) and args.run_parent: + return True + return False + + +def _load_images_from_paths( + image_paths: Sequence[str], pixel_limit: int, device: torch.device +) -> tuple[torch.Tensor, tuple[int, int]]: + sources: list[Image.Image] = [] + for image_path in image_paths: + sources.append(Image.open(image_path).convert("RGB")) + + if not sources: + raise ValueError("No images loaded for cluster.") + + first_img = sources[0] + w_orig, h_orig = first_img.size + scale = math.sqrt(pixel_limit / (w_orig * h_orig)) if w_orig * h_orig > 0 else 1.0 + w_target, h_target = w_orig * scale, h_orig * scale + k, m = round(w_target / 14), round(h_target / 14) + while (k * 14) * (m * 14) > pixel_limit: + if k / max(1, m) > w_target / max(1.0, h_target): + k -= 1 + else: + m -= 1 + target_w = max(1, k) * 14 + target_h = max(1, m) * 14 + + to_tensor = transforms.ToTensor() + tensors: list[torch.Tensor] = [] + for img in sources: + resized = img.resize((target_w, target_h), Image.Resampling.LANCZOS) + tensors.append(to_tensor(resized)) + imgs = torch.stack(tensors, dim=0).to(device) + return imgs, (target_w, target_h) + + +def _write_colmap_text( + colmap_dir: str, + image_names: Sequence[str], + camera_poses_wTc: np.ndarray, + points_xyz: np.ndarray, + points_rgb: np.ndarray, + image_size: tuple[int, int], + focal_length: Optional[float], +) -> None: + if len(image_names) != camera_poses_wTc.shape[0]: + raise ValueError("Number of image names must match number of camera poses.") + os.makedirs(colmap_dir, exist_ok=True) + + width, height = image_size + if focal_length is None: + focal_length = 0.5 * (width + height) + + cameras = {} + + images = {} + for idx, name in enumerate(image_names, start=1): + cameras[idx] = colmap_io.Camera( + id=idx, + model="SIMPLE_PINHOLE", + width=int(width), + height=int(height), + params=np.array([float(focal_length), float(width) / 2.0, float(height) / 2.0]), + ) + wTc = camera_poses_wTc[idx - 1] + r_wc = wTc[:3, :3] + t_wc = wTc[:3, 3] + r_cw = r_wc.T + t_cw = -r_cw @ t_wc + qvec = colmap_io.rotmat2qvec(r_cw) + images[idx] = colmap_io.Image( + id=idx, + qvec=qvec, + tvec=t_cw, + camera_id=idx, + name=name, + xys=np.zeros((0, 2), dtype=np.float64), + point3D_ids=np.zeros((0,), dtype=np.int64), + ) + + points3d = {} + for idx, (xyz, rgb) in enumerate(zip(points_xyz, points_rgb), start=1): + points3d[idx] = colmap_io.Point3D( + id=idx, + xyz=xyz, + rgb=rgb, + error=0.0, + image_ids=np.zeros((0,), dtype=np.int32), + point2D_idxs=np.zeros((0,), dtype=np.int32), + ) + + colmap_io.write_model(cameras, images, points3d, path=colmap_dir, ext=".txt") + + +def _setup_model(device: torch.device, ckpt: str | None) -> Pi3: + if ckpt is not None: + model = Pi3().to(device).eval() + if ckpt.endswith(".safetensors"): + from safetensors.torch import load_file + + weight = load_file(ckpt) + else: + weight = torch.load(ckpt, map_location=device, weights_only=False) + model.load_state_dict(weight) + else: + model = Pi3.from_pretrained("yyfz233/Pi3").to(device).eval() + return model + + +def _run_pi3_on_cluster( + model: Pi3, + device: torch.device, + image_paths: Sequence[str], + image_names: Sequence[str], + output_dir: Path, + focal_length: float | None, + pixel_limit: int, +) -> None: + imgs, (width, height) = _load_images_from_paths(image_paths, pixel_limit=pixel_limit, device=device) + if imgs.shape[0] == 0: + raise ValueError("No images available after preprocessing.") + + if device.type == "cuda": + major, _ = torch.cuda.get_device_capability(device=device) + dtype = torch.bfloat16 if major >= 8 else torch.float16 + with torch.no_grad(): + with torch.amp.autocast("cuda", dtype=dtype): + res = model(imgs[None]) + else: + with torch.no_grad(): + res = model(imgs[None]) + + masks = torch.sigmoid(res["conf"][..., 0]) > 0.1 + non_edge = ~depth_edge(res["local_points"][..., 2], rtol=0.03) + masks = torch.logical_and(masks, non_edge)[0] + + points_xyz = res["points"][0][masks].cpu().numpy() + points_rgb = (imgs.permute(0, 2, 3, 1)[masks].cpu().numpy() * 255.0).round().astype(np.uint8) + + _write_colmap_text( + colmap_dir=str(output_dir), + image_names=image_names, + camera_poses_wTc=res["camera_poses"][0].cpu().numpy(), + points_xyz=points_xyz, + points_rgb=points_rgb, + image_size=(width, height), + focal_length=focal_length, + ) + + +def main() -> None: + args = _parse_args() + cluster_tree = _load_cluster_tree(args.cluster_tree_path) + + loader = _build_loader(args.config_name, args.dataset_dir, args.images_root, args.max_resolution) + image_names = loader.image_filenames() + images_root = args.images_root + if images_root is None and hasattr(loader, "_images_dir"): + images_root = getattr(loader, "_images_dir") + image_paths = _resolve_image_paths(image_names, images_root) + + requested_device = args.device + if requested_device == "cuda" and not torch.cuda.is_available(): + print("CUDA requested but not available. Falling back to CPU.") + requested_device = "cpu" + device = torch.device(requested_device) + model = _setup_model(device, args.ckpt) + + output_root = Path(args.output_root) + output_root.mkdir(parents=True, exist_ok=True) + log_path = output_root / "pi3_cluster.log" + + def log(message: str) -> None: + with open(log_path, "a", encoding="utf-8") as f: + f.write(f"{message}\n") + print(message) + + for path, visibility_graph, is_leaf in _iter_clusters_with_paths(cluster_tree): + if not _should_run_cluster(path, is_leaf, args): + continue + + image_indices = sorted(visibility_graph_keys(visibility_graph)) + if len(image_indices) < args.min_images: + log(f"Skipping {path}: only {len(image_indices)} images.") + continue + if max(image_indices) >= len(image_names): + log(f"Skipping {path}: image index out of range (max={max(image_indices)}).") + continue + + cluster_image_paths = [image_paths[idx] for idx in image_indices] + missing_paths = [p for p in cluster_image_paths if not os.path.exists(p)] + if missing_paths: + log(f"Skipping {path}: missing {len(missing_paths)} images.") + continue + + cluster_image_names = [os.path.basename(image_names[idx]) for idx in image_indices] + output_paths = prepare_output_paths(output_root, path) + output_dir = output_paths.results / args.model_name + if args.skip_existing and (output_dir / "cameras.txt").exists(): + log(f"Skipping {path}: output already exists at {output_dir}.") + continue + + try: + log(f"Running Pi3 for {path} -> {output_dir}") + _run_pi3_on_cluster( + model=model, + device=device, + image_paths=cluster_image_paths, + image_names=cluster_image_names, + output_dir=output_dir, + focal_length=args.focal_length, + pixel_limit=args.pixel_limit, + ) + except Exception as exc: + log(f"Failed {path}: {exc!r}") + finally: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +if __name__ == "__main__": + main() diff --git a/pipeline/2-reconstruction/vggt/run_on_cluster.py b/pipeline/2-reconstruction/vggt/run_on_cluster.py new file mode 100644 index 000000000..8c4c30033 --- /dev/null +++ b/pipeline/2-reconstruction/vggt/run_on_cluster.py @@ -0,0 +1,1552 @@ +""" +Run VGGT on clusters defined by a saved ClusterTree pickle. + +This script mirrors the VGGT pipeline in +`pipeline/2-reconstruction/vggt/evaluation/test_co3d_cluster.py`, but it derives +cluster membership directly from a `cluster_tree.pkl` (generated by +`pipeline/1-partition/partition_metis_megaloc.py`). It writes COLMAP text outputs +under a GTSFM-style results tree, e.g. `/results/C_1/.../vggt`. +""" + +from __future__ import annotations + +import argparse +import gc +import os +import pickle +import random +import sqlite3 +import shutil +import sys +import tempfile +from pathlib import Path +from typing import Iterable, Sequence + +import hydra +import numpy as np +import torch +import torch.nn.functional as F +from hydra.utils import instantiate +from PIL import Image + +from gtsfm.common.outputs import prepare_output_paths +from gtsfm.products.visibility_graph import visibility_graph_keys +from gtsfm.utils.tree import PreOrderIter, Tree + +REPO_ROOT = Path(__file__).resolve().parents[3] +THIRDPARTY_VGGT_ROOT = REPO_ROOT / "thirdparty" / "vggt" +if THIRDPARTY_VGGT_ROOT.exists(): + sys.path.insert(0, str(THIRDPARTY_VGGT_ROOT)) +THIRDPARTY_LIGHTGLUE_ROOT = REPO_ROOT / "thirdparty" / "LightGlue" +if THIRDPARTY_LIGHTGLUE_ROOT.exists(): + sys.path.insert(0, str(THIRDPARTY_LIGHTGLUE_ROOT)) + +from vggt.models.vggt import VGGT +from vggt.utils.geometry import unproject_depth_map_to_point_map +from vggt.utils.load_fn import load_and_preprocess_images +from vggt.utils.pose_enc import pose_encoding_to_extri_intri + +_RESNET_MEAN = [0.485, 0.456, 0.406] +_RESNET_STD = [0.229, 0.224, 0.225] + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run VGGT on clusters from a saved ClusterTree.") + parser.add_argument("--cluster_tree_path", type=str, required=True, help="Path to cluster_tree.pkl") + parser.add_argument("--dataset_dir", type=str, required=True, help="Dataset root (used for loader).") + parser.add_argument("--images_root", type=str, default=None, help="Root directory for images.") + parser.add_argument( + "--output_root", + type=str, + default="results/gerrard-hall/2-reconstruction/vggt_cluster_run", + help="Base output directory (results will be in /results/...).", + ) + parser.add_argument( + "--ba_output_root", + type=str, + default=None, + help="Optional base output directory for BA results (defaults to output_root).", + ) + parser.add_argument( + "--config_name", + type=str, + default="vggt", + help="Config in gtsfm/configs to load for the image loader.", + ) + parser.add_argument( + "--max_resolution", + type=int, + default=None, + help="Optional loader max resolution override.", + ) + parser.add_argument( + "--model_path", + type=str, + default=str(THIRDPARTY_VGGT_ROOT / "weights" / "model.pt"), + help="Path to the VGGT model checkpoint.", + ) + parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility.") + parser.add_argument("--use_ba", action="store_true", default=False, help="Use BA for reconstruction.") + parser.add_argument( + "--ba_loss", + type=str, + choices=["TRIVIAL", "SOFT_L1", "CAUCHY", "HUBER"], + default="CAUCHY", + help="Robust loss type for pycolmap BA.", + ) + parser.add_argument( + "--robust_ba", + action="store_true", + default=False, + help="Enable robust BA loss settings.", + ) + parser.add_argument( + "--no_robust_ba", + action="store_false", + dest="robust_ba", + help="Disable robust BA loss settings and use non-robust BA loss.", + ) + parser.add_argument( + "--ba_loss_scale", + type=float, + default=2.0, + help="Robust loss scale for pycolmap BA.", + ) + parser.add_argument( + "--ba_refine_intrinsics", + action="store_true", + default=False, + help="Allow BA to refine intrinsics (focal/principal point/extra params).", + ) + parser.add_argument( + "--ba_use_gt_calibration", + action="store_true", + default=False, + help="Use ground-truth camera calibration from benchmark COLMAP model as BA initialization.", + ) + parser.add_argument( + "--ba_gt_calibration_dir", + type=str, + default=None, + help="Directory containing benchmark COLMAP text model (cameras.txt/images.txt/points3D.txt).", + ) + parser.add_argument( + "--ba_use_gt_pose", + action="store_true", + default=False, + help="Use ground-truth camera poses from benchmark COLMAP model as BA initialization.", + ) + parser.add_argument( + "--ba_gt_pose_dir", + type=str, + default=None, + help="Directory containing benchmark COLMAP text model (images.txt required) for GT pose initialization.", + ) + parser.add_argument( + "--ba_tracker", + type=str, + choices=["vggt", "vggsfm", "colmap"], + default="vggt", + help="Tracker/backend selection (vggt, vggsfm, or colmap). In non-BA mode, colmap runs COLMAP feature tracking.", + ) + parser.add_argument("--img_load_resolution", type=int, default=1024, help="Square load resolution for VGGT input.") + parser.add_argument("--vggt_fixed_resolution", type=int, default=518, help="VGGT internal inference resolution.") + parser.add_argument( + "--tracking_max_query_pts", + type=int, + default=256, + help="Max number of query points for VGGT tracking.", + ) + parser.add_argument( + "--tracking_query_frame_num", + type=int, + default=3, + help="Number of query frames for VGGT tracking.", + ) + parser.add_argument( + "--tracking_keypoint_extractor", + type=str, + default="aliked", + help="Keypoint extractor(s) for VGGT tracking (e.g. aliked, sp, sift, aliked+sp).", + ) + parser.add_argument( + "--conf_thres_value", type=float, default=5.0, help="Confidence threshold value for depth filtering." + ) + parser.add_argument( + "--max_reproj_error", + type=float, + default=12.0, + help="Maximum reprojection error (pixels) when filtering track observations for export.", + ) + parser.add_argument( + "--point_source", + type=str, + choices=["depth", "triangulation"], + default="depth", + help="3D point source: VGGT depth sampling or triangulation from 2D correspondences + VGGT poses.", + ) + parser.add_argument( + "--triangulation_min_views", + type=int, + default=2, + help="Minimum number of observations required to triangulate a track.", + ) + parser.add_argument( + "--save_tracking_outputs", + action="store_true", + default=False, + help="Also save raw VGGT tracking outputs (pre-BA) in COLMAP text format.", + ) + parser.add_argument("--run_leaf", action="store_true", default=True, help="Run VGGT on leaf clusters.") + parser.add_argument("--run_parent", action="store_true", default=True, help="Run VGGT on non-leaf clusters.") + parser.add_argument("--run_root", action="store_true", default=True, help="Run VGGT on the root cluster.") + parser.add_argument("--min_images", type=int, default=2, help="Skip clusters with fewer images.") + parser.add_argument( + "--no_skip_existing", + action="store_false", + dest="skip_existing", + default=True, + help="Recompute even if output already exists.", + ) + args = parser.parse_args() + if args.images_root is None: + args.images_root = os.path.join(args.dataset_dir, "images") + if not (args.run_leaf or args.run_parent or args.run_root): + args.run_leaf = True + return args + + +def _build_loader(config_name: str, dataset_dir: str, images_dir: str | None, max_resolution: int | None): + overrides: list[str] = [f"loader.dataset_dir={dataset_dir}"] + if images_dir is not None: + overrides.append(f"loader.images_dir={images_dir}") + if max_resolution is not None: + overrides.append(f"loader.max_resolution={max_resolution}") + + with hydra.initialize_config_module(config_module="gtsfm.configs", version_base=None): + cfg = hydra.compose(config_name=config_name, overrides=overrides) + return instantiate(cfg.loader) + + +def _load_cluster_tree(cluster_tree_path: str): + with open(cluster_tree_path, "rb") as f: + return pickle.load(f) + + +def _resolve_image_paths(image_names: Sequence[str], images_root: str | None) -> list[str]: + resolved_paths = [] + for name in image_names: + if os.path.isabs(name): + resolved_paths.append(name) + else: + if images_root is None: + raise ValueError("images_root is required when image filenames are relative.") + resolved_paths.append(os.path.join(images_root, name)) + return resolved_paths + + +def _iter_clusters_with_paths(cluster_tree) -> Iterable[tuple[tuple[int, ...], Sequence[tuple[int, int]], bool]]: + path_tree: Tree[tuple[tuple[int, ...], Sequence[tuple[int, int]]]] = cluster_tree.map_with_path( + lambda path, visibility_graph: (path, visibility_graph) + ) + for node in PreOrderIter(path_tree): + path, visibility_graph = node.value + yield path, visibility_graph, node.is_leaf() + + +def _should_run_cluster(path: tuple[int, ...], is_leaf: bool, args: argparse.Namespace) -> bool: + if path == () and not args.run_root: + return False + if is_leaf and args.run_leaf: + return True + if (not is_leaf) and args.run_parent: + return True + return False + + +def _log_message(log_path: Path, message: str) -> None: + with open(log_path, "a", encoding="utf-8") as log_file: + log_file.write(f"{message}\n") + print(message) + + +def farthest_point_sampling(distance_matrix: torch.Tensor, num_samples: int, most_common_frame_index: int = 0) -> list[int]: + distance_matrix = distance_matrix.clamp(min=0) + num_frames = distance_matrix.size(0) + selected_indices = [most_common_frame_index] + check_distances = distance_matrix[selected_indices] + + while len(selected_indices) < num_samples: + farthest_point = torch.argmax(check_distances) + selected_indices.append(farthest_point.item()) + check_distances = distance_matrix[farthest_point] + check_distances[selected_indices] = 0 + if len(selected_indices) == num_frames: + break + + return selected_indices + + +def generate_rank_by_dino( + images: torch.Tensor, + query_frame_num: int, + image_size: int = 518, + model_name: str = "dinov2_vitb14_reg", + device: str = "cuda", + spatial_similarity: bool = True, +) -> list[int]: + del image_size # maintained for compatibility with original signature + dino_v2_model = torch.hub.load("facebookresearch/dinov2", model_name) + dino_v2_model.eval() + dino_v2_model = dino_v2_model.to(device) + + resnet_mean = torch.tensor(_RESNET_MEAN, device=device).view(1, 3, 1, 1) + resnet_std = torch.tensor(_RESNET_STD, device=device).view(1, 3, 1, 1) + images_resnet_norm = (images - resnet_mean) / resnet_std + + with torch.no_grad(): + frame_feat = dino_v2_model(images_resnet_norm, is_training=True) + + if spatial_similarity: + frame_feat = frame_feat["x_norm_patchtokens"] + frame_feat_norm = F.normalize(frame_feat, p=2, dim=1) + frame_feat_norm = frame_feat_norm.permute(1, 0, 2) + similarity_matrix = torch.bmm(frame_feat_norm, frame_feat_norm.transpose(-1, -2)).mean(dim=0) + else: + frame_feat = frame_feat["x_norm_clstoken"] + frame_feat_norm = F.normalize(frame_feat, p=2, dim=1) + similarity_matrix = torch.mm(frame_feat_norm, frame_feat_norm.transpose(-1, -2)) + + distance_matrix = 100 - similarity_matrix.clone() + similarity_matrix.fill_diagonal_(-100) + similarity_sum = similarity_matrix.sum(dim=1) + most_common_frame_index = torch.argmax(similarity_sum).item() + return farthest_point_sampling(distance_matrix, query_frame_num, most_common_frame_index) + + +def calculate_index_mappings(query_index: int, frame_num: int, device: str | torch.device | None = None) -> torch.Tensor: + new_order = torch.arange(frame_num) + new_order[0] = query_index + new_order[query_index] = 0 + if device is not None: + new_order = new_order.to(device) + return new_order + + +def switch_tensor_order(tensors: list[torch.Tensor | None], order: torch.Tensor, dim: int = 1) -> list[torch.Tensor | None]: + return [torch.index_select(tensor, dim, order) if tensor is not None else None for tensor in tensors] + + +def predict_track( + model: VGGT, + images: torch.Tensor, + query_points: torch.Tensor, + dtype: torch.dtype = torch.bfloat16, + use_tf32_for_track: bool = True, + iters: int = 4, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + with torch.no_grad(): + with torch.cuda.amp.autocast(dtype=dtype): + images = images[None] + aggregated_tokens_list, ps_idx = model.aggregator(images) + if not use_tf32_for_track: + track_list, vis_score, conf_score = model.track_head( + aggregated_tokens_list, images, ps_idx, query_points=query_points, iters=iters + ) + + if use_tf32_for_track: + with torch.cuda.amp.autocast(enabled=False): + track_list, vis_score, conf_score = model.track_head( + aggregated_tokens_list, images, ps_idx, query_points=query_points, iters=iters + ) + + pred_track = track_list[-1] + return pred_track.squeeze(0), vis_score.squeeze(0), conf_score.squeeze(0) + + +def initialize_feature_extractors( + max_query_num: int, det_thres: float, extractor_method: str = "aliked", device: str = "cuda" +) -> dict[str, torch.nn.Module]: + from lightglue import ALIKED, SIFT, SuperPoint + + extractors = {} + methods = extractor_method.lower().split("+") + + for method in methods: + method = method.strip() + if method == "aliked": + extractors["aliked"] = ALIKED(max_num_keypoints=max_query_num, detection_threshold=det_thres).to(device).eval() + elif method == "sp": + extractors["sp"] = SuperPoint(max_num_keypoints=max_query_num, detection_threshold=det_thres).to(device).eval() + elif method == "sift": + extractors["sift"] = SIFT(max_num_keypoints=max_query_num).to(device).eval() + else: + print(f"Warning: Unknown feature extractor '{method}', ignoring.") + + if not extractors: + print(f"Warning: No valid extractors found in '{extractor_method}'. Using ALIKED by default.") + extractors["aliked"] = ALIKED(max_num_keypoints=max_query_num, detection_threshold=det_thres).to(device).eval() + + return extractors + + +def extract_keypoints(query_image: torch.Tensor, extractors: dict[str, torch.nn.Module], max_query_num: int) -> torch.Tensor: + query_points_round = None + with torch.no_grad(): + for extractor in extractors.values(): + query_points_data = extractor.extract(query_image) + extractor_points = query_points_data["keypoints"].round() + if query_points_round is not None: + query_points_round = torch.cat([query_points_round, extractor_points], dim=1) + else: + query_points_round = extractor_points + + if query_points_round.shape[1] > max_query_num: + random_point_indices = torch.randperm(query_points_round.shape[1])[:max_query_num] + query_points_round = query_points_round[:, random_point_indices, :] + return query_points_round + + +def _run_vggt_geometry(model: VGGT, images: torch.Tensor, dtype: torch.dtype) -> dict[str, torch.Tensor]: + with torch.no_grad(): + with torch.cuda.amp.autocast(dtype=dtype): + predictions = model(images) + + with torch.cuda.amp.autocast(dtype=torch.float64): + extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:]) + pred_extrinsic = extrinsic[0] + pred_intrinsic = intrinsic[0] + depth_map = predictions["depth"][0] + depth_conf = predictions["depth_conf"][0] + world_points = unproject_depth_map_to_point_map(depth_map, pred_extrinsic, pred_intrinsic) + + return { + "pred_extrinsic": pred_extrinsic, + "pred_intrinsic": pred_intrinsic, + "depth_map": depth_map, + "depth_conf": depth_conf, + "world_points": torch.from_numpy(world_points).to(images.device), + } + + +def _decode_colmap_pair_id(pair_id: int) -> tuple[int, int]: + pair_id_scale = 2147483647 + image_id2 = pair_id % pair_id_scale + image_id1 = (pair_id - image_id2) // pair_id_scale + return int(image_id1), int(image_id2) + + +def _load_colmap_db_tracks(database_path: str) -> tuple[dict[int, np.ndarray], list[tuple[int, int, np.ndarray]]]: + connection = sqlite3.connect(database_path) + try: + keypoints_by_image: dict[int, np.ndarray] = {} + for image_id, rows, cols, data in connection.execute("SELECT image_id, rows, cols, data FROM keypoints"): + if data is None or rows == 0: + continue + keypoints = np.frombuffer(data, dtype=np.float32).reshape(rows, cols)[:, :2] + keypoints_by_image[int(image_id)] = keypoints + + pair_matches: list[tuple[int, int, np.ndarray]] = [] + if any(True for _ in connection.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='two_view_geometries'")): + query = "SELECT pair_id, rows, cols, data FROM two_view_geometries" + else: + query = "SELECT pair_id, rows, cols, data FROM matches" + for pair_id, rows, cols, data in connection.execute(query): + if data is None or rows == 0: + continue + matches = np.frombuffer(data, dtype=np.uint32).reshape(rows, cols)[:, :2] + image_id1, image_id2 = _decode_colmap_pair_id(int(pair_id)) + pair_matches.append((image_id1, image_id2, matches)) + finally: + connection.close() + return keypoints_by_image, pair_matches + + +def run_colmap_tracking( + model: VGGT, + images: torch.Tensor, + image_names: Sequence[str], + dtype: torch.dtype = torch.bfloat16, + camera_type: str = "SIMPLE_PINHOLE", + apply_depth_filter: bool = True, +) -> dict: + if len(image_names) != images.shape[0]: + raise ValueError("image_names/images length mismatch for COLMAP tracking.") + assert "RADIAL" not in camera_type, "RADIAL camera is not supported yet" + + try: + import pycolmap + except ModuleNotFoundError as exc: + raise ModuleNotFoundError("pycolmap is required for colmap tracker mode.") from exc + + device = images.device + frame_num = images.shape[0] + geometry = _run_vggt_geometry(model, images, dtype) + pred_extrinsic = geometry["pred_extrinsic"] + pred_intrinsic = geometry["pred_intrinsic"] + depth_map = geometry["depth_map"] + depth_conf = geometry["depth_conf"] + world_points = geometry["world_points"] + + with tempfile.TemporaryDirectory(prefix="vggt_colmap_") as tmp_dir: + temp_root = Path(tmp_dir) + images_dir = temp_root / "images" + images_dir.mkdir(parents=True, exist_ok=True) + database_path = temp_root / "database.db" + + alias_to_idx: dict[str, int] = {} + images_np = (images.detach().cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8).transpose(0, 2, 3, 1) + for idx in range(frame_num): + alias_name = f"{idx:06d}.png" + alias_to_idx[alias_name] = idx + Image.fromarray(images_np[idx]).save(images_dir / alias_name) + + pycolmap.extract_features(str(database_path), str(images_dir)) + pycolmap.match_exhaustive(str(database_path)) + + keypoints_by_image, pair_matches = _load_colmap_db_tracks(str(database_path)) + connection = sqlite3.connect(str(database_path)) + try: + image_id_to_idx: dict[int, int] = {} + for image_id, name in connection.execute("SELECT image_id, name FROM images"): + if name in alias_to_idx: + image_id_to_idx[int(image_id)] = alias_to_idx[name] + finally: + connection.close() + + parent: dict[tuple[int, int], tuple[int, int]] = {} + + def _find(node: tuple[int, int]) -> tuple[int, int]: + if node not in parent: + parent[node] = node + while parent[node] != node: + parent[node] = parent[parent[node]] + node = parent[node] + return node + + def _union(a: tuple[int, int], b: tuple[int, int]) -> None: + root_a, root_b = _find(a), _find(b) + if root_a != root_b: + parent[root_b] = root_a + + for image_id1, image_id2, matches in pair_matches: + if image_id1 not in image_id_to_idx or image_id2 not in image_id_to_idx: + continue + for kp1, kp2 in matches: + _union((image_id1, int(kp1)), (image_id2, int(kp2))) + + components: dict[tuple[int, int], list[tuple[int, int]]] = {} + for node in parent: + root = _find(node) + components.setdefault(root, []).append(node) + + tracks_list: list[list[tuple[int, float, float]]] = [] + points3d_list: list[torch.Tensor] = [] + conf_list: list[torch.Tensor] = [] + + for nodes in components.values(): + frame_obs: dict[int, tuple[float, float]] = {} + for image_id, kp_idx in nodes: + frame_idx = image_id_to_idx.get(image_id) + if frame_idx is None: + continue + keypoints = keypoints_by_image.get(image_id) + if keypoints is None or kp_idx >= keypoints.shape[0]: + continue + if frame_idx in frame_obs: + continue + xy = keypoints[kp_idx] + frame_obs[frame_idx] = (float(xy[0]), float(xy[1])) + if len(frame_obs) < 2: + continue + + obs = sorted((fidx, xy[0], xy[1]) for fidx, xy in frame_obs.items()) + seed_frame, seed_x, seed_y = obs[0] + sx = int(np.clip(round(seed_x), 0, world_points.shape[2] - 1)) + sy = int(np.clip(round(seed_y), 0, world_points.shape[1] - 1)) + points3d_list.append(world_points[seed_frame, sy, sx]) + conf_list.append(depth_conf[seed_frame, sy, sx]) + tracks_list.append(obs) + + num_tracks = len(tracks_list) + pred_tracks = torch.zeros((frame_num, num_tracks, 2), dtype=torch.float32, device=device) + pred_vis_scores = torch.zeros((frame_num, num_tracks), dtype=torch.float32, device=device) + pred_conf_scores = torch.zeros((frame_num, num_tracks), dtype=torch.float32, device=device) + masks = torch.zeros((frame_num, num_tracks), dtype=torch.bool, device=device) + + for tidx, obs in enumerate(tracks_list): + for frame_idx, x, y in obs: + pred_tracks[frame_idx, tidx, 0] = float(x) + pred_tracks[frame_idx, tidx, 1] = float(y) + pred_vis_scores[frame_idx, tidx] = 1.0 + pred_conf_scores[frame_idx, tidx] = 1.0 + masks[frame_idx, tidx] = True + + if points3d_list: + pred_world_points = torch.stack(points3d_list, dim=0).to(device=device, dtype=torch.float32) + pred_world_points_conf = torch.stack(conf_list, dim=0).to(device=device, dtype=torch.float32) + else: + pred_world_points = torch.zeros((0, 3), dtype=torch.float32, device=device) + pred_world_points_conf = torch.zeros((0,), dtype=torch.float32, device=device) + + if apply_depth_filter: + filtered_flag = pred_world_points_conf > 1.5 + if filtered_flag.sum() > 0: + pred_world_points = pred_world_points[filtered_flag] + pred_world_points_conf = pred_world_points_conf[filtered_flag] + pred_tracks = pred_tracks[:, filtered_flag] + pred_vis_scores = pred_vis_scores[:, filtered_flag] + pred_conf_scores = pred_conf_scores[:, filtered_flag] + masks = masks[:, filtered_flag] + + _, _, H, W = images.shape + image_size = torch.tensor([W, H], dtype=torch.float32, device=device) + return { + "pred_tracks": pred_tracks, + "pred_vis_scores": pred_vis_scores, + "pred_conf_scores": pred_conf_scores, + "pred_world_points": pred_world_points, + "pred_world_points_conf": pred_world_points_conf, + "pred_extrinsic": pred_extrinsic, + "pred_intrinsic": pred_intrinsic, + "image_size": image_size, + "masks": masks, + "device": device, + "camera_type": camera_type, + "depth_map": depth_map, + "depth_conf": depth_conf, + } + + +def run_vggt_tracking( + model: VGGT, + images: torch.Tensor, + image_names: Sequence[str] | None = None, + dtype: torch.dtype = torch.bfloat16, + max_query_num: int = 2048, + det_thres: float = 0.005, + query_frame_num: int = 3, + extractor_method: str = "aliked+sp+sift", + camera_type: str = "SIMPLE_PINHOLE", + apply_depth_filter: bool = True, +) -> dict: + del image_names + assert "RADIAL" not in camera_type, "RADIAL camera is not supported yet" + + device = images.device + frame_num = images.shape[0] + + geometry = _run_vggt_geometry(model, images, dtype) + pred_extrinsic = geometry["pred_extrinsic"] + pred_intrinsic = geometry["pred_intrinsic"] + depth_map = geometry["depth_map"] + depth_conf = geometry["depth_conf"] + world_points = geometry["world_points"] + + query_frame_indexes = generate_rank_by_dino( + images, query_frame_num, image_size=518, model_name="dinov2_vitb14_reg", device=device, spatial_similarity=False + ) + if 0 in query_frame_indexes: + query_frame_indexes.remove(0) + query_frame_indexes = [0, *query_frame_indexes] + + world_points_conf = depth_conf.to(device) + torch.cuda.empty_cache() + + pred_tracks = [] + pred_vis_scores = [] + pred_conf_scores = [] + pred_world_points = [] + pred_world_points_conf = [] + + extractors = initialize_feature_extractors(max_query_num, det_thres, extractor_method, str(device)) + + for query_index in query_frame_indexes: + query_image = images[query_index] + query_points_round = extract_keypoints(query_image, extractors, max_query_num) + + reorder_index = calculate_index_mappings(query_index, frame_num, device=device) + reorder_images = switch_tensor_order([images], reorder_index, dim=0)[0] + reorder_tracks, reorder_vis_score, reorder_conf_score = predict_track( + model, reorder_images, query_points_round, dtype=dtype, use_tf32_for_track=True, iters=4 + ) + pred_track, pred_vis, pred_score = switch_tensor_order( + [reorder_tracks, reorder_vis_score, reorder_conf_score], reorder_index, dim=0 + ) + pred_tracks.append(pred_track) + pred_vis_scores.append(pred_vis) + pred_conf_scores.append(pred_score) + + query_points_round_long = query_points_round.squeeze(0).long() + query_world_points = world_points[query_index][query_points_round_long[:, 1], query_points_round_long[:, 0]] + query_world_points_conf = world_points_conf[query_index][query_points_round_long[:, 1], query_points_round_long[:, 0]] + pred_world_points.append(query_world_points) + pred_world_points_conf.append(query_world_points_conf) + + pred_tracks = torch.cat(pred_tracks, dim=1) + pred_vis_scores = torch.cat(pred_vis_scores, dim=1) + pred_conf_scores = torch.cat(pred_conf_scores, dim=1) + pred_world_points = torch.cat(pred_world_points, dim=0) + pred_world_points_conf = torch.cat(pred_world_points_conf, dim=0) + + if apply_depth_filter: + filtered_flag = pred_world_points_conf > 1.5 + if filtered_flag.sum() > max_query_num // 2: + pred_world_points = pred_world_points[filtered_flag] + pred_world_points_conf = pred_world_points_conf[filtered_flag] + pred_tracks = pred_tracks[:, filtered_flag] + pred_vis_scores = pred_vis_scores[:, filtered_flag] + pred_conf_scores = pred_conf_scores[:, filtered_flag] + + torch.cuda.empty_cache() + _, _, H, W = images.shape + image_size = torch.tensor([W, H], dtype=pred_tracks.dtype, device=device) + masks = torch.logical_and(pred_vis_scores > 0.05, pred_conf_scores > 0.2) + + return { + "pred_tracks": pred_tracks, + "pred_vis_scores": pred_vis_scores, + "pred_conf_scores": pred_conf_scores, + "pred_world_points": pred_world_points, + "pred_world_points_conf": pred_world_points_conf, + "pred_extrinsic": pred_extrinsic, + "pred_intrinsic": pred_intrinsic, + "image_size": image_size, + "masks": masks, + "device": device, + "camera_type": camera_type, + "depth_map": depth_map, + "depth_conf": depth_conf, + } + + +def setup_model(args: argparse.Namespace): + np.random.seed(args.seed) + torch.manual_seed(args.seed) + random.seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + device = "cuda" if torch.cuda.is_available() else "cpu" + dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16 + model = VGGT() + model.load_state_dict(torch.load(args.model_path)) + model.eval() + model = model.to(device) + return model, device, dtype + + +def _rotmat_to_colmap_qvec(R: np.ndarray) -> np.ndarray: + Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat + K = ( + np.array( + [ + [Rxx - Ryy - Rzz, 0.0, 0.0, 0.0], + [Ryx + Rxy, Ryy - Rxx - Rzz, 0.0, 0.0], + [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0.0], + [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz], + ], + dtype=np.float64, + ) + / 3.0 + ) + eigvals, eigvecs = np.linalg.eigh(K) + qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] + if qvec[0] < 0: + qvec *= -1 + return qvec + + +def _camera_params_from_K(K: np.ndarray, camera_model: str) -> list[float]: + fx = float(K[0, 0]) + fy = float(K[1, 1]) + cx = float(K[0, 2]) + cy = float(K[1, 2]) + if camera_model == "PINHOLE": + return [fx, fy, cx, cy] + if camera_model == "SIMPLE_PINHOLE": + return [(fx + fy) * 0.5, cx, cy] + raise ValueError(f"Unsupported camera model for export: {camera_model}") + + +def _project_points( + points3d: np.ndarray, extrinsics: np.ndarray, intrinsics: np.ndarray, eps: float = 1e-8 +) -> tuple[np.ndarray, np.ndarray]: + points_h = np.concatenate([points3d, np.ones((points3d.shape[0], 1), dtype=points3d.dtype)], axis=1) + cam_points = np.einsum("sij,nj->sni", extrinsics, points_h) + z = cam_points[:, :, 2] + x = cam_points[:, :, 0] / np.maximum(z, eps) + y = cam_points[:, :, 1] / np.maximum(z, eps) + ones = np.ones_like(x) + normalized = np.stack([x, y, ones], axis=-1) + pixels = np.einsum("sij,snj->sni", intrinsics, normalized)[:, :, :2] + return pixels, z + + +def _triangulate_track_point( + observations: np.ndarray, projection_matrices: np.ndarray +) -> np.ndarray | None: + rows = [] + for (x, y), proj in zip(observations, projection_matrices): + rows.append(x * proj[2, :] - proj[0, :]) + rows.append(y * proj[2, :] - proj[1, :]) + A = np.asarray(rows, dtype=np.float64) + if A.shape[0] < 4: + return None + _, _, vt = np.linalg.svd(A, full_matrices=False) + homog = vt[-1] + if abs(float(homog[3])) < 1e-12: + return None + return (homog[:3] / homog[3]).astype(np.float32) + + +def _replace_points_with_triangulation( + tracking_outputs: dict, + min_views: int, + max_reproj_error: float, +) -> dict: + tracks = tracking_outputs["pred_tracks"].detach().cpu().numpy() + masks = tracking_outputs["masks"].detach().cpu().numpy().astype(bool) + intrinsics = tracking_outputs["pred_intrinsic"].detach().cpu().numpy() + extrinsics = tracking_outputs["pred_extrinsic"].detach().cpu().numpy() + + num_frames, num_tracks, _ = tracks.shape + if num_tracks == 0: + return tracking_outputs + + projection_matrices = np.matmul(intrinsics, extrinsics[:, :3, :]) + triangulated_points: list[np.ndarray] = [] + valid_track_indices: list[int] = [] + + for track_idx in range(num_tracks): + obs_mask = masks[:, track_idx] + obs_indices = np.where(obs_mask)[0] + if obs_indices.size < min_views: + continue + + observations = tracks[obs_indices, track_idx, :] + point3d = _triangulate_track_point(observations, projection_matrices[obs_indices]) + if point3d is None: + continue + + point3d_batch = point3d.reshape(1, 3) + projected_xy, projected_z = _project_points(point3d_batch, extrinsics[obs_indices], intrinsics[obs_indices]) + reproj_err = np.linalg.norm(projected_xy[:, 0, :] - observations, axis=1) + if np.count_nonzero(np.logical_and(projected_z[:, 0] > 0, reproj_err < max_reproj_error)) < min_views: + continue + + triangulated_points.append(point3d) + valid_track_indices.append(track_idx) + + if not valid_track_indices: + device = tracking_outputs["pred_tracks"].device + tracking_outputs["pred_tracks"] = tracking_outputs["pred_tracks"][:, :0] + tracking_outputs["pred_vis_scores"] = tracking_outputs["pred_vis_scores"][:, :0] + tracking_outputs["pred_conf_scores"] = tracking_outputs["pred_conf_scores"][:, :0] + tracking_outputs["masks"] = tracking_outputs["masks"][:, :0] + tracking_outputs["pred_world_points"] = torch.zeros((0, 3), dtype=torch.float32, device=device) + tracking_outputs["pred_world_points_conf"] = torch.zeros((0,), dtype=torch.float32, device=device) + return tracking_outputs + + keep_idx = torch.tensor(valid_track_indices, dtype=torch.long, device=tracking_outputs["pred_tracks"].device) + tracking_outputs["pred_tracks"] = tracking_outputs["pred_tracks"].index_select(1, keep_idx) + tracking_outputs["pred_vis_scores"] = tracking_outputs["pred_vis_scores"].index_select(1, keep_idx) + tracking_outputs["pred_conf_scores"] = tracking_outputs["pred_conf_scores"].index_select(1, keep_idx) + tracking_outputs["masks"] = tracking_outputs["masks"].index_select(1, keep_idx) + tracking_outputs["pred_world_points"] = torch.tensor( + np.asarray(triangulated_points, dtype=np.float32), + dtype=torch.float32, + device=tracking_outputs["pred_tracks"].device, + ) + tracking_outputs["pred_world_points_conf"] = torch.ones( + (len(valid_track_indices),), dtype=torch.float32, device=tracking_outputs["pred_tracks"].device + ) + return tracking_outputs + + +def _export_tracking_to_colmap_text( + tracking_outputs: dict, + image_names: Sequence[str], + out_dir: str | Path, + images_tensor: torch.Tensor, + shared_camera: bool = False, + min_track_length: int = 2, + max_reproj_error: float = 12.0, +) -> None: + os.makedirs(out_dir, exist_ok=True) + + tracks = tracking_outputs["pred_tracks"].detach().cpu().numpy() + masks = tracking_outputs["masks"].detach().cpu().numpy().astype(bool) + points3d = tracking_outputs["pred_world_points"].detach().cpu().numpy() + intrinsics = tracking_outputs["pred_intrinsic"].detach().cpu().numpy() + extrinsics = tracking_outputs["pred_extrinsic"].detach().cpu().numpy() + + S, N, _ = tracks.shape + if S != len(image_names): + raise ValueError(f"Mismatch: {S} frames in tracking output but {len(image_names)} image names provided.") + + image_size = tracking_outputs["image_size"].detach().cpu().numpy() + width = int(round(float(image_size[0]))) + height = int(round(float(image_size[1]))) + camera_model = tracking_outputs.get("camera_type", "SIMPLE_PINHOLE") + if camera_model not in {"SIMPLE_PINHOLE", "PINHOLE"}: + camera_model = "SIMPLE_PINHOLE" + + projected_xy, projected_z = _project_points(points3d, extrinsics, intrinsics) + reproj_error = np.linalg.norm(projected_xy - tracks, axis=-1) + reproj_mask = np.logical_and(projected_z > 0, reproj_error < max_reproj_error) + valid_observations = np.logical_and(masks, reproj_mask) + + obs_counts = valid_observations.sum(axis=0) + valid_point_mask = obs_counts >= min_track_length + point_id_map = np.full(N, -1, dtype=np.int64) + valid_point_indices = np.where(valid_point_mask)[0] + for idx, point_idx in enumerate(valid_point_indices): + point_id_map[point_idx] = idx + 1 + + images_np = (images_tensor.detach().cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8) + images_np = images_np.transpose(0, 2, 3, 1) + + image_observations = [[] for _ in range(S)] + point_tracks = {int(point_id_map[p]): [] for p in valid_point_indices} + point_rgb = {} + + for frame_idx in range(S): + visible_points = np.where(valid_observations[frame_idx])[0] + for point_idx in visible_points: + point3d_id = int(point_id_map[point_idx]) + if point3d_id < 0: + continue + x, y = tracks[frame_idx, point_idx] + p2d_idx = len(image_observations[frame_idx]) + image_observations[frame_idx].append((float(x), float(y), point3d_id)) + point_tracks[point3d_id].append((frame_idx + 1, p2d_idx)) + if point3d_id not in point_rgb: + u = int(np.clip(round(float(x)), 0, width - 1)) + v = int(np.clip(round(float(y)), 0, height - 1)) + rgb = images_np[frame_idx, v, u] + point_rgb[point3d_id] = (int(rgb[0]), int(rgb[1]), int(rgb[2])) + + with open(os.path.join(out_dir, "cameras.txt"), "w", encoding="utf-8") as f: + num_cameras = 1 if shared_camera else S + f.write("# Camera list with one line of data per camera:\n") + f.write("# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n") + f.write(f"# Number of cameras: {num_cameras}\n") + if shared_camera: + params = _camera_params_from_K(intrinsics[0], camera_model) + f.write(f"1 {camera_model} {width} {height} {' '.join(map(str, params))}\n") + else: + for frame_idx in range(S): + params = _camera_params_from_K(intrinsics[frame_idx], camera_model) + camera_id = frame_idx + 1 + f.write(f"{camera_id} {camera_model} {width} {height} {' '.join(map(str, params))}\n") + + with open(os.path.join(out_dir, "images.txt"), "w", encoding="utf-8") as f: + f.write("# Image list with two lines of data per image:\n") + f.write("# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n") + f.write("# POINTS2D[] as (X, Y, POINT3D_ID)\n") + f.write(f"# Number of images: {S}\n") + for frame_idx in range(S): + image_id = frame_idx + 1 + camera_id = 1 if shared_camera else image_id + R = extrinsics[frame_idx, :3, :3] + t = extrinsics[frame_idx, :3, 3] + qvec = _rotmat_to_colmap_qvec(R) + name = image_names[frame_idx] + f.write( + f"{image_id} {qvec[0]} {qvec[1]} {qvec[2]} {qvec[3]} " + f"{t[0]} {t[1]} {t[2]} {camera_id} {name}\n" + ) + p2d_tokens = [] + for x, y, point3d_id in image_observations[frame_idx]: + p2d_tokens.extend([str(x), str(y), str(point3d_id)]) + f.write((" ".join(p2d_tokens) + "\n") if p2d_tokens else "\n") + + with open(os.path.join(out_dir, "points3D.txt"), "w", encoding="utf-8") as f: + f.write("# 3D point list with one line of data per point:\n") + f.write("# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n") + f.write(f"# Number of points: {len(valid_point_indices)}\n") + for point_idx in valid_point_indices: + point3d_id = int(point_id_map[point_idx]) + xyz = points3d[point_idx] + rgb = point_rgb.get(point3d_id, (255, 255, 255)) + track_tokens = [] + for image_id, p2d_idx in point_tracks[point3d_id]: + track_tokens.extend([str(image_id), str(p2d_idx)]) + f.write( + f"{point3d_id} {xyz[0]} {xyz[1]} {xyz[2]} {rgb[0]} {rgb[1]} {rgb[2]} 1.0 " + + " ".join(track_tokens) + + "\n" + ) + + +def run_vggt_reconstruction( + args: argparse.Namespace, + model: VGGT, + device: str, + dtype: torch.dtype, + image_path_list: Sequence[str], + output_dir: str, + image_name_list: Sequence[str] | None = None, + tracker_backend: str = "vggt", +) -> bool: + if len(image_path_list) == 0: + raise ValueError("No images provided to VGGT.") + if image_name_list is None: + image_name_list = [os.path.basename(path) for path in image_path_list] + if next(model.parameters()).device.type != device: + model.to(device) + + images = load_and_preprocess_images(image_path_list).to(device=device) + use_depth_points = args.point_source == "depth" + if tracker_backend == "colmap": + tracking_outputs = run_colmap_tracking( + model, + images, + image_names=list(image_name_list), + dtype=dtype, + apply_depth_filter=use_depth_points, + ) + else: + tracking_outputs = run_vggt_tracking( + model, + images, + image_names=list(image_name_list), + dtype=dtype, + max_query_num=args.tracking_max_query_pts, + query_frame_num=args.tracking_query_frame_num, + extractor_method=args.tracking_keypoint_extractor, + apply_depth_filter=use_depth_points, + ) + + if args.point_source == "triangulation": + tracking_outputs = _replace_points_with_triangulation( + tracking_outputs=tracking_outputs, + min_views=args.triangulation_min_views, + max_reproj_error=args.max_reproj_error, + ) + + os.makedirs(output_dir, exist_ok=True) + _export_tracking_to_colmap_text( + tracking_outputs=tracking_outputs, + image_names=image_name_list, + out_dir=output_dir, + images_tensor=images, + shared_camera=False, + min_track_length=2, + max_reproj_error=args.max_reproj_error, + ) + return True + + +def _promote_tracking_writer_outputs(output_dir: Path, log_path: Path) -> None: + """Overwrite top-level COLMAP txt files with tracking_outputs writer results when available.""" + tracking_dir = output_dir / "tracking_outputs" + required = ("cameras.txt", "images.txt", "points3D.txt") + missing = [name for name in required if not (tracking_dir / name).exists()] + if missing: + _log_message(log_path, f"Tracking writer outputs missing in {tracking_dir}: {missing}. Keeping existing files.") + return + + for name in required: + shutil.copy2(tracking_dir / name, output_dir / name) + _log_message(log_path, f"Promoted tracking writer outputs to {output_dir}.") + + +def _parse_colmap_cameras_txt(path: Path) -> dict[int, tuple[str, int, int, list[float]]]: + cameras: dict[int, tuple[str, int, int, list[float]]] = {} + with open(path, "r", encoding="utf-8") as f: + for raw_line in f: + line = raw_line.strip() + if not line or line.startswith("#"): + continue + parts = line.split() + if len(parts) < 5: + continue + camera_id = int(parts[0]) + model = parts[1] + width = int(parts[2]) + height = int(parts[3]) + params = [float(v) for v in parts[4:]] + cameras[camera_id] = (model, width, height, params) + return cameras + + +def _parse_colmap_images_name_to_camera(path: Path) -> dict[str, int]: + name_to_camera: dict[str, int] = {} + with open(path, "r", encoding="utf-8") as f: + lines = f.readlines() + + non_comment_idx = 0 + for raw_line in lines: + line = raw_line.strip() + if not line or line.startswith("#"): + continue + if non_comment_idx % 2 == 0: + parts = line.split() + if len(parts) >= 10: + camera_id = int(parts[8]) + image_name = " ".join(parts[9:]) + name_to_camera[image_name] = camera_id + non_comment_idx += 1 + return name_to_camera + + +def _parse_colmap_images_pose(path: Path) -> dict[str, tuple[float, float, float, float, float, float, float]]: + name_to_pose: dict[str, tuple[float, float, float, float, float, float, float]] = {} + with open(path, "r", encoding="utf-8") as f: + lines = f.readlines() + + non_comment_idx = 0 + for raw_line in lines: + line = raw_line.strip() + if not line or line.startswith("#"): + continue + if non_comment_idx % 2 == 0: + parts = line.split() + if len(parts) >= 10: + image_name = " ".join(parts[9:]) + pose = tuple(float(v) for v in parts[1:8]) + name_to_pose[image_name] = pose + non_comment_idx += 1 + return name_to_pose + + +def _apply_gt_pose_to_input_model(input_dir: Path, gt_pose_dir: Path, log_path: Path) -> Path: + required = ("cameras.txt", "images.txt", "points3D.txt") + missing_input = [name for name in required if not (input_dir / name).exists()] + if missing_input: + _log_message(log_path, f"GT pose override skipped: missing {missing_input} in {input_dir}.") + return input_dir + if not (gt_pose_dir / "images.txt").exists(): + _log_message(log_path, f"GT pose override skipped: missing images.txt in {gt_pose_dir}.") + return input_dir + + gt_name_to_pose = _parse_colmap_images_pose(gt_pose_dir / "images.txt") + gt_basename_to_pose: dict[str, tuple[float, float, float, float, float, float, float]] = {} + for name, pose in gt_name_to_pose.items(): + base = os.path.basename(name) + if base not in gt_basename_to_pose: + gt_basename_to_pose[base] = pose + + with open(input_dir / "images.txt", "r", encoding="utf-8") as f: + lines = f.readlines() + + out_lines: list[str] = [] + updated_count = 0 + non_comment_idx = 0 + for raw_line in lines: + line = raw_line.rstrip("\n") + stripped = line.strip() + if not stripped or stripped.startswith("#"): + out_lines.append(raw_line) + continue + + if non_comment_idx % 2 == 0: + parts = stripped.split() + if len(parts) >= 10: + image_name = " ".join(parts[9:]) + pose = gt_name_to_pose.get(image_name) + if pose is None: + pose = gt_basename_to_pose.get(os.path.basename(image_name)) + if pose is not None: + parts[1:8] = [str(v) for v in pose] + raw_line = " ".join(parts) + "\n" + updated_count += 1 + out_lines.append(raw_line) + non_comment_idx += 1 + + if updated_count == 0: + _log_message(log_path, f"GT pose override found no matching images between {input_dir} and {gt_pose_dir}.") + return input_dir + + temp_dir = Path(tempfile.mkdtemp(prefix="ba_gt_pose_")) + shutil.copy2(input_dir / "cameras.txt", temp_dir / "cameras.txt") + shutil.copy2(input_dir / "points3D.txt", temp_dir / "points3D.txt") + with open(temp_dir / "images.txt", "w", encoding="utf-8") as f: + f.writelines(out_lines) + _log_message(log_path, f"Applied GT poses to {updated_count} images using {gt_pose_dir}. Camera intrinsics untouched.") + return temp_dir + + +def _write_colmap_cameras_txt(path: Path, cameras: dict[int, tuple[str, int, int, list[float]]]) -> None: + with open(path, "w", encoding="utf-8") as f: + f.write("# Camera list with one line of data per camera:\n") + f.write("# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n") + f.write(f"# Number of cameras: {len(cameras)}\n") + for camera_id in sorted(cameras): + model, width, height, params = cameras[camera_id] + f.write(f"{camera_id} {model} {width} {height} {' '.join(map(str, params))}\n") + + +def _extract_intrinsics(model: str, params: list[float]) -> tuple[float, float, float, float]: + if model == "SIMPLE_PINHOLE": + f, cx, cy = params[:3] + return float(f), float(f), float(cx), float(cy) + if model == "PINHOLE": + fx, fy, cx, cy = params[:4] + return float(fx), float(fy), float(cx), float(cy) + if model in {"SIMPLE_RADIAL", "SIMPLE_RADIAL_FISHEYE"}: + f, cx, cy = params[:3] + return float(f), float(f), float(cx), float(cy) + if len(params) >= 4: + fx, fy, cx, cy = params[:4] + return float(fx), float(fy), float(cx), float(cy) + raise ValueError(f"Unsupported camera model/params for intrinsic extraction: {model}, params={params}") + + +def _convert_scaled_intrinsics( + src_model: str, + src_params: list[float], + src_width: int, + src_height: int, + dst_model: str, + dst_params: list[float], + dst_width: int, + dst_height: int, +) -> list[float]: + fx, fy, cx, cy = _extract_intrinsics(src_model, src_params) + sx = float(dst_width) / float(max(src_width, 1)) + sy = float(dst_height) / float(max(src_height, 1)) + fx *= sx + fy *= sy + cx *= sx + cy *= sy + + out_params = list(dst_params) + if dst_model == "SIMPLE_PINHOLE": + f = 0.5 * (fx + fy) + if len(out_params) < 3: + out_params = [0.0, 0.0, 0.0] + out_params[0], out_params[1], out_params[2] = float(f), float(cx), float(cy) + return out_params + if dst_model in {"PINHOLE", "OPENCV", "FULL_OPENCV", "OPENCV_FISHEYE"}: + if len(out_params) < 4: + out_params = [0.0, 0.0, 0.0, 0.0] + out_params[0], out_params[1], out_params[2], out_params[3] = float(fx), float(fy), float(cx), float(cy) + return out_params + if dst_model in {"SIMPLE_RADIAL", "SIMPLE_RADIAL_FISHEYE"}: + f = 0.5 * (fx + fy) + if len(out_params) < 3: + out_params = [0.0, 0.0, 0.0] + out_params[0], out_params[1], out_params[2] = float(f), float(cx), float(cy) + return out_params + raise ValueError(f"Unsupported target camera model for calibration override: {dst_model}") + + +def _apply_gt_calibration_to_input_model(input_dir: Path, gt_calibration_dir: Path, log_path: Path) -> Path: + required = ("cameras.txt", "images.txt", "points3D.txt") + missing_input = [name for name in required if not (input_dir / name).exists()] + if missing_input: + _log_message(log_path, f"GT calibration override skipped: missing {missing_input} in {input_dir}.") + return input_dir + + missing_gt = [name for name in ("cameras.txt", "images.txt") if not (gt_calibration_dir / name).exists()] + if missing_gt: + _log_message(log_path, f"GT calibration override skipped: missing {missing_gt} in {gt_calibration_dir}.") + return input_dir + + input_cameras = _parse_colmap_cameras_txt(input_dir / "cameras.txt") + input_name_to_cam = _parse_colmap_images_name_to_camera(input_dir / "images.txt") + gt_cameras = _parse_colmap_cameras_txt(gt_calibration_dir / "cameras.txt") + gt_name_to_cam = _parse_colmap_images_name_to_camera(gt_calibration_dir / "images.txt") + + gt_basename_to_cam: dict[str, int] = {} + for name, cam_id in gt_name_to_cam.items(): + base = os.path.basename(name) + if base not in gt_basename_to_cam: + gt_basename_to_cam[base] = cam_id + + camera_to_image_names: dict[int, list[str]] = {} + for image_name, camera_id in input_name_to_cam.items(): + camera_to_image_names.setdefault(camera_id, []).append(image_name) + + updated_cameras = dict(input_cameras) + updated_count = 0 + for camera_id, (dst_model, dst_w, dst_h, dst_params) in input_cameras.items(): + image_names = camera_to_image_names.get(camera_id, []) + gt_cam_id = None + for image_name in image_names: + if image_name in gt_name_to_cam: + gt_cam_id = gt_name_to_cam[image_name] + break + base = os.path.basename(image_name) + if base in gt_basename_to_cam: + gt_cam_id = gt_basename_to_cam[base] + break + if gt_cam_id is None or gt_cam_id not in gt_cameras: + continue + + src_model, src_w, src_h, src_params = gt_cameras[gt_cam_id] + try: + new_params = _convert_scaled_intrinsics( + src_model=src_model, + src_params=src_params, + src_width=src_w, + src_height=src_h, + dst_model=dst_model, + dst_params=dst_params, + dst_width=dst_w, + dst_height=dst_h, + ) + except ValueError as exc: + _log_message(log_path, f"GT calibration override skipped for camera {camera_id}: {exc}") + continue + + updated_cameras[camera_id] = (dst_model, dst_w, dst_h, new_params) + updated_count += 1 + + if updated_count == 0: + _log_message(log_path, f"GT calibration override found no matching cameras between {input_dir} and {gt_calibration_dir}.") + return input_dir + + temp_dir = Path(tempfile.mkdtemp(prefix="ba_gt_calibration_")) + shutil.copy2(input_dir / "images.txt", temp_dir / "images.txt") + shutil.copy2(input_dir / "points3D.txt", temp_dir / "points3D.txt") + _write_colmap_cameras_txt(temp_dir / "cameras.txt", updated_cameras) + _log_message(log_path, f"Applied GT calibration to {updated_count} cameras using {gt_calibration_dir}.") + return temp_dir + + +def _run_ba_on_saved_reconstruction( + input_dir: Path, + ba_output_dir: Path, + log_path: Path, + skip_existing: bool, + robust_ba: bool, + ba_loss: str, + ba_loss_scale: float, + ba_refine_intrinsics: bool, + ba_use_gt_calibration: bool, + ba_gt_calibration_dir: str | None, + ba_use_gt_pose: bool, + ba_gt_pose_dir: str | None, +) -> bool: + if skip_existing and (ba_output_dir / "cameras.txt").exists(): + _log_message(log_path, f"Skipping BA: output already exists at {ba_output_dir}.") + return True + + required = ("cameras.txt", "images.txt", "points3D.txt") + missing = [name for name in required if not (input_dir / name).exists()] + if missing: + _log_message(log_path, f"Skipping BA: missing {missing} in {input_dir}.") + return False + + try: + import pycolmap + except ModuleNotFoundError as exc: + raise ModuleNotFoundError("pycolmap is required for BA mode.") from exc + + ba_input_dir = input_dir + temp_ba_input_dirs: list[Path] = [] + if ba_use_gt_calibration: + if ba_gt_calibration_dir is None: + _log_message(log_path, "GT calibration requested but --ba_gt_calibration_dir is not set; using original calibration.") + else: + ba_input_dir = _apply_gt_calibration_to_input_model(input_dir, Path(ba_gt_calibration_dir), log_path) + if ba_input_dir != input_dir: + temp_ba_input_dirs.append(ba_input_dir) + + if ba_use_gt_pose: + if ba_gt_pose_dir is None: + _log_message(log_path, "GT pose requested but --ba_gt_pose_dir is not set; using original poses.") + else: + prev_input_dir = ba_input_dir + ba_input_dir = _apply_gt_pose_to_input_model(ba_input_dir, Path(ba_gt_pose_dir), log_path) + if ba_input_dir != prev_input_dir: + temp_ba_input_dirs.append(ba_input_dir) + + try: + reconstruction = pycolmap.Reconstruction(str(ba_input_dir)) + if reconstruction.num_images() == 0: + _log_message(log_path, f"Skipping BA: empty reconstruction at {input_dir}.") + return False + + ba_options = pycolmap.BundleAdjustmentOptions() + if hasattr(ba_options, "loss_function_type"): + selected_loss = ba_loss if robust_ba else "TRIVIAL" + if hasattr(pycolmap, "LossFunctionType") and hasattr(pycolmap.LossFunctionType, selected_loss): + ba_options.loss_function_type = getattr(pycolmap.LossFunctionType, selected_loss) + else: + ba_options.loss_function_type = selected_loss + if robust_ba and hasattr(ba_options, "loss_function_scale"): + ba_options.loss_function_scale = float(ba_loss_scale) + + for attr_name in ("refine_focal_length", "refine_principal_point", "refine_extra_params"): + if hasattr(ba_options, attr_name): + setattr(ba_options, attr_name, bool(ba_refine_intrinsics)) + + if hasattr(ba_options, "solver_options"): + solver_options = ba_options.solver_options + if hasattr(solver_options, "max_num_iterations"): + solver_options.max_num_iterations = 50 + if hasattr(solver_options, "function_tolerance"): + solver_options.function_tolerance = 1e-6 + if hasattr(solver_options, "gradient_tolerance"): + solver_options.gradient_tolerance = 1e-10 + + pycolmap.bundle_adjustment(reconstruction, ba_options) + + ba_output_dir.mkdir(parents=True, exist_ok=True) + reconstruction.write_text(str(ba_output_dir)) + _log_message(log_path, f"Saved BA reconstruction to {ba_output_dir}.") + return True + finally: + for temp_dir in temp_ba_input_dirs: + shutil.rmtree(temp_dir, ignore_errors=True) + + +def _get_tracker_output_subdir(args: argparse.Namespace) -> str: + if not args.use_ba: + return "colmap" if args.ba_tracker == "colmap" else "vggt" + return args.ba_tracker + + +def _get_tracking_backend(args: argparse.Namespace) -> str: + if args.ba_tracker == "colmap": + return "colmap" + return "vggt" + + +def _cleanup_after_cluster() -> None: + gc.collect() + try: + import torch + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + except Exception: + # Best-effort cleanup: skip torch-specific cleanup if unavailable. + pass + + +def main() -> None: + args = _parse_args() + if not args.use_ba: + # Force tracker execution for every non-BA cluster run. + args.save_tracking_outputs = True + else: + default_gt_dir = Path(args.dataset_dir) / "dslr_calibration_undistorted" + if args.ba_use_gt_calibration and args.ba_gt_calibration_dir is None and default_gt_dir.exists(): + args.ba_gt_calibration_dir = str(default_gt_dir) + if args.ba_use_gt_pose and args.ba_gt_pose_dir is None and default_gt_dir.exists(): + args.ba_gt_pose_dir = str(default_gt_dir) + + cluster_tree = _load_cluster_tree(args.cluster_tree_path) + if cluster_tree is None: + raise ValueError(f"cluster_tree.pkl was empty or invalid: {args.cluster_tree_path}") + + image_names = image_paths = None + if not args.use_ba: + loader = _build_loader(args.config_name, args.dataset_dir, args.images_root, args.max_resolution) + image_names = loader.image_filenames() + images_root = args.images_root + if images_root is None and hasattr(loader, "_images_dir"): + images_root = getattr(loader, "_images_dir") + image_paths = _resolve_image_paths(image_names, images_root) + + output_root = Path(args.output_root) + ba_output_root = Path(args.ba_output_root) if args.use_ba and args.ba_output_root else output_root + log_path = output_root / "vggt_cluster.log" + output_root.mkdir(parents=True, exist_ok=True) + + model = device = dtype = None + if not args.use_ba: + model, device, dtype = setup_model(args) + _log_message( + log_path, + f"Reconstruction mode: tracker={_get_tracking_backend(args)}, point_source={args.point_source}", + ) + + if args.use_ba: + _log_message( + log_path, + ( + f"BA mode (tracker={args.ba_tracker}): reading reconstructions from {output_root} " + f"and writing optimized models to {ba_output_root}." + ), + ) + + for path, visibility_graph, is_leaf in _iter_clusters_with_paths(cluster_tree): + if not _should_run_cluster(path, is_leaf, args): + continue + + image_indices = sorted(visibility_graph_keys(visibility_graph)) + if len(image_indices) < args.min_images: + _log_message(log_path, f"Skipping {path}: only {len(image_indices)} images.") + continue + + output_paths = prepare_output_paths(output_root, path) + tracker_output_subdir = _get_tracker_output_subdir(args) + output_dir = output_paths.results / tracker_output_subdir + if args.use_ba: + ba_output_paths = prepare_output_paths(ba_output_root, path) + ba_output_dir = ba_output_paths.results / tracker_output_subdir + else: + ba_output_dir = None + + try: + if args.use_ba: + _log_message( + log_path, + ( + f"Running BA for {path}: {output_dir} -> {ba_output_dir} " + f"(robust_ba={args.robust_ba}, loss={args.ba_loss}, scale={args.ba_loss_scale}, " + f"refine_intrinsics={args.ba_refine_intrinsics}, " + f"use_gt_calibration={args.ba_use_gt_calibration}, " + f"use_gt_pose={args.ba_use_gt_pose})" + ), + ) + _run_ba_on_saved_reconstruction( + output_dir, + ba_output_dir, + log_path, + args.skip_existing, + args.robust_ba, + args.ba_loss, + args.ba_loss_scale, + args.ba_refine_intrinsics, + args.ba_use_gt_calibration, + args.ba_gt_calibration_dir, + args.ba_use_gt_pose, + args.ba_gt_pose_dir, + ) + continue + + if max(image_indices) >= len(image_names): + _log_message(log_path, f"Skipping {path}: image index out of range (max={max(image_indices)}).") + continue + + cluster_image_names = [image_names[idx] for idx in image_indices] + cluster_image_paths = [image_paths[idx] for idx in image_indices] + missing_paths = [p for p in cluster_image_paths if not os.path.exists(p)] + if missing_paths: + _log_message(log_path, f"Skipping {path}: missing {len(missing_paths)} images.") + continue + + run_output_dir = output_dir + if args.skip_existing and (run_output_dir / "cameras.txt").exists(): + _log_message(log_path, f"Skipping {path}: output already exists at {run_output_dir}.") + continue + + _log_message(log_path, f"Running VGGT for {path} -> {run_output_dir}") + run_vggt_reconstruction( + args, + model, + device, + dtype, + cluster_image_paths, + str(run_output_dir), + image_name_list=cluster_image_names, + tracker_backend=_get_tracking_backend(args), + ) + _promote_tracking_writer_outputs(run_output_dir, log_path) + except Exception as exc: + _log_message(log_path, f"Failed {path}: {exc!r}") + finally: + _cleanup_after_cluster() + + +if __name__ == "__main__": + main() diff --git a/pipeline/3-cluster_ba/cluster_ba.py b/pipeline/3-cluster_ba/cluster_ba.py new file mode 100644 index 000000000..af405f6ef --- /dev/null +++ b/pipeline/3-cluster_ba/cluster_ba.py @@ -0,0 +1,260 @@ +""" +Run VGGT tracking + bundle adjustment for each cluster reconstruction. + +This script mirrors the directory structure of an input results tree (e.g. +`.../2-reconstruction/vggt_cluster_run/results`) and re-runs VGGT tracking/BA +per `vggt` folder. Outputs are written in COLMAP format under a user-specified +output root with the same relative layout. +""" + +from __future__ import annotations + +import argparse +import importlib +import os +import sys +from pathlib import Path + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run VGGT tracking + BA for each cluster reconstruction.") + parser.add_argument( + "--input_root", + type=str, + required=True, + help=( + "Path to a results root or run root. If this path contains a " + "`results/` subdir, that subdir is scanned for vggt reconstructions." + ), + ) + parser.add_argument( + "--output_root", + type=str, + default=str(Path.cwd()), + help=( + "Base output directory. Results are written under " + "`/results/...` unless output_root itself is named `results`." + ), + ) + parser.add_argument( + "--images_root", + type=str, + default=None, + help="Root directory for images referenced in images.txt (required if names are relative).", + ) + parser.add_argument( + "--no_skip_existing", + action="store_false", + dest="skip_existing", + default=True, + help="Recompute even if output already exists.", + ) + parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility.") + parser.add_argument( + "--model_path", + type=str, + default="pipeline/2-reconstruction/vggt/weights/model.pt", + help="Path to the VGGT model checkpoint.", + ) + parser.add_argument("--use_ba", action="store_true", default=True, help="Use BA for reconstruction.") + parser.add_argument( + "--no_use_ba", + action="store_false", + dest="use_ba", + help="Disable BA (still runs VGGT feed-forward reconstruction).", + ) + parser.add_argument( + "--ba_tracker", + type=str, + choices=["vggt", "vggsfm"], + default="vggt", + help="Tracker used for BA (vggt or vggsfm).", + ) + parser.add_argument("--img_load_resolution", type=int, default=1024, help="Square load resolution for VGGT input.") + parser.add_argument("--vggt_fixed_resolution", type=int, default=518, help="VGGT internal inference resolution.") + parser.add_argument( + "--conf_thres_value", type=float, default=5.0, help="Confidence threshold value for depth filtering." + ) + return parser.parse_args() + + +def _resolve_results_root(path: Path) -> Path: + if (path / "results").is_dir(): + return path / "results" + return path + + +def _resolve_output_results_root(path: Path) -> Path: + if path.name == "results": + return path + return path / "results" + + +def _is_relative_to(path: Path, base: Path) -> bool: + try: + path.relative_to(base) + return True + except ValueError: + return False + + +def _has_colmap_files(path: Path) -> bool: + has_cameras = (path / "cameras.txt").exists() or (path / "cameras.bin").exists() + has_images = (path / "images.txt").exists() or (path / "images.bin").exists() + has_points = (path / "points3D.txt").exists() or (path / "points3D.bin").exists() + return has_cameras and has_images and has_points + + +def _iter_vggt_dirs(results_root: Path, output_results_root: Path) -> list[Path]: + vggt_dirs: list[Path] = [] + for dirpath, _, _ in os.walk(results_root): + path = Path(dirpath) + if _is_relative_to(path, output_results_root): + continue + if path.name != "vggt": + continue + if _has_colmap_files(path): + vggt_dirs.append(path) + return vggt_dirs + + +def _log(log_path: Path, message: str) -> None: + log_path.parent.mkdir(parents=True, exist_ok=True) + with open(log_path, "a", encoding="utf-8") as log_file: + log_file.write(f"{message}\n") + print(message) + + +def _ensure_ba_module() -> None: + try: + importlib.import_module("ba") + return + except ModuleNotFoundError: + candidate_eval_dirs: list[Path] = [] + search_roots = [Path(__file__).resolve()] + list(Path(__file__).resolve().parents) + search_roots += [Path.cwd().resolve()] + list(Path.cwd().resolve().parents) + for root in search_roots: + thirdparty_eval = root / "thirdparty" / "vggt" / "evaluation" + if (thirdparty_eval / "ba.py").exists(): + candidate_eval_dirs.append(thirdparty_eval) + pipeline_eval = root / "pipeline" / "2-reconstruction" / "vggt" / "evaluation" + if (pipeline_eval / "ba.py").exists(): + candidate_eval_dirs.append(pipeline_eval) + + lightglue_root = None + for root in search_roots: + candidate = root / "thirdparty" / "LightGlue" + if candidate.exists(): + lightglue_root = candidate + break + if lightglue_root is not None: + sys.path.insert(0, str(lightglue_root)) + + for eval_dir in candidate_eval_dirs: + sys.path.insert(0, str(eval_dir)) + sys.path.insert(0, str(eval_dir.parent)) + try: + importlib.import_module("ba") + return + except ModuleNotFoundError: + continue + raise + + +def _parse_colmap_images_txt(images_txt_path: Path) -> list[str]: + image_names: list[str] = [] + with open(images_txt_path, "r", encoding="utf-8") as images_file: + lines = images_file.readlines() + idx = 0 + while idx < len(lines): + line = lines[idx].strip() + if not line or line.startswith("#"): + idx += 1 + continue + parts = line.split() + if len(parts) < 10: + raise ValueError(f"Invalid images.txt line in {images_txt_path}: {line}") + image_names.append(" ".join(parts[9:])) + idx += 2 + return image_names + + +def _resolve_image_paths(image_names: list[str], images_root: str | None) -> list[str]: + resolved_paths = [] + for name in image_names: + if os.path.isabs(name): + resolved_paths.append(name) + else: + if images_root is None: + raise ValueError("images_root is required when image filenames are relative.") + resolved_paths.append(os.path.join(images_root, name)) + return resolved_paths + + +def _setup_vggt(args: argparse.Namespace): + _ensure_ba_module() + test_module = importlib.import_module("evaluation.test_co3d_cluster") + setup_model = getattr(test_module, "setup_model") + run_vggt_reconstruction = getattr(test_module, "run_vggt_reconstruction") + model, device, dtype = setup_model(args) + return run_vggt_reconstruction, model, device, dtype + + +def main() -> None: + args = _parse_args() + + input_root = Path(args.input_root) + output_root = Path(args.output_root) + input_results_root = _resolve_results_root(input_root) + output_results_root = _resolve_output_results_root(output_root) + + log_path = output_root / "cluster_ba.log" + vggt_dirs = _iter_vggt_dirs(input_results_root, output_results_root) + if not vggt_dirs: + _log(log_path, f"No vggt reconstructions found under {input_results_root}.") + return + + run_vggt_reconstruction, model, device, dtype = _setup_vggt(args) + + for vggt_dir in vggt_dirs: + rel_path = vggt_dir.relative_to(input_results_root) + output_dir = output_results_root / rel_path + + if args.skip_existing and _has_colmap_files(output_dir): + _log(log_path, f"Skipping {vggt_dir}: output already exists at {output_dir}.") + continue + + images_txt_path = vggt_dir / "images.txt" + if not images_txt_path.exists(): + _log(log_path, f"Skipping {vggt_dir}: missing images.txt.") + continue + + image_names = _parse_colmap_images_txt(images_txt_path) + try: + image_paths = _resolve_image_paths(image_names, args.images_root) + except Exception as exc: + _log(log_path, f"Skipping {vggt_dir}: {exc!r}") + continue + + missing = [path for path in image_paths if not os.path.exists(path)] + if missing: + _log(log_path, f"Skipping {vggt_dir}: missing {len(missing)} images.") + continue + + _log(log_path, f"Running VGGT+BA for {vggt_dir} -> {output_dir}") + try: + run_vggt_reconstruction( + args, + model, + device, + dtype, + image_paths, + str(output_dir), + image_name_list=image_names, + ) + except Exception as exc: + _log(log_path, f"Failed {vggt_dir}: {exc!r}") + + +if __name__ == "__main__": + main() diff --git a/pipeline/4-alignment/alignment.py b/pipeline/4-alignment/alignment.py new file mode 100644 index 000000000..ae638ab86 --- /dev/null +++ b/pipeline/4-alignment/alignment.py @@ -0,0 +1,430 @@ +""" +Align and merge cluster reconstructions using Sim(3) estimated from overlapping cameras. + +Input: + - COLMAP text reconstructions stored under `/results/.../` + (e.g. `.../results/C_1/C_1_2/vggt`). + - Cluster tree pickle from the partition stage. + +Output: + - COLMAP text reconstructions written under `/results/.../` + (default `merged`) for every non-leaf cluster node. + - COLMAP text reconstructions for the original (pre-merge) clusters written under + `/results/.../` (default `_original`). +""" + +from __future__ import annotations + +import argparse +import pickle +import shlex +import subprocess +from pathlib import Path +from typing import Dict, Iterable, Optional, Tuple + +from gtsam import Pose3, Similarity3, SfmTrack + +import gtsfm.utils.logger as logger_utils +import gtsfm.common.types as gtsfm_types +from gtsfm.common.gtsfm_data import GtsfmData +from gtsfm.common.outputs import cluster_label +from gtsfm.utils.tree import PostOrderIter, Tree + +logger = logger_utils.get_logger() + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Align and merge cluster reconstructions using Sim(3).") + parser.add_argument("--cluster_tree_path", type=str, required=True, help="Path to cluster_tree.pkl") + parser.add_argument( + "--input_root", + type=str, + required=True, + help="Root containing `results/` from the reconstruction stage.", + ) + parser.add_argument( + "--output_root", + type=str, + default=None, + help="Root to write merged outputs (defaults to input_root).", + ) + parser.add_argument( + "--input_model_name", + type=str, + default="vggt", + help="Name of per-cluster COLMAP model folder to read.", + ) + parser.add_argument( + "--original_model_name", + type=str, + default=None, + help="Name of per-cluster COLMAP model folder to write original (pre-merge) models. " + "Defaults to '_original'.", + ) + parser.add_argument( + "--output_model_name", + type=str, + default="merged", + help="Name of per-cluster COLMAP model folder to write.", + ) + parser.add_argument( + "--min_common_cameras", + type=int, + default=2, + help="Minimum number of overlapping cameras required to estimate Sim(3).", + ) + parser.add_argument( + "--drop_child_if_fail", + action="store_true", + default=True, + help="Drop child if alignment or merging fails.", + ) + parser.add_argument( + "--no_drop_child_if_fail", + action="store_false", + dest="drop_child_if_fail", + help="Keep child in place (identity Sim(3)) if alignment fails.", + ) + parser.add_argument( + "--skip_existing", + action="store_true", + default=False, + help="Skip writing if merged output already exists (will be reused for parents).", + ) + parser.add_argument( + "--run_colmap_ba", + action="store_true", + default=False, + help="Run COLMAP bundle_adjuster on each merged model before propagating to parents.", + ) + parser.add_argument( + "--colmap_path", + type=str, + default="colmap", + help="Path to the COLMAP executable used for bundle_adjuster.", + ) + parser.add_argument( + "--colmap_ba_extra_args", + type=str, + default="", + help="Extra arguments passed to COLMAP bundle_adjuster (single string, shell-like).", + ) + parser.add_argument( + "--convert_ba_to_txt", + action="store_true", + default=False, + help="After COLMAP BA, run model_converter to produce a TXT model in a separate folder.", + ) + return parser.parse_args() + + +def _load_cluster_tree(cluster_tree_path: str) -> Tree: + with open(cluster_tree_path, "rb") as f: + cluster_tree = pickle.load(f) + if cluster_tree is None: + raise ValueError(f"cluster_tree.pkl was empty or invalid: {cluster_tree_path}") + return cluster_tree + + +def _cluster_dir(root: Path, path: Tuple[int, ...]) -> Path: + cluster_dir = root / "results" + if path: + for depth in range(len(path)): + cluster_dir = cluster_dir / cluster_label(path[: depth + 1]) + return cluster_dir + + +def _model_dir(root: Path, path: Tuple[int, ...], model_name: str) -> Path: + return _cluster_dir(root, path) / model_name + + +def _read_scene(model_dir: Path, name_to_idx: Optional[Dict[str, int]] = None) -> Optional[GtsfmData]: + if not model_dir.exists(): + return None + if not (model_dir / "images.txt").exists() and not (model_dir / "images.bin").exists(): + return None + try: + scene = GtsfmData.read_colmap(str(model_dir)) + if name_to_idx is not None: + scene = _remap_scene_to_global_indices(scene, name_to_idx) + return scene + except Exception as exc: + logger.warning("Failed to read COLMAP model at %s: %s", model_dir, exc) + return None + + +def _pose_map_by_name(scene: GtsfmData) -> Dict[str, Pose3]: + pose_map: Dict[str, Pose3] = {} + for idx, pose in scene.poses().items(): + name = scene.get_image_info(idx).name + if name is None: + continue + pose_map[name] = pose + return pose_map + + +def _parse_colmap_image_names(images_txt: Path) -> set[str]: + names: set[str] = set() + if not images_txt.exists(): + return names + + # Preserve empty lines. COLMAP text models use two lines per image, and images with no + # observations legitimately have an empty second line. + lines = images_txt.read_text().splitlines() + i = 0 + while i < len(lines): + line = lines[i].strip() + if not line or line.startswith("#"): + i += 1 + continue + parts = line.split() + if len(parts) >= 10: + try: + int(parts[0]) # IMAGE_ID + int(parts[8]) # CAMERA_ID + names.add(parts[9]) # NAME + i += 2 + continue + except ValueError: + pass + i += 1 + return names + + +def _build_global_name_map(input_root: Path, model_name: str) -> Dict[str, int]: + base = input_root / "results" + names: set[str] = set() + for images_txt in base.rglob(f"{model_name}/images.txt"): + names.update(_parse_colmap_image_names(images_txt)) + if not names: + raise ValueError(f"No image names found under {base} for model {model_name}.") + sorted_names = sorted(names) + return {name: idx for idx, name in enumerate(sorted_names)} + + +def _remap_scene_to_global_indices(scene: GtsfmData, name_to_idx: Dict[str, int]) -> GtsfmData: + remapped = GtsfmData(number_images=len(name_to_idx)) + + # Remap cameras and image info by global image name. + for idx, pose in scene.poses().items(): + info = scene.get_image_info(idx) + name = info.name + if name is None or name not in name_to_idx: + continue + new_idx = name_to_idx[name] + camera = scene.get_camera(idx) + if camera is None: + continue + calibration = camera.calibration() + camera_type = gtsfm_types.get_camera_class_for_calibration(calibration) + remapped.add_camera(new_idx, camera_type(pose, calibration)) # type: ignore + remapped.set_image_info(new_idx, name=name, shape=info.shape) + + # Remap tracks by global image name. + for track in scene.tracks(): + new_track = SfmTrack(track.point3()) + for k in range(track.numberMeasurements()): + i, uv = track.measurement(k) + name = scene.get_image_info(i).name + if name is None or name not in name_to_idx: + continue + new_track.addMeasurement(name_to_idx[name], uv) + if new_track.numberMeasurements() > 0: + new_track.r = getattr(track, "r", 0) + new_track.g = getattr(track, "g", 0) + new_track.b = getattr(track, "b", 0) + remapped.add_track(new_track) + + if scene.has_gaussian_splats(): + remapped.set_gaussian_splats(scene.get_gaussian_splats()) + + return remapped + + +def _sim3_from_common_names( + a_scene: GtsfmData, b_scene: GtsfmData, min_common: int +) -> Similarity3: + a_map = _pose_map_by_name(a_scene) + b_map = _pose_map_by_name(b_scene) + common_names = [name for name in a_map if name in b_map] + if len(common_names) < min_common: + raise ValueError(f"Found only {len(common_names)} overlapping cameras (need {min_common}).") + pose_pairs = [(a_map[name], b_map[name]) for name in common_names] + return Similarity3.Align(pose_pairs) + + +def _align_and_merge( + parent: GtsfmData, + child: GtsfmData, + min_common: int, + drop_child_if_fail: bool, +) -> GtsfmData: + try: + aSb = _sim3_from_common_names(parent, child, min_common=min_common) + except Exception as exc: + if drop_child_if_fail: + logger.warning("Dropping child due to alignment failure: %s", exc) + return parent + logger.warning("Alignment failed; using identity Sim(3): %s", exc) + aSb = Similarity3() + + try: + return parent.merged_with(child, aSb) + except Exception as exc: + logger.warning("Failed to merge child: %s", exc) + return parent + + +def _iter_path_tree(cluster_tree: Tree) -> Iterable[Tree[Tuple[int, ...]]]: + path_tree: Tree[Tuple[int, ...]] = cluster_tree.map_with_path(lambda path, _: path) + return PostOrderIter(path_tree) + + +def _run_colmap_bundle_adjuster( + input_model_dir: Path, output_model_dir: Path, colmap_path: str, extra_args: str +) -> bool: + if not input_model_dir.exists(): + logger.warning("COLMAP BA skipped: model dir missing at %s", input_model_dir) + return False + cmd = [ + colmap_path, + "bundle_adjuster", + "--input_path", + str(input_model_dir), + "--output_path", + str(output_model_dir), + ] + if extra_args: + cmd.extend(shlex.split(extra_args)) + try: + logger.info("Running COLMAP BA: %s", " ".join(cmd)) + subprocess.run(cmd, check=True) + return True + except Exception as exc: + logger.warning("COLMAP BA failed for %s: %s", input_model_dir, exc) + return False + + +def _run_colmap_model_converter(input_model_dir: Path, output_model_dir: Path, colmap_path: str) -> bool: + if not input_model_dir.exists(): + logger.warning("COLMAP model conversion skipped: model dir missing at %s", input_model_dir) + return False + cmd = [ + colmap_path, + "model_converter", + "--input_path", + str(input_model_dir), + "--output_path", + str(output_model_dir), + "--output_type", + "TXT", + ] + try: + logger.info("Running COLMAP model_converter: %s", " ".join(cmd)) + subprocess.run(cmd, check=True) + return True + except Exception as exc: + logger.warning("COLMAP model conversion failed for %s: %s", input_model_dir, exc) + return False + + +def main() -> None: + args = _parse_args() + + input_root = Path(args.input_root) + output_root = Path(args.output_root) if args.output_root is not None else input_root + original_model_name = args.original_model_name or f"{args.input_model_name}_original" + cluster_tree = _load_cluster_tree(args.cluster_tree_path) + name_to_idx = _build_global_name_map(input_root, args.input_model_name) + + merged_cache: Dict[Tuple[int, ...], Optional[GtsfmData]] = {} + + for node in _iter_path_tree(cluster_tree): + path = node.value + input_dir = _model_dir(input_root, path, args.input_model_name) + current = _read_scene(input_dir, name_to_idx) + if current is None: + logger.warning("Skipping %s: missing model at %s", path, input_dir) + merged_cache[path] = None + continue + + original_dir = _model_dir(output_root, path, original_model_name) + if not args.skip_existing or not (original_dir / "cameras.txt").exists(): + original_dir.mkdir(parents=True, exist_ok=True) + current.export_as_colmap_text(original_dir) + logger.info("Wrote original model for %s to %s", path, original_dir) + + pre_ba_dir = _model_dir(output_root, path, f"{args.output_model_name}_pre_ba") + ba_dir = _model_dir(output_root, path, f"{args.output_model_name}_colmap_ba") + ba_txt_dir = _model_dir(output_root, path, f"{args.output_model_name}_colmap_ba_txt") + + if node.is_leaf(): + merged_cache[path] = current + continue + + if args.skip_existing: + cached = None + if args.run_colmap_ba and (ba_dir / "cameras.txt").exists(): + cached = _read_scene(ba_dir, name_to_idx) + elif (pre_ba_dir / "cameras.txt").exists(): + cached = _read_scene(pre_ba_dir, name_to_idx) + if cached is not None: + merged_cache[path] = cached + logger.info("Reusing existing merged output at %s", ba_dir if args.run_colmap_ba else pre_ba_dir) + continue + + merged = current + for child in node.children: + child_path = child.value + child_scene = merged_cache.get(child_path) + if child_scene is None: + if args.run_colmap_ba: + child_scene = _read_scene( + _model_dir(output_root, child_path, f"{args.output_model_name}_colmap_ba"), + name_to_idx, + ) + if child_scene is None: + child_scene = _read_scene( + _model_dir(output_root, child_path, f"{args.output_model_name}_pre_ba"), + name_to_idx, + ) + if child_scene is None and args.convert_ba_to_txt: + child_scene = _read_scene( + _model_dir(output_root, child_path, f"{args.output_model_name}_colmap_ba_txt"), + name_to_idx, + ) + if child_scene is None: + child_scene = _read_scene(_model_dir(input_root, child_path, args.input_model_name), name_to_idx) + if child_scene is None: + logger.warning("Missing child model for %s -> %s", path, child_path) + continue + merged = _align_and_merge( + merged, + child_scene, + min_common=args.min_common_cameras, + drop_child_if_fail=args.drop_child_if_fail, + ) + + pre_ba_dir.mkdir(parents=True, exist_ok=True) + merged.export_as_colmap_text(pre_ba_dir) + if args.run_colmap_ba: + ba_dir.mkdir(parents=True, exist_ok=True) + if _run_colmap_bundle_adjuster(pre_ba_dir, ba_dir, args.colmap_path, args.colmap_ba_extra_args): + merged_ba = _read_scene(ba_dir, name_to_idx) + if merged_ba is not None: + merged = merged_ba + else: + logger.warning("COLMAP BA completed but failed to reload model at %s", ba_dir) + if args.convert_ba_to_txt: + ba_txt_dir.mkdir(parents=True, exist_ok=True) + _run_colmap_model_converter(ba_dir, ba_txt_dir, args.colmap_path) + merged_cache[path] = merged + logger.info( + "Wrote merged model for %s to %s", + path, + ba_dir if args.run_colmap_ba else pre_ba_dir, + ) + + +if __name__ == "__main__": + main() diff --git a/pipeline/5-global_ba/global_ba.py b/pipeline/5-global_ba/global_ba.py new file mode 100644 index 000000000..b19680941 --- /dev/null +++ b/pipeline/5-global_ba/global_ba.py @@ -0,0 +1,5 @@ +''' +In this file, implement global_ba + + +''' \ No newline at end of file diff --git a/pipeline/run_pipeline.sh b/pipeline/run_pipeline.sh new file mode 100644 index 000000000..42565a8ce --- /dev/null +++ b/pipeline/run_pipeline.sh @@ -0,0 +1,246 @@ +#!/usr/bin/env bash + +set -euo pipefail +export HF_HOME=/nethome/xzhang979/nvme/cache + +if [[ "${1:-}" == "-h" || "${1:-}" == "--help" ]]; then + echo "Usage: $0 [tracker] [--reconstruction_method {vggt_cluster|pi3}]" + echo "Example: $0 gerrard-hall" + echo "Example: $0 gerrard-hall vggsfm" + echo "Example: $0 gerrard-hall --reconstruction_method pi3" + exit 0 +fi + +if [[ $# -lt 1 ]]; then + echo "Usage: $0 [tracker] [--reconstruction_method {vggt_cluster|pi3}]" + echo "Example: $0 gerrard-hall" + echo "Example: $0 gerrard-hall vggsfm" + echo "Example: $0 gerrard-hall --reconstruction_method pi3" + exit 1 +fi + +DATASET_NAME="$1" +shift + +TRACKER="vggt" +RECONSTRUCTION_METHOD="vggt_cluster" + +if [[ $# -gt 0 ]]; then + case "$1" in + vggt|vggsfm|colmap) + TRACKER="$1" + shift + ;; + esac +fi + +while [[ $# -gt 0 ]]; do + case "$1" in + --reconstruction_method) + if [[ $# -lt 2 ]]; then + echo "Error: --reconstruction_method requires a value" + exit 1 + fi + RECONSTRUCTION_METHOD="$2" + if [[ "${RECONSTRUCTION_METHOD}" != "vggt_cluster" && "${RECONSTRUCTION_METHOD}" != "pi3" ]]; then + echo "Error: --reconstruction_method must be one of: vggt_cluster, pi3" + exit 1 + fi + shift 2 + ;; + -h|--help) + echo "Usage: $0 [tracker] [--reconstruction_method {vggt_cluster|pi3}]" + echo "Example: $0 gerrard-hall" + echo "Example: $0 gerrard-hall vggsfm" + echo "Example: $0 gerrard-hall --reconstruction_method pi3" + exit 0 + ;; + *) + echo "Error: unknown argument '$1'" + exit 1 + ;; + esac +done + +if [[ "${TRACKER}" != "vggt" && "${TRACKER}" != "vggsfm" && "${TRACKER}" != "colmap" ]]; then + echo "Error: tracker must be one of: vggt, vggsfm, colmap" + exit 1 +fi + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" + +DATASET_DIR="${REPO_ROOT}/benchmarks/${DATASET_NAME}" +if [[ "${RECONSTRUCTION_METHOD}" == "pi3" ]]; then + RESULTS_ROOT="${REPO_ROOT}/pipeline/results/${DATASET_NAME}_pi3" +else + RESULTS_ROOT="${REPO_ROOT}/pipeline/results/${DATASET_NAME}" +fi +CLUSTER_TREE_PATH="${RESULTS_ROOT}/1-partition/results/cluster_tree.pkl" + +RECON_RUN_NAME="vggt_cluster_run" +RECON_MODEL_NAME="${TRACKER}" +if [[ "${RECONSTRUCTION_METHOD}" == "pi3" ]]; then + RECON_RUN_NAME="pi3_run" + RECON_MODEL_NAME="pi3" +fi +RECON_OUTPUT_ROOT="${RESULTS_ROOT}/2-reconstruction/${RECON_RUN_NAME}" + +if [[ ! -d "${DATASET_DIR}" ]]; then + echo "Error: dataset directory not found: ${DATASET_DIR}" + exit 1 +fi + +BASELINE_DIR="$(find "${DATASET_DIR}" -mindepth 1 -maxdepth 4 -type d \( -name "sparse" -o -name "colmap" -o -name "sfm" \) -print -quit)" +if [[ -z "${BASELINE_DIR}" ]]; then + echo "Error: baseline directory not found. Expected one of:" + echo " any 'sparse', 'colmap', or 'sfm' directory under ${DATASET_DIR}" + exit 1 +fi + +# Ensure conda activation works in non-interactive shells. +if [[ "${CONDA_DEFAULT_ENV:-}" != "gtsfm-v2" ]]; then + if command -v conda >/dev/null 2>&1; then + eval "$(conda shell.bash hook)" || true + fi + + if [[ -z "${CONDA_EXE:-}" ]]; then + for candidate in \ + "${HOME}/miniconda3/etc/profile.d/conda.sh" \ + "${HOME}/anaconda3/etc/profile.d/conda.sh" \ + "/opt/conda/etc/profile.d/conda.sh"; do + if [[ -f "${candidate}" ]]; then + # shellcheck disable=SC1090 + source "${candidate}" + break + fi + done + fi + + if ! command -v conda >/dev/null 2>&1; then + echo "Error: conda command is not available in this shell." + echo "Please run with gtsfm-v2 already active, or ensure conda is installed and initialized." + exit 1 + fi + + conda activate gtsfm-v2 +fi + +# partition +python "${REPO_ROOT}/pipeline/1-partition/partition_metis_megaloc.py" \ + --dataset_dir "${DATASET_DIR}" \ + --output_root "${RESULTS_ROOT}/1-partition" + +# reconstruction +if [[ "${RECONSTRUCTION_METHOD}" == "pi3" ]]; then + python "${REPO_ROOT}/pipeline/2-reconstruction/Pi3/run_on_cluster.py" \ + --cluster_tree_path "${CLUSTER_TREE_PATH}" \ + --dataset_dir "${DATASET_DIR}" \ + --output_root "${RECON_OUTPUT_ROOT}" \ + --model_name "${RECON_MODEL_NAME}" +else + python "${REPO_ROOT}/pipeline/2-reconstruction/vggt/run_on_cluster.py" \ + --cluster_tree_path "${CLUSTER_TREE_PATH}" \ + --dataset_dir "${DATASET_DIR}" \ + --output_root "${RECON_OUTPUT_ROOT}" \ + --ba_tracker "${TRACKER}" +fi + +python "${REPO_ROOT}/pipeline/utils/check_tracks.py" \ + --recon_root "${RECON_OUTPUT_ROOT}/results" \ + --images_root "${DATASET_DIR}" \ + --model_name "${RECON_MODEL_NAME}" + +if [[ "${RECONSTRUCTION_METHOD}" != "pi3" ]]; then + python "${REPO_ROOT}/pipeline/2-reconstruction/vggt/run_on_cluster.py" \ + --cluster_tree_path "${CLUSTER_TREE_PATH}" \ + --dataset_dir "${DATASET_DIR}" \ + --output_root "${RECON_OUTPUT_ROOT}" \ + --ba_tracker "${TRACKER}" \ + --use_ba \ + --ba_output_root "${RESULTS_ROOT}/3-cluster_ba/vggt_cluster_run" + + python "${REPO_ROOT}/pipeline/utils/check_tracks.py" \ + --recon_root "${RESULTS_ROOT}/3-cluster_ba/vggt_cluster_run/results" \ + --images_root "${DATASET_DIR}" \ + --model_name "${TRACKER}" +fi + +python "${REPO_ROOT}/gtsfm/evaluation/compare_colmap_outputs_by_cluster.py" \ + --baseline "${BASELINE_DIR}" \ + --root "${RECON_OUTPUT_ROOT}" \ + --recon_name "${RECON_MODEL_NAME}" \ + --csv_output "${RECON_OUTPUT_ROOT}/${RECON_MODEL_NAME}_eval/cluster_pose_metrics.csv" + +if [[ "${RECONSTRUCTION_METHOD}" != "pi3" ]]; then + python "${REPO_ROOT}/gtsfm/evaluation/compare_colmap_outputs_by_cluster.py" \ + --baseline "${BASELINE_DIR}" \ + --root "${RESULTS_ROOT}/3-cluster_ba/vggt_cluster_run" \ + --recon_name "${TRACKER}" \ + --csv_output "${RESULTS_ROOT}/3-cluster_ba/vggt_cluster_run/${TRACKER}_ba_eval/cluster_pose_metrics.csv" +fi + +# alignment +eval_reconstruction() { + local current_model_dir="$1" + local output_dir="$2" + local recon_name + recon_name="$(basename "${current_model_dir}")" + python "${REPO_ROOT}/gtsfm/evaluation/compare_colmap_outputs.py" \ + --baseline "${BASELINE_DIR}" \ + --current "${current_model_dir}" \ + --output "${output_dir}" + python "${REPO_ROOT}/gtsfm/evaluation/compare_colmap_outputs_by_cluster.py" \ + --baseline "${BASELINE_DIR}" \ + --root "${output_dir}" \ + --recon_name "${recon_name}" \ + --csv_output "${output_dir}/vggt_eval/cluster_pose_metrics.csv" +} + +## case 1 +python "${REPO_ROOT}/pipeline/4-alignment/alignment.py" \ + --cluster_tree_path "${CLUSTER_TREE_PATH}" \ + --input_root "${RECON_OUTPUT_ROOT}" \ + --input_model_name "${RECON_MODEL_NAME}" \ + --output_root "${RESULTS_ROOT}/4-alignment" +eval_reconstruction \ + "${RESULTS_ROOT}/4-alignment/results/merged_pre_ba" \ + "${RESULTS_ROOT}/4-alignment/results" + +if [[ "${RECONSTRUCTION_METHOD}" != "pi3" ]]; then + ## case 2 + python "${REPO_ROOT}/pipeline/4-alignment/alignment.py" \ + --cluster_tree_path "${CLUSTER_TREE_PATH}" \ + --input_root "${RESULTS_ROOT}/3-cluster_ba/vggt_cluster_run" \ + --input_model_name "${TRACKER}" \ + --output_root "${RESULTS_ROOT}/4-alignment-clusterba" + eval_reconstruction \ + "${RESULTS_ROOT}/4-alignment-clusterba/results/merged_pre_ba" \ + "${RESULTS_ROOT}/4-alignment-clusterba/results" +fi + +## case 3 +python "${REPO_ROOT}/pipeline/4-alignment/alignment.py" \ + --cluster_tree_path "${CLUSTER_TREE_PATH}" \ + --input_root "${RECON_OUTPUT_ROOT}" \ + --input_model_name "${RECON_MODEL_NAME}" \ + --output_root "${RESULTS_ROOT}/5-global_ba" \ + --run_colmap_ba \ + --convert_ba_to_txt +eval_reconstruction \ + "${RESULTS_ROOT}/5-global_ba/results/merged_colmap_ba_txt" \ + "${RESULTS_ROOT}/5-global_ba/results" + +if [[ "${RECONSTRUCTION_METHOD}" != "pi3" ]]; then + ## case 4 + python "${REPO_ROOT}/pipeline/4-alignment/alignment.py" \ + --cluster_tree_path "${CLUSTER_TREE_PATH}" \ + --input_root "${RESULTS_ROOT}/3-cluster_ba/vggt_cluster_run" \ + --output_root "${RESULTS_ROOT}/5-global_ba-cluster_ba" \ + --input_model_name "${TRACKER}" \ + --run_colmap_ba \ + --convert_ba_to_txt + eval_reconstruction \ + "${RESULTS_ROOT}/5-global_ba-cluster_ba/results/merged_colmap_ba_txt" \ + "${RESULTS_ROOT}/5-global_ba-cluster_ba/results" +fi diff --git a/pipeline/run_pipeline_eth3dmvs.sh b/pipeline/run_pipeline_eth3dmvs.sh new file mode 100644 index 000000000..8f9604533 --- /dev/null +++ b/pipeline/run_pipeline_eth3dmvs.sh @@ -0,0 +1,454 @@ +#!/usr/bin/env bash + +set -euo pipefail +export HF_HOME=/nethome/xzhang979/nvme/cache +TRACKER="vggt" +RECONSTRUCTION_METHOD="vggt_cluster" +POINT_SOURCE="depth" +TRIANGULATION_MIN_VIEWS=2 +BA_USE_GT_CALIBRATION=0 +BA_GT_CALIBRATION_DIR="" +SINGLE_CLUSTER=0 + +usage() { + echo "Usage: $0 [tracker] [--reconstruction_method {vggt_cluster|pi3}] [--point_source {depth|triangulation}] [--triangulation_min_views N] [--ba_use_gt_calibration] [--ba_gt_calibration_dir DIR] [--single_cluster]" + echo "Example: $0" + echo "Example: $0 vggsfm" + echo "Example: $0 --reconstruction_method pi3" + echo "Example: $0 colmap --point_source triangulation --triangulation_min_views 3" + echo "Example: $0 colmap --ba_use_gt_calibration" + echo "Example: $0 vggt --single_cluster" +} + +while [[ $# -gt 0 ]]; do + case "$1" in + vggt|vggsfm|colmap) + TRACKER="$1" + shift + ;; + --point_source) + if [[ $# -lt 2 ]]; then + echo "Error: --point_source requires a value" + usage + exit 1 + fi + POINT_SOURCE="$2" + if [[ "${POINT_SOURCE}" != "depth" && "${POINT_SOURCE}" != "triangulation" ]]; then + echo "Error: --point_source must be one of: depth, triangulation" + usage + exit 1 + fi + shift 2 + ;; + --reconstruction_method) + if [[ $# -lt 2 ]]; then + echo "Error: --reconstruction_method requires a value" + usage + exit 1 + fi + RECONSTRUCTION_METHOD="$2" + if [[ "${RECONSTRUCTION_METHOD}" != "vggt_cluster" && "${RECONSTRUCTION_METHOD}" != "pi3" ]]; then + echo "Error: --reconstruction_method must be one of: vggt_cluster, pi3" + usage + exit 1 + fi + shift 2 + ;; + --triangulation_min_views) + if [[ $# -lt 2 ]]; then + echo "Error: --triangulation_min_views requires an integer value" + usage + exit 1 + fi + TRIANGULATION_MIN_VIEWS="$2" + if ! [[ "${TRIANGULATION_MIN_VIEWS}" =~ ^[0-9]+$ ]] || [[ "${TRIANGULATION_MIN_VIEWS}" -lt 2 ]]; then + echo "Error: --triangulation_min_views must be an integer >= 2" + usage + exit 1 + fi + shift 2 + ;; + --ba_use_gt_calibration) + BA_USE_GT_CALIBRATION=1 + shift + ;; + --ba_gt_calibration_dir) + if [[ $# -lt 2 ]]; then + echo "Error: --ba_gt_calibration_dir requires a directory path" + usage + exit 1 + fi + BA_GT_CALIBRATION_DIR="$2" + shift 2 + ;; + --single_cluster) + SINGLE_CLUSTER=1 + shift + ;; + -h|--help) + usage + exit 0 + ;; + *) + echo "Error: unknown argument '$1'" + usage + exit 1 + ;; + esac +done + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" +DATASET_ROOT="${REPO_ROOT}/benchmarks/eth3dmvs" +PARTITION_MODE_TAG="metis" +if [[ "${SINGLE_CLUSTER}" -eq 1 ]]; then + PARTITION_MODE_TAG="single_cluster" +fi +if [[ "${RECONSTRUCTION_METHOD}" == "pi3" ]]; then + RUN_TAG="pi3_${PARTITION_MODE_TAG}" +else + RUN_TAG="${TRACKER}_${POINT_SOURCE}_${PARTITION_MODE_TAG}" +fi +RESULTS_ROOT_BASE="${REPO_ROOT}/pipeline/results/eth3dmvs_${RUN_TAG}" + +if [[ ! -d "${DATASET_ROOT}" ]]; then + echo "Error: dataset directory not found: ${DATASET_ROOT}" + exit 1 +fi + +# Ensure conda activation works in non-interactive shells. +if [[ "${CONDA_DEFAULT_ENV:-}" != "gtsfm-v2" ]]; then + if command -v conda >/dev/null 2>&1; then + eval "$(conda shell.bash hook)" || true + fi + + if [[ -z "${CONDA_EXE:-}" ]]; then + for candidate in \ + "${HOME}/miniconda3/etc/profile.d/conda.sh" \ + "${HOME}/anaconda3/etc/profile.d/conda.sh" \ + "/opt/conda/etc/profile.d/conda.sh"; do + if [[ -f "${candidate}" ]]; then + # shellcheck disable=SC1090 + source "${candidate}" + break + fi + done + fi + + if ! command -v conda >/dev/null 2>&1; then + echo "Error: conda command is not available in this shell." + echo "Please run with gtsfm-v2 already active, or ensure conda is installed and initialized." + exit 1 + fi + + conda activate gtsfm-v2 +fi + +prepare_eval_baseline() { + local source_baseline_dir="$1" + local prepared_baseline_dir="$2" + + mkdir -p "${prepared_baseline_dir}" + cp "${source_baseline_dir}/cameras.txt" "${prepared_baseline_dir}/cameras.txt" + cp "${source_baseline_dir}/points3D.txt" "${prepared_baseline_dir}/points3D.txt" + + # Keep COLMAP image records, but normalize NAME to basename to match reconstruction outputs. + awk ' + BEGIN {non_comment_idx = 0} + { + if ($0 ~ /^#/ || $0 ~ /^[[:space:]]*$/) { + print $0 + next + } + + if (non_comment_idx % 2 == 0) { + name = $10 + sub(/^.*\//, "", name) + printf "%s %s %s %s %s %s %s %s %s %s\n", $1, $2, $3, $4, $5, $6, $7, $8, $9, name + } else { + print $0 + } + + non_comment_idx++ + } + ' "${source_baseline_dir}/images.txt" > "${prepared_baseline_dir}/images.txt" +} + +run_scene() { + local split="$1" + local scene_name="$2" + local dataset_dir="${DATASET_ROOT}/${split}/${scene_name}" + local results_root="${RESULTS_ROOT_BASE}/${split}/${scene_name}" + local cluster_tree_path="${results_root}/1-partition/results/cluster_tree.pkl" + local scene_images_dir + local baseline_dir + local eval_baseline_dir + local recon_run_name="vggt_cluster_run" + local recon_model_name="${TRACKER}" + local aligned_original_model_dir="${TRACKER}_original" + if [[ "${RECONSTRUCTION_METHOD}" == "pi3" ]]; then + recon_run_name="pi3_run" + recon_model_name="pi3" + aligned_original_model_dir="pi3_original" + fi + local recon_output_root="${results_root}/2-reconstruction/${recon_run_name}" + local ba_run_name="vggt_cluster_run" + if [[ "${BA_USE_GT_CALIBRATION}" -eq 1 ]]; then + ba_run_name="${ba_run_name}__gtcalib-on" + else + ba_run_name="${ba_run_name}__gtcalib-off" + fi + + echo "========================================" + echo "Running ETH3D scene: ${split}/${scene_name}" + echo "Dataset: ${dataset_dir}" + echo "Results: ${results_root}" + + if [[ -d "${dataset_dir}/dslr_calibration_undistorted" ]]; then + baseline_dir="${dataset_dir}/dslr_calibration_undistorted" + else + baseline_dir="$(find "${dataset_dir}" -mindepth 1 -maxdepth 4 -type d \( -name "sparse" -o -name "colmap" \) -print -quit)" + fi + + if [[ -z "${baseline_dir}" ]]; then + echo "Error: baseline directory not found for scene ${split}/${scene_name}." + echo "Expected: ${dataset_dir}/dslr_calibration_undistorted" + exit 1 + fi + + for required_file in cameras.txt images.txt points3D.txt; do + if [[ ! -f "${baseline_dir}/${required_file}" ]]; then + echo "Error: baseline file missing: ${baseline_dir}/${required_file}" + exit 1 + fi + done + + if [[ -d "${dataset_dir}/images/dslr_images_undistorted" ]]; then + scene_images_dir="${dataset_dir}/images/dslr_images_undistorted" + elif [[ -d "${dataset_dir}/images" ]]; then + scene_images_dir="${dataset_dir}/images" + else + echo "Error: image directory not found for scene ${split}/${scene_name}." + echo "Expected one of:" + echo " ${dataset_dir}/images/dslr_images_undistorted" + echo " ${dataset_dir}/images" + exit 1 + fi + + eval_baseline_dir="${results_root}/_eval_baseline" + prepare_eval_baseline "${baseline_dir}" "${eval_baseline_dir}" + + echo "Images: ${scene_images_dir}" + echo "Baseline (source): ${baseline_dir}" + echo "Baseline (eval): ${eval_baseline_dir}" + + # partition + partition_cmd=( + python "${REPO_ROOT}/pipeline/1-partition/partition_metis_megaloc.py" + --dataset_dir "${dataset_dir}" + --images_dir "${scene_images_dir}" + --output_root "${results_root}/1-partition" + ) + if [[ "${SINGLE_CLUSTER}" -eq 1 ]]; then + partition_cmd+=(--single_cluster) + fi + "${partition_cmd[@]}" + + # reconstruction + if [[ "${RECONSTRUCTION_METHOD}" == "pi3" ]]; then + python "${REPO_ROOT}/pipeline/2-reconstruction/Pi3/run_on_cluster.py" \ + --cluster_tree_path "${cluster_tree_path}" \ + --dataset_dir "${dataset_dir}" \ + --images_root "${scene_images_dir}" \ + --output_root "${recon_output_root}" \ + --model_name "${recon_model_name}" + else + python "${REPO_ROOT}/pipeline/2-reconstruction/vggt/run_on_cluster.py" \ + --cluster_tree_path "${cluster_tree_path}" \ + --dataset_dir "${dataset_dir}" \ + --images_root "${scene_images_dir}" \ + --output_root "${recon_output_root}" \ + --ba_tracker "${TRACKER}" \ + --point_source "${POINT_SOURCE}" \ + --triangulation_min_views "${TRIANGULATION_MIN_VIEWS}" + fi + + python "${REPO_ROOT}/pipeline/utils/check_tracks.py" \ + --recon_root "${recon_output_root}/results" \ + --images_root "${dataset_dir}" \ + --model_name "${recon_model_name}" + + if [[ "${RECONSTRUCTION_METHOD}" != "pi3" ]]; then + ba_cmd=( + python "${REPO_ROOT}/pipeline/2-reconstruction/vggt/run_on_cluster.py" + --cluster_tree_path "${cluster_tree_path}" + --dataset_dir "${dataset_dir}" + --images_root "${scene_images_dir}" + --output_root "${recon_output_root}" + --ba_tracker "${TRACKER}" + --use_ba + --ba_output_root "${results_root}/3-cluster_ba/${ba_run_name}" + ) + if [[ "${BA_USE_GT_CALIBRATION}" -eq 1 ]]; then + ba_cmd+=(--ba_use_gt_calibration) + if [[ -n "${BA_GT_CALIBRATION_DIR}" ]]; then + ba_cmd+=(--ba_gt_calibration_dir "${BA_GT_CALIBRATION_DIR}") + fi + fi + "${ba_cmd[@]}" + fi + + if [[ "${RECONSTRUCTION_METHOD}" != "pi3" ]]; then + python "${REPO_ROOT}/pipeline/utils/check_tracks.py" \ + --recon_root "${results_root}/3-cluster_ba/${ba_run_name}/results" \ + --images_root "${dataset_dir}" \ + --model_name "${TRACKER}" + fi + + python "${REPO_ROOT}/gtsfm/evaluation/compare_colmap_outputs_by_cluster.py" \ + --baseline "${eval_baseline_dir}" \ + --root "${recon_output_root}" \ + --recon_name "${recon_model_name}" \ + --csv_output "${recon_output_root}/${recon_model_name}_eval/cluster_pose_metrics.csv" + + if [[ "${RECONSTRUCTION_METHOD}" != "pi3" ]]; then + python "${REPO_ROOT}/gtsfm/evaluation/compare_colmap_outputs_by_cluster.py" \ + --baseline "${eval_baseline_dir}" \ + --root "${results_root}/3-cluster_ba/${ba_run_name}" \ + --recon_name "${TRACKER}" \ + --csv_output "${results_root}/3-cluster_ba/${ba_run_name}/${TRACKER}_ba_eval/cluster_pose_metrics.csv" + fi + + eval_reconstruction() { + local current_model_dir="$1" + local output_dir="$2" + local fallback_model_dir="${3:-}" + local recon_name + + if [[ ! -f "${current_model_dir}/images.txt" && ! -f "${current_model_dir}/images.bin" ]]; then + if [[ -n "${fallback_model_dir}" && ( -f "${fallback_model_dir}/images.txt" || -f "${fallback_model_dir}/images.bin" ) ]]; then + echo "Model not found at ${current_model_dir}; falling back to ${fallback_model_dir}" + current_model_dir="${fallback_model_dir}" + else + echo "Error: evaluation model missing at ${current_model_dir}" + if [[ -n "${fallback_model_dir}" ]]; then + echo "Fallback model also missing at ${fallback_model_dir}" + fi + exit 1 + fi + fi + + recon_name="$(basename "${current_model_dir}")" + + python "${REPO_ROOT}/gtsfm/evaluation/compare_colmap_outputs.py" \ + --baseline "${eval_baseline_dir}" \ + --current "${current_model_dir}" \ + --output "${output_dir}" + + python "${REPO_ROOT}/gtsfm/evaluation/compare_colmap_outputs_by_cluster.py" \ + --baseline "${eval_baseline_dir}" \ + --root "${output_dir}" \ + --recon_name "${recon_name}" \ + --csv_output "${output_dir}/vggt_eval/cluster_pose_metrics.csv" + } + + # case 1 + python "${REPO_ROOT}/pipeline/4-alignment/alignment.py" \ + --cluster_tree_path "${cluster_tree_path}" \ + --input_root "${recon_output_root}" \ + --input_model_name "${recon_model_name}" \ + --output_root "${results_root}/4-alignment" + + eval_reconstruction \ + "${results_root}/4-alignment/results/merged_pre_ba" \ + "${results_root}/4-alignment/results" \ + "${results_root}/4-alignment/results/${aligned_original_model_dir}" + + if [[ "${RECONSTRUCTION_METHOD}" != "pi3" ]]; then + # case 2 + python "${REPO_ROOT}/pipeline/4-alignment/alignment.py" \ + --cluster_tree_path "${cluster_tree_path}" \ + --input_root "${results_root}/3-cluster_ba/${ba_run_name}" \ + --input_model_name "${TRACKER}" \ + --output_root "${results_root}/4-alignment-clusterba" + + eval_reconstruction \ + "${results_root}/4-alignment-clusterba/results/merged_pre_ba" \ + "${results_root}/4-alignment-clusterba/results" \ + "${results_root}/4-alignment-clusterba/results/${aligned_original_model_dir}" + fi + + # case 3 + python "${REPO_ROOT}/pipeline/4-alignment/alignment.py" \ + --cluster_tree_path "${cluster_tree_path}" \ + --input_root "${recon_output_root}" \ + --input_model_name "${recon_model_name}" \ + --output_root "${results_root}/5-global_ba" \ + --run_colmap_ba \ + --convert_ba_to_txt + + eval_reconstruction \ + "${results_root}/5-global_ba/results/merged_colmap_ba_txt" \ + "${results_root}/5-global_ba/results" \ + "${results_root}/5-global_ba/results/${aligned_original_model_dir}" + + if [[ "${RECONSTRUCTION_METHOD}" != "pi3" ]]; then + # case 4 + python "${REPO_ROOT}/pipeline/4-alignment/alignment.py" \ + --cluster_tree_path "${cluster_tree_path}" \ + --input_root "${results_root}/3-cluster_ba/${ba_run_name}" \ + --output_root "${results_root}/5-global_ba-cluster_ba" \ + --input_model_name "${TRACKER}" \ + --run_colmap_ba \ + --convert_ba_to_txt + + eval_reconstruction \ + "${results_root}/5-global_ba-cluster_ba/results/merged_colmap_ba_txt" \ + "${results_root}/5-global_ba-cluster_ba/results" \ + "${results_root}/5-global_ba-cluster_ba/results/${aligned_original_model_dir}" + fi +} + +mapfile -t SCENES < <(find "${DATASET_ROOT}" -mindepth 2 -maxdepth 2 -type d | sort) + +if [[ ${#SCENES[@]} -eq 0 ]]; then + echo "Error: no scene directories found under ${DATASET_ROOT}" + exit 1 +fi + +echo "Found ${#SCENES[@]} ETH3D scenes." +echo "Run tag: ${RUN_TAG}" +echo "Results root base: ${RESULTS_ROOT_BASE}" +echo "Reconstruction method: ${RECONSTRUCTION_METHOD}" +echo "Tracker: ${TRACKER}" +echo "Point source: ${POINT_SOURCE} (triangulation_min_views=${TRIANGULATION_MIN_VIEWS})" +echo "BA use GT calibration: ${BA_USE_GT_CALIBRATION} (dir=${BA_GT_CALIBRATION_DIR:-auto})" +FAILED_SCENES=() +SUCCEEDED_COUNT=0 +for scene_path in "${SCENES[@]}"; do + split="$(basename "$(dirname "${scene_path}")")" + scene_name="$(basename "${scene_path}")" + + # Execute each scene in an isolated shell so a failure does not stop the full dataset run. + set +e + ( + set -euo pipefail + run_scene "${split}" "${scene_name}" + ) + scene_rc=$? + set -e + + if [[ ${scene_rc} -ne 0 ]]; then + FAILED_SCENES+=("${split}/${scene_name}") + echo "Scene failed (${scene_rc}): ${split}/${scene_name}" + else + SUCCEEDED_COUNT=$((SUCCEEDED_COUNT + 1)) + echo "Scene completed: ${split}/${scene_name}" + fi +done + +echo "Completed ETH3D scenes: ${SUCCEEDED_COUNT}/${#SCENES[@]} succeeded." +if [[ ${#FAILED_SCENES[@]} -gt 0 ]]; then + echo "Failed scenes (${#FAILED_SCENES[@]}):" + printf ' - %s\n' "${FAILED_SCENES[@]}" +fi diff --git a/pipeline/utils/check_tracks.py b/pipeline/utils/check_tracks.py new file mode 100644 index 000000000..ea5476423 --- /dev/null +++ b/pipeline/utils/check_tracks.py @@ -0,0 +1,949 @@ +"""Check track quality for COLMAP reconstructions under cluster folders. + +This script: +1) Finds cluster model folders named `--model_name` under `--recon_root`. +2) Computes reprojection errors for all track measurements in each cluster. +3) Saves summary statistics (JSON + CSV). +4) Saves overlay visualizations for up to `--max_images` images per cluster. + +Example: + python pipeline/utils/check_tracks.py \ + --recon_root /nethome/xzhang979/nvme/gtsfm/pipeline/results/gerrard-hall/2-reconstruction/vggt_cluster_run/results \ + --images_root /nethome/xzhang979/nvme/gtsfm/benchmarks/gerrard-hall \ + --model_name vggt +""" + +from __future__ import annotations + +import argparse +import colorsys +import csv +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +import numpy as np +from PIL import Image as PILImage +from PIL import ImageDraw + +try: + import cv2 +except ImportError: + cv2 = None + +try: + import matplotlib.pyplot as plt +except ImportError: + plt = None + +from gtsfm.common.gtsfm_data import GtsfmData + + +@dataclass +class MeasurementPair: + """One 2D measurement and its reprojection for visualization/error calculation.""" + + track_idx: int + measured: np.ndarray + reprojected: np.ndarray + error: float + + +IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"} + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Compute and visualize track reprojection errors in cluster reconstructions.") + parser.add_argument("--recon_root", type=str, required=True, help="Root directory containing per-cluster outputs.") + parser.add_argument("--images_root", type=str, required=True, help="Root directory for source images.") + parser.add_argument("--model_name", type=str, default="vggt", help="COLMAP model directory name per cluster.") + parser.add_argument( + "--output_root", + type=str, + default=None, + help="Output root for reports/visualizations. Defaults to /check_tracks.", + ) + parser.add_argument("--max_images", type=int, default=5, help="Maximum overlay images to save per cluster.") + parser.add_argument( + "--max_tracks_per_image", + type=int, + default=500, + help="Maximum track measurements to draw per image (for readability/speed).", + ) + parser.add_argument("--line_width", type=int, default=1, help="Line width in visualization.") + parser.add_argument("--dot_radius", type=int, default=2, help="Dot radius in visualization.") + parser.add_argument("--random_seed", type=int, default=0, help="Random seed for visualization sampling.") + parser.add_argument( + "--max_correspondences_per_pair", + type=int, + default=20, + help="Maximum cross-image correspondence lines to draw for each adjacent image pair.", + ) + parser.add_argument( + "--corr_low_residual_only", + action="store_true", + help="For stitched correspondences, only keep points with low reprojection residual.", + ) + parser.add_argument( + "--corr_residual_thresh_px", + type=float, + default=None, + help="Residual threshold (px) for --corr_low_residual_only. Defaults to cluster p95 if omitted.", + ) + parser.add_argument("--hist_bins", type=int, default=80, help="Number of bins for reprojection histograms.") + parser.add_argument( + "--hist_clip_px", + type=float, + default=100.0, + help="Clip histogram x-axis to this pixel value for readability.", + ) + return parser.parse_args() + + +def _is_colmap_model_dir(model_dir: Path) -> bool: + has_images = (model_dir / "images.txt").exists() or (model_dir / "images.bin").exists() + has_points = (model_dir / "points3D.txt").exists() or (model_dir / "points3D.bin").exists() + has_cameras = (model_dir / "cameras.txt").exists() or (model_dir / "cameras.bin").exists() + return has_images and has_points and has_cameras + + +def _find_model_dirs(recon_root: Path, model_name: str) -> list[Path]: + matches: list[Path] = [] + for path in recon_root.rglob(model_name): + if path.is_dir() and _is_colmap_model_dir(path): + matches.append(path) + return sorted(matches) + + +def _parse_colmap_recon_sizes(model_dir: Path) -> dict[int, tuple[int, int]]: + """Parse image_id -> (recon_h, recon_w) from COLMAP text model files, if available.""" + cameras_txt = model_dir / "cameras.txt" + images_txt = model_dir / "images.txt" + if not cameras_txt.exists() or not images_txt.exists(): + return {} + + camera_hw: dict[int, tuple[int, int]] = {} + for line in cameras_txt.read_text(encoding="utf-8").splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + parts = line.split() + if len(parts) < 4: + continue + try: + camera_id = int(parts[0]) + width = int(parts[2]) + height = int(parts[3]) + except ValueError: + continue + camera_hw[camera_id] = (height, width) + + image_hw: dict[int, tuple[int, int]] = {} + lines = images_txt.read_text(encoding="utf-8").splitlines() + i = 0 + while i < len(lines): + line = lines[i].strip() + if not line or line.startswith("#"): + i += 1 + continue + parts = line.split() + if len(parts) >= 9: + try: + image_id = int(parts[0]) + camera_id = int(parts[8]) + except ValueError: + i += 2 + continue + if camera_id in camera_hw: + image_hw[image_id] = camera_hw[camera_id] + i += 2 + + return image_hw + + +def _build_image_index(images_root: Path) -> dict[str, list[Path]]: + """Index all images by lowercase basename to support flexible lookup.""" + index: dict[str, list[Path]] = {} + for p in images_root.rglob("*"): + if not p.is_file(): + continue + if p.suffix.lower() not in IMAGE_EXTENSIONS: + continue + key = p.name.lower() + index.setdefault(key, []).append(p) + return index + + +def _get_image_path(image_name: str, images_root: Path, image_index: dict[str, list[Path]]) -> Optional[Path]: + # 1) Try relative path from images_root as-is. + direct = images_root / image_name + if direct.exists(): + return direct + + # 2) Try common nested path under images/. + nested = images_root / "images" / image_name + if nested.exists(): + return nested + + # 3) Fallback to basename index (handles unknown subfolder layout). + candidates = image_index.get(Path(image_name).name.lower(), []) + if candidates: + return candidates[0] + + return None + + +def _collect_pairs_for_camera(gtsfm_data: GtsfmData, camera_idx: int) -> list[MeasurementPair]: + camera = gtsfm_data.get_camera(camera_idx) + if camera is None: + return [] + + pairs: list[MeasurementPair] = [] + for track_idx, measurement_idx in gtsfm_data.get_measurements_for_camera(camera_idx): + track = gtsfm_data.get_track(track_idx) + _, uv_measured = track.measurement(measurement_idx) + uv_reproj, success = camera.projectSafe(track.point3()) + if not success: + continue + + measured = np.asarray(uv_measured, dtype=float) + reprojected = np.asarray(uv_reproj, dtype=float) + error = float(np.linalg.norm(measured - reprojected)) + pairs.append(MeasurementPair(track_idx=track_idx, measured=measured, reprojected=reprojected, error=error)) + return pairs + + +def _track_color(track_idx: int) -> tuple[int, int, int]: + """Deterministic RGB color for each track id.""" + hue = (track_idx * 0.61803398875) % 1.0 + r, g, b = colorsys.hsv_to_rgb(hue, 0.75, 0.95) + return int(r * 255), int(g * 255), int(b * 255) + + +def _select_adjacent_cameras( + candidate_camera_indices: list[int], + per_camera_mean_err: list[tuple[int, float]], + max_images: int, +) -> list[int]: + """Pick a consecutive camera window centered around the worst-error camera.""" + if not candidate_camera_indices: + return [] + + ordered = sorted(candidate_camera_indices) + if len(ordered) <= max_images: + return ordered + + anchor = per_camera_mean_err[0][0] if per_camera_mean_err else ordered[len(ordered) // 2] + anchor_pos = ordered.index(anchor) if anchor in ordered else len(ordered) // 2 + half = max_images // 2 + start = max(0, anchor_pos - half) + end = start + max_images + if end > len(ordered): + end = len(ordered) + start = end - max_images + return ordered[start:end] + + +def _draw_pairs_on_image( + image_rgb: np.ndarray, + pairs: list[MeasurementPair], + scale_u: float, + scale_v: float, + line_width: int, + dot_radius: int, +) -> np.ndarray: + image_rgb = image_rgb.astype(np.uint8) + + if cv2 is not None: + image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR) + + for pair in pairs: + x_meas = int(round(pair.measured[0] * scale_u)) + y_meas = int(round(pair.measured[1] * scale_v)) + x_rep = int(round(pair.reprojected[0] * scale_u)) + y_rep = int(round(pair.reprojected[1] * scale_v)) + r, g, b = _track_color(pair.track_idx) + color_bgr = (b, g, r) + + # Use unique track color. Reprojected point is slightly larger. + cv2.line( + image_bgr, + (x_rep, y_rep), + (x_meas, y_meas), + color_bgr, + thickness=line_width, + lineType=cv2.LINE_AA, + ) + cv2.circle( + image_bgr, + (x_meas, y_meas), + dot_radius, + (160, 160, 160), + thickness=-1, + lineType=cv2.LINE_AA, + ) + cv2.circle( + image_bgr, + (x_rep, y_rep), + dot_radius + 1, + color_bgr, + thickness=-1, + lineType=cv2.LINE_AA, + ) + return cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) + + pil_image = PILImage.fromarray(image_rgb) + draw = ImageDraw.Draw(pil_image) + for pair in pairs: + x_meas = float(pair.measured[0] * scale_u) + y_meas = float(pair.measured[1] * scale_v) + x_rep = float(pair.reprojected[0] * scale_u) + y_rep = float(pair.reprojected[1] * scale_v) + color = _track_color(pair.track_idx) + draw.line([(x_rep, y_rep), (x_meas, y_meas)], fill=color, width=line_width) + draw.ellipse( + [(x_meas - dot_radius, y_meas - dot_radius), (x_meas + dot_radius, y_meas + dot_radius)], + fill=(160, 160, 160), + ) + draw.ellipse( + [(x_rep - (dot_radius + 1), y_rep - (dot_radius + 1)), (x_rep + (dot_radius + 1), y_rep + (dot_radius + 1))], + fill=color, + ) + return np.asarray(pil_image) + + +def _compose_correspondence_canvas( + items: list[tuple[int, np.ndarray, dict[int, tuple[float, float]]]], + max_correspondences_per_pair: int, + rng: np.random.Generator, + point_radius: int, + line_width: int, +) -> tuple[np.ndarray, int]: + """Create a single stitched image with adjacent-image correspondence lines.""" + if not items: + return np.zeros((1, 1, 3), dtype=np.uint8), 0 + + target_h = int(min(img.shape[0] for _, img, _ in items)) + resized_images: list[np.ndarray] = [] + track_to_xy_per_image: list[dict[int, tuple[float, float]]] = [] + widths: list[int] = [] + + for _, image_rgb, track_xy in items: + h, w = image_rgb.shape[:2] + if h != target_h: + new_w = max(1, int(round(w * (target_h / h)))) + if cv2 is not None: + resized = cv2.resize(image_rgb, (new_w, target_h), interpolation=cv2.INTER_AREA) + else: + resized = np.asarray(PILImage.fromarray(image_rgb).resize((new_w, target_h), PILImage.Resampling.BILINEAR)) + else: + resized = image_rgb + new_w = w + + sx = new_w / w if w > 0 else 1.0 + sy = target_h / h if h > 0 else 1.0 + + track_map: dict[int, tuple[float, float]] = {} + for tid, (x0, y0) in track_xy.items(): + x = float(x0 * sx) + y = float(y0 * sy) + track_map[tid] = (x, y) + + resized_images.append(resized.astype(np.uint8)) + track_to_xy_per_image.append(track_map) + widths.append(new_w) + + gap = 24 + canvas_w = int(sum(widths) + gap * (len(widths) - 1)) + canvas_h = target_h + canvas = np.zeros((canvas_h, canvas_w, 3), dtype=np.uint8) + + x_offsets: list[int] = [] + x_cursor = 0 + for img in resized_images: + h, w = img.shape[:2] + canvas[:h, x_cursor : x_cursor + w] = img + x_offsets.append(x_cursor) + x_cursor += w + gap + + correspondences_drawn = 0 + + if cv2 is not None: + canvas_bgr = cv2.cvtColor(canvas, cv2.COLOR_RGB2BGR) + for i in range(len(track_to_xy_per_image) - 1): + a_map = track_to_xy_per_image[i] + b_map = track_to_xy_per_image[i + 1] + common = list(set(a_map.keys()).intersection(b_map.keys())) + if not common: + continue + if len(common) > max_correspondences_per_pair: + sampled = rng.choice(len(common), size=max_correspondences_per_pair, replace=False) + common = [common[int(k)] for k in sampled] + + for tid in common: + ax, ay = a_map[tid] + bx, by = b_map[tid] + p1 = (int(round(ax + x_offsets[i])), int(round(ay))) + p2 = (int(round(bx + x_offsets[i + 1])), int(round(by))) + r, g, b = _track_color(tid) + color_bgr = (b, g, r) + cv2.line(canvas_bgr, p1, p2, color_bgr, thickness=line_width, lineType=cv2.LINE_AA) + cv2.circle(canvas_bgr, p1, point_radius + 1, color_bgr, thickness=-1, lineType=cv2.LINE_AA) + cv2.circle(canvas_bgr, p2, point_radius + 1, color_bgr, thickness=-1, lineType=cv2.LINE_AA) + correspondences_drawn += 1 + + return cv2.cvtColor(canvas_bgr, cv2.COLOR_BGR2RGB), correspondences_drawn + + pil_canvas = PILImage.fromarray(canvas) + draw = ImageDraw.Draw(pil_canvas) + for i in range(len(track_to_xy_per_image) - 1): + a_map = track_to_xy_per_image[i] + b_map = track_to_xy_per_image[i + 1] + common = list(set(a_map.keys()).intersection(b_map.keys())) + if not common: + continue + if len(common) > max_correspondences_per_pair: + sampled = rng.choice(len(common), size=max_correspondences_per_pair, replace=False) + common = [common[int(k)] for k in sampled] + + for tid in common: + ax, ay = a_map[tid] + bx, by = b_map[tid] + x1, y1 = float(ax + x_offsets[i]), float(ay) + x2, y2 = float(bx + x_offsets[i + 1]), float(by) + color = _track_color(tid) + draw.line([(x1, y1), (x2, y2)], fill=color, width=line_width) + draw.ellipse( + [(x1 - (point_radius + 1), y1 - (point_radius + 1)), (x1 + (point_radius + 1), y1 + (point_radius + 1))], + fill=color, + ) + draw.ellipse( + [(x2 - (point_radius + 1), y2 - (point_radius + 1)), (x2 + (point_radius + 1), y2 + (point_radius + 1))], + fill=color, + ) + correspondences_drawn += 1 + + return np.asarray(pil_canvas), correspondences_drawn + + +def _stats(values: list[float]) -> dict[str, float]: + if not values: + return { + "count": 0, + "mean": float("nan"), + "median": float("nan"), + "p90": float("nan"), + "p95": float("nan"), + "max": float("nan"), + } + arr = np.asarray(values, dtype=float) + return { + "count": float(arr.size), + "mean": float(np.mean(arr)), + "median": float(np.median(arr)), + "p90": float(np.percentile(arr, 90)), + "p95": float(np.percentile(arr, 95)), + "max": float(np.max(arr)), + } + + +def _in_image_bounds(uv: np.ndarray, width: int, height: int) -> bool: + x, y = float(uv[0]), float(uv[1]) + return 0.0 <= x < float(width) and 0.0 <= y < float(height) + + +def _compute_oob_sanity( + per_camera: dict[int, list[MeasurementPair]], + gtsfm_data: GtsfmData, + recon_hw_by_image: dict[int, tuple[int, int]], +) -> dict[str, float]: + measured_oob = 0 + reprojected_oob = 0 + total = 0 + per_image_measured_oob_max = 0.0 + per_image_reprojected_oob_max = 0.0 + + for camera_idx, pairs in per_camera.items(): + if not pairs: + continue + info = gtsfm_data.get_image_info(camera_idx) + if camera_idx in recon_hw_by_image: + h, w = recon_hw_by_image[camera_idx] + elif info.shape is not None: + h, w = info.shape + else: + continue + if w <= 0 or h <= 0: + continue + + img_measured_oob = 0 + img_reprojected_oob = 0 + for pair in pairs: + total += 1 + if not _in_image_bounds(pair.measured, w, h): + measured_oob += 1 + img_measured_oob += 1 + if not _in_image_bounds(pair.reprojected, w, h): + reprojected_oob += 1 + img_reprojected_oob += 1 + + img_total = len(pairs) + if img_total > 0: + per_image_measured_oob_max = max(per_image_measured_oob_max, img_measured_oob / img_total) + per_image_reprojected_oob_max = max(per_image_reprojected_oob_max, img_reprojected_oob / img_total) + + measured_oob_rate = (measured_oob / total) if total > 0 else float("nan") + reprojected_oob_rate = (reprojected_oob / total) if total > 0 else float("nan") + + return { + "measured_oob_count": float(measured_oob), + "reprojected_oob_count": float(reprojected_oob), + "measured_oob_rate": float(measured_oob_rate), + "reprojected_oob_rate": float(reprojected_oob_rate), + "per_image_measured_oob_rate_max": float(per_image_measured_oob_max), + "per_image_reprojected_oob_rate_max": float(per_image_reprojected_oob_max), + } + + +def _compute_adjacent_pair_displacement_sanity( + selected_cameras: list[int], + per_camera: dict[int, list[MeasurementPair]], +) -> dict[str, float]: + if len(selected_cameras) < 2: + return { + "adjacent_pair_count": 0.0, + "adjacent_pair_min_common_tracks": float("nan"), + "adjacent_pair_median_common_tracks": float("nan"), + "adjacent_pair_disp_median_px": float("nan"), + "adjacent_pair_disp_p90_px": float("nan"), + "adjacent_pair_disp_max_px": float("nan"), + } + + pair_common_counts: list[int] = [] + all_displacements: list[float] = [] + for cam_a, cam_b in zip(selected_cameras[:-1], selected_cameras[1:]): + a_map = {p.track_idx: p.measured for p in per_camera.get(cam_a, [])} + b_map = {p.track_idx: p.measured for p in per_camera.get(cam_b, [])} + common_ids = set(a_map.keys()).intersection(b_map.keys()) + pair_common_counts.append(len(common_ids)) + if not common_ids: + continue + for tid in common_ids: + all_displacements.append(float(np.linalg.norm(a_map[tid] - b_map[tid]))) + + if pair_common_counts: + pair_common_arr = np.asarray(pair_common_counts, dtype=float) + min_common = float(np.min(pair_common_arr)) + median_common = float(np.median(pair_common_arr)) + else: + min_common = float("nan") + median_common = float("nan") + + if all_displacements: + disp_arr = np.asarray(all_displacements, dtype=float) + disp_median = float(np.median(disp_arr)) + disp_p90 = float(np.percentile(disp_arr, 90)) + disp_max = float(np.max(disp_arr)) + else: + disp_median = float("nan") + disp_p90 = float("nan") + disp_max = float("nan") + + return { + "adjacent_pair_count": float(max(0, len(selected_cameras) - 1)), + "adjacent_pair_min_common_tracks": min_common, + "adjacent_pair_median_common_tracks": median_common, + "adjacent_pair_disp_median_px": disp_median, + "adjacent_pair_disp_p90_px": disp_p90, + "adjacent_pair_disp_max_px": disp_max, + } + + +def _save_histogram( + errors: list[float], + output_path: Path, + title: str, + bins: int, + clip_px: float, +) -> bool: + """Save histogram plot if matplotlib is available.""" + if plt is None or not errors: + return False + + arr = np.asarray(errors, dtype=float) + shown = np.clip(arr, 0.0, clip_px) + mean_v = float(np.mean(arr)) + median_v = float(np.median(arr)) + p90_v = float(np.percentile(arr, 90)) + p95_v = float(np.percentile(arr, 95)) + std_v = float(np.std(arr)) + min_v = float(np.min(arr)) + max_v = float(np.max(arr)) + n_v = int(arr.size) + + fig = plt.figure(figsize=(8, 4.5), dpi=140) + ax = fig.add_subplot(111) + ax.hist(shown, bins=bins, color="#2f80ed", edgecolor="#1b4f9c", alpha=0.9) + ax.axvline(mean_v, color="#0b1f3b", linestyle="-", linewidth=1.6, label=f"mean={mean_v:.3f}") + ax.axvline(median_v, color="#f2994a", linestyle="--", linewidth=1.5, label=f"median={median_v:.3f}") + ax.axvline(p90_v, color="#27ae60", linestyle="--", linewidth=1.3, label=f"p90={p90_v:.3f}") + ax.axvline(p95_v, color="#eb5757", linestyle="--", linewidth=1.3, label=f"p95={p95_v:.3f}") + ax.set_title(title) + ax.set_xlabel(f"Reprojection error (px), clipped to {clip_px:g}") + ax.set_ylabel("Measurements") + ax.grid(True, alpha=0.25) + stats_text = ( + f"n={n_v}\n" + f"min={min_v:.3f}px\n" + f"max={max_v:.3f}px\n" + f"std={std_v:.3f}px\n" + f"mean={mean_v:.3f}px\n" + f"median={median_v:.3f}px\n" + f"p90={p90_v:.3f}px\n" + f"p95={p95_v:.3f}px" + ) + ax.text( + 0.985, + 0.98, + stats_text, + transform=ax.transAxes, + va="top", + ha="right", + fontsize=8.5, + bbox={"boxstyle": "round", "facecolor": "white", "alpha": 0.9, "edgecolor": "#7f8c8d"}, + ) + ax.legend(loc="upper left", fontsize=8) + fig.tight_layout() + output_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(output_path) + plt.close(fig) + return True + + +def _process_cluster( + model_dir: Path, + recon_root: Path, + images_root: Path, + output_root: Path, + max_images: int, + max_tracks_per_image: int, + line_width: int, + dot_radius: int, + max_correspondences_per_pair: int, + corr_low_residual_only: bool, + corr_residual_thresh_px: float | None, + hist_bins: int, + hist_clip_px: float, + rng: np.random.Generator, + image_index: dict[str, list[Path]], +) -> tuple[dict[str, object], list[float]]: + cluster_rel = model_dir.relative_to(recon_root) + cluster_output_dir = output_root / cluster_rel + cluster_output_dir.mkdir(parents=True, exist_ok=True) + + gtsfm_data = GtsfmData.read_colmap(str(model_dir)) + recon_hw_by_image = _parse_colmap_recon_sizes(model_dir) + camera_indices = sorted(gtsfm_data.get_valid_camera_indices()) + + all_errors: list[float] = [] + per_camera: dict[int, list[MeasurementPair]] = {} + per_camera_mean_err: list[tuple[int, float]] = [] + + for camera_idx in camera_indices: + pairs = _collect_pairs_for_camera(gtsfm_data, camera_idx) + if not pairs: + continue + per_camera[camera_idx] = pairs + camera_err = float(np.mean([pair.error for pair in pairs])) + per_camera_mean_err.append((camera_idx, camera_err)) + all_errors.extend(pair.error for pair in pairs) + + stats = _stats(all_errors) + corr_residual_thresh = float(stats["p95"]) if corr_residual_thresh_px is None else float(corr_residual_thresh_px) + summary: dict[str, object] = { + "cluster": str(cluster_rel), + "num_cameras": len(camera_indices), + "num_tracks": int(gtsfm_data.number_tracks()), + "num_measurements": int(stats["count"]), + "mean_reproj_error_px": stats["mean"], + "median_reproj_error_px": stats["median"], + "p90_reproj_error_px": stats["p90"], + "p95_reproj_error_px": stats["p95"], + "max_reproj_error_px": stats["max"], + "viz_images": [], + "histogram_plot": "", + "correspondence_plot": "", + "correspondence_lines": 0, + "measured_correspondence_plot": "", + "measured_correspondence_lines": 0, + "corr_residual_thresh_px": corr_residual_thresh if corr_low_residual_only else float("nan"), + "missing_images": 0, + "saved_overlay_count": 0, + "measured_oob_count": 0, + "reprojected_oob_count": 0, + "measured_oob_rate": float("nan"), + "reprojected_oob_rate": float("nan"), + "per_image_measured_oob_rate_max": float("nan"), + "per_image_reprojected_oob_rate_max": float("nan"), + "adjacent_pair_count": 0, + "adjacent_pair_min_common_tracks": float("nan"), + "adjacent_pair_median_common_tracks": float("nan"), + "adjacent_pair_disp_median_px": float("nan"), + "adjacent_pair_disp_p90_px": float("nan"), + "adjacent_pair_disp_max_px": float("nan"), + } + + # Choose adjacent images centered around the worst reprojection-error camera. + per_camera_mean_err.sort(key=lambda x: x[1], reverse=True) + selected_cameras = _select_adjacent_cameras(list(per_camera.keys()), per_camera_mean_err, max_images=max_images) + summary.update(_compute_oob_sanity(per_camera=per_camera, gtsfm_data=gtsfm_data, recon_hw_by_image=recon_hw_by_image)) + summary.update(_compute_adjacent_pair_displacement_sanity(selected_cameras=selected_cameras, per_camera=per_camera)) + composed_items: list[tuple[int, np.ndarray, dict[int, tuple[float, float]]]] = [] + composed_items_measured: list[tuple[int, np.ndarray, dict[int, tuple[float, float]]]] = [] + + for camera_idx in selected_cameras: + info = gtsfm_data.get_image_info(camera_idx) + if info.name is None: + continue + + image_path = _get_image_path(info.name, images_root, image_index) + if image_path is None: + summary["missing_images"] = int(summary["missing_images"]) + 1 + continue + + if cv2 is not None: + image_bgr = cv2.imread(str(image_path), cv2.IMREAD_COLOR) + if image_bgr is None: + continue + image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) + else: + image_rgb = np.asarray(PILImage.open(image_path).convert("RGB")) + + loaded_h, loaded_w = image_rgb.shape[:2] + recon_h: int + recon_w: int + if camera_idx in recon_hw_by_image: + recon_h, recon_w = recon_hw_by_image[camera_idx] + elif info.shape is not None: + recon_h, recon_w = info.shape + else: + recon_h, recon_w = loaded_h, loaded_w + + # Draw only on reconstruction-resolution images. + if recon_w > 0 and recon_h > 0 and (loaded_w != recon_w or loaded_h != recon_h): + if cv2 is not None: + image_rgb = cv2.resize(image_rgb, (recon_w, recon_h), interpolation=cv2.INTER_AREA) + else: + image_rgb = np.asarray( + PILImage.fromarray(image_rgb).resize((recon_w, recon_h), PILImage.Resampling.BILINEAR) + ) + scale_u = 1.0 + scale_v = 1.0 + + pairs = per_camera[camera_idx] + if len(pairs) > max_tracks_per_image: + sampled_indices = rng.choice(len(pairs), size=max_tracks_per_image, replace=False) + sampled_pairs = [pairs[int(i)] for i in sampled_indices] + else: + sampled_pairs = pairs + + overlay_rgb = _draw_pairs_on_image( + image_rgb=image_rgb, + pairs=sampled_pairs, + scale_u=scale_u, + scale_v=scale_v, + line_width=line_width, + dot_radius=dot_radius, + ) + + stem = Path(info.name).stem + out_name = f"{camera_idx:06d}_{stem}_tracks.jpg" + out_path = cluster_output_dir / out_name + if cv2 is not None: + cv2.imwrite(str(out_path), cv2.cvtColor(overlay_rgb, cv2.COLOR_RGB2BGR)) + else: + PILImage.fromarray(overlay_rgb).save(out_path, quality=95) + summary["viz_images"].append(str(out_path.relative_to(output_root))) + summary["saved_overlay_count"] = int(summary["saved_overlay_count"]) + 1 + track_xy: dict[int, tuple[float, float]] = {} + track_xy_measured: dict[int, tuple[float, float]] = {} + for p in pairs: + if corr_low_residual_only and p.error > corr_residual_thresh: + continue + track_xy[p.track_idx] = (float(p.reprojected[0]), float(p.reprojected[1])) + track_xy_measured[p.track_idx] = (float(p.measured[0]), float(p.measured[1])) + composed_items.append((camera_idx, image_rgb, track_xy)) + composed_items_measured.append((camera_idx, image_rgb, track_xy_measured)) + + if len(composed_items) >= 2: + composed_items.sort(key=lambda x: x[0]) + corr_img, corr_count = _compose_correspondence_canvas( + items=composed_items, + max_correspondences_per_pair=max_correspondences_per_pair, + rng=rng, + point_radius=dot_radius, + line_width=max(1, line_width), + ) + corr_path = cluster_output_dir / "adjacent_reprojected_correspondences.jpg" + if cv2 is not None: + cv2.imwrite(str(corr_path), cv2.cvtColor(corr_img, cv2.COLOR_RGB2BGR)) + else: + PILImage.fromarray(corr_img).save(corr_path, quality=95) + summary["correspondence_plot"] = str(corr_path.relative_to(output_root)) + summary["correspondence_lines"] = int(corr_count) + + if len(composed_items_measured) >= 2: + composed_items_measured.sort(key=lambda x: x[0]) + corr_img_m, corr_count_m = _compose_correspondence_canvas( + items=composed_items_measured, + max_correspondences_per_pair=max_correspondences_per_pair, + rng=rng, + point_radius=dot_radius, + line_width=max(1, line_width), + ) + corr_path_m = cluster_output_dir / "adjacent_measured_correspondences.jpg" + if cv2 is not None: + cv2.imwrite(str(corr_path_m), cv2.cvtColor(corr_img_m, cv2.COLOR_RGB2BGR)) + else: + PILImage.fromarray(corr_img_m).save(corr_path_m, quality=95) + summary["measured_correspondence_plot"] = str(corr_path_m.relative_to(output_root)) + summary["measured_correspondence_lines"] = int(corr_count_m) + + hist_path = cluster_output_dir / "reprojection_hist.png" + if _save_histogram( + errors=all_errors, + output_path=hist_path, + title=f"Cluster {cluster_rel} reprojection error", + bins=hist_bins, + clip_px=hist_clip_px, + ): + summary["histogram_plot"] = str(hist_path.relative_to(output_root)) + + return summary, all_errors + + +def _write_reports(output_root: Path, summaries: list[dict[str, object]], global_errors: list[float], bins: int, clip_px: float) -> None: + output_root.mkdir(parents=True, exist_ok=True) + + json_path = output_root / "cluster_reprojection_summary.json" + with open(json_path, "w", encoding="utf-8") as f: + json.dump(summaries, f, indent=2) + + csv_path = output_root / "cluster_reprojection_summary.csv" + fieldnames = [ + "cluster", + "num_cameras", + "num_tracks", + "num_measurements", + "mean_reproj_error_px", + "median_reproj_error_px", + "p90_reproj_error_px", + "p95_reproj_error_px", + "max_reproj_error_px", + "viz_images", + "histogram_plot", + "correspondence_plot", + "correspondence_lines", + "measured_correspondence_plot", + "measured_correspondence_lines", + "corr_residual_thresh_px", + "missing_images", + "saved_overlay_count", + "measured_oob_count", + "reprojected_oob_count", + "measured_oob_rate", + "reprojected_oob_rate", + "per_image_measured_oob_rate_max", + "per_image_reprojected_oob_rate_max", + "adjacent_pair_count", + "adjacent_pair_min_common_tracks", + "adjacent_pair_median_common_tracks", + "adjacent_pair_disp_median_px", + "adjacent_pair_disp_p90_px", + "adjacent_pair_disp_max_px", + ] + with open(csv_path, "w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + for summary in summaries: + row = dict(summary) + row["viz_images"] = ";".join(summary.get("viz_images", [])) + writer.writerow(row) + + _save_histogram( + errors=global_errors, + output_path=output_root / "global_reprojection_hist.png", + title="Global reprojection error across clusters", + bins=bins, + clip_px=clip_px, + ) + + +def main() -> None: + args = _parse_args() + + recon_root = Path(args.recon_root) + images_root = Path(args.images_root) + output_root = Path(args.output_root) if args.output_root else recon_root / "check_tracks" + rng = np.random.default_rng(args.random_seed) + image_index = _build_image_index(images_root) + print(f"Indexed {sum(len(v) for v in image_index.values())} images under {images_root}") + + model_dirs = _find_model_dirs(recon_root, args.model_name) + if not model_dirs: + raise ValueError(f"No COLMAP model directories named '{args.model_name}' found under: {recon_root}") + + summaries: list[dict[str, object]] = [] + global_errors: list[float] = [] + for model_dir in model_dirs: + try: + summary, cluster_errors = _process_cluster( + model_dir=model_dir, + recon_root=recon_root, + images_root=images_root, + output_root=output_root, + max_images=args.max_images, + max_tracks_per_image=args.max_tracks_per_image, + line_width=args.line_width, + dot_radius=args.dot_radius, + max_correspondences_per_pair=args.max_correspondences_per_pair, + corr_low_residual_only=args.corr_low_residual_only, + corr_residual_thresh_px=args.corr_residual_thresh_px, + hist_bins=args.hist_bins, + hist_clip_px=args.hist_clip_px, + rng=rng, + image_index=image_index, + ) + summaries.append(summary) + global_errors.extend(cluster_errors) + print( + f"[OK] {summary['cluster']}: " + f"tracks={summary['num_tracks']}, " + f"measurements={summary['num_measurements']}, " + f"mean={summary['mean_reproj_error_px']:.3f}px, " + f"p95={summary['p95_reproj_error_px']:.3f}px, " + f"overlays={summary['saved_overlay_count']}/{args.max_images}, " + f"corr_lines={summary['correspondence_lines']}, " + f"missing_images={summary['missing_images']}, " + f"oob(m/r)={100.0 * float(summary['measured_oob_rate']):.2f}%/" + f"{100.0 * float(summary['reprojected_oob_rate']):.2f}%, " + f"adj_min_common={int(float(summary['adjacent_pair_min_common_tracks'])) if not np.isnan(float(summary['adjacent_pair_min_common_tracks'])) else -1}" + ) + except Exception as exc: + print(f"[FAIL] {model_dir}: {exc}") + + summaries.sort(key=lambda s: float(s["mean_reproj_error_px"]), reverse=True) + _write_reports(output_root, summaries, global_errors=global_errors, bins=args.hist_bins, clip_px=args.hist_clip_px) + + print(f"\nSaved reports to: {output_root}") + print(f" - {output_root / 'cluster_reprojection_summary.json'}") + print(f" - {output_root / 'cluster_reprojection_summary.csv'}") + if plt is not None: + print(f" - {output_root / 'global_reprojection_hist.png'}") + else: + print(" - matplotlib not available; histogram plots were skipped") + + +if __name__ == "__main__": + main() diff --git a/thirdparty/Pi3 b/thirdparty/Pi3 new file mode 160000 index 000000000..c8ad630ae --- /dev/null +++ b/thirdparty/Pi3 @@ -0,0 +1 @@ +Subproject commit c8ad630ae16c91c25f8e9b9a07adc78adc3049ca