diff --git a/examples/v1/config/agent_rl_qwen35_30b_grpo.py b/examples/v1/config/agent_rl_qwen35_30b_grpo.py
index f94ebeff78..a26176240f 100644
--- a/examples/v1/config/agent_rl_qwen35_30b_grpo.py
+++ b/examples/v1/config/agent_rl_qwen35_30b_grpo.py
@@ -48,7 +48,7 @@
max_concurrent_groups = 512
max_prompt_length = 4096
-pack_max_length = 68 * 1024
+pack_max_length = 256 * 1024
max_response_length = 64 * 1024
train_ep_size = 1
@@ -66,7 +66,7 @@
lr = 1e-6
hf_interval = 5
total_epochs = 10
-sp_size = 1
+sp_size = 4
# evaluation settings
enable_evaluate = False
enable_initial_evaluate = False
@@ -218,8 +218,14 @@ def convert_rollout_tractory_to_train(env, group_data_items):
model_cfg.text_config.z_loss_cfg = None
model_cfg.text_config.balancing_loss_cfg = None
model_cfg.text_config.freeze_routers = True
+# model_cfg.text_config.mtp_config = MTPConfig(
+# num_layers=1,
+# loss_scaling_factor=1.0,
+# detach_mtp_lm_head_weight=True,
+# detach_mtp_inputs=True,
+# share_weights=False,
+# )
model_cfg.text_config.vocab_size = 251392
-# model_cfg.text_config.embed_grad_max_token_id = 251173
optim_cfg = AdamWConfig(
lr=lr,
diff --git a/examples/v1/config/interns2-35ba3-base05-20260424a-rl-data260426rc1-56k-agentrl-tb0429rc0.py b/examples/v1/config/interns2-35ba3-base05-20260424a-rl-data260426rc1-56k-agentrl-tb0429rc0.py
new file mode 100644
index 0000000000..2c5f3b521e
--- /dev/null
+++ b/examples/v1/config/interns2-35ba3-base05-20260424a-rl-data260426rc1-56k-agentrl-tb0429rc0.py
@@ -0,0 +1,788 @@
+import os
+
+os.environ["XTUNER_USE_LMDEPLOY"] = "1"
+
+import hashlib
+import json
+
+# os.environ["HF_HOME"] = "/mnt/shared-storage-user/liukuikun/.cache/huggingface"
+# os.environ["TRANSFORMERS_OFFLINE"] = "1"
+from copy import deepcopy
+from functools import partial
+
+import ray
+from lagent.actions.mcp_client import AsyncMCPClient
+from lagent.actions.web_visitor import WebVisitor
+from lagent.agents.fc_agent import FunctionCallAgent, get_tool_prompt
+from ray.util.placement_group import placement_group
+
+from projects.claw_bench.claw_tokenize_fn import RLClawTokenizeFnConfig
+from projects.tb2_eval.tb2_eval_tokenize_fn import RLTB2EvalTokenizeFnConfig
+from projects.tb2_rl.tb2_rl_tokenize_fn import RLTB2RLTokenizeFnConfig
+from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig
+from xtuner.v1.data_proto.rl_data import (
+ RLAgentDataItem,
+ RLDataFlowItem,
+ RLJudgerResponseItem,
+ RLRolloutResponseItem,
+ RolloutState,
+ SampleParams,
+ update_dataflow_item,
+)
+from xtuner.v1.datasets import DatasetConfig, Qwen3VLTokenizeFnConfig
+from xtuner.v1.datasets.config import DataloaderConfig
+from xtuner.v1.model import Qwen3_5_VLMoE35BA3Config
+from xtuner.v1.ray.base import (
+ AcceleratorResourcesConfig,
+ AutoAcceleratorWorkers,
+ CPUResourcesConfig,
+)
+from xtuner.v1.ray.config.worker import RolloutConfig
+from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig
+from xtuner.v1.ray.environment.agent_env import AgentEnvironment
+from xtuner.v1.ray.environment.composed_env import ComposedEnvironment
+from xtuner.v1.ray.environment.install_agent_env import InstallAgentEnvironment
+from xtuner.v1.ray.environment.lagent.agents import (
+ AsyncTokenInOutAgent,
+ EnvAgent,
+ JudgerWrapper,
+ finish_condition_func,
+)
+from xtuner.v1.ray.environment.lagent.llms.controller_wrapper import ControllerWrapper
+from xtuner.v1.ray.environment.lagent.schema import AgentMessage
+from xtuner.v1.ray.evaluator import EvaluatorConfig
+from xtuner.v1.ray.judger.compass_verifier_v2 import CompassVerifierV2Config
+from xtuner.v1.ray.judger.controller import JudgerConfig
+from xtuner.v1.ray.rollout import RolloutController
+from xtuner.v1.rl.base import WorkerConfig
+from xtuner.v1.rl.base.rollout_is import RolloutImportanceSampling
+from xtuner.v1.rl.grpo import GRPOLossConfig
+from xtuner.v1.train.agent_rl_trainer import AgentRLTrainerConfig
+from xtuner.v1.train.trainer import LoadCheckpointConfig
+from xtuner.v1.utils.compute_metric import compute_metric
+
+if not ray.is_initialized():
+ ray.init(ignore_reinit_error=True, runtime_env={"env_vars": {"RAY_DEBUG_POST_MORTEM": "0"}})
+
+experimental_name = os.path.basename(__file__).split(".py")[0]
+# base_work_dir = "/mnt/shared-storage-user/llmit1/user/wangziyi/exp/mindcopilot_rl/work_dirs"
+base_work_dir = '/mnt/shared-storage-user/llmit1/user/liukuikun/delivery/interns2_preview_0430rc1'
+work_dir = os.path.join(base_work_dir, experimental_name)
+
+model_name = os.environ["RL_LLM_MODEL"]
+# model_path = "/mnt/shared-storage-user/llmit1/user/liujiangning/exp/s2_preview/agent_rl/s2-preview-thinker_sft_0228b_rl0312rc3_fix_klmismatch/20260327042049/hf-40"
+model_path = '/mnt/shared-storage-user/llmit1/user/wangziyi/exp/mindcopilot_rl/work_dirs/ckpt/interns2-35ba3-base05-20260424a-rl-data260426rc1-56k-badword-mtp4/20260426021137/hf-140'
+stop_word = "<|im_end|>"
+
+# basic settings
+global_batch_size = 128
+prompt_repeat_k = 8
+max_concurrent_groups = 512
+
+max_prompt_length = 4096
+pack_max_length = 130 * 1024
+max_response_length = 128 * 1024
+
+train_ep_size = 1
+train_sp_size = 4
+rollout_tp_size = 4
+rollout_ep_size = 1
+fp32_lm_head = True
+enable_float8_rollout = False
+rollout_max_batch_size = 128
+max_prefill_token_num = 1024
+enable_return_routed_experts = True
+enable_partial_rollout = False
+staleness_threshold = 0.0
+tail_batch_candidate_steps = 0
+auto_resume = True
+skip_load_weights = True
+lr = 1e-6
+train_optimizer_steps = 2 # mini batch steps
+hf_interval = 10
+total_epochs = 10
+# evaluation settings
+enable_evaluate = True
+enable_initial_evaluate = True
+evaluate_step = 5
+
+# agent setting
+max_turn = 5 # 最大对话轮次
+lower_tool_turn_bound = 3 # 是否惩罚使用工具的轮次小于该值的样本 , None表示不启用该惩罚
+enable_repeated_tool_call_penalty = True # 是否惩罚重复调用工具的样本
+enable_no_thinking_penalty = False
+max_tool_response_length = 8192
+
+
+# 1. resources
+resources = AcceleratorResourcesConfig(
+ accelerator="GPU",
+ num_workers=64,
+ num_cpus_per_worker=12,
+ cpu_memory_per_worker=16 * 1024**3, # 16 GB
+)
+judger_cpu_resources = CPUResourcesConfig.from_total(total_cpus=16, num_workers=16, total_memory=64 * 1024**3)
+
+# 2. rollout
+rollout_config = RolloutConfig(
+ env=experimental_name,
+ device=resources.accelerator,
+ model_name=model_name,
+ model_path=model_path,
+ dtype="bfloat16",
+ tensor_parallel_size=rollout_tp_size,
+ expert_parallel_size=rollout_ep_size,
+ gpu_memory_utilization=0.7,
+ enable_float8=enable_float8_rollout,
+ skip_load_weights=skip_load_weights,
+ context_length=max_response_length,
+ rollout_max_batch_size_per_instance=rollout_max_batch_size,
+ chunked_prefill_size=4096,
+ allow_over_concurrency_ratio=2.0,
+ rollout_timeout=36000,
+ enable_return_routed_experts=enable_return_routed_experts,
+ # return_routed_experts_key=True,
+ # max_prefill_token_num=max_prefill_token_num,
+ extra_rollout_config=dict(lmdeploy_log_level="ERROR", lmdeploy_uvicorn_log_level="ERROR"),
+ fp32_lm_head=fp32_lm_head,
+)
+
+# sampling params
+training_sample_params = SampleParams(
+ max_tokens=max_response_length, top_k=0, top_p=0.999, temperature=1.0, min_tokens=0
+)
+evaluation_sample_params = deepcopy(training_sample_params)
+evaluation_sample_params.temperature = 0.8
+# evaluation_sample_params.max_tokens = max_response_length
+
+tokenize_fn_cfg = Qwen3VLTokenizeFnConfig(
+ processor_path=model_path,
+ min_pixels=None,
+ # max_pixels=None,
+ # max_pixels=2097152,
+ video_min_total_pixels=None,
+ video_max_total_pixels=None,
+ video_min_frames=None,
+ video_max_frames=None,
+ fps=None,
+ rand_video_max_frames=24,
+ add_vision_id=True,
+ system_message=None,
+ hash=None,
+ enable_3d_rope=False,
+ oss_loader_cfg=None,
+ debug=True,
+ oss_time_log_thr=10,
+)
+
+# 2. dataset
+from intern_s1_delivery.dataset.xpuyu_dataset_vl import RLTokenizeFnConfig
+from intern_s1_delivery.dataset.xpuyu_dataset_vl import (
+ parse_xpuyu_json_cfg as parse_xpuyu_json_cfg_vl,
+)
+
+from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig
+
+
+def parse_xpuyu_json_cfg(path, max_prompt_length):
+ with open(path, "r") as f:
+ json_cfg = json.load(f)
+ converted_cfg = []
+ for ds_name, ds_cfg in json_cfg.items():
+ annotation = ds_cfg["annotation"]
+ tokenize_fn = ds_cfg["tokenize_fn"]
+
+ if tokenize_fn == "RLClawTokenizeFnConfig":
+ tokenize_fn_cfg = RLClawTokenizeFnConfig
+ elif tokenize_fn == "RLTB2RLTokenizeFnConfig":
+ tokenize_fn_cfg = RLTB2RLTokenizeFnConfig
+ elif tokenize_fn == "RLTB2EvalTokenizeFnConfig":
+ tokenize_fn_cfg = RLTB2EvalTokenizeFnConfig
+
+ if isinstance(annotation, str):
+ annotation = [annotation]
+ for ann in annotation:
+ converted_cfg.append(
+ {
+ "dataset": DatasetConfig(
+ name=ds_name,
+ anno_path=ann,
+ sample_ratio=ds_cfg["sample_ratio"],
+ class_name='JsonlDataset',
+ ),
+ "tokenize_fn": tokenize_fn_cfg(
+ root_path=ds_cfg.get("root_path", None), max_length=max_prompt_length
+ ),
+ }
+ )
+ return converted_cfg
+
+
+TRAIN_DATA_PATH = (
+ '/mnt/shared-storage-user/llmit1/user/liujiangning/data/s1_1_rl_delivery_agent/exp_rl/rl_data_260126.json'
+)
+TEST_DATA_PATH = (
+ '/mnt/shared-storage-user/llmit/user/wangziyi/projs/crg_rl_projects/data/deep_research_data/val_v2.jsonl'
+)
+
+# data_path = os.environ["DATA_PATH"]
+data_path = TRAIN_DATA_PATH
+# eval_data_path = os.environ["EVAL_DATA_PATH"]
+eval_data_path = TEST_DATA_PATH
+train_dataset_cfg = parse_xpuyu_json_cfg(
+ '/mnt/shared-storage-user/llmit/user/wangziyi/projs/xtuner_agent_dev/examples/demo_data/agent_dev/tb2_rl/meta.json',
+ max_prompt_length,
+)
+eval_dataset_cfg = (
+ [
+ {
+ "dataset": DatasetConfig(
+ name="tb2-eval",
+ anno_path="/mnt/shared-storage-user/llmit1/user/liukuikun/delivery/data/tb2_eval_tasks.jsonl",
+ sample_ratio=1.0,
+ media_root=None,
+ class_name='JsonlDataset',
+ ),
+ "tokenize_fn": RLTB2EvalTokenizeFnConfig(
+ root_path="/mnt/shared-storage-user/llmit/user/wangziyi/projs/terminalbench2-harbor-p-cluster/terminal-bench-2",
+ max_length=max_prompt_length,
+ ),
+ },
+ ]
+ if enable_evaluate
+ else []
+)
+dataloader_config = DataloaderConfig(pack_max_length=pack_max_length, collator="fake_collator", pack_level="none")
+
+
+# 3. judger
+judger_cfg = JudgerConfig(
+ reward_judger_configs=[
+ CompassVerifierV2Config(
+ hosts=[
+ "10.102.251.61:23333",
+ "10.102.251.61:23334",
+ "10.102.251.61:23335",
+ "10.102.251.61:23336",
+ "10.102.251.61:23337",
+ "10.102.251.61:23338",
+ "10.102.251.61:23339",
+ "10.102.251.61:23340",
+ "10.102.216.52:23333",
+ "10.102.216.52:23334",
+ "10.102.216.52:23335",
+ "10.102.216.52:23336",
+ "10.102.216.52:23337",
+ "10.102.216.52:23338",
+ "10.102.216.52:23339",
+ "10.102.216.52:23340",
+ "10.102.238.19:23333",
+ "10.102.238.19:23334",
+ "10.102.238.19:23335",
+ "10.102.238.19:23336",
+ "10.102.238.19:23337",
+ "10.102.238.19:23338",
+ "10.102.238.19:23339",
+ "10.102.238.19:23340",
+ "10.102.239.68:23333",
+ "10.102.239.68:23334",
+ "10.102.239.68:23335",
+ "10.102.239.68:23336",
+ "10.102.239.68:23337",
+ "10.102.239.68:23338",
+ "10.102.239.68:23339",
+ "10.102.239.68:23340",
+ ]
+ )
+ ]
+)
+
+from xtuner.v1.ray.environment.lagent.parsers import Qwen3_5FunctionCallParser
+
+
+def prepare_agent_inputs_for_search(env, group_data_item: RLDataFlowItem):
+ env_agent = group_data_item.env.agent.extra_info['agent'].env_agent
+ user_prompt = group_data_item.data.messages[-1]['content']
+ env_message = AgentMessage(
+ sender="env", content=user_prompt, uid=hashlib.md5(user_prompt.encode('utf-8')).hexdigest()
+ )
+ if not env_agent.memory.get_memory():
+ set_env_message = AgentMessage(sender="env", content=group_data_item)
+ env_agent.memory and env_agent.memory.add(set_env_message) # type: ignore[union-attr]
+ return (env_message,)
+
+
+def convert_rollout_tractory_to_train_for_search(env, group_data_items):
+ agent_data_items, rollout_response_items, judger_response_items = [], [], []
+ for i in range(len(group_data_items)):
+ history = group_data_items[i].env.rollout.extra_info['agent_state_dict']['policy_agent.memory']
+ env_history = group_data_items[i].env.rollout.extra_info['agent_state_dict']['env_agent.memory']
+ messages = group_data_items[i].env.rollout.extra_info['agent_message_dict']['policy_agent.messages']
+ agent_data_items.append(RLAgentDataItem(extra_info=dict(messages=messages, state={"history": history})))
+ rollout_response_items.append(
+ RLRolloutResponseItem(
+ response=history[-1]['raw_content'],
+ response_ids=history[-1]['raw_content_ids'],
+ logprobs=history[-1]['raw_content_logprobs'],
+ state=RolloutState.COMPLETED,
+ )
+ )
+ judger_response_items.append(RLJudgerResponseItem(reward=dict(score=env_history[-1]['reward'])))
+ group_data_items = update_dataflow_item(group_data_items, "env.agent", agent_data_items)
+ group_data_items = update_dataflow_item(group_data_items, "env.rollout", rollout_response_items)
+ group_data_items = update_dataflow_item(group_data_items, "env.judger", judger_response_items)
+ return group_data_items
+
+
+def prepare_agent_inputs_for_tb2rl(env, group_data_item: RLDataFlowItem):
+ return group_data_item
+
+
+def convert_rollout_tractory_to_train_for_tb2rl(env, group_data_items):
+ agent_data_items, rollout_response_items, judger_response_items = [], [], []
+ for i in range(len(group_data_items)):
+ messages = group_data_items[i].env.agent.extra_info['message_dict']['policy_agent.messages']
+ tools = group_data_items[i].env.agent.extra_info['message_dict'].get('policy_agent.tools')
+ agent_data_items.append(RLAgentDataItem(extra_info=dict(messages=messages, tools=tools)))
+ breakpoint()
+ rollout_response_items.append(
+ RLRolloutResponseItem(
+ response=messages[-1]['raw_content'],
+ response_ids=messages[-1]['raw_content_ids'],
+ logprobs=messages[-1]['raw_content_logprobs'],
+ state=RolloutState.COMPLETED,
+ )
+ )
+ reward_payload = group_data_items[i].env.judger.extra_info['total']
+ judger_response_items.append(RLJudgerResponseItem(reward=dict(score=reward_payload)))
+ group_data_items = update_dataflow_item(group_data_items, "env.agent", agent_data_items)
+ group_data_items = update_dataflow_item(group_data_items, "env.rollout", rollout_response_items)
+ group_data_items = update_dataflow_item(group_data_items, "env.judger", judger_response_items)
+ # breakpoint()
+ return group_data_items
+
+
+pg = AutoAcceleratorWorkers.build_placement_group(resources)
+rollout_controller = ray.remote(max_concurrency=1000)(RolloutController).remote(rollout_config, pg)
+load_checkpoint_cfg = LoadCheckpointConfig(load_optimizer_states=False, load_optimizer_args=False)
+
+search_tool = dict(
+ type=AsyncMCPClient,
+ name='SerperSearch',
+ server_type='http',
+ rate_limit=500.0,
+ # max_concurrency=128,
+ url=[
+ 'http://10.102.103.157:8091/mcp',
+ 'http://10.102.103.155:8096/mcp',
+ 'http://10.102.103.155:8092/mcp',
+ 'http://10.102.103.155:8095/mcp',
+ 'http://10.102.103.155:8097/mcp',
+ 'http://10.102.103.155:8098/mcp',
+ 'http://10.102.103.155:8094/mcp',
+ 'http://10.102.103.155:8093/mcp',
+ ],
+)
+browse_tool = dict(
+ type=AsyncMCPClient,
+ name='JinaBrowse',
+ server_type='http',
+ rate_limit=100.0,
+ max_concurrency=40,
+ url=[
+ 'http://10.102.103.155:8104/mcp',
+ 'http://10.102.103.155:8100/mcp',
+ 'http://10.102.103.148:8101/mcp',
+ 'http://10.102.103.157:8105/mcp',
+ 'http://10.102.103.155:8099/mcp',
+ 'http://10.102.103.155:8102/mcp',
+ 'http://10.102.103.155:8103/mcp',
+ 'http://10.102.103.155:8106/mcp',
+ ],
+)
+visit_tool = dict(
+ type=WebVisitor,
+ browse_tool=dict(
+ type=AsyncMCPClient,
+ name='JinaBrowse',
+ server_type='http',
+ rate_limit=100.0,
+ max_concurrency=40,
+ url=[
+ 'http://10.102.103.155:8104/mcp',
+ 'http://10.102.103.155:8100/mcp',
+ 'http://10.102.103.148:8101/mcp',
+ 'http://10.102.103.157:8105/mcp',
+ 'http://10.102.103.155:8099/mcp',
+ 'http://10.102.103.155:8102/mcp',
+ 'http://10.102.103.155:8103/mcp',
+ 'http://10.102.103.155:8106/mcp',
+ ],
+ ),
+ llm=dict(
+ type=ControllerWrapper,
+ rollout_controller=rollout_controller,
+ sample_params=SampleParams(max_tokens=max_response_length),
+ tool_call_parser=Qwen3_5FunctionCallParser(),
+ ),
+ truncate_browse_response_length=60000,
+ tokenizer_path=model_path,
+)
+
+tool_template = """# Tools
+
+You have access to the following functions:
+
+
+{tools}
+
+
+If you choose to call a function ONLY reply in the following format with NO suffix:
+
+
+
+
+value_1
+
+
+This is the value for the second parameter
+that can span
+multiple lines
+
+
+
+
+
+Reminder:
+- Function calls MUST follow the specified format: an inner block must be nested within XML tags
+- Required parameters MUST be specified
+- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after
+- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
+"""
+
+search_browse_tool_prompt = get_tool_prompt([search_tool, browse_tool], template=tool_template)
+search_visit_tool_prompt = get_tool_prompt([search_tool, visit_tool], template=tool_template)
+
+train_agent_with_search_browse = dict(
+ type=FunctionCallAgent,
+ policy_agent=dict(
+ type=AsyncTokenInOutAgent,
+ llm=dict(
+ type=ControllerWrapper,
+ rollout_controller=rollout_controller,
+ sample_params=SampleParams(max_tokens=max_response_length),
+ tool_call_parser=Qwen3_5FunctionCallParser(),
+ ),
+ template=search_browse_tool_prompt,
+ ),
+ env_agent=dict(
+ type=EnvAgent,
+ actions=[search_tool, browse_tool],
+ judger=dict(
+ type=JudgerWrapper,
+ judger_cfg=judger_cfg,
+ placement_group=ray.get(
+ placement_group(bundles=[{"CPU": 1, "memory": 1024**3}], strategy="PACK").ready(), timeout=30
+ ),
+ ),
+ max_turn=max_turn,
+ lower_tool_turn_bound=lower_tool_turn_bound,
+ enable_repeated_tool_call_penalty=enable_repeated_tool_call_penalty,
+ enable_no_thinking_penalty=enable_no_thinking_penalty,
+ max_tool_response_length=max_tool_response_length,
+ ),
+ finish_condition=finish_condition_func,
+ initialize_input=False,
+)
+train_agent_with_search_visit = dict(
+ type=FunctionCallAgent,
+ policy_agent=dict(
+ type=AsyncTokenInOutAgent,
+ llm=dict(
+ type=ControllerWrapper,
+ rollout_controller=rollout_controller,
+ sample_params=SampleParams(max_tokens=max_response_length),
+ tool_call_parser=Qwen3_5FunctionCallParser(),
+ ),
+ template=search_visit_tool_prompt,
+ ),
+ env_agent=dict(
+ type=EnvAgent,
+ actions=[search_tool, visit_tool],
+ judger=dict(
+ type=JudgerWrapper,
+ judger_cfg=judger_cfg,
+ placement_group=ray.get(
+ placement_group(bundles=[{"CPU": 1, "memory": 1024**3}], strategy="PACK").ready(), timeout=30
+ ),
+ ),
+ max_turn=max_turn,
+ lower_tool_turn_bound=lower_tool_turn_bound,
+ enable_repeated_tool_call_penalty=enable_repeated_tool_call_penalty,
+ enable_no_thinking_penalty=enable_no_thinking_penalty,
+ max_tool_response_length=max_tool_response_length,
+ ),
+ finish_condition=finish_condition_func,
+ initialize_input=False,
+)
+eval_agent = dict(
+ type=FunctionCallAgent,
+ policy_agent=dict(
+ type=AsyncTokenInOutAgent,
+ llm=dict(
+ type=ControllerWrapper,
+ rollout_controller=rollout_controller,
+ sample_params=SampleParams(max_tokens=max_response_length),
+ tool_call_parser=Qwen3_5FunctionCallParser(),
+ ),
+ template=search_visit_tool_prompt,
+ ),
+ env_agent=dict(
+ type=EnvAgent,
+ actions=[search_tool, visit_tool],
+ judger=dict(
+ type=JudgerWrapper,
+ judger_cfg=judger_cfg,
+ placement_group=ray.get(
+ placement_group(bundles=[{"CPU": 1, "memory": 1024**3}], strategy="PACK").ready(), timeout=30
+ ),
+ ),
+ max_turn=max_turn,
+ lower_tool_turn_bound=None,
+ enable_repeated_tool_call_penalty=False,
+ enable_no_thinking_penalty=False,
+ max_tool_response_length=max_tool_response_length,
+ ),
+ finish_condition=finish_condition_func,
+ initialize_input=False,
+)
+
+
+def bucket(key: str, n: int) -> int:
+ hashkey = hashlib.md5(key.encode('utf-8')).hexdigest()
+ return sum(ord(c) for c in hashkey) % n
+
+
+def rollout_env_router_fn(item: RLDataFlowItem):
+ if item.data.extra_info.get('origin_data_source', '').startswith('gaia') or item.data.extra_info.get(
+ 'origin_data_source'
+ ) in [
+ 'BrowseComp-ZH',
+ 'HLE',
+ 'browsecomp',
+ ]:
+ return 'eval'
+
+ if 'claw-bench' in item.data.data_source:
+ return 'train_clawbench'
+ if 'tb2-rl' in item.data.data_source:
+ return 'train_tb2rl'
+ if 'tb2-eval' in item.data.data_source:
+ return 'eval_tb2eval'
+
+ match bucket(item.data.messages[-1]['content'], 2):
+ case 0:
+ return 'train_agent_with_search_browse'
+ case 1:
+ return 'train_agent_with_search_visit'
+
+
+environment_config = dict(
+ type=ComposedEnvironment,
+ environment=experimental_name,
+ rollout_controller=rollout_controller,
+ environments={
+ 'train_agent_with_search_browse': dict(
+ type=AgentEnvironment,
+ environment='train_agent_with_search_browse',
+ agent_cfg=train_agent_with_search_browse,
+ rollout_controller=rollout_controller,
+ preprocess_func=prepare_agent_inputs_for_search,
+ postprocess_func=convert_rollout_tractory_to_train_for_search,
+ ),
+ 'train_agent_with_search_visit': dict(
+ type=AgentEnvironment,
+ environment='train_agent_with_search_visit',
+ agent_cfg=train_agent_with_search_visit,
+ rollout_controller=rollout_controller,
+ preprocess_func=prepare_agent_inputs_for_search,
+ postprocess_func=convert_rollout_tractory_to_train_for_search,
+ ),
+ 'eval': dict(
+ type=AgentEnvironment,
+ environment='eval',
+ agent_cfg=eval_agent,
+ rollout_controller=rollout_controller,
+ preprocess_func=prepare_agent_inputs_for_search,
+ postprocess_func=convert_rollout_tractory_to_train_for_search,
+ ),
+ 'train_tb2rl': dict(
+ type=InstallAgentEnvironment,
+ environment='train_tb2rl',
+ rollout_controller=rollout_controller,
+ preprocess_func=prepare_agent_inputs_for_tb2rl,
+ postprocess_func=convert_rollout_tractory_to_train_for_tb2rl,
+ ),
+ 'eval_tb2eval': dict(
+ type=InstallAgentEnvironment,
+ environment='eval_tb2eval',
+ rollout_controller=rollout_controller,
+ preprocess_func=prepare_agent_inputs_for_tb2rl,
+ postprocess_func=convert_rollout_tractory_to_train_for_tb2rl,
+ ),
+ },
+ router=rollout_env_router_fn,
+)
+
+# 4. dataflow and evaluator
+dataflow_config = DataFlowConfig(
+ env=experimental_name,
+ max_concurrent=max_concurrent_groups,
+ enable_partial_rollout=enable_partial_rollout,
+ tail_batch_candidate_steps=tail_batch_candidate_steps,
+ staleness_threshold=staleness_threshold,
+ max_retry_times=3,
+ prompt_repeat_k=prompt_repeat_k,
+ global_batch_size=global_batch_size,
+ sample_params=training_sample_params,
+)
+
+evaluator_cfg = (
+ EvaluatorConfig(
+ enable_evaluate=enable_evaluate,
+ enable_initial_evaluate=enable_initial_evaluate,
+ dataset_cfg=eval_dataset_cfg,
+ tokenizer=model_path,
+ eval_sample_ratio=1.0,
+ evaluate_step=evaluate_step,
+ compute_metric_func=partial(
+ compute_metric,
+ source_normalizer={
+ 'miroRL': 'websearch',
+ 'musique': 'websearch',
+ 'websailor': 'websearch',
+ 'webdancer': 'websearch',
+ 'gaia-level1': ('gaia-level1', 'gaia'),
+ 'gaia-level2': ('gaia-level2', 'gaia'),
+ 'gaia-level3': ('gaia-level3', 'gaia'),
+ },
+ ),
+ sample_params=evaluation_sample_params,
+ max_concurrent=8192,
+ )
+ if enable_evaluate
+ else None
+)
+
+
+def group_sample_filter_func(group_samples):
+ # filter all correct or all wrong sample
+ group_samples = [s for s in group_samples if s.env.rollout.response_ids is not None]
+
+ # filter all same reward sample
+ rewards = [d.env.judger.reward["score"] for d in group_samples]
+ if len(set(rewards)) == 1:
+ print(f"filter all same reward sample: {rewards}")
+ return []
+ return group_samples
+
+
+replay_buffer_cfg = ReplayBufferConfig(
+ dataset_cfg=train_dataset_cfg,
+ dataloader_cfg=dataloader_config,
+ tokenizer=model_path,
+ # postprocessor_func=group_sample_filter_func,
+)
+
+# # 5. Train worker
+model_cfg = Qwen3_5_VLMoE35BA3Config(
+ freeze_vision=True,
+ freeze_projector=True,
+)
+model_cfg.float8_cfg = None
+model_cfg.text_config.ep_size = 1
+model_cfg.text_config.z_loss_cfg = None
+model_cfg.text_config.balancing_loss_cfg = None
+model_cfg.text_config.freeze_routers = True
+model_cfg.text_config.vocab_size = 251392
+# model_cfg.text_config.embed_grad_max_token_id = 251173
+
+optim_cfg = AdamWConfig(
+ lr=lr,
+ betas=(0.9, 0.95),
+ max_grad_norm=1.0,
+ weight_decay=0.1,
+ foreach=False,
+ skip_grad_norm_threshold=0.9,
+ eps=1e-15,
+)
+loss_cfg = GRPOLossConfig(
+ policy_loss_cfg=dict(
+ cliprange_high=0.2,
+ cliprange_low=0.2,
+ loss_type="intern_s1_delivery.modules.pg_loss.pg_loss_fn",
+ clip_ratio_c=5.0,
+ log_prob_diff_min=-20.0,
+ log_prob_diff_max=20.0,
+ ),
+ ignore_idx=-100,
+ use_kl_loss=False,
+ kl_loss_coef=0.0,
+ kl_loss_type="low_var_kl",
+ mode="chunk",
+ chunk_size=512,
+ rollout_is=RolloutImportanceSampling(
+ rollout_is_level="token",
+ rollout_is_mode="both",
+ rollout_is_threshold=(5, 0),
+ rollout_is_mask_threshold=(5, 0.5),
+ rollout_is_veto_threshold=(20, 0),
+ ),
+)
+lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=lr)
+fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=train_ep_size)
+train_worker_cfg: WorkerConfig = WorkerConfig(
+ model_cfg=model_cfg,
+ load_from=model_path,
+ optim_cfg=optim_cfg,
+ loss_cfg=loss_cfg,
+ lr_cfg=lr_cfg,
+ fsdp_cfg=fsdp_cfg,
+ sp_size=train_sp_size,
+ optimizer_steps=train_optimizer_steps,
+ pack_max_length=pack_max_length,
+)
+
+
+# 6. RL Trainer
+trainer = AgentRLTrainerConfig(
+ load_from=model_path,
+ pg=pg,
+ environment_config=environment_config,
+ dataflow_config=dataflow_config,
+ replay_buffer_config=replay_buffer_cfg,
+ train_worker_cfg=train_worker_cfg,
+ evaluator_config=evaluator_cfg,
+ tokenizer_path=model_path,
+ work_dir=work_dir,
+ total_epochs=total_epochs,
+ hf_interval=hf_interval,
+ skip_load_weights=skip_load_weights,
+ auto_resume=auto_resume,
+ checkpoint_interval=1,
+ checkpoint_maxkeep=1,
+ load_checkpoint_cfg=load_checkpoint_cfg,
+ checkpoint_no_save_optimizer=True,
+ skip_checkpoint_validation=True,
+)
+
+
+import torch.distributed as dist
+
+from xtuner.v1.train.agent_rl_trainer import AgentRLTrainer
+
+trainer = AgentRLTrainer.from_config(trainer)
+trainer.fit()
+
+if dist.is_initialized():
+ dist.destroy_process_group()
diff --git a/examples/v1/config/interns2-35ba3-base05-20260424a-rl-data260426rc1-56k-badword-mtp4_agenticrl_tb2_mtp4_0503rc1.py b/examples/v1/config/interns2-35ba3-base05-20260424a-rl-data260426rc1-56k-badword-mtp4_agenticrl_tb2_mtp4_0503rc1.py
new file mode 100644
index 0000000000..ea278af8f7
--- /dev/null
+++ b/examples/v1/config/interns2-35ba3-base05-20260424a-rl-data260426rc1-56k-badword-mtp4_agenticrl_tb2_mtp4_0503rc1.py
@@ -0,0 +1,1442 @@
+import os
+
+os.environ["XTUNER_USE_LMDEPLOY"] = "1"
+
+import hashlib
+import json
+
+# os.environ["TRANSFORMERS_OFFLINE"] = "1"
+from copy import deepcopy
+from functools import partial
+
+import ray
+from intern_s1_delivery.advantage.rloo_entropy_badword import (
+ OverlongRLOOGroupEntropyBadwordAdvantageConfig,
+)
+from lagent.actions.mcp_client import AsyncMCPClient
+from lagent.actions.web_visitor import WebVisitor
+from lagent.agents.fc_agent import FunctionCallAgent, get_tool_prompt
+from ray.util.placement_group import placement_group
+
+from claw_bench.claw_tokenize_fn import RLClawTokenizeFnConfig
+from tb2_eval.tb2_eval_tokenize_fn import RLTB2EvalTokenizeFnConfig
+from tb2_rl.tb2_rl_tokenize_fn import RLTB2RLTokenizeFnConfig
+from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig
+from xtuner.v1.data_proto.rl_data import (
+ RLAgentDataItem,
+ RLDataFlowItem,
+ RLJudgerResponseItem,
+ RLRolloutResponseItem,
+ RolloutState,
+ SampleParams,
+ update_dataflow_item,
+)
+from xtuner.v1.datasets import Qwen3VLTokenizeFnConfig
+from xtuner.v1.datasets.config import DataloaderConfig
+from xtuner.v1.model import Qwen3_5_VLMoE35BA3Config
+from xtuner.v1.module.mtp import MTPConfig
+from xtuner.v1.ray.base import (
+ AcceleratorResourcesConfig,
+ AutoAcceleratorWorkers,
+ CPUResourcesConfig,
+)
+from xtuner.v1.ray.config.worker import RolloutConfig
+from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig
+from xtuner.v1.ray.environment.agent_env import AgentEnvironment
+from xtuner.v1.ray.environment.composed_env import ComposedEnvironment
+from xtuner.v1.ray.environment.install_agent_env import InstallAgentEnvironment
+from xtuner.v1.ray.environment.lagent.agents import (
+ AsyncTokenInOutAgent,
+ EnvAgent,
+ JudgerWrapper,
+ finish_condition_func,
+)
+from xtuner.v1.ray.environment.lagent.llms.controller_wrapper import ControllerWrapper
+from xtuner.v1.ray.environment.lagent.schema import AgentMessage
+from xtuner.v1.ray.evaluator import EvaluatorConfig
+from xtuner.v1.ray.judger.compass_verifier_v2 import CompassVerifierV2Config
+from xtuner.v1.ray.judger.controller import JudgerConfig
+from xtuner.v1.ray.judger.frontierscience_judger import FrontierScienceJudgerConfig
+from xtuner.v1.ray.judger.review import ReviewJudgerConfig
+from xtuner.v1.ray.judger.sgi_judger import SGIJudgerConfig
+from xtuner.v1.ray.rollout import RolloutController
+from xtuner.v1.rl.base import WorkerConfig
+from xtuner.v1.rl.base.rollout_is import RolloutImportanceSampling
+from xtuner.v1.rl.grpo import GRPOLossConfig
+from xtuner.v1.train.agent_rl_trainer import AgentRLTrainerConfig
+from xtuner.v1.train.trainer import LoadCheckpointConfig
+from xtuner.v1.utils.compute_metric import compute_metric
+
+if not ray.is_initialized():
+ ray.init(ignore_reinit_error=True, runtime_env={"env_vars": {"RAY_DEBUG_POST_MORTEM": "0"}})
+
+experimental_name = os.path.basename(__file__).split(".py")[0]
+base_work_dir = "/mnt/shared-storage-user/llmit1/user/liukuikun/delivery/interns2_preview_0430rc9"
+work_dir = os.path.join(base_work_dir, experimental_name)
+model_name = os.environ["RL_LLM_MODEL"]
+model_path = '/mnt/shared-storage-user/llmit1/user/wangziyi/exp/mindcopilot_rl/work_dirs/ckpt/interns2-35ba3-base05-20260424a-rl-data260428rc0-56k-badword-mtp4-resume800/20260430074140/hf-40'
+stop_word = "<|im_end|>"
+
+# basic settings
+global_batch_size = 256
+prompt_repeat_k = 16
+max_concurrent_groups = 512
+
+max_prompt_length = 16 * 1024
+pack_max_length = 130 * 1024
+max_response_length = 128 * 1024
+
+train_ep_size = 1
+train_sp_size = 2
+rollout_tp_size = 4
+rollout_ep_size = 1
+fp32_lm_head = True
+enable_float8_rollout = False
+rollout_max_batch_size = 128
+max_prefill_token_num = 1024
+enable_return_routed_experts = True
+enable_partial_rollout = False
+staleness_threshold = 0.0
+tail_batch_candidate_steps = 0
+auto_resume = True
+skip_load_weights = True
+lr = 1e-6
+train_optimizer_steps = 8 # mini batch steps
+hf_interval = 5
+total_epochs = 10
+# evaluation settings
+enable_evaluate = True
+enable_initial_evaluate = True
+evaluate_step = 5
+
+# agent setting
+max_turn = 50 # 最大对话轮次
+lower_tool_turn_bound = 12 # 是否惩罚使用工具的轮次小于该值的样本 , None表示不启用该惩罚
+lower_tool_turn_bound_science = 8
+enable_repeated_tool_call_penalty = True # 是否惩罚重复调用工具的样本
+enable_no_thinking_penalty = False
+max_tool_response_length = 8192
+
+
+# 1. resources
+resources = AcceleratorResourcesConfig(
+ accelerator="GPU",
+ num_workers=64,
+ num_cpus_per_worker=12,
+ cpu_memory_per_worker=16 * 1024**3, # 16 GB
+)
+judger_cpu_resources = CPUResourcesConfig.from_total(total_cpus=64, num_workers=64, total_memory=512 * 1024**3)
+
+# 2. rollout
+rollout_config = RolloutConfig(
+ env=experimental_name,
+ device=resources.accelerator,
+ model_name=model_name,
+ model_path=model_path,
+ dtype="bfloat16",
+ tensor_parallel_size=rollout_tp_size,
+ expert_parallel_size=rollout_ep_size,
+ gpu_memory_utilization=0.7,
+ enable_float8=enable_float8_rollout,
+ skip_load_weights=skip_load_weights,
+ context_length=max_response_length,
+ rollout_max_batch_size_per_instance=rollout_max_batch_size,
+ # chunked_prefill_size=4096,
+ allow_over_concurrency_ratio=1.2,
+ rollout_timeout=1800,
+ enable_return_routed_experts=enable_return_routed_experts,
+ # max_prefill_token_num=max_prefill_token_num,
+ extra_rollout_config=dict(
+ lmdeploy_log_level="ERROR",
+ lmdeploy_uvicorn_log_level="ERROR",
+ lmdeploy_speculative_algorithm='qwen3_5_mtp',
+ lmdeploy_speculative_num_draft_tokens=4,
+ ),
+ fp32_lm_head=fp32_lm_head,
+)
+
+# sampling params
+training_sample_params = SampleParams(
+ max_tokens=max_response_length, top_k=0, top_p=0.999, temperature=1.0, min_tokens=0
+)
+evaluation_sample_params = deepcopy(training_sample_params)
+evaluation_sample_params.temperature = 0.8
+# evaluation_sample_params.max_tokens = max_response_length
+
+data_judger_mapping = {
+ "agent": {"compass_verifier_v2": 1.0},
+ "agent_science": {"compass_verifier_v2": 1.0},
+ 'GAIA_sft_1229': {"compass_verifier_v2": 1.0},
+ 'gaia-level1': {"compass_verifier_v2": 1.0},
+ 'gaia-level2': {"compass_verifier_v2": 1.0},
+ 'gaia-level3': {"compass_verifier_v2": 1.0},
+ 'BrowseComp-ZH': {"compass_verifier_v2": 1.0},
+ 'HLE': {"compass_verifier_v2": 1.0},
+ 'browsecomp': {"compass_verifier_v2": 1.0},
+ 'math': {"compass_verifier_v2": 1.0},
+ 'AIME2024': {"compass_verifier_v2": 1.0},
+ 'AIME2025': {"compass_verifier_v2": 1.0},
+ 'aime2026': {"compass_verifier_v2": 1.0},
+ 'hmmt26': {"compass_verifier_v2": 1.0},
+ 'IMO-Bench-AnswerBench': {"compass_verifier_v2": 1.0},
+ 'UGD_hard': {"compass_verifier_v2": 1.0},
+ 'openreview': {"openreview": 1.0},
+ 'openreview_test': {"openreview": 1.0},
+ 'sgi-deep-research': {"sgi_judger": 1.0},
+ 'frontierscience': {"frontierscience_judger": 1.0},
+}
+tokenize_fn_cfg = Qwen3VLTokenizeFnConfig(
+ processor_path=model_path,
+ min_pixels=None,
+ # max_pixels=None,
+ # max_pixels=2097152,
+ video_min_total_pixels=None,
+ video_max_total_pixels=None,
+ video_min_frames=None,
+ video_max_frames=None,
+ fps=None,
+ rand_video_max_frames=24,
+ add_vision_id=True,
+ system_message=None,
+ hash=None,
+ enable_3d_rope=False,
+ oss_loader_cfg=None,
+ debug=True,
+ oss_time_log_thr=10,
+ chat_template='qwen3-vl-rl',
+)
+
+from intern_s1_delivery.dataset.xpuyu_dataset_vl import (
+ RLTokenizeFnConfig,
+)
+from intern_s1_delivery.dataset.xpuyu_dataset_vl import (
+ parse_xpuyu_json_cfg as parse_xpuyu_json_cfg_vl,
+)
+
+from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig
+
+# 2. dataset
+# from lagent_rl.datasets.parse_cfg import parse_xpuyu_json_cfg
+
+
+def parse_xpuyu_json_cfg(path, max_prompt_length):
+ with open(path, "r") as f:
+ json_cfg = json.load(f)
+ converted_cfg = []
+ for ds_name, ds_cfg in json_cfg.items():
+ annotation = ds_cfg["annotation"]
+ tokenize_fn = ds_cfg["tokenize_fn"]
+
+ if tokenize_fn == "RLClawTokenizeFnConfig":
+ tokenize_fn_cfg = RLClawTokenizeFnConfig
+ elif tokenize_fn == "RLTB2RLTokenizeFnConfig":
+ tokenize_fn_cfg = RLTB2RLTokenizeFnConfig
+ elif tokenize_fn == "RLTB2EvalTokenizeFnConfig":
+ tokenize_fn_cfg = RLTB2EvalTokenizeFnConfig
+
+ if isinstance(annotation, str):
+ annotation = [annotation]
+ for ann in annotation:
+ converted_cfg.append(
+ {
+ "dataset": DatasetConfig(
+ name=ds_name,
+ anno_path=ann,
+ sample_ratio=ds_cfg["sample_ratio"],
+ class_name='JsonlDataset',
+ ),
+ "tokenize_fn": tokenize_fn_cfg(
+ root_path=ds_cfg.get("root_path", None), max_length=max_prompt_length
+ ),
+ }
+ )
+ return converted_cfg
+
+
+TRAIN_DATA_PATH_SCIENCE_SEARCH = "/mnt/shared-storage-user/llmit/user/liujiangning/projects/crg_rl_projects/src/lagent_rl/scripts/s2_preview_35b_agentrl/interns2_35ba3_b03_0413a_reasoningRL_scienceSearch0421a/scienceSearch0421a.json"
+TRAIN_DATA_PATH_INTERNLM_SCIENCE = "/mnt/shared-storage-user/llmit/user/liujiangning/projects/crg_rl_projects/src/lagent_rl/scripts/s2_preview_35b_agentrl/interns2_35ba3_b03_0413a_reasoningRL_scienceSearch0423a/scienceSearch0423a.json"
+
+TRAIN_DATA_PATH_SEARCH = (
+ '/mnt/shared-storage-user/llmit1/user/liujiangning/data/s1_1_rl_delivery_agent/exp_rl/rl_data_260126.json'
+)
+TRAIN_DATA_PATH_MATH = '/mnt/shared-storage-user/llmit1/user/liujiangning/data/s1_1_rl_delivery_agent/exp_rl/train_interns1-1_260124rc0_pure_math.json'
+TEST_DATA_PATH_MATH = '/mnt/shared-storage-user/llmit1/user/liujiangning/data/s1_1_rl_delivery_agent/math_benchmark/val_python_toolcall.json'
+
+
+train_dataset_cfg_science_search = parse_xpuyu_json_cfg_vl(
+ TRAIN_DATA_PATH_SCIENCE_SEARCH, tokenize_fn_cfg, max_prompt_length, data_judger_mapping
+)
+train_dataset_cfg_internlm_science = parse_xpuyu_json_cfg_vl(
+ TRAIN_DATA_PATH_INTERNLM_SCIENCE, tokenize_fn_cfg, max_prompt_length, data_judger_mapping
+)
+
+train_dataset_cfg_search = parse_xpuyu_json_cfg_vl(
+ TRAIN_DATA_PATH_SEARCH, tokenize_fn_cfg, max_prompt_length, data_judger_mapping
+)
+train_dataset_cfg_math = parse_xpuyu_json_cfg_vl(
+ TRAIN_DATA_PATH_MATH, tokenize_fn_cfg, max_prompt_length, data_judger_mapping
+)
+train_dataset_cfg_review = [
+ {
+ "dataset": DatasetConfig(
+ name='openreview',
+ anno_path="/mnt/shared-storage-user/llmit/user/wangziyi/projs/demo/xtuner/examples/demo_data/agent/openreview/train.jsonl",
+ sample_ratio=0.1,
+ media_root=None,
+ class_name='VLMJsonlDataset',
+ ),
+ "tokenize_fn": RLTokenizeFnConfig(
+ tokenize_fn_cfg=tokenize_fn_cfg,
+ system_prompt=None,
+ max_length=max_prompt_length,
+ data_judger_mapping=data_judger_mapping,
+ ignore_multimodal_info=False,
+ ),
+ }
+]
+train_dataset_cfg_tb2rl = parse_xpuyu_json_cfg(
+ '/mnt/shared-storage-user/llmit/user/wangziyi/projs/xtuner_agent_dev/examples/demo_data/agent_dev/tb2_rl/meta.json',
+ max_prompt_length,
+)
+train_dataset_cfg = (
+ train_dataset_cfg_science_search
+ + train_dataset_cfg_internlm_science
+ + train_dataset_cfg_search
+ + train_dataset_cfg_math
+ + train_dataset_cfg_review
+ + train_dataset_cfg_tb2rl
+)
+eval_dataset_cfg_search = [
+ {
+ "dataset": DatasetConfig(
+ name="gaia",
+ anno_path="/mnt/shared-storage-user/llmit/user/wangziyi/projs/crg_rl_projects/data/gaia_text_103.jsonl",
+ sample_ratio=4.0,
+ media_root=None,
+ class_name='VLMJsonlDataset',
+ ),
+ "tokenize_fn": RLTokenizeFnConfig(
+ tokenize_fn_cfg=tokenize_fn_cfg,
+ system_prompt=None,
+ max_length=max_prompt_length,
+ data_judger_mapping=data_judger_mapping,
+ ignore_multimodal_info=True,
+ ),
+ },
+ {
+ "dataset": DatasetConfig(
+ name="browsecomp-zh",
+ anno_path="/mnt/shared-storage-user/llmit/user/wangziyi/projs/crg_rl_projects/data/browsecomp-zh.jsonl",
+ sample_ratio=0.0,
+ media_root=None,
+ class_name='VLMJsonlDataset',
+ ),
+ "tokenize_fn": RLTokenizeFnConfig(
+ tokenize_fn_cfg=tokenize_fn_cfg,
+ system_prompt=None,
+ max_length=max_prompt_length,
+ data_judger_mapping=data_judger_mapping,
+ ignore_multimodal_info=True,
+ ),
+ },
+ {
+ "dataset": DatasetConfig(
+ name="browsecomp",
+ anno_path="/mnt/shared-storage-user/llmit/user/wangziyi/projs/crg_rl_projects/data/browsecomp.jsonl",
+ sample_ratio=1.0,
+ media_root=None,
+ class_name='VLMJsonlDataset',
+ ),
+ "tokenize_fn": RLTokenizeFnConfig(
+ tokenize_fn_cfg=tokenize_fn_cfg,
+ system_prompt=None,
+ max_length=max_prompt_length,
+ data_judger_mapping=data_judger_mapping,
+ ignore_multimodal_info=True,
+ ),
+ },
+ {
+ "dataset": DatasetConfig(
+ name="hle",
+ anno_path="/mnt/shared-storage-user/llmit/user/wangziyi/projs/crg_rl_projects/data/hle.jsonl",
+ sample_ratio=1.0,
+ media_root=None,
+ class_name='VLMJsonlDataset',
+ ),
+ "tokenize_fn": RLTokenizeFnConfig(
+ tokenize_fn_cfg=tokenize_fn_cfg,
+ system_prompt=None,
+ max_length=max_prompt_length,
+ data_judger_mapping=data_judger_mapping,
+ ignore_multimodal_info=True,
+ ),
+ },
+ {
+ "dataset": DatasetConfig(
+ name="sgi-deep-research",
+ anno_path="/mnt/shared-storage-user/llmit1/user/liujiangning/data/eval_benchmark_testset/sgi_deep_research_gaia_format.jsonl",
+ sample_ratio=1.0,
+ media_root=None,
+ class_name='VLMJsonlDataset',
+ ),
+ "tokenize_fn": RLTokenizeFnConfig(
+ tokenize_fn_cfg=tokenize_fn_cfg,
+ system_prompt=None,
+ max_length=max_prompt_length,
+ data_judger_mapping=data_judger_mapping,
+ ignore_multimodal_info=True,
+ ),
+ },
+ {
+ "dataset": DatasetConfig(
+ name="frontierscience",
+ anno_path="/mnt/shared-storage-user/llmit1/user/liujiangning/data/eval_benchmark_testset/frontierscience_gaia_format.jsonl",
+ sample_ratio=1.0,
+ media_root=None,
+ class_name='VLMJsonlDataset',
+ ),
+ "tokenize_fn": RLTokenizeFnConfig(
+ tokenize_fn_cfg=tokenize_fn_cfg,
+ system_prompt=None,
+ max_length=max_prompt_length,
+ data_judger_mapping=data_judger_mapping,
+ ignore_multimodal_info=True,
+ ),
+ },
+]
+eval_dataset_cfg_math = parse_xpuyu_json_cfg_vl(
+ TEST_DATA_PATH_MATH, tokenize_fn_cfg, max_prompt_length, data_judger_mapping, ignore_multimodal_info=True
+)
+eval_dataset_cfg_review = [
+ {
+ "dataset": DatasetConfig(
+ name="openreview",
+ anno_path="/mnt/shared-storage-user/llmit/user/wangziyi/projs/demo/xtuner/examples/demo_data/agent/openreview/test.jsonl",
+ sample_ratio=1.0,
+ media_root=None,
+ class_name='VLMJsonlDataset',
+ ),
+ "tokenize_fn": RLTokenizeFnConfig(
+ tokenize_fn_cfg=tokenize_fn_cfg,
+ system_prompt=None,
+ max_length=max_prompt_length,
+ data_judger_mapping=data_judger_mapping,
+ ignore_multimodal_info=True,
+ ),
+ },
+]
+eval_data_cfg_tb2eval = [
+ {
+ "dataset": DatasetConfig(
+ name="tb2-eval",
+ anno_path="/mnt/shared-storage-user/llmit1/user/liukuikun/delivery/data/tb2_eval_tasks.jsonl",
+ sample_ratio=1.0,
+ media_root=None,
+ class_name='JsonlDataset',
+ ),
+ "tokenize_fn": RLTB2EvalTokenizeFnConfig(
+ root_path="/mnt/shared-storage-user/llmit/user/wangziyi/projs/terminalbench2-harbor-p-cluster/terminal-bench-2",
+ max_length=max_prompt_length,
+ ),
+ },
+]
+# eval_dataset_cfg = (
+# (eval_dataset_cfg_search + eval_dataset_cfg_math + eval_dataset_cfg_review + eval_data_cfg_tb2eval)
+# if enable_evaluate
+# else []
+# )
+eval_dataset_cfg = eval_data_cfg_tb2eval
+dataloader_config = DataloaderConfig(pack_max_length=pack_max_length, collator="fake_collator", pack_level="none")
+
+
+# 3. judger
+compass_judger_cfg = JudgerConfig(
+ enable_weighted_judgers=True,
+ reward_judger_configs=[
+ CompassVerifierV2Config(
+ hosts=[
+ "10.102.251.61:23333",
+ "10.102.251.61:23334",
+ "10.102.251.61:23335",
+ "10.102.251.61:23336",
+ "10.102.251.61:23337",
+ "10.102.251.61:23338",
+ "10.102.251.61:23339",
+ "10.102.251.61:23340",
+ "10.102.216.52:23333",
+ "10.102.216.52:23334",
+ "10.102.216.52:23335",
+ "10.102.216.52:23336",
+ "10.102.216.52:23337",
+ "10.102.216.52:23338",
+ "10.102.216.52:23339",
+ "10.102.216.52:23340",
+ "10.102.238.19:23333",
+ "10.102.238.19:23334",
+ "10.102.238.19:23335",
+ "10.102.238.19:23336",
+ "10.102.238.19:23337",
+ "10.102.238.19:23338",
+ "10.102.238.19:23339",
+ "10.102.238.19:23340",
+ "10.102.239.68:23333",
+ "10.102.239.68:23334",
+ "10.102.239.68:23335",
+ "10.102.239.68:23336",
+ "10.102.239.68:23337",
+ "10.102.239.68:23338",
+ "10.102.239.68:23339",
+ "10.102.239.68:23340",
+ ]
+ ),
+ SGIJudgerConfig(
+ hosts=[
+ "10.102.213.32:30030",
+ "10.102.213.32:30031",
+ "10.102.213.32:30032",
+ "10.102.213.32:30033",
+ "10.102.213.32:30034",
+ "10.102.213.32:30035",
+ "10.102.213.32:30036",
+ "10.102.213.32:30037",
+ ],
+ model_name="/mnt/shared-storage-user/gpfs2-shared-public/huggingface/hub/models--Qwen--Qwen3.5-35B-A3B/snapshots/ec2d4ece1ffb563322cbee9a48fe0e3fcbce0307",
+ num_ray_actors=1,
+ request_timeout=60,
+ ),
+ FrontierScienceJudgerConfig(
+ hosts=[
+ "10.102.213.32:30030",
+ "10.102.213.32:30031",
+ "10.102.213.32:30032",
+ "10.102.213.32:30033",
+ "10.102.213.32:30034",
+ "10.102.213.32:30035",
+ "10.102.213.32:30036",
+ "10.102.213.32:30037",
+ ],
+ model_name="/mnt/shared-storage-user/gpfs2-shared-public/huggingface/hub/models--Qwen--Qwen3.5-35B-A3B/snapshots/ec2d4ece1ffb563322cbee9a48fe0e3fcbce0307",
+ num_ray_actors=1,
+ request_timeout=60,
+ ),
+ ],
+)
+review_judger_cfg = JudgerConfig(reward_judger_configs=[ReviewJudgerConfig(judger_name="openreview")])
+
+from xtuner.v1.ray.judger.controller import JudgerController
+
+compass_judger_controller = JudgerController.remote(
+ compass_judger_cfg,
+ ray.get(
+ placement_group(
+ bundles=[{"CPU": 1, "memory": 1024**3}] * len(compass_judger_cfg.reward_judger_configs),
+ strategy="PACK",
+ ).ready(),
+ timeout=30,
+ )
+)
+review_judger_controller = JudgerController.remote(
+ review_judger_cfg,
+ ray.get(
+ placement_group(
+ bundles=[{"CPU": 1, "memory": 1024**3}] * len(review_judger_cfg.reward_judger_configs),
+ strategy="PACK",
+ ).ready(),
+ timeout=30,
+ )
+)
+
+
+from xtuner.v1.ray.environment.lagent.parsers import Qwen3_5FunctionCallParser
+
+
+def prepare_agent_inputs(env, group_data_item: RLDataFlowItem):
+ env_agent = group_data_item.env.agent.extra_info.pop('agent').env_agent
+ user_prompt = group_data_item.data.messages[-1]['content']
+ env_message = AgentMessage(
+ sender="env", content=user_prompt, uid=hashlib.md5(user_prompt.encode('utf-8')).hexdigest()
+ )
+ if not env_agent.memory.get_memory():
+ set_env_message = AgentMessage(sender="env", content=group_data_item)
+ env_agent.memory and env_agent.memory.add(set_env_message) # type: ignore[union-attr]
+ return (env_message,)
+
+
+def convert_rollout_tractory_to_train(env, group_data_items):
+ agent_data_items, rollout_response_items, judger_response_items = [], [], []
+ for i in range(len(group_data_items)):
+ history = group_data_items[i].env.rollout.extra_info['agent_state_dict']['policy_agent.memory']
+ env_history = group_data_items[i].env.rollout.extra_info['agent_state_dict']['env_agent.memory']
+ messages = group_data_items[i].env.rollout.extra_info['agent_message_dict']['policy_agent.messages']
+ agent_data_items.append(RLAgentDataItem(extra_info=dict(messages=messages, state={"history": history})))
+ rollout_response_items.append(
+ RLRolloutResponseItem(
+ response=history[-1]['raw_content'],
+ response_ids=history[-1]['raw_content_ids'],
+ logprobs=history[-1]['raw_content_logprobs'],
+ state=RolloutState.COMPLETED,
+ )
+ )
+ reward_payload = env_history[-1]['reward']
+ if isinstance(reward_payload, dict):
+ if 'score' in reward_payload and group_data_items[i].data.extra_info.get('origin_data_source') in [
+ 'openreview',
+ ]: # scale down the reward for review data
+ reward_payload['score'] = reward_payload['score'] / 10
+ judger_response_items.append(RLJudgerResponseItem(reward=reward_payload))
+ else:
+ judger_response_items.append(RLJudgerResponseItem(reward=dict(score=reward_payload)))
+ group_data_items = update_dataflow_item(group_data_items, "env.agent", agent_data_items)
+ group_data_items = update_dataflow_item(group_data_items, "env.rollout", rollout_response_items)
+ group_data_items = update_dataflow_item(group_data_items, "env.judger", judger_response_items)
+ return group_data_items
+
+
+def prepare_agent_inputs_for_tb2rl(env, group_data_item: RLDataFlowItem):
+ return group_data_item
+
+
+def convert_rollout_tractory_to_train_for_tb2rl(env, group_data_items):
+ agent_data_items, rollout_response_items, judger_response_items = [], [], []
+ for i in range(len(group_data_items)):
+ messages = group_data_items[i].env.agent.extra_info['message_dict']['policy_agent.messages']
+ tools = group_data_items[i].env.agent.extra_info['message_dict'].get('policy_agent.tools')
+ agent_data_items.append(RLAgentDataItem(extra_info=dict(messages=messages, tools=tools)))
+ # breakpoint()
+ rollout_response_items.append(
+ RLRolloutResponseItem(
+ response=messages[-1]['raw_content'],
+ response_ids=messages[-1]['raw_content_ids'],
+ logprobs=messages[-1]['raw_content_logprobs'],
+ state=RolloutState.COMPLETED,
+ )
+ )
+ reward_payload = group_data_items[i].env.judger.extra_info['total']
+ judger_response_items.append(RLJudgerResponseItem(reward=dict(score=reward_payload)))
+ group_data_items = update_dataflow_item(group_data_items, "env.agent", agent_data_items)
+ group_data_items = update_dataflow_item(group_data_items, "env.rollout", rollout_response_items)
+ group_data_items = update_dataflow_item(group_data_items, "env.judger", judger_response_items)
+ # breakpoint()
+ return group_data_items
+
+
+pg = AutoAcceleratorWorkers.build_placement_group(resources)
+rollout_controller = ray.remote(max_concurrency=1000)(RolloutController).remote(rollout_config, pg)
+load_checkpoint_cfg = LoadCheckpointConfig(load_optimizer_states=False, load_optimizer_args=False)
+
+search_tool = AsyncMCPClient(
+ # type=AsyncMCPClient,
+ name='SerperSearch',
+ server_type='http',
+ rate_limit=500.0,
+ # max_concurrency=128,
+ url=[
+ 'http://10.102.103.157:8091/mcp',
+ 'http://10.102.103.155:8096/mcp',
+ 'http://10.102.103.155:8092/mcp',
+ 'http://10.102.103.155:8095/mcp',
+ 'http://10.102.103.155:8097/mcp',
+ 'http://10.102.103.155:8098/mcp',
+ 'http://10.102.103.155:8094/mcp',
+ 'http://10.102.103.155:8093/mcp',
+ ],
+)
+browse_tool = AsyncMCPClient(
+ # type=AsyncMCPClient,
+ name='JinaBrowse',
+ server_type='http',
+ rate_limit=100.0,
+ max_concurrency=40,
+ url=[
+ 'http://10.102.103.155:8104/mcp',
+ 'http://10.102.103.155:8100/mcp',
+ 'http://10.102.103.148:8101/mcp',
+ 'http://10.102.103.157:8105/mcp',
+ 'http://10.102.103.155:8099/mcp',
+ 'http://10.102.103.155:8102/mcp',
+ 'http://10.102.103.155:8103/mcp',
+ 'http://10.102.103.155:8106/mcp',
+ ],
+)
+visit_tool = WebVisitor(
+ # type=WebVisitor,
+ browse_tool=dict(
+ type=AsyncMCPClient,
+ name='JinaBrowse',
+ server_type='http',
+ rate_limit=100.0,
+ max_concurrency=40,
+ url=[
+ 'http://10.102.103.155:8104/mcp',
+ 'http://10.102.103.155:8100/mcp',
+ 'http://10.102.103.148:8101/mcp',
+ 'http://10.102.103.157:8105/mcp',
+ 'http://10.102.103.155:8099/mcp',
+ 'http://10.102.103.155:8102/mcp',
+ 'http://10.102.103.155:8103/mcp',
+ 'http://10.102.103.155:8106/mcp',
+ ],
+ ),
+ llm=dict(
+ type=ControllerWrapper,
+ rollout_controller=rollout_controller,
+ sample_params=SampleParams(max_tokens=max_response_length),
+ tool_call_parser=Qwen3_5FunctionCallParser(),
+ ),
+ truncate_browse_response_length=60000,
+ tokenizer_path=model_path,
+)
+arxiv_tool = AsyncMCPClient(
+ # type=AsyncMCPClient,
+ name='arxiv_search',
+ server_type='http',
+ rate_limit=50.0,
+ max_concurrency=20,
+ url=[
+ 'http://10.102.252.176:2364/mcp',
+ ],
+)
+
+
+tool_template = """# Tools
+
+You have access to the following functions:
+
+
+{tools}
+
+
+If you choose to call a function ONLY reply in the following format with NO suffix:
+
+
+
+
+value_1
+
+
+This is the value for the second parameter
+that can span
+multiple lines
+
+
+
+
+
+Reminder:
+- Function calls MUST follow the specified format: an inner block must be nested within XML tags
+- Required parameters MUST be specified
+- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after
+- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
+"""
+
+review_sys_prompt = r"""You are an expert reviewer in the field of pre-training, 3D human pose and shape estimation, and self-supervised representation learning for the ICLR 2023 conference. Your responsibilitie is conducting an initial review. You must follow a strict reasoning process.
+
+### INTERACTION PROTOCOL
+You must perform a **Thought-Action** loop for every step. Do not rush to the final review.
+1. **THINK (``)** - **Before Tool Use**: Analyze the paper step-by-step. Identify gaps in your knowledge. Formulate search queries. Explicitly check if the `end_date` constraint is met.
+ - **After Tool Response**: Analyze the search results. synthesis the information. Decide if more searches are needed or if you have sufficient context to write the review.
+ - The reasoning content must be enclosed with `` and `` tags.
+
+2. **ACTION (``)**:
+ - If you need external information, output a tool call enclosed in `` and `` tags.
+ - The content inside the tags must be valid JSON format.
+ - **CRITICAL**: When calling `arxiv_search`, you MUST set the `end_date` argument to '20230603'.
+
+3. **FINALIZE (``)**:
+ - Only when you have completed all necessary research and reasoning, output the final review inside `` and `` tags.
+ - The review must include: Summary, Strengths, Weaknesses, Questions, and References.
+
+### CITATION & REFERENCE STANDARDS (CRITICAL)
+You must adhere to the following strict formatting rules for the final review:
+1. **Sequential Numbering**: Citations in the text must be numbered sequentially starting from [1] based on the order they first appear (e.g., [1], [2], [3]). **Do NOT use the ID returned by the search tool (e.g., do not use [81], [28]).**
+2. **Inline Citation**: Every external claim must have an inline citation. Example: 'Recent studies [1] have shown that...'
+3. **Reference List**: The 'References' section at the end must strictly match the inline citations. Each entry must contain:
+ - Format: `[ID] Authors. **Title**. Venue, Year. URL`
+ - Example: `[1] J. Smith et al. **Deep Learning**. NeurIPS, 2023. https://arxiv.org/...`
+ - Ensure the Title and URL are complete. The URL field in references must be the EXACT URL returned by the search tool. Do not use placeholders like '...'.
+
+### OUTPUT FORMAT (``)
+
+## Summary
+...
+## Strengths
+...
+## Weaknesses
+...
+## Questions
+...
+## References
+[1] ...
+[2] ...
+
+
+### CONSTRAINTS
+- Never output `` and `` in the same turn.
+- Strictly adhere to the submission deadline: do not use knowledge published after 20230603.- Do NOT hallucinate citations. You can ONLY cite papers that explicitly appear in the search results from the output. If you cannot find a relevant paper, do not cite a fake one.
+"""
+
+
+from lagent_rl.environment.lagent_ext.actions.python_executor import PythonExecutor
+
+python_action = PythonExecutor(
+ # type=PythonExecutor,
+ rate_limit_qps=500.0,
+ burst=20,
+ retries=5,
+ connect_timeout=5.0,
+ read_timeout=30.0
+)
+
+# tool prompts with python (for science search)
+search_browse_python_tool_prompt = get_tool_prompt([search_tool, browse_tool, python_action], template=tool_template)
+search_visit_python_tool_prompt = get_tool_prompt([search_tool, visit_tool, python_action], template=tool_template)
+# tool prompts without python (for pure search)
+search_browse_tool_prompt = get_tool_prompt([search_tool, browse_tool], template=tool_template)
+search_visit_tool_prompt = get_tool_prompt([search_tool, visit_tool], template=tool_template)
+# other tool prompts
+review_tool_prompt = get_tool_prompt([arxiv_tool], template=tool_template)
+python_tool_prompt = get_tool_prompt([python_action], template=tool_template)
+
+
+# ============================================================
+# Science search agents (with python_action) - for agent_science data
+# ============================================================
+train_science_search_browse_agent = dict(
+ type=FunctionCallAgent,
+ policy_agent=dict(
+ type=AsyncTokenInOutAgent,
+ llm=dict(
+ type=ControllerWrapper,
+ rollout_controller=rollout_controller,
+ sample_params=SampleParams(max_tokens=max_response_length),
+ tool_call_parser=Qwen3_5FunctionCallParser(),
+ ),
+ template=search_browse_python_tool_prompt,
+ ),
+ env_agent=dict(
+ type=EnvAgent,
+ actions=[search_tool, browse_tool, python_action],
+ judger=JudgerWrapper(
+ # type=JudgerWrapper,
+ # judger_cfg=compass_judger_cfg,
+ # placement_group=ray.get(
+ # placement_group(
+ # bundles=[{"CPU": 1, "memory": 1024**3}] * len(compass_judger_cfg.reward_judger_configs),
+ # strategy="PACK",
+ # ).ready(),
+ # timeout=30,
+ # ),
+ judger_controller=compass_judger_controller,
+ ),
+ max_turn=max_turn,
+ lower_tool_turn_bound=lower_tool_turn_bound_science,
+ enable_repeated_tool_call_penalty=enable_repeated_tool_call_penalty,
+ enable_no_thinking_penalty=enable_no_thinking_penalty,
+ max_tool_response_length=max_tool_response_length,
+ ),
+ finish_condition=finish_condition_func,
+ initialize_input=False,
+)
+train_science_search_visit_agent = dict(
+ type=FunctionCallAgent,
+ policy_agent=dict(
+ type=AsyncTokenInOutAgent,
+ llm=dict(
+ type=ControllerWrapper,
+ rollout_controller=rollout_controller,
+ sample_params=SampleParams(max_tokens=max_response_length),
+ tool_call_parser=Qwen3_5FunctionCallParser(),
+ ),
+ template=search_visit_python_tool_prompt,
+ ),
+ env_agent=dict(
+ type=EnvAgent,
+ actions=[search_tool, visit_tool, python_action],
+ judger=JudgerWrapper(
+ # type=JudgerWrapper,
+ # judger_cfg=compass_judger_cfg,
+ # placement_group=ray.get(
+ # placement_group(
+ # bundles=[{"CPU": 1, "memory": 1024**3}] * len(compass_judger_cfg.reward_judger_configs),
+ # strategy="PACK",
+ # ).ready(),
+ # timeout=30,
+ # ),
+ judger_controller=compass_judger_controller,
+ ),
+ max_turn=max_turn,
+ lower_tool_turn_bound=lower_tool_turn_bound_science,
+ enable_repeated_tool_call_penalty=enable_repeated_tool_call_penalty,
+ enable_no_thinking_penalty=enable_no_thinking_penalty,
+ max_tool_response_length=max_tool_response_length,
+ ),
+ finish_condition=finish_condition_func,
+ initialize_input=False,
+)
+
+
+train_agent_with_search_browse = dict(
+ type=FunctionCallAgent,
+ policy_agent=dict(
+ type=AsyncTokenInOutAgent,
+ llm=dict(
+ type=ControllerWrapper,
+ rollout_controller=rollout_controller,
+ sample_params=SampleParams(max_tokens=max_response_length),
+ tool_call_parser=Qwen3_5FunctionCallParser(),
+ ),
+ template=search_browse_tool_prompt,
+ ),
+ env_agent=dict(
+ type=EnvAgent,
+ actions=[search_tool, browse_tool],
+ judger=JudgerWrapper(
+ # type=JudgerWrapper,
+ # judger_cfg=compass_judger_cfg,
+ # placement_group=ray.get(
+ # placement_group(
+ # bundles=[{"CPU": 1, "memory": 1024**3}] * len(compass_judger_cfg.reward_judger_configs),
+ # strategy="PACK",
+ # ).ready(),
+ # timeout=30,
+ # ),
+ judger_controller=compass_judger_controller,
+ ),
+ max_turn=max_turn,
+ lower_tool_turn_bound=lower_tool_turn_bound,
+ enable_repeated_tool_call_penalty=enable_repeated_tool_call_penalty,
+ enable_no_thinking_penalty=enable_no_thinking_penalty,
+ max_tool_response_length=max_tool_response_length,
+ ),
+ finish_condition=finish_condition_func,
+ initialize_input=False,
+)
+train_agent_with_search_visit = dict(
+ type=FunctionCallAgent,
+ policy_agent=dict(
+ type=AsyncTokenInOutAgent,
+ llm=dict(
+ type=ControllerWrapper,
+ rollout_controller=rollout_controller,
+ sample_params=SampleParams(max_tokens=max_response_length),
+ tool_call_parser=Qwen3_5FunctionCallParser(),
+ ),
+ template=search_visit_tool_prompt,
+ ),
+ env_agent=dict(
+ type=EnvAgent,
+ actions=[search_tool, visit_tool],
+ judger=JudgerWrapper(
+ # type=JudgerWrapper,
+ # judger_cfg=compass_judger_cfg,
+ # placement_group=ray.get(
+ # placement_group(
+ # bundles=[{"CPU": 1, "memory": 1024**3}] * len(compass_judger_cfg.reward_judger_configs),
+ # strategy="PACK",
+ # ).ready(),
+ # timeout=30,
+ # ),
+ judger_controller=compass_judger_controller,
+ ),
+ max_turn=max_turn,
+ lower_tool_turn_bound=lower_tool_turn_bound,
+ enable_repeated_tool_call_penalty=enable_repeated_tool_call_penalty,
+ enable_no_thinking_penalty=enable_no_thinking_penalty,
+ max_tool_response_length=max_tool_response_length,
+ ),
+ finish_condition=finish_condition_func,
+ initialize_input=False,
+)
+eval_search_agent = dict(
+ type=FunctionCallAgent,
+ policy_agent=dict(
+ type=AsyncTokenInOutAgent,
+ llm=dict(
+ type=ControllerWrapper,
+ rollout_controller=rollout_controller,
+ sample_params=SampleParams(max_tokens=max_response_length),
+ tool_call_parser=Qwen3_5FunctionCallParser(),
+ ),
+ template=search_visit_python_tool_prompt,
+ ),
+ env_agent=dict(
+ type=EnvAgent,
+ actions=[search_tool, visit_tool, python_action],
+ judger=JudgerWrapper(
+ # type=JudgerWrapper,
+ # judger_cfg=compass_judger_cfg,
+ # placement_group=ray.get(
+ # placement_group(
+ # bundles=[{"CPU": 1, "memory": 1024**3}] * len(compass_judger_cfg.reward_judger_configs),
+ # strategy="PACK",
+ # ).ready(),
+ # timeout=30,
+ # ),
+ judger_controller=compass_judger_controller,
+ ),
+ max_turn=max_turn,
+ lower_tool_turn_bound=None,
+ enable_repeated_tool_call_penalty=False,
+ enable_no_thinking_penalty=False,
+ max_tool_response_length=max_tool_response_length,
+ ),
+ finish_condition=finish_condition_func,
+ initialize_input=False,
+)
+
+train_math_agent = dict(
+ type=FunctionCallAgent,
+ policy_agent=dict(
+ type=AsyncTokenInOutAgent,
+ llm=dict(
+ type=ControllerWrapper,
+ rollout_controller=rollout_controller,
+ sample_params=SampleParams(max_tokens=max_response_length),
+ tool_call_parser=Qwen3_5FunctionCallParser(),
+ ),
+ template=python_tool_prompt,
+ ),
+ env_agent=dict(
+ type=EnvAgent,
+ actions=[python_action],
+ judger=JudgerWrapper(
+ # type=JudgerWrapper,
+ # judger_cfg=compass_judger_cfg,
+ # placement_group=ray.get(
+ # placement_group(
+ # bundles=[{"CPU": 1, "memory": 1024**3}] * len(compass_judger_cfg.reward_judger_configs),
+ # strategy="PACK",
+ # ).ready(),
+ # timeout=30,
+ # ),
+ judger_controller=compass_judger_controller,
+ ),
+ max_turn=25,
+ lower_tool_turn_bound=5,
+ enable_repeated_tool_call_penalty=False,
+ enable_no_thinking_penalty=False,
+ max_tool_response_length=4096,
+ ),
+ finish_condition=finish_condition_func,
+ initialize_input=False,
+)
+eval_math_agent = dict(
+ type=FunctionCallAgent,
+ policy_agent=dict(
+ type=AsyncTokenInOutAgent,
+ llm=dict(
+ type=ControllerWrapper,
+ rollout_controller=rollout_controller,
+ sample_params=SampleParams(max_tokens=max_response_length),
+ tool_call_parser=Qwen3_5FunctionCallParser(),
+ ),
+ template=python_tool_prompt,
+ ),
+ env_agent=dict(
+ type=EnvAgent,
+ actions=[python_action],
+ judger=JudgerWrapper(
+ # type=JudgerWrapper,
+ # judger_cfg=compass_judger_cfg,
+ # placement_group=ray.get(
+ # placement_group(
+ # bundles=[{"CPU": 1, "memory": 1024**3}] * len(compass_judger_cfg.reward_judger_configs),
+ # strategy="PACK",
+ # ).ready(),
+ # timeout=30,
+ # ),
+ judger_controller=compass_judger_controller,
+ ),
+ max_turn=25,
+ lower_tool_turn_bound=None,
+ enable_repeated_tool_call_penalty=False,
+ enable_no_thinking_penalty=False,
+ max_tool_response_length=4096,
+ ),
+ finish_condition=finish_condition_func,
+ initialize_input=False,
+)
+
+train_review_agent = dict(
+ type=FunctionCallAgent,
+ policy_agent=dict(
+ type=AsyncTokenInOutAgent,
+ llm=dict(
+ type=ControllerWrapper,
+ rollout_controller=rollout_controller,
+ sample_params=SampleParams(max_tokens=max_response_length),
+ tool_call_parser=Qwen3_5FunctionCallParser(argument_type={'end_date': str}),
+ ),
+ template=review_tool_prompt + "\n\n" + review_sys_prompt,
+ ),
+ env_agent=dict(
+ type=EnvAgent,
+ actions=[arxiv_tool],
+ judger=JudgerWrapper(
+ # type=JudgerWrapper,
+ # judger_cfg=review_judger_cfg,
+ # placement_group=ray.get(
+ # placement_group(bundles=[{"CPU": 1, "memory": 1024**3}], strategy="PACK").ready(), timeout=30
+ # ),
+ judger_controller=review_judger_controller,
+ reward_key=None,
+ ),
+ max_turn=25,
+ lower_tool_turn_bound=None,
+ enable_repeated_tool_call_penalty=False,
+ enable_no_thinking_penalty=False,
+ max_tool_response_length=max_tool_response_length,
+ ),
+ finish_condition=finish_condition_func,
+ initialize_input=False,
+)
+eval_review_agent = dict(
+ type=FunctionCallAgent,
+ policy_agent=dict(
+ type=AsyncTokenInOutAgent,
+ llm=dict(
+ type=ControllerWrapper,
+ rollout_controller=rollout_controller,
+ sample_params=SampleParams(max_tokens=max_response_length),
+ tool_call_parser=Qwen3_5FunctionCallParser(argument_type={'end_date': str}),
+ ),
+ template=review_tool_prompt + "\n\n" + review_sys_prompt,
+ ),
+ env_agent=dict(
+ type=EnvAgent,
+ actions=[arxiv_tool],
+ judger=JudgerWrapper(
+ # type=JudgerWrapper,
+ # judger_cfg=review_judger_cfg,
+ # placement_group=ray.get(
+ # placement_group(bundles=[{"CPU": 1, "memory": 1024**3}], strategy="PACK").ready(), timeout=30
+ # ),
+ judger_controller=review_judger_controller,
+ reward_key=None,
+ ),
+ max_turn=25,
+ lower_tool_turn_bound=None,
+ enable_repeated_tool_call_penalty=False,
+ enable_no_thinking_penalty=False,
+ max_tool_response_length=max_tool_response_length,
+ ),
+ finish_condition=finish_condition_func,
+ initialize_input=False,
+)
+
+
+def bucket(key: str, n: int) -> int:
+ hashkey = hashlib.md5(key.encode('utf-8')).hexdigest()
+ return sum(ord(c) for c in hashkey) % n
+
+
+def rollout_env_router_fn(item: RLDataFlowItem):
+ source = item.data.extra_info.get('origin_data_source', '')
+
+ # 1. Routing for Evaluation Environments
+ if source in ['openreview_test']:
+ return 'eval_review_agent'
+ if source.startswith('gaia') or source in [
+ 'BrowseComp-ZH',
+ 'HLE',
+ 'browsecomp',
+ 'GAIA_sft_1229',
+ 'sgi-deep-research',
+ 'frontierscience',
+ ]:
+ return 'eval_search_agent'
+ if source in ['AIME2024', 'AIME2025', 'aime2026', 'hmmt26', 'IMO-Bench-AnswerBench', 'UGD_hard']:
+ return 'eval_math_agent'
+ if source == 'tb2-eval':
+ return 'eval_tb2eval'
+
+ # 2. Routing for Train Environments
+ if source == 'openreview':
+ return 'train_review_agent'
+ elif source == 'math':
+ return 'train_math_agent'
+ elif source == 'claw-bench':
+ return 'train_clawbench'
+ elif source == 'tb2-rl':
+ return 'train_tb2rl'
+ elif source == 'agent_science':
+ # Science search data (with python_action)
+ match bucket(item.data.messages[-1]['content'], 2):
+ case 0:
+ return 'train_science_search_browse_agent'
+ case 1:
+ return 'train_science_search_visit_agent'
+ else:
+ # Default fallback to Search Train Environments
+ match bucket(item.data.messages[-1]['content'], 2):
+ case 0:
+ return 'train_agent_with_search_browse'
+ case 1:
+ return 'train_agent_with_search_visit'
+
+
+environment_config = dict(
+ type=ComposedEnvironment,
+ environment=experimental_name,
+ rollout_controller=rollout_controller,
+ environments={
+ 'train_science_search_browse_agent': dict(
+ type=AgentEnvironment,
+ environment='train_science_search_browse_agent',
+ agent_cfg=train_science_search_browse_agent,
+ rollout_controller=rollout_controller,
+ preprocess_func=prepare_agent_inputs,
+ postprocess_func=convert_rollout_tractory_to_train,
+ ),
+ 'train_science_search_visit_agent': dict(
+ type=AgentEnvironment,
+ environment='train_science_search_visit_agent',
+ agent_cfg=train_science_search_visit_agent,
+ rollout_controller=rollout_controller,
+ preprocess_func=prepare_agent_inputs,
+ postprocess_func=convert_rollout_tractory_to_train,
+ ),
+ 'train_agent_with_search_browse': dict(
+ type=AgentEnvironment,
+ environment='train_agent_with_search_browse',
+ agent_cfg=train_agent_with_search_browse,
+ rollout_controller=rollout_controller,
+ preprocess_func=prepare_agent_inputs,
+ postprocess_func=convert_rollout_tractory_to_train,
+ ),
+ 'train_agent_with_search_visit': dict(
+ type=AgentEnvironment,
+ environment='train_agent_with_search_visit',
+ agent_cfg=train_agent_with_search_visit,
+ rollout_controller=rollout_controller,
+ preprocess_func=prepare_agent_inputs,
+ postprocess_func=convert_rollout_tractory_to_train,
+ ),
+ 'eval_search_agent': dict(
+ type=AgentEnvironment,
+ environment='eval_search_agent',
+ agent_cfg=eval_search_agent,
+ rollout_controller=rollout_controller,
+ preprocess_func=prepare_agent_inputs,
+ postprocess_func=convert_rollout_tractory_to_train,
+ ),
+ 'train_math_agent': dict(
+ type=AgentEnvironment,
+ environment='train_math_agent',
+ agent_cfg=train_math_agent,
+ rollout_controller=rollout_controller,
+ preprocess_func=prepare_agent_inputs,
+ postprocess_func=convert_rollout_tractory_to_train,
+ ),
+ 'eval_math_agent': dict(
+ type=AgentEnvironment,
+ environment='eval_math_agent',
+ agent_cfg=eval_math_agent,
+ rollout_controller=rollout_controller,
+ preprocess_func=prepare_agent_inputs,
+ postprocess_func=convert_rollout_tractory_to_train,
+ ),
+ 'train_review_agent': dict(
+ type=AgentEnvironment,
+ environment='train_review_agent',
+ agent_cfg=train_review_agent,
+ rollout_controller=rollout_controller,
+ preprocess_func=prepare_agent_inputs,
+ postprocess_func=convert_rollout_tractory_to_train,
+ ),
+ 'eval_review_agent': dict(
+ type=AgentEnvironment,
+ environment='eval_review_agent',
+ agent_cfg=eval_review_agent,
+ rollout_controller=rollout_controller,
+ preprocess_func=prepare_agent_inputs,
+ postprocess_func=convert_rollout_tractory_to_train,
+ ),
+ 'train_tb2rl': dict(
+ type=InstallAgentEnvironment,
+ environment='train_tb2rl',
+ rollout_controller=rollout_controller,
+ preprocess_func=prepare_agent_inputs_for_tb2rl,
+ postprocess_func=convert_rollout_tractory_to_train_for_tb2rl,
+ ),
+ 'eval_tb2eval': dict(
+ type=InstallAgentEnvironment,
+ environment='eval_tb2eval',
+ rollout_controller=rollout_controller,
+ preprocess_func=prepare_agent_inputs_for_tb2rl,
+ postprocess_func=convert_rollout_tractory_to_train_for_tb2rl,
+ ),
+ },
+ router=rollout_env_router_fn,
+)
+
+# 4. dataflow and evaluator
+dataflow_config = DataFlowConfig(
+ env=experimental_name,
+ max_concurrent=max_concurrent_groups,
+ enable_partial_rollout=enable_partial_rollout,
+ tail_batch_candidate_steps=tail_batch_candidate_steps,
+ staleness_threshold=staleness_threshold,
+ max_retry_times=3,
+ prompt_repeat_k=prompt_repeat_k,
+ global_batch_size=global_batch_size,
+ sample_params=training_sample_params,
+)
+
+evaluator_cfg = (
+ EvaluatorConfig(
+ enable_evaluate=enable_evaluate,
+ enable_initial_evaluate=enable_initial_evaluate,
+ dataset_cfg=eval_dataset_cfg,
+ tokenizer=model_path,
+ eval_sample_ratio=1.0,
+ evaluate_step=evaluate_step,
+ compute_metric_func=partial(
+ compute_metric,
+ source_normalizer={
+ 'miroRL': 'websearch',
+ 'musique': 'websearch',
+ 'websailor': 'websearch',
+ 'webdancer': 'websearch',
+ 'gaia-level1': ('gaia-level1', 'gaia'),
+ 'gaia-level2': ('gaia-level2', 'gaia'),
+ 'gaia-level3': ('gaia-level3', 'gaia'),
+ },
+ ),
+ sample_params=evaluation_sample_params,
+ max_concurrent=8192,
+ )
+ if enable_evaluate
+ else None
+)
+
+
+def group_sample_filter_func(group_samples):
+ # filter all correct or all wrong sample
+ group_samples = [s for s in group_samples if s.env.rollout.response_ids is not None]
+
+ # filter all same reward sample
+ rewards = [d.env.judger.reward["score"] for d in group_samples]
+ if len(set(rewards)) == 1:
+ print(f"filter all same reward sample: {rewards}")
+ return []
+ return group_samples
+
+
+replay_buffer_cfg = ReplayBufferConfig(
+ dataset_cfg=train_dataset_cfg,
+ dataloader_cfg=dataloader_config,
+ tokenizer=model_path,
+ # postprocessor_func=group_sample_filter_func,
+)
+
+# # 5. Train worker
+model_cfg = Qwen3_5_VLMoE35BA3Config(
+ freeze_vision=True,
+ freeze_projector=True,
+)
+model_cfg.float8_cfg = None
+model_cfg.text_config.ep_size = 1
+model_cfg.text_config.z_loss_cfg = None
+model_cfg.text_config.balancing_loss_cfg = None
+model_cfg.text_config.freeze_routers = True
+model_cfg.text_config.mtp_config = MTPConfig(
+ num_layers=4,
+ loss_scaling_factor=1.0,
+ detach_mtp_lm_head_weight=True,
+ detach_mtp_inputs=True,
+ share_weights=True,
+)
+model_cfg.text_config.vocab_size = 251392
+# model_cfg.text_config.embed_grad_max_token_id = 251173
+
+optim_cfg = AdamWConfig(
+ lr=lr,
+ betas=(0.9, 0.95),
+ max_grad_norm=1.0,
+ weight_decay=0.1,
+ foreach=False,
+ skip_grad_norm_threshold=5,
+ eps=1e-15,
+)
+loss_cfg = GRPOLossConfig(
+ policy_loss_cfg=dict(
+ cliprange_high=0.2,
+ cliprange_low=0.2,
+ loss_type="intern_s1_delivery.modules.pg_loss.pg_loss_fn",
+ clip_ratio_c=5.0,
+ log_prob_diff_min=-20.0,
+ log_prob_diff_max=20.0,
+ ),
+ ignore_idx=-100,
+ use_kl_loss=False,
+ kl_loss_coef=0.0,
+ kl_loss_type="low_var_kl",
+ mode="chunk",
+ chunk_size=512,
+ rollout_is=RolloutImportanceSampling(
+ rollout_is_level="token",
+ rollout_is_mode="both",
+ rollout_is_threshold=(5, 0),
+ rollout_is_mask_threshold=(5, 0.5),
+ rollout_is_veto_threshold=(20, 0),
+ ),
+)
+lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=lr)
+fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=train_ep_size, fp32_lm_head=fp32_lm_head)
+train_worker_cfg: WorkerConfig = WorkerConfig(
+ model_cfg=model_cfg,
+ load_from=model_path,
+ optim_cfg=optim_cfg,
+ loss_cfg=loss_cfg,
+ lr_cfg=lr_cfg,
+ fsdp_cfg=fsdp_cfg,
+ sp_size=train_sp_size,
+ optimizer_steps=train_optimizer_steps,
+ pack_max_length=pack_max_length,
+)
+
+
+# 6. RL Trainer
+trainer = AgentRLTrainerConfig(
+ load_from=model_path,
+ pg=pg,
+ environment_config=environment_config,
+ dataflow_config=dataflow_config,
+ replay_buffer_config=replay_buffer_cfg,
+ train_worker_cfg=train_worker_cfg,
+ evaluator_config=evaluator_cfg,
+ tokenizer_path=model_path,
+ work_dir=work_dir,
+ total_epochs=total_epochs,
+ hf_interval=hf_interval,
+ skip_load_weights=skip_load_weights,
+ auto_resume=auto_resume,
+ checkpoint_interval=2,
+ checkpoint_maxkeep=1,
+ load_checkpoint_cfg=load_checkpoint_cfg,
+ checkpoint_no_save_optimizer=True,
+ skip_checkpoint_validation=True,
+ advantage_estimator_config=OverlongRLOOGroupEntropyBadwordAdvantageConfig(
+ entropy_upper_bound=0.65,
+ entropy_lower_bound=0.25,
+ tau_upper=0.0,
+ tau_lower=0.0,
+ coeff_min_upper=0.2,
+ coeff_min_lower=0.5,
+ overlong_filer=True,
+ badword_ratio_cost_factor=1.0,
+ tokenizer_path=model_path,
+ ),
+)
+
+
+import torch.distributed as dist
+
+from xtuner.v1.train.agent_rl_trainer import AgentRLTrainer
+
+trainer = AgentRLTrainer.from_config(trainer)
+trainer.fit()
+
+if dist.is_initialized():
+ dist.destroy_process_group()
diff --git a/examples/v1/scripts/rjob_run_train_interns1_1.sh b/examples/v1/scripts/rjob_run_train_interns1_1.sh
new file mode 100755
index 0000000000..e1e285d9bf
--- /dev/null
+++ b/examples/v1/scripts/rjob_run_train_interns1_1.sh
@@ -0,0 +1,20 @@
+JOB_NAME=xtuner_router_reply_insterns2_preview_ßtrain
+NUM_NODES=8
+WORKER_GPU=8
+CPU_NUMS=12
+NODE_MEMS=960
+IMAGE=registry.h.pjlab.org.cn/ailab-puyu/xpuyu:torch-2.7.0-076676dd-0708
+# rjob delete ${JOB_NAME}
+clusterx run \
+--job-name=${JOB_NAME} \
+--gpus-per-task=${WORKER_GPU} \
+--memory-per-task=$NODE_MEMS \
+--cpus-per-task=$((CPU_NUMS * (WORKER_GPU > 0 ? WORKER_GPU : 1))) \
+--image=${IMAGE} \
+--num-nodes ${NUM_NODES} \
+--priority 9 \
+--partition puyullm_gpu \
+--project-name ailab-puyullmgpu \
+--no-env \
+/mnt/shared-storage-user/llmit/user/liukuikun/workspace/xtuner/examples/v1/scripts/train_agentrl.sh
+# "zsh -exc '/mnt/shared-storage-user/llmit/user/liukuikun/workspace/crg_rl_projects_router_reply/scripts/run_train_interns1_1.sh'"
diff --git a/examples/v1/scripts/train_agentrl.sh b/examples/v1/scripts/train_agentrl.sh
index a19cd67e47..8afeb8fc54 100755
--- a/examples/v1/scripts/train_agentrl.sh
+++ b/examples/v1/scripts/train_agentrl.sh
@@ -11,13 +11,14 @@ export TRANSFORMERS_OFFLINE=1
export HF_EVALUATE_OFFLINE=1
export HF_HUB_OFFLINE=1
+# lmdeploy_dir=/mnt/shared-storage-user/llmit/user/lvchengqi/projects/interns2_rl/mtp_rl_dev/lmdeploy
lmdeploy_dir=/mnt/shared-storage-user/llmit/user/lvchengqi/projects/interns2_rl/mtp_rl_dev/lmdeploy
xtuner_dir=/mnt/shared-storage-user/llmit/user/liukuikun/workspace/xtuner
-intern_s2_delivery_dir=/mnt/shared-storage-user/llmit/user/lvchengqi/projects/interns2_rl/crg_rl_projects/src
+intern_s2_delivery_dir=/mnt/shared-storage-user/llmit/user/liujiangning/projects/interns2_preview_agentrl_mtp/crg_rl_projects/src
lagent_dir=/mnt/shared-storage-user/llmit/user/liukuikun/workspace/lagent
xtuner_project_dir=/mnt/shared-storage-user/llmit/user/liukuikun/workspace/xtuner/projects
-
-export PYTHONPATH=$intern_s2_delivery_dir:$lmdeploy_dir:$xtuner_dir:$lagent_dir:$xtuner_project_dir:$PYTHONPATH
+transformer_model=/mnt/shared-storage-user/llmit1/user/wangziyi/.cache/huggingface/modules
+export PYTHONPATH=$intern_s2_delivery_dir:$lmdeploy_dir:$xtuner_dir:$lagent_dir:$xtuner_project_dir:$transformer_model:$PYTHONPATH
export NLTK_DATA=/mnt/shared-storage-user/llmit/user/lishuaibin/mv2yidian/nltk_data
export XTUNER_USE_FA3=1
@@ -29,7 +30,7 @@ export PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True'
export XTUNER_USE_SGLANG=0
export XTUNER_USE_LMDEPLOY=1
export XTUNER_USE_VLLM=0
-export RL_LLM_MODEL=train_lkk_test_0429rc1
+export RL_LLM_MODEL=train_lkk_test_$(date "+%m%d%H%M%S")
export MASTER_PORT=6000
export WORLD_SIZE=1
export LMD_SKIP_WARMUP=1
@@ -51,8 +52,8 @@ export TRAIN_OPTIMIZER_STEPS=8
current_time=$(date "+%m%d%H")
-export CONFIG_PATH='/mnt/shared-storage-user/llmit/user/liukuikun/workspace/xtuner/examples/v1/config/interns2-35ba3-base03-20260413a-websearch-rl0415rc1_local.py'
-export WORK_DIR='/mnt/shared-storage-user/llmit1/user/liukuikun/delivery/interns2_preview_0429'
+export CONFIG_PATH='/mnt/shared-storage-user/llmit/user/liukuikun/workspace/xtuner/examples/v1/config/interns2-35ba3-base05-20260424a-rl-data260426rc1-56k-badword-mtp4_agenticrl_tb2_mtp4_0503rc1.py'
+export WORK_DIR='/mnt/shared-storage-user/llmit1/user/liukuikun/delivery/interns2_preview_0430rc9'
if [ ! -d "$WORK_DIR" ]; then
diff --git a/my_changes.patch b/my_changes.patch
new file mode 100644
index 0000000000..28c89d7ed9
--- /dev/null
+++ b/my_changes.patch
@@ -0,0 +1,62 @@
+diff --git a/xtuner/v1/ray/environment/lagent/tokenize.py b/xtuner/v1/ray/environment/lagent/tokenize.py
+index 41ae95ec..27074cf5 100644
+--- a/xtuner/v1/ray/environment/lagent/tokenize.py
++++ b/xtuner/v1/ray/environment/lagent/tokenize.py
+@@ -101,7 +101,7 @@ def tokenize(
+ f"Expected routed_experts_ref to be a base64 string, but got {type(routed_experts_ref)}"
+ )
+ ref_bytes = base64.b64decode(routed_experts_ref.encode("utf-8"))
+- routed_experts = cloudpickle.loads(ref_bytes)
++ routed_experts = ref_bytes
+ else:
+ routed_experts = None
+
+diff --git a/xtuner/v1/ray/rollout/worker.py b/xtuner/v1/ray/rollout/worker.py
+index fff43245..ad128f17 100644
+--- a/xtuner/v1/ray/rollout/worker.py
++++ b/xtuner/v1/ray/rollout/worker.py
+@@ -12,6 +12,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
+ import httpx
+ import numpy as np
+ import ray
++from ray import ObjectRef, cloudpickle
+ import requests # type: ignore[import-untyped]
+ from packaging.version import Version
+ from ray import ObjectRef
+@@ -584,8 +585,13 @@ class RolloutWorker(SingleAcceleratorWorker):
+ else:
+ cur_routed_experts = routed_experts
+
+- history_routed_experts = await input_extra_info["routed_experts"] # n, layer, expert
+- ray.internal.free(input_extra_info["routed_experts"], local_only=False)
++ history_routed_experts_key = input_extra_info["routed_experts"]
++ if isinstance(history_routed_experts_key, ObjectRef):
++ history_routed_experts = await input_extra_info["routed_experts"] # n, layer, expert
++ elif isinstance(history_routed_experts_key, str):
++ history_routed_experts_key = cloudpickle.loads(history_routed_experts_key)
++ history_routed_experts = await history_routed_experts_key
++ ray.internal.free(history_routed_experts_key, local_only=False)
+ del input_extra_info
+
+ assert (history_routed_experts.shape[0] - 1) > 0 and history_routed_experts.shape[
+diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py
+index 6ab83524..57a4feae 100644
+--- a/xtuner/v1/rl/base/worker.py
++++ b/xtuner/v1/rl/base/worker.py
+@@ -7,6 +7,7 @@ from pathlib import Path
+ from typing import Dict, Iterable, List, Sequence, TypeAlias, TypedDict, cast
+
+ import ray
++from ray import cloudpickle
+ import requests
+ import torch
+ import torch.distributed as dist
+@@ -422,7 +423,7 @@ class TrainingWorker(SingleAcceleratorWorker):
+ )
+ out_rollout_routed_expert.append(rollout_routed_experts_tensor)
+ else:
+- rollout_routed_expert_refs = rollout_routed_expert
++ rollout_routed_expert_refs = cloudpickle.loads(rollout_routed_expert)
+ rollout_routed_expert = ray.get(rollout_routed_expert_refs)
+ # free obj store explicitly
+ if self.sp_mesh is None or self.sp_mesh.size() == 1:
diff --git a/projects/claw_bench/agents/internclaw/config.py b/projects/claw_bench/agents/internclaw/config.py
index c59f241f8b..d90dc424e1 100644
--- a/projects/claw_bench/agents/internclaw/config.py
+++ b/projects/claw_bench/agents/internclaw/config.py
@@ -2,7 +2,7 @@
import os
-workspace = os.environ.get("TASK_WORKSPACE", "/workspace")
+workspace = os.environ.get("TASK_WORKSPACE", "")
skills_root = f"{workspace}/skills"
model = dict(
@@ -10,7 +10,7 @@
model=dict(
model=os.environ.get(
"RL_LLM_MODEL",
- "sft_interns2_pre_base03_20260413a_lr2e5_128gpu_hf5646",
+ "train_lkk_test",
),
base_url=os.environ.get(
"RL_LLM_BASE_URL",
diff --git a/projects/claw_bench/pipeline.py b/projects/claw_bench/pipeline.py
index 9033ec6dd8..9bef56309b 100644
--- a/projects/claw_bench/pipeline.py
+++ b/projects/claw_bench/pipeline.py
@@ -13,13 +13,14 @@
from xtuner.v1.ray.environment.rl_task.hooks import (
BenchEnv,
+ DumpDaemonLogOnFailure,
InstallLagent,
ParseJudgerStdout,
PickAgent,
RenderInstruction,
RunAgentInstallDeps,
+ UploadAgentConfigSource,
UploadChosenAgent,
- WriteAgentConfig,
)
from xtuner.v1.ray.environment.rl_task.judgers import Judger
from xtuner.v1.ray.environment.rl_task.runner import Runner
@@ -63,7 +64,7 @@
PATHS = SimpleNamespace(
wrappers_bench="/tmp/wrappers/claw_bench",
wrappers_lagent="/tmp/wrappers/lagent",
- agent_config="/tmp/agent_config.json",
+ agent_config="/tmp/agent_config.py",
trajectory="/tmp/trajectory.json",
message="/tmp/message.json",
verifier="/tmp/verifier",
@@ -204,8 +205,8 @@ def claw_pipeline(
ExecHook(f"mkdir -p {ws}"),
# 7. Render instruction.md: `workspace/` → abs path + {{KEY}} → env.
# RenderInstruction(rewrites=CLAW_INSTRUCTION_REWRITES),
- # 8. Exec chosen agent's config.py on host → upload resulting JSON.
- WriteAgentConfig(dst=PATHS.agent_config),
+ # 8. Upload chosen agent's config.py — daemon execs it in-sandbox.
+ UploadAgentConfigSource(dst=PATHS.agent_config),
# 9. Run install-deps.sh if the chosen agent template has one.
RunAgentInstallDeps(workspace=ws),
],
@@ -216,7 +217,7 @@ def claw_pipeline(
extras={"WORKSPACE": ws, "CLAW_WORKSPACE": ws},
),
timeout=1800,
- post=[DownloadHook(["/workspace", "/tmp/agent_response.txt"]), ReadFileHook("/tmp/message.json", "message")],
+ post=[DownloadHook(["/workspace", "/tmp/agent_response.txt"]), ReadFileHook("/tmp/message.json", "message"), DumpDaemonLogOnFailure()],
)
return Runner(
diff --git a/projects/claw_bench/wrappers/lagent/lagent_entry.sh b/projects/claw_bench/wrappers/lagent/lagent_entry.sh
index b92ec92476..5c940ddc4f 100755
--- a/projects/claw_bench/wrappers/lagent/lagent_entry.sh
+++ b/projects/claw_bench/wrappers/lagent/lagent_entry.sh
@@ -14,7 +14,8 @@
# state_dict → kills daemon. Exits 0 on success, non-zero on failure.
#
# Relies on /tmp/lagent-py wrapper (shared conda python + PYTHONPATH), which
-# runner bootstrap writes. Agent config JSON is what daemon consumes.
+# runner bootstrap writes. --config is a Python file defining agent_config;
+# daemon execs it in-sandbox so os.environ lookups resolve here, not on host.
# ---------------------------------------------------------------------------
set -uo pipefail
@@ -55,6 +56,26 @@ fi
LOG=/tmp/agent_daemon.log
: > "$LOG"
+# If a daemon call returned {"error": "..."} over socket (exception during
+# dispatch — e.g. LLM 4xx/5xx, timeout), exit non-zero so the host side's
+# _dump_daemon_log fires and the traceback surfaces in xtuner logs.
+_die_on_daemon_error() {
+ local resp="$1" phase="$2" code="$3"
+ local err
+ err=$(printf '%s' "$resp" | "$LAGENT_PY" -c '
+import json, sys
+try:
+ print(json.loads(sys.stdin.read() or "{}").get("error", ""))
+except Exception:
+ pass
+')
+ if [ -n "$err" ]; then
+ echo "daemon error in ${phase}: ${err}" >&2
+ tail -n 500 "$LOG" >&2 || true
+ exit "$code"
+ fi
+}
+
# ── 1. Start AgentDaemon ──────────────────────────────────────────────
nohup "$LAGENT_PY" -m lagent.serving.sandbox.daemon start \
--mode agent \
@@ -102,6 +123,7 @@ CHAT_RESP=$("$LAGENT_PY" -m lagent.serving.sandbox.daemon call \
tail -n 100 "$LOG" >&2 || true
exit 5
}
+_die_on_daemon_error "$CHAT_RESP" chat 5
# Extract final response content (plain text) to RESPONSE_OUT.
printf '%s' "$CHAT_RESP" | "$LAGENT_PY" -c '
@@ -121,6 +143,7 @@ STATE_RESP=$("$LAGENT_PY" -m lagent.serving.sandbox.daemon call \
tail -n 100 "$LOG" >&2 || true
exit 6
}
+_die_on_daemon_error "$STATE_RESP" state_dict 6
# Wrap lagent's native state into {"trajectory": [...]} if it isn't already.
printf '%s' "$STATE_RESP" | "$LAGENT_PY" -c '
@@ -153,6 +176,7 @@ POLICY_AGENT_MESSAGES=$("$LAGENT_PY" -m lagent.serving.sandbox.daemon call \
tail -n 100 "$LOG" >&2 || true
exit 7
}
+_die_on_daemon_error "$POLICY_AGENT_MESSAGES" get_messages 7
printf '%s' "$POLICY_AGENT_MESSAGES" | "$LAGENT_PY" -c '
import json, sys
diff --git a/projects/tb2_eval/agents/interndp/config.py b/projects/tb2_eval/agents/interndp/config.py
index 71037ae0da..b0efb194f4 100644
--- a/projects/tb2_eval/agents/interndp/config.py
+++ b/projects/tb2_eval/agents/interndp/config.py
@@ -43,7 +43,7 @@
model=dict(
model=os.environ.get(
"RL_LLM_MODEL",
- "train_lkk_test",
+ "",
),
base_url=os.environ.get(
"RL_LLM_BASE_URL",
@@ -52,8 +52,8 @@
api_key=os.environ.get("RL_LLM_API_KEY", "sk-admin"),
),
sample_params=dict(temperature=0.7, top_p=1.0, top_k=50),
- timeout=600,
- max_retry=500,
+ timeout=900,
+ max_retry=1,
sleep_interval=5,
extra_body=dict(spaces_between_special_tokens=False),
)
@@ -70,8 +70,10 @@
env_agent = dict(
type="lagent.agents.env_agent.RLEnvAgent",
actions=base_actions,
- max_turn=25,
+ max_turn=100,
enable_no_thinking_penalty=False,
+ max_tool_response_length=4096,
+ tool_response_truncate_side="left",
)
agent_config = dict(
diff --git a/projects/tb2_eval/pipeline.py b/projects/tb2_eval/pipeline.py
index 008e2cc5e6..5d958f6a2b 100644
--- a/projects/tb2_eval/pipeline.py
+++ b/projects/tb2_eval/pipeline.py
@@ -27,12 +27,13 @@
from xtuner.v1.ray.environment.rl_task.hooks import (
BenchEnv,
+ DumpDaemonLogOnFailure,
InstallLagent,
ParseJudgerStdout,
PickAgent,
RunAgentInstallDeps,
+ UploadAgentConfigSource,
UploadChosenAgent,
- WriteAgentConfig,
)
from xtuner.v1.ray.environment.rl_task.judgers import Judger
from xtuner.v1.ray.environment.rl_task.runner import Runner
@@ -58,7 +59,7 @@
PATHS = SimpleNamespace(
wrappers_bench="/tmp/wrappers/tb2_eval",
wrappers_lagent="/tmp/wrappers/lagent",
- agent_config="/tmp/agent_config.json",
+ agent_config="/tmp/agent_config.py",
trajectory="/tmp/trajectory.json",
message="/tmp/message.json",
tests="/tests",
@@ -91,7 +92,7 @@
]
# Placeholder image — each task overrides this via sandbox_spec in extra_info.
-DEFAULT_SANDBOX = SandboxSpec(image="tb2-eval-placeholder", ttl_seconds=1800, workspace_path="/app")
+DEFAULT_SANDBOX = SandboxSpec(image="tb2-eval-placeholder", ttl_seconds=11700, workspace_path="/app")
# ─────────────────────────────────────────────────────────────────
@@ -205,8 +206,8 @@ def tb2_eval_pipeline(
UploadChosenAgent(target_dir=f"{ws}/agent/"),
# 7. Ensure workspace dir exists.
ExecHook(f"mkdir -p {ws}"),
- # 8. Exec chosen agent's config.py on host → upload resulting JSON.
- WriteAgentConfig(dst=PATHS.agent_config),
+ # 8. Upload chosen agent's config.py — daemon execs it in-sandbox.
+ UploadAgentConfigSource(dst=PATHS.agent_config),
# 9. Run install-deps.sh if the chosen agent template has one.
RunAgentInstallDeps(workspace=ws),
],
@@ -215,10 +216,11 @@ def tb2_eval_pipeline(
workspace=ws,
extras={"WORKSPACE": ws},
),
- timeout=900,
+ timeout=10800,
post=[
DownloadHook([ws, "/tmp/agent_response.txt"]),
ReadFileHook("/tmp/message.json", "message"),
+ DumpDaemonLogOnFailure(),
],
)
diff --git a/projects/tb2_eval/wrappers/lagent/lagent_entry.sh b/projects/tb2_eval/wrappers/lagent/lagent_entry.sh
index b92ec92476..5c940ddc4f 100644
--- a/projects/tb2_eval/wrappers/lagent/lagent_entry.sh
+++ b/projects/tb2_eval/wrappers/lagent/lagent_entry.sh
@@ -14,7 +14,8 @@
# state_dict → kills daemon. Exits 0 on success, non-zero on failure.
#
# Relies on /tmp/lagent-py wrapper (shared conda python + PYTHONPATH), which
-# runner bootstrap writes. Agent config JSON is what daemon consumes.
+# runner bootstrap writes. --config is a Python file defining agent_config;
+# daemon execs it in-sandbox so os.environ lookups resolve here, not on host.
# ---------------------------------------------------------------------------
set -uo pipefail
@@ -55,6 +56,26 @@ fi
LOG=/tmp/agent_daemon.log
: > "$LOG"
+# If a daemon call returned {"error": "..."} over socket (exception during
+# dispatch — e.g. LLM 4xx/5xx, timeout), exit non-zero so the host side's
+# _dump_daemon_log fires and the traceback surfaces in xtuner logs.
+_die_on_daemon_error() {
+ local resp="$1" phase="$2" code="$3"
+ local err
+ err=$(printf '%s' "$resp" | "$LAGENT_PY" -c '
+import json, sys
+try:
+ print(json.loads(sys.stdin.read() or "{}").get("error", ""))
+except Exception:
+ pass
+')
+ if [ -n "$err" ]; then
+ echo "daemon error in ${phase}: ${err}" >&2
+ tail -n 500 "$LOG" >&2 || true
+ exit "$code"
+ fi
+}
+
# ── 1. Start AgentDaemon ──────────────────────────────────────────────
nohup "$LAGENT_PY" -m lagent.serving.sandbox.daemon start \
--mode agent \
@@ -102,6 +123,7 @@ CHAT_RESP=$("$LAGENT_PY" -m lagent.serving.sandbox.daemon call \
tail -n 100 "$LOG" >&2 || true
exit 5
}
+_die_on_daemon_error "$CHAT_RESP" chat 5
# Extract final response content (plain text) to RESPONSE_OUT.
printf '%s' "$CHAT_RESP" | "$LAGENT_PY" -c '
@@ -121,6 +143,7 @@ STATE_RESP=$("$LAGENT_PY" -m lagent.serving.sandbox.daemon call \
tail -n 100 "$LOG" >&2 || true
exit 6
}
+_die_on_daemon_error "$STATE_RESP" state_dict 6
# Wrap lagent's native state into {"trajectory": [...]} if it isn't already.
printf '%s' "$STATE_RESP" | "$LAGENT_PY" -c '
@@ -153,6 +176,7 @@ POLICY_AGENT_MESSAGES=$("$LAGENT_PY" -m lagent.serving.sandbox.daemon call \
tail -n 100 "$LOG" >&2 || true
exit 7
}
+_die_on_daemon_error "$POLICY_AGENT_MESSAGES" get_messages 7
printf '%s' "$POLICY_AGENT_MESSAGES" | "$LAGENT_PY" -c '
import json, sys
diff --git a/projects/tb2_eval/wrappers/tb2_eval/emit_judger_result_from_ctrf.py b/projects/tb2_eval/wrappers/tb2_eval/emit_judger_result_from_ctrf.py
index e31aab8afc..674236272e 100644
--- a/projects/tb2_eval/wrappers/tb2_eval/emit_judger_result_from_ctrf.py
+++ b/projects/tb2_eval/wrappers/tb2_eval/emit_judger_result_from_ctrf.py
@@ -1,8 +1,12 @@
#!/usr/bin/env python3
-"""Parse a CTRF JSON report + test log → emit a ``JudgerResult`` line to stdout.
-
-Honors ``@pytest.mark.weight(N)`` via pytest-json-ctrf's ``extra``/``metadata``
-section. Tests with no explicit weight get 1.0.
+"""Emit a ``JudgerResult`` line to stdout that matches official TB2 scoring.
+
+The bench's ``tests/test.sh`` writes the authoritative binary outcome to
+``/logs/verifier/reward.txt`` (``1`` iff every pytest invocation in that
+script exited 0, else ``0``) — including the multi-pytest case in tasks
+like ``fix-code-vulnerability``. We read that file as the source of truth
+for ``total``. CTRF is parsed only for per-test observability in the
+``criteria`` field and never used for scoring.
"""
from __future__ import annotations
@@ -13,26 +17,6 @@
from pathlib import Path
-def _extract_weight(test: dict) -> float:
- for section in ("extra", "metadata"):
- extras = test.get(section) or []
- if isinstance(extras, list):
- for e in extras:
- if isinstance(e, dict) and e.get("key") == "weight":
- try:
- return float(e.get("value", 1.0))
- except (TypeError, ValueError):
- return 1.0
- elif isinstance(extras, dict):
- w = extras.get("weight")
- if w is not None:
- try:
- return float(w)
- except (TypeError, ValueError):
- return 1.0
- return 1.0
-
-
def _log_tail(path: Path, bytes_: int = 800) -> str:
try:
return path.read_text(errors="replace")[-bytes_:]
@@ -40,60 +24,81 @@ def _log_tail(path: Path, bytes_: int = 800) -> str:
return ""
-def main() -> int:
- ap = argparse.ArgumentParser()
- ap.add_argument("--ctrf", required=True)
- ap.add_argument("--log", required=True)
- ap.add_argument("--pytest-rc", type=int, required=True)
- ap.add_argument("--judger-name", default="rule_grader")
- args = ap.parse_args()
+def _read_reward(path: Path) -> float | None:
+ try:
+ raw = path.read_text().strip()
+ except Exception:
+ return None
+ if not raw:
+ return None
+ try:
+ return float(raw)
+ except ValueError:
+ return None
- ctrf_path = Path(args.ctrf)
- log_path = Path(args.log)
+def _parse_criteria(ctrf_path: Path) -> tuple[dict[str, dict[str, float]], int, str | None]:
+ """Parse CTRF into per-test criteria for observability only.
+
+ Returns:
+ tuple[dict[str, dict[str, float]], int, str | None]: ``(criteria, test_count,
+ error)``. ``criteria`` maps test name to ``{"score": 0.0|1.0}``. ``error`` is
+ ``None`` on success or a message describing why CTRF was unreadable.
+ """
try:
data = json.loads(ctrf_path.read_text())
except Exception as exc:
- print(
- json.dumps(
- {
- "judger_name": args.judger_name,
- "total": 0.0,
- "error": f"ctrf missing/parse failed: {exc}. log tail: {_log_tail(log_path)}",
- },
- ensure_ascii=False,
- )
- )
- return 0
-
+ return {}, 0, f"ctrf missing/parse failed: {exc}"
tests = (data.get("results", {}) or {}).get("tests", []) or []
criteria: dict[str, dict[str, float]] = {}
for t in tests:
name = t.get("name", "unknown")
passed = t.get("status") == "passed"
- weight = _extract_weight(t)
- criteria[name] = {"score": 1.0 if passed else 0.0, "weight": weight}
+ criteria[name] = {"score": 1.0 if passed else 0.0}
+ return criteria, len(tests), None
- total_w = sum(c["weight"] for c in criteria.values())
- if total_w <= 0:
- total = 0.0
+
+def main() -> int:
+ ap = argparse.ArgumentParser()
+ ap.add_argument("--ctrf", required=True)
+ ap.add_argument("--log", required=True)
+ ap.add_argument("--reward-file", required=True)
+ ap.add_argument("--pytest-rc", type=int, required=True)
+ ap.add_argument("--judger-name", default="rule_grader")
+ args = ap.parse_args()
+
+ ctrf_path = Path(args.ctrf)
+ log_path = Path(args.log)
+ reward_path = Path(args.reward_file)
+
+ reward = _read_reward(reward_path)
+ criteria, test_count, ctrf_error = _parse_criteria(ctrf_path)
+
+ result: dict = {
+ "judger_name": args.judger_name,
+ "criteria": criteria,
+ "metadata": {
+ "pytest_rc": args.pytest_rc,
+ "test_count": test_count,
+ "reward_source": "reward.txt" if reward is not None else "pytest_rc",
+ },
+ }
+
+ if reward is not None:
+ result["total"] = round(reward, 4)
else:
- total = sum(c["score"] * c["weight"] for c in criteria.values()) / total_w
-
- print(
- json.dumps(
- {
- "judger_name": args.judger_name,
- "total": round(total, 4),
- "criteria": criteria,
- "metadata": {
- "pytest_rc": args.pytest_rc,
- "test_count": len(tests),
- },
- },
- ensure_ascii=False,
+ # reward.txt missing/unreadable: fall back to test.sh exit code.
+ result["total"] = 1.0 if args.pytest_rc == 0 else 0.0
+ result["error"] = (
+ f"reward file unreadable at {reward_path}; fell back to pytest_rc. "
+ f"log tail: {_log_tail(log_path)}"
)
- )
+
+ if ctrf_error is not None:
+ # CTRF is observability-only; surface the parse error but don't change total.
+ result.setdefault("error", ctrf_error)
+
+ print(json.dumps(result, ensure_ascii=False))
return 0
diff --git a/projects/tb2_eval/wrappers/tb2_eval/run_tests.sh b/projects/tb2_eval/wrappers/tb2_eval/run_tests.sh
index 9a523d19de..6a7bd318a6 100644
--- a/projects/tb2_eval/wrappers/tb2_eval/run_tests.sh
+++ b/projects/tb2_eval/wrappers/tb2_eval/run_tests.sh
@@ -5,10 +5,13 @@
# The bench ships its own test harness at /tests/test.sh which:
# - installs pytest + pytest-json-ctrf + /tests/test_requirements.txt
# - runs /tests/test_outputs.py with --ctrf /logs/verifier/ctrf.json
+# - writes the authoritative 0/1 outcome to /logs/verifier/reward.txt
+# (for multi-pytest tasks like fix-code-vulnerability, this is the
+# AND of all pytest exit codes — matching official TB2 scoring)
#
-# We invoke it and hand the resulting CTRF to our shared emitter so
-# the stage's stdout is a single JudgerResult JSON line (what
-# ParseJudgerStdout expects).
+# We invoke it and hand both reward.txt and the resulting CTRF to our
+# shared emitter. reward.txt drives the JudgerResult `total`; CTRF is
+# parsed for per-test observability only.
#
# Env:
# $TESTS_DIR tests directory inside the sandbox (default: /tests)
@@ -36,5 +39,6 @@ fi
"$PY" "$WRAPPER_DIR/emit_judger_result_from_ctrf.py" \
--ctrf /logs/verifier/ctrf.json \
--log "$TEST_LOG" \
+ --reward-file /logs/verifier/reward.txt \
--pytest-rc "$TEST_RC" \
--judger-name "$JUDGER_NAME"
\ No newline at end of file
diff --git a/projects/tb2_rl/agents/interndp/config.py b/projects/tb2_rl/agents/interndp/config.py
index 71037ae0da..5758037749 100644
--- a/projects/tb2_rl/agents/interndp/config.py
+++ b/projects/tb2_rl/agents/interndp/config.py
@@ -43,7 +43,7 @@
model=dict(
model=os.environ.get(
"RL_LLM_MODEL",
- "train_lkk_test",
+ "",
),
base_url=os.environ.get(
"RL_LLM_BASE_URL",
@@ -52,8 +52,8 @@
api_key=os.environ.get("RL_LLM_API_KEY", "sk-admin"),
),
sample_params=dict(temperature=0.7, top_p=1.0, top_k=50),
- timeout=600,
- max_retry=500,
+ timeout=1800,
+ max_retry=1,
sleep_interval=5,
extra_body=dict(spaces_between_special_tokens=False),
)
@@ -70,7 +70,9 @@
env_agent = dict(
type="lagent.agents.env_agent.RLEnvAgent",
actions=base_actions,
- max_turn=25,
+ max_turn=100,
+ max_tool_response_length=4096,
+ tool_response_truncate_side="left",
enable_no_thinking_penalty=False,
)
diff --git a/projects/tb2_rl/pipeline.py b/projects/tb2_rl/pipeline.py
index 8e45ee1a29..62ec8fe4bb 100644
--- a/projects/tb2_rl/pipeline.py
+++ b/projects/tb2_rl/pipeline.py
@@ -24,12 +24,13 @@
from xtuner.v1.ray.environment.rl_task.hooks import (
BenchEnv,
+ DumpDaemonLogOnFailure,
InstallLagent,
ParseJudgerStdout,
PickAgent,
RunAgentInstallDeps,
+ UploadAgentConfigSource,
UploadChosenAgent,
- WriteAgentConfig,
)
from xtuner.v1.ray.environment.rl_task.judgers import Judger
from xtuner.v1.ray.environment.rl_task.runner import Runner
@@ -55,7 +56,7 @@
PATHS = SimpleNamespace(
wrappers_bench="/tmp/wrappers/tb2_rl",
wrappers_lagent="/tmp/wrappers/lagent",
- agent_config="/tmp/agent_config.json",
+ agent_config="/tmp/agent_config.py",
trajectory="/tmp/trajectory.json",
message="/tmp/message.json",
tests="/tests",
@@ -87,7 +88,7 @@
),
]
-DEFAULT_SANDBOX = SandboxSpec(image="t-data-processing-v1", ttl_seconds=1800, workspace_path="/app")
+DEFAULT_SANDBOX = SandboxSpec(image="t-data-processing-v1", ttl_seconds=11700, workspace_path="/app")
# ─────────────────────────────────────────────────────────────────
@@ -201,8 +202,8 @@ def tb2_rl_pipeline(
UploadChosenAgent(target_dir=f"{ws}/agent/"),
# 7. Ensure workspace dir exists.
ExecHook(f"mkdir -p {ws}"),
- # 8. Exec chosen agent's config.py on host → upload resulting JSON.
- WriteAgentConfig(dst=PATHS.agent_config),
+ # 8. Upload chosen agent's config.py — daemon execs it in-sandbox.
+ UploadAgentConfigSource(dst=PATHS.agent_config),
# 9. Run install-deps.sh if the chosen agent template has one.
RunAgentInstallDeps(workspace=ws),
],
@@ -211,10 +212,11 @@ def tb2_rl_pipeline(
workspace=ws,
extras={"WORKSPACE": ws},
),
- timeout=900,
+ timeout=10800,
post=[
DownloadHook([ws, "/tmp/agent_response.txt"]),
ReadFileHook("/tmp/message.json", "message"),
+ DumpDaemonLogOnFailure(),
],
)
diff --git a/projects/tb2_rl/wrappers/lagent/lagent_entry.sh b/projects/tb2_rl/wrappers/lagent/lagent_entry.sh
index b92ec92476..5c940ddc4f 100755
--- a/projects/tb2_rl/wrappers/lagent/lagent_entry.sh
+++ b/projects/tb2_rl/wrappers/lagent/lagent_entry.sh
@@ -14,7 +14,8 @@
# state_dict → kills daemon. Exits 0 on success, non-zero on failure.
#
# Relies on /tmp/lagent-py wrapper (shared conda python + PYTHONPATH), which
-# runner bootstrap writes. Agent config JSON is what daemon consumes.
+# runner bootstrap writes. --config is a Python file defining agent_config;
+# daemon execs it in-sandbox so os.environ lookups resolve here, not on host.
# ---------------------------------------------------------------------------
set -uo pipefail
@@ -55,6 +56,26 @@ fi
LOG=/tmp/agent_daemon.log
: > "$LOG"
+# If a daemon call returned {"error": "..."} over socket (exception during
+# dispatch — e.g. LLM 4xx/5xx, timeout), exit non-zero so the host side's
+# _dump_daemon_log fires and the traceback surfaces in xtuner logs.
+_die_on_daemon_error() {
+ local resp="$1" phase="$2" code="$3"
+ local err
+ err=$(printf '%s' "$resp" | "$LAGENT_PY" -c '
+import json, sys
+try:
+ print(json.loads(sys.stdin.read() or "{}").get("error", ""))
+except Exception:
+ pass
+')
+ if [ -n "$err" ]; then
+ echo "daemon error in ${phase}: ${err}" >&2
+ tail -n 500 "$LOG" >&2 || true
+ exit "$code"
+ fi
+}
+
# ── 1. Start AgentDaemon ──────────────────────────────────────────────
nohup "$LAGENT_PY" -m lagent.serving.sandbox.daemon start \
--mode agent \
@@ -102,6 +123,7 @@ CHAT_RESP=$("$LAGENT_PY" -m lagent.serving.sandbox.daemon call \
tail -n 100 "$LOG" >&2 || true
exit 5
}
+_die_on_daemon_error "$CHAT_RESP" chat 5
# Extract final response content (plain text) to RESPONSE_OUT.
printf '%s' "$CHAT_RESP" | "$LAGENT_PY" -c '
@@ -121,6 +143,7 @@ STATE_RESP=$("$LAGENT_PY" -m lagent.serving.sandbox.daemon call \
tail -n 100 "$LOG" >&2 || true
exit 6
}
+_die_on_daemon_error "$STATE_RESP" state_dict 6
# Wrap lagent's native state into {"trajectory": [...]} if it isn't already.
printf '%s' "$STATE_RESP" | "$LAGENT_PY" -c '
@@ -153,6 +176,7 @@ POLICY_AGENT_MESSAGES=$("$LAGENT_PY" -m lagent.serving.sandbox.daemon call \
tail -n 100 "$LOG" >&2 || true
exit 7
}
+_die_on_daemon_error "$POLICY_AGENT_MESSAGES" get_messages 7
printf '%s' "$POLICY_AGENT_MESSAGES" | "$LAGENT_PY" -c '
import json, sys
diff --git a/xtuner/v1/data_proto/rl_data.py b/xtuner/v1/data_proto/rl_data.py
index 8cf71ca134..5e80f6d3c9 100644
--- a/xtuner/v1/data_proto/rl_data.py
+++ b/xtuner/v1/data_proto/rl_data.py
@@ -17,11 +17,14 @@
if TYPE_CHECKING:
+ import numpy as np
import ray
RayObjectRef = ray.ObjectRef
+ NumpyArray = np.ndarray
else:
RayObjectRef: TypeAlias = Any
+ NumpyArray: TypeAlias = Any
logger = get_logger()
@@ -116,7 +119,7 @@ class RLDatasetItem(BaseModel):
class RolloutExtraInfo(TypedDict):
- routed_experts: NotRequired[list[int] | str | RayObjectRef] # type: ignore[valid-type]
+ routed_experts: NotRequired[list[int] | RayObjectRef | NumpyArray] # type: ignore[valid-type]
partial_rollout_input_ids: NotRequired[list[int]]
diff --git a/xtuner/v1/ray/base/accelerator.py b/xtuner/v1/ray/base/accelerator.py
index cf05509525..866ae51777 100644
--- a/xtuner/v1/ray/base/accelerator.py
+++ b/xtuner/v1/ray/base/accelerator.py
@@ -1,4 +1,5 @@
import os
+from datetime import timedelta
from typing import Any, Dict, List, Literal, Tuple, TypeVar
import ray
@@ -246,9 +247,15 @@ def setup_distributed(self, rank: int, master_addr: str, master_port: int, world
else:
raise ValueError(f"Unsupported accelerator architecture: {self.accelerator}")
# 使用环境变量初始化
+ # NCCL watchdog aborts collectives after this timeout. The default
+ # (10 min) is too tight for long-context MoE RL jobs where a single
+ # slow rank on a heavy allgather can burn 15+ min. Override via
+ # XTUNER_DIST_TIMEOUT_MIN env var.
+ timeout_min = int(os.environ.get("XTUNER_DIST_TIMEOUT_MIN", 60))
dist.init_process_group(
backend=backend,
init_method="env://", # 这告诉 PyTorch 从环境变量读取配置
+ timeout=timedelta(minutes=timeout_min),
)
def test_all_reduce(self):
diff --git a/xtuner/v1/ray/config/worker.py b/xtuner/v1/ray/config/worker.py
index 279849d41c..55d28fc1f2 100644
--- a/xtuner/v1/ray/config/worker.py
+++ b/xtuner/v1/ray/config/worker.py
@@ -186,13 +186,6 @@ class RolloutConfig(BaseModel):
help="Whether to enable returning routed experts for the rollout worker.",
),
] = False
- return_routed_experts_key: Annotated[
- bool,
- Parameter(
- group=infer_group,
- help="Whether to keep LMDeploy routed experts as shared-store string keys before training.",
- ),
- ] = False
launch_server_method: Annotated[
Literal["ray", "multiprocessing"],
Parameter(
diff --git a/xtuner/v1/ray/dataflow/replay_buffer.py b/xtuner/v1/ray/dataflow/replay_buffer.py
index 94206d7b76..30144d2769 100644
--- a/xtuner/v1/ray/dataflow/replay_buffer.py
+++ b/xtuner/v1/ray/dataflow/replay_buffer.py
@@ -22,12 +22,10 @@
RLEnvDataItem,
RLExtraDataItem,
RLUIDItem,
- RolloutExtraInfo,
RolloutState,
is_valid_for_replaybuffer,
)
from xtuner.v1.datasets.config import DataloaderConfig
-from xtuner.v1.ray.rollout.lmdeploy import get_lmdeploy_routed_experts_ref
from xtuner.v1.ray.utils import free_object_refs
from xtuner.v1.utils import XTUNER_DETERMINISTIC, get_logger
from xtuner.v1.utils.device import get_device
@@ -375,24 +373,6 @@ def __init__(self, replay_buffer_cfg):
self.sample_from_aborted_count = 0
self.sample_from_expired_count = 0
- def _pop_routed_experts_from_extra_info(
- self, extra_info: RolloutExtraInfo, *, free_ref: bool = False
- ) -> ObjectRef | None:
- if "routed_experts" not in extra_info:
- return None
-
- routed_experts = extra_info["routed_experts"]
- if isinstance(routed_experts, str):
- routed_experts = get_lmdeploy_routed_experts_ref(routed_experts)
- elif not isinstance(routed_experts, ObjectRef):
- routed_experts = ray.put(routed_experts)
-
- del extra_info["routed_experts"]
- if free_ref:
- free_object_refs([routed_experts])
- return None
- return routed_experts
-
def _update_replay_meta_state(self, replay_meta: ReplayMeta, new_state: RolloutState):
for observation_id in replay_meta.observation_ids:
old_state = self._observations2states.get(observation_id)
@@ -408,10 +388,6 @@ def _strip_rollout_payload_for_rerun(self, replay_meta: ReplayMeta, new_state: R
reused."""
old_obs_refs = [ref for ref in replay_meta.observation_refs if ref is not None]
if old_obs_refs:
- for old_obs_ref in old_obs_refs:
- old_env = ray.get(old_obs_ref)
- if hasattr(old_env, "rollout"):
- self._pop_routed_experts_from_extra_info(old_env.rollout.extra_info, free_ref=True)
ray.internal.free(old_obs_refs, local_only=False)
replay_meta.observation_refs = [ray.put(RLEnvDataItem()) for _ in replay_meta.observation_ids]
self._update_replay_meta_state(replay_meta, new_state)
@@ -590,6 +566,33 @@ def _restore_nested_objectrefs(self, obj: Any):
return obj
return obj
+ def convert_to_ray_objref(self, data_item: RLDataFlowItem):
+ """Converts large tensors in RLDataFlowItem to ray.ObjectRefs.
+
+ Args:
+ data_item (RLDataFlowItem): The data item containing large tensors.
+ Returns:
+ RLDataFlowItem: The data item with large tensors converted to ray.ObjectRefs.
+ """
+ # convert data.multimodal_train_info to ray.ObjectRef
+ if hasattr(data_item.data, "multimodal_train_info"):
+ multimodal_info = data_item.data.multimodal_train_info
+ if multimodal_info and "pixel_values" in multimodal_info:
+ # 一组数据共享同一个data_item.data,所以只需要转换一次
+ if not isinstance(multimodal_info["pixel_values"], ray.ObjectRef):
+ pixel_values_ref = ray.put(multimodal_info["pixel_values"])
+ del multimodal_info["pixel_values"]
+ data_item.data.multimodal_train_info["pixel_values"] = pixel_values_ref # type: ignore[index]
+ # convert rollout.extra_info.router_experts to ray.ObjectRef
+ if "routed_experts" in data_item.env.rollout.extra_info:
+ val = data_item.env.rollout.extra_info["routed_experts"]
+ # str = uuid key into RoutedExpertStore; ObjectRef = legacy path.
+ # Either is left untouched; only raw tensors get ray.put'd.
+ if not isinstance(val, (ray.ObjectRef, str)):
+ routed_experts_ref = ray.put(val)
+ del data_item.env.rollout.extra_info["routed_experts"]
+ data_item.env.rollout.extra_info["routed_experts"] = routed_experts_ref
+
def has_objectref(self, item: RLDataFlowItem) -> bool:
def check(obj):
if isinstance(obj, ray.ObjectRef):
@@ -718,7 +721,11 @@ def _sample_from_expired_storage(self) -> List[RLDataFlowItem]:
for sample in group_samples:
assert sample.data.input_ids and sample.data.num_tokens, "input_ids or num_tokens is empty!"
- self._pop_routed_experts_from_extra_info(sample.env.rollout.extra_info, free_ref=True)
+ if "routed_experts" in sample.env.rollout.extra_info:
+ val = sample.env.rollout.extra_info["routed_experts"]
+ if isinstance(val, ray.ObjectRef):
+ ray.internal.free(val, local_only=False)
+ del sample.env.rollout.extra_info["routed_experts"]
del sample.env
sample.env = RLEnvDataItem() # 重置env数据
sample.uid.action_id = action_id
@@ -754,7 +761,11 @@ def _sample_from_aborted_storage(self) -> List[RLDataFlowItem]:
assert sample.data.input_ids and sample.data.num_tokens, "input_ids or num_tokens is empty!"
if not self.enable_partial_rollout:
# 清除上次的response_ids等env数据
- self._pop_routed_experts_from_extra_info(sample.env.rollout.extra_info, free_ref=True)
+ if "routed_experts" in sample.env.rollout.extra_info:
+ val = sample.env.rollout.extra_info["routed_experts"]
+ if isinstance(val, ray.ObjectRef):
+ ray.internal.free(val, local_only=False)
+ del sample.env.rollout.extra_info["routed_experts"]
del sample.env
sample.env = RLEnvDataItem()
sample.uid.version = 0
diff --git a/xtuner/v1/ray/environment/agent_env.py b/xtuner/v1/ray/environment/agent_env.py
index 6a3b964cae..a1f7b1a184 100644
--- a/xtuner/v1/ray/environment/agent_env.py
+++ b/xtuner/v1/ray/environment/agent_env.py
@@ -1,6 +1,7 @@
import asyncio
import inspect
import os
+import pickle
import traceback
from copy import deepcopy
from typing import Callable, List, Self, Tuple
@@ -20,6 +21,11 @@
from .base_env import BaseEnvironment
+# Memory diagnostics thresholds (bytes). Set via env var to disable (0).
+_ITEM_SIZE_LOG_THRESHOLD = int(os.environ.get("XTUNER_ITEM_SIZE_LOG_MB", "1")) * 1024 * 1024
+_RSS_MONITOR_INTERVAL = int(os.environ.get("XTUNER_RSS_MONITOR_SEC", "60"))
+
+
def check_dead_actors():
# 获取所有 Actor 的列表
from ray.util.state import list_actors
@@ -35,6 +41,38 @@ def check_dead_actors():
return dead_actors
+def _log_item_size_if_large(item) -> None:
+ """Diagnostic: pickle-measure item + per-field breakdown when total exceeds threshold.
+
+ Runs on every ``_inner_agent_call`` iteration; on fresh samples the total is a few KB
+ so the early-exit keeps steady-state overhead negligible. On resume / abort paths the
+ item carries ``agent_state_dict`` / ``agent_message_dict`` from prior turns — we want
+ visibility into which field dominates.
+ """
+ try:
+ total_size = len(pickle.dumps(item))
+ except Exception as exc:
+ get_logger().debug(f"[item-size] pickle.dumps failed: {exc}")
+ return
+ if total_size < _ITEM_SIZE_LOG_THRESHOLD:
+ return
+ field_sizes: dict = {}
+ try:
+ for key, val in item.env.rollout.extra_info.items():
+ try:
+ field_sizes[key] = len(pickle.dumps(val))
+ except Exception:
+ field_sizes[key] = -1 # type: ignore[assignment]
+ except Exception:
+ pass
+ get_logger().warning(
+ f"[item-size] sample={getattr(item.uid, 'observation_id', '?')} "
+ f"total={total_size / 1e6:.1f}MB "
+ f"state={item.env.rollout.state} "
+ f"fields={ {k: f'{v / 1e6:.1f}MB' for k, v in field_sizes.items()} }"
+ )
+
+
@ray.remote(max_concurrency=int(os.environ.get("XTUNER_MAX_CONCURRENCY", 2000))) # type: ignore[call-overload]
class AgentEnvironment(BaseEnvironment):
def __init__(
@@ -54,6 +92,24 @@ def __init__(
self.agent_cfg = agent_cfg
self.preprocess_func = preprocess_func
self.postprocess_func = postprocess_func
+ if _RSS_MONITOR_INTERVAL > 0:
+ self._rss_task = asyncio.get_event_loop().create_task(self._rss_monitor())
+
+ async def _rss_monitor(self):
+ """Diagnostic: periodically log this actor's RSS so cross-sample leaks surface."""
+ try:
+ import psutil
+ except ImportError:
+ get_logger().warning("[actor-rss] psutil not available, disabling RSS monitor")
+ return
+ proc = psutil.Process()
+ while True:
+ try:
+ rss_gb = proc.memory_info().rss / 1e9
+ get_logger().info(f"[actor-rss] env={self.environment} pid={proc.pid} rss={rss_gb:.2f}GB")
+ except Exception as exc:
+ get_logger().warning(f"[actor-rss] read failed: {exc}")
+ await asyncio.sleep(_RSS_MONITOR_INTERVAL)
async def generate( # type: ignore[override]
self, group_data_items: List[RLDataFlowItem], sample_params=None, extra_params=None
@@ -68,6 +124,8 @@ async def _inner_agent_call(item):
if item.env.rollout.state == RolloutState.ABORTED:
agent.load_state_dict(item.env.rollout.extra_info["agent_state_dict"]) # type: ignore[operator]
+ if _ITEM_SIZE_LOG_THRESHOLD > 0:
+ _log_item_size_if_large(item)
_item = deepcopy(item)
_item.env.agent.extra_info["agent"] = agent
agent_inputs = self.preprocess_func(self, _item)
diff --git a/xtuner/v1/ray/environment/install_agent_env.py b/xtuner/v1/ray/environment/install_agent_env.py
index 809567fcb5..59b0baf375 100644
--- a/xtuner/v1/ray/environment/install_agent_env.py
+++ b/xtuner/v1/ray/environment/install_agent_env.py
@@ -112,8 +112,9 @@ async def _inner_agent_call(item):
uid,
provider=self.provider,
lagent_src_dir=DEFAULT_LAGENT_SRC,
- llm_base_url=None,
- llm_api_key=None,
+ llm_model=os.environ.get("RL_LLM_MODEL"),
+ llm_base_url=os.environ.get("RL_LLM_BASE_URL"),
+ llm_api_key=os.environ.get("RL_LLM_API_KEY"),
)
except BaseException as exc:
get_logger().error(
@@ -144,7 +145,26 @@ async def _inner_agent_call(item):
# passed_data_items.append(sample)
continue
else:
- sample.env.agent.extra_info["message_dict"] = result["env"]["agent"]["message_dict"]
+ # Defend against silent-pass / truncated trajectory: rc=0 but
+ # last message in policy_agent.messages lacks the fields
+ # postprocess will read. Same heuristic as
+ # DumpDaemonLogOnFailure so log signal + filter stay
+ # consistent.
+ msg_dict = (result["env"]["agent"] or {}).get("message_dict") or {}
+ messages = msg_dict.get("policy_agent.messages") or []
+ last = messages[-1] if messages else {}
+ required = ("raw_content", "raw_content_ids", "raw_content_logprobs")
+ missing = [k for k in required if not last.get(k)]
+ if missing:
+ get_logger().warning(
+ f"silent-pass rollout skipped: "
+ f"uid={sample.uid.observation_id} "
+ f"task_id={result.get('data', {}).get('extra_info', {}).get('task_id')} "
+ f"missing={missing}"
+ )
+ continue
+ sample.env.agent.extra_info["message_dict"] = msg_dict
+ sample.env.agent.extra_info["daemon_log"] = result["env"]["agent"].get("daemon_log", "")
sample.env.judger.extra_info.update(result["env"]["judger"])
completed_data_items.append(sample)
completed_data_items_result = self.postprocess_func(self, completed_data_items) # type: ignore[arg-type]
diff --git a/xtuner/v1/ray/environment/lagent/parsers.py b/xtuner/v1/ray/environment/lagent/parsers.py
index 20717db26d..3ab2a2fb1b 100644
--- a/xtuner/v1/ray/environment/lagent/parsers.py
+++ b/xtuner/v1/ray/environment/lagent/parsers.py
@@ -109,11 +109,21 @@ def parse_response(self, data: AgentMessage) -> AgentMessage:
parameters = {}
for p_match in param_matches:
p_name = p_match.group(1).strip()
- p_value = p_match.group(2).strip()
+ # Strip exactly one leading and one trailing newline — those
+ # are the formatting newlines inserted by the chat template
+ # around the value. Preserving additional newlines lets the
+ # model express trailing whitespace when it matters (e.g. a
+ # '\n' at the end of terminal keystrokes that triggers bash
+ # to execute the command).
+ p_raw = p_match.group(2)
+ if p_raw.startswith("\n"):
+ p_raw = p_raw[1:]
+ if p_raw.endswith("\n"):
+ p_raw = p_raw[:-1]
try:
- parsed_value = ast.literal_eval(p_value)
+ parsed_value = ast.literal_eval(p_raw)
except (ValueError, SyntaxError):
- parsed_value = p_value
+ parsed_value = p_raw
if p_name in self.argument_type:
try:
parsed_value = self.argument_type[p_name](parsed_value)
diff --git a/xtuner/v1/ray/environment/lagent/tokenize.py b/xtuner/v1/ray/environment/lagent/tokenize.py
index e8c6a1e6c8..865af2ad57 100644
--- a/xtuner/v1/ray/environment/lagent/tokenize.py
+++ b/xtuner/v1/ray/environment/lagent/tokenize.py
@@ -2,6 +2,9 @@
import re
from typing import Any, Dict, List
+import numpy as np
+import ray
+
from xtuner.v1.utils import get_logger
@@ -25,6 +28,7 @@ def tokenize(
thinking_start_ids = tokenizer.encode("", add_special_tokens=False)
thinking_end_ids = tokenizer.encode("", add_special_tokens=False)
routed_experts = None
+ previous_routed_experts_tasks = set()
def get_content_index(content_ids) -> int:
content_ids_str = " ".join([str(content_id) for content_id in content_ids])
@@ -81,7 +85,33 @@ def split_conversation(messages: List[Dict[str, Any]]) -> List[List[Dict[str, An
and "routed_experts" in msg[0]["extra_info"]
and msg[0]["extra_info"]["routed_experts"] is not None
):
- routed_experts = msg[0]["extra_info"]["routed_experts"]
+ routed_experts_ref = msg[0]["extra_info"]["routed_experts"]
+ if isinstance(routed_experts_ref, np.ndarray):
+ # Inline path: numpy array carried directly. Same array object
+ # across turns → same id(); dedup on id avoids re-submitting the
+ # identical history on rollout retries.
+ dedup_key = id(routed_experts_ref)
+ passthrough: Any = routed_experts_ref
+ elif isinstance(routed_experts_ref, ray.ObjectRef):
+ dedup_key = routed_experts_ref.hex()
+ passthrough = routed_experts_ref
+ elif isinstance(routed_experts_ref, str):
+ # Legacy path: uuid key into RoutedExpertStore — forward as-is.
+ # Legacy path (base64(cloudpickle(ObjectRef))) is also a str;
+ # we forward it unchanged and let the rollout worker detect
+ # the shape via isinstance checks downstream.
+ dedup_key = routed_experts_ref
+ passthrough = routed_experts_ref
+ else:
+ raise TypeError(f"Unexpected type for routed_experts_ref: {type(routed_experts_ref)}")
+ if dedup_key in previous_routed_experts_tasks:
+ logger.warning(
+ "[tokenize_fn] Detected repeated routed_experts_ref, setting routed_experts to None to avoid errors."
+ )
+ routed_experts = None
+ else:
+ routed_experts = passthrough
+ previous_routed_experts_tasks.add(dedup_key)
else:
routed_experts = None
diff --git a/xtuner/v1/ray/environment/rl_task/hooks.py b/xtuner/v1/ray/environment/rl_task/hooks.py
index 5f2d2ca8a5..14ff30fc19 100644
--- a/xtuner/v1/ray/environment/rl_task/hooks.py
+++ b/xtuner/v1/ray/environment/rl_task/hooks.py
@@ -39,6 +39,7 @@
walk_files,
)
from xtuner.v1.ray.environment.rl_task.schemas import AgentSpec, CriterionScore, JudgerResult
+from xtuner.v1.utils import get_logger
# ─────────────────────────────────────────────────────────────────
@@ -52,7 +53,7 @@ class PickAgent(Hook):
Selection is deterministic on ``ctx["uid"]["root_id"]`` so the same
rollout always picks the same agent. Also stores ``template_root`` in
``ctx["agent_template_root"]`` so downstream hooks
- (:class:`UploadChosenAgent`, :class:`WriteAgentConfig`,
+ (:class:`UploadChosenAgent`, :class:`UploadAgentConfigSource`,
:class:`RunAgentInstallDeps`) know where the agent's files live on
the host.
"""
@@ -204,30 +205,33 @@ async def __call__(self, client: Any, ctx: dict[str, Any]) -> None:
# ─────────────────────────────────────────────────────────────────
-# Agent config: exec config.py on host, upload as JSON
+# Agent config: upload config.py source (daemon execs it in-sandbox)
# ─────────────────────────────────────────────────────────────────
-class WriteAgentConfig(Hook):
- """Exec the chosen agent's ``config.py`` on the host; upload the resulting
- ``agent_config`` dict as JSON to ``dst``.
+class UploadAgentConfigSource(Hook):
+ """Upload the chosen agent's ``config.py`` source file to ``dst``.
+
+ The lagent daemon exec's this file in the sandbox to build the agent
+ dict — so ``os.environ`` lookups inside ``config.py`` resolve against
+ the sandbox's own env (populated by :class:`BenchEnv`), not the host's.
Agent template lives at ``ctx["agent_template_root"] / chosen.name /``
(populated by :class:`PickAgent`).
"""
- name = "write_agent_config"
+ name = "upload_agent_config_source"
- def __init__(self, dst: str = "/tmp/agent_config.json"):
+ def __init__(self, dst: str = "/tmp/agent_config.py"):
self.dst = dst
async def __call__(self, client: Any, ctx: dict[str, Any]) -> None:
chosen: AgentSpec = ctx["chosen_agent"]
template_root: Path = ctx["agent_template_root"]
cfg_path = template_root / chosen.name / chosen.config
- cfg = _exec_python_ns(cfg_path, "agent_config")
- blob = json.dumps(cfg, ensure_ascii=False).encode()
- await http_upload(client, self.dst, base64.b64encode(blob).decode())
+ if not cfg_path.is_file():
+ raise FileNotFoundError(f"agent config {cfg_path!r} not found")
+ await upload_tar_and_extract(client, {self.dst: cfg_path}, "/")
# ─────────────────────────────────────────────────────────────────
@@ -316,15 +320,86 @@ def __call__(self, ctx: dict[str, Any]) -> dict[str, str]:
"TASK_WORKSPACE": self.workspace,
"TASK_INSTRUCTION": f"{self.workspace}/{data.instruction}",
}
- if runtime.get("llm_base_url"):
- env["RL_LLM_BASE_URL"] = runtime["llm_base_url"]
- if runtime.get("llm_api_key"):
- env["RL_LLM_API_KEY"] = runtime["llm_api_key"]
+ for env_key, runtime_key in (
+ ("RL_LLM_MODEL", "llm_model"),
+ ("RL_LLM_BASE_URL", "llm_base_url"),
+ ("RL_LLM_API_KEY", "llm_api_key"),
+ ):
+ val = runtime.get(runtime_key)
+ if val:
+ env[env_key] = val
env.update(self.extras)
ctx["env_vars_for_instruction"] = env
return env
+# ─────────────────────────────────────────────────────────────────
+# Daemon log retrieval (post-hook)
+# ─────────────────────────────────────────────────────────────────
+
+
+class DumpDaemonLogOnFailure(Hook):
+ """Post-hook: pull ``/tmp/agent_daemon.log`` and log its tail on failure.
+
+ Two triggers:
+ - Stage's entry returned non-zero (``rc != 0``) — usual sandbox error.
+ - Silent-pass: ``rc == 0`` but the pulled ``message_key`` contents show
+ the last ``policy_agent.messages`` entry lacks ``raw_content_ids``
+ (LLM call somehow produced no token ids — typically an exception
+ swallowed by the agent layer). Disable by passing ``message_key=None``.
+
+ Always stores the full daemon log at ``ctx["pulled"][key]`` for
+ downstream consumers regardless of whether we log.
+ """
+
+ name = "dump_daemon_log_on_failure"
+
+ def __init__(
+ self,
+ path: str = "/tmp/agent_daemon.log",
+ *,
+ tail_lines: int = 500,
+ key: str = "daemon_log",
+ message_key: str | None = "message",
+ ):
+ self.path = path
+ self.tail_lines = tail_lines
+ self.key = key
+ self.message_key = message_key
+
+ async def __call__(self, client: Any, ctx: dict[str, Any]) -> None:
+ try:
+ blob = await client.download_file(self.path)
+ except Exception as exc:
+ get_logger().warning(f"could not download daemon log at {self.path}: {exc}")
+ return
+ text = blob.decode(errors="replace")
+ ctx.setdefault("pulled", {})[self.key] = text
+
+ result = ctx.get("result")
+ rc = getattr(result, "return_code", 0) if result else 0
+
+ should_dump = rc != 0
+ reason = f"rc={rc}"
+ if not should_dump and self.message_key:
+ raw = (ctx.get("pulled") or {}).get(self.message_key) or ""
+ try:
+ msgs = json.loads(raw).get("policy_agent.messages", []) if raw else []
+ except Exception:
+ msgs = []
+ last = msgs[-1] if msgs else {}
+ required = ("raw_content", "raw_content_ids", "raw_content_logprobs")
+ missing = [k for k in required if not last.get(k)]
+ if missing:
+ should_dump = True
+ reason = f"silent-pass (last message missing {missing})"
+
+ if should_dump:
+ lines = text.splitlines()
+ tail = "\n".join(lines[-self.tail_lines :]) if len(lines) > self.tail_lines else text
+ get_logger().error(f"daemon log tail [{reason}] ({self.path}):\n{tail}")
+
+
# ─────────────────────────────────────────────────────────────────
# Judger result parsing
# ─────────────────────────────────────────────────────────────────
@@ -353,14 +428,6 @@ async def __call__(self, client: Any, ctx: dict[str, Any]) -> None:
# ─────────────────────────────────────────────────────────────────
-def _exec_python_ns(path: Path, expected_name: str) -> Any:
- ns: dict[str, Any] = {}
- exec(compile(path.read_text(encoding="utf-8"), str(path), "exec"), ns)
- if expected_name not in ns:
- raise KeyError(f"{expected_name!r} not defined in {path}")
- return ns[expected_name]
-
-
def _parse_stage_stdout(name: str, result: StageResult) -> JudgerResult:
if result.return_code != 0:
return JudgerResult(
diff --git a/xtuner/v1/ray/environment/rl_task/runner.py b/xtuner/v1/ray/environment/rl_task/runner.py
index 5c5075ae26..6e62c9351e 100644
--- a/xtuner/v1/ray/environment/rl_task/runner.py
+++ b/xtuner/v1/ray/environment/rl_task/runner.py
@@ -68,6 +68,7 @@ async def run_single(
*,
provider: Any,
lagent_src_dir: str | Path | None = None,
+ llm_model: str | None = None,
llm_base_url: str | None = None,
llm_api_key: str | None = None,
) -> dict[str, Any]:
@@ -81,10 +82,12 @@ async def run_single(
"uid": uid,
"runtime": {
"lagent_src_dir": lagent_src_dir,
+ "llm_model": llm_model,
"llm_base_url": llm_base_url,
"llm_api_key": llm_api_key,
},
"workspace": self.infer.sandbox.workspace_path,
+ "sandbox_image": self.infer.sandbox.image,
}
client = None
@@ -104,7 +107,6 @@ async def run_single(
infer_result = await self.infer.run(client, ctx)
get_logger().info(f"[{tid}] infer: done rc={infer_result.return_code} ({time.monotonic() - t1:.1f}s)")
if not infer_result.ok:
- await _dump_daemon_log(client)
return _mark_failed(
data,
uid,
@@ -157,14 +159,6 @@ def _infer_metadata(ctx: dict[str, Any]) -> dict[str, Any]:
return md
-async def _dump_daemon_log(client) -> None:
- try:
- data = await client.download_file("/tmp/agent_daemon.log")
- get_logger().error(f"agent daemon log tail:\n{data.decode(errors='replace')[-4000:]}")
- except Exception as exc:
- get_logger().warning(f"could not download daemon log: {exc}")
-
-
_ACQUIRE_MAX_ATTEMPTS = 3
# Sandbox cold-start can take 30-60s under load; pathological boots run
# longer. With async waits this budget is cheap (no thread tied up), so
@@ -264,7 +258,10 @@ def _mark_completed(
"step_rewards": [sr.model_dump() for sr in judge.step_rewards],
"failed": judge.failed,
},
- "agent": {"message_dict": json.loads(infer.pulled.get("message", "{}"))},
+ "agent": {
+ "message_dict": json.loads(infer.pulled.get("message", "{}")),
+ "daemon_log": infer.pulled.get("daemon_log", ""),
+ },
},
}
@@ -305,8 +302,8 @@ def _mark_failed(
DEFAULT_GATEWAY = "http://env-gateway.ailab.ailab.ai"
-# DEFAULT_LAGENT_SRC = "/mnt/shared-storage-user/llmit/user/liukuikun/workspace/lagent"
-DEFAULT_LAGENT_SRC = "/mnt/shared-storage-user/llmit/user/wangziyi/projs/lagent"
+DEFAULT_LAGENT_SRC = "/mnt/shared-storage-user/llmit/user/liukuikun/workspace/lagent"
+# DEFAULT_LAGENT_SRC = "/mnt/shared-storage-user/llmit/user/wangziyi/projs/lagent"
def _load_dataset_from_config(config_path: Path) -> Any:
@@ -338,6 +335,7 @@ async def _run_one(
*,
uid: dict[str, int],
lagent_src_dir: str | None,
+ llm_model: str | None,
llm_base_url: str | None,
llm_api_key: str | None,
) -> dict[str, Any]:
@@ -350,6 +348,7 @@ async def _run_one(
uid,
provider=provider,
lagent_src_dir=lagent_src_dir,
+ llm_model=llm_model,
llm_base_url=llm_base_url,
llm_api_key=llm_api_key,
)
@@ -406,6 +405,7 @@ async def _guarded(job_idx: int, td: Path, uid: dict[str, int]) -> dict[str, Any
provider,
uid=uid,
lagent_src_dir=lagent_src,
+ llm_model=args.llm_model,
llm_base_url=args.llm_base_url,
llm_api_key=args.llm_api_key,
)
@@ -586,6 +586,7 @@ def main() -> int:
default=DEFAULT_LAGENT_SRC,
help="Local path to lagent source. Pass '' to skip upload.",
)
+ parser.add_argument("--llm-model", default=None)
parser.add_argument("--llm-base-url", default=None)
parser.add_argument("--llm-api-key", default=None)
parser.add_argument(
diff --git a/xtuner/v1/ray/environment/rl_task/sandbox.py b/xtuner/v1/ray/environment/rl_task/sandbox.py
index f8a04f781f..6212eec601 100644
--- a/xtuner/v1/ray/environment/rl_task/sandbox.py
+++ b/xtuner/v1/ray/environment/rl_task/sandbox.py
@@ -376,8 +376,12 @@ async def exec_in(
"""
if env:
command = _expand_vars(command, env)
- prefix = " ".join(f'{k}="{v}"' for k, v in env.items())
- command = f"{prefix} {command}"
+ # Use `export` so vars carry across chained commands (`bash A && bash B`).
+ # Inline `VAR=val cmd1 && cmd2` scopes VAR to cmd1 only, which bites
+ # when entry runs pre_entry.sh && lagent_entry.sh — daemon subprocess
+ # wouldn't see RL_LLM_MODEL etc.
+ exports = "; ".join(f'export {k}="{v}"' for k, v in env.items())
+ command = f"{exports}; {command}"
result = await client.execute(command, cwd, timeout_sec)
rc = _result_code(result)
if raise_on_error and rc != 0:
@@ -487,17 +491,27 @@ def _result_code(exec_res: dict[str, Any]) -> int:
return int(rc)
+_HOOK_STUCK_WARN_SEC = 30.0
+
+
async def _run_hook(hook: Hook, client: Any, ctx: dict[str, Any], *, phase: str) -> None:
"""Run one hook; on failure, log + stash + re-raise with a label that names
the hook class so the traceback says which one blew up."""
name = getattr(hook, "name", None) or type(hook).__name__
- label = f"{phase}-hook {type(hook).__name__}({name!r})"
tid = (ctx.get("data") and getattr(ctx["data"], "id", None)) or "?"
- get_logger().debug(f"[{tid}] {label} start")
+ image = ctx.get("sandbox_image") or "?"
+ label = f"{phase}-hook {type(hook).__name__}({name!r}) image={image}"
+ get_logger().info(f"[{tid}] {label} start")
t0 = time.monotonic()
+ hook_task = asyncio.create_task(hook(client, ctx))
try:
- await hook(client, ctx)
- get_logger().debug(f"[{tid}] {label} done ({time.monotonic() - t0:.2f}s)")
+ while True:
+ done, _ = await asyncio.wait([hook_task], timeout=_HOOK_STUCK_WARN_SEC)
+ if done:
+ break
+ get_logger().warning(f"[{tid}] {label} still running ({time.monotonic() - t0:.0f}s)")
+ await hook_task # re-raise if hook failed
+ get_logger().info(f"[{tid}] {label} done ({time.monotonic() - t0:.2f}s)")
except Exception as exc:
import traceback as _tb
diff --git a/xtuner/v1/ray/evaluator.py b/xtuner/v1/ray/evaluator.py
index 90215bc5ef..3ab169b862 100644
--- a/xtuner/v1/ray/evaluator.py
+++ b/xtuner/v1/ray/evaluator.py
@@ -203,7 +203,29 @@ async def eval_worker_task(self, sample: RLDataFlowItem):
group_sample = await self.env_controller.run.remote(
[sample], sample_params=self.sample_params, extra_params=extra_params
) # type: ignore[attr-defined]
- self.return_list.append(group_sample[0])
+ if not group_sample:
+ # Rollout failed upstream (HTTP 500 / actor death / cancellation);
+ # skip instead of IndexError'ing the append path.
+ return
+ sample_out = group_sample[0]
+ # Avoid `return_list` accumulation. _save_trajectories only reads:
+ # env.agent.extra_info["messages"] (keep — needed for jsonl)
+ # env.agent.extra_info["daemon_log"] (keep — eval-only)
+ # env.rollout.response_ids / .response (keep)
+ # env.judger.reward (keep)
+ # Everything else in extra_info is trainer-side state (per-turn lagent
+ # memory / routed_experts bookkeeping) and at MB-per-sample scale.
+ # Popping here prevents a single eval round from pushing the actor's
+ # Python heap into the 100+GB range (observed: 318GB per actor) and
+ # triggering node-level OOM on the rollout head node.
+ if sample_out.env.agent.extra_info:
+ sample_out.env.agent.extra_info.pop("state", None)
+ sample_out.env.agent.extra_info.pop("agent", None)
+ if sample_out.env.rollout.extra_info:
+ sample_out.env.rollout.extra_info.pop("agent_state_dict", None)
+ sample_out.env.rollout.extra_info.pop("agent_message_dict", None)
+ sample_out.env.rollout.extra_info.pop("routed_experts", None)
+ self.return_list.append(sample_out)
async def concurrent_eval_task_runner(self):
"""Runs evaluation tasks concurrently to generate a batch of samples.
diff --git a/xtuner/v1/ray/rollout/controller.py b/xtuner/v1/ray/rollout/controller.py
index b44513a901..1e9c3f2ee0 100644
--- a/xtuner/v1/ray/rollout/controller.py
+++ b/xtuner/v1/ray/rollout/controller.py
@@ -8,6 +8,7 @@
from typing import Any, Dict, List, Optional, Union
from uuid import uuid4
+import numpy as np
import ray
import uvicorn
from fastapi import FastAPI
@@ -136,9 +137,7 @@ def __init__(
self.num_workers = 0
self.workers_info: Dict[str, WorkerInfo] = {} # url -> WorkerInfo
self.active_rollout_workers: List[RolloutWorker] = []
- tokenizer_path = infer_config.tokenizer_path
- assert tokenizer_path is not None, "tokenizer_path must be set before creating RolloutController"
- self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
+ self.tokenizer = AutoTokenizer.from_pretrained(infer_config.tokenizer_path, trust_remote_code=True)
self.workers, self.rank_bundle_idx_list = AutoAcceleratorWorkers.from_placement_group(
self._get_worker_cls(), infer_config, placement_group
)
@@ -171,7 +170,7 @@ def __init__(
Qwen3TokenReasonParser,
)
- self.reasoning_parser = Qwen3TokenReasonParser(tokenizer_path)
+ self.reasoning_parser = Qwen3TokenReasonParser(infer_config.tokenizer_path)
self.tool_call_parser = Qwen3_5FunctionCallParser()
def _get_worker_status_for_router(self) -> Dict[RolloutWorker, bool]:
@@ -486,6 +485,10 @@ def start_api_server(self, host: str = "0.0.0.0", port: int = 8000):
@app.post("/v1/chat/completions")
async def chat_completions(request: LagentChatCompletionRequest):
+ import base64
+
+ from ray import cloudpickle
+
inputs = tokenize(self.tokenizer, request.messages, request.tools)
response: RLRolloutResponseItem = await self.rollout(
prompt=request.messages,
@@ -498,11 +501,28 @@ async def chat_completions(request: LagentChatCompletionRequest):
{"routed_experts": inputs["routed_experts"]} if inputs["routed_experts"] is not None else {}
),
)
+ # HTTP boundary needs JSON-serializable response. In-cluster flow
+ # uses inline numpy (zero-copy via Ray RPC) but FastAPI's
+ # jsonable_encoder can't handle ndarray. Round-trip via xtuner's
+ # RoutedExpertStore: register the numpy, carry the str key over HTTP,
+ # and the client's next-turn request returns it through worker.py's
+ # `isinstance(history, str)` branch.
+ re_val = response.extra_info.get("routed_experts")
+ if isinstance(re_val, np.ndarray):
+ from xtuner.v1.ray.rollout.routed_expert_store import get_store
+
+ store = get_store()
+ local_ref = ray.put(re_val)
+ key = await store.put_ref.remote([local_ref])
+ response.extra_info["routed_experts"] = key
+ elif isinstance(re_val, ray.ObjectRef):
+ # Legacy path kept as a defensive fallback — encode the same
+ # way as before so older clients still decode correctly.
+ response.extra_info["routed_experts"] = base64.b64encode(cloudpickle.dumps(re_val)).decode("utf-8")
message = AgentMessage.from_model_response(response, "assistant")
message = self.reasoning_parser.parse_response(message)
message = self.tool_call_parser.parse_response(message)
completion_message = LagentChatCompletionMessage.from_agent_message(message)
- completion_tokens = response.num_return_tokens or 0
return LagentChatCompletion(
model=request.model,
choices=[
@@ -516,8 +536,8 @@ async def chat_completions(request: LagentChatCompletionRequest):
],
usage=CompletionUsage(
prompt_tokens=len(inputs["input_ids"]),
- completion_tokens=completion_tokens,
- total_tokens=len(inputs["input_ids"]) + completion_tokens,
+ completion_tokens=response.num_return_tokens,
+ total_tokens=len(inputs["input_ids"]) + response.num_return_tokens,
),
).model_dump()
diff --git a/xtuner/v1/ray/rollout/lmdeploy.py b/xtuner/v1/ray/rollout/lmdeploy.py
index f4d3cab87b..c556a09a12 100644
--- a/xtuner/v1/ray/rollout/lmdeploy.py
+++ b/xtuner/v1/ray/rollout/lmdeploy.py
@@ -10,7 +10,6 @@
from ray.util.placement_group import placement_group_table
from transformers import AutoTokenizer
-from xtuner.v1.data_proto.rl_data import RolloutExtraInfo
from xtuner.v1.ray.config import RolloutConfig
from .worker import RolloutWorker
@@ -18,27 +17,6 @@
SHARED_STORE = "shared_store"
SHARED_STORE_NAMESPACE = "lmdeploy"
-_LMDEPLOY_ACTOR = None
-
-
-def get_lmdeploy_routed_experts_ref(routed_experts: Any):
- global _LMDEPLOY_ACTOR
- if isinstance(routed_experts, str):
- if _LMDEPLOY_ACTOR is None:
- _LMDEPLOY_ACTOR = ray.get_actor(SHARED_STORE, namespace=SHARED_STORE_NAMESPACE)
- return _LMDEPLOY_ACTOR.get.remote(routed_experts)
- if isinstance(routed_experts, ray.ObjectRef):
- return routed_experts
- return ray.put(torch.as_tensor(routed_experts))
-
-
-def put_lmdeploy_routed_experts_ref(routed_experts: Any, return_key: bool):
- global _LMDEPLOY_ACTOR
- if return_key:
- if _LMDEPLOY_ACTOR is None:
- _LMDEPLOY_ACTOR = ray.get_actor(SHARED_STORE, namespace=SHARED_STORE_NAMESPACE)
- return ray.get(_LMDEPLOY_ACTOR.put.remote(routed_experts))
- return ray.put(routed_experts)
def run_lmdeploy_server_wrapper(lmdeploy_config_namespace: Namespace):
@@ -101,7 +79,7 @@ def __init__(
self.api_keys = self.config.api_key
self.model_name = self.config.model_name
self.enable_return_routed_experts = self.config.enable_return_routed_experts
- self.return_routed_experts_key = self.config.return_routed_experts_key
+ self.lmdeploy_actor = None
async def _create_request(
self,
@@ -236,42 +214,14 @@ def reset_prefix_cache(self):
"""It will implemented for LMDeploy worker in the future."""
pass
- async def _handle_routed_experts_response(
- self,
- root_id: Any,
- action_id: Any,
- response: dict,
- input_extra_info: RolloutExtraInfo,
- extra_info: RolloutExtraInfo,
- finish_reason: str,
- ) -> None:
- exist_history_routed_experts = (
- "routed_experts" in input_extra_info and input_extra_info["routed_experts"] is not None
- )
- routed_experts = response["meta_info"].pop("routed_experts") # token[layer[expert]]
- if routed_experts is not None and not exist_history_routed_experts:
- if isinstance(routed_experts, str) and self.return_routed_experts_key:
- extra_info["routed_experts"] = routed_experts
- else:
- extra_info["routed_experts"] = get_lmdeploy_routed_experts_ref(routed_experts)
- elif routed_experts is not None and exist_history_routed_experts:
- cur_routed_experts_ref = get_lmdeploy_routed_experts_ref(routed_experts)
- history_routed_experts_ref = get_lmdeploy_routed_experts_ref(input_extra_info["routed_experts"])
- cur_routed_experts = await cur_routed_experts_ref # n, layer, expert
- history_routed_experts = await history_routed_experts_ref # n, layer, expert
- ray.internal.free([cur_routed_experts_ref, history_routed_experts_ref], local_only=False)
- concat_routed_experts = self._concat_partial_routed_experts(
- root_id, action_id, response, history_routed_experts, cur_routed_experts
- )
- extra_info["routed_experts"] = put_lmdeploy_routed_experts_ref(
- concat_routed_experts, self.return_routed_experts_key
- )
- del history_routed_experts
- del cur_routed_experts
- else:
- assert finish_reason == "abort", (
- f"routed_experts is None, but finish_reason is {finish_reason}, expected abort. response: {response}"
- )
+ def _decode_routed_experts(self, routed_experts: Any):
+ if isinstance(routed_experts, str):
+ if self.lmdeploy_actor is None:
+ self.lmdeploy_actor = ray.get_actor(SHARED_STORE, namespace=SHARED_STORE_NAMESPACE)
+ assert self.lmdeploy_actor is not None, "LMDeploy actor should be available in the shared store."
+ routed_experts_ref = self.lmdeploy_actor.get.remote(routed_experts)
+ return routed_experts_ref
+ return torch.tensor(routed_experts)
def _transform_rollout_config_to_server_configs(self) -> Namespace:
"""Transform the RolloutConfig into a Namespace suitable for the
@@ -323,8 +273,6 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace:
extra_engine_config: Dict[str, Any] = {}
if backend == "pytorch" and self.config.enable_return_routed_experts:
extra_engine_config["enable_return_routed_experts"] = True
- if self.config.return_routed_experts_key:
- extra_engine_config["enable_transfer_obj_ref"] = True
if backend == "pytorch" and self.config.router_n_groups:
hf_overrides = extra_engine_config.setdefault("hf_overrides", {})
hf_overrides.update(router_n_groups=self.config.router_n_groups)
diff --git a/xtuner/v1/ray/rollout/routed_expert_store.py b/xtuner/v1/ray/rollout/routed_expert_store.py
new file mode 100644
index 0000000000..0d1f3169d4
--- /dev/null
+++ b/xtuner/v1/ray/rollout/routed_expert_store.py
@@ -0,0 +1,198 @@
+"""Named Ray actor that owns all ``routed_experts`` ObjectRefs.
+
+Motivation:
+ lmdeploy workers produce ``routed_experts`` per rollout and encode them as
+ Ray ``ObjectRef`` values. Those refs travel through HTTP/JSON (the agent
+ sandbox roundtrip) as cloudpickled hex strings, spending minutes to hours
+ outside any Python process. Ray's distributed refcount + cloudpickle
+ out-of-band tracking is fragile at that timescale — under heartbeat flakes
+ or owner restart the object can be evicted, and training later hits
+ ``ObjectFetchTimedOutError`` with "no locations were found".
+
+ Routing the refs through a dedicated long-lived store decouples their
+ lifetime from lmdeploy's lifecycle. Transport between actors/processes
+ becomes a short uuid string (``put_tensor`` → ``get_ref`` → ``release``)
+ instead of a cloudpickled ObjectRef blob, so the failure mode is replaced
+ by a plain dict lookup that either hits or misses.
+"""
+
+from __future__ import annotations
+
+import os
+import time
+import uuid
+from typing import Any
+
+import ray
+from ray import ObjectRef
+
+from xtuner.v1.utils import get_logger
+
+
+_STORE_NAME = "routed_expert_store"
+# Generous default — covers long rollouts that may not be consumed immediately
+# (e.g. eval groups, aborted tasks awaiting GC, etc). Overridable via env var
+# for quick tuning without re-editing the source.
+_DEFAULT_TTL_SEC = int(os.environ.get("XTUNER_ROUTED_EXPERT_TTL_SEC", 24 * 3600))
+_DEFAULT_GC_INTERVAL_SEC = int(os.environ.get("XTUNER_ROUTED_EXPERT_GC_INTERVAL_SEC", 300))
+
+
+@ray.remote(num_cpus=0)
+class RoutedExpertStore:
+ """Dedicated store actor for routed_experts ObjectRefs.
+
+ Every ref in ``self._store`` has the store as its Ray owner (via
+ ``ray.put``), so its lifetime is bounded purely by the store's Python
+ dict — no distributed heartbeats or cloudpickle TTL involved.
+
+ GC runs inline on each ``put_*`` call: if the last sweep happened more
+ than ``gc_interval_sec`` ago, stale keys older than ``ttl_sec`` are
+ evicted. Keeps the store honest without a background task.
+ """
+
+ def __init__(self, ttl_sec: int = _DEFAULT_TTL_SEC, gc_interval_sec: int = _DEFAULT_GC_INTERVAL_SEC):
+ self._store: dict[str, tuple[ObjectRef, float]] = {}
+ self._ttl = ttl_sec
+ self._gc_interval = gc_interval_sec
+ self._last_gc = time.monotonic()
+ # Diagnostic counters — help locate double-consume / leak symptoms
+ # after the fact without changing runtime behaviour.
+ self._n_put = 0
+ self._n_get = 0
+ self._n_release = 0
+ self._n_missing_get = 0
+ self._n_missing_release = 0
+
+ def put_tensor(self, tensor: Any) -> str:
+ """Take ownership of a tensor via ray.put; return a uuid key.
+
+ NOTE: this places the tensor in the STORE actor's local plasma —
+ under heavy traffic this can cause one-node plasma saturation (all
+ rollouts across the cluster funnel to whichever node this actor
+ runs on). Prefer ``put_ref`` + worker-side ``ray.put`` for
+ multi-node scale.
+ """
+ self._maybe_gc()
+ ref = ray.put(tensor)
+ self._n_put += 1
+ return self._stash(ref)
+
+ def put_ref(self, wrapped: list) -> str:
+ """Register an externally-owned ObjectRef, return a uuid key.
+
+ The caller must wrap the ref in a single-element list. Ray
+ auto-dereferences bare ObjectRef args; the list wrapper bypasses
+ that so the store receives the ref itself (not the materialized
+ tensor).
+
+ Ownership stays with the original ``ray.put`` caller (typically
+ the rollout worker); data lives in that caller's node plasma.
+ The store just holds a Python-level strong reference so Ray's
+ distributed refcount keeps the object alive across consumers.
+
+ This distributes plasma pressure across all rollout nodes instead
+ of funneling it to the store's node. Trade-off: if the owner
+ worker dies, the ref's object is lost; with ``put_tensor`` only
+ the store dying would lose data.
+ """
+ self._maybe_gc()
+ if not (isinstance(wrapped, list) and len(wrapped) == 1):
+ raise TypeError(f"put_ref expects [ObjectRef] to bypass auto-deref, got {type(wrapped)}")
+ ref = wrapped[0]
+ if not isinstance(ref, ObjectRef):
+ raise TypeError(f"put_ref expects [ObjectRef], got [{type(ref)}]")
+ self._n_put += 1
+ return self._stash(ref)
+
+ def get_ref(self, key: str) -> ObjectRef:
+ """Return the stashed ObjectRef; caller runs ``ray.get`` to
+ materialize."""
+ entry = self._store.get(key)
+ if entry is None:
+ self._n_missing_get += 1
+ raise KeyError(f"RoutedExpertStore: key not found: {key}")
+ self._store[key] = (entry[0], time.monotonic())
+ self._n_get += 1
+ return entry[0]
+
+ def release(self, key: str) -> None:
+ if self._store.pop(key, None) is None:
+ self._n_missing_release += 1
+ else:
+ self._n_release += 1
+
+ def release_many(self, keys: list[str]) -> None:
+ for k in keys:
+ self.release(k)
+
+ def stats(self) -> dict:
+ return {
+ "live": len(self._store),
+ "ttl_sec": self._ttl,
+ "n_put": self._n_put,
+ "n_get": self._n_get,
+ "n_release": self._n_release,
+ "n_missing_get": self._n_missing_get,
+ "n_missing_release": self._n_missing_release,
+ }
+
+ def _stash(self, ref: ObjectRef) -> str:
+ key = uuid.uuid4().hex
+ self._store[key] = (ref, time.monotonic())
+ return key
+
+ def _maybe_gc(self) -> None:
+ now = time.monotonic()
+ if now - self._last_gc < self._gc_interval:
+ return
+ self._last_gc = now
+ stale = [k for k, (_, t) in self._store.items() if now - t > self._ttl]
+ for k in stale:
+ self._store.pop(k, None)
+ if stale:
+ get_logger().warning(
+ f"RoutedExpertStore GC: evicted {len(stale)} stale keys (ttl={self._ttl}s, remaining={len(self._store)})"
+ )
+
+
+_handle_cache: Any = None
+
+
+def get_store():
+ """Process-local cached handle to the singleton store actor.
+
+ First tries ``ray.get_actor`` (fast path if another caller created it
+ already). On NotFound, tries to create; if creation races with another
+ caller (raises because name is taken), falls back to another lookup.
+ """
+ global _handle_cache
+ if _handle_cache is not None:
+ return _handle_cache
+
+ # Fast path: already exists.
+ try:
+ _handle_cache = ray.get_actor(_STORE_NAME)
+ return _handle_cache
+ except ValueError:
+ pass
+
+ # Slow path: try to create. Race with concurrent callers is handled by
+ # catching the "name already taken" case and re-looking-up.
+ import time as _time
+
+ for attempt in range(10):
+ try:
+ _handle_cache = RoutedExpertStore.options(name=_STORE_NAME).remote()
+ return _handle_cache
+ except ValueError as exc:
+ # Either "name already taken" (someone else won) or transient
+ # issue. Try to look up; if that also fails, back off.
+ try:
+ _handle_cache = ray.get_actor(_STORE_NAME)
+ return _handle_cache
+ except ValueError:
+ get_logger().debug(f"RoutedExpertStore bootstrap retry {attempt}: {exc}")
+ _time.sleep(0.2 * (attempt + 1))
+ continue
+
+ raise RuntimeError(f"RoutedExpertStore: failed to acquire named actor {_STORE_NAME!r} after retries")
diff --git a/xtuner/v1/ray/rollout/worker.py b/xtuner/v1/ray/rollout/worker.py
index 4f9e3e17c1..833b374697 100644
--- a/xtuner/v1/ray/rollout/worker.py
+++ b/xtuner/v1/ray/rollout/worker.py
@@ -14,11 +14,11 @@
import ray
import requests # type: ignore[import-untyped]
from packaging.version import Version
-from ray import ObjectRef
+from ray import ObjectRef, cloudpickle
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from transformers import AutoTokenizer
-from xtuner.v1.data_proto.rl_data import RLRolloutResponseItem, RolloutExtraInfo, RolloutState
+from xtuner.v1.data_proto.rl_data import RLRolloutResponseItem, RolloutState
from xtuner.v1.ray import find_master_addr_and_port
from xtuner.v1.ray.base import AutoAcceleratorWorkers, SingleAcceleratorWorker
from xtuner.v1.ray.config import RolloutConfig
@@ -156,80 +156,6 @@ def init(self, dist_init_addr: str = ""):
def _decode_routed_experts(self, routed_experts: Any) -> Any:
return routed_experts
- def _concat_partial_routed_experts(
- self,
- root_id: Any,
- action_id: Any,
- response: dict,
- history_routed_experts: Any,
- cur_routed_experts: Any,
- ) -> Any:
- assert (history_routed_experts.shape[0] - 1) > 0 and history_routed_experts.shape[
- 0
- ] - 1 <= cur_routed_experts.shape[0], (
- f"Existing routed_experts shape: {history_routed_experts.shape}, current routed_experts shape: {cur_routed_experts.shape}"
- )
- init_cur_roued_experts = cur_routed_experts.shape[0]
- cur_routed_experts = cur_routed_experts[history_routed_experts.shape[0] :, :, :]
- concat_routed_experts = np.concatenate((history_routed_experts, cur_routed_experts), axis=0)
- prompt_tokens = response["meta_info"].get("prompt_tokens", 0)
- response_tokens = response["meta_info"].get("completion_tokens", 0)
- assert concat_routed_experts.shape[0] == prompt_tokens + response_tokens - 1, (
- f"Routed experts shape {concat_routed_experts.shape[0]} does not match total tokens {prompt_tokens + response_tokens - 1}"
- )
- self.logger.debug(
- f"[{root_id}/{action_id}] Partial Rollout Stats: "
- f"Tokens(prompt={prompt_tokens}, response={response_tokens}, total={prompt_tokens + response_tokens}) | "
- f"Experts(exist={history_routed_experts.shape}, init_cur={init_cur_roued_experts}, cur={cur_routed_experts.shape}, concat={concat_routed_experts.shape})"
- )
- return concat_routed_experts
-
- async def _handle_routed_experts_response(
- self,
- root_id: Any,
- action_id: Any,
- response: dict,
- input_extra_info: RolloutExtraInfo,
- extra_info: RolloutExtraInfo,
- finish_reason: str,
- ) -> None:
- exist_history_routed_experts = (
- "routed_experts" in input_extra_info and input_extra_info["routed_experts"] is not None
- )
- routed_experts = response["meta_info"].pop("routed_experts") # token[layer[expert]]
- if routed_experts is not None and not exist_history_routed_experts:
- routed_experts = self._decode_routed_experts(routed_experts)
- assert not isinstance(routed_experts, str), (
- "String routed_experts keys must be handled by the backend-specific rollout worker."
- )
- if not isinstance(routed_experts, ObjectRef):
- routed_experts = ray.put(routed_experts)
- extra_info["routed_experts"] = routed_experts
- elif routed_experts is not None and exist_history_routed_experts:
- routed_experts = self._decode_routed_experts(routed_experts)
- if isinstance(routed_experts, ObjectRef):
- cur_routed_experts = await routed_experts # n, layer, expert
- ray.internal.free(routed_experts, local_only=False)
- else:
- cur_routed_experts = routed_experts
-
- history_routed_experts_ref = input_extra_info["routed_experts"]
- assert isinstance(history_routed_experts_ref, ObjectRef), (
- "Base rollout worker expects history routed_experts to be a Ray ObjectRef."
- )
- history_routed_experts = await history_routed_experts_ref # n, layer, expert
- ray.internal.free(history_routed_experts_ref, local_only=False)
- concat_routed_experts = self._concat_partial_routed_experts(
- root_id, action_id, response, history_routed_experts, cur_routed_experts
- )
- extra_info["routed_experts"] = ray.put(concat_routed_experts)
- del history_routed_experts
- del cur_routed_experts
- else:
- assert finish_reason == "abort", (
- f"routed_experts is None, but finish_reason is {finish_reason}, expected abort. response: {response}"
- )
-
def set_engine_rank_mesh_array(self, engine_rank_mesh_array: list[list[int]]):
self.engine_rank_mesh_array = engine_rank_mesh_array
@@ -625,7 +551,7 @@ async def _handle_non_stream_response(
if "return_token_ids" in extra_params and extra_params["return_token_ids"]:
last_logprobs: list[float] = []
try:
- extra_info: RolloutExtraInfo = {}
+ extra_info = {}
finish_reason = response["meta_info"]["finish_reason"]["type"]
if finish_reason == "abort" and self.receive_abort_request.is_set() is False:
self.receive_abort_request.set()
@@ -648,14 +574,79 @@ async def _handle_non_stream_response(
assert "routed_experts" in response["meta_info"], (
"enable_return_routed_experts is True, but routed_experts is not in meta_info"
)
- await self._handle_routed_experts_response(
- root_id=root_id,
- action_id=action_id,
- response=response,
- input_extra_info=input_extra_info,
- extra_info=extra_info,
- finish_reason=finish_reason,
+ exist_history_routed_experts = (
+ "routed_experts" in input_extra_info and input_extra_info["routed_experts"] is not None
)
+ routed_experts = response["meta_info"].pop("routed_experts") # token[layer[expert]]
+ if routed_experts is not None and not exist_history_routed_experts:
+ # Turn 1: materialize tensor inline. lmdeploy's _SHARED_STORE.get is
+ # destructive (pop+ray.get) so awaiting `decoded` already releases
+ # lmdeploy's plasma copy; we just carry the numpy through extra_info.
+ decoded = self._decode_routed_experts(routed_experts)
+ if isinstance(decoded, ObjectRef):
+ tensor = await decoded
+ else:
+ tensor = decoded
+ extra_info["routed_experts"] = tensor
+ elif routed_experts is not None and exist_history_routed_experts:
+ # Turn 2+: concat history with new decoded chunk. History comes inline
+ # as np.ndarray (new path) or via legacy store/ref types (compat).
+ decoded = self._decode_routed_experts(routed_experts)
+ if isinstance(decoded, ObjectRef):
+ cur_routed_experts = await decoded
+ else:
+ cur_routed_experts = decoded
+
+ history = input_extra_info["routed_experts"]
+ if isinstance(history, np.ndarray):
+ # New path: inline numpy history.
+ history_routed_experts = history
+ elif isinstance(history, str):
+ # Legacy path: uuid key into RoutedExpertStore. Release after use
+ # (there's no retry path that re-consumes the same key in inline mode).
+ from xtuner.v1.ray.rollout.routed_expert_store import get_store as _legacy_get_store
+
+ legacy_store = _legacy_get_store()
+ history_ref = await legacy_store.get_ref.remote(history)
+ history_routed_experts = await history_ref
+ legacy_store.release.remote(history)
+ elif isinstance(history, ObjectRef):
+ history_routed_experts = await history
+ ray.internal.free([history], local_only=False)
+ elif isinstance(history, (bytes, bytearray)):
+ # Legacy path: controller.py serialized ref as base64(cloudpickle(...)).
+ history_ref = cloudpickle.loads(history)
+ history_routed_experts = await history_ref
+ ray.internal.free([history_ref], local_only=False)
+ else:
+ raise TypeError(f"Unexpected type for input_extra_info['routed_experts']: {type(history)}")
+ del input_extra_info
+
+ assert (history_routed_experts.shape[0] - 1) > 0 and history_routed_experts.shape[
+ 0
+ ] - 1 <= cur_routed_experts.shape[0], (
+ f"Existing routed_experts shape: {history_routed_experts.shape}, current routed_experts shape: {cur_routed_experts.shape}"
+ )
+ init_cur_roued_experts = cur_routed_experts.shape[0]
+ cur_routed_experts = cur_routed_experts[history_routed_experts.shape[0] :, :, :]
+ concat_routed_experts = np.concatenate((history_routed_experts, cur_routed_experts), axis=0)
+ prompt_tokens = response["meta_info"].get("prompt_tokens", 0)
+ response_tokens = response["meta_info"].get("completion_tokens", 0)
+ assert concat_routed_experts.shape[0] == prompt_tokens + response_tokens - 1, (
+ f"Routed experts shape {concat_routed_experts.shape[0]} does not match total tokens {prompt_tokens + response_tokens - 1}"
+ )
+ self.logger.debug(
+ f"[{root_id}/{action_id}] Partial Rollout Stats: "
+ f"Tokens(prompt={prompt_tokens}, response={response_tokens}, total={prompt_tokens + response_tokens}) | "
+ f"Experts(exist={history_routed_experts.shape}, init_cur={init_cur_roued_experts}, cur={cur_routed_experts.shape}, concat={concat_routed_experts.shape})"
+ )
+ extra_info["routed_experts"] = concat_routed_experts
+ del history_routed_experts
+ del cur_routed_experts
+ else:
+ assert finish_reason == "abort", (
+ f"routed_experts is None, but finish_reason is {finish_reason}, expected abort. response: {response}"
+ )
# NOTE: When set return_token_ids = True, the response must contain valid token_ids/logprobs.
# If not, we consider it as an invalid response and retry it.
# NOTE: !!! When finish_reason is abort, some queries may not return token_ids or logprobs. !!!
diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py
index a16cfc77e7..db7eba77d5 100644
--- a/xtuner/v1/rl/base/worker.py
+++ b/xtuner/v1/rl/base/worker.py
@@ -6,6 +6,7 @@
from pathlib import Path
from typing import Dict, Iterable, List, Sequence, TypeAlias, TypedDict, cast
+import numpy as np
import ray
import requests
import torch
@@ -13,6 +14,7 @@
import tqdm
from mmengine.runner import set_random_seed
from pydantic import BaseModel, ConfigDict
+from ray import cloudpickle
from ray.actor import ActorClass, ActorProxy
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.tensor import DTensor
@@ -36,6 +38,7 @@
from xtuner.v1.model.utils.misc import ModelForwardExtraLogInfo
from xtuner.v1.ray.base import SingleAcceleratorWorker
from xtuner.v1.ray.config import RolloutConfig
+from xtuner.v1.ray.rollout.routed_expert_store import get_store
from xtuner.v1.rl.base.loss import BaseRLLossContext
from xtuner.v1.train.trainer import LoadCheckpointConfig
from xtuner.v1.utils import (
@@ -402,7 +405,9 @@ def compute_ref_logprobs(
return ref_logprobs_list
def _add_rollout_routed_experts(
- self, seq_ctx: SequenceContext, rollout_routed_experts: torch.Tensor | list[torch.Tensor | ray.ObjectRef]
+ self,
+ seq_ctx: SequenceContext,
+ rollout_routed_experts: torch.Tensor | list[torch.Tensor | ray.ObjectRef | np.ndarray | str | bytes],
):
language_cfg = (
self.config.model_cfg.text_config
@@ -411,6 +416,12 @@ def _add_rollout_routed_experts(
)
to_free_routed_expert_refs: list[ray.ObjectRef] = []
+ # Store keys to release AFTER consumption. Training is 1:1 — each
+ # rollout's final key appears in this train step exactly once —
+ # so releasing here is safe (no double-consume from training side).
+ # SP-mesh: multiple ranks see the same data; only rank-0 releases,
+ # gated by dist.barrier() so other ranks finish their ray.get first.
+ to_release_pin_keys: list[str] = []
if isinstance(rollout_routed_experts, list):
# list[n,l,e]
out_rollout_routed_expert = []
@@ -426,8 +437,47 @@ def _add_rollout_routed_experts(
),
)
out_rollout_routed_expert.append(rollout_routed_experts_tensor)
+ elif isinstance(rollout_routed_expert, np.ndarray):
+ # Inline path: numpy array carried directly in extra_info.
+ rollout_routed_expert = torch.as_tensor(rollout_routed_expert, dtype=torch.long)
+ rollout_routed_expert = rollout_routed_expert.reshape(
+ -1, language_cfg.num_hidden_layers, language_cfg.num_experts_per_tok
+ )
+ out_rollout_routed_expert.append(rollout_routed_expert)
+ elif isinstance(rollout_routed_expert, ray.ObjectRef):
+ # Inline + replay_buffer wrap: ray.put(numpy) → ObjectRef. Owner is the
+ # replay_buffer actor (long-lived), so ray.get is safe.
+ rollout_routed_expert_ref = rollout_routed_expert
+ rollout_routed_expert = ray.get(rollout_routed_expert_ref)
+ if self.sp_mesh is None or self.sp_mesh.size() == 1:
+ ray.internal.free([rollout_routed_expert_ref], local_only=False)
+ elif self.sp_mesh.get_local_rank() == 0:
+ to_free_routed_expert_refs.append(rollout_routed_expert_ref)
+ rollout_routed_expert = torch.as_tensor(rollout_routed_expert, dtype=torch.long)
+ rollout_routed_expert = rollout_routed_expert.reshape(
+ -1, language_cfg.num_hidden_layers, language_cfg.num_experts_per_tok
+ )
+ out_rollout_routed_expert.append(rollout_routed_expert)
+ elif isinstance(rollout_routed_expert, str):
+ # Legacy path: uuid key → RoutedExpertStore. Owner is the
+ # store actor, so ray.get is reliable regardless of
+ # lmdeploy worker state.
+ store = get_store()
+ pin_key = rollout_routed_expert
+ pin_ref = ray.get(store.get_ref.remote(pin_key))
+ rollout_routed_expert = ray.get(pin_ref)
+ if self.sp_mesh is None or self.sp_mesh.size() == 1:
+ store.release.remote(pin_key)
+ elif self.sp_mesh.get_local_rank() == 0:
+ to_release_pin_keys.append(pin_key)
+ rollout_routed_expert = torch.as_tensor(rollout_routed_expert, dtype=torch.long)
+ rollout_routed_expert = rollout_routed_expert.reshape(
+ -1, language_cfg.num_hidden_layers, language_cfg.num_experts_per_tok
+ )
+ out_rollout_routed_expert.append(rollout_routed_expert)
else:
- rollout_routed_expert_refs = rollout_routed_expert
+ # Legacy path: bytes (base64-decoded cloudpickle of ObjectRef).
+ rollout_routed_expert_refs = cloudpickle.loads(rollout_routed_expert)
rollout_routed_expert = ray.get(rollout_routed_expert_refs)
# free obj store explicitly
if self.sp_mesh is None or self.sp_mesh.size() == 1:
@@ -465,7 +515,10 @@ def _add_rollout_routed_experts(
dist.barrier()
for free_routed_expert_refs in to_free_routed_expert_refs:
ray.internal.free(free_routed_expert_refs, local_only=False)
+ if to_release_pin_keys:
+ get_store().release_many.remote(to_release_pin_keys)
del to_free_routed_expert_refs
+ del to_release_pin_keys
@ray_method
def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int) -> WorkerLogItem:
diff --git a/xtuner/v1/train/agent_rl_trainer.py b/xtuner/v1/train/agent_rl_trainer.py
index 16159957dc..714d782abc 100644
--- a/xtuner/v1/train/agent_rl_trainer.py
+++ b/xtuner/v1/train/agent_rl_trainer.py
@@ -23,7 +23,6 @@
from xtuner.v1.ray.dataflow import DataFlow, DataFlowConfig, ReplayBufferConfig
from xtuner.v1.ray.environment.lagent.tokenize import tokenize
from xtuner.v1.ray.evaluator import Evaluator, EvaluatorConfig
-from xtuner.v1.ray.rollout.lmdeploy import get_lmdeploy_routed_experts_ref
from xtuner.v1.rl.base import WorkerConfig
from xtuner.v1.rl.config.advantage import BaseAdvantageConfig, GRPOAdvantageConfig
from xtuner.v1.train.trainer import LoadCheckpointConfig
@@ -43,6 +42,24 @@
DEVICE_MODULE = get_torch_device_module()
+_COMPACT_MSG_KEYS = ("reasoning_content", "thinking", "content", "tool_calls", "raw_content", "name", "tool_call_id")
+
+
+def _compact_message(msg: dict) -> dict:
+ """Project a memory message dict to the fields worth inspecting in a
+ trajectory.
+
+ Keeps role, reasoning (under either schema name), content, tool_calls, raw_content (for debugging parser failures),
+ and tool identifiers. Drops token ids / logprobs / session metadata that would bloat the jsonl.
+ """
+ out: dict = {"role": msg.get("role", "assistant")}
+ for key in _COMPACT_MSG_KEYS:
+ val = msg.get(key)
+ if val not in (None, "", [], {}):
+ out[key] = val
+ return out
+
+
class AgentRLTrainerConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
@@ -456,24 +473,15 @@ def _save_trajectories(self, data_groups, save_path, rollout_idx=None, is_eval:
f.write("\n")
for group in data_groups:
for data in group:
+ messages = data.env.agent.extra_info.get("messages", [])
entry = {
- # "raw_prompt": data.data.extra_info["raw_prompt"],
- "prompt": [
- {
- "role": msg["role"],
- "content": msg["raw_content"] if "raw_content" in msg else msg["content"],
- }
- for msg in data.env.agent.extra_info.get("messages", [])[:-1]
- ],
- "response": data.env.rollout.response,
+ "prompt": [_compact_message(msg) for msg in messages[:-1]],
+ "response": _compact_message(messages[-1]) if messages else None,
"response_len": len(data.env.rollout.response_ids or []),
- # "label": data.data.reward_model["ground_truth"],
"reward": data.env.judger.reward["score"],
- # "round": sum(msg['role'] == 'assistant' for msg in data.env.agent.extra_info['messages'][:-1]),
- # "judger_response": data.env.judger.extra_info,
}
- # if "completions" in data.env.agent.extra_info:
- # entry["completions"] = data.env.agent.extra_info["completions"]
+ if is_eval:
+ entry["daemon_log"] = data.env.agent.extra_info.get("daemon_log", "")
json.dump(entry, f, ensure_ascii=False, indent=2)
f.write("\n")
@@ -619,10 +627,7 @@ def _extract_score(value, default=0.0):
), f"{rollout_logprobs.size()} vs {shifted_labels.size()}"
seq_ctx = SequenceContext.from_input_ids((input_ids,), device="cpu")
- routed_experts = inputs["routed_experts"]
- if isinstance(routed_experts, str):
- routed_experts = get_lmdeploy_routed_experts_ref(routed_experts)
- seq_ctx.rollout_routed_experts = routed_experts
+ seq_ctx.rollout_routed_experts = inputs["routed_experts"]
data_batches.append(
dict(
seq_ctx=seq_ctx,
diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py
index 7c80065120..986f7886f3 100644
--- a/xtuner/v1/train/rl_trainer.py
+++ b/xtuner/v1/train/rl_trainer.py
@@ -26,7 +26,6 @@
from xtuner.v1.ray.environment import SingleTurnEnvironment, SingleTurnEnvironmentProxy
from xtuner.v1.ray.evaluator import Evaluator, EvaluatorConfig
from xtuner.v1.ray.judger import JudgerConfig
-from xtuner.v1.ray.rollout.lmdeploy import get_lmdeploy_routed_experts_ref
from xtuner.v1.rl.base import (
TrainingController,
TrainingControllerProxy,
@@ -554,7 +553,26 @@ def _rollout_step(self, rollout_idx: int, step_timer_dict: dict) -> RolloutInfo:
"""Performs a single rollout step to generate experience."""
with timer("generation", step_timer_dict):
ray.get(self._rollout_env_controller.update_active_workers.remote())
- dataflow_result = ray.get(self._rollout_dataflow.run.remote())
+ # Ultimate fallback: without a timeout, dataflow.run can hang
+ # forever when inner rollout workers die and their pending tasks
+ # never resolve. 12h is deliberately loose — normal long-tail
+ # rollout batches run ~5h, so anything past 12h is almost
+ # certainly a real deadlock worth surfacing to the driver.
+ dataflow_ref = self._rollout_dataflow.run.remote()
+ try:
+ dataflow_result = ray.get(dataflow_ref, timeout=12 * 3600)
+ except ray.exceptions.GetTimeoutError:
+ self.logger.error(
+ f"rollout_idx {rollout_idx}: dataflow.run exceeded 12h, "
+ f"likely a stuck rollout/sandbox. cancelling task and raising."
+ )
+ try:
+ ray.cancel(dataflow_ref, force=True)
+ except Exception as exc:
+ self.logger.warning(f"ray.cancel of dataflow task failed: {exc}")
+ raise RuntimeError(
+ f"dataflow.run hung for 12h at rollout_idx={rollout_idx}; check ray dashboard for dead actors"
+ ) from None
if XTUNER_DETERMINISTIC:
data_groups, multimodal_train_infos = self._sort_rollout_outputs(
@@ -822,8 +840,6 @@ def _prepare_train_data(self, data_groups, pack_max_length, multimodal_train_inf
if "routed_experts" in group[i].env.rollout.extra_info:
routed_experts = group[i].env.rollout.extra_info.pop("routed_experts") # n,layer*expert
- if isinstance(routed_experts, str):
- routed_experts = get_lmdeploy_routed_experts_ref(routed_experts)
seq_ctx.rollout_routed_experts = routed_experts # n,layer,expert
data_batches.append(data_dict)