From 702bfcded754c9bb5e9766f1bdbd531ad6ce7b42 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Wed, 3 Jun 2026 20:52:46 -0700 Subject: [PATCH 01/10] MAINT: Evict I/O modules from pyrit.models into pyrit.io Move StorageIO/DiskStorageIO/AzureBlobStorageIO, the data-type serializers, and data_serializer_factory out of pyrit.models into a new pyrit.io package. Leave deprecation shims (removed_in 0.17.0) at the old pyrit.models locations and method-level shims (Seed.set_sha256_value_async, MessagePiece.set_sha256_values_async) that delegate to the new free functions in pyrit.io.serializers. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/backend/services/attack_service.py | 2 +- pyrit/backend/services/converter_service.py | 2 +- pyrit/common/data_url_converter.py | 2 +- pyrit/common/display_response.py | 3 +- .../seed_datasets/remote/_image_cache.py | 2 +- .../seed_datasets/remote/msts_dataset.py | 3 +- pyrit/io/__init__.py | 54 ++ pyrit/io/serializers.py | 792 ++++++++++++++++++ pyrit/io/storage.py | 507 +++++++++++ pyrit/memory/azure_sql_memory.py | 7 +- pyrit/memory/memory_interface.py | 11 +- pyrit/memory/sqlite_memory.py | 3 +- .../chat_message_normalizer.py | 3 +- pyrit/models/__init__.py | 44 +- pyrit/models/data_type_serializer.py | 773 +---------------- pyrit/models/messages/message_piece.py | 24 +- pyrit/models/seeds/seed.py | 21 +- pyrit/models/seeds/seed_prompt.py | 3 +- pyrit/models/storage_io.py | 519 +----------- pyrit/output/conversation/pretty.py | 2 +- .../add_image_text_converter.py | 3 +- .../add_image_to_video_converter.py | 3 +- .../add_text_image_converter.py | 3 +- .../prompt_converter/audio_echo_converter.py | 3 +- .../audio_frequency_converter.py | 3 +- .../prompt_converter/audio_speed_converter.py | 3 +- .../audio_volume_converter.py | 3 +- .../audio_white_noise_converter.py | 3 +- .../azure_speech_audio_to_text_converter.py | 3 +- .../azure_speech_text_to_audio_converter.py | 3 +- .../base_image_to_image_converter.py | 3 +- .../image_compression_converter.py | 3 +- .../image_overlay_converter.py | 4 +- pyrit/prompt_converter/pdf_converter.py | 4 +- pyrit/prompt_converter/qr_code_converter.py | 3 +- .../transparency_attack_converter.py | 3 +- pyrit/prompt_converter/word_doc_converter.py | 6 +- pyrit/prompt_normalizer/prompt_normalizer.py | 3 +- .../openai/openai_chat_target.py | 11 +- .../openai/openai_image_target.py | 8 +- .../openai/openai_realtime_target.py | 8 +- .../prompt_target/openai/openai_tts_target.py | 8 +- .../openai/openai_video_target.py | 10 +- .../playwright_copilot_target.py | 9 +- .../prompt_target/websocket_copilot_target.py | 3 +- .../azure_content_filter_scorer.py | 10 +- .../test_convert_local_image_to_data_url.py | 2 +- tests/unit/common/test_display_response.py | 4 +- tests/unit/io/test_deprecation_shims.py | 181 ++++ .../test_serializers.py} | 41 +- .../test_storage_io.py => io/test_storage.py} | 2 +- tests/unit/models/test_import_boundary.py | 9 +- tests/unit/output/test_blur_images.py | 4 +- .../prompt_converter/test_pdf_converter.py | 3 +- 54 files changed, 1757 insertions(+), 1392 deletions(-) create mode 100644 pyrit/io/__init__.py create mode 100644 pyrit/io/serializers.py create mode 100644 pyrit/io/storage.py create mode 100644 tests/unit/io/test_deprecation_shims.py rename tests/unit/{models/test_data_type_serializer.py => io/test_serializers.py} (95%) rename tests/unit/{models/test_storage_io.py => io/test_storage.py} (99%) diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py index 0e13919855..3bea0fa55c 100644 --- a/pyrit/backend/services/attack_service.py +++ b/pyrit/backend/services/attack_service.py @@ -50,6 +50,7 @@ from pyrit.backend.models.common import PaginationInfo from pyrit.backend.services.converter_service import get_converter_service from pyrit.backend.services.target_service import get_target_service +from pyrit.io import data_serializer_factory from pyrit.memory import CentralMemory from pyrit.models import ( AttackOutcome, @@ -60,7 +61,6 @@ MessagePiece, PromptDataType, build_atomic_attack_identifier, - data_serializer_factory, ) from pyrit.prompt_normalizer import PromptConverterConfiguration, PromptNormalizer diff --git a/pyrit/backend/services/converter_service.py b/pyrit/backend/services/converter_service.py index 66cb8bdc31..4df2f1b1eb 100644 --- a/pyrit/backend/services/converter_service.py +++ b/pyrit/backend/services/converter_service.py @@ -36,8 +36,8 @@ CreateConverterResponse, PreviewStep, ) +from pyrit.io import data_serializer_factory from pyrit.models import PromptDataType -from pyrit.models.data_type_serializer import data_serializer_factory from pyrit.prompt_converter import PromptConverter from pyrit.prompt_target import PromptTarget from pyrit.registry.object_registries import ConverterRegistry diff --git a/pyrit/common/data_url_converter.py b/pyrit/common/data_url_converter.py index 20ff008332..0193b2a5fb 100644 --- a/pyrit/common/data_url_converter.py +++ b/pyrit/common/data_url_converter.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. from pyrit.common.deprecation import print_deprecation_message -from pyrit.models import DataTypeSerializer, data_serializer_factory +from pyrit.io import DataTypeSerializer, data_serializer_factory # Supported image formats for Azure OpenAI GPT-4o, # https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/use-your-image-data diff --git a/pyrit/common/display_response.py b/pyrit/common/display_response.py index 0c40e5b5aa..11093a6cf8 100644 --- a/pyrit/common/display_response.py +++ b/pyrit/common/display_response.py @@ -8,8 +8,9 @@ from pyrit.common.deprecation import print_deprecation_message from pyrit.common.notebook_utils import is_in_ipython_session +from pyrit.io import AzureBlobStorageIO, DiskStorageIO from pyrit.memory import CentralMemory -from pyrit.models import AzureBlobStorageIO, DiskStorageIO, MessagePiece +from pyrit.models import MessagePiece logger = logging.getLogger(__name__) diff --git a/pyrit/datasets/seed_datasets/remote/_image_cache.py b/pyrit/datasets/seed_datasets/remote/_image_cache.py index dbc866a47b..9b5cb807ae 100644 --- a/pyrit/datasets/seed_datasets/remote/_image_cache.py +++ b/pyrit/datasets/seed_datasets/remote/_image_cache.py @@ -17,7 +17,7 @@ from typing import Any, Optional from pyrit.common.net_utility import make_request_and_raise_if_error_async -from pyrit.models import data_serializer_factory +from pyrit.io import data_serializer_factory logger = logging.getLogger(__name__) diff --git a/pyrit/datasets/seed_datasets/remote/msts_dataset.py b/pyrit/datasets/seed_datasets/remote/msts_dataset.py index e1ada75a29..8fa876e6da 100644 --- a/pyrit/datasets/seed_datasets/remote/msts_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/msts_dataset.py @@ -12,7 +12,8 @@ from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( _RemoteDatasetLoader, ) -from pyrit.models import SeedDataset, SeedPrompt, data_serializer_factory +from pyrit.io import data_serializer_factory +from pyrit.models import SeedDataset, SeedPrompt if TYPE_CHECKING: from PIL.Image import Image as PILImage diff --git a/pyrit/io/__init__.py b/pyrit/io/__init__.py new file mode 100644 index 0000000000..5333f788cb --- /dev/null +++ b/pyrit/io/__init__.py @@ -0,0 +1,54 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +I/O layer for PyRIT: storage backends and multi-modal data serializers. + +Provides the disk and blob storage adapters (``StorageIO`` and its +implementations) and the data-type serializers (``data_serializer_factory`` and +the per-type ``*DataTypeSerializer`` classes) used to read and write prompt +payloads such as text, images, audio, and video. + +Unlike ``pyrit.models``, modules in this package may depend on ``pyrit.memory`` +and ``pyrit.auth`` (resolved lazily to avoid import cycles). +""" + +from pyrit.io.serializers import ( + AllowedCategories, + AudioPathDataTypeSerializer, + BinaryPathDataTypeSerializer, + DataTypeSerializer, + ErrorDataTypeSerializer, + ImagePathDataTypeSerializer, + TextDataTypeSerializer, + URLDataTypeSerializer, + VideoPathDataTypeSerializer, + data_serializer_factory, + set_message_piece_sha256_async, + set_seed_sha256_async, +) +from pyrit.io.storage import ( + AzureBlobStorageIO, + DiskStorageIO, + StorageIO, + SupportedContentType, +) + +__all__ = [ + "AllowedCategories", + "AudioPathDataTypeSerializer", + "AzureBlobStorageIO", + "BinaryPathDataTypeSerializer", + "DataTypeSerializer", + "data_serializer_factory", + "DiskStorageIO", + "ErrorDataTypeSerializer", + "ImagePathDataTypeSerializer", + "set_message_piece_sha256_async", + "set_seed_sha256_async", + "StorageIO", + "SupportedContentType", + "TextDataTypeSerializer", + "URLDataTypeSerializer", + "VideoPathDataTypeSerializer", +] diff --git a/pyrit/io/serializers.py b/pyrit/io/serializers.py new file mode 100644 index 0000000000..486f3a6f3b --- /dev/null +++ b/pyrit/io/serializers.py @@ -0,0 +1,792 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +import abc +import asyncio +import base64 +import hashlib +import time +import wave +from mimetypes import guess_type +from pathlib import Path +from typing import TYPE_CHECKING, Literal, Optional, Union, get_args +from urllib.parse import urlparse + +import aiofiles + +from pyrit.common.deprecation import print_deprecation_message +from pyrit.common.path import DB_DATA_PATH +from pyrit.io.storage import DiskStorageIO, StorageIO + +if TYPE_CHECKING: + from pyrit.memory import MemoryInterface + from pyrit.models.literals import PromptDataType + from pyrit.models.messages.message_piece import MessagePiece + from pyrit.models.seeds.seed import Seed + +# Define allowed categories for validation +AllowedCategories = Literal["seed-prompt-entries", "prompt-memory-entries"] + + +def _write_wav_sync( + path: str, + *, + num_channels: int, + sample_width: int, + sample_rate: int, + data: bytes, +) -> None: + """Write PCM audio bytes to a WAV file synchronously.""" + with wave.open(path, "wb") as wav_file: + wav_file.setnchannels(num_channels) + wav_file.setsampwidth(sample_width) + wav_file.setframerate(sample_rate) + wav_file.writeframes(data) + + +def data_serializer_factory( + *, + data_type: PromptDataType, + value: Optional[str] = None, + extension: Optional[str] = None, + category: AllowedCategories, +) -> DataTypeSerializer: + """ + Create a DataTypeSerializer instance. + + Args: + data_type (str): The type of the data (e.g., 'text', 'image_path', 'audio_path'). + value (str): The data value to be serialized. + extension (Optional[str]): The file extension, if applicable. + category (AllowedCategories): The category or context for the data (e.g., 'seed-prompt-entries'). + + Returns: + DataTypeSerializer: An instance of the appropriate serializer. + + Raises: + ValueError: If the category is not provided or invalid. + + """ + if not category: + raise ValueError( + f"The 'category' argument is mandatory and must be one of the following: {get_args(AllowedCategories)}." + ) + if value is not None: + if data_type in ["text", "reasoning", "function_call", "tool_call", "function_call_output"]: + return TextDataTypeSerializer(prompt_text=value, data_type=data_type) + if data_type == "image_path": + return ImagePathDataTypeSerializer(category=category, prompt_text=value, extension=extension) + if data_type == "audio_path": + return AudioPathDataTypeSerializer(category=category, prompt_text=value, extension=extension) + if data_type == "video_path": + return VideoPathDataTypeSerializer(category=category, prompt_text=value, extension=extension) + if data_type == "binary_path": + return BinaryPathDataTypeSerializer(category=category, prompt_text=value, extension=extension) + if data_type == "error": + return ErrorDataTypeSerializer(prompt_text=value) + if data_type == "url": + return URLDataTypeSerializer(category=category, prompt_text=value, extension=extension) + raise ValueError(f"Data type {data_type} not supported") + if data_type == "image_path": + return ImagePathDataTypeSerializer(category=category, extension=extension) + if data_type == "audio_path": + return AudioPathDataTypeSerializer(category=category, extension=extension) + if data_type == "video_path": + return VideoPathDataTypeSerializer(category=category, extension=extension) + if data_type == "binary_path": + return BinaryPathDataTypeSerializer(category=category, extension=extension) + if data_type == "error": + return ErrorDataTypeSerializer(prompt_text="") + raise ValueError(f"Data type {data_type} without prompt text not supported") + + +class DataTypeSerializer(abc.ABC): + """ + Abstract base class for data type normalizers. + + Responsible for reading and saving multi-modal data types to local disk or Azure Storage Account. + """ + + data_type: PromptDataType + value: str + category: str + data_sub_directory: str + file_extension: str + + _file_path: Union[Path, str] | None = None + + @property + def _memory(self) -> MemoryInterface: + from pyrit.memory import CentralMemory + + return CentralMemory.get_memory_instance() + + def _get_storage_io(self) -> StorageIO: + """ + Retrieve the input datasets storage handle. + + Returns: + StorageIO: An instance of DiskStorageIO or AzureBlobStorageIO based on the storage configuration. + + Raises: + ValueError: If the Azure Storage URL is detected but the datasets storage handle is not set. + RuntimeError: If results_storage_io is not configured but Azure storage URL was detected. + + """ + if self._is_azure_storage_url(self.value): + # Scenarios where a user utilizes an in-memory DuckDB but also needs to interact + # with an Azure Storage Account, ex., XPIAWorkflow. + if self._memory.results_storage_io is None: + raise RuntimeError("results_storage_io is not configured but Azure storage URL was detected") + return self._memory.results_storage_io + return DiskStorageIO() + + @abc.abstractmethod + def data_on_disk(self) -> bool: + """ + Indicate whether the data is stored on disk. + + Returns: + bool: True when data is persisted on disk. + + """ + + async def save_data_async(self, data: bytes, output_filename: Optional[str] = None) -> None: + """ + Save data to storage. + + Arguments: + data: bytes: The data to be saved. + output_filename (optional, str): filename to store data as. Defaults to UUID if not provided + + Raises: + RuntimeError: If storage IO is not initialized. + """ + file_path = await self.get_data_filename_async(file_name=output_filename) + if self._memory.results_storage_io is None: + raise RuntimeError("Storage IO not initialized") + await self._memory.results_storage_io.write_file_async(file_path, data) + self.value = str(file_path) + + async def save_b64_image_async(self, data: str | bytes, output_filename: str | None = None) -> None: + """ + Save a base64-encoded image to storage. + + Arguments: + data: string or bytes with base64 data + output_filename (optional, str): filename to store image as. Defaults to UUID if not provided + + Raises: + RuntimeError: If storage IO is not initialized. + """ + file_path = await self.get_data_filename_async(file_name=output_filename) + image_bytes = base64.b64decode(data) + if self._memory.results_storage_io is None: + raise RuntimeError("Storage IO not initialized") + await self._memory.results_storage_io.write_file_async(file_path, image_bytes) + self.value = str(file_path) + + async def save_formatted_audio_async( + self, + data: bytes, + num_channels: int = 1, + sample_width: int = 2, + sample_rate: int = 16000, + output_filename: Optional[str] = None, + ) -> None: + """ + Save PCM16 or similarly formatted audio data to storage. + + Arguments: + data: bytes with audio data + output_filename (optional, str): filename to store audio as. Defaults to UUID if not provided + num_channels (optional, int): number of channels in audio data. Defaults to 1 + sample_width (optional, int): sample width in bytes. Defaults to 2 + sample_rate (optional, int): sample rate in Hz. Defaults to 16000 + + Raises: + RuntimeError: If storage IO is not initialized. + """ + file_path = await self.get_data_filename_async(file_name=output_filename) + + # save audio file locally first if in AzureStorageBlob so we can use wave.open to set audio parameters + if self._is_azure_storage_url(str(file_path)): + local_temp_path = Path(DB_DATA_PATH, "temp_audio.wav") + await asyncio.to_thread( + _write_wav_sync, + str(local_temp_path), + num_channels=num_channels, + sample_width=sample_width, + sample_rate=sample_rate, + data=data, + ) + + async with aiofiles.open(local_temp_path, "rb") as f: + audio_data = await f.read() + if self._memory.results_storage_io is None: + raise RuntimeError("self._memory.results_storage_io is not initialized") + await self._memory.results_storage_io.write_file_async(file_path, audio_data) + local_temp_path.unlink() + + # If local, we can just save straight to disk and do not need to delete temp file after + else: + await asyncio.to_thread( + _write_wav_sync, + str(file_path), + num_channels=num_channels, + sample_width=sample_width, + sample_rate=sample_rate, + data=data, + ) + + self.value = str(file_path) + + async def read_data_async(self) -> bytes: + """ + Read data from storage. + + Returns: + bytes: The data read from storage. + + Raises: + TypeError: If the serializer does not represent on-disk data. + RuntimeError: If no value is set. + FileNotFoundError: If the referenced file does not exist. + + """ + if not self.data_on_disk(): + raise TypeError(f"Data for data Type {self.data_type} is not stored on disk") + + if not self.value: + raise RuntimeError("Prompt text not set") + + storage_io = self._get_storage_io() + # Check if path exists + file_exists = await storage_io.path_exists_async(path=self.value) + if not file_exists: + raise FileNotFoundError(f"File not found: {self.value}") + # Read the contents from the path + return await storage_io.read_file_async(self.value) + + async def read_data_base64_async(self) -> str: + """ + Read data from storage and return it as a base64 string. + + Returns: + str: Base64-encoded data. + + """ + byte_array = await self.read_data_async() + return base64.b64encode(byte_array).decode("utf-8") + + async def get_sha256_async(self) -> str: + """ + Compute SHA256 hash for this serializer's current value. + + Returns: + str: Hex digest of the computed SHA256 hash. + + Raises: + FileNotFoundError: If on-disk data path does not exist. + ValueError: If in-memory data cannot be converted to bytes. + + """ + input_bytes: bytes | None = None + + if self.data_on_disk(): + storage_io = self._get_storage_io() + file_exists = await storage_io.path_exists_async(self.value) + if not file_exists: + raise FileNotFoundError(f"File not found: {self.value}") + + # Read the data from storage + input_bytes = await storage_io.read_file_async(self.value) + else: + if isinstance(self.value, str): + input_bytes = self.value.encode("utf-8") + else: + raise ValueError(f"Invalid data type {self.value}, expected str data type.") + + hash_object = hashlib.sha256(input_bytes) + return hash_object.hexdigest() + + async def get_data_filename_async(self, file_name: Optional[str] = None) -> Union[Path, str]: + """ + Generate or retrieve a unique filename for the data file. + + Args: + file_name (Optional[str]): Optional file name override. + + Returns: + Union[Path, str]: Full storage path for the generated data file. + + Raises: + TypeError: If the serializer is not configured for on-disk data. + RuntimeError: If required data subdirectory information is missing. + + """ + if self._file_path: + return self._file_path + + if not self.data_on_disk(): + raise TypeError("Data is not stored on disk") + + if not self.data_sub_directory: + raise RuntimeError("Data sub directory not set") + + ticks = int(time.time() * 1_000_000) + if self._memory.results_path: + results_path = str(self._memory.results_path) + else: + from pyrit.common.path import DB_DATA_PATH + + results_path = str(DB_DATA_PATH) + file_name = file_name if file_name else str(ticks) + + if self._is_azure_storage_url(results_path): + full_data_directory_path = results_path + self.data_sub_directory + self._file_path = full_data_directory_path + f"/{file_name}.{self.file_extension}" + else: + full_data_directory_path = results_path + self.data_sub_directory + if self._memory.results_storage_io is None: + raise RuntimeError("self._memory.results_storage_io is not initialized") + await self._memory.results_storage_io.create_directory_if_not_exists_async(Path(full_data_directory_path)) + self._file_path = Path(full_data_directory_path, f"{file_name}.{self.file_extension}") + + return self._file_path + + async def save_data( # pyrit-async-suffix-exempt + self, data: bytes, output_filename: Optional[str] = None + ) -> None: + """ + Save data to storage (deprecated alias of ``save_data_async``). + + Args: + data: The data to be saved. + output_filename: Optional filename to store data as. + """ + print_deprecation_message( + old_item="pyrit.io.serializers.DataTypeSerializer.save_data", + new_item="pyrit.io.serializers.DataTypeSerializer.save_data_async", + removed_in="0.16.0", + ) + await self.save_data_async(data, output_filename) + + async def save_b64_image( # pyrit-async-suffix-exempt + self, data: str | bytes, output_filename: str | None = None + ) -> None: + """ + Save a base64-encoded image to storage (deprecated alias of ``save_b64_image_async``). + + Args: + data: String or bytes with base64 data. + output_filename: Optional filename to store image as. + """ + print_deprecation_message( + old_item="pyrit.io.serializers.DataTypeSerializer.save_b64_image", + new_item="pyrit.io.serializers.DataTypeSerializer.save_b64_image_async", + removed_in="0.16.0", + ) + await self.save_b64_image_async(data, output_filename) + + async def save_formatted_audio( # pyrit-async-suffix-exempt + self, + data: bytes, + num_channels: int = 1, + sample_width: int = 2, + sample_rate: int = 16000, + output_filename: Optional[str] = None, + ) -> None: + """ + Save formatted audio data to storage (deprecated alias of ``save_formatted_audio_async``). + + Args: + data: Audio data bytes. + num_channels: Number of channels in audio data. + sample_width: Sample width in bytes. + sample_rate: Sample rate in Hz. + output_filename: Optional filename to store audio as. + """ + print_deprecation_message( + old_item="pyrit.io.serializers.DataTypeSerializer.save_formatted_audio", + new_item="pyrit.io.serializers.DataTypeSerializer.save_formatted_audio_async", + removed_in="0.16.0", + ) + await self.save_formatted_audio_async(data, num_channels, sample_width, sample_rate, output_filename) + + async def read_data(self) -> bytes: # pyrit-async-suffix-exempt + """ + Read data from storage (deprecated alias of ``read_data_async``). + + Returns: + bytes: The data read from storage. + """ + print_deprecation_message( + old_item="pyrit.io.serializers.DataTypeSerializer.read_data", + new_item="pyrit.io.serializers.DataTypeSerializer.read_data_async", + removed_in="0.16.0", + ) + return await self.read_data_async() + + async def read_data_base64(self) -> str: # pyrit-async-suffix-exempt + """ + Read data and return it as a base64 string (deprecated alias of ``read_data_base64_async``). + + Returns: + str: Base64-encoded data. + """ + print_deprecation_message( + old_item="pyrit.io.serializers.DataTypeSerializer.read_data_base64", + new_item="pyrit.io.serializers.DataTypeSerializer.read_data_base64_async", + removed_in="0.16.0", + ) + return await self.read_data_base64_async() + + async def get_sha256(self) -> str: # pyrit-async-suffix-exempt + """ + Compute SHA256 hash for this serializer's current value (deprecated alias of ``get_sha256_async``). + + Returns: + str: Hex digest of the computed SHA256 hash. + """ + print_deprecation_message( + old_item="pyrit.io.serializers.DataTypeSerializer.get_sha256", + new_item="pyrit.io.serializers.DataTypeSerializer.get_sha256_async", + removed_in="0.16.0", + ) + return await self.get_sha256_async() + + async def get_data_filename( # pyrit-async-suffix-exempt + self, file_name: Optional[str] = None + ) -> Union[Path, str]: + """ + Generate or retrieve a unique filename for the data file (deprecated alias of ``get_data_filename_async``). + + Args: + file_name: Optional file name override. + + Returns: + Union[Path, str]: Full storage path for the generated data file. + """ + print_deprecation_message( + old_item="pyrit.io.serializers.DataTypeSerializer.get_data_filename", + new_item="pyrit.io.serializers.DataTypeSerializer.get_data_filename_async", + removed_in="0.16.0", + ) + return await self.get_data_filename_async(file_name) + + @staticmethod + def get_extension(file_path: str) -> str | None: + """ + Get the file extension from the file path. + + Args: + file_path (str): Input file path. + + Returns: + str | None: File extension (including dot) or None if unavailable. + + """ + ext = Path(file_path).suffix + return ext or None + + @staticmethod + def get_mime_type(file_path: str) -> str | None: + """ + Get the MIME type of the file path. + + Args: + file_path (str): Input file path. + + Returns: + str | None: MIME type if detectable; otherwise None. + + """ + mime_type, _ = guess_type(file_path) + return mime_type + + def _is_azure_storage_url(self, path: str) -> bool: + """ + Validate whether the given path is an Azure Storage URL. + + Args: + path (str): Path or URL to check. + + Returns: + bool: True if the path is an Azure Blob Storage URL. + + """ + parsed = urlparse(path) + return parsed.scheme in ("http", "https") and "blob.core.windows.net" in parsed.netloc + + +class TextDataTypeSerializer(DataTypeSerializer): + """Serializer for text and text-like prompt values that stay in-memory.""" + + def __init__(self, *, prompt_text: str, data_type: PromptDataType = "text") -> None: + """ + Initialize a text serializer. + + Args: + prompt_text (str): Prompt value. + data_type (PromptDataType): Text-like prompt data type. + + """ + self.data_type = data_type + self.value = prompt_text + + def data_on_disk(self) -> bool: + """ + Indicate whether this serializer persists data on disk. + + Returns: + bool: Always False for text serializers. + + """ + return False + + +class ErrorDataTypeSerializer(DataTypeSerializer): + """Serializer for error payloads stored as in-memory text.""" + + def __init__(self, *, prompt_text: str) -> None: + """ + Initialize an error serializer. + + Args: + prompt_text (str): Error payload text. + + """ + self.data_type = "error" + self.value = prompt_text + + def data_on_disk(self) -> bool: + """ + Indicate whether this serializer persists data on disk. + + Returns: + bool: Always False for error serializers. + + """ + return False + + +class URLDataTypeSerializer(DataTypeSerializer): + """Serializer for URL values and URL-backed local file references.""" + + def __init__(self, *, category: str, prompt_text: str, extension: Optional[str] = None) -> None: + """ + Initialize a URL serializer. + + Args: + category (str): Data category folder name. + prompt_text (str): URL or path value. + extension (Optional[str]): Optional extension for persisted content. + + """ + self.data_type = "url" + self.value = prompt_text + self.data_sub_directory = f"/{category}/urls" + self.file_extension = extension if extension else "txt" + self.on_disk = not (prompt_text.startswith(("http://", "https://"))) + + def data_on_disk(self) -> bool: + """ + Indicate whether this serializer persists data on disk. + + Returns: + bool: True for non-http values, False for URL values. + + """ + return self.on_disk + + +class ImagePathDataTypeSerializer(DataTypeSerializer): + """Serializer for image path values stored on disk.""" + + def __init__(self, *, category: str, prompt_text: Optional[str] = None, extension: Optional[str] = None) -> None: + """ + Initialize an image-path serializer. + + Args: + category (str): Data category folder name. + prompt_text (Optional[str]): Optional existing image path. + extension (Optional[str]): Optional image extension. + + """ + self.data_type = "image_path" + self.data_sub_directory = f"/{category}/images" + self.file_extension = extension if extension else "png" + + if prompt_text: + self.value = prompt_text + + def data_on_disk(self) -> bool: + """ + Indicate whether this serializer persists data on disk. + + Returns: + bool: Always True for image path serializers. + + """ + return True + + +class AudioPathDataTypeSerializer(DataTypeSerializer): + """Serializer for audio path values stored on disk.""" + + def __init__( + self, + *, + category: str, + prompt_text: Optional[str] = None, + extension: Optional[str] = None, + ) -> None: + """ + Initialize an audio-path serializer. + + Args: + category (str): Data category folder name. + prompt_text (Optional[str]): Optional existing audio path. + extension (Optional[str]): Optional audio extension. + + """ + self.data_type = "audio_path" + self.data_sub_directory = f"/{category}/audio" + self.file_extension = extension if extension else "mp3" + + if prompt_text: + self.value = prompt_text + + def data_on_disk(self) -> bool: + """ + Indicate whether this serializer persists data on disk. + + Returns: + bool: Always True for audio path serializers. + + """ + return True + + +class VideoPathDataTypeSerializer(DataTypeSerializer): + """Serializer for video path values stored on disk.""" + + def __init__( + self, + *, + category: str, + prompt_text: Optional[str] = None, + extension: Optional[str] = None, + ) -> None: + """ + Initialize a video-path serializer. + + Args: + category (str): The category or context for the data. + prompt_text (Optional[str]): The video path or identifier. + extension (Optional[str]): The file extension, defaults to 'mp4'. + + """ + self.data_type = "video_path" + self.data_sub_directory = f"/{category}/videos" + self.file_extension = extension if extension else "mp4" + + if prompt_text: + self.value = prompt_text + + def data_on_disk(self) -> bool: + """ + Indicate whether this serializer persists data on disk. + + Returns: + bool: Always True for video path serializers. + + """ + return True + + +class BinaryPathDataTypeSerializer(DataTypeSerializer): + """Serializer for generic binary path values stored on disk.""" + + def __init__( + self, + *, + category: str, + prompt_text: Optional[str] = None, + extension: Optional[str] = None, + ) -> None: + """ + Initialize a generic binary-path serializer. + + This serializer handles generic binary data that doesn't fit into specific + categories like images, audio, or video. Useful for XPIA attacks and + storing files like PDFs, documents, or other binary formats. + + Args: + category (str): The category or context for the data. + prompt_text (Optional[str]): The binary file path or identifier. + extension (Optional[str]): The file extension, defaults to 'bin'. + + """ + self.data_type = "binary_path" + self.data_sub_directory = f"/{category}/binaries" + self.file_extension = extension if extension else "bin" + + if prompt_text: + self.value = prompt_text + + def data_on_disk(self) -> bool: + """ + Indicate whether this serializer persists data on disk. + + Returns: + bool: Always True for binary path serializers. + + """ + return True + + +async def set_message_piece_sha256_async(message_piece: MessagePiece) -> None: + """ + Compute and assign SHA256 hash values for a message piece's original and converted payloads. + + Async because blob payloads may need to be fetched. Must be called explicitly after + the message piece is constructed and its values are finalized. + + Args: + message_piece (MessagePiece): The message piece to populate with SHA256 values. + """ + original_serializer = data_serializer_factory( + category="prompt-memory-entries", + data_type=message_piece.original_value_data_type, + value=message_piece.original_value, + ) + message_piece.original_value_sha256 = await original_serializer.get_sha256_async() + + converted_serializer = data_serializer_factory( + category="prompt-memory-entries", + data_type=message_piece.converted_value_data_type, + value=message_piece.converted_value, + ) + message_piece.converted_value_sha256 = await converted_serializer.get_sha256_async() + + +async def set_seed_sha256_async(seed: Seed) -> None: + """ + Compute and assign the SHA256 hash value for a seed's value. + + Should be called after the seed ``value`` is serialized to text, as file paths used in + the ``value`` may have changed from local to memory storage paths. Async due to blob retrieval. + + Args: + seed (Seed): The seed to populate with its SHA256 value. + """ + serializer = data_serializer_factory( + category="seed-prompt-entries", + data_type=seed.data_type, + value=seed.value, + ) + seed.value_sha256 = await serializer.get_sha256_async() diff --git a/pyrit/io/storage.py b/pyrit/io/storage.py new file mode 100644 index 0000000000..e42e01f1e1 --- /dev/null +++ b/pyrit/io/storage.py @@ -0,0 +1,507 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from enum import Enum +from pathlib import Path +from typing import TYPE_CHECKING, Optional, Union +from urllib.parse import urlparse + +import aiofiles + +from pyrit.common.deprecation import print_deprecation_message + +if TYPE_CHECKING: + from azure.storage.blob.aio import ContainerClient as AsyncContainerClient + +logger = logging.getLogger(__name__) + + +class SupportedContentType(Enum): + """ + All supported content types for uploading blobs to provided storage account container. + See all options here: https://www.iana.org/assignments/media-types/media-types.xhtml. + """ + + # TODO, add other media supported types + PLAIN_TEXT = "text/plain" + + +class StorageIO(ABC): + """ + Abstract interface for storage systems (local disk, Azure Storage Account, etc.). + """ + + @abstractmethod + async def read_file_async(self, path: Union[Path, str]) -> bytes: + """ + Asynchronously reads the file (or blob) from the given path. + """ + + @abstractmethod + async def write_file_async(self, path: Union[Path, str], data: bytes) -> None: + """ + Asynchronously writes data to the given path. + """ + + @abstractmethod + async def path_exists_async(self, path: Union[Path, str]) -> bool: + """ + Asynchronously checks if a file or blob exists at the given path. + """ + + @abstractmethod + async def is_file_async(self, path: Union[Path, str]) -> bool: + """ + Asynchronously checks if the path refers to a file (not a directory or container). + """ + + @abstractmethod + async def create_directory_if_not_exists_async(self, path: Union[Path, str]) -> None: + """ + Asynchronously creates a directory or equivalent in the storage system if it doesn't exist. + """ + + async def read_file(self, path: Union[Path, str]) -> bytes: # pyrit-async-suffix-exempt + """ + Read a file from storage (deprecated alias of ``read_file_async``). + + Args: + path (Union[Path, str]): The path to the file. + + Returns: + bytes: The content of the file. + """ + print_deprecation_message( + old_item="pyrit.io.storage.StorageIO.read_file", + new_item="pyrit.io.storage.StorageIO.read_file_async", + removed_in="0.16.0", + ) + return await self.read_file_async(path) + + async def write_file(self, path: Union[Path, str], data: bytes) -> None: # pyrit-async-suffix-exempt + """ + Write data to storage (deprecated alias of ``write_file_async``). + + Args: + path (Union[Path, str]): The path to the file. + data (bytes): The content to write to the file. + """ + print_deprecation_message( + old_item="pyrit.io.storage.StorageIO.write_file", + new_item="pyrit.io.storage.StorageIO.write_file_async", + removed_in="0.16.0", + ) + await self.write_file_async(path, data) + + async def path_exists(self, path: Union[Path, str]) -> bool: # pyrit-async-suffix-exempt + """ + Check whether a path exists (deprecated alias of ``path_exists_async``). + + Args: + path (Union[Path, str]): The path to check. + + Returns: + bool: True if the path exists, False otherwise. + """ + print_deprecation_message( + old_item="pyrit.io.storage.StorageIO.path_exists", + new_item="pyrit.io.storage.StorageIO.path_exists_async", + removed_in="0.16.0", + ) + return await self.path_exists_async(path) + + async def is_file(self, path: Union[Path, str]) -> bool: # pyrit-async-suffix-exempt + """ + Check whether the given path is a file (deprecated alias of ``is_file_async``). + + Args: + path (Union[Path, str]): The path to check. + + Returns: + bool: True if the path is a file, False otherwise. + """ + print_deprecation_message( + old_item="pyrit.io.storage.StorageIO.is_file", + new_item="pyrit.io.storage.StorageIO.is_file_async", + removed_in="0.16.0", + ) + return await self.is_file_async(path) + + async def create_directory_if_not_exists(self, path: Union[Path, str]) -> None: # pyrit-async-suffix-exempt + """ + Create a directory if it does not exist (deprecated alias of ``create_directory_if_not_exists_async``). + + Args: + path (Union[Path, str]): The directory path to create. + """ + print_deprecation_message( + old_item="pyrit.io.storage.StorageIO.create_directory_if_not_exists", + new_item="pyrit.io.storage.StorageIO.create_directory_if_not_exists_async", + removed_in="0.16.0", + ) + await self.create_directory_if_not_exists_async(path) + + +class DiskStorageIO(StorageIO): + """ + Implementation of StorageIO for local disk storage. + """ + + async def read_file_async(self, path: Union[Path, str]) -> bytes: + """ + Asynchronously reads a file from the local disk. + + Args: + path (Union[Path, str]): The path to the file. + + Returns: + bytes: The content of the file. + + """ + path = self._convert_to_path(path) + async with aiofiles.open(path, "rb") as file: + return await file.read() + + async def write_file_async(self, path: Union[Path, str], data: bytes) -> None: + """ + Asynchronously writes data to a file on the local disk. + + Args: + path (Path): The path to the file. + data (bytes): The content to write to the file. + + """ + path = self._convert_to_path(path) + async with aiofiles.open(path, "wb") as file: + await file.write(data) + + async def path_exists_async(self, path: Union[Path, str]) -> bool: + """ + Check whether a path exists on the local disk. + + Args: + path (Path): The path to check. + + Returns: + bool: True if the path exists, False otherwise. + + """ + path = self._convert_to_path(path) + return path.exists() + + async def is_file_async(self, path: Union[Path, str]) -> bool: + """ + Check whether the given path is a file (not a directory). + + Args: + path (Path): The path to check. + + Returns: + bool: True if the path is a file, False otherwise. + + """ + path = self._convert_to_path(path) + return path.is_file() + + async def create_directory_if_not_exists_async(self, path: Union[Path, str]) -> None: + """ + Asynchronously creates a directory if it doesn't exist on the local disk. + + Args: + path (Path): The directory path to create. + + """ + directory_path = self._convert_to_path(path) + if not directory_path.exists(): + directory_path.mkdir(parents=True, exist_ok=True) + + def _convert_to_path(self, path: Union[Path, str]) -> Path: + """ + Convert an input path to a Path object. + + Args: + path (Union[Path, str]): Input path value. + + Returns: + Path: Normalized Path instance. + + """ + return Path(path) if isinstance(path, str) else path + + +class AzureBlobStorageIO(StorageIO): + """ + Implementation of StorageIO for Azure Blob Storage. + """ + + def __init__( + self, + *, + container_url: Optional[str] = None, + sas_token: Optional[str] = None, + blob_content_type: SupportedContentType = SupportedContentType.PLAIN_TEXT, + ) -> None: + """ + Initialize an Azure Blob Storage I/O adapter. + + Args: + container_url (Optional[str]): Azure Blob container URL. + sas_token (Optional[str]): Optional SAS token. + blob_content_type (SupportedContentType): Blob content type for uploads. + + Raises: + ValueError: If container_url is missing. + + """ + self._blob_content_type: str = blob_content_type.value + if not container_url: + raise ValueError("Invalid Azure Storage Account Container URL.") + + self._container_url: str = container_url + self._sas_token = sas_token + self._client_async: AsyncContainerClient | None = None + + async def _create_container_client_async(self) -> AsyncContainerClient: + """ + Create an asynchronous ContainerClient for Azure Storage. + + If a SAS token is provided via the + AZURE_STORAGE_ACCOUNT_SAS_TOKEN environment variable or the init sas_token parameter, it will be used + for authentication. Otherwise, a delegation SAS token will be created using Entra ID authentication. + + Returns: + AsyncContainerClient: The initialized container client. + """ + from azure.storage.blob.aio import ContainerClient as AsyncContainerClient + + from pyrit.auth import AzureStorageAuth + + sas_token = self._sas_token + if not self._sas_token: + logger.info("SAS token not provided. Creating a delegation SAS token using Entra ID authentication.") + sas_token = await AzureStorageAuth.get_sas_token_async(self._container_url) + + self._client_async = AsyncContainerClient.from_container_url( + container_url=self._container_url, + credential=sas_token, + ) + return self._client_async + + async def _upload_blob_async(self, file_name: str, data: bytes, content_type: str) -> None: + """ + (Async) Handles uploading blob to given storage container. + + Args: + file_name (str): File name to assign to uploaded blob. + data (bytes): Byte representation of content to upload to container. + content_type (str): Content type to upload. + + Raises: + RuntimeError: If the Azure container client is not initialized. + """ + from azure.core.exceptions import ClientAuthenticationError + from azure.storage.blob import ContentSettings + + content_settings = ContentSettings(content_type=f"{content_type}") + logger.info(msg="\nUploading to Azure Storage as blob:\n\t" + file_name) + + try: + if self._client_async is None: + raise RuntimeError("Azure container client not initialized") + await self._client_async.upload_blob( + name=file_name, + data=data, + content_settings=content_settings, + overwrite=True, + ) + except Exception as exc: + if isinstance(exc, ClientAuthenticationError): + logger.exception( + msg="Authentication failed. Please check that the container existence in the " + "Azure Storage Account and ensure the validity of the provided SAS token. If you " + "haven't set the SAS token as an environment variable use `az login` to " + "enable delegation-based SAS authentication to connect to the storage account" + ) + raise + logger.exception(msg=f"An unexpected error occurred: {exc}") + raise + + def parse_blob_url(self, file_path: str) -> tuple[str, str]: + """ + Parse a blob URL to extract the container and blob name. + + Args: + file_path (str): Full blob URL. + + Returns: + tuple[str, str]: Container name and blob name. + + Raises: + ValueError: If file_path is not a valid blob URL. + + """ + parsed_url = urlparse(file_path) + if parsed_url.scheme and parsed_url.netloc: + container_name = parsed_url.path.split("/")[1] + blob_name = "/".join(parsed_url.path.split("/")[2:]) + return container_name, blob_name + raise ValueError("Invalid blob URL") + + def _resolve_blob_name(self, path: Union[Path, str]) -> str: + """ + Resolve a blob name from either a full blob URL or a relative blob path. + + When a full URL is provided the blob name is extracted from it. The container + name embedded in the URL is intentionally discarded — operations always run + against the container configured in the constructor. + + Backslashes are normalized to forward slashes so that ``Path`` objects + created on Windows still produce valid blob names. + + Args: + path (Union[Path, str]): Blob URL or relative blob path. + + Returns: + str: The resolved blob name. + + """ + path_str = str(path).replace("\\", "/") + try: + # parse_blob_url validates scheme + netloc internally + _, blob_name = self.parse_blob_url(path_str) + return blob_name + except ValueError: + return path_str + + async def read_file_async(self, path: Union[Path, str]) -> bytes: + """ + Asynchronously reads the content of a file (blob) from Azure Blob Storage. + + If the provided ``path`` is a full URL + (e.g., ``https://account.blob.core.windows.net/container/dir1/dir2/sample.png``), + it extracts the relative blob path (e.g., ``dir1/dir2/sample.png``) to correctly access the blob. + If a relative path is provided, it will use it as-is. + + Args: + path (str): The path to the file (blob) in Azure Blob Storage. + This can be either a full URL or a relative path. + + Returns: + bytes: The content of the file (blob) as bytes. + + Example: + ``file_content = await read_file_async("https://account.blob.core.windows.net/container/dir2/1726627689003831.png")`` + + Or using a relative path: + + ``file_content = await read_file_async("dir1/dir2/1726627689003831.png")`` + + """ + if not self._client_async: + self._client_async = await self._create_container_client_async() + + blob_name = self._resolve_blob_name(path) + + try: + blob_client = self._client_async.get_blob_client(blob=blob_name) + + # Download the blob + blob_stream = await blob_client.download_blob() + return bytes(await blob_stream.readall()) + + except Exception as exc: + logger.exception(f"Failed to read file at {blob_name}: {exc}") + raise + finally: + await self._client_async.close() + self._client_async = None + + async def write_file_async(self, path: Union[Path, str], data: bytes) -> None: + """ + Write data to Azure Blob Storage at the specified path. + + If the provided ``path`` is a full URL, the blob name is extracted from it. + If a relative path is provided, it is used as the blob name directly. + + Args: + path (Union[Path, str]): Full blob URL or relative blob path. + data (bytes): The data to write. + """ + if not self._client_async: + self._client_async = await self._create_container_client_async() + blob_name = self._resolve_blob_name(path) + try: + await self._upload_blob_async(file_name=blob_name, data=data, content_type=self._blob_content_type) + except Exception as exc: + logger.exception(f"Failed to write file at {blob_name}: {exc}") + raise + finally: + await self._client_async.close() + self._client_async = None + + async def path_exists_async(self, path: Union[Path, str]) -> bool: + """ + Check whether a given path exists in the Azure Blob Storage container. + + Args: + path (Union[Path, str]): Blob URL or path to test. + + Returns: + bool: True when the path exists. + """ + from azure.core.exceptions import ResourceNotFoundError + + if not self._client_async: + self._client_async = await self._create_container_client_async() + try: + blob_name = self._resolve_blob_name(path) + blob_client = self._client_async.get_blob_client(blob=blob_name) + await blob_client.get_blob_properties() + return True + except ResourceNotFoundError: + return False + finally: + await self._client_async.close() + self._client_async = None + + async def is_file_async(self, path: Union[Path, str]) -> bool: + """ + Check whether the path refers to a file (blob) in Azure Blob Storage. + + Args: + path (Union[Path, str]): Blob URL or path to test. + + Returns: + bool: True when the blob exists and has non-zero content size. + """ + from azure.core.exceptions import ResourceNotFoundError + + if not self._client_async: + self._client_async = await self._create_container_client_async() + try: + blob_name = self._resolve_blob_name(path) + blob_client = self._client_async.get_blob_client(blob=blob_name) + blob_properties = await blob_client.get_blob_properties() + return bool(blob_properties.size > 0) + except ResourceNotFoundError: + return False + finally: + await self._client_async.close() + self._client_async = None + + async def create_directory_if_not_exists_async(self, directory_path: Union[Path, str]) -> None: # type: ignore[ty:invalid-method-override] + """ + Log a no-op directory creation for Azure Blob Storage. + + Args: + directory_path (Union[Path, str]): Requested directory path. + + """ + logger.info( + f"Directory creation is handled automatically during upload operations in Azure Blob Storage. " + f"Directory path: {directory_path}" + ) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 6723ae2842..ce4050040f 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -19,6 +19,7 @@ from pyrit.auth.azure_auth import AzureAuth from pyrit.common import default_values from pyrit.common.singleton import Singleton +from pyrit.io import AzureBlobStorageIO from pyrit.memory.memory_interface import MemoryInterface from pyrit.memory.memory_models import ( AttackResultEntry, @@ -26,11 +27,7 @@ EmbeddingDataEntry, PromptMemoryEntry, ) -from pyrit.models import ( - AzureBlobStorageIO, - ConversationStats, - MessagePiece, -) +from pyrit.models import ConversationStats, MessagePiece if TYPE_CHECKING: from azure.core.credentials import AccessToken diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 26448f5b6c..32f6661d2a 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -23,6 +23,12 @@ from pyrit.common.deprecation import print_deprecation_message from pyrit.common.path import DB_DATA_PATH +from pyrit.io import ( + DataTypeSerializer, + StorageIO, + data_serializer_factory, + set_seed_sha256_async, +) from pyrit.memory.memory_exporter import MemoryExporter from pyrit.memory.memory_models import ( AttackResultEntry, @@ -36,7 +42,6 @@ from pyrit.models import ( AttackResult, ConversationStats, - DataTypeSerializer, IdentifierFilter, IdentifierType, Message, @@ -47,8 +52,6 @@ SeedDataset, SeedGroup, SeedType, - StorageIO, - data_serializer_factory, group_conversation_message_pieces_by_sequence, sort_message_pieces, ) @@ -1397,7 +1400,7 @@ async def add_seeds_to_memory_async(self, *, seeds: Sequence[Seed], added_by: Op serialized_prompt_value = await self._serialize_seed_value_async(prompt=prompt) prompt.value = serialized_prompt_value - await prompt.set_sha256_value_async() + await set_seed_sha256_async(prompt) if prompt.value_sha256 and not self.get_seeds( value_sha256=[prompt.value_sha256], dataset_name=prompt.dataset_name diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index 461d2b871b..57202ffec2 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -21,6 +21,7 @@ from pyrit.common.deprecation import print_deprecation_message from pyrit.common.path import DB_DATA_PATH from pyrit.common.singleton import Singleton +from pyrit.io import DiskStorageIO from pyrit.memory.memory_interface import MemoryInterface from pyrit.memory.memory_models import ( AttackResultEntry, @@ -29,7 +30,7 @@ PromptMemoryEntry, ScenarioResultEntry, ) -from pyrit.models import ConversationStats, DiskStorageIO, MessagePiece +from pyrit.models import ConversationStats, MessagePiece logger = logging.getLogger(__name__) diff --git a/pyrit/message_normalizer/chat_message_normalizer.py b/pyrit/message_normalizer/chat_message_normalizer.py index 4fc11fbd0a..09eef1fdeb 100644 --- a/pyrit/message_normalizer/chat_message_normalizer.py +++ b/pyrit/message_normalizer/chat_message_normalizer.py @@ -9,13 +9,14 @@ import aiofiles from pyrit.common.data_url_converter import convert_local_image_to_data_url_async +from pyrit.io import DataTypeSerializer from pyrit.message_normalizer.message_normalizer import ( MessageListNormalizer, MessageStringNormalizer, SystemMessageBehavior, apply_system_message_behavior_async, ) -from pyrit.models import ChatMessage, DataTypeSerializer, Message +from pyrit.models import ChatMessage, Message from pyrit.models.messages.message_piece import MessagePiece if TYPE_CHECKING: diff --git a/pyrit/models/__init__.py b/pyrit/models/__init__.py index f3256e5c98..3a6a0a49bf 100644 --- a/pyrit/models/__init__.py +++ b/pyrit/models/__init__.py @@ -17,6 +17,7 @@ a deprecation shim through ``0.16.0``. """ +import importlib from typing import TYPE_CHECKING, Any from pyrit.common.deprecation import print_deprecation_message @@ -27,17 +28,6 @@ ) from pyrit.models.conversation_reference import ConversationReference, ConversationType from pyrit.models.conversation_stats import ConversationStats -from pyrit.models.data_type_serializer import ( - AllowedCategories, - AudioPathDataTypeSerializer, - BinaryPathDataTypeSerializer, - DataTypeSerializer, - ErrorDataTypeSerializer, - ImagePathDataTypeSerializer, - TextDataTypeSerializer, - VideoPathDataTypeSerializer, - data_serializer_factory, -) from pyrit.models.embeddings import EmbeddingData, EmbeddingResponse, EmbeddingSupport, EmbeddingUsageInformation from pyrit.models.harm_definition import HarmDefinition, ScaleDescription, get_all_harm_definitions from pyrit.models.identifiers import ( @@ -101,10 +91,6 @@ SimulatedTargetSystemPromptPaths, ) -# Keep old module-level imports working (deprecated, will be removed) -# These are re-exported from the seeds submodule -from pyrit.models.storage_io import AzureBlobStorageIO, DiskStorageIO, StorageIO - __all__ = [ "ALLOWED_CHAT_MESSAGE_ROLES", "AllowedCategories", @@ -202,6 +188,24 @@ "ScorerIdentifier": ComponentIdentifier, } +# Names that moved to ``pyrit.io`` in Phase 9. Served lazily via importlib so that +# importing ``pyrit.models`` stays import-boundary clean and fires no warning until a +# moved name is actually accessed. Will be removed in 0.17.0. +_MOVED_TO_PYRIT_IO: dict[str, str] = { + "AllowedCategories": "pyrit.io.serializers", + "AudioPathDataTypeSerializer": "pyrit.io.serializers", + "BinaryPathDataTypeSerializer": "pyrit.io.serializers", + "DataTypeSerializer": "pyrit.io.serializers", + "ErrorDataTypeSerializer": "pyrit.io.serializers", + "ImagePathDataTypeSerializer": "pyrit.io.serializers", + "TextDataTypeSerializer": "pyrit.io.serializers", + "VideoPathDataTypeSerializer": "pyrit.io.serializers", + "data_serializer_factory": "pyrit.io.serializers", + "AzureBlobStorageIO": "pyrit.io.storage", + "DiskStorageIO": "pyrit.io.storage", + "StorageIO": "pyrit.io.storage", +} + _warned: set[str] = set() @@ -216,4 +220,14 @@ def __getattr__(name: str) -> Any: ) _warned.add(name) return target + if name in _MOVED_TO_PYRIT_IO: + target_module = _MOVED_TO_PYRIT_IO[name] + if name not in _warned: + print_deprecation_message( + old_item=f"{__name__}.{name}", + new_item=f"{target_module}.{name}", + removed_in="0.17.0", + ) + _warned.add(name) + return getattr(importlib.import_module(target_module), name) raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/pyrit/models/data_type_serializer.py b/pyrit/models/data_type_serializer.py index 3ee43eed62..369cc405d5 100644 --- a/pyrit/models/data_type_serializer.py +++ b/pyrit/models/data_type_serializer.py @@ -1,747 +1,38 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from __future__ import annotations - -import abc -import asyncio -import base64 -import hashlib -import time -import wave -from mimetypes import guess_type -from pathlib import Path -from typing import TYPE_CHECKING, Literal, Optional, Union, get_args -from urllib.parse import urlparse - -import aiofiles - -from pyrit.common.deprecation import print_deprecation_message -from pyrit.common.path import DB_DATA_PATH -from pyrit.models.storage_io import DiskStorageIO, StorageIO - -if TYPE_CHECKING: - from pyrit.memory import MemoryInterface - from pyrit.models.literals import PromptDataType - -# Define allowed categories for validation -AllowedCategories = Literal["seed-prompt-entries", "prompt-memory-entries"] - - -def _write_wav_sync( - path: str, - *, - num_channels: int, - sample_width: int, - sample_rate: int, - data: bytes, -) -> None: - """Write PCM audio bytes to a WAV file synchronously.""" - with wave.open(path, "wb") as wav_file: - wav_file.setnchannels(num_channels) - wav_file.setsampwidth(sample_width) - wav_file.setframerate(sample_rate) - wav_file.writeframes(data) - - -def data_serializer_factory( - *, - data_type: PromptDataType, - value: Optional[str] = None, - extension: Optional[str] = None, - category: AllowedCategories, -) -> DataTypeSerializer: - """ - Create a DataTypeSerializer instance. - - Args: - data_type (str): The type of the data (e.g., 'text', 'image_path', 'audio_path'). - value (str): The data value to be serialized. - extension (Optional[str]): The file extension, if applicable. - category (AllowedCategories): The category or context for the data (e.g., 'seed-prompt-entries'). - - Returns: - DataTypeSerializer: An instance of the appropriate serializer. - - Raises: - ValueError: If the category is not provided or invalid. - - """ - if not category: - raise ValueError( - f"The 'category' argument is mandatory and must be one of the following: {get_args(AllowedCategories)}." - ) - if value is not None: - if data_type in ["text", "reasoning", "function_call", "tool_call", "function_call_output"]: - return TextDataTypeSerializer(prompt_text=value, data_type=data_type) - if data_type == "image_path": - return ImagePathDataTypeSerializer(category=category, prompt_text=value, extension=extension) - if data_type == "audio_path": - return AudioPathDataTypeSerializer(category=category, prompt_text=value, extension=extension) - if data_type == "video_path": - return VideoPathDataTypeSerializer(category=category, prompt_text=value, extension=extension) - if data_type == "binary_path": - return BinaryPathDataTypeSerializer(category=category, prompt_text=value, extension=extension) - if data_type == "error": - return ErrorDataTypeSerializer(prompt_text=value) - if data_type == "url": - return URLDataTypeSerializer(category=category, prompt_text=value, extension=extension) - raise ValueError(f"Data type {data_type} not supported") - if data_type == "image_path": - return ImagePathDataTypeSerializer(category=category, extension=extension) - if data_type == "audio_path": - return AudioPathDataTypeSerializer(category=category, extension=extension) - if data_type == "video_path": - return VideoPathDataTypeSerializer(category=category, extension=extension) - if data_type == "binary_path": - return BinaryPathDataTypeSerializer(category=category, extension=extension) - if data_type == "error": - return ErrorDataTypeSerializer(prompt_text="") - raise ValueError(f"Data type {data_type} without prompt text not supported") - - -class DataTypeSerializer(abc.ABC): - """ - Abstract base class for data type normalizers. - - Responsible for reading and saving multi-modal data types to local disk or Azure Storage Account. - """ - - data_type: PromptDataType - value: str - category: str - data_sub_directory: str - file_extension: str - - _file_path: Union[Path, str] | None = None - - @property - def _memory(self) -> MemoryInterface: - from pyrit.memory import CentralMemory - - return CentralMemory.get_memory_instance() - - def _get_storage_io(self) -> StorageIO: - """ - Retrieve the input datasets storage handle. - - Returns: - StorageIO: An instance of DiskStorageIO or AzureBlobStorageIO based on the storage configuration. - - Raises: - ValueError: If the Azure Storage URL is detected but the datasets storage handle is not set. - RuntimeError: If results_storage_io is not configured but Azure storage URL was detected. - - """ - if self._is_azure_storage_url(self.value): - # Scenarios where a user utilizes an in-memory DuckDB but also needs to interact - # with an Azure Storage Account, ex., XPIAWorkflow. - if self._memory.results_storage_io is None: - raise RuntimeError("results_storage_io is not configured but Azure storage URL was detected") - return self._memory.results_storage_io - return DiskStorageIO() - - @abc.abstractmethod - def data_on_disk(self) -> bool: - """ - Indicate whether the data is stored on disk. - - Returns: - bool: True when data is persisted on disk. - - """ - - async def save_data_async(self, data: bytes, output_filename: Optional[str] = None) -> None: - """ - Save data to storage. - - Arguments: - data: bytes: The data to be saved. - output_filename (optional, str): filename to store data as. Defaults to UUID if not provided - - Raises: - RuntimeError: If storage IO is not initialized. - """ - file_path = await self.get_data_filename_async(file_name=output_filename) - if self._memory.results_storage_io is None: - raise RuntimeError("Storage IO not initialized") - await self._memory.results_storage_io.write_file_async(file_path, data) - self.value = str(file_path) - - async def save_b64_image_async(self, data: str | bytes, output_filename: str | None = None) -> None: - """ - Save a base64-encoded image to storage. - - Arguments: - data: string or bytes with base64 data - output_filename (optional, str): filename to store image as. Defaults to UUID if not provided - - Raises: - RuntimeError: If storage IO is not initialized. - """ - file_path = await self.get_data_filename_async(file_name=output_filename) - image_bytes = base64.b64decode(data) - if self._memory.results_storage_io is None: - raise RuntimeError("Storage IO not initialized") - await self._memory.results_storage_io.write_file_async(file_path, image_bytes) - self.value = str(file_path) - - async def save_formatted_audio_async( - self, - data: bytes, - num_channels: int = 1, - sample_width: int = 2, - sample_rate: int = 16000, - output_filename: Optional[str] = None, - ) -> None: - """ - Save PCM16 or similarly formatted audio data to storage. - - Arguments: - data: bytes with audio data - output_filename (optional, str): filename to store audio as. Defaults to UUID if not provided - num_channels (optional, int): number of channels in audio data. Defaults to 1 - sample_width (optional, int): sample width in bytes. Defaults to 2 - sample_rate (optional, int): sample rate in Hz. Defaults to 16000 - - Raises: - RuntimeError: If storage IO is not initialized. - """ - file_path = await self.get_data_filename_async(file_name=output_filename) - - # save audio file locally first if in AzureStorageBlob so we can use wave.open to set audio parameters - if self._is_azure_storage_url(str(file_path)): - local_temp_path = Path(DB_DATA_PATH, "temp_audio.wav") - await asyncio.to_thread( - _write_wav_sync, - str(local_temp_path), - num_channels=num_channels, - sample_width=sample_width, - sample_rate=sample_rate, - data=data, - ) - - async with aiofiles.open(local_temp_path, "rb") as f: - audio_data = await f.read() - if self._memory.results_storage_io is None: - raise RuntimeError("self._memory.results_storage_io is not initialized") - await self._memory.results_storage_io.write_file_async(file_path, audio_data) - local_temp_path.unlink() - - # If local, we can just save straight to disk and do not need to delete temp file after - else: - await asyncio.to_thread( - _write_wav_sync, - str(file_path), - num_channels=num_channels, - sample_width=sample_width, - sample_rate=sample_rate, - data=data, - ) - - self.value = str(file_path) - - async def read_data_async(self) -> bytes: - """ - Read data from storage. - - Returns: - bytes: The data read from storage. - - Raises: - TypeError: If the serializer does not represent on-disk data. - RuntimeError: If no value is set. - FileNotFoundError: If the referenced file does not exist. - - """ - if not self.data_on_disk(): - raise TypeError(f"Data for data Type {self.data_type} is not stored on disk") - - if not self.value: - raise RuntimeError("Prompt text not set") - - storage_io = self._get_storage_io() - # Check if path exists - file_exists = await storage_io.path_exists_async(path=self.value) - if not file_exists: - raise FileNotFoundError(f"File not found: {self.value}") - # Read the contents from the path - return await storage_io.read_file_async(self.value) - - async def read_data_base64_async(self) -> str: - """ - Read data from storage and return it as a base64 string. - - Returns: - str: Base64-encoded data. - - """ - byte_array = await self.read_data_async() - return base64.b64encode(byte_array).decode("utf-8") - - async def get_sha256_async(self) -> str: - """ - Compute SHA256 hash for this serializer's current value. - - Returns: - str: Hex digest of the computed SHA256 hash. - - Raises: - FileNotFoundError: If on-disk data path does not exist. - ValueError: If in-memory data cannot be converted to bytes. - - """ - input_bytes: bytes | None = None - - if self.data_on_disk(): - storage_io = self._get_storage_io() - file_exists = await storage_io.path_exists_async(self.value) - if not file_exists: - raise FileNotFoundError(f"File not found: {self.value}") - - # Read the data from storage - input_bytes = await storage_io.read_file_async(self.value) - else: - if isinstance(self.value, str): - input_bytes = self.value.encode("utf-8") - else: - raise ValueError(f"Invalid data type {self.value}, expected str data type.") - - hash_object = hashlib.sha256(input_bytes) - return hash_object.hexdigest() - - async def get_data_filename_async(self, file_name: Optional[str] = None) -> Union[Path, str]: - """ - Generate or retrieve a unique filename for the data file. - - Args: - file_name (Optional[str]): Optional file name override. - - Returns: - Union[Path, str]: Full storage path for the generated data file. - - Raises: - TypeError: If the serializer is not configured for on-disk data. - RuntimeError: If required data subdirectory information is missing. - - """ - if self._file_path: - return self._file_path - - if not self.data_on_disk(): - raise TypeError("Data is not stored on disk") - - if not self.data_sub_directory: - raise RuntimeError("Data sub directory not set") - - ticks = int(time.time() * 1_000_000) - if self._memory.results_path: - results_path = str(self._memory.results_path) - else: - from pyrit.common.path import DB_DATA_PATH - - results_path = str(DB_DATA_PATH) - file_name = file_name if file_name else str(ticks) - - if self._is_azure_storage_url(results_path): - full_data_directory_path = results_path + self.data_sub_directory - self._file_path = full_data_directory_path + f"/{file_name}.{self.file_extension}" - else: - full_data_directory_path = results_path + self.data_sub_directory - if self._memory.results_storage_io is None: - raise RuntimeError("self._memory.results_storage_io is not initialized") - await self._memory.results_storage_io.create_directory_if_not_exists_async(Path(full_data_directory_path)) - self._file_path = Path(full_data_directory_path, f"{file_name}.{self.file_extension}") - - return self._file_path - - async def save_data( # pyrit-async-suffix-exempt - self, data: bytes, output_filename: Optional[str] = None - ) -> None: - """ - Save data to storage (deprecated alias of ``save_data_async``). - - Args: - data: The data to be saved. - output_filename: Optional filename to store data as. - """ - print_deprecation_message( - old_item="pyrit.models.data_type_serializer.DataTypeSerializer.save_data", - new_item="pyrit.models.data_type_serializer.DataTypeSerializer.save_data_async", - removed_in="0.16.0", - ) - await self.save_data_async(data, output_filename) - - async def save_b64_image( # pyrit-async-suffix-exempt - self, data: str | bytes, output_filename: str | None = None - ) -> None: - """ - Save a base64-encoded image to storage (deprecated alias of ``save_b64_image_async``). - - Args: - data: String or bytes with base64 data. - output_filename: Optional filename to store image as. - """ - print_deprecation_message( - old_item="pyrit.models.data_type_serializer.DataTypeSerializer.save_b64_image", - new_item="pyrit.models.data_type_serializer.DataTypeSerializer.save_b64_image_async", - removed_in="0.16.0", - ) - await self.save_b64_image_async(data, output_filename) - - async def save_formatted_audio( # pyrit-async-suffix-exempt - self, - data: bytes, - num_channels: int = 1, - sample_width: int = 2, - sample_rate: int = 16000, - output_filename: Optional[str] = None, - ) -> None: - """ - Save formatted audio data to storage (deprecated alias of ``save_formatted_audio_async``). - - Args: - data: Audio data bytes. - num_channels: Number of channels in audio data. - sample_width: Sample width in bytes. - sample_rate: Sample rate in Hz. - output_filename: Optional filename to store audio as. - """ - print_deprecation_message( - old_item="pyrit.models.data_type_serializer.DataTypeSerializer.save_formatted_audio", - new_item="pyrit.models.data_type_serializer.DataTypeSerializer.save_formatted_audio_async", - removed_in="0.16.0", - ) - await self.save_formatted_audio_async(data, num_channels, sample_width, sample_rate, output_filename) - - async def read_data(self) -> bytes: # pyrit-async-suffix-exempt - """ - Read data from storage (deprecated alias of ``read_data_async``). - - Returns: - bytes: The data read from storage. - """ - print_deprecation_message( - old_item="pyrit.models.data_type_serializer.DataTypeSerializer.read_data", - new_item="pyrit.models.data_type_serializer.DataTypeSerializer.read_data_async", - removed_in="0.16.0", - ) - return await self.read_data_async() - - async def read_data_base64(self) -> str: # pyrit-async-suffix-exempt - """ - Read data and return it as a base64 string (deprecated alias of ``read_data_base64_async``). +""" +Deprecation shim — the data-type serializers moved to ``pyrit.io.serializers``. - Returns: - str: Base64-encoded data. - """ - print_deprecation_message( - old_item="pyrit.models.data_type_serializer.DataTypeSerializer.read_data_base64", - new_item="pyrit.models.data_type_serializer.DataTypeSerializer.read_data_base64_async", - removed_in="0.16.0", - ) - return await self.read_data_base64_async() +Importing names from ``pyrit.models.data_type_serializer`` still works for one +release but emits a one-time ``DeprecationWarning`` per name. Import from +``pyrit.io`` instead. This shim will be removed in 0.17.0. +""" - async def get_sha256(self) -> str: # pyrit-async-suffix-exempt - """ - Compute SHA256 hash for this serializer's current value (deprecated alias of ``get_sha256_async``). - - Returns: - str: Hex digest of the computed SHA256 hash. - """ - print_deprecation_message( - old_item="pyrit.models.data_type_serializer.DataTypeSerializer.get_sha256", - new_item="pyrit.models.data_type_serializer.DataTypeSerializer.get_sha256_async", - removed_in="0.16.0", - ) - return await self.get_sha256_async() - - async def get_data_filename( # pyrit-async-suffix-exempt - self, file_name: Optional[str] = None - ) -> Union[Path, str]: - """ - Generate or retrieve a unique filename for the data file (deprecated alias of ``get_data_filename_async``). - - Args: - file_name: Optional file name override. - - Returns: - Union[Path, str]: Full storage path for the generated data file. - """ - print_deprecation_message( - old_item="pyrit.models.data_type_serializer.DataTypeSerializer.get_data_filename", - new_item="pyrit.models.data_type_serializer.DataTypeSerializer.get_data_filename_async", - removed_in="0.16.0", - ) - return await self.get_data_filename_async(file_name) - - @staticmethod - def get_extension(file_path: str) -> str | None: - """ - Get the file extension from the file path. - - Args: - file_path (str): Input file path. - - Returns: - str | None: File extension (including dot) or None if unavailable. - - """ - ext = Path(file_path).suffix - return ext or None - - @staticmethod - def get_mime_type(file_path: str) -> str | None: - """ - Get the MIME type of the file path. - - Args: - file_path (str): Input file path. - - Returns: - str | None: MIME type if detectable; otherwise None. - - """ - mime_type, _ = guess_type(file_path) - return mime_type - - def _is_azure_storage_url(self, path: str) -> bool: - """ - Validate whether the given path is an Azure Storage URL. - - Args: - path (str): Path or URL to check. - - Returns: - bool: True if the path is an Azure Blob Storage URL. - - """ - parsed = urlparse(path) - return parsed.scheme in ("http", "https") and "blob.core.windows.net" in parsed.netloc - - -class TextDataTypeSerializer(DataTypeSerializer): - """Serializer for text and text-like prompt values that stay in-memory.""" - - def __init__(self, *, prompt_text: str, data_type: PromptDataType = "text") -> None: - """ - Initialize a text serializer. - - Args: - prompt_text (str): Prompt value. - data_type (PromptDataType): Text-like prompt data type. - - """ - self.data_type = data_type - self.value = prompt_text - - def data_on_disk(self) -> bool: - """ - Indicate whether this serializer persists data on disk. - - Returns: - bool: Always False for text serializers. - - """ - return False - - -class ErrorDataTypeSerializer(DataTypeSerializer): - """Serializer for error payloads stored as in-memory text.""" - - def __init__(self, *, prompt_text: str) -> None: - """ - Initialize an error serializer. - - Args: - prompt_text (str): Error payload text. - - """ - self.data_type = "error" - self.value = prompt_text - - def data_on_disk(self) -> bool: - """ - Indicate whether this serializer persists data on disk. - - Returns: - bool: Always False for error serializers. - - """ - return False - - -class URLDataTypeSerializer(DataTypeSerializer): - """Serializer for URL values and URL-backed local file references.""" - - def __init__(self, *, category: str, prompt_text: str, extension: Optional[str] = None) -> None: - """ - Initialize a URL serializer. - - Args: - category (str): Data category folder name. - prompt_text (str): URL or path value. - extension (Optional[str]): Optional extension for persisted content. - - """ - self.data_type = "url" - self.value = prompt_text - self.data_sub_directory = f"/{category}/urls" - self.file_extension = extension if extension else "txt" - self.on_disk = not (prompt_text.startswith(("http://", "https://"))) - - def data_on_disk(self) -> bool: - """ - Indicate whether this serializer persists data on disk. - - Returns: - bool: True for non-http values, False for URL values. - - """ - return self.on_disk - - -class ImagePathDataTypeSerializer(DataTypeSerializer): - """Serializer for image path values stored on disk.""" - - def __init__(self, *, category: str, prompt_text: Optional[str] = None, extension: Optional[str] = None) -> None: - """ - Initialize an image-path serializer. - - Args: - category (str): Data category folder name. - prompt_text (Optional[str]): Optional existing image path. - extension (Optional[str]): Optional image extension. - - """ - self.data_type = "image_path" - self.data_sub_directory = f"/{category}/images" - self.file_extension = extension if extension else "png" - - if prompt_text: - self.value = prompt_text - - def data_on_disk(self) -> bool: - """ - Indicate whether this serializer persists data on disk. - - Returns: - bool: Always True for image path serializers. - - """ - return True - - -class AudioPathDataTypeSerializer(DataTypeSerializer): - """Serializer for audio path values stored on disk.""" - - def __init__( - self, - *, - category: str, - prompt_text: Optional[str] = None, - extension: Optional[str] = None, - ) -> None: - """ - Initialize an audio-path serializer. - - Args: - category (str): Data category folder name. - prompt_text (Optional[str]): Optional existing audio path. - extension (Optional[str]): Optional audio extension. - - """ - self.data_type = "audio_path" - self.data_sub_directory = f"/{category}/audio" - self.file_extension = extension if extension else "mp3" - - if prompt_text: - self.value = prompt_text - - def data_on_disk(self) -> bool: - """ - Indicate whether this serializer persists data on disk. - - Returns: - bool: Always True for audio path serializers. - - """ - return True - - -class VideoPathDataTypeSerializer(DataTypeSerializer): - """Serializer for video path values stored on disk.""" - - def __init__( - self, - *, - category: str, - prompt_text: Optional[str] = None, - extension: Optional[str] = None, - ) -> None: - """ - Initialize a video-path serializer. - - Args: - category (str): The category or context for the data. - prompt_text (Optional[str]): The video path or identifier. - extension (Optional[str]): The file extension, defaults to 'mp4'. - - """ - self.data_type = "video_path" - self.data_sub_directory = f"/{category}/videos" - self.file_extension = extension if extension else "mp4" - - if prompt_text: - self.value = prompt_text - - def data_on_disk(self) -> bool: - """ - Indicate whether this serializer persists data on disk. - - Returns: - bool: Always True for video path serializers. - - """ - return True - - -class BinaryPathDataTypeSerializer(DataTypeSerializer): - """Serializer for generic binary path values stored on disk.""" - - def __init__( - self, - *, - category: str, - prompt_text: Optional[str] = None, - extension: Optional[str] = None, - ) -> None: - """ - Initialize a generic binary-path serializer. - - This serializer handles generic binary data that doesn't fit into specific - categories like images, audio, or video. Useful for XPIA attacks and - storing files like PDFs, documents, or other binary formats. - - Args: - category (str): The category or context for the data. - prompt_text (Optional[str]): The binary file path or identifier. - extension (Optional[str]): The file extension, defaults to 'bin'. - - """ - self.data_type = "binary_path" - self.data_sub_directory = f"/{category}/binaries" - self.file_extension = extension if extension else "bin" - - if prompt_text: - self.value = prompt_text - - def data_on_disk(self) -> bool: - """ - Indicate whether this serializer persists data on disk. - - Returns: - bool: Always True for binary path serializers. +from __future__ import annotations - """ - return True +from pyrit.common.deprecation import module_deprecation_getattr + +__all__ = [ + "AllowedCategories", + "AudioPathDataTypeSerializer", + "BinaryPathDataTypeSerializer", + "DataTypeSerializer", + "data_serializer_factory", + "ErrorDataTypeSerializer", + "ImagePathDataTypeSerializer", + "TextDataTypeSerializer", + "URLDataTypeSerializer", + "VideoPathDataTypeSerializer", +] + +__getattr__ = module_deprecation_getattr( + old_module="pyrit.models.data_type_serializer", + target_module="pyrit.io.serializers", + names=__all__, + removed_in="0.17.0", +) + + +def __dir__() -> list[str]: + return sorted(__all__) diff --git a/pyrit/models/messages/message_piece.py b/pyrit/models/messages/message_piece.py index 131919ecf8..44c64eb453 100644 --- a/pyrit/models/messages/message_piece.py +++ b/pyrit/models/messages/message_piece.py @@ -17,7 +17,6 @@ ) from pyrit.common.deprecation import print_deprecation_message -from pyrit.models.data_type_serializer import data_serializer_factory from pyrit.models.literals import ( # noqa: TC001 (runtime-required by Pydantic field annotations) ChatMessageRole, PromptDataType, @@ -306,22 +305,19 @@ async def set_sha256_values_async(self) -> None: """ Compute SHA256 hash values for original and converted payloads. - Async because blob payloads may need to be fetched. Must be called - explicitly after construction. + .. deprecated:: 0.15.0 + Use ``pyrit.io.serializers.set_message_piece_sha256_async`` instead. + This method will be removed in 0.17.0. """ - original_serializer = data_serializer_factory( - category="prompt-memory-entries", - data_type=self.original_value_data_type, - value=self.original_value, - ) - self.original_value_sha256 = await original_serializer.get_sha256_async() + import importlib - converted_serializer = data_serializer_factory( - category="prompt-memory-entries", - data_type=self.converted_value_data_type, - value=self.converted_value, + print_deprecation_message( + old_item="pyrit.models.messages.message_piece.MessagePiece.set_sha256_values_async", + new_item="pyrit.io.serializers.set_message_piece_sha256_async", + removed_in="0.17.0", ) - self.converted_value_sha256 = await converted_serializer.get_sha256_async() + serializers = importlib.import_module("pyrit.io.serializers") + await serializers.set_message_piece_sha256_async(self) def sort_message_pieces(message_pieces: list[MessagePiece]) -> list[MessagePiece]: diff --git a/pyrit/models/seeds/seed.py b/pyrit/models/seeds/seed.py index 353d5313c7..7fc6da745c 100644 --- a/pyrit/models/seeds/seed.py +++ b/pyrit/models/seeds/seed.py @@ -232,19 +232,22 @@ def render_template_value_silent(self, **kwargs: Any) -> str: async def set_sha256_value_async(self) -> None: """ Compute the SHA256 hash value asynchronously. - It should be called after prompt `value` is serialized to text, - as file paths used in the `value` may have changed from local to memory storage paths. - Note, this method is async due to the blob retrieval. And because of that, we opted - to take it out of main and setter functions. The disadvantage is that it must be explicitly called. + .. deprecated:: 0.15.0 + Use ``pyrit.io.serializers.set_seed_sha256_async`` instead. + This method will be removed in 0.17.0. """ - from pyrit.models.data_type_serializer import data_serializer_factory + import importlib - original_serializer = data_serializer_factory( - category="seed-prompt-entries", data_type=self.data_type, value=self.value - ) + from pyrit.common.deprecation import print_deprecation_message - self.value_sha256 = await original_serializer.get_sha256_async() + print_deprecation_message( + old_item="pyrit.models.seeds.seed.Seed.set_sha256_value_async", + new_item="pyrit.io.serializers.set_seed_sha256_async", + removed_in="0.17.0", + ) + serializers = importlib.import_module("pyrit.io.serializers") + await serializers.set_seed_sha256_async(self) @staticmethod def escape_for_jinja(value: str) -> str: diff --git a/pyrit/models/seeds/seed_prompt.py b/pyrit/models/seeds/seed_prompt.py index 9656211be8..ecb3e5b4f6 100644 --- a/pyrit/models/seeds/seed_prompt.py +++ b/pyrit/models/seeds/seed_prompt.py @@ -15,7 +15,6 @@ from tinytag import TinyTag from pyrit.common.path import PATHS_DICT -from pyrit.models.data_type_serializer import DataTypeSerializer from pyrit.models.literals import ( # noqa: TC001 (runtime-required by Pydantic field annotations) ChatMessageRole, PromptDataType, @@ -106,7 +105,7 @@ def set_encoding_metadata(self) -> None: return if self.metadata is None: self.metadata = {} - extension = DataTypeSerializer.get_extension(self.value) + extension = Path(self.value).suffix or None if extension: extension = extension.lstrip(".") self.metadata.update({"format": extension}) diff --git a/pyrit/models/storage_io.py b/pyrit/models/storage_io.py index 5b610f80d8..c2057935f0 100644 --- a/pyrit/models/storage_io.py +++ b/pyrit/models/storage_io.py @@ -1,507 +1,32 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from __future__ import annotations - -import logging -from abc import ABC, abstractmethod -from enum import Enum -from pathlib import Path -from typing import TYPE_CHECKING, Optional, Union -from urllib.parse import urlparse - -import aiofiles - -from pyrit.common.deprecation import print_deprecation_message - -if TYPE_CHECKING: - from azure.storage.blob.aio import ContainerClient as AsyncContainerClient - -logger = logging.getLogger(__name__) - - -class SupportedContentType(Enum): - """ - All supported content types for uploading blobs to provided storage account container. - See all options here: https://www.iana.org/assignments/media-types/media-types.xhtml. - """ - - # TODO, add other media supported types - PLAIN_TEXT = "text/plain" - - -class StorageIO(ABC): - """ - Abstract interface for storage systems (local disk, Azure Storage Account, etc.). - """ - - @abstractmethod - async def read_file_async(self, path: Union[Path, str]) -> bytes: - """ - Asynchronously reads the file (or blob) from the given path. - """ - - @abstractmethod - async def write_file_async(self, path: Union[Path, str], data: bytes) -> None: - """ - Asynchronously writes data to the given path. - """ - - @abstractmethod - async def path_exists_async(self, path: Union[Path, str]) -> bool: - """ - Asynchronously checks if a file or blob exists at the given path. - """ - - @abstractmethod - async def is_file_async(self, path: Union[Path, str]) -> bool: - """ - Asynchronously checks if the path refers to a file (not a directory or container). - """ - - @abstractmethod - async def create_directory_if_not_exists_async(self, path: Union[Path, str]) -> None: - """ - Asynchronously creates a directory or equivalent in the storage system if it doesn't exist. - """ - - async def read_file(self, path: Union[Path, str]) -> bytes: # pyrit-async-suffix-exempt - """ - Read a file from storage (deprecated alias of ``read_file_async``). - - Args: - path (Union[Path, str]): The path to the file. - - Returns: - bytes: The content of the file. - """ - print_deprecation_message( - old_item="pyrit.models.storage_io.StorageIO.read_file", - new_item="pyrit.models.storage_io.StorageIO.read_file_async", - removed_in="0.16.0", - ) - return await self.read_file_async(path) - - async def write_file(self, path: Union[Path, str], data: bytes) -> None: # pyrit-async-suffix-exempt - """ - Write data to storage (deprecated alias of ``write_file_async``). - - Args: - path (Union[Path, str]): The path to the file. - data (bytes): The content to write to the file. - """ - print_deprecation_message( - old_item="pyrit.models.storage_io.StorageIO.write_file", - new_item="pyrit.models.storage_io.StorageIO.write_file_async", - removed_in="0.16.0", - ) - await self.write_file_async(path, data) - - async def path_exists(self, path: Union[Path, str]) -> bool: # pyrit-async-suffix-exempt - """ - Check whether a path exists (deprecated alias of ``path_exists_async``). - - Args: - path (Union[Path, str]): The path to check. - - Returns: - bool: True if the path exists, False otherwise. - """ - print_deprecation_message( - old_item="pyrit.models.storage_io.StorageIO.path_exists", - new_item="pyrit.models.storage_io.StorageIO.path_exists_async", - removed_in="0.16.0", - ) - return await self.path_exists_async(path) - - async def is_file(self, path: Union[Path, str]) -> bool: # pyrit-async-suffix-exempt - """ - Check whether the given path is a file (deprecated alias of ``is_file_async``). - - Args: - path (Union[Path, str]): The path to check. - - Returns: - bool: True if the path is a file, False otherwise. - """ - print_deprecation_message( - old_item="pyrit.models.storage_io.StorageIO.is_file", - new_item="pyrit.models.storage_io.StorageIO.is_file_async", - removed_in="0.16.0", - ) - return await self.is_file_async(path) - - async def create_directory_if_not_exists(self, path: Union[Path, str]) -> None: # pyrit-async-suffix-exempt - """ - Create a directory if it does not exist (deprecated alias of ``create_directory_if_not_exists_async``). - - Args: - path (Union[Path, str]): The directory path to create. - """ - print_deprecation_message( - old_item="pyrit.models.storage_io.StorageIO.create_directory_if_not_exists", - new_item="pyrit.models.storage_io.StorageIO.create_directory_if_not_exists_async", - removed_in="0.16.0", - ) - await self.create_directory_if_not_exists_async(path) - - -class DiskStorageIO(StorageIO): - """ - Implementation of StorageIO for local disk storage. - """ - - async def read_file_async(self, path: Union[Path, str]) -> bytes: - """ - Asynchronously reads a file from the local disk. - - Args: - path (Union[Path, str]): The path to the file. - - Returns: - bytes: The content of the file. - - """ - path = self._convert_to_path(path) - async with aiofiles.open(path, "rb") as file: - return await file.read() - - async def write_file_async(self, path: Union[Path, str], data: bytes) -> None: - """ - Asynchronously writes data to a file on the local disk. - - Args: - path (Path): The path to the file. - data (bytes): The content to write to the file. - - """ - path = self._convert_to_path(path) - async with aiofiles.open(path, "wb") as file: - await file.write(data) - - async def path_exists_async(self, path: Union[Path, str]) -> bool: - """ - Check whether a path exists on the local disk. - - Args: - path (Path): The path to check. - - Returns: - bool: True if the path exists, False otherwise. - - """ - path = self._convert_to_path(path) - return path.exists() - - async def is_file_async(self, path: Union[Path, str]) -> bool: - """ - Check whether the given path is a file (not a directory). - - Args: - path (Path): The path to check. - - Returns: - bool: True if the path is a file, False otherwise. - - """ - path = self._convert_to_path(path) - return path.is_file() - - async def create_directory_if_not_exists_async(self, path: Union[Path, str]) -> None: - """ - Asynchronously creates a directory if it doesn't exist on the local disk. - - Args: - path (Path): The directory path to create. - - """ - directory_path = self._convert_to_path(path) - if not directory_path.exists(): - directory_path.mkdir(parents=True, exist_ok=True) - - def _convert_to_path(self, path: Union[Path, str]) -> Path: - """ - Convert an input path to a Path object. - - Args: - path (Union[Path, str]): Input path value. +""" +Deprecation shim — the storage I/O classes moved to ``pyrit.io.storage``. - Returns: - Path: Normalized Path instance. +Importing names from ``pyrit.models.storage_io`` still works for one release but +emits a one-time ``DeprecationWarning`` per name. Import from ``pyrit.io`` instead. +This shim will be removed in 0.17.0. +""" - """ - return Path(path) if isinstance(path, str) else path - - -class AzureBlobStorageIO(StorageIO): - """ - Implementation of StorageIO for Azure Blob Storage. - """ - - def __init__( - self, - *, - container_url: Optional[str] = None, - sas_token: Optional[str] = None, - blob_content_type: SupportedContentType = SupportedContentType.PLAIN_TEXT, - ) -> None: - """ - Initialize an Azure Blob Storage I/O adapter. - - Args: - container_url (Optional[str]): Azure Blob container URL. - sas_token (Optional[str]): Optional SAS token. - blob_content_type (SupportedContentType): Blob content type for uploads. - - Raises: - ValueError: If container_url is missing. - - """ - self._blob_content_type: str = blob_content_type.value - if not container_url: - raise ValueError("Invalid Azure Storage Account Container URL.") - - self._container_url: str = container_url - self._sas_token = sas_token - self._client_async: AsyncContainerClient | None = None - - async def _create_container_client_async(self) -> AsyncContainerClient: - """ - Create an asynchronous ContainerClient for Azure Storage. - - If a SAS token is provided via the - AZURE_STORAGE_ACCOUNT_SAS_TOKEN environment variable or the init sas_token parameter, it will be used - for authentication. Otherwise, a delegation SAS token will be created using Entra ID authentication. - - Returns: - AsyncContainerClient: The initialized container client. - """ - from azure.storage.blob.aio import ContainerClient as AsyncContainerClient - - from pyrit.auth import AzureStorageAuth - - sas_token = self._sas_token - if not self._sas_token: - logger.info("SAS token not provided. Creating a delegation SAS token using Entra ID authentication.") - sas_token = await AzureStorageAuth.get_sas_token_async(self._container_url) - - self._client_async = AsyncContainerClient.from_container_url( - container_url=self._container_url, - credential=sas_token, - ) - return self._client_async - - async def _upload_blob_async(self, file_name: str, data: bytes, content_type: str) -> None: - """ - (Async) Handles uploading blob to given storage container. - - Args: - file_name (str): File name to assign to uploaded blob. - data (bytes): Byte representation of content to upload to container. - content_type (str): Content type to upload. - - Raises: - RuntimeError: If the Azure container client is not initialized. - """ - from azure.core.exceptions import ClientAuthenticationError - from azure.storage.blob import ContentSettings - - content_settings = ContentSettings(content_type=f"{content_type}") - logger.info(msg="\nUploading to Azure Storage as blob:\n\t" + file_name) - - try: - if self._client_async is None: - raise RuntimeError("Azure container client not initialized") - await self._client_async.upload_blob( - name=file_name, - data=data, - content_settings=content_settings, - overwrite=True, - ) - except Exception as exc: - if isinstance(exc, ClientAuthenticationError): - logger.exception( - msg="Authentication failed. Please check that the container existence in the " - "Azure Storage Account and ensure the validity of the provided SAS token. If you " - "haven't set the SAS token as an environment variable use `az login` to " - "enable delegation-based SAS authentication to connect to the storage account" - ) - raise - logger.exception(msg=f"An unexpected error occurred: {exc}") - raise - - def parse_blob_url(self, file_path: str) -> tuple[str, str]: - """ - Parse a blob URL to extract the container and blob name. - - Args: - file_path (str): Full blob URL. - - Returns: - tuple[str, str]: Container name and blob name. - - Raises: - ValueError: If file_path is not a valid blob URL. - - """ - parsed_url = urlparse(file_path) - if parsed_url.scheme and parsed_url.netloc: - container_name = parsed_url.path.split("/")[1] - blob_name = "/".join(parsed_url.path.split("/")[2:]) - return container_name, blob_name - raise ValueError("Invalid blob URL") - - def _resolve_blob_name(self, path: Union[Path, str]) -> str: - """ - Resolve a blob name from either a full blob URL or a relative blob path. - - When a full URL is provided the blob name is extracted from it. The container - name embedded in the URL is intentionally discarded — operations always run - against the container configured in the constructor. - - Backslashes are normalized to forward slashes so that ``Path`` objects - created on Windows still produce valid blob names. - - Args: - path (Union[Path, str]): Blob URL or relative blob path. - - Returns: - str: The resolved blob name. - - """ - path_str = str(path).replace("\\", "/") - try: - # parse_blob_url validates scheme + netloc internally - _, blob_name = self.parse_blob_url(path_str) - return blob_name - except ValueError: - return path_str - - async def read_file_async(self, path: Union[Path, str]) -> bytes: - """ - Asynchronously reads the content of a file (blob) from Azure Blob Storage. - - If the provided ``path`` is a full URL - (e.g., ``https://account.blob.core.windows.net/container/dir1/dir2/sample.png``), - it extracts the relative blob path (e.g., ``dir1/dir2/sample.png``) to correctly access the blob. - If a relative path is provided, it will use it as-is. - - Args: - path (str): The path to the file (blob) in Azure Blob Storage. - This can be either a full URL or a relative path. - - Returns: - bytes: The content of the file (blob) as bytes. - - Example: - ``file_content = await read_file_async("https://account.blob.core.windows.net/container/dir2/1726627689003831.png")`` - - Or using a relative path: - - ``file_content = await read_file_async("dir1/dir2/1726627689003831.png")`` - - """ - if not self._client_async: - self._client_async = await self._create_container_client_async() - - blob_name = self._resolve_blob_name(path) - - try: - blob_client = self._client_async.get_blob_client(blob=blob_name) - - # Download the blob - blob_stream = await blob_client.download_blob() - return bytes(await blob_stream.readall()) - - except Exception as exc: - logger.exception(f"Failed to read file at {blob_name}: {exc}") - raise - finally: - await self._client_async.close() - self._client_async = None - - async def write_file_async(self, path: Union[Path, str], data: bytes) -> None: - """ - Write data to Azure Blob Storage at the specified path. - - If the provided ``path`` is a full URL, the blob name is extracted from it. - If a relative path is provided, it is used as the blob name directly. - - Args: - path (Union[Path, str]): Full blob URL or relative blob path. - data (bytes): The data to write. - """ - if not self._client_async: - self._client_async = await self._create_container_client_async() - blob_name = self._resolve_blob_name(path) - try: - await self._upload_blob_async(file_name=blob_name, data=data, content_type=self._blob_content_type) - except Exception as exc: - logger.exception(f"Failed to write file at {blob_name}: {exc}") - raise - finally: - await self._client_async.close() - self._client_async = None - - async def path_exists_async(self, path: Union[Path, str]) -> bool: - """ - Check whether a given path exists in the Azure Blob Storage container. - - Args: - path (Union[Path, str]): Blob URL or path to test. - - Returns: - bool: True when the path exists. - """ - from azure.core.exceptions import ResourceNotFoundError - - if not self._client_async: - self._client_async = await self._create_container_client_async() - try: - blob_name = self._resolve_blob_name(path) - blob_client = self._client_async.get_blob_client(blob=blob_name) - await blob_client.get_blob_properties() - return True - except ResourceNotFoundError: - return False - finally: - await self._client_async.close() - self._client_async = None - - async def is_file_async(self, path: Union[Path, str]) -> bool: - """ - Check whether the path refers to a file (blob) in Azure Blob Storage. - - Args: - path (Union[Path, str]): Blob URL or path to test. +from __future__ import annotations - Returns: - bool: True when the blob exists and has non-zero content size. - """ - from azure.core.exceptions import ResourceNotFoundError +from pyrit.common.deprecation import module_deprecation_getattr - if not self._client_async: - self._client_async = await self._create_container_client_async() - try: - blob_name = self._resolve_blob_name(path) - blob_client = self._client_async.get_blob_client(blob=blob_name) - blob_properties = await blob_client.get_blob_properties() - return bool(blob_properties.size > 0) - except ResourceNotFoundError: - return False - finally: - await self._client_async.close() - self._client_async = None +__all__ = [ + "AzureBlobStorageIO", + "DiskStorageIO", + "StorageIO", + "SupportedContentType", +] - async def create_directory_if_not_exists_async(self, directory_path: Union[Path, str]) -> None: # type: ignore[ty:invalid-method-override] - """ - Log a no-op directory creation for Azure Blob Storage. +__getattr__ = module_deprecation_getattr( + old_module="pyrit.models.storage_io", + target_module="pyrit.io.storage", + names=__all__, + removed_in="0.17.0", +) - Args: - directory_path (Union[Path, str]): Requested directory path. - """ - logger.info( - f"Directory creation is handled automatically during upload operations in Azure Blob Storage. " - f"Directory path: {directory_path}" - ) +def __dir__() -> list[str]: + return sorted(__all__) diff --git a/pyrit/output/conversation/pretty.py b/pyrit/output/conversation/pretty.py index 7af6250c1f..93347fb398 100644 --- a/pyrit/output/conversation/pretty.py +++ b/pyrit/output/conversation/pretty.py @@ -338,7 +338,7 @@ async def _display_image_async(self, piece: MessagePiece) -> None: if not is_in_ipython_session(): return - from pyrit.models.data_type_serializer import ImagePathDataTypeSerializer + from pyrit.io import ImagePathDataTypeSerializer try: serializer = ImagePathDataTypeSerializer(category="", prompt_text=piece.converted_value) diff --git a/pyrit/prompt_converter/add_image_text_converter.py b/pyrit/prompt_converter/add_image_text_converter.py index d4ee0809c8..2e0fd3acc9 100644 --- a/pyrit/prompt_converter/add_image_text_converter.py +++ b/pyrit/prompt_converter/add_image_text_converter.py @@ -10,7 +10,8 @@ from PIL.ImageFont import FreeTypeFont from pyrit.common.deprecation import print_deprecation_message -from pyrit.models import ComponentIdentifier, PromptDataType, data_serializer_factory +from pyrit.io import data_serializer_factory +from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.base_image_text_converter import _BaseImageTextConverter from pyrit.prompt_converter.prompt_converter import ConverterResult diff --git a/pyrit/prompt_converter/add_image_to_video_converter.py b/pyrit/prompt_converter/add_image_to_video_converter.py index 148f4d5e03..9468b30f7d 100644 --- a/pyrit/prompt_converter/add_image_to_video_converter.py +++ b/pyrit/prompt_converter/add_image_to_video_converter.py @@ -10,7 +10,8 @@ import numpy as np from pyrit.common.path import DB_DATA_PATH -from pyrit.models import ComponentIdentifier, PromptDataType, data_serializer_factory +from pyrit.io import data_serializer_factory +from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter logger = logging.getLogger(__name__) diff --git a/pyrit/prompt_converter/add_text_image_converter.py b/pyrit/prompt_converter/add_text_image_converter.py index 57964c8e75..84b6890797 100644 --- a/pyrit/prompt_converter/add_text_image_converter.py +++ b/pyrit/prompt_converter/add_text_image_converter.py @@ -11,7 +11,8 @@ from PIL.ImageFont import FreeTypeFont from pyrit.common.deprecation import print_deprecation_message -from pyrit.models import ComponentIdentifier, PromptDataType, data_serializer_factory +from pyrit.io import data_serializer_factory +from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.base_image_text_converter import _BaseImageTextConverter from pyrit.prompt_converter.prompt_converter import ConverterResult diff --git a/pyrit/prompt_converter/audio_echo_converter.py b/pyrit/prompt_converter/audio_echo_converter.py index 73a40385d4..dd1a7ce91c 100644 --- a/pyrit/prompt_converter/audio_echo_converter.py +++ b/pyrit/prompt_converter/audio_echo_converter.py @@ -8,7 +8,8 @@ import numpy as np from scipy.io import wavfile -from pyrit.models import PromptDataType, data_serializer_factory +from pyrit.io import data_serializer_factory +from pyrit.models import PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter logger = logging.getLogger(__name__) diff --git a/pyrit/prompt_converter/audio_frequency_converter.py b/pyrit/prompt_converter/audio_frequency_converter.py index 65050ea1f1..07054b90fd 100644 --- a/pyrit/prompt_converter/audio_frequency_converter.py +++ b/pyrit/prompt_converter/audio_frequency_converter.py @@ -8,7 +8,8 @@ import numpy as np from scipy.io import wavfile -from pyrit.models import ComponentIdentifier, PromptDataType, data_serializer_factory +from pyrit.io import data_serializer_factory +from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter logger = logging.getLogger(__name__) diff --git a/pyrit/prompt_converter/audio_speed_converter.py b/pyrit/prompt_converter/audio_speed_converter.py index 9a7a8053e6..31475d2ee0 100644 --- a/pyrit/prompt_converter/audio_speed_converter.py +++ b/pyrit/prompt_converter/audio_speed_converter.py @@ -9,7 +9,8 @@ from scipy.interpolate import interp1d from scipy.io import wavfile -from pyrit.models import PromptDataType, data_serializer_factory +from pyrit.io import data_serializer_factory +from pyrit.models import PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter logger = logging.getLogger(__name__) diff --git a/pyrit/prompt_converter/audio_volume_converter.py b/pyrit/prompt_converter/audio_volume_converter.py index 40e8e2a340..f6039f2bbd 100644 --- a/pyrit/prompt_converter/audio_volume_converter.py +++ b/pyrit/prompt_converter/audio_volume_converter.py @@ -8,7 +8,8 @@ import numpy as np from scipy.io import wavfile -from pyrit.models import PromptDataType, data_serializer_factory +from pyrit.io import data_serializer_factory +from pyrit.models import PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter logger = logging.getLogger(__name__) diff --git a/pyrit/prompt_converter/audio_white_noise_converter.py b/pyrit/prompt_converter/audio_white_noise_converter.py index 63726ce356..854b81f2cf 100644 --- a/pyrit/prompt_converter/audio_white_noise_converter.py +++ b/pyrit/prompt_converter/audio_white_noise_converter.py @@ -8,7 +8,8 @@ import numpy as np from scipy.io import wavfile -from pyrit.models import PromptDataType, data_serializer_factory +from pyrit.io import data_serializer_factory +from pyrit.models import PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter logger = logging.getLogger(__name__) diff --git a/pyrit/prompt_converter/azure_speech_audio_to_text_converter.py b/pyrit/prompt_converter/azure_speech_audio_to_text_converter.py index 9f355ce158..1b9a3deb08 100644 --- a/pyrit/prompt_converter/azure_speech_audio_to_text_converter.py +++ b/pyrit/prompt_converter/azure_speech_audio_to_text_converter.py @@ -13,7 +13,8 @@ from pyrit.auth.azure_auth import get_speech_config, get_speech_config_async from pyrit.common import default_values from pyrit.common.deprecation import print_deprecation_message -from pyrit.models import ComponentIdentifier, PromptDataType, data_serializer_factory +from pyrit.io import data_serializer_factory +from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter logger = logging.getLogger(__name__) diff --git a/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py b/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py index ffb8934b36..fda0876fd4 100644 --- a/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py +++ b/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py @@ -11,7 +11,8 @@ from pyrit.auth.azure_auth import get_speech_config_async from pyrit.common import default_values from pyrit.common.deprecation import print_deprecation_message -from pyrit.models import ComponentIdentifier, PromptDataType, data_serializer_factory +from pyrit.io import data_serializer_factory +from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter logger = logging.getLogger(__name__) diff --git a/pyrit/prompt_converter/base_image_to_image_converter.py b/pyrit/prompt_converter/base_image_to_image_converter.py index 0351e9a6ad..6214d41665 100644 --- a/pyrit/prompt_converter/base_image_to_image_converter.py +++ b/pyrit/prompt_converter/base_image_to_image_converter.py @@ -11,7 +11,8 @@ import aiohttp from PIL import Image -from pyrit.models import PromptDataType, data_serializer_factory +from pyrit.io import data_serializer_factory +from pyrit.models import PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter logger = logging.getLogger(__name__) diff --git a/pyrit/prompt_converter/image_compression_converter.py b/pyrit/prompt_converter/image_compression_converter.py index 57fc13b856..9c33cc7470 100644 --- a/pyrit/prompt_converter/image_compression_converter.py +++ b/pyrit/prompt_converter/image_compression_converter.py @@ -10,7 +10,8 @@ import aiohttp from PIL import Image -from pyrit.models import ComponentIdentifier, PromptDataType, data_serializer_factory +from pyrit.io import data_serializer_factory +from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter logger = logging.getLogger(__name__) diff --git a/pyrit/prompt_converter/image_overlay_converter.py b/pyrit/prompt_converter/image_overlay_converter.py index cd0d1d0c3a..7578f61e55 100644 --- a/pyrit/prompt_converter/image_overlay_converter.py +++ b/pyrit/prompt_converter/image_overlay_converter.py @@ -7,8 +7,8 @@ from PIL import Image -from pyrit.models import ComponentIdentifier, PromptDataType, data_serializer_factory -from pyrit.models.data_type_serializer import DataTypeSerializer +from pyrit.io import DataTypeSerializer, data_serializer_factory +from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter logger = logging.getLogger(__name__) diff --git a/pyrit/prompt_converter/pdf_converter.py b/pyrit/prompt_converter/pdf_converter.py index 4017f426e9..352b9b8a5b 100644 --- a/pyrit/prompt_converter/pdf_converter.py +++ b/pyrit/prompt_converter/pdf_converter.py @@ -13,8 +13,8 @@ from reportlab.pdfgen import canvas from pyrit.common.logger import logger -from pyrit.models import ComponentIdentifier, PromptDataType, SeedPrompt, data_serializer_factory -from pyrit.models.data_type_serializer import DataTypeSerializer +from pyrit.io import DataTypeSerializer, data_serializer_factory +from pyrit.models import ComponentIdentifier, PromptDataType, SeedPrompt from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter diff --git a/pyrit/prompt_converter/qr_code_converter.py b/pyrit/prompt_converter/qr_code_converter.py index cc1424d14f..b2519600b7 100644 --- a/pyrit/prompt_converter/qr_code_converter.py +++ b/pyrit/prompt_converter/qr_code_converter.py @@ -5,7 +5,8 @@ import segno -from pyrit.models import ComponentIdentifier, PromptDataType, data_serializer_factory +from pyrit.io import data_serializer_factory +from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter diff --git a/pyrit/prompt_converter/transparency_attack_converter.py b/pyrit/prompt_converter/transparency_attack_converter.py index c02f1be008..4da612d93a 100644 --- a/pyrit/prompt_converter/transparency_attack_converter.py +++ b/pyrit/prompt_converter/transparency_attack_converter.py @@ -10,7 +10,8 @@ import numpy as np from PIL import Image -from pyrit.models import ComponentIdentifier, PromptDataType, data_serializer_factory +from pyrit.io import data_serializer_factory +from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter logger = logging.getLogger(__name__) diff --git a/pyrit/prompt_converter/word_doc_converter.py b/pyrit/prompt_converter/word_doc_converter.py index 1dc5d00b5c..81ef353d18 100644 --- a/pyrit/prompt_converter/word_doc_converter.py +++ b/pyrit/prompt_converter/word_doc_converter.py @@ -12,14 +12,14 @@ from docx import Document from pyrit.common.logger import logger -from pyrit.models import PromptDataType, SeedPrompt, data_serializer_factory +from pyrit.io import data_serializer_factory from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter if TYPE_CHECKING: from pathlib import Path - from pyrit.models import ComponentIdentifier - from pyrit.models.data_type_serializer import DataTypeSerializer + from pyrit.io import DataTypeSerializer + from pyrit.models import ComponentIdentifier, PromptDataType, SeedPrompt @dataclass diff --git a/pyrit/prompt_normalizer/prompt_normalizer.py b/pyrit/prompt_normalizer/prompt_normalizer.py index 9c23f91cf3..87a517d493 100644 --- a/pyrit/prompt_normalizer/prompt_normalizer.py +++ b/pyrit/prompt_normalizer/prompt_normalizer.py @@ -19,6 +19,7 @@ execution_context, get_execution_context, ) +from pyrit.io import set_message_piece_sha256_async from pyrit.memory import CentralMemory, MemoryInterface from pyrit.models import ( ComponentIdentifier, @@ -377,7 +378,7 @@ async def convert_audio_async( async def _calc_hash_async(self, request: Message) -> None: """Add a request to the memory.""" - tasks = [asyncio.create_task(piece.set_sha256_values_async()) for piece in request.message_pieces] + tasks = [asyncio.create_task(set_message_piece_sha256_async(piece)) for piece in request.message_pieces] await asyncio.gather(*tasks) async def hash_and_persist_message_async(self, *, message: Message) -> None: diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index f82fc40e29..0a04bfaef6 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -13,15 +13,8 @@ PyritException, pyrit_target_retry, ) -from pyrit.models import ( - ChatMessage, - ComponentIdentifier, - DataTypeSerializer, - Message, - MessagePiece, - construct_response_from_request, - data_serializer_factory, -) +from pyrit.io import DataTypeSerializer, data_serializer_factory +from pyrit.models import ChatMessage, ComponentIdentifier, Message, MessagePiece, construct_response_from_request from pyrit.models.json_response_config import _JsonResponseConfig from pyrit.prompt_target.common.target_capabilities import TargetCapabilities from pyrit.prompt_target.common.target_configuration import TargetConfiguration diff --git a/pyrit/prompt_target/openai/openai_image_target.py b/pyrit/prompt_target/openai/openai_image_target.py index 7ae0383998..ae1ef6b593 100644 --- a/pyrit/prompt_target/openai/openai_image_target.py +++ b/pyrit/prompt_target/openai/openai_image_target.py @@ -11,12 +11,8 @@ EmptyResponseException, pyrit_target_retry, ) -from pyrit.models import ( - ComponentIdentifier, - Message, - construct_response_from_request, - data_serializer_factory, -) +from pyrit.io import data_serializer_factory +from pyrit.models import ComponentIdentifier, Message, construct_response_from_request from pyrit.prompt_target.common.target_capabilities import TargetCapabilities from pyrit.prompt_target.common.target_configuration import TargetConfiguration from pyrit.prompt_target.common.utils import limit_requests_per_minute diff --git a/pyrit/prompt_target/openai/openai_realtime_target.py b/pyrit/prompt_target/openai/openai_realtime_target.py index 0197d5ba64..26d61c758e 100644 --- a/pyrit/prompt_target/openai/openai_realtime_target.py +++ b/pyrit/prompt_target/openai/openai_realtime_target.py @@ -15,12 +15,8 @@ pyrit_target_retry, ) from pyrit.exceptions.exception_classes import ServerErrorException -from pyrit.models import ( - ComponentIdentifier, - Message, - construct_response_from_request, - data_serializer_factory, -) +from pyrit.io import data_serializer_factory +from pyrit.models import ComponentIdentifier, Message, construct_response_from_request from pyrit.prompt_target.common.realtime_audio import ( RealtimeTargetResult, ServerVadConfig, diff --git a/pyrit/prompt_target/openai/openai_tts_target.py b/pyrit/prompt_target/openai/openai_tts_target.py index d9c66128cb..71c71c2cb2 100644 --- a/pyrit/prompt_target/openai/openai_tts_target.py +++ b/pyrit/prompt_target/openai/openai_tts_target.py @@ -7,12 +7,8 @@ from pyrit.exceptions import ( pyrit_target_retry, ) -from pyrit.models import ( - ComponentIdentifier, - Message, - construct_response_from_request, - data_serializer_factory, -) +from pyrit.io import data_serializer_factory +from pyrit.models import ComponentIdentifier, Message, construct_response_from_request from pyrit.prompt_target.common.target_capabilities import TargetCapabilities from pyrit.prompt_target.common.target_configuration import TargetConfiguration from pyrit.prompt_target.common.utils import limit_requests_per_minute diff --git a/pyrit/prompt_target/openai/openai_video_target.py b/pyrit/prompt_target/openai/openai_video_target.py index 37dedc9ae0..9dfac1a356 100644 --- a/pyrit/prompt_target/openai/openai_video_target.py +++ b/pyrit/prompt_target/openai/openai_video_target.py @@ -11,14 +11,8 @@ from pyrit.exceptions import ( pyrit_target_retry, ) -from pyrit.models import ( - ComponentIdentifier, - DataTypeSerializer, - Message, - MessagePiece, - construct_response_from_request, - data_serializer_factory, -) +from pyrit.io import DataTypeSerializer, data_serializer_factory +from pyrit.models import ComponentIdentifier, Message, MessagePiece, construct_response_from_request from pyrit.prompt_target.common.target_capabilities import TargetCapabilities from pyrit.prompt_target.common.target_configuration import TargetConfiguration from pyrit.prompt_target.common.utils import limit_requests_per_minute diff --git a/pyrit/prompt_target/playwright_copilot_target.py b/pyrit/prompt_target/playwright_copilot_target.py index a2dfc796b2..aff5195ee0 100644 --- a/pyrit/prompt_target/playwright_copilot_target.py +++ b/pyrit/prompt_target/playwright_copilot_target.py @@ -9,13 +9,8 @@ from enum import Enum from typing import TYPE_CHECKING, Any, Optional, Union -from pyrit.models import ( - ComponentIdentifier, - Message, - MessagePiece, - construct_response_from_request, - data_serializer_factory, -) +from pyrit.io import data_serializer_factory +from pyrit.models import ComponentIdentifier, Message, MessagePiece, construct_response_from_request from pyrit.models.literals import PromptDataType from pyrit.prompt_target.common.prompt_target import PromptTarget from pyrit.prompt_target.common.target_capabilities import TargetCapabilities diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index 6a5de15f60..4e92628856 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -19,7 +19,8 @@ EmptyResponseException, pyrit_target_retry, ) -from pyrit.models import ComponentIdentifier, DataTypeSerializer, Message, MessagePiece, construct_response_from_request +from pyrit.io import DataTypeSerializer +from pyrit.models import ComponentIdentifier, Message, MessagePiece, construct_response_from_request from pyrit.prompt_target import PromptTarget, limit_requests_per_minute from pyrit.prompt_target.common.target_capabilities import TargetCapabilities from pyrit.prompt_target.common.target_configuration import TargetConfiguration diff --git a/pyrit/score/float_scale/azure_content_filter_scorer.py b/pyrit/score/float_scale/azure_content_filter_scorer.py index 2ef3b412fb..a4b4e8e030 100644 --- a/pyrit/score/float_scale/azure_content_filter_scorer.py +++ b/pyrit/score/float_scale/azure_content_filter_scorer.py @@ -19,14 +19,8 @@ from pyrit.auth import AsyncTokenProviderCredential, ensure_async_token_provider, get_azure_async_token_provider from pyrit.common import default_values -from pyrit.models import ( - ComponentIdentifier, - DataTypeSerializer, - Message, - MessagePiece, - Score, - data_serializer_factory, -) +from pyrit.io import DataTypeSerializer, data_serializer_factory +from pyrit.models import ComponentIdentifier, Message, MessagePiece, Score from pyrit.score.float_scale.float_scale_score_aggregator import ( FloatScaleScorerByCategory, ) diff --git a/tests/unit/common/test_convert_local_image_to_data_url.py b/tests/unit/common/test_convert_local_image_to_data_url.py index bebbd6c67e..adff0f4890 100644 --- a/tests/unit/common/test_convert_local_image_to_data_url.py +++ b/tests/unit/common/test_convert_local_image_to_data_url.py @@ -50,7 +50,7 @@ async def test_convert_local_image_to_data_url_missing_file(): @patch("os.path.exists", return_value=True) @patch("mimetypes.guess_type", return_value=("image/jpg", None)) -@patch("pyrit.models.data_type_serializer.ImagePathDataTypeSerializer") +@patch("pyrit.io.serializers.ImagePathDataTypeSerializer") @patch("pyrit.memory.CentralMemory.get_memory_instance", return_value=SQLiteMemory(db_path=":memory:")) async def test_convert_image_to_data_url_success( mock_get_memory_instance, mock_serializer_class, mock_guess_type, mock_exists diff --git a/tests/unit/common/test_display_response.py b/tests/unit/common/test_display_response.py index e06f8ee6d6..d32e559709 100644 --- a/tests/unit/common/test_display_response.py +++ b/tests/unit/common/test_display_response.py @@ -98,7 +98,7 @@ async def test_display_image_logs_error_when_storage_io_is_none(mock_ipython, ca @patch("pyrit.common.display_response.display", create=True) async def test_display_image_azure_fallback_to_disk(mock_display, mock_image, mock_disk_io_cls, mock_ipython): """Test that when AzureBlobStorageIO read fails, it falls back to DiskStorageIO.""" - from pyrit.models import AzureBlobStorageIO + from pyrit.io import AzureBlobStorageIO mock_memory = MagicMock() mock_azure_io = MagicMock(spec=AzureBlobStorageIO) @@ -126,7 +126,7 @@ async def test_display_image_azure_fallback_to_disk(mock_display, mock_image, mo @patch("pyrit.common.display_response.DiskStorageIO") async def test_display_image_azure_and_disk_both_fail(mock_disk_io_cls, mock_ipython, caplog): """Test that when both AzureBlobStorageIO and DiskStorageIO fail, error is logged and returns.""" - from pyrit.models import AzureBlobStorageIO + from pyrit.io import AzureBlobStorageIO mock_memory = MagicMock() mock_azure_io = MagicMock(spec=AzureBlobStorageIO) diff --git a/tests/unit/io/test_deprecation_shims.py b/tests/unit/io/test_deprecation_shims.py new file mode 100644 index 0000000000..dda25dbbf9 --- /dev/null +++ b/tests/unit/io/test_deprecation_shims.py @@ -0,0 +1,181 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for the Phase 9 deprecation shims. + +``pyrit.models.storage_io`` and ``pyrit.models.data_type_serializer`` moved to +``pyrit.io.storage`` / ``pyrit.io.serializers``. The old module paths, the +``pyrit.models`` package-root re-exports, and the +``MessagePiece.set_sha256_values_async`` / ``Seed.set_sha256_value_async`` +method shims all still work but emit a ``DeprecationWarning`` pointing at the +new ``pyrit.io`` location. These tests pin that contract. The shims will be +removed in 0.17.0. +""" + +from __future__ import annotations + +import importlib +import subprocess +import sys +import warnings +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +import pyrit.io.serializers as new_serializers +import pyrit.io.storage as new_storage +import pyrit.models as models_pkg +import pyrit.models.data_type_serializer as serializer_shim +import pyrit.models.storage_io as storage_shim +from pyrit.models.messages.message_piece import MessagePiece +from pyrit.models.seeds.seed import Seed + +MODULE_SHIM_PAIRS = [ + (storage_shim, new_storage, "pyrit.models.storage_io", "pyrit.io.storage"), + (serializer_shim, new_serializers, "pyrit.models.data_type_serializer", "pyrit.io.serializers"), +] + + +@pytest.fixture(autouse=True) +def _reset_models_warned(): + """Reset the ``pyrit.models`` package-root warn-once cache so each test starts clean.""" + saved = set(models_pkg._warned) + models_pkg._warned.clear() + try: + yield + finally: + models_pkg._warned.clear() + models_pkg._warned.update(saved) + + +@pytest.mark.parametrize("shim_mod, new_mod, old_path, new_path", MODULE_SHIM_PAIRS) +def test_module_shim_forwards_every_name(shim_mod, new_mod, old_path, new_path): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + for name in shim_mod.__all__: + assert getattr(shim_mod, name) is getattr(new_mod, name), f"{old_path}.{name} did not forward" + + +@pytest.mark.parametrize("shim_mod, new_mod, old_path, new_path", MODULE_SHIM_PAIRS) +def test_module_shim_warns_once_per_name(shim_mod, new_mod, old_path, new_path): + # Reload the shim to reset its internal warn-once closure for a clean count. + shim_mod = importlib.reload(shim_mod) + for name in shim_mod.__all__: + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always", DeprecationWarning) + getattr(shim_mod, name) + getattr(shim_mod, name) + + dep = [w for w in caught if issubclass(w.category, DeprecationWarning)] + assert len(dep) == 1, f"Expected 1 DeprecationWarning for {old_path}.{name}, got {len(dep)}" + message = str(dep[0].message) + assert f"{old_path}.{name}" in message + assert f"{new_path}.{name}" in message + assert "0.17.0" in message + + +@pytest.mark.parametrize("shim_mod, new_mod, old_path, new_path", MODULE_SHIM_PAIRS) +def test_module_shim_attribute_error_for_unknown_name(shim_mod, new_mod, old_path, new_path): + with pytest.raises(AttributeError, match=f"module {old_path!r} has no attribute"): + _ = shim_mod.definitely_not_a_real_name + + +@pytest.mark.parametrize("shim_mod, new_mod, old_path, new_path", MODULE_SHIM_PAIRS) +def test_module_shim_dir_returns_sorted_all(shim_mod, new_mod, old_path, new_path): + assert dir(shim_mod) == sorted(shim_mod.__all__) + + +def test_moved_to_pyrit_io_contains_expected_root_exports(): + # Guards against accidentally dropping a previously root-importable name from the + # forwarding table. These are exactly the names that used to be importable from + # ``pyrit.models`` and now live in ``pyrit.io``. URLDataTypeSerializer and + # SupportedContentType were never root-exported, so they are intentionally absent. + expected = { + "AllowedCategories", + "AudioPathDataTypeSerializer", + "BinaryPathDataTypeSerializer", + "DataTypeSerializer", + "ErrorDataTypeSerializer", + "ImagePathDataTypeSerializer", + "TextDataTypeSerializer", + "VideoPathDataTypeSerializer", + "data_serializer_factory", + "AzureBlobStorageIO", + "DiskStorageIO", + "StorageIO", + } + assert set(models_pkg._MOVED_TO_PYRIT_IO) == expected + + +@pytest.mark.parametrize("name", sorted(models_pkg._MOVED_TO_PYRIT_IO)) +def test_models_package_root_forwards_and_warns_once(name): + target_module = models_pkg._MOVED_TO_PYRIT_IO[name] + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always", DeprecationWarning) + first = getattr(models_pkg, name) + second = getattr(models_pkg, name) + + assert first is second + assert first is getattr(importlib.import_module(target_module), name) + + dep = [w for w in caught if issubclass(w.category, DeprecationWarning)] + assert len(dep) == 1, f"Expected 1 DeprecationWarning for pyrit.models.{name}, got {len(dep)}" + message = str(dep[0].message) + assert f"pyrit.models.{name}" in message + assert f"{target_module}.{name}" in message + assert "0.17.0" in message + + +def test_importing_pyrit_models_does_not_warn(): + # Use a subprocess so the import is genuinely fresh and reloading the core + # package can't contaminate other tests in this worker. Filter to warnings + # that reference the moved paths so unrelated third-party DeprecationWarnings + # emitted at import time don't make this flaky. + script = ( + "import warnings\n" + "with warnings.catch_warnings(record=True) as caught:\n" + " warnings.simplefilter('always')\n" + " import pyrit.models\n" + "offenders = [str(w.message) for w in caught\n" + " if issubclass(w.category, DeprecationWarning)\n" + " and ('pyrit.io' in str(w.message) or 'pyrit.models.storage_io' in str(w.message)\n" + " or 'pyrit.models.data_type_serializer' in str(w.message))]\n" + "assert not offenders, offenders\n" + ) + result = subprocess.run([sys.executable, "-c", script], capture_output=True, text=True) + assert result.returncode == 0, f"Importing pyrit.models warned about moved names:\n{result.stderr}" + + +async def test_message_piece_method_shim_warns_and_delegates(): + fake_self = MagicMock(spec=MessagePiece) + delegate = AsyncMock() + with patch.object(new_serializers, "set_message_piece_sha256_async", delegate): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always", DeprecationWarning) + await MessagePiece.set_sha256_values_async(fake_self) + + delegate.assert_awaited_once_with(fake_self) + dep = [w for w in caught if issubclass(w.category, DeprecationWarning)] + assert len(dep) == 1 + message = str(dep[0].message) + assert "MessagePiece.set_sha256_values_async" in message + assert "pyrit.io.serializers.set_message_piece_sha256_async" in message + assert "0.17.0" in message + + +async def test_seed_method_shim_warns_and_delegates(): + fake_self = MagicMock(spec=Seed) + delegate = AsyncMock() + with patch.object(new_serializers, "set_seed_sha256_async", delegate): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always", DeprecationWarning) + await Seed.set_sha256_value_async(fake_self) + + delegate.assert_awaited_once_with(fake_self) + dep = [w for w in caught if issubclass(w.category, DeprecationWarning)] + assert len(dep) == 1 + message = str(dep[0].message) + assert "Seed.set_sha256_value_async" in message + assert "pyrit.io.serializers.set_seed_sha256_async" in message + assert "0.17.0" in message diff --git a/tests/unit/models/test_data_type_serializer.py b/tests/unit/io/test_serializers.py similarity index 95% rename from tests/unit/models/test_data_type_serializer.py rename to tests/unit/io/test_serializers.py index fdcc204f10..bd28721600 100644 --- a/tests/unit/models/test_data_type_serializer.py +++ b/tests/unit/io/test_serializers.py @@ -11,7 +11,7 @@ import pytest from PIL import Image -from pyrit.models import ( +from pyrit.io import ( AllowedCategories, BinaryPathDataTypeSerializer, DataTypeSerializer, @@ -19,7 +19,10 @@ ImagePathDataTypeSerializer, TextDataTypeSerializer, data_serializer_factory, + set_message_piece_sha256_async, + set_seed_sha256_async, ) +from pyrit.models import MessagePiece, SeedPrompt def test_allowed_categories(): @@ -252,7 +255,7 @@ async def test_read_data_local_file_with_dummy_image(sqlite_instance): with open(image_path, "rb") as f: mock_storage_io.read_file_async.return_value = f.read() - with patch("pyrit.models.data_type_serializer.DiskStorageIO", return_value=mock_storage_io): + with patch("pyrit.io.serializers.DiskStorageIO", return_value=mock_storage_io): serializer = data_serializer_factory( category="prompt-memory-entries", data_type="image_path", value=image_path ) @@ -385,7 +388,7 @@ async def test_save_b64_image_raises_when_results_storage_io_none(): async def test_save_formatted_audio_raises_when_results_storage_io_none(): - from pyrit.models import data_serializer_factory as factory + from pyrit.io import data_serializer_factory as factory serializer = factory(category="prompt-memory-entries", data_type="audio_path") mock_memory = MagicMock() @@ -408,7 +411,7 @@ async def test_save_formatted_audio_writes_local_wav_via_to_thread(sqlite_instan """save_formatted_audio (local-disk path) should produce a readable WAV via _write_wav_sync.""" import wave - from pyrit.models import data_serializer_factory as factory + from pyrit.io import data_serializer_factory as factory serializer = factory(category="prompt-memory-entries", data_type="audio_path") output_path = tmp_path / "out.wav" @@ -434,7 +437,7 @@ def test_write_wav_sync_produces_readable_wav(tmp_path): """_write_wav_sync should produce a WAV file readable by wave.open with the same metadata and frames.""" import wave - from pyrit.models.data_type_serializer import _write_wav_sync + from pyrit.io.serializers import _write_wav_sync out_path = tmp_path / "direct.wav" pcm = b"\x10\x00\x20\x00\x30\x00\x40\x00" @@ -459,7 +462,7 @@ async def test_save_formatted_audio_writes_azure_wav_via_storage_io(sqlite_insta import wave from pyrit.common import path as common_path - from pyrit.models import data_serializer_factory as factory + from pyrit.io import data_serializer_factory as factory captured: dict[str, bytes] = {} @@ -480,7 +483,7 @@ async def _capture_write(file_path, data): with patch.object(serializer, "get_data_filename_async", new_callable=AsyncMock, return_value=azure_url): # Redirect DB_DATA_PATH so the temp_audio.wav write lands in tmp_path with patch.object(common_path, "DB_DATA_PATH", str(tmp_path)): - from pyrit.models import data_type_serializer as dts_module + from pyrit.io import serializers as dts_module with patch.object(dts_module, "DB_DATA_PATH", str(tmp_path)): await serializer.save_formatted_audio_async( @@ -600,7 +603,7 @@ async def test_get_data_filename_emits_deprecation_warning_and_delegates(sqlite_ async def test_save_formatted_audio_azure_storage_unlinks_local_temp(tmp_path): """save_formatted_audio_async cleans up the local temp WAV after writing to Azure storage.""" - from pyrit.models import data_serializer_factory as factory + from pyrit.io import data_serializer_factory as factory serializer = factory(category="prompt-memory-entries", data_type="audio_path") mock_memory = MagicMock() @@ -611,7 +614,7 @@ async def test_save_formatted_audio_azure_storage_unlinks_local_temp(tmp_path): with ( patch.object(type(serializer), "_memory", new_callable=PropertyMock, return_value=mock_memory), patch.object(serializer, "get_data_filename_async", new_callable=AsyncMock, return_value=azure_url), - patch("pyrit.models.data_type_serializer.DB_DATA_PATH", tmp_path), + patch("pyrit.io.serializers.DB_DATA_PATH", tmp_path), ): await serializer.save_formatted_audio_async(data=b"\x00\x01\x02\x03") @@ -620,3 +623,23 @@ async def test_save_formatted_audio_azure_storage_unlinks_local_temp(tmp_path): mock_storage_io.write_file_async.assert_awaited_once() assert mock_storage_io.write_file_async.call_args[0][0] == azure_url assert serializer.value == azure_url + + +async def test_set_message_piece_sha256_async_sets_text_hashes(sqlite_instance): + piece = MessagePiece(role="user", original_value="Hello") + piece.original_value = "newvalue" + piece.converted_value = "newvalue" + + await set_message_piece_sha256_async(piece) + + expected = "70e01503173b8e904d53b40b3ebb3bded5e5d3add087d3463a4b1abe92f1a8ca" + assert piece.original_value_sha256 == expected + assert piece.converted_value_sha256 == expected + + +async def test_set_seed_sha256_async_sets_text_hash(sqlite_instance): + seed = SeedPrompt(value="Hello1", data_type="text") + + await set_seed_sha256_async(seed) + + assert seed.value_sha256 == "948edbe7ede5aa7423476ae29dcd7d61e7711a071aea0d83698377effa896525" diff --git a/tests/unit/models/test_storage_io.py b/tests/unit/io/test_storage.py similarity index 99% rename from tests/unit/models/test_storage_io.py rename to tests/unit/io/test_storage.py index 0adde24a75..af5eceb767 100644 --- a/tests/unit/models/test_storage_io.py +++ b/tests/unit/io/test_storage.py @@ -6,7 +6,7 @@ import pytest -from pyrit.models.storage_io import ( +from pyrit.io.storage import ( AzureBlobStorageIO, DiskStorageIO, SupportedContentType, diff --git a/tests/unit/models/test_import_boundary.py b/tests/unit/models/test_import_boundary.py index 981012711c..76cd97dce7 100644 --- a/tests/unit/models/test_import_boundary.py +++ b/tests/unit/models/test_import_boundary.py @@ -56,24 +56,19 @@ # ratchet, tracked separately so the phase that removes the lazy workaround is # explicit. KNOWN_LAZY_VIOLATIONS: dict[str, dict[str, str]] = { - "pyrit.models.data_type_serializer": { - "pyrit.memory": "phase-9", - }, "pyrit.models.identifiers.evaluation_identifier": { "pyrit.executor.attack.core.attack_strategy": "phase-7", }, - "pyrit.models.storage_io": { - "pyrit.auth": "phase-9", - }, } # Reverse-guard violations: pyrit.common modules that still reach up into higher # layers. These are slated to relocate; the ratchet forces them to shrink. KNOWN_COMMON_VIOLATIONS: dict[str, dict[str, str]] = { "pyrit.common.data_url_converter": { - "pyrit.models": "relocate", + "pyrit.io": "relocate", }, "pyrit.common.display_response": { + "pyrit.io": "relocate", "pyrit.memory": "relocate", "pyrit.models": "relocate", }, diff --git a/tests/unit/output/test_blur_images.py b/tests/unit/output/test_blur_images.py index 538e3fb8a0..2e8e20cc60 100644 --- a/tests/unit/output/test_blur_images.py +++ b/tests/unit/output/test_blur_images.py @@ -53,7 +53,7 @@ async def test_pretty_blurs_image_bytes_before_display(tmp_path, patch_central_d with ( patch("pyrit.common.notebook_utils.is_in_ipython_session", return_value=True), patch( - "pyrit.models.data_type_serializer.ImagePathDataTypeSerializer", + "pyrit.io.serializers.ImagePathDataTypeSerializer", return_value=fake_serializer, ), patch( @@ -93,7 +93,7 @@ async def test_pretty_does_not_blur_by_default(tmp_path, patch_central_database) with ( patch("pyrit.common.notebook_utils.is_in_ipython_session", return_value=True), patch( - "pyrit.models.data_type_serializer.ImagePathDataTypeSerializer", + "pyrit.io.serializers.ImagePathDataTypeSerializer", return_value=fake_serializer, ), patch( diff --git a/tests/unit/prompt_converter/test_pdf_converter.py b/tests/unit/prompt_converter/test_pdf_converter.py index b40daf88c5..4bd72c2326 100644 --- a/tests/unit/prompt_converter/test_pdf_converter.py +++ b/tests/unit/prompt_converter/test_pdf_converter.py @@ -11,7 +11,8 @@ from reportlab.lib.pagesizes import A4 from reportlab.pdfgen import canvas -from pyrit.models import DataTypeSerializer, SeedPrompt +from pyrit.io import DataTypeSerializer +from pyrit.models import SeedPrompt from pyrit.prompt_converter import ConverterResult, PDFConverter From aaa3c0f37a444a331c52a7d6cd85f840a4537af4 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 4 Jun 2026 12:37:06 -0700 Subject: [PATCH 02/10] MAINT: Move pyrit.io into pyrit.memory.storage Relocate the storage/serialization layer from the top-level pyrit.io package into pyrit.memory.storage, resolving the naming collision with pyrit.output and reflecting that the serializers are a facade over CentralMemory (results_path / results_storage_io). Updates all importers, deprecation shims, and the import-boundary allowlist. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/backend/services/attack_service.py | 2 +- pyrit/backend/services/converter_service.py | 2 +- pyrit/common/data_url_converter.py | 2 +- pyrit/common/display_response.py | 2 +- .../seed_datasets/remote/_image_cache.py | 2 +- .../seed_datasets/remote/msts_dataset.py | 2 +- pyrit/memory/azure_sql_memory.py | 2 +- pyrit/memory/memory_interface.py | 12 +++---- pyrit/memory/sqlite_memory.py | 2 +- pyrit/{io => memory/storage}/__init__.py | 12 ++++--- pyrit/{io => memory/storage}/serializers.py | 30 ++++++++--------- pyrit/{io => memory/storage}/storage.py | 20 ++++++------ .../chat_message_normalizer.py | 2 +- pyrit/models/__init__.py | 32 +++++++++---------- pyrit/models/data_type_serializer.py | 7 ++-- pyrit/models/messages/message_piece.py | 6 ++-- pyrit/models/seeds/seed.py | 6 ++-- pyrit/models/storage_io.py | 9 +++--- pyrit/output/conversation/pretty.py | 2 +- .../add_image_text_converter.py | 2 +- .../add_image_to_video_converter.py | 2 +- .../add_text_image_converter.py | 2 +- .../prompt_converter/audio_echo_converter.py | 2 +- .../audio_frequency_converter.py | 2 +- .../prompt_converter/audio_speed_converter.py | 2 +- .../audio_volume_converter.py | 2 +- .../audio_white_noise_converter.py | 2 +- .../azure_speech_audio_to_text_converter.py | 2 +- .../azure_speech_text_to_audio_converter.py | 2 +- .../base_image_to_image_converter.py | 2 +- .../image_compression_converter.py | 2 +- .../image_overlay_converter.py | 2 +- pyrit/prompt_converter/pdf_converter.py | 2 +- pyrit/prompt_converter/qr_code_converter.py | 2 +- .../transparency_attack_converter.py | 2 +- pyrit/prompt_converter/word_doc_converter.py | 4 +-- pyrit/prompt_normalizer/prompt_normalizer.py | 2 +- .../openai/openai_chat_target.py | 2 +- .../openai/openai_image_target.py | 2 +- .../openai/openai_realtime_target.py | 2 +- .../prompt_target/openai/openai_tts_target.py | 2 +- .../openai/openai_video_target.py | 2 +- .../playwright_copilot_target.py | 2 +- .../prompt_target/websocket_copilot_target.py | 2 +- .../azure_content_filter_scorer.py | 2 +- .../test_convert_local_image_to_data_url.py | 2 +- tests/unit/common/test_display_response.py | 4 +-- tests/unit/memory/storage/__init__.py | 2 ++ .../storage}/test_deprecation_shims.py | 28 ++++++++-------- .../storage}/test_serializers.py | 18 +++++------ .../{io => memory/storage}/test_storage.py | 2 +- tests/unit/models/test_import_boundary.py | 4 +-- tests/unit/output/test_blur_images.py | 4 +-- .../prompt_converter/test_pdf_converter.py | 2 +- 54 files changed, 140 insertions(+), 134 deletions(-) rename pyrit/{io => memory/storage}/__init__.py (73%) rename pyrit/{io => memory/storage}/serializers.py (95%) rename pyrit/{io => memory/storage}/storage.py (95%) create mode 100644 tests/unit/memory/storage/__init__.py rename tests/unit/{io => memory/storage}/test_deprecation_shims.py (87%) rename tests/unit/{io => memory/storage}/test_serializers.py (97%) rename tests/unit/{io => memory/storage}/test_storage.py (99%) diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py index 3bea0fa55c..d10fca5de7 100644 --- a/pyrit/backend/services/attack_service.py +++ b/pyrit/backend/services/attack_service.py @@ -50,8 +50,8 @@ from pyrit.backend.models.common import PaginationInfo from pyrit.backend.services.converter_service import get_converter_service from pyrit.backend.services.target_service import get_target_service -from pyrit.io import data_serializer_factory from pyrit.memory import CentralMemory +from pyrit.memory.storage import data_serializer_factory from pyrit.models import ( AttackOutcome, AttackResult, diff --git a/pyrit/backend/services/converter_service.py b/pyrit/backend/services/converter_service.py index 4df2f1b1eb..96daa3d72e 100644 --- a/pyrit/backend/services/converter_service.py +++ b/pyrit/backend/services/converter_service.py @@ -36,7 +36,7 @@ CreateConverterResponse, PreviewStep, ) -from pyrit.io import data_serializer_factory +from pyrit.memory.storage import data_serializer_factory from pyrit.models import PromptDataType from pyrit.prompt_converter import PromptConverter from pyrit.prompt_target import PromptTarget diff --git a/pyrit/common/data_url_converter.py b/pyrit/common/data_url_converter.py index 0193b2a5fb..9a81323c81 100644 --- a/pyrit/common/data_url_converter.py +++ b/pyrit/common/data_url_converter.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. from pyrit.common.deprecation import print_deprecation_message -from pyrit.io import DataTypeSerializer, data_serializer_factory +from pyrit.memory.storage import DataTypeSerializer, data_serializer_factory # Supported image formats for Azure OpenAI GPT-4o, # https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/use-your-image-data diff --git a/pyrit/common/display_response.py b/pyrit/common/display_response.py index 11093a6cf8..87e819d613 100644 --- a/pyrit/common/display_response.py +++ b/pyrit/common/display_response.py @@ -8,8 +8,8 @@ from pyrit.common.deprecation import print_deprecation_message from pyrit.common.notebook_utils import is_in_ipython_session -from pyrit.io import AzureBlobStorageIO, DiskStorageIO from pyrit.memory import CentralMemory +from pyrit.memory.storage import AzureBlobStorageIO, DiskStorageIO from pyrit.models import MessagePiece logger = logging.getLogger(__name__) diff --git a/pyrit/datasets/seed_datasets/remote/_image_cache.py b/pyrit/datasets/seed_datasets/remote/_image_cache.py index 9b5cb807ae..4575f900ea 100644 --- a/pyrit/datasets/seed_datasets/remote/_image_cache.py +++ b/pyrit/datasets/seed_datasets/remote/_image_cache.py @@ -17,7 +17,7 @@ from typing import Any, Optional from pyrit.common.net_utility import make_request_and_raise_if_error_async -from pyrit.io import data_serializer_factory +from pyrit.memory.storage import data_serializer_factory logger = logging.getLogger(__name__) diff --git a/pyrit/datasets/seed_datasets/remote/msts_dataset.py b/pyrit/datasets/seed_datasets/remote/msts_dataset.py index 8fa876e6da..264e00d4d2 100644 --- a/pyrit/datasets/seed_datasets/remote/msts_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/msts_dataset.py @@ -12,7 +12,7 @@ from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( _RemoteDatasetLoader, ) -from pyrit.io import data_serializer_factory +from pyrit.memory.storage import data_serializer_factory from pyrit.models import SeedDataset, SeedPrompt if TYPE_CHECKING: diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index ce4050040f..fa93cdefa3 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -19,7 +19,6 @@ from pyrit.auth.azure_auth import AzureAuth from pyrit.common import default_values from pyrit.common.singleton import Singleton -from pyrit.io import AzureBlobStorageIO from pyrit.memory.memory_interface import MemoryInterface from pyrit.memory.memory_models import ( AttackResultEntry, @@ -27,6 +26,7 @@ EmbeddingDataEntry, PromptMemoryEntry, ) +from pyrit.memory.storage import AzureBlobStorageIO from pyrit.models import ConversationStats, MessagePiece if TYPE_CHECKING: diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 32f6661d2a..8f87b50be0 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -23,12 +23,6 @@ from pyrit.common.deprecation import print_deprecation_message from pyrit.common.path import DB_DATA_PATH -from pyrit.io import ( - DataTypeSerializer, - StorageIO, - data_serializer_factory, - set_seed_sha256_async, -) from pyrit.memory.memory_exporter import MemoryExporter from pyrit.memory.memory_models import ( AttackResultEntry, @@ -39,6 +33,12 @@ ScoreEntry, SeedEntry, ) +from pyrit.memory.storage import ( + DataTypeSerializer, + StorageIO, + data_serializer_factory, + set_seed_sha256_async, +) from pyrit.models import ( AttackResult, ConversationStats, diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index 57202ffec2..2d13001a81 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -21,7 +21,6 @@ from pyrit.common.deprecation import print_deprecation_message from pyrit.common.path import DB_DATA_PATH from pyrit.common.singleton import Singleton -from pyrit.io import DiskStorageIO from pyrit.memory.memory_interface import MemoryInterface from pyrit.memory.memory_models import ( AttackResultEntry, @@ -30,6 +29,7 @@ PromptMemoryEntry, ScenarioResultEntry, ) +from pyrit.memory.storage import DiskStorageIO from pyrit.models import ConversationStats, MessagePiece logger = logging.getLogger(__name__) diff --git a/pyrit/io/__init__.py b/pyrit/memory/storage/__init__.py similarity index 73% rename from pyrit/io/__init__.py rename to pyrit/memory/storage/__init__.py index 5333f788cb..b10fcb1d35 100644 --- a/pyrit/io/__init__.py +++ b/pyrit/memory/storage/__init__.py @@ -2,18 +2,20 @@ # Licensed under the MIT license. """ -I/O layer for PyRIT: storage backends and multi-modal data serializers. +Storage layer for PyRIT: storage backends and multi-modal data serializers. Provides the disk and blob storage adapters (``StorageIO`` and its implementations) and the data-type serializers (``data_serializer_factory`` and the per-type ``*DataTypeSerializer`` classes) used to read and write prompt payloads such as text, images, audio, and video. -Unlike ``pyrit.models``, modules in this package may depend on ``pyrit.memory`` -and ``pyrit.auth`` (resolved lazily to avoid import cycles). +These serializers write payload files into the location configured on the active +memory instance (``results_path`` / ``results_storage_io``), which is why they +live alongside ``pyrit.memory``: the database holds the records and this package +holds the blob payloads those records point to. """ -from pyrit.io.serializers import ( +from pyrit.memory.storage.serializers import ( AllowedCategories, AudioPathDataTypeSerializer, BinaryPathDataTypeSerializer, @@ -27,7 +29,7 @@ set_message_piece_sha256_async, set_seed_sha256_async, ) -from pyrit.io.storage import ( +from pyrit.memory.storage.storage import ( AzureBlobStorageIO, DiskStorageIO, StorageIO, diff --git a/pyrit/io/serializers.py b/pyrit/memory/storage/serializers.py similarity index 95% rename from pyrit/io/serializers.py rename to pyrit/memory/storage/serializers.py index 486f3a6f3b..514623a66d 100644 --- a/pyrit/io/serializers.py +++ b/pyrit/memory/storage/serializers.py @@ -18,7 +18,7 @@ from pyrit.common.deprecation import print_deprecation_message from pyrit.common.path import DB_DATA_PATH -from pyrit.io.storage import DiskStorageIO, StorageIO +from pyrit.memory.storage.storage import DiskStorageIO, StorageIO if TYPE_CHECKING: from pyrit.memory import MemoryInterface @@ -368,8 +368,8 @@ async def save_data( # pyrit-async-suffix-exempt output_filename: Optional filename to store data as. """ print_deprecation_message( - old_item="pyrit.io.serializers.DataTypeSerializer.save_data", - new_item="pyrit.io.serializers.DataTypeSerializer.save_data_async", + old_item="pyrit.memory.storage.serializers.DataTypeSerializer.save_data", + new_item="pyrit.memory.storage.serializers.DataTypeSerializer.save_data_async", removed_in="0.16.0", ) await self.save_data_async(data, output_filename) @@ -385,8 +385,8 @@ async def save_b64_image( # pyrit-async-suffix-exempt output_filename: Optional filename to store image as. """ print_deprecation_message( - old_item="pyrit.io.serializers.DataTypeSerializer.save_b64_image", - new_item="pyrit.io.serializers.DataTypeSerializer.save_b64_image_async", + old_item="pyrit.memory.storage.serializers.DataTypeSerializer.save_b64_image", + new_item="pyrit.memory.storage.serializers.DataTypeSerializer.save_b64_image_async", removed_in="0.16.0", ) await self.save_b64_image_async(data, output_filename) @@ -410,8 +410,8 @@ async def save_formatted_audio( # pyrit-async-suffix-exempt output_filename: Optional filename to store audio as. """ print_deprecation_message( - old_item="pyrit.io.serializers.DataTypeSerializer.save_formatted_audio", - new_item="pyrit.io.serializers.DataTypeSerializer.save_formatted_audio_async", + old_item="pyrit.memory.storage.serializers.DataTypeSerializer.save_formatted_audio", + new_item="pyrit.memory.storage.serializers.DataTypeSerializer.save_formatted_audio_async", removed_in="0.16.0", ) await self.save_formatted_audio_async(data, num_channels, sample_width, sample_rate, output_filename) @@ -424,8 +424,8 @@ async def read_data(self) -> bytes: # pyrit-async-suffix-exempt bytes: The data read from storage. """ print_deprecation_message( - old_item="pyrit.io.serializers.DataTypeSerializer.read_data", - new_item="pyrit.io.serializers.DataTypeSerializer.read_data_async", + old_item="pyrit.memory.storage.serializers.DataTypeSerializer.read_data", + new_item="pyrit.memory.storage.serializers.DataTypeSerializer.read_data_async", removed_in="0.16.0", ) return await self.read_data_async() @@ -438,8 +438,8 @@ async def read_data_base64(self) -> str: # pyrit-async-suffix-exempt str: Base64-encoded data. """ print_deprecation_message( - old_item="pyrit.io.serializers.DataTypeSerializer.read_data_base64", - new_item="pyrit.io.serializers.DataTypeSerializer.read_data_base64_async", + old_item="pyrit.memory.storage.serializers.DataTypeSerializer.read_data_base64", + new_item="pyrit.memory.storage.serializers.DataTypeSerializer.read_data_base64_async", removed_in="0.16.0", ) return await self.read_data_base64_async() @@ -452,8 +452,8 @@ async def get_sha256(self) -> str: # pyrit-async-suffix-exempt str: Hex digest of the computed SHA256 hash. """ print_deprecation_message( - old_item="pyrit.io.serializers.DataTypeSerializer.get_sha256", - new_item="pyrit.io.serializers.DataTypeSerializer.get_sha256_async", + old_item="pyrit.memory.storage.serializers.DataTypeSerializer.get_sha256", + new_item="pyrit.memory.storage.serializers.DataTypeSerializer.get_sha256_async", removed_in="0.16.0", ) return await self.get_sha256_async() @@ -471,8 +471,8 @@ async def get_data_filename( # pyrit-async-suffix-exempt Union[Path, str]: Full storage path for the generated data file. """ print_deprecation_message( - old_item="pyrit.io.serializers.DataTypeSerializer.get_data_filename", - new_item="pyrit.io.serializers.DataTypeSerializer.get_data_filename_async", + old_item="pyrit.memory.storage.serializers.DataTypeSerializer.get_data_filename", + new_item="pyrit.memory.storage.serializers.DataTypeSerializer.get_data_filename_async", removed_in="0.16.0", ) return await self.get_data_filename_async(file_name) diff --git a/pyrit/io/storage.py b/pyrit/memory/storage/storage.py similarity index 95% rename from pyrit/io/storage.py rename to pyrit/memory/storage/storage.py index e42e01f1e1..bf79a048e0 100644 --- a/pyrit/io/storage.py +++ b/pyrit/memory/storage/storage.py @@ -76,8 +76,8 @@ async def read_file(self, path: Union[Path, str]) -> bytes: # pyrit-async-suffi bytes: The content of the file. """ print_deprecation_message( - old_item="pyrit.io.storage.StorageIO.read_file", - new_item="pyrit.io.storage.StorageIO.read_file_async", + old_item="pyrit.memory.storage.storage.StorageIO.read_file", + new_item="pyrit.memory.storage.storage.StorageIO.read_file_async", removed_in="0.16.0", ) return await self.read_file_async(path) @@ -91,8 +91,8 @@ async def write_file(self, path: Union[Path, str], data: bytes) -> None: # pyri data (bytes): The content to write to the file. """ print_deprecation_message( - old_item="pyrit.io.storage.StorageIO.write_file", - new_item="pyrit.io.storage.StorageIO.write_file_async", + old_item="pyrit.memory.storage.storage.StorageIO.write_file", + new_item="pyrit.memory.storage.storage.StorageIO.write_file_async", removed_in="0.16.0", ) await self.write_file_async(path, data) @@ -108,8 +108,8 @@ async def path_exists(self, path: Union[Path, str]) -> bool: # pyrit-async-suff bool: True if the path exists, False otherwise. """ print_deprecation_message( - old_item="pyrit.io.storage.StorageIO.path_exists", - new_item="pyrit.io.storage.StorageIO.path_exists_async", + old_item="pyrit.memory.storage.storage.StorageIO.path_exists", + new_item="pyrit.memory.storage.storage.StorageIO.path_exists_async", removed_in="0.16.0", ) return await self.path_exists_async(path) @@ -125,8 +125,8 @@ async def is_file(self, path: Union[Path, str]) -> bool: # pyrit-async-suffix-e bool: True if the path is a file, False otherwise. """ print_deprecation_message( - old_item="pyrit.io.storage.StorageIO.is_file", - new_item="pyrit.io.storage.StorageIO.is_file_async", + old_item="pyrit.memory.storage.storage.StorageIO.is_file", + new_item="pyrit.memory.storage.storage.StorageIO.is_file_async", removed_in="0.16.0", ) return await self.is_file_async(path) @@ -139,8 +139,8 @@ async def create_directory_if_not_exists(self, path: Union[Path, str]) -> None: path (Union[Path, str]): The directory path to create. """ print_deprecation_message( - old_item="pyrit.io.storage.StorageIO.create_directory_if_not_exists", - new_item="pyrit.io.storage.StorageIO.create_directory_if_not_exists_async", + old_item="pyrit.memory.storage.storage.StorageIO.create_directory_if_not_exists", + new_item="pyrit.memory.storage.storage.StorageIO.create_directory_if_not_exists_async", removed_in="0.16.0", ) await self.create_directory_if_not_exists_async(path) diff --git a/pyrit/message_normalizer/chat_message_normalizer.py b/pyrit/message_normalizer/chat_message_normalizer.py index 09eef1fdeb..454045794b 100644 --- a/pyrit/message_normalizer/chat_message_normalizer.py +++ b/pyrit/message_normalizer/chat_message_normalizer.py @@ -9,7 +9,7 @@ import aiofiles from pyrit.common.data_url_converter import convert_local_image_to_data_url_async -from pyrit.io import DataTypeSerializer +from pyrit.memory.storage import DataTypeSerializer from pyrit.message_normalizer.message_normalizer import ( MessageListNormalizer, MessageStringNormalizer, diff --git a/pyrit/models/__init__.py b/pyrit/models/__init__.py index 3a6a0a49bf..da4dafb9cc 100644 --- a/pyrit/models/__init__.py +++ b/pyrit/models/__init__.py @@ -188,22 +188,22 @@ "ScorerIdentifier": ComponentIdentifier, } -# Names that moved to ``pyrit.io`` in Phase 9. Served lazily via importlib so that +# Names that moved to ``pyrit.memory.storage``. Served lazily via importlib so that # importing ``pyrit.models`` stays import-boundary clean and fires no warning until a # moved name is actually accessed. Will be removed in 0.17.0. -_MOVED_TO_PYRIT_IO: dict[str, str] = { - "AllowedCategories": "pyrit.io.serializers", - "AudioPathDataTypeSerializer": "pyrit.io.serializers", - "BinaryPathDataTypeSerializer": "pyrit.io.serializers", - "DataTypeSerializer": "pyrit.io.serializers", - "ErrorDataTypeSerializer": "pyrit.io.serializers", - "ImagePathDataTypeSerializer": "pyrit.io.serializers", - "TextDataTypeSerializer": "pyrit.io.serializers", - "VideoPathDataTypeSerializer": "pyrit.io.serializers", - "data_serializer_factory": "pyrit.io.serializers", - "AzureBlobStorageIO": "pyrit.io.storage", - "DiskStorageIO": "pyrit.io.storage", - "StorageIO": "pyrit.io.storage", +_MOVED_TO_MEMORY_STORAGE: dict[str, str] = { + "AllowedCategories": "pyrit.memory.storage.serializers", + "AudioPathDataTypeSerializer": "pyrit.memory.storage.serializers", + "BinaryPathDataTypeSerializer": "pyrit.memory.storage.serializers", + "DataTypeSerializer": "pyrit.memory.storage.serializers", + "ErrorDataTypeSerializer": "pyrit.memory.storage.serializers", + "ImagePathDataTypeSerializer": "pyrit.memory.storage.serializers", + "TextDataTypeSerializer": "pyrit.memory.storage.serializers", + "VideoPathDataTypeSerializer": "pyrit.memory.storage.serializers", + "data_serializer_factory": "pyrit.memory.storage.serializers", + "AzureBlobStorageIO": "pyrit.memory.storage.storage", + "DiskStorageIO": "pyrit.memory.storage.storage", + "StorageIO": "pyrit.memory.storage.storage", } _warned: set[str] = set() @@ -220,8 +220,8 @@ def __getattr__(name: str) -> Any: ) _warned.add(name) return target - if name in _MOVED_TO_PYRIT_IO: - target_module = _MOVED_TO_PYRIT_IO[name] + if name in _MOVED_TO_MEMORY_STORAGE: + target_module = _MOVED_TO_MEMORY_STORAGE[name] if name not in _warned: print_deprecation_message( old_item=f"{__name__}.{name}", diff --git a/pyrit/models/data_type_serializer.py b/pyrit/models/data_type_serializer.py index 369cc405d5..a2659204a3 100644 --- a/pyrit/models/data_type_serializer.py +++ b/pyrit/models/data_type_serializer.py @@ -2,11 +2,12 @@ # Licensed under the MIT license. """ -Deprecation shim — the data-type serializers moved to ``pyrit.io.serializers``. +Deprecation shim — the data-type serializers now live in +``pyrit.memory.storage``. Importing names from ``pyrit.models.data_type_serializer`` still works for one release but emits a one-time ``DeprecationWarning`` per name. Import from -``pyrit.io`` instead. This shim will be removed in 0.17.0. +``pyrit.memory.storage`` instead. This shim will be removed in 0.17.0. """ from __future__ import annotations @@ -28,7 +29,7 @@ __getattr__ = module_deprecation_getattr( old_module="pyrit.models.data_type_serializer", - target_module="pyrit.io.serializers", + target_module="pyrit.memory.storage.serializers", names=__all__, removed_in="0.17.0", ) diff --git a/pyrit/models/messages/message_piece.py b/pyrit/models/messages/message_piece.py index 44c64eb453..4978be8d94 100644 --- a/pyrit/models/messages/message_piece.py +++ b/pyrit/models/messages/message_piece.py @@ -306,17 +306,17 @@ async def set_sha256_values_async(self) -> None: Compute SHA256 hash values for original and converted payloads. .. deprecated:: 0.15.0 - Use ``pyrit.io.serializers.set_message_piece_sha256_async`` instead. + Use ``pyrit.memory.storage.serializers.set_message_piece_sha256_async`` instead. This method will be removed in 0.17.0. """ import importlib print_deprecation_message( old_item="pyrit.models.messages.message_piece.MessagePiece.set_sha256_values_async", - new_item="pyrit.io.serializers.set_message_piece_sha256_async", + new_item="pyrit.memory.storage.serializers.set_message_piece_sha256_async", removed_in="0.17.0", ) - serializers = importlib.import_module("pyrit.io.serializers") + serializers = importlib.import_module("pyrit.memory.storage.serializers") await serializers.set_message_piece_sha256_async(self) diff --git a/pyrit/models/seeds/seed.py b/pyrit/models/seeds/seed.py index 7fc6da745c..c8ee9588bf 100644 --- a/pyrit/models/seeds/seed.py +++ b/pyrit/models/seeds/seed.py @@ -234,7 +234,7 @@ async def set_sha256_value_async(self) -> None: Compute the SHA256 hash value asynchronously. .. deprecated:: 0.15.0 - Use ``pyrit.io.serializers.set_seed_sha256_async`` instead. + Use ``pyrit.memory.storage.serializers.set_seed_sha256_async`` instead. This method will be removed in 0.17.0. """ import importlib @@ -243,10 +243,10 @@ async def set_sha256_value_async(self) -> None: print_deprecation_message( old_item="pyrit.models.seeds.seed.Seed.set_sha256_value_async", - new_item="pyrit.io.serializers.set_seed_sha256_async", + new_item="pyrit.memory.storage.serializers.set_seed_sha256_async", removed_in="0.17.0", ) - serializers = importlib.import_module("pyrit.io.serializers") + serializers = importlib.import_module("pyrit.memory.storage.serializers") await serializers.set_seed_sha256_async(self) @staticmethod diff --git a/pyrit/models/storage_io.py b/pyrit/models/storage_io.py index c2057935f0..ba4b284e44 100644 --- a/pyrit/models/storage_io.py +++ b/pyrit/models/storage_io.py @@ -2,11 +2,12 @@ # Licensed under the MIT license. """ -Deprecation shim — the storage I/O classes moved to ``pyrit.io.storage``. +Deprecation shim — the storage I/O classes now live in +``pyrit.memory.storage``. Importing names from ``pyrit.models.storage_io`` still works for one release but -emits a one-time ``DeprecationWarning`` per name. Import from ``pyrit.io`` instead. -This shim will be removed in 0.17.0. +emits a one-time ``DeprecationWarning`` per name. Import from +``pyrit.memory.storage`` instead. This shim will be removed in 0.17.0. """ from __future__ import annotations @@ -22,7 +23,7 @@ __getattr__ = module_deprecation_getattr( old_module="pyrit.models.storage_io", - target_module="pyrit.io.storage", + target_module="pyrit.memory.storage.storage", names=__all__, removed_in="0.17.0", ) diff --git a/pyrit/output/conversation/pretty.py b/pyrit/output/conversation/pretty.py index 93347fb398..f77e751d80 100644 --- a/pyrit/output/conversation/pretty.py +++ b/pyrit/output/conversation/pretty.py @@ -338,7 +338,7 @@ async def _display_image_async(self, piece: MessagePiece) -> None: if not is_in_ipython_session(): return - from pyrit.io import ImagePathDataTypeSerializer + from pyrit.memory.storage import ImagePathDataTypeSerializer try: serializer = ImagePathDataTypeSerializer(category="", prompt_text=piece.converted_value) diff --git a/pyrit/prompt_converter/add_image_text_converter.py b/pyrit/prompt_converter/add_image_text_converter.py index 2e0fd3acc9..5a539a0362 100644 --- a/pyrit/prompt_converter/add_image_text_converter.py +++ b/pyrit/prompt_converter/add_image_text_converter.py @@ -10,7 +10,7 @@ from PIL.ImageFont import FreeTypeFont from pyrit.common.deprecation import print_deprecation_message -from pyrit.io import data_serializer_factory +from pyrit.memory.storage import data_serializer_factory from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.base_image_text_converter import _BaseImageTextConverter from pyrit.prompt_converter.prompt_converter import ConverterResult diff --git a/pyrit/prompt_converter/add_image_to_video_converter.py b/pyrit/prompt_converter/add_image_to_video_converter.py index 9468b30f7d..b3ef5c8731 100644 --- a/pyrit/prompt_converter/add_image_to_video_converter.py +++ b/pyrit/prompt_converter/add_image_to_video_converter.py @@ -10,7 +10,7 @@ import numpy as np from pyrit.common.path import DB_DATA_PATH -from pyrit.io import data_serializer_factory +from pyrit.memory.storage import data_serializer_factory from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter diff --git a/pyrit/prompt_converter/add_text_image_converter.py b/pyrit/prompt_converter/add_text_image_converter.py index 84b6890797..0261fd9212 100644 --- a/pyrit/prompt_converter/add_text_image_converter.py +++ b/pyrit/prompt_converter/add_text_image_converter.py @@ -11,7 +11,7 @@ from PIL.ImageFont import FreeTypeFont from pyrit.common.deprecation import print_deprecation_message -from pyrit.io import data_serializer_factory +from pyrit.memory.storage import data_serializer_factory from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.base_image_text_converter import _BaseImageTextConverter from pyrit.prompt_converter.prompt_converter import ConverterResult diff --git a/pyrit/prompt_converter/audio_echo_converter.py b/pyrit/prompt_converter/audio_echo_converter.py index dd1a7ce91c..b854bf57ac 100644 --- a/pyrit/prompt_converter/audio_echo_converter.py +++ b/pyrit/prompt_converter/audio_echo_converter.py @@ -8,7 +8,7 @@ import numpy as np from scipy.io import wavfile -from pyrit.io import data_serializer_factory +from pyrit.memory.storage import data_serializer_factory from pyrit.models import PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter diff --git a/pyrit/prompt_converter/audio_frequency_converter.py b/pyrit/prompt_converter/audio_frequency_converter.py index 07054b90fd..dc4e42c36b 100644 --- a/pyrit/prompt_converter/audio_frequency_converter.py +++ b/pyrit/prompt_converter/audio_frequency_converter.py @@ -8,7 +8,7 @@ import numpy as np from scipy.io import wavfile -from pyrit.io import data_serializer_factory +from pyrit.memory.storage import data_serializer_factory from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter diff --git a/pyrit/prompt_converter/audio_speed_converter.py b/pyrit/prompt_converter/audio_speed_converter.py index 31475d2ee0..038ee6e387 100644 --- a/pyrit/prompt_converter/audio_speed_converter.py +++ b/pyrit/prompt_converter/audio_speed_converter.py @@ -9,7 +9,7 @@ from scipy.interpolate import interp1d from scipy.io import wavfile -from pyrit.io import data_serializer_factory +from pyrit.memory.storage import data_serializer_factory from pyrit.models import PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter diff --git a/pyrit/prompt_converter/audio_volume_converter.py b/pyrit/prompt_converter/audio_volume_converter.py index f6039f2bbd..dc0acc1acd 100644 --- a/pyrit/prompt_converter/audio_volume_converter.py +++ b/pyrit/prompt_converter/audio_volume_converter.py @@ -8,7 +8,7 @@ import numpy as np from scipy.io import wavfile -from pyrit.io import data_serializer_factory +from pyrit.memory.storage import data_serializer_factory from pyrit.models import PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter diff --git a/pyrit/prompt_converter/audio_white_noise_converter.py b/pyrit/prompt_converter/audio_white_noise_converter.py index 854b81f2cf..9769934d24 100644 --- a/pyrit/prompt_converter/audio_white_noise_converter.py +++ b/pyrit/prompt_converter/audio_white_noise_converter.py @@ -8,7 +8,7 @@ import numpy as np from scipy.io import wavfile -from pyrit.io import data_serializer_factory +from pyrit.memory.storage import data_serializer_factory from pyrit.models import PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter diff --git a/pyrit/prompt_converter/azure_speech_audio_to_text_converter.py b/pyrit/prompt_converter/azure_speech_audio_to_text_converter.py index 1b9a3deb08..84e15a389f 100644 --- a/pyrit/prompt_converter/azure_speech_audio_to_text_converter.py +++ b/pyrit/prompt_converter/azure_speech_audio_to_text_converter.py @@ -13,7 +13,7 @@ from pyrit.auth.azure_auth import get_speech_config, get_speech_config_async from pyrit.common import default_values from pyrit.common.deprecation import print_deprecation_message -from pyrit.io import data_serializer_factory +from pyrit.memory.storage import data_serializer_factory from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter diff --git a/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py b/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py index fda0876fd4..6071d25cf9 100644 --- a/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py +++ b/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py @@ -11,7 +11,7 @@ from pyrit.auth.azure_auth import get_speech_config_async from pyrit.common import default_values from pyrit.common.deprecation import print_deprecation_message -from pyrit.io import data_serializer_factory +from pyrit.memory.storage import data_serializer_factory from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter diff --git a/pyrit/prompt_converter/base_image_to_image_converter.py b/pyrit/prompt_converter/base_image_to_image_converter.py index 6214d41665..8ed42b5685 100644 --- a/pyrit/prompt_converter/base_image_to_image_converter.py +++ b/pyrit/prompt_converter/base_image_to_image_converter.py @@ -11,7 +11,7 @@ import aiohttp from PIL import Image -from pyrit.io import data_serializer_factory +from pyrit.memory.storage import data_serializer_factory from pyrit.models import PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter diff --git a/pyrit/prompt_converter/image_compression_converter.py b/pyrit/prompt_converter/image_compression_converter.py index 9c33cc7470..11ec45dacb 100644 --- a/pyrit/prompt_converter/image_compression_converter.py +++ b/pyrit/prompt_converter/image_compression_converter.py @@ -10,7 +10,7 @@ import aiohttp from PIL import Image -from pyrit.io import data_serializer_factory +from pyrit.memory.storage import data_serializer_factory from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter diff --git a/pyrit/prompt_converter/image_overlay_converter.py b/pyrit/prompt_converter/image_overlay_converter.py index 7578f61e55..b73c9c519b 100644 --- a/pyrit/prompt_converter/image_overlay_converter.py +++ b/pyrit/prompt_converter/image_overlay_converter.py @@ -7,7 +7,7 @@ from PIL import Image -from pyrit.io import DataTypeSerializer, data_serializer_factory +from pyrit.memory.storage import DataTypeSerializer, data_serializer_factory from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter diff --git a/pyrit/prompt_converter/pdf_converter.py b/pyrit/prompt_converter/pdf_converter.py index 352b9b8a5b..52bfefad00 100644 --- a/pyrit/prompt_converter/pdf_converter.py +++ b/pyrit/prompt_converter/pdf_converter.py @@ -13,7 +13,7 @@ from reportlab.pdfgen import canvas from pyrit.common.logger import logger -from pyrit.io import DataTypeSerializer, data_serializer_factory +from pyrit.memory.storage import DataTypeSerializer, data_serializer_factory from pyrit.models import ComponentIdentifier, PromptDataType, SeedPrompt from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter diff --git a/pyrit/prompt_converter/qr_code_converter.py b/pyrit/prompt_converter/qr_code_converter.py index b2519600b7..042ed54341 100644 --- a/pyrit/prompt_converter/qr_code_converter.py +++ b/pyrit/prompt_converter/qr_code_converter.py @@ -5,7 +5,7 @@ import segno -from pyrit.io import data_serializer_factory +from pyrit.memory.storage import data_serializer_factory from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter diff --git a/pyrit/prompt_converter/transparency_attack_converter.py b/pyrit/prompt_converter/transparency_attack_converter.py index 4da612d93a..cd8833c4cb 100644 --- a/pyrit/prompt_converter/transparency_attack_converter.py +++ b/pyrit/prompt_converter/transparency_attack_converter.py @@ -10,7 +10,7 @@ import numpy as np from PIL import Image -from pyrit.io import data_serializer_factory +from pyrit.memory.storage import data_serializer_factory from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter diff --git a/pyrit/prompt_converter/word_doc_converter.py b/pyrit/prompt_converter/word_doc_converter.py index 81ef353d18..47bb8bd010 100644 --- a/pyrit/prompt_converter/word_doc_converter.py +++ b/pyrit/prompt_converter/word_doc_converter.py @@ -12,13 +12,13 @@ from docx import Document from pyrit.common.logger import logger -from pyrit.io import data_serializer_factory +from pyrit.memory.storage import data_serializer_factory from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter if TYPE_CHECKING: from pathlib import Path - from pyrit.io import DataTypeSerializer + from pyrit.memory.storage import DataTypeSerializer from pyrit.models import ComponentIdentifier, PromptDataType, SeedPrompt diff --git a/pyrit/prompt_normalizer/prompt_normalizer.py b/pyrit/prompt_normalizer/prompt_normalizer.py index 87a517d493..21ba187be3 100644 --- a/pyrit/prompt_normalizer/prompt_normalizer.py +++ b/pyrit/prompt_normalizer/prompt_normalizer.py @@ -19,8 +19,8 @@ execution_context, get_execution_context, ) -from pyrit.io import set_message_piece_sha256_async from pyrit.memory import CentralMemory, MemoryInterface +from pyrit.memory.storage import set_message_piece_sha256_async from pyrit.models import ( ComponentIdentifier, Message, diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index 0a04bfaef6..5c954fb83d 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -13,7 +13,7 @@ PyritException, pyrit_target_retry, ) -from pyrit.io import DataTypeSerializer, data_serializer_factory +from pyrit.memory.storage import DataTypeSerializer, data_serializer_factory from pyrit.models import ChatMessage, ComponentIdentifier, Message, MessagePiece, construct_response_from_request from pyrit.models.json_response_config import _JsonResponseConfig from pyrit.prompt_target.common.target_capabilities import TargetCapabilities diff --git a/pyrit/prompt_target/openai/openai_image_target.py b/pyrit/prompt_target/openai/openai_image_target.py index ae1ef6b593..f47c173270 100644 --- a/pyrit/prompt_target/openai/openai_image_target.py +++ b/pyrit/prompt_target/openai/openai_image_target.py @@ -11,7 +11,7 @@ EmptyResponseException, pyrit_target_retry, ) -from pyrit.io import data_serializer_factory +from pyrit.memory.storage import data_serializer_factory from pyrit.models import ComponentIdentifier, Message, construct_response_from_request from pyrit.prompt_target.common.target_capabilities import TargetCapabilities from pyrit.prompt_target.common.target_configuration import TargetConfiguration diff --git a/pyrit/prompt_target/openai/openai_realtime_target.py b/pyrit/prompt_target/openai/openai_realtime_target.py index 26d61c758e..bcfc616061 100644 --- a/pyrit/prompt_target/openai/openai_realtime_target.py +++ b/pyrit/prompt_target/openai/openai_realtime_target.py @@ -15,7 +15,7 @@ pyrit_target_retry, ) from pyrit.exceptions.exception_classes import ServerErrorException -from pyrit.io import data_serializer_factory +from pyrit.memory.storage import data_serializer_factory from pyrit.models import ComponentIdentifier, Message, construct_response_from_request from pyrit.prompt_target.common.realtime_audio import ( RealtimeTargetResult, diff --git a/pyrit/prompt_target/openai/openai_tts_target.py b/pyrit/prompt_target/openai/openai_tts_target.py index 71c71c2cb2..4d6b031824 100644 --- a/pyrit/prompt_target/openai/openai_tts_target.py +++ b/pyrit/prompt_target/openai/openai_tts_target.py @@ -7,7 +7,7 @@ from pyrit.exceptions import ( pyrit_target_retry, ) -from pyrit.io import data_serializer_factory +from pyrit.memory.storage import data_serializer_factory from pyrit.models import ComponentIdentifier, Message, construct_response_from_request from pyrit.prompt_target.common.target_capabilities import TargetCapabilities from pyrit.prompt_target.common.target_configuration import TargetConfiguration diff --git a/pyrit/prompt_target/openai/openai_video_target.py b/pyrit/prompt_target/openai/openai_video_target.py index 9dfac1a356..9e63bed485 100644 --- a/pyrit/prompt_target/openai/openai_video_target.py +++ b/pyrit/prompt_target/openai/openai_video_target.py @@ -11,7 +11,7 @@ from pyrit.exceptions import ( pyrit_target_retry, ) -from pyrit.io import DataTypeSerializer, data_serializer_factory +from pyrit.memory.storage import DataTypeSerializer, data_serializer_factory from pyrit.models import ComponentIdentifier, Message, MessagePiece, construct_response_from_request from pyrit.prompt_target.common.target_capabilities import TargetCapabilities from pyrit.prompt_target.common.target_configuration import TargetConfiguration diff --git a/pyrit/prompt_target/playwright_copilot_target.py b/pyrit/prompt_target/playwright_copilot_target.py index aff5195ee0..a92ac37f49 100644 --- a/pyrit/prompt_target/playwright_copilot_target.py +++ b/pyrit/prompt_target/playwright_copilot_target.py @@ -9,7 +9,7 @@ from enum import Enum from typing import TYPE_CHECKING, Any, Optional, Union -from pyrit.io import data_serializer_factory +from pyrit.memory.storage import data_serializer_factory from pyrit.models import ComponentIdentifier, Message, MessagePiece, construct_response_from_request from pyrit.models.literals import PromptDataType from pyrit.prompt_target.common.prompt_target import PromptTarget diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index 4e92628856..53aa7aa23b 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -19,7 +19,7 @@ EmptyResponseException, pyrit_target_retry, ) -from pyrit.io import DataTypeSerializer +from pyrit.memory.storage import DataTypeSerializer from pyrit.models import ComponentIdentifier, Message, MessagePiece, construct_response_from_request from pyrit.prompt_target import PromptTarget, limit_requests_per_minute from pyrit.prompt_target.common.target_capabilities import TargetCapabilities diff --git a/pyrit/score/float_scale/azure_content_filter_scorer.py b/pyrit/score/float_scale/azure_content_filter_scorer.py index a4b4e8e030..33e74929b5 100644 --- a/pyrit/score/float_scale/azure_content_filter_scorer.py +++ b/pyrit/score/float_scale/azure_content_filter_scorer.py @@ -19,7 +19,7 @@ from pyrit.auth import AsyncTokenProviderCredential, ensure_async_token_provider, get_azure_async_token_provider from pyrit.common import default_values -from pyrit.io import DataTypeSerializer, data_serializer_factory +from pyrit.memory.storage import DataTypeSerializer, data_serializer_factory from pyrit.models import ComponentIdentifier, Message, MessagePiece, Score from pyrit.score.float_scale.float_scale_score_aggregator import ( FloatScaleScorerByCategory, diff --git a/tests/unit/common/test_convert_local_image_to_data_url.py b/tests/unit/common/test_convert_local_image_to_data_url.py index adff0f4890..502cc1df95 100644 --- a/tests/unit/common/test_convert_local_image_to_data_url.py +++ b/tests/unit/common/test_convert_local_image_to_data_url.py @@ -50,7 +50,7 @@ async def test_convert_local_image_to_data_url_missing_file(): @patch("os.path.exists", return_value=True) @patch("mimetypes.guess_type", return_value=("image/jpg", None)) -@patch("pyrit.io.serializers.ImagePathDataTypeSerializer") +@patch("pyrit.memory.storage.serializers.ImagePathDataTypeSerializer") @patch("pyrit.memory.CentralMemory.get_memory_instance", return_value=SQLiteMemory(db_path=":memory:")) async def test_convert_image_to_data_url_success( mock_get_memory_instance, mock_serializer_class, mock_guess_type, mock_exists diff --git a/tests/unit/common/test_display_response.py b/tests/unit/common/test_display_response.py index d32e559709..ee944bc859 100644 --- a/tests/unit/common/test_display_response.py +++ b/tests/unit/common/test_display_response.py @@ -98,7 +98,7 @@ async def test_display_image_logs_error_when_storage_io_is_none(mock_ipython, ca @patch("pyrit.common.display_response.display", create=True) async def test_display_image_azure_fallback_to_disk(mock_display, mock_image, mock_disk_io_cls, mock_ipython): """Test that when AzureBlobStorageIO read fails, it falls back to DiskStorageIO.""" - from pyrit.io import AzureBlobStorageIO + from pyrit.memory.storage import AzureBlobStorageIO mock_memory = MagicMock() mock_azure_io = MagicMock(spec=AzureBlobStorageIO) @@ -126,7 +126,7 @@ async def test_display_image_azure_fallback_to_disk(mock_display, mock_image, mo @patch("pyrit.common.display_response.DiskStorageIO") async def test_display_image_azure_and_disk_both_fail(mock_disk_io_cls, mock_ipython, caplog): """Test that when both AzureBlobStorageIO and DiskStorageIO fail, error is logged and returns.""" - from pyrit.io import AzureBlobStorageIO + from pyrit.memory.storage import AzureBlobStorageIO mock_memory = MagicMock() mock_azure_io = MagicMock(spec=AzureBlobStorageIO) diff --git a/tests/unit/memory/storage/__init__.py b/tests/unit/memory/storage/__init__.py new file mode 100644 index 0000000000..9a0454564d --- /dev/null +++ b/tests/unit/memory/storage/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. diff --git a/tests/unit/io/test_deprecation_shims.py b/tests/unit/memory/storage/test_deprecation_shims.py similarity index 87% rename from tests/unit/io/test_deprecation_shims.py rename to tests/unit/memory/storage/test_deprecation_shims.py index dda25dbbf9..3d5a748508 100644 --- a/tests/unit/io/test_deprecation_shims.py +++ b/tests/unit/memory/storage/test_deprecation_shims.py @@ -5,11 +5,11 @@ Tests for the Phase 9 deprecation shims. ``pyrit.models.storage_io`` and ``pyrit.models.data_type_serializer`` moved to -``pyrit.io.storage`` / ``pyrit.io.serializers``. The old module paths, the +``pyrit.memory.storage.storage`` / ``pyrit.memory.storage.serializers``. The old module paths, the ``pyrit.models`` package-root re-exports, and the ``MessagePiece.set_sha256_values_async`` / ``Seed.set_sha256_value_async`` method shims all still work but emit a ``DeprecationWarning`` pointing at the -new ``pyrit.io`` location. These tests pin that contract. The shims will be +new ``pyrit.memory.storage`` location. These tests pin that contract. The shims will be removed in 0.17.0. """ @@ -23,8 +23,8 @@ import pytest -import pyrit.io.serializers as new_serializers -import pyrit.io.storage as new_storage +import pyrit.memory.storage.serializers as new_serializers +import pyrit.memory.storage.storage as new_storage import pyrit.models as models_pkg import pyrit.models.data_type_serializer as serializer_shim import pyrit.models.storage_io as storage_shim @@ -32,8 +32,8 @@ from pyrit.models.seeds.seed import Seed MODULE_SHIM_PAIRS = [ - (storage_shim, new_storage, "pyrit.models.storage_io", "pyrit.io.storage"), - (serializer_shim, new_serializers, "pyrit.models.data_type_serializer", "pyrit.io.serializers"), + (storage_shim, new_storage, "pyrit.models.storage_io", "pyrit.memory.storage.storage"), + (serializer_shim, new_serializers, "pyrit.models.data_type_serializer", "pyrit.memory.storage.serializers"), ] @@ -86,10 +86,10 @@ def test_module_shim_dir_returns_sorted_all(shim_mod, new_mod, old_path, new_pat assert dir(shim_mod) == sorted(shim_mod.__all__) -def test_moved_to_pyrit_io_contains_expected_root_exports(): +def test_moved_to_memory_storage_contains_expected_root_exports(): # Guards against accidentally dropping a previously root-importable name from the # forwarding table. These are exactly the names that used to be importable from - # ``pyrit.models`` and now live in ``pyrit.io``. URLDataTypeSerializer and + # ``pyrit.models`` and now live in ``pyrit.memory.storage``. URLDataTypeSerializer and # SupportedContentType were never root-exported, so they are intentionally absent. expected = { "AllowedCategories", @@ -105,12 +105,12 @@ def test_moved_to_pyrit_io_contains_expected_root_exports(): "DiskStorageIO", "StorageIO", } - assert set(models_pkg._MOVED_TO_PYRIT_IO) == expected + assert set(models_pkg._MOVED_TO_MEMORY_STORAGE) == expected -@pytest.mark.parametrize("name", sorted(models_pkg._MOVED_TO_PYRIT_IO)) +@pytest.mark.parametrize("name", sorted(models_pkg._MOVED_TO_MEMORY_STORAGE)) def test_models_package_root_forwards_and_warns_once(name): - target_module = models_pkg._MOVED_TO_PYRIT_IO[name] + target_module = models_pkg._MOVED_TO_MEMORY_STORAGE[name] with warnings.catch_warnings(record=True) as caught: warnings.simplefilter("always", DeprecationWarning) first = getattr(models_pkg, name) @@ -139,7 +139,7 @@ def test_importing_pyrit_models_does_not_warn(): " import pyrit.models\n" "offenders = [str(w.message) for w in caught\n" " if issubclass(w.category, DeprecationWarning)\n" - " and ('pyrit.io' in str(w.message) or 'pyrit.models.storage_io' in str(w.message)\n" + " and ('pyrit.memory.storage' in str(w.message) or 'pyrit.models.storage_io' in str(w.message)\n" " or 'pyrit.models.data_type_serializer' in str(w.message))]\n" "assert not offenders, offenders\n" ) @@ -160,7 +160,7 @@ async def test_message_piece_method_shim_warns_and_delegates(): assert len(dep) == 1 message = str(dep[0].message) assert "MessagePiece.set_sha256_values_async" in message - assert "pyrit.io.serializers.set_message_piece_sha256_async" in message + assert "pyrit.memory.storage.serializers.set_message_piece_sha256_async" in message assert "0.17.0" in message @@ -177,5 +177,5 @@ async def test_seed_method_shim_warns_and_delegates(): assert len(dep) == 1 message = str(dep[0].message) assert "Seed.set_sha256_value_async" in message - assert "pyrit.io.serializers.set_seed_sha256_async" in message + assert "pyrit.memory.storage.serializers.set_seed_sha256_async" in message assert "0.17.0" in message diff --git a/tests/unit/io/test_serializers.py b/tests/unit/memory/storage/test_serializers.py similarity index 97% rename from tests/unit/io/test_serializers.py rename to tests/unit/memory/storage/test_serializers.py index bd28721600..53af26ba60 100644 --- a/tests/unit/io/test_serializers.py +++ b/tests/unit/memory/storage/test_serializers.py @@ -11,7 +11,7 @@ import pytest from PIL import Image -from pyrit.io import ( +from pyrit.memory.storage import ( AllowedCategories, BinaryPathDataTypeSerializer, DataTypeSerializer, @@ -255,7 +255,7 @@ async def test_read_data_local_file_with_dummy_image(sqlite_instance): with open(image_path, "rb") as f: mock_storage_io.read_file_async.return_value = f.read() - with patch("pyrit.io.serializers.DiskStorageIO", return_value=mock_storage_io): + with patch("pyrit.memory.storage.serializers.DiskStorageIO", return_value=mock_storage_io): serializer = data_serializer_factory( category="prompt-memory-entries", data_type="image_path", value=image_path ) @@ -388,7 +388,7 @@ async def test_save_b64_image_raises_when_results_storage_io_none(): async def test_save_formatted_audio_raises_when_results_storage_io_none(): - from pyrit.io import data_serializer_factory as factory + from pyrit.memory.storage import data_serializer_factory as factory serializer = factory(category="prompt-memory-entries", data_type="audio_path") mock_memory = MagicMock() @@ -411,7 +411,7 @@ async def test_save_formatted_audio_writes_local_wav_via_to_thread(sqlite_instan """save_formatted_audio (local-disk path) should produce a readable WAV via _write_wav_sync.""" import wave - from pyrit.io import data_serializer_factory as factory + from pyrit.memory.storage import data_serializer_factory as factory serializer = factory(category="prompt-memory-entries", data_type="audio_path") output_path = tmp_path / "out.wav" @@ -437,7 +437,7 @@ def test_write_wav_sync_produces_readable_wav(tmp_path): """_write_wav_sync should produce a WAV file readable by wave.open with the same metadata and frames.""" import wave - from pyrit.io.serializers import _write_wav_sync + from pyrit.memory.storage.serializers import _write_wav_sync out_path = tmp_path / "direct.wav" pcm = b"\x10\x00\x20\x00\x30\x00\x40\x00" @@ -462,7 +462,7 @@ async def test_save_formatted_audio_writes_azure_wav_via_storage_io(sqlite_insta import wave from pyrit.common import path as common_path - from pyrit.io import data_serializer_factory as factory + from pyrit.memory.storage import data_serializer_factory as factory captured: dict[str, bytes] = {} @@ -483,7 +483,7 @@ async def _capture_write(file_path, data): with patch.object(serializer, "get_data_filename_async", new_callable=AsyncMock, return_value=azure_url): # Redirect DB_DATA_PATH so the temp_audio.wav write lands in tmp_path with patch.object(common_path, "DB_DATA_PATH", str(tmp_path)): - from pyrit.io import serializers as dts_module + from pyrit.memory.storage import serializers as dts_module with patch.object(dts_module, "DB_DATA_PATH", str(tmp_path)): await serializer.save_formatted_audio_async( @@ -603,7 +603,7 @@ async def test_get_data_filename_emits_deprecation_warning_and_delegates(sqlite_ async def test_save_formatted_audio_azure_storage_unlinks_local_temp(tmp_path): """save_formatted_audio_async cleans up the local temp WAV after writing to Azure storage.""" - from pyrit.io import data_serializer_factory as factory + from pyrit.memory.storage import data_serializer_factory as factory serializer = factory(category="prompt-memory-entries", data_type="audio_path") mock_memory = MagicMock() @@ -614,7 +614,7 @@ async def test_save_formatted_audio_azure_storage_unlinks_local_temp(tmp_path): with ( patch.object(type(serializer), "_memory", new_callable=PropertyMock, return_value=mock_memory), patch.object(serializer, "get_data_filename_async", new_callable=AsyncMock, return_value=azure_url), - patch("pyrit.io.serializers.DB_DATA_PATH", tmp_path), + patch("pyrit.memory.storage.serializers.DB_DATA_PATH", tmp_path), ): await serializer.save_formatted_audio_async(data=b"\x00\x01\x02\x03") diff --git a/tests/unit/io/test_storage.py b/tests/unit/memory/storage/test_storage.py similarity index 99% rename from tests/unit/io/test_storage.py rename to tests/unit/memory/storage/test_storage.py index af5eceb767..6a29821e5f 100644 --- a/tests/unit/io/test_storage.py +++ b/tests/unit/memory/storage/test_storage.py @@ -6,7 +6,7 @@ import pytest -from pyrit.io.storage import ( +from pyrit.memory.storage.storage import ( AzureBlobStorageIO, DiskStorageIO, SupportedContentType, diff --git a/tests/unit/models/test_import_boundary.py b/tests/unit/models/test_import_boundary.py index 76cd97dce7..14dd6084c7 100644 --- a/tests/unit/models/test_import_boundary.py +++ b/tests/unit/models/test_import_boundary.py @@ -65,10 +65,10 @@ # layers. These are slated to relocate; the ratchet forces them to shrink. KNOWN_COMMON_VIOLATIONS: dict[str, dict[str, str]] = { "pyrit.common.data_url_converter": { - "pyrit.io": "relocate", + "pyrit.memory.storage": "relocate", }, "pyrit.common.display_response": { - "pyrit.io": "relocate", + "pyrit.memory.storage": "relocate", "pyrit.memory": "relocate", "pyrit.models": "relocate", }, diff --git a/tests/unit/output/test_blur_images.py b/tests/unit/output/test_blur_images.py index 2e8e20cc60..1e60046c8e 100644 --- a/tests/unit/output/test_blur_images.py +++ b/tests/unit/output/test_blur_images.py @@ -53,7 +53,7 @@ async def test_pretty_blurs_image_bytes_before_display(tmp_path, patch_central_d with ( patch("pyrit.common.notebook_utils.is_in_ipython_session", return_value=True), patch( - "pyrit.io.serializers.ImagePathDataTypeSerializer", + "pyrit.memory.storage.serializers.ImagePathDataTypeSerializer", return_value=fake_serializer, ), patch( @@ -93,7 +93,7 @@ async def test_pretty_does_not_blur_by_default(tmp_path, patch_central_database) with ( patch("pyrit.common.notebook_utils.is_in_ipython_session", return_value=True), patch( - "pyrit.io.serializers.ImagePathDataTypeSerializer", + "pyrit.memory.storage.serializers.ImagePathDataTypeSerializer", return_value=fake_serializer, ), patch( diff --git a/tests/unit/prompt_converter/test_pdf_converter.py b/tests/unit/prompt_converter/test_pdf_converter.py index 4bd72c2326..d2121baff7 100644 --- a/tests/unit/prompt_converter/test_pdf_converter.py +++ b/tests/unit/prompt_converter/test_pdf_converter.py @@ -11,7 +11,7 @@ from reportlab.lib.pagesizes import A4 from reportlab.pdfgen import canvas -from pyrit.io import DataTypeSerializer +from pyrit.memory.storage import DataTypeSerializer from pyrit.models import SeedPrompt from pyrit.prompt_converter import ConverterResult, PDFConverter From 98f166a647b900e0e8d85703a56fa86769b65756 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 4 Jun 2026 14:20:24 -0700 Subject: [PATCH 03/10] MAINT: Expose pyrit.memory.storage public API from pyrit.memory Re-export the storage IO classes and data-type serializers from the pyrit.memory package root so callers can use the shorthand (e.g. 'from pyrit.memory import AzureBlobStorageIO'). Update all external consumers to the shorthand. Internal pyrit.memory.* modules keep importing from pyrit.memory.storage directly to avoid an import cycle during package init, and the pyrit.common foundation-layer files keep the precise submodule import. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/backend/services/attack_service.py | 3 +- pyrit/backend/services/converter_service.py | 2 +- .../seed_datasets/remote/_image_cache.py | 2 +- .../seed_datasets/remote/msts_dataset.py | 2 +- pyrit/memory/__init__.py | 36 ++++++++++++++++++- .../chat_message_normalizer.py | 2 +- pyrit/output/conversation/pretty.py | 2 +- .../add_image_text_converter.py | 2 +- .../add_image_to_video_converter.py | 2 +- .../add_text_image_converter.py | 2 +- .../prompt_converter/audio_echo_converter.py | 2 +- .../audio_frequency_converter.py | 2 +- .../prompt_converter/audio_speed_converter.py | 2 +- .../audio_volume_converter.py | 2 +- .../audio_white_noise_converter.py | 2 +- .../azure_speech_audio_to_text_converter.py | 2 +- .../azure_speech_text_to_audio_converter.py | 2 +- .../base_image_to_image_converter.py | 2 +- .../image_compression_converter.py | 2 +- .../image_overlay_converter.py | 2 +- pyrit/prompt_converter/pdf_converter.py | 2 +- pyrit/prompt_converter/qr_code_converter.py | 2 +- .../transparency_attack_converter.py | 2 +- pyrit/prompt_converter/word_doc_converter.py | 4 +-- pyrit/prompt_normalizer/prompt_normalizer.py | 3 +- .../openai/openai_chat_target.py | 2 +- .../openai/openai_image_target.py | 2 +- .../openai/openai_realtime_target.py | 2 +- .../prompt_target/openai/openai_tts_target.py | 2 +- .../openai/openai_video_target.py | 2 +- .../playwright_copilot_target.py | 2 +- .../prompt_target/websocket_copilot_target.py | 2 +- .../azure_content_filter_scorer.py | 2 +- tests/unit/common/test_display_response.py | 4 +-- .../prompt_converter/test_pdf_converter.py | 2 +- 35 files changed, 71 insertions(+), 39 deletions(-) diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py index d4257c0002..23d58bf623 100644 --- a/pyrit/backend/services/attack_service.py +++ b/pyrit/backend/services/attack_service.py @@ -50,8 +50,7 @@ from pyrit.backend.models.common import PaginationInfo from pyrit.backend.services.converter_service import get_converter_service from pyrit.backend.services.target_service import get_target_service -from pyrit.memory import CentralMemory -from pyrit.memory.storage import data_serializer_factory +from pyrit.memory import CentralMemory, data_serializer_factory from pyrit.models import ( AttackOutcome, AttackResult, diff --git a/pyrit/backend/services/converter_service.py b/pyrit/backend/services/converter_service.py index 96daa3d72e..dca9a774dc 100644 --- a/pyrit/backend/services/converter_service.py +++ b/pyrit/backend/services/converter_service.py @@ -36,7 +36,7 @@ CreateConverterResponse, PreviewStep, ) -from pyrit.memory.storage import data_serializer_factory +from pyrit.memory import data_serializer_factory from pyrit.models import PromptDataType from pyrit.prompt_converter import PromptConverter from pyrit.prompt_target import PromptTarget diff --git a/pyrit/datasets/seed_datasets/remote/_image_cache.py b/pyrit/datasets/seed_datasets/remote/_image_cache.py index 4575f900ea..dd69bd5284 100644 --- a/pyrit/datasets/seed_datasets/remote/_image_cache.py +++ b/pyrit/datasets/seed_datasets/remote/_image_cache.py @@ -17,7 +17,7 @@ from typing import Any, Optional from pyrit.common.net_utility import make_request_and_raise_if_error_async -from pyrit.memory.storage import data_serializer_factory +from pyrit.memory import data_serializer_factory logger = logging.getLogger(__name__) diff --git a/pyrit/datasets/seed_datasets/remote/msts_dataset.py b/pyrit/datasets/seed_datasets/remote/msts_dataset.py index 264e00d4d2..d0b4a83b11 100644 --- a/pyrit/datasets/seed_datasets/remote/msts_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/msts_dataset.py @@ -12,7 +12,7 @@ from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( _RemoteDatasetLoader, ) -from pyrit.memory.storage import data_serializer_factory +from pyrit.memory import data_serializer_factory from pyrit.models import SeedDataset, SeedPrompt if TYPE_CHECKING: diff --git a/pyrit/memory/__init__.py b/pyrit/memory/__init__.py index 70acf720f5..c423f2588b 100644 --- a/pyrit/memory/__init__.py +++ b/pyrit/memory/__init__.py @@ -15,18 +15,52 @@ from pyrit.memory.memory_interface import MemoryInterface from pyrit.memory.memory_models import AttackResultEntry, EmbeddingDataEntry, PromptMemoryEntry, SeedEntry from pyrit.memory.sqlite_memory import SQLiteMemory +from pyrit.memory.storage import ( + AllowedCategories, + AudioPathDataTypeSerializer, + AzureBlobStorageIO, + BinaryPathDataTypeSerializer, + DataTypeSerializer, + DiskStorageIO, + ErrorDataTypeSerializer, + ImagePathDataTypeSerializer, + StorageIO, + SupportedContentType, + TextDataTypeSerializer, + URLDataTypeSerializer, + VideoPathDataTypeSerializer, + data_serializer_factory, + set_message_piece_sha256_async, + set_seed_sha256_async, +) __all__ = [ + "AllowedCategories", "AttackResultEntry", + "AudioPathDataTypeSerializer", + "AzureBlobStorageIO", "AzureSQLMemory", + "BinaryPathDataTypeSerializer", "CentralMemory", - "SQLiteMemory", + "DataTypeSerializer", + "data_serializer_factory", + "DiskStorageIO", "EmbeddingDataEntry", + "ErrorDataTypeSerializer", + "ImagePathDataTypeSerializer", "MemoryInterface", "MemoryEmbedding", "MemoryExporter", "PromptMemoryEntry", "SeedEntry", + "set_message_piece_sha256_async", + "set_seed_sha256_async", + "SQLiteMemory", + "StorageIO", + "SupportedContentType", + "TextDataTypeSerializer", + "URLDataTypeSerializer", + "VideoPathDataTypeSerializer", ] diff --git a/pyrit/message_normalizer/chat_message_normalizer.py b/pyrit/message_normalizer/chat_message_normalizer.py index 454045794b..1613ec61cf 100644 --- a/pyrit/message_normalizer/chat_message_normalizer.py +++ b/pyrit/message_normalizer/chat_message_normalizer.py @@ -9,7 +9,7 @@ import aiofiles from pyrit.common.data_url_converter import convert_local_image_to_data_url_async -from pyrit.memory.storage import DataTypeSerializer +from pyrit.memory import DataTypeSerializer from pyrit.message_normalizer.message_normalizer import ( MessageListNormalizer, MessageStringNormalizer, diff --git a/pyrit/output/conversation/pretty.py b/pyrit/output/conversation/pretty.py index f77e751d80..10aff03afb 100644 --- a/pyrit/output/conversation/pretty.py +++ b/pyrit/output/conversation/pretty.py @@ -338,7 +338,7 @@ async def _display_image_async(self, piece: MessagePiece) -> None: if not is_in_ipython_session(): return - from pyrit.memory.storage import ImagePathDataTypeSerializer + from pyrit.memory import ImagePathDataTypeSerializer try: serializer = ImagePathDataTypeSerializer(category="", prompt_text=piece.converted_value) diff --git a/pyrit/prompt_converter/add_image_text_converter.py b/pyrit/prompt_converter/add_image_text_converter.py index 5a539a0362..e030930cca 100644 --- a/pyrit/prompt_converter/add_image_text_converter.py +++ b/pyrit/prompt_converter/add_image_text_converter.py @@ -10,7 +10,7 @@ from PIL.ImageFont import FreeTypeFont from pyrit.common.deprecation import print_deprecation_message -from pyrit.memory.storage import data_serializer_factory +from pyrit.memory import data_serializer_factory from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.base_image_text_converter import _BaseImageTextConverter from pyrit.prompt_converter.prompt_converter import ConverterResult diff --git a/pyrit/prompt_converter/add_image_to_video_converter.py b/pyrit/prompt_converter/add_image_to_video_converter.py index b3ef5c8731..e9b162b4fd 100644 --- a/pyrit/prompt_converter/add_image_to_video_converter.py +++ b/pyrit/prompt_converter/add_image_to_video_converter.py @@ -10,7 +10,7 @@ import numpy as np from pyrit.common.path import DB_DATA_PATH -from pyrit.memory.storage import data_serializer_factory +from pyrit.memory import data_serializer_factory from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter diff --git a/pyrit/prompt_converter/add_text_image_converter.py b/pyrit/prompt_converter/add_text_image_converter.py index 0261fd9212..a4576f8e41 100644 --- a/pyrit/prompt_converter/add_text_image_converter.py +++ b/pyrit/prompt_converter/add_text_image_converter.py @@ -11,7 +11,7 @@ from PIL.ImageFont import FreeTypeFont from pyrit.common.deprecation import print_deprecation_message -from pyrit.memory.storage import data_serializer_factory +from pyrit.memory import data_serializer_factory from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.base_image_text_converter import _BaseImageTextConverter from pyrit.prompt_converter.prompt_converter import ConverterResult diff --git a/pyrit/prompt_converter/audio_echo_converter.py b/pyrit/prompt_converter/audio_echo_converter.py index b854bf57ac..176b8fa219 100644 --- a/pyrit/prompt_converter/audio_echo_converter.py +++ b/pyrit/prompt_converter/audio_echo_converter.py @@ -8,7 +8,7 @@ import numpy as np from scipy.io import wavfile -from pyrit.memory.storage import data_serializer_factory +from pyrit.memory import data_serializer_factory from pyrit.models import PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter diff --git a/pyrit/prompt_converter/audio_frequency_converter.py b/pyrit/prompt_converter/audio_frequency_converter.py index dc4e42c36b..cada5a407e 100644 --- a/pyrit/prompt_converter/audio_frequency_converter.py +++ b/pyrit/prompt_converter/audio_frequency_converter.py @@ -8,7 +8,7 @@ import numpy as np from scipy.io import wavfile -from pyrit.memory.storage import data_serializer_factory +from pyrit.memory import data_serializer_factory from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter diff --git a/pyrit/prompt_converter/audio_speed_converter.py b/pyrit/prompt_converter/audio_speed_converter.py index 038ee6e387..1eb81538cd 100644 --- a/pyrit/prompt_converter/audio_speed_converter.py +++ b/pyrit/prompt_converter/audio_speed_converter.py @@ -9,7 +9,7 @@ from scipy.interpolate import interp1d from scipy.io import wavfile -from pyrit.memory.storage import data_serializer_factory +from pyrit.memory import data_serializer_factory from pyrit.models import PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter diff --git a/pyrit/prompt_converter/audio_volume_converter.py b/pyrit/prompt_converter/audio_volume_converter.py index dc0acc1acd..5bc088af33 100644 --- a/pyrit/prompt_converter/audio_volume_converter.py +++ b/pyrit/prompt_converter/audio_volume_converter.py @@ -8,7 +8,7 @@ import numpy as np from scipy.io import wavfile -from pyrit.memory.storage import data_serializer_factory +from pyrit.memory import data_serializer_factory from pyrit.models import PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter diff --git a/pyrit/prompt_converter/audio_white_noise_converter.py b/pyrit/prompt_converter/audio_white_noise_converter.py index 9769934d24..187a0de019 100644 --- a/pyrit/prompt_converter/audio_white_noise_converter.py +++ b/pyrit/prompt_converter/audio_white_noise_converter.py @@ -8,7 +8,7 @@ import numpy as np from scipy.io import wavfile -from pyrit.memory.storage import data_serializer_factory +from pyrit.memory import data_serializer_factory from pyrit.models import PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter diff --git a/pyrit/prompt_converter/azure_speech_audio_to_text_converter.py b/pyrit/prompt_converter/azure_speech_audio_to_text_converter.py index 84e15a389f..6f972812d1 100644 --- a/pyrit/prompt_converter/azure_speech_audio_to_text_converter.py +++ b/pyrit/prompt_converter/azure_speech_audio_to_text_converter.py @@ -13,7 +13,7 @@ from pyrit.auth.azure_auth import get_speech_config, get_speech_config_async from pyrit.common import default_values from pyrit.common.deprecation import print_deprecation_message -from pyrit.memory.storage import data_serializer_factory +from pyrit.memory import data_serializer_factory from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter diff --git a/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py b/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py index 6071d25cf9..2ff7be3ef1 100644 --- a/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py +++ b/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py @@ -11,7 +11,7 @@ from pyrit.auth.azure_auth import get_speech_config_async from pyrit.common import default_values from pyrit.common.deprecation import print_deprecation_message -from pyrit.memory.storage import data_serializer_factory +from pyrit.memory import data_serializer_factory from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter diff --git a/pyrit/prompt_converter/base_image_to_image_converter.py b/pyrit/prompt_converter/base_image_to_image_converter.py index 8ed42b5685..87434ec0e4 100644 --- a/pyrit/prompt_converter/base_image_to_image_converter.py +++ b/pyrit/prompt_converter/base_image_to_image_converter.py @@ -11,7 +11,7 @@ import aiohttp from PIL import Image -from pyrit.memory.storage import data_serializer_factory +from pyrit.memory import data_serializer_factory from pyrit.models import PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter diff --git a/pyrit/prompt_converter/image_compression_converter.py b/pyrit/prompt_converter/image_compression_converter.py index 11ec45dacb..c0a1fbd27b 100644 --- a/pyrit/prompt_converter/image_compression_converter.py +++ b/pyrit/prompt_converter/image_compression_converter.py @@ -10,7 +10,7 @@ import aiohttp from PIL import Image -from pyrit.memory.storage import data_serializer_factory +from pyrit.memory import data_serializer_factory from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter diff --git a/pyrit/prompt_converter/image_overlay_converter.py b/pyrit/prompt_converter/image_overlay_converter.py index b73c9c519b..ef416cfb02 100644 --- a/pyrit/prompt_converter/image_overlay_converter.py +++ b/pyrit/prompt_converter/image_overlay_converter.py @@ -7,7 +7,7 @@ from PIL import Image -from pyrit.memory.storage import DataTypeSerializer, data_serializer_factory +from pyrit.memory import DataTypeSerializer, data_serializer_factory from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter diff --git a/pyrit/prompt_converter/pdf_converter.py b/pyrit/prompt_converter/pdf_converter.py index 52bfefad00..3c2b54a8d6 100644 --- a/pyrit/prompt_converter/pdf_converter.py +++ b/pyrit/prompt_converter/pdf_converter.py @@ -13,7 +13,7 @@ from reportlab.pdfgen import canvas from pyrit.common.logger import logger -from pyrit.memory.storage import DataTypeSerializer, data_serializer_factory +from pyrit.memory import DataTypeSerializer, data_serializer_factory from pyrit.models import ComponentIdentifier, PromptDataType, SeedPrompt from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter diff --git a/pyrit/prompt_converter/qr_code_converter.py b/pyrit/prompt_converter/qr_code_converter.py index 042ed54341..5098266d37 100644 --- a/pyrit/prompt_converter/qr_code_converter.py +++ b/pyrit/prompt_converter/qr_code_converter.py @@ -5,7 +5,7 @@ import segno -from pyrit.memory.storage import data_serializer_factory +from pyrit.memory import data_serializer_factory from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter diff --git a/pyrit/prompt_converter/transparency_attack_converter.py b/pyrit/prompt_converter/transparency_attack_converter.py index cd8833c4cb..7301504450 100644 --- a/pyrit/prompt_converter/transparency_attack_converter.py +++ b/pyrit/prompt_converter/transparency_attack_converter.py @@ -10,7 +10,7 @@ import numpy as np from PIL import Image -from pyrit.memory.storage import data_serializer_factory +from pyrit.memory import data_serializer_factory from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter diff --git a/pyrit/prompt_converter/word_doc_converter.py b/pyrit/prompt_converter/word_doc_converter.py index 47bb8bd010..3c005cfb88 100644 --- a/pyrit/prompt_converter/word_doc_converter.py +++ b/pyrit/prompt_converter/word_doc_converter.py @@ -12,13 +12,13 @@ from docx import Document from pyrit.common.logger import logger -from pyrit.memory.storage import data_serializer_factory +from pyrit.memory import data_serializer_factory from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter if TYPE_CHECKING: from pathlib import Path - from pyrit.memory.storage import DataTypeSerializer + from pyrit.memory import DataTypeSerializer from pyrit.models import ComponentIdentifier, PromptDataType, SeedPrompt diff --git a/pyrit/prompt_normalizer/prompt_normalizer.py b/pyrit/prompt_normalizer/prompt_normalizer.py index 21ba187be3..7901a49954 100644 --- a/pyrit/prompt_normalizer/prompt_normalizer.py +++ b/pyrit/prompt_normalizer/prompt_normalizer.py @@ -19,8 +19,7 @@ execution_context, get_execution_context, ) -from pyrit.memory import CentralMemory, MemoryInterface -from pyrit.memory.storage import set_message_piece_sha256_async +from pyrit.memory import CentralMemory, MemoryInterface, set_message_piece_sha256_async from pyrit.models import ( ComponentIdentifier, Message, diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index 5c954fb83d..a771f6e33f 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -13,7 +13,7 @@ PyritException, pyrit_target_retry, ) -from pyrit.memory.storage import DataTypeSerializer, data_serializer_factory +from pyrit.memory import DataTypeSerializer, data_serializer_factory from pyrit.models import ChatMessage, ComponentIdentifier, Message, MessagePiece, construct_response_from_request from pyrit.models.json_response_config import _JsonResponseConfig from pyrit.prompt_target.common.target_capabilities import TargetCapabilities diff --git a/pyrit/prompt_target/openai/openai_image_target.py b/pyrit/prompt_target/openai/openai_image_target.py index f47c173270..e54e239051 100644 --- a/pyrit/prompt_target/openai/openai_image_target.py +++ b/pyrit/prompt_target/openai/openai_image_target.py @@ -11,7 +11,7 @@ EmptyResponseException, pyrit_target_retry, ) -from pyrit.memory.storage import data_serializer_factory +from pyrit.memory import data_serializer_factory from pyrit.models import ComponentIdentifier, Message, construct_response_from_request from pyrit.prompt_target.common.target_capabilities import TargetCapabilities from pyrit.prompt_target.common.target_configuration import TargetConfiguration diff --git a/pyrit/prompt_target/openai/openai_realtime_target.py b/pyrit/prompt_target/openai/openai_realtime_target.py index bcfc616061..9d226490e6 100644 --- a/pyrit/prompt_target/openai/openai_realtime_target.py +++ b/pyrit/prompt_target/openai/openai_realtime_target.py @@ -15,7 +15,7 @@ pyrit_target_retry, ) from pyrit.exceptions.exception_classes import ServerErrorException -from pyrit.memory.storage import data_serializer_factory +from pyrit.memory import data_serializer_factory from pyrit.models import ComponentIdentifier, Message, construct_response_from_request from pyrit.prompt_target.common.realtime_audio import ( RealtimeTargetResult, diff --git a/pyrit/prompt_target/openai/openai_tts_target.py b/pyrit/prompt_target/openai/openai_tts_target.py index 4d6b031824..f7033d62ea 100644 --- a/pyrit/prompt_target/openai/openai_tts_target.py +++ b/pyrit/prompt_target/openai/openai_tts_target.py @@ -7,7 +7,7 @@ from pyrit.exceptions import ( pyrit_target_retry, ) -from pyrit.memory.storage import data_serializer_factory +from pyrit.memory import data_serializer_factory from pyrit.models import ComponentIdentifier, Message, construct_response_from_request from pyrit.prompt_target.common.target_capabilities import TargetCapabilities from pyrit.prompt_target.common.target_configuration import TargetConfiguration diff --git a/pyrit/prompt_target/openai/openai_video_target.py b/pyrit/prompt_target/openai/openai_video_target.py index 9e63bed485..40cffc8494 100644 --- a/pyrit/prompt_target/openai/openai_video_target.py +++ b/pyrit/prompt_target/openai/openai_video_target.py @@ -11,7 +11,7 @@ from pyrit.exceptions import ( pyrit_target_retry, ) -from pyrit.memory.storage import DataTypeSerializer, data_serializer_factory +from pyrit.memory import DataTypeSerializer, data_serializer_factory from pyrit.models import ComponentIdentifier, Message, MessagePiece, construct_response_from_request from pyrit.prompt_target.common.target_capabilities import TargetCapabilities from pyrit.prompt_target.common.target_configuration import TargetConfiguration diff --git a/pyrit/prompt_target/playwright_copilot_target.py b/pyrit/prompt_target/playwright_copilot_target.py index a92ac37f49..d83a70b75c 100644 --- a/pyrit/prompt_target/playwright_copilot_target.py +++ b/pyrit/prompt_target/playwright_copilot_target.py @@ -9,7 +9,7 @@ from enum import Enum from typing import TYPE_CHECKING, Any, Optional, Union -from pyrit.memory.storage import data_serializer_factory +from pyrit.memory import data_serializer_factory from pyrit.models import ComponentIdentifier, Message, MessagePiece, construct_response_from_request from pyrit.models.literals import PromptDataType from pyrit.prompt_target.common.prompt_target import PromptTarget diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index 53aa7aa23b..96d22f05bb 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -19,7 +19,7 @@ EmptyResponseException, pyrit_target_retry, ) -from pyrit.memory.storage import DataTypeSerializer +from pyrit.memory import DataTypeSerializer from pyrit.models import ComponentIdentifier, Message, MessagePiece, construct_response_from_request from pyrit.prompt_target import PromptTarget, limit_requests_per_minute from pyrit.prompt_target.common.target_capabilities import TargetCapabilities diff --git a/pyrit/score/float_scale/azure_content_filter_scorer.py b/pyrit/score/float_scale/azure_content_filter_scorer.py index 33e74929b5..ec244b313c 100644 --- a/pyrit/score/float_scale/azure_content_filter_scorer.py +++ b/pyrit/score/float_scale/azure_content_filter_scorer.py @@ -19,7 +19,7 @@ from pyrit.auth import AsyncTokenProviderCredential, ensure_async_token_provider, get_azure_async_token_provider from pyrit.common import default_values -from pyrit.memory.storage import DataTypeSerializer, data_serializer_factory +from pyrit.memory import DataTypeSerializer, data_serializer_factory from pyrit.models import ComponentIdentifier, Message, MessagePiece, Score from pyrit.score.float_scale.float_scale_score_aggregator import ( FloatScaleScorerByCategory, diff --git a/tests/unit/common/test_display_response.py b/tests/unit/common/test_display_response.py index 1b83f1b530..faac0e90ce 100644 --- a/tests/unit/common/test_display_response.py +++ b/tests/unit/common/test_display_response.py @@ -98,7 +98,7 @@ async def test_display_image_logs_error_when_storage_io_is_none(mock_ipython, ca @patch("pyrit.common.display_response.display", create=True) async def test_display_image_azure_fallback_to_disk(mock_display, mock_image, mock_disk_io_cls, mock_ipython): """Test that when AzureBlobStorageIO read fails, it falls back to DiskStorageIO.""" - from pyrit.memory.storage import AzureBlobStorageIO + from pyrit.memory import AzureBlobStorageIO mock_memory = MagicMock() mock_azure_io = MagicMock(spec=AzureBlobStorageIO) @@ -126,7 +126,7 @@ async def test_display_image_azure_fallback_to_disk(mock_display, mock_image, mo @patch("pyrit.common.display_response.DiskStorageIO") async def test_display_image_azure_and_disk_both_fail(mock_disk_io_cls, mock_ipython, caplog): """Test that when both AzureBlobStorageIO and DiskStorageIO fail, error is logged and returns.""" - from pyrit.memory.storage import AzureBlobStorageIO + from pyrit.memory import AzureBlobStorageIO mock_memory = MagicMock() mock_azure_io = MagicMock(spec=AzureBlobStorageIO) diff --git a/tests/unit/prompt_converter/test_pdf_converter.py b/tests/unit/prompt_converter/test_pdf_converter.py index d2121baff7..9ba93066d4 100644 --- a/tests/unit/prompt_converter/test_pdf_converter.py +++ b/tests/unit/prompt_converter/test_pdf_converter.py @@ -11,7 +11,7 @@ from reportlab.lib.pagesizes import A4 from reportlab.pdfgen import canvas -from pyrit.memory.storage import DataTypeSerializer +from pyrit.memory import DataTypeSerializer from pyrit.models import SeedPrompt from pyrit.prompt_converter import ConverterResult, PDFConverter From c0e8365bea1396ba9353e8d332cf255ea0fa04d6 Mon Sep 17 00:00:00 2001 From: Richard Lundeen <137218279+rlundeen2@users.noreply.github.com> Date: Fri, 5 Jun 2026 13:13:51 -0700 Subject: [PATCH 04/10] Update pyrit/common/data_url_converter.py Co-authored-by: Roman Lutz --- pyrit/common/data_url_converter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrit/common/data_url_converter.py b/pyrit/common/data_url_converter.py index 9a81323c81..6fef6337d6 100644 --- a/pyrit/common/data_url_converter.py +++ b/pyrit/common/data_url_converter.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. from pyrit.common.deprecation import print_deprecation_message -from pyrit.memory.storage import DataTypeSerializer, data_serializer_factory +from pyrit.memory import DataTypeSerializer, data_serializer_factory # Supported image formats for Azure OpenAI GPT-4o, # https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/use-your-image-data From 439601fe39857906abe8535e8d1797c73e641ee6 Mon Sep 17 00:00:00 2001 From: Richard Lundeen <137218279+rlundeen2@users.noreply.github.com> Date: Fri, 5 Jun 2026 13:14:39 -0700 Subject: [PATCH 05/10] Update pyrit/prompt_target/openai/openai_chat_target.py Co-authored-by: Roman Lutz --- pyrit/prompt_target/openai/openai_chat_target.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index a771f6e33f..79efbc3703 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -20,7 +20,8 @@ from pyrit.prompt_target.common.target_configuration import TargetConfiguration from pyrit.prompt_target.common.utils import limit_requests_per_minute, validate_temperature, validate_top_p from pyrit.prompt_target.openai.openai_chat_audio_config import OpenAIChatAudioConfig -from pyrit.prompt_target.openai.openai_target import OpenAITarget +from pyrit.prompt_target import ( +TargetCapabilities, TargetConfiguration, limit_requests_per_minute, validate_temperature, validate_top_p, OpenAIChatAudioConfig, OpenAITarget) logger = logging.getLogger(__name__) From 79cbb42d94b94d8e666303c61e8c565fd205aa6e Mon Sep 17 00:00:00 2001 From: Richard Lundeen <137218279+rlundeen2@users.noreply.github.com> Date: Fri, 5 Jun 2026 13:14:55 -0700 Subject: [PATCH 06/10] Update pyrit/message_normalizer/chat_message_normalizer.py Co-authored-by: Roman Lutz --- pyrit/message_normalizer/chat_message_normalizer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyrit/message_normalizer/chat_message_normalizer.py b/pyrit/message_normalizer/chat_message_normalizer.py index 1613ec61cf..bf092a569a 100644 --- a/pyrit/message_normalizer/chat_message_normalizer.py +++ b/pyrit/message_normalizer/chat_message_normalizer.py @@ -16,8 +16,7 @@ SystemMessageBehavior, apply_system_message_behavior_async, ) -from pyrit.models import ChatMessage, Message -from pyrit.models.messages.message_piece import MessagePiece +from pyrit.models import ChatMessage, Message, MessagePiece if TYPE_CHECKING: from pyrit.models.literals import ChatMessageRole From b466d6bf12c6a2bef0ebac9cfcf526cf02c45ac9 Mon Sep 17 00:00:00 2001 From: Richard Lundeen <137218279+rlundeen2@users.noreply.github.com> Date: Fri, 5 Jun 2026 13:15:05 -0700 Subject: [PATCH 07/10] Update pyrit/message_normalizer/chat_message_normalizer.py Co-authored-by: Roman Lutz --- pyrit/message_normalizer/chat_message_normalizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrit/message_normalizer/chat_message_normalizer.py b/pyrit/message_normalizer/chat_message_normalizer.py index bf092a569a..383b7410a0 100644 --- a/pyrit/message_normalizer/chat_message_normalizer.py +++ b/pyrit/message_normalizer/chat_message_normalizer.py @@ -10,7 +10,7 @@ from pyrit.common.data_url_converter import convert_local_image_to_data_url_async from pyrit.memory import DataTypeSerializer -from pyrit.message_normalizer.message_normalizer import ( +from pyrit.message_normalizer import ( MessageListNormalizer, MessageStringNormalizer, SystemMessageBehavior, From 584789272c4fa73204ca7ad35bd542c83b67fb1a Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Fri, 5 Jun 2026 13:25:01 -0700 Subject: [PATCH 08/10] Fix import shortenings: prefer module-level, fall back to submodule for cycles - display_response: consolidate to single 'from pyrit.memory import' shorthand - chat_message_normalizer/openai_chat_target: revert cycle-inducing package-root self-imports to direct submodule imports - test_import_boundary: update KNOWN_COMMON_VIOLATIONS allowlist to match shortened pyrit.memory imports Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/common/display_response.py | 3 +-- pyrit/message_normalizer/chat_message_normalizer.py | 2 +- pyrit/prompt_target/openai/openai_chat_target.py | 3 +-- tests/unit/models/test_import_boundary.py | 3 +-- 4 files changed, 4 insertions(+), 7 deletions(-) diff --git a/pyrit/common/display_response.py b/pyrit/common/display_response.py index aa67997e6d..aa00ff279f 100644 --- a/pyrit/common/display_response.py +++ b/pyrit/common/display_response.py @@ -8,8 +8,7 @@ from pyrit.common.deprecation import print_deprecation_message from pyrit.common.notebook_utils import is_in_ipython_session -from pyrit.memory import CentralMemory -from pyrit.memory.storage import AzureBlobStorageIO, DiskStorageIO +from pyrit.memory import AzureBlobStorageIO, CentralMemory, DiskStorageIO from pyrit.models import MessagePiece logger = logging.getLogger(__name__) diff --git a/pyrit/message_normalizer/chat_message_normalizer.py b/pyrit/message_normalizer/chat_message_normalizer.py index 383b7410a0..bf092a569a 100644 --- a/pyrit/message_normalizer/chat_message_normalizer.py +++ b/pyrit/message_normalizer/chat_message_normalizer.py @@ -10,7 +10,7 @@ from pyrit.common.data_url_converter import convert_local_image_to_data_url_async from pyrit.memory import DataTypeSerializer -from pyrit.message_normalizer import ( +from pyrit.message_normalizer.message_normalizer import ( MessageListNormalizer, MessageStringNormalizer, SystemMessageBehavior, diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index 79efbc3703..a771f6e33f 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -20,8 +20,7 @@ from pyrit.prompt_target.common.target_configuration import TargetConfiguration from pyrit.prompt_target.common.utils import limit_requests_per_minute, validate_temperature, validate_top_p from pyrit.prompt_target.openai.openai_chat_audio_config import OpenAIChatAudioConfig -from pyrit.prompt_target import ( -TargetCapabilities, TargetConfiguration, limit_requests_per_minute, validate_temperature, validate_top_p, OpenAIChatAudioConfig, OpenAITarget) +from pyrit.prompt_target.openai.openai_target import OpenAITarget logger = logging.getLogger(__name__) diff --git a/tests/unit/models/test_import_boundary.py b/tests/unit/models/test_import_boundary.py index f9721e5e2f..bcb52c1cd6 100644 --- a/tests/unit/models/test_import_boundary.py +++ b/tests/unit/models/test_import_boundary.py @@ -65,10 +65,9 @@ # layers. These are slated to relocate; the ratchet forces them to shrink. KNOWN_COMMON_VIOLATIONS: dict[str, dict[str, str]] = { "pyrit.common.data_url_converter": { - "pyrit.memory.storage": "relocate", + "pyrit.memory": "relocate", }, "pyrit.common.display_response": { - "pyrit.memory.storage": "relocate", "pyrit.memory": "relocate", "pyrit.models": "relocate", }, From d603e65dd319a32670d325206321eca9b9d9d7e9 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Fri, 5 Jun 2026 16:10:46 -0700 Subject: [PATCH 09/10] Modernize storage typing to PEP 604 after merging main - Apply ruff UP007/UP045 (Optional/Union -> X | None) to the relocated serializers.py and storage.py to match the repo-wide modernization (origin/main #1884) - Add ty pragma for azure readall() str|bytes stub mismatch surfaced by the post-merge ty hardening (#1931) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/memory/storage/serializers.py | 38 ++++++++++----------- pyrit/memory/storage/storage.py | 52 ++++++++++++++--------------- uv.lock | 2 +- 3 files changed, 46 insertions(+), 46 deletions(-) diff --git a/pyrit/memory/storage/serializers.py b/pyrit/memory/storage/serializers.py index c9367284a3..7a4e84ff14 100644 --- a/pyrit/memory/storage/serializers.py +++ b/pyrit/memory/storage/serializers.py @@ -12,7 +12,7 @@ import wave from mimetypes import guess_type from pathlib import Path -from typing import TYPE_CHECKING, Literal, Optional, Union, get_args +from typing import TYPE_CHECKING, Literal, get_args from urllib.parse import urlparse import aiofiles @@ -50,8 +50,8 @@ def _write_wav_sync( def data_serializer_factory( *, data_type: PromptDataType, - value: Optional[str] = None, - extension: Optional[str] = None, + value: str | None = None, + extension: str | None = None, category: AllowedCategories, ) -> DataTypeSerializer: """ @@ -116,7 +116,7 @@ class DataTypeSerializer(abc.ABC): data_sub_directory: str file_extension: str - _file_path: Union[Path, str] | None = None + _file_path: Path | str | None = None @property def _memory(self) -> MemoryInterface: @@ -154,7 +154,7 @@ def data_on_disk(self) -> bool: """ - async def save_data_async(self, data: bytes, output_filename: Optional[str] = None) -> None: + async def save_data_async(self, data: bytes, output_filename: str | None = None) -> None: """ Save data to storage. @@ -195,7 +195,7 @@ async def save_formatted_audio_async( num_channels: int = 1, sample_width: int = 2, sample_rate: int = 16000, - output_filename: Optional[str] = None, + output_filename: str | None = None, ) -> None: """ Save PCM16 or similarly formatted audio data to storage. @@ -315,7 +315,7 @@ async def get_sha256_async(self) -> str: hash_object = hashlib.sha256(input_bytes) return hash_object.hexdigest() - async def get_data_filename_async(self, file_name: Optional[str] = None) -> Union[Path, str]: + async def get_data_filename_async(self, file_name: str | None = None) -> Path | str: """ Generate or retrieve a unique filename for the data file. @@ -361,7 +361,7 @@ async def get_data_filename_async(self, file_name: Optional[str] = None) -> Unio return self._file_path async def save_data( # pyrit-async-suffix-exempt - self, data: bytes, output_filename: Optional[str] = None + self, data: bytes, output_filename: str | None = None ) -> None: """ Save data to storage (deprecated alias of ``save_data_async``). @@ -400,7 +400,7 @@ async def save_formatted_audio( # pyrit-async-suffix-exempt num_channels: int = 1, sample_width: int = 2, sample_rate: int = 16000, - output_filename: Optional[str] = None, + output_filename: str | None = None, ) -> None: """ Save formatted audio data to storage (deprecated alias of ``save_formatted_audio_async``). @@ -462,8 +462,8 @@ async def get_sha256(self) -> str: # pyrit-async-suffix-exempt return await self.get_sha256_async() async def get_data_filename( # pyrit-async-suffix-exempt - self, file_name: Optional[str] = None - ) -> Union[Path, str]: + self, file_name: str | None = None + ) -> Path | str: """ Generate or retrieve a unique filename for the data file (deprecated alias of ``get_data_filename_async``). @@ -579,7 +579,7 @@ def data_on_disk(self) -> bool: class URLDataTypeSerializer(DataTypeSerializer): """Serializer for URL values and URL-backed local file references.""" - def __init__(self, *, category: str, prompt_text: str, extension: Optional[str] = None) -> None: + def __init__(self, *, category: str, prompt_text: str, extension: str | None = None) -> None: """ Initialize a URL serializer. @@ -609,7 +609,7 @@ def data_on_disk(self) -> bool: class ImagePathDataTypeSerializer(DataTypeSerializer): """Serializer for image path values stored on disk.""" - def __init__(self, *, category: str, prompt_text: Optional[str] = None, extension: Optional[str] = None) -> None: + def __init__(self, *, category: str, prompt_text: str | None = None, extension: str | None = None) -> None: """ Initialize an image-path serializer. @@ -644,8 +644,8 @@ def __init__( self, *, category: str, - prompt_text: Optional[str] = None, - extension: Optional[str] = None, + prompt_text: str | None = None, + extension: str | None = None, ) -> None: """ Initialize an audio-path serializer. @@ -681,8 +681,8 @@ def __init__( self, *, category: str, - prompt_text: Optional[str] = None, - extension: Optional[str] = None, + prompt_text: str | None = None, + extension: str | None = None, ) -> None: """ Initialize a video-path serializer. @@ -718,8 +718,8 @@ def __init__( self, *, category: str, - prompt_text: Optional[str] = None, - extension: Optional[str] = None, + prompt_text: str | None = None, + extension: str | None = None, ) -> None: """ Initialize a generic binary-path serializer. diff --git a/pyrit/memory/storage/storage.py b/pyrit/memory/storage/storage.py index bf79a048e0..aeb084c9c6 100644 --- a/pyrit/memory/storage/storage.py +++ b/pyrit/memory/storage/storage.py @@ -7,7 +7,7 @@ from abc import ABC, abstractmethod from enum import Enum from pathlib import Path -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING from urllib.parse import urlparse import aiofiles @@ -36,36 +36,36 @@ class StorageIO(ABC): """ @abstractmethod - async def read_file_async(self, path: Union[Path, str]) -> bytes: + async def read_file_async(self, path: Path | str) -> bytes: """ Asynchronously reads the file (or blob) from the given path. """ @abstractmethod - async def write_file_async(self, path: Union[Path, str], data: bytes) -> None: + async def write_file_async(self, path: Path | str, data: bytes) -> None: """ Asynchronously writes data to the given path. """ @abstractmethod - async def path_exists_async(self, path: Union[Path, str]) -> bool: + async def path_exists_async(self, path: Path | str) -> bool: """ Asynchronously checks if a file or blob exists at the given path. """ @abstractmethod - async def is_file_async(self, path: Union[Path, str]) -> bool: + async def is_file_async(self, path: Path | str) -> bool: """ Asynchronously checks if the path refers to a file (not a directory or container). """ @abstractmethod - async def create_directory_if_not_exists_async(self, path: Union[Path, str]) -> None: + async def create_directory_if_not_exists_async(self, path: Path | str) -> None: """ Asynchronously creates a directory or equivalent in the storage system if it doesn't exist. """ - async def read_file(self, path: Union[Path, str]) -> bytes: # pyrit-async-suffix-exempt + async def read_file(self, path: Path | str) -> bytes: # pyrit-async-suffix-exempt """ Read a file from storage (deprecated alias of ``read_file_async``). @@ -82,7 +82,7 @@ async def read_file(self, path: Union[Path, str]) -> bytes: # pyrit-async-suffi ) return await self.read_file_async(path) - async def write_file(self, path: Union[Path, str], data: bytes) -> None: # pyrit-async-suffix-exempt + async def write_file(self, path: Path | str, data: bytes) -> None: # pyrit-async-suffix-exempt """ Write data to storage (deprecated alias of ``write_file_async``). @@ -97,7 +97,7 @@ async def write_file(self, path: Union[Path, str], data: bytes) -> None: # pyri ) await self.write_file_async(path, data) - async def path_exists(self, path: Union[Path, str]) -> bool: # pyrit-async-suffix-exempt + async def path_exists(self, path: Path | str) -> bool: # pyrit-async-suffix-exempt """ Check whether a path exists (deprecated alias of ``path_exists_async``). @@ -114,7 +114,7 @@ async def path_exists(self, path: Union[Path, str]) -> bool: # pyrit-async-suff ) return await self.path_exists_async(path) - async def is_file(self, path: Union[Path, str]) -> bool: # pyrit-async-suffix-exempt + async def is_file(self, path: Path | str) -> bool: # pyrit-async-suffix-exempt """ Check whether the given path is a file (deprecated alias of ``is_file_async``). @@ -131,7 +131,7 @@ async def is_file(self, path: Union[Path, str]) -> bool: # pyrit-async-suffix-e ) return await self.is_file_async(path) - async def create_directory_if_not_exists(self, path: Union[Path, str]) -> None: # pyrit-async-suffix-exempt + async def create_directory_if_not_exists(self, path: Path | str) -> None: # pyrit-async-suffix-exempt """ Create a directory if it does not exist (deprecated alias of ``create_directory_if_not_exists_async``). @@ -151,7 +151,7 @@ class DiskStorageIO(StorageIO): Implementation of StorageIO for local disk storage. """ - async def read_file_async(self, path: Union[Path, str]) -> bytes: + async def read_file_async(self, path: Path | str) -> bytes: """ Asynchronously reads a file from the local disk. @@ -166,7 +166,7 @@ async def read_file_async(self, path: Union[Path, str]) -> bytes: async with aiofiles.open(path, "rb") as file: return await file.read() - async def write_file_async(self, path: Union[Path, str], data: bytes) -> None: + async def write_file_async(self, path: Path | str, data: bytes) -> None: """ Asynchronously writes data to a file on the local disk. @@ -179,7 +179,7 @@ async def write_file_async(self, path: Union[Path, str], data: bytes) -> None: async with aiofiles.open(path, "wb") as file: await file.write(data) - async def path_exists_async(self, path: Union[Path, str]) -> bool: + async def path_exists_async(self, path: Path | str) -> bool: """ Check whether a path exists on the local disk. @@ -193,7 +193,7 @@ async def path_exists_async(self, path: Union[Path, str]) -> bool: path = self._convert_to_path(path) return path.exists() - async def is_file_async(self, path: Union[Path, str]) -> bool: + async def is_file_async(self, path: Path | str) -> bool: """ Check whether the given path is a file (not a directory). @@ -207,7 +207,7 @@ async def is_file_async(self, path: Union[Path, str]) -> bool: path = self._convert_to_path(path) return path.is_file() - async def create_directory_if_not_exists_async(self, path: Union[Path, str]) -> None: + async def create_directory_if_not_exists_async(self, path: Path | str) -> None: """ Asynchronously creates a directory if it doesn't exist on the local disk. @@ -219,7 +219,7 @@ async def create_directory_if_not_exists_async(self, path: Union[Path, str]) -> if not directory_path.exists(): directory_path.mkdir(parents=True, exist_ok=True) - def _convert_to_path(self, path: Union[Path, str]) -> Path: + def _convert_to_path(self, path: Path | str) -> Path: """ Convert an input path to a Path object. @@ -241,8 +241,8 @@ class AzureBlobStorageIO(StorageIO): def __init__( self, *, - container_url: Optional[str] = None, - sas_token: Optional[str] = None, + container_url: str | None = None, + sas_token: str | None = None, blob_content_type: SupportedContentType = SupportedContentType.PLAIN_TEXT, ) -> None: """ @@ -351,7 +351,7 @@ def parse_blob_url(self, file_path: str) -> tuple[str, str]: return container_name, blob_name raise ValueError("Invalid blob URL") - def _resolve_blob_name(self, path: Union[Path, str]) -> str: + def _resolve_blob_name(self, path: Path | str) -> str: """ Resolve a blob name from either a full blob URL or a relative blob path. @@ -377,7 +377,7 @@ def _resolve_blob_name(self, path: Union[Path, str]) -> str: except ValueError: return path_str - async def read_file_async(self, path: Union[Path, str]) -> bytes: + async def read_file_async(self, path: Path | str) -> bytes: """ Asynchronously reads the content of a file (blob) from Azure Blob Storage. @@ -411,7 +411,7 @@ async def read_file_async(self, path: Union[Path, str]) -> bytes: # Download the blob blob_stream = await blob_client.download_blob() - return bytes(await blob_stream.readall()) + return bytes(await blob_stream.readall()) # type: ignore[ty:invalid-argument-type] except Exception as exc: logger.exception(f"Failed to read file at {blob_name}: {exc}") @@ -420,7 +420,7 @@ async def read_file_async(self, path: Union[Path, str]) -> bytes: await self._client_async.close() self._client_async = None - async def write_file_async(self, path: Union[Path, str], data: bytes) -> None: + async def write_file_async(self, path: Path | str, data: bytes) -> None: """ Write data to Azure Blob Storage at the specified path. @@ -443,7 +443,7 @@ async def write_file_async(self, path: Union[Path, str], data: bytes) -> None: await self._client_async.close() self._client_async = None - async def path_exists_async(self, path: Union[Path, str]) -> bool: + async def path_exists_async(self, path: Path | str) -> bool: """ Check whether a given path exists in the Azure Blob Storage container. @@ -468,7 +468,7 @@ async def path_exists_async(self, path: Union[Path, str]) -> bool: await self._client_async.close() self._client_async = None - async def is_file_async(self, path: Union[Path, str]) -> bool: + async def is_file_async(self, path: Path | str) -> bool: """ Check whether the path refers to a file (blob) in Azure Blob Storage. @@ -493,7 +493,7 @@ async def is_file_async(self, path: Union[Path, str]) -> bool: await self._client_async.close() self._client_async = None - async def create_directory_if_not_exists_async(self, directory_path: Union[Path, str]) -> None: # type: ignore[ty:invalid-method-override] + async def create_directory_if_not_exists_async(self, directory_path: Path | str) -> None: # type: ignore[ty:invalid-method-override] """ Log a no-op directory creation for Azure Blob Storage. diff --git a/uv.lock b/uv.lock index 31454866e2..e4ecfb9887 100644 --- a/uv.lock +++ b/uv.lock @@ -5150,7 +5150,7 @@ wheels = [ [[package]] name = "pyrit" -version = "0.14.0.dev0" +version = "0.15.0.dev0" source = { editable = "." } dependencies = [ { name = "aiofiles" }, From 231f7b68524953b73488cd6f01946e8266d9ecce Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Fri, 5 Jun 2026 16:27:44 -0700 Subject: [PATCH 10/10] Suppress pre-existing ty errors surfaced by import relocation These 6 type errors are pre-existing on main (third-party stub and SQLAlchemy Base typing mismatches from the recent dep audit, #1931). They were ungated on main because that PR did not touch these files; relocating their imports to pyrit.memory newly subjects them to the per-file ty gate. Add minimal pragmas consistent with the existing codebase convention: - sqlite_memory/azure_sql_memory: session.get(type(entry), entry.id) - sqlite_memory: MemoryExporter.export_data(list(data), ...) - add_image_text_converter: ImageFont.truetype(str | None, ...) - openai_tts_target: AsyncSpeech.create response_format/speed - azure_content_filter_scorer: ContentSafetyClient credential Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/memory/azure_sql_memory.py | 2 +- pyrit/memory/sqlite_memory.py | 4 ++-- pyrit/prompt_converter/add_image_text_converter.py | 2 +- pyrit/prompt_target/openai/openai_tts_target.py | 4 ++-- pyrit/score/float_scale/azure_content_filter_scorer.py | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index a027888d7f..7d62ab5bd9 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -868,7 +868,7 @@ def _update_entries(self, *, entries: MutableSequence[Base], update_fields: dict # attributes from the (potentially stale) detached object # and silently overwrite concurrent updates to columns # that are NOT in update_fields. - entry_in_session = session.get(type(entry), entry.id) + entry_in_session = session.get(type(entry), entry.id) # type: ignore[ty:unresolved-attribute] if entry_in_session is None: entry_in_session = session.merge(entry) for field, value in update_fields.items(): diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index 3bb347bfe1..61f556aee2 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -441,7 +441,7 @@ def _update_entries(self, *, entries: MutableSequence[Base], update_fields: dict # attributes from the (potentially stale) detached object # and silently overwrite concurrent updates to columns # that are NOT in update_fields. - entry_in_session = session.get(type(entry), entry.id) + entry_in_session = session.get(type(entry), entry.id) # type: ignore[ty:unresolved-attribute] if entry_in_session is None: entry_in_session = session.merge(entry) for field, value in update_fields.items(): @@ -615,7 +615,7 @@ def export_all_tables(self, *, export_type: str = "json") -> None: file_extension = f".{export_type}" file_path = DB_DATA_PATH / f"{table_name}{file_extension}" # Convert to list for exporter compatibility - self.exporter.export_data(list(data), file_path=file_path, export_type=export_type) + self.exporter.export_data(list(data), file_path=file_path, export_type=export_type) # type: ignore[ty:invalid-argument-type] def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories: Sequence[str]) -> Any: """ diff --git a/pyrit/prompt_converter/add_image_text_converter.py b/pyrit/prompt_converter/add_image_text_converter.py index e030930cca..2f5678b021 100644 --- a/pyrit/prompt_converter/add_image_text_converter.py +++ b/pyrit/prompt_converter/add_image_text_converter.py @@ -193,7 +193,7 @@ def _load_font_at_size(self, size: int) -> FreeTypeFont: if self._font_load_failed: return cast("FreeTypeFont", ImageFont.load_default(size=size)) try: - return ImageFont.truetype(self._font_name, size) + return ImageFont.truetype(self._font_name, size) # type: ignore[ty:invalid-argument-type] except OSError: logger.warning(f"Cannot open font resource: {self._font_name}. Using Pillow built-in default font.") self._font_load_failed = True diff --git a/pyrit/prompt_target/openai/openai_tts_target.py b/pyrit/prompt_target/openai/openai_tts_target.py index e86b6cb9c4..03602c31a4 100644 --- a/pyrit/prompt_target/openai/openai_tts_target.py +++ b/pyrit/prompt_target/openai/openai_tts_target.py @@ -141,8 +141,8 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me model=str(body_parameters["model"]), voice=str(body_parameters["voice"]), input=str(body_parameters["input"]), - response_format=body_parameters.get("response_format"), - speed=body_parameters.get("speed"), + response_format=body_parameters.get("response_format"), # type: ignore[ty:invalid-argument-type] + speed=body_parameters.get("speed"), # type: ignore[ty:invalid-argument-type] ), request=message, ) diff --git a/pyrit/score/float_scale/azure_content_filter_scorer.py b/pyrit/score/float_scale/azure_content_filter_scorer.py index 4a267322f0..99d24ef06a 100644 --- a/pyrit/score/float_scale/azure_content_filter_scorer.py +++ b/pyrit/score/float_scale/azure_content_filter_scorer.py @@ -144,7 +144,7 @@ def __init__( if callable(self._api_key): # Token provider - create an AsyncTokenCredential wrapper credential = AsyncTokenProviderCredential(self._api_key) # type: ignore[ty:invalid-argument-type] - self._azure_cf_client = ContentSafetyClient(self._endpoint, credential=credential) + self._azure_cf_client = ContentSafetyClient(self._endpoint, credential=credential) # type: ignore[ty:invalid-argument-type] else: # String API key if not isinstance(self._api_key, str):