Skip to content
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
ed818c2
potential fix
mkolodner-sc May 19, 2026
abb8e56
Update
mkolodner-sc May 19, 2026
a0e84fa
Update
mkolodner-sc May 19, 2026
088fe1b
Improvements
mkolodner-sc May 19, 2026
5ca621c
Change int16 to int32
mkolodner-sc May 19, 2026
ac2ef26
Fix degree tensor tests and type checks
mkolodner-sc May 28, 2026
7ad9faa
Merge branch 'mkolodner-sc/ppr_gs_memory' of github.com:Snapchat/GiGL…
mkolodner-sc May 28, 2026
d850b37
Add E2E PPR graphstore test
mkolodner-sc May 28, 2026
845704b
Update
mkolodner-sc May 28, 2026
ebbc318
Fixes
mkolodner-sc May 28, 2026
65eac99
Fix PPR graph-store sampling worker capacity
mkolodner-sc May 28, 2026
97bd538
Fix
mkolodner-sc May 29, 2026
92c9f51
more fixes
mkolodner-sc May 29, 2026
7e31417
change back
mkolodner-sc May 29, 2026
d9d2086
Avoid cast for heterogeneous inference node ids
mkolodner-sc May 29, 2026
fd1e9ae
Trim branch to PPR sampler fixes
mkolodner-sc May 29, 2026
71e1fa1
Merge remote-tracking branch 'origin/main' into mkolodner-sc/ppr_gs_m…
mkolodner-sc May 29, 2026
2ef9548
Keep PPR test ty ignores
mkolodner-sc May 29, 2026
b08f0e5
Remove stale PPR test ty ignore
mkolodner-sc May 29, 2026
a6eedd1
Use union shape for PPR degree tensors
mkolodner-sc May 29, 2026
68ab0f2
Restore useful degree computation comments
mkolodner-sc May 29, 2026
e71ccdb
Remove sampler diagnostic wrapper
mkolodner-sc May 29, 2026
f76e548
Simplify degree all-reduce helper
mkolodner-sc May 29, 2026
23ee86f
Document degree tensor assumptions
mkolodner-sc May 29, 2026
3b3497d
Address PPR degree review comments
mkolodner-sc May 29, 2026
5ac1c63
Address PPR degree memory review comments
mkolodner-sc May 29, 2026
aa42d7a
Comments
mkolodner-sc May 29, 2026
2641834
Document PPR degree tensor dtype rationale
mkolodner-sc May 29, 2026
1ff8635
Address remaining comments
mkolodner-sc May 29, 2026
5548260
Fix
mkolodner-sc Jun 1, 2026
a9df285
Improve solution
mkolodner-sc Jun 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gigl/distributed/base_dist_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ def create_mp_producer(
if isinstance(degree_tensors, dict):
logger.info(
f"Pre-computed degree tensors for PPR sampling across "
f"{len(degree_tensors)} edge types."
f"{len(degree_tensors)} node types."
)
else:
logger.info(
Expand Down
32 changes: 18 additions & 14 deletions gigl/distributed/dist_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(
Union[FeatureInfo, dict[EdgeType, FeatureInfo]]
] = None,
degree_tensor: Optional[
Union[torch.Tensor, dict[EdgeType, torch.Tensor]]
Union[torch.Tensor, dict[NodeType, torch.Tensor]]
] = None,
max_labels_per_anchor_node: Optional[int] = None,
edge_weights: Optional[
Expand Down Expand Up @@ -111,7 +111,7 @@ def __init__(
Note this will be None in the homogeneous case if the data has no node features, or will only contain node types with node features in the heterogeneous case.
edge_feature_info: Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]]: Dimension of edge features and its data type, will be a dict if heterogeneous.
Note this will be None in the homogeneous case if the data has no edge features, or will only contain edge types with edge features in the heterogeneous case.
degree_tensor: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]: Pre-computed degree tensor. Lazily computed on first access via the degree_tensor property.
degree_tensor: Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]]: Pre-computed degree tensor. Lazily computed on first access via the degree_tensor property.
max_labels_per_anchor_node (Optional[int]): Optional cap for how many
labels to materialize per anchor node for ABLP label fetching.
edge_weights: Per-edge sampling weights for this rank's partition.
Expand Down Expand Up @@ -152,7 +152,7 @@ def __init__(
self._edge_feature_info = edge_feature_info

self._degree_tensor: Optional[
Union[torch.Tensor, dict[EdgeType, torch.Tensor]]
Union[torch.Tensor, dict[NodeType, torch.Tensor]]
] = degree_tensor
self._max_labels_per_anchor_node = max_labels_per_anchor_node
self._edge_weights: Optional[
Expand Down Expand Up @@ -315,23 +315,25 @@ def edge_feature_info(
@property
def degree_tensor(
self,
) -> Union[torch.Tensor, dict[EdgeType, torch.Tensor]]:
) -> Union[torch.Tensor, dict[NodeType, torch.Tensor]]:
"""
Lazily compute and return the degree tensor for the graph.
Lazily compute and return degree tensors for the graph.

On first access, computes node degrees from the graph partition and uses
all-reduce to aggregate across all machines. Requires torch.distributed
to be initialized.
all-reduce to aggregate across all machines. For heterogeneous graphs,
degrees are summed across all incident edge types per anchor node type
before the all-reduce, so the per-edge-type tensor is never stored.
Requires torch.distributed to be initialized.

Over-counting correction (for processes sharing the same data on the same
machine) is handled automatically by detecting the distributed topology.

The result is cached for subsequent accesses.

Returns:
Union[torch.Tensor, dict[EdgeType, torch.Tensor]]: The aggregated degree tensor.
- For homogeneous graphs: A tensor of shape [num_nodes].
- For heterogeneous graphs: A dict mapping EdgeType to degree tensors.
Union[torch.Tensor, dict[NodeType, torch.Tensor]]: Degree tensor for
homogeneous graphs, or total degree tensors keyed by node type
for heterogeneous graphs.

Raises:
RuntimeError: If torch.distributed is not initialized.
Expand All @@ -341,7 +343,9 @@ def degree_tensor(
if self.graph is None:
raise ValueError("Dataset graph is None. Cannot compute degrees.")

self._degree_tensor = compute_and_broadcast_degree_tensor(self.graph)
self._degree_tensor = compute_and_broadcast_degree_tensor(
self.graph, self._edge_dir
)
return self._degree_tensor

@property
Expand Down Expand Up @@ -943,7 +947,7 @@ def share_ipc(
Optional[Union[int, dict[NodeType, int]]],
Optional[Union[FeatureInfo, dict[NodeType, FeatureInfo]]],
Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]],
Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]],
Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]],
Optional[int],
Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]],
]:
Expand All @@ -967,7 +971,7 @@ def share_ipc(
Optional[Union[int, dict[NodeType, int]]]: Number of test nodes on the current machine. Will be a dict if heterogeneous.
Optional[Union[FeatureInfo, dict[NodeType, FeatureInfo]]]: Node feature dim and its data type, will be a dict if heterogeneous
Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]]: Edge feature dim and its data type, will be a dict if heterogeneous
Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]: Degree tensors, will be a dict if heterogeneous
Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]]: Degree tensors
Optional[int]: Optional per-anchor label cap for ABLP label fetching
Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]: Per-edge sampling weights for this rank's partition
"""
Expand Down Expand Up @@ -1256,7 +1260,7 @@ def _rebuild_distributed_dataset(
Optional[
Union[FeatureInfo, dict[EdgeType, FeatureInfo]]
], # Edge feature dim and its data type
Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], # Degree tensors
Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]], # Degree tensors
Optional[int], # Optional per-anchor label cap for ABLP label fetching
Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], # edge_weights
],
Expand Down
150 changes: 66 additions & 84 deletions gigl/distributed/dist_ppr_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from graphlearn_torch.utils import merge_dict

