diff --git a/examples/v1/config/rl_disagg_single.py b/examples/v1/config/rl_disagg_single.py index 0be84c632..1e26eb734 100644 --- a/examples/v1/config/rl_disagg_single.py +++ b/examples/v1/config/rl_disagg_single.py @@ -126,7 +126,7 @@ # 3. judger -judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", judger_type="router") +judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k") # 4. train worker diff --git a/tests/rl/test_update_weight_disaggregated.py b/tests/rl/test_update_weight_disaggregated.py new file mode 100644 index 000000000..934b416a0 --- /dev/null +++ b/tests/rl/test_update_weight_disaggregated.py @@ -0,0 +1,326 @@ +import os +import hashlib +import sys +import tempfile +import time +import unittest +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parents[2] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) +TEST_DIR = Path(__file__).resolve().parent +if str(TEST_DIR) not in sys.path: + sys.path.insert(0, str(TEST_DIR)) + +import ray +import torch +import torch.distributed as dist + +from xtuner.v1.rl.rollout import RolloutController +from xtuner.v1.data_proto.rl_data import SampleParams, RolloutState +from xtuner.v1.config import ( + AdamWConfig, + FSDPConfig, + LRConfig, +) +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers +from xtuner.v1.rl.trainer import WorkerConfig, TrainingController, TrainingWorker as BaseTrainingWorker +from xtuner.v1.rl.loss import GRPOLossConfig as LossConfig +from xtuner.v1.model import get_model_config_from_hf +from xtuner.v1.utils import ray_method + +TEST_TEXT_MESSAGES = [{"role": "user", "content": "Hello!"}] +MODEL_PATH = os.environ.get("MODEL_PATH") or os.environ.get("QWEN3_VL_DENSE_PATH") + + +class HashingTrainingWorker(BaseTrainingWorker): + def _init_update_weighter(self): + super()._init_update_weighter() + self._test_update_weight_sha256 = hashlib.sha256() + self._test_update_weight_bucket_count = 0 + + @ray_method + def reset_update_weight_sha256(self): + self._test_update_weight_sha256 = hashlib.sha256() + self._test_update_weight_bucket_count = 0 + + @ray_method + def get_update_weight_sha256(self): + return { + "rank": self.rank, + "sha256": self._test_update_weight_sha256.hexdigest(), + "bucket_count": self._test_update_weight_bucket_count, + } + + def request_update_params(self, state_dict, train_enable_ep=False, finished=False, profile_context=None): + if state_dict and dist.get_rank() == 0: + for name in sorted(state_dict): + tensor = state_dict[name].detach().contiguous().cpu() + self._test_update_weight_sha256.update(name.encode("utf-8")) + self._test_update_weight_sha256.update(str(tensor.dtype).encode("utf-8")) + self._test_update_weight_sha256.update(str(tuple(tensor.shape)).encode("utf-8")) + self._test_update_weight_sha256.update(tensor.view(torch.uint8).numpy().tobytes()) + self._test_update_weight_bucket_count += 1 + return super().request_update_params( + state_dict, + train_enable_ep=train_enable_ep, + finished=finished, + ) + + +class TestUpdateWeight(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + if MODEL_PATH is None: + raise unittest.SkipTest("MODEL_PATH is not set") + os.environ["XTUNER_USE_FA3"] = "1" + + @classmethod + def tearDownClass(cls) -> None: + del os.environ["XTUNER_USE_FA3"] + + def setUp(self): + ray.init(num_cpus=80, ignore_reinit_error=True) + self.model_path = MODEL_PATH + self.temp_dir = tempfile.TemporaryDirectory() + self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") + self.init_config() + self.pg = AutoAcceleratorWorkers.build_placement_group( + self.train_resources_cfg, + name=f"test_update_weight_train_{id(self)}", + ) + + def tearDown(self): + ray.shutdown() + self.temp_dir.cleanup() + + def init_config(self): + train_num_workers = int(os.environ.get("TRAIN_NUM_WORKERS", "4")) + rollout_num_workers = int(os.environ.get("ROLLOUT_NUM_WORKERS", "4")) + rollout_tp_size = int(os.environ.get("ROLLOUT_TP_SIZE", str(rollout_num_workers))) + rollout_ep_size = int(os.environ.get("ROLLOUT_EP_SIZE", "1")) + train_ep_size = int(os.environ.get("TRAIN_EP_SIZE", "1")) + + self.train_resources_cfg = AcceleratorResourcesConfig( + accelerator="GPU", + num_workers=train_num_workers, + num_cpus_per_worker=float(os.environ.get("TRAIN_CPUS_PER_WORKER", "12")), + cpu_memory_per_worker=int(os.environ.get("TRAIN_CPU_MEMORY_PER_WORKER", str(16 * 1024**3))), + ) + self.rollout_resources_cfg = AcceleratorResourcesConfig( + accelerator="GPU", + num_workers=rollout_num_workers, + num_cpus_per_worker=float(os.environ.get("ROLLOUT_CPUS_PER_WORKER", "12")), + cpu_memory_per_worker=int(os.environ.get("ROLLOUT_CPU_MEMORY_PER_WORKER", str(16 * 1024**3))), + ) + self.rollout_cfg = RolloutConfig( + env="test_rollout", + model_path=MODEL_PATH, + model_name=os.path.basename(MODEL_PATH).lower(), + tokenizer_path=MODEL_PATH, + rollout_cross_node_comm=False, + tensor_parallel_size=rollout_tp_size, + expert_parallel_size=rollout_ep_size, + gpus_per_node=int(os.environ.get("GPUS_PER_NODE", "8")), # gpu: 8, npu: 16 + dtype="bfloat16", + skip_load_weights=True, + context_length=int(os.environ.get("ROLLOUT_CONTEXT_LENGTH", "256")), + worker_log_dir=self.worker_log_dir, + gpu_memory_utilization=float(os.environ.get("ROLLOUT_GPU_MEMORY_UTILIZATION", "0.5")), + ) + + # training config + model_cfg = get_model_config_from_hf(Path(MODEL_PATH)) + optim_cfg: AdamWConfig = AdamWConfig(lr=5e-7, foreach=False) + fsdp_cfg: FSDPConfig = FSDPConfig(ep_size=train_ep_size) + lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=5e-7) + self.worker_cfg: WorkerConfig = WorkerConfig( + model_cfg=model_cfg, + optim_cfg=optim_cfg, + loss_cfg=LossConfig( + policy_loss_cfg=dict( + cliprange_high=0.28, + cliprange_low=0.2, + loss_type="vanilla", + ), + ignore_idx=-100, + use_kl_loss=False, + kl_loss_coef=0.001, + kl_loss_type="low_var_kl", + mode="eager"), + lr_cfg=lr_cfg, + fsdp_cfg=fsdp_cfg, + load_from=MODEL_PATH, + sp_size=1, + pack_max_length=1024, + ) + + def _build_train_controller(self, worker_cls=BaseTrainingWorker): + TrainingWorker = ray.remote( + runtime_env={ + "env_vars": { + "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1", + "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "1", + } + }, + )(worker_cls) + train_workers, _ = AutoAcceleratorWorkers.from_placement_group( + TrainingWorker, self.worker_cfg, self.pg + ) + ray.get([worker.test_all_reduce.remote() for worker in train_workers]) + train_controller = TrainingController(workers=train_workers) + train_controller.set_train_rollout_mode("disaggregated") + return train_controller + + def _build_sglang_rollout_controller(self): + rollout_pg = AutoAcceleratorWorkers.build_placement_group( + self.rollout_resources_cfg, + name=f"test_update_weight_rollout_{id(self)}", + ) + self.rollout_cfg.skip_load_weights = False + return ray.remote(RolloutController).remote( + self.rollout_cfg, + rollout_pg, + ) + + @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") + def test_lmdeploy_update_weight_and_generate(self): + # init train + TrainingWorker = ray.remote( + runtime_env={ + "env_vars": { + "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1", + "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "1", + } + }, + )(BaseTrainingWorker) + train_workers, _ = AutoAcceleratorWorkers.from_placement_group( + TrainingWorker, self.worker_cfg, self.pg + ) + futures = [ worker.test_all_reduce.remote() for worker in train_workers ] + ray.get(futures) + train_controller = TrainingController( + workers=train_workers, + ) + # fixed sample params + sample_params = SampleParams(temperature=0.0, max_tokens=128, top_k=1) + + # init rollout_controller and rollout baseline + self.rollout_cfg.skip_load_weights = False + rollout_controller = ray.remote(RolloutController).remote( + self.rollout_cfg, + self.pg, + ) + + input_state = RolloutState(message=TEST_TEXT_MESSAGES, sample_params=sample_params) + res_baseline = ray.get(rollout_controller.generate.remote(rollout_state=input_state)) + + # start update weight test + info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) + train_controller.update_rollout_info(info_dict) + + # update weights + ray.get(rollout_controller.offload.remote()) + train_controller.onload(target="all") + train_controller.offload("optimizer") + ray.get(rollout_controller.onload_weights.remote()) + train_controller.update_weights() + train_controller.offload("model") + ray.get(rollout_controller.onload_kvcache.remote()) + + res_update_weight = ray.get(rollout_controller.generate.remote(rollout_state=input_state)) + self.assertEqual(res_update_weight.response, res_baseline.response) + ray.get(rollout_controller.shutdown.remote(), timeout=60) + + @unittest.skipIf(os.environ.get("XTUNER_USE_SGLANG", "0") == "0", "sglang backend is not enabled") + def test_sglang_disaggregated_update_weight_and_generate(self): + # init train on a dedicated placement group + TrainingWorker = ray.remote( + runtime_env={ + "env_vars": { + "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1", + "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "1", + } + }, + )(BaseTrainingWorker) + train_workers, _ = AutoAcceleratorWorkers.from_placement_group( + TrainingWorker, self.worker_cfg, self.pg + ) + futures = [worker.test_all_reduce.remote() for worker in train_workers] + ray.get(futures) + train_controller = TrainingController(workers=train_workers) + train_controller.set_train_rollout_mode("disaggregated") + + # init rollout on a separate placement group + rollout_pg = AutoAcceleratorWorkers.build_placement_group( + self.rollout_resources_cfg, + name=f"test_update_weight_rollout_{id(self)}", + ) + self.rollout_cfg.skip_load_weights = False + rollout_controller = ray.remote(RolloutController).remote( + self.rollout_cfg, + rollout_pg, + ) + + sample_params = SampleParams(temperature=0.0, max_tokens=128, top_k=1) + input_state = RolloutState(message=TEST_TEXT_MESSAGES, sample_params=sample_params) + res_baseline = ray.get(rollout_controller.generate.remote(rollout_state=input_state)) + + info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) + train_controller.update_rollout_info(info_dict) + + train_controller.update_weights() + + res_update_weight = ray.get(rollout_controller.generate.remote(rollout_state=input_state)) + self.assertEqual(res_update_weight.response, res_baseline.response) + ray.get(rollout_controller.shutdown.remote(), timeout=60) + + @unittest.skipIf(os.environ.get("XTUNER_USE_SGLANG", "0") == "0", "sglang backend is not enabled") + def test_sglang_disaggregated_update_weight_after_pause_and_generate(self): + train_controller = self._build_train_controller() + rollout_controller = self._build_sglang_rollout_controller() + + sample_params = SampleParams(temperature=0.0, max_tokens=128, top_k=1) + input_state = RolloutState(message=TEST_TEXT_MESSAGES, sample_params=sample_params) + res_baseline = ray.get(rollout_controller.generate.remote(rollout_state=input_state)) + + info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) + train_controller.update_rollout_info(info_dict) + + ray.get(rollout_controller.pause_generation.remote()) + time.sleep(float(os.environ.get("XTUNER_UPDATE_WEIGHT_PAUSE_SLEEP", "2"))) + train_controller.update_weights() + ray.get(rollout_controller.continue_generation.remote()) + + res_update_weight = ray.get(rollout_controller.generate.remote(rollout_state=input_state)) + self.assertEqual(res_update_weight.response, res_baseline.response) + ray.get(rollout_controller.shutdown.remote(), timeout=60) + + @unittest.skipIf(os.environ.get("XTUNER_USE_SGLANG", "0") == "0", "sglang backend is not enabled") + def test_sglang_disaggregated_update_weight_sha256_is_stable(self): + train_controller = self._build_train_controller(worker_cls=HashingTrainingWorker) + rollout_controller = self._build_sglang_rollout_controller() + + info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) + train_controller.update_rollout_info(info_dict) + + ray.get([worker.reset_update_weight_sha256.remote() for worker in train_controller.workers]) + train_controller.update_weights() + first_hashes = ray.get([worker.get_update_weight_sha256.remote() for worker in train_controller.workers]) + + ray.get([worker.reset_update_weight_sha256.remote() for worker in train_controller.workers]) + train_controller.update_weights() + second_hashes = ray.get([worker.get_update_weight_sha256.remote() for worker in train_controller.workers]) + + first_rank0_hash = next(item for item in first_hashes if item["rank"] == 0) + second_rank0_hash = next(item for item in second_hashes if item["rank"] == 0) + self.assertGreater(first_rank0_hash["bucket_count"], 0) + self.assertEqual(first_rank0_hash["sha256"], second_rank0_hash["sha256"]) + self.assertEqual(first_rank0_hash["bucket_count"], second_rank0_hash["bucket_count"]) + ray.get(rollout_controller.shutdown.remote(), timeout=60) + + +if __name__ == "__main__": + unittest.main() diff --git a/xtuner/v1/rl/rollout/sglang.py b/xtuner/v1/rl/rollout/sglang.py index 1ea8398cb..93374220a 100644 --- a/xtuner/v1/rl/rollout/sglang.py +++ b/xtuner/v1/rl/rollout/sglang.py @@ -2,6 +2,7 @@ import os from typing import Any, Dict, List, Union +import ray import numpy as np import requests import torch @@ -221,6 +222,8 @@ def _transform_rollout_config_to_server_configs(self): ep_size = num_gpus_per_engine if self.config.expert_parallel_size > 1 else self.config.expert_parallel_size nnodes = max(1, num_gpus_per_engine // self.config.gpus_per_node) node_rank = self.rank // self.config.gpus_per_node if nnodes > 1 else 0 + assigned_gpu_id = int(ray.get_runtime_context().get_accelerator_ids()[self.accelerator][0]) + init_kwargs = dict( model_path=self.config.model_path, trust_remote_code=True, @@ -228,7 +231,7 @@ def _transform_rollout_config_to_server_configs(self): port=self.server_port, nccl_port=self.nccl_port, dist_init_addr=self.dist_init_addr, - base_gpu_id=self.rank % self.config.gpus_per_node, + base_gpu_id=assigned_gpu_id, gpu_id_step=1, nnodes=nnodes, node_rank=node_rank, diff --git a/xtuner/v1/rl/trainer/controller.py b/xtuner/v1/rl/trainer/controller.py index 3ce836bef..74f992463 100644 --- a/xtuner/v1/rl/trainer/controller.py +++ b/xtuner/v1/rl/trainer/controller.py @@ -292,6 +292,9 @@ def onload(self, target: Literal["model", "optimizer", "all"] = "all"): def update_rollout_info(self, info_dict): ray.get([worker.update_rollout_info.remote(**info_dict) for worker in self.workers]) # type: ignore[attr-defined] + def set_train_rollout_mode(self, train_rollout_mode: str): + ray.get([worker.set_train_rollout_mode.remote(train_rollout_mode) for worker in self.workers]) + def update_weights(self): """Update the weights of the training workers.""" handles = [worker.update_weights.remote() for worker in self.workers] diff --git a/xtuner/v1/rl/trainer/update_weighter.py b/xtuner/v1/rl/trainer/update_weighter.py new file mode 100644 index 000000000..a2dcfc06b --- /dev/null +++ b/xtuner/v1/rl/trainer/update_weighter.py @@ -0,0 +1,889 @@ +import os +import socket +from concurrent.futures import ThreadPoolExecutor +from datetime import timedelta +from itertools import chain +from threading import Lock +from typing import Any, Dict, List, TypeAlias, cast + +import requests +import torch +import torch.distributed as dist +import tqdm +from packaging.version import parse as parse_version +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.distributed_c10d import ( + Backend, + PrefixStore, + Store, + _new_process_group_helper, + _world, + default_pg_timeout, + rendezvous, +) +from torch.distributed.tensor import DTensor + +from xtuner.v1.model.compose.base import BaseComposeConfig +from xtuner.v1.model.compose.qwen3_vl import Qwen3VLForConditionalGeneration +from xtuner.v1.model.moe.moe import MoE +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.utils import ( + get_device, + get_torch_device_module, + monkey_unpatch_torch_reductions, + ray_method, +) +from xtuner.v1.utils.load_spec import LoadEnum, LoadSpec + + +DeviceMeshRaw: TypeAlias = List[List[int]] # A list of lists representing device mesh indices +ServiceUrlMap: TypeAlias = Dict[int, str] # A dictionary mapping service names to their URLs +RolloutEngineInfo: TypeAlias = list[tuple[int, str, int]] # (rollout rank, server url, engine gpu count) +DEVICE = get_device() +DEVICE_MODULE = get_torch_device_module() + + +class UpdateWeighter: + def _init_update_weighter(self): + # Used to update weight to rollout engine + self.rollout_device_mesh: DeviceMesh | None = None + self.rollout_url: str | None = None + self.rollout_cfg_info: dict = dict() + self.endpoints: dict[str, str] = dict() + self.endpoints["update_weights"] = "update_weights" + + self.rollout_engine_rank_mesh_array: DeviceMeshRaw = [] + self.rollout_server_url_dict: ServiceUrlMap = {} + self.worker_server_urls_status: dict[str, bool] = {} + + self._global_hf_keys_mapping_cache: dict[str, list[str]] = dict() + self._ipc_tensor_bytes: int = int(self.config.update_weight_bucket_size_in_gb * 1024**3) + self._update_params_ipc_tensor = None + self._update_params_ipc_event = None + self._sglang_disagg_group: dist.ProcessGroup | None = None + self._sglang_disagg_group_name: str | None = None + self._sglang_disagg_engine_urls: list[str] = [] + self._sglang_disagg_executor: ThreadPoolExecutor | None = None + self._train_update_sync_group: dist.ProcessGroup | None = None + self._sglang_disagg_update_lock = Lock() + + @ray_method + def update_rollout_info( + self, + engine_rank_mesh_array: DeviceMeshRaw, + server_url_dict: ServiceUrlMap, + rollout_config: RolloutConfig, + worker_server_urls_status: Dict[str, bool], + api_server_url: str | None = None, + ): + """Update the rollout information for the training worker.""" + tp = rollout_config.tensor_parallel_size + ep = rollout_config.expert_parallel_size + assert tp == 1 or ep == 1, "Either tensor parallel size or engine parallel size must be 1." + if self.rollout_device_mesh is None: + self.rollout_device_mesh = DeviceMesh( + "cpu", mesh=engine_rank_mesh_array, mesh_dim_names=("engine_instance", "engine_parallel") + ) + rollout_server_url = server_url_dict.get(self.rank, "") + if worker_server_urls_status.get(rollout_server_url, "False") is False: + self.logger.error(f"Rollout server url {rollout_server_url} is not available.") + self.rollout_url = None + else: + self.rollout_url = rollout_server_url + + self.rollout_engine_rank_mesh_array = [[int(rank) for rank in ranks] for ranks in engine_rank_mesh_array] + self.rollout_server_url_dict = {int(rank): url for rank, url in server_url_dict.items()} + self.worker_server_urls_status = worker_server_urls_status + + old_rollout_url = self.rollout_url + if old_rollout_url != self.rollout_url: + self._reset_sglang_disagg_group() + + self.rollout_cfg_info["tp"] = tp + self.rollout_cfg_info["ep"] = ep + self.rollout_cfg_info["api_key"] = rollout_config.api_key + if os.environ.get("XTUNER_USE_SGLANG", "0") == "1": + self.rollout_cfg_info["backend"] = "sglang" + elif os.environ.get("XTUNER_USE_VLLM", "0") == "1": + self.rollout_cfg_info["backend"] = "vllm" + else: + self.rollout_cfg_info["backend"] = (rollout_config.extra_rollout_config or dict()).get( + "lmdeploy_backend", "pytorch" + ) + + @ray_method + def set_train_rollout_mode(self, train_rollout_mode: str): + mode = train_rollout_mode.lower() + if mode == "colocate": + self.is_train_rollout_colocated = True + elif mode == "disaggregated": + self.is_train_rollout_colocated = False + else: + raise ValueError( + f"Unsupported train_rollout_mode: {train_rollout_mode!r}. " + "Expected 'colocate' or 'disaggregated'." + ) + + if self.is_train_rollout_colocated: + self._reset_sglang_disagg_group() + + def _reset_sglang_disagg_group(self): + if self._sglang_disagg_executor is not None: + self._sglang_disagg_executor.shutdown(wait=False, cancel_futures=True) + try: + if self._sglang_disagg_group is not None: + dist.destroy_process_group(self._sglang_disagg_group) + except Exception: + pass + self._sglang_disagg_group = None + self._sglang_disagg_group_name = None + self._sglang_disagg_engine_urls = [] + self._sglang_disagg_executor = None + + def _get_train_update_sync_group(self) -> dist.ProcessGroup: + if self._train_update_sync_group is None: + ranks = list(range(dist.get_world_size())) + self._train_update_sync_group = dist.new_group(ranks=ranks, backend="gloo") + return self._train_update_sync_group + + @ray_method + def update_weights(self): + """Update the model weights.""" + if self.is_train_rollout_colocated: + self._update_weights_colocated() + else: + self._update_weights_disaggregated() + + def _update_weights_colocated(self): + DEVICE_MODULE.empty_cache() + self._update_params_ipc_event = DEVICE_MODULE.Event(interprocess=True) + if self.rollout_cfg_info.get("backend") == "turbomind": + self._update_weights_by_layer() + else: + if isinstance(self.config.model_cfg, BaseComposeConfig): + self._update_weights_hf_generator(submodule="language_model", final_update=False) + self._update_weights_hf_generator(submodule="vision_tower", final_update=False) + self._update_weights_hf_generator(submodule="multi_modal_projector", final_update=True) + else: + self._update_weights_hf_generator(final_update=True) + self._update_params_ipc_tensor = None + self._update_params_ipc_event = None + + DEVICE_MODULE.empty_cache() + + def _update_weights_disaggregated(self): + assert self.rollout_cfg_info.get("backend") == "sglang", ( + "Only SGLang disaggregated weight update is implemented now." + ) + DEVICE_MODULE.empty_cache() + try: + if isinstance(self.config.model_cfg, BaseComposeConfig): + self._update_weights_hf_generator(submodule="language_model", final_update=False) + self._update_weights_hf_generator(submodule="vision_tower", final_update=False) + self._update_weights_hf_generator(submodule="multi_modal_projector", final_update=True) + else: + self._update_weights_hf_generator(final_update=True) + finally: + DEVICE_MODULE.empty_cache() + + def _rl_get_fused_ep_hf_param(self, model: MoE, target_ep_rank: int, target_ep_size: int, bucket_size: int): + fused_param_groups: list[tuple[torch.Tensor, LoadSpec]] = model._group_param_by_load_spec(LoadEnum.FUSED) + model_ep_size = 1 if model.fsdp_config is None else model.fsdp_config.ep_size + if not fused_param_groups: + return + + def _get_hf_params( + fsdp_tensor_list: list[tuple[torch.Tensor, LoadSpec]], + ) -> tuple[list[torch.Tensor], list[str]]: + hf_keys_list: list[str] = [] + hf_tensor_list: list[torch.Tensor] = [] + + for fsdp_tensor, load_spec in fsdp_tensor_list: + hf_keys = load_spec.hf_keys + if model_ep_size > 1 and model.ep_mesh is not None: + if load_spec.name not in self._global_hf_keys_mapping_cache: + global_hf_keys: list[list[str] | None] = [None] * model_ep_size + dist.all_gather_object(global_hf_keys, hf_keys, group=model.ep_mesh.get_group()) + global_hf_keys_gathered = cast(list[list[str]], global_hf_keys) + self._global_hf_keys_mapping_cache[load_spec.name] = list( + chain.from_iterable(global_hf_keys_gathered) + ) + hf_keys = self._global_hf_keys_mapping_cache[load_spec.name] + + fused_full_tensor = fsdp_tensor.bfloat16() + if isinstance(fused_full_tensor, DTensor): + fused_full_tensor = fused_full_tensor.full_tensor() + dim = cast(int, load_spec.dim) + num_split = len(hf_keys) + hf_tensor_size = fused_full_tensor.shape[dim] / num_split + assert hf_tensor_size.is_integer(), "Internal Error, hf_tensor_size is not integer" + hf_tensor_size = int(hf_tensor_size) + + hf_tensor = fused_full_tensor.split([hf_tensor_size] * num_split, dim=dim) + assert num_split % target_ep_size == 0, ( + f"len(hf_keys) of '{hf_keys}' is {num_split}, it must be divisible by target_ep_size {target_ep_size}" + ) + start_idx = (num_split // target_ep_size) * target_ep_rank + end_idx = (num_split // target_ep_size) * (target_ep_rank + 1) + + hf_keys_list.extend(hf_keys[start_idx:end_idx]) + hf_tensor_list.extend(hf_tensor[start_idx:end_idx]) + + hf_tensor_list = [ + model.param_to_safetensor(safetensor, name) for safetensor, name in zip(hf_tensor_list, hf_keys_list) + ] + + return hf_tensor_list, hf_keys_list + + safetensor_size = 0 + dtype = torch.bfloat16 + tensor_list: list[tuple[torch.Tensor, LoadSpec]] = [] + + for param, load_spec in fused_param_groups: + tensor_size = dtype.itemsize * param.numel() // target_ep_size + if safetensor_size + tensor_size > bucket_size and tensor_list: + hf_params, name_list = _get_hf_params(tensor_list) + yield name_list, hf_params + safetensor_size = tensor_size + name_list = load_spec.hf_keys.copy() + tensor_list = [(param, load_spec)] + continue + safetensor_size += tensor_size + tensor_list.append((param, load_spec)) + + if tensor_list: + hf_params, name_list = _get_hf_params(tensor_list) + yield name_list, hf_params + + @torch.no_grad() + def _update_weights_hf_generator(self, submodule=None, final_update=False): + """Update the model weights.""" + self.endpoints["update_weights"] = "update_weights" + assert self.rollout_device_mesh is not None + + model = self._engine.model + if submodule: + model = getattr(model, submodule) + + dtype = torch.bfloat16 + bucket_size = int(self.config.update_weight_bucket_size_in_gb * 1024**3) + same_gen = model._get_same_hf_param( + model._group_param_by_load_spec(LoadEnum.SAME), dtype=dtype, device=DEVICE, bucket_size=bucket_size + ) + + train_enable_ep = model.fsdp_config is not None and model.fsdp_config.ep_size > 1 + if train_enable_ep: + if self.rollout_cfg_info["ep"] > 1: + fused_gen = self._rl_get_fused_ep_hf_param( + model, + target_ep_rank=self.rollout_device_mesh["engine_parallel"].get_coordinate()[0], + target_ep_size=self.rollout_device_mesh["engine_parallel"].size(), + bucket_size=bucket_size, + ) + else: + fused_gen = self._rl_get_fused_ep_hf_param( + model, + target_ep_rank=0, + target_ep_size=1, + bucket_size=bucket_size, + ) + else: + fused_gen = model._get_fused_hf_param( + model._group_param_by_load_spec(LoadEnum.FUSED), + dtype=dtype, + device=DEVICE, + bucket_size=bucket_size, + update_weights_for_rl=True, + ) + shard_gen = model._get_shard_hf_param( + model._group_param_by_load_spec(LoadEnum.SHARD), dtype=dtype, device=DEVICE, bucket_size=bucket_size + ) + + for name_list, fused_param_list in fused_gen: + state_dict = {name: param.detach() for name, param in zip(name_list, fused_param_list)} + self.request_update_params(state_dict, train_enable_ep=train_enable_ep, finished=False) + del state_dict, name_list, fused_param_list + + for name_list, param_list in chain(same_gen, shard_gen): + state_dict = {name: param.detach() for name, param in zip(name_list, param_list)} + self.request_update_params(state_dict, train_enable_ep=train_enable_ep, finished=False) + del state_dict, name_list, param_list + + if self.rollout_cfg_info["backend"] == "pytorch" and final_update: + self.request_update_params({}, train_enable_ep=train_enable_ep, finished=True) + + if self.is_train_rollout_colocated: + dist.barrier() + else: + dist.barrier(group=self._get_train_update_sync_group()) + DEVICE_MODULE.empty_cache() + return + + def _update_weights_by_layer(self): + """Update the model weights.""" + self.endpoints["update_weights"] = "update_weights" + assert self.rollout_device_mesh is not None + + model = self._engine.model + DEVICE_MODULE.empty_cache() + + if isinstance(model.config, BaseComposeConfig): + # TODO: support float8 for vision compose model + dtype = torch.bfloat16 + else: + if (model.config.float8_cfg is not None) and (model.config.float8_cfg.enable_float8): + dtype = torch.float8_e4m3fn + else: + dtype = torch.bfloat16 + + def get_params(tensor_list, name_list, save_dtype): + _tensor_list, _spec_list = list(zip(*tensor_list)) + fsdp_unshard_tensor_list = model._fsdp_foreach_allgather(_tensor_list, _spec_list) + if save_dtype == torch.float8_e4m3fn: + fsdp_unshard_tensor_list, name_list = model._to_float8( + fsdp_unshard_tensor_list, name_list, _tensor_list, save_dtype + ) + return fsdp_unshard_tensor_list, name_list + + saved_list = [] + is_qwen3vl = False + if isinstance(model.config, BaseComposeConfig): + language_model = model.language_model + if isinstance(model, Qwen3VLForConditionalGeneration): + is_qwen3vl = True + else: + language_model = model + + if is_qwen3vl: + vision_hf_prefix = "model.visual." + projector_hf_prefix = "model.visual." + else: + vision_hf_prefix = "model.vision_tower." + projector_hf_prefix = "model.multi_modal_projector." + + for i, layer in tqdm.tqdm(language_model.layers.items(), desc="[gather weight]"): + tensor_list = [] + name_list = [] + for sub_name, param in layer.state_dict().items(): + if isinstance(model.config, BaseComposeConfig): + saved_list.append(f"language_model.layers.{i}.{sub_name}") + else: + saved_list.append(f"layers.{i}.{sub_name}") + local_tensor = param._local_tensor if isinstance(param, DTensor) else param + local_tensor = local_tensor.bfloat16() + load_spec = language_model.load_spec_mapping.get(f"layers.{i}.{sub_name}") + + if isinstance(model.config, BaseComposeConfig): + name = f"model.language_model.layers.{i}.{sub_name}" + else: + name = f"model.layers.{i}.{sub_name}" + + if ".experts." in name and ".mlp.experts." not in name: + name = name.replace(".experts.", ".mlp.experts.") + if ".gate." in name and ".mlp.gate." not in name: + name = name.replace(".gate.", ".mlp.gate.") + name_list.append(name) + tensor_list.append((local_tensor, load_spec)) + fsdp_unshard_tensor_list, name_list = get_params(tensor_list, name_list, dtype) + state_dict = dict(zip(name_list, fsdp_unshard_tensor_list)) + self.request_update_params(state_dict) + + for name, param in model.state_dict().items(): + if name in saved_list: + continue + local_tensor = param._local_tensor if isinstance(param, DTensor) else param + local_tensor = local_tensor.bfloat16() + load_spec = model.load_spec_mapping.get(name) + + if isinstance(model.config, BaseComposeConfig): + if "vision_tower." in name: + name = name.replace("vision_tower.", vision_hf_prefix) + elif "multi_modal_projector." in name: + name = name.replace("multi_modal_projector.", projector_hf_prefix) + elif name == "language_model.norm.weight": + name = "model.language_model.norm.weight" + elif name == "language_model.embed_tokens.weight": + name = "model.language_model.embed_tokens.weight" + elif name == "language_model.lm_head.weight": + name = "lm_head.weight" + else: + if name == "norm.weight": + name = "model.norm.weight" + elif name == "embed_tokens.weight": + name = "model.embed_tokens.weight" + tensor_list = [(local_tensor, load_spec)] + name_list = [name] + fsdp_unshard_tensor_list, name_list = get_params(tensor_list, name_list, dtype) + state_dict = dict(zip(name_list, fsdp_unshard_tensor_list)) + self.request_update_params(state_dict) + + if self.rollout_cfg_info["backend"] in ("pytorch", "vllm"): + self.request_update_params({}, finished=True) + + dist.barrier() + DEVICE_MODULE.empty_cache() + return + + @staticmethod + def _compute_state_dict_bytes(state_dict: Dict[str, torch.Tensor]) -> int: + total_bytes = 0 + for tensor in state_dict.values(): + total_bytes += tensor.numel() * tensor.element_size() + return total_bytes + + + @staticmethod + def _init_external_process_group( + backend: str | Backend | None = None, + init_method: str | None = None, + timeout: timedelta | None = None, + world_size: int = -1, + rank: int = -1, + store: Store | None = None, + group_name: str | None = None, + pg_options: Any | None = None, + ) -> dist.ProcessGroup: + assert (store is None) or (init_method is None), "Cannot specify both store and init_method." + if store is not None: + assert world_size > 0, "world_size must be positive if using store" + assert rank >= 0, "rank must be non-negative if using store" + elif init_method is None: + init_method = "env://" + + backend = Backend(backend) if backend else Backend("undefined") + if timeout is None: + timeout = default_pg_timeout + + if store is None: + rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout) + store, rank, world_size = next(rendezvous_iterator) + store.set_timeout(timeout) + if group_name is not None: + store = PrefixStore(group_name, store) + + pg_options_param_name = ( + "backend_options" if parse_version(torch.__version__) >= parse_version("2.6") else "pg_options" + ) + pg, _ = _new_process_group_helper( + world_size, + rank, + [], + backend, + store, + group_name=group_name, + **{pg_options_param_name: pg_options}, + timeout=timeout, + ) + _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} + return pg + + @staticmethod + def _create_ipc_tensor(size_in_bytes: int, dtype: torch.dtype): + return torch.empty(size_in_bytes, dtype=torch.uint8, device=DEVICE).view(dtype) + + def _get_sglang_disagg_engine_info(self) -> RolloutEngineInfo: + engine_info: RolloutEngineInfo = [] + seen_urls: set[str] = set() + rank_to_engine_size: dict[int, int] = {} + for engine_ranks in self.rollout_engine_rank_mesh_array: + engine_size = len(engine_ranks) + for rank in engine_ranks: + rank_to_engine_size[int(rank)] = engine_size + + for rank, url in sorted(self.rollout_server_url_dict.items(), key=lambda item: int(item[0])): + rank = int(rank) + if not url or url in seen_urls: + continue + if self.worker_server_urls_status.get(url, False) is False: + continue + seen_urls.add(url) + engine_info.append( + ( + rank, + url, + rank_to_engine_size.get(rank, max(self.rollout_cfg_info["tp"], self.rollout_cfg_info["ep"])), + ) + ) + return engine_info + + def _ensure_sglang_disagg_group(self): + if self._sglang_disagg_group is not None: + return + engine_info = self._get_sglang_disagg_engine_info() + if not engine_info: + self.logger.error("No active rollout engine url, cannot init sglang weight update group") + return + + os.environ["TORCHELASTIC_USE_AGENT_STORE"] = "False" + backend = "nccl" + + master_address = None + master_port = None + # get address and port for weight-update + try: + import ray + + master_address = ray.util.get_node_ip_address() + except Exception: + master_address = socket.gethostbyname(socket.gethostname()) + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("", 0)) + master_port = int(sock.getsockname()[1]) + + group_name = f"xtuner_sglang_weight_update_{self.rank}" + world_size = sum(engine_size for _, _, engine_size in engine_info) + 1 + + self._sglang_disagg_executor = ThreadPoolExecutor(max_workers=max(1, len(engine_info))) + init_futures = [] + rank_offset = 1 + for _, url, engine_size in engine_info: + payload = { + "master_address": master_address, + "master_port": master_port, + "rank_offset": rank_offset, + "world_size": world_size, + "group_name": group_name, + "backend": backend, + } + init_futures.append( + self._sglang_disagg_executor.submit( + requests.post, + f"{url}/init_weights_update_group", + json=payload, + ) + ) + rank_offset += engine_size + + self._sglang_disagg_group = self._init_external_process_group( + backend=backend, + init_method=f"tcp://{master_address}:{master_port}", + world_size=world_size, + rank=0, + group_name=group_name, + ) + + for init_future in init_futures: + response = init_future.result() + response.raise_for_status() + result = response.json() + assert result.get("success", True), ( + f"SGLang init_weights_update_group failed: {result.get('message', result)}" + ) + + self._sglang_disagg_group_name = group_name + self._sglang_disagg_engine_urls = [url for _, url, _ in engine_info] + + def _request_update_params_sglang_disaggregated(self, state_dict): + if not state_dict: + return + + train_sync_group = self._get_train_update_sync_group() + head_rank = 0 + if dist.get_rank() != head_rank: + dist.barrier(group=train_sync_group) + return + + self._ensure_sglang_disagg_group() + if self._sglang_disagg_group is None: + dist.barrier(group=train_sync_group) + return + + assert self._sglang_disagg_executor is not None + assert self._sglang_disagg_group_name is not None + with self._sglang_disagg_update_lock: + try: + from sglang.srt.model_executor.model_runner import FlattenedTensorBucket + except Exception as e: + raise RuntimeError( + "Disaggregated update_weights currently only supports sglang builds " + "that provide `sglang.srt.model_executor.model_runner.FlattenedTensorBucket`." + ) from e + + names = list(state_dict.keys()) + tensors = [tensor.detach().to(device=DEVICE, non_blocking=True).contiguous() for tensor in state_dict.values()] + payload = { + "names": names, + "dtypes": [str(tensor.dtype).replace("torch.", "") for tensor in tensors], + "shapes": [list(tensor.shape) for tensor in tensors], + "group_name": self._sglang_disagg_group_name, + "load_format": "flattened_bucket", + } + update_futures = [ + self._sglang_disagg_executor.submit( + requests.post, + f"{url}/update_weights_from_distributed", + json=payload, + ) + for url in self._sglang_disagg_engine_urls + ] + assert self._sglang_disagg_group is not None + flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=list(zip(names, tensors))) + flattened_tensor = flattened_tensor_bucket.get_flattened_tensor() + + dist.broadcast(flattened_tensor, src=0, group=self._sglang_disagg_group) + DEVICE_MODULE.synchronize() + for update_future in update_futures: + response = update_future.result() + response.raise_for_status() + result = response.json() + assert result.get("success", True), ( + f"SGLang update_weights_from_distributed failed: {result.get('message', result)}" + ) + dist.barrier(group=train_sync_group) + + @ray_method + def request_update_params(self, state_dict, train_enable_ep=False, finished=False): + """Send a request to update the parameters on the rollout workers. + + This method serializes the state dictionary and sends it to the + appropriate rollout worker via an HTTP request. + + Args: + state_dict (dict | list): The state dictionary containing the model + parameters to update. + train_enable_ep (bool): Whether the training engine enables expert parallelism. + Defaults to False. + finished (bool): A flag indicating whether this is the final + batch of updates. Defaults to False. + """ + + if self.rollout_cfg_info["backend"] == "sglang" and not self.is_train_rollout_colocated: + self._request_update_params_sglang_disaggregated(state_dict) + return + + cpu_mesh = self.rollout_device_mesh["engine_parallel"] + cpu_group = cpu_mesh.get_group() + head_rank = cpu_mesh.mesh[0].item() + if self.rollout_url is None: + self.logger.error(f"rank {self.rank} url in None, cannot update weights and skip") + return + + if self.rollout_cfg_info["backend"] == "vllm": + + def serialize_state_dict(state_dict: dict) -> str: + import base64 + from io import BytesIO + from multiprocessing.reduction import ForkingPickler + + from torch.multiprocessing.reductions import reduce_tensor + + data = [(k, reduce_tensor(v)) for k, v in state_dict.items()] + buf = BytesIO() + ForkingPickler(buf).dump(data) + buf.seek(0) + return base64.b64encode(buf.read()).decode("utf-8") + + serialized_data = [None] * self.rollout_cfg_info["tp"] + dist.gather_object( + serialize_state_dict(state_dict), + serialized_data if dist.get_rank() == head_rank else None, + dst=head_rank, + group=cpu_group, + ) + if dist.get_rank() == head_rank: + headers = { + "Content-Type": "application/json", + } + data_ = json.dumps(dict(serialized_named_tensors=serialized_data, finished=finished)) + data = dict(method="update_weight_npu_ipc", args=[data_]) + response = requests.post(f"{self.rollout_url}/collective_rpc", headers=headers, json=data) + assert response.status_code == 200, f"response.status_code = {response.status_code}" + + if finished: + dist.barrier(group=cpu_group) + return + + if self.rollout_cfg_info["backend"] == "pytorch": + # TODO(chenchiyu): remove lmdeploy related code + from lmdeploy.utils import serialize_state_dict + + try: + from lmdeploy.utils import FlattenedTensorBucket + + use_flattened_tensor_bucket = True + except Exception: + use_flattened_tensor_bucket = False + + if self.rollout_cfg_info["backend"] == "pytorch" and self.rollout_cfg_info["tp"] > 1: + serialized_data = [None] * self.rollout_cfg_info["tp"] + if use_flattened_tensor_bucket and state_dict: + state_dict_bytes = self._compute_state_dict_bytes(state_dict) + send_ipc_tensor = ( + state_dict_bytes > self._ipc_tensor_bytes or self._update_params_ipc_tensor is None + ) + if send_ipc_tensor: + self._ipc_tensor_bytes = max(self._ipc_tensor_bytes, state_dict_bytes) + if self._update_params_ipc_tensor is not None: + self._update_params_ipc_event.wait() + torch.cuda.synchronize() + self._update_params_ipc_tensor = self._create_ipc_tensor( + self._ipc_tensor_bytes, state_dict[next(iter(state_dict))].dtype + ) + else: + self._update_params_ipc_event.wait() + + flattened_tensor_bucket = FlattenedTensorBucket( + named_tensors=list(state_dict.items()), + flattened_tensor=self._update_params_ipc_tensor, + ) + metadata = flattened_tensor_bucket.get_metadata() + flattened_tensor_data = { + "metadata": metadata, + "require_clone": False, + } + self._update_params_ipc_event.record() + + if send_ipc_tensor: + flattened_tensor_data["flattened_tensor"] = flattened_tensor_bucket.get_flattened_tensor() + flattened_tensor_data["event_ipc_handle"] = self._update_params_ipc_event.ipc_handle() + tp_serialized_data = serialize_state_dict(flattened_tensor_data) + else: + tp_serialized_data = serialize_state_dict(state_dict) + dist.gather_object( + tp_serialized_data, + serialized_data if dist.get_rank() == head_rank else None, + dst=head_rank, + group=cpu_group, + ) + elif self.rollout_cfg_info["backend"] == "pytorch": + if use_flattened_tensor_bucket: + state_dict_bytes = self._compute_state_dict_bytes(state_dict) + send_ipc_tensor = ( + state_dict_bytes > self._ipc_tensor_bytes or self._update_params_ipc_tensor is None + ) + if send_ipc_tensor: + self._ipc_tensor_bytes = max(self._ipc_tensor_bytes, state_dict_bytes) + if self._update_params_ipc_tensor is not None: + # wait previous ipc event recorded of lmdeploy + self._update_params_ipc_event.wait() + torch.cuda.synchronize() + self._update_params_ipc_tensor = self._create_ipc_tensor( + self._ipc_tensor_bytes, state_dict[next(iter(state_dict))].dtype + ) + else: + self._update_params_ipc_event.wait() + + flattened_tensor_bucket = FlattenedTensorBucket( + named_tensors=list(state_dict.items()), + flattened_tensor=self._update_params_ipc_tensor, + ) + metadata = flattened_tensor_bucket.get_metadata() + flattened_tensor_data = { + "metadata": metadata, + "require_clone": False, + } + self._update_params_ipc_event.record() + + if send_ipc_tensor: + flattened_tensor_data["flattened_tensor"] = flattened_tensor_bucket.get_flattened_tensor() + flattened_tensor_data["event_ipc_handle"] = self._update_params_ipc_event.ipc_handle() + serialized_data = serialize_state_dict(flattened_tensor_data) + else: + serialized_data = serialize_state_dict(state_dict) + else: + # for turbomind backend, only head_rank should serialize data + serialized_data = serialize_state_dict(state_dict) if dist.get_rank() == head_rank else None + else: + # sglang + from sglang.srt.utils import MultiprocessingSerializer + from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions + + try: + from sglang.srt.model_executor.model_runner import FlattenedTensorBucket + + use_flattened_tensor_bucket = True + except Exception: + use_flattened_tensor_bucket = False + + # NOTE: xtuner目前去掉sglang的patch也不会出问题,但为了保险起见,还是保留patch逻辑,并且在update_weights结束后unpatch + monkey_patch_torch_reductions() + state_dict = state_dict.items() + if self.rollout_cfg_info["tp"] == 1: + if use_flattened_tensor_bucket: + flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=state_dict) + metadata = flattened_tensor_bucket.get_metadata() + + flattened_tensor_data = { + "flattened_tensor": flattened_tensor_bucket.get_flattened_tensor(), + "metadata": metadata, + } + serialized_data = MultiprocessingSerializer.serialize(flattened_tensor_data, output_str=True) + else: + serialized_data = MultiprocessingSerializer.serialize(state_dict, output_str=True) + + serialized_data = [serialized_data] + else: + serialized_data = [None] * self.rollout_cfg_info["tp"] + if use_flattened_tensor_bucket: + flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=state_dict) + metadata = flattened_tensor_bucket.get_metadata() + + flattened_tensor_data = { + "flattened_tensor": flattened_tensor_bucket.get_flattened_tensor(), + "metadata": metadata, + } + tp_serialized_data = MultiprocessingSerializer.serialize(flattened_tensor_data, output_str=True) + dist.gather_object( + tp_serialized_data, + serialized_data if dist.get_rank() == head_rank else None, + dst=head_rank, + group=cpu_group, + ) + else: + tp_serialized_data = MultiprocessingSerializer.serialize(state_dict, output_str=True) + dist.gather_object( + tp_serialized_data, + serialized_data if dist.get_rank() == head_rank else None, + dst=head_rank, + group=cpu_group, + ) + + if dist.get_rank() == head_rank: + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.rollout_cfg_info['api_key']}", + } + if self.rollout_cfg_info["backend"] == "sglang": + payload = { + "serialized_named_tensors": serialized_data, + "flush_cache": False, + } + try: + from sglang.srt.model_executor.model_runner import FlattenedTensorBucket + + use_flattened_tensor_bucket = True + except Exception: + use_flattened_tensor_bucket = False + if use_flattened_tensor_bucket: + payload["load_format"] = "flattened_bucket" + + url = f"{self.rollout_url}/update_weights_from_tensor" + response = requests.post(url, json=payload or {}) + response.raise_for_status() + else: + data = dict(serialized_named_tensors=serialized_data, finished=finished) + try: + from lmdeploy.utils import FlattenedTensorBucket + + use_flattened_tensor_bucket = True + except Exception: + use_flattened_tensor_bucket = False + + if use_flattened_tensor_bucket: + data["load_format"] = "flattened_bucket" + response = requests.post( + f"{self.rollout_url}/{self.endpoints['update_weights']}", headers=headers, json=data + ) + assert response.status_code == 200, f"response.status_code = {response.status_code}" + + # TODO(chenchiyu): narrow this condition + if finished or (self.rollout_cfg_info["backend"] == "pytorch" and train_enable_ep and self.rollout_cfg_info["tp"] > 1): + # This barrier is aim to make each tp head rank sync with other ranks in engine_parallel group + # which could not be barrier by `fsdp_foreach_allgather` of the next state dict. (Happens in same_gen, shard not tested) + # Without barrier, some ranks in engine_parallel group would not wait for current iter data ipc event recording in lmdeploy. + # They would write next iter state_dict into the ipc tensor before lmdeploy load current iter weight. + dist.barrier(group=cpu_group) + + monkey_unpatch_torch_reductions() + return diff --git a/xtuner/v1/rl/trainer/worker.py b/xtuner/v1/rl/trainer/worker.py index fd63d4085..868b68df5 100644 --- a/xtuner/v1/rl/trainer/worker.py +++ b/xtuner/v1/rl/trainer/worker.py @@ -56,6 +56,7 @@ from ..rollout_is import merge_rollout_is_metrics +from .update_weighter import UpdateWeighter DeviceMeshRaw: TypeAlias = List[List[int]] # A list of lists representing device mesh indices ServiceUrlMap: TypeAlias = Dict[int, str] # A dictionary mapping service names to their URLs @@ -63,30 +64,6 @@ DEVICE_MODULE = get_torch_device_module() -def serialize_state_dict(state_dict: dict) -> str: - """Serialize state dict to str. - - The consumer should use it on same node. As the producer and consumer may - have different GPU visibility, we use reduce_tensor instead of ForkingPickler.dumps - to fix the device_id when loading the serialized tensor. - - Args: - state_dict (dict[str, torch.Tensor]): state dict to serialize. - Returns: - str: serialized state dict. - """ - import base64 - from io import BytesIO - from multiprocessing.reduction import ForkingPickler - - from torch.multiprocessing.reductions import reduce_tensor - - data = [(k, reduce_tensor(v)) for k, v in state_dict.items()] - buf = BytesIO() - ForkingPickler(buf).dump(data) - buf.seek(0) - return base64.b64encode(buf.read()).decode("utf-8") - def calculate_entropy( shifted_labels_list: Sequence[torch.Tensor], @@ -213,7 +190,7 @@ class WorkerLogItem(TypedDict): sft_train_metrics: NotRequired[dict[str, float]] -class TrainingWorker(SingleAcceleratorWorker): +class TrainingWorker(SingleAcceleratorWorker, UpdateWeighter): _SAVE_OPTIMIZER_DIR = "optimizer" _SAVE_MODEL_DIR = "model" _SAVE_SFT_DATALOADER_DIR = "sft_dataloader" @@ -265,12 +242,6 @@ def __init__( self._optimizer_steps = worker_cfg.optimizer_steps - # Used to update weight to rollout engine - self.rollout_device_mesh: DeviceMesh | None = None - self.rollout_url: str | None = None - self.rollout_cfg_info: dict = dict() - self.endpoints: dict[str, str] = dict() - self.endpoints["update_weights"] = "update_weights" if worker_cfg.loss_cfg.chunk_size is not None: mode = "chunk" else: @@ -280,6 +251,8 @@ def __init__( if isinstance(worker_cfg.model_cfg, BaseComposeConfig): if hasattr(worker_cfg.model_cfg.text_config, "mtp_config"): self.mtp_config = worker_cfg.model_cfg.text_config.mtp_config + + self._init_update_weighter() def _init_sft(self, worker_cfg: WorkerConfig): self._sft_dataloader_config = worker_cfg.sft_dataloader_cfg @@ -909,606 +882,6 @@ def onload_model(self): def onload_optimizer(self): self._engine.put_optimizer_to_device(DEVICE) - @ray_method - def update_rollout_info( - self, - engine_rank_mesh_array: DeviceMeshRaw, - server_url_dict: ServiceUrlMap, - rollout_config: RolloutConfig, - worker_server_urls_status: Dict[str, bool], - api_server_url: str | None = None, - ): - """Update the rollout information for the training worker.""" - tp = rollout_config.tensor_parallel_size - ep = rollout_config.expert_parallel_size - assert tp == 1 or ep == 1, "Either tensor parallel size or engine parallel size must be 1." - if self.rollout_device_mesh is None: - self.rollout_device_mesh = DeviceMesh( - "cpu", mesh=engine_rank_mesh_array, mesh_dim_names=("engine_instance", "engine_parallel") - ) - rollout_server_url = server_url_dict.get(self.rank, "") - if worker_server_urls_status.get(rollout_server_url, "False") is False: - self.logger.error(f"Rollout server url {rollout_server_url} is not available.") - self.rollout_url = None - else: - self.rollout_url = rollout_server_url - self.rollout_cfg_info["tp"] = tp - self.rollout_cfg_info["ep"] = ep - self.rollout_cfg_info["api_key"] = rollout_config.api_key - if os.environ.get("XTUNER_USE_SGLANG", "0") == "1": - self.rollout_cfg_info["backend"] = "sglang" - elif os.environ.get("XTUNER_USE_VLLM", "0") == "1": - self.rollout_cfg_info["backend"] = "vllm" - else: - self.rollout_cfg_info["backend"] = (rollout_config.extra_rollout_config or dict()).get( - "lmdeploy_backend", "pytorch" - ) - - @ray_method - def update_weights(self): - """Update the model weights.""" - if self.rollout_cfg_info.get("backend") == "turbomind": - self._update_weights_by_layer() - else: - if isinstance(self.config.model_cfg, BaseComposeConfig): - self._update_weights_hf_generator(submodule="vision_tower", final_update=False) - self._update_weights_hf_generator(submodule="multi_modal_projector", final_update=False) - self._update_weights_hf_generator(submodule="language_model", final_update=True) - else: - self._update_weights_hf_generator() - - def _update_weights_hf_generator(self, submodule=None, final_update=True): - """Update the model weights.""" - self.endpoints["update_weights"] = "update_weights" - assert self.rollout_device_mesh is not None - - model = self._engine.model - if submodule: - model = getattr(model, submodule) - - DEVICE_MODULE.empty_cache() - - # TODO: force bfloat16 dtype for now - dtype = torch.bfloat16 - - bucket_size = int(self.config.update_weight_bucket_size_in_gb * 1024**3) - same_gen = model._get_same_hf_param( - model._group_param_by_load_spec(LoadEnum.SAME), dtype=dtype, device=DEVICE, bucket_size=bucket_size - ) - fused_gen = model._get_fused_hf_param( - model._group_param_by_load_spec(LoadEnum.FUSED), - dtype=dtype, - device=DEVICE, - bucket_size=bucket_size, - update_weights_for_rl=True, - ) - shard_gen = model._get_shard_hf_param( - model._group_param_by_load_spec(LoadEnum.SHARD), dtype=dtype, device=DEVICE, bucket_size=bucket_size - ) - - for name_list, fused_param_list in fused_gen: - state_dict = {name: param.detach() for name, param in zip(name_list, fused_param_list)} - if model.fsdp_config.ep_size > 1: - # When ep_size > 1, generator generates part of the fused param on each ep rank in one ep_group. - # We can all gather them to get full fused param but it would lead to a larger memory usage. - # So we broadcast the part fused param from each ep rank in ep_group sequentially, - # and update the part of the fused param sequentially to reduce memory usage. - if isinstance(model.config, BaseComposeConfig): - ep_mesh: DeviceMesh = model.language_model.ep_mesh - else: - ep_mesh: DeviceMesh = model.ep_mesh - ep_group = ep_mesh.get_group() - global_rank = dist.get_rank() - for src_global_rank in dist.get_process_group_ranks(ep_group): - broadcast_state_dict = dict() - for key, tensor in state_dict.items(): - obj_to_broadcast = [key, tensor.to("meta")] if global_rank == src_global_rank else [None, None] - dist.broadcast_object_list(obj_to_broadcast, src=src_global_rank, group=ep_group) - real_key, meta_tensor = obj_to_broadcast - buffer = ( - state_dict[real_key] - if global_rank == src_global_rank - else torch.empty_like(meta_tensor, device=DEVICE) - ) - dist.broadcast(buffer, src=src_global_rank, group=ep_group) - broadcast_state_dict[real_key] = buffer - self.request_update_params(broadcast_state_dict, finished=False) - del broadcast_state_dict, buffer - else: - self.request_update_params(state_dict, finished=False) - del state_dict, name_list, fused_param_list - - for name_list, param_list in chain(same_gen, shard_gen): - state_dict = {name: param.detach() for name, param in zip(name_list, param_list)} - self.request_update_params(state_dict, finished=False) - del state_dict, name_list, param_list - - if self.rollout_cfg_info["backend"] in ("pytorch", "vllm") and final_update: - self.request_update_params({}, finished=True) - - dist.barrier() - DEVICE_MODULE.empty_cache() - return - - def _update_weights_by_layer(self): - """Update the model weights.""" - self.endpoints["update_weights"] = "update_weights" - assert self.rollout_device_mesh is not None - - model = self._engine.model - DEVICE_MODULE.empty_cache() - - if isinstance(model.config, BaseComposeConfig): - # TODO: support float8 for vision compose model - dtype = torch.bfloat16 - else: - if (model.config.float8_cfg is not None) and (model.config.float8_cfg.enable_float8): - dtype = torch.float8_e4m3fn - else: - dtype = torch.bfloat16 - - def get_params(tensor_list, name_list, save_dtype): - _tensor_list, _spec_list = list(zip(*tensor_list)) - fsdp_unshard_tensor_list = model._fsdp_foreach_allgather(_tensor_list, _spec_list) - if save_dtype == torch.float8_e4m3fn: - fsdp_unshard_tensor_list, name_list = model._to_float8( - fsdp_unshard_tensor_list, name_list, _tensor_list, save_dtype - ) - return fsdp_unshard_tensor_list, name_list - - saved_list = [] - is_qwen3vl = False - if isinstance(model.config, BaseComposeConfig): - language_model = model.language_model - if isinstance(model, Qwen3VLForConditionalGeneration): - is_qwen3vl = True - else: - language_model = model - - if is_qwen3vl: - vision_hf_prefix = "model.visual." - projector_hf_prefix = "model.visual." - else: - vision_hf_prefix = "model.vision_tower." - projector_hf_prefix = "model.multi_modal_projector." - - for i, layer in tqdm.tqdm(language_model.layers.items(), desc="[gather weight]"): - tensor_list = [] - name_list = [] - for sub_name, param in layer.state_dict().items(): - if isinstance(model.config, BaseComposeConfig): - saved_list.append(f"language_model.layers.{i}.{sub_name}") - else: - saved_list.append(f"layers.{i}.{sub_name}") - local_tensor = param._local_tensor if isinstance(param, DTensor) else param - local_tensor = local_tensor.bfloat16() - load_spec = language_model.load_spec_mapping.get(f"layers.{i}.{sub_name}") - - if isinstance(model.config, BaseComposeConfig): - name = f"model.language_model.layers.{i}.{sub_name}" - else: - name = f"model.layers.{i}.{sub_name}" - - if ".experts." in name and ".mlp.experts." not in name: - name = name.replace(".experts.", ".mlp.experts.") - if ".gate." in name and ".mlp.gate." not in name: - name = name.replace(".gate.", ".mlp.gate.") - name_list.append(name) - tensor_list.append((local_tensor, load_spec)) - fsdp_unshard_tensor_list, name_list = get_params(tensor_list, name_list, dtype) - state_dict = dict(zip(name_list, fsdp_unshard_tensor_list)) - self.request_update_params(state_dict) - - for name, param in model.state_dict().items(): - if name in saved_list: - continue - local_tensor = param._local_tensor if isinstance(param, DTensor) else param - local_tensor = local_tensor.bfloat16() - load_spec = model.load_spec_mapping.get(name) - - if isinstance(model.config, BaseComposeConfig): - if "vision_tower." in name: - name = name.replace("vision_tower.", vision_hf_prefix) - elif "multi_modal_projector." in name: - name = name.replace("multi_modal_projector.", projector_hf_prefix) - elif name == "language_model.norm.weight": - name = "model.language_model.norm.weight" - elif name == "language_model.embed_tokens.weight": - name = "model.language_model.embed_tokens.weight" - elif name == "language_model.lm_head.weight": - name = "lm_head.weight" - else: - if name == "norm.weight": - name = "model.norm.weight" - elif name == "embed_tokens.weight": - name = "model.embed_tokens.weight" - tensor_list = [(local_tensor, load_spec)] - name_list = [name] - fsdp_unshard_tensor_list, name_list = get_params(tensor_list, name_list, dtype) - state_dict = dict(zip(name_list, fsdp_unshard_tensor_list)) - self.request_update_params(state_dict) - - if self.rollout_cfg_info["backend"] in ("pytorch", "vllm"): - self.request_update_params({}, finished=True) - - dist.barrier() - DEVICE_MODULE.empty_cache() - return - - # def update_weights1(self): - # """Update the model weights.""" - # self.endpoints["update_weights"] = "update_weights" - # assert self.rollout_device_mesh is not None - # time1 = time.time() - - # model = self._engine.model - # DEVICE_MODULE.empty_cache() - - # if (model.config.float8_cfg is not None) and (model.config.float8_cfg.enable_float8): - # dtype = torch.float8_e4m3fn - # else: - # dtype = torch.bfloat16 - - # fused_params = [] - # for name, param in model.state_dict().items(): - # load_spec = model.load_spec_mapping.get(name) - # if load_spec.load_enum == LoadEnum.FUSED: - # fused_params.append((name, param, load_spec)) - - # # TODO: decouple update_weights from the model structure - # bucket_size = 1024**3 - # safetensor_size = 0 - # tensor_list: list[tuple[torch.Tensor, LoadSpec]] = [] - # name_list: list[str] = [] - # for name, param, load_spec in fused_params: - # local_tensor = param._local_tensor if isinstance(param, DTensor) else param - # local_tensor = local_tensor.bfloat16() - # if safetensor_size + model._get_tensor_size(param, dtype) > bucket_size: - # _tensor_list, _spec_list = list(zip(*tensor_list)) - # fsdp_unshard_tensor_list = model._fsdp_foreach_allgather(_tensor_list, _spec_list) - # if dtype == torch.float8_e4m3fn: - # fsdp_unshard_tensor_list, name_list = model._to_float8( - # fsdp_unshard_tensor_list, name_list, _tensor_list, dtype - # ) - # state_dict = dict(zip(name_list, fsdp_unshard_tensor_list)) - # self.request_update_params(state_dict) - # safetensor_size = 0 - # tensor_list = [(local_tensor, load_spec)] - # name_list = ["model." + name.replace(".experts.", ".mlp.experts.")] - # continue - # safetensor_size += model._get_tensor_size(param, dtype) - # tensor_list.append((local_tensor, load_spec)) - # name_list.append("model." + name.replace(".experts.", ".mlp.experts.")) - - # if tensor_list: - # assert len(name_list) == len(tensor_list) - # _tensor_list, _spec_list = list(zip(*tensor_list)) - # fsdp_unshard_tensor_list = model._fsdp_foreach_allgather(_tensor_list, _spec_list) - # if dtype == torch.float8_e4m3fn: - # fsdp_unshard_tensor_list, name_list = model._to_float8( - # fsdp_unshard_tensor_list, name_list, _tensor_list, dtype - # ) - # state_dict = dict(zip(name_list, fsdp_unshard_tensor_list)) - # self.request_update_params(state_dict) - - # same_gen = model._get_same_hf_param( - # model._group_param_by_load_spec(LoadEnum.SAME), - # dtype=dtype, - # device="cuda", - # bucket_size=1024**3, - # ) - # for name_list, gathered_tensor_list in tqdm.tqdm(same_gen, desc="[update dense weights]"): - # state_dict = dict(zip(name_list, gathered_tensor_list)) - # self.request_update_params(state_dict) - # del state_dict - - # self.request_update_params({}, finished=True) - - # dist.barrier() - # logger.info(f"update weights time: {time.time() - time1}") - # DEVICE_MODULE.empty_cache() - # return - - # def update_weights(self): - # """Update the model weights.""" - # self.endpoints["update_weights"] = "update_weights" - # assert self.rollout_device_mesh is not None - - # model = self._engine.model - # DEVICE_MODULE.empty_cache() - - # saved_keys = [] - # gather_duration = [] - # weight_duration = [] - # reshard_duration = [] - - # # update decoder layers - # for i, layer in tqdm.tqdm(model.layers.items(), desc="[gather weight]"): - # start = time.perf_counter() - # layer.unshard() - # layer_state_dict = {} - - # for sub_name, param in layer.named_parameters(): - # if "_checkpoint_wrapped_module." in sub_name: - # sub_name = sub_name.replace("_checkpoint_wrapped_module.", "") - # if isinstance(param, DTensor): - # param = param.to_local() - - # if isinstance(param, WeightWithDynamicTilewiseFloat8CastTensor): - # param = param._tensor - - # if isinstance(param, Float8Tensor): - # scale_name = f"model.layers.{i}.{sub_name}_scale_inv" - # assert "fused_w1w3" in sub_name or "fused_w2" in sub_name - # # save scale_inv parameter to state_dict - # scale_tensor = param._scale - # quant_tensor = param._data - # ep_mesh = model.ep_mesh - # if ep_mesh.size() > 1: - # scale_tensor = torch.cat(dist.nn.all_gather(scale_tensor, group=ep_mesh.get_group()), dim=0) - # quant_tensor = torch.cat(dist.nn.all_gather(quant_tensor, group=ep_mesh.get_group()), dim=0) - # layer_state_dict[scale_name] = scale_tensor.detach() - # # set `param` which will be added to state_dict at the bottom of the for-block - # param = quant_tensor - - # param = param.to(DEVICE) - # name = f"model.layers.{i}.{sub_name}" - # saved_keys.append(name.replace("model.", "")) - # if ".experts." in name and ".mlp." not in name: - # name = name.replace(".experts.", ".mlp.experts.") - # if ".gate." in name and ".mlp." not in name: - # name = name.replace(".gate.", ".mlp.gate.") - # layer_state_dict[name] = param.detach() - # gather_duration.append(time.perf_counter() - start) - # start = time.perf_counter() - # self.request_update_params(layer_state_dict, finished=True) - # breakpoint() - # weight_duration.append(time.perf_counter() - start) - - # start = time.perf_counter() - # del layer_state_dict - # layer.reshard() - # reshard_duration.append(time.perf_counter() - start) - - # if dist.get_rank() == 0: - # logger.debug( - # f"Rank 0 Gather decoder layers done, total {sum(gather_duration):.2f}s, avg " - # f"{sum(gather_duration) / len(gather_duration):.2f}s" - # ) - # logger.debug( - # f"Rank 0 migrate/save decoder layers done, total {sum(weight_duration):.2f}s, avg " - # f"{sum(weight_duration) / len(weight_duration):.2f}s" - # ) - # logger.debug( - # f"Rank 0 reshard decoder layers done, total {sum(reshard_duration):.2f}s, avg " - # f"{sum(reshard_duration) / len(reshard_duration):.2f}s" - # ) - - # # update other params - # model.norm.unshard() - # model.lm_head.unshard() - # model.embed_tokens.unshard() - # others_state_dict = {} - # for name, param in model.named_parameters(): - # if "_checkpoint_wrapped_module." in name: - # continue - # if name not in saved_keys: - # saved_keys.append(name) - # if name == "norm.weight": - # name = "model.norm.weight" - # if name == "embed_tokens.weight": - # name = "model.embed_tokens.weight" - # if isinstance(param, DTensor): - # param = param.to_local() - # others_state_dict[name] = param.detach() - # self.request_update_params(others_state_dict, finished=True) - # model.norm.reshard() - # model.lm_head.reshard() - # model.embed_tokens.reshard() - # del others_state_dict - # del param - - # dist.barrier() - # DEVICE_MODULE.empty_cache() - # return - - @ray_method - def request_update_params(self, state_dict, finished=False): - """Send a request to update the parameters on the rollout workers. - - This method serializes the state dictionary and sends it to the - appropriate rollout worker via an HTTP request. - - Args: - state_dict (dict | list): The state dictionary containing the model - parameters to update. - finished (bool): A flag indicating whether this is the final - batch of updates. Defaults to False. - """ - cpu_mesh = self.rollout_device_mesh["engine_parallel"] - cpu_group = cpu_mesh.get_group() - head_rank = cpu_mesh.mesh[0].item() - if self.rollout_url is None: - self.logger.error(f"rank {self.rank} url in None, cannot update weights and skip") - return - - if self.rollout_cfg_info["backend"] == "vllm": - - def serialize_state_dict(state_dict: dict) -> str: - import base64 - from io import BytesIO - from multiprocessing.reduction import ForkingPickler - - from torch.multiprocessing.reductions import reduce_tensor - - data = [(k, reduce_tensor(v)) for k, v in state_dict.items()] - buf = BytesIO() - ForkingPickler(buf).dump(data) - buf.seek(0) - return base64.b64encode(buf.read()).decode("utf-8") - - serialized_data = [None] * self.rollout_cfg_info["tp"] - dist.gather_object( - serialize_state_dict(state_dict), - serialized_data if dist.get_rank() == head_rank else None, - dst=head_rank, - group=cpu_group, - ) - if dist.get_rank() == head_rank: - headers = { - "Content-Type": "application/json", - } - data_ = json.dumps(dict(serialized_named_tensors=serialized_data, finished=finished)) - data = dict(method="update_weight_npu_ipc", args=[data_]) - response = requests.post(f"{self.rollout_url}/collective_rpc", headers=headers, json=data) - assert response.status_code == 200, f"response.status_code = {response.status_code}" - - if finished: - dist.barrier(group=cpu_group) - return - - if self.rollout_cfg_info["backend"] == "pytorch": - # TODO(chenchiyu): remove lmdeploy related code - from lmdeploy.utils import serialize_state_dict - - try: - from lmdeploy.utils import FlattenedTensorBucket - - use_flattened_tensor_bucket = True - except Exception: - use_flattened_tensor_bucket = False - - if self.rollout_cfg_info["backend"] == "pytorch" and self.rollout_cfg_info["tp"] > 1: - serialized_data = [None] * self.rollout_cfg_info["tp"] - if use_flattened_tensor_bucket: - flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=list(state_dict.items())) - metadata = flattened_tensor_bucket.get_metadata() - flattened_tensor_data = { - "flattened_tensor": flattened_tensor_bucket.get_flattened_tensor(), - "metadata": metadata, - } - tp_serialized_data = serialize_state_dict(flattened_tensor_data) - else: - tp_serialized_data = serialize_state_dict(state_dict) - dist.gather_object( - tp_serialized_data, - serialized_data if dist.get_rank() == head_rank else None, - dst=head_rank, - group=cpu_group, - ) - elif self.rollout_cfg_info["backend"] == "pytorch": - if use_flattened_tensor_bucket: - flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=list(state_dict.items())) - metadata = flattened_tensor_bucket.get_metadata() - flattened_tensor_data = { - "flattened_tensor": flattened_tensor_bucket.get_flattened_tensor(), - "metadata": metadata, - } - serialized_data = serialize_state_dict(flattened_tensor_data) - else: - serialized_data = serialize_state_dict(state_dict) - else: - # for turbomind backend, only head_rank should serialize data - serialized_data = serialize_state_dict(state_dict) if dist.get_rank() == head_rank else None - else: - # sglang - from sglang.srt.utils import MultiprocessingSerializer - from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions - - try: - from sglang.srt.model_executor.model_runner import FlattenedTensorBucket - - use_flattened_tensor_bucket = True - except Exception: - use_flattened_tensor_bucket = False - - # NOTE: xtuner目前去掉sglang的patch也不会出问题,但为了保险起见,还是保留patch逻辑,并且在update_weights结束后unpatch - monkey_patch_torch_reductions() - state_dict = state_dict.items() - if self.rollout_cfg_info["tp"] == 1: - if use_flattened_tensor_bucket: - flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=state_dict) - metadata = flattened_tensor_bucket.get_metadata() - - flattened_tensor_data = { - "flattened_tensor": flattened_tensor_bucket.get_flattened_tensor(), - "metadata": metadata, - } - serialized_data = MultiprocessingSerializer.serialize(flattened_tensor_data, output_str=True) - else: - serialized_data = MultiprocessingSerializer.serialize(state_dict, output_str=True) - - serialized_data = [serialized_data] - else: - serialized_data = [None] * self.rollout_cfg_info["tp"] - if use_flattened_tensor_bucket: - flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=state_dict) - metadata = flattened_tensor_bucket.get_metadata() - - flattened_tensor_data = { - "flattened_tensor": flattened_tensor_bucket.get_flattened_tensor(), - "metadata": metadata, - } - tp_serialized_data = MultiprocessingSerializer.serialize(flattened_tensor_data, output_str=True) - dist.gather_object( - tp_serialized_data, - serialized_data if dist.get_rank() == head_rank else None, - dst=head_rank, - group=cpu_group, - ) - else: - tp_serialized_data = MultiprocessingSerializer.serialize(state_dict, output_str=True) - dist.gather_object( - tp_serialized_data, - serialized_data if dist.get_rank() == head_rank else None, - dst=head_rank, - group=cpu_group, - ) - - if dist.get_rank() == head_rank: - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.rollout_cfg_info['api_key']}", - } - if self.rollout_cfg_info["backend"] == "sglang": - payload = { - "serialized_named_tensors": serialized_data, - "flush_cache": False, - } - try: - from sglang.srt.model_executor.model_runner import FlattenedTensorBucket - - use_flattened_tensor_bucket = True - except Exception: - use_flattened_tensor_bucket = False - if use_flattened_tensor_bucket: - payload["load_format"] = "flattened_bucket" - - url = f"{self.rollout_url}/update_weights_from_tensor" - response = requests.post(url, json=payload or {}) - response.raise_for_status() - else: - data = dict(serialized_named_tensors=serialized_data, finished=finished) - try: - from lmdeploy.utils import FlattenedTensorBucket - - use_flattened_tensor_bucket = True - except Exception: - use_flattened_tensor_bucket = False - - if use_flattened_tensor_bucket: - data["load_format"] = "flattened_bucket" - response = requests.post( - f"{self.rollout_url}/{self.endpoints['update_weights']}", headers=headers, json=data - ) - assert response.status_code == 200, f"response.status_code = {response.status_code}" - - if finished: - dist.barrier(group=cpu_group) - - monkey_unpatch_torch_reductions() - return - @ray_method def save(self, checkpoint_path: Path | str, no_save_optimizer: bool = False): """Save the DCP checkpoint of the training worker.""" diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index 9245e9f18..bb599d6c7 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -898,6 +898,8 @@ def __init__(self, cfg: RLColocateTrainerConfig): if self._debug_rollout: self.logger.warning("Debug rollout mode is enabled, rollout will not be offloaded.") + self.train_controller.set_train_rollout_mode("colocate") + def _sync_weights_from_train_workers(self) -> None: self.logger.info("Rollout workers skip load weights, update weights from train workers.") ray.get(self.rollout_controller.offload.remote()) @@ -1032,6 +1034,7 @@ def __init__(self, cfg: RLDisaggregatedTrainerConfig): "Debug rollout mode is enabled. Disaggregated training keeps rollout workers resident." ) + self.train_controller.set_train_rollout_mode("disaggregated") def _build_disaggregated_placement_groups( self, train_resources: AcceleratorResourcesConfig, @@ -1155,5 +1158,7 @@ async def _sync_weights_and_save(self, train_step: int, step_timer_dict: dict): self.fake_update_weights() def fake_update_weights(self): + ray.get(self.rollout_controller.pause_generation.remote()) self.train_controller.update_weights() + ray.get(self.rollout_controller.continue_generation.remote()) self.logger.info("Rollout workers updated weights through fake disaggregated sync.")