From 53ee4f245c9dfc4f834edb5e15b302f74fcefb39 Mon Sep 17 00:00:00 2001 From: NightRaven109 Date: Tue, 10 Jun 2025 15:42:35 -0500 Subject: [PATCH 1/3] Add prune train ckpt script, and Add local weight loading option to inference.py --- README.md | 15 ++ examples/inference/inference.py | 14 +- .../convert_checkpoint_to_safetensors.py | 135 ++++++++++++++++++ 3 files changed, 162 insertions(+), 2 deletions(-) create mode 100644 examples/training/convert_checkpoint_to_safetensors.py diff --git a/README.md b/README.md index 704757b..6556cd1 100644 --- a/README.md +++ b/README.md @@ -130,6 +130,15 @@ python examples/inference/inference.py \ --output_path output_images ``` +Or run weights locally with + +```bash +python examples/inference/inference.py \ +--model_path path_to_local_model_directory \ +--source_image path_to_your_image.jpg \ +--output_path output_images +``` + See the trained models on the HF Hub 🤗 - [Surface normals Checkpoint](https://huggingface.co/jasperai/LBM_normals) - [Depth Checkpoint](https://huggingface.co/jasperai/LBM_depth) @@ -163,6 +172,12 @@ To train the model, you can use the following command: python examples/training/train_lbm_surface.py examples/training/config/surface.yaml ``` +To prune trained output ckpt to just model weights for inference + +```bash +python examples/training/convert_checkpoint_to_safetensors.py --checkpoint_path examples/training/output --output_dir out +``` + *Note*: Make sure to update the relevant section of the `yaml` file to use your own data and log the results on your own [WandB](https://wandb.ai/site). ## Citation diff --git a/examples/inference/inference.py b/examples/inference/inference.py index d81eb22..3560e8f 100644 --- a/examples/inference/inference.py +++ b/examples/inference/inference.py @@ -21,14 +21,24 @@ default="normals", choices=["normals", "depth", "relighting"], ) +parser.add_argument( + "--model_path", + type=str, + help="Path to local model directory (overrides model_name if provided)", +) args = parser.parse_args() def main(): - # download the weights from HF hub - if not os.path.exists(os.path.join(PATH, "ckpts", f"{args.model_name}")): + # Use custom model path if provided + if args.model_path: + logging.info(f"Loading LBM model from custom path: {args.model_path}") + model = get_model(args.model_path, torch_dtype=torch.bfloat16, device="cuda") + + # Otherwise use model_name with HF hub or local cache + elif not os.path.exists(os.path.join(PATH, "ckpts", f"{args.model_name}")): logging.info(f"Downloading {args.model_name} LBM model from HF hub...") model = get_model( f"jasperai/LBM_{args.model_name}", diff --git a/examples/training/convert_checkpoint_to_safetensors.py b/examples/training/convert_checkpoint_to_safetensors.py new file mode 100644 index 0000000..49218a5 --- /dev/null +++ b/examples/training/convert_checkpoint_to_safetensors.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python3 +""" +Convert PyTorch Lightning checkpoint to safetensors format for inference. + +This script converts the large training checkpoints (~14GB) that include optimizer state +and training metadata to lightweight safetensors files (~5GB) with just model weights. +""" + +import argparse +import logging +import os +import shutil +from pathlib import Path + +import torch +import yaml +from safetensors.torch import save_file + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def convert_checkpoint_to_safetensors( + checkpoint_path: str, + output_dir: str, + config_path: str = None, +): + """ + Convert a PyTorch Lightning checkpoint to safetensors format. + + Args: + checkpoint_path: Path to the .ckpt file + output_dir: Directory to save the converted files + config_path: Path to config.yaml (if None, will look in checkpoint directory) + """ + checkpoint_path = Path(checkpoint_path) + output_dir = Path(output_dir) + + if not checkpoint_path.exists(): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + # Create output directory + output_dir.mkdir(parents=True, exist_ok=True) + + # Load checkpoint + logger.info(f"Loading checkpoint from {checkpoint_path}") + checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) + + # Extract model state dict + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + else: + raise ValueError("No 'state_dict' found in checkpoint") + + # Remove "model." prefix from keys (as done in inference/utils.py line 76) + logger.info("Cleaning state dict - removing 'model.' prefix") + cleaned_state_dict = {} + model_prefix = "model." + + for key, value in state_dict.items(): + if key.startswith(model_prefix): + new_key = key[len(model_prefix):] + # Clone tensors to break memory sharing (fixes safetensors shared memory error) + cleaned_state_dict[new_key] = value.clone() + else: + # Keep keys that don't have the model prefix + cleaned_state_dict[key] = value.clone() + + # Save as safetensors + safetensors_path = output_dir / "model.safetensors" + logger.info(f"Saving safetensors to {safetensors_path}") + save_file(cleaned_state_dict, safetensors_path) + + # Handle config.yaml + if config_path is None: + # Look for config.yaml in the same directory as checkpoint + config_path = checkpoint_path.parent / "config.yaml" + else: + config_path = Path(config_path) + + if config_path.exists(): + output_config_path = output_dir / "config.yaml" + logger.info(f"Copying config from {config_path} to {output_config_path}") + shutil.copy2(config_path, output_config_path) + else: + logger.warning(f"Config file not found at {config_path}") + logger.info("You may need to manually create config.yaml for inference") + + # Log size comparison + original_size = checkpoint_path.stat().st_size / (1024**3) # GB + new_size = safetensors_path.stat().st_size / (1024**3) # GB + + logger.info(f"Conversion complete!") + logger.info(f"Original checkpoint: {original_size:.2f} GB") + logger.info(f"Safetensors file: {new_size:.2f} GB") + logger.info(f"Size reduction: {((original_size - new_size) / original_size * 100):.1f}%") + logger.info(f"Output directory: {output_dir}") + + +def main(): + parser = argparse.ArgumentParser( + description="Convert PyTorch Lightning checkpoint to safetensors format" + ) + parser.add_argument( + "--checkpoint_path", + required=True, + help="Path to the .ckpt file to convert" + ) + parser.add_argument( + "--output_dir", + required=True, + help="Directory to save the converted files" + ) + parser.add_argument( + "--config_path", + help="Path to config.yaml (optional, will look in checkpoint directory if not provided)" + ) + + args = parser.parse_args() + + try: + convert_checkpoint_to_safetensors( + checkpoint_path=args.checkpoint_path, + output_dir=args.output_dir, + config_path=args.config_path, + ) + except Exception as e: + logger.error(f"Conversion failed: {e}") + return 1 + + return 0 + + +if __name__ == "__main__": + exit(main()) \ No newline at end of file From 72658086ac086f4dbfaee0505c2b53cacbec8f6f Mon Sep 17 00:00:00 2001 From: Benjamin Gregg <90132896+Night1099@users.noreply.github.com> Date: Tue, 10 Jun 2025 17:11:07 -0500 Subject: [PATCH 2/3] Fix LPIPS loss bug when using l1 --- .../training/convert_checkpoint_to_safetensors.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/examples/training/convert_checkpoint_to_safetensors.py b/examples/training/convert_checkpoint_to_safetensors.py index 49218a5..23b5f2b 100644 --- a/examples/training/convert_checkpoint_to_safetensors.py +++ b/examples/training/convert_checkpoint_to_safetensors.py @@ -52,17 +52,25 @@ def convert_checkpoint_to_safetensors( else: raise ValueError("No 'state_dict' found in checkpoint") - # Remove "model." prefix from keys (as done in inference/utils.py line 76) - logger.info("Cleaning state dict - removing 'model.' prefix") + # Remove "model." prefix from keys and filter out LPIPS loss weights + logger.info("Cleaning state dict - removing 'model.' prefix and filtering training-only weights") cleaned_state_dict = {} model_prefix = "model." for key, value in state_dict.items(): if key.startswith(model_prefix): new_key = key[len(model_prefix):] + # Skip LPIPS loss weights that are only used during training + if new_key.startswith("lpips_loss."): + logger.debug(f"Skipping training-only weight: {new_key}") + continue # Clone tensors to break memory sharing (fixes safetensors shared memory error) cleaned_state_dict[new_key] = value.clone() else: + # Skip LPIPS loss weights that don't have model prefix + if key.startswith("lpips_loss."): + logger.debug(f"Skipping training-only weight: {key}") + continue # Keep keys that don't have the model prefix cleaned_state_dict[key] = value.clone() @@ -132,4 +140,4 @@ def main(): if __name__ == "__main__": - exit(main()) \ No newline at end of file + exit(main()) From cf57e41e5799f1a65e9e9e073e90d5bb3ab092d4 Mon Sep 17 00:00:00 2001 From: Benjamin Gregg <90132896+Night1099@users.noreply.github.com> Date: Wed, 11 Jun 2025 00:11:33 -0500 Subject: [PATCH 3/3] Revert LPIPS weights removal, made in error, didnt have right config --- .../training/convert_checkpoint_to_safetensors.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/examples/training/convert_checkpoint_to_safetensors.py b/examples/training/convert_checkpoint_to_safetensors.py index 23b5f2b..6288ef0 100644 --- a/examples/training/convert_checkpoint_to_safetensors.py +++ b/examples/training/convert_checkpoint_to_safetensors.py @@ -52,25 +52,17 @@ def convert_checkpoint_to_safetensors( else: raise ValueError("No 'state_dict' found in checkpoint") - # Remove "model." prefix from keys and filter out LPIPS loss weights - logger.info("Cleaning state dict - removing 'model.' prefix and filtering training-only weights") + # Remove "model." prefix from keys (as done in inference/utils.py line 76) + logger.info("Cleaning state dict - removing 'model.' prefix") cleaned_state_dict = {} model_prefix = "model." for key, value in state_dict.items(): if key.startswith(model_prefix): new_key = key[len(model_prefix):] - # Skip LPIPS loss weights that are only used during training - if new_key.startswith("lpips_loss."): - logger.debug(f"Skipping training-only weight: {new_key}") - continue # Clone tensors to break memory sharing (fixes safetensors shared memory error) cleaned_state_dict[new_key] = value.clone() else: - # Skip LPIPS loss weights that don't have model prefix - if key.startswith("lpips_loss."): - logger.debug(f"Skipping training-only weight: {key}") - continue # Keep keys that don't have the model prefix cleaned_state_dict[key] = value.clone()