from gigl.distributed.base_sampler import BaseDistNeighborSampler
from gigl.types.graph import is_label_edge_type
from gigl.types.graph import DEFAULT_HOMOGENEOUS_NODE_TYPE, is_label_edge_type

# Trailing "." is an intentional separator. These constants are used both to
# write metadata keys (f"{KEY}{repr(edge_type)}" → e.g. "ppr_edge_index.('user', 'to', 'story')")
Expand All @@ -26,14 +26,14 @@
PPR_EDGE_INDEX_METADATA_KEY = "ppr_edge_index."
PPR_WEIGHT_METADATA_KEY = "ppr_weight."

# Sentinel type names for homogeneous graphs. The PPR algorithm uses
# dict[NodeType, ...] internally for both homo and hetero graphs; these
# sentinels let the homogeneous path reuse the same dict-based code.
_PPR_HOMOGENEOUS_NODE_TYPE = "ppr_homogeneous_node_type"
# Sentinel edge type for homogeneous graphs. The PPR algorithm uses
# dict[NodeType, ...] internally for both homo and hetero graphs; the
# DEFAULT_HOMOGENEOUS_NODE_TYPE sentinel lets the homogeneous path reuse
# the same dict-based code.
_PPR_HOMOGENEOUS_EDGE_TYPE = (
_PPR_HOMOGENEOUS_NODE_TYPE,
DEFAULT_HOMOGENEOUS_NODE_TYPE,
"to",
_PPR_HOMOGENEOUS_NODE_TYPE,
DEFAULT_HOMOGENEOUS_NODE_TYPE,
)


