diff --git a/cms/auth_content/constants.py b/cms/auth_content/constants.py index 0920faa1f0..1a12935387 100644 --- a/cms/auth_content/constants.py +++ b/cms/auth_content/constants.py @@ -3,7 +3,6 @@ get_all_theme_names_and_ids, ) -WILDCARD_ID_VALUE = "-1" PERMISSION_SET_FIELDS = [ { "field_name": "theme", diff --git a/cms/auth_content/models/permission_sets.py b/cms/auth_content/models/permission_sets.py index ba5e2ca8bd..4b01792298 100644 --- a/cms/auth_content/models/permission_sets.py +++ b/cms/auth_content/models/permission_sets.py @@ -6,7 +6,7 @@ from wagtail.admin.panels import FieldPanel, mark_safe from cms.auth_content.auth_utils import _create_form_field -from cms.auth_content.constants import PERMISSION_SET_FIELDS, WILDCARD_ID_VALUE +from cms.auth_content.constants import PERMISSION_SET_FIELDS from cms.dynamic_content import help_texts from cms.metrics_interface.field_choices_callables import ( get_all_geography_names_and_codes, @@ -16,6 +16,7 @@ get_all_theme_names_and_ids, get_all_topic_names_and_ids, ) +from common.auth.permissions import WILDCARD_ID_VALUE class PermissionSetForm(WagtailAdminPageForm): diff --git a/cms/dashboard/viewsets.py b/cms/dashboard/viewsets.py index aee25c8d4e..38deec9ef1 100644 --- a/cms/dashboard/viewsets.py +++ b/cms/dashboard/viewsets.py @@ -1,4 +1,3 @@ -import logging from itertools import chain from django.urls import path @@ -10,43 +9,15 @@ from caching.private_api.decorators import cache_response from cms.auth_content.auth_utils import is_auth_enabled -from cms.auth_content.constants import WILDCARD_ID_VALUE from cms.dashboard.serializers import CMSDraftPagesSerializer, ListablePageSerializer from cms.metrics_documentation.models.child import MetricsDocumentationChildEntry from cms.topic.models import TopicPage +from common.auth.logging import log_user_permission_summary +from common.auth.permissions import check_page_permissions -logger = logging.getLogger(__name__) AUTH_ENABLED = is_auth_enabled() -def check_permissions(user_permissions, theme_id, sub_theme_id, topic_id) -> bool: - if not isinstance(user_permissions, list): - return False - - for permission in user_permissions: - permission_theme_id = permission.get("theme", {}).get("id") - permission_sub_theme_id = permission.get("sub_theme", {}).get("id") - permission_topic_id = permission.get("topic", {}).get("id") - - if permission_theme_id == WILDCARD_ID_VALUE: - return True - - if ( - permission_theme_id == theme_id - and permission_sub_theme_id == WILDCARD_ID_VALUE - ): - return True - - if ( - permission_theme_id == theme_id - and permission_sub_theme_id == sub_theme_id - and (permission_topic_id in {WILDCARD_ID_VALUE, topic_id}) - ): - return True - - return False - - @extend_schema(tags=["cms"]) class CMSPagesAPIViewSet(PagesAPIViewSet): # This is the /pages (or proxy/pages env dependent endpoint) @@ -109,19 +80,14 @@ def get_queryset(self): filtered_queryset = is_public_pages | pages_without_is_public else: - logger.info( - "User %s has total permission sets: %s", - req.user.username, - req.user.permission_sets["summary"]["total_permission_sets"], - ) + log_user_permission_summary(req.user) + has_global_access = req.user.permission_sets["summary"][ "has_global_access" ] if has_global_access: - logger.info("User %s has global access", req.user.username) filtered_queryset = queryset - else: user_permissions = req.user.permission_sets["permission_sets"] pages_to_check = chain( @@ -138,11 +104,11 @@ def get_queryset(self): page_id for page_id, page in pages_to_check if page.is_public - or check_permissions( - user_permissions, - page.theme, - page.sub_theme, - page.topic, + or check_page_permissions( + permission_sets=user_permissions, + theme_id=page.theme, + sub_theme_id=page.sub_theme, + topic_id=page.topic, ) ] diff --git a/common/auth/cognito_jwt/backend.py b/common/auth/cognito_jwt/backend.py index 1fed42fec9..86efd6f11b 100644 --- a/common/auth/cognito_jwt/backend.py +++ b/common/auth/cognito_jwt/backend.py @@ -50,16 +50,19 @@ def authenticate(self, request): raise exceptions.AuthenticationFailed from None custom_user_manager = self.get_custom_user_manager() + if custom_user_manager: user = custom_user_manager.get_or_create_for_cognito(jwt_payload) else: user_model = self.get_user_model() user = user_model.objects.get_or_create_for_cognito(jwt_payload) + if not user: logger.debug( "Unable to create user from JWT, defaulting to unauthenticated" ) return None + return (user, jwt_token) @staticmethod diff --git a/common/auth/logging.py b/common/auth/logging.py new file mode 100644 index 0000000000..6c63424b45 --- /dev/null +++ b/common/auth/logging.py @@ -0,0 +1,43 @@ +"""Utilities for logging authentication and permission information across the API.""" + +import logging +from typing import Any + +logger = logging.getLogger(__name__) + + +def log_user_permission_summary(user: Any) -> None: + """Log permission information for an authenticated user. + + This function logs the permission set summary and global access status. + It expects ``user.permission_sets`` to be a dict with the shape produced + by ``CognitoManager.get_or_create_for_cognito``: + + .. code-block:: python + + { + "permission_sets": [...], + "summary": {"total_permission_sets": 2, "has_global_access": False}, + } + + Args: + user: The authenticated user object that has a ``permission_sets`` dict. + """ + + if not hasattr(user, "username"): + return + if not hasattr(user, "permission_sets"): + return + + username = user.username + permission_sets = user.permission_sets + + if not isinstance(permission_sets, dict): + return + + log_msg = f'User {username} has total permission sets {permission_sets["summary"]["total_permission_sets"]}' + + if permission_sets["summary"]["has_global_access"]: + log_msg += " and global access" + + logger.info(log_msg) diff --git a/common/auth/permissions.py b/common/auth/permissions.py new file mode 100644 index 0000000000..3b9d2c91d8 --- /dev/null +++ b/common/auth/permissions.py @@ -0,0 +1,354 @@ +from typing import TypedDict + +from ingestion.metrics_interface.interface import MetricsAPIInterface + +WILDCARD_ID_VALUE = "-1" + +""" + A few classes with type hints to represent our complete + JWT permission set hierarchy. Please do import and use + this from other modules too to keep us safe & well: +""" + + +class PermissionRowType(TypedDict): + theme: dict[str, str] + sub_theme: dict[str, str] + topic: dict[str, str] + metric: dict[str, str] + geography_type: dict[str, str] + geography: dict[str, str] + + +class PermissionSetSummaryType(TypedDict): + has_global_access: bool + + +class PermissionSetsType(TypedDict): + permission_sets: list[PermissionRowType] + summary: PermissionSetSummaryType + + +def check_chart_permissions_by_name( + *, + permission_sets: PermissionSetsType, + theme_name: str, + sub_theme_name: str, + topic_name: str, + metric_name: str, + geography_type: str, + geography_name: str, +) -> bool: + """Convert permission resource names into ids (before checking CHART permissions).""" + + if not isinstance(permission_sets, dict): + return False + if not isinstance(permission_sets.get("permission_sets"), list): + return False + if not isinstance(permission_sets.get("summary"), dict): + return False + if not isinstance(permission_sets.get("summary").get("has_global_access"), bool): + return False + + if permission_sets.get("summary").get("has_global_access"): + return True + + topic_manager = MetricsAPIInterface.get_topic_manager() + metric_manager = MetricsAPIInterface.get_metric_manager() + geography_type_manager = MetricsAPIInterface.get_geography_type_manager() + geography_manager = MetricsAPIInterface.get_geography_manager() + + theme_id, sub_theme_id, topic_id = topic_manager.get_id_by_name( + theme_name, sub_theme_name, topic_name + ) + metric_id = metric_manager.get_id_by_name(metric_name) + geography_type_id = geography_type_manager.get_id_by_name(geography_type) + geography_id = geography_manager.get_code_by_name(geography_name, geography_type) + + # Sanity check, because front-end must always + # send content for any of these 6 requests + if any( + value is None + for value in ( + theme_id, + sub_theme_id, + topic_id, + metric_id, + geography_type_id, + geography_id, + ) + ): + return False + + return check_chart_permissions( + permission_sets=permission_sets.get("permission_sets"), + theme_id=theme_id, + sub_theme_id=sub_theme_id, + topic_id=topic_id, + metric_id=metric_id, + geography_type=geography_type_id, + geography_id=geography_id, + ) + + +def check_chart_permissions( # noqa: PLR0914 + *, + permission_sets: list[PermissionRowType], + theme_id: str, + sub_theme_id: str, + topic_id: str, + metric_id: str, + geography_type: str, + geography_id: str, +) -> bool: + """Check permissions whether the end-user can access a specific CHART through the API.""" + + if not isinstance(permission_sets, list): + return False + + resource_ids = _normalize_resource_ids( + theme_id, + sub_theme_id, + topic_id, + metric_id, + geography_type, + geography_id, + ) + if resource_ids is None: + return False + ( + theme_id, + sub_theme_id, + topic_id, + metric_id, + geography_type, + geography_id, + ) = resource_ids + + for permission_set in permission_sets: + if not isinstance(permission_set, dict): + return False + + permission_ids = _normalize_permission_ids( + "theme", + "sub_theme", + "topic", + "metric", + "geography_type", + "geography", + permission_set=permission_set, + ) + + # All permission fields must be present + if permission_ids is None: + return False + ( + permission_theme_id, + permission_sub_theme_id, + permission_topic_id, + permission_metric_id, + permission_geography_type, + permission_geography_id, + ) = permission_ids + + # Themes, sub themes, topics & metrics have their own + # dependency hierarchy (means wildcards can be at the end) + has_theme_sub_theme_topic_permissions = check_theme_sub_theme_topic_permissions( + permission_theme_id=permission_theme_id, + permission_sub_theme_id=permission_sub_theme_id, + permission_topic_id=permission_topic_id, + theme_id=theme_id, + sub_theme_id=sub_theme_id, + topic_id=topic_id, + ) + has_metric_permissions = check_metric_permissions( + permission_metric_id=permission_metric_id, + metric_id=metric_id, + ) + + # Geographies have their own dependency hierarchy too + has_geography_permissions = check_geography_permissions( + permission_geography_type=permission_geography_type, + permission_geography_id=permission_geography_id, + geography_type=geography_type, + geography_id=geography_id, + ) + + if ( + has_theme_sub_theme_topic_permissions + and has_metric_permissions + and has_geography_permissions + ): + return True + + return False + + +def check_page_permissions( + *, + permission_sets: list[PermissionRowType], + theme_id: str, + sub_theme_id: str, + topic_id: str, +) -> bool: + """Check permissions whether the end-user can access a specific CMS PAGE through the API.""" + + if not isinstance(permission_sets, list): + return False + + resource_ids = _normalize_resource_ids(theme_id, sub_theme_id, topic_id) + if resource_ids is None: + return False + theme_id, sub_theme_id, topic_id = resource_ids + + for permission_set in permission_sets: + if not isinstance(permission_set, dict): + return False + + # Theme must be present, but other permission fields are + # optional, as wildcard hierarchy allows early short-circuit + permission_theme_id = _normalize_permission_id( + field_name="theme", permission_set=permission_set + ) + if permission_theme_id is None: + return False + permission_sub_theme_id = ( + _normalize_permission_id( + field_name="sub_theme", permission_set=permission_set + ) + or "" + ) + permission_topic_id = ( + _normalize_permission_id(field_name="topic", permission_set=permission_set) + or "" + ) + + if check_theme_sub_theme_topic_permissions( + permission_theme_id=permission_theme_id, + permission_sub_theme_id=permission_sub_theme_id, + permission_topic_id=permission_topic_id, + theme_id=theme_id, + sub_theme_id=sub_theme_id, + topic_id=topic_id, + ): + return True + + return False + + +def check_theme_sub_theme_topic_permissions( + *, + permission_theme_id: str, + permission_sub_theme_id: str, + permission_topic_id: str, + theme_id: str, + sub_theme_id: str, + topic_id: str, +) -> bool: + """ + Evaluate the theme/sub-theme/topic portion of a permission row + with its own dependency hierarchy (means wildcards can be at the end) + """ + + if permission_theme_id == WILDCARD_ID_VALUE: + return True + + if permission_theme_id == theme_id and permission_sub_theme_id == WILDCARD_ID_VALUE: + return True + + if ( # noqa: SIM103 + permission_theme_id == theme_id + and permission_sub_theme_id == sub_theme_id + and (permission_topic_id in {WILDCARD_ID_VALUE, topic_id}) + ): + return True + + return False + + +def check_metric_permissions( + *, + permission_metric_id: str, + metric_id: str, +) -> bool: + """ + Evaluate the metric portion of a permission row + for it to be either a wildcard or a match. + """ + + if permission_metric_id in {WILDCARD_ID_VALUE, metric_id}: # noqa: SIM103 + return True + + return False + + +def check_geography_permissions( + *, + permission_geography_type: str, + permission_geography_id: str, + geography_type: str, + geography_id: str, +) -> bool: + """ + Evaluate the geography_type/geography portion of a permission row + with its own dependency hierarchy (means wildcards can be at the end) + """ + + if permission_geography_type == WILDCARD_ID_VALUE: + return True + + if ( # noqa: SIM103 + permission_geography_type == geography_type + and permission_geography_id in {WILDCARD_ID_VALUE, geography_id} + ): + return True + + return False + + +def _get_id_string_or_none(my_id: int | str | None) -> str | None: + """Normalize id to string whilst preserving None values""" + + return str(my_id) if my_id is not None else None + + +def _normalize_resource_ids(*ids: int | str | None) -> tuple[str, ...] | None: + """Normalize all resource ids and return them as tuple of strings.""" + + normalized_ids = tuple(_get_id_string_or_none(my_id) for my_id in ids) + + if _has_missing_ids(*normalized_ids): + return None + + return normalized_ids + + +def _normalize_permission_ids( + *field_names: str, + permission_set: PermissionRowType | dict, +) -> tuple[str, ...] | None: + """Extract and normalize permission ids as tuple of strings.""" + + normalized_ids = tuple( + _normalize_permission_id(field_name=field_name, permission_set=permission_set) + for field_name in field_names + ) + + if _has_missing_ids(*normalized_ids): + return None + + return normalized_ids + + +def _normalize_permission_id( + *, field_name: str, permission_set: PermissionRowType | dict +) -> str | None: + """Extract and normalize the permission id from a permission row.""" + + return _get_id_string_or_none(permission_set.get(field_name, {}).get("id")) + + +def _has_missing_ids(*ids: str | None) -> bool: + """Check if any required id is missing, and if so normalize it to be None.""" + + return any(my_id is None for my_id in ids) diff --git a/metrics/api/views/downloads/single_category_downloads.py b/metrics/api/views/downloads/single_category_downloads.py index ef02e6829b..d2ba555227 100644 --- a/metrics/api/views/downloads/single_category_downloads.py +++ b/metrics/api/views/downloads/single_category_downloads.py @@ -8,7 +8,6 @@ from rest_framework.views import APIView from caching.private_api.decorators import cache_response -from metrics.api.decorators.auth import require_authorisation from metrics.api.serializers import ( CoreHeadlineSerializer, CoreTimeSeriesSerializer, @@ -101,7 +100,6 @@ def _handle_csv( @extend_schema(request=DownloadsSerializer, tags=[DOWNLOADS_API_TAG]) @cache_response() - @require_authorisation def post(self, request, *args, **kwargs): """This endpoint will return the query output in json/csv format diff --git a/metrics/data/managers/core_models/geography.py b/metrics/data/managers/core_models/geography.py index 5ff2ce8cb6..f280df4cbd 100644 --- a/metrics/data/managers/core_models/geography.py +++ b/metrics/data/managers/core_models/geography.py @@ -47,6 +47,42 @@ def get_name_by_code(self, geography_code: str) -> str | None: .first() ) + def get_id_by_name( + self, geography_name: str, geography_type_name: str + ) -> int | None: + """ + Gets the geography ID for a given geography name. A geography_type_name + must also be provided because of geographies that share the same name + across different geography types (e.g. "Liverpool"). + + Returns: + The geography ID if found, or None otherwise + """ + record = self.filter( + name=geography_name, geography_type__name=geography_type_name + ).first() + + return int(record.id) if record else None + + def get_code_by_name( + self, geography_name: str, geography_type_name: str + ) -> str | None: + """ + Gets the geography code for a given geography name. A geography_type_name + must also be provided because of geographies that share the same name + across different geography types (e.g. "Liverpool"). + + The geography_type__name filter performs the inner join to GeographyType. + + Returns: + The geography_code (eg E10000011) if found, or None otherwise + """ + return ( + self.filter(name=geography_name, geography_type__name=geography_type_name) + .values_list("geography_code", flat=True) + .first() + ) + def get_all_geography_codes_by_geography_type( self, geography_type_name: str ) -> Self: @@ -167,6 +203,36 @@ def get_name_by_code(self, geography_code: int) -> str | None: """ return self.get_queryset().get_name_by_code(geography_code) + def get_id_by_name( + self, geography_name: str, geography_type_name: str + ) -> int | None: + """ + Gets the geography ID for a given geography name. A geography_type_name + must also be provided because of geographies that share the same name + across different geography types (e.g. "Liverpool"). + + Returns: + The geography ID if found, or None otherwise + """ + return self.get_queryset().get_id_by_name( + geography_name, geography_type_name=geography_type_name + ) + + def get_code_by_name( + self, geography_name: str, geography_type_name: str + ) -> str | None: + """ + Gets the geography code for a given geography name. A geography_type_name + must also be provided because of geographies that share the same name + across different geography types (e.g. "Liverpool"). + + Returns: + The geography_code (eg E10000011) if found, or None otherwise + """ + return self.get_queryset().get_code_by_name( + geography_name, geography_type_name=geography_type_name + ) + def get_all_names(self) -> GeographyQuerySet: """Gets all available deduplicated geography names as a flat list queryset. diff --git a/metrics/data/managers/core_models/geography_type.py b/metrics/data/managers/core_models/geography_type.py index ae45b36b72..8a93f13a09 100644 --- a/metrics/data/managers/core_models/geography_type.py +++ b/metrics/data/managers/core_models/geography_type.py @@ -41,6 +41,19 @@ def get_name_by_id(self, geography_type_id: int) -> str | None: """ return self.filter(id=geography_type_id).values_list("name", flat=True).first() + def get_id_by_name(self, geography_type_name: str) -> int | None: + """ + Gets the geography type ID for a given geography type name. + + Args: + geography_type_name: The name of the geography type to look up + + Returns: + The geography type ID if found, or None otherwise + """ + record = self.filter(name=geography_type_name).first() + return int(record.id) if record else None + def get_all_names_and_ids(self) -> models.QuerySet: """Gets all available geography_type names as a flat list queryset. @@ -77,6 +90,18 @@ def get_name_by_id(self, geography_type_id: int) -> str | None: """ return self.get_queryset().get_name_by_id(geography_type_id) + def get_id_by_name(self, geography_type_name: str) -> int | None: + """ + Gets the geography type ID for a given geography type name. + + Args: + geography_type_name: The name of the geography type to look up + + Returns: + The geography type ID if found, or None otherwise + """ + return self.get_queryset().get_id_by_name(geography_type_name) + def get_all_names(self) -> GeographyTypeQuerySet: """Gets all available geography_type names as a flat list queryset. diff --git a/metrics/data/managers/core_models/metric.py b/metrics/data/managers/core_models/metric.py index e9fcb961d9..a70fffc81d 100644 --- a/metrics/data/managers/core_models/metric.py +++ b/metrics/data/managers/core_models/metric.py @@ -29,6 +29,16 @@ def get_name_by_id(self, metric_id: int) -> str | None: """ return self.filter(id=metric_id).values_list("name", flat=True).first() + def get_id_by_name(self, metric_name: str) -> int | None: + """ + Gets the metric ID for a given metric name. + + Returns: + The metric ID if found, or None otherwise + """ + record = self.filter(name=metric_name).first() + return int(record.id) if record else None + def get_all_names(self) -> models.QuerySet: """Gets all available metric names as a flat list queryset. @@ -146,6 +156,15 @@ def get_name_by_id(self, metric_id: int) -> str | None: """ return self.get_queryset().get_name_by_id(metric_id) + def get_id_by_name(self, metric_name: str) -> int | None: + """ + Gets the metric ID for a given metric name. + + Returns: + The metric ID if found, or None otherwise + """ + return self.get_queryset().get_id_by_name(metric_name) + def get_all_names(self) -> MetricQuerySet: """Gets all available metric names as a flat list queryset. diff --git a/metrics/data/managers/core_models/time_series.py b/metrics/data/managers/core_models/time_series.py index aed8df7f66..d0bb965980 100644 --- a/metrics/data/managers/core_models/time_series.py +++ b/metrics/data/managers/core_models/time_series.py @@ -13,9 +13,9 @@ from django.db.models.query_utils import Q from django.utils import timezone +from common.auth.permissions import PermissionSetsType, check_chart_permissions_by_name from metrics.api.permissions.fluent_permissions import ( is_public_data_only_enforced, - validate_permissions_for_non_public, ) from metrics.data.models import RBACPermission @@ -157,7 +157,7 @@ def filter_for_audit_list_view( return self._ascending_order(queryset=queryset, field_name="date") - def query_for_data( + def query_for_data( # noqa: PLR0914 self, *, topic: str, @@ -171,8 +171,10 @@ def query_for_data( stratum: str | None = None, sex: str | None = None, age: str | None = None, + theme: str = "", + sub_theme: str = "", metric_value_ranges: list[tuple[str | float | int]] | None = None, - restrict_to_public: bool = True, + permission_sets: PermissionSetsType | None = None, ) -> models.QuerySet: """Filters for a N-item list of dicts by the given params if `fields_to_export` is used. @@ -212,14 +214,18 @@ def query_for_data( Note that options are `M`, `F`, or `ALL`. age: The age range to apply additional filtering to. E.g. `0_4` would be used to capture the age of 0-4 years old + theme: The name of the theme being queried. + This is only used to determine permissions for + the non-public portion of the requested dataset. + sub_theme: The name of the sub theme being queried. + This is only used to determine permissions for + the non-public portion of the requested dataset. metric_value_ranges: List of tuples whereby each tuple represents a permissible metric value range. i.e. to filter for all record with values between 0 -> 80 AND 90 -> 100, this can be provided as `[(0, 80), (90, 100)]`. - restrict_to_public: Boolean switch to restrict the query - to only return public records. - If False, then non-public records will be included. + permission_sets: The JWT permissions extracted from the Cognito token. Returns: QuerySet: An ordered queryset from lowest -> highest @@ -231,6 +237,7 @@ def query_for_data( ]>` """ + queryset = self.filter( metric__topic__name=topic, metric__name=metric, @@ -246,7 +253,18 @@ def query_for_data( age=age, ) - if restrict_to_public: + if permission_sets and check_chart_permissions_by_name( + permission_sets=permission_sets, + theme_name=theme, + sub_theme_name=sub_theme, + topic_name=topic, + metric_name=metric, + geography_type=geography_type, + geography_name=geography, + ): + # Keep both the public and non-public data + pass + else: queryset = queryset.filter(is_public=True) queryset = self._exclude_data_under_embargo(queryset=queryset) @@ -533,6 +551,7 @@ def query_for_data( sub_theme: str = "", metric_value_ranges: list[str | float | int] | None = None, rbac_permissions: Iterable[RBACPermission] | None = None, + permission_sets: PermissionSetsType | None = None, ) -> CoreTimeSeriesQuerySet: """Filters for a 2-item object by the given params. Slices all values older than the `date_from`. @@ -579,9 +598,7 @@ def query_for_data( i.e. to filter for all record with values between 0 -> 80 AND 90 -> 100, this can be provided as `[(0, 80), (90, 100)]`. - rbac_permissions: The RBAC permissions available - to the given request. This dictates whether the given - request is permitted access to non-public data or not. + permission_sets: The JWT permissions extracted from the Cognito token. Notes: If we have the following input `queryset`: @@ -611,20 +628,12 @@ def query_for_data( ]>` """ - rbac_permissions: Iterable[RBACPermission] = rbac_permissions or [] - has_access_to_non_public_data: bool = validate_permissions_for_non_public( - theme=theme, - sub_theme=sub_theme, - topic=topic, - metric=metric, - geography_type=geography_type, - geography=geography, - rbac_permissions=rbac_permissions, - ) return self.get_queryset().query_for_data( fields_to_export=fields_to_export, field_to_order_by=field_to_order_by, + theme=theme, + sub_theme=sub_theme, topic=topic, metric=metric, date_from=date_from, @@ -635,7 +644,7 @@ def query_for_data( sex=sex, age=age, metric_value_ranges=metric_value_ranges, - restrict_to_public=not has_access_to_non_public_data, + permission_sets=permission_sets, ) def query_for_superseded_data( diff --git a/metrics/data/managers/core_models/topic.py b/metrics/data/managers/core_models/topic.py index 00b6f06320..7a1342efa4 100644 --- a/metrics/data/managers/core_models/topic.py +++ b/metrics/data/managers/core_models/topic.py @@ -40,6 +40,35 @@ def get_name_by_id(self, topic_id: int) -> str | None: """ return self.filter(id=topic_id).values_list("name", flat=True).first() + def get_id_by_name( + self, theme_name: str, sub_theme_name: str, topic_name: str + ) -> tuple[int | None, int | None, int | None]: + """ + Gets the theme, sub-theme and topic IDs matching the given names. + + Returns: + A tuple of (theme_id, sub_theme_id, topic_id) if found, + or (None, None, None) if not found. + """ + record = self.filter( + sub_theme__theme__name=theme_name, + sub_theme__name=sub_theme_name, + name=topic_name, + ).first() + + if record: + return ( + int(record.sub_theme.theme_id), + int(record.sub_theme_id), + int(record.id), + ) + + return ( + None, + None, + None, + ) + def get_all_unique_names(self) -> models.QuerySet: """Gets all available unique topic names as a flat list queryset. @@ -113,6 +142,20 @@ def get_name_by_id(self, topic_id: int) -> str | None: """ return self.get_queryset().get_name_by_id(topic_id) + def get_id_by_name( + self, theme_name: str, sub_theme_name: str, topic_name: str + ) -> tuple[int | None, int | None, int | None]: + """ + Gets the theme, sub-theme and topic IDs matching the given names. + + Returns: + A tuple of (theme_id, sub_theme_id, topic_id) if found, + or (None, None, None) if not found. + """ + return self.get_queryset().get_id_by_name( + theme_name, sub_theme_name, topic_name + ) + def get_all_names(self) -> TopicQuerySet: """Gets all available topic names as a flat list queryset. diff --git a/metrics/domain/models/charts/common.py b/metrics/domain/models/charts/common.py index b610ceff1d..e3c1136306 100644 --- a/metrics/domain/models/charts/common.py +++ b/metrics/domain/models/charts/common.py @@ -1,12 +1,10 @@ -from collections.abc import Iterable from decimal import Decimal from typing import Literal -from pydantic.main import BaseModel -from rest_framework.request import Request +from metrics.domain.models.common import BaseRequestParams -class BaseChartRequestParams(BaseModel): +class ChartBaseRequestParams(BaseRequestParams): file_format: Literal["png", "svg", "jpg", "jpeg", "json", "csv"] chart_width: int chart_height: int @@ -17,15 +15,7 @@ class BaseChartRequestParams(BaseModel): y_axis_minimum_value: Decimal | int = 0 y_axis_maximum_value: Decimal | int | None = None legend_title: str | None = "" - request: Request | None = None confidence_intervals: bool | None = False confidence_colour: str | None = "" is_public: bool | None = True data_classification: str | None = None - - class Config: - arbitrary_types_allowed = True - - @property - def rbac_permissions(self) -> Iterable["RBACPermission"]: - return getattr(self.request, "rbac_permissions", []) diff --git a/metrics/domain/models/charts/dual_category_charts.py b/metrics/domain/models/charts/dual_category_charts.py index 09e6f73ccb..27f48c10c3 100644 --- a/metrics/domain/models/charts/dual_category_charts.py +++ b/metrics/domain/models/charts/dual_category_charts.py @@ -1,6 +1,6 @@ from pydantic import BaseModel -from metrics.domain.models.charts.common import BaseChartRequestParams +from metrics.domain.models.charts.common import ChartBaseRequestParams from metrics.domain.models.charts.segments import SegmentParameters @@ -18,7 +18,7 @@ class StaticFields(BaseModel): date_to: str -class DualCategoryChartRequestParams(BaseChartRequestParams): +class DualCategoryChartRequestParams(ChartBaseRequestParams): chart_type: str secondary_category: str primary_field_values: list[str] diff --git a/metrics/domain/models/charts/subplot_charts.py b/metrics/domain/models/charts/subplot_charts.py index 3d92f3322d..3b745a3067 100644 --- a/metrics/domain/models/charts/subplot_charts.py +++ b/metrics/domain/models/charts/subplot_charts.py @@ -1,4 +1,3 @@ -from collections.abc import Iterable from decimal import Decimal from typing import Literal @@ -6,24 +5,17 @@ from rest_framework.request import Request from metrics.domain.models import ChartRequestParams +from metrics.domain.models.common import BaseRequestParams from metrics.domain.models.plots import PlotParameters OPTIONAL_STRING = str | None -class Subplots(BaseModel): +class Subplots(BaseRequestParams): subplot_title: str x_axis: str y_axis: str plots: list[PlotParameters] - request: Request | None = None - - class Config: - arbitrary_types_allowed = True - - @property - def rbac_permissions(self) -> Iterable["RBACPermission"]: - return getattr(self.request, "rbac_permissions", []) """ diff --git a/metrics/domain/models/common.py b/metrics/domain/models/common.py new file mode 100644 index 0000000000..813d831126 --- /dev/null +++ b/metrics/domain/models/common.py @@ -0,0 +1,24 @@ +from collections.abc import Iterable + +from pydantic.main import BaseModel +from rest_framework.request import Request + +from common.auth.permissions import PermissionSetsType + + +class BaseRequestParams(BaseModel): + request: Request | None = None + + class Config: + arbitrary_types_allowed = True + + @property + def permission_sets(self) -> PermissionSetsType: + """Extract optional JWT permissions from the authenticated request""" + + request_user = getattr(self.request, "user", None) + return getattr(request_user, "permission_sets", {}) + + @property + def rbac_permissions(self) -> Iterable["RBACPermission"]: + return getattr(self.request, "rbac_permissions", []) diff --git a/metrics/domain/models/headline.py b/metrics/domain/models/headline.py index 30b737ef8a..dce65414c5 100644 --- a/metrics/domain/models/headline.py +++ b/metrics/domain/models/headline.py @@ -1,10 +1,7 @@ -from collections.abc import Iterable +from metrics.domain.models.common import BaseRequestParams -from pydantic.main import BaseModel -from rest_framework.request import Request - -class HeadlineParameters(BaseModel): +class HeadlineParameters(BaseRequestParams): topic: str metric: str stratum: str @@ -14,10 +11,6 @@ class HeadlineParameters(BaseModel): age: str is_public: bool | None = True data_classification: str | None = None - request: Request | None = None - - class Config: - arbitrary_types_allowed = True @property def topic_name(self) -> str: @@ -68,7 +61,3 @@ def to_dict_for_query(self) -> dict[str, str]: "sex": self.sex_name, "rbac_permissions": self.rbac_permissions, } - - @property - def rbac_permissions(self) -> Iterable["RBACPermission"]: - return getattr(self.request, "rbac_permissions", []) diff --git a/metrics/domain/models/map.py b/metrics/domain/models/map.py index c9bf5ea0f5..ea143e4c7a 100644 --- a/metrics/domain/models/map.py +++ b/metrics/domain/models/map.py @@ -1,8 +1,8 @@ import datetime -from collections.abc import Iterable from pydantic.main import BaseModel -from rest_framework.request import Request + +from metrics.domain.models.common import BaseRequestParams OPTIONAL_STRING = str | None @@ -37,17 +37,8 @@ class MapAccompanyingPoint(BaseModel): parameters: MapAccompanyingPointOptionalParameters -class MapsParameters(BaseModel): +class MapsParameters(BaseRequestParams): date_from: datetime.date date_to: datetime.date parameters: MapMainParameters accompanying_points: list[MapAccompanyingPoint] - - request: Request | None = None - - class Config: - arbitrary_types_allowed = True - - @property - def rbac_permissions(self) -> Iterable["RBACPermission"]: - return getattr(self.request, "rbac_permissions", []) diff --git a/metrics/domain/models/plots.py b/metrics/domain/models/plots.py index 07d614160b..5093c99915 100644 --- a/metrics/domain/models/plots.py +++ b/metrics/domain/models/plots.py @@ -13,7 +13,7 @@ DataSourceFileType, extract_metric_group_from_metric, ) -from metrics.domain.models.charts.common import BaseChartRequestParams +from metrics.domain.models.charts.common import ChartBaseRequestParams class PlotParameters(BaseModel): @@ -141,7 +141,7 @@ def line_colour_enum(self) -> RGBAChartLineColours: return RGBAChartLineColours.BLACK -class ChartRequestParams(BaseChartRequestParams): +class ChartRequestParams(ChartBaseRequestParams): """Holds all the request information / params for a chart in its entirety.""" metric_group: str | None = None diff --git a/metrics/interfaces/plots/access.py b/metrics/interfaces/plots/access.py index 6e4eac34d9..ed1832e91d 100644 --- a/metrics/interfaces/plots/access.py +++ b/metrics/interfaces/plots/access.py @@ -162,7 +162,9 @@ def get_queryset_from_core_model_manager( plot_params["fields_to_export"].append("lower_confidence") return self.core_model_manager.query_for_data( - **plot_params, rbac_permissions=self.chart_request_params.rbac_permissions + **plot_params, + rbac_permissions=self.chart_request_params.rbac_permissions, # old permissions (remove) + permission_sets=self.chart_request_params.permission_sets, # new permissions ) def build_plot_data_from_parameters_with_complete_queryset( diff --git a/tests/fakes/managers/time_series_manager.py b/tests/fakes/managers/time_series_manager.py index b1391facf0..5fb591571c 100644 --- a/tests/fakes/managers/time_series_manager.py +++ b/tests/fakes/managers/time_series_manager.py @@ -54,7 +54,9 @@ def query_for_data( sex: str | None = None, age: str | None = None, rbac_permissions: Iterable[FakeRBACPermission] | None = None, + permission_sets: dict | None = None, metric_value_ranges: list[tuple] | None = None, + **kwargs, ) -> FakeQuerySet: date_from = _convert_string_to_date(date_string=date_from) diff --git a/tests/integration/metrics/api/views/test_geographies.py b/tests/integration/metrics/api/views/test_geographies.py index 685595f630..21dcb006fc 100644 --- a/tests/integration/metrics/api/views/test_geographies.py +++ b/tests/integration/metrics/api/views/test_geographies.py @@ -4,7 +4,7 @@ from rest_framework.response import Response from rest_framework.test import APIClient -from cms.auth_content.constants import WILDCARD_ID_VALUE +from common.auth.permissions import WILDCARD_ID_VALUE from tests.factories.metrics.geography import GeographyFactory from tests.factories.metrics.time_series import CoreTimeSeriesFactory from validation.geography_code import UNITED_KINGDOM_GEOGRAPHY_CODE diff --git a/tests/integration/metrics/api/views/test_permission_sets.py b/tests/integration/metrics/api/views/test_permission_sets.py index 1bb4a90966..2c9b22f64d 100644 --- a/tests/integration/metrics/api/views/test_permission_sets.py +++ b/tests/integration/metrics/api/views/test_permission_sets.py @@ -4,7 +4,7 @@ from rest_framework.response import Response from rest_framework.test import APIClient -from cms.auth_content.constants import WILDCARD_ID_VALUE +from common.auth.permissions import WILDCARD_ID_VALUE from tests.factories.metrics.metric import MetricFactory from tests.factories.metrics.sub_theme import SubThemeFactory from tests.factories.metrics.topic import TopicFactory diff --git a/tests/integration/metrics/api/views/test_user.py b/tests/integration/metrics/api/views/test_user.py index 1264f058e1..e8ad4f5a0f 100644 --- a/tests/integration/metrics/api/views/test_user.py +++ b/tests/integration/metrics/api/views/test_user.py @@ -5,7 +5,7 @@ from rest_framework.response import Response from rest_framework.test import APIClient -from cms.auth_content.constants import WILDCARD_ID_VALUE +from common.auth.permissions import WILDCARD_ID_VALUE from tests.factories.auth_content.models.permission_sets import PermissionSetFactory from tests.factories.auth_content.models.users import UserFactory from tests.factories.metrics.metric import MetricFactory diff --git a/tests/integration/metrics/data/managers/core_models/test_geography.py b/tests/integration/metrics/data/managers/core_models/test_geography.py index 51eacad90f..77fa001e30 100644 --- a/tests/integration/metrics/data/managers/core_models/test_geography.py +++ b/tests/integration/metrics/data/managers/core_models/test_geography.py @@ -119,3 +119,70 @@ def test_get_name_by_code(self): # Access the dictionary returned by .first() result = get_name_by_code assert result == geography_two.name + + @pytest.mark.django_db + @pytest.mark.parametrize( + "lookup_name, expected_index", + [ + ("England", 0), + ("London", 1), + ("NON-EXISTENT", None), + ], + ) + def test_get_id_by_name(self, lookup_name: str, expected_index: int | None): + """ + Given some Geography records + When get_id_by_name() is called + Then the matching geography_id is returned, or None if no match + """ + + # Given + given_geographies = [ + GeographyFactory.create_with_geography_type( + name="England", + geography_code="DUMMY", + geography_type="DUMMY", + ), + GeographyFactory.create_with_geography_type( + name="London", + geography_code="DUMMY", + geography_type="DUMMY", + ), + ] + + # When + geography_id = Geography.objects.get_id_by_name(lookup_name, "DUMMY") + + # Then + expected_id = ( + given_geographies[expected_index].id if expected_index is not None else None + ) + assert geography_id == expected_id + + @pytest.mark.django_db + def test_get_code_by_name(self): + """ + Given some Geography records that share a name across geography types + When get_code_by_name() is called with a specific geography_type + Then the matching geography_code is returned + """ + + # Given + GeographyFactory.create_with_geography_type( + name="Liverpool", + geography_code="E08000012", + geography_type="Lower Tier Local Authority", + ) + liverpool_combined_authority = GeographyFactory.create_with_geography_type( + name="Liverpool", + geography_code="E47000004", + geography_type="Combined Authority", + ) + + # When + geography_code = Geography.objects.get_code_by_name( + "Liverpool", "Combined Authority" + ) + + # Then + assert geography_code == liverpool_combined_authority.geography_code diff --git a/tests/integration/metrics/data/managers/core_models/test_geography_types.py b/tests/integration/metrics/data/managers/core_models/test_geography_types.py index 39afb13802..a28d66b97a 100644 --- a/tests/integration/metrics/data/managers/core_models/test_geography_types.py +++ b/tests/integration/metrics/data/managers/core_models/test_geography_types.py @@ -61,3 +61,36 @@ def test_get_name_by_id(self): # Access the dictionary returned by .first() result = get_name_by_id assert result == "Region" + + @pytest.mark.django_db + @pytest.mark.parametrize( + "lookup_name, expected_index", + [ + ("Region", 0), + ("Nation", 1), + ("NON-EXISTENT", None), + ], + ) + def test_get_id_by_name(self, lookup_name: str, expected_index: int | None): + """ + Given some GeographyType records + When get_id_by_name() is called + Then the matching geography_type_id is returned, or None if no match + """ + + # Given + given_geography_types: list[GeographyType] = [ + GeographyTypeFactory(name="Region", with_geographies=["DUMMY"]), + GeographyTypeFactory(name="Nation", with_geographies=["DUMMY"]), + ] + + # When + geography_type_id = GeographyType.objects.get_id_by_name(lookup_name) + + # Then + expected_id = ( + given_geography_types[expected_index].id + if expected_index is not None + else None + ) + assert geography_type_id == expected_id diff --git a/tests/integration/metrics/data/managers/core_models/test_metric.py b/tests/integration/metrics/data/managers/core_models/test_metric.py index b3337e0e19..f5dea413d2 100644 --- a/tests/integration/metrics/data/managers/core_models/test_metric.py +++ b/tests/integration/metrics/data/managers/core_models/test_metric.py @@ -116,3 +116,40 @@ def test_get_name_by_id(self): # Then assert get_name_by_id == "COVID-19_headline_ONSdeaths_7DayChange" + + @pytest.mark.django_db + @pytest.mark.parametrize( + "lookup_name, expected_index", + [ + ("COVID-19_deaths_ONSByDay", 0), + ("COVID-19_deaths_ONSByWeek", 1), + ("NON-EXISTENT", None), + ], + ) + def test_get_id_by_name(self, lookup_name: str, expected_index: int | None): + """ + Given some Metric records + When get_id_by_name() is called + Then the matching metric id is returned, or None if no match + """ + + # Given + given_metrics = [ + Metric.objects.create( + name="COVID-19_deaths_ONSByDay", + metric_group=MetricGroup.objects.create(name="DUMMY"), + ), + Metric.objects.create( + name="COVID-19_deaths_ONSByWeek", + metric_group=MetricGroup.objects.create(name="DUMMY"), + ), + ] + + # When + metric_id = Metric.objects.get_id_by_name(lookup_name) + + # Then + expected_id = ( + given_metrics[expected_index].id if expected_index is not None else None + ) + assert metric_id == expected_id diff --git a/tests/integration/metrics/data/managers/core_models/test_time_series.py b/tests/integration/metrics/data/managers/core_models/test_time_series.py index 1d3f4b4bc2..57837ff89d 100644 --- a/tests/integration/metrics/data/managers/core_models/test_time_series.py +++ b/tests/integration/metrics/data/managers/core_models/test_time_series.py @@ -7,8 +7,6 @@ from metrics.data.managers.core_models.time_series import CoreTimeSeriesQuerySet from metrics.data.models.core_models import CoreTimeSeries -from metrics.domain.models import get_date_n_months_ago_from_timestamp -from tests.factories.metrics.rbac_models.rbac_permission import RBACPermissionFactory from tests.factories.metrics.time_series import CoreTimeSeriesFactory FAKE_DATES = ("2023-01-01", "2023-01-02", "2023-01-03") @@ -205,14 +203,14 @@ def test_query_for_data_returns_full_records_when_axes_not_provided( ) @pytest.mark.django_db - def test_query_for_data_excludes_non_public_records_when_restrict_to_public_is_true( + def test_query_for_data_excludes_non_public_records_without_permission_sets( self, ): """ Given public and non-public `CoreTimeSeries` records When `query_for_data()` is called from an instance of the `CoreTimeSeriesQueryset` - with `restrict_to_public` given as True + with no permission sets provided Then only the public record is returned """ # Given @@ -229,7 +227,6 @@ def test_query_for_data_excludes_non_public_records_when_restrict_to_public_is_t metric=public_record.metric.name, date_from="2020-01-01", date_to="2025-12-31", - restrict_to_public=True, ) # Then @@ -237,14 +234,14 @@ def test_query_for_data_excludes_non_public_records_when_restrict_to_public_is_t assert non_public_record not in retrieved_records @pytest.mark.django_db - def test_query_for_data_includes_non_public_records_when_restrict_to_public_is_false( + def test_query_for_data_includes_non_public_records_with_global_permission_sets( self, ): """ Given public and non-public `CoreTimeSeries` records When `query_for_data()` is called from an instance of the `CoreTimeSeriesQueryset` - with `restrict_to_public` given as False + with global permission sets provided Then the non-public record is also returned """ # Given @@ -255,19 +252,96 @@ def test_query_for_data_includes_non_public_records_when_restrict_to_public_is_f metric_value=2, date="2023-01-02", is_public=False ) + permission_sets = { + "permission_sets": [], + "summary": {"has_global_access": True}, + } + # When retrieved_records = CoreTimeSeries.objects.get_queryset().query_for_data( + theme=public_record.metric.topic.sub_theme.theme.name, + sub_theme=public_record.metric.topic.sub_theme.name, topic=public_record.metric.topic.name, metric=public_record.metric.name, date_from="2020-01-01", date_to="2025-12-31", - restrict_to_public=False, + permission_sets=permission_sets, ) # Then assert public_record in retrieved_records assert non_public_record in retrieved_records + @pytest.mark.django_db + @pytest.mark.parametrize("does_permission_match", [True, False]) + def test_query_for_data_method_handles_specific_permission_sets( + self, + does_permission_match: bool, + ): + """ + Given a public and a non-public CoreTimeSeries record + When query_for_data() is called with a specific permission set + Then the non-public record is only returned when the permission row matches + """ + + # Given + public_record = CoreTimeSeriesFactory.create_record( + metric_value=1, + date="2020-01-01", + is_public=True, + ) + non_public_record = CoreTimeSeriesFactory.create_record( + metric_value=2, + date="2020-01-01", + is_public=False, + ) + permission_sets = { + "permission_sets": [ + { + "theme": { + "id": str(non_public_record.metric.topic.sub_theme.theme.id) + }, + "sub_theme": { + "id": str(non_public_record.metric.topic.sub_theme.id) + }, + "topic": {"id": str(non_public_record.metric.topic.id)}, + "metric": { + "id": str( + # Tweak id to be wrong for the negative test + non_public_record.metric.id + if does_permission_match + else 999999 + ) + }, + "geography_type": { + "id": str(non_public_record.geography.geography_type.id) + }, + "geography": { + "id": str(non_public_record.geography.geography_code) + }, + } + ], + "summary": {"has_global_access": False}, + } + + # When + retrieved_records = CoreTimeSeries.objects.get_queryset().query_for_data( + theme=public_record.metric.topic.sub_theme.theme.name, + sub_theme=public_record.metric.topic.sub_theme.name, + topic=public_record.metric.topic.name, + metric=public_record.metric.name, + geography=public_record.geography.name, + geography_type=public_record.geography.geography_type.name, + date_from="2010-01-01", + date_to="2030-01-01", + fields_to_export=[], + permission_sets=permission_sets, + ) + + # Then + assert public_record in retrieved_records + assert (non_public_record in retrieved_records) is does_permission_match + class TestCoreTimeSeriesManager: @pytest.mark.django_db @@ -496,10 +570,10 @@ def test_get_available_geographies(self): "metrics.api.permissions.fluent_permissions.auth.ENFORCE_PUBLIC_DATA_ONLY", False, ) - def test_query_for_data_returns_non_public_record_with_acceptable_permissions(self): + def test_query_for_data_returns_non_public_record_with_global_permissions(self): """ Given public and non-public `CoreTimeSeries` records - And an `RBACPermission` which gives access to the non-public portion of the data + And global JWT permission sets And `ENFORCE_PUBLIC_DATA_ONLY` is disabled When `query_for_data()` is called from the `CoreTimeSeriesManager` Then the non-public record is included @@ -520,7 +594,10 @@ def test_query_for_data_returns_non_public_record_with_acceptable_permissions(se "geography": public_record.geography.name, "geography_type": public_record.geography.geography_type.name, } - rbac_permission = RBACPermissionFactory.create_record(**params) + permission_sets = { + "permission_sets": [], + "summary": {"has_global_access": True}, + } # When core_time_series_queryset = CoreTimeSeries.objects.query_for_data( @@ -528,7 +605,7 @@ def test_query_for_data_returns_non_public_record_with_acceptable_permissions(se date_from="2020-01-01", date_to="2025-12-31", fields_to_export=[], - rbac_permissions=[rbac_permission], + permission_sets=permission_sets, ) # Then @@ -536,10 +613,10 @@ def test_query_for_data_returns_non_public_record_with_acceptable_permissions(se assert non_public_record in core_time_series_queryset @pytest.mark.django_db - def test_query_for_data_excludes_non_public_record_without_permissions(self): + def test_query_for_data_excludes_non_public_record_without_permission_sets(self): """ Given public and non-public `CoreTimeSeries` records - And no `RBACPermission` which allows access to the non-public portion of this dataset + And no permission sets are provided When `query_for_data()` is called from the `CoreTimeSeriesManager` Then the non-public record is excluded """ @@ -550,15 +627,6 @@ def test_query_for_data_excludes_non_public_record_without_permissions(self): non_public_record = CoreTimeSeriesFactory.create_record( date="2023-01-02", metric_value=2, is_public=False ) - rbac_permission = RBACPermissionFactory.create_record( - theme="some_other_theme", - sub_theme=None, - topic=None, - metric=None, - geography=None, - geography_type=None, - ) - # When core_time_series_queryset = CoreTimeSeries.objects.query_for_data( theme=public_record.metric.topic.sub_theme.theme.name, @@ -568,7 +636,6 @@ def test_query_for_data_excludes_non_public_record_without_permissions(self): geography=public_record.geography.name, geography_type=public_record.geography.geography_type.name, fields_to_export=[], - rbac_permissions=[rbac_permission], date_from="2020-01-01", date_to="2025-12-31", ) diff --git a/tests/integration/metrics/data/managers/core_models/test_topic.py b/tests/integration/metrics/data/managers/core_models/test_topic.py index 893617c014..da01286b5a 100644 --- a/tests/integration/metrics/data/managers/core_models/test_topic.py +++ b/tests/integration/metrics/data/managers/core_models/test_topic.py @@ -1,6 +1,6 @@ import pytest -from metrics.data.models.core_models.supporting import Topic +from metrics.data.models.core_models.supporting import SubTheme, Theme, Topic from tests.factories.metrics.topic import TopicFactory @@ -72,3 +72,48 @@ def test_query_get_name_by_id(self): # Then assert get_name_by_id == fake_topic_name_three + + @pytest.mark.django_db + @pytest.mark.parametrize( + "theme_name, sub_theme_name, topic_name, expected_index", + [ + ("Infectious Diseases", "Respiratory", "COVID-19", 0), + ("NON-EXISTENT", "Respiratory", "COVID-19", None), + ("Infectious Diseases", "NON-EXISTENT", "COVID-19", None), + ("Infectious Diseases", "Respiratory", "NON-EXISTENT", None), + ], + ) + def test_get_id_by_name( + self, + theme_name: str, + sub_theme_name: str, + topic_name: str, + expected_index: int | None, + ): + """ + Given some theme, sub-theme and topic records + When get_id_by_name() is called + Then the matching 3 ids are returned, or 3 None values if no match + """ + + # Given + given_theme = Theme.objects.create(name="Infectious Diseases") + given_sub_theme = SubTheme.objects.create(name="Respiratory", theme=given_theme) + given_topics = [ + Topic.objects.create(name="COVID-19", sub_theme=given_sub_theme) + ] + + # When + ids = Topic.objects.get_id_by_name(theme_name, sub_theme_name, topic_name) + + # Then + expected_ids = ( + ( + given_theme.id, + given_sub_theme.id, + given_topics[expected_index].id, + ) + if expected_index is not None + else (None, None, None) + ) + assert ids == expected_ids diff --git a/tests/unit/cms/dashboard/test_viewsets.py b/tests/unit/cms/dashboard/test_viewsets.py index 165ccb266c..30bcdd8911 100644 --- a/tests/unit/cms/dashboard/test_viewsets.py +++ b/tests/unit/cms/dashboard/test_viewsets.py @@ -1,157 +1,10 @@ -import pytest - from cms.dashboard.serializers import CMSDraftPagesSerializer, ListablePageSerializer from cms.dashboard.viewsets import ( CMSDraftPagesViewSet, CMSPagesAPIViewSet, - check_permissions, ) -class TestCheckPermissions: - @pytest.mark.parametrize( - "user_permissions, theme_id, sub_theme_id, topic_id", - [ - ([{"theme": {"id": "-1"}}], "10", "20", "30"), - ([{"theme": {"id": "10"}, "sub_theme": {"id": "-1"}}], "10", "20", "30"), - ( - [ - { - "theme": {"id": "10"}, - "sub_theme": {"id": "20"}, - "topic": {"id": "-1"}, - } - ], - "10", - "20", - "30", - ), - ( - [ - { - "theme": {"id": "10"}, - "sub_theme": {"id": "20"}, - "topic": {"id": "30"}, - } - ], - "10", - "20", - "30", - ), - ( - [ - {"theme": {"id": "5"}, "sub_theme": {"id": "-1"}}, - { - "theme": {"id": "10"}, - "sub_theme": {"id": "20"}, - "topic": {"id": "30"}, - }, - ], - "10", - "20", - "30", - ), - ], - ) - def test_check_permissions_valid_access( - self, user_permissions, theme_id, sub_theme_id, topic_id - ): - """ - Given a permission set that does grant access to the provided ids - When the `check_permissions` function is called - Then the function returns true - """ - assert ( - check_permissions(user_permissions, theme_id, sub_theme_id, topic_id) - == True - ) - - @pytest.mark.parametrize( - "user_permissions, theme_id, sub_theme_id, topic_id", - [ - ([{"theme": {"id": "99"}, "sub_theme": {"id": "-1"}}], "10", "20", "30"), - ( - [ - { - "theme": {"id": "10"}, - "sub_theme": {"id": "99"}, - "topic": {"id": "-1"}, - } - ], - "10", - "20", - "30", - ), - ( - [ - { - "theme": {"id": "10"}, - "sub_theme": {"id": "20"}, - "topic": {"id": "99"}, - } - ], - "10", - "20", - "30", - ), - ([], "10", "20", "30"), - ], - ) - def test_check_permissions_invalid_access( - self, user_permissions, theme_id, sub_theme_id, topic_id - ): - """ - Given a permission set that does not grant access to the provided ids - When the `check_permissions` function is called - Then the function returns false - """ - assert ( - check_permissions(user_permissions, theme_id, sub_theme_id, topic_id) - == False - ) - - @pytest.mark.parametrize( - "user_permissions, theme_id, sub_theme_id, topic_id", - [ - ([{}], "10", "20", "30"), - (None, "10", "20", "30"), - ([{"sub_theme": {"id": "-1"}, "topic": {"id": "-1"}}], "10", "20", "30"), - ( - [{"theme": {}, "sub_theme": {"id": "-1"}, "topic": {"id": "-1"}}], - "10", - "20", - "30", - ), - ([{"theme": {"id": "10"}, "topic": {"id": "-1"}}], "10", "20", "30"), - ( - [{"theme": {"id": "10"}, "sub_theme": {}, "topic": {"id": "-1"}}], - "10", - "20", - "30", - ), - ([{"theme": {"id": "10"}, "sub_theme": {"id": "20"}}], "10", "20", "30"), - ( - [{"theme": {"id": "10"}, "sub_theme": {"id": "20"}, "topic": {}}], - "10", - "20", - "30", - ), - ], - ) - def test_check_permissions_with_missing_values( - self, user_permissions, theme_id, sub_theme_id, topic_id - ): - """ - Given a permission set that is missing values - When the `check_permissions` function is called - Then the function returns false - """ - assert ( - check_permissions(user_permissions, theme_id, sub_theme_id, topic_id) - == False - ) - - class TestCMSDraftPagesViewSet: def test_base_serializer_class_is_set_with_correct_serializer(self): """ diff --git a/tests/unit/common/auth/test_logging.py b/tests/unit/common/auth/test_logging.py new file mode 100644 index 0000000000..48b7865607 --- /dev/null +++ b/tests/unit/common/auth/test_logging.py @@ -0,0 +1,65 @@ +from unittest import mock + +import pytest + +from common.auth.logging import log_user_permission_summary + + +@pytest.mark.parametrize( + "permission_sets, expected_has_global_access", + [ + ( + { + "permission_sets": [], + "summary": {"total_permission_sets": 2, "has_global_access": True}, + }, + True, + ), + ( + { + "permission_sets": [], + "summary": {"total_permission_sets": 1, "has_global_access": False}, + }, + False, + ), + ( + { + "permission_sets": [], + "summary": {"total_permission_sets": 0, "has_global_access": False}, + }, + False, + ), + ], +) +@mock.patch("common.auth.logging.logger.info") +def test_log_user_permission_summary( + mocked_logger_info: mock.MagicMock, + permission_sets, + expected_has_global_access, +): + """ + Given different user permission-set payloads + When log_user_permission_summary() is called + Then the expected log calls are made + """ + + # Given + user = mock.Mock(username="user-1") + + if permission_sets is not None: + user.permission_sets = permission_sets + + # When + log_user_permission_summary(user) + + # Then + log_messages = [call.args[0] for call in mocked_logger_info.call_args_list] + expected_count = permission_sets["summary"]["total_permission_sets"] + + assert any( + f"total permission sets {expected_count}" in message for message in log_messages + ) + assert ( + any("global access" in message for message in log_messages) + is expected_has_global_access + ) diff --git a/tests/unit/common/auth/test_permissions.py b/tests/unit/common/auth/test_permissions.py new file mode 100644 index 0000000000..cf077e2b07 --- /dev/null +++ b/tests/unit/common/auth/test_permissions.py @@ -0,0 +1,1127 @@ +from contextlib import ExitStack +from unittest.mock import patch + +import pytest + +from common.auth.permissions import ( + check_chart_permissions, + check_chart_permissions_by_name, + check_page_permissions, + PermissionSetsType, + PermissionRowType, +) + + +class TestCheckPermissionsByName: + THEME = "Infectious disease" + SUB_THEME = "Respiratory" + TOPIC = "COVID-19" + METRIC = "COVID-19_metric" + GEOGRAPHY_TYPE = "Nation" + GEOGRAPHY = "England" + + THEME_ID = 1 + SUB_THEME_ID = 2 + TOPIC_ID = 3 + METRIC_ID = 4 + GEOGRAPHY_TYPE_ID = 5 + GEOGRAPHY_ID = 6 + + def _permissions_by_id(self) -> dict: + return { + "theme": {"id": str(self.THEME_ID)}, + "sub_theme": {"id": str(self.SUB_THEME_ID)}, + "topic": {"id": str(self.TOPIC_ID)}, + "metric": {"id": str(self.METRIC_ID)}, + "geography_type": {"id": str(self.GEOGRAPHY_TYPE_ID)}, + "geography": {"id": str(self.GEOGRAPHY_ID)}, + } + + def _build_permission_sets( + self, permission_rows: list[PermissionRowType], has_global_access: bool = False + ) -> dict: + return { + "permission_sets": permission_rows, + "summary": {"has_global_access": has_global_access}, + } + + def _check_permissions_by_name(self, permission_sets: PermissionSetsType) -> bool: + return check_chart_permissions_by_name( + permission_sets=permission_sets, + theme_name=self.THEME, + sub_theme_name=self.SUB_THEME, + topic_name=self.TOPIC, + metric_name=self.METRIC, + geography_type=self.GEOGRAPHY_TYPE, + geography_name=self.GEOGRAPHY, + ) + + def _patch_lookups( + self, + topic_result=None, + metric_result=None, + geography_type_result=None, + geography_result=None, + ): + """Return patches for all four DB manager methods.""" + + topic_result = topic_result or (self.THEME_ID, self.SUB_THEME_ID, self.TOPIC_ID) + metric_result = metric_result or self.METRIC_ID + geography_type_result = geography_type_result or self.GEOGRAPHY_TYPE_ID + geography_result = geography_result or self.GEOGRAPHY_ID + + stack = ExitStack() + stack.enter_context( + patch( + "metrics.data.managers.core_models.topic.TopicQuerySet.get_id_by_name", + return_value=topic_result, + ) + ) + stack.enter_context( + patch( + "metrics.data.managers.core_models.metric.MetricQuerySet.get_id_by_name", + return_value=metric_result, + ) + ) + stack.enter_context( + patch( + "metrics.data.managers.core_models.geography_type.GeographyTypeQuerySet.get_id_by_name", + return_value=geography_type_result, + ) + ) + stack.enter_context( + patch( + "metrics.data.managers.core_models.geography.GeographyQuerySet.get_code_by_name", + return_value=geography_result, + ) + ) + + return stack + + def test_returns_false_when_topic_lookup_fails(self): + with ( + patch( + "metrics.data.managers.core_models.topic.TopicQuerySet.get_id_by_name", + return_value=(None, None, None), + ), + patch( + "metrics.data.managers.core_models.metric.MetricQuerySet.get_id_by_name", + return_value=self.METRIC_ID, + ), + patch( + "metrics.data.managers.core_models.geography_type.GeographyTypeQuerySet.get_id_by_name", + return_value=self.GEOGRAPHY_TYPE_ID, + ), + patch( + "metrics.data.managers.core_models.geography.GeographyQuerySet.get_code_by_name", + return_value=self.GEOGRAPHY_ID, + ), + ): + assert not self._check_permissions_by_name( + self._build_permission_sets([self._permissions_by_id()]) + ) + + def test_returns_false_when_metric_lookup_fails(self): + with ( + patch( + "metrics.data.managers.core_models.topic.TopicQuerySet.get_id_by_name", + return_value=(self.THEME_ID, self.SUB_THEME_ID, self.TOPIC_ID), + ), + patch( + "metrics.data.managers.core_models.metric.MetricQuerySet.get_id_by_name", + return_value=None, + ), + patch( + "metrics.data.managers.core_models.geography_type.GeographyTypeQuerySet.get_id_by_name", + return_value=self.GEOGRAPHY_TYPE_ID, + ), + patch( + "metrics.data.managers.core_models.geography.GeographyQuerySet.get_code_by_name", + return_value=self.GEOGRAPHY_ID, + ), + ): + assert not self._check_permissions_by_name( + self._build_permission_sets([self._permissions_by_id()]) + ) + + def test_returns_false_when_geography_type_lookup_fails(self): + with ( + patch( + "metrics.data.managers.core_models.topic.TopicQuerySet.get_id_by_name", + return_value=(self.THEME_ID, self.SUB_THEME_ID, self.TOPIC_ID), + ), + patch( + "metrics.data.managers.core_models.metric.MetricQuerySet.get_id_by_name", + return_value=self.METRIC_ID, + ), + patch( + "metrics.data.managers.core_models.geography_type.GeographyTypeQuerySet.get_id_by_name", + return_value=None, + ), + patch( + "metrics.data.managers.core_models.geography.GeographyQuerySet.get_code_by_name", + return_value=self.GEOGRAPHY_ID, + ), + ): + assert not self._check_permissions_by_name( + self._build_permission_sets([self._permissions_by_id()]) + ) + + def test_returns_false_when_geography_lookup_fails(self): + with ( + patch( + "metrics.data.managers.core_models.topic.TopicQuerySet.get_id_by_name", + return_value=(self.THEME_ID, self.SUB_THEME_ID, self.TOPIC_ID), + ), + patch( + "metrics.data.managers.core_models.metric.MetricQuerySet.get_id_by_name", + return_value=self.METRIC_ID, + ), + patch( + "metrics.data.managers.core_models.geography_type.GeographyTypeQuerySet.get_id_by_name", + return_value=self.GEOGRAPHY_TYPE_ID, + ), + patch( + "metrics.data.managers.core_models.geography.GeographyQuerySet.get_code_by_name", + return_value=None, + ), + ): + assert not self._check_permissions_by_name( + self._build_permission_sets([self._permissions_by_id()]) + ) + + def test_returns_true_when_global_access_is_true(self): + with self._patch_lookups(): + assert self._check_permissions_by_name( + self._build_permission_sets([], has_global_access=True) + ) + + @pytest.mark.parametrize( + "permission_sets", + [ + pytest.param( + None, + id="test_none_permission_sets_denies_access", + ), + ], + ) + def test_returns_false_when_permission_sets_is_not_a_dict(self, permission_sets): + assert not check_chart_permissions_by_name( + permission_sets=permission_sets, + theme_name=self.THEME, + sub_theme_name=self.SUB_THEME, + topic_name=self.TOPIC, + metric_name=self.METRIC, + geography_type=self.GEOGRAPHY_TYPE, + geography_name=self.GEOGRAPHY, + ) + + def test_returns_false_when_permission_sets_key_is_not_a_list(self): + assert not check_chart_permissions_by_name( + permission_sets={"permission_sets": "not_a_list", "summary": {}}, + theme_name=self.THEME, + sub_theme_name=self.SUB_THEME, + topic_name=self.TOPIC, + metric_name=self.METRIC, + geography_type=self.GEOGRAPHY_TYPE, + geography_name=self.GEOGRAPHY, + ) + + def test_returns_false_when_summary_is_not_a_dict(self): + assert not check_chart_permissions_by_name( + permission_sets={"permission_sets": [], "summary": "not_a_dict"}, + theme_name=self.THEME, + sub_theme_name=self.SUB_THEME, + topic_name=self.TOPIC, + metric_name=self.METRIC, + geography_type=self.GEOGRAPHY_TYPE, + geography_name=self.GEOGRAPHY, + ) + + def test_returns_false_when_has_global_access_is_not_a_bool(self): + assert not check_chart_permissions_by_name( + permission_sets={ + "permission_sets": [], + "summary": {"has_global_access": "yes"}, + }, + theme_name=self.THEME, + sub_theme_name=self.SUB_THEME, + topic_name=self.TOPIC, + metric_name=self.METRIC, + geography_type=self.GEOGRAPHY_TYPE, + geography_name=self.GEOGRAPHY, + ) + + def test_returns_true_when_lookups_succeed_and_permission_matches(self): + permission_sets = self._build_permission_sets([self._permissions_by_id()]) + with self._patch_lookups(): + assert self._check_permissions_by_name(permission_sets) + + +class TestCheckPermissions: + @pytest.mark.parametrize( + ( + "permission_sets, theme_id, sub_theme_id, topic_id, " + "metric_id, geography_type, geography_id" + ), + [ + pytest.param( + [ + { + "theme": {"id": "10"}, + "sub_theme": {"id": "20"}, + "topic": {"id": "30"}, + "metric": {"id": "40"}, + "geography_type": {"id": "50"}, + "geography": {"id": "60"}, + } + ], + "10", + "20", + "30", + "40", + "50", + "60", + id=("test_all_6_permission_resources_matching_grants_access"), + ), + pytest.param( + [ + { + "theme": {"id": "-1"}, + "sub_theme": {"id": "-1"}, + "topic": {"id": "-1"}, + "metric": {"id": "-1"}, + "geography_type": {"id": "-1"}, + "geography": {"id": "-1"}, + } + ], + "10", + "20", + "30", + "40", + "50", + "60", + id=("test_theme_wildcard_with_wildcards_following_grants_access"), + ), + pytest.param( + [ + { + "theme": {"id": "10"}, + "sub_theme": {"id": "-1"}, + "topic": {"id": "-1"}, + "metric": {"id": "-1"}, + "geography_type": {"id": "-1"}, + "geography": {"id": "-1"}, + } + ], + "10", + "20", + "30", + "40", + "50", + "60", + id=("test_sub_theme_wildcard_with_wildcards_following_grants_access"), + ), + pytest.param( + [ + { + "theme": {"id": "10"}, + "sub_theme": {"id": "20"}, + "topic": {"id": "-1"}, + "metric": {"id": "-1"}, + "geography_type": {"id": "-1"}, + "geography": {"id": "-1"}, + } + ], + "10", + "20", + "30", + "40", + "50", + "60", + id=("test_topic_wildcard_with_wildcards_following_grants_access"), + ), + pytest.param( + [ + { + "theme": {"id": "10"}, + "sub_theme": {"id": "20"}, + "topic": {"id": "30"}, + "metric": {"id": "-1"}, + "geography_type": {"id": "-1"}, + "geography": {"id": "-1"}, + } + ], + "10", + "20", + "30", + "40", + "50", + "60", + id=("test_metric_wildcard_with_wildcards_following_grants_access"), + ), + pytest.param( + [ + { + "theme": {"id": "-1"}, + "sub_theme": {"id": "20"}, + "topic": {"id": "30"}, + "metric": {"id": "40"}, + "geography_type": {"id": "50"}, + "geography": {"id": "60"}, + } + ], + "10", + "20", + "30", + "40", + "50", + "60", + id=( + "test_theme_wildcard_with_geography_type_and_geography_match_grants_access" + ), + ), + pytest.param( + [ + { + "theme": {"id": "10"}, + "sub_theme": {"id": "-1"}, + "topic": {"id": "30"}, + "metric": {"id": "40"}, + "geography_type": {"id": "50"}, + "geography": {"id": "60"}, + } + ], + "10", + "20", + "30", + "40", + "50", + "60", + id=( + "test_sub_theme_wildcard_with_geography_type_and_geography_match_grants_access" + ), + ), + pytest.param( + [ + { + "theme": {"id": "10"}, + "sub_theme": {"id": "20"}, + "topic": {"id": "-1"}, + "metric": {"id": "40"}, + "geography_type": {"id": "50"}, + "geography": {"id": "60"}, + } + ], + "10", + "20", + "30", + "40", + "50", + "60", + id=( + "test_topic_wildcard_with_geography_type_and_geography_match_grants_access" + ), + ), + pytest.param( + [ + { + "theme": {"id": "10"}, + "sub_theme": {"id": "20"}, + "topic": {"id": "30"}, + "metric": {"id": "-1"}, + "geography_type": {"id": "50"}, + "geography": {"id": "60"}, + } + ], + "10", + "20", + "30", + "40", + "50", + "60", + id=( + "test_metric_wildcard_with_geography_type_and_geography_match_grants_access" + ), + ), + pytest.param( + [ + { + "theme": {"id": "10"}, + "sub_theme": {"id": "20"}, + "topic": {"id": "30"}, + "metric": {"id": "40"}, + "geography_type": {"id": "-1"}, + "geography": {"id": "60"}, + } + ], + "10", + "20", + "30", + "40", + "50", + "60", + id=( + "test_first_4_matching_permissions_with_geography_type_wildcard_and_geography_match_grants_access" + ), + ), + pytest.param( + [ + { + "theme": {"id": "10"}, + "sub_theme": {"id": "20"}, + "topic": {"id": "30"}, + "metric": {"id": "40"}, + "geography_type": {"id": "50"}, + "geography": {"id": "-1"}, + } + ], + "10", + "20", + "30", + "40", + "50", + "60", + id=( + "test_first_4_matching_permissions_with_geography_type_match_and_geography_wildcard_grants_access" + ), + ), + pytest.param( + [ + { + "theme": {"id": "-1"}, + "sub_theme": {"id": "-1"}, + "topic": {"id": "-1"}, + "metric": {"id": "-1"}, + "geography_type": {"id": "-1"}, + "geography": {"id": "60"}, + } + ], + "10", + "20", + "30", + "40", + "50", + "60", + id=( + "test_first_4_permissions_wildcards_with_geography_type_wildcard_and_geography_match_grants_access" + ), + ), + pytest.param( + [ + { + "theme": {"id": "-1"}, + "sub_theme": {"id": "-1"}, + "topic": {"id": "-1"}, + "metric": {"id": "-1"}, + "geography_type": {"id": "50"}, + "geography": {"id": "-1"}, + } + ], + "10", + "20", + "30", + "40", + "50", + "60", + id=( + "test_first_4_permissions_wildcards_with_geography_type_match_and_geography_wildcard_grants_access" + ), + ), + pytest.param( + [ + { + "theme": {"id": "5"}, + "sub_theme": {"id": "-1"}, + "topic": {"id": "-1"}, + "metric": {"id": "-1"}, + "geography_type": {"id": "2"}, + "geography": {"id": "-1"}, + }, + { + "theme": {"id": "10"}, + "sub_theme": {"id": "20"}, + "topic": {"id": "30"}, + "metric": {"id": "40"}, + "geography_type": {"id": "50"}, + "geography": {"id": "60"}, + }, + ], + "10", + "20", + "30", + "40", + "50", + "60", + id=( + "test_matching_permission_row_grants_access_even_if_a_prior_row_does_not_match" + ), + ), + ], + ) + def test_check_chart_permissions_valid_access( + self, + permission_sets, + theme_id, + sub_theme_id, + topic_id, + metric_id, + geography_type, + geography_id, + ): + """Test chart permissions succeed when match or wildcard.""" + + assert check_chart_permissions( + permission_sets=permission_sets, + theme_id=theme_id, + sub_theme_id=sub_theme_id, + topic_id=topic_id, + metric_id=metric_id, + geography_type=geography_type, + geography_id=geography_id, + ) + + @pytest.mark.parametrize( + ( + "permission_sets, theme_id, sub_theme_id, topic_id, " + "metric_id, geography_type, geography_id" + ), + [ + pytest.param( + [ + { + "theme": {"id": "99"}, + "sub_theme": {"id": "20"}, + "topic": {"id": "30"}, + "metric": {"id": "40"}, + "geography_type": {"id": "50"}, + "geography": {"id": "60"}, + } + ], + "10", + "20", + "30", + "40", + "50", + "60", + id=("test_theme_mismatch_denies_access"), + ), + pytest.param( + [ + { + "theme": {"id": "10"}, + "sub_theme": {"id": "99"}, + "topic": {"id": "30"}, + "metric": {"id": "40"}, + "geography_type": {"id": "50"}, + "geography": {"id": "60"}, + } + ], + "10", + "20", + "30", + "40", + "50", + "60", + id=("test_sub_theme_mismatch_denies_access"), + ), + pytest.param( + [ + { + "theme": {"id": "10"}, + "sub_theme": {"id": "20"}, + "topic": {"id": "99"}, + "metric": {"id": "40"}, + "geography_type": {"id": "50"}, + "geography": {"id": "60"}, + } + ], + "10", + "20", + "30", + "40", + "50", + "60", + id=("test_topic_mismatch_denies_access"), + ), + pytest.param( + [ + { + "theme": {"id": "10"}, + "sub_theme": {"id": "20"}, + "topic": {"id": "30"}, + "metric": {"id": "99"}, + "geography_type": {"id": "50"}, + "geography": {"id": "60"}, + } + ], + "10", + "20", + "30", + "40", + "50", + "60", + id=("test_metric_mismatch_denies_access"), + ), + pytest.param( + [ + { + "theme": {"id": "10"}, + "sub_theme": {"id": "20"}, + "topic": {"id": "30"}, + "metric": {"id": "40"}, + "geography_type": {"id": "6"}, + "geography": {"id": "60"}, + } + ], + "10", + "20", + "30", + "40", + "50", + "60", + id=("test_geography_type_mismatch_denies_access"), + ), + pytest.param( + [ + { + "theme": {"id": "10"}, + "sub_theme": {"id": "20"}, + "topic": {"id": "30"}, + "metric": {"id": "40"}, + "geography_type": {"id": "50"}, + "geography": {"id": "99"}, + } + ], + "10", + "20", + "30", + "40", + "50", + "60", + id=("test_geography_mismatch_denies_access"), + ), + pytest.param( + [], + "10", + "20", + "30", + "40", + "50", + "60", + id=("test_empty_permission_set_list_denies_access"), + ), + ], + ) + def test_check_chart_permissions_invalid_access( + self, + permission_sets, + theme_id, + sub_theme_id, + topic_id, + metric_id, + geography_type, + geography_id, + ): + """Test chart permissions fail when mismatch and no wildcard.""" + + assert not check_chart_permissions( + permission_sets=permission_sets, + theme_id=theme_id, + sub_theme_id=sub_theme_id, + topic_id=topic_id, + metric_id=metric_id, + geography_type=geography_type, + geography_id=geography_id, + ) + + @pytest.mark.parametrize( + ( + "permission_sets, theme_id, sub_theme_id, topic_id, " + "metric_id, geography_type, geography_id" + ), + [ + pytest.param( + [{}], + "10", + "20", + "30", + "40", + "50", + "60", + id=("test_empty_permission_row_denies_access"), + ), + pytest.param( + None, + "10", + "20", + "30", + "40", + "50", + "60", + id=("test_non_list_permission_sets_denies_access"), + ), + pytest.param( + [{"sub_theme": {"id": "-1"}, "topic": {"id": "-1"}}], + "10", + "20", + "30", + "40", + "50", + "60", + id=("test_missing_theme_denies_access"), + ), + pytest.param( + [{"theme": {}, "sub_theme": {"id": "-1"}, "topic": {"id": "-1"}}], + "10", + "20", + "30", + "40", + "50", + "60", + id=("test_empty_theme_id_in_chart_permissions_denies_access"), + ), + pytest.param( + [{"theme": {"id": "10"}, "topic": {"id": "-1"}}], + "10", + "20", + "30", + "40", + "50", + "60", + id=("test_missing_sub_theme_denies_access"), + ), + pytest.param( + [{"theme": {"id": "10"}, "sub_theme": {}, "topic": {"id": "-1"}}], + "10", + "20", + "30", + "40", + "50", + "60", + id=("test_empty_sub_theme_denies_access"), + ), + pytest.param( + [{"theme": {"id": "10"}, "sub_theme": {"id": "20"}}], + "10", + "20", + "30", + "40", + "50", + "60", + id=("test_missing_topic_denies_access"), + ), + pytest.param( + [{"theme": {"id": "10"}, "sub_theme": {"id": "20"}, "topic": {}}], + "10", + "20", + "30", + "40", + "50", + "60", + id=("test_empty_topic_id_denies_access"), + ), + pytest.param( + [ + { + "theme": {"id": "10"}, + "sub_theme": {"id": "20"}, + "topic": {"id": "30"}, + "metric": {}, + "geography_type": {"id": "50"}, + "geography": {"id": "60"}, + } + ], + "10", + "20", + "30", + "40", + "50", + "60", + id=("test_empty_metric_denies_access"), + ), + pytest.param( + [ + { + "theme": {"id": "10"}, + "sub_theme": {"id": "20"}, + "topic": {"id": "30"}, + "geography_type": {"id": "50"}, + "geography": {"id": "60"}, + } + ], + "10", + "20", + "30", + "40", + "50", + "60", + id=("test_missing_metric_denies_access"), + ), + pytest.param( + [ + { + "theme": {"id": "10"}, + "sub_theme": {"id": "20"}, + "topic": {"id": "30"}, + "metric": {"id": "40"}, + "geography": {"id": "60"}, + } + ], + "10", + "20", + "30", + "40", + "50", + "60", + id=("test_missing_geography_type_denies_access"), + ), + pytest.param( + [ + { + "theme": {"id": "10"}, + "sub_theme": {"id": "20"}, + "topic": {"id": "30"}, + "metric": {"id": "40"}, + "geography_type": {}, + "geography": {"id": "60"}, + } + ], + "10", + "20", + "30", + "40", + "50", + "60", + id=("test_empty_geography_type_denies_access"), + ), + pytest.param( + [ + { + "theme": {"id": "10"}, + "sub_theme": {"id": "20"}, + "topic": {"id": "30"}, + "metric": {"id": "40"}, + "geography_type": {"id": "50"}, + } + ], + "10", + "20", + "30", + "40", + "50", + "60", + id=("test_missing_geography_denies_access"), + ), + pytest.param( + [ + { + "theme": {"id": "10"}, + "sub_theme": {"id": "20"}, + "topic": {"id": "30"}, + "metric": {"id": "40"}, + "geography": {"id": "50"}, + "geography_type": {}, + } + ], + "10", + "20", + "30", + "40", + "50", + "60", + id=("test_empty_geography_denies_access"), + ), + pytest.param( + [123], + "10", + "20", + "30", + "40", + "50", + "60", + id=("test_non_dict_item_in_permission_sets_list_denies_access"), + ), + pytest.param( + [ + { + "theme": {"id": "10"}, + "sub_theme": {"id": "20"}, + "topic": {"id": "30"}, + "metric": {"id": "40"}, + "geography_type": {"id": "50"}, + "geography": {"id": "60"}, + } + ], + None, + "20", + "30", + "40", + "50", + "60", + id=("test_none_theme_resource_id_denies_access"), + ), + ], + ) + def test_check_chart_permissions_with_missing_values( + self, + permission_sets, + theme_id, + sub_theme_id, + topic_id, + metric_id, + geography_type, + geography_id, + ): + """Test chart permissions fail when being passed missing values""" + + assert not check_chart_permissions( + permission_sets=permission_sets, + theme_id=theme_id, + sub_theme_id=sub_theme_id, + topic_id=topic_id, + metric_id=metric_id, + geography_type=geography_type, + geography_id=geography_id, + ) + + +class TestCheckPagePermissions: + @pytest.mark.parametrize( + "user_permissions, theme_id, sub_theme_id, topic_id", + [ + ([{"theme": {"id": "-1"}}], "10", "20", "30"), + ([{"theme": {"id": "10"}, "sub_theme": {"id": "-1"}}], "10", "20", "30"), + ( + [ + { + "theme": {"id": "10"}, + "sub_theme": {"id": "20"}, + "topic": {"id": "-1"}, + } + ], + "10", + "20", + "30", + ), + ( + [ + { + "theme": {"id": "10"}, + "sub_theme": {"id": "20"}, + "topic": {"id": "30"}, + } + ], + "10", + "20", + "30", + ), + ( + [ + {"theme": {"id": "5"}, "sub_theme": {"id": "-1"}}, + { + "theme": {"id": "10"}, + "sub_theme": {"id": "20"}, + "topic": {"id": "30"}, + }, + ], + "10", + "20", + "30", + ), + ], + ) + def test_check_page_permissions_valid_access( + self, user_permissions, theme_id, sub_theme_id, topic_id + ): + """Test page permissions succeed when match or wildcard.""" + + assert check_page_permissions( + permission_sets=user_permissions, + theme_id=theme_id, + sub_theme_id=sub_theme_id, + topic_id=topic_id, + ) + + @pytest.mark.parametrize( + "user_permissions, theme_id, sub_theme_id, topic_id", + [ + ([{"theme": {"id": "99"}, "sub_theme": {"id": "-1"}}], "10", "20", "30"), + ( + [ + { + "theme": {"id": "10"}, + "sub_theme": {"id": "99"}, + "topic": {"id": "-1"}, + } + ], + "10", + "20", + "30", + ), + ( + [ + { + "theme": {"id": "10"}, + "sub_theme": {"id": "20"}, + "topic": {"id": "99"}, + } + ], + "10", + "20", + "30", + ), + ([], "10", "20", "30"), + ], + ) + def test_check_page_permissions_invalid_access( + self, user_permissions, theme_id, sub_theme_id, topic_id + ): + """Test page permissions fail when mismatch and no wildcard.""" + + assert not check_page_permissions( + permission_sets=user_permissions, + theme_id=theme_id, + sub_theme_id=sub_theme_id, + topic_id=topic_id, + ) + + @pytest.mark.parametrize( + "user_permissions, theme_id, sub_theme_id, topic_id", + [ + ([{}], "10", "20", "30"), + (None, "10", "20", "30"), + ([{"sub_theme": {"id": "-1"}, "topic": {"id": "-1"}}], "10", "20", "30"), + ( + [{"theme": {}, "sub_theme": {"id": "-1"}, "topic": {"id": "-1"}}], + "10", + "20", + "30", + ), + ([{"theme": {"id": "10"}, "topic": {"id": "-1"}}], "10", "20", "30"), + ( + [{"theme": {"id": "10"}, "sub_theme": {}, "topic": {"id": "-1"}}], + "10", + "20", + "30", + ), + ([{"theme": {"id": "10"}, "sub_theme": {"id": "20"}}], "10", "20", "30"), + ( + [{"theme": {"id": "10"}, "sub_theme": {"id": "20"}, "topic": {}}], + "10", + "20", + "30", + ), + ], + ) + def test_check_page_permissions_with_missing_values( + self, user_permissions, theme_id, sub_theme_id, topic_id + ): + """Test page permissions fail when being passed missing values""" + + assert not check_page_permissions( + permission_sets=user_permissions, + theme_id=theme_id, + sub_theme_id=sub_theme_id, + topic_id=topic_id, + ) diff --git a/tests/unit/metrics/api/serializers/test_geographies.py b/tests/unit/metrics/api/serializers/test_geographies.py index 108a7567e2..1a1a52e997 100644 --- a/tests/unit/metrics/api/serializers/test_geographies.py +++ b/tests/unit/metrics/api/serializers/test_geographies.py @@ -4,7 +4,7 @@ from rest_framework.exceptions import ValidationError -from cms.auth_content.constants import WILDCARD_ID_VALUE +from common.auth.permissions import WILDCARD_ID_VALUE from metrics.data.models.core_models.supporting import Geography from validation.geography_code import UNITED_KINGDOM_GEOGRAPHY_CODE from metrics.api.serializers.geographies import ( diff --git a/tests/unit/metrics/api/serializers/test_permission_sets.py b/tests/unit/metrics/api/serializers/test_permission_sets.py index 7d6e956270..4378c6162b 100644 --- a/tests/unit/metrics/api/serializers/test_permission_sets.py +++ b/tests/unit/metrics/api/serializers/test_permission_sets.py @@ -3,7 +3,7 @@ import pytest from rest_framework import serializers as drf_serializers -from cms.auth_content.constants import WILDCARD_ID_VALUE +from common.auth.permissions import WILDCARD_ID_VALUE from metrics.api.serializers.permission_sets import ( MetricRequestSerializer, PermissionSetResponseSerializer, diff --git a/tests/unit/metrics/domain/models/test_common.py b/tests/unit/metrics/domain/models/test_common.py new file mode 100644 index 0000000000..8a7f810e3c --- /dev/null +++ b/tests/unit/metrics/domain/models/test_common.py @@ -0,0 +1,102 @@ +import datetime +from types import SimpleNamespace + +import pytest +from django.http import HttpRequest +from rest_framework.request import Request + +from metrics.domain.models import ChartRequestParams, PlotParameters +from metrics.domain.models.charts.subplot_charts import Subplots +from metrics.domain.models.headline import HeadlineParameters +from metrics.domain.models.map import MapMainParameters, MapsParameters + +PERMISSION_SETS = { + "permission_sets": [{"theme": {"id": "1"}}], + "summary": {"has_global_access": False}, +} + +MODEL_FACTORIES = [ + pytest.param(lambda request: _build_models(request)[0], id="headline"), + pytest.param(lambda request: _build_models(request)[1], id="maps"), + pytest.param(lambda request: _build_models(request)[2], id="subplots"), + pytest.param(lambda request: _build_models(request)[3], id="chart_request_params"), +] + + +def _build_request(*, permission_sets=None) -> Request: + request = Request(HttpRequest()) + if permission_sets is not None: + request.user = SimpleNamespace(permission_sets=permission_sets) + return request + + +def _build_models(request: Request) -> tuple: + return ( + HeadlineParameters( + topic="COVID-19", + metric="COVID-19_metric", + geography="England", + geography_type="Nation", + stratum="default", + sex="all", + age="all", + request=request, + ), + MapsParameters( + date_from=datetime.date(2025, 1, 1), + date_to=datetime.date(2025, 12, 31), + parameters=MapMainParameters( + theme="infectious_disease", + sub_theme="respiratory", + topic="COVID-19", + metric="COVID-19_metric", + stratum="default", + age="all", + sex="all", + geography_type="Nation", + geographies=["England"], + ), + accompanying_points=[], + request=request, + ), + Subplots( + subplot_title="Test subplot", + x_axis="date", + y_axis="metric", + plots=[], + request=request, + ), + ChartRequestParams( + file_format="svg", + chart_width=930, + chart_height=220, + x_axis="date", + y_axis="metric", + plots=[ + PlotParameters( + chart_type="bar", + topic="COVID-19", + metric="COVID-19_metric", + ) + ], + request=request, + ), + ) + + +@pytest.mark.parametrize( + "model_factory", + MODEL_FACTORIES, +) +def test_permission_sets_from_request_user(model_factory): + request = _build_request(permission_sets=PERMISSION_SETS) + assert model_factory(request).permission_sets == PERMISSION_SETS + + +@pytest.mark.parametrize( + "model_factory", + MODEL_FACTORIES, +) +def test_permission_sets_default_to_empty_dict(model_factory): + request = _build_request(permission_sets=None) + assert model_factory(request).permission_sets == {} diff --git a/tests/unit/metrics/interfaces/plots/test_access.py b/tests/unit/metrics/interfaces/plots/test_access.py index 2f55016864..b794cb3810 100644 --- a/tests/unit/metrics/interfaces/plots/test_access.py +++ b/tests/unit/metrics/interfaces/plots/test_access.py @@ -546,6 +546,7 @@ def test_get_headline_data_calls_core_headline_manager_with_correct_args(self): sex=mocked_sex, age=mocked_age, rbac_permissions=mocked_chart_request_params.rbac_permissions, + permission_sets=mocked_chart_request_params.permission_sets, ) def test_get_headline_data_calls_core_headline_manager_with_confidence_intervals( @@ -613,6 +614,7 @@ def test_get_headline_data_calls_core_headline_manager_with_confidence_intervals sex=mocked_sex, age=mocked_age, rbac_permissions=mocked_chart_request_params.rbac_permissions, + permission_sets=mocked_chart_request_params.permission_sets, ) @mock.patch(f"{MODULE_PATH}.auth.AUTH_ENABLED", True) @@ -676,6 +678,7 @@ def test_get_queryset_from_core_model_manager_passes_theme_and_topic_into_query_ rbac_permissions=mocked_chart_request_params.rbac_permissions, theme=fake_metric.topic.sub_theme.theme.name, sub_theme=fake_metric.topic.sub_theme.name, + permission_sets=mocked_chart_request_params.permission_sets, ) def test_get_timeseries_calls_core_time_series_manager_with_correct_args(self): @@ -746,6 +749,7 @@ def test_get_timeseries_calls_core_time_series_manager_with_correct_args(self): sex=mocked_sex, age=mocked_age, rbac_permissions=mocked_chart_request_params.rbac_permissions, + permission_sets=mocked_chart_request_params.permission_sets, ) @mock.patch.object(PlotsInterface, "get_queryset_from_core_model_manager")