Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 84 additions & 33 deletions playbooks/core/pytorch-rocm-llms/assets/summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@

class DocumentSummarizer:
"""Summarize documents using Large Language Models"""
def __init__(self, model="mistral"):

def __init__(self, model="gptoss"):
"""
Initialize the summarizer with specified model.

Expand All @@ -38,10 +38,11 @@ def __init__(self, model="mistral"):
"""
if model not in MODELS:
raise ValueError(f"Model must be one of: {list(MODELS.keys())}")


self.model_key = model
self.model_name = MODELS[model]
print(f"Loading {model.upper()} ({self.model_name})...")

self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
Expand Down Expand Up @@ -81,26 +82,40 @@ def _build_messages(self, text):
),
}
]
def summarize(self, text, max_length=150, temperature=0.3):

def summarize(self, text, max_length=1024, temperature=0.3, reasoning_effort="low"):
"""
Summarize the given text.

Args:
text: Text to summarize
max_length: Maximum number of tokens to generate
temperature: Sampling temperature (0.1-1.0)

reasoning_effort: For Harmony-format models (gpt-oss):
"low" | "medium" | "high". Ignored by other models.

Returns:
Summary string
"""
messages = self._build_messages(text)
prompt = self.tokenizer.apply_chat_template(
messages,

# gpt-oss supports a `reasoning_effort` argument in its chat template.
# For non-Harmony models this kwarg will be silently ignored by most
# templates, but we guard it just in case.
template_kwargs = dict(
tokenize=False,
add_generation_prompt=True,
)

if self.model_key == "gptoss":
template_kwargs["reasoning_effort"] = reasoning_effort

try:
prompt = self.tokenizer.apply_chat_template(messages, **template_kwargs)
except TypeError:
# Template doesn't accept reasoning_effort -- retry without it
template_kwargs.pop("reasoning_effort", None)
prompt = self.tokenizer.apply_chat_template(messages, **template_kwargs)

inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
outputs = self.model.generate(
**inputs,
Expand All @@ -111,33 +126,63 @@ def summarize(self, text, max_length=150, temperature=0.3):
)

full = self.tokenizer.decode(outputs[0], skip_special_tokens=False)

# Clean up Harmony-style outputs
# This is needed for summarization because Harmony-style models (like gpt-oss)
# often include tool/analysis outputs and multiple assistant channels. We only want the final output for summarization.

if "<|start|>assistant<|channel|>final<|message|>" in full:
final_part = full.split("<|start|>assistant<|channel|>final<|message|>", 1)[1]
# Stop at <|return|> or <|end|> if present
for stop_tok in ("<|return|>", "<|end|>"):
# Harmony-format models (like gpt-oss) emit multiple channels --
# `analysis` (chain-of-thought) and `final` (user-facing answer).
# We only want the `final` channel for summarization.

final_marker = "<|channel|>final<|message|>"

if final_marker in full:
final_part = full.split(final_marker, 1)[1]
# Stop at end-of-turn markers if present
for stop_tok in ("<|return|>", "<|end|>", "<|start|>"):
if stop_tok in final_part:
final_part = final_part.split(stop_tok, 1)[0]
break
summary = final_part.strip()
else:
# Fallback for models without Harmony format (e.g. Mistral)
summary = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()

# Collapse to the last non-empty line to avoid echoed prompts
lines = [l.strip() for l in summary.splitlines() if l.strip()]
return lines[-1] if lines else ""

if not summary:
return (
"[Warning: model produced an empty final answer. "
"Try increasing --max-length or lowering --reasoning-effort.]"
)
return summary

# No final-channel marker found.
if "<|channel|>analysis" in full or "assistantanalysis" in full:
# Model ran out of tokens while still reasoning.
return (
"[Warning: model did not produce a final answer before hitting "
"the token limit. Try increasing --max-length (e.g. 2048) or "
"setting --reasoning-effort low.]"
)

# Fallback for non-Harmony models (e.g. Mistral): decode without
# special tokens and strip the echoed prompt.
decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
input_decoded = self.tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
if decoded.startswith(input_decoded):
decoded = decoded[len(input_decoded):]
return decoded.strip()


def main():
parser = argparse.ArgumentParser(description="Summarize documents using LLMs")
parser.add_argument("--model", default="gptoss", choices=["mistral", "gptoss"], help="Model to use (default: gptoss)")
parser.add_argument("--model", default="gptoss", choices=list(MODELS.keys()),
help="Model to use (default: gptoss)")
parser.add_argument("--file", default=None, help="Path to .txt file to summarize")
parser.add_argument("--max-length", type=int, default=250, help="Maximum tokens to generate (default: 150)")
parser.add_argument("--temperature", type=float, default=0.3, help="Sampling temperature 0.1-1.0 (default: 0.3)")
parser.add_argument("--max-length", type=int, default=1024,
help="Maximum tokens to generate (default: 1024). "
"gpt-oss needs headroom for its reasoning channel.")
parser.add_argument("--temperature", type=float, default=0.3,
help="Sampling temperature 0.1-1.0 (default: 0.3)")
parser.add_argument("--reasoning-effort", default="low",
choices=["low", "medium", "high"],
help="Reasoning effort for Harmony models like gpt-oss "
"(default: low). Lower = less thinking, faster, "
"more likely to fit in --max-length.")
args = parser.parse_args()

summarizer = DocumentSummarizer(model=args.model)
Expand Down Expand Up @@ -168,11 +213,17 @@ def main():

# Generate summary
print("Generating summary...")
summary = summarizer.summarize(document, args.max_length, args.temperature)
summary = summarizer.summarize(
document,
max_length=args.max_length,
temperature=args.temperature,
reasoning_effort=args.reasoning_effort,
)
print(summary)
print(f"\n[OK] Done! (max_length={args.max_length}, temperature={args.temperature})\n")

# Cleanup
print(f"\n[OK] Done! (max_length={args.max_length}, "
f"temperature={args.temperature}, "
f"reasoning_effort={args.reasoning_effort})\n")

summarizer.cleanup()


Expand Down
Loading