Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions src/tracksdata/array/_graph_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ class GraphArrayView(BaseReadOnlyArray):
buffer_cache_size : int, optional
The maximum number of buffers to keep in the cache for the array.
If None, the default buffer cache size is used.
mask_attr_key : str, optional
The node attribute key used to retrieve mask data. Defaults to "mask".
bbox_attr_key : str, optional
The node attribute key used to retrieve bounding box data. Defaults to "bbox".
"""

def __init__(
Expand All @@ -152,13 +156,17 @@ def __init__(
chunk_shape: tuple[int, ...] | int | None = None,
buffer_cache_size: int | None = None,
dtype: np.dtype | None = None,
mask_attr_key: str = DEFAULT_ATTR_KEYS.MASK,
bbox_attr_key: str = DEFAULT_ATTR_KEYS.BBOX,
):
if attr_key not in graph.node_attr_keys(return_ids=True):
raise ValueError(f"Attribute key '{attr_key}' not found in graph. Expected '{graph.node_attr_keys()}'")

self.graph = graph
self._attr_key = attr_key
self._offset = offset
self._mask_attr_key = mask_attr_key
self._bbox_attr_key = bbox_attr_key

if dtype is None:
# Infer the dtype from the graph's attribute
Expand Down Expand Up @@ -198,7 +206,7 @@ def __init__(

self._spatial_filter = self.graph.bbox_spatial_filter(
frame_attr_key=DEFAULT_ATTR_KEYS.T,
bbox_attr_key=DEFAULT_ATTR_KEYS.BBOX,
bbox_attr_key=self._bbox_attr_key,
)
self.graph.node_added.connect(self._on_node_added)
self.graph.node_removed.connect(self._on_node_removed)
Expand Down Expand Up @@ -348,10 +356,10 @@ def _fill_array(self, time: int, volume_slicing: Sequence[slice], buffer: np.nda
"""
subgraph = self._spatial_filter[(slice(time, time), *volume_slicing)]
df = subgraph.node_attrs(
attr_keys=[self._attr_key, DEFAULT_ATTR_KEYS.MASK],
attr_keys=[self._attr_key, self._mask_attr_key],
)

for mask, value in zip(df[DEFAULT_ATTR_KEYS.MASK], df[self._attr_key], strict=True):
for mask, value in zip(df[self._mask_attr_key], df[self._attr_key], strict=True):
mask: Mask
mask.paint_buffer(buffer, value, offset=self._offset)

Expand Down Expand Up @@ -394,13 +402,18 @@ def _invalidate_from_attrs(self, attrs: dict) -> None:
Invalidate cache region touched by node attributes.

