Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 156 additions & 2 deletions ultrack/core/solve/_test/test_sql_tracking.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
from pathlib import Path
from typing import Tuple

import numpy as np
import pandas as pd
import pytest
import sqlalchemy as sqla
import toml
from sqlalchemy.orm import Session

from ultrack import solve, to_tracks_layer
from ultrack.config.config import MainConfig
from ultrack.core.database import LinkDB, NodeDB, VarAnnotation
from ultrack.config.config import MainConfig, load_config
from ultrack.core.database import Base, LinkDB, NodeDB, VarAnnotation
from ultrack.core.solve.sqltracking import SQLTracking
from ultrack.utils.constants import NO_PARENT
from ultrack.utils.data import make_config_content

_CONFIG_PARAMS = {
"segmentation.n_workers": 4,
Expand Down Expand Up @@ -50,6 +55,17 @@ def _validate_tracking_solution(config: MainConfig):
continue
assert np.any(group["parent_id"] != NO_PARENT)

# every selected node's parent_id must reference another selected node;
# regression guard for window-boundary phantoms in interleaved solving.
# SQLite stores Boolean as 0/1, so cast before using as a mask.
is_selected = nodes["selected"].astype(bool)
parented = nodes[is_selected & (nodes["parent_id"] != NO_PARENT)]
parents_selected = is_selected.loc[parented["parent_id"].values].values
assert np.all(parents_selected), (
f"{(~parents_selected).sum()} selected nodes have parent_id pointing "
"to a non-selected node (dangling parent)"
)


@pytest.mark.parametrize(
"config_content,timelapse_mock_data",
Expand Down Expand Up @@ -138,3 +154,141 @@ def test_annotations_sql_tracking(
tracks_df_annot, _ = to_tracks_layer(config)

assert len(tracks_df) > len(tracks_df_annot)


def _make_minimal_tracking_config(tmp_path: Path, window_size: int, overlap_size: int):
cfg = make_config_content(
{
"data.working_dir": str(tmp_path),
"data.database": "sqlite",
"tracking.window_size": window_size,
"tracking.overlap_size": overlap_size,
}
)
path = tmp_path / "config.toml"
with open(path, mode="w") as f:
toml.dump(cfg, f)
return load_config(path)


def _seed_nodes(config: MainConfig, n_t: int) -> None:
"""Create empty NodeDB rows at every t in [0, n_t)."""
engine = sqla.create_engine(config.data_config.database_path)
Base.metadata.create_all(engine)
rows = [
dict(
t=t,
id=t,
parent_id=NO_PARENT,
hier_parent_id=NO_PARENT,
t_node_id=0,
t_hier_id=0,
z=0.0,
y=0.0,
x=0.0,
area=1,
selected=False,
pickle=None,
features=None,
node_prob=0.5,
)
for t in range(n_t)
]
with Session(engine) as session:
session.execute(sqla.insert(NodeDB), rows)
session.commit()
config.data_config.metadata_add({"shape": [n_t, 1, 1, 1]})


def _mark_selected(config: MainConfig, time_slices: Tuple[int, ...]) -> None:
engine = sqla.create_engine(config.data_config.database_path)
with Session(engine) as session:
session.execute(
sqla.update(NodeDB).where(NodeDB.t.in_(time_slices)).values(selected=True)
)
session.commit()


def test_compute_layout_first_pass(tmp_path: Path) -> None:
"""With no committed neighbours every side gets the full overlap."""
config = _make_minimal_tracking_config(tmp_path, window_size=3, overlap_size=2)
_seed_nodes(config, n_t=15)

tracker = SQLTracking(config)

layout = tracker._compute_layout(index=2) # middle batch, inner [6, 8]
assert layout.left_anchored is False
assert layout.right_anchored is False
assert (layout.solver_start, layout.solver_end) == (4, 10)
assert (layout.commit_start, layout.commit_end) == (5, 9)

first = tracker._compute_layout(index=0) # leftmost batch, inner [0, 2]
assert first.left_anchored is False
assert first.right_anchored is False
assert (first.solver_start, first.solver_end) == (0, 4)
assert (first.commit_start, first.commit_end) == (0, 3)

last_index = tracker.num_batches - 1
last = tracker._compute_layout(index=last_index)
assert last.right_anchored is False
assert last.solver_end == tracker._max_t
assert last.commit_end == tracker._max_t


def test_compute_layout_anchored(tmp_path: Path) -> None:
"""Committed neighbour slices shrink the layout to the inner range."""
config = _make_minimal_tracking_config(tmp_path, window_size=3, overlap_size=2)
_seed_nodes(config, n_t=15)
# batch index 2 has inner [6, 8]; mark its boundaries to simulate that
# both neighbouring batches have already committed.
_mark_selected(config, time_slices=(5, 9))

tracker = SQLTracking(config)
layout = tracker._compute_layout(index=2)
assert layout.left_anchored is True
assert layout.right_anchored is True
assert (layout.solver_start, layout.solver_end) == (6, 8)
assert (layout.commit_start, layout.commit_end) == (6, 8)


@pytest.mark.parametrize(
"selected_slice, expect_left, expect_right, expected_window",
[
# (marked time slice, left_anchored, right_anchored,
# (solver_start, solver_end, commit_start, commit_end))
(5, True, False, (6, 10, 6, 9)), # only left neighbour committed
(9, False, True, (4, 8, 5, 8)), # only right neighbour committed
],
)
def test_compute_layout_mixed_anchoring(
tmp_path: Path,
selected_slice: int,
expect_left: bool,
expect_right: bool,
expected_window: Tuple[int, int, int, int],
) -> None:
"""Only one side anchored: that side shrinks, the other keeps the overlap."""
config = _make_minimal_tracking_config(tmp_path, window_size=3, overlap_size=2)
_seed_nodes(config, n_t=15)
_mark_selected(config, time_slices=(selected_slice,))

tracker = SQLTracking(config)
layout = tracker._compute_layout(index=2)
assert layout.left_anchored is expect_left
assert layout.right_anchored is expect_right
assert (
layout.solver_start,
layout.solver_end,
layout.commit_start,
layout.commit_end,
) == expected_window


def test_is_committed_at_out_of_range(tmp_path: Path) -> None:
"""Out-of-range times return False without touching the DB."""
config = _make_minimal_tracking_config(tmp_path, window_size=3, overlap_size=2)
_seed_nodes(config, n_t=5)

tracker = SQLTracking(config)
assert tracker._is_committed_at(-1) is False
assert tracker._is_committed_at(tracker._max_t + 1) is False
11 changes: 11 additions & 0 deletions ultrack/core/solve/solver/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def add_nodes(
is_last_t: ArrayLike,
is_border: ArrayLike = False,
node_prob: Optional[ArrayLike] = None,
free_appear: bool = True,
free_disappear: bool = True,
) -> None:
"""Add nodes variables solver.

Expand All @@ -46,6 +48,15 @@ def add_nodes(
Default: False.
node_prob: Optional[ArrayLike]
If provided assigns a node probability score to the objective function.
free_appear : bool
When False, the ``is_first_t`` flag is ignored and appearance is
penalised even at the first slice of the solver window. Use this for
batches whose start slice is anchored to a neighbouring batch's
already-committed selection (interior of the experiment).
Default: True.
free_disappear : bool
Same as ``free_appear`` but for ``is_last_t`` and disappearance.
Default: True.
"""

@abstractmethod
Expand Down
18 changes: 18 additions & 0 deletions ultrack/core/solve/solver/mip_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def add_nodes(
is_last_t: ArrayLike,
is_border: ArrayLike = False,
nodes_prob: Optional[ArrayLike] = None,
free_appear: bool = True,
free_disappear: bool = True,
) -> None:
"""Add nodes slack variables to gurobi model.

Expand All @@ -103,6 +105,15 @@ def add_nodes(
Default: False
nodes_prob: Optional[ArrayLike]
If provided assigns a node probability score to the objective function.
free_appear : bool
When False, appearance is penalised even at slices marked
``is_first_t``. Use for batches whose start slice is anchored to a
neighbouring batch's already-committed selection so the solver
cannot freely spawn tracks at that interior boundary.
Default: True.
free_disappear : bool
Same as ``free_appear`` but for ``is_last_t`` and disappearance.
Default: True.
"""
if self._nodes is not None:
raise ValueError("Nodes have already been added.")
Expand All @@ -114,6 +125,13 @@ def add_nodes(
nodes_prob=nodes_prob,
)

is_first_t = np.asarray(is_first_t, dtype=bool)
is_last_t = np.asarray(is_last_t, dtype=bool)
if not free_appear:
is_first_t = np.zeros_like(is_first_t, dtype=bool)
if not free_disappear:
is_last_t = np.zeros_like(is_last_t, dtype=bool)

LOG.info("# %s nodes at starting `t`.", np.sum(is_first_t))
LOG.info("# %s nodes at last `t`.", np.sum(is_last_t))

Expand Down
Loading
Loading