Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions benchmarks/graph_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
N_LINEAGES = 50


def _noop(*_args) -> None:
"""No-op slot to enable the `node_updated` signal-payload path in benchmarks."""


def _build_graph(backend_name: str, n_nodes: int) -> td.graph.BaseGraph:
graph = BACKENDS[backend_name]()
graph.add_node_attr_key("score", dtype=pl.Float64)
Expand Down Expand Up @@ -60,6 +64,16 @@ def setup(self, backend_name: str, n_nodes: int) -> None:
self.removal_targets = all_ids[:N_OPS]
self.update_targets = all_ids[: N_OPS * 4]

# Separate view with a no-op listener attached. Without a listener,
# update_node_attrs skips the signal-payload computation entirely, so
# the P2-2 optimization (deriving new_attrs from old + applied) isn't
# exercised. This view is the BBoxSpatialFilter / GraphArrayView use case.
self.listened_view = self.graph.filter().subgraph()
self.listened_view.node_updated.connect(_noop)
# Smaller batch, representative of interactive editing where the saved
# query overhead is a larger fraction of the total work.
self.listener_update_targets = all_ids[:N_OPS]

# --- remove_node ------------------------------------------------------

def time_remove_node_root(self, backend_name: str, n_nodes: int) -> None:
Expand All @@ -78,6 +92,9 @@ def time_update_node_attrs_root(self, backend_name: str, n_nodes: int) -> None:
def time_update_node_attrs_view(self, backend_name: str, n_nodes: int) -> None:
self.view.update_node_attrs(node_ids=self.update_targets, attrs={"score": 1.0})

def time_update_node_attrs_view_with_listener(self, backend_name: str, n_nodes: int) -> None:
self.listened_view.update_node_attrs(node_ids=self.listener_update_targets, attrs={"score": 1.0})

# --- filter (standalone, materialized to ids) ------------------------

def time_filter_node_ids(self, backend_name: str, n_nodes: int) -> None:
Expand Down
164 changes: 134 additions & 30 deletions src/tracksdata/graph/_graph_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Literal, cast, overload

import bidict
import numpy as np
import polars as pl
import rustworkx as rx

