All checks were successful
CI / lint-and-test (push) Successful in 25s
- Add PRIVACY_SERVICES and PRIVACY_SERVICE_ROLES config variables
- Add user_can_access_privacy_services(claims) helper in auth.py
- /api/events filters out privacy services for users without required roles
- /api/filter-options excludes privacy services from dropdown options
- /api/ask excludes privacy services from NLQ queries
- /api/events/{id}/explain returns 403 for privacy events if unauthorized
- Teams added to default noisy service exclusion (frontend + backend)
- Update .env.example with privacy config documentation
- Add tests for event filtering, filter-options exclusion, and explain 403
110 lines
3.9 KiB
Python
110 lines
3.9 KiB
Python
import time
|
|
|
|
import requests
|
|
import structlog
|
|
from config import (
|
|
AUTH_ALLOWED_GROUPS,
|
|
AUTH_ALLOWED_ROLES,
|
|
AUTH_CLIENT_ID,
|
|
AUTH_ENABLED,
|
|
AUTH_TENANT_ID,
|
|
PRIVACY_SERVICE_ROLES,
|
|
PRIVACY_SERVICES,
|
|
)
|
|
from fastapi import Header, HTTPException
|
|
from jwt import ExpiredSignatureError, InvalidTokenError, decode
|
|
from jwt.algorithms import RSAAlgorithm
|
|
|
|
JWKS_CACHE = {"exp": 0, "keys": []}
|
|
logger = structlog.get_logger("aoc.auth")
|
|
|
|
|
|
def _get_jwks():
|
|
now = time.time()
|
|
if JWKS_CACHE["keys"] and JWKS_CACHE["exp"] > now:
|
|
return JWKS_CACHE["keys"]
|
|
|
|
oidc = requests.get(
|
|
f"https://login.microsoftonline.com/{AUTH_TENANT_ID}/v2.0/.well-known/openid-configuration",
|
|
timeout=10,
|
|
).json()
|
|
jwks_uri = oidc["jwks_uri"]
|
|
keys = requests.get(jwks_uri, timeout=10).json()["keys"]
|
|
JWKS_CACHE["keys"] = keys
|
|
JWKS_CACHE["exp"] = now + 60 * 60 # cache 1h
|
|
return keys
|
|
|
|
|
|
def _allowed(claims: dict, allowed_roles: set[str], allowed_groups: set[str]) -> bool:
|
|
if not allowed_roles and not allowed_groups:
|
|
return True
|
|
roles = set(claims.get("roles", []) or claims.get("role", []) or [])
|
|
groups = set(claims.get("groups", []) or [])
|
|
return bool(
|
|
(allowed_roles and roles.intersection(allowed_roles))
|
|
or (allowed_groups and groups.intersection(allowed_groups))
|
|
)
|
|
|
|
|
|
def _decode_token(token: str, jwks):
|
|
try:
|
|
import json
|
|
|
|
from jwt import get_unverified_header
|
|
|
|
header = get_unverified_header(token)
|
|
kid = header.get("kid")
|
|
key_dict = next((k for k in jwks if k.get("kid") == kid), None)
|
|
if not key_dict:
|
|
raise HTTPException(status_code=401, detail="Invalid token: signing key not found")
|
|
|
|
pub_key = RSAAlgorithm.from_jwk(json.dumps(key_dict))
|
|
decode_kwargs = {"algorithms": ["RS256"]}
|
|
if AUTH_CLIENT_ID:
|
|
decode_kwargs["audience"] = AUTH_CLIENT_ID
|
|
claims = decode(token, pub_key, **decode_kwargs)
|
|
|
|
tid = claims.get("tid")
|
|
iss = claims.get("iss", "")
|
|
if AUTH_TENANT_ID and tid and tid != AUTH_TENANT_ID:
|
|
raise HTTPException(status_code=401, detail="Invalid tenant")
|
|
if AUTH_TENANT_ID and AUTH_TENANT_ID not in iss:
|
|
raise HTTPException(status_code=401, detail="Invalid issuer")
|
|
return claims
|
|
except HTTPException:
|
|
raise
|
|
except ExpiredSignatureError as exc:
|
|
logger.warning("Token verification failed", error_type="ExpiredSignatureError", error=str(exc))
|
|
raise HTTPException(status_code=401, detail="Token expired") from None
|
|
except InvalidTokenError as exc:
|
|
logger.warning("Token verification failed", error_type=type(exc).__name__, error=str(exc))
|
|
raise HTTPException(status_code=401, detail=f"Invalid token ({type(exc).__name__})") from None
|
|
except Exception as exc:
|
|
logger.warning("Token verification failed", error_type=type(exc).__name__, error=str(exc))
|
|
raise HTTPException(status_code=401, detail=f"Invalid token ({type(exc).__name__})") from None
|
|
|
|
|
|
def user_can_access_privacy_services(claims: dict) -> bool:
|
|
"""Check if the user has roles that grant access to privacy-sensitive services."""
|
|
if not PRIVACY_SERVICES or not PRIVACY_SERVICE_ROLES:
|
|
return True
|
|
user_roles = set(claims.get("roles", []) or claims.get("role", []) or [])
|
|
return bool(user_roles.intersection(PRIVACY_SERVICE_ROLES))
|
|
|
|
|
|
def require_auth(authorization: str | None = Header(None)):
|
|
if not AUTH_ENABLED:
|
|
return {"sub": "anonymous"}
|
|
|
|
if not authorization or not authorization.lower().startswith("bearer "):
|
|
raise HTTPException(status_code=401, detail="Missing bearer token")
|
|
|
|
token = authorization.split(" ", 1)[1]
|
|
jwks = _get_jwks()
|
|
claims = _decode_token(token, jwks)
|
|
|
|
if not _allowed(claims, AUTH_ALLOWED_ROLES, AUTH_ALLOWED_GROUPS):
|
|
raise HTTPException(status_code=403, detail="Forbidden")
|
|
|
|
return claims
|