Skip to content

feat: add custom strategy support#269

Open
dgme-syz wants to merge 1 commit intoUnbabel:masterfrom
dgme-syz:feature/support-custom-strategy
Open

feat: add custom strategy support#269
dgme-syz wants to merge 1 commit intoUnbabel:masterfrom
dgme-syz:feature/support-custom-strategy

Conversation

@dgme-syz
Copy link
Copy Markdown

Summary

This PR adds support for custom PyTorch Lightning strategies in model.predict() by exposing a strategy argument to the user.

This allows users to manually specify the execution strategy when running predictions on different hardware backends (e.g., NPUs or other custom accelerators).


Motivation

Currently, the predict() method internally determines the strategy using:

strategy="auto" if gpus < 2 else "ddp"

However, this prevents users from overriding the strategy when using custom accelerators.

For example, when using a custom accelerator such as an NPU, PyTorch Lightning may incorrectly infer the device as cpu, which leads to the strategy being initialized with a CPU device:

if len(self._parallel_devices) <= 1:
    if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator)) or (
        isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps")
    ):
        device = _determine_root_gpu_device(self._parallel_devices)
    else:
        device = "cpu"
    return SingleDeviceStrategy(device=device)

This makes it difficult to run COMET.predict() on non-CUDA accelerators such as Huawei NPUs.

By allowing users to pass a custom strategy, COMET becomes compatible with a wider range of hardware backends.


Changes

  1. Added a strategy argument to predict():
def predict(
    ...
    strategy: Union[str, Strategy] = "auto"
)
  1. Updated the trainer initialization to respect user-provided strategies:
strategy=strategy if (gpus < 2 or isinstance(strategy, Strategy)) else "ddp"
  1. Added the required import:
from pytorch_lightning.strategies.strategy import Strategy

This change preserves the original default behavior while allowing users to override the strategy when needed.


Example Usage

With this change, users can run COMET on an NPU by providing a custom accelerator and strategy:

from comet import download_model, load_from_checkpoint
from pytorch_lightning.strategies.single_device import SingleDeviceStrategy

model_path = download_model("Unbabel/XCOMET-XL")
model = load_from_checkpoint(model_path)

data = [
    {
        "src": "10 到 15 分钟可以送到吗",
        "mt": "Can I receive my food in 10 to 15 minutes?",
        "ref": "Can it be delivered between 10 to 15 minutes?"
    }
]

model_output = model.predict(
    data,
    batch_size=8,
    gpus=1,
    accelerator=NPUAccelerator(),
    strategy=SingleDeviceStrategy(device="npu"),
)

Where NPUAccelerator is a custom PyTorch Lightning accelerator implementation.


Backward Compatibility

This change is fully backward compatible:

  • Default behavior remains unchanged.
  • Existing code using predict() without specifying strategy will behave exactly as before.

Example: Running COMET on NPU

Below is a minimal example of a custom NPUAccelerator implementation
that allows COMET to run on Huawei Ascend NPUs.

NPUAccelerator example
import os
import logging

import torch
import torch_npu
from functools import lru_cache
from typing import Any, Dict, Union
from pytorch_lightning.accelerators.accelerator import Accelerator

from typing_extensions import override

_log = logging.getLogger(__name__)

class NPUAccelerator(Accelerator):
    """Accelerator for HUAWEI NPU devices."""

    @override
    def setup_device(self, device: torch.device) -> None:
        """
        Raises:
            ValueError: If the selected device is not of type NPU.
        """
        if device.type != "npu":
            raise ValueError(
                f"Device should be of type 'npu', got '{device.type}' instead."
            )
        if device.index is None:
            device = torch.device("npu", 0)
        torch.npu.set_device(device.index)

    @override
    def teardown(self) -> None:
        torch.npu.empty_cache()

    @staticmethod
    @override
    def parse_devices(devices: Any) -> Any:
        return [torch.device("npu", i) for i in range(torch.npu.device_count())]

    @staticmethod
    @override
    def get_parallel_devices(devices: Any) -> Any:
        if isinstance(devices, int):
            return [torch.device("npu", i) for i in range(devices)]
        elif isinstance(devices, list):
            try:
                return [torch.device("npu", i) for i in devices]
            except Exception:
                return devices
        elif devices in ("auto", "npu"):
            return [torch.device("npu", i) for i in range(torch.npu.device_count())]
        return []

    @staticmethod
    @override
    def auto_device_count() -> int:
        return torch.npu.device_count()

    @staticmethod
    @override
    def is_available() -> bool:
        return torch.npu.is_available()
    
    @staticmethod
    @override
    def name() -> str:
        return "NPUAccelerator"

    @override
    def setup(self, trainer: "pl.Trainer") -> None:
        """Called by the Trainer to set up the accelerator."""
        self.set_ascend_flags(trainer.local_rank)
        torch.npu.empty_cache()

    @staticmethod
    def set_ascend_flags(local_rank: int) -> None:
        """Set Ascend NPU environment variables, mirroring CUDA's PCI ordering setup."""
        os.environ["ASCEND_DEVICE_ID"] = str(local_rank)

        all_npu_ids = ",".join(str(x) for x in range(torch.npu.device_count()))
        devices = os.getenv("ASCEND_RT_VISIBLE_DEVICES", all_npu_ids)
        _log.info(f"LOCAL_RANK: {local_rank} - ASCEND_RT_VISIBLE_DEVICES: [{devices}]")

    @override
    def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
        """Return NPU memory stats."""
        if isinstance(device, str):
            device = torch.device(device)
        try:
            return torch_npu.npu.memory_stats(device)
        except Exception:
            return {}

    @classmethod
    @override
    def register_accelerators(cls, accelerator_registry) -> None:
        accelerator_registry.register(
            "npu",
            cls,
            description="NPU Accelerator - optimized for large-scale machine learning.",
        )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant