diff --git a/comet/models/base.py b/comet/models/base.py index 9872d2b..d535279 100644 --- a/comet/models/base.py +++ b/comet/models/base.py @@ -27,6 +27,7 @@ import numpy as np import pytorch_lightning as ptl +from pytorch_lightning.strategies.strategy import Strategy import torch from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, Subset @@ -558,6 +559,7 @@ def predict( accelerator: str = "auto", num_workers: int = None, length_batching: bool = True, + strategy: Union[str, Strategy] = "auto" ) -> Prediction: """Method that receives a list of samples (dictionaries with translations, sources and/or references) and returns segment-level scores, system level score @@ -648,7 +650,7 @@ def predict( logger=False, callbacks=callbacks, accelerator=accelerator if gpus > 0 else "cpu", - strategy="auto" if gpus < 2 else "ddp", + strategy=strategy if (gpus < 2 or isinstance(strategy, Strategy)) else "ddp", enable_progress_bar=enable_progress_bar, ) return_predictions = False if gpus > 1 else True