Files
aoc/backend/auth.py
Tomas Kracmar d01e7801ed
All checks were successful
CI / lint-and-test (push) Successful in 51s
Release / build-and-push (push) Successful in 1m57s
security: v1.7.7 hardening release
- Add WEBHOOK_CLIENT_SECRET validation for Graph webhooks
- Add Redis-backed rate limiting (fetch/ask/write/default tiers)
- Validate LLM_BASE_URL to prevent SSRF (HTTPS only, block private IPs)
- Enforce non-wildcard CORS when AUTH_ENABLED=true
- Add Content-Security-Policy headers
- Fix audit middleware to use verified JWT claims via contextvars
- Cap bulk_tags updates to 10,000 documents
- Return generic error messages to clients (no internal detail leakage)
- Strict AlertCondition Pydantic model for alert rules
- Security warning on MCP stdio server startup
- Remove MongoDB/Redis host ports from docker-compose
- Remove mongo_query from /ask API response
2026-04-27 09:16:57 +02:00

117 lines
4.2 KiB
Python

import contextvars
import time
import requests
import structlog
from config import (
AUTH_ALLOWED_GROUPS,
AUTH_ALLOWED_ROLES,
AUTH_CLIENT_ID,
AUTH_ENABLED,
AUTH_TENANT_ID,
PRIVACY_SERVICE_ROLES,
PRIVACY_SERVICES,
)
from fastapi import Header, HTTPException
from jwt import ExpiredSignatureError, InvalidTokenError, decode
from jwt.algorithms import RSAAlgorithm
# Thread-/task-local storage for verified auth claims (used by audit middleware)
_auth_context: contextvars.ContextVar[dict | None] = contextvars.ContextVar("auth_context", default=None)
JWKS_CACHE = {"exp": 0, "keys": []}
logger = structlog.get_logger("aoc.auth")
def _get_jwks():
now = time.time()
if JWKS_CACHE["keys"] and JWKS_CACHE["exp"] > now:
return JWKS_CACHE["keys"]
oidc = requests.get(
f"https://login.microsoftonline.com/{AUTH_TENANT_ID}/v2.0/.well-known/openid-configuration",
timeout=10,
).json()
jwks_uri = oidc["jwks_uri"]
keys = requests.get(jwks_uri, timeout=10).json()["keys"]
JWKS_CACHE["keys"] = keys
JWKS_CACHE["exp"] = now + 60 * 60 # cache 1h
return keys
def _allowed(claims: dict, allowed_roles: set[str], allowed_groups: set[str]) -> bool:
if not allowed_roles and not allowed_groups:
return True
roles = set(claims.get("roles", []) or claims.get("role", []) or [])
groups = set(claims.get("groups", []) or [])
return bool(
(allowed_roles and roles.intersection(allowed_roles))
or (allowed_groups and groups.intersection(allowed_groups))
)
def _decode_token(token: str, jwks):
try:
import json
from jwt import get_unverified_header
header = get_unverified_header(token)
kid = header.get("kid")
key_dict = next((k for k in jwks if k.get("kid") == kid), None)
if not key_dict:
raise HTTPException(status_code=401, detail="Invalid token: signing key not found")
pub_key = RSAAlgorithm.from_jwk(json.dumps(key_dict))
decode_kwargs = {"algorithms": ["RS256"]}
if AUTH_CLIENT_ID:
decode_kwargs["audience"] = AUTH_CLIENT_ID
claims = decode(token, pub_key, **decode_kwargs)
tid = claims.get("tid")
iss = claims.get("iss", "")
if AUTH_TENANT_ID and tid and tid != AUTH_TENANT_ID:
raise HTTPException(status_code=401, detail="Invalid tenant")
if AUTH_TENANT_ID and AUTH_TENANT_ID not in iss:
raise HTTPException(status_code=401, detail="Invalid issuer")
return claims
except HTTPException:
raise
except ExpiredSignatureError as exc:
logger.warning("Token verification failed", error_type="ExpiredSignatureError", error=str(exc))
raise HTTPException(status_code=401, detail="Token expired") from None
except InvalidTokenError as exc:
logger.warning("Token verification failed", error_type=type(exc).__name__, error=str(exc))
raise HTTPException(status_code=401, detail=f"Invalid token ({type(exc).__name__})") from None
except Exception as exc:
logger.warning("Token verification failed", error_type=type(exc).__name__, error=str(exc))
raise HTTPException(status_code=401, detail=f"Invalid token ({type(exc).__name__})") from None
def user_can_access_privacy_services(claims: dict) -> bool:
"""Check if the user has roles that grant access to privacy-sensitive services."""
if not PRIVACY_SERVICES or not PRIVACY_SERVICE_ROLES:
return True
user_roles = set(claims.get("roles", []) or claims.get("role", []) or [])
return bool(user_roles.intersection(PRIVACY_SERVICE_ROLES))
def require_auth(authorization: str | None = Header(None)):
if not AUTH_ENABLED:
user = {"sub": "anonymous"}
_auth_context.set(user)
return user
if not authorization or not authorization.lower().startswith("bearer "):
raise HTTPException(status_code=401, detail="Missing bearer token")
token = authorization.split(" ", 1)[1]
jwks = _get_jwks()
claims = _decode_token(token, jwks)
if not _allowed(claims, AUTH_ALLOWED_ROLES, AUTH_ALLOWED_GROUPS):
raise HTTPException(status_code=403, detail="Forbidden")
_auth_context.set(claims)
return claims