Skip to content
1 change: 1 addition & 0 deletions docs/components/components.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ The composed initializer supports seeded weight initialization for reproducibili
| checkpoint_saving | default | [CheckpointSaving](../../src/modalities/checkpointing/checkpoint_saving.py)| [CheckpointSavingConfig](s../../src/modalities/config/config.py) | -- | Component for saving checkpoints based on a savig and execution strategy. |
| checkpoint_saving_strategy | save_every_k_steps_checkpointing_strategy | [SaveEveryKStepsCheckpointingStrategy](../../src/modalities/checkpointing/checkpoint_saving_strategies.py)| [SaveEveryKStepsCheckpointingStrategyConfig](../../src/modalities/config/config.py) | [CheckpointSavingStrategyIF](../../src/modalities/checkpointing/checkpoint_saving_strategies.py) | Checkpointing strategy saving a checkpoint every k steps |
| checkpoint_saving_strategy | save_k_most_recent_checkpoints_strategy | [SaveKMostRecentCheckpointsStrategy](../../src/modalities/checkpointing/checkpoint_saving_strategies.py)| [SaveKMostRecentCheckpointsStrategyConfig](../../src/modalities/config/config.py) | [CheckpointSavingStrategyIF](../../src/modalities/checkpointing/checkpoint_saving_strategies.py) | Checkpointing strategy saving only the last k checkpoints and deleting the previous ones |
| checkpoint_saving_strategy | keep_every_k_steps_and_m_most_recent_checkpointing_strategy | [KeepEveryKStepsAndMMostRecentCheckpointingStrategy](../../src/modalities/checkpointing/checkpoint_saving_strategies.py)| [KeepEveryKStepsAndMMostRecentCheckpointingStrategyConfig](../../src/modalities/config/config.py) | [CheckpointSavingStrategyIF](../../src/modalities/checkpointing/checkpoint_saving_strategies.py) | Checkpointing strategy saving a checkpoint every k steps and keeping the m most recent checkpoints |
Comment thread
BlueCrescent marked this conversation as resolved.
Outdated
| checkpoint_saving_execution | fsdp | [FSDPCheckpointSaving](../../src/modalities/checkpointing/fsdp/fsdp_checkpoint_saving.py)| [FSDPCheckpointSavingConfig](../../src/modalities/config/config.py) | [CheckpointSavingExecutionABC](../../src/modalities/checkpointing/checkpoint_saving_execution.py) | FSDPCheckpointSaving class for saving checkpoints of FSDP models and optimizers. |
| checkpoint_loading | fsdp | [FSDPCheckpointLoading](../../src/modalities/checkpointing/fsdp/fsdp_checkpoint_loading.py)| [FSDPCheckpointLoadingConfig](../../src/modalities/config/config.py) | [CheckpointLoadingIF](../../src/modalities/checkpointing/checkpoint_loading.py) | Component for loading FSDP checkpoints|
| checkpoint_loading | torch | [TorchCheckpointLoading](../../src/modalities/checkpointing/torch/torch_checkpoint_loading.py)| [TorchCheckpointLoadingConfig](../../src/modalities/config/config.py) | [CheckpointLoadingIF](../../src/modalities/checkpointing/checkpoint_loading.py) | Component for loading PyTorch checkpoints|
Expand Down
57 changes: 57 additions & 0 deletions src/modalities/checkpointing/checkpoint_saving_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,60 @@ def get_checkpoint_instruction(
"""
save_current = training_progress.num_seen_steps_total % self.k == 0
return CheckpointingInstruction(save_current=save_current, checkpoints_to_delete=[])


class KeepEveryKStepsAndMMostRecentCheckpointingStrategy(CheckpointSavingStrategyIF):
"""Strategy for keeping every k steps permanently and additionally the most recent checkpoints."""

def __init__(self, k: int, num_recent_checkpoints_to_keep: int = 2):
"""
Initializes the CheckpointSavingStrategy object.

Args:
k (int): The interval of steps to keep.
num_recent_checkpoints_to_keep (int, optional): The number of recent checkpoints to keep.
This includes all checkpoints but only the ones not divisible by k will actually be deleted.
Defaults to 2.

Returns:
None
"""
super().__init__()
self._k = k
self._num_recent_checkpoints_to_keep = num_recent_checkpoints_to_keep
self._saved_recent_checkpoints: list[TrainingProgress] = []
assert self._k > 0, "k must be greater than 0"
assert self._num_recent_checkpoints_to_keep >= 1, "num_recent_checkpoints_to_keep must be at least 1"
Comment thread
BlueCrescent marked this conversation as resolved.

def get_checkpoint_instruction(
self,
training_progress: TrainingProgress,
evaluation_result: dict[str, EvaluationResultBatch] | None = None,
early_stopping_criterion_fulfilled: bool = False,
) -> CheckpointingInstruction:
"""
Returns a CheckpointingInstruction object.

Args:
training_progress (TrainingProgress): The training progress.
evaluation_result (dict[str, EvaluationResultBatch] | None, optional):
The evaluation result. Defaults to None.
early_stopping_criterion_fulfilled (bool, optional):
Whether the early stopping criterion is fulfilled. Defaults to False.

Returns:
CheckpointingInstruction: The checkpointing instruction object.
"""
self._saved_recent_checkpoints.append(dataclasses.replace(training_progress))
checkpoints_to_delete, self._saved_recent_checkpoints = (
(
self._saved_recent_checkpoints[: -self._num_recent_checkpoints_to_keep],
self._saved_recent_checkpoints[-self._num_recent_checkpoints_to_keep :],
)
if len(self._saved_recent_checkpoints) > self._num_recent_checkpoints_to_keep
else ([], self._saved_recent_checkpoints)
)
# Do not delete checkpoints that are divisible by k.
checkpoints_to_delete = [cp for cp in checkpoints_to_delete if cp.num_seen_steps_current_run % self._k != 0]
Comment thread
BlueCrescent marked this conversation as resolved.
Outdated

return CheckpointingInstruction(save_current=True, checkpoints_to_delete=checkpoints_to_delete)
5 changes: 5 additions & 0 deletions src/modalities/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ class SaveKMostRecentCheckpointsStrategyConfig(BaseModel):
k: Annotated[int, Field(strict=True, ge=-1)]


class KeepEveryKStepsAndMMostRecentCheckpointingStrategyConfig(BaseModel):
k: Annotated[int, Field(strict=True, gt=0)]
num_recent_checkpoints_to_keep: Annotated[int, Field(strict=True, ge=1)] = 2


class TorchCheckpointLoadingConfig(BaseModel):
device: PydanticPytorchDeviceType
precision: Optional[PrecisionEnum] = None
Expand Down
8 changes: 8 additions & 0 deletions src/modalities/registry/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from modalities.checkpointing.checkpoint_saving import CheckpointSaving
from modalities.checkpointing.checkpoint_saving_strategies import (
KeepEveryKStepsAndMMostRecentCheckpointingStrategy,
SaveEveryKStepsCheckpointingStrategy,
SaveKMostRecentCheckpointsStrategy,
)
Expand Down Expand Up @@ -47,6 +48,7 @@
GPT2LLMCollateFnConfig,
GPT2MFUCalculatorConfig,
GPT2ModelTPConfig,
KeepEveryKStepsAndMMostRecentCheckpointingStrategyConfig,
LinearLRSchedulerConfig,
LinearWarmupCosineAnnealingLRSchedulerConfig,
LLMDataLoaderConfig,
Expand Down Expand Up @@ -353,6 +355,12 @@ class ComponentEntity:
SaveKMostRecentCheckpointsStrategy,
SaveKMostRecentCheckpointsStrategyConfig,
),
ComponentEntity(
"checkpoint_saving_strategy",
"keep_every_k_steps_and_m_most_recent_checkpointing_strategy",
KeepEveryKStepsAndMMostRecentCheckpointingStrategy,
KeepEveryKStepsAndMMostRecentCheckpointingStrategyConfig,
),
# checkpoint saving execution
ComponentEntity("checkpoint_saving_execution", "fsdp1", FSDP1CheckpointSaving, FSDP1CheckpointSavingConfig),
ComponentEntity("checkpoint_saving_execution", "dcp", DCPCheckpointSaving, DCPCheckpointSavingConfig),
Expand Down
64 changes: 63 additions & 1 deletion tests/checkpointing/test_checkpoint_strategies.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import dataclasses

import pytest

from modalities.checkpointing.checkpoint_saving_strategies import SaveKMostRecentCheckpointsStrategy
from modalities.checkpointing.checkpoint_saving_instruction import CheckpointingInstruction
from modalities.checkpointing.checkpoint_saving_strategies import (
KeepEveryKStepsAndMMostRecentCheckpointingStrategy,
SaveKMostRecentCheckpointsStrategy,
)
from modalities.training.training_progress import TrainingProgress


Expand Down Expand Up @@ -43,3 +49,59 @@ def test_checkpoint_strategy_k(
if k != 0 and save_current:
training_progress.num_seen_steps_current_run = 100
assert checkpoint_strategy.saved_step_checkpoints[0].num_seen_steps_current_run == num_seen_steps_current_run


@pytest.mark.parametrize(
"k, num_recent_checkpoints_to_keep, num_steps",
[
(3, 2, 11),
(2, 1, 10),
(4, 3, 15),
],
)
def test_keep_every_k_steps_keeps_every_k_steps(k: int, num_recent_checkpoints_to_keep: int, num_steps: int) -> None:
checkpoint_strategy = KeepEveryKStepsAndMMostRecentCheckpointingStrategy(
k=k, num_recent_checkpoints_to_keep=num_recent_checkpoints_to_keep
)
training_progress = TrainingProgress(
num_seen_steps_current_run=0,
num_seen_tokens_current_run=0,
num_target_steps=20,
num_target_tokens=40,
)

# Simulate training progress and checkpointing
simulator = _CheckpointSavingSimulator()
for step in range(1, num_steps + 1):
training_progress.num_seen_steps_current_run = step
checkpoint_instruction = checkpoint_strategy.get_checkpoint_instruction(training_progress=training_progress)
simulator.simulate_training_step(training_progress, checkpoint_instruction)

for ckpt in simulator.saved_checkpoints:
# Check that only checkpoints that are divisible by k or the most recent ones are kept.
last_checkpoints = set(range(num_steps - num_recent_checkpoints_to_keep + 1, num_steps + 1))
assert ckpt.num_seen_steps_current_run % k == 0 or ckpt.num_seen_steps_current_run in last_checkpoints
Comment thread
BlueCrescent marked this conversation as resolved.
Outdated


def test_keep_every_k_steps_checkpointing_strategy_invalid_arguments() -> None:
with pytest.raises(AssertionError):
KeepEveryKStepsAndMMostRecentCheckpointingStrategy(k=0, num_recent_checkpoints_to_keep=1)
with pytest.raises(AssertionError):
KeepEveryKStepsAndMMostRecentCheckpointingStrategy(k=-1, num_recent_checkpoints_to_keep=1)
with pytest.raises(AssertionError):
KeepEveryKStepsAndMMostRecentCheckpointingStrategy(k=2, num_recent_checkpoints_to_keep=0)
with pytest.raises(AssertionError):
KeepEveryKStepsAndMMostRecentCheckpointingStrategy(k=2, num_recent_checkpoints_to_keep=-1)


class _CheckpointSavingSimulator:
def __init__(self):
self.saved_checkpoints: list[TrainingProgress] = []

def simulate_training_step(
self, training_progress: TrainingProgress, ckpt_instruction: CheckpointingInstruction
) -> None:
if ckpt_instruction.save_current:
self.saved_checkpoints.append(dataclasses.replace(training_progress))
for checkpoint_to_delete in ckpt_instruction.checkpoints_to_delete:
self.saved_checkpoints = [cp for cp in self.saved_checkpoints if cp != checkpoint_to_delete]
Loading