Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@
[submodule "thirdparty/FastVGGT"]
path = thirdparty/FastVGGT
url = https://github.com/mystorm16/FastVGGT.git
[submodule "thirdparty/Pi3"]
path = thirdparty/Pi3
url = https://github.com/yyfz/Pi3.git
28 changes: 21 additions & 7 deletions gtsfm/evaluation/compare_colmap_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,26 @@
import gtsfm.utils.logger as logger_utils
import gtsfm.utils.metrics as metric_utils
from gtsfm.cluster_optimizer import save_metrics_reports
from gtsfm.evaluation.metrics import GtsfmMetricsGroup
from gtsfm.evaluation.metrics import GtsfmMetric, GtsfmMetricsGroup
from gtsfm.utils import align, transform

logger = logger_utils.get_logger()


def _is_auc_metric_name(metric_name: str) -> bool:
return metric_name.startswith("pose_auc_@")


def _convert_scalar_auc_metrics_to_percent(metrics: List[GtsfmMetric]) -> List[GtsfmMetric]:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just update the AUC code to save it in percent, we dont need this.

converted_metrics: List[GtsfmMetric] = []
for metric in metrics:
if metric.dim == 0 and _is_auc_metric_name(metric.name) and metric.data is not None:
converted_metrics.append(GtsfmMetric(metric.name, float(metric.data) * 100.0))
else:
converted_metrics.append(metric)
return converted_metrics


