diff --git a/examples/v1/config/agent_rl_qwen35_30b_grpo.py b/examples/v1/config/agent_rl_qwen35_30b_grpo.py index f94ebeff78..a26176240f 100644 --- a/examples/v1/config/agent_rl_qwen35_30b_grpo.py +++ b/examples/v1/config/agent_rl_qwen35_30b_grpo.py @@ -48,7 +48,7 @@ max_concurrent_groups = 512 max_prompt_length = 4096 -pack_max_length = 68 * 1024 +pack_max_length = 256 * 1024 max_response_length = 64 * 1024 train_ep_size = 1 @@ -66,7 +66,7 @@ lr = 1e-6 hf_interval = 5 total_epochs = 10 -sp_size = 1 +sp_size = 4 # evaluation settings enable_evaluate = False enable_initial_evaluate = False @@ -218,8 +218,14 @@ def convert_rollout_tractory_to_train(env, group_data_items): model_cfg.text_config.z_loss_cfg = None model_cfg.text_config.balancing_loss_cfg = None model_cfg.text_config.freeze_routers = True +# model_cfg.text_config.mtp_config = MTPConfig( +# num_layers=1, +# loss_scaling_factor=1.0, +# detach_mtp_lm_head_weight=True, +# detach_mtp_inputs=True, +# share_weights=False, +# ) model_cfg.text_config.vocab_size = 251392 -# model_cfg.text_config.embed_grad_max_token_id = 251173 optim_cfg = AdamWConfig( lr=lr, diff --git a/examples/v1/config/interns2-35ba3-base05-20260424a-rl-data260426rc1-56k-agentrl-tb0429rc0.py b/examples/v1/config/interns2-35ba3-base05-20260424a-rl-data260426rc1-56k-agentrl-tb0429rc0.py new file mode 100644 index 0000000000..2c5f3b521e --- /dev/null +++ b/examples/v1/config/interns2-35ba3-base05-20260424a-rl-data260426rc1-56k-agentrl-tb0429rc0.py @@ -0,0 +1,788 @@ +import os + +os.environ["XTUNER_USE_LMDEPLOY"] = "1" + +import hashlib +import json + +# os.environ["HF_HOME"] = "/mnt/shared-storage-user/liukuikun/.cache/huggingface" +# os.environ["TRANSFORMERS_OFFLINE"] = "1" +from copy import deepcopy +from functools import partial + +import ray +from lagent.actions.mcp_client import AsyncMCPClient +from lagent.actions.web_visitor import WebVisitor +from lagent.agents.fc_agent import FunctionCallAgent, get_tool_prompt +from ray.util.placement_group import placement_group + +from projects.claw_bench.claw_tokenize_fn import RLClawTokenizeFnConfig +from projects.tb2_eval.tb2_eval_tokenize_fn import RLTB2EvalTokenizeFnConfig +from projects.tb2_rl.tb2_rl_tokenize_fn import RLTB2RLTokenizeFnConfig +from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig +from xtuner.v1.data_proto.rl_data import ( + RLAgentDataItem, + RLDataFlowItem, + RLJudgerResponseItem, + RLRolloutResponseItem, + RolloutState, + SampleParams, + update_dataflow_item, +) +from xtuner.v1.datasets import DatasetConfig, Qwen3VLTokenizeFnConfig +from xtuner.v1.datasets.config import DataloaderConfig +from xtuner.v1.model import Qwen3_5_VLMoE35BA3Config +from xtuner.v1.ray.base import ( + AcceleratorResourcesConfig, + AutoAcceleratorWorkers, + CPUResourcesConfig, +) +from xtuner.v1.ray.config.worker import RolloutConfig +from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig +from xtuner.v1.ray.environment.agent_env import AgentEnvironment +from xtuner.v1.ray.environment.composed_env import ComposedEnvironment +from xtuner.v1.ray.environment.install_agent_env import InstallAgentEnvironment +from xtuner.v1.ray.environment.lagent.agents import ( + AsyncTokenInOutAgent, + EnvAgent, + JudgerWrapper, + finish_condition_func, +) +from xtuner.v1.ray.environment.lagent.llms.controller_wrapper import ControllerWrapper +from xtuner.v1.ray.environment.lagent.schema import AgentMessage +from xtuner.v1.ray.evaluator import EvaluatorConfig +from xtuner.v1.ray.judger.compass_verifier_v2 import CompassVerifierV2Config +from xtuner.v1.ray.judger.controller import JudgerConfig +from xtuner.v1.ray.rollout import RolloutController +from xtuner.v1.rl.base import WorkerConfig +from xtuner.v1.rl.base.rollout_is import RolloutImportanceSampling +from xtuner.v1.rl.grpo import GRPOLossConfig +from xtuner.v1.train.agent_rl_trainer import AgentRLTrainerConfig +from xtuner.v1.train.trainer import LoadCheckpointConfig +from xtuner.v1.utils.compute_metric import compute_metric + +if not ray.is_initialized(): + ray.init(ignore_reinit_error=True, runtime_env={"env_vars": {"RAY_DEBUG_POST_MORTEM": "0"}}) + +experimental_name = os.path.basename(__file__).split(".py")[0] +# base_work_dir = "/mnt/shared-storage-user/llmit1/user/wangziyi/exp/mindcopilot_rl/work_dirs" +base_work_dir = '/mnt/shared-storage-user/llmit1/user/liukuikun/delivery/interns2_preview_0430rc1' +work_dir = os.path.join(base_work_dir, experimental_name) + +model_name = os.environ["RL_LLM_MODEL"] +# model_path = "/mnt/shared-storage-user/llmit1/user/liujiangning/exp/s2_preview/agent_rl/s2-preview-thinker_sft_0228b_rl0312rc3_fix_klmismatch/20260327042049/hf-40" +model_path = '/mnt/shared-storage-user/llmit1/user/wangziyi/exp/mindcopilot_rl/work_dirs/ckpt/interns2-35ba3-base05-20260424a-rl-data260426rc1-56k-badword-mtp4/20260426021137/hf-140' +stop_word = "<|im_end|>" + +# basic settings +global_batch_size = 128 +prompt_repeat_k = 8 +max_concurrent_groups = 512 + +max_prompt_length = 4096 +pack_max_length = 130 * 1024 +max_response_length = 128 * 1024 + +train_ep_size = 1 +train_sp_size = 4 +rollout_tp_size = 4 +rollout_ep_size = 1 +fp32_lm_head = True +enable_float8_rollout = False +rollout_max_batch_size = 128 +max_prefill_token_num = 1024 +enable_return_routed_experts = True +enable_partial_rollout = False +staleness_threshold = 0.0 +tail_batch_candidate_steps = 0 +auto_resume = True +skip_load_weights = True +lr = 1e-6 +train_optimizer_steps = 2 # mini batch steps +hf_interval = 10 +total_epochs = 10 +# evaluation settings +enable_evaluate = True +enable_initial_evaluate = True +evaluate_step = 5 + +# agent setting +max_turn = 5 # 最大对话轮次 +lower_tool_turn_bound = 3 # 是否惩罚使用工具的轮次小于该值的样本 , None表示不启用该惩罚 +enable_repeated_tool_call_penalty = True # 是否惩罚重复调用工具的样本 +enable_no_thinking_penalty = False +max_tool_response_length = 8192 + + +# 1. resources +resources = AcceleratorResourcesConfig( + accelerator="GPU", + num_workers=64, + num_cpus_per_worker=12, + cpu_memory_per_worker=16 * 1024**3, # 16 GB +) +judger_cpu_resources = CPUResourcesConfig.from_total(total_cpus=16, num_workers=16, total_memory=64 * 1024**3) + +# 2. rollout +rollout_config = RolloutConfig( + env=experimental_name, + device=resources.accelerator, + model_name=model_name, + model_path=model_path, + dtype="bfloat16", + tensor_parallel_size=rollout_tp_size, + expert_parallel_size=rollout_ep_size, + gpu_memory_utilization=0.7, + enable_float8=enable_float8_rollout, + skip_load_weights=skip_load_weights, + context_length=max_response_length, + rollout_max_batch_size_per_instance=rollout_max_batch_size, + chunked_prefill_size=4096, + allow_over_concurrency_ratio=2.0, + rollout_timeout=36000, + enable_return_routed_experts=enable_return_routed_experts, + # return_routed_experts_key=True, + # max_prefill_token_num=max_prefill_token_num, + extra_rollout_config=dict(lmdeploy_log_level="ERROR", lmdeploy_uvicorn_log_level="ERROR"), + fp32_lm_head=fp32_lm_head, +) + +# sampling params +training_sample_params = SampleParams( + max_tokens=max_response_length, top_k=0, top_p=0.999, temperature=1.0, min_tokens=0 +) +evaluation_sample_params = deepcopy(training_sample_params) +evaluation_sample_params.temperature = 0.8 +# evaluation_sample_params.max_tokens = max_response_length + +tokenize_fn_cfg = Qwen3VLTokenizeFnConfig( + processor_path=model_path, + min_pixels=None, + # max_pixels=None, + # max_pixels=2097152, + video_min_total_pixels=None, + video_max_total_pixels=None, + video_min_frames=None, + video_max_frames=None, + fps=None, + rand_video_max_frames=24, + add_vision_id=True, + system_message=None, + hash=None, + enable_3d_rope=False, + oss_loader_cfg=None, + debug=True, + oss_time_log_thr=10, +) + +# 2. dataset +from intern_s1_delivery.dataset.xpuyu_dataset_vl import RLTokenizeFnConfig +from intern_s1_delivery.dataset.xpuyu_dataset_vl import ( + parse_xpuyu_json_cfg as parse_xpuyu_json_cfg_vl, +) + +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig + + +def parse_xpuyu_json_cfg(path, max_prompt_length): + with open(path, "r") as f: + json_cfg = json.load(f) + converted_cfg = [] + for ds_name, ds_cfg in json_cfg.items(): + annotation = ds_cfg["annotation"] + tokenize_fn = ds_cfg["tokenize_fn"] + + if tokenize_fn == "RLClawTokenizeFnConfig": + tokenize_fn_cfg = RLClawTokenizeFnConfig + elif tokenize_fn == "RLTB2RLTokenizeFnConfig": + tokenize_fn_cfg = RLTB2RLTokenizeFnConfig + elif tokenize_fn == "RLTB2EvalTokenizeFnConfig": + tokenize_fn_cfg = RLTB2EvalTokenizeFnConfig + + if isinstance(annotation, str): + annotation = [annotation] + for ann in annotation: + converted_cfg.append( + { + "dataset": DatasetConfig( + name=ds_name, + anno_path=ann, + sample_ratio=ds_cfg["sample_ratio"], + class_name='JsonlDataset', + ), + "tokenize_fn": tokenize_fn_cfg( + root_path=ds_cfg.get("root_path", None), max_length=max_prompt_length + ), + } + ) + return converted_cfg + + +TRAIN_DATA_PATH = ( + '/mnt/shared-storage-user/llmit1/user/liujiangning/data/s1_1_rl_delivery_agent/exp_rl/rl_data_260126.json' +) +TEST_DATA_PATH = ( + '/mnt/shared-storage-user/llmit/user/wangziyi/projs/crg_rl_projects/data/deep_research_data/val_v2.jsonl' +) + +# data_path = os.environ["DATA_PATH"] +data_path = TRAIN_DATA_PATH +# eval_data_path = os.environ["EVAL_DATA_PATH"] +eval_data_path = TEST_DATA_PATH +train_dataset_cfg = parse_xpuyu_json_cfg( + '/mnt/shared-storage-user/llmit/user/wangziyi/projs/xtuner_agent_dev/examples/demo_data/agent_dev/tb2_rl/meta.json', + max_prompt_length, +) +eval_dataset_cfg = ( + [ + { + "dataset": DatasetConfig( + name="tb2-eval", + anno_path="/mnt/shared-storage-user/llmit1/user/liukuikun/delivery/data/tb2_eval_tasks.jsonl", + sample_ratio=1.0, + media_root=None, + class_name='JsonlDataset', + ), + "tokenize_fn": RLTB2EvalTokenizeFnConfig( + root_path="/mnt/shared-storage-user/llmit/user/wangziyi/projs/terminalbench2-harbor-p-cluster/terminal-bench-2", + max_length=max_prompt_length, + ), + }, + ] + if enable_evaluate + else [] +) +dataloader_config = DataloaderConfig(pack_max_length=pack_max_length, collator="fake_collator", pack_level="none") + + +# 3. judger +judger_cfg = JudgerConfig( + reward_judger_configs=[ + CompassVerifierV2Config( + hosts=[ + "10.102.251.61:23333", + "10.102.251.61:23334", + "10.102.251.61:23335", + "10.102.251.61:23336", + "10.102.251.61:23337", + "10.102.251.61:23338", + "10.102.251.61:23339", + "10.102.251.61:23340", + "10.102.216.52:23333", + "10.102.216.52:23334", + "10.102.216.52:23335", + "10.102.216.52:23336", + "10.102.216.52:23337", + "10.102.216.52:23338", + "10.102.216.52:23339", + "10.102.216.52:23340", + "10.102.238.19:23333", + "10.102.238.19:23334", + "10.102.238.19:23335", + "10.102.238.19:23336", + "10.102.238.19:23337", + "10.102.238.19:23338", + "10.102.238.19:23339", + "10.102.238.19:23340", + "10.102.239.68:23333", + "10.102.239.68:23334", + "10.102.239.68:23335", + "10.102.239.68:23336", + "10.102.239.68:23337", + "10.102.239.68:23338", + "10.102.239.68:23339", + "10.102.239.68:23340", + ] + ) + ] +) + +from xtuner.v1.ray.environment.lagent.parsers import Qwen3_5FunctionCallParser + + +def prepare_agent_inputs_for_search(env, group_data_item: RLDataFlowItem): + env_agent = group_data_item.env.agent.extra_info['agent'].env_agent + user_prompt = group_data_item.data.messages[-1]['content'] + env_message = AgentMessage( + sender="env", content=user_prompt, uid=hashlib.md5(user_prompt.encode('utf-8')).hexdigest() + ) + if not env_agent.memory.get_memory(): + set_env_message = AgentMessage(sender="env", content=group_data_item) + env_agent.memory and env_agent.memory.add(set_env_message) # type: ignore[union-attr] + return (env_message,) + + +def convert_rollout_tractory_to_train_for_search(env, group_data_items): + agent_data_items, rollout_response_items, judger_response_items = [], [], [] + for i in range(len(group_data_items)): + history = group_data_items[i].env.rollout.extra_info['agent_state_dict']['policy_agent.memory'] + env_history = group_data_items[i].env.rollout.extra_info['agent_state_dict']['env_agent.memory'] + messages = group_data_items[i].env.rollout.extra_info['agent_message_dict']['policy_agent.messages'] + agent_data_items.append(RLAgentDataItem(extra_info=dict(messages=messages, state={"history": history}))) + rollout_response_items.append( + RLRolloutResponseItem( + response=history[-1]['raw_content'], + response_ids=history[-1]['raw_content_ids'], + logprobs=history[-1]['raw_content_logprobs'], + state=RolloutState.COMPLETED, + ) + ) + judger_response_items.append(RLJudgerResponseItem(reward=dict(score=env_history[-1]['reward']))) + group_data_items = update_dataflow_item(group_data_items, "env.agent", agent_data_items) + group_data_items = update_dataflow_item(group_data_items, "env.rollout", rollout_response_items) + group_data_items = update_dataflow_item(group_data_items, "env.judger", judger_response_items) + return group_data_items + + +def prepare_agent_inputs_for_tb2rl(env, group_data_item: RLDataFlowItem): + return group_data_item + + +def convert_rollout_tractory_to_train_for_tb2rl(env, group_data_items): + agent_data_items, rollout_response_items, judger_response_items = [], [], [] + for i in range(len(group_data_items)): + messages = group_data_items[i].env.agent.extra_info['message_dict']['policy_agent.messages'] + tools = group_data_items[i].env.agent.extra_info['message_dict'].get('policy_agent.tools') + agent_data_items.append(RLAgentDataItem(extra_info=dict(messages=messages, tools=tools))) + breakpoint() + rollout_response_items.append( + RLRolloutResponseItem( + response=messages[-1]['raw_content'], + response_ids=messages[-1]['raw_content_ids'], + logprobs=messages[-1]['raw_content_logprobs'], + state=RolloutState.COMPLETED, + ) + ) + reward_payload = group_data_items[i].env.judger.extra_info['total'] + judger_response_items.append(RLJudgerResponseItem(reward=dict(score=reward_payload))) + group_data_items = update_dataflow_item(group_data_items, "env.agent", agent_data_items) + group_data_items = update_dataflow_item(group_data_items, "env.rollout", rollout_response_items) + group_data_items = update_dataflow_item(group_data_items, "env.judger", judger_response_items) + # breakpoint() + return group_data_items + + +pg = AutoAcceleratorWorkers.build_placement_group(resources) +rollout_controller = ray.remote(max_concurrency=1000)(RolloutController).remote(rollout_config, pg) +load_checkpoint_cfg = LoadCheckpointConfig(load_optimizer_states=False, load_optimizer_args=False) + +search_tool = dict( + type=AsyncMCPClient, + name='SerperSearch', + server_type='http', + rate_limit=500.0, + # max_concurrency=128, + url=[ + 'http://10.102.103.157:8091/mcp', + 'http://10.102.103.155:8096/mcp', + 'http://10.102.103.155:8092/mcp', + 'http://10.102.103.155:8095/mcp', + 'http://10.102.103.155:8097/mcp', + 'http://10.102.103.155:8098/mcp', + 'http://10.102.103.155:8094/mcp', + 'http://10.102.103.155:8093/mcp', + ], +) +browse_tool = dict( + type=AsyncMCPClient, + name='JinaBrowse', + server_type='http', + rate_limit=100.0, + max_concurrency=40, + url=[ + 'http://10.102.103.155:8104/mcp', + 'http://10.102.103.155:8100/mcp', + 'http://10.102.103.148:8101/mcp', + 'http://10.102.103.157:8105/mcp', + 'http://10.102.103.155:8099/mcp', + 'http://10.102.103.155:8102/mcp', + 'http://10.102.103.155:8103/mcp', + 'http://10.102.103.155:8106/mcp', + ], +) +visit_tool = dict( + type=WebVisitor, + browse_tool=dict( + type=AsyncMCPClient, + name='JinaBrowse', + server_type='http', + rate_limit=100.0, + max_concurrency=40, + url=[ + 'http://10.102.103.155:8104/mcp', + 'http://10.102.103.155:8100/mcp', + 'http://10.102.103.148:8101/mcp', + 'http://10.102.103.157:8105/mcp', + 'http://10.102.103.155:8099/mcp', + 'http://10.102.103.155:8102/mcp', + 'http://10.102.103.155:8103/mcp', + 'http://10.102.103.155:8106/mcp', + ], + ), + llm=dict( + type=ControllerWrapper, + rollout_controller=rollout_controller, + sample_params=SampleParams(max_tokens=max_response_length), + tool_call_parser=Qwen3_5FunctionCallParser(), + ), + truncate_browse_response_length=60000, + tokenizer_path=model_path, +) + +tool_template = """# Tools + +You have access to the following functions: + + +{tools} + + +If you choose to call a function ONLY reply in the following format with NO suffix: + + + + +value_1 + + +This is the value for the second parameter +that can span +multiple lines + + + + + +Reminder: +- Function calls MUST follow the specified format: an inner block must be nested within XML tags +- Required parameters MUST be specified +- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after +- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls +""" + +search_browse_tool_prompt = get_tool_prompt([search_tool, browse_tool], template=tool_template) +search_visit_tool_prompt = get_tool_prompt([search_tool, visit_tool], template=tool_template) + +train_agent_with_search_browse = dict( + type=FunctionCallAgent, + policy_agent=dict( + type=AsyncTokenInOutAgent, + llm=dict( + type=ControllerWrapper, + rollout_controller=rollout_controller, + sample_params=SampleParams(max_tokens=max_response_length), + tool_call_parser=Qwen3_5FunctionCallParser(), + ), + template=search_browse_tool_prompt, + ), + env_agent=dict( + type=EnvAgent, + actions=[search_tool, browse_tool], + judger=dict( + type=JudgerWrapper, + judger_cfg=judger_cfg, + placement_group=ray.get( + placement_group(bundles=[{"CPU": 1, "memory": 1024**3}], strategy="PACK").ready(), timeout=30 + ), + ), + max_turn=max_turn, + lower_tool_turn_bound=lower_tool_turn_bound, + enable_repeated_tool_call_penalty=enable_repeated_tool_call_penalty, + enable_no_thinking_penalty=enable_no_thinking_penalty, + max_tool_response_length=max_tool_response_length, + ), + finish_condition=finish_condition_func, + initialize_input=False, +) +train_agent_with_search_visit = dict( + type=FunctionCallAgent, + policy_agent=dict( + type=AsyncTokenInOutAgent, + llm=dict( + type=ControllerWrapper, + rollout_controller=rollout_controller, + sample_params=SampleParams(max_tokens=max_response_length), + tool_call_parser=Qwen3_5FunctionCallParser(), + ), + template=search_visit_tool_prompt, + ), + env_agent=dict( + type=EnvAgent, + actions=[search_tool, visit_tool], + judger=dict( + type=JudgerWrapper, + judger_cfg=judger_cfg, + placement_group=ray.get( + placement_group(bundles=[{"CPU": 1, "memory": 1024**3}], strategy="PACK").ready(), timeout=30 + ), + ), + max_turn=max_turn, + lower_tool_turn_bound=lower_tool_turn_bound, + enable_repeated_tool_call_penalty=enable_repeated_tool_call_penalty, + enable_no_thinking_penalty=enable_no_thinking_penalty, + max_tool_response_length=max_tool_response_length, + ), + finish_condition=finish_condition_func, + initialize_input=False, +) +eval_agent = dict( + type=FunctionCallAgent, + policy_agent=dict( + type=AsyncTokenInOutAgent, + llm=dict( + type=ControllerWrapper, + rollout_controller=rollout_controller, + sample_params=SampleParams(max_tokens=max_response_length), + tool_call_parser=Qwen3_5FunctionCallParser(), + ), + template=search_visit_tool_prompt, + ), + env_agent=dict( + type=EnvAgent, + actions=[search_tool, visit_tool], + judger=dict( + type=JudgerWrapper, + judger_cfg=judger_cfg, + placement_group=ray.get( + placement_group(bundles=[{"CPU": 1, "memory": 1024**3}], strategy="PACK").ready(), timeout=30 + ), + ), + max_turn=max_turn, + lower_tool_turn_bound=None, + enable_repeated_tool_call_penalty=False, + enable_no_thinking_penalty=False, + max_tool_response_length=max_tool_response_length, + ), + finish_condition=finish_condition_func, + initialize_input=False, +) + + +def bucket(key: str, n: int) -> int: + hashkey = hashlib.md5(key.encode('utf-8')).hexdigest() + return sum(ord(c) for c in hashkey) % n + + +def rollout_env_router_fn(item: RLDataFlowItem): + if item.data.extra_info.get('origin_data_source', '').startswith('gaia') or item.data.extra_info.get( + 'origin_data_source' + ) in [ + 'BrowseComp-ZH', + 'HLE', + 'browsecomp', + ]: + return 'eval' + + if 'claw-bench' in item.data.data_source: + return 'train_clawbench' + if 'tb2-rl' in item.data.data_source: + return 'train_tb2rl' + if 'tb2-eval' in item.data.data_source: + return 'eval_tb2eval' + + match bucket(item.data.messages[-1]['content'], 2): + case 0: + return 'train_agent_with_search_browse' + case 1: + return 'train_agent_with_search_visit' + + +environment_config = dict( + type=ComposedEnvironment, + environment=experimental_name, + rollout_controller=rollout_controller, + environments={ + 'train_agent_with_search_browse': dict( + type=AgentEnvironment, + environment='train_agent_with_search_browse', + agent_cfg=train_agent_with_search_browse, + rollout_controller=rollout_controller, + preprocess_func=prepare_agent_inputs_for_search, + postprocess_func=convert_rollout_tractory_to_train_for_search, + ), + 'train_agent_with_search_visit': dict( + type=AgentEnvironment, + environment='train_agent_with_search_visit', + agent_cfg=train_agent_with_search_visit, + rollout_controller=rollout_controller, + preprocess_func=prepare_agent_inputs_for_search, + postprocess_func=convert_rollout_tractory_to_train_for_search, + ), + 'eval': dict( + type=AgentEnvironment, + environment='eval', + agent_cfg=eval_agent, + rollout_controller=rollout_controller, + preprocess_func=prepare_agent_inputs_for_search, + postprocess_func=convert_rollout_tractory_to_train_for_search, + ), + 'train_tb2rl': dict( + type=InstallAgentEnvironment, + environment='train_tb2rl', + rollout_controller=rollout_controller, + preprocess_func=prepare_agent_inputs_for_tb2rl, + postprocess_func=convert_rollout_tractory_to_train_for_tb2rl, + ), + 'eval_tb2eval': dict( + type=InstallAgentEnvironment, + environment='eval_tb2eval', + rollout_controller=rollout_controller, + preprocess_func=prepare_agent_inputs_for_tb2rl, + postprocess_func=convert_rollout_tractory_to_train_for_tb2rl, + ), + }, + router=rollout_env_router_fn, +) + +# 4. dataflow and evaluator +dataflow_config = DataFlowConfig( + env=experimental_name, + max_concurrent=max_concurrent_groups, + enable_partial_rollout=enable_partial_rollout, + tail_batch_candidate_steps=tail_batch_candidate_steps, + staleness_threshold=staleness_threshold, + max_retry_times=3, + prompt_repeat_k=prompt_repeat_k, + global_batch_size=global_batch_size, + sample_params=training_sample_params, +) + +evaluator_cfg = ( + EvaluatorConfig( + enable_evaluate=enable_evaluate, + enable_initial_evaluate=enable_initial_evaluate, + dataset_cfg=eval_dataset_cfg, + tokenizer=model_path, + eval_sample_ratio=1.0, + evaluate_step=evaluate_step, + compute_metric_func=partial( + compute_metric, + source_normalizer={ + 'miroRL': 'websearch', + 'musique': 'websearch', + 'websailor': 'websearch', + 'webdancer': 'websearch', + 'gaia-level1': ('gaia-level1', 'gaia'), + 'gaia-level2': ('gaia-level2', 'gaia'), + 'gaia-level3': ('gaia-level3', 'gaia'), + }, + ), + sample_params=evaluation_sample_params, + max_concurrent=8192, + ) + if enable_evaluate + else None +) + + +def group_sample_filter_func(group_samples): + # filter all correct or all wrong sample + group_samples = [s for s in group_samples if s.env.rollout.response_ids is not None] + + # filter all same reward sample + rewards = [d.env.judger.reward["score"] for d in group_samples] + if len(set(rewards)) == 1: + print(f"filter all same reward sample: {rewards}") + return [] + return group_samples + + +replay_buffer_cfg = ReplayBufferConfig( + dataset_cfg=train_dataset_cfg, + dataloader_cfg=dataloader_config, + tokenizer=model_path, + # postprocessor_func=group_sample_filter_func, +) + +# # 5. Train worker +model_cfg = Qwen3_5_VLMoE35BA3Config( + freeze_vision=True, + freeze_projector=True, +) +model_cfg.float8_cfg = None +model_cfg.text_config.ep_size = 1 +model_cfg.text_config.z_loss_cfg = None +model_cfg.text_config.balancing_loss_cfg = None +model_cfg.text_config.freeze_routers = True +model_cfg.text_config.vocab_size = 251392 +# model_cfg.text_config.embed_grad_max_token_id = 251173 + +optim_cfg = AdamWConfig( + lr=lr, + betas=(0.9, 0.95), + max_grad_norm=1.0, + weight_decay=0.1, + foreach=False, + skip_grad_norm_threshold=0.9, + eps=1e-15, +) +loss_cfg = GRPOLossConfig( + policy_loss_cfg=dict( + cliprange_high=0.2, + cliprange_low=0.2, + loss_type="intern_s1_delivery.modules.pg_loss.pg_loss_fn", + clip_ratio_c=5.0, + log_prob_diff_min=-20.0, + log_prob_diff_max=20.0, + ), + ignore_idx=-100, + use_kl_loss=False, + kl_loss_coef=0.0, + kl_loss_type="low_var_kl", + mode="chunk", + chunk_size=512, + rollout_is=RolloutImportanceSampling( + rollout_is_level="token", + rollout_is_mode="both", + rollout_is_threshold=(5, 0), + rollout_is_mask_threshold=(5, 0.5), + rollout_is_veto_threshold=(20, 0), + ), +) +lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=lr) +fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=train_ep_size) +train_worker_cfg: WorkerConfig = WorkerConfig( + model_cfg=model_cfg, + load_from=model_path, + optim_cfg=optim_cfg, + loss_cfg=loss_cfg, + lr_cfg=lr_cfg, + fsdp_cfg=fsdp_cfg, + sp_size=train_sp_size, + optimizer_steps=train_optimizer_steps, + pack_max_length=pack_max_length, +) + + +# 6. RL Trainer +trainer = AgentRLTrainerConfig( + load_from=model_path, + pg=pg, + environment_config=environment_config, + dataflow_config=dataflow_config, + replay_buffer_config=replay_buffer_cfg, + train_worker_cfg=train_worker_cfg, + evaluator_config=evaluator_cfg, + tokenizer_path=model_path, + work_dir=work_dir, + total_epochs=total_epochs, + hf_interval=hf_interval, + skip_load_weights=skip_load_weights, + auto_resume=auto_resume, + checkpoint_interval=1, + checkpoint_maxkeep=1, + load_checkpoint_cfg=load_checkpoint_cfg, + checkpoint_no_save_optimizer=True, + skip_checkpoint_validation=True, +) + + +import torch.distributed as dist + +from xtuner.v1.train.agent_rl_trainer import AgentRLTrainer + +trainer = AgentRLTrainer.from_config(trainer) +trainer.fit() + +if dist.is_initialized(): + dist.destroy_process_group() diff --git a/examples/v1/config/interns2-35ba3-base05-20260424a-rl-data260426rc1-56k-badword-mtp4_agenticrl_tb2_mtp4_0503rc1.py b/examples/v1/config/interns2-35ba3-base05-20260424a-rl-data260426rc1-56k-badword-mtp4_agenticrl_tb2_mtp4_0503rc1.py new file mode 100644 index 0000000000..ea278af8f7 --- /dev/null +++ b/examples/v1/config/interns2-35ba3-base05-20260424a-rl-data260426rc1-56k-badword-mtp4_agenticrl_tb2_mtp4_0503rc1.py @@ -0,0 +1,1442 @@ +import os + +os.environ["XTUNER_USE_LMDEPLOY"] = "1" + +import hashlib +import json + +# os.environ["TRANSFORMERS_OFFLINE"] = "1" +from copy import deepcopy +from functools import partial + +import ray +from intern_s1_delivery.advantage.rloo_entropy_badword import ( + OverlongRLOOGroupEntropyBadwordAdvantageConfig, +) +from lagent.actions.mcp_client import AsyncMCPClient +from lagent.actions.web_visitor import WebVisitor +from lagent.agents.fc_agent import FunctionCallAgent, get_tool_prompt +from ray.util.placement_group import placement_group + +from claw_bench.claw_tokenize_fn import RLClawTokenizeFnConfig +from tb2_eval.tb2_eval_tokenize_fn import RLTB2EvalTokenizeFnConfig +from tb2_rl.tb2_rl_tokenize_fn import RLTB2RLTokenizeFnConfig +from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig +from xtuner.v1.data_proto.rl_data import ( + RLAgentDataItem, + RLDataFlowItem, + RLJudgerResponseItem, + RLRolloutResponseItem, + RolloutState, + SampleParams, + update_dataflow_item, +) +from xtuner.v1.datasets import Qwen3VLTokenizeFnConfig +from xtuner.v1.datasets.config import DataloaderConfig +from xtuner.v1.model import Qwen3_5_VLMoE35BA3Config +from xtuner.v1.module.mtp import MTPConfig +from xtuner.v1.ray.base import ( + AcceleratorResourcesConfig, + AutoAcceleratorWorkers, + CPUResourcesConfig, +) +from xtuner.v1.ray.config.worker import RolloutConfig +from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig +from xtuner.v1.ray.environment.agent_env import AgentEnvironment +from xtuner.v1.ray.environment.composed_env import ComposedEnvironment +from xtuner.v1.ray.environment.install_agent_env import InstallAgentEnvironment +from xtuner.v1.ray.environment.lagent.agents import ( + AsyncTokenInOutAgent, + EnvAgent, + JudgerWrapper, + finish_condition_func, +) +from xtuner.v1.ray.environment.lagent.llms.controller_wrapper import ControllerWrapper +from xtuner.v1.ray.environment.lagent.schema import AgentMessage +from xtuner.v1.ray.evaluator import EvaluatorConfig +from xtuner.v1.ray.judger.compass_verifier_v2 import CompassVerifierV2Config +from xtuner.v1.ray.judger.controller import JudgerConfig +from xtuner.v1.ray.judger.frontierscience_judger import FrontierScienceJudgerConfig +from xtuner.v1.ray.judger.review import ReviewJudgerConfig +from xtuner.v1.ray.judger.sgi_judger import SGIJudgerConfig +from xtuner.v1.ray.rollout import RolloutController +from xtuner.v1.rl.base import WorkerConfig +from xtuner.v1.rl.base.rollout_is import RolloutImportanceSampling +from xtuner.v1.rl.grpo import GRPOLossConfig +from xtuner.v1.train.agent_rl_trainer import AgentRLTrainerConfig +from xtuner.v1.train.trainer import LoadCheckpointConfig +from xtuner.v1.utils.compute_metric import compute_metric + +if not ray.is_initialized(): + ray.init(ignore_reinit_error=True, runtime_env={"env_vars": {"RAY_DEBUG_POST_MORTEM": "0"}}) + +experimental_name = os.path.basename(__file__).split(".py")[0] +base_work_dir = "/mnt/shared-storage-user/llmit1/user/liukuikun/delivery/interns2_preview_0430rc9" +work_dir = os.path.join(base_work_dir, experimental_name) +model_name = os.environ["RL_LLM_MODEL"] +model_path = '/mnt/shared-storage-user/llmit1/user/wangziyi/exp/mindcopilot_rl/work_dirs/ckpt/interns2-35ba3-base05-20260424a-rl-data260428rc0-56k-badword-mtp4-resume800/20260430074140/hf-40' +stop_word = "<|im_end|>" + +# basic settings +global_batch_size = 256 +prompt_repeat_k = 16 +max_concurrent_groups = 512 + +max_prompt_length = 16 * 1024 +pack_max_length = 130 * 1024 +max_response_length = 128 * 1024 + +train_ep_size = 1 +train_sp_size = 2 +rollout_tp_size = 4 +rollout_ep_size = 1 +fp32_lm_head = True +enable_float8_rollout = False +rollout_max_batch_size = 128 +max_prefill_token_num = 1024 +enable_return_routed_experts = True +enable_partial_rollout = False +staleness_threshold = 0.0 +tail_batch_candidate_steps = 0 +auto_resume = True +skip_load_weights = True +lr = 1e-6 +train_optimizer_steps = 8 # mini batch steps +hf_interval = 5 +total_epochs = 10 +# evaluation settings +enable_evaluate = True +enable_initial_evaluate = True +evaluate_step = 5 + +# agent setting +max_turn = 50 # 最大对话轮次 +lower_tool_turn_bound = 12 # 是否惩罚使用工具的轮次小于该值的样本 , None表示不启用该惩罚 +lower_tool_turn_bound_science = 8 +enable_repeated_tool_call_penalty = True # 是否惩罚重复调用工具的样本 +enable_no_thinking_penalty = False +max_tool_response_length = 8192 + + +# 1. resources +resources = AcceleratorResourcesConfig( + accelerator="GPU", + num_workers=64, + num_cpus_per_worker=12, + cpu_memory_per_worker=16 * 1024**3, # 16 GB +) +judger_cpu_resources = CPUResourcesConfig.from_total(total_cpus=64, num_workers=64, total_memory=512 * 1024**3) + +# 2. rollout +rollout_config = RolloutConfig( + env=experimental_name, + device=resources.accelerator, + model_name=model_name, + model_path=model_path, + dtype="bfloat16", + tensor_parallel_size=rollout_tp_size, + expert_parallel_size=rollout_ep_size, + gpu_memory_utilization=0.7, + enable_float8=enable_float8_rollout, + skip_load_weights=skip_load_weights, + context_length=max_response_length, + rollout_max_batch_size_per_instance=rollout_max_batch_size, + # chunked_prefill_size=4096, + allow_over_concurrency_ratio=1.2, + rollout_timeout=1800, + enable_return_routed_experts=enable_return_routed_experts, + # max_prefill_token_num=max_prefill_token_num, + extra_rollout_config=dict( + lmdeploy_log_level="ERROR", + lmdeploy_uvicorn_log_level="ERROR", + lmdeploy_speculative_algorithm='qwen3_5_mtp', + lmdeploy_speculative_num_draft_tokens=4, + ), + fp32_lm_head=fp32_lm_head, +) + +# sampling params +training_sample_params = SampleParams( + max_tokens=max_response_length, top_k=0, top_p=0.999, temperature=1.0, min_tokens=0 +) +evaluation_sample_params = deepcopy(training_sample_params) +evaluation_sample_params.temperature = 0.8 +# evaluation_sample_params.max_tokens = max_response_length + +data_judger_mapping = { + "agent": {"compass_verifier_v2": 1.0}, + "agent_science": {"compass_verifier_v2": 1.0}, + 'GAIA_sft_1229': {"compass_verifier_v2": 1.0}, + 'gaia-level1': {"compass_verifier_v2": 1.0}, + 'gaia-level2': {"compass_verifier_v2": 1.0}, + 'gaia-level3': {"compass_verifier_v2": 1.0}, + 'BrowseComp-ZH': {"compass_verifier_v2": 1.0}, + 'HLE': {"compass_verifier_v2": 1.0}, + 'browsecomp': {"compass_verifier_v2": 1.0}, + 'math': {"compass_verifier_v2": 1.0}, + 'AIME2024': {"compass_verifier_v2": 1.0}, + 'AIME2025': {"compass_verifier_v2": 1.0}, + 'aime2026': {"compass_verifier_v2": 1.0}, + 'hmmt26': {"compass_verifier_v2": 1.0}, + 'IMO-Bench-AnswerBench': {"compass_verifier_v2": 1.0}, + 'UGD_hard': {"compass_verifier_v2": 1.0}, + 'openreview': {"openreview": 1.0}, + 'openreview_test': {"openreview": 1.0}, + 'sgi-deep-research': {"sgi_judger": 1.0}, + 'frontierscience': {"frontierscience_judger": 1.0}, +} +tokenize_fn_cfg = Qwen3VLTokenizeFnConfig( + processor_path=model_path, + min_pixels=None, + # max_pixels=None, + # max_pixels=2097152, + video_min_total_pixels=None, + video_max_total_pixels=None, + video_min_frames=None, + video_max_frames=None, + fps=None, + rand_video_max_frames=24, + add_vision_id=True, + system_message=None, + hash=None, + enable_3d_rope=False, + oss_loader_cfg=None, + debug=True, + oss_time_log_thr=10, + chat_template='qwen3-vl-rl', +) + +from intern_s1_delivery.dataset.xpuyu_dataset_vl import ( + RLTokenizeFnConfig, +) +from intern_s1_delivery.dataset.xpuyu_dataset_vl import ( + parse_xpuyu_json_cfg as parse_xpuyu_json_cfg_vl, +) + +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig + +# 2. dataset +# from lagent_rl.datasets.parse_cfg import parse_xpuyu_json_cfg + + +def parse_xpuyu_json_cfg(path, max_prompt_length): + with open(path, "r") as f: + json_cfg = json.load(f) + converted_cfg = [] + for ds_name, ds_cfg in json_cfg.items(): + annotation = ds_cfg["annotation"] + tokenize_fn = ds_cfg["tokenize_fn"] + + if tokenize_fn == "RLClawTokenizeFnConfig": + tokenize_fn_cfg = RLClawTokenizeFnConfig + elif tokenize_fn == "RLTB2RLTokenizeFnConfig": + tokenize_fn_cfg = RLTB2RLTokenizeFnConfig + elif tokenize_fn == "RLTB2EvalTokenizeFnConfig": + tokenize_fn_cfg = RLTB2EvalTokenizeFnConfig + + if isinstance(annotation, str): + annotation = [annotation] + for ann in annotation: + converted_cfg.append( + { + "dataset": DatasetConfig( + name=ds_name, + anno_path=ann, + sample_ratio=ds_cfg["sample_ratio"], + class_name='JsonlDataset', + ), + "tokenize_fn": tokenize_fn_cfg( + root_path=ds_cfg.get("root_path", None), max_length=max_prompt_length + ), + } + ) + return converted_cfg + + +TRAIN_DATA_PATH_SCIENCE_SEARCH = "/mnt/shared-storage-user/llmit/user/liujiangning/projects/crg_rl_projects/src/lagent_rl/scripts/s2_preview_35b_agentrl/interns2_35ba3_b03_0413a_reasoningRL_scienceSearch0421a/scienceSearch0421a.json" +TRAIN_DATA_PATH_INTERNLM_SCIENCE = "/mnt/shared-storage-user/llmit/user/liujiangning/projects/crg_rl_projects/src/lagent_rl/scripts/s2_preview_35b_agentrl/interns2_35ba3_b03_0413a_reasoningRL_scienceSearch0423a/scienceSearch0423a.json" + +TRAIN_DATA_PATH_SEARCH = ( + '/mnt/shared-storage-user/llmit1/user/liujiangning/data/s1_1_rl_delivery_agent/exp_rl/rl_data_260126.json' +) +TRAIN_DATA_PATH_MATH = '/mnt/shared-storage-user/llmit1/user/liujiangning/data/s1_1_rl_delivery_agent/exp_rl/train_interns1-1_260124rc0_pure_math.json' +TEST_DATA_PATH_MATH = '/mnt/shared-storage-user/llmit1/user/liujiangning/data/s1_1_rl_delivery_agent/math_benchmark/val_python_toolcall.json' + + +train_dataset_cfg_science_search = parse_xpuyu_json_cfg_vl( + TRAIN_DATA_PATH_SCIENCE_SEARCH, tokenize_fn_cfg, max_prompt_length, data_judger_mapping +) +train_dataset_cfg_internlm_science = parse_xpuyu_json_cfg_vl( + TRAIN_DATA_PATH_INTERNLM_SCIENCE, tokenize_fn_cfg, max_prompt_length, data_judger_mapping +) + +train_dataset_cfg_search = parse_xpuyu_json_cfg_vl( + TRAIN_DATA_PATH_SEARCH, tokenize_fn_cfg, max_prompt_length, data_judger_mapping +) +train_dataset_cfg_math = parse_xpuyu_json_cfg_vl( + TRAIN_DATA_PATH_MATH, tokenize_fn_cfg, max_prompt_length, data_judger_mapping +) +train_dataset_cfg_review = [ + { + "dataset": DatasetConfig( + name='openreview', + anno_path="/mnt/shared-storage-user/llmit/user/wangziyi/projs/demo/xtuner/examples/demo_data/agent/openreview/train.jsonl", + sample_ratio=0.1, + media_root=None, + class_name='VLMJsonlDataset', + ), + "tokenize_fn": RLTokenizeFnConfig( + tokenize_fn_cfg=tokenize_fn_cfg, + system_prompt=None, + max_length=max_prompt_length, + data_judger_mapping=data_judger_mapping, + ignore_multimodal_info=False, + ), + } +] +train_dataset_cfg_tb2rl = parse_xpuyu_json_cfg( + '/mnt/shared-storage-user/llmit/user/wangziyi/projs/xtuner_agent_dev/examples/demo_data/agent_dev/tb2_rl/meta.json', + max_prompt_length, +) +train_dataset_cfg = ( + train_dataset_cfg_science_search + + train_dataset_cfg_internlm_science + + train_dataset_cfg_search + + train_dataset_cfg_math + + train_dataset_cfg_review + + train_dataset_cfg_tb2rl +) +eval_dataset_cfg_search = [ + { + "dataset": DatasetConfig( + name="gaia", + anno_path="/mnt/shared-storage-user/llmit/user/wangziyi/projs/crg_rl_projects/data/gaia_text_103.jsonl", + sample_ratio=4.0, + media_root=None, + class_name='VLMJsonlDataset', + ), + "tokenize_fn": RLTokenizeFnConfig( + tokenize_fn_cfg=tokenize_fn_cfg, + system_prompt=None, + max_length=max_prompt_length, + data_judger_mapping=data_judger_mapping, + ignore_multimodal_info=True, + ), + }, + { + "dataset": DatasetConfig( + name="browsecomp-zh", + anno_path="/mnt/shared-storage-user/llmit/user/wangziyi/projs/crg_rl_projects/data/browsecomp-zh.jsonl", + sample_ratio=0.0, + media_root=None, + class_name='VLMJsonlDataset', + ), + "tokenize_fn": RLTokenizeFnConfig( + tokenize_fn_cfg=tokenize_fn_cfg, + system_prompt=None, + max_length=max_prompt_length, + data_judger_mapping=data_judger_mapping, + ignore_multimodal_info=True, + ), + }, + { + "dataset": DatasetConfig( + name="browsecomp", + anno_path="/mnt/shared-storage-user/llmit/user/wangziyi/projs/crg_rl_projects/data/browsecomp.jsonl", + sample_ratio=1.0, + media_root=None, + class_name='VLMJsonlDataset', + ), + "tokenize_fn": RLTokenizeFnConfig( + tokenize_fn_cfg=tokenize_fn_cfg, + system_prompt=None, + max_length=max_prompt_length, + data_judger_mapping=data_judger_mapping, + ignore_multimodal_info=True, + ), + }, + { + "dataset": DatasetConfig( + name="hle", + anno_path="/mnt/shared-storage-user/llmit/user/wangziyi/projs/crg_rl_projects/data/hle.jsonl", + sample_ratio=1.0, + media_root=None, + class_name='VLMJsonlDataset', + ), + "tokenize_fn": RLTokenizeFnConfig( + tokenize_fn_cfg=tokenize_fn_cfg, + system_prompt=None, + max_length=max_prompt_length, + data_judger_mapping=data_judger_mapping, + ignore_multimodal_info=True, + ), + }, + { + "dataset": DatasetConfig( + name="sgi-deep-research", + anno_path="/mnt/shared-storage-user/llmit1/user/liujiangning/data/eval_benchmark_testset/sgi_deep_research_gaia_format.jsonl", + sample_ratio=1.0, + media_root=None, + class_name='VLMJsonlDataset', + ), + "tokenize_fn": RLTokenizeFnConfig( + tokenize_fn_cfg=tokenize_fn_cfg, + system_prompt=None, + max_length=max_prompt_length, + data_judger_mapping=data_judger_mapping, + ignore_multimodal_info=True, + ), + }, + { + "dataset": DatasetConfig( + name="frontierscience", + anno_path="/mnt/shared-storage-user/llmit1/user/liujiangning/data/eval_benchmark_testset/frontierscience_gaia_format.jsonl", + sample_ratio=1.0, + media_root=None, + class_name='VLMJsonlDataset', + ), + "tokenize_fn": RLTokenizeFnConfig( + tokenize_fn_cfg=tokenize_fn_cfg, + system_prompt=None, + max_length=max_prompt_length, + data_judger_mapping=data_judger_mapping, + ignore_multimodal_info=True, + ), + }, +] +eval_dataset_cfg_math = parse_xpuyu_json_cfg_vl( + TEST_DATA_PATH_MATH, tokenize_fn_cfg, max_prompt_length, data_judger_mapping, ignore_multimodal_info=True +) +eval_dataset_cfg_review = [ + { + "dataset": DatasetConfig( + name="openreview", + anno_path="/mnt/shared-storage-user/llmit/user/wangziyi/projs/demo/xtuner/examples/demo_data/agent/openreview/test.jsonl", + sample_ratio=1.0, + media_root=None, + class_name='VLMJsonlDataset', + ), + "tokenize_fn": RLTokenizeFnConfig( + tokenize_fn_cfg=tokenize_fn_cfg, + system_prompt=None, + max_length=max_prompt_length, + data_judger_mapping=data_judger_mapping, + ignore_multimodal_info=True, + ), + }, +] +eval_data_cfg_tb2eval = [ + { + "dataset": DatasetConfig( + name="tb2-eval", + anno_path="/mnt/shared-storage-user/llmit1/user/liukuikun/delivery/data/tb2_eval_tasks.jsonl", + sample_ratio=1.0, + media_root=None, + class_name='JsonlDataset', + ), + "tokenize_fn": RLTB2EvalTokenizeFnConfig( + root_path="/mnt/shared-storage-user/llmit/user/wangziyi/projs/terminalbench2-harbor-p-cluster/terminal-bench-2", + max_length=max_prompt_length, + ), + }, +] +# eval_dataset_cfg = ( +# (eval_dataset_cfg_search + eval_dataset_cfg_math + eval_dataset_cfg_review + eval_data_cfg_tb2eval) +# if enable_evaluate +# else [] +# ) +eval_dataset_cfg = eval_data_cfg_tb2eval +dataloader_config = DataloaderConfig(pack_max_length=pack_max_length, collator="fake_collator", pack_level="none") + + +# 3. judger +compass_judger_cfg = JudgerConfig( + enable_weighted_judgers=True, + reward_judger_configs=[ + CompassVerifierV2Config( + hosts=[ + "10.102.251.61:23333", + "10.102.251.61:23334", + "10.102.251.61:23335", + "10.102.251.61:23336", + "10.102.251.61:23337", + "10.102.251.61:23338", + "10.102.251.61:23339", + "10.102.251.61:23340", + "10.102.216.52:23333", + "10.102.216.52:23334", + "10.102.216.52:23335", + "10.102.216.52:23336", + "10.102.216.52:23337", + "10.102.216.52:23338", + "10.102.216.52:23339", + "10.102.216.52:23340", + "10.102.238.19:23333", + "10.102.238.19:23334", + "10.102.238.19:23335", + "10.102.238.19:23336", + "10.102.238.19:23337", + "10.102.238.19:23338", + "10.102.238.19:23339", + "10.102.238.19:23340", + "10.102.239.68:23333", + "10.102.239.68:23334", + "10.102.239.68:23335", + "10.102.239.68:23336", + "10.102.239.68:23337", + "10.102.239.68:23338", + "10.102.239.68:23339", + "10.102.239.68:23340", + ] + ), + SGIJudgerConfig( + hosts=[ + "10.102.213.32:30030", + "10.102.213.32:30031", + "10.102.213.32:30032", + "10.102.213.32:30033", + "10.102.213.32:30034", + "10.102.213.32:30035", + "10.102.213.32:30036", + "10.102.213.32:30037", + ], + model_name="/mnt/shared-storage-user/gpfs2-shared-public/huggingface/hub/models--Qwen--Qwen3.5-35B-A3B/snapshots/ec2d4ece1ffb563322cbee9a48fe0e3fcbce0307", + num_ray_actors=1, + request_timeout=60, + ), + FrontierScienceJudgerConfig( + hosts=[ + "10.102.213.32:30030", + "10.102.213.32:30031", + "10.102.213.32:30032", + "10.102.213.32:30033", + "10.102.213.32:30034", + "10.102.213.32:30035", + "10.102.213.32:30036", + "10.102.213.32:30037", + ], + model_name="/mnt/shared-storage-user/gpfs2-shared-public/huggingface/hub/models--Qwen--Qwen3.5-35B-A3B/snapshots/ec2d4ece1ffb563322cbee9a48fe0e3fcbce0307", + num_ray_actors=1, + request_timeout=60, + ), + ], +) +review_judger_cfg = JudgerConfig(reward_judger_configs=[ReviewJudgerConfig(judger_name="openreview")]) + +from xtuner.v1.ray.judger.controller import JudgerController + +compass_judger_controller = JudgerController.remote( + compass_judger_cfg, + ray.get( + placement_group( + bundles=[{"CPU": 1, "memory": 1024**3}] * len(compass_judger_cfg.reward_judger_configs), + strategy="PACK", + ).ready(), + timeout=30, + ) +) +review_judger_controller = JudgerController.remote( + review_judger_cfg, + ray.get( + placement_group( + bundles=[{"CPU": 1, "memory": 1024**3}] * len(review_judger_cfg.reward_judger_configs), + strategy="PACK", + ).ready(), + timeout=30, + ) +) + + +from xtuner.v1.ray.environment.lagent.parsers import Qwen3_5FunctionCallParser + + +def prepare_agent_inputs(env, group_data_item: RLDataFlowItem): + env_agent = group_data_item.env.agent.extra_info.pop('agent').env_agent + user_prompt = group_data_item.data.messages[-1]['content'] + env_message = AgentMessage( + sender="env", content=user_prompt, uid=hashlib.md5(user_prompt.encode('utf-8')).hexdigest() + ) + if not env_agent.memory.get_memory(): + set_env_message = AgentMessage(sender="env", content=group_data_item) + env_agent.memory and env_agent.memory.add(set_env_message) # type: ignore[union-attr] + return (env_message,) + + +def convert_rollout_tractory_to_train(env, group_data_items): + agent_data_items, rollout_response_items, judger_response_items = [], [], [] + for i in range(len(group_data_items)): + history = group_data_items[i].env.rollout.extra_info['agent_state_dict']['policy_agent.memory'] + env_history = group_data_items[i].env.rollout.extra_info['agent_state_dict']['env_agent.memory'] + messages = group_data_items[i].env.rollout.extra_info['agent_message_dict']['policy_agent.messages'] + agent_data_items.append(RLAgentDataItem(extra_info=dict(messages=messages, state={"history": history}))) + rollout_response_items.append( + RLRolloutResponseItem( + response=history[-1]['raw_content'], + response_ids=history[-1]['raw_content_ids'], + logprobs=history[-1]['raw_content_logprobs'], + state=RolloutState.COMPLETED, + ) + ) + reward_payload = env_history[-1]['reward'] + if isinstance(reward_payload, dict): + if 'score' in reward_payload and group_data_items[i].data.extra_info.get('origin_data_source') in [ + 'openreview', + ]: # scale down the reward for review data + reward_payload['score'] = reward_payload['score'] / 10 + judger_response_items.append(RLJudgerResponseItem(reward=reward_payload)) + else: + judger_response_items.append(RLJudgerResponseItem(reward=dict(score=reward_payload))) + group_data_items = update_dataflow_item(group_data_items, "env.agent", agent_data_items) + group_data_items = update_dataflow_item(group_data_items, "env.rollout", rollout_response_items) + group_data_items = update_dataflow_item(group_data_items, "env.judger", judger_response_items) + return group_data_items + + +def prepare_agent_inputs_for_tb2rl(env, group_data_item: RLDataFlowItem): + return group_data_item + + +def convert_rollout_tractory_to_train_for_tb2rl(env, group_data_items): + agent_data_items, rollout_response_items, judger_response_items = [], [], [] + for i in range(len(group_data_items)): + messages = group_data_items[i].env.agent.extra_info['message_dict']['policy_agent.messages'] + tools = group_data_items[i].env.agent.extra_info['message_dict'].get('policy_agent.tools') + agent_data_items.append(RLAgentDataItem(extra_info=dict(messages=messages, tools=tools))) + # breakpoint() + rollout_response_items.append( + RLRolloutResponseItem( + response=messages[-1]['raw_content'], + response_ids=messages[-1]['raw_content_ids'], + logprobs=messages[-1]['raw_content_logprobs'], + state=RolloutState.COMPLETED, + ) + ) + reward_payload = group_data_items[i].env.judger.extra_info['total'] + judger_response_items.append(RLJudgerResponseItem(reward=dict(score=reward_payload))) + group_data_items = update_dataflow_item(group_data_items, "env.agent", agent_data_items) + group_data_items = update_dataflow_item(group_data_items, "env.rollout", rollout_response_items) + group_data_items = update_dataflow_item(group_data_items, "env.judger", judger_response_items) + # breakpoint() + return group_data_items + + +pg = AutoAcceleratorWorkers.build_placement_group(resources) +rollout_controller = ray.remote(max_concurrency=1000)(RolloutController).remote(rollout_config, pg) +load_checkpoint_cfg = LoadCheckpointConfig(load_optimizer_states=False, load_optimizer_args=False) + +search_tool = AsyncMCPClient( + # type=AsyncMCPClient, + name='SerperSearch', + server_type='http', + rate_limit=500.0, + # max_concurrency=128, + url=[ + 'http://10.102.103.157:8091/mcp', + 'http://10.102.103.155:8096/mcp', + 'http://10.102.103.155:8092/mcp', + 'http://10.102.103.155:8095/mcp', + 'http://10.102.103.155:8097/mcp', + 'http://10.102.103.155:8098/mcp', + 'http://10.102.103.155:8094/mcp', + 'http://10.102.103.155:8093/mcp', + ], +) +browse_tool = AsyncMCPClient( + # type=AsyncMCPClient, + name='JinaBrowse', + server_type='http', + rate_limit=100.0, + max_concurrency=40, + url=[ + 'http://10.102.103.155:8104/mcp', + 'http://10.102.103.155:8100/mcp', + 'http://10.102.103.148:8101/mcp', + 'http://10.102.103.157:8105/mcp', + 'http://10.102.103.155:8099/mcp', + 'http://10.102.103.155:8102/mcp', + 'http://10.102.103.155:8103/mcp', + 'http://10.102.103.155:8106/mcp', + ], +) +visit_tool = WebVisitor( + # type=WebVisitor, + browse_tool=dict( + type=AsyncMCPClient, + name='JinaBrowse', + server_type='http', + rate_limit=100.0, + max_concurrency=40, + url=[ + 'http://10.102.103.155:8104/mcp', + 'http://10.102.103.155:8100/mcp', + 'http://10.102.103.148:8101/mcp', + 'http://10.102.103.157:8105/mcp', + 'http://10.102.103.155:8099/mcp', + 'http://10.102.103.155:8102/mcp', + 'http://10.102.103.155:8103/mcp', + 'http://10.102.103.155:8106/mcp', + ], + ), + llm=dict( + type=ControllerWrapper, + rollout_controller=rollout_controller, + sample_params=SampleParams(max_tokens=max_response_length), + tool_call_parser=Qwen3_5FunctionCallParser(), + ), + truncate_browse_response_length=60000, + tokenizer_path=model_path, +) +arxiv_tool = AsyncMCPClient( + # type=AsyncMCPClient, + name='arxiv_search', + server_type='http', + rate_limit=50.0, + max_concurrency=20, + url=[ + 'http://10.102.252.176:2364/mcp', + ], +) + + +tool_template = """# Tools + +You have access to the following functions: + + +{tools} + + +If you choose to call a function ONLY reply in the following format with NO suffix: + + + + +value_1 + + +This is the value for the second parameter +that can span +multiple lines + + + + + +Reminder: +- Function calls MUST follow the specified format: an inner block must be nested within XML tags +- Required parameters MUST be specified +- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after +- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls +""" + +review_sys_prompt = r"""You are an expert reviewer in the field of pre-training, 3D human pose and shape estimation, and self-supervised representation learning for the ICLR 2023 conference. Your responsibilitie is conducting an initial review. You must follow a strict reasoning process. + +### INTERACTION PROTOCOL +You must perform a **Thought-Action** loop for every step. Do not rush to the final review. +1. **THINK (``)** - **Before Tool Use**: Analyze the paper step-by-step. Identify gaps in your knowledge. Formulate search queries. Explicitly check if the `end_date` constraint is met. + - **After Tool Response**: Analyze the search results. synthesis the information. Decide if more searches are needed or if you have sufficient context to write the review. + - The reasoning content must be enclosed with `` and `` tags. + +2. **ACTION (``)**: + - If you need external information, output a tool call enclosed in `` and `` tags. + - The content inside the tags must be valid JSON format. + - **CRITICAL**: When calling `arxiv_search`, you MUST set the `end_date` argument to '20230603'. + +3. **FINALIZE (``)**: + - Only when you have completed all necessary research and reasoning, output the final review inside `` and `` tags. + - The review must include: Summary, Strengths, Weaknesses, Questions, and References. + +### CITATION & REFERENCE STANDARDS (CRITICAL) +You must adhere to the following strict formatting rules for the final review: +1. **Sequential Numbering**: Citations in the text must be numbered sequentially starting from [1] based on the order they first appear (e.g., [1], [2], [3]). **Do NOT use the ID returned by the search tool (e.g., do not use [81], [28]).** +2. **Inline Citation**: Every external claim must have an inline citation. Example: 'Recent studies [1] have shown that...' +3. **Reference List**: The 'References' section at the end must strictly match the inline citations. Each entry must contain: + - Format: `[ID] Authors. **Title**. Venue, Year. URL` + - Example: `[1] J. Smith et al. **Deep Learning**. NeurIPS, 2023. https://arxiv.org/...` + - Ensure the Title and URL are complete. The URL field in references must be the EXACT URL returned by the search tool. Do not use placeholders like '...'. + +### OUTPUT FORMAT (``) + +## Summary +... +## Strengths +... +## Weaknesses +... +## Questions +... +## References +[1] ... +[2] ... + + +### CONSTRAINTS +- Never output `` and `` in the same turn. +- Strictly adhere to the submission deadline: do not use knowledge published after 20230603.- Do NOT hallucinate citations. You can ONLY cite papers that explicitly appear in the search results from the output. If you cannot find a relevant paper, do not cite a fake one. +""" + + +from lagent_rl.environment.lagent_ext.actions.python_executor import PythonExecutor + +python_action = PythonExecutor( + # type=PythonExecutor, + rate_limit_qps=500.0, + burst=20, + retries=5, + connect_timeout=5.0, + read_timeout=30.0 +) + +# tool prompts with python (for science search) +search_browse_python_tool_prompt = get_tool_prompt([search_tool, browse_tool, python_action], template=tool_template) +search_visit_python_tool_prompt = get_tool_prompt([search_tool, visit_tool, python_action], template=tool_template) +# tool prompts without python (for pure search) +search_browse_tool_prompt = get_tool_prompt([search_tool, browse_tool], template=tool_template) +search_visit_tool_prompt = get_tool_prompt([search_tool, visit_tool], template=tool_template) +# other tool prompts +review_tool_prompt = get_tool_prompt([arxiv_tool], template=tool_template) +python_tool_prompt = get_tool_prompt([python_action], template=tool_template) + + +# ============================================================ +# Science search agents (with python_action) - for agent_science data +# ============================================================ +train_science_search_browse_agent = dict( + type=FunctionCallAgent, + policy_agent=dict( + type=AsyncTokenInOutAgent, + llm=dict( + type=ControllerWrapper, + rollout_controller=rollout_controller, + sample_params=SampleParams(max_tokens=max_response_length), + tool_call_parser=Qwen3_5FunctionCallParser(), + ), + template=search_browse_python_tool_prompt, + ), + env_agent=dict( + type=EnvAgent, + actions=[search_tool, browse_tool, python_action], + judger=JudgerWrapper( + # type=JudgerWrapper, + # judger_cfg=compass_judger_cfg, + # placement_group=ray.get( + # placement_group( + # bundles=[{"CPU": 1, "memory": 1024**3}] * len(compass_judger_cfg.reward_judger_configs), + # strategy="PACK", + # ).ready(), + # timeout=30, + # ), + judger_controller=compass_judger_controller, + ), + max_turn=max_turn, + lower_tool_turn_bound=lower_tool_turn_bound_science, + enable_repeated_tool_call_penalty=enable_repeated_tool_call_penalty, + enable_no_thinking_penalty=enable_no_thinking_penalty, + max_tool_response_length=max_tool_response_length, + ), + finish_condition=finish_condition_func, + initialize_input=False, +) +train_science_search_visit_agent = dict( + type=FunctionCallAgent, + policy_agent=dict( + type=AsyncTokenInOutAgent, + llm=dict( + type=ControllerWrapper, + rollout_controller=rollout_controller, + sample_params=SampleParams(max_tokens=max_response_length), + tool_call_parser=Qwen3_5FunctionCallParser(), + ), + template=search_visit_python_tool_prompt, + ), + env_agent=dict( + type=EnvAgent, + actions=[search_tool, visit_tool, python_action], + judger=JudgerWrapper( + # type=JudgerWrapper, + # judger_cfg=compass_judger_cfg, + # placement_group=ray.get( + # placement_group( + # bundles=[{"CPU": 1, "memory": 1024**3}] * len(compass_judger_cfg.reward_judger_configs), + # strategy="PACK", + # ).ready(), + # timeout=30, + # ), + judger_controller=compass_judger_controller, + ), + max_turn=max_turn, + lower_tool_turn_bound=lower_tool_turn_bound_science, + enable_repeated_tool_call_penalty=enable_repeated_tool_call_penalty, + enable_no_thinking_penalty=enable_no_thinking_penalty, + max_tool_response_length=max_tool_response_length, + ), + finish_condition=finish_condition_func, + initialize_input=False, +) + + +train_agent_with_search_browse = dict( + type=FunctionCallAgent, + policy_agent=dict( + type=AsyncTokenInOutAgent, + llm=dict( + type=ControllerWrapper, + rollout_controller=rollout_controller, + sample_params=SampleParams(max_tokens=max_response_length), + tool_call_parser=Qwen3_5FunctionCallParser(), + ), + template=search_browse_tool_prompt, + ), + env_agent=dict( + type=EnvAgent, + actions=[search_tool, browse_tool], + judger=JudgerWrapper( + # type=JudgerWrapper, + # judger_cfg=compass_judger_cfg, + # placement_group=ray.get( + # placement_group( + # bundles=[{"CPU": 1, "memory": 1024**3}] * len(compass_judger_cfg.reward_judger_configs), + # strategy="PACK", + # ).ready(), + # timeout=30, + # ), + judger_controller=compass_judger_controller, + ), + max_turn=max_turn, + lower_tool_turn_bound=lower_tool_turn_bound, + enable_repeated_tool_call_penalty=enable_repeated_tool_call_penalty, + enable_no_thinking_penalty=enable_no_thinking_penalty, + max_tool_response_length=max_tool_response_length, + ), + finish_condition=finish_condition_func, + initialize_input=False, +) +train_agent_with_search_visit = dict( + type=FunctionCallAgent, + policy_agent=dict( + type=AsyncTokenInOutAgent, + llm=dict( + type=ControllerWrapper, + rollout_controller=rollout_controller, + sample_params=SampleParams(max_tokens=max_response_length), + tool_call_parser=Qwen3_5FunctionCallParser(), + ), + template=search_visit_tool_prompt, + ), + env_agent=dict( + type=EnvAgent, + actions=[search_tool, visit_tool], + judger=JudgerWrapper( + # type=JudgerWrapper, + # judger_cfg=compass_judger_cfg, + # placement_group=ray.get( + # placement_group( + # bundles=[{"CPU": 1, "memory": 1024**3}] * len(compass_judger_cfg.reward_judger_configs), + # strategy="PACK", + # ).ready(), + # timeout=30, + # ), + judger_controller=compass_judger_controller, + ), + max_turn=max_turn, + lower_tool_turn_bound=lower_tool_turn_bound, + enable_repeated_tool_call_penalty=enable_repeated_tool_call_penalty, + enable_no_thinking_penalty=enable_no_thinking_penalty, + max_tool_response_length=max_tool_response_length, + ), + finish_condition=finish_condition_func, + initialize_input=False, +) +eval_search_agent = dict( + type=FunctionCallAgent, + policy_agent=dict( + type=AsyncTokenInOutAgent, + llm=dict( + type=ControllerWrapper, + rollout_controller=rollout_controller, + sample_params=SampleParams(max_tokens=max_response_length), + tool_call_parser=Qwen3_5FunctionCallParser(), + ), + template=search_visit_python_tool_prompt, + ), + env_agent=dict( + type=EnvAgent, + actions=[search_tool, visit_tool, python_action], + judger=JudgerWrapper( + # type=JudgerWrapper, + # judger_cfg=compass_judger_cfg, + # placement_group=ray.get( + # placement_group( + # bundles=[{"CPU": 1, "memory": 1024**3}] * len(compass_judger_cfg.reward_judger_configs), + # strategy="PACK", + # ).ready(), + # timeout=30, + # ), + judger_controller=compass_judger_controller, + ), + max_turn=max_turn, + lower_tool_turn_bound=None, + enable_repeated_tool_call_penalty=False, + enable_no_thinking_penalty=False, + max_tool_response_length=max_tool_response_length, + ), + finish_condition=finish_condition_func, + initialize_input=False, +) + +train_math_agent = dict( + type=FunctionCallAgent, + policy_agent=dict( + type=AsyncTokenInOutAgent, + llm=dict( + type=ControllerWrapper, + rollout_controller=rollout_controller, + sample_params=SampleParams(max_tokens=max_response_length), + tool_call_parser=Qwen3_5FunctionCallParser(), + ), + template=python_tool_prompt, + ), + env_agent=dict( + type=EnvAgent, + actions=[python_action], + judger=JudgerWrapper( + # type=JudgerWrapper, + # judger_cfg=compass_judger_cfg, + # placement_group=ray.get( + # placement_group( + # bundles=[{"CPU": 1, "memory": 1024**3}] * len(compass_judger_cfg.reward_judger_configs), + # strategy="PACK", + # ).ready(), + # timeout=30, + # ), + judger_controller=compass_judger_controller, + ), + max_turn=25, + lower_tool_turn_bound=5, + enable_repeated_tool_call_penalty=False, + enable_no_thinking_penalty=False, + max_tool_response_length=4096, + ), + finish_condition=finish_condition_func, + initialize_input=False, +) +eval_math_agent = dict( + type=FunctionCallAgent, + policy_agent=dict( + type=AsyncTokenInOutAgent, + llm=dict( + type=ControllerWrapper, + rollout_controller=rollout_controller, + sample_params=SampleParams(max_tokens=max_response_length), + tool_call_parser=Qwen3_5FunctionCallParser(), + ), + template=python_tool_prompt, + ), + env_agent=dict( + type=EnvAgent, + actions=[python_action], + judger=JudgerWrapper( + # type=JudgerWrapper, + # judger_cfg=compass_judger_cfg, + # placement_group=ray.get( + # placement_group( + # bundles=[{"CPU": 1, "memory": 1024**3}] * len(compass_judger_cfg.reward_judger_configs), + # strategy="PACK", + # ).ready(), + # timeout=30, + # ), + judger_controller=compass_judger_controller, + ), + max_turn=25, + lower_tool_turn_bound=None, + enable_repeated_tool_call_penalty=False, + enable_no_thinking_penalty=False, + max_tool_response_length=4096, + ), + finish_condition=finish_condition_func, + initialize_input=False, +) + +train_review_agent = dict( + type=FunctionCallAgent, + policy_agent=dict( + type=AsyncTokenInOutAgent, + llm=dict( + type=ControllerWrapper, + rollout_controller=rollout_controller, + sample_params=SampleParams(max_tokens=max_response_length), + tool_call_parser=Qwen3_5FunctionCallParser(argument_type={'end_date': str}), + ), + template=review_tool_prompt + "\n\n" + review_sys_prompt, + ), + env_agent=dict( + type=EnvAgent, + actions=[arxiv_tool], + judger=JudgerWrapper( + # type=JudgerWrapper, + # judger_cfg=review_judger_cfg, + # placement_group=ray.get( + # placement_group(bundles=[{"CPU": 1, "memory": 1024**3}], strategy="PACK").ready(), timeout=30 + # ), + judger_controller=review_judger_controller, + reward_key=None, + ), + max_turn=25, + lower_tool_turn_bound=None, + enable_repeated_tool_call_penalty=False, + enable_no_thinking_penalty=False, + max_tool_response_length=max_tool_response_length, + ), + finish_condition=finish_condition_func, + initialize_input=False, +) +eval_review_agent = dict( + type=FunctionCallAgent, + policy_agent=dict( + type=AsyncTokenInOutAgent, + llm=dict( + type=ControllerWrapper, + rollout_controller=rollout_controller, + sample_params=SampleParams(max_tokens=max_response_length), + tool_call_parser=Qwen3_5FunctionCallParser(argument_type={'end_date': str}), + ), + template=review_tool_prompt + "\n\n" + review_sys_prompt, + ), + env_agent=dict( + type=EnvAgent, + actions=[arxiv_tool], + judger=JudgerWrapper( + # type=JudgerWrapper, + # judger_cfg=review_judger_cfg, + # placement_group=ray.get( + # placement_group(bundles=[{"CPU": 1, "memory": 1024**3}], strategy="PACK").ready(), timeout=30 + # ), + judger_controller=review_judger_controller, + reward_key=None, + ), + max_turn=25, + lower_tool_turn_bound=None, + enable_repeated_tool_call_penalty=False, + enable_no_thinking_penalty=False, + max_tool_response_length=max_tool_response_length, + ), + finish_condition=finish_condition_func, + initialize_input=False, +) + + +def bucket(key: str, n: int) -> int: + hashkey = hashlib.md5(key.encode('utf-8')).hexdigest() + return sum(ord(c) for c in hashkey) % n + + +def rollout_env_router_fn(item: RLDataFlowItem): + source = item.data.extra_info.get('origin_data_source', '') + + # 1. Routing for Evaluation Environments + if source in ['openreview_test']: + return 'eval_review_agent' + if source.startswith('gaia') or source in [ + 'BrowseComp-ZH', + 'HLE', + 'browsecomp', + 'GAIA_sft_1229', + 'sgi-deep-research', + 'frontierscience', + ]: + return 'eval_search_agent' + if source in ['AIME2024', 'AIME2025', 'aime2026', 'hmmt26', 'IMO-Bench-AnswerBench', 'UGD_hard']: + return 'eval_math_agent' + if source == 'tb2-eval': + return 'eval_tb2eval' + + # 2. Routing for Train Environments + if source == 'openreview': + return 'train_review_agent' + elif source == 'math': + return 'train_math_agent' + elif source == 'claw-bench': + return 'train_clawbench' + elif source == 'tb2-rl': + return 'train_tb2rl' + elif source == 'agent_science': + # Science search data (with python_action) + match bucket(item.data.messages[-1]['content'], 2): + case 0: + return 'train_science_search_browse_agent' + case 1: + return 'train_science_search_visit_agent' + else: + # Default fallback to Search Train Environments + match bucket(item.data.messages[-1]['content'], 2): + case 0: + return 'train_agent_with_search_browse' + case 1: + return 'train_agent_with_search_visit' + + +environment_config = dict( + type=ComposedEnvironment, + environment=experimental_name, + rollout_controller=rollout_controller, + environments={ + 'train_science_search_browse_agent': dict( + type=AgentEnvironment, + environment='train_science_search_browse_agent', + agent_cfg=train_science_search_browse_agent, + rollout_controller=rollout_controller, + preprocess_func=prepare_agent_inputs, + postprocess_func=convert_rollout_tractory_to_train, + ), + 'train_science_search_visit_agent': dict( + type=AgentEnvironment, + environment='train_science_search_visit_agent', + agent_cfg=train_science_search_visit_agent, + rollout_controller=rollout_controller, + preprocess_func=prepare_agent_inputs, + postprocess_func=convert_rollout_tractory_to_train, + ), + 'train_agent_with_search_browse': dict( + type=AgentEnvironment, + environment='train_agent_with_search_browse', + agent_cfg=train_agent_with_search_browse, + rollout_controller=rollout_controller, + preprocess_func=prepare_agent_inputs, + postprocess_func=convert_rollout_tractory_to_train, + ), + 'train_agent_with_search_visit': dict( + type=AgentEnvironment, + environment='train_agent_with_search_visit', + agent_cfg=train_agent_with_search_visit, + rollout_controller=rollout_controller, + preprocess_func=prepare_agent_inputs, + postprocess_func=convert_rollout_tractory_to_train, + ), + 'eval_search_agent': dict( + type=AgentEnvironment, + environment='eval_search_agent', + agent_cfg=eval_search_agent, + rollout_controller=rollout_controller, + preprocess_func=prepare_agent_inputs, + postprocess_func=convert_rollout_tractory_to_train, + ), + 'train_math_agent': dict( + type=AgentEnvironment, + environment='train_math_agent', + agent_cfg=train_math_agent, + rollout_controller=rollout_controller, + preprocess_func=prepare_agent_inputs, + postprocess_func=convert_rollout_tractory_to_train, + ), + 'eval_math_agent': dict( + type=AgentEnvironment, + environment='eval_math_agent', + agent_cfg=eval_math_agent, + rollout_controller=rollout_controller, + preprocess_func=prepare_agent_inputs, + postprocess_func=convert_rollout_tractory_to_train, + ), + 'train_review_agent': dict( + type=AgentEnvironment, + environment='train_review_agent', + agent_cfg=train_review_agent, + rollout_controller=rollout_controller, + preprocess_func=prepare_agent_inputs, + postprocess_func=convert_rollout_tractory_to_train, + ), + 'eval_review_agent': dict( + type=AgentEnvironment, + environment='eval_review_agent', + agent_cfg=eval_review_agent, + rollout_controller=rollout_controller, + preprocess_func=prepare_agent_inputs, + postprocess_func=convert_rollout_tractory_to_train, + ), + 'train_tb2rl': dict( + type=InstallAgentEnvironment, + environment='train_tb2rl', + rollout_controller=rollout_controller, + preprocess_func=prepare_agent_inputs_for_tb2rl, + postprocess_func=convert_rollout_tractory_to_train_for_tb2rl, + ), + 'eval_tb2eval': dict( + type=InstallAgentEnvironment, + environment='eval_tb2eval', + rollout_controller=rollout_controller, + preprocess_func=prepare_agent_inputs_for_tb2rl, + postprocess_func=convert_rollout_tractory_to_train_for_tb2rl, + ), + }, + router=rollout_env_router_fn, +) + +# 4. dataflow and evaluator +dataflow_config = DataFlowConfig( + env=experimental_name, + max_concurrent=max_concurrent_groups, + enable_partial_rollout=enable_partial_rollout, + tail_batch_candidate_steps=tail_batch_candidate_steps, + staleness_threshold=staleness_threshold, + max_retry_times=3, + prompt_repeat_k=prompt_repeat_k, + global_batch_size=global_batch_size, + sample_params=training_sample_params, +) + +evaluator_cfg = ( + EvaluatorConfig( + enable_evaluate=enable_evaluate, + enable_initial_evaluate=enable_initial_evaluate, + dataset_cfg=eval_dataset_cfg, + tokenizer=model_path, + eval_sample_ratio=1.0, + evaluate_step=evaluate_step, + compute_metric_func=partial( + compute_metric, + source_normalizer={ + 'miroRL': 'websearch', + 'musique': 'websearch', + 'websailor': 'websearch', + 'webdancer': 'websearch', + 'gaia-level1': ('gaia-level1', 'gaia'), + 'gaia-level2': ('gaia-level2', 'gaia'), + 'gaia-level3': ('gaia-level3', 'gaia'), + }, + ), + sample_params=evaluation_sample_params, + max_concurrent=8192, + ) + if enable_evaluate + else None +) + + +def group_sample_filter_func(group_samples): + # filter all correct or all wrong sample + group_samples = [s for s in group_samples if s.env.rollout.response_ids is not None] + + # filter all same reward sample + rewards = [d.env.judger.reward["score"] for d in group_samples] + if len(set(rewards)) == 1: + print(f"filter all same reward sample: {rewards}") + return [] + return group_samples + + +replay_buffer_cfg = ReplayBufferConfig( + dataset_cfg=train_dataset_cfg, + dataloader_cfg=dataloader_config, + tokenizer=model_path, + # postprocessor_func=group_sample_filter_func, +) + +# # 5. Train worker +model_cfg = Qwen3_5_VLMoE35BA3Config( + freeze_vision=True, + freeze_projector=True, +) +model_cfg.float8_cfg = None +model_cfg.text_config.ep_size = 1 +model_cfg.text_config.z_loss_cfg = None +model_cfg.text_config.balancing_loss_cfg = None +model_cfg.text_config.freeze_routers = True +model_cfg.text_config.mtp_config = MTPConfig( + num_layers=4, + loss_scaling_factor=1.0, + detach_mtp_lm_head_weight=True, + detach_mtp_inputs=True, + share_weights=True, +) +model_cfg.text_config.vocab_size = 251392 +# model_cfg.text_config.embed_grad_max_token_id = 251173 + +optim_cfg = AdamWConfig( + lr=lr, + betas=(0.9, 0.95), + max_grad_norm=1.0, + weight_decay=0.1, + foreach=False, + skip_grad_norm_threshold=5, + eps=1e-15, +) +loss_cfg = GRPOLossConfig( + policy_loss_cfg=dict( + cliprange_high=0.2, + cliprange_low=0.2, + loss_type="intern_s1_delivery.modules.pg_loss.pg_loss_fn", + clip_ratio_c=5.0, + log_prob_diff_min=-20.0, + log_prob_diff_max=20.0, + ), + ignore_idx=-100, + use_kl_loss=False, + kl_loss_coef=0.0, + kl_loss_type="low_var_kl", + mode="chunk", + chunk_size=512, + rollout_is=RolloutImportanceSampling( + rollout_is_level="token", + rollout_is_mode="both", + rollout_is_threshold=(5, 0), + rollout_is_mask_threshold=(5, 0.5), + rollout_is_veto_threshold=(20, 0), + ), +) +lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=lr) +fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=train_ep_size, fp32_lm_head=fp32_lm_head) +train_worker_cfg: WorkerConfig = WorkerConfig( + model_cfg=model_cfg, + load_from=model_path, + optim_cfg=optim_cfg, + loss_cfg=loss_cfg, + lr_cfg=lr_cfg, + fsdp_cfg=fsdp_cfg, + sp_size=train_sp_size, + optimizer_steps=train_optimizer_steps, + pack_max_length=pack_max_length, +) + + +# 6. RL Trainer +trainer = AgentRLTrainerConfig( + load_from=model_path, + pg=pg, + environment_config=environment_config, + dataflow_config=dataflow_config, + replay_buffer_config=replay_buffer_cfg, + train_worker_cfg=train_worker_cfg, + evaluator_config=evaluator_cfg, + tokenizer_path=model_path, + work_dir=work_dir, + total_epochs=total_epochs, + hf_interval=hf_interval, + skip_load_weights=skip_load_weights, + auto_resume=auto_resume, + checkpoint_interval=2, + checkpoint_maxkeep=1, + load_checkpoint_cfg=load_checkpoint_cfg, + checkpoint_no_save_optimizer=True, + skip_checkpoint_validation=True, + advantage_estimator_config=OverlongRLOOGroupEntropyBadwordAdvantageConfig( + entropy_upper_bound=0.65, + entropy_lower_bound=0.25, + tau_upper=0.0, + tau_lower=0.0, + coeff_min_upper=0.2, + coeff_min_lower=0.5, + overlong_filer=True, + badword_ratio_cost_factor=1.0, + tokenizer_path=model_path, + ), +) + + +import torch.distributed as dist + +from xtuner.v1.train.agent_rl_trainer import AgentRLTrainer + +trainer = AgentRLTrainer.from_config(trainer) +trainer.fit() + +if dist.is_initialized(): + dist.destroy_process_group() diff --git a/examples/v1/scripts/rjob_run_train_interns1_1.sh b/examples/v1/scripts/rjob_run_train_interns1_1.sh new file mode 100755 index 0000000000..e1e285d9bf --- /dev/null +++ b/examples/v1/scripts/rjob_run_train_interns1_1.sh @@ -0,0 +1,20 @@ +JOB_NAME=xtuner_router_reply_insterns2_preview_ßtrain +NUM_NODES=8 +WORKER_GPU=8 +CPU_NUMS=12 +NODE_MEMS=960 +IMAGE=registry.h.pjlab.org.cn/ailab-puyu/xpuyu:torch-2.7.0-076676dd-0708 +# rjob delete ${JOB_NAME} +clusterx run \ +--job-name=${JOB_NAME} \ +--gpus-per-task=${WORKER_GPU} \ +--memory-per-task=$NODE_MEMS \ +--cpus-per-task=$((CPU_NUMS * (WORKER_GPU > 0 ? WORKER_GPU : 1))) \ +--image=${IMAGE} \ +--num-nodes ${NUM_NODES} \ +--priority 9 \ +--partition puyullm_gpu \ +--project-name ailab-puyullmgpu \ +--no-env \ +/mnt/shared-storage-user/llmit/user/liukuikun/workspace/xtuner/examples/v1/scripts/train_agentrl.sh +# "zsh -exc '/mnt/shared-storage-user/llmit/user/liukuikun/workspace/crg_rl_projects_router_reply/scripts/run_train_interns1_1.sh'" diff --git a/examples/v1/scripts/train_agentrl.sh b/examples/v1/scripts/train_agentrl.sh index a19cd67e47..8afeb8fc54 100755 --- a/examples/v1/scripts/train_agentrl.sh +++ b/examples/v1/scripts/train_agentrl.sh @@ -11,13 +11,14 @@ export TRANSFORMERS_OFFLINE=1 export HF_EVALUATE_OFFLINE=1 export HF_HUB_OFFLINE=1 +# lmdeploy_dir=/mnt/shared-storage-user/llmit/user/lvchengqi/projects/interns2_rl/mtp_rl_dev/lmdeploy lmdeploy_dir=/mnt/shared-storage-user/llmit/user/lvchengqi/projects/interns2_rl/mtp_rl_dev/lmdeploy xtuner_dir=/mnt/shared-storage-user/llmit/user/liukuikun/workspace/xtuner -intern_s2_delivery_dir=/mnt/shared-storage-user/llmit/user/lvchengqi/projects/interns2_rl/crg_rl_projects/src +intern_s2_delivery_dir=/mnt/shared-storage-user/llmit/user/liujiangning/projects/interns2_preview_agentrl_mtp/crg_rl_projects/src lagent_dir=/mnt/shared-storage-user/llmit/user/liukuikun/workspace/lagent xtuner_project_dir=/mnt/shared-storage-user/llmit/user/liukuikun/workspace/xtuner/projects - -export PYTHONPATH=$intern_s2_delivery_dir:$lmdeploy_dir:$xtuner_dir:$lagent_dir:$xtuner_project_dir:$PYTHONPATH +transformer_model=/mnt/shared-storage-user/llmit1/user/wangziyi/.cache/huggingface/modules +export PYTHONPATH=$intern_s2_delivery_dir:$lmdeploy_dir:$xtuner_dir:$lagent_dir:$xtuner_project_dir:$transformer_model:$PYTHONPATH export NLTK_DATA=/mnt/shared-storage-user/llmit/user/lishuaibin/mv2yidian/nltk_data export XTUNER_USE_FA3=1 @@ -29,7 +30,7 @@ export PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' export XTUNER_USE_SGLANG=0 export XTUNER_USE_LMDEPLOY=1 export XTUNER_USE_VLLM=0 -export RL_LLM_MODEL=train_lkk_test_0429rc1 +export RL_LLM_MODEL=train_lkk_test_$(date "+%m%d%H%M%S") export MASTER_PORT=6000 export WORLD_SIZE=1 export LMD_SKIP_WARMUP=1 @@ -51,8 +52,8 @@ export TRAIN_OPTIMIZER_STEPS=8 current_time=$(date "+%m%d%H") -export CONFIG_PATH='/mnt/shared-storage-user/llmit/user/liukuikun/workspace/xtuner/examples/v1/config/interns2-35ba3-base03-20260413a-websearch-rl0415rc1_local.py' -export WORK_DIR='/mnt/shared-storage-user/llmit1/user/liukuikun/delivery/interns2_preview_0429' +export CONFIG_PATH='/mnt/shared-storage-user/llmit/user/liukuikun/workspace/xtuner/examples/v1/config/interns2-35ba3-base05-20260424a-rl-data260426rc1-56k-badword-mtp4_agenticrl_tb2_mtp4_0503rc1.py' +export WORK_DIR='/mnt/shared-storage-user/llmit1/user/liukuikun/delivery/interns2_preview_0430rc9' if [ ! -d "$WORK_DIR" ]; then diff --git a/my_changes.patch b/my_changes.patch new file mode 100644 index 0000000000..28c89d7ed9 --- /dev/null +++ b/my_changes.patch @@ -0,0 +1,62 @@ +diff --git a/xtuner/v1/ray/environment/lagent/tokenize.py b/xtuner/v1/ray/environment/lagent/tokenize.py +index 41ae95ec..27074cf5 100644 +--- a/xtuner/v1/ray/environment/lagent/tokenize.py ++++ b/xtuner/v1/ray/environment/lagent/tokenize.py +@@ -101,7 +101,7 @@ def tokenize( + f"Expected routed_experts_ref to be a base64 string, but got {type(routed_experts_ref)}" + ) + ref_bytes = base64.b64decode(routed_experts_ref.encode("utf-8")) +- routed_experts = cloudpickle.loads(ref_bytes) ++ routed_experts = ref_bytes + else: + routed_experts = None + +diff --git a/xtuner/v1/ray/rollout/worker.py b/xtuner/v1/ray/rollout/worker.py +index fff43245..ad128f17 100644 +--- a/xtuner/v1/ray/rollout/worker.py ++++ b/xtuner/v1/ray/rollout/worker.py +@@ -12,6 +12,7 @@ from typing import Any, Callable, Dict, List, Optional, Union + import httpx + import numpy as np + import ray ++from ray import ObjectRef, cloudpickle + import requests # type: ignore[import-untyped] + from packaging.version import Version + from ray import ObjectRef +@@ -584,8 +585,13 @@ class RolloutWorker(SingleAcceleratorWorker): + else: + cur_routed_experts = routed_experts + +- history_routed_experts = await input_extra_info["routed_experts"] # n, layer, expert +- ray.internal.free(input_extra_info["routed_experts"], local_only=False) ++ history_routed_experts_key = input_extra_info["routed_experts"] ++ if isinstance(history_routed_experts_key, ObjectRef): ++ history_routed_experts = await input_extra_info["routed_experts"] # n, layer, expert ++ elif isinstance(history_routed_experts_key, str): ++ history_routed_experts_key = cloudpickle.loads(history_routed_experts_key) ++ history_routed_experts = await history_routed_experts_key ++ ray.internal.free(history_routed_experts_key, local_only=False) + del input_extra_info + + assert (history_routed_experts.shape[0] - 1) > 0 and history_routed_experts.shape[ +diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py +index 6ab83524..57a4feae 100644 +--- a/xtuner/v1/rl/base/worker.py ++++ b/xtuner/v1/rl/base/worker.py +@@ -7,6 +7,7 @@ from pathlib import Path + from typing import Dict, Iterable, List, Sequence, TypeAlias, TypedDict, cast + + import ray ++from ray import cloudpickle + import requests + import torch + import torch.distributed as dist +@@ -422,7 +423,7 @@ class TrainingWorker(SingleAcceleratorWorker): + ) + out_rollout_routed_expert.append(rollout_routed_experts_tensor) + else: +- rollout_routed_expert_refs = rollout_routed_expert ++ rollout_routed_expert_refs = cloudpickle.loads(rollout_routed_expert) + rollout_routed_expert = ray.get(rollout_routed_expert_refs) + # free obj store explicitly + if self.sp_mesh is None or self.sp_mesh.size() == 1: diff --git a/projects/claw_bench/agents/internclaw/config.py b/projects/claw_bench/agents/internclaw/config.py index c59f241f8b..d90dc424e1 100644 --- a/projects/claw_bench/agents/internclaw/config.py +++ b/projects/claw_bench/agents/internclaw/config.py @@ -2,7 +2,7 @@ import os -workspace = os.environ.get("TASK_WORKSPACE", "/workspace") +workspace = os.environ.get("TASK_WORKSPACE", "") skills_root = f"{workspace}/skills" model = dict( @@ -10,7 +10,7 @@ model=dict( model=os.environ.get( "RL_LLM_MODEL", - "sft_interns2_pre_base03_20260413a_lr2e5_128gpu_hf5646", + "train_lkk_test", ), base_url=os.environ.get( "RL_LLM_BASE_URL", diff --git a/projects/claw_bench/pipeline.py b/projects/claw_bench/pipeline.py index 9033ec6dd8..9bef56309b 100644 --- a/projects/claw_bench/pipeline.py +++ b/projects/claw_bench/pipeline.py @@ -13,13 +13,14 @@ from xtuner.v1.ray.environment.rl_task.hooks import ( BenchEnv, + DumpDaemonLogOnFailure, InstallLagent, ParseJudgerStdout, PickAgent, RenderInstruction, RunAgentInstallDeps, + UploadAgentConfigSource, UploadChosenAgent, - WriteAgentConfig, ) from xtuner.v1.ray.environment.rl_task.judgers import Judger from xtuner.v1.ray.environment.rl_task.runner import Runner @@ -63,7 +64,7 @@ PATHS = SimpleNamespace( wrappers_bench="/tmp/wrappers/claw_bench", wrappers_lagent="/tmp/wrappers/lagent", - agent_config="/tmp/agent_config.json", + agent_config="/tmp/agent_config.py", trajectory="/tmp/trajectory.json", message="/tmp/message.json", verifier="/tmp/verifier", @@ -204,8 +205,8 @@ def claw_pipeline( ExecHook(f"mkdir -p {ws}"), # 7. Render instruction.md: `workspace/` → abs path + {{KEY}} → env. # RenderInstruction(rewrites=CLAW_INSTRUCTION_REWRITES), - # 8. Exec chosen agent's config.py on host → upload resulting JSON. - WriteAgentConfig(dst=PATHS.agent_config), + # 8. Upload chosen agent's config.py — daemon execs it in-sandbox. + UploadAgentConfigSource(dst=PATHS.agent_config), # 9. Run install-deps.sh if the chosen agent template has one. RunAgentInstallDeps(workspace=ws), ], @@ -216,7 +217,7 @@ def claw_pipeline( extras={"WORKSPACE": ws, "CLAW_WORKSPACE": ws}, ), timeout=1800, - post=[DownloadHook(["/workspace", "/tmp/agent_response.txt"]), ReadFileHook("/tmp/message.json", "message")], + post=[DownloadHook(["/workspace", "/tmp/agent_response.txt"]), ReadFileHook("/tmp/message.json", "message"), DumpDaemonLogOnFailure()], ) return Runner( diff --git a/projects/claw_bench/wrappers/lagent/lagent_entry.sh b/projects/claw_bench/wrappers/lagent/lagent_entry.sh index b92ec92476..5c940ddc4f 100755 --- a/projects/claw_bench/wrappers/lagent/lagent_entry.sh +++ b/projects/claw_bench/wrappers/lagent/lagent_entry.sh @@ -14,7 +14,8 @@ # state_dict → kills daemon. Exits 0 on success, non-zero on failure. # # Relies on /tmp/lagent-py wrapper (shared conda python + PYTHONPATH), which -# runner bootstrap writes. Agent config JSON is what daemon consumes. +# runner bootstrap writes. --config is a Python file defining agent_config; +# daemon execs it in-sandbox so os.environ lookups resolve here, not on host. # --------------------------------------------------------------------------- set -uo pipefail @@ -55,6 +56,26 @@ fi LOG=/tmp/agent_daemon.log : > "$LOG" +# If a daemon call returned {"error": "..."} over socket (exception during +# dispatch — e.g. LLM 4xx/5xx, timeout), exit non-zero so the host side's +# _dump_daemon_log fires and the traceback surfaces in xtuner logs. +_die_on_daemon_error() { + local resp="$1" phase="$2" code="$3" + local err + err=$(printf '%s' "$resp" | "$LAGENT_PY" -c ' +import json, sys +try: + print(json.loads(sys.stdin.read() or "{}").get("error", "")) +except Exception: + pass +') + if [ -n "$err" ]; then + echo "daemon error in ${phase}: ${err}" >&2 + tail -n 500 "$LOG" >&2 || true + exit "$code" + fi +} + # ── 1. Start AgentDaemon ────────────────────────────────────────────── nohup "$LAGENT_PY" -m lagent.serving.sandbox.daemon start \ --mode agent \ @@ -102,6 +123,7 @@ CHAT_RESP=$("$LAGENT_PY" -m lagent.serving.sandbox.daemon call \ tail -n 100 "$LOG" >&2 || true exit 5 } +_die_on_daemon_error "$CHAT_RESP" chat 5 # Extract final response content (plain text) to RESPONSE_OUT. printf '%s' "$CHAT_RESP" | "$LAGENT_PY" -c ' @@ -121,6 +143,7 @@ STATE_RESP=$("$LAGENT_PY" -m lagent.serving.sandbox.daemon call \ tail -n 100 "$LOG" >&2 || true exit 6 } +_die_on_daemon_error "$STATE_RESP" state_dict 6 # Wrap lagent's native state into {"trajectory": [...]} if it isn't already. printf '%s' "$STATE_RESP" | "$LAGENT_PY" -c ' @@ -153,6 +176,7 @@ POLICY_AGENT_MESSAGES=$("$LAGENT_PY" -m lagent.serving.sandbox.daemon call \ tail -n 100 "$LOG" >&2 || true exit 7 } +_die_on_daemon_error "$POLICY_AGENT_MESSAGES" get_messages 7 printf '%s' "$POLICY_AGENT_MESSAGES" | "$LAGENT_PY" -c ' import json, sys diff --git a/projects/tb2_eval/agents/interndp/config.py b/projects/tb2_eval/agents/interndp/config.py index 71037ae0da..b0efb194f4 100644 --- a/projects/tb2_eval/agents/interndp/config.py +++ b/projects/tb2_eval/agents/interndp/config.py @@ -43,7 +43,7 @@ model=dict( model=os.environ.get( "RL_LLM_MODEL", - "train_lkk_test", + "", ), base_url=os.environ.get( "RL_LLM_BASE_URL", @@ -52,8 +52,8 @@ api_key=os.environ.get("RL_LLM_API_KEY", "sk-admin"), ), sample_params=dict(temperature=0.7, top_p=1.0, top_k=50), - timeout=600, - max_retry=500, + timeout=900, + max_retry=1, sleep_interval=5, extra_body=dict(spaces_between_special_tokens=False), ) @@ -70,8 +70,10 @@ env_agent = dict( type="lagent.agents.env_agent.RLEnvAgent", actions=base_actions, - max_turn=25, + max_turn=100, enable_no_thinking_penalty=False, + max_tool_response_length=4096, + tool_response_truncate_side="left", ) agent_config = dict( diff --git a/projects/tb2_eval/pipeline.py b/projects/tb2_eval/pipeline.py index 008e2cc5e6..5d958f6a2b 100644 --- a/projects/tb2_eval/pipeline.py +++ b/projects/tb2_eval/pipeline.py @@ -27,12 +27,13 @@ from xtuner.v1.ray.environment.rl_task.hooks import ( BenchEnv, + DumpDaemonLogOnFailure, InstallLagent, ParseJudgerStdout, PickAgent, RunAgentInstallDeps, + UploadAgentConfigSource, UploadChosenAgent, - WriteAgentConfig, ) from xtuner.v1.ray.environment.rl_task.judgers import Judger from xtuner.v1.ray.environment.rl_task.runner import Runner @@ -58,7 +59,7 @@ PATHS = SimpleNamespace( wrappers_bench="/tmp/wrappers/tb2_eval", wrappers_lagent="/tmp/wrappers/lagent", - agent_config="/tmp/agent_config.json", + agent_config="/tmp/agent_config.py", trajectory="/tmp/trajectory.json", message="/tmp/message.json", tests="/tests", @@ -91,7 +92,7 @@ ] # Placeholder image — each task overrides this via sandbox_spec in extra_info. -DEFAULT_SANDBOX = SandboxSpec(image="tb2-eval-placeholder", ttl_seconds=1800, workspace_path="/app") +DEFAULT_SANDBOX = SandboxSpec(image="tb2-eval-placeholder", ttl_seconds=11700, workspace_path="/app") # ───────────────────────────────────────────────────────────────── @@ -205,8 +206,8 @@ def tb2_eval_pipeline( UploadChosenAgent(target_dir=f"{ws}/agent/"), # 7. Ensure workspace dir exists. ExecHook(f"mkdir -p {ws}"), - # 8. Exec chosen agent's config.py on host → upload resulting JSON. - WriteAgentConfig(dst=PATHS.agent_config), + # 8. Upload chosen agent's config.py — daemon execs it in-sandbox. + UploadAgentConfigSource(dst=PATHS.agent_config), # 9. Run install-deps.sh if the chosen agent template has one. RunAgentInstallDeps(workspace=ws), ], @@ -215,10 +216,11 @@ def tb2_eval_pipeline( workspace=ws, extras={"WORKSPACE": ws}, ), - timeout=900, + timeout=10800, post=[ DownloadHook([ws, "/tmp/agent_response.txt"]), ReadFileHook("/tmp/message.json", "message"), + DumpDaemonLogOnFailure(), ], ) diff --git a/projects/tb2_eval/wrappers/lagent/lagent_entry.sh b/projects/tb2_eval/wrappers/lagent/lagent_entry.sh index b92ec92476..5c940ddc4f 100644 --- a/projects/tb2_eval/wrappers/lagent/lagent_entry.sh +++ b/projects/tb2_eval/wrappers/lagent/lagent_entry.sh @@ -14,7 +14,8 @@ # state_dict → kills daemon. Exits 0 on success, non-zero on failure. # # Relies on /tmp/lagent-py wrapper (shared conda python + PYTHONPATH), which -# runner bootstrap writes. Agent config JSON is what daemon consumes. +# runner bootstrap writes. --config is a Python file defining agent_config; +# daemon execs it in-sandbox so os.environ lookups resolve here, not on host. # --------------------------------------------------------------------------- set -uo pipefail @@ -55,6 +56,26 @@ fi LOG=/tmp/agent_daemon.log : > "$LOG" +# If a daemon call returned {"error": "..."} over socket (exception during +# dispatch — e.g. LLM 4xx/5xx, timeout), exit non-zero so the host side's +# _dump_daemon_log fires and the traceback surfaces in xtuner logs. +_die_on_daemon_error() { + local resp="$1" phase="$2" code="$3" + local err + err=$(printf '%s' "$resp" | "$LAGENT_PY" -c ' +import json, sys +try: + print(json.loads(sys.stdin.read() or "{}").get("error", "")) +except Exception: + pass +') + if [ -n "$err" ]; then + echo "daemon error in ${phase}: ${err}" >&2 + tail -n 500 "$LOG" >&2 || true + exit "$code" + fi +} + # ── 1. Start AgentDaemon ────────────────────────────────────────────── nohup "$LAGENT_PY" -m lagent.serving.sandbox.daemon start \ --mode agent \ @@ -102,6 +123,7 @@ CHAT_RESP=$("$LAGENT_PY" -m lagent.serving.sandbox.daemon call \ tail -n 100 "$LOG" >&2 || true exit 5 } +_die_on_daemon_error "$CHAT_RESP" chat 5 # Extract final response content (plain text) to RESPONSE_OUT. printf '%s' "$CHAT_RESP" | "$LAGENT_PY" -c ' @@ -121,6 +143,7 @@ STATE_RESP=$("$LAGENT_PY" -m lagent.serving.sandbox.daemon call \ tail -n 100 "$LOG" >&2 || true exit 6 } +_die_on_daemon_error "$STATE_RESP" state_dict 6 # Wrap lagent's native state into {"trajectory": [...]} if it isn't already. printf '%s' "$STATE_RESP" | "$LAGENT_PY" -c ' @@ -153,6 +176,7 @@ POLICY_AGENT_MESSAGES=$("$LAGENT_PY" -m lagent.serving.sandbox.daemon call \ tail -n 100 "$LOG" >&2 || true exit 7 } +_die_on_daemon_error "$POLICY_AGENT_MESSAGES" get_messages 7 printf '%s' "$POLICY_AGENT_MESSAGES" | "$LAGENT_PY" -c ' import json, sys diff --git a/projects/tb2_eval/wrappers/tb2_eval/emit_judger_result_from_ctrf.py b/projects/tb2_eval/wrappers/tb2_eval/emit_judger_result_from_ctrf.py index e31aab8afc..674236272e 100644 --- a/projects/tb2_eval/wrappers/tb2_eval/emit_judger_result_from_ctrf.py +++ b/projects/tb2_eval/wrappers/tb2_eval/emit_judger_result_from_ctrf.py @@ -1,8 +1,12 @@ #!/usr/bin/env python3 -"""Parse a CTRF JSON report + test log → emit a ``JudgerResult`` line to stdout. - -Honors ``@pytest.mark.weight(N)`` via pytest-json-ctrf's ``extra``/``metadata`` -section. Tests with no explicit weight get 1.0. +"""Emit a ``JudgerResult`` line to stdout that matches official TB2 scoring. + +The bench's ``tests/test.sh`` writes the authoritative binary outcome to +``/logs/verifier/reward.txt`` (``1`` iff every pytest invocation in that +script exited 0, else ``0``) — including the multi-pytest case in tasks +like ``fix-code-vulnerability``. We read that file as the source of truth +for ``total``. CTRF is parsed only for per-test observability in the +``criteria`` field and never used for scoring. """ from __future__ import annotations @@ -13,26 +17,6 @@ from pathlib import Path -def _extract_weight(test: dict) -> float: - for section in ("extra", "metadata"): - extras = test.get(section) or [] - if isinstance(extras, list): - for e in extras: - if isinstance(e, dict) and e.get("key") == "weight": - try: - return float(e.get("value", 1.0)) - except (TypeError, ValueError): - return 1.0 - elif isinstance(extras, dict): - w = extras.get("weight") - if w is not None: - try: - return float(w) - except (TypeError, ValueError): - return 1.0 - return 1.0 - - def _log_tail(path: Path, bytes_: int = 800) -> str: try: return path.read_text(errors="replace")[-bytes_:] @@ -40,60 +24,81 @@ def _log_tail(path: Path, bytes_: int = 800) -> str: return "" -def main() -> int: - ap = argparse.ArgumentParser() - ap.add_argument("--ctrf", required=True) - ap.add_argument("--log", required=True) - ap.add_argument("--pytest-rc", type=int, required=True) - ap.add_argument("--judger-name", default="rule_grader") - args = ap.parse_args() +def _read_reward(path: Path) -> float | None: + try: + raw = path.read_text().strip() + except Exception: + return None + if not raw: + return None + try: + return float(raw) + except ValueError: + return None - ctrf_path = Path(args.ctrf) - log_path = Path(args.log) +def _parse_criteria(ctrf_path: Path) -> tuple[dict[str, dict[str, float]], int, str | None]: + """Parse CTRF into per-test criteria for observability only. + + Returns: + tuple[dict[str, dict[str, float]], int, str | None]: ``(criteria, test_count, + error)``. ``criteria`` maps test name to ``{"score": 0.0|1.0}``. ``error`` is + ``None`` on success or a message describing why CTRF was unreadable. + """ try: data = json.loads(ctrf_path.read_text()) except Exception as exc: - print( - json.dumps( - { - "judger_name": args.judger_name, - "total": 0.0, - "error": f"ctrf missing/parse failed: {exc}. log tail: {_log_tail(log_path)}", - }, - ensure_ascii=False, - ) - ) - return 0 - + return {}, 0, f"ctrf missing/parse failed: {exc}" tests = (data.get("results", {}) or {}).get("tests", []) or [] criteria: dict[str, dict[str, float]] = {} for t in tests: name = t.get("name", "unknown") passed = t.get("status") == "passed" - weight = _extract_weight(t) - criteria[name] = {"score": 1.0 if passed else 0.0, "weight": weight} + criteria[name] = {"score": 1.0 if passed else 0.0} + return criteria, len(tests), None - total_w = sum(c["weight"] for c in criteria.values()) - if total_w <= 0: - total = 0.0 + +def main() -> int: + ap = argparse.ArgumentParser() + ap.add_argument("--ctrf", required=True) + ap.add_argument("--log", required=True) + ap.add_argument("--reward-file", required=True) + ap.add_argument("--pytest-rc", type=int, required=True) + ap.add_argument("--judger-name", default="rule_grader") + args = ap.parse_args() + + ctrf_path = Path(args.ctrf) + log_path = Path(args.log) + reward_path = Path(args.reward_file) + + reward = _read_reward(reward_path) + criteria, test_count, ctrf_error = _parse_criteria(ctrf_path) + + result: dict = { + "judger_name": args.judger_name, + "criteria": criteria, + "metadata": { + "pytest_rc": args.pytest_rc, + "test_count": test_count, + "reward_source": "reward.txt" if reward is not None else "pytest_rc", + }, + } + + if reward is not None: + result["total"] = round(reward, 4) else: - total = sum(c["score"] * c["weight"] for c in criteria.values()) / total_w - - print( - json.dumps( - { - "judger_name": args.judger_name, - "total": round(total, 4), - "criteria": criteria, - "metadata": { - "pytest_rc": args.pytest_rc, - "test_count": len(tests), - }, - }, - ensure_ascii=False, + # reward.txt missing/unreadable: fall back to test.sh exit code. + result["total"] = 1.0 if args.pytest_rc == 0 else 0.0 + result["error"] = ( + f"reward file unreadable at {reward_path}; fell back to pytest_rc. " + f"log tail: {_log_tail(log_path)}" ) - ) + + if ctrf_error is not None: + # CTRF is observability-only; surface the parse error but don't change total. + result.setdefault("error", ctrf_error) + + print(json.dumps(result, ensure_ascii=False)) return 0 diff --git a/projects/tb2_eval/wrappers/tb2_eval/run_tests.sh b/projects/tb2_eval/wrappers/tb2_eval/run_tests.sh index 9a523d19de..6a7bd318a6 100644 --- a/projects/tb2_eval/wrappers/tb2_eval/run_tests.sh +++ b/projects/tb2_eval/wrappers/tb2_eval/run_tests.sh @@ -5,10 +5,13 @@ # The bench ships its own test harness at /tests/test.sh which: # - installs pytest + pytest-json-ctrf + /tests/test_requirements.txt # - runs /tests/test_outputs.py with --ctrf /logs/verifier/ctrf.json +# - writes the authoritative 0/1 outcome to /logs/verifier/reward.txt +# (for multi-pytest tasks like fix-code-vulnerability, this is the +# AND of all pytest exit codes — matching official TB2 scoring) # -# We invoke it and hand the resulting CTRF to our shared emitter so -# the stage's stdout is a single JudgerResult JSON line (what -# ParseJudgerStdout expects). +# We invoke it and hand both reward.txt and the resulting CTRF to our +# shared emitter. reward.txt drives the JudgerResult `total`; CTRF is +# parsed for per-test observability only. # # Env: # $TESTS_DIR tests directory inside the sandbox (default: /tests) @@ -36,5 +39,6 @@ fi "$PY" "$WRAPPER_DIR/emit_judger_result_from_ctrf.py" \ --ctrf /logs/verifier/ctrf.json \ --log "$TEST_LOG" \ + --reward-file /logs/verifier/reward.txt \ --pytest-rc "$TEST_RC" \ --judger-name "$JUDGER_NAME" \ No newline at end of file diff --git a/projects/tb2_rl/agents/interndp/config.py b/projects/tb2_rl/agents/interndp/config.py index 71037ae0da..5758037749 100644 --- a/projects/tb2_rl/agents/interndp/config.py +++ b/projects/tb2_rl/agents/interndp/config.py @@ -43,7 +43,7 @@ model=dict( model=os.environ.get( "RL_LLM_MODEL", - "train_lkk_test", + "", ), base_url=os.environ.get( "RL_LLM_BASE_URL", @@ -52,8 +52,8 @@ api_key=os.environ.get("RL_LLM_API_KEY", "sk-admin"), ), sample_params=dict(temperature=0.7, top_p=1.0, top_k=50), - timeout=600, - max_retry=500, + timeout=1800, + max_retry=1, sleep_interval=5, extra_body=dict(spaces_between_special_tokens=False), ) @@ -70,7 +70,9 @@ env_agent = dict( type="lagent.agents.env_agent.RLEnvAgent", actions=base_actions, - max_turn=25, + max_turn=100, + max_tool_response_length=4096, + tool_response_truncate_side="left", enable_no_thinking_penalty=False, ) diff --git a/projects/tb2_rl/pipeline.py b/projects/tb2_rl/pipeline.py index 8e45ee1a29..62ec8fe4bb 100644 --- a/projects/tb2_rl/pipeline.py +++ b/projects/tb2_rl/pipeline.py @@ -24,12 +24,13 @@ from xtuner.v1.ray.environment.rl_task.hooks import ( BenchEnv, + DumpDaemonLogOnFailure, InstallLagent, ParseJudgerStdout, PickAgent, RunAgentInstallDeps, + UploadAgentConfigSource, UploadChosenAgent, - WriteAgentConfig, ) from xtuner.v1.ray.environment.rl_task.judgers import Judger from xtuner.v1.ray.environment.rl_task.runner import Runner @@ -55,7 +56,7 @@ PATHS = SimpleNamespace( wrappers_bench="/tmp/wrappers/tb2_rl", wrappers_lagent="/tmp/wrappers/lagent", - agent_config="/tmp/agent_config.json", + agent_config="/tmp/agent_config.py", trajectory="/tmp/trajectory.json", message="/tmp/message.json", tests="/tests", @@ -87,7 +88,7 @@ ), ] -DEFAULT_SANDBOX = SandboxSpec(image="t-data-processing-v1", ttl_seconds=1800, workspace_path="/app") +DEFAULT_SANDBOX = SandboxSpec(image="t-data-processing-v1", ttl_seconds=11700, workspace_path="/app") # ───────────────────────────────────────────────────────────────── @@ -201,8 +202,8 @@ def tb2_rl_pipeline( UploadChosenAgent(target_dir=f"{ws}/agent/"), # 7. Ensure workspace dir exists. ExecHook(f"mkdir -p {ws}"), - # 8. Exec chosen agent's config.py on host → upload resulting JSON. - WriteAgentConfig(dst=PATHS.agent_config), + # 8. Upload chosen agent's config.py — daemon execs it in-sandbox. + UploadAgentConfigSource(dst=PATHS.agent_config), # 9. Run install-deps.sh if the chosen agent template has one. RunAgentInstallDeps(workspace=ws), ], @@ -211,10 +212,11 @@ def tb2_rl_pipeline( workspace=ws, extras={"WORKSPACE": ws}, ), - timeout=900, + timeout=10800, post=[ DownloadHook([ws, "/tmp/agent_response.txt"]), ReadFileHook("/tmp/message.json", "message"), + DumpDaemonLogOnFailure(), ], ) diff --git a/projects/tb2_rl/wrappers/lagent/lagent_entry.sh b/projects/tb2_rl/wrappers/lagent/lagent_entry.sh index b92ec92476..5c940ddc4f 100755 --- a/projects/tb2_rl/wrappers/lagent/lagent_entry.sh +++ b/projects/tb2_rl/wrappers/lagent/lagent_entry.sh @@ -14,7 +14,8 @@ # state_dict → kills daemon. Exits 0 on success, non-zero on failure. # # Relies on /tmp/lagent-py wrapper (shared conda python + PYTHONPATH), which -# runner bootstrap writes. Agent config JSON is what daemon consumes. +# runner bootstrap writes. --config is a Python file defining agent_config; +# daemon execs it in-sandbox so os.environ lookups resolve here, not on host. # --------------------------------------------------------------------------- set -uo pipefail @@ -55,6 +56,26 @@ fi LOG=/tmp/agent_daemon.log : > "$LOG" +# If a daemon call returned {"error": "..."} over socket (exception during +# dispatch — e.g. LLM 4xx/5xx, timeout), exit non-zero so the host side's +# _dump_daemon_log fires and the traceback surfaces in xtuner logs. +_die_on_daemon_error() { + local resp="$1" phase="$2" code="$3" + local err + err=$(printf '%s' "$resp" | "$LAGENT_PY" -c ' +import json, sys +try: + print(json.loads(sys.stdin.read() or "{}").get("error", "")) +except Exception: + pass +') + if [ -n "$err" ]; then + echo "daemon error in ${phase}: ${err}" >&2 + tail -n 500 "$LOG" >&2 || true + exit "$code" + fi +} + # ── 1. Start AgentDaemon ────────────────────────────────────────────── nohup "$LAGENT_PY" -m lagent.serving.sandbox.daemon start \ --mode agent \ @@ -102,6 +123,7 @@ CHAT_RESP=$("$LAGENT_PY" -m lagent.serving.sandbox.daemon call \ tail -n 100 "$LOG" >&2 || true exit 5 } +_die_on_daemon_error "$CHAT_RESP" chat 5 # Extract final response content (plain text) to RESPONSE_OUT. printf '%s' "$CHAT_RESP" | "$LAGENT_PY" -c ' @@ -121,6 +143,7 @@ STATE_RESP=$("$LAGENT_PY" -m lagent.serving.sandbox.daemon call \ tail -n 100 "$LOG" >&2 || true exit 6 } +_die_on_daemon_error "$STATE_RESP" state_dict 6 # Wrap lagent's native state into {"trajectory": [...]} if it isn't already. printf '%s' "$STATE_RESP" | "$LAGENT_PY" -c ' @@ -153,6 +176,7 @@ POLICY_AGENT_MESSAGES=$("$LAGENT_PY" -m lagent.serving.sandbox.daemon call \ tail -n 100 "$LOG" >&2 || true exit 7 } +_die_on_daemon_error "$POLICY_AGENT_MESSAGES" get_messages 7 printf '%s' "$POLICY_AGENT_MESSAGES" | "$LAGENT_PY" -c ' import json, sys diff --git a/xtuner/v1/data_proto/rl_data.py b/xtuner/v1/data_proto/rl_data.py index 8cf71ca134..5e80f6d3c9 100644 --- a/xtuner/v1/data_proto/rl_data.py +++ b/xtuner/v1/data_proto/rl_data.py @@ -17,11 +17,14 @@ if TYPE_CHECKING: + import numpy as np import ray RayObjectRef = ray.ObjectRef + NumpyArray = np.ndarray else: RayObjectRef: TypeAlias = Any + NumpyArray: TypeAlias = Any logger = get_logger() @@ -116,7 +119,7 @@ class RLDatasetItem(BaseModel): class RolloutExtraInfo(TypedDict): - routed_experts: NotRequired[list[int] | str | RayObjectRef] # type: ignore[valid-type] + routed_experts: NotRequired[list[int] | RayObjectRef | NumpyArray] # type: ignore[valid-type] partial_rollout_input_ids: NotRequired[list[int]] diff --git a/xtuner/v1/ray/base/accelerator.py b/xtuner/v1/ray/base/accelerator.py index cf05509525..866ae51777 100644 --- a/xtuner/v1/ray/base/accelerator.py +++ b/xtuner/v1/ray/base/accelerator.py @@ -1,4 +1,5 @@ import os +from datetime import timedelta from typing import Any, Dict, List, Literal, Tuple, TypeVar import ray @@ -246,9 +247,15 @@ def setup_distributed(self, rank: int, master_addr: str, master_port: int, world else: raise ValueError(f"Unsupported accelerator architecture: {self.accelerator}") # 使用环境变量初始化 + # NCCL watchdog aborts collectives after this timeout. The default + # (10 min) is too tight for long-context MoE RL jobs where a single + # slow rank on a heavy allgather can burn 15+ min. Override via + # XTUNER_DIST_TIMEOUT_MIN env var. + timeout_min = int(os.environ.get("XTUNER_DIST_TIMEOUT_MIN", 60)) dist.init_process_group( backend=backend, init_method="env://", # 这告诉 PyTorch 从环境变量读取配置 + timeout=timedelta(minutes=timeout_min), ) def test_all_reduce(self): diff --git a/xtuner/v1/ray/config/worker.py b/xtuner/v1/ray/config/worker.py index 279849d41c..55d28fc1f2 100644 --- a/xtuner/v1/ray/config/worker.py +++ b/xtuner/v1/ray/config/worker.py @@ -186,13 +186,6 @@ class RolloutConfig(BaseModel): help="Whether to enable returning routed experts for the rollout worker.", ), ] = False - return_routed_experts_key: Annotated[ - bool, - Parameter( - group=infer_group, - help="Whether to keep LMDeploy routed experts as shared-store string keys before training.", - ), - ] = False launch_server_method: Annotated[ Literal["ray", "multiprocessing"], Parameter( diff --git a/xtuner/v1/ray/dataflow/replay_buffer.py b/xtuner/v1/ray/dataflow/replay_buffer.py index 94206d7b76..30144d2769 100644 --- a/xtuner/v1/ray/dataflow/replay_buffer.py +++ b/xtuner/v1/ray/dataflow/replay_buffer.py @@ -22,12 +22,10 @@ RLEnvDataItem, RLExtraDataItem, RLUIDItem, - RolloutExtraInfo, RolloutState, is_valid_for_replaybuffer, ) from xtuner.v1.datasets.config import DataloaderConfig -from xtuner.v1.ray.rollout.lmdeploy import get_lmdeploy_routed_experts_ref from xtuner.v1.ray.utils import free_object_refs from xtuner.v1.utils import XTUNER_DETERMINISTIC, get_logger from xtuner.v1.utils.device import get_device @@ -375,24 +373,6 @@ def __init__(self, replay_buffer_cfg): self.sample_from_aborted_count = 0 self.sample_from_expired_count = 0 - def _pop_routed_experts_from_extra_info( - self, extra_info: RolloutExtraInfo, *, free_ref: bool = False - ) -> ObjectRef | None: - if "routed_experts" not in extra_info: - return None - - routed_experts = extra_info["routed_experts"] - if isinstance(routed_experts, str): - routed_experts = get_lmdeploy_routed_experts_ref(routed_experts) - elif not isinstance(routed_experts, ObjectRef): - routed_experts = ray.put(routed_experts) - - del extra_info["routed_experts"] - if free_ref: - free_object_refs([routed_experts]) - return None - return routed_experts - def _update_replay_meta_state(self, replay_meta: ReplayMeta, new_state: RolloutState): for observation_id in replay_meta.observation_ids: old_state = self._observations2states.get(observation_id) @@ -408,10 +388,6 @@ def _strip_rollout_payload_for_rerun(self, replay_meta: ReplayMeta, new_state: R reused.""" old_obs_refs = [ref for ref in replay_meta.observation_refs if ref is not None] if old_obs_refs: - for old_obs_ref in old_obs_refs: - old_env = ray.get(old_obs_ref) - if hasattr(old_env, "rollout"): - self._pop_routed_experts_from_extra_info(old_env.rollout.extra_info, free_ref=True) ray.internal.free(old_obs_refs, local_only=False) replay_meta.observation_refs = [ray.put(RLEnvDataItem()) for _ in replay_meta.observation_ids] self._update_replay_meta_state(replay_meta, new_state) @@ -590,6 +566,33 @@ def _restore_nested_objectrefs(self, obj: Any): return obj return obj + def convert_to_ray_objref(self, data_item: RLDataFlowItem): + """Converts large tensors in RLDataFlowItem to ray.ObjectRefs. + + Args: + data_item (RLDataFlowItem): The data item containing large tensors. + Returns: + RLDataFlowItem: The data item with large tensors converted to ray.ObjectRefs. + """ + # convert data.multimodal_train_info to ray.ObjectRef + if hasattr(data_item.data, "multimodal_train_info"): + multimodal_info = data_item.data.multimodal_train_info + if multimodal_info and "pixel_values" in multimodal_info: + # 一组数据共享同一个data_item.data,所以只需要转换一次 + if not isinstance(multimodal_info["pixel_values"], ray.ObjectRef): + pixel_values_ref = ray.put(multimodal_info["pixel_values"]) + del multimodal_info["pixel_values"] + data_item.data.multimodal_train_info["pixel_values"] = pixel_values_ref # type: ignore[index] + # convert rollout.extra_info.router_experts to ray.ObjectRef + if "routed_experts" in data_item.env.rollout.extra_info: + val = data_item.env.rollout.extra_info["routed_experts"] + # str = uuid key into RoutedExpertStore; ObjectRef = legacy path. + # Either is left untouched; only raw tensors get ray.put'd. + if not isinstance(val, (ray.ObjectRef, str)): + routed_experts_ref = ray.put(val) + del data_item.env.rollout.extra_info["routed_experts"] + data_item.env.rollout.extra_info["routed_experts"] = routed_experts_ref + def has_objectref(self, item: RLDataFlowItem) -> bool: def check(obj): if isinstance(obj, ray.ObjectRef): @@ -718,7 +721,11 @@ def _sample_from_expired_storage(self) -> List[RLDataFlowItem]: for sample in group_samples: assert sample.data.input_ids and sample.data.num_tokens, "input_ids or num_tokens is empty!" - self._pop_routed_experts_from_extra_info(sample.env.rollout.extra_info, free_ref=True) + if "routed_experts" in sample.env.rollout.extra_info: + val = sample.env.rollout.extra_info["routed_experts"] + if isinstance(val, ray.ObjectRef): + ray.internal.free(val, local_only=False) + del sample.env.rollout.extra_info["routed_experts"] del sample.env sample.env = RLEnvDataItem() # 重置env数据 sample.uid.action_id = action_id @@ -754,7 +761,11 @@ def _sample_from_aborted_storage(self) -> List[RLDataFlowItem]: assert sample.data.input_ids and sample.data.num_tokens, "input_ids or num_tokens is empty!" if not self.enable_partial_rollout: # 清除上次的response_ids等env数据 - self._pop_routed_experts_from_extra_info(sample.env.rollout.extra_info, free_ref=True) + if "routed_experts" in sample.env.rollout.extra_info: + val = sample.env.rollout.extra_info["routed_experts"] + if isinstance(val, ray.ObjectRef): + ray.internal.free(val, local_only=False) + del sample.env.rollout.extra_info["routed_experts"] del sample.env sample.env = RLEnvDataItem() sample.uid.version = 0 diff --git a/xtuner/v1/ray/environment/agent_env.py b/xtuner/v1/ray/environment/agent_env.py index 6a3b964cae..a1f7b1a184 100644 --- a/xtuner/v1/ray/environment/agent_env.py +++ b/xtuner/v1/ray/environment/agent_env.py @@ -1,6 +1,7 @@ import asyncio import inspect import os +import pickle import traceback from copy import deepcopy from typing import Callable, List, Self, Tuple @@ -20,6 +21,11 @@ from .base_env import BaseEnvironment +# Memory diagnostics thresholds (bytes). Set via env var to disable (0). +_ITEM_SIZE_LOG_THRESHOLD = int(os.environ.get("XTUNER_ITEM_SIZE_LOG_MB", "1")) * 1024 * 1024 +_RSS_MONITOR_INTERVAL = int(os.environ.get("XTUNER_RSS_MONITOR_SEC", "60")) + + def check_dead_actors(): # 获取所有 Actor 的列表 from ray.util.state import list_actors @@ -35,6 +41,38 @@ def check_dead_actors(): return dead_actors +def _log_item_size_if_large(item) -> None: + """Diagnostic: pickle-measure item + per-field breakdown when total exceeds threshold. + + Runs on every ``_inner_agent_call`` iteration; on fresh samples the total is a few KB + so the early-exit keeps steady-state overhead negligible. On resume / abort paths the + item carries ``agent_state_dict`` / ``agent_message_dict`` from prior turns — we want + visibility into which field dominates. + """ + try: + total_size = len(pickle.dumps(item)) + except Exception as exc: + get_logger().debug(f"[item-size] pickle.dumps failed: {exc}") + return + if total_size < _ITEM_SIZE_LOG_THRESHOLD: + return + field_sizes: dict = {} + try: + for key, val in item.env.rollout.extra_info.items(): + try: + field_sizes[key] = len(pickle.dumps(val)) + except Exception: + field_sizes[key] = -1 # type: ignore[assignment] + except Exception: + pass + get_logger().warning( + f"[item-size] sample={getattr(item.uid, 'observation_id', '?')} " + f"total={total_size / 1e6:.1f}MB " + f"state={item.env.rollout.state} " + f"fields={ {k: f'{v / 1e6:.1f}MB' for k, v in field_sizes.items()} }" + ) + + @ray.remote(max_concurrency=int(os.environ.get("XTUNER_MAX_CONCURRENCY", 2000))) # type: ignore[call-overload] class AgentEnvironment(BaseEnvironment): def __init__( @@ -54,6 +92,24 @@ def __init__( self.agent_cfg = agent_cfg self.preprocess_func = preprocess_func self.postprocess_func = postprocess_func + if _RSS_MONITOR_INTERVAL > 0: + self._rss_task = asyncio.get_event_loop().create_task(self._rss_monitor()) + + async def _rss_monitor(self): + """Diagnostic: periodically log this actor's RSS so cross-sample leaks surface.""" + try: + import psutil + except ImportError: + get_logger().warning("[actor-rss] psutil not available, disabling RSS monitor") + return + proc = psutil.Process() + while True: + try: + rss_gb = proc.memory_info().rss / 1e9 + get_logger().info(f"[actor-rss] env={self.environment} pid={proc.pid} rss={rss_gb:.2f}GB") + except Exception as exc: + get_logger().warning(f"[actor-rss] read failed: {exc}") + await asyncio.sleep(_RSS_MONITOR_INTERVAL) async def generate( # type: ignore[override] self, group_data_items: List[RLDataFlowItem], sample_params=None, extra_params=None @@ -68,6 +124,8 @@ async def _inner_agent_call(item): if item.env.rollout.state == RolloutState.ABORTED: agent.load_state_dict(item.env.rollout.extra_info["agent_state_dict"]) # type: ignore[operator] + if _ITEM_SIZE_LOG_THRESHOLD > 0: + _log_item_size_if_large(item) _item = deepcopy(item) _item.env.agent.extra_info["agent"] = agent agent_inputs = self.preprocess_func(self, _item) diff --git a/xtuner/v1/ray/environment/install_agent_env.py b/xtuner/v1/ray/environment/install_agent_env.py index 809567fcb5..59b0baf375 100644 --- a/xtuner/v1/ray/environment/install_agent_env.py +++ b/xtuner/v1/ray/environment/install_agent_env.py @@ -112,8 +112,9 @@ async def _inner_agent_call(item): uid, provider=self.provider, lagent_src_dir=DEFAULT_LAGENT_SRC, - llm_base_url=None, - llm_api_key=None, + llm_model=os.environ.get("RL_LLM_MODEL"), + llm_base_url=os.environ.get("RL_LLM_BASE_URL"), + llm_api_key=os.environ.get("RL_LLM_API_KEY"), ) except BaseException as exc: get_logger().error( @@ -144,7 +145,26 @@ async def _inner_agent_call(item): # passed_data_items.append(sample) continue else: - sample.env.agent.extra_info["message_dict"] = result["env"]["agent"]["message_dict"] + # Defend against silent-pass / truncated trajectory: rc=0 but + # last message in policy_agent.messages lacks the fields + # postprocess will read. Same heuristic as + # DumpDaemonLogOnFailure so log signal + filter stay + # consistent. + msg_dict = (result["env"]["agent"] or {}).get("message_dict") or {} + messages = msg_dict.get("policy_agent.messages") or [] + last = messages[-1] if messages else {} + required = ("raw_content", "raw_content_ids", "raw_content_logprobs") + missing = [k for k in required if not last.get(k)] + if missing: + get_logger().warning( + f"silent-pass rollout skipped: " + f"uid={sample.uid.observation_id} " + f"task_id={result.get('data', {}).get('extra_info', {}).get('task_id')} " + f"missing={missing}" + ) + continue + sample.env.agent.extra_info["message_dict"] = msg_dict + sample.env.agent.extra_info["daemon_log"] = result["env"]["agent"].get("daemon_log", "") sample.env.judger.extra_info.update(result["env"]["judger"]) completed_data_items.append(sample) completed_data_items_result = self.postprocess_func(self, completed_data_items) # type: ignore[arg-type] diff --git a/xtuner/v1/ray/environment/lagent/parsers.py b/xtuner/v1/ray/environment/lagent/parsers.py index 20717db26d..3ab2a2fb1b 100644 --- a/xtuner/v1/ray/environment/lagent/parsers.py +++ b/xtuner/v1/ray/environment/lagent/parsers.py @@ -109,11 +109,21 @@ def parse_response(self, data: AgentMessage) -> AgentMessage: parameters = {} for p_match in param_matches: p_name = p_match.group(1).strip() - p_value = p_match.group(2).strip() + # Strip exactly one leading and one trailing newline — those + # are the formatting newlines inserted by the chat template + # around the value. Preserving additional newlines lets the + # model express trailing whitespace when it matters (e.g. a + # '\n' at the end of terminal keystrokes that triggers bash + # to execute the command). + p_raw = p_match.group(2) + if p_raw.startswith("\n"): + p_raw = p_raw[1:] + if p_raw.endswith("\n"): + p_raw = p_raw[:-1] try: - parsed_value = ast.literal_eval(p_value) + parsed_value = ast.literal_eval(p_raw) except (ValueError, SyntaxError): - parsed_value = p_value + parsed_value = p_raw if p_name in self.argument_type: try: parsed_value = self.argument_type[p_name](parsed_value) diff --git a/xtuner/v1/ray/environment/lagent/tokenize.py b/xtuner/v1/ray/environment/lagent/tokenize.py index e8c6a1e6c8..865af2ad57 100644 --- a/xtuner/v1/ray/environment/lagent/tokenize.py +++ b/xtuner/v1/ray/environment/lagent/tokenize.py @@ -2,6 +2,9 @@ import re from typing import Any, Dict, List +import numpy as np +import ray + from xtuner.v1.utils import get_logger @@ -25,6 +28,7 @@ def tokenize( thinking_start_ids = tokenizer.encode("", add_special_tokens=False) thinking_end_ids = tokenizer.encode("", add_special_tokens=False) routed_experts = None + previous_routed_experts_tasks = set() def get_content_index(content_ids) -> int: content_ids_str = " ".join([str(content_id) for content_id in content_ids]) @@ -81,7 +85,33 @@ def split_conversation(messages: List[Dict[str, Any]]) -> List[List[Dict[str, An and "routed_experts" in msg[0]["extra_info"] and msg[0]["extra_info"]["routed_experts"] is not None ): - routed_experts = msg[0]["extra_info"]["routed_experts"] + routed_experts_ref = msg[0]["extra_info"]["routed_experts"] + if isinstance(routed_experts_ref, np.ndarray): + # Inline path: numpy array carried directly. Same array object + # across turns → same id(); dedup on id avoids re-submitting the + # identical history on rollout retries. + dedup_key = id(routed_experts_ref) + passthrough: Any = routed_experts_ref + elif isinstance(routed_experts_ref, ray.ObjectRef): + dedup_key = routed_experts_ref.hex() + passthrough = routed_experts_ref + elif isinstance(routed_experts_ref, str): + # Legacy path: uuid key into RoutedExpertStore — forward as-is. + # Legacy path (base64(cloudpickle(ObjectRef))) is also a str; + # we forward it unchanged and let the rollout worker detect + # the shape via isinstance checks downstream. + dedup_key = routed_experts_ref + passthrough = routed_experts_ref + else: + raise TypeError(f"Unexpected type for routed_experts_ref: {type(routed_experts_ref)}") + if dedup_key in previous_routed_experts_tasks: + logger.warning( + "[tokenize_fn] Detected repeated routed_experts_ref, setting routed_experts to None to avoid errors." + ) + routed_experts = None + else: + routed_experts = passthrough + previous_routed_experts_tasks.add(dedup_key) else: routed_experts = None diff --git a/xtuner/v1/ray/environment/rl_task/hooks.py b/xtuner/v1/ray/environment/rl_task/hooks.py index 5f2d2ca8a5..14ff30fc19 100644 --- a/xtuner/v1/ray/environment/rl_task/hooks.py +++ b/xtuner/v1/ray/environment/rl_task/hooks.py @@ -39,6 +39,7 @@ walk_files, ) from xtuner.v1.ray.environment.rl_task.schemas import AgentSpec, CriterionScore, JudgerResult +from xtuner.v1.utils import get_logger # ───────────────────────────────────────────────────────────────── @@ -52,7 +53,7 @@ class PickAgent(Hook): Selection is deterministic on ``ctx["uid"]["root_id"]`` so the same rollout always picks the same agent. Also stores ``template_root`` in ``ctx["agent_template_root"]`` so downstream hooks - (:class:`UploadChosenAgent`, :class:`WriteAgentConfig`, + (:class:`UploadChosenAgent`, :class:`UploadAgentConfigSource`, :class:`RunAgentInstallDeps`) know where the agent's files live on the host. """ @@ -204,30 +205,33 @@ async def __call__(self, client: Any, ctx: dict[str, Any]) -> None: # ───────────────────────────────────────────────────────────────── -# Agent config: exec config.py on host, upload as JSON +# Agent config: upload config.py source (daemon execs it in-sandbox) # ───────────────────────────────────────────────────────────────── -class WriteAgentConfig(Hook): - """Exec the chosen agent's ``config.py`` on the host; upload the resulting - ``agent_config`` dict as JSON to ``dst``. +class UploadAgentConfigSource(Hook): + """Upload the chosen agent's ``config.py`` source file to ``dst``. + + The lagent daemon exec's this file in the sandbox to build the agent + dict — so ``os.environ`` lookups inside ``config.py`` resolve against + the sandbox's own env (populated by :class:`BenchEnv`), not the host's. Agent template lives at ``ctx["agent_template_root"] / chosen.name /`` (populated by :class:`PickAgent`). """ - name = "write_agent_config" + name = "upload_agent_config_source" - def __init__(self, dst: str = "/tmp/agent_config.json"): + def __init__(self, dst: str = "/tmp/agent_config.py"): self.dst = dst async def __call__(self, client: Any, ctx: dict[str, Any]) -> None: chosen: AgentSpec = ctx["chosen_agent"] template_root: Path = ctx["agent_template_root"] cfg_path = template_root / chosen.name / chosen.config - cfg = _exec_python_ns(cfg_path, "agent_config") - blob = json.dumps(cfg, ensure_ascii=False).encode() - await http_upload(client, self.dst, base64.b64encode(blob).decode()) + if not cfg_path.is_file(): + raise FileNotFoundError(f"agent config {cfg_path!r} not found") + await upload_tar_and_extract(client, {self.dst: cfg_path}, "/") # ───────────────────────────────────────────────────────────────── @@ -316,15 +320,86 @@ def __call__(self, ctx: dict[str, Any]) -> dict[str, str]: "TASK_WORKSPACE": self.workspace, "TASK_INSTRUCTION": f"{self.workspace}/{data.instruction}", } - if runtime.get("llm_base_url"): - env["RL_LLM_BASE_URL"] = runtime["llm_base_url"] - if runtime.get("llm_api_key"): - env["RL_LLM_API_KEY"] = runtime["llm_api_key"] + for env_key, runtime_key in ( + ("RL_LLM_MODEL", "llm_model"), + ("RL_LLM_BASE_URL", "llm_base_url"), + ("RL_LLM_API_KEY", "llm_api_key"), + ): + val = runtime.get(runtime_key) + if val: + env[env_key] = val env.update(self.extras) ctx["env_vars_for_instruction"] = env return env +# ───────────────────────────────────────────────────────────────── +# Daemon log retrieval (post-hook) +# ───────────────────────────────────────────────────────────────── + + +class DumpDaemonLogOnFailure(Hook): + """Post-hook: pull ``/tmp/agent_daemon.log`` and log its tail on failure. + + Two triggers: + - Stage's entry returned non-zero (``rc != 0``) — usual sandbox error. + - Silent-pass: ``rc == 0`` but the pulled ``message_key`` contents show + the last ``policy_agent.messages`` entry lacks ``raw_content_ids`` + (LLM call somehow produced no token ids — typically an exception + swallowed by the agent layer). Disable by passing ``message_key=None``. + + Always stores the full daemon log at ``ctx["pulled"][key]`` for + downstream consumers regardless of whether we log. + """ + + name = "dump_daemon_log_on_failure" + + def __init__( + self, + path: str = "/tmp/agent_daemon.log", + *, + tail_lines: int = 500, + key: str = "daemon_log", + message_key: str | None = "message", + ): + self.path = path + self.tail_lines = tail_lines + self.key = key + self.message_key = message_key + + async def __call__(self, client: Any, ctx: dict[str, Any]) -> None: + try: + blob = await client.download_file(self.path) + except Exception as exc: + get_logger().warning(f"could not download daemon log at {self.path}: {exc}") + return + text = blob.decode(errors="replace") + ctx.setdefault("pulled", {})[self.key] = text + + result = ctx.get("result") + rc = getattr(result, "return_code", 0) if result else 0 + + should_dump = rc != 0 + reason = f"rc={rc}" + if not should_dump and self.message_key: + raw = (ctx.get("pulled") or {}).get(self.message_key) or "" + try: + msgs = json.loads(raw).get("policy_agent.messages", []) if raw else [] + except Exception: + msgs = [] + last = msgs[-1] if msgs else {} + required = ("raw_content", "raw_content_ids", "raw_content_logprobs") + missing = [k for k in required if not last.get(k)] + if missing: + should_dump = True + reason = f"silent-pass (last message missing {missing})" + + if should_dump: + lines = text.splitlines() + tail = "\n".join(lines[-self.tail_lines :]) if len(lines) > self.tail_lines else text + get_logger().error(f"daemon log tail [{reason}] ({self.path}):\n{tail}") + + # ───────────────────────────────────────────────────────────────── # Judger result parsing # ───────────────────────────────────────────────────────────────── @@ -353,14 +428,6 @@ async def __call__(self, client: Any, ctx: dict[str, Any]) -> None: # ───────────────────────────────────────────────────────────────── -def _exec_python_ns(path: Path, expected_name: str) -> Any: - ns: dict[str, Any] = {} - exec(compile(path.read_text(encoding="utf-8"), str(path), "exec"), ns) - if expected_name not in ns: - raise KeyError(f"{expected_name!r} not defined in {path}") - return ns[expected_name] - - def _parse_stage_stdout(name: str, result: StageResult) -> JudgerResult: if result.return_code != 0: return JudgerResult( diff --git a/xtuner/v1/ray/environment/rl_task/runner.py b/xtuner/v1/ray/environment/rl_task/runner.py index 5c5075ae26..6e62c9351e 100644 --- a/xtuner/v1/ray/environment/rl_task/runner.py +++ b/xtuner/v1/ray/environment/rl_task/runner.py @@ -68,6 +68,7 @@ async def run_single( *, provider: Any, lagent_src_dir: str | Path | None = None, + llm_model: str | None = None, llm_base_url: str | None = None, llm_api_key: str | None = None, ) -> dict[str, Any]: @@ -81,10 +82,12 @@ async def run_single( "uid": uid, "runtime": { "lagent_src_dir": lagent_src_dir, + "llm_model": llm_model, "llm_base_url": llm_base_url, "llm_api_key": llm_api_key, }, "workspace": self.infer.sandbox.workspace_path, + "sandbox_image": self.infer.sandbox.image, } client = None @@ -104,7 +107,6 @@ async def run_single( infer_result = await self.infer.run(client, ctx) get_logger().info(f"[{tid}] infer: done rc={infer_result.return_code} ({time.monotonic() - t1:.1f}s)") if not infer_result.ok: - await _dump_daemon_log(client) return _mark_failed( data, uid, @@ -157,14 +159,6 @@ def _infer_metadata(ctx: dict[str, Any]) -> dict[str, Any]: return md -async def _dump_daemon_log(client) -> None: - try: - data = await client.download_file("/tmp/agent_daemon.log") - get_logger().error(f"agent daemon log tail:\n{data.decode(errors='replace')[-4000:]}") - except Exception as exc: - get_logger().warning(f"could not download daemon log: {exc}") - - _ACQUIRE_MAX_ATTEMPTS = 3 # Sandbox cold-start can take 30-60s under load; pathological boots run # longer. With async waits this budget is cheap (no thread tied up), so @@ -264,7 +258,10 @@ def _mark_completed( "step_rewards": [sr.model_dump() for sr in judge.step_rewards], "failed": judge.failed, }, - "agent": {"message_dict": json.loads(infer.pulled.get("message", "{}"))}, + "agent": { + "message_dict": json.loads(infer.pulled.get("message", "{}")), + "daemon_log": infer.pulled.get("daemon_log", ""), + }, }, } @@ -305,8 +302,8 @@ def _mark_failed( DEFAULT_GATEWAY = "http://env-gateway.ailab.ailab.ai" -# DEFAULT_LAGENT_SRC = "/mnt/shared-storage-user/llmit/user/liukuikun/workspace/lagent" -DEFAULT_LAGENT_SRC = "/mnt/shared-storage-user/llmit/user/wangziyi/projs/lagent" +DEFAULT_LAGENT_SRC = "/mnt/shared-storage-user/llmit/user/liukuikun/workspace/lagent" +# DEFAULT_LAGENT_SRC = "/mnt/shared-storage-user/llmit/user/wangziyi/projs/lagent" def _load_dataset_from_config(config_path: Path) -> Any: @@ -338,6 +335,7 @@ async def _run_one( *, uid: dict[str, int], lagent_src_dir: str | None, + llm_model: str | None, llm_base_url: str | None, llm_api_key: str | None, ) -> dict[str, Any]: @@ -350,6 +348,7 @@ async def _run_one( uid, provider=provider, lagent_src_dir=lagent_src_dir, + llm_model=llm_model, llm_base_url=llm_base_url, llm_api_key=llm_api_key, ) @@ -406,6 +405,7 @@ async def _guarded(job_idx: int, td: Path, uid: dict[str, int]) -> dict[str, Any provider, uid=uid, lagent_src_dir=lagent_src, + llm_model=args.llm_model, llm_base_url=args.llm_base_url, llm_api_key=args.llm_api_key, ) @@ -586,6 +586,7 @@ def main() -> int: default=DEFAULT_LAGENT_SRC, help="Local path to lagent source. Pass '' to skip upload.", ) + parser.add_argument("--llm-model", default=None) parser.add_argument("--llm-base-url", default=None) parser.add_argument("--llm-api-key", default=None) parser.add_argument( diff --git a/xtuner/v1/ray/environment/rl_task/sandbox.py b/xtuner/v1/ray/environment/rl_task/sandbox.py index f8a04f781f..6212eec601 100644 --- a/xtuner/v1/ray/environment/rl_task/sandbox.py +++ b/xtuner/v1/ray/environment/rl_task/sandbox.py @@ -376,8 +376,12 @@ async def exec_in( """ if env: command = _expand_vars(command, env) - prefix = " ".join(f'{k}="{v}"' for k, v in env.items()) - command = f"{prefix} {command}" + # Use `export` so vars carry across chained commands (`bash A && bash B`). + # Inline `VAR=val cmd1 && cmd2` scopes VAR to cmd1 only, which bites + # when entry runs pre_entry.sh && lagent_entry.sh — daemon subprocess + # wouldn't see RL_LLM_MODEL etc. + exports = "; ".join(f'export {k}="{v}"' for k, v in env.items()) + command = f"{exports}; {command}" result = await client.execute(command, cwd, timeout_sec) rc = _result_code(result) if raise_on_error and rc != 0: @@ -487,17 +491,27 @@ def _result_code(exec_res: dict[str, Any]) -> int: return int(rc) +_HOOK_STUCK_WARN_SEC = 30.0 + + async def _run_hook(hook: Hook, client: Any, ctx: dict[str, Any], *, phase: str) -> None: """Run one hook; on failure, log + stash + re-raise with a label that names the hook class so the traceback says which one blew up.""" name = getattr(hook, "name", None) or type(hook).__name__ - label = f"{phase}-hook {type(hook).__name__}({name!r})" tid = (ctx.get("data") and getattr(ctx["data"], "id", None)) or "?" - get_logger().debug(f"[{tid}] {label} start") + image = ctx.get("sandbox_image") or "?" + label = f"{phase}-hook {type(hook).__name__}({name!r}) image={image}" + get_logger().info(f"[{tid}] {label} start") t0 = time.monotonic() + hook_task = asyncio.create_task(hook(client, ctx)) try: - await hook(client, ctx) - get_logger().debug(f"[{tid}] {label} done ({time.monotonic() - t0:.2f}s)") + while True: + done, _ = await asyncio.wait([hook_task], timeout=_HOOK_STUCK_WARN_SEC) + if done: + break + get_logger().warning(f"[{tid}] {label} still running ({time.monotonic() - t0:.0f}s)") + await hook_task # re-raise if hook failed + get_logger().info(f"[{tid}] {label} done ({time.monotonic() - t0:.2f}s)") except Exception as exc: import traceback as _tb diff --git a/xtuner/v1/ray/evaluator.py b/xtuner/v1/ray/evaluator.py index 90215bc5ef..3ab169b862 100644 --- a/xtuner/v1/ray/evaluator.py +++ b/xtuner/v1/ray/evaluator.py @@ -203,7 +203,29 @@ async def eval_worker_task(self, sample: RLDataFlowItem): group_sample = await self.env_controller.run.remote( [sample], sample_params=self.sample_params, extra_params=extra_params ) # type: ignore[attr-defined] - self.return_list.append(group_sample[0]) + if not group_sample: + # Rollout failed upstream (HTTP 500 / actor death / cancellation); + # skip instead of IndexError'ing the append path. + return + sample_out = group_sample[0] + # Avoid `return_list` accumulation. _save_trajectories only reads: + # env.agent.extra_info["messages"] (keep — needed for jsonl) + # env.agent.extra_info["daemon_log"] (keep — eval-only) + # env.rollout.response_ids / .response (keep) + # env.judger.reward (keep) + # Everything else in extra_info is trainer-side state (per-turn lagent + # memory / routed_experts bookkeeping) and at MB-per-sample scale. + # Popping here prevents a single eval round from pushing the actor's + # Python heap into the 100+GB range (observed: 318GB per actor) and + # triggering node-level OOM on the rollout head node. + if sample_out.env.agent.extra_info: + sample_out.env.agent.extra_info.pop("state", None) + sample_out.env.agent.extra_info.pop("agent", None) + if sample_out.env.rollout.extra_info: + sample_out.env.rollout.extra_info.pop("agent_state_dict", None) + sample_out.env.rollout.extra_info.pop("agent_message_dict", None) + sample_out.env.rollout.extra_info.pop("routed_experts", None) + self.return_list.append(sample_out) async def concurrent_eval_task_runner(self): """Runs evaluation tasks concurrently to generate a batch of samples. diff --git a/xtuner/v1/ray/rollout/controller.py b/xtuner/v1/ray/rollout/controller.py index b44513a901..1e9c3f2ee0 100644 --- a/xtuner/v1/ray/rollout/controller.py +++ b/xtuner/v1/ray/rollout/controller.py @@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Union from uuid import uuid4 +import numpy as np import ray import uvicorn from fastapi import FastAPI @@ -136,9 +137,7 @@ def __init__( self.num_workers = 0 self.workers_info: Dict[str, WorkerInfo] = {} # url -> WorkerInfo self.active_rollout_workers: List[RolloutWorker] = [] - tokenizer_path = infer_config.tokenizer_path - assert tokenizer_path is not None, "tokenizer_path must be set before creating RolloutController" - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True) + self.tokenizer = AutoTokenizer.from_pretrained(infer_config.tokenizer_path, trust_remote_code=True) self.workers, self.rank_bundle_idx_list = AutoAcceleratorWorkers.from_placement_group( self._get_worker_cls(), infer_config, placement_group ) @@ -171,7 +170,7 @@ def __init__( Qwen3TokenReasonParser, ) - self.reasoning_parser = Qwen3TokenReasonParser(tokenizer_path) + self.reasoning_parser = Qwen3TokenReasonParser(infer_config.tokenizer_path) self.tool_call_parser = Qwen3_5FunctionCallParser() def _get_worker_status_for_router(self) -> Dict[RolloutWorker, bool]: @@ -486,6 +485,10 @@ def start_api_server(self, host: str = "0.0.0.0", port: int = 8000): @app.post("/v1/chat/completions") async def chat_completions(request: LagentChatCompletionRequest): + import base64 + + from ray import cloudpickle + inputs = tokenize(self.tokenizer, request.messages, request.tools) response: RLRolloutResponseItem = await self.rollout( prompt=request.messages, @@ -498,11 +501,28 @@ async def chat_completions(request: LagentChatCompletionRequest): {"routed_experts": inputs["routed_experts"]} if inputs["routed_experts"] is not None else {} ), ) + # HTTP boundary needs JSON-serializable response. In-cluster flow + # uses inline numpy (zero-copy via Ray RPC) but FastAPI's + # jsonable_encoder can't handle ndarray. Round-trip via xtuner's + # RoutedExpertStore: register the numpy, carry the str key over HTTP, + # and the client's next-turn request returns it through worker.py's + # `isinstance(history, str)` branch. + re_val = response.extra_info.get("routed_experts") + if isinstance(re_val, np.ndarray): + from xtuner.v1.ray.rollout.routed_expert_store import get_store + + store = get_store() + local_ref = ray.put(re_val) + key = await store.put_ref.remote([local_ref]) + response.extra_info["routed_experts"] = key + elif isinstance(re_val, ray.ObjectRef): + # Legacy path kept as a defensive fallback — encode the same + # way as before so older clients still decode correctly. + response.extra_info["routed_experts"] = base64.b64encode(cloudpickle.dumps(re_val)).decode("utf-8") message = AgentMessage.from_model_response(response, "assistant") message = self.reasoning_parser.parse_response(message) message = self.tool_call_parser.parse_response(message) completion_message = LagentChatCompletionMessage.from_agent_message(message) - completion_tokens = response.num_return_tokens or 0 return LagentChatCompletion( model=request.model, choices=[ @@ -516,8 +536,8 @@ async def chat_completions(request: LagentChatCompletionRequest): ], usage=CompletionUsage( prompt_tokens=len(inputs["input_ids"]), - completion_tokens=completion_tokens, - total_tokens=len(inputs["input_ids"]) + completion_tokens, + completion_tokens=response.num_return_tokens, + total_tokens=len(inputs["input_ids"]) + response.num_return_tokens, ), ).model_dump() diff --git a/xtuner/v1/ray/rollout/lmdeploy.py b/xtuner/v1/ray/rollout/lmdeploy.py index f4d3cab87b..c556a09a12 100644 --- a/xtuner/v1/ray/rollout/lmdeploy.py +++ b/xtuner/v1/ray/rollout/lmdeploy.py @@ -10,7 +10,6 @@ from ray.util.placement_group import placement_group_table from transformers import AutoTokenizer -from xtuner.v1.data_proto.rl_data import RolloutExtraInfo from xtuner.v1.ray.config import RolloutConfig from .worker import RolloutWorker @@ -18,27 +17,6 @@ SHARED_STORE = "shared_store" SHARED_STORE_NAMESPACE = "lmdeploy" -_LMDEPLOY_ACTOR = None - - -def get_lmdeploy_routed_experts_ref(routed_experts: Any): - global _LMDEPLOY_ACTOR - if isinstance(routed_experts, str): - if _LMDEPLOY_ACTOR is None: - _LMDEPLOY_ACTOR = ray.get_actor(SHARED_STORE, namespace=SHARED_STORE_NAMESPACE) - return _LMDEPLOY_ACTOR.get.remote(routed_experts) - if isinstance(routed_experts, ray.ObjectRef): - return routed_experts - return ray.put(torch.as_tensor(routed_experts)) - - -def put_lmdeploy_routed_experts_ref(routed_experts: Any, return_key: bool): - global _LMDEPLOY_ACTOR - if return_key: - if _LMDEPLOY_ACTOR is None: - _LMDEPLOY_ACTOR = ray.get_actor(SHARED_STORE, namespace=SHARED_STORE_NAMESPACE) - return ray.get(_LMDEPLOY_ACTOR.put.remote(routed_experts)) - return ray.put(routed_experts) def run_lmdeploy_server_wrapper(lmdeploy_config_namespace: Namespace): @@ -101,7 +79,7 @@ def __init__( self.api_keys = self.config.api_key self.model_name = self.config.model_name self.enable_return_routed_experts = self.config.enable_return_routed_experts - self.return_routed_experts_key = self.config.return_routed_experts_key + self.lmdeploy_actor = None async def _create_request( self, @@ -236,42 +214,14 @@ def reset_prefix_cache(self): """It will implemented for LMDeploy worker in the future.""" pass - async def _handle_routed_experts_response( - self, - root_id: Any, - action_id: Any, - response: dict, - input_extra_info: RolloutExtraInfo, - extra_info: RolloutExtraInfo, - finish_reason: str, - ) -> None: - exist_history_routed_experts = ( - "routed_experts" in input_extra_info and input_extra_info["routed_experts"] is not None - ) - routed_experts = response["meta_info"].pop("routed_experts") # token[layer[expert]] - if routed_experts is not None and not exist_history_routed_experts: - if isinstance(routed_experts, str) and self.return_routed_experts_key: - extra_info["routed_experts"] = routed_experts - else: - extra_info["routed_experts"] = get_lmdeploy_routed_experts_ref(routed_experts) - elif routed_experts is not None and exist_history_routed_experts: - cur_routed_experts_ref = get_lmdeploy_routed_experts_ref(routed_experts) - history_routed_experts_ref = get_lmdeploy_routed_experts_ref(input_extra_info["routed_experts"]) - cur_routed_experts = await cur_routed_experts_ref # n, layer, expert - history_routed_experts = await history_routed_experts_ref # n, layer, expert - ray.internal.free([cur_routed_experts_ref, history_routed_experts_ref], local_only=False) - concat_routed_experts = self._concat_partial_routed_experts( - root_id, action_id, response, history_routed_experts, cur_routed_experts - ) - extra_info["routed_experts"] = put_lmdeploy_routed_experts_ref( - concat_routed_experts, self.return_routed_experts_key - ) - del history_routed_experts - del cur_routed_experts - else: - assert finish_reason == "abort", ( - f"routed_experts is None, but finish_reason is {finish_reason}, expected abort. response: {response}" - ) + def _decode_routed_experts(self, routed_experts: Any): + if isinstance(routed_experts, str): + if self.lmdeploy_actor is None: + self.lmdeploy_actor = ray.get_actor(SHARED_STORE, namespace=SHARED_STORE_NAMESPACE) + assert self.lmdeploy_actor is not None, "LMDeploy actor should be available in the shared store." + routed_experts_ref = self.lmdeploy_actor.get.remote(routed_experts) + return routed_experts_ref + return torch.tensor(routed_experts) def _transform_rollout_config_to_server_configs(self) -> Namespace: """Transform the RolloutConfig into a Namespace suitable for the @@ -323,8 +273,6 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace: extra_engine_config: Dict[str, Any] = {} if backend == "pytorch" and self.config.enable_return_routed_experts: extra_engine_config["enable_return_routed_experts"] = True - if self.config.return_routed_experts_key: - extra_engine_config["enable_transfer_obj_ref"] = True if backend == "pytorch" and self.config.router_n_groups: hf_overrides = extra_engine_config.setdefault("hf_overrides", {}) hf_overrides.update(router_n_groups=self.config.router_n_groups) diff --git a/xtuner/v1/ray/rollout/routed_expert_store.py b/xtuner/v1/ray/rollout/routed_expert_store.py new file mode 100644 index 0000000000..0d1f3169d4 --- /dev/null +++ b/xtuner/v1/ray/rollout/routed_expert_store.py @@ -0,0 +1,198 @@ +"""Named Ray actor that owns all ``routed_experts`` ObjectRefs. + +Motivation: + lmdeploy workers produce ``routed_experts`` per rollout and encode them as + Ray ``ObjectRef`` values. Those refs travel through HTTP/JSON (the agent + sandbox roundtrip) as cloudpickled hex strings, spending minutes to hours + outside any Python process. Ray's distributed refcount + cloudpickle + out-of-band tracking is fragile at that timescale — under heartbeat flakes + or owner restart the object can be evicted, and training later hits + ``ObjectFetchTimedOutError`` with "no locations were found". + + Routing the refs through a dedicated long-lived store decouples their + lifetime from lmdeploy's lifecycle. Transport between actors/processes + becomes a short uuid string (``put_tensor`` → ``get_ref`` → ``release``) + instead of a cloudpickled ObjectRef blob, so the failure mode is replaced + by a plain dict lookup that either hits or misses. +""" + +from __future__ import annotations + +import os +import time +import uuid +from typing import Any + +import ray +from ray import ObjectRef + +from xtuner.v1.utils import get_logger + + +_STORE_NAME = "routed_expert_store" +# Generous default — covers long rollouts that may not be consumed immediately +# (e.g. eval groups, aborted tasks awaiting GC, etc). Overridable via env var +# for quick tuning without re-editing the source. +_DEFAULT_TTL_SEC = int(os.environ.get("XTUNER_ROUTED_EXPERT_TTL_SEC", 24 * 3600)) +_DEFAULT_GC_INTERVAL_SEC = int(os.environ.get("XTUNER_ROUTED_EXPERT_GC_INTERVAL_SEC", 300)) + + +@ray.remote(num_cpus=0) +class RoutedExpertStore: + """Dedicated store actor for routed_experts ObjectRefs. + + Every ref in ``self._store`` has the store as its Ray owner (via + ``ray.put``), so its lifetime is bounded purely by the store's Python + dict — no distributed heartbeats or cloudpickle TTL involved. + + GC runs inline on each ``put_*`` call: if the last sweep happened more + than ``gc_interval_sec`` ago, stale keys older than ``ttl_sec`` are + evicted. Keeps the store honest without a background task. + """ + + def __init__(self, ttl_sec: int = _DEFAULT_TTL_SEC, gc_interval_sec: int = _DEFAULT_GC_INTERVAL_SEC): + self._store: dict[str, tuple[ObjectRef, float]] = {} + self._ttl = ttl_sec + self._gc_interval = gc_interval_sec + self._last_gc = time.monotonic() + # Diagnostic counters — help locate double-consume / leak symptoms + # after the fact without changing runtime behaviour. + self._n_put = 0 + self._n_get = 0 + self._n_release = 0 + self._n_missing_get = 0 + self._n_missing_release = 0 + + def put_tensor(self, tensor: Any) -> str: + """Take ownership of a tensor via ray.put; return a uuid key. + + NOTE: this places the tensor in the STORE actor's local plasma — + under heavy traffic this can cause one-node plasma saturation (all + rollouts across the cluster funnel to whichever node this actor + runs on). Prefer ``put_ref`` + worker-side ``ray.put`` for + multi-node scale. + """ + self._maybe_gc() + ref = ray.put(tensor) + self._n_put += 1 + return self._stash(ref) + + def put_ref(self, wrapped: list) -> str: + """Register an externally-owned ObjectRef, return a uuid key. + + The caller must wrap the ref in a single-element list. Ray + auto-dereferences bare ObjectRef args; the list wrapper bypasses + that so the store receives the ref itself (not the materialized + tensor). + + Ownership stays with the original ``ray.put`` caller (typically + the rollout worker); data lives in that caller's node plasma. + The store just holds a Python-level strong reference so Ray's + distributed refcount keeps the object alive across consumers. + + This distributes plasma pressure across all rollout nodes instead + of funneling it to the store's node. Trade-off: if the owner + worker dies, the ref's object is lost; with ``put_tensor`` only + the store dying would lose data. + """ + self._maybe_gc() + if not (isinstance(wrapped, list) and len(wrapped) == 1): + raise TypeError(f"put_ref expects [ObjectRef] to bypass auto-deref, got {type(wrapped)}") + ref = wrapped[0] + if not isinstance(ref, ObjectRef): + raise TypeError(f"put_ref expects [ObjectRef], got [{type(ref)}]") + self._n_put += 1 + return self._stash(ref) + + def get_ref(self, key: str) -> ObjectRef: + """Return the stashed ObjectRef; caller runs ``ray.get`` to + materialize.""" + entry = self._store.get(key) + if entry is None: + self._n_missing_get += 1 + raise KeyError(f"RoutedExpertStore: key not found: {key}") + self._store[key] = (entry[0], time.monotonic()) + self._n_get += 1 + return entry[0] + + def release(self, key: str) -> None: + if self._store.pop(key, None) is None: + self._n_missing_release += 1 + else: + self._n_release += 1 + + def release_many(self, keys: list[str]) -> None: + for k in keys: + self.release(k) + + def stats(self) -> dict: + return { + "live": len(self._store), + "ttl_sec": self._ttl, + "n_put": self._n_put, + "n_get": self._n_get, + "n_release": self._n_release, + "n_missing_get": self._n_missing_get, + "n_missing_release": self._n_missing_release, + } + + def _stash(self, ref: ObjectRef) -> str: + key = uuid.uuid4().hex + self._store[key] = (ref, time.monotonic()) + return key + + def _maybe_gc(self) -> None: + now = time.monotonic() + if now - self._last_gc < self._gc_interval: + return + self._last_gc = now + stale = [k for k, (_, t) in self._store.items() if now - t > self._ttl] + for k in stale: + self._store.pop(k, None) + if stale: + get_logger().warning( + f"RoutedExpertStore GC: evicted {len(stale)} stale keys (ttl={self._ttl}s, remaining={len(self._store)})" + ) + + +_handle_cache: Any = None + + +def get_store(): + """Process-local cached handle to the singleton store actor. + + First tries ``ray.get_actor`` (fast path if another caller created it + already). On NotFound, tries to create; if creation races with another + caller (raises because name is taken), falls back to another lookup. + """ + global _handle_cache + if _handle_cache is not None: + return _handle_cache + + # Fast path: already exists. + try: + _handle_cache = ray.get_actor(_STORE_NAME) + return _handle_cache + except ValueError: + pass + + # Slow path: try to create. Race with concurrent callers is handled by + # catching the "name already taken" case and re-looking-up. + import time as _time + + for attempt in range(10): + try: + _handle_cache = RoutedExpertStore.options(name=_STORE_NAME).remote() + return _handle_cache + except ValueError as exc: + # Either "name already taken" (someone else won) or transient + # issue. Try to look up; if that also fails, back off. + try: + _handle_cache = ray.get_actor(_STORE_NAME) + return _handle_cache + except ValueError: + get_logger().debug(f"RoutedExpertStore bootstrap retry {attempt}: {exc}") + _time.sleep(0.2 * (attempt + 1)) + continue + + raise RuntimeError(f"RoutedExpertStore: failed to acquire named actor {_STORE_NAME!r} after retries") diff --git a/xtuner/v1/ray/rollout/worker.py b/xtuner/v1/ray/rollout/worker.py index 4f9e3e17c1..833b374697 100644 --- a/xtuner/v1/ray/rollout/worker.py +++ b/xtuner/v1/ray/rollout/worker.py @@ -14,11 +14,11 @@ import ray import requests # type: ignore[import-untyped] from packaging.version import Version -from ray import ObjectRef +from ray import ObjectRef, cloudpickle from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from transformers import AutoTokenizer -from xtuner.v1.data_proto.rl_data import RLRolloutResponseItem, RolloutExtraInfo, RolloutState +from xtuner.v1.data_proto.rl_data import RLRolloutResponseItem, RolloutState from xtuner.v1.ray import find_master_addr_and_port from xtuner.v1.ray.base import AutoAcceleratorWorkers, SingleAcceleratorWorker from xtuner.v1.ray.config import RolloutConfig @@ -156,80 +156,6 @@ def init(self, dist_init_addr: str = ""): def _decode_routed_experts(self, routed_experts: Any) -> Any: return routed_experts - def _concat_partial_routed_experts( - self, - root_id: Any, - action_id: Any, - response: dict, - history_routed_experts: Any, - cur_routed_experts: Any, - ) -> Any: - assert (history_routed_experts.shape[0] - 1) > 0 and history_routed_experts.shape[ - 0 - ] - 1 <= cur_routed_experts.shape[0], ( - f"Existing routed_experts shape: {history_routed_experts.shape}, current routed_experts shape: {cur_routed_experts.shape}" - ) - init_cur_roued_experts = cur_routed_experts.shape[0] - cur_routed_experts = cur_routed_experts[history_routed_experts.shape[0] :, :, :] - concat_routed_experts = np.concatenate((history_routed_experts, cur_routed_experts), axis=0) - prompt_tokens = response["meta_info"].get("prompt_tokens", 0) - response_tokens = response["meta_info"].get("completion_tokens", 0) - assert concat_routed_experts.shape[0] == prompt_tokens + response_tokens - 1, ( - f"Routed experts shape {concat_routed_experts.shape[0]} does not match total tokens {prompt_tokens + response_tokens - 1}" - ) - self.logger.debug( - f"[{root_id}/{action_id}] Partial Rollout Stats: " - f"Tokens(prompt={prompt_tokens}, response={response_tokens}, total={prompt_tokens + response_tokens}) | " - f"Experts(exist={history_routed_experts.shape}, init_cur={init_cur_roued_experts}, cur={cur_routed_experts.shape}, concat={concat_routed_experts.shape})" - ) - return concat_routed_experts - - async def _handle_routed_experts_response( - self, - root_id: Any, - action_id: Any, - response: dict, - input_extra_info: RolloutExtraInfo, - extra_info: RolloutExtraInfo, - finish_reason: str, - ) -> None: - exist_history_routed_experts = ( - "routed_experts" in input_extra_info and input_extra_info["routed_experts"] is not None - ) - routed_experts = response["meta_info"].pop("routed_experts") # token[layer[expert]] - if routed_experts is not None and not exist_history_routed_experts: - routed_experts = self._decode_routed_experts(routed_experts) - assert not isinstance(routed_experts, str), ( - "String routed_experts keys must be handled by the backend-specific rollout worker." - ) - if not isinstance(routed_experts, ObjectRef): - routed_experts = ray.put(routed_experts) - extra_info["routed_experts"] = routed_experts - elif routed_experts is not None and exist_history_routed_experts: - routed_experts = self._decode_routed_experts(routed_experts) - if isinstance(routed_experts, ObjectRef): - cur_routed_experts = await routed_experts # n, layer, expert - ray.internal.free(routed_experts, local_only=False) - else: - cur_routed_experts = routed_experts - - history_routed_experts_ref = input_extra_info["routed_experts"] - assert isinstance(history_routed_experts_ref, ObjectRef), ( - "Base rollout worker expects history routed_experts to be a Ray ObjectRef." - ) - history_routed_experts = await history_routed_experts_ref # n, layer, expert - ray.internal.free(history_routed_experts_ref, local_only=False) - concat_routed_experts = self._concat_partial_routed_experts( - root_id, action_id, response, history_routed_experts, cur_routed_experts - ) - extra_info["routed_experts"] = ray.put(concat_routed_experts) - del history_routed_experts - del cur_routed_experts - else: - assert finish_reason == "abort", ( - f"routed_experts is None, but finish_reason is {finish_reason}, expected abort. response: {response}" - ) - def set_engine_rank_mesh_array(self, engine_rank_mesh_array: list[list[int]]): self.engine_rank_mesh_array = engine_rank_mesh_array @@ -625,7 +551,7 @@ async def _handle_non_stream_response( if "return_token_ids" in extra_params and extra_params["return_token_ids"]: last_logprobs: list[float] = [] try: - extra_info: RolloutExtraInfo = {} + extra_info = {} finish_reason = response["meta_info"]["finish_reason"]["type"] if finish_reason == "abort" and self.receive_abort_request.is_set() is False: self.receive_abort_request.set() @@ -648,14 +574,79 @@ async def _handle_non_stream_response( assert "routed_experts" in response["meta_info"], ( "enable_return_routed_experts is True, but routed_experts is not in meta_info" ) - await self._handle_routed_experts_response( - root_id=root_id, - action_id=action_id, - response=response, - input_extra_info=input_extra_info, - extra_info=extra_info, - finish_reason=finish_reason, + exist_history_routed_experts = ( + "routed_experts" in input_extra_info and input_extra_info["routed_experts"] is not None ) + routed_experts = response["meta_info"].pop("routed_experts") # token[layer[expert]] + if routed_experts is not None and not exist_history_routed_experts: + # Turn 1: materialize tensor inline. lmdeploy's _SHARED_STORE.get is + # destructive (pop+ray.get) so awaiting `decoded` already releases + # lmdeploy's plasma copy; we just carry the numpy through extra_info. + decoded = self._decode_routed_experts(routed_experts) + if isinstance(decoded, ObjectRef): + tensor = await decoded + else: + tensor = decoded + extra_info["routed_experts"] = tensor + elif routed_experts is not None and exist_history_routed_experts: + # Turn 2+: concat history with new decoded chunk. History comes inline + # as np.ndarray (new path) or via legacy store/ref types (compat). + decoded = self._decode_routed_experts(routed_experts) + if isinstance(decoded, ObjectRef): + cur_routed_experts = await decoded + else: + cur_routed_experts = decoded + + history = input_extra_info["routed_experts"] + if isinstance(history, np.ndarray): + # New path: inline numpy history. + history_routed_experts = history + elif isinstance(history, str): + # Legacy path: uuid key into RoutedExpertStore. Release after use + # (there's no retry path that re-consumes the same key in inline mode). + from xtuner.v1.ray.rollout.routed_expert_store import get_store as _legacy_get_store + + legacy_store = _legacy_get_store() + history_ref = await legacy_store.get_ref.remote(history) + history_routed_experts = await history_ref + legacy_store.release.remote(history) + elif isinstance(history, ObjectRef): + history_routed_experts = await history + ray.internal.free([history], local_only=False) + elif isinstance(history, (bytes, bytearray)): + # Legacy path: controller.py serialized ref as base64(cloudpickle(...)). + history_ref = cloudpickle.loads(history) + history_routed_experts = await history_ref + ray.internal.free([history_ref], local_only=False) + else: + raise TypeError(f"Unexpected type for input_extra_info['routed_experts']: {type(history)}") + del input_extra_info + + assert (history_routed_experts.shape[0] - 1) > 0 and history_routed_experts.shape[ + 0 + ] - 1 <= cur_routed_experts.shape[0], ( + f"Existing routed_experts shape: {history_routed_experts.shape}, current routed_experts shape: {cur_routed_experts.shape}" + ) + init_cur_roued_experts = cur_routed_experts.shape[0] + cur_routed_experts = cur_routed_experts[history_routed_experts.shape[0] :, :, :] + concat_routed_experts = np.concatenate((history_routed_experts, cur_routed_experts), axis=0) + prompt_tokens = response["meta_info"].get("prompt_tokens", 0) + response_tokens = response["meta_info"].get("completion_tokens", 0) + assert concat_routed_experts.shape[0] == prompt_tokens + response_tokens - 1, ( + f"Routed experts shape {concat_routed_experts.shape[0]} does not match total tokens {prompt_tokens + response_tokens - 1}" + ) + self.logger.debug( + f"[{root_id}/{action_id}] Partial Rollout Stats: " + f"Tokens(prompt={prompt_tokens}, response={response_tokens}, total={prompt_tokens + response_tokens}) | " + f"Experts(exist={history_routed_experts.shape}, init_cur={init_cur_roued_experts}, cur={cur_routed_experts.shape}, concat={concat_routed_experts.shape})" + ) + extra_info["routed_experts"] = concat_routed_experts + del history_routed_experts + del cur_routed_experts + else: + assert finish_reason == "abort", ( + f"routed_experts is None, but finish_reason is {finish_reason}, expected abort. response: {response}" + ) # NOTE: When set return_token_ids = True, the response must contain valid token_ids/logprobs. # If not, we consider it as an invalid response and retry it. # NOTE: !!! When finish_reason is abort, some queries may not return token_ids or logprobs. !!! diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index a16cfc77e7..db7eba77d5 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import Dict, Iterable, List, Sequence, TypeAlias, TypedDict, cast +import numpy as np import ray import requests import torch @@ -13,6 +14,7 @@ import tqdm from mmengine.runner import set_random_seed from pydantic import BaseModel, ConfigDict +from ray import cloudpickle from ray.actor import ActorClass, ActorProxy from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.distributed.tensor import DTensor @@ -36,6 +38,7 @@ from xtuner.v1.model.utils.misc import ModelForwardExtraLogInfo from xtuner.v1.ray.base import SingleAcceleratorWorker from xtuner.v1.ray.config import RolloutConfig +from xtuner.v1.ray.rollout.routed_expert_store import get_store from xtuner.v1.rl.base.loss import BaseRLLossContext from xtuner.v1.train.trainer import LoadCheckpointConfig from xtuner.v1.utils import ( @@ -402,7 +405,9 @@ def compute_ref_logprobs( return ref_logprobs_list def _add_rollout_routed_experts( - self, seq_ctx: SequenceContext, rollout_routed_experts: torch.Tensor | list[torch.Tensor | ray.ObjectRef] + self, + seq_ctx: SequenceContext, + rollout_routed_experts: torch.Tensor | list[torch.Tensor | ray.ObjectRef | np.ndarray | str | bytes], ): language_cfg = ( self.config.model_cfg.text_config @@ -411,6 +416,12 @@ def _add_rollout_routed_experts( ) to_free_routed_expert_refs: list[ray.ObjectRef] = [] + # Store keys to release AFTER consumption. Training is 1:1 — each + # rollout's final key appears in this train step exactly once — + # so releasing here is safe (no double-consume from training side). + # SP-mesh: multiple ranks see the same data; only rank-0 releases, + # gated by dist.barrier() so other ranks finish their ray.get first. + to_release_pin_keys: list[str] = [] if isinstance(rollout_routed_experts, list): # list[n,l,e] out_rollout_routed_expert = [] @@ -426,8 +437,47 @@ def _add_rollout_routed_experts( ), ) out_rollout_routed_expert.append(rollout_routed_experts_tensor) + elif isinstance(rollout_routed_expert, np.ndarray): + # Inline path: numpy array carried directly in extra_info. + rollout_routed_expert = torch.as_tensor(rollout_routed_expert, dtype=torch.long) + rollout_routed_expert = rollout_routed_expert.reshape( + -1, language_cfg.num_hidden_layers, language_cfg.num_experts_per_tok + ) + out_rollout_routed_expert.append(rollout_routed_expert) + elif isinstance(rollout_routed_expert, ray.ObjectRef): + # Inline + replay_buffer wrap: ray.put(numpy) → ObjectRef. Owner is the + # replay_buffer actor (long-lived), so ray.get is safe. + rollout_routed_expert_ref = rollout_routed_expert + rollout_routed_expert = ray.get(rollout_routed_expert_ref) + if self.sp_mesh is None or self.sp_mesh.size() == 1: + ray.internal.free([rollout_routed_expert_ref], local_only=False) + elif self.sp_mesh.get_local_rank() == 0: + to_free_routed_expert_refs.append(rollout_routed_expert_ref) + rollout_routed_expert = torch.as_tensor(rollout_routed_expert, dtype=torch.long) + rollout_routed_expert = rollout_routed_expert.reshape( + -1, language_cfg.num_hidden_layers, language_cfg.num_experts_per_tok + ) + out_rollout_routed_expert.append(rollout_routed_expert) + elif isinstance(rollout_routed_expert, str): + # Legacy path: uuid key → RoutedExpertStore. Owner is the + # store actor, so ray.get is reliable regardless of + # lmdeploy worker state. + store = get_store() + pin_key = rollout_routed_expert + pin_ref = ray.get(store.get_ref.remote(pin_key)) + rollout_routed_expert = ray.get(pin_ref) + if self.sp_mesh is None or self.sp_mesh.size() == 1: + store.release.remote(pin_key) + elif self.sp_mesh.get_local_rank() == 0: + to_release_pin_keys.append(pin_key) + rollout_routed_expert = torch.as_tensor(rollout_routed_expert, dtype=torch.long) + rollout_routed_expert = rollout_routed_expert.reshape( + -1, language_cfg.num_hidden_layers, language_cfg.num_experts_per_tok + ) + out_rollout_routed_expert.append(rollout_routed_expert) else: - rollout_routed_expert_refs = rollout_routed_expert + # Legacy path: bytes (base64-decoded cloudpickle of ObjectRef). + rollout_routed_expert_refs = cloudpickle.loads(rollout_routed_expert) rollout_routed_expert = ray.get(rollout_routed_expert_refs) # free obj store explicitly if self.sp_mesh is None or self.sp_mesh.size() == 1: @@ -465,7 +515,10 @@ def _add_rollout_routed_experts( dist.barrier() for free_routed_expert_refs in to_free_routed_expert_refs: ray.internal.free(free_routed_expert_refs, local_only=False) + if to_release_pin_keys: + get_store().release_many.remote(to_release_pin_keys) del to_free_routed_expert_refs + del to_release_pin_keys @ray_method def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int) -> WorkerLogItem: diff --git a/xtuner/v1/train/agent_rl_trainer.py b/xtuner/v1/train/agent_rl_trainer.py index 16159957dc..714d782abc 100644 --- a/xtuner/v1/train/agent_rl_trainer.py +++ b/xtuner/v1/train/agent_rl_trainer.py @@ -23,7 +23,6 @@ from xtuner.v1.ray.dataflow import DataFlow, DataFlowConfig, ReplayBufferConfig from xtuner.v1.ray.environment.lagent.tokenize import tokenize from xtuner.v1.ray.evaluator import Evaluator, EvaluatorConfig -from xtuner.v1.ray.rollout.lmdeploy import get_lmdeploy_routed_experts_ref from xtuner.v1.rl.base import WorkerConfig from xtuner.v1.rl.config.advantage import BaseAdvantageConfig, GRPOAdvantageConfig from xtuner.v1.train.trainer import LoadCheckpointConfig @@ -43,6 +42,24 @@ DEVICE_MODULE = get_torch_device_module() +_COMPACT_MSG_KEYS = ("reasoning_content", "thinking", "content", "tool_calls", "raw_content", "name", "tool_call_id") + + +def _compact_message(msg: dict) -> dict: + """Project a memory message dict to the fields worth inspecting in a + trajectory. + + Keeps role, reasoning (under either schema name), content, tool_calls, raw_content (for debugging parser failures), + and tool identifiers. Drops token ids / logprobs / session metadata that would bloat the jsonl. + """ + out: dict = {"role": msg.get("role", "assistant")} + for key in _COMPACT_MSG_KEYS: + val = msg.get(key) + if val not in (None, "", [], {}): + out[key] = val + return out + + class AgentRLTrainerConfig(BaseModel): model_config = ConfigDict(extra="forbid") @@ -456,24 +473,15 @@ def _save_trajectories(self, data_groups, save_path, rollout_idx=None, is_eval: f.write("\n") for group in data_groups: for data in group: + messages = data.env.agent.extra_info.get("messages", []) entry = { - # "raw_prompt": data.data.extra_info["raw_prompt"], - "prompt": [ - { - "role": msg["role"], - "content": msg["raw_content"] if "raw_content" in msg else msg["content"], - } - for msg in data.env.agent.extra_info.get("messages", [])[:-1] - ], - "response": data.env.rollout.response, + "prompt": [_compact_message(msg) for msg in messages[:-1]], + "response": _compact_message(messages[-1]) if messages else None, "response_len": len(data.env.rollout.response_ids or []), - # "label": data.data.reward_model["ground_truth"], "reward": data.env.judger.reward["score"], - # "round": sum(msg['role'] == 'assistant' for msg in data.env.agent.extra_info['messages'][:-1]), - # "judger_response": data.env.judger.extra_info, } - # if "completions" in data.env.agent.extra_info: - # entry["completions"] = data.env.agent.extra_info["completions"] + if is_eval: + entry["daemon_log"] = data.env.agent.extra_info.get("daemon_log", "") json.dump(entry, f, ensure_ascii=False, indent=2) f.write("\n") @@ -619,10 +627,7 @@ def _extract_score(value, default=0.0): ), f"{rollout_logprobs.size()} vs {shifted_labels.size()}" seq_ctx = SequenceContext.from_input_ids((input_ids,), device="cpu") - routed_experts = inputs["routed_experts"] - if isinstance(routed_experts, str): - routed_experts = get_lmdeploy_routed_experts_ref(routed_experts) - seq_ctx.rollout_routed_experts = routed_experts + seq_ctx.rollout_routed_experts = inputs["routed_experts"] data_batches.append( dict( seq_ctx=seq_ctx, diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index 7c80065120..986f7886f3 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -26,7 +26,6 @@ from xtuner.v1.ray.environment import SingleTurnEnvironment, SingleTurnEnvironmentProxy from xtuner.v1.ray.evaluator import Evaluator, EvaluatorConfig from xtuner.v1.ray.judger import JudgerConfig -from xtuner.v1.ray.rollout.lmdeploy import get_lmdeploy_routed_experts_ref from xtuner.v1.rl.base import ( TrainingController, TrainingControllerProxy, @@ -554,7 +553,26 @@ def _rollout_step(self, rollout_idx: int, step_timer_dict: dict) -> RolloutInfo: """Performs a single rollout step to generate experience.""" with timer("generation", step_timer_dict): ray.get(self._rollout_env_controller.update_active_workers.remote()) - dataflow_result = ray.get(self._rollout_dataflow.run.remote()) + # Ultimate fallback: without a timeout, dataflow.run can hang + # forever when inner rollout workers die and their pending tasks + # never resolve. 12h is deliberately loose — normal long-tail + # rollout batches run ~5h, so anything past 12h is almost + # certainly a real deadlock worth surfacing to the driver. + dataflow_ref = self._rollout_dataflow.run.remote() + try: + dataflow_result = ray.get(dataflow_ref, timeout=12 * 3600) + except ray.exceptions.GetTimeoutError: + self.logger.error( + f"rollout_idx {rollout_idx}: dataflow.run exceeded 12h, " + f"likely a stuck rollout/sandbox. cancelling task and raising." + ) + try: + ray.cancel(dataflow_ref, force=True) + except Exception as exc: + self.logger.warning(f"ray.cancel of dataflow task failed: {exc}") + raise RuntimeError( + f"dataflow.run hung for 12h at rollout_idx={rollout_idx}; check ray dashboard for dead actors" + ) from None if XTUNER_DETERMINISTIC: data_groups, multimodal_train_infos = self._sort_rollout_outputs( @@ -822,8 +840,6 @@ def _prepare_train_data(self, data_groups, pack_max_length, multimodal_train_inf if "routed_experts" in group[i].env.rollout.extra_info: routed_experts = group[i].env.rollout.extra_info.pop("routed_experts") # n,layer*expert - if isinstance(routed_experts, str): - routed_experts = get_lmdeploy_routed_experts_ref(routed_experts) seq_ctx.rollout_routed_experts = routed_experts # n,layer,expert data_batches.append(data_dict)