Files
aoc/backend/auth.py
Tomas Kracmar 4f6e16d64d feat: implement Phase 1 hardening
- Verify JWT signatures via JWKS in auth.py
- Fix broken frontend auth button references
- Add Pydantic Settings for env validation (RETENTION_DAYS, CORS_ORIGINS)
- Create MongoDB indexes + TTL on startup
- Add /health endpoint and CORS middleware
- Escape regex input in event queries
- Fix dedupe() return calculation in maintenance.py
- Replace basic logging with structured structlog JSON logs
- Update README and add ROADMAP.md
2026-04-14 11:48:29 +02:00

93 lines
2.9 KiB
Python

import time
import structlog
from typing import Optional, Set
import requests
from fastapi import Depends, HTTPException, Header
from jose import jwt
from jose.jwk import construct
from config import (
AUTH_ENABLED,
AUTH_TENANT_ID,
AUTH_CLIENT_ID,
AUTH_ALLOWED_ROLES,
AUTH_ALLOWED_GROUPS,
)
JWKS_CACHE = {"exp": 0, "keys": []}
logger = structlog.get_logger("aoc.auth")
def _get_jwks():
now = time.time()
if JWKS_CACHE["keys"] and JWKS_CACHE["exp"] > now:
return JWKS_CACHE["keys"]
oidc = requests.get(
f"https://login.microsoftonline.com/{AUTH_TENANT_ID}/v2.0/.well-known/openid-configuration",
timeout=10,
).json()
jwks_uri = oidc["jwks_uri"]
keys = requests.get(jwks_uri, timeout=10).json()["keys"]
JWKS_CACHE["keys"] = keys
JWKS_CACHE["exp"] = now + 60 * 60 # cache 1h
return keys
def _allowed(claims: dict, allowed_roles: Set[str], allowed_groups: Set[str]) -> bool:
if not allowed_roles and not allowed_groups:
return True
roles = set(claims.get("roles", []) or claims.get("role", []) or [])
groups = set(claims.get("groups", []) or [])
if allowed_roles and roles.intersection(allowed_roles):
return True
if allowed_groups and groups.intersection(allowed_groups):
return True
return False
def _decode_token(token: str, jwks):
try:
header = jwt.get_unverified_header(token)
kid = header.get("kid")
key_dict = next((k for k in jwks if k.get("kid") == kid), None)
if not key_dict:
raise HTTPException(status_code=401, detail="Invalid token: signing key not found")
key = construct(key_dict)
decode_kwargs = {"algorithms": ["RS256"]}
if AUTH_CLIENT_ID:
decode_kwargs["audience"] = AUTH_CLIENT_ID
claims = jwt.decode(token, key, **decode_kwargs)
tid = claims.get("tid")
iss = claims.get("iss", "")
if AUTH_TENANT_ID and tid and tid != AUTH_TENANT_ID:
raise HTTPException(status_code=401, detail="Invalid tenant")
if AUTH_TENANT_ID and AUTH_TENANT_ID not in iss:
raise HTTPException(status_code=401, detail="Invalid issuer")
return claims
except HTTPException:
raise
except Exception as exc:
logger.warning("Token verification failed", error=str(exc))
raise HTTPException(status_code=401, detail="Invalid token")
def require_auth(authorization: Optional[str] = Header(None)):
if not AUTH_ENABLED:
return {"sub": "anonymous"}
if not authorization or not authorization.lower().startswith("bearer "):
raise HTTPException(status_code=401, detail="Missing bearer token")
token = authorization.split(" ", 1)[1]
jwks = _get_jwks()
claims = _decode_token(token, jwks)
if not _allowed(claims, AUTH_ALLOWED_ROLES, AUTH_ALLOWED_GROUPS):
raise HTTPException(status_code=403, detail="Forbidden")
return claims