Expand Down Expand Up @@ -428,6 +429,29 @@ def bulk_add_nodes(self, nodes: list[dict[str, Any]], indices: list[int] | None

return parent_node_ids

def _remove_node_local(self, node_id: int) -> None:
"""
Remove a node from this view's local rx_graph and ID mappings only.

No validation, no signals, no root call. Caller is responsible for those.
"""
local_node_id = self._external_to_local[node_id]

# Capture incident edges BEFORE removal. rustworkx drops them along with
# the node; afterwards we'd have no way to identify which entries to
# clean from `_edge_map_to_root` without scanning the whole bookkeeping.
# `all_edges=True` is required — the default returns only out-edges,
# which would leave in-edge bookkeeping stale.
incident_local_edge_ids = list(self.rx_graph.incident_edges(local_node_id, all_edges=True))

with self.node_removed.blocked():
super().remove_node(local_node_id)

self._remove_id_mapping(external_id=node_id)

for edge_id in incident_local_edge_ids:
self._edge_map_to_root.pop(edge_id, None)

def remove_node(self, node_id: int) -> None:
"""
Remove a node from the graph.
Expand Down Expand Up @@ -462,26 +486,7 @@ def remove_node(self, node_id: int) -> None:
self._root.remove_node(node_id)

if self.sync:
# Get the local node ID and remove from local graph
local_node_id = self._external_to_local[node_id]

with self.node_removed.blocked():
super().remove_node(local_node_id)

# Remove the node mapping
self._remove_id_mapping(external_id=node_id)

# Update edge mappings - remove edges involving this node
edges_to_remove = []
edge_indices = self.rx_graph.edge_indices()
for local_edge_id, _ in list(self._edge_map_to_root.items()):
# Check if this edge is still in the local graph
if local_edge_id not in edge_indices:
edges_to_remove.append(local_edge_id)

for edge_id in edges_to_remove:
if edge_id in self._edge_map_to_root:
del self._edge_map_to_root[edge_id]
self._remove_node_local(node_id)
else:
self._out_of_sync = True

Expand All @@ -490,6 +495,44 @@ def remove_node(self, node_id: int) -> None:
if view_signal_on:
self.node_removed.emit(node_id, old_attrs)

def remove_node_from_view(self, node_id: int) -> None:
"""
Remove a node from this view only, leaving the root graph untouched.

The view's local rx_graph and ID mappings are updated; the root is not
modified. After this call the view no longer represents a strict filter
of the root, but its internal state is consistent and traversals
(successors/predecessors) continue to work.

Only the view's `node_removed` signal fires — the root signal does not,
because the root did not change.

Parameters
----------
node_id : int
The ID of the node to remove from the view.

Raises
------
ValueError
If the node_id does not exist in the view.
RuntimeError
If `sync=False` — view-only removal requires a maintained local view.
"""
if node_id not in self._external_to_local:
raise ValueError(f"Node {node_id} does not exist in the graph.")
if not self.sync:
raise RuntimeError("remove_node_from_view requires sync=True; the local view is not maintained otherwise.")

view_signal_on = is_signal_on(self.node_removed)
if view_signal_on:
old_attrs = self.nodes[node_id].to_dict()

self._remove_node_local(node_id)

if view_signal_on:
self.node_removed.emit(node_id, old_attrs)

def add_edge(
self,
source_id: int,
Expand Down Expand Up @@ -544,6 +587,18 @@ def bulk_add_edges(self, edges: list[dict[str, Any]], return_ids: bool = False)
if return_ids:
return parent_edge_ids

def _remove_edge_local(self, edge_id: int) -> None:
"""
Remove an edge from this view's local rx_graph and edge mapping only.

No validation, no root call. Caller guarantees `edge_id` (root id) is
present in `self._edge_map_from_root`.
"""
local_edge_id = self._edge_map_from_root[edge_id]
src, tgt, _ = self.rx_graph.edge_index_map()[local_edge_id]
self.rx_graph.remove_edge(src, tgt)
del self._edge_map_to_root[local_edge_id]

def remove_edge(
self,
source_id: int | None = None,
Expand All @@ -568,14 +623,57 @@ def remove_edge(
# Remove from the local graph if synced
if self.sync:
if edge_id in self._edge_map_from_root:
local_edge_id = self._edge_map_from_root[edge_id]
edge_map = self.rx_graph.edge_index_map()
src, tgt, _ = edge_map[local_edge_id]
self.rx_graph.remove_edge(src, tgt)
del self._edge_map_to_root[local_edge_id]
self._remove_edge_local(edge_id)
else:
self._out_of_sync = True

def remove_edge_from_view(
self,
source_id: int | None = None,
target_id: int | None = None,
*,
edge_id: int | None = None,
) -> None:
"""
Remove an edge from this view only, leaving the root graph untouched.

Resolves the edge by `edge_id` or by `(source_id, target_id)`. The root
graph is not modified, so the view will diverge from a strict filter of
the root.

Parameters
----------
source_id : int, optional
Source node id of the edge. Required if `edge_id` is not given.
target_id : int, optional
Target node id of the edge. Required if `edge_id` is not given.
edge_id : int, optional
Root edge id. If given, `source_id` and `target_id` are ignored.

Raises
------
ValueError
If neither `edge_id` nor both endpoints are given, or the edge is
not in the view.
RuntimeError
If `sync=False` — view-only removal requires a maintained local view.
"""
if not self.sync:
raise RuntimeError("remove_edge_from_view requires sync=True; the local view is not maintained otherwise.")

if edge_id is None:
if source_id is None or target_id is None:
raise ValueError("Provide either edge_id or both source_id and target_id.")
try:
edge_id = self._root.edge_id(source_id, target_id)
except rx.NoEdgeBetweenNodes as e:
raise ValueError(f"Edge {source_id}->{target_id} does not exist in the graph.") from e

if edge_id not in self._edge_map_from_root:
raise ValueError(f"Edge {edge_id} does not exist in the view.")

self._remove_edge_local(edge_id)

def _get_neighbors(
self,
neighbors_func: Callable[[rx.PyDiGraph, int], rx.NodeIndices],
Expand Down Expand Up @@ -740,12 +838,18 @@ def update_node_attrs(
self._out_of_sync = True

if view_signal_on or root_signal_on:
new_attrs_by_id = (
self._root.filter(node_ids=node_ids)
.node_attrs(attr_keys=signal_keys)
.rows_by_key(key=DEFAULT_ATTR_KEYS.NODE_ID, named=True, unique=True, include_key=True)
)
old_attrs_by_id = cast(dict[int, dict[str, Any]], old_attrs_by_id) # for mypy
# Derive new_attrs by overlaying applied `attrs` onto old_attrs, instead of
# re-querying root. Mirrors the broadcasting semantics of
# `_root.update_node_attrs`: scalars apply to all nodes, sequences index by
# position in `node_ids`.
new_attrs_by_id: dict[int, dict[str, Any]] = {}
for i, node_id in enumerate(node_ids):
new_attrs = dict(old_attrs_by_id[node_id])
for k, v in attrs.items():
if k in new_attrs:
new_attrs[k] = v if np.isscalar(v) else v[i]
new_attrs_by_id[node_id] = new_attrs
if root_signal_on:
for node_id in node_ids:
self._root.node_updated.emit(
Expand Down
Loading
Loading