diff --git a/.gitignore b/.gitignore index 80be2dd..adb4551 100644 --- a/.gitignore +++ b/.gitignore @@ -2,5 +2,8 @@ /venv/ /run/ +/build/ __pycache__/ +*.egg-info/ +CLAUDE.md diff --git a/src/mgds/pipelineModules/EncodeLensText.py b/src/mgds/pipelineModules/EncodeLensText.py new file mode 100644 index 0000000..61ab992 --- /dev/null +++ b/src/mgds/pipelineModules/EncodeLensText.py @@ -0,0 +1,77 @@ +from contextlib import nullcontext + +import torch +from lens.text_encoder import LensGptOssEncoder +from mgds.PipelineModule import PipelineModule +from mgds.pipelineModuleTypes.RandomAccessPipelineModule import RandomAccessPipelineModule + + +# This module is Lens-specific: it relies on LensGptOssEncoder.encode_layers() rather than the +# standard output_hidden_states=True API. The standard API cannot be used because transformers' +# @capture_outputs applies tie_last_hidden_states, which norms hidden_states[-1] — corrupting +# the last selected layer (GPT-OSS layer 23 is the final layer in the model). +# text_encoder.set_selected_layers() must have been called before this module is used; +# LensModelLoader does this immediately after loading the text encoder. +class EncodeLensText( + PipelineModule, + RandomAccessPipelineModule, +): + def __init__( + self, + tokens_name: str, + tokens_attention_mask_in_name: str | None, + hidden_state_out_name: str, + tokens_attention_mask_out_name: str | None, + text_encoder: LensGptOssEncoder, + crop_start: int | None = None, + autocast_contexts: list[torch.autocast | None] = None, + dtype: torch.dtype | None = None, + ): + super(EncodeLensText, self).__init__() + self.tokens_name = tokens_name + self.tokens_attention_mask_in_name = tokens_attention_mask_in_name + self.hidden_state_out_name = hidden_state_out_name + self.tokens_attention_mask_out_name = tokens_attention_mask_out_name + self.text_encoder = text_encoder + self.crop_start = crop_start + + self.autocast_contexts = [nullcontext()] if autocast_contexts is None else autocast_contexts + self.dtype = dtype + + def length(self) -> int: + return self._get_previous_length(self.tokens_name) + + def get_inputs(self) -> list[str]: + return [self.tokens_name, self.tokens_attention_mask_in_name] + + def get_outputs(self) -> list[str]: + return [self.tokens_name, self.hidden_state_out_name, self.tokens_attention_mask_out_name] + + def get_item(self, variation: int, index: int, requested_name: str = None) -> dict: + tokens = self._get_previous_item(variation, self.tokens_name, index) + tokens = tokens.unsqueeze(0) + + if self.tokens_attention_mask_in_name is not None: + tokens_attention_mask = self._get_previous_item(variation, self.tokens_attention_mask_in_name, index) + tokens_attention_mask = tokens_attention_mask.unsqueeze(0) + else: + tokens_attention_mask = None + + with self._all_contexts(self.autocast_contexts): + layer_outputs = self.text_encoder.encode_layers(tokens, tokens_attention_mask) + + hidden_state = torch.cat(layer_outputs, dim=-1) + tokens = tokens.squeeze(dim=0) + hidden_state = hidden_state.squeeze(dim=0) + tokens_attention_mask = tokens_attention_mask.squeeze(dim=0) + + if self.crop_start is not None: + tokens = tokens[self.crop_start:] + tokens_attention_mask = tokens_attention_mask[self.crop_start:] + hidden_state = hidden_state[self.crop_start:] + + return { + self.tokens_name: tokens, + self.hidden_state_out_name: hidden_state, + self.tokens_attention_mask_out_name: tokens_attention_mask, + } diff --git a/src/mgds/pipelineModules/Tokenize.py b/src/mgds/pipelineModules/Tokenize.py index b3ff90c..d800a83 100644 --- a/src/mgds/pipelineModules/Tokenize.py +++ b/src/mgds/pipelineModules/Tokenize.py @@ -21,6 +21,7 @@ def __init__( additional_format_text_tokens: int | None = None, apply_chat_template: Callable | None = None, apply_chat_template_kwargs = {}, + apply_chat_template_post_process: Callable[[str], str] | None = None, expand_mask: int = 0, ): super(Tokenize, self).__init__() @@ -32,6 +33,7 @@ def __init__( self.format_text = format_text self.apply_chat_template = apply_chat_template self.apply_chat_template_kwargs = apply_chat_template_kwargs + self.apply_chat_template_post_process = apply_chat_template_post_process self.additional_format_text_tokens = additional_format_text_tokens self.expand_mask = expand_mask @@ -60,6 +62,8 @@ def get_item(self, variation: int, index: int, requested_name: str = None) -> di tokenize=False, **self.apply_chat_template_kwargs, ) + if self.apply_chat_template_post_process is not None: + text = self.apply_chat_template_post_process(text) tokenizer_output = self.tokenizer( text,