Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api/layout_functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Layout Functions

rustworkx.random_layout
rustworkx.spring_layout
rustworkx.kamada_kawai_layout
rustworkx.bipartite_layout
rustworkx.circular_layout
rustworkx.shell_layout
Expand Down
27 changes: 27 additions & 0 deletions releasenotes/notes/add-kamada-kawai-layout-b5a09cd65656e99f.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
---
features:
- |
Added a new layout function,
:func:`~rustworkx.kamada_kawai_layout`, which positions nodes using
the Kamada-Kawai path-length cost function. The function works with
both :class:`~rustworkx.PyGraph` and :class:`~rustworkx.PyDiGraph`
inputs. The implementation follows the original 1989 algorithm of
Kamada and Kawai: an outer loop selects the node with the largest
partial-gradient norm and an inner loop applies a 2D Newton step
against the local Hessian until convergence.

Disconnected graphs are handled by laying out each connected
component independently and packing the components in a horizontal
row. This avoids the visual collapse seen with single-objective
Kamada-Kawai on disconnected inputs.

Example usage:

.. jupyter-execute::

import rustworkx
from rustworkx.visualization import mpl_draw

graph = rustworkx.generators.hexagonal_lattice_graph(2, 2)
layout = rustworkx.kamada_kawai_layout(graph)
mpl_draw(graph, pos=layout)
107 changes: 107 additions & 0 deletions rustworkx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,113 @@ def networkx_converter(graph, keep_attributes: bool = False):
return new_graph


@_rustworkx_dispatch
def kamada_kawai_layout(
graph,
pos=None,
fixed=None,
weight_fn=None,
default_weight=1.0,
epsilon=1e-4,
max_outer=500,
max_inner=10,
scale=1.0,
center=None,
):
"""
Position nodes using the Kamada-Kawai path-length cost-function.

The layout minimises the energy

.. math::

E = \\frac{1}{2} \\sum_{i<j} k_{ij} (|p_i - p_j| - l_{ij})^2

where :math:`d_{ij}` is the graph-theoretic shortest path between
nodes :math:`i` and :math:`j`, :math:`l_{ij} \\propto d_{ij}` is the
desired display distance, and :math:`k_{ij} = 1 / d_{ij}^2` is the
spring constant. Minimisation follows the original Kamada and Kawai
(1989) scheme: at each outer step the node with the largest
partial-gradient norm is selected and updated by a 2D Newton step
against the local 2x2 Hessian.

Disconnected graphs are laid out by running Kamada-Kawai independently
on each connected component and packing the components in a row, so
components remain visibly separated rather than fighting for space
inside a single global energy minimisation.

:param graph: Graph to be used. Can either be a
:class:`~rustworkx.PyGraph` or :class:`~rustworkx.PyDiGraph`.
:param dict pos:
Initial node positions as a dictionary with node ids as keys
and values as a coordinate list. If ``None``, a per-component
circular layout is used as the starting point. (``default=None``)
:param set fixed: Nodes to keep fixed at initial position.
Error raised if ``fixed`` is specified and ``pos`` is not.
(``default=None``)
:param weight_fn: An optional weight function for an edge. It
will accept a single argument, the edge's weight object, and
return a float used as the edge weight in the all-pairs
shortest path computation.
:param float default_weight: Edge weight when ``weight_fn`` is not
provided. (``default=1.0``)
:param float epsilon: Convergence threshold for the maximum
partial-gradient norm. (``default=1e-4``)
:param int max_outer: Maximum number of outer iterations.
(``default=500``)
:param int max_inner: Maximum number of inner Newton steps per
outer iteration. (``default=10``)
:param float scale: Scale factor for positions. Not used unless
``fixed`` is ``None``. (``default=1.0``)
:param list center: Coordinate pair around which to center the
layout. Not used unless ``fixed`` is ``None``.
(``default=None``)

:returns: A dictionary of positions keyed by node id.
:rtype: dict
"""
raise TypeError(f"Invalid Input Type {type(graph)} for graph")
"""Convert a networkx graph object into a rustworkx graph object.

.. note::

networkx is **not** a dependency of rustworkx and this function
is provided as a convenience method for users of both networkx and
rustworkx. This function will not work unless you install networkx
independently.

:param networkx.Graph graph: The networkx graph to convert.
:param bool keep_attributes: If ``True``, add networkx node attributes to
the data payload in the nodes of the output rustworkx graph. When set to
``True``, the node data payloads in the output rustworkx graph object
will be dictionaries with the node attributes from the input networkx
graph where the ``"__networkx_node__"`` key contains the node from the
input networkx graph.

:returns: A rustworkx graph, either a :class:`~rustworkx.PyDiGraph` or a
:class:`~rustworkx.PyGraph` based on whether the input graph is directed
or not.
:rtype: :class:`~rustworkx.PyDiGraph` or :class:`~rustworkx.PyGraph`
"""
if graph.is_directed():
new_graph = PyDiGraph(multigraph=graph.is_multigraph())
else:
new_graph = PyGraph(multigraph=graph.is_multigraph())
nodes = list(graph.nodes)
node_indices = dict(zip(nodes, new_graph.add_nodes_from(nodes)))
new_graph.add_edges_from(
[(node_indices[x[0]], node_indices[x[1]], x[2]) for x in graph.edges(data=True)]
)

