Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ you will need to set up the variables for validating the token via cognito:

- `export COGNITO_AWS_REGION=eu-west-2` - This is unlikely to change
- `export COGNITO_USER_POOL=eu-west-2_a123bc4DE` - Can be found be checking the `User pool ID` value for your environment on the [AWS console] (https://eu-west-2.console.aws.amazon.com/cognito/v2/idp/user-pools?region=eu-west-2)
- `export COGNITO_JWT_AUTH_HEADER=HTTP_X_UHD_AUTH` - This is unlikely to change
- `export JWT_AUTH_HEADER=HTTP_X_UHD_AUTH` - This is unlikely to change

---

Expand Down
37 changes: 0 additions & 37 deletions common/auth/cognito_jwt/user_manager.py

This file was deleted.

File renamed without changes.
81 changes: 62 additions & 19 deletions common/auth/cognito_jwt/backend.py → common/auth/jwt/backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging

import jwt
from django.apps import apps as django_apps
from django.conf import settings
from django.utils.encoding import force_str
Expand All @@ -8,7 +9,7 @@
from rest_framework import HTTP_HEADER_ENCODING, exceptions
from rest_framework.authentication import BaseAuthentication

from .validator import TokenError, TokenValidator
from .validator import CognitoTokenValidator, EntraTokenValidator, TokenError

logger = logging.getLogger(__name__)

Expand All @@ -22,8 +23,9 @@ def get_authorization_header(request):

Hide some test client ickyness where the header can be unicode.
"""
auth_header = getattr(settings, "COGNITO_JWT_AUTH_HEADER", "Authorization")
auth_header = getattr(settings, "JWT_AUTH_HEADER", "HTTP_AUTHORIZATION")
auth = request.META.get(auth_header, b"")

if isinstance(auth, str):
# Work around django test client oddness
auth = auth.encode(HTTP_HEADER_ENCODING)
Expand All @@ -44,38 +46,58 @@ def authenticate(self, request):

# Authenticate token
try:
token_validator = self.get_token_validator(request)
token_validator, provider_name = self.get_token_validator(jwt_token)
except TokenError as e:
logger.debug("Failed to identify token provider: %s", e)
raise exceptions.AuthenticationFailed(
_("Unknown or malformed token issuer.")
) from e

try:
jwt_payload = token_validator.validate(jwt_token)
except TokenError:
except TokenError as e:
logger.debug(
"%s token validation failed: %s", provider_name.capitalize(), e
)
raise exceptions.AuthenticationFailed from None

custom_user_manager = self.get_custom_user_manager()
custom_user_manager = self.get_custom_user_manager(provider_name)

if custom_user_manager:
user = custom_user_manager.get_or_create_for_cognito(jwt_payload)
user = custom_user_manager.get_or_create(jwt_payload)
else:
user_model = self.get_user_model()
user = user_model.objects.get_or_create_for_cognito(jwt_payload)
user = user_model.objects.get_or_create(jwt_payload)
if not user:
logger.debug(
"Unable to create user from JWT, defaulting to unauthenticated"
)
return None

return (user, jwt_token)

@staticmethod
def get_custom_user_manager():
"""If COGNITO_USER_MANAGER is set, then the user object is obtained
via get_or_create_for_cognito on the user manager, this allows use
def get_custom_user_manager(provider="cognito"):
"""If COGNITO_USER_MANAGER or ENTRA_USER_MANAGER is set, then the user object is obtained
via get_or_create_for_cognito (or get_or_create_for_entra) on the user manager, this allows use
of the default unmodified Django User model"""
result = None
custom_user_manager_path = getattr(settings, "COGNITO_USER_MANAGER", False)
custom_user_manager_path = (
getattr(settings, "ENTRA_USER_MANAGER", False)
if provider == "entra"
else getattr(settings, "COGNITO_USER_MANAGER", False)
)
if custom_user_manager_path:
result = import_string(custom_user_manager_path)()
return result

@staticmethod
def get_user_model():
user_model = getattr(settings, "COGNITO_USER_MODEL", settings.AUTH_USER_MODEL)
def get_user_model(provider="cognito"):
user_model = (
getattr(settings, "ENTRA_USER_MODEL", settings.AUTH_USER_MODEL)
if provider == "entra"
else getattr(settings, "COGNITO_USER_MODEL", settings.AUTH_USER_MODEL)
)
return django_apps.get_model(user_model, require_ready=False)

@staticmethod
Expand All @@ -97,12 +119,33 @@ def get_jwt_token(request):
return auth[1]

@staticmethod
def get_token_validator(request):
return TokenValidator(
settings.COGNITO_AWS_REGION,
settings.COGNITO_USER_POOL,
settings.COGNITO_AUDIENCE,
)
def get_token_validator(jwt_token):
try:
# Decode without verifying signature just to read the header/payload
unverified_payload = jwt.decode(
jwt_token, options={"verify_signature": False} # noqa: S5659
)
issuer = unverified_payload.get("iss", "")
except jwt.PyJWTError as e:
raise exceptions.AuthenticationFailed(_("Malformed JWT.")) from e

if "cognito-idp" in issuer:
validator = CognitoTokenValidator(
settings.COGNITO_AWS_REGION,
settings.COGNITO_USER_POOL,
settings.COGNITO_AUDIENCE,
)
return validator, "cognito"

if "sts.windows.net" in issuer:
validator = EntraTokenValidator(
settings.ENTRA_TENANT_ID,
settings.ENTRA_AUDIENCE,
settings.ENTRA_ALLOWED_APP_IDS,
)
return validator, "entra"

raise exceptions.AuthenticationFailed(_("Invalid or unsupported token issuer."))

@staticmethod
def authenticate_header(request):
Expand Down
74 changes: 74 additions & 0 deletions common/auth/jwt/user_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import logging

from django.contrib.auth import get_user_model
from django.contrib.auth.models import BaseUserManager
from rest_framework import exceptions

from cms.auth_content.models.users import User
from metrics.data.managers.rbac_models.user import UserManager
from metrics.utils.permission_hierarchy import build_permission_hierarchy

logger = logging.getLogger(__name__)


def get_user_permission_set(user_id: str):
permissions = UserManager.get_permission_sets_for_user(user_id)
return build_permission_hierarchy(permissions)


class CognitoManager(BaseUserManager):

@staticmethod
def get_or_create(jwt_payload):
"""Create an ephemeral user instance for this request.
If the permissions aren't present in the JWT, queries for them in
the database based on the entraObjectId in the token
"""
try:
username = jwt_payload["entraObjectId"]
# Check if the JWT already includes permissionSets
# Use if found, if not grab user permissions from the database
if "permissionSets" in jwt_payload:
permission_sets = jwt_payload["permissionSets"]
else:
permission_sets = get_user_permission_set(username)
except KeyError:
logger.debug(
"Error getting entraObjectId and/or permissionSets field(s)"
" from jwt payload: '%s'",
jwt_payload,
)
return None

user_class = get_user_model()
user = user_class(username=username)
user.permission_sets = permission_sets
return user


class EntraManager(BaseUserManager):
Comment thread
itsthatianguy marked this conversation as resolved.

@staticmethod
def get_or_create(jwt_payload):
"""Create an ephemeral user instance for this request.
If the provided appid isn't present in the database, raises
AuthenticationFailed exception
"""
try:
username = jwt_payload["appid"]
if not User.objects.filter(user_id=username).exists():
msg = "Application not found."
raise exceptions.AuthenticationFailed(msg)
permission_sets = get_user_permission_set(username)
except KeyError:
logger.info(
"Error getting entraObjectId and/or permissionSets field(s)"
" from jwt payload: '%s'",
jwt_payload,
)
return None

user_class = get_user_model()
user = user_class(username=username)
user.permission_sets = permission_sets
return user
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class TokenError(Exception):
pass


class TokenValidator:
class CognitoTokenValidator:
def __init__(self, aws_region, aws_user_pool, audience):
self.aws_region = aws_region
self.aws_user_pool = aws_user_pool
Expand Down Expand Up @@ -86,3 +86,78 @@ def validate(self, token):
) as exc:
raise TokenError(str(exc)) from exc
return jwt_data


