Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions xtuner/v1/rl/agent_loop/single_turn_agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from xtuner.v1.rl.utils import create_task

from .agent_loop import AgentLoop, AgentLoopConfig
from .utils import PartialRolloutHandler


class SingleTurnAgentLoopConfig(AgentLoopConfig):
Expand Down Expand Up @@ -34,26 +33,18 @@ def __init__(
enable_batch_judge: bool = False,
):
super().__init__(rollout_ctl, sample_params, hf_checkpoint, judger, logger)
self.max_tokens = self.sample_params.max_tokens
self.partial_rollout_handler = PartialRolloutHandler(max_tokens=self.max_tokens)
self.enable_batch_judge = enable_batch_judge

async def generate_sample(
self,
rollout_state: RolloutState,
**kwargs,
) -> RolloutState:
enable_partial_rollout = kwargs.get("enable_partial_rollout", False)

# rollout state 预处理, enable_partial_rollout = True 会在这里拼接 token 和修正 max_token
rollout_state = self.partial_rollout_handler.preprocess(rollout_state, enable_partial_rollout)
if not rollout_state.tokens:
rollout_state.tokens = rollout_state.prompt_ids

# 推理引擎generate, 生成的结果会覆盖到 rollout_state.response_ids 上
rollout_state = await self.rollout_ctl.generate.remote(rollout_state) # type: ignore[attr-defined]
# rollout state 后处理: 合并 partial rollout 的历史上下文
rollout_state = self.partial_rollout_handler.postprocess(rollout_state)
# 非 COMPLETED 状态(如被截断、放弃等)直接早退,不触发打分
if rollout_state.status != Status.COMPLETED:
return rollout_state
Expand Down
102 changes: 0 additions & 102 deletions xtuner/v1/rl/agent_loop/utils.py

This file was deleted.

