diff --git a/xtuner/v1/rl/agent_loop/single_turn_agent_loop.py b/xtuner/v1/rl/agent_loop/single_turn_agent_loop.py index 3edac1190..0e99bea20 100644 --- a/xtuner/v1/rl/agent_loop/single_turn_agent_loop.py +++ b/xtuner/v1/rl/agent_loop/single_turn_agent_loop.py @@ -6,7 +6,6 @@ from xtuner.v1.rl.utils import create_task from .agent_loop import AgentLoop, AgentLoopConfig -from .utils import PartialRolloutHandler class SingleTurnAgentLoopConfig(AgentLoopConfig): @@ -34,8 +33,6 @@ def __init__( enable_batch_judge: bool = False, ): super().__init__(rollout_ctl, sample_params, hf_checkpoint, judger, logger) - self.max_tokens = self.sample_params.max_tokens - self.partial_rollout_handler = PartialRolloutHandler(max_tokens=self.max_tokens) self.enable_batch_judge = enable_batch_judge async def generate_sample( @@ -43,17 +40,11 @@ async def generate_sample( rollout_state: RolloutState, **kwargs, ) -> RolloutState: - enable_partial_rollout = kwargs.get("enable_partial_rollout", False) - - # rollout state 预处理, enable_partial_rollout = True 会在这里拼接 token 和修正 max_token - rollout_state = self.partial_rollout_handler.preprocess(rollout_state, enable_partial_rollout) if not rollout_state.tokens: rollout_state.tokens = rollout_state.prompt_ids # 推理引擎generate, 生成的结果会覆盖到 rollout_state.response_ids 上 rollout_state = await self.rollout_ctl.generate.remote(rollout_state) # type: ignore[attr-defined] - # rollout state 后处理: 合并 partial rollout 的历史上下文 - rollout_state = self.partial_rollout_handler.postprocess(rollout_state) # 非 COMPLETED 状态(如被截断、放弃等)直接早退,不触发打分 if rollout_state.status != Status.COMPLETED: return rollout_state diff --git a/xtuner/v1/rl/agent_loop/utils.py b/xtuner/v1/rl/agent_loop/utils.py deleted file mode 100644 index 6da4c6504..000000000 --- a/xtuner/v1/rl/agent_loop/utils.py +++ /dev/null @@ -1,102 +0,0 @@ -import time - -import ray - -from xtuner.v1.data_proto.rl_data import RolloutState, Status -from xtuner.v1.rl.utils import clear_rollout_response_for_rerun -from xtuner.v1.utils import get_logger - - -def _resolve_routed_experts(routed_experts: list[int] | ray.ObjectRef) -> list[int]: - if isinstance(routed_experts, ray.ObjectRef): - routed_experts = ray.get(routed_experts) - if hasattr(routed_experts, "tolist"): - routed_experts = routed_experts.tolist() - assert isinstance(routed_experts, list), f"Unexpected routed_experts type: {type(routed_experts)}" - return routed_experts - - -class PartialRolloutHandler: - """Handle preprocessing and postprocessing for partial rollout - continuation.""" - - def __init__(self, max_tokens: int) -> None: - self.logger = get_logger(self.__class__.__name__) - self.max_tokens = max_tokens - - def preprocess(self, rollout_state: RolloutState, enable_partial_rollout: bool = False) -> RolloutState: - if rollout_state.status == Status.EXPIRED or ( - not enable_partial_rollout and rollout_state.status == Status.ABORTED - ): - rollout_state = clear_rollout_response_for_rerun(rollout_state) - rollout_state.sample_params = rollout_state.sample_params.model_copy( - update={"max_tokens": self.max_tokens} - ) - rollout_state.response = "" - rollout_state.status = Status.INIT - - if not rollout_state.response_ids or rollout_state.status == Status.COMPLETED: - return rollout_state - - # Set up token and length variable - response_ids = rollout_state.response_ids - prompt_ids = list(rollout_state.prompt_ids or []) - response_len = len(response_ids) - prompt_len = len(prompt_ids) - - rollout_state.tokens = prompt_ids + response_ids # concatenate for partial rollout continuation - remaining_tokens = self.max_tokens - response_len # compute remaining max_tokens budget - rollout_state.sample_params = rollout_state.sample_params.copy(update={"max_tokens": remaining_tokens}) - - self.logger.debug( - f"[PartialRolloutHandler] Sample {rollout_state.uid} continue rollout | Remaining tokens allowed: {remaining_tokens} | Status: {rollout_state.status} | Prompt len: {prompt_len} | Response len: {response_len} | Staleness: {rollout_state.seq_staleness} | Total tokens: {len(rollout_state.tokens)}" - ) - # TODO: handle routed_experts - rollout_state.extra_fields["history_response_dict"] = { - "response_ids": rollout_state.tokens[prompt_len:] if rollout_state.tokens else [], - "response": rollout_state.response or "", - "logprobs": rollout_state.logprobs or [], - "response_mask": rollout_state.response_mask or [], - "routed_experts": rollout_state.routed_experts, - } - return rollout_state - - def postprocess(self, rollout_state: RolloutState) -> RolloutState: - # TODO: if not enable partial rollout, return directly? - - # Concatenate history response fields - history_dict = rollout_state.extra_fields.pop("history_response_dict", None) - if not history_dict: - return rollout_state - - rollout_state.response_ids = history_dict.get("response_ids", []) + (rollout_state.response_ids or []) - rollout_state.response = history_dict.get("response", "") + (rollout_state.response or "") - rollout_state.logprobs = history_dict.get("logprobs", []) + (rollout_state.logprobs or []) - rollout_state.response_mask = history_dict.get("response_mask", []) + (rollout_state.response_mask or []) - history_routed_experts_ref = history_dict.get("routed_experts") - cur_routed_experts_ref = rollout_state.routed_experts - if history_routed_experts_ref is not None and cur_routed_experts_ref is not None: - start_time = time.time() - history_routed_experts = _resolve_routed_experts(history_routed_experts_ref) - cur_routed_experts = _resolve_routed_experts(cur_routed_experts_ref) - cur_routed_experts_len = len(cur_routed_experts) - history_routed_experts_len = len(history_routed_experts) - assert history_routed_experts_len - 1 <= cur_routed_experts_len, ( - f"Existing routed_experts len: {history_routed_experts_len}, current routed_experts len: {cur_routed_experts_len}" - ) - cur_routed_experts = cur_routed_experts[history_routed_experts_len:] - concat_routed_experts = history_routed_experts + cur_routed_experts - rollout_state.routed_experts = ray.put(concat_routed_experts) - # free_object_refs( - # [ref for ref in (history_routed_experts_ref, cur_routed_experts_ref) if isinstance(ref, ray.ObjectRef)] - # ) - end_time = time.time() - self.logger.info( - f"[PartialRolloutHandler] Postprocess routed_experts concatenation time: {end_time - start_time:.4f} seconds" - ) - elif history_routed_experts_ref is None and cur_routed_experts_ref is not None: - rollout_state.routed_experts = cur_routed_experts_ref - elif history_routed_experts_ref is not None and cur_routed_experts_ref is None: - rollout_state.routed_experts = history_routed_experts_ref - - return rollout_state diff --git a/xtuner/v1/rl/agent_loop_manager/agent_loop_manager.py b/xtuner/v1/rl/agent_loop_manager/agent_loop_manager.py index 00f6f8231..cc3cca84e 100644 --- a/xtuner/v1/rl/agent_loop_manager/agent_loop_manager.py +++ b/xtuner/v1/rl/agent_loop_manager/agent_loop_manager.py @@ -224,7 +224,10 @@ def build( judger=build_judger(task_cfg.judger_config) if task_cfg.judger_config is not None else None, logger=logger, ) - produce_strategy = task_cfg.produce_strategy_config.build(sync_weights_interval=sync_weights_interval) + produce_strategy = task_cfg.produce_strategy_config.build( + sync_weights_interval=sync_weights_interval, + rollout_controller=rollout_controller, + ) sampler = task_cfg.sampler_config.build(tokenizer=tokenizer, replay_buffer=replay_buffer) task_runners.append( _TaskRunner( diff --git a/xtuner/v1/rl/agent_loop_manager/producer.py b/xtuner/v1/rl/agent_loop_manager/producer.py index d4f82e02c..e4efe5bf5 100644 --- a/xtuner/v1/rl/agent_loop_manager/producer.py +++ b/xtuner/v1/rl/agent_loop_manager/producer.py @@ -4,7 +4,11 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum, auto -from typing import Protocol, runtime_checkable +from typing import TYPE_CHECKING, Optional, Protocol, runtime_checkable + + +if TYPE_CHECKING: + from xtuner.v1.rl.rollout.controller import RolloutControllerProxy import ray from pydantic import BaseModel, ConfigDict, Field @@ -71,19 +75,12 @@ class ProduceBatchStatus(Enum): async def _timed_generate_group( agent_loop: AgentLoopSpec, rollout_state: list[RolloutState], - enable_partial_rollout: bool = False, ) -> list[RolloutState]: start = time.perf_counter() if isinstance(agent_loop, ray.actor.ActorHandle): - result = await agent_loop.generate_group.remote( - rollout_state, - enable_partial_rollout=enable_partial_rollout, - ) + result = await agent_loop.generate_group.remote(rollout_state) else: - result = await agent_loop.generate_group( - rollout_state, - enable_partial_rollout=enable_partial_rollout, - ) + result = await agent_loop.generate_group(rollout_state) elapsed = time.perf_counter() - start for item in result: extra_fields = getattr(item, "extra_fields", None) @@ -156,11 +153,21 @@ class ProduceStrategyConfig(ABC, BaseModel): should_continue_fn: ShouldContinueFn = default_should_continue_fn @abstractmethod - def build(self, *, sync_weights_interval: int = 1) -> "ProduceStrategy": ... + def build( + self, + *, + sync_weights_interval: int = 1, + rollout_controller: "Optional[RolloutControllerProxy]" = None, + ) -> "ProduceStrategy": ... class SyncProduceStrategyConfig(ProduceStrategyConfig): - def build(self, *, sync_weights_interval: int = 1) -> "SyncProduceStrategy": + def build( + self, + *, + sync_weights_interval: int = 1, + rollout_controller: "Optional[RolloutControllerProxy]" = None, + ) -> "SyncProduceStrategy": return SyncProduceStrategy( is_valid_sample_fn=self.is_valid_sample_fn, should_continue_fn=self.should_continue_fn ) @@ -172,7 +179,16 @@ class AsyncProduceStrategyConfig(ProduceStrategyConfig): max_staleness: int = Field(default=0, ge=0) tail_batch_trigger_size: int = 0 - def build(self, *, sync_weights_interval: int = 1) -> "AsyncProduceStrategy": + def build( + self, + *, + sync_weights_interval: int = 1, + rollout_controller: "Optional[RolloutControllerProxy]" = None, + ) -> "AsyncProduceStrategy": + if rollout_controller is not None: + import ray + + ray.get(rollout_controller.set_enable_partial_rollout.remote(self.enable_partial_rollout)) return AsyncProduceStrategy( over_sample_threshold=self.over_sample_threshold, enable_partial_rollout=self.enable_partial_rollout, @@ -435,13 +451,7 @@ async def _schedule_one( return False group_status = [Status.EXPIRED, Status.ABORTED] if sample_from_expired else [Status.ABORTED] rollout_state = await sampler.sample(task_name=task_name, group_status=group_status) - task = create_task( - _timed_generate_group( - agent_loop, - rollout_state, - enable_partial_rollout=self.enable_partial_rollout, - ) - ) + task = create_task(_timed_generate_group(agent_loop, rollout_state)) self._pending_tasks.add(task) return True diff --git a/xtuner/v1/rl/rollout/controller.py b/xtuner/v1/rl/rollout/controller.py index 3bbc3e979..dc6da1b50 100644 --- a/xtuner/v1/rl/rollout/controller.py +++ b/xtuner/v1/rl/rollout/controller.py @@ -231,6 +231,12 @@ def _apply_output_parsers(self, rollout_state: RolloutState) -> None: else: rollout_state.extra_fields.pop("reasoning_text", None) + def set_enable_partial_rollout(self, enable: bool) -> None: + """Propagate enable_partial_rollout flag to all active workers.""" + with self.worker_info_lock: + active_actors = [info.actor for info in self.rank2info.values() if info.is_active] + ray.get([actor.set_enable_partial_rollout.remote(enable) for actor in active_actors]) # type: ignore[attr-defined] + def pause_generation(self): self.health_checker.pause() diff --git a/xtuner/v1/rl/rollout/utils.py b/xtuner/v1/rl/rollout/utils.py index 2fd75da19..7c45f7a24 100644 --- a/xtuner/v1/rl/rollout/utils.py +++ b/xtuner/v1/rl/rollout/utils.py @@ -8,8 +8,10 @@ import httpx import ray +from ray import ObjectRef as RayObjectRef -from xtuner.v1.rl.utils import asyncio_run +from xtuner.v1.data_proto.rl_data import RolloutState, Status +from xtuner.v1.rl.utils import asyncio_run, clear_rollout_response_for_rerun from xtuner.v1.utils import get_logger @@ -277,3 +279,107 @@ async def check_worker_health( f"Exception during health check for worker {rank} at {url}: {e}. Failure count: {failing_count}" ) return False + + +def _resolve_routed_experts(routed_experts: list[int] | ray.ObjectRef) -> list[int]: + if isinstance(routed_experts, ray.ObjectRef): + routed_experts = ray.get(routed_experts) + if hasattr(routed_experts, "tolist"): + routed_experts = routed_experts.tolist() + assert isinstance(routed_experts, list), f"Unexpected routed_experts type: {type(routed_experts)}" + return routed_experts + + +class PartialRolloutHandler: + """Handle preprocessing and postprocessing for partial rollout + continuation.""" + + def __init__(self) -> None: + self.logger = get_logger(self.__class__.__name__) + + def preprocess( + self, rollout_state: RolloutState, max_tokens: int, enable_partial_rollout: bool = False + ) -> RolloutState: + if rollout_state.status == Status.EXPIRED or ( + not enable_partial_rollout and rollout_state.status == Status.ABORTED + ): + rollout_state = clear_rollout_response_for_rerun(rollout_state) + rollout_state.sample_params = rollout_state.sample_params.model_copy(update={"max_tokens": max_tokens}) + rollout_state.response = "" + rollout_state.status = Status.INIT + + if not rollout_state.response_ids or rollout_state.status == Status.COMPLETED: + return rollout_state + + # Set up token and length variable + response_ids = rollout_state.response_ids + prompt_ids = list(rollout_state.prompt_ids or []) + response_len = len(response_ids) + prompt_len = len(prompt_ids) + + rollout_state.tokens = prompt_ids + response_ids # concatenate for partial rollout continuation + remaining_tokens = max_tokens - response_len # compute remaining max_tokens budget + rollout_state.sample_params = rollout_state.sample_params.copy(update={"max_tokens": remaining_tokens}) + + self.logger.info( + f"[PartialRolloutHandler] Sample {rollout_state.uid} continue rollout | Remaining tokens allowed: {remaining_tokens} | Status: {rollout_state.status} | Prompt len: {prompt_len} | Response len: {response_len} | Staleness: {rollout_state.seq_staleness} | Total tokens: {len(rollout_state.tokens)}" + ) + return rollout_state + + def postprocess( + self, + rollout_state: RolloutState, + *, + response: str, + response_ids: list[int], + logprobs: list[float], + routed_experts: list[int] | RayObjectRef | None, + finish_reason: str, + status: Status, + enable_partial_rollout: bool = False, + ) -> RolloutState: + if not enable_partial_rollout: + rollout_state.response = response + rollout_state.response_ids = response_ids + rollout_state.logprobs = logprobs + rollout_state.routed_experts = routed_experts + rollout_state.status = status + rollout_state.finish_reason = finish_reason + return rollout_state + + else: + rollout_state.finish_reason = finish_reason + rollout_state.status = status + history_response = rollout_state.response or "" + history_response_ids = rollout_state.response_ids or [] + history_logprobs = rollout_state.logprobs or [] + rollout_state.response = history_response + response + rollout_state.response_ids = history_response_ids + response_ids + rollout_state.logprobs = history_logprobs + logprobs + + # 处理routed experts + history_routed_experts = rollout_state.routed_experts or None + if history_routed_experts is not None and routed_experts is not None: + start_time = time.time() + history_routed_experts = _resolve_routed_experts(history_routed_experts) + cur_routed_experts = _resolve_routed_experts(routed_experts) + cur_routed_experts_len = len(cur_routed_experts) + history_routed_experts_len = len(history_routed_experts) + assert history_routed_experts_len - 1 <= cur_routed_experts_len, ( + f"Existing routed_experts len: {history_routed_experts_len}, current routed_experts len: {cur_routed_experts_len}, history_response_ids len: {len(history_response_ids)}, current response_ids len: {len(response_ids)}" + ) + cur_routed_experts = cur_routed_experts[history_routed_experts_len:] + concat_routed_experts = history_routed_experts + cur_routed_experts + rollout_state.routed_experts = ray.put(concat_routed_experts) + # free_object_refs( + # [ref for ref in (history_routed_experts_ref, cur_routed_experts_ref) if isinstance(ref, ray.ObjectRef)] + # ) + end_time = time.time() + self.logger.info( + f"[PartialRolloutHandler] Postprocess routed_experts concatenation time: {end_time - start_time:.4f} seconds" + ) + elif history_routed_experts is None and routed_experts is not None: + rollout_state.routed_experts = routed_experts + elif history_routed_experts is not None and routed_experts is None: + rollout_state.routed_experts = history_routed_experts + return rollout_state diff --git a/xtuner/v1/rl/rollout/worker.py b/xtuner/v1/rl/rollout/worker.py index dc101d71c..71e83d314 100644 --- a/xtuner/v1/rl/rollout/worker.py +++ b/xtuner/v1/rl/rollout/worker.py @@ -8,7 +8,7 @@ import traceback from abc import abstractmethod from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, List, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, List, Literal, Optional, Union, cast import httpx import ray @@ -30,6 +30,8 @@ from xtuner.v1.utils import get_logger from xtuner.v1.utils.httpx_utils import HttpRequestErrorType, HttpRequestResult +from .utils import PartialRolloutHandler + if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -472,6 +474,11 @@ def __init__( self.abort_timeout = 5.0 self.dist_init_addr: str = "" self.serverl_url: str = "" + self.partial_rollout_handler = PartialRolloutHandler() + self.enable_partial_rollout: bool = False + + def set_enable_partial_rollout(self, enable: bool) -> None: + self.enable_partial_rollout = enable def init(self, dist_init_addr: str) -> tuple[int, str]: """Initialize the worker and launch the server. @@ -583,7 +590,8 @@ async def generate(self, rollout_state: RolloutState) -> RolloutState: uid = rollout_state.uid sample_params: SampleParams = rollout_state.sample_params - + max_tokens = sample_params.max_tokens + enable_partial_rollout = self.enable_partial_rollout if sample_params.return_token_ids: endpoint_url = f"{self.server_url}/{self.endpoints['generate']}" else: @@ -594,8 +602,9 @@ async def generate(self, rollout_state: RolloutState) -> RolloutState: "Authorization": f"Bearer {self.config.api_key}", } - max_retries = self.config.max_retry_per_sample + rollout_state = self.partial_rollout_handler.preprocess(rollout_state, max_tokens, enable_partial_rollout) payload = self._get_request_payload(rollout_state) + max_retries = self.config.max_retry_per_sample # 早退逻辑 1:检查是否已被标记为完成 if rollout_state.status == Status.COMPLETED: @@ -604,7 +613,7 @@ async def generate(self, rollout_state: RolloutState) -> RolloutState: # 早退逻辑 2:检测输入是否还需要 generation (安全获取变量) input_ids = payload.get("input_ids", []) - max_tokens = payload.get("max_tokens") + max_tokens = cast(int, payload.get("max_tokens")) last_id = input_ids[-1] if len(input_ids) > 0 else "None" is_max_tokens_zero = max_tokens is not None and max_tokens <= 0 @@ -830,6 +839,7 @@ async def _safe_handle_response(self, rollout_state: RolloutState, http_response sample_params = rollout_state.sample_params is_token_out = sample_params.return_token_ids response = http_response.json() + if is_token_out: response_ids: list[int] = [] logprobs: list[float] = [] @@ -914,30 +924,39 @@ async def _safe_handle_response(self, rollout_state: RolloutState, http_response rollout_state.error_msg = error_msg return rollout_state - rollout_state.response = returned_response - rollout_state.response_ids = response_ids - rollout_state.logprobs = logprobs - rollout_state.routed_experts = routed_experts - rollout_state.finish_reason = finish_reason - rollout_state.status = rollout_status + rollout_state = self.partial_rollout_handler.postprocess( + rollout_state, + response=returned_response, + response_ids=response_ids, + logprobs=logprobs, + routed_experts=routed_experts, + finish_reason=finish_reason, + status=rollout_status, + enable_partial_rollout=self.enable_partial_rollout, + ) return rollout_state except KeyError as e: - error_msg = f"Missing expected key {e} in response {response} for {uid}" + response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")} + error_msg = f"Missing expected key {e} in response {response_for_log} for {uid}" raise RuntimeError(error_msg) except IndexError as e: - error_msg = f"Index error {e} while processing response {response} for {uid}" + response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")} + error_msg = f"Index error {e} while processing response {response_for_log} for {uid}" raise RuntimeError(error_msg) except AssertionError as e: - error_msg = f"AssertionError: {e} when processing response {response} for {uid}" + response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")} + error_msg = f"AssertionError: {e} when processing response {response_for_log} for {uid}" raise RuntimeError(error_msg) except json.JSONDecodeError as e: error_msg = f"JSONDecodeError: {e} when processing response {response} for {uid}" raise RuntimeError(error_msg) except TypeError as e: - error_msg = f"TypeError: {e} when processing response {response} for {uid}" + response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")} + error_msg = f"TypeError: {e} when processing response {response_for_log} for {uid}" raise RuntimeError(error_msg) except Exception as e: - error_msg = f"Unexpected error: {e} when processing response {response} for {uid}\nTraceback: {traceback.format_exc()}" + response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")} + error_msg = f"Unexpected error: {e} when processing response {response_for_log} for {uid}\nTraceback: {traceback.format_exc()}" raise RuntimeError(error_msg) else: # v1/chat/completions API response @@ -956,22 +975,27 @@ async def _safe_handle_response(self, rollout_state: RolloutState, http_response rollout_state.status = rollout_status return rollout_state except KeyError as e: - error_msg = f"Missing expected key {e} in response {response} for {uid}" + response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")} + error_msg = f"Missing expected key {e} in response {response_for_log} for {uid}" raise RuntimeError(error_msg) except IndexError as e: - error_msg = f"Index error {e} while processing response {response} for {uid}" + response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")} + error_msg = f"Index error {e} while processing response {response_for_log} for {uid}" raise RuntimeError(error_msg) except AssertionError as e: - error_msg = f"AssertionError: {e} when processing response {response} for {uid}" + response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")} + error_msg = f"AssertionError: {e} when processing response {response_for_log} for {uid}" raise RuntimeError(error_msg) except json.JSONDecodeError as e: error_msg = f"JSONDecodeError: {e} when processing response {response} for {uid}" raise RuntimeError(error_msg) except TypeError as e: - error_msg = f"TypeError: {e} when processing response {response} for {uid}" + response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")} + error_msg = f"TypeError: {e} when processing response {response_for_log} for {uid}" raise RuntimeError(error_msg) except Exception as e: - error_msg = f"Unexpected error: {e} when processing response {response} for {uid}\nTraceback: {traceback.format_exc()}" + response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")} + error_msg = f"Unexpected error: {e} when processing response {response_for_log} for {uid}\nTraceback: {traceback.format_exc()}" raise RuntimeError(error_msg) def _adapt_input_to_openai_spec(self, prompts, tools, tool_choice): diff --git a/xtuner/v1/rl/trainer/controller.py b/xtuner/v1/rl/trainer/controller.py index 3ce836bef..0c13e8fdb 100644 --- a/xtuner/v1/rl/trainer/controller.py +++ b/xtuner/v1/rl/trainer/controller.py @@ -264,7 +264,7 @@ def fit(self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx: if data["seq_ctx"].pixel_values is not None: free_pixel_value_refs.extend(data["seq_ctx"].pixel_values) # if len(free_pixel_value_refs) > 0: - # free_object_refs(free_pixel_value_refs) + # free_object_refs(free_pixel_value_refs) del packed_data_batches return log_infos