-
Notifications
You must be signed in to change notification settings - Fork 13.8k
feat: Ideogram structured-caption nodes (CORE-292) #14537
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| def hex_to_rgb(value: str) -> tuple[int, int, int]: | ||
|
jtydhr88 marked this conversation as resolved.
|
||
| h = value.lstrip("#") | ||
| if len(h) != 6: | ||
| return (255, 255, 255) | ||
| try: | ||
| return (int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16)) | ||
| except ValueError: | ||
| return (255, 255, 255) | ||
|
|
||
|
|
||
| def readable_color(rgb: tuple[int, int, int]) -> tuple[int, int, int]: | ||
| r, g, b = rgb | ||
| lum = 0.299 * r + 0.587 * g + 0.114 * b | ||
| if lum >= 130: | ||
| return (r, g, b) | ||
| t = (130 - lum) / (255 - lum) | ||
| return (round(r + (255 - r) * t), round(g + (255 - g) * t), round(b + (255 - b) * t)) | ||
|
|
||
|
|
||
| def normalize_palette(colors) -> list[str]: | ||
| if isinstance(colors, dict): | ||
| colors = colors.values() | ||
| return [c.upper() for c in colors if isinstance(c, str) and c] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,253 @@ | ||
| import numpy as np | ||
| import torch | ||
| from PIL import Image, ImageDraw, ImageEnhance, ImageFont | ||
| from typing_extensions import override | ||
|
|
||
| from comfy_api.latest import ComfyExtension, io | ||
| from comfy_extras.color_util import hex_to_rgb, normalize_palette, readable_color | ||
|
|
||
| _PREVIEW_LONG_EDGE = 1024 | ||
| _PREVIEW_DIM = 0.25 | ||
|
|
||
|
|
||
| def pixels_to_fractions(box: dict, width: int, height: int) -> dict: | ||
| w = width or 1 | ||
| h = height or 1 | ||
| return { | ||
| "x": box.get("x", 0) / w, | ||
| "y": box.get("y", 0) / h, | ||
| "w": box.get("width", 0) / w, | ||
| "h": box.get("height", 0) / h, | ||
| } | ||
|
|
||
|
|
||
| def fractions_to_pixels(box: dict, width: int, height: int) -> dict: | ||
| x, y = box.get("x", 0.0), box.get("y", 0.0) | ||
| w, h = box.get("w", 0.0), box.get("h", 0.0) | ||
| if w < 0: | ||
| x, w = x + w, -w | ||
| if h < 0: | ||
| y, h = y + h, -h | ||
| return { | ||
| "x": round(x * width), | ||
| "y": round(y * height), | ||
| "width": round(w * width), | ||
| "height": round(h * height), | ||
| } | ||
|
|
||
|
|
||
| def fractions_to_bbox_frame(boxes: list, width: int, height: int) -> list: | ||
| pixels = [ | ||
| fractions_to_pixels(box, width, height) | ||
| for box in boxes | ||
| if isinstance(box, dict) | ||
| ] | ||
| return [pixels] if pixels else [] | ||
|
|
||
|
|
||
| def _font(size: int): | ||
| try: | ||
| return ImageFont.load_default(size) | ||
| except Exception: | ||
| return ImageFont.load_default() | ||
|
|
||
|
|
||
| def _wrap(draw, text: str, font, max_w: float) -> list[str]: | ||
| lines = [] | ||
| for para in text.split("\n"): | ||
| line = "" | ||
| for word in para.split(): | ||
| test = word if not line else line + " " + word | ||
| if line and draw.textlength(test, font=font) > max_w: | ||
| lines.append(line) | ||
| line = word | ||
| else: | ||
| line = test | ||
| lines.append(line) | ||
| return lines | ||
|
|
||
|
|
||
| def _bg_from_image(image) -> Image.Image | None: | ||
| if image is None: | ||
| return None | ||
| try: | ||
| arr = (image[0].detach().cpu().numpy() * 255).clip(0, 255).astype(np.uint8) | ||
| return Image.fromarray(arr) | ||
| except Exception: | ||
| return None | ||
|
|
||
|
|
||
| def render_preview(regions, width, height, bg=None): | ||
| if bg is not None: | ||
| iw, ih = bg.size | ||
| long_edge = max(iw, ih) or 1 | ||
| scale = min(1.0, _PREVIEW_LONG_EDGE / long_edge) | ||
| rw, rh = max(1, round(iw * scale)), max(1, round(ih * scale)) | ||
| base = bg.convert("RGB").resize((rw, rh), Image.LANCZOS) | ||
| base = ImageEnhance.Brightness(base).enhance(_PREVIEW_DIM) | ||
| img = base.convert("RGBA") | ||
| else: | ||
| long_edge = max(width, height) or 1 | ||
| scale = min(1.0, _PREVIEW_LONG_EDGE / long_edge) | ||
| rw, rh = max(1, round(width * scale)), max(1, round(height * scale)) | ||
| grey = round(_PREVIEW_DIM * 128) | ||
| img = Image.new("RGBA", (rw, rh), (grey, grey, grey, 255)) | ||
|
|
||
| overlay = Image.new("RGBA", (rw, rh), (0, 0, 0, 0)) | ||
| draw = ImageDraw.Draw(overlay) | ||
| fs = max(10, round(rh / 64)) | ||
| font = _font(fs) | ||
| tag_font = _font(max(9, fs - 2)) | ||
| line_h = fs + 2 | ||
|
|
||
| for i, region in enumerate(regions): | ||
| if not isinstance(region, dict): | ||
| continue | ||
| palette = [c for c in (region.get("palette") or []) if c] | ||
| r, g, b = hex_to_rgb(palette[0]) if palette else (140, 140, 140) | ||
| x1 = max(0, min(rw, round(region.get("x", 0) * rw))) | ||
| y1 = max(0, min(rh, round(region.get("y", 0) * rh))) | ||
| x2 = max(0, min(rw, round((region.get("x", 0) + region.get("w", 0)) * rw))) | ||
| y2 = max(0, min(rh, round((region.get("y", 0) + region.get("h", 0)) * rh))) | ||
| if x2 < x1: | ||
| x1, x2 = x2, x1 | ||
| if y2 < y1: | ||
| y1, y2 = y2, y1 | ||
|
|
||
| draw.rectangle([x1, y1, x2, y2], outline=(r, g, b, 255), width=2) | ||
|
|
||
| swatches = palette[:5] | ||
| if swatches and (x2 - x1) > 2: | ||
| sh = max(5, fs // 2) | ||
| seg = (x2 - x1) / len(swatches) | ||
| for p, hexc in enumerate(swatches): | ||
| sx = x1 + round(p * seg) | ||
| draw.rectangle([sx, y1, x1 + round((p + 1) * seg), y1 + sh], fill=hex_to_rgb(hexc)) | ||
|
|
||
| etype = "text" if region.get("type") == "text" else "obj" | ||
| tag = str(i + 1).zfill(2) | ||
| tw = draw.textlength(tag, font=tag_font) | ||
| draw.rectangle([x1, y1, x1 + tw + 6, y1 + fs + 2], fill=(r, g, b, 255)) | ||
| tag_fill = (0, 0, 0, 255) if (0.299 * r + 0.587 * g + 0.114 * b) > 140 else (255, 255, 255, 255) | ||
| draw.text((x1 + 3, y1 + 1), tag, fill=tag_fill, font=tag_font) | ||
|
|
||
| body = region.get("desc", "") or "" | ||
| if etype == "text" and region.get("text"): | ||
| body = '"%s"%s' % (region["text"], " — " + body if body else "") | ||
| if body and (x2 - x1) > 8: | ||
| ty = y1 + fs + 5 | ||
| for line in _wrap(draw, body, font, x2 - x1 - 8): | ||
| if ty > y2: | ||
| break | ||
| draw.text((x1 + 4, ty), line, fill=readable_color((r, g, b)) + (255,), font=font) | ||
| ty += line_h | ||
|
|
||
| composed = Image.alpha_composite(img, overlay).convert("RGB") | ||
| arr = np.asarray(composed, dtype=np.float32) / 255.0 | ||
| return torch.from_numpy(arr).unsqueeze(0) | ||
|
|
||
|
|
||
| def boxes_to_regions(boxes, width: int, height: int) -> list: | ||
| regions: list = [] | ||
| if not isinstance(boxes, list): | ||
| return regions | ||
| for box in boxes: | ||
| if not isinstance(box, dict): | ||
| continue | ||
| meta = box.get("metadata") | ||
| meta = meta if isinstance(meta, dict) else {} | ||
| regions.append({ | ||
| **pixels_to_fractions(box, width, height), | ||
| "type": meta.get("type", "obj"), | ||
| "text": meta.get("text", ""), | ||
| "desc": meta.get("desc", ""), | ||
| "palette": meta.get("palette", []), | ||
| }) | ||
| return regions | ||
|
|
||
|
|
||
| def _norm_bbox(region: dict) -> list[int]: | ||
| def grid(value: float) -> int: | ||
| return max(0, min(1000, round(value * 1000))) | ||
|
|
||
| x, y = region.get("x", 0.0), region.get("y", 0.0) | ||
| w, h = region.get("w", 0.0), region.get("h", 0.0) | ||
| ymin, xmin, ymax, xmax = grid(y), grid(x), grid(y + h), grid(x + w) | ||
| if ymin > ymax: | ||
| ymin, ymax = ymax, ymin | ||
| if xmin > xmax: | ||
| xmin, xmax = xmax, xmin | ||
| return [ymin, xmin, ymax, xmax] | ||
|
|
||
|
|
||
| def build_elements(regions: list) -> list: | ||
| elements = [] | ||
| for region in regions: | ||
| if not isinstance(region, dict): | ||
| continue | ||
| etype = "text" if region.get("type") == "text" else "obj" | ||
| element = {"type": etype} | ||
| element["bbox"] = _norm_bbox(region) | ||
| if etype == "text": | ||
| element["text"] = region.get("text", "") | ||
| element["desc"] = region.get("desc", "") | ||
| palette = normalize_palette(region.get("palette", [])) | ||
| if palette: | ||
| element["color_palette"] = palette[:5] | ||
| elements.append(element) | ||
| return elements | ||
|
|
||
|
|
||
| class CreateBoundingBoxes(io.ComfyNode): | ||
| @classmethod | ||
| def define_schema(cls): | ||
| editor_state = io.BoundingBoxes.Input( | ||
| "editor_state", | ||
| socketless=False, | ||
| tooltip="Draw bounding boxes and set each box type, text, description, color palette. Start with background element first and foreground last.", | ||
| ) | ||
| return io.Schema( | ||
| node_id="CreateBoundingBoxes", | ||
| display_name="Create Bounding Boxes", | ||
| category="utilities", | ||
| description="Draw bounding boxes in a canvas. Outputs Ideogram prompt elements, pixel-space bounding boxes, and a preview image.", | ||
| inputs=[ | ||
| io.Image.Input( | ||
| "background", | ||
| optional=True, | ||
| tooltip="Optional image used as background in the canvas and preview.", | ||
| ), | ||
| io.Int.Input("width", default=1024, min=64, max=16384, step=16, | ||
| tooltip="Width of the canvas and the pixel grid for the bounding boxes."), | ||
| io.Int.Input("height", default=1024, min=64, max=16384, step=16, | ||
| tooltip="Height of the canvas and the pixel grid for the bounding boxes."), | ||
| editor_state, | ||
| ], | ||
| outputs=[ | ||
| io.Image.Output(display_name="preview"), | ||
| io.BoundingBox.Output(display_name="bboxes"), | ||
| io.Array.Output(display_name="elements"), | ||
| ], | ||
| is_experimental=True, | ||
| ) | ||
|
|
||
| @classmethod | ||
| def execute(cls, width, height, editor_state=None, background=None) -> io.NodeOutput: | ||
| regions = boxes_to_regions(editor_state, width, height) | ||
| preview = render_preview(regions, width, height, _bg_from_image(background)) | ||
| return io.NodeOutput( | ||
| preview, | ||
| fractions_to_bbox_frame(regions, width, height), | ||
| build_elements(regions), | ||
| ui={"dims": [width, height]}, | ||
| ) | ||
|
|
||
|
|
||
| class BoundingBoxesExtension(ComfyExtension): | ||
| @override | ||
| async def get_node_list(self) -> list[type[io.ComfyNode]]: | ||
| return [CreateBoundingBoxes] | ||
|
|
||
|
|
||
| async def comfy_entrypoint() -> BoundingBoxesExtension: | ||
| return BoundingBoxesExtension() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| from typing_extensions import override | ||
| from comfy_api.latest import ComfyExtension, io | ||
| from comfy_extras.color_util import hex_to_rgb | ||
|
|
||
|
|
||
| class ColorToRGBInt(io.ComfyNode): | ||
|
|
@@ -24,9 +25,7 @@ def execute(cls, color: str) -> io.NodeOutput: | |
| # expect format #RRGGBB | ||
| if len(color) != 7 or color[0] != "#": | ||
| raise ValueError("Color must be in format #RRGGBB") | ||
| r = int(color[1:3], 16) | ||
| g = int(color[3:5], 16) | ||
| b = int(color[5:7], 16) | ||
| r, g, b = hex_to_rgb(color) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🎯 Functional Correctness | 🟠 Major | ⚡ Quick win Preserve strict invalid-hex rejection in On Line 28, switching to 🔧 Proposed fix def execute(cls, color: str) -> io.NodeOutput:
# expect format `#RRGGBB`
if len(color) != 7 or color[0] != "#":
raise ValueError("Color must be in format `#RRGGBB`")
+ try:
+ int(color[1:], 16)
+ except ValueError:
+ raise ValueError("Color must be in format `#RRGGBB`")
r, g, b = hex_to_rgb(color)
rgb_int = r * 256 * 256 + g * 256 + b
return io.NodeOutput(rgb_int, color)As per path instructions, 🤖 Prompt for AI AgentsSource: Path instructions |
||
|
|
||
| rgb_int = r * 256 * 256 + g * 256 + b | ||
| return io.NodeOutput(rgb_int, color) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense and would it be possible to reuse the existing type
@comfytype(io_type="BBOX")?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same situation as COLORS, plus a shape mismatch.
the BOUNDING_BOXES is a concrete list[{x, y, width, height, metadata: {type, text, desc, palette}}], and it has to map to our canvas editor widget.
Two reasons we can't reuse BBOX:
So BOUNDING_BOXES reuses the existing BOUNDING_BOX shape where it fits and exists as a distinct list type purely to drive the editor widget.