diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..3c4219a --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,37 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + lint-and-test: + runs-on: ubuntu-latest + defaults: + run: + working-directory: ./backend + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install -r requirements-dev.txt + + - name: Lint with ruff + run: ruff check . + + - name: Format check with ruff + run: ruff format --check . + + - name: Run tests + run: pytest -q diff --git a/README.md b/README.md index f74a361..43b7de4 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,26 @@ Stored document shape (collection `micro_soc.events`): } ``` +## Development + +### Linting and formatting +We use `ruff` for linting and formatting. + +```bash +cd backend +python3 -m venv .venv +source .venv/bin/activate +pip install -r requirements.txt -r requirements-dev.txt +ruff check .. +ruff format .. +``` + +### Running tests +```bash +cd backend +pytest -q +``` + ## Quick smoke tests With the server running: ```bash diff --git a/ROADMAP.md b/ROADMAP.md index 8bdbd3b..5f29a88 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -20,17 +20,17 @@ Goal: fix critical security and reliability gaps before production use. --- -## Phase 2: Stabilize +## Phase 2: Stabilize ✅ Goal: improve resilience, code quality, and development experience. -- [ ] Cache Graph API tokens and reuse them until near expiry -- [ ] Add exponential backoff / retry logic for Graph API and Office 365 API calls -- [ ] Add unit tests for `normalize_event()`, `_make_dedupe_key()`, and `auth.py` -- [ ] Add integration tests for `/api/events` and `/api/fetch-audit-logs` -- [ ] Configure linter/formatter (`ruff` or `black` + `isort`) and pre-commit hooks -- [ ] Set up GitHub Actions CI pipeline (lint + test) -- [ ] Add Pydantic request/response models for API endpoints -- [ ] Validate `page_size` and `hours` with strict FastAPI constraints +- [x] Cache Graph API tokens and reuse them until near expiry +- [x] Add exponential backoff / retry logic for Graph API and Office 365 API calls +- [x] Add unit tests for `normalize_event()`, `_make_dedupe_key()`, and `auth.py` +- [x] Add integration tests for `/api/events` and `/api/fetch-audit-logs` +- [x] Configure linter/formatter (`ruff`) and pre-commit hooks +- [x] Set up GitHub Actions CI pipeline (lint + test) +- [x] Add Pydantic request/response models for API endpoints +- [x] Validate `page_size` and `hours` with strict FastAPI constraints --- diff --git a/backend/auth.py b/backend/auth.py index c858c0a..15c1bc4 100644 --- a/backend/auth.py +++ b/backend/auth.py @@ -1,19 +1,17 @@ import time -import structlog -from typing import Optional, Set import requests -from fastapi import Depends, HTTPException, Header -from jose import jwt -from jose.jwk import construct - +import structlog from config import ( + AUTH_ALLOWED_GROUPS, + AUTH_ALLOWED_ROLES, + AUTH_CLIENT_ID, AUTH_ENABLED, AUTH_TENANT_ID, - AUTH_CLIENT_ID, - AUTH_ALLOWED_ROLES, - AUTH_ALLOWED_GROUPS, ) +from fastapi import Header, HTTPException +from jose import jwt +from jose.jwk import construct JWKS_CACHE = {"exp": 0, "keys": []} logger = structlog.get_logger("aoc.auth") @@ -35,16 +33,15 @@ def _get_jwks(): return keys -def _allowed(claims: dict, allowed_roles: Set[str], allowed_groups: Set[str]) -> bool: +def _allowed(claims: dict, allowed_roles: set[str], allowed_groups: set[str]) -> bool: if not allowed_roles and not allowed_groups: return True roles = set(claims.get("roles", []) or claims.get("role", []) or []) groups = set(claims.get("groups", []) or []) - if allowed_roles and roles.intersection(allowed_roles): - return True - if allowed_groups and groups.intersection(allowed_groups): - return True - return False + return bool( + (allowed_roles and roles.intersection(allowed_roles)) + or (allowed_groups and groups.intersection(allowed_groups)) + ) def _decode_token(token: str, jwks): @@ -72,10 +69,10 @@ def _decode_token(token: str, jwks): raise except Exception as exc: logger.warning("Token verification failed", error=str(exc)) - raise HTTPException(status_code=401, detail="Invalid token") + raise HTTPException(status_code=401, detail="Invalid token") from None -def require_auth(authorization: Optional[str] = Header(None)): +def require_auth(authorization: str | None = Header(None)): if not AUTH_ENABLED: return {"sub": "anonymous"} diff --git a/backend/database.py b/backend/database.py index b63b62f..7d1b82a 100644 --- a/backend/database.py +++ b/backend/database.py @@ -1,6 +1,8 @@ -from pymongo import MongoClient, ASCENDING, DESCENDING, TEXT -from config import MONGO_URI, DB_NAME, RETENTION_DAYS +from contextlib import suppress + import structlog +from config import DB_NAME, MONGO_URI, RETENTION_DAYS +from pymongo import ASCENDING, DESCENDING, TEXT, MongoClient client = MongoClient(MONGO_URI) db = client[DB_NAME] @@ -29,10 +31,8 @@ def setup_indexes(max_retries: int = 5, delay: float = 2.0): name="ttl_timestamp", ) else: - try: + with suppress(Exception): events_collection.drop_index("ttl_timestamp") - except Exception: - pass logger.info("MongoDB indexes ensured") return except Exception as exc: diff --git a/backend/graph/audit_logs.py b/backend/graph/audit_logs.py index 50a06f4..df7a3f4 100644 --- a/backend/graph/audit_logs.py +++ b/backend/graph/audit_logs.py @@ -1,7 +1,8 @@ -import requests from datetime import datetime, timedelta + from graph.auth import get_access_token from graph.resolve import resolve_directory_object, resolve_service_principal_owners +from utils.http import get_with_retry def fetch_audit_logs(hours=24, max_pages=50): @@ -22,13 +23,13 @@ def fetch_audit_logs(hours=24, max_pages=50): raise RuntimeError(f"Aborting pagination after {max_pages} pages to avoid runaway fetch.") try: - res = requests.get(next_url, headers=headers, timeout=20) + res = get_with_retry(next_url, headers=headers, timeout=20) res.raise_for_status() body = res.json() - except requests.RequestException as exc: + except RuntimeError: + raise + except Exception as exc: raise RuntimeError(f"Failed to fetch audit logs page: {exc}") from exc - except ValueError as exc: - raise RuntimeError(f"Invalid JSON response from Graph: {exc}") from exc events.extend(body.get("value", [])) next_url = body.get("@odata.nextLink") diff --git a/backend/graph/auth.py b/backend/graph/auth.py index d5776e0..14ce638 100644 --- a/backend/graph/auth.py +++ b/backend/graph/auth.py @@ -1,9 +1,19 @@ +import time + import requests -from config import TENANT_ID, CLIENT_ID, CLIENT_SECRET +from config import CLIENT_ID, CLIENT_SECRET, TENANT_ID + +_TOKEN_CACHE = {} def get_access_token(scope: str = "https://graph.microsoft.com/.default"): - """Request an application token from Microsoft identity platform.""" + """Request an application token from Microsoft identity platform. + Tokens are cached and reused until 5 minutes before expiry.""" + now = time.time() + cached = _TOKEN_CACHE.get(scope) + if cached and cached["exp"] > now + 300: + return cached["token"] + url = f"https://login.microsoftonline.com/{TENANT_ID}/oauth2/v2.0/token" data = { "grant_type": "client_credentials", @@ -14,9 +24,12 @@ def get_access_token(scope: str = "https://graph.microsoft.com/.default"): try: res = requests.post(url, data=data, timeout=15) res.raise_for_status() - token = res.json().get("access_token") + payload = res.json() + token = payload.get("access_token") if not token: raise RuntimeError("Token endpoint returned no access_token") + expires_in = payload.get("expires_in", 3600) + _TOKEN_CACHE[scope] = {"token": token, "exp": now + expires_in} return token except requests.RequestException as exc: raise RuntimeError(f"Failed to obtain access token: {exc}") from exc diff --git a/backend/graph/resolve.py b/backend/graph/resolve.py index b7b4ea8..349af7b 100644 --- a/backend/graph/resolve.py +++ b/backend/graph/resolve.py @@ -1,6 +1,5 @@ -from typing import Dict, List, Optional -import requests +from utils.http import get_with_retry def _name_from_payload(payload: dict, kind: str) -> str: @@ -26,18 +25,18 @@ def _name_from_payload(payload: dict, kind: str) -> str: return payload.get("displayName") or payload.get("id") or "Unknown" -def _request_json(url: str, token: str) -> Optional[dict]: +def _request_json(url: str, token: str) -> dict | None: try: - res = requests.get(url, headers={"Authorization": f"Bearer {token}"}, timeout=10) + res = get_with_retry(url, headers={"Authorization": f"Bearer {token}"}, timeout=10) if res.status_code == 404: return None res.raise_for_status() return res.json() - except requests.RequestException: + except Exception: return None -def resolve_directory_object(object_id: str, token: str, cache: Dict[str, dict]) -> Optional[dict]: +def resolve_directory_object(object_id: str, token: str, cache: dict[str, dict]) -> dict | None: """ Resolve a directory object (user, servicePrincipal, group, device) to a readable name. Uses a simple multi-endpoint probe with caching to avoid extra Graph traffic. @@ -69,7 +68,7 @@ def resolve_directory_object(object_id: str, token: str, cache: Dict[str, dict]) return None -def resolve_service_principal_owners(sp_id: str, token: str, cache: Dict[str, List[str]]) -> List[str]: +def resolve_service_principal_owners(sp_id: str, token: str, cache: dict[str, list[str]]) -> list[str]: """Return a list of owner display names for a service principal.""" if not sp_id: return [] diff --git a/backend/main.py b/backend/main.py index af57e4f..e006569 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,17 +1,18 @@ import asyncio import logging +from contextlib import suppress from pathlib import Path import structlog -from fastapi import FastAPI, HTTPException -from fastapi.staticfiles import StaticFiles -from fastapi.middleware.cors import CORSMiddleware - +from config import CORS_ORIGINS, ENABLE_PERIODIC_FETCH, FETCH_INTERVAL_MINUTES from database import setup_indexes -from routes.fetch import router as fetch_router, run_fetch -from routes.events import router as events_router +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from fastapi.staticfiles import StaticFiles from routes.config import router as config_router -from config import ENABLE_PERIODIC_FETCH, FETCH_INTERVAL_MINUTES, CORS_ORIGINS +from routes.events import router as events_router +from routes.fetch import router as fetch_router +from routes.fetch import run_fetch def configure_logging(): @@ -90,7 +91,5 @@ async def stop_periodic_fetch(): task = getattr(app.state, "fetch_task", None) if task: task.cancel() - try: + with suppress(Exception): await task - except Exception: - pass diff --git a/backend/maintenance.py b/backend/maintenance.py index dfa974d..5023e27 100644 --- a/backend/maintenance.py +++ b/backend/maintenance.py @@ -7,14 +7,12 @@ new display fields. Example: python maintenance.py renormalize --limit 500 """ import argparse -from typing import List, Set - -from pymongo import UpdateOne from database import events_collection -from graph.auth import get_access_token from graph.audit_logs import _enrich_events -from models.event_model import normalize_event, _make_dedupe_key +from graph.auth import get_access_token +from models.event_model import _make_dedupe_key, normalize_event +from pymongo import UpdateOne def renormalize(limit: int = None, batch_size: int = 200) -> int: @@ -29,7 +27,7 @@ def renormalize(limit: int = None, batch_size: int = 200) -> int: cursor = cursor.limit(int(limit)) updated = 0 - batch: List[UpdateOne] = [] + batch: list[UpdateOne] = [] for doc in cursor: raw = doc.get("raw") or {} @@ -59,7 +57,7 @@ def dedupe(limit: int = None, batch_size: int = 500) -> int: if limit: cursor = cursor.limit(int(limit)) - seen: Set[str] = set() + seen: set[str] = set() to_delete = [] processed = 0 diff --git a/backend/mapping_loader.py b/backend/mapping_loader.py index c328d7e..4138a58 100644 --- a/backend/mapping_loader.py +++ b/backend/mapping_loader.py @@ -1,11 +1,10 @@ from functools import lru_cache from pathlib import Path -from typing import Any, Dict +from typing import Any import yaml - -DEFAULT_MAPPING: Dict[str, Any] = { +DEFAULT_MAPPING: dict[str, Any] = { "category_labels": { "ApplicationManagement": "Application", "UserManagement": "User", @@ -38,7 +37,7 @@ DEFAULT_MAPPING: Dict[str, Any] = { @lru_cache(maxsize=1) -def get_mapping() -> Dict[str, Any]: +def get_mapping() -> dict[str, Any]: """ Load mapping from mappings.yml if present; otherwise fall back to defaults. Users can edit mappings.yml to change labels and summary templates. diff --git a/backend/models/__init__.py b/backend/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/models/api.py b/backend/models/api.py new file mode 100644 index 0000000..8dd0347 --- /dev/null +++ b/backend/models/api.py @@ -0,0 +1,41 @@ +from pydantic import BaseModel, ConfigDict + + +class EventItem(BaseModel): + id: str | None = None + timestamp: str | None = None + service: str | None = None + operation: str | None = None + result: str | None = None + actor_display: str | None = None + target_displays: list[str] | None = None + display_summary: str | None = None + display_category: str | None = None + dedupe_key: str | None = None + actor: dict | None = None + targets: list[dict] | None = None + raw: dict | None = None + raw_text: str | None = None + + model_config = ConfigDict(extra="allow") + + +class PaginatedEventResponse(BaseModel): + items: list[dict] + total: int + page: int + page_size: int + + +class FilterOptionsResponse(BaseModel): + services: list[str] + operations: list[str] + results: list[str] + actors: list[str] + actor_upns: list[str] + devices: list[str] + + +class FetchAuditLogsResponse(BaseModel): + stored_events: int + errors: list[str] diff --git a/backend/models/event_model.py b/backend/models/event_model.py index 75dc4bb..ad02f36 100644 --- a/backend/models/event_model.py +++ b/backend/models/event_model.py @@ -2,7 +2,6 @@ import json from mapping_loader import get_mapping - CATEGORY_LABELS = { "ApplicationManagement": "Application", "UserManagement": "User", diff --git a/backend/requirements-dev.txt b/backend/requirements-dev.txt new file mode 100644 index 0000000..ba30bd3 --- /dev/null +++ b/backend/requirements-dev.txt @@ -0,0 +1,4 @@ +pytest +mongomock +httpx +ruff diff --git a/backend/requirements.txt b/backend/requirements.txt index 09cf7fe..00d51db 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -7,3 +7,4 @@ PyYAML python-jose[cryptography] pydantic-settings structlog +tenacity diff --git a/backend/routes/config.py b/backend/routes/config.py index 234b15e..75acae5 100644 --- a/backend/routes/config.py +++ b/backend/routes/config.py @@ -1,10 +1,10 @@ -from fastapi import APIRouter from config import ( - AUTH_ENABLED, - AUTH_TENANT_ID, AUTH_CLIENT_ID, + AUTH_ENABLED, AUTH_SCOPE, + AUTH_TENANT_ID, ) +from fastapi import APIRouter router = APIRouter() diff --git a/backend/routes/events.py b/backend/routes/events.py index dfad21a..862a7be 100644 --- a/backend/routes/events.py +++ b/backend/routes/events.py @@ -1,22 +1,24 @@ import re -from fastapi import APIRouter, HTTPException, Depends -from database import events_collection + from auth import require_auth +from database import events_collection +from fastapi import APIRouter, Depends, HTTPException, Query +from models.api import FilterOptionsResponse, PaginatedEventResponse router = APIRouter(dependencies=[Depends(require_auth)]) -@router.get("/events") +@router.get("/events", response_model=PaginatedEventResponse) def list_events( - service: str = None, - actor: str = None, - operation: str = None, - result: str = None, - start: str = None, - end: str = None, - search: str = None, - page: int = 1, - page_size: int = 50, + service: str | None = None, + actor: str | None = None, + operation: str | None = None, + result: str | None = None, + start: str | None = None, + end: str | None = None, + search: str | None = None, + page: int = Query(default=1, ge=1), + page_size: int = Query(default=50, ge=1, le=500), ): filters = [] @@ -82,8 +84,8 @@ def list_events( } -@router.get("/filter-options") -def filter_options(limit: int = 200): +@router.get("/filter-options", response_model=FilterOptionsResponse) +def filter_options(limit: int = Query(default=200, ge=1, le=1000)): """ Provide distinct values for UI filters (best-effort, capped). """ diff --git a/backend/routes/fetch.py b/backend/routes/fetch.py index 13ae361..fb5166e 100644 --- a/backend/routes/fetch.py +++ b/backend/routes/fetch.py @@ -1,12 +1,12 @@ -from fastapi import APIRouter, HTTPException, Depends -from pymongo import UpdateOne - -from database import events_collection -from graph.audit_logs import fetch_audit_logs -from sources.unified_audit import fetch_unified_audit -from sources.intune_audit import fetch_intune_audit -from models.event_model import normalize_event from auth import require_auth +from database import events_collection +from fastapi import APIRouter, Depends, HTTPException, Query +from graph.audit_logs import fetch_audit_logs +from models.api import FetchAuditLogsResponse +from models.event_model import normalize_event +from pymongo import UpdateOne +from sources.intune_audit import fetch_intune_audit +from sources.unified_audit import fetch_unified_audit router = APIRouter(dependencies=[Depends(require_auth)]) @@ -40,8 +40,8 @@ def run_fetch(hours: int = 168): return {"stored_events": len(normalized), "errors": errors} -@router.get("/fetch-audit-logs") -def fetch_logs(hours: int = 168): +@router.get("/fetch-audit-logs", response_model=FetchAuditLogsResponse) +def fetch_logs(hours: int = Query(default=168, ge=1, le=720)): try: return run_fetch(hours=hours) except Exception as exc: diff --git a/backend/sources/intune_audit.py b/backend/sources/intune_audit.py index c29687d..25c31f2 100644 --- a/backend/sources/intune_audit.py +++ b/backend/sources/intune_audit.py @@ -1,11 +1,10 @@ -import requests from datetime import datetime, timedelta -from typing import List from graph.auth import get_access_token +from utils.http import get_with_retry -def fetch_intune_audit(hours: int = 24, max_pages: int = 50) -> List[dict]: +def fetch_intune_audit(hours: int = 24, max_pages: int = 50) -> list[dict]: """ Fetch Intune audit events via Microsoft Graph. Requires Intune audit permissions (e.g., DeviceManagementConfiguration.Read.All). @@ -24,13 +23,13 @@ def fetch_intune_audit(hours: int = 24, max_pages: int = 50) -> List[dict]: if pages >= max_pages: raise RuntimeError(f"Aborting Intune pagination after {max_pages} pages.") try: - res = requests.get(url, headers=headers, timeout=20) + res = get_with_retry(url, headers=headers, timeout=20) res.raise_for_status() body = res.json() - except requests.RequestException as exc: + except RuntimeError: + raise + except Exception as exc: raise RuntimeError(f"Failed to fetch Intune audit logs: {exc}") from exc - except ValueError as exc: - raise RuntimeError(f"Invalid Intune response JSON: {exc}") from exc events.extend(body.get("value", [])) url = body.get("@odata.nextLink") diff --git a/backend/sources/unified_audit.py b/backend/sources/unified_audit.py index d1dbc84..77059a0 100644 --- a/backend/sources/unified_audit.py +++ b/backend/sources/unified_audit.py @@ -1,9 +1,8 @@ -import requests +from contextlib import suppress from datetime import datetime, timedelta -from typing import List from graph.auth import get_access_token - +from utils.http import get_with_retry, post_with_retry AUDIT_CONTENT_TYPES = { "Audit.Exchange": "Exchange admin audit", @@ -23,40 +22,41 @@ def _ensure_subscription(content_type: str, token: str, tenant_id: str): url = f"https://manage.office.com/api/v1.0/{tenant_id}/activity/feed/subscriptions/start" params = {"contentType": content_type} headers = {"Authorization": f"Bearer {token}"} - try: - requests.post(url, params=params, headers=headers, timeout=10) - except requests.RequestException: - pass # best-effort + with suppress(Exception): + post_with_retry(url, params=params, headers=headers, timeout=10) -def _list_content(content_type: str, token: str, tenant_id: str, hours: int) -> List[dict]: +def _list_content(content_type: str, token: str, tenant_id: str, hours: int) -> list[dict]: start, end = _time_window(hours) url = f"https://manage.office.com/api/v1.0/{tenant_id}/activity/feed/subscriptions/content" params = {"contentType": content_type, "startTime": start, "endTime": end} headers = {"Authorization": f"Bearer {token}"} try: - res = requests.get(url, params=params, headers=headers, timeout=20) + res = get_with_retry(url, params=params, headers=headers, timeout=20) if res.status_code in (400, 401, 403, 404): # Likely not enabled or insufficient perms; surface the text to the caller. raise RuntimeError(f"{content_type} content listing failed ({res.status_code}): {res.text}") - return [] res.raise_for_status() return res.json() or [] - except requests.RequestException as exc: + except RuntimeError: + raise + except Exception as exc: raise RuntimeError(f"Failed to list {content_type} content: {exc}") from exc -def _download_content(content_uri: str, token: str) -> List[dict]: +def _download_content(content_uri: str, token: str) -> list[dict]: headers = {"Authorization": f"Bearer {token}"} try: - res = requests.get(content_uri, headers=headers, timeout=30) + res = get_with_retry(content_uri, headers=headers, timeout=30) res.raise_for_status() return res.json() or [] - except requests.RequestException as exc: + except RuntimeError: + raise + except Exception as exc: raise RuntimeError(f"Failed to download audit content: {exc}") from exc -def fetch_unified_audit(hours: int = 24, max_files: int = 50) -> List[dict]: +def fetch_unified_audit(hours: int = 24, max_files: int = 50) -> list[dict]: """ Fetch unified audit logs (Exchange, SharePoint, Teams policy changes via Audit.General) using the Office 365 Management Activity API. @@ -67,7 +67,7 @@ def fetch_unified_audit(hours: int = 24, max_files: int = 50) -> List[dict]: events = [] - for content_type in AUDIT_CONTENT_TYPES.keys(): + for content_type in AUDIT_CONTENT_TYPES: _ensure_subscription(content_type, token, TENANT_ID) contents = _list_content(content_type, token, TENANT_ID, hours) for item in contents[:max_files]: diff --git a/backend/tests/__init__.py b/backend/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py new file mode 100644 index 0000000..cc6c7d7 --- /dev/null +++ b/backend/tests/conftest.py @@ -0,0 +1,26 @@ +import mongomock +import pytest +from fastapi.testclient import TestClient + + +@pytest.fixture(scope="function") +def mock_events_collection(): + client = mongomock.MongoClient() + db = client["micro_soc"] + coll = db["events"] + return coll + + +@pytest.fixture(scope="function") +def client(mock_events_collection, monkeypatch): + # Patch the collection in all modules that import it before the app is imported + monkeypatch.setattr("database.events_collection", mock_events_collection) + monkeypatch.setattr("routes.fetch.events_collection", mock_events_collection) + monkeypatch.setattr("routes.events.events_collection", mock_events_collection) + monkeypatch.setattr("auth.AUTH_ENABLED", False) + # Patch health check db.command so it doesn't need a real MongoDB server + monkeypatch.setattr("database.db.command", lambda cmd: {"ok": 1} if cmd == "ping" else {}) + + from main import app + + return TestClient(app) diff --git a/backend/tests/test_api.py b/backend/tests/test_api.py new file mode 100644 index 0000000..89506b5 --- /dev/null +++ b/backend/tests/test_api.py @@ -0,0 +1,98 @@ +from datetime import UTC, datetime + + +def test_health(client): + response = client.get("/health") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + assert data["database"] == "connected" + + +def test_list_events_empty(client): + response = client.get("/api/events") + assert response.status_code == 200 + data = response.json() + assert data["items"] == [] + assert data["total"] == 0 + + +def test_list_events_pagination(client, mock_events_collection): + for i in range(5): + mock_events_collection.insert_one({ + "id": f"evt-{i}", + "timestamp": datetime.now(UTC).isoformat(), + "service": "Directory", + "operation": "Add user", + "result": "success", + "actor_display": f"Actor {i}", + "raw_text": "", + }) + response = client.get("/api/events?page=1&page_size=2") + assert response.status_code == 200 + data = response.json() + assert data["total"] == 5 + assert len(data["items"]) == 2 + assert data["page"] == 1 + assert data["page_size"] == 2 + + +def test_list_events_filter_by_service(client, mock_events_collection): + mock_events_collection.insert_one({ + "id": "evt-1", + "timestamp": datetime.now(UTC).isoformat(), + "service": "Exchange", + "operation": "Update", + "result": "success", + "actor_display": "Alice", + "raw_text": "", + }) + mock_events_collection.insert_one({ + "id": "evt-2", + "timestamp": datetime.now(UTC).isoformat(), + "service": "Directory", + "operation": "Add", + "result": "success", + "actor_display": "Bob", + "raw_text": "", + }) + response = client.get("/api/events?service=Exchange") + assert response.status_code == 200 + data = response.json() + assert data["total"] == 1 + assert data["items"][0]["service"] == "Exchange" + + +def test_list_events_page_size_validation(client): + response = client.get("/api/events?page_size=0") + assert response.status_code == 422 + response = client.get("/api/events?page_size=501") + assert response.status_code == 422 + + +def test_filter_options(client, mock_events_collection): + mock_events_collection.insert_one({ + "id": "evt-1", + "timestamp": datetime.now(UTC).isoformat(), + "service": "Intune", + "operation": "Assign", + "result": "failure", + "actor_display": "Charlie", + "actor_upn": "charlie@example.com", + "raw_text": "", + }) + response = client.get("/api/filter-options") + assert response.status_code == 200 + data = response.json() + assert "Intune" in data["services"] + assert "Assign" in data["operations"] + assert "failure" in data["results"] + assert "Charlie" in data["actors"] + assert "charlie@example.com" in data["actor_upns"] + + +def test_fetch_audit_logs_validation(client): + response = client.get("/api/fetch-audit-logs?hours=0") + assert response.status_code == 422 + response = client.get("/api/fetch-audit-logs?hours=721") + assert response.status_code == 422 diff --git a/backend/tests/test_auth.py b/backend/tests/test_auth.py new file mode 100644 index 0000000..b876a9c --- /dev/null +++ b/backend/tests/test_auth.py @@ -0,0 +1,61 @@ +from unittest.mock import patch + +import auth +import pytest +from auth import _allowed, require_auth +from fastapi import HTTPException + + +@pytest.fixture(autouse=True) +def reset_cache(): + auth.JWKS_CACHE["keys"] = [] + auth.JWKS_CACHE["exp"] = 0 + + +@pytest.fixture +def mock_jwks(): + from Crypto.PublicKey import RSA + from jose.jwk import RSAKey + key = RSA.generate(2048) + rsa_key = RSAKey(key) + jwk_dict = { + "kty": "RSA", + "kid": "test-kid", + "n": rsa_key._key.n, + "e": rsa_key._key.e, + } + return rsa_key, jwk_dict + + +def test_allowed_no_restrictions(): + assert _allowed({}, set(), set()) is True + + +def test_allowed_by_role(): + assert _allowed({"roles": ["Admin"]}, {"Admin"}, set()) is True + assert _allowed({"roles": ["User"]}, {"Admin"}, set()) is False + + +def test_allowed_by_group(): + assert _allowed({"groups": ["SecOps"]}, set(), {"SecOps"}) is True + assert _allowed({"groups": ["Users"]}, set(), {"SecOps"}) is False + + +@patch("auth.AUTH_ENABLED", False) +def test_require_auth_disabled(): + claims = require_auth(None) + assert claims["sub"] == "anonymous" + + +@patch("auth.AUTH_ENABLED", True) +def test_require_auth_missing_header(): + with pytest.raises(HTTPException) as exc_info: + require_auth(None) + assert exc_info.value.status_code == 401 + + +@patch("auth.AUTH_ENABLED", True) +def test_require_auth_invalid_bearer(): + with pytest.raises(HTTPException) as exc_info: + require_auth("Basic abc") + assert exc_info.value.status_code == 401 diff --git a/backend/tests/test_event_model.py b/backend/tests/test_event_model.py new file mode 100644 index 0000000..9e497c1 --- /dev/null +++ b/backend/tests/test_event_model.py @@ -0,0 +1,63 @@ +from models.event_model import _make_dedupe_key, normalize_event + + +def test_make_dedupe_key_prefers_id_and_category(): + e = {"id": "evt-123", "category": "Directory"} + assert _make_dedupe_key(e) == "evt-123|Directory" + + +def test_make_dedupe_key_fallback_without_id(): + e = { + "activityDateTime": "2024-01-01T00:00:00Z", + "category": "Exchange", + "activityDisplayName": "Update setting", + } + key = _make_dedupe_key(e) + assert "2024-01-01T00:00:00Z|Exchange|Update setting" in key + + +def test_normalize_event_basic(): + e = { + "id": "abc", + "activityDateTime": "2024-01-15T10:30:00Z", + "category": "UserManagement", + "activityDisplayName": "Add user", + "result": "success", + "initiatedBy": { + "user": { + "id": "u1", + "displayName": "Alice", + "userPrincipalName": "alice@example.com", + } + }, + "targetResources": [ + {"id": "t1", "displayName": "Bob", "type": "User"} + ], + } + out = normalize_event(e) + assert out["id"] == "abc" + assert out["timestamp"] == "2024-01-15T10:30:00Z" + assert out["service"] == "UserManagement" + assert out["operation"] == "Add user" + assert out["result"] == "success" + assert out["actor_display"] == "Alice (alice@example.com)" + assert out["target_displays"] == ["Bob"] + assert out["dedupe_key"] == "abc|UserManagement" + assert "raw_text" in out + + +def test_normalize_event_with_resolved_actor(): + e = { + "id": "def", + "activityDateTime": "2024-01-15T11:00:00Z", + "category": "ApplicationManagement", + "activityDisplayName": "Add app", + "result": "success", + "initiatedBy": {"servicePrincipal": {"id": "sp1"}}, + "targetResources": [], + "_resolvedActor": {"id": "sp1", "type": "servicePrincipal", "name": "MyApp"}, + "_resolvedActorOwners": ["Owner1"], + } + out = normalize_event(e) + assert out["actor_display"] == "MyApp (owners: Owner1)" + assert out["display_category"] == "Application" diff --git a/backend/utils/__init__.py b/backend/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/utils/http.py b/backend/utils/http.py new file mode 100644 index 0000000..2adfa9d --- /dev/null +++ b/backend/utils/http.py @@ -0,0 +1,29 @@ + +import requests +import structlog +from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential + +logger = structlog.get_logger("aoc.http") + +RETRY_CONFIG = { + "stop": stop_after_attempt(3), + "wait": wait_exponential(multiplier=1, min=2, max=10), + "retry": retry_if_exception_type(requests.RequestException), + "before_sleep": lambda retry_state: logger.warning( + "Retrying HTTP request", + attempt=retry_state.attempt_number, + exception=str(retry_state.outcome.exception()) if retry_state.outcome else None, + ), +} + + +@retry(**RETRY_CONFIG) +def get_with_retry(url: str, headers: dict | None = None, params: dict | None = None, timeout: float = 20) -> requests.Response: + res = requests.get(url, headers=headers, params=params, timeout=timeout) + return res + + +@retry(**RETRY_CONFIG) +def post_with_retry(url: str, headers: dict | None = None, data: dict | None = None, params: dict | None = None, timeout: float = 15) -> requests.Response: + res = requests.post(url, headers=headers, data=data, params=params, timeout=timeout) + return res diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..6a6ac71 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,14 @@ +[tool.ruff] +target-version = "py311" +line-length = 120 + +[tool.ruff.lint] +select = ["E", "F", "I", "N", "W", "UP", "B", "C4", "SIM"] +ignore = ["E501"] + +[tool.ruff.lint.pydocstyle] +convention = "google" + +[tool.pytest.ini_options] +testpaths = ["backend/tests"] +pythonpath = ["backend"]