diff --git a/.env.example b/.env.example index 00454d0..ecc06e9 100644 --- a/.env.example +++ b/.env.example @@ -49,3 +49,9 @@ LLM_MODEL=gpt-4o-mini LLM_MAX_EVENTS=200 LLM_TIMEOUT_SECONDS=30 LLM_API_VERSION= + +# Optional: privacy / service-level access control +# Comma-separated list of services considered privacy-sensitive (hidden from users without PRIVACY_SERVICE_ROLES) +# PRIVACY_SERVICES=Exchange,Teams +# Comma-separated list of Entra roles that can access privacy-sensitive services +# PRIVACY_SERVICE_ROLES=SecurityAdministrator,ComplianceAdministrator diff --git a/backend/auth.py b/backend/auth.py index 0c9363d..506c3cf 100644 --- a/backend/auth.py +++ b/backend/auth.py @@ -8,6 +8,8 @@ from config import ( AUTH_CLIENT_ID, AUTH_ENABLED, AUTH_TENANT_ID, + PRIVACY_SERVICE_ROLES, + PRIVACY_SERVICES, ) from fastapi import Header, HTTPException from jwt import ExpiredSignatureError, InvalidTokenError, decode @@ -82,6 +84,14 @@ def _decode_token(token: str, jwks): raise HTTPException(status_code=401, detail=f"Invalid token ({type(exc).__name__})") from None +def user_can_access_privacy_services(claims: dict) -> bool: + """Check if the user has roles that grant access to privacy-sensitive services.""" + if not PRIVACY_SERVICES or not PRIVACY_SERVICE_ROLES: + return True + user_roles = set(claims.get("roles", []) or claims.get("role", []) or []) + return bool(user_roles.intersection(PRIVACY_SERVICE_ROLES)) + + def require_auth(authorization: str | None = Header(None)): if not AUTH_ENABLED: return {"sub": "anonymous"} diff --git a/backend/config.py b/backend/config.py index 25eeda9..a0c6bda 100644 --- a/backend/config.py +++ b/backend/config.py @@ -51,6 +51,11 @@ class Settings(BaseSettings): LLM_TIMEOUT_SECONDS: int = 30 LLM_API_VERSION: str = "" # e.g. 2025-01-01-preview for Azure OpenAI + # Privacy / Service-level access control + # Services listed here are hidden from users who don't have PRIVACY_SERVICE_ROLES + PRIVACY_SERVICES: str = "" # comma-separated, e.g. "Exchange,Teams" + PRIVACY_SERVICE_ROLES: str = "" # comma-separated, e.g. "SecurityAdministrator,ComplianceAdministrator" + _settings = Settings() @@ -85,3 +90,6 @@ LLM_MODEL = _settings.LLM_MODEL LLM_MAX_EVENTS = _settings.LLM_MAX_EVENTS LLM_TIMEOUT_SECONDS = _settings.LLM_TIMEOUT_SECONDS LLM_API_VERSION = _settings.LLM_API_VERSION + +PRIVACY_SERVICES = {s.strip() for s in _settings.PRIVACY_SERVICES.split(",") if s.strip()} +PRIVACY_SERVICE_ROLES = {r.strip() for r in _settings.PRIVACY_SERVICE_ROLES.split(",") if r.strip()} diff --git a/backend/routes/ask.py b/backend/routes/ask.py index 23bab4a..e7b15ed 100644 --- a/backend/routes/ask.py +++ b/backend/routes/ask.py @@ -5,8 +5,16 @@ from datetime import UTC, datetime, timedelta import httpx import structlog -from auth import require_auth -from config import LLM_API_KEY, LLM_API_VERSION, LLM_BASE_URL, LLM_MAX_EVENTS, LLM_MODEL, LLM_TIMEOUT_SECONDS +from auth import require_auth, user_can_access_privacy_services +from config import ( + LLM_API_KEY, + LLM_API_VERSION, + LLM_BASE_URL, + LLM_MAX_EVENTS, + LLM_MODEL, + LLM_TIMEOUT_SECONDS, + PRIVACY_SERVICES, +) from database import events_collection from fastapi import APIRouter, Depends, HTTPException from models.api import AskRequest, AskResponse @@ -588,6 +596,9 @@ async def explain_event(event_id: str, user: dict = Depends(require_auth)): if not event: raise HTTPException(status_code=404, detail="Event not found") + if event.get("service") in PRIVACY_SERVICES and not user_can_access_privacy_services(user): + raise HTTPException(status_code=403, detail="Access to this event is restricted") + event.pop("_id", None) # Fetch related events for context (same actor or target in last 24h) @@ -678,6 +689,7 @@ async def ask_question(body: AskRequest, user: dict = Depends(require_auth)): # ----------------------------------------------------------------------- # Build and run query # ----------------------------------------------------------------------- + privacy_excluded = [] if user_can_access_privacy_services(user) else list(PRIVACY_SERVICES) query = _build_event_query( entity, start, @@ -689,6 +701,8 @@ async def ask_question(body: AskRequest, user: dict = Depends(require_auth)): include_tags=body.include_tags, exclude_tags=body.exclude_tags, ) + if privacy_excluded: + query["$and"] = query.get("$and", []) + [{"service": {"$nin": privacy_excluded}}] try: total = events_collection.count_documents(query) diff --git a/backend/routes/events.py b/backend/routes/events.py index 39001ae..8909b33 100644 --- a/backend/routes/events.py +++ b/backend/routes/events.py @@ -3,8 +3,9 @@ import re from datetime import UTC, datetime from audit_trail import log_action -from auth import require_auth +from auth import require_auth, user_can_access_privacy_services from bson import ObjectId +from config import PRIVACY_SERVICES from database import events_collection from fastapi import APIRouter, Depends, HTTPException, Query from models.api import ( @@ -44,6 +45,7 @@ def _build_query( cursor: str | None = None, include_tags: list[str] | None = None, exclude_tags: list[str] | None = None, + exclude_services: list[str] | None = None, ) -> dict: filters = [] @@ -51,6 +53,8 @@ def _build_query( filters.append({"service": service}) if services: filters.append({"service": {"$in": services}}) + if exclude_services: + filters.append({"service": {"$nin": exclude_services}}) if actor: actor_safe = re.escape(actor) filters.append( @@ -125,6 +129,7 @@ def list_events( exclude_tags: list[str] | None = Query(default=None), user: dict = Depends(require_auth), ): + privacy_excluded = [] if user_can_access_privacy_services(user) else list(PRIVACY_SERVICES) query = _build_query( service=service, services=services, @@ -137,6 +142,7 @@ def list_events( cursor=cursor, include_tags=include_tags, exclude_tags=exclude_tags, + exclude_services=privacy_excluded, ) safe_page_size = max(1, min(page_size, 500)) @@ -202,6 +208,7 @@ def bulk_tags( exclude_tags: list[str] | None = Query(default=None), user: dict = Depends(require_auth), ): + privacy_excluded = [] if user_can_access_privacy_services(user) else list(PRIVACY_SERVICES) query = _build_query( service=service, services=services, @@ -213,6 +220,7 @@ def bulk_tags( search=search, include_tags=include_tags, exclude_tags=exclude_tags, + exclude_services=privacy_excluded, ) tags = [t.strip() for t in body.tags if t.strip()] if not tags: @@ -235,7 +243,10 @@ def bulk_tags( @router.get("/filter-options", response_model=FilterOptionsResponse) -def filter_options(limit: int = Query(default=200, ge=1, le=1000)): +def filter_options( + limit: int = Query(default=200, ge=1, le=1000), + user: dict = Depends(require_auth), +): safe_limit = max(1, min(limit, 1000)) try: services = sorted(events_collection.distinct("service"))[:safe_limit] @@ -247,6 +258,9 @@ def filter_options(limit: int = Query(default=200, ge=1, le=1000)): except Exception as exc: raise HTTPException(status_code=500, detail=f"Failed to load filter options: {exc}") from exc + if not user_can_access_privacy_services(user): + services = [s for s in services if s not in PRIVACY_SERVICES] + return { "services": services, "operations": operations, diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 6bf52a8..b67b02c 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -33,6 +33,9 @@ def client(mock_events_collection, mock_watermarks_collection, monkeypatch): monkeypatch.setattr("routes.fetch.set_watermark", lambda source, ts: None) monkeypatch.setattr("auth.AUTH_ENABLED", False) monkeypatch.setattr("routes.mcp.AUTH_ENABLED", False) + monkeypatch.setattr("config.PRIVACY_SERVICES", set()) + monkeypatch.setattr("routes.events.PRIVACY_SERVICES", set()) + monkeypatch.setattr("routes.ask.PRIVACY_SERVICES", set()) monkeypatch.setattr("database.db.command", lambda cmd: {"ok": 1} if cmd == "ping" else {}) # Mock audit trail and rules collections so tests don't wait on real MongoDB diff --git a/backend/tests/test_api.py b/backend/tests/test_api.py index 47b3900..df93e61 100644 --- a/backend/tests/test_api.py +++ b/backend/tests/test_api.py @@ -149,6 +149,79 @@ def test_saved_searches_create_validation(client, monkeypatch): assert response.status_code == 400 +def test_privacy_filtering_events(client, mock_events_collection, monkeypatch): + monkeypatch.setattr("config.PRIVACY_SERVICES", {"Exchange", "Teams"}) + monkeypatch.setattr("routes.events.PRIVACY_SERVICES", {"Exchange", "Teams"}) + monkeypatch.setattr("auth.PRIVACY_SERVICE_ROLES", {"SecurityAdmin"}) + monkeypatch.setattr("auth.user_can_access_privacy_services", lambda claims: False) + monkeypatch.setattr("routes.events.user_can_access_privacy_services", lambda claims: False) + + mock_events_collection.insert_one( + { + "id": "evt-dir", + "timestamp": datetime.now(UTC).isoformat(), + "service": "Directory", + "operation": "Add user", + "result": "success", + "actor_display": "Alice", + "raw_text": "", + } + ) + mock_events_collection.insert_one( + { + "id": "evt-exc", + "timestamp": datetime.now(UTC).isoformat(), + "service": "Exchange", + "operation": "Send", + "result": "success", + "actor_display": "Bob", + "raw_text": "", + } + ) + + response = client.get("/api/events") + assert response.status_code == 200 + data = response.json() + ids = [e["id"] for e in data["items"]] + assert "evt-dir" in ids + assert "evt-exc" not in ids + + +def test_privacy_filter_options(client, mock_events_collection, monkeypatch): + monkeypatch.setattr("config.PRIVACY_SERVICES", {"Exchange"}) + monkeypatch.setattr("routes.events.PRIVACY_SERVICES", {"Exchange"}) + monkeypatch.setattr("auth.PRIVACY_SERVICE_ROLES", {"SecurityAdmin"}) + monkeypatch.setattr("auth.user_can_access_privacy_services", lambda claims: False) + monkeypatch.setattr("routes.events.user_can_access_privacy_services", lambda claims: False) + + response = client.get("/api/filter-options") + assert response.status_code == 200 + data = response.json() + assert "Exchange" not in data["services"] + + +def test_privacy_explain_forbidden(client, mock_events_collection, monkeypatch): + monkeypatch.setattr("config.PRIVACY_SERVICES", {"Exchange"}) + monkeypatch.setattr("routes.ask.PRIVACY_SERVICES", {"Exchange"}) + monkeypatch.setattr("auth.PRIVACY_SERVICE_ROLES", {"SecurityAdmin"}) + monkeypatch.setattr("auth.user_can_access_privacy_services", lambda claims: False) + monkeypatch.setattr("routes.ask.user_can_access_privacy_services", lambda claims: False) + + mock_events_collection.insert_one( + { + "id": "evt-exc2", + "timestamp": datetime.now(UTC).isoformat(), + "service": "Exchange", + "operation": "Send", + "result": "success", + "actor_display": "Bob", + "raw_text": "", + } + ) + response = client.post("/api/events/evt-exc2/explain") + assert response.status_code == 403 + + def test_health(client): response = client.get("/health") assert response.status_code == 200