feat: add custom strategy support#269
Open
dgme-syz wants to merge 1 commit intoUnbabel:masterfrom
Open
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds support for custom PyTorch Lightning strategies in
model.predict()by exposing astrategyargument to the user.This allows users to manually specify the execution strategy when running predictions on different hardware backends (e.g., NPUs or other custom accelerators).
Motivation
Currently, the
predict()method internally determines the strategy using:However, this prevents users from overriding the strategy when using custom accelerators.
For example, when using a custom accelerator such as an NPU, PyTorch Lightning may incorrectly infer the device as
cpu, which leads to the strategy being initialized with a CPU device:This makes it difficult to run
COMET.predict()on non-CUDA accelerators such as Huawei NPUs.By allowing users to pass a custom
strategy, COMET becomes compatible with a wider range of hardware backends.Changes
strategyargument topredict():This change preserves the original default behavior while allowing users to override the strategy when needed.
Example Usage
With this change, users can run COMET on an NPU by providing a custom accelerator and strategy:
Where
NPUAcceleratoris a custom PyTorch Lightning accelerator implementation.Backward Compatibility
This change is fully backward compatible:
predict()without specifyingstrategywill behave exactly as before.Example: Running COMET on NPU
Below is a minimal example of a custom
NPUAcceleratorimplementationthat allows COMET to run on Huawei Ascend NPUs.
NPUAccelerator example