def load_poses(colmap_dirpath: str) -> Dict[str, Pose3]:
"""Returns mapping from image filename to associated camera pose."""
wTi_list, img_fnames, _, _, _, _ = io_utils.read_scene_data_from_colmap_format(colmap_dirpath)
Expand Down Expand Up @@ -145,7 +159,7 @@ def export_metrics_group_to_csv(metrics_group: GtsfmMetricsGroup, output_path: s
def _format_pose_auc(metrics_group: GtsfmMetricsGroup) -> str:
auc_parts = []
for metric in metrics_group.metrics:
if not metric.name.startswith("pose_auc_@"):
if not _is_auc_metric_name(metric.name):
continue
if metric.data is None:
continue
Expand All @@ -154,7 +168,7 @@ def _format_pose_auc(metrics_group: GtsfmMetricsGroup) -> str:
except (TypeError, ValueError):
continue
suffix = metric.name.replace("pose_auc_", "")
auc_parts.append(f"{suffix}={value:.3f}")
auc_parts.append(f"{suffix}={value:.2f}%")
return ", ".join(auc_parts)


Expand Down Expand Up @@ -218,11 +232,11 @@ def compare_poses(baseline_dirpath: str, eval_dirpath: str, output_dirpath: str)

rotation_angular_errors = relative_rotation_error_metric.data
translation_angular_errors = relative_translation_error_metric.data
metrics.extend(
metric_utils.compute_pose_auc_metric(
rotation_angular_errors, translation_angular_errors, save_dir=output_dirpath
)
pose_auc_metrics = metric_utils.compute_pose_auc_metric(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update this method to return percent values?

rotation_angular_errors, translation_angular_errors, save_dir=output_dirpath
)
pose_auc_metrics = _convert_scalar_auc_metrics_to_percent(pose_auc_metrics)
metrics.extend(pose_auc_metrics)

ba_pose_metrics = GtsfmMetricsGroup(name="ba_pose_error_metrics", metrics=metrics)

Expand Down
49 changes: 37 additions & 12 deletions gtsfm/evaluation/compare_colmap_outputs_by_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,32 @@
logger = logger_utils.get_logger()


def _is_auc_metric_name(metric_name: str) -> bool:
return (
metric_name.startswith("pose_auc_@")
or metric_name.startswith("rotation_auc_@")
or metric_name.startswith("translation_auc_@")
)


def _convert_scalar_auc_metrics_to_percent(metrics: List[GtsfmMetric]) -> List[GtsfmMetric]:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicate (but also unnecessary) function

converted_metrics: List[GtsfmMetric] = []
for metric in metrics:
if metric.dim == 0 and _is_auc_metric_name(metric.name) and metric.data is not None:
converted_metrics.append(GtsfmMetric(metric.name, float(metric.data) * 100.0))
else:
converted_metrics.append(metric)
return converted_metrics


def _build_short_cluster_plot_filename(recon_dir: Path, root: Path) -> str:
"""Build filename as '<cluster>__<recon_name>_camera_centers.png'."""
del root # not needed after simplifying naming policy
cluster_name = recon_dir.parent.name or "cluster"
recon_name = recon_dir.name or "recon"
return f"{cluster_name}__{recon_name}_camera_centers.png"


def _read_images_txt_with_names(images_txt: Path) -> Dict[str, Pose3]:
"""Read poses from COLMAP images.txt keyed by image NAME."""
if not images_txt.exists():
Expand Down Expand Up @@ -200,7 +226,7 @@ def _compute_pose_metrics(baseline_list: List[Pose3], current_aligned_list: List
rotation_auc_values = metric_utils.pose_auc(rotation_angular_errors, thresholds_deg)
metrics.extend(
[
GtsfmMetric(f"rotation_auc_@{threshold}_deg", auc)
GtsfmMetric(f"rotation_auc_@{threshold}_deg", float(auc) * 100.0)
for threshold, auc in zip(thresholds_deg, rotation_auc_values)
]
)
Expand All @@ -209,15 +235,15 @@ def _compute_pose_metrics(baseline_list: List[Pose3], current_aligned_list: List
translation_auc_values = metric_utils.pose_auc(translation_angular_errors, thresholds_deg)
metrics.extend(
[
GtsfmMetric(f"translation_auc_@{threshold}_deg", auc)
GtsfmMetric(f"translation_auc_@{threshold}_deg", float(auc) * 100.0)
for threshold, auc in zip(thresholds_deg, translation_auc_values)
]
)
metrics.extend(
metric_utils.compute_pose_auc_metric(
relative_rotation_error_metric.data, relative_translation_error_metric.data, thresholds_deg=thresholds_deg
)
pose_auc_metrics = metric_utils.compute_pose_auc_metric(
relative_rotation_error_metric.data, relative_translation_error_metric.data, thresholds_deg=thresholds_deg
)
pose_auc_metrics = _convert_scalar_auc_metrics_to_percent(pose_auc_metrics)
metrics.extend(pose_auc_metrics)

return GtsfmMetricsGroup(name="ba_pose_error_metrics", metrics=metrics)

Expand Down Expand Up @@ -440,14 +466,14 @@ def _plot_pose_auc_boxplot(auc_values_by_label: Dict[str, List[float]], output_p
ax = fig.add_subplot(111)
ax.boxplot(data, vert=True, patch_artist=True)
ax.set_title(title)
ax.set_ylabel("AUC")
ax.set_ylabel("AUC (%)")
ax.set_xticks(range(1, len(labels) + 1))
ax.set_xticklabels(labels, rotation=30, ha="right")
stats_lines = []
for label, values in zip(labels, data):
mean_val = float(np.mean(values))
median_val = float(np.median(values))
stats_lines.append(f"{label}: mean={mean_val:.3f}, med={median_val:.3f}")
stats_lines.append(f"{label}: mean={mean_val:.2f}%, med={median_val:.2f}%")
if stats_lines:
ax.text(
0.02,
Expand Down Expand Up @@ -488,7 +514,7 @@ def _plot_pose_auc_vs_input_images(

ax.set_title("Pose AUC vs input images (all clusters)")
ax.set_xlabel("input images (current count)")
ax.set_ylabel("AUC")
ax.set_ylabel("AUC (%)")
ax.grid(True, linestyle="--", linewidth=0.5, alpha=0.5)
ax.legend(loc="best", fontsize=8)
fig.tight_layout()
Expand Down Expand Up @@ -596,7 +622,7 @@ def _format_auc(metrics_group: GtsfmMetricsGroup, prefix: str) -> str:
except (TypeError, ValueError):
continue
suffix = metric.name.replace(f"{prefix}_", "")
auc_parts.append(f"{suffix}={value:.3f}")
auc_parts.append(f"{suffix}={value:.2f}%")
return ", ".join(auc_parts)


Expand Down Expand Up @@ -832,8 +858,7 @@ def main() -> None:
else:
_print_metrics(str(recon_dir), metrics_group)
if fig_output_dir is not None:
safe_name = str(recon_dir).replace(os.sep, "__")
plot_path = fig_output_dir / f"{safe_name}_camera_centers.png"
plot_path = fig_output_dir / _build_short_cluster_plot_filename(recon_dir, root)
pose_auc_text = _format_auc(metrics_group, "pose_auc")
rotation_auc_text = _format_auc(metrics_group, "rotation_auc")
translation_auc_text = _format_auc(metrics_group, "translation_auc")
Expand Down
4 changes: 4 additions & 0 deletions gtsfm/utils/pycolmap_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ def colmap_camera_to_gtsam_calibration(camera: ColmapCamera) -> CALIBRATION_TYPE
# See https://github.com/colmap/colmap/blob/1f6812e333a1e4b2ef56aa74e2c3873e4e3a40cd/src/colmap/sensor/models.h#L273 # noqa: E501
fx, fy, cx, cy, k1, k2, p1, p2 = camera.params[:8]
return gtsam.Cal3DS2(fx, fy, 0.0, cx, cy, k1, k2, p1, p2)
elif camera_model_name == "SIMPLE_PINHOLE":
# See https://github.com/colmap/colmap/blob/1f6812e333a1e4b2ef56aa74e2c3873e4e3a40cd/src/colmap/sensor/models.h#L196 # noqa: E501
f, cx, cy = camera.params
return gtsam.Cal3_S2(f, f, 0.0, cx, cy)
elif camera_model_name == "PINHOLE":
# See https://github.com/colmap/colmap/blob/1f6812e333a1e4b2ef56aa74e2c3873e4e3a40cd/src/colmap/sensor/models.h#L196 # noqa: E501
fx, fy, cx, cy = camera.params
Expand Down
190 changes: 190 additions & 0 deletions pipeline/1-partition/partition_metis_megaloc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
"""Run MegaLoc retrieval + METIS partitioning and persist the visibility graph + cluster tree.

This script mirrors the image_pairs_generator and graph_partitioner configuration in
`gtsfm/configs/vggt.yaml`. It loads images, generates a visibility graph, partitions
the graph with METIS, and saves both the graph and tree under the chosen output root.
"""

from __future__ import annotations

import argparse
import pickle
import time
from pathlib import Path

import hydra
from dask.distributed import Client, LocalCluster
from hydra.utils import instantiate

import gtsfm.utils.logger as logger_utils
from gtsfm.common.outputs import OutputPaths, prepare_output_paths
from gtsfm.graph_partitioner.graph_partitioner_base import GraphPartitionerBase
from gtsfm.graph_partitioner.single_partitioner import SinglePartitioner
from gtsfm.loader.loader_base import LoaderBase
from gtsfm.products.visibility_graph import VisibilityGraph
from gtsfm.retriever.image_pairs_generator import ImagePairsGenerator

logger = logger_utils.get_logger()


def _build_components(
config_name: str,
dataset_dir: str,
images_dir: str | None,
max_resolution: int | None,
) -> tuple[LoaderBase, ImagePairsGenerator, GraphPartitionerBase]:
overrides: list[str] = [f"loader.dataset_dir={dataset_dir}"]
if images_dir is not None:
overrides.append(f"loader.images_dir={images_dir}")
if max_resolution is not None:
overrides.append(f"loader.max_resolution={max_resolution}")

with hydra.initialize_config_module(config_module="gtsfm.configs", version_base=None):
cfg = hydra.compose(config_name=config_name, overrides=overrides)

loader: LoaderBase = instantiate(cfg.loader)
image_pairs_generator: ImagePairsGenerator = instantiate(cfg.image_pairs_generator)
graph_partitioner: GraphPartitionerBase = instantiate(cfg.graph_partitioner)
return loader, image_pairs_generator, graph_partitioner


def _run_retriever(
client: Client, loader: LoaderBase, image_pairs_generator: ImagePairsGenerator, output_paths: OutputPaths
) -> VisibilityGraph:
start_time = time.time()
batch_size = image_pairs_generator._batch_size
transforms = image_pairs_generator.get_preprocessing_transforms()
image_batch_futures = loader.get_all_descriptor_image_batches_as_futures(client, batch_size, *transforms)
image_fnames = loader.image_filenames()

logger.info("🔥 Running image pair retrieval...")
visibility_graph = image_pairs_generator.run(
client=client,
image_batch_futures=image_batch_futures,
image_fnames=image_fnames,
plots_output_dir=output_paths.plots,
)

try:
image_pairs_generator._retriever.save_diagnostics(
image_fnames=image_fnames,
pairs=visibility_graph,
plots_output_dir=output_paths.plots,
)
except Exception as exc: # pragma: no cover - diagnostic path best-effort
logger.warning("Failed to persist retriever diagnostics: %s", exc)

logger.info("🚀 Image pair retrieval took %.2f min.", (time.time() - start_time) / 60.0)
return visibility_graph


def _save_visibility_graph(graph: VisibilityGraph, output_paths: OutputPaths) -> None:
try:
with open(output_paths.results / "visibility_graph.pkl", "wb") as f:
pickle.dump(graph, f)
except Exception as exc:
logger.warning("Failed to serialize visibility graph: %s", exc)


def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run MegaLoc+METIS partitioning and save outputs.")
parser.add_argument(
"--dataset_dir",
type=str,
required=True,
help="Dataset root containing images/ (Olsson-style loader default).",
)
parser.add_argument(
"--images_dir",
type=str,
default=None,
help="Optional path to images directory (overrides loader default).",
)
parser.add_argument(
"--output_root",
type=str,
default=str(Path.cwd()),
help="Root directory to store results (will create output_root/results).",
)
parser.add_argument(
"--config_name",
type=str,
default="vggt_megaloc_phototourism",
help="Config in gtsfm/configs to load (default: vggt).",
)
parser.add_argument(
"--max_resolution",
type=int,
default=None,
help="Override loader max resolution (if unset, uses config default).",
)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="Number of local Dask workers.",
)
parser.add_argument(
"--threads_per_worker",
type=int,
default=1,
help="Threads per Dask worker.",
)
parser.add_argument(
"--worker_memory_limit",
type=str,
default="32GB",
help="Memory limit per worker, e.g. 16GB.",
)
parser.add_argument(
"--dashboard_address",
type=str,
default=":8787",
help="Dask dashboard address, set to empty string to disable.",
)
parser.add_argument(
"--single_cluster",
action="store_true",
help="Skip METIS and output a single cluster containing all retrieved image pairs.",
)
return parser.parse_args()


def main() -> None:
args = _parse_args()

output_root = Path(args.output_root)
output_paths = prepare_output_paths(output_root, None)

loader, image_pairs_generator, graph_partitioner = _build_components(
config_name=args.config_name,
dataset_dir=args.dataset_dir,
images_dir=args.images_dir,
max_resolution=args.max_resolution,
)

logger.info("🌟 Starting Dask local cluster...")
cluster = LocalCluster(
n_workers=args.num_workers,
threads_per_worker=args.threads_per_worker,
memory_limit=args.worker_memory_limit,
dashboard_address=args.dashboard_address,
)

with Client(cluster) as client:
visibility_graph = _run_retriever(client, loader, image_pairs_generator, output_paths)

if args.single_cluster:
logger.info("🔥 Skipping METIS; creating a single cluster with all retrieved pairs...")
cluster_tree = SinglePartitioner().run(visibility_graph)
else:
logger.info("🔥 Running METIS partitioning...")
cluster_tree = graph_partitioner.run(visibility_graph)
graph_partitioner.log_partition_details(cluster_tree, output_paths)
_save_visibility_graph(visibility_graph, output_paths)

logger.info("✅ Saved visibility_graph.pkl and cluster_tree.pkl under %s", output_paths.results)


if __name__ == "__main__":
main()
Loading