class EntraTokenValidator:
def __init__(self, tenant_id, audience, allowed_app_ids):
self.tenant_id = tenant_id
self.audience = audience
self.allowed_app_ids = allowed_app_ids
self.jwks_url = "https://login.microsoftonline.com/common/discovery/keys"

@cached_property
def expected_issuer(self):
return f"https://sts.windows.net/{self.tenant_id}/"

@cached_property
def _json_web_keys(self):
response = requests.get(self.jwks_url, timeout=10)
response.raise_for_status()
json_data = response.json()
return {item["kid"]: json.dumps(item) for item in json_data["keys"]}

def _get_public_key(self, token):
try:
headers = jwt.get_unverified_header(token)
except jwt.DecodeError as exc:
raise TokenError(str(exc)) from exc

if getattr(settings, "ENTRA_PUBLIC_KEYS_CACHING_ENABLED", False):
cache_key = "entra_jwt:{}".format(headers["kid"])
jwk_data = cache.get(cache_key)

if not jwk_data:
jwk_data = self._json_web_keys.get(headers["kid"])
timeout = getattr(settings, "ENTRA_PUBLIC_KEYS_CACHING_TIMEOUT", 300)
cache.set(cache_key, jwk_data, timeout=timeout)
else:
jwk_data = self._json_web_keys.get(headers["kid"])

