diff --git a/xtuner/v1/ray/dataflow/replay_buffer.py b/xtuner/v1/ray/dataflow/replay_buffer.py index 8596ae7e7..e936e23ec 100644 --- a/xtuner/v1/ray/dataflow/replay_buffer.py +++ b/xtuner/v1/ray/dataflow/replay_buffer.py @@ -84,6 +84,8 @@ def determine_group_state(group_data_items: List[RLDataFlowItem]) -> RolloutStat return RolloutState.SKIPPED elif RolloutState.FAILED in group_states: return RolloutState.FAILED + elif RolloutState.EXPIRED in group_states: + return RolloutState.EXPIRED elif RolloutState.ABORTED in group_states: return RolloutState.ABORTED elif all(state == RolloutState.COMPLETED for state in group_states): @@ -391,6 +393,34 @@ def _strip_rollout_payload_for_rerun(self, replay_meta: ReplayMeta, new_state: R ray.internal.free(old_obs_refs, local_only=False) replay_meta.observation_refs = [ray.put(RLEnvDataItem()) for _ in replay_meta.observation_ids] self._update_replay_meta_state(replay_meta, new_state) + if new_state == RolloutState.EXPIRED and self.tail_batch_trigger_size <= 0: + self._clear_multimodal_objectrefs(replay_meta) + + def _clear_multimodal_objectrefs(self, replay_meta: ReplayMeta): + if replay_meta.action_ref is None: + return + + data_item = ray.get(replay_meta.action_ref) + multimodal_info = getattr(data_item, "multimodal_train_info", None) + if not multimodal_info: + return + + refs_to_free: List[ObjectRef] = [] + changed = False + for key, value in list(multimodal_info.items()): + if isinstance(value, ObjectRef): + refs_to_free.append(value) + multimodal_info[key] = None + changed = True + + if not changed: + return + + old_action_ref = replay_meta.action_ref + replay_meta.action_ref = ray.put(data_item) + if isinstance(old_action_ref, ObjectRef): + refs_to_free.append(old_action_ref) + free_object_refs(refs_to_free) def add(self, grouped_dataitem: List[RLDataFlowItem]): """Adds a group of data items to the storage. @@ -848,6 +878,7 @@ def _clear_meta_for_actions(self, replay_meta: ReplayMeta): This is the single source of truth for deleting an action. """ action_id = replay_meta.action_id + root_id = replay_meta.root_id self._release_replay_meta_refs(replay_meta) @@ -859,6 +890,12 @@ def _clear_meta_for_actions(self, replay_meta: ReplayMeta): self._actions.pop(action_id, None) self._action2observations.pop(action_id, None) + if root_id in self._root2actions: + self._root2actions[root_id] = [ + stored_action_id for stored_action_id in self._root2actions[root_id] if stored_action_id != action_id + ] + if not self._root2actions[root_id]: + del self._root2actions[root_id] del replay_meta def _clear_meta_for_root(self, replay_meta: ReplayMeta):