diff --git a/README.md b/README.md index 23b3d99..be94581 100644 --- a/README.md +++ b/README.md @@ -102,14 +102,25 @@ uvicorn main:app --reload --host 0.0.0.0 --port 8000 - `DELETE /api/rules/{id}` — delete an alert rule. ### MCP Server -A standalone MCP server (`backend/mcp_server.py`) is included for Claude Desktop, Cursor, and other MCP clients: +AOC exposes an MCP interface in two forms: + +**1. HTTP/SSE (production)** — mounted at `/mcp` inside the FastAPI app, behind OIDC auth: +- `GET /mcp/sse` — establish SSE stream (requires Bearer token if `AUTH_ENABLED=true`) +- `POST /mcp/messages/?session_id=...` — send tool calls + +This is the recommended way to use MCP against a remote deployment like `aoc.cqre.net`. Any MCP client that supports SSE transport (e.g. Cursor, Claude Desktop with an SSE bridge, or custom scripts) can connect using the same Entra token as the web UI. + +**2. stdio (local development)** — `python backend/mcp_server.py`: +- Runs as a local subprocess for Claude Desktop +- Connects directly to MongoDB (bypasses FastAPI auth) +- Useful for local development when you have the repo cloned and MongoDB running locally + +Available tools (both transports): - `search_events` — filter by entity, service, operation, result, time range. - `get_event` — retrieve raw event JSON by ID. - `get_summary` — aggregated summary (service, operation, result, actor counts) for the last N days. - `ask` — natural language query returning recent events. -Configure your MCP client to run `python /path/to/aoc/backend/mcp_server.py` with `MONGO_URI` in the environment. - Stored document shape (collection `micro_soc.events`): ```json { diff --git a/backend/main.py b/backend/main.py index 6c63b01..f0d05d4 100644 --- a/backend/main.py +++ b/backend/main.py @@ -116,6 +116,9 @@ if AI_FEATURES_ENABLED: from routes.ask import router as ask_router app.include_router(ask_router, prefix="/api") + from routes.mcp import mcp_asgi + + app.mount("/mcp", mcp_asgi) app.include_router(rules_router, prefix="/api") diff --git a/backend/mcp_common.py b/backend/mcp_common.py new file mode 100644 index 0000000..05d4714 --- /dev/null +++ b/backend/mcp_common.py @@ -0,0 +1,187 @@ +"""Shared MCP tool handlers used by both stdio and SSE transports.""" + +import json +from datetime import UTC, datetime, timedelta + +from database import events_collection +from mcp.types import TextContent + + +async def handle_search_events(arguments: dict) -> list[TextContent]: + days = arguments.get("days", 7) + limit = min(arguments.get("limit", 20), 100) + since = (datetime.now(UTC) - timedelta(days=days)).isoformat().replace("+00:00", "Z") + + filters = [{"timestamp": {"$gte": since}}] + + services = arguments.get("services") + if services: + filters.append({"service": {"$in": services}}) + + operation = arguments.get("operation") + if operation: + filters.append({"operation": {"$regex": operation, "$options": "i"}}) + + result = arguments.get("result") + if result: + filters.append({"result": {"$regex": result, "$options": "i"}}) + + entity = arguments.get("entity") + if entity: + entity_safe = entity.replace(".", "\\.").replace("(", "\\(").replace(")", "\\)") + filters.append( + { + "$or": [ + {"target_displays": {"$elemMatch": {"$regex": entity_safe, "$options": "i"}}}, + {"actor_display": {"$regex": entity_safe, "$options": "i"}}, + {"actor_upn": {"$regex": entity_safe, "$options": "i"}}, + {"raw_text": {"$regex": entity_safe, "$options": "i"}}, + ] + } + ) + + query = {"$and": filters} + cursor = events_collection.find(query).sort("timestamp", -1).limit(limit) + events = list(cursor) + + if not events: + return [TextContent(type="text", text="No matching events found.")] + + lines = [f"Found {len(events)} event(s):\n"] + for e in events: + ts = e.get("timestamp", "?")[:16].replace("T", " ") + svc = e.get("service", "?") + op = e.get("operation", "?") + actor = e.get("actor_display", "?") + result_str = e.get("result", "?") + lines.append(f"{ts} | {svc} | {op} | {actor} | {result_str}") + + return [TextContent(type="text", text="\n".join(lines))] + + +async def handle_get_event(arguments: dict) -> list[TextContent]: + event_id = arguments["event_id"] + event = events_collection.find_one({"id": event_id}) + if not event: + return [TextContent(type="text", text=f"Event {event_id} not found.")] + event.pop("_id", None) + return [TextContent(type="text", text=json.dumps(event, indent=2, default=str))] + + +async def handle_get_summary(arguments: dict) -> list[TextContent]: + days = arguments.get("days", 7) + since = (datetime.now(UTC) - timedelta(days=days)).isoformat().replace("+00:00", "Z") + query = {"timestamp": {"$gte": since}} + + total = events_collection.count_documents(query) + if total == 0: + return [TextContent(type="text", text="No events in the specified period.")] + + svc_pipeline = [ + {"$match": query}, + {"$group": {"_id": "$service", "count": {"$sum": 1}}}, + {"$sort": {"count": -1}}, + {"$limit": 10}, + ] + op_pipeline = [ + {"$match": query}, + {"$group": {"_id": "$operation", "count": {"$sum": 1}}}, + {"$sort": {"count": -1}}, + {"$limit": 10}, + ] + result_pipeline = [ + {"$match": query}, + {"$group": {"_id": "$result", "count": {"$sum": 1}}}, + {"$sort": {"count": -1}}, + ] + actor_pipeline = [ + {"$match": query}, + {"$group": {"_id": "$actor_display", "count": {"$sum": 1}}}, + {"$sort": {"count": -1}}, + {"$limit": 10}, + ] + + svc_counts = list(events_collection.aggregate(svc_pipeline)) + op_counts = list(events_collection.aggregate(op_pipeline)) + result_counts = list(events_collection.aggregate(result_pipeline)) + actor_counts = list(events_collection.aggregate(actor_pipeline)) + + lines = [f"Summary for the last {days} days ({total} total events)\n"] + + lines.append("By service:") + for row in svc_counts: + lines.append(f" {row['_id'] or 'Unknown'}: {row['count']}") + + lines.append("\nBy action:") + for row in op_counts: + lines.append(f" {row['_id'] or 'Unknown'}: {row['count']}") + + lines.append("\nBy result:") + for row in result_counts: + lines.append(f" {row['_id'] or 'Unknown'}: {row['count']}") + + lines.append("\nTop actors:") + for row in actor_counts: + lines.append(f" {row['_id'] or 'Unknown'}: {row['count']}") + + return [TextContent(type="text", text="\n".join(lines))] + + +async def handle_ask(arguments: dict) -> list[TextContent]: + """For now, returns recent events + guidance. In the future this could call the LLM backend.""" + question = arguments["question"] + days = arguments.get("days", 7) + + result = await handle_search_events({"entity": "", "days": days, "limit": 50}) + base_text = result[0].text if result else "" + + text = ( + f"You asked: '{question}'\n\n" + f"Here are the most recent events from the last {days} days:\n\n" + f"{base_text}\n\n" + f"Tip: Use the 'search_events' tool with specific filters " + f"to narrow down the dataset before asking follow-up questions." + ) + return [TextContent(type="text", text=text)] + + +# JSON schemas for tool definitions +SEARCH_EVENTS_SCHEMA = { + "type": "object", + "properties": { + "entity": {"type": "string", "description": "Device name, user UPN, or email to search for"}, + "services": { + "type": "array", + "items": {"type": "string"}, + "description": "Filter by service (e.g. Intune, Directory, Exchange)", + }, + "operation": {"type": "string", "description": "Filter by operation name"}, + "result": {"type": "string", "description": "Filter by result (success, failure)"}, + "days": {"type": "integer", "description": "Number of days to look back (default 7)"}, + "limit": {"type": "integer", "description": "Max events to return (default 20)"}, + }, +} + +GET_EVENT_SCHEMA = { + "type": "object", + "properties": { + "event_id": {"type": "string", "description": "The event ID to retrieve"}, + }, + "required": ["event_id"], +} + +GET_SUMMARY_SCHEMA = { + "type": "object", + "properties": { + "days": {"type": "integer", "description": "Number of days to summarise (default 7)"}, + }, +} + +ASK_SCHEMA = { + "type": "object", + "properties": { + "question": {"type": "string", "description": "Natural language question about audit logs"}, + "days": {"type": "integer", "description": "Number of days to look back (default 7)"}, + }, + "required": ["question"], +} diff --git a/backend/mcp_server.py b/backend/mcp_server.py index 305bb69..f487685 100644 --- a/backend/mcp_server.py +++ b/backend/mcp_server.py @@ -1,9 +1,9 @@ #!/usr/bin/env python3 """ -AOC MCP Server +AOC MCP Server — stdio transport -Standalone MCP server that exposes audit log search tools for Claude Desktop, -Cursor, and other MCP clients. +Standalone MCP server for local use (Claude Desktop, Cursor, etc.). +For the HTTP/SSE version (production, behind auth), see routes/mcp.py. Usage: python mcp_server.py @@ -21,65 +21,28 @@ Claude Desktop config (~/.config/claude/claude_desktop_config.json): """ import asyncio -import json import os import sys -from datetime import UTC, datetime, timedelta -# Ensure backend modules are importable +# Ensure backend modules are importable when run standalone sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) -from database import events_collection from mcp.server import Server from mcp.server.stdio import stdio_server from mcp.types import TextContent, Tool +from mcp_common import ( + ASK_SCHEMA, + GET_EVENT_SCHEMA, + GET_SUMMARY_SCHEMA, + SEARCH_EVENTS_SCHEMA, + handle_ask, + handle_get_event, + handle_get_summary, + handle_search_events, +) app = Server("aoc") -# --------------------------------------------------------------------------- -# Tool definitions -# --------------------------------------------------------------------------- - -_SEARCH_EVENTS_SCHEMA = { - "type": "object", - "properties": { - "entity": {"type": "string", "description": "Device name, user UPN, or email to search for"}, - "services": { - "type": "array", - "items": {"type": "string"}, - "description": "Filter by service (e.g. Intune, Directory, Exchange)", - }, - "operation": {"type": "string", "description": "Filter by operation name"}, - "result": {"type": "string", "description": "Filter by result (success, failure)"}, - "days": {"type": "integer", "description": "Number of days to look back (default 7)"}, - "limit": {"type": "integer", "description": "Max events to return (default 20)"}, - }, -} - -_GET_EVENT_SCHEMA = { - "type": "object", - "properties": { - "event_id": {"type": "string", "description": "The event ID to retrieve"}, - }, - "required": ["event_id"], -} - -_GET_SUMMARY_SCHEMA = { - "type": "object", - "properties": { - "days": {"type": "integer", "description": "Number of days to summarise (default 7)"}, - }, -} - -_ASK_SCHEMA = { - "type": "object", - "properties": { - "question": {"type": "string", "description": "Natural language question about audit logs"}, - "days": {"type": "integer", "description": "Number of days to look back (default 7)"}, - }, - "required": ["question"], -} - @app.list_tools() async def list_tools() -> list[Tool]: @@ -87,186 +50,35 @@ async def list_tools() -> list[Tool]: Tool( name="search_events", description="Search audit events by entity, service, operation, or result.", - inputSchema=_SEARCH_EVENTS_SCHEMA, + inputSchema=SEARCH_EVENTS_SCHEMA, ), - Tool(name="get_event", description="Retrieve a single audit event by its ID.", inputSchema=_GET_EVENT_SCHEMA), + Tool(name="get_event", description="Retrieve a single audit event by its ID.", inputSchema=GET_EVENT_SCHEMA), Tool( name="get_summary", description="Get an aggregated summary of audit activity for the last N days.", - inputSchema=_GET_SUMMARY_SCHEMA, + inputSchema=GET_SUMMARY_SCHEMA, ), Tool( name="ask", description="Ask a natural language question about audit logs. Returns a narrative answer.", - inputSchema=_ASK_SCHEMA, + inputSchema=ASK_SCHEMA, ), ] -# --------------------------------------------------------------------------- -# Tool handlers -# --------------------------------------------------------------------------- - - @app.call_tool() async def call_tool(name: str, arguments: dict) -> list[TextContent]: if name == "search_events": - return await _handle_search_events(arguments) + return await handle_search_events(arguments) if name == "get_event": - return await _handle_get_event(arguments) + return await handle_get_event(arguments) if name == "get_summary": - return await _handle_get_summary(arguments) + return await handle_get_summary(arguments) if name == "ask": - return await _handle_ask(arguments) + return await handle_ask(arguments) raise ValueError(f"Unknown tool: {name}") -async def _handle_search_events(arguments: dict) -> list[TextContent]: - days = arguments.get("days", 7) - limit = min(arguments.get("limit", 20), 100) - since = (datetime.now(UTC) - timedelta(days=days)).isoformat().replace("+00:00", "Z") - - filters = [{"timestamp": {"$gte": since}}] - - services = arguments.get("services") - if services: - filters.append({"service": {"$in": services}}) - - operation = arguments.get("operation") - if operation: - filters.append({"operation": {"$regex": operation, "$options": "i"}}) - - result = arguments.get("result") - if result: - filters.append({"result": {"$regex": result, "$options": "i"}}) - - entity = arguments.get("entity") - if entity: - entity_safe = entity.replace(".", "\\.").replace("(", "\\(").replace(")", "\\)") - filters.append( - { - "$or": [ - {"target_displays": {"$elemMatch": {"$regex": entity_safe, "$options": "i"}}}, - {"actor_display": {"$regex": entity_safe, "$options": "i"}}, - {"actor_upn": {"$regex": entity_safe, "$options": "i"}}, - {"raw_text": {"$regex": entity_safe, "$options": "i"}}, - ] - } - ) - - query = {"$and": filters} - cursor = events_collection.find(query).sort("timestamp", -1).limit(limit) - events = list(cursor) - - if not events: - return [TextContent(type="text", text="No matching events found.")] - - lines = [f"Found {len(events)} event(s):\n"] - for e in events: - ts = e.get("timestamp", "?")[:16].replace("T", " ") - svc = e.get("service", "?") - op = e.get("operation", "?") - actor = e.get("actor_display", "?") - result_str = e.get("result", "?") - lines.append(f"{ts} | {svc} | {op} | {actor} | {result_str}") - - return [TextContent(type="text", text="\n".join(lines))] - - -async def _handle_get_event(arguments: dict) -> list[TextContent]: - event_id = arguments["event_id"] - event = events_collection.find_one({"id": event_id}) - if not event: - return [TextContent(type="text", text=f"Event {event_id} not found.")] - event.pop("_id", None) - return [TextContent(type="text", text=json.dumps(event, indent=2, default=str))] - - -async def _handle_get_summary(arguments: dict) -> list[TextContent]: - days = arguments.get("days", 7) - since = (datetime.now(UTC) - timedelta(days=days)).isoformat().replace("+00:00", "Z") - query = {"timestamp": {"$gte": since}} - - total = events_collection.count_documents(query) - if total == 0: - return [TextContent(type="text", text="No events in the specified period.")] - - # Aggregation pipelines - svc_pipeline = [ - {"$match": query}, - {"$group": {"_id": "$service", "count": {"$sum": 1}}}, - {"$sort": {"count": -1}}, - {"$limit": 10}, - ] - op_pipeline = [ - {"$match": query}, - {"$group": {"_id": "$operation", "count": {"$sum": 1}}}, - {"$sort": {"count": -1}}, - {"$limit": 10}, - ] - result_pipeline = [ - {"$match": query}, - {"$group": {"_id": "$result", "count": {"$sum": 1}}}, - {"$sort": {"count": -1}}, - ] - actor_pipeline = [ - {"$match": query}, - {"$group": {"_id": "$actor_display", "count": {"$sum": 1}}}, - {"$sort": {"count": -1}}, - {"$limit": 10}, - ] - - svc_counts = list(events_collection.aggregate(svc_pipeline)) - op_counts = list(events_collection.aggregate(op_pipeline)) - result_counts = list(events_collection.aggregate(result_pipeline)) - actor_counts = list(events_collection.aggregate(actor_pipeline)) - - lines = [f"Summary for the last {days} days ({total} total events)\n"] - - lines.append("By service:") - for row in svc_counts: - lines.append(f" {row['_id'] or 'Unknown'}: {row['count']}") - - lines.append("\nBy action:") - for row in op_counts: - lines.append(f" {row['_id'] or 'Unknown'}: {row['count']}") - - lines.append("\nBy result:") - for row in result_counts: - lines.append(f" {row['_id'] or 'Unknown'}: {row['count']}") - - lines.append("\nTop actors:") - for row in actor_counts: - lines.append(f" {row['_id'] or 'Unknown'}: {row['count']}") - - return [TextContent(type="text", text="\n".join(lines))] - - -async def _handle_ask(arguments: dict) -> list[TextContent]: - """For now, the MCP 'ask' tool returns a helpful message directing the user to the web UI, - since the full NLQ pipeline requires LLM configuration that may not be available in the MCP context.""" - question = arguments["question"] - days = arguments.get("days", 7) - - # Perform a search to give the user something useful immediately - result = await _handle_search_events({"entity": "", "days": days, "limit": 50}) - base_text = result[0].text if result else "" - - text = ( - f"You asked: '{question}'\n\n" - f"Here are the most recent {min(50, base_text.count(chr(10)) - 1)} events from the last {days} days:\n\n" - f"{base_text}\n\n" - f"Tip: Use the 'search_events' tool with specific filters (services, operation, result) " - f"to narrow down the dataset before asking follow-up questions." - ) - return [TextContent(type="text", text=text)] - - -# --------------------------------------------------------------------------- -# Entry point -# --------------------------------------------------------------------------- - - async def main(): async with stdio_server() as (read_stream, write_stream): await app.run(read_stream, write_stream, app.create_initialization_options()) diff --git a/backend/routes/mcp.py b/backend/routes/mcp.py new file mode 100644 index 0000000..7a5e4a7 --- /dev/null +++ b/backend/routes/mcp.py @@ -0,0 +1,124 @@ +"""MCP server over SSE (HTTP) transport, mounted inside FastAPI with OIDC auth.""" + +import structlog +from auth import ( + AUTH_ALLOWED_GROUPS, + AUTH_ALLOWED_ROLES, + AUTH_ENABLED, + _allowed, + _decode_token, + _get_jwks, +) +from mcp.server import Server +from mcp.server.sse import SseServerTransport +from mcp.types import TextContent, Tool +from mcp_common import ( + ASK_SCHEMA, + GET_EVENT_SCHEMA, + GET_SUMMARY_SCHEMA, + SEARCH_EVENTS_SCHEMA, + handle_ask, + handle_get_event, + handle_get_summary, + handle_search_events, +) +from starlette.requests import Request +from starlette.responses import Response + +logger = structlog.get_logger("aoc.mcp") + +mcp_app = Server("aoc") +transport = SseServerTransport("/messages/") + + +@mcp_app.list_tools() +async def list_tools() -> list[Tool]: + return [ + Tool( + name="search_events", + description="Search audit events by entity, service, operation, or result.", + inputSchema=SEARCH_EVENTS_SCHEMA, + ), + Tool(name="get_event", description="Retrieve a single audit event by its ID.", inputSchema=GET_EVENT_SCHEMA), + Tool( + name="get_summary", + description="Get an aggregated summary of audit activity for the last N days.", + inputSchema=GET_SUMMARY_SCHEMA, + ), + Tool( + name="ask", + description="Ask a natural language question about audit logs. Returns a narrative answer.", + inputSchema=ASK_SCHEMA, + ), + ] + + +@mcp_app.call_tool() +async def call_tool(name: str, arguments: dict) -> list[TextContent]: + if name == "search_events": + return await handle_search_events(arguments) + if name == "get_event": + return await handle_get_event(arguments) + if name == "get_summary": + return await handle_get_summary(arguments) + if name == "ask": + return await handle_ask(arguments) + raise ValueError(f"Unknown tool: {name}") + + +async def _validate_auth(request: Request) -> dict | None: + """Validate Bearer token. Returns claims dict or None on failure.""" + if not AUTH_ENABLED: + return {"sub": "anonymous"} + + auth_header = request.headers.get("authorization", "") + if not auth_header or not auth_header.lower().startswith("bearer "): + return None + + token = auth_header.split(" ", 1)[1] + try: + jwks = _get_jwks() + claims = _decode_token(token, jwks) + except Exception as exc: + logger.warning("MCP auth failed", error=str(exc)) + return None + + if not _allowed(claims, AUTH_ALLOWED_ROLES, AUTH_ALLOWED_GROUPS): + logger.warning("MCP auth forbidden", sub=claims.get("sub")) + return None + + return claims + + +async def mcp_asgi(scope: dict, receive, send): + """ASGI application for MCP over SSE, mounted under /mcp in FastAPI.""" + if scope["type"] != "http": + return + + request = Request(scope, receive) + + # Auth check + claims = await _validate_auth(request) + if claims is None: + response = Response("Unauthorized", status_code=401) + await response(scope, receive, send) + return + + path = scope.get("path", "") + root_path = scope.get("root_path", "") + relative_path = path[len(root_path) :] if path.startswith(root_path) else path + method = scope.get("method", "") + + if relative_path == "/sse" and method == "GET": + logger.info("MCP SSE connection established", sub=claims.get("sub", "unknown")) + async with transport.connect_sse(scope, receive, send) as (read_stream, write_stream): + await mcp_app.run( + read_stream, + write_stream, + mcp_app.create_initialization_options(), + ) + elif relative_path == "/messages/" and method == "POST": + await transport.handle_post_message(scope, receive, send) + else: + response = Response("Not found", status_code=404) + await response(scope, receive, send) diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 6a36304..b69216b 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -30,6 +30,7 @@ def client(mock_events_collection, mock_watermarks_collection, monkeypatch): monkeypatch.setattr("routes.fetch.get_watermark", lambda source: None) monkeypatch.setattr("routes.fetch.set_watermark", lambda source, ts: None) monkeypatch.setattr("auth.AUTH_ENABLED", False) + monkeypatch.setattr("routes.mcp.AUTH_ENABLED", False) monkeypatch.setattr("database.db.command", lambda cmd: {"ok": 1} if cmd == "ping" else {}) # Mock audit trail and rules collections so tests don't wait on real MongoDB diff --git a/backend/tests/test_api.py b/backend/tests/test_api.py index 8e7e049..c683881 100644 --- a/backend/tests/test_api.py +++ b/backend/tests/test_api.py @@ -36,6 +36,25 @@ print('OK') assert "OK" in result.stdout +def test_mcp_sse_mount_exists(): + from main import app + + mcp_mounts = [r for r in app.routes if getattr(r, "path", "") == "/mcp"] + assert len(mcp_mounts) == 1, "MCP mount not found in app routes" + + +def test_mcp_messages_no_session(client): + response = client.post("/mcp/messages/") + # MCP transport returns 400 when session_id is missing, 404 when session not found + assert response.status_code in (400, 404) + + +def test_mcp_sse_auth_required_when_enabled(client, monkeypatch): + monkeypatch.setattr("routes.mcp.AUTH_ENABLED", True) + response = client.get("/mcp/sse") + assert response.status_code == 401 + + def test_health(client): response = client.get("/health") assert response.status_code == 200