5 changes: 4 additions & 1 deletion xtuner/v1/rl/agent_loop_manager/agent_loop_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,10 @@ def build(
judger=build_judger(task_cfg.judger_config) if task_cfg.judger_config is not None else None,
logger=logger,
)
produce_strategy = task_cfg.produce_strategy_config.build(sync_weights_interval=sync_weights_interval)
produce_strategy = task_cfg.produce_strategy_config.build(
sync_weights_interval=sync_weights_interval,
rollout_controller=rollout_controller,
)
sampler = task_cfg.sampler_config.build(tokenizer=tokenizer, replay_buffer=replay_buffer)
task_runners.append(
_TaskRunner(
Expand Down
50 changes: 30 additions & 20 deletions xtuner/v1/rl/agent_loop_manager/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import Protocol, runtime_checkable
from typing import TYPE_CHECKING, Optional, Protocol, runtime_checkable


if TYPE_CHECKING:
from xtuner.v1.rl.rollout.controller import RolloutControllerProxy

import ray
from pydantic import BaseModel, ConfigDict, Field
Expand Down Expand Up @@ -71,19 +75,12 @@ class ProduceBatchStatus(Enum):
async def _timed_generate_group(
agent_loop: AgentLoopSpec,
rollout_state: list[RolloutState],
enable_partial_rollout: bool = False,
) -> list[RolloutState]:
start = time.perf_counter()
if isinstance(agent_loop, ray.actor.ActorHandle):
result = await agent_loop.generate_group.remote(
rollout_state,
enable_partial_rollout=enable_partial_rollout,
)
result = await agent_loop.generate_group.remote(rollout_state)
else:
result = await agent_loop.generate_group(
rollout_state,
enable_partial_rollout=enable_partial_rollout,
)
result = await agent_loop.generate_group(rollout_state)
elapsed = time.perf_counter() - start
for item in result:
extra_fields = getattr(item, "extra_fields", None)
Expand Down Expand Up @@ -156,11 +153,21 @@ class ProduceStrategyConfig(ABC, BaseModel):
should_continue_fn: ShouldContinueFn = default_should_continue_fn

@abstractmethod
def build(self, *, sync_weights_interval: int = 1) -> "ProduceStrategy": ...
def build(
self,
*,
sync_weights_interval: int = 1,
rollout_controller: "Optional[RolloutControllerProxy]" = None,
) -> "ProduceStrategy": ...


class SyncProduceStrategyConfig(ProduceStrategyConfig):
def build(self, *, sync_weights_interval: int = 1) -> "SyncProduceStrategy":
def build(
self,
*,
sync_weights_interval: int = 1,
rollout_controller: "Optional[RolloutControllerProxy]" = None,
) -> "SyncProduceStrategy":
return SyncProduceStrategy(
is_valid_sample_fn=self.is_valid_sample_fn, should_continue_fn=self.should_continue_fn
)
Expand All @@ -172,7 +179,16 @@ class AsyncProduceStrategyConfig(ProduceStrategyConfig):
max_staleness: int = Field(default=0, ge=0)
tail_batch_trigger_size: int = 0

def build(self, *, sync_weights_interval: int = 1) -> "AsyncProduceStrategy":
def build(
self,
*,
sync_weights_interval: int = 1,
rollout_controller: "Optional[RolloutControllerProxy]" = None,
) -> "AsyncProduceStrategy":
if rollout_controller is not None:
import ray

ray.get(rollout_controller.set_enable_partial_rollout.remote(self.enable_partial_rollout))
return AsyncProduceStrategy(
over_sample_threshold=self.over_sample_threshold,
enable_partial_rollout=self.enable_partial_rollout,
Expand Down Expand Up @@ -435,13 +451,7 @@ async def _schedule_one(
return False
group_status = [Status.EXPIRED, Status.ABORTED] if sample_from_expired else [Status.ABORTED]
rollout_state = await sampler.sample(task_name=task_name, group_status=group_status)
task = create_task(
_timed_generate_group(
agent_loop,
rollout_state,
enable_partial_rollout=self.enable_partial_rollout,
)
)
task = create_task(_timed_generate_group(agent_loop, rollout_state))
self._pending_tasks.add(task)
return True

Expand Down
6 changes: 6 additions & 0 deletions xtuner/v1/rl/rollout/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,12 @@ def _apply_output_parsers(self, rollout_state: RolloutState) -> None:
else:
rollout_state.extra_fields.pop("reasoning_text", None)

def set_enable_partial_rollout(self, enable: bool) -> None:
"""Propagate enable_partial_rollout flag to all active workers."""
with self.worker_info_lock:
active_actors = [info.actor for info in self.rank2info.values() if info.is_active]
ray.get([actor.set_enable_partial_rollout.remote(enable) for actor in active_actors]) # type: ignore[attr-defined]

def pause_generation(self):
self.health_checker.pause()

Expand Down
108 changes: 107 additions & 1 deletion xtuner/v1/rl/rollout/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@

import httpx
import ray
from ray import ObjectRef as RayObjectRef

from xtuner.v1.rl.utils import asyncio_run
from xtuner.v1.data_proto.rl_data import RolloutState, Status
from xtuner.v1.rl.utils import asyncio_run, clear_rollout_response_for_rerun
from xtuner.v1.utils import get_logger


Expand Down Expand Up @@ -277,3 +279,107 @@ async def check_worker_health(
f"Exception during health check for worker {rank} at {url}: {e}. Failure count: {failing_count}"
)
return False


def _resolve_routed_experts(routed_experts: list[int] | ray.ObjectRef) -> list[int]:
if isinstance(routed_experts, ray.ObjectRef):
routed_experts = ray.get(routed_experts)
if hasattr(routed_experts, "tolist"):
routed_experts = routed_experts.tolist()
assert isinstance(routed_experts, list), f"Unexpected routed_experts type: {type(routed_experts)}"
return routed_experts


class PartialRolloutHandler:
"""Handle preprocessing and postprocessing for partial rollout
continuation."""

def __init__(self) -> None:
self.logger = get_logger(self.__class__.__name__)

def preprocess(
self, rollout_state: RolloutState, max_tokens: int, enable_partial_rollout: bool = False
) -> RolloutState:
if rollout_state.status == Status.EXPIRED or (
not enable_partial_rollout and rollout_state.status == Status.ABORTED
):
rollout_state = clear_rollout_response_for_rerun(rollout_state)
rollout_state.sample_params = rollout_state.sample_params.model_copy(update={"max_tokens": max_tokens})
rollout_state.response = ""
rollout_state.status = Status.INIT

if not rollout_state.response_ids or rollout_state.status == Status.COMPLETED:
return rollout_state

# Set up token and length variable
response_ids = rollout_state.response_ids
prompt_ids = list(rollout_state.prompt_ids or [])
response_len = len(response_ids)
prompt_len = len(prompt_ids)

rollout_state.tokens = prompt_ids + response_ids # concatenate for partial rollout continuation
remaining_tokens = max_tokens - response_len # compute remaining max_tokens budget
rollout_state.sample_params = rollout_state.sample_params.copy(update={"max_tokens": remaining_tokens})

self.logger.info(
f"[PartialRolloutHandler] Sample {rollout_state.uid} continue rollout | Remaining tokens allowed: {remaining_tokens} | Status: {rollout_state.status} | Prompt len: {prompt_len} | Response len: {response_len} | Staleness: {rollout_state.seq_staleness} | Total tokens: {len(rollout_state.tokens)}"
)
return rollout_state

def postprocess(
self,
rollout_state: RolloutState,
*,
response: str,
response_ids: list[int],
logprobs: list[float],
routed_experts: list[int] | RayObjectRef | None,
finish_reason: str,
status: Status,
enable_partial_rollout: bool = False,
) -> RolloutState:
if not enable_partial_rollout:
rollout_state.response = response
rollout_state.response_ids = response_ids
rollout_state.logprobs = logprobs
rollout_state.routed_experts = routed_experts
rollout_state.status = status
rollout_state.finish_reason = finish_reason
return rollout_state

else:
rollout_state.finish_reason = finish_reason
rollout_state.status = status
history_response = rollout_state.response or ""
history_response_ids = rollout_state.response_ids or []
history_logprobs = rollout_state.logprobs or []
rollout_state.response = history_response + response
rollout_state.response_ids = history_response_ids + response_ids
rollout_state.logprobs = history_logprobs + logprobs

# 处理routed experts
history_routed_experts = rollout_state.routed_experts or None
if history_routed_experts is not None and routed_experts is not None:
start_time = time.time()
history_routed_experts = _resolve_routed_experts(history_routed_experts)
cur_routed_experts = _resolve_routed_experts(routed_experts)
cur_routed_experts_len = len(cur_routed_experts)
history_routed_experts_len = len(history_routed_experts)
assert history_routed_experts_len - 1 <= cur_routed_experts_len, (
f"Existing routed_experts len: {history_routed_experts_len}, current routed_experts len: {cur_routed_experts_len}, history_response_ids len: {len(history_response_ids)}, current response_ids len: {len(response_ids)}"
)
cur_routed_experts = cur_routed_experts[history_routed_experts_len:]
concat_routed_experts = history_routed_experts + cur_routed_experts
rollout_state.routed_experts = ray.put(concat_routed_experts)
# free_object_refs(
# [ref for ref in (history_routed_experts_ref, cur_routed_experts_ref) if isinstance(ref, ray.ObjectRef)]
# )
end_time = time.time()
self.logger.info(
f"[PartialRolloutHandler] Postprocess routed_experts concatenation time: {end_time - start_time:.4f} seconds"
)
elif history_routed_experts is None and routed_experts is not None:
rollout_state.routed_experts = routed_experts
elif history_routed_experts is not None and routed_experts is None:
rollout_state.routed_experts = history_routed_experts
return rollout_state
Loading
Loading