Some checks failed
CI / lint-and-test (push) Has been cancelled
- Cache Graph API tokens with expiry-aware reuse in graph/auth.py - Add tenacity-based retry/backoff wrapper (utils/http.py) and apply to all Graph/source API calls - Add Pydantic request/response models (models/api.py) and FastAPI query constraints - Add unit tests for event_model, auth and integration tests for API endpoints - Configure ruff linter/formatter in pyproject.toml - Add GitHub Actions CI pipeline (.github/workflows/ci.yml) - Add requirements-dev.txt with pytest, mongomock, httpx, ruff - Clean up typing imports and fix ruff linting across codebase
90 lines
2.9 KiB
Python
90 lines
2.9 KiB
Python
import time
|
|
|
|
import requests
|
|
import structlog
|
|
from config import (
|
|
AUTH_ALLOWED_GROUPS,
|
|
AUTH_ALLOWED_ROLES,
|
|
AUTH_CLIENT_ID,
|
|
AUTH_ENABLED,
|
|
AUTH_TENANT_ID,
|
|
)
|
|
from fastapi import Header, HTTPException
|
|
from jose import jwt
|
|
from jose.jwk import construct
|
|
|
|
JWKS_CACHE = {"exp": 0, "keys": []}
|
|
logger = structlog.get_logger("aoc.auth")
|
|
|
|
|
|
def _get_jwks():
|
|
now = time.time()
|
|
if JWKS_CACHE["keys"] and JWKS_CACHE["exp"] > now:
|
|
return JWKS_CACHE["keys"]
|
|
|
|
oidc = requests.get(
|
|
f"https://login.microsoftonline.com/{AUTH_TENANT_ID}/v2.0/.well-known/openid-configuration",
|
|
timeout=10,
|
|
).json()
|
|
jwks_uri = oidc["jwks_uri"]
|
|
keys = requests.get(jwks_uri, timeout=10).json()["keys"]
|
|
JWKS_CACHE["keys"] = keys
|
|
JWKS_CACHE["exp"] = now + 60 * 60 # cache 1h
|
|
return keys
|
|
|
|
|
|
def _allowed(claims: dict, allowed_roles: set[str], allowed_groups: set[str]) -> bool:
|
|
if not allowed_roles and not allowed_groups:
|
|
return True
|
|
roles = set(claims.get("roles", []) or claims.get("role", []) or [])
|
|
groups = set(claims.get("groups", []) or [])
|
|
return bool(
|
|
(allowed_roles and roles.intersection(allowed_roles))
|
|
or (allowed_groups and groups.intersection(allowed_groups))
|
|
)
|
|
|
|
|
|
def _decode_token(token: str, jwks):
|
|
try:
|
|
header = jwt.get_unverified_header(token)
|
|
kid = header.get("kid")
|
|
key_dict = next((k for k in jwks if k.get("kid") == kid), None)
|
|
if not key_dict:
|
|
raise HTTPException(status_code=401, detail="Invalid token: signing key not found")
|
|
|
|
key = construct(key_dict)
|
|
decode_kwargs = {"algorithms": ["RS256"]}
|
|
if AUTH_CLIENT_ID:
|
|
decode_kwargs["audience"] = AUTH_CLIENT_ID
|
|
claims = jwt.decode(token, key, **decode_kwargs)
|
|
|
|
tid = claims.get("tid")
|
|
iss = claims.get("iss", "")
|
|
if AUTH_TENANT_ID and tid and tid != AUTH_TENANT_ID:
|
|
raise HTTPException(status_code=401, detail="Invalid tenant")
|
|
if AUTH_TENANT_ID and AUTH_TENANT_ID not in iss:
|
|
raise HTTPException(status_code=401, detail="Invalid issuer")
|
|
return claims
|
|
except HTTPException:
|
|
raise
|
|
except Exception as exc:
|
|
logger.warning("Token verification failed", error=str(exc))
|
|
raise HTTPException(status_code=401, detail="Invalid token") from None
|
|
|
|
|
|
def require_auth(authorization: str | None = Header(None)):
|
|
if not AUTH_ENABLED:
|
|
return {"sub": "anonymous"}
|
|
|
|
if not authorization or not authorization.lower().startswith("bearer "):
|
|
raise HTTPException(status_code=401, detail="Missing bearer token")
|
|
|
|
token = authorization.split(" ", 1)[1]
|
|
jwks = _get_jwks()
|
|
claims = _decode_token(token, jwks)
|
|
|
|
if not _allowed(claims, AUTH_ALLOWED_ROLES, AUTH_ALLOWED_GROUPS):
|
|
raise HTTPException(status_code=403, detail="Forbidden")
|
|
|
|
return claims
|