feat: implement Phase 2 stabilization
Some checks failed
CI / lint-and-test (push) Has been cancelled
Some checks failed
CI / lint-and-test (push) Has been cancelled
- Cache Graph API tokens with expiry-aware reuse in graph/auth.py - Add tenacity-based retry/backoff wrapper (utils/http.py) and apply to all Graph/source API calls - Add Pydantic request/response models (models/api.py) and FastAPI query constraints - Add unit tests for event_model, auth and integration tests for API endpoints - Configure ruff linter/formatter in pyproject.toml - Add GitHub Actions CI pipeline (.github/workflows/ci.yml) - Add requirements-dev.txt with pytest, mongomock, httpx, ruff - Clean up typing imports and fix ruff linting across codebase
This commit is contained in:
@@ -1,19 +1,17 @@
|
||||
import time
|
||||
import structlog
|
||||
from typing import Optional, Set
|
||||
|
||||
import requests
|
||||
from fastapi import Depends, HTTPException, Header
|
||||
from jose import jwt
|
||||
from jose.jwk import construct
|
||||
|
||||
import structlog
|
||||
from config import (
|
||||
AUTH_ALLOWED_GROUPS,
|
||||
AUTH_ALLOWED_ROLES,
|
||||
AUTH_CLIENT_ID,
|
||||
AUTH_ENABLED,
|
||||
AUTH_TENANT_ID,
|
||||
AUTH_CLIENT_ID,
|
||||
AUTH_ALLOWED_ROLES,
|
||||
AUTH_ALLOWED_GROUPS,
|
||||
)
|
||||
from fastapi import Header, HTTPException
|
||||
from jose import jwt
|
||||
from jose.jwk import construct
|
||||
|
||||
JWKS_CACHE = {"exp": 0, "keys": []}
|
||||
logger = structlog.get_logger("aoc.auth")
|
||||
@@ -35,16 +33,15 @@ def _get_jwks():
|
||||
return keys
|
||||
|
||||
|
||||
def _allowed(claims: dict, allowed_roles: Set[str], allowed_groups: Set[str]) -> bool:
|
||||
def _allowed(claims: dict, allowed_roles: set[str], allowed_groups: set[str]) -> bool:
|
||||
if not allowed_roles and not allowed_groups:
|
||||
return True
|
||||
roles = set(claims.get("roles", []) or claims.get("role", []) or [])
|
||||
groups = set(claims.get("groups", []) or [])
|
||||
if allowed_roles and roles.intersection(allowed_roles):
|
||||
return True
|
||||
if allowed_groups and groups.intersection(allowed_groups):
|
||||
return True
|
||||
return False
|
||||
return bool(
|
||||
(allowed_roles and roles.intersection(allowed_roles))
|
||||
or (allowed_groups and groups.intersection(allowed_groups))
|
||||
)
|
||||
|
||||
|
||||
def _decode_token(token: str, jwks):
|
||||
@@ -72,10 +69,10 @@ def _decode_token(token: str, jwks):
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.warning("Token verification failed", error=str(exc))
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
raise HTTPException(status_code=401, detail="Invalid token") from None
|
||||
|
||||
|
||||
def require_auth(authorization: Optional[str] = Header(None)):
|
||||
def require_auth(authorization: str | None = Header(None)):
|
||||
if not AUTH_ENABLED:
|
||||
return {"sub": "anonymous"}
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from pymongo import MongoClient, ASCENDING, DESCENDING, TEXT
|
||||
from config import MONGO_URI, DB_NAME, RETENTION_DAYS
|
||||
from contextlib import suppress
|
||||
|
||||
import structlog
|
||||
from config import DB_NAME, MONGO_URI, RETENTION_DAYS
|
||||
from pymongo import ASCENDING, DESCENDING, TEXT, MongoClient
|
||||
|
||||
client = MongoClient(MONGO_URI)
|
||||
db = client[DB_NAME]
|
||||
@@ -29,10 +31,8 @@ def setup_indexes(max_retries: int = 5, delay: float = 2.0):
|
||||
name="ttl_timestamp",
|
||||
)
|
||||
else:
|
||||
try:
|
||||
with suppress(Exception):
|
||||
events_collection.drop_index("ttl_timestamp")
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("MongoDB indexes ensured")
|
||||
return
|
||||
except Exception as exc:
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import requests
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from graph.auth import get_access_token
|
||||
from graph.resolve import resolve_directory_object, resolve_service_principal_owners
|
||||
from utils.http import get_with_retry
|
||||
|
||||
|
||||
def fetch_audit_logs(hours=24, max_pages=50):
|
||||
@@ -22,13 +23,13 @@ def fetch_audit_logs(hours=24, max_pages=50):
|
||||
raise RuntimeError(f"Aborting pagination after {max_pages} pages to avoid runaway fetch.")
|
||||
|
||||
try:
|
||||
res = requests.get(next_url, headers=headers, timeout=20)
|
||||
res = get_with_retry(next_url, headers=headers, timeout=20)
|
||||
res.raise_for_status()
|
||||
body = res.json()
|
||||
except requests.RequestException as exc:
|
||||
except RuntimeError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"Failed to fetch audit logs page: {exc}") from exc
|
||||
except ValueError as exc:
|
||||
raise RuntimeError(f"Invalid JSON response from Graph: {exc}") from exc
|
||||
|
||||
events.extend(body.get("value", []))
|
||||
next_url = body.get("@odata.nextLink")
|
||||
|
||||
@@ -1,9 +1,19 @@
|
||||
import time
|
||||
|
||||
import requests
|
||||
from config import TENANT_ID, CLIENT_ID, CLIENT_SECRET
|
||||
from config import CLIENT_ID, CLIENT_SECRET, TENANT_ID
|
||||
|
||||
_TOKEN_CACHE = {}
|
||||
|
||||
|
||||
def get_access_token(scope: str = "https://graph.microsoft.com/.default"):
|
||||
"""Request an application token from Microsoft identity platform."""
|
||||
"""Request an application token from Microsoft identity platform.
|
||||
Tokens are cached and reused until 5 minutes before expiry."""
|
||||
now = time.time()
|
||||
cached = _TOKEN_CACHE.get(scope)
|
||||
if cached and cached["exp"] > now + 300:
|
||||
return cached["token"]
|
||||
|
||||
url = f"https://login.microsoftonline.com/{TENANT_ID}/oauth2/v2.0/token"
|
||||
data = {
|
||||
"grant_type": "client_credentials",
|
||||
@@ -14,9 +24,12 @@ def get_access_token(scope: str = "https://graph.microsoft.com/.default"):
|
||||
try:
|
||||
res = requests.post(url, data=data, timeout=15)
|
||||
res.raise_for_status()
|
||||
token = res.json().get("access_token")
|
||||
payload = res.json()
|
||||
token = payload.get("access_token")
|
||||
if not token:
|
||||
raise RuntimeError("Token endpoint returned no access_token")
|
||||
expires_in = payload.get("expires_in", 3600)
|
||||
_TOKEN_CACHE[scope] = {"token": token, "exp": now + expires_in}
|
||||
return token
|
||||
except requests.RequestException as exc:
|
||||
raise RuntimeError(f"Failed to obtain access token: {exc}") from exc
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from utils.http import get_with_retry
|
||||
|
||||
|
||||
def _name_from_payload(payload: dict, kind: str) -> str:
|
||||
@@ -26,18 +25,18 @@ def _name_from_payload(payload: dict, kind: str) -> str:
|
||||
return payload.get("displayName") or payload.get("id") or "Unknown"
|
||||
|
||||
|
||||
def _request_json(url: str, token: str) -> Optional[dict]:
|
||||
def _request_json(url: str, token: str) -> dict | None:
|
||||
try:
|
||||
res = requests.get(url, headers={"Authorization": f"Bearer {token}"}, timeout=10)
|
||||
res = get_with_retry(url, headers={"Authorization": f"Bearer {token}"}, timeout=10)
|
||||
if res.status_code == 404:
|
||||
return None
|
||||
res.raise_for_status()
|
||||
return res.json()
|
||||
except requests.RequestException:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def resolve_directory_object(object_id: str, token: str, cache: Dict[str, dict]) -> Optional[dict]:
|
||||
def resolve_directory_object(object_id: str, token: str, cache: dict[str, dict]) -> dict | None:
|
||||
"""
|
||||
Resolve a directory object (user, servicePrincipal, group, device) to a readable name.
|
||||
Uses a simple multi-endpoint probe with caching to avoid extra Graph traffic.
|
||||
@@ -69,7 +68,7 @@ def resolve_directory_object(object_id: str, token: str, cache: Dict[str, dict])
|
||||
return None
|
||||
|
||||
|
||||
def resolve_service_principal_owners(sp_id: str, token: str, cache: Dict[str, List[str]]) -> List[str]:
|
||||
def resolve_service_principal_owners(sp_id: str, token: str, cache: dict[str, list[str]]) -> list[str]:
|
||||
"""Return a list of owner display names for a service principal."""
|
||||
if not sp_id:
|
||||
return []
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from contextlib import suppress
|
||||
from pathlib import Path
|
||||
|
||||
import structlog
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from config import CORS_ORIGINS, ENABLE_PERIODIC_FETCH, FETCH_INTERVAL_MINUTES
|
||||
from database import setup_indexes
|
||||
from routes.fetch import router as fetch_router, run_fetch
|
||||
from routes.events import router as events_router
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from routes.config import router as config_router
|
||||
from config import ENABLE_PERIODIC_FETCH, FETCH_INTERVAL_MINUTES, CORS_ORIGINS
|
||||
from routes.events import router as events_router
|
||||
from routes.fetch import router as fetch_router
|
||||
from routes.fetch import run_fetch
|
||||
|
||||
|
||||
def configure_logging():
|
||||
@@ -90,7 +91,5 @@ async def stop_periodic_fetch():
|
||||
task = getattr(app.state, "fetch_task", None)
|
||||
if task:
|
||||
task.cancel()
|
||||
try:
|
||||
with suppress(Exception):
|
||||
await task
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -7,14 +7,12 @@ new display fields. Example:
|
||||
python maintenance.py renormalize --limit 500
|
||||
"""
|
||||
import argparse
|
||||
from typing import List, Set
|
||||
|
||||
from pymongo import UpdateOne
|
||||
|
||||
from database import events_collection
|
||||
from graph.auth import get_access_token
|
||||
from graph.audit_logs import _enrich_events
|
||||
from models.event_model import normalize_event, _make_dedupe_key
|
||||
from graph.auth import get_access_token
|
||||
from models.event_model import _make_dedupe_key, normalize_event
|
||||
from pymongo import UpdateOne
|
||||
|
||||
|
||||
def renormalize(limit: int = None, batch_size: int = 200) -> int:
|
||||
@@ -29,7 +27,7 @@ def renormalize(limit: int = None, batch_size: int = 200) -> int:
|
||||
cursor = cursor.limit(int(limit))
|
||||
|
||||
updated = 0
|
||||
batch: List[UpdateOne] = []
|
||||
batch: list[UpdateOne] = []
|
||||
|
||||
for doc in cursor:
|
||||
raw = doc.get("raw") or {}
|
||||
@@ -59,7 +57,7 @@ def dedupe(limit: int = None, batch_size: int = 500) -> int:
|
||||
if limit:
|
||||
cursor = cursor.limit(int(limit))
|
||||
|
||||
seen: Set[str] = set()
|
||||
seen: set[str] = set()
|
||||
to_delete = []
|
||||
processed = 0
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
DEFAULT_MAPPING: Dict[str, Any] = {
|
||||
DEFAULT_MAPPING: dict[str, Any] = {
|
||||
"category_labels": {
|
||||
"ApplicationManagement": "Application",
|
||||
"UserManagement": "User",
|
||||
@@ -38,7 +37,7 @@ DEFAULT_MAPPING: Dict[str, Any] = {
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_mapping() -> Dict[str, Any]:
|
||||
def get_mapping() -> dict[str, Any]:
|
||||
"""
|
||||
Load mapping from mappings.yml if present; otherwise fall back to defaults.
|
||||
Users can edit mappings.yml to change labels and summary templates.
|
||||
|
||||
0
backend/models/__init__.py
Normal file
0
backend/models/__init__.py
Normal file
41
backend/models/api.py
Normal file
41
backend/models/api.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class EventItem(BaseModel):
|
||||
id: str | None = None
|
||||
timestamp: str | None = None
|
||||
service: str | None = None
|
||||
operation: str | None = None
|
||||
result: str | None = None
|
||||
actor_display: str | None = None
|
||||
target_displays: list[str] | None = None
|
||||
display_summary: str | None = None
|
||||
display_category: str | None = None
|
||||
dedupe_key: str | None = None
|
||||
actor: dict | None = None
|
||||
targets: list[dict] | None = None
|
||||
raw: dict | None = None
|
||||
raw_text: str | None = None
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class PaginatedEventResponse(BaseModel):
|
||||
items: list[dict]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
|
||||
|
||||
class FilterOptionsResponse(BaseModel):
|
||||
services: list[str]
|
||||
operations: list[str]
|
||||
results: list[str]
|
||||
actors: list[str]
|
||||
actor_upns: list[str]
|
||||
devices: list[str]
|
||||
|
||||
|
||||
class FetchAuditLogsResponse(BaseModel):
|
||||
stored_events: int
|
||||
errors: list[str]
|
||||
@@ -2,7 +2,6 @@ import json
|
||||
|
||||
from mapping_loader import get_mapping
|
||||
|
||||
|
||||
CATEGORY_LABELS = {
|
||||
"ApplicationManagement": "Application",
|
||||
"UserManagement": "User",
|
||||
|
||||
4
backend/requirements-dev.txt
Normal file
4
backend/requirements-dev.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
pytest
|
||||
mongomock
|
||||
httpx
|
||||
ruff
|
||||
@@ -7,3 +7,4 @@ PyYAML
|
||||
python-jose[cryptography]
|
||||
pydantic-settings
|
||||
structlog
|
||||
tenacity
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from fastapi import APIRouter
|
||||
from config import (
|
||||
AUTH_ENABLED,
|
||||
AUTH_TENANT_ID,
|
||||
AUTH_CLIENT_ID,
|
||||
AUTH_ENABLED,
|
||||
AUTH_SCOPE,
|
||||
AUTH_TENANT_ID,
|
||||
)
|
||||
from fastapi import APIRouter
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@@ -1,22 +1,24 @@
|
||||
import re
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from database import events_collection
|
||||
|
||||
from auth import require_auth
|
||||
from database import events_collection
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from models.api import FilterOptionsResponse, PaginatedEventResponse
|
||||
|
||||
router = APIRouter(dependencies=[Depends(require_auth)])
|
||||
|
||||
|
||||
@router.get("/events")
|
||||
@router.get("/events", response_model=PaginatedEventResponse)
|
||||
def list_events(
|
||||
service: str = None,
|
||||
actor: str = None,
|
||||
operation: str = None,
|
||||
result: str = None,
|
||||
start: str = None,
|
||||
end: str = None,
|
||||
search: str = None,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
service: str | None = None,
|
||||
actor: str | None = None,
|
||||
operation: str | None = None,
|
||||
result: str | None = None,
|
||||
start: str | None = None,
|
||||
end: str | None = None,
|
||||
search: str | None = None,
|
||||
page: int = Query(default=1, ge=1),
|
||||
page_size: int = Query(default=50, ge=1, le=500),
|
||||
):
|
||||
filters = []
|
||||
|
||||
@@ -82,8 +84,8 @@ def list_events(
|
||||
}
|
||||
|
||||
|
||||
@router.get("/filter-options")
|
||||
def filter_options(limit: int = 200):
|
||||
@router.get("/filter-options", response_model=FilterOptionsResponse)
|
||||
def filter_options(limit: int = Query(default=200, ge=1, le=1000)):
|
||||
"""
|
||||
Provide distinct values for UI filters (best-effort, capped).
|
||||
"""
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from pymongo import UpdateOne
|
||||
|
||||
from database import events_collection
|
||||
from graph.audit_logs import fetch_audit_logs
|
||||
from sources.unified_audit import fetch_unified_audit
|
||||
from sources.intune_audit import fetch_intune_audit
|
||||
from models.event_model import normalize_event
|
||||
from auth import require_auth
|
||||
from database import events_collection
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from graph.audit_logs import fetch_audit_logs
|
||||
from models.api import FetchAuditLogsResponse
|
||||
from models.event_model import normalize_event
|
||||
from pymongo import UpdateOne
|
||||
from sources.intune_audit import fetch_intune_audit
|
||||
from sources.unified_audit import fetch_unified_audit
|
||||
|
||||
router = APIRouter(dependencies=[Depends(require_auth)])
|
||||
|
||||
@@ -40,8 +40,8 @@ def run_fetch(hours: int = 168):
|
||||
return {"stored_events": len(normalized), "errors": errors}
|
||||
|
||||
|
||||
@router.get("/fetch-audit-logs")
|
||||
def fetch_logs(hours: int = 168):
|
||||
@router.get("/fetch-audit-logs", response_model=FetchAuditLogsResponse)
|
||||
def fetch_logs(hours: int = Query(default=168, ge=1, le=720)):
|
||||
try:
|
||||
return run_fetch(hours=hours)
|
||||
except Exception as exc:
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import requests
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List
|
||||
|
||||
from graph.auth import get_access_token
|
||||
from utils.http import get_with_retry
|
||||
|
||||
|
||||
def fetch_intune_audit(hours: int = 24, max_pages: int = 50) -> List[dict]:
|
||||
def fetch_intune_audit(hours: int = 24, max_pages: int = 50) -> list[dict]:
|
||||
"""
|
||||
Fetch Intune audit events via Microsoft Graph.
|
||||
Requires Intune audit permissions (e.g., DeviceManagementConfiguration.Read.All).
|
||||
@@ -24,13 +23,13 @@ def fetch_intune_audit(hours: int = 24, max_pages: int = 50) -> List[dict]:
|
||||
if pages >= max_pages:
|
||||
raise RuntimeError(f"Aborting Intune pagination after {max_pages} pages.")
|
||||
try:
|
||||
res = requests.get(url, headers=headers, timeout=20)
|
||||
res = get_with_retry(url, headers=headers, timeout=20)
|
||||
res.raise_for_status()
|
||||
body = res.json()
|
||||
except requests.RequestException as exc:
|
||||
except RuntimeError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"Failed to fetch Intune audit logs: {exc}") from exc
|
||||
except ValueError as exc:
|
||||
raise RuntimeError(f"Invalid Intune response JSON: {exc}") from exc
|
||||
|
||||
events.extend(body.get("value", []))
|
||||
url = body.get("@odata.nextLink")
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import requests
|
||||
from contextlib import suppress
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List
|
||||
|
||||
from graph.auth import get_access_token
|
||||
|
||||
from utils.http import get_with_retry, post_with_retry
|
||||
|
||||
AUDIT_CONTENT_TYPES = {
|
||||
"Audit.Exchange": "Exchange admin audit",
|
||||
@@ -23,40 +22,41 @@ def _ensure_subscription(content_type: str, token: str, tenant_id: str):
|
||||
url = f"https://manage.office.com/api/v1.0/{tenant_id}/activity/feed/subscriptions/start"
|
||||
params = {"contentType": content_type}
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
try:
|
||||
requests.post(url, params=params, headers=headers, timeout=10)
|
||||
except requests.RequestException:
|
||||
pass # best-effort
|
||||
with suppress(Exception):
|
||||
post_with_retry(url, params=params, headers=headers, timeout=10)
|
||||
|
||||
|
||||
def _list_content(content_type: str, token: str, tenant_id: str, hours: int) -> List[dict]:
|
||||
def _list_content(content_type: str, token: str, tenant_id: str, hours: int) -> list[dict]:
|
||||
start, end = _time_window(hours)
|
||||
url = f"https://manage.office.com/api/v1.0/{tenant_id}/activity/feed/subscriptions/content"
|
||||
params = {"contentType": content_type, "startTime": start, "endTime": end}
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
try:
|
||||
res = requests.get(url, params=params, headers=headers, timeout=20)
|
||||
res = get_with_retry(url, params=params, headers=headers, timeout=20)
|
||||
if res.status_code in (400, 401, 403, 404):
|
||||
# Likely not enabled or insufficient perms; surface the text to the caller.
|
||||
raise RuntimeError(f"{content_type} content listing failed ({res.status_code}): {res.text}")
|
||||
return []
|
||||
res.raise_for_status()
|
||||
return res.json() or []
|
||||
except requests.RequestException as exc:
|
||||
except RuntimeError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"Failed to list {content_type} content: {exc}") from exc
|
||||
|
||||
|
||||
def _download_content(content_uri: str, token: str) -> List[dict]:
|
||||
def _download_content(content_uri: str, token: str) -> list[dict]:
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
try:
|
||||
res = requests.get(content_uri, headers=headers, timeout=30)
|
||||
res = get_with_retry(content_uri, headers=headers, timeout=30)
|
||||
res.raise_for_status()
|
||||
return res.json() or []
|
||||
except requests.RequestException as exc:
|
||||
except RuntimeError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"Failed to download audit content: {exc}") from exc
|
||||
|
||||
|
||||
def fetch_unified_audit(hours: int = 24, max_files: int = 50) -> List[dict]:
|
||||
def fetch_unified_audit(hours: int = 24, max_files: int = 50) -> list[dict]:
|
||||
"""
|
||||
Fetch unified audit logs (Exchange, SharePoint, Teams policy changes via Audit.General)
|
||||
using the Office 365 Management Activity API.
|
||||
@@ -67,7 +67,7 @@ def fetch_unified_audit(hours: int = 24, max_files: int = 50) -> List[dict]:
|
||||
|
||||
events = []
|
||||
|
||||
for content_type in AUDIT_CONTENT_TYPES.keys():
|
||||
for content_type in AUDIT_CONTENT_TYPES:
|
||||
_ensure_subscription(content_type, token, TENANT_ID)
|
||||
contents = _list_content(content_type, token, TENANT_ID, hours)
|
||||
for item in contents[:max_files]:
|
||||
|
||||
0
backend/tests/__init__.py
Normal file
0
backend/tests/__init__.py
Normal file
26
backend/tests/conftest.py
Normal file
26
backend/tests/conftest.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import mongomock
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def mock_events_collection():
|
||||
client = mongomock.MongoClient()
|
||||
db = client["micro_soc"]
|
||||
coll = db["events"]
|
||||
return coll
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client(mock_events_collection, monkeypatch):
|
||||
# Patch the collection in all modules that import it before the app is imported
|
||||
monkeypatch.setattr("database.events_collection", mock_events_collection)
|
||||
monkeypatch.setattr("routes.fetch.events_collection", mock_events_collection)
|
||||
monkeypatch.setattr("routes.events.events_collection", mock_events_collection)
|
||||
monkeypatch.setattr("auth.AUTH_ENABLED", False)
|
||||
# Patch health check db.command so it doesn't need a real MongoDB server
|
||||
monkeypatch.setattr("database.db.command", lambda cmd: {"ok": 1} if cmd == "ping" else {})
|
||||
|
||||
from main import app
|
||||
|
||||
return TestClient(app)
|
||||
98
backend/tests/test_api.py
Normal file
98
backend/tests/test_api.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from datetime import UTC, datetime
|
||||
|
||||
|
||||
def test_health(client):
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["database"] == "connected"
|
||||
|
||||
|
||||
def test_list_events_empty(client):
|
||||
response = client.get("/api/events")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["items"] == []
|
||||
assert data["total"] == 0
|
||||
|
||||
|
||||
def test_list_events_pagination(client, mock_events_collection):
|
||||
for i in range(5):
|
||||
mock_events_collection.insert_one({
|
||||
"id": f"evt-{i}",
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"service": "Directory",
|
||||
"operation": "Add user",
|
||||
"result": "success",
|
||||
"actor_display": f"Actor {i}",
|
||||
"raw_text": "",
|
||||
})
|
||||
response = client.get("/api/events?page=1&page_size=2")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 5
|
||||
assert len(data["items"]) == 2
|
||||
assert data["page"] == 1
|
||||
assert data["page_size"] == 2
|
||||
|
||||
|
||||
def test_list_events_filter_by_service(client, mock_events_collection):
|
||||
mock_events_collection.insert_one({
|
||||
"id": "evt-1",
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"service": "Exchange",
|
||||
"operation": "Update",
|
||||
"result": "success",
|
||||
"actor_display": "Alice",
|
||||
"raw_text": "",
|
||||
})
|
||||
mock_events_collection.insert_one({
|
||||
"id": "evt-2",
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"service": "Directory",
|
||||
"operation": "Add",
|
||||
"result": "success",
|
||||
"actor_display": "Bob",
|
||||
"raw_text": "",
|
||||
})
|
||||
response = client.get("/api/events?service=Exchange")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 1
|
||||
assert data["items"][0]["service"] == "Exchange"
|
||||
|
||||
|
||||
def test_list_events_page_size_validation(client):
|
||||
response = client.get("/api/events?page_size=0")
|
||||
assert response.status_code == 422
|
||||
response = client.get("/api/events?page_size=501")
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_filter_options(client, mock_events_collection):
|
||||
mock_events_collection.insert_one({
|
||||
"id": "evt-1",
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"service": "Intune",
|
||||
"operation": "Assign",
|
||||
"result": "failure",
|
||||
"actor_display": "Charlie",
|
||||
"actor_upn": "charlie@example.com",
|
||||
"raw_text": "",
|
||||
})
|
||||
response = client.get("/api/filter-options")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "Intune" in data["services"]
|
||||
assert "Assign" in data["operations"]
|
||||
assert "failure" in data["results"]
|
||||
assert "Charlie" in data["actors"]
|
||||
assert "charlie@example.com" in data["actor_upns"]
|
||||
|
||||
|
||||
def test_fetch_audit_logs_validation(client):
|
||||
response = client.get("/api/fetch-audit-logs?hours=0")
|
||||
assert response.status_code == 422
|
||||
response = client.get("/api/fetch-audit-logs?hours=721")
|
||||
assert response.status_code == 422
|
||||
61
backend/tests/test_auth.py
Normal file
61
backend/tests/test_auth.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import auth
|
||||
import pytest
|
||||
from auth import _allowed, require_auth
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_cache():
|
||||
auth.JWKS_CACHE["keys"] = []
|
||||
auth.JWKS_CACHE["exp"] = 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_jwks():
|
||||
from Crypto.PublicKey import RSA
|
||||
from jose.jwk import RSAKey
|
||||
key = RSA.generate(2048)
|
||||
rsa_key = RSAKey(key)
|
||||
jwk_dict = {
|
||||
"kty": "RSA",
|
||||
"kid": "test-kid",
|
||||
"n": rsa_key._key.n,
|
||||
"e": rsa_key._key.e,
|
||||
}
|
||||
return rsa_key, jwk_dict
|
||||
|
||||
|
||||
def test_allowed_no_restrictions():
|
||||
assert _allowed({}, set(), set()) is True
|
||||
|
||||
|
||||
def test_allowed_by_role():
|
||||
assert _allowed({"roles": ["Admin"]}, {"Admin"}, set()) is True
|
||||
assert _allowed({"roles": ["User"]}, {"Admin"}, set()) is False
|
||||
|
||||
|
||||
def test_allowed_by_group():
|
||||
assert _allowed({"groups": ["SecOps"]}, set(), {"SecOps"}) is True
|
||||
assert _allowed({"groups": ["Users"]}, set(), {"SecOps"}) is False
|
||||
|
||||
|
||||
@patch("auth.AUTH_ENABLED", False)
|
||||
def test_require_auth_disabled():
|
||||
claims = require_auth(None)
|
||||
assert claims["sub"] == "anonymous"
|
||||
|
||||
|
||||
@patch("auth.AUTH_ENABLED", True)
|
||||
def test_require_auth_missing_header():
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
require_auth(None)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
@patch("auth.AUTH_ENABLED", True)
|
||||
def test_require_auth_invalid_bearer():
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
require_auth("Basic abc")
|
||||
assert exc_info.value.status_code == 401
|
||||
63
backend/tests/test_event_model.py
Normal file
63
backend/tests/test_event_model.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from models.event_model import _make_dedupe_key, normalize_event
|
||||
|
||||
|
||||
def test_make_dedupe_key_prefers_id_and_category():
|
||||
e = {"id": "evt-123", "category": "Directory"}
|
||||
assert _make_dedupe_key(e) == "evt-123|Directory"
|
||||
|
||||
|
||||
def test_make_dedupe_key_fallback_without_id():
|
||||
e = {
|
||||
"activityDateTime": "2024-01-01T00:00:00Z",
|
||||
"category": "Exchange",
|
||||
"activityDisplayName": "Update setting",
|
||||
}
|
||||
key = _make_dedupe_key(e)
|
||||
assert "2024-01-01T00:00:00Z|Exchange|Update setting" in key
|
||||
|
||||
|
||||
def test_normalize_event_basic():
|
||||
e = {
|
||||
"id": "abc",
|
||||
"activityDateTime": "2024-01-15T10:30:00Z",
|
||||
"category": "UserManagement",
|
||||
"activityDisplayName": "Add user",
|
||||
"result": "success",
|
||||
"initiatedBy": {
|
||||
"user": {
|
||||
"id": "u1",
|
||||
"displayName": "Alice",
|
||||
"userPrincipalName": "alice@example.com",
|
||||
}
|
||||
},
|
||||
"targetResources": [
|
||||
{"id": "t1", "displayName": "Bob", "type": "User"}
|
||||
],
|
||||
}
|
||||
out = normalize_event(e)
|
||||
assert out["id"] == "abc"
|
||||
assert out["timestamp"] == "2024-01-15T10:30:00Z"
|
||||
assert out["service"] == "UserManagement"
|
||||
assert out["operation"] == "Add user"
|
||||
assert out["result"] == "success"
|
||||
assert out["actor_display"] == "Alice (alice@example.com)"
|
||||
assert out["target_displays"] == ["Bob"]
|
||||
assert out["dedupe_key"] == "abc|UserManagement"
|
||||
assert "raw_text" in out
|
||||
|
||||
|
||||
def test_normalize_event_with_resolved_actor():
|
||||
e = {
|
||||
"id": "def",
|
||||
"activityDateTime": "2024-01-15T11:00:00Z",
|
||||
"category": "ApplicationManagement",
|
||||
"activityDisplayName": "Add app",
|
||||
"result": "success",
|
||||
"initiatedBy": {"servicePrincipal": {"id": "sp1"}},
|
||||
"targetResources": [],
|
||||
"_resolvedActor": {"id": "sp1", "type": "servicePrincipal", "name": "MyApp"},
|
||||
"_resolvedActorOwners": ["Owner1"],
|
||||
}
|
||||
out = normalize_event(e)
|
||||
assert out["actor_display"] == "MyApp (owners: Owner1)"
|
||||
assert out["display_category"] == "Application"
|
||||
0
backend/utils/__init__.py
Normal file
0
backend/utils/__init__.py
Normal file
29
backend/utils/http.py
Normal file
29
backend/utils/http.py
Normal file
@@ -0,0 +1,29 @@
|
||||
|
||||
import requests
|
||||
import structlog
|
||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
||||
|
||||
logger = structlog.get_logger("aoc.http")
|
||||
|
||||
RETRY_CONFIG = {
|
||||
"stop": stop_after_attempt(3),
|
||||
"wait": wait_exponential(multiplier=1, min=2, max=10),
|
||||
"retry": retry_if_exception_type(requests.RequestException),
|
||||
"before_sleep": lambda retry_state: logger.warning(
|
||||
"Retrying HTTP request",
|
||||
attempt=retry_state.attempt_number,
|
||||
exception=str(retry_state.outcome.exception()) if retry_state.outcome else None,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@retry(**RETRY_CONFIG)
|
||||
def get_with_retry(url: str, headers: dict | None = None, params: dict | None = None, timeout: float = 20) -> requests.Response:
|
||||
res = requests.get(url, headers=headers, params=params, timeout=timeout)
|
||||
return res
|
||||
|
||||
|
||||
@retry(**RETRY_CONFIG)
|
||||
def post_with_retry(url: str, headers: dict | None = None, data: dict | None = None, params: dict | None = None, timeout: float = 15) -> requests.Response:
|
||||
res = requests.post(url, headers=headers, data=data, params=params, timeout=timeout)
|
||||
return res
|
||||
Reference in New Issue
Block a user