diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index fb374472a0..cb2449a961 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -32,6 +32,7 @@ from xtuner.v1.model.base import ModelItem, TransformerConfig from xtuner.v1.model.compose.base import BaseComposeConfig, BaseComposeModel from xtuner.v1.model.compose.qwen3_vl import Qwen3VLForConditionalGeneration +from xtuner.v1.model.moe.moe import MoE from xtuner.v1.model.utils.misc import ModelForwardExtraLogInfo from xtuner.v1.ray.base import SingleAcceleratorWorker from xtuner.v1.ray.config import RolloutConfig @@ -46,7 +47,7 @@ monkey_unpatch_torch_reductions, ray_method, ) -from xtuner.v1.utils.load_spec import LoadEnum +from xtuner.v1.utils.load_spec import LoadEnum, LoadSpec from ..loss_fn import kl_penalty from .loss import BaseRLLossConfig @@ -252,6 +253,10 @@ def __init__( else: mode = "eager" self.logprob_cfg = LogProbConfig(chunk_size=worker_cfg.loss_cfg.chunk_size, mode=mode) + self._global_hf_keys_mapping_cache: dict[str, list[str]] = dict() + self._ipc_tensor_bytes: int = int(self.config.update_weight_bucket_size_in_gb * 1024**3) + self._update_params_ipc_tensor = None + self._update_params_ipc_event = None def _init_sft(self, worker_cfg: WorkerConfig): self._sft_dataloader_config = worker_cfg.sft_dataloader_cfg @@ -853,17 +858,94 @@ def update_rollout_info( @ray_method def update_weights(self): """Update the model weights.""" + DEVICE_MODULE.empty_cache() + self._update_params_ipc_event = DEVICE_MODULE.Event(interprocess=True) if self.rollout_cfg_info.get("backend") == "turbomind": self._update_weights_by_layer() else: if isinstance(self.config.model_cfg, BaseComposeConfig): + self._update_weights_hf_generator(submodule="language_model", final_update=False) self._update_weights_hf_generator(submodule="vision_tower", final_update=False) - self._update_weights_hf_generator(submodule="multi_modal_projector", final_update=False) - self._update_weights_hf_generator(submodule="language_model", final_update=True) + self._update_weights_hf_generator(submodule="multi_modal_projector", final_update=True) else: - self._update_weights_hf_generator() + self._update_weights_hf_generator(final_update=True) + self._update_params_ipc_tensor = None + self._update_params_ipc_event = None + DEVICE_MODULE.empty_cache() + + def _rl_get_fused_ep_hf_param(self, model: MoE, target_ep_rank: int, target_ep_size: int, bucket_size: int): + fused_param_groups: list[tuple[torch.Tensor, LoadSpec]] = model._group_param_by_load_spec(LoadEnum.FUSED) + model_ep_size = 1 if model.fsdp_config is None else model.fsdp_config.ep_size + if not fused_param_groups: + return + + def _get_hf_params( + fsdp_tensor_list: list[tuple[torch.Tensor, LoadSpec]], + ) -> tuple[list[torch.Tensor], list[str]]: + hf_keys_list: list[str] = [] + # Split the fused tensor into target hf tensors + hf_tensor_list: list[torch.Tensor] = [] + + for fsdp_tensor, load_spec in fsdp_tensor_list: + hf_keys = load_spec.hf_keys + if model_ep_size > 1 and model.ep_mesh is not None: + if load_spec.name not in self._global_hf_keys_mapping_cache: + global_hf_keys: list[list[str] | None] = [None] * model_ep_size + dist.all_gather_object(global_hf_keys, hf_keys, group=model.ep_mesh.get_group()) + global_hf_keys_gathered = cast(list[list[str]], global_hf_keys) + self._global_hf_keys_mapping_cache[load_spec.name] = list( + chain.from_iterable(global_hf_keys_gathered) + ) + hf_keys = self._global_hf_keys_mapping_cache[load_spec.name] + + fused_full_tensor = fsdp_tensor.bfloat16() + if isinstance(fused_full_tensor, DTensor): + fused_full_tensor = fused_full_tensor.full_tensor() + dim = cast(int, load_spec.dim) + num_split = len(hf_keys) + hf_tensor_size = fused_full_tensor.shape[dim] / num_split + assert hf_tensor_size.is_integer(), "Internal Error, hf_tensor_size is not integer" + hf_tensor_size = int(hf_tensor_size) + + hf_tensor = fused_full_tensor.split([hf_tensor_size] * num_split, dim=dim) + # slice target ep rank + assert num_split % target_ep_size == 0, ( + f"len(hf_keys) of '{hf_keys}' is {num_split}, it must be divisible by target_ep_size {target_ep_size}" + ) + start_idx = (num_split // target_ep_size) * target_ep_rank + end_idx = (num_split // target_ep_size) * (target_ep_rank + 1) + + hf_keys_list.extend(hf_keys[start_idx:end_idx]) + hf_tensor_list.extend(hf_tensor[start_idx:end_idx]) + + hf_tensor_list = [ + model.param_to_safetensor(safetensor, name) for safetensor, name in zip(hf_tensor_list, hf_keys_list) + ] + + return hf_tensor_list, hf_keys_list + + safetensor_size = 0 + dtype = torch.bfloat16 # hardcode bfloat16 for now + tensor_list: list[tuple[torch.Tensor, LoadSpec]] = [] - def _update_weights_hf_generator(self, submodule=None, final_update=True): + for param, load_spec in fused_param_groups: + tensor_size = dtype.itemsize * param.numel() // target_ep_size + if safetensor_size + tensor_size > bucket_size and tensor_list: + hf_params, name_list = _get_hf_params(tensor_list) + yield name_list, hf_params + safetensor_size = tensor_size + name_list = load_spec.hf_keys.copy() + tensor_list = [(param, load_spec)] + continue + safetensor_size += tensor_size + tensor_list.append((param, load_spec)) + + if tensor_list: + hf_params, name_list = _get_hf_params(tensor_list) + yield name_list, hf_params + + @torch.no_grad() + def _update_weights_hf_generator(self, submodule=None, final_update=False): """Update the model weights.""" self.endpoints["update_weights"] = "update_weights" assert self.rollout_device_mesh is not None @@ -872,8 +954,6 @@ def _update_weights_hf_generator(self, submodule=None, final_update=True): if submodule: model = getattr(model, submodule) - DEVICE_MODULE.empty_cache() - # TODO: force bfloat16 dtype for now dtype = torch.bfloat16 @@ -881,59 +961,51 @@ def _update_weights_hf_generator(self, submodule=None, final_update=True): same_gen = model._get_same_hf_param( model._group_param_by_load_spec(LoadEnum.SAME), dtype=dtype, device=DEVICE, bucket_size=bucket_size ) - fused_gen = model._get_fused_hf_param( - model._group_param_by_load_spec(LoadEnum.FUSED), - dtype=dtype, - device=DEVICE, - bucket_size=bucket_size, - update_weights_for_rl=True, - ) + + train_enable_ep = model.fsdp_config is not None and model.fsdp_config.ep_size > 1 + if train_enable_ep: + # rollout_device_mesh contains the coordinate info of rollout engine + # whose the coordinate is the same as training engine rank + if self.rollout_cfg_info["ep"] > 1: + fused_gen = self._rl_get_fused_ep_hf_param( + model, + target_ep_rank=self.rollout_device_mesh["engine_parallel"].get_coordinate()[0], + target_ep_size=self.rollout_device_mesh["engine_parallel"].size(), + bucket_size=bucket_size, + ) + else: + fused_gen = self._rl_get_fused_ep_hf_param( + model, + target_ep_rank=0, + target_ep_size=1, + bucket_size=bucket_size, + ) + else: + fused_gen = model._get_fused_hf_param( + model._group_param_by_load_spec(LoadEnum.FUSED), + dtype=dtype, + device=DEVICE, + bucket_size=bucket_size, + update_weights_for_rl=True, + ) shard_gen = model._get_shard_hf_param( model._group_param_by_load_spec(LoadEnum.SHARD), dtype=dtype, device=DEVICE, bucket_size=bucket_size ) for name_list, fused_param_list in fused_gen: state_dict = {name: param.detach() for name, param in zip(name_list, fused_param_list)} - if model.fsdp_config.ep_size > 1: - # When ep_size > 1, generator generates part of the fused param on each ep rank in one ep_group. - # We can all gather them to get full fused param but it would lead to a larger memory usage. - # So we broadcast the part fused param from each ep rank in ep_group sequentially, - # and update the part of the fused param sequentially to reduce memory usage. - if isinstance(model.config, BaseComposeConfig): - ep_mesh: DeviceMesh = model.language_model.ep_mesh - else: - ep_mesh: DeviceMesh = model.ep_mesh - ep_group = ep_mesh.get_group() - global_rank = dist.get_rank() - for src_global_rank in dist.get_process_group_ranks(ep_group): - broadcast_state_dict = dict() - for key, tensor in state_dict.items(): - obj_to_broadcast = [key, tensor.to("meta")] if global_rank == src_global_rank else [None, None] - dist.broadcast_object_list(obj_to_broadcast, src=src_global_rank, group=ep_group) - real_key, meta_tensor = obj_to_broadcast - buffer = ( - state_dict[real_key] - if global_rank == src_global_rank - else torch.empty_like(meta_tensor, device=DEVICE) - ) - dist.broadcast(buffer, src=src_global_rank, group=ep_group) - broadcast_state_dict[real_key] = buffer - self.request_update_params(broadcast_state_dict, finished=False) - del broadcast_state_dict, buffer - else: - self.request_update_params(state_dict, finished=False) + self.request_update_params(state_dict, train_enable_ep=train_enable_ep, finished=False) del state_dict, name_list, fused_param_list for name_list, param_list in chain(same_gen, shard_gen): state_dict = {name: param.detach() for name, param in zip(name_list, param_list)} - self.request_update_params(state_dict, finished=False) + self.request_update_params(state_dict, train_enable_ep=train_enable_ep, finished=False) del state_dict, name_list, param_list if self.rollout_cfg_info["backend"] in ("pytorch", "vllm") and final_update: - self.request_update_params({}, finished=True) + self.request_update_params({}, train_enable_ep=train_enable_ep, finished=True) dist.barrier() - DEVICE_MODULE.empty_cache() return def _update_weights_by_layer(self): @@ -1041,185 +1113,19 @@ def get_params(tensor_list, name_list, save_dtype): DEVICE_MODULE.empty_cache() return - # def update_weights1(self): - # """Update the model weights.""" - # self.endpoints["update_weights"] = "update_weights" - # assert self.rollout_device_mesh is not None - # time1 = time.time() - - # model = self._engine.model - # DEVICE_MODULE.empty_cache() - - # if (model.config.float8_cfg is not None) and (model.config.float8_cfg.enable_float8): - # dtype = torch.float8_e4m3fn - # else: - # dtype = torch.bfloat16 - - # fused_params = [] - # for name, param in model.state_dict().items(): - # load_spec = model.load_spec_mapping.get(name) - # if load_spec.load_enum == LoadEnum.FUSED: - # fused_params.append((name, param, load_spec)) - - # # TODO: decouple update_weights from the model structure - # bucket_size = 1024**3 - # safetensor_size = 0 - # tensor_list: list[tuple[torch.Tensor, LoadSpec]] = [] - # name_list: list[str] = [] - # for name, param, load_spec in fused_params: - # local_tensor = param._local_tensor if isinstance(param, DTensor) else param - # local_tensor = local_tensor.bfloat16() - # if safetensor_size + model._get_tensor_size(param, dtype) > bucket_size: - # _tensor_list, _spec_list = list(zip(*tensor_list)) - # fsdp_unshard_tensor_list = model._fsdp_foreach_allgather(_tensor_list, _spec_list) - # if dtype == torch.float8_e4m3fn: - # fsdp_unshard_tensor_list, name_list = model._to_float8( - # fsdp_unshard_tensor_list, name_list, _tensor_list, dtype - # ) - # state_dict = dict(zip(name_list, fsdp_unshard_tensor_list)) - # self.request_update_params(state_dict) - # safetensor_size = 0 - # tensor_list = [(local_tensor, load_spec)] - # name_list = ["model." + name.replace(".experts.", ".mlp.experts.")] - # continue - # safetensor_size += model._get_tensor_size(param, dtype) - # tensor_list.append((local_tensor, load_spec)) - # name_list.append("model." + name.replace(".experts.", ".mlp.experts.")) - - # if tensor_list: - # assert len(name_list) == len(tensor_list) - # _tensor_list, _spec_list = list(zip(*tensor_list)) - # fsdp_unshard_tensor_list = model._fsdp_foreach_allgather(_tensor_list, _spec_list) - # if dtype == torch.float8_e4m3fn: - # fsdp_unshard_tensor_list, name_list = model._to_float8( - # fsdp_unshard_tensor_list, name_list, _tensor_list, dtype - # ) - # state_dict = dict(zip(name_list, fsdp_unshard_tensor_list)) - # self.request_update_params(state_dict) - - # same_gen = model._get_same_hf_param( - # model._group_param_by_load_spec(LoadEnum.SAME), - # dtype=dtype, - # device="cuda", - # bucket_size=1024**3, - # ) - # for name_list, gathered_tensor_list in tqdm.tqdm(same_gen, desc="[update dense weights]"): - # state_dict = dict(zip(name_list, gathered_tensor_list)) - # self.request_update_params(state_dict) - # del state_dict - - # self.request_update_params({}, finished=True) - - # dist.barrier() - # logger.info(f"update weights time: {time.time() - time1}") - # DEVICE_MODULE.empty_cache() - # return - - # def update_weights(self): - # """Update the model weights.""" - # self.endpoints["update_weights"] = "update_weights" - # assert self.rollout_device_mesh is not None - - # model = self._engine.model - # DEVICE_MODULE.empty_cache() - - # saved_keys = [] - # gather_duration = [] - # weight_duration = [] - # reshard_duration = [] - - # # update decoder layers - # for i, layer in tqdm.tqdm(model.layers.items(), desc="[gather weight]"): - # start = time.perf_counter() - # layer.unshard() - # layer_state_dict = {} - - # for sub_name, param in layer.named_parameters(): - # if "_checkpoint_wrapped_module." in sub_name: - # sub_name = sub_name.replace("_checkpoint_wrapped_module.", "") - # if isinstance(param, DTensor): - # param = param.to_local() - - # if isinstance(param, WeightWithDynamicTilewiseFloat8CastTensor): - # param = param._tensor - - # if isinstance(param, Float8Tensor): - # scale_name = f"model.layers.{i}.{sub_name}_scale_inv" - # assert "fused_w1w3" in sub_name or "fused_w2" in sub_name - # # save scale_inv parameter to state_dict - # scale_tensor = param._scale - # quant_tensor = param._data - # ep_mesh = model.ep_mesh - # if ep_mesh.size() > 1: - # scale_tensor = torch.cat(dist.nn.all_gather(scale_tensor, group=ep_mesh.get_group()), dim=0) - # quant_tensor = torch.cat(dist.nn.all_gather(quant_tensor, group=ep_mesh.get_group()), dim=0) - # layer_state_dict[scale_name] = scale_tensor.detach() - # # set `param` which will be added to state_dict at the bottom of the for-block - # param = quant_tensor - - # param = param.to(DEVICE) - # name = f"model.layers.{i}.{sub_name}" - # saved_keys.append(name.replace("model.", "")) - # if ".experts." in name and ".mlp." not in name: - # name = name.replace(".experts.", ".mlp.experts.") - # if ".gate." in name and ".mlp." not in name: - # name = name.replace(".gate.", ".mlp.gate.") - # layer_state_dict[name] = param.detach() - # gather_duration.append(time.perf_counter() - start) - # start = time.perf_counter() - # self.request_update_params(layer_state_dict, finished=True) - # breakpoint() - # weight_duration.append(time.perf_counter() - start) - - # start = time.perf_counter() - # del layer_state_dict - # layer.reshard() - # reshard_duration.append(time.perf_counter() - start) - - # if dist.get_rank() == 0: - # logger.debug( - # f"Rank 0 Gather decoder layers done, total {sum(gather_duration):.2f}s, avg " - # f"{sum(gather_duration) / len(gather_duration):.2f}s" - # ) - # logger.debug( - # f"Rank 0 migrate/save decoder layers done, total {sum(weight_duration):.2f}s, avg " - # f"{sum(weight_duration) / len(weight_duration):.2f}s" - # ) - # logger.debug( - # f"Rank 0 reshard decoder layers done, total {sum(reshard_duration):.2f}s, avg " - # f"{sum(reshard_duration) / len(reshard_duration):.2f}s" - # ) - - # # update other params - # model.norm.unshard() - # model.lm_head.unshard() - # model.embed_tokens.unshard() - # others_state_dict = {} - # for name, param in model.named_parameters(): - # if "_checkpoint_wrapped_module." in name: - # continue - # if name not in saved_keys: - # saved_keys.append(name) - # if name == "norm.weight": - # name = "model.norm.weight" - # if name == "embed_tokens.weight": - # name = "model.embed_tokens.weight" - # if isinstance(param, DTensor): - # param = param.to_local() - # others_state_dict[name] = param.detach() - # self.request_update_params(others_state_dict, finished=True) - # model.norm.reshard() - # model.lm_head.reshard() - # model.embed_tokens.reshard() - # del others_state_dict - # del param - - # dist.barrier() - # DEVICE_MODULE.empty_cache() - # return + @staticmethod + def _compute_state_dict_bytes(state_dict: Dict[str, torch.Tensor]) -> int: + total_bytes = 0 + for tensor in state_dict.values(): + total_bytes += tensor.numel() * tensor.element_size() + return total_bytes + + @staticmethod + def _create_ipc_tensor(size_in_bytes: int, dtype: torch.dtype): + return torch.empty(size_in_bytes, dtype=torch.uint8, device=DEVICE).view(dtype) @ray_method - def request_update_params(self, state_dict, finished=False): + def request_update_params(self, state_dict, train_enable_ep=False, finished=False): """Send a request to update the parameters on the rollout workers. This method serializes the state dictionary and sends it to the @@ -1228,6 +1134,8 @@ def request_update_params(self, state_dict, finished=False): Args: state_dict (dict | list): The state dictionary containing the model parameters to update. + train_enable_ep (bool): Whether the training engine enables expert parallelism. + Defaults to False. finished (bool): A flag indicating whether this is the final batch of updates. Defaults to False. """ @@ -1287,12 +1195,36 @@ def serialize_state_dict(state_dict: dict) -> str: if self.rollout_cfg_info["backend"] == "pytorch" and self.rollout_cfg_info["tp"] > 1: serialized_data = [None] * self.rollout_cfg_info["tp"] if use_flattened_tensor_bucket: - flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=list(state_dict.items())) + state_dict_bytes = self._compute_state_dict_bytes(state_dict) + send_ipc_tensor = ( + state_dict_bytes > self._ipc_tensor_bytes or self._update_params_ipc_tensor is None + ) + if send_ipc_tensor: + self._ipc_tensor_bytes = max(self._ipc_tensor_bytes, state_dict_bytes) + if self._update_params_ipc_tensor is not None: + # wait previous ipc event recorded of lmdeploy + self._update_params_ipc_event.wait() + torch.cuda.synchronize() + self._update_params_ipc_tensor = self._create_ipc_tensor( + self._ipc_tensor_bytes, state_dict[next(iter(state_dict))].dtype + ) + else: + self._update_params_ipc_event.wait() + + flattened_tensor_bucket = FlattenedTensorBucket( + named_tensors=list(state_dict.items()), + flattened_tensor=self._update_params_ipc_tensor, + ) metadata = flattened_tensor_bucket.get_metadata() flattened_tensor_data = { - "flattened_tensor": flattened_tensor_bucket.get_flattened_tensor(), "metadata": metadata, + "require_clone": False, } + self._update_params_ipc_event.record() + + if send_ipc_tensor: + flattened_tensor_data["flattened_tensor"] = flattened_tensor_bucket.get_flattened_tensor() + flattened_tensor_data["event_ipc_handle"] = self._update_params_ipc_event.ipc_handle() tp_serialized_data = serialize_state_dict(flattened_tensor_data) else: tp_serialized_data = serialize_state_dict(state_dict) @@ -1304,12 +1236,36 @@ def serialize_state_dict(state_dict: dict) -> str: ) elif self.rollout_cfg_info["backend"] == "pytorch": if use_flattened_tensor_bucket: - flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=list(state_dict.items())) + state_dict_bytes = self._compute_state_dict_bytes(state_dict) + send_ipc_tensor = ( + state_dict_bytes > self._ipc_tensor_bytes or self._update_params_ipc_tensor is None + ) + if send_ipc_tensor: + self._ipc_tensor_bytes = max(self._ipc_tensor_bytes, state_dict_bytes) + if self._update_params_ipc_tensor is not None: + # wait previous ipc event recorded of lmdeploy + self._update_params_ipc_event.wait() + torch.cuda.synchronize() + self._update_params_ipc_tensor = self._create_ipc_tensor( + self._ipc_tensor_bytes, state_dict[next(iter(state_dict))].dtype + ) + else: + self._update_params_ipc_event.wait() + + flattened_tensor_bucket = FlattenedTensorBucket( + named_tensors=list(state_dict.items()), + flattened_tensor=self._update_params_ipc_tensor, + ) metadata = flattened_tensor_bucket.get_metadata() flattened_tensor_data = { - "flattened_tensor": flattened_tensor_bucket.get_flattened_tensor(), "metadata": metadata, + "require_clone": False, } + self._update_params_ipc_event.record() + + if send_ipc_tensor: + flattened_tensor_data["flattened_tensor"] = flattened_tensor_bucket.get_flattened_tensor() + flattened_tensor_data["event_ipc_handle"] = self._update_params_ipc_event.ipc_handle() serialized_data = serialize_state_dict(flattened_tensor_data) else: serialized_data = serialize_state_dict(state_dict) @@ -1409,7 +1365,12 @@ def serialize_state_dict(state_dict: dict) -> str: ) assert response.status_code == 200, f"response.status_code = {response.status_code}" - if finished: + # TODO(chenchiyu): narrow this condition + if finished or (self.rollout_cfg_info["backend"] == "pytorch" and train_enable_ep and self.rollout_cfg_info["tp"] > 1): + # This barrier is aim to make each tp head rank sync with other ranks in engine_parallel group + # which could not be barrier by `fsdp_foreach_allgather` of the next state dict. (Happens in same_gen, shard not tested) + # Without barrier, some ranks in engine_parallel group would not wait for current iter data ipc event recording in lmdeploy. + # They would write next iter state_dict into the ipc tensor before lmdeploy load current iter weight. dist.barrier(group=cpu_group) monkey_unpatch_torch_reductions() diff --git a/xtuner/v1/rl/config/advantage.py b/xtuner/v1/rl/config/advantage.py index 05a4b8c0eb..5a9c78f0fd 100644 --- a/xtuner/v1/rl/config/advantage.py +++ b/xtuner/v1/rl/config/advantage.py @@ -1,7 +1,7 @@ from typing import Annotated from cyclopts import Group, Parameter -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from xtuner.v1.rl.advantage.base import AdvantageEstimator @@ -12,6 +12,8 @@ class BaseAdvantageConfig(BaseModel): """Intermediate base for discriminated union.""" + model_config = ConfigDict(extra="forbid") + def build(self) -> AdvantageEstimator: raise NotImplementedError("Subclasses must implement this method.")