diff --git a/tests/model/test_qwen3_moe.py b/tests/model/test_qwen3_moe.py index 1c069fdfae..42c62458b4 100644 --- a/tests/model/test_qwen3_moe.py +++ b/tests/model/test_qwen3_moe.py @@ -347,6 +347,149 @@ def test_save_hf(self, device, dispatcher, ep_size): self.assertListEqual(safetensor_keys, model_index_keys) dist.barrier() + + @parametrize.parametrize( + "device,dispatcher,ep_size", + [ + ("cuda", None, 1), + ("cuda", "all2all", 4), + ("cuda", "all2all", 8), + ], + ) + def test_async_save_hf(self, device, dispatcher, ep_size): + self.create_pg(device) + + with tempfile.TemporaryDirectory() as tmpdir: + syncdir = [tmpdir] + if self.world_size > 1: + dist.broadcast_object_list(syncdir, src=0) + tmpdir = Path(syncdir[0]) + saved_hf_path = tmpdir / "hf-1" + origin_hf_path = Path(QWEN3_MOE_PATH) + origin_index_path = origin_hf_path / "model.safetensors.index.json" + saved_index_path = saved_hf_path / "model.safetensors.index.json" + + with torch.device("meta"): + cfg = get_model_config_from_hf(QWEN3_MOE_PATH) + cfg.compile_cfg = False + cfg.dispatcher = dispatcher + cfg.ep_size = ep_size + qwen_model = cfg.build().to(torch.bfloat16) + + fsdp_config = FSDPConfig( + ep_size=ep_size, + cpu_offload=False, + ) + qwen_model.fully_shard(fsdp_config=fsdp_config) + qwen_model.from_hf(QWEN3_MOE_PATH) + + tokenizer = AutoTokenizer.from_pretrained(QWEN3_MOE_PATH, trust_remote_code=True) + + qwen_model.async_save_hf(hf_dir=saved_hf_path) + qwen_model.wait_async_hf() + + if dist.get_rank() == 0: + tokenizer.save_pretrained(str(saved_hf_path)) + + dist.barrier() + + self.assertTrue(saved_hf_path.exists()) + self.assertTrue(saved_index_path.exists()) + + dist.barrier() + + if dist.get_rank() == 0: + with open(origin_index_path, "r") as f: + origin_index = json.load(f) + with open(saved_index_path, "r") as f: + saved_index = json.load(f) + with open(origin_hf_path / "config.json", "r") as f: + origin_config = json.load(f) + with open(saved_hf_path / "config.json", "r") as f: + saved_config = json.load(f) + + self.assertTrue(check_dict_equal(origin_config, saved_config)) + self.assertListEqual( + sorted(origin_index["weight_map"].keys()), + sorted(saved_index["weight_map"].keys()), + ) + + cache_fh = {} + for key in origin_index["weight_map"].keys(): + origin_safetensor_name = origin_index["weight_map"][key] + saved_safetensor_name = saved_index["weight_map"][key] + + if origin_safetensor_name not in cache_fh: + cache_fh[origin_safetensor_name] = safe_open( + str(origin_hf_path / origin_safetensor_name), framework="pt" + ) + if saved_safetensor_name not in cache_fh: + cache_fh[saved_safetensor_name] = safe_open( + str(saved_hf_path / saved_safetensor_name), framework="pt" + ) + + origin_tensor = cache_fh[origin_safetensor_name].get_tensor(key) + saved_tensor = cache_fh[saved_safetensor_name].get_tensor(key) + self.assertTrue(torch.equal(origin_tensor, saved_tensor), f"tensor {key} is not equal") + + safetensor_keys = [] + for safetensor_path in saved_hf_path.glob("*.safetensors"): + fh = cache_fh[safetensor_path.name] + safetensor_keys.extend(fh.keys()) + safetensor_keys.sort() + model_index_keys = list(saved_index["weight_map"].keys()) + model_index_keys.sort() + self.assertListEqual(safetensor_keys, model_index_keys) + + dist.barrier() + + del qwen_model + torch.cuda.empty_cache() + + input_ids = tokenizer("吃葡萄不吐葡萄皮", return_tensors="pt").input_ids.to("cuda") + labels = input_ids.clone() + + hf_origin_model = AutoModelForCausalLM.from_pretrained( + origin_hf_path, + torch_dtype=torch.bfloat16, + trust_remote_code=True, + device_map="cuda", + ) + patch_hf_rms_norm(hf_origin_model) + hf_origin_model.eval() + with torch.no_grad(): + origin_output = hf_origin_model(input_ids=input_ids, labels=labels) + origin_loss = origin_output.loss.detach().cpu() + origin_logits = origin_output.logits.detach().cpu() + + del hf_origin_model + del origin_output + torch.cuda.empty_cache() + + hf_saved_model = AutoModelForCausalLM.from_pretrained( + saved_hf_path, + torch_dtype=torch.bfloat16, + trust_remote_code=True, + device_map="cuda", + ) + patch_hf_rms_norm(hf_saved_model) + hf_saved_model.eval() + with torch.no_grad(): + saved_output = hf_saved_model(input_ids=input_ids, labels=labels) + saved_loss = saved_output.loss.detach().cpu() + saved_logits = saved_output.logits.detach().cpu() + + self.assertTrue( + torch.allclose(origin_loss, saved_loss, rtol=1e-2, atol=1e-2), + f"origin_loss={origin_loss.item()}, saved_loss={saved_loss.item()}", + ) + self.assertTrue(torch.equal(origin_logits.argmax(dim=-1), saved_logits.argmax(dim=-1))) + + del hf_saved_model + del saved_output + torch.cuda.empty_cache() + + dist.barrier() def test_fope_auto_config_with_remote_code(self): self.create_pg('cuda') @@ -462,6 +605,7 @@ def world_size(self) -> int: def create_model_from_hf(load_from: Path, dispatcher: str, ep_size: int): with torch.device("meta"): cfg : Qwen3MoEConfig = get_model_config_from_hf(load_from) + cfg.compile_cfg = False cfg.dispatcher = dispatcher cfg.ep_size = ep_size qwen_model = cfg.build() diff --git a/tests/train/test_trainer_async_hf.py b/tests/train/test_trainer_async_hf.py new file mode 100644 index 0000000000..6fe0380df8 --- /dev/null +++ b/tests/train/test_trainer_async_hf.py @@ -0,0 +1,260 @@ +import json +import os +import tempfile +from pathlib import Path +from unittest.mock import Mock, patch + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing._internal.common_distributed import DistributedTestBase + +from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig +from xtuner.v1.datasets import FTDPTokenizeFnConfig +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.model.moe.qwen3 import Qwen3MoE30BA3Config +from xtuner.v1.model.utils.misc import ModelForwardExtraLogInfo +from xtuner.v1.train.trainer import Trainer +from xtuner.v1.utils.device import get_device + + +DEVICE = get_device() + + +class FakeHFModel(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 10) + self.hf_index_calls = [] + + def forward(self, x): + return self.linear(x) + + def _write_hf_index_and_config(self, hf_dir: Path | str, weight_map: dict[str, str]): + hf_dir = Path(hf_dir) + hf_dir.mkdir(parents=True, exist_ok=True) + self.hf_index_calls.append({"hf_dir": hf_dir, "weight_map": dict(weight_map)}) + (hf_dir / "config.json").write_text('{"model_type": "fake_model"}') + with (hf_dir / "model.safetensors.index.json").open("w") as f: + json.dump({"metadata": {"total_size": len(weight_map)}, "weight_map": weight_map}, f) + + +class FakeAsyncHFEngine: + def __init__(self): + self.save_hf_calls = [] + self.wait_async_hf_calls = [] + self.train_step_calls = 0 + self.grad_norm_calls = 0 + self.optimizer_step_calls = 0 + self.async_hf_status_ok = True + self.async_hf_status_error = "" + + self.model = model = FakeHFModel() + self.optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + self.has_freeze_params = False + self._pending_async_hf = None + + def grad_accumulation_steps(self, *args, **kwargs): + return 1 + + def train_step(self, *args, **kwargs): + self.train_step_calls += 1 + return { + "total_loss": 1.8, + "step_consumed_tokens": 100, + "step_consumed_img_tokens": 0.0, + "grad_norm": torch.tensor(1.0), + "efficient_attn_ratio": 0.5, + "img_efficient_attn_ratio": 0.0, + "logs_info": {"local_loss": 1.0, "reduced_llm_loss": 0.8}, + "extra_info": ModelForwardExtraLogInfo(), + } + + def step_optimizer(self, *args, **kwargs): + self.optimizer_step_calls += 1 + return 1.0 + + def clip_grad_norm(self, do_clip: bool = True, dtype=torch.float32): + self.grad_norm_calls += 1 + return torch.tensor(1.0) + + def save_hf(self, hf_dir: Path | str): + finalized_hf_dir = self.wait_async_hf() + hf_dir = Path(hf_dir) + hf_dir.mkdir(parents=True, exist_ok=True) + self.save_hf_calls.append(hf_dir) + rank = dist.get_rank() + world_size = dist.get_world_size() + shard_name = f"model-rank{rank}-of-{world_size}.safetensors" + (hf_dir / shard_name).write_text(f"fake async model weights for rank {rank}") + weight_map = {f"layers.rank{rank}.weight": shard_name} if self.async_hf_status_ok else {} + self._pending_async_hf = { + "hf_dir": hf_dir, + "ok": self.async_hf_status_ok, + "error": self.async_hf_status_error, + "weight_map": weight_map, + } + return finalized_hf_dir + + def wait_async_hf(self): + self.wait_async_hf_calls.append(None) + if self._pending_async_hf is None: + return None + + pending = self._pending_async_hf + local_status = { + "rank": dist.get_rank(), + "ok": bool(pending["ok"]), + "error": str(pending["error"]), + "weight_map": pending["weight_map"], + } + all_status = [None for _ in range(dist.get_world_size())] + dist.all_gather_object(all_status, local_status) + if not all(status["ok"] for status in all_status): + self._pending_async_hf = None + failed = ", ".join( + f"rank={status['rank']}({status['error']})" for status in all_status if not status["ok"] + ) + raise RuntimeError(f"Async HF save global consistency check failed: {failed}") + + merged_weight_map = {} + for status in all_status: + merged_weight_map.update(status["weight_map"]) + + if dist.get_rank() == 0: + self.model._write_hf_index_and_config( + hf_dir=Path(pending["hf_dir"]), + weight_map=merged_weight_map, + ) + self._pending_async_hf = None + return Path(pending["hf_dir"]) + + +def prepare(fn): + def wrapper(self, *args, **kwargs): + self.alpaca_path = Path(__file__).resolve().parents[1] / "resource" / "openai_sft.jsonl" + self.tokenizer_path = None + self.temp_dir = tempfile.TemporaryDirectory() + self.fake_hf_model_dir = Path(self.temp_dir.name) / "fake_hf_model" + self.work_dir = Path(self.temp_dir.name) / "work_dir" + + self.fake_hf_model_dir.mkdir() + (self.fake_hf_model_dir / "config.json").write_text('{"model_type": "fake_model"}') + (self.fake_hf_model_dir / "model.safetensors").write_text("fake weights") + ret = fn(self, *args, **kwargs) + self.temp_dir.cleanup() + return ret + + return wrapper + + +class TestTrainerAsyncSaveHF(DistributedTestBase): + def create_pg(self, device): + ret = super().create_pg(device) + os.environ["LOCAL_RANK"] = str(dist.get_rank()) + return ret + + def _broadcast_work_dir(self): + work_dir_list = [self.work_dir] + dist.broadcast_object_list(work_dir_list, src=0) + self.work_dir = Path(work_dir_list[0]) + + def _build_trainer(self, total_step=10, **kwargs): + trainer_kwargs = dict( + load_from=str(self.fake_hf_model_dir), + model_cfg=Qwen3MoE30BA3Config(), + optim_cfg=AdamWConfig(lr=1e-4, weight_decay=0.01), + fsdp_cfg=FSDPConfig(tp_size=1), + dataset_cfg=[ + { + "dataset": DatasetConfig(name="alpaca", anno_path=self.alpaca_path, sample_ratio=1.0), + "tokenize_fn": FTDPTokenizeFnConfig(), + }, + ], + dataloader_cfg=DataloaderConfig(), + lr_cfg=LRConfig(lr_type="constant", warmup_ratio=0.1, lr_min=1e-6), + tokenizer_path=self.tokenizer_path, + global_batch_size=2, + total_step=total_step, + work_dir=str(self.work_dir), + hf_interval=3, + hf_max_keep=2, + checkpoint_interval=None, + snapshot_interval=None, + seed=42, + debug=False, + ) + trainer_kwargs.update(kwargs) + return Trainer(**trainer_kwargs) + + @property + def world_size(self) -> int: + return int(os.getenv("XTUNER_TEST_WORLD_SIZE", "2")) + + @patch("xtuner.v1.train.trainer.time.sleep", Mock()) + @patch("xtuner.v1.train.trainer.is_hf_model_path", Mock(return_value=True)) + @patch("xtuner.v1.train.trainer.Trainer._prepare_model_input", Mock(return_value=[])) + @patch( + "xtuner.v1.train.trainer.Trainer.build_engine", + Mock(side_effect=lambda *args, **kwargs: FakeAsyncHFEngine()), + ) + @prepare + def test_async_save_hf_interval(self): + self.create_pg(DEVICE) + self._broadcast_work_dir() + trainer = self._build_trainer(async_hf_export=True) + trainer.fit() + dist.barrier() + + self.assertEqual(len(trainer._engine.save_hf_calls), 4) + if dist.get_rank() == 0: + self.assertEqual(len(trainer._engine.model.hf_index_calls), 4) + + exp_dir = self.work_dir / trainer.exp_dir.name + hf_dirs = sorted(d.name for d in exp_dir.iterdir() if d.name.startswith("hf-") and d.is_dir()) + self.assertEqual(hf_dirs, ["hf-10", "hf-9", "hf-latest"]) + + latest_hf = exp_dir / "hf-latest" + self.assertTrue(latest_hf.is_symlink()) + self.assertEqual(latest_hf.resolve(), (exp_dir / "hf-10").resolve()) + + self.assertEqual( + [Path(path).name for path in trainer.meta.latest_exp.hf_checkpoint_list], + ["hf-9", "hf-10"], + ) + + hf10_dir = exp_dir / "hf-10" + index_path = hf10_dir / "model.safetensors.index.json" + self.assertTrue(index_path.exists()) + with index_path.open("r") as f: + index_info = json.load(f) + weight_map = index_info["weight_map"] + self.assertEqual(sorted(weight_map.keys()), [f"layers.rank{rank}.weight" for rank in range(self.world_size)]) + self.assertEqual( + sorted(weight_map.values()), + [f"model-rank{rank}-of-{self.world_size}.safetensors" for rank in range(self.world_size)], + ) + for shard_name in weight_map.values(): + self.assertTrue((hf10_dir / shard_name).exists()) + else: + self.assertEqual(len(trainer._engine.model.hf_index_calls), 0) + dist.barrier() + + @patch("xtuner.v1.train.trainer.is_hf_model_path", Mock(return_value=True)) + @patch( + "xtuner.v1.train.trainer.Trainer.build_engine", + Mock(side_effect=lambda *args, **kwargs: FakeAsyncHFEngine()), + ) + @prepare + def test_async_save_hf_raises_on_writer_failure(self): + self.create_pg(DEVICE) + self._broadcast_work_dir() + trainer = self._build_trainer(total_step=3, async_hf_export=True) + trainer._engine.async_hf_status_ok = False + trainer._engine.async_hf_status_error = "mock async hf failure" + trainer._cur_step = 3 + + trainer._maybe_save_hf() + with self.assertRaisesRegex(RuntimeError, "Async HF save global consistency check failed"): + trainer._engine.wait_async_hf() + self.assertIsNone(trainer._engine._pending_async_hf) diff --git a/xtuner/v1/engine/train_engine.py b/xtuner/v1/engine/train_engine.py index 8414d8dc13..baeae8a8fa 100644 --- a/xtuner/v1/engine/train_engine.py +++ b/xtuner/v1/engine/train_engine.py @@ -134,6 +134,7 @@ def __init__( optim_cfg: OptimConfig, fsdp_cfg: FSDPConfig, intra_layer_micro_batch: int = 1, + async_hf_export: bool = False, ) -> None: self.model_cfg = model_cfg self.optim_cfg = optim_cfg @@ -143,6 +144,7 @@ def __init__( self.intra_layer_micro_batch = intra_layer_micro_batch self._count = 0 self.has_freeze_params = self.__has_freeze_params() + self._async_hf_export = async_hf_export def __has_freeze_params(self) -> bool: has_freeze_params = False @@ -285,14 +287,19 @@ def clean_param_name(name: str) -> str: name = name.replace("_orig_mod.", "") return name - def save_hf(self, hf_dir: str, save_dtype: torch.dtype = torch.bfloat16): + def save_hf(self, hf_dir: str, save_dtype: torch.dtype = torch.bfloat16) -> Path | None: """Save the hf model to the given directory. Args: hf_dir (str): The directory to save the model. save_dtype (torch.dtype): The dtype to save the model parameters, bfloat16 or float8. """ - self.model.save_hf(hf_dir=hf_dir, save_dtype=save_dtype) + if self._async_hf_export: + return self.model.async_save_hf(hf_dir=hf_dir, save_dtype=save_dtype) + return self.model.save_hf(hf_dir=hf_dir, save_dtype=save_dtype) + + def wait_async_hf(self) -> Path | None: + return self.model.wait_async_hf() # TODO: Support async save def save_dcp( diff --git a/xtuner/v1/model/base.py b/xtuner/v1/model/base.py index 221fd88ce7..ad40905f48 100644 --- a/xtuner/v1/model/base.py +++ b/xtuner/v1/model/base.py @@ -4,11 +4,12 @@ import pydoc import re from concurrent.futures import Future, ThreadPoolExecutor, wait +from dataclasses import dataclass from functools import reduce from importlib import import_module from itertools import chain from pathlib import Path -from shutil import copy, copytree +from shutil import copy, copytree, rmtree from typing import Annotated, Any, Generator, Iterable, Literal, Mapping, Sequence, cast import torch @@ -81,6 +82,15 @@ class BatchForwardInfo(TypedDict): extra_info: ModelForwardExtraLogInfo +@dataclass +class _AsyncHFSaveHandle: + futures: list[Future] + executor: ThreadPoolExecutor + hf_dir: Path + tmp_hf_dir: Path + weight_map: dict[str, str] + + class TorchCompileOption(TypedDict): fullgraph: NotRequired[bool] dynamic: NotRequired[bool | None] @@ -524,6 +534,8 @@ def __init__(self, config: XTunerBaseModelConfig): self.config = config self._hf_path: Path | None = None # type: ignore + self._async_hf_tensor_cache: dict[tuple[Any, ...], torch.Tensor] = {} + self._pending_async_hf: _AsyncHFSaveHandle | None = None self._compile_cfg = self._resolve_compile_cfg(self.config) self._float8_handler: Float8Handler | None = None @@ -658,9 +670,135 @@ def traverse(module): ignored_params=ignored_params if ignored_params else None, ) - def save_hf(self, hf_dir: Path | str, save_dtype: torch.dtype = torch.bfloat16, safetensors_prefix: str = "model"): + def save_hf( + self, hf_dir: Path | str, save_dtype: torch.dtype = torch.bfloat16, safetensors_prefix: str = "model" + ) -> Path: with profile_time_and_memory(f"[Saving HF to [{safetensors_prefix}]{hf_dir} cost]"): self._save_hf(hf_dir=hf_dir, save_dtype=save_dtype, safetensors_prefix=safetensors_prefix) + return Path(hf_dir) + + def async_save_hf( + self, + hf_dir: Path | str, + save_dtype: torch.dtype = torch.bfloat16, + safetensors_prefix: str = "model", + ) -> Path | None: + if self._hf_path is None and self.config.hf_config is None: + raise NotImplementedError( + "The model is not loaded from Huggingface, and the `hf_config` property is not implemented, so it cannot be saved in Huggingface format." + ) + finalized_hf_dir = self.wait_async_hf() + + if isinstance(hf_dir, str): + hf_dir = Path(hf_dir) + tmp_hf_dir = hf_dir.with_name(f"{hf_dir.name}.incomplete") + rank = dist.get_rank() if dist.is_initialized() else 0 + if rank == 0: + if tmp_hf_dir.exists(): + rmtree(tmp_hf_dir) + if dist.is_initialized(): + dist.barrier() + tmp_hf_dir.mkdir(parents=True, exist_ok=True) + if dist.is_initialized(): + dist.barrier() + + file_to_names: list[tuple[str, list[str]]] = [] + weight_map: dict[str, str] = {} + for safetensor_name, name_list, hf_tensor_list in self._iter_hf_save_chunks( + save_dtype=save_dtype, + safetensors_prefix=safetensors_prefix, + device=DEVICE, + ): + cached_names: list[str] = [] + for name, hf_tensor in zip(name_list, hf_tensor_list): + cache_key = (("root", "hf"), ("name", name)) + self._get_or_update_async_hf_cpu_tensor( + hf_tensor, + cache=self._async_hf_tensor_cache, + path=cache_key, + ) + cached_names.append(name) + weight_map[name] = safetensor_name + if cached_names: + file_to_names.append((safetensor_name, cached_names)) + del hf_tensor_list + + if hasattr(DEVICE_MODULE, "synchronize"): + DEVICE_MODULE.synchronize() + + executor = ThreadPoolExecutor(max_workers=1) + save_futures: list[Future] = [] + for filename, names in file_to_names: + tensors: dict[str, torch.Tensor] = {} + for name in names: + cache_key = (("root", "hf"), ("name", name)) + cached_tensor = cast(torch.Tensor | None, self._async_hf_tensor_cache.get(cache_key)) + if cached_tensor is None: + executor.shutdown() + raise RuntimeError(f"Missing cached async HF tensor for key: {name}") + tensors[name] = cached_tensor + save_futures.append(executor.submit(_save_file, tensors, tmp_hf_dir / filename)) + + self._pending_async_hf = _AsyncHFSaveHandle( + futures=save_futures, + executor=executor, + hf_dir=hf_dir, + tmp_hf_dir=tmp_hf_dir, + weight_map=weight_map, + ) + return finalized_hf_dir + + def wait_async_hf(self) -> Path | None: + if self._pending_async_hf is None: + return None + + handle = self._pending_async_hf + local_ok = True + local_error = "" + local_weight_map: dict[str, str] = {} + + try: + for future in handle.futures: + future.result() + local_weight_map = handle.weight_map + except Exception as exc: + local_ok = False + local_error = str(exc) + finally: + handle.executor.shutdown() + + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + local_status = {"rank": rank, "ok": local_ok, "error": local_error, "weight_map": local_weight_map} + if dist.is_initialized(): + all_status: list[dict[str, Any]] = [None for _ in range(world_size)] # type: ignore[list-item] + dist.all_gather_object(all_status, local_status) + else: + all_status = [local_status] + + if not all(status["ok"] for status in all_status): + failed = ", ".join( + f"rank={status['rank']}({status['error']})" for status in all_status if not status["ok"] + ) + self._pending_async_hf = None + raise RuntimeError(f"Async HF save global consistency check failed: {failed}") + + merged_weight_map: dict[str, str] = {} + for status in all_status: + merged_weight_map.update(status["weight_map"]) + + if rank == 0: + self._write_hf_index_and_config(hf_dir=handle.tmp_hf_dir, weight_map=merged_weight_map) + if dist.is_initialized(): + dist.barrier() + if rank == 0: + if handle.hf_dir.exists(): + rmtree(handle.hf_dir) + handle.tmp_hf_dir.rename(handle.hf_dir) + if dist.is_initialized(): + dist.barrier() + self._pending_async_hf = None + return handle.hf_dir def safetensors_to_params( self, @@ -1477,6 +1615,117 @@ def _get_safe_tensor_num(self, dtype: torch.dtype) -> int: + math.ceil(fused_size / bucket_size) ) + def _iter_hf_save_chunks( + self, + save_dtype: torch.dtype = torch.bfloat16, + safetensors_prefix: str = "model", + device: torch.device | str = "cpu", + ) -> Generator[tuple[str, list[str], list[torch.Tensor]], None, None]: + assert save_dtype in [torch.float8_e4m3fn, torch.bfloat16], f"save_dtype {save_dtype} is not supported" + + shard_gen = self._get_shard_hf_param( + self._group_param_by_load_spec(LoadEnum.SHARD), + dtype=save_dtype, + device=device, + ) + same_gen = self._get_same_hf_param( + self._group_param_by_load_spec(LoadEnum.SAME), + dtype=save_dtype, + device=device, + ) + fused_gen = self._get_fused_hf_param( + self._group_param_by_load_spec(LoadEnum.FUSED), + dtype=save_dtype, + device=device, + ) + + is_others_save_rank = not dist.is_initialized() or dist.get_rank() == 0 + save_rank = dist.get_rank() if dist.is_initialized() else 0 + + saved_names: set[str] = set() + safetensor_index = 0 + + for name_list, hf_tensor_list in fused_gen: + if not name_list: + continue + safetensor_index += 1 + safetensor_name = f"{safetensors_prefix}-{safetensor_index:04d}-fused-save_rank{save_rank}.safetensors" + saved_names.update(name_list) + yield safetensor_name, name_list, hf_tensor_list + + safetensor_index = 0 + for name_list, hf_tensor_list in chain(same_gen, shard_gen): + safetensor_index += 1 + safetensor_name = f"{safetensors_prefix}-{safetensor_index:04d}-others-save_rank{save_rank}.safetensors" + if not is_others_save_rank: + continue + + unique_name_list: list[str] = [] + unique_hf_tensor_list: list[torch.Tensor] = [] + for name, hf_tensor in zip(name_list, hf_tensor_list): + if name in saved_names: + continue + saved_names.add(name) + unique_name_list.append(name) + unique_hf_tensor_list.append(hf_tensor) + if unique_name_list: + yield safetensor_name, unique_name_list, unique_hf_tensor_list + + @staticmethod + def _allocate_async_hf_cpu_buffer_like(tensor: torch.Tensor) -> torch.Tensor: + cpu_tensor = torch.empty_like(tensor, device="cpu") + if tensor.is_cuda: + cpu_tensor = cpu_tensor.pin_memory() + return cpu_tensor + + def _get_or_update_async_hf_cpu_tensor( + self, + tensor: torch.Tensor, + cache: dict[tuple[Any, ...], torch.Tensor], + path: tuple[Any, ...], + ) -> torch.Tensor: + detached = tensor.detach() + + cached = cache.get(path) + if cached is None or ( + cached.shape != detached.shape + or cached.dtype != detached.dtype + or cached.layout != detached.layout + or cached.stride() != detached.stride() + ): + cached = self._allocate_async_hf_cpu_buffer_like(detached) + cache[path] = cached + + cached.copy_(detached, non_blocking=detached.is_cuda) + return cached + + def _write_hf_non_weight_files(self, hf_dir: Path) -> None: + if self._hf_path is not None: + for file in cast(Path, self._hf_path).iterdir(): + if file.suffix != ".safetensors": + target_path = hf_dir / file.name + if file.is_file(): + copy(file, target_path) + else: + copytree(file, target_path, ignore_dangling_symlinks=True, dirs_exist_ok=True) + return + + if self.config.hf_config is not None: + self.config.save_hf(hf_dir) + return + + raise RuntimeError("Internal Error, both self.config.hf_config and self._hf_path are None") + + def _write_hf_index_and_config(self, hf_dir: Path | str, weight_map: Mapping[str, str]) -> None: + if isinstance(hf_dir, str): + hf_dir = Path(hf_dir) + + self._write_hf_non_weight_files(hf_dir) + + with open(hf_dir / "model.safetensors.index.json", "w") as f: + index = {"weight_map": dict(weight_map), "metadata": {}} + json.dump(index, f, indent=2, ensure_ascii=False) + def _save_hf( self, hf_dir: Path | str, save_dtype: torch.dtype = torch.bfloat16, safetensors_prefix: str = "model" ): diff --git a/xtuner/v1/train/trainer.py b/xtuner/v1/train/trainer.py index 3e0dbfd9fa..5def465618 100644 --- a/xtuner/v1/train/trainer.py +++ b/xtuner/v1/train/trainer.py @@ -10,7 +10,7 @@ from datetime import datetime, timedelta from pathlib import Path from shutil import rmtree -from typing import Annotated, Callable, Literal, Protocol, Sequence, Sized, cast, overload, runtime_checkable +from typing import Any, Annotated, Callable, Literal, Protocol, Sequence, Sized, cast, overload, runtime_checkable import torch import torch.distributed as dist @@ -334,6 +334,7 @@ class TrainerConfig(BaseModel): strict_load: bool = True checkpoint_interval: int | None = -1 checkpoint_maxkeep: int | None = -1 + async_hf_export: bool = False skip_checkpoint_validation: bool = False # Suggest enabled if fsdp_size is larger than 512 patch_for_dcp_finish: bool = False snapshot_interval: int | None = None @@ -458,6 +459,7 @@ def __init__( strict_load: bool = True, checkpoint_interval: int | None = -1, checkpoint_maxkeep: int | None = -1, + async_hf_export: bool = False, skip_checkpoint_validation: bool = False, # Suggest enabled if fsdp_size is larger than 512 patch_for_dcp_finish: bool = False, snapshot_interval: int | None = None, @@ -513,15 +515,18 @@ def __init__( if not self._can_save_hf: assert_info = ( - f"`hf_interval`: {hf_interval} and `hf_max_keep`: {hf_max_keep} " + f"`hf_interval`: {hf_interval}, `hf_max_keep`: {hf_max_keep} and " + f"`async_hf_export`: {async_hf_export} " f"should be None when `load_from` is not a Huggingface model path, " ) if is_hf_path is False and error_info is not None: assert_info += f", HF path load error Info: {error_info}" - assert hf_interval is None and hf_max_keep is None, assert_info + assert hf_interval is None and hf_max_keep is None and async_hf_export is False, assert_info self._checkpoint_interval = checkpoint_interval self._checkpoint_maxkeep = checkpoint_maxkeep + self._async_hf_export = async_hf_export + self._pending_async_hf_meta: dict[str, Any] | None = None self._snapshot_interval = snapshot_interval self._check_health_interval = check_health_interval self._hf_max_keep = hf_max_keep @@ -692,6 +697,7 @@ def from_config(cls, config: TrainerConfig) -> Self: strict_load=config.strict_load, checkpoint_interval=config.checkpoint_interval, checkpoint_maxkeep=config.checkpoint_maxkeep, + async_hf_export=config.async_hf_export, skip_checkpoint_validation=config.skip_checkpoint_validation, patch_for_dcp_finish=config.patch_for_dcp_finish, snapshot_interval=config.snapshot_interval, @@ -792,6 +798,47 @@ def fit(self): if self.cur_step % 50 == 0: gc.collect() + if self._async_hf_export: + finalized_hf_path = self._engine.wait_async_hf() + if finalized_hf_path is not None: + assert self._pending_async_hf_meta is not None + assert finalized_hf_path == self._pending_async_hf_meta["path"] + latest_hf_link = self.exp_dir / "hf-latest" + + self.meta.latest_exp.hf_checkpoint_list.append(str(finalized_hf_path)) + + if self._hf_max_keep is not None and len(self.meta.latest_exp.hf_checkpoint_list) > self._hf_max_keep: + deleted_hf_checkpoints = self.meta.latest_exp.hf_checkpoint_list[: -self._hf_max_keep] + self.meta.latest_exp.hf_checkpoint_list = self.meta.latest_exp.hf_checkpoint_list[ + -self._hf_max_keep : + ] + for hf_dir in deleted_hf_checkpoints: + if self.rank == 0 and Path(hf_dir).exists(): + rmtree(hf_dir) + + if self.rank == 0: + if isinstance(self.tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): + self.tokenizer.save_pretrained(str(finalized_hf_path)) + latest_hf_link.unlink(missing_ok=True) + latest_hf_link.symlink_to(finalized_hf_path.absolute(), target_is_directory=True) + + meta_path = self.work_dir / self._META_PATH + + if self.rank == 0: + with meta_path.open("w") as f: + f.write(self.meta.model_dump_json(indent=2)) + + hooks = self.hooks_config.get_hooks(HookStage.AFTER_SAVE_HF) + for hook in hooks: + hook( + checkpoint=finalized_hf_path, + step=self._pending_async_hf_meta["step"], + epoch=self._pending_async_hf_meta["epoch"], + total_step=self.total_step, + total_epoch=self.total_epoch, + ) + self._pending_async_hf_meta = None + # TODO: Should use flush rather than close self._exp_tracker.close() if self._metrics_recorder: @@ -1025,6 +1072,7 @@ def build_engine( fsdp_cfg=fsdp_config, model_cfg=model_config, intra_layer_micro_batch=intra_layer_micro_batch, + async_hf_export=self._async_hf_export, ) if model_path is not None and (model_config.dcp_ignore_frozen_params or load_checkpoint_path is None): engine.from_hf(hf_path=model_path, strict=strict) @@ -1575,7 +1623,29 @@ def _maybe_save_hf(self): return save_hf_path = self.exp_dir / f"hf-{self.cur_step}" + + finalized_hf_path = self._engine.save_hf(str(save_hf_path)) + if self._async_hf_export: + if finalized_hf_path is not None: + assert self._pending_async_hf_meta is not None + assert finalized_hf_path == self._pending_async_hf_meta["path"] + finalized_step = self._pending_async_hf_meta["step"] + finalized_epoch = self._pending_async_hf_meta["epoch"] + self._pending_async_hf_meta = None + self._pending_async_hf_meta = { + "path": save_hf_path, + "step": self.cur_step, + "epoch": self._cur_epoch, + } + if finalized_hf_path is None: + return + else: + assert finalized_hf_path is not None + finalized_step = self.cur_step + finalized_epoch = self._cur_epoch + latest_hf_link = self.exp_dir / "hf-latest" + save_hf_path = finalized_hf_path self.meta.latest_exp.hf_checkpoint_list.append(str(save_hf_path)) @@ -1586,7 +1656,6 @@ def _maybe_save_hf(self): if self.rank == 0 and Path(hf_dir).exists(): rmtree(hf_dir) - self._engine.save_hf(str(save_hf_path)) if self.rank == 0: if isinstance(self.tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): self.tokenizer.save_pretrained(str(save_hf_path)) @@ -1604,8 +1673,8 @@ def _maybe_save_hf(self): for hook in hooks: hook( checkpoint=save_hf_path, - step=self.cur_step, - epoch=self._cur_epoch, + step=finalized_step, + epoch=finalized_epoch, total_step=self.total_step, total_epoch=self.total_epoch, )