Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
ed818c2
potential fix
mkolodner-sc May 19, 2026
abb8e56
Update
mkolodner-sc May 19, 2026
a0e84fa
Update
mkolodner-sc May 19, 2026
088fe1b
Improvements
mkolodner-sc May 19, 2026
5ca621c
Change int16 to int32
mkolodner-sc May 19, 2026
ac2ef26
Fix degree tensor tests and type checks
mkolodner-sc May 28, 2026
7ad9faa
Merge branch 'mkolodner-sc/ppr_gs_memory' of github.com:Snapchat/GiGL…
mkolodner-sc May 28, 2026
d850b37
Add E2E PPR graphstore test
mkolodner-sc May 28, 2026
845704b
Update
mkolodner-sc May 28, 2026
ebbc318
Fixes
mkolodner-sc May 28, 2026
65eac99
Fix PPR graph-store sampling worker capacity
mkolodner-sc May 28, 2026
97bd538
Fix
mkolodner-sc May 29, 2026
92c9f51
more fixes
mkolodner-sc May 29, 2026
7e31417
change back
mkolodner-sc May 29, 2026
d9d2086
Avoid cast for heterogeneous inference node ids
mkolodner-sc May 29, 2026
fd1e9ae
Trim branch to PPR sampler fixes
mkolodner-sc May 29, 2026
71e1fa1
Merge remote-tracking branch 'origin/main' into mkolodner-sc/ppr_gs_m…
mkolodner-sc May 29, 2026
a49a650
Add graph-store PPR E2E wiring
mkolodner-sc May 29, 2026
2ef9548
Keep PPR test ty ignores
mkolodner-sc May 29, 2026
8c1dd36
Merge branch 'mkolodner-sc/ppr_gs_memory' into mkolodner-sc/graph_sto…
mkolodner-sc May 29, 2026
b08f0e5
Remove stale PPR test ty ignore
mkolodner-sc May 29, 2026
851ed8b
Merge branch 'mkolodner-sc/ppr_gs_memory' into mkolodner-sc/graph_sto…
mkolodner-sc May 29, 2026
a6eedd1
Use union shape for PPR degree tensors
mkolodner-sc May 29, 2026
68ab0f2
Restore useful degree computation comments
mkolodner-sc May 29, 2026
ab6aecd
Merge branch 'mkolodner-sc/ppr_gs_memory' into mkolodner-sc/graph_sto…
mkolodner-sc May 29, 2026
e71ccdb
Remove sampler diagnostic wrapper
mkolodner-sc May 29, 2026
ee5806b
Merge branch 'mkolodner-sc/ppr_gs_memory' into mkolodner-sc/graph_sto…
mkolodner-sc May 29, 2026
f76e548
Simplify degree all-reduce helper
mkolodner-sc May 29, 2026
98bb3f9
Merge branch 'mkolodner-sc/ppr_gs_memory' into mkolodner-sc/graph_sto…
mkolodner-sc May 29, 2026
23ee86f
Document degree tensor assumptions
mkolodner-sc May 29, 2026
f0e3275
Merge branch 'mkolodner-sc/ppr_gs_memory' into mkolodner-sc/graph_sto…
mkolodner-sc May 29, 2026
3b3497d
Address PPR degree review comments
mkolodner-sc May 29, 2026
5ac1c63
Address PPR degree memory review comments
mkolodner-sc May 29, 2026
a24e32a
Merge branch 'mkolodner-sc/ppr_gs_memory' into mkolodner-sc/graph_sto…
mkolodner-sc May 29, 2026
2f35f22
Configure graph-store PPR sampler options inline
mkolodner-sc May 29, 2026
aa42d7a
Comments
mkolodner-sc May 29, 2026
188525f
Clarify graph-store PPR sampler args
mkolodner-sc May 29, 2026
2641834
Document PPR degree tensor dtype rationale
mkolodner-sc May 29, 2026
1ff8635
Address remaining comments
mkolodner-sc May 29, 2026
5548260
Fix
mkolodner-sc Jun 1, 2026
0e734f3
Merge branch 'mkolodner-sc/ppr_gs_memory' into mkolodner-sc/graph_sto…
mkolodner-sc Jun 1, 2026
a9df285
Improve solution
mkolodner-sc Jun 1, 2026
7757cec
Merge branch 'mkolodner-sc/ppr_gs_memory' into mkolodner-sc/graph_sto…
mkolodner-sc Jun 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,14 @@ run_hom_cora_sup_gs_e2e_test:
--test_spec_uri="tests/e2e_tests/e2e_tests.yaml" \
--test_names="hom_cora_sup_gs_test"

