From cc3c1d8fd2890bf8e0b4f5aa26c2e21430d9a620 Mon Sep 17 00:00:00 2001 From: Faux Date: Tue, 31 Mar 2026 22:12:48 +0200 Subject: [PATCH] support token chunking for Clip --- src/mgds/pipelineModules/EncodeClipText.py | 148 +++++++++++++++++++-- 1 file changed, 140 insertions(+), 8 deletions(-) diff --git a/src/mgds/pipelineModules/EncodeClipText.py b/src/mgds/pipelineModules/EncodeClipText.py index 213226e..ab3a5af 100644 --- a/src/mgds/pipelineModules/EncodeClipText.py +++ b/src/mgds/pipelineModules/EncodeClipText.py @@ -1,11 +1,13 @@ from contextlib import nullcontext - -import torch -from transformers import CLIPTextModel, CLIPTextModelWithProjection +from typing import Any from mgds.PipelineModule import PipelineModule from mgds.pipelineModuleTypes.RandomAccessPipelineModule import RandomAccessPipelineModule +import torch + +from transformers import CLIPTextModel, CLIPTextModelWithProjection + class EncodeClipText( PipelineModule, @@ -22,19 +24,31 @@ def __init__( hidden_state_output_index: int | None = None, autocast_contexts: list[torch.autocast | None] = None, dtype: torch.dtype | None = None, + chunk_if_needed: bool = False, + pooled_output_handling: str = 'FIRST', + tokenizer: Any | None = None, + split_on_comma: bool = False, + max_chunks: int | None = None, ): - super(EncodeClipText, self).__init__() + super().__init__() self.in_name = in_name self.tokens_attention_mask_in_name = tokens_attention_mask_in_name self.hidden_state_out_name = hidden_state_out_name self.pooled_out_name = pooled_out_name self.text_encoder = text_encoder self.add_layer_norm = add_layer_norm - self.hidden_state_output_index = hidden_state_output_index + self.hidden_state_output_index = -1 if hidden_state_output_index is None else hidden_state_output_index self.autocast_contexts = [nullcontext()] if autocast_contexts is None else autocast_contexts self.dtype = dtype + self.chunk_if_needed = chunk_if_needed + self.chunk_size = text_encoder.config.max_position_embeddings - 2 + self.pooled_output_handling = pooled_output_handling + self.tokenizer = tokenizer + self.split_on_comma = split_on_comma + self.max_chunks = max_chunks + def length(self) -> int: return self._get_previous_length(self.in_name) @@ -49,14 +63,22 @@ def get_outputs(self) -> list[str]: def get_item(self, variation: int, index: int, requested_name: str = None) -> dict: tokens = self._get_previous_item(variation, self.in_name, index) - tokens = tokens.unsqueeze(0) if self.tokens_attention_mask_in_name is not None: tokens_attention_mask = self._get_previous_item(variation, self.tokens_attention_mask_in_name, index) - tokens_attention_mask = tokens_attention_mask.unsqueeze(0) else: tokens_attention_mask = None + if self.chunk_if_needed and tokens.shape[0] > self.chunk_size + 2: + return self._get_item_chunked(tokens, tokens_attention_mask, self.chunk_size) + else: + return self._get_item_single(tokens, tokens_attention_mask) + + def _get_item_single(self, tokens: torch.Tensor, tokens_attention_mask: torch.Tensor | None) -> dict: + tokens = tokens.unsqueeze(0) + if tokens_attention_mask is not None: + tokens_attention_mask = tokens_attention_mask.unsqueeze(0) + with self._all_contexts(self.autocast_contexts): if tokens_attention_mask is not None and self.dtype: tokens_attention_mask = tokens_attention_mask.to(dtype=self.dtype) @@ -70,11 +92,11 @@ def get_item(self, variation: int, index: int, requested_name: str = None) -> di hidden_states = text_encoder_output.hidden_states if self.pooled_out_name: + pooled_state = None if hasattr(text_encoder_output, "text_embeds"): pooled_state = text_encoder_output.text_embeds if hasattr(text_encoder_output, "pooler_output"): pooled_state = text_encoder_output.pooler_output - else: pooled_state = None @@ -94,3 +116,113 @@ def get_item(self, variation: int, index: int, requested_name: str = None) -> di self.hidden_state_out_name: hidden_state, self.pooled_out_name: pooled_state, } + + def _get_item_chunked(self, tokens: torch.Tensor, tokens_attention_mask: torch.Tensor | None, chunk_size: int) -> dict: + def get_splits(tokens_row, mask_row, chunk_size, comma_id=None): + tokens_splits = [] + mask_splits = [] + start = 0 + while start < len(tokens_row): + if len(tokens_row) - start <= chunk_size: + tokens_splits.append(tokens_row[start:]) + if mask_row is not None: + mask_splits.append(mask_row[start:]) + break + end = start + chunk_size + if comma_id is not None: + found_comma = -1 + for i in range(end - 1, start - 1, -1): + if tokens_row[i] == comma_id: + found_comma = i + break + if found_comma != -1: + end = found_comma + 1 + tokens_splits.append(tokens_row[start:end]) + if mask_row is not None: + mask_splits.append(mask_row[start:end]) + start = end + return tokens_splits, mask_splits + + # split tokens into chunks of chunk_size tokens, and add BOS/EOS to each + bos_token = tokens[0] + eos_token = tokens[-1] + + # remove BOS and EOS + tokens = tokens[1:-1] + if tokens_attention_mask is not None: + tokens_attention_mask = tokens_attention_mask[1:-1] + + comma_id = None + if self.split_on_comma and self.tokenizer: + comma_id = self.tokenizer.convert_tokens_to_ids(",") + if isinstance(comma_id, list): + comma_id = comma_id[0] + + input_id_chunks, attention_mask_chunks = get_splits(tokens, tokens_attention_mask, chunk_size, comma_id) + + if self.max_chunks is not None: + input_id_chunks = input_id_chunks[:self.max_chunks] + attention_mask_chunks = attention_mask_chunks[:self.max_chunks] + + # pad each chunk with EOS + new_input_id_chunks = [] + new_attention_mask_chunks = [] + for i in range(len(input_id_chunks)): + chunk = input_id_chunks[i] + padding_len = chunk_size - len(chunk) + + if padding_len > 0: + chunk = torch.cat([ + chunk, + torch.full((padding_len,), eos_token, dtype=tokens.dtype, device=tokens.device) + ]) + if tokens_attention_mask is not None: + mask_chunk = attention_mask_chunks[i] + mask_chunk = torch.cat([ + mask_chunk, + torch.full((padding_len,), 0, dtype=tokens_attention_mask.dtype, device=tokens_attention_mask.device) + ]) + new_attention_mask_chunks.append(mask_chunk) + else: + if tokens_attention_mask is not None: + new_attention_mask_chunks.append(attention_mask_chunks[i]) + + new_input_id_chunks.append(chunk) + + input_id_chunks = new_input_id_chunks + if tokens_attention_mask is not None: + attention_mask_chunks = new_attention_mask_chunks + else: + attention_mask_chunks = [None] * len(input_id_chunks) + + # add BOS and EOS to each chunk + input_id_chunks = [torch.cat([bos_token.unsqueeze(0), chunk, eos_token.unsqueeze(0)]) for chunk in input_id_chunks] + if tokens_attention_mask is not None: + attention_mask_chunks = [torch.cat([torch.ones(1, dtype=tokens_attention_mask.dtype, device=tokens_attention_mask.device), chunk, torch.ones(1, dtype=tokens_attention_mask.dtype, device=tokens_attention_mask.device)]) for chunk in attention_mask_chunks] + + hidden_states = [] + pooled_states = [] + + for chunk_tokens, chunk_attention_mask in zip(input_id_chunks, attention_mask_chunks, strict=True): + res = self._get_item_single(chunk_tokens, chunk_attention_mask) + hidden_states.append(res[self.hidden_state_out_name]) + if self.pooled_out_name: + pooled_states.append(res[self.pooled_out_name]) + + hidden_state = torch.cat(hidden_states, dim=0) + + if self.pooled_out_name: + handling = str(self.pooled_output_handling) + if handling == 'FIRST': + pooled_state = pooled_states[0] + elif handling == 'AVERAGE': + pooled_state = torch.mean(torch.stack(pooled_states), dim=0) + else: + pooled_state = pooled_states[0] + else: + pooled_state = None + + return { + self.hidden_state_out_name: hidden_state, + self.pooled_out_name: pooled_state, + }