fix: replace python-jose with PyJWT for robust JWKS signature verification
Some checks failed
CI / lint-and-test (push) Has been cancelled
Some checks failed
CI / lint-and-test (push) Has been cancelled
python-jose failed to correctly construct RSA public keys from Microsoft JWKS entries lacking an explicit alg field, causing signature verification failures. Switch auth.py to PyJWT + jwt.algorithms.RSAAlgorithm.from_jwk() which handles Entra JWKS correctly. Add cryptography explicitly to deps. Update auth tests to remove unused python-jose fixture code.
This commit is contained in:
@@ -10,8 +10,8 @@ from config import (
|
|||||||
AUTH_TENANT_ID,
|
AUTH_TENANT_ID,
|
||||||
)
|
)
|
||||||
from fastapi import Header, HTTPException
|
from fastapi import Header, HTTPException
|
||||||
from jose import jwt
|
from jwt import ExpiredSignatureError, InvalidTokenError, decode
|
||||||
from jose.jwk import construct
|
from jwt.algorithms import RSAAlgorithm
|
||||||
|
|
||||||
JWKS_CACHE = {"exp": 0, "keys": []}
|
JWKS_CACHE = {"exp": 0, "keys": []}
|
||||||
logger = structlog.get_logger("aoc.auth")
|
logger = structlog.get_logger("aoc.auth")
|
||||||
@@ -46,17 +46,21 @@ def _allowed(claims: dict, allowed_roles: set[str], allowed_groups: set[str]) ->
|
|||||||
|
|
||||||
def _decode_token(token: str, jwks):
|
def _decode_token(token: str, jwks):
|
||||||
try:
|
try:
|
||||||
header = jwt.get_unverified_header(token)
|
import json
|
||||||
|
|
||||||
|
from jwt import get_unverified_header
|
||||||
|
|
||||||
|
header = get_unverified_header(token)
|
||||||
kid = header.get("kid")
|
kid = header.get("kid")
|
||||||
key_dict = next((k for k in jwks if k.get("kid") == kid), None)
|
key_dict = next((k for k in jwks if k.get("kid") == kid), None)
|
||||||
if not key_dict:
|
if not key_dict:
|
||||||
raise HTTPException(status_code=401, detail="Invalid token: signing key not found")
|
raise HTTPException(status_code=401, detail="Invalid token: signing key not found")
|
||||||
|
|
||||||
key = construct(key_dict, algorithm="RS256")
|
pub_key = RSAAlgorithm.from_jwk(json.dumps(key_dict))
|
||||||
decode_kwargs = {"algorithms": ["RS256"]}
|
decode_kwargs = {"algorithms": ["RS256"]}
|
||||||
if AUTH_CLIENT_ID:
|
if AUTH_CLIENT_ID:
|
||||||
decode_kwargs["audience"] = AUTH_CLIENT_ID
|
decode_kwargs["audience"] = AUTH_CLIENT_ID
|
||||||
claims = jwt.decode(token, key, **decode_kwargs)
|
claims = decode(token, pub_key, **decode_kwargs)
|
||||||
|
|
||||||
tid = claims.get("tid")
|
tid = claims.get("tid")
|
||||||
iss = claims.get("iss", "")
|
iss = claims.get("iss", "")
|
||||||
@@ -67,6 +71,12 @@ def _decode_token(token: str, jwks):
|
|||||||
return claims
|
return claims
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
|
except ExpiredSignatureError as exc:
|
||||||
|
logger.warning("Token verification failed", error_type="ExpiredSignatureError", error=str(exc))
|
||||||
|
raise HTTPException(status_code=401, detail="Token expired") from None
|
||||||
|
except InvalidTokenError as exc:
|
||||||
|
logger.warning("Token verification failed", error_type=type(exc).__name__, error=str(exc))
|
||||||
|
raise HTTPException(status_code=401, detail=f"Invalid token ({type(exc).__name__})") from None
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("Token verification failed", error_type=type(exc).__name__, error=str(exc))
|
logger.warning("Token verification failed", error_type=type(exc).__name__, error=str(exc))
|
||||||
raise HTTPException(status_code=401, detail=f"Invalid token ({type(exc).__name__})") from None
|
raise HTTPException(status_code=401, detail=f"Invalid token ({type(exc).__name__})") from None
|
||||||
|
|||||||
@@ -4,7 +4,9 @@ pymongo
|
|||||||
python-dotenv
|
python-dotenv
|
||||||
requests
|
requests
|
||||||
PyYAML
|
PyYAML
|
||||||
|
PyJWT
|
||||||
python-jose[cryptography]
|
python-jose[cryptography]
|
||||||
|
cryptography
|
||||||
pydantic-settings
|
pydantic-settings
|
||||||
structlog
|
structlog
|
||||||
tenacity
|
tenacity
|
||||||
|
|||||||
@@ -12,21 +12,6 @@ def reset_cache():
|
|||||||
auth.JWKS_CACHE["exp"] = 0
|
auth.JWKS_CACHE["exp"] = 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_jwks():
|
|
||||||
from Crypto.PublicKey import RSA
|
|
||||||
from jose.jwk import RSAKey
|
|
||||||
key = RSA.generate(2048)
|
|
||||||
rsa_key = RSAKey(key)
|
|
||||||
jwk_dict = {
|
|
||||||
"kty": "RSA",
|
|
||||||
"kid": "test-kid",
|
|
||||||
"n": rsa_key._key.n,
|
|
||||||
"e": rsa_key._key.e,
|
|
||||||
}
|
|
||||||
return rsa_key, jwk_dict
|
|
||||||
|
|
||||||
|
|
||||||
def test_allowed_no_restrictions():
|
def test_allowed_no_restrictions():
|
||||||
assert _allowed({}, set(), set()) is True
|
assert _allowed({}, set(), set()) is True
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user