diff --git a/examples/v1/config/sft_qwen3_30b_async_verify.py b/examples/v1/config/sft_qwen3_30b_async_verify.py new file mode 100644 index 0000000000..777823c41c --- /dev/null +++ b/examples/v1/config/sft_qwen3_30b_async_verify.py @@ -0,0 +1,134 @@ +"""Verification config for async checkpoint with ~30B parameter model. + +Uses Qwen3-32B Dense architecture (64 layers, hidden=5120) to produce ~237 GB +checkpoint (model ~79GB + optimizer ~158GB), large enough to stress async I/O. + +Hardware: 8x H200 (141GB each) +Memory estimate per GPU (FSDP 8-way): + - Model weights (bf16): ~60GB / 8 = ~7.5GB + - Optimizer states (fp32): ~120GB / 8 = ~15.0GB + - Gradients (bf16): ~60GB / 8 = ~7.5GB + - Activations (grad ckpt): ~10-20GB + Total: ~40-50GB per GPU (well within 141GB) + +Requires >= 256 GB host memory per node for CPU staging. + +Environment variables: + ASYNC_CKPT - "1" (default) for async, "0" for sync + CKPT_INTERVAL - checkpoint interval in steps (default: 10) + TOTAL_STEP - total training steps (default: 100) + WORK_DIR - override work directory path (optional) + +Usage: + # Async (default) + torchrun --nproc_per_node=8 xtuner/v1/train/cli/sft.py \ + --config examples/v1/config/sft_qwen3_30b_async_verify.py \ + 2>&1 | tee logs/sft_async_qwen3_30b_$(date +%Y%m%d_%H%M%S).log + + # Sync baseline + ASYNC_CKPT=0 torchrun --nproc_per_node=8 xtuner/v1/train/cli/sft.py \ + --config examples/v1/config/sft_qwen3_30b_async_verify.py \ + 2>&1 | tee logs/sft_sync_qwen3_30b_$(date +%Y%m%d_%H%M%S).log + + # Test on slow storage + WORK_DIR=/mnt/nfs/ckpt_bench torchrun --nproc_per_node=8 \ + xtuner/v1/train/cli/sft.py \ + --config examples/v1/config/sft_qwen3_30b_async_verify.py \ + 2>&1 | tee logs/sft_async_nfs_qwen3_30b_$(date +%Y%m%d_%H%M%S).log + +Analysis: + Compare logs by grepping for these key timing markers: + [Checkpoint Breakdown] - per-checkpoint blocking time breakdown + [Async Checkpoint] Staging - GPU->CPU staging wait (async only) + [Async Checkpoint] Upload - disk I/O wait (async only) + Training finished in - total training wall-clock +""" + +import os + +from xtuner.v1.config import AdamWConfig, LRConfig +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.pt_tokenize_fn import PretrainTokenizeFunctionConfig +from xtuner.v1.loss import CELossConfig +from xtuner.v1.model.dense.qwen3 import Qwen3DenseConfig +from xtuner.v1.module.attention import MHAConfig +from xtuner.v1.train import TrainerConfig + +# --------------------------------------------------------------------------- +# Environment switches +# --------------------------------------------------------------------------- +async_checkpoint = os.environ.get("ASYNC_CKPT", "1") != "0" +checkpoint_save_optimizer = os.environ.get("SAVE_OPTIM", "1") != "0" +checkpoint_interval = int(os.environ.get("CKPT_INTERVAL", "50")) +total_step = int(os.environ.get("TOTAL_STEP", "500")) +work_dir = os.environ.get("WORK_DIR", None) + +# --------------------------------------------------------------------------- +# Model — Qwen3-32B Dense architecture (no pretrained weights, pure verify) +# --------------------------------------------------------------------------- +model_cfg = Qwen3DenseConfig( + vocab_size=151936, + max_position_embeddings=32768, + bos_token_id=151643, + eos_token_id=151645, + num_hidden_layers=64, + max_window_layers=64, + hidden_size=5120, + intermediate_size=17408, + rms_norm_eps=1e-6, + rope_theta=1000000.0, + hidden_act="silu", + attention=MHAConfig( + num_attention_heads=40, + num_key_value_heads=8, + head_dim=128, + qk_norm=True, + sliding_window=None, + ), + tie_word_embeddings=False, +) + +# --------------------------------------------------------------------------- +# Data — reuse existing test data +# --------------------------------------------------------------------------- +dataset_config = [ + { + "dataset": DatasetConfig( + name="pretrain_text", + anno_path="tests/resource/pretrain_example_data.jsonl", + sample_ratio=1.0, + ), + "tokenize_fn": PretrainTokenizeFunctionConfig( + add_bos_token=False, + add_eos_token=True, + ), + }, +] + +dataloader_config = DataloaderConfig( + dataset_config_list=dataset_config, + pack_max_length=4096, +) + +# --------------------------------------------------------------------------- +# Optimizer & LR +# --------------------------------------------------------------------------- +optim_cfg = AdamWConfig(lr=2e-5, foreach=True) +lr_cfg = LRConfig(lr_type="cosine", warmup_ratio=0.05) + +# --------------------------------------------------------------------------- +# Trainer +# --------------------------------------------------------------------------- +trainer = TrainerConfig( + model_cfg=model_cfg, + optim_cfg=optim_cfg, + dataloader_cfg=dataloader_config, + lr_cfg=lr_cfg, + loss_cfg=CELossConfig(mode="chunk", chunk_size=1024), + global_batch_size=32, + total_step=total_step, + checkpoint_interval=checkpoint_interval, + async_checkpoint=async_checkpoint, + checkpoint_save_optimizer=checkpoint_save_optimizer, + work_dir=work_dir, +) diff --git a/examples/v1/config/sft_qwen3_8b_async_verify.py b/examples/v1/config/sft_qwen3_8b_async_verify.py new file mode 100644 index 0000000000..707dc62cd1 --- /dev/null +++ b/examples/v1/config/sft_qwen3_8b_async_verify.py @@ -0,0 +1,129 @@ +"""Verification config for async checkpoint with ~8B parameter model. + +Uses Qwen3-8B architecture (36 layers, hidden=4096) to produce ~87 GB +checkpoint (model + optimizer), large enough to demonstrate async benefit. + +Requires >= 128 GB host memory per rank for CPU staging. + +Environment variables: + ASYNC_CKPT - "1" (default) for async, "0" for sync + CKPT_INTERVAL - checkpoint interval in steps (default: 10) + TOTAL_STEP - total training steps (default: 100) + WORK_DIR - override work directory path (optional) + +Usage: + # ---- Experiment 1: High-frequency checkpoint, local SSD ---- + # Async + torchrun --nproc_per_node=8 xtuner/v1/train/cli/sft.py \ + --config examples/v1/config/sft_qwen3_8b_async_verify.py \ + 2>&1 | tee logs/sft_async_highfreq_qwen3_8b_$(date +%Y%m%d_%H%M%S).log + + # Sync baseline + ASYNC_CKPT=0 torchrun --nproc_per_node=8 xtuner/v1/train/cli/sft.py \ + --config examples/v1/config/sft_qwen3_8b_async_verify.py \ + 2>&1 | tee logs/sft_sync_highfreq_qwen3_8b_$(date +%Y%m%d_%H%M%S).log + + # ---- Experiment 2: Slow storage (NFS/HDFS mount) ---- + # Point WORK_DIR to a network mount to amplify I/O gap + WORK_DIR=/mnt/nfs/ckpt_bench torchrun --nproc_per_node=8 \ + xtuner/v1/train/cli/sft.py \ + --config examples/v1/config/sft_qwen3_8b_async_verify.py \ + 2>&1 | tee logs/sft_async_nfs_qwen3_8b_$(date +%Y%m%d_%H%M%S).log + + ASYNC_CKPT=0 WORK_DIR=/mnt/nfs/ckpt_bench torchrun --nproc_per_node=8 \ + xtuner/v1/train/cli/sft.py \ + --config examples/v1/config/sft_qwen3_8b_async_verify.py \ + 2>&1 | tee logs/sft_sync_nfs_qwen3_8b_$(date +%Y%m%d_%H%M%S).log + +Analysis: + Compare logs by grepping for these key timing markers: + [Checkpoint Breakdown] - per-checkpoint blocking time breakdown + [Checkpoint Total Blocking] - total wall-clock blocking per save + [Async Checkpoint] Staging - GPU->CPU staging wait (async only) + [Async Checkpoint] Upload - disk I/O wait (async only) + [DCP Collect Model/Optimizer State Dict] - state_dict collection time + [DCP save/async_save] - actual save/async_save API time + Training finished in - total training wall-clock + +Verification checklist: + - [ ] Async run exits cleanly (no hang after last step) + - [ ] Checkpoint files are complete: ls /checkpoints/ckpt-step-*/ + - [ ] Loss/grad_norm matches between async and sync runs (diff < 1e-4) + - [ ] Compare total [Checkpoint Total Blocking] sums between runs + - [ ] Check if [Async Checkpoint] Staging shows non-zero waits + (means checkpoint_interval is tight enough to stress the pipeline) +""" + +import os + +from xtuner.v1.config import AdamWConfig, LRConfig +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.pt_tokenize_fn import PretrainTokenizeFunctionConfig +from xtuner.v1.loss import CELossConfig +from xtuner.v1.model import Qwen3Dense8BConfig +from xtuner.v1.train import TrainerConfig + +# Toggle via environment variable for easy A/B comparison +async_checkpoint = os.environ.get("ASYNC_CKPT", "1") != "0" +checkpoint_save_optimizer = os.environ.get("SAVE_OPTIM", "1") != "0" + +# Configurable checkpoint frequency and total steps for different experiments: +# - checkpoint_interval=10 with step_time~2s means ~20s between saves. +# Since each save takes ~11s (async) or ~14s (sync), the pipeline is +# stressed enough that wait_prev > 0 will appear in async mode, +# revealing the true overlap benefit. +# - total_step=100 gives 10 checkpoint events, enough to average out +# cold-start effects while keeping wall-clock under 10 minutes. +checkpoint_interval = int(os.environ.get("CKPT_INTERVAL", "10")) +total_step = int(os.environ.get("TOTAL_STEP", "100")) + +# Optional: override work_dir to test on slow storage (NFS/HDFS) +work_dir = os.environ.get("WORK_DIR", None) + +# 36 layers with full hidden dimensions (~8.7B params) +model_cfg = Qwen3Dense8BConfig( + num_hidden_layers=36, + hidden_size=4096, + intermediate_size=14336, + vocab_size=151936, +) + +# Reuse existing test data +sample_max_length = 4096 +pack_max_length = 4096 + +dataset_config = [ + { + "dataset": DatasetConfig( + name="pretrain_text", + anno_path="tests/resource/pretrain_example_data.jsonl", + sample_ratio=1.0, + ), + "tokenize_fn": PretrainTokenizeFunctionConfig( + add_bos_token=False, + add_eos_token=True, + ), + }, +] + +dataloader_config = DataloaderConfig( + dataset_config_list=dataset_config, + pack_max_length=pack_max_length, +) + +optim_cfg = AdamWConfig(lr=2e-5, foreach=True) +lr_cfg = LRConfig(lr_type="cosine", warmup_ratio=0.05) + +trainer = TrainerConfig( + model_cfg=model_cfg, + optim_cfg=optim_cfg, + dataloader_cfg=dataloader_config, + lr_cfg=lr_cfg, + loss_cfg=CELossConfig(mode="chunk", chunk_size=1024), + global_batch_size=32, + total_step=total_step, + checkpoint_interval=checkpoint_interval, + async_checkpoint=async_checkpoint, + checkpoint_save_optimizer=checkpoint_save_optimizer, + work_dir=work_dir, +) diff --git a/examples/v1/scripts/bench_async_checkpoint.sh b/examples/v1/scripts/bench_async_checkpoint.sh new file mode 100755 index 0000000000..346eadd090 --- /dev/null +++ b/examples/v1/scripts/bench_async_checkpoint.sh @@ -0,0 +1,170 @@ +#!/usr/bin/env bash +# Async checkpoint A/B benchmark script. +# +# Runs async first, removes checkpoints, then runs sync, +# and removes checkpoints again after the sync run. +# +# Usage: +# bash examples/v1/scripts/bench_async_checkpoint.sh 8b +# bash examples/v1/scripts/bench_async_checkpoint.sh 30b +# +# Env: +# CKPT_INTERVAL - checkpoint interval in steps (default: 20) +# TOTAL_STEP - total training steps (default: 100) +# WORK_DIR - override checkpoint save path +# NPROC - number of GPUs (default: 8) +# SAVE_OPTIM - "1" (default) save optimizer, "0" skip optimizer + +set -euo pipefail + +MODEL="${1:?Usage: $0 <8b|30b>}" +NPROC="${NPROC:-8}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +LOG_DIR="logs" + +export CKPT_INTERVAL="${CKPT_INTERVAL:-20}" +export TOTAL_STEP="${TOTAL_STEP:-100}" +export SAVE_OPTIM="${SAVE_OPTIM:-1}" + +mkdir -p "$LOG_DIR" +export PYTHONPATH="$(pwd):${PYTHONPATH:-}" + + +case "$MODEL" in + 8b) CONFIG="examples/v1/config/sft_qwen3_8b_async_verify.py" ;; + 30b) CONFIG="examples/v1/config/sft_qwen3_30b_async_verify.py" ;; + *) echo "Unknown model: $MODEL (use 8b or 30b)"; exit 1 ;; +esac + +WORK_DIR="${WORK_DIR:-$(pwd)}" +export WORK_DIR + +SUFFIX="${MODEL}_${TIMESTAMP}_interval_${CKPT_INTERVAL}" +if [ "$SAVE_OPTIM" = "0" ]; then + SUFFIX="${SUFFIX}_no_optim" +fi + +SYNC_LOG="$LOG_DIR/xtuner_sync_${SUFFIX}.log" +ASYNC_LOG="$LOG_DIR/xtuner_async_${SUFFIX}.log" + +# Extract the Trainer's experiment directory (WORK_DIR/) from a log +# file by parsing the DCP save path. +get_exp_dir_from_log() { + local log_file="$1" + # Match: DCP save for /checkpoints/... or DCP async_save for /checkpoints/... + # The experiment dir is everything before "/checkpoints/" + local exp_dir + exp_dir=$(grep -oP '(?<=DCP (save|async_save) for )\S+' "$log_file" \ + | head -1 \ + | sed 's|/checkpoints/.*||') + echo "$exp_dir" +} + +cleanup_checkpoints() { + local tag="$1" + local log_file="$2" + + echo "--------------------------------------------" + echo " Cleaning checkpoints after ${tag} run" + echo "--------------------------------------------" + + local exp_dir + exp_dir=$(get_exp_dir_from_log "$log_file") + + if [ -n "$exp_dir" ] && [ -d "$exp_dir" ]; then + du -sh "$exp_dir" 2>/dev/null || true + rm -rf "$exp_dir" + echo "Removed experiment dir: $exp_dir" + else + echo "No experiment dir found from log: $log_file" + fi + + if [ -f "${WORK_DIR}/.xtuner" ]; then + rm -f "${WORK_DIR}/.xtuner" + echo "Removed meta file: ${WORK_DIR}/.xtuner" + fi + + if [ -f "${WORK_DIR}/meta.json" ]; then + rm -f "${WORK_DIR}/meta.json" + echo "Removed legacy meta file: ${WORK_DIR}/meta.json" + fi + + echo "" +} + +echo "============================================" +echo " Async Checkpoint Benchmark — ${MODEL^^}" +echo "============================================" +echo "Config: $CONFIG" +echo "GPUs: $NPROC" +echo "Total steps: $TOTAL_STEP" +echo "Save interval: $CKPT_INTERVAL" +echo "Save optimizer: $SAVE_OPTIM" +echo "Work dir: $WORK_DIR" +echo "Run order: async -> cleanup -> sync -> cleanup" +echo "Sync log: $SYNC_LOG" +echo "Async log: $ASYNC_LOG" +echo "" + +# ---- Run 1/2: Async ---- +echo "============================================" +echo " [1/2] Running ASYNC mode..." +echo "============================================" +env ASYNC_CKPT=1 SAVE_OPTIM="$SAVE_OPTIM" \ + torchrun --nproc_per_node="$NPROC" xtuner/v1/train/cli/sft.py \ + --config "$CONFIG" \ + 2>&1 | tee "$ASYNC_LOG" +echo "" + +# ---- Delete async checkpoints before sync run ---- +cleanup_checkpoints "ASYNC" "$ASYNC_LOG" + +# ---- Run 2/2: Sync ---- +echo "============================================" +echo " [2/2] Running SYNC mode..." +echo "============================================" +env ASYNC_CKPT=0 SAVE_OPTIM="$SAVE_OPTIM" \ + torchrun --nproc_per_node="$NPROC" xtuner/v1/train/cli/sft.py \ + --config "$CONFIG" \ + 2>&1 | tee "$SYNC_LOG" +echo "" + +# ---- Delete sync checkpoints after sync run ---- +cleanup_checkpoints "SYNC" "$SYNC_LOG" + +# ---- Extract & compare ---- +echo "============================================" +echo " Results Summary" +echo "============================================" +echo "" + +echo "--- Total Training Time ---" +echo "Sync: $(grep 'Training finished in' "$SYNC_LOG" | head -1 || echo '(not found)')" +echo "Async: $(grep 'Training finished in' "$ASYNC_LOG" | head -1 || echo '(not found)')" +echo "" + +echo "--- Checkpoint Breakdown ---" +echo "[Sync]" +grep -E '\[Checkpoint Breakdown\]' "$SYNC_LOG" | head -20 || echo " (not found)" +echo "" +echo "[Async]" +grep -E '\[Checkpoint Breakdown\]' "$ASYNC_LOG" | head -20 || echo " (not found)" +echo "" + +echo "--- Async Staging/Upload Waits ---" +grep -E '\[Async Checkpoint\]' "$ASYNC_LOG" | head -20 || echo " (not found)" +echo "" + +echo "--- Final Loss & Grad Norm ---" +echo "Sync: $(grep -E 'loss=|reduced_llm_loss:|grad_norm:' "$SYNC_LOG" | tail -1 || echo '(not found)')" +echo "Async: $(grep -E 'loss=|reduced_llm_loss:|grad_norm:' "$ASYNC_LOG" | tail -1 || echo '(not found)')" +echo "" + +echo "--- GPU Memory ---" +echo "Sync: $(grep -E 'max_memory' "$SYNC_LOG" | tail -1 || echo '(not found)')" +echo "Async: $(grep -E 'max_memory' "$ASYNC_LOG" | tail -1 || echo '(not found)')" +echo "" + +echo "Logs saved to:" +echo " $SYNC_LOG" +echo " $ASYNC_LOG" diff --git a/tests/engine/test_moe_train_engine.py b/tests/engine/test_moe_train_engine.py index adb3e4615f..023ebf4595 100644 --- a/tests/engine/test_moe_train_engine.py +++ b/tests/engine/test_moe_train_engine.py @@ -310,15 +310,15 @@ def test_checkpoint_save_load(self, device, dispatcher, ep_size, load_from_type) engine.from_hf(load_from, strict=not tiny_model) dist.barrier() - model_dir, optimizer_dir = tmpdir / "model", tmpdir / "optimizer" - engine.save_dcp(model_dir=model_dir, optimizer_dir=optimizer_dir) + weights_dir = tmpdir / "weights" + engine.save_dcp(weights_dir=weights_dir) dist.barrier() time.sleep(1) engine2 = create_engine_from_hf(load_from, dispatcher, ep_size, tiny=tiny_model) engine2.init_model_weights() - engine2.load_dcp(model_dir=model_dir, optimizer_dir=optimizer_dir) + engine2.load_dcp(weights_dir=weights_dir) # 3. check # check the model state state_dict = engine.model.state_dict() @@ -379,8 +379,7 @@ def test_load_optimizer_with_new_lr(self, device): temp_dir = [None] dist.broadcast_object_list(temp_dir, src=0) temp_dir = Path(temp_dir[0]) - model_dir = temp_dir / "model" - optimizer_dir = temp_dir / "optimizer" + weights_dir = temp_dir / "weights" moe_cfg = Qwen3MoE30BA3Config( num_hidden_layers=2, ) @@ -394,7 +393,7 @@ def test_load_optimizer_with_new_lr(self, device): fsdp_cfg=fsdp_cfg, ) engine.init_model_weights() - engine.save_dcp(model_dir=model_dir, optimizer_dir=optimizer_dir) + engine.save_dcp(weights_dir=weights_dir) dist.barrier() time.sleep(1) @@ -406,7 +405,7 @@ def test_load_optimizer_with_new_lr(self, device): optim_cfg=optim_cfg2, fsdp_cfg=fsdp_cfg, ) - engine2.load_dcp(model_dir=model_dir, optimizer_dir=optimizer_dir, load_args=False) + engine2.load_dcp(weights_dir=weights_dir, load_args=False) assert len(engine.optimizer.state) == len(engine2.optimizer.state) assert len(engine.optimizer.state) != 0 for param_group in engine2.optimizer.param_groups: @@ -421,7 +420,7 @@ def test_load_optimizer_with_new_lr(self, device): optim_cfg=optim_cfg3, fsdp_cfg=fsdp_cfg, ) - engine3.load_dcp(model_dir=model_dir, optimizer_dir=optimizer_dir, load_states=False) + engine3.load_dcp(weights_dir=weights_dir, load_states=False) assert len(engine3.optimizer.state) == 0 for param_group in engine3.optimizer.param_groups: assert param_group['lr'] == lr1 diff --git a/tests/train/test_trainer.py b/tests/train/test_trainer.py index fcf3d5cd97..845453af0b 100644 --- a/tests/train/test_trainer.py +++ b/tests/train/test_trainer.py @@ -5,6 +5,7 @@ import pickle import shutil import weakref +from concurrent.futures import Future from pydantic import TypeAdapter import torch @@ -79,10 +80,17 @@ def clip_grad_norm(self, do_clip: bool=True, dtype=torch.float32): load_dcp = Mock() - def save_dcp(self, model_dir: Path, optimizer_dir: Path | None): - model_dir.mkdir(parents=True, exist_ok=True) - if optimizer_dir is not None: - optimizer_dir.mkdir(parents=True, exist_ok=True) + def save_dcp(self, weights_dir: Path, save_optimizer: bool = True): + weights_dir.mkdir(parents=True, exist_ok=True) + + def async_save_dcp(self, weights_dir: Path) -> Future: + weights_dir.mkdir(parents=True, exist_ok=True) + f: Future = Future() + f.set_result(None) + return f + + def destroy_async_checkpoint_pg(self) -> None: + pass def prepare(fn): @@ -258,6 +266,56 @@ def test_save_checkpoint_interval(self): assert f"step-{step}" in str(checkpoint) assert os.path.exists(checkpoint) + @patch("xtuner.v1.train.trainer.is_hf_model_path", Mock(return_value=True)) + @patch("xtuner.v1.train.trainer.Trainer.build_engine", Mock(side_effect=lambda *args, **kwargs: FakeEngine())) + @patch("xtuner.v1.train.trainer.Trainer._prepare_model_input", Mock(return_value=[])) + @prepare + def test_async_save_checkpoint_interval(self): + self.create_pg(DEVICE) + work_dir_list = [self.work_dir] + dist.broadcast_object_list(work_dir_list, src=0) + self.work_dir = Path(work_dir_list[0]) + model_cfg = Qwen3MoE30BA3Config() + optim_cfg = AdamWConfig(lr=1e-4, weight_decay=0.01) + fsdp_cfg = FSDPConfig(tp_size=1) + dataset_cfg = [ + { + "dataset": DatasetConfig(name="alpaca", anno_path=self.alpaca_path, sample_ratio=1.0), + "tokenize_fn": FTDPTokenizeFnConfig(), + }, + ] + dataloader_cfg = DataloaderConfig() + lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0.1, lr_min=1e-6) + + trainer = Trainer( + load_from=str(self.fake_hf_model_dir), + model_cfg=model_cfg, + optim_cfg=optim_cfg, + fsdp_cfg=fsdp_cfg, + dataset_cfg=dataset_cfg, + dataloader_cfg=dataloader_cfg, + lr_cfg=lr_cfg, + tokenizer_path=self.tokenizer_path, + global_batch_size=2, + total_step=10, + work_dir=str(self.work_dir), + hf_interval=3, + hf_max_keep=2, + seed=42, + debug=False, + checkpoint_interval=5, + async_checkpoint=True, + ) + + trainer.fit() + dist.barrier() + assert len(trainer.meta.latest_exp.checkpoint_list) == 2 + for checkpoint, step in zip(trainer.meta.latest_exp.checkpoint_list, [5, 10]): + assert f"step-{step}" in str(checkpoint) + assert os.path.exists(checkpoint) + weights_dir = Path(checkpoint) / "weights" + assert weights_dir.exists(), f"Expected weights dir at {weights_dir}" + @patch("xtuner.v1.train.trainer.is_hf_model_path", Mock(return_value=True)) @patch("xtuner.v1.train.trainer.Trainer.build_engine", Mock(side_effect=lambda *args, **kwargs: FakeEngine())) @patch("xtuner.v1.train.trainer.Trainer._prepare_model_input", Mock(return_value=[])) @@ -824,8 +882,7 @@ def test_resume_and_load_checkpoint_cfg(tmp_path: Path): mock_data_load_state_dict.assert_not_called() mock_lr_load_state_dict.assert_not_called() mock_load_dcp.assert_called_once_with( - model_dir=checkpoint_path/Trainer._SAVE_MODEL_DIR, - optimizer_dir=checkpoint_path/Trainer._SAVE_OPTIMIZER_DIR, + weights_dir=checkpoint_path / Trainer._SAVE_WEIGHTS_DIR, load_states=True, load_args=True, ) @@ -859,8 +916,7 @@ def test_resume_and_load_checkpoint_cfg(tmp_path: Path): mock_data_load_state_dict.assert_called_once() mock_lr_load_state_dict.assert_called_once() mock_load_dcp.assert_called_once_with( - model_dir=latest_checkpoint/Trainer._SAVE_MODEL_DIR, - optimizer_dir=latest_checkpoint/Trainer._SAVE_OPTIMIZER_DIR, + weights_dir=latest_checkpoint / Trainer._SAVE_WEIGHTS_DIR, load_states=True, load_args=True, ) diff --git a/xtuner/v1/engine/train_engine.py b/xtuner/v1/engine/train_engine.py index 8414d8dc13..517568a50b 100644 --- a/xtuner/v1/engine/train_engine.py +++ b/xtuner/v1/engine/train_engine.py @@ -1,7 +1,11 @@ +from __future__ import annotations + import json import os +import shutil import threading -from concurrent.futures import wait +import time +from concurrent.futures import Future, wait from pathlib import Path from typing import Any, Dict, List, cast @@ -9,6 +13,7 @@ import torch.distributed as dist import torch.distributed.checkpoint as dcp from safetensors import safe_open +from torch.distributed.checkpoint.filesystem import FileSystemWriter from torch.distributed.checkpoint.state_dict import ( StateDictOptions, get_model_state_dict, @@ -134,6 +139,7 @@ def __init__( optim_cfg: OptimConfig, fsdp_cfg: FSDPConfig, intra_layer_micro_batch: int = 1, + async_checkpoint: bool = False, ) -> None: self.model_cfg = model_cfg self.optim_cfg = optim_cfg @@ -143,6 +149,13 @@ def __init__( self.intra_layer_micro_batch = intra_layer_micro_batch self._count = 0 self.has_freeze_params = self.__has_freeze_params() + self._async_checkpoint_pg: dist.ProcessGroup | None = None + self._async_state_dict_cache: dict[str, Any] | None = None + if async_checkpoint: + # dcp.async_save() performs collectives from a background thread. + # Keep those gloo collectives off the training NCCL process group + # to avoid cross-thread communication conflicts. + self._async_checkpoint_pg = dist.new_group(backend="gloo") def __has_freeze_params(self) -> bool: has_freeze_params = False @@ -294,98 +307,179 @@ def save_hf(self, hf_dir: str, save_dtype: torch.dtype = torch.bfloat16): """ self.model.save_hf(hf_dir=hf_dir, save_dtype=save_dtype) - # TODO: Support async save + def _get_dcp_state_dict( + self, + *, + cpu_offload: bool, + save_optimizer: bool = True, + ) -> dict[str, Any]: + options = StateDictOptions( + cpu_offload=cpu_offload, + ignore_frozen_params=self.model_cfg.dcp_ignore_frozen_params, + ) + state_dict: dict[str, Any] = {} + + with profile_time_and_memory("[DCP Collect Model State Dict]"): + state_dict["model"] = get_model_state_dict(self.model, options=options) + + if save_optimizer: + with profile_time_and_memory("[DCP Collect Optimizer State Dict]"): + state_dict["optimizer"] = get_optimizer_state_dict(self.model, self.optimizer, options=options) + + return state_dict + def save_dcp( self, - model_dir: Path, - optimizer_dir: Path | None = None, - ): - rank = dist.get_rank() - - if rank == 0: - model_dir.mkdir(parents=True, exist_ok=True) - if optimizer_dir is not None: - optimizer_dir.mkdir(parents=True, exist_ok=True) - - _options = StateDictOptions(cpu_offload=True, ignore_frozen_params=self.model_cfg.dcp_ignore_frozen_params) - with profile_time_and_memory(f"[DCP Checkpoint to {model_dir}]"): - model_state = get_model_state_dict(self.model, options=_options) + weights_dir: Path, + save_optimizer: bool = True, + ) -> None: + if dist.get_rank() == 0: + weights_dir.mkdir(parents=True, exist_ok=True) + + state_dict = self._get_dcp_state_dict(cpu_offload=True, save_optimizer=save_optimizer) + + with profile_time_and_memory(f"[DCP save for {weights_dir}]"): dcp.save( - model_state, - checkpoint_id=model_dir, + state_dict, + checkpoint_id=weights_dir, ) - with profile_time_and_memory(f"[DCP Checkpoint to {optimizer_dir}]"): - if optimizer_dir is not None: - shard_optimizer_state_dict = get_optimizer_state_dict(self.model, self.optimizer, options=_options) - dcp.save( - shard_optimizer_state_dict, - checkpoint_id=optimizer_dir, - ) + def async_save_dcp( + self, + weights_dir: Path, + save_optimizer: bool = True, + ) -> Future: + + # Match async HF export semantics: write the DCP payload into a + # temporary .incomplete directory and commit it only after every rank's + # async_save future has completed. + incomplete_dir = weights_dir.with_name(f"{weights_dir.name}.incomplete") + if dist.get_rank() == 0: + if incomplete_dir.exists(): + shutil.rmtree(incomplete_dir) + incomplete_dir.mkdir(parents=True, exist_ok=True) + + state_dict = self._get_dcp_state_dict(cpu_offload=False, save_optimizer=save_optimizer) + storage_writer = self._build_async_storage_writer(incomplete_dir, save_optimizer=save_optimizer) + + t0 = time.time() + with profile_time_and_memory(f"[DCP async_save for {weights_dir}]"): + dcp_future = dcp.async_save( + state_dict, + checkpoint_id=incomplete_dir, + storage_writer=storage_writer, + process_group=self._async_checkpoint_pg, + ) + + committed_future: Future = Future() + + def commit_async_save(done_future: Future) -> None: + def commit() -> None: + try: + done_future.result() + if self._async_checkpoint_pg is not None: + dist.barrier(group=self._async_checkpoint_pg) + if dist.get_rank() == 0: + if weights_dir.exists(): + shutil.rmtree(weights_dir) + incomplete_dir.rename(weights_dir) + if self._async_checkpoint_pg is not None: + dist.barrier(group=self._async_checkpoint_pg) + except BaseException as exc: + elapsed = time.time() - t0 + logger.error(f"[DCP async_save for {weights_dir}] failed after {elapsed:.2f}s: {exc}") + if not committed_future.done(): + committed_future.set_exception(exc) + return + + elapsed = time.time() - t0 + logger.info(f"[DCP async_save for {weights_dir}] finished in {elapsed:.2f}s") + if not committed_future.done(): + committed_future.set_result(None) + + threading.Thread(target=commit, daemon=True).start() + + self._async_state_dict_cache = storage_writer.state_dict_cache + dcp_future.add_done_callback(commit_async_save) + return committed_future + + def _build_async_storage_writer(self, weights_dir: Path, *, save_optimizer: bool) -> FileSystemWriter: + # cache_staged_state_dict keeps pinned staging buffers on the + # FileSystemWriter instance. XTuner creates one writer per checkpoint + # path, so carry the cache across writers to preserve steady-state + # async_save launch performance. + if self._async_state_dict_cache is not None: + cached_has_optim = "optimizer" in self._async_state_dict_cache + if cached_has_optim != save_optimizer: + self._async_state_dict_cache = None + + storage_writer = FileSystemWriter(weights_dir, cache_staged_state_dict=True) + storage_writer.state_dict_cache = self._async_state_dict_cache + return storage_writer + + def destroy_async_checkpoint_pg(self) -> None: + """Destroy the dedicated gloo process group used for async checkpoint.""" + self._async_state_dict_cache = None + if self._async_checkpoint_pg is not None: + if dist.is_available() and dist.is_initialized(): + dist.destroy_process_group(self._async_checkpoint_pg) + self._async_checkpoint_pg = None + + def __del__(self) -> None: + try: + self.destroy_async_checkpoint_pg() + except Exception: + pass def load_dcp( self, - model_dir: Path, - optimizer_dir: Path | None = None, + weights_dir: Path, load_states: bool = True, load_args: bool = True, - ): - """Load the dcp model from the given directory. + ) -> None: + """Load a DCP checkpoint saved in the merged weights format. - Args: - dcp_dir (str): The directory to load the model from. + If the checkpoint was saved without optimizer (checkpoint_save_optimizer=False), + only model weights will be loaded regardless of load_states/load_args settings. """ - _load_options = StateDictOptions( - cpu_offload=True, ignore_frozen_params=self.model_cfg.dcp_ignore_frozen_params - ) + load_optimizer = load_states or load_args + state_dict = self._get_dcp_state_dict(cpu_offload=True, save_optimizer=load_optimizer) + if self.has_freeze_params: - _set_options = StateDictOptions(cpu_offload=True, strict=False) + set_options = StateDictOptions(cpu_offload=True, strict=False) else: - _set_options = StateDictOptions(cpu_offload=True, strict=True) - with profile_time_and_memory(f"[Load DCP Model from {model_dir}]"): - shard_model_state_dict = get_model_state_dict(self.model, options=_load_options) - # inplace state_dict - dcp.load( - state_dict=shard_model_state_dict, - checkpoint_id=model_dir, + set_options = StateDictOptions(cpu_offload=True, strict=True) + + with profile_time_and_memory(f"[Load DCP from {weights_dir}]"): + dcp.load(state_dict=state_dict, checkpoint_id=weights_dir) + + set_model_state_dict(self.model, state_dict["model"], options=set_options) + + if not load_optimizer: + return + + optimizer_state_dict = state_dict["optimizer"] + if not load_states: + logger.info("Not loading optimizer states") + optimizer_state_dict["state"] = {} + if not load_args: + logger.info("Not loading arg defaults") + param_groups = self.optimizer.state_dict()["param_groups"] + assert len(param_groups) == 1, "Only one param_group is supported now" + init_defaults = param_groups[0] + init_defaults.pop("params") + for param_group in cast(List[Dict[str, Any]], optimizer_state_dict["param_groups"]): + default_keys = list(filter(lambda x: x != "params", param_group.keys())) + for key in default_keys: + param_group.pop(key) + param_group.update(init_defaults) + + set_optimizer_state_dict( + self.model, + self.optimizer, + optim_state_dict=optimizer_state_dict, + options=set_options, ) - set_model_state_dict(self.model, shard_model_state_dict, options=_set_options) - - if optimizer_dir is not None: - with profile_time_and_memory(f"[Load DCP Optimizer] from {optimizer_dir}"): - shard_optimizer_state_dict = get_optimizer_state_dict( - self.model, self.optimizer, options=_load_options - ) - dcp.load( - state_dict=shard_optimizer_state_dict, - checkpoint_id=optimizer_dir, - ) - if not load_states: - logger.info("Not loading optimizer states") - shard_optimizer_state_dict["state"] = {} - if not load_args: - logger.info("Not loading arg defaults") - param_groups = self.optimizer.state_dict()["param_groups"] - # Now we only support one param_group. If we want to support different lr for different parameters, - # we may use multiple param_groups like: - # [{'params': ['net1.weight', 'net2.weight'], 'lr': 0.001}, {'params': ['net3.weight'], 'lr': 0.002}] - # Then we need change the code here - assert len(param_groups) == 1, "Only one param_group is supported now" - init_defaults = param_groups[0] - init_defaults.pop("params") - for param_group in cast(List[Dict[str, Any]], shard_optimizer_state_dict["param_groups"]): - # param_group is like: {'params': ['net1.weight', 'net2.weight'], 'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0.01} - default_keys = list(filter(lambda x: x != "params", param_group.keys())) - for key in default_keys: - param_group.pop(key) - param_group.update(init_defaults) # lr, betas, eps, etc. - - set_optimizer_state_dict( - self.model, - self.optimizer, - optim_state_dict=shard_optimizer_state_dict, - options=_set_options, - ) def put_model_to_device(self, device: torch.device | str): """Put the model to the given device.""" diff --git a/xtuner/v1/train/arguments/arguments.py b/xtuner/v1/train/arguments/arguments.py index f763cd4eea..c340addb29 100644 --- a/xtuner/v1/train/arguments/arguments.py +++ b/xtuner/v1/train/arguments/arguments.py @@ -105,6 +105,7 @@ class TrainingArguments(BaseModel): load_scheduler: Annotated[bool, Parameter(group=checkpoint_group, help="load scheduler state from checkpoint")] = ( True ) + async_checkpoint: Annotated[bool, Parameter(group=checkpoint_group, help="enable async checkpoint saving")] = False fsdp_config: Annotated[FSDPConfig | None, Parameter(group=parallel_group, help="FSDP configuration")] = None float8_config: Annotated[Float8Config | None, Parameter(group=parallel_group, help="use float8 training")] = None @@ -165,6 +166,7 @@ def to_trainer_config(self): total_epoch=self.epoch_num, resume_cfg=resume_cfg, work_dir=self.work_dir, + async_checkpoint=self.async_checkpoint, ) def _get_dataset_config(self) -> DatasetConfigList: diff --git a/xtuner/v1/train/trainer.py b/xtuner/v1/train/trainer.py index 26ec589191..409edb2420 100644 --- a/xtuner/v1/train/trainer.py +++ b/xtuner/v1/train/trainer.py @@ -6,11 +6,14 @@ import pickle import sys import time +from concurrent.futures import Future, TimeoutError from contextlib import contextmanager from datetime import datetime, timedelta from pathlib import Path from shutil import rmtree -from typing import Annotated, Callable, Literal, Protocol, Sequence, Sized, cast, overload, runtime_checkable +from typing import ( + Annotated, Callable, Literal, Protocol, Sequence, Sized, cast, overload, runtime_checkable, +) import torch import torch.distributed as dist @@ -334,8 +337,10 @@ class TrainerConfig(BaseModel): strict_load: bool = True checkpoint_interval: int | None = -1 checkpoint_maxkeep: int | None = -1 + checkpoint_save_optimizer: bool = True skip_checkpoint_validation: bool = False # Suggest enabled if fsdp_size is larger than 512 patch_for_dcp_finish: bool = False + async_checkpoint: bool = False snapshot_interval: int | None = None check_health_interval: int | None = None hf_interval: int | None = None @@ -427,8 +432,7 @@ class Trainer: _EXP_TRACKING_PATH = "exp_tracking" _CHECKPOINT_DIR = "checkpoints" - _SAVE_OPTIMIZER_DIR = "optimizer" - _SAVE_MODEL_DIR = "model" + _SAVE_WEIGHTS_DIR = "weights" _SAVE_DATALOADER_DIR = "dataloader" _SAVE_SCHEDULER_DIR = "lr_scheduler" _SAVE_TRAIN_STATE_PATH = "train_state.json" @@ -458,8 +462,10 @@ def __init__( strict_load: bool = True, checkpoint_interval: int | None = -1, checkpoint_maxkeep: int | None = -1, + checkpoint_save_optimizer: bool = True, skip_checkpoint_validation: bool = False, # Suggest enabled if fsdp_size is larger than 512 patch_for_dcp_finish: bool = False, + async_checkpoint: bool = False, snapshot_interval: int | None = None, check_health_interval: int | None = None, hf_interval: int | None = None, @@ -522,6 +528,9 @@ def __init__( self._checkpoint_interval = checkpoint_interval self._checkpoint_maxkeep = checkpoint_maxkeep + self._checkpoint_save_optimizer = checkpoint_save_optimizer + self._async_checkpoint = async_checkpoint + self._pending_checkpoint: Future | None = None self._snapshot_interval = snapshot_interval self._check_health_interval = check_health_interval self._hf_max_keep = hf_max_keep @@ -634,6 +643,7 @@ def __init__( load_checkpoint_path=self._load_checkpoint_cfg.checkpoint_path, strict=strict_load, intra_layer_micro_batch=intra_layer_micro_batch, + async_checkpoint=self._async_checkpoint, ) self._lr_cfg = lr_cfg self._lr_scheduler = self.build_lr_scheduler(lr_cfg, self.total_step) @@ -692,8 +702,10 @@ def from_config(cls, config: TrainerConfig) -> Self: strict_load=config.strict_load, checkpoint_interval=config.checkpoint_interval, checkpoint_maxkeep=config.checkpoint_maxkeep, + checkpoint_save_optimizer=config.checkpoint_save_optimizer, skip_checkpoint_validation=config.skip_checkpoint_validation, patch_for_dcp_finish=config.patch_for_dcp_finish, + async_checkpoint=config.async_checkpoint, snapshot_interval=config.snapshot_interval, check_health_interval=config.check_health_interval, hf_interval=config.hf_interval, @@ -793,10 +805,13 @@ def fit(self): gc.collect() # TODO: Should use flush rather than close + self._wait_for_pending_checkpoint() + self._engine.destroy_async_checkpoint_pg() self._exp_tracker.close() if self._metrics_recorder: self._metrics_recorder.close() self.logger.info(f"Training finished in {time.time() - train_begin:.2f} seconds") + dist.barrier() def _prepare_model_input(self, data_batch) -> list[ModelItem]: seq_ctx_list: list[SequenceContext] = [] @@ -1005,6 +1020,7 @@ def build_engine( load_checkpoint_path: str | Path | None, intra_layer_micro_batch: int = 1, strict: bool = True, + async_checkpoint: bool = False, ): """Build the training engine for the transformer model. @@ -1016,6 +1032,7 @@ def build_engine( resume_cfg (ResumeConfig | None): Resume configuration for continuing training. intra_layer_micro_batch (int): Intra-layer micro batch size for gradient accumulation. strict (bool): Whether to strictly load model weights. + async_checkpoint (bool): Whether to create a dedicated gloo process group for async checkpoint. Returns: TrainEngine: Initialized training engine. @@ -1025,6 +1042,7 @@ def build_engine( fsdp_cfg=fsdp_config, model_cfg=model_config, intra_layer_micro_batch=intra_layer_micro_batch, + async_checkpoint=async_checkpoint, ) if model_path is not None and (model_config.dcp_ignore_frozen_params or load_checkpoint_path is None): engine.from_hf(hf_path=model_path, strict=strict) @@ -1092,6 +1110,83 @@ def _maybe_check_health(self): raise RuntimeError("Health check failed, exit training") logger.info(f"Health check passed at step {self.cur_step}") + def _wait_for_pending_checkpoint(self, timeout: int = 300) -> None: + if self._pending_checkpoint is None: + return + + future = self._pending_checkpoint + self._pending_checkpoint = None + + try: + future.result(timeout=timeout) + except TimeoutError: + future.cancel() + raise TimeoutError(f"Async checkpoint timed out after {timeout}s") + + def _finalize_checkpoint_metadata( + self, + checkpoint_path: Path, + meta_path: Path, + train_state_path: Path, + is_snapshot: bool, + cur_step: int, + cur_epoch: int, + total_consumed_tokens: int, + train_time_offset: float, + ) -> None: + # Save train state + if self.rank == 0: + with train_state_path.open("w") as f: + f.write( + json.dumps( + { + "cur_step": cur_step, + "cur_epoch": cur_epoch, + "total_consumed_tokens": total_consumed_tokens, + "train_time_offset": train_time_offset, + } + ) + ) + + # Update meta + current_exp = self.meta.latest_exp + ckp_list = current_exp.checkpoint_list if not is_snapshot else current_exp.snap_checkpoint_list + ckp_list.append(str(checkpoint_path)) + current_exp.cur_step = cur_step + current_exp.cur_epoch = cur_epoch + current_exp.consumed_tokens = int(total_consumed_tokens) + current_exp.history[-1]["end"] = cur_step + + # Delete checkpoints and update meta's checkpoint_list + ckp_maxkeep = self._checkpoint_maxkeep if not is_snapshot else 1 + if ckp_maxkeep is not None and ckp_maxkeep > 0 and len(ckp_list) > ckp_maxkeep: + ckp_pop_num = len(ckp_list) - ckp_maxkeep + for _ in range(ckp_pop_num): + deleted_ckp = ckp_list.pop(0) + if self.rank == 0 and Path(deleted_ckp).exists(): + rmtree(deleted_ckp) + + # Save meta, must after deleting checkpoints to ensure the checkpoint_list is updated in the meta file + if self.rank == 0: + with meta_path.open("w") as f: + f.write(self.meta.model_dump_json(indent=2)) + + dist.barrier() + + if is_snapshot: + hooks = self.hooks_config.get_hooks(HookStage.AFTER_SAVE_SNAPSHOT) + else: + hooks = self.hooks_config.get_hooks(HookStage.AFTER_SAVE_DCP) + + for hook in hooks: + hook( + checkpoint=checkpoint_path, + step=cur_step, + epoch=cur_epoch, + total_step=self.total_step, + total_epoch=self.total_epoch, + ) + def _maybe_save(self, is_snapshot: bool = False) -> bool: ckp_interval = self._checkpoint_interval if not is_snapshot else self._snapshot_interval if ckp_interval is None: @@ -1109,26 +1204,29 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool: checkpoint_path = self._get_checkpoint_path(epoch=self._cur_epoch, step=self.cur_step, is_snapshot=is_snapshot) checkpoint_path.mkdir(parents=True, exist_ok=True) + self._wait_for_pending_checkpoint() + meta_path = self.work_dir / self._META_PATH - optimizer_path = checkpoint_path / self._SAVE_OPTIMIZER_DIR - model_path = checkpoint_path / self._SAVE_MODEL_DIR + weights_path = checkpoint_path / self._SAVE_WEIGHTS_DIR dataloader_path = checkpoint_path / self._SAVE_DATALOADER_DIR scheduler_path = checkpoint_path / self._SAVE_SCHEDULER_DIR train_state_path = checkpoint_path / self._SAVE_TRAIN_STATE_PATH + total_consumed_tokens = ( + self._reduce_number_across_rank(self._local_total_consumed_tokens) + self._init_total_tokens + ) + if self.cur_step % ckp_interval == 0: DEVICE_MODULE.empty_cache() # Save model and optimizer - self._engine.save_dcp( - model_dir=model_path, - optimizer_dir=optimizer_path, - ) - - total_consumed_tokens = ( - self._reduce_number_across_rank(self._local_total_consumed_tokens) + self._init_total_tokens - ) + save_optimizer = self._checkpoint_save_optimizer + future: Future | None = None + if self._async_checkpoint and not is_snapshot: + future = self._engine.async_save_dcp(weights_dir=weights_path, save_optimizer=save_optimizer) + else: + self._engine.save_dcp(weights_dir=weights_path, save_optimizer=save_optimizer) # Save dataloader self._save_dataloader(dataloader_path) @@ -1151,60 +1249,21 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool: with config_bin.open("wb") as f: pickle.dump(self._trainer_cfg, f) - dist.barrier() - - # Save train state - if self.rank == 0: - with train_state_path.open("w") as f: - f.write( - json.dumps( - { - "cur_step": self.cur_step, - "cur_epoch": self._cur_epoch, - "total_consumed_tokens": total_consumed_tokens, - "train_time_offset": self._train_time + self._train_time_offset, - } - ) - ) - - # Update meta - current_exp = self.meta.latest_exp - ckp_list = current_exp.checkpoint_list if not is_snapshot else current_exp.snap_checkpoint_list - ckp_list.append(str(checkpoint_path)) - current_exp.cur_step = self.cur_step - current_exp.cur_epoch = self._cur_epoch - current_exp.consumed_tokens = int(total_consumed_tokens) - current_exp.history[-1]["end"] = self.cur_step - - # Delete checkpoints and update meta's checkpoint_list - ckp_maxkeep = self._checkpoint_maxkeep if not is_snapshot else 1 - if ckp_maxkeep is not None and ckp_maxkeep > 0 and len(ckp_list) > ckp_maxkeep: - ckp_pop_num = len(ckp_list) - ckp_maxkeep - for _ in range(ckp_pop_num): - deleted_ckp = ckp_list.pop(0) - if self.rank == 0 and Path(deleted_ckp).exists(): - rmtree(deleted_ckp) - - # Save meta, must after deleting checkpoints to ensure the checkpoint_list is updated in the meta file - if self.rank == 0: - with meta_path.open("w") as f: - f.write(self.meta.model_dump_json(indent=2)) + if future is not None: + self._pending_checkpoint = future dist.barrier() - if is_snapshot: - hooks = self.hooks_config.get_hooks(HookStage.AFTER_SAVE_SNAPSHOT) - else: - hooks = self.hooks_config.get_hooks(HookStage.AFTER_SAVE_DCP) - - for hook in hooks: - hook( - checkpoint=checkpoint_path, - step=self.cur_step, - epoch=self._cur_epoch, - total_step=self.total_step, - total_epoch=self.total_epoch, - ) + self._finalize_checkpoint_metadata( + checkpoint_path=checkpoint_path, + meta_path=meta_path, + train_state_path=train_state_path, + is_snapshot=is_snapshot, + cur_step=self.cur_step, + cur_epoch=self._cur_epoch, + total_consumed_tokens=total_consumed_tokens, + train_time_offset=self._train_time + self._train_time_offset, + ) return True @@ -1784,16 +1843,12 @@ def _load_checkpoint(self): if not resume_from.exists(): raise FileNotFoundError(f"Checkpoint path {resume_from} does not exist.") - model_path = resume_from / self._SAVE_MODEL_DIR - optimizer_path = ( - resume_from / self._SAVE_OPTIMIZER_DIR - if load_checkpoint_cfg.load_optimizer_states or load_checkpoint_cfg.load_optimizer_args - else None - ) + weights_path = resume_from / self._SAVE_WEIGHTS_DIR + if not weights_path.exists(): + raise FileNotFoundError(f"Checkpoint at {resume_from} has no '{self._SAVE_WEIGHTS_DIR}/' directory.") self._engine.load_dcp( - model_dir=model_path, - optimizer_dir=optimizer_path, + weights_dir=weights_path, load_states=load_checkpoint_cfg.load_optimizer_states, load_args=load_checkpoint_cfg.load_optimizer_args, )