Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/tracksdata/array/_graph_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _validate_shape(
"""Helper function to validate the shape argument."""
if shape is None:
try:
shape = graph.metadata()["shape"]
shape = graph.metadata["shape"]
except KeyError as e:
raise KeyError(
f"`shape` is required to `{func_name}`. "
Expand Down
2 changes: 1 addition & 1 deletion src/tracksdata/functional/_test/test_napari.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_napari_conversion(metadata_shape: bool) -> None:

shape = (2, 10, 22, 32)
if metadata_shape:
graph.update_metadata(shape=shape)
graph.metadata.update(shape=shape)
arg_shape = None
else:
arg_shape = shape
Expand Down
4 changes: 2 additions & 2 deletions src/tracksdata/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Graph backends for representing tracking data as directed graphs in memory or on disk."""

from tracksdata.graph._base_graph import BaseGraph
from tracksdata.graph._base_graph import BaseGraph, MetadataView
from tracksdata.graph._graph_view import GraphView
from tracksdata.graph._rustworkx_graph import IndexedRXGraph, RustWorkXGraph
from tracksdata.graph._sql_graph import SQLGraph

InMemoryGraph = RustWorkXGraph

__all__ = ["BaseGraph", "GraphView", "InMemoryGraph", "IndexedRXGraph", "RustWorkXGraph", "SQLGraph"]
__all__ = ["BaseGraph", "GraphView", "InMemoryGraph", "IndexedRXGraph", "MetadataView", "RustWorkXGraph", "SQLGraph"]
153 changes: 121 additions & 32 deletions src/tracksdata/graph/_base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,75 @@
T = TypeVar("T", bound="BaseGraph")


class MetadataView(dict[str, Any]):
"""Dictionary-like metadata view that syncs mutations back to the graph."""

_MISSING = object()

def __init__(
self,
graph: "BaseGraph",
data: dict[str, Any],
*,
is_public: bool = True,
) -> None:
super().__init__(data)
self._graph = graph
self._is_public = is_public

def __setitem__(self, key: str, value: Any) -> None:
self._graph._set_metadata_with_validation(is_public=self._is_public, **{key: value})
super().__setitem__(key, value)

def __delitem__(self, key: str) -> None:
self._graph._remove_metadata_with_validation(key, is_public=self._is_public)
super().__delitem__(key)

def pop(self, key: str, default: Any = _MISSING) -> Any:
self._graph._validate_metadata_key(key, is_public=self._is_public)

if key not in self:
if default is self._MISSING:
raise KeyError(key)
return default

value = super().__getitem__(key)
self._graph._remove_metadata_with_validation(key, is_public=self._is_public)
super().pop(key, None)
return value

def popitem(self) -> tuple[str, Any]:
key, value = super().popitem()
self._graph._remove_metadata_with_validation(key, is_public=self._is_public)
return key, value

def clear(self) -> None:
keys = list(self.keys())
for key in keys:
self._graph._remove_metadata_with_validation(key, is_public=self._is_public)
super().clear()

def setdefault(self, key: str, default: Any = None) -> Any:
if key in self:
return super().__getitem__(key)
self._graph._set_metadata_with_validation(is_public=self._is_public, **{key: default})
super().__setitem__(key, default)
return default

def update(self, *args, **kwargs) -> None:
updates = dict(*args, **kwargs)
if updates:
self._graph._set_metadata_with_validation(is_public=self._is_public, **updates)
super().update(updates)


class BaseGraph(abc.ABC):
"""
Base class for a graph backend.
"""

_PRIVATE_METADATA_PREFIX = "__private_"

node_added = Signal(int)
node_removed = Signal(int)

Expand Down Expand Up @@ -1186,7 +1250,8 @@ def from_other(cls: type[T], other: "BaseGraph", **kwargs) -> T:
node_attrs = node_attrs.drop(DEFAULT_ATTR_KEYS.NODE_ID)

graph = cls(**kwargs)
graph.update_metadata(**other.metadata())
graph.metadata.update(other.metadata)
graph._private_metadata.update(other._private_metadata)

current_node_attr_schemas = graph._node_attr_schemas()
for k, v in other._node_attr_schemas().items():
Expand Down Expand Up @@ -1786,7 +1851,7 @@ def to_geff(
for k, v in edge_attrs.to_dict().items()
}

td_metadata = self.metadata().copy()
td_metadata = self.metadata.copy()
td_metadata.pop("geff", None) # avoid geff being written multiple times
Comment thread
JoOkuma marked this conversation as resolved.

geff_metadata = geff.GeffMetadata(
Expand Down Expand Up @@ -1824,57 +1889,81 @@ def to_geff(
zarr_format=zarr_format,
)

@abc.abstractmethod
def metadata(self) -> dict[str, Any]:
@property
def metadata(self) -> MetadataView:
"""
Return the metadata of the graph.

Returns
-------
dict[str, Any]
MetadataView
The metadata of the graph as a dictionary.

Examples
--------
```python
metadata = graph.metadata()
metadata = graph.metadata
print(metadata["shape"])
```
"""
return MetadataView(
graph=self,
data={k: v for k, v in self._metadata().items() if not self._is_private_metadata_key(k)},
is_public=True,
)

@abc.abstractmethod
def update_metadata(self, **kwargs) -> None:
"""
Set or update metadata for the graph.
@property
def _private_metadata(self) -> MetadataView:
return MetadataView(
graph=self,
data={k: v for k, v in self._metadata().items() if self._is_private_metadata_key(k)},
is_public=False,
)

Parameters
----------
**kwargs : Any
The metadata items to set by key. Values will be stored as JSON.
@classmethod
def _is_private_metadata_key(cls, key: str) -> bool:
return key.startswith(cls._PRIVATE_METADATA_PREFIX)

def _validate_metadata_key(self, key: str, *, is_public: bool) -> None:
if not isinstance(key, str):
raise TypeError(f"Metadata key must be a string. Got {type(key)}.")
is_private_key = self._is_private_metadata_key(key)
if is_public and is_private_key:
raise ValueError(f"Metadata key '{key}' is reserved for internal use.")
if not is_public and not is_private_key:
raise ValueError(
f"Metadata key '{key}' is not private. Private metadata keys must start with "
f"'{self._PRIVATE_METADATA_PREFIX}'."
)

Examples
--------
```python
graph.update_metadata(shape=[1, 25, 25], path="path/to/image.ome.zarr")
graph.update_metadata(description="Tracking data from experiment 1")
```
"""
def _validate_metadata_keys(self, keys: Sequence[str], *, is_public: bool) -> None:
for key in keys:
self._validate_metadata_key(key, is_public=is_public)

def _set_metadata_with_validation(self, is_public: bool = True, **kwargs) -> None:
self._validate_metadata_keys(kwargs.keys(), is_public=is_public)
self._update_metadata(**kwargs)

def _remove_metadata_with_validation(self, key: str, *, is_public: bool = True) -> None:
self._validate_metadata_key(key, is_public=is_public)
self._remove_metadata(key)

@abc.abstractmethod
def remove_metadata(self, key: str) -> None:
def _metadata(self) -> dict[str, Any]:
"""
Return the full metadata including private keys.
"""
Remove a metadata key from the graph.

Parameters
----------
key : str
The key of the metadata to remove.
@abc.abstractmethod
def _update_metadata(self, **kwargs) -> None:
"""
Backend-specific metadata update implementation without public key validation.
"""

Examples
--------
```python
graph.remove_metadata("shape")
```
@abc.abstractmethod
def _remove_metadata(self, key: str) -> None:
"""
Backend-specific metadata removal implementation without public key validation.
"""

def to_traccuracy_graph(self, array_view_kwargs: dict[str, Any] | None = None) -> "TrackingGraph":
Expand Down
12 changes: 6 additions & 6 deletions src/tracksdata/graph/_graph_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,11 +847,11 @@ def copy(self, **kwargs) -> "GraphView":
"Use `detach` to create a new reference-less graph with the same nodes and edges."
)

def metadata(self) -> dict[str, Any]:
return self._root.metadata()
def _metadata(self) -> dict[str, Any]:
return self._root._metadata()

def update_metadata(self, **kwargs) -> None:
self._root.update_metadata(**kwargs)
def _update_metadata(self, **kwargs) -> None:
self._root._update_metadata(**kwargs)

def remove_metadata(self, key: str) -> None:
self._root.remove_metadata(key)
def _remove_metadata(self, key: str) -> None:
self._root._remove_metadata(key)
8 changes: 4 additions & 4 deletions src/tracksdata/graph/_rustworkx_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def __init__(self, rx_graph: rx.PyDiGraph | None = None) -> None:

elif not isinstance(self._graph.attrs, dict):
LOG.warning(
"previous attribute %s will be added to key 'old_attrs' of `graph.metadata()`",
"previous attribute %s will be added to key 'old_attrs' of `graph.metadata`",
self._graph.attrs,
)
self._graph.attrs = {
Expand Down Expand Up @@ -1499,13 +1499,13 @@ def edge_id(self, source_id: int, target_id: int) -> int:
"""
return self.rx_graph.get_edge_data(source_id, target_id)[DEFAULT_ATTR_KEYS.EDGE_ID]

def metadata(self) -> dict[str, Any]:
def _metadata(self) -> dict[str, Any]:
return self._graph.attrs

def update_metadata(self, **kwargs) -> None:
def _update_metadata(self, **kwargs) -> None:
self._graph.attrs.update(kwargs)

def remove_metadata(self, key: str) -> None:
def _remove_metadata(self, key: str) -> None:
self._graph.attrs.pop(key, None)

def edge_list(self) -> list[list[int, int]]:
Expand Down
6 changes: 3 additions & 3 deletions src/tracksdata/graph/_sql_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1992,19 +1992,19 @@ def remove_edge(
raise ValueError(f"Edge {edge_id} does not exist in the graph.")
session.commit()

def metadata(self) -> dict[str, Any]:
def _metadata(self) -> dict[str, Any]:
with Session(self._engine) as session:
result = session.query(self.Metadata).all()
return {row.key: row.value for row in result}

def update_metadata(self, **kwargs) -> None:
def _update_metadata(self, **kwargs) -> None:
with Session(self._engine) as session:
for key, value in kwargs.items():
metadata_entry = self.Metadata(key=key, value=value)
session.merge(metadata_entry)
session.commit()

def remove_metadata(self, key: str) -> None:
def _remove_metadata(self, key: str) -> None:
with Session(self._engine) as session:
session.query(self.Metadata).filter(self.Metadata.key == key).delete()
session.commit()
Expand Down
Loading
Loading