Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions comfy/latent_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,9 @@ class ACEAudio(LatentFormat):
latent_channels = 8
latent_dimensions = 2

class SeedVR2(LatentFormat):
latent_channels = 16

class ACEAudio15(LatentFormat):
latent_channels = 64
latent_dimensions = 1
Expand Down
6 changes: 4 additions & 2 deletions comfy/ldm/modules/diffusionmodules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def torch_cat_if_needed(xl, dim):
else:
return None

def get_timestep_embedding(timesteps, embedding_dim):
def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Expand All @@ -33,11 +33,13 @@ def get_timestep_embedding(timesteps, embedding_dim):
assert len(timesteps.shape) == 1

half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = math.log(10000) / (half_dim - downscale_freq_shift)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
Comment on lines 35 to 37

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Guard the new frequency-shift parameter.

downscale_freq_shift now feeds the divisor directly. Values >= embedding_dim // 2 make the frequency scale undefined or inverted, so this helper can return broken timestep embeddings instead of failing fast.

Suggested guard
 def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1):
     assert len(timesteps.shape) == 1

     half_dim = embedding_dim // 2
+    if half_dim > 0 and downscale_freq_shift >= half_dim:
+        raise ValueError("downscale_freq_shift must be smaller than embedding_dim // 2")
     emb = math.log(10000) / (half_dim - downscale_freq_shift)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = math.log(10000) / (half_dim - downscale_freq_shift)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1):
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
if half_dim > 0 and downscale_freq_shift >= half_dim:
raise ValueError("downscale_freq_shift must be smaller than embedding_dim // 2")
emb = math.log(10000) / (half_dim - downscale_freq_shift)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@comfy/ldm/modules/diffusionmodules/model.py` around lines 35 - 37, Guard the
new downscale_freq_shift before it's used to avoid division by zero/negative
scale: check that downscale_freq_shift is an integer >= 0 and strictly less than
half_dim (where half_dim = embedding_dim // 2) so that (half_dim -
downscale_freq_shift) > 0; if the check fails, raise a clear ValueError
explaining that downscale_freq_shift must be in [0, half_dim-1]. Apply this
validation immediately before the existing emb computation that uses emb =
math.log(10000) / (half_dim - downscale_freq_shift) so the function (timestep
embedding helper) fails fast instead of producing invalid tensors.

emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0,1,0,0))
return emb
Expand Down
77 changes: 77 additions & 0 deletions comfy/ldm/seedvr/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import torch

from comfy.ldm.modules import attention as _attention


def _var_attention_qkv(q, k, v, heads, skip_reshape):
if skip_reshape:
return q, k, v, q.shape[-1]
total_tokens, embed_dim = q.shape
head_dim = embed_dim // heads
return (
q.view(total_tokens, heads, head_dim),
k.view(k.shape[0], heads, head_dim),
v.view(v.shape[0], heads, head_dim),
head_dim,
)


def _var_attention_output(out, heads, head_dim, skip_output_reshape):
if skip_output_reshape:
return out
return out.reshape(-1, heads * head_dim)


def _validate_split_cu_seqlens(name, cu_seqlens, token_count):
if cu_seqlens.dtype not in (torch.int32, torch.int64):
raise ValueError(f"{name} must use an integer dtype")
if cu_seqlens.ndim != 1 or cu_seqlens.numel() < 2:
raise ValueError(f"{name} must be a 1D tensor with at least two offsets")
if cu_seqlens[0].item() != 0:
raise ValueError(f"{name} must start at 0")
if (cu_seqlens[1:] <= cu_seqlens[:-1]).any().item():
raise ValueError(f"{name} must be strictly increasing")
if cu_seqlens[-1].item() != token_count:
raise ValueError(f"{name} does not match token count")


def _split_indices(cu_seqlens):
return cu_seqlens[1:-1].to(device="cpu", dtype=torch.long)


def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)

_validate_split_cu_seqlens("cu_seqlens_q", cu_seqlens_q, q.shape[0])
_validate_split_cu_seqlens("cu_seqlens_k", cu_seqlens_k, k.shape[0])
if cu_seqlens_k[-1].item() != v.shape[0]:
raise ValueError("cu_seqlens_k does not match v token count")

q_split_indices = _split_indices(cu_seqlens_q)
k_split_indices = _split_indices(cu_seqlens_k)
q_splits = torch.tensor_split(q, q_split_indices, dim=0)
k_splits = torch.tensor_split(k, k_split_indices, dim=0)
v_splits = torch.tensor_split(v, k_split_indices, dim=0)
if len(q_splits) != len(k_splits) or len(q_splits) != len(v_splits):
raise ValueError("cu_seqlens_q and cu_seqlens_k must describe the same sequence count")

out = []
for q_i, k_i, v_i in zip(q_splits, k_splits, v_splits):
q_i = q_i.permute(1, 0, 2).unsqueeze(0)
k_i = k_i.permute(1, 0, 2).unsqueeze(0)
v_i = v_i.permute(1, 0, 2).unsqueeze(0)
out_dtype = q_i.dtype
if _attention.optimized_attention is _attention.attention_sage and q_i.dtype not in (torch.float16, torch.bfloat16):
q_i = q_i.to(torch.bfloat16)
k_i = k_i.to(torch.bfloat16)
v_i = v_i.to(torch.bfloat16)
out_i = _attention.optimized_attention(q_i, k_i, v_i, heads, skip_reshape=True, skip_output_reshape=True)
if out_i.dtype != out_dtype:
out_i = out_i.to(out_dtype)
out.append(out_i.squeeze(0).permute(1, 0, 2))

out = torch.cat(out, dim=0)
return _var_attention_output(out, heads, head_dim, skip_output_reshape)


optimized_var_attention = var_attention_optimized_split
Loading
Loading