Expand Down Expand Up @@ -74,10 +74,11 @@ class DistPPRNeighborSampler(BaseDistNeighborSampler):
but require more computation. Typical values: 1e-4 to 1e-6.
max_ppr_nodes: Maximum number of nodes to return per seed based on PPR scores.
num_neighbors_per_hop: Maximum number of neighbors to fetch per hop.
total_degree_dtype: Dtype for precomputed total-degree tensors. Defaults
to ``torch.int32``. Use a larger dtype if nodes have exceptionally high
aggregate degrees.
degree_tensors: Pre-computed degree tensors from the dataset.
degree_tensors: Pre-computed total-degree tensors (int32). Homogeneous
graphs use a single tensor; heterogeneous graphs use tensors keyed
by NodeType. Must be pre-computed by the caller through
``DistDataset.degree_tensor`` so workers share a single allocation
rather than recomputing per-worker.
"""

def __init__(
Expand All @@ -87,8 +88,7 @@ def __init__(
eps: float = 1e-4,
max_ppr_nodes: int = 50,
num_neighbors_per_hop: int = 100_000,
total_degree_dtype: torch.dtype = torch.int32,
degree_tensors: Union[torch.Tensor, dict[EdgeType, torch.Tensor]],
degree_tensors: Union[torch.Tensor, dict[NodeType, torch.Tensor]],
max_fetch_iterations: Optional[int] = None,
**kwargs,
):
Expand Down Expand Up @@ -125,23 +125,14 @@ def __init__(

self._node_type_to_edge_types[anchor_type].append(etype)
else:
self._node_type_to_edge_types[_PPR_HOMOGENEOUS_NODE_TYPE] = [
self._node_type_to_edge_types[DEFAULT_HOMOGENEOUS_NODE_TYPE] = [
_PPR_HOMOGENEOUS_EDGE_TYPE
]
self._is_homogeneous = True

# Precompute total degree per node type: the sum of degrees across all
# edge types traversable from that node type. This is a graph-level
# property used on every PPR iteration, so computing it once at init
# avoids per-node summation and cache lookups in the hot loop.
# TODO (mkolodner-sc): This trades memory for throughput — we
# materialize a tensor per node type to avoid recomputing total degree
# on every neighbor during sampling. Computing it here (rather than in
# the dataset) also keeps the door open for edge-specific degree
# strategies. If memory becomes a bottleneck, revisit this.
self._node_type_to_total_degree: dict[NodeType, torch.Tensor] = (
self._build_total_degree_tensors(degree_tensors, total_degree_dtype)
)
# Normalize the public homogeneous/heterogeneous degree-tensor shape to
# the node-type keyed form used internally by PPR.
self._node_type_to_total_degree = self._normalize_degree_tensors(degree_tensors)

# Build integer ID mappings for the C++ forward-push kernel. String
# NodeType / EdgeType keys are only used at the Python boundary
Expand Down Expand Up @@ -191,57 +182,26 @@ def __init__(
for nt in all_node_types
]

def _build_total_degree_tensors(
def _normalize_degree_tensors(
Comment thread
mkolodner-sc marked this conversation as resolved.
Outdated
self,
degree_tensors: Union[torch.Tensor, dict[EdgeType, torch.Tensor]],
dtype: torch.dtype,
degree_tensors: Union[torch.Tensor, dict[NodeType, torch.Tensor]],
) -> dict[NodeType, torch.Tensor]:
"""Build total-degree tensors by summing per-edge-type degrees for each node type.

