diff --git a/.dev_scripts/debug_gateway.py b/.dev_scripts/debug_gateway.py new file mode 100644 index 0000000000..db7a087cfe --- /dev/null +++ b/.dev_scripts/debug_gateway.py @@ -0,0 +1,239 @@ +#!/usr/bin/env python +"""Start a RolloutController-backed Gateway for manual protocol debugging. + +This script is intended for end-to-end debugging with real clients such as +Claude Code, Codex, curl, or the OpenAI SDK. It starts the RolloutController, +waits for rollout workers to become ready, then serves the Gateway in the +current process. +""" + +from __future__ import annotations + +import argparse +import os +import sys +import time +from pathlib import Path +from typing import Any + + +DEFAULT_WORK_DIR = Path("/tmp/xtuner_debug_gateway") +DEFAULT_MODEL_PATH = os.environ.get("ROLLOUT_MODEL_PATH") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Start a local XTuner Gateway backed by a RolloutController for manual protocol debugging.\n\n" + "Example:\n" + " python .dev_scripts/debug_gateway.py --model-path /path/to/model --model-name local-test" + ), + formatter_class=argparse.RawTextHelpFormatter, + ) + + parser.add_argument( + "--model-path", + default=DEFAULT_MODEL_PATH, + required=DEFAULT_MODEL_PATH is None, + help="Model path for rollout workers. Defaults to the ROLLOUT_MODEL_PATH environment variable.", + ) + parser.add_argument("--model-name", default=None, help="Model name exposed by the Gateway.") + parser.add_argument("--tokenizer-path", default=None, help="Tokenizer path. Defaults to --model-path.") + parser.add_argument("--rollout-env", default="debug_gateway", help="Rollout environment name.") + parser.add_argument("--ray-address", default="local", help="Ray cluster address. Use 'local' to start one.") + parser.add_argument("--ray-namespace", default="xtuner-debug-gateway", help="Ray namespace for this debug run.") + parser.add_argument("--controller-name", default=None, help="Optional Ray actor name for the RolloutController.") + parser.add_argument( + "--ray-max-concurrency", + type=int, + default=int(os.environ.get("RAY_MAX_CONCURRENCY", 1000)), + help="max_concurrency for the RolloutController actor.", + ) + + parser.add_argument("--tensor-parallel-size", type=int, default=4) + parser.add_argument("--expert-parallel-size", type=int, default=1) + parser.add_argument("--num-workers", type=int, default=4) + parser.add_argument("--num-cpus-per-worker", type=int, default=16) + parser.add_argument("--cpu-memory-per-worker-gb", type=int, default=8) + parser.add_argument("--context-length", type=int, default=32768) + parser.add_argument("--dist-port-base", type=int, default=42000) + parser.add_argument("--api-host", default="127.0.0.1") + parser.add_argument("--api-port", type=int, default=30080) + parser.add_argument("--worker-log-dir", default=str(DEFAULT_WORK_DIR / "worker_logs")) + parser.add_argument("--placement-group-name", default="xtuner_debug_gateway_pg") + parser.add_argument( + "--ready-poll-seconds", + type=float, + default=5.0, + help="Polling interval while waiting for rollout workers to become ready.", + ) + parser.add_argument("--tool-call-parser", default="qwen3", help="Tool call parser used by the rollout backend.") + parser.add_argument("--reasoning-parser", default="qwen3", help="Reasoning parser used by the rollout backend.") + + parser.add_argument("--host", default="127.0.0.1", help="Gateway bind host.") + parser.add_argument("--port", type=int, default=8091, help="Gateway bind port.") + parser.add_argument("--log-level", default="info", help="Uvicorn log level.") + parser.add_argument( + "--capture-folder", + default=None, + help="Optional request capture folder. If omitted, defaults to /gateway_captures.", + ) + + return parser.parse_args() + + +def resolve_capture_output_file(capture_folder: str | Path | None) -> Path | None: + if capture_folder is None: + return None + from xtuner.v1.rl.gateway.adapters.capture import resolve_capture_output_path + + return resolve_capture_output_path(capture_folder) + + +def describe_capture_output(capture_folder: str | Path | None) -> str: + capture_output_file = resolve_capture_output_file(capture_folder) + if capture_output_file is None: + return "disabled" + return f"{capture_output_file} (requests with API keys are split into api_key_.jsonl)" + + +def init_ray(address: str, namespace: str) -> dict[str, Any]: + import ray + + ctx = ray.init(address=address, namespace=namespace, ignore_reinit_error=True) + address_info = getattr(ctx, "address_info", {}) or {} + return { + "requested_ray_address": address, + "ray_address": address_info.get("address") or address_info.get("gcs_address") or address, + "namespace": namespace, + "ray_context": address_info, + } + + +def build_rollout_config(args: argparse.Namespace): + from xtuner.v1.rl.rollout.worker import RolloutConfig + + model_path = str(args.model_path) + tokenizer_path = str(args.tokenizer_path or args.model_path) + model_name = args.model_name or Path(model_path).name.lower() + return RolloutConfig( + env=args.rollout_env, + device="GPU", + model_path=model_path, + model_name=model_name, + tokenizer_path=tokenizer_path, + tensor_parallel_size=args.tensor_parallel_size, + expert_parallel_size=args.expert_parallel_size, + context_length=args.context_length, + worker_log_dir=args.worker_log_dir, + dist_port_base=args.dist_port_base, + api_host=args.api_host, + api_port=args.api_port, + tool_call_parser=args.tool_call_parser, + reasoning_parser=args.reasoning_parser, + ) + + +def build_controller(args: argparse.Namespace): + import ray + + from xtuner.v1.rl.rollout.controller import RolloutController + from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers + + resource_config = AcceleratorResourcesConfig( + accelerator="GPU", + num_workers=args.num_workers, + num_cpus_per_worker=args.num_cpus_per_worker, + cpu_memory_per_worker=args.cpu_memory_per_worker_gb * 1024**3, + ) + placement_group = AutoAcceleratorWorkers.build_placement_group( + resource_config, + name=args.placement_group_name, + ) + rollout_config = build_rollout_config(args) + actor_options: dict[str, Any] = { + "max_concurrency": args.ray_max_concurrency, + } + if args.controller_name: + actor_options["name"] = args.controller_name + controller = ray.remote(RolloutController).options(**actor_options).remote(rollout_config, placement_group) + print("Created rollout controller.") + return controller, placement_group + + +def wait_for_controller_ready(controller, poll_seconds: float) -> dict[str, Any]: + import ray + + while True: + ready, status = ray.get(controller.get_ready_status.remote()) + if ready: + print(f"Rollout controller ready: {status}") + return status + print(f"Waiting for rollout workers... {status}") + time.sleep(poll_seconds) + + +def start_gateway(args: argparse.Namespace, controller) -> None: + from xtuner.v1.rl.gateway.config import GatewayConfig + from xtuner.v1.rl.gateway.server import build_local_gateway_app, serve_gateway + + capture_folder = args.capture_folder + if capture_folder is None: + capture_folder = str(Path(args.worker_log_dir) / GatewayConfig._CAPTURE_PATH_FOLDER) + + cfg = GatewayConfig( + host=args.host, + port=args.port, + auto_start=False, + capture_folder=capture_folder, + log_level=args.log_level, + ) + + app = build_local_gateway_app(controller, config=cfg) + print(f"Starting gateway at http://{cfg.host}:{cfg.port}") + print(f"Gateway capture output: {describe_capture_output(cfg.capture_folder)}") + serve_gateway(app, cfg) + + +def cleanup_controller(controller, placement_group) -> None: + import ray + + try: + ray.get(controller.shutdown.remote(), timeout=300) + except Exception as exc: + print(f"Failed to shutdown rollout controller cleanly: {exc}", file=sys.stderr) + try: + ray.kill(controller, no_restart=True) + except Exception as exc: + print(f"Failed to kill rollout controller: {exc}", file=sys.stderr) + if placement_group is not None: + try: + ray.util.remove_placement_group(placement_group) + except Exception as exc: + print(f"Failed to remove placement group: {exc}", file=sys.stderr) + + +def main() -> None: + args = parse_args() + controller = None + placement_group = None + try: + init_info = init_ray(args.ray_address, args.ray_namespace) + print( + "Initialized Ray: " + f"requested_address={init_info['requested_ray_address']}, " + f"address={init_info['ray_address']}, namespace={init_info['namespace']}" + ) + controller, placement_group = build_controller(args) + wait_for_controller_ready(controller, args.ready_poll_seconds) + start_gateway(args, controller) + finally: + ray_module = sys.modules.get("ray") + if ray_module is not None and ray_module.is_initialized(): + if controller is not None: + cleanup_controller(controller, placement_group) + ray_module.shutdown() + + +if __name__ == "__main__": + main() diff --git a/.dev_scripts/rl_config_factory.py b/.dev_scripts/rl_config_factory.py deleted file mode 100644 index d8a9e42f1c..0000000000 --- a/.dev_scripts/rl_config_factory.py +++ /dev/null @@ -1,129 +0,0 @@ -from pathlib import Path -from typing import Any, Dict, Optional - -from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig -from xtuner.v1.datasets import RLTokenizeFnConfig -from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig -from xtuner.v1.ray.base import AcceleratorResourcesConfig -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig -from xtuner.v1.ray.evaluator import EvaluatorConfig -from xtuner.v1.ray.judger.controller import JudgerConfig -from xtuner.v1.rl.base import WorkerConfig -from xtuner.v1.rl.grpo import GRPOLossConfig -from xtuner.v1.utils.rl_test_utils import get_eos_token - - -def _filter_pydantic_kwargs(target_class: Any, kwargs: Dict) -> Dict: - accepted_keys = set(target_class.model_fields.keys()) - return {k: v for k, v in kwargs.items() if k in accepted_keys} - - -def _build_config(config_class, **kwargs): - filtered_params = _filter_pydantic_kwargs(config_class, kwargs) - return config_class(**filtered_params) - - -def get_resources_config(**kwargs) -> AcceleratorResourcesConfig: - return _build_config(AcceleratorResourcesConfig, **kwargs) - - -def get_rollout_config(**kwargs) -> RolloutConfig: - return _build_config(RolloutConfig, **kwargs) - - -def get_dataflow_config(**kwargs) -> DataFlowConfig: - return _build_config(DataFlowConfig, **kwargs) - - -def get_replay_buffer_config(tokenizer: Any, **kwargs) -> ReplayBufferConfig: - tokenizer_config = RLTokenizeFnConfig(max_length=kwargs["max_prompt_length"]) - train_dataset = DatasetConfig(anno_path=kwargs["data_path"]) - train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] - dataloader_config = DataloaderConfig(collator="fake_collator", pack_level="none") - return ReplayBufferConfig( - dataset_cfg=train_dataset_cfg, - dataloader_cfg=dataloader_config, - tokenizer=tokenizer, - postprocessor_func=kwargs.get("filter_func"), - ) - - -def get_dapo_judger_config(tokenizer: Any, **kwargs): - dapo_defaults_args = { - "enable_overlong_buffer": True, - "overlong_buffer_len": 4096, - "overlong_penalty_factor": 1.0, - } - dapo_config_params = {**dapo_defaults_args, **kwargs} - eos_token_id = get_eos_token(kwargs["model_path"]) - eos_token_str = tokenizer.convert_ids_to_tokens(eos_token_id) - from xtuner.v1.ray.judger.dapo_math import DapoMathJudgerConfig - - filtered_params = _filter_pydantic_kwargs(DapoMathJudgerConfig, dapo_config_params) - dapomath_judger_config = DapoMathJudgerConfig( - judger_name="dapo_math", - eos_token=eos_token_str, - max_response_len=kwargs["max_response_length"], - tokenizer=tokenizer, - **filtered_params, - ) - return JudgerConfig(reward_judger_configs=[dapomath_judger_config]) - - -def get_gsm8k_judger_config(): - from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig - - gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k") - judger_cfg = JudgerConfig(reward_judger_configs=[gsm8k_judger_config]) - return judger_cfg - - -def get_evaluator_config(tokenizer: Any, **kwargs) -> Optional[EvaluatorConfig]: - if not kwargs["enable_evaluate"]: - return None - - eval_dataset = DatasetConfig(anno_path=kwargs["eval_data_path"]) - tokenizer_config = RLTokenizeFnConfig(max_length=kwargs["max_prompt_length"]) - eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] - - filtered_params = _filter_pydantic_kwargs(EvaluatorConfig, kwargs) - - return EvaluatorConfig( - dataset_cfg=eval_dataset_cfg, - tokenizer=tokenizer, - **filtered_params, - ) - - -def get_train_worker_config(**kwargs) -> WorkerConfig: - from xtuner.v1.model import get_model_config_from_hf - - model_cfg = get_model_config_from_hf(Path(kwargs["model_path"])) - defaults = { - "optim_cfg": AdamWConfig(lr=1e-6, betas=(0.9, 0.999), max_grad_norm=1.0, weight_decay=0.1, foreach=False), - "loss_cfg": GRPOLossConfig( - policy_loss_cfg={ - "cliprange_high": 0.28, - "cliprange_low": 0.2, - "loss_type": "vanilla", - "clip_ratio_c": 10.0, - "log_prob_diff_min": -20.0, - "log_prob_diff_max": 20.0, - }, - ignore_idx=-100, - use_kl_loss=False, - kl_loss_coef=0.0, - kl_loss_type="low_var_kl", - mode="chunk", - chunk_size=512, - ), - "lr_cfg": LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6), - "fsdp_cfg": FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1), - "sp_size": 1, - "optimizer_steps": 16, - "pack_max_length": 4096, - } - config_params = {**defaults, **kwargs} - filtered_params = _filter_pydantic_kwargs(WorkerConfig, config_params) - return WorkerConfig(load_from=config_params["model_path"], model_cfg=model_cfg, **filtered_params) diff --git a/.gitignore b/.gitignore index 40288d0c84..0f4bfe2800 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,8 @@ +zdev/ +old/ +bak/ +exp*/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/autotest/config/rl_qwen3_30B_gsm8k_grpo.py b/autotest/config/rl_qwen3_30B_gsm8k_grpo.py index 13b6c11ff7..cda1234891 100644 --- a/autotest/config/rl_qwen3_30B_gsm8k_grpo.py +++ b/autotest/config/rl_qwen3_30B_gsm8k_grpo.py @@ -2,25 +2,27 @@ from copy import deepcopy from pathlib import Path -from transformers import AutoTokenizer -from xtuner.v1.config import ( - AdamWConfig, - FSDPConfig, - LRConfig, -) +from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig from xtuner.v1.data_proto.rl_data import SampleParams -from xtuner.v1.datasets import RLTokenizeFnConfig from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLTextTokenizeFnConfig from xtuner.v1.model import get_model_config_from_hf -from xtuner.v1.ray.base import AcceleratorResourcesConfig -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig -from xtuner.v1.ray.evaluator import EvaluatorConfig -from xtuner.v1.ray.judger.controller import JudgerConfig -from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig -from xtuner.v1.rl.base import WorkerConfig -from xtuner.v1.rl.grpo import GRPOLossConfig -from xtuner.v1.train.rl_trainer import RLTrainerConfig +from xtuner.v1.rl.advantage import GRPOAdvantageConfig +from xtuner.v1.rl.agent_loop import SingleTurnAgentLoopConfig +from xtuner.v1.rl.agent_loop_manager import ( + AgentLoopManagerConfig, + SamplerConfig, + SyncProduceStrategyConfig, + TaskSpecConfig, +) +from xtuner.v1.rl.evaluator import EvaluatorConfig +from xtuner.v1.rl.judger import GSM8KJudgerConfig +from xtuner.v1.rl.loss import GRPOLossConfig +from xtuner.v1.rl.replay_buffer import SyncReplayBufferConfig +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.trainer import WorkerConfig +from xtuner.v1.rl.utils import AcceleratorResourcesConfig +from xtuner.v1.train.rl_trainer import RLColocateTrainerConfig work_dir = os.environ["WORK_DIR"] @@ -28,7 +30,7 @@ data_path = os.environ["DATA_PATH"] eval_data_path = os.environ["EVAL_DATA_PATH"] enable_return_routed_experts = True -enable_evaluate = True if eval_data_path != "" else False +enable_evaluate = eval_data_path != "" # basic settings experimental_name = "grpo_gsm8k_tiny" @@ -50,7 +52,7 @@ accelerator="GPU", num_workers=8, num_cpus_per_worker=12, - cpu_memory_per_worker=16 * 1024**3, # 16 GB + cpu_memory_per_worker=16 * 1024**3, ) # 2. rollout @@ -68,58 +70,43 @@ ) # sampling params -training_sample_params = SampleParams( - max_tokens=max_response_length, -) +training_sample_params = SampleParams(max_tokens=max_response_length) evaluation_sample_params = deepcopy(training_sample_params) evaluation_sample_params.top_p = 1.0 evaluation_sample_params.temperature = 0.0 evaluation_sample_params.top_k = 1 -# dataset: 不需要修改 -train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) -eval_dataset = DatasetConfig(name=experimental_name, anno_path=eval_data_path) if enable_evaluate else None -tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) -tokenizer_config = RLTokenizeFnConfig(max_length=max_prompt_length) - -train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] -eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] if enable_evaluate else [] - -dataloader_config = DataloaderConfig(pack_max_length=pack_max_length, collator="fake_collator", pack_level="none") - -# 3. judger -gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k") -judger_cfg = JudgerConfig(reward_judger_configs=[gsm8k_judger_config]) - -# 4. dataflow and evaluator -dataflow_config = DataFlowConfig( - env=experimental_name, - prompt_repeat_k=prompt_repeat_k, - global_batch_size=global_batch_size, - sample_params=training_sample_params, +# 3. datasets +tokenizer_config = RLTextTokenizeFnConfig(max_length=max_prompt_length) +train_dataset_cfg = [ + { + "dataset": DatasetConfig(name=experimental_name, anno_path=data_path), + "tokenize_fn": tokenizer_config, + } +] +eval_dataset_cfg = [ + { + "dataset": DatasetConfig(name=experimental_name, anno_path=eval_data_path if enable_evaluate else data_path), + "tokenize_fn": tokenizer_config, + } +] +dataloader_cfg = DataloaderConfig( + dataset_config_list=train_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", ) - -evaluator_cfg = ( - EvaluatorConfig( - enable_evaluate=enable_evaluate, - enable_initial_evaluate=enable_initial_evaluate, - dataset_cfg=eval_dataset_cfg, - tokenizer=tokenizer, - evaluate_step=evaluate_step, - compute_metric_func=None, - sample_params=evaluation_sample_params, - ) - if enable_evaluate - else None +eval_dataloader_cfg = DataloaderConfig( + dataset_config_list=eval_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", ) -# replay buffer config: : 不需要修改 -replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=train_dataset_cfg, dataloader_cfg=dataloader_config, tokenizer=tokenizer -) +# 4. judger +judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k") -# 5. Train worker -# NOTE: modify model_cfg +# 5. train worker model_cfg = get_model_config_from_hf(Path(model_path)) optim_cfg = AdamWConfig(lr=1e-6, foreach=False) loss_cfg = GRPOLossConfig( @@ -137,7 +124,7 @@ ) lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) -train_worker_cfg: WorkerConfig = WorkerConfig( +train_worker_cfg = WorkerConfig( model_cfg=model_cfg, load_from=model_path, optim_cfg=optim_cfg, @@ -149,19 +136,52 @@ pack_max_length=pack_max_length, ) -# 6. RL Trainer -trainer = RLTrainerConfig( - load_from=model_path, +# 6. agent loop managers +agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=training_sample_params, +) +agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="train_task", + agent_loop_config=agent_loop_config, + judger_config=judger_config, + produce_strategy_config=SyncProduceStrategyConfig(), + sampler_config=SamplerConfig(dataloader_cfg=dataloader_cfg, prompt_repeat_k=prompt_repeat_k), + ), +) + +eval_agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=evaluation_sample_params, +) +eval_agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="eval_task", + agent_loop_config=eval_agent_loop_config, + judger_config=judger_config, + sampler_config=SamplerConfig(dataloader_cfg=eval_dataloader_cfg, prompt_repeat_k=1), + ), +) + +# 7. trainer +trainer = RLColocateTrainerConfig( resources=resources, + train_worker_cfg=train_worker_cfg, rollout_config=rollout_config, - dataflow_config=dataflow_config, - judger_config=judger_cfg, - replay_buffer_config=replay_buffer_cfg, - evaluator_config=evaluator_cfg, - train_worker_config=train_worker_cfg, tokenizer_path=model_path, - work_dir=work_dir, + replay_buffer_config=SyncReplayBufferConfig(), + agent_loop_manager_cfg=agent_loop_manager_cfg, + eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg, + evaluator_config=EvaluatorConfig(compute_metric_func=None), + load_from=model_path, total_epochs=total_epochs, + train_batch_size=global_batch_size, + advantage_estimator_config=GRPOAdvantageConfig(eps=1e-8), + enable_evaluate=enable_evaluate, + enable_initial_evaluate=enable_evaluate and enable_initial_evaluate, + evaluate_step=evaluate_step, + work_dir=work_dir, hf_interval=hf_interval, exp_tracker="jsonl", ) diff --git a/autotest/config/rl_qwen3_8B_gsm8k_grpo.py b/autotest/config/rl_qwen3_8B_gsm8k_grpo.py deleted file mode 100644 index 1c52f9dca1..0000000000 --- a/autotest/config/rl_qwen3_8B_gsm8k_grpo.py +++ /dev/null @@ -1,166 +0,0 @@ -import os -from copy import deepcopy -from pathlib import Path - -from transformers import AutoTokenizer -from xtuner.v1.config import ( - AdamWConfig, - FSDPConfig, - LRConfig, -) -from xtuner.v1.data_proto.rl_data import SampleParams -from xtuner.v1.datasets import RLTokenizeFnConfig -from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig -from xtuner.v1.model import get_model_config_from_hf -from xtuner.v1.ray.base import AcceleratorResourcesConfig -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig -from xtuner.v1.ray.evaluator import EvaluatorConfig -from xtuner.v1.ray.judger.controller import JudgerConfig -from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig -from xtuner.v1.rl.base import WorkerConfig -from xtuner.v1.rl.grpo import GRPOLossConfig -from xtuner.v1.train.rl_trainer import RLTrainerConfig - - -work_dir = os.environ["WORK_DIR"] -model_path = os.environ["MODEL_PATH"] -data_path = os.environ["DATA_PATH"] -eval_data_path = os.environ["EVAL_DATA_PATH"] -enable_evaluate = True if eval_data_path != "" else False -enable_partial_rollout = int(os.environ.get("ENABLE_PARTIAL_ROLLOUT", "0")) - -# basic settings -experimental_name = "grpo_gsm8k_tiny" -total_epochs = 3 -global_batch_size = 64 -prompt_repeat_k = 5 -rollout_tp_size = 1 -rollout_ep_size = 1 -max_prompt_length = 512 -max_response_length = 1024 -pack_max_length = 32768 -train_optimizer_steps = 1 -hf_interval = 100 -enable_initial_evaluate = True -evaluate_step = 15 - -# 1. resources -resources = AcceleratorResourcesConfig( - accelerator="GPU", - num_workers=8, - num_cpus_per_worker=12, - cpu_memory_per_worker=16 * 1024**3, # 16 GB -) - -# 2. rollout -rollout_config = RolloutConfig( - env=experimental_name, - device=resources.accelerator, - model_path=model_path, - dtype="bfloat16", - tensor_parallel_size=rollout_tp_size, - expert_parallel_size=rollout_ep_size, - gpu_memory_utilization=0.85, - context_length=max_prompt_length + max_response_length, - rollout_max_batch_size_per_instance=1024, -) - -# sampling params -training_sample_params = SampleParams( - max_tokens=max_response_length, -) -evaluation_sample_params = deepcopy(training_sample_params) -evaluation_sample_params.top_p = 1.0 -evaluation_sample_params.temperature = 0.0 -evaluation_sample_params.top_k = 1 - -# dataset: 不需要修改 -train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) -eval_dataset = DatasetConfig(name=experimental_name, anno_path=eval_data_path) if enable_evaluate else None -tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) -tokenizer_config = RLTokenizeFnConfig(max_length=max_prompt_length) - -train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] -eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] if enable_evaluate else [] - -dataloader_config = DataloaderConfig(pack_max_length=pack_max_length, collator="fake_collator", pack_level="none") - -# 3. judger -gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k") -judger_cfg = JudgerConfig(reward_judger_configs=[gsm8k_judger_config]) - -# 4. dataflow and evaluator -dataflow_config = DataFlowConfig( - env=experimental_name, - prompt_repeat_k=prompt_repeat_k, - global_batch_size=global_batch_size, - sample_params=training_sample_params, - enable_partial_rollout=enable_partial_rollout, -) - -evaluator_cfg = ( - EvaluatorConfig( - enable_evaluate=enable_evaluate, - enable_initial_evaluate=enable_initial_evaluate, - dataset_cfg=eval_dataset_cfg, - tokenizer=tokenizer, - evaluate_step=evaluate_step, - compute_metric_func=None, - sample_params=evaluation_sample_params, - ) - if enable_evaluate - else None -) - -# replay buffer config: : 不需要修改 -replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=train_dataset_cfg, dataloader_cfg=dataloader_config, tokenizer=tokenizer -) - -# 5. Train worker -model_cfg = get_model_config_from_hf(Path(model_path)) -optim_cfg = AdamWConfig(lr=1e-6, foreach=False) -loss_cfg = GRPOLossConfig( - policy_loss_cfg=dict( - cliprange_high=0.2, - cliprange_low=0.2, - loss_type="vanilla", - ), - ignore_idx=-100, - use_kl_loss=True, - kl_loss_coef=0.001, - kl_loss_type="low_var_kl", - mode="chunk", - chunk_size=512, -) -lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) -fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) -train_worker_cfg: WorkerConfig = WorkerConfig( - model_cfg=model_cfg, - load_from=model_path, - optim_cfg=optim_cfg, - loss_cfg=loss_cfg, - lr_cfg=lr_cfg, - fsdp_cfg=fsdp_cfg, - sp_size=1, - optimizer_steps=train_optimizer_steps, - pack_max_length=pack_max_length, -) - -# 6. RL Trainer -trainer = RLTrainerConfig( - load_from=model_path, - resources=resources, - rollout_config=rollout_config, - dataflow_config=dataflow_config, - judger_config=judger_cfg, - replay_buffer_config=replay_buffer_cfg, - evaluator_config=evaluator_cfg, - train_worker_config=train_worker_cfg, - tokenizer_path=model_path, - work_dir=work_dir, - total_epochs=total_epochs, - hf_interval=hf_interval, - exp_tracker="jsonl", -) diff --git a/autotest/config/rl_qwen3_gsm8k_grpo.py b/autotest/config/rl_qwen3_gsm8k_grpo.py new file mode 100644 index 0000000000..ecd7ed778b --- /dev/null +++ b/autotest/config/rl_qwen3_gsm8k_grpo.py @@ -0,0 +1,206 @@ +"""RL Colocate Trainer 示例配置(GRPO + GSM8K)。 + +用法:通过环境变量传入路径后,由 CLI 加载本配置并 trainer_cfg.build().fit()。 需设置: WORK_DIR, MODEL_PATH, DATA_PATH, EVAL_DATA_PATH 可选: +WORLD_SIZE, ENABLE_RETURN_ROUTED_EXPERTS, LOSS_TYPE, LOSS_MODE, SP_SIZE +""" + +import os +from pathlib import Path + +from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig +from xtuner.v1.data_proto.rl_data import SampleParams +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLTextTokenizeFnConfig +from xtuner.v1.model import get_model_config_from_hf +from xtuner.v1.rl.advantage import GRPOAdvantageConfig +from xtuner.v1.rl.agent_loop import SingleTurnAgentLoopConfig +from xtuner.v1.rl.agent_loop_manager import ( + AgentLoopManagerConfig, + SamplerConfig, + SyncProduceStrategyConfig, + TaskSpecConfig, +) +from xtuner.v1.rl.evaluator import EvaluatorConfig +from xtuner.v1.rl.judger.gsm8k import GSM8KJudgerConfig +from xtuner.v1.rl.loss import GRPOLossConfig +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.trainer import WorkerConfig +from xtuner.v1.rl.utils import AcceleratorResourcesConfig +from xtuner.v1.train.rl_trainer import RLColocateTrainerConfig + + +# env +work_dir = os.environ["WORK_DIR"] +model_path = os.environ["MODEL_PATH"] +data_path = os.environ["DATA_PATH"] +eval_data_path = os.environ["EVAL_DATA_PATH"] +enable_return_routed_experts = os.environ.get("ENABLE_RETURN_ROUTED_EXPERTS", "0") +WORLD_SIZE = int(os.environ.get("WORLD_SIZE", "1")) + +# basic settings +experimental_name = "grpo_gsm8k" +total_train_steps = 45 +evaluate_step = 45 +train_optimizer_steps = 1 +train_batch_size = 64 * train_optimizer_steps +prompt_repeat_k = 5 +rollout_tp_size = 1 +rollout_ep_size = 1 +max_prompt_length = 512 +max_response_length = 1024 +pack_max_length = 32 * 1024 + +# 1. resources +resources = AcceleratorResourcesConfig( + accelerator="GPU", + num_workers=8 * WORLD_SIZE, + num_cpus_per_worker=12, + cpu_memory_per_worker=16 * 1024**3, # 16 GB +) + +# 2. rollout +rollout_config = RolloutConfig( + env=experimental_name, + device=resources.accelerator, + model_path=model_path, + dtype="bfloat16", + tensor_parallel_size=rollout_tp_size, + expert_parallel_size=rollout_ep_size, + gpu_memory_utilization=0.8, + context_length=max_response_length + max_prompt_length, + enable_return_routed_experts=(enable_return_routed_experts == "1"), +) + +# 3. judger +judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", num_ray_actors=1) + +# 4. train worker +lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) +fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) +model_cfg = get_model_config_from_hf(Path(model_path)) +if hasattr(model_cfg, "balancing_loss_cfg"): + model_cfg.balancing_loss_cfg = None +if hasattr(model_cfg, "z_loss_cfg"): + model_cfg.z_loss_cfg = None +optim_cfg = AdamWConfig(lr=1e-6, foreach=False, weight_decay=0.1) +loss_cfg = GRPOLossConfig( + policy_loss_cfg=dict( + cliprange_high=0.28, + cliprange_low=0.2, + loss_type=os.environ.get("LOSS_TYPE", "vanilla"), + clip_ratio_c=10.0, + log_prob_diff_min=-20.0, + log_prob_diff_max=20.0, + ), + ignore_idx=-100, + use_kl_loss=False, + kl_loss_coef=0.0, + kl_loss_type="low_var_kl", + mode=os.environ.get("LOSS_MODE", "chunk"), + chunk_size=512, +) +train_worker_cfg = WorkerConfig( + model_cfg=model_cfg, + load_from=model_path, + optim_cfg=optim_cfg, + loss_cfg=loss_cfg, + lr_cfg=lr_cfg, + fsdp_cfg=fsdp_cfg, + sp_size=int(os.environ.get("SP_SIZE", "1")), + optimizer_steps=train_optimizer_steps, + pack_max_length=pack_max_length, +) + +# 5. train agent loop manager +train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) +tokenizer_config = RLTextTokenizeFnConfig(max_length=max_prompt_length) +train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] +dataloader_cfg = DataloaderConfig( + dataset_config_list=train_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +sampler_config = SamplerConfig( + dataloader_cfg=dataloader_cfg, + prompt_repeat_k=prompt_repeat_k, +) +training_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=0, + top_p=1.0, + temperature=1.0, + min_tokens=0, +) +agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=training_sample_params, +) +produce_strategy_config = SyncProduceStrategyConfig() +agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="train_task", + agent_loop_config=agent_loop_config, + judger_config=judger_config, + produce_strategy_config=produce_strategy_config, + sampler_config=sampler_config, + ), +) + +# 6. eval agent loop manager +eval_dataset = DatasetConfig(name=experimental_name, anno_path=eval_data_path, sample_ratio=1.0) +eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] +eval_dataloader_cfg = DataloaderConfig( + dataset_config_list=eval_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +eval_sampler_config = SamplerConfig( + dataloader_cfg=eval_dataloader_cfg, + prompt_repeat_k=1, +) +evaluation_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=1, + top_p=1.0, + temperature=0.0, + min_tokens=0, +) +eval_agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=evaluation_sample_params, +) +eval_agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="eval_task", + agent_loop_config=eval_agent_loop_config, + judger_config=judger_config, + sampler_config=eval_sampler_config, + ), +) + +# 7. evaluator +evaluator_config = EvaluatorConfig(compute_metric_func=None) + +# 8. RL Colocate Trainer Config(CLI 通过 config["trainer"].build() 得到 Trainer) +trainer = RLColocateTrainerConfig( + resources=resources, + train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config + rollout_config=rollout_config, + tokenizer_path=model_path, + replay_buffer_config=dict(), + agent_loop_manager_cfg=agent_loop_manager_cfg, + eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg, + evaluator_config=evaluator_config, + load_from=model_path, + total_train_steps=total_train_steps, + train_batch_size=train_batch_size, + advantage_estimator_config=GRPOAdvantageConfig(eps=1e-8), + enable_evaluate=True, + enable_initial_evaluate=False, + evaluate_step=evaluate_step, + work_dir=work_dir, + seed=123, + debug_rollout=False, +) diff --git a/autotest/config/rl_qwen3_vl_geometry3k_grpo.py b/autotest/config/rl_qwen3_vl_geometry3k_grpo.py index afea2aeac5..79952fd28c 100644 --- a/autotest/config/rl_qwen3_vl_geometry3k_grpo.py +++ b/autotest/config/rl_qwen3_vl_geometry3k_grpo.py @@ -1,29 +1,36 @@ import os from copy import deepcopy -from transformers import AutoTokenizer -from xtuner.v1.config import ( - AdamWConfig, - FSDPConfig, - LRConfig, -) + +from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig from xtuner.v1.data_proto.rl_data import SampleParams +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLQwen3VLTokenizeFnConfig from xtuner.v1.model.compose.qwen3_vl import Qwen3VLDense8BConfig -from xtuner.v1.ray.base import AcceleratorResourcesConfig -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig -from xtuner.v1.ray.evaluator import EvaluatorConfig -from xtuner.v1.ray.judger.controller import JudgerConfig -from xtuner.v1.rl.base import WorkerConfig -from xtuner.v1.rl.grpo import GRPOLossConfig -from xtuner.v1.train.rl_trainer import RLTrainerConfig -from xtuner.v1.datasets import RLTokenizeFnConfig, DatasetConfig, Qwen3VLTokenizeFnConfig, DataloaderConfig -from xtuner.v1.ray.judger.geo3k import GEO3KJudgerConfig +from xtuner.v1.rl.advantage import GRPOAdvantageConfig +from xtuner.v1.rl.agent_loop import SingleTurnAgentLoopConfig +from xtuner.v1.rl.agent_loop_manager import ( + AgentLoopManagerConfig, + SamplerConfig, + SyncProduceStrategyConfig, + TaskSpecConfig, +) +from xtuner.v1.rl.evaluator import EvaluatorConfig +from xtuner.v1.rl.judger import GEO3KJudgerConfig +from xtuner.v1.rl.loss import GRPOLossConfig +from xtuner.v1.rl.replay_buffer import SyncReplayBufferConfig +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.trainer import WorkerConfig +from xtuner.v1.rl.utils import AcceleratorResourcesConfig +from xtuner.v1.train.rl_trainer import RLColocateTrainerConfig + + work_dir = os.environ["WORK_DIR"] model_path = os.environ["MODEL_PATH"] data_path = os.environ["DATA_PATH"] eval_data_path = os.environ["EVAL_DATA_PATH"] -enable_evaluate = True if eval_data_path != "" else False +enable_evaluate = eval_data_path != "" media_root = os.environ["MEDIA_ROOT"] + # basic settings experimental_name = "grpo_geo3k" total_epochs = 15 @@ -39,12 +46,14 @@ enable_initial_evaluate = True evaluate_step = 15 +# 1. resources resources = AcceleratorResourcesConfig( accelerator="GPU", num_workers=8, num_cpus_per_worker=12, - cpu_memory_per_worker=16 * 1024**3, # 16 GB + cpu_memory_per_worker=16 * 1024**3, ) + # 2. rollout rollout_config = RolloutConfig( env=experimental_name, @@ -54,77 +63,64 @@ tensor_parallel_size=rollout_tp_size, expert_parallel_size=rollout_ep_size, gpu_memory_utilization=0.75, - context_length = max_response_length + max_prompt_length, - # rollout_max_batch_size_per_instance=64, # optional, will be determined automatically if not set + context_length=max_response_length + max_prompt_length, ) + # sampling params -training_sample_params = SampleParams( - max_tokens=max_response_length, -) +training_sample_params = SampleParams(max_tokens=max_response_length) evaluation_sample_params = deepcopy(training_sample_params) evaluation_sample_params.top_p = 1.0 evaluation_sample_params.temperature = 0.0 evaluation_sample_params.top_k = 1 -# dataset: 不需要修改 -train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) -eval_dataset = DatasetConfig(name=experimental_name, anno_path=eval_data_path) if enable_evaluate else None -tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) -tokenize_fn_cfg = Qwen3VLTokenizeFnConfig(processor_path=model_path, max_length=max_prompt_length) + +# 3. datasets train_dataset_cfg = [ { - "dataset": DatasetConfig(name="geo3k", - anno_path=data_path, - class_name='VLMJsonlDataset', - media_root=media_root, - sample_ratio=1.0), - "tokenize_fn": RLTokenizeFnConfig(max_length=max_prompt_length, - tokenize_fn_cfg=tokenize_fn_cfg), + "dataset": DatasetConfig( + name="geo3k", + anno_path=data_path, + class_name="VLMJsonlDataset", + media_root=media_root, + sample_ratio=1.0, + ), + "tokenize_fn": RLQwen3VLTokenizeFnConfig(processor_path=model_path, max_length=max_prompt_length), } ] -eval_dataset_cfg = [] -if enable_evaluate: - eval_dataset_cfg = [ - { - "dataset": DatasetConfig(name="geo3k", - anno_path=eval_data_path, - class_name='VLMJsonlDataset', - media_root=media_root, - sample_ratio=1.0), - "tokenize_fn": RLTokenizeFnConfig(max_length=max_prompt_length, - tokenize_fn_cfg=tokenize_fn_cfg, - ignore_multimodal_info=True), - } - ] -dataloader_config = DataloaderConfig(num_workers=8, - collator="fake_collator", - pack_level="none") -# 3. judger -geo3k_judger_config = GEO3KJudgerConfig() -judger_cfg = JudgerConfig(reward_judger_configs=[geo3k_judger_config]) -# 4. dataflow and evaluator -dataflow_config = DataFlowConfig( - env=experimental_name, - prompt_repeat_k=prompt_repeat_k, - global_batch_size=global_batch_size, - sample_params=training_sample_params, - # max_concurrent=64, # optional, will be determined automatically if not set +eval_dataset_cfg = [ + { + "dataset": DatasetConfig( + name="geo3k", + anno_path=eval_data_path if enable_evaluate else data_path, + class_name="VLMJsonlDataset", + media_root=media_root, + sample_ratio=1.0, + ), + "tokenize_fn": RLQwen3VLTokenizeFnConfig( + processor_path=model_path, + max_length=max_prompt_length, + ignore_multimodal_info=True, + ), + } +] +dataloader_cfg = DataloaderConfig( + dataset_config_list=train_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", + num_workers=8, ) -evaluator_cfg = EvaluatorConfig( - enable_evaluate=enable_evaluate, - enable_initial_evaluate=enable_initial_evaluate, - dataset_cfg=eval_dataset_cfg, - dataloader_cfg=dataloader_config, - tokenizer=tokenizer, - evaluate_step=evaluate_step, - compute_metric_func=None, - sample_params=evaluation_sample_params, -) if enable_evaluate else None -# replay buffer config: : 不需要修改 -replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=train_dataset_cfg, dataloader_cfg=dataloader_config, tokenizer=tokenizer +eval_dataloader_cfg = DataloaderConfig( + dataset_config_list=eval_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", + num_workers=8, ) -# 5. Train worker -# NOTE: modify model_cfg + +# 4. judger +judger_config = GEO3KJudgerConfig() + +# 5. train worker model_cfg = Qwen3VLDense8BConfig(freeze_vision=True, freeze_projector=True) optim_cfg = AdamWConfig(lr=1e-6, foreach=False) loss_cfg = GRPOLossConfig( @@ -142,7 +138,7 @@ ) lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False) -train_worker_cfg: WorkerConfig = WorkerConfig( +train_worker_cfg = WorkerConfig( model_cfg=model_cfg, load_from=model_path, optim_cfg=optim_cfg, @@ -153,19 +149,53 @@ optimizer_steps=train_optimizer_steps, pack_max_length=pack_max_length, ) -# 6. RL Trainer -trainer = RLTrainerConfig( - load_from=model_path, + +# 6. agent loop managers +agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=training_sample_params, +) +agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="train_task", + agent_loop_config=agent_loop_config, + judger_config=judger_config, + produce_strategy_config=SyncProduceStrategyConfig(), + sampler_config=SamplerConfig(dataloader_cfg=dataloader_cfg, prompt_repeat_k=prompt_repeat_k), + ), +) + +eval_agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=evaluation_sample_params, +) +eval_agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="eval_task", + agent_loop_config=eval_agent_loop_config, + judger_config=judger_config, + sampler_config=SamplerConfig(dataloader_cfg=eval_dataloader_cfg, prompt_repeat_k=1), + ), +) + +# 7. trainer +trainer = RLColocateTrainerConfig( resources=resources, + train_worker_cfg=train_worker_cfg, rollout_config=rollout_config, - dataflow_config=dataflow_config, - judger_config=judger_cfg, - replay_buffer_config=replay_buffer_cfg, - evaluator_config=evaluator_cfg, - train_worker_config=train_worker_cfg, tokenizer_path=model_path, - work_dir=work_dir, + replay_buffer_config=SyncReplayBufferConfig(), + agent_loop_manager_cfg=agent_loop_manager_cfg, + eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg, + evaluator_config=EvaluatorConfig(compute_metric_func=None), + load_from=model_path, total_epochs=total_epochs, + train_batch_size=global_batch_size, + advantage_estimator_config=GRPOAdvantageConfig(eps=1e-8), + enable_evaluate=enable_evaluate, + enable_initial_evaluate=enable_evaluate and enable_initial_evaluate, + evaluate_step=evaluate_step, + work_dir=work_dir, hf_interval=hf_interval, exp_tracker="jsonl", ) diff --git a/ci/scripts/test_ray_sft.py b/ci/scripts/test_ray_sft.py deleted file mode 100644 index 10943072a8..0000000000 --- a/ci/scripts/test_ray_sft.py +++ /dev/null @@ -1,110 +0,0 @@ -import os -import argparse - -from xtuner.v1.engine import EngineConfig -from xtuner.v1.model.moe.qwen3 import Qwen3MoE30BA3Config -from xtuner.v1.config import ( - AdamWConfig, - FSDPConfig, - LRConfig, -) -from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig -from xtuner.v1.model.moe.moe import BalancingLossConfig, ZLossConfig -from xtuner.v1.datasets import FTDPTokenizeFnConfig -import ray -from xtuner.v1.rl.base.worker import TrainingWorker -from xtuner.v1.ray.base import AutoAcceleratorWorkers, AcceleratorResourcesConfig -from xtuner.v1.train import TrainerConfig -from xtuner.v1.train.trainer import Trainer -from xtuner.v1.loss.ce_loss import CELossConfig -import torch - - -def test_ray(trainer_cfg): - ray.init() - resources = AcceleratorResourcesConfig( - accelerator="GPU", - num_workers=8, - cpu_memory_per_worker=16 * 1024**3, # 16 GB - ) - workers, pg = AutoAcceleratorWorkers.from_config( - TrainingWorker, trainer_cfg, resources - ) - futures = [ worker.test_all_reduce.remote() for worker in workers ] - print(ray.get(futures)) - handles = [worker.fit.remote() for worker in workers] - print(ray.get(handles)) - return - - -def test_torchrun(rank, trainer_cfg): - os.environ["RANK"] = str(rank) - os.environ["WORLD_SIZE"] = "8" - os.environ["MASTER_ADDR"] = "127.0.0.1" - os.environ["MASTER_PORT"] = "29000" - os.environ["LOCAL_RANK"] = str(rank) - trainer = Trainer.from_config(trainer_cfg) - trainer.fit() - return - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Test internode EP kernels') - parser.add_argument('--ray', action='store_true') - parser.add_argument('--model-path', type=str, default=os.environ["QWEN3_MOE_PATH"]) - parser.add_argument('--data-path', type=str, default=os.environ["ALPACA_PATH"]) - args = parser.parse_args() - - moe_cfg = Qwen3MoE30BA3Config( - ep_size=1, - balancing_loss_cfg=BalancingLossConfig(), - z_loss_cfg=ZLossConfig(), - ) - - optim_cfg: AdamWConfig = AdamWConfig() - fsdp_cfg: FSDPConfig = FSDPConfig( - torch_compile=True, - cpu_offload=False, - ep_size=1, - ) - dataset_cfg = [ - dict(dataset=DatasetConfig(name='alpaca', anno_path=args.data_path, sample_ratio=1.0), - tokenize_fn=FTDPTokenizeFnConfig()), - ] - - dataloader_cfg = DataloaderConfig( - pack_max_length=512, - max_length=512, - ) - - engine_cfg = EngineConfig( - model_cfg=moe_cfg, - fsdp_cfg=fsdp_cfg, - optim_cfg=optim_cfg, - ) - lr_cfg = LRConfig(lr_type="cosine", lr_min=1e-6) - - loss_cfg = CELossConfig() - - trainer_cfg = TrainerConfig( - model_cfg=moe_cfg, - load_from=args.model_path, - tokenizer_path=args.model_path, - dataset_cfg=dataset_cfg, - dataloader_cfg=dataloader_cfg, - optim_cfg=optim_cfg, - lr_cfg=lr_cfg, - fsdp_cfg=fsdp_cfg, - loss_cfg=loss_cfg, - global_batch_size=16, - total_epoch=1, - work_dir="/tmp/qwen3_moe_test", - seed=42, - ) - - if args.ray: - test_ray(trainer_cfg) - else: - torch.multiprocessing.spawn(test_torchrun, args=(trainer_cfg, ), nprocs=8) - test_torchrun(trainer_cfg) - diff --git a/design/disagg_design_part1.md b/design/disagg_design_part1.md new file mode 100644 index 0000000000..7caa8488da --- /dev/null +++ b/design/disagg_design_part1.md @@ -0,0 +1,866 @@ +# 非共卡训练设计说明 + +## 1. 设计目标 + +当前 xtuner 已经有一套共卡训练流程: + +- rollout 和 train 共用同一批卡 +- 每一轮训练前先做 rollout +- rollout 完成后切换到训练 +- 训练结束后同步权重,再进入下一轮 rollout + +这套流程实现简单,但 rollout 与 train 强同步,两个阶段会互相等待。 + +本次设计要补的是“非共卡训练”路径: + +- rollout 和 train 使用不同的卡组 +- rollout 在后台持续生成数据 +- trainer 前台按需消费 replay buffer 中的数据 +- 到权重同步点时,由 trainer 显式中断 producer +- 权重同步完成后,再恢复 producer + +一句话概括:从“串行切换式 rollout/train”改成“后台生产 + 前台训练 + 显式同步点”的模式。 + +--- + +## 2. 总体思路 + +为了尽量少改现有代码,本方案不推翻已有 `AgentLoopManager + AsyncProduceStrategy + ReplayBuffer` 的结构,而是把现在耦合在一次 `produce_batch()` 调用中的逻辑拆开。 + +现有生产大致是: + +1. 调度 rollout +2. 回收 pending +3. 从 replay buffer 取出训练 batch + +新的设计会把它拆成三个可复用步骤: + +- `_produce_batch_to_buffer(...)` + 只负责调度 rollout,把结果写入 replay buffer +- `pause_product(...)` + 显式停止并回收尚未收尾的 rollout 任务 +- `_get_batch_from_buffer(...)` + 只负责从 replay buffer 取训练 batch,并组装统计信息 + +这样有两个直接好处: + +- 共卡路径仍然可以继续复用 `produce_batch()`,只是内部改成三段式 +- 非共卡路径可以单独复用“只生产”和“只取数”的步骤,形成后台 producer + 前台 consumer + +目前设计的伪代码见 `design/disagg_draft.py` + +--- + +## 3. 关键状态与状态机 + +### 3.1 `ProduceBatchStatus` + +`ProduceBatchStatus` 表示“某一次 producer 调度调用”的结果,不是 manager 的全局状态。 + +- `NORMAL` + - 本次调度正常结束 + - 可能已经新生成了一批样本写入 buffer + - 也可能发现 buffer 里已经有足够样本,因此不需要继续发新 rollout + +- `UPDATE_ABORT` + - 外部准备进行权重同步 + - producer 不应再继续补发新任务 + - 剩余 pending rollout 交给外层显式 pause + +- `EXPIRED_BATCH` + - 当前 rollout 使用的模型版本已经过旧 + - 在本设计里,这会被当成一个立即停止信号 + - 不再优先尝试消费旧 completed leftovers,而是立刻进行权重更新,这样是为了尽早更新rollout权重避免其占卡空转 + +### 3.2 `AgentLoopManagerStatus` + +`AgentLoopManagerStatus` 表示 manager 的全局运行状态,可以把它理解为后台 producer 主循环的状态机。 + +- `NORMAL` + - 正常生成状态 + +- `UPDATE_ABORT` + - 已收到权重更新信号 + - producer 暂停继续补任务 + - 等待 trainer 完成 pause 和权重同步 + +- `EXPIRED_BATCH` + - 当前 rollout model 已经过旧 + - trainer 看到这个状态后会立刻跳过训练,直接进入权重同步 + +- `FINISH` + - 整个训练结束 + - producer loop 应退出 + +### 3.3 状态流转 + +全局状态按下面的路径流转: + +- 初始状态是 `NORMAL` +- `NORMAL -> UPDATE_ABORT` + - trainer 开始做权重同步前触发 +- `UPDATE_ABORT -> NORMAL` + - 权重同步完成后调用 `continue_product()` +- `NORMAL -> EXPIRED_BATCH` + - 当前 rollout model 已经过旧 +- `EXPIRED_BATCH -> UPDATE_ABORT` + - trainer 检测到过期后,进入权重同步阶段 +- 任意状态 -> `FINISH` + - 训练结束 + +这里有一个重要区分: + +- `ProduceBatchStatus` 是“单次调度调用的局部结果” +- `AgentLoopManagerStatus` 是“后台 producer 的全局运行状态” + +--- + +## 4. 关键接口改动 + +### 4.1 `ProduceBatchResult` + +`ProduceBatchResult` 新增: + +- `status: ProduceBatchStatus` + +用途: + +- 共卡路径下,通常返回 `NORMAL` +- 非共卡路径下,`get_batch()` 如果发现 manager 已经处于 `EXPIRED_BATCH`,可以直接返回一个空 batch,并通过 `status` 告诉 trainer “当前 rollout model 已经过旧,这轮不要训练,直接去做权重同步” + +其余 timing / leftover 字段继续保留,用于训练日志与调试。 + +### 4.2 `AgentLoopManager` + +新增成员: + +- `_update_event: asyncio.Event` + - trainer 触发权重更新时置位 + - producer 看到后尽快停止继续补任务 + +- `_finish_event: asyncio.Event` + - 用于训练结束时安全退出 producer loop + +- `_model_rollout_step: int` + - rollout 当前使用的是哪一版权重 + - 注意它和 producer 自己循环的 `rollout_step` 不是同一个概念 + +- `_status: AgentLoopManagerStatus` + - manager 的全局运行状态 + +- `_pause_time_s: float` + - 最近一次 `pause_product()` 的耗时 + - 在下次 `ProduceBatchResult` 中上报后清零 + +新增接口: + +- `pause_product(for_weight_update: bool = False) -> float` +- `continue_product(model_rollout_step: int) -> None` +- `produce_loop(batch_size: int, start_rollout_step: int) -> Awaitable[None]` +- `get_batch(batch_size: int, rollout_step: int) -> Awaitable[ProduceBatchResult]` + +### 4.3 `AsyncProduceStrategy` + +改动点: + +- `pause_product()` 从私有方法改成公开接口 +- `pending_tasks` 改成实例属性 `self._pending_tasks` +- `produce_batch()` 改成“只生产到 buffer”,返回 `ProduceBatchStatus` +- `produce_batch()` 不再在内部自动 pause + +这里的核心意图是:把“停 producer”这个动作交回外层控制,而不是由 strategy 自己决定何时收尾。 + +--- + +## 5. 生产侧设计 + +### 5.1 为什么 `self._pending_tasks` 要持久化 + +旧设计里,`pending_tasks` 是一次 `produce_batch()` 调用内的局部变量。 + +这适合共卡训练,因为: + +- 一次 `produce_batch()` 调完就要收尾 +- 不存在跨多轮调用还继续保留 pending rollout 的需求 + +但非共卡场景不同: + +- producer loop 会持续运行 +- 多次 `AsyncProduceStrategy.produce_batch()` 调用之间,可能有 rollout 任务还没结束 +- 这些 pending 任务需要跨调用保留,并在合适的时候统一 pause + +因此 `pending_tasks` 必须提升为实例属性 `self._pending_tasks`。 + +### 5.2 `AsyncProduceStrategy.produce_batch()` 的新职责 + +新的 `produce_batch()` 不再返回训练 batch,而只负责“往 replay buffer 里生产数据”。 + +调用开始时需要做这些事: + +1. 回收 `self._pending_tasks` 中已经完成的任务 +2. 立即判断当前模型版本是否已经过旧 +3. 如果模型仍然新鲜,再判断 buffer / 是否需要继续补发 rollout +4. 把新的结果写回 replay buffer + +返回结果是一个 `ProduceBatchStatus`: + +- `NORMAL` +- `UPDATE_ABORT` +- `EXPIRED_BATCH` + +### 5.3 `pause_product()` 外提 + +旧逻辑里,strategy 会在 `produce_batch()` 末尾自动 cleanup pending rollout。 + +新逻辑把 pause 提升到 manager / trainer 外层,原因是: + +- 非共卡下,什么时候停 producer,不应由 strategy 自己决定 +- trainer 需要在权重同步前明确地说“现在停下来” +- pause 的耗时也应作为一次显式操作被记录 + +因此: + +- `AsyncProduceStrategy.pause_product()` 负责真正停止并回收 pending rollout +- `AgentLoopManager.pause_product()` 负责在单 task / 多 task 情况下统一调度各 strategy 的 pause + +### 5.4 生成耗时统计改法 + +原本 `ProducerTimings.generate_times_s` 是 `produce_batch()` 的返回值之一。 + +但现在 `AsyncProduceStrategy.produce_batch()` 已不再直接返回训练 batch,也不再适合携带这一类结果统计。 + +新的做法是: + +- 每个 group 的 generate time 直接写到对应 `RolloutState.extra_fields["group_generate_time_s"]` +- `AgentLoopManager._get_batch_from_buffer()` 取 batch 时再重建 `group_gen_count / mean / p50 / p99` +- `pause_product()` 返回 `pause_time_s` +- `AgentLoopManager` 暂存到 `_pause_time_s`,下一次 `ProduceBatchResult` 带出后清零 + +这样统计信息仍然保留,但不再强依赖 strategy 的同步返回值。 + +--- + +## 6. Staleness 设计 + +### 6.1 为什么要区分 `rollout_step` 和 `model_rollout_step` + +非共卡下,producer 是后台持续运行的,它自己的循环步数会不断前进。 + +但样本真正重要的信息不是“producer 在第几轮循环里生成了它”,而是: + +- 这段 token 是用哪一版权重生成的 + +因此: + +- `response_rollout_steps` 不应再记录“当前 producer 调度步数” +- 而应记录“当前 rollout 使用的模型版本步数”,即 `_model_rollout_step` + +所以 `SingleTurnAgentLoop` / `PartialRolloutHandler` 需要改成接收 `model_rollout_step`。 + +### 6.2 为什么仍然保留样本 staleness 重算能力 + +非共卡场景下,buffer 中的样本可能停留更久。 + +如果只看样本写入时的 `seq_staleness`,有问题: + +- 它只是历史快照 +- 到了真正训练时,这个样本可能已经又老了很多 + +因此设计里仍建议保留一个轻量 helper,例如: + +- `refresh_seq_staleness(group, current_rollout_step)` + +它的职责是: + +- 在需要检查旧样本新鲜度时 +- 根据 `response_rollout_steps` 和当前训练步 +- 重新计算 staleness + +### 6.3 `ExpiredBatch` 的判定 + +`ExpiredBatch` 不是简单地说“当前没有数据”,也不是“先尽量消化旧 completed leftovers 再决定要不要停”。 + +它的真实含义是: + +- 当前 rollout 使用的模型版本已经过旧 +- 为了让 rollout 侧尽快切到新权重,不再继续占卡等待 +- producer 应立即停止并要求外层尽快做权重同步 + +因此这里采用的是“更激进的停机策略”: + +- 只要当前 rollout model 过旧,就直接返回 `EXPIRED_BATCH` +- 不再优先尝试复用 buffer 里的旧 completed 数据 +- trainer 收到信号后,直接跳过训练并推进权重更新 + +这样做的原因是: + +- 非共卡场景下,rollout 卡组本来就和 train 卡组解耦 +- 如果 rollout 明知已经过旧,还继续等待 trainer 去消化旧数据, + rollout 侧会白白占卡、拖慢权重切换 +- 所以这里优先保证 rollout 尽快更新,而不是优先榨取旧样本 + +--- + +## 7. AgentLoopManager 设计 + +### 7.1 `_produce_batch_to_buffer(...)` + +这是新的内部工具函数,只负责生产,不负责取数。 + +这里有一个实现约束: + +- `model_rollout_step` 不再作为 `_produce_batch_to_buffer(...)` 的显式参数传入 +- 而是统一从 `self._model_rollout_step` 读取 +- 共卡路径在 `produce_batch()` 入口通过 + `continue_product(model_rollout_step=rollout_step)` 对齐这个状态 +- 非共卡路径则在外部权重同步完成后通过 `continue_product(...)` 更新它 + +单 task: + +- 直接调用该 task 对应的 `AsyncProduceStrategy.produce_batch()` + +多 task: + +- 继续沿用 `get_task_batch_sizes()` 进行 batch 分配 +- 用 `asyncio.gather()` 并发调度各 task 的 produce +- 对返回的 `ProduceBatchStatus` 做聚合 + +状态聚合优先级: + +- `UPDATE_ABORT` > `EXPIRED_BATCH` > `NORMAL` + +原因是: + +- 如果有任何 task 收到了权重更新信号,整个 producer 就应优先停下来 +- 其次才是某些 task 因当前 rollout model 过旧而直接停机 + +### 7.2 `_get_batch_from_buffer(...)` + +这是新的内部工具函数,只负责取数。 + +职责: + +1. 从 replay buffer 中按 `COMPLETED` 取训练 batch +2. 统计 leftover 的 `COMPLETED / ABORTED / EXPIRED` +3. 从 sample 的 `extra_fields` 中重建 generate timing +4. 把最近一次 pause 的 `pause_time_s` 带给结果 + +### 7.3 `pause_product(...)` + +当 `for_weight_update=True` 时: + +1. 先置 `_update_event` +2. 再把 manager 状态切到 `UPDATE_ABORT` +3. 最后调用各 task strategy 的 `pause_product()` + +顺序很重要。 + +先置事件的原因是: + +- 防止 producer 在 pause 开始前又继续补发新任务 + +### 7.4 `continue_product(...)` + +`continue_product(model_rollout_step=...)` 的作用是恢复 producer 控制状态: + +- 清 `_update_event` +- `_status = NORMAL` +- `_model_rollout_step = 当前训练步` + +这样 producer 才知道:“现在 rollout 侧已经切换到新权重,可以继续生成了”。 + +--- + +## 8. 共卡路径怎么适配 + +虽然这次目标是非共卡,但共卡路径也需要适配接口拆分,避免维护两套不同的逻辑。 + +新的共卡 `AgentLoopManager.produce_batch()` 内部改成: + +1. `continue_generation()` +2. `continue_product(model_rollout_step=rollout_step)` +3. `_produce_batch_to_buffer(...)` +4. `pause_product(for_weight_update=False)` +5. `_get_batch_from_buffer(...)` +6. `pause_generation()` + +这样做的意义是: + +- 共卡与非共卡都复用同一套底层逻辑 +- 差别只在于: + - 共卡路径是“一次调用里生产+收尾+取数” + - 非共卡路径是“后台持续生产,前台单独取数” + +--- + +## 9. 非共卡 producer loop + +`produce_loop(batch_size, start_rollout_step)` 是非共卡新增的后台生产循环。 + +主逻辑: + +1. 持续调用 `_produce_batch_to_buffer(batch_size, rollout_step)` +2. 根据返回状态决定下一步动作 + +当返回: + +- `NORMAL` + - 表示本轮生产逻辑正常结束 + - producer 自己维护的本地 `rollout_step += 1` + +- `EXPIRED_BATCH` + - manager 进入 `EXPIRED_BATCH` + - producer 立即暂停继续工作 + - 不再继续尝试消化旧 completed buffer + - 等待 trainer 进行权重同步并 `continue_product()` + +- `UPDATE_ABORT` + - 表示 trainer 正在准备做权重同步 + - producer 不再自己 pause,避免与 trainer 竞争 + - 只等待外部 `continue_product()` + +- `FINISH` + - 退出 producer loop + +这里有一个重要约束: + +- producer loop 收到 `UPDATE_ABORT` 后不做二次 pause +- 只有 trainer 的权重同步路径显式调用 `pause_product()` + +这样可以避免重复回收 pending rollout 的竞态。 + +--- + +## 10. 非共卡 `get_batch(...)` + +`get_batch(batch_size, rollout_step)` 是训练侧的消费接口。 + +正常情况: + +- 它只调用 `_get_batch_from_buffer(...)` +- 本身不驱动新的 rollout 生成 + +特殊情况: + +- 如果当前 `AgentLoopManager._status == EXPIRED_BATCH` +- 则直接返回 `ProduceBatchResult(status=EXPIRED_BATCH, rollout_states=[])` + +这样 trainer 就能收到一个很明确的信号: + +- 这轮不要再训练 +- 直接进入权重同步 + +--- + +## 11. 非共卡 Trainer 设计 + +### 11.1 新 trainer 入口 + +新增: + +- `xtuner/v1/train/rl_disaggregated_trainer.py` + +它的配置形状保持和现有 `rl_disagg_single.py` / `rl_disagg_multi.py` 一致。 + +### 11.2 为什么 `fit()` 仍保持同步 + +当前 CLI 仍然使用: + +- `trainer = trainer_cfg.build()` +- `trainer.fit()` + +为了不改 CLI,设计上保留同步 `fit()`,内部再用 `asyncio_run(self._fit())` 包一层。 + +这样对外接口不变,但内部可以自然地管理: + +- 后台 producer task +- 前台 async 取数 +- eval 与 producer 的优先级关系 + +### 11.3 `_fit()` 主流程 + +训练主流程分成两条并行逻辑: + +- 后台:`producer_task = create_task(agent_loop_manager.produce_loop(...))` +- 前台:训练循环不断 `get_batch()` + +前台每一轮大致是: + +1. `produce_result = await agent_loop_manager.get_batch(...)` +2. 如果 `status != EXPIRED_BATCH` + - `_prepare_train_data(...)` + - `train_controller.fit(...)` +3. 如果到达同步点 + - `agent_loop_manager.pause_product(for_weight_update=True)` + - `_sync_weights_and_save(...)` +4. 如果这一轮需要 eval + - 先做 eval +5. `agent_loop_manager.continue_product(model_rollout_step=current_step)` + +### 11.4 为什么 `EXPIRED_BATCH` 时直接跳过训练 + +`EXPIRED_BATCH` 的语义不是“这一轮没有数据”,而是: + +- 当前 rollout 权重版本已经过旧 +- 应优先让 rollout 侧尽快更新权重,而不是继续等待 trainer 消化旧样本 + +这时如果还继续从 buffer 中拿旧数据训练,会带来两个问题: + +- rollout 侧还要继续等待,不能尽快切到新权重 +- stale 数据仍可能被继续消费 + +因此策略是: + +- 直接跳过 `_prepare_train_data` +- 跳过 `train_controller.fit` +- 直接进入权重同步 + +--- + +## 12. 权重同步与评测优先级 + +### 12.1 `_sync_weights_and_save(...)` + +非共卡下,权重同步前的顺序必须是: + +1. `agent_loop_manager.pause_product(for_weight_update=True)` +2. `_sync_weights_and_save(rollout_step)` + +而 `_sync_weights_and_save(rollout_step)` 内部再做: + +1. `_maybe_save_checkpoint(rollout_step)` +2. `bind_train_rollout(...)` +3. `fake_update_weights()` + +这里暂时不走真实的 `train_controller.update_weights()`,而是保留一个显式占位函数: + +- `fake_update_weights()` + +这样后续接入真实跨卡同步实现时,不需要改 trainer 主流程。 + +### 12.2 为什么 eval 要优先于 producer continue_product + +如果同步权重后立刻 `continue_product()`,producer 会马上恢复后台生成。 + +但如果这一步还要做 eval,就会出现: + +- eval 和 background producer 同时竞争 rollout 资源 + +因此本设计固定采用: + +- 先同步权重 +- 若本轮需要 eval,则先跑 eval +- eval 完成后再 `continue_product()` + +即:eval 的优先级高于 background producer。 + +--- + +## 13. Checkpoint 保存与恢复 + +### 13.1 为什么 checkpoint 需要专门细化 + +共卡训练里,rollout 和 train 是串行切换的,checkpoint 保存点天然比较清晰。 + +但非共卡训练里,多了一个持续运行的后台 producer: + +- replay buffer 可能正在被 producer 写入 +- strategy 内可能还挂着未收尾的 pending rollout +- manager 还维护了 `_model_rollout_step` / `_status` / `_update_event` 这些运行时状态 + +如果不把 save / resume 逻辑说清楚,恢复后很容易出现下面的问题: + +- producer 恢复得太早,在 rollout 权重同步前就继续生成 +- replay buffer 恢复了,但 manager 的 `model_rollout_step` 不对 +- checkpoint 拍下来时还有 pending rollout 没收尾,导致恢复语义不一致 + +所以这里要求 checkpoint 必须在“静止态”拍摄。 + +### 13.2 安全保存点 + +checkpoint 的安全保存点固定放在 `_sync_weights_and_save(...)` 中,且必须满足: + +1. `agent_loop_manager.pause_product(for_weight_update=True)` 已完成 +2. producer 不会继续补发新任务 +3. replay buffer 不再被后台并发写入 +4. `continue_product()` 尚未发生 + +因此保存点的语义是: + +- trainer 当前步的训练结果已经稳定 +- producer 已暂停 +- rollout 仍需在 resume 后由 train 侧重新同步权重 + +这里特意把 checkpoint 放在 `continue_product()` 之前,是为了避免 producer 恢复后台生成后,又把系统带回“动态变化态”。 + +这里还要再强调一个容易误解的点: + +- `_maybe_save_checkpoint(rollout_step)` 中,必须显式传入 + `model_rollout_step_override=rollout_step` +- 不能偷懒直接把当时 manager 内部的 `self._model_rollout_step` 原样存盘 + +原因是: + +- save 的时机在 `pause_product(for_weight_update=True)` 之后 +- 但在 `continue_product(model_rollout_step=rollout_step)` 之前 +- 所以 save 那一刻,manager 里的 `self._model_rollout_step` 仍然还是旧的 rollout 权重版本 +- 而主流程的真实意图,是保存“本轮同步完成后,resume 应该继续使用的新 rollout_step” + +换句话说,这个 override 不是建议项,而是恢复语义正确性的必要条件。 + +### 13.3 保存内容 + +沿用当前 colocate trainer 的三层保存结构: + +#### 1. `AgentLoopManager.save(...)` + +除现有的 sampler / replay buffer 外,非共卡路径还要保存 manager 自身状态。 + +建议保存: + +- 各 task sampler 状态 +- replay buffer +- `agent_loop_manager_state.json` + +`agent_loop_manager_state.json` 至少包含: + +- `model_rollout_step` +- `status` + +其中 `model_rollout_step` 的来源要特别注意: + +- 在 `_maybe_save_checkpoint(rollout_step)` 里,必须通过 + `model_rollout_step_override=rollout_step` 显式写入 +- 不应直接落盘 save 瞬间那个尚未经过 `continue_product()` 推进的旧 + `self._model_rollout_step` + +推荐语义: + +- `status` 保存为 `UPDATE_ABORT` +- 表示这个 checkpoint 拍摄时,producer 已经被暂停 + +不建议保存: + +- `_update_event` +- `_finish_event` +- `_pause_time_s` +- `AsyncProduceStrategy._pending_tasks` + +原因: + +- event 是运行时同步原语,不适合直接持久化 +- `pause_time_s` 是一次性日志字段,resume 后清零即可 +- `pending_tasks` 是内存里的协程对象,checkpoint 前必须已经 pause 并清空 + +#### 2. `train_controller.save(...)` + +和 colocate trainer 语义一致,保存训练态: + +- model +- optimizer +- 其他 DCP 状态 + +#### 3. `trainer_state.json` + +建议至少保存: + +- `cur_step` + +建议额外保存: + +- `global_train_step` +- `model_rollout_step` + +其中: + +- `cur_step` 决定训练主循环恢复到哪一步 +- `model_rollout_step` 主要用于恢复校验和排障 + +### 13.4 `model_rollout_step` 为什么要单独保存 + +`model_rollout_step` 不能只依赖 `cur_step` 间接推导。 + +原因是: + +- checkpoint 是在 `pause_product()` 后、`continue_product()` 前拍摄的 +- 这个时间点的 manager 状态,不一定能直接由训练步数唯一还原 +- 后续如果 sync 策略变化,`cur_step` 与 rollout 实际使用的权重版本也不一定严格一一对应 + +因此这里更准确的做法是: + +- 在 manager state 中显式保存“resume 目标版本”的 `model_rollout_step` +- 这个值在 `_maybe_save_checkpoint(rollout_step)` 时通过 + `model_rollout_step_override=rollout_step` 传入 +- resume 后直接按这个保存值恢复 + +这里之所以不能直接依赖 `self._model_rollout_step`,不是因为它永远不可信, +而是因为当前设计的 checkpoint 保存点刚好卡在: + +1. `pause_product(...)` 已完成 +2. `continue_product(model_rollout_step=rollout_step)` 尚未执行 + +所以 save 瞬间的 `self._model_rollout_step` 在语义上仍代表“旧 rollout 模型版本”, +而 resume 后我们真正希望恢复的是“新的 rollout_step 对应版本”。 + +### 13.5 `AgentLoopManager.save(...)` 的约束 + +为了保证 checkpoint 可恢复,保存前应满足: + +- 所有 `AsyncProduceStrategy._pending_tasks` 为空 +- `AgentLoopManager._status == UPDATE_ABORT` +- `_update_event` 已经置位 +- producer 当前不会继续写 replay buffer +- 调用方显式以 `model_rollout_step_override=rollout_step` 传入 + 本次 checkpoint 对应的目标 rollout 版本 + +如果这些条件不满足,建议: + +- 拒绝保存,或 +- 在 save 前先强制执行 pause + +### 13.6 Resume 顺序 + +resume 的 source of truth 不是 rollout 的运行时内存,而是: + +- train checkpoint +- replay buffer +- sampler 状态 +- manager state + +推荐恢复顺序如下: + +1. `train_controller.resume(checkpoint_path)` +2. `agent_loop_manager.resume(checkpoint_path)` +3. `bind_train_rollout(...)` +4. `fake_update_weights()` 或后续真实权重同步 +5. `agent_loop_manager.continue_product(model_rollout_step=saved_model_rollout_step)` +6. `fit()` 启动新的 `producer_task = create_task(produce_loop(...))` + +这里有两个重要点: + +- 不恢复旧的 producer task + - producer task 是运行时协程,进程重启后必须重新创建 +- rollout 权重不作为 checkpoint source of truth + - resume 后总是从 train 侧重新同步一次 rollout +- `saved_model_rollout_step` 应该对应新的 `rollout_step` + - 也就是 `_maybe_save_checkpoint(rollout_step)` 时显式 override 写进去的值 + - 不能退回到 save 瞬间那个旧的 `self._model_rollout_step` + +### 13.7 `AgentLoopManager.resume(...)` 的目标状态 + +`AgentLoopManager.resume(...)` 恢复 sampler / replay buffer 后,推荐把 manager 留在一个“暂停态”: + +- `_model_rollout_step = saved_model_rollout_step` +- `_status = UPDATE_ABORT` +- `_update_event.set()` +- `_finish_event.clear()` +- `_pause_time_s = 0.0` + +这样做的原因是: + +- resume 完成时 producer 还不应立即继续生成 +- 必须等 trainer 先重新把 train 权重同步到 rollout +- 然后再由 `continue_product()` 切回 `NORMAL` + +### 13.8 eval manager 是否保存 + +默认不保存 `eval_agent_loop_manager`,保持与当前 colocate trainer 一致的语义。 + +原因: + +- eval 数据流不影响训练正确性 +- eval sampler 从头开始通常可接受 +- 这样能减少 checkpoint 内容和恢复复杂度 + +如果后续需要“精确恢复 eval 进度”,再单独扩展即可。 + +--- + +## 14. 测试建议 + +建议至少覆盖以下场景。 + +### 14.1 Producer / Strategy + +- `AsyncProduceStrategy.produce_batch()` 能正确返回: + - `NORMAL` + - `UPDATE_ABORT` + - `EXPIRED_BATCH` +- `self._pending_tasks` 能跨调用保留 +- `pause_product()` 外提后,pending 回收逻辑仍正确 + +### 14.2 Manager + +- 单 task / 多 task 的 `_produce_batch_to_buffer()` 行为一致 +- 多 task 下 `task_batch_sizes` 仍正确分配 +- `get_batch()` 在 `EXPIRED_BATCH` 状态下直接返回空 batch + 状态 +- `pause_product(for_weight_update=True)` 先置 `_update_event` +- `save()` 前若仍有 pending tasks,会拒绝保存或先 pause +- `resume()` 后 manager 先处于 `UPDATE_ABORT`,而不是直接 `NORMAL` + +### 14.3 Trainer + +- `EXPIRED_BATCH` 会跳过训练,直接进入同步 +- eval 步上 continue_product 发生在 eval 之后 +- `FINISH` 时 producer task 能正确退出 +- checkpoint 保存点发生在 pause 之后、continue_product 之前 +- resume 后会先做一次 rollout 权重同步,再启动新的 producer_task + +### 14.4 端到端测试 +- 对于配置示例 examples/v1/config/rl_disagg_multi.py 和 examples/v1/config/rl_disagg_single.py 跑通基本训练流程不报错 +- 运行脚本参考 zdev/rl_design_disagg.sh + +--- + +## 15. 当前明确的设计取舍 + +- `ExpiredBatch` 采用更激进的策略: + - 只要当前 rollout model 过旧,就直接停 + - 不再优先尝试消费旧 completed leftovers + - 优先让 rollout 侧尽快完成权重更新 + +- `produce_loop` 的 batch size 采用显式传参: + - `produce_loop(batch_size, start_rollout_step)` + +- pause 只由 trainer 的权重同步路径显式触发一次 + - producer 收到 `UPDATE_ABORT` 后不再自行二次 pause + +- rollout 当前只支持 `abort_all` + - 不做按 request 的精细化取消 +- 活跃 pending 统一在 pause 阶段处理 + +- `fit()` 对外保持同步 + - 内部通过 async `_fit()` 实现非共卡调度 + +- checkpoint 只在 producer 已暂停的静止态拍摄 + - 不保存运行中的 pending rollout task + - resume 后始终重新创建 producer_task + - rollout 权重始终从 train 侧重新同步 + +--- + +## 16. 对应实现锚点 + +本设计主要落在这些模块: + +- `xtuner/v1/rl/agent_loop/producer.py` + - `AsyncProduceStrategy` + +- `xtuner/v1/rl/agent_loop/agent_loop_manager.py` + - `ProduceBatchResult` + - `AgentLoopManager` + +- `xtuner/v1/rl/agent_loop/single_turn_agent_loop.py` + - `generate_sample(...)` + +- `xtuner/v1/rl/agent_loop/utils.py` + - `PartialRolloutHandler.postprocess(...)` + +- `xtuner/v1/train/rl_disaggregated_trainer.py` + - 新增的非共卡 trainer + +如果需要进一步实现,可以优先按照这个顺序推进: + +1. 先把 `AgentLoopManager` 和 `AsyncProduceStrategy` 的接口拆开 +2. 再补 `RLDisaggregatedTrainer` +3. 最后补测试和配置示例的细化 diff --git a/design/disagg_design_part2.md b/design/disagg_design_part2.md new file mode 100644 index 0000000000..70a2e70cf2 --- /dev/null +++ b/design/disagg_design_part2.md @@ -0,0 +1,962 @@ +# `produce_loop` / `AsyncProduceStrategy` 重设计 + +## 目标 + +在尽量少改现有结构的前提下,解决四件事: + +1. 保留 `AgentLoopManager.produce_loop` 本地的 `future_step`,继续按 future step 逐个预取 batch。 +2. `AsyncProduceStrategy.produce_batch` 的动态控制从“当前 buffer 中有多少 completed”改成“消费者已消费 + buffer fresh + pending”的累计口径,避免消费者取走 batch1 后,生产 batch2 时误补 batch1。 +3. staleness / expired 状态只在两个地方写: + - strategy 在 `replay_buffer.put` 前,根据 `progress.next_consumer_step` 刷新。 + - manager 在 `get_batch` 入口按当前 `rollout_step` 刷新 buffer 中已有 completed,并在成功取出 batch 后推进 `progress.next_consumer_step = rollout_step + 1` 再刷新 leftover completed。 +4. `_pending_tasks` 不再用整体赋值覆盖,改成 snapshot + claim 的增量认领,避免 producer 和 pause 并发 drain 同一 task。 + +## 主要考虑点 + +### 1. `consumed_samples` 和 `consumer_step` 不能只传值 + +Opus 方案里 `produce_batch(..., consumed_samples, consumer_step)` 是进入 strategy 时的一次性快照。 + +这仍然有竞态: + +- producer 进入 `produce_batch` 时,`consumed_samples = 0, fresh = batch1`。 +- producer 正在等待 batch2 rollout 完成。 +- consumer 并发取走 batch1,此时 manager 中 `consumed_samples = batch1, fresh = 0`。 +- 如果 strategy 仍使用旧的 `consumed_samples = 0`,它会误以为 batch1 缺失,继续多补一批。 + +因此 strategy 内每次计算动态控制、以及每次 put 前刷新 staleness 时,都必须读取 live 值。 + +推荐接口不是传一组 getter,而是传一个可变 progress 引用: + +```python +progress: ProduceProgress +``` + +只要 Manager 原地更新这个对象,strategy 每次读取 `progress.next_consumer_step` / `progress.consumed_samples[task_name]` 时拿到的就是最新值。 + +`next_consumer_step` 不是“已经完成训练的最新 step”,而是 producer 在 put 新样本时应该面向的消费 step: + +- `get_batch(i)` 开始时,训练侧正在等待 step `i` 的 batch,因此设置 `next_consumer_step = i`。 +- `get_batch(i)` 成功取出非空 batch 后,训练侧即将消费 step `i`,producer 后续应面向 step `i + 1`,因此返回前设置 `next_consumer_step = i + 1`。 +- `EXPIRED_BATCH` 或 finish 空返回没有成功消费 batch,不推进到 `i + 1`。 + +### 2. over-sample 不应放大全部历史累计目标 + +Opus 方案使用: + +```python +desired_window = ceil((1 + over_sample) * target_cumulative) +``` + +这会把已经消费过的历史目标也一起放大。假设 batch size = `B`,当前在预取 batch10,前 9 个 batch 都已消费,`over_sample=0.5`: + +- 必要目标只缺 batch10 的 `B`。 +- 但上式要求窗口达到 `15B`,等价于对前 9 个已经消费掉的 batch 也重新保留超发窗口。 + +修正为“按当前 task batch size 给本轮 target 增加固定超发预算”: + +```python +available = consumed_abs + fresh +target_abs = progress.target_samples[task_name] +oversample_budget = ceil(over_sample * task_batch_size) +scheduled_target = target_abs + oversample_budget +``` + +返回条件仍然是: + +```python +available >= target_abs +``` + +`pending` 只用于决定还要不要继续发任务: + +```python +scheduled_effective = available + pending_count +if scheduled_effective < scheduled_target: + schedule_more() +``` + +这样 over-sample 只给当前 task batch 一个固定 ahead window,不会让历史累计目标反复膨胀。tail-batch mode 下 `oversample_budget = 0`,本轮新增任务固定从 `Status.EXPIRED` pool 取样,不主动停止已有 pending,也不强制清空 expired pool。 + +### 3. `_pending_tasks` 不能只靠循环顶部检查 `_update_event` + +`pause_produce(use_global_progress=True)` 会先 set `_update_event`,但随后会立刻进入各 task strategy 的 `pause_produce`。此时后台 producer 可能还停在 `asyncio.wait(self._pending_tasks, ...)` 中。 + +所以不能只假设 producer 会先返回,也不能只在 `produce_batch()` 循环顶部检查一次 event。需要在 strategy 内保证: + +- 同一个 done task 只能被一方认领。 +- `_pending_tasks` 只能增量 add / discard,不能 `self._pending_tasks = set(pending)` 整体覆盖。 +- `_schedule_one()` 在 pending lock 内检查 `update_event.is_set()`;如果 pause 发生在调度中途,本次已创建的 task 必须先加入 pending,再由 pause drain 收尾。 +- `_schedule_tasks_until()` 返回后还要再次检查 `update_event`,避免 pending 已被 pause drain 清空后误返回 `NORMAL`。 + +## 核心状态 + +### Manager 侧 + +新增一个 manager 持有的可变进度对象: + +```python +@dataclass +class ProduceProgress: + next_consumer_step: int + producer_future_step: int + consumed_samples: dict[str, int] + target_samples: dict[str, int] + target_upto_future_step: int + +self._produce_progress: ProduceProgress +``` + +`ProduceProgress` 可以放在 `producer.py` 或一个小的 shared module 中;`agent_loop_manager.py` 已经依赖 `producer.py`,因此由 manager 构造并传给 strategy 不会引入新的反向依赖。 + +含义: + +- `progress.next_consumer_step`:producer 当前应按哪个训练 step 计算新样本的 staleness。fresh disagg 训练初始化为 `1`;`get_batch(i)` 开始时设为 `i`,成功取出 batch 后设为 `i + 1`。 +- `progress.consumed_samples[task]`:consumer 已经从 buffer 取走并用于训练的 group 数,按 task 绝对累计。 +- `progress.producer_future_step`:producer 当前正在预取的 future step。它属于 manager,不属于 strategy。 +- `progress.target_samples[task]`:截至 `progress.target_upto_future_step`,该 task 应该累计生产出的目标 group 数。 +- `progress.target_upto_future_step`:`target_samples` 已经覆盖到的最大 future step;初始化为 `0`。 + +动态控制使用绝对累计口径: + +```python +available(task) = progress.consumed_samples[task] + fresh_completed(task) +required(task) = max(0, progress.target_samples[task] - available(task)) +``` + +只要 target 和 consumed 都是绝对累计量,就不需要维护 Progress Window,也不需要在 sync 后重置窗口。 + +关键约束: + +- `self._produce_progress` 的对象引用应保持稳定,初始化后不要在运行中整体替换。 +- resume/load 时也优先原地更新字段,而不是 `self._produce_progress = ProduceProgress(...)` 后让 strategy 持有旧引用。 +- `consumed_samples` / `target_samples` 也按 key 原地更新;如果必须整体替换 dict,要保证 strategy 没有缓存旧 dict。 +- `progress` 的写入方应收敛在 Manager / 调用方初始化与消费入口: + - Manager 构造、resume、`_ensure_target_upto()`、`get_batch()` 消费计数负责维护全局 `progress`。 + - Strategy 不在传入的 `progress` 上补 key,也不通过 `setdefault()` 修复缺失状态。 + - 传入 `progress` 时,`consumed_samples[task_name]` 和 `target_samples[task_name]` 必须已经存在;缺失应 fail fast。 +- `progress` 的读取方必须显式读取: + - Strategy 内使用 `progress.consumed_samples[task_name]` / `progress.target_samples[task_name]`。 + - 不使用 `dict.get(task_name, 0)` 这类兜底,避免把初始化或 checkpoint 漂移问题隐藏成“目标为 0 / 已消费为 0”。 + - 除了本轮 `target_abs` / `scheduled_target` 这种语义上需要冻结的调度目标,不把 `progress` 字段先复制到局部标量或局部 dict 再使用,例如不要写 `current_rollout_step = progress.next_consumer_step` 或 `target_by_task = dict(progress.target_samples)`;需要字段值时直接读 `progress.xxx`,让并发更新能尽早生效。 + - `progress = self._produce_progress` 这类对象引用别名可以保留;它不复制字段值。 +- `target_cumulative` 只作为过渡期的一致性校验参数;strategy 不用它构造或修复 `progress`。 +- 所有 strategy 调用都必须显式传入已经初始化好的 `progress`,不再支持 `progress=None` 的本地兜底。 + +### Colocate 路径的 progress 约束 + +`AsyncProduceStrategy` 内部只保留一套语义:`available = consumed + fresh_completed`,并和 `target_samples[task_name]` 比较。区别只在 progress 的来源: + +- 非共卡 `produce_loop()` 使用 Manager 的全局 `_produce_progress`,target/consumed 都是跨 step 的绝对累计值。 +- 共卡 `AgentLoopManager.produce_batch()` 不复用非共卡全局进度窗口;它为本次同步调用构造一个局部 `ProduceProgress`,`next_consumer_step` 等于本次 `rollout_step`,含义是“本次同步调用生产出的 batch 要服务的训练 step”。 +- 共卡取走 batch 后,如需要记录 consumed,也应写入这次调用的局部 `ProduceProgress`,不要污染非共卡全局 `_produce_progress`。 +- 直接调用 `AsyncProduceStrategy.produce_batch(...)` 也必须传入 `progress`;测试或临时调用如需同步语义,应由调用方构造一次性 local progress。 + +### Strategy 侧 + +`AsyncProduceStrategy` 保留 task 私有的 pending 集合: + +```python +self._pending_tasks: set[asyncio.Task] +self._pending_lock: asyncio.Lock +``` + +`pending_count` 不建议再单独维护成第二份可变状态,直接使用: + +```python +pending_count = len(self._pending_tasks) +``` + +如果实现上为了日志或性能保留 `_pending_count`,也必须只在同一个 helper 中和 `_pending_tasks` 同步更新,不能分散手写。 + +## Cumulative Target + +Manager 为当前 `progress.producer_future_step` 维护每个 task 的绝对累计目标。 + +推荐不要每次从 step 1 重新求和,而是维护一个单调前进的 target 计数器,并 checkpoint: + +```python +def _ensure_target_upto(self, batch_size: int, current_future_step: int) -> None: + progress = self._produce_progress + if current_future_step <= progress.target_upto_future_step: + return + + for fs in range(progress.target_upto_future_step + 1, current_future_step + 1): + if len(self.task_runners) == 1: + progress.target_samples[self.task_runners[0].task_name] += batch_size + else: + sizes = self.get_task_batch_sizes(batch_size, fs) + self._validate_task_batch_sizes(sizes, batch_size) + for task_name, n in sizes.items(): + progress.target_samples[task_name] += n + + progress.target_upto_future_step = current_future_step +``` + +Manager 把该 task 从 step 1 到当前 future step 的绝对累计目标维护在 `progress.target_samples[task_name]` 中;strategy 直接从 `progress` 读取,不通过局部 `target_cumulative` 快照驱动生产。 + +`progress.target_samples` 需要 checkpoint。这样即使后续自定义的 `get_task_batch_sizes` 不是纯函数,也不会在 resume 后因为重算历史分配而漂移。 + +strategy 内部实时计算: + +```python +fresh = await replay_buffer.count(task_name, Status.COMPLETED) +available = progress.consumed_samples[task_name] + fresh +required = max(0, progress.target_samples[task_name] - available) +``` + +这里 `progress` 是 Manager 传入的可变引用;strategy 不缓存 `consumed` 或 `next_consumer_step`,每次循环现场读取。`target_abs` / `scheduled_target` 是本轮 produce_batch 的静态调度决策,进入循环前冻结。 + +## AsyncProduceStrategy 动态控制 + +### 新接口 + +`AsyncProduceStrategy.produce_batch` 的进度入口改为必传 `progress`: + +```python +async def produce_batch( + self, + agent_loop, + sampler, + replay_buffer, + batch_size: int, + task_name: str, + rollout_step: int = 0, # disagg 下传 current_future_step + update_event: asyncio.Event | None = None, + *, + model_rollout_step: int, + progress: ProduceProgress, + target_cumulative: int | None = None, +) -> ProduceBatchStatus: +``` + +入口只做 fail fast 校验,不做缺省初始化: + +```python +# fail fast:调用方必须完整初始化 progress。 +if task_name not in progress.consumed_samples: + raise KeyError(...) +if task_name not in progress.target_samples: + raise KeyError(...) +if target_cumulative is not None and target_cumulative != progress.target_samples[task_name]: + raise ValueError(...) +``` + +### 主循环 + +伪代码: + +```python +async def produce_batch(...): + current_future_step = rollout_step + if update_event is None: + update_event = asyncio.Event() + _validate_progress_for_task(progress, task_name, target_cumulative) + if progress.target_samples[task_name] <= 0: + return ProduceBatchStatus.NORMAL + + if update_event.is_set(): + return ProduceBatchStatus.UPDATE_ABORT + if self.is_model_expired(current_future_step, model_rollout_step): + return ProduceBatchStatus.EXPIRED_BATCH + + # 只在进入本轮时回收一次跨调用遗留的 done task,避免 done task 长期留在 pending 集合。 + claimed_done = await self._claim_already_done() + await self._put_claimed_tasks(claimed_done, replay_buffer, task_name, progress) + + if update_event.is_set(): + return ProduceBatchStatus.UPDATE_ABORT + if self.is_model_expired(current_future_step, model_rollout_step): + return ProduceBatchStatus.EXPIRED_BATCH + + expired_count = await replay_buffer.count(task_name=task_name, group_status=Status.EXPIRED) + sample_from_expired = ( + self.tail_batch_trigger_size > 0 + and expired_count >= self.tail_batch_trigger_size + ) + target_abs = progress.target_samples[task_name] + oversample_budget = 0 if sample_from_expired else math.ceil(self.over_sample_threshold * batch_size) + scheduled_target = target_abs + oversample_budget + + while True: + if update_event.is_set(): + return ProduceBatchStatus.UPDATE_ABORT + if self.is_model_expired(current_future_step, model_rollout_step): + return ProduceBatchStatus.EXPIRED_BATCH + + fresh = await replay_buffer.count(task_name=task_name, group_status=Status.COMPLETED) + available = progress.consumed_samples[task_name] + fresh + if available >= target_abs: + return ProduceBatchStatus.NORMAL + + pending_count = await self._pending_count() + desired_pending = max(0, scheduled_target - available) + if available + pending_count < scheduled_target: + await self._schedule_tasks_until( + agent_loop=agent_loop, + sampler=sampler, + task_name=task_name, + desired_pending=desired_pending, + sample_from_expired=sample_from_expired, + model_rollout_step=model_rollout_step, + update_event=update_event, + ) + if update_event.is_set(): + return ProduceBatchStatus.UPDATE_ABORT + + pending_snapshot = await self._snapshot_pending() + if update_event.is_set(): + return ProduceBatchStatus.UPDATE_ABORT + if not pending_snapshot: + # sampler 无数据或当前没有 pending;交回 manager,避免忙等。 + return ProduceBatchStatus.NORMAL + + done, _ = await asyncio.wait( + pending_snapshot, + timeout=1, + return_when=asyncio.FIRST_COMPLETED, + ) + claimed_done = await self._claim_done(done) + await self._put_claimed_tasks(claimed_done, replay_buffer, task_name, progress) +``` + +注意: + +- `available >= target_abs` 才表示当前 future step 的必要目标达成。 +- `pending` 不参与返回条件,只参与“是否要继续调度”的判断。 +- `sample_from_expired` 和 `scheduled_target` 是本轮静态决策,放在循环前;循环中只更新 live `available` / `pending_count`。 +- tail-batch mode 下 `scheduled_target == target_abs`,本轮新增任务只从 `Status.EXPIRED` pool 取样,不主动停止已有 pending。 +- 失败、filtered、aborted 的 group 会被 put,但不会增加 `fresh`,下一轮自然会补发。 +- 如果 consumer 在本循环中间取走了 batch,下一次读取 `progress.consumed_samples[task_name]` 会看到新值,不会误补已消费的部分。 + +## staleness / expired 写入策略 + +只保留两个写入点。 + +### 写入点 1:strategy put 前 + +新增一个集中 helper,替代 scattered 的 `update_expired_status` 调用: + +```python +async def _put_generated_group( + self, + items: list[RolloutState], + replay_buffer: ReplayBuffer, + task_name: str, + current_rollout_step: int, +) -> None: + refresh_seq_staleness(items, current_rollout_step) + expire_group_if_needed(items, self.tail_batch_stale_threshold) + await replay_buffer.put(items, task_name) +``` + +`expire_group_if_needed` 的语义应覆盖 completed 和 aborted: + +```python +def expire_group_if_needed(group: list[RolloutState], threshold: int) -> list[RolloutState]: + if threshold <= 0: + return group + group_status = update_group_status(group) + if group_status not in (Status.COMPLETED, Status.ABORTED): + return group + if any(sample.seq_staleness >= threshold for sample in group): + for sample in group: + sample.status = Status.EXPIRED + return group +``` + +为什么不直接复用当前 `update_expired_status`: + +- 当前实现只检查 `sample.status == Status.ABORTED` 的样本。 +- 本次需求要求 buffer 中 completed 样本也会因 train_step 推进而过期。 +- 因此需要一个对 completed / aborted group 都生效的 group-level 过期 helper,或扩展 `update_expired_status` 的语义。 + +### 写入点 2:`AgentLoopManager.get_batch` + +`get_batch(i)` 在等待当前 step batch 前,先把 producer 的 staleness 基准切到 `i`,并刷新 buffer 中已有 completed / aborted;成功取出非空 batch 后,再把基准推进到 `i + 1` 并刷新 leftover completed / aborted: + +```python +async def get_batch(self, batch_size: int, rollout_step: int) -> ProduceBatchResult: + progress = self._produce_progress + progress.next_consumer_step = rollout_step + + for task in self.task_runners: + threshold = getattr(task.produce_strategy, "tail_batch_stale_threshold", 0) + await self.replay_buffer.refresh_staleness( + task_name=task.task_name, + current_rollout_step=rollout_step, + tail_batch_stale_threshold=threshold, + ) + + while not self._finish_event.is_set(): + ... + if ready: + result = await self._get_batch_from_buffer(..., consume_progress=progress) + if result.rollout_states: + progress.next_consumer_step = rollout_step + 1 + await self._refresh_staleness_for_all_tasks(rollout_step + 1) + return result +``` + +`_get_single_task_batch_from_buffer` 中对返回 batch 调 `refresh_seq_staleness(group, rollout_step)` 可以保留;那只是刷新即将交给训练侧的数据对象,不再写回 buffer,不算第三个 buffer 状态写入点。 + +这里接受 eventual consistency:`progress.next_consumer_step = i + 1` 与 `refresh_staleness(i + 1)` 不是和 producer `count(COMPLETED)` 共享的全局事务。极短窗口内 producer 可能看到已经推进到 `i + 1` 的 progress,同时 buffer 中还有尚未刷新为 expired 的 completed / aborted leftover,并因此短暂低估缺口。后续的 `refresh_staleness` 和下一轮 produce 会修正 fresh count;当前方案接受这种最终一致性,不为它引入跨 progress / replay buffer 的全局锁。 + +## ReplayBuffer 改动 + +新增: + +```python +async def refresh_staleness( + self, + task_name: str, + current_rollout_step: int, + tail_batch_stale_threshold: int, +) -> int: + ... +``` + +语义: + +- 在 `ReplayBuffer._lock` 下查询该 task 的 `Status.COMPLETED` / `Status.ABORTED` groups。 +- 对每个 group 调 `refresh_seq_staleness(group, current_rollout_step)`。 +- 更新 `StorageItem.staleness = max(sample.seq_staleness for sample in group)`。 +- 如果 `tail_batch_stale_threshold > 0` 且 group 中任意样本 `seq_staleness >= threshold`,则整组样本置 `Status.EXPIRED`,`StorageItem.status = Status.EXPIRED`。 +- 返回本次新翻转为 expired 的 group 数。 + +为避免 destructive `get -> mutate -> put`,storage 增加最小 update 能力: + +```python +class StorageBackend: + async def update(self, items: list[StorageItem]) -> None: ... +``` + +实现: + +- `NaiveStorage.update`:按 `uid` 覆盖 `_items[uid]`,保留原 `timestamp_id`。 +- `PandasStorage.update`:flush 后按 `uid` 更新 `status / staleness / item` 列。 + +这样刷新 completed / aborted staleness 不会和 consumer 抢样本,也不会改变 FIFO / staleness policy 的时间顺序。 + +## `_pending_tasks` 并发控制 + +### helper + +strategy 内新增小锁保护的 helper: + +```python +async def _snapshot_pending(self) -> set[asyncio.Task]: + async with self._pending_lock: + return set(self._pending_tasks) + +async def _pending_count(self) -> int: + async with self._pending_lock: + return len(self._pending_tasks) + +async def _claim_done(self, done: set[asyncio.Task]) -> set[asyncio.Task]: + async with self._pending_lock: + claimed = done & self._pending_tasks + self._pending_tasks.difference_update(claimed) + return claimed + +async def _claim_already_done(self) -> set[asyncio.Task]: + async with self._pending_lock: + done = {task for task in self._pending_tasks if task.done()} + self._pending_tasks.difference_update(done) + return done +``` + +所有 done task 都必须先 claim,再 `task.result()` / `replay_buffer.put`。 + +### schedule + +`_schedule_tasks_until` 不再直接裸写 set。为避免 pause 正在开始时出现“采样后未纳入 pending”的缝隙,把“检查 `update_event` + sample + create task + add pending”包在同一个 pending lock 中。 + +这个 lock 不覆盖真正的 rollout generate,只覆盖一次轻量采样和 task 创建: + +```python +async def _schedule_one(..., update_event: asyncio.Event | None): + async with self._pending_lock: + if update_event is not None and update_event.is_set(): + return False + if len(self._pending_tasks) >= desired_pending: + return False + + group_status = Status.EXPIRED if sample_from_expired else Status.ABORTED + rollout_state = await sampler.sample(task_name=task_name, group_status=group_status) + task = create_task( + _timed_generate_group( + agent_loop, + rollout_state, + enable_partial_rollout=self.enable_partial_rollout, + ) + ) + self._pending_tasks.add(task) + # 绑定 task 发起时的模型版本。partial rollout 可能已有更早版本的 prefix; + # put 前只为本次新增 token 补这个版本,不能用 task 完成时 manager 的最新版本。 + self._pending_task_model_steps[task] = model_rollout_step + return True +``` + +如果不希望在 lock 内 `await sampler.sample(...)`,也可以引入 `_scheduling_count` 防止 pause 在“采样中但尚未 add pending”时提前返回;但这会增加状态。按“尽量少改且易维护”的约束,短锁方案更直接。 + +`update_event` 是 manager 级暂停信号。`AsyncProduceStrategy` 不再维护 `_pausing` 作为第二套暂停状态;pause drain 只依赖 pending snapshot / claim helper。 + +`_pending_task_model_steps` 是必要状态:它记录每个 pending rollout task 发起时使用的 `model_rollout_step`。partial rollout 样本的 `response_rollout_steps` 可能已经包含旧 prefix 的更早模型版本;task 完成后只为新增 token 追加该 task 发起时的版本,最终 staleness 仍按 `min(response_rollout_steps)` 计算。这个版本不能从 task 完成时的 manager `_model_rollout_step` 读取。 + +### pause + +`ProduceStrategy.pause_produce` 使用统一接口;Manager 对 sync / async strategy 使用同一调用形态: + +```python +async def pause_produce( + self, + agent_loop, + replay_buffer, + task_name: str, + *, + progress: ProduceProgress, +) -> float: + ... +``` + +Manager 侧按 `use_global_progress` 选择 progress: + +- `use_global_progress=True`:非共卡后台 `produce_loop` 在权重同步点前暂停,传全局 `_produce_progress`,因为后台 producer / trainer consumer 共享同一窗口。 +- `use_global_progress=False`:共卡同步 `produce_batch()` 的本次调用收尾,传本次调用的局部 progress。 +- Sync strategy 的默认实现忽略 `progress` 并返回 `0.0`,因此无需在 Manager 侧按子类分支。 + +`AsyncProduceStrategy.pause_produce` 的收尾逻辑为: + +```python +async def pause_produce(..., *, progress: ProduceProgress) -> float: + pause_start = time.perf_counter() + + if await self._pending_count() == 0: + return 0.0 + + rollout_ctl = await get_agent_loop_rollout_ctl(agent_loop) + await pause_generation(rollout_ctl) + + while True: + pending_snapshot = await self._snapshot_pending() + if not pending_snapshot: + break + + done, _ = await asyncio.wait( + pending_snapshot, + timeout=1, + return_when=asyncio.FIRST_COMPLETED, + ) + claimed_done = await self._claim_done(done) + for task in claimed_done: + task_model_rollout_step = self._pending_task_model_steps.pop(task) + await self._put_generated_group( + task.result(), + replay_buffer, + task_name, + current_rollout_step=progress.next_consumer_step, + model_rollout_step=task_model_rollout_step, + ) + + if await self._pending_count() > 0: + await pause_generation(rollout_ctl) + await asyncio.sleep(1) + + return time.perf_counter() - pause_start +``` + +关键点: + +- producer 和 pause 可以同时 `asyncio.wait` 同一份 snapshot,但只有先 `_claim_done` 的一方会处理结果。 +- 另一方拿到的 done task 因为已经不在 `_pending_tasks` 里,`claimed_done` 为空,不会重复 put。 +- 不再出现 `self._pending_tasks = set(pending_tasks)` 覆盖新 task 或复活旧 task。 + +## Manager 生产流程 + +### `produce_loop` + +`produce_loop` 默认从 `progress.producer_future_step` 继续生产,不再接受 `start_rollout_step` 覆盖入口。测试 / resume 如需指定起点,应直接恢复或设置 `progress.producer_future_step`,保证生产 step 只有一个状态来源。manager 初始化时把生产进度放到绝对坐标系原点: + +```python +self._produce_progress = ProduceProgress( + next_consumer_step=1, + producer_future_step=1, + consumed_samples={task.task_name: 0 for task in self.task_runners}, + target_samples={task.task_name: 0 for task in self.task_runners}, + target_upto_future_step=0, +) +``` + +resume 时从 checkpoint 恢复这些状态;trainer 不再把 `self._cur_step` 传进 `produce_loop`。 + +```python +async def produce_loop(self, batch_size: int): + while not self._finish_event.is_set(): + if self._status == AgentLoopManagerStatus.FINISH: + break + if self._status == AgentLoopManagerStatus.UPDATE_ABORT: + await self._wait_for_status_exit(AgentLoopManagerStatus.UPDATE_ABORT) + continue + if self._status == AgentLoopManagerStatus.EXPIRED_BATCH: + await self._wait_for_status_exit(AgentLoopManagerStatus.EXPIRED_BATCH) + continue + + rollout_ctl = await get_agent_loop_rollout_ctl(self.task_runners[0].agent_loop) + await continue_generation(rollout_ctl) + + status = await self._produce_batch_to_buffer( + batch_size=batch_size, + progress=self._produce_progress, + ) + + if status == ProduceBatchStatus.NORMAL: + self._produce_progress.producer_future_step += 1 + elif status == ProduceBatchStatus.EXPIRED_BATCH: + self._status = AgentLoopManagerStatus.EXPIRED_BATCH + + await asyncio.sleep(0) +``` + +`RLDisaggregatedTrainer._fit` 对应改成: + +```python +producer_task = create_task( + self.agent_loop_manager.produce_loop(batch_size=self.train_batch_size) +) +``` + +恢复训练时,由 `agent_loop_manager.resume(...)` 原地恢复 `self._produce_progress` 的各字段,不再依赖 trainer 传入 `_cur_step`。 + +### 多 task ExpiredBatch 提前停止 + +多 task 下,如果任一 task 在当前 `future_step` 上已经整体过期,其他 task 不应继续生产。 + +因为整体过期只依赖: + +```python +current_future_step +model_rollout_step +task.produce_strategy.tail_batch_stale_threshold +``` + +所以 manager 可以在 `asyncio.gather` 前统一预检查: + +实现时把当前 strategy 内部的 `_is_model_expired` 提升为 public wrapper,例如 `is_model_expired`,供 manager 做这个预检查。 + +```python +expired_tasks = [ + task.task_name + for task in self.task_runners + if isinstance(task.produce_strategy, AsyncProduceStrategy) + and task.produce_strategy.is_model_expired( + current_future_step, + self._model_rollout_step, + ) +] +if expired_tasks: + self.logger.info(f"Expired future_step={current_future_step}, tasks={expired_tasks}") + return ProduceBatchStatus.EXPIRED_BATCH +``` + +strategy 内仍保留同样检查,作为单 task / 兼容路径的兜底。 + +### `_produce_batch_to_buffer` + +伪代码: + +```python +async def _produce_batch_to_buffer( + self, + batch_size: int, + progress: ProduceProgress, + *, + task_batch_sizes: dict[str, int] | None = None, +): + current_future_step = progress.producer_future_step + if progress is self._produce_progress: + # 只有后台生产循环使用全局 progress,需要在这里推进累计 target; + # colocate 路径传入的是一次性本地 progress,不能污染全局计数。 + self._ensure_target_upto(batch_size, current_future_step) + + # 当前 step 的 task batch sizes 用于本轮 over-sample 预算;active task 由 progress target 决定。 + current_sizes = ( + self._get_task_batch_sizes_for_step(batch_size, current_future_step) + if task_batch_sizes is None + else task_batch_sizes + ) + self._validate_task_batch_sizes(current_sizes, batch_size) + + if self._any_task_model_expired(current_future_step): + return ProduceBatchStatus.EXPIRED_BATCH + + async def run_task(task): + name = task.task_name + return await task.produce_strategy.produce_batch( + task.agent_loop, + task.sampler, + self.replay_buffer, + current_sizes[name], + name, + rollout_step=current_future_step, + model_rollout_step=self._model_rollout_step, + update_event=self._update_event, + progress=progress, + ) + + # 注意:即使 current_sizes[name] == 0,该 task 也可能需要补之前因 expired/failed 造成的缺口。 + tasks_to_run = [ + task + for task in self.task_runners + if progress.target_samples[task.task_name] > 0 + ] + statuses = await asyncio.gather(*(run_task(task) for task in tasks_to_run)) + return _aggregate_status(statuses) +``` + +## Manager 消费流程 + +### `get_batch` + +`get_batch` 做三件工作: + +1. 函数开始时设置 `progress.next_consumer_step = rollout_step`,表示当前正在等待 step `rollout_step` 的训练 batch。 +2. 入口刷新一次 buffer 中 completed / aborted 样本的 staleness / expired,避免直接消费进入函数时已经过期的 completed。 +3. 成功取出 batch 后,按实际返回数量更新 `progress.consumed_samples`,再设置 `progress.next_consumer_step = rollout_step + 1` 并按下一 step 刷新 leftover completed / aborted。 + +这里接受 eventual consistency:`refresh_staleness`、producer 的 fresh count 和 consumer 的 get 不是全局事务。为了让逻辑更简单,`get_batch` 不在等待循环里反复 refresh;如果某个 completed / aborted 在等待期间才变 stale,它最多会在本次入口 refresh 与成功消费后 refresh 之间存在一个短暂窗口,下一次入口 / producer 计数 / 成功消费后的 refresh 会收敛状态。 + +伪代码: + +```python +async def get_batch(self, batch_size: int, rollout_step: int) -> ProduceBatchResult: + progress = self._produce_progress + progress.next_consumer_step = rollout_step + await self._refresh_staleness_for_all_tasks(rollout_step) + + while not self._finish_event.is_set(): + if self._status == AgentLoopManagerStatus.EXPIRED_BATCH: + return ProduceBatchResult( + rollout_states=[], + status=ProduceBatchStatus.EXPIRED_BATCH, + ) + + if await self._is_batch_ready(batch_size, rollout_step): + result = await self._get_batch_from_buffer( + batch_size, + rollout_step, + consume_progress=progress, + ) + if result.rollout_states: + progress.next_consumer_step = rollout_step + 1 + await self._refresh_staleness_for_all_tasks(rollout_step + 1) + return result + + await asyncio.sleep(self._STATUS_POLL_INTERVAL_S) + + return ProduceBatchResult(rollout_states=[]) +``` + +`_get_batch_from_buffer(..., consume_progress=progress)` 应按实际结果计数: + +```python +consume_progress.consumed_samples[task_runner.task_name] += len(batch_rollout_states) +``` + +不要只按 `task_batch_sizes` 加,因为实际返回结果才是 buffer 被消费的权威事实。 + +### `pause_produce` + +manager 侧用 `use_global_progress` 区分使用全局 progress 还是本次调用的局部 progress: + +```python +async def pause_produce( + *, + use_global_progress: bool, + progress: ProduceProgress | None = None, +) -> float: + ... +``` + +`pause_produce` 入口先校验参数并选择 progress,再置 event / status: + +```python +if use_global_progress: + pause_progress = self._produce_progress + +self._update_event.set() +self._status = AgentLoopManagerStatus.UPDATE_ABORT +``` + +`use_global_progress=True` 是 sticky pause:状态保持到 trainer 完成权重同步 / 评测后调用 `continue_produce()`。 + +`use_global_progress=False` 用于共卡 `produce_batch()` 的显式收尾,必须传入本次调用的局部 progress;它也会 set `_update_event` / `UPDATE_ABORT`,由下一次 `produce_batch()` 入口的 `continue_produce()` 清理: + +```python +else: + if progress is None: + raise ValueError(...) + pause_progress = progress +``` + +随后调用 strategy 时传同一个 live progress 引用: + +```python +pause_time_s += await strategy.pause_produce( + task.agent_loop, + self.replay_buffer, + task.task_name, + progress=pause_progress, +) +``` + +## 不再需要 Progress Window + +本版使用绝对累计 target 和绝对累计 consumed: + +```python +available_abs = consumed_abs + fresh +required = max(0, target_abs - available_abs) +``` + +因此不需要 `_target_base_step` / `_target_base_consumed`,也不需要在 resume / sync 后重置生产窗口。 + +权重同步后的 `continue_produce(model_rollout_step=...)` 只负责恢复状态机和更新 rollout 侧模型版本: + +```python +def continue_produce(self, model_rollout_step: int) -> None: + self._status = AgentLoopManagerStatus.NORMAL + self._model_rollout_step = model_rollout_step + self._update_event.clear() +``` + +`progress.producer_future_step`、`progress.target_samples`、`progress.target_upto_future_step` 和 `progress.consumed_samples` 都是训练全局绝对进度,不随 sync 重置。 + +## Checkpoint + +保存 manager state 时追加: + +```json +{ + "next_consumer_step": 1, + "consumed_samples": {"task": 0}, + "producer_future_step": 1, + "target_samples": {"task": 0}, + "target_upto_future_step": 0 +} +``` + +恢复时: + +- 读回上述状态后,原地写回 `self._produce_progress` 的字段;这些字段都是新 checkpoint 的必需字段,缺失时直接 fail fast。 +- `_pending_tasks` 不保存;保存前仍要求 pending 为空。 +- 每个 strategy 初始化 `_pending_tasks = set()`;不再保存或恢复 strategy-local pause flag。 +- 保持现有 `resume()` 进入 `UPDATE_ABORT` 且 `_update_event.set()` 的语义,让 trainer 显式 `continue_produce` 后再恢复生产。 + +## 共卡路径 + +`AgentLoopManager.produce_batch` 继续作为 colocate 的同步入口,不改变外部契约。 + +实现方式: + +- 共卡路径不使用 disagg 的全局绝对 target 状态,避免污染后台 producer 的累计进度。 +- 共卡路径在调用 `_produce_batch_to_buffer(..., progress=local_progress)` 前构造本次调用的局部 `ProduceProgress`: + - `next_consumer_step = rollout_step` + - `producer_future_step = rollout_step` + - `consumed_samples = {task_name: 0}` + - `target_samples = current_task_batch_sizes` + +共卡路径每次生产后仍调用 `pause_produce(use_global_progress=False, progress=local_progress)` 收尾 pending,然后 `_get_batch_from_buffer` 返回训练 batch。 + +共卡模式约束:同一个 `AgentLoopManager` 实例只用一种数据提供模式。`SYNC_PRODUCE_BATCH` 收尾会让 manager 保持 `UPDATE_ABORT` / `_update_event.set()`,下一次 `produce_batch()` 入口先调用 `continue_produce(model_rollout_step=rollout_step - 1)` 恢复。不要在两次 sync `produce_batch()` 之间混用 `produce_loop()` / `get_batch()`。 + +## 删除 / 收敛的旧逻辑 + +删除或停止使用: + +- `AsyncProduceStrategy._process_leftover_samples` + - 它 destructive `get -> mutate -> put`,会和 consumer 抢 completed。 + - completed / aborted staleness 刷新统一交给 `ReplayBuffer.refresh_staleness`。 +- 所有 `self._pending_tasks = set(pending_tasks)`。 +- strategy 中基于 `previously_completed_count = replay_buffer.count(COMPLETED)` 的局部 batch 判断。 +- strategy 内 `progress=None` 时构造 local progress 的兜底逻辑。 +- `AsyncProduceStrategy._current_rollout_step` 以及 `pause_produce` 中基于它的 fallback。 +- `AsyncProduceStrategy.produce_batch(..., model_rollout_step=None)` 的 fallback;调用方必须显式传入合法 `model_rollout_step`。 +- `AgentLoopManager.produce_loop(start_rollout_step=...)` 覆盖入口;producer 起点只来自 `progress.producer_future_step`。 +- `_produce_batch_to_buffer(..., rollout_step=...)` / `current_future_step=...` / `use_global_progress` / `progress_override` 这些多入口参数;内部统一使用必传的 `progress.producer_future_step`。 +- `_refresh_staleness_for_all_tasks` 中判断 replay buffer 是否存在刷新接口的 fallback;`ReplayBuffer.refresh_staleness` 是固定依赖,缺失应 fail fast。 +- `get_batch` while 循环内的重复 completed refresh。 +- `_refresh_leftover_counts` 这类只为日志字段再次 recount 的逻辑。 +- resume 时读取 `latest_consumer_step` 或用 `manager_state.get(..., default)` 隐藏字段缺失的兼容逻辑。 + +收敛到: + +- `ReplayBuffer.refresh_staleness` 负责 buffer 中 completed / aborted 的 in-place 刷新。 +- strategy `_put_generated_group` 负责新生成 / pause drain 结果 put 前刷新。 +- strategy `_claim_done` 负责 pending task 的唯一认领。 + +## 正确性小结 + +### 消费者取走 batch 不会导致 producer 误补 + +生产 batch2 时: + +```python +available_abs = consumed_abs + fresh +``` + +consumer 取走 batch1 后: + +- `fresh` 减少 `B` +- `consumed_abs` 增加 `B` + +所以 `available_abs` 不变,producer 不会把已消费的 batch1 当成缺口。 + +### completed 样本过期会触发补发 + +`get_batch` 在入口按当前 step 刷新 completed,成功取出 batch 后也会按下一 step 刷新 leftover: + +- 过期样本从 `COMPLETED` 翻成 `EXPIRED` +- `fresh` 下降 +- 下一轮 strategy 动态控制看到 `required > 0` +- 自动补发 + +这不是强事务保证:如果刷新后、消费前又有新 completed 变 stale,当前实现允许它短暂保持 completed。这个窗口通过后续入口 refresh 或成功消费后的下一 step refresh 收敛,换取更少的重复 count / refresh 和更简单的状态维护。 + +partial rollout 的 staleness 使用 `min(response_rollout_steps)`,仍由 `refresh_seq_staleness` 统一计算。 + +### 多 task 某 task 整 batch 过期时,全局尽早停 + +manager 在 gather 前检查所有 task 的 `is_model_expired(current_future_step, model_rollout_step)`。 + +只要一个 task expired: + +- 当前 `_produce_batch_to_buffer` 直接返回 `EXPIRED_BATCH` +- 其他 task 不再新发 rollout +- `produce_loop` 设置 manager status 为 `EXPIRED_BATCH` +- consumer 的 `get_batch` 返回空 batch + `EXPIRED_BATCH`,trainer 优先同步权重 + +### `_pending_tasks` 不重复 put、不丢 task + +两边可以同时 wait snapshot,但 done task 必须先 claim: + +```python +claimed = done & self._pending_tasks +self._pending_tasks.difference_update(claimed) +``` + +同一 task 只有一个协程能 claim 成功,因此不会重复 `task.result()` / `replay_buffer.put`。 + +新增 task 只通过 helper add,不再用整体赋值覆盖集合,因此不会抹掉新 task 或复活已完成 task。 + +## 建议测试 + +1. 单 task:producer 已完成 batch1,consumer 取走 batch1,producer 生产 batch2 时只补 batch2,不额外补 batch1。 +2. 单 task:producer 进入 `produce_batch` 后 consumer 中途取走 batch,strategy 通过 live `progress.consumed_samples` 不误补。 +3. completed stale:buffer 里已有 completed partial rollout,`get_batch(rollout_step)` 后超过 threshold 的 group in-place 变成 expired。 +4. put 前 stale:新生成 group 在 put 前按最新 `progress.next_consumer_step` 刷新;如果已经 stale,直接以 expired 入 buffer。 +5. 多 task:任一 task 在当前 future step 上 `EXPIRED_BATCH`,其他 task 不再 schedule。 +6. pending race:让 `produce_batch` 和 `pause_produce` 同时 wait 同一个 pending task,确认 replay_buffer 只 put 一次。 +7. checkpoint:保存 / 恢复后 `progress.producer_future_step`、`progress.target_samples`、`progress.target_upto_future_step` 和 `progress.consumed_samples` 不回退,buffer leftovers 仍可被后续 train step 消费。 +8. fixed over-sample budget:当前只缺 1 个样本时,`over_sample=1, task_batch_size=4` 应调度到 `target_abs + 4`,而不是按缺口只调度到 `available + 2`。 +9. tail-batch static mode:进入 tail-batch mode 后,本轮新增任务只从 `Status.EXPIRED` pool 取样,且 `scheduled_target == target_abs` 不超发。 diff --git a/docs/en/api/rl_trainer.rst b/docs/en/api/rl_trainer.rst index 906a6f8a5d..31bef4197e 100644 --- a/docs/en/api/rl_trainer.rst +++ b/docs/en/api/rl_trainer.rst @@ -7,4 +7,7 @@ RL Trainer :toctree: generated :nosignatures: - train.rl_trainer.RLTrainer \ No newline at end of file + train.rl_trainer.RLColocateTrainer + train.rl_trainer.RLColocateTrainerConfig + train.rl_trainer.RLDisaggregatedTrainer + train.rl_trainer.RLDisaggregatedTrainerConfig diff --git a/docs/en/rl/advanced_tutorial/loss.md b/docs/en/rl/advanced_tutorial/loss.md index 130517d2e3..c8f54df849 100644 --- a/docs/en/rl/advanced_tutorial/loss.md +++ b/docs/en/rl/advanced_tutorial/loss.md @@ -9,8 +9,7 @@ All loss calculations in XTuner involve two core components: `LossConfig` and `L ```python import torch import torch.nn as nn -from xtuner.v1.rl.grpo import GRPOLossConfig, GRPOLossContext -from xtuner.v1.rl.base import RLLossContextInputItem +from xtuner.v1.rl.loss import GRPOLossConfig, GRPOLossContext, RLLossContextInputItem from xtuner.v1.data_proto import SequenceContext def gather_logprobs(logits, shifted_labels): diff --git a/docs/en/rl/tutorial/rl_grpo_trainer.md b/docs/en/rl/tutorial/rl_grpo_trainer.md index 95f0efd5a5..fa37a424e7 100644 --- a/docs/en/rl/tutorial/rl_grpo_trainer.md +++ b/docs/en/rl/tutorial/rl_grpo_trainer.md @@ -91,7 +91,7 @@ If you need more fine-grained control (such as distributed inference, inference ```{code-block} python :caption: Configure Inference Environment -from xtuner.v1.ray.config.worker import RolloutConfig +from xtuner.v1.rl.rollout.worker import RolloutConfig model_path = "/path/to/qwen3-8B" # Replace with your model path @@ -143,8 +143,8 @@ For more configuration parameters, please refer to the API documentation: {class :caption: Configure Training Strategy from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig from xtuner.v1.model.dense.qwen3 import Qwen3Dense8BConfig -from xtuner.v1.rl.base import WorkerConfig -from xtuner.v1.rl.grpo import GRPOLossConfig +from xtuner.v1.rl.trainer import WorkerConfig +from xtuner.v1.rl.loss import GRPOLossConfig model_path = "/path/to/qwen3-8B" # Fill in your model path train_optimizer_steps = 4 # Training optimization steps diff --git a/docs/zh_cn/api/rl_trainer.rst b/docs/zh_cn/api/rl_trainer.rst index 906a6f8a5d..31bef4197e 100644 --- a/docs/zh_cn/api/rl_trainer.rst +++ b/docs/zh_cn/api/rl_trainer.rst @@ -7,4 +7,7 @@ RL Trainer :toctree: generated :nosignatures: - train.rl_trainer.RLTrainer \ No newline at end of file + train.rl_trainer.RLColocateTrainer + train.rl_trainer.RLColocateTrainerConfig + train.rl_trainer.RLDisaggregatedTrainer + train.rl_trainer.RLDisaggregatedTrainerConfig diff --git a/docs/zh_cn/rl/advanced_tutorial/gateway_api_debug.md b/docs/zh_cn/rl/advanced_tutorial/gateway_api_debug.md new file mode 100644 index 0000000000..06b23c6ab2 --- /dev/null +++ b/docs/zh_cn/rl/advanced_tutorial/gateway_api_debug.md @@ -0,0 +1,207 @@ +# Gateway 兼容接口联调 + +本文记录如何使用真实的 Agent 客户端和 OpenAI SDK 联调 XTuner Gateway,验证 Gateway 对 Anthropic Messages、OpenAI Responses 和 OpenAI Chat Completions 接口的兼容情况。 + +## 适用场景 + +当你修改 Gateway、Rollout Controller、Agent Loop 或协议适配层后,可以按本文流程做一次端到端验证,确认: + +- Gateway 能够接收 `/v1/messages`、`/v1/responses` 和 `/v1/chat/completions` 请求。 +- Claude Code、Codex 等真实 Agent 客户端能够连接到本地 Gateway。 +- 普通对话和工具调用链路都能正常返回。 +- Gateway 的请求捕获日志能够记录调试所需的协议转换信息。 + +## 前置条件 + +1. 已安装 XTuner 运行环境,并能启动 Rollout Controller 和 Gateway。 +2. Gateway 服务默认监听 `http://127.0.0.1:8091`。 +3. Gateway 模型名配置为 `local-test`。 +4. 鉴权 token 使用本地调试值 `dummy`。 +5. 启动 Gateway 时建议打开 `capture_folder`,便于回看请求、协议适配结果和模型输出。 + +```{note} +真实 Agent 客户端会携带较长的系统提示词和工具定义。联调 Claude Code 时建议将上下文长度设置到 32K;联调 Codex 时建议至少设置到 16K。 +``` + +## 启动 Gateway + +先启动 Rollout Controller 和 Gateway。以下命令是本地调试脚本示例: + +```bash +python .dev_scripts/debug_gateway.py \ + --model-path \ + --model-name local-test \ + --context-length 32768 +``` + +启动时需要确认: + +- Gateway 端口为 `8091`。 +- 模型名为 `local-test`。 +- 上下文长度满足当前客户端需求。 +- 已配置 `capture_folder`。 + +## 验证 Anthropic Messages 接口 + +Claude Code 通过 Anthropic Messages API 访问 Gateway,可用于验证 `/v1/messages` 的协议适配和工具调用链路。 + +### 安装 Claude Code + +```bash +curl -fsSL https://claude.ai/install.sh | bash +``` + +### 配置环境变量 + +```bash +export ANTHROPIC_BASE_URL=http://127.0.0.1:8091 +export ANTHROPIC_AUTH_TOKEN=dummy +export ANTHROPIC_MODEL=local-test +export API_TIMEOUT_MS=600000 +``` + +### 验证普通对话 + +启动 Claude Code 后发送: + +```text +Reply with exactly: OK +``` + +如果客户端能够收到模型回复,说明 `/v1/messages` 的基础请求链路可用。 + +### 验证工具调用 + +继续发送以下 prompt: + +```text +Use your tools to find the gateway route definitions, then add a single log line for every incoming request to /v1/messages. Show me the exact file you changed and the patch you would apply. +``` + +如果 Claude Code 能够正常调用工具、读取仓库文件,并返回拟修改的文件和 patch,说明工具调用链路可用。 + +## 验证 OpenAI Responses 接口 + +Codex 通过 OpenAI Responses API 访问 Gateway,可用于验证 `/v1/responses` 的协议适配和工具调用链路。 + +### 安装 Codex + +按 Codex 官方安装方式完成安装后,配置本地模型提供方。 + +### 配置 Codex + +在 Codex 的 `config.toml` 中添加本地 Gateway provider: + +```toml +model = "local-test" +model_provider = "xtuner" + +[model_providers.xtuner] +name = "xtuner gateway" +base_url = "http://127.0.0.1:8091/v1" +env_key = "XTUNER_GATEWAY_KEY" +``` + +配置访问 token: + +```bash +export XTUNER_GATEWAY_KEY=dummy +``` + +### 先用 curl 验证接口 + +启动 Codex 前,先确认 `/v1/responses` 能直接返回: + +```bash +curl http://127.0.0.1:8091/v1/responses \ + -H 'content-type: application/json' \ + -H 'authorization: Bearer dummy' \ + -d '{ + "model": "local-test", + "input": "Reply with exactly OK" + }' +``` + +如果返回状态为 `completed`,且 `output` 中包含模型回复,说明 Responses 接口基础链路可用。 + +### 验证普通对话 + +启动 Codex 后发送: + +```text +你好 +``` + +如果 Codex 能收到中文回复,说明客户端能够通过本地 Gateway 完成基础对话。 + +### 验证工具调用 + +继续发送以下 prompt: + +```text +Use your tools to list the top-level files and directories in the current repository. +Do not explain your plan. +Do not answer from memory. +If you cannot access tools, reply exactly: NO_TOOLS +``` + +如果 Codex 返回了仓库顶层文件和目录,而不是 `NO_TOOLS`,说明 Responses 接口下的工具调用链路可用。 + +## 验证 OpenAI Chat Completions 接口 + +除了真实 Agent 客户端,也可以使用 OpenAI Python SDK 验证 `/v1/chat/completions`。 + +```python +from openai import OpenAI + +client = OpenAI( + base_url="http://127.0.0.1:8091/v1", + api_key="dummy", +) + +resp = client.chat.completions.create( + model="local-test", + messages=[ + {"role": "system", "content": "You are a helpful coding assistant."}, + {"role": "user", "content": "Reply with exactly: OK"}, + ], + max_tokens=32, + temperature=0, +) + +print(resp.choices[0].message.content) +``` + +如果输出包含 `OK`,说明 Chat Completions 接口基础链路可用。 + +## 检查 capture 日志 + +联调过程中建议同步检查 Gateway 的 `capture_folder` 输出。重点确认每条记录中是否包含: + +- `source_protocol`:请求来源协议,例如 `anthropic_messages` 或 `openai_responses`。 +- `internal_messages`:Gateway 转换后发送给 Rollout 的内部消息。 +- `output_messages` 或 `output_text`:模型输出转换回客户端协议后的结果。 +- `rollout_tools` 和 `rollout_tool_choice`:工具定义和工具选择策略。 +- `request_id`:用于串联客户端请求、Gateway 记录和 Rollout 结果。 + +这些字段能帮助定位问题出在客户端请求、协议适配、Rollout 生成还是响应转换阶段。 + +## 常见问题 + +### 客户端请求超时 + +先检查 Gateway 是否仍在运行,并适当增大客户端超时时间。Claude Code 可设置: + +```bash +export API_TIMEOUT_MS=600000 +``` + +同时检查 Rollout Controller 是否收到请求,以及推理服务是否有可用并发。 + +### 客户端上下文过长 + +真实 Agent 客户端会注入系统提示词、工具 schema 和历史消息。如果请求被截断或报 context length 相关错误,需要增大 Gateway 和推理后端的上下文长度。 + +### 工具调用没有触发 + +先使用本文中的工具调用 prompt 做最小复现,再检查 `capture_folder` 中是否记录了工具定义。如果 `rollout_tools` 为空,问题通常出在客户端请求到 Gateway 的协议适配阶段;如果工具定义存在但没有工具调用结果,需要继续检查模型输出和 Agent 客户端的工具执行日志。 diff --git a/docs/zh_cn/rl/advanced_tutorial/loss.md b/docs/zh_cn/rl/advanced_tutorial/loss.md index 278129d6bd..5578b30a44 100644 --- a/docs/zh_cn/rl/advanced_tutorial/loss.md +++ b/docs/zh_cn/rl/advanced_tutorial/loss.md @@ -9,8 +9,7 @@ XTuner 中所有的 loss 计算均涉及两个核心组件 `LossConfig` 和 `Los ```python import torch import torch.nn as nn -from xtuner.v1.rl.grpo import GRPOLossConfig, GRPOLossContext -from xtuner.v1.rl.base import RLLossContextInputItem +from xtuner.v1.rl.loss import GRPOLossConfig, GRPOLossContext, RLLossContextInputItem from xtuner.v1.data_proto import SequenceContext def gather_logprobs(logits, shifted_labels): diff --git a/docs/zh_cn/rl/tutorial/rl_grpo_trainer.md b/docs/zh_cn/rl/tutorial/rl_grpo_trainer.md index 267bd14d49..e8f1486741 100644 --- a/docs/zh_cn/rl/tutorial/rl_grpo_trainer.md +++ b/docs/zh_cn/rl/tutorial/rl_grpo_trainer.md @@ -92,7 +92,7 @@ replay_buffer_cfg = ReplayBufferConfig( ```{code-block} python :caption: 配置推理环境 -from xtuner.v1.ray.config.worker import RolloutConfig +from xtuner.v1.rl.rollout.worker import RolloutConfig model_path = "/path/to/qwen3-8B" # 替换为您的模型路径 @@ -144,8 +144,8 @@ judger_cfg = JudgerConfig( :caption: 配置训练策略 from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig from xtuner.v1.model.dense.qwen3 import Qwen3Dense8BConfig -from xtuner.v1.rl.base import WorkerConfig -from xtuner.v1.rl.grpo import GRPOLossConfig +from xtuner.v1.rl.rollout.worker import WorkerConfig +from xtuner.v1.rl.loss import GRPOLossConfig model_path = "/path/to/qwen3-8B" # 填入您的模型路径 train_optimizer_steps = 4 # 训练优化步数 @@ -201,7 +201,7 @@ evaluator_cfg = EvaluatorConfig( 除以上的生成和训练配置外,我们需要配置系统所需资源(如GPU、CPU、内存)等,此处我们使用默认的资源配置,示例如下。 ```{code-block} python -from xtuner.v1.ray.base import AcceleratorResourcesConfig +from xtuner.v1.rl.utils import AcceleratorResourcesConfig resources = AcceleratorResourcesConfig( accelerator="GPU", num_accelerators_per_worker=1, diff --git a/examples/v1/config/rl_dapo_math.py b/examples/v1/config/rl_dapo_math.py new file mode 100644 index 0000000000..67982aad8a --- /dev/null +++ b/examples/v1/config/rl_dapo_math.py @@ -0,0 +1,214 @@ +import os +from pathlib import Path + +from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig +from xtuner.v1.data_proto.rl_data import SampleParams +from xtuner.v1.rl.advantage import GRPOAdvantageConfig +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLTextTokenizeFnConfig +from xtuner.v1.model import get_model_config_from_hf +from xtuner.v1.rl.utils import AcceleratorResourcesConfig +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.judger import DapoMathJudgerConfig +from xtuner.v1.rl.replay_buffer import SyncReplayBufferConfig +from xtuner.v1.rl.trainer import WorkerConfig +from xtuner.v1.rl.agent_loop import SingleTurnAgentLoopConfig +from xtuner.v1.rl.agent_loop_manager import AgentLoopManagerConfig, SamplerConfig, SyncProduceStrategyConfig, TaskSpecConfig +from xtuner.v1.rl.evaluator import EvaluatorConfig +from xtuner.v1.rl.loss import GRPOLossConfig +from xtuner.v1.train.rl_trainer import RLColocateTrainerConfig + +work_dir = os.environ["WORK_DIR"] +model_path = os.environ["MODEL_PATH"] +data_path = os.environ["DATA_PATH"] +eval_data_path = os.environ["EVAL_DATA_PATH"] +enable_return_routed_experts = os.environ.get("ENABLE_RETURN_ROUTED_EXPERTS", "0") +NNODE = int(os.environ.get("WORLD_SIZE", "1")) + +# basic settings +experimental_name = "dapo_math" +total_epochs = 1 +train_batch_size = 512 +prompt_repeat_k = 16 +rollout_tp_size = 1 +rollout_ep_size = 1 +max_prompt_length = 2048 +max_response_length = 8192 +pack_max_length = 32768 +train_optimizer_steps = 16 +hf_interval = 50 +enable_initial_evaluate = True +evaluate_step = 5 + +# 1. resources +resources = AcceleratorResourcesConfig( + accelerator="GPU", + num_workers=8 * NNODE, + num_cpus_per_worker=12, + cpu_memory_per_worker=16 * 1024**3, # 16 GB +) + +# 2. rollout +rollout_config = RolloutConfig( + env=experimental_name, + device=resources.accelerator, + model_path=model_path, + dtype="bfloat16", + tensor_parallel_size=rollout_tp_size, + expert_parallel_size=rollout_ep_size, + gpu_memory_utilization=0.8, + context_length=max_response_length + max_prompt_length, + enable_return_routed_experts=(enable_return_routed_experts == "1"), + rollout_max_batch_size_per_instance=2048 +) + +# 3. judger +from xtuner.v1.rl.utils import get_eos_token +from transformers import AutoTokenizer +eos_token_id = get_eos_token(model_path) +tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) +eos_token_str = tokenizer.convert_ids_to_tokens(eos_token_id) +judger_config = DapoMathJudgerConfig( + judger_name="dapo_math", + num_ray_actors=1, + eos_token=eos_token_str, + enable_overlong_buffer = True, + max_response_len =max_response_length, + overlong_buffer_len=4096, + overlong_penalty_factor=1.0, + tokenizer=tokenizer +) + +# 4. train worker +lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) +fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) +model_cfg = get_model_config_from_hf(Path(model_path)) +if hasattr(model_cfg, "balancing_loss_cfg"): + model_cfg.balancing_loss_cfg = None +if hasattr(model_cfg, "z_loss_cfg"): + model_cfg.z_loss_cfg = None +optim_cfg = AdamWConfig(lr=1e-6, foreach=False, weight_decay=0.1) +loss_cfg = GRPOLossConfig( + policy_loss_cfg=dict( + cliprange_high=0.28, + cliprange_low=0.2, + loss_type=os.environ.get("LOSS_TYPE", "vanilla"), + clip_ratio_c=10.0, + log_prob_diff_min=-20.0, + log_prob_diff_max=20.0, + ), + ignore_idx=-100, + use_kl_loss=False, + kl_loss_coef=0.0, + kl_loss_type="low_var_kl", + mode=os.environ.get("LOSS_MODE", "chunk"), + chunk_size=512, +) +train_worker_cfg = WorkerConfig( + model_cfg=model_cfg, + load_from=model_path, + optim_cfg=optim_cfg, + loss_cfg=loss_cfg, + lr_cfg=lr_cfg, + fsdp_cfg=fsdp_cfg, + sp_size=int(os.environ.get("SP_SIZE", "1")), + optimizer_steps=train_optimizer_steps, + pack_max_length=pack_max_length, +) + +# 5. train agent loop manager +train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) +tokenizer_config = RLTextTokenizeFnConfig(max_length=max_prompt_length) +train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] +dataloader_cfg = DataloaderConfig( + dataset_config_list=train_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +sampler_config = SamplerConfig( + dataloader_cfg=dataloader_cfg, + prompt_repeat_k=prompt_repeat_k, +) +training_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=0, + top_p=1.0, + temperature=1.0, + min_tokens=0, +) +agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=training_sample_params, +) +produce_strategy_config = SyncProduceStrategyConfig() +agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="train_task", + agent_loop_config=agent_loop_config, + judger_config=judger_config, + produce_strategy_config=produce_strategy_config, + sampler_config=sampler_config, + ), +) + +# 6. eval agent loop manager +eval_dataset = DatasetConfig( + name=experimental_name, anno_path=eval_data_path, sample_ratio=1.0 +) +eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] +eval_dataloader_cfg = DataloaderConfig( + dataset_config_list=eval_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +eval_sampler_config = SamplerConfig( + dataloader_cfg=eval_dataloader_cfg, + prompt_repeat_k=1, +) +evaluation_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=1, + top_p=0.7, + temperature=0.0, + min_tokens=0, +) +eval_agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=evaluation_sample_params, +) +eval_agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="eval_task", + agent_loop_config=eval_agent_loop_config, + judger_config=judger_config, + sampler_config=eval_sampler_config, + ), +) + +def dapo_compute_metric(samples): + return {"accuracy": sum(s.reward["acc"] > 0 for s in samples) / len(samples)} + +evaluator_config = EvaluatorConfig(compute_metric_func=dapo_compute_metric) + +trainer = RLColocateTrainerConfig( + resources=resources, + train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config + rollout_config=rollout_config, + tokenizer_path=model_path, + replay_buffer_config=SyncReplayBufferConfig(), + agent_loop_manager_cfg=agent_loop_manager_cfg, + eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg, + evaluator_config=evaluator_config, + load_from=model_path, + train_batch_size=train_batch_size, + advantage_estimator_config=GRPOAdvantageConfig(eps=1e-8), + enable_evaluate=True, + enable_initial_evaluate=False, + total_train_steps=500, + evaluate_step=evaluate_step, + work_dir=work_dir, + seed=123, + debug_rollout=False, +) diff --git a/examples/v1/config/rl_dapo_math_async.py b/examples/v1/config/rl_dapo_math_async.py new file mode 100644 index 0000000000..9e1cdbbde0 --- /dev/null +++ b/examples/v1/config/rl_dapo_math_async.py @@ -0,0 +1,217 @@ +import os +from pathlib import Path + +from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig +from xtuner.v1.data_proto.rl_data import SampleParams +from xtuner.v1.rl.advantage import GRPOAdvantageConfig +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLTextTokenizeFnConfig +from xtuner.v1.model import get_model_config_from_hf +from xtuner.v1.rl.utils import AcceleratorResourcesConfig +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.judger import DapoMathJudgerConfig +from xtuner.v1.rl.replay_buffer import AsyncReplayBufferConfig +from xtuner.v1.rl.trainer import WorkerConfig +from xtuner.v1.rl.agent_loop import SingleTurnAgentLoopConfig +from xtuner.v1.rl.agent_loop_manager import AgentLoopManagerConfig, AsyncProduceStrategyConfig, SamplerConfig, TaskSpecConfig +from xtuner.v1.rl.evaluator import EvaluatorConfig +from xtuner.v1.rl.loss import GRPOLossConfig +from xtuner.v1.train.rl_trainer import RLColocateTrainerConfig + +work_dir = os.environ["WORK_DIR"] +model_path = os.environ["MODEL_PATH"] +data_path = os.environ["DATA_PATH"] +eval_data_path = os.environ["EVAL_DATA_PATH"] +enable_return_routed_experts = os.environ.get("ENABLE_RETURN_ROUTED_EXPERTS", "0") +NNODE = int(os.environ.get("WORLD_SIZE", "1")) + +# basic settings +experimental_name = "dapo_math" +total_epochs = 1 +train_batch_size = 512 +prompt_repeat_k = 16 +rollout_tp_size = 1 +rollout_ep_size = 1 +max_prompt_length = 2048 +max_response_length = 8192 +pack_max_length = 32768 +train_optimizer_steps = 16 +hf_interval = 50 +enable_initial_evaluate = True +evaluate_step = 5 + +# 1. resources +resources = AcceleratorResourcesConfig( + accelerator="GPU", + num_workers=8 * NNODE, + num_cpus_per_worker=12, + cpu_memory_per_worker=16 * 1024**3, # 16 GB +) + +# 2. rollout +rollout_config = RolloutConfig( + env=experimental_name, + device=resources.accelerator, + model_path=model_path, + dtype="bfloat16", + tensor_parallel_size=rollout_tp_size, + expert_parallel_size=rollout_ep_size, + gpu_memory_utilization=0.8, + context_length=max_response_length + max_prompt_length, + enable_return_routed_experts=(enable_return_routed_experts == "1"), + rollout_max_batch_size_per_instance=2048 +) + +# 3. judger +from xtuner.v1.rl.utils import get_eos_token +from transformers import AutoTokenizer +eos_token_id = get_eos_token(model_path) +tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) +eos_token_str = tokenizer.convert_ids_to_tokens(eos_token_id) +judger_config = DapoMathJudgerConfig( + judger_name="dapo_math", + num_ray_actors=1, + eos_token=eos_token_str, + enable_overlong_buffer = True, + max_response_len=max_response_length, + overlong_buffer_len=4096, + overlong_penalty_factor=1.0, + tokenizer=tokenizer) +# 4. train worker +lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) +fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) +model_cfg = get_model_config_from_hf(Path(model_path)) +if hasattr(model_cfg, "balancing_loss_cfg"): + model_cfg.balancing_loss_cfg = None +if hasattr(model_cfg, "z_loss_cfg"): + model_cfg.z_loss_cfg = None +optim_cfg = AdamWConfig(lr=1e-6, foreach=False, weight_decay=0.1) +loss_cfg = GRPOLossConfig( + policy_loss_cfg=dict( + cliprange_high=0.28, + cliprange_low=0.2, + loss_type=os.environ.get("LOSS_TYPE", "vanilla"), + clip_ratio_c=10.0, + log_prob_diff_min=-20.0, + log_prob_diff_max=20.0, + ), + ignore_idx=-100, + use_kl_loss=False, + kl_loss_coef=0.0, + kl_loss_type="low_var_kl", + mode=os.environ.get("LOSS_MODE", "chunk"), + chunk_size=512, +) +train_worker_cfg = WorkerConfig( + model_cfg=model_cfg, + load_from=model_path, + optim_cfg=optim_cfg, + loss_cfg=loss_cfg, + lr_cfg=lr_cfg, + fsdp_cfg=fsdp_cfg, + sp_size=int(os.environ.get("SP_SIZE", "1")), + optimizer_steps=train_optimizer_steps, + pack_max_length=pack_max_length, +) + +# 5. train agent loop manager +train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) +tokenizer_config = RLTextTokenizeFnConfig(max_length=max_prompt_length) +train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] +dataloader_cfg = DataloaderConfig( + dataset_config_list=train_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +sampler_config = SamplerConfig( + dataloader_cfg=dataloader_cfg, + prompt_repeat_k=prompt_repeat_k, +) +training_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=0, + top_p=1.0, + temperature=1.0, + min_tokens=0, +) +agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=training_sample_params, +) +produce_strategy_config = AsyncProduceStrategyConfig( + over_sample_threshold=0.2, + enable_partial_rollout=True, + max_staleness=0, + tail_batch_trigger_size=256, +) +agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="train_task", + agent_loop_config=agent_loop_config, + judger_config=judger_config, + produce_strategy_config=produce_strategy_config, + sampler_config=sampler_config, + ), +) + +# 6. eval agent loop manager +eval_dataset = DatasetConfig( + name=experimental_name, anno_path=eval_data_path, sample_ratio=1.0 +) +eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] +eval_dataloader_cfg = DataloaderConfig( + dataset_config_list=eval_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +eval_sampler_config = SamplerConfig( + dataloader_cfg=eval_dataloader_cfg, + prompt_repeat_k=1, +) +evaluation_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=1, + top_p=0.7, + temperature=0.0, + min_tokens=0, +) +eval_agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=evaluation_sample_params, +) +eval_agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="eval_task", + agent_loop_config=eval_agent_loop_config, + judger_config=judger_config, + sampler_config=eval_sampler_config, + ), +) + +def dapo_compute_metric(samples): + return {"accuracy": sum(s.reward["acc"] > 0 for s in samples) / len(samples)} + +evaluator_config = EvaluatorConfig(compute_metric_func=dapo_compute_metric) + +trainer = RLColocateTrainerConfig( + resources=resources, + train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config + rollout_config=rollout_config, + tokenizer_path=model_path, + replay_buffer_config=AsyncReplayBufferConfig(), + agent_loop_manager_cfg=agent_loop_manager_cfg, + eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg, + evaluator_config=evaluator_config, + load_from=model_path, + train_batch_size=train_batch_size, + advantage_estimator_config=GRPOAdvantageConfig(eps=1e-8), + enable_evaluate=True, + enable_initial_evaluate=False, + total_train_steps=500, + evaluate_step=evaluate_step, + work_dir=work_dir, + seed=123, + debug_rollout=False, +) diff --git a/examples/v1/config/rl_dapo_math_async_filter.py b/examples/v1/config/rl_dapo_math_async_filter.py new file mode 100644 index 0000000000..6b2d1f94a6 --- /dev/null +++ b/examples/v1/config/rl_dapo_math_async_filter.py @@ -0,0 +1,230 @@ +import os +from pathlib import Path + +from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig +from xtuner.v1.data_proto.rl_data import SampleParams +from xtuner.v1.rl.advantage import GRPOAdvantageConfig +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLTextTokenizeFnConfig +from xtuner.v1.model import get_model_config_from_hf +from xtuner.v1.rl.utils import AcceleratorResourcesConfig +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.judger import DapoMathJudgerConfig +from xtuner.v1.rl.replay_buffer import AsyncReplayBufferConfig +from xtuner.v1.rl.trainer import WorkerConfig +from xtuner.v1.rl.agent_loop import SingleTurnAgentLoopConfig +from xtuner.v1.rl.agent_loop_manager import AgentLoopManagerConfig, AsyncProduceStrategyConfig, SamplerConfig, TaskSpecConfig +from xtuner.v1.rl.evaluator import EvaluatorConfig +from xtuner.v1.rl.loss import GRPOLossConfig +from xtuner.v1.train.rl_trainer import RLColocateTrainerConfig + +work_dir = os.environ["WORK_DIR"] +model_path = os.environ["MODEL_PATH"] +data_path = os.environ["DATA_PATH"] +eval_data_path = os.environ["EVAL_DATA_PATH"] +enable_return_routed_experts = os.environ.get("ENABLE_RETURN_ROUTED_EXPERTS", "0") +NNODE = int(os.environ.get("WORLD_SIZE", "1")) + +# basic settings +experimental_name = "dapo_math" +total_epochs = 1 +train_batch_size = 512 +prompt_repeat_k = 16 +rollout_tp_size = 1 +rollout_ep_size = 1 +max_prompt_length = 2048 +max_response_length = 8192 +pack_max_length = 32768 +train_optimizer_steps = 16 +hf_interval = 50 +enable_initial_evaluate = True +evaluate_step = 5 + +# 1. resources +resources = AcceleratorResourcesConfig( + accelerator="GPU", + num_workers=8 * NNODE, + num_cpus_per_worker=12, + cpu_memory_per_worker=16 * 1024**3, # 16 GB +) + +# 2. rollout +rollout_config = RolloutConfig( + env=experimental_name, + device=resources.accelerator, + model_path=model_path, + dtype="bfloat16", + tensor_parallel_size=rollout_tp_size, + expert_parallel_size=rollout_ep_size, + gpu_memory_utilization=0.8, + context_length=max_response_length + max_prompt_length, + enable_return_routed_experts=(enable_return_routed_experts == "1"), + rollout_max_batch_size_per_instance=1024 +) + +# 3. judger +from xtuner.v1.rl.utils import get_eos_token +from transformers import AutoTokenizer +eos_token_id = get_eos_token(model_path) +tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) +eos_token_str = tokenizer.convert_ids_to_tokens(eos_token_id) +dapomath_judger_config = DapoMathJudgerConfig( + judger_name="dapo_math", + eos_token=eos_token_str, + enable_overlong_buffer = True, + max_response_len =max_response_length, + overlong_buffer_len=4096, + overlong_penalty_factor=1.0, + tokenizer=tokenizer) +judger_config = DapoMathJudgerConfig(judger_name="dapo_math", num_ray_actors=1) + +# 4. train worker +lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) +fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) +model_cfg = get_model_config_from_hf(Path(model_path)) +if hasattr(model_cfg, "balancing_loss_cfg"): + model_cfg.balancing_loss_cfg = None +if hasattr(model_cfg, "z_loss_cfg"): + model_cfg.z_loss_cfg = None +optim_cfg = AdamWConfig(lr=1e-6, foreach=False, weight_decay=0.1) +loss_cfg = GRPOLossConfig( + policy_loss_cfg=dict( + cliprange_high=0.28, + cliprange_low=0.2, + loss_type=os.environ.get("LOSS_TYPE", "vanilla"), + clip_ratio_c=10.0, + log_prob_diff_min=-20.0, + log_prob_diff_max=20.0, + ), + ignore_idx=-100, + use_kl_loss=False, + kl_loss_coef=0.0, + kl_loss_type="low_var_kl", + mode=os.environ.get("LOSS_MODE", "chunk"), + chunk_size=512, +) +train_worker_cfg = WorkerConfig( + model_cfg=model_cfg, + load_from=model_path, + optim_cfg=optim_cfg, + loss_cfg=loss_cfg, + lr_cfg=lr_cfg, + fsdp_cfg=fsdp_cfg, + sp_size=int(os.environ.get("SP_SIZE", "1")), + optimizer_steps=train_optimizer_steps, + pack_max_length=pack_max_length, +) + +# 5. train agent loop manager +train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) +tokenizer_config = RLTextTokenizeFnConfig(max_length=max_prompt_length) +train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] +dataloader_cfg = DataloaderConfig( + dataset_config_list=train_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +sampler_config = SamplerConfig( + dataloader_cfg=dataloader_cfg, + prompt_repeat_k=prompt_repeat_k, +) +training_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=0, + top_p=1.0, + temperature=1.0, + min_tokens=0, +) +agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=training_sample_params, +) +def group_samples_filter_func(rollout_states): + valid_responses = [] + for state in rollout_states: + if state.response_ids is not None: + valid_responses.append(state) + + rewards = [res.reward["score"] for res in valid_responses] + if len(set(rewards)) == 1: + return False + else: + return True + +produce_strategy_config = AsyncProduceStrategyConfig( + over_sample_threshold=0.2, + enable_partial_rollout=True, + max_staleness=0, + tail_batch_trigger_size=256, + is_valid_sample_fn=group_samples_filter_func +) +agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="train_task", + agent_loop_config=agent_loop_config, + judger_config=judger_config, + produce_strategy_config=produce_strategy_config, + sampler_config=sampler_config, + ), +) + +# 6. eval agent loop manager +eval_dataset = DatasetConfig( + name=experimental_name, anno_path=eval_data_path, sample_ratio=1.0 +) +eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] +eval_dataloader_cfg = DataloaderConfig( + dataset_config_list=eval_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +eval_sampler_config = SamplerConfig( + dataloader_cfg=eval_dataloader_cfg, + prompt_repeat_k=1, +) +evaluation_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=1, + top_p=0.7, + temperature=0.0, + min_tokens=0, +) +eval_agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=evaluation_sample_params, +) +eval_agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="eval_task", + agent_loop_config=eval_agent_loop_config, + judger_config=judger_config, + sampler_config=eval_sampler_config, + ), +) + +def dapo_compute_metric(samples): + return {"accuracy": sum(s.reward["acc"] > 0 for s in samples) / len(samples)} + +evaluator_config = EvaluatorConfig(compute_metric_func=dapo_compute_metric) + +trainer = RLColocateTrainerConfig( + resources=resources, + train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config + rollout_config=rollout_config, + tokenizer_path=model_path, + replay_buffer_config=AsyncReplayBufferConfig(), + agent_loop_manager_cfg=agent_loop_manager_cfg, + eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg, + evaluator_config=evaluator_config, + load_from=model_path, + train_batch_size=train_batch_size, + advantage_estimator_config=GRPOAdvantageConfig(eps=1e-8), + enable_evaluate=True, + enable_initial_evaluate=False, + evaluate_step=evaluate_step, + work_dir=work_dir, + seed=123, + debug_rollout=False, +) diff --git a/examples/v1/config/rl_disagg_multi.py b/examples/v1/config/rl_disagg_multi.py new file mode 100644 index 0000000000..de4f6f47c6 --- /dev/null +++ b/examples/v1/config/rl_disagg_multi.py @@ -0,0 +1,353 @@ +"""RL Disaggregated Trainer example config (Multi-Task: GSM8K + DAPO Math). + +Required env vars: + WORK_DIR + MODEL_PATH + GSM8K_DATA_PATH + GSM8K_EVAL_DATA_PATH + DAPO_DATA_PATH + DAPO_EVAL_DATA_PATH + +Common optional env vars: + TRAIN_NUM_WORKERS=4 + ROLLOUT_NUM_WORKERS=4 + TRAIN_BATCH_SIZE=64 + TOTAL_TRAIN_STEPS=4 + SYNC_WEIGHTS_INTERVAL=1 + OVER_SAMPLE_THRESHOLD=0.0 + PARTIAL_ROLLOUT=0 + GSM8K_TASK_WEIGHT=3.0 + DAPO_TASK_WEIGHT=1.0 + ENABLE_EVALUATE=0 +""" + +import os +from pathlib import Path + +from transformers import AutoTokenizer + +from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig +from xtuner.v1.data_proto.rl_data import SampleParams +from xtuner.v1.rl.advantage import GRPOAdvantageConfig +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLTextTokenizeFnConfig +from xtuner.v1.model import get_model_config_from_hf +from xtuner.v1.rl.agent_loop import SingleTurnAgentLoopConfig +from xtuner.v1.rl.agent_loop_manager import ( + AgentLoopManagerConfig, + AsyncProduceStrategyConfig, + SamplerConfig, + SyncProduceStrategyConfig, + TaskSpecConfig, +) +from xtuner.v1.rl.evaluator import EvaluatorConfig +from xtuner.v1.rl.judger import DapoMathJudgerConfig +from xtuner.v1.rl.loss import GRPOLossConfig +from xtuner.v1.rl.replay_buffer import SyncReplayBufferConfig +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.trainer import WorkerConfig +from xtuner.v1.rl.utils import AcceleratorResourcesConfig, get_eos_token +from xtuner.v1.train.rl_trainer import ( + RLDisaggregatedTrainerConfig, +) + + +work_dir = os.environ["WORK_DIR"] +model_path = os.environ["MODEL_PATH"] +gsm8k_data_path = os.environ["GSM8K_DATA_PATH"] +gsm8k_eval_data_path = os.environ["GSM8K_EVAL_DATA_PATH"] +dapo_data_path = os.environ["DAPO_DATA_PATH"] +dapo_eval_data_path = os.environ["DAPO_EVAL_DATA_PATH"] +enable_return_routed_experts = os.environ.get("ENABLE_RETURN_ROUTED_EXPERTS", "0") + + +experimental_name = "disaggregated_multi_task_gsm8k_dapo_math" +total_train_steps = int(os.environ.get("TOTAL_TRAIN_STEPS", "4")) +evaluate_step = int(os.environ.get("EVALUATE_STEP", str(total_train_steps))) +train_optimizer_steps = int(os.environ.get("TRAIN_OPTIMIZER_STEPS", "4")) +train_batch_size = int(os.environ.get("TRAIN_BATCH_SIZE", "64")) +sync_weights_interval = int(os.environ.get("SYNC_WEIGHTS_INTERVAL", "1")) +over_sample_threshold = float(os.environ.get("OVER_SAMPLE_THRESHOLD", "0.0")) +partial_rollout = os.environ.get("PARTIAL_ROLLOUT", "0") == "1" +tail_batch_trigger_size = int(os.environ.get("TAIL_BATCH_TRIGGER_SIZE", "0")) +max_staleness = int(os.environ.get("MAX_STALENESS", "0")) +enable_evaluate = os.environ.get("ENABLE_EVALUATE", "0") == "1" +gsm8k_task_weight = float(os.environ.get("GSM8K_TASK_WEIGHT", "3.0")) +dapo_task_weight = float(os.environ.get("DAPO_TASK_WEIGHT", "1.0")) +rollout_tp_size = int(os.environ.get("ROLLOUT_TP_SIZE", "1")) +rollout_ep_size = int(os.environ.get("ROLLOUT_EP_SIZE", "1")) +gsm8k_prompt_repeat_k = int(os.environ.get("GSM8K_PROMPT_REPEAT_K", "3")) +dapo_prompt_repeat_k = int(os.environ.get("DAPO_PROMPT_REPEAT_K", "4")) +gsm8k_max_prompt_length = int(os.environ.get("GSM8K_MAX_PROMPT_LENGTH", "512")) +dapo_max_prompt_length = int(os.environ.get("DAPO_MAX_PROMPT_LENGTH", "2048")) +gsm8k_max_response_length = int(os.environ.get("GSM8K_MAX_RESPONSE_LENGTH", "1024")) +dapo_max_response_length = int(os.environ.get("DAPO_MAX_RESPONSE_LENGTH", "4096")) +pack_max_length = int(os.environ.get("PACK_MAX_LENGTH", "32768")) + +max_prompt_length = dapo_max_prompt_length +max_response_length = dapo_max_response_length + + +train_resources = AcceleratorResourcesConfig( + accelerator=os.environ.get("ACCELERATOR", "GPU"), + num_workers=int(os.environ.get("TRAIN_NUM_WORKERS", "4")), + num_cpus_per_worker=float(os.environ.get("TRAIN_CPUS_PER_WORKER", "12")), + cpu_memory_per_worker=int(os.environ.get("TRAIN_CPU_MEMORY_PER_WORKER", str(16 * 1024**3))), +) + +rollout_resources = AcceleratorResourcesConfig( + accelerator=os.environ.get("ACCELERATOR", "GPU"), + num_workers=int(os.environ.get("ROLLOUT_NUM_WORKERS", "4")), + num_cpus_per_worker=float(os.environ.get("ROLLOUT_CPUS_PER_WORKER", "12")), + cpu_memory_per_worker=int(os.environ.get("ROLLOUT_CPU_MEMORY_PER_WORKER", str(16 * 1024**3))), +) + + +rollout_config = RolloutConfig( + env=experimental_name, + device=rollout_resources.accelerator, + model_path=model_path, + dtype="bfloat16", + tensor_parallel_size=rollout_tp_size, + expert_parallel_size=rollout_ep_size, + gpu_memory_utilization=float(os.environ.get("ROLLOUT_GPU_MEMORY_UTILIZATION", "0.8")), + context_length=max_response_length + max_prompt_length, + enable_return_routed_experts=(enable_return_routed_experts == "1"), + rollout_max_batch_size_per_instance=2048, +) + + +eos_token_id = get_eos_token(model_path) +tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) +eos_token_str = tokenizer.convert_ids_to_tokens(eos_token_id) +judger_config = DapoMathJudgerConfig( + judger_name="dapo_math", + eos_token=eos_token_str, + enable_overlong_buffer=True, + max_response_len=max_response_length, + overlong_buffer_len=4096, + overlong_penalty_factor=1.0, + tokenizer=tokenizer, +) + + +lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) +fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) +model_cfg = get_model_config_from_hf(Path(model_path)) +if hasattr(model_cfg, "balancing_loss_cfg"): + model_cfg.balancing_loss_cfg = None +if hasattr(model_cfg, "z_loss_cfg"): + model_cfg.z_loss_cfg = None +optim_cfg = AdamWConfig(lr=1e-6, foreach=False, weight_decay=0.1) +loss_cfg = GRPOLossConfig( + policy_loss_cfg=dict( + cliprange_high=0.28, + cliprange_low=0.2, + loss_type=os.environ.get("LOSS_TYPE", "vanilla"), + clip_ratio_c=10.0, + log_prob_diff_min=-20.0, + log_prob_diff_max=20.0, + ), + ignore_idx=-100, + use_kl_loss=False, + kl_loss_coef=0.0, + kl_loss_type="low_var_kl", + mode=os.environ.get("LOSS_MODE", "chunk"), + chunk_size=512, +) +train_worker_cfg = WorkerConfig( + model_cfg=model_cfg, + load_from=model_path, + optim_cfg=optim_cfg, + loss_cfg=loss_cfg, + lr_cfg=lr_cfg, + fsdp_cfg=fsdp_cfg, + sp_size=int(os.environ.get("SP_SIZE", "1")), + optimizer_steps=train_optimizer_steps, + pack_max_length=pack_max_length, +) + + +gsm8k_train_tokenizer_config = RLTextTokenizeFnConfig(max_length=gsm8k_max_prompt_length) +dapo_train_tokenizer_config = RLTextTokenizeFnConfig(max_length=dapo_max_prompt_length) + +gsm8k_train_sampler_config = SamplerConfig( + dataloader_cfg=DataloaderConfig( + dataset_config_list=[ + { + "dataset": DatasetConfig(name="gsm8k", anno_path=gsm8k_data_path), + "tokenize_fn": gsm8k_train_tokenizer_config, + } + ], + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", + ), + prompt_repeat_k=gsm8k_prompt_repeat_k, +) +dapo_train_sampler_config = SamplerConfig( + dataloader_cfg=DataloaderConfig( + dataset_config_list=[ + { + "dataset": DatasetConfig(name="dapo_math", anno_path=dapo_data_path), + "tokenize_fn": dapo_train_tokenizer_config, + } + ], + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", + ), + prompt_repeat_k=dapo_prompt_repeat_k, +) + +gsm8k_train_agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=SampleParams( + max_tokens=gsm8k_max_response_length, + top_k=0, + top_p=1.0, + temperature=1.0, + min_tokens=0, + ), +) +dapo_train_agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=SampleParams( + max_tokens=dapo_max_response_length, + top_k=0, + top_p=1.0, + temperature=1.0, + min_tokens=0, + ), +) + +if over_sample_threshold > 0 or partial_rollout: + produce_strategy_config = AsyncProduceStrategyConfig( + over_sample_threshold=over_sample_threshold, + enable_partial_rollout=partial_rollout, + tail_batch_trigger_size=tail_batch_trigger_size, + max_staleness=max_staleness, + ) +else: + produce_strategy_config = SyncProduceStrategyConfig() + +agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=[ + TaskSpecConfig( + task_name="train_task:dapo_math", + weight=dapo_task_weight, + agent_loop_config=dapo_train_agent_loop_config, + judger_config=judger_config, + produce_strategy_config=produce_strategy_config, + sampler_config=dapo_train_sampler_config, + ), + TaskSpecConfig( + task_name="train_task:gsm8k", + weight=gsm8k_task_weight, + agent_loop_config=gsm8k_train_agent_loop_config, + judger_config=judger_config, + produce_strategy_config=produce_strategy_config, + sampler_config=gsm8k_train_sampler_config, + ), + ], +) + + +gsm8k_eval_sampler_config = SamplerConfig( + dataloader_cfg=DataloaderConfig( + dataset_config_list=[ + { + "dataset": DatasetConfig(name="gsm8k_eval", anno_path=gsm8k_eval_data_path, sample_ratio=1.0), + "tokenize_fn": gsm8k_train_tokenizer_config, + } + ], + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", + ), + prompt_repeat_k=1, +) +dapo_eval_sampler_config = SamplerConfig( + dataloader_cfg=DataloaderConfig( + dataset_config_list=[ + { + "dataset": DatasetConfig(name="dapo_math_eval", anno_path=dapo_eval_data_path, sample_ratio=1.0), + "tokenize_fn": dapo_train_tokenizer_config, + } + ], + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", + ), + prompt_repeat_k=1, +) + +gsm8k_eval_agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=SampleParams( + max_tokens=gsm8k_max_response_length, + top_k=1, + top_p=1.0, + temperature=0.0, + min_tokens=0, + ), +) +dapo_eval_agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=SampleParams( + max_tokens=dapo_max_response_length, + top_k=1, + top_p=0.7, + temperature=0.0, + min_tokens=0, + ), +) + +eval_agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=[ + TaskSpecConfig( + task_name="eval_task:dapo_math", + weight=dapo_task_weight, + agent_loop_config=dapo_eval_agent_loop_config, + judger_config=judger_config, + sampler_config=dapo_eval_sampler_config, + ), + TaskSpecConfig( + task_name="eval_task:gsm8k", + weight=gsm8k_task_weight, + agent_loop_config=gsm8k_eval_agent_loop_config, + judger_config=judger_config, + sampler_config=gsm8k_eval_sampler_config, + ), + ], +) + + +def compute_metric(samples): + return {"accuracy": sum(sample.reward["acc"] > 0 for sample in samples) / len(samples)} + + +evaluator_config = EvaluatorConfig(compute_metric_func=compute_metric) + +trainer = RLDisaggregatedTrainerConfig( + train_resources=train_resources, + rollout_resources=rollout_resources, + train_worker_cfg=train_worker_cfg, + rollout_config=rollout_config, + tokenizer_path=model_path, + replay_buffer_config=SyncReplayBufferConfig(), + agent_loop_manager_cfg=agent_loop_manager_cfg, + eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg, + evaluator_config=evaluator_config, + load_from=model_path, + train_batch_size=train_batch_size, + advantage_estimator_config=GRPOAdvantageConfig(eps=1e-8), + total_train_steps=total_train_steps, + sync_weights_interval=sync_weights_interval, + enable_evaluate=enable_evaluate, + enable_initial_evaluate=False, + evaluate_step=evaluate_step, + work_dir=work_dir, + seed=int(os.environ.get("SEED", "123")), + debug_rollout=os.environ.get("DEBUG_ROLLOUT", "0") == "1", +) diff --git a/examples/v1/config/rl_disagg_single.py b/examples/v1/config/rl_disagg_single.py new file mode 100644 index 0000000000..0be84c6321 --- /dev/null +++ b/examples/v1/config/rl_disagg_single.py @@ -0,0 +1,276 @@ +"""RL Disaggregated Trainer example config (GRPO + GSM8K). + +This config uses a mocked Disaggregated weight-sync hook until the real cross-device weight update module lands. + +Required env vars: WORK_DIR, MODEL_PATH, DATA_PATH, EVAL_DATA_PATH +Common optional env vars: + TRAIN_NUM_WORKERS=4, ROLLOUT_NUM_WORKERS=4, TRAIN_BATCH_SIZE=64, + TOTAL_TRAIN_STEPS=45, SYNC_WEIGHTS_INTERVAL=1, + OVER_SAMPLE_THRESHOLD=0.0, PARTIAL_ROLLOUT=0, + TAIL_BATCH_TRIGGER_SIZE=0, MAX_STALENESS=0, ENABLE_EVALUATE=0 + +Mode mapping in the current design: + Mode 1 (On-Policy): + SYNC_WEIGHTS_INTERVAL=1 + OVER_SAMPLE_THRESHOLD=0.0 + PARTIAL_ROLLOUT=0 + Mode 2 (Stream Off-Policy): + SYNC_WEIGHTS_INTERVAL>1 + OVER_SAMPLE_THRESHOLD=0.0 + PARTIAL_ROLLOUT=0 + Mode 3 (Async Stale): + OVER_SAMPLE_THRESHOLD>0.0 + PARTIAL_ROLLOUT=0 + Mode 4 (Async Partial Rollout): + OVER_SAMPLE_THRESHOLD>0.0 + PARTIAL_ROLLOUT=1 + +Responsibility split: + - trainer / step scheduling: + TRAIN_BATCH_SIZE, TOTAL_TRAIN_STEPS, SYNC_WEIGHTS_INTERVAL + - producer / replay-buffer policy: + OVER_SAMPLE_THRESHOLD, PARTIAL_ROLLOUT, + TAIL_BATCH_TRIGGER_SIZE, MAX_STALENESS +""" + +import os +from pathlib import Path + +from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig +from xtuner.v1.data_proto.rl_data import SampleParams +from xtuner.v1.rl.advantage import GRPOAdvantageConfig +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLTextTokenizeFnConfig +from xtuner.v1.model import get_model_config_from_hf +from xtuner.v1.rl.agent_loop import SingleTurnAgentLoopConfig +from xtuner.v1.rl.agent_loop_manager import ( + AgentLoopManagerConfig, + AsyncProduceStrategyConfig, + SamplerConfig, + SyncProduceStrategyConfig, + TaskSpecConfig, +) +from xtuner.v1.rl.evaluator import EvaluatorConfig +from xtuner.v1.rl.judger import GSM8KJudgerConfig +from xtuner.v1.rl.loss import GRPOLossConfig +from xtuner.v1.rl.replay_buffer import SyncReplayBufferConfig +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.trainer import WorkerConfig +from xtuner.v1.rl.utils import AcceleratorResourcesConfig +from xtuner.v1.train.rl_trainer import ( + RLDisaggregatedTrainerConfig, +) + + +# env +work_dir = os.environ["WORK_DIR"] +model_path = os.environ["MODEL_PATH"] +data_path = os.environ["DATA_PATH"] +eval_data_path = os.environ["EVAL_DATA_PATH"] +enable_return_routed_experts = os.environ.get("ENABLE_RETURN_ROUTED_EXPERTS", "0") + + +# basic settings +experimental_name = "disaggregated_grpo_gsm8k" +total_train_steps = int(os.environ.get("TOTAL_TRAIN_STEPS", "16")) +evaluate_step = int(os.environ.get("EVALUATE_STEP", str(total_train_steps))) +train_optimizer_steps = int(os.environ.get("TRAIN_OPTIMIZER_STEPS", "1")) +train_batch_size = int(os.environ.get("TRAIN_BATCH_SIZE", str(32 * train_optimizer_steps))) +sync_weights_interval = int(os.environ.get("SYNC_WEIGHTS_INTERVAL", "1")) +over_sample_threshold = float(os.environ.get("OVER_SAMPLE_THRESHOLD", "0.0")) +partial_rollout = os.environ.get("PARTIAL_ROLLOUT", "0") == "1" +tail_batch_trigger_size = int(os.environ.get("TAIL_BATCH_TRIGGER_SIZE", "0")) +max_staleness = int(os.environ.get("MAX_STALENESS", "0")) +prompt_repeat_k = int(os.environ.get("PROMPT_REPEAT_K", "4")) +rollout_tp_size = int(os.environ.get("ROLLOUT_TP_SIZE", "1")) +rollout_ep_size = int(os.environ.get("ROLLOUT_EP_SIZE", "1")) +max_prompt_length = int(os.environ.get("MAX_PROMPT_LENGTH", "512")) +max_response_length = int(os.environ.get("MAX_RESPONSE_LENGTH", "1024")) +pack_max_length = int(os.environ.get("PACK_MAX_LENGTH", str(32 * 1024))) +enable_evaluate = os.environ.get("ENABLE_EVALUATE", "1") == "1" + +# execution knobs: +# - sync_weights_interval controls how many train steps share one weight-sync interval +# - over_sample_threshold / partial_rollout feed the train-task produce strategy +# - tail_batch_* controls replay-buffer recycling policy inside AsyncProduceStrategy + + +# 1. resources: default 4 GPUs for training and 4 GPUs for rollout. +train_resources = AcceleratorResourcesConfig( + accelerator=os.environ.get("ACCELERATOR", "GPU"), + num_workers=int(os.environ.get("TRAIN_NUM_WORKERS", "4")), + num_cpus_per_worker=float(os.environ.get("TRAIN_CPUS_PER_WORKER", "12")), + cpu_memory_per_worker=int(os.environ.get("TRAIN_CPU_MEMORY_PER_WORKER", str(16 * 1024**3))), +) + +rollout_resources = AcceleratorResourcesConfig( + accelerator=os.environ.get("ACCELERATOR", "GPU"), + num_workers=int(os.environ.get("ROLLOUT_NUM_WORKERS", "4")), + num_cpus_per_worker=float(os.environ.get("ROLLOUT_CPUS_PER_WORKER", "12")), + cpu_memory_per_worker=int(os.environ.get("ROLLOUT_CPU_MEMORY_PER_WORKER", str(16 * 1024**3))), +) + + +# 2. rollout +rollout_config = RolloutConfig( + env=experimental_name, + device=rollout_resources.accelerator, + model_path=model_path, + dtype="bfloat16", + tensor_parallel_size=rollout_tp_size, + expert_parallel_size=rollout_ep_size, + gpu_memory_utilization=float(os.environ.get("ROLLOUT_GPU_MEMORY_UTILIZATION", "0.8")), + context_length=max_response_length + max_prompt_length, + enable_return_routed_experts=(enable_return_routed_experts == "1"), +) + + +# 3. judger +judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", judger_type="router") + + +# 4. train worker +lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) +fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) +model_cfg = get_model_config_from_hf(Path(model_path)) +if hasattr(model_cfg, "balancing_loss_cfg"): + model_cfg.balancing_loss_cfg = None +if hasattr(model_cfg, "z_loss_cfg"): + model_cfg.z_loss_cfg = None + +optim_cfg = AdamWConfig(lr=1e-6, foreach=False, weight_decay=0.1) +loss_cfg = GRPOLossConfig( + policy_loss_cfg=dict( + cliprange_high=0.28, + cliprange_low=0.2, + loss_type=os.environ.get("LOSS_TYPE", "vanilla"), + clip_ratio_c=10.0, + log_prob_diff_min=-20.0, + log_prob_diff_max=20.0, + ), + ignore_idx=-100, + use_kl_loss=False, + kl_loss_coef=0.0, + kl_loss_type="low_var_kl", + mode=os.environ.get("LOSS_MODE", "chunk"), + chunk_size=512, +) +train_worker_cfg = WorkerConfig( + model_cfg=model_cfg, + load_from=model_path, + optim_cfg=optim_cfg, + loss_cfg=loss_cfg, + lr_cfg=lr_cfg, + fsdp_cfg=fsdp_cfg, + sp_size=int(os.environ.get("SP_SIZE", "1")), + optimizer_steps=train_optimizer_steps, + pack_max_length=pack_max_length, +) + + +# 5. train agent loop manager +train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) +tokenizer_config = RLTextTokenizeFnConfig(max_length=max_prompt_length) +train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] +dataloader_cfg = DataloaderConfig( + dataset_config_list=train_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +sampler_config = SamplerConfig( + dataloader_cfg=dataloader_cfg, + prompt_repeat_k=prompt_repeat_k, +) +training_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=0, + top_p=1.0, + temperature=1.0, + min_tokens=0, +) +agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=training_sample_params, +) +if over_sample_threshold > 0 or partial_rollout: + produce_strategy_config = AsyncProduceStrategyConfig( + over_sample_threshold=over_sample_threshold, + enable_partial_rollout=partial_rollout, + tail_batch_trigger_size=tail_batch_trigger_size, + max_staleness=max_staleness, + ) +else: + produce_strategy_config = SyncProduceStrategyConfig() +agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="train_task", + agent_loop_config=agent_loop_config, + judger_config=judger_config, + produce_strategy_config=produce_strategy_config, + sampler_config=sampler_config, + ), +) + + +# 6. eval agent loop manager +eval_dataset = DatasetConfig(name=experimental_name, anno_path=eval_data_path, sample_ratio=1.0) +eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] +eval_dataloader_cfg = DataloaderConfig( + dataset_config_list=eval_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +eval_sampler_config = SamplerConfig( + dataloader_cfg=eval_dataloader_cfg, + prompt_repeat_k=1, +) +evaluation_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=1, + top_p=1.0, + temperature=0.0, + min_tokens=0, +) +eval_agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=evaluation_sample_params, +) +eval_agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="eval_task", + agent_loop_config=eval_agent_loop_config, + judger_config=judger_config, + sampler_config=eval_sampler_config, + ), +) + + +# 7. evaluator +evaluator_config = EvaluatorConfig(compute_metric_func=None) + + +# 8. RL Disaggregated Trainer Config +trainer = RLDisaggregatedTrainerConfig( + train_resources=train_resources, + rollout_resources=rollout_resources, + train_worker_cfg=train_worker_cfg, + rollout_config=rollout_config, + tokenizer_path=model_path, + replay_buffer_config=SyncReplayBufferConfig(), + agent_loop_manager_cfg=agent_loop_manager_cfg, + eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg, + evaluator_config=evaluator_config, + load_from=model_path, + train_batch_size=train_batch_size, + advantage_estimator_config=GRPOAdvantageConfig(eps=1e-8), + total_train_steps=total_train_steps, + sync_weights_interval=sync_weights_interval, + enable_evaluate=enable_evaluate, + enable_initial_evaluate=True, + evaluate_step=evaluate_step, + work_dir=work_dir, + seed=int(os.environ.get("SEED", "123")), + debug_rollout=os.environ.get("DEBUG_ROLLOUT", "0") == "1", +) diff --git a/examples/v1/config/rl_grpo_geo3k_judge.py b/examples/v1/config/rl_grpo_geo3k_judge.py new file mode 100644 index 0000000000..0c4447bd0c --- /dev/null +++ b/examples/v1/config/rl_grpo_geo3k_judge.py @@ -0,0 +1,226 @@ +"""RL Colocate Trainer 示例配置(GRPO + GSM8K)。 + +用法:通过环境变量传入路径后,由 CLI 加载本配置并 trainer_cfg.build().fit()。 +需设置: WORK_DIR, MODEL_PATH, DATA_PATH, EVAL_DATA_PATH +可选: WORLD_SIZE, ENABLE_RETURN_ROUTED_EXPERTS, LOSS_TYPE, LOSS_MODE, SP_SIZE +""" +import os + +from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig +from xtuner.v1.data_proto.rl_data import SampleParams +from xtuner.v1.rl.advantage import GRPOAdvantageConfig +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLQwen3VLTokenizeFnConfig +from xtuner.v1.model.compose.qwen3_vl import Qwen3VLDense8BConfig +from xtuner.v1.rl.utils import AcceleratorResourcesConfig +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.judger import GEO3KJudgerConfig +from xtuner.v1.rl.replay_buffer import SyncReplayBufferConfig +from xtuner.v1.rl.trainer import WorkerConfig +from xtuner.v1.rl.agent_loop import SingleTurnAgentLoopConfig +from xtuner.v1.rl.agent_loop_manager import AgentLoopManagerConfig, SamplerConfig, SyncProduceStrategyConfig, TaskSpecConfig +from xtuner.v1.rl.evaluator import EvaluatorConfig +from xtuner.v1.rl.loss import GRPOLossConfig +from xtuner.v1.train.rl_trainer import RLColocateTrainerConfig + +# env +work_dir = os.environ["WORK_DIR"] +model_path = os.environ["MODEL_PATH"] +data_path = os.environ["DATA_PATH"] +eval_data_path = os.environ["EVAL_DATA_PATH"] +enable_return_routed_experts = os.environ.get("ENABLE_RETURN_ROUTED_EXPERTS", "0") +NNODE = int(os.environ.get("WORLD_SIZE", "1")) +media_root = os.environ["MEDIA_ROOT"] + +# basic settings +experimental_name = "grpo_geo3k" +total_train_steps = 45 # TODO: total_epoch +evaluate_step = 45 +train_optimizer_steps = 4 +train_batch_size = 1024 +prompt_repeat_k = 5 +rollout_tp_size = 1 +rollout_ep_size = 1 +max_prompt_length = 1024 +max_response_length = 2048 +pack_max_length = 32 * 1024 + +# 1. resources +resources = AcceleratorResourcesConfig( + accelerator="GPU", + num_workers=8 * NNODE, + num_cpus_per_worker=12, + cpu_memory_per_worker=16 * 1024**3, # 16 GB +) + +# 2. rollout +rollout_config = RolloutConfig( + env=experimental_name, + device=resources.accelerator, + model_path=model_path, + dtype="bfloat16", + tensor_parallel_size=rollout_tp_size, + expert_parallel_size=rollout_ep_size, + gpu_memory_utilization=0.8, + context_length=max_response_length + max_prompt_length, + enable_return_routed_experts=(enable_return_routed_experts == "1"), +) + +# 3. judger +judger_config = GEO3KJudgerConfig(num_ray_actors=1) + +# 4. train worker +lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) +fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) + +# TODO: support get_model_config_from_hf +model_cfg = Qwen3VLDense8BConfig(freeze_vision=True, freeze_projector=True) + +if hasattr(model_cfg.text_config, "balancing_loss_cfg"): + model_cfg.text_config.balancing_loss_cfg = None +if hasattr(model_cfg.text_config, "z_loss_cfg"): + model_cfg.text_config.z_loss_cfg = None +optim_cfg = AdamWConfig(lr=1e-6, foreach=False, weight_decay=0.1) +loss_cfg = GRPOLossConfig( + policy_loss_cfg=dict( + cliprange_high=0.28, + cliprange_low=0.2, + loss_type=os.environ.get("LOSS_TYPE", "vanilla"), + clip_ratio_c=10.0, + log_prob_diff_min=-20.0, + log_prob_diff_max=20.0, + ), + ignore_idx=-100, + use_kl_loss=False, + kl_loss_coef=0.0, + kl_loss_type="low_var_kl", + mode=os.environ.get("LOSS_MODE", "chunk"), + chunk_size=512, +) +train_worker_cfg = WorkerConfig( + model_cfg=model_cfg, + load_from=model_path, + optim_cfg=optim_cfg, + loss_cfg=loss_cfg, + lr_cfg=lr_cfg, + fsdp_cfg=fsdp_cfg, + sp_size=int(os.environ.get("SP_SIZE", "1")), + optimizer_steps=train_optimizer_steps, + pack_max_length=pack_max_length, +) + +# 5. train agent loop manager +train_dataset_cfg = [ + { + "dataset": DatasetConfig(name="geo3k", + anno_path=data_path, + class_name='VLMJsonlDataset', + media_root=media_root, + sample_ratio=1.0), + "tokenize_fn": RLQwen3VLTokenizeFnConfig(processor_path=model_path, + max_length=max_prompt_length), + } +] + +dataloader_cfg = DataloaderConfig( + dataset_config_list=train_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", + num_workers=8, +) +sampler_config = SamplerConfig( + dataloader_cfg=dataloader_cfg, + prompt_repeat_k=prompt_repeat_k, +) +training_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=0, + top_p=1.0, + temperature=1.0, + min_tokens=0, +) +agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=training_sample_params, +) +produce_strategy_config = SyncProduceStrategyConfig() +agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="train_task", + agent_loop_config=agent_loop_config, + judger_config=judger_config, + produce_strategy_config=produce_strategy_config, + sampler_config=sampler_config, + ), +) + +# 6. eval agent loop manager +eval_dataset_cfg = [ + { + "dataset": DatasetConfig(name="geo3k", + anno_path=eval_data_path, + class_name='VLMJsonlDataset', + media_root=media_root, + sample_ratio=1.0), + "tokenize_fn": RLQwen3VLTokenizeFnConfig(processor_path=model_path, + max_length=max_prompt_length, + ignore_multimodal_info=True), + } +] + +eval_dataloader_cfg = DataloaderConfig( + dataset_config_list=eval_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", + num_workers=8, +) +eval_sampler_config = SamplerConfig( + dataloader_cfg=eval_dataloader_cfg, + prompt_repeat_k=1, +) +evaluation_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=1, + top_p=1.0, + temperature=0.0, + min_tokens=0, +) +eval_agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=evaluation_sample_params, +) +eval_agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="eval_task", + agent_loop_config=eval_agent_loop_config, + judger_config=judger_config, + sampler_config=eval_sampler_config, + ), +) + +# 7. evaluator +evaluator_config = EvaluatorConfig(compute_metric_func=None) + +# 8. RL Colocate Trainer Config(CLI 通过 config["trainer"].build() 得到 Trainer) +trainer = RLColocateTrainerConfig( + resources=resources, + train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config + rollout_config=rollout_config, + tokenizer_path=model_path, + replay_buffer_config=SyncReplayBufferConfig(), + agent_loop_manager_cfg=agent_loop_manager_cfg, + eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg, + evaluator_config=evaluator_config, + load_from=model_path, + total_train_steps=total_train_steps, + train_batch_size=train_batch_size, + advantage_estimator_config=GRPOAdvantageConfig(eps=1e-8), + enable_evaluate=True, + enable_initial_evaluate=False, + evaluate_step=evaluate_step, + work_dir=work_dir, + seed=123, + debug_rollout=False, +) diff --git a/examples/v1/config/rl_grpo_gsm8k_async.py b/examples/v1/config/rl_grpo_gsm8k_async.py new file mode 100644 index 0000000000..ab0504a48f --- /dev/null +++ b/examples/v1/config/rl_grpo_gsm8k_async.py @@ -0,0 +1,208 @@ +"""RL Colocate Trainer 示例配置(GRPO + GSM8K)。 + +用法:通过环境变量传入路径后,由 CLI 加载本配置并 trainer_cfg.build().fit()。 +需设置: WORK_DIR, MODEL_PATH, DATA_PATH, EVAL_DATA_PATH +可选: WORLD_SIZE, ENABLE_RETURN_ROUTED_EXPERTS, LOSS_TYPE, LOSS_MODE, SP_SIZE +""" +import os +from pathlib import Path + +from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig +from xtuner.v1.data_proto.rl_data import SampleParams +from xtuner.v1.rl.advantage import GRPOAdvantageConfig +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLTextTokenizeFnConfig +from xtuner.v1.model import get_model_config_from_hf +from xtuner.v1.rl.utils import AcceleratorResourcesConfig +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.judger import GSM8KJudgerConfig +from xtuner.v1.rl.replay_buffer import AsyncReplayBufferConfig +from xtuner.v1.rl.trainer import WorkerConfig +from xtuner.v1.rl.agent_loop import SingleTurnAgentLoopConfig +from xtuner.v1.rl.agent_loop_manager import AgentLoopManagerConfig, AsyncProduceStrategyConfig, SamplerConfig, TaskSpecConfig +from xtuner.v1.rl.evaluator import EvaluatorConfig +from xtuner.v1.rl.loss import GRPOLossConfig +from xtuner.v1.train.rl_trainer import RLColocateTrainerConfig + +# env +work_dir = os.environ["WORK_DIR"] +model_path = os.environ["MODEL_PATH"] +data_path = os.environ["DATA_PATH"] +eval_data_path = os.environ["EVAL_DATA_PATH"] +enable_return_routed_experts = os.environ.get("ENABLE_RETURN_ROUTED_EXPERTS", "0") +NNODE = int(os.environ.get("WORLD_SIZE", "1")) + +# basic settings +experimental_name = "grpo_gsm8k" +total_train_steps = 45 +evaluate_step = 45 +train_optimizer_steps = 1 +train_batch_size = 64 * train_optimizer_steps +prompt_repeat_k = 5 +rollout_tp_size = 1 +rollout_ep_size = 1 +max_prompt_length = 2048 +max_response_length = 8192 +pack_max_length = 10 * 1024 + +# 1. resources +resources = AcceleratorResourcesConfig( + accelerator="GPU", + num_workers=8 * NNODE, + num_cpus_per_worker=12, + cpu_memory_per_worker=16 * 1024**3, # 16 GB +) + +# 2. rollout +rollout_config = RolloutConfig( + env=experimental_name, + device=resources.accelerator, + model_path=model_path, + dtype="bfloat16", + tensor_parallel_size=rollout_tp_size, + expert_parallel_size=rollout_ep_size, + gpu_memory_utilization=0.8, + context_length=max_response_length + max_prompt_length, + enable_return_routed_experts=(enable_return_routed_experts == "1"), +) + +# 3. judger +judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", num_ray_actors=1) + +# 4. train worker +lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) +fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) +model_cfg = get_model_config_from_hf(Path(model_path)) +if hasattr(model_cfg, "balancing_loss_cfg"): + model_cfg.balancing_loss_cfg = None +if hasattr(model_cfg, "z_loss_cfg"): + model_cfg.z_loss_cfg = None +optim_cfg = AdamWConfig(lr=1e-6, foreach=False, weight_decay=0.1) +loss_cfg = GRPOLossConfig( + policy_loss_cfg=dict( + cliprange_high=0.28, + cliprange_low=0.2, + loss_type=os.environ.get("LOSS_TYPE", "vanilla"), + clip_ratio_c=10.0, + log_prob_diff_min=-20.0, + log_prob_diff_max=20.0, + ), + ignore_idx=-100, + use_kl_loss=False, + kl_loss_coef=0.0, + kl_loss_type="low_var_kl", + mode=os.environ.get("LOSS_MODE", "chunk"), + chunk_size=512, +) +train_worker_cfg = WorkerConfig( + model_cfg=model_cfg, + load_from=model_path, + optim_cfg=optim_cfg, + loss_cfg=loss_cfg, + lr_cfg=lr_cfg, + fsdp_cfg=fsdp_cfg, + sp_size=int(os.environ.get("SP_SIZE", "1")), + optimizer_steps=train_optimizer_steps, + pack_max_length=pack_max_length, +) + +# 5. train agent loop manager +train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) +tokenizer_config = RLTextTokenizeFnConfig(max_length=max_prompt_length) +train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] +dataloader_cfg = DataloaderConfig( + dataset_config_list=train_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +sampler_config = SamplerConfig( + dataloader_cfg=dataloader_cfg, + prompt_repeat_k=prompt_repeat_k, +) +training_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=0, + top_p=1.0, + temperature=1.0, + min_tokens=0, +) +agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=training_sample_params, +) +produce_strategy_config = AsyncProduceStrategyConfig( + over_sample_threshold = 0.8, + enable_partial_rollout = True, + max_staleness=0, + tail_batch_trigger_size=64 +) +agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="train_task", + agent_loop_config=agent_loop_config, + judger_config=judger_config, + produce_strategy_config=produce_strategy_config, + sampler_config=sampler_config, + ), +) + +# 6. eval agent loop manager +eval_dataset = DatasetConfig( + name=experimental_name, anno_path=eval_data_path, sample_ratio=1.0 +) +eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] +eval_dataloader_cfg = DataloaderConfig( + dataset_config_list=eval_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +eval_sampler_config = SamplerConfig( + dataloader_cfg=eval_dataloader_cfg, + prompt_repeat_k=1, +) +evaluation_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=1, + top_p=1.0, + temperature=0.0, + min_tokens=0, +) +eval_agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=evaluation_sample_params, +) +eval_agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="eval_task", + agent_loop_config=eval_agent_loop_config, + judger_config=judger_config, + sampler_config=eval_sampler_config, + ), +) + +# 7. evaluator +evaluator_config = EvaluatorConfig(compute_metric_func=None) + +# 8. RL Colocate Trainer Config(CLI 通过 config["trainer"].build() 得到 Trainer) +trainer = RLColocateTrainerConfig( + resources=resources, + train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config + rollout_config=rollout_config, + tokenizer_path=model_path, + replay_buffer_config=AsyncReplayBufferConfig(), + agent_loop_manager_cfg=agent_loop_manager_cfg, + eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg, + evaluator_config=evaluator_config, + load_from=model_path, + total_train_steps=total_train_steps, + train_batch_size=train_batch_size, + advantage_estimator_config=GRPOAdvantageConfig(eps=1e-8), + enable_evaluate=True, + enable_initial_evaluate=False, + evaluate_step=evaluate_step, + work_dir=work_dir, + seed=123, + debug_rollout=False, +) diff --git a/examples/v1/config/rl_grpo_gsm8k_judge.py b/examples/v1/config/rl_grpo_gsm8k_judge.py new file mode 100644 index 0000000000..2d7b09a5d8 --- /dev/null +++ b/examples/v1/config/rl_grpo_gsm8k_judge.py @@ -0,0 +1,203 @@ +"""RL Colocate Trainer 示例配置(GRPO + GSM8K)。 + +用法:通过环境变量传入路径后,由 CLI 加载本配置并 trainer_cfg.build().fit()。 +需设置: WORK_DIR, MODEL_PATH, DATA_PATH, EVAL_DATA_PATH +可选: WORLD_SIZE, ENABLE_RETURN_ROUTED_EXPERTS, LOSS_TYPE, LOSS_MODE, SP_SIZE +""" +import os +from pathlib import Path + +from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig +from xtuner.v1.data_proto.rl_data import SampleParams +from xtuner.v1.rl.advantage import GRPOAdvantageConfig +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLTextTokenizeFnConfig +from xtuner.v1.model import get_model_config_from_hf +from xtuner.v1.rl.utils import AcceleratorResourcesConfig +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.judger import GSM8KJudgerConfig +from xtuner.v1.rl.replay_buffer import SyncReplayBufferConfig +from xtuner.v1.rl.trainer import WorkerConfig +from xtuner.v1.rl.agent_loop import SingleTurnAgentLoopConfig +from xtuner.v1.rl.agent_loop_manager import AgentLoopManagerConfig, SamplerConfig, SyncProduceStrategyConfig, TaskSpecConfig +from xtuner.v1.rl.evaluator import EvaluatorConfig +from xtuner.v1.rl.loss import GRPOLossConfig +from xtuner.v1.train.rl_trainer import RLColocateTrainerConfig + +# env +work_dir = os.environ["WORK_DIR"] +model_path = os.environ["MODEL_PATH"] +data_path = os.environ["DATA_PATH"] +eval_data_path = os.environ["EVAL_DATA_PATH"] +enable_return_routed_experts = os.environ.get("ENABLE_RETURN_ROUTED_EXPERTS", "0") +NNODE = int(os.environ.get("WORLD_SIZE", "1")) + +# basic settings +experimental_name = "grpo_gsm8k" +total_train_steps = 45 +evaluate_step = 45 +train_optimizer_steps = 1 +train_batch_size = 64 * train_optimizer_steps +prompt_repeat_k = 5 +rollout_tp_size = 1 +rollout_ep_size = 1 +max_prompt_length = 512 +max_response_length = 1024 +pack_max_length = 32 * 1024 + +# 1. resources +resources = AcceleratorResourcesConfig( + accelerator="GPU", + num_workers=8 * NNODE, + num_cpus_per_worker=12, + cpu_memory_per_worker=16 * 1024**3, # 16 GB +) + +# 2. rollout +rollout_config = RolloutConfig( + env=experimental_name, + device=resources.accelerator, + model_path=model_path, + dtype="bfloat16", + tensor_parallel_size=rollout_tp_size, + expert_parallel_size=rollout_ep_size, + gpu_memory_utilization=0.8, + context_length=max_response_length + max_prompt_length, + enable_return_routed_experts=(enable_return_routed_experts == "1"), +) + +# 3. judger +judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", num_ray_actors=1) + +# 4. train worker +lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) +fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) +model_cfg = get_model_config_from_hf(Path(model_path)) +if hasattr(model_cfg, "balancing_loss_cfg"): + model_cfg.balancing_loss_cfg = None +if hasattr(model_cfg, "z_loss_cfg"): + model_cfg.z_loss_cfg = None +optim_cfg = AdamWConfig(lr=1e-6, foreach=False, weight_decay=0.1) +loss_cfg = GRPOLossConfig( + policy_loss_cfg=dict( + cliprange_high=0.28, + cliprange_low=0.2, + loss_type=os.environ.get("LOSS_TYPE", "vanilla"), + clip_ratio_c=10.0, + log_prob_diff_min=-20.0, + log_prob_diff_max=20.0, + ), + ignore_idx=-100, + use_kl_loss=False, + kl_loss_coef=0.0, + kl_loss_type="low_var_kl", + mode=os.environ.get("LOSS_MODE", "chunk"), + chunk_size=512, +) +train_worker_cfg = WorkerConfig( + model_cfg=model_cfg, + load_from=model_path, + optim_cfg=optim_cfg, + loss_cfg=loss_cfg, + lr_cfg=lr_cfg, + fsdp_cfg=fsdp_cfg, + sp_size=int(os.environ.get("SP_SIZE", "1")), + optimizer_steps=train_optimizer_steps, + pack_max_length=pack_max_length, +) + +# 5. train agent loop manager +train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) +tokenizer_config = RLTextTokenizeFnConfig(max_length=max_prompt_length) +train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] +dataloader_cfg = DataloaderConfig( + dataset_config_list=train_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +sampler_config = SamplerConfig( + dataloader_cfg=dataloader_cfg, + prompt_repeat_k=prompt_repeat_k, +) +training_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=0, + top_p=1.0, + temperature=1.0, + min_tokens=0, +) +agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=training_sample_params, +) +produce_strategy_config = SyncProduceStrategyConfig() +agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="train_task", + agent_loop_config=agent_loop_config, + judger_config=judger_config, + produce_strategy_config=produce_strategy_config, + sampler_config=sampler_config, + ), +) + +# 6. eval agent loop manager +eval_dataset = DatasetConfig( + name=experimental_name, anno_path=eval_data_path, sample_ratio=1.0 +) +eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] +eval_dataloader_cfg = DataloaderConfig( + dataset_config_list=eval_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +eval_sampler_config = SamplerConfig( + dataloader_cfg=eval_dataloader_cfg, + prompt_repeat_k=1, +) +evaluation_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=1, + top_p=1.0, + temperature=0.0, + min_tokens=0, +) +eval_agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=evaluation_sample_params, +) +eval_agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="eval_task", + agent_loop_config=eval_agent_loop_config, + judger_config=judger_config, + sampler_config=eval_sampler_config, + ), +) + +# 7. evaluator +evaluator_config = EvaluatorConfig(compute_metric_func=None) + +# 8. RL Colocate Trainer Config(CLI 通过 config["trainer"].build() 得到 Trainer) +trainer = RLColocateTrainerConfig( + resources=resources, + train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config + rollout_config=rollout_config, + tokenizer_path=model_path, + replay_buffer_config=SyncReplayBufferConfig(), + agent_loop_manager_cfg=agent_loop_manager_cfg, + eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg, + evaluator_config=evaluator_config, + load_from=model_path, + total_train_steps=total_train_steps, + train_batch_size=train_batch_size, + advantage_estimator_config=GRPOAdvantageConfig(eps=1e-8), + enable_evaluate=True, + enable_initial_evaluate=False, + evaluate_step=evaluate_step, + work_dir=work_dir, + seed=123, + debug_rollout=False, +) diff --git a/examples/v1/config/rl_grpo_gsm8k_with_tool.py b/examples/v1/config/rl_grpo_gsm8k_with_tool.py new file mode 100644 index 0000000000..fc38f332d6 --- /dev/null +++ b/examples/v1/config/rl_grpo_gsm8k_with_tool.py @@ -0,0 +1,224 @@ +"""RL Colocate Trainer 示例配置(GRPO + GSM8K)。 + +用法:通过环境变量传入路径后,由 CLI 加载本配置并 trainer_cfg.build().fit()。 +需设置: WORK_DIR, MODEL_PATH, DATA_PATH, EVAL_DATA_PATH +可选: WORLD_SIZE, ENABLE_RETURN_ROUTED_EXPERTS, LOSS_TYPE, LOSS_MODE, SP_SIZE +""" +import os +from pathlib import Path + +from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig +from xtuner.v1.data_proto.rl_data import SampleParams +from xtuner.v1.rl.advantage import GRPOAdvantageConfig +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLTextTokenizeFnConfig +from xtuner.v1.model import get_model_config_from_hf +from xtuner.v1.rl.utils import AcceleratorResourcesConfig +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.judger import GSM8KJudgerConfig +from xtuner.v1.rl.replay_buffer import SyncReplayBufferConfig +from xtuner.v1.rl.trainer import WorkerConfig +from xtuner.v1.rl.agent_loop_manager import AgentLoopManagerConfig, SamplerConfig, SyncProduceStrategyConfig, TaskSpecConfig +from xtuner.v1.rl.evaluator import EvaluatorConfig +from xtuner.v1.rl.loss import GRPOLossConfig +from xtuner.v1.train.rl_trainer import RLColocateTrainerConfig +from xtuner.v1.rl.agent_loop.gsm8k_with_tool import GSM8KToolAgentLoopConfig + +# env +work_dir = os.environ["WORK_DIR"] +model_path = os.environ["MODEL_PATH"] +data_path = os.environ["DATA_PATH"] +eval_data_path = os.environ["EVAL_DATA_PATH"] +enable_return_routed_experts = os.environ.get("ENABLE_RETURN_ROUTED_EXPERTS", "0") +NNODE = int(os.environ.get("WORLD_SIZE", "1")) + +# basic settings +experimental_name = "grpo_gsm8k_with_tool" +total_train_steps = 45 +evaluate_step = 45 +train_optimizer_steps = 1 +train_batch_size = 64 * train_optimizer_steps +prompt_repeat_k = 5 +rollout_tp_size = 1 +rollout_ep_size = 1 +max_prompt_length = 1024 +max_response_length = 2048 +pack_max_length = 8 * 1024 + +# 1. resources +resources = AcceleratorResourcesConfig( + accelerator="GPU", + num_workers=8 * NNODE, + num_cpus_per_worker=12, + cpu_memory_per_worker=16 * 1024**3, # 16 GB +) + +# 2. rollout +rollout_config = RolloutConfig( + env=experimental_name, + device=resources.accelerator, + model_path=model_path, + dtype="bfloat16", + tensor_parallel_size=rollout_tp_size, + expert_parallel_size=rollout_ep_size, + gpu_memory_utilization=0.8, + context_length=max_response_length + max_prompt_length, + enable_return_routed_experts=(enable_return_routed_experts == "1"), +) + +# 3. judger +judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", num_ray_actors=1) + +# 4. train worker +lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) +fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) +model_cfg = get_model_config_from_hf(Path(model_path)) +if hasattr(model_cfg, "balancing_loss_cfg"): + model_cfg.balancing_loss_cfg = None +if hasattr(model_cfg, "z_loss_cfg"): + model_cfg.z_loss_cfg = None +optim_cfg = AdamWConfig(lr=1e-6, foreach=False, weight_decay=0.1) +loss_cfg = GRPOLossConfig( + policy_loss_cfg=dict( + cliprange_high=0.28, + cliprange_low=0.2, + loss_type=os.environ.get("LOSS_TYPE", "vanilla"), + clip_ratio_c=10.0, + log_prob_diff_min=-20.0, + log_prob_diff_max=20.0, + ), + ignore_idx=-100, + use_kl_loss=False, + kl_loss_coef=0.0, + kl_loss_type="low_var_kl", + mode=os.environ.get("LOSS_MODE", "chunk"), + chunk_size=512, +) +train_worker_cfg = WorkerConfig( + model_cfg=model_cfg, + load_from=model_path, + optim_cfg=optim_cfg, + loss_cfg=loss_cfg, + lr_cfg=lr_cfg, + fsdp_cfg=fsdp_cfg, + sp_size=int(os.environ.get("SP_SIZE", "1")), + optimizer_steps=train_optimizer_steps, + pack_max_length=pack_max_length, +) + +# 5. train agent loop manager +gsm8k_tools = [ + { + "type": "function", + "function": { + "name": "calc_gsm8k_reward", + "description": "A tool for calculating the reward of gsm8k. (1.0 if parsed answer is correct, 0.0 if parsed answer is incorrect or not correctly parsed)", + "parameters": { + "type": "object", + "properties": { + "answer": { + "type": "string", + "description": "The model's answer to the GSM8K math problem, must be a digits", + }, + "required": ["answer"], + }, + }, + }, + } +] +train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) +tokenizer_config = RLTextTokenizeFnConfig(max_length=max_prompt_length, tools_schema=gsm8k_tools) +train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] +dataloader_cfg = DataloaderConfig( + dataset_config_list=train_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +sampler_config = SamplerConfig( + dataloader_cfg=dataloader_cfg, + prompt_repeat_k=prompt_repeat_k, +) +training_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=0, + top_p=1.0, + temperature=1.0, + min_tokens=0, +) +agent_loop_config = GSM8KToolAgentLoopConfig( + max_turns=2, + hf_checkpoint=model_path, + sample_params=training_sample_params, +) +produce_strategy_config = SyncProduceStrategyConfig() +agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="train_task", + agent_loop_config=agent_loop_config, + judger_config=judger_config, + produce_strategy_config=produce_strategy_config, + sampler_config=sampler_config, + ), +) + +# 6. eval agent loop manager +eval_dataset = DatasetConfig( + name=experimental_name, anno_path=eval_data_path, sample_ratio=1.0 +) +eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] +eval_dataloader_cfg = DataloaderConfig( + dataset_config_list=eval_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +eval_sampler_config = SamplerConfig( + dataloader_cfg=eval_dataloader_cfg, + prompt_repeat_k=1, +) +evaluation_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=1, + top_p=1.0, + temperature=0.0, + min_tokens=0, +) +eval_agent_loop_config = GSM8KToolAgentLoopConfig( + max_turns=2, + hf_checkpoint=model_path, + sample_params=evaluation_sample_params, +) +eval_agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="eval_task", + agent_loop_config=eval_agent_loop_config, + judger_config=judger_config, + sampler_config=eval_sampler_config, + ), +) + +# 7. evaluator +evaluator_config = EvaluatorConfig(compute_metric_func=None) + +# 8. RL Colocate Trainer Config(CLI 通过 config["trainer"].build() 得到 Trainer) +trainer = RLColocateTrainerConfig( + resources=resources, + train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config + rollout_config=rollout_config, + tokenizer_path=model_path, + replay_buffer_config=SyncReplayBufferConfig(), + agent_loop_manager_cfg=agent_loop_manager_cfg, + eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg, + evaluator_config=evaluator_config, + load_from=model_path, + total_train_steps=total_train_steps, + train_batch_size=train_batch_size, + advantage_estimator_config=GRPOAdvantageConfig(eps=1e-8), + enable_evaluate=True, + enable_initial_evaluate=False, + evaluate_step=evaluate_step, + work_dir=work_dir, + seed=123, + debug_rollout=False, +) diff --git a/examples/v1/config/rl_interns1_mini_grpo.py b/examples/v1/config/rl_interns1_mini_grpo.py deleted file mode 100644 index dd759408b7..0000000000 --- a/examples/v1/config/rl_interns1_mini_grpo.py +++ /dev/null @@ -1,206 +0,0 @@ -import os -from copy import deepcopy - -from transformers import AutoTokenizer -from xtuner.v1.config import ( - AdamWConfig, - FSDPConfig, - LRConfig, -) -from xtuner.v1.data_proto.rl_data import SampleParams -from xtuner.v1.model.compose.intern_s1 import InternS1MiniConfig -from xtuner.v1.ray.base import AcceleratorResourcesConfig -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig -from xtuner.v1.ray.evaluator import EvaluatorConfig -from xtuner.v1.ray.judger.controller import JudgerConfig -from xtuner.v1.rl.base import WorkerConfig -from xtuner.v1.rl.grpo import GRPOLossConfig -from xtuner.v1.train.rl_trainer import RLTrainerConfig -from xtuner.v1.datasets import RLTokenizeFnConfig, DatasetConfig, InternS1VLTokenizeFnConfig, DataloaderConfig -from xtuner.v1.ray.judger.geo3k import GEO3KJudgerConfig -from xtuner.v1.rl.config.advantage import GRPOAdvantageConfig - - -work_dir = os.environ["WORK_DIR"] -model_path = os.environ["MODEL_PATH"] -data_path = os.environ["DATA_PATH"] -eval_data_path = os.environ["EVAL_DATA_PATH"] -enable_evaluate = True if eval_data_path != "" else False -media_root = os.environ["MEDIA_ROOT"] - -# basic settings -experimental_name = "grpo_geo3k" -total_epochs = 15 -global_batch_size = 1024 -prompt_repeat_k = 5 -rollout_tp_size = 2 -rollout_ep_size = 1 -max_prompt_length = 4096 # Note: 不设置大一点,大部分数据都会被过滤掉 -max_response_length = 2048 -pack_max_length = 32768 -train_optimizer_steps = 4 -hf_interval = 15 -enable_initial_evaluate = True -evaluate_step = 10 - -# grpo quick test: -# total_epochs = 3 -# global_batch_size = 64 -# prompt_repeat_k = 5 -# rollout_tp_size = 1 -# rollout_ep_size = 1 -# max_prompt_length = 512 -# max_response_length = 1024 -# pack_max_length = 32768 -# train_optimizer_steps = 1 -# hf_interval = 100 -# enable_initial_evaluate = True -# evaluate_step = 15 - -# 1. resources -resources = AcceleratorResourcesConfig( - accelerator="GPU", - num_workers=8, - num_cpus_per_worker=12, - cpu_memory_per_worker=16 * 1024**3, # 16 GB -) - -# 2. rollout -rollout_config = RolloutConfig( - env=experimental_name, - device=resources.accelerator, - model_path=model_path, - dtype="bfloat16", - tensor_parallel_size=rollout_tp_size, - expert_parallel_size=rollout_ep_size, - gpu_memory_utilization=0.75, - context_length=max_prompt_length+max_response_length, - extra_rollout_config={ - "sglang_grammar_backend": 'none', - } - # rollout_max_batch_size_per_instance=16, # optional -) - -# sampling params -training_sample_params = SampleParams( - max_tokens=max_response_length, -) -evaluation_sample_params = deepcopy(training_sample_params) -evaluation_sample_params.top_p = 1.0 -evaluation_sample_params.temperature = 0.0 -evaluation_sample_params.top_k = 1 - -# dataset: 不需要修改 -train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) -eval_dataset = DatasetConfig(name=experimental_name, anno_path=eval_data_path) if enable_evaluate else None -tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - -tokenize_fn_cfg = InternS1VLTokenizeFnConfig(model_cfg=InternS1MiniConfig(), max_length=max_prompt_length) -train_dataset_cfg = [ - { - "dataset": DatasetConfig(name="geo3k", - anno_path=data_path, - class_name='VLMJsonlDataset', - media_root=media_root, - sample_ratio=1.0), - "tokenize_fn": RLTokenizeFnConfig(max_length=max_prompt_length, - tokenize_fn_cfg=tokenize_fn_cfg), - } -] - -eval_dataset_cfg = [] -if enable_evaluate: - eval_dataset_cfg = [ - { - "dataset": DatasetConfig(name="geo3k", - anno_path=eval_data_path, - class_name='VLMJsonlDataset', - media_root=media_root, - sample_ratio=1.0), - "tokenize_fn": RLTokenizeFnConfig(max_length=max_prompt_length, - tokenize_fn_cfg=tokenize_fn_cfg, - ignore_multimodal_info=True), - } - ] - -dataloader_config = DataloaderConfig(num_workers=8, - collator="fake_collator", - pack_level="none") - -# 3. judger -geo3k_judger_config = GEO3KJudgerConfig() -judger_cfg = JudgerConfig(reward_judger_configs=[geo3k_judger_config]) - -# 4. dataflow and evaluator -dataflow_config = DataFlowConfig( - env=experimental_name, - prompt_repeat_k=prompt_repeat_k, - global_batch_size=global_batch_size, - sample_params=training_sample_params, - # max_concurrent=64, # optional -) - -evaluator_cfg = EvaluatorConfig( - enable_evaluate=enable_evaluate, - enable_initial_evaluate=enable_initial_evaluate, - dataset_cfg=eval_dataset_cfg, - dataloader_cfg=dataloader_config, - tokenizer=model_path, - evaluate_step=evaluate_step, - compute_metric_func=None, - sample_params=evaluation_sample_params, -) if enable_evaluate else None - -# replay buffer config: : 不需要修改 -replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=train_dataset_cfg, dataloader_cfg=dataloader_config, tokenizer=model_path -) - -# 5. Train worker -# NOTE: modify model_cfg -model_cfg = InternS1MiniConfig(freeze_vision=True, freeze_projector=True) -optim_cfg = AdamWConfig(lr=1e-6, foreach=False) -loss_cfg = GRPOLossConfig( - policy_loss_cfg=dict( - cliprange_high=0.2, - cliprange_low=0.2, - loss_type="vanilla", - ), - ignore_idx=-100, - use_kl_loss=True, - kl_loss_coef=0.001, - kl_loss_type="low_var_kl", - mode="chunk", - chunk_size=512, -) -lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) -fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False) -train_worker_cfg: WorkerConfig = WorkerConfig( - model_cfg=model_cfg, - load_from=model_path, - optim_cfg=optim_cfg, - loss_cfg=loss_cfg, - lr_cfg=lr_cfg, - fsdp_cfg=fsdp_cfg, - sp_size=1, - optimizer_steps=train_optimizer_steps, - pack_max_length=pack_max_length, -) - -# 6. RL Trainer -trainer = RLTrainerConfig( - load_from=model_path, - resources=resources, - rollout_config=rollout_config, - dataflow_config=dataflow_config, - judger_config=judger_cfg, - replay_buffer_config=replay_buffer_cfg, - evaluator_config=evaluator_cfg, - train_worker_config=train_worker_cfg, - tokenizer_path=model_path, - work_dir=work_dir, - total_epochs=total_epochs, - hf_interval=hf_interval, - advantage_estimator_config=GRPOAdvantageConfig(eps=1e-8), -) diff --git a/examples/v1/config/rl_multi_task_gsm8k_dapo_math.py b/examples/v1/config/rl_multi_task_gsm8k_dapo_math.py new file mode 100644 index 0000000000..6d8faee875 --- /dev/null +++ b/examples/v1/config/rl_multi_task_gsm8k_dapo_math.py @@ -0,0 +1,313 @@ +"""RL Colocate Trainer 示例配置(Multi-Task: GSM8K + DAPO Math)。 + +需设置环境变量: + WORK_DIR + MODEL_PATH + GSM8K_DATA_PATH + GSM8K_EVAL_DATA_PATH + DAPO_DATA_PATH + DAPO_EVAL_DATA_PATH + +可选环境变量: + WORLD_SIZE + ENABLE_RETURN_ROUTED_EXPERTS + LOSS_TYPE + LOSS_MODE + SP_SIZE + GSM8K_TASK_WEIGHT + DAPO_TASK_WEIGHT +""" + +import os +from pathlib import Path + +from transformers import AutoTokenizer + +from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig +from xtuner.v1.data_proto.rl_data import SampleParams +from xtuner.v1.rl.advantage import GRPOAdvantageConfig +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLTextTokenizeFnConfig +from xtuner.v1.model import get_model_config_from_hf +from xtuner.v1.rl.agent_loop import SingleTurnAgentLoopConfig +from xtuner.v1.rl.agent_loop_manager import ( + AgentLoopManagerConfig, + SamplerConfig, + SyncProduceStrategyConfig, + TaskSpecConfig, +) +from xtuner.v1.rl.evaluator import EvaluatorConfig +from xtuner.v1.rl.judger import DapoMathJudgerConfig +from xtuner.v1.rl.replay_buffer import SyncReplayBufferConfig +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.trainer import WorkerConfig +from xtuner.v1.rl.loss import GRPOLossConfig +from xtuner.v1.rl.utils import AcceleratorResourcesConfig, get_eos_token +from xtuner.v1.train.rl_trainer import RLColocateTrainerConfig + +work_dir = os.environ["WORK_DIR"] +model_path = os.environ["MODEL_PATH"] +gsm8k_data_path = os.environ["GSM8K_DATA_PATH"] +gsm8k_eval_data_path = os.environ["GSM8K_EVAL_DATA_PATH"] +dapo_data_path = os.environ["DAPO_DATA_PATH"] +dapo_eval_data_path = os.environ["DAPO_EVAL_DATA_PATH"] +enable_return_routed_experts = os.environ.get("ENABLE_RETURN_ROUTED_EXPERTS", "0") +NNODE = int(os.environ.get("WORLD_SIZE", "1")) + +experimental_name = "multi_task_gsm8k_dapo_math" +total_train_steps = 50 +evaluate_step = 5 +train_optimizer_steps = 8 +train_batch_size = 128 +gsm8k_task_weight = float(os.environ.get("GSM8K_TASK_WEIGHT", "1.0")) +dapo_task_weight = float(os.environ.get("DAPO_TASK_WEIGHT", "1.0")) +rollout_tp_size = 1 +rollout_ep_size = 1 +gsm8k_prompt_repeat_k = 5 +dapo_prompt_repeat_k = 8 +gsm8k_max_prompt_length = 512 +dapo_max_prompt_length = 2048 +gsm8k_max_response_length = 1024 +dapo_max_response_length = 8192 +max_prompt_length = dapo_max_prompt_length +max_response_length = dapo_max_response_length +pack_max_length = 32768 + +resources = AcceleratorResourcesConfig( + accelerator="GPU", + num_workers=8 * NNODE, + num_cpus_per_worker=12, + cpu_memory_per_worker=16 * 1024**3, +) + +rollout_config = RolloutConfig( + env=experimental_name, + device=resources.accelerator, + model_path=model_path, + dtype="bfloat16", + tensor_parallel_size=rollout_tp_size, + expert_parallel_size=rollout_ep_size, + gpu_memory_utilization=0.8, + context_length=max_response_length + max_prompt_length, + enable_return_routed_experts=(enable_return_routed_experts == "1"), + rollout_max_batch_size_per_instance=2048, +) + +eos_token_id = get_eos_token(model_path) +tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) +eos_token_str = tokenizer.convert_ids_to_tokens(eos_token_id) +judger_config = DapoMathJudgerConfig( + judger_name="dapo_math", + num_ray_actors=1, + eos_token=eos_token_str, + enable_overlong_buffer=True, + max_response_len=max_response_length, + overlong_buffer_len=4096, + overlong_penalty_factor=1.0, + tokenizer=tokenizer, +) + +lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) +fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) +model_cfg = get_model_config_from_hf(Path(model_path)) +if hasattr(model_cfg, "balancing_loss_cfg"): + model_cfg.balancing_loss_cfg = None +if hasattr(model_cfg, "z_loss_cfg"): + model_cfg.z_loss_cfg = None +optim_cfg = AdamWConfig(lr=1e-6, foreach=False, weight_decay=0.1) +loss_cfg = GRPOLossConfig( + policy_loss_cfg=dict( + cliprange_high=0.28, + cliprange_low=0.2, + loss_type=os.environ.get("LOSS_TYPE", "vanilla"), + clip_ratio_c=10.0, + log_prob_diff_min=-20.0, + log_prob_diff_max=20.0, + ), + ignore_idx=-100, + use_kl_loss=False, + kl_loss_coef=0.0, + kl_loss_type="low_var_kl", + mode=os.environ.get("LOSS_MODE", "chunk"), + chunk_size=512, +) +train_worker_cfg = WorkerConfig( + model_cfg=model_cfg, + load_from=model_path, + optim_cfg=optim_cfg, + loss_cfg=loss_cfg, + lr_cfg=lr_cfg, + fsdp_cfg=fsdp_cfg, + sp_size=int(os.environ.get("SP_SIZE", "1")), + optimizer_steps=train_optimizer_steps, + pack_max_length=pack_max_length, +) + +gsm8k_train_tokenizer_config = RLTextTokenizeFnConfig(max_length=gsm8k_max_prompt_length) +dapo_train_tokenizer_config = RLTextTokenizeFnConfig(max_length=dapo_max_prompt_length) + +gsm8k_train_sampler_config = SamplerConfig( + dataloader_cfg=DataloaderConfig( + dataset_config_list=[ + { + "dataset": DatasetConfig(name="gsm8k", anno_path=gsm8k_data_path), + "tokenize_fn": gsm8k_train_tokenizer_config, + } + ], + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", + ), + prompt_repeat_k=gsm8k_prompt_repeat_k, +) +dapo_train_sampler_config = SamplerConfig( + dataloader_cfg=DataloaderConfig( + dataset_config_list=[ + { + "dataset": DatasetConfig(name="dapo_math", anno_path=dapo_data_path), + "tokenize_fn": dapo_train_tokenizer_config, + } + ], + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", + ), + prompt_repeat_k=dapo_prompt_repeat_k, +) + +gsm8k_train_agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=SampleParams( + max_tokens=gsm8k_max_response_length, + top_k=0, + top_p=1.0, + temperature=1.0, + min_tokens=0, + ), +) +dapo_train_agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=SampleParams( + max_tokens=dapo_max_response_length, + top_k=0, + top_p=1.0, + temperature=1.0, + min_tokens=0, + ), +) + +agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=[ + TaskSpecConfig( + task_name="train_task:dapo_math", + weight=dapo_task_weight, + agent_loop_config=dapo_train_agent_loop_config, + judger_config=judger_config, + produce_strategy_config=SyncProduceStrategyConfig(), + sampler_config=dapo_train_sampler_config, + ), + TaskSpecConfig( + task_name="train_task:gsm8k", + weight=gsm8k_task_weight, + agent_loop_config=gsm8k_train_agent_loop_config, + produce_strategy_config=SyncProduceStrategyConfig(), + sampler_config=gsm8k_train_sampler_config, + ), + ], +) + +gsm8k_eval_sampler_config = SamplerConfig( + dataloader_cfg=DataloaderConfig( + dataset_config_list=[ + { + "dataset": DatasetConfig(name="gsm8k_eval", anno_path=gsm8k_eval_data_path, sample_ratio=1.0), + "tokenize_fn": gsm8k_train_tokenizer_config, + } + ], + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", + ), + prompt_repeat_k=1, +) +dapo_eval_sampler_config = SamplerConfig( + dataloader_cfg=DataloaderConfig( + dataset_config_list=[ + { + "dataset": DatasetConfig(name="dapo_math_eval", anno_path=dapo_eval_data_path, sample_ratio=1.0), + "tokenize_fn": dapo_train_tokenizer_config, + } + ], + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", + ), + prompt_repeat_k=1, +) + +gsm8k_eval_agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=SampleParams( + max_tokens=gsm8k_max_response_length, + top_k=1, + top_p=1.0, + temperature=0.0, + min_tokens=0, + ), +) +dapo_eval_agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=SampleParams( + max_tokens=dapo_max_response_length, + top_k=1, + top_p=0.7, + temperature=0.0, + min_tokens=0, + ), +) + +eval_agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=[ + TaskSpecConfig( + task_name="eval_task:dapo_math", + weight=dapo_task_weight, + agent_loop_config=dapo_eval_agent_loop_config, + judger_config=judger_config, + sampler_config=dapo_eval_sampler_config, + ), + TaskSpecConfig( + task_name="eval_task:gsm8k", + weight=gsm8k_task_weight, + agent_loop_config=gsm8k_eval_agent_loop_config, + sampler_config=gsm8k_eval_sampler_config, + ), + ], +) + + +def compute_metric(samples): + return {"accuracy": sum(sample.reward["acc"] > 0 for sample in samples) / len(samples)} + + +evaluator_config = EvaluatorConfig(compute_metric_func=compute_metric) + +trainer = RLColocateTrainerConfig( + resources=resources, + train_worker_cfg=train_worker_cfg, + rollout_config=rollout_config, + tokenizer_path=model_path, + replay_buffer_config=SyncReplayBufferConfig(), + agent_loop_manager_cfg=agent_loop_manager_cfg, + eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg, + evaluator_config=evaluator_config, + load_from=model_path, + train_batch_size=train_batch_size, + advantage_estimator_config=GRPOAdvantageConfig(eps=1e-8), + enable_evaluate=True, + enable_initial_evaluate=False, + total_train_steps=total_train_steps, + evaluate_step=evaluate_step, + work_dir=work_dir, + seed=123, + debug_rollout=False, +) diff --git a/examples/v1/config/rl_qwen25_7B_dapo.py b/examples/v1/config/rl_qwen25_7B_dapo.py deleted file mode 100644 index 98e472fde3..0000000000 --- a/examples/v1/config/rl_qwen25_7B_dapo.py +++ /dev/null @@ -1,182 +0,0 @@ -import os -from copy import deepcopy -from pathlib import Path -from transformers import AutoTokenizer -from xtuner.v1.config import ( - AdamWConfig, - FSDPConfig, - LRConfig, -) -from xtuner.v1.data_proto.rl_data import SampleParams -from xtuner.v1.datasets import RLTokenizeFnConfig -from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig -from xtuner.v1.model import get_model_config_from_hf -from xtuner.v1.model.dense.qwen2 import Qwen2Dense7BConfig -from xtuner.v1.ray.base import AcceleratorResourcesConfig -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig -from xtuner.v1.ray.evaluator import EvaluatorConfig -from xtuner.v1.ray.judger.controller import JudgerConfig -from xtuner.v1.ray.judger.dapo_math import DapoMathJudgerConfig -from xtuner.v1.rl.base import WorkerConfig -from xtuner.v1.rl.grpo import GRPOLossConfig -from xtuner.v1.train.rl_trainer import RLTrainerConfig -from xtuner.v1.rl.config.advantage import GRPOAdvantageConfig - -work_dir = os.environ["WORK_DIR"] -model_path = os.environ["MODEL_PATH"] -data_path = os.environ["DATA_PATH"] -eval_data_path = os.environ["EVAL_DATA_PATH"] -enable_evaluate = True if eval_data_path != "" else False -enbale_partial_rollout = int(os.environ.get("ENBALE_PARTIAL_ROLLOUT", "0")) - -# basic settings -experimental_name = "dapo_math" -total_epochs = 1 -global_batch_size = 512 -prompt_repeat_k = 16 -rollout_tp_size = 2 -rollout_ep_size = 1 -max_prompt_length = 2048 -max_response_length = 8192 -pack_max_length = 32768 -train_optimizer_steps = 16 -hf_interval = 50 -enable_initial_evaluate = True -evaluate_step = 5 - -# 1. resources -resources = AcceleratorResourcesConfig( - accelerator="GPU", - num_workers=8, - num_cpus_per_worker=12, - cpu_memory_per_worker=16 * 1024**3, # 16 GB -) - -# 2. rollout -rollout_config = RolloutConfig( - env=experimental_name, - device=resources.accelerator, - model_path=model_path, - dtype="bfloat16", - tensor_parallel_size=rollout_tp_size, - expert_parallel_size=rollout_ep_size, - gpu_memory_utilization=0.8, - context_length = max_response_length + max_prompt_length, - # rollout_max_batch_size_per_instance=64, # optional, will be determined automatically if not set -) - -# sampling params -training_sample_params = SampleParams( - max_tokens=max_response_length, - top_k=0, - top_p=1.0, - temperature=1.0, - min_tokens=0, -) -evaluation_sample_params = deepcopy(training_sample_params) -evaluation_sample_params.top_p = 0.7 - -# dataset -train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) -eval_dataset = DatasetConfig(name=experimental_name, anno_path=eval_data_path) if enable_evaluate else None -tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) -tokenizer_config = RLTokenizeFnConfig(max_length=max_prompt_length) - -train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] -eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] if enable_evaluate else [] - -dataloader_config = DataloaderConfig(pack_max_length=pack_max_length, collator="fake_collator", pack_level="none") - -# 3. judger -from xtuner.v1.utils.rl_test_utils import get_eos_token -eos_token_id = get_eos_token(model_path) -eos_token_str = tokenizer.convert_ids_to_tokens(eos_token_id) -dapomath_judger_config = DapoMathJudgerConfig( - judger_name="dapo_math", - eos_token=eos_token_str, - enable_overlong_buffer = True, - max_response_len =max_response_length, - overlong_buffer_len=4096, - overlong_penalty_factor=1.0, - tokenizer=tokenizer) -judger_cfg = JudgerConfig(reward_judger_configs=[dapomath_judger_config]) - -# 4. dataflow and evaluator -dataflow_config = DataFlowConfig( - env=experimental_name, - prompt_repeat_k=prompt_repeat_k, - global_batch_size=global_batch_size, - sample_params=training_sample_params, - enable_partial_rollout=enbale_partial_rollout, - # max_concurrent=64, # optional, will be determined automatically if not set -) - - -def dapo_compute_metric(samples): - return {"accuracy": sum(s.env.judger.reward["acc"] > 0 for s in samples) / len(samples)} - - -evaluator_cfg = EvaluatorConfig( - enable_evaluate=enable_evaluate, - enable_initial_evaluate=enable_initial_evaluate, - dataset_cfg=eval_dataset_cfg, - tokenizer=tokenizer, - evaluate_step=evaluate_step, - compute_metric_func=dapo_compute_metric, - sample_params=evaluation_sample_params, -) if enable_evaluate else None - -replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=train_dataset_cfg, dataloader_cfg=dataloader_config, tokenizer=tokenizer -) - -# 5. Train worker -model_cfg = get_model_config_from_hf(Path(model_path)) -optim_cfg = AdamWConfig(lr=1e-6, betas=(0.9, 0.999), max_grad_norm=1.0, weight_decay=0.1, foreach=False) -loss_cfg = GRPOLossConfig( - policy_loss_cfg=dict( - cliprange_high=0.28, - cliprange_low=0.2, - loss_type="vanilla", - clip_ratio_c=10.0, - log_prob_diff_min=-20.0, - log_prob_diff_max=20.0, - ), - ignore_idx=-100, - use_kl_loss=False, - kl_loss_coef=0.0, - kl_loss_type="low_var_kl", - mode="chunk", - chunk_size=512, -) -lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) -fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) -train_worker_cfg: WorkerConfig = WorkerConfig( - model_cfg=model_cfg, - load_from=model_path, - optim_cfg=optim_cfg, - loss_cfg=loss_cfg, - lr_cfg=lr_cfg, - fsdp_cfg=fsdp_cfg, - sp_size=1, - optimizer_steps=train_optimizer_steps, - pack_max_length=pack_max_length, -) - -# 6. RL Trainer -trainer = RLTrainerConfig( - load_from=model_path, - resources=resources, - rollout_config=rollout_config, - dataflow_config=dataflow_config, - judger_config=judger_cfg, - replay_buffer_config=replay_buffer_cfg, - evaluator_config=evaluator_cfg, - train_worker_config=train_worker_cfg, - tokenizer_path=model_path, - work_dir=work_dir, - total_epochs=total_epochs, - hf_interval=hf_interval, - advantage_estimator_config=GRPOAdvantageConfig(eps=1e-8), -) diff --git a/examples/v1/config/rl_qwen25_7B_dapo_async.py b/examples/v1/config/rl_qwen25_7B_dapo_async.py deleted file mode 100644 index b5e4055879..0000000000 --- a/examples/v1/config/rl_qwen25_7B_dapo_async.py +++ /dev/null @@ -1,214 +0,0 @@ -import os -from copy import deepcopy - -from transformers import AutoTokenizer -from xtuner.v1.config import ( - AdamWConfig, - FSDPConfig, - LRConfig, -) -from xtuner.v1.data_proto.rl_data import SampleParams -from xtuner.v1.datasets import RLTokenizeFnConfig -from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig -from xtuner.v1.model.dense.qwen2 import Qwen2Dense7BConfig -from xtuner.v1.ray.base import AcceleratorResourcesConfig -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig -from xtuner.v1.ray.evaluator import EvaluatorConfig -from xtuner.v1.ray.judger.controller import JudgerConfig -from xtuner.v1.ray.judger.dapo_math import DapoMathJudgerConfig -from xtuner.v1.rl.base import WorkerConfig -from xtuner.v1.rl.grpo import GRPOLossConfig -from xtuner.v1.train.rl_trainer import RLTrainerConfig -from xtuner.v1.rl.base.rollout_is import RolloutImportanceSampling -from xtuner.v1.model import get_model_config_from_hf -from xtuner.v1.rl.config.advantage import GRPOAdvantageConfig - - -work_dir = os.environ["WORK_DIR"] -model_path = os.environ["MODEL_PATH"] -data_path = os.environ["DATA_PATH"] -eval_data_path = os.environ.get("EVAL_DATA_PATH") -enable_evaluate = True if eval_data_path != "" else False -global_batch_size = int(os.environ.get("GLOBAL_BATCH_SIZE", "16")) -enable_return_routed_experts = 0 -enbale_partial_rollout = 1 -staleness_threshold = 0.2 -tail_batch_candidate_steps = 2 -tail_batch_trigger_size = global_batch_size -max_response_length= 8192 -enable_float8_rollout = 0 - -# basic settings -experimental_name = "dapo_math" -total_epochs = 1 -prompt_repeat_k = 16 -rollout_tp_size = 1 -rollout_ep_size = 1 -max_prompt_length = 2048 -pack_max_length = 32768 -train_optimizer_steps = 16 -hf_interval = 50 -enable_initial_evaluate = True -evaluate_step = 5 - -# 1. resources -resources = AcceleratorResourcesConfig( - accelerator="GPU", - num_workers=8, - num_cpus_per_worker=12, - cpu_memory_per_worker=16 * 1024**3, # 16 GB -) - -# 2. rollout -rollout_config = RolloutConfig( - env=experimental_name, - device=resources.accelerator, - model_path=model_path, - dtype="bfloat16", - tensor_parallel_size=rollout_tp_size, - expert_parallel_size=rollout_ep_size, - gpu_memory_utilization=0.8, - enable_float8=enable_float8_rollout, - context_length = max_response_length + max_prompt_length, - rollout_max_batch_size_per_instance=512, - allow_over_concurrency_ratio=4, - rollout_timeout=7200.0, - enable_return_routed_experts=enable_return_routed_experts, - extra_rollout_config=dict(lmdeploy_log_level="ERROR", lmdeploy_uvicorn_log_level="ERROR"), -) - -# sampling params -training_sample_params = SampleParams( - max_tokens=max_response_length, - top_k=0, - top_p=1.0, - temperature=1.0, - min_tokens=0, -) -evaluation_sample_params = deepcopy(training_sample_params) -evaluation_sample_params.top_p = 0.7 - -# dataset -train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) -eval_dataset = DatasetConfig(name=experimental_name, anno_path=eval_data_path) if enable_evaluate else None -tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) -tokenizer_config = RLTokenizeFnConfig(max_length=max_prompt_length) - -train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] -eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] if enable_evaluate else [] - -dataloader_config = DataloaderConfig(pack_max_length=pack_max_length, collator="fake_collator", pack_level="none") - -# 3. judger -from xtuner.v1.utils.rl_test_utils import get_eos_token -eos_token_id = get_eos_token(model_path) -eos_token_str = tokenizer.convert_ids_to_tokens(eos_token_id) -dapomath_judger_config = DapoMathJudgerConfig( - judger_name="dapo_math", - eos_token=eos_token_str, - enable_overlong_buffer = True, - max_response_len =max_response_length, - overlong_buffer_len=4096, - overlong_penalty_factor=1.0, - tokenizer=tokenizer) -judger_cfg = JudgerConfig(reward_judger_configs=[dapomath_judger_config]) - -# 4. dataflow and evaluator -dataflow_config = DataFlowConfig( - env=experimental_name, - prompt_repeat_k=prompt_repeat_k, - global_batch_size=global_batch_size, - sample_params=training_sample_params, - enable_partial_rollout=enbale_partial_rollout, - staleness_threshold=staleness_threshold, - tail_batch_candidate_steps=tail_batch_candidate_steps, - tail_batch_trigger_size=tail_batch_trigger_size -) - -def dapo_compute_metric(samples): - return {"accuracy": sum(s.env.judger.reward["acc"] > 0 for s in samples) / len(samples)} - - -evaluator_cfg = EvaluatorConfig( - enable_evaluate=enable_evaluate, - enable_initial_evaluate=enable_initial_evaluate, - dataset_cfg=eval_dataset_cfg, - tokenizer=tokenizer, - evaluate_step=evaluate_step, - compute_metric_func=dapo_compute_metric, - sample_params=evaluation_sample_params, - max_concurrent=1024, -) if enable_evaluate else None - -def group_sample_filter_func(group_samples): - rewards = [d.env.judger.reward["score"] for d in group_samples] - if len(set(rewards)) == 1: - print(f"filter all same reward sample: {rewards}") - return [] - else: - return group_samples - -replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=train_dataset_cfg, - dataloader_cfg=dataloader_config, - tokenizer=tokenizer, - # postprocessor_func=group_sample_filter_func -) - -# 5. Train worker -model_cfg = Qwen2Dense7BConfig() -optim_cfg = AdamWConfig(lr=1e-6, betas=(0.9, 0.999), max_grad_norm=1.0, weight_decay=0.1, foreach=False) -loss_cfg = GRPOLossConfig( - policy_loss_cfg=dict( - cliprange_high=0.28, - cliprange_low=0.2, - loss_type="vanilla", - clip_ratio_c=10.0, - log_prob_diff_min=-20.0, - log_prob_diff_max=20.0, - ), - ignore_idx=-100, - use_kl_loss=False, - kl_loss_coef=0.0, - kl_loss_type="low_var_kl", - mode="chunk", - chunk_size=512, - rollout_is=RolloutImportanceSampling( - rollout_is_level="token", - rollout_is_mode="both", - rollout_is_threshold=(5, 0.5), - rollout_is_mask_threshold=(5, 0.5), - rollout_is_veto_threshold=(20, 0), - ), -) -lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) -fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) -train_worker_cfg: WorkerConfig = WorkerConfig( - model_cfg=model_cfg, - load_from=model_path, - optim_cfg=optim_cfg, - loss_cfg=loss_cfg, - lr_cfg=lr_cfg, - fsdp_cfg=fsdp_cfg, - sp_size=1, - optimizer_steps=train_optimizer_steps, - pack_max_length=pack_max_length, -) - -# 6. RL Trainer -trainer = RLTrainerConfig( - load_from=model_path, - resources=resources, - rollout_config=rollout_config, - dataflow_config=dataflow_config, - judger_config=judger_cfg, - replay_buffer_config=replay_buffer_cfg, - evaluator_config=evaluator_cfg, - train_worker_config=train_worker_cfg, - tokenizer_path=model_path, - work_dir=work_dir, - total_epochs=total_epochs, - hf_interval=hf_interval, - advantage_estimator_config=GRPOAdvantageConfig(eps=1e-8), -) diff --git a/examples/v1/config/rl_qwen3_30B_dapo.py b/examples/v1/config/rl_qwen3_30B_dapo.py deleted file mode 100644 index 475031478d..0000000000 --- a/examples/v1/config/rl_qwen3_30B_dapo.py +++ /dev/null @@ -1,181 +0,0 @@ -import os -from copy import deepcopy -from pathlib import Path -from transformers import AutoTokenizer -from xtuner.v1.config import ( - AdamWConfig, - FSDPConfig, - LRConfig, -) -from xtuner.v1.data_proto.rl_data import SampleParams -from xtuner.v1.datasets import RLTokenizeFnConfig -from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig -from xtuner.v1.model.moe.qwen3 import Qwen3MoE30BA3Config -from xtuner.v1.ray.base import AcceleratorResourcesConfig -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig -from xtuner.v1.ray.evaluator import EvaluatorConfig -from xtuner.v1.ray.judger.controller import JudgerConfig -from xtuner.v1.ray.judger.dapo_math import DapoMathJudgerConfig -from xtuner.v1.rl.base import WorkerConfig -from xtuner.v1.rl.grpo import GRPOLossConfig -from xtuner.v1.train.rl_trainer import RLTrainerConfig -from xtuner.v1.model import get_model_config_from_hf -from xtuner.v1.rl.config.advantage import GRPOAdvantageConfig - -work_dir = os.environ["WORK_DIR"] -model_path = os.environ["MODEL_PATH"] -data_path = os.environ["DATA_PATH"] -eval_data_path = os.environ["EVAL_DATA_PATH"] -enable_return_routed_experts = os.environ.get("ENABLE_RETURN_ROUTED_EXPERTS", '0') -enable_evaluate = True if eval_data_path != "" else False - -# basic settings -experimental_name = "dapo_math" -total_epochs = 1 -global_batch_size = 512 -prompt_repeat_k = 16 -rollout_tp_size = 2 -rollout_ep_size = 1 -max_prompt_length = 2048 -max_response_length = 8192 -pack_max_length = 32768 -train_optimizer_steps = 16 -hf_interval = 50 -enable_initial_evaluate = True -evaluate_step = 5 - -# 1. resources -resources = AcceleratorResourcesConfig( - accelerator="GPU", - num_workers=8, - num_cpus_per_worker=12, - cpu_memory_per_worker=16 * 1024**3, # 16 GB -) - -# 2. rollout -rollout_config = RolloutConfig( - env=experimental_name, - device=resources.accelerator, - model_path=model_path, - dtype="bfloat16", - tensor_parallel_size=rollout_tp_size, - expert_parallel_size=rollout_ep_size, - gpu_memory_utilization=0.8, - context_length = max_response_length + max_prompt_length, - # rollout_max_batch_size_per_instance=512, - enable_return_routed_experts=True if enable_return_routed_experts == "1" else False, -) - -# sampling params -training_sample_params = SampleParams( - max_tokens=max_response_length, - top_k=0, - top_p=1.0, - temperature=1.0, - min_tokens=0, -) -evaluation_sample_params = deepcopy(training_sample_params) -evaluation_sample_params.top_p = 0.7 - -# dataset -train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) -eval_dataset = DatasetConfig(name=experimental_name, anno_path=eval_data_path) if enable_evaluate else None -tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) -tokenizer_config = RLTokenizeFnConfig(max_length=max_prompt_length) - -train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] -eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] if enable_evaluate else [] - -dataloader_config = DataloaderConfig(pack_max_length=pack_max_length, collator="fake_collator", pack_level="none") - -# 3. judger -from xtuner.v1.utils.rl_test_utils import get_eos_token -eos_token_id = get_eos_token(model_path) -eos_token_str = tokenizer.convert_ids_to_tokens(eos_token_id) -dapomath_judger_config = DapoMathJudgerConfig( - judger_name="dapo_math", - eos_token=eos_token_str, - enable_overlong_buffer = True, - max_response_len =max_response_length, - overlong_buffer_len=4096, - overlong_penalty_factor=1.0, - tokenizer=tokenizer) -judger_cfg = JudgerConfig(reward_judger_configs=[dapomath_judger_config]) - -# 4. dataflow and evaluator -dataflow_config = DataFlowConfig( - env=experimental_name, - prompt_repeat_k=prompt_repeat_k, - global_batch_size=global_batch_size, - sample_params=training_sample_params, -) - - -def dapo_compute_metric(samples): - return {"accuracy": sum(s.env.judger.reward["acc"] > 0 for s in samples) / len(samples)} - - -evaluator_cfg = EvaluatorConfig( - enable_evaluate=enable_evaluate, - enable_initial_evaluate=enable_initial_evaluate, - dataset_cfg=eval_dataset_cfg, - tokenizer=tokenizer, - evaluate_step=evaluate_step, - compute_metric_func=dapo_compute_metric, - sample_params=evaluation_sample_params, -) if enable_evaluate else None - -replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=train_dataset_cfg, dataloader_cfg=dataloader_config, tokenizer=tokenizer -) - -# 5. Train worker -model_cfg = get_model_config_from_hf(Path(model_path)) -optim_cfg = AdamWConfig(lr=1e-6, betas=(0.9, 0.999), max_grad_norm=1.0, weight_decay=0.1, foreach=False) -loss_cfg = GRPOLossConfig( - policy_loss_cfg=dict( - cliprange_high=0.28, - cliprange_low=0.2, - loss_type="vanilla", - clip_ratio_c=10.0, - log_prob_diff_min=-20.0, - log_prob_diff_max=20.0, - ), - ignore_idx=-100, - use_kl_loss=False, - kl_loss_coef=0.0, - kl_loss_type="low_var_kl", - mode="chunk", - chunk_size=512, -) -lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) -fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) -train_worker_cfg: WorkerConfig = WorkerConfig( - model_cfg=model_cfg, - load_from=model_path, - optim_cfg=optim_cfg, - loss_cfg=loss_cfg, - lr_cfg=lr_cfg, - fsdp_cfg=fsdp_cfg, - sp_size=1, - optimizer_steps=train_optimizer_steps, - pack_max_length=pack_max_length, -) - -# 6. RL Trainer -trainer = RLTrainerConfig( - load_from=model_path, - resources=resources, - rollout_config=rollout_config, - dataflow_config=dataflow_config, - judger_config=judger_cfg, - replay_buffer_config=replay_buffer_cfg, - evaluator_config=evaluator_cfg, - train_worker_config=train_worker_cfg, - tokenizer_path=model_path, - work_dir=work_dir, - total_epochs=total_epochs, - hf_interval=hf_interval, - advantage_estimator_config=GRPOAdvantageConfig(eps=1e-8), -) diff --git a/examples/v1/config/rl_qwen3_30B_dapo_async.py b/examples/v1/config/rl_qwen3_30B_dapo_async.py deleted file mode 100644 index bfe31ee9b4..0000000000 --- a/examples/v1/config/rl_qwen3_30B_dapo_async.py +++ /dev/null @@ -1,214 +0,0 @@ -import os -from copy import deepcopy - -from transformers import AutoTokenizer -from xtuner.v1.config import ( - AdamWConfig, - FSDPConfig, - LRConfig, -) -from xtuner.v1.data_proto.rl_data import SampleParams -from xtuner.v1.datasets import RLTokenizeFnConfig -from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig -from xtuner.v1.model.moe.qwen3 import Qwen3MoE30BA3Config -from xtuner.v1.ray.base import AcceleratorResourcesConfig -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig -from xtuner.v1.ray.evaluator import EvaluatorConfig -from xtuner.v1.ray.judger.controller import JudgerConfig -from xtuner.v1.ray.judger.dapo_math import DapoMathJudgerConfig -from xtuner.v1.rl.base import WorkerConfig -from xtuner.v1.rl.grpo import GRPOLossConfig -from xtuner.v1.train.rl_trainer import RLTrainerConfig -from xtuner.v1.rl.base.rollout_is import RolloutImportanceSampling -from xtuner.v1.model import get_model_config_from_hf -from xtuner.v1.rl.config.advantage import GRPOAdvantageConfig - - -work_dir = os.environ["WORK_DIR"] -model_path = os.environ["MODEL_PATH"] -data_path = os.environ["DATA_PATH"] -eval_data_path = os.environ.get("EVAL_DATA_PATH") -enable_evaluate = True if eval_data_path != "" else False -global_batch_size = int(os.environ.get("GLOBAL_BATCH_SIZE", "16")) -enable_return_routed_experts = 1 -enbale_partial_rollout = 1 -staleness_threshold = 0.2 -tail_batch_candidate_steps = 2 -tail_batch_trigger_size = global_batch_size -max_response_length= 8192 -enable_float8_rollout = 0 - -# basic settings -experimental_name = "dapo_math" -total_epochs = 1 -prompt_repeat_k = 16 -rollout_tp_size = 1 -rollout_ep_size = 1 -max_prompt_length = 2048 -pack_max_length = 32768 -train_optimizer_steps = 16 -hf_interval = 50 -enable_initial_evaluate = True -evaluate_step = 5 - -# 1. resources -resources = AcceleratorResourcesConfig( - accelerator="GPU", - num_workers=8, - num_cpus_per_worker=12, - cpu_memory_per_worker=16 * 1024**3, # 16 GB -) - -# 2. rollout -rollout_config = RolloutConfig( - env=experimental_name, - device=resources.accelerator, - model_path=model_path, - dtype="bfloat16", - tensor_parallel_size=rollout_tp_size, - expert_parallel_size=rollout_ep_size, - gpu_memory_utilization=0.8, - enable_float8=enable_float8_rollout, - context_length = max_response_length + max_prompt_length, - rollout_max_batch_size_per_instance=512, - allow_over_concurrency_ratio=4, - rollout_timeout=7200.0, - enable_return_routed_experts=enable_return_routed_experts, - extra_rollout_config=dict(lmdeploy_log_level="ERROR", lmdeploy_uvicorn_log_level="ERROR"), -) - -# sampling params -training_sample_params = SampleParams( - max_tokens=max_response_length, - top_k=0, - top_p=1.0, - temperature=1.0, - min_tokens=0, -) -evaluation_sample_params = deepcopy(training_sample_params) -evaluation_sample_params.top_p = 0.7 - -# dataset -train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) -eval_dataset = DatasetConfig(name=experimental_name, anno_path=eval_data_path) if enable_evaluate else None -tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) -tokenizer_config = RLTokenizeFnConfig(max_length=max_prompt_length) - -train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] -eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] if enable_evaluate else [] - -dataloader_config = DataloaderConfig(pack_max_length=pack_max_length, collator="fake_collator", pack_level="none") - -# 3. judger -from xtuner.v1.utils.rl_test_utils import get_eos_token -eos_token_id = get_eos_token(model_path) -eos_token_str = tokenizer.convert_ids_to_tokens(eos_token_id) -dapomath_judger_config = DapoMathJudgerConfig( - judger_name="dapo_math", - eos_token=eos_token_str, - enable_overlong_buffer = True, - max_response_len =max_response_length, - overlong_buffer_len=4096, - overlong_penalty_factor=1.0, - tokenizer=tokenizer) -judger_cfg = JudgerConfig(reward_judger_configs=[dapomath_judger_config]) - -# 4. dataflow and evaluator -dataflow_config = DataFlowConfig( - env=experimental_name, - prompt_repeat_k=prompt_repeat_k, - global_batch_size=global_batch_size, - sample_params=training_sample_params, - enable_partial_rollout=enbale_partial_rollout, - staleness_threshold=staleness_threshold, - tail_batch_candidate_steps=tail_batch_candidate_steps, - tail_batch_trigger_size=tail_batch_trigger_size -) - -def dapo_compute_metric(samples): - return {"accuracy": sum(s.env.judger.reward["acc"] > 0 for s in samples) / len(samples)} - - -evaluator_cfg = EvaluatorConfig( - enable_evaluate=enable_evaluate, - enable_initial_evaluate=enable_initial_evaluate, - dataset_cfg=eval_dataset_cfg, - tokenizer=tokenizer, - evaluate_step=evaluate_step, - compute_metric_func=dapo_compute_metric, - sample_params=evaluation_sample_params, - max_concurrent=1024, -) if enable_evaluate else None - -def group_sample_filter_func(group_samples): - rewards = [d.env.judger.reward["score"] for d in group_samples] - if len(set(rewards)) == 1: - print(f"filter all same reward sample: {rewards}") - return [] - else: - return group_samples - -replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=train_dataset_cfg, - dataloader_cfg=dataloader_config, - tokenizer=tokenizer, - # postprocessor_func=group_sample_filter_func -) - -# 5. Train worker -model_cfg = Qwen3MoE30BA3Config(freeze_routers=True) -optim_cfg = AdamWConfig(lr=1e-6, betas=(0.9, 0.95), max_grad_norm=1.0, weight_decay=0.1, foreach=False, skip_grad_norm_threshold=0.9, eps=1e-15) -loss_cfg = GRPOLossConfig( - policy_loss_cfg=dict( - cliprange_high=0.28, - cliprange_low=0.2, - loss_type="vanilla", - clip_ratio_c=10.0, - log_prob_diff_min=-20.0, - log_prob_diff_max=20.0, - ), - ignore_idx=-100, - use_kl_loss=False, - kl_loss_coef=0.0, - kl_loss_type="low_var_kl", - mode="chunk", - chunk_size=512, - rollout_is=RolloutImportanceSampling( - rollout_is_level="token", - rollout_is_mode="both", - rollout_is_threshold=(5, 0.5), - rollout_is_mask_threshold=(5, 0.5), - rollout_is_veto_threshold=(20, 0), - ), -) -lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) -fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) -train_worker_cfg: WorkerConfig = WorkerConfig( - model_cfg=model_cfg, - load_from=model_path, - optim_cfg=optim_cfg, - loss_cfg=loss_cfg, - lr_cfg=lr_cfg, - fsdp_cfg=fsdp_cfg, - sp_size=1, - optimizer_steps=train_optimizer_steps, - pack_max_length=pack_max_length, -) - -# 6. RL Trainer -trainer = RLTrainerConfig( - load_from=model_path, - resources=resources, - rollout_config=rollout_config, - dataflow_config=dataflow_config, - judger_config=judger_cfg, - replay_buffer_config=replay_buffer_cfg, - evaluator_config=evaluator_cfg, - train_worker_config=train_worker_cfg, - tokenizer_path=model_path, - work_dir=work_dir, - total_epochs=total_epochs, - hf_interval=hf_interval, - advantage_estimator_config=GRPOAdvantageConfig(eps=1e-8), -) diff --git a/examples/v1/config/rl_qwen3_30B_grpo.py b/examples/v1/config/rl_qwen3_30B_grpo.py deleted file mode 100644 index 3396aae537..0000000000 --- a/examples/v1/config/rl_qwen3_30B_grpo.py +++ /dev/null @@ -1,176 +0,0 @@ -import os -from copy import deepcopy -from pathlib import Path -from transformers import AutoTokenizer -from xtuner.v1.config import ( - AdamWConfig, - FSDPConfig, - LRConfig, -) -from xtuner.v1.data_proto.rl_data import SampleParams -from xtuner.v1.datasets import RLTokenizeFnConfig -from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig -from xtuner.v1.model.moe.qwen3 import Qwen3MoE30BA3Config -from xtuner.v1.ray.base import AcceleratorResourcesConfig -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig -from xtuner.v1.ray.evaluator import EvaluatorConfig -from xtuner.v1.ray.judger.controller import JudgerConfig -from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig -from xtuner.v1.rl.base import WorkerConfig -from xtuner.v1.rl.grpo import GRPOLossConfig -from xtuner.v1.train.rl_trainer import RLTrainerConfig -from xtuner.v1.model import get_model_config_from_hf -from xtuner.v1.rl.config.advantage import GRPOAdvantageConfig - -work_dir = os.environ["WORK_DIR"] -model_path = os.environ["MODEL_PATH"] -data_path = os.environ["DATA_PATH"] -eval_data_path = os.environ["EVAL_DATA_PATH"] -enable_return_routed_experts = os.environ.get("ENABLE_RETURN_ROUTED_EXPERTS", '0') -enable_evaluate = True if eval_data_path != "" else False - -# basic settings -experimental_name = "grpo_gsm8k" -total_epochs = 15 -global_batch_size = 1024 -prompt_repeat_k = 5 -rollout_tp_size = 2 -rollout_ep_size = 1 -max_prompt_length = 512 -max_response_length = 1024 -pack_max_length = 32768 -train_optimizer_steps = 4 -hf_interval = 15 -enable_initial_evaluate = True -evaluate_step = 10 - -# grpo quick test settings -# total_epochs = 3 -# global_batch_size = 64 -# prompt_repeat_k = 5 -# rollout_tp_size = 1 -# rollout_ep_size = 1 -# max_prompt_length = 512 -# max_response_length = 1024 -# pack_max_length = 32768 -# train_optimizer_steps = 1 -# hf_interval = 100 -# enable_initial_evaluate = True -# evaluate_step = 15 - -# 1. resources -resources = AcceleratorResourcesConfig( - accelerator="GPU", - num_workers=8, - num_cpus_per_worker=12, - cpu_memory_per_worker=16 * 1024**3, # 16 GB -) - -# 2. rollout -rollout_config = RolloutConfig( - env=experimental_name, - device=resources.accelerator, - model_path=model_path, - dtype="bfloat16", - tensor_parallel_size=rollout_tp_size, - expert_parallel_size=rollout_ep_size, - gpu_memory_utilization=0.75, - context_length = max_response_length + max_prompt_length, - enable_return_routed_experts=True if enable_return_routed_experts == "1" else False, -) - -# sampling params -training_sample_params = SampleParams( - max_tokens=max_response_length, -) -evaluation_sample_params = deepcopy(training_sample_params) -evaluation_sample_params.top_p = 1.0 -evaluation_sample_params.temperature = 0.0 -evaluation_sample_params.top_k = 1 - -# dataset: 不需要修改 -train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) -eval_dataset = DatasetConfig(name=experimental_name, anno_path=eval_data_path) if enable_evaluate else None -tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) -tokenizer_config = RLTokenizeFnConfig(max_length=max_prompt_length) - -train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] -eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] if enable_evaluate else [] - -dataloader_config = DataloaderConfig(pack_max_length=pack_max_length, collator="fake_collator", pack_level="none") - -# 3. judger -dapomath_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k") -judger_cfg = JudgerConfig(reward_judger_configs=[dapomath_judger_config]) - -# 4. dataflow and evaluator -dataflow_config = DataFlowConfig( - env=experimental_name, - prompt_repeat_k=prompt_repeat_k, - global_batch_size=global_batch_size, - sample_params=training_sample_params, -) - -evaluator_cfg = EvaluatorConfig( - enable_evaluate=enable_evaluate, - enable_initial_evaluate=enable_initial_evaluate, - dataset_cfg=eval_dataset_cfg, - tokenizer=tokenizer, - evaluate_step=evaluate_step, - compute_metric_func=None, - sample_params=evaluation_sample_params, -) if enable_evaluate else None - -# replay buffer config: : 不需要修改 -replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=train_dataset_cfg, dataloader_cfg=dataloader_config, tokenizer=tokenizer -) - -# 5. Train worker -# NOTE: modify model_cfg -model_cfg = get_model_config_from_hf(Path(model_path)) -optim_cfg = AdamWConfig(lr=1e-6, foreach=False) -loss_cfg = GRPOLossConfig( - policy_loss_cfg=dict( - cliprange_high=0.2, - cliprange_low=0.2, - loss_type="vanilla", - ), - ignore_idx=-100, - use_kl_loss=True, - kl_loss_coef=0.001, - kl_loss_type="low_var_kl", - mode="chunk", - chunk_size=512, -) -lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) -fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) -train_worker_cfg: WorkerConfig = WorkerConfig( - model_cfg=model_cfg, - load_from=model_path, - optim_cfg=optim_cfg, - loss_cfg=loss_cfg, - lr_cfg=lr_cfg, - fsdp_cfg=fsdp_cfg, - sp_size=1, - optimizer_steps=train_optimizer_steps, - pack_max_length=pack_max_length, -) - -# 6. RL Trainer -trainer = RLTrainerConfig( - load_from=model_path, - resources=resources, - rollout_config=rollout_config, - dataflow_config=dataflow_config, - judger_config=judger_cfg, - replay_buffer_config=replay_buffer_cfg, - evaluator_config=evaluator_cfg, - train_worker_config=train_worker_cfg, - tokenizer_path=model_path, - work_dir=work_dir, - total_epochs=total_epochs, - hf_interval=hf_interval, - advantage_estimator_config=GRPOAdvantageConfig(eps=1e-8), -) diff --git a/examples/v1/config/rl_qwen3_30B_grpo_npu.py b/examples/v1/config/rl_qwen3_30B_grpo_npu.py index ebb688dbec..4824909627 100644 --- a/examples/v1/config/rl_qwen3_30B_grpo_npu.py +++ b/examples/v1/config/rl_qwen3_30B_grpo_npu.py @@ -1,33 +1,31 @@ import os from copy import deepcopy from pathlib import Path -from transformers import AutoTokenizer -from xtuner.v1.config import ( - AdamWConfig, - FSDPConfig, - LRConfig, -) + +from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig from xtuner.v1.data_proto.rl_data import SampleParams -from xtuner.v1.datasets import RLTokenizeFnConfig from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig -from xtuner.v1.model.moe.qwen3 import Qwen3MoE30BA3Config -from xtuner.v1.ray.base import AcceleratorResourcesConfig -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig -from xtuner.v1.ray.evaluator import EvaluatorConfig -from xtuner.v1.ray.judger.controller import JudgerConfig -from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig -from xtuner.v1.rl.base import WorkerConfig -from xtuner.v1.rl.grpo import GRPOLossConfig -from xtuner.v1.train.rl_trainer import RLTrainerConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLTextTokenizeFnConfig from xtuner.v1.model import get_model_config_from_hf +from xtuner.v1.rl.advantage import GRPOAdvantageConfig +from xtuner.v1.rl.agent_loop import SingleTurnAgentLoopConfig +from xtuner.v1.rl.agent_loop_manager import AgentLoopManagerConfig, SamplerConfig, SyncProduceStrategyConfig, TaskSpecConfig +from xtuner.v1.rl.evaluator import EvaluatorConfig +from xtuner.v1.rl.judger import GSM8KJudgerConfig +from xtuner.v1.rl.loss import GRPOLossConfig +from xtuner.v1.rl.replay_buffer import SyncReplayBufferConfig +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.trainer import WorkerConfig +from xtuner.v1.rl.utils import AcceleratorResourcesConfig +from xtuner.v1.train.rl_trainer import RLColocateTrainerConfig + work_dir = os.environ["WORK_DIR"] model_path = os.environ["MODEL_PATH"] data_path = os.environ["DATA_PATH"] eval_data_path = os.environ["EVAL_DATA_PATH"] -enable_return_routed_experts = os.environ.get("ENABLE_RETURN_ROUTED_EXPERTS", '0') -enable_evaluate = True if eval_data_path != "" else False +enable_return_routed_experts = os.environ.get("ENABLE_RETURN_ROUTED_EXPERTS", "0") +enable_evaluate = eval_data_path != "" # basic settings experimental_name = "grpo_gsm8k" @@ -45,26 +43,12 @@ enable_initial_evaluate = True evaluate_step = 10 -# grpo quick test settings -# total_epochs = 3 -# global_batch_size = 64 -# prompt_repeat_k = 5 -# rollout_tp_size = 1 -# rollout_ep_size = 1 -# max_prompt_length = 512 -# max_response_length = 1024 -# pack_max_length = 32768 -# train_optimizer_steps = 1 -# hf_interval = 100 -# enable_initial_evaluate = True -# evaluate_step = 15 - # 1. resources resources = AcceleratorResourcesConfig( accelerator="NPU", num_workers=16, num_cpus_per_worker=6, - cpu_memory_per_worker=16 * 1024**3, # 16 GB + cpu_memory_per_worker=16 * 1024**3, ) # 2. rollout @@ -77,59 +61,48 @@ data_parallel_size=rollout_dp_size, expert_parallel_size=rollout_ep_size, gpu_memory_utilization=0.85, - context_length = max_response_length + max_prompt_length, - enable_return_routed_experts=True if enable_return_routed_experts == "1" else False, + context_length=max_response_length + max_prompt_length, + enable_return_routed_experts=enable_return_routed_experts == "1", ) # sampling params -training_sample_params = SampleParams( - max_tokens=max_response_length, -) +training_sample_params = SampleParams(max_tokens=max_response_length) evaluation_sample_params = deepcopy(training_sample_params) evaluation_sample_params.top_p = 1.0 evaluation_sample_params.temperature = 0.0 evaluation_sample_params.top_k = 1 -# dataset: 不需要修改 -train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) -eval_dataset = DatasetConfig(name=experimental_name, anno_path=eval_data_path) if enable_evaluate else None -tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) -tokenizer_config = RLTokenizeFnConfig(max_length=max_prompt_length) - -train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] -eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] if enable_evaluate else [] - -dataloader_config = DataloaderConfig(pack_max_length=pack_max_length, collator="fake_collator", pack_level="none") - -# 3. judger -dapomath_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k") -judger_cfg = JudgerConfig(reward_judger_configs=[dapomath_judger_config]) - -# 4. dataflow and evaluator -dataflow_config = DataFlowConfig( - env=experimental_name, - prompt_repeat_k=prompt_repeat_k, - global_batch_size=global_batch_size, - sample_params=training_sample_params, +# 3. datasets +tokenizer_config = RLTextTokenizeFnConfig(max_length=max_prompt_length) +train_dataset_cfg = [ + { + "dataset": DatasetConfig(name=experimental_name, anno_path=data_path), + "tokenize_fn": tokenizer_config, + } +] +eval_dataset_cfg = [ + { + "dataset": DatasetConfig(name=experimental_name, anno_path=eval_data_path if enable_evaluate else data_path), + "tokenize_fn": tokenizer_config, + } +] +dataloader_cfg = DataloaderConfig( + dataset_config_list=train_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", ) - -evaluator_cfg = EvaluatorConfig( - enable_evaluate=enable_evaluate, - enable_initial_evaluate=enable_initial_evaluate, - dataset_cfg=eval_dataset_cfg, - tokenizer=tokenizer, - evaluate_step=evaluate_step, - compute_metric_func=None, - sample_params=evaluation_sample_params, -) if enable_evaluate else None - -# replay buffer config: : 不需要修改 -replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=train_dataset_cfg, dataloader_cfg=dataloader_config, tokenizer=tokenizer +eval_dataloader_cfg = DataloaderConfig( + dataset_config_list=eval_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", ) -# 5. Train worker -# NOTE: modify model_cfg +# 4. judger +judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k") + +# 5. train worker model_cfg = get_model_config_from_hf(Path(model_path)) optim_cfg = AdamWConfig(lr=1e-6, foreach=False) loss_cfg = GRPOLossConfig( @@ -147,7 +120,7 @@ ) lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) -train_worker_cfg: WorkerConfig = WorkerConfig( +train_worker_cfg = WorkerConfig( model_cfg=model_cfg, load_from=model_path, optim_cfg=optim_cfg, @@ -159,18 +132,51 @@ pack_max_length=pack_max_length, ) -# 6. RL Trainer -trainer = RLTrainerConfig( - load_from=model_path, +# 6. agent loop managers +agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=training_sample_params, +) +agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="train_task", + agent_loop_config=agent_loop_config, + judger_config=judger_config, + produce_strategy_config=SyncProduceStrategyConfig(), + sampler_config=SamplerConfig(dataloader_cfg=dataloader_cfg, prompt_repeat_k=prompt_repeat_k), + ), +) + +eval_agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=evaluation_sample_params, +) +eval_agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="eval_task", + agent_loop_config=eval_agent_loop_config, + judger_config=judger_config, + sampler_config=SamplerConfig(dataloader_cfg=eval_dataloader_cfg, prompt_repeat_k=1), + ), +) + +# 7. trainer +trainer = RLColocateTrainerConfig( resources=resources, + train_worker_cfg=train_worker_cfg, rollout_config=rollout_config, - dataflow_config=dataflow_config, - judger_config=judger_cfg, - replay_buffer_config=replay_buffer_cfg, - evaluator_config=evaluator_cfg, - train_worker_config=train_worker_cfg, tokenizer_path=model_path, - work_dir=work_dir, + replay_buffer_config=SyncReplayBufferConfig(), + agent_loop_manager_cfg=agent_loop_manager_cfg, + eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg, + evaluator_config=EvaluatorConfig(compute_metric_func=None), + load_from=model_path, total_epochs=total_epochs, + train_batch_size=global_batch_size, + advantage_estimator_config=GRPOAdvantageConfig(eps=1e-8), + enable_evaluate=enable_evaluate, + enable_initial_evaluate=enable_evaluate and enable_initial_evaluate, + evaluate_step=evaluate_step, + work_dir=work_dir, hf_interval=hf_interval, ) diff --git a/examples/v1/config/rl_qwen3_8B_grpo.py b/examples/v1/config/rl_qwen3_8B_grpo.py deleted file mode 100644 index d8dc3c678f..0000000000 --- a/examples/v1/config/rl_qwen3_8B_grpo.py +++ /dev/null @@ -1,184 +0,0 @@ -import os -from copy import deepcopy -from pathlib import Path -from transformers import AutoTokenizer -from xtuner.v1.config import ( - AdamWConfig, - FSDPConfig, - LRConfig, -) -from xtuner.v1.data_proto.rl_data import SampleParams -from xtuner.v1.datasets import RLTokenizeFnConfig -from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig -from xtuner.v1.model.dense.qwen3 import Qwen3Dense8BConfig -from xtuner.v1.ray.base import AcceleratorResourcesConfig -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig -from xtuner.v1.ray.evaluator import EvaluatorConfig -from xtuner.v1.ray.judger.controller import JudgerConfig -from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig -from xtuner.v1.rl.base import WorkerConfig -from xtuner.v1.rl.grpo import GRPOLossConfig -from xtuner.v1.train.rl_trainer import RLTrainerConfig -from xtuner.v1.model import get_model_config_from_hf -from xtuner.v1.rl.config.advantage import GRPOAdvantageConfig - -work_dir = os.environ["WORK_DIR"] -model_path = os.environ["MODEL_PATH"] -data_path = os.environ["DATA_PATH"] -eval_data_path = os.environ["EVAL_DATA_PATH"] -enable_evaluate = True if eval_data_path != "" else False -enbale_partial_rollout = int(os.environ.get("ENBALE_PARTIAL_ROLLOUT", "0")) - -# basic settings -experimental_name = "grpo_gsm8k" -total_epochs = 15 -global_batch_size = 1024 -prompt_repeat_k = 5 -rollout_tp_size = 2 -rollout_ep_size = 1 -max_prompt_length = 512 -max_response_length = 1024 -pack_max_length = 32768 -train_optimizer_steps = 4 -hf_interval = 15 -enable_initial_evaluate = True -evaluate_step = 10 -# TODO: 提供不同模型/不同输入输出长度下最优的rollout_max_batch_size_per_instance配置建议 -# NOTE: 目前Xtuner的数据流并发度由rollout_max_batch_size_per_instance控制,并且提供allow_over_concurrency_ratio来控制数据流并发度略大于推理引擎并发度, -# 具体逻辑可见 xtuner/v1/ray/dataflow/flow.py 中 max_concurrent 的计算方式 -# 当然你也可以手动调整 dataflow_config 中的 max_concurrent 参数来控制数据流并发度 -rollout_max_batch_size_per_instance = 128 - -# grpo quick test settings for rapid accuracy validation within ~30 minutes: -# - Initial eval accuracy: ~25% -# - After training: ~88% eval accuracy -# total_epochs = 3 -# global_batch_size = 64 -# prompt_repeat_k = 5 -# rollout_tp_size = 1 -# rollout_ep_size = 1 -# max_prompt_length = 512 -# max_response_length = 1024 -# pack_max_length = 32768 -# train_optimizer_steps = 1 -# hf_interval = 100 -# enable_initial_evaluate = True -# evaluate_step = 15 - -# 1. resources -resources = AcceleratorResourcesConfig( - accelerator="GPU", - num_workers=8, - num_cpus_per_worker=12, - cpu_memory_per_worker=16 * 1024**3, # 16 GB -) - -# 2. rollout -rollout_config = RolloutConfig( - env=experimental_name, - device=resources.accelerator, - model_path=model_path, - dtype="bfloat16", - tensor_parallel_size=rollout_tp_size, - expert_parallel_size=rollout_ep_size, - gpu_memory_utilization=0.75, - context_length = max_response_length + max_prompt_length, - # rollout_max_batch_size_per_instance=rollout_max_batch_size_per_instance, # optional, will be determined automatically if not set -) - -# sampling params -training_sample_params = SampleParams( - max_tokens=max_response_length, -) -evaluation_sample_params = deepcopy(training_sample_params) -evaluation_sample_params.top_p = 1.0 -evaluation_sample_params.temperature = 0.0 -evaluation_sample_params.top_k = 1 - -# dataset: 不需要修改 -train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) -eval_dataset = DatasetConfig(name=experimental_name, anno_path=eval_data_path) if enable_evaluate else None -tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) -tokenizer_config = RLTokenizeFnConfig(max_length=max_prompt_length) - -train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] -eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] if enable_evaluate else [] - -dataloader_config = DataloaderConfig(pack_max_length=pack_max_length, collator="fake_collator", pack_level="none") - -# 3. judger -gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k") -judger_cfg = JudgerConfig(reward_judger_configs=[gsm8k_judger_config]) - -# 4. dataflow and evaluator -dataflow_config = DataFlowConfig( - env=experimental_name, - prompt_repeat_k=prompt_repeat_k, - global_batch_size=global_batch_size, - sample_params=training_sample_params, - enable_partial_rollout=enbale_partial_rollout, - # max_concurrent=64, # optional, will be determined automatically if not set -) - -evaluator_cfg = EvaluatorConfig( - enable_evaluate=enable_evaluate, - enable_initial_evaluate=enable_initial_evaluate, - dataset_cfg=eval_dataset_cfg, - tokenizer=tokenizer, - evaluate_step=evaluate_step, - compute_metric_func=None, - sample_params=evaluation_sample_params, -) if enable_evaluate else None - -# replay buffer config: : 不需要修改 -replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=train_dataset_cfg, dataloader_cfg=dataloader_config, tokenizer=tokenizer -) - -# 5. Train worker -model_cfg = get_model_config_from_hf(Path(model_path)) -optim_cfg = AdamWConfig(lr=1e-6, foreach=False) -loss_cfg = GRPOLossConfig( - policy_loss_cfg=dict( - cliprange_high=0.2, - cliprange_low=0.2, - loss_type="vanilla", - ), - ignore_idx=-100, - use_kl_loss=True, - kl_loss_coef=0.001, - kl_loss_type="low_var_kl", - mode="chunk", - chunk_size=512, -) -lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) -fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) -train_worker_cfg: WorkerConfig = WorkerConfig( - model_cfg=model_cfg, - load_from=model_path, - optim_cfg=optim_cfg, - loss_cfg=loss_cfg, - lr_cfg=lr_cfg, - fsdp_cfg=fsdp_cfg, - sp_size=1, - optimizer_steps=train_optimizer_steps, - pack_max_length=pack_max_length, -) - -# 6. RL Trainer -trainer = RLTrainerConfig( - load_from=model_path, - resources=resources, - rollout_config=rollout_config, - dataflow_config=dataflow_config, - judger_config=judger_cfg, - replay_buffer_config=replay_buffer_cfg, - evaluator_config=evaluator_cfg, - train_worker_config=train_worker_cfg, - tokenizer_path=model_path, - work_dir=work_dir, - total_epochs=total_epochs, - hf_interval=hf_interval, - advantage_estimator_config=GRPOAdvantageConfig(eps=1e-8), -) diff --git a/examples/v1/config/rl_qwen3_8B_grpo_tiny.py b/examples/v1/config/rl_qwen3_8B_grpo_tiny.py deleted file mode 100644 index cfa07945ba..0000000000 --- a/examples/v1/config/rl_qwen3_8B_grpo_tiny.py +++ /dev/null @@ -1,136 +0,0 @@ -import os -from copy import deepcopy -from pathlib import Path -from transformers import AutoTokenizer -from xtuner.v1.config import ( - AdamWConfig, - FSDPConfig, - LRConfig, -) -from xtuner.v1.data_proto.rl_data import SampleParams -from xtuner.v1.datasets import RLTokenizeFnConfig -from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig -from xtuner.v1.model.dense.qwen3 import Qwen3Dense8BConfig -from xtuner.v1.ray.base import AcceleratorResourcesConfig -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig -from xtuner.v1.ray.evaluator import EvaluatorConfig -from xtuner.v1.ray.judger.controller import JudgerConfig -from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig -from xtuner.v1.rl.base import WorkerConfig -from xtuner.v1.rl.grpo import GRPOLossConfig -from xtuner.v1.train.rl_trainer import RLTrainerConfig -from xtuner.v1.model import get_model_config_from_hf -from xtuner.v1.rl.config.advantage import GRPOAdvantageConfig - -work_dir = os.environ["WORK_DIR"] -model_path = os.environ["MODEL_PATH"] -data_path = os.environ["DATA_PATH"] - -# basic settings -experimental_name = "grpo_gsm8k_tiny" -total_epochs = 1 -global_batch_size = 128 -prompt_repeat_k = 8 -rollout_tp_size = 1 -rollout_ep_size = 1 -max_prompt_length = 512 -max_response_length = 1024 -pack_max_length = 32768 -train_optimizer_steps = 1 - -# 1. resources -resources = AcceleratorResourcesConfig( - accelerator="GPU", - num_workers=1, - num_cpus_per_worker=12, - cpu_memory_per_worker=16 * 1024**3, # 16 GB -) - -# 2. rollout -rollout_config = RolloutConfig( - env=experimental_name, - device=resources.accelerator, - model_path=model_path, - dtype="bfloat16", - tensor_parallel_size=rollout_tp_size, - expert_parallel_size=rollout_ep_size, - gpu_memory_utilization=0.75, - context_length=max_prompt_length+max_response_length, - # rollout_max_batch_size_per_instance=1024, # optional -) - -# sampling params -training_sample_params = SampleParams( - max_tokens=max_response_length, -) - -# dataset: 不需要修改 -tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) -tokenizer_config = RLTokenizeFnConfig(max_length=max_prompt_length) -train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) -train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] -dataloader_config = DataloaderConfig(pack_max_length=pack_max_length, collator="fake_collator", pack_level="none") - -# 3. judger -gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k") -judger_cfg = JudgerConfig(reward_judger_configs=[gsm8k_judger_config]) - -# 4. dataflow and evaluator -dataflow_config = DataFlowConfig( - env=experimental_name, - prompt_repeat_k=prompt_repeat_k, - global_batch_size=global_batch_size, - sample_params=training_sample_params, -) - -# replay buffer config: : 不需要修改 -replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=train_dataset_cfg, dataloader_cfg=dataloader_config, tokenizer=tokenizer -) - -# 5. Train worker -# NOTE: modify model_cfg -model_cfg = get_model_config_from_hf(Path(model_path)) -optim_cfg = AdamWConfig(lr=1e-6, foreach=False) -loss_cfg = GRPOLossConfig( - policy_loss_cfg=dict( - cliprange_high=0.2, - cliprange_low=0.2, - loss_type="vanilla", - ), - ignore_idx=-100, - use_kl_loss=True, - kl_loss_coef=0.001, - kl_loss_type="low_var_kl", - mode="chunk", - chunk_size=512, -) -lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) -fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) -train_worker_cfg: WorkerConfig = WorkerConfig( - model_cfg=model_cfg, - load_from=model_path, - optim_cfg=optim_cfg, - loss_cfg=loss_cfg, - lr_cfg=lr_cfg, - fsdp_cfg=fsdp_cfg, - sp_size=1, - optimizer_steps=train_optimizer_steps, - pack_max_length=pack_max_length, -) - -# 6. RL Trainer -trainer = RLTrainerConfig( - load_from=model_path, - resources=resources, - rollout_config=rollout_config, - dataflow_config=dataflow_config, - judger_config=judger_cfg, - replay_buffer_config=replay_buffer_cfg, - train_worker_config=train_worker_cfg, - tokenizer_path=model_path, - work_dir=work_dir, - total_epochs=total_epochs, - advantage_estimator_config=GRPOAdvantageConfig(eps=1e-8), -) diff --git a/examples/v1/config/rl_qwen3_vl_8B_grpo.py b/examples/v1/config/rl_qwen3_vl_8B_grpo.py deleted file mode 100644 index 648f803cd5..0000000000 --- a/examples/v1/config/rl_qwen3_vl_8B_grpo.py +++ /dev/null @@ -1,203 +0,0 @@ -import os -from copy import deepcopy - -from transformers import AutoTokenizer -from xtuner.v1.config import ( - AdamWConfig, - FSDPConfig, - LRConfig, -) -from xtuner.v1.data_proto.rl_data import SampleParams -from xtuner.v1.model.compose.qwen3_vl import Qwen3VLDense8BConfig -from xtuner.v1.ray.base import AcceleratorResourcesConfig -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig -from xtuner.v1.ray.evaluator import EvaluatorConfig -from xtuner.v1.ray.judger.controller import JudgerConfig -from xtuner.v1.rl.base import WorkerConfig -from xtuner.v1.rl.grpo import GRPOLossConfig -from xtuner.v1.train.rl_trainer import RLTrainerConfig -from xtuner.v1.datasets import RLTokenizeFnConfig, DatasetConfig, Qwen3VLTokenizeFnConfig, DataloaderConfig -from xtuner.v1.ray.judger.geo3k import GEO3KJudgerConfig -from xtuner.v1.rl.config.advantage import GRPOAdvantageConfig - - -work_dir = os.environ["WORK_DIR"] -model_path = os.environ["MODEL_PATH"] -data_path = os.environ["DATA_PATH"] -eval_data_path = os.environ["EVAL_DATA_PATH"] -enable_evaluate = True if eval_data_path != "" else False -media_root = os.environ["MEDIA_ROOT"] - -# basic settings -experimental_name = "grpo_geo3k" -total_epochs = 15 -global_batch_size = 1024 -prompt_repeat_k = 5 -rollout_tp_size = 2 -rollout_ep_size = 1 -max_prompt_length = 1024 -max_response_length = 2048 -pack_max_length = 32768 -train_optimizer_steps = 4 -hf_interval = 15 -enable_initial_evaluate = True -evaluate_step = 10 - -# grpo quick test settings: -# total_epochs = 3 -# global_batch_size = 64 -# prompt_repeat_k = 5 -# rollout_tp_size = 1 -# rollout_ep_size = 1 -# max_prompt_length = 512 -# max_response_length = 1024 -# pack_max_length = 32768 -# train_optimizer_steps = 1 -# hf_interval = 100 -# enable_initial_evaluate = True -# evaluate_step = 15 - -# 1. resources -resources = AcceleratorResourcesConfig( - accelerator="GPU", - num_workers=8, - num_cpus_per_worker=12, - cpu_memory_per_worker=16 * 1024**3, # 16 GB -) - -# 2. rollout -rollout_config = RolloutConfig( - env=experimental_name, - device=resources.accelerator, - model_path=model_path, - dtype="bfloat16", - tensor_parallel_size=rollout_tp_size, - expert_parallel_size=rollout_ep_size, - gpu_memory_utilization=0.75, - context_length = max_response_length + max_prompt_length, - # rollout_max_batch_size_per_instance=64, # optional, will be determined automatically if not set -) - -# sampling params -training_sample_params = SampleParams( - max_tokens=max_response_length, -) -evaluation_sample_params = deepcopy(training_sample_params) -evaluation_sample_params.top_p = 1.0 -evaluation_sample_params.temperature = 0.0 -evaluation_sample_params.top_k = 1 - -# dataset: 不需要修改 -train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) -eval_dataset = DatasetConfig(name=experimental_name, anno_path=eval_data_path) if enable_evaluate else None -tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - -tokenize_fn_cfg = Qwen3VLTokenizeFnConfig(processor_path=model_path, max_length=max_prompt_length) -train_dataset_cfg = [ - { - "dataset": DatasetConfig(name="geo3k", - anno_path=data_path, - class_name='VLMJsonlDataset', - media_root=media_root, - sample_ratio=1.0), - "tokenize_fn": RLTokenizeFnConfig(max_length=max_prompt_length, - tokenize_fn_cfg=tokenize_fn_cfg), - } -] - -eval_dataset_cfg = [] -if enable_evaluate: - eval_dataset_cfg = [ - { - "dataset": DatasetConfig(name="geo3k", - anno_path=eval_data_path, - class_name='VLMJsonlDataset', - media_root=media_root, - sample_ratio=1.0), - "tokenize_fn": RLTokenizeFnConfig(max_length=max_prompt_length, - tokenize_fn_cfg=tokenize_fn_cfg, - ignore_multimodal_info=True), - } - ] - -dataloader_config = DataloaderConfig(num_workers=8, - collator="fake_collator", - pack_level="none") - -# 3. judger -geo3k_judger_config = GEO3KJudgerConfig() -judger_cfg = JudgerConfig(reward_judger_configs=[geo3k_judger_config]) - -# 4. dataflow and evaluator -dataflow_config = DataFlowConfig( - env=experimental_name, - prompt_repeat_k=prompt_repeat_k, - global_batch_size=global_batch_size, - sample_params=training_sample_params, - # max_concurrent=64, # optional, will be determined automatically if not set -) - -evaluator_cfg = EvaluatorConfig( - enable_evaluate=enable_evaluate, - enable_initial_evaluate=enable_initial_evaluate, - dataset_cfg=eval_dataset_cfg, - dataloader_cfg=dataloader_config, - tokenizer=tokenizer, - evaluate_step=evaluate_step, - compute_metric_func=None, - sample_params=evaluation_sample_params, -) if enable_evaluate else None - -# replay buffer config: : 不需要修改 -replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=train_dataset_cfg, dataloader_cfg=dataloader_config, tokenizer=tokenizer -) - -# 5. Train worker -# NOTE: modify model_cfg -model_cfg = Qwen3VLDense8BConfig(freeze_vision=True, freeze_projector=True) -optim_cfg = AdamWConfig(lr=1e-6, foreach=False) -loss_cfg = GRPOLossConfig( - policy_loss_cfg=dict( - cliprange_high=0.2, - cliprange_low=0.2, - loss_type="vanilla", - ), - ignore_idx=-100, - use_kl_loss=True, - kl_loss_coef=0.001, - kl_loss_type="low_var_kl", - mode="chunk", - chunk_size=512, -) -lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) -fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False) -train_worker_cfg: WorkerConfig = WorkerConfig( - model_cfg=model_cfg, - load_from=model_path, - optim_cfg=optim_cfg, - loss_cfg=loss_cfg, - lr_cfg=lr_cfg, - fsdp_cfg=fsdp_cfg, - sp_size=1, - optimizer_steps=train_optimizer_steps, - pack_max_length=pack_max_length, -) - -# 6. RL Trainer -trainer = RLTrainerConfig( - load_from=model_path, - resources=resources, - rollout_config=rollout_config, - dataflow_config=dataflow_config, - judger_config=judger_cfg, - replay_buffer_config=replay_buffer_cfg, - evaluator_config=evaluator_cfg, - train_worker_config=train_worker_cfg, - tokenizer_path=model_path, - work_dir=work_dir, - total_epochs=total_epochs, - hf_interval=hf_interval, - advantage_estimator_config=GRPOAdvantageConfig(eps=1e-8), -) diff --git a/examples/v1/config/rl_qwen3p5_vl_35B_grpo_mixdata.py b/examples/v1/config/rl_qwen3p5_vl_35B_grpo_mixdata.py index 93d7ab26cb..ddcda369dd 100644 --- a/examples/v1/config/rl_qwen3p5_vl_35B_grpo_mixdata.py +++ b/examples/v1/config/rl_qwen3p5_vl_35B_grpo_mixdata.py @@ -1,23 +1,29 @@ +import json import os + from transformers import AutoTokenizer -from xtuner.v1.config import ( - AdamWConfig, - FSDPConfig, - LRConfig, -) -import json -from xtuner.v1.ray.judger.dapo_math import DapoMathJudgerConfig + +from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig from xtuner.v1.data_proto.rl_data import SampleParams +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLQwen3VLTokenizeFnConfig from xtuner.v1.model import Qwen3_5_VLMoE35BA3Config -from xtuner.v1.ray.base import AcceleratorResourcesConfig -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig -from xtuner.v1.ray.judger.controller import JudgerConfig -from xtuner.v1.rl.base import WorkerConfig -from xtuner.v1.rl.grpo import GRPOLossConfig -from xtuner.v1.train.rl_trainer import RLTrainerConfig -from xtuner.v1.datasets import RLTokenizeFnConfig, DatasetConfig, Qwen3VLTokenizeFnConfig, DataloaderConfig -from xtuner.v1.rl.base.rollout_is import RolloutImportanceSampling +from xtuner.v1.rl.advantage import GRPOAdvantageConfig +from xtuner.v1.rl.agent_loop import SingleTurnAgentLoopConfig +from xtuner.v1.rl.agent_loop_manager import AgentLoopManagerConfig, SamplerConfig, SyncProduceStrategyConfig, TaskSpecConfig +from xtuner.v1.rl.evaluator import EvaluatorConfig +from xtuner.v1.rl.judger import DapoMathJudgerConfig +from xtuner.v1.rl.loss import GRPOLossConfig +from xtuner.v1.rl.replay_buffer import SyncReplayBufferConfig +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.trainer import RolloutImportanceSampling, WorkerConfig +from xtuner.v1.rl.utils import AcceleratorResourcesConfig, get_eos_token +from xtuner.v1.train.rl_trainer import RLColocateTrainerConfig + + +def _as_list(value): + return value if isinstance(value, list) else [value] + work_dir = os.environ["WORK_DIR"] model_path = os.environ["MODEL_PATH"] @@ -41,7 +47,7 @@ accelerator="GPU", num_workers=8, num_cpus_per_worker=12, - cpu_memory_per_worker=16 * 1024**3, # 16 GB + cpu_memory_per_worker=16 * 1024**3, ) # 2. rollout @@ -54,7 +60,7 @@ tensor_parallel_size=rollout_tp_size, expert_parallel_size=rollout_ep_size, gpu_memory_utilization=0.8, - context_length = max_response_length + max_prompt_length, + context_length=max_response_length + max_prompt_length, enable_return_routed_experts=True, rollout_max_batch_size_per_instance=512, ) @@ -67,64 +73,89 @@ temperature=1.0, min_tokens=0, ) +evaluation_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=1, + top_p=1.0, + temperature=0.0, + min_tokens=0, +) - -tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - -with open(meta_data_path) as f: +# 3. datasets +with open(meta_data_path, "r", encoding="utf-8") as f: ds_collections = json.load(f) train_dataset_cfg = [] -for name, _data in ds_collections.items(): - tokenize_fn_cfg = Qwen3VLTokenizeFnConfig(processor_path=model_path, - max_length=max_prompt_length, - system_message=_data.get('system_message', None), - chat_template="qwen3-vl-rl") - _data_cfg = {"dataset": DatasetConfig(name=name, - anno_path=_data['annotation'], - media_root=_data.get('media_root', ''), - sample_ratio=_data.get('sample_ratio', 1.0), - class_name='VLMJsonlDataset'), - "tokenize_fn": RLTokenizeFnConfig(max_length=max_prompt_length, - tokenize_fn_cfg=tokenize_fn_cfg), - } - train_dataset_cfg.append(_data_cfg) - -dataloader_config = DataloaderConfig(num_workers=8, - collator="fake_collator", - pack_level="none") - -# 3. judger -from xtuner.v1.utils.rl_test_utils import get_eos_token +eval_dataset_cfg = [] +for name, data in ds_collections.items(): + annotations = _as_list(data["annotation"]) + for annotation in annotations: + train_dataset_cfg.append( + { + "dataset": DatasetConfig( + name=name, + anno_path=annotation, + media_root=data.get("media_root", ""), + sample_ratio=data.get("sample_ratio", 1.0), + class_name="VLMJsonlDataset", + ), + "tokenize_fn": RLQwen3VLTokenizeFnConfig( + processor_path=model_path, + max_length=max_prompt_length, + system_message=data.get("system_message", None), + chat_template="qwen3.5-vl", + ), + } + ) + eval_dataset_cfg.append( + { + "dataset": DatasetConfig( + name=name, + anno_path=annotation, + media_root=data.get("media_root", ""), + sample_ratio=data.get("sample_ratio", 1.0), + class_name="VLMJsonlDataset", + ), + "tokenize_fn": RLQwen3VLTokenizeFnConfig( + processor_path=model_path, + max_length=max_prompt_length, + system_message=data.get("system_message", None), + chat_template="qwen3.5-vl", + ignore_multimodal_info=True, + ), + } + ) + +dataloader_cfg = DataloaderConfig( + dataset_config_list=train_dataset_cfg, + num_workers=8, + collator="fake_collator", + pack_level="none", + pack_max_length=pack_max_length, +) +eval_dataloader_cfg = DataloaderConfig( + dataset_config_list=eval_dataset_cfg, + num_workers=8, + collator="fake_collator", + pack_level="none", + pack_max_length=pack_max_length, +) + +# 4. judger +tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) eos_token_id = get_eos_token(model_path) eos_token_str = tokenizer.convert_ids_to_tokens(eos_token_id) -dapomath_judger_config = DapoMathJudgerConfig( - judger_name="dapo_math", +judger_config = DapoMathJudgerConfig( + judger_name="dapo_math", eos_token=eos_token_str, - enable_overlong_buffer = True, - max_response_len =max_response_length, - overlong_buffer_len=4096, - overlong_penalty_factor=1.0, - tokenizer=tokenizer) -judger_cfg = JudgerConfig(reward_judger_configs=[dapomath_judger_config]) - -# 4. dataflow and evaluator -dataflow_config = DataFlowConfig( - env=experimental_name, - prompt_repeat_k=prompt_repeat_k, - global_batch_size=global_batch_size, - sample_params=training_sample_params, - # max_concurrent=64, # optional, will be determined automatically if not set -) - - -# replay buffer config: : 不需要修改 -replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=train_dataset_cfg, dataloader_cfg=dataloader_config, tokenizer=tokenizer + enable_overlong_buffer=True, + max_response_len=max_response_length, + overlong_buffer_len=4096, + overlong_penalty_factor=1.0, + tokenizer=tokenizer, ) -# 5. Train worker -# NOTE: modify model_cfg +# 5. train worker model_cfg = Qwen3_5_VLMoE35BA3Config(freeze_vision=True, freeze_projector=True) optim_cfg = AdamWConfig(lr=1e-6, betas=(0.9, 0.999), max_grad_norm=1.0, weight_decay=0.1, foreach=False) loss_cfg = GRPOLossConfig( @@ -152,7 +183,7 @@ ) lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1, fp32_lm_head=True) -train_worker_cfg: WorkerConfig = WorkerConfig( +train_worker_cfg = WorkerConfig( model_cfg=model_cfg, load_from=model_path, optim_cfg=optim_cfg, @@ -164,17 +195,51 @@ pack_max_length=pack_max_length, ) -# 6. RL Trainer -trainer = RLTrainerConfig( - load_from=model_path, +# 6. agent loop managers +agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=training_sample_params, +) +agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="train_task", + agent_loop_config=agent_loop_config, + judger_config=judger_config, + produce_strategy_config=SyncProduceStrategyConfig(), + sampler_config=SamplerConfig(dataloader_cfg=dataloader_cfg, prompt_repeat_k=prompt_repeat_k), + ), +) + +eval_agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=evaluation_sample_params, +) +eval_agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="eval_task", + agent_loop_config=eval_agent_loop_config, + judger_config=judger_config, + sampler_config=SamplerConfig(dataloader_cfg=eval_dataloader_cfg, prompt_repeat_k=1), + ), +) + +# 7. trainer +trainer = RLColocateTrainerConfig( resources=resources, + train_worker_cfg=train_worker_cfg, rollout_config=rollout_config, - dataflow_config=dataflow_config, - judger_config=judger_cfg, - replay_buffer_config=replay_buffer_cfg, - train_worker_config=train_worker_cfg, tokenizer_path=model_path, - work_dir=work_dir, + replay_buffer_config=SyncReplayBufferConfig(), + agent_loop_manager_cfg=agent_loop_manager_cfg, + eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg, + evaluator_config=EvaluatorConfig(compute_metric_func=None), + load_from=model_path, total_epochs=total_epochs, + train_batch_size=global_batch_size, + advantage_estimator_config=GRPOAdvantageConfig(eps=1e-8), + enable_evaluate=False, + enable_initial_evaluate=False, + evaluate_step=1, + work_dir=work_dir, hf_interval=hf_interval, ) diff --git a/examples/v1/scripts/run_rl.sh b/examples/v1/scripts/run_rl.sh index cdba439d37..d727500238 100644 --- a/examples/v1/scripts/run_rl.sh +++ b/examples/v1/scripts/run_rl.sh @@ -23,6 +23,8 @@ else ACCELERATOR_PER_NODE=${7:-8} fi +ulimit -n 65536 # OSError: [Errno 24] Too many open files + export PYTHONPATH=$(pwd):$PYTHONPATH # ray 环境变量 diff --git a/examples/v1/scripts/run_rl_submit.sh b/examples/v1/scripts/run_rl_submit.sh index 68e47d1323..5b7b645d65 100644 --- a/examples/v1/scripts/run_rl_submit.sh +++ b/examples/v1/scripts/run_rl_submit.sh @@ -21,6 +21,8 @@ else ACCELERATOR_PER_NODE=${7:-8} fi +ulimit -n 65536 # OSError: [Errno 24] Too many open files + export PYTHONPATH=$(pwd):$PYTHONPATH # NOTE: if you add new env vars, please also add them to RUNTIME_ENV_JSON in step 4. # master 节点的 IP 地址 diff --git a/recipe/claude_code/calculator_tool.py b/recipe/claude_code/calculator_tool.py new file mode 100644 index 0000000000..7c617b13b8 --- /dev/null +++ b/recipe/claude_code/calculator_tool.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +import json +import re +import sys +from pathlib import Path + +from xtuner.v1.data_proto.rl_data import RolloutState +from xtuner.v1.rl.judger.native import Judger + + +CALCULATOR_TOOL_NAME = "mcp__calculator__calculator" +CALCULATOR_PROMPT = """You are NOT allowed to do arithmetic yourself. + +You MUST use the calculator tool to compute the result. +The calculator tool is named mcp__calculator__calculator. +Your first assistant response MUST be exactly one structured tool call and nothing else: + + + +23 + 19 + + + + +Question: +What is 23 + 19? + +Return only the final answer.""" + +CALCULATOR_SYSTEM_PROMPT = """You are testing an agent loop tool-calling path. +The only successful behavior is: +1. Call the mcp__calculator__calculator tool with {"expression": "23 + 19"}. +2. Read the tool result. +3. Return only the final answer as plain text: 42. + +Do not solve arithmetic directly. Do not describe a tool call in prose. +Do not generate a title. Do not use boxed answer formatting. +For Qwen-style tool calls, use this exact XML form: + + + +23 + 19 + + +""" + + +class CalculatorJudger(Judger): + async def judge(self, rollout_state: RolloutState) -> RolloutState: + stdout = rollout_state.extra_fields.get("claudecode_cli_stdout") or "" + answer = "" + if stdout: + try: + answer = normalize_answer(json.loads(stdout).get("result")) + except Exception: + answer = normalize_answer(stdout) + rollout_state.reward = { + "score": 1.0 if answer == "42" else 0.0, + "answer": answer, + } + return rollout_state + + +def normalize_answer(value: object) -> str: + text = "" if value is None else str(value) + text = text.strip().strip("`").strip() + boxed = re.search(r"\\boxed\{([^{}]+)\}", text) + if boxed: + return boxed.group(1).strip() + final_answer = re.search(r"final answer\s*:\s*([^\n]+)", text, flags=re.IGNORECASE) + if final_answer: + text = final_answer.group(1).strip() + return text.strip().strip("`").strip() + + +def write_calculator_mcp_server(work_dir: Path) -> tuple[Path, Path]: + mcp_server_path = work_dir / "calculator_mcp_server.py" + mcp_config_path = work_dir / "calculator_mcp_config.json" + mcp_server_path.write_text( + """ +from __future__ import annotations + +import ast +import operator + +from fastmcp import FastMCP + + +mcp = FastMCP("calculator") + +OPS = { + ast.Add: operator.add, + ast.Sub: operator.sub, + ast.Mult: operator.mul, + ast.Div: operator.truediv, + ast.USub: operator.neg, +} + + +def _eval(node): + if isinstance(node, ast.Expression): + return _eval(node.body) + if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)): + return node.value + if isinstance(node, ast.BinOp) and type(node.op) in OPS: + return OPS[type(node.op)](_eval(node.left), _eval(node.right)) + if isinstance(node, ast.UnaryOp) and type(node.op) in OPS: + return OPS[type(node.op)](_eval(node.operand)) + raise ValueError("Only simple arithmetic expressions are supported.") + + +@mcp.tool(name="calculator", description="Evaluate a simple arithmetic expression") +def calculator(expression: str) -> str: + value = _eval(ast.parse(expression, mode="eval")) + if isinstance(value, float) and value.is_integer(): + value = int(value) + return str(value) + + +if __name__ == "__main__": + mcp.run() +""".lstrip(), + encoding="utf-8", + ) + mcp_config_path.write_text( + json.dumps( + { + "mcpServers": { + "calculator": { + "command": sys.executable, + "args": [str(mcp_server_path)], + } + } + }, + indent=2, + ), + encoding="utf-8", + ) + return mcp_server_path, mcp_config_path diff --git a/recipe/claude_code/claudecode_agent_loop.py b/recipe/claude_code/claudecode_agent_loop.py new file mode 100644 index 0000000000..c62e113d69 --- /dev/null +++ b/recipe/claude_code/claudecode_agent_loop.py @@ -0,0 +1,350 @@ +from __future__ import annotations + +import asyncio +import copy +import os +from pathlib import Path +from typing import Any +from uuid import uuid4 + +import httpx +from pydantic import ConfigDict, Field + +from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status +from xtuner.v1.rl.agent_loop.agent_loop import AgentLoop, AgentLoopConfig +from xtuner.v1.rl.judger.native import Judger +from xtuner.v1.rl.rollout import RolloutController +from xtuner.v1.rl.utils import chat_trace_records_to_rollout_states + + +DEFAULT_READONLY_INSTRUCTION = ( + "You are running inside an automated rollout collection job. " + "Work in read-only mode: inspect files and report findings, but do not edit, create, delete, move, " + "format, commit, push, install dependencies, or run commands that write to the repository or external services." +) + + +class ClaudeCodeAgentLoopConfig(AgentLoopConfig): + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + + claude_command: list[str] = Field(default_factory=lambda: ["$HOME/.local/bin/claude"]) + cwd: str | None = None + timeout_s: float = 600.0 + api_timeout_ms: int = 600000 + max_turns: int = 5 + output_format: str = "json" + permission_mode: str = "plan" + tools: str | None = "Read,Grep,Glob,LS,Bash" + allowed_tools: str | None = None + disallowed_tools: str | None = "Edit,Write,MultiEdit,NotebookEdit" + mcp_config: list[str] = Field(default_factory=list) + strict_mcp_config: bool = False + system_prompt: str | None = None + append_system_prompt: str | None = None + readonly_instruction: str = DEFAULT_READONLY_INSTRUCTION + bare: bool = True + extra_env: dict[str, str] = Field(default_factory=dict) + + def build_local( + self, + rollout_controller, + judger: Judger | None = None, + logger=None, + ) -> ClaudeCodeAgentLoop: + return ClaudeCodeAgentLoop( + claude_command=self.claude_command, + cwd=self.cwd, + timeout_s=self.timeout_s, + api_timeout_ms=self.api_timeout_ms, + max_turns=self.max_turns, + output_format=self.output_format, + permission_mode=self.permission_mode, + tools=self.tools, + allowed_tools=self.allowed_tools, + disallowed_tools=self.disallowed_tools, + mcp_config=self.mcp_config, + strict_mcp_config=self.strict_mcp_config, + system_prompt=self.system_prompt, + append_system_prompt=self.append_system_prompt, + readonly_instruction=self.readonly_instruction, + bare=self.bare, + extra_env=self.extra_env, + rollout_ctl=rollout_controller, + sample_params=self.sample_params, + hf_checkpoint=self.hf_checkpoint, + judger=judger, + logger=logger, + ) + + +class ClaudeCodeAgentLoop(AgentLoop): + def __init__( + self, + claude_command: list[str], + cwd: str | None, + timeout_s: float, + api_timeout_ms: int, + max_turns: int, + output_format: str, + permission_mode: str, + tools: str | None, + allowed_tools: str | None, + disallowed_tools: str | None, + mcp_config: list[str], + strict_mcp_config: bool, + system_prompt: str | None, + append_system_prompt: str | None, + readonly_instruction: str, + bare: bool, + extra_env: dict[str, str], + rollout_ctl: RolloutController, + sample_params: SampleParams, + hf_checkpoint: str, + judger: Judger | None = None, + logger=None, + ) -> None: + super().__init__( + rollout_ctl=rollout_ctl, + sample_params=sample_params, + hf_checkpoint=hf_checkpoint, + judger=judger, + logger=logger, + ) + self.claude_command = claude_command + self.cwd = cwd + self.timeout_s = timeout_s + self.api_timeout_ms = api_timeout_ms + self.max_turns = max_turns + self.output_format = output_format + self.permission_mode = permission_mode + self.tools = tools + self.allowed_tools = allowed_tools + self.disallowed_tools = disallowed_tools + self.mcp_config = mcp_config + self.strict_mcp_config = strict_mcp_config + self.system_prompt = system_prompt + self.append_system_prompt = append_system_prompt + self.readonly_instruction = readonly_instruction + self.bare = bare + self.extra_env = extra_env + + async def generate_sample( # type: ignore[override] + self, rollout_state: RolloutState, **kwargs + ) -> list[RolloutState]: + try: + metadata = await self.rollout_ctl.get_rollout_metadata.remote() # type: ignore[attr-defined] + gateway_url = metadata.get("api_server_url") + rollout_config = metadata.get("rollout_config") + model_name = getattr(rollout_config, "model_name", None) or "rollout-controller" + if not gateway_url: + return [ + self._failed_state( + rollout_state, + "Gateway is not started. Configure GatewayConfig(auto_start=True) " + "before using ClaudeCodeAgentLoop.", + ) + ] + + api_key = f"claudecode_{uuid4().hex}" + command = self._build_command(rollout_state, model_name=model_name) + returncode, stdout, stderr = await self._run_claude(command, gateway_url, model_name, api_key) + records = await self._pop_trace_store_records(gateway_url, api_key) + rollout_extra_fields = { + "claudecode_api_key": api_key, + "claudecode_cli_returncode": returncode, + "claudecode_cli_stdout": self._truncate(stdout), + "claudecode_cli_stderr": self._truncate(stderr), + } + + if not records: + reason = "Claude Code finished without trace store records for this api_key." + if returncode != 0: + reason += f" returncode={returncode}, stderr={self._truncate(stderr)}" + return [self._failed_state(rollout_state, reason, extra_fields=rollout_extra_fields)] + + reward = None + if self.judger is not None: + judge_state = rollout_state.model_copy(deep=True) + judge_state.extra_fields = { + **copy.deepcopy(rollout_state.extra_fields), + **copy.deepcopy(rollout_extra_fields), + } + judged_state = await self.judger.judge(judge_state) + if judged_state.reward is None: + return [ + self._failed_state( + rollout_state, + "Judger completed without setting reward.", + extra_fields=rollout_extra_fields, + ) + ] + reward = copy.deepcopy(judged_state.reward) + + states = chat_trace_records_to_rollout_states( + rollout_state=rollout_state, + records=records, + tokenizer=self.tokenizer, + extra_fields=rollout_extra_fields, + ) + if not states: + return [ + self._failed_state( + rollout_state, + "Gateway trace records did not contain trainable turns.", + extra_fields=rollout_extra_fields, + ) + ] + + completed_states = [state for state in states if state.status == Status.COMPLETED] + if reward is not None: + for state in completed_states: + state.reward = copy.deepcopy(reward) + return states + except Exception as exc: + return [self._failed_state(rollout_state, f"ClaudeCodeAgentLoop failed: {exc}")] + + def _build_command(self, rollout_state: RolloutState, *, model_name: str) -> list[str]: + command = [os.path.expandvars(os.path.expanduser(part)) for part in self.claude_command] + prompt = self._build_prompt(rollout_state) + if self.bare: + command.append("--bare") + if self.system_prompt: + command.extend(["--system-prompt", self.system_prompt]) + if self.append_system_prompt: + command.extend(["--append-system-prompt", self.append_system_prompt]) + for config in self.mcp_config: + command.extend(["--mcp-config", os.path.expandvars(os.path.expanduser(config))]) + if self.strict_mcp_config: + command.append("--strict-mcp-config") + command.extend( + [ + "-p", + prompt, + "--output-format", + self.output_format, + "--permission-mode", + self.permission_mode, + "--model", + model_name, + "--max-turns", + str(self.max_turns), + "--no-session-persistence", + ] + ) + if self.tools is not None: + command.extend(["--tools", self.tools]) + if self.allowed_tools: + command.extend(["--allowedTools", self.allowed_tools]) + if self.disallowed_tools: + command.extend(["--disallowedTools", self.disallowed_tools]) + return command + + def _build_prompt(self, rollout_state: RolloutState) -> str: + content = "" + for message in reversed(rollout_state.message): + if message.get("role") == "user": + content = self._message_content_to_text(message.get("content")) + break + if not content and rollout_state.message: + content = self._message_content_to_text(rollout_state.message[-1].get("content")) + if not self.readonly_instruction: + return content + return f"{self.readonly_instruction}\n\nTask:\n{content}" + + def _message_content_to_text(self, content: Any) -> str: + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + parts = [] + for item in content: + if isinstance(item, dict): + if "text" in item: + parts.append(str(item["text"])) + elif item.get("type") == "text": + parts.append(str(item.get("text", ""))) + else: + parts.append(str(item)) + else: + parts.append(str(item)) + return "\n".join(part for part in parts if part) + return str(content) + + async def _run_claude( + self, + command: list[str], + gateway_url: str, + model_name: str, + api_key: str, + ) -> tuple[int, str, str]: + env = os.environ.copy() + env.update( + { + "ANTHROPIC_BASE_URL": gateway_url, + "ANTHROPIC_AUTH_TOKEN": api_key, + "ANTHROPIC_API_KEY": api_key, + "ANTHROPIC_MODEL": model_name, + "API_TIMEOUT_MS": str(self.api_timeout_ms), + "PATH": f"{Path.home() / '.local' / 'bin'}:{env.get('PATH', '')}", + } + ) + env.update({key: os.path.expandvars(os.path.expanduser(value)) for key, value in self.extra_env.items()}) + + process = await asyncio.create_subprocess_exec( + *command, + cwd=str(Path(self.cwd or os.getcwd()).resolve()), + env=env, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + try: + stdout_bytes, stderr_bytes = await asyncio.wait_for(process.communicate(), timeout=self.timeout_s) + except asyncio.TimeoutError: + process.kill() + stdout_bytes, stderr_bytes = await process.communicate() + returncode = process.returncode if process.returncode is not None else -9 + stderr = stderr_bytes.decode("utf-8", errors="replace") + stderr = f"Claude Code timed out after {self.timeout_s}s.\n{stderr}" + return returncode, stdout_bytes.decode("utf-8", errors="replace"), stderr + + return ( + process.returncode if process.returncode is not None else 0, + stdout_bytes.decode("utf-8", errors="replace"), + stderr_bytes.decode("utf-8", errors="replace"), + ) + + async def _pop_trace_store_records(self, gateway_url: str, api_key: str) -> list[dict[str, Any]]: + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.post( + f"{gateway_url.rstrip('/')}/trace_store/pop", + headers={"Authorization": f"Bearer {api_key}"}, + ) + response.raise_for_status() + payload = response.json() + records = payload.get("records", []) + if not isinstance(records, list): + return [] + return records + + def _failed_state( + self, + rollout_state: RolloutState, + error_msg: str, + *, + extra_fields: dict[str, Any] | None = None, + ) -> RolloutState: + failed = rollout_state.model_copy(deep=True) + failed.status = Status.FAILED + failed.error_msg = error_msg + if extra_fields: + failed.extra_fields = { + **copy.deepcopy(rollout_state.extra_fields), + **copy.deepcopy(extra_fields), + } + return failed + + def _truncate(self, text: str, max_chars: int = 4096) -> str: + if len(text) <= max_chars: + return text + return text[:max_chars] + "..." diff --git a/recipe/claude_code/run_claudecode_tool_e2e.sh b/recipe/claude_code/run_claudecode_tool_e2e.sh new file mode 100755 index 0000000000..8d2dd87774 --- /dev/null +++ b/recipe/claude_code/run_claudecode_tool_e2e.sh @@ -0,0 +1,52 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)" + +XTUNER_USE_LMDEPLOY="${XTUNER_USE_LMDEPLOY:-1}" +XTUNER_CLAUDECODE_TOOL_MAX_TURNS="${XTUNER_CLAUDECODE_TOOL_MAX_TURNS:-4}" +XTUNER_CLAUDECODE_TOOL_MAX_TOKENS="${XTUNER_CLAUDECODE_TOOL_MAX_TOKENS:-1024}" +XTUNER_CLAUDECODE_TOOL_CONTEXT_LENGTH="${XTUNER_CLAUDECODE_TOOL_CONTEXT_LENGTH:-32768}" +XTUNER_CLAUDECODE_TOOL_TIMEOUT_S="${XTUNER_CLAUDECODE_TOOL_TIMEOUT_S:-600}" +XTUNER_CLAUDECODE_TOOL_OUTPUT_DIR="${XTUNER_CLAUDECODE_TOOL_OUTPUT_DIR:-/tmp/xtuner_claudecode_tool_e2e}" +XTUNER_CLAUDECODE_TOOL_CALL_PARSER="${XTUNER_CLAUDECODE_TOOL_CALL_PARSER:-qwen3p5}" +XTUNER_CLAUDECODE_REASONING_PARSER="${XTUNER_CLAUDECODE_REASONING_PARSER:-qwen3}" + +if [[ -z "${ROLLOUT_MODEL_PATH:-}" ]]; then + echo "ROLLOUT_MODEL_PATH must be set." >&2 + exit 1 +fi + +if [[ ! -e "${ROLLOUT_MODEL_PATH}" ]]; then + echo "ROLLOUT_MODEL_PATH does not exist: ${ROLLOUT_MODEL_PATH}" >&2 + exit 1 +fi + +mkdir -p "${XTUNER_CLAUDECODE_TOOL_OUTPUT_DIR}" + +export PYTHONPATH="${REPO_ROOT}:${PYTHONPATH:-}" + +export ROLLOUT_MODEL_PATH +export XTUNER_USE_LMDEPLOY +export XTUNER_CLAUDECODE_TOOL_MAX_TURNS +export XTUNER_CLAUDECODE_TOOL_MAX_TOKENS +export XTUNER_CLAUDECODE_TOOL_CONTEXT_LENGTH +export XTUNER_CLAUDECODE_TOOL_TIMEOUT_S +export XTUNER_CLAUDECODE_TOOL_OUTPUT_DIR +export XTUNER_CLAUDECODE_TOOL_CALL_PARSER +export XTUNER_CLAUDECODE_REASONING_PARSER + +cd "${REPO_ROOT}" + +echo "Running Claude Code calculator tool E2E" +echo " repo root: ${REPO_ROOT}" +echo " model: ${ROLLOUT_MODEL_PATH}" +echo " output: ${XTUNER_CLAUDECODE_TOOL_OUTPUT_DIR}" +echo " max_turns: ${XTUNER_CLAUDECODE_TOOL_MAX_TURNS}" +echo " max_tokens: ${XTUNER_CLAUDECODE_TOOL_MAX_TOKENS}" +echo " context_length: ${XTUNER_CLAUDECODE_TOOL_CONTEXT_LENGTH}" +echo " tool_call_parser: ${XTUNER_CLAUDECODE_TOOL_CALL_PARSER}" +echo " reasoning_parser: ${XTUNER_CLAUDECODE_REASONING_PARSER}" + +python "${SCRIPT_DIR}/test_claude_code_with_calculator.py" diff --git a/recipe/claude_code/test_claude_code_with_calculator.py b/recipe/claude_code/test_claude_code_with_calculator.py new file mode 100644 index 0000000000..c894753e40 --- /dev/null +++ b/recipe/claude_code/test_claude_code_with_calculator.py @@ -0,0 +1,247 @@ +from __future__ import annotations + +import asyncio +import json +import os +import tempfile +from pathlib import Path +from typing import Any +from uuid import uuid4 + +from calculator_tool import ( + CalculatorJudger, + CALCULATOR_PROMPT, + CALCULATOR_SYSTEM_PROMPT, + CALCULATOR_TOOL_NAME, + normalize_answer, + write_calculator_mcp_server, +) +from claudecode_agent_loop import ClaudeCodeAgentLoopConfig +from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status +from xtuner.v1.rl.gateway import wait_for_gateway_ready +from xtuner.v1.rl.utils import find_free_ports + + +RESOURCE_MAP = { + "npu": "NPU", + "cuda": "GPU", +} + + +async def test_claude_code_with_calculator(model_path: str) -> list[RolloutState]: + import ray + + os.environ.setdefault("XTUNER_USE_FA3", "1") + os.environ.setdefault("LMD_SKIP_WARMUP", "1") + os.environ.pop("RAY_ADDRESS", None) + ray.init(address="local", ignore_reinit_error=True) + + temp_dir = tempfile.TemporaryDirectory() + work_dir = Path(temp_dir.name) + worker_log_dir = work_dir / "work_dirs" + output_dir = Path(os.environ.get("XTUNER_CLAUDECODE_TOOL_OUTPUT_DIR", work_dir / "outputs")) + output_dir.mkdir(parents=True, exist_ok=True) + controller = None + placement_group = None + + try: + _, mcp_config_path = write_calculator_mcp_server(work_dir) + gateway_url, controller, placement_group = _start_rollout_controller_and_gateway( + model_path=model_path, + worker_log_dir=worker_log_dir, + ) + cfg = ClaudeCodeAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=SampleParams( + max_tokens=int(os.environ.get("XTUNER_CLAUDECODE_TOOL_MAX_TOKENS", "1024")), + temperature=0.0, + ), + claude_command=[os.environ.get("XTUNER_CLAUDE_BIN", str(Path.home() / ".local" / "bin" / "claude"))], + cwd=str(work_dir), + timeout_s=float(os.environ.get("XTUNER_CLAUDECODE_TOOL_TIMEOUT_S", "600")), + api_timeout_ms=int(os.environ.get("XTUNER_CLAUDECODE_TOOL_API_TIMEOUT_MS", "600000")), + max_turns=int(os.environ.get("XTUNER_CLAUDECODE_TOOL_MAX_TURNS", "4")), + output_format="json", + permission_mode=os.environ.get("XTUNER_CLAUDECODE_PERMISSION_MODE", "bypassPermissions"), + tools=None, + allowed_tools=CALCULATOR_TOOL_NAME, + disallowed_tools="Bash,Edit,Read,Grep,Glob,LS,WebFetch,WebSearch", + mcp_config=[str(mcp_config_path)], + strict_mcp_config=True, + system_prompt=CALCULATOR_SYSTEM_PROMPT, + readonly_instruction="", + ) + agent_loop = cfg.build(rollout_controller=controller, judger=CalculatorJudger()) + rollout_state = RolloutState( + message=[{"role": "user", "content": CALCULATOR_PROMPT}], + task_name="calculator_tool_call", + extra_fields={"gateway_url": gateway_url}, + ) + + states = await agent_loop.generate_sample(rollout_state) + _dump_rollout_states(output_dir, "calculator", states) + failed = [state for state in states if state.status == Status.FAILED] + if failed: + raise AssertionError(f"ClaudeCodeAgentLoop returned failed states: {[state.error_msg for state in failed]}") + + completed = [state for state in states if state.status == Status.COMPLETED] + if len(completed) < 2: + raise AssertionError("Expected one tool-call turn and one final-answer turn.") + + api_keys = {state.extra_fields["claudecode_api_key"] for state in completed} + if len(api_keys) != 1: + raise AssertionError(f"Expected one Claude Code api key, got {api_keys}.") + for index, state in enumerate(completed): + if state.prompt_ids is None: + raise AssertionError(f"State {index} is missing prompt_ids.") + if state.tokens != state.prompt_ids: + raise AssertionError(f"State {index} tokens must equal prompt_ids.") + if not state.response_ids: + raise AssertionError(f"State {index} is missing response_ids.") + if state.logprobs is None: + raise AssertionError(f"State {index} is missing logprobs.") + if len(state.logprobs) != len(state.response_ids): + raise AssertionError(f"State {index} logprobs length does not match response_ids length.") + if state.response_mask != [1] * len(state.response_ids): + raise AssertionError(f"State {index} response_mask does not match response_ids length.") + if state.response is None: + raise AssertionError(f"State {index} is missing response text.") + if "gateway_trace_records" not in state.extra_fields: + raise AssertionError(f"State {index} is missing gateway_trace_records.") + if state.extra_fields["gateway_trace_count"] != len(states): + raise AssertionError(f"State {index} gateway_trace_count does not match trace count.") + if state.extra_fields["claudecode_cli_returncode"] != 0: + raise AssertionError(f"State {index} Claude Code returncode is not 0.") + + tool_blocks = [] + for state in completed: + snapshot = state.extra_fields.get("gateway_response_snapshot") or {} + for block in snapshot.get("content") or []: + if block.get("type") == "tool_use": + tool_blocks.append(block) + if not tool_blocks: + raise AssertionError("Expected at least one calculator tool_use block.") + calculator_blocks = [block for block in tool_blocks if "calculator" in str(block.get("name"))] + if not calculator_blocks: + raise AssertionError(f"Expected calculator tool call, got: {tool_blocks}") + if calculator_blocks[0].get("input", {}).get("expression") != "23 + 19": + raise AssertionError(f"Unexpected calculator input: {calculator_blocks[0].get('input')}") + + final_answer = normalize_answer(completed[-1].response) + if final_answer != "42": + raise AssertionError(f"Expected final answer 42, got {final_answer!r}.") + if completed[-1].reward != {"score": 1.0, "answer": "42"}: + raise AssertionError(f"Expected reward score 1.0 for answer 42, got {completed[-1].reward!r}.") + return states + finally: + _cleanup_ray(controller=controller, placement_group=placement_group) + temp_dir.cleanup() + + +def _start_rollout_controller_and_gateway( + *, + model_path: str, + worker_log_dir: Path, +) -> tuple[str, Any, Any]: + import ray + import torch + + from xtuner.v1.rl.gateway.config import GatewayConfig + from xtuner.v1.rl.rollout.worker import RolloutConfig + from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers + + accelerator = RESOURCE_MAP[torch.accelerator.current_accelerator().type] + tensor_parallel_size = int(os.environ.get("XTUNER_CLAUDECODE_TOOL_TP", "1")) + num_workers = int(os.environ.get("XTUNER_CLAUDECODE_TOOL_NUM_WORKERS", str(tensor_parallel_size))) + resource_config = AcceleratorResourcesConfig( + accelerator=accelerator, + num_workers=num_workers, + num_cpus_per_worker=int(os.environ.get("XTUNER_CLAUDECODE_TOOL_CPUS_PER_WORKER", "8")), + cpu_memory_per_worker=int(os.environ.get("XTUNER_CLAUDECODE_TOOL_CPU_MEMORY", str(16 * 1024**3))), + ) + placement_group = AutoAcceleratorWorkers.build_placement_group( + resource_config, + name=f"claudecode_tool_pg_{uuid4().hex[:8]}", + ) + rollout_config = RolloutConfig( + env=f"claudecode_tool_{uuid4().hex[:8]}", + model_path=model_path, + model_name=os.path.basename(model_path).lower(), + tokenizer_path=model_path, + context_length=int(os.environ.get("XTUNER_CLAUDECODE_TOOL_CONTEXT_LENGTH", "32768")), + worker_log_dir=worker_log_dir / "rollout", + tensor_parallel_size=tensor_parallel_size, + expert_parallel_size=1, + dist_port_base=int( + os.environ.get("XTUNER_CLAUDECODE_TOOL_DIST_PORT_BASE", str(find_free_ports(nums=8, contiguous=True)[0])) + ), + tool_call_parser=os.environ.get("XTUNER_CLAUDECODE_TOOL_CALL_PARSER", "qwen3p5"), + reasoning_parser=os.environ.get("XTUNER_CLAUDECODE_REASONING_PARSER", "qwen3"), + api_host="127.0.0.1", + api_port=find_free_ports()[0], + ) + controller = rollout_config.build(placement_group) + gateway_host = ray.util.get_node_ip_address() + gateway_config = GatewayConfig( + host=gateway_host, + port=find_free_ports(host=gateway_host)[0], + capture_folder=str(worker_log_dir / "gateway_captures"), + ) + gateway_url = ray.get(controller.start_gateway.remote(gateway_config), timeout=1800) + wait_for_gateway_ready(gateway_url) + return gateway_url, controller, placement_group + + +def _dump_rollout_states(output_dir: Path, case_name: str, states: list[RolloutState]) -> None: + output_path = output_dir / f"rollout-states-{case_name}.json" + payload = [_redact_rollout_state_for_dump(state) for state in states] + output_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") + print(f"Claude Code rollout states written to {output_path}") + + +def _redact_rollout_state_for_dump(state: RolloutState) -> dict: + payload = state.model_dump(mode="json") + for key in ("prompt_ids", "response_ids"): + payload.pop(key, None) + extra_fields = payload.get("extra_fields") + if isinstance(extra_fields, dict): + records = extra_fields.get("gateway_trace_records") + if isinstance(records, list): + for record in records: + if isinstance(record, dict): + record.pop("prompt_ids", None) + record.pop("response_ids", None) + record.pop("tokens", None) + return payload + + +def _cleanup_ray(*, controller: Any, placement_group: Any) -> None: + import ray + + if controller is not None: + try: + ray.get(controller.shutdown.remote(), timeout=300) + except Exception: + pass + try: + ray.kill(controller, no_restart=True) + except Exception: + pass + if placement_group is not None: + ray.util.remove_placement_group(placement_group) + if ray.is_initialized(): + ray.shutdown() + + +async def _main_async() -> None: + states = await test_claude_code_with_calculator(model_path=os.environ["ROLLOUT_MODEL_PATH"]) + completed_count = sum(state.status == Status.COMPLETED for state in states) + print(f"Claude Code calculator tool E2E passed with {completed_count} completed rollout states.") + + +def main() -> None: + asyncio.run(_main_async()) + + +if __name__ == "__main__": + main() diff --git a/examples/v1/config/agent_rl_interns1_pro_mini_grpo.py b/recipe/lagent/agent_rl_interns1_pro_mini_grpo.py similarity index 100% rename from examples/v1/config/agent_rl_interns1_pro_mini_grpo.py rename to recipe/lagent/agent_rl_interns1_pro_mini_grpo.py diff --git a/xtuner/v1/train/agent_rl_trainer.py b/recipe/lagent/agent_rl_trainer.py similarity index 100% rename from xtuner/v1/train/agent_rl_trainer.py rename to recipe/lagent/agent_rl_trainer.py diff --git a/xtuner/v1/ray/environment/agent_env.py b/recipe/lagent/environment/agent_env.py similarity index 100% rename from xtuner/v1/ray/environment/agent_env.py rename to recipe/lagent/environment/agent_env.py diff --git a/xtuner/v1/ray/environment/composed_env.py b/recipe/lagent/environment/composed_env.py similarity index 100% rename from xtuner/v1/ray/environment/composed_env.py rename to recipe/lagent/environment/composed_env.py diff --git a/xtuner/v1/ray/environment/lagent/__init__.py b/recipe/lagent/environment/lagent/__init__.py similarity index 100% rename from xtuner/v1/ray/environment/lagent/__init__.py rename to recipe/lagent/environment/lagent/__init__.py diff --git a/xtuner/v1/ray/environment/lagent/agents/__init__.py b/recipe/lagent/environment/lagent/agents/__init__.py similarity index 100% rename from xtuner/v1/ray/environment/lagent/agents/__init__.py rename to recipe/lagent/environment/lagent/agents/__init__.py diff --git a/xtuner/v1/ray/environment/lagent/agents/env_agent.py b/recipe/lagent/environment/lagent/agents/env_agent.py similarity index 100% rename from xtuner/v1/ray/environment/lagent/agents/env_agent.py rename to recipe/lagent/environment/lagent/agents/env_agent.py diff --git a/xtuner/v1/ray/environment/lagent/agents/jugder_wrapper.py b/recipe/lagent/environment/lagent/agents/jugder_wrapper.py similarity index 100% rename from xtuner/v1/ray/environment/lagent/agents/jugder_wrapper.py rename to recipe/lagent/environment/lagent/agents/jugder_wrapper.py diff --git a/xtuner/v1/ray/environment/lagent/agents/tito_agent.py b/recipe/lagent/environment/lagent/agents/tito_agent.py similarity index 100% rename from xtuner/v1/ray/environment/lagent/agents/tito_agent.py rename to recipe/lagent/environment/lagent/agents/tito_agent.py diff --git a/xtuner/v1/ray/environment/lagent/llms/__init__.py b/recipe/lagent/environment/lagent/llms/__init__.py similarity index 100% rename from xtuner/v1/ray/environment/lagent/llms/__init__.py rename to recipe/lagent/environment/lagent/llms/__init__.py diff --git a/xtuner/v1/ray/environment/lagent/llms/controller_wrapper.py b/recipe/lagent/environment/lagent/llms/controller_wrapper.py similarity index 100% rename from xtuner/v1/ray/environment/lagent/llms/controller_wrapper.py rename to recipe/lagent/environment/lagent/llms/controller_wrapper.py diff --git a/xtuner/v1/ray/environment/lagent/parsers.py b/recipe/lagent/environment/lagent/parsers.py similarity index 100% rename from xtuner/v1/ray/environment/lagent/parsers.py rename to recipe/lagent/environment/lagent/parsers.py diff --git a/xtuner/v1/ray/environment/lagent/schema.py b/recipe/lagent/environment/lagent/schema.py similarity index 100% rename from xtuner/v1/ray/environment/lagent/schema.py rename to recipe/lagent/environment/lagent/schema.py diff --git a/xtuner/v1/ray/environment/lagent/tokenize.py b/recipe/lagent/environment/lagent/tokenize.py similarity index 100% rename from xtuner/v1/ray/environment/lagent/tokenize.py rename to recipe/lagent/environment/lagent/tokenize.py diff --git a/tests/ray/test_auto.py b/recipe/verl_agent/__init__.py similarity index 100% rename from tests/ray/test_auto.py rename to recipe/verl_agent/__init__.py diff --git a/recipe/verl_agent/common/__init__.py b/recipe/verl_agent/common/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/recipe/verl_agent/common/agent_loop_verl_tool.py b/recipe/verl_agent/common/agent_loop_verl_tool.py new file mode 100644 index 0000000000..b607ff7f04 --- /dev/null +++ b/recipe/verl_agent/common/agent_loop_verl_tool.py @@ -0,0 +1,154 @@ +from typing import Any, Optional + +from omegaconf import DictConfig +from verl.experimental.agent_loop.agent_loop import AgentLoopOutput, DictConfigWrap +from verl.experimental.agent_loop.tool_agent_loop import ToolAgentLoop +from verl.utils.dataset.rl_dataset import get_dataset_class +from verl.workers.rollout.replica import TokenOutput + +from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status +from xtuner.v1.rl.judger import Judger +from xtuner.v1.rl.rollout.controller import RolloutControllerProxy +from xtuner.v1.rl.agent_loop import AgentLoop, AgentLoopConfig + + +class VerlToolAgentLoopConfig(AgentLoopConfig): + config: DictConfig + + def build( + self, + rollout_controller: RolloutControllerProxy, + judger: Judger | None = None, + logger=None, + ) -> "VerlToolAgentLoop": + verl_tool_agent_loop = VerlToolAgentLoop( + rollout_controller=rollout_controller, + sample_params=self.sample_params, + hf_checkpoint=self.hf_checkpoint, + config=self.config, + judger=judger, + ) + return verl_tool_agent_loop + + +class XtunerAsyncLLMServerManager: + def __init__(self, rollout_controller: RolloutControllerProxy): + self.rollout_controller = rollout_controller + + async def generate( + self, + request_id: str, + *, + prompt_ids: list[int], + sampling_params: dict[str, Any], + image_data: Optional[list[Any]] = None, + video_data: Optional[list[Any]] = None, + ) -> TokenOutput: + sample_params = SampleParams( + return_token_ids=True, + temperature=sampling_params.get("temperature", 1.0), + top_p=sampling_params.get("top_p", 1.0), + top_k=sampling_params.get("top_k", 0), + repetition_penalty=sampling_params.get("repetition_penalty", 1.0), + return_logprob=bool(sampling_params.get("logprobs", True)), + ) + + # session_id is set in the VerlToolAgentLoop.generate_sample + # and ignore request_id generated by verl.ToolAgentLoop.run + session_uid = sampling_params.get("session_uid", -1) + + rollout_state = RolloutState( + message=[], + tokens=prompt_ids, + session_uid=session_uid, + sample_params=sample_params, + ) + + response: RolloutState = await self.rollout_controller.generate.remote( + rollout_state=rollout_state, + ) + + finish_reason = response.finish_reason + + return TokenOutput( + token_ids=response.response_ids or [], + log_probs=response.logprobs, + routed_experts=response.routed_experts, + stop_reason=finish_reason, + ) + + +class VerlToolAgentLoop(AgentLoop): + def __init__( + self, + rollout_controller: RolloutControllerProxy, + sample_params: SampleParams, + hf_checkpoint: str, + config: DictConfig, + judger: Judger | None = None, + logger=None, + ): + super().__init__(rollout_controller, sample_params, hf_checkpoint, judger, logger) + + server_manager = XtunerAsyncLLMServerManager(rollout_controller) + + dataset_cls = get_dataset_class(config.data) + + self.verl_tool_agent_loop = ToolAgentLoop( + trainer_config=DictConfigWrap(config=config), + server_manager=server_manager, + tokenizer=self.tokenizer, + processor=self.processor, + dataset_cls=dataset_cls, + data_config=DictConfigWrap(config.data), + ) + + async def generate_sample(self, rollout_state: RolloutState) -> RolloutState: + assert rollout_state.sample_params is not None, "sample_params must be set in rollout_state" + + # convert rollout_state to verl_tool_agent_loop input + sp = rollout_state.sample_params + sampling_params = dict( + temperature=sp.temperature, + top_p=sp.top_p, + top_k=sp.top_k, + repetition_penalty=sp.repetition_penalty, + logprobs=sp.return_logprob, + # session_id is used to identify the session in the server manager + session_uid=rollout_state.session_uid, + ) + + input_kwargs = { + "raw_prompt": rollout_state.message, + "tools_kwargs": rollout_state.extra_fields.get("tools_kwargs", {}), + } + + # run verl_tool_agent_loop + try: + output: AgentLoopOutput = await self.verl_tool_agent_loop.run(sampling_params, **input_kwargs) + except Exception as e: + rollout_state.status = Status.FAILED + rollout_state.error_msg = str(e) + self.logger.error(f"[VerlToolAgentLoop][{rollout_state.session_uid}] generate_sample failed: {e}") + return rollout_state + # TODO: handle samples with corrupted tool tokens ? + + # convert verl_tool_agent_loop output to rollout_state + rollout_state.prompt_ids = output.prompt_ids + rollout_state.response_ids = output.response_ids + rollout_state.logprobs = output.response_logprobs + rollout_state.routed_experts = output.routed_experts + rollout_state.response_mask = output.response_mask + rollout_state.status = Status.COMPLETED + rollout_state.extra_fields.update(output.extra_fields) + # judger needs response in text format + rollout_state.response = self.tokenizer.decode(rollout_state.response_ids) + # for trajectory dump, we need to add raw_prompt to extra_fields + # raw_prompt is updated in tool_agent_loop: apply_chat_template of tools + rollout_state.extra_fields["raw_prompt"] = self.tokenizer.decode(rollout_state.prompt_ids) + + # judge rollout_state + if self.judger is not None: + rollout_state = await self.judger.judge(rollout_state) + + return rollout_state diff --git a/recipe/verl_agent/gsm8k_tool_example/gsm8k_tool_grpo_config.py b/recipe/verl_agent/gsm8k_tool_example/gsm8k_tool_grpo_config.py new file mode 100644 index 0000000000..924ed08514 --- /dev/null +++ b/recipe/verl_agent/gsm8k_tool_example/gsm8k_tool_grpo_config.py @@ -0,0 +1,219 @@ +import os +from pathlib import Path + +from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig +from xtuner.v1.data_proto.rl_data import SampleParams +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLTextTokenizeFnConfig +from xtuner.v1.model import get_model_config_from_hf +from xtuner.v1.rl.utils import AcceleratorResourcesConfig, create_task +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.judger.gsm8k import GSM8KJudgerConfig +from xtuner.v1.rl.replay_buffer import SyncReplayBufferConfig +from xtuner.v1.rl.trainer import WorkerConfig +from xtuner.v1.rl.agent_loop_manager import AgentLoopManagerConfig, SamplerConfig, SyncProduceStrategyConfig +from xtuner.v1.rl.evaluator import EvaluatorConfig +from xtuner.v1.rl.loss import GRPOLossConfig +from xtuner.v1.train.rl_trainer import RLColocateTrainerConfig +from recipe.verl_agent.common.agent_loop_verl_tool import VerlToolAgentLoopConfig +# env +work_dir = os.environ["WORK_DIR"] +model_path = os.environ["MODEL_PATH"] +data_path = os.environ["DATA_PATH"] +eval_data_path = os.environ["EVAL_DATA_PATH"] +enable_return_routed_experts = os.environ.get("ENABLE_RETURN_ROUTED_EXPERTS", "0") +WORLD_SIZE = int(os.environ.get("WORLD_SIZE", "1")) + +# basic settings +experimental_name = "verl_gsm8k_tool" +total_train_steps = 45 +evaluate_step = 45 +train_optimizer_steps = 1 +train_batch_size = 64 * train_optimizer_steps +prompt_repeat_k = 5 +rollout_tp_size = 1 +rollout_ep_size = 1 +max_prompt_length = 1024 +max_response_length = 1024 +pack_max_length = 32 * 1024 + +# 1. resources +resources = AcceleratorResourcesConfig( + accelerator="GPU", + num_workers=8 * WORLD_SIZE, + num_cpus_per_worker=12, + cpu_memory_per_worker=16 * 1024**3, # 16 GB +) + +# 2. rollout +rollout_config = RolloutConfig( + env=experimental_name, + device=resources.accelerator, + model_path=model_path, + dtype="bfloat16", + tensor_parallel_size=rollout_tp_size, + expert_parallel_size=rollout_ep_size, + gpu_memory_utilization=0.8, + context_length=max_response_length + max_prompt_length, + enable_return_routed_experts=(enable_return_routed_experts == "1"), +) + +# 3. judger +judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", num_ray_actors=1) + +# 4. train worker +lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) +fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) +model_cfg = get_model_config_from_hf(Path(model_path)) +if hasattr(model_cfg, "balancing_loss_cfg"): + model_cfg.balancing_loss_cfg = None +if hasattr(model_cfg, "z_loss_cfg"): + model_cfg.z_loss_cfg = None +optim_cfg = AdamWConfig(lr=1e-6, foreach=False, weight_decay=0.1) +loss_cfg = GRPOLossConfig( + policy_loss_cfg=dict( + cliprange_high=0.28, + cliprange_low=0.2, + loss_type=os.environ.get("LOSS_TYPE", "vanilla"), + clip_ratio_c=10.0, + log_prob_diff_min=-20.0, + log_prob_diff_max=20.0, + ), + ignore_idx=-100, + use_kl_loss=False, + kl_loss_coef=0.0, + kl_loss_type="low_var_kl", + mode=os.environ.get("LOSS_MODE", "chunk"), + chunk_size=512, +) +train_worker_cfg = WorkerConfig( + model_cfg=model_cfg, + load_from=model_path, + optim_cfg=optim_cfg, + loss_cfg=loss_cfg, + lr_cfg=lr_cfg, + fsdp_cfg=fsdp_cfg, + sp_size=int(os.environ.get("SP_SIZE", "1")), + optimizer_steps=train_optimizer_steps, + pack_max_length=pack_max_length, +) + +# 5.0 verl config +# gsm8k tool config +tool_config_path = "recipe/verl_agent/gsm8k_tool_example/tool_config/gsm8k_tool_config.yaml" +tool_call_parser_name = "hermes" + +from hydra import compose, initialize_config_dir +import verl + +verl_config_dir = os.path.join(os.path.dirname(verl.__file__), "trainer/config") +with initialize_config_dir(config_dir=verl_config_dir): + verl_config = compose( + config_name="ppo_trainer", + overrides=[ + "data.max_prompt_length=" + str(max_prompt_length), # also set rollout.prompt_length by OmegaConf's oc.select + "data.max_response_length=" + str(max_response_length), # also set rollout.response_length + "+data.apply_chat_template_kwargs.enable_thinking=False", + "actor_rollout_ref.rollout.multi_turn.format=" + tool_call_parser_name, + "actor_rollout_ref.rollout.multi_turn.tool_config_path=" + tool_config_path, + "actor_rollout_ref.rollout.multi_turn.max_tool_response_length=" + str(max_response_length), + "actor_rollout_ref.rollout.multi_turn.max_assistant_turns=5", + "actor_rollout_ref.rollout.multi_turn.enable=True", + ], + ) + +# 5.1 train agent loop +training_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=0, + top_p=1.0, + temperature=1.0, + min_tokens=0, +) +verl_tool_agent_loop_config = VerlToolAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=training_sample_params, + config=verl_config, +) + +# 5.2 train agent loop manager +train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) +tokenizer_config = RLTextTokenizeFnConfig(max_length=max_prompt_length) +train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] +dataloader_cfg = DataloaderConfig( + dataset_config_list=train_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +sampler_config = SamplerConfig( + dataloader_cfg=dataloader_cfg, + prompt_repeat_k=prompt_repeat_k, +) +produce_strategy_config = SyncProduceStrategyConfig() +agent_loop_manager_cfg = AgentLoopManagerConfig( + task_name="train_task", + agent_loop_config=verl_tool_agent_loop_config, + produce_strategy_config=produce_strategy_config, + sampler_config=sampler_config, +) + +# 6.1 eval agent loop +evaluation_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=1, + top_p=1.0, + temperature=0.0, + min_tokens=0, +) +eval_verl_tool_agent_loop_config = VerlToolAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=evaluation_sample_params, + config=verl_config, +) + +# 6.2 eval agent loop manager +eval_dataset = DatasetConfig( + name=experimental_name, anno_path=eval_data_path, sample_ratio=1.0 +) +eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] +eval_dataloader_cfg = DataloaderConfig( + dataset_config_list=eval_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +eval_sampler_config = SamplerConfig( + dataloader_cfg=eval_dataloader_cfg, + prompt_repeat_k=1, +) +eval_agent_loop_manager_cfg = AgentLoopManagerConfig( + task_name="eval_task", + agent_loop_config=eval_verl_tool_agent_loop_config, + sampler_config=eval_sampler_config, +) + +# 7. evaluator +evaluator_config = EvaluatorConfig(compute_metric_func=None) + +# 8. RL Colocate Trainer Config(CLI 通过 config["trainer"].build() 得到 Trainer) +trainer = RLColocateTrainerConfig( + resources=resources, + train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config + rollout_config=rollout_config, + judger_config=judger_config, + tokenizer_path=model_path, + replay_buffer_config=SyncReplayBufferConfig(), + agent_loop_manager_cfg=agent_loop_manager_cfg, + eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg, + evaluator_config=evaluator_config, + load_from=model_path, + total_train_steps=total_train_steps, + train_batch_size=train_batch_size, + enable_evaluate=True, + enable_initial_evaluate=False, + evaluate_step=evaluate_step, + work_dir=work_dir, + seed=123, + debug_rollout=False, +) diff --git a/recipe/verl_agent/gsm8k_tool_example/tool_config/gsm8k_tool_config.yaml b/recipe/verl_agent/gsm8k_tool_example/tool_config/gsm8k_tool_config.yaml new file mode 100644 index 0000000000..a4197baabf --- /dev/null +++ b/recipe/verl_agent/gsm8k_tool_example/tool_config/gsm8k_tool_config.yaml @@ -0,0 +1,16 @@ +tools: + - class_name: "verl.tools.gsm8k_tool.Gsm8kTool" + config: + type: native + tool_schema: + type: "function" + function: + name: "calc_gsm8k_reward" + description: "A tool for calculating the reward of gsm8k. (1.0 if parsed answer is correct, 0.0 if parsed answer is incorrect or not correctly parsed)" + parameters: + type: "object" + properties: + answer: + type: "string" + description: "The model's answer to the GSM8K math problem, must be a digits" + required: ["answer"] diff --git a/recipe/verl_agent/sandbox_example/__init__.py b/recipe/verl_agent/sandbox_example/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/recipe/verl_agent/sandbox_example/sandbox.py b/recipe/verl_agent/sandbox_example/sandbox.py new file mode 100644 index 0000000000..81a9ac2543 --- /dev/null +++ b/recipe/verl_agent/sandbox_example/sandbox.py @@ -0,0 +1,53 @@ +import re + +import aiohttp +from transformers.utils import get_json_schema + +from verl.tools.base_tool import BaseTool, OpenAIFunctionToolSchema, ToolResponse + + +class SandboxTool(BaseTool): + def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): + super().__init__(config, tool_schema) + self.code_pattern = re.compile(r"```py(.*?)```", re.DOTALL) + + async def code_interpreter(self, code: str) -> str: + """Execute the code in the sandbox. + + Args: + code: The code to be executed. + + Returns: + str: The output of the code execution. + """ + async with aiohttp.ClientSession() as session: + async with session.post( + self.config.get("sandbox_fusion_url"), + json={"code": code}, + ) as resp: + resp.raise_for_status() + result = await resp.json() + stdout, stderr = result["run_result"]["stdout"], result["run_result"]["stderr"] + return stdout + stderr + + def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: + schema = get_json_schema(self.code_interpreter) + return OpenAIFunctionToolSchema(**schema) + + async def execute(self, instance_id: str, parameters: dict, **kwargs) -> tuple[str, float, dict]: + code = parameters["code"] + matches = self.code_pattern.findall(code) + if matches: + code = matches[0].strip() + + lines = code.split("\n") + for i, line in reversed(list(enumerate(lines))): + if line == "": + continue + if not lines[i].startswith("print"): + lines[i] = f"print({line})" + break + code = "\n".join(lines) + + result = await self.code_interpreter(code) + return ToolResponse(text=result), 0.0, {} diff --git a/recipe/verl_agent/sandbox_example/sandbox_grpo_config.py b/recipe/verl_agent/sandbox_example/sandbox_grpo_config.py new file mode 100644 index 0000000000..0da0720bdc --- /dev/null +++ b/recipe/verl_agent/sandbox_example/sandbox_grpo_config.py @@ -0,0 +1,313 @@ +"""RL Colocate Trainer 示例配置(GRPO + GSM8K)。 + +用法:通过环境变量传入路径后,由 CLI 加载本配置并 trainer_cfg.build().fit()。 +需设置: WORK_DIR, MODEL_PATH, DATA_PATH, EVAL_DATA_PATH +可选: WORLD_SIZE, ENABLE_RETURN_ROUTED_EXPERTS, LOSS_TYPE, LOSS_MODE, SP_SIZE +""" +import os +from pathlib import Path + +from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig +from xtuner.v1.data_proto.rl_data import SampleParams +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLTextTokenizeFnConfig +from xtuner.v1.model import get_model_config_from_hf +from xtuner.v1.rl.utils import AcceleratorResourcesConfig +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.judger.gsm8k import GSM8KJudgerConfig +from xtuner.v1.rl.utils import create_task +from xtuner.v1.rl.replay_buffer import SyncReplayBufferConfig +from xtuner.v1.rl.trainer import WorkerConfig +from xtuner.v1.rl.agent_loop_manager import AgentLoopManagerConfig, SamplerConfig, SyncProduceStrategyConfig +from xtuner.v1.rl.evaluator import EvaluatorConfig +from xtuner.v1.rl.loss import GRPOLossConfig +from xtuner.v1.train.rl_trainer import RLColocateTrainerConfig +from recipe.verl_agent.common.agent_loop_verl_tool import VerlToolAgentLoopConfig +# env +work_dir = os.environ["WORK_DIR"] +model_path = os.environ["MODEL_PATH"] +data_path = os.environ["DATA_PATH"] +eval_data_path = os.environ["EVAL_DATA_PATH"] +enable_return_routed_experts = os.environ.get("ENABLE_RETURN_ROUTED_EXPERTS", "0") +WORLD_SIZE = int(os.environ.get("WORLD_SIZE", "1")) + +# basic settings +experimental_name = "grpo_gsm8k_verl_tool" +total_train_steps = 45 +evaluate_step = 45 +train_optimizer_steps = 1 +train_batch_size = 64 * train_optimizer_steps +prompt_repeat_k = 5 +rollout_tp_size = 1 +rollout_ep_size = 1 +max_prompt_length = 512 +max_response_length = 2048 +pack_max_length = 32 * 1024 + +# 1. resources +resources = AcceleratorResourcesConfig( + accelerator="GPU", + num_workers=8 * WORLD_SIZE, + num_cpus_per_worker=12, + cpu_memory_per_worker=16 * 1024**3, # 16 GB +) + +# 2. rollout +rollout_config = RolloutConfig( + env=experimental_name, + device=resources.accelerator, + model_path=model_path, + dtype="bfloat16", + tensor_parallel_size=rollout_tp_size, + expert_parallel_size=rollout_ep_size, + gpu_memory_utilization=0.8, + context_length=max_response_length + max_prompt_length, + enable_return_routed_experts=(enable_return_routed_experts == "1"), +) + +# 3. judger +judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", num_ray_actors=1) + +# 4. train worker +lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) +fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) +model_cfg = get_model_config_from_hf(Path(model_path)) +if hasattr(model_cfg, "balancing_loss_cfg"): + model_cfg.balancing_loss_cfg = None +if hasattr(model_cfg, "z_loss_cfg"): + model_cfg.z_loss_cfg = None +optim_cfg = AdamWConfig(lr=1e-6, foreach=False, weight_decay=0.1) +loss_cfg = GRPOLossConfig( + policy_loss_cfg=dict( + cliprange_high=0.28, + cliprange_low=0.2, + loss_type=os.environ.get("LOSS_TYPE", "vanilla"), + clip_ratio_c=10.0, + log_prob_diff_min=-20.0, + log_prob_diff_max=20.0, + ), + ignore_idx=-100, + use_kl_loss=False, + kl_loss_coef=0.0, + kl_loss_type="low_var_kl", + mode=os.environ.get("LOSS_MODE", "chunk"), + chunk_size=512, +) +train_worker_cfg = WorkerConfig( + model_cfg=model_cfg, + load_from=model_path, + optim_cfg=optim_cfg, + loss_cfg=loss_cfg, + lr_cfg=lr_cfg, + fsdp_cfg=fsdp_cfg, + sp_size=int(os.environ.get("SP_SIZE", "1")), + optimizer_steps=train_optimizer_steps, + pack_max_length=pack_max_length, +) + +# code sand box just for toy example +import ray +import asyncio +import socket +import tempfile +import sys +import fastapi +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse +import uvicorn +import json + +@ray.remote(num_cpus=1) +class Sandbox: + """Sandbox to execute python code.""" + + def __init__(self): + self.address = ray._private.services.get_node_ip_address() + self.port = self._get_free_port() + create_task(self._start_fastapi_server()) + + async def code_execution(self, request: Request): + request_json = await request.json() + code = request_json["code"] + # print(f"execute code:\n{code}") + + _, temp_file = tempfile.mkstemp(suffix=".py", prefix="temp_code", dir=None, text=True) + with open(temp_file, "w") as f: + f.write(code) + + try: + process = await asyncio.create_subprocess_exec( + sys.executable, temp_file, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE + ) + + stdout, stderr = await process.communicate() + + response = { + "status": "Success" if process.returncode == 0 else "Failed", + "run_result": { + "status": "Finished", + "stdout": stdout.decode(), + "stderr": stderr.decode(), + "return_code": process.returncode, + }, + } + return JSONResponse(content=response) + finally: + try: + os.unlink(temp_file) + except Exception: + pass + + def _get_free_port(self): + with socket.socket() as sock: + sock.bind(("", 0)) + return sock.getsockname()[1] + + async def _start_fastapi_server(self): + app = fastapi.FastAPI() + app.router.add_api_route("/run_code", self.code_execution, methods=["POST"]) + + config = uvicorn.Config(app, host=["::", "0.0.0.0"], port=self.port, log_level="warning") + server = uvicorn.Server(config) + await server.serve() + + async def get_server_address(self) -> str: + """Get FastAPI server address.""" + return f"{self.address}:{self.port}" + +sandbox = Sandbox.remote() +sandbox_address = ray.get(sandbox.get_server_address.remote()) +print(f"Sandbox server address: {sandbox_address}") +tool_config = { + "tools": [ + { + "class_name": "recipe.verl_agent.sandbox_example.sandbox.SandboxTool", + "config": { + "type": "native", + "sandbox_fusion_url": f"http://{sandbox_address}/run_code", + }, + }, + ], +} + +tool_config_path = "tool_config.json" +with open(tool_config_path, "w") as f: + json.dump(tool_config, f) + +# 5.0 verl config +tool_call_parser_name = "hermes" + +from hydra import compose, initialize_config_dir +import verl + +verl_config_dir = os.path.join(os.path.dirname(verl.__file__), "trainer/config") +with initialize_config_dir(config_dir=verl_config_dir): + verl_config = compose( + config_name="ppo_trainer", + overrides=[ + "data.max_prompt_length=" + str(max_prompt_length), # also set rollout.prompt_length by OmegaConf's oc.select + "data.max_response_length=" + str(max_response_length), # also set rollout.response_length + "+data.apply_chat_template_kwargs.enable_thinking=False", + "actor_rollout_ref.rollout.multi_turn.format=" + tool_call_parser_name, + "actor_rollout_ref.rollout.multi_turn.tool_config_path=" + tool_config_path, + "actor_rollout_ref.rollout.multi_turn.max_tool_response_length=" + str(max_response_length), + "actor_rollout_ref.rollout.multi_turn.max_assistant_turns=5", + "actor_rollout_ref.rollout.multi_turn.enable=True", + ], + ) + +# 5.1 train agent loop +training_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=0, + top_p=1.0, + temperature=1.0, + min_tokens=0, +) +verl_tool_agent_loop_config = VerlToolAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=training_sample_params, + config=verl_config, +) + +# 5.2 train agent loop manager +train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) +tokenizer_config = RLTextTokenizeFnConfig(max_length=max_prompt_length) +train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] +dataloader_cfg = DataloaderConfig( + dataset_config_list=train_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +sampler_config = SamplerConfig( + dataloader_cfg=dataloader_cfg, + prompt_repeat_k=prompt_repeat_k, +) +produce_strategy_config = SyncProduceStrategyConfig() +agent_loop_manager_cfg = AgentLoopManagerConfig( + task_name="train_task", + agent_loop_config=verl_tool_agent_loop_config, + produce_strategy_config=produce_strategy_config, + sampler_config=sampler_config, +) + +# 6.1 eval agent loop +evaluation_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=1, + top_p=1.0, + temperature=0.0, + min_tokens=0, +) +eval_verl_tool_agent_loop_config = VerlToolAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=evaluation_sample_params, + config=verl_config, +) + +# 6.2 eval agent loop manager +eval_dataset = DatasetConfig( + name=experimental_name, anno_path=eval_data_path, sample_ratio=1.0 +) +eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] +eval_dataloader_cfg = DataloaderConfig( + dataset_config_list=eval_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", +) +eval_sampler_config = SamplerConfig( + dataloader_cfg=eval_dataloader_cfg, + prompt_repeat_k=1, +) +eval_agent_loop_manager_cfg = AgentLoopManagerConfig( + task_name="eval_task", + agent_loop_config=eval_verl_tool_agent_loop_config, + sampler_config=eval_sampler_config, +) + +# 7. evaluator +evaluator_config = EvaluatorConfig(compute_metric_func=None) + +# 8. RL Colocate Trainer Config(CLI 通过 config["trainer"].build() 得到 Trainer) +trainer = RLColocateTrainerConfig( + resources=resources, + train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config + rollout_config=rollout_config, + judger_config=judger_config, + tokenizer_path=model_path, + replay_buffer_config=SyncReplayBufferConfig(), + agent_loop_manager_cfg=agent_loop_manager_cfg, + eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg, + evaluator_config=evaluator_config, + load_from=model_path, + total_train_steps=total_train_steps, + train_batch_size=train_batch_size, + enable_evaluate=True, + enable_initial_evaluate=False, + evaluate_step=evaluate_step, + work_dir=work_dir, + seed=123, + debug_rollout=False, +) diff --git a/recipe/verl_agent/sandbox_example/test_verl_tool_agent_loop.py b/recipe/verl_agent/sandbox_example/test_verl_tool_agent_loop.py new file mode 100644 index 0000000000..fa2272fc7f --- /dev/null +++ b/recipe/verl_agent/sandbox_example/test_verl_tool_agent_loop.py @@ -0,0 +1,406 @@ +import os +import sys +import json +import socket +import asyncio +import tempfile +import unittest + +import ray +import torch +import fastapi +import uvicorn +from fastapi import Request +from fastapi.responses import JSONResponse +from transformers import AutoTokenizer + +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers +from recipe.verl_agent.common.agent_loop_verl_tool import VerlToolAgentLoopConfig +from xtuner.v1.rl.agent_loop_manager import AgentLoopManagerConfig, SamplerConfig, SyncProduceStrategyConfig +from xtuner.v1.data_proto.rl_data import RolloutState, Status, SampleParams +from xtuner.v1.rl.rollout import RolloutController +from xtuner.v1.rl.judger.gsm8k import GSM8KJudgerConfig +from xtuner.v1.rl.utils import create_task +from xtuner.v1.rl.replay_buffer import SyncReplayBufferConfig +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLTextTokenizeFnConfig + +MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] +TRAIN_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] +TEST_DATA_PATH = os.environ["ROLLOUT_TEST_DATA_PATH"] +VERL_TRAIN_DATA_PATH = "/fake/path/to/train.parquet" +VERL_TEST_DATA_PATH = "/fake/path/to/test.parquet" + +FAKE_INPUT_ITEM = RolloutState( + message=[{ + 'role': 'user', + 'content': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Let\'s think step by step and output the final answer after "####"' + }], + reward_model={'ground_truth': '72', 'style': 'rule'}, +) + +resource_map = { + "npu": "NPU", + "cuda": "GPU", +} + + +@ray.remote(num_cpus=1) +class Sandbox: + """Sandbox to execute python code for tool-calling agent tests.""" + + def __init__(self): + self.address = ray._private.services.get_node_ip_address() + self.port = self._get_free_port() + create_task(self._start_fastapi_server()) + + async def code_execution(self, request: Request): + request_json = await request.json() + code = request_json["code"] + + _, temp_file = tempfile.mkstemp( + suffix=".py", prefix="temp_code", dir=None, text=True + ) + with open(temp_file, "w") as f: + f.write(code) + + try: + process = await asyncio.create_subprocess_exec( + sys.executable, + temp_file, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await process.communicate() + response = { + "status": "Success" if process.returncode == 0 else "Failed", + "run_result": { + "status": "Finished", + "stdout": stdout.decode(), + "stderr": stderr.decode(), + "return_code": process.returncode, + }, + } + return JSONResponse(content=response) + finally: + try: + os.unlink(temp_file) + except Exception: + pass + + def _get_free_port(self): + with socket.socket() as sock: + sock.bind(("", 0)) + return sock.getsockname()[1] + + async def _start_fastapi_server(self): + app = fastapi.FastAPI() + app.router.add_api_route( + "/run_code", self.code_execution, methods=["POST"] + ) + config = uvicorn.Config( + app, host=["::", "0.0.0.0"], port=self.port, log_level="warning" + ) + server = uvicorn.Server(config) + await server.serve() + + async def get_server_address(self) -> str: + return f"{self.address}:{self.port}" + + +def _build_verl_config( + model_path: str, + train_file: str, + test_file: str, + tool_config_path: str, + max_prompt_length: int, + max_response_length: int, + rollout_name: str = "sglang", + tool_call_parser_name: str = "hermes", +): + from hydra import compose, initialize_config_dir + import verl + + verl_config_dir = os.path.join( + os.path.dirname(verl.__file__), "trainer/config" + ) + with initialize_config_dir(config_dir=verl_config_dir): + verl_config = compose( + config_name="ppo_trainer", + overrides=[ + "algorithm.adv_estimator=grpo", + "data.train_files=" + train_file, + "data.val_files=" + test_file, + "data.return_raw_chat=True", + "data.train_batch_size=32", + "data.max_prompt_length=" + str(max_prompt_length), + "data.max_response_length=" + str(max_response_length), + "+data.apply_chat_template_kwargs.enable_thinking=False", + "actor_rollout_ref.model.path=" + model_path, + "actor_rollout_ref.actor.ppo_mini_batch_size=8", + "actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8", + "actor_rollout_ref.actor.fsdp_config.param_offload=True", + "actor_rollout_ref.actor.fsdp_config.optimizer_offload=True", + "actor_rollout_ref.rollout.name=" + rollout_name, + "actor_rollout_ref.rollout.mode=async", + "actor_rollout_ref.rollout.tensor_model_parallel_size=1", + "actor_rollout_ref.rollout.n=8", + "actor_rollout_ref.rollout.response_length=" + str(max_response_length), + "actor_rollout_ref.rollout.skip_tokenizer_init=False", + "+actor_rollout_ref.rollout.engine_kwargs.vllm.enable_auto_tool_choice=True", + "+actor_rollout_ref.rollout.engine_kwargs.vllm.tool_call_parser=hermes", + "+actor_rollout_ref.rollout.engine_kwargs.sglang.tool_call_parser=qwen25", + "actor_rollout_ref.rollout.multi_turn.format=" + tool_call_parser_name, + "actor_rollout_ref.rollout.multi_turn.tool_config_path=" + tool_config_path, + "+actor_rollout_ref.rollout.multi_turn.multi_turn.max_tool_response_length=" + str(max_response_length), + "actor_rollout_ref.rollout.agent.default_agent_loop=tool_agent", + "actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8", + "trainer.val_before_train=True", + "trainer.log_val_generations=10", + "trainer.n_gpus_per_node=8", + "trainer.test_freq=-1", + "trainer.total_training_steps=5", + "trainer.logger=['console','tensorboard']", + "trainer.project_name=verl", + "trainer.experiment_name=test_verl_tool_agent_loop", + ], + ) + return verl_config + + +class TestVerlToolAgentLoop(unittest.IsolatedAsyncioTestCase): + @classmethod + def setUpClass(cls) -> None: + os.environ["XTUNER_USE_FA3"] = "1" + os.environ["LMD_SKIP_WARMUP"] = "1" + + @classmethod + def tearDownClass(cls) -> None: + del os.environ["XTUNER_USE_FA3"] + del os.environ["LMD_SKIP_WARMUP"] + + def init_config(self): + self.resources_cfg = AcceleratorResourcesConfig( + accelerator=resource_map[torch.accelerator.current_accelerator().type], + num_workers=1, + num_cpus_per_worker=8, + cpu_memory_per_worker=16 * 1024**3, + ) + self.max_prompt_length = 512 + self.max_response_length = 4096 + self.context_length = self.max_prompt_length + self.max_response_length + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_path, trust_remote_code=True + ) + + def setUp(self): + ray.init(num_cpus=80, ignore_reinit_error=True) + self.model_path = MODEL_PATH + self.data_path = TRAIN_DATA_PATH + self.temp_dir = tempfile.TemporaryDirectory() + self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") + + def tearDown(self): + ray.shutdown() + self.temp_dir.cleanup() + + def _setup_sandbox_and_verl_config(self): + """Create sandbox actor and verl config, return (verl_config, tool_config_path).""" + sandbox = Sandbox.remote() + self._sandbox = sandbox + # TODO: replace with a real sandbox server address + sandbox_address = ray.get(sandbox.get_server_address.remote()) + + tool_config = { + "tools": [ + { + "class_name": "recipe.verl_agent.sandbox_example.sandbox.SandboxTool", + "config": { + "type": "native", + "sandbox_fusion_url": f"http://{sandbox_address}/run_code", + }, + }, + ], + } + tool_config_path = os.path.join(self.temp_dir.name, "tool_config.json") + with open(tool_config_path, "w") as f: + json.dump(tool_config, f) + + verl_config = _build_verl_config( + model_path=self.model_path, + train_file=VERL_TRAIN_DATA_PATH, + test_file=VERL_TEST_DATA_PATH, + tool_config_path=tool_config_path, + max_prompt_length=self.max_prompt_length, + max_response_length=self.max_response_length, + ) + return verl_config + + async def test_verl_tool_agent_loop(self): + # 1. 初始化 config + self.init_config() + verl_config = self._setup_sandbox_and_verl_config() + + rollout_config = RolloutConfig( + env="test_verl_tool_agent_loop", + model_path=self.model_path, + model_name=os.path.basename(self.model_path).lower(), + tokenizer_path=self.model_path, + context_length=self.context_length, + worker_log_dir=self.worker_log_dir, + ) + judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", num_ray_actors=1) + + training_sample_params = SampleParams( + max_tokens=self.max_response_length, + top_k=0, + top_p=1.0, + temperature=1.0, + min_tokens=0, + return_token_ids=True, + return_logprob=True, + ) + agent_loop_cfg = VerlToolAgentLoopConfig( + hf_checkpoint=self.model_path, + sample_params=training_sample_params, + config=verl_config, + ) + + # 2. 创建 rollout_controller, judger + pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg) + rollout_controller = ray.remote(RolloutController).remote( + rollout_config, pg + ) + gsm8k_judger = judger_config.build() + + # 3. 创建 VerlToolAgentLoop + agent_loop = agent_loop_cfg.build( + rollout_controller=rollout_controller, judger=gsm8k_judger + ) + + # 4. 构造输入数据 + prompt_repeat_k = 4 + rollout_state = FAKE_INPUT_ITEM.model_copy(deep=True) + group_in_rollout_state = [ + FAKE_INPUT_ITEM.model_copy(deep=True) for _ in range(prompt_repeat_k) + ] + + # 5. 执行 generate_group && generate_sample + group_rollout_state = await agent_loop.generate_group(group_in_rollout_state) + single_rollout_state = await agent_loop.generate_sample(rollout_state) + + print(f"prompt: {single_rollout_state.extra_fields['raw_prompt']}") + print(f"response: {single_rollout_state.response}") + + # 6. 验证结果 + self.assertEqual(len(group_rollout_state), prompt_repeat_k) + for state in group_rollout_state: + self.assertEqual(state.status, Status.COMPLETED) + self.assertIsNotNone(state.response_ids) + self.assertGreater(len(state.response_ids), 0) + self.assertIsNotNone(state.prompt_ids) + self.assertIsNotNone(state.logprobs) + self.assertIsNotNone(state.loss_mask) + + self.assertEqual(single_rollout_state.status, Status.COMPLETED) + self.assertIsNotNone(single_rollout_state.response_ids) + self.assertGreater(len(single_rollout_state.response_ids), 0) + self.assertIsNotNone(single_rollout_state.prompt_ids) + self.assertIsNotNone(single_rollout_state.logprobs) + self.assertIsNotNone(single_rollout_state.loss_mask) + + async def test_verl_tool_agent_loop_manager(self): + # 1. 初始化 config + self.init_config() + verl_config = self._setup_sandbox_and_verl_config() + + rollout_config = RolloutConfig( + env="test_verl_tool_agent_loop_manager", + model_path=self.model_path, + model_name=os.path.basename(self.model_path).lower(), + tokenizer_path=self.model_path, + context_length=self.context_length, + worker_log_dir=self.worker_log_dir, + ) + judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", num_ray_actors=1) + + training_sample_params = SampleParams( + max_tokens=self.max_response_length, + top_k=0, + top_p=1.0, + temperature=1.0, + min_tokens=0, + ) + agent_loop_cfg = VerlToolAgentLoopConfig( + hf_checkpoint=self.model_path, + sample_params=training_sample_params, + config=verl_config, + ) + + prompt_repeat_k = 2 + sampler_config = SamplerConfig( + dataloader_cfg=DataloaderConfig( + dataset_config_list=[ + { + "dataset": DatasetConfig( + name="gsm8k", + anno_path=TRAIN_DATA_PATH, + sample_ratio=1.0, + ), + "tokenize_fn": RLTextTokenizeFnConfig( + max_length=self.max_prompt_length + ), + }, + ], + collator="fake_collator", + pack_level="none", + group_by_length=False, + ), + prompt_repeat_k=prompt_repeat_k, + ) + agent_loop_manager_cfg = AgentLoopManagerConfig( + task_name="test_verl_tool", + agent_loop_config=agent_loop_cfg, + produce_strategy_config=SyncProduceStrategyConfig(), + sampler_config=sampler_config, + ) + + # 2. 创建 rollout_controller, judger + pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg) + rollout_controller = ray.remote(RolloutController).remote( + rollout_config, pg + ) + gsm8k_judger = judger_config.build() + + # 3. 创建 AgentLoopManager + replay_buffer_cfg = SyncReplayBufferConfig() + replay_buffer = replay_buffer_cfg.build() + agent_loop_manager = agent_loop_manager_cfg.build( + rollout_controller=rollout_controller, + judger=gsm8k_judger, + tokenizer=self.tokenizer, + replay_buffer=replay_buffer, + ) + + # 4. 执行 produce_batch + results = await agent_loop_manager.produce_batch(batch_size=4, train_step=0, model_step=0) + batch_rollout_states = results.rollout_states + + # 5. 验证结果 + self.assertEqual(len(batch_rollout_states), 4) + for group_state in batch_rollout_states: + self.assertEqual(len(group_state), prompt_repeat_k) + group_message = group_state[0].message + for state in group_state: + self.assertEqual(state.status, Status.COMPLETED) + self.assertIsNotNone(state.response_ids) + self.assertGreater(len(state.response_ids), 0) + self.assertEqual(state.message, group_message) + self.assertIsNotNone(state.prompt_ids) + self.assertIsNotNone(state.logprobs) + self.assertIsNotNone(state.loss_mask) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/loss/test_grpo_loss.py b/tests/loss/test_grpo_loss.py index 7a1747a125..b007c08c2c 100644 --- a/tests/loss/test_grpo_loss.py +++ b/tests/loss/test_grpo_loss.py @@ -7,10 +7,9 @@ import torch import torch.distributed as dist import torch.nn as nn -from xtuner.v1.rl.grpo import GRPOLossConfig, GRPOLossContext +from xtuner.v1.rl.loss import GRPOLossConfig, GRPOLossContext, kl_penalty from xtuner.v1.data_proto import SequenceContext from xtuner.v1.rl.utils import gather_logprobs -from xtuner.v1.rl.loss_fn import kl_penalty from xtuner.v1.utils.test_utils import init_data_mesh diff --git a/tests/loss/test_oreal_loss.py b/tests/loss/test_oreal_loss.py index 2ceae417d0..7d6e545eb9 100644 --- a/tests/loss/test_oreal_loss.py +++ b/tests/loss/test_oreal_loss.py @@ -11,10 +11,9 @@ from torch.distributed.device_mesh import DeviceMesh, init_device_mesh import torch.nn as nn import torch.nn.functional as F -from xtuner.v1.rl.oreal.loss import OrealLossConfig, OrealLossContext +from xtuner.v1.rl.loss import OrealLossConfig, OrealLossContext, kl_penalty from xtuner.v1.data_proto import SequenceContext from xtuner.v1.rl.utils import gather_logprobs -from xtuner.v1.rl.loss_fn import kl_penalty from xtuner.v1.data_proto.utils import unpack_sequence from xtuner.v1.utils.test_utils import init_data_mesh diff --git a/tests/ray/test_evaluator.py b/tests/ray/test_evaluator.py deleted file mode 100644 index 321070f878..0000000000 --- a/tests/ray/test_evaluator.py +++ /dev/null @@ -1,112 +0,0 @@ -import os -import unittest -import ray -import tempfile -from transformers import AutoTokenizer - -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.judger.controller import JudgerConfig -from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers -from xtuner.v1.ray.environment import SingleTurnEnvironment -from xtuner.v1.ray.evaluator import Evaluator, EvaluatorConfig -from xtuner.v1.data_proto.rl_data import SampleParams -from xtuner.v1.datasets import RLTokenizeFnConfig, DatasetConfig, OpenaiTokenizeFunctionConfig - - -MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] -TEST_DATA_PATH = os.environ["ROLLOUT_TEST_DATA_PATH"] - - -class TestEvaluator(unittest.TestCase): - - @classmethod - def setUpClass(cls) -> None: - os.environ["XTUNER_USE_FA3"] = "1" - os.environ["LMD_SKIP_WARMUP"] = "1" - - @classmethod - def tearDownClass(cls) -> None: - del os.environ["XTUNER_USE_FA3"] - del os.environ["LMD_SKIP_WARMUP"] - - def init_config(self): - self.resources_cfg = AcceleratorResourcesConfig( - accelerator="GPU", - num_workers=8, - num_cpus_per_worker=8, - cpu_memory_per_worker=16 * 1024**3, # 16 GB - ) - self.max_prompt_length = 512 - self.max_response_length = 1024 - self.rollout_cfg = RolloutConfig( - env="test_rollout", - model_path=MODEL_PATH, - model_name=os.path.basename(MODEL_PATH).lower(), - tokenizer_path=MODEL_PATH, - tensor_parallel_size=8, - context_length=self.max_prompt_length + self.max_response_length, - worker_log_dir=self.worker_log_dir - ) - from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig - gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k") - self.judger_cfg = JudgerConfig( - reward_judger_configs=[gsm8k_judger_config], - worker_log_dir=self.worker_log_dir - ) - self.eval_dataset_cfg = [ - { - "dataset": DatasetConfig(name="gsm8k", - anno_path=TEST_DATA_PATH, - sample_ratio=1.0), - "tokenize_fn": RLTokenizeFnConfig(max_length=self.max_prompt_length) - }, - ] - self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) - self.pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg) - self.test_env = SingleTurnEnvironment.remote( - "test_env", - self.pg, - self.rollout_cfg, - None, - self.judger_cfg - ) - self.sample_params = SampleParams( - top_p=1.0, - temperature=0.0, - max_tokens=self.max_response_length, - top_k=1 - ) - - def setUp(self): - ray.init(num_cpus=80) - self.model_path = MODEL_PATH - self.temp_dir = tempfile.TemporaryDirectory() - self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") - self.init_config() - - def tearDown(self): - ray.shutdown() - self.temp_dir.cleanup() - - @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") - def test_lmdeploy_evaluator(self): - def custom_compute_metric(samples): - return {"custom_accuracy": sum(s.env.judger.reward["score"] > 0 for s in samples) / len(samples)} - - evaluator_cfg = EvaluatorConfig( - dataset_cfg=self.eval_dataset_cfg, - tokenizer=self.tokenizer, - max_concurrent=16, - eval_sample_ratio=0.004, # generate 5 samples - compute_metric_func=custom_compute_metric, - sample_params=self.sample_params, - worker_log_dir=self.worker_log_dir - ) - evaluator = Evaluator.remote(evaluator_cfg, self.test_env) - try: - ray.get(evaluator.run.remote()) - except Exception as e: - self.fail(f"evaluator.run.remote() raised an exception: {e}") - -if __name__ == '__main__': - unittest.main() diff --git a/tests/ray/test_judger.py b/tests/ray/test_judger.py deleted file mode 100644 index fec3c8b8c9..0000000000 --- a/tests/ray/test_judger.py +++ /dev/null @@ -1,231 +0,0 @@ -import os -import copy -import json -import ray -import unittest -import tempfile -import numpy as np -from uuid import uuid4 -from xtuner.v1.ray.judger.controller import JudgerController, JudgerConfig -from xtuner.v1.data_proto.rl_data import RLDataFlowItem, RLDatasetItem, RLEnvDataItem, RLRolloutResponseItem, RLUIDItem -from xtuner.v1.ray.base import AutoCPUWorkers, CPUResourcesConfig -MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] -DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] -GEO_ROLLOUT_DATA_PATH = os.environ["GEO_ROLLOUT_DATA_PATH"] -VERL_ROLLOUT_DATA_PATH = os.environ["VERL_ROLLOUT_DATA_PATH"] -DAPO_DATA_PATH = os.environ.get("ROLLOUT_DAPO_DATA_PATH") - -FAKE_JUDGER_INPUT_ITEM = RLDataFlowItem( - uid=RLUIDItem(action_id=uuid4().int, - observation_id=uuid4().int), - data=RLDatasetItem( - messages=[{ - 'role': 'user', - 'content': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Let\'s think step by step and output the final answer after "####"' - }], - num_tokens=62, - reward_model={'ground_truth': '72', 'style': 'rule'}, - ability='math', - data_source={'openai/gsm8k': 1.0} - ), - env=RLEnvDataItem( - rollout=RLRolloutResponseItem( - response="\nOkay, let's see. Natalia sold clips to 48 friends in April. Then in May, she sold half as many. So first, I need to figure out how many she sold in May. Half of 48 is 24, right? Because 48 divided by 2 is 24. So in May, she sold 24 clips.\n\nNow, to find the total number of clips sold in both months, I need to add the number from April and May together. That would be 48 (April) plus 24 (May). Let me do the addition: 48 + 24. Hmm, 40 + 20 is 60, and 8 + 4 is 12. So 60 + 12 is 72. So altogether, she sold 72 clips.\n\nWait, let me check that again. 48 plus 24. Yes, 48 + 20 is 68, then plus 4 more is 72. Yep, that seems right. So the total is 72.\n\n\nNatalia sold 48 clips in April. In May, she sold half as many, which is 48 ÷ 2 = 24 clips. Adding both months together: 48 + 24 = 72. \n\n#### 72<|im_end|>", - ) - ) -) -FAKE_JUDGER_INPUT_ITEM_1 = copy.deepcopy(FAKE_JUDGER_INPUT_ITEM) -FAKE_JUDGER_INPUT_ITEM_1.uid.observation_id = uuid4().int -FAKE_JUDGER_INPUT_ITEM_MULTI_DATA = [FAKE_JUDGER_INPUT_ITEM, FAKE_JUDGER_INPUT_ITEM_1] # 用action_id来标识是不同的输入数据 -FAKE_JUDGER_INPUT_ITEM_MULTI_SOURCE = copy.deepcopy(FAKE_JUDGER_INPUT_ITEM) -FAKE_JUDGER_INPUT_ITEM_MULTI_SOURCE.data.data_source = {'openai/gsm8k-1': 0.5, 'openai/gsm8k-2': 0.5} - - -def construct_judger_data(data_path): - dataitem = [] - with open(data_path, 'r', encoding='utf-8') as f: - for line_num, line in enumerate(f, 1): - data = json.loads(line.strip()) - data_item = RLDataFlowItem( - uid=RLUIDItem( - action_id=uuid4().int, - observation_id=uuid4().int - ), - data=RLDatasetItem( - messages=[{ - 'role': 'user', - 'content': data["input"][5:-11] - }], - reward_model={"ground_truth": data["gts"]}, - data_source={"openai/gsm8k": 1.0} - ), - env=RLEnvDataItem( - rollout=RLRolloutResponseItem(response=data['output']) - ) - ) - dataitem.append(data_item) - return dataitem - - -def construct_new_judger_data(data_path, judger_name='dapo_math'): - data_item_list = [] - save_reward = [] - with open(data_path, 'r', encoding='utf-8') as f: - lines = f.readlines() - for i in range(0, len(lines), 7): - group = ''.join(lines[i:i + 7]).strip() - if group: - try: - item = json.loads(group) - data_item = RLDataFlowItem( - uid=RLUIDItem( - action_id=uuid4().int, - observation_id=uuid4().int - ), - data=RLDatasetItem( - messages=[{ - 'role': 'user', - 'content': "" - }], - reward_model={"ground_truth": item["label"]}, - data_source={judger_name: 1.0} - ), - env=RLEnvDataItem( - rollout=RLRolloutResponseItem(response=item['response']) - ) - ) - data_item_list.append(data_item) - save_reward.append(item["reward"]) - except Exception as e: - print(f"Error parsing group starting at line {i + 12}: {e}") - return data_item_list, save_reward - - -class TestJudgerController(unittest.TestCase): - - def setUp(self): - ray.init(num_cpus=80, ignore_reinit_error=True) - self.temp_dir = tempfile.TemporaryDirectory() - self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") - - def tearDown(self): - ray.shutdown() - self.temp_dir.cleanup() - - def test_gsm8k_judger(self): - from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig - gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k") - judger_cfg = JudgerConfig( - reward_judger_configs=[gsm8k_judger_config], - worker_log_dir=self.worker_log_dir - ) - judger_controller = JudgerController.remote(judger_cfg) - # 返回的形式为:RLJudgerResponseItem(uid=112750990920317762694895938380669501546, reward={'openai/gsm8k': 1}, extra_info={}) - res1 = ray.get(judger_controller.run.remote(FAKE_JUDGER_INPUT_ITEM)) - self.assertEqual(res1.reward["score"], 1.0) - res2 = ray.get(judger_controller.run.remote(FAKE_JUDGER_INPUT_ITEM_MULTI_DATA)) - self.assertEqual(res2[0].reward["score"], 1.0) - self.assertEqual(res2[1].reward["score"], 1.0) - - def test_dapo_judger(self): - from xtuner.v1.ray.judger.dapo_math import DapoMathJudgerConfig - from xtuner.v1.utils.rl_test_utils import get_eos_token - from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) - eos_token = get_eos_token(MODEL_PATH) - eos_token_str = tokenizer.convert_ids_to_tokens(eos_token) - - dapo_judger_config = DapoMathJudgerConfig( - judger_name="dapo_math", - eos_token=eos_token_str, - enable_overlong_buffer=True, - max_response_len=32768, - overlong_buffer_len=4096, - overlong_penalty_factor=1.0, - tokenizer=tokenizer - - ) - judger_cfg = JudgerConfig( - reward_judger_configs=[dapo_judger_config], - worker_log_dir=self.worker_log_dir - ) - judger_controller = JudgerController.remote(judger_cfg) - judger_data, save_reward = construct_new_judger_data(DAPO_DATA_PATH) - group_data = ray.get(judger_controller.run.remote(judger_data)) - reward = [data.reward["score"] for data in group_data] - self.assertEqual(np.mean(reward), np.mean(save_reward)) - - def test_geo_judger(self): - from xtuner.v1.ray.judger.geo3k import GEO3KJudgerConfig - geo_judger_config = GEO3KJudgerConfig() - judger_cfg = JudgerConfig( - reward_judger_configs=[geo_judger_config], - worker_log_dir=self.worker_log_dir - ) - judger_controller = JudgerController.remote(judger_cfg) - judger_data, save_reward = construct_new_judger_data(GEO_ROLLOUT_DATA_PATH, judger_name="hiyouga/geometry3k") - group_data = ray.get(judger_controller.run.remote(judger_data)) - reward = [data.reward["score"] for data in group_data] - self.assertEqual(np.mean(reward), np.mean(save_reward)) - - def test_gsm8k_multi_judger(self): - from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig - # 支持一个GSM8KJudgerConfig创建多个实例 - gsm8k_judger_config_1 = GSM8KJudgerConfig(judger_name="openai/gsm8k-1") - gsm8k_judger_config_2 = GSM8KJudgerConfig(judger_name="openai/gsm8k-2") - judger_cfg = JudgerConfig( - reward_judger_configs=[ - gsm8k_judger_config_1, - gsm8k_judger_config_2 - ], - enable_weighted_judgers=True, - worker_log_dir=self.worker_log_dir, - ) - cpu_resources_config = CPUResourcesConfig.from_total( - total_cpus=2, - total_memory=2 * 1024**3, - num_workers=2 - ) - pg = AutoCPUWorkers.build_placement_group(cpu_resources_config) - judger_controller = JudgerController.remote(judger_cfg, pg) - res3 = ray.get(judger_controller.run.remote(FAKE_JUDGER_INPUT_ITEM_MULTI_SOURCE)) - self.assertEqual(res3.reward["score"], 1.0) - - def test_gsm8k_judger_score(self): - """Test the judger functionality with single and multiple data sources.""" - from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig - gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k") - judger_cfg = JudgerConfig( - reward_judger_configs=[gsm8k_judger_config], - worker_log_dir=self.worker_log_dir - ) - judger_controller = JudgerController.remote(judger_cfg) - judger_data = construct_judger_data(VERL_ROLLOUT_DATA_PATH) - group_data = ray.get(judger_controller.run.remote(judger_data)) - reward = [data.reward["score"] for data in group_data] - verl_score = 0.2418 - self.assertEqual(round(np.mean(reward), 4), verl_score) - - def test_gsm8k_remote_judger(self): - from xtuner.v1.utils.rl_test_utils import JudgerServer, GSM8KRemoteJudgerConfig - - server = JudgerServer(port=8018) - server.start() - try: - remote_judger_config = GSM8KRemoteJudgerConfig(judger_name="openai/gsm8k", remote_url=server.url) - judger_cfg = JudgerConfig( - reward_judger_configs=[remote_judger_config], - worker_log_dir=self.worker_log_dir - ) - judger_controller = JudgerController.remote(judger_cfg) - judger_data = construct_judger_data(VERL_ROLLOUT_DATA_PATH) - group_data = ray.get(judger_controller.run.remote(judger_data)) - reward = [data.reward["score"] for data in group_data] - verl_score = 0.2418 - self.assertEqual(round(np.mean(reward), 4), verl_score) - finally: - server.stop() - - -if __name__ == "__main__": - unittest.main() \ No newline at end of file diff --git a/tests/ray/test_mock_rollout.py b/tests/ray/test_mock_rollout.py deleted file mode 100644 index c57ece9076..0000000000 --- a/tests/ray/test_mock_rollout.py +++ /dev/null @@ -1,138 +0,0 @@ -import os -import asyncio -import unittest -import ray -from transformers import AutoTokenizer -import torch -import tempfile -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers -from xtuner.v1.ray.dataflow import DataFlow, DataFlowConfig, ReplayBufferConfig -from xtuner.v1.ray.environment import SingleTurnEnvironment -from xtuner.v1.datasets import RLTokenizeFnConfig -from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig -from xtuner.v1.ray.rollout.controller import RolloutController -from xtuner.v1.utils.rl_test_utils import MockTimeoutRolloutWorker, MockRequestErrorRolloutWorker, MockClientErrorRolloutWorker, MockServerErrorRolloutWorker - -MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] -TRAIN_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] -resource_map = {"npu": "NPU", "cuda": "GPU"} -@ray.remote -class MockTimeoutRolloutController(RolloutController): - def _get_worker_cls(self): - return ray.remote(MockTimeoutRolloutWorker) - def deactivate_worker_by_url(self, url): - pass -@ray.remote -class MockRequestErrorRolloutController(RolloutController): - def _get_worker_cls(self): - return ray.remote(MockRequestErrorRolloutWorker) - def deactivate_worker_by_url(self, url): - pass -@ray.remote -class MockClientErrorRolloutController(RolloutController): - def _get_worker_cls(self): - return ray.remote(MockClientErrorRolloutWorker) - def deactivate_worker_by_url(self, url): - pass -@ray.remote -class MockServerErrorRolloutController(RolloutController): - def _get_worker_cls(self): - return ray.remote(MockServerErrorRolloutWorker) - - def deactivate_worker_by_url(self, url): - pass - -class TestMockRollout(unittest.TestCase): - @classmethod - def setUpClass(cls): - os.environ["XTUNER_USE_FA3"] = "1" - - @classmethod - def tearDownClass(cls): - del os.environ["XTUNER_USE_FA3"] - - def setUp(self): - ray.init(num_cpus=80, ignore_reinit_error=True) - self.global_batch_size = 3 - self.max_prompt_length = 4096 - self.max_response_length = 128 - self.max_concurrent = 3 - self.max_retry_times = 3 - self.temp_dir = tempfile.TemporaryDirectory() - self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") - self.rollout_cfg = RolloutConfig( - env="test_mock_rollout", - model_path=MODEL_PATH, - model_name=os.path.basename(MODEL_PATH).lower(), - tokenizer_path=MODEL_PATH, - tensor_parallel_size=1, - context_length=self.max_prompt_length + self.max_response_length, - max_retry_per_worker=2, - worker_log_dir=self.worker_log_dir - ) - tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) - - self.dataflow_cfg = DataFlowConfig( - max_concurrent=self.max_concurrent, - global_batch_size=self.global_batch_size, - max_retry_times=self.max_retry_times, - worker_log_dir=self.worker_log_dir - ) - train_dataset_cfg = [{ - "dataset": DatasetConfig(name="mock_data", anno_path=TRAIN_DATA_PATH), - "tokenize_fn": RLTokenizeFnConfig(max_length=self.max_prompt_length), - }] - dataloader_cfg = DataloaderConfig( - collator='fake_collator', - pack_level='none', - group_by_length=False, - ) - self.replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=train_dataset_cfg, - dataloader_cfg=dataloader_cfg, - tokenizer=tokenizer, - worker_log_dir=self.worker_log_dir - ) - - def tearDown(self): - ray.shutdown() - self.temp_dir.cleanup() - - async def _run_mock_test(self, mock_controller_cls, error_name, pg): - rollout_controller = mock_controller_cls.remote(self.rollout_cfg, pg) - self.test_env = SingleTurnEnvironment.remote("env", pg, self.rollout_cfg, rollout_controller=rollout_controller) - self.test_dataflow = DataFlow.remote("dataflow", self.dataflow_cfg, self.replay_buffer_cfg, self.test_env) - - result = await self.test_dataflow.run.remote(num=3) - completed_rollouts = result["data_groups"] - status = await self.test_dataflow.get_replaybuffer_status.remote() - self.assertEqual(len(completed_rollouts), 0, f"[{error_name}] Expected no rollouts to complete successfully.") - self.assertEqual(status["remain_completed_samples_count"], 0, f"[{error_name}] Completed count in buffer should be 0.") - self.assertEqual(status["remain_aborted_samples_count"], 0, f"[{error_name}] Expected no rollouts to be interrupted.") - await self.test_env.shutdown.remote() - - @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") - def test_parallel_mock_rollout(self): - async def run_parallel(): - res_cfg_small = AcceleratorResourcesConfig( - accelerator=resource_map[torch.accelerator.current_accelerator().type], - num_workers=2, - num_cpus_per_worker=2, - ) - - pgs = [AutoAcceleratorWorkers.build_placement_group(res_cfg_small, name=f"pg_{i}") for i in range(4)] - await asyncio.gather(*[pg.ready() for pg in pgs]) - - tasks = [ - self._run_mock_test(MockTimeoutRolloutController, "timeout", pgs[0]), - self._run_mock_test(MockRequestErrorRolloutController, "request_error", pgs[1]), - self._run_mock_test(MockClientErrorRolloutController, "client_error", pgs[2]), - self._run_mock_test(MockServerErrorRolloutController, "server_error", pgs[3]), - ] - await asyncio.gather(*tasks) - - asyncio.run(run_parallel()) - -if __name__ == "__main__": - unittest.main() \ No newline at end of file diff --git a/tests/ray/test_rl_train_with_sft.py b/tests/ray/test_rl_train_with_sft.py deleted file mode 100644 index 2ddfcf01d1..0000000000 --- a/tests/ray/test_rl_train_with_sft.py +++ /dev/null @@ -1,179 +0,0 @@ -import os -import unittest -from transformers import AutoTokenizer -import shutil -import tempfile -import json -import torch -from xtuner.v1.data_proto.sequence_context import SequenceContext -import ray -from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers -from xtuner.v1.config import ( - AdamWConfig, - FSDPConfig, - LRConfig, -) -from xtuner.v1.rl.base import WorkerConfig, TrainingController, TrainingWorker as BaseTrainingWorker -from xtuner.v1.rl.grpo.loss import GRPOLossConfig as LossConfig -from xtuner.v1.model.dense.qwen3 import Qwen3Dense8BConfig -from xtuner.v1.datasets.sft_tokenize_fn import OpenaiTokenizeFunctionConfig -from xtuner.v1.loss import CELossConfig -from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig -from xtuner.v1.train.trainer import LoadCheckpointConfig - -QWEN3_PATH = os.environ["QWEN3_PATH"] -ALPACA_PATH = os.environ["ALPACA_PATH"] - - -class TestRLTrainWithSFT(unittest.TestCase): - def setUp(self): - ray.init(num_cpus=80, ignore_reinit_error=True) - - resources = AcceleratorResourcesConfig( - accelerator="GPU", - num_accelerators_per_worker=1, - num_cpus_per_worker=8, - num_workers=8, - cpu_memory_per_worker=16 * 1024 ** 3, # 16 GB - ) - - pg = AutoAcceleratorWorkers.build_placement_group(resources) - self.pg = pg - - self.temp_dir = tempfile.mkdtemp() - tokenizer = AutoTokenizer.from_pretrained(QWEN3_PATH, trust_remote_code=True) - self.tokenizer = tokenizer - self.prompt_repeat_k = 8 - file = './tests/ray/rollout_output.jsonl' - with open(file, 'r') as f: - data = [json.loads(line) for line in f] - data_groups = [data[i:i + self.prompt_repeat_k] for i in range(0, len(data), self.prompt_repeat_k)] - data_groups = data_groups[:8] - data_batches = [] - for group in data_groups: - prompt_ids = tokenizer(group[0]['prompt'], return_tensors='pt')['input_ids'].flatten().tolist() - rewards = [item['reward'] for item in group] - rewards = torch.tensor(rewards, dtype=torch.float32) - advantages = (rewards - rewards.mean(0)) / (rewards.std(0) + 1e-8) - - for i in range(self.prompt_repeat_k): - item = group[i] - response_ids = tokenizer(item['response'], return_tensors='pt')['input_ids'].flatten().tolist() - input_ids = prompt_ids + response_ids - shifted_labels = [-100] * (len(prompt_ids) - 1) + response_ids + [-100] - input_ids = torch.tensor(input_ids, dtype=torch.int64).unsqueeze(0) - shifted_labels = torch.tensor(shifted_labels, dtype=torch.int64).unsqueeze(0) - data_batches.append( - dict( - seq_ctx=SequenceContext.from_input_ids((input_ids,), device="cpu"), - shifted_labels=shifted_labels, - advantage=advantages[i].item(), - ) - ) - self.data_batches = data_batches - - def tearDown(self): - shutil.rmtree(self.temp_dir) - ray.shutdown() - - def build_train_controller(self): - model_cfg = Qwen3Dense8BConfig() - optim_cfg: AdamWConfig = AdamWConfig(lr=5e-7, foreach=False) - fsdp_cfg: FSDPConfig = FSDPConfig( - cpu_offload=False, - ep_size=1, - ) - lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=5e-7) - - dataset_config = [] - _data_cfg = {"dataset": DatasetConfig(name='apach', - anno_path=ALPACA_PATH), - "tokenize_fn": OpenaiTokenizeFunctionConfig( - chat_template='qwen3', - max_length=32768 - ) - } - dataset_config.append(_data_cfg) - - sft_dataloader_cfg = DataloaderConfig( - dataset_config_list=dataset_config, - pack_max_length=32768, - pack_to_max_length=True, - num_workers=0, - ) - sft_global_batch_size = 8 - loss_reduction = "square" - sft_loss_cfg = CELossConfig(mode="chunk", chunk_size=1024, loss_reduction=loss_reduction) - - worker_cfg: WorkerConfig = WorkerConfig( - sft_dataloader_cfg=sft_dataloader_cfg, - sft_global_batch_size=sft_global_batch_size, - sft_loss_cfg=sft_loss_cfg, - seed=42, - model_cfg=model_cfg, - optim_cfg=optim_cfg, - loss_cfg=LossConfig( - policy_loss_cfg=dict( - cliprange_high=0.28, - cliprange_low=0.2, - loss_type="vanilla", - ), - ignore_idx=-100, - use_kl_loss=True, - kl_loss_coef=0.001, - kl_loss_type="low_var_kl", - mode="eager"), - lr_cfg=lr_cfg, - fsdp_cfg=fsdp_cfg, - load_from=QWEN3_PATH, - sp_size=1, - pack_max_length=8192, - ) - - TrainingWorker = ray.remote( - runtime_env={ - "env_vars": { - "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1", - "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "1", - } - }, - )(BaseTrainingWorker) - train_workers, _ = AutoAcceleratorWorkers.from_placement_group( - TrainingWorker, worker_cfg, self.pg - ) - futures = [worker.test_all_reduce.remote() for worker in train_workers] - print(ray.get(futures)) - train_controller = TrainingController.remote( - workers=train_workers, - ) - ray.get(train_controller.__ray_ready__.remote()) - return train_controller - - def test_rl_train_with_sft(self): - train_controller = self.build_train_controller() - - ray.get(train_controller.fit.remote(self.data_batches, pack_max_length=1024, rollout_idx=0)) - ray.get(train_controller.save.remote(os.path.join(self.temp_dir, "save_test"), no_save_optimizer=True)) - - log_infos = ray.get(train_controller.fit.remote(self.data_batches, pack_max_length=1024, rollout_idx=1)) - efficient_attn_ratio_list = [] - for log_info in log_infos: - efficient_attn_ratio_list.append(log_info['sft_train_metrics']['efficient_attn_ratio']) - assert all([efficient_attn_ratio > 0 for efficient_attn_ratio in efficient_attn_ratio_list]) - - ray.kill(train_controller) - train_controller = self.build_train_controller() - load_checkpoint_cfg = LoadCheckpointConfig(checkpoint_path=os.path.join(self.temp_dir, "save_test"), - load_optimizer_states=False, - load_optimizer_args=False - ) - ray.get(train_controller.resume.remote(load_checkpoint_cfg)) - - log_infos = ray.get(train_controller.fit.remote(self.data_batches, pack_max_length=1024, rollout_idx=1)) - new_efficient_attn_ratio_list = [] - for log_info in log_infos: - new_efficient_attn_ratio_list.append(log_info['sft_train_metrics']['efficient_attn_ratio']) - - efficient_attn_ratio_list.sort() - new_efficient_attn_ratio_list.sort() - self.assertEqual(efficient_attn_ratio_list, new_efficient_attn_ratio_list) diff --git a/tests/ray/test_rl_trainer.py b/tests/ray/test_rl_trainer.py deleted file mode 100644 index d008d465dc..0000000000 --- a/tests/ray/test_rl_trainer.py +++ /dev/null @@ -1,253 +0,0 @@ -import os -import tempfile -import unittest -from pathlib import Path - -import ray -import torch - -from transformers import AutoTokenizer -from xtuner.v1.config import ( - AdamWConfig, - FSDPConfig, - LRConfig, -) -from xtuner.v1.data_proto.rl_data import SampleParams -from xtuner.v1.datasets import RLTokenizeFnConfig -from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig -from xtuner.v1.model import get_model_config_from_hf -from xtuner.v1.ray.base import AcceleratorResourcesConfig, CPUResourcesConfig -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig -from xtuner.v1.ray.judger.controller import JudgerConfig -from xtuner.v1.rl.base import WorkerConfig -from xtuner.v1.rl.grpo import GRPOLossConfig -from xtuner.v1.train.rl_trainer import RLTrainer, RLTrainerConfig - - -MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] -TRAIN_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] -TEST_DATA_PATH = os.environ["ROLLOUT_TEST_DATA_PATH"] -resource_map = { - "npu": "NPU", - "cuda": "GPU", -} - - -class TestRLTrainer(unittest.TestCase): - @classmethod - def setUpClass(cls): - os.environ["XTUNER_USE_FA3"] = "1" - os.environ["LMD_SKIP_WARMUP"] = "1" - - @classmethod - def tearDownClass(cls): - del os.environ["XTUNER_USE_FA3"] - del os.environ["LMD_SKIP_WARMUP"] - - def init_traine_worker_config(self, train_optimizer_steps, pack_max_length): - model_cfg = get_model_config_from_hf(Path(MODEL_PATH)) - model_cfg.compile_cfg = False - optim_cfg = AdamWConfig(lr=1e-6, betas=(0.9, 0.999), max_grad_norm=1.0, weight_decay=0.1, foreach=False) - loss_cfg = GRPOLossConfig( - policy_loss_cfg=dict( - cliprange_high=0.28, - cliprange_low=0.2, - loss_type="vanilla", - clip_ratio_c=10.0, - log_prob_diff_min=-20.0, - log_prob_diff_max=20.0, - ), - ignore_idx=-100, - use_kl_loss=False, - kl_loss_coef=0.0, - kl_loss_type="low_var_kl", - mode="chunk", - chunk_size=512, - ) - lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) - fsdp_cfg = FSDPConfig(cpu_offload=False, ep_size=1) - train_worker_cfg: WorkerConfig = WorkerConfig( - model_cfg=model_cfg, - load_from=MODEL_PATH, - optim_cfg=optim_cfg, - loss_cfg=loss_cfg, - lr_cfg=lr_cfg, - fsdp_cfg=fsdp_cfg, - sp_size=1, - optimizer_steps=train_optimizer_steps, - pack_max_length=pack_max_length, - ) - return train_worker_cfg - - def init_replay_buffer_config(self, max_prompt_length): - train_dataset_cfg = [ - { - "dataset": DatasetConfig(name="gsm8k", anno_path=TRAIN_DATA_PATH, sample_ratio=1.0), - "tokenize_fn": RLTokenizeFnConfig(max_length=max_prompt_length), - }, - ] - dataloader_cfg = DataloaderConfig( - collator="fake_collator", - pack_level="none", - group_by_length=False, - ) - tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) - replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=train_dataset_cfg, - dataloader_cfg=dataloader_cfg, - tokenizer=tokenizer, - worker_log_dir=self.worker_log_dir, - ) - return replay_buffer_cfg - - def init_resources_config(self, num_workers, num_cpus_per_worker, cpu_memory_per_worker): - resources = AcceleratorResourcesConfig( - accelerator=resource_map[torch.accelerator.current_accelerator().type], - num_workers=num_workers, - num_cpus_per_worker=num_cpus_per_worker, - cpu_memory_per_worker=cpu_memory_per_worker, - ) - return resources - - def init_cpu_resources_config(self, num_cpus_per_worker, cpu_memory_per_worker): - cpu_resources = CPUResourcesConfig( - num_cpus_per_worker=num_cpus_per_worker, - cpu_memory_per_worker=cpu_memory_per_worker, - ) - return cpu_resources - - def init_rollout_config(self, max_prompt_length, max_response_length): - rollout_config = RolloutConfig( - env="test_rl_trainer", - model_path=MODEL_PATH, - worker_log_dir=self.worker_log_dir, - rollout_max_batch_size_per_instance=1024, - context_length=max_response_length + max_prompt_length, - ) - return rollout_config - - def init_dataflow_config(self, max_response_length, global_batch_size, prompt_repeat_k, enable_partial_rollout): - sample_params = SampleParams( - max_tokens=max_response_length, - ) - dataflow_config = DataFlowConfig( - env="test_rl_trainer", - global_batch_size=global_batch_size, - prompt_repeat_k=prompt_repeat_k, - worker_log_dir=self.worker_log_dir, - sample_params=sample_params, - enable_partial_rollout=enable_partial_rollout, - max_concurrent=1024, - ) - return dataflow_config - - def init_judger_config(self): - from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig - - gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k") - judger_cfg = JudgerConfig(reward_judger_configs=[gsm8k_judger_config], worker_log_dir=self.worker_log_dir) - return judger_cfg - - def init_multi_judger_config(self): - from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig - - # 支持一个GSM8KJudgerConfig创建多个实例 - gsm8k_judger_config_1 = GSM8KJudgerConfig(judger_name="openai/gsm8k-1") - gsm8k_judger_config_2 = GSM8KJudgerConfig(judger_name="openai/gsm8k-2") - judger_cfg = JudgerConfig( - reward_judger_configs=[gsm8k_judger_config_1, gsm8k_judger_config_2], - worker_log_dir=self.worker_log_dir, - ) - return judger_cfg - - def setUp(self): - ray.init(num_cpus=80, ignore_reinit_error=True) - self.temp_dir = tempfile.TemporaryDirectory() - self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") - - train_optimizer_steps = 2 - pack_max_length = 32768 - max_prompt_length = 2048 - max_response_length = 1024 - global_batch_size = 4 - prompt_repeat_k = 4 - enable_partial_rollout = False - - self.train_worker_cfg = self.init_traine_worker_config(train_optimizer_steps, pack_max_length) - self.replay_buffer_cfg = self.init_replay_buffer_config(max_prompt_length) - self.resources_cfg = self.init_resources_config( - num_workers=8, num_cpus_per_worker=8, cpu_memory_per_worker=8 * 1024**3 - ) - self.cpu_resources = self.init_cpu_resources_config(num_cpus_per_worker=1, cpu_memory_per_worker=1 * 1024**3) - self.rollout_config = self.init_rollout_config( - max_response_length=max_response_length, max_prompt_length=max_prompt_length - ) - self.dataflow_config = self.init_dataflow_config( - max_response_length=max_response_length, - global_batch_size=global_batch_size, - prompt_repeat_k=prompt_repeat_k, - enable_partial_rollout=enable_partial_rollout, - ) - self.judger_config = self.init_judger_config() - - def tearDown(self): - self.temp_dir.cleanup() - ray.shutdown() - - def test_rl_trainer(self): - multi_judger_config = self.init_multi_judger_config() - cpu_resources = self.init_cpu_resources_config(num_cpus_per_worker=2, cpu_memory_per_worker=2 * 1024**3) - trainer_config = RLTrainerConfig( - load_from=MODEL_PATH, - resources=self.resources_cfg, - cpu_resources=cpu_resources, - rollout_config=self.rollout_config, - dataflow_config=self.dataflow_config, - judger_config=multi_judger_config, - replay_buffer_config=self.replay_buffer_cfg, - train_worker_config=self.train_worker_cfg, - work_dir=self.worker_log_dir, - tokenizer_path=MODEL_PATH, - total_epochs=1, - rollout_steps=1, - ) - trainer = RLTrainer.from_config(trainer_config) - self.assertIsNotNone(trainer, "Trainer should be created successfully") - try: - trainer.fit() - except Exception as e: - self.fail(f"trainer.fit() raised unexpected exception: {e}") - # assure all writers are closed before checking log files - del trainer - log_files = list(Path(self.worker_log_dir).rglob("*.log")) - self.assertGreater(len(log_files), 0, "Should generate log files") - trajectory_files = list(Path(self.worker_log_dir).rglob("*_trajectory.jsonl")) - self.assertGreater(len(trajectory_files), 0, "Should generate trajectory files") - - def test_judger_cpu_pg_creation_with_error(self): - """Test RLTrainer judger_cpu_pg creation.""" - multi_judger_config = self.init_multi_judger_config() - # error resource with multi-judger - cpu_resources = self.init_cpu_resources_config(num_cpus_per_worker=1, cpu_memory_per_worker=1 * 1024**3) - trainer_config = RLTrainerConfig( - load_from=MODEL_PATH, - resources=self.resources_cfg, - cpu_resources=cpu_resources, - rollout_config=self.rollout_config, - dataflow_config=self.dataflow_config, - judger_config=multi_judger_config, - replay_buffer_config=self.replay_buffer_cfg, - train_worker_config=self.train_worker_cfg, - work_dir=self.worker_log_dir, - tokenizer_path=MODEL_PATH, - total_epochs=1, - rollout_steps=1, - ) - with self.assertRaises(AssertionError) as cm: - trainer = RLTrainer.from_config(trainer_config) - - print(f"Expected AssertionError caught: {cm.exception}") - -if __name__ == "__main__": - unittest.main() diff --git a/tests/ray/test_rollout.py b/tests/ray/test_rollout.py deleted file mode 100644 index 31f8542d3c..0000000000 --- a/tests/ray/test_rollout.py +++ /dev/null @@ -1,401 +0,0 @@ -import os -import subprocess -from functools import wraps -import unittest -import tempfile -import ray -import torch -from pathlib import Path -from transformers import AutoTokenizer -import tempfile -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.judger.controller import JudgerConfig -from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers -from xtuner.v1.ray.dataflow import DataFlow, DataFlowConfig, ReplayBufferConfig -from xtuner.v1.data_proto.rl_data import SampleParams -from xtuner.v1.ray.environment import SingleTurnEnvironment -from xtuner.v1.ray.rollout import RolloutController -from xtuner.v1.ray.judger import JudgerController -from xtuner.v1.datasets import RLTokenizeFnConfig, build_datasets, build_dataloader -from xtuner.v1.datasets.config import ( - DataloaderConfig, - DatasetConfig, -) -import asyncio - -TEST_TEXT_MESSAGES=[{"role": "user", "content": "Hello!"}] -MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] -MOE_MODEL_PATH = os.environ["QWEN3_MOE_PATH"] -TRAIN_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] -TEST_DATA_PATH = os.environ["ROLLOUT_TEST_DATA_PATH"] -resource_map = { - "npu": "NPU", - "cuda": "GPU", -} -class TestRollout(unittest.TestCase): - - @classmethod - def setUpClass(cls) -> None: - os.environ["XTUNER_USE_FA3"] = "1" - os.environ["LMD_SKIP_WARMUP"] = "1" - - @classmethod - def tearDownClass(cls) -> None: - del os.environ["XTUNER_USE_FA3"] - del os.environ["LMD_SKIP_WARMUP"] - - def init_config(self): - self.resources_cfg = AcceleratorResourcesConfig( - accelerator=resource_map[torch.accelerator.current_accelerator().type], - num_workers=8, - num_cpus_per_worker=8, - cpu_memory_per_worker=16 * 1024**3, # 16 GB - ) - self.max_prompt_length = 512 - self.max_response_length = 1024 - self.context_length = self.max_prompt_length + self.max_response_length - from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig - gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k") - self.judger_cfg = JudgerConfig( - reward_judger_configs=[gsm8k_judger_config], - worker_log_dir=self.worker_log_dir, - ) - self.dataflow_cfg = DataFlowConfig( - env="test", - prompt_repeat_k=1, - global_batch_size=1, - enable_partial_rollout=0, - max_retry_times=1, - worker_log_dir=self.worker_log_dir, - ) - self.train_dataset_cfg = [ - { - "dataset": DatasetConfig(name="gsm8k", - anno_path=TRAIN_DATA_PATH, - sample_ratio=1.0), - "tokenize_fn": RLTokenizeFnConfig(max_length=self.max_prompt_length), - }, - ] - self.dataloader_cfg = DataloaderConfig( - collator='fake_collator', - pack_level='none', - group_by_length=False, - ) - self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) - self.replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=self.train_dataset_cfg, - dataloader_cfg=self.dataloader_cfg, - tokenizer=self.tokenizer, - worker_log_dir=self.worker_log_dir, - ) - - def setUp(self): - ray.init(num_cpus=80, ignore_reinit_error=True) - self.data_path = TRAIN_DATA_PATH - self.model_path = MODEL_PATH - self.temp_dir = tempfile.TemporaryDirectory() - self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") - self.init_config() - - def tearDown(self): - ray.shutdown() - # When lmdeploy enable ep>1, it uses deep_ep. Buffer implicit destroy would cause some ray actor stucked. - # Use pkill cleen up ray::WorkerWrapper process after close ray cluster connection as workaround. - # TODO(chenchiyu): add excplicit deep_ep destroy in lmdeploy. - self._cleanup_lmdeploy_ray_worker_wrapper() - self.temp_dir.cleanup() - - @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") - def test_parallel_rollout(self): - resource_config = AcceleratorResourcesConfig( - accelerator=resource_map[torch.accelerator.current_accelerator().type], - num_workers=4, - num_cpus_per_worker=4, - cpu_memory_per_worker=8 * 1024**3, # 8 GB - ) - pg1 = AutoAcceleratorWorkers.build_placement_group(resource_config, name="tp_pg") - pg2 = AutoAcceleratorWorkers.build_placement_group(resource_config, name="ep_pg") - dense_model_path = MODEL_PATH - moe_model_path = MOE_MODEL_PATH - dist_port_base = 38000 - async def run_both(): - return await asyncio.gather( - self._run_rollout(model_path=dense_model_path, tp_size=4, ep_size=1, pg=pg1, dist_port_base=dist_port_base), - self._run_rollout(model_path=moe_model_path, tp_size=1, ep_size=4, pg=pg2, dist_port_base=dist_port_base + 1024 * 4), - return_exceptions=False - ) - - asyncio.run(run_both()) - - @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") - def test_parallel_model_save_and_resume(self): - resource_config = AcceleratorResourcesConfig( - accelerator=resource_map[torch.accelerator.current_accelerator().type], - num_workers=4, - num_cpus_per_worker=4, - cpu_memory_per_worker=8 * 1024**3, # 8 GB - ) - pg1 = AutoAcceleratorWorkers.build_placement_group(resource_config, name="dense_pg") - pg2 = AutoAcceleratorWorkers.build_placement_group(resource_config, name="moe_pg") - - async def run_both(): - return await asyncio.wait_for( - asyncio.gather( - self._run_dense_save_resume_sync_async(pg1), - self._run_moe_save_resume_with_r3(pg2), - return_exceptions=False - ), - timeout=300 - ) - try: - asyncio.run(run_both()) - except asyncio.TimeoutError: - self.fail("test_parallel_model_save_and_resume timed out after 300s") - - def _cleanup_lmdeploy_ray_worker_wrapper(self): - try: - result = subprocess.run(["pkill", "-f", "ray::RayWorkerWrapper*"], capture_output=True, text=True, timeout=10) - if result.returncode != 0: - print(f"pkill command failed with return code {result.returncode}: {result.stderr}." - " Maybe no lmdeploy ray::RayWorkerWrapper processes found.") - except Exception as e: - print(f"Error stopping ray::RayWorkerWrapper cluster: {e}") - - async def _run_rollout(self, model_path, tp_size, ep_size, pg, dist_port_base): - rollout_config = RolloutConfig( - env="test_rollout", - model_path=model_path, - model_name=os.path.basename(model_path).lower(), - tokenizer_path=model_path, - tensor_parallel_size=tp_size, - expert_parallel_size=ep_size, - context_length=self.context_length, - worker_log_dir=self.worker_log_dir, - dist_port_base=dist_port_base, - - ) - rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg) - try: - result = await asyncio.wait_for(rollout_controller.rollout.remote(prompt=TEST_TEXT_MESSAGES), timeout=300) - self.assertEqual(result.finish_reason, "stop") - except asyncio.TimeoutError: - self.fail("TP Rollout timed out!") - finally: - await asyncio.wait_for(rollout_controller.shutdown.remote(), timeout=300) - - async def _run_dataflow_save_resume_test(self, test_env, dataflow_cfg: DataFlowConfig, replay_buffer_cfg: ReplayBufferConfig): - """ - Generic driver for dataflow save/resume tests. - """ - # 1. Initialize Environment and DataFlow - is_partial_rollout = dataflow_cfg.enable_partial_rollout == 1 - test_flow = DataFlow.remote("test_env", dataflow_cfg, replay_buffer_cfg, test_env) - - # 2. Initial Run - await test_flow.run.remote() - - # Capture status before saving (critical for partial rollout consistency check) - rl_status_before_save = await test_flow.get_replaybuffer_status.remote() - - # 3. Save - save_dir = Path(self.temp_dir.name) / 'checkpoints' / f'ckpt-step-2' - save_dir.mkdir(parents=True, exist_ok=True) - await test_flow.save.remote(save_dir) - - # Define run logic based on mode - async def run_continuation(status_ref): - if is_partial_rollout: - remain = status_ref["remain_aborted_samples_count"] + status_ref["remain_completed_samples_count"] - # Finish the remaining paused samples - result = await test_flow.run.remote(num=remain, enable_partial_rollout=0) - return result["data_groups"] - else: - # Normal run - result = await test_flow.run.remote() - return result["data_groups"] - - # continue running after save - responses_old = await run_continuation(rl_status_before_save) - rb_status_old = await test_flow.get_replaybuffer_status.remote() - - - # resume from saved checkpoint - await test_flow.resume.remote(save_dir) - rl_status_resume = await test_flow.get_replaybuffer_status.remote() - responses_new = await run_continuation(rl_status_resume) - rb_status_new = await test_flow.get_replaybuffer_status.remote() - - # Compare Data - ids_old = self._get_sorted_input_ids(responses_old) - ids_new = self._get_sorted_input_ids(responses_new) - self.assertEqual(ids_old, ids_new) - - # Compare ReplayBuffer Status (Old run vs New run) - for key in rb_status_old: - self.assertEqual(rb_status_old[key], rb_status_new[key]) - - # For partial rollout, verify the resumed state matches the saved state - if is_partial_rollout: - for key in rl_status_before_save: - self.assertEqual(rl_status_before_save[key], rl_status_resume[key]) - - async def _run_dense_save_resume_sync_async(self, pg): - model_path = MODEL_PATH - worker_log_dir = os.path.join(self.worker_log_dir, "test_dense") - rollout_config = RolloutConfig( - env="test_rollout", - model_path=model_path, - model_name=os.path.basename(model_path).lower(), - tokenizer_path=model_path, - context_length=self.context_length, - worker_log_dir=worker_log_dir, - dist_port_base=37000, - ) - test_env = SingleTurnEnvironment.remote( - "test_env", - pg, - rollout_cfg=rollout_config, - ) - sync_dataflow_cfg = DataFlowConfig( - env="test", - prompt_repeat_k=2, - global_batch_size=2, - enable_partial_rollout=0, - max_concurrent=2, - max_retry_times=1, - worker_log_dir=worker_log_dir, - ) - async_dataflow_cfg = DataFlowConfig( - env="test", - prompt_repeat_k=2, - global_batch_size=2, - enable_partial_rollout=1, - staleness_threshold=1, - max_retry_times=1, - worker_log_dir=self.worker_log_dir, - ) - replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=self.train_dataset_cfg, - dataloader_cfg=self.dataloader_cfg, - tokenizer=self.tokenizer, - worker_log_dir=worker_log_dir, - ) - self._run_dataflow_save_resume_test(test_env, sync_dataflow_cfg, replay_buffer_cfg) - self._run_dataflow_save_resume_test(test_env, async_dataflow_cfg, replay_buffer_cfg) - - async def _run_moe_save_resume_with_r3(self, pg): - model_path = MOE_MODEL_PATH - worker_log_dir = os.path.join(self.worker_log_dir, "test_moe_r3") - rollout_config = RolloutConfig( - env="test_rollout", - model_path=model_path, - model_name=os.path.basename(model_path).lower(), - tokenizer_path=model_path, - expert_parallel_size=2, - context_length=self.context_length, - worker_log_dir=worker_log_dir, - dist_port_base=36000, - ) - test_env = SingleTurnEnvironment.remote( - "test_env", - pg, - rollout_cfg=rollout_config, - ) - async_dataflow_cfg = DataFlowConfig( - env="test", - prompt_repeat_k=2, - global_batch_size=2, - enable_partial_rollout=1, - max_concurrent=4, - max_retry_times=1, - worker_log_dir=worker_log_dir, - ) - replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=self.train_dataset_cfg, - dataloader_cfg=self.dataloader_cfg, - tokenizer=self.tokenizer, - worker_log_dir=worker_log_dir, - ) - self._run_dataflow_save_resume_test(test_env, async_dataflow_cfg, replay_buffer_cfg) - - def _get_sorted_input_ids(self, responses): - """Helper to extract and sort input_ids from responses.""" - all_ids = [] - for data_items in responses[0]: - for data_item in data_items: - all_ids.extend(data_item.data.input_ids) - all_ids.sort() - return all_ids - - @unittest.skip("skip lmdeploy turbomind generate test due to ci environment issue") - def test_lmdeploy_turbomind_generate(self): - from xtuner.v1.ray.rollout import LMDeployWorker - rollout_config = RolloutConfig( - env="test_rollout", - model_path=MODEL_PATH, - model_name=os.path.basename(MODEL_PATH).lower(), - tokenizer_path=MODEL_PATH, - context_length=self.context_length, - worker_log_dir=self.worker_log_dir, - extra_rollout_config={"lmdeploy_backend": "turbomind"}, - ) - sample_params = SampleParams(temperature=0.0) - pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg) - rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg) # type: ignore[attr-defined] - res1 = ray.get(rollout_controller.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params)) - res2 = ray.get(rollout_controller.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params)) - self.assertEqual(res1, res2, f"res1 != res2, res1={res1}, res2={res2}") - ray.get(rollout_controller.shutdown.remote(), timeout=300) - - @unittest.skipIf(os.environ.get("XTUNER_USE_SGLANG", "0") == "0", "lmdeploy backend is not enabled") - def test_sglang_generate(self): - from xtuner.v1.ray.rollout import SGLangWorker - self.rollout_cfg.launch_server_method="multiprocessing" - rollout_config = RolloutConfig( - env="test_rollout", - model_path=MODEL_PATH, - model_name=os.path.basename(MODEL_PATH).lower(), - tokenizer_path=MODEL_PATH, - context_length=self.context_length, - worker_log_dir=self.worker_log_dir, - launch_server_method="multiprocessing" - ) - sample_params = SampleParams(temperature=0.0) - pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg) - rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg) # type: ignore[attr-defined] - res1 = ray.get(rollout_controller.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params)) - self.assertEqual(res1.finish_reason, "stop") - print("Response from SGLang infer:", res1) - ray.get(rollout_controller.shutdown.remote(), timeout=300) - - @unittest.skipIf(os.environ.get("XTUNER_USE_SGLANG", "0") == "0", "lmdeploy backend is not enabled") - def test_sglang_dataflow(self): - self.dataflow_cfg.enable_partial_rollout = 0 - rollout_config = RolloutConfig( - env="test_rollout", - model_path=MODEL_PATH, - model_name=os.path.basename(MODEL_PATH).lower(), - tokenizer_path=MODEL_PATH, - context_length=self.context_length, - worker_log_dir=self.worker_log_dir, - launch_server_method="multiprocessing" - ) - pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg) - test_env = SingleTurnEnvironment.remote( - "test_env", - pg, - rollout_cfg=rollout_config, - ) - test_flow = DataFlow.remote("test_env", - self.dataflow_cfg, - self.replay_buffer_cfg, - test_env - ) - responses = ray.get(test_flow.run.remote(), timeout=300)["data_groups"] - finished_samples_count = sum(1 for data in responses for item in data if item.env.rollout.finish_reason == "stop" or item.env.rollout.finish_reason == "length") - self.assertEqual(finished_samples_count // self.dataflow_cfg.prompt_repeat_k, self.dataflow_cfg.global_batch_size) - ray.get(test_env.shutdown.remote(), timeout=300) - print("responses: ", responses) - -if __name__ == "__main__": - unittest.main() diff --git a/tests/ray/test_vl_rollout.py b/tests/ray/test_vl_rollout.py deleted file mode 100644 index 81621e9d21..0000000000 --- a/tests/ray/test_vl_rollout.py +++ /dev/null @@ -1,207 +0,0 @@ -import os -import subprocess -from functools import wraps -import unittest -import tempfile -import ray -import torch -from pathlib import Path -from transformers import AutoTokenizer -import tempfile -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.judger.controller import JudgerConfig -from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers -from xtuner.v1.ray.dataflow import DataFlow, DataFlowConfig, ReplayBufferConfig -from xtuner.v1.data_proto.rl_data import SampleParams -from xtuner.v1.ray.environment import SingleTurnEnvironment -from xtuner.v1.ray.rollout import RolloutController -from xtuner.v1.ray.judger import JudgerController -from xtuner.v1.datasets import RLTokenizeFnConfig, build_datasets, build_dataloader -from xtuner.v1.datasets.config import ( - DataloaderConfig, - DatasetConfig, -) - -MODEL_PATH=os.getenv("QWEN3_VL_DENSE_PATH") -TRAIN_DATA_PATH=os.getenv("GEO3K_TRAIN_DATA_PATH") -MEDIA_ROOT=os.getenv("GEO3K_MEDIA_ROOT") - -resource_map = { - "npu": "NPU", - "cuda": "GPU", -} -class TestRollout(unittest.TestCase): - - @classmethod - def setUpClass(cls) -> None: - os.environ["XTUNER_USE_FA3"] = "1" - os.environ["LMD_SKIP_WARMUP"] = "1" - - @classmethod - def tearDownClass(cls) -> None: - del os.environ["XTUNER_USE_FA3"] - del os.environ["LMD_SKIP_WARMUP"] - - def init_config(self): - self.resources_cfg = AcceleratorResourcesConfig( - accelerator=resource_map[torch.accelerator.current_accelerator().type], - num_workers=8, - num_cpus_per_worker=8, - cpu_memory_per_worker=16 * 1024**3, # 16 GB - ) - self.max_prompt_length = 2048 - self.max_response_length = 2048 - self.rollout_cfg = RolloutConfig( - env="test_rollout", - model_path=self.model_path, - model_name=os.path.basename(self.model_path).lower(), - tokenizer_path=self.model_path, - rollout_cross_node_comm=False, - tensor_parallel_size=2, - expert_parallel_size=1, - gpus_per_node=8, # gpu: 8, npu: 16 - dtype="bfloat16", - launch_server_method="ray", - context_length=self.max_prompt_length + self.max_response_length, - worker_log_dir=self.worker_log_dir, - ) - from xtuner.v1.ray.judger.geo3k import GEO3KJudgerConfig - geo3k_judger_config = GEO3KJudgerConfig() - self.judger_cfg = JudgerConfig(reward_judger_configs=[geo3k_judger_config]) - - self.dataflow_cfg = DataFlowConfig( - env="test", - prompt_repeat_k=2, - global_batch_size=2, - enable_partial_rollout=0, - max_retry_times=1, - worker_log_dir=self.worker_log_dir, - ) - self.training_sample_params = SampleParams( - max_tokens=self.max_response_length, - ) - self.evaluation_sample_params = SampleParams( - max_tokens=self.max_response_length, - top_p=1.0, - temperature=0.0, - top_k=1, - ) - self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) - from xtuner.v1.datasets import Qwen3VLTokenizeFnConfig - tokenize_fn_cfg = Qwen3VLTokenizeFnConfig(processor_path=self.model_path) - train_dataset_cfg = [ - { - "dataset": DatasetConfig(name="geo3k", - anno_path=self.data_path, - class_name='VLMJsonlDataset', - media_root=self.media_root, - sample_ratio=1.0), - "tokenize_fn": RLTokenizeFnConfig(max_length=self.max_prompt_length, - tokenize_fn_cfg=tokenize_fn_cfg), - } - ] - dataloader_config = DataloaderConfig(num_workers=8, - collator="fake_collator", - pack_level="none") - - self.replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=train_dataset_cfg, - dataloader_cfg=dataloader_config, - tokenizer=self.tokenizer, - ) - - def setUp(self): - ray.init(num_cpus=80, ignore_reinit_error=True) - self.data_path = TRAIN_DATA_PATH - self.model_path = MODEL_PATH - self.media_root = MEDIA_ROOT - self.temp_dir = tempfile.TemporaryDirectory() - self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") - self.init_config() - self.pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg) - - def tearDown(self): - ray.shutdown() - # When lmdeploy enable ep>1, it uses deep_ep. Buffer implicit destroy would cause some ray actor stucked. - # Use pkill cleen up ray::WorkerWrapper process after close ray cluster connection as workaround. - # TODO(chenchiyu): add excplicit deep_ep destroy in lmdeploy. - self._cleanup_lmdeploy_ray_worker_wrapper() - self.temp_dir.cleanup() - - def _cleanup_lmdeploy_ray_worker_wrapper(self): - try: - result = subprocess.run(["pkill", "-f", "ray::RayWorkerWrapper*"], capture_output=True, text=True, timeout=10) - if result.returncode != 0: - print(f"pkill command failed with return code {result.returncode}: {result.stderr}." - " Maybe no lmdeploy ray::RayWorkerWrapper processes found.") - except Exception as e: - print(f"Error stopping ray::RayWorkerWrapper cluster: {e}") - - @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") - def test_vl_resume_with_partial_rollout(self): - rollout_cfg = self.rollout_cfg - # rollout_cfg.enable_return_routed_experts = True - self.test_env = SingleTurnEnvironment.remote( - "test_env", - self.pg, - rollout_cfg=rollout_cfg, - ) - dataflow_cfg = self.dataflow_cfg - dataflow_cfg.global_batch_size = 2 - dataflow_cfg.staleness_threshold = 1 - dataflow_cfg.enable_partial_rollout = 1 - self.test_flow = DataFlow.remote("test_env", - dataflow_cfg, - self.replay_buffer_cfg, - self.test_env - ) - ray.get(self.test_flow.run.remote(), timeout=300) - rl_status_save = ray.get(self.test_flow.get_replaybuffer_status.remote()) - save_dir = Path(self.temp_dir.name) / 'checkpoints' / 'ckpt-step-2' - save_dir.mkdir(parents=True, exist_ok=True) - - ray.get(self.test_flow.save.remote(save_dir)) - remain_paused_samples_old = rl_status_save["remain_aborted_samples_count"] + rl_status_save["remain_completed_samples_count"] - responses_old = ray.get(self.test_flow.run.remote(num=remain_paused_samples_old, staleness_threshold=0), timeout=300) - rb_status_old = ray.get(self.test_flow.get_replaybuffer_status.remote()) - - mm_info_old = [] - for multimodal_train_infos in responses_old["mm_train_infos"]: - image_grid_thw = multimodal_train_infos["image_grid_thw"].numpy().flatten() - mm_info_old.extend(image_grid_thw) - - ray.get(self.test_flow.resume.remote(save_dir)) - rl_status_resume = ray.get(self.test_flow.get_replaybuffer_status.remote()) - remain_paused_samples_new = rl_status_resume["remain_aborted_samples_count"] + rl_status_resume["remain_completed_samples_count"] - responses_new = ray.get(self.test_flow.run.remote(num=remain_paused_samples_new, staleness_threshold=0), timeout=300) - rb_status_new = ray.get(self.test_flow.get_replaybuffer_status.remote()) - - mm_info_new = [] - for multimodal_train_infos in responses_new["mm_train_infos"]: - image_grid_thw = multimodal_train_infos["image_grid_thw"].numpy().flatten() - mm_info_new.extend(image_grid_thw) - - all_train_prompt_ids_old = [] - for data_items in responses_old["data_groups"]: - for data_item in data_items: - all_train_prompt_ids_old.extend(data_item.data.input_ids) - - all_train_prompt_ids_new = [] - for data_items in responses_new["data_groups"]: - for data_item in data_items: - all_train_prompt_ids_new.extend(data_item.data.input_ids) - - all_train_prompt_ids_old.sort() - all_train_prompt_ids_new.sort() - mm_info_old.sort() - mm_info_new.sort() - self.assertEqual(all_train_prompt_ids_old, all_train_prompt_ids_new) - self.assertEqual(mm_info_old, mm_info_new) - for key in rb_status_old: - self.assertEqual(rb_status_old[key], rb_status_new[key]) - for key in rl_status_save: - self.assertEqual(rl_status_save[key], rl_status_resume[key]) - ray.get(self.test_env.shutdown.remote(), timeout=300) - -if __name__ == "__main__": - unittest.main() \ No newline at end of file diff --git a/tests/ray/rollout_output.jsonl b/tests/rl/rollout_output.jsonl similarity index 100% rename from tests/ray/rollout_output.jsonl rename to tests/rl/rollout_output.jsonl diff --git a/tests/rl/test_agent_loop.py b/tests/rl/test_agent_loop.py new file mode 100644 index 0000000000..1395b358d5 --- /dev/null +++ b/tests/rl/test_agent_loop.py @@ -0,0 +1,236 @@ +import os +import unittest +import copy +import ray +import tempfile +import torch +from transformers import AutoTokenizer +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers +from xtuner.v1.rl.agent_loop import SingleTurnAgentLoopConfig +from xtuner.v1.rl.agent_loop_manager import ( + AgentLoopManagerConfig, + SamplerConfig, + SyncProduceStrategyConfig, + TaskSpecConfig, +) +from xtuner.v1.data_proto.rl_data import RolloutState, Status, SampleParams +from xtuner.v1.rl.rollout import RolloutController +from xtuner.v1.rl.judger.gsm8k import GSM8KJudgerConfig +from xtuner.v1.rl.replay_buffer import SyncReplayBufferConfig +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLTextTokenizeFnConfig + +MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] +MOE_MODEL_PATH = os.environ["QWEN3_MOE_PATH"] +TRAIN_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] +TEST_DATA_PATH = os.environ["ROLLOUT_TEST_DATA_PATH"] +TEST_DIST_PORT_BASE = int(os.environ.get("XTUNER_DIST_PORT_BASE", "35000")) +TEST_NUM_WORKERS = int(os.environ.get("XTUNER_TEST_NUM_WORKERS", "1")) +FAKE_INPUT_ITEM = RolloutState( + message=[{ + 'role': 'user', + 'content': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Let\'s think step by step and output the final answer after "####"' + }], + reward_model={'ground_truth': '72', 'style': 'rule'}, +) +resource_map = { + "npu": "NPU", + "cuda": "GPU", +} + +class TestAgentLoop(unittest.IsolatedAsyncioTestCase): + @classmethod + def setUpClass(cls) -> None: + os.environ["XTUNER_USE_FA3"] = "1" + os.environ["LMD_SKIP_WARMUP"] = "1" + + @classmethod + def tearDownClass(cls) -> None: + del os.environ["XTUNER_USE_FA3"] + del os.environ["LMD_SKIP_WARMUP"] + + def init_config(self): + self.resources_cfg = AcceleratorResourcesConfig( + accelerator=resource_map[torch.accelerator.current_accelerator().type], + num_workers=TEST_NUM_WORKERS, + num_cpus_per_worker=8, + cpu_memory_per_worker=16 * 1024**3, # 16 GB + ) + self.max_prompt_length = 512 + self.max_response_length = 1024 + self.context_length = self.max_prompt_length + self.max_response_length + self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) + + def setUp(self): + ray.init(num_cpus=80, ignore_reinit_error=True) + self.data_path = TRAIN_DATA_PATH + self.model_path = MODEL_PATH + self.temp_dir = tempfile.TemporaryDirectory() + self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") + + def tearDown(self): + ray.shutdown() + self.temp_dir.cleanup() + + async def test_gsm8k_agent_loop(self): + # 1. 初始化 config + self.init_config() + rollout_config = RolloutConfig( + env="test_agent_loop", + model_path=self.model_path, + model_name=os.path.basename(self.model_path).lower(), + tokenizer_path=self.model_path, + context_length=self.context_length, + dist_port_base=TEST_DIST_PORT_BASE, + tensor_parallel_size=TEST_NUM_WORKERS, + worker_log_dir=self.worker_log_dir, + ) + judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", num_ray_actors=1) + agent_loop_cfg = SingleTurnAgentLoopConfig( + hf_checkpoint=self.model_path, + sample_params=SampleParams(max_tokens=self.max_response_length, temperature=0.0), + ) + # 2. 创建 rollout_controller + pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg) + rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg) + # 3. 创建 AgentLoop + agent_loop = agent_loop_cfg.build( + rollout_controller=rollout_controller, + judger=judger_config.build(), + ) + # 4. 构造输入数据 + prompt_repeat_k = 4 + rollout_state = FAKE_INPUT_ITEM + group_in_rollout_state = [FAKE_INPUT_ITEM] * prompt_repeat_k + # 5. 执行 generate_group && generate_sample + group_rollout_state = await agent_loop.generate_group(group_in_rollout_state) + single_rollout_state = await agent_loop.generate_sample(rollout_state) + # 6. 验证结果 + self.assertEqual(len(group_rollout_state), 4) + for state in group_rollout_state: + self.assertEqual(state.status, Status.COMPLETED) + self.assertGreater(len(state.response_ids), 0) + self.assertEqual(single_rollout_state.reward["score"], 1) + self.assertEqual(single_rollout_state.status, Status.COMPLETED) + self.assertGreater(len(single_rollout_state.response_ids), 0) + self.assertEqual(single_rollout_state.reward["score"], 1) + + async def test_gsm8k_agent_loop_with_ray_actor_judger(self): + self.init_config() + rollout_config = RolloutConfig( + env="test_agent_loop_ray_actor", + model_path=self.model_path, + model_name=os.path.basename(self.model_path).lower(), + tokenizer_path=self.model_path, + context_length=self.context_length, + dist_port_base=TEST_DIST_PORT_BASE, + tensor_parallel_size=TEST_NUM_WORKERS, + worker_log_dir=self.worker_log_dir, + ) + judger_config = GSM8KJudgerConfig( + judger_name="openai/gsm8k", + num_ray_actors=1, + num_cpus_per_actor=1, + ) + agent_loop_cfg = SingleTurnAgentLoopConfig( + hf_checkpoint=self.model_path, + sample_params=SampleParams(max_tokens=self.max_response_length, temperature=0.0), + num_ray_actors=1, + num_cpus=1, + ) + + pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg) + rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg) + agent_loop = agent_loop_cfg.build( + rollout_controller=rollout_controller, + judger=judger_config.build(), + ) + + prompt_repeat_k = 2 + rollout_state = copy.deepcopy(FAKE_INPUT_ITEM) + group_in_rollout_state = [copy.deepcopy(FAKE_INPUT_ITEM) for _ in range(prompt_repeat_k)] + + group_rollout_state = await agent_loop.generate_group.remote(group_in_rollout_state) + single_rollout_state = await agent_loop.generate_sample.remote(rollout_state) + + self.assertEqual(len(group_rollout_state), prompt_repeat_k) + for state in group_rollout_state: + self.assertEqual(state.status, Status.COMPLETED) + self.assertGreater(len(state.response_ids), 0) + self.assertEqual(state.reward["score"], 1) + self.assertEqual(single_rollout_state.status, Status.COMPLETED) + self.assertGreater(len(single_rollout_state.response_ids), 0) + self.assertEqual(single_rollout_state.reward["score"], 1) + + async def test_gsm8k_agent_loop_manager(self): + # 1. 初始化 config + self.init_config() + rollout_config = RolloutConfig( + env="test_agent_loop", + model_path=self.model_path, + model_name=os.path.basename(self.model_path).lower(), + tokenizer_path=self.model_path, + context_length=self.context_length, + dist_port_base=TEST_DIST_PORT_BASE, + tensor_parallel_size=TEST_NUM_WORKERS, + worker_log_dir=self.worker_log_dir, + ) + judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", num_ray_actors=1) + agent_loop_cfg = SingleTurnAgentLoopConfig( + hf_checkpoint=self.model_path, + sample_params=SampleParams(max_tokens=self.max_response_length, temperature=0.0), + ) + sampler_config = SamplerConfig( + dataloader_cfg=DataloaderConfig( + dataset_config_list=[ + { + "dataset": DatasetConfig(name="gsm8k", + anno_path=TRAIN_DATA_PATH, + sample_ratio=1.0), + "tokenize_fn": RLTextTokenizeFnConfig(max_length=self.max_prompt_length), + }, + ], + collator='fake_collator', + pack_level='none', + group_by_length=False, + ), + prompt_repeat_k=2, + ) + agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=[ + TaskSpecConfig( + task_name="test_gsm8k", + agent_loop_config=agent_loop_cfg, + judger_config=judger_config, + produce_strategy_config=SyncProduceStrategyConfig(), + sampler_config=sampler_config, + ) + ], + ) + # 2. 创建 rollout_controller + pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg) + rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg) + # 3. 创建 AgentLoopManager + replay_buffer_cfg = SyncReplayBufferConfig() + replay_buffer = replay_buffer_cfg.build() + agent_loop_manager = agent_loop_manager_cfg.build( + rollout_controller=rollout_controller, + tokenizer=self.tokenizer, + replay_buffer=replay_buffer, + ) + # 4. 执行 produce_batch + result = await agent_loop_manager.produce_batch(batch_size=4, train_step=0, model_step=0) + batch_rollout_states = result.rollout_states + # 5. 验证结果 + self.assertEqual(len(batch_rollout_states), 4) + for group_state in batch_rollout_states: + self.assertEqual(len(group_state), 2) + group_message = group_state[0].message + for state in group_state: + self.assertEqual(state.status, Status.COMPLETED) + self.assertGreater(len(state.response_ids), 0) + self.assertEqual(state.message, group_message) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/rl/test_agent_loop_package_layout.py b/tests/rl/test_agent_loop_package_layout.py new file mode 100644 index 0000000000..c4c4b97196 --- /dev/null +++ b/tests/rl/test_agent_loop_package_layout.py @@ -0,0 +1,18 @@ +from xtuner.v1.rl import agent_loop, agent_loop_manager + + +def test_agent_loop_package_only_exports_loop_objects(): + # 这里明确包边界:agent_loop 只放单条 rollout loop,不再兼容导出 manager 对象。 + assert hasattr(agent_loop, "AgentLoop") + assert hasattr(agent_loop, "SingleTurnAgentLoopConfig") + assert not hasattr(agent_loop, "AgentLoopManagerConfig") + assert not hasattr(agent_loop, "SamplerConfig") + assert not hasattr(agent_loop, "ProduceBatchStatus") + + +def test_agent_loop_manager_package_exports_manager_objects(): + # manager 包统一承载批量调度、采样和 producer 相关类型。 + assert hasattr(agent_loop_manager, "AgentLoopManagerConfig") + assert hasattr(agent_loop_manager, "SamplerConfig") + assert hasattr(agent_loop_manager, "ProduceBatchStatus") + assert hasattr(agent_loop_manager, "ProduceBatchResult") diff --git a/tests/rl/test_agent_loop_utils.py b/tests/rl/test_agent_loop_utils.py new file mode 100644 index 0000000000..75ee89c9c2 --- /dev/null +++ b/tests/rl/test_agent_loop_utils.py @@ -0,0 +1,120 @@ +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status, refresh_seq_staleness +from xtuner.v1.rl.agent_loop.single_turn_agent_loop import SingleTurnAgentLoop +from xtuner.v1.rl.agent_loop.utils import PartialRolloutHandler + + +def _make_rollout_state( + response_ids: list[int], + response_model_steps: list[int] | None = None, + seq_staleness: int = 0, + status: Status = Status.ABORTED, + extra_fields: dict | None = None, +): + return RolloutState( + uid=1, + message=[{"role": "user", "content": "hello"}], + prompt_ids=[101, 102], + response_ids=response_ids, + response="resp", + logprobs=[0.0] * len(response_ids), + response_mask=[1] * len(response_ids), + response_model_steps=response_model_steps, + seq_staleness=seq_staleness, + sample_params=SampleParams(max_tokens=8), + status=status, + extra_fields=extra_fields or {}, + ) + + +class TestAgentLoopUtils(unittest.TestCase): + def test_refresh_seq_staleness_recomputes_from_response_model_steps(self): + group = [_make_rollout_state(response_ids=[1, 2], response_model_steps=[3, 4], seq_staleness=0)] + + refresh_seq_staleness(group, current_train_step=8) + + self.assertEqual(group[0].seq_staleness, 4) + + def test_refresh_seq_staleness_resets_without_response_model_steps(self): + group = [_make_rollout_state(response_ids=[1, 2], response_model_steps=None, seq_staleness=5)] + + refresh_seq_staleness(group, current_train_step=8) + + self.assertEqual(group[0].seq_staleness, 0) + + def test_partial_rollout_postprocess_only_concatenates_history(self): + handler = PartialRolloutHandler(max_tokens=8) + rollout_state = _make_rollout_state( + response_ids=[30, 31], + response_model_steps=[2, 2], + seq_staleness=0, + extra_fields={ + "history_response_dict": { + "response_ids": [10, 11], + "response": "hi", + "logprobs": [0.1, 0.2], + "response_mask": [1, 1], + "routed_experts": None, + } + }, + ) + + result = handler.postprocess(rollout_state) + + self.assertEqual(result.response_ids, [10, 11, 30, 31]) + self.assertEqual(result.response_model_steps, [2, 2]) + self.assertEqual(result.seq_staleness, 0) + + +class TestSingleTurnAgentLoop(unittest.IsolatedAsyncioTestCase): + def _build_agent_loop(self): + rollout_ctl = MagicMock() + rollout_ctl.generate.remote = AsyncMock() + with ( + patch("xtuner.v1.rl.agent_loop.agent_loop.load_tokenizer", return_value=MagicMock()), + patch("xtuner.v1.rl.agent_loop.agent_loop.load_processor", return_value=MagicMock()), + ): + return SingleTurnAgentLoop( + rollout_ctl=rollout_ctl, + sample_params=SampleParams(max_tokens=8), + hf_checkpoint="dummy", + judger=None, + logger=MagicMock(), + ) + + async def test_generate_sample_does_not_update_staleness(self): + agent_loop = self._build_agent_loop() + rollout_state = _make_rollout_state(response_ids=[], status=Status.ABORTED) + generated_state = _make_rollout_state(response_ids=[30, 31], seq_staleness=7, status=Status.ABORTED) + agent_loop.rollout_ctl.generate.remote.return_value = generated_state + + result = await agent_loop.generate_sample( + rollout_state, + ) + + self.assertIsNone(result.response_model_steps) + self.assertEqual(result.seq_staleness, 7) + + async def test_generate_sample_does_not_update_sample_version(self): + agent_loop = self._build_agent_loop() + rollout_state = _make_rollout_state(response_ids=[], status=Status.ABORTED) + generated_state = _make_rollout_state(response_ids=[30, 31], status=Status.ABORTED) + agent_loop.rollout_ctl.generate.remote.return_value = generated_state + + result = await agent_loop.generate_sample(rollout_state) + + self.assertIsNone(result.response_model_steps) + self.assertEqual(result.seq_staleness, 0) + + async def test_generate_sample_does_not_require_model_step(self): + agent_loop = self._build_agent_loop() + rollout_state = _make_rollout_state(response_ids=[], status=Status.ABORTED) + generated_state = _make_rollout_state(response_ids=[30, 31], status=Status.ABORTED) + agent_loop.rollout_ctl.generate.remote.return_value = generated_state + + result = await agent_loop.generate_sample(rollout_state) + + self.assertIsNone(result.response_model_steps) + self.assertEqual(result.seq_staleness, 0) diff --git a/tests/rl/test_async_rollout.py b/tests/rl/test_async_rollout.py new file mode 100644 index 0000000000..6edb46878c --- /dev/null +++ b/tests/rl/test_async_rollout.py @@ -0,0 +1,720 @@ +from __future__ import annotations + +import os +import unittest + +import ray +import torch + +from transformers import AutoTokenizer + +from xtuner.v1.data_proto.rl_data import SampleParams, Status +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLTextTokenizeFnConfig +from xtuner.v1.rl.agent_loop import SingleTurnAgentLoopConfig +from xtuner.v1.rl.agent_loop_manager import ( + AgentLoopManagerConfig, + AsyncProduceStrategyConfig, + SamplerConfig, + TaskSpecConfig, +) +from xtuner.v1.rl.replay_buffer import AsyncReplayBufferConfig +from xtuner.v1.rl.rollout import RolloutController +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers + +MODEL_PATH = os.environ.get("ROLLOUT_MODEL_PATH", "") +DATA_PATH = os.environ.get("ROLLOUT_DATA_PATH", "") +MAX_PROMPT_LENGTH = 512 +MAX_RESPONSE_LENGTH = 512 +PACK_MAX_LENGTH = MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH +EXPERIMENTAL_NAME = "async_rl_integration_test" + +_RESOURCE_MAP = {"npu": "NPU", "cuda": "GPU"} + + +def _accelerator_type() -> str: + return _RESOURCE_MAP[torch.accelerator.current_accelerator().type] + + +def _build_rollout_controller(): + """Build a RolloutController backed by a real inference engine.""" + resources_cfg = AcceleratorResourcesConfig( + accelerator=_accelerator_type(), + num_workers=1, + num_cpus_per_worker=8, + cpu_memory_per_worker=16 * 1024**3, + ) + rollout_config = RolloutConfig( + env=EXPERIMENTAL_NAME, + device=resources_cfg.accelerator, + model_path=MODEL_PATH, + gpu_memory_utilization=0.8, + context_length=MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH, + rollout_max_batch_size_per_instance=16, + max_retry_per_sample=0, + ) + pg = AutoAcceleratorWorkers.build_placement_group(resources_cfg) + rollout_ctl = ray.remote(RolloutController).remote(rollout_config, pg) + return rollout_ctl + + +def _build_agent_loop_manager( + rollout_ctl, + task_name: str, + over_sample_threshold: float = 0.0, + enable_partial_rollout: bool = False, + max_staleness: int = 0, + tail_batch_trigger_size: int = 0, + prompt_repeat_k: int = 1, + max_tokens: int = MAX_RESPONSE_LENGTH, +): + """Build an AgentLoopManager backed by a fresh AsyncReplayBuffer.""" + tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + + dataset_cfg = DatasetConfig(name=EXPERIMENTAL_NAME, anno_path=DATA_PATH) + tokenizer_fn_cfg = RLTextTokenizeFnConfig(max_length=MAX_PROMPT_LENGTH) + dataloader_cfg = DataloaderConfig( + dataset_config_list=[{"dataset": dataset_cfg, "tokenize_fn": tokenizer_fn_cfg}], + pack_max_length=PACK_MAX_LENGTH, + collator="fake_collator", + pack_level="none", + ) + sampler_config = SamplerConfig( + dataloader_cfg=dataloader_cfg, + prompt_repeat_k=prompt_repeat_k, + ) + + sample_params = SampleParams( + max_tokens=max_tokens, + temperature=1.0, + top_k=0, + top_p=1.0, + return_token_ids=True, + ) + agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=MODEL_PATH, + sample_params=sample_params, + ) + + produce_strategy_config = AsyncProduceStrategyConfig( + over_sample_threshold=over_sample_threshold, + enable_partial_rollout=enable_partial_rollout, + max_staleness=max_staleness, + tail_batch_trigger_size=tail_batch_trigger_size, + ) + + replay_buffer = AsyncReplayBufferConfig().build() + + manager_cfg = AgentLoopManagerConfig( + tasks=[ + TaskSpecConfig( + task_name=task_name, + agent_loop_config=agent_loop_config, + produce_strategy_config=produce_strategy_config, + sampler_config=sampler_config, + ) + ], + ) + manager = manager_cfg.build( + rollout_controller=rollout_ctl, + tokenizer=tokenizer, + replay_buffer=replay_buffer, + logger=None, + ) + return manager + +class TestOversampling(unittest.IsolatedAsyncioTestCase): + """Oversampling tests (mirrors debug_rollout=True: rollout only, no training). + + Why ABORTED samples are guaranteed: + - over_sample_threshold=2.0 => data_concurrency = 3 * batch_size = 6 tasks + - max_tokens=512 => long responses; most tasks still in-flight + when the first batch_size completions arrive + - _cleanup_pending_tasks() => remaining tasks get abort-signalled and + stored as ABORTED in the replay buffer + """ + + OVER_SAMPLE_THRESHOLD = 2.0 # data_concurrency = 3 * batch_size + BATCH_SIZE = 2 + INITIAL_DATA_CONCURRENCY = int((1 + OVER_SAMPLE_THRESHOLD) * BATCH_SIZE) # = 6 + + @classmethod + def setUpClass(cls) -> None: + os.environ.setdefault("XTUNER_USE_FA3", "1") + os.environ.setdefault("LMD_SKIP_WARMUP", "1") + + def setUp(self): + ray.init(num_cpus=32, ignore_reinit_error=True) + self.rollout_ctl = _build_rollout_controller() + + def tearDown(self): + ray.shutdown() + + async def test_1_1_total_count_after_first_rollout(self): + """1.1: After produce_batch round 1: + + remain_completed + remain_aborted == INITIAL_DATA_CONCURRENCY + + Flow: + 1. strategy starts INITIAL_DATA_CONCURRENCY tasks concurrently. + 2. As soon as BATCH_SIZE completions are collected, the while-loop + exits; remaining pending tasks go through _cleanup_pending_tasks + and are stored as ABORTED. + 3. produce_batch() then calls replay_buffer.get(BATCH_SIZE, COMPLETED) + which consumes exactly BATCH_SIZE items. + 4. Any extras that completed during the abort window remain as + COMPLETED in the buffer. + + Therefore: + remain_completed + remain_aborted == INITIAL_DATA_CONCURRENCY - BATCH_SIZE + + Because every task either ends up COMPLETED or ABORTED in the buffer, + and exactly BATCH_SIZE items are consumed by replay_buffer.get(). + """ + manager = _build_agent_loop_manager( + self.rollout_ctl, + task_name="test_1_1", + over_sample_threshold=self.OVER_SAMPLE_THRESHOLD, + ) + replay_buffer = manager.replay_buffer + + await manager.produce_batch(batch_size=self.BATCH_SIZE, train_step=1, model_step=0) + + remain_completed = await replay_buffer.count( + task_name="test_1_1", group_status=Status.COMPLETED + ) + remain_aborted = await replay_buffer.count( + task_name="test_1_1", group_status=Status.ABORTED + ) + + # Primary assertion: items remaining in buffer after produce_batch consumes + # BATCH_SIZE completed samples == INITIAL_DATA_CONCURRENCY - BATCH_SIZE + expected_remaining = self.INITIAL_DATA_CONCURRENCY - self.BATCH_SIZE + self.assertEqual( + remain_completed + remain_aborted, + expected_remaining, + msg=( + f"remain_completed={remain_completed}, remain_aborted={remain_aborted}, " + f"expected total={expected_remaining} " + f"(= INITIAL_DATA_CONCURRENCY {self.INITIAL_DATA_CONCURRENCY} " + f"- BATCH_SIZE {self.BATCH_SIZE})" + ), + ) + + async def test_1_2_second_rollout_does_not_convert_completed_leftovers(self): + """1.2: Round 2 no longer destructively converts COMPLETED leftovers. + + AsyncProduceStrategy v2.2 keeps completed leftovers in the fresh window. + Only existing ABORTED samples may be re-sampled through the ABORTED pool; + completed samples are either consumed as completed or refreshed/expired by + the manager consumer entry. + """ + manager = _build_agent_loop_manager( + self.rollout_ctl, + task_name="test_1_2", + over_sample_threshold=self.OVER_SAMPLE_THRESHOLD, + ) + replay_buffer = manager.replay_buffer + original_sample = manager.data_sampler.sample + + sampled_from_aborted = 0 + + async def instrumented_sample(task_name, group_status=None, **kwargs): + nonlocal sampled_from_aborted + result = await original_sample( + task_name=task_name, group_status=group_status, **kwargs + ) + # Items fetched from the ABORTED pool still carry status==ABORTED. + if result and result[0].status == Status.ABORTED: + sampled_from_aborted += 1 + return result + + manager.data_sampler.sample = instrumented_sample + + # --- Round 1 --- + await manager.produce_batch(batch_size=self.BATCH_SIZE, train_step=1, model_step=0) + + # After round 1: produce_batch consumed BATCH_SIZE completed items. + # The leftover items (completed but not consumed) stay in the buffer. + round1_remain_completed = await replay_buffer.count( + task_name="test_1_2", group_status=Status.COMPLETED + ) + round1_remain_aborted = await replay_buffer.count( + task_name="test_1_2", group_status=Status.ABORTED + ) + # Total leftover == INITIAL_DATA_CONCURRENCY - BATCH_SIZE + expected_leftover = self.INITIAL_DATA_CONCURRENCY - self.BATCH_SIZE + self.assertEqual( + round1_remain_completed + round1_remain_aborted, + expected_leftover, + msg=( + f"Round 1 leftover: completed={round1_remain_completed}, " + f"aborted={round1_remain_aborted}, expected total={expected_leftover}" + ), + ) + # --- Round 2: reset counter then run --- + sampled_from_aborted = 0 + await manager.produce_batch(batch_size=self.BATCH_SIZE, train_step=2, model_step=1) + + self.assertLessEqual( + sampled_from_aborted, + round1_remain_aborted, + msg=( + "Round 2 should not convert round-1 COMPLETED leftovers into the ABORTED queue: " + f"sampled_from_aborted={sampled_from_aborted}, round1_remain_aborted={round1_remain_aborted}" + ), + ) + + +class TestPartialRollout(unittest.IsolatedAsyncioTestCase): + """Partial-rollout tests. + + All tests inject pre-constructed ABORTED samples directly into the + replay buffer so that Sampler.sample(group_status=[ABORTED]) picks them + up without any mocking. The real AgentLoopManager.produce_batch() is + used throughout. + + Key configuration: + - over_sample_threshold=2.0 → data_concurrency = 3; guarantees concurrent + tasks so the genuine oversampling + partial- + rollout path is exercised (not just injected + into a single-task environment). + - enable_partial_rollout=True → ABORTED samples resume from existing + response_ids instead of starting over. + """ + + BATCH_SIZE = 1 + OVER_SAMPLE = 2.0 # data_concurrency = int((1+2.0)*1) = 3; genuine oversampling + # Short max_tokens for the max-exhausted short-circuit test; medium for multi-round. + MAX_TOKENS_SHORT = 8 + MAX_TOKENS_MULTI = 32 + + @classmethod + def setUpClass(cls) -> None: + os.environ.setdefault("XTUNER_USE_FA3", "1") + os.environ.setdefault("LMD_SKIP_WARMUP", "1") + + def setUp(self): + ray.init(num_cpus=32, ignore_reinit_error=True) + self.rollout_ctl = _build_rollout_controller() + self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + + def tearDown(self): + ray.shutdown() + + def _make_aborted_state(self, uid: int, prompt: str, response_ids: list[int], + response_model_steps: list[int] | None = None, + max_tokens: int = MAX_RESPONSE_LENGTH) -> "RolloutState": + """Helper: build an ABORTED RolloutState with given response_ids.""" + from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status + prompt_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"] + state = RolloutState( + uid=uid, + message=[{"role": "user", "content": prompt}], + prompt_ids=prompt_ids, + sample_params=SampleParams( + max_tokens=max_tokens, + temperature=1.0, + top_k=0, + top_p=1.0, + return_token_ids=True, + ), + status=Status.ABORTED, + response_ids=response_ids, + response="placeholder", + logprobs=[0.0] * len(response_ids), + response_mask=[1] * len(response_ids), + response_model_steps=response_model_steps if response_model_steps is not None else [0] * len(response_ids), + seq_staleness=0, + extra_fields={}, + ) + return state + + async def test_2_1_partial_rollout_response_ids_are_concatenated(self): + """2.1: Partial rollout 的 response_ids 前缀必须保持不变。 + + Setup: + - over_sample_threshold=2.0 → 3 个并发任务;注入的 ABORTED 样本与另外 + 2 个 dataloader 新样本同时运行,真实触发 oversampling + partial-rollout 路径。 + - 注入 uid=9001 的 ABORTED 样本,response_ids=[1000,1001,1002,1003]。 + - 由于多任务竞争,注入样本可能多次被 abort 并在后续轮次继续;每次 + preprocess 以 existing response_ids 为前缀,postprocess 拼接新内容。 + + 断言: 最终完成的 uid=9001 样本的 response_ids 以初始 4 token 为前缀, + 且长度 > 4(确实生成了新内容)。 + """ + from xtuner.v1.data_proto.rl_data import Status + task_name = "test_2_1" + initial_response_ids = [1000, 1001, 1002, 1003] + injected_uid = 9001 + + manager = _build_agent_loop_manager( + self.rollout_ctl, + task_name=task_name, + over_sample_threshold=self.OVER_SAMPLE, + enable_partial_rollout=True, + max_tokens=MAX_RESPONSE_LENGTH, + ) + replay_buffer = manager.replay_buffer + + state = self._make_aborted_state( + uid=injected_uid, + prompt="Count from one.", + response_ids=initial_response_ids, + max_tokens=MAX_RESPONSE_LENGTH, + ) + await replay_buffer.put([state], task_name) + + # Loop: with oversampling the injected sample may be aborted multiple times + # before completing. Search by uid across rounds. + target_sample = None + for train_step in range(1, 15): + completed_groups = await manager.produce_batch( + batch_size=self.BATCH_SIZE, + train_step=train_step, + model_step=train_step - 1, + ) + for group in completed_groups.rollout_states: + for sample in group: + if sample.uid == injected_uid: + target_sample = sample + if target_sample is not None: + break + + self.assertIsNotNone( + target_sample, + msg=f"Injected sample (uid={injected_uid}) never completed within 14 rounds", + ) + final_response_ids = target_sample.response_ids + self.assertGreater( + len(final_response_ids), len(initial_response_ids), + msg="Partial rollout should have appended new tokens", + ) + self.assertEqual( + final_response_ids[: len(initial_response_ids)], + initial_response_ids, + msg="response_ids must start with the original injected prefix", + ) + + async def test_2_2_eos_in_response_skips_inference_engine(self): + """2.2: ABORTED 样本末尾为 EOS token → worker 短路,response_ids 不变。 + + EOS 短路不调用推理引擎,注入样本几乎瞬间完成,在 3 个并发任务中 + 必然最先完成,因此 completed_groups[0][0] 就是注入样本。 + + 断言: 返回样本的 response_ids 与注入时完全相同。 + """ + from xtuner.v1.data_proto.rl_data import Status + from xtuner.v1.rl.rollout.worker import get_eos_token + + task_name = "test_2_2" + eos = get_eos_token(MODEL_PATH) + eos_id = eos[0] if isinstance(eos, list) else eos + initial_response_ids = [1000, 1001, eos_id] + injected_uid = 9002 + + manager = _build_agent_loop_manager( + self.rollout_ctl, + task_name=task_name, + over_sample_threshold=self.OVER_SAMPLE, + enable_partial_rollout=True, + max_tokens=MAX_RESPONSE_LENGTH, + ) + replay_buffer = manager.replay_buffer + + state = self._make_aborted_state( + uid=injected_uid, + prompt="Say hello.", + response_ids=initial_response_ids, + max_tokens=MAX_RESPONSE_LENGTH, + ) + await replay_buffer.put([state], task_name) + + # EOS short-circuit completes with no LLM call → always wins the race. + completed_groups = await manager.produce_batch( + batch_size=self.BATCH_SIZE, + train_step=1, + model_step=0, + ) + completed_groups = completed_groups.rollout_states + + self.assertEqual(len(completed_groups), self.BATCH_SIZE) + final = completed_groups[0][0] + self.assertEqual(final.uid, injected_uid, + msg="EOS short-circuit sample should be the first to complete") + self.assertEqual(final.status, Status.COMPLETED) + self.assertEqual( + final.response_ids, + initial_response_ids, + msg="EOS short-circuit: response_ids must be identical to the injected ones", + ) + + async def test_2_3_max_tokens_exhausted_skips_inference_engine(self): + """2.3: len(response_ids)==max_tokens → remaining_tokens==0 → worker 短路,response_ids 不变。 + + 与 test_2_2 同理,短路不调用推理引擎,注入样本在 3 个并发任务中必然最先完成。 + + 断言: 返回样本的 response_ids 与注入时完全相同。 + """ + from xtuner.v1.data_proto.rl_data import Status + task_name = "test_2_3" + max_tokens = self.MAX_TOKENS_SHORT + initial_response_ids = list(range(1010, 1010 + max_tokens)) # len == max_tokens + injected_uid = 9003 + + manager = _build_agent_loop_manager( + self.rollout_ctl, + task_name=task_name, + over_sample_threshold=self.OVER_SAMPLE, + enable_partial_rollout=True, + max_tokens=max_tokens, + ) + replay_buffer = manager.replay_buffer + + state = self._make_aborted_state( + uid=injected_uid, + prompt="Say hello.", + response_ids=initial_response_ids, + max_tokens=max_tokens, + ) + await replay_buffer.put([state], task_name) + + # max_tokens short-circuit completes with no LLM call → always wins the race. + completed_groups = await manager.produce_batch( + batch_size=self.BATCH_SIZE, + train_step=1, + model_step=0, + ) + completed_groups = completed_groups.rollout_states + + self.assertEqual(len(completed_groups), self.BATCH_SIZE) + final = completed_groups[0][0] + self.assertEqual(final.uid, injected_uid, + msg="max_tokens short-circuit sample should be the first to complete") + self.assertEqual(final.status, Status.COMPLETED) + self.assertEqual( + final.response_ids, + initial_response_ids, + msg="max_tokens exhausted: response_ids must be identical to the injected ones", + ) + + async def test_2_4_multi_round_response_ids_never_exceed_max_tokens(self): + """2.4: 多轮 partial rollout 后 len(response_ids) <= max_tokens。 + + over_sample_threshold=2.0 → 每轮 3 个并发任务;注入样本可能经历多次 + abort + continue 才能完成。无论经历几轮,最终 response_ids 长度不超过 max_tokens。 + + 按 uid 搜索目标样本,最多跑 14 轮。 + """ + from xtuner.v1.data_proto.rl_data import Status + task_name = "test_2_4" + max_tokens = self.MAX_TOKENS_MULTI + injected_uid = 9004 + + manager = _build_agent_loop_manager( + self.rollout_ctl, + task_name=task_name, + over_sample_threshold=self.OVER_SAMPLE, + enable_partial_rollout=True, + max_tokens=max_tokens, + ) + replay_buffer = manager.replay_buffer + + state = self._make_aborted_state( + uid=injected_uid, + prompt="Count from one.", + response_ids=[1020, 1021], # 2 tokens initially; max_tokens=32 + max_tokens=max_tokens, + ) + await replay_buffer.put([state], task_name) + + target_sample = None + for train_step in range(1, 15): + completed_groups = await manager.produce_batch( + batch_size=self.BATCH_SIZE, + train_step=train_step, + model_step=train_step - 1, + ) + for group in completed_groups.rollout_states : + for sample in group: + if sample.uid == injected_uid: + self.assertLessEqual( + len(sample.response_ids), + max_tokens, + msg=( + f"Step {train_step}: accumulated response_ids length " + f"{len(sample.response_ids)} exceeds max_tokens {max_tokens}" + ), + ) + target_sample = sample + if target_sample is not None: + break + + self.assertIsNotNone( + target_sample, + msg=f"Injected sample (uid={injected_uid}) never completed within 14 rounds", + ) + self.assertLessEqual( + len(target_sample.response_ids), + max_tokens, + msg=f"Final response_ids length {len(target_sample.response_ids)} > max_tokens {max_tokens}", + ) + + +class TestTailBatch(unittest.IsolatedAsyncioTestCase): + BATCH_SIZE = 2 + # 真实 lmdeploy 后端在大量并发 abort 时容易触发 session cleanup 异常; + # 这里沿用 oversampling 覆盖里已稳定验证的并发规模,仍足够产生 leftover。 + OVER_SAMPLE = 2.0 # data_concurrency = (1 + 2.0) * BATCH_SIZE = 6 + + @classmethod + def setUpClass(cls) -> None: + os.environ.setdefault("XTUNER_USE_FA3", "1") + os.environ.setdefault("LMD_SKIP_WARMUP", "1") + + def setUp(self): + ray.init(num_cpus=32, ignore_reinit_error=True) + self.rollout_ctl = _build_rollout_controller() + + def tearDown(self): + ray.shutdown() + + async def test_3_1_max_staleness_0_marks_expired(self): + """3.1a: max_staleness=0 — 需要 3 轮才能在 buffer 中观察到 EXPIRED。 + + staleness 积累路径(enable_partial_rollout=True): + Round 1 (step=1): 6 个并发任务,2 个完成后其余被 abort。 + 被 abort 的样本携带 step=1 生成的分段 response,response_model_steps=[1,...]. + Round 2 (step=2): round1 的 ABORTED 样本被续写,多数在 round2 内完成(COMPLETED)。 + postprocess 只拼接 partial rollout 历史;producer put 前记录 response_model_steps 并刷新 staleness。 + 该轮未消费的 COMPLETED 样本会留在 buffer 中。 + Round 3 (step=3): produce_batch 作为消费入口,先刷新 buffer 中的 + COMPLETED 样本,检查 seq_staleness=1 >= stale_threshold=1 → 标为 EXPIRED, + 放回 buffer。由于 trigger_size=0,EXPIRED 样本不在本轮被消费。 + + 断言: round3 结束后 buffer 中 expired > 0。 + """ + from xtuner.v1.data_proto.rl_data import Status + + MAX_STALENESS = 0 + task_name = "test_3_1a" + + manager = _build_agent_loop_manager( + self.rollout_ctl, + task_name=task_name, + over_sample_threshold=self.OVER_SAMPLE, + enable_partial_rollout=True, + max_staleness=MAX_STALENESS, + tail_batch_trigger_size=0, # 只测 EXPIRED 标记,不触发 tail-batch 模式 + # 测试用 rollout context_length 是 MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH; + # max_tokens 不能超过这个测试配置,否则 lmdeploy 会进入超长请求的异常/abort 路径。 + max_tokens=MAX_RESPONSE_LENGTH, + ) + replay_buffer = manager.replay_buffer + + # 3 轮是让 staleness 自然积累并被 produce_batch 入口刷新标记的最少轮数: + # round1 产生 ABORTED(step=1 tokens)→ round2 续写完成并留作 COMPLETED + # → round3 开头刷新 completed 并标 EXPIRED(staleness=1 >= 1) + for train_step in range(1, 5): + await manager.produce_batch( + batch_size=self.BATCH_SIZE, + train_step=train_step, + model_step=train_step - 1, + ) + + expired_count = await replay_buffer.count( + task_name=task_name, group_status=Status.EXPIRED + ) + aborted_count = await replay_buffer.count( + task_name=task_name, group_status=Status.ABORTED + ) + + self.assertGreater( + expired_count, 0, + msg=( + f"max_staleness=0: after 3 rounds (steps 1→3), leftover COMPLETED samples " + f"with seq_staleness=1 should be marked EXPIRED by produce_batch entry refresh. " + f"expired={expired_count}, aborted={aborted_count}" + ), + ) + + async def test_3_2_tail_batch_mode_resets_staleness_to_zero(self): + """3.2: 真实多轮循环自然触发 tail-batch 模式,验证 seq_staleness 重置为 0。 + + 配置: + over_sample_threshold=2.0 → 每轮产生大量遗留样本 + max_staleness=0 → stale_threshold=1,一步即触发 EXPIRED + tail_batch_trigger_size = BATCH_SIZE // 2 = 1 → expired >= 1 即进入 tail-batch + + 流程 (最多 10 轮): + - 在调用 produce_batch 之前读取 expired_before。 + - 若 expired_before >= trigger_size,本轮由 strategy 进入 tail-batch 模式: + 从 EXPIRED 池取样 → preprocess 重置 response_ids=[], response_model_steps=[] + → 全新生成 → postprocess 拼接历史 + → producer put 前记录 response_model_steps 并刷新 staleness 为 0。 + - 取到第一个 tail-batch 完成样本后退出循环。 + + 断言: + 1. tail-batch 模式在 10 轮内被触发。 + 2. 该轮返回的 COMPLETED 样本 seq_staleness == 0。 + """ + from xtuner.v1.data_proto.rl_data import Status + + MAX_STALENESS = 0 + TRIGGER_SIZE = self.BATCH_SIZE // 2 # = 1 + task_name = "test_3_2" + + manager = _build_agent_loop_manager( + self.rollout_ctl, + task_name=task_name, + over_sample_threshold=self.OVER_SAMPLE, + enable_partial_rollout=True, + max_staleness=MAX_STALENESS, + tail_batch_trigger_size=TRIGGER_SIZE, + # 保持在测试用 context_length 内,避免尾批测试被超长生成请求干扰。 + max_tokens=MAX_RESPONSE_LENGTH, + ) + replay_buffer = manager.replay_buffer + + tail_batch_triggered = False + completed_from_tail_batch = None + + for train_step in range(1, 11): + expired_before = await replay_buffer.count( + task_name=task_name, group_status=Status.EXPIRED + ) + + completed_groups = await manager.produce_batch( + batch_size=self.BATCH_SIZE, + train_step=train_step, + model_step=train_step - 1, + ) + completed_groups = completed_groups.rollout_states + + # 进入本轮前 expired >= trigger_size → 本轮就是 tail-batch 轮 + if expired_before >= TRIGGER_SIZE: + tail_batch_triggered = True + if completed_groups: + completed_from_tail_batch = completed_groups[0][0] + break + + self.assertTrue( + tail_batch_triggered, + msg="Tail-batch mode was never triggered within 10 rollout rounds.", + ) + self.assertIsNotNone( + completed_from_tail_batch, + msg="Tail-batch round produced no completed samples.", + ) + self.assertEqual( + completed_from_tail_batch.seq_staleness, 0, + msg=( + f"Tail-batch sample must have seq_staleness=0 (fresh generation), " + f"got seq_staleness={completed_from_tail_batch.seq_staleness}" + ), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/rl/test_auto.py b/tests/rl/test_auto.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/ray/test_cpu_pg.py b/tests/rl/test_cpu_pg.py similarity index 98% rename from tests/ray/test_cpu_pg.py rename to tests/rl/test_cpu_pg.py index c66a5e4954..773e4f3d5a 100644 --- a/tests/ray/test_cpu_pg.py +++ b/tests/rl/test_cpu_pg.py @@ -6,7 +6,7 @@ import httpx import ray -from xtuner.v1.ray.base import AutoCPUWorkers, BaseCPUWorker, CPUResourcesConfig +from xtuner.v1.rl.utils import AutoCPUWorkers, BaseCPUWorker, CPUResourcesConfig @ray.remote(num_cpus=1) diff --git a/tests/rl/test_gateway.py b/tests/rl/test_gateway.py new file mode 100644 index 0000000000..93bceab474 --- /dev/null +++ b/tests/rl/test_gateway.py @@ -0,0 +1,691 @@ +import json +import os +import socket +import subprocess +import tempfile +import threading +import time +import unittest +from pathlib import Path +from typing import Any +from uuid import uuid4 + +import httpx +import ray +import torch + +from xtuner.v1.rl.gateway.adapters import build_api_key_trace_key +from xtuner.v1.rl.gateway.config import GatewayConfig +from xtuner.v1.rl.gateway.server import build_local_gateway_app, serve_gateway +from xtuner.v1.rl.rollout import RolloutController +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers + + +MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] +RESOURCE_MAP = { + "npu": "NPU", + "cuda": "GPU", +} + + +@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") +class TestGatewayProtocolChain(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + os.environ["XTUNER_USE_FA3"] = "1" + os.environ["LMD_SKIP_WARMUP"] = "1" + + @classmethod + def tearDownClass(cls) -> None: + del os.environ["XTUNER_USE_FA3"] + del os.environ["LMD_SKIP_WARMUP"] + + def setUp(self): + ray.init(address="local", ignore_reinit_error=True) + self.temp_dir = tempfile.TemporaryDirectory() + self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") + self.capture_output_path = Path(self.temp_dir.name) / "gateway_capture_output" + self.openai_body_output_path = Path(self.temp_dir.name) / "openai_body.json" + self.anthropic_body_output_path = Path(self.temp_dir.name) / "anthropic_body.json" + self.responses_body_output_path = Path(self.temp_dir.name) / "responses_body.json" + self.controller = None + self.placement_group = None + self.test_run_id = uuid4().hex[:8] + + def tearDown(self): + if self.controller is not None: + try: + ray.get(self.controller.shutdown.remote(), timeout=300) + except Exception: + pass + try: + ray.kill(self.controller, no_restart=True) + except Exception: + pass + if self.placement_group is not None: + ray.util.remove_placement_group(self.placement_group) + ray.shutdown() + self._cleanup_lmdeploy_ray_worker_wrapper() + self.temp_dir.cleanup() + + def _cleanup_lmdeploy_ray_worker_wrapper(self): + try: + subprocess.run( + ["pkill", "-f", "ray::RayWorkerWrapper*"], + capture_output=True, + text=True, + timeout=10, + check=False, + ) + except Exception: + return + + def _build_controller(self): + resource_config = AcceleratorResourcesConfig( + accelerator=RESOURCE_MAP[torch.accelerator.current_accelerator().type], + num_workers=4, + num_cpus_per_worker=16, + cpu_memory_per_worker=8 * 1024**3, + ) + self.placement_group = AutoAcceleratorWorkers.build_placement_group( + resource_config, + name=f"gateway_protocol_pg_{self.test_run_id}", + ) + rollout_config = RolloutConfig( + env=f"test_gateway_protocol_{self.test_run_id}", + model_path=MODEL_PATH, + model_name=os.path.basename(MODEL_PATH).lower(), + tokenizer_path=MODEL_PATH, + tool_call_parser="qwen3", + reasoning_parser="qwen3", + tensor_parallel_size=4, + expert_parallel_size=1, + context_length=1536, + worker_log_dir=os.path.join(self.worker_log_dir, "gateway"), + dist_port_base=42000, + api_host="127.0.0.1", + api_port=30080, + ) + return ray.remote(RolloutController).remote(rollout_config, self.placement_group) + + def _get_rollout_config(self) -> RolloutConfig: + rollout_metadata = ray.get(self.controller.get_rollout_metadata.remote()) + return rollout_metadata["rollout_config"] + + def _read_capture_records(self) -> list[dict]: + if not self.capture_output_path.exists(): + return [] + if self.capture_output_path.is_file(): + with self.capture_output_path.open("r", encoding="utf-8") as f: + return [json.loads(line) for line in f] + records = [] + for capture_file in sorted(self.capture_output_path.glob("*.jsonl")): + with capture_file.open("r", encoding="utf-8") as f: + records.extend(json.loads(line) for line in f) + return records + + def _write_json_output(self, path: Path, payload: dict) -> None: + path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") + + def _capture_records_by_protocol(self, capture_records: list[dict]) -> dict[str, dict]: + return {record["protocol"]: record for record in capture_records} + + def _assert_trace_record_matches_capture( + self, + trace_record, + capture_record: dict, + *, + expected_request_field: str, + expected_request_role: str | None, + expected_response_field: str, + ) -> None: + self.assertIsNotNone(trace_record) + self.assertEqual(trace_record.request_id, capture_record["request_id"]) + self.assertEqual(trace_record.finish_reason, capture_record["rollout_finish_reason"] or capture_record["finish_reason"]) + self.assertEqual(trace_record.status.value, capture_record["status"]) + self.assertGreater(len(trace_record.prompt_ids), 0) + self.assertGreater(len(trace_record.response_ids), 0) + self.assertTrue(trace_record.input_text) + self.assertTrue(trace_record.output_text) + self.assertGreater(capture_record["prompt_tokens"], 0) + self.assertGreater(capture_record["completion_tokens"], 0) + self.assertTrue(capture_record["input_text"]) + self.assertIn(expected_request_field, trace_record.request_snapshot) + if expected_request_role is not None: + self.assertEqual(trace_record.request_snapshot[expected_request_field][0]["role"], expected_request_role) + self.assertIn(expected_response_field, trace_record.response_snapshot) + + def _find_free_port(self) -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return int(sock.getsockname()[1]) + + def _wait_for_gateway_ready(self, base_url: str, *, timeout_seconds: float = 120.0) -> None: + deadline = time.time() + timeout_seconds + last_error = None + while time.time() < deadline: + try: + response = httpx.get(f"{base_url}/livez", timeout=5.0) + if response.status_code == 200: + return + except Exception as exc: + last_error = exc + time.sleep(1.0) + if last_error is not None: + raise AssertionError(f"Gateway did not become ready at {base_url}: {last_error}") from last_error + raise AssertionError(f"Gateway did not become ready at {base_url}") + + def _serve_gateway_blocking_in_background(self, app, config: GatewayConfig) -> tuple[str, threading.Thread]: + thread = threading.Thread( + target=serve_gateway, + args=(app, config), + daemon=True, + name=f"gateway-blocking-{config.port}", + ) + thread.start() + base_url = self._wait_for_gateway_ready_from_config(config) + return base_url, thread + + def _wait_for_gateway_ready_from_config(self, config: GatewayConfig, *, timeout_seconds: float = 120.0) -> str: + deadline = time.time() + timeout_seconds + last_error = None + while time.time() < deadline: + base_url = f"http://127.0.0.1:{config.port}" + try: + response = httpx.get(f"{base_url}/livez", timeout=5.0) + if response.status_code == 200: + return base_url + except Exception as exc: + last_error = exc + time.sleep(1.0) + if last_error is not None: + raise AssertionError(f"Gateway did not become ready for config port {config.port}: {last_error}") from last_error + raise AssertionError(f"Gateway did not become ready for config port {config.port}") + + def _post_json( + self, + base_url: str, + path: str, + payload: dict, + *, + api_key: str | None = None, + ) -> httpx.Response: + headers = {"Authorization": f"Bearer {api_key}"} if api_key else None + return httpx.post(f"{base_url}{path}", json=payload, headers=headers, timeout=120.0) + + def _get_json(self, base_url: str, path: str) -> httpx.Response: + return httpx.get(f"{base_url}{path}", timeout=30.0) + + def start_rollout_controller_and_gateway(self) -> tuple[RolloutConfig, GatewayConfig, str, Any]: + self.controller = self._build_controller() + rollout_config = self._get_rollout_config() + gateway_config = GatewayConfig(port=self._find_free_port(), capture_folder=str(self.capture_output_path)) + app = build_local_gateway_app(self.controller, config=gateway_config) + base_url, _ = self._serve_gateway_blocking_in_background(app, gateway_config) + return rollout_config, gateway_config, base_url, app + + def test_gateway_runtime_endpoints(self): + rollout_config, _, base_url, _ = self.start_rollout_controller_and_gateway() + + livez_response = self._get_json(base_url, "/livez") + self.assertEqual(livez_response.status_code, 200, livez_response.text) + self.assertEqual(livez_response.json(), {"status": "ok"}) + + readyz_response = self._get_json(base_url, "/readyz") + self.assertEqual(readyz_response.status_code, 200, readyz_response.text) + readyz_body = readyz_response.json() + self.assertTrue(readyz_body["ready"]) + self.assertEqual(readyz_body["status"], "ready") + self.assertIsInstance(readyz_body["details"], dict) + + capabilities_response = self._get_json(base_url, "/capabilities") + self.assertEqual(capabilities_response.status_code, 200, capabilities_response.text) + capabilities_body = capabilities_response.json() + self.assertEqual(capabilities_body["model"], rollout_config.model_name) + self.assertEqual(capabilities_body["backend"], rollout_config.rollout_backend) + self.assertEqual(capabilities_body["context_length"], rollout_config.context_length) + self.assertTrue(capabilities_body["supports_stream"]) + self.assertTrue(capabilities_body["supports_tools"]) + self.assertFalse(capabilities_body["supports_cancel"]) + self.assertTrue(capabilities_body["supports_parallel_tool_calls"]) + self.assertTrue(capabilities_body["supports_reasoning"]) + + def test_gateway_messages(self): + rollout_config, _, base_url, app = self.start_rollout_controller_and_gateway() + + openai_payload = { + "model": rollout_config.model_name, + "messages": [ + {"role": "user", "content": "你好,请用一句话介绍自己。"}, + ], + "max_tokens": 256, + } + anthropic_payload = { + "model": rollout_config.model_name, + "messages": [ + {"role": "user", "content": [{"type": "text", "text": "今天北京天气怎么样?"}]}, + ], + "tools": [ + { + "name": "get_weather", + "description": "查询指定城市的实时天气", + "input_schema": { + "type": "object", + "properties": {"city": {"type": "string", "description": "城市名称"}}, + "required": ["city"], + }, + } + ], + "tool_choice": {"type": "auto"}, + "max_tokens": 512, + } + responses_payload = { + "model": rollout_config.model_name, + "instructions": "你是一个数学助手,回答要简洁。", + "input": [ + { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": "1 + 1 等于几?"}], + }, + ], + "max_output_tokens": 1024, + } + + openai_adapter = app.state.gateway_openai_adapter + anthropic_adapter = app.state.gateway_anthropic_adapter + responses_adapter = app.state.gateway_responses_adapter + openai_api_key = "trace-openai" + anthropic_api_key = "trace-anthropic" + responses_api_key = "trace-responses" + openai_trace_key = build_api_key_trace_key(openai_api_key) + anthropic_trace_key = build_api_key_trace_key(anthropic_api_key) + responses_trace_key = build_api_key_trace_key(responses_api_key) + + openai_response = self._post_json( + base_url, + "/v1/chat/completions", + openai_payload, + api_key=openai_api_key, + ) + self.assertEqual(openai_response.status_code, 200, openai_response.text) + openai_body = openai_response.json() + self._write_json_output(self.openai_body_output_path, openai_body) + self.assertEqual(openai_body["model"], rollout_config.model_name) + self.assertEqual(openai_body["choices"][0]["message"]["role"], "assistant") + self.assertIn(openai_body["choices"][0]["finish_reason"], {"stop", "length"}) + self.assertGreater(openai_body["usage"]["prompt_tokens"], 0) + self.assertTrue(openai_body["choices"][0]["message"].get("content")) + + anthropic_response = self._post_json( + base_url, + "/v1/messages", + anthropic_payload, + api_key=anthropic_api_key, + ) + self.assertEqual(anthropic_response.status_code, 200, anthropic_response.text) + anthropic_body = anthropic_response.json() + self._write_json_output(self.anthropic_body_output_path, anthropic_body) + self.assertEqual(anthropic_body["type"], "message") + self.assertEqual(anthropic_body["role"], "assistant") + self.assertEqual(anthropic_body["model"], rollout_config.model_name) + self.assertGreater(anthropic_body["usage"]["input_tokens"], 0) + self.assertTrue(anthropic_body["content"]) + + responses_response = self._post_json( + base_url, + "/v1/responses", + responses_payload, + api_key=responses_api_key, + ) + self.assertEqual(responses_response.status_code, 200, responses_response.text) + responses_body = responses_response.json() + self._write_json_output(self.responses_body_output_path, responses_body) + self.assertEqual(responses_body["object"], "response") + self.assertEqual(responses_body["model"], rollout_config.model_name) + self.assertGreater(responses_body["usage"]["input_tokens"], 0) + self.assertTrue(responses_body["output"]) + + openai_traces = openai_adapter.get_trace_records(openai_trace_key) + anthropic_traces = anthropic_adapter.get_trace_records(anthropic_trace_key) + responses_traces = responses_adapter.get_trace_records(responses_trace_key) + self.assertEqual(len(openai_traces), 1) + self.assertEqual(len(anthropic_traces), 1) + self.assertEqual(len(responses_traces), 1) + openai_trace = openai_traces[0] + anthropic_trace = anthropic_traces[0] + responses_trace = responses_traces[0] + self.assertEqual(openai_trace.trace_key, openai_trace_key) + self.assertEqual(anthropic_trace.trace_key, anthropic_trace_key) + self.assertEqual(responses_trace.trace_key, responses_trace_key) + self.assertNotEqual(openai_trace.trace_key, openai_api_key) + self.assertNotEqual(anthropic_trace.trace_key, anthropic_api_key) + self.assertNotEqual(responses_trace.trace_key, responses_api_key) + self.assertEqual(openai_trace.sequence, 0) + self.assertEqual(anthropic_trace.sequence, 0) + self.assertEqual(responses_trace.sequence, 0) + self.assertGreater(openai_trace.created_at, 0.0) + self.assertGreater(anthropic_trace.created_at, 0.0) + self.assertGreater(responses_trace.created_at, 0.0) + + capture_records = self._read_capture_records() + self.assertGreaterEqual(len(capture_records), 3) + protocol_records = {record["protocol"]: record for record in capture_records[-3:]} + self.assertIn("OpenAIChatAdapter", protocol_records) + self.assertIn("AnthropicChatAdapter", protocol_records) + self.assertIn("OpenAIResponsesAdapter", protocol_records) + + openai_record = protocol_records["OpenAIChatAdapter"] + self.assertTrue(openai_record["internal_messages"]) + self.assertEqual(openai_record["request_id"], openai_trace.request_id) + self.assertEqual(openai_record["output_messages"][0]["role"], "assistant") + self.assertTrue(openai_record["input_text"]) + self._assert_trace_record_matches_capture( + openai_trace, + openai_record, + expected_request_field="messages", + expected_request_role="user", + expected_response_field="choices", + ) + self.assertEqual(openai_trace.response_snapshot["choices"][0]["message"]["role"], "assistant") + + anthropic_record = protocol_records["AnthropicChatAdapter"] + self.assertEqual(anthropic_record["request_id"], anthropic_trace.request_id) + self.assertTrue(anthropic_record["output_messages"][0]["content"]) + self._assert_trace_record_matches_capture( + anthropic_trace, + anthropic_record, + expected_request_field="messages", + expected_request_role="user", + expected_response_field="content", + ) + self.assertEqual(anthropic_trace.request_snapshot["messages"][0]["role"], "user") + self.assertEqual(anthropic_trace.response_snapshot["role"], "assistant") + + responses_record = protocol_records["OpenAIResponsesAdapter"] + self.assertTrue(responses_record["output_messages"]) + self.assertEqual(responses_record["request_id"], responses_trace.request_id) + self.assertTrue(responses_record["input_text"]) + self._assert_trace_record_matches_capture( + responses_trace, + responses_record, + expected_request_field="input", + expected_request_role=None, + expected_response_field="output", + ) + self.assertEqual(responses_trace.response_snapshot["status"], "completed") + + openai_trace_get_response = httpx.get( + f"{base_url}/trace_store", + headers={"Authorization": f"Bearer {openai_api_key}"}, + timeout=30.0, + ) + self.assertEqual(openai_trace_get_response.status_code, 200, openai_trace_get_response.text) + openai_trace_get_body = openai_trace_get_response.json() + self.assertEqual(openai_trace_get_body["trace_key"], openai_trace_key) + self.assertEqual(openai_trace_get_body["count"], 1) + self.assertEqual(openai_trace_get_body["records"][0]["request_id"], openai_trace.request_id) + self.assertEqual(openai_trace_get_body["records"][0]["status"], openai_trace.status.value) + self.assertEqual(openai_trace_get_body["records"][0]["sequence"], openai_trace.sequence) + + openai_trace_pop_response = httpx.post( + f"{base_url}/trace_store/pop", + headers={"Authorization": f"Bearer {openai_api_key}"}, + timeout=30.0, + ) + self.assertEqual(openai_trace_pop_response.status_code, 200, openai_trace_pop_response.text) + openai_trace_pop_body = openai_trace_pop_response.json() + self.assertEqual(openai_trace_pop_body["trace_key"], openai_trace_key) + self.assertEqual(openai_trace_pop_body["count"], 1) + self.assertEqual(openai_trace_pop_body["records"][0]["request_id"], openai_trace.request_id) + + anthropic_trace_pop_response = httpx.post( + f"{base_url}/trace_store/pop", + params={"trace_key": anthropic_trace_key}, + timeout=30.0, + ) + self.assertEqual(anthropic_trace_pop_response.status_code, 200, anthropic_trace_pop_response.text) + anthropic_trace_pop_body = anthropic_trace_pop_response.json() + self.assertEqual(anthropic_trace_pop_body["trace_key"], anthropic_trace_key) + self.assertEqual(anthropic_trace_pop_body["count"], 1) + self.assertEqual(anthropic_trace_pop_body["records"][0]["request_id"], anthropic_trace.request_id) + + responses_trace_clear_response = httpx.post( + f"{base_url}/trace_store/clear", + params={"trace_key": responses_trace_key}, + timeout=30.0, + ) + self.assertEqual(responses_trace_clear_response.status_code, 200, responses_trace_clear_response.text) + responses_trace_clear_body = responses_trace_clear_response.json() + self.assertEqual(responses_trace_clear_body["trace_key"], responses_trace_key) + self.assertTrue(responses_trace_clear_body["cleared"]) + + self.assertEqual(openai_adapter.get_trace_records(openai_trace_key), []) + self.assertEqual(anthropic_adapter.get_trace_records(anthropic_trace_key), []) + self.assertEqual(responses_adapter.get_trace_records(responses_trace_key), []) + + def test_gateway_ir_fallback_behavior(self): + self.controller = self._build_controller() + rollout_config = self._get_rollout_config() + gateway_config = GatewayConfig(port=self._find_free_port(), capture_folder=str(self.capture_output_path)) + app = build_local_gateway_app(self.controller, config=gateway_config) + base_url, _ = self._serve_gateway_blocking_in_background(app, gateway_config) + + openai_payload = { + "model": rollout_config.model_name, + "messages": [ + {"role": "user", "content": "Call the search tool if you need it."}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_bad_openai", + "type": "function", + "function": { + "name": "search", + "arguments": "not-json", + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "call_bad_openai", + "content": "Sunny, 26C", + }, + {"role": "user", "content": "Finish the answer in one sentence. DONE"}, + ], + "tools": [ + { + "type": "function", + "function": { + "name": "search", + "description": "Search the latest weather.", + "parameters": { + "type": "object", + "properties": {"q": {"type": "string"}}, + "required": ["q"], + }, + }, + } + ], + "tool_choice": { + "type": "function", + "function": {"name": "search"}, + }, + "temperature": 0.2, + "top_p": 0.9, + "presence_penalty": 0.6, + "frequency_penalty": 0.4, + "stop": ["DONE"], + "max_tokens": 32, + } + openai_invalid_n_payload = { + **openai_payload, + "n": 2, + } + anthropic_payload = { + "model": rollout_config.model_name, + "system": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": "abc", + }, + } + ], + "messages": [ + {"role": "user", "content": [{"type": "text", "text": "hello"}]}, + ], + "max_tokens": 32, + } + responses_payload = { + "model": rollout_config.model_name, + "instructions": "Follow the system rule.", + "input": [ + { + "type": "message", + "role": "developer", + "content": [{"type": "input_text", "text": "Use concise answers."}], + }, + { + "type": "reasoning", + "summary": [{"type": "summary_text", "text": "Need private reasoning first."}], + }, + { + "type": "function_call", + "call_id": "call_bad_responses", + "name": "search", + "arguments": "not-json", + }, + { + "type": "function_call_output", + "call_id": "call_bad_responses", + "output": [{"type": "text", "text": "Sunny, 26C"}], + }, + { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": "Answer now."}], + }, + ], + "tools": [ + { + "type": "function", + "name": "search", + "description": "Search the latest weather.", + "parameters": { + "type": "object", + "properties": {"q": {"type": "string"}}, + "required": ["q"], + }, + }, + { + "type": "web_search_preview", + "name": "web_search_preview", + }, + ], + "tool_choice": {"type": "function", "name": "search"}, + "parallel_tool_calls": True, + "store": True, + "include": ["reasoning.encrypted_content"], + "reasoning": {"effort": "high"}, + "temperature": 0.1, + "top_p": 0.8, + "max_output_tokens": 32, + } + responses_invalid_content_payload = { + **responses_payload, + "input": [ + { + "type": "message", + "role": "developer", + "content": [ + {"type": "input_text", "text": "Use concise answers."}, + {"type": "image", "image_url": "https://example.com/ignored.png"}, + ], + } + ], + } + responses_stream_payload = { + **responses_payload, + "stream": True, + } + openai_response = self._post_json(base_url, "/v1/chat/completions", openai_payload) + self.assertEqual(openai_response.status_code, 200, openai_response.text) + + openai_invalid_n_response = self._post_json(base_url, "/v1/chat/completions", openai_invalid_n_payload) + self.assertEqual(openai_invalid_n_response.status_code, 400, openai_invalid_n_response.text) + openai_invalid_n_body = openai_invalid_n_response.json() + self.assertEqual(openai_invalid_n_body["error"]["type"], "invalid_request_error") + self.assertEqual(openai_invalid_n_body["error"]["code"], "n_not_supported") + + anthropic_response = self._post_json(base_url, "/v1/messages", anthropic_payload) + self.assertEqual(anthropic_response.status_code, 400, anthropic_response.text) + anthropic_error_body = anthropic_response.json() + self.assertEqual(anthropic_error_body["type"], "error") + self.assertEqual(anthropic_error_body["error"]["type"], "invalid_request_error") + self.assertIn("Unsupported Anthropic content block type(s) in system: image", anthropic_error_body["error"]["message"]) + + responses_response = self._post_json(base_url, "/v1/responses", responses_payload) + self.assertEqual(responses_response.status_code, 200, responses_response.text) + + responses_invalid_content_response = self._post_json(base_url, "/v1/responses", responses_invalid_content_payload) + self.assertEqual(responses_invalid_content_response.status_code, 400, responses_invalid_content_response.text) + responses_invalid_content_body = responses_invalid_content_response.json() + self.assertEqual(responses_invalid_content_body["error"]["type"], "invalid_request_error") + self.assertEqual(responses_invalid_content_body["error"]["code"], "unsupported_content_block") + + responses_stream_response = self._post_json(base_url, "/v1/responses", responses_stream_payload) + self.assertEqual(responses_stream_response.status_code, 200, responses_stream_response.text) + self.assertEqual( + responses_stream_response.headers.get("content-type"), + "text/event-stream; charset=utf-8", + ) + self.assertIn("event: response.created", responses_stream_response.text) + self.assertIn("event: response.completed", responses_stream_response.text) + + capture_records = self._read_capture_records() + protocol_records = self._capture_records_by_protocol(capture_records) + self.assertIn("OpenAIChatAdapter", protocol_records) + self.assertIn("OpenAIResponsesAdapter", protocol_records) + self.assertNotIn("AnthropicChatAdapter", protocol_records) + + openai_record = protocol_records["OpenAIChatAdapter"] + self.assertEqual(openai_record["rollout_tool_choice"], {"type": "function", "function": {"name": "search"}}) + self.assertEqual(len(openai_record["rollout_tools"]), 1) + self.assertEqual(openai_record["rollout_tools"][0]["function"]["name"], "search") + self.assertEqual(openai_record["rollout_sample_params"]["presence_penalty"], 0.6) + self.assertEqual(openai_record["rollout_sample_params"]["frequency_penalty"], 0.4) + self.assertEqual(openai_record["rollout_sample_params"]["temperature"], 0.2) + self.assertEqual(openai_record["rollout_sample_params"]["top_p"], 0.9) + self.assertEqual(openai_record["rollout_sample_params"]["stops"], ["DONE"]) + self.assertEqual( + openai_record["internal_messages"][1]["tool_calls"][0]["function"]["arguments"], + {"raw": "not-json"}, + ) + + responses_record = protocol_records["OpenAIResponsesAdapter"] + self.assertEqual(responses_record["rollout_tool_choice"], {"type": "function", "function": {"name": "search"}}) + self.assertEqual(len(responses_record["rollout_tools"]), 1) + self.assertEqual(responses_record["rollout_tools"][0]["function"]["name"], "search") + self.assertTrue(responses_record["rollout_sample_params"]["max_tokens"] <= 32) + self.assertEqual(responses_record["rollout_sample_params"]["temperature"], 0.1) + self.assertEqual(responses_record["rollout_sample_params"]["top_p"], 0.8) + self.assertNotIn("store", responses_record["rollout_sample_params"]) + self.assertNotIn("include", responses_record["rollout_sample_params"]) + self.assertEqual(responses_record["internal_messages"][0]["role"], "system") + self.assertEqual(responses_record["internal_messages"][0]["content"], "Follow the system rule.") + self.assertEqual(responses_record["internal_messages"][1]["role"], "system") + self.assertEqual(responses_record["internal_messages"][1]["content"], "Use concise answers.") + self.assertEqual( + responses_record["internal_messages"][3]["tool_calls"][0]["function"]["arguments"], + {"raw": "not-json"}, + ) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/rl/test_judger.py b/tests/rl/test_judger.py new file mode 100644 index 0000000000..3261651b3f --- /dev/null +++ b/tests/rl/test_judger.py @@ -0,0 +1,216 @@ +import os +import json +import ray +import unittest +import tempfile +import numpy as np +import asyncio +from xtuner.v1.rl.utils import AutoCPUWorkers, CPUResourcesConfig +from xtuner.v1.data_proto.rl_data import RolloutState + +MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] +DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] +GEO_ROLLOUT_DATA_PATH = os.environ["GEO_ROLLOUT_DATA_PATH"] +VERL_ROLLOUT_DATA_PATH = os.environ["VERL_ROLLOUT_DATA_PATH"] +DAPO_DATA_PATH = os.environ.get("ROLLOUT_DAPO_DATA_PATH") +FAKE_JUDGER_INPUT_ITEM = RolloutState( + message=[{ + 'role': 'user', + 'content': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Let\'s think step by step and output the final answer after "####"' + }], + reward_model={'ground_truth': '72', 'style': 'rule'}, + response="\nOkay, let's see. Natalia sold clips to 48 friends in April. Then in May, she sold half as many. So first, I need to figure out how many she sold in May. Half of 48 is 24, right? Because 48 divided by 2 is 24. So in May, she sold 24 clips.\n\nNow, to find the total number of clips sold in both months, I need to add the number from April and May together. That would be 48 (April) plus 24 (May). Let me do the addition: 48 + 24. Hmm, 40 + 20 is 60, and 8 + 4 is 12. So 60 + 12 is 72. So altogether, she sold 72 clips.\n\nWait, let me check that again. 48 plus 24. Yes, 48 + 20 is 68, then plus 4 more is 72. Yep, that seems right. So the total is 72.\n\n\nNatalia sold 48 clips in April. In May, she sold half as many, which is 48 ÷ 2 = 24 clips. Adding both months together: 48 + 24 = 72. \n\n#### 72<|im_end|>" +) + +def construct_gsm8k_judger_data(data_path) -> tuple[list[RolloutState], list[float]]: + states = [] + history_reward = [] + if not data_path or not os.path.exists(data_path): + return states + with open(data_path, 'r', encoding='utf-8') as f: + for line in f: + item = json.loads(line.strip()) + prompt = item["input"][5:-11] + response = item["output"] + gt = item["gts"] + states.append( + RolloutState( + message=[{"role": "user", "content": prompt}], + response=response, + reward_model={"ground_truth": str(gt)} + ) + ) + history_reward.append(item["reward"]) + return states, history_reward + +def construct_geo3k_dapo_judger_data(data_path) -> tuple[list[RolloutState], list[float]]: + states = [] + history_reward = [] + if not data_path or not os.path.exists(data_path): + return states + with open(data_path, 'r', encoding='utf-8') as f: + lines = f.readlines() + for i in range(0, len(lines), 7): + group = ''.join(lines[i:i + 7]).strip() + if not group: continue + item = json.loads(group) + states.append( + RolloutState( + message=[{"role": "user", "content": ""}], + response=item['response'], + reward_model={"ground_truth": str(item["label"])} + ) + ) + history_reward.append(item["reward"]) + return states, history_reward + +class TestJudgerController(unittest.TestCase): + + def setUp(self): + ray.init(num_cpus=80, ignore_reinit_error=True) + self.temp_dir = tempfile.TemporaryDirectory() + self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") + + def tearDown(self): + ray.shutdown() + self.temp_dir.cleanup() + + async def _judger_batch(self, judger_router, states): + return await asyncio.gather(*(judger_router.judge(s) for s in states)) + + def test_gsm8k_judger(self): + from xtuner.v1.rl.judger.gsm8k import GSM8KJudgerConfig + + gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", num_ray_actors=1) + # Test Case 1: NativeJudger + native_judger = GSM8KJudgerConfig(judger_name="openai/gsm8k").build() + res1 = asyncio.run(native_judger.judge(FAKE_JUDGER_INPUT_ITEM)) + self.assertEqual(res1.reward["score"], 1.0) + + # Test Case 2: remote judger with given pg + cpu_cfg = CPUResourcesConfig(num_workers=1, num_cpus_per_worker=1) + pg = AutoCPUWorkers.build_placement_group(cpu_cfg) + ray.get(pg.ready()) + native_judger_actors = gsm8k_judger_config.build(pg, 0) + res2 = asyncio.run(native_judger_actors.judge(FAKE_JUDGER_INPUT_ITEM)) + self.assertEqual(res2.reward["score"], 1.0) + del native_judger_actors + + # Test Case 3: JudgerPool + 一批数据的分数是否正确 + judger_router = gsm8k_judger_config.build(pg) + states, history_reward = construct_gsm8k_judger_data(VERL_ROLLOUT_DATA_PATH) + rollout_states = asyncio.run(self._judger_batch(judger_router, states)) + rewards = [s.reward["score"] for s in rollout_states] + expected_avg_score = np.mean(history_reward) + self.assertEqual(round(np.mean(rewards), 4), round(expected_avg_score, 4)) + + def test_dapo_batch_judge_score(self): + # 测试 dapo judger + 1 个实例池 的评判分数是否正确 + from xtuner.v1.rl.judger.dapo_math import DapoMathJudgerConfig + from xtuner.v1.utils.rl_test_utils import get_eos_token + from transformers import AutoTokenizer + # 构建数据 + states, history_reward = construct_geo3k_dapo_judger_data(DAPO_DATA_PATH) + tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + eos_token = get_eos_token(MODEL_PATH) + eos_token_str = tokenizer.convert_ids_to_tokens(eos_token) + # 定义 Judger Config + config = DapoMathJudgerConfig( + judger_name="dapo_math", + num_ray_actors=1, + eos_token=eos_token_str, + enable_overlong_buffer=True, + max_response_len=32768, + overlong_buffer_len=4096, + overlong_penalty_factor=1.0, + tokenizer=tokenizer + ) + router = config.build() + rollout_states = asyncio.run(self._judger_batch(router, states)) + rewards = [s.reward["score"] for s in rollout_states] + expected_avg_score = np.mean(history_reward) + self.assertEqual(round(np.mean(rewards), 4), round(expected_avg_score, 4)) + + def test_geo_batch_judge_score(self): + # 测试 geo judger + 4 个实例池的评判分数是否正确 + from xtuner.v1.rl.judger.geo3k import GEO3KJudgerConfig + config = GEO3KJudgerConfig(judger_name="geo3k", num_ray_actors=4) + states, history_reward = construct_geo3k_dapo_judger_data(GEO_ROLLOUT_DATA_PATH) + router = config.build() + rollout_states = asyncio.run(self._judger_batch(router, states)) + rewards = [s.reward["score"] for s in rollout_states] + expected_avg_score = np.mean(history_reward) + self.assertEqual(round(np.mean(rewards), 4), round(expected_avg_score, 4)) + # 验证Router中确实有4个Worker实例在运行 + self.assertEqual(len(router.get_worker_status()), 4) + + def test_multi_judger_router(self): + import time + from xtuner.v1.rl.judger.gsm8k import GSM8KJudgerConfig + + gsm8k_config_1 = GSM8KJudgerConfig( + judger_name="openai/gsm8k_1", + num_ray_actors=2, + num_cpus_per_actor=1, + ) + gsm8k_config_2 = GSM8KJudgerConfig( + judger_name="openai/gsm8k_2", + num_ray_actors=8, + num_cpus_per_actor=2, + ) + + gsm8k_router_1 = gsm8k_config_1.build() + gsm8k_router_2 = gsm8k_config_2.build() + + states, history_reward = construct_gsm8k_judger_data(VERL_ROLLOUT_DATA_PATH) + gsm8k_results_1 = asyncio.run(self._judger_batch(gsm8k_router_1, states)) + gsm8k_results_2 = asyncio.run(self._judger_batch(gsm8k_router_2, states)) + + gsm8k_rewards_1 = [s.reward["score"] for s in gsm8k_results_1] + gsm8k_rewards_2 = [s.reward["score"] for s in gsm8k_results_2] + + expected_avg_score = np.mean(history_reward) + self.assertEqual(round(np.mean(gsm8k_rewards_1), 4), round(expected_avg_score, 4)) + self.assertEqual(round(np.mean(gsm8k_rewards_2), 4), round(expected_avg_score, 4)) + self.assertEqual(len(gsm8k_router_1.get_worker_status()), 2) + self.assertEqual(len(gsm8k_router_2.get_worker_status()), 8) + + def test_gsm8k_remote_judger(self): + # 测试输入remote_url时 + 1个实例 + 裸的NativeJudger的评判分数是否正确 + from xtuner.v1.utils.rl_test_utils import JudgerServer, GSM8KRemoteJudgerConfig + + server = JudgerServer(port=8018) + server.start() + try: + remote_judger_config = GSM8KRemoteJudgerConfig(judger_name="openai/gsm8k", reward_handler=server.url) + native_remote_judger = remote_judger_config.build() + res = asyncio.run(native_remote_judger.judge(FAKE_JUDGER_INPUT_ITEM)) + self.assertEqual(res.reward["score"], 1.0) + finally: + server.stop() + + def test_composed_judger_config(self): + from xtuner.v1.rl.judger import ComposedJudgerConfig, JudgerConfig + + def reward_a(response, label, extra_info): + return {"score": 1.0, "source": "a"} + + def reward_b(response, label, extra_info): + return {"score": 0.25, "source": "b"} + + judger_config = ComposedJudgerConfig( + branches={ + "correctness": JudgerConfig(judger_name="correctness", reward_handler=reward_a), + "format": JudgerConfig(judger_name="format", reward_handler=reward_b), + }, + select_fn=lambda state, branches: ["correctness", "format"], + ) + + judger = judger_config.build() + rollout_state = asyncio.run(judger.judge(FAKE_JUDGER_INPUT_ITEM.model_copy(deep=True))) + + self.assertEqual(rollout_state.reward["correctness"], 1.0) + self.assertEqual(rollout_state.reward["format"], 0.25) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/rl/test_mock_rollout.py b/tests/rl/test_mock_rollout.py new file mode 100644 index 0000000000..4ac1ec4b39 --- /dev/null +++ b/tests/rl/test_mock_rollout.py @@ -0,0 +1,193 @@ +import os +import asyncio +import unittest +import ray +from transformers import AutoTokenizer +import torch +import tempfile +import httpx +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.rollout.lmdeploy import LMDeployWorker +from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers +from xtuner.v1.data_proto.rl_data import RolloutState, Status +from xtuner.v1.rl.rollout.controller import RolloutController +from xtuner.v1.utils.httpx_utils import HttpRequestErrorType, HttpRequestResult + +TEST_TEXT_MESSAGES=[{"role": "user", "content": "Hello!"}] +MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] +TRAIN_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] +resource_map = {"npu": "NPU", "cuda": "GPU"} + +class MockTimeoutRolloutWorker(LMDeployWorker): + async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: + try: + raise httpx.TimeoutException("Mocked timeout error") + except Exception as e: + error_type = HttpRequestErrorType.from_exception(e) + result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) + self.logger.info(f"Caught mocked timeout exception: {e.__class__.__name__}") + return result + + def _launch_server(self): + pass # Override + + +class MockRequestErrorRolloutWorker(LMDeployWorker): + async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: + try: + req = httpx.Request("POST", url) + raise httpx.RequestError("Mocked httpx request error", request=req) + except Exception as e: + error_type = HttpRequestErrorType.from_exception(e) + result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) + self.logger.info(f"Caught mocked request error exception: {e.__class__.__name__}") + return result + + def _launch_server(self): + pass # Override + + +class MockClientErrorRolloutWorker(LMDeployWorker): + async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: + try: + req = httpx.Request("POST", url) + res = httpx.Response(400, request=req) + raise httpx.HTTPStatusError("Mocked client error", request=req, response=res) + except Exception as e: + error_type = HttpRequestErrorType.from_exception(e) + result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) + self.logger.info(f"Caught mocked client exception: {e.__class__.__name__}") + return result + + def _launch_server(self): + pass # Override + + +class MockServerErrorRolloutWorker(LMDeployWorker): + async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: + try: + req = httpx.Request("POST", url) + res = httpx.Response(500, request=req) + raise httpx.HTTPStatusError("Mocked server error", request=req, response=res) + except Exception as e: + error_type = HttpRequestErrorType.from_exception(e) + result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) + self.logger.info(f"Caught mocked server exception: {e.__class__.__name__}") + return result + + def _launch_server(self): + pass # Override + +class MockInvalidResponseRolloutWorker(LMDeployWorker): + async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: + mock_rollout_state = RolloutState(message=TEST_TEXT_MESSAGES, status=Status.FAILED) + result = HttpRequestResult(response=mock_rollout_state) + return result + + async def _safe_handle_response(self, rollout_state, http_response) -> RolloutState: + mock_rollout_state = RolloutState(message=TEST_TEXT_MESSAGES, status=Status.FAILED) + return mock_rollout_state + + def _launch_server(self): + pass # Override + +@ray.remote +class MockTimeoutRolloutController(RolloutController): + def _get_worker_cls(self): return ray.remote(MockTimeoutRolloutWorker) + +@ray.remote +class MockRequestErrorRolloutController(RolloutController): + def _get_worker_cls(self): return ray.remote(MockRequestErrorRolloutWorker) + +@ray.remote +class MockClientErrorRolloutController(RolloutController): + def _get_worker_cls(self): return ray.remote(MockClientErrorRolloutWorker) + +@ray.remote +class MockServerErrorRolloutController(RolloutController): + def _get_worker_cls(self): return ray.remote(MockServerErrorRolloutWorker) + +@ray.remote +class MockInvalidResponseRolloutController(RolloutController): + def _get_worker_cls(self): return ray.remote(MockInvalidResponseRolloutWorker) + +class TestMockRollout(unittest.TestCase): + @classmethod + def setUpClass(cls): + os.environ["XTUNER_USE_FA3"] = "1" + + @classmethod + def tearDownClass(cls): + del os.environ["XTUNER_USE_FA3"] + + def setUp(self): + current_dir = os.path.abspath(os.path.dirname(__file__)) + python_path = f"{current_dir}:{os.environ.get('PYTHONPATH', '')}" + + ray.init(num_cpus=80, ignore_reinit_error=True, runtime_env={"env_vars": {"PYTHONPATH": python_path}}) + self.global_batch_size = 3 + self.max_prompt_length = 4096 + self.max_response_length = 128 + self.max_concurrent = 3 + self.max_retry_times = 3 + self.temp_dir = tempfile.TemporaryDirectory() + self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") + self.rollout_cfg = RolloutConfig( + env="test_mock_rollout", + model_path=MODEL_PATH, + model_name=os.path.basename(MODEL_PATH).lower(), + tokenizer_path=MODEL_PATH, + tensor_parallel_size=1, + context_length=self.max_prompt_length + self.max_response_length, + max_retry_per_worker=2, + max_retry_per_sample=3, + worker_log_dir=self.worker_log_dir, + ) + self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + + def tearDown(self): + ray.shutdown() + self.temp_dir.cleanup() + + async def _run_mock_test(self, mock_controller_cls, error_name, pg): + rollout_controller = mock_controller_cls.remote(self.rollout_cfg, pg) + input_state = RolloutState(message=TEST_TEXT_MESSAGES) + result_state = await rollout_controller.generate.remote(rollout_state=input_state) + self.assertEqual(result_state.status, Status.FAILED, f"Expected rollout to fail due to {error_name}, but it succeeded.") + self.assertIsNotNone(result_state.error_msg, f"Expected an error message for {error_name} case, but got None.") + if error_name == "server_error": + self.assertIn("Server error", result_state.error_msg, f"Expected error message to indicate a server error for {error_name} case, but got: {result_state.error_msg}") + elif error_name == "client_error": + self.assertIn("Client error", result_state.error_msg, f"Expected error message to indicate a client error for {error_name} case, but got: {result_state.error_msg}") + elif error_name in ["request_error", "timeout"]: + self.assertIn("Request failed", result_state.error_msg, f"Expected error message to indicate a request error for {error_name} case, but got: {result_state.error_msg}") + self.assertIn(str(self.rollout_cfg.max_retry_per_sample), result_state.error_msg, f"Expected error message to include max retry times for {error_name} case, but got: {result_state.error_msg}") + elif error_name == "invalid_response": + self.assertIn("Invalid rollout response", result_state.error_msg, f"Expected error message to indicate an invalid response for {error_name} case, but got: {result_state.error_msg}") + self.assertIn(str(self.rollout_cfg.max_retry_per_sample), result_state.error_msg, f"Expected error message to include max retry times for {error_name} case, but got: {result_state.error_msg}") + + @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") + def test_parallel_mock_rollout(self): + async def run_parallel(): + res_cfg_small = AcceleratorResourcesConfig( + accelerator=resource_map[torch.accelerator.current_accelerator().type], + num_workers=1, + num_cpus_per_worker=2, + ) + + pgs = [AutoAcceleratorWorkers.build_placement_group(res_cfg_small, name=f"pg_{i}") for i in range(5)] + await asyncio.gather(*[pg.ready() for pg in pgs]) + + tasks = [ + self._run_mock_test(MockTimeoutRolloutController, "timeout", pgs[0]), + self._run_mock_test(MockRequestErrorRolloutController, "request_error", pgs[1]), + self._run_mock_test(MockClientErrorRolloutController, "client_error", pgs[2]), + self._run_mock_test(MockServerErrorRolloutController, "server_error", pgs[3]), + self._run_mock_test(MockInvalidResponseRolloutController, "invalid_response", pgs[4]), + ] + await asyncio.gather(*tasks) + + asyncio.run(run_parallel()) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/rl/test_multi_task_agent_loop_manager.py b/tests/rl/test_multi_task_agent_loop_manager.py new file mode 100644 index 0000000000..bf14324d7f --- /dev/null +++ b/tests/rl/test_multi_task_agent_loop_manager.py @@ -0,0 +1,728 @@ +import asyncio +import json +import tempfile +import unittest +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock + +from xtuner.v1.rl.agent_loop_manager.agent_loop_manager import ( + AgentLoopManager, + AgentLoopManagerConfig, + AgentLoopManagerStatus, + TaskSpecConfig, + _TaskRunner, +) +from xtuner.v1.rl.agent_loop_manager.producer import GROUP_GENERATE_TIME_KEY, ProduceBatchStatus +from xtuner.v1.data_proto.rl_data import Status +from xtuner.v1.rl.utils import calculate_seq_staleness + + +class _FakeSampler: + def __init__(self, size: int = 1): + self._size = size + self.saved_paths: list[Path] = [] + self.resumed_paths: list[Path] = [] + + def __len__(self) -> int: + return self._size + + def save(self, checkpoint_path): + self.saved_paths.append(Path(checkpoint_path)) + return None + + def resume(self, checkpoint_path): + self.resumed_paths.append(Path(checkpoint_path)) + return None + + +class _FakeProduceStrategy: + def __init__( + self, + status: ProduceBatchStatus = ProduceBatchStatus.NORMAL, + cleanup_pause_time_s: float = 0.0, + stale_threshold: int = 1, + ): + self.status = status + self.cleanup_pause_time_s = cleanup_pause_time_s + self.stale_threshold = stale_threshold + self.called_batch_sizes: list[int] = [] + self.called_train_steps: list[int] = [] + self.called_model_steps: list[int] = [] + self.called_update_events: list[object | None] = [] + self.called_update_event_states: list[bool | None] = [] + self.called_progresses: list[object] = [] + self.called_target_cumulatives: list[int | None] = [] + self.cleanup_progresses: list[object | None] = [] + self.cleanup_call_count = 0 + + async def produce_batch( + self, + agent_loop, + sampler, + replay_buffer, + batch_size: int, + task_name: str, + train_step: int = 0, + update_event=None, + *, + model_step: int, + progress, + target_cumulative: int | None = None, + ) -> ProduceBatchStatus: + self.called_batch_sizes.append(batch_size) + self.called_train_steps.append(train_step) + self.called_model_steps.append(model_step) + self.called_update_events.append(update_event) + self.called_update_event_states.append(None if update_event is None else update_event.is_set()) + self.called_progresses.append(progress) + self.called_target_cumulatives.append(target_cumulative) + return self.status + + async def pause_produce(self, agent_loop, replay_buffer, task_name: str, *, progress) -> float: + self.cleanup_call_count += 1 + self.cleanup_progresses.append(progress) + return self.cleanup_pause_time_s + + +class _FakeStatusProduceStrategy: + def __init__(self, status: ProduceBatchStatus, pause_time_s: float): + self.status = status + self.pause_time_s = pause_time_s + self.cleanup_call_count = 0 + self.called_train_steps: list[int] = [] + self.called_model_steps: list[int] = [] + self.called_update_events: list[object | None] = [] + self.called_update_event_states: list[bool | None] = [] + self.called_progresses: list[object] = [] + self.called_target_cumulatives: list[int | None] = [] + self.cleanup_progresses: list[object | None] = [] + + async def produce_batch( + self, + agent_loop, + sampler, + replay_buffer, + batch_size: int, + task_name: str, + train_step: int = 0, + update_event=None, + *, + model_step: int, + progress, + target_cumulative: int | None = None, + ) -> ProduceBatchStatus: + self.called_train_steps.append(train_step) + self.called_model_steps.append(model_step) + self.called_update_events.append(update_event) + self.called_update_event_states.append(None if update_event is None else update_event.is_set()) + self.called_progresses.append(progress) + self.called_target_cumulatives.append(target_cumulative) + return self.status + + async def pause_produce(self, agent_loop, replay_buffer, task_name: str, *, progress) -> float: + self.cleanup_call_count += 1 + self.cleanup_progresses.append(progress) + return self.pause_time_s + + +class _FakeRolloutState: + def __init__(self, uid: str, group_generate_time_s: float): + self.uid = uid + self.extra_fields = {GROUP_GENERATE_TIME_KEY: group_generate_time_s} + + +class _FakeStalenessRolloutState(_FakeRolloutState): + def __init__(self, uid: str, group_generate_time_s: float, response_model_steps: list[int], seq_staleness: int = 0): + super().__init__(uid, group_generate_time_s) + self.response_model_steps = response_model_steps + self.seq_staleness = seq_staleness + + +class _SequencedProduceStrategy(_FakeProduceStrategy): + def __init__(self, statuses: list[ProduceBatchStatus], cleanup_pause_time_s: float = 0.0): + super().__init__(status=ProduceBatchStatus.NORMAL, cleanup_pause_time_s=cleanup_pause_time_s) + self._statuses = list(statuses) + + async def produce_batch( + self, + agent_loop, + sampler, + replay_buffer, + batch_size: int, + task_name: str, + train_step: int = 0, + update_event=None, + *, + model_step: int, + progress, + target_cumulative: int | None = None, + ) -> ProduceBatchStatus: + self.called_batch_sizes.append(batch_size) + self.called_train_steps.append(train_step) + self.called_model_steps.append(model_step) + self.called_update_events.append(update_event) + self.called_update_event_states.append(None if update_event is None else update_event.is_set()) + self.called_progresses.append(progress) + self.called_target_cumulatives.append(target_cumulative) + return self._statuses.pop(0) if self._statuses else ProduceBatchStatus.NORMAL + + +class _FakeReplayBuffer: + def __init__(self, rollout_states_by_task: dict[str, list[list[str]]], leftover_counts: dict[tuple[str, Status], int]): + self._rollout_states_by_task = rollout_states_by_task + self._leftover_counts = leftover_counts + self.saved_paths: list[Path] = [] + self.resumed_paths: list[Path] = [] + self.refresh_staleness_calls: list[tuple[str, int, int, tuple[Status, ...]]] = [] + + async def get(self, batch_size: int, task_name: str, group_status: Status): + assert group_status == Status.COMPLETED + groups = self._rollout_states_by_task.get(task_name, []) + selected = groups[:batch_size] + self._rollout_states_by_task[task_name] = groups[batch_size:] + return selected + + async def count(self, task_name: str, group_status: Status): + return self._leftover_counts.get((task_name, group_status), 0) + + async def refresh_staleness( + self, + task_name: str, + current_train_step: int, + stale_threshold: int, + statuses: list[Status] | None = None, + ): + self.refresh_staleness_calls.append((task_name, current_train_step, stale_threshold, tuple(statuses or ()))) + for group in self._rollout_states_by_task.get(task_name, []): + for state in group: + response_model_steps = getattr(state, "response_model_steps", None) or [] + if response_model_steps and hasattr(state, "seq_staleness"): + state.seq_staleness = calculate_seq_staleness( + min(response_model_steps), current_train_step + ) + return 0 + + async def save(self, checkpoint_path: Path | str): + self.saved_paths.append(Path(checkpoint_path)) + + async def resume(self, checkpoint_path: Path | str): + self.resumed_paths.append(Path(checkpoint_path)) + + +class _SequencedCompletedReplayBuffer(_FakeReplayBuffer): + def __init__(self, completed_counts: list[int], rollout_states_by_task: dict[str, list[list[str]]]): + super().__init__(rollout_states_by_task=rollout_states_by_task, leftover_counts={}) + self._completed_counts = list(completed_counts) + self.get_calls: list[tuple[int, str, Status]] = [] + self.completed_count_call_count = 0 + + async def get(self, batch_size: int, task_name: str, group_status: Status): + self.get_calls.append((batch_size, task_name, group_status)) + return await super().get(batch_size, task_name, group_status) + + async def count(self, task_name: str, group_status: Status): + if group_status == Status.COMPLETED: + self.completed_count_call_count += 1 + if self._completed_counts: + return self._completed_counts.pop(0) + return await super().count(task_name, group_status) + + +def _fake_agent_loop(): + rollout_ctl = MagicMock() + rollout_ctl.continue_generation.remote = AsyncMock() + rollout_ctl.pause_generation.remote = AsyncMock() + rollout_ctl.get_rollout_metadata.remote = AsyncMock(return_value={"server_url_dict": {}}) + agent_loop = MagicMock() + agent_loop.rollout_ctl = rollout_ctl + return agent_loop + + +class TestMultiTaskAgentLoopManager(unittest.IsolatedAsyncioTestCase): + async def _wait_until(self, predicate, timeout_s: float = 1.0): + deadline = asyncio.get_running_loop().time() + timeout_s + while asyncio.get_running_loop().time() < deadline: + if predicate(): + return + await asyncio.sleep(0.01) + self.fail("Timed out waiting for condition.") + + def test_manager_config_accepts_single_task_spec(self): + task = TaskSpecConfig.model_construct( + task_name="single_task", + agent_loop_config=MagicMock(), + produce_strategy_config=MagicMock(), + sampler_config=MagicMock(), + weight=1.0, + ) + + manager_config = AgentLoopManagerConfig(tasks=task) + + self.assertEqual(manager_config.tasks.task_name, "single_task") + + async def test_produce_batch_allocates_by_weight_and_returns_task_sorted_results(self): + strategy_a = _FakeProduceStrategy() + strategy_b = _FakeProduceStrategy() + strategy_c = _FakeProduceStrategy() + replay_buffer = _FakeReplayBuffer( + rollout_states_by_task={ + "task_a": [["a-0"], ["a-1"]], + "task_b": [["b-0"], ["b-1"], ["b-2"]], + "task_c": [], + }, + leftover_counts={ + ("task_a", Status.COMPLETED): 1, + ("task_b", Status.ABORTED): 2, + }, + ) + + multi_task_manager = AgentLoopManager( + task_runners=[ + _TaskRunner( + task_name="task_b", + agent_loop=_fake_agent_loop(), + produce_strategy=strategy_b, + sampler=_FakeSampler(), + weight=1.0, + order=0, + ), + _TaskRunner( + task_name="task_a", + agent_loop=_fake_agent_loop(), + produce_strategy=strategy_a, + sampler=_FakeSampler(), + weight=2.0, + order=1, + ), + _TaskRunner( + task_name="task_c", + agent_loop=_fake_agent_loop(), + produce_strategy=strategy_c, + sampler=_FakeSampler(), + weight=0.0, + order=2, + ), + ], + replay_buffer=replay_buffer, + ) + multi_task_manager._status = AgentLoopManagerStatus.UPDATE_ABORT + multi_task_manager._update_event.set() + + result = await multi_task_manager.produce_batch(batch_size=7, train_step=3, model_step=2) + + self.assertEqual(result.task_batch_sizes, {"task_a": 5, "task_b": 2, "task_c": 0}) + # sync produce_batch 在本轮入口恢复 NORMAL,收尾 pause 后保留 UPDATE_ABORT 到下一轮入口再清理。 + self.assertEqual(multi_task_manager._status, AgentLoopManagerStatus.UPDATE_ABORT) + self.assertTrue(multi_task_manager._update_event.is_set()) + self.assertEqual(multi_task_manager._model_step, 2) + self.assertEqual(strategy_a.called_batch_sizes, [5]) + self.assertEqual(strategy_b.called_batch_sizes, [2]) + self.assertEqual(strategy_c.called_batch_sizes, []) + self.assertEqual(strategy_a.called_train_steps, [3]) + self.assertEqual(strategy_b.called_train_steps, [3]) + self.assertEqual(strategy_a.called_model_steps, [2]) + self.assertEqual(strategy_b.called_model_steps, [2]) + self.assertEqual(len(strategy_a.called_update_events), 1) + self.assertFalse(strategy_a.called_update_event_states[0]) + self.assertEqual(len(strategy_b.called_update_events), 1) + self.assertFalse(strategy_b.called_update_event_states[0]) + self.assertEqual(result.rollout_states, [["a-0"], ["a-1"], ["b-0"], ["b-1"]]) + self.assertEqual(result.leftover_completed, 1) + self.assertEqual(result.leftover_aborted, 2) + self.assertEqual(result.leftover_expired, 0) + self.assertIn("task_a", result.task_results) + self.assertIn("task_b", result.task_results) + self.assertIn("task_c", result.task_results) + + def test_save_and_resume_roundtrip_restores_paused_manager_state(self): + sampler = _FakeSampler() + replay_buffer = _FakeReplayBuffer({}, {}) + manager = AgentLoopManager( + task_runners=[ + _TaskRunner( + task_name="task_a", + agent_loop=_fake_agent_loop(), + produce_strategy=_FakeProduceStrategy(stale_threshold=5), + sampler=sampler, + weight=1.0, + order=0, + ) + ], + replay_buffer=replay_buffer, + ) + manager._status = AgentLoopManagerStatus.EXPIRED_BATCH + manager._model_step = 2 + manager._pause_time_s = 1.5 + + with tempfile.TemporaryDirectory() as tmp_dir: + checkpoint_path = Path(tmp_dir) + manager.save(checkpoint_path, model_step=7) + + state_path = checkpoint_path / manager._MANAGER_STATE_PATH + with state_path.open("r") as f: + state = json.load(f) + + self.assertEqual(state["status"], "EXPIRED_BATCH") + self.assertEqual(state["model_step"], 7) + self.assertNotIn("model_step_override", state) + self.assertEqual(state["next_consumer_step"], 1) + self.assertEqual(state["producer_future_step"], 1) + self.assertEqual(state["consumed_samples"], {"task_a": 0}) + self.assertEqual(state["target_samples"], {"task_a": 0}) + self.assertEqual(state["target_upto_future_step"], 0) + + restored_step = manager.resume(checkpoint_path) + + self.assertEqual(restored_step, 7) + self.assertEqual(manager._status, AgentLoopManagerStatus.UPDATE_ABORT) + self.assertTrue(manager._update_event.is_set()) + self.assertFalse(manager._finish_event.is_set()) + self.assertEqual(manager._pause_time_s, 0.0) + self.assertEqual(manager._model_step, 7) + self.assertEqual(sampler.saved_paths, [Path(tmp_dir) / manager._TASK_CHECKPOINT_DIR / "task_a"]) + self.assertEqual(sampler.resumed_paths, [Path(tmp_dir) / manager._TASK_CHECKPOINT_DIR / "task_a"]) + self.assertEqual(replay_buffer.saved_paths, [Path(tmp_dir)]) + self.assertEqual(replay_buffer.resumed_paths, [Path(tmp_dir)]) + + def test_save_rejects_pending_async_tasks(self): + strategy = _FakeProduceStrategy() + strategy._pending_tasks = {object()} + manager = AgentLoopManager( + task_runners=[ + _TaskRunner( + task_name="task_a", + agent_loop=_fake_agent_loop(), + produce_strategy=strategy, + sampler=_FakeSampler(), + weight=1.0, + order=0, + ) + ], + replay_buffer=_FakeReplayBuffer({}, {}), + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + with self.assertRaisesRegex(RuntimeError, "pending rollout tasks"): + manager.save(tmp_dir, model_step=0) + + async def test_custom_get_task_batch_sizes_can_disable_tasks(self): + strategy_a = _FakeProduceStrategy() + strategy_b = _FakeProduceStrategy() + replay_buffer = _FakeReplayBuffer( + rollout_states_by_task={ + "task_a": [["a-0"]], + "task_b": [["b-0"], ["b-1"]], + }, + leftover_counts={}, + ) + + class _CustomBatchManager(AgentLoopManager): + def get_task_batch_sizes(self, global_batch_size: int, train_step: int) -> dict[str, int]: + self.observed_train_step = train_step + return {"task_a": 0, "task_b": global_batch_size} + + multi_task_manager = _CustomBatchManager( + task_runners=[ + _TaskRunner( + task_name="task_a", + agent_loop=_fake_agent_loop(), + produce_strategy=strategy_a, + sampler=_FakeSampler(), + weight=1.0, + order=0, + ), + _TaskRunner( + task_name="task_b", + agent_loop=_fake_agent_loop(), + produce_strategy=strategy_b, + sampler=_FakeSampler(), + weight=1.0, + order=1, + ), + ], + replay_buffer=replay_buffer, + ) + + result = await multi_task_manager.produce_batch(batch_size=2, train_step=9, model_step=8) + + self.assertEqual(multi_task_manager.observed_train_step, 9) + self.assertEqual(result.task_batch_sizes, {"task_a": 0, "task_b": 2}) + self.assertEqual(strategy_a.called_batch_sizes, []) + self.assertEqual(strategy_b.called_batch_sizes, [2]) + self.assertEqual(result.rollout_states, [["b-0"], ["b-1"]]) + + async def test_status_returning_strategy_uses_cleanup_and_reconstructs_group_timing_stats(self): + strategy = _FakeStatusProduceStrategy(status=ProduceBatchStatus.NORMAL, pause_time_s=1.25) + replay_buffer = _FakeReplayBuffer( + rollout_states_by_task={ + "task_a": [ + [_FakeRolloutState("a-0", 0.5)], + [_FakeRolloutState("a-1", 1.0)], + ], + }, + leftover_counts={}, + ) + + manager = AgentLoopManager( + task_runners=[ + _TaskRunner( + task_name="task_a", + agent_loop=_fake_agent_loop(), + produce_strategy=strategy, + sampler=_FakeSampler(), + weight=1.0, + order=0, + ), + ], + replay_buffer=replay_buffer, + ) + + result = await manager.produce_batch(batch_size=2, train_step=7, model_step=6) + + self.assertEqual(strategy.cleanup_call_count, 1) + self.assertEqual(len(strategy.cleanup_progresses), 1) + self.assertIsNotNone(strategy.cleanup_progresses[0]) + self.assertEqual(strategy.cleanup_progresses[0].consumed_samples["task_a"], 2) + self.assertEqual(manager._model_step, 6) + self.assertEqual(strategy.called_train_steps, [7]) + self.assertEqual(strategy.called_model_steps, [6]) + self.assertEqual(len(strategy.called_update_events), 1) + self.assertFalse(strategy.called_update_event_states[0]) + self.assertEqual(manager._status, AgentLoopManagerStatus.UPDATE_ABORT) + self.assertTrue(manager._update_event.is_set()) + self.assertEqual(result.group_gen_count, 2) + self.assertAlmostEqual(result.group_gen_mean_s, 0.75) + self.assertAlmostEqual(result.group_gen_p50_s, 1.0) + self.assertAlmostEqual(result.group_gen_p99_s, 1.0) + self.assertAlmostEqual(result.group_gen_pause_time_s, 1.25) + + async def test_produce_batch_requires_non_empty_rollout_states(self): + manager = AgentLoopManager( + task_runners=[ + _TaskRunner( + task_name="task_a", + agent_loop=_fake_agent_loop(), + produce_strategy=_FakeStatusProduceStrategy(status=ProduceBatchStatus.NORMAL, pause_time_s=0.0), + sampler=_FakeSampler(), + weight=1.0, + order=0, + ), + ], + replay_buffer=_FakeReplayBuffer(rollout_states_by_task={}, leftover_counts={}), + ) + + with self.assertRaisesRegex(AssertionError, "must return non-empty rollout_states"): + await manager.produce_batch(batch_size=1, train_step=3, model_step=2) + + async def test_pause_produce_from_async_produce_loop_sets_status_and_pause_time(self): + strategy = _FakeProduceStrategy(cleanup_pause_time_s=2.5) + manager = AgentLoopManager( + task_runners=[ + _TaskRunner( + task_name="task_a", + agent_loop=_fake_agent_loop(), + produce_strategy=strategy, + sampler=_FakeSampler(), + weight=1.0, + order=0, + ), + ], + replay_buffer=_FakeReplayBuffer({}, {}), + ) + + pause_time_s = await manager.pause_produce(use_global_progress=True) + + self.assertEqual(pause_time_s, 2.5) + self.assertEqual(strategy.cleanup_call_count, 1) + self.assertEqual(len(strategy.cleanup_progresses), 1) + self.assertIs(strategy.cleanup_progresses[0], manager._produce_progress) + self.assertTrue(manager._update_event.is_set()) + self.assertEqual(manager._status, AgentLoopManagerStatus.UPDATE_ABORT) + self.assertEqual(manager._pause_time_s, 2.5) + + async def test_pause_produce_validates_progress_selection_before_state_change(self): + strategy = _FakeProduceStrategy(cleanup_pause_time_s=2.5) + manager = AgentLoopManager( + task_runners=[ + _TaskRunner( + task_name="task_a", + agent_loop=_fake_agent_loop(), + produce_strategy=strategy, + sampler=_FakeSampler(), + weight=1.0, + order=0, + ), + ], + replay_buffer=_FakeReplayBuffer({}, {}), + ) + + with self.assertRaisesRegex(ValueError, "progress must not be provided"): + await manager.pause_produce(use_global_progress=True, progress=object()) + self.assertFalse(manager._update_event.is_set()) + self.assertEqual(manager._status, AgentLoopManagerStatus.NORMAL) + + with self.assertRaisesRegex(ValueError, "progress must be provided"): + await manager.pause_produce(use_global_progress=False) + self.assertFalse(manager._update_event.is_set()) + self.assertEqual(manager._status, AgentLoopManagerStatus.NORMAL) + self.assertEqual(strategy.cleanup_call_count, 0) + + async def test_get_batch_returns_expired_batch_when_manager_is_expired(self): + manager = AgentLoopManager( + task_runners=[ + _TaskRunner( + task_name="task_a", + agent_loop=_fake_agent_loop(), + produce_strategy=_FakeProduceStrategy(), + sampler=_FakeSampler(), + weight=1.0, + order=0, + ), + ], + replay_buffer=_FakeReplayBuffer({}, {}), + ) + manager._status = AgentLoopManagerStatus.EXPIRED_BATCH + + result = await manager.get_batch(batch_size=2, train_step=11) + + self.assertEqual(result.status, ProduceBatchStatus.EXPIRED_BATCH) + self.assertEqual(result.rollout_states, []) + + async def test_get_batch_refreshes_staleness_at_entry(self): + replay_buffer = _FakeReplayBuffer( + rollout_states_by_task={ + "task_a": [[_FakeStalenessRolloutState("a-0", 0.2, response_model_steps=[4], seq_staleness=0)]], + }, + leftover_counts={("task_a", Status.COMPLETED): 1}, + ) + manager = AgentLoopManager( + task_runners=[ + _TaskRunner( + task_name="task_a", + agent_loop=_fake_agent_loop(), + produce_strategy=_FakeProduceStrategy(stale_threshold=5), + sampler=_FakeSampler(), + weight=1.0, + order=0, + ), + ], + replay_buffer=replay_buffer, + ) + + result = await manager.get_batch(batch_size=1, train_step=9) + + self.assertEqual( + replay_buffer.refresh_staleness_calls, + [ + ("task_a", 9, 5, (Status.COMPLETED, Status.ABORTED)), + ("task_a", 10, 5, (Status.COMPLETED, Status.ABORTED)), + ], + ) + self.assertEqual(result.rollout_states[0][0].seq_staleness, 4) + self.assertEqual(manager._produce_progress.next_consumer_step, 10) + self.assertEqual(manager._produce_progress.consumed_samples["task_a"], 1) + + async def test_get_batch_waits_until_requested_batch_size_is_ready(self): + replay_buffer = _SequencedCompletedReplayBuffer( + completed_counts=[0, 1, 2], + rollout_states_by_task={ + "task_a": [ + [_FakeStalenessRolloutState("a-0", 0.2, response_model_steps=[4], seq_staleness=0)], + [_FakeStalenessRolloutState("a-1", 0.3, response_model_steps=[4], seq_staleness=0)], + ], + }, + ) + manager = AgentLoopManager( + task_runners=[ + _TaskRunner( + task_name="task_a", + agent_loop=_fake_agent_loop(), + produce_strategy=_FakeProduceStrategy(), + sampler=_FakeSampler(), + weight=1.0, + order=0, + ), + ], + replay_buffer=replay_buffer, + ) + manager._STATUS_POLL_INTERVAL_S = 0.01 + + result = await asyncio.wait_for(manager.get_batch(batch_size=2, train_step=9), timeout=1.0) + + self.assertEqual([group[0].uid for group in result.rollout_states], ["a-0", "a-1"]) + self.assertEqual(replay_buffer.get_calls, [(2, "task_a", Status.COMPLETED)]) + self.assertGreaterEqual(replay_buffer.completed_count_call_count, 3) + self.assertEqual(manager._produce_progress.consumed_samples["task_a"], 2) + self.assertEqual(manager._produce_progress.next_consumer_step, 10) + + async def test_produce_batch_to_buffer_aggregates_status_with_update_abort_priority(self): + manager = AgentLoopManager( + task_runners=[ + _TaskRunner( + task_name="task_a", + agent_loop=_fake_agent_loop(), + produce_strategy=_FakeProduceStrategy(status=ProduceBatchStatus.NORMAL), + sampler=_FakeSampler(), + weight=1.0, + order=0, + ), + _TaskRunner( + task_name="task_b", + agent_loop=_fake_agent_loop(), + produce_strategy=_FakeProduceStrategy(status=ProduceBatchStatus.EXPIRED_BATCH), + sampler=_FakeSampler(), + weight=1.0, + order=1, + ), + _TaskRunner( + task_name="task_c", + agent_loop=_fake_agent_loop(), + produce_strategy=_FakeProduceStrategy(status=ProduceBatchStatus.UPDATE_ABORT), + sampler=_FakeSampler(), + weight=1.0, + order=2, + ), + ], + replay_buffer=_FakeReplayBuffer({}, {}), + ) + + manager._model_step = 5 + manager._produce_progress.producer_future_step = 5 + status = await manager._produce_batch_to_buffer(batch_size=3, progress=manager._produce_progress) + + self.assertEqual(status, ProduceBatchStatus.UPDATE_ABORT) + + async def test_produce_loop_waits_for_continue_produce_and_stops_on_finish(self): + strategy = _SequencedProduceStrategy( + statuses=[ProduceBatchStatus.NORMAL, ProduceBatchStatus.EXPIRED_BATCH, ProduceBatchStatus.NORMAL], + ) + manager = AgentLoopManager( + task_runners=[ + _TaskRunner( + task_name="task_a", + agent_loop=_fake_agent_loop(), + produce_strategy=strategy, + sampler=_FakeSampler(), + weight=1.0, + order=0, + ), + ], + replay_buffer=_FakeReplayBuffer({}, {}), + ) + manager._STATUS_POLL_INTERVAL_S = 0.01 + + manager._produce_progress.producer_future_step = 3 + loop_task = asyncio.create_task(manager.produce_loop(batch_size=1)) + await self._wait_until(lambda: manager._status == AgentLoopManagerStatus.EXPIRED_BATCH) + self.assertEqual(manager._status, AgentLoopManagerStatus.EXPIRED_BATCH) + self.assertEqual(strategy.called_train_steps[:2], [3, 4]) + + manager.continue_produce(model_step=9) + await self._wait_until(lambda: len(strategy.called_train_steps) >= 3) + self.assertEqual(manager._status, AgentLoopManagerStatus.NORMAL) + self.assertEqual(strategy.called_train_steps[:3], [3, 4, 4]) + self.assertEqual(strategy.called_model_steps[2], 9) + + manager._status = AgentLoopManagerStatus.FINISH + manager._finish_event.set() + await asyncio.wait_for(loop_task, timeout=1.0) diff --git a/tests/rl/test_producer.py b/tests/rl/test_producer.py new file mode 100644 index 0000000000..b9ba0ec441 --- /dev/null +++ b/tests/rl/test_producer.py @@ -0,0 +1,575 @@ +import unittest +import asyncio +from unittest.mock import AsyncMock, MagicMock + +from xtuner.v1.rl.agent_loop_manager import ( + AsyncProduceStrategyConfig, + ProduceBatchStatus, + ProduceProgress, + SamplerConfig, + SyncProduceStrategyConfig, +) +from xtuner.v1.rl.replay_buffer import AsyncReplayBufferConfig +from xtuner.v1.data_proto.rl_data import Status + + +class MockRolloutState: + def __init__(self, id, seq_staleness=1, status=Status.COMPLETED): + self.id = id + self.uid = id + self.status = status + self.seq_staleness = seq_staleness + self.response_ids = [] + self.extra_fields = {} + + +class TestProducer(unittest.IsolatedAsyncioTestCase): + def setUp(self): + # 1. 模拟 DataloaderConfig 和 Dataloader + self.mock_dataloader_cfg = MagicMock() + self.mock_dataloader = MagicMock() + # 模拟 next(dataloader_iter) 返回 [RolloutState] + self.mock_dataloader.__iter__.return_value = iter([[MockRolloutState(i)] for i in range(100)]) + self.mock_dataloader_cfg.build.return_value = self.mock_dataloader + + # 2. 模拟 Tokenizer + self.mock_tokenizer = MagicMock() + + # 3. 准备 ReplayBuffer + replay_buffer_cfg = AsyncReplayBufferConfig() + self.replay_buffer = replay_buffer_cfg.build() + + def _build_sampler(self): + sampler_cfg = SamplerConfig.model_construct(dataloader_cfg=self.mock_dataloader_cfg) + return sampler_cfg.build(self.mock_tokenizer, self.replay_buffer) + + def _build_progress( + self, + task_name: str, + target: int, + train_step: int = 0, + consumed: int = 0, + ) -> ProduceProgress: + return ProduceProgress( + next_consumer_step=train_step, + producer_future_step=train_step, + consumed_samples={task_name: consumed}, + target_samples={task_name: target}, + target_upto_future_step=train_step, + ) + + def _build_agent_loop(self, sleep_by_id: dict[int, float] | None = None): + mock_agent_loop = MagicMock() + mock_agent_loop.rollout_ctl.continue_generation.remote = AsyncMock(return_value=None) + mock_agent_loop.rollout_ctl.pause_generation.remote = AsyncMock(return_value=None) + mock_agent_loop.rollout_ctl.get_rollout_metadata.remote = AsyncMock(return_value={"server_url_dict": {}}) + + sleep_by_id = sleep_by_id or {} + + async def mock_gen(rs, **kwargs): + await asyncio.sleep(sleep_by_id.get(rs[0].id, 0.0)) + for r in rs: + r.seq_staleness = kwargs.get("model_step", kwargs.get("train_step", 0)) + r.status = Status.COMPLETED + return rs + + mock_agent_loop.generate_group = mock_gen + return mock_agent_loop + + async def test_sampler_with_replay_buffer(self): + task_name = "test_task" + sampler = self._build_sampler() + + # 场景 A: ReplayBuffer 为空,从 Dataloader 拿 + data = await sampler.sample(task_name) + self.assertEqual(data[0].id, 0) + + # 场景 B: ReplayBuffer 有多个候选状态,按列表顺序优先拿 + aborted_item = MockRolloutState(999, status=Status.ABORTED) + expired_item = MockRolloutState(1000, status=Status.EXPIRED) + await self.replay_buffer.put([aborted_item], task_name) + await self.replay_buffer.put([expired_item], task_name) + + data = await sampler.sample(task_name, group_status=[Status.EXPIRED, Status.ABORTED]) + self.assertEqual(data[0].id, 1000) + + data = await sampler.sample(task_name, group_status=[Status.EXPIRED, Status.ABORTED]) + self.assertEqual(data[0].id, 999) + + # 场景 C: ReplayBuffer 对应状态都为空,回退到 Dataloader + data = await sampler.sample(task_name, group_status=[Status.EXPIRED, Status.ABORTED]) + self.assertEqual(data[0].id, 1) + + async def test_sync_produce_strategy(self): + task_name = "test_task" + mock_agent_loop = self._build_agent_loop({0: 0.0, 1: 0.01}) + produce_strategy_cfg = SyncProduceStrategyConfig() + sampler = self._build_sampler() + strategy = produce_strategy_cfg.build() + + # 执行:生产 batch_size 为 2 的数据 + status = await strategy.produce_batch( + mock_agent_loop, + sampler, + self.replay_buffer, + batch_size=2, + task_name=task_name, + train_step=4, + model_step=3, + progress=self._build_progress(task_name, target=2, train_step=4), + ) + self.assertEqual(status, ProduceBatchStatus.NORMAL) + + # 验证:ReplayBuffer 中应该有 2 条 COMPLETED 数据 + final_data = await self.replay_buffer.get(10, task_name, Status.COMPLETED) + print(final_data[0][0].id, final_data[0][0].status) + print(final_data[1][0].id, final_data[1][0].status) + self.assertEqual(len(final_data), 2) + self.assertEqual(final_data[0][0].id, 0) + self.assertEqual(final_data[1][0].id, 1) + + async def test_async_produce_strategy(self): + # 这个async_produce_strategy的测试主要验证超发逻辑 + staleness 优先get的逻辑 + # 异步的其他功能如 partial_rollout, tail_batch不在这里进行验证 + mock_agent_loop = MagicMock() + mock_agent_loop.pause = AsyncMock() + mock_agent_loop.rollout_ctl.continue_generation.remote = AsyncMock(return_value=None) + task_name = "test_task" + call_count = 0 + async def mock_gen(rs, **kwargs): + nonlocal call_count + call_count += 1 + for r in rs: + if r.id == 999: + r.seq_staleness = 5 + else: + r.seq_staleness = call_count + r.status = Status.COMPLETED + print(r.id, r.seq_staleness, r.status) + return rs + + mock_agent_loop.generate_group = mock_gen + + sampler_cfg = SamplerConfig.model_construct(dataloader_cfg=self.mock_dataloader_cfg) + produce_strategy_cfg = AsyncProduceStrategyConfig(over_sample_threshold= 1) + sampler = sampler_cfg.build(self.mock_tokenizer, self.replay_buffer) + strategy = produce_strategy_cfg.build() + # 预处理 + aborted_item = MockRolloutState(999, status=Status.ABORTED) + await self.replay_buffer.put([aborted_item], task_name) + # 执行 + status = await strategy.produce_batch( + mock_agent_loop, + sampler, + self.replay_buffer, + batch_size=2, + task_name=task_name, + model_step=0, + progress=self._build_progress(task_name, target=2), + ) + self.assertEqual(status, ProduceBatchStatus.NORMAL) + + # 验证:ReplayBuffer 中应该有 4 条 COMPLETED 数据。 + final_data = await self.replay_buffer.get(10, task_name, Status.COMPLETED) + self.assertEqual(len(final_data), 4) + self.assertEqual(sorted(group[0].id for group in final_data), [0, 1, 2, 999]) + + async def test_async_produce_strategy_uses_live_consumed_progress(self): + task_name = "test_live_consumed" + call_count = 0 + + async def mock_gen(rs, **kwargs): + nonlocal call_count + call_count += 1 + for r in rs: + r.status = Status.COMPLETED + return rs + + mock_agent_loop = self._build_agent_loop() + mock_agent_loop.generate_group = mock_gen + sampler = self._build_sampler() + # 该用例验证版本记录顺序,放宽 stale 策略避免在生产入口提前返回。 + strategy = AsyncProduceStrategyConfig(over_sample_threshold=0.0, max_staleness=3).build() + progress = ProduceProgress( + next_consumer_step=1, + producer_future_step=2, + consumed_samples={task_name: 1}, + target_samples={task_name: 2}, + target_upto_future_step=2, + ) + + status = await strategy.produce_batch( + mock_agent_loop, + sampler, + self.replay_buffer, + batch_size=1, + task_name=task_name, + train_step=2, + model_step=1, + target_cumulative=2, + progress=progress, + ) + + self.assertEqual(status, ProduceBatchStatus.NORMAL) + self.assertEqual(call_count, 1) + self.assertEqual(await self.replay_buffer.count(task_name, Status.COMPLETED), 1) + + async def test_async_produce_strategy_uses_fixed_batch_oversample_budget(self): + task_name = "test_fixed_oversample" + sampler = MagicMock() + sample_ids = iter(range(100, 200)) + + async def sample(task_name, group_status=None): + self.assertEqual(group_status, [Status.ABORTED]) + return [MockRolloutState(next(sample_ids), status=Status.ABORTED)] + + sampler.sample = AsyncMock(side_effect=sample) + mock_agent_loop = self._build_agent_loop() + strategy = AsyncProduceStrategyConfig(over_sample_threshold=1.0).build() + progress = self._build_progress(task_name, target=10, consumed=9) + + status = await strategy.produce_batch( + mock_agent_loop, + sampler, + self.replay_buffer, + batch_size=4, + task_name=task_name, + model_step=0, + progress=progress, + ) + + self.assertEqual(status, ProduceBatchStatus.NORMAL) + # 当前只缺 1 个样本,但 over-sample 预算固定为 over * batch_size = 4, + # 因此本轮最多调度到 target + 4,对应初始发射 5 个任务。 + self.assertEqual(sampler.sample.await_count, 5) + self.assertEqual(await self.replay_buffer.count(task_name, Status.COMPLETED), 5) + + async def test_async_produce_strategy_tail_batch_is_static_and_no_oversample(self): + task_name = "test_tail_static" + for sample_id in (900, 901): + await self.replay_buffer.put([MockRolloutState(sample_id, status=Status.EXPIRED)], task_name) + + sampler = self._build_sampler() + original_sample = sampler.sample + sampled_statuses: list[list[Status] | None] = [] + + async def instrumented_sample(task_name, group_status=None): + sampled_statuses.append(group_status) + return await original_sample(task_name=task_name, group_status=group_status) + + sampler.sample = instrumented_sample + mock_agent_loop = self._build_agent_loop() + strategy = AsyncProduceStrategyConfig( + over_sample_threshold=1.0, + tail_batch_trigger_size=1, + ).build() + + status = await strategy.produce_batch( + mock_agent_loop, + sampler, + self.replay_buffer, + batch_size=2, + task_name=task_name, + model_step=0, + progress=self._build_progress(task_name, target=2), + ) + + self.assertEqual(status, ProduceBatchStatus.NORMAL) + # tail-batch 模式在本轮优先走 EXPIRED pool,并且不使用 over-sample 额外发射。 + self.assertEqual(sampled_statuses, [[Status.EXPIRED, Status.ABORTED], [Status.EXPIRED, Status.ABORTED]]) + completed = await self.replay_buffer.get(10, task_name, Status.COMPLETED) + self.assertEqual(sorted(group[0].id for group in completed), [900, 901]) + + async def test_async_produce_strategy_fails_fast_on_invalid_progress(self): + task_name = "test_invalid_progress" + strategy = AsyncProduceStrategyConfig(over_sample_threshold=0.0).build() + mock_agent_loop = self._build_agent_loop() + sampler = MagicMock() + sampler.sample = AsyncMock(side_effect=AssertionError("sampler.sample should not be called")) + + missing_consumed = ProduceProgress( + next_consumer_step=1, + producer_future_step=1, + consumed_samples={}, + target_samples={task_name: 1}, + target_upto_future_step=1, + ) + with self.assertRaisesRegex(KeyError, "consumed_samples"): + await strategy.produce_batch( + mock_agent_loop, + sampler, + self.replay_buffer, + batch_size=1, + task_name=task_name, + train_step=1, + model_step=0, + target_cumulative=1, + progress=missing_consumed, + ) + + mismatched_target = ProduceProgress( + next_consumer_step=1, + producer_future_step=1, + consumed_samples={task_name: 0}, + target_samples={task_name: 2}, + target_upto_future_step=1, + ) + with self.assertRaisesRegex(ValueError, "target_cumulative"): + await strategy.produce_batch( + mock_agent_loop, + sampler, + self.replay_buffer, + batch_size=1, + task_name=task_name, + train_step=1, + model_step=0, + target_cumulative=1, + progress=mismatched_target, + ) + + async def test_async_produce_strategy_records_sample_version_before_staleness_refresh(self): + task_name = "test_sample_version" + + async def mock_gen(rs, **kwargs): + self.assertNotIn("model_step", kwargs) + for r in rs: + r.response_ids = [10, 11] + r.status = Status.COMPLETED + return rs + + mock_agent_loop = self._build_agent_loop() + mock_agent_loop.generate_group = mock_gen + sampler = self._build_sampler() + # 该用例验证版本记录顺序,放宽 stale 策略避免在生产入口提前返回。 + strategy = AsyncProduceStrategyConfig(over_sample_threshold=0.0, max_staleness=3).build() + + status = await strategy.produce_batch( + mock_agent_loop, + sampler, + self.replay_buffer, + batch_size=1, + task_name=task_name, + train_step=5, + model_step=3, + progress=self._build_progress(task_name, target=1, train_step=5), + ) + + self.assertEqual(status, ProduceBatchStatus.NORMAL) + completed = await self.replay_buffer.get(1, task_name, Status.COMPLETED) + self.assertEqual(completed[0][0].response_model_steps, [3, 3]) + self.assertEqual(completed[0][0].seq_staleness, 1) + + async def test_async_produce_strategy_preserves_partial_rollout_old_versions(self): + task_name = "test_partial_rollout_versions" + partial_item = MockRolloutState(700, status=Status.ABORTED) + partial_item.response_ids = [10] + partial_item.response_model_steps = [1] + await self.replay_buffer.put([partial_item], task_name) + + async def mock_gen(rs, **kwargs): + self.assertNotIn("model_step", kwargs) + # partial rollout 的历史 token 已有版本,新 token 应按本次调度时的模型版本补齐。 + rs[0].response_ids = [10, 11, 12] + rs[0].status = Status.COMPLETED + return rs + + mock_agent_loop = self._build_agent_loop() + mock_agent_loop.generate_group = mock_gen + sampler = self._build_sampler() + # 该用例验证 partial rollout 版本拼接,放宽 stale 策略保留旧分段。 + strategy = AsyncProduceStrategyConfig( + over_sample_threshold=0.0, + enable_partial_rollout=True, + max_staleness=3, + ).build() + + status = await strategy.produce_batch( + mock_agent_loop, + sampler, + self.replay_buffer, + batch_size=1, + task_name=task_name, + train_step=5, + model_step=3, + progress=self._build_progress(task_name, target=1, train_step=5), + ) + + self.assertEqual(status, ProduceBatchStatus.NORMAL) + completed = await self.replay_buffer.get(1, task_name, Status.COMPLETED) + self.assertEqual(completed[0][0].response_model_steps, [1, 3, 3]) + self.assertEqual(completed[0][0].seq_staleness, 3) + + async def test_async_produce_strategy_reclaims_cross_call_pending_and_records_timing(self): + task_name = "test_task" + mock_agent_loop = self._build_agent_loop({0: 0.01, 1: 0.05, 2: 0.05}) + produce_strategy_cfg = AsyncProduceStrategyConfig(over_sample_threshold=2.0, enable_partial_rollout=True) + sampler = self._build_sampler() + strategy = produce_strategy_cfg.build() + progress = self._build_progress(task_name, target=1) + + status = await strategy.produce_batch( + mock_agent_loop, + sampler, + self.replay_buffer, + batch_size=1, + task_name=task_name, + model_step=0, + progress=progress, + ) + self.assertEqual(status, ProduceBatchStatus.NORMAL) + self.assertGreater(len(strategy._pending_tasks), 0) + + await asyncio.sleep(0.08) + + status = await strategy.produce_batch( + mock_agent_loop, + sampler, + self.replay_buffer, + batch_size=1, + task_name=task_name, + model_step=0, + progress=progress, + ) + self.assertEqual(status, ProduceBatchStatus.NORMAL) + self.assertEqual(len(strategy._pending_tasks), 0) + + final_data = await self.replay_buffer.get(10, task_name, Status.COMPLETED) + self.assertEqual(len(final_data), 3) + self.assertEqual(sorted(group[0].id for group in final_data), [0, 1, 2]) + for group in final_data: + self.assertIn("group_generate_time_s", group[0].extra_fields) + self.assertGreater(group[0].extra_fields["group_generate_time_s"], 0.0) + + async def test_async_produce_strategy_pause_produce_is_explicit(self): + task_name = "test_cleanup" + mock_agent_loop = self._build_agent_loop({0: 0.01, 1: 0.2, 2: 0.2}) + produce_strategy_cfg = AsyncProduceStrategyConfig(over_sample_threshold=2.0, enable_partial_rollout=True) + sampler = self._build_sampler() + strategy = produce_strategy_cfg.build() + progress = self._build_progress(task_name, target=1) + + await strategy.produce_batch( + mock_agent_loop, + sampler, + self.replay_buffer, + batch_size=1, + task_name=task_name, + model_step=0, + progress=progress, + ) + self.assertGreater(len(strategy._pending_tasks), 0) + + pause_time_s = await strategy.pause_produce( + mock_agent_loop, + self.replay_buffer, + task_name, + progress=progress, + ) + + self.assertGreaterEqual(pause_time_s, 0.0) + self.assertEqual(len(strategy._pending_tasks), 0) + completed = await self.replay_buffer.count(task_name, Status.COMPLETED) + aborted = await self.replay_buffer.count(task_name, Status.ABORTED) + expired = await self.replay_buffer.count(task_name, Status.EXPIRED) + self.assertEqual(completed + aborted + expired, 3) + + async def test_async_produce_strategy_returns_update_abort_without_sampling(self): + task_name = "test_update_abort" + strategy = AsyncProduceStrategyConfig(over_sample_threshold=1.0).build() + mock_agent_loop = self._build_agent_loop() + sampler = MagicMock() + sampler.sample = AsyncMock(side_effect=AssertionError("sampler.sample should not be called")) + update_event = asyncio.Event() + update_event.set() + + status = await strategy.produce_batch( + mock_agent_loop, + sampler, + self.replay_buffer, + batch_size=1, + task_name=task_name, + train_step=1, + model_step=1, + update_event=update_event, + progress=self._build_progress(task_name, target=1, train_step=1), + ) + + self.assertEqual(status, ProduceBatchStatus.UPDATE_ABORT) + self.assertEqual(await self.replay_buffer.count(task_name, Status.COMPLETED), 0) + + async def test_async_produce_strategy_returns_update_abort_after_schedule_pause(self): + task_name = "test_update_abort_after_schedule" + strategy = AsyncProduceStrategyConfig(over_sample_threshold=0.0).build() + mock_agent_loop = self._build_agent_loop({0: 0.05}) + sampler = MagicMock() + update_event = asyncio.Event() + progress = self._build_progress(task_name, target=1) + + async def sample(task_name, group_status=None): + # 模拟 manager 在调度临界区中途触发 pause;当前样本会进入 pending,后续应停止继续调度。 + update_event.set() + return [MockRolloutState(0, status=Status.ABORTED)] + + sampler.sample = AsyncMock(side_effect=sample) + + status = await strategy.produce_batch( + mock_agent_loop, + sampler, + self.replay_buffer, + batch_size=1, + task_name=task_name, + update_event=update_event, + model_step=0, + progress=progress, + ) + + self.assertEqual(status, ProduceBatchStatus.UPDATE_ABORT) + self.assertEqual(sampler.sample.await_count, 1) + + await strategy.pause_produce( + mock_agent_loop, + self.replay_buffer, + task_name, + progress=progress, + ) + self.assertEqual(len(strategy._pending_tasks), 0) + + async def test_async_produce_strategy_returns_expired_batch_before_processing_leftovers(self): + task_name = "test_expired_batch" + strategy = AsyncProduceStrategyConfig(max_staleness=0).build() + mock_agent_loop = self._build_agent_loop() + sampler = MagicMock() + sampler.sample = AsyncMock(side_effect=AssertionError("sampler.sample should not be called")) + await self.replay_buffer.put([MockRolloutState(999, status=Status.COMPLETED)], task_name) + + status = await strategy.produce_batch( + mock_agent_loop, + sampler, + self.replay_buffer, + batch_size=1, + task_name=task_name, + train_step=3, + model_step=1, + progress=self._build_progress(task_name, target=1, train_step=3), + ) + + self.assertEqual(status, ProduceBatchStatus.EXPIRED_BATCH) + self.assertEqual(await self.replay_buffer.count(task_name, Status.COMPLETED), 1) + self.assertEqual(await self.replay_buffer.count(task_name, Status.ABORTED), 0) + + async def test_refresh_staleness_refreshes_before_expire_check(self): + task_name = "test_refresh_leftover" + stale_item = MockRolloutState(1000, seq_staleness=0, status=Status.COMPLETED) + stale_item.response_model_steps = [3] + await self.replay_buffer.put([stale_item], task_name) + + expired_count = await self.replay_buffer.refresh_staleness( + task_name=task_name, + current_train_step=6, + stale_threshold=2, + ) + expired_groups = await self.replay_buffer.get(10, task_name, Status.EXPIRED) + + self.assertEqual(expired_count, 1) + self.assertEqual(len(expired_groups), 1) + self.assertEqual(expired_groups[0][0].seq_staleness, 2) diff --git a/tests/rl/test_replay_buffer.py b/tests/rl/test_replay_buffer.py new file mode 100644 index 0000000000..85c13d949b --- /dev/null +++ b/tests/rl/test_replay_buffer.py @@ -0,0 +1,221 @@ +import unittest +from pathlib import Path +from tempfile import TemporaryDirectory + +from xtuner.v1.data_proto.rl_data import Status +from xtuner.v1.rl.replay_buffer import AsyncReplayBufferConfig, SyncReplayBufferConfig + + +class MockState: + def __init__( + self, + state_id, + staleness=0, + input_ids=None, + status=Status.COMPLETED, + response_model_steps=None, + ): + self.id = state_id + self.seq_staleness = staleness + self.status = status + self.input_ids = input_ids if input_ids is not None else [state_id] + self.response_model_steps = response_model_steps + + +class TestReplayBuffer(unittest.IsolatedAsyncioTestCase): + @staticmethod + def _get_sorted_input_ids(data_groups): + return sorted(tuple(state.input_ids) for group in data_groups for state in group) + + async def _run_roundtrip_input_ids_case(self, replay_buffer_cfg, put_groups, task_name, sample_size): + with TemporaryDirectory() as tmp_dir: + save_path = Path(tmp_dir) + original = replay_buffer_cfg.build() + for group in put_groups: + await original.put(group, task_name) + await original.save(save_path) + + old_sampled = await original.get(sample_size, task_name, Status.COMPLETED) + + resumed = replay_buffer_cfg.build() + await resumed.resume(save_path) + new_sampled = await resumed.get(sample_size, task_name, Status.COMPLETED) + + self.assertEqual(self._get_sorted_input_ids(old_sampled), self._get_sorted_input_ids(new_sampled)) + + async def test_basic_ordering_and_task_isolation(self): + fifo_cfg = SyncReplayBufferConfig() + fifo = fifo_cfg.build() + await fifo.put([MockState(1), MockState(2)], "task1") + await fifo.put([MockState(3)], "task1") + await fifo.put([MockState(200)], "task2") + + res_task1 = await fifo.get(2, "task1", Status.COMPLETED) + res_task2 = await fifo.get(1, "task2", Status.COMPLETED) + self.assertEqual([s.id for s in res_task1[0]], [1, 2]) + self.assertEqual([s.id for s in res_task1[1]], [3]) + self.assertEqual([s.id for s in res_task2[0]], [200]) + + staleness_cfg = AsyncReplayBufferConfig() + staleness = staleness_cfg.build() + await staleness.put([MockState("low", staleness=1)], "task") + await staleness.put([MockState("high", staleness=5)], "task") + sampled = await staleness.get(2, "task", Status.COMPLETED) + self.assertEqual(sampled[0][0].id, "high") + self.assertEqual(sampled[1][0].id, "low") + + async def test_save_resume_keeps_query_behavior_fifo(self): + replay_buffer_cfg = SyncReplayBufferConfig() + with TemporaryDirectory() as tmp_dir: + save_path = Path(tmp_dir) + buffer = replay_buffer_cfg.build() + await buffer.put([MockState("a1", status=Status.COMPLETED, input_ids=[11, 12])], "task_a") + await buffer.put([MockState("a2", status=Status.FAILED, input_ids=[21])], "task_a") + await buffer.put([MockState("b1", status=Status.COMPLETED, input_ids=[31])], "task_b") + await buffer.save(save_path) + + resumed = replay_buffer_cfg.build() + await resumed.resume(save_path) + + self.assertEqual(await resumed.count("task_a", Status.COMPLETED), 1) + self.assertEqual(await resumed.count("task_a", Status.FAILED), 1) + self.assertEqual(await resumed.count("task_b", Status.COMPLETED), 1) + self.assertEqual(await resumed.count("task_b", Status.FAILED), 0) + + completed = await resumed.get(5, "task_a", Status.COMPLETED) + failed = await resumed.get(5, "task_a", Status.FAILED) + self.assertEqual([s.id for s in completed[0]], ["a1"]) + self.assertEqual([s.id for s in failed[0]], ["a2"]) + + await resumed.put([MockState("a3", input_ids=[41])], "task_a") + next_completed = await resumed.get(1, "task_a", Status.COMPLETED) + self.assertEqual([s.id for s in next_completed[0]], ["a3"]) + + async def test_save_resume_keeps_query_behavior_staleness(self): + replay_buffer_cfg = AsyncReplayBufferConfig() + with TemporaryDirectory() as tmp_dir: + save_path = Path(tmp_dir) + buffer = replay_buffer_cfg.build() + await buffer.put([MockState("done_low", staleness=1, status=Status.COMPLETED, input_ids=[101])], "task") + await buffer.put([MockState("failed_high", staleness=10, status=Status.FAILED, input_ids=[201])], "task") + await buffer.put([MockState("done_mid", staleness=5, status=Status.COMPLETED, input_ids=[301, 302])], "task") + await buffer.save(save_path) + + resumed = replay_buffer_cfg.build() + await resumed.resume(save_path) + + self.assertEqual(await resumed.count("task", Status.COMPLETED), 2) + self.assertEqual(await resumed.count("task", Status.FAILED), 1) + + completed = await resumed.get(2, "task", Status.COMPLETED) + failed = await resumed.get(1, "task", Status.FAILED) + self.assertEqual(completed[0][0].id, "done_mid") + self.assertEqual(completed[1][0].id, "done_low") + self.assertEqual(failed[0][0].id, "failed_high") + + async def test_save_resume_sample_keeps_input_ids_fifo(self): + await self._run_roundtrip_input_ids_case( + replay_buffer_cfg=SyncReplayBufferConfig(), + put_groups=[ + [MockState(1, input_ids=[101, 102]), MockState(2, input_ids=[201])], + [MockState(3, input_ids=[301, 302, 303])], + ], + task_name="task", + sample_size=2, + ) + + async def test_save_resume_sample_keeps_input_ids_staleness(self): + await self._run_roundtrip_input_ids_case( + replay_buffer_cfg=AsyncReplayBufferConfig(), + put_groups=[ + [MockState("mid", staleness=3, input_ids=[301, 302])], + [MockState("high", staleness=5, input_ids=[501])], + [MockState("low", staleness=1, input_ids=[101, 102, 103])], + ], + task_name="task", + sample_size=3, + ) + + async def test_refresh_staleness_expires_completed_in_place(self): + replay_buffer = AsyncReplayBufferConfig().build() + await replay_buffer.put( + [MockState("stale", response_model_steps=[3], status=Status.COMPLETED)], + "task", + ) + await replay_buffer.put( + [MockState("fresh", response_model_steps=[6], status=Status.COMPLETED)], + "task", + ) + + expired_count = await replay_buffer.refresh_staleness( + task_name="task", + current_train_step=6, + stale_threshold=2, + ) + + self.assertEqual(expired_count, 1) + self.assertEqual(await replay_buffer.count("task", Status.COMPLETED), 1) + self.assertEqual(await replay_buffer.count("task", Status.EXPIRED), 1) + expired = await replay_buffer.get(1, "task", Status.EXPIRED) + completed = await replay_buffer.get(1, "task", Status.COMPLETED) + self.assertEqual(expired[0][0].id, "stale") + self.assertEqual(expired[0][0].seq_staleness, 2) + self.assertEqual(completed[0][0].id, "fresh") + + async def test_refresh_staleness_expires_aborted_in_place(self): + replay_buffer = AsyncReplayBufferConfig().build() + await replay_buffer.put( + [MockState("stale-aborted", response_model_steps=[3], status=Status.ABORTED)], + "task", + ) + await replay_buffer.put( + [MockState("fresh-aborted", response_model_steps=[5], status=Status.ABORTED)], + "task", + ) + + expired_count = await replay_buffer.refresh_staleness( + task_name="task", + current_train_step=6, + stale_threshold=2, + ) + + self.assertEqual(expired_count, 1) + self.assertEqual(await replay_buffer.count("task", Status.ABORTED), 1) + self.assertEqual(await replay_buffer.count("task", Status.EXPIRED), 1) + expired = await replay_buffer.get(1, "task", Status.EXPIRED) + aborted = await replay_buffer.get(1, "task", Status.ABORTED) + self.assertEqual(expired[0][0].id, "stale-aborted") + self.assertEqual(expired[0][0].seq_staleness, 2) + self.assertEqual(expired[0][0].status, Status.EXPIRED) + self.assertEqual(aborted[0][0].id, "fresh-aborted") + self.assertEqual(aborted[0][0].seq_staleness, 0) + self.assertEqual(aborted[0][0].status, Status.ABORTED) + + async def test_refresh_staleness_respects_status_filter(self): + replay_buffer = AsyncReplayBufferConfig().build() + await replay_buffer.put( + [MockState("stale-completed", response_model_steps=[3], status=Status.COMPLETED)], + "task", + ) + await replay_buffer.put( + [MockState("stale-aborted", response_model_steps=[3], status=Status.ABORTED)], + "task", + ) + + expired_count = await replay_buffer.refresh_staleness( + task_name="task", + current_train_step=6, + stale_threshold=2, + statuses=[Status.ABORTED], + ) + + self.assertEqual(expired_count, 1) + self.assertEqual(await replay_buffer.count("task", Status.COMPLETED), 1) + self.assertEqual(await replay_buffer.count("task", Status.ABORTED), 0) + self.assertEqual(await replay_buffer.count("task", Status.EXPIRED), 1) + completed = await replay_buffer.get(1, "task", Status.COMPLETED) + expired = await replay_buffer.get(1, "task", Status.EXPIRED) + self.assertEqual(completed[0][0].id, "stale-completed") + self.assertEqual(completed[0][0].status, Status.COMPLETED) + self.assertEqual(expired[0][0].id, "stale-aborted") + self.assertEqual(expired[0][0].status, Status.EXPIRED) diff --git a/tests/rl/test_rl_colocate_trainer.py b/tests/rl/test_rl_colocate_trainer.py new file mode 100644 index 0000000000..da5bdd45db --- /dev/null +++ b/tests/rl/test_rl_colocate_trainer.py @@ -0,0 +1,295 @@ +import asyncio +import tempfile +import unittest +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +from xtuner.v1.data_proto.rl_data import Status +from xtuner.v1.rl.agent_loop_manager import AsyncProduceStrategyConfig, ProduceBatchResult +from xtuner.v1.rl.agent_loop_manager.agent_loop_manager import AgentLoopManager, _TaskRunner +from xtuner.v1.rl.replay_buffer import AsyncReplayBufferConfig +from xtuner.v1.train.rl_trainer import RLColocateTrainer + + +class _FakeRolloutState: + def __init__(self, uid: int): + self.id = uid + self.uid = str(uid) + self.status = Status.INIT + self.seq_staleness = 0 + self.response_ids = [] + self.response = None + self.reward = None + self.extra_fields = {} + self.response_model_steps = [] + + +class _FakeSampler: + def __init__(self): + self._next_id = 0 + + def __len__(self): + return 8 + + def save(self, checkpoint_path): + return None + + def resume(self, checkpoint_path): + return None + + async def sample(self, task_name, group_status=None, **kwargs): + item = _FakeRolloutState(self._next_id) + self._next_id += 1 + return [item] + + +def _build_fake_agent_loop(): + rollout_ctl = MagicMock() + rollout_ctl.continue_generation.remote = AsyncMock(return_value=None) + rollout_ctl.pause_generation.remote = AsyncMock(return_value=None) + rollout_ctl.get_rollout_metadata.remote = AsyncMock(return_value={"server_url_dict": {}}) + agent_loop = MagicMock() + agent_loop.rollout_ctl = rollout_ctl + + async def generate_group(rollout_states, **kwargs): + model_step = kwargs.get("model_step", kwargs.get("train_step", 0)) + for state in rollout_states: + state.status = Status.COMPLETED + state.response_ids = [1, 2, 3] + state.response = "ok" + state.reward = {"score": 1.0} + state.response_model_steps = [model_step] + return rollout_states + + agent_loop.generate_group = generate_group + return agent_loop + + +class TestRLColocateTrainer(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + + def tearDown(self): + self.temp_dir.cleanup() + + def _make_trainer(self, agent_loop_manager, *, total_train_steps: int = 1, sync_weights_interval: int = 1): + trainer = RLColocateTrainer.__new__(RLColocateTrainer) + trainer.logger = MagicMock() + trainer._total_train_steps = total_train_steps + trainer._cur_step = 0 + trainer._global_train_step = 0 + trainer.train_batch_size = 1 + trainer._sync_weights_interval = sync_weights_interval + trainer._debug_rollout = False + trainer._enable_evaluate = False + trainer._enable_initial_evaluate = False + trainer._evaluate_step = 1 + trainer._train_worker_cfg = SimpleNamespace(pack_max_length=16) + trainer._meta = SimpleNamespace( + latest_exp=SimpleNamespace(exp_dir=str(Path(self.temp_dir.name) / "exp")), + ) + Path(trainer.exp_dir).mkdir(parents=True, exist_ok=True) + trainer.agent_loop_manager = agent_loop_manager + trainer.eval_agent_loop_manager = MagicMock() + trainer.evaluator = MagicMock(eval_batch_size=1) + trainer.tokenizer = MagicMock() + trainer._exp_tracker = MagicMock() + trainer._display_all_workers_log = False + trainer._save_trajectories = MagicMock() + trainer._sync_weights_and_save = MagicMock( + side_effect=lambda train_step, step_timer_dict: train_step % trainer._sync_weights_interval == 0 + ) + trainer._log_step = MagicMock() + trainer._prepare_train_data = MagicMock( + return_value=([{"seq_ctx": "fake"}], {"batch_size": 1, "rewards/mean": 1.0}) + ) + + trainer.rollout_controller = SimpleNamespace( + offload=SimpleNamespace(remote=MagicMock(return_value="rollout_offloaded")), + ) + trainer.train_controller = SimpleNamespace( + onload=MagicMock(return_value="train_onloaded"), + fit=MagicMock( + return_value=[ + { + "rollout_is_metrics": {}, + "mismatch_metrics": {}, + "rollout_entropy": 0.0, + "train_entropy": 0.0, + "train_metrics": [], + "sft_train_metrics": {}, + } + ] + ), + ) + return trainer + + def test_fit_accepts_async_strategy_manager_on_colocate_path(self): + replay_buffer = AsyncReplayBufferConfig().build() + manager = AgentLoopManager( + task_runners=[ + _TaskRunner( + task_name="train_task", + agent_loop=_build_fake_agent_loop(), + produce_strategy=AsyncProduceStrategyConfig(over_sample_threshold=0.0).build(), + sampler=_FakeSampler(), + weight=1.0, + order=0, + ) + ], + replay_buffer=replay_buffer, + ) + trainer = self._make_trainer(manager) + + with ( + patch("xtuner.v1.train.rl_trainer.asyncio_run", side_effect=asyncio.run), + patch("xtuner.v1.train.rl_trainer.ray.get", side_effect=lambda obj: obj), + ): + trainer.fit() + + trainer.rollout_controller.offload.remote.assert_called_once_with() + trainer.train_controller.onload.assert_called_once_with(target="all") + trainer.train_controller.fit.assert_called_once() + trainer._prepare_train_data.assert_called_once() + trainer._save_trajectories.assert_called_once() + trainer._sync_weights_and_save.assert_called_once() + trainer._log_step.assert_called_once() + self.assertEqual(trainer._cur_step, 1) + + def test_fit_requires_non_empty_batch_from_manager(self): + async def _produce_empty(batch_size, train_step, **kwargs): + return ProduceBatchResult(rollout_states=[]) + + empty_manager = SimpleNamespace(produce_batch=_produce_empty) + trainer = self._make_trainer(empty_manager) + + with ( + patch("xtuner.v1.train.rl_trainer.asyncio_run", side_effect=asyncio.run), + patch("xtuner.v1.train.rl_trainer.ray.get", side_effect=lambda obj: obj), + ): + with self.assertRaisesRegex(AssertionError, "return non-empty rollout_states"): + trainer.fit() + + trainer.rollout_controller.offload.remote.assert_not_called() + trainer.train_controller.onload.assert_not_called() + trainer.train_controller.fit.assert_not_called() + trainer._prepare_train_data.assert_not_called() + trainer._save_trajectories.assert_not_called() + trainer._sync_weights_and_save.assert_not_called() + trainer._log_step.assert_not_called() + self.assertEqual(trainer._cur_step, 0) + + def test_fit_uses_sync_interval_and_passes_rollout_model_step(self): + produce_calls = [] + + async def _produce_batch(batch_size, train_step, *, model_step): + produce_calls.append((batch_size, train_step, model_step)) + return ProduceBatchResult(rollout_states=[[f"sample-{train_step}"]]) + + trainer = self._make_trainer( + SimpleNamespace(produce_batch=_produce_batch), + total_train_steps=3, + sync_weights_interval=2, + ) + with ( + patch("xtuner.v1.train.rl_trainer.asyncio_run", side_effect=asyncio.run), + patch("xtuner.v1.train.rl_trainer.ray.get", side_effect=lambda obj: obj), + ): + trainer.fit() + + self.assertEqual(produce_calls, [(1, 1, 0), (1, 2, 0), (1, 3, 2)]) + self.assertEqual( + [call.args[0] for call in trainer._sync_weights_and_save.call_args_list], + [1, 2, 3], + ) + self.assertEqual(trainer._cur_step, 3) + + def test_sync_weights_and_save_can_skip_weight_update_and_restore_rollout(self): + trainer = RLColocateTrainer.__new__(RLColocateTrainer) + events = [] + trainer._sync_weights_interval = 2 + trainer._maybe_save_checkpoint = MagicMock(side_effect=lambda step: events.append(f"save:{step}")) + trainer._maybe_save_hf = MagicMock(side_effect=lambda step: events.append(f"hf:{step}")) + trainer.train_controller = SimpleNamespace( + update_weights=MagicMock(side_effect=lambda: events.append("update_weights")), + offload=MagicMock(side_effect=lambda target="all": events.append(("train_offload", target))), + ) + trainer.rollout_controller = SimpleNamespace( + recover_failed_workers=SimpleNamespace( + remote=MagicMock(side_effect=lambda: events.append("recover_rollout")) + ), + onload_weights=SimpleNamespace(remote=MagicMock(side_effect=lambda: events.append("onload_weights"))), + onload_kvcache=SimpleNamespace(remote=MagicMock(side_effect=lambda: events.append("onload_kvcache"))), + ) + + with ( + patch("xtuner.v1.train.rl_trainer.ray.get", side_effect=lambda obj: obj), + patch( + "xtuner.v1.train.rl_trainer.bind_train_rollout", + side_effect=lambda train_controller, rollout_controller: events.append("bind"), + ), + ): + synced = trainer._sync_weights_and_save(train_step=1, step_timer_dict={}) + + self.assertFalse(synced) + self.assertEqual( + events, + [ + ("train_offload", "optimizer"), + "save:1", + "hf:1", + "recover_rollout", + ("train_offload", "model"), + "onload_weights", + "onload_kvcache", + ], + ) + + def test_sync_weights_and_save_updates_weights_on_interval_step(self): + trainer = RLColocateTrainer.__new__(RLColocateTrainer) + events = [] + trainer.logger = MagicMock() + trainer._sync_weights_interval = 2 + trainer._maybe_save_checkpoint = MagicMock(side_effect=lambda step: events.append(f"save:{step}")) + trainer._maybe_save_hf = MagicMock(side_effect=lambda step: events.append(f"hf:{step}")) + trainer.train_controller = SimpleNamespace( + update_weights=MagicMock(side_effect=lambda: events.append("update_weights")), + offload=MagicMock(side_effect=lambda target="all": events.append(("train_offload", target))), + ) + trainer.rollout_controller = SimpleNamespace( + recover_failed_workers=SimpleNamespace( + remote=MagicMock(side_effect=lambda: events.append("recover_rollout")) + ), + onload_weights=SimpleNamespace(remote=MagicMock(side_effect=lambda: events.append("onload_weights"))), + onload_kvcache=SimpleNamespace(remote=MagicMock(side_effect=lambda: events.append("onload_kvcache"))), + ) + + with ( + patch("xtuner.v1.train.rl_trainer.ray.get", side_effect=lambda obj: obj), + patch( + "xtuner.v1.train.rl_trainer.bind_train_rollout", + side_effect=lambda train_controller, rollout_controller: events.append("bind"), + ), + ): + synced = trainer._sync_weights_and_save(train_step=2, step_timer_dict={}) + + self.assertTrue(synced) + self.assertEqual( + events, + [ + ("train_offload", "optimizer"), + "save:2", + "hf:2", + "recover_rollout", + "bind", + "onload_weights", + "update_weights", + ("train_offload", "model"), + "onload_kvcache", + ], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/rl/test_rl_colocate_trainer_integration.py b/tests/rl/test_rl_colocate_trainer_integration.py new file mode 100644 index 0000000000..7332c371de --- /dev/null +++ b/tests/rl/test_rl_colocate_trainer_integration.py @@ -0,0 +1,325 @@ +import os +import unittest +import shutil +import tempfile +import ray +from pathlib import Path + +from xtuner.v1.rl.utils import AcceleratorResourcesConfig +from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig +from xtuner.v1.model import get_model_config_from_hf +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLTextTokenizeFnConfig +from xtuner.v1.train.trainer import LoadCheckpointConfig +from xtuner.v1.train.rl_trainer import RLColocateTrainerConfig +from xtuner.v1.rl.trainer import WorkerConfig +from xtuner.v1.rl.loss import GRPOLossConfig +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.judger import GSM8KJudgerConfig +from xtuner.v1.loss import CELossConfig +from xtuner.v1.datasets.sft_tokenize_fn import OpenaiTokenizeFunctionConfig +from xtuner.v1.rl.replay_buffer import SyncReplayBufferConfig +from xtuner.v1.rl.agent_loop import SingleTurnAgentLoopConfig +from xtuner.v1.rl.agent_loop_manager import ( + AgentLoopManagerConfig, + SamplerConfig, + SyncProduceStrategyConfig, + TaskSpecConfig, +) +from xtuner.v1.rl.evaluator import EvaluatorConfig +from xtuner.v1.data_proto.rl_data import SampleParams +from xtuner.v1.data_proto.sequence_context import SequenceContext +from transformers import AutoTokenizer +import torch + +QWEN3_PATH = os.environ["QWEN3_PATH"] +ALPACA_PATH = os.environ["ALPACA_PATH"] +ROLLOUT_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] + + +class TestRLColocateTrainerIntegration(unittest.TestCase): + """Integration test for RLColocateTrainer with checkpoint save/resume.""" + + def setUp(self): + ray.init(num_cpus=80, num_gpus=8, ignore_reinit_error=True) + self.temp_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.temp_dir, ignore_errors=True) + ray.shutdown() + + def build_trainer_config(self, work_dir, checkpoint_interval=1, checkpoint_maxkeep=2, auto_resume=False): + """Build RLColocateTrainerConfig for testing.""" + model_path = QWEN3_PATH + data_path = ALPACA_PATH + + # Resources + resources = AcceleratorResourcesConfig( + accelerator="GPU", + num_workers=8, + num_cpus_per_worker=4, + cpu_memory_per_worker=8 * 1024**3, + ) + + # Rollout config + rollout_config = RolloutConfig( + env="test_rl", + device="GPU", + model_path=model_path, + dtype="bfloat16", + tensor_parallel_size=1, + expert_parallel_size=1, + gpu_memory_utilization=0.5, + context_length=1536, + ) + + # Judger + judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", num_ray_actors=1) + + # Train worker + lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) + fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) + model_cfg = get_model_config_from_hf(Path(model_path)) + if hasattr(model_cfg, "balancing_loss_cfg"): + model_cfg.balancing_loss_cfg = None + if hasattr(model_cfg, "z_loss_cfg"): + model_cfg.z_loss_cfg = None + + optim_cfg = AdamWConfig(lr=1e-6, foreach=False, weight_decay=0.1) + loss_cfg = GRPOLossConfig( + policy_loss_cfg=dict( + cliprange_high=0.28, + cliprange_low=0.2, + loss_type="vanilla", + clip_ratio_c=10.0, + log_prob_diff_min=-20.0, + log_prob_diff_max=20.0, + ), + ignore_idx=-100, + use_kl_loss=False, + kl_loss_coef=0.0, + kl_loss_type="low_var_kl", + mode="chunk", + chunk_size=512, + ) + + # SFT configs for WorkerConfig + sft_dataset_config = [{ + "dataset": DatasetConfig(name='alpaca', anno_path=data_path), + "tokenize_fn": OpenaiTokenizeFunctionConfig( + chat_template='qwen3', + max_length=32768 + ) + }] + sft_dataloader_cfg = DataloaderConfig( + dataset_config_list=sft_dataset_config, + pack_max_length=32768, + pack_to_max_length=True, + num_workers=0, + ) + sft_global_batch_size = 8 + sft_loss_cfg = CELossConfig(mode="chunk", chunk_size=1024, loss_reduction="square") + + train_worker_cfg = WorkerConfig( + model_cfg=model_cfg, + load_from=model_path, + optim_cfg=optim_cfg, + loss_cfg=loss_cfg, + lr_cfg=lr_cfg, + fsdp_cfg=fsdp_cfg, + sp_size=1, + optimizer_steps=1, + pack_max_length=2048, + sft_dataloader_cfg=sft_dataloader_cfg, + sft_global_batch_size=sft_global_batch_size, + sft_loss_cfg=sft_loss_cfg, + ) + + # Agent loop manager + train_dataset = DatasetConfig(name="test_rl", anno_path=ROLLOUT_DATA_PATH) + tokenizer_config = RLTextTokenizeFnConfig(max_length=512) + train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] + dataloader_cfg = DataloaderConfig( + dataset_config_list=train_dataset_cfg, + pack_max_length=2048, + collator="fake_collator", + pack_level="none", + ) + sampler_config = SamplerConfig( + dataloader_cfg=dataloader_cfg, + prompt_repeat_k=2, + ) + training_sample_params = SampleParams( + max_tokens=512, + top_k=0, + top_p=1.0, + temperature=1.0, + min_tokens=0, + ) + agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=training_sample_params, + ) + produce_strategy_config = SyncProduceStrategyConfig() + agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=[ + TaskSpecConfig( + task_name="train_task", + agent_loop_config=agent_loop_config, + judger_config=judger_config, + produce_strategy_config=produce_strategy_config, + sampler_config=sampler_config, + ) + ], + ) + + # Eval agent loop manager (minimal) + eval_sampler_config = SamplerConfig( + dataloader_cfg=dataloader_cfg, + prompt_repeat_k=1, + ) + eval_agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=SampleParams(max_tokens=512, top_k=1, temperature=0.0), + ) + eval_agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=[ + TaskSpecConfig( + task_name="eval_task", + agent_loop_config=eval_agent_loop_config, + judger_config=judger_config, + sampler_config=eval_sampler_config, + ) + ], + ) + + # Evaluator + evaluator_config = EvaluatorConfig(compute_metric_func=None) + + return RLColocateTrainerConfig( + resources=resources, + train_worker_cfg=train_worker_cfg, + rollout_config=rollout_config, + tokenizer_path=model_path, + replay_buffer_config=SyncReplayBufferConfig(), + agent_loop_manager_cfg=agent_loop_manager_cfg, + eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg, + evaluator_config=evaluator_config, + load_from=model_path, + total_train_steps=2, + train_batch_size=4, + enable_evaluate=False, + enable_initial_evaluate=False, + work_dir=work_dir, + checkpoint_interval=checkpoint_interval, + checkpoint_maxkeep=checkpoint_maxkeep, + auto_resume=auto_resume, + seed=42, + debug_rollout=False, + ) + + def test_rl_train_with_sft(self): + """Test train_controller save/resume with efficient_attn_ratio verification.""" + work_dir = Path(self.temp_dir) / "work_dir_sft" + work_dir.mkdir(parents=True, exist_ok=True) + + # Build trainer to get train_controller + trainer_cfg = self.build_trainer_config( + work_dir=str(work_dir), + checkpoint_interval=1, + checkpoint_maxkeep=2, + auto_resume=False, + ) + trainer = trainer_cfg.build() + train_controller = trainer.train_controller + + # Prepare synthetic data batches + tokenizer = AutoTokenizer.from_pretrained(QWEN3_PATH, trust_remote_code=True) + + # Create simple prompts and responses + prompts = ["What is 2+2?", "What is the capital of France?"] + responses = [ + ["4", "Four", "2+2=4", "The answer is 4"], + ["Paris", "The capital is Paris", "Paris, France", "It's Paris"] + ] + + data_batches = [] + for prompt, response_list in zip(prompts, responses): + prompt_ids = tokenizer(prompt, return_tensors='pt')['input_ids'].flatten().tolist() + rewards = torch.tensor([1.0, 0.8, 0.9, 0.7], dtype=torch.float32) + advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8) + + for i, response in enumerate(response_list): + response_ids = tokenizer(response, return_tensors='pt')['input_ids'].flatten().tolist() + # Align with RLColocateTrainer._prepare_train_data(): + # - input_ids excludes last token (usually eos) of response_ids + # - shifted_labels aligns to input_ids length + input_ids = prompt_ids + response_ids[:-1] + shifted_labels = [-100] * (len(prompt_ids) - 1) + response_ids + input_ids_tensor = torch.tensor(input_ids, dtype=torch.int64).unsqueeze(0) + shifted_labels_tensor = torch.tensor(shifted_labels, dtype=torch.int64).unsqueeze(0) + + adv_val = advantages[i].item() + # Controller._packing expects `advantage` as a list and will flatten it. + # Keep the length consistent with shifted_labels/input_ids. + advantage_list = [adv_val] * (len(prompt_ids) - 1) + [adv_val] * len(response_ids) + + data_batches.append(dict( + seq_ctx=SequenceContext.from_input_ids((input_ids_tensor,), device="cpu"), + shifted_labels=shifted_labels_tensor, + advantage=advantage_list, + )) + + # RLColocateTrainer initializes by offloading train workers to CPU. + # Align with RLColocateTrainer.fit() which onloads before training. + train_controller.onload(target="all") + + # First fit and save + train_controller.fit(data_batches, pack_max_length=1024, rollout_idx=0) + checkpoint_path = str(work_dir / "save_test") + train_controller.save(checkpoint_path, no_save_optimizer=True) + + # Second fit and collect metrics + train_controller.onload(target="all") + log_infos = train_controller.fit(data_batches, pack_max_length=1024, rollout_idx=1) + efficient_attn_ratio_list = [] + for log_info in log_infos: + efficient_attn_ratio_list.append(log_info['sft_train_metrics']['efficient_attn_ratio']) + self.assertTrue(all([ratio > 0 for ratio in efficient_attn_ratio_list])) + + # Kill and rebuild + del trainer + ray.shutdown() + # Re-init Ray with enough resources for AcceleratorResourcesConfig(num_workers=8, num_cpus_per_worker=4). + ray.init(num_cpus=80, num_gpus=8, ignore_reinit_error=True) + + trainer_cfg = self.build_trainer_config( + work_dir=str(work_dir), + checkpoint_interval=1, + checkpoint_maxkeep=2, + auto_resume=False, + ) + trainer = trainer_cfg.build() + train_controller = trainer.train_controller + + # Resume and verify + load_checkpoint_cfg = LoadCheckpointConfig( + checkpoint_path=checkpoint_path, + load_optimizer_states=False, + load_optimizer_args=False + ) + train_controller.resume(load_checkpoint_cfg) + + train_controller.onload(target="all") + log_infos = train_controller.fit(data_batches, pack_max_length=1024, rollout_idx=1) + new_efficient_attn_ratio_list = [] + for log_info in log_infos: + new_efficient_attn_ratio_list.append(log_info['sft_train_metrics']['efficient_attn_ratio']) + + efficient_attn_ratio_list.sort() + new_efficient_attn_ratio_list.sort() + self.assertEqual(efficient_attn_ratio_list, new_efficient_attn_ratio_list) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/rl/test_rl_disaggregated_trainer.py b/tests/rl/test_rl_disaggregated_trainer.py new file mode 100644 index 0000000000..70f9ea2e19 --- /dev/null +++ b/tests/rl/test_rl_disaggregated_trainer.py @@ -0,0 +1,290 @@ +import asyncio +import tempfile +import unittest +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +from xtuner.v1.rl.agent_loop_manager import ( + AgentLoopManagerStatus, + ProduceBatchResult, + ProduceBatchStatus, +) +from xtuner.v1.train.rl_trainer import RLDisaggregatedTrainer, _validate_sync_intervals + + +class _FakeManager: + def __init__(self, get_batch_results): + self._results = list(get_batch_results) + self._status = AgentLoopManagerStatus.NORMAL + self._finish_event = asyncio.Event() + self.calls: list[object] = [] + + async def produce_loop(self, batch_size: int): + self.calls.append(("produce_loop_start", batch_size)) + await self._finish_event.wait() + self.calls.append("produce_loop_exit") + + async def get_batch(self, batch_size: int, train_step: int): + self.calls.append(("get_batch", batch_size, train_step)) + return self._results.pop(0) + + async def pause_produce(self, *, use_global_progress: bool): + self.calls.append(("pause_produce", use_global_progress)) + return 0.25 + + def continue_produce(self, model_step: int): + self.calls.append(("continue_produce", model_step)) + + +class TestRLDisaggregatedTrainer(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + + def tearDown(self): + self.temp_dir.cleanup() + + def _make_trainer(self, agent_loop_manager): + trainer = RLDisaggregatedTrainer.__new__(RLDisaggregatedTrainer) + trainer.logger = MagicMock() + trainer._cur_step = 0 + trainer._total_train_steps = 1 + trainer._global_train_step = 0 + trainer.train_batch_size = 2 + trainer._sync_weights_interval = 1 + trainer._enable_evaluate = False + trainer._enable_initial_evaluate = False + trainer._evaluate_step = 1 + trainer._debug_rollout = False + trainer._display_all_workers_log = False + trainer._train_worker_cfg = SimpleNamespace(pack_max_length=16) + trainer._meta = SimpleNamespace( + latest_exp=SimpleNamespace(exp_dir=str(Path(self.temp_dir.name) / "exp")), + ) + Path(trainer.exp_dir).mkdir(parents=True, exist_ok=True) + trainer.agent_loop_manager = agent_loop_manager + trainer.eval_agent_loop_manager = SimpleNamespace(produce_batch=AsyncMock()) + trainer.evaluator = MagicMock(eval_batch_size=1, run=MagicMock(return_value={"acc": 1.0})) + trainer._exp_tracker = MagicMock() + trainer._prepare_train_data = MagicMock( + return_value=([{"seq_ctx": "fake"}], {"batch_size": 1, "rewards/mean": 1.0}) + ) + trainer._save_trajectories = MagicMock() + trainer._log_step = MagicMock() + trainer._maybe_save_checkpoint = MagicMock() + trainer._maybe_save_hf = MagicMock() + trainer.fake_update_weights = MagicMock() + trainer.train_controller = SimpleNamespace( + fit=MagicMock(return_value=[{"train_metrics": [], "sft_train_metrics": {}}]), + onload=MagicMock(return_value="onload"), + offload=MagicMock(return_value="offload"), + update_weights=MagicMock(return_value="update"), + ) + trainer.rollout_controller = SimpleNamespace( + recover_failed_workers=SimpleNamespace(remote=MagicMock(return_value="recover")), + onload_weights=SimpleNamespace(remote=MagicMock(return_value="onload_weights")), + onload_kvcache=SimpleNamespace(remote=MagicMock(return_value="onload_kvcache")), + ) + return trainer + + def test_sync_weights_and_save_saves_before_fake_update(self): + manager = _FakeManager([]) + trainer = self._make_trainer(manager) + events: list[str] = [] + trainer._maybe_save_checkpoint = MagicMock(side_effect=lambda step: events.append(f"save:{step}")) + trainer._maybe_save_hf = MagicMock(side_effect=lambda step: events.append(f"hf:{step}")) + trainer.fake_update_weights = MagicMock(side_effect=lambda: events.append("fake_update")) + + with ( + patch("xtuner.v1.train.rl_trainer.ray.get", side_effect=lambda obj, timeout=None: obj), + patch( + "xtuner.v1.train.rl_trainer.bind_train_rollout", + side_effect=lambda train_controller, rollout_controller: events.append("bind"), + ), + ): + asyncio.run(trainer._sync_weights_and_save(train_step=3, step_timer_dict={})) + + self.assertEqual(events, ["save:3", "hf:3", "bind", "fake_update"]) + trainer.train_controller.offload.assert_not_called() + + def test_fit_skips_train_when_batch_is_expired(self): + manager = _FakeManager( + [ProduceBatchResult(rollout_states=[], status=ProduceBatchStatus.EXPIRED_BATCH)] + ) + trainer = self._make_trainer(manager) + trainer._sync_weights_and_save = AsyncMock() + + asyncio.run(trainer._fit()) + + trainer._prepare_train_data.assert_not_called() + trainer.train_controller.fit.assert_not_called() + trainer._sync_weights_and_save.assert_awaited_once() + self.assertIn(("continue_produce", 1), manager.calls) + self.assertIn("produce_loop_exit", manager.calls) + + def test_fit_runs_eval_before_reset_and_stops_producer(self): + manager = _FakeManager( + [ProduceBatchResult(rollout_states=[["sample"]], status=ProduceBatchStatus.NORMAL)] + ) + trainer = self._make_trainer(manager) + trainer._enable_evaluate = True + events: list[str] = [] + + async def sync_weights_and_save(train_step: int, step_timer_dict: dict): + events.append("sync") + + async def eval_produce_batch(batch_size: int, train_step: int, model_step: int): + events.append("eval") + return ProduceBatchResult(rollout_states=[["eval"]]) + + def continue_produce(model_step: int): + events.append("continue_produce") + manager.calls.append(("continue_produce", model_step)) + + trainer._sync_weights_and_save = AsyncMock(side_effect=sync_weights_and_save) + trainer.eval_agent_loop_manager.produce_batch = AsyncMock(side_effect=eval_produce_batch) + trainer.evaluator.run = MagicMock(return_value={"acc": 1.0}) + manager.continue_produce = continue_produce + + with patch("xtuner.v1.train.rl_trainer.ray.get", side_effect=lambda obj, timeout=None: obj): + asyncio.run(trainer._fit()) + + trainer._prepare_train_data.assert_called_once() + trainer.train_controller.fit.assert_called_once() + trainer.train_controller.onload.assert_not_called() + self.assertEqual(events, ["sync", "eval", "continue_produce"]) + self.assertTrue(manager._finish_event.is_set()) + self.assertIn("produce_loop_exit", manager.calls) + + def test_fake_update_weights_does_not_onload_rollout(self): + manager = _FakeManager([]) + trainer = self._make_trainer(manager) + + with patch("xtuner.v1.train.rl_trainer.ray.get", side_effect=lambda obj, timeout=None: obj): + trainer.fake_update_weights = RLDisaggregatedTrainer.fake_update_weights.__get__( + trainer, RLDisaggregatedTrainer + ) + trainer.fake_update_weights() + + trainer.train_controller.update_weights.assert_called_once_with() + trainer.rollout_controller.onload_weights.remote.assert_not_called() + trainer.rollout_controller.onload_kvcache.remote.assert_not_called() + + def test_resume_from_checkpoint_syncs_weights_then_resets_manager(self): + trainer = RLDisaggregatedTrainer.__new__(RLDisaggregatedTrainer) + trainer.logger = MagicMock() + trainer._load_checkpoint_cfg = SimpleNamespace(checkpoint_path=Path(self.temp_dir.name)) + trainer.train_controller = SimpleNamespace(resume=MagicMock(return_value="resume")) + trainer.rollout_controller = SimpleNamespace() + events: list[str] = [] + + def manager_resume(checkpoint_path): + events.append(f"manager_resume:{Path(checkpoint_path).name}") + return 3 + + def manager_continue_produce(model_step: int): + events.append(f"continue_produce:{model_step}") + + trainer.agent_loop_manager = SimpleNamespace( + resume=MagicMock(side_effect=manager_resume), + continue_produce=MagicMock(side_effect=manager_continue_produce), + ) + trainer.fake_update_weights = MagicMock(side_effect=lambda: events.append("fake_update")) + + train_state_path = Path(self.temp_dir.name) / trainer._SAVE_TRAIN_STATE_PATH + train_state_path.write_text('{"cur_step": 3}') + + with ( + patch("xtuner.v1.train.rl_trainer.ray.get", side_effect=lambda obj, timeout=None: obj), + patch( + "xtuner.v1.train.rl_trainer.bind_train_rollout", + side_effect=lambda train_controller, rollout_controller: events.append("bind"), + ), + ): + trainer._resume_from_checkpoint(self.temp_dir.name) + + trainer.train_controller.resume.assert_called_once_with(trainer._load_checkpoint_cfg) + self.assertEqual(trainer._cur_step, 3) + trainer.agent_loop_manager.resume.assert_called_once_with(Path(self.temp_dir.name)) + self.assertTrue(events[0].startswith("manager_resume:")) + self.assertEqual(events[1:], ["bind", "fake_update", "continue_produce:3"]) + + def test_validate_sync_schedule_accepts_multiples(self): + _validate_sync_intervals(sync_weights_interval=2, checkpoint_interval=4, hf_interval=6) + _validate_sync_intervals(sync_weights_interval=2, checkpoint_interval=-1, hf_interval=None) + _validate_sync_intervals( + sync_weights_interval=2, + checkpoint_interval=-1, + hf_interval=None, + evaluate_step=4, + enable_evaluate=True, + ) + + def test_validate_sync_schedule_rejects_non_multiple_checkpoint_interval(self): + with self.assertRaisesRegex(ValueError, "checkpoint_interval=5.*sync_weights_interval=2"): + _validate_sync_intervals(sync_weights_interval=2, checkpoint_interval=5, hf_interval=-1) + + def test_validate_sync_schedule_rejects_non_multiple_hf_interval(self): + with self.assertRaisesRegex(ValueError, "hf_interval=5.*sync_weights_interval=2"): + _validate_sync_intervals(sync_weights_interval=2, checkpoint_interval=4, hf_interval=5) + + def test_validate_sync_schedule_rejects_non_multiple_evaluate_step(self): + with self.assertRaisesRegex(ValueError, "evaluate_step=5.*sync_weights_interval=2"): + _validate_sync_intervals( + sync_weights_interval=2, + checkpoint_interval=4, + hf_interval=6, + evaluate_step=5, + enable_evaluate=True, + ) + + def test_build_disaggregated_placement_groups_uses_distinct_names(self): + trainer = RLDisaggregatedTrainer.__new__(RLDisaggregatedTrainer) + trainer.logger = MagicMock() + trainer._meta = SimpleNamespace( + latest_exp=SimpleNamespace(exp_dir=str(Path(self.temp_dir.name) / "20260416130000")), + ) + train_pg = SimpleNamespace(id="train-pg-id") + rollout_pg = SimpleNamespace(id="rollout-pg-id") + + with patch( + "xtuner.v1.train.rl_trainer.AutoAcceleratorWorkers.build_placement_group", + side_effect=[train_pg, rollout_pg], + ) as build_pg: + built_train_pg, built_rollout_pg = trainer._build_disaggregated_placement_groups( + train_resources=object(), + rollout_resources=object(), + ) + + self.assertIs(built_train_pg, train_pg) + self.assertIs(built_rollout_pg, rollout_pg) + self.assertEqual( + build_pg.call_args_list[0].kwargs["name"], + "xtuner_rl_disagg_20260416130000_train", + ) + self.assertEqual( + build_pg.call_args_list[1].kwargs["name"], + "xtuner_rl_disagg_20260416130000_rollout", + ) + + def test_build_disaggregated_placement_groups_rejects_reused_pg(self): + trainer = RLDisaggregatedTrainer.__new__(RLDisaggregatedTrainer) + trainer.logger = MagicMock() + trainer._meta = SimpleNamespace( + latest_exp=SimpleNamespace(exp_dir=str(Path(self.temp_dir.name) / "20260416130000")), + ) + shared_pg = SimpleNamespace(id="shared-pg-id") + + with patch( + "xtuner.v1.train.rl_trainer.AutoAcceleratorWorkers.build_placement_group", + side_effect=[shared_pg, shared_pg], + ): + with self.assertRaisesRegex(RuntimeError, "distinct placement groups"): + trainer._build_disaggregated_placement_groups( + train_resources=object(), + rollout_resources=object(), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/rl/test_rollout.py b/tests/rl/test_rollout.py new file mode 100644 index 0000000000..31d9e8e5d4 --- /dev/null +++ b/tests/rl/test_rollout.py @@ -0,0 +1,159 @@ +import asyncio +import os +import subprocess +import unittest +import tempfile +import ray +import torch +from transformers import AutoTokenizer +import tempfile +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers +from xtuner.v1.data_proto.rl_data import Status, SampleParams, RolloutState +from xtuner.v1.rl.rollout import RolloutController + +TEST_TEXT_MESSAGES=[{"role": "user", "content": "Hello!"}] +MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] +MOE_MODEL_PATH = os.environ["QWEN3_MOE_PATH"] +TRAIN_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] +TEST_DATA_PATH = os.environ["ROLLOUT_TEST_DATA_PATH"] +resource_map = { + "npu": "NPU", + "cuda": "GPU", +} +class TestRollout(unittest.IsolatedAsyncioTestCase): + + @classmethod + def setUpClass(cls) -> None: + os.environ["XTUNER_USE_FA3"] = "1" + os.environ["LMD_SKIP_WARMUP"] = "1" + + @classmethod + def tearDownClass(cls) -> None: + del os.environ["XTUNER_USE_FA3"] + del os.environ["LMD_SKIP_WARMUP"] + + def init_config(self): + self.resources_cfg = AcceleratorResourcesConfig( + accelerator=resource_map[torch.accelerator.current_accelerator().type], + num_workers=8, + num_cpus_per_worker=8, + cpu_memory_per_worker=16 * 1024**3, # 16 GB + ) + self.max_prompt_length = 512 + self.max_response_length = 1024 + self.context_length = self.max_prompt_length + self.max_response_length + self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) + + def setUp(self): + ray.init(num_cpus=80, ignore_reinit_error=True) + self.data_path = TRAIN_DATA_PATH + self.model_path = MODEL_PATH + self.temp_dir = tempfile.TemporaryDirectory() + self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") + self.init_config() + + def tearDown(self): + ray.shutdown() + # When lmdeploy enable ep>1, it uses deep_ep. Buffer implicit destroy would cause some ray actor stucked. + # Use pkill cleen up ray::WorkerWrapper process after close ray cluster connection as workaround. + # TODO(chenchiyu): add excplicit deep_ep destroy in lmdeploy. + self._cleanup_lmdeploy_ray_worker_wrapper() + self.temp_dir.cleanup() + + @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") + def test_parallel_rollout(self): + resource_config = AcceleratorResourcesConfig( + accelerator=resource_map[torch.accelerator.current_accelerator().type], + num_workers=4, + num_cpus_per_worker=4, + cpu_memory_per_worker=8 * 1024**3, # 8 GB + ) + pg1 = AutoAcceleratorWorkers.build_placement_group(resource_config, name="tp_pg") + pg2 = AutoAcceleratorWorkers.build_placement_group(resource_config, name="ep_pg") + dense_model_path = MODEL_PATH + moe_model_path = MOE_MODEL_PATH + dist_port_base = 38000 + async def run_both(): + return await asyncio.gather( + self._run_rollout(model_path=dense_model_path, tp_size=4, ep_size=1, pg=pg1, dist_port_base=dist_port_base), + self._run_rollout(model_path=moe_model_path, tp_size=1, ep_size=4, pg=pg2, dist_port_base=dist_port_base + 1024 * 4), + return_exceptions=False + ) + + asyncio.run(run_both()) + + def _cleanup_lmdeploy_ray_worker_wrapper(self): + try: + result = subprocess.run(["pkill", "-f", "ray::RayWorkerWrapper*"], capture_output=True, text=True, timeout=10) + if result.returncode != 0: + print(f"pkill command failed with return code {result.returncode}: {result.stderr}." + " Maybe no lmdeploy ray::RayWorkerWrapper processes found.") + except Exception as e: + print(f"Error stopping ray::RayWorkerWrapper cluster: {e}") + + async def _run_rollout(self, model_path, tp_size, ep_size, pg, dist_port_base): + rollout_config = RolloutConfig( + env="test_rollout", + model_path=model_path, + model_name=os.path.basename(model_path).lower(), + tokenizer_path=model_path, + tensor_parallel_size=tp_size, + expert_parallel_size=ep_size, + context_length=self.context_length, + worker_log_dir=self.worker_log_dir, + dist_port_base=dist_port_base, + enable_return_routed_experts=ep_size > 1, # ep_size > 1 默认打开r3 + ) + rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg) + result_refs = [] + + # Test Case 1: 文本输入 + 文本输出 + # TODO(@duanyanhui): test prompt in and prompt out with v1/chat/completion api + # sample_params1 = SampleParams(return_token_ids=False) + # input1 = RolloutState(message=TEST_TEXT_MESSAGES, sample_params=sample_params1) + # result1_ref = rollout_controller.generate.remote(rollout_state=input1) + # result_refs.append(result1_ref) + + # Test Case 2: 文本输入 + Token 输出 + sample_params2 = SampleParams(return_token_ids=True) + input2 = RolloutState(message=TEST_TEXT_MESSAGES, sample_params=sample_params2) + result2_ref = rollout_controller.generate.remote(rollout_state=input2) + result_refs.append(result2_ref) + + # Test Case 3: Token 输入 + Token 输出 + text_prompt = self.tokenizer.apply_chat_template(TEST_TEXT_MESSAGES, tokenize=False, add_generation_prompt=True) + input_tokens = self.tokenizer(text_prompt, add_special_tokens=False)["input_ids"] + sample_params3 = SampleParams(return_token_ids=True) + input3 = RolloutState(message=TEST_TEXT_MESSAGES, tokens=input_tokens, sample_params=sample_params3) + result3_ref = rollout_controller.generate.remote(rollout_state=input3) + result_refs.append(result3_ref) + + try: + results = await asyncio.wait_for(asyncio.gather(*result_refs), timeout=300) + for i, result in enumerate(results): + case_id = f"Case {i+1}" + self.assertEqual(result.status, Status.COMPLETED, + msg=f"{case_id} failed: Expected status COMPLETED but got {result.status}") + self.assertEqual(result.finish_reason, 'stop', + msg=f"{case_id} failed: Expected finish_reason 'stop' but got {result.finish_reason}") + + if result.sample_params.return_token_ids: + self.assertGreater(len(result.response_ids), 0, + msg=f"{case_id} failed: response_ids should not be empty when return_token_ids is True") + + if result.sample_params.return_logprob: + self.assertEqual(len(result.logprobs), len(result.response_ids), + msg=f"{case_id} failed: logprobs length ({len(result.logprobs)}) " + f"does not match response_ids length ({len(result.response_ids)})") + + except asyncio.TimeoutError: + if tp_size > 1 and ep_size == 1: + self.fail("TP and Dense Rollout timed out!") + if ep_size > 1 and tp_size == 1: + self.fail("EP and MoE Rollout timed out!") + finally: + await asyncio.wait_for(rollout_controller.shutdown.remote(), timeout=300) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/rl/test_rollout_utils.py b/tests/rl/test_rollout_utils.py new file mode 100644 index 0000000000..749dd16892 --- /dev/null +++ b/tests/rl/test_rollout_utils.py @@ -0,0 +1,98 @@ +import ray +import torch +import threading +import time +import unittest +import os +import tempfile +from types import SimpleNamespace +from unittest.mock import patch + +from xtuner.v1.data_proto.rl_data import Status, RolloutState, SampleParams +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.rollout.controller import RolloutController, WorkerInfo +from xtuner.v1.rl.rollout.utils import RolloutHealthChecker, SessionRouter +from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers, asyncio_run + +MODEL_PATH = os.environ.get("ROLLOUT_MODEL_PATH", "") +RESOURCE_MAP = {"npu": "NPU", "cuda": "GPU"} +TEST_TEXT_MESSAGES=[{"role": "user", "content": "Hello!"}] + +class TestRolloutControllerRecover(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + os.environ["XTUNER_USE_FA3"] = "1" + os.environ["LMD_SKIP_WARMUP"] = "1" + + @classmethod + def tearDownClass(cls) -> None: + del os.environ["XTUNER_USE_FA3"] + del os.environ["LMD_SKIP_WARMUP"] + + def setUp(self): + ray.init(num_cpus=80, address="local", ignore_reinit_error=True) + self.model_path = MODEL_PATH + self.temp_dir = tempfile.TemporaryDirectory() + self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") + + def tearDown(self): + ray.shutdown() + self.temp_dir.cleanup() + + def init_rollout_controller(self): + resource_cfg = AcceleratorResourcesConfig( + accelerator=RESOURCE_MAP[torch.accelerator.current_accelerator().type], + num_workers=1, + num_cpus_per_worker=4, + cpu_memory_per_worker=8 * 1024**3, + ) + pg = AutoAcceleratorWorkers.build_placement_group(resource_cfg, name="recover_test_pg") + rollout_cfg = RolloutConfig( + env="test_rollout_utils", + model_path=self.model_path, + model_name=os.path.basename(self.model_path).lower(), + tokenizer_path=self.model_path, + tensor_parallel_size=1, + expert_parallel_size=1, + worker_log_dir=self.temp_dir.name, + context_length=8192, + health_check_interval_seconds=10, + health_check_failure_threshold=1, + ) + controller = RolloutController(rollout_cfg, pg) + return controller + + @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") + def test_healthcheck_deactivate_and_recover(self): + controller = self.init_rollout_controller() + ranks = list(controller.rank2info.keys()) + rank0 = ranks[0] + actor0 = controller.rank2info[rank0].actor + ray.get(actor0.shutdown.remote()) + time.sleep(3) # wait for the actor to be fully killed + health_before_recover = ray.get(actor0.check_health.remote()) + url = controller.rank2info[rank0].url + self.assertFalse(health_before_recover) + + controller.health_checker.run_once() + + self.assertFalse(controller.rank2info[rank0].is_active) + rollout_state = RolloutState( + message=TEST_TEXT_MESSAGES, + sample_params=SampleParams(return_token_ids=True), + ) + out = asyncio_run(controller.generate(rollout_state)) + self.assertEqual(out.status, Status.FAILED) + + controller.recover_failed_workers() + + self.assertTrue(controller.rank2info[rank0].is_active) + self.assertEqual(url, controller.rank2info[rank0].url) + health_after_recover = ray.get(actor0.check_health.remote()) + self.assertTrue(health_after_recover) + out = asyncio_run(controller.generate(rollout_state)) + self.assertNotEqual(out.status, Status.FAILED) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/rl/test_staleness_policy.py b/tests/rl/test_staleness_policy.py new file mode 100644 index 0000000000..b379d830a7 --- /dev/null +++ b/tests/rl/test_staleness_policy.py @@ -0,0 +1,47 @@ +import unittest + +from pydantic import ValidationError + +from xtuner.v1.data_proto.rl_data import Status +from xtuner.v1.rl.agent_loop_manager import AsyncProduceStrategyConfig, calculate_stale_threshold +from xtuner.v1.rl.agent_loop_manager.producer import expire_group_if_needed + + +class _State: + def __init__(self, seq_staleness: int, status: Status = Status.COMPLETED): + self.uid = "state" + self.seq_staleness = seq_staleness + self.status = status + + +class TestStalenessPolicy(unittest.TestCase): + def test_max_staleness_zero_uses_sync_interval_as_threshold(self): + # max_staleness=0 表示只接受同步间隔内天然存在的最小滞后。 + self.assertEqual(calculate_stale_threshold(max_staleness=0, sync_weights_interval=4), 4) + strategy = AsyncProduceStrategyConfig(max_staleness=0).build(sync_weights_interval=4) + + self.assertFalse(strategy.is_model_expired(train_step=8, model_step=4)) + self.assertTrue(strategy.is_model_expired(train_step=9, model_step=4)) + + def test_max_staleness_one_allows_one_extra_sync_interval(self): + self.assertEqual(calculate_stale_threshold(max_staleness=1, sync_weights_interval=4), 8) + strategy = AsyncProduceStrategyConfig(max_staleness=1).build(sync_weights_interval=4) + + self.assertFalse(strategy.is_model_expired(train_step=12, model_step=4)) + self.assertTrue(strategy.is_model_expired(train_step=13, model_step=4)) + + def test_negative_max_staleness_is_invalid(self): + with self.assertRaises(ValidationError): + AsyncProduceStrategyConfig(max_staleness=-1) + + def test_expire_group_requires_positive_step_threshold(self): + with self.assertRaisesRegex(ValueError, "stale_threshold must be positive"): + expire_group_if_needed([_State(seq_staleness=0)], stale_threshold=0) + + group = [_State(seq_staleness=4)] + expire_group_if_needed(group, stale_threshold=4) + self.assertEqual(group[0].status, Status.EXPIRED) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/ray/test_update_weight.py b/tests/rl/test_update_weight.py similarity index 79% rename from tests/ray/test_update_weight.py rename to tests/rl/test_update_weight.py index fa008d3d71..ce2897918b 100644 --- a/tests/ray/test_update_weight.py +++ b/tests/rl/test_update_weight.py @@ -3,17 +3,17 @@ import tempfile import ray -from xtuner.v1.ray.rollout import RolloutController -from xtuner.v1.data_proto.rl_data import SampleParams +from xtuner.v1.rl.rollout import RolloutController +from xtuner.v1.data_proto.rl_data import SampleParams, RolloutState from xtuner.v1.config import ( AdamWConfig, FSDPConfig, LRConfig, ) -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers -from xtuner.v1.rl.base import WorkerConfig, TrainingController, TrainingWorker as BaseTrainingWorker -from xtuner.v1.rl.grpo.loss import GRPOLossConfig as LossConfig +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers +from xtuner.v1.rl.trainer import WorkerConfig, TrainingController, TrainingWorker as BaseTrainingWorker +from xtuner.v1.rl.loss import GRPOLossConfig as LossConfig from xtuner.v1.model.compose.qwen3_vl import Qwen3VLDense4BConfig TEST_TEXT_MESSAGES = [{"role": "user", "content": "Hello!"}] @@ -106,11 +106,9 @@ def test_lmdeploy_update_weight_and_generate(self): ) futures = [ worker.test_all_reduce.remote() for worker in train_workers ] ray.get(futures) - train_controller = TrainingController.remote( + train_controller = TrainingController( workers=train_workers, ) - ray.get(train_controller.__ray_ready__.remote()) - # fixed sample params sample_params = SampleParams(temperature=0.0, max_tokens=128, top_k=1) @@ -121,22 +119,23 @@ def test_lmdeploy_update_weight_and_generate(self): self.pg, ) - res_baseline = ray.get(rollout_controller.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params)) + input_state = RolloutState(message=TEST_TEXT_MESSAGES, sample_params=sample_params) + res_baseline = ray.get(rollout_controller.generate.remote(rollout_state=input_state)) # start update weight test - info_dict = ray.get(rollout_controller.get_rollout_info.remote()) - ray.get(train_controller.update_rollout_info.remote(info_dict)) + info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) + train_controller.update_rollout_info(info_dict) # update weights ray.get(rollout_controller.offload.remote()) - ray.get(train_controller.onload.remote(target="all")) - ray.get(train_controller.offload.remote(["optimizer"])) + train_controller.onload(target="all") + train_controller.offload("optimizer") ray.get(rollout_controller.onload_weights.remote()) - ray.get(train_controller.update_weights.remote()) - ray.get(train_controller.offload.remote(["model"])) + train_controller.update_weights() + train_controller.offload("model") ray.get(rollout_controller.onload_kvcache.remote()) - res_update_weight = ray.get(rollout_controller.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params)) + res_update_weight = ray.get(rollout_controller.generate.remote(rollout_state=input_state)) self.assertEqual(res_update_weight.response, res_baseline.response) ray.get(rollout_controller.shutdown.remote(), timeout=60) diff --git a/tests/ray/test_utils.py b/tests/rl/test_utils.py similarity index 97% rename from tests/ray/test_utils.py rename to tests/rl/test_utils.py index ca516469b8..127198440e 100644 --- a/tests/ray/test_utils.py +++ b/tests/rl/test_utils.py @@ -5,7 +5,7 @@ -from xtuner.v1.ray.utils import find_master_addr_and_port, get_accelerator_ids, get_ray_accelerator +from xtuner.v1.rl.utils.ray_utils import find_master_addr_and_port, get_accelerator_ids, get_ray_accelerator import parametrize diff --git a/tests/rl/test_vl_rollout.py b/tests/rl/test_vl_rollout.py new file mode 100644 index 0000000000..4c194a179c --- /dev/null +++ b/tests/rl/test_vl_rollout.py @@ -0,0 +1,163 @@ +import os +import subprocess +import unittest +import tempfile +import ray +import torch +from transformers import AutoTokenizer +import tempfile +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers +from xtuner.v1.data_proto.rl_data import Status +from xtuner.v1.datasets.rl_tokenize_fn import RLQwen3VLTokenizeFnConfig +import asyncio +from xtuner.v1.rl.rollout import RolloutController + + +MODEL_PATH=os.getenv("QWEN3_VL_DENSE_PATH") +MOE_MODEL_PATH=os.getenv("QWEN3_VL_MOE_PATH") +MEDIA_ROOT=os.getenv("GEO3K_MEDIA_ROOT") + +resource_map = { + "npu": "NPU", + "cuda": "GPU", +} +class TestVLMRollout(unittest.IsolatedAsyncioTestCase): + + @classmethod + def setUpClass(cls) -> None: + os.environ["XTUNER_USE_FA3"] = "1" + os.environ["LMD_SKIP_WARMUP"] = "1" + + @classmethod + def tearDownClass(cls) -> None: + del os.environ["XTUNER_USE_FA3"] + del os.environ["LMD_SKIP_WARMUP"] + + def init_config(self): + self.resources_cfg = AcceleratorResourcesConfig( + accelerator=resource_map[torch.accelerator.current_accelerator().type], + num_workers=8, + num_cpus_per_worker=8, + cpu_memory_per_worker=16 * 1024**3, # 16 GB + ) + self.max_prompt_length = 1024 + self.max_response_length = 2048 + self.context_length = self.max_prompt_length + self.max_response_length + + tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) + tokenize_fn = RLQwen3VLTokenizeFnConfig(processor_path=self.model_path, max_length=self.max_prompt_length) + self.tokenize_fn = tokenize_fn.build(tokenizer) + + def setUp(self): + ray.init(num_cpus=80, ignore_reinit_error=True) + self.model_path = MODEL_PATH + self.temp_dir = tempfile.TemporaryDirectory() + self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") + self.init_config() + + def tearDown(self): + ray.shutdown() + # When lmdeploy enable ep>1, it uses deep_ep. Buffer implicit destroy would cause some ray actor stucked. + # Use pkill cleen up ray::WorkerWrapper process after close ray cluster connection as workaround. + # TODO(chenchiyu): add excplicit deep_ep destroy in lmdeploy. + self._cleanup_lmdeploy_ray_worker_wrapper() + self.temp_dir.cleanup() + + def _cleanup_lmdeploy_ray_worker_wrapper(self): + try: + result = subprocess.run(["pkill", "-f", "ray::RayWorkerWrapper*"], capture_output=True, text=True, timeout=10) + if result.returncode != 0: + print(f"pkill command failed with return code {result.returncode}: {result.stderr}." + " Maybe no lmdeploy ray::RayWorkerWrapper processes found.") + except Exception as e: + print(f"Error stopping ray::RayWorkerWrapper cluster: {e}") + + + @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") + def test_parallel_rollout(self): + resource_config = AcceleratorResourcesConfig( + accelerator=resource_map[torch.accelerator.current_accelerator().type], + num_workers=4, + num_cpus_per_worker=4, + cpu_memory_per_worker=8 * 1024**3, # 8 GB + ) + pg1 = AutoAcceleratorWorkers.build_placement_group(resource_config, name="tp_pg") + pg2 = AutoAcceleratorWorkers.build_placement_group(resource_config, name="ep_pg") + dense_model_path = MODEL_PATH + moe_model_path = MOE_MODEL_PATH + dist_port_base = 38000 + async def run_both(): + return await asyncio.gather( + self._run_rollout(model_path=dense_model_path, tp_size=4, ep_size=1, pg=pg1, dist_port_base=dist_port_base), + # self._run_rollout(model_path=moe_model_path, tp_size=1, ep_size=4, pg=pg2, dist_port_base=dist_port_base + 1024 * 4), # TODO: lmdeploy 修复后启动 + return_exceptions=False + ) + + asyncio.run(run_both()) + + async def _run_rollout(self, model_path, tp_size, ep_size, pg, dist_port_base): + rollout_config = RolloutConfig( + env="test_rollout", + model_path=model_path, + model_name=os.path.basename(model_path).lower(), + tokenizer_path=model_path, + tensor_parallel_size=tp_size, + expert_parallel_size=ep_size, + context_length=self.context_length, + worker_log_dir=self.worker_log_dir, + dist_port_base=dist_port_base, + enable_return_routed_experts=ep_size > 1, # ep_size > 1 默认打开r3 + ) + rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg) + result_refs = [] + + # Test Case 1: 纯文本 + rollout_state = self.tokenize_fn({ + "prompt": [{"role": "user", "content": "Hello!"}], + "data_source": "test/text", + }) + result1_ref = rollout_controller.generate.remote(rollout_state=rollout_state) + result_refs.append(result1_ref) + + # Test Case 2: 图片 + input_data = {"prompt": [{"content": [{"image_url": {"image_wh": [297, 265], "url": "images/test_0.jpg"}, "type": "image_url"}, {"text": "Chords $\\overline{A C}$ and $\\overline{D F}$ are equidistant from the center. If the radius of $\\odot G$ is 26 find $A C$ You FIRST think about the reasoning process as an internal monologue and then provide the final answer. The reasoning process MUST BE enclosed within tags. The final answer MUST BE put in \\boxed{}.", "type": "text"}], "role": "user"}], "data_source": "hiyouga/geometry3k", "ability": "math", "reward_model": {"ground_truth": "48", "style": "rule"}} + rollout_state = self.tokenize_fn(input_data, media_root=MEDIA_ROOT) + rollout_state.tokens = rollout_state.prompt_ids + result2_ref = rollout_controller.generate.remote(rollout_state=rollout_state) + result_refs.append(result2_ref) + + try: + results = await asyncio.wait_for(asyncio.gather(*result_refs), timeout=300) + for i, result in enumerate(results): + case_id = f"Case {i+1}" + self.assertEqual(result.status, Status.COMPLETED, + msg=f"{case_id} failed: Expected status COMPLETED but got {result.status} and error_msg {result.error_msg}") + self.assertEqual(result.finish_reason, 'stop', + msg=f"{case_id} failed: Expected finish_reason 'stop' but got {result.finish_reason}") + + if result.sample_params.return_token_ids: + self.assertGreater(len(result.response_ids), 0, + msg=f"{case_id} failed: response_ids should not be empty when return_token_ids is True") + + if result.sample_params.return_logprob: + self.assertEqual(len(result.logprobs), len(result.response_ids), + msg=f"{case_id} failed: logprobs length ({len(result.logprobs)}) " + f"does not match response_ids length ({len(result.response_ids)})") + + except asyncio.TimeoutError: + if tp_size > 1 and ep_size == 1: + self.fail("TP and Dense Rollout timed out!") + if ep_size > 1 and tp_size == 1: + self.fail("EP and MoE Rollout timed out!") + finally: + await asyncio.wait_for(rollout_controller.shutdown.remote(), timeout=300) + + # @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") + # def test_vl_resume_with_partial_rollout(self): + # # TODO: 后续实现 + # pass + + +if __name__ == "__main__": + unittest.main() diff --git a/xtuner/v1/data_proto/__init__.py b/xtuner/v1/data_proto/__init__.py index c30af9de46..027eaa22ff 100644 --- a/xtuner/v1/data_proto/__init__.py +++ b/xtuner/v1/data_proto/__init__.py @@ -1,6 +1,8 @@ +from .cache_item import CacheItem from .sequence_context import SequenceContext __all__ = [ + "CacheItem", "SequenceContext", ] diff --git a/xtuner/v1/data_proto/cache_item.py b/xtuner/v1/data_proto/cache_item.py new file mode 100644 index 0000000000..0193a1d48f --- /dev/null +++ b/xtuner/v1/data_proto/cache_item.py @@ -0,0 +1,10 @@ +from typing_extensions import NotRequired, TypedDict + + +class CacheItem(TypedDict): + num_tokens: int + num_img_tokens: NotRequired[list[int]] + proxy_attn_flops: NotRequired[float] + + +__all__ = ["CacheItem"] diff --git a/xtuner/v1/data_proto/rl_data.py b/xtuner/v1/data_proto/rl_data.py index 8470f8b3bf..7b02fec16e 100644 --- a/xtuner/v1/data_proto/rl_data.py +++ b/xtuner/v1/data_proto/rl_data.py @@ -1,453 +1,304 @@ from __future__ import annotations -import copy -from typing import TYPE_CHECKING, Any, TypeAlias +import base64 +from enum import Enum +from typing import TYPE_CHECKING, Any, Literal, TypeAlias +import numpy as np import torch -from cyclopts import Parameter -from pydantic import BaseModel, ConfigDict, Field -from typing_extensions import Annotated, NotRequired, Self, TypedDict - -from xtuner.v1.utils import StrEnum +from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator +from typing_extensions import NotRequired, TypedDict # ==================================== # ====== DataFlow 数据流 ============== # ==================================== +from xtuner.v1.data_proto.utils import calculate_seq_staleness from xtuner.v1.utils.logger import get_logger if TYPE_CHECKING: - import ray - - RayObjectRef = ray.ObjectRef + from ray import ObjectRef as RayObjectRef else: RayObjectRef: TypeAlias = Any logger = get_logger() -class RolloutState(StrEnum): - """ - - 1. State Transitions from finish_reason and RolloutState: - - A new task starts as `INIT`. - - A successful generation (finish_reason 'stop' or 'length') becomes `COMPLETED`. - - A generation stopped by the dataflow (e.g., for partial rollout) becomes `ABORTED`. - - A generation that fails due to an inference server error becomes `FAILED`. - - A generation skipped due to client errors or timeout errors (e.g., invalid input) becomes `SKIPPED`. - - Data used for training is marked as `ARCHIVED`. - - Old data (rollout for morn than expiration step) in the replay buffer is marked as `EXPIRED`. - - 2. Dataflow Handling Based on RolloutState: - - `INIT`: Data is in progress; no special handling. - - `COMPLETED`: Data is valid for filtering, replay buffer insertion and training. - - `ABORTED`: Data may be partially valid; It's valid for replay buffer insertion but not for filtering and training. - - `FAILED`: Data is invalid; not used for filtering, replay buffer or training. - - `SKIPPED`: Data is invalid; not used for filtering, replay buffer or training. - - `ARCHIVED`: Data is stored for historical purposes; not used for training. - - `EXPIRED`: Data is removed from the replay buffer; not used for training. - """ - +class SampleParams(BaseModel): + model_config = ConfigDict(extra="forbid") + n: int = 1 + top_k: int = 0 + top_p: float = 1.0 + temperature: float = 1.0 + repetition_penalty: float = 1.0 + presence_penalty: float = 0.0 + frequency_penalty: float = 0.0 + min_tokens: int = 0 + max_tokens: int = 2048 + stops: list[str] = [] + stop_token_ids: list[int] = [] + skip_special_tokens: bool = True + sampling_seed: int | None = None + stream: bool = False + return_logprob: bool = True + top_logprobs: int = 1 + return_token_ids: bool = True + include_stop_str_in_output: bool = True + no_stop_trim: bool = True + spaces_between_special_tokens: bool = False + return_routed_experts: bool = False + + +class Status(Enum): INIT = "init" COMPLETED = "completed" ABORTED = "aborted" + EXPIRED = "expired" FAILED = "failed" + FILTERED = "filtered" + # 归档,这个状态还是要保留,用不用再说,用于表示这个数据已经用于一次训练了,但保留在数据库里以备查询 ARCHIVED = "archived" - EXPIRED = "expired" - SKIPPED = "skipped" - - @staticmethod - def from_str(state_str: str) -> RolloutState: - for state in RolloutState: - if state.value == state_str: - return state - raise ValueError(f"Unknown ReplayState string: {state_str}") -class RLUIDItem(BaseModel): - """A unique identifier for tracking data items within the dataflow. +class MultimodalInfo(TypedDict): + # 使用TypedDict给出pixel_values的类型提示 + pixel_values: NotRequired[np.ndarray | RayObjectRef | None] + image_grid_thw: NotRequired[torch.Tensor] - Attributes: - env (str): The environment name. - root_id (int): The root ID for grouping related data items. - action_id (int): The ID for a specific action in prompt. - observation_id (int): The ID for a specific observation in response. - version (int): The version number of the data item. - """ +class RolloutFunctionCall(BaseModel): model_config = ConfigDict(extra="forbid") - env: str = "" - root_id: int = -1 - action_id: int = -1 - observation_id: int = -1 - version: int = 0 + name: str + arguments: Any = Field(default_factory=dict) + raw_arguments_text: str | None = None -class MultimodalTrainInfo(TypedDict): - pixel_values: NotRequired[torch.Tensor | RayObjectRef | None] # type: ignore[valid-type] - image_grid_thw: NotRequired[torch.Tensor] - position_ids: NotRequired[torch.Tensor | None] +class RolloutToolCall(BaseModel): + model_config = ConfigDict(extra="forbid") -class RLDatasetItem(BaseModel): - """Represents the data structure output from the dataset. + id: str + type: Literal["function"] = "function" + function: RolloutFunctionCall - Attributes: - messages (Optional[List[Dict[str, Any]]]): The message list for the prompt. - input_ids (Optional[List[int]]): The tokenized input IDs. - num_tokens (Optional[int]): The number of tokens in the input. - proxy_attn_flops (Optional[float]): The estimated proxy attention FLOPs for the sample. unused for RL - ability (Optional[str]): The ability or category of the data. - reward_model (Optional[Dict[str, Any]]): Data required by the reward model, like ground truth. - data_source (Optional[Dict[str, Any]]): The source of the data, used for weighting rewards. - extra_info (Dict[str, Any]): Additional user-defined information. - """ +class RolloutState(BaseModel): model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) - messages: list[dict[str, Any]] | None = None - input_ids: list[int] | None = None + + # --- 数据 --- + message_uid: int | None = None # 通过计算原始的message的哈希值得到的id,一组的数据为同一个prompt_id + message: list[dict[str, Any]] # dataset输出,需要在AgentLoop中转换成input_ids + prompt_ids: list[int] | None = None # 原始 prompt的token ids num_tokens: int | None = None proxy_attn_flops: float | None = None - ability: str | None = None + data_source: dict[str, Any] | str | None = None + mm_info: MultimodalInfo | None = None reward_model: dict[str, Any] | None = None - data_source: dict[str, Any] | None = None - extra_info: dict[str, Any] = dict() - multimodal_train_info: MultimodalTrainInfo | None = None + # --- InferEngine 输入 --- + session_uid: int | None = None + tokens: list[int] | None = None # 每一次推理引擎的实际输入 + tools: list | None = None + tool_choice: str | dict[str, Any] | None = None + sample_params: SampleParams = SampleParams() -class RolloutExtraInfo(TypedDict): - routed_experts: NotRequired[list[int] | RayObjectRef] # type: ignore[valid-type] - partial_rollout_input_ids: NotRequired[list[int]] - - -class RLRolloutResponseItem(BaseModel): - """Represents the data structure output from the rollout process. - - Attributes: - response (Optional[str]): The generated text response from the model. - response_ids (Optional[List[int]]): The token IDs of the generated response. - num_return_tokens (Optional[int]): The number of tokens in the response. - finish_reason (Optional[str]): The reason why the generation finished (e.g., 'stop', 'length'). - logprobs (Optional[List[float]]): The log probabilities of the generated tokens. - extra_info (Dict[str, Any]): Additional user-defined information. - """ - - model_config = ConfigDict(extra="forbid") + # --- InferEngine 输出 --- + # 每一次推理引擎的实际输出, 在rollout worker中被覆盖写 response: str | None = None + tool_calls: list[RolloutToolCall] | None = None response_ids: list[int] | None = None logprobs: list[float] | None = None - num_return_tokens: int | None = None - versioned_response: list[str] = Field(default_factory=list) - versioned_response_ids: list[list[int]] = Field(default_factory=list) - versioned_logprobs: list[list[float]] = Field(default_factory=list) - versioned_num_return_tokens: list[int] = Field(default_factory=list) - finish_reason: str | None = None # "stop", "length", "abort", "failed", "skipped" - extra_info: RolloutExtraInfo = Field(default_factory=dict) - state: RolloutState = RolloutState.INIT - - def _update_by_append(self, other: Self) -> None: - other_ids_copy = copy.deepcopy(other.response_ids) - other_logprobs_copy = copy.deepcopy(other.logprobs) - other_response_copy = copy.deepcopy(other.response) - if other_response_copy is not None: - assert self.response is not None, "response must not be None when updating partial data." - self.response += other_response_copy - self.versioned_response.append(other_response_copy) - - if other_ids_copy is not None: - assert self.response_ids is not None, "response_ids must not be None when updating partial data." - self.response_ids.extend(other_ids_copy.copy()) - self.versioned_response_ids.append(other_ids_copy) - self.versioned_num_return_tokens.append(len(other_ids_copy)) - - if other_logprobs_copy is not None: - assert self.logprobs is not None, "logprobs must not be None when updating partial data." - self.logprobs.extend(other_logprobs_copy.copy()) - self.versioned_logprobs.append(other_logprobs_copy) - - self.num_return_tokens = len(self.response_ids) if self.response_ids is not None else 0 - self.finish_reason = other.finish_reason - self.extra_info.update(other.extra_info) - self.state = other.state - return - - def update(self, other: Self) -> None: - """Updates this RLRolloutResponseItem with data from another one. - - If partial_rollout is True, concat other response to this RLRolloutResponseItem's response. - """ - if not isinstance(other, RLRolloutResponseItem): - raise TypeError("Can only update with another RLRolloutResponseItem instance.") - - if other.response_ids is None and other.logprobs is None and other.response is None: - self.finish_reason = other.finish_reason - self.state = other.state - self.extra_info.update(other.extra_info) - return - - if self.response_ids is None: - assert self.response is None and self.logprobs is None, ( - "Inconsistent state: if response_ids is None, response and logprobs must also be None." - ) - self.response = "" - self.response_ids = [] - self.logprobs = [] - self.num_return_tokens = 0 - else: - assert self.response is not None and self.logprobs is not None, ( - "Inconsistent state: if response_ids is not None, response and logprobs must also be not None." - ) - - self._update_by_append(other) - - -class RLJudgerResponseItem(BaseModel): - """Represents the data structure output from the judger. - - Attributes: - uid (Optional[int]): A unique ID to identify which input the result corresponds to. - reward (Dict[str, Any]): A dictionary of reward scores, e.g., {"score": score}. - extra_info (Dict[str, Any]): Additional user-defined information. - """ - - model_config = ConfigDict(extra="forbid") + routed_experts: list[int] | RayObjectRef | None = None + finish_reason: str | None = None + # response_mask: 记录response_ids中哪个token算loss, 与response_ids长度相同,每轮rollout在 agent_loop.generate 中覆盖写 + response_mask: list[int] | None = None + # response_model_steps:记录 response_ids 中每个 token 来自哪个 model_step,与 response_ids 长度相同。 + response_model_steps: list[int] | None = None + # 记录该样本过期程度,即最早生成 token 的模型版本与当前训练步数的差值,数值越大表示越过期。 + seq_staleness: int = 0 + + # --- Judger 输出 --- + reward: dict[str, Any] | None = None + + # --- 状态 --- uid: int | None = None - reward: dict[str, Any] = Field(default_factory=lambda: {"score": 0.0, "val": 0.0}) - extra_info: dict[str, Any] = dict() - - -class RLAgentDataItem(BaseModel): - # todo: define agent output data structure - model_config = ConfigDict(extra="forbid") - extra_info: dict[str, Any] = dict() - - -class RLEnvDataItem(BaseModel): - """Contains the internal data structures of the environment, stored as an - observation. - - Attributes: - rollout (RLRolloutResponseItem): Data from the rollout stage. - judger (RLJudgerResponseItem): Data from the judger stage. - agent (RLAgentDataItem): Data from the agent stage. - extra_info (Dict[str, Any]): Additional user-defined information. - """ - - model_config = ConfigDict(extra="forbid") - rollout: RLRolloutResponseItem = RLRolloutResponseItem() - judger: RLJudgerResponseItem = RLJudgerResponseItem() - agent: RLAgentDataItem = RLAgentDataItem() - extra_info: dict[str, Any] = dict() - - -class RLExtraDataItem(BaseModel): - """Reserved for data that does not belong to a specific stage of the - dataflow. - - Attributes: - retry_times (int): The number of times the data processing has been retried. - extra_info (Dict[str, Any]): Additional user-defined information. - """ - - model_config = ConfigDict(extra="forbid") - retry_times: int = 0 - extra_info: dict[str, Any] = dict() - - -class RLDataFlowItem(BaseModel): - """The core data structure that flows through the dataflow and environment. - - It encapsulates all information related to a single data point, including its - unique ID, the original data, environment outputs, and extra metadata. - - Attributes: - uid (RLUIDItem): The unique identifier for the data item. - data (RLDatasetItem): The original data from the dataset. - env (RLEnvDataItem): The collected outputs from the environment stages. - extra_info (RLExtraDataItem): Additional reserved information. - """ - - model_config = ConfigDict(extra="forbid") - uid: RLUIDItem = RLUIDItem() - data: RLDatasetItem = RLDatasetItem() - env: RLEnvDataItem = RLEnvDataItem() - extra_info: RLExtraDataItem = RLExtraDataItem() - - -def is_valid_for_replaybuffer(group_data_items: list[RLDataFlowItem]) -> bool: - """Checks if a group of data items is valid for insertion into the replay - buffer. - - Args: - group_data_items: A list of RLDataFlowItem objects. - - Returns: - True if the group is valid, False otherwise. - - NOTE: Why this check is needed: - - For system fault tolerance, this check is performed at rollout / dataflow - time, but we still do it here to ensure replay buffer data integrity. - - 'skipped' or 'failed' states indicate that the rollout process did not - complete successfully or was intentionally bypassed. - - 'aborted' states may still contain useful data for the replay buffer, - as the rollout was started but not finished. - - 'completed' states are valid and should be included in the replay buffer. - """ - is_skipped = any(item.env.rollout.state == RolloutState.SKIPPED for item in group_data_items) - is_failed = any(item.env.rollout.state == RolloutState.FAILED for item in group_data_items) - if is_skipped or is_failed: - logger.warning( - "Invalid dataflow group found during replay buffer insertion, skipped: {is_skipped}, failed: {is_failed}." - ) - return False - return True - - -def is_valid_for_training(group_data_items: list[RLDataFlowItem]) -> bool: - """Checks if a group of data items is valid for a training step. - - Args: - group_data_items: A list of RLDataFlowItem objects. - - Returns: - True if the group is valid, False otherwise. - - NOTE: Why this check is needed: - - For system fault tolerance, this check is performed at rollout / dataflow - time, but we still do it here to ensure training data integrity. - - 'skipped'/'failed': These items are fundamentally broken or incomplete and - should not be used for training. - - 'aborted': These items represent rollouts that were stopped - prematurely. Using such partial data could lead the model to learn - undesirable behaviors (e.g., stopping generation too early). - - Empty response/response_ids: The model's generated response is the core - of the training data for RL algorithms like PPO. If the response is - missing, there is nothing to compute rewards on or to train the model with. - """ - is_abort = any(item.env.rollout.state == RolloutState.ABORTED for item in group_data_items) - is_skipped = any(item.env.rollout.state == RolloutState.SKIPPED for item in group_data_items) - is_failed = any(item.env.rollout.state == RolloutState.FAILED for item in group_data_items) - if is_skipped or is_failed or is_abort: - logger.debug( - f"Invalid dataflow group found during training, rollout state skipped: {is_skipped}, failed: {is_failed}, aborted: {is_abort}." - ) - return False - for item in group_data_items: - rollout_info = item.env.rollout - response_valid = True if rollout_info.response is not None and len(rollout_info.response) > 0 else False - ids_valid = True if rollout_info.response_ids is not None and len(rollout_info.response_ids) > 0 else False - if not ids_valid: - # NOTE: `response_ids` is the critical field for token-in-token-out mode, so we ensure it's not empty. - logger.warning( - "Invalid dataflow item found during training: no response or response_ids and skip this item." - ) - return False - if not response_valid: - # NOTE: check valid response string for judger inputs - logger.warning("Invalid dataflow item found during training: empty response string and skip this item.") - return False - return True - - -def update_rollout_item(group_data_items, target_value): - """Update a list of RLDataFlowItem objects by merging another - RLRolloutResponseItem into each item's env.rollout attribute. + task_name: str | None = None + status: Status = Status.INIT + error_msg: str | None = None + position_ids: torch.Tensor | None = None + extra_fields: dict[str, Any] = {} + + @field_serializer("routed_experts") + def _serialize_routed_experts(self, value: list[int] | RayObjectRef | None) -> list[int] | str | None: + """序列化 routed_experts 字段: + + - None -> None + - list[int] -> list[int](原样保留) + - RayObjectRef -> base64 编码的字符串(通过 ray.cloudpickle 序列化) + """ + import ray + + if value is None: + return None + if isinstance(value, ray.ObjectRef): + data = ray.cloudpickle.dumps(value) + return base64.b64encode(data).decode("utf-8") + return value + + @field_validator("routed_experts", mode="before") + @classmethod + def _deserialize_routed_experts(cls, value: Any) -> list[int] | RayObjectRef | None: + """反序列化 routed_experts 字段: + + - None -> None + - list[int] -> list[int](原样保留) + - str(base64 编码)-> RayObjectRef(通过 ray.cloudpickle 反序列化) + - RayObjectRef -> RayObjectRef(原样保留) + """ + import ray + + if value is None: + return None + if isinstance(value, ray.ObjectRef): + return value + if isinstance(value, str): + data = base64.b64decode(value) + return ray.cloudpickle.loads(data) + if isinstance(value, list): + return value + return value + + @field_serializer("mm_info") + def _serialize_mm_info(self, value: MultimodalInfo | None) -> MultimodalInfo | None: + # TODO: Not currently needed + return None + + +def update_status_from_finish_reason(finish_reason: str | None) -> Status: + """Updates the internal status based on the inference engine's finish + reason. + + State Transition Logic: + ------------------------------------------------------------- + | Finish Reason (Input) | Internal Status (Output) | + | :----------------------------- | :----------------------- | + | `stop`, `length`, `tool_calls` | `Status.COMPLETED` | + | `abort` | `Status.ABORTED` | + | `error` or `None` | `Status.FAILED` | + | *Others* | *Raises ValueError* | + ------------------------------------------------------------- Args: - group_data_items (List[RLDataFlowItem]): List of data items to update. - target_value (List[RLRolloutResponseItem]): The rollout response item to merge into each data item. + finish_reason (str | None): The raw finish reason string returned by + the inference engine (e.g., vLLM, LMDeploy). - Returns: - List[RLDataFlowItem]: The updated list of data items. - - Example: - >>> # Suppose you want to update the rollout response for each item - >>> items = [RLDataFlowItem(), RLDataFlowItem()] - >>> rollout_response = RLRolloutResponseItem(response="new response", response_ids=[1,2,3]) - >>> update_rollout_item(items, rollout_response) - # Now each item's env.rollout has been updated with the new response and response_ids + Raises: + ValueError: If the ``finish_reason`` is unknown and cannot be mapped. """ - - for idx, item in enumerate(group_data_items): - item.env.rollout.update(target_value[idx]) - - return group_data_items - - -def update_dataflow_item(group_data_items, target_key, target_value): - """Update a list of RLDataFlowItem objects by setting a nested attribute - for each item. + if finish_reason is None: + logger.error("finish_reason is None, setting status to FAILED.") + return Status.FAILED + + reason = finish_reason.lower() + if reason in ("stop", "length", "tool_calls"): + return Status.COMPLETED + elif reason == "abort": + return Status.ABORTED + elif reason == "error": + logger.warning("finish_reason is 'error', setting status to FAILED.") + return Status.FAILED + else: + logger.error(f"finish_reason '{finish_reason}' is unknown, setting status to FAILED.") + return Status.FAILED + + +def update_group_status(rollout_states: list[RolloutState]) -> Status: + """Updates the group status based on the individual rollout states. + + Group Status Logic: + ------------------------------------------------------------- + | Individual Rollout States | Group Status (Output) | + | :----------------------------- | :----------------------- | + | All `Status.COMPLETED` | `Status.COMPLETED` | + | Any `Status.FAILED` | `Status.FAILED` | + | Any `Status.ABORTED` | `Status.ABORTED` | + | Any `Status.EXPIRED` | `Status.EXPIRED` | + | Any `Status.FILTERED` | `Status.FILTERED` | + | *Others* | *Determined by priority*| + ------------------------------------------------------------- + + Priority Order (from highest to lowest): + 1. FAILED + 2. ABORTED + 3. EXPIRED + 4. FILTERED + 5. COMPLETED Args: - group_data_items (List[RLDataFlowItem]): List of data items to update. - target_key (str): Dot-separated path to the attribute to update (e.g., 'env.rollout.response'). - target_value (List[Any]): List of values to set, one for each data item. + rollout_states (list[RolloutState]): A list of individual rollout states. Returns: - List[RLDataFlowItem]: The updated list of data items. - - Example: - >>> # Suppose you want to update the 'response' field in env.rollout for each item - >>> items = [RLDataFlowItem(), RLDataFlowItem()] - >>> responses = ["hello", "world"] - >>> update_dataflow_item(items, "env.rollout.response", responses) - # Now items[0].env.rollout.response == "hello", items[1].env.rollout.response == "world" + Status: The aggregated group status based on the individual states. """ + if all(state.status == Status.COMPLETED for state in rollout_states): + return Status.COMPLETED + elif any(state.status == Status.FAILED for state in rollout_states): + return Status.FAILED + elif any(state.status == Status.ABORTED for state in rollout_states): + return Status.ABORTED + elif any(state.status == Status.EXPIRED for state in rollout_states): + return Status.EXPIRED + elif any(state.status == Status.FILTERED for state in rollout_states): + return Status.FILTERED + else: + # If there are other statuses, we can determine the group status based on a defined priority order. + # For now, we will default to COMPLETED if none of the above conditions are met. + return Status.COMPLETED + + +def update_sample_version(rollout_state: RolloutState, model_step: int) -> RolloutState: + """Append token source model version for newly generated response + tokens.""" + response_len = len(rollout_state.response_ids or []) + response_model_steps = list(getattr(rollout_state, "response_model_steps", None) or []) + missing_response_steps = max(0, response_len - len(response_model_steps)) + if missing_response_steps: + response_model_steps.extend([model_step] * missing_response_steps) + rollout_state.response_model_steps = response_model_steps + return rollout_state + + +def refresh_seq_staleness(group: list[RolloutState], current_train_step: int) -> list[RolloutState]: + for rollout_state in group: + # response_model_steps 记录每个 response token 的模型版本; + # 最早版本决定整条样本的滞后程度。 + response_model_steps = getattr(rollout_state, "response_model_steps", None) or [] + if response_model_steps: + rollout_state.seq_staleness = calculate_seq_staleness(min(response_model_steps), current_train_step) + else: + rollout_state.seq_staleness = 0 + return group - group_length = len(group_data_items) - assert group_length == len(target_value) - - keys = target_key.split(".") - - for i in range(group_length): - parent_obj = group_data_items[i] - for key in keys[:-1]: - parent_obj = getattr(parent_obj, key) - setattr(parent_obj, keys[-1], target_value[i]) - - return group_data_items +def update_expired_status(samples: list[RolloutState], stale_threshold: int) -> list[RolloutState]: + if stale_threshold <= 0: + raise ValueError(f"stale_threshold must be positive, got {stale_threshold}.") + is_group_expired = False -# ============================================== -# ====== Rollout API Server 数据流 ============== -# ============================================== + # 1. 检查组内是否存过期的样本 + for sample in samples: + if sample.status == Status.ABORTED and sample.seq_staleness >= stale_threshold: + logger.debug( + f"Sample {sample.uid} (seq_staleness: {sample.seq_staleness}) exceeded threshold ({stale_threshold}). Triggering group expiration." + ) + is_group_expired = True + break # 一旦发现过期,直接跳出,无需检查剩余样本 + # 2. 如果存在过期样本,将组内所有样本置为过期 + if is_group_expired: + # NOTE: 当一组数据中有一个样本被标记为过期后,这组数据中就可能出现未超过过期阈值但状态是 aborted 的样本。 + # 这些样本在后续的生成过程中也不应该被继续生成了,所以直接把它们都标记为过期, 才能在preprocess中将之前的response清掉。 + for sample in samples: + sample.status = Status.EXPIRED -class SampleParams(BaseModel): - model_config = ConfigDict(extra="forbid") - n: Annotated[int, Parameter(help="Number of samples to generate.")] = 1 - top_k: Annotated[ - int, Parameter(help="The number of highest probability vocabulary tokens to keep for top-k-filtering.") - ] = 0 - top_p: Annotated[float, Parameter(help="The cumulative probability for nucleus sampling.")] = 1.0 - temperature: Annotated[float, Parameter(help="The value used to module the next token probabilities.")] = 1.0 - repetition_penalty: Annotated[float, Parameter(help="The parameter for repetition penalty.")] = 1.0 - presence_penalty: Annotated[float, Parameter(help="The parameter for presence penalty.")] = 0.0 - frequency_penalty: Annotated[float, Parameter(help="The parameter for frequency penalty.")] = 0.0 - min_tokens: Annotated[int, Parameter(help="Minimum number of tokens to generate.")] = 0 - max_tokens: Annotated[int, Parameter(help="Maximum number of tokens to generate.")] = 2048 - stops: Annotated[list[str], Parameter(help="List of stop sequences.")] = [] - stop_token_ids: Annotated[list[int], Parameter(help="List of stop token IDs.")] = [] - skip_special_tokens: Annotated[bool, Parameter(help="Whether to skip special tokens.")] = True - sampling_seed: Annotated[int | None, Parameter(help="The seed for random number generator in sampling.")] = None - - -class RolloutExtraParams(TypedDict): - stream: bool - return_logprob: bool - top_logprobs: int - return_token_ids: bool - include_stop_str_in_output: bool - no_stop_trim: bool - skip_special_tokens: bool - spaces_between_special_tokens: bool - - -# 说明: 这里没定义API server情况数据格式,因为直接使用openai server的格式 -class RLRolloutRequestItem(BaseModel): - model_config = ConfigDict(extra="forbid") - messages: list[dict[str, Any]] - tools: list = Field(default_factory=list) - tool_choice: str = "auto" - sample_params: SampleParams = Field(default_factory=SampleParams) - extra_params: dict[str, Any] = Field(default_factory=dict) + return samples diff --git a/xtuner/v1/data_proto/utils.py b/xtuner/v1/data_proto/utils.py index a65b2d6e6d..71ce473084 100644 --- a/xtuner/v1/data_proto/utils.py +++ b/xtuner/v1/data_proto/utils.py @@ -11,6 +11,11 @@ # TODO: (yehaochen) Missing typehint here +def calculate_seq_staleness(model_step: int, current_train_step: int) -> int: + # model_step 是指哪个 train_step 训练后的模型;完全同步时 current_train_step 领先 1 步。 + return current_train_step - model_step - 1 + + def pad_to_multiple_of(sequence, padding_value, multiple_of, dim=-1): length = sequence.shape[dim] if length % multiple_of == 0: diff --git a/xtuner/v1/datasets/__init__.py b/xtuner/v1/datasets/__init__.py index 94a8f52a79..4f28d18955 100644 --- a/xtuner/v1/datasets/__init__.py +++ b/xtuner/v1/datasets/__init__.py @@ -25,7 +25,7 @@ PretrainTokenizeFunction, PretrainTokenizeFunctionConfig, ) -from .rl_tokenize_fn import RLTokenizeFnConfig +from .rl_tokenize_fn import RLTextTokenizeFnConfig from .sampler import LengthGroupedSampler, ParallelSampler from .sft_tokenize_fn import OpenaiTokenizeFunction, OpenaiTokenizeFunctionConfig from .utils import CachableTokenizeFunction, calculate_file_sha256, calculate_xxhash, tokenizer_hash @@ -56,6 +56,7 @@ "build_datasets", "build_dataloader", "sft_llm_collator", + "fake_collator", "intern_s1_vl_sft_collator", "qwen3_vl_sft_collator", "FtdpTokenizeFunction", @@ -65,7 +66,6 @@ "VLMJsonlDataset", "FTDPTokenizeFnConfig", "InternS1VLTokenizeFnConfig", - "fake_collator", "RLTokenizeFnConfig", "DatasetConfigList", "DataloaderConfig", diff --git a/xtuner/v1/datasets/_hardcode_patch.py b/xtuner/v1/datasets/_hardcode_patch.py index de223e7a3d..6c83a9ea08 100644 --- a/xtuner/v1/datasets/_hardcode_patch.py +++ b/xtuner/v1/datasets/_hardcode_patch.py @@ -26,9 +26,8 @@ from xtuner.v1.utils import get_logger from .ftdp import FtdpTokenizeFunction -from .mllm_tokenize_fn import Qwen3VLTokenizeFunction +from .mllm_tokenize_fn import InternS1VLTokenizeFunction, Qwen3VLTokenizeFunction from .pt_tokenize_fn import PretrainTokenizeFunction -from .rl_tokenize_fn.rl_tokenize_fn import InternS1VLTokenizeFunction from .sft_tokenize_fn import OpenaiTokenizeFunction diff --git a/xtuner/v1/datasets/config.py b/xtuner/v1/datasets/config.py index f9783369f4..bd8a0993de 100644 --- a/xtuner/v1/datasets/config.py +++ b/xtuner/v1/datasets/config.py @@ -355,7 +355,7 @@ def build_collator(self): elif self.collator == "qwen3_vl_sft_collator": return qwen3_vl_sft_collator elif self.collator == "fake_collator": - return fake_collator # for RL + return fake_collator else: collator = pydoc.locate(self.collator) if collator is None: diff --git a/xtuner/v1/datasets/data_item.py b/xtuner/v1/datasets/data_item.py index ddd0860eed..fbd900599a 100644 --- a/xtuner/v1/datasets/data_item.py +++ b/xtuner/v1/datasets/data_item.py @@ -1,11 +1,6 @@ import torch -from typing_extensions import NotRequired, TypedDict - -class CacheItem(TypedDict): - num_tokens: int - num_img_tokens: NotRequired[list[int]] - proxy_attn_flops: NotRequired[float] +from xtuner.v1.data_proto.cache_item import CacheItem class DataItem(CacheItem): diff --git a/xtuner/v1/datasets/jsonl.py b/xtuner/v1/datasets/jsonl.py index 8d95d716b2..1a5ac403b0 100644 --- a/xtuner/v1/datasets/jsonl.py +++ b/xtuner/v1/datasets/jsonl.py @@ -17,18 +17,18 @@ from multiprocessing import Process, Queue from pathlib import Path from threading import Lock -from typing import Callable, Dict, TypeVar, cast +from typing import Any, Callable, Dict, TypeVar, cast import numpy as np import torch from mmengine import mkdir_or_exist from mmengine.dist import barrier, get_rank +from pydantic import BaseModel from torch import distributed as dist from tqdm import tqdm from xtuner.v1.datasets.data_item import CacheItem from xtuner.v1.datasets.pt_tokenize_fn.long_text import LongTextPretrainTokenizeFunction -from xtuner.v1.datasets.rl_tokenize_fn.rl_tokenize_fn import RLTokenizeFn from xtuner.v1.utils import SharedMemory, get_logger from xtuner.v1.utils.dist_utils import get_local_process_group, get_local_world_size, is_local_rank0 @@ -439,11 +439,7 @@ def __init__( ################################## Post-processing of offsets, num_tokens and _meta ####################################### tok_hash_str = "" - if isinstance( - tokenize_fn, RLTokenizeFn - ): # RLTokenizeFn is CachableTokenizeFunction, but it does not have a hash method - tok_hash_str = "RLTokenizeFn" - elif isinstance(tokenize_fn, CachableTokenizeFunction): + if isinstance(tokenize_fn, CachableTokenizeFunction): tok_hash_str = tokenize_fn.hash() job_discriminator = os.environ.get("MASTER_PORT", "") @@ -624,14 +620,32 @@ def count_offsets(self, cache_dir=None): @staticmethod def _tokenize_by_offset( data: bytes, - tokenize_fn: Callable[[dict], CacheItem], + tokenize_fn: Callable[[dict], CacheItem | BaseModel], ) -> dict: line = data.decode() - tokenized: dict = tokenize_fn(json.loads(line)) # type: ignore[assignment] - res = {"num_tokens": tokenized["num_tokens"], "proxy_attn_flops": tokenized["proxy_attn_flops"]} - if "chunks" in tokenized: - res["chunks"] = tokenized["chunks"] - return res + tokenized = tokenize_fn(json.loads(line)) + if isinstance(tokenized, dict): + res = {"num_tokens": tokenized["num_tokens"], "proxy_attn_flops": tokenized["proxy_attn_flops"]} + if "chunks" in tokenized: + tokenized = cast(dict[str, Any], tokenized) + res["chunks"] = tokenized["chunks"] + return res + if isinstance(tokenized, BaseModel): + # RL tokenize functions return RolloutState, a Pydantic model, + # during cache building. Extract cache metadata here so those + # tokenizers do not need to return a separate CacheItem dict. + num_tokens = getattr(tokenized, "num_tokens", None) + if num_tokens is None: + raise TypeError(f"{type(tokenized).__name__} must provide `num_tokens` for dataset cache.") + proxy_attn_flops = getattr(tokenized, "proxy_attn_flops", None) + if proxy_attn_flops is None: + proxy_attn_flops = float(num_tokens) + res = {"num_tokens": num_tokens, "proxy_attn_flops": proxy_attn_flops} + chunks = getattr(tokenized, "chunks", None) + if chunks is not None: + res["chunks"] = chunks + return res + raise TypeError(f"{type(tokenized).__name__} must be a CacheItem-like dict or a Pydantic model.") def count_tokens(self, offsets, cache_dir=None): self.tokenize_fn.set_state("cache") diff --git a/xtuner/v1/datasets/resume.py b/xtuner/v1/datasets/resume.py new file mode 100644 index 0000000000..9c73111ec6 --- /dev/null +++ b/xtuner/v1/datasets/resume.py @@ -0,0 +1,50 @@ +from torch.utils.data import DataLoader +from typing_extensions import TypedDict + +from xtuner.v1.utils import get_logger + +from .packing import ExpandSoftPackDataset, _LegacySoftPackDataset +from .sampler import LengthGroupedSampler, ParallelSampler + + +logger = get_logger() + + +class DataloaderState(TypedDict): + sampler: dict + dataset: dict + + +def get_dataloader_state(dataloader: DataLoader, consumed_samples: int) -> DataloaderState: + sampler: ParallelSampler | LengthGroupedSampler = dataloader.sampler # type: ignore[assignment] + dataset: ExpandSoftPackDataset | _LegacySoftPackDataset = dataloader.dataset # type: ignore[assignment] + dataloader_state = DataloaderState(sampler={}, dataset={}) + + if not hasattr(sampler, "load_state_dict") or not hasattr(sampler, "get_state_dict"): + logger.warning(f"Resuming from {type(sampler)} is risky.") + else: + dataloader_state["sampler"].update(sampler.get_state_dict(total_consumed_steps=consumed_samples)) + + if not hasattr(dataset, "load_state_dict") or not hasattr(dataset, "get_state_dict"): + logger.warning(f"Resuming from {type(dataset)} is risky.") + else: + dataloader_state["dataset"].update(dataset.get_state_dict()) + + return dataloader_state + + +def load_dataloader_state(dataloader: DataLoader, state: dict): + sampler = dataloader.sampler + dataset = dataloader.dataset + + # Sampler require `load_state_dict` to restore the training progress since the sampler state will + # record the consumed samples. + if not hasattr(sampler, "load_state_dict"): + logger.warning(f"Resuming from {type(sampler)} is risky.") + + if hasattr(sampler, "load_state_dict"): + sampler.load_state_dict(state["sampler"]) + + # If the dataset records the training progress, we also restore it. + if hasattr(dataset, "load_state_dict"): + dataset.load_state_dict(state["dataset"]) diff --git a/xtuner/v1/datasets/rl_tokenize_fn/__init__.py b/xtuner/v1/datasets/rl_tokenize_fn/__init__.py index 2ecf6f3f61..83eb5f8b7f 100644 --- a/xtuner/v1/datasets/rl_tokenize_fn/__init__.py +++ b/xtuner/v1/datasets/rl_tokenize_fn/__init__.py @@ -1,6 +1,5 @@ -from .rl_tokenize_fn import RLTokenizeFnConfig +from .qwen3_vl_tokenize_fn import RLQwen3VLTokenizeFnConfig +from .text_tokenize_fn import RLTextTokenizeFnConfig -__all__ = [ - "RLTokenizeFnConfig", -] +__all__ = ["RLTextTokenizeFnConfig", "RLQwen3VLTokenizeFnConfig"] diff --git a/xtuner/v1/datasets/rl_tokenize_fn/qwen3_vl_tokenize_fn.py b/xtuner/v1/datasets/rl_tokenize_fn/qwen3_vl_tokenize_fn.py new file mode 100644 index 0000000000..9ee98b5356 --- /dev/null +++ b/xtuner/v1/datasets/rl_tokenize_fn/qwen3_vl_tokenize_fn.py @@ -0,0 +1,122 @@ +from typing import cast + +from xtuner.v1.data_proto.rl_data import RolloutState + +from ...data_proto.rl_data import MultimodalInfo +from ..mllm_tokenize_fn.qwen3_vl_tokenize_fn import Qwen3VLTokenizeFnConfig, Qwen3VLTokenizeFunction, QwenVL3DataItem +from ..utils import replace_image_context_and_collect_media_data + + +def remove_consecutive_img_context_tokens(tokens: list[int], img_context_id: int) -> list[int]: + if not tokens: + return tokens + + new_tokens = [tokens[0]] + for i in range(1, len(tokens)): + if tokens[i] == img_context_id and tokens[i - 1] == img_context_id: + continue # 跳过连续的 img_context_id + else: + new_tokens.append(tokens[i]) + return new_tokens + + +class RLQwen3VLTokenizeFunction(Qwen3VLTokenizeFunction): + def __init__(self, *args, ignore_multimodal_info: bool = False, data_judger_mapping: dict | None = None, **kwargs): + self.ignore_multimodal_info = ignore_multimodal_info + self.data_judger_mapping = data_judger_mapping + super().__init__(*args, **kwargs) + + # TODO: tool call + def __call__(self, item: dict, media_root: str = "", **kwargs) -> RolloutState: + extra_info = item.get("extra_info", {}) + message = item["prompt"] + system_prompt = getattr(self, "system_prompt", None) + if system_prompt: + if message[0]["role"] == "system": + message = message[1:] + message = [{"role": "system", "content": system_prompt}] + message + + data = super().__call__({"messages": message}, media_root=media_root) + + if self.state == "cache": + return RolloutState( + message=message, + num_tokens=data["num_tokens"], + proxy_attn_flops=data.get("proxy_attn_flops", float(data["num_tokens"])), + ) + else: + data = cast(QwenVL3DataItem, data) + image_data, _ = replace_image_context_and_collect_media_data(message, media_root, True) + if image_data: + extra_info["image_data"] = image_data + + # 因为 sft tokenizer fn 可能并没有完全和 apply_chat_template 中的 jinja 模块对齐,特别是 system prompt + # 为了确保一致,必须要通过 tokenizer_fn 得到 prompt_token_ids + # raw_prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) + # prompt_token_ids = self.tokenizer(raw_prompt, add_special_tokens=False)["input_ids"] + prompt_token_ids = remove_consecutive_img_context_tokens(data["input_ids"], self.img_context_token_id) + raw_prompt = self.tokenizer.decode(prompt_token_ids) # Just for logging + extra_info["raw_prompt"] = raw_prompt + # 训练时的 prompt token ids,包含连续的 img_context_token_id + extra_info["train_prompt_ids"] = data["input_ids"] + + mm_info = None + if not self.ignore_multimodal_info: + mm_info = MultimodalInfo() + if "pixel_values" in data: + mm_info["pixel_values"] = data["pixel_values"].numpy() # for ray put into shared memory + if "image_grid_thw" in data: + mm_info["image_grid_thw"] = data["image_grid_thw"] + + data_source = item.get("data_source") + assert data_source is not None, "data_source is required in item" + extra_info["origin_data_source"] = data_source + data_judger_mapping = getattr(self, "data_judger_mapping", None) + if data_judger_mapping is not None: + mapped_judger_name_and_weight = data_judger_mapping.get(data_source) + else: + mapped_judger_name_and_weight = {data_source: 1.0} + + return RolloutState( + message=message, + num_tokens=data["num_tokens"], + proxy_attn_flops=data.get("proxy_attn_flops", float(data["num_tokens"])), + prompt_ids=prompt_token_ids, + position_ids=data["position_ids"], + data_source=mapped_judger_name_and_weight, + reward_model=item.get("reward_model", {}), + mm_info=mm_info, + extra_fields=extra_info, + ) + + def hash(self) -> str: + return "RLQwen3VLTokenizeFunction" + + +class RLQwen3VLTokenizeFnConfig(Qwen3VLTokenizeFnConfig): + ignore_multimodal_info: bool = False # eval is True + + def build( + self, tokenizer, tokenizer_hash: str | None = None, anno_name: str = "", **kwargs + ) -> RLQwen3VLTokenizeFunction: + return RLQwen3VLTokenizeFunction( + tokenizer, + self.processor_path, + anno_name, + chat_template=self.chat_template, + min_pixels=self.min_pixels, + max_pixels=self.max_pixels, + oss_loader_cfg=self.oss_loader_cfg, + video_min_total_pixels=self.video_min_total_pixels, + video_max_total_pixels=self.video_max_total_pixels, + video_min_frames=self.video_min_frames, + video_max_frames=self.video_max_frames, + rand_video_max_frames=self.rand_video_max_frames, + fps=self.fps, + enable_3d_rope=self.enable_3d_rope, + add_vision_id=self.add_vision_id, + max_length=self.max_length, + system_message=self.system_message, + tokenizer_hash=tokenizer_hash, + ignore_multimodal_info=self.ignore_multimodal_info, + ) diff --git a/xtuner/v1/datasets/rl_tokenize_fn/rl_tokenize_fn.py b/xtuner/v1/datasets/rl_tokenize_fn/rl_tokenize_fn.py deleted file mode 100644 index 1dfbca52aa..0000000000 --- a/xtuner/v1/datasets/rl_tokenize_fn/rl_tokenize_fn.py +++ /dev/null @@ -1,197 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import cast - -from pydantic import BaseModel, ConfigDict - -from transformers import PreTrainedTokenizer -from xtuner.v1.data_proto.rl_data import RLDatasetItem -from xtuner.v1.utils import get_logger - -from ..data_item import OmniDataItem -from ..mllm_tokenize_fn.intern_s1_vl_tokenize_fn import InternS1VLTokenizeFunction -from ..mllm_tokenize_fn.qwen3_vl_tokenize_fn import Qwen3VLTokenizeFunction -from ..utils import ( - CachableTokenizeFunction, - replace_image_context_and_collect_media_data, -) - - -logger = get_logger() - - -def remove_consecutive_twos(tokens, img_context_id): - if not tokens: - return tokens - - new_tokens = [tokens[0]] - for i in range(1, len(tokens)): - if tokens[i] == img_context_id and tokens[i - 1] == img_context_id: - continue # 跳过连续的 img_context_id - else: - new_tokens.append(tokens[i]) - return new_tokens - - -class RLTokenizeFn(CachableTokenizeFunction[RLDatasetItem]): - def __init__( - self, - tokenizer_fn: CachableTokenizeFunction | None, - tokenizer: PreTrainedTokenizer, - max_length: int | None = None, - ignore_multimodal_info: bool = False, - data_judger_mapping: dict | None = None, - system_prompt: str | None = None, - ): - super().__init__(tokenizer) - self.tokenizer_fn = tokenizer_fn - self.max_length = max_length - - self.img_context_id = None - self.ignore_multimodal_info = ignore_multimodal_info - self.data_judger_mapping = data_judger_mapping - self.system_prompt = system_prompt - self.model_name = "default" - if self.tokenizer_fn: - if isinstance(self.tokenizer_fn, Qwen3VLTokenizeFunction): - self.model_name = "qwen3_vl" - elif isinstance(self.tokenizer_fn, InternS1VLTokenizeFunction): - self.model_name = "intern_s1_vl" - else: - raise ValueError(f"Unsupported tokenizer_fn type: {type(self.tokenizer_fn)}") - self.img_context_id = tokenizer.convert_tokens_to_ids(self.tokenizer_fn.chat_template.image_context_token) - - def __call__(self, item: dict, **kwargs) -> RLDatasetItem: - """example: - item = { - "data_source": data_source, - "prompt": [ - { - "role": "user", - "content": question, - } - ], - "ability": "math", - "reward_model": {"style": "rule", "ground_truth": solution}, - "extra_info": { - "split": split, - "index": idx, - "answer": answer_raw, - "question": question_raw, - }, - } - """ - - extra_info = item.get("extra_info", {}) - messages = item["prompt"] - if self.system_prompt: - if messages[0]["role"] == "system": - messages = messages[1:] - messages = [{"role": "system", "content": self.system_prompt}] + messages - if self.tokenizer_fn is None: - # pure text - raw_prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) - data = self.tokenizer(raw_prompt, add_special_tokens=False) - prompt_token_ids = data["input_ids"] - num_tokens = len(data["input_ids"]) - else: - # mllm - self.tokenizer_fn.state = self.state - data = self.tokenizer_fn({"messages": messages}, **kwargs) - data = cast(OmniDataItem, data) - num_tokens = data["num_tokens"] - - media_root = kwargs.get("media_root", "") - if self.model_name == "qwen3_vl": - image_data, _ = replace_image_context_and_collect_media_data(messages, media_root, True) - elif self.model_name == "intern_s1_vl": - image_data, _ = replace_image_context_and_collect_media_data(messages, media_root, False) - else: - raise ValueError(f"Unsupported model_name: {self.model_name}") - if image_data: - extra_info["image_data"] = image_data - - # 不能用下面的逻辑得到 rollout 的 prompt_token_ids - # 因为 sft tokenizer fn 可能并没有完全和 apply_chat_template 中的 jinja 模块对齐,特别是 system prompt - # 为了确保一致,必须要通过 tokenizer_fn 得到 prompt_token_ids - # raw_prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) - # prompt_token_ids = self.tokenizer(raw_prompt, add_special_tokens=False)["input_ids"] - if self.state != "cache": - prompt_token_ids = remove_consecutive_twos(data["input_ids"], self.img_context_id) - else: - prompt_token_ids = [1] # Just a placeholder - raw_prompt = self.tokenizer.decode(prompt_token_ids) # Just for logging - - multimodal_train_info = {} - extra_info["raw_prompt"] = raw_prompt - - if self.state == "cache": - if self.max_length is not None and num_tokens > self.max_length: - num_tokens = 0 # will be filtered out by the dataset filter - else: - if self.max_length is not None: - assert num_tokens <= self.max_length, f"num_tokens {num_tokens} > max_length {self.max_length}" - if not self.ignore_multimodal_info: - if "pixel_values" in data: - multimodal_train_info["pixel_values"] = data["pixel_values"] - if "image_flags" in data: - multimodal_train_info["image_flags"] = data["image_flags"] # intern-s1 or intern-vl - if "image_grid_thw" in data: - multimodal_train_info["image_grid_thw"] = data["image_grid_thw"] # qwen3-vl - if "position_ids" in data: - multimodal_train_info["position_ids"] = data["position_ids"] # qwen3-vl - - # 在多模态场景下,训练和 rollout 的 prompt ids 是不一样的 - # 为了统一训练处理逻辑,额外保存 train_prompt_ids - extra_info["train_prompt_ids"] = data["input_ids"] - - data_source = item.get("data_source") - assert data_source is not None, "data_source is required in item" - extra_info["origin_data_source"] = data_source - if self.data_judger_mapping is not None: - mapped_judger_name_and_weight = self.data_judger_mapping.get(data_source) - else: - mapped_judger_name_and_weight = {data_source: 1.0} - rl_out_data = { - "messages": messages, - "input_ids": prompt_token_ids, - "num_tokens": num_tokens, - "proxy_attn_flops": float(num_tokens), # unused for RL. for comatibility of jsonldataset - "reward_model": item["reward_model"], - "ability": item.get("ability", None), - "data_source": mapped_judger_name_and_weight, - "extra_info": extra_info, - "multimodal_train_info": multimodal_train_info, - } - return rl_out_data # type: ignore - - def hash(self) -> str: - raise ValueError("不应该触发这个方法, 因为 RLTokenizeFn 不需要缓存。") - - -class RLTokenizeFnConfig(BaseModel): - model_config = ConfigDict(title="Base RL dataset config for xtuner", extra="forbid") - tokenize_fn_cfg: BaseModel | None = None - max_length: int | None = None - ignore_multimodal_info: bool = False # eval is True - data_judger_mapping: dict | None = None # {data_source: (judger_name, judger_weight)} - system_prompt: str | None = None - - def build( - self, tokenizer: PreTrainedTokenizer, tokenizer_hash: str | None = None, anno_name: str | None = None, **kwargs - ) -> RLTokenizeFn: - tokenizer_fn = None - if self.tokenize_fn_cfg: - tokenizer_fn = self.tokenize_fn_cfg.build( - tokenizer=tokenizer, - tokenizer_hash=tokenizer_hash, - anno_name=anno_name, - **kwargs, - ) - return RLTokenizeFn( - tokenizer_fn, - tokenizer=tokenizer, - max_length=self.max_length, - ignore_multimodal_info=self.ignore_multimodal_info, - data_judger_mapping=self.data_judger_mapping, - system_prompt=self.system_prompt, - ) diff --git a/xtuner/v1/datasets/rl_tokenize_fn/text_tokenize_fn.py b/xtuner/v1/datasets/rl_tokenize_fn/text_tokenize_fn.py new file mode 100644 index 0000000000..4ad576f635 --- /dev/null +++ b/xtuner/v1/datasets/rl_tokenize_fn/text_tokenize_fn.py @@ -0,0 +1,103 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from pydantic import BaseModel, ConfigDict + +from transformers import PreTrainedTokenizer +from xtuner.v1.data_proto.rl_data import RolloutState +from xtuner.v1.utils import get_logger + +from ..utils import CachableTokenizeFunction + + +logger = get_logger() + + +class RLTextTokenizeFn(CachableTokenizeFunction[RolloutState]): + def __init__( + self, + tokenizer: PreTrainedTokenizer, + max_length: int | None = None, + tools_schema: list | None = None, + data_judger_mapping: dict | None = None, + system_prompt: str | None = None, + ): + super().__init__(tokenizer) + self.max_length = max_length + self.tools_schema = tools_schema if tools_schema is not None else [] + self.data_judger_mapping = data_judger_mapping + self.system_prompt = system_prompt + + def __call__(self, item: dict, **kwargs) -> RolloutState: + """example: + item = { + "data_source": data_source, + "prompt": [ + { + "role": "user", + "content": question, + } + ], + "ability": "math", + "reward_model": {"style": "rule", "ground_truth": solution}, + "extra_info": { + "split": split, + "index": idx, + "answer": answer_raw, + "question": question_raw, + }, + } + """ + + extra_info = item.get("extra_info", {}) + message = item["prompt"] + + if self.system_prompt: + if message[0]["role"] == "system": + message = message[1:] + message = [{"role": "system", "content": self.system_prompt}] + message + raw_prompt = self.tokenizer.apply_chat_template( + message, tools=self.tools_schema, add_generation_prompt=True, tokenize=False + ) + extra_info["raw_prompt"] = raw_prompt + data = self.tokenizer(raw_prompt, add_special_tokens=False) + prompt_token_ids = data["input_ids"] + num_tokens = len(data["input_ids"]) + + if self.state == "cache": + if self.max_length is not None and num_tokens > self.max_length: + num_tokens = 0 # will be filtered out by the dataset filter + else: + if self.max_length is not None: + assert num_tokens <= self.max_length, f"num_tokens {num_tokens} > max_length {self.max_length}" + + mapped_judger_name_and_weight = None + if self.state != "cache": + data_source = item.get("data_source") + assert data_source is not None, "data_source is required in item" + extra_info["origin_data_source"] = data_source + if self.data_judger_mapping is not None: + mapped_judger_name_and_weight = self.data_judger_mapping.get(data_source) + else: + mapped_judger_name_and_weight = {data_source: 1.0} + + rollout_state = RolloutState( + prompt_ids=prompt_token_ids, + message=message, + reward_model=item.get("reward_model", {}), + num_tokens=num_tokens, + proxy_attn_flops=float(num_tokens), + data_source=mapped_judger_name_and_weight, + extra_fields=extra_info, + ) + return rollout_state + + def hash(self) -> str: + return "RLTextTokenizeFn" + + +class RLTextTokenizeFnConfig(BaseModel): + model_config = ConfigDict(title="Text RL dataset config for xtuner", extra="forbid") + max_length: int | None = None + tools_schema: list | None = None + + def build(self, tokenizer: PreTrainedTokenizer, **kwargs) -> RLTextTokenizeFn: + return RLTextTokenizeFn(tokenizer=tokenizer, max_length=self.max_length, tools_schema=self.tools_schema) diff --git a/xtuner/v1/loss/ce_loss.py b/xtuner/v1/loss/ce_loss.py index 4c29f58d1d..0d1ad15ccd 100644 --- a/xtuner/v1/loss/ce_loss.py +++ b/xtuner/v1/loss/ce_loss.py @@ -8,11 +8,11 @@ from torch.distributed.device_mesh import DeviceMesh from torch.distributed.nn.functional import all_reduce -from xtuner.v1.loss import BaseLossConfig, BaseLossContext, BaseLossKwargs -from xtuner.v1.loss.chunk_loss import ChunkLoss from xtuner.v1.utils.device import get_device # from xtuner.v1.profiler.prober import ProberList +from .base_loss_ctx import BaseLossConfig, BaseLossContext, BaseLossKwargs +from .chunk_loss import ChunkLoss from .utils import sp_gather, sp_split diff --git a/xtuner/v1/loss/rl_loss.py b/xtuner/v1/loss/rl_loss.py index c603d406f4..b6d15291d1 100644 --- a/xtuner/v1/loss/rl_loss.py +++ b/xtuner/v1/loss/rl_loss.py @@ -4,7 +4,7 @@ import torch.nn.functional as F from torch.distributed.device_mesh import DeviceMesh -from xtuner.v1.rl.utils import gather_logprobs +from xtuner.v1.rl.utils.misc import gather_logprobs from xtuner.v1.utils.device import get_device from .base_loss_ctx import BaseLossConfig, BaseLossContext, BaseLossKwargs diff --git a/xtuner/v1/module/lm_head/lm_head.py b/xtuner/v1/module/lm_head/lm_head.py index 3e8e48350f..67e912d47c 100644 --- a/xtuner/v1/module/lm_head/lm_head.py +++ b/xtuner/v1/module/lm_head/lm_head.py @@ -6,7 +6,7 @@ from torch.distributed.tensor import DTensor from typing_extensions import overload -from xtuner.v1.loss import LMHeadLossContext +from xtuner.v1.loss.ce_loss import LMHeadLossContext Loss: TypeAlias = torch.Tensor diff --git a/xtuner/v1/ray/__init__.py b/xtuner/v1/ray/__init__.py deleted file mode 100644 index 1d2ccfb3c9..0000000000 --- a/xtuner/v1/ray/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .base import AcceleratorResourcesConfig, AutoAcceleratorWorkers, SingleAcceleratorWorker -from .utils import ( - find_master_addr_and_port, - get_accelerator_ids, - get_ray_accelerator, - load_function, -) diff --git a/xtuner/v1/ray/base/__init__.py b/xtuner/v1/ray/base/__init__.py deleted file mode 100644 index c66e30c7fa..0000000000 --- a/xtuner/v1/ray/base/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .accelerator import AcceleratorResourcesConfig, AutoAcceleratorWorkers, SingleAcceleratorWorker -from .cpu import AutoCPUWorkers, BaseCPUWorker, CPUResourcesConfig diff --git a/xtuner/v1/ray/base/cpu.py b/xtuner/v1/ray/base/cpu.py deleted file mode 100644 index ca08f5d859..0000000000 --- a/xtuner/v1/ray/base/cpu.py +++ /dev/null @@ -1,191 +0,0 @@ -from typing import Any, Dict, TypeVar - -import ray -from cyclopts import Parameter -from pydantic import BaseModel, ConfigDict, field_validator -from ray.util.placement_group import VALID_PLACEMENT_GROUP_STRATEGIES, PlacementGroup, placement_group -from typing_extensions import Annotated - - -PG_READY_TIMEOUT = 30 # seconds -T = TypeVar("T") - - -class CPUResourcesConfig(BaseModel): - """Configuration for CPU resources in a placement group for XTuner. - - This class provide specific configuration options for CPU-based workers in Ray placement groups. - - Args: - num_cpus_per_worker (float): Number of CPUs to allocate per worker in the - placement group. Defaults to 8. - cpu_memory_per_worker (int): Amount of CPU memory (in bytes) to allocate - for each worker in the placement group. - num_workers (int): Total number of workers in the placement group. - """ - - model_config = ConfigDict(extra="forbid") - num_workers: Annotated[int, Parameter(help="Number of workers in the placement group.")] = 1 - num_cpus_per_worker: Annotated[float, Parameter(help="Number of CPUs to allocate for the placement group.")] = 1 - cpu_memory_per_worker: Annotated[ - int, Parameter(help="Amount of memory (in bytes) to allocate for the placement group.") - ] = 1024**3 # 1 GB - pg_pack_strategy: Annotated[ - str, - Parameter(help="Placement group packing strategy, options: " + ", ".join(VALID_PLACEMENT_GROUP_STRATEGIES)), - ] = "SPREAD" - - @field_validator("pg_pack_strategy") - @classmethod - def check_pg_pack_strategy(cls, v): - if v not in VALID_PLACEMENT_GROUP_STRATEGIES: - raise ValueError(f"pg_pack_strategy must be one of {VALID_PLACEMENT_GROUP_STRATEGIES}") - return v - - def model_post_init(self, __context: Any) -> None: - assert ray.is_initialized(), "Ray must be initialized before creating CPUResourcesConfig." - available_resources = ray.available_resources() - available_cpus = available_resources.get("CPU", 0) - available_memory = available_resources.get("memory", 0) - # TODO: manage single controller's cpu resource to replace "10" here - needed_cpus = (self.num_cpus_per_worker * self.num_workers) + 10 - assert needed_cpus <= available_cpus, ( - f"Not enough available CPUs in Ray cluster, available_cpus is {available_cpus} but xtuner needs {needed_cpus}." - ) - needed_memory = self.cpu_memory_per_worker * self.num_workers + 10 * 1024**3 - assert needed_memory <= available_memory, ( - f"Not enough available memory in Ray cluster, available_memory is {available_memory} but xtuner needs {needed_memory}." - ) - # TODO: check all resources sum in cluster to avoid over allocation - - @classmethod - def from_total( - cls, total_cpus: float | int, total_memory: int, num_workers: int, pg_pack_strategy: str = "SPREAD" - ): - """Create a CPUResourcesConfig from total CPU and memory resources. - - Args: - total_cpus (float | int): Total number of CPUs to allocate across all workers. - total_memory (int): Total amount of memory (in bytes) to allocate across all workers. - num_workers (int): Number of workers in the placement group. - - Returns: - CPUResourcesConfig: The created CPUResourcesConfig object. - """ - assert num_workers > 0, "Number of workers must be positive." - return cls( - num_workers=num_workers, - num_cpus_per_worker=total_cpus / num_workers, - cpu_memory_per_worker=total_memory / num_workers, - pg_pack_strategy=pg_pack_strategy, - ) - - -class BaseCPUWorker: - """The BaseCPUWorker class serves as a foundational structure for CPU-based - workers within the XTuner framework. - - This class is designed to be extended by specific CPU worker implementations. - It provides a constructor that accepts a configuration object, allowing - subclasses to initialize with custom settings. - - Args: - config: The configuration object for the CPU worker. - num_cpus (float | int): The number of CPUs allocated to this worker. - Defaults to 1. - """ - - def __init__(self, config, num_cpus: float | int = 1): - self.config = config - self.num_cpus = num_cpus - - -class AutoCPUWorkers: - """A utility class for automatically creating and managing cpu actors - within a Ray PlacementGroup.""" - - @staticmethod - def build_placement_group(resources_config: CPUResourcesConfig): - """Build a Ray PlacementGroup based on the provided resource - configuration. - - Args: - resources_config (CPUResourcesConfig): The configuration - specifying the resources for each worker bundle. - - Returns: - PlacementGroup: The created Ray PlacementGroup. - """ - bundles = [ - { - "CPU": resources_config.num_cpus_per_worker, - "memory": resources_config.cpu_memory_per_worker, - } - ] * resources_config.num_workers - - pg = placement_group(bundles=bundles, strategy=resources_config.pg_pack_strategy) - - ray.get(pg.ready(), timeout=PG_READY_TIMEOUT) - return pg - - @staticmethod - def get_pg_options(pg: PlacementGroup, num_cpus: int | float = -1) -> Dict: - """Provide a dictionary of resource requests for Ray tasks or actors - with specific cpu requirements. - - Args: - pg (PlacementGroup): The placement group to get options for. - num_cpus (float): The number of CPUs to request. If set to -1, - the default CPU allocation from the placement group bundle - will be used. Defaults to -1. - - Returns: - Dict: A dictionary of Ray resource options for `task.options()`. - """ - assert len(pg.bundle_specs) > 0, "Placement group has no bundles defined." - default_cpu = pg.bundle_specs[0].get("CPU", 1) - return {"num_cpus": num_cpus if num_cpus >= 0 else default_cpu} - - @classmethod - def from_config(cls, worker_cls, worker_config, cpu_config: CPUResourcesConfig): - """Create workers and a placement group from configuration objects. - - Args: - worker_cls: The class of the worker to instantiate. - worker_config: The configuration for each worker instance. - cpu_config (CPUResourcesConfig): The configuration - for the cpu resources. - - Returns: - List[T]: List of created worker instances. - """ - pg = AutoCPUWorkers.build_placement_group(cpu_config) - workers_list = cls.from_placement_group(worker_cls, worker_config, pg) - - return workers_list, pg - - @classmethod - def from_placement_group(cls, worker_cls, worker_config, pg: PlacementGroup, num_workers: int = -1): - """Create workers from an existing placement group. - - Args: - worker_cls: The class of the worker to instantiate. - worker_config: The configuration for each worker instance. - pg (PlacementGroup): The existing placement group to use. - num_workers (int): The number of workers to create. Defaults to -1, - the number of bundles in the placement group will be used. - - Returns: - List[T]: List of created worker instances. - """ - pg_options = cls.get_pg_options(pg) - - num_workers = num_workers if num_workers > 0 else pg.bundle_count - workers_list = [] - for _ in range(num_workers): - worker = worker_cls.options(placement_group=pg, **pg_options).remote( - worker_config, num_cpus=pg_options.get("num_cpus", 1) - ) # type: ignore[attr-defined] - workers_list.append(worker) - - return workers_list diff --git a/xtuner/v1/ray/config/__init__.py b/xtuner/v1/ray/config/__init__.py deleted file mode 100644 index 22f86fcb56..0000000000 --- a/xtuner/v1/ray/config/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .worker import ( - RolloutConfig, - TrainingWorkerConfig, -) diff --git a/xtuner/v1/ray/config/worker.py b/xtuner/v1/ray/config/worker.py deleted file mode 100644 index 55d28fc1f2..0000000000 --- a/xtuner/v1/ray/config/worker.py +++ /dev/null @@ -1,359 +0,0 @@ -import json -import os -import socket -from pathlib import Path -from typing import Any, List, Literal, Optional, Union - -from cyclopts import Group, Parameter -from pydantic import BaseModel, ConfigDict, PrivateAttr -from typing_extensions import Annotated - -from xtuner.v1.utils import get_logger - - -worker_group = Group("worker", help="Types of workers available.") -train_group = Group("Training", sort_key=90, help="Training worker configuration.") -infer_group = Group("inference", help="Inference worker configuration.") - - -class TrainingWorkerConfig(BaseModel): - """Configuration for the TrainingWorker.""" - - model_config = ConfigDict(extra="forbid") - type: Literal["train"] = "train" - train_model_path: Annotated[str, Parameter(group=train_group, help="Path to the training model.")] - - -class RolloutConfig(BaseModel): - """Rollout worker configuration for XTuner. - - This class defines comprehensive configuration parameters for rollout workers in XTuner, - supporting multiple inference backends with distributed computing and optimization features. - - Args: - env (str): Environment variables for the rollout worker. Defaults to "". - backend (str): Backend framework ('vllm', 'lmdeploy', etc.). Defaults to "lmdeploy". - model_path (str | Path): Path to the inference model. - model_name (str): Model name for the backend engine. - tokenizer_path (str): Path to the model tokenizer. Defaults to "". - api_key (Optional[Union[List[str], str]]): API keys for rollout service. Supports single key or list of keys. Defaults to None. - api_port (Optional[int]): Port number for the rollout API server. If not set, it will find an available port starting from 8000. Defaults to 8000. - gpus_per_node (int): Number of GPUs per node. Defaults to 8. - dtype (str): Model data type ('bfloat16', 'float16', 'int8'). Defaults to "bfloat16". - gpu_memory_utilization (float): GPU memory utilization ratio. Defaults to 0.85. - random_seed (int): Random seed for reproducible generation. Defaults to 1024. - rollout_cross_node_comm (bool): Enable cross-node communication. Defaults to False. - rollout_max_batch_size_per_instance (int): Maximum batch size for the rollout worker. If not set, it will be determined automatically based on `context_length`. Defaults to 512. - allow_over_concurrency_ratio (float): Factor to allow over-concurrency in HTTP requests for the rollout worker to improve GPU utilization. Defaults to 1.2. - tensor_parallel_size (int): GPUs per inference engine (tensor parallelism). Defaults to 1. - expert_parallel_size (int): Experts per inference engine (expert parallelism). Defaults to 1. - enable_chunked_prefill (bool): Enable chunked prefill for memory efficiency. Defaults to False. - chunked_prefill_size (int): Chunk size for prefill operations. Defaults to 128. - skip_load_weights (bool): Skip weight loading for rollout worker. Defaults to False. - rollout_timeout (float): Timeout duration in seconds for rollout requests. Defaults to 3600.0. - context_length (int): Context length for the rollout worker. - launch_server_method (Literal["ray", "multiprocessing"]): Server launch method. Defaults to "ray". - system_prompt (Optional[str]): System prompt to guide generation behavior. Defaults to None. - extra_rollout_config (Optional[dict]): Backend-specific configurations using engine prefixes - (e.g., 'vllm_enable_chunked_prefill', 'lmdeploy_max_batch_size'). Defaults to empty dict. - - **Examples:** - - Example configuration with LMDeploy backend:: - - config = RolloutConfig( - env="test_env", - model_path="Qwen/Qwen3-8B", - model_name="Qwen3-8B", - tensor_parallel_size=2, - gpu_memory_utilization=0.6, - gpus_per_node=8, - backend="lmdeploy", - ) - """ - - model_config = ConfigDict(extra="forbid") - - # base config - env: Annotated[ - str, - Parameter(group=infer_group, help="Environment variables to set for the rollout."), - ] = "" - device: Annotated[str, Parameter(group=infer_group, help="Device to be used for the rollout worker.")] = "GPU" - model_path: Annotated[str | Path, Parameter(group=infer_group, help="Path to the SGLang model.")] - model_name: Annotated[ - str | None, Parameter(group=infer_group, help="Name of the model to be used in the LMDeploy.") - ] = None - tokenizer_path: Annotated[ - str | None, Parameter(group=infer_group, help="Path to the tokenizer for the model.") - ] = None - api_key: Annotated[ - Optional[Union[List[str], str]], - Parameter( - group=infer_group, - help="API keys for the rollout service. Can be a single key or a list of keys.", - ), - ] = None - api_port: Annotated[ - int, - Parameter(group=infer_group, help="Port number for the rollout API server. If not set, 8000 will be used."), - ] = 8000 - gpus_per_node: Annotated[int, Parameter(group=infer_group, help="Number of GPUs allocated per node.")] = 8 - dtype: Annotated[ - str, - Parameter(group=infer_group, help="Data type for the model, e.g., 'bfloat16', 'float16', 'int8'."), - ] = "bfloat16" - gpu_memory_utilization: Annotated[ - float, Parameter(group=infer_group, help="GPU memory utilization for the rollout worker.") - ] = 0.85 - random_seed: Annotated[int, Parameter(group=infer_group, help="Random seed for the rollout worker.")] = 1024 - # distributed config - rollout_cross_node_comm: Annotated[ - bool, - Parameter( - group=infer_group, - help="Whether to enable cross-node communication for the rollout worker.", - ), - ] = False - dist_port_base: Annotated[ - int, - Parameter( - group=infer_group, - help="Base port number for distributed communication among rollout workers.", - ), - ] = 35000 - rollout_max_batch_size_per_instance: Annotated[ - Optional[int], - Parameter( - group=infer_group, - help="Maximum batch size for the rollout worker. If not set, it will be determined automatically based on the model and GPU memory.", - ), - ] = None - allow_over_concurrency_ratio: Annotated[ - float, - Parameter( - group=infer_group, - help="Factor to allow over concurrency in the http request for rollout worker to improve GPU utilization.", - ), - ] = 1.2 - tensor_parallel_size: Annotated[ - int, - Parameter( - group=infer_group, - help="Number of GPUs allocated for each inference engine in the rollout worker.", - ), - ] = 1 - data_parallel_size: Annotated[ - int, - Parameter( - group=infer_group, - help="Number of GPUs allocated for processing data batches in parallel (Data Parallelism).", - ), - ] = 1 - expert_parallel_size: Annotated[ - int, - Parameter( - group=infer_group, - help="Number of experts allocated for each inference engine in the rollout worker.", - ), - ] = 1 - # optimization config - enable_chunked_prefill: Annotated[ - bool, - Parameter( - group=infer_group, - help="Whether to enable chunked prefill for the rollout worker.", - ), - ] = False - chunked_prefill_size: Annotated[ - int, - Parameter( - group=infer_group, - help="Chunked prefill size for the rollout worker.", - ), - ] = 128 - skip_load_weights: Annotated[ - bool, - Parameter( - group=infer_group, - help="Whether to skip loading weights for the rollout worker.", - ), - ] = False - enable_return_routed_experts: Annotated[ - bool, - Parameter( - group=infer_group, - help="Whether to enable returning routed experts for the rollout worker.", - ), - ] = False - launch_server_method: Annotated[ - Literal["ray", "multiprocessing"], - Parameter( - group=infer_group, - help="Method to launch the rollout server, either 'ray' or 'multiprocessing'.", - ), - ] = "ray" - rollout_timeout: Annotated[ - float, - Parameter( - group=infer_group, - help="Timeout duration (in seconds) for rollout requests.", - ), - ] = 1200.0 - context_length: Annotated[ - Optional[int], - Parameter( - group=infer_group, - help="Context length for the rollout worker.", - ), - ] = None - enable_float8: Annotated[ - bool, - Parameter( - group=infer_group, - help="Whether to enable float8 quantization for the rollout worker.", - ), - ] = False - extra_rollout_config: Annotated[ - dict, - Parameter( - group=infer_group, - help='Extra configuration for different rollout worker. vllm parameters will start with prefix "vllm", etc.', - ), - ] = {} - max_retry_per_worker: Annotated[ - Optional[int], - Parameter( - group=infer_group, - help="Maximum number of retries per rollout worker before deactivation.", - ), - ] = None - max_retry_per_sample: Annotated[ - int, - Parameter( - group=infer_group, - help="Maximum number of retries per sample before marking it as failed.", - ), - ] = 1 - max_prefill_token_num: Annotated[ - Optional[int], - Parameter( - group=infer_group, - help="The number of tokens each iteration during prefill.", - ), - ] = None - router_n_groups: Annotated[ - Optional[int], - Parameter( - group=infer_group, - help="The number of groups in MoE model with group router, e.g. Intern-S1-Pro.", - ), - ] = None - fp32_lm_head: Annotated[ - bool, - Parameter( - group=infer_group, - help="Use float32 for language model head.", - ), - ] = False - worker_log_dir: Annotated[Path, Parameter(help="Directory to save worker logs.")] = Path.cwd() / "work_dir" - _logged_server_urls_per_engine: bool = PrivateAttr(default=False) - - @property - def rollout_backend(self) -> str: - backend = "" - if os.environ.get("XTUNER_USE_SGLANG", "0") == "1": - backend = "sglang" - elif os.environ.get("XTUNER_USE_VLLM", "0") == "1": - backend = "vllm" - elif os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "1": - backend = "lmdeploy" - - assert backend in ["sglang", "vllm", "lmdeploy"], ( - f"Unsupported rollout backend: {backend}. Please set XTUNER_USE_SGLANG, XTUNER_USE_VLLM, or XTUNER_USE_LMDEPLOY to 1." - ) - return backend - - @property - def server_urls_per_engine(self) -> int: - # server_urls_per_engine is introduced for lmdeploy ep settings - # for now only lmdeploy pytorch backend with ep > 1 requires multiple server urls per engine - if self.rollout_backend == "lmdeploy" and self.expert_parallel_size > 1: - # when expert parallelism is used, lmdeploy requires `expert_parallel_size` server instances per engine - if not self._logged_server_urls_per_engine: - self._logged_server_urls_per_engine = True - get_logger().info( - f"Setting server_urls_per_engine={self.expert_parallel_size} due to expert parallelism in LMDeploy." - ) - return self.expert_parallel_size - else: - return 1 - - def model_post_init(self, __context: Any) -> None: - if self.model_name is None: - model_name_from_config = None - config_json_path = Path(self.model_path) / "config.json" - try: - with open(config_json_path, encoding="utf-8") as f: - config_data = json.load(f) - model_name_from_config = config_data.get("model_type") - except (json.JSONDecodeError, OSError): - pass - self.model_name = model_name_from_config or Path(self.model_path).name - - if self.tokenizer_path is None: - self.tokenizer_path = str(self.model_path) - - port = self.api_port - while True: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - try: - s.bind(("localhost", port)) - break - except OSError: - port += 1 - self.api_port = port - - if self.device == "NPU": - self.gpus_per_node = 16 - - if self.rollout_backend == "sglang": - self.launch_server_method = "multiprocessing" - self.rollout_cross_node_comm = False - else: - self.launch_server_method = "ray" - self.rollout_cross_node_comm = True - - if self.rollout_max_batch_size_per_instance is None: - assert self.context_length is not None, ( - "context_length must be set if rollout_max_batch_size_per_instance is not provided." - ) - # TODO(@duanyanhui): Provide better suggestions for different models/input-output lengths - if self.context_length <= 4096: - self.rollout_max_batch_size_per_instance = 1024 - elif self.context_length <= 8192: - self.rollout_max_batch_size_per_instance = 512 - else: - self.rollout_max_batch_size_per_instance = 128 - - if self.max_retry_per_worker is None: - self.max_retry_per_worker = self.rollout_max_batch_size_per_instance - - self.worker_log_dir.mkdir(parents=True, exist_ok=True) - - -if __name__ == "__main__": - from cyclopts import App, Group, Parameter - - app = App() - - @app.default - def test_command(*, config: RolloutConfig): - """A test command to verify the command line interface. - - Args: - config: The rollout configuration. - """ - print("This is a test command.") - - app() diff --git a/xtuner/v1/ray/dataflow/__init__.py b/xtuner/v1/ray/dataflow/__init__.py deleted file mode 100644 index 3f2fbaf630..0000000000 --- a/xtuner/v1/ray/dataflow/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .flow import DataFlow, DataFlowConfig, DataFlowProxy -from .replay_buffer import ReplayBuffer, ReplayBufferConfig diff --git a/xtuner/v1/ray/dataflow/flow.py b/xtuner/v1/ray/dataflow/flow.py deleted file mode 100644 index 71e4921bbf..0000000000 --- a/xtuner/v1/ray/dataflow/flow.py +++ /dev/null @@ -1,605 +0,0 @@ -import asyncio -import math -import time -from pathlib import Path -from typing import Any, Dict, List, Optional, TypedDict - -import httpx -import ray -from cyclopts import Parameter -from pydantic import BaseModel, ConfigDict -from ray.actor import ActorProxy -from tqdm.auto import tqdm -from typing_extensions import Annotated - -from xtuner.v1.data_proto.rl_data import ( - MultimodalTrainInfo, - RLDataFlowItem, - RolloutState, -) -from xtuner.v1.ray.environment import SingleTurnEnvironment -from xtuner.v1.ray.rollout.controller import SampleParams -from xtuner.v1.ray.utils import create_task -from xtuner.v1.utils import get_logger, ray_method - -from .replay_buffer import ReplayBuffer, ReplayBufferConfig, determine_group_state - - -class DataFlowResult(TypedDict): - data_groups: List[List[RLDataFlowItem]] - mm_train_infos: List[MultimodalTrainInfo] - metrics: Dict[str, Any] - - -class DataFlowConfig(BaseModel): - """Data flow configuration for XTuner. - - Simple configuration for managing concurrent data generation workflows - in reinforcement learning training. - - Args: - env (str): Environment identifier. Defaults to "". - max_concurrent (int): Maximum concurrent tasks. Defaults to 8. - prompt_repeat_k (int): Times to repeat each prompt. Defaults to 1. - global_batch_size (int): Target samples to collect. Defaults to 8. - max_retry_times (int): Maximum retry attempts. Defaults to 1. - enable_partial_rollout (int): Enable async mode (1) or disable (0). Defaults to 0. - sample_params (SampleParams): Model sampling parameters. Defaults to SampleParams(). - - **Examples:** - - Example configuration for dataflow:: - - config = DataFlowConfig( - env="test_env", - max_concurrent=256, - global_batch_size=1024, - prompt_repeat_k=8, - sample_params=SampleParams(max_tokens=2048), - ) - """ - - model_config = ConfigDict(extra="forbid") - - env: Annotated[ - str, - Parameter(help="Environment name to set for the dataflow."), - ] = "" - max_concurrent: Annotated[ - Optional[int], - Parameter(help="Maximum number of concurrent tasks."), - ] = None - max_retry_times: Annotated[ - int, - Parameter(help="Maximum number of retry task for failed samples."), - ] = 3 - prompt_repeat_k: Annotated[ - int, - Parameter(help="Number of times to repeat each prompt."), - ] = 1 - global_batch_size: Annotated[ - int, - Parameter(help="Target number of samples to collect before stopping."), - ] = 8 - sample_params: Annotated[SampleParams, Parameter(help="Parameters for sampling from the model.")] = SampleParams() - extra_params: Annotated[Dict, Parameter(help="Extra parameters for rollout.")] = {} - # async params - staleness_threshold: Annotated[ - float, - Parameter( - help="The maximum allowed threshold of stale (expired) samples in a training batch. Must be between 0.0 and 1.0." - ), - ] = 0.0 - enable_partial_rollout: Annotated[ - bool, - Parameter(help="Whether to enable partial rollout for asynchronous data generation."), - ] = False - tail_batch_candidate_steps: Annotated[ - int, - Parameter( - help="Number of rollout steps after which a sample becomes a candidate for the tail batch. Set to 0 to disable." - ), - ] = 0 - tail_batch_trigger_size: Annotated[ - Optional[int], - Parameter( - help="Number of candidate samples needed in the queue to trigger a tail batch operation. Set to 0 to disable." - ), - ] = None - worker_log_dir: Annotated[Path, Parameter(help="Directory to save worker logs.")] = Path.cwd() / "work_dir" - - def model_post_init(self, __context: Any) -> None: - self.worker_log_dir.mkdir(parents=True, exist_ok=True) - if self.tail_batch_trigger_size is None: - self.tail_batch_trigger_size = self.global_batch_size - - -class RawDataFlow: - """A Ray actor that manages the data flow for reinforcement learning. - - This class is responsible for sampling prompts, interacting with the environment or to generate responses, - processing the results, and storing them in a replay buffer. It orchestrates the asynchronous generation of - training data. - """ - - def __init__( - self, - env: str, - dataflow_cfg: DataFlowConfig, - replay_buffer_cfg: ReplayBufferConfig, - environment: SingleTurnEnvironment, - ): - """Initializes the DataFlow actor. - - Args: - env (str): The name of the environment. - dataflow_cfg (DataFlowConfig): Configuration for the data flow. - replay_buffer_cfg (ReplayBufferConfig): Configuration for the - replay buffer. - environment (EnvController): The environment controller actor. - postprocessor (Optional[Callable]): An optional function to - post-process the generated samples. - """ - self.logger = get_logger(log_dir=dataflow_cfg.worker_log_dir, tag="DataFlow") - self.env = env - self.config = dataflow_cfg - replay_buffer_cfg.worker_log_dir = self.config.worker_log_dir - self.replay_buffer = ReplayBuffer(replay_buffer_cfg) - self.replay_buffer.setup_storage_config( - enable_partial_rollout=self.config.enable_partial_rollout, - tail_batch_candidate_steps=self.config.tail_batch_candidate_steps, - tail_batch_trigger_size=self.config.tail_batch_trigger_size, # type: ignore - ) - self.staleness_threshold = self.config.staleness_threshold - self.env_controller = environment - self.finished_samples_count = 0 - self.skipped_sample_count = 0 - self.failed_sample_count = 0 - self.filtered_samples_count = 0 - self._raw_reward_sum = 0.0 - self._raw_reward_count = 0 - self.tb_metrics: Dict[str, Any] = {} - self.target_batch_size = self.config.global_batch_size - rollout_info = ray.get(self.env_controller.get_rollout_info.remote()) # type: ignore[attr-defined] - self.worker_url_list = list(rollout_info["server_url_dict"].values()) - self.logger.info(f"DataFlow connected to active rollout workers url: {self.worker_url_list}") - rollout_config = rollout_info["rollout_config"] - max_concurrent = int( - ( - rollout_config.rollout_max_batch_size_per_instance - * len(self.worker_url_list) - / self.config.prompt_repeat_k - ) - * rollout_config.allow_over_concurrency_ratio - ) - - if self.config.max_concurrent is None: - self.config.max_concurrent = max_concurrent - self.logger.info( - f"Set Dataflow max_concurrent to {self.config.max_concurrent} based on rollout max batch size and number of workers." - ) - else: - self.logger.warning( - f"Dataflow max_concurrent is set to {self.config.max_concurrent}, we proposed to set max_concurrent to {max_concurrent} based on rollout_max_batch_size_per_instance." - ) - self.logger.info(f"DataFlowConfig:\n{self.config.model_dump_json(indent=2)}") - self.cleanup_task_time = 5 * 60 # 5 minutes - self.cancel_response_timeout = 5.0 - - def _reset_internal_states( - self, - global_batch_size: Optional[int] = None, - sample_params: Optional[SampleParams] = None, - extra_params: Optional[Dict] = None, - staleness_threshold: Optional[float] = None, - ): - """Resets all internal state variables of DataFlow.""" - self.skipped_sample_count = 0 - self.failed_sample_count = 0 - self.filtered_samples_count = 0 - self._raw_reward_sum = 0.0 - self._raw_reward_count = 0 - self.tb_metrics = {} - if global_batch_size and global_batch_size > 0: - self.target_batch_size = global_batch_size - else: - self.target_batch_size = self.config.global_batch_size - - if staleness_threshold is not None: - self.staleness_threshold = staleness_threshold - else: - self.staleness_threshold = self.config.staleness_threshold - - self.sample_from_expired_storage, self.finished_samples_count = self.replay_buffer.get_prerun_state() - ray.get(self.env_controller.restart.remote()) # type: ignore[attr-defined] - # Restart judger abort state for next round - try: - ray.get(self.env_controller.restart_judger.remote()) # type: ignore[attr-defined] - except Exception as e: - self.logger.error(f"Failed to restart judger (next round may be affected): {e}") - self.sample_params = sample_params if sample_params else self.config.sample_params - self.extra_params = extra_params if extra_params else self.config.extra_params - logger_msg = ( - f"DataFlow states for new generations: target_batch_size={self.target_batch_size}, " - f"sample_params: {self.sample_params}, extra_params: {self.extra_params}, " - f"sample_from_expired_storage={self.sample_from_expired_storage}, finished_samples_count={self.finished_samples_count}, " - ) - self.logger.info(logger_msg) - - @ray_method - def get_train_dataset_length(self): - """Gets the length of the training dataset from the replay buffer.""" - return self.replay_buffer.get_train_dataset_length() - - @ray_method - async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowItem]] = None): - """A single worker task to generate and process a group of samples. - - This task performs the following steps: - 1. Samples a prompt from the replay buffer (or uses a sample for retry). - 2. Calls the environment controller or rollout controller to generate a response. - 3. Post-processes the generated samples use default postprocessor and custom postprocessor. - 4. Adds the filtered samples to the replay buffer. - - Args: - group_samples_for_retry (Optional[List[RLDataFlowItem]]): A group - of samples to retry if a previous attempt failed. Defaults to - None. - - Returns: - Optional[List[RLDataFlowItem]]: The group of samples if the task - fails and needs to be retried, otherwise None. - """ - task_start_time = time.perf_counter() - # step 1: sample - # TODO(@duanyanhui): More fine-grained control over group data generation: - # Pass n to the inference engine to ensure that the same data is processed by the same server, improving efficiency. - group_data_items = self.replay_buffer.sample(self.env, self.config.prompt_repeat_k) - assert len(group_data_items) > 0, "Sampled empty group data items from replay buffer." - action_id = group_data_items[0].uid.action_id - # step 2: env generate - env_run_ref = self.env_controller.run.remote( # type: ignore[attr-defined] - group_data_items, - sample_params=self.sample_params, - extra_params=self.extra_params, - ) - try: - group_data_items = await asyncio.shield(env_run_ref) - except asyncio.CancelledError as exc: - ray.cancel(env_run_ref, recursive=True) - try: - group_data_items = await asyncio.wait_for( - asyncio.shield(env_run_ref), - timeout=self.cancel_response_timeout, - ) - except BaseException: - raise exc - - # Step 3: Determine the sample's state and act accordingly. - group_state = determine_group_state(group_data_items) - self.logger.debug(f"Determined replay state for {action_id}: {group_state}") - if group_state == RolloutState.COMPLETED: - # Accumulate raw rewards before post_processor filters samples out. - for item in group_data_items: - reward_data = getattr(item.env, "judger", None) - if reward_data is not None: - reward_dict = reward_data.reward if hasattr(reward_data, "reward") else reward_data - score = reward_dict.get("score") if isinstance(reward_dict, dict) else None - if score is not None: - self._raw_reward_sum += score - self._raw_reward_count += 1 - if not self.sample_from_expired_storage: - group_data_items = self.replay_buffer.post_processor(group_data_items) # type: ignore[attr-defined] - if len(group_data_items) > 0: - self.replay_buffer.add(group_data_items) # type: ignore[attr-defined] - else: - self.filtered_samples_count += 1 - self.logger.debug(f"Worker task completed successfully for {action_id}.") - elif group_state == RolloutState.ABORTED: - self.replay_buffer.add(group_data_items) # type: ignore[attr-defined] - self.logger.debug(f"Adding aborted sample {action_id} to aborted storage") - elif group_state == RolloutState.SKIPPED: - self.skipped_sample_count += 1 - self.logger.info(f"Total skipped samples count: {self.skipped_sample_count}") - elif group_state == RolloutState.FAILED: - self.failed_sample_count += 1 - self.logger.info(f"Total failed samples count: {self.failed_sample_count}") - else: - self.logger.error(f"Unexpected group state '{group_state}' for action_id {action_id}.") - - return time.perf_counter() - task_start_time - - async def concurrent_task_runner(self): - """Orchestrates the concurrent execution of worker tasks. - - This method manages a pool of asynchronous worker tasks to collect a - target number of samples (`self.target_batch_size`). It dynamically - adjusts the number of concurrent tasks based on progress and a - staleness threshold, ensuring efficient data generation. - - The process is as follows: - 1. Initializes a set of worker tasks, potentially over-provisioning - based on `self.config.staleness_threshold` to account for - variability in task completion times. - 2. Enters a main loop that continues until `target_batch_size` - samples are collected. - 3. Inside the loop, it periodically checks the number of pending - tasks and launches new ones if the current number is insufficient - to meet the target, maintaining a steady flow of data generation. - 4. Uses `asyncio.wait` with a short timeout to efficiently monitor - for completed tasks without blocking execution. - 5. A progress bar (`tqdm`) is updated as samples are collected. - 6. Once `target_batch_size` is reached, it sends a pause/abort - signal to all rollout workers to prevent them from starting new - computations. - 7. It then waits for any remaining in-flight tasks to complete, with - a configurable timeout to prevent indefinite hanging. Tasks that - do not finish within the timeout are forcefully cancelled. - """ - waiting_tasks = set() - dataflow_start_time = time.perf_counter() - task_completion_times = [] - with tqdm(total=self.target_batch_size, desc="rollout_controller for training samples", miniters=10) as pbar: - last_pbar_n = self.finished_samples_count - init_finished_samples_count = self.finished_samples_count - - if self.sample_from_expired_storage: - # 如果是从过期的存储中采样数据,需要禁用staleness_threshold - data_concurrency = self.target_batch_size - self.finished_samples_count - self.logger.info( - f"Sampling from expired storage, starting {data_concurrency} worker tasks from expired samples." - ) - else: - data_concurrency = math.ceil( - (1 + self.staleness_threshold) * (self.target_batch_size - self.finished_samples_count) - ) - self.logger.info( - f"Starting dataflow concurrent task runner with data_concurrency: {data_concurrency}, target_batch_size: {self.target_batch_size}, finished_samples_count: {self.finished_samples_count}, staleness_threshold: {self.staleness_threshold}" - ) - - for _ in range(data_concurrency): - task = create_task(self.worker_task()) - waiting_tasks.add(task) - - while ( - self.finished_samples_count < self.target_batch_size - and self.failed_sample_count < self.target_batch_size - and self.skipped_sample_count < self.target_batch_size - ): - done_tasks, pending_tasks = await asyncio.wait( - waiting_tasks, timeout=0.1, return_when=asyncio.FIRST_COMPLETED - ) - - for done_task in done_tasks: - task_time = done_task.result() - task_completion_times.append(task_time) - - self.finished_samples_count = self.replay_buffer.get_completed_samples_count() - pbar.update(self.finished_samples_count - last_pbar_n) - last_pbar_n = self.finished_samples_count - - waiting_tasks = pending_tasks - - while ( - len(waiting_tasks) + self.finished_samples_count < data_concurrency + init_finished_samples_count - ): - # 当存在被filter掉的样本时,需要补数据 - task = create_task(self.worker_task()) - waiting_tasks.add(task) - - pbar.n = self.finished_samples_count - pbar.refresh() - - if self.finished_samples_count >= self.target_batch_size: - self.logger.info( - f"Target batch size {self.target_batch_size} reached with finished_samples_count: {self.finished_samples_count}." - ) - elif self.skipped_sample_count >= self.target_batch_size: - self.logger.info( - f"Stopping data generation as skipped samples {self.skipped_sample_count} reached target batch size {self.target_batch_size}." - ) - elif self.failed_sample_count >= self.target_batch_size: - self.logger.info( - f"Stopping data generation as failed samples {self.failed_sample_count} reached target batch size {self.target_batch_size}." - ) - generation_time = time.perf_counter() - dataflow_start_time - pause_start_time = time.perf_counter() - - if len(waiting_tasks) > 0: - self.logger.info(f"Start pausing env controller for remaining worker tasks {len(waiting_tasks)}.") - await self.pause() - cleanup_start_time = time.perf_counter() - while len(waiting_tasks) > 0: - elapsed_time = time.perf_counter() - cleanup_start_time - if elapsed_time > self.cleanup_task_time: - self.logger.warning( - f"Cleanup timeout of {self.cleanup_task_time}s reached. " - f"Forcefully cancelling {len(waiting_tasks)} remaining tasks." - ) - for task in waiting_tasks: - task.cancel() - # Wait for cancellations to complete - await asyncio.gather(*waiting_tasks, return_exceptions=True) - break # Exit the cleanup loop - # NOTE: Keep sending pause requests because the inference engine only marks currently running requests as aborted. - # When a waiting request starts running, it still needs another pause request to be marked as aborted. - _, pending_tasks = await asyncio.wait(waiting_tasks, timeout=0.1, return_when=asyncio.FIRST_COMPLETED) - if len(pending_tasks) > 0: - await self.pause() - await asyncio.sleep(1) - self.logger.debug( - f"Waiting for {len(pending_tasks)} remaining worker tasks to complete after pausing env controller." - ) - waiting_tasks = pending_tasks - self.logger.info("All worker tasks have completed after pausing env controller.") - - pause_time = time.perf_counter() - pause_start_time - dataflow_time = time.perf_counter() - dataflow_start_time - self.logger.info( - f"dataflow task finished, generation_time: {generation_time:.2f}s, pause_time: {pause_time:.2f}s, total_time: {dataflow_time:.2f}s" - ) - self.tb_metrics["time/generation_time"] = generation_time - self.tb_metrics["time/pause_time"] = pause_time - - task_completion_dict = self._log_task_completion_stats(task_completion_times, "Task Completion Time Stats:\n") - for k, v in task_completion_dict.items(): - self.tb_metrics[f"task_time/{k}"] = v - - @ray_method - async def pause(self, timeout: float = 60.0): - """Asynchronously sends abort requests to all rollout workers and - judgers. - - Args: - timeout: HTTP request timeout in seconds. - """ - if not self.worker_url_list: - self.logger.info("No active rollout workers to pause.") - return - - async with httpx.AsyncClient() as client: - tasks = [self._send_abort_request(client, url, timeout=timeout) for url in self.worker_url_list] - results = await asyncio.gather(*tasks) - - failed_workers = [url for url, success in results if not success] - succeeded_count = len(self.worker_url_list) - len(failed_workers) - - if failed_workers: - self.logger.warning( - f"Abort requests completed. Succeeded: {succeeded_count}, " - f"Failed: {len(failed_workers)}. Failed workers: {failed_workers}" - ) - else: - self.logger.info(f"All {succeeded_count} abort requests sent successfully.") - - # Abort judger actors whose config has abort_on_pause=True - try: - await self.env_controller.abort_judger.remote() # type: ignore[attr-defined] - self.logger.info("Judger abort signal sent successfully.") - except Exception as e: - self.logger.warning(f"Failed to send judger abort signal: {e}") - - @ray_method - async def run( - self, - num: Optional[int] = None, - sample_params: Optional[SampleParams] = None, - extra_params: Optional[Dict] = None, - staleness_threshold: Optional[float] = None, - ) -> DataFlowResult: - """Starts the data generation process. - - This method resets the internal state and runs the concurrent task - runner to collect a new batch of samples from the environment. - - Args: - num (Optional[int]): The target number of samples to collect for this run. - Overrides the existing global_batch_size in DataFlowConfig if provided. - sample_params (Optional[SampleParams]): Parameters for model sampling. - Overrides the existing sample_params in DataFlowConfig if provided. - extra_params (Optional[Dict]): Additional parameters for rollout. - Overrides the existing extra_params in DataFlowConfig if provided. - staleness_threshold (Optional[float]): Override for staleness threshold. - - Returns: - DataFlowResult: The collected training samples and metadata. - """ - self._reset_internal_states( - global_batch_size=num, - sample_params=sample_params, - extra_params=extra_params, - staleness_threshold=staleness_threshold, - ) - self.logging_replaybuffer_state("DataFlow run started. ") - await self.concurrent_task_runner() - self.logging_replaybuffer_state("DataFlow run completed. ") - - get_start_time = time.perf_counter() - return_samples = self.replay_buffer.get_samples(self.target_batch_size) # type: ignore[attr-defined] - self.logger.info( - f"Getting {self.target_batch_size} samples from replay buffer took {time.perf_counter() - get_start_time:.2f}s" - ) - self.tb_metrics["time/get_samples_time"] = time.perf_counter() - get_start_time - dataflow_result = DataFlowResult( - data_groups=return_samples[0], - mm_train_infos=return_samples[1], - metrics=self.tb_metrics, - ) - return dataflow_result - - def logging_replaybuffer_state(self, logging_msg: Optional[str] = None): - status = self.get_replaybuffer_status() - logging_msg = logging_msg if logging_msg else "" - logging_msg += f"ReplayBuffer Status: {status}" - logging_msg += f", finished_samples_count: {self.finished_samples_count}, " - logging_msg += f"skipped_samples_count: {self.skipped_sample_count}, " - logging_msg += f"failed_samples_count: {self.failed_sample_count}, " - logging_msg += f"filtered_samples_count: {self.filtered_samples_count}, " - if self._raw_reward_count > 0: - avg_raw_reward = self._raw_reward_sum / self._raw_reward_count - logging_msg += f"avg_raw_reward: {avg_raw_reward:.6f} (n={self._raw_reward_count}), " - self.tb_metrics["reward/avg_raw_reward"] = avg_raw_reward - self.tb_metrics["reward/raw_reward_count"] = self._raw_reward_count - self.logger.info(logging_msg) - - def get_replaybuffer_status(self): - return self.replay_buffer.status() - - async def _send_abort_request(self, client, url, timeout): - worker_url = f"{url}/abort_request" - try: - response = await client.post(worker_url, json={"abort_all": True}, timeout=timeout) - response.raise_for_status() - self.logger.debug(f"Successfully sent abort request to {url}") - return url, True - except Exception as e: - self.logger.error(f"Failed to send abort request to {url}: {e}") - return url, False - - def _log_task_completion_stats(self, task_times: List[float], logger_msg: Optional[str] = None): - if not task_times: - self.logger.info("No task completion times to report.") - return {} - - import numpy as np - - stats_dict = { - "p50": np.percentile(task_times, 50), - "p90": np.percentile(task_times, 90), - "p95": np.percentile(task_times, 95), - "p99": np.percentile(task_times, 99), - "max": np.max(task_times), - "avg": np.mean(task_times), - "std": np.std(task_times), - } - stats_dict["p99_p50_ratio"] = stats_dict["p99"] / stats_dict["p50"] if stats_dict["p50"] > 0 else float("inf") - - task_completions_report = ( - f" - Avg Time: {stats_dict['avg']:.2f}s, Std: {stats_dict['std']:.2f}s\n" - f" - P50 (Median): {stats_dict['p50']:.2f}s, P90: {stats_dict['p90']:.2f}s, P95: {stats_dict['p95']:.2f}s, P99: {stats_dict['p99']:.2f}s\n" - f" - Max Time: {stats_dict['max']:.2f}s, Ratio (P99 / P50): {stats_dict['p99_p50_ratio']:.2f}\n" - ) - logger_msg = logger_msg if logger_msg else "" - logger_msg += task_completions_report - self.logger.info(logger_msg) - return stats_dict - - def save(self, save_path: Path | str): - """Saves the replay buffer to the specified path. - - Args: - save_path (str): The path to the checkpoint file to save to. - """ - self.replay_buffer.save(save_path) - - def resume(self, resume_path: Path | str): - """Resumes the replay buffer from the specified path. - - Args: - resume_path (str): The path to the checkpoint file to resume from. - """ - self.replay_buffer.resume(resume_path) - - -DataFlow = ray.remote(RawDataFlow) -DataFlowProxy = ActorProxy[RawDataFlow] diff --git a/xtuner/v1/ray/dataflow/replay_buffer.py b/xtuner/v1/ray/dataflow/replay_buffer.py deleted file mode 100644 index 8596ae7e74..0000000000 --- a/xtuner/v1/ray/dataflow/replay_buffer.py +++ /dev/null @@ -1,1108 +0,0 @@ -import itertools -import time -from collections import defaultdict -from dataclasses import dataclass, field -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from uuid import uuid4 - -import numpy -import ray -import torch -from cyclopts import Parameter -from pydantic import BaseModel, ConfigDict, Field -from ray import ObjectRef -from typing_extensions import Annotated - -from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast -from xtuner.v1.data_proto.rl_data import ( - MultimodalTrainInfo, - RLDataFlowItem, - RLDatasetItem, - RLEnvDataItem, - RLExtraDataItem, - RLUIDItem, - RolloutState, - is_valid_for_replaybuffer, -) -from xtuner.v1.datasets.config import DataloaderConfig -from xtuner.v1.ray.utils import free_object_refs -from xtuner.v1.utils import XTUNER_DETERMINISTIC, get_logger -from xtuner.v1.utils.device import get_device - - -DEVICE = get_device() -logger = get_logger() - - -@dataclass -class ReplayMeta: - """ReplayMeta aggregates all versions of data related to a single prompt in - the replay buffer. - - Attributes: - env (str): Name or identifier of the environment. - root_id (int): Identifier for grouping related prompts (e.g., for GRPO or multi-turn scenarios). - action_id (int): Unique identifier for the prompt. If the prompt changes (such as in a multi-turn scenario), a new action_id is assigned. - action_ref (ObjectRef): Ray object reference to the prompt data (corresponds to RLDatasetItem in RLDataFlowItem). - observation_ids (List[int]): IDs for different responses to the same prompt. Each response has a unique observation_id. - observation_refs (List[ObjectRef]): Ray object references to environment data for each observation (corresponds to RLEnvDataItem in RLDataFlowItem). - observation_versions (List[int]): Version numbers for each observation, supporting async rollout. - state (str): Overall state of the prompt (e.g., "paused" for partial rollout, or other rollout states). - extra_info (Dict[str, Any]): Additional metadata or information. - """ - - env: str = "" - root_id: int = 0 - action_id: int = 0 # same prompt share the same action_id - action_ref: ObjectRef = None - observation_ids: List[int] = field(default_factory=list) - observation_refs: List[ObjectRef] = field(default_factory=list) - observation_versions: List[int] = field(default_factory=list) # 目前发数据为按组下发,暂时用不到 - observation_extra_infos: List[RLExtraDataItem] = field(default_factory=list) - state: RolloutState = RolloutState.INIT - version: int = 0 # version for partial rollout - extra_info: Dict[str, Any] = field(default_factory=dict) - - -@dataclass -class SerializedRayObjectRef: - """Snapshot marker that preserves where a ray.ObjectRef originally - lived.""" - - value: Any - - -def determine_group_state(group_data_items: List[RLDataFlowItem]) -> RolloutState: - """Determines the processing strategy for a group of rollout samples based - on their state.""" - # TODO(@duanyanhui): remove this function when send one request instead of group requests. - if not group_data_items: - return RolloutState.SKIPPED - group_states = {item.env.rollout.state for item in group_data_items} - if RolloutState.SKIPPED in group_states: - return RolloutState.SKIPPED - elif RolloutState.FAILED in group_states: - return RolloutState.FAILED - elif RolloutState.ABORTED in group_states: - return RolloutState.ABORTED - elif all(state == RolloutState.COMPLETED for state in group_states): - return RolloutState.COMPLETED - else: - raise ValueError(f"Unknown group states: {group_states}") - - -def mapping_dataitem_to_replaymeta(grouped_dataitem: List[RLDataFlowItem]) -> ReplayMeta: - assert len(grouped_dataitem) > 0 - - env_str = grouped_dataitem[0].uid.env - root_id = grouped_dataitem[0].uid.root_id - action_id = grouped_dataitem[0].uid.action_id - # !!! 注意:这里放的是第一个dataitem的data,因为一组数据的data是一样的 !!! - data = grouped_dataitem[0].data - # 现在是按组发送,那么一组里的dataitem的version是一样的,如果一组中的数据在某次rollout step中没有生成的数据,version也还是会+1 - group_version = grouped_dataitem[0].uid.version - observation_ids = [] - observation_refs = [] - observation_versions = [] - observation_extra_infos = [] - - for item in grouped_dataitem: - observation_ids.append(item.uid.observation_id) - observation_refs.append(ray.put(item.env)) - observation_versions.append(item.uid.version) - observation_extra_infos.append(item.extra_info.model_copy(deep=True)) - - group_state = determine_group_state(grouped_dataitem) - logger.debug( - f"Mapping data items to ReplayMeta {action_id} with group_state: {group_state}, group_version: {group_version}" - ) - - replay_meta = ReplayMeta( - env=env_str, - root_id=root_id, - action_id=action_id, - action_ref=ray.put(data), - observation_ids=observation_ids, - observation_refs=observation_refs, - observation_versions=observation_versions, - observation_extra_infos=observation_extra_infos, - state=group_state, - version=group_version, - extra_info={}, - ) - return replay_meta - - -def mapping_replaymeta_to_dataitem(replay_meta: ReplayMeta, consume_refs: bool = True) -> List[RLDataFlowItem]: - env_str = replay_meta.env - root_id = replay_meta.root_id - action_id = replay_meta.action_id - - action_ref = replay_meta.action_ref - observation_refs = list(replay_meta.observation_refs) - - data_value = ray.get(action_ref) if action_ref is not None else None - - env_values = [ray.get(obs_ref) for obs_ref in observation_refs] - - if consume_refs: - refs_to_free: List[ObjectRef] = [] - if isinstance(action_ref, ObjectRef): - refs_to_free.append(action_ref) - refs_to_free.extend([ref for ref in observation_refs if isinstance(ref, ObjectRef)]) - free_object_refs(refs_to_free) - replay_meta.action_ref = None - replay_meta.observation_refs.clear() - - group_data_item = [] - observation_versions = replay_meta.observation_versions or [replay_meta.version] * len(replay_meta.observation_ids) - observation_extra_infos = replay_meta.observation_extra_infos or [ - RLExtraDataItem() for _ in replay_meta.observation_ids - ] - for idx, (obs_id, env_data) in enumerate(zip(replay_meta.observation_ids, env_values)): - observation_version = observation_versions[idx] if idx < len(observation_versions) else replay_meta.version - extra_info = ( - observation_extra_infos[idx].model_copy(deep=True) - if idx < len(observation_extra_infos) - else RLExtraDataItem() - ) - if env_data.rollout.state == RolloutState.INIT and replay_meta.state != RolloutState.INIT: - env_data.rollout.state = replay_meta.state - item = RLDataFlowItem( - uid=RLUIDItem( - env=env_str, - root_id=root_id, - action_id=action_id, - observation_id=obs_id, - version=observation_version, - ), - data=data_value, - env=env_data, - extra_info=extra_info, - ) - group_data_item.append(item) - return group_data_item - - -class ReplayBufferConfig(BaseModel): - """Replay buffer configuration for XTuner. - - This class defines configuration parameters for the replay buffer system in XTuner, - managing dataset handling, data loading, text processing, and post-processing - operations for reinforcement learning experience replay. - - Args: - dataset_cfg (List): Configuration for datasets used to sample initial prompts. - dataloader_cfg (DataloaderConfig): Configuration for the PyTorch DataLoader - that iterates over the dataset. - tokenizer (PreTrainedTokenizer | PreTrainedTokenizerFast): Tokenizer for - processing text data, including support for partial rollouts. - postprocessor_func (Optional[Callable]): Optional function to filter or - modify data groups after generation. Defaults to None. - replay_ratio (float): Ratio of samples to replay from the buffer versus - sampling new data. Defaults to 0. - replay_weights (dict): Weights for different states in the replay buffer - to control sampling priorities. Defaults to empty dict. - - **Examples:** - - Example configuration for ReplayBuffer with GSM8K dataset config and base dataloader config:: - - from transformers import AutoTokenizer - - config = ReplayBufferConfig( - dataset_cfg=[{ - "dataset": DatasetConfig(name="gsm8k", anno_path="path/to/data"), - "tokenize_fn": RLTokenizeFnConfig(max_length=512) - }], - dataloader_cfg=DataloaderConfig(collator='fake_collator'), - tokenizer=AutoTokenizer.from_pretrained("model_path"), - postprocessor_func=None, - ) - """ - - model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") - - dataset_cfg: Annotated[List, Parameter(help="The dataset object to sample initial prompts from.")] - - dataloader_cfg: Annotated[ - Optional[DataloaderConfig], Parameter(help="The PyTorch DataLoader for iterating over the dataset.") - ] = None - - tokenizer: Annotated[ - Union[PreTrainedTokenizer, PreTrainedTokenizerFast, str], - Field(exclude=True), - Parameter(help="The tokenizer for processing text data, e.g., for partial rollouts."), - ] - postprocessor_func: Annotated[ - Optional[Callable], - Field(exclude=True), - Parameter(help="An optional function to filter or modify data groups after they are generated."), - ] = None - replay_ratio: Annotated[ - float, - Parameter(help="Ratio of samples to replay from the buffer."), - ] = 0 - replay_weights: Annotated[ - dict, - Parameter(help="Weights for different states in the replay buffer."), - ] = {} - worker_log_dir: Annotated[Path, Parameter(help="Directory to save worker logs.")] = Path.cwd() / "work_dir" - - -class DatasetSampler: - """Sampler for drawing new prompts from the configured dataset. - - This class is responsible for building a dataloader from the provided dataset configurations and sampling fresh - data prompts upon request. - """ - - def __init__(self, dataset_cfg, dataloader_cfg, tokenizer): - """Initializes the DatasetSampler. - - Args: - dataset_cfg (List): Configuration for the datasets to sample from. - dataloader_cfg (Optional[DataloaderConfig]): Configuration for the - PyTorch DataLoader. - tokenizer (Union[PreTrainedTokenizer, PreTrainedTokenizerFast, str]): - The tokenizer for processing text data. Can be a path or an object. - """ - self.tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] - if isinstance(tokenizer, str): - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True) - else: - self.tokenizer = tokenizer - if dataloader_cfg is not None: - self.dataloader_cfg = dataloader_cfg - self.dataloader_cfg.dataset_config_list = dataset_cfg - else: - self.dataloader_cfg = DataloaderConfig( - dataset_config_list=dataset_cfg, - collator="fake_collator", - pack_level="none", - num_workers=1, - ) - self.dataloader = self.dataloader_cfg.build( - tokenizer=self.tokenizer, dp_mesh=None, global_batch_size=1, micro_batch_size=1, seed=1 - ) - self.dataloader_iter = iter(self.dataloader) - self.cur_epoch = 0 - self.reduced_consumed_samples = 0 - self._next_root_id = 0 - self.logger = get_logger() - - def sample(self, env: str, prompt_repeat_k: int) -> List[RLDataFlowItem]: - """Samples a new prompt from the dataset and prepares it as a group. - - This method fetches the next item from the dataloader, assigns new - unique IDs (root_id, action_id, observation_id), and formats it into - a list of RLDataFlowItem objects, repeated `prompt_repeat_k` times. - - Args: - env (str): The environment name to be associated with the new samples. - prompt_repeat_k (int): The number of times to repeat the sampled - prompt in the returned group. - - Returns: - List[RLDataFlowItem]: A list of newly created data items for a rollout. - """ - if XTUNER_DETERMINISTIC: - root_id = max(self._next_root_id, self.reduced_consumed_samples * prompt_repeat_k) - action_id = root_id - self._next_root_id = root_id + prompt_repeat_k - else: - root_id = uuid4().int - action_id = uuid4().int - group_data_item: List[RLDataFlowItem] = [RLDataFlowItem() for _ in range(prompt_repeat_k)] - try: - data = next(self.dataloader_iter)[0] - except StopIteration: - self.cur_epoch += 1 - self.dataloader.set_epoch(self.cur_epoch) - self.dataloader_iter = iter(self.dataloader) - data = next(self.dataloader_iter)[0] - self.reduced_consumed_samples += 1 - - multimodal_train_info = data.pop("multimodal_train_info", {}) - if "pixel_values" in multimodal_train_info: - multimodal_train_info["pixel_values"] = ray.put(multimodal_train_info["pixel_values"]) - # If it is a mixture of pure text and image data, there will be position_id but no pixel_values - data["multimodal_train_info"] = multimodal_train_info - - for item_idx, data_item in enumerate(group_data_item): - data_item.uid = RLUIDItem( - env=env, - root_id=root_id, - action_id=action_id, - observation_id=root_id + item_idx if XTUNER_DETERMINISTIC else uuid4().int, - ) - data_item.data = RLDatasetItem(**data) - data_item.extra_info = RLExtraDataItem(retry_times=0) - self.logger.debug(f"Sampling new prompt with action_id: {action_id} in env: {env}") - return group_data_item - - def resume(self, dataloader_path): - dataloader_state = torch.load(dataloader_path, map_location=DEVICE) - self.dataloader.load_state_dict(dataloader_state) - self.dataloader_iter = iter(self.dataloader) - self.reduced_consumed_samples = int(dataloader_state["total_consumed_samples"]) - self.cur_epoch = dataloader_state["sampler"]["epoch"] - - -class ReplayBufferStorage: - """Handles the storage of experiences for the replay buffer.""" - - def __init__(self, replay_buffer_cfg): - """Initializes the data structures for storing replay data.""" - self.enable_partial_rollout: bool = False - self.tail_batch_candidate_steps: int = 0 - self.tail_batch_trigger_size: int = 0 - - self._completed_actions: Dict[int, List[int]] = defaultdict(list) - self._aborted_actions: Dict[int, List[int]] = defaultdict(list) - self._expired_actions: List[int] = [] - self._actions: Dict[int, ReplayMeta] = {} - self._root2actions: Dict[int, List[int]] = {} - self._observations: Dict[int, ReplayMeta] = {} - self._observations2states: Dict[int, str] = {} - self._states: Dict[str, List[int]] = defaultdict(list) - self._action2observations: Dict[int, List[int]] = defaultdict(list) - self._multimodal_train_infos: Dict[int, Dict[str, Any]] = {} - self.logger = get_logger(log_dir=replay_buffer_cfg.worker_log_dir, tag="ReplayBuffer") - self.sample_from_aborted_count = 0 - self.sample_from_expired_count = 0 - - def _update_replay_meta_state(self, replay_meta: ReplayMeta, new_state: RolloutState): - for observation_id in replay_meta.observation_ids: - old_state = self._observations2states.get(observation_id) - if old_state and observation_id in self._states.get(old_state, []): - self._states[old_state].remove(observation_id) - self._observations2states[observation_id] = new_state - if observation_id not in self._states[new_state]: - self._states[new_state].append(observation_id) - replay_meta.state = new_state - - def _strip_rollout_payload_for_rerun(self, replay_meta: ReplayMeta, new_state: RolloutState): - """Keep prompt refs only and drop rollout outputs that will not be - reused.""" - old_obs_refs = [ref for ref in replay_meta.observation_refs if ref is not None] - if old_obs_refs: - ray.internal.free(old_obs_refs, local_only=False) - replay_meta.observation_refs = [ray.put(RLEnvDataItem()) for _ in replay_meta.observation_ids] - self._update_replay_meta_state(replay_meta, new_state) - - def add(self, grouped_dataitem: List[RLDataFlowItem]): - """Adds a group of data items to the storage. - - Args: - grouped_dataitem (List[RLDataFlowItem]): A list of data items - belonging to the same group. - """ - if ( - grouped_dataitem is None - or len(grouped_dataitem) == 0 - or is_valid_for_replaybuffer(grouped_dataitem) is False - ): - return - - replay_meta = mapping_dataitem_to_replaymeta(grouped_dataitem) - root_id = replay_meta.root_id - action_id = replay_meta.action_id - - # 1. 更新版本 - if root_id in self._root2actions: - # TODO: 考虑到非共卡的情况,version是否更新需要根据是否update_weights来判断 - replay_meta.version += 1 - self._root2actions[root_id].append(action_id) - self.logger.debug( - f"Existing root_id: {root_id} with action_id {action_id} found. Incrementing version to {replay_meta.version}." - ) - else: - self._root2actions[root_id] = [action_id] - self._actions[action_id] = replay_meta - - # 2. 根据rollout_state更新completed/aborted/expired相关映射 - self._check_rollout_state_and_insert(replay_meta) - - # 3. 更新observations相关映射 - for observation_id in replay_meta.observation_ids: - self._observations[observation_id] = replay_meta - self._observations2states[observation_id] = replay_meta.state - if observation_id not in self._action2observations[action_id]: - self._action2observations[action_id].append(observation_id) - if observation_id not in self._states[replay_meta.state]: - self._states[replay_meta.state].append(observation_id) - - def get(self, global_batch_size: int) -> Tuple[List[List[RLDataFlowItem]], List[MultimodalTrainInfo | None]]: - """Retrieves a batch of finished sample groups from the buffer. - - Args: - global_batch_size (int): The number of sample groups to retrieve. - - Raises: - ValueError: If there are not enough finished samples in the buffer - to meet the `global_batch_size`. - - Returns: - List[List[RLDataFlowItem]]: A list of sample groups. Each inner - list contains a group of data items that were generated from the - same initial prompt, repeated `repeat_prompt_k` times. - """ - samples = [] - multimodal_train_infos = [] - target_batch_size = min(global_batch_size, self.completed_samples_count) - self.logger.info(f"Retrieving {target_batch_size} completed samples from the replay buffer.") - task_time = [] - for _ in range(target_batch_size): - task_start_time = time.perf_counter() - action_id = self._pop_highest_version_action(self._completed_actions) - if action_id is None: - self.logger.warning("Get action_id None from completed_actions and skip this iteration.") - continue - replay_meta = self._actions.pop(action_id) - group_samples = mapping_replaymeta_to_dataitem(replay_meta) - # 将这条数据彻底清除,不用再记录root_id对应的action_ids了 - self._clear_meta_for_root(replay_meta) - multimodal_train_info = None - # TODO: 是否需要额外返回不重复的 multimodal_train_infos? - for data_item in group_samples: - if hasattr(data_item.data, "multimodal_train_info"): - multimodal_train_info = data_item.data.multimodal_train_info - del data_item.data.multimodal_train_info - if "partial_rollout_input_ids" in data_item.env.rollout.extra_info: - del data_item.env.rollout.extra_info["partial_rollout_input_ids"] - samples.append(group_samples) - multimodal_train_infos.append(multimodal_train_info) - task_end_time = time.perf_counter() - task_time.append(task_end_time - task_start_time) - # 检查completed_samples中是否还有剩余的数据,并且检查其是否过期 - avg_time = sum(task_time) / len(task_time) if len(task_time) > 0 else 0 - self.logger.info( - f"Remaining completed samples in buffer: {self.completed_samples_count}, task_time: {sum(task_time)}s, avg_time: {avg_time}s" - ) - self._check_completed_samples_expired() - self._check_completed_samples_aborted() - return samples, multimodal_train_infos - - def sample(self, sample_from_expired_states) -> List[RLDataFlowItem]: - if sample_from_expired_states and self.expired_samples_count > 0: - self.sample_from_expired_count += 1 - return self._sample_from_expired_storage() - if self.aborted_samples_count > 0: - self.sample_from_aborted_count += 1 - return self._sample_from_aborted_storage() - return [] - - def clear(self): - for replay_meta in list(self._actions.values()): - self._release_replay_meta_refs(replay_meta) - - attrs_to_clear = [ - "_aborted_actions", - "_completed_actions", - "_expired_actions", - "_actions", - "_root2actions", - "_observations", - "_observations2states", - "_states", - "_action2observations", - "_multimodal_train_infos", - ] - for attr in attrs_to_clear: - getattr(self, attr).clear() - self.sample_from_aborted_count = 0 - self.sample_from_expired_count = 0 - - def snapshot_ray_objects(self, data_item: RLDataFlowItem): - """Replaces nested ray.ObjectRefs with serializable markers.""" - self._snapshot_nested_objectrefs(data_item) - - def restore_ray_objects(self, data_item: RLDataFlowItem): - """Restores nested ray.ObjectRefs from serialized snapshot markers.""" - self._restore_nested_objectrefs(data_item) - - def _snapshot_nested_objectrefs(self, obj: Any): - if isinstance(obj, ObjectRef): - value = ray.get(obj) - return SerializedRayObjectRef(self._snapshot_nested_objectrefs(value)) - if isinstance(obj, BaseModel): - for field_name in type(obj).model_fields: - setattr(obj, field_name, self._snapshot_nested_objectrefs(getattr(obj, field_name))) - return obj - if isinstance(obj, list): - for idx, value in enumerate(obj): - obj[idx] = self._snapshot_nested_objectrefs(value) - return obj - if isinstance(obj, tuple): - return tuple(self._snapshot_nested_objectrefs(value) for value in obj) - if isinstance(obj, set): - return {self._snapshot_nested_objectrefs(value) for value in obj} - if isinstance(obj, dict): - for key, value in list(obj.items()): - obj[key] = self._snapshot_nested_objectrefs(value) - return obj - return obj - - def _restore_nested_objectrefs(self, obj: Any): - if isinstance(obj, SerializedRayObjectRef): - return ray.put(self._restore_nested_objectrefs(obj.value)) - if isinstance(obj, BaseModel): - for field_name in type(obj).model_fields: - setattr(obj, field_name, self._restore_nested_objectrefs(getattr(obj, field_name))) - return obj - if isinstance(obj, list): - for idx, value in enumerate(obj): - obj[idx] = self._restore_nested_objectrefs(value) - return obj - if isinstance(obj, tuple): - return tuple(self._restore_nested_objectrefs(value) for value in obj) - if isinstance(obj, set): - return {self._restore_nested_objectrefs(value) for value in obj} - if isinstance(obj, dict): - for key, value in list(obj.items()): - obj[key] = self._restore_nested_objectrefs(value) - return obj - return obj - - def convert_to_ray_objref(self, data_item: RLDataFlowItem): - """Converts large tensors in RLDataFlowItem to ray.ObjectRefs. - - Args: - data_item (RLDataFlowItem): The data item containing large tensors. - Returns: - RLDataFlowItem: The data item with large tensors converted to ray.ObjectRefs. - """ - # convert data.multimodal_train_info to ray.ObjectRef - if hasattr(data_item.data, "multimodal_train_info"): - multimodal_info = data_item.data.multimodal_train_info - if multimodal_info and "pixel_values" in multimodal_info: - # 一组数据共享同一个data_item.data,所以只需要转换一次 - if not isinstance(multimodal_info["pixel_values"], ray.ObjectRef): - pixel_values_ref = ray.put(multimodal_info["pixel_values"]) - del multimodal_info["pixel_values"] - data_item.data.multimodal_train_info["pixel_values"] = pixel_values_ref # type: ignore[index] - # convert rollout.extra_info.router_experts to ray.ObjectRef - if "routed_experts" in data_item.env.rollout.extra_info: - if not isinstance(data_item.env.rollout.extra_info["routed_experts"], ray.ObjectRef): - routed_experts_ref = ray.put(data_item.env.rollout.extra_info["routed_experts"]) - del data_item.env.rollout.extra_info["routed_experts"] - data_item.env.rollout.extra_info["routed_experts"] = routed_experts_ref - - def has_objectref(self, item: RLDataFlowItem) -> bool: - def check(obj): - if isinstance(obj, ray.ObjectRef): - return True - if isinstance(obj, SerializedRayObjectRef): - return check(obj.value) - if isinstance(obj, BaseModel): - return any(check(getattr(obj, f)) for f in type(obj).model_fields) - if isinstance(obj, (list, tuple, set)): - return any(check(x) for x in obj) - if isinstance(obj, dict): - return any(check(v) for v in obj.values()) - if isinstance(obj, (str, int, float, bool, type(None), torch.Tensor, numpy.ndarray)): - return False - # 如果不满足以上类型,抛出错误,防止意想不到的问题 - raise TypeError( - f"Unsupported type: {type(obj)} in {obj} " - f"Expected ray.ObjectRef, SerializedRayObjectRef, BaseModel, list/tuple/set, dict, or primitive types." - ) - - return check(item) - - def dump(self, file_path: Path): - """Dumps the entire state of the replay buffer storage to a single - file, resolving all ray.ObjectRefs to their actual values. - - Args: - file_path (str): The path to the file where the state will be - saved. - """ - all_data_items = [] - for replay_meta in self._actions.values(): - # dump 仅用于序列化快照,这里可直接消费 refs,避免长时间占用 object store - data_items = mapping_replaymeta_to_dataitem(replay_meta, consume_refs=False) - for item in data_items: - self.snapshot_ray_objects(item) - res = self.has_objectref(item) - assert not res, "ReplayBufferStorage.dump found unresolved ray.ObjectRef in RLDataFlowItem" - all_data_items.append(data_items) - - state = { - "_completed_actions": self._completed_actions, - "_aborted_actions": self._aborted_actions, - "_expired_actions": self._expired_actions, - "_actions": all_data_items, - "_root2actions": dict(self._root2actions), - "_observations2states": self._observations2states, - "_states": dict(self._states), - "_action2observations": dict(self._action2observations), - "_multimodal_train_infos": self._multimodal_train_infos, - "sample_from_aborted_count": self.sample_from_aborted_count, - "sample_from_expired_count": self.sample_from_expired_count, - } - - torch.save(state, file_path) - self.logger.info(f"ReplayBufferStorage state dumped to {file_path}") - - def resume(self, file_path: Path): - """Resumes the replay buffer storage from a single file. - - Args: - file_path (str): The path to the file from which to restore the - state. - """ - - self.clear() - - state = torch.load(file_path, map_location="cpu", weights_only=False) - - self._completed_actions = defaultdict(list, state["_completed_actions"]) - self._aborted_actions = defaultdict(list, state["_aborted_actions"]) - self._expired_actions = state["_expired_actions"] - self._root2actions = defaultdict(list, state["_root2actions"]) - self._observations2states = state["_observations2states"] - self._states = defaultdict(list, state["_states"]) - self._action2observations = defaultdict(list, state["_action2observations"]) - self._multimodal_train_infos = state.get("_multimodal_train_infos", {}) - self.sample_from_aborted_count = state.get("sample_from_aborted_count", 0) - self.sample_from_expired_count = state.get("sample_from_expired_count", 0) - - dump_actions = state["_actions"] - # 重建 _actions 和 _observations: 与replaymeta相关 - for group_dataitem in dump_actions: - for data_item in group_dataitem: - self.restore_ray_objects(data_item) - replay_meta = mapping_dataitem_to_replaymeta(group_dataitem) - action_id = replay_meta.action_id - self._actions[action_id] = replay_meta - for observation_id in self._action2observations[action_id]: - self._observations[observation_id] = replay_meta - - self.logger.info(f"ReplayBufferStorage state successfully resumed from {file_path}") - self.logger.info( - f"ReplayBuffer Storage status: completed_samples_count={self.completed_samples_count}, aborted_samples_count={self.aborted_samples_count}, expired_samples_count={self.expired_samples_count}" - ) - - @property - def completed_samples_count(self) -> int: - return sum(len(bucket) for bucket in self._completed_actions.values()) - - @property - def aborted_samples_count(self): - return sum(len(bucket) for bucket in self._aborted_actions.values()) - - @property - def expired_samples_count(self): - return len(self._expired_actions) - - def _sample_from_expired_storage(self) -> List[RLDataFlowItem]: - """Samples an item from the expired storage for re-rollout. - - This method takes an action_id from the expired queue, retrieves its - original prompt data, cleans up all its previous rollout outputs, and - prepares it as a new sample group with a fresh action_id and reset - version (0) to be sent for a new generation attempt. - - Returns: - List[RLDataFlowItem]: A list of data items ready for a new rollout. - """ - assert len(self._expired_actions) > 0 - action_id = self._expired_actions.pop() - replay_meta = self._actions.pop(action_id) - group_samples = mapping_replaymeta_to_dataitem(replay_meta) - # 把这条数据上次的记录全部删掉,重新开始rollout,root2actions也要清除 - self._clear_meta_for_root(replay_meta) - - for sample in group_samples: - assert sample.data.input_ids and sample.data.num_tokens, "input_ids or num_tokens is empty!" - if "routed_experts" in sample.env.rollout.extra_info: - ray.internal.free(sample.env.rollout.extra_info["routed_experts"], local_only=False) - del sample.env.rollout.extra_info["routed_experts"] - del sample.env - sample.env = RLEnvDataItem() # 重置env数据 - sample.uid.action_id = action_id - sample.uid.version = 0 - - self.logger.debug( - f"Sampling expired action_id: {action_id} from replay buffer, remain expired samples: {len(self._expired_actions)}" - ) - return group_samples - - def _sample_from_aborted_storage(self) -> List[RLDataFlowItem]: - """Samples an item from the aborted storage for re-rollout. - - This method retrieves an action with the highest version (oldest version) from the - aborted buckets. It then cleans up its previous (aborted) rollout - outputs and prepares it as a new sample group with a fresh action_id. - The original version number is preserved to track its retry history. - - Returns: - List[RLDataFlowItem]: A list of data items ready for a new rollout. - """ - assert self.aborted_samples_count > 0 - action_id = self._pop_highest_version_action(self._aborted_actions) - # 通过self.aborted_samples_count判断过这里不会返回None - replay_meta = self._actions.pop(action_id) # type: ignore[arg-type] - replay_meta_version = replay_meta.version - group_samples = mapping_replaymeta_to_dataitem(replay_meta) - # 把这条数据上次rollout产生的输出的记录都删掉,上次的数据已经记录在了RLEnvDataItem中了 - self._clear_meta_for_actions(replay_meta) - - sample_action_id = uuid4().int - for sample in group_samples: - assert sample.data.input_ids and sample.data.num_tokens, "input_ids or num_tokens is empty!" - if not self.enable_partial_rollout: - # 清除上次的response_ids等env数据 - if "routed_experts" in sample.env.rollout.extra_info: - ray.internal.free(sample.env.rollout.extra_info["routed_experts"], local_only=False) - del sample.env.rollout.extra_info["routed_experts"] - del sample.env - sample.env = RLEnvDataItem() - sample.uid.version = 0 - sample.uid.action_id = action_id if action_id is not None else sample_action_id - else: - # 将异步的逻辑尽量放在replay buffer中处理,尽量不在env/rollout中进行处理 - history_response_ids = list(itertools.chain.from_iterable(sample.env.rollout.versioned_response_ids)) - sample.env.rollout.extra_info["partial_rollout_input_ids"] = ( - sample.data.input_ids + history_response_ids - ) - self.logger.debug( - f"partial rollout enabled, {sample_action_id} pass response_ids {len(history_response_ids)} to input_ids {len(sample.data.input_ids)} to data extra info when sampling." - ) - sample.uid.version = replay_meta_version - sample.uid.action_id = int(sample_action_id) - - self.logger.debug( - f"Sampling aborted action_id: {sample_action_id}, root_id: {group_samples[0].uid.root_id} from replay buffer, remain aborted samples: {self.aborted_samples_count}" - ) - return group_samples - - def _pop_highest_version_action(self, buckets: Dict[int, List[int]]) -> Optional[int]: - if not buckets: - return None - - highest_version = sorted(buckets.keys())[-1] - action_list = buckets[highest_version] - action_id = action_list.pop() - if not action_list: - del buckets[highest_version] - - return action_id - - def _check_completed_samples_expired(self): - """Moves samples from completed buckets to the expired list if they are - too old after get target completed samples from replay buffer. - - This method iterates through the `_completed_actions` buckets. If a - bucket's version index is greater than or equal to the configured - `tail_batch_candidate_steps`, all action_ids within that bucket are - moved to the `_expired_actions` list, marking them as expired. - """ - if self.tail_batch_candidate_steps <= 0: - return - - expired_versions = [v for v in self._completed_actions if v >= self.tail_batch_candidate_steps] - - for version in expired_versions: - bucket = self._completed_actions.pop(version) - for action_id in bucket: - replay_meta = self._actions.get(action_id) - if replay_meta is not None: - self._strip_rollout_payload_for_rerun(replay_meta, RolloutState.EXPIRED) - self._expired_actions.extend(bucket) - self.logger.info( - f"Moved {len(bucket)} completed samples with version {version} to expired samples due to exceeding tail_batch_candidate_steps." - ) - - def _check_completed_samples_aborted(self): - if self.enable_partial_rollout: - return - - for version, bucket in self._completed_actions.items(): - for action_id in bucket: - replay_meta = self._actions.get(action_id) - if replay_meta is not None: - self._strip_rollout_payload_for_rerun(replay_meta, RolloutState.ABORTED) - self._aborted_actions[0].extend(bucket) - self.logger.info( - f"Moved {len(bucket)} completed samples with version {version} to aborted samples due to partial rollout disabled." - ) - self._completed_actions.clear() - - def _release_replay_meta_refs(self, replay_meta: ReplayMeta): - refs_to_free: List[ObjectRef] = [] - if isinstance(replay_meta.action_ref, ObjectRef): - refs_to_free.append(replay_meta.action_ref) - refs_to_free.extend([ref for ref in replay_meta.observation_refs if isinstance(ref, ObjectRef)]) - free_object_refs(refs_to_free) - replay_meta.action_ref = None - replay_meta.observation_refs.clear() - - def _clear_meta_for_actions(self, replay_meta: ReplayMeta): - """Completely removes an action and all its associated data from the - storage. - - This is the single source of truth for deleting an action. - """ - action_id = replay_meta.action_id - - self._release_replay_meta_refs(replay_meta) - - for observation_id in replay_meta.observation_ids: - self._observations.pop(observation_id, None) - state = self._observations2states.pop(observation_id, None) - if state and observation_id in self._states.get(state, []): - self._states[state].remove(observation_id) - - self._actions.pop(action_id, None) - self._action2observations.pop(action_id, None) - del replay_meta - - def _clear_meta_for_root(self, replay_meta: ReplayMeta): - """Clears all actions and associated metadata linked to the same - root_id. - - This function is called after a sample group is successfully retrieved - for training. It ensures that all historical versions of a prompt - (linked by root_id) are purged from the storage to prevent them from - being re-sampled or replayed. - - Args: - replay_meta (ReplayMeta): The metadata of the action that was just - retrieved. The root_id from this object will be used to find - and clear all related actions. - """ - root_id = replay_meta.root_id - current_action_id = replay_meta.action_id - - self._clear_meta_for_actions(replay_meta) - - if root_id in self._root2actions: - for action_id in self._root2actions[root_id]: - if action_id == current_action_id: - continue - new_replay_meta = self._actions.pop(action_id, None) - if new_replay_meta: - self._clear_meta_for_actions(new_replay_meta) - del self._root2actions[root_id] - - def _check_rollout_state_and_insert(self, replay_meta: ReplayMeta): - """Checks the rollout state of a ReplayMeta object and inserts its - action_id into the appropriate state bucket. - - This method acts as a router, directing action_ids to different storage - lists (_aborted_actions, _completed_actions, _expired_actions) based on - their final rollout state and version. It also handles the logic for - when an aborted sample becomes expired due to too many retries. - - Args: - replay_meta (ReplayMeta): The metadata object containing the final - state and version of a rollout action. - """ - state = replay_meta.state - root_id = replay_meta.root_id - action_id = replay_meta.action_id - - if state == RolloutState.ABORTED: - if self.tail_batch_candidate_steps > 0 and replay_meta.version >= self.tail_batch_candidate_steps: - # 过期的数据需要重置状态 - self._expired_actions.append(action_id) - self._strip_rollout_payload_for_rerun(replay_meta, RolloutState.EXPIRED) - self.logger.debug( - f"Add expired sample with action_id: {action_id} to _expired_actions because version: {replay_meta.version} >= tail_batch_candidate_steps: {self.tail_batch_candidate_steps}." - ) - else: - if not self.enable_partial_rollout: - self._strip_rollout_payload_for_rerun(replay_meta, RolloutState.ABORTED) - self._aborted_actions[replay_meta.version].append(action_id) - self.logger.debug( - f"Add aborted sample with action_id: {action_id} version: {replay_meta.version} to _aborted_actions." - ) - elif state == RolloutState.COMPLETED: - self._completed_actions[replay_meta.version].append(action_id) - self.logger.debug(f"Add sample with root_id: {root_id}, action_id: {action_id} to finished_actions.") - else: - raise AssertionError( - f"Unsupported rollout state {state} for action_id {action_id} in ReplayBufferStorage." - ) - - -class ReplayBuffer: - """A Ray actor that manages experience replay for reinforcement - learning.""" - - def __init__( - self, - config: ReplayBufferConfig, - ): - """Initializes the ReplayBuffer actor. - - Args: - config (ReplayBufferConfig): The configuration object. - """ - self.config = config - self.storage = ReplayBufferStorage(config) - self.sampler = DatasetSampler(config.dataset_cfg, config.dataloader_cfg, config.tokenizer) - self.post_processor_func = config.postprocessor_func - self.sample_from_expired_states = False - self.sample_from_dataset_count = 0 - self.logger = get_logger(log_dir=config.worker_log_dir, tag="ReplayBuffer") - - def setup_storage_config( - self, enable_partial_rollout: bool, tail_batch_candidate_steps: int, tail_batch_trigger_size: int - ): - """Sets up the storage configuration for the replay buffer. - - Args: - enable_partial_rollout (bool): Flag to enable partial rollouts. - tail_batch_candidate_steps (int): Number of steps to consider for - tail batch sampling. - tail_batch_trigger_size (int): Threshold size to trigger tail batch - sampling. - """ - self.storage.enable_partial_rollout = enable_partial_rollout - self.storage.tail_batch_candidate_steps = tail_batch_candidate_steps - self.storage.tail_batch_trigger_size = tail_batch_trigger_size - - def get_prerun_state(self) -> Tuple[bool, int]: - if ( - self.storage.tail_batch_trigger_size > 0 - and self.storage.expired_samples_count > self.storage.tail_batch_trigger_size - ): - self.sample_from_expired_states = True - self.logger.info( - f"Enable sampling from expired states. Expired samples: {self.storage.expired_samples_count}, threshold: {self.storage.tail_batch_trigger_size}" - ) - else: - self.sample_from_expired_states = False - return self.sample_from_expired_states, self.storage.completed_samples_count - - def get_train_dataset_length(self): - """Returns the length of the training dataloader.""" - return len(self.sampler.dataloader) - - def post_processor(self, group_samples): - """Applies a post-processing function to a group of samples. - - Args: - group_samples: A list of samples to process. - - Returns: - The processed group of samples. - """ - if self.post_processor_func: - group_samples = self.post_processor_func(group_samples) - return group_samples - return group_samples - - def sample(self, env, prompt_repeat_k) -> List[RLDataFlowItem]: - """Samples a batch of experiences from the replay buffer. - - Args: - env: The environment name. - enable_partial_rollout (int): Flag to enable partial rollouts. - prompt_repeat_k (int): Number of times to repeat a prompt. - - Returns: - A list of sampled data items. - """ - storage_samples = self.storage.sample(self.sample_from_expired_states) - if storage_samples: - return storage_samples - else: - self.sample_from_dataset_count += 1 - return self.sampler.sample(env, prompt_repeat_k) - - def get_samples( - self, - global_batch_size: int, - ): - """Gets a batch of finished samples from the storage. - - Args: - global_batch_size (int): The number of sample groups to retrieve. - - Returns: - A list of sample groups. - """ - return self.storage.get(global_batch_size) - - def add(self, grouped_dataitem: List[RLDataFlowItem]): - """Adds a group of data items to the replay buffer storage. - - Args: - grouped_dataitem (List[RLDataFlowItem]): A list of data items - from the same group. - """ - self.storage.add(grouped_dataitem) - - def status(self): - return { - "remain_completed_samples_count": self.storage.completed_samples_count, - "remain_aborted_samples_count": self.storage.aborted_samples_count, - "remain_expired_samples_count": self.storage.expired_samples_count, - "sample_from_dataset_count": self.sample_from_dataset_count, - "sample_from_aborted_count": self.storage.sample_from_aborted_count, - "sample_from_expired_count": self.storage.sample_from_expired_count, - } - - def save(self, file_path: Path | str): - """Saves the replay buffer's storage to a file. - - Args: - file_path (str): The path to the file for saving the data. - """ - if isinstance(file_path, str): - file_path = Path(file_path) - file_path.mkdir(parents=True, exist_ok=True) - - # save dataloader - dataloader_path = file_path / "dataloader" - dataloader_state = self.sampler.dataloader.get_state_dict() - torch.save(dataloader_state, dataloader_path) - - # save storage - rb_storage_path = file_path / "replay_buffer_storage.pth" - self.storage.dump(rb_storage_path) - - def resume(self, file_path: Path | str): - """Resumes the replay buffer's storage from a file. - - Args: - file_path (str): The path to the file from which to restore the - state. - """ - if isinstance(file_path, str): - file_path = Path(file_path) - dataloader_path = file_path / "dataloader" - if dataloader_path.exists(): - self.sampler.resume(dataloader_path) - self.sample_from_dataset_count = self.sampler.reduced_consumed_samples - self.logger.info( - f"Dataloader state successfully resumed from {dataloader_path} and set to epoch {self.sampler.cur_epoch} and step {self.sampler.reduced_consumed_samples}." - ) - else: - self.logger.warning(f"Dataloader state file {dataloader_path} does not exist. Skipping dataloader resume.") - # resume storage - rb_storage_path = file_path / "replay_buffer_storage.pth" - if rb_storage_path.exists(): - self.storage.resume(rb_storage_path) - else: - self.logger.warning( - f"ReplayBufferStorage state file {rb_storage_path} does not exist. Skipping storage resume." - ) - - def get_completed_samples_count(self) -> int: - """Returns the count of completed samples in the replay buffer. - - Returns: - int: The number of completed samples. - """ - return self.storage.completed_samples_count - - def clear(self): - """Clears the replay buffer storage.""" - self.storage.clear() diff --git a/xtuner/v1/ray/environment/__init__.py b/xtuner/v1/ray/environment/__init__.py deleted file mode 100644 index 02112e66a7..0000000000 --- a/xtuner/v1/ray/environment/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .base_env import BaseEnvironment -from .single_turn_env import SingleTurnEnvironment, SingleTurnEnvironmentProxy diff --git a/xtuner/v1/ray/environment/base_env.py b/xtuner/v1/ray/environment/base_env.py deleted file mode 100644 index 4f73af9004..0000000000 --- a/xtuner/v1/ray/environment/base_env.py +++ /dev/null @@ -1,247 +0,0 @@ -import os -from abc import ABC, abstractmethod -from typing import Any, List - -import ray - -from xtuner.v1.data_proto.rl_data import RLDataFlowItem -from xtuner.v1.utils import ray_method - - -class BaseEnvironment(ABC): - """The BaseEnvironment class provides a foundational structure for managing - rollout and judger controllers for single-turn generation or multi-turn - generation. - - This class is responsible for initializing the necessary controllers based on the provided - configurations and placement group. It defines abstract methods for generation and - execution, which must be implemented by subclasses. - - Args: - environment (str): The name or identifier of the environment. - rollout_pg (Any): The placement group for scheduling rollout Ray actors. - rollout_cfg (Any, optional): The configuration for the rollout controller. Defaults to None. - judger_pg (Any): The placement group for scheduling judger Ray actors. - Defaults to None indicates using the rollout_pg. - judger_cfg (Any, optional): The configuration for the judger controller. Defaults to None. - """ - - def __init__( - self, - environment: str, - rollout_pg: Any, - rollout_cfg: Any, - judger_pg: Any = None, - judger_cfg: Any = None, - rollout_controller=None, - judger_controller=None, - ): - # judger_pg = judger_pg if judger_pg else rollout_pg - self.environment = environment - if rollout_controller: - self.rollout_controller = rollout_controller - else: - self.rollout_controller = self.init_rollout_controller(rollout_cfg, rollout_pg) - if judger_controller: - self.judger_controller = judger_controller - else: - self.judger_controller = self.init_judger_controller(judger_cfg, judger_pg) - - def init_rollout_controller(self, rollout_cfg: Any, placement_group: Any): - """Initializes the rollout controller with the appropriate worker - backend. - - Based on the `rollout_cfg`, this method selects and initializes the corresponding - rollout worker (e.g., `LMDeployWorker` or `vLLMWorker`). It then creates and - returns a `RolloutController` to manage these workers. - - Args: - rollout_cfg (Any): The configuration for the rollout controller. - placement_group (Any): The placement group for scheduling Ray actors. - - Returns: - The initialized rollout controller, or None if `rollout_cfg` is not provided. - - Raises: - NotImplementedError: If the specified rollout backend is not supported. - """ - - rollout_controller = None - if rollout_cfg is None: - return rollout_controller - - from xtuner.v1.ray.rollout.controller import RolloutController - - rollout_controller = ( - ray.remote(RolloutController) - .options(max_concurrency=int(os.environ.get("RAY_MAX_CONCURRENCY", 1000))) - .remote(rollout_cfg, placement_group) - ) # type: ignore[attr-defined] - return rollout_controller - - def init_judger_controller(self, judger_cfg: Any, placement_group: Any): - """Initializes the judger controller. - - If a `judger_cfg` is provided, this method creates and returns a `JudgerController` - to handle evaluation and judging tasks. - - Args: - judger_cfg (Any): The configuration for the judger controller. - placement_group (Any): The placement group for scheduling Ray actors. - - Returns: - The initialized judger controller, or None if `judger_cfg` is not provided. - """ - judger_controller = None - if judger_cfg: - from xtuner.v1.ray.judger.controller import JudgerController - - judger_controller = JudgerController.remote(judger_cfg, placement_group) # type: ignore[attr-defined] - return judger_controller - - @abstractmethod - @ray_method - async def generate( - self, data: List[RLDataFlowItem], sample_params: Any, extra_params: Any - ) -> List[RLDataFlowItem]: - """Generates responses from the model for the given data using the - inference engine. This method is primarily used for single-step - inference. - - Args: - data: The input data, which can be a single prompt, RLTextDataItem, or a list of RLTextDataItem. - sample_params: Sampling parameters for the generation process. - - Returns: - A list of generated samples, each populated with 'response_str' and 'state' - """ - pass - - @abstractmethod - @ray_method - async def run(self, data: List[RLDataFlowItem], sample_params: Any, extra_params: Any) -> List[RLDataFlowItem]: - """Executes a full cycle of generation and interpretation, such as - generating a response and then evaluating it with a judger. This method - can be extended to support complex interactions like multi-turn - dialogues. - - Args: - data: The input data for the generation process. - sample_params: Sampling parameters for generation. - - Returns: - A list of generated samples - """ - pass - - def _call_rollout_func(self, method_name: str, block: bool): - """A helper function to dynamically call a method on the rollout - controller. - - Args: - method_name (str): The name of the method to call. - block (bool): Whether to block until the call completes. - - Returns: - The result of the method call. - """ - import ray - - assert self.rollout_controller, "Rollout controller is not initialized." - if block: - return ray.get(getattr(self.rollout_controller, method_name).remote()) - return getattr(self.rollout_controller, method_name).remote() - - @ray_method - def pause(self, block=True) -> None: - """Pauses the rollout workers. - - Args: - block (bool): Whether to block until the operation completes. - """ - return self._call_rollout_func("pause", block) - - @ray_method - def shutdown(self, block=True) -> None: - """Shuts down the rollout workers. - - Args: - block (bool): Whether to block until the operation completes. - """ - return self._call_rollout_func("shutdown", block) - - @ray_method - def restart(self, block=True) -> None: - """Restarts the rollout workers. - - Args: - block (bool): Whether to block until the operation completes. - """ - return self._call_rollout_func("restart", block) - - @ray_method - def get_rollout_info(self, block=True) -> dict[str, Any]: - """Gets information about the rollout workers. - - Args: - block (bool): Whether to block until the operation completes. - """ - return self._call_rollout_func("get_rollout_info", block) - - @ray_method - def onload_weights(self, block=True) -> None: - """Loads weights onto the rollout workers. - - Args: - block (bool): Whether to block until the operation completes. - """ - return self._call_rollout_func("onload_weights", block) - - @ray_method - def onload_kvcache(self, block=True) -> str: - """Loads the KV cache onto the rollout workers. - - Args: - block (bool): Whether to block until the operation completes. - """ - return self._call_rollout_func("onload_kvcache", block) - - @ray_method - def offload(self, block=True) -> str: - """Offloads weights and the KV cache from the rollout workers. - - Args: - block (bool): Whether to block until the operation completes. - """ - return self._call_rollout_func("offload", block) - - @ray_method - def update_active_workers(self, block=True) -> None: - """Checks the status of active rollout workers. - - Args: - block (bool): Whether to block until the operation completes. - """ - return self._call_rollout_func("update_active_workers", block) - - @ray_method - def get_rollout_stats(self, block=True) -> dict[str, Any]: - """Gets statistics from the rollout workers. - - Args: - block (bool): Whether to block until the operation completes. - """ - return self._call_rollout_func("get_rollout_stats", block) - - @ray_method - async def abort_judger(self): - """Abort in-flight judger requests for judgers with - ``abort_on_pause=True``.""" - if self.judger_controller: - await self.judger_controller.abort.remote() - - @ray_method - async def restart_judger(self): - """Clear abort state on judgers with ``abort_on_pause=True``.""" - if self.judger_controller: - await self.judger_controller.restart_judger.remote() diff --git a/xtuner/v1/ray/environment/single_turn_env.py b/xtuner/v1/ray/environment/single_turn_env.py deleted file mode 100644 index 43fcad7544..0000000000 --- a/xtuner/v1/ray/environment/single_turn_env.py +++ /dev/null @@ -1,221 +0,0 @@ -import asyncio -import copy -import os -from pathlib import Path -from typing import List, cast - -import ray -from ray.actor import ActorClass, ActorProxy - -from xtuner.v1.data_proto.rl_data import ( - RLDataFlowItem, - RLJudgerResponseItem, - RLRolloutResponseItem, - RolloutState, - is_valid_for_training, - update_dataflow_item, - update_rollout_item, -) -from xtuner.v1.ray.environment.base_env import BaseEnvironment -from xtuner.v1.ray.utils import build_deterministic_session_id, deterministic_item_sort_key -from xtuner.v1.utils import XTUNER_DETERMINISTIC, get_logger, ray_method - - -class RawSingleTurnEnvironment(BaseEnvironment): - """A single-turn environment for handling generation and evaluation tasks. - - This class extends `BaseEnvironment` to provide a concrete implementation for - single-turn interactions. It manages the rollout process for generating responses - and can coordinate with a judger for evaluation. - - Args: - environment (str): The name of the environment. - rollout_pg: The placement group for scheduling rollout Ray actors. - rollout_cfg (optional): Configuration for the rollout controller. Defaults to None. - judger_pg (Any): The placement group for scheduling judger Ray actors. - Defaults to None indicates using the rollout_pg. - judger_cfg (optional): Configuration for the judger controller. Defaults to None. - rollout_controller (optional): An instance of the rollout controller. Defaults to None. - judger_controller (optional): An instance of the judger controller. Defaults to None. - """ - - def __init__( - self, - environment: str, - rollout_pg, - rollout_cfg=None, - judger_pg=None, - judger_cfg=None, - rollout_controller=None, - judger_controller=None, - ): - super().__init__( - environment, rollout_pg, rollout_cfg, judger_pg, judger_cfg, rollout_controller, judger_controller - ) - if rollout_cfg: - worker_log_dir = rollout_cfg.worker_log_dir - elif judger_cfg: - worker_log_dir = judger_cfg.worker_log_dir - else: - worker_log_dir = Path.cwd() / "work_dir" - self.logger = get_logger(log_dir=worker_log_dir, tag="SingleTurnEnv") - if rollout_cfg and rollout_cfg.enable_return_routed_experts: - self.logger.info("!!! Enable `return routed experts` in rollout controller. !!!") - self.rollout_timeout = rollout_cfg.rollout_timeout if rollout_cfg else 1200.0 - self.judger_timeout = judger_cfg.judger_timeout if judger_cfg else 1200.0 - # The timeout for the environment to wait for the rollout controller's response. - # This should be longer than the controller's internal timeout (`rollout_timeout`) - # to account for potential queuing delays and other overheads. - self.timeout_multiplier = 2.0 - self.cancel_response_timeout = 5.0 - self.rollout_cfg = rollout_cfg - - async def generate( # type: ignore[override] - self, group_data_items: List[RLDataFlowItem], sample_params=None, extra_params=None - ) -> List[RLDataFlowItem]: - """Generate responses for a batch of RLTextDataItem using the rollout - controller. - - Each item in `group_data_items` will be sent to the rollout controller for response generation - with the provided sampling parameters. The generated response string and state will be - added to each RLTextDataItem in-place as `response_str` and `state` fields. - - Args: - group_data_items (List[RLTextDataItem]): - A list of RLTextDataItem objects containing the prompts/messages for generation. - sample_params: Sampling parameters for the generation process. The type should match - the rollout controller's expected sampling parameter type (e.g., SampleParams or dict). - extra_params: Extra parameters for generation. If contains "disable_routed_experts=True", - will force disable return_routed_experts regardless of config. - - Returns: - List[RLTextDataItem]: - The same list of RLTextDataItem, with each item enriched with the generated response - and state from the rollout controller. - """ - if extra_params is None: - extra_params = {} - if self.rollout_controller: - if XTUNER_DETERMINISTIC: - group_data_items = sorted(group_data_items, key=deterministic_item_sort_key) - response_future = [] - for i, sample in enumerate(group_data_items): - rollout_extra_info = dict(sample.data.extra_info) - rollout_extra_info["root_id"] = sample.uid.root_id - rollout_extra_info["action_id"] = sample.uid.action_id - rollout_extra_info["observation_id"] = sample.uid.observation_id - update_sample_params = sample_params - session_id = None - if XTUNER_DETERMINISTIC: - update_sample_params = copy.deepcopy(sample_params) - update_sample_params.sampling_seed = self.rollout_cfg.random_seed + i - session_id = build_deterministic_session_id(self.environment, sample) - - if "partial_rollout_input_ids" in sample.env.rollout.extra_info: - input_ids_length = len(sample.data.input_ids) if sample.data.input_ids is not None else 0 - current_partial_length = len(sample.env.rollout.extra_info["partial_rollout_input_ids"]) - rollout_extra_info["partial_rollout_input_ids"] = sample.env.rollout.extra_info[ - "partial_rollout_input_ids" - ] - assert sample_params is not None, "sample_params should not be None when using partial rollout." - update_sample_params = copy.deepcopy(sample_params) - update_sample_params.max_tokens = sample_params.max_tokens - ( - current_partial_length - input_ids_length - ) - self.logger.debug( - f"root_id: {sample.uid.root_id}, action_id {sample.uid.action_id} pass current_partial_length {current_partial_length}, input_ids_length {input_ids_length} to rollout and set max_tokens to {update_sample_params.max_tokens}" - ) - - if "routed_experts" in sample.env.rollout.extra_info: - rollout_extra_info["routed_experts"] = sample.env.rollout.extra_info["routed_experts"] - - fut = self.rollout_controller.rollout.remote( - prompt=sample.data.messages, - input_ids=sample.data.input_ids, - sample_params=update_sample_params, - extra_params=extra_params, - session_id=session_id, - extra_info=rollout_extra_info, - ) - del rollout_extra_info - - response_future.append(fut) - try: - response_gather = asyncio.gather(*response_future) - rollout_responses = await asyncio.wait_for( - asyncio.shield(response_gather), timeout=self.rollout_timeout * self.timeout_multiplier - ) - except asyncio.CancelledError as exc: - for fut in response_future: - ray.cancel(fut, recursive=True) - try: - rollout_responses = await asyncio.wait_for( - asyncio.gather(*response_future, return_exceptions=True), - timeout=self.cancel_response_timeout, - ) - except BaseException: - raise exc - if not all(isinstance(response, RLRolloutResponseItem) for response in rollout_responses): - raise exc - except asyncio.TimeoutError: - for fut in response_future: - ray.cancel(fut, recursive=True) - self.logger.error("Get rollout controller response timeout and return the failed response.") - rollout_responses = [RLRolloutResponseItem(state="skipped") for _ in group_data_items] - group_data_items = update_rollout_item(group_data_items, rollout_responses) - return group_data_items - - @ray_method - async def run( # type: ignore[override] - self, group_data_items: List[RLDataFlowItem], sample_params=None, extra_params=None - ) -> List[RLDataFlowItem]: - """Runs a full generation and judger cycle. - - This method first generates responses using the `generate` method and then, - if a judger controller is available, it uses the judger to evaluate the - generated responses. - - Args: - data: The input data for the cycle. Can be a list of strings, - a single `RLTextDataItem`, or a list of `RLTextDataItem`. - sample_params: Sampling parameters for the generation process. - - Returns: - The data enriched with generated responses and evaluation results. - The format of the return value matches the format of the input `data`. - """ - group_data_items = await self.generate(group_data_items, sample_params, extra_params) # type: ignore[assignment] - continue_judger = is_valid_for_training(group_data_items) - if self.judger_controller and continue_judger: - judger_response_ref = self.judger_controller.run.remote(group_data_items) - try: - judger_responses: List[RLJudgerResponseItem] = await asyncio.wait_for( - judger_response_ref, - timeout=self.judger_timeout * self.timeout_multiplier, - ) - except asyncio.CancelledError: - ray.cancel(judger_response_ref, recursive=True) - raise - except asyncio.TimeoutError: - ray.cancel(judger_response_ref, recursive=True) - self.logger.error("Get judger controller response timeout and return the failed response.") - judger_responses = [ - RLJudgerResponseItem( - extra_info={"state": "failed"}, - ) - for _ in group_data_items - ] - group_data_items = update_dataflow_item(group_data_items, "env.judger", judger_responses) - # Mark items whose judger was aborted so they are treated as ABORTED - # rather than COMPLETED in downstream state determination. - for item, judger_resp in zip(group_data_items, judger_responses): - if isinstance(judger_resp.extra_info, dict) and judger_resp.extra_info.get("state") == "aborted": - item.env.rollout.state = RolloutState.ABORTED - return group_data_items - - -SingleTurnEnvironment = cast( - ActorClass[RawSingleTurnEnvironment], - ray.remote(max_concurrency=int(os.environ.get("RAY_MAX_CONCURRENCY", 1000)))(RawSingleTurnEnvironment), -) -SingleTurnEnvironmentProxy = ActorProxy[RawSingleTurnEnvironment] diff --git a/xtuner/v1/ray/evaluator.py b/xtuner/v1/ray/evaluator.py deleted file mode 100644 index 90215bc5ef..0000000000 --- a/xtuner/v1/ray/evaluator.py +++ /dev/null @@ -1,290 +0,0 @@ -import asyncio -from pathlib import Path -from typing import Callable, List, Optional, Sized, TypeVar, Union -from uuid import uuid4 - -import ray -from cyclopts import Parameter -from pydantic import BaseModel, ConfigDict, Field -from ray.actor import ActorProxy -from tqdm.auto import tqdm -from typing_extensions import Annotated - -from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast -from xtuner.v1.data_proto.rl_data import RLDataFlowItem, RLDatasetItem, RLUIDItem, SampleParams -from xtuner.v1.datasets import build_dataloader, build_datasets -from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfigList -from xtuner.v1.ray.environment import BaseEnvironment -from xtuner.v1.ray.utils import create_task -from xtuner.v1.utils import get_logger -from xtuner.v1.utils.type_helper import ray_method - - -T = TypeVar("T") -Ret = TypeVar("Ret") - - -class EvaluatorConfig(BaseModel): - """Configuration for the Evaluator in XTuner. - - This class defines the configuration parameters for model evaluation in XTuner, including four main aspects: - - - Dataset configuration: Specifies the evaluation dataset and tokenizer for text processing - - - Evaluator control logic: Manages concurrency levels and retry mechanisms for robust evaluation - - - Evaluation scheduling: Controls evaluation step intervals and sample size (either by ratio or absolute count) - - - Custom metric computation: Supports user-defined functions for specialized metric calculations - - Args: - dataset_cfg (DatasetConfigList): Configuration for the evaluation dataset. - tokenizer (PreTrainedTokenizer | PreTrainedTokenizerFast): Tokenizer used for text processing. - evaluate_step (int): Step interval for triggering evaluation. Defaults to 1. - eval_sample_ratio (float): Ratio of samples to evaluate from the generated samples. If > 0, overrides eval_sample_num. Defaults to 0 (use all samples). - eval_sample_num (int): Number of samples to evaluate from the generated samples. Used if eval_sample_ratio is 0. Defaults to 0 (use all samples). - max_concurrent (int): Maximum number of concurrent evaluation tasks. Defaults to 8. - max_retry_times (int): Maximum number of retry attempts for failed evaluation tasks. Defaults to 2. - compute_metric_func (Optional[Callable]): Optional function to compute or filter metrics for generated data groups. If None, uses default metric computation. - - **Examples:** - - Example configuration for evaluator with GSM8K dataset:: - - from transformers import AutoTokenizer - - config = EvaluatorConfig( - dataset_cfg=[{ - "dataset": DatasetConfig(name="gsm8k", anno_path="test_data.json"), - "tokenize_fn": RLTokenizeFnConfig(max_length=512) - }], - tokenizer=AutoTokenizer.from_pretrained("model_path"), - max_concurrent=32, - eval_sample_ratio=0.8, # Use 80% of samples - evaluate_step=10, - compute_metric_func=custom_accuracy_metric - ) - """ - - model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") - - enable_evaluate: Annotated[ - bool, - Parameter(help="Flag to enable or disable evaluation during training."), - ] = False - enable_initial_evaluate: Annotated[ - bool, - Parameter(help="Flag to enable or disable initial evaluation before training starts."), - ] = False - dataset_cfg: Annotated[ - DatasetConfigList, - Parameter(help="Configuration for the dataset."), - ] - dataloader_cfg: Annotated[ - Optional[DataloaderConfig], Parameter(help="The PyTorch DataLoader for iterating over the dataset.") - ] = None - - tokenizer: Annotated[ - Union[PreTrainedTokenizer, PreTrainedTokenizerFast, str], - Field(exclude=True), - Parameter(help="Tokenizer for text processing."), - ] - max_concurrent: Annotated[ - int, - Parameter(help="Maximum number of concurrent tasks."), - ] = 512 - eval_sample_ratio: Annotated[ - float, - Parameter(help="Ratio of samples to evaluate from the generated samples."), - ] = 0 - eval_sample_num: Annotated[ - int, - Parameter(help="Number of samples to evaluate from the generated samples."), - ] = 0 - max_retry_times: Annotated[int, Parameter(help="Maximum number of retry attempts for failed tasks.")] = 2 - evaluate_step: Annotated[int, Parameter(help="Step interval for evaluation.")] = 1 - compute_metric_func: Annotated[ - Optional[Callable], - Field(exclude=True), - Parameter(help="An optional function to filter or modify data groups after they are generated."), - ] = None - sample_params: Annotated[ - SampleParams, - Parameter(help="Sampling parameters for evaluation."), - ] = SampleParams() - worker_log_dir: Annotated[Path, Parameter(help="Directory to save worker logs.")] = Path.cwd() / "work_dir" - - -class RawEvaluator: - """A Ray actor for evaluating a model's performance on a given dataset. - - The Evaluator generates responses using an environment controller or rollout controller, then it use default or - custom computes metrics function to compute scores for generated samples. It returns the evaluation scores and - generated samples. - """ - - def __init__(self, config: EvaluatorConfig, env_controller: BaseEnvironment): - """Initialize the Evaluator. - - Args: - config (EvaluatorConfig): The configuration for the evaluator. - env_controller (EnvController): The environment controller used for - generating responses. - """ - self.config = config - self.sample_params = self.config.sample_params - self.dataset = ( - build_datasets(config.dataset_cfg, config.tokenizer) - if isinstance(config.tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)) - else build_datasets( - config.dataset_cfg, AutoTokenizer.from_pretrained(config.tokenizer, trust_remote_code=True) - ) - ) - - if config.dataloader_cfg is not None: - self.dataloader_cfg = config.dataloader_cfg - else: - self.dataloader_cfg = DataloaderConfig( - collator="fake_collator", - pack_level="none", - ) - self.dataloader = build_dataloader( - dataloader_config=self.dataloader_cfg, - datasets=self.dataset, - global_batch_size=1, - micro_batch_size=1, - seed=1, - ) - assert isinstance(self.dataloader, Sized) - - self.env_controller = env_controller - self.return_list: List[RLDataFlowItem] = [] - if self.config.eval_sample_ratio > 0: - self.eval_batch_size = int(len(self.dataloader) * self.config.eval_sample_ratio) - elif self.config.eval_sample_num > 0: - self.eval_batch_size = self.config.eval_sample_num - else: - self.eval_batch_size = len(self.dataloader) - if self.config.compute_metric_func is not None: - self.compute_metric = self.config.compute_metric_func - else: - self.compute_metric = self.default_compute_metric - self.logger = get_logger(log_dir=config.worker_log_dir, tag="Evaluator") - - def default_compute_metric(self, samples): - """Default metric computation function. - - Calculates accuracy based on whether the reward is positive. - - Args: - samples (list): A list of RLDataFlowItem samples. - - Returns: - dict: A dictionary containing the accuracy score. - """ - return {"accuracy": sum(s.env.judger.reward["score"] > 0 for s in samples) / len(samples)} - - async def eval_worker_task(self, sample: RLDataFlowItem): - """A single worker task to evaluate one sample. - - This task calls the environment controller to run the model on a - sample. If it fails, it returns the sample with an incremented - retry count. - - Args: - sample (RLDataFlowItem): The data item to evaluate. - - Returns: - RLDataFlowItem or None: The sample with retry information if it - failed, or None if it succeeded or failed without a sample. - """ - # Force disable return_routed_experts for evaluator to reduce overhead - extra_params = {"disable_routed_experts": True} - group_sample = await self.env_controller.run.remote( - [sample], sample_params=self.sample_params, extra_params=extra_params - ) # type: ignore[attr-defined] - self.return_list.append(group_sample[0]) - - async def concurrent_eval_task_runner(self): - """Runs evaluation tasks concurrently to generate a batch of samples. - - This method orchestrates the evaluation process by creating and managing - a pool of asynchronous worker tasks. It continuously fetches data from - the dataloader and submits evaluation tasks until the desired number of - samples (`self.eval_batch_size`) has been successfully processed. - """ - waiting_tasks = set() - self.logger.info(f"Start to generate {self.eval_batch_size} samples for evaluate") - self.logger.info(f"Evaluate sample parameters set to {self.sample_params}.") - data_iter = iter(self.dataloader) - with tqdm(total=self.eval_batch_size, desc="Rollout for eval samples") as pbar: - update_step = max(1, int(self.eval_batch_size * 0.1)) - next_update_threshold = update_step - while len(self.return_list) < self.eval_batch_size: - if len(self.return_list) >= next_update_threshold: - pbar.n = len(self.return_list) - pbar.refresh() - next_update_threshold += update_step - while len(waiting_tasks) < self.config.max_concurrent: - if len(self.return_list) + len(waiting_tasks) >= self.eval_batch_size: - break - try: - data = next(data_iter) - except StopIteration: - data_iter = iter(self.dataloader) - data = next(data_iter) - self.logger.warning("Restarting the evaluation dataset.") - uid = RLUIDItem(action_id=uuid4().int, observation_id=uuid4().int) - data_item = RLDataFlowItem(data=RLDatasetItem(**data[0]), uid=uid) - task = create_task(self.eval_worker_task(data_item)) - waiting_tasks.add(task) - - if len(waiting_tasks) == 0: - break - - _, pending_tasks = await asyncio.wait(waiting_tasks, timeout=0.1, return_when=asyncio.FIRST_COMPLETED) - waiting_tasks = pending_tasks - - pbar.n = len(self.return_list) - pbar.refresh() - - self.logger.info("Target batch size reached.") - if waiting_tasks: - await asyncio.wait_for(asyncio.gather(*waiting_tasks, return_exceptions=True), timeout=10) - - rollout_stats = await self.env_controller.get_rollout_stats.remote() # type: ignore[attr-defined] - self.logger.info(rollout_stats) - - @ray_method - async def run(self, return_samples=False): - """Run the full evaluation process. - - This method resets the state, runs the concurrent task runner, - computes the final metrics, and returns the results. - - Args: - sample_params (Optional[SampleParams]): Sampling parameters for - generation. Defaults to a greedy strategy. - return_samples (bool): Whether to return the generated samples - along with the scores. Defaults to False. - - Returns: - dict or Tuple[dict, list]: The evaluation scores, and optionally - the generated samples. - """ - self.return_list = [] - await self.env_controller.restart.remote() # type: ignore[attr-defined] - await self.concurrent_eval_task_runner() - if len(self.return_list) == 0: - self.logger.warning("No valid samples were generated during evaluation.") - return {} if not return_samples else ({}, []) - scores = self.compute_metric(self.return_list) - # To match the training format : each group's data is a list - self.eval_samples = [[sample] for sample in self.return_list] - if return_samples: - return scores, self.eval_samples - return scores - - -Evaluator = ray.remote(RawEvaluator) -EvaluatorProxy = ActorProxy[RawEvaluator] diff --git a/xtuner/v1/ray/judger/__init__.py b/xtuner/v1/ray/judger/__init__.py deleted file mode 100644 index 6a194b0245..0000000000 --- a/xtuner/v1/ray/judger/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .controller import JudgerConfig, JudgerController diff --git a/xtuner/v1/ray/judger/controller.py b/xtuner/v1/ray/judger/controller.py deleted file mode 100644 index a024d7c8d7..0000000000 --- a/xtuner/v1/ray/judger/controller.py +++ /dev/null @@ -1,297 +0,0 @@ -import asyncio -import random -from pathlib import Path -from typing import List, Optional - -import ray -from cyclopts import Parameter -from pydantic import BaseModel, ConfigDict, computed_field -from ray.util.placement_group import PlacementGroup, placement_group -from typing_extensions import Annotated - -from xtuner.v1.data_proto.rl_data import RLDataFlowItem, RLJudgerResponseItem - -from .native import NativeJudgerConfig - - -PG_READY_TIMEOUT = 30 - - -class JudgerConfig(BaseModel): - """Judger configuration for XTuner. - - Configuration for the judging system managing batch processing and custom judger - implementations for model evaluation and reward computation. - - Args: - enable_batch_reward (bool): Enable calculate reward within the data group of repeat_prompt_k. Defaults to False. - - reward_judger_configs (Dict[str, BaseModel]): Dictionary mapping judger names - to their configuration objects. We provided the example GSM8KJudgerConfig - for GSM8K mathematical reasoning tasks (see ``xtuner/v1/ray/judger/gsm8k.py``). Defaults to empty dict. - - **Examples:** - - Example configuration for single judger:: - - config = JudgerConfig( - enable_batch_reward=False, - reward_judger_configs={ - "gsm8k": GSM8KJudgerConfig(...) - } - ) - - Example configuration for multiple judgers:: - - config = JudgerConfig( - reward_judger_configs={ - "gsm8k": GSM8KJudgerConfig(...), - "math_qa": MathQAJudgerConfig(...), - "custom_eval": CustomJudgerConfig(...) - } - ) - - .. note:: - You should ensure each dataset item specifies data_source with dictionary mapping judger names to their weight ratios - - Example dataset item:: - - data_item = { - "data_source": {"gsm8k": 0.7, "math_qa": 0.3}, - "response_str": "...", - "reward_model": {"ground_truth": "..."} - } - """ - - model_config = ConfigDict(extra="forbid") - - enable_batch_reward: Annotated[ - bool, Parameter(help="Whether to enable batch reward calculation for multiple samples at once.") - ] = False - enable_weighted_judgers: Annotated[ - bool, Parameter(help="Whether to enable weighted reward calculation on multi judgers.") - ] = False - reward_judger_configs: Annotated[ - List[NativeJudgerConfig], - Parameter(help="A custom Python function for computing reward given model output and label."), - ] = [] - judger_timeout: Annotated[float, Parameter(help="Timeout for each judger request in seconds.")] = 1200.0 - worker_log_dir: Annotated[Path, Parameter(help="Directory to save worker logs.")] = Path.cwd() / "work_dir" - - @computed_field - def total_bundles_needed(self) -> list[dict]: - judger_total_bundles = [ - {"CPU": cfg.num_cpus_per_actor, "memory": cfg.num_cpus_per_actor * 1024**3} - for cfg in self.reward_judger_configs - for _ in range(cfg.num_ray_actors) - ] - return judger_total_bundles - - @computed_field - def total_cpus_needed(self) -> int: - judger_total_cpus = sum(cfg.num_cpus_per_actor * cfg.num_ray_actors for cfg in self.reward_judger_configs) - return judger_total_cpus - - @computed_field - def total_memory_needed(self) -> int: - judger_total_memory = sum( - cfg.num_cpus_per_actor * 1024**3 * cfg.num_ray_actors for cfg in self.reward_judger_configs - ) - return judger_total_memory - - -@ray.remote -class JudgerController: - """Controller for judging model outputs and calculating rewards.""" - - def __init__(self, judger_config: JudgerConfig, pg: Optional[PlacementGroup] = None): - """Initialize the JudgerController. - - Args: - judger_config (JudgerConfig): The configuration for the judger. - placement_group: The Ray placement group for resource allocation. - Defaults to None. - """ - self.judger_config = judger_config - # note: placement_group is used to control the placement of Ray tasks. - # It will be implemented when gpu judger is needed - if pg is None: - assert len(self.judger_config.reward_judger_configs) == 1, ( - "If no placement group is provided, there should be only one judger config." - ) - defaule_placement_group = placement_group(bundles=[{"CPU": 1, "memory": 1024**3}], strategy="PACK") - ray.get([defaule_placement_group.ready()], timeout=PG_READY_TIMEOUT) - self.pg = defaule_placement_group - else: - assert len(pg.bundle_specs) >= sum( - config.num_ray_actors for config in self.judger_config.reward_judger_configs - ), "The provided placement group does not have enough bundles for all judger actors." - self.pg = pg - self.reward_judger: List[List[ray.actor.ActorHandle]] = [] - self.reward_judger_names: List[str] = [] - self.judger_instance_count = 0 - - for idx, config in enumerate(self.judger_config.reward_judger_configs): - # start_bundle_idx用于指定从placement group的哪个bundle开始分配资源 - judger = config.build_actor(pg=self.pg, start_bundle_idx=self.judger_instance_count) - # 同一类judger可能会有多个实例(例如多个Ray actor),同一类的judger作为一行 - self.reward_judger.append(judger) - self.reward_judger_names.append(config.judger_name) - self.judger_instance_count += len(judger) - self._abort_on_pause_mask: list[bool] = [ - cfg.abort_on_pause for cfg in self.judger_config.reward_judger_configs - ] - self.enable_weighted_judgers = ( - False if len(self.reward_judger) == 1 else self.judger_config.enable_weighted_judgers - ) - - async def _call_single_reward_judger( - self, judger: List[ray.actor.ActorHandle], group_data_item: List[RLDataFlowItem] - ): - """Call a single custom reward judger to calculate rewards. - - Args: - judger (NativeJudger): An instance of a custom judger. - responses (List[str]): A list of model-generated responses. - labels (List[str]): A list of ground-truth labels. - - Returns: - List[RLJudgerResponseItem]: A list of RLJudgerResponseItem containing - calculated rewards for each sample. - """ - tasks = [] - judger_input_data = ( - [group_data_item] if self.judger_config.enable_batch_reward else [[item] for item in group_data_item] - ) - - if self.judger_config.enable_batch_reward: - # Randomly pick a judger instance for batch evaluation to balance the load. - tasks.append(random.choice(judger).judge.remote(group_data_item)) - else: - tasks.extend([judger[idx % len(judger)].judge.remote(item) for idx, item in enumerate(judger_input_data)]) - return tasks - - async def _call_custom_reward_judger( - self, - active_judgers: List[List[ray.actor.ActorHandle]], - active_reward_judger_names: List[str], - group_data_item: List[RLDataFlowItem], - ) -> List[RLJudgerResponseItem]: - """Call custom reward judgers to calculate rewards. - - Args: - active_judgers (Dict[str, NativeJudger]): A dictionary of active - judgers. - responses (List[str]): A list of model-generated responses. - labels (List[str]): A list of ground-truth labels. - - Returns: - Dict[str, List[float]]: A dictionary where keys are judger names - and values are lists of calculated rewards for each sample. - """ - active_judgers_len = len(active_judgers) - task_len_list = [0] - all_tasks = [] - assert active_judgers_len == len(active_reward_judger_names), ( - f"Expected {active_judgers_len} active judgers, but got {len(active_reward_judger_names)}" - ) - for judger in active_judgers: - tasks = await self._call_single_reward_judger(judger, group_data_item) - all_tasks.extend(tasks) - task_len_list.append(task_len_list[-1] + len(tasks)) - - all_results = await asyncio.gather(*all_tasks) - - assert len(all_results) == len(group_data_item) * len(active_judgers), ( - f"Expected {len(group_data_item) * len(active_judgers)} results, but got {len(all_results)}" - ) - - active_judger_results = {} - for i in range(active_judgers_len): - active_judger_results[active_reward_judger_names[i]] = all_results[task_len_list[i] : task_len_list[i + 1]] - - # 为每个样本创建一个 RLJudgerResponseItem,不同judger的结果放在同一个item中 - uid_list = [item.uid.observation_id for item in group_data_item] - judger_response_items_dict = {uid: RLJudgerResponseItem(uid=uid) for uid in uid_list} - for judger_name, results in active_judger_results.items(): - for result in results: - for data in result: - return_uid = data.uid - judger_response_items_dict[return_uid].reward.update(data.reward) - judger_response_items_dict[return_uid].reward.update({judger_name: data.reward}) - judger_response_items_dict[return_uid].extra_info.update(data.extra_info) - return list(judger_response_items_dict.values()) - - async def run( - self, group_data_item: RLDataFlowItem | List[RLDataFlowItem] - ) -> RLJudgerResponseItem | List[RLJudgerResponseItem]: - """Run the judging process for a group of data items. - - Args: - group_data_item (List[RLTextDataItem]): A list of RLTextDataItem, - each containing the response and other relevant information. - - Returns: - List[float]: A list of final calculated rewards for each data item. - """ - input_type_is_list = True - if not isinstance(group_data_item, list): - input_type_is_list = False - group_data_item = [group_data_item] - - if self.enable_weighted_judgers: - data_source = group_data_item[0].data.data_source - # 如果要使用多个judger并且进行加权打分,则必须在数据集中指定data_source的分数 - assert data_source, "No data source found for the given datasets when multiple judgers are provided." - active_reward_judger = [] - active_reward_judger_names = [] - for idx, judger in enumerate(self.reward_judger): - judger_name = self.reward_judger_names[idx] - if judger_name in data_source: - active_reward_judger.append(judger) - active_reward_judger_names.append(judger_name) - assert active_reward_judger, ( - f"No active reward judger in {self.reward_judger_names} found for the given data source {data_source}." - ) - judger_response_item = await self._call_custom_reward_judger( - active_reward_judger, active_reward_judger_names, group_data_item - ) - - # NOTE: 只计算score的加权和 - for item in judger_response_item: - final_reward = 0 - for name, weight in data_source.items(): - if name in item.reward: - final_reward += item.reward[name]["score"] * weight - item.reward["score"] = final_reward - else: - judger_response_item = await self._call_custom_reward_judger( - self.reward_judger, self.reward_judger_names, group_data_item - ) - if input_type_is_list is False: - return judger_response_item[0] - return judger_response_item - - async def abort(self): - """Abort running judger requests whose config has - ``abort_on_pause=True``.""" - tasks = [] - for idx, judger_group in enumerate(self.reward_judger): - if not self._abort_on_pause_mask[idx]: - continue - for actor in judger_group: - tasks.append(actor.abort.remote()) - if tasks: - await asyncio.gather(*tasks, return_exceptions=True) - - async def restart_judger(self): - """Clear abort state on judger actors whose config has - ``abort_on_pause=True``.""" - tasks = [] - for idx, judger_group in enumerate(self.reward_judger): - if not self._abort_on_pause_mask[idx]: - continue - for actor in judger_group: - tasks.append(actor.restart.remote()) - if tasks: - await asyncio.gather(*tasks, return_exceptions=True) diff --git a/xtuner/v1/ray/judger/native.py b/xtuner/v1/ray/judger/native.py deleted file mode 100644 index 0540e9d91b..0000000000 --- a/xtuner/v1/ray/judger/native.py +++ /dev/null @@ -1,275 +0,0 @@ -import inspect -from typing import Any, Callable, List, Optional - -import httpx -import ray -from pydantic import BaseModel, ConfigDict, Field -from ray.util.placement_group import PlacementGroup - -from xtuner.v1.data_proto.rl_data import RLDataFlowItem, RLJudgerResponseItem -from xtuner.v1.utils import get_logger - - -class NativeJudgerConfig(BaseModel): - """Configuration class for NativeJudger. - - This class defines the configuration options for initializing a NativeJudger, - including resource allocation (number of Ray actors and CPUs per actor), - reward function or remote judging service, optional pre/post-processing functions, - request timeout, and any extra information needed for judging. - - Attributes: - judger_name (str): Name identifier for the judger. - num_ray_actors (int): Number of Ray actor instances to launch. - num_cpus_per_actor (int): Number of CPUs allocated per actor. - reward_func (Optional[Callable]): Local reward function for judging. - Exactly one of reward_func or remote_url must be provided. - remote_url (Optional[str]): Remote service URL for judging. - Exactly one of reward_func or remote_url must be provided. - preprocess_func (Optional[Callable]): Function to preprocess input data before judging. - postprocess_func (Optional[Callable]): Function to postprocess the judging result. - request_timeout (float): Timeout (in seconds) for remote requests. - extra_info (dict): Additional information to be passed to the judger or reward function. - """ - - model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") - judger_name: str - num_ray_actors: int = 1 - num_cpus_per_actor: int = 1 - cpu_memory_per_actor: int = 1024**3 - reward_func: Optional[Callable] = Field(default=None, exclude=True) - remote_url: Optional[str] = None - preprocess_func: Optional[Callable] = Field(default=None, exclude=True) - postprocess_func: Optional[Callable] = Field(default=None, exclude=True) - request_timeout: float = 30.0 - extra_info: dict = Field(default={}, exclude=True) - abort_on_pause: bool = False - - def build_actor(self, pg: PlacementGroup, start_bundle_idx: int) -> List[ray.actor.ActorClass]: - """Create and launch Ray actor instances for the GSM8K judger. - - This method instantiates multiple NativeJudger Ray actors according to `num_ray_actors`, - assigning each to a specific bundle in the provided placement group for resource isolation. - Each actor is initialized with the judger's configuration and reward function. - - Args: - pg: The Ray PlacementGroup used to allocate resources for the actors. - start_bundle_idx: The starting bundle index in the placement group for actor placement. - - Returns: - List[ActorClass]: A list of Ray actor handles representing the launched judger workers. - """ - workers_list = [] - for idx in range(self.num_ray_actors): - bundle_idx = start_bundle_idx + idx - pg_options = {"num_cpus": self.num_cpus_per_actor, "memory": self.cpu_memory_per_actor} - assert pg.bundle_specs[bundle_idx].get("CPU", 1) >= self.num_cpus_per_actor, ( - f"Placement group bundle {bundle_idx} does not have enough CPU resources." - ) - assert pg.bundle_specs[bundle_idx].get("memory", 0) >= self.cpu_memory_per_actor, ( - f"Placement group bundle {bundle_idx} does not have enough memory resources." - ) - worker = ( - ray.remote(NativeJudger) - .options( - placement_group=pg, - placement_group_bundle_index=bundle_idx, - **pg_options, - ) - .remote( - judger_name=self.judger_name, - reward_func=self.reward_func, - remote_url=self.remote_url, - preprocess_func=self.preprocess_func, - postprocess_func=self.postprocess_func, - request_timeout=self.request_timeout, - extra_info=self.extra_info, - ) - ) - workers_list.append(worker) - return workers_list - - -class NativeJudger: - """Base class for judgers, providing a standard interface for executing a - judging process, which can be either a local function or a remote service. - - The judger orchestrates a three-step pipeline: - 1. Pre-process the input data. - 2. Execute the core logic (local function or remote HTTP call). - 3. Post-process the result. - """ - - def __init__( - self, - judger_name: str = "native_judger", - reward_func: Optional[Callable] = None, - remote_url: Optional[str] = None, - preprocess_func: Optional[Callable] = None, - postprocess_func: Optional[Callable] = None, - request_timeout: float = 30.0, - extra_info: dict = {}, - ): - """Initialize the NativeJudger. - - Args: - reward_func (Optional[Callable]): A local function to compute the - reward. Exactly one of `reward_func` or `remote_url` must be - provided. Defaults to None. - remote_url (Optional[str]): The URL of a remote service for - judging. Exactly one of `reward_func` or `remote_url` must be - provided. Defaults to None. - preprocess_func (Optional[Callable]): A function to preprocess the - input data before judger execution. Defaults to None. - postprocess_func (Optional[Callable]): A function to postprocess - the judger result. Defaults to None. - request_timeout (float): Timeout for remote requests in seconds. - Defaults to 30.0. - extra_info (dict): Extra information to be passed to the reward - function. Defaults to {}. - - Raises: - ValueError: If both or neither of `reward_func` and `remote_url` - are provided. - """ - if (reward_func is None and remote_url is None) or (reward_func is not None and remote_url is not None): - raise ValueError("Exactly one of 'reward_func' or 'remote_url' must be provided.") - self.judger_name = judger_name - self.extra_info = extra_info - self.reward_func = reward_func - self.remote_url = remote_url - - self.preprocess_func = preprocess_func or self._default_preprocess - self.postprocess_func = postprocess_func or self._default_postprocess - - self.http_client = None - self.execute_func = None - - if self.reward_func: - self.execute_func = self._local_executor - elif self.remote_url: - self.http_client = httpx.AsyncClient(timeout=request_timeout) - self.execute_func = self._remote_executor - - def _default_preprocess(self, data_item: List[RLDataFlowItem], extra_info: dict) -> Any: - """Default preprocessing function. - - Args: - data_item (RLDataFlowItem | List[RLDataFlowItem]): The data item(s) to preprocess. - - Returns: - Any: A dictionary containing the responses, labels, and extra info. - """ - - assert len(data_item) == 1, "Default preprocess only supports single data item." - # TODO: Support batch reward calculation via API server - response = data_item[0].env.rollout.response - assert data_item[0].data.reward_model is not None - label = data_item[0].data.reward_model["ground_truth"] - return { - "response": response, - "label": label, - "extra_info": extra_info, - } - - def _default_postprocess(self, result: Any) -> List[RLJudgerResponseItem]: - ## 将结果包装成 RLJudgerResponseItem - """Default postprocessing function. - - Args: - result (Any): The result from the execution step. - - Returns: - Any: The result, unchanged. - """ - if not isinstance(result, list): - result = [result] - # todo: 支持多个judger结果的返回 - judger_response_item = [RLJudgerResponseItem(reward=result[i]) for i in range(len(result))] - return judger_response_item - - async def _local_executor(self, data_item: List[RLDataFlowItem]) -> List[RLJudgerResponseItem]: - """Executes the reward function locally. - - Args: - responses (str | List[str]): The model's response(s). - labels (str | List[str]): The ground-truth label(s). - - Returns: - Any: The postprocessed result of the reward function. - """ - assert self.reward_func is not None, "reward_func cannot be None for local execution." - # 记录每个judger请求的uid, 方便后续结果合并 - uid_list = [item.uid.observation_id for item in data_item] - kwargs = self.preprocess_func(data_item, self.extra_info) - if inspect.iscoroutinefunction(self.reward_func): - json_result = await self.reward_func(**kwargs) - else: - json_result = self.reward_func(**kwargs) - - # transform json to RLJudgerResponseItem - result = self.postprocess_func(json_result) - for i in range(len(result)): - result[i].uid = uid_list[i] - return result - - async def _remote_executor(self, data_item: List[RLDataFlowItem]) -> List[RLJudgerResponseItem]: - """Executes the reward function by calling a remote service. - - Args: - responses (str | List[str]): The model's response(s). - labels (str | List[str]): The ground-truth label(s). - - Returns: - Any: The postprocessed result from the remote service, or None if - an error occurs. - """ - assert self.remote_url is not None and self.http_client is not None, ( - "remote_url cannot be None for remote execution." - ) - payload = self.preprocess_func(data_item, self.extra_info) - try: - response = await self.http_client.post(self.remote_url, json=payload) - response.raise_for_status() - json_result = response.json() - # 重要,必须加 - json_result["uid"] = data_item[0].uid.observation_id - # transform json to RLJudgerResponseItem - return self.postprocess_func(json_result) - except httpx.RequestError as exc: - get_logger().error(f"An error occurred while requesting {exc.request.url}: {exc}") - return [] - - async def judge(self, data_item: List[RLDataFlowItem]) -> List[RLJudgerResponseItem]: - """The main public method to run the judging pipeline. - - Args: - responses (str | List[str]): The model's response(s) to be judged. - labels (str | List[str]): The ground-truth label(s). - - Returns: - Any: The final result after the full - preprocess-execute-postprocess pipeline. - - Raises: - RuntimeError: If the judger is not properly initialized. - """ - if self.execute_func is None: - raise RuntimeError("Judger is not properly initialized.") - return await self.execute_func(data_item) - - def get_judger_name(self) -> str: - """Get the name of the judger. - - Returns: - str: The name of the judger. - """ - return self.judger_name - - async def abort(self): - """No-op abort for judgers that don't support cancellation.""" - pass - - def restart(self): - """No-op restart for judgers that don't support cancellation.""" - pass diff --git a/xtuner/v1/ray/rollout/controller.py b/xtuner/v1/ray/rollout/controller.py deleted file mode 100644 index 0cc1923f94..0000000000 --- a/xtuner/v1/ray/rollout/controller.py +++ /dev/null @@ -1,609 +0,0 @@ -import asyncio -import os -import socket -import threading -import time -from collections import OrderedDict -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Union -from uuid import uuid4 - -import ray -import uvicorn -from fastapi import FastAPI -from ray.util.placement_group import PlacementGroup - -from transformers import AutoTokenizer -from xtuner.v1.data_proto.rl_data import ( - RLRolloutRequestItem, - RLRolloutResponseItem, - RolloutExtraParams, - SampleParams, -) -from xtuner.v1.ray.base import AutoAcceleratorWorkers -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.utils import get_logger - -from .worker import RolloutWorker - - -ROLLOUT_RAY_GET_TIMEOUT = os.getenv("XTUNER_ROLLOUT_RAY_GET_TIMEOUT", 5 * 3600) # default 5 hours - - -@dataclass -class WorkerInfo: - """A data class to hold all state information for a single worker.""" - - actor: RolloutWorker - rank: int = -1 - is_active: bool = True - failure_count: int = 0 - running_count: int = 0 - success_count: int = 0 - - -class SessionRouter: - def __init__( - self, - worker_status: Dict[Any, bool], # worker: worker_status - max_sessions: int = 10000, - max_idle_seconds: Optional[float] = 3600.0, - ): - assert len(worker_status) > 0 - self._workers = list(worker_status.items()) - self._max_sessions = max_sessions - self._max_idle = max_idle_seconds - - # OrderedDict: key=session_id -> value=(worker, last_used_ts) - self._map: OrderedDict[int, tuple[Any, float]] = OrderedDict() - self._lock = asyncio.Lock() - self.logger = get_logger() - - def _now(self) -> float: - return time.time() - - def _evict_expired(self): - if self._max_idle is None: - return - now = self._now() - - to_delete = [] - for sid, (_, last_used) in self._map.items(): - if now - last_used > self._max_idle: - to_delete.append(sid) - else: - break - for sid in to_delete: - self._map.pop(sid, None) - - def _evict_lru_to_capacity(self): - while len(self._map) > self._max_sessions: - self._map.popitem(last=False) - - def update_active_workers(self, worker_status: Dict[Any, bool]): - self._workers = list(worker_status.items()) - self.logger.debug(f"SessionRouter update active workers: {self._workers}") - - def _get_healthy_workers(self) -> List[tuple[Any, bool]]: - return [worker for worker in self._workers if worker[1]] - - def _select_worker_for_session(self, session_id: int) -> tuple[Any, bool]: - healthy_workers = self._get_healthy_workers() - if not healthy_workers: - raise RuntimeError("No healthy rollout workers available for SessionRouter.") - worker_idx = session_id % len(healthy_workers) - return healthy_workers[worker_idx] - - async def get_worker(self, session_id: int) -> Any: - async with self._lock: - self._evict_expired() - - if session_id in self._map: - worker, _ = self._map.pop(session_id) - self._map[session_id] = (worker, self._now()) - if worker[1]: # worker is healthy - return worker[0] - - worker = self._select_worker_for_session(session_id) - self._map[session_id] = (worker, self._now()) - - self._evict_lru_to_capacity() - return worker[0] - - -class RolloutController: - """Controller for managing and coordinating multiple RolloutWorker - actors.""" - - def __init__( - self, - infer_config: RolloutConfig, - placement_group: PlacementGroup, - ): - """Initialize the RolloutController. - - Args: - infer_config (RolloutConfig): The configuration for the rollout. - placement_group (PlacementGroup): The placement group for the - RolloutWorker actors. - """ - self.config = infer_config - self.num_gpus_per_engine = ( - self.config.expert_parallel_size - if self.config.expert_parallel_size > 1 - else self.config.tensor_parallel_size - ) - self.logger = get_logger(log_dir=infer_config.worker_log_dir, tag="RolloutController") - self.num_workers = 0 - self.workers_info: Dict[str, WorkerInfo] = {} # url -> WorkerInfo - self.active_rollout_workers: List[RolloutWorker] = [] - self.tokenizer = AutoTokenizer.from_pretrained(infer_config.tokenizer_path, trust_remote_code=True) - self.workers, self.rank_bundle_idx_list = AutoAcceleratorWorkers.from_placement_group( - self._get_worker_cls(), infer_config, placement_group - ) - self.engine_rank_mesh_array, self.worker_server_urls_map = self.init_workers() - self.start_api_server() - # todo(@duanyanhui): add router to replace native round robin - self.router = SessionRouter(self._get_worker_status_for_router()) - self.sample_params = SampleParams().dict() - self.extra_params = dict( - RolloutExtraParams( - stream=False, - include_stop_str_in_output=True, - no_stop_trim=True, - return_logprob=True, - return_token_ids=True, - skip_special_tokens=False, - spaces_between_special_tokens=False, - top_logprobs=1, - ) - ) - self.print_params_flag = True - # The timeout for the environment to wait for the rollout controller's response. - # This should be longer than the controller's internal timeout (`rollout_timeout`) - # to account for potential queuing delays and other overheads. - self.timeout_multiplier = 2.0 - self.cancel_response_timeout = 5.0 - - def _get_worker_status_for_router(self) -> Dict[RolloutWorker, bool]: - """Helper to generate the status dict required by the SessionRouter.""" - return {info.actor: info.is_active for info in self.workers_info.values()} - - def _get_worker_cls(self): - if os.environ.get("XTUNER_USE_LMDEPLOY") == "1": - from .lmdeploy import LMDeployWorker - - return ray.remote(LMDeployWorker) - elif os.environ.get("XTUNER_USE_VLLM") == "1": - from .vllm import vLLMWorker - - return ray.remote(vLLMWorker) - elif os.environ.get("XTUNER_USE_SGLANG") == "1": - from .sglang import SGLangWorker - - return ray.remote(SGLangWorker) - else: - raise NotImplementedError( - "Rollout backend is not supported." - "Please set XTUNER_USE_LMDEPLOY or XTUNER_USE_VLLM" - " or XTUNER_USE_SGLANG environment variable." - ) - - def _get_active_worker_to_url_map(self): - """Get a mapping of active workers to their server URLs.""" - return {info.actor: url for url, info in self.workers_info.items()} - - def _is_port_in_use(self, host: str, port: int) -> bool: - """Check if a port is in use on the given host.""" - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - try: - s.bind((host, port)) - return False - except OSError: - return True - - def _update_active_workers_and_urls_map(self, active_rollout_workers, worker_server_urls_map): - """Update the list of active rollout workers and their server URLs. - - When the inference engine is launched across nodes (rollout_cross_node_comm=True), only the worker with - tp_rank=0 in each engine is responsible for receiving input data. Other tp_ranks do not accept input. - Therefore, this function updates active_rollout_workers and worker_server_urls_map to keep only the tp_rank=0 - workers and their corresponding URLs. - """ - if self.config.rollout_cross_node_comm or self.num_gpus_per_engine < self.config.gpus_per_node: - return active_rollout_workers, worker_server_urls_map - else: - active_worker_interval = self.num_gpus_per_engine // self.config.gpus_per_node - active_rank = list(worker_server_urls_map.keys())[::active_worker_interval] - active_worker_server_urls = list(worker_server_urls_map.values())[::active_worker_interval] - return active_rollout_workers[::active_worker_interval], dict(zip(active_rank, active_worker_server_urls)) - - def get_rollout_info(self): - """Get information about the current rollout setup. - - Returns: - dict: A dictionary containing the engine mesh list, server URL - dictionary, and the rollout configuration. - """ - worker_server_urls_status = {url: info.is_active for url, info in self.workers_info.items()} - return dict( - engine_rank_mesh_array=self.engine_rank_mesh_array, - server_url_dict=self.worker_server_urls_map, - rollout_config=self.config, - worker_server_urls_status=worker_server_urls_status, - ) - - def init_workers(self): - """Initializes and configures the pool of RolloutWorker actors. - - This method configures distributed inference engines by grouping - workers, where each group forms a tensor-parallel inference engine. It - determines the `active_workers` to act as the head of each engine, - constructs the `engine_rank_mesh_array` to define engine topology, acquires - necessary distributed communication ports, and finally launches servers - on the `active_workers` to get their addresses. - - Returns: - Tuple[List, Dict]: A tuple where the first element is - `engine_rank_mesh_array`, a list of lists containing the ranks of workers - in each engine, and the second element is `worker_server_urls_map`, - a dictionary mapping the ID of each active worker to its - corresponding server URL. - """ - active_servers_count, nodes_per_engine = self._get_active_servers_count(self.config, len(self.workers)) - interval = len(self.workers) // active_servers_count - active_rollout_workers = self.workers[::interval] - self.num_workers = len(active_rollout_workers) - server_urls_per_engine = self.config.server_urls_per_engine - - set_bundle_idxs_objectref = [] - engine_rank_mesh_array = [] - activate_worker_idx = 0 - for active_worker in active_rollout_workers: - head_rank, _ = self.rank_bundle_idx_list[activate_worker_idx] - engine_workers_meta = self.rank_bundle_idx_list[head_rank : head_rank + interval] - engine_bundle_idxs = [meta[1] for meta in engine_workers_meta] # meta: (rank, bundle_idx) - set_bundle_idxs_objectref.append(active_worker.set_engine_bundle_idxs.remote(engine_bundle_idxs)) # type: ignore[attr-defined] - engine_rank_mesh_array.append([meta[0] for meta in engine_workers_meta]) - activate_worker_idx += interval - ray.get(set_bundle_idxs_objectref) - # set engine mesh list for each worker - ray.get( - [worker.set_engine_rank_mesh_array.remote(engine_rank_mesh_array) for worker in active_rollout_workers] - ) # type: ignore[attr-defined] - # init dist_init_addr for each worker according to parallel settings - init_dist_init_addrs = ray.get([worker.init_dist_port.remote() for worker in active_rollout_workers]) # type: ignore[attr-defined] - dist_init_addrs = self._update_dist_init_addr( - nodes_per_engine, server_urls_per_engine, init_dist_init_addrs, self.num_gpus_per_engine - ) - # launch rollout servers - worker_server_urls_map = dict( # rank -> url - ray.get([worker.init.remote(dist_init_addrs[i]) for i, worker in enumerate(active_rollout_workers)]) - ) - active_rollout_workers, worker_server_urls_map = self._update_active_workers_and_urls_map( - active_rollout_workers, worker_server_urls_map - ) - self.workers_info = {} - for i in range(len(active_rollout_workers)): - rank = list(worker_server_urls_map.keys())[i] - url = worker_server_urls_map[rank] - self.workers_info[url] = WorkerInfo(rank=rank, actor=active_rollout_workers[i]) - self.logger.info(f"Rollout worker server URLs: {list(self.workers_info.keys())}") - return engine_rank_mesh_array, worker_server_urls_map - - def _deactivate_worker(self, url: str): - """A helper function to deactivate a worker, update all related states, - and shut it down.""" - worker_info = self.workers_info.get(url) - if not worker_info or not worker_info.is_active: - return - - self.logger.warning(f"Deactivating rollout worker {worker_info.actor} with URL {url} due to failures.") - worker_info.is_active = False - self.router.update_active_workers(self._get_worker_status_for_router()) - - ray.get(worker_info.actor.offload.remote(), timeout=ROLLOUT_RAY_GET_TIMEOUT) # type: ignore[attr-defined] - ray.get(worker_info.actor.shutdown.remote(), timeout=ROLLOUT_RAY_GET_TIMEOUT) # type: ignore[attr-defined] - - def update_active_workers(self): - """Check the health of all active rollout workers. - - Returns: - List[bool]: A list of booleans indicating the health status of - each active rollout worker. - """ - active_workers = [(url, info) for url, info in self.workers_info.items() if info.is_active] - if not active_workers: - return - - urls, infos = zip(*active_workers) - actors = [info.actor for info in infos] - - health_statuses = ray.get([actor.check_health.remote() for actor in actors], timeout=ROLLOUT_RAY_GET_TIMEOUT) - - for url, is_healthy in zip(urls, health_statuses): - if not is_healthy: - self.logger.warning(f"Rollout worker {url} is unhealthy.") - self._deactivate_worker(url) - - def deactivate_worker_by_url(self, url: str): - """Deactivates a worker identified by its URL after it exceeds the - maximum retry count.""" - worker_info = self.workers_info.get(url) - if not worker_info or not worker_info.is_active: - return - - worker_info.failure_count += 1 - if ( - self.config.max_retry_per_worker is not None - and worker_info.failure_count < self.config.max_retry_per_worker - ): - self.logger.warning( - f"Rollout worker {url} failed {worker_info.failure_count} times, but not deactivated yet." - ) - return - - self._deactivate_worker(url) - - async def rollout( - self, - prompt: Union[str, List[Dict[str, Any]]] | None = None, - input_ids: Optional[List[int]] | None = None, - tools: List = [], - tool_choice: str = "auto", - sample_params: Optional[SampleParams] = None, - extra_params: dict = dict(), - format: str = "openai", - session_id: Optional[int] = None, - extra_info: dict = dict(), - ) -> RLRolloutResponseItem: - # 这个函数接受标准的openapi chat create接口,所以不需要再额外定义输入的形式 - """Perform a rollout using one of the workers in a round-robin fashion. - - Args: - prompt (List[str]): The prompt to send to the model. - tools (List, optional): A list of tools the model can call. - Defaults to []. - tool_choice (str, optional): The tool choice strategy. - Defaults to "auto". - sample_params (Optional[SampleParams], optional): The sampling - parameters for generation. If None, the default `sample_params` - of the controller will be used. Defaults to None. - extra_params (dict, optional): Extra parameters for the worker. - Defaults to dict(). - format (str, optional): The format of the response. - Defaults to "openai". - - Returns: - The response from the rollout worker. - """ - session_id = session_id if session_id else uuid4().int - worker = await self.router.get_worker(session_id) - # update sample params and extra params (use copy to avoid modifying global state) - current_sample_params = {**self.sample_params, **(sample_params.dict() if sample_params else {})} - current_extra_params = {**self.extra_params, **(extra_params if extra_params else {})} - if self.print_params_flag: - self.logger.info( - f"Rollout with sample params: {current_sample_params}, extra params: {current_extra_params}" - ) - self.print_params_flag = False - assert prompt is not None or input_ids is not None, "Either prompt or input_ids must be provided." - active_worker_to_url_map = self._get_active_worker_to_url_map() - server_url = active_worker_to_url_map.get(worker) - self.workers_info[server_url].running_count += 1 - response_ref = worker.rollout.remote( # type: ignore[attr-defined] - prompt=prompt, - input_ids=input_ids, - tools=tools, - tool_choice=tool_choice, - sample_params=current_sample_params, - extra_params=current_extra_params, - format=format, - extra_info=extra_info, - ) - try: - selected_worker_info = self.workers_info[server_url] - response = await asyncio.wait_for( - asyncio.shield(response_ref), timeout=self.config.rollout_timeout * self.timeout_multiplier - ) - selected_worker_info.success_count += 1 - if response.state == "failed" or response.state == "skipped": - selected_worker_info.failure_count += 1 - self.logger.error(f"Get failed/skipped response from rollout worker {worker}, deactivate it.") - self.deactivate_worker_by_url(server_url) - return response - except asyncio.CancelledError as exc: - ray.cancel(response_ref, recursive=True) - try: - return await asyncio.wait_for(asyncio.shield(response_ref), timeout=self.cancel_response_timeout) - except BaseException: - raise exc - except asyncio.TimeoutError: - ray.cancel(response_ref, recursive=True) - selected_worker_info.failure_count += 1 - self.logger.error(f"Get response from rollout worker {worker} timeout and return skip this sample.") - self.deactivate_worker_by_url(server_url) - return RLRolloutResponseItem(state="skipped") - - def get_rollout_stats(self) -> str: - """Get statistics about the rollout workers. - - Returns: - str: A formatted string containing statistics about each rollout - """ - log_parts = ["Rollout Worker Stats:"] - for url, info in self.workers_info.items(): - log_parts.append( - f" - URL: {url} | Rank: {info.rank} | Active: {info.is_active} | " - f"Running: {info.running_count} | Success: {info.success_count} | " - f"Failures: {info.failure_count}" - ) - log_msg = "\n".join(log_parts) - return log_msg - - def start_api_server(self, host: str = "0.0.0.0", port: int = 8000): - """Starts the API server to expose the rollout functionality.""" - app = FastAPI() - port = self.config.api_port if self.config.api_port else port - - original_port = port - while self._is_port_in_use(host, port): - self.logger.warning(f"Port {port} is in use, trying port {port + 1}") - port += 1 - - if original_port != port: - self.logger.info(f"API server will use port {port} instead of the originally configured {original_port}.") - - @app.post("/v1/chat/completions") - async def chat_completions(request: RLRolloutRequestItem) -> RLRolloutResponseItem: - response = await self.rollout( - prompt=request.messages, - tools=request.tools, - tool_choice=request.tool_choice, - sample_params=request.sample_params, - extra_params=request.extra_params, - ) - return response - - config = uvicorn.Config(app, host=host, port=port) - server = uvicorn.Server(config) - server_thread = threading.Thread(target=server.run, daemon=True) - server_thread.start() - - # internal functions - def _update_dist_init_addr(self, nodes_per_engine, server_urls_per_engine, dist_init_addrs, tp_size): - """Update the distributed initialization addresses for workers. - - This is used to group workers that belong to the same inference engine. - - Args: - nodes_per_engine (int): The number of nodes per inference engine. - server_urls_per_engine (int): The number of server urls per inference engine. - dist_init_addrs (list): The list of initial addresses. - tp_size (int): The tensor parallel size. - - Returns: - list: The updated list of distributed initialization addresses. - """ - # lmdeploy pytorch ep: server_urls_per_engine > 1 - # sglang cross node engine: nodes_per_engine > 1 - assert server_urls_per_engine == 1 or nodes_per_engine == 1 - if nodes_per_engine > 1: - index = list(range(0, self.num_workers + 1, tp_size)) + [self.num_workers] - for i in range(1, len(index)): - dist_init_addrs[index[i - 1] : index[i]] = [dist_init_addrs[index[i - 1]]] * (index[i] - index[i - 1]) - if server_urls_per_engine > 1: - activate_servers = len(dist_init_addrs) - for i in range(0, activate_servers, server_urls_per_engine): - dist_init_addrs[i : i + server_urls_per_engine] = [dist_init_addrs[i]] * server_urls_per_engine - return dist_init_addrs - - def _get_active_servers_count(self, infer_config: RolloutConfig, gpu_nums: int): - """Calculate the number of active servers and nodes per engine. - - This calculation depends on the inference backend and parallelism settings. - - Args: - infer_config (RolloutConfig): The rollout configuration. - gpu_nums (int): The total number of GPUs available. - - Returns: - Tuple[int, int]: A tuple containing the number of active servers - and the number of nodes per engine. - """ - # NOTE:Since different inference engines have different launch methods, - # the number of nodes contained in each engine is not consistent. - # For example: sglang requires starting an inference engine for each node, - # while lmdeploy and vllm does not. Therefore, we calculate the number - # of active servers based on the configuration. - support_cross_node_comm = infer_config.rollout_cross_node_comm - gpus_per_node = infer_config.gpus_per_node - nodes_per_engine = ( - 1 - if support_cross_node_comm or self.num_gpus_per_engine < gpus_per_node - else self.num_gpus_per_engine // gpus_per_node - ) - - active_servers_count = int( - (gpu_nums // self.num_gpus_per_engine) * nodes_per_engine * infer_config.server_urls_per_engine - ) - return active_servers_count, nodes_per_engine - - def _broadcast_to_active_workers(self, method_name: str, block: bool): - """Helper function to call a method on all active workers. - - Args: - method_name (str): The name of the method to call. - block (bool): Whether to block until the call completes. - - Returns: - A list of futures if `block` is False, otherwise a list of results. - """ - futures = [] - for info in self.workers_info.values(): - if info.is_active: - futures.append(getattr(info.actor, method_name).remote()) - else: - self.logger.warning(f"Skipping {method_name} for inactive worker {info.actor}.") - - if not block: - return futures - - results = ray.get(futures, timeout=ROLLOUT_RAY_GET_TIMEOUT) - return results - - def pause(self, block=True): - """Pauses all active rollout workers. - - Args: - block (bool): Whether to block until the operation completes. - """ - return self._broadcast_to_active_workers("pause", block) - - def restart(self, block=True): - """Restarts all active rollout workers. - - Args: - block (bool): Whether to block until the operation completes. - """ - return self._broadcast_to_active_workers("restart", block) - - def reset_prefix_cache(self, block=True): - """Resets the prefix cache on all active rollout workers. - - Args: - block (bool): Whether to block until the operation completes. - """ - return self._broadcast_to_active_workers("reset_prefix_cache", block) - - def offload(self, block=True): - """Offloads model weights and KV cache on all active rollout workers. - - Args: - block (bool): Whether to block until the operation completes. - """ - return self._broadcast_to_active_workers("offload", block) - - def onload_weights(self, block=True): - """Onloads model weights on all active rollout workers. - - Args: - block (bool): Whether to block until the operation completes. - """ - return self._broadcast_to_active_workers("onload_weights", block) - - def onload_kvcache(self, block=True): - """Onloads KV cache on all active rollout workers. - - Args: - block (bool): Whether to block until the operation completes. - """ - return self._broadcast_to_active_workers("onload_kvcache", block) - - def shutdown(self, block=True): - """Shuts down all active rollout workers. - - Args: - block (bool): Whether to block until the operation completes. - """ - return self._broadcast_to_active_workers("shutdown", block) diff --git a/xtuner/v1/ray/rollout/worker.py b/xtuner/v1/ray/rollout/worker.py deleted file mode 100644 index dc3780caaf..0000000000 --- a/xtuner/v1/ray/rollout/worker.py +++ /dev/null @@ -1,859 +0,0 @@ -import asyncio -import copy -import json -import multiprocessing -import os -import time -import traceback -import uuid -from abc import abstractmethod -from typing import Any, Callable, Dict, List, Optional, Union - -import httpx -import numpy as np -import ray -import requests # type: ignore[import-untyped] -from packaging.version import Version -from ray import ObjectRef -from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy - -from transformers import AutoTokenizer -from xtuner.v1.data_proto.rl_data import RLRolloutResponseItem, RolloutState -from xtuner.v1.ray import find_master_addr_and_port -from xtuner.v1.ray.base import AutoAcceleratorWorkers, SingleAcceleratorWorker -from xtuner.v1.ray.config import RolloutConfig -from xtuner.v1.utils import get_logger -from xtuner.v1.utils.httpx_utils import HttpRequestErrorType, HttpRequestResult - - -def get_eos_token(model_path: str) -> int | List[int]: - from xtuner.v1.utils.logger import get_logger - - logger = get_logger() - generation_config_path = os.path.join(model_path, "generation_config.json") - if not os.path.exists(generation_config_path): - logger.warning( - f"Config {generation_config_path} does not exist and thus cannot get eos_token. You must provide eos_token manually." - ) - return [] - with open(generation_config_path) as f: - generation_config = json.load(f) - eos_token_id = generation_config.get("eos_token_id") - return eos_token_id - - -class RolloutWorker(SingleAcceleratorWorker): - """Base class for a rollout worker that runs an inference server. - - This class manages the lifecycle of a distributed inference server, including initialization, launching, and - handling generation requests. It is designed to be subclassed for specific inference backends like LMDeploy, vLLM - or SGLang. - """ - - def __init__( - self, - config: RolloutConfig, - rank: int, - master_addr: str, - master_port: int, - world_size: int, - accelerator: str = "GPU", - ): - """Initialize the RolloutWorker. - - Args: - config (RolloutConfig): The configuration for the rollout. - rank (int): The rank of this worker in the distributed setup. - master_addr (str): The address of the Ray master node. - master_port (int): The port of the Ray master node. - world_size (int): The total number of workers. - accelerator (str): The type of accelerator to use. - Defaults to "GPU". - """ - self.config = config - self.rank = rank - self.master_addr = master_addr # ray master - self.master_port = master_port - self.world_size = world_size - self.accelerator = accelerator - self.server_func: Callable - self.endpoints: dict[str, str] = dict() - self.engine_rank_mesh_array: list[list[int]] - # http_concurrency is calculated based on the max batch size per engine and the total number of engines - assert config.rollout_max_batch_size_per_instance, ( - "rollout_max_batch_size_per_instance must be set in RolloutConfig" - ) - http_concurrency = config.rollout_max_batch_size_per_instance * config.allow_over_concurrency_ratio - limits = httpx.Limits(max_connections=http_concurrency, max_keepalive_connections=100) - self.client = httpx.AsyncClient(limits=limits, timeout=self.config.rollout_timeout) - self.paused = False - self.server_task = None - self.engine_bundle_idxs: list[int] = [] - self.server_process: Optional[multiprocessing.Process] = None - self.logger = get_logger(log_dir=config.worker_log_dir, tag="RolloutWorker") - self.tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer_path, trust_remote_code=True) - self.check_flag = True # only print once - self.enable_return_routed_experts = self.config.enable_return_routed_experts - if self.rank == 0: - self.logger.info(f"RolloutConfig:\n{self.config.model_dump_json(indent=2)}") - eos_token = get_eos_token(self.config.model_path) - self.logger.info(f"Using eos_token: {eos_token} for model at {self.config.model_path}") - self.eos_token: List[int] = [eos_token] if isinstance(eos_token, int) else eos_token - self.receive_abort_request = asyncio.Event() - self.abort_timeout = 5.0 - - def init_dist_port(self): - """Initialize distributed communication ports. - - This method acquires three free ports for the distributed setup: - one for the inference server, one for NCCL, and one for Ray's - distributed communication. - - Returns: - str: The distributed initialization address (host:port). - """ - scheduling_strategy = PlacementGroupSchedulingStrategy( - placement_group=ray.util.get_current_placement_group(), - placement_group_capture_child_tasks=True, - placement_group_bundle_index=self.engine_bundle_idxs[0], - ) - - local_rank = int(ray.get_runtime_context().get_accelerator_ids()[self.accelerator][0]) - interval = 1024 - start_port = self.config.dist_port_base + local_rank * interval - end_port = start_port + interval - self.host, self.ports = ray.get( - find_master_addr_and_port.options(scheduling_strategy=scheduling_strategy).remote( - nums=3, - start_port=start_port, - end_port=end_port, - ) - ) - - self.dist_port = self.ports[0] - self.server_port = self.ports[1] - self.nccl_port = self.ports[2] - self.dist_init_addr = f"{self.host}:{self.dist_port}" - self.server_url = f"http://{self.host}:{self.server_port}" - return self.dist_init_addr - - def init(self, dist_init_addr: str = ""): - """Initialize the worker and launch the server. - - Args: - dist_init_addr (str): The distributed initialization address. - If not provided, the one generated by `init_dist_port` is used. - - Returns: - Tuple[int, str]: A tuple containing the worker's rank and its - server URL. - """ - self.dist_init_addr = dist_init_addr if dist_init_addr else self.dist_init_addr - self.receive_abort_request.clear() - self.launch_server() - return (self.rank, self.server_url) - - def _decode_routed_experts(self, routed_experts: Any) -> Any: - return routed_experts - - def set_engine_rank_mesh_array(self, engine_rank_mesh_array: list[list[int]]): - self.engine_rank_mesh_array = engine_rank_mesh_array - - def set_engine_bundle_idxs(self, engine_bundle_idxs: list[int]): - """Set the bundle indices for the inference engine. - - This is used by some backends (like LMDeploy with Ray executor) to - know which bundles in the placement group belong to this engine. - - Args: - engine_bundle_idxs (list[int]): A list of bundle indices. - """ - self.engine_bundle_idxs = engine_bundle_idxs - - def launch_server(self): - """Launch the inference server as a separate process or Ray task. - - It waits for the server to become healthy before returning. - - Raises: - TimeoutError: If the server fails to start within the specified - timeout. - Exception: If the server task terminates unexpectedly. - """ - server_configs = self._transform_rollout_config_to_server_configs() - timeout = 3600.0 # Increased timeout to 5 minutes for downloading large models - start_time = time.perf_counter() - last_log_time = start_time - headers = { - "Content-Type": "application/json; charset=utf-8", - "Authorization": f"Bearer {server_configs.api_key}", - } - - self.logger.info(f"Launch server task on server_url: {self.server_url}") - - # note(@duanyanhui): launch server as multiprocessing for sglang temporarily - if self.config.launch_server_method == "multiprocessing": - mp_ctx = multiprocessing.get_context("spawn") - process = mp_ctx.Process(target=self.server_func, args=(server_configs,)) - process.start() - self.server_process = process - time.sleep(60) # Wait for the server to start - with requests.Session() as session: - while time.perf_counter() - start_time < timeout: - try: - response = session.get( - f"{self.server_url}/{self.endpoints['health_generate']}", headers=headers - ) - if response.status_code == 200: - return - except requests.RequestException as e: - self.logger.error( - f"can't connect to server url {self.server_url}/{self.endpoints['health_generate']} because {e}" - ) - - current_time = time.perf_counter() - if current_time - last_log_time >= 15: - self.logger.info( - f"Waiting for server to start, Elapsed time: {current_time - start_time:.2f}s" - ) - last_log_time = current_time - - time.sleep(5) - process.terminate() - raise TimeoutError("Server failed to start within the timeout period.") - else: - # launch the server as ray task - # so that the lmdeploy backend could get externl pg - current_pg = ray.util.get_current_placement_group() - scheduling_strategy = PlacementGroupSchedulingStrategy( - placement_group=current_pg, - placement_group_capture_child_tasks=True, - placement_group_bundle_index=self.engine_bundle_idxs[0], - ) - assert ray.is_initialized() - ray_kwargs = ( - {"runtime_env": server_configs.ray_runtime_env} if hasattr(server_configs, "ray_runtime_env") else {} - ) - self.server_task = ( - ray.remote(self.server_func) - .options( - scheduling_strategy=scheduling_strategy, - **AutoAcceleratorWorkers.get_pg_options(current_pg), - **ray_kwargs, - ) - .remote(server_configs) - ) - - with requests.Session() as session: - while time.perf_counter() - start_time < timeout: - try: - response = session.get( - f"{self.server_url}/{self.endpoints['health_generate']}", headers=headers - ) - if response.status_code == 200: - return - except requests.RequestException: - pass - - try: - ray.get(self.server_task, timeout=0.1) - raise Exception("Server task terminated unexpectedly.") - except ray.exceptions.GetTimeoutError: - pass - except Exception as e: - raise e - - current_time = time.perf_counter() - if current_time - last_log_time >= 15: - self.logger.info( - f"Waiting for server to start... Elapsed time: {current_time - start_time:.2f}s" - ) - last_log_time = current_time - - ray.cancel(self.server_task) - raise TimeoutError("Server failed to start within the timeout period.") - - def _adapt_input_to_openai_spec(self, prompts, tools, tool_choice): - openai_prompts = [] - openai_tools = [] - # transform claude spec to openai spec - # 1. transform system prompt: concat provided system_prompt to input prompt - system_prompt = self.config.system_prompt - if system_prompt: - system_prompt_json = {"role": "system", "content": f"{system_prompt}"} - prompts.insert(0, system_prompt_json) - # 2. transform multi-modal usage - for prompt in prompts: - content = prompt["content"] - openai_content = [] - for item in content: - if item["type"] == "image": - if item["source"]["type"] == "base64": - openai_url = f"data:{item['source']['media_type']};base64,{item['source']['data']}" - if item["source"]["type"] == "url": - openai_url = item["source"]["url"] - new_prompt = {"type": "image_url", "image_url": {"url": openai_url}} - openai_content.append(new_prompt) - elif item["type"] == "text": - openai_content.append(item) - new_prompt = copy.deepcopy(prompt) - new_prompt["content"] = openai_content - openai_prompts.append(new_prompt) - # 3. transform tool use - for tool in tools: - openai_tool = { - "type": "function", - "function": { - "name": tool["name"], - "description": tool["description"], - "parameters": tool["input_schema"], - }, - } - openai_tools.append(openai_tool) - return openai_prompts, openai_tools - - def _check_infer_engine_version(self, return_token_ids: bool): - # TODO(@duanyanhui): remove this check when all backends support return_token_ids - if self.check_flag: - if os.environ.get("XTUNER_USE_VLLM", "0") == "1": - if return_token_ids: - self.logger.error( - "VLLM backend does not support return_token_ids or generate with input_ids as input in Xtuner now" - ) - elif os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "1": - import lmdeploy - - lmdeploy_version = lmdeploy.__version__ - if return_token_ids and Version(lmdeploy_version) < Version("0.10.2"): - self.logger.error( - f"You should use lmdeploy >= v0.10.2 to support return_token_ids, but current version is {lmdeploy_version}" - ) - self.check_flag = False - - async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: - send_task = None - try: - if self.receive_abort_request.is_set(): - self.logger.debug(f"Request to {url} was cancelled before sending due to an abort signal.") - return HttpRequestResult(error_type=HttpRequestErrorType.REQUEST_ABORTED, url=url, payload=payload) - req = self.client.build_request( - "POST", - url, - headers=headers, - json=payload, - ) - send_task = asyncio.create_task(self.client.send(req)) - r = await send_task - r.raise_for_status() - return HttpRequestResult(response=r) - - except asyncio.CancelledError: - self.logger.debug(f"Request to {url} was cancelled while waiting for the response.") - if send_task is not None and not send_task.done(): - send_task.cancel() - await asyncio.gather(send_task, return_exceptions=True) - self.receive_abort_request.set() - return HttpRequestResult(error_type=HttpRequestErrorType.REQUEST_ABORTED, url=url, payload=payload) - except Exception as e: - error_type = HttpRequestErrorType.from_exception(e) - result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) - return result - - async def rollout_task( - self, - prompts: Union[str, List[Dict[str, Any]]] | None, - input_ids: List[int] | None, - tools: List, - tool_choice: str, - sample_params: dict, - extra_params: dict, - format: str, - extra_info: dict, - ) -> RLRolloutResponseItem: - uid = extra_info.get("action_id", str(uuid.uuid4())) - action_id = extra_info.get("action_id", str(uuid.uuid4())) - root_id = extra_info.get("root_id", str(uuid.uuid4())) - response = None - cur_retry_times = 0 - - if format == "openai": - openai_prompts, openai_tools = prompts, tools - else: - openai_prompts, openai_tools = self._adapt_input_to_openai_spec(prompts, tools, tool_choice) - - if "return_token_ids" in extra_params and extra_params["return_token_ids"]: - endpoint_url = f"{self.server_url}/{self.endpoints['generate']}" - else: - endpoint_url = f"{self.server_url}/{self.endpoints['v1/chat/completions']}" - - while True: - # 当拼接后的response_ids长度已经达到了max_tokens时,则不需要发送数据,直接返回 - if extra_info.get("partial_rollout_input_ids", None) is not None: - if sample_params["max_tokens"] == 0: - self.logger.debug( - f"Request {uid} reached max context length {self.config.context_length}, no need to rollout more." - ) - return RLRolloutResponseItem( - response=None, - response_ids=None, - logprobs=None, - num_return_tokens=0, - finish_reason="length", - state=RolloutState.COMPLETED, - ) - if extra_info["partial_rollout_input_ids"][-1] in self.eos_token: - self.logger.debug( - f"Request {uid} already ends with eos token {extra_info['partial_rollout_input_ids'][-1]}, no need to rollout more" - ) - return RLRolloutResponseItem( - response=None, - response_ids=None, - logprobs=None, - num_return_tokens=0, - finish_reason="stop", - state=RolloutState.COMPLETED, - ) - - http_result = await self._create_request( - endpoint_url, - openai_prompts, - input_ids, - openai_tools, - tool_choice, - sample_params=sample_params, - extra_params=extra_params, - extra_info=extra_info, - ) - # Case 1: Request was successful - if http_result.response is not None: # 推理完成:completed状态:finish_reason为abort/stop/length, 退出 - response = await self._handle_non_stream_response( - root_id, action_id, sample_params, extra_params, http_result.response, extra_info - ) - if response.state == RolloutState.SKIPPED: - # retry - cur_retry_times += 1 - if cur_retry_times < self.config.max_retry_per_sample: - self.logger.warning( - f"Invalid rollout response for request {uid}, retrying {cur_retry_times}/{self.config.max_retry_per_sample}." - ) - await asyncio.sleep(0.1) - continue - else: - return RLRolloutResponseItem(state=RolloutState.SKIPPED) - return response - - # Case2: Return aborted error if receive abort signal - if http_result.error_type == HttpRequestErrorType.REQUEST_ABORTED: - return RLRolloutResponseItem(finish_reason="abort", state=RolloutState.ABORTED) - - # Case 3: A fatal, non-retryable error occurred - elif http_result.is_unknown_error: - raise RuntimeError( - f"Unexpected error during rollout request {uid} to {http_result.url}: {http_result.exception}" - ) - - # Case 4: A retryable error occurred, and we still have retries left - elif http_result.is_retryable and cur_retry_times < self.config.max_retry_per_sample: - cur_retry_times += 1 - self.logger.warning( - f"Retrying rollout request {uid} to {http_result.url} due to {http_result.error_type} with {http_result.error_msg}. " - f"Retry {cur_retry_times}/{self.config.max_retry_per_sample}." - ) - await asyncio.sleep(0.1) - continue - - elif http_result.is_retryable and cur_retry_times >= self.config.max_retry_per_sample: - self.logger.warning( - f"rollout request {uid} to {http_result.url} was skipped due to max retries reached" - ) - return RLRolloutResponseItem(state=RolloutState.SKIPPED) - elif http_result.is_client_error: - self.logger.warning( - f"rollout request {uid} to {http_result.url} was skipped due to client error {http_result.error_type} with {http_result.error_msg}" - ) - return RLRolloutResponseItem(state=RolloutState.SKIPPED) - elif http_result.is_server_error: - self.logger.warning( - f"rollout request {uid} to {http_result.url} failed due to server error {http_result.error_type} with {http_result.error_msg}" - ) - return RLRolloutResponseItem(state=RolloutState.FAILED) - else: - raise RuntimeError( - f"Unhandled error case for rollout request {uid} to {http_result.url}: {http_result.exception}" - ) - - async def _handle_stream_response(self, uid, sample_params, extra_params, response) -> RLRolloutResponseItem: - last_trajectory = "" - last_token_ids = [] - last_logprobs = [] - finish_reason = "" - async for chunk in response.aiter_lines(): - if not chunk.startswith("data:"): - continue - try: - chunk_data_str = chunk[len("data:") :].strip() - if self.paused or chunk_data_str == "[DONE]": - finish_reason = "paused" if self.paused else finish_reason - break - if not (chunk_data_str.startswith("{") and chunk_data_str.endswith("}")): - continue - - chunk_data = json.loads(chunk_data_str) - - if "return_token_ids" in extra_params and extra_params["return_token_ids"]: - last_trajectory = last_trajectory + chunk_data.get("text", "") - finish_reason = chunk_data["meta_info"].get("finish_reason") - if finish_reason is not None: - finish_reason = finish_reason["type"] - - output_token_logprobs = chunk_data["meta_info"].get("output_token_logprobs") - if output_token_logprobs is not None: - for token_logprob in output_token_logprobs: - last_logprobs.append(token_logprob[0]) - last_token_ids.append(token_logprob[1]) - else: - delta_content = chunk_data["choices"][0]["delta"].get("content") - last_trajectory = last_trajectory + delta_content if delta_content else last_trajectory - last_token_id = chunk_data["choices"][0]["delta"].get("gen_tokens") - if last_token_id is not None: - last_token_ids.extend(last_token_id) - finish_reason = chunk_data["choices"][0].get("finish_reason") - logprobs_content = chunk_data["choices"][0]["logprobs"] - if logprobs_content is not None: - for content_item in logprobs_content["content"]: - last_logprobs.append(content_item["logprob"]) - - except json.JSONDecodeError as e: - self.logger.error(f"JSON decode error for chunk in request {uid}: {chunk}, error: {e}") - continue - except Exception as e: - self.logger.error(f"Error processing chunk for {uid}: {chunk}, error: {e}") - return RLRolloutResponseItem( - response="", - finish_reason="failed", - ) - - assert finish_reason in ["stop", "length", "tool_call", "abort"], f"Unexpected finish_reason: {finish_reason}" - rollout_response = RLRolloutResponseItem( - response=last_trajectory, - response_ids=last_token_ids if len(last_token_ids) > 0 else None, - num_return_tokens=len(last_token_ids) if len(last_token_ids) > 0 else None, - finish_reason=finish_reason, - logprobs=last_logprobs, - ) - return rollout_response - - async def _handle_non_stream_response( - self, root_id, action_id, sample_params, extra_params, response, input_extra_info - ) -> RLRolloutResponseItem: - response = response.json() - uid = action_id - if "return_token_ids" in extra_params and extra_params["return_token_ids"]: - last_logprobs: list[float] = [] - try: - extra_info = {} - finish_reason = response["meta_info"]["finish_reason"]["type"] - if finish_reason == "abort" and self.receive_abort_request.is_set() is False: - self.receive_abort_request.set() - self.logger.info(f"Setting receive_abort_request to True for rank {self.rank}") - if "output_token_logprobs" in response["meta_info"]: - if response["meta_info"]["output_token_logprobs"] is None: - last_token_ids = [] - last_logprobs = [] - else: - last_token_ids = [item[1] for item in response["meta_info"]["output_token_logprobs"]] - last_logprobs = [item[0] for item in response["meta_info"]["output_token_logprobs"]] - assert len(last_token_ids) <= sample_params["max_tokens"], ( - f"Generation length exceeds the limit: generated length is {len(last_token_ids)}, limit is {sample_params['max_tokens']}" - ) - else: - num_return_tokens = response["meta_info"].get("completion_tokens", 0) - last_token_ids = response["output_ids"][-num_return_tokens:] if num_return_tokens > 0 else [] - - if self.enable_return_routed_experts and not extra_params.get("disable_routed_experts", False): - assert "routed_experts" in response["meta_info"], ( - "enable_return_routed_experts is True, but routed_experts is not in meta_info" - ) - exist_history_routed_experts = ( - "routed_experts" in input_extra_info and input_extra_info["routed_experts"] is not None - ) - routed_experts = response["meta_info"].pop("routed_experts") # token[layer[expert]] - if routed_experts is not None and not exist_history_routed_experts: - # 不存在历史专家,先把当前专家存起来 - routed_experts = self._decode_routed_experts(routed_experts) - if not isinstance(routed_experts, ObjectRef): # 全部转为ray.objectref存储 - routed_experts = ray.put(routed_experts) - extra_info["routed_experts"] = routed_experts - elif routed_experts is not None and exist_history_routed_experts: - # 存在历史专家,则不进行put 操作,直接进行concat - routed_experts = self._decode_routed_experts(routed_experts) - if isinstance(routed_experts, ObjectRef): - cur_routed_experts = await routed_experts # n,layer,expert - ray.internal.free(routed_experts, local_only=False) - else: - cur_routed_experts = routed_experts - - history_routed_experts = await input_extra_info["routed_experts"] # n, layer, expert - ray.internal.free(input_extra_info["routed_experts"], local_only=False) - del input_extra_info - - assert (history_routed_experts.shape[0] - 1) > 0 and history_routed_experts.shape[ - 0 - ] - 1 <= cur_routed_experts.shape[0], ( - f"Existing routed_experts shape: {history_routed_experts.shape}, current routed_experts shape: {cur_routed_experts.shape}" - ) - init_cur_roued_experts = cur_routed_experts.shape[0] - cur_routed_experts = cur_routed_experts[history_routed_experts.shape[0] :, :, :] - concat_routed_experts = np.concatenate((history_routed_experts, cur_routed_experts), axis=0) - prompt_tokens = response["meta_info"].get("prompt_tokens", 0) - response_tokens = response["meta_info"].get("completion_tokens", 0) - assert concat_routed_experts.shape[0] == prompt_tokens + response_tokens - 1, ( - f"Routed experts shape {concat_routed_experts.shape[0]} does not match total tokens {prompt_tokens + response_tokens - 1}" - ) - self.logger.debug( - f"[{root_id}/{action_id}] Partial Rollout Stats: " - f"Tokens(prompt={prompt_tokens}, response={response_tokens}, total={prompt_tokens + response_tokens}) | " - f"Experts(exist={history_routed_experts.shape}, init_cur={init_cur_roued_experts}, cur={cur_routed_experts.shape}, concat={concat_routed_experts.shape})" - ) - extra_info["routed_experts"] = ray.put(concat_routed_experts) - del history_routed_experts - del cur_routed_experts - else: - assert finish_reason == "abort", ( - f"routed_experts is None, but finish_reason is {finish_reason}, expected abort. response: {response}" - ) - # NOTE: When set return_token_ids = True, the response must contain valid token_ids/logprobs. - # If not, we consider it as an invalid response and retry it. - # NOTE: !!! When finish_reason is abort, some queries may not return token_ids or logprobs. !!! - if finish_reason != "abort" and (len(last_token_ids) == 0 or len(last_logprobs) == 0): - self.logger.error(f"Invalid rollout response for request {uid}: {response}") - return RLRolloutResponseItem(state=RolloutState.SKIPPED) - else: - rollout_response = RLRolloutResponseItem( - response=response["text"], - response_ids=last_token_ids, - num_return_tokens=len(last_token_ids), - finish_reason=finish_reason, - logprobs=last_logprobs, - extra_info=extra_info, - state=RolloutState.ABORTED if finish_reason == "abort" else RolloutState.COMPLETED, - ) - # self.logger.info(f"Rollout response for request {uid}: finish_reason={finish_reason}, num_return_tokens={len(last_token_ids)}") - return rollout_response - except KeyError as e: - error_msg = f"Missing expected key {e} in response {response} for {uid}" - raise RuntimeError(error_msg) - except IndexError as e: - error_msg = f"Index error {e} while processing response {response} for {uid}" - raise RuntimeError(error_msg) - except AssertionError as e: - error_msg = f"AssertionError: {e} when processing response {response} for {uid}" - raise RuntimeError(error_msg) - except json.JSONDecodeError as e: - error_msg = f"JSONDecodeError: {e} when processing response {response} for {uid}" - raise RuntimeError(error_msg) - except TypeError as e: - error_msg = f"TypeError: {e} when processing response {response} for {uid}" - raise RuntimeError(error_msg) - except Exception as e: - error_msg = f"Unexpected error: {e} when processing response {response} for {uid}\nTraceback: {traceback.format_exc()}" - raise RuntimeError(error_msg) - else: - # v1/chat/completions API response - try: - last_trajectory = response["choices"][0]["message"]["content"] - finish_reason = response["choices"][0]["finish_reason"] - rollout_response = RLRolloutResponseItem( - response=last_trajectory, - finish_reason=finish_reason, - num_return_tokens=response["usage"]["completion_tokens"], - ) - return rollout_response - except KeyError as e: - error_msg = f"Missing expected key {e} in response {response} for {uid}" - raise RuntimeError(error_msg) - except IndexError as e: - error_msg = f"Index error {e} while processing response {response} for {uid}" - raise RuntimeError(error_msg) - except AssertionError as e: - error_msg = f"AssertionError: {e} when processing response {response} for {uid}" - raise RuntimeError(error_msg) - except json.JSONDecodeError as e: - error_msg = f"JSONDecodeError: {e} when processing response {response} for {uid}" - raise RuntimeError(error_msg) - except TypeError as e: - error_msg = f"TypeError: {e} when processing response {response} for {uid}" - raise RuntimeError(error_msg) - except Exception as e: - error_msg = f"Unexpected error: {e} when processing response {response} for {uid}\nTraceback: {traceback.format_exc()}" - raise RuntimeError(error_msg) - - async def rollout( - self, - prompt: Union[str, List[Dict[str, Any]]] | None = None, - input_ids: Optional[List[int]] | None = None, - tools: List = [], - tool_choice: str = "auto", - sample_params: dict = dict(), - extra_params: dict = dict(), - format: str = "openai", - extra_info: dict = dict(), - ) -> RLRolloutResponseItem: - """Public method to initiate a rollout. - - Args: - prompt (str): The input prompt for generation. - sample_params (dict): Parameters for sampling. - - Returns: - The result of the `rollout_task`. - """ - return await self.rollout_task( - prompt, input_ids, tools, tool_choice, sample_params, extra_params, format=format, extra_info=extra_info - ) - - def pause(self): - """Pause the worker's generation process.""" - self.paused = True - - def restart(self): - """Resume the worker's generation process.""" - self.receive_abort_request.clear() - - def check_health(self) -> bool: - """Check the health of the worker's server. - - Returns: - bool: True if the server is healthy, False otherwise. - """ - try: - headers = { - "Content-Type": "application/json; charset=utf-8", - "Authorization": f"Bearer {self.config.api_key}", - } - response = requests.get( - f"{self.server_url}/{self.endpoints['health_generate']}", headers=headers, timeout=5.0 - ) - return response.status_code == 200 - except requests.RequestException as e: - self.logger.error(f"Health check failed for server {self.server_url}: {e}") - return False - - def shutdown(self): - """Shut down the worker, its server task, and any child processes.""" - if self.server_task is not None: - ray.cancel(self.server_task) - return - - if self.server_process is not None: - import psutil - - parent = psutil.Process(self.server_process.pid) - children = parent.children(recursive=True) - for child in children: - child.terminate() - gone, alive = psutil.wait_procs(children, timeout=5) - for child in alive: - child.kill() - parent.terminate() - parent.wait(timeout=5) - self.logger.debug(f"Worker {self.rank} server process and its children terminated.") - return - - @abstractmethod - async def _create_request( - self, - url: str, - prompt: Union[str, List[Dict[str, Any]]] | None, - input_ids: List[int] | None, - tools: List, - tool_choice: str, - sample_params: dict, - extra_params: dict, - extra_info: dict, - ): - """Abstract method to create a generation request. - - Must be implemented by subclasses. - """ - raise NotImplementedError("_create_request must be implemented in subclass") - - @abstractmethod - def _transform_rollout_config_to_server_configs(self): - """Abstract method to transform rollout config to server configs. - - Must be implemented by subclasses. - """ - raise NotImplementedError("_transform_rollout_config_to_server_configs must be implemented in subclass") - - @abstractmethod - def _transform_sample_params(self, sample_params: Dict): - """Abstract method to transform rollout config to server configs. - - Must be implemented by subclasses. - """ - raise NotImplementedError("_transform_rollout_config_to_server_configs must be implemented in subclass") - - @abstractmethod - def get_logprobs(self, input_ids, sampling_params): - """Abstract method to get log probabilities. - - Must be implemented by subclasses. - """ - raise NotImplementedError("get_logprobs must be implemented in subclass") - - @abstractmethod - def update_weights(self): - """Abstract method to update model weights. - - Must be implemented by subclasses. - """ - raise NotImplementedError("update_weights must be implemented in subclass") - - @abstractmethod - def reset_prefix_cache(self): - """Abstract method to reset the prefix cache. - - Must be implemented by subclasses. - """ - raise NotImplementedError("reset_prefix_cache must be implemented in subclass") - - @abstractmethod - def offload(self): - """Abstract method to offload the model and KVcache. - - Must be implemented by subclasses. - """ - raise NotImplementedError("reset_prefix_cache must be implemented in subclass") - - @abstractmethod - def onload_weights(self): - """Abstract method to onload the model weights. - - Must be implemented by subclasses. - """ - pass - - @abstractmethod - def onload_kvcache(self): - """Abstract method to onload the KV cache. - - Must be implemented by subclasses. - """ - pass - - @abstractmethod - def pause_generation(self): - """Abstract method to pause the generation process. - - Must be implemented by subclasses. - """ - raise NotImplementedError("pause_generation must be implemented in subclass") - - @abstractmethod - def continue_generation(self): - """Abstract method to continue the generation process. - - Must be implemented by subclasses. - """ - raise NotImplementedError("continue_generation must be implemented in subclass") diff --git a/xtuner/v1/ray/utils.py b/xtuner/v1/ray/utils.py deleted file mode 100644 index c3549da849..0000000000 --- a/xtuner/v1/ray/utils.py +++ /dev/null @@ -1,239 +0,0 @@ -import asyncio -import hashlib -import importlib -import socket -from asyncio import AbstractEventLoop, Task -from typing import TYPE_CHECKING, Callable, Coroutine, List, Optional, cast - -import ray -from ray import ObjectRef - - -if TYPE_CHECKING: - import ray.actor - - from xtuner.v1.ray.base.accelerator import AcceleratorType - - -def get_ray_accelerator() -> "AcceleratorType": - from xtuner.v1.utils.device import get_device - - """Get the type of accelerator available in the Ray environment. - - This function checks for the availability of CUDA and NPU devices and - returns the corresponding accelerator type. - - Returns: - AcceleratorType: The type of accelerator ("GPU" or "NPU"). - - Raises: - NotImplementedError: If neither CUDA nor NPU is available. - """ - accelerator = None - if get_device() == "cuda": - accelerator = "GPU" - return "GPU" - else: - try: - import torch_npu # noqa: F401 - - accelerator = "NPU" - except ImportError: - pass - - if accelerator is None: - raise NotImplementedError( - "Supports only CUDA or NPU. If your device is CUDA or NPU, " - "please make sure that your environmental settings are " - "configured correctly." - ) - - return cast("AcceleratorType", accelerator) - - -def load_function(path): - """Load a function from a module. - - :param path: The path to the function, e.g. "module.submodule.function". - :return: The function object. - """ - module_path, _, attr = path.rpartition(".") - module = importlib.import_module(module_path) - return getattr(module, attr) - - -def _is_port_available(check_socket: socket.socket, port: int) -> bool: - try: - check_socket.bind(("", port)) - check_socket.listen(1) - return True - except OSError: - return False - - -@ray.remote -def find_master_addr_and_port( - nums: int = 1, start_port: Optional[int] = None, end_port: Optional[int] = None -) -> tuple[str, int] | tuple[str, list[int]]: - """Finds an available master address and a specified number of ports. - - This remote function gets the node's IP address and binds to one or more - available ports, which can be used for distributed communication. - - Args: - nums (int): The number of ports to find. Defaults to 1. - start_port (Optional[int]): The starting port to search from. - If None, random available ports will be used. Defaults to None. - end_port (Optional[int]): The ending port to search to (exclusive). - If start_port is None, this parameter is ignored. Defaults to None. - - Returns: - A tuple containing the address and a single port if `nums` is 1, - or a list of ports if `nums` is greater than 1. - """ - addr = ray.util.get_node_ip_address() - ports: list[int] = [] - sockets: list[socket.socket] = [] - - if start_port is None: - for _ in range(nums): - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - # if the port is binded and listened by this socket and then we close it, - # socket.SO_REUSEADDR would make the port be reusable even it's in TIME_WAIT state. - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sockets.append(s) - if _is_port_available(check_socket=s, port=0): - ports.append(s.getsockname()[1]) - else: - assert isinstance(start_port, int), "If start_port isn't None, it must be an integer." - assert isinstance(end_port, int), "If start_port isn't None, end_port must be an integer." - assert end_port - start_port >= nums, ( - "If start_port isn't None, the range between start_port and end_port must be at least nums." - ) - - for candidate_port in range(start_port, end_port): - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - # if the port is binded and listened by this socket and then we close it, - # socket.SO_REUSEADDR would make the port be reusable even it's in TIME_WAIT state. - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sockets.append(s) - if _is_port_available(check_socket=s, port=candidate_port): - ports.append(candidate_port) - # enough ports found - if len(ports) >= nums: - break - - if len(ports) < nums: - raise RuntimeError(f"Could not find {nums} available ports starting from port {start_port} to {end_port}.") - - # close all sockets, no matter available or not - for s in sockets: - s.close() - - if len(ports) == 1: - return addr, ports[0] - else: - return addr, ports - - -@ray.remote -def get_accelerator_ids(accelerator: str) -> list: - """Get the IDs of the available accelerators (GPUs, NPUs, etc.) in the Ray - cluster.""" - return ray.get_runtime_context().get_accelerator_ids()[accelerator] - - -def bind_train_rollout( - train_workers, - rollout_controller, -) -> None: - """Bind the training and rollout workers for updating weights. - - This function retrieves rollout information from the rollout controller - and distributes it to the training workers, enabling them to update the - rollout models' weights. - - Args: - train_workers: A list of training worker actors. - rollout_controller: The rollout controller actor. - """ - info_dict = ray.get(rollout_controller.get_rollout_info.remote()) # type: ignore[attr-defined] - ray.get([worker.update_rollout_info.remote(**info_dict) for worker in train_workers]) # type: ignore[attr-defined] - return - - -def handle_task_exception(task: Task): - """Handles exceptions from an asyncio Task. - - This function checks if a task has raised an exception and, if so, - re-raises it. It ignores `asyncio.CancelledError`. - - Args: - task (Task): The asyncio task to check for exceptions. - - Raises: - Exception: The exception raised by the task. - """ - try: - exc = task.exception() - if exc is not None: - raise exc - except asyncio.CancelledError: - pass # Task was cancelled, ignore - - -def create_task( - coro: Coroutine, - loop: Optional[AbstractEventLoop] = None, - done_callbacks: Optional[List[Callable[[Task], object]]] = None, -) -> asyncio.tasks.Task: - """Creates and configures an asyncio Task. - - This function creates a task from a coroutine and attaches specified - done callbacks. By default, it includes a callback to handle exceptions. - - Args: - coro (Coroutine): The coroutine to wrap in a task. - loop (Optional[AbstractEventLoop], optional): The event loop to run - the task in. If None, the current event loop is used. - Defaults to None. - done_callbacks (Optional[List[Callable[[Task], object]]], optional): - A list of callbacks to add to the task. If None, a default - exception handler is used. Defaults to None. - - Returns: - asyncio.tasks.Task: The created asyncio task. - """ - if loop is None: - loop = asyncio.get_event_loop() - if done_callbacks is None: - done_callbacks = [handle_task_exception] - task = loop.create_task(coro) - for callback in done_callbacks: - task.add_done_callback(callback) - return task - - -def free_object_refs(refs: List[ObjectRef]) -> None: - valid_refs = [ref for ref in refs if isinstance(ref, ObjectRef)] - if not valid_refs: - return - try: - ray._private.internal_api.free(valid_refs, local_only=False) - except Exception: - ray.internal.free(valid_refs, local_only=False) - - -def deterministic_item_sort_key(sample) -> tuple[int, int, int, int]: - return ( - sample.uid.root_id, - sample.uid.action_id, - sample.uid.observation_id, - sample.uid.version, - ) - - -def build_deterministic_session_id(environment: str, sample) -> int: - session_key = f"{environment}|{sample.uid.root_id}|{sample.uid.action_id}|{sample.uid.observation_id}" - session_id = int.from_bytes(hashlib.sha256(session_key.encode("utf-8")).digest()[:8], "big") - return session_id or 1 diff --git a/xtuner/v1/rl/advantage/__init__.py b/xtuner/v1/rl/advantage/__init__.py index c7f29c1537..9e03878114 100644 --- a/xtuner/v1/rl/advantage/__init__.py +++ b/xtuner/v1/rl/advantage/__init__.py @@ -1,4 +1,12 @@ from xtuner.v1.rl.advantage.base import AdvantageEstimator +from xtuner.v1.rl.advantage.config import ( + BaseAdvantageConfig, + DrGRPOAdvantageConfig, + GRPOAdvantageConfig, + OPOAdvantageConfig, + PassKAdvantageConfig, + RLOOAdvantageConfig, +) from xtuner.v1.rl.advantage.grpo import DrGRPOEstimator, GRPOEstimator from xtuner.v1.rl.advantage.opo import OPOEstimator from xtuner.v1.rl.advantage.passk import PassKEstimator @@ -7,6 +15,12 @@ __all__ = [ "AdvantageEstimator", + "BaseAdvantageConfig", + "GRPOAdvantageConfig", + "DrGRPOAdvantageConfig", + "RLOOAdvantageConfig", + "OPOAdvantageConfig", + "PassKAdvantageConfig", "GRPOEstimator", "DrGRPOEstimator", "RLOOEstimator", diff --git a/xtuner/v1/rl/advantage/base.py b/xtuner/v1/rl/advantage/base.py index faffe0fc78..44d279abbe 100644 --- a/xtuner/v1/rl/advantage/base.py +++ b/xtuner/v1/rl/advantage/base.py @@ -19,15 +19,11 @@ def compute(self, rewards: torch.Tensor, group: list[RLDataFlowItem]) -> torch.T from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING +from typing import Any import torch -if TYPE_CHECKING: - from xtuner.v1.data_proto.rl_data import RLDataFlowItem - - class AdvantageEstimator(ABC): """Abstract base class for advantage estimation. @@ -46,7 +42,7 @@ def compute(self, rewards, group): """ @abstractmethod - def compute(self, rewards: torch.Tensor, group: list[RLDataFlowItem]) -> torch.Tensor: + def compute(self, rewards: torch.Tensor, group: list[Any]) -> torch.Tensor: """Compute advantages from rewards for a single prompt group. Args: diff --git a/xtuner/v1/rl/config/advantage.py b/xtuner/v1/rl/advantage/config.py similarity index 97% rename from xtuner/v1/rl/config/advantage.py rename to xtuner/v1/rl/advantage/config.py index 05a4b8c0eb..5a9c78f0fd 100644 --- a/xtuner/v1/rl/config/advantage.py +++ b/xtuner/v1/rl/advantage/config.py @@ -1,7 +1,7 @@ from typing import Annotated from cyclopts import Group, Parameter -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from xtuner.v1.rl.advantage.base import AdvantageEstimator @@ -12,6 +12,8 @@ class BaseAdvantageConfig(BaseModel): """Intermediate base for discriminated union.""" + model_config = ConfigDict(extra="forbid") + def build(self) -> AdvantageEstimator: raise NotImplementedError("Subclasses must implement this method.") diff --git a/xtuner/v1/rl/advantage/grpo.py b/xtuner/v1/rl/advantage/grpo.py index b9c63cc787..e2f911e7fe 100644 --- a/xtuner/v1/rl/advantage/grpo.py +++ b/xtuner/v1/rl/advantage/grpo.py @@ -1,14 +1,10 @@ -from typing import TYPE_CHECKING +from typing import Any import torch from xtuner.v1.rl.advantage.base import AdvantageEstimator -if TYPE_CHECKING: - from xtuner.v1.data_proto.rl_data import RLDataFlowItem - - class GRPOEstimator(AdvantageEstimator): """Group Relative Policy Optimization (GRPO) advantage estimator. https://arxiv.org/abs/2402.03300 @@ -27,7 +23,7 @@ class GRPOEstimator(AdvantageEstimator): def __init__(self, eps: float = 1e-8) -> None: self.eps = eps - def compute(self, rewards: torch.Tensor, group: list["RLDataFlowItem"]) -> torch.Tensor: + def compute(self, rewards: torch.Tensor, group: list[Any]) -> torch.Tensor: mean = rewards.mean() std = rewards.std() + self.eps return (rewards - mean) / std @@ -56,13 +52,13 @@ def __init__(self, max_length: float, eps: float = 1e-8) -> None: self.max_length = max_length self.eps = eps - def compute(self, rewards: torch.Tensor, group: list["RLDataFlowItem"]) -> torch.Tensor: + def compute(self, rewards: torch.Tensor, group: list[Any]) -> torch.Tensor: mean = rewards.mean() std = rewards.std() + self.eps z = (rewards - mean) / std lengths = torch.tensor( - [len(d.env.rollout.response_ids) for d in group], # type: ignore + [_response_len(data) for data in group], dtype=torch.float32, device=rewards.device, ) @@ -70,3 +66,9 @@ def compute(self, rewards: torch.Tensor, group: list["RLDataFlowItem"]) -> torch def __repr__(self) -> str: return f"DrGRPOEstimator(max_length={self.max_length}, eps={self.eps})" + + +def _response_len(data: Any) -> int: + if hasattr(data, "response_ids"): + return len(data.response_ids or []) + return len(data.env.rollout.response_ids or []) diff --git a/xtuner/v1/rl/advantage/opo.py b/xtuner/v1/rl/advantage/opo.py index bf5272e1ec..0aa2ff54f9 100644 --- a/xtuner/v1/rl/advantage/opo.py +++ b/xtuner/v1/rl/advantage/opo.py @@ -1,14 +1,10 @@ -from typing import TYPE_CHECKING +from typing import Any import torch from xtuner.v1.rl.advantage.base import AdvantageEstimator -if TYPE_CHECKING: - from xtuner.v1.data_proto.rl_data import RLDataFlowItem - - class OPOEstimator(AdvantageEstimator): """OPO advantage estimator. @@ -33,9 +29,9 @@ class OPOEstimator(AdvantageEstimator): def __init__(self, eps: float = 1e-8) -> None: self.eps = eps - def compute(self, rewards: torch.Tensor, group: list["RLDataFlowItem"]) -> torch.Tensor: + def compute(self, rewards: torch.Tensor, group: list[Any]) -> torch.Tensor: lengths = torch.tensor( - [len(d.env.rollout.response_ids) for d in group], # type: ignore + [_response_len(data) for data in group], dtype=torch.float32, device=rewards.device, ) @@ -44,3 +40,9 @@ def compute(self, rewards: torch.Tensor, group: list["RLDataFlowItem"]) -> torch def __repr__(self) -> str: return "OPOEstimator()" + + +def _response_len(data: Any) -> int: + if hasattr(data, "response_ids"): + return len(data.response_ids or []) + return len(data.env.rollout.response_ids or []) diff --git a/xtuner/v1/rl/advantage/passk.py b/xtuner/v1/rl/advantage/passk.py index b7ae5a5a2e..ef5359aea5 100644 --- a/xtuner/v1/rl/advantage/passk.py +++ b/xtuner/v1/rl/advantage/passk.py @@ -1,14 +1,10 @@ -from typing import TYPE_CHECKING +from typing import Any import torch from xtuner.v1.rl.advantage.base import AdvantageEstimator -if TYPE_CHECKING: - from xtuner.v1.data_proto.rl_data import RLDataFlowItem - - class PassKEstimator(AdvantageEstimator): """Pass@k Training for Adaptively Balancing Exploration and Exploitation of Large Reasoning Models. https://arxiv.org/pdf/2508.10751 @@ -43,7 +39,7 @@ def _comb(self, n: int, r: int) -> float: return float(comb(n, r)) - def compute(self, rewards: torch.Tensor, group: list["RLDataFlowItem"]) -> torch.Tensor: + def compute(self, rewards: torch.Tensor, group: list[Any]) -> torch.Tensor: import numpy as np n = len(rewards) diff --git a/xtuner/v1/rl/advantage/rloo.py b/xtuner/v1/rl/advantage/rloo.py index 9d8ed83975..0d2657416a 100644 --- a/xtuner/v1/rl/advantage/rloo.py +++ b/xtuner/v1/rl/advantage/rloo.py @@ -1,14 +1,10 @@ -from typing import TYPE_CHECKING +from typing import Any import torch from xtuner.v1.rl.advantage.base import AdvantageEstimator -if TYPE_CHECKING: - from xtuner.v1.data_proto.rl_data import RLDataFlowItem - - class RLOOEstimator(AdvantageEstimator): """REINFORCE Leave-One-Out (RLOO) advantage estimator. https://arxiv.org/abs/2402.14740 @@ -23,7 +19,7 @@ class RLOOEstimator(AdvantageEstimator): When K=1, returns the raw reward as the advantage. """ - def compute(self, rewards: torch.Tensor, group: list["RLDataFlowItem"]) -> torch.Tensor: + def compute(self, rewards: torch.Tensor, group: list[Any]) -> torch.Tensor: k = len(rewards) if k == 1: return rewards diff --git a/xtuner/v1/rl/agent_loop/__init__.py b/xtuner/v1/rl/agent_loop/__init__.py new file mode 100644 index 0000000000..c3ab93eb50 --- /dev/null +++ b/xtuner/v1/rl/agent_loop/__init__.py @@ -0,0 +1,25 @@ +from .agent_loop import ( + AgentLoop, + AgentLoopActor, + AgentLoopConfig, + AgentLoopSpec, + RayAgentLoop, + RayAgentLoopProxy, + RouterAgentLoop, + get_agent_loop_rollout_ctl, +) +from .single_turn_agent_loop import SingleTurnAgentLoop, SingleTurnAgentLoopConfig + + +__all__ = [ + "AgentLoopConfig", + "SingleTurnAgentLoopConfig", + "AgentLoop", + "AgentLoopSpec", + "AgentLoopActor", + "RouterAgentLoop", + "RayAgentLoop", + "RayAgentLoopProxy", + "SingleTurnAgentLoop", + "get_agent_loop_rollout_ctl", +] diff --git a/xtuner/v1/rl/agent_loop/agent_loop.py b/xtuner/v1/rl/agent_loop/agent_loop.py new file mode 100644 index 0000000000..59cf739c96 --- /dev/null +++ b/xtuner/v1/rl/agent_loop/agent_loop.py @@ -0,0 +1,248 @@ +from __future__ import annotations + +import asyncio +from abc import ABC, abstractmethod +from typing import TypeAlias, cast + +from pydantic import BaseModel, ConfigDict, Field, model_validator +from ray.actor import ActorClass, ActorProxy +from ray.util.placement_group import PlacementGroup + +from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams +from xtuner.v1.rl.judger import Judger +from xtuner.v1.rl.rollout import RolloutController +from xtuner.v1.rl.utils import CPUActorLauncher, create_task +from xtuner.v1.utils import get_logger, ray_method +from xtuner.v1.utils.processing_utils import load_processor, load_tokenizer + + +class AgentLoopConfig(ABC, BaseModel): + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + hf_checkpoint: str + sample_params: SampleParams + num_ray_actors: int = Field( + default=0, + ge=0, + description="Number of AgentLoop Ray actor instances. 0 means local mode.", + ) + num_cpus: float = Field(default=1, gt=0, description="CPU cores required by the AgentLoop actor itself.") + cpu_memory: int = Field(default=1024**3, gt=0, description="CPU memory in bytes required by AgentLoop.") + + @model_validator(mode="after") + def _validate_ray_actor_config(self) -> AgentLoopConfig: + if self.num_ray_actors == 0 and (self.num_cpus != 1 or self.cpu_memory != 1024**3): + logger = get_logger() + logger.warning("num_cpus and cpu_memory are ignored when AgentLoop runs in local mode.") + return self + + def build(self, rollout_controller, judger: Judger | None = None, logger=None) -> AgentLoopSpec: + if self.num_ray_actors == 0: + return self.build_local( + rollout_controller=rollout_controller, + judger=judger, + logger=logger, + ) + if self.num_ray_actors > 1: + return self._build_router( + rollout_controller=rollout_controller, + judger=judger, + logger=logger, + ) + return self._build_ray_actor( + rollout_controller=rollout_controller, + judger=judger, + logger=logger, + ) + + @abstractmethod + def build_local( + self, + rollout_controller, + judger: Judger | None = None, + logger=None, + ) -> AgentLoop: ... + + def _build_ray_actor( + self, + rollout_controller: RolloutController, + pg: PlacementGroup | None = None, + judger: Judger | None = None, + logger=None, + ) -> RayAgentLoopProxy: + return cast( + "RayAgentLoopProxy", + CPUActorLauncher.build_actor( + AgentLoopActor, + self, + rollout_controller, + judger, + logger, + pg=pg, + bundle_idx=0, + actor_num_cpus=self.num_cpus, + actor_memory=self.cpu_memory, + capture_child_tasks=True, + ), + ) + + def _build_ray_actors( + self, + rollout_controller: RolloutController, + num_actors: int, + pg: PlacementGroup | None = None, + judger: Judger | None = None, + logger=None, + start_bundle_idx: int = 0, + ) -> list[RayAgentLoopProxy]: + return cast( + list["RayAgentLoopProxy"], + CPUActorLauncher.build_actors( + AgentLoopActor, + self, + rollout_controller, + judger, + logger, + pg=pg, + start_bundle_idx=start_bundle_idx, + num_workers=num_actors, + actor_num_cpus_per_worker=self.num_cpus, + actor_memory_per_worker=self.cpu_memory, + capture_child_tasks=True, + ), + ) + + def _build_router( + self, + rollout_controller: RolloutController, + pg: PlacementGroup | None = None, + judger: Judger | None = None, + logger=None, + start_bundle_idx: int = 0, + ) -> RouterAgentLoop: + return RouterAgentLoop( + workers=self._build_ray_actors( + rollout_controller=rollout_controller, + num_actors=self.num_ray_actors, + pg=pg, + judger=judger, + logger=logger, + start_bundle_idx=start_bundle_idx, + ), + rollout_ctl=rollout_controller, + ) + + +class AgentLoop(ABC): + def __init__( + self, + rollout_ctl: RolloutController, + sample_params: SampleParams, + hf_checkpoint: str, + judger: Judger | None = None, + logger=None, + ) -> None: + self.rollout_ctl = rollout_ctl + self.hf_checkpoint = hf_checkpoint + self.tokenizer = load_tokenizer(hf_checkpoint, trust_remote_code=True) + self.processor = load_processor(hf_checkpoint, trust_remote_code=True) + self.sample_params = sample_params + self.judger = judger + if logger is None: + self.logger = get_logger() + else: + self.logger = logger + + @abstractmethod + async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> RolloutState: ... + + async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> list[RolloutState]: + pending_tasks = [] + for state in rollout_state: + state.sample_params = self.sample_params + task = create_task(self.generate_sample(state, **kwargs)) + pending_tasks.append(task) + generated_samples = asyncio.gather(*pending_tasks) + group_samples = await generated_samples + return group_samples + + +class RouterAgentLoop: + def __init__(self, workers: list[RayAgentLoopProxy], rollout_ctl: RolloutController): + self.workers = workers + self.rollout_ctl = rollout_ctl + self._worker_loads = dict.fromkeys(workers, 0) + self._rr_index = 0 + self._lock = asyncio.Lock() + + async def _pick_worker(self) -> RayAgentLoopProxy: + async with self._lock: + min_load = min(self._worker_loads.values()) + candidates = [worker for worker in self.workers if self._worker_loads[worker] == min_load] + worker = candidates[self._rr_index % len(candidates)] + self._rr_index = (self._rr_index + 1) % len(self.workers) + self._worker_loads[worker] += 1 + return worker + + async def _release_worker(self, worker: RayAgentLoopProxy) -> None: + async with self._lock: + self._worker_loads[worker] -= 1 + + async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> RolloutState: + worker = await self._pick_worker() + try: + return await worker.generate_sample.remote(rollout_state, **kwargs) + finally: + await self._release_worker(worker) + + async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> list[RolloutState]: + worker = await self._pick_worker() + try: + return await worker.generate_group.remote(rollout_state, **kwargs) + finally: + await self._release_worker(worker) + + def get_worker_status(self) -> dict[str, int]: + return {str(worker): load for worker, load in self._worker_loads.items()} + + +async def get_agent_loop_rollout_ctl(agent_loop: AgentLoopSpec) -> RolloutController: + rollout_ctl = getattr(agent_loop, "rollout_ctl", None) + if rollout_ctl is not None: + return rollout_ctl + + get_rollout_ctl = getattr(agent_loop, "get_rollout_ctl", None) + if get_rollout_ctl is None or not hasattr(get_rollout_ctl, "remote"): + raise AttributeError(f"Agent loop {type(agent_loop)} does not expose rollout_ctl or get_rollout_ctl().") + return await get_rollout_ctl.remote() + + +class AgentLoopActor: + def __init__( + self, + agent_loop_config: AgentLoopConfig, + rollout_controller: RolloutController, + judger: Judger | None = None, + logger=None, + ): + self.agent_loop = agent_loop_config.build_local( + rollout_controller=rollout_controller, + judger=judger, + logger=logger, + ) + + @ray_method + async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> RolloutState: + return await self.agent_loop.generate_sample(rollout_state, **kwargs) + + @ray_method + async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> list[RolloutState]: + return await self.agent_loop.generate_group(rollout_state, **kwargs) + + @ray_method + async def get_rollout_ctl(self): + return self.agent_loop.rollout_ctl + + +RayAgentLoop = cast(ActorClass[AgentLoopActor], CPUActorLauncher.to_actor_class(AgentLoopActor)) +RayAgentLoopProxy: TypeAlias = ActorProxy[AgentLoopActor] +AgentLoopSpec: TypeAlias = AgentLoop | RayAgentLoopProxy | RouterAgentLoop diff --git a/xtuner/v1/rl/agent_loop/gsm8k_with_tool.py b/xtuner/v1/rl/agent_loop/gsm8k_with_tool.py new file mode 100644 index 0000000000..feb1a7c9ce --- /dev/null +++ b/xtuner/v1/rl/agent_loop/gsm8k_with_tool.py @@ -0,0 +1,157 @@ +import copy +import json +import re +from typing import cast + +from pydantic import BaseModel, ConfigDict + +from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams +from xtuner.v1.rl.agent_loop import AgentLoop, AgentLoopConfig +from xtuner.v1.rl.judger import Judger +from xtuner.v1.rl.rollout import RolloutController +from xtuner.v1.utils import get_logger + + +logger = get_logger() + + +class GSM8KToolAgentLoopConfig(AgentLoopConfig): + max_turns: int + + def build_local(self, rollout_controller, judger: Judger | None = None, logger=None) -> "GSM8KToolAgentLoop": + return GSM8KToolAgentLoop( + max_turns=self.max_turns, + rollout_ctl=rollout_controller, + hf_checkpoint=self.hf_checkpoint, + sample_params=self.sample_params, + judger=judger, + ) + + +class FunctionCall(BaseModel): + model_config = ConfigDict(extra="forbid") + + name: str + arguments: dict + + +class GSM8KToolAgentLoop(AgentLoop): + def __init__( + self, + max_turns: int, + rollout_ctl: RolloutController, + hf_checkpoint: str, + sample_params: SampleParams, + judger: Judger | None = None, + ): + super().__init__( + rollout_ctl=rollout_ctl, hf_checkpoint=hf_checkpoint, sample_params=sample_params, judger=judger + ) + self.max_turns = max_turns + self.tool_call_pattern = re.compile(r"\n*(.*?)", re.DOTALL) + self.tool_call_start_token: str = "" + self.tool_call_end_token: str = "" + + def calc_gsm8k_reward(self, answer: dict, ground_truth: str) -> float: + from xtuner.v1.rl.judger.gsm8k import compute_reward + + extra_info = {"score": 1.0, "format_score": 0} + actual_answer = answer.get("answer", "") + if not actual_answer.startswith("#### "): + actual_answer = "#### " + actual_answer + return compute_reward(actual_answer, ground_truth, extra_info) + + def extract_tool_calls(self, rollout_state: RolloutState) -> tuple[str, list[FunctionCall]]: + text = self.tokenizer.decode(rollout_state.response_ids) + if self.tool_call_start_token not in text or self.tool_call_end_token not in text: + return text, [] + + matches = self.tool_call_pattern.findall(text) + function_calls = [] + for match in matches: + try: + function_call = json.loads(match) + name, arguments = function_call["name"], function_call["arguments"] + function_calls.append(FunctionCall(name=name, arguments=arguments)) + except Exception as e: + logger.error(f"Error parsing tool call JSON: {e}") + continue + + content = self.tool_call_pattern.sub("", text) + return content, function_calls + + async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> RolloutState: + # Respect state passed from preprocess for partial rollout continuation. + base_sample_params = copy.deepcopy(rollout_state.sample_params or self.sample_params) + final_response_mask: list[int] = [] + final_response_ids: list[int] = [] + final_logprobs: list[float] = [] + + max_len = base_sample_params.max_tokens + cur_turn_tokens = list(rollout_state.tokens or rollout_state.prompt_ids or []) + remaining_max_tokens = max_len - len(final_response_ids) + cur_turn = 0 + while True: + if cur_turn >= self.max_turns or len(final_response_ids) >= max_len or remaining_max_tokens <= 0: + break + + rollout_state.tokens = cur_turn_tokens + rollout_state.sample_params = copy.deepcopy(base_sample_params) + rollout_state.sample_params.max_tokens = remaining_max_tokens + + rollout_state = await self.rollout_ctl.generate.remote(rollout_state) # type: ignore[attr-defined] + cur_turn += 1 + response_ids = cast(list[int], rollout_state.response_ids) + cur_turn_tokens.extend(response_ids) + + # 处理 rollout_controller 的输出 + final_response_ids.extend(response_ids) + final_logprobs.extend(cast(list[float], rollout_state.logprobs)) + final_response_mask.extend([1] * len(response_ids)) + # TODO: 处理 routed_experts, 要注意这里涉及到是否要解引用的问题 + + if len(final_response_ids) >= max_len: + break + + _, function_calls = self.extract_tool_calls(rollout_state) + if not function_calls: + break + + tool_messages = [] + for function_call in function_calls: + tool_name = function_call.name + tool_args = function_call.arguments + if tool_name == "calc_gsm8k_reward": + answer = tool_args + ground_truth = cast(dict, rollout_state.reward_model).get("ground_truth", "") + function_results = self.calc_gsm8k_reward(answer, ground_truth) + tool_message = { + "role": "tool", + "content": json.dumps({"result": function_results}, ensure_ascii=False), + } + tool_messages.append(tool_message) + + # 处理工具调用的输出 + tools_response_ids = self.tokenizer.apply_chat_template(tool_messages, remove_system_prompt=True) + final_response_ids.extend(tools_response_ids) + final_logprobs.extend([0.0] * len(tools_response_ids)) + final_response_mask.extend([0] * len(tools_response_ids)) + + # 处理下一轮生成的输入 + cur_turn_tokens.extend(tools_response_ids) + remaining_max_tokens = max_len - len(final_response_ids) + + final_response_ids = final_response_ids[:max_len] + final_response_mask = final_response_mask[:max_len] + final_logprobs = final_logprobs[:max_len] + + rollout_state.response_ids = final_response_ids + rollout_state.response_mask = final_response_mask + rollout_state.logprobs = final_logprobs + rollout_state.response = self.tokenizer.decode(rollout_state.response_ids) + assert len(rollout_state.response_ids) == len(rollout_state.response_mask) == len(rollout_state.logprobs), ( + f"{len(rollout_state.response_ids)} vs {len(rollout_state.response_mask)} vs {len(rollout_state.logprobs)}" + ) + if self.judger is not None: + rollout_state = await self.judger.judge(rollout_state) + return rollout_state diff --git a/xtuner/v1/rl/agent_loop/single_turn_agent_loop.py b/xtuner/v1/rl/agent_loop/single_turn_agent_loop.py new file mode 100644 index 0000000000..14cc57a2fb --- /dev/null +++ b/xtuner/v1/rl/agent_loop/single_turn_agent_loop.py @@ -0,0 +1,54 @@ +from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status +from xtuner.v1.rl.judger import Judger +from xtuner.v1.rl.rollout import RolloutController + +from .agent_loop import AgentLoop, AgentLoopConfig +from .utils import PartialRolloutHandler + + +class SingleTurnAgentLoopConfig(AgentLoopConfig): + def build_local(self, rollout_controller, judger: Judger | None = None, logger=None) -> "SingleTurnAgentLoop": + return SingleTurnAgentLoop( + rollout_ctl=rollout_controller, + sample_params=self.sample_params, + hf_checkpoint=self.hf_checkpoint, + judger=judger, + logger=logger, + ) + + +class SingleTurnAgentLoop(AgentLoop): + def __init__( + self, + rollout_ctl: RolloutController, + sample_params: SampleParams, + hf_checkpoint: str, + judger: Judger | None = None, + logger=None, + ): + super().__init__(rollout_ctl, sample_params, hf_checkpoint, judger, logger) + self.max_tokens = self.sample_params.max_tokens + self.partial_rollout_handler = PartialRolloutHandler(max_tokens=self.max_tokens) + + async def generate_sample( + self, + rollout_state: RolloutState, + **kwargs, + ) -> RolloutState: + enable_partial_rollout = kwargs.get("enable_partial_rollout", False) + + # rollout state 预处理, enable_partial_rollout = True 会在这里拼接 token 和修正 max_token + rollout_state = self.partial_rollout_handler.preprocess(rollout_state, enable_partial_rollout) + if not rollout_state.tokens: + rollout_state.tokens = rollout_state.prompt_ids + + # 推理引擎generate, 生成的结果会覆盖到 rollout_state.response_ids 上 + rollout_state = await self.rollout_ctl.generate.remote(rollout_state) # type: ignore[attr-defined] + # rollout state 后处理: 合并 partial rollout 的历史上下文 + rollout_state = self.partial_rollout_handler.postprocess(rollout_state) + # 非 COMPLETED 状态(如被截断、放弃等)直接早退,不触发打分 + if rollout_state.status != Status.COMPLETED: + return rollout_state + if self.judger is not None: + rollout_state = await self.judger.judge(rollout_state) + return rollout_state diff --git a/xtuner/v1/rl/agent_loop/utils.py b/xtuner/v1/rl/agent_loop/utils.py new file mode 100644 index 0000000000..9d88bd8244 --- /dev/null +++ b/xtuner/v1/rl/agent_loop/utils.py @@ -0,0 +1,117 @@ +import time + +import ray + +from xtuner.v1.data_proto.rl_data import RolloutState, Status +from xtuner.v1.rl.utils import clear_rollout_response_for_rerun, free_object_refs +from xtuner.v1.utils import get_logger + + +logger = get_logger() + + +def _resolve_routed_experts(routed_experts: list[int] | ray.ObjectRef) -> list[int]: + if isinstance(routed_experts, ray.ObjectRef): + routed_experts = ray.get(routed_experts) + if hasattr(routed_experts, "tolist"): + routed_experts = routed_experts.tolist() + assert isinstance(routed_experts, list), f"Unexpected routed_experts type: {type(routed_experts)}" + return routed_experts + + +class PartialRolloutHandler: + """Handle preprocessing and postprocessing for partial rollout + continuation.""" + + def __init__(self, max_tokens: int) -> None: + self.max_tokens = max_tokens + + def preprocess(self, rollout_state: RolloutState, enable_partial_rollout: bool = False) -> RolloutState: + if rollout_state.status == Status.EXPIRED or ( + not enable_partial_rollout and rollout_state.status == Status.ABORTED + ): + rollout_state = clear_rollout_response_for_rerun(rollout_state) + rollout_state.sample_params = rollout_state.sample_params.model_copy( + update={"max_tokens": self.max_tokens} + ) + rollout_state.response = "" + rollout_state.status = Status.INIT + + if not rollout_state.response_ids or rollout_state.status == Status.COMPLETED: + return rollout_state + + # Set up token and length variable + response_ids = rollout_state.response_ids + prompt_ids = list(rollout_state.prompt_ids or []) + response_len = len(response_ids) + prompt_len = len(prompt_ids) + + rollout_state.tokens = prompt_ids + response_ids # concatenate for partial rollout continuation + remaining_tokens = self.max_tokens - response_len # compute remaining max_tokens budget + rollout_state.sample_params = rollout_state.sample_params.copy(update={"max_tokens": remaining_tokens}) + + logger.debug( + f"[PartialRolloutHandler] Sample {rollout_state.uid} continue rollout | Remaining tokens allowed: {remaining_tokens} | Status: {rollout_state.status} | Prompt len: {prompt_len} | Response len: {response_len} | Staleness: {rollout_state.seq_staleness} | Total tokens: {len(rollout_state.tokens)}" + ) + # TODO: handle routed_experts + rollout_state.extra_fields["history_response_dict"] = { + "response_ids": rollout_state.tokens[prompt_len:] if rollout_state.tokens else [], + "response": rollout_state.response or "", + "logprobs": rollout_state.logprobs or [], + "response_mask": rollout_state.response_mask or [], + "routed_experts": rollout_state.routed_experts, + } + return rollout_state + + def postprocess(self, rollout_state: RolloutState) -> RolloutState: + # TODO: if not enable partial rollout, return directly? + + # Concatenate history response fields + history_dict = rollout_state.extra_fields.pop("history_response_dict", None) + if not history_dict: + return rollout_state + + rollout_state.response_ids = history_dict.get("response_ids", []) + (rollout_state.response_ids or []) + rollout_state.response = history_dict.get("response", "") + (rollout_state.response or "") + rollout_state.logprobs = history_dict.get("logprobs", []) + (rollout_state.logprobs or []) + rollout_state.response_mask = history_dict.get("response_mask", []) + (rollout_state.response_mask or []) + history_routed_experts_ref = history_dict.get("routed_experts") + cur_routed_experts_ref = rollout_state.routed_experts + if history_routed_experts_ref is not None and cur_routed_experts_ref is not None: + start_time = time.time() + history_routed_experts = _resolve_routed_experts(history_routed_experts_ref) + cur_routed_experts = _resolve_routed_experts(cur_routed_experts_ref) + cur_routed_experts_len = len(cur_routed_experts) + history_routed_experts_len = len(history_routed_experts) + assert history_routed_experts_len - 1 <= cur_routed_experts_len, ( + f"Existing routed_experts len: {history_routed_experts_len}, current routed_experts len: {cur_routed_experts_len}" + ) + cur_routed_experts = cur_routed_experts[history_routed_experts_len:] + concat_routed_experts = history_routed_experts + cur_routed_experts + + prompt_ids = rollout_state.prompt_ids or [] + response_ids = rollout_state.response_ids or [] + expect_tokens_num = len(prompt_ids) + len(response_ids) - 1 + assert len(concat_routed_experts) == expect_tokens_num, ( + f"After concatenation, routed_experts len: {len(concat_routed_experts)}, expected tokens num: {expect_tokens_num}" + ) + logger.info( + f"[PartialRolloutHandler] Postprocess rollout {rollout_state.uid}: " + f"concat routed_experts len={len(concat_routed_experts)} " + f"(history={history_routed_experts_len}, new={cur_routed_experts_len}), " + f"prompt={len(prompt_ids)}, response={len(response_ids)}" + ) + rollout_state.routed_experts = ray.put(concat_routed_experts) + free_object_refs( + [ref for ref in (history_routed_experts_ref, cur_routed_experts_ref) if isinstance(ref, ray.ObjectRef)] + ) + end_time = time.time() + logger.info( + f"[PartialRolloutHandler] Postprocess routed_experts concatenation time: {end_time - start_time:.4f} seconds" + ) + elif history_routed_experts_ref is None and cur_routed_experts_ref is not None: + rollout_state.routed_experts = cur_routed_experts_ref + elif history_routed_experts_ref is not None and cur_routed_experts_ref is None: + rollout_state.routed_experts = history_routed_experts_ref + + return rollout_state diff --git a/xtuner/v1/rl/agent_loop_manager/__init__.py b/xtuner/v1/rl/agent_loop_manager/__init__.py new file mode 100644 index 0000000000..aa51d7f0c8 --- /dev/null +++ b/xtuner/v1/rl/agent_loop_manager/__init__.py @@ -0,0 +1,40 @@ +from .agent_loop_manager import ( + AgentLoopManager, + AgentLoopManagerConfig, + AgentLoopManagerStatus, + ProduceBatchResult, + TaskSpecConfig, +) +from .producer import ( + AsyncProduceStrategy, + AsyncProduceStrategyConfig, + ProduceBatchStatus, + ProduceProgress, + ProduceStrategy, + ProduceStrategyConfig, + SyncProduceStrategy, + SyncProduceStrategyConfig, + calculate_stale_threshold, +) +from .sampler import Sampler, SamplerConfig + + +# manager 包只暴露批量调度、采样和生产策略;单条 agent loop 保持在 agent_loop 包。 +__all__ = [ + "AgentLoopManagerConfig", + "AgentLoopManager", + "AgentLoopManagerStatus", + "TaskSpecConfig", + "ProduceBatchResult", + "ProduceStrategyConfig", + "SyncProduceStrategyConfig", + "AsyncProduceStrategyConfig", + "ProduceBatchStatus", + "ProduceProgress", + "ProduceStrategy", + "SyncProduceStrategy", + "AsyncProduceStrategy", + "calculate_stale_threshold", + "SamplerConfig", + "Sampler", +] diff --git a/xtuner/v1/rl/agent_loop_manager/agent_loop_manager.py b/xtuner/v1/rl/agent_loop_manager/agent_loop_manager.py new file mode 100644 index 0000000000..714c84727c --- /dev/null +++ b/xtuner/v1/rl/agent_loop_manager/agent_loop_manager.py @@ -0,0 +1,926 @@ +import asyncio +import json +import math +import os +import time +from dataclasses import dataclass +from enum import Enum, auto +from pathlib import Path + +from pydantic import BaseModel, ConfigDict, Field + +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from xtuner.v1.data_proto.rl_data import RolloutState, Status +from xtuner.v1.rl.agent_loop import AgentLoopConfig, AgentLoopSpec, get_agent_loop_rollout_ctl +from xtuner.v1.rl.judger import ComposedJudgerConfig, JudgerConfig, build_judger +from xtuner.v1.rl.replay_buffer import ReplayBuffer +from xtuner.v1.rl.rollout import RolloutController, continue_generation, pause_generation +from xtuner.v1.rl.utils import asyncio_run +from xtuner.v1.utils import get_logger + +from .producer import ( + GROUP_GENERATE_TIME_KEY, + AsyncProduceStrategy, + ProduceBatchStatus, + ProduceProgress, + ProduceStrategy, + ProduceStrategyConfig, + SyncProduceStrategyConfig, +) +from .sampler import Sampler, SamplerConfig + + +@dataclass +class ProduceBatchResult: + """Result of a single ``produce_batch`` call. + + Args: + rollout_states (list[list[RolloutState]]): Completed rollout groups retrieved from the replay buffer for training. + group_gen_count (int | None): Number of generate-group calls finished in this batch (None if no generations ran). + group_gen_mean_s (float | None): Mean wall-clock time per generate-group call, in seconds. + group_gen_p50_s (float | None): Median (p50) generate-group time, in seconds. + group_gen_p99_s (float | None): 99th percentile generate-group time, in seconds. + group_gen_p99_p50_ratio (float | None): Ratio of p99 to p50, indicating tail-latency skew. + group_gen_pause_time_s (float | None): Time spent in pause/cleanup phase (async strategy only), in seconds. + leftover_completed (int): Number of completed groups remaining in the replay buffer after this batch. + leftover_aborted (int): Number of aborted groups remaining in the replay buffer. + leftover_expired (int): Number of expired groups remaining in the replay buffer. + """ + + rollout_states: list[list[RolloutState]] + status: ProduceBatchStatus = ProduceBatchStatus.NORMAL + # per-group generation timing stats (all None if no generations occurred) + group_gen_count: int | None = None + group_gen_mean_s: float | None = None + group_gen_p50_s: float | None = None + group_gen_p99_s: float | None = None + group_gen_p99_p50_ratio: float | None = None + group_gen_pause_time_s: float | None = None + # leftover samples remaining in replay buffer after batch retrieval + leftover_completed: int = 0 + leftover_aborted: int = 0 + leftover_expired: int = 0 + task_batch_sizes: dict[str, int] | None = None + task_results: dict[str, "ProduceBatchResult"] | None = None + + +@dataclass(frozen=True) +class _TaskRunner: + task_name: str + agent_loop: AgentLoopSpec + produce_strategy: ProduceStrategy + sampler: Sampler + weight: float = 1.0 + order: int = 0 + + +class _TaskSamplerView: + def __init__(self, samplers: list[Sampler]): + self._samplers = samplers + + def __len__(self) -> int: + return sum(len(sampler) for sampler in self._samplers) + + +class AgentLoopManagerStatus(Enum): + """AgentLoopManager 的全局状态. + + 按下面的路径流转: + - 初始状态是 NORMAL + - NORMAL -> UPDATE_ABORT + - trainer 开始做权重同步前触发 + - UPDATE_ABORT -> NORMAL + - 权重同步完成后调用 continue_product() + - NORMAL -> EXPIRED_BATCH + - 当前 rollout model 已经过旧 + - EXPIRED_BATCH -> UPDATE_ABORT + - trainer 检测到过期后,进入权重同步阶段 + - 任意状态 -> FINISH + - 训练结束 + + 这里有一个重要区分: + - AgentLoopManagerStatus 是“后台 producer 的全局运行状态” + - ProduceBatchStatus 是“单次调度调用的局部结果” + """ + + NORMAL = auto() + UPDATE_ABORT = auto() + EXPIRED_BATCH = auto() + FINISH = auto() + + +def _fill_produce_timing_stats( + result: ProduceBatchResult, generate_times_s: list[float], pause_time_s: float = 0.0 +) -> None: + if not generate_times_s: + if pause_time_s > 0: + result.group_gen_pause_time_s = pause_time_s + return + sorted_times = sorted(generate_times_s) + n = len(sorted_times) + mean_s = sum(sorted_times) / n + p50_s = sorted_times[n // 2] + p99_s = sorted_times[int(n * 0.99)] + ratio = p99_s / p50_s if p50_s > 0 else float("inf") + result.group_gen_count = n + result.group_gen_mean_s = mean_s + result.group_gen_p50_s = p50_s + result.group_gen_p99_s = p99_s + result.group_gen_p99_p50_ratio = ratio + result.group_gen_pause_time_s = pause_time_s + + +def _fill_group_timing_stats( + result: ProduceBatchResult, rollout_states: list[list[RolloutState]], pause_time_s: float = 0.0 +) -> None: + generate_times: list[float] = [] + for group in rollout_states: + if not group: + continue + group_time = getattr(group[0], "extra_fields", {}).get(GROUP_GENERATE_TIME_KEY) + if group_time is not None: + generate_times.append(group_time) + + _fill_produce_timing_stats(result, generate_times, pause_time_s=pause_time_s) + + +def _aggregate_status(statuses: list[ProduceBatchStatus]) -> ProduceBatchStatus: + if any(status == ProduceBatchStatus.UPDATE_ABORT for status in statuses): + return ProduceBatchStatus.UPDATE_ABORT + if any(status == ProduceBatchStatus.EXPIRED_BATCH for status in statuses): + return ProduceBatchStatus.EXPIRED_BATCH + return ProduceBatchStatus.NORMAL + + +async def _produce_single_task_to_buffer( + task_runner: _TaskRunner, + replay_buffer: ReplayBuffer, + batch_size: int, + train_step: int, + model_step: int, + update_event: asyncio.Event | None, + progress: ProduceProgress, + target_cumulative: int | None = None, +) -> ProduceBatchStatus: + return await task_runner.produce_strategy.produce_batch( + task_runner.agent_loop, + task_runner.sampler, + replay_buffer, + batch_size, + task_runner.task_name, + train_step=train_step, + model_step=model_step, + update_event=update_event, + progress=progress, + target_cumulative=target_cumulative, + ) + + +class TaskSpecConfig(BaseModel): + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + + task_name: str + weight: float = Field(default=1.0, ge=0.0) + agent_loop_config: AgentLoopConfig + judger_config: JudgerConfig | ComposedJudgerConfig | None = None + produce_strategy_config: ProduceStrategyConfig = SyncProduceStrategyConfig() + sampler_config: SamplerConfig + + +class AgentLoopManagerConfig(BaseModel): + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + + tasks: list[TaskSpecConfig] | TaskSpecConfig + + def build( + self, + rollout_controller: RolloutController, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, + replay_buffer: ReplayBuffer, + logger=None, + sync_weights_interval: int = 1, + ) -> "AgentLoopManager": + tasks = self.tasks if isinstance(self.tasks, list) else [self.tasks] + if not tasks: + raise ValueError("AgentLoopManagerConfig requires at least one task config.") + + seen_task_names: set[str] = set() + task_runners: list[_TaskRunner] = [] + for order, task_cfg in enumerate(tasks): + if task_cfg.task_name in seen_task_names: + raise ValueError(f"Duplicate task_name found in AgentLoopManagerConfig: {task_cfg.task_name}") + seen_task_names.add(task_cfg.task_name) + + agent_loop = task_cfg.agent_loop_config.build( + rollout_controller=rollout_controller, + judger=build_judger(task_cfg.judger_config) if task_cfg.judger_config is not None else None, + logger=logger, + ) + produce_strategy = task_cfg.produce_strategy_config.build(sync_weights_interval=sync_weights_interval) + sampler = task_cfg.sampler_config.build(tokenizer=tokenizer, replay_buffer=replay_buffer) + task_runners.append( + _TaskRunner( + task_name=task_cfg.task_name, + agent_loop=agent_loop, + produce_strategy=produce_strategy, + sampler=sampler, + weight=task_cfg.weight, + order=order, + ) + ) + + return AgentLoopManager( + task_runners=task_runners, + replay_buffer=replay_buffer, + logger=logger, + ) + + +class AgentLoopManager: + _TASK_CHECKPOINT_DIR = "tasks" + _MANAGER_STATE_PATH = "agent_loop_manager_state.json" + _STATUS_POLL_INTERVAL_S = 1.0 + + def __init__( + self, + task_runners: list[_TaskRunner], + replay_buffer: ReplayBuffer, + logger=None, + ): + if not task_runners: + raise ValueError("AgentLoopManager requires at least one task runner.") + if sum(task.weight for task in task_runners) <= 0: + raise ValueError("At least one task weight must be positive for AgentLoopManager.") + + self.task_runners = task_runners + self.replay_buffer = replay_buffer + self.data_sampler = ( + task_runners[0].sampler + if len(task_runners) == 1 + else _TaskSamplerView([task.sampler for task in task_runners]) + ) + self.name = task_runners[0].task_name if len(task_runners) == 1 else "multi_task" + if logger is None: + self.logger = get_logger() + else: + self.logger = logger + + # 非共卡并发控制信号:consumer 在同步权重前置位,producer / strategy 应直接观察 + # event 状态并尽快停止继续发新 rollout;不要用额外布尔快照替代这个 event。 + self._update_event = asyncio.Event() + + self._finish_event = asyncio.Event() + + # 非共卡 producer 读取的 model_step:rollout 侧当前使用的是哪个 train_step 同步后的模型。 + # consumer 完成权重同步后通过 continue_produce 更新;已 schedule 的 pending task + # 必须在 strategy 内绑定发起时的 model_step,不能在 task 完成时再读取最新值。 + self._model_step = 0 + + # 非共卡 producer / consumer 共享的控制状态。produce_loop / get_batch 应直接读取 + # self._status,不要跨 await 缓存局部快照,避免错过同步、过期或结束状态变化。 + self._status = AgentLoopManagerStatus.NORMAL + + # pause_produce 写入、下一次 get batch 读取并清零的耗时指标。 + # 只用于消费侧日志/metrics;读写不构成生产正确性依赖。 + self._pause_time_s = 0.0 + + # 非共卡 producer / consumer 共享的绝对累计进度。对象引用必须保持稳定; + # consumer 原地更新字段,producer / strategy 需要字段值时直接读取 progress.xxx, + # 不要把字段值复制成跨 await 使用的局部快照。 + self._produce_progress = ProduceProgress( + next_consumer_step=1, + producer_future_step=1, + consumed_samples={task.task_name: 0 for task in self.task_runners}, + target_samples={task.task_name: 0 for task in self.task_runners}, + target_upto_future_step=0, + ) + + def get_task_batch_sizes(self, global_batch_size: int, train_step: int) -> dict[str, int]: + """Return the per-task batch sizes for the current train step. + + Subclasses may override this method to implement custom dynamic batch allocation policies. Returning 0 for a + task effectively disables that task for the current produce_batch call. + """ + if global_batch_size < 0: + raise ValueError(f"global_batch_size must be non-negative, got {global_batch_size}") + + total_weight = sum(task.weight for task in self.task_runners) + if total_weight <= 0: + raise ValueError("Sum of task weights must be positive.") + if global_batch_size == 0: + return {task.task_name: 0 for task in self.task_runners} + + raw_allocations = [global_batch_size * task.weight / total_weight for task in self.task_runners] + floor_allocations = [math.floor(raw) for raw in raw_allocations] + remaining = global_batch_size - sum(floor_allocations) + + task_batch_sizes = {task.task_name: floor_allocations[idx] for idx, task in enumerate(self.task_runners)} + if remaining <= 0: + return task_batch_sizes + + ranked_tasks = sorted( + enumerate(self.task_runners), + key=lambda item: ( + -(raw_allocations[item[0]] - floor_allocations[item[0]]), + item[1].order, + ), + ) + for idx, task in ranked_tasks[:remaining]: + task_batch_sizes[task.task_name] += 1 + return task_batch_sizes + + def _validate_task_batch_sizes(self, task_batch_sizes: dict[str, int], global_batch_size: int) -> None: + expected_task_names = {task.task_name for task in self.task_runners} + actual_task_names = set(task_batch_sizes.keys()) + if actual_task_names != expected_task_names: + missing_task_names = expected_task_names - actual_task_names + extra_task_names = actual_task_names - expected_task_names + raise ValueError( + "Invalid task batch sizes returned by get_task_batch_sizes: " + f"missing={sorted(missing_task_names)}, extra={sorted(extra_task_names)}" + ) + + negative_batch_sizes = { + task_name: task_batch_size + for task_name, task_batch_size in task_batch_sizes.items() + if task_batch_size < 0 + } + if negative_batch_sizes: + raise ValueError(f"Task batch sizes must be non-negative, got {negative_batch_sizes}") + + total_batch_size = sum(task_batch_sizes.values()) + if total_batch_size != global_batch_size: + raise ValueError( + "Task batch sizes must sum to the requested global batch size, " + f"got total={total_batch_size}, expected={global_batch_size}" + ) + + def _ensure_target_upto(self, batch_size: int, current_future_step: int) -> None: + progress = self._produce_progress + if current_future_step <= progress.target_upto_future_step: + return + + for future_step in range(progress.target_upto_future_step + 1, current_future_step + 1): + if len(self.task_runners) == 1: + progress.target_samples[self.task_runners[0].task_name] += batch_size + else: + task_batch_sizes = self.get_task_batch_sizes(batch_size, future_step) + self._validate_task_batch_sizes(task_batch_sizes, batch_size) + for task_name, task_batch_size in task_batch_sizes.items(): + progress.target_samples[task_name] += task_batch_size + + progress.target_upto_future_step = current_future_step + + def _any_task_model_expired(self, current_future_step: int) -> bool: + expired_tasks = [ + task.task_name + for task in self.task_runners + if isinstance(task.produce_strategy, AsyncProduceStrategy) + and task.produce_strategy.is_model_expired(current_future_step, self._model_step) + ] + if expired_tasks: + self.logger.info(f"Expired future_step={current_future_step}, tasks={expired_tasks}") + return True + return False + + async def _refresh_for_all_tasks(self, train_step: int, statuses: list[Status]) -> None: + XTUNER_REFRESH_STALENESS = os.environ.get("XTUNER_REFRESH_STALENESS", "1") == "1" + self.logger.debug(f"[AgentLoopManager][{self.name}] XTUNER_REFRESH_STALENESS={XTUNER_REFRESH_STALENESS}") + for task in self.task_runners: + if XTUNER_REFRESH_STALENESS: + # TODO: 同步Colocate训练,都必须走这个分支才能保证精度正常, + # 但是逻辑与下面分支有何不同? 查清原因后删除分支判断 + stale_threshold = getattr(task.produce_strategy, "stale_threshold", 1) + else: + # 同步生产没有跨权重版本的后台样本,只有异步 strategy 需要刷新并淘汰历史样本。 + stale_threshold = getattr(task.produce_strategy, "stale_threshold", None) + if stale_threshold is None: + expired_count = await self.replay_buffer.count( + task_name=task.task_name, group_status=Status.EXPIRED + ) + self.logger.info( + f"[AgentLoopManager][{self.name}] Skip Refresh staleness for task {task.task_name}: expired_count={expired_count}" + ) + continue + + expired_count = await self.replay_buffer.refresh_staleness( + task_name=task.task_name, + current_train_step=train_step, + stale_threshold=stale_threshold, + statuses=statuses, + ) + self.logger.info( + f"[AgentLoopManager][{self.name}] Refresh staleness for task {task.task_name}: expired_count={expired_count}" + ) + + def _get_task_batch_sizes_for_step(self, batch_size: int, train_step: int) -> dict[str, int]: + if len(self.task_runners) == 1: + return {self.task_runners[0].task_name: batch_size} + + task_batch_sizes = self.get_task_batch_sizes(batch_size, train_step) + self._validate_task_batch_sizes(task_batch_sizes, batch_size) + return task_batch_sizes + + def _build_local_produce_progress( + self, + task_batch_sizes: dict[str, int], + train_step: int, + ) -> ProduceProgress: + return ProduceProgress( + next_consumer_step=train_step, + producer_future_step=train_step, + consumed_samples={task.task_name: 0 for task in self.task_runners}, + target_samples=dict(task_batch_sizes), + target_upto_future_step=train_step, + ) + + @staticmethod + def _aggregate_task_results( + ordered_tasks: list[_TaskRunner], task_results: dict[str, ProduceBatchResult] + ) -> ProduceBatchResult: + rollout_states: list[list[RolloutState]] = [] + leftover_completed = 0 + leftover_aborted = 0 + leftover_expired = 0 + total_group_count = 0 + weighted_group_mean_sum = 0.0 + weighted_group_p50_sum = 0.0 + weighted_group_p99_sum = 0.0 + weighted_group_ratio_sum = 0.0 + total_pause_time_s = 0.0 + + for task in ordered_tasks: + result = task_results[task.task_name] + rollout_states.extend(result.rollout_states) + leftover_completed += result.leftover_completed + leftover_aborted += result.leftover_aborted + leftover_expired += result.leftover_expired + if result.group_gen_count is not None and result.group_gen_mean_s is not None: + total_group_count += result.group_gen_count + weighted_group_mean_sum += result.group_gen_count * result.group_gen_mean_s + weighted_group_p50_sum += result.group_gen_count * (result.group_gen_p50_s or 0.0) + weighted_group_p99_sum += result.group_gen_count * (result.group_gen_p99_s or 0.0) + weighted_group_ratio_sum += result.group_gen_count * (result.group_gen_p99_p50_ratio or 0.0) + total_pause_time_s += result.group_gen_pause_time_s or 0.0 + + aggregated = ProduceBatchResult( + rollout_states=rollout_states, + leftover_completed=leftover_completed, + leftover_aborted=leftover_aborted, + leftover_expired=leftover_expired, + task_results={task.task_name: task_results[task.task_name] for task in ordered_tasks}, + ) + if total_group_count > 0: + aggregated.group_gen_count = total_group_count + aggregated.group_gen_mean_s = weighted_group_mean_sum / total_group_count + aggregated.group_gen_p50_s = weighted_group_p50_sum / total_group_count + aggregated.group_gen_p99_s = weighted_group_p99_sum / total_group_count + aggregated.group_gen_p99_p50_ratio = weighted_group_ratio_sum / total_group_count + aggregated.group_gen_pause_time_s = total_pause_time_s + return aggregated + + async def _produce_batch_to_buffer( + self, + batch_size: int, + progress: ProduceProgress, + *, + task_batch_sizes: dict[str, int] | None = None, + ) -> ProduceBatchStatus: + current_future_step = progress.producer_future_step + model_step = self._model_step + current_sizes = ( + self._get_task_batch_sizes_for_step(batch_size, current_future_step) + if task_batch_sizes is None + else task_batch_sizes + ) + self._validate_task_batch_sizes(current_sizes, batch_size) + + if progress is self._produce_progress: + # 只有后台生产循环使用全局 progress,需要在这里推进累计 target; + # colocate 路径传入的是一次性本地 progress,不能污染全局计数。 + self._ensure_target_upto(batch_size, current_future_step) + + if self._any_task_model_expired(current_future_step): + self.logger.info( + f"[AgentLoopManager][{self.name}] EXPIRED_BATCH: any task model expired at future_step {current_future_step}" + ) + return ProduceBatchStatus.EXPIRED_BATCH + + if len(self.task_runners) == 1: + task = self.task_runners[0] + self.logger.info(f"[AgentLoopManager][{self.name}] produce_to_buffer start batch={batch_size}") + return await _produce_single_task_to_buffer( + task_runner=task, + replay_buffer=self.replay_buffer, + batch_size=current_sizes[task.task_name], + train_step=current_future_step, + model_step=model_step, + update_event=self._update_event, + progress=progress, + ) + + active_tasks = [task for task in self.task_runners if progress.target_samples[task.task_name] > 0] + assert active_tasks, "No active tasks found" + + statuses = await asyncio.gather( + *[ + _produce_single_task_to_buffer( + task_runner=task, + replay_buffer=self.replay_buffer, + batch_size=current_sizes[task.task_name], + train_step=current_future_step, + model_step=model_step, + update_event=self._update_event, + progress=progress, + ) + for task in active_tasks + ] + ) + return _aggregate_status(statuses) + + async def pause_produce( + self, + *, + use_global_progress: bool, + progress: ProduceProgress | None = None, + ) -> float: + # 这是 producer 的“显式刹车”接口。 + # + # 设计动机: + # - 旧 colocate 语义里,一次 produce_batch() 结束后就自然收尾; + # - 非共卡后,producer 可能在后台持续运行,何时停下来必须交给 trainer 明确控制。 + # + # 因此调用方必须显式说明是否使用全局 progress: + # - use_global_progress=True:非共卡后台生产循环在权重同步点前暂停; + # - use_global_progress=False:共卡同步 produce_batch 的本次调用收尾,使用本地 progress。 + # 返回值 `pause_time_s` 不是业务语义,而是日志/诊断信息, + # 供训练侧在下一次消费 batch 时上报。 + # use_global_progress=False 模式会在下一次 produce_batch 入口通过 continue_produce 恢复; + # use_global_progress=True 模式则由 trainer 在权重同步和评测完成后显式恢复。 + if use_global_progress: + if progress is not None: + raise ValueError("progress must not be provided when use_global_progress=True.") + pause_progress = self._produce_progress + else: + if progress is None: + raise ValueError("progress must be provided when use_global_progress=False.") + pause_progress = progress + + # 合法参数确认后,统一拉起 manager 级暂停信号,阻止仍在运行的 produce_batch 继续调度新 rollout。 + self._update_event.set() + self._status = AgentLoopManagerStatus.UPDATE_ABORT + pause_time_s = 0.0 + for task in self.task_runners: + pause_time_s += await task.produce_strategy.pause_produce( + task.agent_loop, + self.replay_buffer, + task.task_name, + progress=pause_progress, + ) + self._pause_time_s = pause_time_s + return pause_time_s + + async def _get_single_task_batch_from_buffer( + self, + task_runner: _TaskRunner, + batch_size: int, + train_step: int, + consume_progress: ProduceProgress | None = None, + ) -> ProduceBatchResult: + result = ProduceBatchResult(rollout_states=[]) + batch_rollout_states: list[list[RolloutState]] = await self.replay_buffer.get( + batch_size, task_runner.task_name, Status.COMPLETED + ) + result.rollout_states = batch_rollout_states + if consume_progress is not None: + # get 已从 buffer 删除样本,立刻更新 consumed,避免 producer 短暂误判缺口。 + consume_progress.consumed_samples[task_runner.task_name] += len(batch_rollout_states) + completed_sample_count, aborted_sample_count, expired_sample_count = await asyncio.gather( + self.replay_buffer.count(task_name=task_runner.task_name, group_status=Status.COMPLETED), + self.replay_buffer.count(task_name=task_runner.task_name, group_status=Status.ABORTED), + self.replay_buffer.count(task_name=task_runner.task_name, group_status=Status.EXPIRED), + ) + result.leftover_completed = completed_sample_count + result.leftover_aborted = aborted_sample_count + result.leftover_expired = expired_sample_count + return result + + async def _get_batch_from_buffer( + self, + batch_size: int, + train_step: int, + consume_progress: ProduceProgress | None = None, + task_batch_sizes: dict[str, int] | None = None, + ) -> ProduceBatchResult: + pause_time_s = self._pause_time_s + self._pause_time_s = 0.0 + + if len(self.task_runners) == 1: + task = self.task_runners[0] + result = await self._get_single_task_batch_from_buffer( + task, + batch_size, + train_step, + consume_progress=consume_progress, + ) + _fill_group_timing_stats(result, result.rollout_states, pause_time_s=pause_time_s) + return result + + if task_batch_sizes is None: + task_batch_sizes = self._get_task_batch_sizes_for_step(batch_size, train_step) + else: + self._validate_task_batch_sizes(task_batch_sizes, batch_size) + active_tasks = [task for task in self.task_runners if task_batch_sizes[task.task_name] > 0] + results = ( + await asyncio.gather( + *[ + self._get_single_task_batch_from_buffer( + task, + task_batch_sizes[task.task_name], + train_step, + consume_progress=consume_progress, + ) + for task in active_tasks + ] + ) + if active_tasks + else [] + ) + + task_results = {task.task_name: result for task, result in zip(active_tasks, results)} + for task in self.task_runners: + if task.task_name not in task_results: + task_results[task.task_name] = ProduceBatchResult(rollout_states=[]) + + ordered_tasks = sorted(self.task_runners, key=lambda task: (task.task_name, task.order)) + aggregated = self._aggregate_task_results(ordered_tasks, task_results) + aggregated.task_batch_sizes = {task.task_name: task_batch_sizes[task.task_name] for task in ordered_tasks} + _fill_group_timing_stats(aggregated, aggregated.rollout_states, pause_time_s=pause_time_s) + return aggregated + + async def _is_batch_ready(self, batch_size: int, train_step: int) -> bool: + if len(self.task_runners) == 1: + task = self.task_runners[0] + completed_count = await self.replay_buffer.count(task_name=task.task_name, group_status=Status.COMPLETED) + return completed_count >= batch_size + + task_batch_sizes = self._get_task_batch_sizes_for_step(batch_size, train_step) + active_tasks = [task for task in self.task_runners if task_batch_sizes[task.task_name] > 0] + if not active_tasks: + return True + + completed_counts = await asyncio.gather( + *[ + self.replay_buffer.count(task_name=task.task_name, group_status=Status.COMPLETED) + for task in active_tasks + ] + ) + return all( + completed_count >= task_batch_sizes[task.task_name] + for task, completed_count in zip(active_tasks, completed_counts) + ) + + # continue_produce 的语义是“producer 可以恢复工作了”。 + + def continue_produce(self, model_step: int) -> None: + # + # 它和 pause_produce(use_global_progress=True) 是一对: + # - pause_produce(...) 负责让 producer 停下来; + # - continue_produce(...) 负责在同步/评测完成后解除暂停。 + # + # 这里同步更新 `_model_step`,表示 rollout 侧接下来生成样本时, + # 应把“当前正在使用的是哪一版权重”记录成这个版本号。 + self._status = AgentLoopManagerStatus.NORMAL + self._model_step = model_step + self._update_event.clear() + + async def _wait_for_status_exit(self, blocked_status: AgentLoopManagerStatus) -> None: + while not self._finish_event.is_set() and self._status == blocked_status: + await asyncio.sleep(self._STATUS_POLL_INTERVAL_S) + + async def produce_batch( + self, + batch_size: int, + train_step: int, + *, + model_step: int, + ) -> ProduceBatchResult: + # `produce_batch()` 是保留给 colocate 路径的同步入口。 + # + # 它虽然名字没变,但内部已经改成三段式: + # 1. `_produce_batch_to_buffer()` 只负责生产,把结果写入 replay buffer + # 2. `pause_produce()` 显式收尾 pending rollout + # 3. `_get_batch_from_buffer()` 再把训练 batch 取出来 + # + # 这也是为什么这里要求返回非空 batch: + # - colocate 语义下,调用它就是为了拿一批可训练 completed groups + # - 如果需要合法返回空 batch + 特殊状态,那应该走 disagg 的 `get_batch()` + if batch_size <= 0: + raise ValueError(f"produce_batch expects batch_size > 0, got {batch_size}") + start = time.perf_counter() + self.logger.info( + f"[AgentLoopManager][{self.name}] Start produce_batch: train_step={train_step} model_step={model_step} batch_size={batch_size}" + ) + current_sizes = self._get_task_batch_sizes_for_step(batch_size, train_step) + active_tasks = [task for task in self.task_runners if current_sizes[task.task_name] > 0] + assert active_tasks, "No active tasks found" + + rollout_ctl = await get_agent_loop_rollout_ctl(active_tasks[0].agent_loop) + await continue_generation(rollout_ctl) + try: + # 共卡路径不复用非共卡的 paused producer 状态机。 + # 即使 manager 是从 resume() 恢复出来、当前仍处在 UPDATE_ABORT, + # produce_batch() 也应视作一次独立的同步生产过程,从干净状态开始。 + # + # 共卡路径下,produce_batch() 对应 rollout worker 当前持有的权重版本。 + self.continue_produce(model_step=model_step) + # 共卡 produce_batch 也是消费入口;生产前先刷新 buffer 中已有 completed / aborted。 + await self._refresh_for_all_tasks(train_step, [Status.COMPLETED, Status.ABORTED]) + local_progress = self._build_local_produce_progress(current_sizes, train_step) + status = await self._produce_batch_to_buffer( + batch_size=batch_size, + progress=local_progress, + task_batch_sizes=current_sizes, + ) + await self.pause_produce( + use_global_progress=False, + progress=local_progress, + ) + result = await self._get_batch_from_buffer( + batch_size=batch_size, + train_step=train_step, + consume_progress=local_progress, + task_batch_sizes=current_sizes, + ) + result.status = status + assert result.rollout_states, ( + "AgentLoopManager.produce_batch() must return non-empty rollout_states for colocated training. " + "Use get_batch() for disaggregated empty/expired reads." + ) + finally: + await pause_generation(rollout_ctl) + + self.logger.info( + f"[AgentLoopManager][{self.name}] produce_batch done " + f"elapsed={time.perf_counter() - start:.3f}, completed_groups={len(result.rollout_states)}" + ) + return result + + async def produce_loop(self, batch_size: int) -> None: + # `produce_loop()` 是非共卡新增的后台生产循环。 + # batch_size 表示每个 future train_step 的目标生产规模;producer 需要它来推进累计目标, + # 所以这个参数保留在后台生产入口,而不是从 get_batch() 的消费请求里推断。 + # + # 和 colocate 最大的区别是: + # - 它不直接把 batch 返回给 trainer + # - 它只是持续把样本“喂”进 replay buffer + # - trainer 前台通过 `get_batch()` 异步消费 + # + # 因此这里的核心职责不是“凑出一批训练数据”,而是根据 manager 的全局状态机 + # 决定什么时候继续生产、什么时候暂停等待、什么时候彻底退出。 + while not self._finish_event.is_set(): + if self._status == AgentLoopManagerStatus.FINISH: + break + if self._status == AgentLoopManagerStatus.UPDATE_ABORT: + # trainer 已经发出了“准备同步权重”的信号。 + # producer 在这里阻塞等待 continue_produce(),而不是自己擅自恢复。 + await self._wait_for_status_exit(AgentLoopManagerStatus.UPDATE_ABORT) + continue + if self._status == AgentLoopManagerStatus.EXPIRED_BATCH: + # 当前 rollout 权重已经过旧。 + # 这里继续等待 trainer 完成同步,再通过 continue_produce() 恢复。 + await self._wait_for_status_exit(AgentLoopManagerStatus.EXPIRED_BATCH) + continue + + rollout_ctl = await get_agent_loop_rollout_ctl(self.task_runners[0].agent_loop) + await continue_generation(rollout_ctl) + produce_status = await self._produce_batch_to_buffer( + batch_size=batch_size, + progress=self._produce_progress, + ) + + if produce_status == ProduceBatchStatus.EXPIRED_BATCH: + # 注意: + # - EXPIRED_BATCH 是 producer 在生产过程中自己检测出来的“立即停下”信号 + # - UPDATE_ABORT 则是 trainer 在同步前通过 pause_produce() 主动设置的 + self._status = AgentLoopManagerStatus.EXPIRED_BATCH + elif produce_status == ProduceBatchStatus.NORMAL: + # 只有正常完成一轮生产时,producer 自己维护的 train_step 才前进一步。 + self._produce_progress.producer_future_step += 1 + + # 主动让出事件循环,避免 fake strategy / 极快路径在测试里造成忙等空转。 + await asyncio.sleep(0) + + async def get_batch(self, batch_size: int, train_step: int) -> ProduceBatchResult: + # `get_batch()` 是非共卡路径给 trainer 的消费接口。 + # + # 设计上它和 `produce_batch()` 明确分工: + # - `produce_batch()`:colocate,一次调用内完成“生产+收尾+取数” + # - `get_batch()`:disagg,等待 replay buffer 准备好当前训练步所需 batch 后再取数 + # + # 因而这里允许返回空 batch 的唯一合法场景仍然只有: + # - 当 manager 已进入 EXPIRED_BATCH,返回空 batch + 状态信号 + # - trainer 看到后应跳过训练,优先去做权重同步 + progress = self._produce_progress + progress.next_consumer_step = train_step + await self._refresh_for_all_tasks(train_step, [Status.COMPLETED, Status.ABORTED]) + + while not self._finish_event.is_set(): + if self._status == AgentLoopManagerStatus.EXPIRED_BATCH: + return ProduceBatchResult( + rollout_states=[], + status=ProduceBatchStatus.EXPIRED_BATCH, + ) + # TODO: call self.get_task_batch_sizes before while instead of below two functions + if await self._is_batch_ready(batch_size=batch_size, train_step=train_step): + result = await self._get_batch_from_buffer( + batch_size=batch_size, + train_step=train_step, + consume_progress=progress, + ) + if result.rollout_states: + progress.next_consumer_step = train_step + 1 + await self._refresh_for_all_tasks(train_step + 1, [Status.COMPLETED, Status.ABORTED]) + return result + await asyncio.sleep(self._STATUS_POLL_INTERVAL_S) + + return ProduceBatchResult(rollout_states=[]) + + def _task_checkpoint_path(self, checkpoint_path: Path | str, task_name: str) -> Path: + checkpoint_path = Path(checkpoint_path) + return checkpoint_path / self._TASK_CHECKPOINT_DIR / task_name + + def _manager_state_path(self, checkpoint_path: Path | str) -> Path: + checkpoint_path = Path(checkpoint_path) + return checkpoint_path / self._MANAGER_STATE_PATH + + def _get_pending_task_counts(self) -> dict[str, int]: + pending_task_counts: dict[str, int] = {} + for task in self.task_runners: + pending_tasks = getattr(task.produce_strategy, "_pending_tasks", None) + if pending_tasks: + pending_task_counts[task.task_name] = len(pending_tasks) + return pending_task_counts + + def save(self, checkpoint_path: Path | str, model_step: int) -> None: + """Save all task sampler states and the shared replay buffer.""" + checkpoint_path = Path(checkpoint_path) + checkpoint_path.mkdir(parents=True, exist_ok=True) + pending_task_counts = self._get_pending_task_counts() + if pending_task_counts: + raise RuntimeError( + "Cannot save AgentLoopManager while pending rollout tasks still exist: " + f"{pending_task_counts}. Call pause_produce() first." + ) + # 保存前显式记录当前 checkpoint 对应的模型步数,resume 时直接恢复这一份状态。 + self._model_step = model_step + for task in self.task_runners: + task_checkpoint_path = self._task_checkpoint_path(checkpoint_path, task.task_name) + task_checkpoint_path.mkdir(parents=True, exist_ok=True) + task.sampler.save(task_checkpoint_path) + asyncio_run(self.replay_buffer.save(checkpoint_path)) + manager_state_path = self._manager_state_path(checkpoint_path) + progress = self._produce_progress + with manager_state_path.open("w") as f: + json.dump( + { + "status": self._status.name, + "model_step": self._model_step, + "next_consumer_step": progress.next_consumer_step, + "producer_future_step": progress.producer_future_step, + "consumed_samples": progress.consumed_samples, + "target_samples": progress.target_samples, + "target_upto_future_step": progress.target_upto_future_step, + }, + f, + ) + + def resume(self, checkpoint_path: Path | str) -> int: + """Resume all task sampler states and the shared replay buffer.""" + checkpoint_path = Path(checkpoint_path) + for task in self.task_runners: + task.sampler.resume(self._task_checkpoint_path(checkpoint_path, task.task_name)) + asyncio_run(self.replay_buffer.resume(checkpoint_path)) + + manager_state_path = self._manager_state_path(checkpoint_path) + with manager_state_path.open("r") as f: + manager_state = json.load(f) + saved_model_step = manager_state["model_step"] + progress = self._produce_progress + progress.next_consumer_step = manager_state["next_consumer_step"] + progress.producer_future_step = manager_state["producer_future_step"] + progress.target_upto_future_step = manager_state["target_upto_future_step"] + + # dict 原地更新,避免 strategy 持有旧引用。 + progress.consumed_samples.clear() + progress.consumed_samples.update(manager_state["consumed_samples"]) + progress.target_samples.clear() + progress.target_samples.update(manager_state["target_samples"]) + + self._update_event = asyncio.Event() + self._finish_event = asyncio.Event() + self._update_event.set() + self._status = AgentLoopManagerStatus.UPDATE_ABORT + self._pause_time_s = 0.0 + self._model_step = saved_model_step + return saved_model_step diff --git a/xtuner/v1/rl/agent_loop_manager/producer.py b/xtuner/v1/rl/agent_loop_manager/producer.py new file mode 100644 index 0000000000..002cf0b2d9 --- /dev/null +++ b/xtuner/v1/rl/agent_loop_manager/producer.py @@ -0,0 +1,608 @@ +import asyncio +import math +import time +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from enum import Enum, auto +from typing import Protocol, runtime_checkable + +import ray +from pydantic import BaseModel, ConfigDict, Field + +from xtuner.v1.data_proto.rl_data import ( + RolloutState, + Status, + refresh_seq_staleness, + update_group_status, + update_sample_version, +) +from xtuner.v1.rl.agent_loop import AgentLoopSpec, get_agent_loop_rollout_ctl +from xtuner.v1.rl.replay_buffer import ReplayBuffer +from xtuner.v1.rl.rollout.utils import pause_generation +from xtuner.v1.rl.utils import calculate_seq_staleness, create_task +from xtuner.v1.utils import get_logger + +from .sampler import Sampler + + +logger = get_logger() +GROUP_GENERATE_TIME_KEY = "group_generate_time_s" + + +@dataclass +class ProduceProgress: + """生产者和消费者共享的 live 进度对象。 + + 设计目标: + - Manager / 调用方负责初始化并原地更新这个对象,strategy 只接收引用并读取最新进度。 + - target / consumed 使用全局绝对累计口径,避免 consumer 取走 buffer 中的 completed 后, + producer 把已消费样本误判成缺口并重复补发。 + - 同一套语义同时服务非共卡全局 progress 和共卡 produce_batch 的局部 progress。 + + 使用注意: + - 不要在 strategy 中补 key 或用 dict.get(..., 0) 兜底;缺少 task key 应 fail fast。 + - 除非语义明确要求冻结本轮 produce_batch 的 target / scheduled_target, + 否则不要把字段值复制成局部快照后跨 await 使用;需要字段值时直接读 progress.xxx, + 让并发更新后的 next_consumer_step / consumed_samples 能尽早生效。 + - 运行中不要整体替换 ProduceProgress 对象;resume 时也应原地更新字段,避免旧引用失效。 + + 字段含义: + - next_consumer_step:producer 写入新样本时应面向的训练 step。get_batch(i) 入口设为 i, + 成功取出非空 batch 后设为 i + 1。 + - producer_future_step:producer 当前准备生产的 future step。 + - consumed_samples:各 task 已被 consumer 从 replay buffer 取走的 group 绝对累计数。 + - target_samples:各 task 截至 target_upto_future_step 应生产出的 group 绝对累计目标。 + - target_upto_future_step:target_samples 已覆盖到的最大 future step。 + """ + + next_consumer_step: int = 1 + producer_future_step: int = 1 + consumed_samples: dict[str, int] = field(default_factory=dict) + target_samples: dict[str, int] = field(default_factory=dict) + target_upto_future_step: int = 0 + + +class ProduceBatchStatus(Enum): + NORMAL = auto() + UPDATE_ABORT = auto() + EXPIRED_BATCH = auto() + + +async def _timed_generate_group( + agent_loop: AgentLoopSpec, + rollout_state: list[RolloutState], + enable_partial_rollout: bool = False, +) -> list[RolloutState]: + start = time.perf_counter() + if isinstance(agent_loop, ray.actor.ActorHandle): + result = await agent_loop.generate_group.remote( + rollout_state, + enable_partial_rollout=enable_partial_rollout, + ) + else: + result = await agent_loop.generate_group( + rollout_state, + enable_partial_rollout=enable_partial_rollout, + ) + elapsed = time.perf_counter() - start + for item in result: + extra_fields = getattr(item, "extra_fields", None) + if extra_fields is None: + extra_fields = {} + setattr(item, "extra_fields", extra_fields) + extra_fields[GROUP_GENERATE_TIME_KEY] = elapsed + return result + + +def default_is_valid_sample_fn(samples: list[RolloutState]) -> bool: + return all(sample.status == Status.COMPLETED for sample in samples) + + +def default_should_continue_fn(completed_count: int, batch_size: int, **kwargs) -> bool: + return completed_count < batch_size + + +def calculate_stale_threshold(max_staleness: int, sync_weights_interval: int) -> int: + if max_staleness < 0: + raise ValueError(f"max_staleness must be non-negative, got {max_staleness}.") + if sync_weights_interval <= 0: + raise ValueError(f"sync_weights_interval must be positive, got {sync_weights_interval}.") + + # max_staleness 按同步周期计数;+1 表示训练天然必须接受的当前同步周期滞后。 + return (max_staleness + 1) * sync_weights_interval + + +def expire_group_if_needed(group: list[RolloutState], stale_threshold: int) -> list[RolloutState]: + if stale_threshold <= 0: + raise ValueError(f"stale_threshold must be positive, got {stale_threshold}.") + + group_status = update_group_status(group) + if group_status not in (Status.COMPLETED, Status.ABORTED): + return group + if any(getattr(sample, "seq_staleness", 0) >= stale_threshold for sample in group): + # completed / aborted 只要组内任一样本过期,就整组转为 EXPIRED。 + for sample in group: + sample.status = Status.EXPIRED + return group + + +def _validate_progress_for_task( + progress: ProduceProgress, + task_name: str, + target_cumulative: int | None, +) -> None: + if task_name not in progress.consumed_samples: + raise KeyError(f"ProduceProgress.consumed_samples missing task_name={task_name!r}") + if task_name not in progress.target_samples: + raise KeyError(f"ProduceProgress.target_samples missing task_name={task_name!r}") + + if target_cumulative is not None and target_cumulative != progress.target_samples[task_name]: + raise ValueError( + "target_cumulative must match progress.target_samples when progress is provided, " + f"got target_cumulative={target_cumulative}, " + f"progress.target_samples[{task_name!r}]={progress.target_samples[task_name]}" + ) + + +@runtime_checkable +class IsValidSampleFn(Protocol): + def __call__(self, samples: list[RolloutState]) -> bool: ... + + +@runtime_checkable +class ShouldContinueFn(Protocol): + def __call__(self, completed_count: int, batch_size: int, **kwargs) -> bool: ... + + +class ProduceStrategyConfig(ABC, BaseModel): + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + is_valid_sample_fn: IsValidSampleFn = default_is_valid_sample_fn + should_continue_fn: ShouldContinueFn = default_should_continue_fn + + @abstractmethod + def build(self, *, sync_weights_interval: int = 1) -> "ProduceStrategy": ... + + +class SyncProduceStrategyConfig(ProduceStrategyConfig): + def build(self, *, sync_weights_interval: int = 1) -> "SyncProduceStrategy": + return SyncProduceStrategy( + is_valid_sample_fn=self.is_valid_sample_fn, should_continue_fn=self.should_continue_fn + ) + + +class AsyncProduceStrategyConfig(ProduceStrategyConfig): + over_sample_threshold: float = 0.0 + enable_partial_rollout: bool = False + max_staleness: int = Field(default=0, ge=0) + tail_batch_trigger_size: int = 0 + + def build(self, *, sync_weights_interval: int = 1) -> "AsyncProduceStrategy": + return AsyncProduceStrategy( + over_sample_threshold=self.over_sample_threshold, + enable_partial_rollout=self.enable_partial_rollout, + max_staleness=self.max_staleness, + sync_weights_interval=sync_weights_interval, + tail_batch_trigger_size=self.tail_batch_trigger_size, + is_valid_sample_fn=self.is_valid_sample_fn, + should_continue_fn=self.should_continue_fn, + ) + + +class ProduceStrategy(ABC): + def __init__( + self, + is_valid_sample_fn: IsValidSampleFn, + should_continue_fn: ShouldContinueFn, + ): + self.is_valid_sample_fn = is_valid_sample_fn + self.should_continue_fn = should_continue_fn + + @abstractmethod + async def produce_batch( + self, + agent_loop: AgentLoopSpec, + sampler: Sampler, + replay_buffer: ReplayBuffer, + batch_size: int, + task_name: str, + train_step: int = 0, + update_event: asyncio.Event | None = None, + *, + model_step: int, + progress: ProduceProgress, + target_cumulative: int | None = None, + ) -> ProduceBatchStatus: ... + + async def pause_produce( + self, + agent_loop: AgentLoopSpec, + replay_buffer: ReplayBuffer, + task_name: str, + *, + progress: ProduceProgress, + ) -> float: + return 0.0 + + +class SyncProduceStrategy(ProduceStrategy): + async def produce_batch( + self, + agent_loop: AgentLoopSpec, + sampler: Sampler, + replay_buffer: ReplayBuffer, + batch_size: int, + task_name: str, + train_step: int = 0, + update_event: asyncio.Event | None = None, + *, + model_step: int, + progress: ProduceProgress, + target_cumulative: int | None = None, + ) -> ProduceBatchStatus: + pending_tasks = set() + completed_sample_count = await replay_buffer.count(task_name=task_name, group_status=Status.COMPLETED) + # TODO: 是否支持 SyncProduceStrategy 在非共卡时使用?如果支持,下面这行注释掉? + # assert completed_sample_count == 0, "SyncProduceStrategy assumes no completed samples at the start." + + for _ in range(batch_size): + rollout_state = await sampler.sample(task_name=task_name) + task = create_task( + _timed_generate_group( + agent_loop, + rollout_state, + ) + ) + pending_tasks.add(task) + + logger.info(f"[SyncProduceStrategy] Started {len(pending_tasks)} initial tasks.") + + while self.should_continue_fn(completed_sample_count, batch_size): + if not pending_tasks: + logger.warning("[SyncProduceStrategy] All tasks are done but not enough samples collected.") + break + done_tasks, pending_tasks = await asyncio.wait( + pending_tasks, timeout=1, return_when=asyncio.FIRST_COMPLETED + ) + # 如果要过滤,在这个地方处理,然后加入到 replay buffer + # 如果被过滤的数据就放到 put_to_filtered pool 中 + for task in done_tasks: + items = task.result() + is_valid = self.is_valid_sample_fn(items) + for item in items: + update_sample_version(item, model_step) + refresh_seq_staleness(items, train_step) + await replay_buffer.put(items, task_name) + if not is_valid: + continue + + completed_sample_count += 1 + if progress.target_samples[task_name] > 0: + logger.info( + f"[{self.__class__.__name__}] Collected " + f"{min(progress.target_samples[task_name], max(0, completed_sample_count))}/" + f"{progress.target_samples[task_name]} " + f"valid samples for task {task_name}." + ) + + while len(pending_tasks) + completed_sample_count < batch_size and self.should_continue_fn( + completed_sample_count, batch_size + ): + rollout_state = await sampler.sample(task_name=task_name) + task = create_task( + _timed_generate_group( + agent_loop, + rollout_state, + ) + ) + pending_tasks.add(task) + + return ProduceBatchStatus.NORMAL + + +class AsyncProduceStrategy(ProduceStrategy): + def __init__( + self, + over_sample_threshold: float, + enable_partial_rollout: bool, + tail_batch_trigger_size: int, + max_staleness: int, + sync_weights_interval: int, + is_valid_sample_fn: IsValidSampleFn, + should_continue_fn: ShouldContinueFn, + ): + super().__init__(is_valid_sample_fn, should_continue_fn) + + # TODO: 需要添加 tail_batch_max_tries + # 作用是:如果一个样本多次重试,则将它置为特殊状态 MAX_TRIES,这类样本和过期样本一起触发tail batch逻辑 + # 这个依赖:RolloutState 添加并维护一个新的属性 num_tries,每次打断时加1,达到 max_tries 时置为 MAX_TRIES + # 如果 enable_partial_rollout=True,不会触发这个逻辑,所以不受此影响 + # 如果 enable_partial_rollout=False,分两种情况: + # 1) staleness = 0,即不允许过期样本,此时过期触发tail batch逻辑已经cover了tail batch逻辑 + # 2) staleness > 0,此时需要 重试tail batch逻辑,否则多次重试的样本会影响rollout 效率 + if not enable_partial_rollout and max_staleness > 0: + logger.warning( + "max_staleness > 0, enable_partial_rollout is False, this will affect rollout efficiency because not support tail_batch_max_tries logic now" + ) + + self.over_sample_threshold = over_sample_threshold + self.enable_partial_rollout = enable_partial_rollout + self.max_staleness = max_staleness + self.sync_weights_interval = sync_weights_interval + self.stale_threshold = calculate_stale_threshold(max_staleness, sync_weights_interval) + self.tail_batch_trigger_size = tail_batch_trigger_size + self._pending_tasks: set[asyncio.Task] = set() + self._pending_task_model_steps: dict[asyncio.Task, int] = {} + self._pending_lock = asyncio.Lock() + + def is_model_expired(self, train_step: int, model_step: int) -> bool: + staleness = calculate_seq_staleness(model_step, train_step) + return staleness >= self.stale_threshold + + def _is_model_expired(self, train_step: int, model_step: int) -> bool: + return self.is_model_expired(train_step, model_step) + + async def _snapshot_pending(self) -> set[asyncio.Task]: + async with self._pending_lock: + return set(self._pending_tasks) + + async def _pending_count(self) -> int: + async with self._pending_lock: + return len(self._pending_tasks) + + async def _claim_done(self, done: set[asyncio.Task]) -> set[asyncio.Task]: + async with self._pending_lock: + claimed = done & self._pending_tasks + self._pending_tasks.difference_update(claimed) + return claimed + + async def _claim_already_done(self) -> set[asyncio.Task]: + async with self._pending_lock: + done = {task for task in self._pending_tasks if task.done()} + self._pending_tasks.difference_update(done) + return done + + async def _put_generated_group( + self, + items: list[RolloutState], + replay_buffer: ReplayBuffer, + task_name: str, + current_train_step: int, + model_step: int, + ) -> bool: + for item in items: + update_sample_version(item, model_step) + refresh_seq_staleness(items, current_train_step) + items = expire_group_if_needed(items, self.stale_threshold) + is_valid = self.is_valid_sample_fn(items) + await replay_buffer.put(items, task_name) + return is_valid + + async def _put_claimed_tasks( + self, + claimed_tasks: set[asyncio.Task], + replay_buffer: ReplayBuffer, + task_name: str, + progress: ProduceProgress, + available_base: int | None = None, + ) -> int: + valid_completed_count = 0 + for task in claimed_tasks: + # 每个 pending task 必须绑定调度时的模型版本;缺失说明调度状态已损坏,直接暴露。 + task_model_step = self._pending_task_model_steps.pop(task) + is_valid = await self._put_generated_group( + task.result(), + replay_buffer, + task_name, + current_train_step=progress.next_consumer_step, + model_step=task_model_step, + ) + if is_valid: + valid_completed_count += 1 + if is_valid and available_base is not None: + if progress.target_samples[task_name] > 0: + logger.info( + f"[{self.__class__.__name__}] Collected " + f"{min(progress.target_samples[task_name], max(0, available_base + valid_completed_count))}/" + f"{progress.target_samples[task_name]} " + f"valid samples for task {task_name}." + ) + return valid_completed_count + + async def _schedule_one( + self, + agent_loop: AgentLoopSpec, + sampler: Sampler, + desired_pending: int, + sample_from_expired: bool, + task_name: str, + model_step: int, + update_event: asyncio.Event | None, + ) -> bool: + async with self._pending_lock: + # update_event 是 manager 级暂停信号;在调度临界区内检查,避免 pause 已触发后继续新增任务。 + if update_event is not None and update_event.is_set(): + return False + if len(self._pending_tasks) >= desired_pending: + return False + group_status = [Status.EXPIRED, Status.ABORTED] if sample_from_expired else [Status.ABORTED] + rollout_state = await sampler.sample(task_name=task_name, group_status=group_status) + task = create_task( + _timed_generate_group( + agent_loop, + rollout_state, + enable_partial_rollout=self.enable_partial_rollout, + ) + ) + self._pending_tasks.add(task) + self._pending_task_model_steps[task] = model_step + return True + + async def _schedule_tasks_until( + self, + agent_loop: AgentLoopSpec, + sampler: Sampler, + task_name: str, + desired_pending: int, + sample_from_expired: bool, + model_step: int, + update_event: asyncio.Event | None, + ) -> None: + while await self._schedule_one( + agent_loop=agent_loop, + sampler=sampler, + desired_pending=desired_pending, + sample_from_expired=sample_from_expired, + task_name=task_name, + model_step=model_step, + update_event=update_event, + ): + pass + + async def pause_produce( + self, + agent_loop: AgentLoopSpec, + replay_buffer: ReplayBuffer, + task_name: str, + *, + progress: ProduceProgress, + ) -> float: + pause_start = time.perf_counter() + if await self._pending_count() == 0: + return 0.0 + + rollout_ctl = await get_agent_loop_rollout_ctl(agent_loop) + await pause_generation(rollout_ctl) + while True: + pending_snapshot = await self._snapshot_pending() + if not pending_snapshot: + break + + done_tasks, _ = await asyncio.wait( + pending_snapshot, + timeout=1, + return_when=asyncio.FIRST_COMPLETED, + ) + claimed_done = await self._claim_done(done_tasks) + for task in claimed_done: + paused_items = task.result() + # pause 可能发生在权重同步之后,但这里仍要使用 task 发起时绑定的模型版本。 + task_model_step = self._pending_task_model_steps.pop(task) + for item in paused_items: + logger.debug( + f"[{self.__class__.__name__}] Task {task_name} | " + f"Collecting paused sample (uid: {item.uid}, status: {item.status}, " + f"length: {len(item.response_ids or [])}) after pausing generation." + ) + await self._put_generated_group( + paused_items, + replay_buffer, + task_name, + current_train_step=progress.next_consumer_step, + model_step=task_model_step, + ) + if await self._pending_count() > 0: + await pause_generation(rollout_ctl) + await asyncio.sleep(1) + return time.perf_counter() - pause_start + + async def produce_batch( + self, + agent_loop: AgentLoopSpec, + sampler: Sampler, + replay_buffer: ReplayBuffer, + batch_size: int, + task_name: str, + train_step: int = 0, + update_event: asyncio.Event | None = None, + *, + model_step: int, + progress: ProduceProgress, + target_cumulative: int | None = None, + ) -> ProduceBatchStatus: + if update_event is None: + update_event = asyncio.Event() + _validate_progress_for_task(progress, task_name, target_cumulative) + + if progress.target_samples[task_name] <= 0: + return ProduceBatchStatus.NORMAL + + if update_event.is_set(): + return ProduceBatchStatus.UPDATE_ABORT + if self.is_model_expired(train_step, model_step): + return ProduceBatchStatus.EXPIRED_BATCH + + # 先回收跨 produce_batch 调用遗留的已完成任务,避免 done task 长期留在 pending 集合里。 + claimed_done = await self._claim_already_done() + await self._put_claimed_tasks( + claimed_done, + replay_buffer, + task_name, + progress, + ) + + if update_event.is_set(): + return ProduceBatchStatus.UPDATE_ABORT + if self.is_model_expired(train_step, model_step): + return ProduceBatchStatus.EXPIRED_BATCH + + expired_count = await replay_buffer.count(task_name=task_name, group_status=Status.EXPIRED) + sample_from_expired = self.tail_batch_trigger_size > 0 and expired_count >= self.tail_batch_trigger_size + if sample_from_expired: + logger.info( + f"Tail batch trigger condition met: {expired_count} expired samples " + f"(threshold: {self.tail_batch_trigger_size}). Enabling tail batch mode." + ) + + # 本轮 produce_batch 的必要累计目标固定;normal 模式只按当前 task batch 追加固定超发预算。 + # tail-batch 模式只补必要缺口,新增任务固定从 EXPIRED pool 取,不再扩大超发窗口。 + target_abs = progress.target_samples[task_name] + oversample_budget = 0 if sample_from_expired else math.ceil(self.over_sample_threshold * batch_size) + scheduled_target = target_abs + oversample_budget + + while True: + if update_event.is_set(): + return ProduceBatchStatus.UPDATE_ABORT + if self.is_model_expired(train_step, model_step): + return ProduceBatchStatus.EXPIRED_BATCH + + fresh = await replay_buffer.count(task_name=task_name, group_status=Status.COMPLETED) + available = progress.consumed_samples[task_name] + fresh + # if available >= target_abs: + if not self.should_continue_fn(available, target_abs): + return ProduceBatchStatus.NORMAL + + pending_count = await self._pending_count() + desired_pending = max(0, scheduled_target - available) + if available + pending_count < scheduled_target: + await self._schedule_tasks_until( + agent_loop=agent_loop, + sampler=sampler, + task_name=task_name, + desired_pending=desired_pending, + sample_from_expired=sample_from_expired, + model_step=model_step, + update_event=update_event, + ) + if update_event.is_set(): + return ProduceBatchStatus.UPDATE_ABORT + + pending_snapshot = await self._snapshot_pending() + if update_event.is_set(): + return ProduceBatchStatus.UPDATE_ABORT + if not pending_snapshot: + logger.warning("All tasks are done but not enough samples collected.") + return ProduceBatchStatus.NORMAL + + done_tasks, _ = await asyncio.wait( + pending_snapshot, + timeout=1, + return_when=asyncio.FIRST_COMPLETED, + ) + claimed_done = await self._claim_done(done_tasks) + await self._put_claimed_tasks( + claimed_done, + replay_buffer, + task_name, + progress, + available_base=available, + ) diff --git a/xtuner/v1/rl/agent_loop_manager/sampler.py b/xtuner/v1/rl/agent_loop_manager/sampler.py new file mode 100644 index 0000000000..a897b72c48 --- /dev/null +++ b/xtuner/v1/rl/agent_loop_manager/sampler.py @@ -0,0 +1,131 @@ +import copy +from pathlib import Path +from typing import Iterator, Optional, cast +from uuid import uuid4 + +import ray +import torch +from pydantic import BaseModel, ConfigDict + +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast +from xtuner.v1.data_proto.rl_data import RolloutState, Status +from xtuner.v1.datasets.config import DataloaderConfig +from xtuner.v1.datasets.dataloader import Dataloader +from xtuner.v1.rl.replay_buffer import ReplayBuffer +from xtuner.v1.utils import XTUNER_DETERMINISTIC +from xtuner.v1.utils.logger import get_logger + + +logger = get_logger(__name__) + + +class SamplerConfig(BaseModel): + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + dataloader_cfg: DataloaderConfig + prompt_repeat_k: int = 1 + + def build( + self, tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | str, replay_buffer: ReplayBuffer + ) -> "Sampler": + if isinstance(tokenizer, str): + tokenizer_obj = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True) + else: + tokenizer_obj = tokenizer + dataloader = self.dataloader_cfg.build( + tokenizer=tokenizer_obj, dp_mesh=None, global_batch_size=1, micro_batch_size=1, seed=1 + ) + return Sampler(dataloader=dataloader, prompt_repeat_k=self.prompt_repeat_k, replay_buffer=replay_buffer) + + +# TODO: The best solution is to put it in the fake_collator, +# but it will cause a deadlock problem, so it is temporarily placed here. +# The best solution should be to start the dataloader using spawn. +def put_to_ray(data: RolloutState) -> RolloutState: + if hasattr(data, "mm_info") and data.mm_info is not None: + pixel_values = data.mm_info.get("pixel_values", None) + if pixel_values is not None: + data.mm_info["pixel_values"] = ray.put(pixel_values) + return data + + +class _DatasetSampler: + def __init__(self, dataloader: Dataloader, prompt_repeat_k: int): + self.dataloader = dataloader + self.dataloader_iter: Optional[Iterator] = None + self.cur_epoch = 0 + self.prompt_repeat_k = prompt_repeat_k + self._consumed_samples: int = 0 + + def __len__(self) -> int: + return len(self.dataloader) + + def sample_from_dataloader(self) -> list[RolloutState]: + if self.dataloader_iter is None: + self.dataloader_iter = iter(self.dataloader) + assert self.dataloader_iter is not None + try: + data = cast(RolloutState, next(self.dataloader_iter)[0]) + data = put_to_ray(data) + + except StopIteration: + self.cur_epoch += 1 + self.dataloader.set_epoch(self.cur_epoch) + self.dataloader_iter = iter(self.dataloader) + data = cast(RolloutState, next(self.dataloader_iter)[0]) + data = put_to_ray(data) + + if XTUNER_DETERMINISTIC: + message_uid = self._consumed_samples + uid_base = self._consumed_samples * self.prompt_repeat_k + + group_data = [] + for item_idx in range(self.prompt_repeat_k): + new_data = copy.deepcopy(data) + if XTUNER_DETERMINISTIC: + new_data.message_uid = message_uid + new_data.uid = uid_base + item_idx + new_data.session_uid = new_data.uid + else: + new_data.uid = uuid4().int + group_data.append(new_data) + self._consumed_samples += 1 + return cast(list[RolloutState], group_data) + + +class Sampler(_DatasetSampler): + _DATALOADER_FILE = "dataloader" + + def __init__( + self, + dataloader: Dataloader, + prompt_repeat_k: int, + replay_buffer: ReplayBuffer, + ): + super().__init__(dataloader, prompt_repeat_k) + self.replay_buffer = replay_buffer + + async def sample(self, task_name: str, group_status: list[Status] | None = None) -> list[RolloutState]: + for status in group_status or []: + buffer_data = await self.replay_buffer.get(1, task_name=task_name, group_status=status) + if buffer_data: + return buffer_data[0] + return self.sample_from_dataloader() + + def save(self, checkpoint_path: Path | str) -> None: + """Save the sampler's dataloader state to checkpoint.""" + checkpoint_path = Path(checkpoint_path) + dataloader_state = self.dataloader.get_state_dict() + torch.save(dataloader_state, checkpoint_path / self._DATALOADER_FILE) + + def resume(self, checkpoint_path: Path | str) -> None: + """Resume the sampler's dataloader state from checkpoint.""" + checkpoint_path = Path(checkpoint_path) + dataloader_path = checkpoint_path / self._DATALOADER_FILE + if not dataloader_path.exists(): + logger.warning(f"Dataloader state {dataloader_path} not found, skipping resume.") + return + state = torch.load(dataloader_path, map_location="cpu") + self.dataloader.load_state_dict(state) + self.dataloader_iter = iter(self.dataloader) + self._consumed_samples = state["sampler"]["step"] + self.cur_epoch = state["sampler"]["epoch"] diff --git a/xtuner/v1/rl/base.py b/xtuner/v1/rl/base.py deleted file mode 100644 index a299434804..0000000000 --- a/xtuner/v1/rl/base.py +++ /dev/null @@ -1,12 +0,0 @@ -from typing import Literal - -from cyclopts import Parameter -from pydantic import BaseModel, Field -from typing_extensions import Annotated - - -class BaseTrainerConfig(BaseModel): - type: Annotated[ - Literal["xtuner", "lmdeploy", "sglang", "vllm"], - Parameter(group="Worker Types", description="Type of the worker."), - ] = Field(..., discriminator="type") diff --git a/xtuner/v1/rl/base/__init__.py b/xtuner/v1/rl/base/__init__.py deleted file mode 100644 index 7141d58260..0000000000 --- a/xtuner/v1/rl/base/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -from .controller import TrainingController, TrainingControllerProxy -from .loss import BaseRLLossConfig, BaseRLLossContext, BaseRLLossKwargs, compute_kl_loss_weight -from .worker import TrainingWorker, TrainingWorkerClass, TrainingWorkerProxy, WorkerConfig, WorkerLogItem - - -__all__ = [ - "TrainingController", - "TrainingControllerProxy", - "TrainingWorkerClass", - "TrainingWorkerProxy", - "TrainingWorker", - "WorkerConfig", - "BaseRLLossConfig", - "BaseRLLossKwargs", - "BaseRLLossContext", - "compute_kl_loss_weight", - "WorkerLogItem", -] diff --git a/xtuner/v1/rl/config/__init__.py b/xtuner/v1/rl/config/__init__.py deleted file mode 100644 index ad4e0eb5a1..0000000000 --- a/xtuner/v1/rl/config/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from .advantage import ( - DrGRPOAdvantageConfig, - GRPOAdvantageConfig, - OPOAdvantageConfig, - PassKAdvantageConfig, - RLOOAdvantageConfig, -) -from .trainer import GRPOTrainerConfig diff --git a/xtuner/v1/rl/config/loss.py b/xtuner/v1/rl/config/loss.py deleted file mode 100644 index 1cf4af1230..0000000000 --- a/xtuner/v1/rl/config/loss.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing import Literal - -from cyclopts import Parameter -from pydantic import BaseModel, ConfigDict -from typing_extensions import Annotated - - -class BaseLossConfig(BaseModel): - """Base configuration for loss function.""" - - model_config = ConfigDict(extra="forbid") - type: Annotated[ - Literal["grpo", "ppo"], - Parameter(group="Loss Types", help="Type of the loss function."), - ] diff --git a/xtuner/v1/rl/config/trainer.py b/xtuner/v1/rl/config/trainer.py deleted file mode 100644 index 058dacf0a0..0000000000 --- a/xtuner/v1/rl/config/trainer.py +++ /dev/null @@ -1,64 +0,0 @@ -from typing import Optional - -from cyclopts import Group, Parameter -from pydantic import BaseModel, ConfigDict, Field -from typing_extensions import Annotated - -from xtuner.v1.engine.config import EngineConfig -from xtuner.v1.ray.base import AcceleratorResourcesConfig -from xtuner.v1.ray.config.worker import RolloutConfig - - -grpo_group = Group("GRPO", sort_key=1, help="GRPO Trainer Configuration") - -actor_worker_group = Group("Actor Workers", sort_key=90, help="Configuration for the rollout worker.") -actor_resources_group = Group("Actor Resources", sort_key=90, help="Configuration for the actor resources.") -rollout_worker_group = Group("Rollout Workers", sort_key=90, help="Configuration for the rollout worker.") -rollout_resources_group = Group("Rollout Resources", sort_key=90, help="Configuration for the rollout resources.") - - -class GRPOTrainerConfig(BaseModel): - """Configuration for the GRPO Ray Trainer.""" - - model_config = ConfigDict(extra="forbid") - actor: Annotated[ - EngineConfig, - Parameter(group=actor_worker_group, help="Configuration for the rollout worker."), - ] - - critic: Annotated[ - EngineConfig, - Parameter(group=actor_worker_group, help="Configuration for the rollout worker."), - ] - - actor_resources: Annotated[ - AcceleratorResourcesConfig, Parameter(group=actor_resources_group, help="Resources allocated for the actor.") - ] - - rollout: Annotated[ - RolloutConfig, - Parameter(group=rollout_worker_group, help="Configuration for the rollout worker."), - # Discriminator('type') - ] - rollout_resources: Annotated[ - Optional[AcceleratorResourcesConfig], - Parameter(group=rollout_resources_group, help="Resources allocated for the rollout."), - ] = None - - enrionment: Annotated[str, Parameter(group=grpo_group, help="Environment for the GRPO training.")] = "default" - - global_batch_size: Annotated[int, Parameter(group=grpo_group, help="Batch size for training.")] = Field( - 32, help="Batch size for training." - ) - - micro_batch_size: Annotated[int, Parameter(group=grpo_group, help="Micro batch size for training.")] = Field( - 8, help="Micro batch size for training." - ) - - num_mini_batches: Annotated[int, Parameter(group=grpo_group, help="Number of mini-batches for training.")] = Field( - 4, help="Number of mini-batches for training." - ) - - total_steps: Annotated[int, Parameter(group=grpo_group, help="Total number of training steps.")] = Field( - 100000, help="Total number of training steps." - ) diff --git a/xtuner/v1/rl/evaluator.py b/xtuner/v1/rl/evaluator.py new file mode 100644 index 0000000000..4c12f08d08 --- /dev/null +++ b/xtuner/v1/rl/evaluator.py @@ -0,0 +1,80 @@ +from collections.abc import Mapping +from typing import Annotated, Protocol, cast, runtime_checkable + +from cyclopts import Parameter +from pydantic import BaseModel, ConfigDict, Field + +from xtuner.v1.data_proto.rl_data import RolloutState + + +@runtime_checkable +class ComputeMetricProtocol(Protocol): + def __call__(self, samples: list[RolloutState]) -> dict[str, float]: ... + + +def default_compute_metric_func(samples: list[RolloutState]) -> dict[str, float]: + if not samples: + return {"accuracy": 0.0} + + positive = 0 + for s in samples: + reward = s.reward + assert isinstance(reward, Mapping) + score = reward["score"] + if score > 0: + positive += 1 + return {"accuracy": positive / len(samples)} + + +class Evaluator: + def __init__( + self, + compute_metric_func: ComputeMetricProtocol | None = None, + eval_batch_size: int = 0, + ): + self.compute_metric_func = compute_metric_func or default_compute_metric_func + self.eval_batch_size = eval_batch_size + + def run(self, samples: list[RolloutState] | list[list[RolloutState]]) -> dict[str, float]: + # 将 list[list[RolloutState]] 转换为 list[RolloutState] + if samples and isinstance(samples[0], list): + flat_samples = [sample for batch in cast(list[list[RolloutState]], samples) for sample in batch] + else: + flat_samples = cast(list[RolloutState], samples) + return self.compute_metric_func(flat_samples) + + +class EvaluatorConfig(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") + + eval_sample_ratio: Annotated[ + float, + Parameter(help="Ratio of samples to evaluate from the generated samples."), + ] = 0 + eval_sample_num: Annotated[ + int, + Parameter(help="Number of samples to evaluate from the generated samples."), + ] = 0 + + compute_metric_func: Annotated[ + ComputeMetricProtocol | None, + Field(exclude=True), + Parameter(help="An optional metric computation function."), + ] = None + + def build(self, total_eval_samples: int = 0) -> "Evaluator": + if self.eval_sample_num > 0: + eval_batch_size = self.eval_sample_num + else: + assert total_eval_samples > 0, ( + "Total eval samples must be greater than 0 if eval sample num is not provided" + ) + if self.eval_sample_ratio > 0: + eval_batch_size = int(total_eval_samples * self.eval_sample_ratio) + else: + eval_batch_size = total_eval_samples + + return Evaluator( + compute_metric_func=self.compute_metric_func, + eval_batch_size=eval_batch_size, + ) diff --git a/xtuner/v1/rl/gateway/__init__.py b/xtuner/v1/rl/gateway/__init__.py new file mode 100644 index 0000000000..e5a8dfafdb --- /dev/null +++ b/xtuner/v1/rl/gateway/__init__.py @@ -0,0 +1,20 @@ +from .backend.local_backend import LocalRolloutBackend +from .config import GatewayConfig +from .server import ( + build_gateway_app, + build_local_gateway_app, + serve_gateway, + serve_gateway_in_thread, + wait_for_gateway_ready, +) + + +__all__ = [ + "GatewayConfig", + "LocalRolloutBackend", + "build_gateway_app", + "build_local_gateway_app", + "serve_gateway", + "serve_gateway_in_thread", + "wait_for_gateway_ready", +] diff --git a/xtuner/v1/rl/gateway/adapters/__init__.py b/xtuner/v1/rl/gateway/adapters/__init__.py new file mode 100644 index 0000000000..4c68e142e8 --- /dev/null +++ b/xtuner/v1/rl/gateway/adapters/__init__.py @@ -0,0 +1,43 @@ +from .anthropic import ( + AnthropicChatAdapter, + AnthropicChatAdapterError, + AnthropicCountTokensRequest, + AnthropicCountTokensResponse, + AnthropicMessagesRequest, + AnthropicMessagesResponse, +) +from .base import BaseChatAPIAdapter +from .openai import ( + ChatCompletionRequest, + ChatCompletionResponse, + OpenAIChatAdapter, + OpenAIChatAdapterError, +) +from .responses import ResponsesRequest, ResponsesResponse +from .trace import ( + DEFAULT_CHAT_TRACE_KEY, + ChatTraceRecord, + ChatTraceStore, + build_api_key_trace_key, +) + + +__all__ = [ + "AnthropicChatAdapter", + "AnthropicChatAdapterError", + "AnthropicCountTokensRequest", + "AnthropicCountTokensResponse", + "AnthropicMessagesRequest", + "AnthropicMessagesResponse", + "ChatCompletionRequest", + "ChatCompletionResponse", + "OpenAIChatAdapter", + "OpenAIChatAdapterError", + "ResponsesRequest", + "ResponsesResponse", + "BaseChatAPIAdapter", + "DEFAULT_CHAT_TRACE_KEY", + "ChatTraceRecord", + "ChatTraceStore", + "build_api_key_trace_key", +] diff --git a/xtuner/v1/rl/gateway/adapters/anthropic.py b/xtuner/v1/rl/gateway/adapters/anthropic.py new file mode 100644 index 0000000000..81705986e3 --- /dev/null +++ b/xtuner/v1/rl/gateway/adapters/anthropic.py @@ -0,0 +1,655 @@ +import json +from collections.abc import AsyncIterator +from typing import Any, Literal +from uuid import uuid4 + +from fastapi.responses import StreamingResponse +from pydantic import BaseModel, ConfigDict + +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast + +from ..core.models import ( + CanonicalGenerateRequest, + CanonicalGenerateResponse, + CanonicalMessage, + CanonicalReasoning, + CanonicalReasoningBlock, + CanonicalReasoningStep, + CanonicalTextBlock, + CanonicalToolCall, + CanonicalToolCallBlock, + CanonicalToolChoice, + CanonicalToolDefinition, + CanonicalToolResult, + CanonicalToolResultBlock, +) +from .base import BaseChatAPIAdapter +from .streaming import build_sse_response, encode_sse_event +from .trace import ChatTraceStore, normalize_trace_payload + + +class AnthropicTextContent(BaseModel): + model_config = ConfigDict(extra="allow") + + type: str = "text" + text: str + + +AnthropicContentBlock = dict[str, Any] + + +class AnthropicMessage(BaseModel): + model_config = ConfigDict(extra="allow") + + role: Literal["user", "assistant"] + content: str | list[AnthropicContentBlock] + + +class AnthropicMessagesRequest(BaseModel): + model_config = ConfigDict(extra="allow") + + session_uid: int | None = None + model: str | None = None + system: str | list[dict[str, Any]] | None = None + messages: list[AnthropicMessage] + max_tokens: int + stream: bool = False + temperature: float | None = None + top_p: float | None = None + stop_sequences: list[str] | None = None + tools: list[dict[str, Any]] | None = None + tool_choice: str | dict[str, Any] | None = None + + +class AnthropicCountTokensRequest(BaseModel): + model_config = ConfigDict(extra="allow") + + model: str | None = None + system: str | list[dict[str, Any]] | None = None + messages: list[AnthropicMessage] + tools: list[dict[str, Any]] | None = None + + +class AnthropicCountTokensResponse(BaseModel): + model_config = ConfigDict(extra="allow") + + input_tokens: int + + +class AnthropicUsage(BaseModel): + model_config = ConfigDict(extra="allow") + + input_tokens: int + output_tokens: int + + +class AnthropicMessagesResponse(BaseModel): + model_config = ConfigDict(extra="allow") + + id: str + type: Literal["message"] = "message" + role: Literal["assistant"] = "assistant" + content: list[dict[str, Any]] + model: str + stop_reason: str | None = None + stop_sequence: str | None = None + usage: AnthropicUsage + + +class AnthropicChatAdapterError(RuntimeError): + def __init__(self, message: str, error_type: str, request_id: str | None = None): + super().__init__(message) + self.message = message + self.error_type = error_type + self.request_id = request_id + + +class AnthropicChatAdapter(BaseChatAPIAdapter[AnthropicMessagesRequest, AnthropicMessagesResponse]): + def __init__( + self, + generate_handler, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | str | None, + default_model_name: str | None = None, + context_length: int | None = None, + capture_folder: str | None = None, + trace_store: ChatTraceStore | None = None, + ): + if isinstance(tokenizer, str): + tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True) + super().__init__(generate_handler, tokenizer=tokenizer, capture_folder=capture_folder, trace_store=trace_store) + self._default_model_name = default_model_name + self._context_length = context_length + + async def messages( + self, + request: AnthropicMessagesRequest, + *, + api_key: str | None = None, + ) -> AnthropicMessagesResponse | StreamingResponse: + if request.stream: + response = await self.handle_request(request, api_key=api_key) + return build_sse_response(self.iter_stream_events(response)) + return await self.handle_request(request, api_key=api_key) + + async def count_tokens(self, request: AnthropicCountTokensRequest) -> AnthropicCountTokensResponse: + internal_messages = self._build_internal_messages(request) + tokenizer_tools = self._normalize_tools_for_backend(request.tools) + if self._tokenizer is None: + return AnthropicCountTokensResponse(input_tokens=0) + raw_prompt_ids = self._tokenizer.apply_chat_template( + internal_messages, + tools=tokenizer_tools, + tokenize=True, + add_generation_prompt=True, + ) + prompt_ids = raw_prompt_ids.get("input_ids") if hasattr(raw_prompt_ids, "get") else list(raw_prompt_ids) + return AnthropicCountTokensResponse(input_tokens=len(prompt_ids)) + + def validate_request(self, request: AnthropicMessagesRequest) -> None: + return None + + def request_to_canonical_request(self, request: AnthropicMessagesRequest) -> CanonicalGenerateRequest: + messages: list[CanonicalMessage] = [] + if request.system: + messages.append(self._anthropic_system_to_canonical_message(request.system)) + messages.extend(self._anthropic_messages_to_canonical_messages(request.messages)) + return CanonicalGenerateRequest( + request_id=f"anthropic_req_{uuid4().hex}", + model=request.model or self._default_model_name or "rollout-controller", + messages=messages, + tools=self._anthropic_tools_to_canonical(request.tools), + tool_choice=self._anthropic_tool_choice_to_canonical(request.tool_choice), + temperature=request.temperature, + top_p=request.top_p, + max_tokens=request.max_tokens, + stop=list(request.stop_sequences or []), + stream=False, + metadata={ + key: value + for key, value in { + "source_protocol": "anthropic_messages", + "client_stream": bool(request.stream), + "session_uid": request.session_uid, + }.items() + if value is not None + }, + ) + + def normalize_request(self, request: AnthropicMessagesRequest) -> dict[str, Any]: + return normalize_trace_payload(request.model_dump(mode="python", exclude_none=True)) + + def normalize_response(self, response: AnthropicMessagesResponse) -> dict[str, Any]: + return normalize_trace_payload(response.model_dump(mode="python", exclude_none=True)) + + async def iter_stream_events( + self, + response: AnthropicMessagesResponse, + ) -> AsyncIterator[str]: + yield encode_sse_event( + { + "type": "message_start", + "message": { + "id": response.id, + "type": response.type, + "role": response.role, + "content": [], + "model": response.model, + "stop_reason": None, + "stop_sequence": None, + "usage": { + "input_tokens": response.usage.input_tokens, + "output_tokens": 0, + }, + }, + }, + event="message_start", + ) + + for index, block in enumerate(response.content): + block_type = block.get("type") + start_block: dict[str, Any] + delta: dict[str, Any] + if block_type == "reasoning": + start_block = {"type": "thinking", "thinking": ""} + delta = {"type": "thinking_delta", "thinking": str(block.get("text", ""))} + elif block_type == "tool_use": + start_block = { + "type": "tool_use", + "id": block.get("id"), + "name": block.get("name"), + "input": {}, + } + delta = { + "type": "input_json_delta", + "partial_json": json.dumps(block.get("input", {}), ensure_ascii=False), + } + else: + start_block = {"type": "text", "text": ""} + delta = {"type": "text_delta", "text": str(block.get("text", ""))} + + yield encode_sse_event( + { + "type": "content_block_start", + "index": index, + "content_block": start_block, + }, + event="content_block_start", + ) + yield encode_sse_event( + { + "type": "content_block_delta", + "index": index, + "delta": delta, + }, + event="content_block_delta", + ) + yield encode_sse_event( + { + "type": "content_block_stop", + "index": index, + }, + event="content_block_stop", + ) + + yield encode_sse_event( + { + "type": "message_delta", + "delta": { + "stop_reason": self._stream_stop_reason(response.stop_reason), + "stop_sequence": response.stop_sequence, + }, + "usage": { + "output_tokens": response.usage.output_tokens, + }, + }, + event="message_delta", + ) + yield encode_sse_event({"type": "message_stop"}, event="message_stop") + + def canonical_response_to_protocol_response( + self, + canonical_response: CanonicalGenerateResponse, + request: AnthropicMessagesRequest, + ) -> AnthropicMessagesResponse: + content = self._canonical_response_to_anthropic_blocks( + canonical_response, + tools=self._anthropic_tools_to_canonical(request.tools), + ) + stop_reason = canonical_response.finish_reason or "stop" + if any(block.get("type") == "tool_use" for block in content): + stop_reason = "tool_use" + return AnthropicMessagesResponse( + id=f"msg_{canonical_response.request_id}", + content=content, + model=canonical_response.model or self._default_model_name or "rollout-controller", + stop_reason=stop_reason, + usage=AnthropicUsage( + input_tokens=canonical_response.usage.prompt_tokens, + output_tokens=canonical_response.usage.completion_tokens, + ), + ) + + def _build_internal_messages(self, request: AnthropicCountTokensRequest) -> list[dict[str, Any]]: + messages: list[dict[str, Any]] = [] + if request.system: + if isinstance(request.system, str): + system_text = request.system + else: + system_text = self._join_text_blocks(request.system, context="system") + messages.append({"role": "system", "content": system_text}) + + for message in request.messages: + if isinstance(message.content, str): + messages.append({"role": message.role, "content": message.content}) + else: + messages.extend(self._convert_content_blocks_to_backend_messages(message.role, message.content)) + return messages + + def _join_text_blocks(self, blocks: list[dict[str, Any]], context: str) -> str: + unsupported_types = [str(block.get("type")) for block in blocks if block.get("type") != "text"] + if unsupported_types: + unsupported_str = ", ".join(sorted(set(unsupported_types))) + raise AnthropicChatAdapterError( + f"Unsupported Anthropic content block type(s) in {context}: {unsupported_str}", + "invalid_request_error", + ) + return "\n".join(str(block.get("text", "")) for block in blocks) + + def _convert_content_blocks_to_backend_messages( + self, + role: str, + blocks: list[dict[str, Any]], + ) -> list[dict[str, Any]]: + backend_messages: list[dict[str, Any]] = [] + text_chunks: list[str] = [] + tool_calls: list[dict[str, Any]] = [] + + def flush_text_chunks() -> None: + if text_chunks: + backend_messages.append({"role": role, "content": "\n".join(text_chunks)}) + text_chunks.clear() + + for block in blocks: + block_type = block.get("type") + if block_type == "text": + text_value = str(block.get("text", "")) + if role == "assistant": + text_value = self._sanitize_assistant_text(text_value) + text_chunks.append(text_value) + elif block_type == "tool_use": + tool_calls.append( + { + "id": block.get("id") or f"toolu_{uuid4().hex}", + "type": "function", + "function": { + "name": str(block.get("name", "")), + "arguments": normalize_trace_payload(block.get("input", {})), + }, + } + ) + elif block_type == "tool_result": + flush_text_chunks() + backend_messages.append( + { + "role": "tool", + "content": self._serialize_tool_result_content(block.get("content")), + "tool_call_id": block.get("tool_use_id"), + } + ) + else: + raise AnthropicChatAdapterError( + f"Unsupported Anthropic content block type in messages[{role}]: {block_type}", + "invalid_request_error", + ) + + if tool_calls: + backend_messages.append( + { + "role": role, + "content": "\n".join(text_chunks), + "tool_calls": tool_calls, + } + ) + text_chunks.clear() + flush_text_chunks() + return backend_messages + + def _serialize_tool_result_content(self, content: Any) -> str: + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + if all(isinstance(item, dict) and item.get("type") == "text" for item in content): + return "\n".join(str(item.get("text", "")) for item in content) + return json.dumps(content, ensure_ascii=False) + if isinstance(content, dict): + return json.dumps(content, ensure_ascii=False) + return str(content) + + def _normalize_tools_for_backend(self, tools: list[dict[str, Any]] | None) -> list[dict[str, Any]] | None: + if not tools: + return None + normalized_tools = [] + for tool in tools: + if tool.get("type") == "function": + normalized_tools.append(normalize_trace_payload(tool)) + else: + normalized_tools.append( + { + "type": "function", + "function": { + "name": tool["name"], + "description": tool.get("description", ""), + "parameters": tool["input_schema"], + }, + } + ) + return normalize_trace_payload(normalized_tools) + + def _sanitize_assistant_text(self, text: str) -> str: + cleaned = text.replace("<|im_end|>", "") + cleaned = cleaned.replace("", "") + cleaned = cleaned.replace("", "") + return cleaned.strip() + + def _anthropic_system_to_canonical_message( + self, + system: str | list[dict[str, Any]], + ) -> CanonicalMessage: + if isinstance(system, str): + content = [CanonicalTextBlock(text=system)] if system else [] + else: + content = [] + for block in system: + if block.get("type") != "text": + raise AnthropicChatAdapterError( + f"Unsupported Anthropic content block type(s) in system: {block.get('type')}", + "invalid_request_error", + ) + text = str(block.get("text", "")) + if text: + content.append(CanonicalTextBlock(text=text)) + return CanonicalMessage( + role="system", + content=content, + metadata={"source_protocol": "anthropic_messages"}, + ) + + def _anthropic_messages_to_canonical_messages( + self, + messages: list[AnthropicMessage], + ) -> list[CanonicalMessage]: + canonical_messages = [] + for message in messages: + if isinstance(message.content, str): + content_blocks = [CanonicalTextBlock(text=message.content)] if message.content else [] + else: + content_blocks = self._anthropic_content_blocks_to_canonical(message.content) + canonical_messages.append( + CanonicalMessage( + role=message.role, + content=content_blocks, + metadata={"source_protocol": "anthropic_messages"}, + ) + ) + return canonical_messages + + def _anthropic_content_blocks_to_canonical( + self, + blocks: list[dict[str, Any]], + ) -> list[Any]: + canonical_blocks: list[Any] = [] + for block in blocks: + block_type = block.get("type") + if block_type == "text": + canonical_blocks.append(CanonicalTextBlock(text=str(block.get("text", "")))) + elif block_type == "tool_use": + canonical_blocks.append( + CanonicalToolCallBlock( + tool_call=CanonicalToolCall( + id=str(block.get("id") or f"toolu_{uuid4().hex}"), + name=str(block.get("name", "")), + arguments=normalize_trace_payload(block.get("input", {})), + metadata={"source_protocol": "anthropic_messages"}, + ) + ) + ) + elif block_type == "tool_result": + content = block.get("content") + canonical_blocks.append( + CanonicalToolResultBlock( + tool_result=CanonicalToolResult( + tool_call_id=str(block.get("tool_use_id") or ""), + output=content, + output_text=self._serialize_tool_result_content(content), + is_error=bool(block.get("is_error", False)), + metadata={"source_protocol": "anthropic_messages"}, + ) + ) + ) + elif block_type in {"reasoning", "thinking"}: + reasoning_text = str(block.get("text", "")) + canonical_blocks.append( + CanonicalReasoningBlock( + reasoning=CanonicalReasoning( + steps=[CanonicalReasoningStep(text=reasoning_text)] if reasoning_text else [], + metadata={"source_protocol": "anthropic_messages"}, + ) + ) + ) + else: + raise AnthropicChatAdapterError( + f"Unsupported Anthropic content block type in canonical mapping: {block_type}", + "invalid_request_error", + ) + return canonical_blocks + + def _anthropic_tools_to_canonical( + self, + tools: list[dict[str, Any]] | None, + ) -> list[CanonicalToolDefinition]: + if not tools: + return [] + canonical_tools = [] + for tool in tools: + if tool.get("type") == "function": + function_spec = tool.get("function", {}) + name = function_spec.get("name") + description = function_spec.get("description") + parameters = function_spec.get("parameters", {}) + else: + name = tool.get("name") + description = tool.get("description") + parameters = tool.get("input_schema", {}) + canonical_tools.append( + CanonicalToolDefinition( + name=str(name or ""), + description=description, + parameters_json_schema=parameters, + metadata={"source_protocol": "anthropic_messages"}, + ) + ) + return canonical_tools + + def _anthropic_tool_choice_to_canonical( + self, + tool_choice: str | dict[str, Any] | None, + ) -> CanonicalToolChoice | None: + if tool_choice is None: + return None + if isinstance(tool_choice, str): + mapped_type = "required" if tool_choice == "any" else tool_choice + return CanonicalToolChoice(type=mapped_type) + choice_type = tool_choice.get("type") + if choice_type == "tool": + return CanonicalToolChoice( + type="specific", + tool_name=tool_choice.get("name"), + metadata={"source_protocol": "anthropic_messages"}, + ) + mapped_type = "required" if choice_type == "any" else str(choice_type or "auto") + return CanonicalToolChoice( + type=mapped_type, + metadata={"source_protocol": "anthropic_messages"}, + ) + + def _canonical_response_to_anthropic_blocks( + self, + response: CanonicalGenerateResponse, + tools: list[CanonicalToolDefinition] | None = None, + ) -> list[dict[str, Any]]: + blocks: list[dict[str, Any]] = [] + for block in response.output.content: + if isinstance(block, CanonicalTextBlock): + if block.text: + blocks.append({"type": "text", "text": block.text}) + elif isinstance(block, CanonicalToolCallBlock): + tool_call = self._sanitize_tool_call_for_request(block.tool_call, tools=tools or []) + blocks.append( + { + "type": "tool_use", + "id": tool_call.id, + "name": tool_call.name, + "input": tool_call.arguments if tool_call.arguments is not None else {}, + } + ) + elif isinstance(block, CanonicalToolResultBlock): + tool_result_content: Any = block.tool_result.output + if tool_result_content is None: + tool_result_content = block.tool_result.output_text or "" + blocks.append( + { + "type": "tool_result", + "tool_use_id": block.tool_result.tool_call_id, + "content": tool_result_content, + "is_error": block.tool_result.is_error, + } + ) + elif isinstance(block, CanonicalReasoningBlock): + reasoning_text = self._reasoning_to_text(block.reasoning) + if reasoning_text: + blocks.append({"type": "thinking", "thinking": reasoning_text}) + return blocks or [{"type": "text", "text": ""}] + + def _sanitize_tool_call_for_request( + self, + tool_call: CanonicalToolCall, + *, + tools: list[CanonicalToolDefinition], + ) -> CanonicalToolCall: + tool_definition = next((tool for tool in tools if tool.name == tool_call.name), None) + if tool_definition is None: + return tool_call + + properties = tool_definition.parameters_json_schema.get("properties") + if not isinstance(properties, dict): + return tool_call + + arguments = tool_call.arguments + normalized_arguments = False + if not isinstance(arguments, dict): + normalized_arguments = True + if tool_call.raw_arguments_text is not None: + try: + decoded = json.loads(tool_call.raw_arguments_text) + except Exception: + decoded = {"raw": tool_call.raw_arguments_text} + arguments = decoded if isinstance(decoded, dict) else {"value": decoded} + elif arguments is None: + arguments = {} + elif isinstance(arguments, str): + try: + decoded = json.loads(arguments) + except Exception: + decoded = {"raw": arguments} + arguments = decoded if isinstance(decoded, dict) else {"value": decoded} + else: + arguments = {"value": arguments} + + allowed_keys = set(properties) + cleaned_arguments = {key: value for key, value in arguments.items() if key in allowed_keys} + if cleaned_arguments == arguments and not normalized_arguments: + return tool_call + + dropped_keys = sorted(set(arguments) - set(cleaned_arguments)) + metadata = dict(tool_call.metadata) + if dropped_keys: + metadata["dropped_arguments"] = dropped_keys + return CanonicalToolCall( + id=tool_call.id, + name=tool_call.name, + arguments=cleaned_arguments, + raw_arguments_text=None, + metadata=metadata, + ) + + def _reasoning_to_text(self, reasoning: CanonicalReasoning) -> str: + return "\n".join(step.text for step in reasoning.steps if step.text).strip() + + def _stream_stop_reason(self, stop_reason: str | None) -> str | None: + if stop_reason == "stop": + return "end_turn" + if stop_reason == "length": + return "max_tokens" + return stop_reason diff --git a/xtuner/v1/rl/gateway/adapters/base.py b/xtuner/v1/rl/gateway/adapters/base.py new file mode 100644 index 0000000000..f0ff991438 --- /dev/null +++ b/xtuner/v1/rl/gateway/adapters/base.py @@ -0,0 +1,246 @@ +from __future__ import annotations + +import json +import logging +from abc import ABC, abstractmethod +from collections.abc import Awaitable, Callable +from typing import Any, Generic, TypeVar + +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from xtuner.v1.data_proto.rl_data import Status + +from ..core.models import ( + CanonicalGenerateRequest, + CanonicalGenerateResponse, + CanonicalReasoningBlock, + CanonicalTextBlock, + CanonicalToolCall, + CanonicalToolCallBlock, + CanonicalToolResultBlock, +) +from .capture import append_gateway_capture_record, render_blocks_as_text +from .trace import ( + ChatTraceRecord, + ChatTraceStore, + build_api_key_trace_key, + normalize_trace_payload, + snapshot_routed_experts, +) + + +GenerateHandler = Callable[[CanonicalGenerateRequest], Awaitable[CanonicalGenerateResponse]] +RequestT = TypeVar("RequestT") +ResponseT = TypeVar("ResponseT") +logger = logging.getLogger(__name__) + + +def coerce_content_to_text(content: Any) -> str | None: + """Coerce arbitrary content (str, list of blocks, dict) to a plain + string.""" + if content is None: + return None + if isinstance(content, str): + return content + if isinstance(content, list): + text_chunks = [] + for item in content: + if isinstance(item, dict) and item.get("type") in {"text", "input_text", "output_text"}: + text_chunks.append(str(item.get("text", ""))) + joined = "\n".join(chunk for chunk in text_chunks if chunk) + return joined or None + if isinstance(content, dict) and "text" in content: + return str(content["text"]) + return str(content) + + +def stringify_tool_arguments(tool_call: CanonicalToolCall) -> str: + if tool_call.raw_arguments_text is not None: + return tool_call.raw_arguments_text + if isinstance(tool_call.arguments, str): + return tool_call.arguments + return json.dumps(tool_call.arguments if tool_call.arguments is not None else {}, ensure_ascii=False) + + +class BaseChatAPIAdapter(ABC, Generic[RequestT, ResponseT]): + def __init__( + self, + generate_handler: GenerateHandler, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | None, + *, + capture_folder: str | None = None, + trace_store: ChatTraceStore | None = None, + trace_store_max_entries: int = 10000, + ): + self._generate_handler = generate_handler + self._tokenizer = tokenizer + self._capture_folder = capture_folder + self._trace_store = trace_store or ChatTraceStore(max_entries=trace_store_max_entries) + + async def handle_request(self, request: RequestT, *, api_key: str | None = None) -> ResponseT: + self.validate_request(request) + canonical_request = self.request_to_canonical_request(request) + canonical_response = await self._generate_handler(canonical_request) + response = self.canonical_response_to_protocol_response(canonical_response, request) + record_trace_key = build_api_key_trace_key(api_key) + self._trace_store.append( + self._build_trace_record( + record_trace_key, + request, + response, + canonical_response, + ) + ) + self._write_capture_record( + request=request, + response=response, + canonical_response=canonical_response, + api_key=api_key, + ) + return response + + def get_trace_records(self, trace_key: str) -> list[ChatTraceRecord]: + return self._trace_store.get(trace_key) + + def pop_trace_records(self, trace_key: str) -> list[ChatTraceRecord]: + return self._trace_store.pop(trace_key) + + def clear_trace_records(self, trace_key: str) -> None: + self._trace_store.clear(trace_key) + + def _build_trace_record( + self, + trace_key: str, + request: RequestT, + response: ResponseT, + canonical_response: CanonicalGenerateResponse, + ) -> ChatTraceRecord: + request_snapshot = self.normalize_request(request) + response_snapshot = self.normalize_response(response) + rollout_trace = self._get_rollout_trace(canonical_response) + status = rollout_trace.get("status", Status.COMPLETED.value) + output_text = rollout_trace.get("output_text") or render_blocks_as_text( + self._build_output_message_list(canonical_response) + ) + return ChatTraceRecord( + trace_key=trace_key, + request_snapshot=request_snapshot, + response_snapshot=response_snapshot, + prompt_ids=list(rollout_trace.get("prompt_ids") or []), + response_ids=list(rollout_trace.get("response_ids") or []), + input_text=rollout_trace.get("input_text", ""), + output_text=output_text, + logprobs=rollout_trace.get("logprobs"), + routed_experts=snapshot_routed_experts(rollout_trace.get("routed_experts")), + finish_reason=rollout_trace.get("rollout_finish_reason") or canonical_response.finish_reason, + status=Status(status) if isinstance(status, str) else status, + request_id=canonical_response.request_id, + ) + + def _write_capture_record( + self, + request: RequestT, + response: ResponseT, + canonical_response: CanonicalGenerateResponse, + api_key: str | None = None, + ) -> None: + if self._capture_folder is None: + return + rollout_trace = self._get_rollout_trace(canonical_response) + try: + response_snapshot = self.normalize_response(response) + response_finish_reason = ( + response_snapshot.get("stop_reason") + or response_snapshot.get("finish_reason") + or canonical_response.finish_reason + ) + output_messages = self._build_output_message_list(canonical_response) + append_gateway_capture_record( + self._capture_folder, + { + "protocol": self.__class__.__name__, + "request_id": canonical_response.request_id, + "session_uid": rollout_trace.get("session_uid"), + "status": rollout_trace.get("status", Status.COMPLETED.value), + "finish_reason": response_finish_reason, + "rollout_finish_reason": rollout_trace.get("rollout_finish_reason"), + "prompt_tokens": canonical_response.usage.prompt_tokens, + "completion_tokens": canonical_response.usage.completion_tokens, + "request": self.normalize_request(request), + "response": response_snapshot, + "internal_messages": rollout_trace.get("internal_messages"), + "rollout_tools": rollout_trace.get("rollout_tools"), + "rollout_tool_choice": rollout_trace.get("rollout_tool_choice"), + "rollout_sample_params": rollout_trace.get("rollout_sample_params"), + "output_messages": output_messages, + "input_text": rollout_trace.get("input_text", ""), + "output_text": render_blocks_as_text(output_messages), + }, + api_key=api_key, + ) + except Exception: + logger.warning(f"Failed to write gateway capture record to {self._capture_folder}", exc_info=True) + return + + def _get_rollout_trace(self, canonical_response: CanonicalGenerateResponse) -> dict[str, Any]: + trace_payload = canonical_response.metadata.get("rollout_trace", {}) + if not isinstance(trace_payload, dict): + return {} + return trace_payload + + def _build_output_message_list( + self, + canonical_response: CanonicalGenerateResponse, + ) -> list[dict[str, Any]]: + content: list[dict[str, Any]] = [] + for block in canonical_response.output.content: + if isinstance(block, CanonicalTextBlock): + content.append({"type": "text", "text": block.text}) + elif isinstance(block, CanonicalReasoningBlock): + reasoning_text = "\n".join(step.text for step in block.reasoning.steps if step.text).strip() + if reasoning_text: + content.append({"type": "reasoning", "text": reasoning_text}) + elif isinstance(block, CanonicalToolCallBlock): + content.append( + { + "type": "tool_use", + "id": block.tool_call.id, + "name": block.tool_call.name, + "input": normalize_trace_payload(block.tool_call.arguments), + } + ) + elif isinstance(block, CanonicalToolResultBlock): + tool_result_content = block.tool_result.output + if tool_result_content is None: + tool_result_content = block.tool_result.output_text or "" + content.append( + { + "type": "tool_result", + "tool_use_id": block.tool_result.tool_call_id, + "content": normalize_trace_payload(tool_result_content), + } + ) + return [{"role": "assistant", "content": content or ""}] + + @abstractmethod + def validate_request(self, request: RequestT) -> None: + raise NotImplementedError + + @abstractmethod + def request_to_canonical_request(self, request: RequestT) -> CanonicalGenerateRequest: + raise NotImplementedError + + @abstractmethod + def normalize_request(self, request: RequestT) -> dict[str, Any]: + raise NotImplementedError + + @abstractmethod + def normalize_response(self, response: ResponseT) -> dict[str, Any]: + raise NotImplementedError + + @abstractmethod + def canonical_response_to_protocol_response( + self, + canonical_response: CanonicalGenerateResponse, + request: RequestT, + ) -> ResponseT: + raise NotImplementedError diff --git a/xtuner/v1/rl/gateway/adapters/capture.py b/xtuner/v1/rl/gateway/adapters/capture.py new file mode 100644 index 0000000000..00f6817415 --- /dev/null +++ b/xtuner/v1/rl/gateway/adapters/capture.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import json +import threading +from datetime import datetime, timezone +from hashlib import sha256 +from pathlib import Path +from typing import Any + + +_CAPTURE_LOCK = threading.RLock() +_NO_API_KEY_CAPTURE_FILE_NAME = "api_key_none.jsonl" + + +def resolve_capture_output_path(folder: str | Path, api_key: str | None = None) -> Path: + if not api_key: + return Path(folder) / _NO_API_KEY_CAPTURE_FILE_NAME + api_key_hash = sha256(api_key.encode("utf-8")).hexdigest()[:16] + return Path(folder) / f"api_key_{api_key_hash}.jsonl" + + +def append_gateway_capture_record(folder: str | Path, record: dict[str, Any], api_key: str | None = None) -> None: + capture_path = resolve_capture_output_path(folder, api_key=api_key) + capture_path.parent.mkdir(parents=True, exist_ok=True) + payload = { + "type": "gateway_turn", + "timestamp": datetime.now(timezone.utc).isoformat(), + **record, + } + with _CAPTURE_LOCK: + with capture_path.open("a", encoding="utf-8") as f: + f.write(json.dumps(payload, ensure_ascii=False) + "\n") + + +def render_blocks_as_text(value: Any) -> str: + if value is None: + return "" + if isinstance(value, str): + return value + if isinstance(value, list): + rendered_parts = [render_blocks_as_text(item) for item in value] + return "\n".join(part for part in rendered_parts if part) + if isinstance(value, dict): + block_type = value.get("type") + if block_type == "text": + return str(value.get("text", "")) + if block_type == "tool_use": + name = value.get("name", "") + input_payload = json.dumps(value.get("input", {}), ensure_ascii=False, sort_keys=True) + return f"{input_payload}" + if block_type == "tool_result": + tool_use_id = value.get("tool_use_id", "") + content = render_blocks_as_text(value.get("content")) + return f"{content}" + if "content" in value: + return render_blocks_as_text(value["content"]) + return json.dumps(value, ensure_ascii=False, sort_keys=True) + return str(value) diff --git a/xtuner/v1/rl/gateway/adapters/openai.py b/xtuner/v1/rl/gateway/adapters/openai.py new file mode 100644 index 0000000000..4fccebc92a --- /dev/null +++ b/xtuner/v1/rl/gateway/adapters/openai.py @@ -0,0 +1,511 @@ +import json +import time +from collections.abc import AsyncIterator +from typing import Any, Literal +from uuid import uuid4 + +from fastapi.responses import StreamingResponse +from pydantic import BaseModel, ConfigDict, Field + +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast + +from ..core.models import ( + CanonicalGenerateRequest, + CanonicalGenerateResponse, + CanonicalMessage, + CanonicalReasoningBlock, + CanonicalTextBlock, + CanonicalToolCall, + CanonicalToolCallBlock, + CanonicalToolChoice, + CanonicalToolDefinition, + CanonicalToolResult, + CanonicalToolResultBlock, +) +from .base import BaseChatAPIAdapter, coerce_content_to_text, stringify_tool_arguments +from .streaming import build_sse_response, encode_sse_event +from .trace import ChatTraceStore, normalize_trace_payload + + +class ChatCompletionStreamOptions(BaseModel): + model_config = ConfigDict(extra="allow") + + include_usage: bool = False + continuous_usage_stats: bool = False + + +class ChatCompletionRequest(BaseModel): + model_config = ConfigDict(extra="allow") + + session_uid: int | str | None = None + session_id: int | str | None = None + model: str | None = None + messages: list[dict[str, Any]] + tools: list[dict[str, Any]] | None = None + tool_choice: str | dict[str, Any] | None = None + parallel_tool_calls: bool | None = None + stream: bool = False + stream_options: ChatCompletionStreamOptions | None = None + n: int | None = None + temperature: float | None = None + top_p: float | None = None + top_k: int | None = None + max_tokens: int | None = None + max_completion_tokens: int | None = None + min_tokens: int | None = None + stop: str | list[str] | None = None + stop_token_ids: list[int] | None = None + presence_penalty: float | None = None + frequency_penalty: float | None = None + repetition_penalty: float | None = None + skip_special_tokens: bool | None = None + no_stop_trim: bool | None = None + seed: int | None = None + user: str | None = None + return_routed_experts: bool | None = None + chat_template_kwargs: dict[str, Any] | None = None + + +class UsageInfo(BaseModel): + model_config = ConfigDict(extra="allow") + + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + + +class ChatMessage(BaseModel): + model_config = ConfigDict(extra="allow") + + role: str + content: str | None = None + reasoning_content: str | None = None + tool_calls: list[dict[str, Any]] | None = None + + +class DeltaMessage(BaseModel): + model_config = ConfigDict(extra="allow") + + role: str | None = None + content: str | None = None + reasoning_content: str | None = None + tool_calls: list[dict[str, Any]] | None = None + + +class ChatCompletionResponseChoice(BaseModel): + model_config = ConfigDict(extra="allow") + + index: int + message: ChatMessage + finish_reason: str | None = None + + +class ChatCompletionResponse(BaseModel): + model_config = ConfigDict(extra="allow") + + id: str + object: Literal["chat.completion"] = "chat.completion" + created: int + model: str + choices: list[ChatCompletionResponseChoice] + usage: UsageInfo + + +class ChatCompletionResponseStreamChoice(BaseModel): + model_config = ConfigDict(extra="allow") + + index: int + delta: DeltaMessage = Field(default_factory=DeltaMessage) + finish_reason: str | None = None + + +class ChatCompletionStreamResponse(BaseModel): + model_config = ConfigDict(extra="allow") + + id: str + object: Literal["chat.completion.chunk"] = "chat.completion.chunk" + created: int + model: str + choices: list[ChatCompletionResponseStreamChoice] = Field(default_factory=list) + usage: UsageInfo | None = None + + +class OpenAIChatAdapterError(RuntimeError): + def __init__( + self, + message: str, + error_type: str, + code: str, + request_id: str | None = None, + ): + super().__init__(message) + self.message = message + self.error_type = error_type + self.code = code + self.request_id = request_id + + +class OpenAIChatAdapter(BaseChatAPIAdapter[ChatCompletionRequest, ChatCompletionResponse]): + def __init__( + self, + generate_handler, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | str, + default_model_name: str | None = None, + context_length: int | None = None, + capture_folder: str | None = None, + trace_store: ChatTraceStore | None = None, + trace_store_max_entries: int = 10000, + ): + if isinstance(tokenizer, str): + tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True) + super().__init__( + generate_handler, + tokenizer=tokenizer, + capture_folder=capture_folder, + trace_store=trace_store, + trace_store_max_entries=trace_store_max_entries, + ) + self._default_model_name = default_model_name + self._context_length = context_length + + async def chat( + self, + request: ChatCompletionRequest, + *, + api_key: str | None = None, + ) -> ChatCompletionResponse | StreamingResponse: + if request.stream: + response = await self.handle_request(request, api_key=api_key) + return build_sse_response(self.iter_stream_events(response, request)) + return await self.handle_request(request, api_key=api_key) + + def validate_request(self, request: ChatCompletionRequest) -> None: + if request.n not in (None, 1): + raise OpenAIChatAdapterError( + "n>1 is not supported yet", + "invalid_request_error", + "n_not_supported", + ) + + def request_to_canonical_request(self, request: ChatCompletionRequest) -> CanonicalGenerateRequest: + normalized_messages = normalize_trace_payload(request.messages) + normalized_tools = normalize_trace_payload(request.tools) + normalized_tool_choice = normalize_trace_payload(request.tool_choice) + stop = [] if request.stop is None else [request.stop] if isinstance(request.stop, str) else list(request.stop) + chat_template_kwargs = request.chat_template_kwargs or {} + return CanonicalGenerateRequest( + request_id=f"chatcmpl_req_{uuid4().hex}", + model=request.model or self._default_model_name or "rollout-controller", + messages=[self._openai_message_to_canonical_message(message) for message in normalized_messages], + tools=self._openai_tools_to_canonical(normalized_tools), + tool_choice=self._openai_tool_choice_to_canonical(normalized_tool_choice), + temperature=request.temperature, + top_p=request.top_p, + max_tokens=request.max_completion_tokens + if request.max_completion_tokens is not None + else request.max_tokens, + stop=stop, + stream=False, + metadata={ + key: value + for key, value in { + "source_protocol": "openai_chat_completions", + "client_stream": bool(request.stream), + "session_uid": getattr(request, "session_uid", getattr(request, "session_id", None)), + "n": request.n, + "presence_penalty": request.presence_penalty, + "frequency_penalty": request.frequency_penalty, + "top_k": request.top_k, + "repetition_penalty": request.repetition_penalty, + "min_tokens": request.min_tokens, + "stop_token_ids": request.stop_token_ids, + "skip_special_tokens": request.skip_special_tokens, + "no_stop_trim": request.no_stop_trim, + "spaces_between_special_tokens": chat_template_kwargs.get("spaces_between_special_tokens"), + "sampling_seed": request.seed, + "user": request.user, + "return_routed_experts": request.return_routed_experts, + }.items() + if value is not None + }, + ) + + def canonical_response_to_chat_completion_response( + self, + response: CanonicalGenerateResponse, + ) -> ChatCompletionResponse: + message_content = self._render_openai_response_text(response) + reasoning_content = self._render_openai_reasoning_text(response) + tool_calls = self._canonical_tool_calls_to_openai(response) + finish_reason = response.finish_reason or ("tool_calls" if tool_calls else "stop") + return ChatCompletionResponse( + id=response.request_id, + created=int(time.time()), + model=response.model or self._default_model_name or "rollout-controller", + choices=[ + ChatCompletionResponseChoice( + index=0, + message=ChatMessage( + role="assistant", + content=None if tool_calls and not message_content else message_content, + reasoning_content=reasoning_content, + tool_calls=tool_calls or None, + ), + finish_reason=finish_reason, + ) + ], + usage=UsageInfo( + prompt_tokens=response.usage.prompt_tokens, + completion_tokens=response.usage.completion_tokens, + total_tokens=response.usage.total_tokens, + ), + ) + + def canonical_response_to_protocol_response( + self, + canonical_response: CanonicalGenerateResponse, + request: ChatCompletionRequest, + ) -> ChatCompletionResponse: + return self.canonical_response_to_chat_completion_response(canonical_response) + + def normalize_request(self, request: ChatCompletionRequest) -> dict[str, Any]: + return normalize_trace_payload( + { + "messages": request.messages, + "tools": request.tools, + "tool_choice": request.tool_choice, + } + ) + + def normalize_response(self, response: ChatCompletionResponse) -> dict[str, Any]: + normalized_choices = [] + for choice in response.choices: + normalized_choices.append( + { + "message": getattr(choice.message, "model_dump", lambda **_: choice.message)( + mode="python", + exclude_none=True, + ) + if choice.message is not None + else None, + "finish_reason": choice.finish_reason, + } + ) + return normalize_trace_payload({"choices": normalized_choices}) + + async def iter_stream_events( + self, + response: ChatCompletionResponse, + request: ChatCompletionRequest, + ) -> AsyncIterator[str]: + choice = response.choices[0] + include_usage = bool(getattr(request.stream_options, "include_usage", False)) + + initial_chunk = ChatCompletionStreamResponse( + id=response.id, + created=response.created, + model=response.model, + choices=[ + ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(role="assistant"), + ) + ], + ) + yield encode_sse_event(initial_chunk.model_dump(mode="json", exclude_none=True)) + + if choice.message.reasoning_content: + reasoning_chunk = ChatCompletionStreamResponse( + id=response.id, + created=response.created, + model=response.model, + choices=[ + ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(reasoning_content=choice.message.reasoning_content), + ) + ], + ) + yield encode_sse_event(reasoning_chunk.model_dump(mode="json", exclude_none=True)) + + if choice.message.content: + content_chunk = ChatCompletionStreamResponse( + id=response.id, + created=response.created, + model=response.model, + choices=[ + ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(content=choice.message.content), + ) + ], + ) + yield encode_sse_event(content_chunk.model_dump(mode="json", exclude_none=True)) + + for index, tool_call in enumerate(choice.message.tool_calls or []): + tool_call_id = tool_call.get("id") if isinstance(tool_call, dict) else getattr(tool_call, "id", None) + tool_call_type = ( + tool_call.get("type", "function") + if isinstance(tool_call, dict) + else getattr(tool_call, "type", "function") + ) + function_payload = ( + tool_call.get("function") if isinstance(tool_call, dict) else getattr(tool_call, "function", None) + ) + if isinstance(function_payload, BaseModel): + function_payload = function_payload.model_dump(mode="json", exclude_none=True) + tool_call_chunk = ChatCompletionStreamResponse( + id=response.id, + created=response.created, + model=response.model, + choices=[ + ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage( + tool_calls=[ + { + "index": index, + "id": tool_call_id, + "type": tool_call_type, + "function": function_payload, + } + ] + ), + ) + ], + ) + yield encode_sse_event(tool_call_chunk.model_dump(mode="json", exclude_none=True)) + + final_chunk = ChatCompletionStreamResponse( + id=response.id, + created=response.created, + model=response.model, + choices=[ + ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(), + finish_reason=choice.finish_reason, + ) + ], + usage=response.usage if include_usage else None, + ) + yield encode_sse_event(final_chunk.model_dump(mode="json", exclude_none=True)) + yield encode_sse_event("[DONE]") + + def _openai_message_to_canonical_message(self, message: dict[str, Any]) -> CanonicalMessage: + role = str(message.get("role", "user")) + content_blocks: list[Any] = [] + if role == "tool": + content_blocks.append( + CanonicalToolResultBlock( + tool_result=CanonicalToolResult( + tool_call_id=str(message.get("tool_call_id") or message.get("name") or ""), + name=message.get("name"), + output=message.get("content"), + output_text=coerce_content_to_text(message.get("content")), + metadata={"source_protocol": "openai_chat_completions"}, + ) + ) + ) + else: + content_text = coerce_content_to_text(message.get("content")) + if content_text: + content_blocks.append(CanonicalTextBlock(text=content_text)) + for tool_call in message.get("tool_calls") or []: + content_blocks.append(CanonicalToolCallBlock(tool_call=self._openai_tool_call_to_canonical(tool_call))) + return CanonicalMessage( + role=role if role in {"system", "user", "assistant", "tool"} else "user", + content=content_blocks, + name=message.get("name"), + metadata={ + key: value + for key, value in { + "source_protocol": "openai_chat_completions", + "tool_call_id": message.get("tool_call_id"), + }.items() + if value is not None + }, + ) + + def _openai_tools_to_canonical(self, tools: list[dict[str, Any]] | None) -> list[CanonicalToolDefinition]: + if not tools: + return [] + canonical_tools = [] + for tool in tools: + function_spec = tool.get("function", tool) + canonical_tools.append( + CanonicalToolDefinition( + name=str(function_spec.get("name", "")), + description=function_spec.get("description"), + parameters_json_schema=function_spec.get("parameters", {}), + metadata={"source_protocol": "openai_chat_completions"}, + ) + ) + return canonical_tools + + def _openai_tool_choice_to_canonical(self, tool_choice: Any) -> CanonicalToolChoice | None: + if tool_choice is None: + return None + if isinstance(tool_choice, str): + return CanonicalToolChoice(type=tool_choice) + function_spec = tool_choice.get("function") or {} + return CanonicalToolChoice( + type="specific", + tool_name=function_spec.get("name"), + metadata={"source_protocol": "openai_chat_completions"}, + ) + + def _openai_tool_call_to_canonical(self, tool_call: dict[str, Any]) -> CanonicalToolCall: + function_spec = tool_call.get("function") or {} + raw_arguments = function_spec.get("arguments") + parsed_arguments = self._parse_tool_arguments(raw_arguments) + metadata: dict[str, Any] = {"source_protocol": "openai_chat_completions"} + if isinstance(parsed_arguments, dict) and parsed_arguments.pop("__parse_error__", False): + metadata["arguments_parse_error"] = True + return CanonicalToolCall( + id=str(tool_call.get("id") or f"call_{uuid4().hex}"), + name=str(function_spec.get("name", "")), + arguments=parsed_arguments, + raw_arguments_text=raw_arguments if isinstance(raw_arguments, str) else None, + metadata=metadata, + ) + + def _canonical_tool_calls_to_openai(self, response: CanonicalGenerateResponse) -> list[dict[str, Any]]: + tool_calls = [] + for block in response.output.content: + if isinstance(block, CanonicalToolCallBlock): + tool_calls.append( + { + "id": block.tool_call.id, + "type": "function", + "function": { + "name": block.tool_call.name, + "arguments": stringify_tool_arguments(block.tool_call), + }, + } + ) + return tool_calls + + def _render_openai_response_text(self, response: CanonicalGenerateResponse) -> str | None: + text_chunks = [] + for block in response.output.content: + if isinstance(block, CanonicalTextBlock): + text_chunks.append(block.text) + joined = "".join(text_chunks).strip() + return joined or None + + def _render_openai_reasoning_text(self, response: CanonicalGenerateResponse) -> str | None: + reasoning_chunks: list[str] = [] + for block in response.output.content: + if isinstance(block, CanonicalReasoningBlock): + reasoning_chunks.extend(step.text for step in block.reasoning.steps if step.text) + joined = "\n".join(chunk for chunk in reasoning_chunks if chunk).strip() + return joined or None + + def _parse_tool_arguments(self, raw_arguments: Any) -> Any: + if not isinstance(raw_arguments, str): + return raw_arguments + try: + return json.loads(raw_arguments) + except Exception: + return {"__parse_error__": True, "raw": raw_arguments} diff --git a/xtuner/v1/rl/gateway/adapters/responses.py b/xtuner/v1/rl/gateway/adapters/responses.py new file mode 100644 index 0000000000..e0c1e74495 --- /dev/null +++ b/xtuner/v1/rl/gateway/adapters/responses.py @@ -0,0 +1,587 @@ +from __future__ import annotations + +import json +import re +import time +from collections.abc import AsyncIterator +from typing import Any, Literal +from uuid import uuid4 + +from fastapi.responses import StreamingResponse +from pydantic import BaseModel, ConfigDict + +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast + +from ..core.models import ( + CanonicalGenerateRequest, + CanonicalGenerateResponse, + CanonicalMessage, + CanonicalReasoning, + CanonicalReasoningBlock, + CanonicalReasoningStep, + CanonicalTextBlock, + CanonicalToolCall, + CanonicalToolCallBlock, + CanonicalToolChoice, + CanonicalToolDefinition, + CanonicalToolResult, + CanonicalToolResultBlock, +) +from .base import BaseChatAPIAdapter, stringify_tool_arguments +from .openai import OpenAIChatAdapterError +from .streaming import build_sse_response, encode_sse_event +from .trace import ChatTraceStore, normalize_trace_payload + + +class ResponsesRequest(BaseModel): + model_config = ConfigDict(extra="allow") + + session_uid: int | None = None + model: str | None = None + instructions: str | None = None + input: str | list[dict[str, Any]] | None = None + tools: list[dict[str, Any]] | None = None + tool_choice: str | dict[str, Any] | None = None + stream: bool = False + store: bool = False + parallel_tool_calls: bool | None = None + include: list[Any] | None = None + reasoning: dict[str, Any] | None = None + max_output_tokens: int | None = None + temperature: float | None = None + top_p: float | None = None + + +class ResponsesUsage(BaseModel): + model_config = ConfigDict(extra="allow") + + input_tokens: int + output_tokens: int + total_tokens: int + + +class ResponsesResponse(BaseModel): + model_config = ConfigDict(extra="allow") + + id: str + object: Literal["response"] = "response" + created_at: int + status: Literal["completed"] = "completed" + model: str + output: list[dict[str, Any]] + output_text: str = "" + parallel_tool_calls: bool = False + store: bool = False + text: dict[str, Any] = {"format": {"type": "text"}} + usage: ResponsesUsage + + +class OpenAIResponsesAdapter(BaseChatAPIAdapter[ResponsesRequest, ResponsesResponse]): + _disabled_tool_names = { + "list_mcp_resources", + "list_mcp_resource_templates", + "read_mcp_resource", + "request_user_input", + } + + def __init__( + self, + generate_handler, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | str | None, + default_model_name: str | None = None, + context_length: int | None = None, + capture_folder: str | None = None, + trace_store: ChatTraceStore | None = None, + ): + if isinstance(tokenizer, str): + tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True) + super().__init__(generate_handler, tokenizer=tokenizer, capture_folder=capture_folder, trace_store=trace_store) + self._default_model_name = default_model_name + self._context_length = context_length + + async def responses( + self, + request: ResponsesRequest, + *, + api_key: str | None = None, + ) -> ResponsesResponse | StreamingResponse: + if request.stream: + response = await self.handle_request(request, api_key=api_key) + return build_sse_response(self.iter_stream_events(response)) + return await self.handle_request(request, api_key=api_key) + + def validate_request(self, request: ResponsesRequest) -> None: + return None + + def request_to_canonical_request(self, request: ResponsesRequest) -> CanonicalGenerateRequest: + return CanonicalGenerateRequest( + request_id=f"responses_req_{uuid4().hex}", + model=request.model or self._default_model_name or "rollout-controller", + messages=self._responses_input_to_canonical_messages(request), + tools=self._responses_tools_to_canonical(request.tools), + tool_choice=self._responses_tool_choice_to_canonical(request.tool_choice), + parallel_tool_calls=request.parallel_tool_calls, + temperature=request.temperature, + top_p=request.top_p, + max_tokens=request.max_output_tokens, + stream=False, + metadata={ + key: value + for key, value in { + "source_protocol": "openai_responses", + "client_stream": bool(request.stream), + "session_uid": request.session_uid, + "store": request.store, + "include": request.include, + "reasoning": request.reasoning, + }.items() + if value is not None + }, + ) + + def normalize_request(self, request: ResponsesRequest) -> dict[str, Any]: + return normalize_trace_payload(request.model_dump(mode="python", exclude_none=True)) + + def normalize_response(self, response: ResponsesResponse) -> dict[str, Any]: + return normalize_trace_payload(response.model_dump(mode="python", exclude_none=True)) + + async def iter_stream_events( + self, + response: ResponsesResponse, + ) -> AsyncIterator[str]: + created_response = response.model_dump(mode="json", exclude_none=True) + created_response["status"] = "in_progress" + + yield encode_sse_event( + { + "type": "response.created", + "response": created_response, + }, + event="response.created", + ) + yield encode_sse_event( + { + "type": "response.in_progress", + "response": created_response, + }, + event="response.in_progress", + ) + + for output_index, item in enumerate(response.output): + yield encode_sse_event( + { + "type": "response.output_item.added", + "output_index": output_index, + "item": item, + }, + event="response.output_item.added", + ) + + if item.get("type") == "message": + for content_index, part in enumerate(item.get("content", [])): + yield encode_sse_event( + { + "type": "response.content_part.added", + "output_index": output_index, + "content_index": content_index, + "item_id": item.get("id"), + "part": part, + }, + event="response.content_part.added", + ) + if part.get("type") == "output_text": + yield encode_sse_event( + { + "type": "response.output_text.delta", + "output_index": output_index, + "content_index": content_index, + "item_id": item.get("id"), + "delta": part.get("text", ""), + }, + event="response.output_text.delta", + ) + yield encode_sse_event( + { + "type": "response.output_text.done", + "output_index": output_index, + "content_index": content_index, + "item_id": item.get("id"), + "text": part.get("text", ""), + }, + event="response.output_text.done", + ) + yield encode_sse_event( + { + "type": "response.content_part.done", + "output_index": output_index, + "content_index": content_index, + "item_id": item.get("id"), + "part": part, + }, + event="response.content_part.done", + ) + + if item.get("type") == "function_call": + yield encode_sse_event( + { + "type": "response.function_call_arguments.delta", + "output_index": output_index, + "item_id": item.get("id"), + "delta": item.get("arguments", ""), + }, + event="response.function_call_arguments.delta", + ) + yield encode_sse_event( + { + "type": "response.function_call_arguments.done", + "output_index": output_index, + "item_id": item.get("id"), + "arguments": item.get("arguments", ""), + }, + event="response.function_call_arguments.done", + ) + + yield encode_sse_event( + { + "type": "response.output_item.done", + "output_index": output_index, + "item": item, + }, + event="response.output_item.done", + ) + + yield encode_sse_event( + { + "type": "response.completed", + "response": response.model_dump(mode="json", exclude_none=True), + }, + event="response.completed", + ) + + def canonical_response_to_protocol_response( + self, + canonical_response: CanonicalGenerateResponse, + request: ResponsesRequest, + ) -> ResponsesResponse: + output_items = self._canonical_response_to_responses_output_items(canonical_response) + output_text = "".join( + block.text for block in canonical_response.output.content if isinstance(block, CanonicalTextBlock) + ).strip() + return ResponsesResponse( + id=f"resp_{canonical_response.request_id}", + created_at=int(time.time()), + model=canonical_response.model or self._default_model_name or "rollout-controller", + output=output_items, + output_text=output_text, + parallel_tool_calls=bool( + request.parallel_tool_calls + if request is not None + else canonical_response.metadata.get("parallel_tool_calls") + ), + store=bool(request.store) if request is not None else False, + usage=ResponsesUsage( + input_tokens=canonical_response.usage.prompt_tokens, + output_tokens=canonical_response.usage.completion_tokens, + total_tokens=canonical_response.usage.total_tokens, + ), + ) + + def _normalize_input_role(self, role: Any) -> str: + if role in {"developer", "system"}: + return "system" + if role in {"assistant", "tool"}: + return str(role) + return "user" + + def _extract_message_item_text(self, content: Any) -> str: + if isinstance(content, str): + return content + if not isinstance(content, list): + return str(content) + text_chunks: list[str] = [] + for part in content: + part_type = part.get("type") + if part_type in {"input_text", "output_text", "text", "summary_text", "reasoning_text"}: + text_chunks.append(str(part.get("text", ""))) + return "\n".join(chunk for chunk in text_chunks if chunk) + + def _serialize_tool_output(self, output: Any, tool_name: str | None = None) -> str: + if output is None: + return "" + if isinstance(output, str): + return self._sanitize_tool_output_text(output, tool_name=tool_name) + if isinstance(output, list): + text_chunks = [str(part.get("text", "")) for part in output if isinstance(part, dict) and "text" in part] + if text_chunks: + return self._sanitize_tool_output_text("\n".join(text_chunks), tool_name=tool_name) + return json.dumps(output, ensure_ascii=False) + if isinstance(output, dict): + return json.dumps(output, ensure_ascii=False) + return str(output) + + def _sanitize_tool_output_text(self, text: str, tool_name: str | None = None) -> str: + if tool_name not in {"exec_command", "write_stdin"}: + return text + marker = "\nOutput:\n" + if marker in text: + prefix, body = text.split(marker, 1) + exit_code = self._extract_exec_exit_code(prefix) + body = body.strip() + if exit_code is None: + return body + if body: + return f"[exit_code={exit_code}]\n{body}" + return f"[exit_code={exit_code}]" + return text + + def _extract_exec_exit_code(self, text: str) -> int | None: + match = re.search(r"Process exited with code (\d+)", text) + if match is not None: + return int(match.group(1)) + return None + + def _responses_input_to_canonical_messages(self, request: ResponsesRequest) -> list[CanonicalMessage]: + messages: list[CanonicalMessage] = [] + if request.instructions: + messages.append( + CanonicalMessage( + role="system", + content=[CanonicalTextBlock(text=request.instructions)], + metadata={"source_protocol": "openai_responses"}, + ) + ) + if request.input is None: + return messages + if isinstance(request.input, str): + messages.append( + CanonicalMessage( + role="user", + content=[CanonicalTextBlock(text=request.input)], + metadata={"source_protocol": "openai_responses"}, + ) + ) + return messages + + tool_name_by_call_id: dict[str, str] = {} + for item in request.input: + item_type = item.get("type", "message") + if item_type == "message": + role = self._normalize_input_role(item.get("role")) + content_blocks = self._responses_message_content_to_canonical(item.get("content")) + messages.append( + CanonicalMessage( + role=role if role in {"system", "user", "assistant", "tool"} else "user", + content=content_blocks, + metadata={"source_protocol": "openai_responses"}, + ) + ) + elif item_type == "function_call": + call_id = str(item.get("call_id") or f"call_{uuid4().hex}") + tool_name = str(item.get("name", "")) + tool_name_by_call_id[call_id] = tool_name + messages.append( + CanonicalMessage( + role="assistant", + content=[ + CanonicalToolCallBlock( + tool_call=CanonicalToolCall( + id=call_id, + name=tool_name, + arguments=self._parse_json_string_or_mapping(item.get("arguments")), + raw_arguments_text=item.get("arguments") + if isinstance(item.get("arguments"), str) + else None, + metadata={"source_protocol": "openai_responses"}, + ) + ) + ], + metadata={"source_protocol": "openai_responses"}, + ) + ) + elif item_type == "function_call_output": + call_id = str(item.get("call_id") or "") + output = item.get("output") + messages.append( + CanonicalMessage( + role="tool", + content=[ + CanonicalToolResultBlock( + tool_result=CanonicalToolResult( + tool_call_id=call_id, + name=tool_name_by_call_id.get(call_id), + output=output, + output_text=self._serialize_tool_output( + output, tool_name=tool_name_by_call_id.get(call_id) + ), + metadata={"source_protocol": "openai_responses"}, + ) + ) + ], + metadata={"source_protocol": "openai_responses"}, + ) + ) + elif item_type == "reasoning": + reasoning_text = self._responses_reasoning_item_to_text(item) + messages.append( + CanonicalMessage( + role="assistant", + content=[ + CanonicalReasoningBlock( + reasoning=CanonicalReasoning( + steps=[CanonicalReasoningStep(text=reasoning_text)] if reasoning_text else [], + metadata={"source_protocol": "openai_responses"}, + ) + ) + ], + metadata={"source_protocol": "openai_responses"}, + ) + ) + return messages + + def _responses_message_content_to_canonical(self, content: Any) -> list[Any]: + if isinstance(content, str): + return [CanonicalTextBlock(text=content)] if content else [] + if not isinstance(content, list): + return [CanonicalTextBlock(text=str(content))] + + blocks: list[Any] = [] + unsupported_types: list[str] = [] + for part in content: + part_type = part.get("type") + if part_type in {"input_text", "output_text", "text"}: + text = str(part.get("text", "")) + if text: + blocks.append(CanonicalTextBlock(text=text)) + elif part_type in {"summary_text", "reasoning_text"}: + reasoning_text = str(part.get("text", "")) + if reasoning_text: + blocks.append( + CanonicalReasoningBlock( + reasoning=CanonicalReasoning( + steps=[CanonicalReasoningStep(text=reasoning_text)], + metadata={"source_protocol": "openai_responses"}, + ) + ) + ) + else: + unsupported_types.append(str(part_type)) + if unsupported_types: + unsupported_str = ", ".join(sorted(set(unsupported_types))) + raise OpenAIChatAdapterError( + f"Unsupported Responses content block type(s): {unsupported_str}", + "invalid_request_error", + "unsupported_content_block", + ) + return blocks + + def _responses_reasoning_item_to_text(self, item: dict[str, Any]) -> str: + content = item.get("content") + if isinstance(content, list): + chunks = [] + for part in content: + if isinstance(part, dict) and part.get("type") in {"reasoning_text", "summary_text", "text"}: + chunks.append(str(part.get("text", ""))) + if chunks: + return "\n".join(chunk for chunk in chunks if chunk) + summary = item.get("summary") + if isinstance(summary, list): + chunks = [str(part.get("text", "")) for part in summary if isinstance(part, dict)] + if chunks: + return "\n".join(chunk for chunk in chunks if chunk) + return str(item.get("text", "")) + + def _responses_tools_to_canonical(self, tools: list[dict[str, Any]] | None) -> list[CanonicalToolDefinition]: + if not tools: + return [] + canonical_tools = [] + for tool in tools: + if tool.get("type") != "function": + continue + tool_name = str(tool.get("name", "")) + if tool_name in self._disabled_tool_names: + continue + canonical_tools.append( + CanonicalToolDefinition( + name=tool_name, + description=tool.get("description"), + parameters_json_schema=tool.get("parameters", {}), + metadata={"source_protocol": "openai_responses"}, + ) + ) + return canonical_tools + + def _responses_tool_choice_to_canonical( + self, tool_choice: str | dict[str, Any] | None + ) -> CanonicalToolChoice | None: + if tool_choice is None: + return None + if isinstance(tool_choice, str): + return CanonicalToolChoice(type=tool_choice) + if tool_choice.get("type") == "function": + return CanonicalToolChoice( + type="specific", + tool_name=tool_choice.get("name"), + metadata={"source_protocol": "openai_responses"}, + ) + return CanonicalToolChoice( + type=str(tool_choice.get("type", "auto")), + metadata={"source_protocol": "openai_responses"}, + ) + + def _canonical_response_to_responses_output_items( + self, + response: CanonicalGenerateResponse, + ) -> list[dict[str, Any]]: + output_items: list[dict[str, Any]] = [] + for block in response.output.content: + if isinstance(block, CanonicalTextBlock): + output_items.append( + { + "id": f"msg_{uuid4().hex}", + "type": "message", + "status": "completed", + "role": "assistant", + "content": [{"type": "output_text", "text": block.text, "annotations": []}], + } + ) + elif isinstance(block, CanonicalToolCallBlock): + output_items.append( + { + "id": f"fc_{uuid4().hex}", + "type": "function_call", + "status": "completed", + "call_id": block.tool_call.id, + "name": block.tool_call.name, + "arguments": stringify_tool_arguments(block.tool_call), + } + ) + elif isinstance(block, CanonicalToolResultBlock): + output_items.append( + { + "id": f"fco_{uuid4().hex}", + "type": "function_call_output", + "call_id": block.tool_result.tool_call_id, + "output": block.tool_result.output + if block.tool_result.output is not None + else block.tool_result.output_text, + } + ) + elif isinstance(block, CanonicalReasoningBlock): + reasoning_text = "\n".join(step.text for step in block.reasoning.steps if step.text).strip() + if reasoning_text: + output_items.append( + { + "id": f"rs_{uuid4().hex}", + "type": "reasoning", + "summary": [{"type": "summary_text", "text": reasoning_text}], + } + ) + return output_items + + def _parse_json_string_or_mapping(self, value: Any) -> Any: + if isinstance(value, str): + try: + return json.loads(value) + except Exception: + return {"raw": value} + return value or {} diff --git a/xtuner/v1/rl/gateway/adapters/streaming.py b/xtuner/v1/rl/gateway/adapters/streaming.py new file mode 100644 index 0000000000..41fad73cb7 --- /dev/null +++ b/xtuner/v1/rl/gateway/adapters/streaming.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import json +from typing import Any + +from fastapi.responses import StreamingResponse + + +def encode_sse_event(data: Any, *, event: str | None = None) -> str: + if isinstance(data, str): + payload = data + else: + payload = json.dumps(data, ensure_ascii=False) + + lines: list[str] = [] + if event is not None: + lines.append(f"event: {event}") + if payload: + lines.extend(f"data: {line}" for line in payload.splitlines()) + else: + lines.append("data:") + return "\n".join(lines) + "\n\n" + + +def build_sse_response(event_iterator) -> StreamingResponse: + return StreamingResponse( + event_iterator, + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) diff --git a/xtuner/v1/rl/gateway/adapters/trace.py b/xtuner/v1/rl/gateway/adapters/trace.py new file mode 100644 index 0000000000..dca5ff4e3b --- /dev/null +++ b/xtuner/v1/rl/gateway/adapters/trace.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +import hashlib +import threading +import time +from collections import OrderedDict +from collections.abc import Sequence +from copy import deepcopy +from dataclasses import dataclass +from typing import Any + +from pydantic import BaseModel + +from xtuner.v1.data_proto.rl_data import Status + + +DEFAULT_CHAT_TRACE_KEY = "__default__" + + +def build_api_key_trace_key(api_key: str | None) -> str: + if not api_key: + return DEFAULT_CHAT_TRACE_KEY + api_key_hash = hashlib.sha256(api_key.encode("utf-8")).hexdigest()[:16] + return f"api_key_{api_key_hash}" + + +def normalize_trace_payload(value: Any) -> Any: + if isinstance(value, BaseModel): + return normalize_trace_payload(value.model_dump(mode="python", exclude_none=True)) + if isinstance(value, dict): + return { + str(key): normalize_trace_payload(val) + for key, val in sorted(value.items(), key=lambda item: str(item[0])) + if val is not None + } + if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): + return [normalize_trace_payload(item) for item in value] + return value + + +def snapshot_routed_experts(routed_experts: Any) -> Any: + if routed_experts is None: + return None + try: + import ray + + if isinstance(routed_experts, ray.ObjectRef): + return routed_experts + except Exception: + pass + return deepcopy(routed_experts) + + +@dataclass +class ChatTraceRecord: + trace_key: str + request_snapshot: dict[str, Any] + response_snapshot: dict[str, Any] + prompt_ids: list[int] + response_ids: list[int] + input_text: str + output_text: str + logprobs: list[float] | None + routed_experts: Any + finish_reason: str | None + status: Status + sequence: int = -1 + created_at: float = 0.0 + request_id: str | None = None + + +class ChatTraceStore: + def __init__(self, max_entries: int = 10000): + self._max_entries = max_entries + self._records: OrderedDict[str, OrderedDict[int, ChatTraceRecord]] = OrderedDict() + self._record_order: OrderedDict[tuple[str, int], None] = OrderedDict() + self._next_sequence: dict[str, int] = {} + self._lock = threading.RLock() + + def append(self, record: ChatTraceRecord) -> ChatTraceRecord: + with self._lock: + sequence = self._next_sequence.get(record.trace_key, 0) + self._next_sequence[record.trace_key] = sequence + 1 + record.sequence = sequence + record.created_at = time.time() + records = self._records.setdefault(record.trace_key, OrderedDict()) + records[sequence] = record + self._record_order[(record.trace_key, sequence)] = None + self._evict_if_needed() + return record + + def get(self, trace_key: str) -> list[ChatTraceRecord]: + with self._lock: + records = self._records.get(trace_key) + if records is None: + return [] + return list(records.values()) + + def pop(self, trace_key: str) -> list[ChatTraceRecord]: + with self._lock: + records = self._records.pop(trace_key, None) + self._next_sequence.pop(trace_key, None) + if records is None: + return [] + for sequence in records: + self._record_order.pop((trace_key, sequence), None) + return list(records.values()) + + def clear(self, trace_key: str) -> None: + with self._lock: + records = self._records.pop(trace_key, None) + self._next_sequence.pop(trace_key, None) + if records is None: + return + for sequence in records: + self._record_order.pop((trace_key, sequence), None) + + def _evict_if_needed(self) -> None: + while len(self._record_order) > self._max_entries: + (trace_key, sequence), _ = self._record_order.popitem(last=False) + records = self._records.get(trace_key) + if records is None: + continue + records.pop(sequence, None) + if not records: + self._records.pop(trace_key, None) + self._next_sequence.pop(trace_key, None) diff --git a/xtuner/v1/rl/gateway/backend/__init__.py b/xtuner/v1/rl/gateway/backend/__init__.py new file mode 100644 index 0000000000..00a867413f --- /dev/null +++ b/xtuner/v1/rl/gateway/backend/__init__.py @@ -0,0 +1,8 @@ +from .local_backend import LocalRolloutBackend +from .protocol import GatewayBackend + + +__all__ = [ + "GatewayBackend", + "LocalRolloutBackend", +] diff --git a/xtuner/v1/rl/gateway/backend/local_backend.py b/xtuner/v1/rl/gateway/backend/local_backend.py new file mode 100644 index 0000000000..ae8e773cb9 --- /dev/null +++ b/xtuner/v1/rl/gateway/backend/local_backend.py @@ -0,0 +1,429 @@ +from __future__ import annotations + +import json +from typing import Any +from uuid import uuid4 + +import ray +from ray.actor import ActorHandle + +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast +from xtuner.v1.data_proto.rl_data import RolloutState, RolloutToolCall, SampleParams, Status +from xtuner.v1.rl.rollout.parser.factory import build_tool_call_parser +from xtuner.v1.rl.rollout.worker import RolloutConfig + +from ..adapters.base import coerce_content_to_text +from ..adapters.trace import normalize_trace_payload +from ..core.exceptions import ContextLengthExceededError, ToolCallParseError +from ..core.models import ( + BackendHealth, + CanonicalAssistantTurn, + CanonicalGenerateRequest, + CanonicalGenerateResponse, + CanonicalReasoning, + CanonicalReasoningBlock, + CanonicalReasoningStep, + CanonicalTextBlock, + CanonicalToolCall, + CanonicalToolCallBlock, + CanonicalToolChoice, + CanonicalToolDefinition, + CanonicalToolResultBlock, + CanonicalUsage, + ModelCapabilities, + ModelCard, +) + + +class LocalRolloutBackend: + def __init__( + self, + controller: ActorHandle, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | str | None = None, + rollout_config: RolloutConfig | None = None, + ): + self._controller = controller + self._config = rollout_config or self._resolve_rollout_config(controller) + if isinstance(tokenizer, str): + tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True) + resolved_tokenizer = tokenizer + if resolved_tokenizer is None: + resolved_tokenizer = AutoTokenizer.from_pretrained( + self._config.tokenizer_path, + trust_remote_code=True, + ) + self._tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast = resolved_tokenizer + self._tool_call_parser = build_tool_call_parser(self._config.tool_call_parser) + + async def generate(self, request: CanonicalGenerateRequest) -> CanonicalGenerateResponse: + rollout_state = self._canonical_request_to_rollout_state(request) + rollout_state = await self._controller.generate.remote(rollout_state) + self._raise_for_failed_rollout(rollout_state, request_id=str(rollout_state.uid)) + return self._rollout_state_to_canonical_response(rollout_state, request) + + async def health(self) -> BackendHealth: + ready, details = await self._controller.get_ready_status.remote() + return BackendHealth( + ready=ready, + status="ready" if ready else "unavailable", + details=details, + ) + + async def list_models(self) -> list[ModelCard]: + return [ + ModelCard( + id=self._model_name, + backend=self._config.rollout_backend, + context_length=self._config.context_length, + ) + ] + + async def get_capabilities(self) -> ModelCapabilities: + return ModelCapabilities( + model=self._model_name, + backend=self._config.rollout_backend, + context_length=self._config.context_length, + supports_stream=True, + supports_tools=True, + supports_cancel=False, + supports_parallel_tool_calls=True, + supports_reasoning=True, + ) + + async def cancel(self, request_id: str) -> dict[str, Any]: + return { + "request_id": request_id, + "cancelled": False, + "status": "not_supported", + } + + @property + def _model_name(self) -> str: + return self._config.model_name or "rollout-controller" + + def _resolve_rollout_config(self, controller: ActorHandle) -> RolloutConfig: + rollout_metadata = ray.get(controller.get_rollout_metadata.remote()) + return rollout_metadata["rollout_config"] + + def _canonical_request_to_rollout_state(self, canonical_request: CanonicalGenerateRequest) -> RolloutState: + internal_messages = self._canonical_messages_to_backend_messages(canonical_request.messages) + rollout_tools = self._canonical_tools_to_backend(canonical_request.tools) + rollout_tool_choice = self._canonical_tool_choice_to_backend(canonical_request.tool_choice) + prompt_ids = self._render_prompt_ids(internal_messages, rollout_tools) + max_tokens = self._fit_max_tokens_to_context(prompt_ids, canonical_request.max_tokens) + return RolloutState( + uid=uuid4().int, + message=internal_messages, + prompt_ids=prompt_ids, + tokens=prompt_ids, + session_uid=canonical_request.metadata.get("session_uid"), + tools=rollout_tools, + tool_choice=rollout_tool_choice, + sample_params=self._build_sample_params(canonical_request, max_tokens=max_tokens), + ) + + def _raise_for_failed_rollout(self, rollout_state: RolloutState, request_id: str) -> None: + if rollout_state.status == Status.FAILED: + raise RuntimeError(rollout_state.error_msg or f"Rollout generation failed for request {request_id}") + + def _rollout_state_to_canonical_response( + self, + rollout_state: RolloutState, + canonical_request: CanonicalGenerateRequest, + ) -> CanonicalGenerateResponse: + request_id = str(rollout_state.uid) + normal_text = rollout_state.response + tool_calls = [ + self._rollout_tool_call_to_canonical(tool_call) for tool_call in (rollout_state.tool_calls or []) + ] + self._raise_for_unparsed_tool_call_markup( + canonical_request=canonical_request, + normal_text=normal_text, + tool_calls=tool_calls, + ) + reasoning_text = None + if isinstance(rollout_state.extra_fields.get("reasoning_text"), str): + reasoning_text = rollout_state.extra_fields.get("reasoning_text") + content_blocks: list[Any] = [] + if reasoning_text: + content_blocks.append( + CanonicalReasoningBlock( + reasoning=CanonicalReasoning( + steps=[CanonicalReasoningStep(text=reasoning_text)], + metadata={"source_backend": "local_rollout"}, + ) + ) + ) + if normal_text: + content_blocks.append(CanonicalTextBlock(text=normal_text)) + for tool_call in tool_calls: + content_blocks.append(CanonicalToolCallBlock(tool_call=tool_call)) + + finish_reason = rollout_state.finish_reason or "stop" + if tool_calls and finish_reason == "stop": + finish_reason = "tool_calls" + + prompt_tokens = len(rollout_state.prompt_ids or []) + completion_tokens = self._count_completion_tokens(rollout_state) + metadata = { + "rollout_trace": self._build_rollout_trace_snapshot(rollout_state), + "parallel_tool_calls": canonical_request.parallel_tool_calls, + "source_backend": "local_rollout", + } + return CanonicalGenerateResponse( + request_id=request_id, + model=canonical_request.model or self._model_name, + output=CanonicalAssistantTurn(content=content_blocks), + finish_reason=finish_reason, + usage=CanonicalUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + metadata=metadata, + ) + + def _raise_for_unparsed_tool_call_markup( + self, + *, + canonical_request: CanonicalGenerateRequest, + normal_text: str | None, + tool_calls: list[CanonicalToolCall], + ) -> None: + if self._tool_call_parser is None: + return + if self._tool_call_parser.should_reject_unparsed_markup( + has_tools=bool(canonical_request.tools), + text=normal_text, + parsed_tool_calls=tool_calls, + ): + raise ToolCallParseError( + "Tool-enabled generation returned tool-call markup that could not be parsed into structured " + "tool calls." + ) + + def _canonical_messages_to_backend_messages(self, messages: list[Any]) -> list[dict[str, Any]]: + backend_messages: list[dict[str, Any]] = [] + for message in messages: + if message.role == "tool": + for block in message.content: + if isinstance(block, CanonicalToolResultBlock): + backend_messages.append( + { + "role": "tool", + "content": block.tool_result.output_text + if block.tool_result.output_text is not None + else coerce_content_to_text(block.tool_result.output), + "tool_call_id": block.tool_result.tool_call_id, + } + ) + continue + + text_chunks: list[str] = [] + tool_calls: list[dict[str, Any]] = [] + for block in message.content: + if isinstance(block, CanonicalTextBlock): + if block.text: + text_chunks.append(block.text) + elif isinstance(block, CanonicalReasoningBlock): + reasoning_text = "\n".join(step.text for step in block.reasoning.steps if step.text).strip() + if reasoning_text: + text_chunks.append(reasoning_text) + elif isinstance(block, CanonicalToolCallBlock): + tool_calls.append( + { + "id": block.tool_call.id, + "type": "function", + "function": { + "name": block.tool_call.name, + "arguments": self._render_tool_arguments_for_template(block.tool_call), + }, + } + ) + payload: dict[str, Any] = {"role": message.role, "content": "\n".join(text_chunks)} + if message.name: + payload["name"] = message.name + if tool_calls: + payload["tool_calls"] = tool_calls + backend_messages.append(self._normalize_backend_message(payload)) + return backend_messages + + def _canonical_tools_to_backend(self, tools: list[CanonicalToolDefinition]) -> list[dict[str, Any]] | None: + if not tools: + return None + return normalize_trace_payload( + [ + { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description or "", + "parameters": tool.parameters_json_schema, + }, + } + for tool in tools + ] + ) + + def _canonical_tool_choice_to_backend(self, tool_choice: CanonicalToolChoice | None) -> Any: + if tool_choice is None: + return None + if tool_choice.type == "specific": + return { + "type": "function", + "function": {"name": tool_choice.tool_name}, + } + return tool_choice.type + + def _render_prompt_ids( + self, + internal_messages: list[dict[str, Any]], + rollout_tools: list[dict[str, Any]] | None, + ) -> list[int] | None: + raw_prompt_ids = self._tokenizer.apply_chat_template( + internal_messages, + tools=rollout_tools, + tokenize=True, + add_generation_prompt=True, + ) + if hasattr(raw_prompt_ids, "get"): + return raw_prompt_ids.get("input_ids") + return list(raw_prompt_ids) + + def _build_sample_params( + self, + canonical_request: CanonicalGenerateRequest, + *, + max_tokens: int | None, + ) -> SampleParams: + kwargs = { + "return_token_ids": True, + "return_logprob": True, + "stream": canonical_request.stream, + "stops": canonical_request.stop, + **{ + key: value + for key, value in { + "n": canonical_request.metadata.get("n"), + "max_tokens": max_tokens if max_tokens is not None else canonical_request.max_tokens, + "temperature": canonical_request.temperature, + "top_p": canonical_request.top_p, + "top_k": canonical_request.metadata.get("top_k"), + "repetition_penalty": canonical_request.metadata.get("repetition_penalty"), + "presence_penalty": canonical_request.metadata.get("presence_penalty"), + "frequency_penalty": canonical_request.metadata.get("frequency_penalty"), + "min_tokens": canonical_request.metadata.get("min_tokens"), + "stop_token_ids": canonical_request.metadata.get("stop_token_ids"), + "skip_special_tokens": canonical_request.metadata.get("skip_special_tokens"), + "no_stop_trim": canonical_request.metadata.get("no_stop_trim"), + "spaces_between_special_tokens": canonical_request.metadata.get("spaces_between_special_tokens"), + "sampling_seed": canonical_request.metadata.get("sampling_seed"), + "return_routed_experts": canonical_request.metadata.get("return_routed_experts"), + }.items() + if value is not None + }, + } + return SampleParams(**kwargs) + + def _fit_max_tokens_to_context( + self, + prompt_ids: list[int] | None, + requested_max_tokens: int | None, + ) -> int | None: + context_length = self._config.context_length + if context_length is None or prompt_ids is None or requested_max_tokens is None: + return requested_max_tokens + prompt_tokens = len(prompt_ids) + available_completion_tokens = context_length - prompt_tokens + if available_completion_tokens <= 0: + raise ContextLengthExceededError(prompt_tokens=prompt_tokens, context_length=context_length) + return min(requested_max_tokens, available_completion_tokens) + + def _count_completion_tokens(self, rollout_state: RolloutState) -> int: + if rollout_state.response_ids is not None: + return len(rollout_state.response_ids) + if rollout_state.response: + return len(self._tokenizer(rollout_state.response, add_special_tokens=False)["input_ids"]) + return 0 + + def _rollout_tool_call_to_canonical(self, tool_call: RolloutToolCall) -> CanonicalToolCall: + return CanonicalToolCall( + id=tool_call.id, + name=tool_call.function.name, + arguments=tool_call.function.arguments, + raw_arguments_text=tool_call.function.raw_arguments_text, + ) + + def _build_rollout_trace_snapshot(self, rollout_state: RolloutState) -> dict[str, Any]: + return { + "session_uid": rollout_state.session_uid, + "status": rollout_state.status.value, + "rollout_finish_reason": rollout_state.finish_reason, + "prompt_ids": list(rollout_state.prompt_ids or []), + "response_ids": list(rollout_state.response_ids or []), + "logprobs": None if rollout_state.logprobs is None else list(rollout_state.logprobs), + "routed_experts": normalize_trace_payload(rollout_state.routed_experts), + "internal_messages": normalize_trace_payload(rollout_state.message), + "rollout_tools": normalize_trace_payload(rollout_state.tools), + "rollout_tool_choice": normalize_trace_payload(rollout_state.tool_choice), + "rollout_sample_params": normalize_trace_payload( + rollout_state.sample_params.model_dump(mode="python", exclude_none=True) + ), + "input_text": self._decode_prompt_ids(rollout_state), + "output_text": self._render_rollout_output_text(rollout_state), + } + + def _render_rollout_output_text(self, rollout_state: RolloutState) -> str: + parts = [] + if rollout_state.response: + parts.append(rollout_state.response) + for rollout_tool_call in rollout_state.tool_calls or []: + tool_call = self._rollout_tool_call_to_canonical(rollout_tool_call) + arguments = self._stringify_tool_arguments(tool_call) + parts.append(f"{arguments}") + return "\n".join(parts) + + def _decode_prompt_ids(self, rollout_state: RolloutState) -> str: + """Decode prompt token IDs to text without re-running the chat + template.""" + try: + return self._tokenizer.decode(rollout_state.prompt_ids or [], skip_special_tokens=False) + except Exception: + return "" + + def _stringify_tool_arguments(self, tool_call: CanonicalToolCall) -> str: + if tool_call.raw_arguments_text is not None: + return tool_call.raw_arguments_text + if isinstance(tool_call.arguments, str): + return tool_call.arguments + return json.dumps(tool_call.arguments if tool_call.arguments is not None else {}, ensure_ascii=False) + + def _render_tool_arguments_for_template(self, tool_call: CanonicalToolCall) -> dict[str, Any]: + arguments = tool_call.arguments + if isinstance(arguments, dict): + return arguments + if tool_call.raw_arguments_text is not None: + try: + decoded = json.loads(tool_call.raw_arguments_text) + except Exception: + return {"raw": tool_call.raw_arguments_text} + if isinstance(decoded, dict): + return decoded + return {"value": decoded} + if arguments is None: + return {} + if isinstance(arguments, str): + try: + decoded = json.loads(arguments) + except Exception: + return {"raw": arguments} + if isinstance(decoded, dict): + return decoded + return {"value": decoded} + return {"value": arguments} + + def _normalize_backend_message(self, payload: dict[str, Any]) -> dict[str, Any]: + """Normalize a backend message dict: remove None values and sort keys.""" + return { + str(key): val for key, val in sorted(payload.items(), key=lambda item: str(item[0])) if val is not None + } diff --git a/xtuner/v1/rl/gateway/backend/protocol.py b/xtuner/v1/rl/gateway/backend/protocol.py new file mode 100644 index 0000000000..5fbb2f1446 --- /dev/null +++ b/xtuner/v1/rl/gateway/backend/protocol.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from typing import Any, Protocol + +from ..core.models import ( + BackendHealth, + CanonicalGenerateRequest, + CanonicalGenerateResponse, + ModelCapabilities, + ModelCard, +) + + +class GatewayBackend(Protocol): + async def generate(self, request: CanonicalGenerateRequest) -> CanonicalGenerateResponse: ... + + async def health(self) -> BackendHealth: ... + + async def list_models(self) -> list[ModelCard]: ... + + async def get_capabilities(self) -> ModelCapabilities: ... + + async def cancel(self, request_id: str) -> dict[str, Any]: ... diff --git a/xtuner/v1/rl/gateway/config.py b/xtuner/v1/rl/gateway/config.py new file mode 100644 index 0000000000..235e56b8d2 --- /dev/null +++ b/xtuner/v1/rl/gateway/config.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path + + +@dataclass +class GatewayConfig: + _CAPTURE_PATH_FOLDER = "gateway_captures" + """Configuration for the XTuner gateway HTTP server. + + Examples:: + + # Auto-start with RolloutController: + cfg = GatewayConfig(port=8080) + + # Opt-out of auto-start (start manually later): + cfg = GatewayConfig(port=8080, auto_start=False) + + # With request capture (writes one JSONL file per API key): + cfg = GatewayConfig(port=8080, capture_folder="/tmp/gateway_captures") + """ + + port: int + """TCP port to bind the server on.""" + + host: str = "0.0.0.0" + """Interface to bind the server on.""" + + auto_start: bool = True + """Whether to start the gateway automatically when the RolloutController + initialises. + + Set to False if you want to start the gateway manually via + :func:`~xtuner.v1.rl.gateway.serve_gateway` or + :meth:`~xtuner.v1.rl.rollout.controller.RolloutController.start_gateway`. + """ + + capture_folder: str | None = None + """Optional folder for writing per-request trace records. + + The gateway writes one JSONL file per API key inside this folder. If + omitted, this resolves to ``./worker_dirs/gateway_captures``; when started + by :class:`~xtuner.v1.rl.rollout.controller.RolloutController`, an omitted + value resolves relative to ``RolloutConfig.worker_log_dir`` instead. + """ + title: str = "XTuner Gateway" + """FastAPI application title shown in /docs.""" + + version: str = "0.1.0" + """FastAPI application version string.""" + + log_level: str = "warning" + """Uvicorn log level (debug/info/warning/error/critical).""" + + def __post_init__(self) -> None: + if self.capture_folder is None: + self.capture_folder = str(Path.cwd() / "worker_dirs" / self._CAPTURE_PATH_FOLDER) + print(f"GatewayConfig.capture_folder is not specified, use default capture_folder: {self.capture_folder}") diff --git a/xtuner/v1/rl/gateway/core/__init__.py b/xtuner/v1/rl/gateway/core/__init__.py new file mode 100644 index 0000000000..7e0210c672 --- /dev/null +++ b/xtuner/v1/rl/gateway/core/__init__.py @@ -0,0 +1,49 @@ +from .exceptions import ContextLengthExceededError, GatewayError, GatewayStateError, ModelNotFoundError +from .models import ( + BackendHealth, + CanonicalAssistantTurn, + CanonicalContentBlock, + CanonicalGenerateRequest, + CanonicalGenerateResponse, + CanonicalMessage, + CanonicalReasoning, + CanonicalReasoningBlock, + CanonicalReasoningStep, + CanonicalTextBlock, + CanonicalToolCall, + CanonicalToolCallBlock, + CanonicalToolChoice, + CanonicalToolDefinition, + CanonicalToolResult, + CanonicalToolResultBlock, + CanonicalUsage, + ModelCapabilities, + ModelCard, +) + + +__all__ = [ + "BackendHealth", + "CanonicalAssistantTurn", + "CanonicalContentBlock", + "CanonicalGenerateRequest", + "CanonicalGenerateResponse", + "CanonicalMessage", + "CanonicalReasoning", + "CanonicalReasoningBlock", + "CanonicalReasoningStep", + "CanonicalTextBlock", + "CanonicalToolCall", + "CanonicalToolCallBlock", + "CanonicalToolChoice", + "CanonicalToolDefinition", + "CanonicalToolResult", + "CanonicalToolResultBlock", + "CanonicalUsage", + "ContextLengthExceededError", + "GatewayError", + "GatewayStateError", + "ModelCapabilities", + "ModelCard", + "ModelNotFoundError", +] diff --git a/xtuner/v1/rl/gateway/core/exceptions.py b/xtuner/v1/rl/gateway/core/exceptions.py new file mode 100644 index 0000000000..7d52f2e626 --- /dev/null +++ b/xtuner/v1/rl/gateway/core/exceptions.py @@ -0,0 +1,28 @@ +class GatewayError(RuntimeError): + """Base exception for gateway failures.""" + + +class GatewayStateError(GatewayError): + """Raised when the gateway app is missing required runtime state.""" + + +class ModelNotFoundError(GatewayError): + """Raised when a requested model is not exposed by the backend.""" + + def __init__(self, model: str): + super().__init__(f"Model '{model}' is not available.") + self.model = model + + +class ContextLengthExceededError(GatewayError): + """Raised when the prompt is too long for the model's context window.""" + + def __init__(self, prompt_tokens: int, context_length: int): + super().__init__(f"Input is too long: prompt_tokens={prompt_tokens}, context_length={context_length}.") + self.prompt_tokens = prompt_tokens + self.context_length = context_length + + +class ToolCallParseError(GatewayError): + """Raised when a tool-enabled response contains tool-call markup that could + not be parsed into structured tool calls.""" diff --git a/xtuner/v1/rl/gateway/core/models.py b/xtuner/v1/rl/gateway/core/models.py new file mode 100644 index 0000000000..976011752a --- /dev/null +++ b/xtuner/v1/rl/gateway/core/models.py @@ -0,0 +1,153 @@ +from __future__ import annotations + +from typing import Annotated, Any, Literal, TypeAlias + +from pydantic import BaseModel, ConfigDict, Field, model_validator + + +class GatewayCoreModel(BaseModel): + model_config = ConfigDict(extra="forbid") + + +class CanonicalToolDefinition(GatewayCoreModel): + name: str + description: str | None = None + parameters_json_schema: dict[str, Any] = Field(default_factory=dict) + metadata: dict[str, Any] = Field(default_factory=dict) + + +class CanonicalToolChoice(GatewayCoreModel): + type: Literal["auto", "none", "required", "specific"] = "auto" + tool_name: str | None = None + metadata: dict[str, Any] = Field(default_factory=dict) + + @model_validator(mode="after") + def validate_specific_choice(self) -> CanonicalToolChoice: + if self.type == "specific" and not self.tool_name: + raise ValueError("tool_name is required when tool choice type is 'specific'.") + return self + + +class CanonicalToolCall(GatewayCoreModel): + id: str + name: str + arguments: Any = None + raw_arguments_text: str | None = None + metadata: dict[str, Any] = Field(default_factory=dict) + + +class CanonicalToolResult(GatewayCoreModel): + tool_call_id: str + name: str | None = None + output: Any = None + output_text: str | None = None + is_error: bool = False + metadata: dict[str, Any] = Field(default_factory=dict) + + +class CanonicalReasoningStep(GatewayCoreModel): + text: str + metadata: dict[str, Any] = Field(default_factory=dict) + + +class CanonicalReasoning(GatewayCoreModel): + steps: list[CanonicalReasoningStep] = Field(default_factory=list) + metadata: dict[str, Any] = Field(default_factory=dict) + + +class CanonicalTextBlock(GatewayCoreModel): + type: Literal["text"] = "text" + text: str + metadata: dict[str, Any] = Field(default_factory=dict) + + +class CanonicalToolCallBlock(GatewayCoreModel): + type: Literal["tool_call"] = "tool_call" + tool_call: CanonicalToolCall + + +class CanonicalToolResultBlock(GatewayCoreModel): + type: Literal["tool_result"] = "tool_result" + tool_result: CanonicalToolResult + + +class CanonicalReasoningBlock(GatewayCoreModel): + type: Literal["reasoning"] = "reasoning" + reasoning: CanonicalReasoning + + +CanonicalContentBlock: TypeAlias = Annotated[ + CanonicalTextBlock | CanonicalToolCallBlock | CanonicalToolResultBlock | CanonicalReasoningBlock, + Field(discriminator="type"), +] + + +class CanonicalMessage(GatewayCoreModel): + role: Literal["system", "user", "assistant", "tool"] + content: list[CanonicalContentBlock] = Field(default_factory=list) + name: str | None = None + metadata: dict[str, Any] = Field(default_factory=dict) + + +class CanonicalGenerateRequest(GatewayCoreModel): + request_id: str + model: str + messages: list[CanonicalMessage] = Field(default_factory=list) + tools: list[CanonicalToolDefinition] = Field(default_factory=list) + tool_choice: CanonicalToolChoice | None = None + parallel_tool_calls: bool | None = None + temperature: float | None = None + top_p: float | None = None + max_tokens: int | None = None + stop: list[str] = Field(default_factory=list) + stream: bool = False + metadata: dict[str, Any] = Field(default_factory=dict) + + +class CanonicalUsage(GatewayCoreModel): + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + + +class CanonicalAssistantTurn(GatewayCoreModel): + role: Literal["assistant"] = "assistant" + content: list[CanonicalContentBlock] = Field(default_factory=list) + metadata: dict[str, Any] = Field(default_factory=dict) + + +class CanonicalGenerateResponse(GatewayCoreModel): + request_id: str + model: str + output: CanonicalAssistantTurn + finish_reason: str = "stop" + usage: CanonicalUsage = Field(default_factory=CanonicalUsage) + metadata: dict[str, Any] = Field(default_factory=dict) + + +class ModelCard(GatewayCoreModel): + id: str + backend: str + context_length: int | None = None + owned_by: str | None = None + metadata: dict[str, Any] = Field(default_factory=dict) + + +class ModelCapabilities(GatewayCoreModel): + model: str + backend: str + context_length: int | None = None + supports_stream: bool = False + supports_tools: bool = False + supports_cancel: bool = False + supports_parallel_tool_calls: bool = False + supports_reasoning: bool = False + metadata: dict[str, Any] = Field(default_factory=dict) + + +class BackendHealth(GatewayCoreModel): + ready: bool + status: Literal["ready", "degraded", "unavailable"] + details: dict[str, Any] = Field(default_factory=dict) + reason: str | None = None + metadata: dict[str, Any] = Field(default_factory=dict) diff --git a/xtuner/v1/rl/gateway/server/__init__.py b/xtuner/v1/rl/gateway/server/__init__.py new file mode 100644 index 0000000000..722863cd83 --- /dev/null +++ b/xtuner/v1/rl/gateway/server/__init__.py @@ -0,0 +1,16 @@ +from .app import ( + build_gateway_app, + build_local_gateway_app, + serve_gateway, + serve_gateway_in_thread, + wait_for_gateway_ready, +) + + +__all__ = [ + "build_gateway_app", + "build_local_gateway_app", + "serve_gateway", + "serve_gateway_in_thread", + "wait_for_gateway_ready", +] diff --git a/xtuner/v1/rl/gateway/server/app.py b/xtuner/v1/rl/gateway/server/app.py new file mode 100644 index 0000000000..a7ca63389a --- /dev/null +++ b/xtuner/v1/rl/gateway/server/app.py @@ -0,0 +1,274 @@ +from __future__ import annotations + +import socket +import threading +import time +from typing import Union + +import httpx +import ray +import uvicorn +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse +from ray.actor import ActorHandle + +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast +from xtuner.v1.rl.rollout.worker import RolloutConfig + +from ..adapters import AnthropicChatAdapter, ChatTraceStore, OpenAIChatAdapter +from ..adapters.responses import OpenAIResponsesAdapter +from ..backend.local_backend import LocalRolloutBackend +from ..backend.protocol import GatewayBackend +from ..config import GatewayConfig +from ..core.exceptions import ContextLengthExceededError, GatewayError, ToolCallParseError +from .routes import ( + build_anthropic_router, + build_openai_router, + build_responses_router, + build_runtime_router, + build_trace_store_router, +) + + +TokenizerArg = Union[PreTrainedTokenizer, PreTrainedTokenizerFast, str] + + +# --------------------------------------------------------------------------- +# Internal base builder +# --------------------------------------------------------------------------- + + +def _create_base_gateway_app( + backend: GatewayBackend, + *, + title: str = "XTuner Gateway", + version: str = "0.1.0", +) -> FastAPI: + """Create the base FastAPI app with runtime routes and global error + handlers. + + This is an internal builder used by higher-level factory functions. The returned app exposes /livez, /readyz, and + /capabilities but no protocol-specific endpoints. + """ + app = FastAPI(title=title, version=version) + app.state.gateway_backend = backend + app.include_router(build_runtime_router()) + + @app.exception_handler(ContextLengthExceededError) + async def context_length_error_handler(request: Request, exc: ContextLengthExceededError) -> JSONResponse: + return JSONResponse( + status_code=400, + content={"error": {"message": str(exc), "type": "context_length_exceeded", "code": "context_too_long"}}, + ) + + @app.exception_handler(GatewayError) + async def gateway_error_handler(request: Request, exc: GatewayError) -> JSONResponse: + return JSONResponse( + status_code=500, + content={"error": {"message": str(exc), "type": type(exc).__name__, "code": "gateway_error"}}, + ) + + @app.exception_handler(ToolCallParseError) + async def tool_call_parse_error_handler(request: Request, exc: ToolCallParseError) -> JSONResponse: + return JSONResponse( + status_code=400, + content={"error": {"message": str(exc), "type": "tool_call_parse_error", "code": "tool_call_parse_error"}}, + ) + + @app.exception_handler(Exception) + async def generic_error_handler(request: Request, exc: Exception) -> JSONResponse: + return JSONResponse( + status_code=500, + content={"error": {"message": str(exc), "type": "internal_error", "code": "internal_server_error"}}, + ) + + return app + + +# --------------------------------------------------------------------------- +# Generic public factory (works with any GatewayBackend) +# --------------------------------------------------------------------------- + + +def build_gateway_app( + backend: GatewayBackend, + *, + tokenizer: TokenizerArg, + model_name: str, + context_length: int, + config: GatewayConfig | None = None, +) -> FastAPI: + """Build a gateway FastAPI app wired to *any* :class:`GatewayBackend`. + + This is the lowest-level public factory. Use this when you have a custom + backend (e.g. a future ``RemoteRolloutBackend``) and want to wire it into + the full gateway stack (OpenAI / Anthropic / Responses endpoints). + + Args: + backend: An object that satisfies the :class:`~xtuner.v1.rl.gateway.backend.protocol.GatewayBackend` protocol. + tokenizer: Tokenizer used for prompt encoding and token-count helpers. + Accepts a :class:`~transformers.PreTrainedTokenizer`, + :class:`~transformers.PreTrainedTokenizerFast`, or a **string** + path/identifier which will be loaded via + :func:`~transformers.AutoTokenizer.from_pretrained`. + model_name: Default model name reported by the ``/capabilities`` endpoint. + context_length: Maximum context length enforced by the gateway. + config: Gateway configuration (title, version, capture_folder, ...). + Defaults to a bare :class:`~xtuner.v1.rl.gateway.config.GatewayConfig` + with ``port=8080`` when not provided. + + Returns: + A fully-configured :class:`fastapi.FastAPI` instance ready to serve. + """ + if isinstance(tokenizer, str): + tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True) + cfg = config or GatewayConfig(port=8080) + app = _create_base_gateway_app(backend, title=cfg.title, version=cfg.version) + app.state.gateway_trace_store = ChatTraceStore() + adapter_kwargs = { + "generate_handler": backend.generate, + "tokenizer": tokenizer, + "default_model_name": model_name, + "context_length": context_length, + "capture_folder": cfg.capture_folder, + "trace_store": app.state.gateway_trace_store, + } + app.state.gateway_openai_adapter = OpenAIChatAdapter(**adapter_kwargs) + app.state.gateway_anthropic_adapter = AnthropicChatAdapter(**adapter_kwargs) + app.state.gateway_responses_adapter = OpenAIResponsesAdapter(**adapter_kwargs) + app.include_router(build_openai_router()) + app.include_router(build_anthropic_router()) + app.include_router(build_responses_router()) + app.include_router(build_trace_store_router()) + return app + + +# --------------------------------------------------------------------------- +# LocalRolloutBackend convenience factory +# --------------------------------------------------------------------------- + + +def build_local_gateway_app( + controller: ActorHandle, + config: GatewayConfig | None = None, + rollout_config: RolloutConfig | None = None, +) -> FastAPI: + """Build a gateway app backed by a Ray-actor RolloutController.""" + cfg = config or GatewayConfig(port=8080) + if rollout_config is None: + rollout_metadata = ray.get(controller.get_rollout_metadata.remote()) + rollout_config = rollout_metadata["rollout_config"] + tokenizer = AutoTokenizer.from_pretrained(rollout_config.tokenizer_path, trust_remote_code=True) + + model_name = rollout_config.model_name + if model_name is None: + raise ValueError("controller.config.model_name must be set when building a local gateway app") + context_length = rollout_config.context_length + if context_length is None: + raise ValueError("controller.config.context_length must be set when building a local gateway app") + + backend = LocalRolloutBackend(controller, tokenizer=tokenizer, rollout_config=rollout_config) + return build_gateway_app( + backend, + tokenizer=tokenizer, + model_name=model_name, + context_length=context_length, + config=cfg, + ) + + +# --------------------------------------------------------------------------- +# Standalone serve helpers +# --------------------------------------------------------------------------- + + +def serve_gateway(app: FastAPI, config: GatewayConfig) -> None: + """Start the gateway server in the **current thread** (blocking). + + Use this for a fully standalone gateway process:: + + from xtuner.v1.rl.gateway import ( + GatewayConfig, build_local_gateway_app, serve_gateway + ) + + config = GatewayConfig(port=8080, auto_start=False) + app = build_local_gateway_app(controller, config) + serve_gateway(app, config) # blocks until interrupted + + For a custom backend:: + + from xtuner.v1.rl.gateway import ( + GatewayConfig, build_gateway_app, serve_gateway + ) + + config = GatewayConfig(port=8080) + app = build_gateway_app( + my_backend, + tokenizer=tokenizer, + model_name="my-model", + context_length=32768, + config=config, + ) + serve_gateway(app, config) + + Args: + app: A FastAPI application previously built by :func:`build_gateway_app` + or :func:`build_local_gateway_app`. + config: Gateway configuration supplying ``host``, ``port``, and + ``log_level``. + """ + _ensure_gateway_port_available(config) + uvicorn.run(app, host=config.host, port=config.port, log_level=config.log_level) + + +def serve_gateway_in_thread(app: FastAPI, config: GatewayConfig) -> threading.Thread: + """Start the gateway server in a **daemon thread** (non-blocking). + + Returns the :class:`threading.Thread` that is running uvicorn so callers + can monitor it if needed. The thread is daemonised so it will not prevent + the process from exiting. + + Args: + app: A FastAPI application previously built by :func:`build_gateway_app` + or :func:`build_local_gateway_app`. + config: Gateway configuration supplying ``host``, ``port``, and + ``log_level``. + + Returns: + The started daemon thread. + """ + thread = threading.Thread( + target=serve_gateway, + args=(app, config), + daemon=True, + name="gateway-server", + ) + thread.start() + return thread + + +def wait_for_gateway_ready(base_url: str, *, timeout_seconds: float = 180.0) -> None: + """Block until a gateway server responds successfully on ``/livez``.""" + deadline = time.time() + timeout_seconds + last_error = None + while time.time() < deadline: + try: + response = httpx.get(f"{base_url}/livez", timeout=5.0) + if response.status_code == 200: + return + last_error = response.text + except Exception as exc: + last_error = repr(exc) + time.sleep(1.0) + raise AssertionError(f"Gateway did not become ready at {base_url}: {last_error}") + + +def _ensure_gateway_port_available(config: GatewayConfig) -> None: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind((config.host, config.port)) + return + except OSError: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind((config.host, 0)) + config.port = int(sock.getsockname()[1]) diff --git a/xtuner/v1/rl/gateway/server/routes.py b/xtuner/v1/rl/gateway/server/routes.py new file mode 100644 index 0000000000..b80217170f --- /dev/null +++ b/xtuner/v1/rl/gateway/server/routes.py @@ -0,0 +1,287 @@ +from __future__ import annotations + +from enum import Enum +from typing import Any, cast + +from fastapi import APIRouter, Depends, Query, Request +from fastapi.responses import JSONResponse +from pydantic import BaseModel + +from ..adapters import ( + AnthropicChatAdapter, + AnthropicChatAdapterError, + AnthropicMessagesRequest, + ChatCompletionRequest, + ChatTraceRecord, + ChatTraceStore, + OpenAIChatAdapter, + OpenAIChatAdapterError, + ResponsesRequest, + build_api_key_trace_key, +) +from ..adapters.responses import OpenAIResponsesAdapter +from ..backend.protocol import GatewayBackend +from ..core.exceptions import GatewayStateError + + +def get_openai_adapter(request: Request) -> OpenAIChatAdapter: + adapter = getattr(request.app.state, "gateway_openai_adapter", None) + if adapter is None: + raise GatewayStateError("Gateway OpenAI adapter is not configured.") + return cast(OpenAIChatAdapter, adapter) + + +def get_anthropic_adapter(request: Request) -> AnthropicChatAdapter: + adapter = getattr(request.app.state, "gateway_anthropic_adapter", None) + if adapter is None: + raise GatewayStateError("Gateway Anthropic adapter is not configured.") + return cast(AnthropicChatAdapter, adapter) + + +def get_responses_adapter(request: Request) -> OpenAIResponsesAdapter: + adapter = getattr(request.app.state, "gateway_responses_adapter", None) + if adapter is None: + raise GatewayStateError("Gateway Responses adapter is not configured.") + return cast(OpenAIResponsesAdapter, adapter) + + +def extract_api_key(request: Request) -> str | None: + authorization = request.headers.get("authorization") + if authorization: + scheme, _, credentials = authorization.partition(" ") + if scheme.lower() == "bearer" and credentials.strip(): + return credentials.strip() + if authorization.strip(): + return authorization.strip() + + api_key = request.headers.get("x-api-key") or request.headers.get("api-key") + if api_key and api_key.strip(): + return api_key.strip() + return None + + +# --------------------------------------------------------------------------- +# Runtime router (/livez, /readyz, /capabilities) +# --------------------------------------------------------------------------- + + +def build_runtime_router() -> APIRouter: + router = APIRouter() + + @router.get("/livez") + async def livez() -> dict[str, str]: + return {"status": "ok"} + + @router.get("/readyz") + async def readyz(request: Request): + backend = _get_backend(request) + health = await backend.health() + payload = health.model_dump(mode="json") + if health.ready: + return payload + return JSONResponse(status_code=503, content=payload) + + @router.get("/capabilities") + async def get_capabilities(request: Request): + backend = _get_backend(request) + capabilities = await backend.get_capabilities() + return capabilities.model_dump(mode="json") + + return router + + +def _get_backend(request: Request) -> GatewayBackend: + backend = getattr(request.app.state, "gateway_backend", None) + if backend is None: + raise GatewayStateError("Gateway backend is not configured.") + return cast(GatewayBackend, backend) + + +# --------------------------------------------------------------------------- +# Trace store router (/trace_store) +# --------------------------------------------------------------------------- + + +def build_trace_store_router() -> APIRouter: + router = APIRouter() + + @router.get("/trace_store") + async def get_trace_records( + request: Request, + trace_key: str | None = Query(default=None), + ) -> dict: + trace_store = _get_trace_store(request) + resolved_trace_key = _resolve_trace_key(request, trace_key) + records = trace_store.get(resolved_trace_key) + return _build_trace_store_response(resolved_trace_key, records) + + @router.post("/trace_store/pop") + async def pop_trace_records( + request: Request, + trace_key: str | None = Query(default=None), + ) -> dict: + trace_store = _get_trace_store(request) + resolved_trace_key = _resolve_trace_key(request, trace_key) + records = trace_store.pop(resolved_trace_key) + return _build_trace_store_response(resolved_trace_key, records) + + @router.post("/trace_store/clear") + async def clear_trace_records( + request: Request, + trace_key: str | None = Query(default=None), + ) -> dict: + trace_store = _get_trace_store(request) + resolved_trace_key = _resolve_trace_key(request, trace_key) + trace_store.clear(resolved_trace_key) + return { + "trace_key": resolved_trace_key, + "cleared": True, + } + + return router + + +def _get_trace_store(request: Request) -> ChatTraceStore: + trace_store = getattr(request.app.state, "gateway_trace_store", None) + if trace_store is None: + raise GatewayStateError("Gateway trace store is not configured.") + return cast(ChatTraceStore, trace_store) + + +def _resolve_trace_key(request: Request, trace_key: str | None) -> str: + if trace_key: + return trace_key + return build_api_key_trace_key(extract_api_key(request)) + + +def _build_trace_store_response(trace_key: str, records: list[ChatTraceRecord]) -> dict[str, Any]: + return { + "trace_key": trace_key, + "count": len(records), + "records": [_serialize_trace_record(record) for record in records], + } + + +def _serialize_trace_record(record: ChatTraceRecord) -> dict[str, Any]: + return { + "trace_key": record.trace_key, + "request_snapshot": _serialize_trace_value(record.request_snapshot), + "response_snapshot": _serialize_trace_value(record.response_snapshot), + "prompt_ids": list(record.prompt_ids), + "response_ids": list(record.response_ids), + "input_text": record.input_text, + "output_text": record.output_text, + "logprobs": _serialize_trace_value(record.logprobs), + "routed_experts": _serialize_trace_value(record.routed_experts), + "finish_reason": record.finish_reason, + "status": _serialize_trace_value(record.status), + "sequence": record.sequence, + "created_at": record.created_at, + "request_id": record.request_id, + } + + +def _serialize_trace_value(value: Any) -> Any: + if value is None or isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, BaseModel): + try: + return _serialize_trace_value(value.model_dump(mode="json", exclude_none=True)) + except Exception: + return _serialize_trace_value(value.model_dump(mode="python", exclude_none=True)) + if isinstance(value, Enum): + return value.value + if isinstance(value, dict): + return {str(key): _serialize_trace_value(val) for key, val in value.items() if val is not None} + if isinstance(value, (list, tuple, set)): + return [_serialize_trace_value(item) for item in value] + if isinstance(value, bytes): + return value.decode("utf-8", errors="replace") + try: + import ray + + if isinstance(value, ray.ObjectRef): + return str(value) + except Exception: + pass + if hasattr(value, "tolist"): + try: + return _serialize_trace_value(value.tolist()) + except Exception: + pass + return str(value) + + +# --------------------------------------------------------------------------- +# OpenAI Chat Completions router (/v1/chat/completions) +# --------------------------------------------------------------------------- + + +def build_openai_router() -> APIRouter: + router = APIRouter() + + @router.post("/v1/chat/completions") + async def chat_completions( + request_body: ChatCompletionRequest, + request: Request, + adapter: OpenAIChatAdapter = Depends(get_openai_adapter), + ): + try: + return await adapter.chat(request_body, api_key=extract_api_key(request)) + except OpenAIChatAdapterError as exc: + return JSONResponse( + status_code=400 if exc.error_type == "invalid_request_error" else 500, + content={"error": {"message": exc.message, "type": exc.error_type, "code": exc.code}}, + ) + + return router + + +# --------------------------------------------------------------------------- +# Anthropic Messages router (/v1/messages) +# --------------------------------------------------------------------------- + + +def build_anthropic_router() -> APIRouter: + router = APIRouter() + + @router.post("/v1/messages") + async def messages( + request_body: AnthropicMessagesRequest, + request: Request, + adapter: AnthropicChatAdapter = Depends(get_anthropic_adapter), + ): + try: + return await adapter.messages(request_body, api_key=extract_api_key(request)) + except AnthropicChatAdapterError as exc: + return JSONResponse( + status_code=400 if exc.error_type == "invalid_request_error" else 500, + content={"type": "error", "error": {"type": exc.error_type, "message": exc.message}}, + ) + + return router + + +# --------------------------------------------------------------------------- +# OpenAI Responses router (/v1/responses) +# --------------------------------------------------------------------------- + + +def build_responses_router() -> APIRouter: + router = APIRouter() + + @router.post("/v1/responses") + async def responses( + request_body: ResponsesRequest, + request: Request, + adapter: OpenAIResponsesAdapter = Depends(get_responses_adapter), + ): + try: + return await adapter.responses(request_body, api_key=extract_api_key(request)) + except OpenAIChatAdapterError as exc: + return JSONResponse( + status_code=400 if exc.error_type == "invalid_request_error" else 500, + content={"error": {"message": exc.message, "type": exc.error_type, "code": exc.code}}, + ) + + return router diff --git a/xtuner/v1/rl/grpo/__init__.py b/xtuner/v1/rl/grpo/__init__.py deleted file mode 100644 index cdcff8f252..0000000000 --- a/xtuner/v1/rl/grpo/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .loss import GRPOLossConfig, GRPOLossContext - - -__all__ = [ - "GRPOLossConfig", - "GRPOLossContext", -] diff --git a/xtuner/v1/rl/judger/__init__.py b/xtuner/v1/rl/judger/__init__.py new file mode 100644 index 0000000000..92c5c59cf1 --- /dev/null +++ b/xtuner/v1/rl/judger/__init__.py @@ -0,0 +1,22 @@ +from .compass_verifier_v2 import CompassVerifierV2Config +from .composed import ( + ComposedJudger, + ComposedJudgerConfig, + default_merge_fn, + default_select_fn, +) +from .dapo_math import DapoMathJudgerConfig +from .factory import ( + build_judger, +) +from .geo3k import GEO3KJudgerConfig +from .gsm8k import GSM8KJudgerConfig +from .native import ( + Judger, + JudgerConfig, + JudgerPool, + NativeJudger, + RayJudger, + RayJudgerProxy, + RemoteJudger, +) diff --git a/xtuner/v1/ray/judger/compass_verifier_v2.py b/xtuner/v1/rl/judger/compass_verifier_v2.py similarity index 60% rename from xtuner/v1/ray/judger/compass_verifier_v2.py rename to xtuner/v1/rl/judger/compass_verifier_v2.py index 52fe5648ce..75ae6d9077 100644 --- a/xtuner/v1/ray/judger/compass_verifier_v2.py +++ b/xtuner/v1/rl/judger/compass_verifier_v2.py @@ -1,14 +1,13 @@ import asyncio import random -from typing import List import aiohttp -import ray import requests # type: ignore[import-untyped] -from ray.util.placement_group import PlacementGroup +from pydantic import ConfigDict -from xtuner.v1.data_proto.rl_data import RLDataFlowItem, RLJudgerResponseItem -from xtuner.v1.ray.judger.native import NativeJudgerConfig +from xtuner.v1.data_proto.rl_data import RolloutState, Status +from xtuner.v1.rl.judger.native import Judger, JudgerConfig +from xtuner.v1.utils.type_helper import ray_method verify_prompt = """ @@ -38,64 +37,55 @@ """ -class CompassVerifierV2: - """Base class for judgers, providing a standard interface for executing a - judging process, which can be either a local function or a remote service. - - The judger orchestrates a three-step pipeline: - 1. Pre-process the input data. - 2. Execute the core logic (local function or remote HTTP call). - 3. Post-process the result. - """ - +class CompassVerifierV2(Judger): def __init__( self, - hosts=[], + hosts: list[str], request_timeout: float = 30.0, max_retries: int = 3, - thinking_finish_words=["", "**Final Answer**", ""], + thinking_finish_words: list[str] | None = None, ): + if not hosts: + raise ValueError("CompassVerifierV2 requires at least one host.") self.hosts = hosts self.request_timeout = request_timeout self.max_retries = max_retries - self.thinking_finish_words = thinking_finish_words + self.thinking_finish_words = thinking_finish_words or ["", "**Final Answer**", ""] self.model_name = requests.get( f"http://{self.hosts[0]}/v1/models", headers={"Authorization": "Bearer "}, + timeout=request_timeout, ).json()["data"][0]["id"] self.judger_name = "compass_verifier_v2" - async def judge(self, data_item: List[RLDataFlowItem]) -> List[RLJudgerResponseItem]: - response_future = [self.judge_single(d) for d in data_item] - judger_responses = await asyncio.gather(*response_future) - return judger_responses - - async def judge_single(self, data_item: RLDataFlowItem) -> RLJudgerResponseItem: - # print(f"[Judger]: input {data_item}") - if data_item.env.rollout.finish_reason not in ["finished", "stop"]: - return RLJudgerResponseItem(uid=data_item.uid.observation_id, reward={"score": -1}) - question = data_item.data.messages[-1]["content"] # type: ignore[index] - model_answer = (data_item.env.rollout.response or "").replace("<|im_end|>", "").strip() + @ray_method + async def judge(self, rollout_state: RolloutState) -> RolloutState: # type: ignore[override] + if rollout_state.status != Status.COMPLETED or rollout_state.response is None: + rollout_state.reward = {"score": -1} + return rollout_state + + question = rollout_state.message[-1]["content"] + model_answer = rollout_state.response.replace("<|im_end|>", "").strip() for thinking_finish_word in self.thinking_finish_words: if thinking_finish_word in model_answer: model_answer = model_answer.split(thinking_finish_word)[-1] - # only keep last 10 lines - num_lines = len(model_answer.split("\n")) - if num_lines > 10: - model_answer = "\n".join(model_answer.split("\n")[-10:]) - + answer_lines = model_answer.split("\n") + if len(answer_lines) > 10: + model_answer = "\n".join(answer_lines[-10:]) if len(model_answer) > 1000: model_answer = model_answer[-1000:] - label = data_item.data.reward_model["ground_truth"] # type: ignore[index] - outcome_reward = await self._judge_with_llm(question, model_answer, label) - # print(f"[Judger]: final reward {final_reward}") - return RLJudgerResponseItem(uid=data_item.uid.observation_id, reward={"score": outcome_reward}) + assert rollout_state.reward_model is not None and "ground_truth" in rollout_state.reward_model, ( + "RolloutState must have reward_model with 'ground_truth' for CompassVerifierV2." + ) + outcome_reward = await self._judge_with_llm(question, model_answer, rollout_state.reward_model["ground_truth"]) + rollout_state.reward = {"score": outcome_reward} + return rollout_state async def _judge_with_llm(self, question: str, model_response: str, label: str): headers = {"Content-Type": "application/json"} - prompt = verify_prompt.format("", "", question=question, llm_response=model_response, gold_answer=label) + prompt = verify_prompt.format(question=question, llm_response=model_response, gold_answer=label) data = { "model": self.model_name, "messages": [{"role": "user", "content": prompt}], @@ -113,45 +103,31 @@ async def _judge_with_llm(self, question: str, model_response: str, label: str): response_json = await response.json() if response.status != 200: error_msg = response_json.get("error", {}).get("message", "Unknown error") - raise Exception(f"API request failed with status {response.status}: {error_msg}") + raise RuntimeError(f"API request failed with status {response.status}: {error_msg}") res_str = response_json["choices"][0]["message"]["content"] - if res_str.strip() == "A": - return 1 - else: - return -1 + return 1 if res_str.strip() == "A" else -1 except Exception as e: await asyncio.sleep(1) print(f"[Judger]: Error try {i}: {str(e)}") raise RuntimeError(f"Cannot connect to judger service for {self.max_retries} times.") def get_judger_name(self) -> str: - """Get the name of the judger. - - Returns: - str: The name of the judger. - """ return self.judger_name -class CompassVerifierV2Config(NativeJudgerConfig): - """Configuration for the CompassVerifierV2 judger.""" +class CompassVerifierV2Config(JudgerConfig): + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") - hosts: list + hosts: list[str] judger_name: str = "compass_verifier_v2" - - def build_actor(self, pg: PlacementGroup, start_bundle_idx: int) -> List[ray.actor.ActorClass]: - workers_list = [] - for idx in range(self.num_ray_actors): - bundle_idx = start_bundle_idx + idx - pg_options = {"num_cpus": pg.bundle_specs[bundle_idx].get("CPU", 1)} - worker = ( - ray.remote(CompassVerifierV2) - .options( - placement_group=pg, - placement_group_bundle_index=bundle_idx, - **pg_options, - ) - .remote(hosts=self.hosts) - ) - workers_list.append(worker) - return workers_list # type: ignore[return-value] + request_timeout: float = 30.0 + max_retries: int = 3 + thinking_finish_words: list[str] = ["", "**Final Answer**", ""] + + def build_local(self) -> CompassVerifierV2: + return CompassVerifierV2( + hosts=self.hosts, + request_timeout=self.request_timeout, + max_retries=self.max_retries, + thinking_finish_words=self.thinking_finish_words, + ) diff --git a/xtuner/v1/rl/judger/composed.py b/xtuner/v1/rl/judger/composed.py new file mode 100644 index 0000000000..c4aff7e78a --- /dev/null +++ b/xtuner/v1/rl/judger/composed.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +from typing import Callable, TypeAlias + +from pydantic import BaseModel, ConfigDict, Field +from ray.util.placement_group import PlacementGroup + +from xtuner.v1.data_proto.rl_data import RolloutState + +from .native import Judger, JudgerConfig + + +SelectedJudgerKeys: TypeAlias = str | list[str] | None +JudgerSelectFn: TypeAlias = Callable[[RolloutState, dict[str, Judger]], SelectedJudgerKeys] +JudgerMergeFn: TypeAlias = Callable[[RolloutState, dict[str, RolloutState]], RolloutState] + + +def default_select_fn(rollout_state: RolloutState, branches: dict[str, Judger]) -> SelectedJudgerKeys: + """Default branch selector for ``ComposedJudgerConfig``. + + Selection order is intentionally simple: + 1. If ``rollout_state.data_source`` is a string and matches a branch key, use it. + 2. Otherwise return ``None`` and let ``default_key`` or the single-branch fallback decide. + + Users with task-specific routing logic should pass a custom ``select_fn`` instead of extending + this default heuristic. + """ + data_source = rollout_state.data_source + if isinstance(data_source, str) and data_source in branches: + return data_source + + return None + + +def default_merge_fn(original: RolloutState, judged: dict[str, RolloutState]) -> RolloutState: + """Default merger for ``ComposedJudgerConfig``. + + This merger intentionally does not combine multiple judger scores into a single aggregated value. + It writes the merged reward as ``{branch_name: score}``, where ``branch_name`` is the selected + key from ``ComposedJudgerConfig.branches`` and ``score`` is taken from each child judger's + ``reward["score"]``. + + Users who need weighted sums, richer reward payloads, or custom post-processing should provide + their own ``merge_fn``. + """ + merged = original.model_copy(deep=True) + merged.reward = {} + + for name, state in judged.items(): + if state.reward is None or "score" not in state.reward: + raise KeyError(f"Default merge_fn requires reward['score'] for branch {name!r}.") + merged.reward[name] = state.reward["score"] + + return merged + + +class ComposedJudger(Judger): + def __init__( + self, + branches: dict[str, Judger], + select_fn: JudgerSelectFn = default_select_fn, + merge_fn: JudgerMergeFn = default_merge_fn, + default_key: str | None = "default", + ): + if not branches: + raise ValueError("ComposedJudger requires at least one branch.") + self.branches = branches + self.select_fn = select_fn + self.merge_fn = merge_fn + self.default_key = default_key + + def _resolve_selected_keys(self, rollout_state: RolloutState) -> list[str]: + selected = self.select_fn(rollout_state, self.branches) + + if selected is None: + selected_keys: list[str] = [] + elif isinstance(selected, str): + selected_keys = [selected] + else: + selected_keys = list(dict.fromkeys(selected)) + + if not selected_keys: + if self.default_key is not None and self.default_key in self.branches: + return [self.default_key] + if len(self.branches) == 1: + return [next(iter(self.branches))] + raise KeyError( + f"ComposedJudger could not select a branch for task_name={rollout_state.task_name!r}, " + f"data_source={rollout_state.data_source!r}, available={sorted(self.branches)}" + ) + return selected_keys + + async def judge(self, rollout_state: RolloutState) -> RolloutState: + selected_keys = self._resolve_selected_keys(rollout_state) + + judged: dict[str, RolloutState] = {} + for key in selected_keys: + if key not in self.branches: + raise KeyError(f"Unknown judger branch: {key}, available={sorted(self.branches)}") + judged[key] = await self.branches[key].judge(rollout_state.model_copy(deep=True)) + return self.merge_fn(rollout_state, judged) + + +class ComposedJudgerConfig(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") + + branches: dict[str, JudgerConfigLike] + # ``select_fn`` chooses which branch keys should be executed for one sample. + # Return a single string for single-judger routing, a list of strings for multi-judger execution, + # or ``None`` to fall back to ``default_key`` / single-branch implicit fallback. + select_fn: JudgerSelectFn = Field(default=default_select_fn, exclude=True) + # ``merge_fn`` merges the judged rollout states back into one rollout state. + # The default implementation does not aggregate scores; it writes ``{branch_name: score}``. + merge_fn: JudgerMergeFn | None = Field(default=None, exclude=True) + default_key: str | None = "default" + + def get_num_placement_group_bundles(self) -> int: + return sum(branch.get_num_placement_group_bundles() for branch in self.branches.values()) + + def build(self, pg: PlacementGroup | None = None, start_bundle_idx: int = 0) -> Judger: + from .factory import build_judger + + return build_judger(self, pg=pg, start_bundle_idx=start_bundle_idx) + + +JudgerConfigLike: TypeAlias = JudgerConfig | ComposedJudgerConfig + +ComposedJudgerConfig.model_rebuild() diff --git a/xtuner/v1/ray/judger/dapo_math.py b/xtuner/v1/rl/judger/dapo_math.py similarity index 81% rename from xtuner/v1/ray/judger/dapo_math.py rename to xtuner/v1/rl/judger/dapo_math.py index 285649c1c3..5ddcfdead8 100644 --- a/xtuner/v1/ray/judger/dapo_math.py +++ b/xtuner/v1/rl/judger/dapo_math.py @@ -1,9 +1,9 @@ import re from typing import Any, Callable, List, Optional, Tuple -from pydantic import ConfigDict, Field +from pydantic import Field, model_validator -from .native import NativeJudgerConfig +from .native import JudgerConfig # Adapted from https://github.com/volcengine/verl/blob/main/verl/utils/reward_score/math_dapo.py @@ -290,8 +290,7 @@ def compute_reward(response, label, extra_info): return {"score": reward, "acc": out["acc"]} -class DapoMathJudgerConfig(NativeJudgerConfig): - model_config = ConfigDict(extra="forbid") +class DapoMathJudgerConfig(JudgerConfig): eos_token: List[str] | str enable_overlong_buffer: bool score: int = 1 @@ -300,62 +299,42 @@ class DapoMathJudgerConfig(NativeJudgerConfig): overlong_buffer_len: Optional[int] = None overlong_penalty_factor: Optional[float] = None tokenizer: Any = Field(default=None, exclude=True) - reward_func: Callable = Field(default=compute_reward, exclude=True) - - def __init__( - self, - judger_name: str, - eos_token: List[str] | str, - enable_overlong_buffer: bool, - max_response_len: Optional[int], - overlong_buffer_len: Optional[int], - overlong_penalty_factor: Optional[float], - tokenizer: Any, - score: int = 1, - format_score: int = 0, - ): - if isinstance(eos_token, str): - assert eos_token.strip() != "", "eos_token string must not be empty" - elif isinstance(eos_token, list): - assert all(isinstance(e, str) and e.strip() != "" for e in eos_token), ( + reward_handler: Callable | str = Field(default=compute_reward, exclude=True) + extra_info: dict = Field(default_factory=dict, exclude=True) + + @model_validator(mode="after") + def _pack_extra_info(self) -> "DapoMathJudgerConfig": + if isinstance(self.eos_token, str): + assert self.eos_token.strip() != "", "eos_token string must not be empty" + elif isinstance(self.eos_token, list): + assert all(isinstance(e, str) and e.strip() != "" for e in self.eos_token), ( "All eos_token list elements must be non-empty strings" ) - assert len(eos_token) > 0, "eos_token list must not be empty" + assert len(self.eos_token) > 0, "eos_token list must not be empty" else: raise TypeError("eos_token must be a non-empty string or a non-empty list of strings") - # 初始化基类 - super().__init__( - judger_name=judger_name, - eos_token=eos_token, - enable_overlong_buffer=enable_overlong_buffer, - score=score, - format_score=format_score, - max_response_len=max_response_len, - overlong_buffer_len=overlong_buffer_len, - overlong_penalty_factor=overlong_penalty_factor, - tokenizer=tokenizer, - ) - - self.extra_info.update( + self.extra_info.update( # type: ignore[attr-defined] { - "eos_token": eos_token, - "score": score, - "format_score": format_score, + "eos_token": self.eos_token, + "score": self.score, + "format_score": self.format_score, } ) - if enable_overlong_buffer: - assert max_response_len is not None - assert overlong_buffer_len is not None - assert overlong_penalty_factor is not None - assert tokenizer is not None - self.extra_info.update( + if self.enable_overlong_buffer: + assert self.max_response_len is not None, "max_response_len is required." + assert self.overlong_buffer_len is not None, "overlong_buffer_len is required." + assert self.overlong_penalty_factor is not None, "overlong_penalty_factor is required." + assert self.tokenizer is not None, "tokenizer is required." + self.extra_info.update( # type: ignore[attr-defined] { - "enable_overlong_buffer": enable_overlong_buffer, - "max_response_len": max_response_len, - "overlong_buffer_len": overlong_buffer_len, - "overlong_penalty_factor": overlong_penalty_factor, - "tokenizer": tokenizer, + "enable_overlong_buffer": self.enable_overlong_buffer, + "max_response_len": self.max_response_len, + "overlong_buffer_len": self.overlong_buffer_len, + "overlong_penalty_factor": self.overlong_penalty_factor, + "tokenizer": self.tokenizer, } ) + + return self diff --git a/xtuner/v1/rl/judger/factory.py b/xtuner/v1/rl/judger/factory.py new file mode 100644 index 0000000000..0cb0cdd997 --- /dev/null +++ b/xtuner/v1/rl/judger/factory.py @@ -0,0 +1,47 @@ +from ray.util.placement_group import PlacementGroup + +from .composed import ComposedJudger, ComposedJudgerConfig, JudgerConfigLike, default_merge_fn +from .native import Judger, JudgerConfig, JudgerPool + + +# +# Use ``JudgerConfig`` when one sample only needs one concrete judger implementation: +# one reward handler, one judger_name, and one execution mode (local or Ray actors). +# +# Use ``ComposedJudgerConfig`` when one sample may need to be routed to different child +# judgers by ``select_fn``, or when you want to run multiple child judgers and merge their +# outputs with ``merge_fn``. +# +def build_judger(config: JudgerConfigLike, pg: PlacementGroup | None = None, start_bundle_idx: int = 0) -> Judger: + if isinstance(config, ComposedJudgerConfig): + return _build_composite_judger(config, pg=pg, start_bundle_idx=start_bundle_idx) + return _build_replicated_judger(config, pg=pg, start_bundle_idx=start_bundle_idx) + + +def _build_replicated_judger(config: JudgerConfig, pg: PlacementGroup | None, start_bundle_idx: int) -> Judger: + if config.num_ray_actors == 0: + return config.build_local() + if config.num_ray_actors == 1: + return config._build_remote_judger(pg=pg, bundle_idx=start_bundle_idx) + return JudgerPool( + replicas=config._build_remote_judgers(pg=pg, start_bundle_idx=start_bundle_idx), + judger_name=config.judger_name, + ) + + +def _build_composite_judger( + config: ComposedJudgerConfig, + pg: PlacementGroup | None, + start_bundle_idx: int, +) -> Judger: + branches: dict[str, Judger] = {} + bundle_idx = start_bundle_idx + for key, branch_config in config.branches.items(): + branches[key] = build_judger(branch_config, pg=pg, start_bundle_idx=bundle_idx) + bundle_idx += branch_config.get_num_placement_group_bundles() + return ComposedJudger( + branches=branches, + select_fn=config.select_fn, + merge_fn=config.merge_fn or default_merge_fn, + default_key=config.default_key, + ) diff --git a/xtuner/v1/ray/judger/geo3k.py b/xtuner/v1/rl/judger/geo3k.py similarity index 90% rename from xtuner/v1/ray/judger/geo3k.py rename to xtuner/v1/rl/judger/geo3k.py index 71e5dd2592..3449c8852b 100644 --- a/xtuner/v1/ray/judger/geo3k.py +++ b/xtuner/v1/rl/judger/geo3k.py @@ -8,7 +8,7 @@ extract_boxed_content = None grade_answer = None -from .native import NativeJudgerConfig +from .native import JudgerConfig def format_reward(predict_str: str) -> float: @@ -35,9 +35,9 @@ def compute_reward(response, label, extra_info) -> dict: return {"score": score, "acc": acc} -class GEO3KJudgerConfig(NativeJudgerConfig): +class GEO3KJudgerConfig(JudgerConfig): """Configuration for the GEO3K judger.""" judger_name: str = "hiyouga/geometry3k" extra_info: dict = {"format_score": 0.1, "use_boxed": True} - reward_func: Callable = compute_reward + reward_handler: Callable | str = compute_reward diff --git a/xtuner/v1/ray/judger/gsm8k.py b/xtuner/v1/rl/judger/gsm8k.py similarity index 95% rename from xtuner/v1/ray/judger/gsm8k.py rename to xtuner/v1/rl/judger/gsm8k.py index 3a22d83783..c125c3f610 100644 --- a/xtuner/v1/ray/judger/gsm8k.py +++ b/xtuner/v1/rl/judger/gsm8k.py @@ -1,7 +1,7 @@ import re from typing import Callable -from .native import NativeJudgerConfig +from .native import JudgerConfig _SOLUTION_CLIP_CHARS = 300 @@ -77,9 +77,9 @@ def compute_reward(response, label, extra_info): return {"score": extra_info["format_score"]} -class GSM8KJudgerConfig(NativeJudgerConfig): +class GSM8KJudgerConfig(JudgerConfig): """Configuration for the GSM8K judger.""" judger_name: str = "openai/gsm8k" extra_info: dict = {"score": 1, "format_score": 0} - reward_func: Callable = compute_reward + reward_handler: Callable | str = compute_reward diff --git a/xtuner/v1/rl/judger/native.py b/xtuner/v1/rl/judger/native.py new file mode 100644 index 0000000000..d0b3699333 --- /dev/null +++ b/xtuner/v1/rl/judger/native.py @@ -0,0 +1,235 @@ +from __future__ import annotations + +import asyncio +import inspect +from abc import ABC, abstractmethod +from typing import Callable, TypeAlias, cast + +import httpx +from pydantic import BaseModel, ConfigDict, Field, model_validator +from ray.actor import ActorClass, ActorProxy +from ray.util.placement_group import PlacementGroup + +from xtuner.v1.data_proto.rl_data import RolloutState +from xtuner.v1.rl.utils import CPUActorLauncher +from xtuner.v1.utils.logger import get_logger +from xtuner.v1.utils.type_helper import ray_method + + +logger = get_logger() + + +class Judger(ABC): + @abstractmethod + async def judge(self, rollout_state: RolloutState) -> RolloutState: ... + + +class NativeJudger(Judger): + """Local judger implementation backed by a Python callable or HTTP + endpoint.""" + + def __init__( + self, + judger_name: str = "native_judger", + reward_handler: Callable | str | None = None, + extra_info: dict | None = None, + request_timeout: float = 30.0, + ): + self._judger_name = judger_name + self.extra_info = extra_info or {} + self.reward_handler = reward_handler + self.request_timeout = request_timeout + + @ray_method + async def judge(self, rollout_state: RolloutState) -> RolloutState: # type: ignore[override] + assert rollout_state.response is not None, ( + "RolloutState must have a response for judging. You should detokenize the response_ids in AgentLoop" + ) + assert rollout_state.reward_model is not None and "ground_truth" in rollout_state.reward_model, ( + "RolloutState must have reward_model with 'ground_truth' for judging. You should set reward_model in AgentLoop" + ) + + input_kwargs = { + "response": rollout_state.response, + "label": rollout_state.reward_model["ground_truth"], + "extra_info": {**self.extra_info}, + } + + judger_response = None + if isinstance(self.reward_handler, str): + async with httpx.AsyncClient(timeout=self.request_timeout) as client: + response = await client.post(self.reward_handler, json=input_kwargs) + response.raise_for_status() + judger_response = response.json() + elif callable(self.reward_handler): + if inspect.iscoroutinefunction(self.reward_handler): + judger_response = await self.reward_handler(**input_kwargs) + else: + judger_response = self.reward_handler(**input_kwargs) + + assert judger_response is not None, "Reward handler did not return a response." + assert isinstance(judger_response, dict), ( + f"Reward handler must return a dict, but got {type(judger_response)}." + ) + rollout_state.reward = judger_response + return rollout_state + + def get_judger_name(self) -> str: + return self._judger_name + + +class RemoteJudger(Judger): + def __init__(self, actor: RayJudgerProxy, judger_name: str): + self.actor = actor + self._judger_name = judger_name + + @ray_method + async def judge(self, rollout_state: RolloutState) -> RolloutState: # type: ignore[override] + return await self.actor.judge.remote(rollout_state) + + def get_judger_name(self) -> str: + return self._judger_name + + +class JudgerPool(Judger): + """Round-robin dispatch across replicas of the same judger type.""" + + def __init__(self, replicas: list[Judger], judger_name: str): + if not replicas: + raise ValueError("JudgerPool requires at least one replica.") + self.replicas = replicas + self._judger_name = judger_name + self._rr_index = 0 + self._lock = asyncio.Lock() + self._worker_loads = dict.fromkeys(range(len(replicas)), 0) + + async def _pick_replica(self) -> tuple[int, Judger]: + async with self._lock: + replica_idx = self._rr_index % len(self.replicas) + self._rr_index = (self._rr_index + 1) % len(self.replicas) + self._worker_loads[replica_idx] += 1 + return replica_idx, self.replicas[replica_idx] + + async def _release_replica(self, replica_idx: int) -> None: + async with self._lock: + self._worker_loads[replica_idx] -= 1 + + @ray_method + async def judge(self, rollout_state: RolloutState) -> RolloutState: # type: ignore[override] + replica_idx, replica = await self._pick_replica() + try: + return await replica.judge(rollout_state) + finally: + await self._release_replica(replica_idx) + + def get_worker_status(self) -> dict[str, int]: + return {f"{self._judger_name}[{idx}]": load for idx, load in self._worker_loads.items()} + + def get_judger_name(self) -> str: + return self._judger_name + + +class JudgerConfig(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") + + judger_name: str + reward_handler: Callable | str | None = Field(default=None, exclude=True) + request_timeout: float = 30.0 + extra_info: dict = Field(default_factory=dict, exclude=True) + num_ray_actors: int = Field(default=0, ge=0, description="0 means local mode, >0 means remote Ray actors.") + num_cpus_per_actor: int = Field(default=1, gt=0, description="CPU cores per remote judger actor.") + cpu_memory_per_actor: int = Field( + default=1024**3, gt=0, description="CPU memory in bytes per remote judger actor." + ) + + @model_validator(mode="after") + def _validate_ray_actor_config(self) -> JudgerConfig: + if self.num_ray_actors == 0: + if self.num_cpus_per_actor != 1 or self.cpu_memory_per_actor != 1024**3: + logger.warning( + "num_cpus_per_actor and cpu_memory_per_actor are ignored when Judger runs in local mode." + ) + return self + + def get_num_placement_group_bundles(self) -> int: + return self.num_ray_actors + + def get_cpu_bundles(self) -> list[dict[str, float | int]]: + return [ + { + "CPU": self.num_cpus_per_actor, + "memory": self.cpu_memory_per_actor, + } + for _ in range(self.get_num_placement_group_bundles()) + ] + + def build_local(self) -> Judger: + return NativeJudger( + judger_name=self.judger_name, + reward_handler=self.reward_handler, + request_timeout=self.request_timeout, + extra_info=self.extra_info, + ) + + def _build_remote_actor(self, pg: PlacementGroup | None = None, bundle_idx: int = 0) -> RayJudgerProxy: + return CPUActorLauncher.build_actor( + JudgerActor, + self, + pg=pg, + bundle_idx=bundle_idx, + actor_num_cpus=self.num_cpus_per_actor, + actor_memory=self.cpu_memory_per_actor, + ) + + def _build_remote_actors( + self, + pg: PlacementGroup | None = None, + start_bundle_idx: int = 0, + num_ray_actors: int | None = None, + ) -> list[RayJudgerProxy]: + actor_count = self.num_ray_actors if num_ray_actors is None else num_ray_actors + return CPUActorLauncher.build_actors( + JudgerActor, + self, + pg=pg, + start_bundle_idx=start_bundle_idx, + num_workers=actor_count, + actor_num_cpus_per_worker=self.num_cpus_per_actor, + actor_memory_per_worker=self.cpu_memory_per_actor, + ) + + def _build_remote_judger(self, pg: PlacementGroup | None = None, bundle_idx: int = 0) -> Judger: + return RemoteJudger(self._build_remote_actor(pg=pg, bundle_idx=bundle_idx), judger_name=self.judger_name) + + def _build_remote_judgers( + self, + pg: PlacementGroup | None = None, + start_bundle_idx: int = 0, + num_ray_actors: int | None = None, + ) -> list[Judger]: + return [ + RemoteJudger(actor, judger_name=self.judger_name) + for actor in self._build_remote_actors( + pg=pg, + start_bundle_idx=start_bundle_idx, + num_ray_actors=num_ray_actors, + ) + ] + + def build(self, pg: PlacementGroup | None = None, start_bundle_idx: int = 0) -> Judger: + from .factory import build_judger + + return build_judger(self, pg=pg, start_bundle_idx=start_bundle_idx) + + +class JudgerActor: + def __init__(self, judger_config: JudgerConfig): + self.judger = judger_config.build_local() + + @ray_method + async def judge(self, rollout_state: RolloutState) -> RolloutState: + return await self.judger.judge(rollout_state) + + +RayJudger = cast(ActorClass[JudgerActor], CPUActorLauncher.to_actor_class(JudgerActor)) +RayJudgerProxy: TypeAlias = ActorProxy[JudgerActor] diff --git a/xtuner/v1/rl/loss/__init__.py b/xtuner/v1/rl/loss/__init__.py new file mode 100644 index 0000000000..550338e9e0 --- /dev/null +++ b/xtuner/v1/rl/loss/__init__.py @@ -0,0 +1,4 @@ +from .base_loss import BaseRLLossConfig, BaseRLLossContext, BaseRLLossKwargs, compute_kl_loss_weight +from .grpo_loss import GRPOLossConfig, GRPOLossContext, GRPOLossKwargs +from .loss_fn import check_config, get_policy_loss_fn, kl_penalty, pg_loss_fn, register_policy_loss, sft_loss_fn +from .oreal_loss import OrealLossConfig, OrealLossContext, OrealLossKwargs diff --git a/xtuner/v1/rl/base/loss.py b/xtuner/v1/rl/loss/base_loss.py similarity index 99% rename from xtuner/v1/rl/base/loss.py rename to xtuner/v1/rl/loss/base_loss.py index 006f30ca75..8ee859a09f 100644 --- a/xtuner/v1/rl/base/loss.py +++ b/xtuner/v1/rl/loss/base_loss.py @@ -6,10 +6,10 @@ from xtuner.v1.loss.ce_loss import CELossConfig, CELossContext, CELossKwargs from xtuner.v1.loss.utils import sp_gather, sp_split -from xtuner.v1.utils.device import get_device # from ..utils import sp_split -from .rollout_is import RolloutImportanceSampling +from xtuner.v1.rl.rollout_is import RolloutImportanceSampling +from xtuner.v1.utils.device import get_device DEVICE = get_device() diff --git a/xtuner/v1/rl/grpo/loss.py b/xtuner/v1/rl/loss/grpo_loss.py similarity index 98% rename from xtuner/v1/rl/grpo/loss.py rename to xtuner/v1/rl/loss/grpo_loss.py index ec34c0b156..e9aa607b55 100644 --- a/xtuner/v1/rl/grpo/loss.py +++ b/xtuner/v1/rl/loss/grpo_loss.py @@ -7,14 +7,14 @@ from xtuner.v1.utils import get_logger -from ..base import ( +from ..utils import gather_logprobs +from .base_loss import ( BaseRLLossConfig, BaseRLLossContext, BaseRLLossKwargs, compute_kl_loss_weight, ) -from ..loss_fn import get_policy_loss_fn, kl_penalty -from ..utils import gather_logprobs +from .loss_fn import get_policy_loss_fn, kl_penalty logger = get_logger() diff --git a/xtuner/v1/rl/loss_fn.py b/xtuner/v1/rl/loss/loss_fn.py similarity index 100% rename from xtuner/v1/rl/loss_fn.py rename to xtuner/v1/rl/loss/loss_fn.py diff --git a/xtuner/v1/rl/oreal/loss.py b/xtuner/v1/rl/loss/oreal_loss.py similarity index 98% rename from xtuner/v1/rl/oreal/loss.py rename to xtuner/v1/rl/loss/oreal_loss.py index 4d39ccf0cf..f115e7762b 100644 --- a/xtuner/v1/rl/oreal/loss.py +++ b/xtuner/v1/rl/loss/oreal_loss.py @@ -5,14 +5,14 @@ import torch.distributed as dist import torch.nn.functional as F -from ..base import ( +from ..utils import gather_logprobs +from .base_loss import ( BaseRLLossConfig, BaseRLLossContext, BaseRLLossKwargs, compute_kl_loss_weight, ) -from ..loss_fn import get_policy_loss_fn, kl_penalty, sft_loss_fn -from ..utils import gather_logprobs +from .loss_fn import get_policy_loss_fn, kl_penalty, sft_loss_fn class OrealLossConfig(BaseRLLossConfig): diff --git a/xtuner/v1/rl/oreal/__init__.py b/xtuner/v1/rl/oreal/__init__.py deleted file mode 100644 index 1a15ec3944..0000000000 --- a/xtuner/v1/rl/oreal/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .loss import OrealLossConfig, OrealLossContext - - -__all__ = [ - "OrealLossConfig", - "OrealLossContext", -] diff --git a/xtuner/v1/rl/replay_buffer.py b/xtuner/v1/rl/replay_buffer.py new file mode 100644 index 0000000000..c24c9b9972 --- /dev/null +++ b/xtuner/v1/rl/replay_buffer.py @@ -0,0 +1,551 @@ +import asyncio +from abc import ABC, abstractmethod +from dataclasses import dataclass, fields, is_dataclass, replace +from itertools import count +from pathlib import Path +from typing import Any, List, TypeAlias, Union + +import pandas as pd +import ray +import torch +from pydantic import BaseModel, ConfigDict + +from xtuner.v1.data_proto.rl_data import RolloutState, Status, refresh_seq_staleness, update_group_status +from xtuner.v1.rl.utils import ( + BetweenNode, + ConditionNode, + LogicNode, + LogicOperator, + Operators, + QueryNode, + ScalarNode, + SetNode, + clear_rollout_response_for_rerun, + parse_query, +) +from xtuner.v1.utils import get_logger + + +logger = get_logger(__name__) + + +@dataclass +class StorageItem: + # 存储类型 + item: List[RolloutState] + uid: int + timestamp_id: int + task_name: str + status: Status + staleness: int + + +@dataclass +class SerializedRayObjectRef: + value: Any + + +def _snapshot_nested_objectrefs(obj: Any) -> Any: + if isinstance(obj, ray.ObjectRef): + return SerializedRayObjectRef(_snapshot_nested_objectrefs(ray.get(obj))) + if isinstance(obj, BaseModel): + snapshot = obj.model_copy(deep=False) + for field_name in type(obj).model_fields: + setattr(snapshot, field_name, _snapshot_nested_objectrefs(getattr(obj, field_name))) + return snapshot + if is_dataclass(obj) and not isinstance(obj, type): + return replace( + obj, + **{field.name: _snapshot_nested_objectrefs(getattr(obj, field.name)) for field in fields(obj)}, + ) + if isinstance(obj, list): + return [_snapshot_nested_objectrefs(value) for value in obj] + if isinstance(obj, tuple): + return tuple(_snapshot_nested_objectrefs(value) for value in obj) + if isinstance(obj, set): + return {_snapshot_nested_objectrefs(value) for value in obj} + if isinstance(obj, dict): + return {key: _snapshot_nested_objectrefs(value) for key, value in obj.items()} + return obj + + +def _restore_nested_objectrefs(obj: Any) -> Any: + if isinstance(obj, SerializedRayObjectRef): + return ray.put(_restore_nested_objectrefs(obj.value)) + if isinstance(obj, BaseModel): + restored = obj.model_copy(deep=False) + for field_name in type(obj).model_fields: + setattr(restored, field_name, _restore_nested_objectrefs(getattr(obj, field_name))) + return restored + if is_dataclass(obj) and not isinstance(obj, type): + return replace( + obj, + **{field.name: _restore_nested_objectrefs(getattr(obj, field.name)) for field in fields(obj)}, + ) + if isinstance(obj, list): + return [_restore_nested_objectrefs(value) for value in obj] + if isinstance(obj, tuple): + return tuple(_restore_nested_objectrefs(value) for value in obj) + if isinstance(obj, set): + return {_restore_nested_objectrefs(value) for value in obj} + if isinstance(obj, dict): + return {key: _restore_nested_objectrefs(value) for key, value in obj.items()} + return obj + + +QUERY_KEYS = [f.name for f in fields(StorageItem)] +QueryKey = Union[str, LogicOperator] # str 是 StorageItem 的字段名,LogicOperator 是 "$and", "$or" 等逻辑操作符 + +# 查询类型: +QueryDict: TypeAlias = dict[ + QueryKey, + Union[ + Any, # 直接匹配值,例如: {"task_name": "math"} + dict[Operators, Any], # 操作符匹配,例如: {"uid": {"$gt": 10}} + List["QueryDict"], # 逻辑组合,例如: {"$and": [{"a": 1}, {"b": 2}]} + ], +] +QueryType = Union[QueryDict, QueryNode] + + +class StorageBackend(ABC): + @abstractmethod + async def put(self, item: StorageItem) -> int: ... + + @abstractmethod + async def get(self, query: QueryType) -> List[StorageItem]: ... + + @abstractmethod + async def count(self, query: QueryType) -> int: ... + + @abstractmethod + async def delete(self, uids: list[int]) -> None: ... + + @abstractmethod + async def update(self, items: list[StorageItem]) -> None: ... + + @abstractmethod + def __len__(self) -> int: ... + + @abstractmethod + def state_dict(self) -> dict[str, Any]: ... + + @abstractmethod + def load_state_dict(self, state: dict[str, Any]) -> None: ... + + +class ReplayPolicy(ABC): + @abstractmethod + async def put(self, item: StorageItem, storage_backend: StorageBackend) -> None: ... + + @abstractmethod + async def get(self, count: int, query: QueryType, storage_backend: StorageBackend) -> list[list[RolloutState]]: ... + + async def count(self, query: QueryType, storage_backend: StorageBackend) -> int: + return await storage_backend.count(query) + + +class NaiveStorage(StorageBackend): + def __init__(self): + self._uid_gen = count(1) + self._timestamp_id_gen = count(1) + self._items: dict[int, StorageItem] = {} + + async def put(self, item: StorageItem) -> int: + uid = next(self._uid_gen) + stored = replace(item, uid=uid, timestamp_id=next(self._timestamp_id_gen)) + self._items[uid] = stored + return uid + + def _evaluate(self, item: StorageItem, query_node: QueryNode) -> bool: + """NaiveStorage 实现的原生 Python 对象过滤树遍历.""" + if isinstance(query_node, LogicNode): + if not query_node.conditions: + return query_node.relation == "$and" + + if query_node.relation == "$and": + return all(self._evaluate(item, child) for child in query_node.conditions) + else: + return any(self._evaluate(item, child) for child in query_node.conditions) + + elif isinstance(query_node, ConditionNode): + if query_node.field not in QUERY_KEYS: + raise ValueError(f"查询字段错误: 找不到属性 '{query_node.field}'。可用属性为: {QUERY_KEYS}") + val = getattr(item, query_node.field) + + if isinstance(query_node, ScalarNode): + if query_node.op == "$eq": + return val == query_node.value + if query_node.op == "$ne": + return val != query_node.value + if query_node.op == "$gt": + return val > query_node.value + if query_node.op == "$gte": + return val >= query_node.value + if query_node.op == "$lt": + return val < query_node.value + if query_node.op == "$lte": + return val <= query_node.value + + elif isinstance(query_node, SetNode): + if query_node.op == "$in": + return val in query_node.value + if query_node.op == "$not_in": + return val not in query_node.value + + elif isinstance(query_node, BetweenNode): + return query_node.lower <= val <= query_node.upper + + return False + + async def get(self, query: QueryType) -> list[StorageItem]: + ast = parse_query(query) + return [item for item in self._items.values() if self._evaluate(item, ast)] + + async def count(self, query: QueryType) -> int: + ast = parse_query(query) + return sum(1 for item in self._items.values() if self._evaluate(item, ast)) + + async def delete(self, uids: list[int]) -> None: + if not uids: + return + for uid in uids: + self._items.pop(uid, None) + + async def update(self, items: list[StorageItem]) -> None: + for item in items: + old_item = self._items.get(item.uid) + if old_item is None: + continue + # 原地更新保留 uid/timestamp,避免刷新 staleness 改变 replay 顺序。 + self._items[item.uid] = replace(item, uid=old_item.uid, timestamp_id=old_item.timestamp_id) + + def __len__(self) -> int: + return len(self._items) + + def state_dict(self) -> dict[str, Any]: + max_uid = max(self._items, default=0) + max_timestamp_id = max((item.timestamp_id for item in self._items.values()), default=0) + return { + "items": [_snapshot_nested_objectrefs(item) for item in self._items.values()], + "next_uid": max_uid + 1, + "next_timestamp_id": max_timestamp_id + 1, + } + + def load_state_dict(self, state: dict[str, Any]) -> None: + items: list[StorageItem] = [_restore_nested_objectrefs(item) for item in state["items"]] + self._items = {item.uid: item for item in items} + self._uid_gen = count(state["next_uid"]) + self._timestamp_id_gen = count(state["next_timestamp_id"]) + + +class PandasStorage(StorageBackend): + def __init__(self): + self._uid_gen = count(1) + self._timestamp_id_gen = count(1) + self._df = pd.DataFrame(columns=["uid", "timestamp_id", "task_name", "status", "staleness", "item"]) + self._buffer: list[dict] = [] + + def _flush_buffer(self): + if self._buffer: + new_df = pd.DataFrame(self._buffer) + self._df = new_df if self._df.empty else pd.concat([self._df, new_df], ignore_index=True) + self._buffer.clear() + + async def put(self, item: StorageItem) -> int: + uid = next(self._uid_gen) + row = { + "uid": uid, + "timestamp_id": next(self._timestamp_id_gen), + "task_name": item.task_name, + "status": item.status, + "staleness": item.staleness, + "item": item.item, + } + self._buffer.append(row) + return uid + + def _evaluate(self, query_node: QueryNode, df: pd.DataFrame) -> pd.Series: + """PandasStorage 实现的向量化 DataFrame 过滤树遍历.""" + if isinstance(query_node, LogicNode): + if not query_node.conditions: + return ( + pd.Series(True, index=df.index) + if query_node.relation == "$and" + else pd.Series(False, index=df.index) + ) + + mask = self._evaluate(query_node.conditions[0], df) + for child in query_node.conditions[1:]: + child_mask = self._evaluate(child, df) + if query_node.relation == "$and": + mask = mask & child_mask + else: + mask = mask | child_mask + return mask + + elif isinstance(query_node, ConditionNode): + field = query_node.field + if field not in QUERY_KEYS: + raise ValueError(f"查询字段错误: 找不到属性 '{query_node.field}'。可用属性为: {QUERY_KEYS}") + series = df[query_node.field] + + if isinstance(query_node, ScalarNode): + if query_node.op == "$eq": + return series == query_node.value + if query_node.op == "$ne": + return series != query_node.value + if query_node.op == "$gt": + return series > query_node.value + if query_node.op == "$gte": + return series >= query_node.value + if query_node.op == "$lt": + return series < query_node.value + if query_node.op == "$lte": + return series <= query_node.value + + elif isinstance(query_node, SetNode): + if query_node.op == "$in": + return series.isin(query_node.value) + if query_node.op == "$not_in": + return ~series.isin(query_node.value) + + elif isinstance(query_node, BetweenNode): + return series.between(query_node.lower, query_node.upper) + else: + raise ValueError(f"不支持的查询节点类型: {type(query_node)}") + + async def get(self, query: QueryType) -> list[StorageItem]: + self._flush_buffer() + if self._df.empty: + return [] + + ast = parse_query(query) + filtered_df = self._df[self._evaluate(ast, self._df)] + return [ + StorageItem( + item=row["item"], + uid=row["uid"], + timestamp_id=row["timestamp_id"], + task_name=row["task_name"], + status=row["status"], + staleness=row["staleness"], + ) + for _, row in filtered_df.iterrows() + ] + + async def count(self, query: QueryType) -> int: + self._flush_buffer() + if self._df.empty: + return 0 + ast = parse_query(query) + return int(self._evaluate(ast, self._df).sum()) + + async def delete(self, uids: list[int]) -> None: + self._flush_buffer() + if not uids or self._df.empty: + return + self._df = self._df[~self._df["uid"].isin(uids)] + + async def update(self, items: list[StorageItem]) -> None: + self._flush_buffer() + if not items or self._df.empty: + return + for item in items: + mask = self._df["uid"] == item.uid + if not mask.any(): + continue + for row_idx in self._df.index[mask]: + self._df.at[row_idx, "status"] = item.status + self._df.at[row_idx, "staleness"] = item.staleness + self._df.at[row_idx, "item"] = item.item + + def __len__(self) -> int: + return len(self._df) + len(self._buffer) + + def state_dict(self) -> dict[str, Any]: + self._flush_buffer() + max_uid = int(self._df["uid"].max()) if not self._df.empty else 0 + max_timestamp_id = int(self._df["timestamp_id"].max()) if not self._df.empty else 0 + df = self._df.copy(deep=True) + if not df.empty: + df["item"] = df["item"].map(_snapshot_nested_objectrefs) + return { + "df": df, + "next_uid": max_uid + 1, + "next_timestamp_id": max_timestamp_id + 1, + } + + def load_state_dict(self, state: dict[str, Any]) -> None: + self._df = state["df"].copy(deep=True) + if not self._df.empty: + self._df["item"] = self._df["item"].map(_restore_nested_objectrefs) + self._buffer = [] + self._uid_gen = count(state["next_uid"]) + self._timestamp_id_gen = count(state["next_timestamp_id"]) + + +class FIFOReplayPolicy(ReplayPolicy): + async def put(self, item: StorageItem, storage_backend: StorageBackend) -> None: + if not item.item: + return + await storage_backend.put(item) + + async def get(self, count: int, query: QueryType, storage_backend: StorageBackend) -> list[list[RolloutState]]: + if count <= 0: + return [] + records = await storage_backend.get(query) + records.sort(key=lambda r: r.timestamp_id) + selected = records[:count] + if selected: + await storage_backend.delete([record.uid for record in selected]) + return [record.item for record in selected] + + +class StalenessReplayPolicy(ReplayPolicy): + async def put(self, item: StorageItem, storage_backend: StorageBackend) -> None: + if not item.item: + return + await storage_backend.put(item) + + async def get(self, count: int, query: QueryType, storage_backend: StorageBackend) -> list[list[RolloutState]]: + if count <= 0: + return [] + + records = await storage_backend.get(query) + records.sort(key=lambda r: (-r.staleness, r.timestamp_id)) + selected = records[:count] + if selected: + await storage_backend.delete([record.uid for record in selected]) + return [record.item for record in selected] + + async def count(self, query: QueryType, storage_backend: StorageBackend) -> int: + return await storage_backend.count(query) + + +class ReplayBuffer: + _SAVE_PATH = "replay_buffer.pth" + + def __init__(self, policy: ReplayPolicy, storage_backend: StorageBackend): + self._policy = policy + self._storage = storage_backend + self._lock = asyncio.Lock() + + async def put(self, items: list[RolloutState], task_name: str) -> None: + if not items: + return + status = update_group_status(items) + if status == Status.EXPIRED: + for item in items: + clear_rollout_response_for_rerun(item) + storage_item = StorageItem( + item=items, + uid=0, + timestamp_id=0, + task_name=task_name, + status=status, + staleness=max(item.seq_staleness for item in items), + ) + async with self._lock: + await self._policy.put(storage_item, self._storage) + + async def get(self, batch_size: int, task_name: str, group_status: Status) -> list[list[RolloutState]]: + # 使用 DSL 字典进行查询 + query_dsl: QueryDict = {"$and": [{"task_name": task_name}, {"status": group_status}]} + async with self._lock: + return await self._policy.get(batch_size, query_dsl, self._storage) + + async def count(self, task_name: str, group_status: Status) -> int: + # 使用 DSL 字典进行查询 + query_dsl: QueryDict = {"$and": [{"task_name": task_name}, {"status": group_status}]} + async with self._lock: + return await self._policy.count(query_dsl, self._storage) + + async def refresh_staleness( + self, + task_name: str, + current_train_step: int, + stale_threshold: int, + statuses: list[Status] | None = None, + ) -> int: + # 刷新可复用样本的 staleness;completed / aborted 都可能来自旧权重,需要按 train_step 淘汰。 + if stale_threshold <= 0: + raise ValueError(f"stale_threshold must be positive, got {stale_threshold}.") + if statuses is None: + statuses = [Status.COMPLETED, Status.ABORTED] + query_dsl: QueryDict = { + "$and": [ + {"task_name": task_name}, + {"status": {"$in": statuses}}, + ] + } + async with self._lock: + records = await self._storage.get(query_dsl) + updated_records: list[StorageItem] = [] + expired_count = 0 + for record in records: + refresh_seq_staleness(record.item, current_train_step) + staleness = max((getattr(item, "seq_staleness", 0) for item in record.item), default=0) + should_expire = any(getattr(item, "seq_staleness", 0) >= stale_threshold for item in record.item) + if should_expire: + # completed / aborted 样本超过 step 级阈值时整组翻转,后续 sampler 可按 EXPIRED 重新取样。 + for item in record.item: + clear_rollout_response_for_rerun(item) + item.status = Status.EXPIRED + status = Status.EXPIRED + expired_count += 1 + else: + status = update_group_status(record.item) + updated_records.append(replace(record, status=status, staleness=staleness)) + await self._storage.update(updated_records) + return expired_count + + def __len__(self) -> int: + return len(self._storage) + + async def save(self, path: str | Path) -> None: + file_path = Path(path) + file_path.mkdir(parents=True, exist_ok=True) + replay_buffer_path = file_path / self._SAVE_PATH + async with self._lock: + state = { + "policy": type(self._policy).__name__, + "storage": type(self._storage).__name__, + "storage_state": self._storage.state_dict(), + } + await asyncio.to_thread(torch.save, state, replay_buffer_path) + logger.info(f"Replay buffer saved to {replay_buffer_path}") + + async def resume(self, path: str | Path) -> None: + if len(self._storage) > 0: + raise RuntimeError("Cannot resume into a non-empty buffer") + + file_path = Path(path) + replay_buffer_path = file_path / self._SAVE_PATH + state = await asyncio.to_thread(torch.load, replay_buffer_path, map_location="cpu", weights_only=False) + if state["policy"] != type(self._policy).__name__: + raise ValueError(f"Replay policy mismatch: expected {type(self._policy).__name__}, got {state['policy']}") + + if state["storage"] != type(self._storage).__name__: + raise ValueError( + f"Storage backend mismatch: expected {type(self._storage).__name__}, got {state['storage']}" + ) + + async with self._lock: + self._storage.load_state_dict(state["storage_state"]) + logger.info(f"Replay buffer resumed from {replay_buffer_path}") + + +class SyncReplayBufferConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + def build(self): + return ReplayBuffer(policy=FIFOReplayPolicy(), storage_backend=NaiveStorage()) + + +class AsyncReplayBufferConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + def build(self): + policy = StalenessReplayPolicy() + return ReplayBuffer(policy=policy, storage_backend=NaiveStorage()) diff --git a/xtuner/v1/ray/rollout/__init__.py b/xtuner/v1/rl/rollout/__init__.py similarity index 75% rename from xtuner/v1/ray/rollout/__init__.py rename to xtuner/v1/rl/rollout/__init__.py index f09429134a..349cd2fad7 100644 --- a/xtuner/v1/ray/rollout/__init__.py +++ b/xtuner/v1/rl/rollout/__init__.py @@ -1,6 +1,6 @@ import os -from .controller import RolloutController, SampleParams +from .controller import RolloutController from .worker import RolloutWorker @@ -10,3 +10,5 @@ from .vllm import vLLMWorker if os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "1": from .lmdeploy import LMDeployWorker + +from .utils import continue_generation, pause_generation diff --git a/xtuner/v1/rl/rollout/controller.py b/xtuner/v1/rl/rollout/controller.py new file mode 100644 index 0000000000..3bbc3e9796 --- /dev/null +++ b/xtuner/v1/rl/rollout/controller.py @@ -0,0 +1,485 @@ +import asyncio +import os +import threading +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeAlias, TypedDict +from uuid import uuid4 + +import ray +from ray.actor import ActorProxy +from ray.util.placement_group import PlacementGroup + +from transformers import AutoTokenizer +from xtuner.v1.data_proto.rl_data import RolloutState, Status +from xtuner.v1.rl.utils import AutoAcceleratorWorkers +from xtuner.v1.utils import XTUNER_DETERMINISTIC, get_logger + +from .parser.factory import build_reasoning_parser, build_tool_call_parser +from .parser.reasoning_parser import ReasoningParser +from .parser.tool_parser import ToolCallParser +from .utils import ROLLOUT_RAY_GET_TIMEOUT, RolloutHealthChecker, SessionRouter +from .worker import RolloutConfig, RolloutWorker + + +if TYPE_CHECKING: + from xtuner.v1.rl.gateway.config import GatewayConfig + + +@dataclass +class WorkerInfo: + """A data class to hold all state information for a single worker.""" + + actor: RolloutWorker + url: str + is_active: bool = True + + +class RolloutWorkerMetadata(TypedDict): + """Metadata for rollout workers and their configuration. + + This data structure encapsulates all necessary information about the rollout worker infrastructure, including + engine topology, server addresses, and worker status. Used for communication between training processes and rollout + workers. + """ + + # 推理引擎的拓扑结构,每个子列表代表一个推理引擎包含的所有 worker ranks + # 例如:[[0, 1, 2, 3], [4, 5, 6, 7]] 表示有 2 个推理引擎,每个引擎包含 4 个 workers + # 用于确定分布式推理的并行组划分 + engine_rank_mesh_array: List[List[int]] + + # worker rank 到服务器 URL 的映射字典,用于训练进程与 rollout workers 通信 + # 键:worker 的 rank ID(字符串形式的整数) + # 值:对应的服务器地址列表(通常每个 rank 对应一个 URL) + server_url_dict: Dict[str, List[str]] + + # Rollout 配置对象,包含推理引擎的所有配置参数 + # 包括:并行策略(TP/EP)、超时设置、后端类型(LMDeploy/vLLM/SGLang)等 + rollout_config: RolloutConfig + + # 每个 worker 服务器 URL 的当前活跃状态 + # 键:服务器 URL 字符串 + # 值:布尔值,True 表示该 worker 处于活跃状态,False 表示已失效或停用 + worker_server_urls_status: Dict[str, bool] + + # Gateway HTTP server URL (e.g. "http://1.2.3.4:8080"). + # Set after start_gateway() is called; None if the gateway has not been started. + api_server_url: Optional[str] + + +# Keep this as a Ray actor because Ray AgentLoop actors need a shared, cross-process handle to the same controller +# state; passing a normal Python object would serialize a separate copy into each actor. +class RolloutController: + """Controller for managing and coordinating multiple RolloutWorker + actors.""" + + def __init__( + self, + infer_config: RolloutConfig, + placement_group: PlacementGroup, + ): + """Initialize the RolloutController. + + Args: + infer_config (RolloutConfig): The configuration for the rollout. + placement_group (PlacementGroup): The placement group for the + RolloutWorker actors. + """ + self.config = infer_config + self.num_gpus_per_engine = ( + self.config.expert_parallel_size + if self.config.expert_parallel_size > 1 + else self.config.tensor_parallel_size + ) + self.logger = get_logger(log_dir=infer_config.worker_log_dir, tag="RolloutController") + self.engine_rank_mesh_array: List[List[int]] = [] + self.worker_server_urls_map: dict[str, List[str]] = {} + self.rank2info: dict[int, WorkerInfo] = {} + self.engine_rank_mesh_array, self.worker_server_urls_map, self.rank2info = self._init_workers(placement_group) + self.num_active_workers = len(self.rank2info) + self.worker_info_lock = threading.RLock() + # The timeout for the environment to wait for the rollout controller's response. + # This should be longer than the controller's internal timeout (`rollout_timeout`) + # to account for potential queuing delays and other overheads. + self.timeout_multiplier = 2.0 + self.router = SessionRouter(self.rank2info, worker_infos_lock=self.worker_info_lock) + self.health_checker = RolloutHealthChecker( + config=self.config, + workers_info=self.rank2info, + worker_infos_lock=self.worker_info_lock, + ) + self.health_checker.start() + self._tool_call_parser, self._reasoning_parser = self._build_output_parsers() + self._gateway_url: str | None = None + + def start_gateway(self, config: "GatewayConfig") -> str | None: + """Start the gateway HTTP server in a daemon thread and return its URL. + + The gateway exposes OpenAI-compatible endpoints that forward requests to + this controller via :class:`~xtuner.v1.rl.gateway.backend.local_backend.LocalRolloutBackend`. + Agent loops (e.g. CamelAgentLoop) discover the URL via :meth:`get_rollout_metadata`. + + Args: + config: Gateway configuration. ``port`` and ``host`` control where + the server binds; ``capture_folder`` enables per-request trace files. + + Returns: + The base URL of the gateway, e.g. ``"http://1.2.3.4:8080"``, or + ``None`` when the configured rollout backend does not support the + gateway. + """ + if self.config.rollout_backend == "sglang": + self.logger.error("XTuner gateway is not supported for SGLang rollout backend yet; skip starting gateway.") + return None + + from xtuner.v1.rl.gateway import build_local_gateway_app, serve_gateway_in_thread + + config.capture_folder = str(Path(self.config.worker_log_dir) / config._CAPTURE_PATH_FOLDER) + app = build_local_gateway_app( + ray.get_runtime_context().current_actor, config=config, rollout_config=self.config + ) + serve_gateway_in_thread(app, config) + host = ray.util.get_node_ip_address() if config.host in ("", "0.0.0.0") else config.host + url = f"http://{host}:{config.port}" + self._gateway_url = url + self.logger.info(f"Gateway server started at {url}, capture_folder: {config.capture_folder}") + return url + + def get_rollout_metadata(self) -> RolloutWorkerMetadata: + """Get information about the current rollout setup. + + Returns: + dict: A dictionary containing the engine mesh list, server URL + dictionary, and the rollout configuration. + """ + with self.worker_info_lock: + worker_server_urls_status = {info.url: info.is_active for info in self.rank2info.values()} + rollout_metadata: RolloutWorkerMetadata = { + "engine_rank_mesh_array": self.engine_rank_mesh_array, + "server_url_dict": self.worker_server_urls_map, + "rollout_config": self.config, + "worker_server_urls_status": worker_server_urls_status, + "api_server_url": self._gateway_url, + } + return rollout_metadata + + def _build_output_parsers(self) -> tuple[ToolCallParser | None, ReasoningParser | None]: + tool_call_parser = None + reasoning_parser = None + + if self.config.tool_call_parser != "none": + tool_call_parser = build_tool_call_parser(self.config.tool_call_parser) + + if self.config.reasoning_parser != "none": + tokenizer_path = self.config.tokenizer_path or self.config.model_path + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True) + reasoning_parser = build_reasoning_parser(self.config.reasoning_parser, tokenizer) + + return tool_call_parser, reasoning_parser + + def get_ready_status(self) -> tuple[bool, dict[str, Any]]: + with self.worker_info_lock: + active_workers = sum(1 for info in self.rank2info.values() if info.is_active) + total_workers = len(self.rank2info) + return active_workers > 0, { + "active_workers": active_workers, + "total_workers": total_workers, + } + + async def generate(self, rollout_state: RolloutState) -> RolloutState: + if XTUNER_DETERMINISTIC: + sample_params = rollout_state.sample_params.model_copy(deep=True) + sample_params.sampling_seed = self.config.random_seed + ( + (rollout_state.uid or 0) - (rollout_state.message_uid or 0) + ) + rollout_state.sample_params = sample_params + + session_id = rollout_state.session_uid if rollout_state.session_uid is not None else uuid4().int + worker = await self.router.get_worker(session_id) + if worker is None: + rollout_state.status = Status.FAILED + rollout_state.error_msg = "No active rollout worker available." + return rollout_state + + response_ref = worker.generate.remote(rollout_state=rollout_state) # type: ignore[attr-defined] + try: + response_rollout_state = await asyncio.wait_for( + response_ref, timeout=self.config.rollout_timeout * self.timeout_multiplier + ) + self._apply_output_parsers(response_rollout_state) + return response_rollout_state + except asyncio.TimeoutError: + self.logger.error(f"Rollout timeout for worker {worker}. Skipping sample.") + rollout_state.status = Status.FAILED + rollout_state.error_msg = ( + f"Rollout request timed out after {self.config.rollout_timeout * self.timeout_multiplier} seconds." + ) + return rollout_state + + def _apply_output_parsers(self, rollout_state: RolloutState) -> None: + """Apply tool-call and reasoning parsers to the rollout state in- + place.""" + if self._tool_call_parser is not None: + parsed = self._tool_call_parser.parse(rollout_state) + rollout_state.tool_calls = parsed.tool_calls + rollout_state.response = parsed.remaining_text or None + if self._reasoning_parser is not None: + parsed_reasoning = self._reasoning_parser.parse(rollout_state) + rollout_state.response = parsed_reasoning.remaining_text + if parsed_reasoning.reasoning_text: + rollout_state.extra_fields["reasoning_text"] = parsed_reasoning.reasoning_text + else: + rollout_state.extra_fields.pop("reasoning_text", None) + + def pause_generation(self): + self.health_checker.pause() + + def continue_generation(self): + self.health_checker.resume() + self._broadcast_to_active_workers("continue_generation") + + def offload(self): + self._broadcast_to_active_workers("offload") + + def onload(self): + self._broadcast_to_active_workers("onload_weights") + self._broadcast_to_active_workers("onload_kvcache") + + def onload_weights(self): + self._broadcast_to_active_workers("onload_weights") + + def onload_kvcache(self): + self._broadcast_to_active_workers("onload_kvcache") + + def shutdown(self): + """Shuts down all active rollout workers. + + Args: + block (bool): Whether to block until the operation completes. + """ + self.health_checker.stop() + self._broadcast_to_active_workers("shutdown") + + def recover_failed_workers(self): + """Recovers from worker failures by restarting failed workers and + reinitializing the rollout setup.""" + self.health_checker.pause() + with self.worker_info_lock: + failed_workers = [info for info in self.rank2info.values() if not info.is_active] + if not failed_workers: + self.logger.info("No failed workers detected during recovery.") + return + + self.logger.warning(f"Detected {len(failed_workers)} failed workers. Initiating recovery process.") + for worker in failed_workers: + if self._restart_failed_workers(worker.actor): + with self.worker_info_lock: + rank = self._get_rank_by_actor(worker.actor) + if rank is not None: + self.rank2info[rank].is_active = True + self.health_checker.resume() + + def _restart_failed_workers(self, worker: RolloutWorker) -> bool: + try: + dist_init_addr = ray.get(worker.init_dist_port.remote(), timeout=ROLLOUT_RAY_GET_TIMEOUT) # type: ignore[attr-defined] + _, url = ray.get(worker.init.remote(dist_init_addr), timeout=ROLLOUT_RAY_GET_TIMEOUT) # type: ignore[attr-defined] + is_healthy = ray.get(worker.check_health.remote(), timeout=ROLLOUT_RAY_GET_TIMEOUT) # type: ignore[attr-defined] + if is_healthy: + self.logger.info(f"Successfully restarted worker {worker} with URL {url}.") + return True + else: + self.logger.error(f"Worker {worker} is still unhealthy after restart.") + return False + except Exception as e: + self.logger.error(f"Failed to restart worker: {e}") + return False + + def _update_dist_init_addr(self, nodes_per_engine, server_urls_per_engine, dist_init_addrs, tp_size): + """Update the distributed initialization addresses for workers. + + This is used to group workers that belong to the same inference engine. + + Args: + nodes_per_engine (int): The number of nodes per inference engine. + server_urls_per_engine (int): The number of server urls per inference engine. + dist_init_addrs (list): The list of initial addresses. + tp_size (int): The tensor parallel size. + + Returns: + list: The updated list of distributed initialization addresses. + """ + # lmdeploy pytorch ep: server_urls_per_engine > 1 + # sglang cross node engine: nodes_per_engine > 1 + assert server_urls_per_engine == 1 or nodes_per_engine == 1 + if nodes_per_engine > 1: + index = list(range(0, self.num_active_workers + 1, tp_size)) + [self.num_active_workers] + for i in range(1, len(index)): + dist_init_addrs[index[i - 1] : index[i]] = [dist_init_addrs[index[i - 1]]] * (index[i] - index[i - 1]) + if server_urls_per_engine > 1: + activate_servers = len(dist_init_addrs) + for i in range(0, activate_servers, server_urls_per_engine): + dist_init_addrs[i : i + server_urls_per_engine] = [dist_init_addrs[i]] * server_urls_per_engine + return dist_init_addrs + + def _get_active_servers_count(self, infer_config: RolloutConfig, gpu_nums: int): + """Calculate the number of active servers and nodes per engine. + + This calculation depends on the inference backend and parallelism settings. + + Args: + infer_config (RolloutConfig): The rollout configuration. + gpu_nums (int): The total number of GPUs available. + + Returns: + Tuple[int, int]: A tuple containing the number of active servers + and the number of nodes per engine. + """ + # NOTE:Since different inference engines have different launch methods, + # the number of nodes contained in each engine is not consistent. + # For example: sglang requires starting an inference engine for each node, + # while lmdeploy and vllm does not. Therefore, we calculate the number + # of active servers based on the configuration. + support_cross_node_comm = infer_config.rollout_cross_node_comm + gpus_per_node = infer_config.gpus_per_node + nodes_per_engine = ( + 1 + if support_cross_node_comm or self.num_gpus_per_engine < gpus_per_node + else self.num_gpus_per_engine // gpus_per_node + ) + + active_servers_count = int( + (gpu_nums // self.num_gpus_per_engine) * nodes_per_engine * infer_config.server_urls_per_engine + ) + return active_servers_count, nodes_per_engine + + def _broadcast_to_active_workers(self, method_name: str): + """Helper function to call a method on all active workers. + + Args: + method_name (str): The name of the method to call. + block (bool): Whether to block until the call completes. + + Returns: + A list of futures if `block` is False, otherwise a list of results. + """ + futures = [] + with self.worker_info_lock: + active_actors = [info.actor for info in self.rank2info.values() if info.is_active] + futures = [getattr(actor, method_name).remote() for actor in active_actors] + results = ray.get(futures, timeout=ROLLOUT_RAY_GET_TIMEOUT) + return results + + def _get_worker_cls(self): + if os.environ.get("XTUNER_USE_LMDEPLOY") == "1": + from .lmdeploy import LMDeployWorker + + return ray.remote(LMDeployWorker) + elif os.environ.get("XTUNER_USE_VLLM") == "1": + from .vllm import vLLMWorker + + return ray.remote(vLLMWorker) + elif os.environ.get("XTUNER_USE_SGLANG") == "1": + from .sglang import SGLangWorker + + return ray.remote(SGLangWorker) + else: + raise NotImplementedError( + "Rollout backend is not supported." + "Please set XTUNER_USE_LMDEPLOY or XTUNER_USE_VLLM" + " or XTUNER_USE_SGLANG environment variable." + ) + + def _get_rank_by_actor(self, actor: RolloutWorker) -> Optional[int]: + """Get rank by actor object. + + Args: + actor: The RolloutWorker actor object. + + Returns: + The rank of the worker, or None if not found. + """ + for rank, info in self.rank2info.items(): + if info.actor == actor: + return rank + return None + + def _update_active_workers_and_urls_map(self, active_rollout_workers, worker_server_urls_map): + """Update the list of active rollout workers and their server URLs. + + When the inference engine is launched across nodes (rollout_cross_node_comm=True), only the worker with + tp_rank=0 in each engine is responsible for receiving input data. Other tp_ranks do not accept input. + Therefore, this function updates active_rollout_workers and worker_server_urls_map to keep only the tp_rank=0 + workers and their corresponding URLs. + """ + if self.config.rollout_cross_node_comm or self.num_gpus_per_engine < self.config.gpus_per_node: + return active_rollout_workers, worker_server_urls_map + else: + active_worker_interval = self.num_gpus_per_engine // self.config.gpus_per_node + active_rank = list(worker_server_urls_map.keys())[::active_worker_interval] + active_worker_server_urls = list(worker_server_urls_map.values())[::active_worker_interval] + return active_rollout_workers[::active_worker_interval], dict(zip(active_rank, active_worker_server_urls)) + + def _init_workers(self, placement_group: PlacementGroup): + """Initializes and configures the pool of RolloutWorker actors. + + This method creates workers from the placement group, configures distributed + inference engines by grouping workers, where each group forms a tensor-parallel + inference engine. It determines the `active_workers` to act as the head of each + engine, constructs the `engine_rank_mesh_array` to define engine topology, + acquires necessary distributed communication ports, and finally launches servers + on the `active_workers` to get their addresses. + + Returns: + Tuple[List, Dict]: A tuple where the first element is + `engine_rank_mesh_array`, a list of lists containing the ranks of workers + in each engine, and the second element is `worker_server_urls_map`, + a dictionary mapping the rank of each active worker to its + corresponding server URL. + """ + # Create workers from placement group + workers, rank_bundle_idx_list = AutoAcceleratorWorkers.from_placement_group( + self._get_worker_cls(), self.config, placement_group + ) + active_servers_count, nodes_per_engine = self._get_active_servers_count(self.config, len(workers)) + interval = len(workers) // active_servers_count + active_rollout_workers = workers[::interval] + server_urls_per_engine = self.config.server_urls_per_engine + + set_bundle_idxs_objectref = [] + engine_rank_mesh_array = [] + activate_worker_idx = 0 + for active_worker in active_rollout_workers: + head_rank, _ = rank_bundle_idx_list[activate_worker_idx] + engine_workers_meta = rank_bundle_idx_list[head_rank : head_rank + interval] + engine_bundle_idxs = [meta[1] for meta in engine_workers_meta] # meta: (rank, bundle_idx) + set_bundle_idxs_objectref.append(active_worker._set_engine_bundle_idxs.remote(engine_bundle_idxs)) # type: ignore[attr-defined] + engine_rank_mesh_array.append([meta[0] for meta in engine_workers_meta]) + activate_worker_idx += interval + ray.get(set_bundle_idxs_objectref) + # set engine mesh list for each worker + ray.get( + [worker._set_engine_rank_mesh_array.remote(engine_rank_mesh_array) for worker in active_rollout_workers] + ) # type: ignore[attr-defined] + # init dist_init_addr for each worker according to parallel settings + init_dist_init_addrs = ray.get([worker.init_dist_port.remote() for worker in active_rollout_workers]) # type: ignore[attr-defined] + dist_init_addrs = self._update_dist_init_addr( + nodes_per_engine, server_urls_per_engine, init_dist_init_addrs, self.num_gpus_per_engine + ) + # launch rollout servers + worker_server_urls_map = dict( # rank -> url + ray.get([worker.init.remote(dist_init_addrs[i]) for i, worker in enumerate(active_rollout_workers)]) + ) + active_rollout_workers, worker_server_urls_map = self._update_active_workers_and_urls_map( + active_rollout_workers, worker_server_urls_map + ) + workers_info = {} + for i in range(len(active_rollout_workers)): + rank = list(worker_server_urls_map.keys())[i] + url = worker_server_urls_map[rank] + workers_info[rank] = WorkerInfo(actor=active_rollout_workers[i], url=url) + self.logger.info(f"Rollout worker server URLs: {[info.url for info in workers_info.values()]}") + return engine_rank_mesh_array, worker_server_urls_map, workers_info + + +RayRolloutController = ray.remote(RolloutController) +RolloutControllerProxy: TypeAlias = ActorProxy[RayRolloutController] diff --git a/xtuner/v1/ray/rollout/lmdeploy.py b/xtuner/v1/rl/rollout/lmdeploy.py similarity index 76% rename from xtuner/v1/ray/rollout/lmdeploy.py rename to xtuner/v1/rl/rollout/lmdeploy.py index a0ec513557..f26fe72c3a 100644 --- a/xtuner/v1/ray/rollout/lmdeploy.py +++ b/xtuner/v1/rl/rollout/lmdeploy.py @@ -1,8 +1,7 @@ -import copy import os from argparse import Namespace from itertools import chain -from typing import Any, Dict, List, Union +from typing import Any, Dict, List import ray import requests @@ -10,9 +9,9 @@ from ray.util.placement_group import placement_group_table from transformers import AutoTokenizer -from xtuner.v1.ray.config import RolloutConfig +from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams -from .worker import RolloutWorker +from .worker import RolloutConfig, RolloutWorker SHARED_STORE = "shared_store" @@ -81,81 +80,69 @@ def __init__( self.enable_return_routed_experts = self.config.enable_return_routed_experts self.lmdeploy_actor = None - async def _create_request( - self, - url: str, - prompt: Union[str, List[Dict[str, Any]]] | None, - input_ids: List[int] | None, - tools: List, # reserved for agent tool use - tool_choice: str, # reserved for agent tool use - sample_params: dict, - extra_params: dict, - extra_info: dict, - ): - """Create and send a streaming generation request to the server. + def offload(self): + """Offloads the model weights and KV cache.""" + return self._sleep(level=2) - Args: - url (str): The URL of the generation endpoint. - prompt (List[Dict[str, str]]): The input prompt for generation, - formatted as a list of messages. - tools (List): A list of tools the model can call. - tool_choice (str): The tool choice strategy. - sample_params (dict): Parameters for sampling. Defaults to {}. - extra_params (dict): Extra parameters for the request. - Defaults to {}. + def onload_weights(self): + """Onloads the model weights by waking up the model.""" + return self._wake_up(tags=["weights"]) - Returns: - An httpx.Response object for streaming the response. - """ - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_keys}", # 如果需要鉴权 - } - payload = { - "model": self.model_name, - "tools": tools if len(tools) > 0 else None, - "tool_choice": tool_choice if tool_choice else None, - } - if "return_token_ids" in extra_params and extra_params["return_token_ids"]: - if "image_data" in extra_info: - assert input_ids is not None, "input_ids is required when image_data is provided." - - if input_ids is not None: - payload["input_ids"] = input_ids - if "image_data" in extra_info: - payload["image_data"] = extra_info["image_data"] + def onload_kvcache(self): + """Onloads the KV cache by waking up the model.""" + return self._wake_up(tags=["kv_cache"]) + + def _get_request_payload(self, rollout_state: RolloutState) -> dict: + tools = rollout_state.tools + tool_choice = rollout_state.tool_choice + sample_params = rollout_state.sample_params + message = rollout_state.message + input_tokens = rollout_state.tokens + + optional_fields: dict[str, object] = {} + if tools is not None: + optional_fields["tools"] = tools + if tool_choice is not None: + optional_fields["tool_choice"] = tool_choice + + if sample_params.return_token_ids: + payload = {"model": self.model_name, **optional_fields} + + if "image_data" in rollout_state.extra_fields: + assert input_tokens is not None, "input_tokens is required when image_data is provided." + payload["image_data"] = rollout_state.extra_fields["image_data"] + + if input_tokens is not None: + payload["input_ids"] = input_tokens else: - text_prompt = self.tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) + text_prompt = self.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) prompt_token_ids = self.tokenizer(text_prompt, add_special_tokens=False)["input_ids"] payload["input_ids"] = prompt_token_ids + sample_params.return_routed_experts = True if self.enable_return_routed_experts else False + lmdeploy_sample_params = self._transform_sample_params(sample_params) + payload.update(lmdeploy_sample_params) else: - payload["messages"] = prompt - - if "partial_rollout_input_ids" in extra_info: - assert "return_token_ids" in extra_params and extra_params["return_token_ids"], ( - "concat response_ids and input_ids is only compatible with return_token_ids=True." - ) - payload["input_ids"] = extra_info["partial_rollout_input_ids"] - assert len(payload["input_ids"]) <= self.config.context_length, ( - f"Total input length {len(payload['input_ids'])} exceeds context length {self.config.context_length}." - ) - - if self.enable_return_routed_experts and not extra_params.get("disable_routed_experts", False): - extra_params["return_routed_experts"] = True - - lmdeploy_sample_params = self._transform_sample_params(sample_params, extra_params) - payload.update(lmdeploy_sample_params) - return await self._safe_post_request(url, headers, payload) - - def get_logprobs(self, input_ids, sampling_params): - """This method will be implemented for the LMDeploy worker in the - future.""" - pass - - def generate(self, input_ids, sampling_params): - """This method will be implemented for the LMDeploy worker in the - future.""" - pass + payload = { + "model": self.model_name, + "messages": rollout_state.message, + **optional_fields, + } + lmdeploy_sample_params = { + "temperature": sample_params.temperature, + "top_p": sample_params.top_p, + "n": sample_params.n, + "stream": sample_params.stream, + "max_tokens": sample_params.max_tokens, + "repetition_penalty": sample_params.repetition_penalty, + "top_k": sample_params.top_k, + "skip_special_tokens": sample_params.skip_special_tokens, + } + if sample_params.stops: + lmdeploy_sample_params["stop"] = sample_params.stops + if sample_params.min_tokens > 0: + lmdeploy_sample_params["min_new_tokens"] = sample_params.min_tokens + payload.update(lmdeploy_sample_params) + return payload def _sleep(self, level: int = 1): """Put the model into a sleep state to save resources. @@ -173,11 +160,7 @@ def _sleep(self, level: int = 1): assert response.status_code == 200, response.status_code return response.text - def offload(self): - """Offloads the model weights and KV cache.""" - return self._sleep(level=2) - - def wake_up(self, tags: List[str] | None = None): + def _wake_up(self, tags: List[str] | None = None): """Wakes up the model from a sleep state. Args: @@ -194,33 +177,12 @@ def wake_up(self, tags: List[str] | None = None): assert response.status_code == 200, response.status_code return response.text - def onload_weights(self): - """Onloads the model weights by waking up the model.""" - return self.wake_up(tags=["weights"]) - - def onload_kvcache(self): - """Onloads the KV cache by waking up the model.""" - return self.wake_up(tags=["kv_cache"]) - - def pause_generation(self): - """It will implemented for LMDeploy worker in the future.""" - pass - - def continue_generation(self): - """It will implemented for LMDeploy worker in the future.""" - pass - - def reset_prefix_cache(self): - """It will implemented for LMDeploy worker in the future.""" - pass - - def _decode_routed_experts(self, routed_experts: Any): + def _decode_routed_experts(self, routed_experts: Any) -> Any: if isinstance(routed_experts, str): if self.lmdeploy_actor is None: self.lmdeploy_actor = ray.get_actor(SHARED_STORE, namespace=SHARED_STORE_NAMESPACE) assert self.lmdeploy_actor is not None, "LMDeploy actor should be available in the shared store." - routed_experts_ref = self.lmdeploy_actor.get.remote(routed_experts) - return routed_experts_ref + return self.lmdeploy_actor.get.remote(routed_experts) return torch.tensor(routed_experts) def _transform_rollout_config_to_server_configs(self) -> Namespace: @@ -253,7 +215,7 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace: # Therefore, each server only needs to handle 1 / dp_size of the total requests max_batch_size = self.config.rollout_max_batch_size_per_instance // dp_size distributed_executor_backend = lmdeploy_config_kwargs.get("distributed_executor_backend", "ray") - lmdeploy_config_kwargs["log_level"] = lmdeploy_config_kwargs.pop("log_level", "WARNING") + lmdeploy_config_kwargs["log_level"] = lmdeploy_config_kwargs.pop("log_level", "ERROR") lmdeploy_config_kwargs["uvicorn_log_level"] = lmdeploy_config_kwargs.pop("uvicorn_log_level", "ERROR") lmdeploy_config_kwargs["tm_log_level"] = lmdeploy_config_kwargs.pop("tm_log_level", "ERROR") @@ -396,8 +358,5 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace: **lmdeploy_config_kwargs, ) - def _transform_sample_params(self, sample_params: Dict, extra_params: Dict = {}): - lmdeploy_sample_params = copy.deepcopy(sample_params) - if extra_params: - lmdeploy_sample_params.update(extra_params) - return lmdeploy_sample_params + def _transform_sample_params(self, sample_params: SampleParams) -> dict: + return sample_params.model_dump(exclude_none=True) diff --git a/xtuner/v1/rl/rollout/parser/__init__.py b/xtuner/v1/rl/rollout/parser/__init__.py new file mode 100644 index 0000000000..a7c3cdf594 --- /dev/null +++ b/xtuner/v1/rl/rollout/parser/__init__.py @@ -0,0 +1,19 @@ +from .factory import build_reasoning_parser, build_tool_call_parser +from .qwen3_reasoning_parser import Qwen3ReasoningParser +from .qwen3_tool_parser import Qwen3ToolCallParser +from .qwen3p5_tool_parser import Qwen3p5ToolCallParser +from .reasoning_parser import ParsedReasoningResult, ReasoningParser +from .tool_parser import ParsedToolCallResult, ToolCallParser + + +__all__ = [ + "ParsedReasoningResult", + "ParsedToolCallResult", + "Qwen3ReasoningParser", + "Qwen3p5ToolCallParser", + "Qwen3ToolCallParser", + "ReasoningParser", + "ToolCallParser", + "build_reasoning_parser", + "build_tool_call_parser", +] diff --git a/xtuner/v1/rl/rollout/parser/factory.py b/xtuner/v1/rl/rollout/parser/factory.py new file mode 100644 index 0000000000..86cf37e4ce --- /dev/null +++ b/xtuner/v1/rl/rollout/parser/factory.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from typing import Literal + +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast + +from .qwen3_reasoning_parser import Qwen3ReasoningParser, extract_qwen3_reasoning_strip_tokens +from .qwen3_tool_parser import Qwen3ToolCallParser +from .qwen3p5_tool_parser import Qwen3p5ToolCallParser +from .reasoning_parser import ReasoningParser +from .tool_parser import ToolCallParser + + +ToolCallParserName = Literal["none", "qwen3", "qwen3p5"] +ReasoningParserName = Literal["none", "qwen3"] + + +def build_tool_call_parser(parser_name: ToolCallParserName) -> ToolCallParser | None: + if parser_name == "none": + return None + if parser_name == "qwen3": + return Qwen3ToolCallParser() + if parser_name == "qwen3p5": + return Qwen3p5ToolCallParser() + raise ValueError(f"Unsupported tool_call_parser: {parser_name}") + + +def build_reasoning_parser( + parser_name: ReasoningParserName, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, +) -> ReasoningParser | None: + if parser_name == "none": + return None + if parser_name == "qwen3": + return Qwen3ReasoningParser(strip_tokens=extract_qwen3_reasoning_strip_tokens(tokenizer)) + raise ValueError(f"Unsupported reasoning_parser: {parser_name}") diff --git a/xtuner/v1/rl/rollout/parser/qwen3_reasoning_parser.py b/xtuner/v1/rl/rollout/parser/qwen3_reasoning_parser.py new file mode 100644 index 0000000000..7ee7dad8b4 --- /dev/null +++ b/xtuner/v1/rl/rollout/parser/qwen3_reasoning_parser.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import re + +from xtuner.v1.data_proto.rl_data import RolloutState + +from .reasoning_parser import ParsedReasoningResult, ReasoningParser + + +class Qwen3ReasoningParser(ReasoningParser): + _reasoning_pattern = re.compile(r"\s*(.*?)\s*", re.DOTALL) + + def __init__(self, strip_tokens: list[str] | None = None): + self._strip_tokens = strip_tokens or [] + + def parse(self, rollout_state: RolloutState) -> ParsedReasoningResult: + text = rollout_state.response or "" + if not text: + return ParsedReasoningResult() + cleaned = text + for token in self._strip_tokens: + cleaned = cleaned.replace(token, "") + reasoning_chunks = [ + match.group(1).strip() for match in self._reasoning_pattern.finditer(cleaned) if match.group(1).strip() + ] + content = self._reasoning_pattern.sub("", cleaned).strip() + if not reasoning_chunks and "" in cleaned: + prefix, suffix = cleaned.split("", 1) + content = prefix.strip() + truncated_reasoning = suffix.replace("", "").strip() + if truncated_reasoning: + reasoning_chunks.append(truncated_reasoning) + elif not reasoning_chunks and "" in cleaned: + reasoning_text, content = cleaned.split("", 1) + reasoning_text = reasoning_text.strip() + if reasoning_text: + reasoning_chunks.append(reasoning_text) + content = content.strip() + reasoning = "\n".join(reasoning_chunks).strip() or None + return ParsedReasoningResult(reasoning_text=reasoning, remaining_text=content or None) + + +def extract_qwen3_reasoning_strip_tokens( + tokenizer, +) -> list[str]: + strip_tokens: list[str] = [] + + eos_token = getattr(tokenizer, "eos_token", None) + if isinstance(eos_token, str) and eos_token: + strip_tokens.append(eos_token) + + for token in getattr(tokenizer, "additional_special_tokens", []) or []: + if not isinstance(token, str): + continue + lowered = token.lower() + if any(marker in lowered for marker in ("im_end", "eot", "end_of_turn", "turn_end")): + strip_tokens.append(token) + + return list(dict.fromkeys(strip_tokens)) diff --git a/xtuner/v1/rl/rollout/parser/qwen3_tool_parser.py b/xtuner/v1/rl/rollout/parser/qwen3_tool_parser.py new file mode 100644 index 0000000000..f7eab82539 --- /dev/null +++ b/xtuner/v1/rl/rollout/parser/qwen3_tool_parser.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +import re +from typing import Any +from uuid import uuid4 + +from xtuner.v1.data_proto.rl_data import RolloutToolCall + +from .tool_parser import ( + ParsedToolCallResult, + ToolCallParser, + build_rollout_tool_call, + coerce_parameter_value, + parse_json_or_python_mapping, +) + + +class Qwen3ToolCallParser(ToolCallParser): + _tool_call_pattern = re.compile(r"\n*(.*?)", re.DOTALL) + _qwen_function_pattern = re.compile(r"\n]+)>(.*?)", re.DOTALL) + _qwen_parameter_pattern = re.compile(r"\n]+)>(.*?)", re.DOTALL) + _xml_tag_pattern = re.compile(r"<([a-zA-Z_][^>\n/]*)>(.*?)", re.DOTALL) + + def parse_text(self, text: str) -> ParsedToolCallResult: + if not text: + return ParsedToolCallResult() + cleaned_text, tool_calls = self._extract_tool_call_tags(text) + cleaned_text, qwen_tool_calls = self._extract_qwen_function_calls(cleaned_text) + tool_calls.extend(qwen_tool_calls) + return ParsedToolCallResult(remaining_text=cleaned_text.strip(), tool_calls=tool_calls) + + def should_reject_unparsed_markup( + self, + *, + has_tools: bool, + text: str | None, + parsed_tool_calls: list[Any] | None, + ) -> bool: + if not has_tools: + return False + if parsed_tool_calls: + return False + if not text: + return False + return any(marker in text for marker in ("", "", " tuple[str, list[RolloutToolCall]]: + tool_calls: list[RolloutToolCall] = [] + text_parts: list[str] = [] + last_end = 0 + for match in self._qwen_function_pattern.finditer(text): + if match.start() > last_end: + text_parts.append(text[last_end : match.start()]) + parsed_tool_call = self._parse_qwen_function_call(match.group(1).strip(), match.group(2)) + if parsed_tool_call is None: + text_parts.append(match.group(0)) + else: + tool_calls.append(parsed_tool_call) + last_end = match.end() + if last_end < len(text): + text_parts.append(text[last_end:]) + return "".join(text_parts), tool_calls + + def _parse_single_textual_tool_call(self, raw_payload: str) -> RolloutToolCall | None: + payload = parse_json_or_python_mapping(raw_payload) + if isinstance(payload, dict) and payload.get("name"): + arguments = payload.get("arguments", payload.get("parameters", {})) + return build_rollout_tool_call( + name=str(payload["name"]), + arguments=arguments, + call_id=str(payload.get("id") or f"call_{uuid4().hex}"), + ) + function_match = self._qwen_function_pattern.search(raw_payload) + if function_match is None: + return None + return self._parse_qwen_function_call(function_match.group(1).strip(), function_match.group(2)) + + def _parse_qwen_function_call(self, function_name: str, function_body: str) -> RolloutToolCall | None: + arguments: dict[str, Any] = {} + for parameter_match in self._qwen_parameter_pattern.finditer(function_body): + param_name = parameter_match.group(1).strip() + param_value = parameter_match.group(2).strip() + arguments[param_name] = coerce_parameter_value(param_value) + if not arguments: + for tag_match in self._xml_tag_pattern.finditer(function_body): + tag_name = tag_match.group(1).strip() + if tag_name.startswith("function="): + continue + tag_value = tag_match.group(2).strip() + if tag_name in {"path", "file_path"}: + arguments[tag_name] = tag_value + else: + arguments[tag_name] = coerce_parameter_value(tag_value) + return build_rollout_tool_call( + name=function_name, + arguments=arguments, + call_id=f"call_{uuid4().hex}", + ) diff --git a/xtuner/v1/rl/rollout/parser/qwen3p5_tool_parser.py b/xtuner/v1/rl/rollout/parser/qwen3p5_tool_parser.py new file mode 100644 index 0000000000..eb39733058 --- /dev/null +++ b/xtuner/v1/rl/rollout/parser/qwen3p5_tool_parser.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import re +from typing import Any +from uuid import uuid4 + +from xtuner.v1.data_proto.rl_data import RolloutToolCall + +from .tool_parser import ParsedToolCallResult, ToolCallParser, build_rollout_tool_call, coerce_parameter_value + + +class Qwen3p5ToolCallParser(ToolCallParser): + _tool_call_pattern = re.compile(r"\n*(.*?)", re.DOTALL) + _parameter_pattern = re.compile(r"\n]+)>(.*?)", re.DOTALL) + + def parse_text(self, text: str) -> ParsedToolCallResult: + if not text: + return ParsedToolCallResult() + cleaned_text, tool_calls = self._extract_tool_call_tags(text) + return ParsedToolCallResult(remaining_text=cleaned_text.strip(), tool_calls=tool_calls) + + def should_reject_unparsed_markup( + self, + *, + has_tools: bool, + text: str | None, + parsed_tool_calls: list[Any] | None, + ) -> bool: + if not has_tools: + return False + if parsed_tool_calls: + return False + if not text: + return False + return any(marker in text for marker in ("", "", " RolloutToolCall | None: + function_name = self._extract_function_name(raw_payload) + if not function_name: + return None + + arguments: dict[str, Any] = {} + for parameter_match in self._parameter_pattern.finditer(raw_payload): + parameter_name = parameter_match.group(1).strip() + parameter_value = parameter_match.group(2).strip() + arguments[parameter_name] = coerce_parameter_value(parameter_value) + + return build_rollout_tool_call( + name=function_name, + arguments=arguments, + call_id=f"call_{uuid4().hex}", + ) + + def _extract_function_name(self, raw_payload: str) -> str | None: + function_start = raw_payload.find("", name_start), raw_payload.find("\n", name_start)) if index != -1 + ] + if not terminators: + return None + + function_name = raw_payload[name_start : min(terminators)].strip() + return function_name or None diff --git a/xtuner/v1/rl/rollout/parser/reasoning_parser.py b/xtuner/v1/rl/rollout/parser/reasoning_parser.py new file mode 100644 index 0000000000..a6ec85b542 --- /dev/null +++ b/xtuner/v1/rl/rollout/parser/reasoning_parser.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod + +from pydantic import BaseModel, ConfigDict + +from xtuner.v1.data_proto.rl_data import RolloutState + + +class ParsedReasoningResult(BaseModel): + model_config = ConfigDict(extra="forbid") + + reasoning_text: str | None = None + remaining_text: str | None = None + + +class ReasoningParser(ABC): + @abstractmethod + def parse(self, rollout_state: RolloutState) -> ParsedReasoningResult: + """Return parsed reasoning and remaining text for a rollout + response.""" diff --git a/xtuner/v1/rl/rollout/parser/tool_parser.py b/xtuner/v1/rl/rollout/parser/tool_parser.py new file mode 100644 index 0000000000..0bf042c5cf --- /dev/null +++ b/xtuner/v1/rl/rollout/parser/tool_parser.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +import ast +import json +from abc import ABC, abstractmethod +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from xtuner.v1.data_proto.rl_data import RolloutFunctionCall, RolloutState, RolloutToolCall + + +class ParsedToolCallResult(BaseModel): + model_config = ConfigDict(extra="forbid") + + tool_calls: list[RolloutToolCall] = Field(default_factory=list) + remaining_text: str = "" + + +class ToolCallParser(ABC): + def parse(self, rollout_state: RolloutState) -> ParsedToolCallResult: + return self.parse_text(rollout_state.response or "") + + def should_reject_unparsed_markup( + self, + *, + has_tools: bool, + text: str | None, + parsed_tool_calls: list[Any] | None, + ) -> bool: + """Whether the remaining assistant text should be rejected as a + malformed tool call. + + Most parsers do not use textual tool-call markup, so the default behavior is to accept the text. Parsers with + format-specific markup can override this and reject outputs that still contain unparsed tool-call fragments. + """ + return False + + @abstractmethod + def parse_text(self, text: str) -> ParsedToolCallResult: + raise NotImplementedError + + +def extract_tokenizer_token_contents( + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | Any, +) -> set[str]: + token_contents: set[str] = set() + + for token in getattr(tokenizer, "additional_special_tokens", []) or []: + if isinstance(token, str): + token_contents.add(token) + + added_tokens_decoder = getattr(tokenizer, "added_tokens_decoder", None) + if isinstance(added_tokens_decoder, dict): + for token_info in added_tokens_decoder.values(): + if isinstance(token_info, str): + token_contents.add(token_info) + elif isinstance(token_info, dict): + content = token_info.get("content") + if isinstance(content, str): + token_contents.add(content) + else: + content = getattr(token_info, "content", None) + if isinstance(content, str): + token_contents.add(content) + + get_vocab = getattr(tokenizer, "get_vocab", None) + if callable(get_vocab): + try: + vocab = get_vocab() + except Exception: + vocab = None + if isinstance(vocab, dict): + token_contents.update(token for token in vocab if isinstance(token, str)) + + return token_contents + + +def parse_json_or_python_mapping(raw_payload: str) -> Any: + try: + return json.loads(raw_payload) + except Exception: + try: + return ast.literal_eval(raw_payload) + except Exception: + return None + + +def coerce_parameter_value(value: str) -> Any: + stripped = value.strip() + if not stripped: + return "" + try: + return json.loads(stripped) + except Exception: + try: + return ast.literal_eval(stripped) + except Exception: + return stripped + + +def build_rollout_tool_call( + *, + name: str, + arguments: Any, + call_id: str, +) -> RolloutToolCall: + raw_arguments_text = arguments if isinstance(arguments, str) else None + parsed_arguments = arguments + if isinstance(arguments, str): + decoded = parse_json_or_python_mapping(arguments) + parsed_arguments = decoded if decoded is not None else {"raw": arguments} + return RolloutToolCall( + id=call_id, + function=RolloutFunctionCall( + name=name, + arguments=parsed_arguments, + raw_arguments_text=raw_arguments_text, + ), + ) diff --git a/xtuner/v1/ray/rollout/sglang.py b/xtuner/v1/rl/rollout/sglang.py similarity index 83% rename from xtuner/v1/ray/rollout/sglang.py rename to xtuner/v1/rl/rollout/sglang.py index bdedd16bb2..9110eeff67 100644 --- a/xtuner/v1/ray/rollout/sglang.py +++ b/xtuner/v1/rl/rollout/sglang.py @@ -8,10 +8,10 @@ from urllib3.exceptions import NewConnectionError from transformers import AutoConfig, AutoTokenizer -from xtuner.v1.ray.config import RolloutConfig +from xtuner.v1.data_proto.rl_data import RolloutState from xtuner.v1.utils import XTUNER_DETERMINISTIC -from .worker import RolloutWorker +from .worker import RolloutConfig, RolloutWorker class SGLangWorker(RolloutWorker): @@ -41,6 +41,47 @@ def __init__( self.model_name = self.config.model_name self.enable_return_routed_experts = self.config.enable_return_routed_experts + def _get_request_payload(self, rollout_state: RolloutState) -> dict: + sample_params = rollout_state.sample_params + payload: dict[str, Any] = {"model": self.model_name} + + if rollout_state.tools is not None: + payload["tools"] = rollout_state.tools + if rollout_state.tool_choice is not None: + payload["tool_choice"] = rollout_state.tool_choice + + sglang_sample_params = self._transform_sample_params(sample_params.model_dump()) + sglang_extra_params = self._transform_extra_params(sample_params.model_dump()) + payload.update(sglang_extra_params) + + if self.enable_return_routed_experts and not rollout_state.extra_fields.get("disable_routed_experts", False): + payload["return_routed_experts"] = True + + if sample_params.return_token_ids: + if "image_data" in rollout_state.extra_fields: + assert rollout_state.tokens is not None, "input_ids is required when image_data is provided." + payload["image_data"] = rollout_state.extra_fields["image_data"] + + if rollout_state.tokens is not None: + payload["input_ids"] = rollout_state.tokens + else: + text_prompt = self.tokenizer.apply_chat_template( + rollout_state.message, tokenize=False, add_generation_prompt=True + ) + payload["input_ids"] = self.tokenizer(text_prompt, add_special_tokens=False)["input_ids"] + + payload["sampling_params"] = sglang_sample_params + return payload + + payload["messages"] = rollout_state.message + payload.update(sglang_sample_params) + # The chat-completions API uses OpenAI-style names. + payload["max_tokens"] = sglang_sample_params["max_new_tokens"] + payload["min_tokens"] = sglang_sample_params["min_new_tokens"] + payload.pop("max_new_tokens", None) + payload.pop("min_new_tokens", None) + return payload + async def _create_request( self, url: str, @@ -249,8 +290,11 @@ def _transform_sample_params(self, sample_params: Dict): "stop_token_ids": sample_params["stop_token_ids"], "skip_special_tokens": sample_params["skip_special_tokens"], } - if XTUNER_DETERMINISTIC: - sglang_sample_params["sampling_seed"] = sample_params["sampling_seed"] + sampling_seed = sample_params.get("sampling_seed") + if sampling_seed is None and XTUNER_DETERMINISTIC: + sampling_seed = self.config.random_seed + if sampling_seed is not None: + sglang_sample_params["sampling_seed"] = sampling_seed return sglang_sample_params def _transform_extra_params(self, extra_params: Dict): diff --git a/xtuner/v1/rl/rollout/utils.py b/xtuner/v1/rl/rollout/utils.py new file mode 100644 index 0000000000..2fd75da190 --- /dev/null +++ b/xtuner/v1/rl/rollout/utils.py @@ -0,0 +1,279 @@ +import asyncio +import os +import threading +import time +from collections import OrderedDict +from itertools import cycle +from typing import TYPE_CHECKING, Any, Optional + +import httpx +import ray + +from xtuner.v1.rl.utils import asyncio_run +from xtuner.v1.utils import get_logger + + +if TYPE_CHECKING: + from .controller import RolloutControllerProxy, WorkerInfo + from .worker import RolloutConfig, RolloutWorker + +ROLLOUT_RAY_GET_TIMEOUT = int(os.getenv("XTUNER_ROLLOUT_RAY_GET_TIMEOUT", str(5 * 3600))) # default 5 hours +logger = get_logger() + + +class SessionRouter: + def __init__( + self, + worker_infos: dict[int, "WorkerInfo"], # worker: worker_status + worker_infos_lock: Optional[threading.RLock] = None, + max_sessions: int = 10000, + max_idle_seconds: Optional[float] = 3600.0, + ): + self._worker_infos = worker_infos + self._worker_infos_lock = worker_infos_lock + self._max_sessions = max_sessions + self._max_idle = max_idle_seconds + + # OrderedDict: key=session_id -> value=(worker_rank, last_used_ts) + self._map: OrderedDict[int, tuple[int, float]] = OrderedDict() + + self._worker_cycler = cycle(worker_infos.keys()) + self._lock = asyncio.Lock() + self.logger = get_logger() + + def _now(self) -> float: + return time.time() + + def _evict_expired(self): + if self._max_idle is None: + return + now = self._now() + + to_delete = [] + for sid, (_, last_used) in self._map.items(): + if now - last_used > self._max_idle: + to_delete.append(sid) + else: + break + for sid in to_delete: + self._map.pop(sid, None) + + def _evict_lru_to_capacity(self): + while len(self._map) > self._max_sessions: + self._map.popitem(last=False) + + def _choose_next_active_worker(self) -> tuple[int, Any]: + n = len(self._worker_infos) + for _ in range(n): + rank = next(self._worker_cycler) + if self._worker_infos_lock is None: + info = self._worker_infos[rank] + if info and info.is_active: + return rank, info.actor + else: + with self._worker_infos_lock: + info = self._worker_infos[rank] + if info and info.is_active: + return rank, info.actor + return -1, None + + async def get_worker(self, session_id: int) -> Optional[Any]: + async with self._lock: + self._evict_expired() + + if session_id in self._map: + worker_rank, _ = self._map.pop(session_id) + if self._worker_infos_lock is None: + info = self._worker_infos.get(worker_rank) + else: + with self._worker_infos_lock: + info = self._worker_infos.get(worker_rank) + if info and info.is_active: + self._map[session_id] = (worker_rank, self._now()) + return info.actor + + rank, worker = self._choose_next_active_worker() + if rank == -1: + return None + self._map[session_id] = (rank, self._now()) + self._evict_lru_to_capacity() + return worker + + +class RolloutHealthChecker: + def __init__( + self, + config: "RolloutConfig", + workers_info: dict[int, "WorkerInfo"], + worker_infos_lock: Optional[threading.RLock] = None, + ): + self._workers_info = workers_info + self._worker_infos_lock = worker_infos_lock + self._check_interval = config.health_check_interval_seconds + self._check_failure_threshold = config.health_check_failure_threshold + self._stop_event: Optional[threading.Event] = None + self._pause_event: Optional[threading.Event] = None + self._thread: Optional[threading.Thread] = None + + def start(self) -> None: + if self._thread and self._thread.is_alive(): + return + + self._stop_event = threading.Event() + self._pause_event = threading.Event() + self._pause_event.set() # 启动时设置为暂停状态,开始generation后再调用restart方法恢复 + self._thread = threading.Thread(target=self._run_loop, daemon=True) + self._thread.start() + logger.info("RolloutHealthChecker started.") + + def stop(self) -> None: + if not self._thread: + return + + assert self._stop_event is not None + self._stop_event.set() + if self._pause_event: + self._pause_event.clear() + self._thread.join(timeout=5) + self._thread = None + self._stop_event = None + logger.info("RolloutHealthChecker stopped.") + + def pause(self) -> None: + if self._pause_event is None: + return + self._pause_event.set() + logger.info("RolloutHealthChecker paused.") + + def resume(self) -> None: + if self._pause_event is None: + return + self._pause_event.clear() + logger.info("RolloutHealthChecker restarted.") + + def run_once(self) -> None: + logger.debug("RolloutHealthChecker running health checks for all workers.") + if self._worker_infos_lock is None: + workers_snapshot = { + rank: (info.actor, info.url, info.is_active) for rank, info in self._workers_info.items() + } + else: + with self._worker_infos_lock: + workers_snapshot = { + rank: (info.actor, info.url, info.is_active) for rank, info in self._workers_info.items() + } + + tasks = [ + check_worker_health( + actor, + rank, + url, + is_active, + self._check_failure_threshold, + ) + for rank, (actor, url, is_active) in workers_snapshot.items() + ] + + async def _run_checks() -> list[bool]: + return await asyncio.gather(*tasks) + + check_results = asyncio_run(_run_checks()) + inactive_workers = [] + for rank, is_healthy in zip(workers_snapshot.keys(), check_results): + if not is_healthy: + logger.warning(f"Worker {rank} failed health check. Marking as inactive.") + if self._worker_infos_lock is None: + self._workers_info[rank].is_active = False + inactive_worker = self._workers_info[rank].actor + else: + with self._worker_infos_lock: + self._workers_info[rank].is_active = False + inactive_worker = self._workers_info[rank].actor + if inactive_worker is None: + logger.error(f"[RolloutHealthChecker] Worker {rank} has no actor reference. Skipping shutdown.") + continue + inactive_workers.append((rank, inactive_worker)) + else: + logger.debug(f"[RolloutHealthChecker] Worker {rank} passed health check.") + + for rank, inactive_worker in inactive_workers: + try: + ray.get(inactive_worker.offload.remote(), timeout=ROLLOUT_RAY_GET_TIMEOUT) # type: ignore[attr-defined] + ray.get(inactive_worker.shutdown.remote(), timeout=ROLLOUT_RAY_GET_TIMEOUT) # type: ignore[attr-defined] + except Exception as e: + logger.error(f"Exception while shutting down worker {rank}: {e}") + + def _run_loop(self) -> None: + assert self._stop_event is not None and self._pause_event is not None + logger.info("RolloutHealthChecker loop started.") + + while not self._stop_event.is_set(): + while self._pause_event.is_set() and not self._stop_event.is_set(): + self._stop_event.wait(timeout=0.5) + + if self._stop_event.is_set(): + break + + if not self._pause_event.is_set() and not self._stop_event.is_set(): + self.run_once() + + if self._stop_event.wait(self._check_interval): + break + + +async def send_abort_request(client: httpx.AsyncClient, url: str, timeout: float = 60.0) -> tuple[str, bool]: + worker_url = f"{url}/abort_request" + try: + response = await client.post(worker_url, json={"abort_all": True}, timeout=timeout) + response.raise_for_status() + logger.debug(f"Successfully sent abort request to {url}") + return url, True + except Exception as e: + logger.error(f"Failed to send abort request to {url}: {e}") + return url, False + + +async def pause_generation(rollout_ctl: "RolloutControllerProxy", pause_time_out: float = 60.0) -> None: + await rollout_ctl.pause_generation.remote() # type: ignore[attr-defined] + rollout_ctl_metadata = await rollout_ctl.get_rollout_metadata.remote() # type: ignore[attr-defined] + infer_server_url = list(rollout_ctl_metadata["server_url_dict"].values()) + async with httpx.AsyncClient() as client: + tasks = [send_abort_request(client, url, timeout=pause_time_out) for url in infer_server_url] + results = await asyncio.gather(*tasks) + + failed_workers = [url for url, success in results if not success] + succeeded_count = len(infer_server_url) - len(failed_workers) + + if failed_workers: + logger.warning( + f"Abort requests completed. Succeeded: {succeeded_count}, " + f"Failed: {len(failed_workers)}. Failed workers: {failed_workers}" + ) + else: + logger.info(f"All {succeeded_count} abort requests sent successfully.") + + +async def continue_generation(rollout_ctl: "RolloutControllerProxy") -> None: + return await rollout_ctl.continue_generation.remote() # type: ignore[attr-defined] + + +async def check_worker_health( + worker: "RolloutWorker", rank: int, url: str, is_active: bool, failure_threshold: int = 3 +) -> bool: + if worker is None or not is_active: + logger.warning("Worker has no actor reference or is marked inactive.") + return False + failing_count = 0 + while failing_count < failure_threshold: + try: + health_status = await worker.check_health.remote() # type: ignore[attr-defined] + if health_status: + return True + failing_count += 1 + logger.warning(f"Health check failed for worker {rank} at {url}. Failure count: {failing_count}") + except Exception as e: + failing_count += 1 + logger.error( + f"Exception during health check for worker {rank} at {url}: {e}. Failure count: {failing_count}" + ) + return False diff --git a/xtuner/v1/ray/rollout/vllm.py b/xtuner/v1/rl/rollout/vllm.py similarity index 79% rename from xtuner/v1/ray/rollout/vllm.py rename to xtuner/v1/rl/rollout/vllm.py index fb44fac0a1..85db410aa7 100644 --- a/xtuner/v1/ray/rollout/vllm.py +++ b/xtuner/v1/rl/rollout/vllm.py @@ -12,11 +12,11 @@ from vllm.entrypoints.utils import cli_env_setup from vllm.utils import FlexibleArgumentParser -from xtuner.v1.data_proto.rl_data import RLRolloutResponseItem, RolloutState -from xtuner.v1.ray.config import RolloutConfig -from xtuner.v1.ray.rollout.worker import RolloutWorker +from xtuner.v1.data_proto.rl_data import RolloutState, Status, update_status_from_finish_reason from xtuner.v1.utils.device import get_device, get_torch_device_module +from .worker import RolloutConfig, RolloutWorker + DEVICE = get_device() DEVICE_MODULE = get_torch_device_module() @@ -301,7 +301,7 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace: args["max_model_len"] = self.config.context_length args["enforce_eager"] = False args["enable_sleep_mode"] = True - args["worker_extension_cls"] = "xtuner.v1.ray.rollout.vllm.WorkerWrap" + args["worker_extension_cls"] = "xtuner.v1.rl.rollout.vllm.WorkerWrap" args["trust_remote_code"] = True args["enable_prefix_caching"] = False args["allowed_local_media_path"] = "/" @@ -360,41 +360,75 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace: ray_runtime_env={"env_vars": env}, ) - async def _handle_stream_response(self, uid, sample_params, extra_params, response) -> RLRolloutResponseItem: + async def _safe_handle_response(self, rollout_state: RolloutState, http_response) -> RolloutState: + if rollout_state.sample_params.stream: + return await self._handle_stream_response(rollout_state, http_response) + return await self._handle_non_stream_response(rollout_state, http_response) + + async def _handle_stream_response(self, rollout_state: RolloutState, response) -> RolloutState: raise NotImplementedError - async def _handle_non_stream_response( - self, root_id, action_id, sample_params, extra_params, response, input_extra_info - ) -> RLRolloutResponseItem: - uid = action_id - last_token_ids = [] - last_logprobs = [] - - response = response.json()["choices"][0] - if "logprobs" in response: - last_token_ids = response["token_ids"] - last_logprobs = [item["logprob"] for item in response["logprobs"]["content"]] + async def _handle_non_stream_response(self, rollout_state: RolloutState, response) -> RolloutState: + uid = rollout_state.uid or rollout_state.message_uid + sample_params = rollout_state.sample_params + last_token_ids: list[int] = [] + last_logprobs: list[float] = [] + routed_experts = None + + response_json = response.json() + response_choice = response_json["choices"][0] + if response_choice.get("logprobs") is not None: + last_token_ids = response_choice.get("token_ids", response_json.get("token_ids", [])) + last_logprobs = [ + item["logprob"] for item in response_choice["logprobs"].get("content", []) if "logprob" in item + ] assert len(last_token_ids) == len(last_logprobs) - assert len(last_token_ids) <= sample_params["max_tokens"], ( - f"Generation length exceeds limit: generated {len(last_token_ids)}, limit {sample_params['max_tokens']}" + assert len(last_token_ids) <= sample_params.max_tokens, ( + f"Generation length exceeds limit: generated {len(last_token_ids)}, limit {sample_params.max_tokens}" ) - last_trajectory = response["message"]["content"] - finish_reason = response["finish_reason"] + + last_trajectory = response_choice["message"].get("content") or "" + finish_reason = response_choice.get("finish_reason") if finish_reason == "abort" and self.receive_abort_request.is_set() is False: self.receive_abort_request.set() self.logger.info(f"Setting receive_abort_request to True for rank {self.rank}") - if finish_reason != "abort" and (len(last_token_ids) == 0 or len(last_logprobs) == 0): - self.logger.error(f"Invalid rollout response for request {uid}: {response}") - return RLRolloutResponseItem(state=RolloutState.SKIPPED) - - rollout_response = RLRolloutResponseItem( - response=last_trajectory, - response_ids=last_token_ids if len(last_token_ids) > 0 else None, - num_return_tokens=len(last_token_ids) if len(last_token_ids) > 0 else None, - finish_reason=finish_reason, - logprobs=last_logprobs, - state=RolloutState.ABORTED if finish_reason == "abort" else RolloutState.COMPLETED, - ) - - return rollout_response + if self.enable_return_routed_experts: + routed_experts = response_choice.get("routed_experts", response_json.get("routed_experts")) + if routed_experts is not None: + if isinstance(routed_experts, str): + import base64 + + data = base64.b64decode(routed_experts) + routed_experts = ray.cloudpickle.loads(data) + else: + routed_experts = torch.tensor(routed_experts) + routed_experts = ray.put(routed_experts) + + rollout_status = update_status_from_finish_reason(finish_reason) + if rollout_status == Status.COMPLETED: + validation_errors = [] + if sample_params.return_token_ids and len(last_token_ids) == 0: + validation_errors.append("empty response_ids") + if sample_params.return_logprob and len(last_logprobs) == 0: + validation_errors.append("missing logprobs") + if not last_trajectory: + validation_errors.append("empty response text") + if self.enable_return_routed_experts and routed_experts is None: + validation_errors.append("missing routed_experts") + + if validation_errors: + error_msg = f"Incomplete rollout data for request {uid}: {', '.join(validation_errors)}" + self.logger.error(f"{error_msg}. Raw response: {response_json}") + rollout_state.status = Status.FAILED + rollout_state.error_msg = error_msg + return rollout_state + + rollout_state.response = last_trajectory + rollout_state.response_ids = last_token_ids if len(last_token_ids) > 0 else None + rollout_state.logprobs = last_logprobs if len(last_logprobs) > 0 else None + rollout_state.routed_experts = routed_experts + rollout_state.finish_reason = finish_reason + rollout_state.status = rollout_status + + return rollout_state diff --git a/xtuner/v1/rl/rollout/worker.py b/xtuner/v1/rl/rollout/worker.py new file mode 100644 index 0000000000..5b3869e3e8 --- /dev/null +++ b/xtuner/v1/rl/rollout/worker.py @@ -0,0 +1,1075 @@ +import asyncio +import copy +import json +import multiprocessing +import os +import socket +import time +import traceback +from abc import abstractmethod +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, List, Literal, Optional, Union + +import httpx +import ray +import requests # type: ignore[import-untyped] +from cyclopts import Group, Parameter +from packaging.version import Version +from pydantic import BaseModel, ConfigDict, PrivateAttr +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from typing_extensions import Annotated + +from transformers import AutoTokenizer +from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status, update_status_from_finish_reason +from xtuner.v1.rl.utils import ( + AutoAcceleratorWorkers, + SingleAcceleratorWorker, + find_master_addr_and_port, + get_eos_token, +) +from xtuner.v1.utils import get_logger +from xtuner.v1.utils.httpx_utils import HttpRequestErrorType, HttpRequestResult + + +if TYPE_CHECKING: + from ray.util.placement_group import PlacementGroup + + +infer_group = Group("inference", help="Inference worker configuration.") + + +class RolloutConfig(BaseModel): + """Rollout worker configuration for XTuner. + + This class defines comprehensive configuration parameters for rollout workers in XTuner, + supporting multiple inference backends with distributed computing and optimization features. + + Args: + env (str): Environment variables for the rollout worker. Defaults to "". + backend (str): Backend framework ('vllm', 'lmdeploy', etc.). Defaults to "lmdeploy". + model_path (str | Path): Path to the inference model. + model_name (str): Model name for the backend engine. + tokenizer_path (str): Path to the model tokenizer. Defaults to "". + api_key (Optional[Union[List[str], str]]): API keys for rollout service. Supports single key or + list of keys. Defaults to None. + api_port (Optional[int]): Port number for the rollout API server. If not set, it will find an + available port starting from 8000. Defaults to 8000. + gpus_per_node (int): Number of GPUs per node. Defaults to 8. + dtype (str): Model data type ('bfloat16', 'float16', 'int8'). Defaults to "bfloat16". + gpu_memory_utilization (float): GPU memory utilization ratio. Defaults to 0.85. + random_seed (int): Random seed for reproducible generation. Defaults to 1024. + rollout_cross_node_comm (bool): Enable cross-node communication. Defaults to False. + rollout_max_batch_size_per_instance (int): Maximum batch size for the rollout worker. If not set, it + will be determined automatically based on `context_length`. Defaults to 512. + allow_over_concurrency_ratio (float): Factor to allow over-concurrency in HTTP requests for the + rollout worker to improve GPU utilization. Defaults to 1.2. + tensor_parallel_size (int): GPUs per inference engine (tensor parallelism). Defaults to 1. + expert_parallel_size (int): Experts per inference engine (expert parallelism). Defaults to 1. + enable_chunked_prefill (bool): Enable chunked prefill for memory efficiency. Defaults to False. + chunked_prefill_size (int): Chunk size for prefill operations. Defaults to 128. + skip_load_weights (bool): Skip weight loading for rollout worker. Defaults to False. + rollout_timeout (float): Timeout duration in seconds for rollout requests. Defaults to 3600.0. + context_length (int): Context length for the rollout worker. + launch_server_method (Literal["ray", "multiprocessing"]): Server launch method. Defaults to "ray". + system_prompt (Optional[str]): System prompt to guide generation behavior. Defaults to None. + extra_rollout_config (Optional[dict]): Backend-specific configurations using engine prefixes + (e.g., 'vllm_enable_chunked_prefill', 'lmdeploy_max_batch_size'). Defaults to empty dict. + + **Examples:** + + Example configuration with LMDeploy backend:: + + config = RolloutConfig( + env="test_env", + model_path="Qwen/Qwen3-8B", + model_name="Qwen3-8B", + tensor_parallel_size=2, + gpu_memory_utilization=0.6, + gpus_per_node=8, + backend="lmdeploy", + ) + """ + + model_config = ConfigDict(extra="forbid") + + # base config + env: Annotated[ + str, + Parameter(group=infer_group, help="Environment variables to set for the rollout."), + ] = "" + device: Annotated[str, Parameter(group=infer_group, help="Device to be used for the rollout worker.")] = "GPU" + model_path: Annotated[str | Path, Parameter(group=infer_group, help="Path to the SGLang model.")] + model_name: Annotated[ + str | None, Parameter(group=infer_group, help="Name of the model to be used in the LMDeploy.") + ] = None + tokenizer_path: Annotated[ + str | None, Parameter(group=infer_group, help="Path to the tokenizer for the model.") + ] = None + api_key: Annotated[ + Optional[Union[List[str], str]], + Parameter( + group=infer_group, + help="API keys for the rollout service. Can be a single key or a list of keys.", + ), + ] = None + api_port: Annotated[ + int, + Parameter(group=infer_group, help="Port number for the rollout API server. If not set, 8000 will be used."), + ] = 8000 + api_host: Annotated[ + str, + Parameter(group=infer_group, help="Host for the rollout API server."), + ] = "0.0.0.0" + gpus_per_node: Annotated[int, Parameter(group=infer_group, help="Number of GPUs allocated per node.")] = 8 + dtype: Annotated[ + str, + Parameter(group=infer_group, help="Data type for the model, e.g., 'bfloat16', 'float16', 'int8'."), + ] = "bfloat16" + gpu_memory_utilization: Annotated[ + float, Parameter(group=infer_group, help="GPU memory utilization for the rollout worker.") + ] = 0.85 + random_seed: Annotated[int, Parameter(group=infer_group, help="Random seed for the rollout worker.")] = 1024 + # distributed config + rollout_cross_node_comm: Annotated[ + bool, + Parameter( + group=infer_group, + help="Whether to enable cross-node communication for the rollout worker.", + ), + ] = False + dist_port_base: Annotated[ + int, + Parameter( + group=infer_group, + help="Base port number for distributed communication among rollout workers.", + ), + ] = 35000 + rollout_max_batch_size_per_instance: Annotated[ + Optional[int], + Parameter( + group=infer_group, + help="Maximum batch size for the rollout worker. If not set, it will be determined automatically based on the model and GPU memory.", + ), + ] = None + allow_over_concurrency_ratio: Annotated[ + float, + Parameter( + group=infer_group, + help="Factor to allow over concurrency in the http request for rollout worker to improve GPU utilization.", + ), + ] = 1.2 + tensor_parallel_size: Annotated[ + int, + Parameter( + group=infer_group, + help="Number of GPUs allocated for each inference engine in the rollout worker.", + ), + ] = 1 + data_parallel_size: Annotated[ + int, + Parameter( + group=infer_group, + help="Number of GPUs allocated for processing data batches in parallel (Data Parallelism).", + ), + ] = 1 + expert_parallel_size: Annotated[ + int, + Parameter( + group=infer_group, + help="Number of experts allocated for each inference engine in the rollout worker.", + ), + ] = 1 + # optimization config + enable_chunked_prefill: Annotated[ + bool, + Parameter( + group=infer_group, + help="Whether to enable chunked prefill for the rollout worker.", + ), + ] = False + chunked_prefill_size: Annotated[ + int, + Parameter( + group=infer_group, + help="Chunked prefill size for the rollout worker.", + ), + ] = 128 + skip_load_weights: Annotated[ + bool, + Parameter( + group=infer_group, + help="Whether to skip loading weights for the rollout worker.", + ), + ] = False + enable_return_routed_experts: Annotated[ + bool, + Parameter( + group=infer_group, + help="Whether to enable returning routed experts for the rollout worker.", + ), + ] = False + launch_server_method: Annotated[ + Literal["ray", "multiprocessing"], + Parameter( + group=infer_group, + help="Method to launch the rollout server, either 'ray' or 'multiprocessing'.", + ), + ] = "ray" + rollout_timeout: Annotated[ + float, + Parameter( + group=infer_group, + help="Timeout duration (in seconds) for rollout requests.", + ), + ] = 1200.0 + context_length: Annotated[ + Optional[int], + Parameter( + group=infer_group, + help="Context length for the rollout worker.", + ), + ] = None + tool_call_parser: Annotated[ + Literal["none", "qwen3", "qwen3p5"], + Parameter( + group=infer_group, + help='Structured tool-call parser to apply to rollout output. Use "none" to disable parsing, "qwen3" to enable Qwen3 tool-call parsing, or "qwen3p5" to enable Qwen3.5 coder-style tool-call parsing.', + ), + ] = "none" + reasoning_parser: Annotated[ + Literal["none", "qwen3"], + Parameter( + group=infer_group, + help='Reasoning parser to apply to rollout output. Use "none" to disable parsing or "qwen3" to enable Qwen3 parsing.', + ), + ] = "none" + enable_float8: Annotated[ + bool, + Parameter( + group=infer_group, + help="Whether to enable float8 quantization for the rollout worker.", + ), + ] = False + extra_rollout_config: Annotated[ + dict, + Parameter( + group=infer_group, + help='Extra configuration for different rollout worker. vllm parameters will start with prefix "vllm", etc.', + ), + ] = {} + max_retry_per_worker: Annotated[ + Optional[int], + Parameter( + group=infer_group, + help="Maximum number of retries per rollout worker before deactivation.", + ), + ] = None + max_retry_per_sample: Annotated[ + int, + Parameter( + group=infer_group, + help="Maximum number of retries per sample before marking it as failed.", + ), + ] = 1 + max_prefill_token_num: Annotated[ + Optional[int], + Parameter( + group=infer_group, + help="The number of tokens each iteration during prefill.", + ), + ] = None + router_n_groups: Annotated[ + Optional[int], + Parameter( + group=infer_group, + help="The number of groups in MoE model with group router, e.g. Intern-S1-Pro.", + ), + ] = None + fp32_lm_head: Annotated[ + bool, + Parameter( + group=infer_group, + help="Use float32 for language model head.", + ), + ] = False + worker_log_dir: Annotated[Path, Parameter(help="Directory to save worker logs.")] = Path.cwd() / "work_dir" + health_check_interval_seconds: Annotated[ + float, + Parameter( + group=infer_group, + help="Interval in seconds between rollout worker health checks.", + ), + ] = 30.0 + health_check_failure_threshold: Annotated[ + int, + Parameter( + group=infer_group, + help="Number of consecutive health check failures required before marking a worker inactive.", + ), + ] = 3 + _logged_server_urls_per_engine: bool = PrivateAttr(default=False) + + @property + def rollout_backend(self) -> str: + backend = "" + if os.environ.get("XTUNER_USE_SGLANG", "0") == "1": + backend = "sglang" + elif os.environ.get("XTUNER_USE_VLLM", "0") == "1": + backend = "vllm" + elif os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "1": + backend = "lmdeploy" + + assert backend in ["sglang", "vllm", "lmdeploy"], ( + f"Unsupported rollout backend: {backend}. Please set XTUNER_USE_SGLANG, XTUNER_USE_VLLM, or XTUNER_USE_LMDEPLOY to 1." + ) + return backend + + @property + def server_urls_per_engine(self) -> int: + # server_urls_per_engine is introduced for lmdeploy ep settings + # for now only lmdeploy pytorch backend with ep > 1 requires multiple server urls per engine + if self.rollout_backend == "lmdeploy" and self.expert_parallel_size > 1: + # when expert parallelism is used, lmdeploy requires `expert_parallel_size` server instances per engine + if not self._logged_server_urls_per_engine: + self._logged_server_urls_per_engine = True + get_logger().info( + f"Setting server_urls_per_engine={self.expert_parallel_size} due to expert parallelism in LMDeploy." + ) + return self.expert_parallel_size + else: + return 1 + + def model_post_init(self, __context: Any) -> None: + if self.model_name is None: + model_name_from_config = None + config_json_path = Path(self.model_path) / "config.json" + try: + with open(config_json_path, encoding="utf-8") as f: + config_data = json.load(f) + model_name_from_config = config_data.get("model_type") + except (json.JSONDecodeError, OSError): + pass + self.model_name = model_name_from_config or Path(self.model_path).name + + if self.tokenizer_path is None: + self.tokenizer_path = str(self.model_path) + + port = self.api_port + while True: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind((self.api_host if self.api_host != "0.0.0.0" else "localhost", port)) + break + except OSError: + port += 1 + self.api_port = port + + if self.device == "NPU": + self.gpus_per_node = 16 + + if self.rollout_backend == "sglang": + self.launch_server_method = "multiprocessing" + self.rollout_cross_node_comm = False + else: + self.launch_server_method = "ray" + self.rollout_cross_node_comm = True + + if self.rollout_max_batch_size_per_instance is None: + assert self.context_length is not None, ( + "context_length must be set if rollout_max_batch_size_per_instance is not provided." + ) + # TODO(@duanyanhui): Provide better suggestions for different models/input-output lengths + if self.context_length <= 4096: + self.rollout_max_batch_size_per_instance = 1024 + elif self.context_length <= 8192: + self.rollout_max_batch_size_per_instance = 512 + else: + self.rollout_max_batch_size_per_instance = 128 + + if self.max_retry_per_worker is None: + self.max_retry_per_worker = self.rollout_max_batch_size_per_instance + + self.worker_log_dir.mkdir(parents=True, exist_ok=True) + + def build(self, placement_group: "PlacementGroup"): + """Build and return a Ray remote RolloutController from this config. + + Args: + placement_group: The placement group for scheduling RolloutWorker actors. + + Returns: + A Ray actor handle (proxy) of RolloutController. + """ + import ray + + from xtuner.v1.rl.rollout.controller import RolloutController + + return ( + ray.remote(RolloutController) + .options(max_concurrency=int(os.environ.get("RAY_MAX_CONCURRENCY", 1000))) + .remote(self, placement_group) + ) + + +class RolloutWorker(SingleAcceleratorWorker): + """Base class for a rollout worker that runs an inference server. + + This class manages the lifecycle of a distributed inference server, including initialization, launching, and + handling generation requests. It is designed to be subclassed for specific inference backends like LMDeploy, vLLM + or SGLang. + """ + + def __init__( + self, + config: RolloutConfig, + rank: int, + master_addr: str, + master_port: int, + world_size: int, + accelerator: str = "GPU", + ): + """Initialize the RolloutWorker. + + Args: + config (RolloutConfig): The configuration for the rollout. + rank (int): The rank of this worker in the distributed setup. + master_addr (str): The address of the Ray master node. + master_port (int): The port of the Ray master node. + world_size (int): The total number of workers. + accelerator (str): The type of accelerator to use. + Defaults to "GPU". + """ + self.config = config + self.rank = rank + self.master_addr = master_addr # ray master + self.master_port = master_port + self.world_size = world_size + self.accelerator = accelerator + self.server_func: Callable + self.endpoints: dict[str, str] = dict() + self.engine_rank_mesh_array: list[list[int]] + # http_concurrency is calculated based on the max batch size per engine and the total number of engines + assert config.rollout_max_batch_size_per_instance, ( + "rollout_max_batch_size_per_instance must be set in RolloutConfig" + ) + http_concurrency = config.rollout_max_batch_size_per_instance * config.allow_over_concurrency_ratio + limits = httpx.Limits(max_connections=http_concurrency, max_keepalive_connections=100) + self.client = httpx.AsyncClient(limits=limits, timeout=self.config.rollout_timeout) + self.paused = False + self.server_task = None + self.engine_bundle_idxs: list[int] = [] + self.server_process: Optional[multiprocessing.Process] = None + self.logger = get_logger(log_dir=config.worker_log_dir, tag="RolloutWorker") + self.tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer_path, trust_remote_code=True) + self.check_flag = True # only print once + self.enable_return_routed_experts = self.config.enable_return_routed_experts + if self.rank == 0: + self.logger.info(f"RolloutConfig:\n{self.config.model_dump_json(indent=2)}") + eos_token = get_eos_token(self.config.model_path) + self.logger.info(f"Using eos_token: {eos_token} for model at {self.config.model_path}") + self.eos_token: List[int] = [eos_token] if isinstance(eos_token, int) else eos_token + self.receive_abort_request = asyncio.Event() + self.abort_timeout = 5.0 + self.dist_init_addr: str = "" + self.serverl_url: str = "" + + def init(self, dist_init_addr: str) -> tuple[int, str]: + """Initialize the worker and launch the server. + + Args: + dist_init_addr (str): The distributed initialization address. + If not provided, the one generated by `init_dist_port` is used. + + Returns: + Tuple[int, str]: A tuple containing the worker's rank and its + server URL. + """ + self.dist_init_addr = dist_init_addr if dist_init_addr else self.dist_init_addr + self.receive_abort_request.clear() + self._launch_server() + return (self.rank, self.server_url) + + def init_dist_port(self) -> str: + """Initialize distributed communication ports. + + This method acquires three free ports for the distributed setup: + one for the inference server, one for NCCL, and one for Ray's + distributed communication. + + Returns: + str: The distributed initialization address (host:port). + """ + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=ray.util.get_current_placement_group(), + placement_group_capture_child_tasks=True, + placement_group_bundle_index=self.engine_bundle_idxs[0], + ) + + local_rank = int(ray.get_runtime_context().get_accelerator_ids()[self.accelerator][0]) + interval = 1024 + start_port = self.config.dist_port_base + local_rank * interval + end_port = start_port + interval + self.host, self.ports = ray.get( + find_master_addr_and_port.options(scheduling_strategy=scheduling_strategy).remote( + nums=3, + start_port=start_port, + end_port=end_port, + ) + ) + + self.dist_port = self.ports[0] + self.server_port = self.ports[1] + self.nccl_port = self.ports[2] + self.dist_init_addr = f"{self.host}:{self.dist_port}" + self.server_url = f"http://{self.host}:{self.server_port}" + return self.dist_init_addr + + def shutdown(self): + """Shut down the worker, its server task, and any child processes.""" + if self.server_task is not None: + ray.cancel(self.server_task, force=True) + return + + if self.server_process is not None: + import psutil + + parent = psutil.Process(self.server_process.pid) + children = parent.children(recursive=True) + for child in children: + child.terminate() + gone, alive = psutil.wait_procs(children, timeout=5) + for child in alive: + child.kill() + parent.terminate() + parent.wait(timeout=5) + self.logger.debug(f"Worker {self.rank} server process and its children terminated.") + return + + def pause_generation(self): + """Pause the worker's generation process.""" + self.paused = True + + def continue_generation(self): + """Resume the worker's generation process.""" + self.receive_abort_request.clear() + + def check_health(self) -> bool: + """Check the health of the worker's server. + + Returns: + bool: True if the server is healthy, False otherwise. + """ + try: + headers = { + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {self.config.api_key}", + } + response = requests.get( + f"{self.server_url}/{self.endpoints['health_generate']}", headers=headers, timeout=5.0 + ) + return response.status_code == 200 + except requests.RequestException as e: + self.logger.error(f"Health check failed for server {self.server_url}: {e}") + return False + + def _decode_routed_experts(self, routed_experts: Any) -> Any: + return routed_experts + + async def generate(self, rollout_state: RolloutState) -> RolloutState: + # TODO(@duanyanhui): + # 1. support claude format input + # 2. 需要看下新的输入输出(RolloutState)怎么适配PartialRollout的逻辑,先跑起来 + # 3. 对于流式返回的response先删掉,目前还用不上,等需要的时候再加上 + + uid = rollout_state.uid + sample_params: SampleParams = rollout_state.sample_params + + if sample_params.return_token_ids: + endpoint_url = f"{self.server_url}/{self.endpoints['generate']}" + else: + endpoint_url = f"{self.server_url}/{self.endpoints['v1/chat/completions']}" + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.config.api_key}", + } + + max_retries = self.config.max_retry_per_sample + payload = self._get_request_payload(rollout_state) + + # 早退逻辑 1:检查是否已被标记为完成 + if rollout_state.status == Status.COMPLETED: + self.logger.debug(f"Request {uid} is already marked as COMPLETED, skipping generation.") + return rollout_state + + # 早退逻辑 2:检测输入是否还需要 generation (安全获取变量) + input_ids = payload.get("input_ids", []) + max_tokens = payload.get("max_tokens") + + last_id = input_ids[-1] if len(input_ids) > 0 else "None" + is_max_tokens_zero = max_tokens is not None and max_tokens <= 0 + is_eos_reached = len(input_ids) > 0 and input_ids[-1] in self.eos_token + + if is_max_tokens_zero or is_eos_reached: + self.logger.debug( + f"No generation needed for request {uid}: max_tokens={max_tokens} or last input_id={last_id} is in eos_token." + ) + rollout_state.status = Status.COMPLETED + rollout_state.response_ids = [] + rollout_state.response = "" + rollout_state.logprobs = [] + rollout_state.response_mask = [] + rollout_state.response_model_steps = [] + rollout_state.finish_reason = "stop" if is_eos_reached else "length" + return rollout_state + + for attempt in range(max_retries + 1): + is_last_attempt = attempt == max_retries + http_result = await self._safe_post_request(endpoint_url, headers=headers, payload=payload) + + # Case 1: HTTP Request is Successful + if http_result.response: + # Case 1.1: Valid rollout response + rollout_state = await self._safe_handle_response(rollout_state, http_result.response) + if rollout_state.status in [Status.COMPLETED, Status.ABORTED]: + return rollout_state + + if is_last_attempt: + # Case 1.2: Invalid rollout response and no retries left, so we return FAILED + self.logger.warning( + f"Invalid rollout response for request {uid} after {max_retries} attempts, marking as FAILED." + ) + rollout_state.status = Status.FAILED + rollout_state.error_msg = f"Invalid rollout response after {max_retries} attempts." + return rollout_state + + # Case 1.3: Invalid rollout response but we have retries left + self.logger.warning( + f"Invalid rollout response for request {uid}, retrying {attempt + 1}/{max_retries}." + ) + await asyncio.sleep(0.1) + continue + + # Case 2: Error occurred during HTTP Request + if http_result.error_type == HttpRequestErrorType.REQUEST_ABORTED: + # Case 2.1: The request was aborted due to an signal set by `receive_abort_request` + rollout_state.finish_reason = "abort" + rollout_state.status = update_status_from_finish_reason("abort") + return rollout_state + + if http_result.is_client_error: + # Case 2.2: A non-retryable client error occurred (such as 4xx HTTP status) + self.logger.warning( + f"rollout request {uid} to {http_result.url} was skipped due to client error {http_result.error_type} with {http_result.error_msg}" + ) + rollout_state.error_msg = ( + f"Client error {http_result.error_type} with message: {http_result.error_msg}" + ) + rollout_state.status = Status.FAILED + return rollout_state + + if http_result.is_server_error: + # Case 2.3: A non-retryable server error occurred (such as 5xx HTTP status) + self.logger.warning( + f"rollout request {uid} to {http_result.url} failed due to server error {http_result.error_type} with {http_result.error_msg}" + ) + rollout_state.error_msg = ( + f"Server error {http_result.error_type} with message: {http_result.error_msg}" + ) + rollout_state.status = Status.FAILED + return rollout_state + + # Case 3: Retryable error occurred during HTTP Request + if http_result.is_retryable: + if is_last_attempt: + self.logger.warning( + f"rollout request {uid} to {http_result.url} failed after {max_retries} attempts due to retryable error {http_result.error_type} with {http_result.error_msg}" + ) + rollout_state.error_msg = f"Request failed after {max_retries} attempts due to retryable error {http_result.error_type} with message: {http_result.error_msg}" + rollout_state.status = Status.FAILED + return rollout_state + + self.logger.warning( + f"rollout request {uid} to {http_result.url} failed due to retryable error {http_result.error_type} with {http_result.error_msg}, retrying {attempt + 1}/{max_retries}." + ) + await asyncio.sleep(0.1) + continue + + # Case 4: Unknown error occurred during HTTP Request and stop the rollout + if http_result.is_unknown_error: + raise RuntimeError( + f"Unexpected error during rollout request {uid} to {http_result.url}: {http_result.exception}" + ) + return rollout_state + + def _launch_server(self): + """Launch the inference server as a separate process or Ray task. + + It waits for the server to become healthy before returning. + + Raises: + TimeoutError: If the server fails to start within the specified + timeout. + Exception: If the server task terminates unexpectedly. + """ + server_configs = self._transform_rollout_config_to_server_configs() + timeout = 3600.0 # Increased timeout to 5 minutes for downloading large models + start_time = time.perf_counter() + last_log_time = start_time + headers = { + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {server_configs.api_key}", + } + + self.logger.info(f"Launch server task on server_url: {self.server_url}") + + # note(@duanyanhui): launch server as multiprocessing for sglang temporarily + if self.config.launch_server_method == "multiprocessing": + mp_ctx = multiprocessing.get_context("spawn") + process = mp_ctx.Process(target=self.server_func, args=(server_configs,)) + process.start() + self.server_process = process + time.sleep(60) # Wait for the server to start + with requests.Session() as session: + while time.perf_counter() - start_time < timeout: + try: + response = session.get( + f"{self.server_url}/{self.endpoints['health_generate']}", headers=headers + ) + if response.status_code == 200: + return + except requests.RequestException as e: + self.logger.error( + f"can't connect to server url {self.server_url}/{self.endpoints['health_generate']} because {e}" + ) + + current_time = time.perf_counter() + if current_time - last_log_time >= 15: + self.logger.info( + f"Waiting for server to start, Elapsed time: {current_time - start_time:.2f}s" + ) + last_log_time = current_time + + time.sleep(5) + process.terminate() + raise TimeoutError("Server failed to start within the timeout period.") + else: + # launch the server as ray task + # so that the lmdeploy backend could get externl pg + current_pg = ray.util.get_current_placement_group() + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=current_pg, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=self.engine_bundle_idxs[0], + ) + assert ray.is_initialized() + ray_kwargs = ( + {"runtime_env": server_configs.ray_runtime_env} if hasattr(server_configs, "ray_runtime_env") else {} + ) + self.server_task = ( + ray.remote(self.server_func) + .options( + scheduling_strategy=scheduling_strategy, + **AutoAcceleratorWorkers.get_pg_options(current_pg), + **ray_kwargs, + ) + .remote(server_configs) + ) + + with requests.Session() as session: + while time.perf_counter() - start_time < timeout: + try: + response = session.get( + f"{self.server_url}/{self.endpoints['health_generate']}", headers=headers + ) + if response.status_code == 200: + return + except requests.RequestException: + pass + + try: + ray.get(self.server_task, timeout=0.1) + raise Exception("Server task terminated unexpectedly.") + except ray.exceptions.GetTimeoutError: + pass + except Exception as e: + raise e + + current_time = time.perf_counter() + if current_time - last_log_time >= 15: + self.logger.info( + f"Waiting for server to start... Elapsed time: {current_time - start_time:.2f}s" + ) + last_log_time = current_time + + ray.cancel(self.server_task) + raise TimeoutError("Server failed to start within the timeout period.") + + async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: + try: + if self.receive_abort_request.is_set(): + self.logger.debug(f"Request to {url} was cancelled before sending due to an abort signal.") + return HttpRequestResult(error_type=HttpRequestErrorType.REQUEST_ABORTED, url=url, payload=payload) + req = self.client.build_request( + "POST", + url, + headers=headers, + json=payload, + ) + r = await self.client.send(req) + r.raise_for_status() + return HttpRequestResult(response=r) + + except Exception as e: + error_type = HttpRequestErrorType.from_exception(e) + result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) + return result + + async def _safe_handle_response(self, rollout_state: RolloutState, http_response: httpx.Response) -> RolloutState: + uid = rollout_state.message_uid + sample_params = rollout_state.sample_params + is_token_out = sample_params.return_token_ids + response = http_response.json() + if is_token_out: + response_ids: list[int] = [] + logprobs: list[float] = [] + routed_experts = None + returned_response = "" + finish_reason = response["meta_info"]["finish_reason"]["type"] + if finish_reason == "abort" and self.receive_abort_request.is_set() is False: + self.receive_abort_request.set() + self.logger.info(f"Setting receive_abort_request to True for rank {self.rank}") + try: + returned_response = response.get("text", "") + # 获取response_ids && respoonse_ids + if ( + "output_token_logprobs" in response["meta_info"] + and response["meta_info"]["output_token_logprobs"] is not None + ): + response_ids = [item[1] for item in response["meta_info"]["output_token_logprobs"]] + logprobs = [item[0] for item in response["meta_info"]["output_token_logprobs"]] + else: + num_return_tokens = response["meta_info"].get("completion_tokens", 0) + response_ids = response["output_ids"][-num_return_tokens:] if num_return_tokens > 0 else [] + + # 获取 routed_experts + if self.enable_return_routed_experts: + assert "routed_experts" in response["meta_info"], ( + "enable_return_routed_experts is True, but routed_experts is not in meta_info" + ) + routed_experts = response["meta_info"]["routed_experts"] # token[layer[expert]] + if routed_experts is not None: + routed_experts = self._decode_routed_experts(routed_experts) + if not isinstance(routed_experts, ray.ObjectRef): + routed_experts = ray.put(routed_experts) + + # 获取 status + rollout_status = update_status_from_finish_reason(finish_reason) + + # 检查输出结果 + if rollout_status == Status.COMPLETED: + validation_errors = [] + + if not response_ids: + validation_errors.append("empty response_ids") + + if not response: + validation_errors.append("empty response text") + + if sample_params.return_logprob and not logprobs: + validation_errors.append("missing logprobs") + + if self.enable_return_routed_experts and routed_experts is None: + validation_errors.append("missing routed_experts") + + if validation_errors: + error_msg = f"Incomplete rollout data for msg {uid}: {', '.join(validation_errors)}" + self.logger.error(error_msg) + rollout_state.status = Status.FAILED + rollout_state.error_msg = error_msg + return rollout_state + elif rollout_status == Status.FAILED: + error_msg = f"Rollout failed for msg {uid} with finish_reason {finish_reason}" + self.logger.error(error_msg) + rollout_state.status = Status.FAILED + rollout_state.error_msg = error_msg + return rollout_state + + rollout_state.response = returned_response + rollout_state.response_ids = response_ids + rollout_state.logprobs = logprobs + rollout_state.routed_experts = routed_experts + rollout_state.finish_reason = finish_reason + rollout_state.status = rollout_status + return rollout_state + except KeyError as e: + error_msg = f"Missing expected key {e} in response {response} for {uid}" + raise RuntimeError(error_msg) + except IndexError as e: + error_msg = f"Index error {e} while processing response {response} for {uid}" + raise RuntimeError(error_msg) + except AssertionError as e: + error_msg = f"AssertionError: {e} when processing response {response} for {uid}" + raise RuntimeError(error_msg) + except json.JSONDecodeError as e: + error_msg = f"JSONDecodeError: {e} when processing response {response} for {uid}" + raise RuntimeError(error_msg) + except TypeError as e: + error_msg = f"TypeError: {e} when processing response {response} for {uid}" + raise RuntimeError(error_msg) + except Exception as e: + error_msg = f"Unexpected error: {e} when processing response {response} for {uid}\nTraceback: {traceback.format_exc()}" + raise RuntimeError(error_msg) + else: + # v1/chat/completions API response + try: + returned_response = response["choices"][0]["message"]["content"] + finish_reason = response["choices"][0]["finish_reason"] + rollout_status = update_status_from_finish_reason(finish_reason) + if rollout_status == Status.COMPLETED and not returned_response: + self.logger.error(f"Empty response text for msg {uid} with finish_reason {finish_reason}") + rollout_state.status = Status.FAILED + rollout_state.error_msg = "Empty response text" + return rollout_state + + rollout_state.response = returned_response + rollout_state.finish_reason = finish_reason + rollout_state.status = rollout_status + return rollout_state + except KeyError as e: + error_msg = f"Missing expected key {e} in response {response} for {uid}" + raise RuntimeError(error_msg) + except IndexError as e: + error_msg = f"Index error {e} while processing response {response} for {uid}" + raise RuntimeError(error_msg) + except AssertionError as e: + error_msg = f"AssertionError: {e} when processing response {response} for {uid}" + raise RuntimeError(error_msg) + except json.JSONDecodeError as e: + error_msg = f"JSONDecodeError: {e} when processing response {response} for {uid}" + raise RuntimeError(error_msg) + except TypeError as e: + error_msg = f"TypeError: {e} when processing response {response} for {uid}" + raise RuntimeError(error_msg) + except Exception as e: + error_msg = f"Unexpected error: {e} when processing response {response} for {uid}\nTraceback: {traceback.format_exc()}" + raise RuntimeError(error_msg) + + def _adapt_input_to_openai_spec(self, prompts, tools, tool_choice): + openai_prompts = [] + openai_tools = [] + # transform claude spec to openai spec + # 1. transform system prompt: concat provided system_prompt to input prompt + system_prompt = self.config.system_prompt + if system_prompt: + system_prompt_json = {"role": "system", "content": f"{system_prompt}"} + prompts.insert(0, system_prompt_json) + # 2. transform multi-modal usage + for prompt in prompts: + content = prompt["content"] + openai_content = [] + for item in content: + if item["type"] == "image": + if item["source"]["type"] == "base64": + openai_url = f"data:{item['source']['media_type']};base64,{item['source']['data']}" + if item["source"]["type"] == "url": + openai_url = item["source"]["url"] + new_prompt = {"type": "image_url", "image_url": {"url": openai_url}} + openai_content.append(new_prompt) + elif item["type"] == "text": + openai_content.append(item) + new_prompt = copy.deepcopy(prompt) + new_prompt["content"] = openai_content + openai_prompts.append(new_prompt) + # 3. transform tool use + for tool in tools: + openai_tool = { + "type": "function", + "function": { + "name": tool["name"], + "description": tool["description"], + "parameters": tool["input_schema"], + }, + } + openai_tools.append(openai_tool) + return openai_prompts, openai_tools + + def _check_infer_engine_version(self, return_token_ids: bool): + # TODO(@duanyanhui): remove this check when all backends support return_token_ids + if self.check_flag: + if os.environ.get("XTUNER_USE_VLLM", "0") == "1": + if return_token_ids: + self.logger.error( + "VLLM backend does not support return_token_ids or generate with input_ids as input in Xtuner now" + ) + elif os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "1": + import lmdeploy + + lmdeploy_version = lmdeploy.__version__ + if return_token_ids and Version(lmdeploy_version) < Version("0.10.2"): + self.logger.error( + f"You should use lmdeploy >= v0.10.2 to support return_token_ids, but current version is {lmdeploy_version}" + ) + self.check_flag = False + + def _set_engine_rank_mesh_array(self, engine_rank_mesh_array: list[list[int]]): + self.engine_rank_mesh_array = engine_rank_mesh_array + + def _set_engine_bundle_idxs(self, engine_bundle_idxs: list[int]): + """Set the bundle indices for the inference engine. + + This is used by some backends (like LMDeploy with Ray executor) to + know which bundles in the placement group belong to this engine. + + Args: + engine_bundle_idxs (list[int]): A list of bundle indices. + """ + self.engine_bundle_idxs = engine_bundle_idxs + + @abstractmethod + def _get_request_payload(self, rollout_state: RolloutState) -> dict: + """Abstract method to create a generation request. + + Must be implemented by subclasses. + """ + raise NotImplementedError("_create_request must be implemented in subclass") + + @abstractmethod + def _transform_rollout_config_to_server_configs(self): + """Abstract method to transform rollout config to server configs. + + Must be implemented by subclasses. + """ + raise NotImplementedError("_transform_rollout_config_to_server_configs must be implemented in subclass") + + @abstractmethod + def _transform_sample_params(self, sample_params: SampleParams) -> dict: + """Abstract method to transform rollout config to server configs. + + Must be implemented by subclasses. + """ + raise NotImplementedError("_transform_rollout_config_to_server_configs must be implemented in subclass") + + @abstractmethod + def offload(self): + """Abstract method to offload the model and KVcache. + + Must be implemented by subclasses. + """ + raise NotImplementedError("reset_prefix_cache must be implemented in subclass") + + @abstractmethod + def onload_weights(self): + """Abstract method to onload the model weights. + + Must be implemented by subclasses. + """ + pass + + @abstractmethod + def onload_kvcache(self): + """Abstract method to onload the KV cache. + + Must be implemented by subclasses. + """ + pass diff --git a/xtuner/v1/rl/base/rollout_is.py b/xtuner/v1/rl/rollout_is.py similarity index 100% rename from xtuner/v1/rl/base/rollout_is.py rename to xtuner/v1/rl/rollout_is.py diff --git a/xtuner/v1/rl/trainer/__init__.py b/xtuner/v1/rl/trainer/__init__.py new file mode 100644 index 0000000000..2b7c95c235 --- /dev/null +++ b/xtuner/v1/rl/trainer/__init__.py @@ -0,0 +1,27 @@ +from ..rollout_is import ( + RolloutImportanceSampling, + compute_is_metrics, + compute_mismatch_metrics, + compute_rollout_importance_weights, + merge_rollout_is_metrics, +) +from .controller import ColateItem, TrainingController +from .worker import TrainingWorker, WorkerConfig, WorkerInputItem, WorkerLogItem, WorkerTrainLogItem +from .update_weighter import UpdateWeighter + + +__all__ = [ + "ColateItem", + "TrainingController", + "RolloutImportanceSampling", + "compute_rollout_importance_weights", + "compute_is_metrics", + "compute_mismatch_metrics", + "merge_rollout_is_metrics", + "UpdateWeighter", + "WorkerConfig", + "WorkerInputItem", + "WorkerTrainLogItem", + "WorkerLogItem", + "TrainingWorker", +] diff --git a/xtuner/v1/rl/base/controller.py b/xtuner/v1/rl/trainer/controller.py similarity index 93% rename from xtuner/v1/rl/base/controller.py rename to xtuner/v1/rl/trainer/controller.py index 8f036ca991..af22760a60 100644 --- a/xtuner/v1/rl/base/controller.py +++ b/xtuner/v1/rl/trainer/controller.py @@ -4,13 +4,12 @@ import ray import torch -from ray.actor import ActorProxy from xtuner.v1.data_proto.sequence_context import SequenceContext from xtuner.v1.model.compose.base import BaseComposeConfig -from xtuner.v1.ray.utils import free_object_refs +from xtuner.v1.rl.utils import free_object_refs from xtuner.v1.train.trainer import LoadCheckpointConfig -from xtuner.v1.utils import ray_method +from xtuner.v1.utils import get_logger from .worker import TrainingWorker, WorkerLogItem @@ -25,9 +24,10 @@ class ColateItem(TypedDict): rollout_logprobs: torch.Tensor | None -class RawTrainingController: +class TrainingController: def __init__(self, workers: list[TrainingWorker]) -> None: self.workers = workers + self.logger = get_logger() # TODO(hha): 这个逻辑不够通用,应该复用 sft 函数,从而支持 expand soft pack def _get_pack_infos(self, dataset, num_tokens, target, random=None): @@ -116,6 +116,7 @@ def _packing(self, data_batches, pack_max_length, language_cfg): dtype=data_batches[0]["shifted_labels"].dtype, device=data_batches[0]["shifted_labels"].device, ) + pad_advantages = [-100] * pad_len if is_qwen3_vl: _position_ids_list = [] for pad_token in pad_tokens: @@ -129,10 +130,7 @@ def _packing(self, data_batches, pack_max_length, language_cfg): seq_ctx_list.append(pad_seq_ctx) label_list.append(pad_labels) - advantage_list.extend( - [-100] * math.ceil(pad_len / 1024) - ) # can be any number, pad tokens are excluded from the calculation of the loss function. - + advantage_list.append(pad_advantages) if rollout_logprobs_list is not None: pad_rollout_logprobs = torch.zeros( 1, @@ -144,10 +142,8 @@ def _packing(self, data_batches, pack_max_length, language_cfg): seq_ctx = SequenceContext.cat(seq_ctx_list) shifted_labels = torch.cat(label_list, dim=1) # (1, max_len) - advantages = torch.tensor(advantage_list).float().unsqueeze(0) # (1, num_samples) - cu_seq_lens_q = seq_ctx.cu_seq_lens_q - num_tokens = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1] - advantages = torch.repeat_interleave(advantages, num_tokens, dim=1) # (1, max_len) + advantage_flat = [item for sublist in advantage_list for item in sublist] + advantages = torch.tensor(advantage_flat, dtype=torch.float32).unsqueeze(0) rollout_logprobs = None if rollout_logprobs_list is not None: @@ -169,7 +165,6 @@ def _grouped_by_max_length(self, packed_data_batches): # 排序后这条 pack 会被放在最前面,导致 rank0 的第一个 step 消耗的有效 token 数往往少于其他 rank,是正常现象。 return sorted(packed_data_batches, key=lambda x: x["seq_ctx"].max_length_q, reverse=True) - @ray_method def fit(self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx: int) -> list[WorkerLogItem]: has_rollout_routed_experts = False language_cfg = None @@ -274,7 +269,6 @@ def fit(self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx: del packed_data_batches return log_infos - @ray_method def offload(self, target: Literal["model", "optimizer", "all"] = "all"): if target == "model": ray.get([worker.offload_model.remote() for worker in self.workers], timeout=TRAIN_RAY_GET_TIMEOUT) # type: ignore @@ -285,7 +279,6 @@ def offload(self, target: Literal["model", "optimizer", "all"] = "all"): ray.get([worker.offload_optimizer.remote() for worker in self.workers], timeout=TRAIN_RAY_GET_TIMEOUT) # type: ignore return - @ray_method def onload(self, target: Literal["model", "optimizer", "all"] = "all"): """Onload the model or optimizer of the training workers.""" if target == "model": @@ -297,41 +290,28 @@ def onload(self, target: Literal["model", "optimizer", "all"] = "all"): ray.get([worker.onload_optimizer.remote() for worker in self.workers], timeout=TRAIN_RAY_GET_TIMEOUT) # type: ignore return - @ray_method def update_rollout_info(self, info_dict): ray.get([worker.update_rollout_info.remote(**info_dict) for worker in self.workers]) # type: ignore[attr-defined] - @ray_method def update_weights(self): """Update the weights of the training workers.""" handles = [worker.update_weights.remote() for worker in self.workers] ray.get(handles, timeout=TRAIN_RAY_GET_TIMEOUT) return - @ray_method def save_hf(self, hf_dir: str, save_dtype: torch.dtype = torch.bfloat16): handles = [worker.save_hf.remote(hf_dir, save_dtype) for worker in self.workers] # type: ignore ray.get(handles, timeout=TRAIN_RAY_GET_TIMEOUT) return - @ray_method def resume(self, load_checkpoint_cfg: LoadCheckpointConfig): """Resume the training workers from the checkpoint.""" handles = [worker.resume.remote(load_checkpoint_cfg) for worker in self.workers] # type: ignore ray.get(handles, timeout=TRAIN_RAY_GET_TIMEOUT) return - @ray_method def save(self, dcp_dir: str, no_save_optimizer: bool = False): """Save the DCP checkpoint of the training workers.""" handles = [worker.save.remote(dcp_dir, no_save_optimizer) for worker in self.workers] # type: ignore ray.get(handles, timeout=TRAIN_RAY_GET_TIMEOUT) return - - @ray_method - def ready(self) -> bool: - return True - - -TrainingController = ray.remote(RawTrainingController) -TrainingControllerProxy = ActorProxy[RawTrainingController] diff --git a/xtuner/v1/rl/trainer/update_weighter.py b/xtuner/v1/rl/trainer/update_weighter.py new file mode 100644 index 0000000000..3a85ec97e0 --- /dev/null +++ b/xtuner/v1/rl/trainer/update_weighter.py @@ -0,0 +1,525 @@ +import os +from itertools import chain +from typing import Dict, List, TypeAlias, cast + +import requests +import torch +import torch.distributed as dist +import tqdm +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor + +from xtuner.v1.model.compose.base import BaseComposeConfig +from xtuner.v1.model.compose.qwen3_vl import Qwen3VLForConditionalGeneration +from xtuner.v1.model.moe.moe import MoE +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.utils import ( + get_device, + get_torch_device_module, + monkey_unpatch_torch_reductions, + ray_method, +) +from xtuner.v1.utils.load_spec import LoadEnum, LoadSpec + + +DeviceMeshRaw: TypeAlias = List[List[int]] # A list of lists representing device mesh indices +ServiceUrlMap: TypeAlias = Dict[int, str] # A dictionary mapping service names to their URLs +DEVICE = get_device() +DEVICE_MODULE = get_torch_device_module() + + +class UpdateWeighter: + def _init_update_weighter(self): + self.rollout_device_mesh: DeviceMesh | None = None + self.rollout_url: str | None = None + self.rollout_cfg_info: dict = dict() + self.endpoints: dict[str, str] = dict() + self.endpoints["update_weights"] = "update_weights" + self._global_hf_keys_mapping_cache: dict[str, list[str]] = dict() + self._ipc_tensor_bytes: int = int(self.config.update_weight_bucket_size_in_gb * 1024**3) + self._update_params_ipc_tensor = None + self._update_params_ipc_event = None + + @ray_method + def update_rollout_info( + self, + engine_rank_mesh_array: DeviceMeshRaw, + server_url_dict: ServiceUrlMap, + rollout_config: RolloutConfig, + worker_server_urls_status: Dict[str, bool], + api_server_url: str | None = None, + ): + tp = rollout_config.tensor_parallel_size + ep = rollout_config.expert_parallel_size + assert tp == 1 or ep == 1, "Either tensor parallel size or engine parallel size must be 1." + if self.rollout_device_mesh is None: + self.rollout_device_mesh = DeviceMesh( + "cpu", mesh=engine_rank_mesh_array, mesh_dim_names=("engine_instance", "engine_parallel") + ) + rollout_server_url = server_url_dict.get(self.rank, "") + if worker_server_urls_status.get(rollout_server_url, "False") is False: + self.logger.error(f"Rollout server url {rollout_server_url} is not available.") + self.rollout_url = None + else: + self.rollout_url = rollout_server_url + self.rollout_cfg_info["tp"] = tp + self.rollout_cfg_info["ep"] = ep + self.rollout_cfg_info["api_key"] = rollout_config.api_key + if os.environ.get("XTUNER_USE_SGLANG", "0") == "1": + self.rollout_cfg_info["backend"] = "sglang" + else: + self.rollout_cfg_info["backend"] = (rollout_config.extra_rollout_config or dict()).get( + "lmdeploy_backend", "pytorch" + ) + + @ray_method + def update_weights(self): + DEVICE_MODULE.empty_cache() + self._update_params_ipc_event = DEVICE_MODULE.Event(interprocess=True) + if self.rollout_cfg_info.get("backend") == "turbomind": + self._update_weights_by_layer() + else: + if isinstance(self.config.model_cfg, BaseComposeConfig): + self._update_weights_hf_generator(submodule="language_model", final_update=False) + self._update_weights_hf_generator(submodule="vision_tower", final_update=False) + self._update_weights_hf_generator(submodule="multi_modal_projector", final_update=True) + else: + self._update_weights_hf_generator(final_update=True) + self._update_params_ipc_tensor = None + self._update_params_ipc_event = None + DEVICE_MODULE.empty_cache() + + def _rl_get_fused_ep_hf_param(self, model: MoE, target_ep_rank: int, target_ep_size: int, bucket_size: int): + fused_param_groups: list[tuple[torch.Tensor, LoadSpec]] = model._group_param_by_load_spec(LoadEnum.FUSED) + model_ep_size = 1 if model.fsdp_config is None else model.fsdp_config.ep_size + if not fused_param_groups: + return + + def _get_hf_params( + fsdp_tensor_list: list[tuple[torch.Tensor, LoadSpec]], + ) -> tuple[list[torch.Tensor], list[str]]: + hf_keys_list: list[str] = [] + hf_tensor_list: list[torch.Tensor] = [] + + for fsdp_tensor, load_spec in fsdp_tensor_list: + hf_keys = load_spec.hf_keys + if model_ep_size > 1 and model.ep_mesh is not None: + if load_spec.name not in self._global_hf_keys_mapping_cache: + global_hf_keys: list[list[str] | None] = [None] * model_ep_size + dist.all_gather_object(global_hf_keys, hf_keys, group=model.ep_mesh.get_group()) + global_hf_keys_gathered = cast(list[list[str]], global_hf_keys) + self._global_hf_keys_mapping_cache[load_spec.name] = list( + chain.from_iterable(global_hf_keys_gathered) + ) + hf_keys = self._global_hf_keys_mapping_cache[load_spec.name] + + fused_full_tensor = fsdp_tensor.bfloat16() + if isinstance(fused_full_tensor, DTensor): + fused_full_tensor = fused_full_tensor.full_tensor() + dim = cast(int, load_spec.dim) + num_split = len(hf_keys) + hf_tensor_size = fused_full_tensor.shape[dim] / num_split + assert hf_tensor_size.is_integer(), "Internal Error, hf_tensor_size is not integer" + hf_tensor_size = int(hf_tensor_size) + + hf_tensor = fused_full_tensor.split([hf_tensor_size] * num_split, dim=dim) + assert num_split % target_ep_size == 0, ( + f"len(hf_keys) of '{hf_keys}' is {num_split}, it must be divisible by target_ep_size {target_ep_size}" + ) + start_idx = (num_split // target_ep_size) * target_ep_rank + end_idx = (num_split // target_ep_size) * (target_ep_rank + 1) + + hf_keys_list.extend(hf_keys[start_idx:end_idx]) + hf_tensor_list.extend(hf_tensor[start_idx:end_idx]) + + hf_tensor_list = [ + model.param_to_safetensor(safetensor, name) for safetensor, name in zip(hf_tensor_list, hf_keys_list) + ] + + return hf_tensor_list, hf_keys_list + + safetensor_size = 0 + dtype = torch.bfloat16 + tensor_list: list[tuple[torch.Tensor, LoadSpec]] = [] + + for param, load_spec in fused_param_groups: + tensor_size = dtype.itemsize * param.numel() // target_ep_size + if safetensor_size + tensor_size > bucket_size and tensor_list: + hf_params, name_list = _get_hf_params(tensor_list) + yield name_list, hf_params + safetensor_size = tensor_size + name_list = load_spec.hf_keys.copy() + tensor_list = [(param, load_spec)] + continue + safetensor_size += tensor_size + tensor_list.append((param, load_spec)) + + if tensor_list: + hf_params, name_list = _get_hf_params(tensor_list) + yield name_list, hf_params + + @torch.no_grad() + def _update_weights_hf_generator(self, submodule=None, final_update=False): + self.endpoints["update_weights"] = "update_weights" + assert self.rollout_device_mesh is not None + + model = self._engine.model + if submodule: + model = getattr(model, submodule) + + dtype = torch.bfloat16 + bucket_size = int(self.config.update_weight_bucket_size_in_gb * 1024**3) + same_gen = model._get_same_hf_param( + model._group_param_by_load_spec(LoadEnum.SAME), dtype=dtype, device=DEVICE, bucket_size=bucket_size + ) + + train_enable_ep = model.fsdp_config is not None and model.fsdp_config.ep_size > 1 + if train_enable_ep: + if self.rollout_cfg_info["ep"] > 1: + fused_gen = self._rl_get_fused_ep_hf_param( + model, + target_ep_rank=self.rollout_device_mesh["engine_parallel"].get_coordinate()[0], + target_ep_size=self.rollout_device_mesh["engine_parallel"].size(), + bucket_size=bucket_size, + ) + else: + fused_gen = self._rl_get_fused_ep_hf_param( + model, + target_ep_rank=0, + target_ep_size=1, + bucket_size=bucket_size, + ) + else: + fused_gen = model._get_fused_hf_param( + model._group_param_by_load_spec(LoadEnum.FUSED), + dtype=dtype, + device=DEVICE, + bucket_size=bucket_size, + update_weights_for_rl=True, + ) + shard_gen = model._get_shard_hf_param( + model._group_param_by_load_spec(LoadEnum.SHARD), dtype=dtype, device=DEVICE, bucket_size=bucket_size + ) + + for name_list, fused_param_list in fused_gen: + state_dict = {name: param.detach() for name, param in zip(name_list, fused_param_list)} + self.request_update_params(state_dict, train_enable_ep=train_enable_ep, finished=False) + del state_dict, name_list, fused_param_list + + for name_list, param_list in chain(same_gen, shard_gen): + state_dict = {name: param.detach() for name, param in zip(name_list, param_list)} + self.request_update_params(state_dict, train_enable_ep=train_enable_ep, finished=False) + del state_dict, name_list, param_list + + if self.rollout_cfg_info["backend"] == "pytorch" and final_update: + self.request_update_params({}, train_enable_ep=train_enable_ep, finished=True) + + dist.barrier() + return + + def _update_weights_by_layer(self): + self.endpoints["update_weights"] = "update_weights" + assert self.rollout_device_mesh is not None + + model = self._engine.model + DEVICE_MODULE.empty_cache() + + if isinstance(model.config, BaseComposeConfig): + dtype = torch.bfloat16 + else: + if (model.config.float8_cfg is not None) and (model.config.float8_cfg.enable_float8): + dtype = torch.float8_e4m3fn + else: + dtype = torch.bfloat16 + + def get_params(tensor_list, name_list, save_dtype): + _tensor_list, _spec_list = list(zip(*tensor_list)) + fsdp_unshard_tensor_list = model._fsdp_foreach_allgather(_tensor_list, _spec_list) + if save_dtype == torch.float8_e4m3fn: + fsdp_unshard_tensor_list, name_list = model._to_float8( + fsdp_unshard_tensor_list, name_list, _tensor_list, save_dtype + ) + return fsdp_unshard_tensor_list, name_list + + saved_list = [] + is_qwen3vl = False + if isinstance(model.config, BaseComposeConfig): + language_model = model.language_model + if isinstance(model, Qwen3VLForConditionalGeneration): + is_qwen3vl = True + else: + language_model = model + + if is_qwen3vl: + vision_hf_prefix = "model.visual." + projector_hf_prefix = "model.visual." + else: + vision_hf_prefix = "model.vision_tower." + projector_hf_prefix = "model.multi_modal_projector." + + for i, layer in tqdm.tqdm(language_model.layers.items(), desc="[gather weight]"): + tensor_list = [] + name_list = [] + for sub_name, param in layer.state_dict().items(): + if isinstance(model.config, BaseComposeConfig): + saved_list.append(f"language_model.layers.{i}.{sub_name}") + else: + saved_list.append(f"layers.{i}.{sub_name}") + local_tensor = param._local_tensor if isinstance(param, DTensor) else param + local_tensor = local_tensor.bfloat16() + load_spec = language_model.load_spec_mapping.get(f"layers.{i}.{sub_name}") + + if isinstance(model.config, BaseComposeConfig): + name = f"model.language_model.layers.{i}.{sub_name}" + else: + name = f"model.layers.{i}.{sub_name}" + + if ".experts." in name and ".mlp.experts." not in name: + name = name.replace(".experts.", ".mlp.experts.") + if ".gate." in name and ".mlp.gate." not in name: + name = name.replace(".gate.", ".mlp.gate.") + name_list.append(name) + tensor_list.append((local_tensor, load_spec)) + fsdp_unshard_tensor_list, name_list = get_params(tensor_list, name_list, dtype) + state_dict = dict(zip(name_list, fsdp_unshard_tensor_list)) + self.request_update_params(state_dict) + + for name, param in model.state_dict().items(): + if name in saved_list: + continue + local_tensor = param._local_tensor if isinstance(param, DTensor) else param + local_tensor = local_tensor.bfloat16() + load_spec = model.load_spec_mapping.get(name) + + if isinstance(model.config, BaseComposeConfig): + if "vision_tower." in name: + name = name.replace("vision_tower.", vision_hf_prefix) + elif "multi_modal_projector." in name: + name = name.replace("multi_modal_projector.", projector_hf_prefix) + elif name == "language_model.norm.weight": + name = "model.language_model.norm.weight" + elif name == "language_model.embed_tokens.weight": + name = "model.language_model.embed_tokens.weight" + elif name == "language_model.lm_head.weight": + name = "lm_head.weight" + else: + if name == "norm.weight": + name = "model.norm.weight" + elif name == "embed_tokens.weight": + name = "model.embed_tokens.weight" + tensor_list = [(local_tensor, load_spec)] + name_list = [name] + fsdp_unshard_tensor_list, name_list = get_params(tensor_list, name_list, dtype) + state_dict = dict(zip(name_list, fsdp_unshard_tensor_list)) + self.request_update_params(state_dict) + + if self.rollout_cfg_info["backend"] == "pytorch": + self.request_update_params({}, finished=True) + + dist.barrier() + DEVICE_MODULE.empty_cache() + return + + @staticmethod + def _compute_state_dict_bytes(state_dict: Dict[str, torch.Tensor]) -> int: + total_bytes = 0 + for tensor in state_dict.values(): + total_bytes += tensor.numel() * tensor.element_size() + return total_bytes + + @staticmethod + def _create_ipc_tensor(size_in_bytes: int, dtype: torch.dtype): + return torch.empty(size_in_bytes, dtype=torch.uint8, device=DEVICE).view(dtype) + + @ray_method + def request_update_params(self, state_dict, train_enable_ep=False, finished=False): + cpu_mesh = self.rollout_device_mesh["engine_parallel"] + cpu_group = cpu_mesh.get_group() + head_rank = cpu_mesh.mesh[0].item() + if self.rollout_url is None: + self.logger.error(f"rank {self.rank} url in None, cannot update weights and skip") + return + + if self.rollout_cfg_info["backend"] == "pytorch": + from lmdeploy.utils import serialize_state_dict + + try: + from lmdeploy.utils import FlattenedTensorBucket + + use_flattened_tensor_bucket = True + except Exception: + use_flattened_tensor_bucket = False + + if self.rollout_cfg_info["backend"] == "pytorch" and self.rollout_cfg_info["tp"] > 1: + serialized_data = [None] * self.rollout_cfg_info["tp"] + if use_flattened_tensor_bucket and state_dict: + state_dict_bytes = self._compute_state_dict_bytes(state_dict) + send_ipc_tensor = ( + state_dict_bytes > self._ipc_tensor_bytes or self._update_params_ipc_tensor is None + ) + if send_ipc_tensor: + self._ipc_tensor_bytes = max(self._ipc_tensor_bytes, state_dict_bytes) + if self._update_params_ipc_tensor is not None: + self._update_params_ipc_event.wait() + torch.cuda.synchronize() + self._update_params_ipc_tensor = self._create_ipc_tensor( + self._ipc_tensor_bytes, state_dict[next(iter(state_dict))].dtype + ) + else: + self._update_params_ipc_event.wait() + + flattened_tensor_bucket = FlattenedTensorBucket( + named_tensors=list(state_dict.items()), + flattened_tensor=self._update_params_ipc_tensor, + ) + metadata = flattened_tensor_bucket.get_metadata() + flattened_tensor_data = { + "metadata": metadata, + "require_clone": False, + } + self._update_params_ipc_event.record() + + if send_ipc_tensor: + flattened_tensor_data["flattened_tensor"] = flattened_tensor_bucket.get_flattened_tensor() + flattened_tensor_data["event_ipc_handle"] = self._update_params_ipc_event.ipc_handle() + tp_serialized_data = serialize_state_dict(flattened_tensor_data) + else: + tp_serialized_data = serialize_state_dict(state_dict) + dist.gather_object( + tp_serialized_data, + serialized_data if dist.get_rank() == head_rank else None, + dst=head_rank, + group=cpu_group, + ) + elif self.rollout_cfg_info["backend"] == "pytorch": + if use_flattened_tensor_bucket and state_dict: + state_dict_bytes = self._compute_state_dict_bytes(state_dict) + send_ipc_tensor = ( + state_dict_bytes > self._ipc_tensor_bytes or self._update_params_ipc_tensor is None + ) + if send_ipc_tensor: + self._ipc_tensor_bytes = max(self._ipc_tensor_bytes, state_dict_bytes) + if self._update_params_ipc_tensor is not None: + self._update_params_ipc_event.wait() + torch.cuda.synchronize() + self._update_params_ipc_tensor = self._create_ipc_tensor( + self._ipc_tensor_bytes, state_dict[next(iter(state_dict))].dtype + ) + else: + self._update_params_ipc_event.wait() + + flattened_tensor_bucket = FlattenedTensorBucket( + named_tensors=list(state_dict.items()), + flattened_tensor=self._update_params_ipc_tensor, + ) + metadata = flattened_tensor_bucket.get_metadata() + flattened_tensor_data = { + "metadata": metadata, + "require_clone": False, + } + self._update_params_ipc_event.record() + + if send_ipc_tensor: + flattened_tensor_data["flattened_tensor"] = flattened_tensor_bucket.get_flattened_tensor() + flattened_tensor_data["event_ipc_handle"] = self._update_params_ipc_event.ipc_handle() + serialized_data = serialize_state_dict(flattened_tensor_data) + else: + serialized_data = serialize_state_dict(state_dict) + else: + serialized_data = serialize_state_dict(state_dict) if dist.get_rank() == head_rank else None + else: + from sglang.srt.utils import MultiprocessingSerializer + from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions + + try: + from sglang.srt.model_executor.model_runner import FlattenedTensorBucket + + use_flattened_tensor_bucket = True + except Exception: + use_flattened_tensor_bucket = False + + monkey_patch_torch_reductions() + state_dict = state_dict.items() + if self.rollout_cfg_info["tp"] == 1: + if use_flattened_tensor_bucket: + flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=state_dict) + metadata = flattened_tensor_bucket.get_metadata() + + flattened_tensor_data = { + "flattened_tensor": flattened_tensor_bucket.get_flattened_tensor(), + "metadata": metadata, + } + serialized_data = MultiprocessingSerializer.serialize(flattened_tensor_data, output_str=True) + else: + serialized_data = MultiprocessingSerializer.serialize(state_dict, output_str=True) + + serialized_data = [serialized_data] + else: + serialized_data = [None] * self.rollout_cfg_info["tp"] + if use_flattened_tensor_bucket: + flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=state_dict) + metadata = flattened_tensor_bucket.get_metadata() + + flattened_tensor_data = { + "flattened_tensor": flattened_tensor_bucket.get_flattened_tensor(), + "metadata": metadata, + } + tp_serialized_data = MultiprocessingSerializer.serialize(flattened_tensor_data, output_str=True) + dist.gather_object( + tp_serialized_data, + serialized_data if dist.get_rank() == head_rank else None, + dst=head_rank, + group=cpu_group, + ) + else: + tp_serialized_data = MultiprocessingSerializer.serialize(state_dict, output_str=True) + dist.gather_object( + tp_serialized_data, + serialized_data if dist.get_rank() == head_rank else None, + dst=head_rank, + group=cpu_group, + ) + + if dist.get_rank() == head_rank: + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.rollout_cfg_info['api_key']}", + } + if self.rollout_cfg_info["backend"] == "sglang": + payload = { + "serialized_named_tensors": serialized_data, + "flush_cache": False, + } + try: + from sglang.srt.model_executor.model_runner import FlattenedTensorBucket + + use_flattened_tensor_bucket = True + except Exception: + use_flattened_tensor_bucket = False + if use_flattened_tensor_bucket: + payload["load_format"] = "flattened_bucket" + + url = f"{self.rollout_url}/update_weights_from_tensor" + response = requests.post(url, json=payload or {}) + response.raise_for_status() + else: + data = dict(serialized_named_tensors=serialized_data, finished=finished) + try: + from lmdeploy.utils import FlattenedTensorBucket + + use_flattened_tensor_bucket = True + except Exception: + use_flattened_tensor_bucket = False + + if use_flattened_tensor_bucket: + data["load_format"] = "flattened_bucket" + response = requests.post( + f"{self.rollout_url}/{self.endpoints['update_weights']}", headers=headers, json=data + ) + assert response.status_code == 200, f"response.status_code = {response.status_code}" + + if finished or (self.rollout_cfg_info["backend"] == "pytorch" and train_enable_ep and self.rollout_cfg_info["tp"] > 1): + dist.barrier(group=cpu_group) + + monkey_unpatch_torch_reductions() + return diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/trainer/worker.py similarity index 54% rename from xtuner/v1/rl/base/worker.py rename to xtuner/v1/rl/trainer/worker.py index fb374472a0..ba495beef2 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/trainer/worker.py @@ -2,20 +2,21 @@ import math import os import time -from itertools import chain from pathlib import Path -from typing import Dict, Iterable, List, Sequence, TypeAlias, TypedDict, cast +from typing import TYPE_CHECKING, Iterable, List, Sequence, TypedDict, cast + +if TYPE_CHECKING: + from ray.util.placement_group import PlacementGroup + +import numpy as np import ray -import requests import torch import torch.distributed as dist -import tqdm from mmengine.runner import set_random_seed from pydantic import BaseModel, ConfigDict from ray.actor import ActorClass, ActorProxy -from torch.distributed.device_mesh import DeviceMesh, init_device_mesh -from torch.distributed.tensor import DTensor +from torch.distributed.device_mesh import init_device_mesh from typing_extensions import NotRequired from transformers import AutoTokenizer @@ -31,11 +32,9 @@ from xtuner.v1.model.base import BaseModel as XtunerBaseModel from xtuner.v1.model.base import ModelItem, TransformerConfig from xtuner.v1.model.compose.base import BaseComposeConfig, BaseComposeModel -from xtuner.v1.model.compose.qwen3_vl import Qwen3VLForConditionalGeneration from xtuner.v1.model.utils.misc import ModelForwardExtraLogInfo -from xtuner.v1.ray.base import SingleAcceleratorWorker -from xtuner.v1.ray.config import RolloutConfig -from xtuner.v1.rl.base.loss import BaseRLLossContext +from xtuner.v1.rl.loss import BaseRLLossConfig, BaseRLLossContext, kl_penalty +from xtuner.v1.rl.utils import SingleAcceleratorWorker, gather_logprobs from xtuner.v1.train.trainer import LoadCheckpointConfig from xtuner.v1.utils import ( XTUNER_DETERMINISTIC, @@ -43,47 +42,17 @@ get_device, get_logger, get_torch_device_module, - monkey_unpatch_torch_reductions, ray_method, ) -from xtuner.v1.utils.load_spec import LoadEnum -from ..loss_fn import kl_penalty -from .loss import BaseRLLossConfig -from .rollout_is import merge_rollout_is_metrics +from ..rollout_is import merge_rollout_is_metrics +from .update_weighter import UpdateWeighter -DeviceMeshRaw: TypeAlias = List[List[int]] # A list of lists representing device mesh indices -ServiceUrlMap: TypeAlias = Dict[int, str] # A dictionary mapping service names to their URLs DEVICE = get_device() DEVICE_MODULE = get_torch_device_module() -def serialize_state_dict(state_dict: dict) -> str: - """Serialize state dict to str. - - The consumer should use it on same node. As the producer and consumer may - have different GPU visibility, we use reduce_tensor instead of ForkingPickler.dumps - to fix the device_id when loading the serialized tensor. - - Args: - state_dict (dict[str, torch.Tensor]): state dict to serialize. - Returns: - str: serialized state dict. - """ - import base64 - from io import BytesIO - from multiprocessing.reduction import ForkingPickler - - from torch.multiprocessing.reductions import reduce_tensor - - data = [(k, reduce_tensor(v)) for k, v in state_dict.items()] - buf = BytesIO() - ForkingPickler(buf).dump(data) - buf.seek(0) - return base64.b64encode(buf.read()).decode("utf-8") - - def calculate_entropy( shifted_labels_list: Sequence[torch.Tensor], old_logprobs_list: Sequence[torch.Tensor | None], @@ -166,6 +135,26 @@ class WorkerConfig(BaseModel): rollout_steps_per_sft: int = 1 sft_loss_cfg: CELossConfig = CELossConfig() + def build(self, placement_group: "PlacementGroup"): + """Build training workers and controller from this config and placement + group.""" + # import here to avoid circular import + from xtuner.v1.rl.trainer.controller import TrainingController + from xtuner.v1.rl.utils import AutoAcceleratorWorkers + + TrainingWorkerCls = ray.remote( + runtime_env={ + "env_vars": { + "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1", + "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "1", + "HCCL_NPU_SOCKET_PORT_RANGE": "auto", + } + } + )(TrainingWorker) + train_workers, _ = AutoAcceleratorWorkers.from_placement_group(TrainingWorkerCls, self, placement_group) + ray.wait([w.ready.remote() for w in train_workers]) + return TrainingController(workers=train_workers) + class WorkerInputItem(TypedDict): seq_ctx: SequenceContext @@ -189,7 +178,7 @@ class WorkerLogItem(TypedDict): sft_train_metrics: NotRequired[dict[str, float]] -class TrainingWorker(SingleAcceleratorWorker): +class TrainingWorker(SingleAcceleratorWorker, UpdateWeighter): _SAVE_OPTIMIZER_DIR = "optimizer" _SAVE_MODEL_DIR = "model" _SAVE_SFT_DATALOADER_DIR = "sft_dataloader" @@ -240,18 +229,12 @@ def __init__( ) self._optimizer_steps = worker_cfg.optimizer_steps - - # Used to update weight to rollout engine - self.rollout_device_mesh: DeviceMesh | None = None - self.rollout_url: str | None = None - self.rollout_cfg_info: dict = dict() - self.endpoints: dict[str, str] = dict() - self.endpoints["update_weights"] = "update_weights" if worker_cfg.loss_cfg.chunk_size is not None: mode = "chunk" else: mode = "eager" self.logprob_cfg = LogProbConfig(chunk_size=worker_cfg.loss_cfg.chunk_size, mode=mode) + self._init_update_weighter() def _init_sft(self, worker_cfg: WorkerConfig): self._sft_dataloader_config = worker_cfg.sft_dataloader_cfg @@ -483,13 +466,16 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int) -> WorkerLo seq_ctx = data["seq_ctx"] pixel_values = seq_ctx.pixel_values if pixel_values is not None: - if not isinstance(pixel_values, torch.Tensor): + if not isinstance(pixel_values, np.ndarray): assert isinstance(pixel_values, list), ( f"pixel_values should be list of tensor, got {type(pixel_values)}" ) - pixel_value_refs = list(pixel_values) - pixel_values = torch.cat(ray.get(pixel_value_refs), dim=0) + pixel_values = [ray.get(pixel_obf) for pixel_obf in pixel_values] + pixel_values = [torch.as_tensor(pixel_value) for pixel_value in pixel_values] + pixel_values = torch.cat(pixel_values, dim=0) seq_ctx.pixel_values = pixel_values + else: + raise NotImplementedError("The case where pixel_values is a numpy array is not implemented yet.") rollout_routed_experts = seq_ctx.rollout_routed_experts if rollout_routed_experts is not None: @@ -816,605 +802,6 @@ def onload_model(self): def onload_optimizer(self): self._engine.put_optimizer_to_device(DEVICE) - @ray_method - def update_rollout_info( - self, - engine_rank_mesh_array: DeviceMeshRaw, - server_url_dict: ServiceUrlMap, - rollout_config: RolloutConfig, - worker_server_urls_status: Dict[str, bool], - ): - """Update the rollout information for the training worker.""" - tp = rollout_config.tensor_parallel_size - ep = rollout_config.expert_parallel_size - assert tp == 1 or ep == 1, "Either tensor parallel size or engine parallel size must be 1." - if self.rollout_device_mesh is None: - self.rollout_device_mesh = DeviceMesh( - "cpu", mesh=engine_rank_mesh_array, mesh_dim_names=("engine_instance", "engine_parallel") - ) - rollout_server_url = server_url_dict.get(self.rank, "") - if worker_server_urls_status.get(rollout_server_url, "False") is False: - self.logger.error(f"Rollout server url {rollout_server_url} is not available.") - self.rollout_url = None - else: - self.rollout_url = rollout_server_url - self.rollout_cfg_info["tp"] = tp - self.rollout_cfg_info["ep"] = ep - self.rollout_cfg_info["api_key"] = rollout_config.api_key - if os.environ.get("XTUNER_USE_SGLANG", "0") == "1": - self.rollout_cfg_info["backend"] = "sglang" - elif os.environ.get("XTUNER_USE_VLLM", "0") == "1": - self.rollout_cfg_info["backend"] = "vllm" - else: - self.rollout_cfg_info["backend"] = (rollout_config.extra_rollout_config or dict()).get( - "lmdeploy_backend", "pytorch" - ) - - @ray_method - def update_weights(self): - """Update the model weights.""" - if self.rollout_cfg_info.get("backend") == "turbomind": - self._update_weights_by_layer() - else: - if isinstance(self.config.model_cfg, BaseComposeConfig): - self._update_weights_hf_generator(submodule="vision_tower", final_update=False) - self._update_weights_hf_generator(submodule="multi_modal_projector", final_update=False) - self._update_weights_hf_generator(submodule="language_model", final_update=True) - else: - self._update_weights_hf_generator() - - def _update_weights_hf_generator(self, submodule=None, final_update=True): - """Update the model weights.""" - self.endpoints["update_weights"] = "update_weights" - assert self.rollout_device_mesh is not None - - model = self._engine.model - if submodule: - model = getattr(model, submodule) - - DEVICE_MODULE.empty_cache() - - # TODO: force bfloat16 dtype for now - dtype = torch.bfloat16 - - bucket_size = int(self.config.update_weight_bucket_size_in_gb * 1024**3) - same_gen = model._get_same_hf_param( - model._group_param_by_load_spec(LoadEnum.SAME), dtype=dtype, device=DEVICE, bucket_size=bucket_size - ) - fused_gen = model._get_fused_hf_param( - model._group_param_by_load_spec(LoadEnum.FUSED), - dtype=dtype, - device=DEVICE, - bucket_size=bucket_size, - update_weights_for_rl=True, - ) - shard_gen = model._get_shard_hf_param( - model._group_param_by_load_spec(LoadEnum.SHARD), dtype=dtype, device=DEVICE, bucket_size=bucket_size - ) - - for name_list, fused_param_list in fused_gen: - state_dict = {name: param.detach() for name, param in zip(name_list, fused_param_list)} - if model.fsdp_config.ep_size > 1: - # When ep_size > 1, generator generates part of the fused param on each ep rank in one ep_group. - # We can all gather them to get full fused param but it would lead to a larger memory usage. - # So we broadcast the part fused param from each ep rank in ep_group sequentially, - # and update the part of the fused param sequentially to reduce memory usage. - if isinstance(model.config, BaseComposeConfig): - ep_mesh: DeviceMesh = model.language_model.ep_mesh - else: - ep_mesh: DeviceMesh = model.ep_mesh - ep_group = ep_mesh.get_group() - global_rank = dist.get_rank() - for src_global_rank in dist.get_process_group_ranks(ep_group): - broadcast_state_dict = dict() - for key, tensor in state_dict.items(): - obj_to_broadcast = [key, tensor.to("meta")] if global_rank == src_global_rank else [None, None] - dist.broadcast_object_list(obj_to_broadcast, src=src_global_rank, group=ep_group) - real_key, meta_tensor = obj_to_broadcast - buffer = ( - state_dict[real_key] - if global_rank == src_global_rank - else torch.empty_like(meta_tensor, device=DEVICE) - ) - dist.broadcast(buffer, src=src_global_rank, group=ep_group) - broadcast_state_dict[real_key] = buffer - self.request_update_params(broadcast_state_dict, finished=False) - del broadcast_state_dict, buffer - else: - self.request_update_params(state_dict, finished=False) - del state_dict, name_list, fused_param_list - - for name_list, param_list in chain(same_gen, shard_gen): - state_dict = {name: param.detach() for name, param in zip(name_list, param_list)} - self.request_update_params(state_dict, finished=False) - del state_dict, name_list, param_list - - if self.rollout_cfg_info["backend"] in ("pytorch", "vllm") and final_update: - self.request_update_params({}, finished=True) - - dist.barrier() - DEVICE_MODULE.empty_cache() - return - - def _update_weights_by_layer(self): - """Update the model weights.""" - self.endpoints["update_weights"] = "update_weights" - assert self.rollout_device_mesh is not None - - model = self._engine.model - DEVICE_MODULE.empty_cache() - - if isinstance(model.config, BaseComposeConfig): - # TODO: support float8 for vision compose model - dtype = torch.bfloat16 - else: - if (model.config.float8_cfg is not None) and (model.config.float8_cfg.enable_float8): - dtype = torch.float8_e4m3fn - else: - dtype = torch.bfloat16 - - def get_params(tensor_list, name_list, save_dtype): - _tensor_list, _spec_list = list(zip(*tensor_list)) - fsdp_unshard_tensor_list = model._fsdp_foreach_allgather(_tensor_list, _spec_list) - if save_dtype == torch.float8_e4m3fn: - fsdp_unshard_tensor_list, name_list = model._to_float8( - fsdp_unshard_tensor_list, name_list, _tensor_list, save_dtype - ) - return fsdp_unshard_tensor_list, name_list - - saved_list = [] - is_qwen3vl = False - if isinstance(model.config, BaseComposeConfig): - language_model = model.language_model - if isinstance(model, Qwen3VLForConditionalGeneration): - is_qwen3vl = True - else: - language_model = model - - if is_qwen3vl: - vision_hf_prefix = "model.visual." - projector_hf_prefix = "model.visual." - else: - vision_hf_prefix = "model.vision_tower." - projector_hf_prefix = "model.multi_modal_projector." - - for i, layer in tqdm.tqdm(language_model.layers.items(), desc="[gather weight]"): - tensor_list = [] - name_list = [] - for sub_name, param in layer.state_dict().items(): - if isinstance(model.config, BaseComposeConfig): - saved_list.append(f"language_model.layers.{i}.{sub_name}") - else: - saved_list.append(f"layers.{i}.{sub_name}") - local_tensor = param._local_tensor if isinstance(param, DTensor) else param - local_tensor = local_tensor.bfloat16() - load_spec = language_model.load_spec_mapping.get(f"layers.{i}.{sub_name}") - - if isinstance(model.config, BaseComposeConfig): - name = f"model.language_model.layers.{i}.{sub_name}" - else: - name = f"model.layers.{i}.{sub_name}" - - if ".experts." in name and ".mlp.experts." not in name: - name = name.replace(".experts.", ".mlp.experts.") - if ".gate." in name and ".mlp.gate." not in name: - name = name.replace(".gate.", ".mlp.gate.") - name_list.append(name) - tensor_list.append((local_tensor, load_spec)) - fsdp_unshard_tensor_list, name_list = get_params(tensor_list, name_list, dtype) - state_dict = dict(zip(name_list, fsdp_unshard_tensor_list)) - self.request_update_params(state_dict) - - for name, param in model.state_dict().items(): - if name in saved_list: - continue - local_tensor = param._local_tensor if isinstance(param, DTensor) else param - local_tensor = local_tensor.bfloat16() - load_spec = model.load_spec_mapping.get(name) - - if isinstance(model.config, BaseComposeConfig): - if "vision_tower." in name: - name = name.replace("vision_tower.", vision_hf_prefix) - elif "multi_modal_projector." in name: - name = name.replace("multi_modal_projector.", projector_hf_prefix) - elif name == "language_model.norm.weight": - name = "model.language_model.norm.weight" - elif name == "language_model.embed_tokens.weight": - name = "model.language_model.embed_tokens.weight" - elif name == "language_model.lm_head.weight": - name = "lm_head.weight" - else: - if name == "norm.weight": - name = "model.norm.weight" - elif name == "embed_tokens.weight": - name = "model.embed_tokens.weight" - tensor_list = [(local_tensor, load_spec)] - name_list = [name] - fsdp_unshard_tensor_list, name_list = get_params(tensor_list, name_list, dtype) - state_dict = dict(zip(name_list, fsdp_unshard_tensor_list)) - self.request_update_params(state_dict) - - if self.rollout_cfg_info["backend"] in ("pytorch", "vllm"): - self.request_update_params({}, finished=True) - - dist.barrier() - DEVICE_MODULE.empty_cache() - return - - # def update_weights1(self): - # """Update the model weights.""" - # self.endpoints["update_weights"] = "update_weights" - # assert self.rollout_device_mesh is not None - # time1 = time.time() - - # model = self._engine.model - # DEVICE_MODULE.empty_cache() - - # if (model.config.float8_cfg is not None) and (model.config.float8_cfg.enable_float8): - # dtype = torch.float8_e4m3fn - # else: - # dtype = torch.bfloat16 - - # fused_params = [] - # for name, param in model.state_dict().items(): - # load_spec = model.load_spec_mapping.get(name) - # if load_spec.load_enum == LoadEnum.FUSED: - # fused_params.append((name, param, load_spec)) - - # # TODO: decouple update_weights from the model structure - # bucket_size = 1024**3 - # safetensor_size = 0 - # tensor_list: list[tuple[torch.Tensor, LoadSpec]] = [] - # name_list: list[str] = [] - # for name, param, load_spec in fused_params: - # local_tensor = param._local_tensor if isinstance(param, DTensor) else param - # local_tensor = local_tensor.bfloat16() - # if safetensor_size + model._get_tensor_size(param, dtype) > bucket_size: - # _tensor_list, _spec_list = list(zip(*tensor_list)) - # fsdp_unshard_tensor_list = model._fsdp_foreach_allgather(_tensor_list, _spec_list) - # if dtype == torch.float8_e4m3fn: - # fsdp_unshard_tensor_list, name_list = model._to_float8( - # fsdp_unshard_tensor_list, name_list, _tensor_list, dtype - # ) - # state_dict = dict(zip(name_list, fsdp_unshard_tensor_list)) - # self.request_update_params(state_dict) - # safetensor_size = 0 - # tensor_list = [(local_tensor, load_spec)] - # name_list = ["model." + name.replace(".experts.", ".mlp.experts.")] - # continue - # safetensor_size += model._get_tensor_size(param, dtype) - # tensor_list.append((local_tensor, load_spec)) - # name_list.append("model." + name.replace(".experts.", ".mlp.experts.")) - - # if tensor_list: - # assert len(name_list) == len(tensor_list) - # _tensor_list, _spec_list = list(zip(*tensor_list)) - # fsdp_unshard_tensor_list = model._fsdp_foreach_allgather(_tensor_list, _spec_list) - # if dtype == torch.float8_e4m3fn: - # fsdp_unshard_tensor_list, name_list = model._to_float8( - # fsdp_unshard_tensor_list, name_list, _tensor_list, dtype - # ) - # state_dict = dict(zip(name_list, fsdp_unshard_tensor_list)) - # self.request_update_params(state_dict) - - # same_gen = model._get_same_hf_param( - # model._group_param_by_load_spec(LoadEnum.SAME), - # dtype=dtype, - # device="cuda", - # bucket_size=1024**3, - # ) - # for name_list, gathered_tensor_list in tqdm.tqdm(same_gen, desc="[update dense weights]"): - # state_dict = dict(zip(name_list, gathered_tensor_list)) - # self.request_update_params(state_dict) - # del state_dict - - # self.request_update_params({}, finished=True) - - # dist.barrier() - # logger.info(f"update weights time: {time.time() - time1}") - # DEVICE_MODULE.empty_cache() - # return - - # def update_weights(self): - # """Update the model weights.""" - # self.endpoints["update_weights"] = "update_weights" - # assert self.rollout_device_mesh is not None - - # model = self._engine.model - # DEVICE_MODULE.empty_cache() - - # saved_keys = [] - # gather_duration = [] - # weight_duration = [] - # reshard_duration = [] - - # # update decoder layers - # for i, layer in tqdm.tqdm(model.layers.items(), desc="[gather weight]"): - # start = time.perf_counter() - # layer.unshard() - # layer_state_dict = {} - - # for sub_name, param in layer.named_parameters(): - # if "_checkpoint_wrapped_module." in sub_name: - # sub_name = sub_name.replace("_checkpoint_wrapped_module.", "") - # if isinstance(param, DTensor): - # param = param.to_local() - - # if isinstance(param, WeightWithDynamicTilewiseFloat8CastTensor): - # param = param._tensor - - # if isinstance(param, Float8Tensor): - # scale_name = f"model.layers.{i}.{sub_name}_scale_inv" - # assert "fused_w1w3" in sub_name or "fused_w2" in sub_name - # # save scale_inv parameter to state_dict - # scale_tensor = param._scale - # quant_tensor = param._data - # ep_mesh = model.ep_mesh - # if ep_mesh.size() > 1: - # scale_tensor = torch.cat(dist.nn.all_gather(scale_tensor, group=ep_mesh.get_group()), dim=0) - # quant_tensor = torch.cat(dist.nn.all_gather(quant_tensor, group=ep_mesh.get_group()), dim=0) - # layer_state_dict[scale_name] = scale_tensor.detach() - # # set `param` which will be added to state_dict at the bottom of the for-block - # param = quant_tensor - - # param = param.to(DEVICE) - # name = f"model.layers.{i}.{sub_name}" - # saved_keys.append(name.replace("model.", "")) - # if ".experts." in name and ".mlp." not in name: - # name = name.replace(".experts.", ".mlp.experts.") - # if ".gate." in name and ".mlp." not in name: - # name = name.replace(".gate.", ".mlp.gate.") - # layer_state_dict[name] = param.detach() - # gather_duration.append(time.perf_counter() - start) - # start = time.perf_counter() - # self.request_update_params(layer_state_dict, finished=True) - # breakpoint() - # weight_duration.append(time.perf_counter() - start) - - # start = time.perf_counter() - # del layer_state_dict - # layer.reshard() - # reshard_duration.append(time.perf_counter() - start) - - # if dist.get_rank() == 0: - # logger.debug( - # f"Rank 0 Gather decoder layers done, total {sum(gather_duration):.2f}s, avg " - # f"{sum(gather_duration) / len(gather_duration):.2f}s" - # ) - # logger.debug( - # f"Rank 0 migrate/save decoder layers done, total {sum(weight_duration):.2f}s, avg " - # f"{sum(weight_duration) / len(weight_duration):.2f}s" - # ) - # logger.debug( - # f"Rank 0 reshard decoder layers done, total {sum(reshard_duration):.2f}s, avg " - # f"{sum(reshard_duration) / len(reshard_duration):.2f}s" - # ) - - # # update other params - # model.norm.unshard() - # model.lm_head.unshard() - # model.embed_tokens.unshard() - # others_state_dict = {} - # for name, param in model.named_parameters(): - # if "_checkpoint_wrapped_module." in name: - # continue - # if name not in saved_keys: - # saved_keys.append(name) - # if name == "norm.weight": - # name = "model.norm.weight" - # if name == "embed_tokens.weight": - # name = "model.embed_tokens.weight" - # if isinstance(param, DTensor): - # param = param.to_local() - # others_state_dict[name] = param.detach() - # self.request_update_params(others_state_dict, finished=True) - # model.norm.reshard() - # model.lm_head.reshard() - # model.embed_tokens.reshard() - # del others_state_dict - # del param - - # dist.barrier() - # DEVICE_MODULE.empty_cache() - # return - - @ray_method - def request_update_params(self, state_dict, finished=False): - """Send a request to update the parameters on the rollout workers. - - This method serializes the state dictionary and sends it to the - appropriate rollout worker via an HTTP request. - - Args: - state_dict (dict | list): The state dictionary containing the model - parameters to update. - finished (bool): A flag indicating whether this is the final - batch of updates. Defaults to False. - """ - cpu_mesh = self.rollout_device_mesh["engine_parallel"] - cpu_group = cpu_mesh.get_group() - head_rank = cpu_mesh.mesh[0].item() - if self.rollout_url is None: - self.logger.error(f"rank {self.rank} url in None, cannot update weights and skip") - return - - if self.rollout_cfg_info["backend"] == "vllm": - - def serialize_state_dict(state_dict: dict) -> str: - import base64 - from io import BytesIO - from multiprocessing.reduction import ForkingPickler - - from torch.multiprocessing.reductions import reduce_tensor - - data = [(k, reduce_tensor(v)) for k, v in state_dict.items()] - buf = BytesIO() - ForkingPickler(buf).dump(data) - buf.seek(0) - return base64.b64encode(buf.read()).decode("utf-8") - - serialized_data = [None] * self.rollout_cfg_info["tp"] - dist.gather_object( - serialize_state_dict(state_dict), - serialized_data if dist.get_rank() == head_rank else None, - dst=head_rank, - group=cpu_group, - ) - if dist.get_rank() == head_rank: - headers = { - "Content-Type": "application/json", - } - data_ = json.dumps(dict(serialized_named_tensors=serialized_data, finished=finished)) - data = dict(method="update_weight_npu_ipc", args=[data_]) - response = requests.post(f"{self.rollout_url}/collective_rpc", headers=headers, json=data) - assert response.status_code == 200, f"response.status_code = {response.status_code}" - - if finished: - dist.barrier(group=cpu_group) - return - - if self.rollout_cfg_info["backend"] == "pytorch": - # TODO(chenchiyu): remove lmdeploy related code - from lmdeploy.utils import serialize_state_dict - - try: - from lmdeploy.utils import FlattenedTensorBucket - - use_flattened_tensor_bucket = True - except Exception: - use_flattened_tensor_bucket = False - - if self.rollout_cfg_info["backend"] == "pytorch" and self.rollout_cfg_info["tp"] > 1: - serialized_data = [None] * self.rollout_cfg_info["tp"] - if use_flattened_tensor_bucket: - flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=list(state_dict.items())) - metadata = flattened_tensor_bucket.get_metadata() - flattened_tensor_data = { - "flattened_tensor": flattened_tensor_bucket.get_flattened_tensor(), - "metadata": metadata, - } - tp_serialized_data = serialize_state_dict(flattened_tensor_data) - else: - tp_serialized_data = serialize_state_dict(state_dict) - dist.gather_object( - tp_serialized_data, - serialized_data if dist.get_rank() == head_rank else None, - dst=head_rank, - group=cpu_group, - ) - elif self.rollout_cfg_info["backend"] == "pytorch": - if use_flattened_tensor_bucket: - flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=list(state_dict.items())) - metadata = flattened_tensor_bucket.get_metadata() - flattened_tensor_data = { - "flattened_tensor": flattened_tensor_bucket.get_flattened_tensor(), - "metadata": metadata, - } - serialized_data = serialize_state_dict(flattened_tensor_data) - else: - serialized_data = serialize_state_dict(state_dict) - else: - # for turbomind backend, only head_rank should serialize data - serialized_data = serialize_state_dict(state_dict) if dist.get_rank() == head_rank else None - else: - # sglang - from sglang.srt.utils import MultiprocessingSerializer - from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions - - try: - from sglang.srt.model_executor.model_runner import FlattenedTensorBucket - - use_flattened_tensor_bucket = True - except Exception: - use_flattened_tensor_bucket = False - - # NOTE: xtuner目前去掉sglang的patch也不会出问题,但为了保险起见,还是保留patch逻辑,并且在update_weights结束后unpatch - monkey_patch_torch_reductions() - state_dict = state_dict.items() - if self.rollout_cfg_info["tp"] == 1: - if use_flattened_tensor_bucket: - flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=state_dict) - metadata = flattened_tensor_bucket.get_metadata() - - flattened_tensor_data = { - "flattened_tensor": flattened_tensor_bucket.get_flattened_tensor(), - "metadata": metadata, - } - serialized_data = MultiprocessingSerializer.serialize(flattened_tensor_data, output_str=True) - else: - serialized_data = MultiprocessingSerializer.serialize(state_dict, output_str=True) - - serialized_data = [serialized_data] - else: - serialized_data = [None] * self.rollout_cfg_info["tp"] - if use_flattened_tensor_bucket: - flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=state_dict) - metadata = flattened_tensor_bucket.get_metadata() - - flattened_tensor_data = { - "flattened_tensor": flattened_tensor_bucket.get_flattened_tensor(), - "metadata": metadata, - } - tp_serialized_data = MultiprocessingSerializer.serialize(flattened_tensor_data, output_str=True) - dist.gather_object( - tp_serialized_data, - serialized_data if dist.get_rank() == head_rank else None, - dst=head_rank, - group=cpu_group, - ) - else: - tp_serialized_data = MultiprocessingSerializer.serialize(state_dict, output_str=True) - dist.gather_object( - tp_serialized_data, - serialized_data if dist.get_rank() == head_rank else None, - dst=head_rank, - group=cpu_group, - ) - - if dist.get_rank() == head_rank: - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.rollout_cfg_info['api_key']}", - } - if self.rollout_cfg_info["backend"] == "sglang": - payload = { - "serialized_named_tensors": serialized_data, - "flush_cache": False, - } - try: - from sglang.srt.model_executor.model_runner import FlattenedTensorBucket - - use_flattened_tensor_bucket = True - except Exception: - use_flattened_tensor_bucket = False - if use_flattened_tensor_bucket: - payload["load_format"] = "flattened_bucket" - - url = f"{self.rollout_url}/update_weights_from_tensor" - response = requests.post(url, json=payload or {}) - response.raise_for_status() - else: - data = dict(serialized_named_tensors=serialized_data, finished=finished) - try: - from lmdeploy.utils import FlattenedTensorBucket - - use_flattened_tensor_bucket = True - except Exception: - use_flattened_tensor_bucket = False - - if use_flattened_tensor_bucket: - data["load_format"] = "flattened_bucket" - response = requests.post( - f"{self.rollout_url}/{self.endpoints['update_weights']}", headers=headers, json=data - ) - assert response.status_code == 200, f"response.status_code = {response.status_code}" - - if finished: - dist.barrier(group=cpu_group) - - monkey_unpatch_torch_reductions() - return - @ray_method def save(self, checkpoint_path: Path | str, no_save_optimizer: bool = False): """Save the DCP checkpoint of the training worker.""" diff --git a/xtuner/v1/rl/utils.py b/xtuner/v1/rl/utils.py deleted file mode 100644 index 8da313a360..0000000000 --- a/xtuner/v1/rl/utils.py +++ /dev/null @@ -1,60 +0,0 @@ -import atexit -import signal -import subprocess - -import torch.nn.functional as F - -from xtuner.v1.utils.logger import get_logger - - -def gather_logprobs(logits, shifted_labels): - logprobs = F.log_softmax(logits, dim=-1) - logprobs = logprobs.gather(dim=-1, index=shifted_labels.clip(min=0).unsqueeze(-1)).squeeze(-1) - return logprobs - - -logger = get_logger() - - -def close_ray(): - """Clean up the ray resource.""" - import ray - - # 1. Shutdown ray if initialized - try: - if ray.is_initialized(): - ray.shutdown() - logger.info("Ray shutdown successfully") - except Exception as e: - logger.warning(f"Error during ray.shutdown(): {e}") - - # 2. Stop ray launched by CLI - try: - result = subprocess.run(["ray", "stop", "--force"], capture_output=True, text=True, timeout=10) - if result.returncode != 0: - logger.warning(f"Ray stop failed: {result.stderr}") - except Exception as e: - logger.warning(f"Error stopping ray cluster: {e}") - - -def register_cleanup(): - """Register cleanup handlers for Ray on exit and signals.""" - _cleaned = False - - def cleanup_once(): - nonlocal _cleaned - if not _cleaned: - _cleaned = True - close_ray() - - def signal_handler(signum, frame): - logger.info(f"Received signal {signum}, cleaning up...") - cleanup_once() - import sys - - sys.exit(128 + signum) - - signal.signal(signal.SIGINT, signal_handler) - signal.signal(signal.SIGTERM, signal_handler) - - atexit.register(cleanup_once) diff --git a/xtuner/v1/rl/utils/__init__.py b/xtuner/v1/rl/utils/__init__.py new file mode 100644 index 0000000000..8e42e94cab --- /dev/null +++ b/xtuner/v1/rl/utils/__init__.py @@ -0,0 +1,81 @@ +from .async_utils import asyncio_run, create_task, handle_task_exception +from .misc import ( + BetweenNode, + BetweenOperator, + ConditionNode, + LogicNode, + LogicOperator, + Operators, + QueryNode, + ScalarNode, + ScalarOperator, + SetNode, + SetOperator, + calculate_seq_staleness, + chat_trace_records_to_rollout_states, + find_free_ports, + gather_logprobs, + get_eos_token, + load_function, + parse_query, + sort_rollout_state_for_deterministic, +) +from .ray_utils import ( + bind_train_rollout, + clear_rollout_response_for_rerun, + close_ray, + find_master_addr_and_port, + free_object_refs, + get_accelerator_ids, + get_ray_accelerator, + register_cleanup, +) +from .ray_worker import ( + AcceleratorResourcesConfig, + AutoAcceleratorWorkers, + AutoCPUWorkers, + BaseCPUWorker, + CPUActorLauncher, + CPUResourcesConfig, + SingleAcceleratorWorker, +) + + +__all__ = [ + "AcceleratorResourcesConfig", + "SingleAcceleratorWorker", + "AutoAcceleratorWorkers", + "CPUResourcesConfig", + "CPUActorLauncher", + "BaseCPUWorker", + "AutoCPUWorkers", + "get_ray_accelerator", + "load_function", + "find_master_addr_and_port", + "get_accelerator_ids", + "free_object_refs", + "clear_rollout_response_for_rerun", + "bind_train_rollout", + "handle_task_exception", + "create_task", + "QueryNode", + "ConditionNode", + "ScalarNode", + "SetNode", + "BetweenNode", + "LogicNode", + "parse_query", + "gather_logprobs", + "close_ray", + "register_cleanup", + "ScalarOperator", + "SetOperator", + "BetweenOperator", + "LogicOperator", + "Operators", + "get_eos_token", + "calculate_seq_staleness", + "chat_trace_records_to_rollout_states", + "sort_rollout_state_for_deterministic", + "find_free_ports", +] diff --git a/xtuner/v1/rl/utils/async_utils.py b/xtuner/v1/rl/utils/async_utils.py new file mode 100644 index 0000000000..ca6b25d4ee --- /dev/null +++ b/xtuner/v1/rl/utils/async_utils.py @@ -0,0 +1,109 @@ +import asyncio +from asyncio import AbstractEventLoop, Task +from typing import Any, Callable, Coroutine, List, Optional + + +_ASYNCIO_RUN_LOOP: AbstractEventLoop | None = None + + +def handle_task_exception(task: Task): + """Handles exceptions from an asyncio Task. + + This function checks if a task has raised an exception and, if so, + re-raises it. It ignores `asyncio.CancelledError`. + + Args: + task (Task): The asyncio task to check for exceptions. + + Raises: + Exception: The exception raised by the task. + """ + try: + exc = task.exception() + if exc is not None: + raise exc + except asyncio.CancelledError: + pass # Task was cancelled, ignore + + +def create_task( + coro: Coroutine, + loop: Optional[AbstractEventLoop] = None, + done_callbacks: Optional[List[Callable[[Task], object]]] = None, +) -> asyncio.tasks.Task: + """Creates and configures an asyncio Task. + + This function creates a task from a coroutine and attaches specified + done callbacks. By default, it includes a callback to handle exceptions. + + Args: + coro (Coroutine): The coroutine to wrap in a task. + loop (Optional[AbstractEventLoop], optional): The event loop to run + the task in. If None, the current event loop is used. + Defaults to None. + done_callbacks (Optional[List[Callable[[Task], object]]], optional): + A list of callbacks to add to the task. If None, a default + exception handler is used. Defaults to None. + + Returns: + asyncio.tasks.Task: The created asyncio task. + """ + if loop is None: + loop = asyncio.get_event_loop() + if done_callbacks is None: + done_callbacks = [handle_task_exception] + task = loop.create_task(coro) + for callback in done_callbacks: + task.add_done_callback(callback) + return task + + +def _get_default_asyncio_loop() -> AbstractEventLoop: + """Get a module-level event loop reused by ``asyncio_run``.""" + global _ASYNCIO_RUN_LOOP + if _ASYNCIO_RUN_LOOP is not None and not _ASYNCIO_RUN_LOOP.is_closed(): + return _ASYNCIO_RUN_LOOP + + try: + loop = asyncio.get_running_loop() + except RuntimeError: + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + _ASYNCIO_RUN_LOOP = loop + return _ASYNCIO_RUN_LOOP + + +def asyncio_run(coro: Coroutine, loop: Optional[AbstractEventLoop] = None) -> Any: + """Synchronously run a coroutine on a shared/explicit event loop. + + This helper is used by `RLColocateTrainer.fit` for rollout collection: + 1) Trainer runs in sync code and repeatedly calls: + - self.eval_agent_loop_manager.produce_batch(...) + - self.agent_loop_manager.produce_batch(...) + 2) `produce_batch` is async, and internally runs `ProduceStrategy.produce_batch`, + which launches many nested async tasks (`create_task`) and ultimately calls + `AgentLoop.generate_group -> generate_sample`. + 3) In `VerlToolAgentLoop`, `generate_sample` awaits `self.verl_tool_agent_loop.run()`, + where the tool loop stays on the same loop. + + In this pattern, if sync code uses `asyncio.run` every call, each invocation + creates/closes a fresh loop, but `VerlToolAgentLoop` keeps internal work + on one loop, the wrapped `generate_sample -> run -> Ray futures` chain can see + mismatched loop ownership and trigger: + ``Future attached to a different loop``. + + `asyncio_run` keeps calls bound to a stable loop instance so nested task/future + chains stay compatible across repeated rollout phases. + + This helper is for sync-to-async boundaries only and should not be used from + within an already running event loop. + """ + if loop is None: + loop = _get_default_asyncio_loop() + if loop.is_running(): + raise RuntimeError("asyncio_run does not support being called from a running event loop.") + return loop.run_until_complete(coro) diff --git a/xtuner/v1/rl/utils/misc.py b/xtuner/v1/rl/utils/misc.py new file mode 100644 index 0000000000..7ee9f35cd9 --- /dev/null +++ b/xtuner/v1/rl/utils/misc.py @@ -0,0 +1,338 @@ +import importlib +import json +import random +import socket +import typing +from abc import ABC +from copy import deepcopy +from dataclasses import asdict, is_dataclass +from pathlib import Path +from typing import Any, List, Literal, Union + +import torch.nn.functional as F + +from xtuner.v1.data_proto.rl_data import RolloutState, Status +from xtuner.v1.data_proto.utils import calculate_seq_staleness as calculate_seq_staleness +from xtuner.v1.utils.logger import get_logger + + +logger = get_logger() +ScalarOperator = Literal["$eq", "$ne", "$gt", "$gte", "$lt", "$lte"] +SetOperator = Literal["$in", "$not_in"] +BetweenOperator = Literal["$between"] +Operators = Union[ScalarOperator, SetOperator, BetweenOperator] +LogicOperator = Literal["$and", "$or"] + + +class QueryNode(ABC): + """查询语法树的基类,仅作数据结构标记.""" + + pass + + +class ConditionNode(QueryNode): + """代表一个具体的查询条件.""" + + field: str + + +class ScalarNode(ConditionNode): + def __init__(self, field: str, op: ScalarOperator, value: Any): + self.field = field + self.op = op + self.value = value + + +class SetNode(ConditionNode): + def __init__(self, field: str, op: SetOperator, value: list[Any] | tuple[Any]): + self.field = field + self.op = op + self.value = value + + +class BetweenNode(ConditionNode): + def __init__(self, field: str, lower: Any, upper: Any): + if lower > upper: + raise ValueError("lower bound must be less than or equal to upper bound") + self.field = field + self.op = "$between" + self.lower = lower + self.upper = upper + + +class LogicNode(QueryNode): + """复合逻辑组.""" + + def __init__(self, relation: LogicOperator, conditions: List[QueryNode]): + self.relation = relation + self.conditions = conditions + + +def parse_query(expr: Union[dict, QueryNode]) -> QueryNode: + """将基于字典的 DSL 解析为纯粹的 AST 节点树 (ConditionNode, LogicNode)""" + if isinstance(expr, QueryNode): + return expr + + if isinstance(expr, dict): + conditions: list[QueryNode] = [] + for key, value in expr.items(): + if key in ("$and", "$or"): + if isinstance(value, list): + sub_asts = [parse_query(sub_expr) for sub_expr in value] + conditions.append(LogicNode(key, sub_asts)) # type: ignore + else: + raise ValueError(f"逻辑操作符 {key} 的值必须是一个列表") + else: + if isinstance(value, dict): + # 例如: {"staleness": {"$lt": 5, "$gt": 0}} + for op, op_val in value.items(): + if op in typing.get_args(ScalarOperator): + conditions.append(ScalarNode(field=key, op=op, value=op_val)) + elif op in typing.get_args(SetOperator): + if not isinstance(op_val, (list, tuple)): + raise ValueError(f"操作符 '{op}' 需要传入一个列表或元组") + conditions.append(SetNode(field=key, op=op, value=op_val)) + elif op == "$between": + if not isinstance(op_val, (list, tuple)) or len(op_val) != 2: + raise ValueError("操作符 '$between' 需要传入包含2个元素的列表或元组") + conditions.append(BetweenNode(field=key, lower=op_val[0], upper=op_val[1])) + else: + raise ValueError(f"不支持的操作符: {op}") + else: + # 隐式等值,例如: {"task_name": "math"} -> "$eq" + conditions.append(ScalarNode(field=key, op="$eq", value=value)) + + if len(conditions) > 1: + # 默认多个条件之间是 AND 关系,例如: {"uid": "123", "status": {"$in": ["pending", "running]}}} + return LogicNode("$and", conditions) # type: ignore + return conditions[0] if conditions else LogicNode("$and", []) + + raise ValueError(f"不支持的查询表达式格式: {expr}") + + +def gather_logprobs(logits, shifted_labels): + logprobs = F.log_softmax(logits, dim=-1) + logprobs = logprobs.gather(dim=-1, index=shifted_labels.clip(min=0).unsqueeze(-1)).squeeze(-1) + return logprobs + + +def sort_rollout_state_for_deterministic(data_groups: list[list[RolloutState]]) -> list[list[RolloutState]]: + def sort_key(sample: RolloutState) -> tuple[int, int]: + return (sample.message_uid or 0, sample.uid or 0) + + sorted_groups = [sorted(group, key=sort_key) for group in data_groups] + sorted_groups.sort(key=lambda group: min((sort_key(item) for item in group), default=(-1, -1))) + return sorted_groups + + +def load_function(path): + """Load a function from a module. + + :param path: The path to the function, e.g. "module.submodule.function". + :return: The function object. + """ + module_path, _, attr = path.rpartition(".") + module = importlib.import_module(module_path) + return getattr(module, attr) + + +def find_free_ports( + *, + nums: int = 1, + host: str = "127.0.0.1", + start_port: int | None = None, + end_port: int | None = None, + contiguous: bool = False, +) -> list[int]: + """Return available TCP ports on the given host. + + The candidate sockets are kept open until all requested ports are found so + one call cannot return duplicate ports. Set ``contiguous=True`` to require + the returned ports to be a continuous range. + """ + if nums < 1: + raise ValueError("nums must be greater than 0.") + if start_port is not None: + if end_port is None: + raise ValueError("end_port must be set when start_port is set.") + if end_port - start_port < nums: + raise ValueError("The port range must contain at least nums ports.") + + def try_bind_ports(candidate_ports: list[int]) -> list[int] | None: + ports: list[int] = [] + sockets: list[socket.socket] = [] + try: + for candidate_port in candidate_ports: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + try: + sock.bind((host, candidate_port)) + sock.listen(1) + except OSError: + sock.close() + return None + + sockets.append(sock) + ports.append(int(sock.getsockname()[1])) + return ports + finally: + for sock in sockets: + sock.close() + + if contiguous: + if start_port is None: + for _ in range(100): + candidate = random.randint(20000, 60000 - nums) + bound_ports = try_bind_ports(list(range(candidate, candidate + nums))) + if bound_ports is not None: + return bound_ports + else: + assert end_port is not None + for candidate in range(start_port, end_port - nums + 1): + bound_ports = try_bind_ports(list(range(candidate, candidate + nums))) + if bound_ports is not None: + return bound_ports + else: + available_ports: list[int] = [] + sockets: list[socket.socket] = [] + try: + if start_port is None: + candidates: range | list[int] = [0] * nums + else: + assert end_port is not None + candidates = range(start_port, end_port) + + for candidate_port in candidates: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + try: + sock.bind((host, candidate_port)) + sock.listen(1) + except OSError: + sock.close() + continue + + sockets.append(sock) + available_ports.append(int(sock.getsockname()[1])) + if len(available_ports) >= nums: + return available_ports + finally: + for sock in sockets: + sock.close() + + if start_port is None: + raise RuntimeError(f"Could not find {nums} available ports.") + raise RuntimeError(f"Could not find {nums} available ports from {start_port} to {end_port}.") + + +def get_eos_token(model_path: str) -> int | List[int]: + generation_config_path = Path(model_path) / "generation_config.json" + if not generation_config_path.exists(): + logger.warning( + f"Config {generation_config_path} does not exist and thus cannot get eos_token. You must provide eos_token manually." + ) + return [] + with open(generation_config_path) as f: + generation_config = json.load(f) + eos_token_id = generation_config.get("eos_token_id") + if eos_token_id is None: + raise ValueError( + f"eos_token_id is not found in {generation_config_path}. You must provide eos_token manually." + ) + return eos_token_id + + +def chat_trace_records_to_rollout_states( + rollout_state: RolloutState, + records: list[Any], + *, + tokenizer: Any | None = None, + extra_fields: dict[str, Any] | None = None, +) -> list[RolloutState]: + """Convert Gateway chat trace records into trainable rollout states. + + The records may be ``ChatTraceRecord`` dataclass instances or serialized + dictionaries returned by ``/trace_store``. + """ + normalized_records = [] + for record in records: + if isinstance(record, dict): + normalized_records.append(record) + elif not isinstance(record, type) and is_dataclass(record): + normalized_records.append(asdict(record)) + elif hasattr(record, "__dict__"): + normalized_records.append(dict(record.__dict__)) + else: + raise TypeError(f"Unsupported chat trace record type: {type(record)}") + + trace_count = len(normalized_records) + trace_summary = [ + { + "request_id": record.get("request_id"), + "finish_reason": record.get("finish_reason"), + "status": record.get("status"), + "prompt_ids": record.get("prompt_ids", []), + "response_ids": record.get("response_ids", []), + } + for record in normalized_records + ] + + states: list[RolloutState] = [] + for index, record in enumerate(normalized_records): + prompt_ids = record.get("prompt_ids") + response_ids = record.get("response_ids") + if not prompt_ids or not response_ids: + raise RuntimeError(f"Gateway trace record {index} is missing prompt_ids or response_ids.") + + logprobs = record.get("logprobs") + if not isinstance(logprobs, list) or len(logprobs) != len(response_ids): + logprobs = None + + status_value = record.get("status") + if isinstance(status_value, Status): + status = status_value + elif isinstance(status_value, str): + try: + status = Status(status_value) + except ValueError: + status = Status.FAILED + else: + status = Status.FAILED + + request_id = record.get("request_id") + try: + uid = int(request_id) if request_id is not None else None + except (TypeError, ValueError): + uid = None + + response = record.get("output_text") + if response is None and tokenizer is not None: + try: + response = tokenizer.decode(response_ids) + except Exception: + response = None + + normalized = rollout_state.model_copy(deep=True) + normalized.uid = uid + normalized.prompt_ids = list(prompt_ids) + normalized.tokens = list(prompt_ids) + normalized.response_ids = list(response_ids) + normalized.response_mask = [1] * len(response_ids) + normalized.logprobs = logprobs + normalized.response = response + normalized.finish_reason = record.get("finish_reason") + normalized.status = status + normalized.error_msg = None if status == Status.COMPLETED else f"Gateway trace status={status.value}" + normalized.reward = None + normalized.extra_fields = { + **deepcopy(rollout_state.extra_fields), + "gateway_trace_index": index, + "gateway_trace_count": trace_count, + "gateway_trace_records": deepcopy(trace_summary), + "gateway_request_id": record.get("request_id"), + "gateway_request_snapshot": record.get("request_snapshot"), + "gateway_response_snapshot": record.get("response_snapshot"), + **deepcopy(extra_fields or {}), + } + states.append(normalized) + return states diff --git a/xtuner/v1/rl/utils/ray_utils.py b/xtuner/v1/rl/utils/ray_utils.py new file mode 100644 index 0000000000..98f85708d3 --- /dev/null +++ b/xtuner/v1/rl/utils/ray_utils.py @@ -0,0 +1,181 @@ +import atexit +import signal +import subprocess +from typing import TYPE_CHECKING, Optional, cast + +import ray +from ray import ObjectRef + +from xtuner.v1.utils.logger import get_logger + +from .misc import find_free_ports + + +if TYPE_CHECKING: + from xtuner.v1.data_proto.rl_data import RolloutState + + from .ray_worker import AcceleratorType + + +logger = get_logger() + + +@ray.remote +def find_master_addr_and_port( + nums: int = 1, start_port: Optional[int] = None, end_port: Optional[int] = None +) -> tuple[str, int] | tuple[str, list[int]]: + """Finds an available master address and a specified number of ports. + + This remote function gets the node's IP address and binds to one or more + available ports, which can be used for distributed communication. + + Args: + nums (int): The number of ports to find. Defaults to 1. + start_port (Optional[int]): The starting port to search from. + If None, random available ports will be used. Defaults to None. + end_port (Optional[int]): The ending port to search to (exclusive). + If start_port is None, this parameter is ignored. Defaults to None. + + Returns: + A tuple containing the address and a single port if `nums` is 1, + or a list of ports if `nums` is greater than 1. + """ + addr = ray.util.get_node_ip_address() + ports = find_free_ports(nums=nums, host="", start_port=start_port, end_port=end_port) + + if len(ports) == 1: + return addr, ports[0] + else: + return addr, ports + + +@ray.remote +def get_accelerator_ids(accelerator: str) -> list: + """Get the IDs of the available accelerators (GPUs, NPUs, etc.) in the Ray + cluster.""" + return ray.get_runtime_context().get_accelerator_ids()[accelerator] + + +def get_ray_accelerator() -> "AcceleratorType": + from xtuner.v1.utils.device import get_device + + """Get the type of accelerator available in the Ray environment. + + This function checks for the availability of CUDA and NPU devices and + returns the corresponding accelerator type. + + Returns: + AcceleratorType: The type of accelerator ("GPU" or "NPU"). + + Raises: + NotImplementedError: If neither CUDA nor NPU is available. + """ + accelerator = None + if get_device() == "cuda": + accelerator = "GPU" + return "GPU" + else: + try: + import torch_npu # noqa: F401 + + accelerator = "NPU" + except ImportError: + pass + + if accelerator is None: + raise NotImplementedError( + "Supports only CUDA or NPU. If your device is CUDA or NPU, " + "please make sure that your environmental settings are " + "configured correctly." + ) + + return cast("AcceleratorType", accelerator) + + +def free_object_refs(refs: list[ObjectRef]) -> None: + valid_refs = [ref for ref in refs if isinstance(ref, ObjectRef)] + if not valid_refs: + return + + try: + ray._private.internal_api.free(valid_refs, local_only=False) + except Exception: + ray.internal.free(valid_refs, local_only=False) + + +def clear_rollout_response_for_rerun(rollout_state: "RolloutState") -> "RolloutState": + routed_experts = getattr(rollout_state, "routed_experts", None) + if isinstance(routed_experts, ObjectRef): + free_object_refs([routed_experts]) + rollout_state.tokens = getattr(rollout_state, "prompt_ids", None) + rollout_state.response = None + rollout_state.response_ids = [] + rollout_state.logprobs = [] + rollout_state.routed_experts = None + rollout_state.finish_reason = None + rollout_state.response_mask = [] + rollout_state.response_model_steps = [] + rollout_state.reward = None + rollout_state.error_msg = None + return rollout_state + + +def close_ray(): + """Clean up the ray resource.""" + # 1. Shutdown ray if initialized + try: + if ray.is_initialized(): + ray.shutdown() + logger.info("Ray shutdown successfully") + except Exception as e: + logger.warning(f"Error during ray.shutdown(): {e}") + + # 2. Stop ray launched by CLI + try: + result = subprocess.run(["ray", "stop", "--force"], capture_output=True, text=True, timeout=10) + if result.returncode != 0: + logger.warning(f"Ray stop failed: {result.stderr}") + except Exception as e: + logger.warning(f"Error stopping ray cluster: {e}") + + +def register_cleanup(): + """Register cleanup handlers for Ray on exit and signals.""" + _cleaned = False + + def cleanup_once(): + nonlocal _cleaned + if not _cleaned: + _cleaned = True + close_ray() + + def signal_handler(signum, frame): + logger.info(f"Received signal {signum}, cleaning up...") + cleanup_once() + import sys + + sys.exit(128 + signum) + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + atexit.register(cleanup_once) + + +def bind_train_rollout( + train_workers, + rollout_controller, +) -> None: + """Bind the training and rollout workers for updating weights. + + This function retrieves rollout information from the rollout controller + and distributes it to the training workers, enabling them to update the + rollout models' weights. + + Args: + train_workers: A list of training worker actors. + rollout_controller: The rollout controller actor. + """ + info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) # type: ignore[attr-defined] + ray.get([worker.update_rollout_info.remote(**info_dict) for worker in train_workers]) # type: ignore[attr-defined] + return diff --git a/xtuner/v1/ray/base/accelerator.py b/xtuner/v1/rl/utils/ray_worker.py similarity index 56% rename from xtuner/v1/ray/base/accelerator.py rename to xtuner/v1/rl/utils/ray_worker.py index cf05509525..859d6cc649 100644 --- a/xtuner/v1/ray/base/accelerator.py +++ b/xtuner/v1/rl/utils/ray_worker.py @@ -1,4 +1,5 @@ import os +import threading from typing import Any, Dict, List, Literal, Tuple, TypeVar import ray @@ -13,9 +14,10 @@ placement_group, placement_group_table, ) +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from typing_extensions import Annotated -from ..utils import find_master_addr_and_port, get_accelerator_ids +from .ray_utils import find_master_addr_and_port, get_accelerator_ids PG_READY_TIMEOUT = os.getenv("XTUNER_PG_READY_TIMEOUT", 30) # default 30 seconds @@ -23,6 +25,84 @@ T = TypeVar("T") +class CPUResourcesConfig(BaseModel): + """Configuration for CPU resources in a placement group for XTuner. + + This class provide specific configuration options for CPU-based workers in Ray placement groups. + + Args: + num_cpus_per_worker (float): Number of CPUs to allocate per worker in the + placement group. Defaults to 8. + cpu_memory_per_worker (int): Amount of CPU memory (in bytes) to allocate + for each worker in the placement group. + num_workers (int): Total number of workers in the placement group. + """ + + model_config = ConfigDict(extra="forbid") + num_workers: Annotated[int, Parameter(help="Number of workers in the placement group.")] = 1 + num_cpus_per_worker: Annotated[float, Parameter(help="Number of CPUs to allocate for the placement group.")] = 1 + cpu_memory_per_worker: Annotated[ + int, Parameter(help="Amount of memory (in bytes) to allocate for the placement group.") + ] = 1024**3 # 1 GB + pg_pack_strategy: Annotated[ + str, + Parameter(help="Placement group packing strategy, options: " + ", ".join(VALID_PLACEMENT_GROUP_STRATEGIES)), + ] = "SPREAD" + + @field_validator("pg_pack_strategy") + @classmethod + def check_pg_pack_strategy(cls, v): + if v not in VALID_PLACEMENT_GROUP_STRATEGIES: + raise ValueError(f"pg_pack_strategy must be one of {VALID_PLACEMENT_GROUP_STRATEGIES}") + return v + + def model_post_init(self, __context: Any) -> None: + assert ray.is_initialized(), "Ray must be initialized before creating CPUResourcesConfig." + available_resources = ray.available_resources() + available_cpus = available_resources.get("CPU", 0) + available_memory = available_resources.get("memory", 0) + # TODO: manage single controller's cpu resource to replace "10" here + needed_cpus = (self.num_cpus_per_worker * self.num_workers) + 10 + assert needed_cpus <= available_cpus, ( + f"Not enough available CPUs in Ray cluster, available_cpus is {available_cpus} but xtuner needs {needed_cpus}." + ) + needed_memory = self.cpu_memory_per_worker * self.num_workers + 10 * 1024**3 + assert needed_memory <= available_memory, ( + f"Not enough available memory in Ray cluster, available_memory is {available_memory} but xtuner needs {needed_memory}." + ) + # TODO: check all resources sum in cluster to avoid over allocation + + @classmethod + def from_total( + cls, total_cpus: float | int, total_memory: int, num_workers: int, pg_pack_strategy: str = "SPREAD" + ): + """Create a CPUResourcesConfig from total CPU and memory resources. + + Args: + total_cpus (float | int): Total number of CPUs to allocate across all workers. + total_memory (int): Total amount of memory (in bytes) to allocate across all workers. + num_workers (int): Number of workers in the placement group. + + Returns: + CPUResourcesConfig: The created CPUResourcesConfig object. + """ + assert num_workers > 0, "Number of workers must be positive." + return cls( + num_workers=num_workers, + num_cpus_per_worker=total_cpus / num_workers, + cpu_memory_per_worker=total_memory / num_workers, + pg_pack_strategy=pg_pack_strategy, + ) + + def build_placement_group(self) -> PlacementGroup: + """Build a Ray PlacementGroup based on this resource configuration. + + Returns: + PlacementGroup: The created Ray PlacementGroup. + """ + return CPUActorLauncher.build_placement_group(self) + + class AcceleratorResourcesConfig(BaseModel): """Configuration for accelerator resources in a placement group for XTuner. @@ -235,6 +315,7 @@ def setup_distributed(self, rank: int, master_addr: str, master_port: int, world os.environ["RANK"] = str(rank) os.environ["WORLD_SIZE"] = str(world_size) os.environ["LOCAL_RANK"] = str(self.get_logical_local_rank()) + # backend 参数是指定通信后端,不是从环境变量获取 # - 'nccl': NVIDIA GPU 间通信(推荐用于 GPU) # - 'gloo': CPU 通信或跨平台 @@ -462,3 +543,291 @@ def from_placement_group( rank_bundle_idx_list.append((rank, bundle_idx)) return workers_list, rank_bundle_idx_list + + +class BaseCPUWorker: + """The BaseCPUWorker class serves as a foundational structure for CPU-based + workers within the XTuner framework. + + This class is designed to be extended by specific CPU worker implementations. + It provides a constructor that accepts a configuration object, allowing + subclasses to initialize with custom settings. + + Args: + config: The configuration object for the CPU worker. + num_cpus (float | int): The number of CPUs allocated to this worker. + Defaults to 1. + """ + + def __init__(self, config, num_cpus: float | int = 1): + self.config = config + self.num_cpus = num_cpus + + +class CPUActorLauncher: + """Infrastructure for launching CPU Ray actors from plain Python classes. + + This class owns the generic actorization flow for CPU-only components: + building homogeneous CPU placement groups, converting plain classes into + Ray actor classes, validating bundle resources, and launching one or more + actors on specific bundles. + """ + + _ACTOR_CLASS_CACHE: dict[type, ActorClass] = {} + + @staticmethod + def build_placement_group(resources_config: CPUResourcesConfig): + """Build a Ray PlacementGroup based on the provided resource + configuration. + + Args: + resources_config (CPUResourcesConfig): The configuration + specifying the resources for each worker bundle. + + Returns: + PlacementGroup: The created Ray PlacementGroup. + """ + bundles = [ + { + "CPU": resources_config.num_cpus_per_worker, + "memory": resources_config.cpu_memory_per_worker, + } + ] * resources_config.num_workers + + pg = placement_group(bundles=bundles, strategy=resources_config.pg_pack_strategy) + + ray.get(pg.ready(), timeout=PG_READY_TIMEOUT) + return pg + + @staticmethod + def get_pg_options(pg: PlacementGroup, num_cpus: int | float = -1) -> Dict: + """Provide a dictionary of resource requests for Ray tasks or actors + with specific cpu requirements. + + Args: + pg (PlacementGroup): The placement group to get options for. + num_cpus (float): The number of CPUs to request. If set to -1, + the default CPU allocation from the placement group bundle + will be used. Defaults to -1. + + Returns: + Dict: A dictionary of Ray resource options for `task.options()`. + """ + assert len(pg.bundle_specs) > 0, "Placement group has no bundles defined." + default_cpu = pg.bundle_specs[0].get("CPU", 1) + return {"num_cpus": num_cpus if num_cpus >= 0 else default_cpu} + + @classmethod + def to_actor_class(cls, worker_cls): + """Convert a plain Python class into a Ray actor class. + + If ``worker_cls`` is already a Ray actor class, it is returned as-is. + """ + if hasattr(worker_cls, "remote") and hasattr(worker_cls, "options"): + return worker_cls + + if worker_cls not in cls._ACTOR_CLASS_CACHE: + cls._ACTOR_CLASS_CACHE[worker_cls] = ray.remote(worker_cls) + return cls._ACTOR_CLASS_CACHE[worker_cls] + + @staticmethod + def _get_bundle_resources(pg: PlacementGroup, bundle_idx: int) -> dict[str, float | int]: + assert len(pg.bundle_specs) > bundle_idx, f"Placement group does not have bundle index {bundle_idx}." + return pg.bundle_specs[bundle_idx] + + @classmethod + def _resolve_actor_resources( + cls, + pg: PlacementGroup, + bundle_idx: int, + actor_num_cpus: int | float | None = None, + actor_memory: int | None = None, + ) -> tuple[float | int, int]: + bundle = cls._get_bundle_resources(pg, bundle_idx) + resolved_num_cpus = actor_num_cpus if actor_num_cpus is not None else bundle.get("CPU", 1) + resolved_memory = actor_memory if actor_memory is not None else int(bundle.get("memory", 0)) + assert bundle.get("CPU", 1) >= resolved_num_cpus, ( + f"Placement group bundle {bundle_idx} does not have enough CPU resources." + ) + assert bundle.get("memory", 0) >= resolved_memory, ( + f"Placement group bundle {bundle_idx} does not have enough memory resources." + ) + return resolved_num_cpus, resolved_memory + + @classmethod + def build_actor( + cls, + worker_cls, + *init_args, + pg: PlacementGroup | None = None, + bundle_idx: int = 0, + actor_num_cpus: int | float | None = None, + actor_memory: int | None = None, + capture_child_tasks: bool = False, + **init_kwargs, + ): + """Build a single CPU actor from a plain class or Ray actor class.""" + resolved_num_cpus = 1 if actor_num_cpus is None else actor_num_cpus + resolved_memory = actor_memory + + actor_cls = cls.to_actor_class(worker_cls) + actor_options = { + "num_cpus": resolved_num_cpus, + } + if resolved_memory is not None and resolved_memory > 0: + actor_options["memory"] = resolved_memory + + if pg is None: + return actor_cls.options(**actor_options).remote(*init_args, **init_kwargs) + + resolved_num_cpus, resolved_memory = cls._resolve_actor_resources( + pg=pg, + bundle_idx=bundle_idx, + actor_num_cpus=actor_num_cpus, + actor_memory=actor_memory, + ) + actor_options["num_cpus"] = resolved_num_cpus + if resolved_memory > 0: + actor_options["memory"] = resolved_memory + actor_options["scheduling_strategy"] = PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=bundle_idx, + placement_group_capture_child_tasks=capture_child_tasks, + ) + return actor_cls.options(**actor_options).remote(*init_args, **init_kwargs) + + @classmethod + def build_actors( + cls, + worker_cls, + *init_args, + pg: PlacementGroup | None = None, + start_bundle_idx: int = 0, + num_workers: int = 1, + actor_num_cpus_per_worker: int | float | None = None, + actor_memory_per_worker: int | None = None, + capture_child_tasks: bool = False, + **init_kwargs, + ): + """Build multiple homogeneous CPU actors from a plain class or Ray + actor class.""" + workers_list = [] + for idx in range(num_workers): + workers_list.append( + cls.build_actor( + worker_cls, + *init_args, + pg=pg, + bundle_idx=start_bundle_idx + idx, + actor_num_cpus=actor_num_cpus_per_worker, + actor_memory=actor_memory_per_worker, + capture_child_tasks=capture_child_tasks, + **init_kwargs, + ) + ) + return workers_list + + +class AutoCPUWorkers(CPUActorLauncher): + """Convenience wrapper for BaseCPUWorker-style homogeneous worker pools. + + `CPUActorLauncher` is the generic actorization layer. `AutoCPUWorkers` + keeps the legacy worker-centric API that instantiates one worker per bundle + using the conventional `(worker_config, num_cpus=...)` constructor shape. + """ + + _PG_NEXT_BUNDLE_INDEX: dict[str, int] = {} + _PG_NEXT_BUNDLE_INDEX_LOCK = threading.Lock() + + @staticmethod + def _get_pg_key(pg: PlacementGroup) -> str: + """Build a stable placement-group identifier for local bundle + tracking.""" + return str(pg.id) + + @classmethod + def _reserve_bundle_range( + cls, + pg: PlacementGroup, + num_workers: int, + start_bundle_idx: int | None, + ) -> tuple[int, int]: + """Reserve a contiguous bundle range for worker creation. + + When ``start_bundle_idx`` is omitted, the next unconsumed bundle range + in this process is used. Explicit bundle reservations still advance the + local cursor so later auto-allocation does not reuse the same bundles. + """ + pg_key = cls._get_pg_key(pg) + + with cls._PG_NEXT_BUNDLE_INDEX_LOCK: + current_cursor = cls._PG_NEXT_BUNDLE_INDEX.get(pg_key, 0) + resolved_start_bundle_idx = current_cursor if start_bundle_idx is None else start_bundle_idx + resolved_num_workers = num_workers if num_workers > 0 else pg.bundle_count - resolved_start_bundle_idx + + assert resolved_num_workers > 0, "At least one worker must be created from the placement group." + assert resolved_start_bundle_idx >= 0, "start_bundle_idx must be non-negative." + assert resolved_start_bundle_idx + resolved_num_workers <= pg.bundle_count, ( + "Placement group does not have enough remaining bundles for the requested CPU workers." + ) + + cls._PG_NEXT_BUNDLE_INDEX[pg_key] = max(current_cursor, resolved_start_bundle_idx + resolved_num_workers) + + return resolved_start_bundle_idx, resolved_num_workers + + @classmethod + def from_config(cls, worker_cls, worker_config, cpu_config: CPUResourcesConfig): + """Create workers and a placement group from configuration objects. + + Args: + worker_cls: The class of the worker to instantiate. + worker_config: The configuration for each worker instance. + cpu_config (CPUResourcesConfig): The configuration + for the cpu resources. + + Returns: + List[T]: List of created worker instances. + """ + pg = cls.build_placement_group(cpu_config) + workers_list = cls.from_placement_group(worker_cls, worker_config, pg) + + return workers_list, pg + + @classmethod + def from_placement_group( + cls, + worker_cls, + worker_config, + pg: PlacementGroup, + num_workers: int = -1, + start_bundle_idx: int | None = None, + ): + """Create workers from an existing placement group. + + Args: + worker_cls: The class of the worker to instantiate. + worker_config: The configuration for each worker instance. + pg (PlacementGroup): The existing placement group to use. + num_workers (int): The number of workers to create. Defaults to -1, + the remaining bundles in the placement group will be used. + start_bundle_idx (int | None): Bundle index to start from. If + omitted, the next unconsumed local bundle range for this + placement group will be used. + + Returns: + List[T]: List of created worker instances. + """ + start_bundle_idx, num_workers = cls._reserve_bundle_range( + pg=pg, num_workers=num_workers, start_bundle_idx=start_bundle_idx + ) + default_cpu = cls._get_bundle_resources(pg, start_bundle_idx).get("CPU", 1) + return cls.build_actors( + worker_cls, + worker_config, + num_cpus=default_cpu, + pg=pg, + start_bundle_idx=start_bundle_idx, + num_workers=num_workers, + actor_num_cpus_per_worker=default_cpu, + actor_memory_per_worker=None, + ) diff --git a/xtuner/v1/train/cli/rl.py b/xtuner/v1/train/cli/rl.py index 0a91ee1edb..444c6aa0fb 100644 --- a/xtuner/v1/train/cli/rl.py +++ b/xtuner/v1/train/cli/rl.py @@ -10,7 +10,6 @@ from cyclopts.group import Group from xtuner.v1.rl.utils import register_cleanup -from xtuner.v1.train.rl_trainer import RLTrainer from xtuner.v1.utils import Config from xtuner.v1.utils.track_rl_mem import monitor_actor_memory @@ -56,7 +55,7 @@ def main( track_thread.start() trainer_cfg = Config.fromfile(config)["trainer"] - trainer = RLTrainer.from_config(trainer_cfg) + trainer = trainer_cfg.build() trainer.fit() if dist.is_initialized(): diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index eff55b7b3c..998e359f9b 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -1,832 +1,674 @@ import json import os import random -from datetime import datetime from pathlib import Path from shutil import rmtree -from typing import List, cast +from typing import Any, List, cast import ray import torch -from mmengine import load from mmengine.dist import get_rank from mmengine.runner import set_random_seed -from pydantic import BaseModel, ConfigDict, model_validator -from ray.util.placement_group import placement_group -from typing_extensions import Literal, Self, TypedDict +from pydantic import BaseModel, ConfigDict, Field, model_validator +from typing_extensions import Literal, TypedDict from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast from xtuner.v1._writer import get_writer -from xtuner.v1.data_proto.rl_data import MultimodalTrainInfo, RLDataFlowItem, is_valid_for_training +from xtuner.v1.data_proto.rl_data import RolloutState, Status from xtuner.v1.data_proto.sequence_context import SequenceContext from xtuner.v1.patch import patch_default_save_plan -from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers, CPUResourcesConfig -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.dataflow import DataFlow, DataFlowConfig, DataFlowProxy, ReplayBufferConfig -from xtuner.v1.ray.environment import SingleTurnEnvironment, SingleTurnEnvironmentProxy -from xtuner.v1.ray.evaluator import Evaluator, EvaluatorConfig -from xtuner.v1.ray.judger import JudgerConfig -from xtuner.v1.rl.base import ( - TrainingController, - TrainingControllerProxy, - TrainingWorkerClass, - TrainingWorkerProxy, - WorkerConfig, - WorkerLogItem, +from xtuner.v1.rl.advantage import BaseAdvantageConfig, GRPOAdvantageConfig +from xtuner.v1.rl.agent_loop_manager import ( + AgentLoopManagerConfig, + AgentLoopManagerStatus, + ProduceBatchResult, + ProduceBatchStatus, ) -from xtuner.v1.rl.base import TrainingWorker as BaseTrainingWorker -from xtuner.v1.rl.config.advantage import BaseAdvantageConfig, GRPOAdvantageConfig -from xtuner.v1.train import ResumeConfig -from xtuner.v1.utils import XTUNER_DETERMINISTIC, get_logger, is_hf_model_path, record_git_info, timer +from xtuner.v1.rl.agent_loop_manager.producer import default_should_continue_fn +from xtuner.v1.rl.evaluator import EvaluatorConfig +from xtuner.v1.rl.gateway.config import GatewayConfig +from xtuner.v1.rl.replay_buffer import AsyncReplayBufferConfig, SyncReplayBufferConfig +from xtuner.v1.rl.rollout.controller import RolloutControllerProxy +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.trainer.controller import TrainingController +from xtuner.v1.rl.trainer.worker import WorkerConfig, WorkerLogItem +from xtuner.v1.rl.utils import ( + AcceleratorResourcesConfig, + AutoAcceleratorWorkers, + asyncio_run, + create_task, + sort_rollout_state_for_deterministic, +) +from xtuner.v1.train.trainer import LoadCheckpointConfig, XTunerMeta +from xtuner.v1.utils import XTUNER_DETERMINISTIC, get_logger, is_hf_model_path, set_deterministic, timer from xtuner.v1.utils.device import get_device, get_torch_device_module -from xtuner.v1.utils.env_check import get_rollout_engine_version - -from .trainer import ExpHistory, ExpInfo, GitInfo, LoadCheckpointConfig, XTunerMeta # TODO: Move DEVICE to `xtuner.utils.device` PG_READY_TIMEOUT = 30 -TRAINER_RAY_GET_TIMEOUT = 5 * 3600 # 5 hour DEVICE = get_device() DEVICE_MODULE = get_torch_device_module() +def check_fa3(): + if os.environ.get("XTUNER_USE_FA3", "0") != "1": + return + + try: + from xtuner.v1.ops.flash_attn import get_flash_attn_varlen + + get_flash_attn_varlen() + except RuntimeError as e: + raise RuntimeError(f"Flash attention v3 runtime error {e}, Please install it first or set XTUNER_USE_FA3=0.") + + def bind_train_rollout( - train_controller, - env_controller, + train_controller: TrainingController, + rollout_controller: RolloutControllerProxy, ) -> None: """Bind the training and rollout workers for update weights.""" - info_dict = ray.get(env_controller.get_rollout_info.remote()) # type: ignore[attr-defined] - ray.get(train_controller.update_rollout_info.remote(info_dict)) + info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) # type: ignore[attr-defined] + train_controller.update_rollout_info(info_dict) return -class RolloutInfo(TypedDict): - data_groups: list[list[RLDataFlowItem]] - multimodal_train_infos: list[MultimodalTrainInfo] - task_time: dict[str, float] - replay_buffer_info: dict[str, float] - - -class TrainInfo(TypedDict): +class TrainInfo(TypedDict, total=False): data_info: dict[str, float] workers_log_item: list[WorkerLogItem] -class RLTrainerConfig(BaseModel): - model_config = ConfigDict(extra="forbid") - load_from: str | Path - resources: AcceleratorResourcesConfig - cpu_resources: CPUResourcesConfig | None = None +def get_train_seq_ctx( + input_ids: torch.LongTensor, + position_ids: torch.Tensor | None = None, + multimodal_train_info: dict | None = None, + len_response_ids: int = 0, +): + seq_ctx = SequenceContext.from_input_ids((input_ids,), device="cpu") + if position_ids is not None and len(position_ids.shape) == 3: + # qwen3vl 需要特殊处理,其余的不需要额外处理 + max_value = position_ids.max(dim=-1).values # (3,1) + response_position_ids = max_value.unsqueeze(-1).expand(-1, -1, len_response_ids) + torch.arange( + 1, len_response_ids + 1, device=max_value.device + ) + position_ids = torch.cat([position_ids, response_position_ids], dim=-1) + seq_ctx.position_ids = position_ids # type: ignore[assignment] + assert position_ids.size(-1) == input_ids.size(-1) + + if multimodal_train_info: + seq_ctx.pixel_values = multimodal_train_info.get("pixel_values") + seq_ctx.image_grid_thw = multimodal_train_info.get("image_grid_thw") + return seq_ctx + + +def is_valid_for_training(group_data_items: list[RolloutState], logger) -> bool: + """Checks if a group of rollout states is valid for a training step. + + Args: + group_data_items: A list of RolloutState objects. + + Returns: + True if the group is valid, False otherwise. + + NOTE: Why this check is needed: + - For system fault tolerance, this check is performed at rollout / dataflow + time, but we still do it here to ensure training data integrity. + - 'filtered'/'failed': These items are fundamentally broken or incomplete and + should not be used for training. + - 'aborted': These items represent rollouts that were stopped + prematurely. Using such partial data could lead the model to learn + undesirable behaviors (e.g., stopping generation too early). + - Empty response/response_ids: The model's generated response is the core + of the training data for RL algorithms like PPO. If the response is + missing, there is nothing to compute rewards on or to train the model with. + """ + is_abort = any(item.status == Status.ABORTED for item in group_data_items) + is_filtered = any(item.status == Status.FILTERED for item in group_data_items) + is_failed = any(item.status == Status.FAILED for item in group_data_items) + if is_filtered or is_failed or is_abort: + logger.warning( + f"Invalid dataflow group found during training, rollout state filtered: {is_filtered}, failed: {is_failed}, aborted: {is_abort}." + ) + return False + for item in group_data_items: + response_valid = item.response is not None and len(item.response) > 0 + ids_valid = item.response_ids is not None and len(item.response_ids) > 0 + if not ids_valid: + # NOTE: `response_ids` is the critical field for token-in-token-out mode, so we ensure it's not empty. + logger.warning( + "Invalid dataflow item found during training: no response or response_ids and skip this item." + ) + return False + if not response_valid: + # NOTE: check valid response string for judger inputs + logger.warning("Invalid dataflow item found during training: empty response string and skip this item.") + return False + return True + + +def _validate_sync_intervals( + sync_weights_interval: int, + checkpoint_interval: int | None, + hf_interval: int | None, + evaluate_step: int | None = None, + enable_evaluate: bool = False, +) -> None: + if sync_weights_interval <= 0: + raise ValueError(f"sync_weights_interval must be positive, got {sync_weights_interval}.") + + for name, interval in ( + ("checkpoint_interval", checkpoint_interval), + ("hf_interval", hf_interval), + ): + if interval is None or interval == -1: + continue + if interval <= 0: + raise ValueError(f"{name} must be positive or -1/None to disable it, got {interval}.") + if interval % sync_weights_interval != 0: + raise ValueError( + f"{name}={interval} must be a multiple of sync_weights_interval={sync_weights_interval}, " + "because checkpoint/HF saves only run on weight-sync steps." + ) + + if enable_evaluate: + if evaluate_step is None or evaluate_step <= 0: + raise ValueError(f"evaluate_step must be positive when evaluation is enabled, got {evaluate_step}.") + if evaluate_step % sync_weights_interval != 0: + raise ValueError( + f"evaluate_step={evaluate_step} must be a multiple of " + f"sync_weights_interval={sync_weights_interval}, because evaluation only runs on weight-sync steps." + ) + + +class BaseRLTrainerConfig(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") + + train_worker_cfg: WorkerConfig rollout_config: RolloutConfig - dataflow_config: DataFlowConfig - judger_config: JudgerConfig - replay_buffer_config: ReplayBufferConfig - train_worker_config: WorkerConfig - evaluator_config: EvaluatorConfig | None = None tokenizer_path: str | Path + replay_buffer_config: SyncReplayBufferConfig | AsyncReplayBufferConfig = SyncReplayBufferConfig() + agent_loop_manager_cfg: AgentLoopManagerConfig + eval_agent_loop_manager_cfg: AgentLoopManagerConfig + evaluator_config: EvaluatorConfig + load_from: str | Path + total_train_steps: int | None = None + total_epochs: int | None = None + train_batch_size: int + advantage_estimator_config: BaseAdvantageConfig = Field(default_factory=GRPOAdvantageConfig) + sync_weights_interval: int = 1 + gateway_config: GatewayConfig | None = None + + enable_evaluate: bool = True + enable_initial_evaluate: bool = False + evaluate_step: int = 1 work_dir: Path | str | None = None - log_dir: Path | str | None = None - total_epochs: int - resume_config: ResumeConfig | None = None auto_resume: bool = False load_checkpoint_cfg: LoadCheckpointConfig = LoadCheckpointConfig() - strict_load: bool = True checkpoint_interval: int | None = -1 checkpoint_maxkeep: int | None = -1 + hf_interval: int | None = -1 + hf_max_keep: int | None = -1 checkpoint_no_save_optimizer: bool = False - skip_checkpoint_validation: bool = False # Suggest enabled if fsdp_size is larger than 512 - hf_interval: int | None = None - hf_max_keep: int | None = None - seed: int = 42 - debug: bool = False + log_dir: Path | str | None = None + seed: int = 66 debug_rollout: bool = False - rollout_steps: int | None = None - display_all_workers_log: bool = False + skip_checkpoint_validation: bool = False exp_tracker: Literal["tensorboard", "jsonl"] = "tensorboard" - advantage_estimator_config: BaseAdvantageConfig = GRPOAdvantageConfig() @model_validator(mode="after") - def _convert_work_dir(self): - if isinstance(self.work_dir, str): - self.work_dir = Path(self.work_dir) - elif self.work_dir is None: - self.work_dir = Path.cwd() + def _validate_sync_intervals(self): + if self.total_train_steps is None and self.total_epochs is None: + raise ValueError("Either total_train_steps or total_epochs must be provided.") + if self.total_train_steps is not None and self.total_train_steps <= 0: + raise ValueError(f"total_train_steps must be positive, got {self.total_train_steps}.") + if self.total_epochs is not None and self.total_epochs <= 0: + raise ValueError(f"total_epochs must be positive, got {self.total_epochs}.") + _validate_sync_intervals( + sync_weights_interval=self.sync_weights_interval, + checkpoint_interval=self.checkpoint_interval, + hf_interval=self.hf_interval, + evaluate_step=self.evaluate_step, + enable_evaluate=self.enable_evaluate, + ) return self -def get_train_seq_ctx( - input_ids: torch.LongTensor, multimodal_train_info: dict | None = None, len_response_ids: int = 0 -): - seq_ctx = SequenceContext.from_input_ids((input_ids,), device="cpu") - if multimodal_train_info and len(multimodal_train_info) > 0: - position_ids = multimodal_train_info.get("position_ids") # (1,n) or (3,1,n) - if position_ids is not None and len(position_ids.shape) == 3: - # qwen3vl 需要特殊处理,其余的不需要额外处理 - max_value = position_ids.max(dim=-1).values # (3,1) - response_position_ids = max_value.unsqueeze(-1).expand(-1, -1, len_response_ids) + torch.arange( - 1, len_response_ids + 1, device=max_value.device - ) - position_ids = torch.cat([position_ids, response_position_ids], dim=-1) - seq_ctx.position_ids = position_ids # type: ignore[assignment] - assert position_ids.size(-1) == input_ids.size(-1) - seq_ctx.pixel_values = multimodal_train_info.get("pixel_values") - seq_ctx.image_grid_thw = multimodal_train_info.get("image_grid_thw") - return seq_ctx +class RLColocateTrainerConfig(BaseRLTrainerConfig): + resources: AcceleratorResourcesConfig + def build(self) -> "RLColocateTrainer": + return RLColocateTrainer(self) -class RLTrainer: - """Universal Reinforcement Learning Trainer for XTuner. - A flexible RL training orchestrator that supports multiple RL algorithms - through pluggable training workers and controllers. Manages the complete - RL training workflow including rollout generation, policy updates, - evaluation, and checkpoint management. +class RLDisaggregatedTrainerConfig(BaseRLTrainerConfig): + train_resources: AcceleratorResourcesConfig + rollout_resources: AcceleratorResourcesConfig - **Training Workflow:** - 1. Initialize distributed workers and rollout environment - 2. Generate experiences using current policy - 3. Update policy using algorithm-specific training logic - 4. Synchronize weights between training and rollout workers - 5. Evaluate model performance and save checkpoints + def build(self) -> "RLDisaggregatedTrainer": + return RLDisaggregatedTrainer(self) - Args: - load_from (str | Path): Path to the base model to load. Should be a HuggingFace - model path (e.g., "meta-llama/Llama-2-7b-hf") or local model directory. - resources (AcceleratorResourcesConfig): Configuration for distributed computing - resources including number of workers, GPU allocation, and placement groups. - rollout_config (RolloutConfig): Configuration for rollout workers that generate - experiences by interacting with the environment. - dataflow_config (DataFlowConfig): Data orchestration configuration controlling - experience collection, batch formation, and data distribution across workers. - judger_config (JudgerConfig): Configuration for the reward model or scoring system - that evaluates generated responses and provides training signals. - replay_buffer_config (ReplayBufferConfig): Settings for experience replay buffer - including capacity, sampling strategy, and data retention policies. - evaluator_config (EvaluatorConfig | None): Evaluation configuration specifying metrics, - evaluation datasets, and assessment frequency for monitoring training progress. Defaults to None. - train_worker_cfg (WorkerConfig): Configuration for distributed training workers - including model architecture, optimizer settings, loss functions, and parallelism. - tokenizer_path (str | Path): Path to the tokenizer for text preprocessing. - Should be compatible with the base model specified in load_from. - work_dir (Path | str | None): Working directory for experiment outputs, - checkpoints, and logs. Defaults to None. - log_dir (Path | str | None): Directory for training logs and monitoring outputs. - Defaults to None. - total_epochs (int): Total number of training epochs to execute. - enable_evaluate (bool): Whether to perform periodic evaluation during training. - resume_config (ResumeConfig | None): Configuration for resuming training from - a previous checkpoint. Defaults to None. - auto_resume (bool): Whether to automatically resume training. Defaults to False. - load_checkpoint_cfg (LoadCheckpointConfig): Configuration for loading checkpoints. - strict_load (bool): Whether to strictly enforce checkpoint loading compatibility. - Defaults to True. - hf_interval (int | None): Interval (in epochs) for saving HuggingFace format - checkpoints. Defaults to None. - hf_max_keep (int | None): Maximum number of HuggingFace checkpoints to retain. - Defaults to None. - seed (int): Random seed for reproducible training. Defaults to 42. - debug (bool): Enable debug mode with additional logging. Defaults to False. - debug_rollout (bool): Enable debug mode for rollout workers. Defaults to False. - rollout_steps (int | None): Total number of rollout steps to perform. - If specified, overrides total_epochs. Defaults to None. - display_all_workers_log (bool): Whether to display logs from all workers. Defaults to False. - exp_tracker (Literal["tensorboard", "jsonl"]): Type of experiment tracker to use. - Options are "tensorboard" or "jsonl". Defaults to "tensorboard". - - **Examples:** - - Example configuration for GRPO RL training setup:: - - trainer = RLTrainer( - load_from="Qwen3-8B", - resources=resources_config, - rollout_config=rollout_cfg, - dataflow_config=dataflow_cfg, - judger_config=judger_cfg, - replay_buffer_config=buffer_cfg, - evaluator_config=eval_cfg, - train_worker_cfg=worker_cfg, - tokenizer_path="Qwen3-8B", - total_epochs=10, - enable_evaluate=True - ) - trainer.fit() - """ - META_PATH = ".xtuner_grpo" +class BaseRLTrainer: _EXP_TRACKING_PATH = "exp_tracking" _CHECKPOINT_DIR = "checkpoints" + _HF_DIR = "hf" _SAVE_TRAIN_STATE_PATH = "train_state.json" - def __init__( - self, - *, - load_from: str | Path, # Huggingface model path or saved trainer_path - resources: AcceleratorResourcesConfig, - cpu_resources: CPUResourcesConfig | None = None, - rollout_config: RolloutConfig, - dataflow_config: DataFlowConfig, - judger_config: JudgerConfig, - replay_buffer_config: ReplayBufferConfig, - train_worker_cfg: WorkerConfig, - evaluator_config: EvaluatorConfig | None = None, - tokenizer_path: str | Path, - work_dir: Path | str | None = None, - log_dir: Path | str | None = None, - total_epochs: int, - auto_resume: bool = False, - load_checkpoint_cfg: LoadCheckpointConfig = LoadCheckpointConfig(), - strict_load: bool = True, - checkpoint_interval: int | None = -1, - checkpoint_maxkeep: int | None = -1, - checkpoint_no_save_optimizer: bool = False, - skip_checkpoint_validation: bool = False, # Suggest enabled if fsdp_size is larger than 512 - hf_interval: int | None = None, - hf_max_keep: int | None = None, - seed: int = 42, - debug: bool = False, - debug_rollout: bool = False, - rollout_steps: int | None = None, - exp_tracker: Literal["tensorboard", "jsonl"] = "tensorboard", - display_all_workers_log: bool = False, - trainer_cfg: RLTrainerConfig | None = None, - advantage_estimator_config: BaseAdvantageConfig = GRPOAdvantageConfig(), - ): - """Initialize the RL training system.""" - if os.environ.get("XTUNER_USE_FA3", "0") == "1": - try: - from xtuner.v1.ops.flash_attn import get_flash_attn_varlen - - get_flash_attn_varlen() - except RuntimeError as e: - raise RuntimeError( - f"Flash attention v3 runtime error {e}, Please install it first or set XTUNER_USE_FA3=0." - ) - train_worker_cfg.load_from = load_from - - self._total_epochs = total_epochs - self._cur_step = 0 - self._global_train_step = 1 - - if skip_checkpoint_validation: - patch_default_save_plan() - - self._rl_trainer_cfg = trainer_cfg - self._load_from = Path(load_from) if isinstance(load_from, str) else load_from + train_controller: TrainingController + rollout_controller: RolloutControllerProxy + + def _init_common(self, cfg: BaseRLTrainerConfig, *, meta_path: str, logger_tag: str) -> None: + check_fa3() + self._init_work_dir_and_meta(cfg, meta_path) + self._init_load_source(cfg) + self._init_save_config(cfg) + log_dir = self._init_logger(cfg, logger_tag) + self._init_train_state(cfg) + self._init_train_worker_config(cfg, log_dir) + self._init_rollout_config(cfg, log_dir) + self._init_runtime_flags(cfg) + self._advantage_estimator = cfg.advantage_estimator_config.build() + + self._exp_tracker = get_writer(writer_type=cfg.exp_tracker, log_dir=log_dir / self._EXP_TRACKING_PATH) + self._display_all_workers_log = False + + def _init_work_dir_and_meta(self, cfg: BaseRLTrainerConfig, meta_path: str) -> None: + work_dir = Path(cfg.work_dir) if cfg.work_dir else Path.cwd() / "work_dirs" + if get_rank() == 0: + work_dir.mkdir(parents=True, exist_ok=True) + self._meta = XTunerMeta.build(work_dir, meta_path, cfg.auto_resume) + self._meta_path = meta_path - is_hf_path, error_info = is_hf_model_path(load_from) if load_from is not None else False, "" + def _init_load_source(self, cfg: BaseRLTrainerConfig) -> None: + self._load_from = Path(cfg.load_from) if isinstance(cfg.load_from, str) else cfg.load_from + is_hf_path, error_info = is_hf_model_path(cfg.load_from) if cfg.load_from is not None else (False, "") self._load_from_hf = is_hf_path - if not self._load_from_hf: raise NotImplementedError(error_info) - self._hf_max_keep = hf_max_keep - self._hf_interval = hf_interval - self._checkpoint_interval = checkpoint_interval - self._checkpoint_maxkeep = checkpoint_maxkeep - self._checkpoint_no_save_optimizer = checkpoint_no_save_optimizer - - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True) - - self._debug = debug - self._debug_rollout = debug_rollout - self._seed = seed - self._set_deterministic() - self._set_random_seed(seed) + def _init_save_config(self, cfg: BaseRLTrainerConfig) -> None: + self._hf_max_keep = cfg.hf_max_keep + self._hf_interval = cfg.hf_interval - if work_dir is None: - work_dir = Path.cwd() / "work_dir" + self._checkpoint_interval = cfg.checkpoint_interval + self._checkpoint_maxkeep = cfg.checkpoint_maxkeep + self._checkpoint_no_save_optimizer = cfg.checkpoint_no_save_optimizer + self._load_checkpoint_cfg = self._resolve_load_checkpoint_cfg(cfg.auto_resume, cfg.load_checkpoint_cfg) - if isinstance(work_dir, str): - work_dir = Path(work_dir) + def _init_logger(self, cfg: BaseRLTrainerConfig, logger_tag: str) -> Path: + log_dir = self.exp_dir / "logs" + self.logger = get_logger(log_dir=log_dir, tag=logger_tag) - if get_rank() == 0: - work_dir.mkdir(parents=True, exist_ok=True) - - self._work_dir = work_dir - self._auto_resume = auto_resume - self._meta = self._init_xtuner_meta(work_dir, self._auto_resume) - - if log_dir is None: - log_dir = self.exp_dir - if isinstance(log_dir, str): - log_dir = Path(log_dir) - - self.logger = self._init_logger(log_dir) - - self._load_checkpoint_cfg = self._resolve_load_checkpoint_cfg(self._auto_resume, load_checkpoint_cfg) - - if train_worker_cfg.seed is None: - self.logger.warning(f"RLTrainer seed {seed} is used as train worker seed.") - train_worker_cfg.seed = seed - - train_worker_cfg.log_dir = log_dir - dataflow_config.worker_log_dir = log_dir - rollout_config.worker_log_dir = log_dir - self._enable_evaluate = False - self._enable_initial_evaluate = False - if evaluator_config: - evaluator_config.worker_log_dir = log_dir - self._enable_evaluate = evaluator_config.enable_evaluate - self._enable_initial_evaluate = evaluator_config.enable_initial_evaluate - self._pg = AutoAcceleratorWorkers.build_placement_group(resources) - - if cpu_resources is not None: - # NOTE: Here we only check CPU and memory for judger actors because only judger actors use CPU resources currently. - assert judger_config.total_cpus_needed <= cpu_resources.num_cpus_per_worker * cpu_resources.num_workers, ( - f"Not enough CPU resources for judger actors, " - f"required {judger_config.total_cpus_needed}, but got {cpu_resources.num_cpus_per_worker * cpu_resources.num_workers}." - ) - assert ( - judger_config.total_memory_needed <= cpu_resources.cpu_memory_per_worker * cpu_resources.num_workers - ), ( - f"Not enough memory resources for judger actors, " - f"required {judger_config.total_memory_needed}, but got {cpu_resources.cpu_memory_per_worker * cpu_resources.num_workers}." - ) - - self._judger_cpu_pg = placement_group(bundles=judger_config.total_bundles_needed, strategy="SPREAD") - ray.get(self._judger_cpu_pg.ready(), timeout=PG_READY_TIMEOUT) - - # We need to build train controller first, and then build rollout dataflow to make - # inference engines know how much memory they can utilize. - self._train_controller = self._build_train_controller(train_worker_cfg) + if cfg.skip_checkpoint_validation: + patch_default_save_plan() + return log_dir + def _init_train_state(self, cfg: BaseRLTrainerConfig) -> None: + self._total_train_steps = cfg.total_train_steps or 0 + self._total_epochs = cfg.total_epochs + self._cur_step = 0 + self._global_train_step = 0 + self._seed = cfg.seed + self.train_batch_size = cfg.train_batch_size + self._sync_weights_interval = cfg.sync_weights_interval + set_deterministic() + set_random_seed(cfg.seed) + + def _init_train_worker_config(self, cfg: BaseRLTrainerConfig, log_dir: Path) -> None: + if cfg.train_worker_cfg.seed is None: + self.logger.warning(f"RLTrainer seed {cfg.seed} is used as train worker seed.") + cfg.train_worker_cfg.seed = cfg.seed + cfg.train_worker_cfg.load_from = cfg.load_from + cfg.train_worker_cfg.log_dir = log_dir + self._train_worker_cfg = cfg.train_worker_cfg + + def _init_rollout_config(self, cfg: BaseRLTrainerConfig, log_dir: Path) -> None: + cfg.rollout_config.worker_log_dir = log_dir if self._load_checkpoint_cfg.checkpoint_path is not None: - rollout_config.skip_load_weights = True + cfg.rollout_config.skip_load_weights = True self.logger.info( f"Skip load rollout weights due to resume from checkpoint {self._load_checkpoint_cfg.checkpoint_path}" ) + self._rollout_config = cfg.rollout_config - # resume train worker - ray.get(self._train_controller.resume.remote(self._load_checkpoint_cfg)) - - train_state_path = Path(self._load_checkpoint_cfg.checkpoint_path) / self._SAVE_TRAIN_STATE_PATH - with train_state_path.open("r") as f: - train_state = json.load(f) - self._cur_step = train_state["cur_step"] + def _init_runtime_flags(self, cfg: BaseRLTrainerConfig) -> None: + self._enable_evaluate = cfg.enable_evaluate + self._enable_initial_evaluate = cfg.enable_initial_evaluate + self._evaluate_step = cfg.evaluate_step + self._debug_rollout = cfg.debug_rollout - self._rollout_env_controller, self._rollout_dataflow = self._build_rollout_dataflow( - dataflow_cfg=dataflow_config, - rollout_cfg=rollout_config, - judger_cfg=judger_config, - replay_buffer_config=replay_buffer_config, + def _maybe_start_gateway(self, cfg: BaseRLTrainerConfig) -> None: + if cfg.gateway_config is None or not cfg.gateway_config.auto_start: + return + # gateway 依赖 rollout controller,因此在 rollout controller 构建完成后统一启动。 + ray.get(self.rollout_controller.start_gateway.remote(cfg.gateway_config)) + + def _build_agent_loop_components(self, cfg: BaseRLTrainerConfig, replay_buffer) -> None: + self.tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer_path, trust_remote_code=True) + self.agent_loop_manager = cfg.agent_loop_manager_cfg.build( + rollout_controller=self.rollout_controller, + tokenizer=self.tokenizer, + replay_buffer=replay_buffer, + logger=self.logger, + sync_weights_interval=cfg.sync_weights_interval, ) - self._dataflow_partial_rollout_step = dataflow_config.tail_batch_candidate_steps - if self._load_checkpoint_cfg.checkpoint_path is not None: - # resume rollout dataflow - self.logger.info(f"Resume rollout dataflow from checkpoint {self._load_checkpoint_cfg.checkpoint_path}") - ray.get(self._rollout_dataflow.resume.remote(self._load_checkpoint_cfg.checkpoint_path)) - - if self._enable_evaluate and evaluator_config: - self._evaluator = Evaluator.remote(evaluator_config, self._rollout_env_controller) # type: ignore[attr-defined] - self._eval_step = evaluator_config.evaluate_step - else: - pass - - self._global_batch_size = dataflow_config.global_batch_size - self._rollout_steps = ( - ray.get(self._rollout_dataflow.get_train_dataset_length.remote()) # type: ignore[attr-defined] - // dataflow_config.global_batch_size - * total_epochs + self.eval_agent_loop_manager = cfg.eval_agent_loop_manager_cfg.build( + rollout_controller=self.rollout_controller, + tokenizer=self.tokenizer, + replay_buffer=replay_buffer, + logger=self.logger, + sync_weights_interval=cfg.sync_weights_interval, ) - if rollout_steps is not None: - self._rollout_steps = rollout_steps - self.logger.info(f"Set rollout steps to {self._rollout_steps} according to rollout_steps arg") - - bind_train_rollout(train_controller=self._train_controller, env_controller=self._rollout_env_controller) - # update weights if rollout_config.skip_load_weights == True - if rollout_config.skip_load_weights: - self.logger.info("Rollout workers skip load weights, update weights from train workers.") - ray.get(self._train_controller.offload.remote(target="optimizer")) - ray.get(self._rollout_env_controller.offload.remote()) - ray.get(self._rollout_env_controller.onload_weights.remote()) - ray.get(self._train_controller.update_weights.remote()) - ray.get(self._train_controller.offload.remote(target="model")) - ray.get(self._rollout_env_controller.onload_kvcache.remote()) - self.logger.info("Rollout workers has updated weights from train workers.") - else: - ray.get(self._train_controller.offload.remote(target="all")) - self._train_worker_cfg = train_worker_cfg + total_eval_samples = len(self.eval_agent_loop_manager.data_sampler) + self.evaluator = cfg.evaluator_config.build(total_eval_samples=total_eval_samples) + self._resolve_total_train_steps(cfg) - if self._rl_trainer_cfg is not None and get_rank() == 0: - config_path = log_dir / "rl_trainer_config.json" - with config_path.open("w") as f: - f.write(self._rl_trainer_cfg.model_dump_json(indent=2)) - - env_path = log_dir / "env.json" - environment_variables = dict(os.environ) - infer_engine_version = get_rollout_engine_version() - environment_variables.update(infer_engine_version) - with env_path.open("w") as f: - json.dump(environment_variables, f, indent=2) + def _resolve_total_train_steps(self, cfg: BaseRLTrainerConfig) -> None: + if cfg.total_train_steps is not None: + self._total_train_steps = cfg.total_train_steps + return - self._ray_get_timeout = max( - TRAINER_RAY_GET_TIMEOUT, rollout_config.rollout_timeout, judger_config.judger_timeout + assert cfg.total_epochs is not None + dataset_size = len(self.agent_loop_manager.data_sampler) + self._total_train_steps = dataset_size // cfg.train_batch_size * cfg.total_epochs + self.logger.info( + "Resolved total_train_steps from total_epochs: " + f"dataset_size={dataset_size}, train_batch_size={cfg.train_batch_size}, " + f"total_epochs={cfg.total_epochs}, total_train_steps={self._total_train_steps}" ) - self._exp_tracker = self._init_tracker(exp_tracker, log_dir / self._EXP_TRACKING_PATH) - self._display_all_workers_log = display_all_workers_log - self._advantage_estimator = advantage_estimator_config.build() + @property + def exp_dir(self) -> Path: + return Path(self._meta.latest_exp.exp_dir) def _resolve_load_checkpoint_cfg( self, auto_resume: bool, load_checkpoint_cfg: LoadCheckpointConfig ) -> LoadCheckpointConfig: - # auto_resume优先级高,如果有latest ckp,则说明走auto_resume逻辑 - # 此时,覆盖load checkpoint path - latest_checkpoint = self.meta.latest_exp.latest_checkpoint + """Resolve checkpoint path for auto-resume.""" + latest_checkpoint = self._meta.latest_exp.latest_checkpoint if latest_checkpoint is not None and auto_resume: load_checkpoint_cfg.checkpoint_path = Path(latest_checkpoint) return load_checkpoint_cfg - def _init_tracker(self, exp_tracker: Literal["tensorboard", "jsonl"], work_dir: Path): - writer = get_writer(writer_type=exp_tracker, log_dir=work_dir) - return writer - - @classmethod - def from_config(cls, config: RLTrainerConfig) -> Self: - """Create a Trainer instance from a TrainerConfig. - - Args: - config (TrainerConfig): TrainerConfig instance containing all configuration parameters. - - Returns: - Self: Trainer instance initialized with the provided config. - """ - self = cls( - load_from=config.load_from, - resources=config.resources, - cpu_resources=config.cpu_resources, - rollout_config=config.rollout_config, - dataflow_config=config.dataflow_config, - judger_config=config.judger_config, - replay_buffer_config=config.replay_buffer_config, - train_worker_cfg=config.train_worker_config, - evaluator_config=config.evaluator_config, - tokenizer_path=config.tokenizer_path, - work_dir=config.work_dir, - log_dir=config.log_dir, - total_epochs=config.total_epochs, - auto_resume=config.auto_resume, - load_checkpoint_cfg=config.load_checkpoint_cfg, - strict_load=config.strict_load, - checkpoint_interval=config.checkpoint_interval, - checkpoint_maxkeep=config.checkpoint_maxkeep, - checkpoint_no_save_optimizer=config.checkpoint_no_save_optimizer, - hf_interval=config.hf_interval, - hf_max_keep=config.hf_max_keep, - skip_checkpoint_validation=config.skip_checkpoint_validation, - seed=config.seed, - debug=config.debug, - debug_rollout=config.debug_rollout, - rollout_steps=config.rollout_steps, - exp_tracker=config.exp_tracker, - trainer_cfg=config, - advantage_estimator_config=config.advantage_estimator_config, - ) - return self + def _resume_train_controller_and_state(self, checkpoint_path: Path | str) -> Path: + # 子类只复用训练 worker 和 train_state 恢复,权重同步流程各自维护。 + checkpoint_path = Path(checkpoint_path) + self.train_controller.resume(self._load_checkpoint_cfg) - def _build_rollout_dataflow( - self, - dataflow_cfg: DataFlowConfig, - rollout_cfg: RolloutConfig, - judger_cfg: JudgerConfig, - replay_buffer_config: ReplayBufferConfig, - ) -> tuple[SingleTurnEnvironmentProxy, DataFlowProxy]: - env = SingleTurnEnvironment.remote("grpo", self._pg, rollout_cfg, self._judger_cpu_pg, judger_cfg) - flow = DataFlow.remote("grpo", dataflow_cfg, replay_buffer_config, env) - return env, flow - - def _build_train_controller(self, train_worker_cfg: WorkerConfig) -> TrainingControllerProxy: - TrainingWorker = cast( - TrainingWorkerClass, - ray.remote( - runtime_env={ - "env_vars": { - "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1", - "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "1", - "HCCL_NPU_SOCKET_PORT_RANGE": "auto", - } - }, - )(BaseTrainingWorker), - ) - train_workers: list[TrainingWorkerProxy] - train_workers, _ = AutoAcceleratorWorkers.from_placement_group(TrainingWorker, train_worker_cfg, self._pg) - ray.wait([worker.ready.remote() for worker in train_workers]) - train_controller = TrainingController.remote(workers=train_workers) - return train_controller - - def _initial_evaluate(self): - """Performs an initial evaluation before the training loop starts.""" - if self._debug_rollout: + train_state_path = checkpoint_path / self._SAVE_TRAIN_STATE_PATH + with train_state_path.open("r") as f: + train_state = json.load(f) + self._cur_step = train_state["cur_step"] + return checkpoint_path + + def _maybe_save_checkpoint(self, cur_step: int) -> None: + """Save checkpoint if interval condition is met.""" + ckp_interval = self._checkpoint_interval + if ckp_interval is None or ckp_interval == -1: + return + if cur_step % ckp_interval != 0: return - if self._enable_initial_evaluate and self._enable_evaluate and self._evaluator: - ray.get(self._rollout_env_controller.update_active_workers.remote()) - scores, eval_data_groups = ray.get(self._evaluator.run.remote(return_samples=True)) - trajectory_save_path = self.exp_dir / "eval_0_trajectory.jsonl" - self._save_trajectories(eval_data_groups, trajectory_save_path) - self.logger.info(f"Initial rollout evaluate scores {scores} and start training") - tb_scores = {f"eval/{k}": v for k, v in scores.items()} - self._exp_tracker.add_scalars( - tag_scalar_dict=tb_scores, - global_step=0, - ) - @staticmethod - def _group_item_sort_key(item): - return item.uid.action_id, item.uid.observation_id + checkpoint_path = self.exp_dir / self._CHECKPOINT_DIR / f"ckpt-step-{cur_step}" + checkpoint_path.mkdir(parents=True, exist_ok=True) - def _sort_rollout_outputs(self, data_groups, multimodal_train_infos=None): - sorted_data_groups = [sorted(group, key=self._group_item_sort_key) for group in data_groups] + # 1. Save sampler (dataloader) state + self.logger.info(f"Saving sampler state to {checkpoint_path}") + self.agent_loop_manager.save(checkpoint_path, model_step=cur_step) - if multimodal_train_infos: - grouped_items = list(zip(sorted_data_groups, multimodal_train_infos)) - grouped_items.sort(key=lambda item: item[0][0].uid.root_id if item[0] else -1) - return [item[0] for item in grouped_items], [item[1] for item in grouped_items] + # 2. Save DCP checkpoint (model + optimizer) + self.logger.info(f"Saving DCP checkpoint to {checkpoint_path}") + self.train_controller.save(str(checkpoint_path), self._checkpoint_no_save_optimizer) - sorted_data_groups.sort(key=lambda group: group[0].uid.root_id if group else -1) - return sorted_data_groups, multimodal_train_infos + # 3. Save train state JSON + train_state_path = checkpoint_path / self._SAVE_TRAIN_STATE_PATH + with train_state_path.open("w") as f: + json.dump({"cur_step": cur_step}, f) - def _rollout_step(self, rollout_idx: int, step_timer_dict: dict) -> RolloutInfo: - """Performs a single rollout step to generate experience.""" - with timer("generation", step_timer_dict): - ray.get(self._rollout_env_controller.update_active_workers.remote()) - dataflow_result = ray.get(self._rollout_dataflow.run.remote()) + # 4. Update meta + current_exp = self._meta.latest_exp + current_exp.checkpoint_list.append(str(checkpoint_path)) - if XTUNER_DETERMINISTIC: - data_groups, multimodal_train_infos = self._sort_rollout_outputs( - dataflow_result["data_groups"], dataflow_result.get("mm_train_infos", None) - ) - dataflow_result["data_groups"] = data_groups - dataflow_result["mm_train_infos"] = multimodal_train_infos - - with timer("save_trajectory", step_timer_dict): - trajectory_save_path = self.exp_dir / f"rollout_idx_{rollout_idx}_trajectory.jsonl" - self._save_trajectories(dataflow_result["data_groups"], trajectory_save_path) - self.logger.info(f"Rollout_idx {rollout_idx} finished, saved trajectories to {trajectory_save_path}") - - if not self._debug_rollout: - with timer("rollout_offload", step_timer_dict): - ray.get(self._rollout_dataflow.pause.remote()) - ray.get(self._rollout_env_controller.offload.remote()) - - rollout_info: RolloutInfo = { - "data_groups": dataflow_result["data_groups"], - "multimodal_train_infos": dataflow_result.get("mm_train_infos", None), - "task_time": dataflow_result.get("metrics", {}), - "replay_buffer_info": ray.get(self._rollout_dataflow.get_replaybuffer_status.remote()), - } - return rollout_info + # 5. Prune old checkpoints + ckp_maxkeep = self._checkpoint_maxkeep + ckp_list = current_exp.checkpoint_list + if ckp_maxkeep is not None and ckp_maxkeep > 0 and len(ckp_list) > ckp_maxkeep: + for deleted in ckp_list[:-ckp_maxkeep]: + if Path(deleted).exists(): + rmtree(deleted, ignore_errors=True) + current_exp.checkpoint_list = ckp_list[-ckp_maxkeep:] - def _train_step(self, rollout_idx: int, data_groups, multimodal_train_infos, step_timer_dict: dict) -> TrainInfo: - """Performs a single training step on the generated experience.""" - with timer("onload", step_timer_dict): - ray.get(self._train_controller.onload.remote(target="all")) - self.logger.info("Training controller loaded") + # 6. Persist meta to disk + meta_path = self.exp_dir.parent / self._meta_path + with meta_path.open("w") as f: + f.write(self._meta.model_dump_json(indent=2)) - with timer("prepare_data", step_timer_dict): - data_batches, data_info = self._prepare_train_data( - data_groups, self._train_worker_cfg.pack_max_length, multimodal_train_infos - ) - self.logger.info(f"Prepared {len(data_batches)} training data batches") + def _maybe_save_hf(self, cur_step: int): + if self._hf_interval is None or self._hf_interval == -1: + return - with timer("training", step_timer_dict): - workers_log_item: List[WorkerLogItem] = ray.get( - self._train_controller.fit.remote( - data_batches, pack_max_length=self._train_worker_cfg.pack_max_length, rollout_idx=rollout_idx - ) + if not self._load_from_hf: + raise RuntimeError( + "Only support saving to Huggingface format when loading from Huggingface! " + "You meet this error means `load_from` of trainer is not a Huggingface model path." ) - train_log_info: TrainInfo = { - "data_info": data_info, - "workers_log_item": workers_log_item, - } - return train_log_info - - def _sync_weights_and_save(self, rollout_idx: int, step_timer_dict: dict): - """Synchronizes weights and saves checkpoints.""" - with timer("save_ckpt", step_timer_dict): - ray.get(self._train_controller.offload.remote(target="optimizer")) - self._maybe_save_hf() - self._maybe_save_checkpoint() - with timer("sync_weight", step_timer_dict): - bind_train_rollout(train_controller=self._train_controller, env_controller=self._rollout_env_controller) - ray.get(self._rollout_env_controller.onload_weights.remote()) - ray.get(self._train_controller.update_weights.remote()) - self.logger.info("Model weights synchronized successfully.") - ray.get(self._train_controller.offload.remote(target="model")) - ray.get(self._rollout_env_controller.onload_kvcache.remote()) - - def _evaluate_step(self, rollout_idx: int, step_timer_dict: dict) -> dict[str, float]: - """Performs an evaluation step.""" - eval_log_info = {} - if self._enable_evaluate and self._evaluator and rollout_idx % self._eval_step == 0: - with timer("evaluation", step_timer_dict): - scores, eval_data_groups = ray.get(self._evaluator.run.remote(return_samples=True)) - trajectory_save_path = self.exp_dir / f"eval_{rollout_idx}_trajectory.jsonl" - self._save_trajectories(eval_data_groups, trajectory_save_path) - eval_log_info.update(scores) - return eval_log_info - - def fit(self): - """Run the RL training loop. - - This method executes the main rl training loop, iterating generating through the dataset and performing - training steps. It handles rollout, prepare training data, update policy , synchronize model weights, and - evaluation. - """ - self.logger.info("Start RL training") - if self._cur_step >= self._rollout_steps: - self.logger.info(f"Rollout steps {self._rollout_steps} reached, stop training") + if cur_step % self._hf_interval != 0 and cur_step != self._total_train_steps: return - self._initial_evaluate() - - for rollout_idx in range(self._cur_step + 1, self._rollout_steps + 1): - self.logger.info(f"Rollout {rollout_idx}/{self._rollout_steps} start") - step_timer_dict = {} - with timer("step", step_timer_dict): - # 1. Rollout to generate experience - rollout_info = self._rollout_step(rollout_idx, step_timer_dict) - - train_log_info = {} - eval_log_info = {} - if not self._debug_rollout: - # 2. Train on the generated experience - train_log_info = self._train_step( - rollout_idx, - rollout_info["data_groups"], - rollout_info["multimodal_train_infos"], - step_timer_dict, - ) + save_hf_path = self.exp_dir / self._HF_DIR / f"hf-step-{cur_step}" + save_hf_path.mkdir(parents=True, exist_ok=True) - # 3. Synchronize weights and save checkpoints - self._sync_weights_and_save(rollout_idx, step_timer_dict) + # update meta + current_exp = self._meta.latest_exp + current_exp.hf_checkpoint_list.append(str(save_hf_path)) - # 4. Evaluate model performance - eval_log_info = self._evaluate_step(rollout_idx, step_timer_dict) + # save hf + self.logger.info(f"Saving Huggingface checkpoint to {save_hf_path}") + hf_list = self._meta.latest_exp.hf_checkpoint_list + if self._hf_max_keep is not None and self._hf_max_keep > 0 and len(hf_list) > self._hf_max_keep: + for deleted in hf_list[: -self._hf_max_keep]: + if Path(deleted).exists(): + rmtree(deleted, ignore_errors=True) + current_exp.hf_checkpoint_list = hf_list[-self._hf_max_keep :] + self.train_controller.save_hf(str(save_hf_path)) - self._log_step(rollout_idx, step_timer_dict, rollout_info, train_log_info, eval_log_info) - self._cur_step = rollout_idx + # save tokenizer + if isinstance(self.tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): + self.tokenizer.save_pretrained(str(save_hf_path)) - self._exp_tracker.close() + async def _run_initial_evaluate(self) -> None: + eval_produce_result = await self.eval_agent_loop_manager.produce_batch( + self.evaluator.eval_batch_size, + train_step=1, + model_step=0, + ) + if XTUNER_DETERMINISTIC: + eval_produce_result.rollout_states = sort_rollout_state_for_deterministic( + eval_produce_result.rollout_states + ) + eval_metrics = self.evaluator.run(eval_produce_result.rollout_states) + self.logger.info(f"Initial rollout evaluate scores {eval_metrics} and start training") + tb_scores = {f"eval/{k}": v for k, v in eval_metrics.items()} + self._exp_tracker.add_scalars(tag_scalar_dict=tb_scores, global_step=0) - def _log_step( + def _train_one_batch( self, - rollout_idx: int, + train_batch: list[list[RolloutState]], + train_step: int, step_timer_dict: dict, - rollout_info: RolloutInfo, - train_info: TrainInfo, - eval_info: dict[str, float], - ): - all_scalars = {} - log_time_str = "" - trajectory_str = "" - eval_str = "" - if step_timer_dict: - all_scalars.update({f"time/{k}": v for k, v in step_timer_dict.items()}) - log_time_str = f"\nRollout {rollout_idx} finished and timing listed:\n" - log_time_str += "\n".join([f" - {k:<25}: {v:.2f}s" for k, v in step_timer_dict.items()]) - - if rollout_info: - all_scalars.update(rollout_info.get("task_time", {})) - all_scalars.update({f"async/{k}": v for k, v in rollout_info.get("replay_buffer_info", {}).items()}) - - if train_info: - all_scalars.update({f"response/{k}": v for k, v in train_info.get("data_info", {}).items()}) - trajectory_str = f"\nRollout {rollout_idx} data statistics:\n" - trajectory_str += "\n".join([f"- {k:<25}: {v:.4f}" for k, v in train_info.get("data_info", {}).items()]) - rank0_log_item = train_info["workers_log_item"][0] - rank0_rollout_is_metrics = rank0_log_item.get("rollout_is_metrics", {}) - rank0_mismatch_metrics = rank0_log_item.get("mismatch_metrics", {}) - rank0_rollout_entropy = rank0_log_item.get("rollout_entropy", 0.0) - all_scalars.update({f"rollout_is/{k}": v for k, v in rank0_rollout_is_metrics.items()}) - all_scalars.update({f"{k}": v for k, v in rank0_mismatch_metrics.items()}) - all_scalars.update({"entropy/rollout": rank0_rollout_entropy}) - all_scalars.update({"entropy/train": rank0_log_item["train_entropy"]}) - for worker_idx, log_item in enumerate(train_info["workers_log_item"]): - if not self._display_all_workers_log and worker_idx > 0: - break - mini_batch_metrics: dict[str, List[float]] = {} - for mini_batch_log in log_item["train_metrics"]: - for k, v in mini_batch_log.items(): - mini_batch_metrics.setdefault(k, []).append(cast(float, v)) - - for key, value in mini_batch_metrics.items(): - avg_value = sum(value) / len(value) - all_scalars.update({f"train_metrics/worker_{worker_idx}/step_avg_{key}": avg_value}) - - rank_sft_log = log_item["sft_train_metrics"] - for k, v in rank_sft_log.items(): - all_scalars.update({f"sft_train_metrics/worker_{worker_idx}/{k}": v}) - - if eval_info: - all_scalars.update({f"eval/{k}": v for k, v in eval_info.items()}) - eval_str = " ".join([f"{k}: {v:.4f}" for k, v in eval_info.items()]) - - self.logger.info(f"Rollout {rollout_idx}/{self._rollout_steps}{log_time_str} {trajectory_str} ") - if eval_str: - self.logger.info(f"Eval: {eval_str}") - self._exp_tracker.add_scalars(tag_scalar_dict=all_scalars, global_step=rollout_idx) + *, + offload_rollout_before_train: bool = False, + onload_train_before_train: bool = False, + ) -> TrainInfo: + train_sample_count = sum(len(group) for group in train_batch) + self.logger.info(f"generate {train_sample_count} samples for training") + + train_trajectory_dir = self.exp_dir / "train_rollout" + train_trajectory_dir.mkdir(parents=True, exist_ok=True) + train_trajectory_path = train_trajectory_dir / f"train_rollout_{train_step}.jsonl" + self._save_trajectories(train_batch, train_trajectory_path) + self.logger.info(f"Train step {train_step} train trajectories saved to {train_trajectory_path}") + + # 共卡需要先释放 rollout,再把训练 worker onload;非共卡不走这两个动作。 + if offload_rollout_before_train: + ray.get(self.rollout_controller.offload.remote()) + if onload_train_before_train: + with timer("onload", step_timer_dict): + self.train_controller.onload(target="all") + self.logger.info("Training controller loaded") - def _log_mini_batch_metrics(self, workers_log_item: List[WorkerLogItem]): - train_start_step = self._global_train_step - for worker_idx, log_item in enumerate(workers_log_item): - for step_idx, metrics in enumerate(log_item["train_metrics"]): - if not self._display_all_workers_log and worker_idx > 0: - break - current_global_step = train_start_step + step_idx + with timer("prepare_data", step_timer_dict): + data_batches, data_info = self._prepare_train_data(train_batch, self._train_worker_cfg.pack_max_length) + self.logger.info(f"Prepared {len(data_batches)} training data batches") - self._exp_tracker.add_scalars( - tag_scalar_dict={f"train_metrics/worker_{worker_idx}/{k}": v for k, v in metrics.items()}, - global_step=current_global_step, - ) - self._global_train_step += len(workers_log_item[0]["train_metrics"]) + with timer("training", step_timer_dict): + workers_log_item: list[WorkerLogItem] = self.train_controller.fit( + data_batches, + pack_max_length=self._train_worker_cfg.pack_max_length, + rollout_idx=train_step, + ) + return { + "data_info": data_info, + "workers_log_item": workers_log_item, + } - # TODO: advantage 是在 DataFlow 里算好,还是在 train controller 里算? - # 因为可能有根据 advantage 来判断数据能否进 rl 训练的需求。暂时先放在这 - def _prepare_train_data(self, data_groups, pack_max_length, multimodal_train_infos=None): + async def _run_evaluation(self, train_step: int) -> dict[str, float]: + eval_produce_result = await self.eval_agent_loop_manager.produce_batch( + self.evaluator.eval_batch_size, + train_step=1, + model_step=0, + ) + if XTUNER_DETERMINISTIC: + eval_produce_result.rollout_states = sort_rollout_state_for_deterministic( + eval_produce_result.rollout_states + ) + eval_batch = eval_produce_result.rollout_states + eval_metrics = self.evaluator.run(eval_batch) + eval_trajectory_dir = self.exp_dir / "eval_rollout" + eval_trajectory_dir.mkdir(parents=True, exist_ok=True) + eval_trajectory_path = eval_trajectory_dir / f"eval_rollout_{train_step}.jsonl" + self._save_trajectories(eval_batch, eval_trajectory_path) + self.logger.info(f"Train step {train_step} eval trajectories saved to {eval_trajectory_path}") + return eval_metrics + + # TODO: simplify with Packer.pack_pad_dispatch() + def _prepare_train_data(self, data_groups: list[list[RolloutState]], pack_max_length: int): rewards_list = [] advantages_list = [] prompt_len_list = [] response_len_list = [] data_batches = [] - is_multimodal = False - if multimodal_train_infos and len(multimodal_train_infos) > 0: - assert len(multimodal_train_infos) == len(data_groups), ( - f"{len(multimodal_train_infos)} vs {len(data_groups)}" - ) - is_multimodal = True for j, group in enumerate(data_groups): - if not is_valid_for_training(group): + if not is_valid_for_training(group, self.logger): self.logger.error(f"Skip one data group {group} due to rollout failed or empty response.") continue - if is_multimodal: - multimodal_train_info = multimodal_train_infos[j] + + is_vlm_model = "train_prompt_ids" in group[0].extra_fields + if is_vlm_model: + # TODO(hha): VLM, 不好的设计,后续要去掉 + prompt_ids = group[0].extra_fields["train_prompt_ids"] else: - multimodal_train_info = None + prompt_ids = group[0].prompt_ids + assert prompt_ids is not None and len(prompt_ids) > 0, ( + f"Prompt ids cannot be None or empty in data: {group[0]}" + ) + rewards = [] + for data in group: + assert data.reward is not None and "score" in data.reward, ( + f"Reward is missing or does not contain 'score' key in data: {data}" + ) + rewards.append(data.reward["score"]) - prompt_ids = group[0].data.extra_info["train_prompt_ids"] - rewards = [data.env.judger.reward["score"] for data in group] rewards_list.extend(rewards) - rewards = torch.tensor(rewards, dtype=torch.float32) - prompt_repeat_k = len(group) - - advantages = self._advantage_estimator.compute(rewards, group) # TODO: support PPO advantage estimation + rewards_tensor = torch.tensor(rewards, dtype=torch.float32) + advantages = self._advantage_estimator.compute(rewards_tensor, group) + prompt_repeat_k = len(group) for i in range(prompt_repeat_k): - item = group[i].env.rollout.response - logprobs = None - if group[i].env.rollout.response_ids is not None: - response_ids = group[i].env.rollout.response_ids - if isinstance(response_ids, torch.Tensor): - response_ids = response_ids.flatten().tolist() - logprobs = group[i].env.rollout.logprobs - assert len(logprobs) == len(response_ids), f"{len(logprobs)} vs {len(response_ids)}" - # 只有 response 部分有 logprobs, 需要前面追加 - logprobs = [0] * (len(prompt_ids) - 1) + logprobs + item = group[i].response + logprobs: list[float] | None = None + + response_ids: List[int] = [] + if group[i].response_ids is not None: + resp_ids_raw = group[i].response_ids + if isinstance(resp_ids_raw, torch.Tensor): + response_ids = resp_ids_raw.flatten().tolist() + else: + response_ids = cast(List[int], resp_ids_raw) + + logprobs = group[i].logprobs + if logprobs is not None: + assert len(logprobs) == len(response_ids), ( + f"{len(logprobs)} vs {len(response_ids)}, data: {group[i]}" + ) + # 只有 response 部分有 logprobs, 需要前面追加 + logprobs = [0.0] * (len(prompt_ids) - 1) + logprobs # type: ignore[arg-type] else: + assert item is not None, "response item cannot be None" response_ids = self.tokenizer(item, return_tensors="pt")["input_ids"].flatten().tolist() + # 返回的 routed_experts 不包括 eos 的值,实际上也不需要,需要减一 + # TODO: verl tool agent loop 是否需要? input_ids = prompt_ids + response_ids[:-1] prompt_len_list.append(len(prompt_ids)) response_len_list.append(len(response_ids)) - advantages_list.extend([advantages[i]] * len(response_ids)) - shifted_labels = [-100] * (len(prompt_ids) - 1) + response_ids + # 根据 response_mask 计算 response_ids 对应的shifted_labels + if not group[i].response_mask: + response_mask = [1] * len(response_ids) + response_labels = response_ids + else: + assert len(group[i].response_mask) == len(response_ids), ( # type: ignore[arg-type] + f"{len(group[i].response_mask)} vs {len(response_ids)}" # type: ignore[arg-type] + ) + response_mask = cast(list[int], group[i].response_mask) + response_labels = [ + response_id if mask_id != 0 else -100 + for response_id, mask_id in zip(response_ids, response_mask) + ] + shifted_labels = [-100] * (len(prompt_ids) - 1) + response_labels + shifted_labels_t = torch.tensor(shifted_labels, dtype=torch.int64).unsqueeze(0) + + # 根据 response_mask 计算新的 advantages + advatnages_val = advantages[i].item() + actual_advantages = [advatnages_val] * len(prompt_ids) + [ + 0.0 if mask == 0 else advatnages_val for mask in response_mask + ] + advantages_list.extend(actual_advantages[:-1]) + assert len(input_ids) <= pack_max_length, f"{len(input_ids)} vs {pack_max_length}" - input_ids = torch.tensor(input_ids, dtype=torch.int64).unsqueeze(0) - shifted_labels = torch.tensor(shifted_labels, dtype=torch.int64).unsqueeze(0) + input_ids_t = torch.tensor(input_ids, dtype=torch.int64).unsqueeze(0) if logprobs is not None: rollout_logprobs = torch.tensor(logprobs, dtype=torch.float32).unsqueeze(0) - assert rollout_logprobs.size() == shifted_labels.size(), ( - f"{rollout_logprobs.size()} vs {shifted_labels.size()}" + assert rollout_logprobs.size() == shifted_labels_t.size(), ( + f"{rollout_logprobs.size()} vs {shifted_labels_t.size()}" ) else: rollout_logprobs = None - seq_ctx = get_train_seq_ctx(input_ids, multimodal_train_info, len(response_ids) - 1) + position_ids = group[i].position_ids + multimodal_train_info = group[i].mm_info + multi_info_cast = cast(dict | None, multimodal_train_info) + seq_ctx = get_train_seq_ctx(input_ids_t, position_ids, multi_info_cast, len(response_ids) - 1) # type: ignore[arg-type] + data_dict = { "seq_ctx": seq_ctx, - "shifted_labels": shifted_labels, - "advantage": advantages[i].item(), + "shifted_labels": shifted_labels_t, + "advantage": actual_advantages, "rollout_logprobs": rollout_logprobs, } - if "routed_experts" in group[i].env.rollout.extra_info: - routed_experts = group[i].env.rollout.extra_info.pop("routed_experts") # n,layer*expert - seq_ctx.rollout_routed_experts = routed_experts # n,layer,expert + seq_ctx.rollout_routed_experts = group[i].routed_experts # n,layer*expert data_batches.append(data_dict) - if multimodal_train_info is not None: - del multimodal_train_info - random.shuffle(data_batches) + if not XTUNER_DETERMINISTIC: + random.shuffle(data_batches) rewards_t = torch.tensor(rewards_list).float() if rewards_list else torch.tensor([0.0]).float() advantages_t = torch.tensor(advantages_list).float() if advantages_list else torch.tensor([0.0]).float() @@ -851,294 +693,427 @@ def _prepare_train_data(self, data_groups, pack_max_length, multimodal_train_inf } return data_batches, info_dict - def _save_trajectories(self, data_groups, save_path): - rewards = [] + def _log_step( + self, + train_step: int, + step_timer_dict: dict, + produce_result: ProduceBatchResult, + train_info: TrainInfo, + eval_info: dict[str, float], + ): + all_scalars = {} + log_time_str = "" + trajectory_str = "" + eval_str = "" + if step_timer_dict: + all_scalars.update({f"time/{k}": v for k, v in step_timer_dict.items()}) + log_time_str = f"\nTrain step {train_step} finished and timing listed:\n" + log_time_str += "\n".join([f" - {k:<25}: {v:.2f}s" for k, v in step_timer_dict.items()]) + + if produce_result.group_gen_count is not None: + all_scalars["timing/task_n"] = produce_result.group_gen_count + all_scalars["timing/task_mean_s"] = produce_result.group_gen_mean_s + all_scalars["timing/task_p50_s"] = produce_result.group_gen_p50_s + all_scalars["timing/task_p99_s"] = produce_result.group_gen_p99_s + all_scalars["timing/task_p99_p50_ratio"] = produce_result.group_gen_p99_p50_ratio + all_scalars["timing/pause_s"] = produce_result.group_gen_pause_time_s + all_scalars["async/completed_samples"] = produce_result.leftover_completed + all_scalars["async/aborted_samples"] = produce_result.leftover_aborted + all_scalars["async/expired_samples"] = produce_result.leftover_expired + + if train_info: + all_scalars.update({f"response/{k}": v for k, v in train_info.get("data_info", {}).items()}) + trajectory_str = f"\nTrain step {train_step} data statistics:\n" + trajectory_str += "\n".join([f"- {k:<25}: {v:.4f}" for k, v in train_info.get("data_info", {}).items()]) + rank0_log_item = train_info["workers_log_item"][0] + rank0_rollout_is_metrics = rank0_log_item.get("rollout_is_metrics", {}) + rank0_mismatch_metrics = rank0_log_item.get("mismatch_metrics", {}) + rank0_rollout_entropy = rank0_log_item.get("rollout_entropy", 0.0) + all_scalars.update({f"rollout_is/{k}": v for k, v in rank0_rollout_is_metrics.items()}) + all_scalars.update({f"{k}": v for k, v in rank0_mismatch_metrics.items()}) + all_scalars.update({"entropy/rollout": rank0_rollout_entropy}) + all_scalars.update({"entropy/train": rank0_log_item["train_entropy"]}) + for worker_idx, log_item in enumerate(train_info["workers_log_item"]): + if not self._display_all_workers_log and worker_idx > 0: + break + mini_batch_metrics: dict[str, List[float]] = {} + for mini_batch_log in log_item["train_metrics"]: + for k, v in mini_batch_log.items(): + mini_batch_metrics.setdefault(k, []).append(cast(float, v)) + + for key, value in mini_batch_metrics.items(): + avg_value = sum(value) / len(value) + all_scalars.update({f"train_metrics/worker_{worker_idx}/step_avg_{key}": avg_value}) + + rank_sft_log = log_item["sft_train_metrics"] + for k, v in rank_sft_log.items(): + all_scalars.update({f"sft_train_metrics/worker_{worker_idx}/{k}": v}) + + self._log_mini_batch_metrics(train_info["workers_log_item"]) - rollout_response_len_list = [] - version_dict = {i: 0 for i in range(self._dataflow_partial_rollout_step + 1)} + if eval_info: + all_scalars.update({f"eval/{k}": v for k, v in eval_info.items()}) + eval_str = " ".join([f"{k}: {v:.4f}" for k, v in eval_info.items()]) + + self.logger.info(f"Train step {train_step}/{self._total_train_steps}{log_time_str} {trajectory_str} ") + if eval_str: + self.logger.info(f"Eval: {eval_str}") + self._exp_tracker.add_scalars(tag_scalar_dict=all_scalars, global_step=train_step) + + def _save_trajectories(self, data_groups: list[list[RolloutState]], save_path: Path) -> None: + rewards = [] + response_len_list = [] - # NOTE: Since we currently default to token-in token-out, the code for checking whether response_ids have Retokenization Drift is commented out. - # If you need to debug, you can uncomment it. - # mismatch_token_ids_count = 0 - # response_len_list = [] for group in data_groups: - if not is_valid_for_training(group): - self.logger.error(f"Skip one data group {group} due to rollout failed or empty response.") + if not is_valid_for_training(group, self.logger): continue for data in group: - rewards.append(data.env.judger.reward["score"]) - if data.env.rollout.response_ids is not None: - if isinstance(data.env.rollout.response_ids, torch.Tensor): - response_ids = data.env.rollout.response_ids.flatten().tolist() + assert data.reward is not None + rewards.append(data.reward["score"]) + if data.response_ids is not None: + if isinstance(data.response_ids, torch.Tensor): + response_ids = data.response_ids.flatten().tolist() else: - response_ids = data.env.rollout.response_ids - rollout_response_len_list.append(len(response_ids)) - # response_str = self.tokenizer.decode(response_ids, skip_special_tokens=False) - # revert_encode_response_ids = self.tokenizer.encode(response_str, add_special_tokens=False) - - # response_str_to_ids = self.tokenizer.encode(data.env.rollout.response, add_special_tokens=False) - # response_len_list.append(len(response_str_to_ids)) + response_ids = data.response_ids + response_len_list.append(len(response_ids)) + elif data.response is not None: + response_ids = self.tokenizer.encode(data.response, add_special_tokens=False) + response_len_list.append(len(response_ids)) - # if response_ids != revert_encode_response_ids or response_ids != response_str_to_ids: - # mismatch_token_ids_count += 1 - else: - response_ids = self.tokenizer.encode(data.env.rollout.response, add_special_tokens=False) - rollout_response_len_list.append(len(response_ids)) - - version = data.uid.version - if version not in version_dict: - version_dict[version] = 0 - version_dict[version] += 1 - - rewards_tensor = torch.tensor(rewards).float() - rollout_response_lens: torch.Tensor = torch.tensor([0.0]).float() - if len(rollout_response_len_list) > 0: - rollout_response_lens = torch.tensor(rollout_response_len_list).float() + rewards_tensor = torch.tensor(rewards).float() if rewards else torch.tensor([0.0]).float() + response_lens = torch.tensor(response_len_list).float() if response_len_list else torch.tensor([0.0]).float() _count = 0 with open(save_path, "w", encoding="utf-8") as f: - item = { + summary = { "reward_mean": rewards_tensor.mean().item(), "reward_std": rewards_tensor.std().item(), "reward_max": rewards_tensor.max().item(), "reward_min": rewards_tensor.min().item(), - "response_len_mean": rollout_response_lens.mean().item(), - "response_len_std": rollout_response_lens.std().item(), - "response_len_max": rollout_response_lens.max().item(), - "response_len_min": rollout_response_lens.min().item(), + "response_len_mean": response_lens.mean().item(), + "response_len_std": response_lens.std().item(), + "response_len_max": response_lens.max().item(), + "response_len_min": response_lens.min().item(), "total_len": len(rewards), - "versions": version_dict, - # "mismatch_token_ids_count": mismatch_token_ids_count, } - json.dump(item, f, ensure_ascii=False, indent=2) + json.dump(summary, f, ensure_ascii=False, indent=2) f.write("\n") for group in data_groups: - if not is_valid_for_training(group): - self.logger.error(f"Skip one data group {group} due to rollout failed or empty response.") + if not is_valid_for_training(group, self.logger): continue for data in group: + assert data.reward is not None + ground_truth = None + if data.reward_model is not None: + ground_truth = data.reward_model.get("ground_truth") item = { - "action_id": data.uid.action_id, - "prompt": data.data.extra_info["raw_prompt"], - "response": data.env.rollout.response, - "versioned_response": data.env.rollout.versioned_response, - # "response_ids": str(data.env.rollout.response_ids), - # "versioned_response_ids": str(data.env.rollout.versioned_response_ids), - "response_len": rollout_response_len_list[_count], - "versioned_response_len": data.env.rollout.versioned_num_return_tokens, - "label": data.data.reward_model.get("ground_truth", ""), - "reward": data.env.judger.reward["score"], - "version": data.uid.version, - "finish_reason": data.env.rollout.finish_reason, + "prompt": data.message, + "raw_prompt": data.extra_fields.get("raw_prompt", None), + "response": data.response, + "response_len": response_len_list[_count], + "label": ground_truth, + "reward": data.reward["score"], + "finish_reason": data.finish_reason, } json.dump(item, f, ensure_ascii=False, indent=2) f.write("\n") _count += 1 - def _load_trajectories(self, save_path): - data_groups = [] - with open(save_path) as f: - for line in f: - item = json.loads(line) - messages = item["messages"] - responses = item["response"] - rewards = item["reward"] - group = [] - for response, reward in zip(responses, rewards): - group.append( - { - "messages": messages, - "response_str": response, - "reward": reward, - } - ) - data_groups.append(group) - return data_groups - - def _compute_metrics(self, data_groups): - correctness = [1 if data[0]["reward"] > 0 else 0 for data in data_groups] - acc = sum(correctness) / len(correctness) - return acc + def _log_mini_batch_metrics(self, workers_log_item: List[WorkerLogItem]): + train_start_step = self._global_train_step + 1 + for worker_idx, log_item in enumerate(workers_log_item): + for step_idx, mini_batch_log in enumerate(log_item["train_metrics"]): + if not self._display_all_workers_log and worker_idx > 0: + break + current_global_step = train_start_step + step_idx - def _maybe_save_hf(self): - if self._hf_interval is None: - return + metrics: dict[str, Any] = dict(mini_batch_log) - assert self._load_from_hf, ( - "Only support saving to Huggingface format when loading from Huggingface! " - "You meet this error means `load_from` of trainer is not a Huggingface model path." - ) + self._exp_tracker.add_scalars( + tag_scalar_dict={f"train_metrics/worker_{worker_idx}/{k}": float(v) for k, v in metrics.items()}, + global_step=current_global_step, + ) + self._global_train_step += len(workers_log_item[0]["train_metrics"]) - if (self.cur_step + 1) % self._hf_interval != 0 and (self.cur_step + 1) != self._rollout_steps: - return - save_hf_path = self.exp_dir / f"hf-{self.cur_step + 1}" - self.logger.info(f"Saving step {self.cur_step + 1} hf checkpoints to: {save_hf_path}") - self.meta.latest_exp.hf_checkpoint_list.append(str(save_hf_path)) +class RLColocateTrainer(BaseRLTrainer): + _META_PATH = ".xtuner_rl_colocate_trainer" - if self._hf_max_keep is not None and len(self.meta.latest_exp.hf_checkpoint_list) > self._hf_max_keep: - deleted_hf_checkpoints = self.meta.latest_exp.hf_checkpoint_list[: -self._hf_max_keep] - self.meta.latest_exp.hf_checkpoint_list = self.meta.latest_exp.hf_checkpoint_list[-self._hf_max_keep :] - for hf_dir in deleted_hf_checkpoints: - rmtree(hf_dir) + # 共卡 trainer 保留自己的资源编排、resume、主循环和权重同步;通用保存、日志仍在 BaseRLTrainer。 + def __init__(self, cfg: RLColocateTrainerConfig): + self._init_common(cfg, meta_path=self._META_PATH, logger_tag="RLTrainer") - ray.get(self._train_controller.save_hf.remote(str(save_hf_path)), timeout=self._ray_get_timeout) - if isinstance(self.tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): - self.tokenizer.save_pretrained(str(save_hf_path)) + self._pg = AutoAcceleratorWorkers.build_placement_group(cfg.resources) + self.train_controller = self._train_worker_cfg.build(self._pg) + self.rollout_controller = self._rollout_config.build(self._pg) + self._maybe_start_gateway(cfg) - def _maybe_save_checkpoint(self): - ckp_interval = self._checkpoint_interval - if ckp_interval is None: - return + replay_buffer = cfg.replay_buffer_config.build() + self._build_agent_loop_components(cfg, replay_buffer) - if ckp_interval == -1: - return + if self._load_checkpoint_cfg.checkpoint_path is not None: + self._resume_from_checkpoint(self._load_checkpoint_cfg.checkpoint_path) else: - if (self.cur_step + 1) % ckp_interval != 0 or (self.cur_step + 1) == self._rollout_steps: - return + self.train_controller.offload(target="all") - checkpoint_path = self.exp_dir / self._CHECKPOINT_DIR / f"ckpt-step-{self.cur_step + 1}" - checkpoint_path.mkdir(parents=True, exist_ok=True) + if self._debug_rollout: + self.logger.warning("Debug rollout mode is enabled, rollout will not be offloaded.") - self.logger.info(f"Saving step {self.cur_step + 1} rollout dataflow to: {checkpoint_path}") - ray.get(self._rollout_dataflow.save.remote(str(checkpoint_path)), timeout=self._ray_get_timeout) - self.logger.info(f"Saving step {self.cur_step + 1} dcp checkpoints to: {checkpoint_path}") - ray.get( - self._train_controller.save.remote(str(checkpoint_path), self._checkpoint_no_save_optimizer), - timeout=self._ray_get_timeout, - ) + def _resume_from_checkpoint(self, checkpoint_path: Path | str) -> None: + checkpoint_path = self._resume_train_controller_and_state(checkpoint_path) - # Update meta - current_exp = self.meta.latest_exp - ckp_list = current_exp.checkpoint_list - ckp_list.append(str(checkpoint_path)) - current_exp.cur_step = self.cur_step + 1 - current_exp.history[-1]["end"] = self.cur_step + 1 + self.logger.info(f"Resume sampler from {checkpoint_path}") + self.agent_loop_manager.resume(checkpoint_path) - train_state_path = checkpoint_path / self._SAVE_TRAIN_STATE_PATH - with train_state_path.open("w") as f: - f.write( - json.dumps( - { - "cur_step": self.cur_step + 1, - } - ) - ) + bind_train_rollout(train_controller=self.train_controller, rollout_controller=self.rollout_controller) + self.logger.info("Rollout workers skip load weights, update weights from train workers.") + self.train_controller.offload(target="optimizer") + ray.get(self.rollout_controller.offload.remote()) + ray.get(self.rollout_controller.onload_weights.remote()) + self.train_controller.update_weights() + self.train_controller.offload(target="model") + ray.get(self.rollout_controller.onload_kvcache.remote()) + self.logger.info("Rollout workers updated weights from train workers.") - # Delete checkpoints and update meta's checkpoint_list - ckp_maxkeep = self._checkpoint_maxkeep - if ckp_maxkeep is not None and ckp_maxkeep > 0 and len(ckp_list) > ckp_maxkeep: - ckp_pop_num = len(ckp_list) - ckp_maxkeep - for _ in range(ckp_pop_num): - deleted_ckp = ckp_list.pop(0) - if Path(deleted_ckp).exists(): - rmtree(deleted_ckp, ignore_errors=True) + def fit(self): + self.logger.info("Start RL training") + if self._cur_step >= self._total_train_steps: + self.logger.info(f"Train steps {self._total_train_steps} reached, stop training") + return - meta_path = self.work_dir / self.META_PATH - with meta_path.open("w") as f: - f.write(self.meta.model_dump_json(indent=2)) + if self._enable_initial_evaluate and not self._debug_rollout: + asyncio_run(self._run_initial_evaluate()) - def _init_logger(self, work_dir: Path): - # Logging system maybe need better design - logger = get_logger(log_dir=work_dir, tag="RLTrainer") - return logger + init_train_step = self._cur_step + 1 + model_step = self._get_colocate_rollout_model_step(init_train_step) + for train_step in range(init_train_step, self._total_train_steps + 1): + self.logger.info(f"Train step {train_step}/{self._total_train_steps} start") + step_timer_dict = {} + with timer("step", step_timer_dict): + # 共卡路径一次调用内完成 rollout 生产和 replay buffer 消费。 + self.logger.info( + f"[Step {train_step}] start to generate rollout experience for train step {train_step} with model step {model_step}" + ) + produce_result: ProduceBatchResult = asyncio_run( + self.agent_loop_manager.produce_batch( + self.train_batch_size, + train_step=train_step, + model_step=model_step, + ) + ) + if XTUNER_DETERMINISTIC: + produce_result.rollout_states = sort_rollout_state_for_deterministic(produce_result.rollout_states) + train_batch = produce_result.rollout_states + assert train_batch, ( + "RLColocateTrainer expects agent_loop_manager.produce_batch() to return non-empty rollout_states." + ) - def _set_deterministic(self): - if XTUNER_DETERMINISTIC: - os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" - torch.use_deterministic_algorithms(True, warn_only=True) + if not self._debug_rollout: + train_log_info = self._train_one_batch( + train_batch, + train_step, + step_timer_dict, + offload_rollout_before_train=True, + onload_train_before_train=True, + ) - def _set_random_seed(self, seed: int): - set_random_seed(seed) + weights_synced = self._sync_weights_and_save(train_step, step_timer_dict) + if weights_synced: + model_step = train_step - def _init_xtuner_meta(self, work_dir: Path, resume: bool) -> XTunerMeta: - if not work_dir.exists(): - work_dir.mkdir(parents=True, exist_ok=True) + eval_log_info = {} + if weights_synced and self._enable_evaluate and train_step % self._evaluate_step == 0: + with timer("evaluation", step_timer_dict): + eval_log_info.update(asyncio_run(self._run_evaluation(train_step))) + else: + train_log_info = {} + eval_log_info = {} - meta_path = work_dir / self.META_PATH - if not meta_path.exists(): - meta = XTunerMeta(exps=[]) - with open(meta_path, "w") as f: - f.write(meta.model_dump_json(indent=2)) + self._log_step(train_step, step_timer_dict, produce_result, train_log_info, eval_log_info) + self._cur_step = train_step - meta = cast(XTunerMeta, XTunerMeta.model_validate(load(meta_path, file_format="json"))) + def _get_colocate_rollout_model_step(self, train_step: int) -> int: + previous_step = train_step - 1 + return previous_step - (previous_step % self._sync_weights_interval) - resume = resume and bool(meta.exps) + def _sync_weights_and_save(self, train_step: int, step_timer_dict: dict) -> bool: + """Save state and switch colocated resources back to rollout + workers.""" + should_sync_weights = train_step % self._sync_weights_interval == 0 + with timer("save_ckpt", step_timer_dict): + self.train_controller.offload(target="optimizer") + self._maybe_save_checkpoint(train_step) + self._maybe_save_hf(train_step) + + ray.get(self.rollout_controller.recover_failed_workers.remote()) + timer_name = "sync_weight" if should_sync_weights else "switch_to_rollout" + with timer(timer_name, step_timer_dict): + if should_sync_weights: + bind_train_rollout(train_controller=self.train_controller, rollout_controller=self.rollout_controller) + ray.get(self.rollout_controller.onload_weights.remote()) + self.train_controller.update_weights() + self.logger.info("Model weights synchronized successfully.") + self.train_controller.offload(target="model") + else: + self.train_controller.offload(target="model") + ray.get(self.rollout_controller.onload_weights.remote()) + ray.get(self.rollout_controller.onload_kvcache.remote()) + return should_sync_weights - if resume and meta.exps: - latest_exp = meta.exps[-1] - latest_exp_history = latest_exp.history[-1] - begin = cast(int, latest_exp_history.get("end") or latest_exp_history["begin"]) - exp_dir = Path(latest_exp.exp_dir) - git_dir = exp_dir / f"git-info-begin-{begin}" +class RLDisaggregatedTrainer(BaseRLTrainer): + _META_PATH = ".xtuner_rl_disaggregated_trainer" - if not git_dir: - git_dir.mkdir(parents=True, exist_ok=True) + def __init__(self, cfg: RLDisaggregatedTrainerConfig): + self._init_common(cfg, meta_path=self._META_PATH, logger_tag="RLDisaggTrainer") - staged_path, unstaged_path = git_dir / "staged.diff", git_dir / "unstaged.diff" + self._train_pg, self._rollout_pg = self._build_disaggregated_placement_groups( + train_resources=cfg.train_resources, + rollout_resources=cfg.rollout_resources, + ) + self.train_controller = self._train_worker_cfg.build(self._train_pg) + self.rollout_controller = self._rollout_config.build(self._rollout_pg) + self._maybe_start_gateway(cfg) + + replay_buffer = cfg.replay_buffer_config.build() + self._build_agent_loop_components(cfg, replay_buffer) + # 在非共卡使用模式时,生产者和消费者并发执行 + # 为了让生产者和消费者配合,不能引入生产中的早停机制,否则生产不够,消费者会被阻塞 + # 所以 should_continue_fn 必须为 default_should_continue_fn + for task_runner in self.agent_loop_manager.task_runners: + if task_runner.produce_strategy.should_continue_fn is not default_should_continue_fn: + raise ValueError( + "In disaggregated mode, should_continue_fn must be default, " + "because it does not allow early stopping in production." + ) - if not git_dir.exists(): - git_dir.mkdir(parents=True, exist_ok=True) - commit = record_git_info(staged_path, unstaged_path) - git_info = GitInfo( - commit=commit, - staged=str(staged_path), - unstaged=str(unstaged_path), - ) + if self._load_checkpoint_cfg.checkpoint_path is not None: + self._resume_from_checkpoint(self._load_checkpoint_cfg.checkpoint_path) - timestamp = datetime.now().strftime("%Y%m%d%H%M%S") - new_exp_history = ExpHistory( - begin=begin, - timestamp=timestamp, - git_info=git_info, + if self._debug_rollout: + self.logger.warning( + "Debug rollout mode is enabled. Disaggregated training keeps rollout workers resident." ) - latest_exp.history.append(new_exp_history) - else: - timestamp = datetime.now().strftime("%Y%m%d%H%M%S") - exp_dir = work_dir / timestamp - git_dir = Path(f"{exp_dir}/git-info-begin-{0}") - - if not git_dir.exists(): - git_dir.mkdir(parents=True, exist_ok=True) - - staged_path, unstaged_path = git_dir / "staged.diff", git_dir / "unstaged.diff" - commit = record_git_info(staged_path, unstaged_path) - git_info = GitInfo( - commit=commit, - staged=str(staged_path), - unstaged=str(unstaged_path), + + def _build_disaggregated_placement_groups( + self, + train_resources: AcceleratorResourcesConfig, + rollout_resources: AcceleratorResourcesConfig, + ): + pg_name_prefix = f"xtuner_rl_disagg_{self.exp_dir.name}" + train_pg_name = f"{pg_name_prefix}_train" + rollout_pg_name = f"{pg_name_prefix}_rollout" + + train_pg = AutoAcceleratorWorkers.build_placement_group(train_resources, name=train_pg_name) + rollout_pg = AutoAcceleratorWorkers.build_placement_group(rollout_resources, name=rollout_pg_name) + if train_pg.id == rollout_pg.id: + raise RuntimeError( + "RLDisaggregatedTrainer requires distinct placement groups for train and rollout, " + f"but both resolved to the same placement group id={train_pg.id}. " + "Please check placement-group naming and stale Ray cluster state." ) - new_history = ExpHistory( - begin=0, - timestamp=timestamp, - git_info=git_info, + self.logger.info( + "Created disaggregated placement groups: " + f"train={train_pg_name}(id={train_pg.id}), " + f"rollout={rollout_pg_name}(id={rollout_pg.id})" + ) + return train_pg, rollout_pg + + def _resume_from_checkpoint(self, checkpoint_path: Path | str) -> None: + checkpoint_path = self._resume_train_controller_and_state(checkpoint_path) + + self.logger.info(f"Resume sampler from {checkpoint_path}") + saved_model_step = self.agent_loop_manager.resume(checkpoint_path) + assert self._cur_step == saved_model_step + + bind_train_rollout(train_controller=self.train_controller, rollout_controller=self.rollout_controller) + self.logger.info("Rollout workers skip load weights, update weights from train workers.") + self.fake_update_weights() + self.agent_loop_manager.continue_produce(model_step=saved_model_step) + + def fit(self): + # 对外保留同步 fit 接口,内部用 async loop 组织 producer/consumer。 + return asyncio_run(self._fit()) + + async def _fit(self): + self.logger.info("Start RL disaggregated training") + if self._cur_step >= self._total_train_steps: + self.logger.info(f"Train steps {self._total_train_steps} reached, stop training") + return + + if self._enable_initial_evaluate: + await self._run_initial_evaluate() + + # 后台 producer 只负责持续往 replay buffer 写数据,前台 trainer 通过 get_batch 消费。 + producer_task = create_task( + self.agent_loop_manager.produce_loop( + batch_size=self.train_batch_size, ) - new_exp = ExpInfo(history=[new_history], exp_dir=str(exp_dir)) - meta.exps.append(new_exp) - return meta + ) + try: + for train_step in range(self._cur_step + 1, self._total_train_steps + 1): + self.logger.info(f"Train step {train_step}/{self._total_train_steps} start") + step_timer_dict: dict[str, float] = {} + train_log_info = {} + eval_log_info = {} + with timer("step", step_timer_dict): + produce_result = await self.agent_loop_manager.get_batch( + self.train_batch_size, train_step=train_step + ) + if XTUNER_DETERMINISTIC: + produce_result.rollout_states = sort_rollout_state_for_deterministic( + produce_result.rollout_states + ) + need_sync = ( + produce_result.status == ProduceBatchStatus.EXPIRED_BATCH + or train_step % self._sync_weights_interval == 0 + or train_step == self._total_train_steps + ) + if produce_result.status != ProduceBatchStatus.EXPIRED_BATCH: + train_batch = produce_result.rollout_states + assert train_batch, ( + "RLDisaggregatedTrainer expects get_batch() to return non-empty rollout_states " + "unless status is EXPIRED_BATCH." + ) + train_log_info = self._train_one_batch(train_batch, train_step, step_timer_dict) + else: + self.logger.info( + "Skip train step because rollout model is expired; prioritize weight sync first." + ) - @property - def work_dir(self) -> Path: - return self._work_dir + if need_sync: + # 同步前先暂停后台 producer,避免 save/sync 时还有 pending rollout 继续写 buffer。 + with timer("pause_produce", step_timer_dict): + await self.agent_loop_manager.pause_produce(use_global_progress=True) - @property - def exp_dir(self) -> Path: - return Path(self._meta.latest_exp.exp_dir) + await self._sync_weights_and_save(train_step, step_timer_dict) - @property - def meta(self) -> XTunerMeta: - return self._meta + if self._enable_evaluate and train_step % self._evaluate_step == 0: + # eval 放在恢复 producer 前,避免后台生产抢占 rollout 资源。 + with timer("evaluation", step_timer_dict): + eval_log_info.update(await self._run_evaluation(train_step)) - @property - def cur_step(self): - return self._cur_step + self.agent_loop_manager.continue_produce(model_step=train_step) - @property - def total_epoch(self): - return self._total_epochs + self._log_step(train_step, step_timer_dict, produce_result, train_log_info, eval_log_info) + self._cur_step = train_step + finally: + self.agent_loop_manager._status = AgentLoopManagerStatus.FINISH + self.agent_loop_manager._finish_event.set() + await producer_task - @property - def rollout_steps(self): - return self._rollout_steps + async def _sync_weights_and_save(self, train_step: int, step_timer_dict: dict): + # 非共卡已经在 _fit 里暂停 producer;这里保持静止态下的 save -> bind -> update 顺序。 + with timer("save_ckpt", step_timer_dict): + self._maybe_save_checkpoint(train_step) + self._maybe_save_hf(train_step) + + ray.get(self.rollout_controller.recover_failed_workers.remote()) + with timer("sync_weight", step_timer_dict): + bind_train_rollout(train_controller=self.train_controller, rollout_controller=self.rollout_controller) + self.fake_update_weights() + + def fake_update_weights(self): + self.train_controller.update_weights() + self.logger.info("Rollout workers updated weights through fake disaggregated sync.") diff --git a/xtuner/v1/train/trainer.py b/xtuner/v1/train/trainer.py index 26ec589191..0fcf64435c 100644 --- a/xtuner/v1/train/trainer.py +++ b/xtuner/v1/train/trainer.py @@ -159,6 +159,69 @@ def get_exp_by_checkpoint(self, checkpoint: str) -> ExpInfo | None: return exp return None + @classmethod + def build(cls, work_dir: Path, meta_filename: str, resume: bool) -> "XTunerMeta": + """Create or load meta from work_dir and optionally start a new exp or + resume. + + Single-process helper (e.g. for rl_trainer). For distributed training use the trainer's _init_xtuner_meta. + """ + if not work_dir.exists(): + work_dir.mkdir(parents=True, exist_ok=True) + + meta_path = work_dir / meta_filename + if not meta_path.exists(): + meta = cls(exps=[]) + with open(meta_path, "w") as f: + f.write(meta.model_dump_json(indent=2)) + + meta = cast(XTunerMeta, cls.model_validate(load(meta_path, file_format="json"))) + resume = resume and bool(meta.exps) + + if resume and meta.exps: + latest_exp = meta.exps[-1] + latest_exp_history = latest_exp.history[-1] + begin = cast(int, latest_exp_history.get("end") or latest_exp_history["begin"]) + exp_dir = Path(latest_exp.exp_dir) + git_dir = exp_dir / f"git-info-begin-{begin}" + if not git_dir.exists(): + git_dir.mkdir(parents=True, exist_ok=True) + staged_path, unstaged_path = git_dir / "staged.diff", git_dir / "unstaged.diff" + commit = record_git_info(staged_path, unstaged_path) + git_info = GitInfo( + commit=commit, + staged=str(staged_path), + unstaged=str(unstaged_path), + ) + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + new_exp_history = ExpHistory( + begin=begin, + timestamp=timestamp, + git_info=git_info, + ) + latest_exp.history.append(new_exp_history) + else: + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + exp_dir = work_dir / timestamp + git_dir = Path(f"{exp_dir}/git-info-begin-0") + if not git_dir.exists(): + git_dir.mkdir(parents=True, exist_ok=True) + staged_path, unstaged_path = git_dir / "staged.diff", git_dir / "unstaged.diff" + commit = record_git_info(staged_path, unstaged_path) + git_info = GitInfo( + commit=commit, + staged=str(staged_path), + unstaged=str(unstaged_path), + ) + new_history = ExpHistory( + begin=0, + timestamp=timestamp, + git_info=git_info, + ) + new_exp = ExpInfo(history=[new_history], exp_dir=str(exp_dir)) + meta.exps.append(new_exp) + return meta + class ResumeConfig(BaseModel): model_config = ConfigDict(extra="forbid") @@ -1338,6 +1401,7 @@ def _init_dist(self, backend: str | None = None): dist.all_reduce(warmup_tensor) def _init_xtuner_meta(self, work_dir: Path, auto_resume: bool) -> XTunerMeta: + # TODO: simplify with XTunerMeta.build() of dist version if not work_dir.exists(): if self.rank == 0: work_dir.mkdir(parents=True, exist_ok=True) diff --git a/xtuner/v1/utils/__init__.py b/xtuner/v1/utils/__init__.py index 107bb0dbed..d23b22358e 100644 --- a/xtuner/v1/utils/__init__.py +++ b/xtuner/v1/utils/__init__.py @@ -17,6 +17,7 @@ get_padding_length, is_hf_model_path, record_git_info, + set_deterministic, trim_memory, ) from .pad import pad_to_max_length, pad_to_multiple_of @@ -62,5 +63,6 @@ "ray_method", "profile_time", "clean_param_name", + "set_deterministic", "trim_memory", ] diff --git a/xtuner/v1/utils/convert_gsm8k_with_tool.py b/xtuner/v1/utils/convert_gsm8k_with_tool.py new file mode 100644 index 0000000000..d5ba33f403 --- /dev/null +++ b/xtuner/v1/utils/convert_gsm8k_with_tool.py @@ -0,0 +1,87 @@ +"""Preprocess the GSM8k dataset to parquet format.""" + +import argparse +import os +import re + +import datasets + + +def extract_solution(solution_str): + solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str) + assert solution is not None + final_solution = solution.group(0) + final_solution = final_solution.split("#### ")[1].replace(",", "") + return final_solution + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input-dir", default="openai/gsm8k") + parser.add_argument("--out-dir") + + args = parser.parse_args() + + dataset = datasets.load_dataset(args.input_dir, "default") + + train_dataset = dataset["train"] + test_dataset = dataset["test"] + + instruction_following = 'Let\'s think step by step and output the final answer after "####".' + + # add a row to each data item that represents a unique id + # Adapted from https://github.com/verl-project/verl/blob/c37d4d53850906aced4c071666340ec26966d707/examples/data_preprocess/gsm8k_tool_agent_loop.py#L62 + def make_map_fn(split): + def process_fn(example, idx): + question_raw = example.pop("question") + + question = question_raw + " " + instruction_following + + answer_raw = example.pop("answer") + solution = extract_solution(answer_raw) + data = { + "data_source": "openai/gsm8k", + "agent_name": "tool_agent", + "prompt": [ + { + "role": "system", + "content": ( + "You are a math expert. You are given a question and you need to solve it step by step. " + "Reasoning step by step before any tool call. " + "You should use the `calc_gsm8k_reward` tool after step by step solving the question, " + "before generate final answer at least once and refine your answer if necessary. " + "Put your final answer in the format of `#### `." + ), + }, + { + "role": "user", + "content": question, + }, + ], + "ability": "math", + "reward_model": {"style": "rule", "ground_truth": solution}, + "extra_info": { + "split": split, + "index": idx, + "answer": answer_raw, + "question": question_raw, + "need_tools_kwargs": True, + "tools_kwargs": { + "calc_gsm8k_reward": { + "create_kwargs": {"ground_truth": solution}, + }, + }, + }, + } + return data + + return process_fn + + train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True) + test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True) + + out_dir = args.out_dir + + os.makedirs(out_dir, exist_ok=True) + train_dataset.to_json(os.path.join(out_dir, "train.jsonl"), orient="records", lines=True) + test_dataset.to_json(os.path.join(out_dir, "test.jsonl"), orient="records", lines=True) diff --git a/xtuner/v1/utils/misc.py b/xtuner/v1/utils/misc.py index 311b0d0a46..c2360f2540 100644 --- a/xtuner/v1/utils/misc.py +++ b/xtuner/v1/utils/misc.py @@ -10,6 +10,7 @@ from types import FunctionType from typing import Annotated +import torch from huggingface_hub import constants from mmengine import is_installed @@ -25,6 +26,13 @@ logger = get_logger() XTUNER_DETERMINISTIC = os.getenv("XTUNER_DETERMINISTIC") == "true" + +def set_deterministic(): + if XTUNER_DETERMINISTIC: + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" + torch.use_deterministic_algorithms(True, warn_only=True) + + # https://github.com/python/cpython/issues/82300#issuecomment-2169035092 if sys.version_info >= (3, 13): SharedMemory = _mpshm.SharedMemory diff --git a/xtuner/v1/utils/processing_utils.py b/xtuner/v1/utils/processing_utils.py new file mode 100644 index 0000000000..4e378be9ca --- /dev/null +++ b/xtuner/v1/utils/processing_utils.py @@ -0,0 +1,23 @@ +from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizerBase, ProcessorMixin + +from .logger import get_logger + + +logger = get_logger() + + +def load_tokenizer(name_or_path: str, **kwargs): + return AutoTokenizer.from_pretrained(name_or_path, **kwargs) + + +def load_processor(name_or_path: str, **kwargs): + try: + proc = AutoProcessor.from_pretrained(name_or_path, **kwargs) + except (OSError, ValueError) as e: + logger.warning(f"Failed to load processor from {name_or_path}: {e}") + proc = None + + if isinstance(proc, PreTrainedTokenizerBase) or not isinstance(proc, ProcessorMixin): + proc = None + + return proc diff --git a/xtuner/v1/utils/rl_test_utils.py b/xtuner/v1/utils/rl_test_utils.py index b4677b5638..d6fec9c82a 100644 --- a/xtuner/v1/utils/rl_test_utils.py +++ b/xtuner/v1/utils/rl_test_utils.py @@ -2,82 +2,14 @@ import multiprocessing import os import time -from typing import Any, Callable, Dict, List +from typing import Any, Dict, List -import httpx import requests import uvicorn from fastapi import FastAPI from pydantic import BaseModel, ConfigDict, Field -from xtuner.v1.ray.judger.native import NativeJudgerConfig - -# try: -from xtuner.v1.ray.rollout.lmdeploy import LMDeployWorker -from xtuner.v1.utils.httpx_utils import HttpRequestErrorType, HttpRequestResult - - -# except ImportError: -# LMDeployWorker = object -class MockTimeoutRolloutWorker(LMDeployWorker): - async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: - try: - raise httpx.TimeoutException("Mocked timeout error") - except Exception as e: - error_type = HttpRequestErrorType.from_exception(e) - result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) - self.logger.info(f"Caught mocked timeout exception: {e.__class__.__name__}") - return result - - def launch_server(self): - pass # Override - - -class MockRequestErrorRolloutWorker(LMDeployWorker): - async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: - try: - req = httpx.Request("POST", url) - raise httpx.RequestError("Mocked httpx request error", request=req) - except Exception as e: - error_type = HttpRequestErrorType.from_exception(e) - result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) - self.logger.info(f"Caught mocked request error exception: {e.__class__.__name__}") - return result - - def launch_server(self): - pass # Override - - -class MockClientErrorRolloutWorker(LMDeployWorker): - async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: - try: - req = httpx.Request("POST", url) - res = httpx.Response(400, request=req) - raise httpx.HTTPStatusError("Mocked client error", request=req, response=res) - except Exception as e: - error_type = HttpRequestErrorType.from_exception(e) - result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) - self.logger.info(f"Caught mocked client exception: {e.__class__.__name__}") - return result - - def launch_server(self): - pass # Override - - -class MockServerErrorRolloutWorker(LMDeployWorker): - async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: - try: - req = httpx.Request("POST", url) - res = httpx.Response(500, request=req) - raise httpx.HTTPStatusError("Mocked server error", request=req, response=res) - except Exception as e: - error_type = HttpRequestErrorType.from_exception(e) - result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) - self.logger.info(f"Caught mocked server exception: {e.__class__.__name__}") - return result - - def launch_server(self): - pass # Override +from xtuner.v1.rl.judger.native import JudgerConfig app = FastAPI() @@ -113,7 +45,7 @@ class JudgeResponse(BaseModel): @app.post("/judge", response_model=JudgeResponse) async def judge(request: JudgeRequest): - from xtuner.v1.ray.judger.gsm8k import compute_reward + from xtuner.v1.rl.judger.gsm8k import compute_reward """Endpoint to compute reward for a given response and label.""" # The compute_reward function returns a float, we wrap it in a dict @@ -158,17 +90,7 @@ def stop(self): print("Server stopped.") -def custom_postprocessor_for_gsm8k(result): - from xtuner.v1.data_proto.rl_data import RLJudgerResponseItem - - if not isinstance(result, list): - result = [result] - judger_response_item = [RLJudgerResponseItem(uid=result[i]["uid"], reward=result[i]) for i in range(len(result))] - return judger_response_item - - -class GSM8KRemoteJudgerConfig(NativeJudgerConfig): +class GSM8KRemoteJudgerConfig(JudgerConfig): judger_name: str - remote_url: str - extra_info: dict = {"score": 1, "format_score": 0} - postprocess_func: Callable = custom_postprocessor_for_gsm8k + reward_handler: str + extra_info: dict = Field(default_factory=lambda: {"score": 1, "format_score": 0}) diff --git a/xtuner/v1/utils/type_helper.py b/xtuner/v1/utils/type_helper.py index cca7fcf7de..3b86ebc3fd 100644 --- a/xtuner/v1/utils/type_helper.py +++ b/xtuner/v1/utils/type_helper.py @@ -40,7 +40,21 @@ def ray_method(f: Callable[Concatenate[C, P], Awaitable[T]]) -> RemoteMethod[P, def ray_method(f: Callable[Concatenate[C, P], T]) -> RemoteMethod[P, T]: ... -def ray_method(f): +def ray_method(f=None, *, num_returns=1, concurrency_group=None): + """Decorator for Ray actor methods. + + Compatible with Ray versions that require at least one of num_returns or concurrency_group. Ray.method() must be + called with keyword args only, then applied to the function: ray.method(num_returns=1)(f). + """ import ray - return ray.method(f) # type: ignore[ret-type] + kwargs = {"num_returns": num_returns} + if concurrency_group is not None: + kwargs["concurrency_group"] = concurrency_group + + if f is None: + # Called as @ray_method(num_returns=...) or @ray_method(concurrency_group=...) + return lambda fn: ray.method(**kwargs)(fn) + + # Called as @ray_method + return ray.method(**kwargs)(f) # type: ignore[ret-type]