diff --git a/docs/providers/documentation/gke-provider.mdx b/docs/providers/documentation/gke-provider.mdx index 2aa1b342e4..e5c92e6c82 100644 --- a/docs/providers/documentation/gke-provider.mdx +++ b/docs/providers/documentation/gke-provider.mdx @@ -13,6 +13,8 @@ import AutoGeneratedSnippet from '/snippets/providers/gke-snippet-autogenerated. 2. Ensure your service account has the necessary permissions to manage GKE clusters (`roles/container.admin`). 3. Provide the `gcp_credentials`, `project_id`, and `zone` in your provider configuration. +Alternatively, leave `service_account_json` empty to authenticate with [Application Default Credentials](https://cloud.google.com/docs/authentication/application-default-credentials), for example a GKE [Workload Identity](https://cloud.google.com/kubernetes-engine/docs/concepts/workload-identity) service account. The project is resolved automatically and can be overridden with `project_id`. + ## Usefull Links -[Google Kubernetes Engine Documentation](https://cloud.google.com/kubernetes-engine/docs) diff --git a/docs/snippets/providers/gke-snippet-autogenerated.mdx b/docs/snippets/providers/gke-snippet-autogenerated.mdx index 926cae5d61..d96a811034 100644 --- a/docs/snippets/providers/gke-snippet-autogenerated.mdx +++ b/docs/snippets/providers/gke-snippet-autogenerated.mdx @@ -1,11 +1,12 @@ -{/* This snippet is automatically generated using scripts/docs_render_provider_snippets.py +{/* This snippet is automatically generated using scripts/docs_render_provider_snippets.py Do not edit it manually, as it will be overwritten */} ## Authentication This provider requires authentication. -- **service_account_json**: The service account JSON with container.viewer role (required: True, sensitive: True) - **cluster_name**: The name of the cluster (required: True, sensitive: False) +- **service_account_json**: The service account JSON with container.viewer role. Leave empty to use Application Default Credentials (e.g. GKE Workload Identity) (required: False, sensitive: True) - **region**: The GKE cluster region (required: False, sensitive: False) +- **project_id**: The GCP project id (defaults to the service account project or the Application Default Credentials project) (required: False, sensitive: False) Certain scopes may be required to perform specific actions or queries via the provider. Below is a summary of relevant scopes and their use cases: - **roles/container.viewer**: Read access to GKE resources (mandatory) diff --git a/keep/providers/gke_provider/gke_credentials.py b/keep/providers/gke_provider/gke_credentials.py new file mode 100644 index 0000000000..83e48115dc --- /dev/null +++ b/keep/providers/gke_provider/gke_credentials.py @@ -0,0 +1,49 @@ +import json + +from google.auth import default as google_auth_default +from google.auth.credentials import Credentials +from google.auth.exceptions import DefaultCredentialsError +from google.oauth2 import service_account + +GKE_SCOPES = ["https://www.googleapis.com/auth/cloud-platform"] + + +def resolve_service_account( + service_account_json: str | None, + project_id: str | None = None, + logger=None, +) -> tuple[dict | None, str | None]: + """Parse the optional service account JSON; data is None to fall back to ADC.""" + resolved_project = project_id or None + if not service_account_json: + return None, resolved_project + try: + data = json.loads(service_account_json) + except Exception: + if logger is not None: + logger.warning( + "Invalid service_account_json provided, falling back to " + "Application Default Credentials" + ) + return None, resolved_project + return data, resolved_project or data.get("project_id") + + +def build_gke_credentials( + service_account_data: dict | None = None, + project_id: str | None = None, +) -> tuple[Credentials, str | None]: + """Return (credentials, project_id) from the service account JSON, or ADC if none.""" + if service_account_data: + credentials = service_account.Credentials.from_service_account_info( + service_account_data, scopes=GKE_SCOPES + ) + return credentials, project_id or service_account_data.get("project_id") + + try: + credentials, default_project = google_auth_default(scopes=GKE_SCOPES) + except DefaultCredentialsError as exc: + raise DefaultCredentialsError( + "No service account JSON provided and no Application Default Credentials found" + ) from exc + return credentials, project_id or default_project diff --git a/keep/providers/gke_provider/gke_provider.py b/keep/providers/gke_provider/gke_provider.py index 372b998fa1..cf93116edb 100644 --- a/keep/providers/gke_provider/gke_provider.py +++ b/keep/providers/gke_provider/gke_provider.py @@ -5,13 +5,16 @@ import pydantic from google.auth.transport import requests from google.cloud.container_v1 import ClusterManagerClient -from google.oauth2 import service_account from kubernetes import client, config from kubernetes.stream import stream from keep.contextmanager.contextmanager import ContextManager from keep.exceptions.provider_exception import ProviderException from keep.providers.base.base_provider import BaseProvider +from keep.providers.gke_provider.gke_credentials import ( + build_gke_credentials, + resolve_service_account, +) from keep.providers.models.provider_config import ProviderConfig, ProviderScope from keep.providers.models.provider_method import ProviderMethod from keep.providers.providers_factory import ProvidersFactory @@ -21,18 +24,19 @@ class GkeProviderAuthConfig: """GKE authentication configuration.""" + cluster_name: str = dataclasses.field( + metadata={"required": True, "description": "The name of the cluster"} + ) service_account_json: str = dataclasses.field( + default="", metadata={ - "required": True, - "description": "The service account JSON with container.viewer role", + "required": False, + "description": "The service account JSON with container.viewer role. Leave empty to use Application Default Credentials (e.g. GKE Workload Identity)", "sensitive": True, "type": "file", "name": "service_account_json", "file_type": "application/json", - } - ) - cluster_name: str = dataclasses.field( - metadata={"required": True, "description": "The name of the cluster"} + }, ) region: str = dataclasses.field( default="us-central1", @@ -42,6 +46,14 @@ class GkeProviderAuthConfig: "hint": "us-central1", }, ) + project_id: str = dataclasses.field( + default="", + metadata={ + "required": False, + "description": "The GCP project id (defaults to the service account project or the Application Default Credentials project)", + "hint": "my-gcp-project", + }, + ) class GkeProvider(BaseProvider): @@ -157,14 +169,11 @@ def __init__( self, context_manager: ContextManager, provider_id: str, config: ProviderConfig ): super().__init__(context_manager, provider_id, config) - try: - self._service_account_data = json.loads( - self.authentication_config.service_account_json - ) - self._project_id = self._service_account_data.get("project_id") - except Exception: - self._service_account_data = None - self._project_id = None + self._service_account_data, self._project_id = resolve_service_account( + self.authentication_config.service_account_json, + self.authentication_config.project_id, + self.logger, + ) self._region = self.authentication_config.region self._cluster_name = self.authentication_config.cluster_name self._client = None @@ -174,23 +183,28 @@ def dispose(self): if self._client: self._client.api_client.rest_client.pool_manager.clear() + def _get_credentials(self): + credentials, project_id = build_gke_credentials( + self._service_account_data, self._project_id + ) + self._project_id = project_id + return credentials + def validate_config(self): """Validate the provided configuration.""" self.authentication_config = GkeProviderAuthConfig(**self.config.authentication) def validate_scopes(self) -> dict[str, bool | str]: """Validate if the service account has the required permissions.""" - if not self._service_account_data or not self._project_id: - return {"roles/container.viewer": "Service account JSON is invalid"} - scopes = {scope.name: False for scope in self.PROVIDER_SCOPES} try: # Test GKE API permissions - credentials = service_account.Credentials.from_service_account_info( - self._service_account_data, - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ) + credentials = self._get_credentials() + if not self._project_id: + raise ProviderException( + "Could not resolve the GCP project id; set project_id or provide a service account JSON" + ) auth_request = requests.Request() credentials.refresh(auth_request) gke_client = ClusterManagerClient(credentials=credentials) @@ -425,10 +439,7 @@ def __generate_client(self): """Generate a Kubernetes client configured for GKE.""" try: # Create GKE client with credentials - credentials = service_account.Credentials.from_service_account_info( - self._service_account_data, - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ) + credentials = self._get_credentials() auth_request = requests.Request() credentials.refresh(auth_request) gke_client = ClusterManagerClient(credentials=credentials) diff --git a/tests/test_gke_credentials.py b/tests/test_gke_credentials.py new file mode 100644 index 0000000000..bac19efc19 --- /dev/null +++ b/tests/test_gke_credentials.py @@ -0,0 +1,98 @@ +import json +from unittest.mock import MagicMock, patch + +import pytest +from google.auth.exceptions import DefaultCredentialsError + +from keep.providers.gke_provider.gke_credentials import ( + build_gke_credentials, + resolve_service_account, +) + +MODULE = "keep.providers.gke_provider.gke_credentials" + + +def test_uses_service_account_when_json_provided(): + data = {"project_id": "sa-project", "client_email": "x@y.iam"} + with ( + patch(f"{MODULE}.service_account") as mock_sa, + patch(f"{MODULE}.google_auth_default") as mock_default, + ): + credentials, project_id = build_gke_credentials(data) + mock_sa.Credentials.from_service_account_info.assert_called_once() + mock_default.assert_not_called() + assert project_id == "sa-project" + assert credentials is mock_sa.Credentials.from_service_account_info.return_value + + +def test_falls_back_to_adc_when_no_service_account(): + with ( + patch(f"{MODULE}.service_account") as mock_sa, + patch( + f"{MODULE}.google_auth_default", + return_value=("adc-creds", "adc-project"), + ) as mock_default, + ): + credentials, project_id = build_gke_credentials(None) + mock_default.assert_called_once() + mock_sa.Credentials.from_service_account_info.assert_not_called() + assert credentials == "adc-creds" + assert project_id == "adc-project" + + +def test_explicit_project_id_overrides_adc(): + with patch(f"{MODULE}.google_auth_default", return_value=("c", "adc-project")): + _, project_id = build_gke_credentials(None, project_id="explicit") + assert project_id == "explicit" + + +def test_service_account_project_used_when_no_explicit_project(): + data = {"project_id": "sa-project"} + with patch(f"{MODULE}.service_account"): + _, project_id = build_gke_credentials(data, project_id="") + assert project_id == "sa-project" + + +def test_build_gke_credentials_clear_error_when_no_adc(): + with patch( + f"{MODULE}.google_auth_default", + side_effect=DefaultCredentialsError("raw"), + ): + with pytest.raises( + DefaultCredentialsError, match="No service account JSON provided" + ): + build_gke_credentials(None) + + +def test_resolve_parses_json_and_project(): + data = {"project_id": "sa-project"} + parsed, project_id = resolve_service_account(json.dumps(data)) + assert parsed == data + assert project_id == "sa-project" + + +def test_resolve_empty_json_returns_none(): + assert resolve_service_account("") == (None, None) + + +def test_resolve_explicit_project_overrides_sa_project(): + data = {"project_id": "sa-project"} + parsed, project_id = resolve_service_account( + json.dumps(data), project_id="explicit" + ) + assert parsed == data + assert project_id == "explicit" + + +def test_resolve_malformed_json_warns_and_falls_back(): + logger = MagicMock() + parsed, project_id = resolve_service_account("{not-json", logger=logger) + assert parsed is None + assert project_id is None + logger.warning.assert_called_once() + + +def test_resolve_malformed_json_keeps_explicit_project(): + parsed, project_id = resolve_service_account("{not-json", project_id="explicit") + assert parsed is None + assert project_id == "explicit"