Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,8 @@

/venv/
/run/
/build/

__pycache__/
*.egg-info/
CLAUDE.md
77 changes: 77 additions & 0 deletions src/mgds/pipelineModules/EncodeLensText.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from contextlib import nullcontext

import torch
from lens.text_encoder import LensGptOssEncoder
from mgds.PipelineModule import PipelineModule
from mgds.pipelineModuleTypes.RandomAccessPipelineModule import RandomAccessPipelineModule


# This module is Lens-specific: it relies on LensGptOssEncoder.encode_layers() rather than the
# standard output_hidden_states=True API. The standard API cannot be used because transformers'
# @capture_outputs applies tie_last_hidden_states, which norms hidden_states[-1] — corrupting
# the last selected layer (GPT-OSS layer 23 is the final layer in the model).
# text_encoder.set_selected_layers() must have been called before this module is used;
# LensModelLoader does this immediately after loading the text encoder.
class EncodeLensText(
PipelineModule,
RandomAccessPipelineModule,
):
def __init__(
self,
tokens_name: str,
tokens_attention_mask_in_name: str | None,
hidden_state_out_name: str,
tokens_attention_mask_out_name: str | None,
text_encoder: LensGptOssEncoder,
crop_start: int | None = None,
autocast_contexts: list[torch.autocast | None] = None,
dtype: torch.dtype | None = None,
):
super(EncodeLensText, self).__init__()
self.tokens_name = tokens_name
self.tokens_attention_mask_in_name = tokens_attention_mask_in_name
self.hidden_state_out_name = hidden_state_out_name
self.tokens_attention_mask_out_name = tokens_attention_mask_out_name
self.text_encoder = text_encoder
self.crop_start = crop_start

self.autocast_contexts = [nullcontext()] if autocast_contexts is None else autocast_contexts
self.dtype = dtype

def length(self) -> int:
return self._get_previous_length(self.tokens_name)

def get_inputs(self) -> list[str]:
return [self.tokens_name, self.tokens_attention_mask_in_name]

def get_outputs(self) -> list[str]:
return [self.tokens_name, self.hidden_state_out_name, self.tokens_attention_mask_out_name]

def get_item(self, variation: int, index: int, requested_name: str = None) -> dict:
tokens = self._get_previous_item(variation, self.tokens_name, index)
tokens = tokens.unsqueeze(0)

if self.tokens_attention_mask_in_name is not None:
tokens_attention_mask = self._get_previous_item(variation, self.tokens_attention_mask_in_name, index)
tokens_attention_mask = tokens_attention_mask.unsqueeze(0)
else:
tokens_attention_mask = None

with self._all_contexts(self.autocast_contexts):
layer_outputs = self.text_encoder.encode_layers(tokens, tokens_attention_mask)

hidden_state = torch.cat(layer_outputs, dim=-1)
tokens = tokens.squeeze(dim=0)
hidden_state = hidden_state.squeeze(dim=0)
tokens_attention_mask = tokens_attention_mask.squeeze(dim=0)

if self.crop_start is not None:
tokens = tokens[self.crop_start:]
tokens_attention_mask = tokens_attention_mask[self.crop_start:]
hidden_state = hidden_state[self.crop_start:]

return {
self.tokens_name: tokens,
self.hidden_state_out_name: hidden_state,
self.tokens_attention_mask_out_name: tokens_attention_mask,
}
4 changes: 4 additions & 0 deletions src/mgds/pipelineModules/Tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(
additional_format_text_tokens: int | None = None,
apply_chat_template: Callable | None = None,
apply_chat_template_kwargs = {},
apply_chat_template_post_process: Callable[[str], str] | None = None,
expand_mask: int = 0,
):
super(Tokenize, self).__init__()
Expand All @@ -32,6 +33,7 @@ def __init__(
self.format_text = format_text
self.apply_chat_template = apply_chat_template
self.apply_chat_template_kwargs = apply_chat_template_kwargs
self.apply_chat_template_post_process = apply_chat_template_post_process
self.additional_format_text_tokens = additional_format_text_tokens
self.expand_mask = expand_mask

Expand Down Expand Up @@ -60,6 +62,8 @@ def get_item(self, variation: int, index: int, requested_name: str = None) -> di
tokenize=False,
**self.apply_chat_template_kwargs,
)
if self.apply_chat_template_post_process is not None:
text = self.apply_chat_template_post_process(text)

tokenizer_output = self.tokenizer(
text,
Expand Down