Falls back to larger invalidation windows when metadata is incomplete.
When the bbox attribute key is not present in the attrs dict (e.g. when
using a non-default bbox key and the update doesn't affect this view),
the method returns without invalidating.
"""

time_value = attrs.get(DEFAULT_ATTR_KEYS.T)
if time_value is None:
raise ValueError(f"Node attributes must contain '{DEFAULT_ATTR_KEYS.T}' key for cache invalidation.")
if DEFAULT_ATTR_KEYS.BBOX not in attrs:
raise ValueError(f"Node attributes must contain '{DEFAULT_ATTR_KEYS.BBOX}' key for cache invalidation.")
if self._bbox_attr_key not in attrs:
# The update doesn't involve this view's bbox attribute —
# nothing to invalidate (e.g. a nuclear GAV seeing a membrane-only update).
return

try:
time = int(np.asarray(time_value).item())
Expand All @@ -411,7 +424,7 @@ def _invalidate_from_attrs(self, attrs: dict) -> None:
if not (0 <= time < self.original_shape[0]):
return

slices = self._bbox_to_slices(attrs[DEFAULT_ATTR_KEYS.BBOX])
slices = self._bbox_to_slices(attrs[self._bbox_attr_key])
if slices is not None:
self._cache.invalidate(time=time, volume_slicing=slices)

Expand Down
116 changes: 116 additions & 0 deletions src/tracksdata/array/_test/test_graph_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,3 +497,119 @@ def test_graph_array_view_invalidates_chunk_on_remove(graph_backend: BaseGraph)
output = np.asarray(array_view[0])
assert output[1, 1] == 1
assert output[5, 5] == 0


def test_graph_array_view_custom_mask_bbox_keys(graph_backend: BaseGraph) -> None:
"""Test GraphArrayView with custom mask_attr_key and bbox_attr_key."""

# Standard mask/bbox attributes
graph_backend.add_node_attr_key("label", dtype=pl.Int64)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4))

# Custom (nuclear) mask/bbox attributes
graph_backend.add_node_attr_key("nuc_mask", pl.Object)
graph_backend.add_node_attr_key("nuc_bbox", pl.Array(pl.Int64, 4), default_value=[0, 0, 0, 0])

# Create masks: membrane is 4x4, nuclear is 2x2 at same location
mem_mask = Mask(np.ones((4, 4), dtype=bool), bbox=np.array([10, 20, 14, 24]))
nuc_mask = Mask(np.ones((2, 2), dtype=bool), bbox=np.array([11, 21, 13, 23]))

graph_backend.add_node(
{
DEFAULT_ATTR_KEYS.T: 0,
"label": 5,
DEFAULT_ATTR_KEYS.MASK: mem_mask,
DEFAULT_ATTR_KEYS.BBOX: mem_mask.bbox,
"nuc_mask": nuc_mask,
"nuc_bbox": nuc_mask.bbox,
}
)

# Standard GAV uses default mask/bbox
std_view = GraphArrayView(
graph=graph_backend, shape=(2, 50, 50), attr_key="label"
)
# Custom GAV uses nuclear mask/bbox
nuc_view = GraphArrayView(
graph=graph_backend,
shape=(2, 50, 50),
attr_key="label",
mask_attr_key="nuc_mask",
bbox_attr_key="nuc_bbox",
)

std_result = np.asarray(std_view[0])
nuc_result = np.asarray(nuc_view[0])

# Standard view should have label painted at membrane mask area (4x4)
assert std_result[10, 20] == 5
assert std_result[13, 23] == 5
# Nuclear view should have label painted at nuclear mask area (2x2)
assert nuc_result[11, 21] == 5
assert nuc_result[12, 22] == 5

# Point inside membrane but outside nuclear mask
assert std_result[10, 20] == 5 # inside membrane
assert nuc_result[10, 20] == 0 # outside nuclear

# Total painted area differs
assert np.sum(std_result > 0) == 16 # 4x4
assert np.sum(nuc_result > 0) == 4 # 2x2


def test_graph_array_view_custom_keys_survives_membrane_update(graph_backend: BaseGraph) -> None:
"""Test that a nuclear GAV handles membrane-only updates without error.

When update_node_attrs is called with only the standard mask/bbox,
the nuclear GAV should handle the signal gracefully (the signal attrs
may or may not include 'nuc_bbox' depending on the graph backend).
After the update, the nuclear GAV should still return correct data.
"""
graph_backend.add_node_attr_key("label", dtype=pl.Int64)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4))
graph_backend.add_node_attr_key("nuc_mask", pl.Object)
graph_backend.add_node_attr_key("nuc_bbox", pl.Array(pl.Int64, 4), default_value=[0, 0, 0, 0])

mem_mask = Mask(np.ones((2, 2), dtype=bool), bbox=np.array([1, 1, 3, 3]))
nuc_mask = Mask(np.ones((2, 2), dtype=bool), bbox=np.array([1, 1, 3, 3]))

node_id = graph_backend.add_node(
{
DEFAULT_ATTR_KEYS.T: 0,
"label": 1,
DEFAULT_ATTR_KEYS.MASK: mem_mask,
DEFAULT_ATTR_KEYS.BBOX: mem_mask.bbox,
"nuc_mask": nuc_mask,
"nuc_bbox": nuc_mask.bbox,
}
)

nuc_view = GraphArrayView(
graph=graph_backend,
shape=(2, 8, 8),
attr_key="label",
mask_attr_key="nuc_mask",
bbox_attr_key="nuc_bbox",
)

# Verify initial nuclear data is correct
output = np.asarray(nuc_view[0])
assert output[1, 1] == 1
assert output[2, 2] == 1

# Update only the membrane mask — should not crash the nuclear GAV
moved_mem_mask = Mask(np.ones((2, 2), dtype=bool), bbox=np.array([5, 5, 7, 7]))
graph_backend.update_node_attrs(
attrs={
DEFAULT_ATTR_KEYS.MASK: [moved_mem_mask],
DEFAULT_ATTR_KEYS.BBOX: [moved_mem_mask.bbox],
},
node_ids=[node_id],
)

# Nuclear data should still be correct after membrane-only update
output = np.asarray(nuc_view[0])
assert output[1, 1] == 1
assert output[2, 2] == 1
7 changes: 7 additions & 0 deletions src/tracksdata/graph/filters/_spatial_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,9 @@ def _add_node(
new_attrs : dict[str, Any]
Current node attributes to insert into the spatial index.
"""
if self._bbox_attr_key not in new_attrs:
return

from spatial_graph import PointRTree

if self._node_rtree is None:
Expand Down Expand Up @@ -470,6 +473,8 @@ def _remove_node(
"""
if self._node_rtree is None:
return
if self._bbox_attr_key not in old_attrs:
return

positions_min, positions_max = self._attrs_to_bb_window(old_attrs)

Expand All @@ -485,6 +490,8 @@ def _update_node(
old_attrs: dict[str, Any],
new_attrs: dict[str, Any],
) -> None:
if self._bbox_attr_key not in old_attrs:
return
self._remove_node(node_id, old_attrs=old_attrs)
self._add_node(node_id, new_attrs=new_attrs)

Expand Down