diff --git a/ultrack/core/solve/_test/test_sql_tracking.py b/ultrack/core/solve/_test/test_sql_tracking.py index b7c05ef..ff30514 100644 --- a/ultrack/core/solve/_test/test_sql_tracking.py +++ b/ultrack/core/solve/_test/test_sql_tracking.py @@ -1,14 +1,19 @@ +from pathlib import Path +from typing import Tuple + import numpy as np import pandas as pd import pytest import sqlalchemy as sqla +import toml from sqlalchemy.orm import Session from ultrack import solve, to_tracks_layer -from ultrack.config.config import MainConfig -from ultrack.core.database import LinkDB, NodeDB, VarAnnotation +from ultrack.config.config import MainConfig, load_config +from ultrack.core.database import Base, LinkDB, NodeDB, VarAnnotation from ultrack.core.solve.sqltracking import SQLTracking from ultrack.utils.constants import NO_PARENT +from ultrack.utils.data import make_config_content _CONFIG_PARAMS = { "segmentation.n_workers": 4, @@ -50,6 +55,17 @@ def _validate_tracking_solution(config: MainConfig): continue assert np.any(group["parent_id"] != NO_PARENT) + # every selected node's parent_id must reference another selected node; + # regression guard for window-boundary phantoms in interleaved solving. + # SQLite stores Boolean as 0/1, so cast before using as a mask. + is_selected = nodes["selected"].astype(bool) + parented = nodes[is_selected & (nodes["parent_id"] != NO_PARENT)] + parents_selected = is_selected.loc[parented["parent_id"].values].values + assert np.all(parents_selected), ( + f"{(~parents_selected).sum()} selected nodes have parent_id pointing " + "to a non-selected node (dangling parent)" + ) + @pytest.mark.parametrize( "config_content,timelapse_mock_data", @@ -138,3 +154,141 @@ def test_annotations_sql_tracking( tracks_df_annot, _ = to_tracks_layer(config) assert len(tracks_df) > len(tracks_df_annot) + + +def _make_minimal_tracking_config(tmp_path: Path, window_size: int, overlap_size: int): + cfg = make_config_content( + { + "data.working_dir": str(tmp_path), + "data.database": "sqlite", + "tracking.window_size": window_size, + "tracking.overlap_size": overlap_size, + } + ) + path = tmp_path / "config.toml" + with open(path, mode="w") as f: + toml.dump(cfg, f) + return load_config(path) + + +def _seed_nodes(config: MainConfig, n_t: int) -> None: + """Create empty NodeDB rows at every t in [0, n_t).""" + engine = sqla.create_engine(config.data_config.database_path) + Base.metadata.create_all(engine) + rows = [ + dict( + t=t, + id=t, + parent_id=NO_PARENT, + hier_parent_id=NO_PARENT, + t_node_id=0, + t_hier_id=0, + z=0.0, + y=0.0, + x=0.0, + area=1, + selected=False, + pickle=None, + features=None, + node_prob=0.5, + ) + for t in range(n_t) + ] + with Session(engine) as session: + session.execute(sqla.insert(NodeDB), rows) + session.commit() + config.data_config.metadata_add({"shape": [n_t, 1, 1, 1]}) + + +def _mark_selected(config: MainConfig, time_slices: Tuple[int, ...]) -> None: + engine = sqla.create_engine(config.data_config.database_path) + with Session(engine) as session: + session.execute( + sqla.update(NodeDB).where(NodeDB.t.in_(time_slices)).values(selected=True) + ) + session.commit() + + +def test_compute_layout_first_pass(tmp_path: Path) -> None: + """With no committed neighbours every side gets the full overlap.""" + config = _make_minimal_tracking_config(tmp_path, window_size=3, overlap_size=2) + _seed_nodes(config, n_t=15) + + tracker = SQLTracking(config) + + layout = tracker._compute_layout(index=2) # middle batch, inner [6, 8] + assert layout.left_anchored is False + assert layout.right_anchored is False + assert (layout.solver_start, layout.solver_end) == (4, 10) + assert (layout.commit_start, layout.commit_end) == (5, 9) + + first = tracker._compute_layout(index=0) # leftmost batch, inner [0, 2] + assert first.left_anchored is False + assert first.right_anchored is False + assert (first.solver_start, first.solver_end) == (0, 4) + assert (first.commit_start, first.commit_end) == (0, 3) + + last_index = tracker.num_batches - 1 + last = tracker._compute_layout(index=last_index) + assert last.right_anchored is False + assert last.solver_end == tracker._max_t + assert last.commit_end == tracker._max_t + + +def test_compute_layout_anchored(tmp_path: Path) -> None: + """Committed neighbour slices shrink the layout to the inner range.""" + config = _make_minimal_tracking_config(tmp_path, window_size=3, overlap_size=2) + _seed_nodes(config, n_t=15) + # batch index 2 has inner [6, 8]; mark its boundaries to simulate that + # both neighbouring batches have already committed. + _mark_selected(config, time_slices=(5, 9)) + + tracker = SQLTracking(config) + layout = tracker._compute_layout(index=2) + assert layout.left_anchored is True + assert layout.right_anchored is True + assert (layout.solver_start, layout.solver_end) == (6, 8) + assert (layout.commit_start, layout.commit_end) == (6, 8) + + +@pytest.mark.parametrize( + "selected_slice, expect_left, expect_right, expected_window", + [ + # (marked time slice, left_anchored, right_anchored, + # (solver_start, solver_end, commit_start, commit_end)) + (5, True, False, (6, 10, 6, 9)), # only left neighbour committed + (9, False, True, (4, 8, 5, 8)), # only right neighbour committed + ], +) +def test_compute_layout_mixed_anchoring( + tmp_path: Path, + selected_slice: int, + expect_left: bool, + expect_right: bool, + expected_window: Tuple[int, int, int, int], +) -> None: + """Only one side anchored: that side shrinks, the other keeps the overlap.""" + config = _make_minimal_tracking_config(tmp_path, window_size=3, overlap_size=2) + _seed_nodes(config, n_t=15) + _mark_selected(config, time_slices=(selected_slice,)) + + tracker = SQLTracking(config) + layout = tracker._compute_layout(index=2) + assert layout.left_anchored is expect_left + assert layout.right_anchored is expect_right + assert ( + layout.solver_start, + layout.solver_end, + layout.commit_start, + layout.commit_end, + ) == expected_window + + +def test_is_committed_at_out_of_range(tmp_path: Path) -> None: + """Out-of-range times return False without touching the DB.""" + config = _make_minimal_tracking_config(tmp_path, window_size=3, overlap_size=2) + _seed_nodes(config, n_t=5) + + tracker = SQLTracking(config) + assert tracker._is_committed_at(-1) is False + assert tracker._is_committed_at(tracker._max_t + 1) is False diff --git a/ultrack/core/solve/solver/base_solver.py b/ultrack/core/solve/solver/base_solver.py index 0c4bf24..7d0590a 100644 --- a/ultrack/core/solve/solver/base_solver.py +++ b/ultrack/core/solve/solver/base_solver.py @@ -30,6 +30,8 @@ def add_nodes( is_last_t: ArrayLike, is_border: ArrayLike = False, node_prob: Optional[ArrayLike] = None, + free_appear: bool = True, + free_disappear: bool = True, ) -> None: """Add nodes variables solver. @@ -46,6 +48,15 @@ def add_nodes( Default: False. node_prob: Optional[ArrayLike] If provided assigns a node probability score to the objective function. + free_appear : bool + When False, the ``is_first_t`` flag is ignored and appearance is + penalised even at the first slice of the solver window. Use this for + batches whose start slice is anchored to a neighbouring batch's + already-committed selection (interior of the experiment). + Default: True. + free_disappear : bool + Same as ``free_appear`` but for ``is_last_t`` and disappearance. + Default: True. """ @abstractmethod diff --git a/ultrack/core/solve/solver/mip_solver.py b/ultrack/core/solve/solver/mip_solver.py index 6a0007a..d8c2846 100644 --- a/ultrack/core/solve/solver/mip_solver.py +++ b/ultrack/core/solve/solver/mip_solver.py @@ -87,6 +87,8 @@ def add_nodes( is_last_t: ArrayLike, is_border: ArrayLike = False, nodes_prob: Optional[ArrayLike] = None, + free_appear: bool = True, + free_disappear: bool = True, ) -> None: """Add nodes slack variables to gurobi model. @@ -103,6 +105,15 @@ def add_nodes( Default: False nodes_prob: Optional[ArrayLike] If provided assigns a node probability score to the objective function. + free_appear : bool + When False, appearance is penalised even at slices marked + ``is_first_t``. Use for batches whose start slice is anchored to a + neighbouring batch's already-committed selection so the solver + cannot freely spawn tracks at that interior boundary. + Default: True. + free_disappear : bool + Same as ``free_appear`` but for ``is_last_t`` and disappearance. + Default: True. """ if self._nodes is not None: raise ValueError("Nodes have already been added.") @@ -114,6 +125,13 @@ def add_nodes( nodes_prob=nodes_prob, ) + is_first_t = np.asarray(is_first_t, dtype=bool) + is_last_t = np.asarray(is_last_t, dtype=bool) + if not free_appear: + is_first_t = np.zeros_like(is_first_t, dtype=bool) + if not free_disappear: + is_last_t = np.zeros_like(is_last_t, dtype=bool) + LOG.info("# %s nodes at starting `t`.", np.sum(is_first_t)) LOG.info("# %s nodes at last `t`.", np.sum(is_last_t)) diff --git a/ultrack/core/solve/sqltracking.py b/ultrack/core/solve/sqltracking.py index 5ddd55d..2b88f1a 100644 --- a/ultrack/core/solve/sqltracking.py +++ b/ultrack/core/solve/sqltracking.py @@ -1,6 +1,7 @@ import itertools import logging import math +from dataclasses import dataclass from typing import Optional, Tuple import numpy as np @@ -27,6 +28,27 @@ LOG = logging.getLogger(__name__) +@dataclass(frozen=True) +class _BatchLayout: + """Slice ranges used by a single windowed-solve batch. + + The solver builds variables for every time slice in ``[solver_start, + solver_end]`` (the solver window). The database write covers + ``[commit_start, commit_end]`` (the commit window). When a neighbouring + batch has already committed its own solution, the boundary on that side is + anchored: the solver shrinks to share exactly the anchored slice with the + neighbour and the commit window does not extend past the inner range on + that side. + """ + + solver_start: int + solver_end: int + commit_start: int + commit_end: int + left_anchored: bool + right_anchored: bool + + class SQLTracking: def __init__( self, @@ -45,6 +67,7 @@ def __init__( self._tracking_config = config.tracking_config self._data_config = config.data_config self._solver: Optional[MIPSolver] = None + self._layout: Optional[_BatchLayout] = None self._max_t = maximum_time_from_database(self._data_config) if self._tracking_config.window_size is None: @@ -70,18 +93,27 @@ def construct_model(self, index: int = 0) -> None: f"Invalid index {index}, expected between [0, {self.num_batches})." ) + layout = self._compute_layout(index) + self._layout = layout + solver = MIPSolver(self._tracking_config) print(f"Solving ILP batch {index}") + print( + f" solver window [{layout.solver_start}, {layout.solver_end}], " + f"commit [{layout.commit_start}, {layout.commit_end}], " + f"left_anchored={layout.left_anchored}, " + f"right_anchored={layout.right_anchored}" + ) print("Constructing ILP ...") - self._add_nodes(solver=solver, index=index) - self._add_edges(solver=solver, index=index) + self._add_nodes(solver=solver, layout=layout) + self._add_edges(solver=solver, layout=layout) solver.set_standard_constraints() - self._add_overlap_constraints(solver=solver, index=index) - self._add_boundary_constraints(solver=solver, index=index) + self._add_overlap_constraints(solver=solver, layout=layout) + self._add_boundary_constraints(solver=solver, layout=layout) self._solver = solver @@ -142,7 +174,10 @@ def fix_annotations(self, index: int) -> None: engine = sqla.create_engine(self._data_config.database_path) - start_time, end_time = self._window_limits(index, True) + layout = ( + self._layout if self._layout is not None else self._compute_layout(index) + ) + start_time, end_time = layout.solver_start, layout.solver_end with Session(engine) as session: # setting extra slack variables @@ -189,7 +224,10 @@ def fix_ground_truth_matches(self, index: int) -> None: """ engine = sqla.create_engine(self._data_config.database_path) - start_time, end_time = self._window_limits(index, True) + layout = ( + self._layout if self._layout is not None else self._compute_layout(index) + ) + start_time, end_time = layout.solver_start, layout.solver_end with Session(engine) as session: # setting extra slack variables @@ -237,6 +275,10 @@ def set_number_of_segments(self, time: int, number_of_segments: int) -> None: def _window_limits(self, index: int, with_overlap: bool) -> Tuple[int, int]: """Computes time window of a given index, with or without overlap. + Kept for backward compatibility; ``_compute_layout`` is the new + canonical helper that also accounts for whether neighbouring batches + have already committed. + Parameters ---------- index : int @@ -258,15 +300,73 @@ def _window_limits(self, index: int, with_overlap: bool) -> Tuple[int, int]: ) * self._window_size + with_overlap * self._tracking_config.overlap_size return start_time, end_time - 1 - def _add_nodes(self, solver: BaseSolver, index: int) -> None: - """Query nodes from a given batch index and add them to solver. + def _is_committed_at(self, t: int) -> bool: + """Whether some node at time ``t`` is already ``selected`` in the DB. - Parameters - ---------- - index : int - Batch index. + Used to detect that a neighbouring batch has already committed its + solution at the boundary of the current batch. """ - start_time, end_time = self._window_limits(index, True) + if t < 0 or t > self._max_t: + return False + engine = sqla.create_engine(self._data_config.database_path) + with Session(engine) as session: + return ( + session.query(NodeDB.id).where(NodeDB.t == t, NodeDB.selected).first() + is not None + ) + + def _compute_layout(self, index: int) -> _BatchLayout: + """Decide solver and commit windows for ``index`` from DB state. + + A batch is *anchored* on a side when its neighbour batch on that side + has already written a solution there. Anchored batches use a tight + window so the neighbour's committed boundary is reused verbatim; + non-anchored sides use the configured overlap on the solver window + and commit one slice past the inner range so the next-solving + neighbour finds an anchor. + """ + overlap = self._tracking_config.overlap_size + inner_start = index * self._window_size + inner_end = min((index + 1) * self._window_size - 1, self._max_t) + + left_anchored = inner_start > 0 and self._is_committed_at(inner_start - 1) + right_anchored = inner_end < self._max_t and self._is_committed_at( + inner_end + 1 + ) + + if left_anchored: + solver_start = inner_start + commit_start = inner_start + else: + solver_start = max(inner_start - overlap, 0) + commit_start = max(inner_start - 1, 0) + + if right_anchored: + solver_end = inner_end + commit_end = inner_end + else: + solver_end = min(inner_end + overlap, self._max_t) + commit_end = min(inner_end + 1, self._max_t) + + return _BatchLayout( + solver_start=solver_start, + solver_end=solver_end, + commit_start=commit_start, + commit_end=commit_end, + left_anchored=left_anchored, + right_anchored=right_anchored, + ) + + def _add_nodes(self, solver: BaseSolver, layout: _BatchLayout) -> None: + """Query nodes inside ``layout.solver_*`` and add them to the solver. + + Appearance and disappearance are kept free at the solver window + boundaries only when that side is *not* anchored to a neighbouring + batch's committed selection; otherwise the solver pays the regular + penalty so it cannot spawn or terminate tracks at an interior seam. + """ + start_time = layout.solver_start + end_time = layout.solver_end engine = sqla.create_engine(self._data_config.database_path) border_distance = self._tracking_config.image_border_size @@ -317,10 +417,7 @@ def _add_nodes(self, solver: BaseSolver, index: int) -> None: f"Found {df.shape[0] - n_invalid_prob} / {df.shape[0]} valid probs." ) - start_time = max(start_time, 0) - end_time = min(end_time, self._max_t) - - LOG.info(f"Batch {index}, nodes with t between {start_time} and {end_time}") + LOG.info(f"nodes with t between {start_time} and {end_time}") solver.add_nodes( df["id"], @@ -328,17 +425,14 @@ def _add_nodes(self, solver: BaseSolver, index: int) -> None: df["t"] == end_time, is_border=is_border, nodes_prob=nodes_prob, + free_appear=not layout.left_anchored, + free_disappear=not layout.right_anchored, ) - def _add_edges(self, solver: BaseSolver, index: int) -> None: - """Query edges from a given batch index and add them to solver. - - Parameters - ---------- - index : int - Batch index. - """ - start_time, end_time = self._window_limits(index, True) + def _add_edges(self, solver: BaseSolver, layout: _BatchLayout) -> None: + """Query edges inside ``layout.solver_*`` and add them to the solver.""" + start_time = layout.solver_start + end_time = layout.solver_end engine = sqla.create_engine(self._data_config.database_path) with Session(engine) as session: @@ -351,20 +445,17 @@ def _add_edges(self, solver: BaseSolver, index: int) -> None: df = pd.read_sql(query.statement, session.bind) LOG.info( - f"Batch {index}, edges with source nodes with t between {start_time} and {end_time - 1}" + f"edges with source nodes with t between {start_time} and {end_time - 1}" ) solver.add_edges(df["source_id"], df["target_id"], df["weight"]) - def _add_overlap_constraints(self, solver: BaseSolver, index: int) -> None: - """Adds overlaping segmentation constrainsts - - Parameters - ---------- - index : int - Batch index. - """ - start_time, end_time = self._window_limits(index, True) + def _add_overlap_constraints( + self, solver: BaseSolver, layout: _BatchLayout + ) -> None: + """Adds overlapping segmentation constraints inside the solver window.""" + start_time = layout.solver_start + end_time = layout.solver_end engine = sqla.create_engine(self._data_config.database_path) with Session(engine) as session: @@ -377,28 +468,29 @@ def _add_overlap_constraints(self, solver: BaseSolver, index: int) -> None: solver.add_overlap_constraints(df["node_id"], df["ancestor_id"]) - def _add_boundary_constraints(self, solver: BaseSolver, index: int) -> None: - """ - Enforce to solution nodes from the boundary (in time) already selected from adjacent batches. + def _add_boundary_constraints( + self, solver: BaseSolver, layout: _BatchLayout + ) -> None: + """Force this batch's solver to keep the neighbours' boundary picks. - Parameters - ---------- - index : int - Batch index. + At an anchored side the solver window's outermost slice is shared with + the neighbouring batch's already-committed selection, so every + ``selected=True`` node at that slice is forced into this batch's + solution. The non-anchored sides have no neighbouring commit to read. """ - start_time, end_time = self._window_limits(index, True) - engine = sqla.create_engine(self._data_config.database_path) with Session(engine) as session: query = session.query(NodeDB.id).where(NodeDB.selected) - start_nodes = [n for n, in query.where(NodeDB.t == start_time)] - end_nodes = [n for n, in query.where(NodeDB.t == end_time)] + start_nodes = [n for n, in query.where(NodeDB.t == layout.solver_start)] + end_nodes = [n for n, in query.where(NodeDB.t == layout.solver_end)] LOG.info( - f"# {len(start_nodes)} boundary constraints found at at t = {start_time}" + f"# {len(start_nodes)} boundary constraints found at t = {layout.solver_start}" + ) + LOG.info( + f"# {len(end_nodes)} boundary constraints found at t = {layout.solver_end}" ) - LOG.info(f"# {len(end_nodes)} boundary constraints found at at t = {end_time}") solver.enforce_nodes_solution_value(start_nodes, variable="node", value=True) solver.enforce_nodes_solution_value(end_nodes, variable="node", value=True) @@ -406,48 +498,65 @@ def _add_boundary_constraints(self, solver: BaseSolver, index: int) -> None: def add_solution(self, index: int = 0) -> None: """Adds selected nodes to solution in database. + Writes this batch's solution over ``layout.commit_start..commit_end``. + ``parent_id`` is updated everywhere in the commit range except at the + leftmost slice when it coincides with the solver's leftmost slice -- + there the solver has no incoming edge and its ``parent_id`` is + ``NO_PARENT``, which would clobber the previous batch's lineage. + Parameters ---------- index : int Batch index, by default 0, which works for single batch tracking. """ solution = self.solver.solution() - solution["node_id"] = solution.index - start_time, end_time = self._window_limits(index, False) + layout = ( + self._layout if self._layout is not None else self._compute_layout(index) + ) + commit_start = layout.commit_start + commit_end = layout.commit_end + skip_parent_at_start = layout.commit_start == layout.solver_start + + records = solution[["node_id", "parent_id"]].to_dict("records") + ids_only = solution[["node_id"]].to_dict("records") engine = sqla.create_engine(self._data_config.database_path) with Session(engine) as session: - general_stmt = ( - sqla.update(NodeDB) - .where( - NodeDB.t.between(start_time, end_time), - NodeDB.id == sqla.bindparam("node_id"), + # Inner range: write selected + parent_id for every batch node. + inner_lower = commit_start + 1 if skip_parent_at_start else commit_start + if inner_lower <= commit_end: + inner_stmt = ( + sqla.update(NodeDB) + .where( + NodeDB.t.between(inner_lower, commit_end), + NodeDB.id == sqla.bindparam("node_id"), + ) + .values(parent_id=sqla.bindparam("parent_id"), selected=True) + ) + session.connection().execute( + inner_stmt, + records, + execution_options={"synchronize_session": False}, ) - .values(parent_id=sqla.bindparam("parent_id"), selected=True) - ) - session.connection().execute( - general_stmt, - solution[["node_id", "parent_id"]].to_dict("records"), - execution_options={"synchronize_session": False}, - ) - # condition isn't necessary but avoids a useless operation - if start_time > 0: - # insert nodes from start time - 1 without their parent - start_stmt = ( + # Left boundary slice (only when the solver had no incoming edges + # here): mark selected without touching parent_id so we keep the + # previous batch's lineage pointing back into its own inner range. + if skip_parent_at_start: + boundary_stmt = ( sqla.update(NodeDB) .where( - NodeDB.t == start_time - 1, + NodeDB.t == commit_start, NodeDB.id == sqla.bindparam("node_id"), ) .values(selected=True) ) session.connection().execute( - start_stmt, - solution[["node_id"]].to_dict("records"), - execution_options={"syncronize_session": False}, + boundary_stmt, + ids_only, + execution_options={"synchronize_session": False}, ) session.commit()