For homogeneous graphs, the total degree is just the single degree tensor.
For heterogeneous graphs, it sums degree tensors across all edge types
traversable from each node type, padding shorter tensors with zeros.

Args:
degree_tensors: Per-edge-type degree tensors from the dataset.
dtype: Dtype for the output tensors.

Returns:
Dict mapping node type to a 1-D tensor of total degrees.
"""
result: dict[NodeType, torch.Tensor] = {}
"""Normalize degree tensors to the node-type keyed shape PPR uses."""
if isinstance(degree_tensors, torch.Tensor):
if not self._is_homogeneous:
raise ValueError(
"Expected degree tensors keyed by node type for heterogeneous PPR sampling."
)
return {DEFAULT_HOMOGENEOUS_NODE_TYPE: degree_tensors}

if self._is_homogeneous:
assert isinstance(degree_tensors, torch.Tensor)
# Single edge type: degree values fit directly in the target dtype.
result[_PPR_HOMOGENEOUS_NODE_TYPE] = degree_tensors.to(dtype)
else:
assert isinstance(degree_tensors, dict)
dtype_max = torch.iinfo(dtype).max
for node_type, edge_types in self._node_type_to_edge_types.items():
max_len = 0
for et in edge_types:
if et not in degree_tensors:
raise ValueError(
f"Edge type {et} not found in degree tensors. "
f"Available: {list(degree_tensors.keys())}"
)
max_len = max(max_len, len(degree_tensors[et]))

# Each degree tensor is indexed by node ID (derived from CSR
# indptr), so index i in every edge type's tensor refers to
# the same node. Element-wise summation gives the total degree
# per node across all edge types. Shorter tensors are padded
# implicitly (only the first len(et_degrees) entries are added).
# Sum in int64: aggregate degrees are bounded by partition size
# and fit comfortably within int64 range in practice.
summed = torch.zeros(max_len, dtype=torch.int64)
for et in edge_types:
et_degrees = degree_tensors[et]
summed[: len(et_degrees)] += et_degrees.to(torch.int64)
result[node_type] = summed.clamp(max=dtype_max).to(dtype)

return result
missing_anchor_types = set(self._node_type_to_edge_types.keys()) - set(
degree_tensors.keys()
)
if missing_anchor_types:
raise ValueError(
f"Missing PPR degree tensors for node types: {missing_anchor_types}"
)
return degree_tensors

def _get_destination_type(self, edge_type: EdgeType) -> NodeType:
"""Get the node type at the destination end of an edge type."""
Expand Down Expand Up @@ -294,8 +254,15 @@ async def _batch_fetch_neighbors(
self._sample_one_hop(
srcs=nodes_by_etype_id[eid].to(device),
num_nbr=self._num_neighbors_per_hop,
# _sample_one_hop expects None for homogeneous graphs, not the PPR sentinel.
etype=None if etype == _PPR_HOMOGENEOUS_EDGE_TYPE else etype,
# _sample_one_hop expects None only for true homogeneous graphs.
# Labeled homogeneous ABLP graphs are hetero-backed because label
# edges are represented as separate edge types, so they still need
# the explicit default edge type here.
etype=(
None
if self._is_homogeneous and etype == _PPR_HOMOGENEOUS_EDGE_TYPE
else etype
),
)
)
outputs: list[NeighborOutput] = await asyncio.gather(*sample_tasks)
Expand Down Expand Up @@ -362,7 +329,7 @@ async def _compute_ppr_scores(
valid_counts = tensor([1, 3, 2, 0])
"""
if seed_node_type is None:
seed_node_type = _PPR_HOMOGENEOUS_NODE_TYPE
seed_node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE
device = seed_nodes.device