if jwk_data:
return RSAAlgorithm.from_jwk(jwk_data)
return None

def validate(self, token):
public_key = self._get_public_key(token)
if not public_key:
msg = "No key found for this token"
raise TokenError(msg)

params = {
"jwt": token,
"key": public_key,
"issuer": self.expected_issuer,
"audience": self.audience,
"algorithms": ["RS256"],
}

try:
payload = jwt.decode(**params)
except (
jwt.InvalidTokenError,
jwt.ExpiredSignatureError,
jwt.DecodeError,
) as exc:
raise TokenError(str(exc)) from exc

roles = payload.get("roles", [])
if "application.read" not in roles:
msg = "Missing required role: application.read"
raise TokenError(msg)

app_id_claim = payload.get("appid") or payload.get("azp")
if app_id_claim not in self.allowed_app_ids:
msg = "Invalid app_id claim"
raise TokenError(msg)

return payload
9 changes: 8 additions & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,18 @@
# The name of the AWS profile to use for the AWS client used for ingestion
AWS_PROFILE_NAME = os.environ.get("AWS_PROFILE_NAME")

JWT_AUTH_HEADER = os.environ.get("JWT_AUTH_HEADER", "HTTP_AUTHORIZATION")

# Cognito configuration
COGNITO_AWS_REGION = os.environ.get("COGNITO_AWS_REGION")
COGNITO_JWT_AUTH_HEADER = os.environ.get("COGNITO_JWT_AUTH_HEADER")
COGNITO_USER_POOL = os.environ.get("COGNITO_USER_POOL")

# Entra configuration
ENTRA_AUDIENCE = os.environ.get("ENTRA_AUDIENCE")
ENTRA_APP_ID = os.environ.get("ENTRA_APP_ID")
ENTRA_ALLOWED_APP_IDS = os.environ.get("ENTRA_ALLOWED_APP_IDS", "")
ENTRA_TENANT_ID = os.environ.get("ENTRA_TENANT_ID")

# Database configuration
POSTGRES_DB = os.environ.get("POSTGRES_DB")
POSTGRES_USER = os.environ.get("POSTGRES_USER")
Expand Down
15 changes: 12 additions & 3 deletions metrics/api/settings/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,18 @@
},
]

COGNITO_USER_MANAGER = "common.auth.cognito_jwt.user_manager.CognitoManager"
JWT_AUTH_HEADER = config.JWT_AUTH_HEADER

ENTRA_USER_MANAGER = "common.auth.jwt.user_manager.EntraManager"
ENTRA_AUDIENCE = config.ENTRA_AUDIENCE
ENTRA_APP_ID = config.ENTRA_APP_ID
ENTRA_ALLOWED_APP_IDS = config.ENTRA_ALLOWED_APP_IDS.split(",")
ENTRA_TENANT_ID = config.ENTRA_TENANT_ID
ENTRA_PUBLIC_KEYS_CACHING_ENABLED = True
ENTRA_PUBLIC_KEYS_CACHING_TIMEOUT = 60 * 60 * 24 # 24h caching, default is 300s

COGNITO_USER_MANAGER = "common.auth.jwt.user_manager.CognitoManager"
COGNITO_AWS_REGION = config.COGNITO_AWS_REGION
COGNITO_JWT_AUTH_HEADER = config.COGNITO_JWT_AUTH_HEADER
COGNITO_USER_POOL = config.COGNITO_USER_POOL
COGNITO_AUDIENCE = None
COGNITO_PUBLIC_KEYS_CACHING_ENABLED = True
Expand All @@ -128,7 +137,7 @@
"DEFAULT_SCHEMA_CLASS": "drf_spectacular.openapi.AutoSchema",
"DEFAULT_AUTHENTICATION_CLASSES": [
"rest_framework.authentication.SessionAuthentication",
"common.auth.cognito_jwt.JSONWebTokenAuthentication",
"common.auth.jwt.JSONWebTokenAuthentication",
],
}

Expand Down
Loading
Loading