run_hom_cora_sup_gs_ppr_e2e_test: compiled_pipeline_path:=${GIGL_E2E_TEST_COMPILED_PIPELINE_PATH}
run_hom_cora_sup_gs_ppr_e2e_test: compile_gigl_kubeflow_pipeline
run_hom_cora_sup_gs_ppr_e2e_test:
uv run python tests/e2e_tests/e2e_test.py \
--compiled_pipeline_path=$(compiled_pipeline_path) \
--test_spec_uri="tests/e2e_tests/e2e_tests.yaml" \
--test_names="hom_cora_sup_gs_ppr_test"

run_het_dblp_sup_gs_e2e_test: compiled_pipeline_path:=${GIGL_E2E_TEST_COMPILED_PIPELINE_PATH}
run_het_dblp_sup_gs_e2e_test: compile_gigl_kubeflow_pipeline
run_het_dblp_sup_gs_e2e_test:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# This config runs homogeneous CORA supervised training and inference in Graph Store mode
# with PPR sampling. It intentionally reuses the standard graph-store training/inference
# entrypoints, changing only the sampler args and keeping the loop short for E2E coverage.
graphMetadata:
edgeTypes:
- dstNodeType: paper
relation: cites
srcNodeType: paper
nodeTypes:
- paper
datasetConfig:
dataPreprocessorConfig:
dataPreprocessorConfigClsPath: gigl.src.mocking.mocking_assets.passthrough_preprocessor_config_for_mocked_assets.PassthroughPreprocessorConfigForMockedAssets
dataPreprocessorArgs:
mocked_dataset_name: 'cora_homogeneous_node_anchor_edge_features_user_defined_labels'
trainerConfig:
trainerArgs:
log_every_n_batch: "1"
num_neighbors: "[10, 10]"
# Parsed in the graph-store training entrypoint and passed directly as
# kwargs to PPRSamplerOptions in gigl/distributed/sampler_options.py.
# Presence of ppr_sampler_options selects PPR; otherwise this example uses
# k-hop sampling configured by num_neighbors.
ppr_sampler_options: >-
{
"alpha": 0.5,
"eps": 0.0001,
"max_ppr_nodes": 20,
"num_neighbors_per_hop": 100,
"max_fetch_iterations": 2
}
sampling_workers_per_process: "2"
main_batch_size: "8"
random_batch_size: "8"
num_max_train_batches: "4"
num_val_batches: "4"
val_every_n_batch: "1"
command: python -m examples.link_prediction.graph_store.homogeneous_training
graphStoreStorageConfig:
command: python -m examples.link_prediction.graph_store.storage_main
storageArgs:
sample_edge_direction: "in"
splitter_cls_path: "gigl.utils.data_splitters.DistNodeAnchorLinkSplitter"
splitter_kwargs: >-
{
"sampling_direction": "in",
"should_convert_labels_to_edges": True,
"num_val": 0.25,
"num_test": 0.25
}
num_server_sessions: "1"
inferencerConfig:
inferencerArgs:
log_every_n_batch: "1"
num_neighbors: "[10, 10]"
# Parsed in the graph-store inference entrypoint and passed directly as
# kwargs to PPRSamplerOptions in gigl/distributed/sampler_options.py.
# Presence of ppr_sampler_options selects PPR; otherwise this example uses
# k-hop sampling configured by num_neighbors.
ppr_sampler_options: >-
{
"alpha": 0.5,
"eps": 0.0001,
"max_ppr_nodes": 20,
"num_neighbors_per_hop": 100,
"max_fetch_iterations": 2
}
Comment thread
mkolodner-sc marked this conversation as resolved.
sampling_workers_per_inference_process: "2"
inferenceBatchSize: 256
command: python -m examples.link_prediction.graph_store.homogeneous_inference
graphStoreStorageConfig:
command: python -m examples.link_prediction.graph_store.storage_main
storageArgs:
sample_edge_direction: "in"
num_server_sessions: "1"
sharedConfig:
shouldSkipInference: false
shouldSkipModelEvaluation: true
taskMetadata:
nodeAnchorBasedLinkPredictionTaskMetadata:
supervisionEdgeTypes:
- dstNodeType: paper
relation: cites
srcNodeType: paper
featureFlags:
should_run_glt_backend: 'True'
data_preprocessor_num_shards: '2'
60 changes: 35 additions & 25 deletions examples/link_prediction/graph_store/homogeneous_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,13 @@
"""

import argparse
import ast
import gc
import os
import sys
import time
from dataclasses import dataclass
from typing import Union
from typing import Optional, Union

import torch
import torch.multiprocessing as mp
Expand All @@ -101,6 +102,7 @@
from gigl.common.utils.gcs import GcsUtils
from gigl.distributed.graph_store.compute import init_compute_process
from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset
from gigl.distributed.sampler_options import PPRSamplerOptions, SamplerOptions
from gigl.distributed.utils import get_graph_store_info
from gigl.env.distributed import GraphStoreInfo
from gigl.nn import LinkPredictionGNN
Expand All @@ -114,12 +116,6 @@

logger = Logger()

# Default number of inference processes per machine incase one isnt provided in inference args
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Graph Store mode, the source of truth should be cluster_info.num_processes_per_compute, not a local CPU/GPU heuristic. The previous fallback could make inference spawn a different number of compute processes than storage expected, causing storage rendezvous failures like “only N/M clients joined.”

# i.e. `local_world_size` is not provided, and we can't infer automatically.
# If there are GPUs attached to the machine, we automatically infer to setting
# LOCAL_WORLD_SIZE == # of gpus on the machine.
DEFAULT_CPU_BASED_LOCAL_WORLD_SIZE = 4


@dataclass(frozen=True)
class InferenceProcessArgs:
Expand All @@ -143,6 +139,7 @@ class InferenceProcessArgs:
inference_batch_size (int): Batch size to use for inference.
num_neighbors (Union[list[int], dict[EdgeType, list[int]]]): Fanout for subgraph sampling,
where the ith item corresponds to the number of items to sample for the ith hop.
sampler_options (Optional[SamplerOptions]): Sampler variant. None uses k-hop sampling.
sampling_workers_per_inference_process (int): Number of sampling workers per inference
process.
sampling_worker_shared_channel_size (str): Shared-memory buffer size (bytes) allocated for
Expand All @@ -169,6 +166,7 @@ class InferenceProcessArgs:
# Inference configuration
inference_batch_size: int
num_neighbors: Union[list[int], dict[EdgeType, list[int]]]
sampler_options: Optional[SamplerOptions]
sampling_workers_per_inference_process: int
sampling_worker_shared_channel_size: str
log_every_n_batch: int
Expand Down Expand Up @@ -242,6 +240,7 @@ def _inference_process(
# For large-scale settings, consider setting this field to 30-60 seconds to ensure dataloaders
# don't compete for memory during initialization, causing OOM
process_start_gap_seconds=0,
sampler_options=args.sampler_options,
)
# Initialize a LinkPredictionGNN model and load parameters from
# the saved model.
Expand Down Expand Up @@ -455,25 +454,23 @@ def _run_example_inference(
if arg_local_world_size is not None:
local_world_size = int(arg_local_world_size)
logger.info(f"Using local_world_size from inferencer_args: {local_world_size}")
if torch.cuda.is_available() and local_world_size != torch.cuda.device_count():
logger.warning(
f"local_world_size {local_world_size} does not match the number of GPUs {torch.cuda.device_count()}. "
"This may lead to unexpected failures with NCCL communication incase GPUs are being used for "
+ "training/inference. Consider setting local_world_size to the number of GPUs."
)
else:
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
# If GPUs are available, we set the local_world_size to the number of GPUs
local_world_size = torch.cuda.device_count()
logger.info(
f"Detected {local_world_size} GPUs. Thus, setting local_world_size to {local_world_size}"
)
else:
# If no GPUs are available, we set the local_world_size to the number of inference processes per machine
logger.info(
f"No GPUs detected. Thus, setting local_world_size to `{DEFAULT_CPU_BASED_LOCAL_WORLD_SIZE}`"
)
local_world_size = DEFAULT_CPU_BASED_LOCAL_WORLD_SIZE
local_world_size = cluster_info.num_processes_per_compute
logger.info(
f"Using local_world_size from cluster_info.num_processes_per_compute: {local_world_size}"
)
if local_world_size != cluster_info.num_processes_per_compute:
raise ValueError(
f"Graph Store local_world_size={local_world_size} must match "
f"cluster_info.num_processes_per_compute="
f"{cluster_info.num_processes_per_compute}"
)
if torch.cuda.is_available() and local_world_size != torch.cuda.device_count():
logger.warning(
f"local_world_size {local_world_size} does not match the number of GPUs {torch.cuda.device_count()}. "
"This may lead to unexpected failures with NCCL communication incase GPUs are being used for "
+ "training/inference. Consider setting local_world_size to the number of GPUs."
)

if cluster_info.compute_node_rank == 0:
gcs_utils = GcsUtils()
Expand All @@ -494,6 +491,10 @@ def _run_example_inference(
# Parses the fanout as a string. For the homogeneous case, the fanouts should be specified
# as a string of a list of integers, such as "[10, 10]".
num_neighbors = parse_fanout(inferencer_args.get("num_neighbors", "[10, 10]"))
sampler_options: Optional[SamplerOptions] = None
sampler_options_args = inferencer_args.get("ppr_sampler_options")
if sampler_options_args is not None and sampler_options_args.strip():
sampler_options = PPRSamplerOptions(**ast.literal_eval(sampler_options_args))

# While the ideal value for `sampling_workers_per_inference_process` has been identified to be
# between `2` and `4`, this may need some tuning depending on the pipeline. We default this
Expand All @@ -516,6 +517,14 @@ def _run_example_inference(

log_every_n_batch = int(inferencer_args.get("log_every_n_batch", "50"))

logger.info(
f"Got inference args local_world_size={local_world_size}, "
f"num_neighbors={num_neighbors}, sampler_options={sampler_options}, "
f"sampling_workers_per_inference_process={sampling_workers_per_inference_process}, "
f"sampling_worker_shared_channel_size={sampling_worker_shared_channel_size}, "
f"log_every_n_batch={log_every_n_batch}"
)

# When using mp.spawn with `nprocs`, the first argument is implicitly set to be the process number on the current machine.
inference_args = InferenceProcessArgs(
local_world_size=local_world_size,
Expand All @@ -528,6 +537,7 @@ def _run_example_inference(
edge_feature_dim=edge_feature_dim,
inference_batch_size=inference_batch_size,
num_neighbors=num_neighbors,
sampler_options=sampler_options,
sampling_workers_per_inference_process=sampling_workers_per_inference_process,
sampling_worker_shared_channel_size=sampling_worker_shared_channel_size,
log_every_n_batch=log_every_n_batch,
Expand Down
31 changes: 26 additions & 5 deletions examples/link_prediction/graph_store/homogeneous_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@
"""

