Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 141 additions & 16 deletions comfy/ldm/ideogram4/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,27 +226,85 @@ def _tokens_to_img(self, tokens, gh, gw):
x = x.permute(0, 5, 3, 4, 1, 2) # (B, c, pi, pj, gh, gw)
return x.reshape(B, C, gh, gw)

def _image_position_ids(self, gh, gw, device):
h_idx = torch.arange(gh, device=device).view(-1, 1).expand(gh, gw).reshape(-1)
w_idx = torch.arange(gw, device=device).view(1, -1).expand(gh, gw).reshape(-1)
t_idx = torch.zeros_like(h_idx)
def _image_position_ids(self, gh, gw, device, index=0, h_offset=0, w_offset=0):
h_idx = torch.arange(gh, device=device).view(-1, 1).expand(gh, gw).reshape(-1) + h_offset
w_idx = torch.arange(gw, device=device).view(1, -1).expand(gh, gw).reshape(-1) + w_offset
t_idx = torch.full_like(h_idx, index)
return torch.stack([t_idx, h_idx, w_idx], dim=1) + IMAGE_POSITION_OFFSET # (L_img, 3)

def _run_conditional(self, x_chunk, context_chunk, attn_mask_chunk, t_chunk, gh, gw, transformer_options):
def _run_conditional(self, x_chunk, context_chunk, attn_mask_chunk, t_chunk, gh, gw, transformer_options, ref_latents=None, ref_latents_method="index"):
B = x_chunk.shape[0]
device = x_chunk.device
img_tokens = self._img_to_tokens(x_chunk)
L_img = img_tokens.shape[1]
L_text = context_chunk.shape[1]
L = L_text + L_img
latent_dim = img_tokens.shape[-1]

ref_tokens_list = []
ref_pos_ids_list = []
ref_num_tokens = []

if ref_latents is not None:
h = 0
w = 0
index = 0
index_ref_method = (ref_latents_method == "index") or (ref_latents_method == "index_timestep_zero")
negative_ref_method = ref_latents_method == "negative_index"

for ref in ref_latents:
ref_b, ref_c, ref_h, ref_w = ref.shape
ref_gh = ref_h
ref_gw = ref_w

if index_ref_method:
index += 1
gh_offset = 0
gw_offset = 0
elif negative_ref_method:
index -= 1
gh_offset = 0
gw_offset = 0
else: # offset/default
index = 1
gh_offset = 0
gw_offset = 0
if ref_gh + h > ref_gw + w:
gw_offset = w
else:
gh_offset = h
h = max(h, ref_gh + gh_offset)
w = max(w, ref_gw + gw_offset)

ref_tokens = self._img_to_tokens(ref)
ref_tokens_list.append(ref_tokens)
ref_num_tokens.append(ref_tokens.shape[1])

ref_pos = self._image_position_ids(ref_gh, ref_gw, device, index=index, h_offset=gh_offset, w_offset=gw_offset)
ref_pos_ids_list.append(ref_pos)

transformer_options = transformer_options.copy()
transformer_options["reference_image_num_tokens"] = ref_num_tokens

L_ref = sum(t.shape[1] for t in ref_tokens_list) if ref_tokens_list else 0
L = L_text + L_img + L_ref

x_full = torch.zeros(B, L, latent_dim, dtype=img_tokens.dtype, device=device)
x_full[:, L_text:] = img_tokens
x_full[:, L_text:L_text+L_img] = img_tokens

curr_idx = L_text + L_img
for ref_tokens in ref_tokens_list:
ref_len = ref_tokens.shape[1]
x_full[:, curr_idx:curr_idx+ref_len] = ref_tokens
curr_idx += ref_len

text_pos = torch.arange(L_text, device=device).view(-1, 1).expand(L_text, 3)
img_pos = self._image_position_ids(gh, gw, device)
position_ids = torch.cat([text_pos, img_pos], dim=0).unsqueeze(0).expand(B, L, 3)

pos_ids_all = [text_pos, img_pos]
for ref_pos in ref_pos_ids_list:
pos_ids_all.append(ref_pos)

position_ids = torch.cat(pos_ids_all, dim=0).unsqueeze(0).expand(B, L, 3)

