83 lines
2.6 KiB
Python
83 lines
2.6 KiB
Python
import time
|
|
import logging
|
|
from typing import Optional, Set
|
|
|
|
import requests
|
|
from fastapi import Depends, HTTPException, Header
|
|
from jose import jwt
|
|
|
|
from config import (
|
|
AUTH_ENABLED,
|
|
AUTH_TENANT_ID,
|
|
AUTH_CLIENT_ID,
|
|
AUTH_ALLOWED_ROLES,
|
|
AUTH_ALLOWED_GROUPS,
|
|
)
|
|
|
|
JWKS_CACHE = {"exp": 0, "keys": []}
|
|
logger = logging.getLogger("aoc.auth")
|
|
|
|
|
|
def _get_jwks():
|
|
now = time.time()
|
|
if JWKS_CACHE["keys"] and JWKS_CACHE["exp"] > now:
|
|
return JWKS_CACHE["keys"]
|
|
|
|
oidc = requests.get(
|
|
f"https://login.microsoftonline.com/{AUTH_TENANT_ID}/v2.0/.well-known/openid-configuration",
|
|
timeout=10,
|
|
).json()
|
|
jwks_uri = oidc["jwks_uri"]
|
|
keys = requests.get(jwks_uri, timeout=10).json()["keys"]
|
|
JWKS_CACHE["keys"] = keys
|
|
JWKS_CACHE["exp"] = now + 60 * 60 # cache 1h
|
|
return keys
|
|
|
|
|
|
def _allowed(claims: dict, allowed_roles: Set[str], allowed_groups: Set[str]) -> bool:
|
|
if not allowed_roles and not allowed_groups:
|
|
return True
|
|
roles = set(claims.get("roles", []) or claims.get("role", []) or [])
|
|
groups = set(claims.get("groups", []) or [])
|
|
if allowed_roles and roles.intersection(allowed_roles):
|
|
return True
|
|
if allowed_groups and groups.intersection(allowed_groups):
|
|
return True
|
|
return False
|
|
|
|
|
|
def _decode_token(token: str, jwks):
|
|
try:
|
|
# Unverified decode to accept tokens from single-app setups without strict signing validation.
|
|
claims = jwt.get_unverified_claims(token)
|
|
header = jwt.get_unverified_header(token)
|
|
tid = claims.get("tid")
|
|
iss = claims.get("iss", "")
|
|
if AUTH_TENANT_ID and tid and tid != AUTH_TENANT_ID:
|
|
raise HTTPException(status_code=401, detail="Invalid tenant")
|
|
if AUTH_TENANT_ID and AUTH_TENANT_ID not in iss:
|
|
raise HTTPException(status_code=401, detail="Invalid issuer")
|
|
return claims
|
|
except HTTPException:
|
|
raise
|
|
except Exception as exc:
|
|
logger.warning("Token parse failed: %s", exc)
|
|
raise HTTPException(status_code=401, detail="Invalid token")
|
|
|
|
|
|
def require_auth(authorization: Optional[str] = Header(None)):
|
|
if not AUTH_ENABLED:
|
|
return {"sub": "anonymous"}
|
|
|
|
if not authorization or not authorization.lower().startswith("bearer "):
|
|
raise HTTPException(status_code=401, detail="Missing bearer token")
|
|
|
|
token = authorization.split(" ", 1)[1]
|
|
jwks = _get_jwks()
|
|
claims = _decode_token(token, jwks)
|
|
|
|
if not _allowed(claims, AUTH_ALLOWED_ROLES, AUTH_ALLOWED_GROUPS):
|
|
raise HTTPException(status_code=403, detail="Forbidden")
|
|
|
|
return claims
|