import argparse
import ast
import gc
import os
import statistics
Expand All @@ -143,6 +144,7 @@
shutdown_compute_process,
)
from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset
from gigl.distributed.sampler_options import PPRSamplerOptions, SamplerOptions
from gigl.distributed.utils import get_available_device, get_graph_store_info
from gigl.env.distributed import GraphStoreInfo
from gigl.nn import LinkPredictionGNN, RetrievalLoss
Expand Down Expand Up @@ -191,6 +193,7 @@ def _setup_dataloaders(
split: Literal["train", "val", "test"],
cluster_info: GraphStoreInfo,
num_neighbors: list[int] | dict[EdgeType, list[int]],
sampler_options: Optional[SamplerOptions],
sampling_workers_per_process: int,
main_batch_size: int,
random_batch_size: int,
Expand All @@ -205,6 +208,7 @@ def _setup_dataloaders(
split (Literal["train", "val", "test"]): The current split which we are loading data for.
cluster_info (GraphStoreInfo): Cluster topology info for graph store mode.
num_neighbors: Fanout for subgraph sampling.
sampler_options (Optional[SamplerOptions]): Sampler variant. None uses k-hop sampling.
sampling_workers_per_process (int): Number of sampling workers per training/testing process.
main_batch_size (int): Batch size for main dataloader with query and labeled nodes.
random_batch_size (int): Batch size for random negative dataloader.
Expand Down Expand Up @@ -240,6 +244,7 @@ def _setup_dataloaders(
channel_size=sampling_worker_shared_channel_size,
process_start_gap_seconds=process_start_gap_seconds,
shuffle=shuffle,
sampler_options=sampler_options,
)

logger.info(f"---Rank {rank} finished setting up main loader for split={split}")
Expand All @@ -266,6 +271,7 @@ def _setup_dataloaders(
channel_size=sampling_worker_shared_channel_size,
process_start_gap_seconds=process_start_gap_seconds,
shuffle=shuffle,
sampler_options=sampler_options,
)

logger.info(
Expand Down Expand Up @@ -375,6 +381,7 @@ class TrainingProcessArgs:
sampling_workers_per_process (int): Number of sampling workers per training/testing process.
sampling_worker_shared_channel_size (str): Shared-memory buffer size for the channel during sampling.
process_start_gap_seconds (int): Time to sleep between dataloader initializations.
sampler_options (Optional[SamplerOptions]): Sampler variant. None uses k-hop sampling.
main_batch_size (int): Batch size for main dataloader.
random_batch_size (int): Batch size for random negative dataloader.
learning_rate (float): Learning rate for the optimizer.
Expand All @@ -400,6 +407,7 @@ class TrainingProcessArgs:

# Sampling config
num_neighbors: list[int] | dict[EdgeType, list[int]]
sampler_options: Optional[SamplerOptions]
sampling_workers_per_process: int
sampling_worker_shared_channel_size: str
process_start_gap_seconds: int
Expand Down Expand Up @@ -463,6 +471,7 @@ def _training_process(
split="train",
cluster_info=args.cluster_info,
num_neighbors=args.num_neighbors,
sampler_options=args.sampler_options,
sampling_workers_per_process=args.sampling_workers_per_process,
main_batch_size=args.main_batch_size,
random_batch_size=args.random_batch_size,
Expand All @@ -481,6 +490,7 @@ def _training_process(
split="val",
cluster_info=args.cluster_info,
num_neighbors=args.num_neighbors,
sampler_options=args.sampler_options,
sampling_workers_per_process=args.sampling_workers_per_process,
main_batch_size=args.main_batch_size,
random_batch_size=args.random_batch_size,
Expand Down Expand Up @@ -637,6 +647,7 @@ def _training_process(
split="test",
cluster_info=args.cluster_info,
num_neighbors=args.num_neighbors,
sampler_options=args.sampler_options,
sampling_workers_per_process=args.sampling_workers_per_process,
main_batch_size=args.main_batch_size,
random_batch_size=args.random_batch_size,
Expand Down Expand Up @@ -837,13 +848,17 @@ def _run_example_training(
# Training Hyperparameters
trainer_args = dict(gbml_config_pb_wrapper.trainer_config.trainer_args)

if torch.cuda.is_available():
default_local_world_size = torch.cuda.device_count()
else:
default_local_world_size = 2
local_world_size = int(
trainer_args.get("local_world_size", str(default_local_world_size))
trainer_args.get(
"local_world_size", str(cluster_info.num_processes_per_compute)
)
)
if local_world_size != cluster_info.num_processes_per_compute:
raise ValueError(
f"Graph Store local_world_size={local_world_size} must match "
f"cluster_info.num_processes_per_compute="
f"{cluster_info.num_processes_per_compute}"
)

if torch.cuda.is_available():
if local_world_size > torch.cuda.device_count():
Expand All @@ -853,6 +868,10 @@ def _run_example_training(

fanout = trainer_args.get("num_neighbors", "[10, 10]")
num_neighbors = parse_fanout(fanout)
sampler_options: Optional[SamplerOptions] = None
sampler_options_args = trainer_args.get("ppr_sampler_options")
if sampler_options_args is not None and sampler_options_args.strip():
sampler_options = PPRSamplerOptions(**ast.literal_eval(sampler_options_args))

sampling_workers_per_process: int = int(
trainer_args.get("sampling_workers_per_process", "4")
Expand Down Expand Up @@ -880,6 +899,7 @@ def _run_example_training(
logger.info(
f"Got training args local_world_size={local_world_size}, \
num_neighbors={num_neighbors}, \
sampler_options={sampler_options}, \
sampling_workers_per_process={sampling_workers_per_process}, \
main_batch_size={main_batch_size}, \
random_batch_size={random_batch_size}, \
Expand Down Expand Up @@ -931,6 +951,7 @@ def _run_example_training(
node_feature_dim=node_feature_dim,
edge_feature_dim=edge_feature_dim,
num_neighbors=num_neighbors,
sampler_options=sampler_options,
sampling_workers_per_process=sampling_workers_per_process,
sampling_worker_shared_channel_size=sampling_worker_shared_channel_size,
process_start_gap_seconds=process_start_gap_seconds,
Expand Down
Loading