indicator = torch.empty(B, L, dtype=torch.long, device=device)
indicator[:, :L_text] = LLM_TOKEN_INDICATOR
Expand All @@ -263,20 +321,84 @@ def _run_conditional(self, x_chunk, context_chunk, attn_mask_chunk, t_chunk, gh,

out = self._backbone(context_chunk, x_full, t_chunk, position_ids, attn_mask, indicator,
transformer_options=transformer_options)
return self._tokens_to_img(out[:, L_text:], gh, gw)
return self._tokens_to_img(out[:, L_text:L_text+L_img], gh, gw)

def _run_image_only(self, x_chunk, t_chunk, gh, gw, transformer_options):
def _run_image_only(self, x_chunk, t_chunk, gh, gw, transformer_options, ref_latents=None, ref_latents_method="index"):
B = x_chunk.shape[0]
device = x_chunk.device
img_tokens = self._img_to_tokens(x_chunk)
L_img = img_tokens.shape[1]
latent_dim = img_tokens.shape[-1]

position_ids = self._image_position_ids(gh, gw, device).unsqueeze(0).expand(B, L_img, 3)
indicator = torch.full((B, L_img), OUTPUT_IMAGE_INDICATOR, dtype=torch.long, device=device)
ref_tokens_list = []
ref_pos_ids_list = []
ref_num_tokens = []

if ref_latents is not None:
h = 0
w = 0
index = 0
index_ref_method = (ref_latents_method == "index") or (ref_latents_method == "index_timestep_zero")
negative_ref_method = ref_latents_method == "negative_index"

for ref in ref_latents:
ref_b, ref_c, ref_h, ref_w = ref.shape
ref_gh = ref_h
ref_gw = ref_w

if index_ref_method:
index += 1
gh_offset = 0
gw_offset = 0
elif negative_ref_method:
index -= 1
gh_offset = 0
gw_offset = 0
else: # offset/default
index = 1
gh_offset = 0
gw_offset = 0
if ref_gh + h > ref_gw + w:
gw_offset = w
else:
gh_offset = h
h = max(h, ref_gh + gh_offset)
w = max(w, ref_gw + gw_offset)

ref_tokens = self._img_to_tokens(ref)
ref_tokens_list.append(ref_tokens)
ref_num_tokens.append(ref_tokens.shape[1])

ref_pos = self._image_position_ids(ref_gh, ref_gw, device, index=index, h_offset=gh_offset, w_offset=gw_offset)
ref_pos_ids_list.append(ref_pos)

transformer_options = transformer_options.copy()
transformer_options["reference_image_num_tokens"] = ref_num_tokens

L_ref = sum(t.shape[1] for t in ref_tokens_list) if ref_tokens_list else 0
L_img_total = L_img + L_ref

x_full = torch.zeros(B, L_img_total, latent_dim, dtype=img_tokens.dtype, device=device)
x_full[:, :L_img] = img_tokens

curr_idx = L_img
for ref_tokens in ref_tokens_list:
ref_len = ref_tokens.shape[1]
x_full[:, curr_idx:curr_idx+ref_len] = ref_tokens
curr_idx += ref_len

img_pos = self._image_position_ids(gh, gw, device)

pos_ids_all = [img_pos]
for ref_pos in ref_pos_ids_list:
pos_ids_all.append(ref_pos)

position_ids = torch.cat(pos_ids_all, dim=0).unsqueeze(0).expand(B, L_img_total, 3)
indicator = torch.full((B, L_img_total), OUTPUT_IMAGE_INDICATOR, dtype=torch.long, device=device)

# Image-only sequence is a single segment -> no mask, full attention, no LLM context.
out = self._backbone(None, img_tokens, t_chunk, position_ids, None, indicator, transformer_options=transformer_options)
return self._tokens_to_img(out, gh, gw)
out = self._backbone(None, x_full, t_chunk, position_ids, None, indicator, transformer_options=transformer_options)
return self._tokens_to_img(out[:, :L_img], gh, gw)

def forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
Expand All @@ -290,8 +412,11 @@ def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_

timesteps = 1.0 - timesteps

ref_latents = kwargs.get("ref_latents", None)
ref_latents_method = kwargs.get("ref_latents_method", "index")

# unconditional pass
if context is None:
return -self._run_image_only(x, timesteps, gh, gw, transformer_options)
return -self._run_image_only(x, timesteps, gh, gw, transformer_options, ref_latents=ref_latents, ref_latents_method=ref_latents_method)

return -self._run_conditional(x, context, attention_mask, timesteps, gh, gw, transformer_options)
return -self._run_conditional(x, context, attention_mask, timesteps, gh, gw, transformer_options, ref_latents=ref_latents, ref_latents_method=ref_latents_method)
19 changes: 19 additions & 0 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2266,6 +2266,7 @@ def extra_conds_shapes(self, **kwargs):
class Ideogram4(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ideogram4.model.Ideogram4Transformer2DModel)
self.memory_usage_factor_conds = ("ref_latents",)

def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
Expand All @@ -2276,6 +2277,24 @@ def extra_conds(self, **kwargs):
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)

ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
latents = []
for lat in ref_latents:
latents.append(self.process_latent_in(lat))
out['ref_latents'] = comfy.conds.CONDList(latents)

ref_latents_method = kwargs.get("reference_latents_method", None)
if ref_latents_method is not None:
out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method)
return out

def extra_conds_shapes(self, **kwargs):
out = {}
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
out['ref_latents'] = list([1, 128, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
return out

class HunyuanImage21(BaseModel):
Expand Down
Loading