From cac4c7c4acb66e3d4d8dfe273dd06e7e6b0c22a1 Mon Sep 17 00:00:00 2001 From: mdokukin Date: Wed, 27 May 2026 16:21:32 -0700 Subject: [PATCH] Fixed gpt-oss response truncation by raising token limit, adding reasoning-effort control, properly extracting the final channel, and detecting incomplete generations --- .../pytorch-rocm-llms/assets/summarizer.py | 117 +++++++++++++----- 1 file changed, 84 insertions(+), 33 deletions(-) diff --git a/playbooks/core/pytorch-rocm-llms/assets/summarizer.py b/playbooks/core/pytorch-rocm-llms/assets/summarizer.py index 11f41213..0a1519f7 100644 --- a/playbooks/core/pytorch-rocm-llms/assets/summarizer.py +++ b/playbooks/core/pytorch-rocm-llms/assets/summarizer.py @@ -28,8 +28,8 @@ class DocumentSummarizer: """Summarize documents using Large Language Models""" - - def __init__(self, model="mistral"): + + def __init__(self, model="gptoss"): """ Initialize the summarizer with specified model. @@ -38,10 +38,11 @@ def __init__(self, model="mistral"): """ if model not in MODELS: raise ValueError(f"Model must be one of: {list(MODELS.keys())}") - + + self.model_key = model self.model_name = MODELS[model] print(f"Loading {model.upper()} ({self.model_name})...") - + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.model = AutoModelForCausalLM.from_pretrained( self.model_name, @@ -81,26 +82,40 @@ def _build_messages(self, text): ), } ] - - def summarize(self, text, max_length=150, temperature=0.3): + + def summarize(self, text, max_length=1024, temperature=0.3, reasoning_effort="low"): """ Summarize the given text. - + Args: text: Text to summarize max_length: Maximum number of tokens to generate temperature: Sampling temperature (0.1-1.0) - + reasoning_effort: For Harmony-format models (gpt-oss): + "low" | "medium" | "high". Ignored by other models. + Returns: Summary string """ messages = self._build_messages(text) - prompt = self.tokenizer.apply_chat_template( - messages, + + # gpt-oss supports a `reasoning_effort` argument in its chat template. + # For non-Harmony models this kwarg will be silently ignored by most + # templates, but we guard it just in case. + template_kwargs = dict( tokenize=False, add_generation_prompt=True, ) - + if self.model_key == "gptoss": + template_kwargs["reasoning_effort"] = reasoning_effort + + try: + prompt = self.tokenizer.apply_chat_template(messages, **template_kwargs) + except TypeError: + # Template doesn't accept reasoning_effort -- retry without it + template_kwargs.pop("reasoning_effort", None) + prompt = self.tokenizer.apply_chat_template(messages, **template_kwargs) + inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) outputs = self.model.generate( **inputs, @@ -111,33 +126,63 @@ def summarize(self, text, max_length=150, temperature=0.3): ) full = self.tokenizer.decode(outputs[0], skip_special_tokens=False) - + # Clean up Harmony-style outputs - # This is needed for summarization because Harmony-style models (like gpt-oss) - # often include tool/analysis outputs and multiple assistant channels. We only want the final output for summarization. - - if "<|start|>assistant<|channel|>final<|message|>" in full: - final_part = full.split("<|start|>assistant<|channel|>final<|message|>", 1)[1] - # Stop at <|return|> or <|end|> if present - for stop_tok in ("<|return|>", "<|end|>"): + # Harmony-format models (like gpt-oss) emit multiple channels -- + # `analysis` (chain-of-thought) and `final` (user-facing answer). + # We only want the `final` channel for summarization. + + final_marker = "<|channel|>final<|message|>" + + if final_marker in full: + final_part = full.split(final_marker, 1)[1] + # Stop at end-of-turn markers if present + for stop_tok in ("<|return|>", "<|end|>", "<|start|>"): if stop_tok in final_part: final_part = final_part.split(stop_tok, 1)[0] break summary = final_part.strip() - else: - # Fallback for models without Harmony format (e.g. Mistral) - summary = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip() - - # Collapse to the last non-empty line to avoid echoed prompts - lines = [l.strip() for l in summary.splitlines() if l.strip()] - return lines[-1] if lines else "" + + if not summary: + return ( + "[Warning: model produced an empty final answer. " + "Try increasing --max-length or lowering --reasoning-effort.]" + ) + return summary + + # No final-channel marker found. + if "<|channel|>analysis" in full or "assistantanalysis" in full: + # Model ran out of tokens while still reasoning. + return ( + "[Warning: model did not produce a final answer before hitting " + "the token limit. Try increasing --max-length (e.g. 2048) or " + "setting --reasoning-effort low.]" + ) + + # Fallback for non-Harmony models (e.g. Mistral): decode without + # special tokens and strip the echoed prompt. + decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True) + input_decoded = self.tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True) + if decoded.startswith(input_decoded): + decoded = decoded[len(input_decoded):] + return decoded.strip() + def main(): parser = argparse.ArgumentParser(description="Summarize documents using LLMs") - parser.add_argument("--model", default="gptoss", choices=["mistral", "gptoss"], help="Model to use (default: gptoss)") + parser.add_argument("--model", default="gptoss", choices=list(MODELS.keys()), + help="Model to use (default: gptoss)") parser.add_argument("--file", default=None, help="Path to .txt file to summarize") - parser.add_argument("--max-length", type=int, default=250, help="Maximum tokens to generate (default: 150)") - parser.add_argument("--temperature", type=float, default=0.3, help="Sampling temperature 0.1-1.0 (default: 0.3)") + parser.add_argument("--max-length", type=int, default=1024, + help="Maximum tokens to generate (default: 1024). " + "gpt-oss needs headroom for its reasoning channel.") + parser.add_argument("--temperature", type=float, default=0.3, + help="Sampling temperature 0.1-1.0 (default: 0.3)") + parser.add_argument("--reasoning-effort", default="low", + choices=["low", "medium", "high"], + help="Reasoning effort for Harmony models like gpt-oss " + "(default: low). Lower = less thinking, faster, " + "more likely to fit in --max-length.") args = parser.parse_args() summarizer = DocumentSummarizer(model=args.model) @@ -168,11 +213,17 @@ def main(): # Generate summary print("Generating summary...") - summary = summarizer.summarize(document, args.max_length, args.temperature) + summary = summarizer.summarize( + document, + max_length=args.max_length, + temperature=args.temperature, + reasoning_effort=args.reasoning_effort, + ) print(summary) - print(f"\n[OK] Done! (max_length={args.max_length}, temperature={args.temperature})\n") - - # Cleanup + print(f"\n[OK] Done! (max_length={args.max_length}, " + f"temperature={args.temperature}, " + f"reasoning_effort={args.reasoning_effort})\n") + summarizer.cleanup()