fix: replace python-jose with PyJWT for robust JWKS signature verification
Some checks failed
CI / lint-and-test (push) Has been cancelled

python-jose failed to correctly construct RSA public keys from Microsoft
JWKS entries lacking an explicit alg field, causing signature verification
failures. Switch auth.py to PyJWT + jwt.algorithms.RSAAlgorithm.from_jwk()
which handles Entra JWKS correctly. Add cryptography explicitly to deps.
Update auth tests to remove unused python-jose fixture code.
This commit is contained in:
2026-04-14 16:47:54 +02:00
parent c22c637511
commit ed310a06de
3 changed files with 17 additions and 20 deletions

View File

@@ -10,8 +10,8 @@ from config import (
AUTH_TENANT_ID, AUTH_TENANT_ID,
) )
from fastapi import Header, HTTPException from fastapi import Header, HTTPException
from jose import jwt from jwt import ExpiredSignatureError, InvalidTokenError, decode
from jose.jwk import construct from jwt.algorithms import RSAAlgorithm
JWKS_CACHE = {"exp": 0, "keys": []} JWKS_CACHE = {"exp": 0, "keys": []}
logger = structlog.get_logger("aoc.auth") logger = structlog.get_logger("aoc.auth")
@@ -46,17 +46,21 @@ def _allowed(claims: dict, allowed_roles: set[str], allowed_groups: set[str]) ->
def _decode_token(token: str, jwks): def _decode_token(token: str, jwks):
try: try:
header = jwt.get_unverified_header(token) import json
from jwt import get_unverified_header
header = get_unverified_header(token)
kid = header.get("kid") kid = header.get("kid")
key_dict = next((k for k in jwks if k.get("kid") == kid), None) key_dict = next((k for k in jwks if k.get("kid") == kid), None)
if not key_dict: if not key_dict:
raise HTTPException(status_code=401, detail="Invalid token: signing key not found") raise HTTPException(status_code=401, detail="Invalid token: signing key not found")
key = construct(key_dict, algorithm="RS256") pub_key = RSAAlgorithm.from_jwk(json.dumps(key_dict))
decode_kwargs = {"algorithms": ["RS256"]} decode_kwargs = {"algorithms": ["RS256"]}
if AUTH_CLIENT_ID: if AUTH_CLIENT_ID:
decode_kwargs["audience"] = AUTH_CLIENT_ID decode_kwargs["audience"] = AUTH_CLIENT_ID
claims = jwt.decode(token, key, **decode_kwargs) claims = decode(token, pub_key, **decode_kwargs)
tid = claims.get("tid") tid = claims.get("tid")
iss = claims.get("iss", "") iss = claims.get("iss", "")
@@ -67,6 +71,12 @@ def _decode_token(token: str, jwks):
return claims return claims
except HTTPException: except HTTPException:
raise raise
except ExpiredSignatureError as exc:
logger.warning("Token verification failed", error_type="ExpiredSignatureError", error=str(exc))
raise HTTPException(status_code=401, detail="Token expired") from None
except InvalidTokenError as exc:
logger.warning("Token verification failed", error_type=type(exc).__name__, error=str(exc))
raise HTTPException(status_code=401, detail=f"Invalid token ({type(exc).__name__})") from None
except Exception as exc: except Exception as exc:
logger.warning("Token verification failed", error_type=type(exc).__name__, error=str(exc)) logger.warning("Token verification failed", error_type=type(exc).__name__, error=str(exc))
raise HTTPException(status_code=401, detail=f"Invalid token ({type(exc).__name__})") from None raise HTTPException(status_code=401, detail=f"Invalid token ({type(exc).__name__})") from None

View File

@@ -4,7 +4,9 @@ pymongo
python-dotenv python-dotenv
requests requests
PyYAML PyYAML
PyJWT
python-jose[cryptography] python-jose[cryptography]
cryptography
pydantic-settings pydantic-settings
structlog structlog
tenacity tenacity

View File

@@ -12,21 +12,6 @@ def reset_cache():
auth.JWKS_CACHE["exp"] = 0 auth.JWKS_CACHE["exp"] = 0
@pytest.fixture
def mock_jwks():
from Crypto.PublicKey import RSA
from jose.jwk import RSAKey
key = RSA.generate(2048)
rsa_key = RSAKey(key)
jwk_dict = {
"kty": "RSA",
"kid": "test-kid",
"n": rsa_key._key.n,
"e": rsa_key._key.e,
}
return rsa_key, jwk_dict
def test_allowed_no_restrictions(): def test_allowed_no_restrictions():
assert _allowed({}, set(), set()) is True assert _allowed({}, set(), set()) is True