From b45f5de683a0f029343860062cc5da9fd93a4b11 Mon Sep 17 00:00:00 2001 From: silveroxides Date: Sat, 20 Jun 2026 17:58:09 +0200 Subject: [PATCH 1/2] Add reference latent capabilities to Ideogram 4 --- comfy/ldm/ideogram4/model.py | 157 +++++++++++++++++++++++++++++++---- comfy/model_base.py | 19 +++++ 2 files changed, 160 insertions(+), 16 deletions(-) diff --git a/comfy/ldm/ideogram4/model.py b/comfy/ldm/ideogram4/model.py index 4ea5b8aafbb2..c6d477e83595 100644 --- a/comfy/ldm/ideogram4/model.py +++ b/comfy/ldm/ideogram4/model.py @@ -226,27 +226,85 @@ def _tokens_to_img(self, tokens, gh, gw): x = x.permute(0, 5, 3, 4, 1, 2) # (B, c, pi, pj, gh, gw) return x.reshape(B, C, gh, gw) - def _image_position_ids(self, gh, gw, device): - h_idx = torch.arange(gh, device=device).view(-1, 1).expand(gh, gw).reshape(-1) - w_idx = torch.arange(gw, device=device).view(1, -1).expand(gh, gw).reshape(-1) - t_idx = torch.zeros_like(h_idx) + def _image_position_ids(self, gh, gw, device, index=0, h_offset=0, w_offset=0): + h_idx = torch.arange(gh, device=device).view(-1, 1).expand(gh, gw).reshape(-1) + h_offset + w_idx = torch.arange(gw, device=device).view(1, -1).expand(gh, gw).reshape(-1) + w_offset + t_idx = torch.full_like(h_idx, index) return torch.stack([t_idx, h_idx, w_idx], dim=1) + IMAGE_POSITION_OFFSET # (L_img, 3) - def _run_conditional(self, x_chunk, context_chunk, attn_mask_chunk, t_chunk, gh, gw, transformer_options): + def _run_conditional(self, x_chunk, context_chunk, attn_mask_chunk, t_chunk, gh, gw, transformer_options, ref_latents=None, ref_latents_method="index"): B = x_chunk.shape[0] device = x_chunk.device img_tokens = self._img_to_tokens(x_chunk) L_img = img_tokens.shape[1] L_text = context_chunk.shape[1] - L = L_text + L_img latent_dim = img_tokens.shape[-1] + ref_tokens_list = [] + ref_pos_ids_list = [] + ref_num_tokens = [] + + if ref_latents is not None: + h = 0 + w = 0 + index = 0 + index_ref_method = (ref_latents_method == "index") or (ref_latents_method == "index_timestep_zero") + negative_ref_method = ref_latents_method == "negative_index" + + for ref in ref_latents: + ref_b, ref_c, ref_h, ref_w = ref.shape + ref_gh = ref_h // self.patch_size + ref_gw = ref_w // self.patch_size + + if index_ref_method: + index += 1 + gh_offset = 0 + gw_offset = 0 + elif negative_ref_method: + index -= 1 + gh_offset = 0 + gw_offset = 0 + else: # offset/default + index = 1 + gh_offset = 0 + gw_offset = 0 + if ref_gh + h > ref_gw + w: + gw_offset = w + else: + gh_offset = h + h = max(h, ref_gh + gh_offset) + w = max(w, ref_gw + gw_offset) + + ref_tokens = self._img_to_tokens(ref) + ref_tokens_list.append(ref_tokens) + ref_num_tokens.append(ref_tokens.shape[1]) + + ref_pos = self._image_position_ids(ref_gh, ref_gw, device, index=index, h_offset=gh_offset, w_offset=gw_offset) + ref_pos_ids_list.append(ref_pos) + + transformer_options = transformer_options.copy() + transformer_options["reference_image_num_tokens"] = ref_num_tokens + + L_ref = sum(t.shape[1] for t in ref_tokens_list) if ref_tokens_list else 0 + L = L_text + L_img + L_ref + x_full = torch.zeros(B, L, latent_dim, dtype=img_tokens.dtype, device=device) - x_full[:, L_text:] = img_tokens + x_full[:, L_text:L_text+L_img] = img_tokens + + curr_idx = L_text + L_img + for ref_tokens in ref_tokens_list: + ref_len = ref_tokens.shape[1] + x_full[:, curr_idx:curr_idx+ref_len] = ref_tokens + curr_idx += ref_len text_pos = torch.arange(L_text, device=device).view(-1, 1).expand(L_text, 3) img_pos = self._image_position_ids(gh, gw, device) - position_ids = torch.cat([text_pos, img_pos], dim=0).unsqueeze(0).expand(B, L, 3) + + pos_ids_all = [text_pos, img_pos] + for ref_pos in ref_pos_ids_list: + pos_ids_all.append(ref_pos) + + position_ids = torch.cat(pos_ids_all, dim=0).unsqueeze(0).expand(B, L, 3) indicator = torch.empty(B, L, dtype=torch.long, device=device) indicator[:, :L_text] = LLM_TOKEN_INDICATOR @@ -263,20 +321,84 @@ def _run_conditional(self, x_chunk, context_chunk, attn_mask_chunk, t_chunk, gh, out = self._backbone(context_chunk, x_full, t_chunk, position_ids, attn_mask, indicator, transformer_options=transformer_options) - return self._tokens_to_img(out[:, L_text:], gh, gw) + return self._tokens_to_img(out[:, L_text:L_text+L_img], gh, gw) - def _run_image_only(self, x_chunk, t_chunk, gh, gw, transformer_options): + def _run_image_only(self, x_chunk, t_chunk, gh, gw, transformer_options, ref_latents=None, ref_latents_method="index"): B = x_chunk.shape[0] device = x_chunk.device img_tokens = self._img_to_tokens(x_chunk) L_img = img_tokens.shape[1] + latent_dim = img_tokens.shape[-1] - position_ids = self._image_position_ids(gh, gw, device).unsqueeze(0).expand(B, L_img, 3) - indicator = torch.full((B, L_img), OUTPUT_IMAGE_INDICATOR, dtype=torch.long, device=device) + ref_tokens_list = [] + ref_pos_ids_list = [] + ref_num_tokens = [] + + if ref_latents is not None: + h = 0 + w = 0 + index = 0 + index_ref_method = (ref_latents_method == "index") or (ref_latents_method == "index_timestep_zero") + negative_ref_method = ref_latents_method == "negative_index" + + for ref in ref_latents: + ref_b, ref_c, ref_h, ref_w = ref.shape + ref_gh = ref_h // self.patch_size + ref_gw = ref_w // self.patch_size + + if index_ref_method: + index += 1 + gh_offset = 0 + gw_offset = 0 + elif negative_ref_method: + index -= 1 + gh_offset = 0 + gw_offset = 0 + else: # offset/default + index = 1 + gh_offset = 0 + gw_offset = 0 + if ref_gh + h > ref_gw + w: + gw_offset = w + else: + gh_offset = h + h = max(h, ref_gh + gh_offset) + w = max(w, ref_gw + gw_offset) + + ref_tokens = self._img_to_tokens(ref) + ref_tokens_list.append(ref_tokens) + ref_num_tokens.append(ref_tokens.shape[1]) + + ref_pos = self._image_position_ids(ref_gh, ref_gw, device, index=index, h_offset=gh_offset, w_offset=gw_offset) + ref_pos_ids_list.append(ref_pos) + + transformer_options = transformer_options.copy() + transformer_options["reference_image_num_tokens"] = ref_num_tokens + + L_ref = sum(t.shape[1] for t in ref_tokens_list) if ref_tokens_list else 0 + L_img_total = L_img + L_ref + + x_full = torch.zeros(B, L_img_total, latent_dim, dtype=img_tokens.dtype, device=device) + x_full[:, :L_img] = img_tokens + + curr_idx = L_img + for ref_tokens in ref_tokens_list: + ref_len = ref_tokens.shape[1] + x_full[:, curr_idx:curr_idx+ref_len] = ref_tokens + curr_idx += ref_len + + img_pos = self._image_position_ids(gh, gw, device) + + pos_ids_all = [img_pos] + for ref_pos in ref_pos_ids_list: + pos_ids_all.append(ref_pos) + + position_ids = torch.cat(pos_ids_all, dim=0).unsqueeze(0).expand(B, L_img_total, 3) + indicator = torch.full((B, L_img_total), OUTPUT_IMAGE_INDICATOR, dtype=torch.long, device=device) # Image-only sequence is a single segment -> no mask, full attention, no LLM context. - out = self._backbone(None, img_tokens, t_chunk, position_ids, None, indicator, transformer_options=transformer_options) - return self._tokens_to_img(out, gh, gw) + out = self._backbone(None, x_full, t_chunk, position_ids, None, indicator, transformer_options=transformer_options) + return self._tokens_to_img(out[:, :L_img], gh, gw) def forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( @@ -290,8 +412,11 @@ def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_ timesteps = 1.0 - timesteps + ref_latents = kwargs.get("ref_latents", None) + ref_latents_method = kwargs.get("ref_latents_method", "index") + # unconditional pass if context is None: - return -self._run_image_only(x, timesteps, gh, gw, transformer_options) + return -self._run_image_only(x, timesteps, gh, gw, transformer_options, ref_latents=ref_latents, ref_latents_method=ref_latents_method) - return -self._run_conditional(x, context, attention_mask, timesteps, gh, gw, transformer_options) + return -self._run_conditional(x, context, attention_mask, timesteps, gh, gw, transformer_options, ref_latents=ref_latents, ref_latents_method=ref_latents_method) diff --git a/comfy/model_base.py b/comfy/model_base.py index 264dbb9b33df..8a568f35ce7c 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -2266,6 +2266,7 @@ def extra_conds_shapes(self, **kwargs): class Ideogram4(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ideogram4.model.Ideogram4Transformer2DModel) + self.memory_usage_factor_conds = ("ref_latents",) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) @@ -2276,6 +2277,24 @@ def extra_conds(self, **kwargs): cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + latents = [] + for lat in ref_latents: + latents.append(self.process_latent_in(lat)) + out['ref_latents'] = comfy.conds.CONDList(latents) + + ref_latents_method = kwargs.get("reference_latents_method", None) + if ref_latents_method is not None: + out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method) + return out + + def extra_conds_shapes(self, **kwargs): + out = {} + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + out['ref_latents'] = list([1, 128, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))]) return out class HunyuanImage21(BaseModel): From 6e434c785b762b0d9b20bd1ead406f92d6ec686f Mon Sep 17 00:00:00 2001 From: silveroxides Date: Sat, 20 Jun 2026 18:08:55 +0200 Subject: [PATCH 2/2] Fix missing position ids --- comfy/ldm/ideogram4/model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/ideogram4/model.py b/comfy/ldm/ideogram4/model.py index c6d477e83595..769eb53a0a48 100644 --- a/comfy/ldm/ideogram4/model.py +++ b/comfy/ldm/ideogram4/model.py @@ -253,8 +253,8 @@ def _run_conditional(self, x_chunk, context_chunk, attn_mask_chunk, t_chunk, gh, for ref in ref_latents: ref_b, ref_c, ref_h, ref_w = ref.shape - ref_gh = ref_h // self.patch_size - ref_gw = ref_w // self.patch_size + ref_gh = ref_h + ref_gw = ref_w if index_ref_method: index += 1 @@ -343,8 +343,8 @@ def _run_image_only(self, x_chunk, t_chunk, gh, gw, transformer_options, ref_lat for ref in ref_latents: ref_b, ref_c, ref_h, ref_w = ref.shape - ref_gh = ref_h // self.patch_size - ref_gw = ref_w // self.patch_size + ref_gh = ref_h + ref_gw = ref_w if index_ref_method: index += 1