Skip to content
Open
29 changes: 22 additions & 7 deletions pyrit/message_normalizer/generic_system_squash.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.


from pyrit.message_normalizer.message_normalizer import MessageListNormalizer
from pyrit.models import Message
from pyrit.models import Message, MessagePiece


class GenericSystemSquashNormalizer(MessageListNormalizer[Message]):
Expand Down Expand Up @@ -43,12 +42,28 @@ async def normalize_async(self, messages: list[Message]) -> list[Message]:
# Only system message, convert to user message
return [Message.from_prompt(prompt=first_piece.converted_value, role="user")]

# Combine system with first user message
user_message_index = next(
(i for i, message in enumerate(messages[1:], start=1) if message.api_role == "user"),
-1,
)
if user_message_index == -1:
# Preserve the instruction content without rewriting non-user messages.
return [Message.from_prompt(prompt=first_piece.converted_value, role="user")] + list(messages[1:])

# Combine system with the first user message
system_content = first_piece.converted_value
user_piece = messages[1].get_piece()
user_piece = messages[user_message_index].get_piece()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But if the first piece in this message is non-text this won't work. Multipart messages are explicitly mentioned in the description of this PR. I think we should address it thoroughly if we already address most of it.

user_content = user_piece.converted_value

combined_content = f"### Instructions ###\n\n{system_content}\n\n######\n\n{user_content}"
squashed_message = Message.from_prompt(prompt=combined_content, role="user")
# Return the squashed message followed by remaining messages (skip first two)
return [squashed_message] + list(messages[2:])
combined_piece = MessagePiece(
role="user",
original_value=combined_content,
conversation_id=user_piece.conversation_id,
sequence=user_piece.sequence,
)
remaining_pieces = list(messages[user_message_index].message_pieces[1:])
squashed_message = Message(message_pieces=[combined_piece] + remaining_pieces)

# Remove system (index 0), replace the first user message with the squashed version, preserve all others
return list(messages[1:user_message_index]) + [squashed_message] + list(messages[user_message_index + 1 :])
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,73 @@ async def test_generic_squash_normalize_to_dicts_async():
assert "### Instructions ###" in result[0]["converted_value"]
assert "System message" in result[0]["converted_value"]
assert "User message" in result[0]["converted_value"]


@pytest.mark.asyncio
async def test_generic_squash_preserves_multipart_user_message():
"""Test that squashing keeps non-text user pieces instead of collapsing to plain text."""
conversation_id = "conv-1"
messages = [
_make_message("system", "System message"),
Message(
message_pieces=[
MessagePiece(
role="user",
original_value="User message",
conversation_id=conversation_id,
sequence=0,
),
MessagePiece(
role="user",
original_value="/tmp/example.png",
original_value_data_type="image_path",
conversation_id=conversation_id,
sequence=0,
),
]
),
]

result = await GenericSystemSquashNormalizer().normalize_async(messages)

assert len(result) == 1
assert result[0].api_role == "user"
assert len(result[0].message_pieces) == 2
assert result[0].get_value() == "### Instructions ###\n\nSystem message\n\n######\n\nUser message"
assert result[0].message_pieces[1].converted_value == "/tmp/example.png"
assert result[0].message_pieces[1].converted_value_data_type == "image_path"


@pytest.mark.asyncio
async def test_generic_squash_uses_first_user_message_instead_of_rewriting_assistant():
"""Test that squash targets the first user message even if assistant messages appear first."""
messages = [
_make_message("system", "System message"),
_make_message("assistant", "Assistant message"),
_make_message("user", "User message"),
]

result = await GenericSystemSquashNormalizer().normalize_async(messages)

assert len(result) == 2
assert result[0].api_role == "assistant"
assert result[0].get_value() == "Assistant message"
assert result[1].api_role == "user"
assert result[1].get_value() == "### Instructions ###\n\nSystem message\n\n######\n\nUser message"


@pytest.mark.asyncio
async def test_generic_squash_no_user_message_converts_system_to_user():
"""Test that system is converted to user when no user messages exist."""
messages = [
_make_message("system", "System message"),
_make_message("assistant", "Assistant message"),
]

result = await GenericSystemSquashNormalizer().normalize_async(messages)

assert len(result) == 2
assert result[0].api_role == "user"
assert result[0].get_value() == "System message"
assert result[1].api_role == "assistant"
assert result[1].get_value() == "Assistant message"
Loading