ppr_state = PPRForwardPush(
Expand Down Expand Up @@ -422,12 +389,12 @@ async def _compute_ppr_scores(
if self._is_homogeneous:
assert (
len(ntype_to_flat_ids) == 1
and _PPR_HOMOGENEOUS_NODE_TYPE in ntype_to_flat_ids
and DEFAULT_HOMOGENEOUS_NODE_TYPE in ntype_to_flat_ids
)
return (
ntype_to_flat_ids[_PPR_HOMOGENEOUS_NODE_TYPE],
ntype_to_flat_weights[_PPR_HOMOGENEOUS_NODE_TYPE],
ntype_to_valid_counts[_PPR_HOMOGENEOUS_NODE_TYPE],
ntype_to_flat_ids[DEFAULT_HOMOGENEOUS_NODE_TYPE],
ntype_to_flat_weights[DEFAULT_HOMOGENEOUS_NODE_TYPE],
ntype_to_valid_counts[DEFAULT_HOMOGENEOUS_NODE_TYPE],
)
else:
return (
Expand Down Expand Up @@ -636,17 +603,32 @@ async def _sample_from_nodes(
)

else:
assert isinstance(nodes_to_sample, torch.Tensor)
if isinstance(nodes_to_sample, torch.Tensor):
homogeneous_nodes_to_sample = nodes_to_sample
elif isinstance(nodes_to_sample, dict):
node_types = set(nodes_to_sample.keys())
if node_types != {DEFAULT_HOMOGENEOUS_NODE_TYPE}:
raise ValueError(
f"Expected only {DEFAULT_HOMOGENEOUS_NODE_TYPE} for homogeneous PPR sampling, "
f"received node types: {node_types}"
)
homogeneous_nodes_to_sample = nodes_to_sample[
DEFAULT_HOMOGENEOUS_NODE_TYPE
]
else:
raise TypeError(
f"Expected Tensor or node-type mapping for homogeneous PPR sampling, got {type(nodes_to_sample)}"
)

# Register seeds; local indices 0..N-1 are assigned internally.
# srcs holds their global IDs (same values as nodes_to_sample).
srcs = inducer.init_node(nodes_to_sample)
srcs = inducer.init_node(homogeneous_nodes_to_sample)

(
homo_flat_ids,
homo_flat_weights,
homo_valid_counts,
) = await self._compute_ppr_scores(nodes_to_sample, None)
) = await self._compute_ppr_scores(homogeneous_nodes_to_sample, None)
assert isinstance(homo_flat_ids, torch.Tensor)
assert isinstance(homo_flat_weights, torch.Tensor)
assert isinstance(homo_valid_counts, torch.Tensor)
Expand Down
6 changes: 3 additions & 3 deletions gigl/distributed/dist_sampling_producer.py
Comment thread
mkolodner-sc marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
SamplingConfig,
SamplingType,
)
from graphlearn_torch.typing import EdgeType
from graphlearn_torch.typing import NodeType
from graphlearn_torch.utils import seed_everything
from torch._C import _set_worker_signal_handlers
from torch.utils.data.dataloader import DataLoader
Expand All @@ -55,7 +55,7 @@ def _sampling_worker_loop(
sampling_completed_worker_count, # mp.Value
mp_barrier: Barrier,
sampler_options: SamplerOptions,
degree_tensors: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]],
degree_tensors: Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]],
):
dist_sampler = None
try:
Expand Down Expand Up @@ -181,7 +181,7 @@ def __init__(
channel: ChannelBase,
sampler_options: SamplerOptions,
degree_tensors: Optional[
Union[torch.Tensor, dict[EdgeType, torch.Tensor]]
Union[torch.Tensor, dict[NodeType, torch.Tensor]]
] = None,
):
super().__init__(data, sampler_input, sampling_config, worker_options, channel)
Expand Down
Loading