diff --git a/src/sentry/api/urls.py b/src/sentry/api/urls.py index b4c2c8944b3954..aece38e155df06 100644 --- a/src/sentry/api/urls.py +++ b/src/sentry/api/urls.py @@ -554,6 +554,10 @@ from sentry.seer.endpoints.organization_seer_workflows import OrganizationSeerWorkflowsEndpoint from sentry.seer.endpoints.project_seer_night_shift import ProjectSeerNightShiftEndpoint from sentry.seer.endpoints.project_seer_preferences import ProjectSeerPreferencesEndpoint +from sentry.seer.endpoints.project_seer_repos import ( + OrganizationSeerProjectRepoDetailsEndpoint, + OrganizationSeerProjectReposEndpoint, +) from sentry.seer.endpoints.search_agent_start import SearchAgentStartEndpoint from sentry.seer.endpoints.search_agent_state import SearchAgentStateEndpoint from sentry.seer.endpoints.seer_rpc import SeerRpcServiceEndpoint @@ -2462,6 +2466,16 @@ def create_group_urls(name_prefix: str) -> list[URLPattern | URLResolver]: OrganizationSeerOnboardingCheck.as_view(), name="sentry-api-0-organization-seer-onboarding-check", ), + re_path( + r"^(?P[^/]+)/seer/projects/(?P\d+)/repos/$", + OrganizationSeerProjectReposEndpoint.as_view(), + name="sentry-api-0-organization-seer-project-repos", + ), + re_path( + r"^(?P[^/]+)/seer/projects/(?P\d+)/repos/(?P\d+)/$", + OrganizationSeerProjectRepoDetailsEndpoint.as_view(), + name="sentry-api-0-organization-seer-project-repo-details", + ), re_path( r"^(?P[^/]+)/autofix/automation-settings/$", OrganizationAutofixAutomationSettingsEndpoint.as_view(), diff --git a/src/sentry/seer/autofix/utils.py b/src/sentry/seer/autofix/utils.py index 7f04bdbc53e308..5d2bff369c0a2a 100644 --- a/src/sentry/seer/autofix/utils.py +++ b/src/sentry/seer/autofix/utils.py @@ -788,6 +788,107 @@ def _set_if_not_default(key: str, value: Any, default: Any) -> None: ) +class BranchOverrideData(TypedDict): + tag_name: str + tag_value: str + branch_name: str + + +def _write_branch_overrides( + project_repo: SeerProjectRepository, branch_overrides: list[BranchOverrideData] +) -> None: + """Replace all branch overrides for the given project repo.""" + SeerProjectRepositoryBranchOverride.objects.filter( + seer_project_repository=project_repo + ).delete() + if branch_overrides: + SeerProjectRepositoryBranchOverride.objects.bulk_create( + [ + SeerProjectRepositoryBranchOverride( + seer_project_repository=project_repo, + tag_name=override["tag_name"], + tag_value=override["tag_value"], + branch_name=override["branch_name"], + ) + for override in branch_overrides + ] + ) + + +class ProjectRepoCreateData(TypedDict, total=False): + repository_id: int + branch_name: str | None + instructions: str | None + branch_overrides: list[BranchOverrideData] + + +def add_seer_project_repos(project: Project, repos_data: list[ProjectRepoCreateData]) -> list[int]: + """Connect repos to the given project. Raises ValueError if any repo is already connected.""" + repo_ids = [d["repository_id"] for d in repos_data] + + connected_ids = set( + SeerProjectRepository.objects.filter( + project=project, repository_id__in=repo_ids + ).values_list("repository_id", flat=True) + ) + if connected_ids: + raise ValueError(connected_ids) + + created_ids = [] + with transaction.atomic(router.db_for_write(SeerProjectRepository)): + list(Project.objects.select_for_update().filter(id=project.id)) + + for data in repos_data: + project_repo = SeerProjectRepository.objects.create( + project=project, + repository_id=data["repository_id"], + branch_name=data.get("branch_name"), + instructions=data.get("instructions"), + ) + _write_branch_overrides(project_repo, data.get("branch_overrides", [])) + created_ids.append(project_repo.id) + + return created_ids + + +def replace_all_seer_project_repos( + project: Project, repos_data: list[ProjectRepoCreateData] +) -> None: + """Replace all repos for the given project.""" + with transaction.atomic(router.db_for_write(SeerProjectRepository)): + list(Project.objects.select_for_update().filter(id=project.id)) + SeerProjectRepository.objects.filter(project=project).delete() + for data in repos_data: + project_repo = SeerProjectRepository.objects.create( + project=project, + repository_id=data["repository_id"], + branch_name=data.get("branch_name"), + instructions=data.get("instructions"), + ) + _write_branch_overrides(project_repo, data.get("branch_overrides", [])) + + +class ProjectRepoUpdateData(TypedDict, total=False): + branch_name: str | None + instructions: str | None + branch_overrides: list[BranchOverrideData] + + +def update_seer_project_repo( + project_repo: SeerProjectRepository, data: ProjectRepoUpdateData +) -> None: + """Update a given project repo. Raises DatabaseError if the row doesn't exist by the time we save.""" + with transaction.atomic(router.db_for_write(SeerProjectRepository)): + list(Project.objects.select_for_update().filter(id=project_repo.project_id)) + if "branch_name" in data: + project_repo.branch_name = data["branch_name"] + if "instructions" in data: + project_repo.instructions = data["instructions"] + project_repo.save(force_update=True) + if "branch_overrides" in data: + _write_branch_overrides(project_repo, data["branch_overrides"]) + + def has_project_connected_repos(organization: Organization, project: Project) -> bool: """Check if a project has connected repositories for Seer automation.""" return SeerProjectRepository.objects.filter( diff --git a/src/sentry/seer/endpoints/project_seer_repos.py b/src/sentry/seer/endpoints/project_seer_repos.py new file mode 100644 index 00000000000000..8c622c26f78fb6 --- /dev/null +++ b/src/sentry/seer/endpoints/project_seer_repos.py @@ -0,0 +1,349 @@ +from __future__ import annotations + +from collections.abc import Sequence +from functools import partial +from typing import TypedDict + +from django.db import DatabaseError, router, transaction +from django.db.models import Value +from django.db.models.functions import Replace +from rest_framework import serializers +from rest_framework.request import Request +from rest_framework.response import Response + +from sentry.api.api_owners import ApiOwner +from sentry.api.api_publish_status import ApiPublishStatus +from sentry.api.base import cell_silo_endpoint +from sentry.api.bases.organization import OrganizationEndpoint, OrganizationPermission +from sentry.api.event_search import QueryToken, SearchConfig, SearchFilter +from sentry.api.event_search import parse_search_query as base_parse_search_query +from sentry.api.paginator import OffsetPaginator +from sentry.api.serializers.rest_framework import CamelSnakeSerializer +from sentry.constants import ObjectStatus +from sentry.exceptions import InvalidSearchQuery +from sentry.models.organization import Organization +from sentry.models.project import Project +from sentry.models.repository import Repository +from sentry.seer.autofix.utils import ( + add_seer_project_repos, + replace_all_seer_project_repos, + update_seer_project_repo, +) +from sentry.seer.constants import SEER_SUPPORTED_SCM_PROVIDERS +from sentry.seer.models.project_repository import SeerProjectRepository + +SORT_FIELDS_MAPPING: dict[str, str] = { + "name": "repository__name", + "-name": "-repository__name", + "provider": "provider_normalized", + "-provider": "-provider_normalized", +} + +search_config = SearchConfig.create_from( + SearchConfig(), allowed_keys={"name", "provider"}, allow_boolean=False, free_text_key="name" +) +parse_search_query = partial(base_parse_search_query, config=search_config) + + +class BranchOverrideResponse(TypedDict): + tagName: str + tagValue: str + branchName: str + + +class ProjectRepoResponse(TypedDict): + repositoryId: str + provider: str + owner: str + name: str + externalId: str + integrationId: str | None + branchName: str | None + branchOverrides: list[BranchOverrideResponse] + instructions: str | None + + +def _serialize_project_repo(project_repo: SeerProjectRepository) -> ProjectRepoResponse: + repo = project_repo.repository + name_parts = repo.name.split("/", 1) + owner = name_parts[0] if len(name_parts) > 1 else "" + name = name_parts[1] if len(name_parts) > 1 else repo.name + + return ProjectRepoResponse( + repositoryId=str(repo.id), + provider=repo.provider or "", + owner=owner, + name=name, + externalId=repo.external_id or "", + integrationId=str(repo.integration_id) if repo.integration_id is not None else None, + branchName=project_repo.branch_name, + branchOverrides=[ + BranchOverrideResponse( + tagName=bo.tag_name, + tagValue=bo.tag_value, + branchName=bo.branch_name, + ) + for bo in project_repo.branch_overrides.all() + ], + instructions=project_repo.instructions, + ) + + +def _apply_search_filters(queryset, filters: Sequence[QueryToken]): + for f in filters: + if not isinstance(f, SearchFilter): + continue + + key = f.key.name + op = f.operator + value = f.value.value + + if key == "name": + if op == "=": + queryset = queryset.filter(repository__name__icontains=value) + elif op == "!=": + queryset = queryset.exclude(repository__name__icontains=value) + elif op == "IN": + queryset = queryset.filter(repository__name__in=value) + elif op == "NOT IN": + queryset = queryset.exclude(repository__name__in=value) + + elif key == "provider": + normalize = lambda v: v.removeprefix("integrations:") + if op == "=": + queryset = queryset.filter(provider_normalized=normalize(value)) + elif op == "!=": + queryset = queryset.exclude(provider_normalized=normalize(value)) + elif op == "IN": + queryset = queryset.filter(provider_normalized__in=[normalize(v) for v in value]) + elif op == "NOT IN": + queryset = queryset.exclude(provider_normalized__in=[normalize(v) for v in value]) + + return queryset + + +def _get_valid_repo_ids(repo_ids: list[int], organization: Organization) -> set[int]: + """Return a subset of active repo ids with Seer-supported providers belonging to the given org.""" + return set( + Repository.objects.filter( + id__in=repo_ids, + organization_id=organization.id, + status=ObjectStatus.ACTIVE, + provider__in=SEER_SUPPORTED_SCM_PROVIDERS, + ).values_list("id", flat=True) + ) + + +def _get_project_repos_queryset(project: Project): + return ( + SeerProjectRepository.objects.filter( + project=project, repository__status=ObjectStatus.ACTIVE + ) + .select_related("repository") + .prefetch_related("branch_overrides") + ) + + +class BranchOverrideSerializer(CamelSnakeSerializer): + tag_name = serializers.CharField(required=True) + tag_value = serializers.CharField(required=True) + branch_name = serializers.CharField(required=True) + + +def _validate_branch_overrides(value): + if not value: + return value + seen: set[tuple[str, str]] = set() + for override in value: + key = (override["tag_name"], override["tag_value"]) + if key in seen: + raise serializers.ValidationError( + f"Duplicate branch override for tag {key[0]}={key[1]}" + ) + seen.add(key) + return value + + +class SeerProjectRepoSerializer(CamelSnakeSerializer): + repository_id = serializers.IntegerField(required=True) + branch_name = serializers.CharField(required=False, allow_null=True, allow_blank=True) + instructions = serializers.CharField(required=False, allow_null=True, allow_blank=True) + branch_overrides = BranchOverrideSerializer( + many=True, required=False, default=list, allow_null=False + ) + + def validate_branch_overrides(self, value): + return _validate_branch_overrides(value) + + +class SeerProjectRepoUpdateSerializer(CamelSnakeSerializer): + branch_name = serializers.CharField(required=False, allow_null=True, allow_blank=True) + instructions = serializers.CharField(required=False, allow_null=True, allow_blank=True) + branch_overrides = BranchOverrideSerializer(many=True, required=False, allow_null=False) + + def validate_branch_overrides(self, value): + return _validate_branch_overrides(value) + + +class SeerProjectReposRequestSerializer(CamelSnakeSerializer): + repos = SeerProjectRepoSerializer(many=True, required=True, allow_empty=True) + + +@cell_silo_endpoint +class OrganizationSeerProjectReposEndpoint(OrganizationEndpoint): + owner = ApiOwner.ML_AI + publish_status = { + "GET": ApiPublishStatus.EXPERIMENTAL, + "POST": ApiPublishStatus.EXPERIMENTAL, + "PUT": ApiPublishStatus.EXPERIMENTAL, + } + permission_classes = (OrganizationPermission,) + + def get(self, request: Request, organization: Organization, project_id: int) -> Response: + project = self.get_projects(request, organization, project_ids={int(project_id)})[0] + + queryset = _get_project_repos_queryset(project).annotate( + # Strip the provider prefix if present, so we can order by it. + provider_normalized=Replace("repository__provider", Value("integrations:"), Value("")) + ) + + search_query = request.GET.get("query", "") + if search_query: + try: + filters = parse_search_query(search_query) + queryset = _apply_search_filters(queryset, filters) + except (InvalidSearchQuery, ValueError): + return Response({"detail": "Invalid search query"}, status=400) + + sort_by = request.GET.get("sortBy", "name") + order_by = SORT_FIELDS_MAPPING.get(sort_by) + if order_by is None: + return Response({"detail": f"Invalid sortBy: {sort_by}"}, status=400) + + return self.paginate( + request=request, + queryset=queryset, + order_by=order_by, + on_results=lambda results: [_serialize_project_repo(r) for r in results], + paginator_cls=OffsetPaginator, + ) + + def post(self, request: Request, organization: Organization, project_id: int) -> Response: + project = self.get_projects(request, organization, project_ids={int(project_id)})[0] + + serializer = SeerProjectReposRequestSerializer(data=request.data) + if not serializer.is_valid(): + return Response(serializer.errors, status=400) + + repos_data = serializer.validated_data["repos"] + if not repos_data: + return Response({"detail": "repos must not be empty."}, status=400) + + repo_ids = [r["repository_id"] for r in repos_data] + valid_repo_ids = _get_valid_repo_ids(repo_ids, organization) + invalid_repo_ids = set(repo_ids) - valid_repo_ids + if invalid_repo_ids: + return Response( + {"detail": f"Invalid repository IDs: {sorted(invalid_repo_ids)}"}, status=400 + ) + + try: + created_ids = add_seer_project_repos(project, repos_data) + except ValueError as e: + connected_ids = e.args[0] + return Response( + {"detail": f"Repositories already connected: {sorted(connected_ids)}"}, + status=409, + ) + + result = _get_project_repos_queryset(project).filter(id__in=created_ids) + return Response([_serialize_project_repo(r) for r in result], status=201) + + def put(self, request: Request, organization: Organization, project_id: int) -> Response: + project = self.get_projects(request, organization, project_ids={int(project_id)})[0] + + serializer = SeerProjectReposRequestSerializer(data=request.data) + if not serializer.is_valid(): + return Response(serializer.errors, status=400) + + repos_data = serializer.validated_data["repos"] + + if repos_data: + repo_ids = [r["repository_id"] for r in repos_data] + valid_repo_ids = _get_valid_repo_ids(repo_ids, organization) + invalid_repo_ids = set(repo_ids) - valid_repo_ids + if invalid_repo_ids: + return Response( + {"detail": f"Invalid repository IDs: {sorted(invalid_repo_ids)}"}, + status=400, + ) + + replace_all_seer_project_repos(project, repos_data) + + result = _get_project_repos_queryset(project) + return Response([_serialize_project_repo(r) for r in result]) + + +@cell_silo_endpoint +class OrganizationSeerProjectRepoDetailsEndpoint(OrganizationEndpoint): + owner = ApiOwner.ML_AI + publish_status = { + "GET": ApiPublishStatus.EXPERIMENTAL, + "PUT": ApiPublishStatus.EXPERIMENTAL, + "DELETE": ApiPublishStatus.EXPERIMENTAL, + } + permission_classes = (OrganizationPermission,) + + def _get_project_repo(self, project: Project, repo_id: int) -> SeerProjectRepository | None: + return _get_project_repos_queryset(project).filter(repository_id=repo_id).first() + + def get( + self, request: Request, organization: Organization, project_id: int, repo_id: int + ) -> Response: + project = self.get_projects(request, organization, project_ids={int(project_id)})[0] + + project_repo = self._get_project_repo(project, repo_id) + if project_repo is None: + return Response(status=404) + + return Response(_serialize_project_repo(project_repo)) + + def put( + self, request: Request, organization: Organization, project_id: int, repo_id: int + ) -> Response: + project = self.get_projects(request, organization, project_ids={int(project_id)})[0] + + serializer = SeerProjectRepoUpdateSerializer(data=request.data) + if not serializer.is_valid(): + return Response(serializer.errors, status=400) + + project_repo = self._get_project_repo(project, repo_id) + if project_repo is None: + return Response(status=404) + + try: + update_seer_project_repo(project_repo, serializer.validated_data) + except DatabaseError: + return Response(status=404) + + project_repo = self._get_project_repo(project, repo_id) + if project_repo is None: + return Response(status=404) + return Response(_serialize_project_repo(project_repo)) + + def delete( + self, request: Request, organization: Organization, project_id: int, repo_id: int + ) -> Response: + project = self.get_projects(request, organization, project_ids={int(project_id)})[0] + + with transaction.atomic(router.db_for_write(SeerProjectRepository)): + deleted_count, _ = SeerProjectRepository.objects.filter( + project=project, + repository_id=repo_id, + repository__status=ObjectStatus.ACTIVE, + ).delete() + + if deleted_count == 0: + return Response(status=404) + + return Response(status=204) diff --git a/static/app/utils/api/knownSentryApiUrls.generated.ts b/static/app/utils/api/knownSentryApiUrls.generated.ts index 9158d3bc68fd84..21c45625f91f19 100644 --- a/static/app/utils/api/knownSentryApiUrls.generated.ts +++ b/static/app/utils/api/knownSentryApiUrls.generated.ts @@ -562,6 +562,8 @@ export type KnownSentryApiUrls = | '/organizations/$organizationIdOrSlug/seer/explorer-runs/' | '/organizations/$organizationIdOrSlug/seer/explorer-update/$runId/' | '/organizations/$organizationIdOrSlug/seer/onboarding-check/' + | '/organizations/$organizationIdOrSlug/seer/projects/$projectId/repos/' + | '/organizations/$organizationIdOrSlug/seer/projects/$projectId/repos/$repoId/' | '/organizations/$organizationIdOrSlug/seer/setup-check/' | '/organizations/$organizationIdOrSlug/seer/supergroups/$supergroupId/' | '/organizations/$organizationIdOrSlug/seer/supergroups/by-group/' diff --git a/tests/sentry/seer/autofix/test_autofix_utils.py b/tests/sentry/seer/autofix/test_autofix_utils.py index 85d8a96b60cb0d..69078122cd7db5 100644 --- a/tests/sentry/seer/autofix/test_autofix_utils.py +++ b/tests/sentry/seer/autofix/test_autofix_utils.py @@ -3,6 +3,7 @@ import orjson import pytest +from django.db import DatabaseError from sentry.constants import ( SEER_AUTOMATED_RUN_STOPPING_POINT_DEFAULT, @@ -22,6 +23,7 @@ AutomationCodingAgent, CodingAgentProviderType, CodingAgentStatus, + add_seer_project_repos, bulk_read_preferences_from_sentry_db, bulk_write_preferences_to_sentry_db, clear_preference_automation_handoff, @@ -33,6 +35,8 @@ has_project_connected_repos, is_seer_seat_based_tier_enabled, read_preference_from_sentry_db, + replace_all_seer_project_repos, + update_seer_project_repo, update_seer_project_settings, write_preference_to_sentry_db, ) @@ -1726,3 +1730,241 @@ def test_deletes_option_when_value_is_default(self) -> None: assert not ProjectOption.objects.filter( project=self.project, key="sentry:seer_scanner_automation" ).exists() + + +class TestAddSeerProjectRepos(TestCase): + def setUp(self) -> None: + super().setUp() + self.project = self.create_project(organization=self.organization) + self.repo1 = self.create_repo( + project=self.project, + name="getsentry/sentry", + provider="integrations:github", + external_id="111", + ) + self.repo2 = self.create_repo( + project=self.project, + name="getsentry/relay", + provider="integrations:github", + external_id="222", + ) + + def test_creates_project_repos(self): + ids = add_seer_project_repos( + self.project, + [ + {"repository_id": self.repo1.id, "branch_name": "main", "instructions": "hello"}, + {"repository_id": self.repo2.id}, + ], + ) + assert len(ids) == 2 + pr1 = SeerProjectRepository.objects.get(project=self.project, repository=self.repo1) + assert pr1.branch_name == "main" + assert pr1.instructions == "hello" + pr2 = SeerProjectRepository.objects.get(project=self.project, repository=self.repo2) + assert pr2.branch_name is None + + def test_creates_branch_overrides(self): + add_seer_project_repos( + self.project, + [ + { + "repository_id": self.repo1.id, + "branch_overrides": [ + {"tag_name": "environment", "tag_value": "prod", "branch_name": "release"}, + ], + }, + ], + ) + project_repo = SeerProjectRepository.objects.get( + project=self.project, repository=self.repo1 + ) + overrides = list(project_repo.branch_overrides.all()) + assert len(overrides) == 1 + assert overrides[0].tag_name == "environment" + assert overrides[0].tag_value == "prod" + assert overrides[0].branch_name == "release" + + def test_raises_if_already_connected(self): + SeerProjectRepository.objects.create(project=self.project, repository=self.repo1) + + with pytest.raises(ValueError) as exc_info: + add_seer_project_repos(self.project, [{"repository_id": self.repo1.id}]) + assert self.repo1.id in exc_info.value.args[0] + + def test_returns_created_ids(self): + created_ids = add_seer_project_repos(self.project, [{"repository_id": self.repo1.id}]) + assert created_ids == list( + SeerProjectRepository.objects.filter(project=self.project).values_list("id", flat=True) + ) + + +class TestReplaceAllSeerProjectRepos(TestCase): + def setUp(self) -> None: + super().setUp() + self.project = self.create_project(organization=self.organization) + self.repo1 = self.create_repo( + project=self.project, + name="getsentry/sentry", + provider="integrations:github", + external_id="111", + ) + self.repo2 = self.create_repo( + project=self.project, + name="getsentry/relay", + provider="integrations:github", + external_id="222", + ) + + def test_replaces_existing_repos(self): + SeerProjectRepository.objects.create(project=self.project, repository=self.repo1) + + replace_all_seer_project_repos( + self.project, [{"repository_id": self.repo2.id, "branch_name": "develop"}] + ) + + assert not SeerProjectRepository.objects.filter( + project=self.project, repository=self.repo1 + ).exists() + pr2 = SeerProjectRepository.objects.get(project=self.project, repository=self.repo2) + assert pr2.branch_name == "develop" + + def test_clears_all_when_empty(self): + SeerProjectRepository.objects.create(project=self.project, repository=self.repo1) + + replace_all_seer_project_repos(self.project, []) + + assert SeerProjectRepository.objects.filter(project=self.project).count() == 0 + + def test_creates_branch_overrides(self): + replace_all_seer_project_repos( + self.project, + [ + { + "repository_id": self.repo1.id, + "branch_overrides": [ + {"tag_name": "env", "tag_value": "staging", "branch_name": "staging"}, + ], + }, + ], + ) + project_repo = SeerProjectRepository.objects.get( + project=self.project, repository=self.repo1 + ) + assert project_repo.branch_overrides.count() == 1 + + def test_clears_old_branch_overrides(self): + project_repo = SeerProjectRepository.objects.create( + project=self.project, repository=self.repo1 + ) + SeerProjectRepositoryBranchOverride.objects.create( + seer_project_repository=project_repo, + tag_name="env", + tag_value="prod", + branch_name="old", + ) + + replace_all_seer_project_repos(self.project, [{"repository_id": self.repo1.id}]) + + project_repo = SeerProjectRepository.objects.get( + project=self.project, repository=self.repo1 + ) + assert project_repo.branch_overrides.count() == 0 + + +class TestUpdateSeerProjectRepo(TestCase): + def setUp(self) -> None: + super().setUp() + self.project = self.create_project(organization=self.organization) + self.repo = self.create_repo( + project=self.project, + name="getsentry/sentry", + provider="integrations:github", + external_id="111", + ) + + def test_updates_branch_name(self): + project_repo = SeerProjectRepository.objects.create( + project=self.project, repository=self.repo, branch_name="main" + ) + + update_seer_project_repo(project_repo, {"branch_name": "develop"}) + + project_repo.refresh_from_db() + assert project_repo.branch_name == "develop" + + def test_updates_instructions(self): + project_repo = SeerProjectRepository.objects.create( + project=self.project, repository=self.repo + ) + + update_seer_project_repo(project_repo, {"instructions": "new instructions"}) + + project_repo.refresh_from_db() + assert project_repo.instructions == "new instructions" + + def test_partial_update_preserves_other_fields(self): + project_repo = SeerProjectRepository.objects.create( + project=self.project, + repository=self.repo, + branch_name="main", + instructions="original", + ) + + update_seer_project_repo(project_repo, {"branch_name": "develop"}) + + project_repo.refresh_from_db() + assert project_repo.branch_name == "develop" + assert project_repo.instructions == "original" + + def test_replaces_branch_overrides(self): + project_repo = SeerProjectRepository.objects.create( + project=self.project, repository=self.repo, branch_name="main" + ) + SeerProjectRepositoryBranchOverride.objects.create( + seer_project_repository=project_repo, + tag_name="env", + tag_value="prod", + branch_name="old", + ) + + update_seer_project_repo( + project_repo, + { + "branch_overrides": [ + {"tag_name": "env", "tag_value": "staging", "branch_name": "new"}, + ] + }, + ) + + branch_overrides = list(project_repo.branch_overrides.all()) + assert len(branch_overrides) == 1 + assert branch_overrides[0].tag_value == "staging" + + project_repo.refresh_from_db() + assert project_repo.branch_name == "main" + + def test_clears_branch_overrides(self): + project_repo = SeerProjectRepository.objects.create( + project=self.project, repository=self.repo + ) + SeerProjectRepositoryBranchOverride.objects.create( + seer_project_repository=project_repo, + tag_name="env", + tag_value="prod", + branch_name="old", + ) + + update_seer_project_repo(project_repo, {"branch_overrides": []}) + + project_repo.refresh_from_db() + assert project_repo.branch_overrides.count() == 0 + + def test_raises_database_error_if_row_deleted(self): + project_repo = SeerProjectRepository.objects.create( + project=self.project, repository=self.repo + ) + SeerProjectRepository.objects.filter(id=project_repo.id).delete() + + with pytest.raises(DatabaseError): + update_seer_project_repo(project_repo, {"branch_name": "develop"}) diff --git a/tests/sentry/seer/endpoints/test_project_seer_repos.py b/tests/sentry/seer/endpoints/test_project_seer_repos.py new file mode 100644 index 00000000000000..ee81db899e40ec --- /dev/null +++ b/tests/sentry/seer/endpoints/test_project_seer_repos.py @@ -0,0 +1,571 @@ +from django.urls import reverse + +from sentry.constants import ObjectStatus +from sentry.models.repository import Repository +from sentry.seer.models.project_repository import ( + SeerProjectRepository, + SeerProjectRepositoryBranchOverride, +) +from sentry.testutils.cases import APITestCase + + +class OrganizationSeerProjectReposGetTest(APITestCase): + endpoint = "sentry-api-0-organization-seer-project-repos" + + def reverse_url(self): + return reverse( + self.endpoint, + kwargs={ + "organization_id_or_slug": self.organization.slug, + "project_id": self.project.id, + }, + ) + + def setUp(self) -> None: + super().setUp() + self.login_as(user=self.user) + self.integration = self.create_integration( + organization=self.organization, provider="github", external_id="ext123" + ) + self.repo1 = self.create_repo( + project=self.project, + name="getsentry/sentry", + provider="integrations:github", + external_id="111", + integration_id=self.integration.id, + ) + self.repo2 = self.create_repo( + project=self.project, + name="getsentry/relay", + provider="integrations:github", + external_id="222", + integration_id=self.integration.id, + ) + + def test_empty(self): + response = self.get_success_response() + assert len(response.data) == 0 + + def test_returns_connected_repos(self): + SeerProjectRepository.objects.create( + project=self.project, + repository=self.repo1, + branch_name="main", + instructions="use pytest", + ) + SeerProjectRepository.objects.create(project=self.project, repository=self.repo2) + + response = self.get_success_response() + assert len(response.data) == 2 + + project_repos_by_name = {r["name"]: r for r in response.data} + project_repo_sentry = project_repos_by_name["sentry"] + assert project_repo_sentry["repositoryId"] == str(self.repo1.id) + assert project_repo_sentry["provider"] == "integrations:github" + assert project_repo_sentry["owner"] == "getsentry" + assert project_repo_sentry["externalId"] == "111" + assert project_repo_sentry["integrationId"] == str(self.integration.id) + assert project_repo_sentry["branchName"] == "main" + assert project_repo_sentry["instructions"] == "use pytest" + + project_repo_relay = project_repos_by_name["relay"] + assert project_repo_relay["repositoryId"] == str(self.repo2.id) + assert project_repo_relay["branchName"] is None + assert project_repo_relay["instructions"] is None + + def test_excludes_inactive_repos(self): + SeerProjectRepository.objects.create(project=self.project, repository=self.repo1) + self.repo1.status = ObjectStatus.HIDDEN + self.repo1.save() + + response = self.get_success_response() + assert len(response.data) == 0 + + def test_returns_branch_overrides(self): + project_repo = SeerProjectRepository.objects.create( + project=self.project, repository=self.repo1 + ) + SeerProjectRepositoryBranchOverride.objects.create( + seer_project_repository=project_repo, + tag_name="environment", + tag_value="production", + branch_name="release", + ) + + response = self.get_success_response() + assert len(response.data[0]["branchOverrides"]) == 1 + branch_overrides = response.data[0]["branchOverrides"][0] + assert branch_overrides["tagName"] == "environment" + assert branch_overrides["tagValue"] == "production" + assert branch_overrides["branchName"] == "release" + + def test_search_by_name(self): + SeerProjectRepository.objects.create(project=self.project, repository=self.repo1) + SeerProjectRepository.objects.create(project=self.project, repository=self.repo2) + + response = self.get_success_response(qs_params={"query": "relay"}) + assert len(response.data) == 1 + assert response.data[0]["name"] == "relay" + + def test_search_by_name_exclude(self): + SeerProjectRepository.objects.create(project=self.project, repository=self.repo1) + SeerProjectRepository.objects.create(project=self.project, repository=self.repo2) + + response = self.get_success_response(qs_params={"query": "!name:relay"}) + assert len(response.data) == 1 + assert response.data[0]["name"] == "sentry" + + def test_search_by_provider(self): + SeerProjectRepository.objects.create(project=self.project, repository=self.repo1) + SeerProjectRepository.objects.create(project=self.project, repository=self.repo2) + + response = self.get_success_response(qs_params={"query": "provider:github"}) + assert len(response.data) == 2 + + response = self.get_success_response(qs_params={"query": "provider:integrations:github"}) + assert len(response.data) == 2 + + def test_sort_by_name_ascending(self): + SeerProjectRepository.objects.create(project=self.project, repository=self.repo1) + SeerProjectRepository.objects.create(project=self.project, repository=self.repo2) + + response = self.get_success_response(qs_params={"sortBy": "name"}) + names = [r["name"] for r in response.data] + assert names == ["relay", "sentry"] + + def test_sort_by_name_descending(self): + SeerProjectRepository.objects.create(project=self.project, repository=self.repo1) + SeerProjectRepository.objects.create(project=self.project, repository=self.repo2) + + response = self.get_success_response(qs_params={"sortBy": "-name"}) + names = [r["name"] for r in response.data] + assert names == ["sentry", "relay"] + + def test_invalid_sort_field(self): + response = self.get_error_response(qs_params={"sortBy": "invalid"}, status_code=400) + assert "Invalid sortBy" in response.data["detail"] + + def test_invalid_search_query(self): + self.get_error_response(qs_params={"query": "invalid:field:value"}, status_code=400) + + +class OrganizationSeerProjectReposPostTest(APITestCase): + endpoint = "sentry-api-0-organization-seer-project-repos" + method = "post" + + def reverse_url(self): + return reverse( + self.endpoint, + kwargs={ + "organization_id_or_slug": self.organization.slug, + "project_id": self.project.id, + }, + ) + + def setUp(self) -> None: + super().setUp() + self.login_as(user=self.user) + self.repo1 = self.create_repo( + project=self.project, + name="getsentry/sentry", + provider="integrations:github", + external_id="111", + ) + self.repo2 = self.create_repo( + project=self.project, + name="getsentry/relay", + provider="integrations:github", + external_id="222", + ) + + def test_add_repos(self): + response = self.get_success_response( + repos=[ + { + "repositoryId": self.repo1.id, + "branchName": "main", + "instructions": "run tests", + }, + {"repositoryId": self.repo2.id}, + ], + status_code=201, + ) + assert len(response.data) == 2 + + project_repos_by_id = {r["repositoryId"]: r for r in response.data} + assert project_repos_by_id[str(self.repo1.id)]["branchName"] == "main" + assert project_repos_by_id[str(self.repo1.id)]["instructions"] == "run tests" + assert project_repos_by_id[str(self.repo2.id)]["branchName"] is None + + assert SeerProjectRepository.objects.filter(project=self.project).count() == 2 + + def test_add_repos_with_branch_overrides(self): + response = self.get_success_response( + repos=[ + { + "repositoryId": self.repo1.id, + "branchOverrides": [ + { + "tagName": "environment", + "tagValue": "production", + "branchName": "release", + } + ], + } + ], + status_code=201, + ) + assert len(response.data[0]["branchOverrides"]) == 1 + assert response.data[0]["branchOverrides"][0]["branchName"] == "release" + + def test_empty_repos_returns_400(self): + response = self.get_error_response(repos=[], status_code=400) + assert "repos must not be empty" in response.data["detail"] + + def test_invalid_repo_id_returns_400(self): + response = self.get_error_response(repos=[{"repositoryId": 99999}], status_code=400) + assert "Invalid repository IDs" in response.data["detail"] + + def test_repo_from_other_org_returns_400(self): + other_org = self.create_organization(owner=self.user) + other_repo = Repository.objects.create( + organization_id=other_org.id, name="other/repo", provider="github", external_id="999" + ) + + self.get_error_response(repos=[{"repositoryId": other_repo.id}], status_code=400) + + def test_unsupported_provider_returns_400(self): + unsupported_repo = Repository.objects.create( + organization_id=self.organization.id, + name="getsentry/unsupported", + provider="integrations:gitlab", + external_id="999", + ) + + self.get_error_response(repos=[{"repositoryId": unsupported_repo.id}], status_code=400) + + def test_already_connected_repo_returns_409(self): + SeerProjectRepository.objects.create(project=self.project, repository=self.repo1) + + response = self.get_error_response(repos=[{"repositoryId": self.repo1.id}], status_code=409) + assert "Repositories already connected" in response.data["detail"] + + def test_inactive_repo_returns_400(self): + self.repo1.status = ObjectStatus.HIDDEN + self.repo1.save() + + self.get_error_response(repos=[{"repositoryId": self.repo1.id}], status_code=400) + + def test_duplicate_branch_override_returns_400(self): + self.get_error_response( + repos=[ + { + "repositoryId": self.repo1.id, + "branchOverrides": [ + { + "tagName": "environment", + "tagValue": "production", + "branchName": "release", + }, + { + "tagName": "environment", + "tagValue": "production", + "branchName": "hotfix", + }, + ], + } + ], + status_code=400, + ) + + def test_missing_repository_id_returns_400(self): + self.get_error_response(repos=[{"branchName": "main"}], status_code=400) + + +class OrganizationSeerProjectReposPutTest(APITestCase): + endpoint = "sentry-api-0-organization-seer-project-repos" + method = "put" + + def reverse_url(self): + return reverse( + self.endpoint, + kwargs={ + "organization_id_or_slug": self.organization.slug, + "project_id": self.project.id, + }, + ) + + def setUp(self) -> None: + super().setUp() + self.login_as(user=self.user) + self.repo1 = self.create_repo( + project=self.project, + name="getsentry/sentry", + provider="integrations:github", + external_id="111", + ) + self.repo2 = self.create_repo( + project=self.project, + name="getsentry/relay", + provider="integrations:github", + external_id="222", + ) + + def test_replace_all_repos(self): + SeerProjectRepository.objects.create(project=self.project, repository=self.repo1) + + response = self.get_success_response( + repos=[{"repositoryId": self.repo2.id, "branchName": "develop"}], + ) + assert len(response.data) == 1 + assert response.data[0]["repositoryId"] == str(self.repo2.id) + assert response.data[0]["branchName"] == "develop" + + assert not SeerProjectRepository.objects.filter( + project=self.project, repository=self.repo1 + ).exists() + + def test_replace_with_empty_clears_all(self): + SeerProjectRepository.objects.create(project=self.project, repository=self.repo1) + SeerProjectRepository.objects.create(project=self.project, repository=self.repo2) + + response = self.get_success_response(repos=[]) + assert response.data == [] + assert SeerProjectRepository.objects.filter(project=self.project).count() == 0 + + def test_replace_invalid_repo_returns_400(self): + self.get_error_response(repos=[{"repositoryId": 99999}], status_code=400) + + def test_replace_with_branch_overrides(self): + response = self.get_success_response( + repos=[ + { + "repositoryId": self.repo1.id, + "branchOverrides": [ + { + "tagName": "environment", + "tagValue": "staging", + "branchName": "staging-branch", + } + ], + } + ], + ) + assert len(response.data[0]["branchOverrides"]) == 1 + + +class OrganizationSeerProjectRepoDetailsGetTest(APITestCase): + endpoint = "sentry-api-0-organization-seer-project-repo-details" + + def detail_url(self, repo_id): + return reverse( + self.endpoint, + kwargs={ + "organization_id_or_slug": self.organization.slug, + "project_id": self.project.id, + "repo_id": repo_id, + }, + ) + + def reverse_url(self): + return self.detail_url(self.repo1.id) + + def setUp(self) -> None: + super().setUp() + self.login_as(user=self.user) + self.repo1 = self.create_repo( + project=self.project, + name="getsentry/sentry", + provider="integrations:github", + external_id="111", + ) + + def test_get_repo(self): + project_repo = SeerProjectRepository.objects.create( + project=self.project, repository=self.repo1, branch_name="main", instructions="hello" + ) + SeerProjectRepositoryBranchOverride.objects.create( + seer_project_repository=project_repo, + tag_name="environment", + tag_value="production", + branch_name="release", + ) + + response = self.get_success_response() + assert response.data["repositoryId"] == str(self.repo1.id) + assert response.data["branchName"] == "main" + assert response.data["instructions"] == "hello" + assert len(response.data["branchOverrides"]) == 1 + assert response.data["branchOverrides"][0]["tagName"] == "environment" + assert response.data["branchOverrides"][0]["tagValue"] == "production" + assert response.data["branchOverrides"][0]["branchName"] == "release" + + def test_not_connected_returns_404(self): + self.get_error_response(status_code=404) + + def test_inactive_repo_returns_404(self): + SeerProjectRepository.objects.create(project=self.project, repository=self.repo1) + self.repo1.status = ObjectStatus.HIDDEN + self.repo1.save() + + self.get_error_response(status_code=404) + + def test_nonexistent_repo_returns_404(self): + response = self.client.get(self.detail_url(99999)) + assert response.status_code == 404 + + +class OrganizationSeerProjectRepoDetailsPutTest(APITestCase): + endpoint = "sentry-api-0-organization-seer-project-repo-details" + method = "put" + + def reverse_url(self): + return reverse( + self.endpoint, + kwargs={ + "organization_id_or_slug": self.organization.slug, + "project_id": self.project.id, + "repo_id": self.repo1.id, + }, + ) + + def setUp(self) -> None: + super().setUp() + self.login_as(user=self.user) + self.repo1 = self.create_repo( + project=self.project, + name="getsentry/sentry", + provider="integrations:github", + external_id="111", + ) + + def test_update_branch_name(self): + SeerProjectRepository.objects.create( + project=self.project, repository=self.repo1, branch_name="main" + ) + + response = self.get_success_response(branchName="develop") + assert response.data["branchName"] == "develop" + + def test_update_instructions(self): + SeerProjectRepository.objects.create(project=self.project, repository=self.repo1) + + response = self.get_success_response(instructions="new instructions") + assert response.data["instructions"] == "new instructions" + + def test_update_branch_overrides(self): + pr = SeerProjectRepository.objects.create(project=self.project, repository=self.repo1) + SeerProjectRepositoryBranchOverride.objects.create( + seer_project_repository=pr, + tag_name="environment", + tag_value="production", + branch_name="old-branch", + ) + + response = self.get_success_response( + branchOverrides=[ + { + "tagName": "environment", + "tagValue": "staging", + "branchName": "staging-branch", + } + ], + ) + assert len(response.data["branchOverrides"]) == 1 + assert response.data["branchOverrides"][0]["tagValue"] == "staging" + + assert ( + SeerProjectRepositoryBranchOverride.objects.filter(seer_project_repository=pr).count() + == 1 + ) + + def test_partial_update_preserves_other_fields(self): + SeerProjectRepository.objects.create( + project=self.project, repository=self.repo1, branch_name="main", instructions="original" + ) + + response = self.get_success_response(branchName="develop") + assert response.data["branchName"] == "develop" + assert response.data["instructions"] == "original" + + def test_not_connected_returns_404(self): + self.get_error_response(branchName="main", status_code=404) + + def test_set_null_branch_name(self): + SeerProjectRepository.objects.create( + project=self.project, repository=self.repo1, branch_name="main" + ) + + response = self.get_success_response(branchName=None) + assert response.data["branchName"] is None + + def test_clear_branch_overrides(self): + pr = SeerProjectRepository.objects.create(project=self.project, repository=self.repo1) + SeerProjectRepositoryBranchOverride.objects.create( + seer_project_repository=pr, + tag_name="environment", + tag_value="production", + branch_name="release", + ) + + response = self.get_success_response(branchOverrides=[]) + assert response.data["branchOverrides"] == [] + assert ( + SeerProjectRepositoryBranchOverride.objects.filter(seer_project_repository=pr).count() + == 0 + ) + + +class OrganizationSeerProjectRepoDetailsDeleteTest(APITestCase): + endpoint = "sentry-api-0-organization-seer-project-repo-details" + method = "delete" + + def reverse_url(self): + return reverse( + self.endpoint, + kwargs={ + "organization_id_or_slug": self.organization.slug, + "project_id": self.project.id, + "repo_id": self.repo1.id, + }, + ) + + def setUp(self) -> None: + super().setUp() + self.login_as(user=self.user) + self.repo1 = self.create_repo( + project=self.project, + name="getsentry/sentry", + provider="integrations:github", + external_id="111", + ) + + def test_delete_repo(self): + SeerProjectRepository.objects.create(project=self.project, repository=self.repo1) + + self.get_success_response() + assert not SeerProjectRepository.objects.filter( + project=self.project, repository=self.repo1 + ).exists() + + def test_delete_not_connected_returns_404(self): + self.get_error_response(status_code=404) + + def test_delete_inactive_repo_returns_404(self): + SeerProjectRepository.objects.create(project=self.project, repository=self.repo1) + self.repo1.status = ObjectStatus.HIDDEN + self.repo1.save() + + self.get_error_response(status_code=404) + + def test_delete_cascades_branch_overrides(self): + pr = SeerProjectRepository.objects.create(project=self.project, repository=self.repo1) + SeerProjectRepositoryBranchOverride.objects.create( + seer_project_repository=pr, + tag_name="environment", + tag_value="production", + branch_name="release", + ) + + self.get_success_response() + assert SeerProjectRepositoryBranchOverride.objects.count() == 0