if keep_attributes:
for node, node_index in node_indices.items():
attributes = graph.nodes[node]
attributes["__networkx_node__"] = node
new_graph[node_index] = attributes

return new_graph


@_rustworkx_dispatch
def bipartite_layout(
graph,
Expand Down
14 changes: 14 additions & 0 deletions rustworkx/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ from .rustworkx import digraph_bipartite_layout as digraph_bipartite_layout
from .rustworkx import graph_bipartite_layout as graph_bipartite_layout
from .rustworkx import digraph_circular_layout as digraph_circular_layout
from .rustworkx import graph_circular_layout as graph_circular_layout
from .rustworkx import digraph_kamada_kawai_layout as digraph_kamada_kawai_layout
from .rustworkx import graph_kamada_kawai_layout as graph_kamada_kawai_layout
from .rustworkx import digraph_random_layout as digraph_random_layout
from .rustworkx import graph_random_layout as graph_random_layout
from .rustworkx import digraph_shell_layout as digraph_shell_layout
Expand Down Expand Up @@ -504,6 +506,18 @@ def spring_layout(
center: tuple[float, float] | None = ...,
seed: int | None = ...,
) -> Pos2DMapping: ...
def kamada_kawai_layout(
graph: PyGraph[_S, _T] | PyDiGraph[_S, _T],
pos: dict[int, tuple[float, float]] | None = ...,
fixed: set[int] | None = ...,
weight_fn: Callable[[_T], float] | None = ...,
default_weight: float = ...,
epsilon: float = ...,
max_outer: int = ...,
max_inner: int = ...,
scale: float = ...,
center: tuple[float, float] | None = ...,
) -> Pos2DMapping: ...
def networkx_converter(graph: Any, keep_attributes: bool = ...) -> PyGraph | PyDiGraph: ...
def bipartite_layout(
graph: PyGraph[_S, _T] | PyDiGraph[_S, _T],
Expand Down
26 changes: 26 additions & 0 deletions rustworkx/rustworkx.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,32 @@ def graph_spring_layout(
seed: int | None = ...,
/,
) -> Pos2DMapping: ...
def digraph_kamada_kawai_layout(
graph: PyDiGraph[_S, _T],
pos: dict[int, tuple[float, float]] | None = ...,
fixed: set[int] | None = ...,
weight_fn: Callable[[_T], float] | None = ...,
default_weight: float = ...,
epsilon: float = ...,
max_outer: int = ...,
max_inner: int = ...,
scale: float = ...,
center: tuple[float, float] | None = ...,
/,
) -> Pos2DMapping: ...
def graph_kamada_kawai_layout(
graph: PyGraph[_S, _T],
pos: dict[int, tuple[float, float]] | None = ...,
fixed: set[int] | None = ...,
weight_fn: Callable[[_T], float] | None = ...,
default_weight: float = ...,
epsilon: float = ...,
max_outer: int = ...,
max_inner: int = ...,
scale: float = ...,
center: tuple[float, float] | None = ...,
/,
) -> Pos2DMapping: ...

# Line graph

Expand Down
Loading