From ed310a06defc30880f4a252ad8a99941349f7d32 Mon Sep 17 00:00:00 2001 From: Tomas Kracmar Date: Tue, 14 Apr 2026 16:47:54 +0200 Subject: [PATCH] fix: replace python-jose with PyJWT for robust JWKS signature verification python-jose failed to correctly construct RSA public keys from Microsoft JWKS entries lacking an explicit alg field, causing signature verification failures. Switch auth.py to PyJWT + jwt.algorithms.RSAAlgorithm.from_jwk() which handles Entra JWKS correctly. Add cryptography explicitly to deps. Update auth tests to remove unused python-jose fixture code. --- backend/auth.py | 20 +++++++++++++++----- backend/requirements.txt | 2 ++ backend/tests/test_auth.py | 15 --------------- 3 files changed, 17 insertions(+), 20 deletions(-) diff --git a/backend/auth.py b/backend/auth.py index 2d45510..0c9363d 100644 --- a/backend/auth.py +++ b/backend/auth.py @@ -10,8 +10,8 @@ from config import ( AUTH_TENANT_ID, ) from fastapi import Header, HTTPException -from jose import jwt -from jose.jwk import construct +from jwt import ExpiredSignatureError, InvalidTokenError, decode +from jwt.algorithms import RSAAlgorithm JWKS_CACHE = {"exp": 0, "keys": []} logger = structlog.get_logger("aoc.auth") @@ -46,17 +46,21 @@ def _allowed(claims: dict, allowed_roles: set[str], allowed_groups: set[str]) -> def _decode_token(token: str, jwks): try: - header = jwt.get_unverified_header(token) + import json + + from jwt import get_unverified_header + + header = get_unverified_header(token) kid = header.get("kid") key_dict = next((k for k in jwks if k.get("kid") == kid), None) if not key_dict: raise HTTPException(status_code=401, detail="Invalid token: signing key not found") - key = construct(key_dict, algorithm="RS256") + pub_key = RSAAlgorithm.from_jwk(json.dumps(key_dict)) decode_kwargs = {"algorithms": ["RS256"]} if AUTH_CLIENT_ID: decode_kwargs["audience"] = AUTH_CLIENT_ID - claims = jwt.decode(token, key, **decode_kwargs) + claims = decode(token, pub_key, **decode_kwargs) tid = claims.get("tid") iss = claims.get("iss", "") @@ -67,6 +71,12 @@ def _decode_token(token: str, jwks): return claims except HTTPException: raise + except ExpiredSignatureError as exc: + logger.warning("Token verification failed", error_type="ExpiredSignatureError", error=str(exc)) + raise HTTPException(status_code=401, detail="Token expired") from None + except InvalidTokenError as exc: + logger.warning("Token verification failed", error_type=type(exc).__name__, error=str(exc)) + raise HTTPException(status_code=401, detail=f"Invalid token ({type(exc).__name__})") from None except Exception as exc: logger.warning("Token verification failed", error_type=type(exc).__name__, error=str(exc)) raise HTTPException(status_code=401, detail=f"Invalid token ({type(exc).__name__})") from None diff --git a/backend/requirements.txt b/backend/requirements.txt index b7afea3..8a0af0c 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -4,7 +4,9 @@ pymongo python-dotenv requests PyYAML +PyJWT python-jose[cryptography] +cryptography pydantic-settings structlog tenacity diff --git a/backend/tests/test_auth.py b/backend/tests/test_auth.py index b876a9c..fe733ed 100644 --- a/backend/tests/test_auth.py +++ b/backend/tests/test_auth.py @@ -12,21 +12,6 @@ def reset_cache(): auth.JWKS_CACHE["exp"] = 0 -@pytest.fixture -def mock_jwks(): - from Crypto.PublicKey import RSA - from jose.jwk import RSAKey - key = RSA.generate(2048) - rsa_key = RSAKey(key) - jwk_dict = { - "kty": "RSA", - "kid": "test-kid", - "n": rsa_key._key.n, - "e": rsa_key._key.e, - } - return rsa_key, jwk_dict - - def test_allowed_no_restrictions(): assert _allowed({}, set(), set()) is True