diff --git a/pyproject.toml b/pyproject.toml index 87b1e93b3f73..692ff425a5f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -816,7 +816,6 @@ module = [ "sentry.incidents.logic", "sentry.incidents.models.*", "sentry.incidents.receivers", - "sentry.incidents.serializers.*", "sentry.incidents.subscription_processor", "sentry.incidents.tasks", "sentry.incidents.typings.*", diff --git a/src/sentry/incidents/serializers/alert_rule.py b/src/sentry/incidents/serializers/alert_rule.py index 3968572b1cd4..04ad03b44816 100644 --- a/src/sentry/incidents/serializers/alert_rule.py +++ b/src/sentry/incidents/serializers/alert_rule.py @@ -3,10 +3,11 @@ import logging import operator from datetime import timedelta -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from sentry.models.environment import Environment + from sentry.snuba.models import SnubaQuery import sentry_sdk from django import forms @@ -132,7 +133,7 @@ def validate_environment(self, value: str | None) -> Environment | None: field.bind("environment", self) return field.to_internal_value(value) - def validate_threshold_type(self, threshold_type): + def validate_threshold_type(self, threshold_type: int) -> AlertRuleThresholdType: try: return AlertRuleThresholdType(threshold_type) except ValueError: @@ -141,7 +142,7 @@ def validate_threshold_type(self, threshold_type): % [item.value for item in AlertRuleThresholdType] ) - def validate_aggregate(self, aggregate): + def validate_aggregate(self, aggregate: str) -> str: """ Validate aggregate field and reject upsampled_count() from user input. @@ -155,17 +156,17 @@ def validate_aggregate(self, aggregate): ) return aggregate - def validate_eap_rule(self, data): + def validate_eap_rule(self, data: dict[str, Any]) -> None: """ Validate EAP rule data. """ event_types = data.get("event_types", []) if SnubaQueryEventType.EventType.TRACE_ITEM_METRIC in event_types: - aggregate = data.get("aggregate") + aggregate: str = data.get("aggregate", "") validate_trace_metrics_aggregate(aggregate) - def validate_deprecated_transactions_datasets(self, data): + def validate_deprecated_transactions_datasets(self, data: dict[str, Any]) -> None: new_dataset = data.get("dataset") organization = self.context.get("organization") if organization and features.has( @@ -176,7 +177,7 @@ def validate_deprecated_transactions_datasets(self, data): "Updating transaction-based alerts is disabled as we migrate to the spans dataset. Update the dataset to events_analytics_platform with the is_transaction:true filter instead." ) - def validate(self, data): + def validate(self, data: dict[str, Any]) -> dict[str, Any]: """ Performs validation on an alert rule's data. This includes ensuring there is either 1 or 2 triggers, which each have @@ -225,7 +226,13 @@ def validate(self, data): return data - def _translate_thresholds(self, threshold_type, comparison_delta, triggers, data): + def _translate_thresholds( + self, + threshold_type: AlertRuleThresholdType, + comparison_delta: int | None, + triggers: list[dict[str, Any]], + data: dict[str, Any], + ) -> None: """ Performs transformations on the thresholds used in the alert. Currently this is used to translate thresholds for comparison alerts. The frontend will pass in the delta percent @@ -248,7 +255,12 @@ def _translate_thresholds(self, threshold_type, comparison_delta, triggers, data for trigger in triggers: trigger["alert_threshold"] = translator(trigger["alert_threshold"]) - def _validate_trigger_thresholds(self, threshold_type, trigger, resolve_threshold): + def _validate_trigger_thresholds( + self, + threshold_type: AlertRuleThresholdType, + trigger: dict[str, Any], + resolve_threshold: int | float | None, + ) -> None: if trigger.get("alert_threshold") is None: raise serializers.ValidationError("Trigger must have an alertThreshold") @@ -274,20 +286,25 @@ def _validate_trigger_thresholds(self, threshold_type, trigger, resolve_threshol f"{trigger['label']} alert threshold must be {threshold_type.name.lower()} resolution threshold" ) - def _validate_critical_warning_triggers(self, threshold_type, critical, warning): + def _validate_critical_warning_triggers( + self, + threshold_type: AlertRuleThresholdType, + critical: dict[str, Any], + warning: dict[str, Any], + ) -> None: if threshold_type == AlertRuleThresholdType.ABOVE: alert_op = operator.lt - threshold_type = "above" + threshold_name = "above" else: alert_op = operator.gt - threshold_type = "below" + threshold_name = "below" if alert_op(critical["alert_threshold"], warning["alert_threshold"]): raise serializers.ValidationError( - f"Critical trigger must have an alert threshold {threshold_type} warning trigger" + f"Critical trigger must have an alert threshold {threshold_name} warning trigger" ) - def create(self, validated_data): + def create(self, validated_data: dict[str, Any]) -> AlertRule: # type: ignore[override] org_subscription_count = QuerySubscription.objects.filter( project__organization_id=self.context["organization"].id, status__in=( @@ -342,7 +359,7 @@ def create(self, validated_data): return alert_rule - def _apply_error_upsampling_if_needed(self, validated_data): + def _apply_error_upsampling_if_needed(self, validated_data: dict[str, Any]) -> None: """ Automatically convert count() to upsampled_count() for error alerts on upsampled projects. """ @@ -356,7 +373,7 @@ def _apply_error_upsampling_if_needed(self, validated_data): if are_any_projects_error_upsampled(project_ids): validated_data["aggregate"] = "upsampled_count()" - def update(self, instance, validated_data): + def update(self, instance: AlertRule, validated_data: dict[str, Any]) -> AlertRule: # type: ignore[override] triggers = validated_data.pop("triggers") if "id" in validated_data: validated_data.pop("id") @@ -399,7 +416,7 @@ def update(self, instance, validated_data): return alert_rule - def _handle_triggers(self, alert_rule, triggers): + def _handle_triggers(self, alert_rule: AlertRule, triggers: list[dict[str, Any]]) -> None: channel_lookup_timeout_error = None if triggers is not None: # Delete triggers we don't have present in the incoming data @@ -447,7 +464,7 @@ def _handle_triggers(self, alert_rule, triggers): if channel_lookup_timeout_error: raise channel_lookup_timeout_error - def _mark_query_as_user_updated(self, snuba_query): + def _mark_query_as_user_updated(self, snuba_query: SnubaQuery) -> None: """ Mark the snuba query as user-updated in the query_snapshot field. This is used to skip automatic migrations for queries that users have already modified. diff --git a/src/sentry/incidents/serializers/alert_rule_trigger.py b/src/sentry/incidents/serializers/alert_rule_trigger.py index e3252b86057c..574a8c7ad353 100644 --- a/src/sentry/incidents/serializers/alert_rule_trigger.py +++ b/src/sentry/incidents/serializers/alert_rule_trigger.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from typing import Any + from django import forms from django.db import router, transaction from rest_framework import serializers @@ -20,7 +24,7 @@ from .alert_rule_trigger_action import AlertRuleTriggerActionSerializer -class AlertRuleTriggerSerializer(CamelSnakeModelSerializer): +class AlertRuleTriggerSerializer(CamelSnakeModelSerializer[AlertRuleTrigger]): """ Serializer for creating/updating an alert rule trigger. Required context: - `alert_rule`: The alert_rule related to this trigger. @@ -41,7 +45,7 @@ class Meta: fields = ["id", "label", "alert_threshold", "excluded_projects", "actions"] extra_kwargs = {"label": {"min_length": 1, "max_length": 64}} - def create(self, validated_data): + def create(self, validated_data: dict[str, Any]) -> AlertRuleTrigger: with transaction.atomic(router.db_for_write(AlertRuleTrigger)): try: actions = validated_data.pop("actions", None) @@ -60,7 +64,9 @@ def create(self, validated_data): self._handle_actions(alert_rule_trigger, actions) return alert_rule_trigger - def update(self, instance, validated_data): + def update( + self, instance: AlertRuleTrigger, validated_data: dict[str, Any] + ) -> AlertRuleTrigger: actions = validated_data.pop("actions") if "id" in validated_data: validated_data.pop("id") @@ -74,7 +80,9 @@ def update(self, instance, validated_data): except AlertRuleTriggerLabelAlreadyUsedError: raise serializers.ValidationError("This label is already in use for this alert rule") - def _handle_actions(self, alert_rule_trigger, actions): + def _handle_actions( + self, alert_rule_trigger: AlertRuleTrigger, actions: list[dict[str, Any]] | None + ) -> None: channel_lookup_timeout_error = None if actions is not None: # Delete actions we don't have present in the updated data. diff --git a/src/sentry/incidents/serializers/alert_rule_trigger_action.py b/src/sentry/incidents/serializers/alert_rule_trigger_action.py index ffc53fc41a2a..7e9eaf322347 100644 --- a/src/sentry/incidents/serializers/alert_rule_trigger_action.py +++ b/src/sentry/incidents/serializers/alert_rule_trigger_action.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from typing import Any + import sentry_sdk from django.forms import ValidationError from django.utils.encoding import force_str @@ -21,11 +25,11 @@ from sentry.integrations.slack.utils.channel import validate_slack_entity_id from sentry.models.organizationmember import OrganizationMember from sentry.models.team import Team -from sentry.notifications.models.notificationaction import ActionService +from sentry.notifications.models.notificationaction import ActionService, ActionTarget from sentry.shared_integrations.exceptions import ApiRateLimitedError -class AlertRuleTriggerActionSerializer(CamelSnakeModelSerializer): +class AlertRuleTriggerActionSerializer(CamelSnakeModelSerializer[AlertRuleTriggerAction]): """ Serializer for creating/updating a trigger action. Required context: - `trigger`: The trigger related to this action. @@ -74,7 +78,7 @@ def validate_type(self, type: str) -> ActionService: raise serializers.ValidationError(f"Invalid type, valid values are {valid_slugs!r}") return factory.service_type - def validate_target_type(self, target_type): + def validate_target_type(self, target_type: str) -> ActionTarget: if target_type not in STRING_TO_ACTION_TARGET_TYPE: raise serializers.ValidationError( "Invalid targetType, valid values are [%s]" @@ -82,7 +86,7 @@ def validate_target_type(self, target_type): ) return STRING_TO_ACTION_TARGET_TYPE[target_type] - def validate(self, attrs): + def validate(self, attrs: dict[str, Any]) -> dict[str, Any]: if ("type" in attrs) != ("target_type" in attrs) != ("target_identifier" in attrs): raise serializers.ValidationError( "type, targetType and targetIdentifier must be passed together" @@ -90,7 +94,7 @@ def validate(self, attrs): type = attrs.get("type") target_type = attrs.get("target_type") access: Access = self.context["access"] - identifier = attrs.get("target_identifier") + identifier: str = attrs.get("target_identifier", "") # Validate that target_identifier is an integer for USER and TEAM target types if target_type in ( @@ -212,7 +216,7 @@ def validate(self, attrs): ) return attrs - def create(self, validated_data): + def create(self, validated_data: dict[str, Any]) -> AlertRuleTriggerAction: for key in ("id", "sentry_app_installation_uuid"): validated_data.pop(key, None) try: @@ -238,7 +242,9 @@ def create(self, validated_data): return action - def update(self, instance, validated_data): + def update( + self, instance: AlertRuleTriggerAction, validated_data: dict[str, Any] + ) -> AlertRuleTriggerAction: for key in ("id", "sentry_app_installation_uuid"): validated_data.pop(key, None) diff --git a/tests/sentry/incidents/endpoints/test_serializers.py b/tests/sentry/incidents/endpoints/test_serializers.py index d68f3fc98530..84e6f57c0a77 100644 --- a/tests/sentry/incidents/endpoints/test_serializers.py +++ b/tests/sentry/incidents/endpoints/test_serializers.py @@ -1199,6 +1199,7 @@ def test_pagerduty_valid_priority(self, mock_get: MagicMock) -> None: serializer = AlertRuleTriggerActionSerializer(data=params, context=self.context) assert serializer.is_valid() action = serializer.save() + assert isinstance(action.sentry_app_config, dict) assert action.sentry_app_config["priority"] == "critical" @patch( @@ -1217,6 +1218,7 @@ def test_opsgenie_valid_priority(self, mock_get: MagicMock) -> None: serializer = AlertRuleTriggerActionSerializer(data=params, context=self.context) assert serializer.is_valid() action = serializer.save() + assert isinstance(action.sentry_app_config, dict) assert action.sentry_app_config["priority"] == "P1" def test_discord(self) -> None: