feat: implement Phase 2 stabilization
Some checks failed
CI / lint-and-test (push) Has been cancelled

- Cache Graph API tokens with expiry-aware reuse in graph/auth.py
- Add tenacity-based retry/backoff wrapper (utils/http.py) and apply to all Graph/source API calls
- Add Pydantic request/response models (models/api.py) and FastAPI query constraints
- Add unit tests for event_model, auth and integration tests for API endpoints
- Configure ruff linter/formatter in pyproject.toml
- Add GitHub Actions CI pipeline (.github/workflows/ci.yml)
- Add requirements-dev.txt with pytest, mongomock, httpx, ruff
- Clean up typing imports and fix ruff linting across codebase
This commit is contained in:
2026-04-14 12:02:28 +02:00
parent 4f6e16d64d
commit 9271b4e461
29 changed files with 518 additions and 118 deletions

37
.github/workflows/ci.yml vendored Normal file
View File

@@ -0,0 +1,37 @@
name: CI
on:
push:
branches: [main]
pull_request:
branches: [main]
jobs:
lint-and-test:
runs-on: ubuntu-latest
defaults:
run:
working-directory: ./backend
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r requirements-dev.txt
- name: Lint with ruff
run: ruff check .
- name: Format check with ruff
run: ruff format --check .
- name: Run tests
run: pytest -q

View File

@@ -94,6 +94,26 @@ Stored document shape (collection `micro_soc.events`):
}
```
## Development
### Linting and formatting
We use `ruff` for linting and formatting.
```bash
cd backend
python3 -m venv .venv
source .venv/bin/activate
pip install -r requirements.txt -r requirements-dev.txt
ruff check ..
ruff format ..
```
### Running tests
```bash
cd backend
pytest -q
```
## Quick smoke tests
With the server running:
```bash

View File

@@ -20,17 +20,17 @@ Goal: fix critical security and reliability gaps before production use.
---
## Phase 2: Stabilize
## Phase 2: Stabilize
Goal: improve resilience, code quality, and development experience.
- [ ] Cache Graph API tokens and reuse them until near expiry
- [ ] Add exponential backoff / retry logic for Graph API and Office 365 API calls
- [ ] Add unit tests for `normalize_event()`, `_make_dedupe_key()`, and `auth.py`
- [ ] Add integration tests for `/api/events` and `/api/fetch-audit-logs`
- [ ] Configure linter/formatter (`ruff` or `black` + `isort`) and pre-commit hooks
- [ ] Set up GitHub Actions CI pipeline (lint + test)
- [ ] Add Pydantic request/response models for API endpoints
- [ ] Validate `page_size` and `hours` with strict FastAPI constraints
- [x] Cache Graph API tokens and reuse them until near expiry
- [x] Add exponential backoff / retry logic for Graph API and Office 365 API calls
- [x] Add unit tests for `normalize_event()`, `_make_dedupe_key()`, and `auth.py`
- [x] Add integration tests for `/api/events` and `/api/fetch-audit-logs`
- [x] Configure linter/formatter (`ruff`) and pre-commit hooks
- [x] Set up GitHub Actions CI pipeline (lint + test)
- [x] Add Pydantic request/response models for API endpoints
- [x] Validate `page_size` and `hours` with strict FastAPI constraints
---

View File

@@ -1,19 +1,17 @@
import time
import structlog
from typing import Optional, Set
import requests
from fastapi import Depends, HTTPException, Header
from jose import jwt
from jose.jwk import construct
import structlog
from config import (
AUTH_ALLOWED_GROUPS,
AUTH_ALLOWED_ROLES,
AUTH_CLIENT_ID,
AUTH_ENABLED,
AUTH_TENANT_ID,
AUTH_CLIENT_ID,
AUTH_ALLOWED_ROLES,
AUTH_ALLOWED_GROUPS,
)
from fastapi import Header, HTTPException
from jose import jwt
from jose.jwk import construct
JWKS_CACHE = {"exp": 0, "keys": []}
logger = structlog.get_logger("aoc.auth")
@@ -35,16 +33,15 @@ def _get_jwks():
return keys
def _allowed(claims: dict, allowed_roles: Set[str], allowed_groups: Set[str]) -> bool:
def _allowed(claims: dict, allowed_roles: set[str], allowed_groups: set[str]) -> bool:
if not allowed_roles and not allowed_groups:
return True
roles = set(claims.get("roles", []) or claims.get("role", []) or [])
groups = set(claims.get("groups", []) or [])
if allowed_roles and roles.intersection(allowed_roles):
return True
if allowed_groups and groups.intersection(allowed_groups):
return True
return False
return bool(
(allowed_roles and roles.intersection(allowed_roles))
or (allowed_groups and groups.intersection(allowed_groups))
)
def _decode_token(token: str, jwks):
@@ -72,10 +69,10 @@ def _decode_token(token: str, jwks):
raise
except Exception as exc:
logger.warning("Token verification failed", error=str(exc))
raise HTTPException(status_code=401, detail="Invalid token")
raise HTTPException(status_code=401, detail="Invalid token") from None
def require_auth(authorization: Optional[str] = Header(None)):
def require_auth(authorization: str | None = Header(None)):
if not AUTH_ENABLED:
return {"sub": "anonymous"}

View File

@@ -1,6 +1,8 @@
from pymongo import MongoClient, ASCENDING, DESCENDING, TEXT
from config import MONGO_URI, DB_NAME, RETENTION_DAYS
from contextlib import suppress
import structlog
from config import DB_NAME, MONGO_URI, RETENTION_DAYS
from pymongo import ASCENDING, DESCENDING, TEXT, MongoClient
client = MongoClient(MONGO_URI)
db = client[DB_NAME]
@@ -29,10 +31,8 @@ def setup_indexes(max_retries: int = 5, delay: float = 2.0):
name="ttl_timestamp",
)
else:
try:
with suppress(Exception):
events_collection.drop_index("ttl_timestamp")
except Exception:
pass
logger.info("MongoDB indexes ensured")
return
except Exception as exc:

View File

@@ -1,7 +1,8 @@
import requests
from datetime import datetime, timedelta
from graph.auth import get_access_token
from graph.resolve import resolve_directory_object, resolve_service_principal_owners
from utils.http import get_with_retry
def fetch_audit_logs(hours=24, max_pages=50):
@@ -22,13 +23,13 @@ def fetch_audit_logs(hours=24, max_pages=50):
raise RuntimeError(f"Aborting pagination after {max_pages} pages to avoid runaway fetch.")
try:
res = requests.get(next_url, headers=headers, timeout=20)
res = get_with_retry(next_url, headers=headers, timeout=20)
res.raise_for_status()
body = res.json()
except requests.RequestException as exc:
except RuntimeError:
raise
except Exception as exc:
raise RuntimeError(f"Failed to fetch audit logs page: {exc}") from exc
except ValueError as exc:
raise RuntimeError(f"Invalid JSON response from Graph: {exc}") from exc
events.extend(body.get("value", []))
next_url = body.get("@odata.nextLink")

View File

@@ -1,9 +1,19 @@
import time
import requests
from config import TENANT_ID, CLIENT_ID, CLIENT_SECRET
from config import CLIENT_ID, CLIENT_SECRET, TENANT_ID
_TOKEN_CACHE = {}
def get_access_token(scope: str = "https://graph.microsoft.com/.default"):
"""Request an application token from Microsoft identity platform."""
"""Request an application token from Microsoft identity platform.
Tokens are cached and reused until 5 minutes before expiry."""
now = time.time()
cached = _TOKEN_CACHE.get(scope)
if cached and cached["exp"] > now + 300:
return cached["token"]
url = f"https://login.microsoftonline.com/{TENANT_ID}/oauth2/v2.0/token"
data = {
"grant_type": "client_credentials",
@@ -14,9 +24,12 @@ def get_access_token(scope: str = "https://graph.microsoft.com/.default"):
try:
res = requests.post(url, data=data, timeout=15)
res.raise_for_status()
token = res.json().get("access_token")
payload = res.json()
token = payload.get("access_token")
if not token:
raise RuntimeError("Token endpoint returned no access_token")
expires_in = payload.get("expires_in", 3600)
_TOKEN_CACHE[scope] = {"token": token, "exp": now + expires_in}
return token
except requests.RequestException as exc:
raise RuntimeError(f"Failed to obtain access token: {exc}") from exc

View File

@@ -1,6 +1,5 @@
from typing import Dict, List, Optional
import requests
from utils.http import get_with_retry
def _name_from_payload(payload: dict, kind: str) -> str:
@@ -26,18 +25,18 @@ def _name_from_payload(payload: dict, kind: str) -> str:
return payload.get("displayName") or payload.get("id") or "Unknown"
def _request_json(url: str, token: str) -> Optional[dict]:
def _request_json(url: str, token: str) -> dict | None:
try:
res = requests.get(url, headers={"Authorization": f"Bearer {token}"}, timeout=10)
res = get_with_retry(url, headers={"Authorization": f"Bearer {token}"}, timeout=10)
if res.status_code == 404:
return None
res.raise_for_status()
return res.json()
except requests.RequestException:
except Exception:
return None
def resolve_directory_object(object_id: str, token: str, cache: Dict[str, dict]) -> Optional[dict]:
def resolve_directory_object(object_id: str, token: str, cache: dict[str, dict]) -> dict | None:
"""
Resolve a directory object (user, servicePrincipal, group, device) to a readable name.
Uses a simple multi-endpoint probe with caching to avoid extra Graph traffic.
@@ -69,7 +68,7 @@ def resolve_directory_object(object_id: str, token: str, cache: Dict[str, dict])
return None
def resolve_service_principal_owners(sp_id: str, token: str, cache: Dict[str, List[str]]) -> List[str]:
def resolve_service_principal_owners(sp_id: str, token: str, cache: dict[str, list[str]]) -> list[str]:
"""Return a list of owner display names for a service principal."""
if not sp_id:
return []

View File

@@ -1,17 +1,18 @@
import asyncio
import logging
from contextlib import suppress
from pathlib import Path
import structlog
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from config import CORS_ORIGINS, ENABLE_PERIODIC_FETCH, FETCH_INTERVAL_MINUTES
from database import setup_indexes
from routes.fetch import router as fetch_router, run_fetch
from routes.events import router as events_router
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from routes.config import router as config_router
from config import ENABLE_PERIODIC_FETCH, FETCH_INTERVAL_MINUTES, CORS_ORIGINS
from routes.events import router as events_router
from routes.fetch import router as fetch_router
from routes.fetch import run_fetch
def configure_logging():
@@ -90,7 +91,5 @@ async def stop_periodic_fetch():
task = getattr(app.state, "fetch_task", None)
if task:
task.cancel()
try:
with suppress(Exception):
await task
except Exception:
pass

View File

@@ -7,14 +7,12 @@ new display fields. Example:
python maintenance.py renormalize --limit 500
"""
import argparse
from typing import List, Set
from pymongo import UpdateOne
from database import events_collection
from graph.auth import get_access_token
from graph.audit_logs import _enrich_events
from models.event_model import normalize_event, _make_dedupe_key
from graph.auth import get_access_token
from models.event_model import _make_dedupe_key, normalize_event
from pymongo import UpdateOne
def renormalize(limit: int = None, batch_size: int = 200) -> int:
@@ -29,7 +27,7 @@ def renormalize(limit: int = None, batch_size: int = 200) -> int:
cursor = cursor.limit(int(limit))
updated = 0
batch: List[UpdateOne] = []
batch: list[UpdateOne] = []
for doc in cursor:
raw = doc.get("raw") or {}
@@ -59,7 +57,7 @@ def dedupe(limit: int = None, batch_size: int = 500) -> int:
if limit:
cursor = cursor.limit(int(limit))
seen: Set[str] = set()
seen: set[str] = set()
to_delete = []
processed = 0

View File

@@ -1,11 +1,10 @@
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict
from typing import Any
import yaml
DEFAULT_MAPPING: Dict[str, Any] = {
DEFAULT_MAPPING: dict[str, Any] = {
"category_labels": {
"ApplicationManagement": "Application",
"UserManagement": "User",
@@ -38,7 +37,7 @@ DEFAULT_MAPPING: Dict[str, Any] = {
@lru_cache(maxsize=1)
def get_mapping() -> Dict[str, Any]:
def get_mapping() -> dict[str, Any]:
"""
Load mapping from mappings.yml if present; otherwise fall back to defaults.
Users can edit mappings.yml to change labels and summary templates.

View File

41
backend/models/api.py Normal file
View File

@@ -0,0 +1,41 @@
from pydantic import BaseModel, ConfigDict
class EventItem(BaseModel):
id: str | None = None
timestamp: str | None = None
service: str | None = None
operation: str | None = None
result: str | None = None
actor_display: str | None = None
target_displays: list[str] | None = None
display_summary: str | None = None
display_category: str | None = None
dedupe_key: str | None = None
actor: dict | None = None
targets: list[dict] | None = None
raw: dict | None = None
raw_text: str | None = None
model_config = ConfigDict(extra="allow")
class PaginatedEventResponse(BaseModel):
items: list[dict]
total: int
page: int
page_size: int
class FilterOptionsResponse(BaseModel):
services: list[str]
operations: list[str]
results: list[str]
actors: list[str]
actor_upns: list[str]
devices: list[str]
class FetchAuditLogsResponse(BaseModel):
stored_events: int
errors: list[str]

View File

@@ -2,7 +2,6 @@ import json
from mapping_loader import get_mapping
CATEGORY_LABELS = {
"ApplicationManagement": "Application",
"UserManagement": "User",

View File

@@ -0,0 +1,4 @@
pytest
mongomock
httpx
ruff

View File

@@ -7,3 +7,4 @@ PyYAML
python-jose[cryptography]
pydantic-settings
structlog
tenacity

View File

@@ -1,10 +1,10 @@
from fastapi import APIRouter
from config import (
AUTH_ENABLED,
AUTH_TENANT_ID,
AUTH_CLIENT_ID,
AUTH_ENABLED,
AUTH_SCOPE,
AUTH_TENANT_ID,
)
from fastapi import APIRouter
router = APIRouter()

View File

@@ -1,22 +1,24 @@
import re
from fastapi import APIRouter, HTTPException, Depends
from database import events_collection
from auth import require_auth
from database import events_collection
from fastapi import APIRouter, Depends, HTTPException, Query
from models.api import FilterOptionsResponse, PaginatedEventResponse
router = APIRouter(dependencies=[Depends(require_auth)])
@router.get("/events")
@router.get("/events", response_model=PaginatedEventResponse)
def list_events(
service: str = None,
actor: str = None,
operation: str = None,
result: str = None,
start: str = None,
end: str = None,
search: str = None,
page: int = 1,
page_size: int = 50,
service: str | None = None,
actor: str | None = None,
operation: str | None = None,
result: str | None = None,
start: str | None = None,
end: str | None = None,
search: str | None = None,
page: int = Query(default=1, ge=1),
page_size: int = Query(default=50, ge=1, le=500),
):
filters = []
@@ -82,8 +84,8 @@ def list_events(
}
@router.get("/filter-options")
def filter_options(limit: int = 200):
@router.get("/filter-options", response_model=FilterOptionsResponse)
def filter_options(limit: int = Query(default=200, ge=1, le=1000)):
"""
Provide distinct values for UI filters (best-effort, capped).
"""

View File

@@ -1,12 +1,12 @@
from fastapi import APIRouter, HTTPException, Depends
from pymongo import UpdateOne
from database import events_collection
from graph.audit_logs import fetch_audit_logs
from sources.unified_audit import fetch_unified_audit
from sources.intune_audit import fetch_intune_audit
from models.event_model import normalize_event
from auth import require_auth
from database import events_collection
from fastapi import APIRouter, Depends, HTTPException, Query
from graph.audit_logs import fetch_audit_logs
from models.api import FetchAuditLogsResponse
from models.event_model import normalize_event
from pymongo import UpdateOne
from sources.intune_audit import fetch_intune_audit
from sources.unified_audit import fetch_unified_audit
router = APIRouter(dependencies=[Depends(require_auth)])
@@ -40,8 +40,8 @@ def run_fetch(hours: int = 168):
return {"stored_events": len(normalized), "errors": errors}
@router.get("/fetch-audit-logs")
def fetch_logs(hours: int = 168):
@router.get("/fetch-audit-logs", response_model=FetchAuditLogsResponse)
def fetch_logs(hours: int = Query(default=168, ge=1, le=720)):
try:
return run_fetch(hours=hours)
except Exception as exc:

View File

@@ -1,11 +1,10 @@
import requests
from datetime import datetime, timedelta
from typing import List
from graph.auth import get_access_token
from utils.http import get_with_retry
def fetch_intune_audit(hours: int = 24, max_pages: int = 50) -> List[dict]:
def fetch_intune_audit(hours: int = 24, max_pages: int = 50) -> list[dict]:
"""
Fetch Intune audit events via Microsoft Graph.
Requires Intune audit permissions (e.g., DeviceManagementConfiguration.Read.All).
@@ -24,13 +23,13 @@ def fetch_intune_audit(hours: int = 24, max_pages: int = 50) -> List[dict]:
if pages >= max_pages:
raise RuntimeError(f"Aborting Intune pagination after {max_pages} pages.")
try:
res = requests.get(url, headers=headers, timeout=20)
res = get_with_retry(url, headers=headers, timeout=20)
res.raise_for_status()
body = res.json()
except requests.RequestException as exc:
except RuntimeError:
raise
except Exception as exc:
raise RuntimeError(f"Failed to fetch Intune audit logs: {exc}") from exc
except ValueError as exc:
raise RuntimeError(f"Invalid Intune response JSON: {exc}") from exc
events.extend(body.get("value", []))
url = body.get("@odata.nextLink")

View File

@@ -1,9 +1,8 @@
import requests
from contextlib import suppress
from datetime import datetime, timedelta
from typing import List
from graph.auth import get_access_token
from utils.http import get_with_retry, post_with_retry
AUDIT_CONTENT_TYPES = {
"Audit.Exchange": "Exchange admin audit",
@@ -23,40 +22,41 @@ def _ensure_subscription(content_type: str, token: str, tenant_id: str):
url = f"https://manage.office.com/api/v1.0/{tenant_id}/activity/feed/subscriptions/start"
params = {"contentType": content_type}
headers = {"Authorization": f"Bearer {token}"}
try:
requests.post(url, params=params, headers=headers, timeout=10)
except requests.RequestException:
pass # best-effort
with suppress(Exception):
post_with_retry(url, params=params, headers=headers, timeout=10)
def _list_content(content_type: str, token: str, tenant_id: str, hours: int) -> List[dict]:
def _list_content(content_type: str, token: str, tenant_id: str, hours: int) -> list[dict]:
start, end = _time_window(hours)
url = f"https://manage.office.com/api/v1.0/{tenant_id}/activity/feed/subscriptions/content"
params = {"contentType": content_type, "startTime": start, "endTime": end}
headers = {"Authorization": f"Bearer {token}"}
try:
res = requests.get(url, params=params, headers=headers, timeout=20)
res = get_with_retry(url, params=params, headers=headers, timeout=20)
if res.status_code in (400, 401, 403, 404):
# Likely not enabled or insufficient perms; surface the text to the caller.
raise RuntimeError(f"{content_type} content listing failed ({res.status_code}): {res.text}")
return []
res.raise_for_status()
return res.json() or []
except requests.RequestException as exc:
except RuntimeError:
raise
except Exception as exc:
raise RuntimeError(f"Failed to list {content_type} content: {exc}") from exc
def _download_content(content_uri: str, token: str) -> List[dict]:
def _download_content(content_uri: str, token: str) -> list[dict]:
headers = {"Authorization": f"Bearer {token}"}
try:
res = requests.get(content_uri, headers=headers, timeout=30)
res = get_with_retry(content_uri, headers=headers, timeout=30)
res.raise_for_status()
return res.json() or []
except requests.RequestException as exc:
except RuntimeError:
raise
except Exception as exc:
raise RuntimeError(f"Failed to download audit content: {exc}") from exc
def fetch_unified_audit(hours: int = 24, max_files: int = 50) -> List[dict]:
def fetch_unified_audit(hours: int = 24, max_files: int = 50) -> list[dict]:
"""
Fetch unified audit logs (Exchange, SharePoint, Teams policy changes via Audit.General)
using the Office 365 Management Activity API.
@@ -67,7 +67,7 @@ def fetch_unified_audit(hours: int = 24, max_files: int = 50) -> List[dict]:
events = []
for content_type in AUDIT_CONTENT_TYPES.keys():
for content_type in AUDIT_CONTENT_TYPES:
_ensure_subscription(content_type, token, TENANT_ID)
contents = _list_content(content_type, token, TENANT_ID, hours)
for item in contents[:max_files]:

View File

26
backend/tests/conftest.py Normal file
View File

@@ -0,0 +1,26 @@
import mongomock
import pytest
from fastapi.testclient import TestClient
@pytest.fixture(scope="function")
def mock_events_collection():
client = mongomock.MongoClient()
db = client["micro_soc"]
coll = db["events"]
return coll
@pytest.fixture(scope="function")
def client(mock_events_collection, monkeypatch):
# Patch the collection in all modules that import it before the app is imported
monkeypatch.setattr("database.events_collection", mock_events_collection)
monkeypatch.setattr("routes.fetch.events_collection", mock_events_collection)
monkeypatch.setattr("routes.events.events_collection", mock_events_collection)
monkeypatch.setattr("auth.AUTH_ENABLED", False)
# Patch health check db.command so it doesn't need a real MongoDB server
monkeypatch.setattr("database.db.command", lambda cmd: {"ok": 1} if cmd == "ping" else {})
from main import app
return TestClient(app)

98
backend/tests/test_api.py Normal file
View File

@@ -0,0 +1,98 @@
from datetime import UTC, datetime
def test_health(client):
response = client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "ok"
assert data["database"] == "connected"
def test_list_events_empty(client):
response = client.get("/api/events")
assert response.status_code == 200
data = response.json()
assert data["items"] == []
assert data["total"] == 0
def test_list_events_pagination(client, mock_events_collection):
for i in range(5):
mock_events_collection.insert_one({
"id": f"evt-{i}",
"timestamp": datetime.now(UTC).isoformat(),
"service": "Directory",
"operation": "Add user",
"result": "success",
"actor_display": f"Actor {i}",
"raw_text": "",
})
response = client.get("/api/events?page=1&page_size=2")
assert response.status_code == 200
data = response.json()
assert data["total"] == 5
assert len(data["items"]) == 2
assert data["page"] == 1
assert data["page_size"] == 2
def test_list_events_filter_by_service(client, mock_events_collection):
mock_events_collection.insert_one({
"id": "evt-1",
"timestamp": datetime.now(UTC).isoformat(),
"service": "Exchange",
"operation": "Update",
"result": "success",
"actor_display": "Alice",
"raw_text": "",
})
mock_events_collection.insert_one({
"id": "evt-2",
"timestamp": datetime.now(UTC).isoformat(),
"service": "Directory",
"operation": "Add",
"result": "success",
"actor_display": "Bob",
"raw_text": "",
})
response = client.get("/api/events?service=Exchange")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
assert data["items"][0]["service"] == "Exchange"
def test_list_events_page_size_validation(client):
response = client.get("/api/events?page_size=0")
assert response.status_code == 422
response = client.get("/api/events?page_size=501")
assert response.status_code == 422
def test_filter_options(client, mock_events_collection):
mock_events_collection.insert_one({
"id": "evt-1",
"timestamp": datetime.now(UTC).isoformat(),
"service": "Intune",
"operation": "Assign",
"result": "failure",
"actor_display": "Charlie",
"actor_upn": "charlie@example.com",
"raw_text": "",
})
response = client.get("/api/filter-options")
assert response.status_code == 200
data = response.json()
assert "Intune" in data["services"]
assert "Assign" in data["operations"]
assert "failure" in data["results"]
assert "Charlie" in data["actors"]
assert "charlie@example.com" in data["actor_upns"]
def test_fetch_audit_logs_validation(client):
response = client.get("/api/fetch-audit-logs?hours=0")
assert response.status_code == 422
response = client.get("/api/fetch-audit-logs?hours=721")
assert response.status_code == 422

View File

@@ -0,0 +1,61 @@
from unittest.mock import patch
import auth
import pytest
from auth import _allowed, require_auth
from fastapi import HTTPException
@pytest.fixture(autouse=True)
def reset_cache():
auth.JWKS_CACHE["keys"] = []
auth.JWKS_CACHE["exp"] = 0
@pytest.fixture
def mock_jwks():
from Crypto.PublicKey import RSA
from jose.jwk import RSAKey
key = RSA.generate(2048)
rsa_key = RSAKey(key)
jwk_dict = {
"kty": "RSA",
"kid": "test-kid",
"n": rsa_key._key.n,
"e": rsa_key._key.e,
}
return rsa_key, jwk_dict
def test_allowed_no_restrictions():
assert _allowed({}, set(), set()) is True
def test_allowed_by_role():
assert _allowed({"roles": ["Admin"]}, {"Admin"}, set()) is True
assert _allowed({"roles": ["User"]}, {"Admin"}, set()) is False
def test_allowed_by_group():
assert _allowed({"groups": ["SecOps"]}, set(), {"SecOps"}) is True
assert _allowed({"groups": ["Users"]}, set(), {"SecOps"}) is False
@patch("auth.AUTH_ENABLED", False)
def test_require_auth_disabled():
claims = require_auth(None)
assert claims["sub"] == "anonymous"
@patch("auth.AUTH_ENABLED", True)
def test_require_auth_missing_header():
with pytest.raises(HTTPException) as exc_info:
require_auth(None)
assert exc_info.value.status_code == 401
@patch("auth.AUTH_ENABLED", True)
def test_require_auth_invalid_bearer():
with pytest.raises(HTTPException) as exc_info:
require_auth("Basic abc")
assert exc_info.value.status_code == 401

View File

@@ -0,0 +1,63 @@
from models.event_model import _make_dedupe_key, normalize_event
def test_make_dedupe_key_prefers_id_and_category():
e = {"id": "evt-123", "category": "Directory"}
assert _make_dedupe_key(e) == "evt-123|Directory"
def test_make_dedupe_key_fallback_without_id():
e = {
"activityDateTime": "2024-01-01T00:00:00Z",
"category": "Exchange",
"activityDisplayName": "Update setting",
}
key = _make_dedupe_key(e)
assert "2024-01-01T00:00:00Z|Exchange|Update setting" in key
def test_normalize_event_basic():
e = {
"id": "abc",
"activityDateTime": "2024-01-15T10:30:00Z",
"category": "UserManagement",
"activityDisplayName": "Add user",
"result": "success",
"initiatedBy": {
"user": {
"id": "u1",
"displayName": "Alice",
"userPrincipalName": "alice@example.com",
}
},
"targetResources": [
{"id": "t1", "displayName": "Bob", "type": "User"}
],
}
out = normalize_event(e)
assert out["id"] == "abc"
assert out["timestamp"] == "2024-01-15T10:30:00Z"
assert out["service"] == "UserManagement"
assert out["operation"] == "Add user"
assert out["result"] == "success"
assert out["actor_display"] == "Alice (alice@example.com)"
assert out["target_displays"] == ["Bob"]
assert out["dedupe_key"] == "abc|UserManagement"
assert "raw_text" in out
def test_normalize_event_with_resolved_actor():
e = {
"id": "def",
"activityDateTime": "2024-01-15T11:00:00Z",
"category": "ApplicationManagement",
"activityDisplayName": "Add app",
"result": "success",
"initiatedBy": {"servicePrincipal": {"id": "sp1"}},
"targetResources": [],
"_resolvedActor": {"id": "sp1", "type": "servicePrincipal", "name": "MyApp"},
"_resolvedActorOwners": ["Owner1"],
}
out = normalize_event(e)
assert out["actor_display"] == "MyApp (owners: Owner1)"
assert out["display_category"] == "Application"

View File

29
backend/utils/http.py Normal file
View File

@@ -0,0 +1,29 @@
import requests
import structlog
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
logger = structlog.get_logger("aoc.http")
RETRY_CONFIG = {
"stop": stop_after_attempt(3),
"wait": wait_exponential(multiplier=1, min=2, max=10),
"retry": retry_if_exception_type(requests.RequestException),
"before_sleep": lambda retry_state: logger.warning(
"Retrying HTTP request",
attempt=retry_state.attempt_number,
exception=str(retry_state.outcome.exception()) if retry_state.outcome else None,
),
}
@retry(**RETRY_CONFIG)
def get_with_retry(url: str, headers: dict | None = None, params: dict | None = None, timeout: float = 20) -> requests.Response:
res = requests.get(url, headers=headers, params=params, timeout=timeout)
return res
@retry(**RETRY_CONFIG)
def post_with_retry(url: str, headers: dict | None = None, data: dict | None = None, params: dict | None = None, timeout: float = 15) -> requests.Response:
res = requests.post(url, headers=headers, data=data, params=params, timeout=timeout)
return res

14
pyproject.toml Normal file
View File

@@ -0,0 +1,14 @@
[tool.ruff]
target-version = "py311"
line-length = 120
[tool.ruff.lint]
select = ["E", "F", "I", "N", "W", "UP", "B", "C4", "SIM"]
ignore = ["E501"]
[tool.ruff.lint.pydocstyle]
convention = "google"
[tool.pytest.ini_options]
testpaths = ["backend/tests"]
pythonpath = ["backend"]