From 92d3df371ea4de3ff424c0f8379abb86967673bc Mon Sep 17 00:00:00 2001 From: root Date: Fri, 10 Apr 2026 18:56:12 +0000 Subject: [PATCH] feat: add custom strategy support --- comet/models/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/comet/models/base.py b/comet/models/base.py index 9872d2b..d535279 100644 --- a/comet/models/base.py +++ b/comet/models/base.py @@ -27,6 +27,7 @@ import numpy as np import pytorch_lightning as ptl +from pytorch_lightning.strategies.strategy import Strategy import torch from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, Subset @@ -558,6 +559,7 @@ def predict( accelerator: str = "auto", num_workers: int = None, length_batching: bool = True, + strategy: Union[str, Strategy] = "auto" ) -> Prediction: """Method that receives a list of samples (dictionaries with translations, sources and/or references) and returns segment-level scores, system level score @@ -648,7 +650,7 @@ def predict( logger=False, callbacks=callbacks, accelerator=accelerator if gpus > 0 else "cpu", - strategy="auto" if gpus < 2 else "ddp", + strategy=strategy if (gpus < 2 or isinstance(strategy, Strategy)) else "ddp", enable_progress_bar=enable_progress_bar, ) return_predictions = False if gpus > 1 else True