Files
aoc/backend/routes/mcp.py
Tomas Kracmar 5122739c01
All checks were successful
CI / lint-and-test (push) Successful in 36s
feat: MCP server over SSE with OIDC auth
- Extract shared MCP tool handlers to mcp_common.py
- mcp_server.py now uses shared handlers (stdio transport for local dev)
- New routes/mcp.py: SSE transport behind existing OIDC Bearer auth
- Mount MCP ASGI app at /mcp in main.py when AI_FEATURES_ENABLED
- /mcp/sse  -> establishes SSE stream (requires valid token when auth enabled)
- /mcp/messages/ -> receives MCP client messages
- Update README with SSE MCP docs
- Add tests for mount existence, auth, and message routing
2026-04-21 07:38:12 +02:00

125 lines
3.9 KiB
Python

"""MCP server over SSE (HTTP) transport, mounted inside FastAPI with OIDC auth."""
import structlog
from auth import (
AUTH_ALLOWED_GROUPS,
AUTH_ALLOWED_ROLES,
AUTH_ENABLED,
_allowed,
_decode_token,
_get_jwks,
)
from mcp.server import Server
from mcp.server.sse import SseServerTransport
from mcp.types import TextContent, Tool
from mcp_common import (
ASK_SCHEMA,
GET_EVENT_SCHEMA,
GET_SUMMARY_SCHEMA,
SEARCH_EVENTS_SCHEMA,
handle_ask,
handle_get_event,
handle_get_summary,
handle_search_events,
)
from starlette.requests import Request
from starlette.responses import Response
logger = structlog.get_logger("aoc.mcp")
mcp_app = Server("aoc")
transport = SseServerTransport("/messages/")
@mcp_app.list_tools()
async def list_tools() -> list[Tool]:
return [
Tool(
name="search_events",
description="Search audit events by entity, service, operation, or result.",
inputSchema=SEARCH_EVENTS_SCHEMA,
),
Tool(name="get_event", description="Retrieve a single audit event by its ID.", inputSchema=GET_EVENT_SCHEMA),
Tool(
name="get_summary",
description="Get an aggregated summary of audit activity for the last N days.",
inputSchema=GET_SUMMARY_SCHEMA,
),
Tool(
name="ask",
description="Ask a natural language question about audit logs. Returns a narrative answer.",
inputSchema=ASK_SCHEMA,
),
]
@mcp_app.call_tool()
async def call_tool(name: str, arguments: dict) -> list[TextContent]:
if name == "search_events":
return await handle_search_events(arguments)
if name == "get_event":
return await handle_get_event(arguments)
if name == "get_summary":
return await handle_get_summary(arguments)
if name == "ask":
return await handle_ask(arguments)
raise ValueError(f"Unknown tool: {name}")
async def _validate_auth(request: Request) -> dict | None:
"""Validate Bearer token. Returns claims dict or None on failure."""
if not AUTH_ENABLED:
return {"sub": "anonymous"}
auth_header = request.headers.get("authorization", "")
if not auth_header or not auth_header.lower().startswith("bearer "):
return None
token = auth_header.split(" ", 1)[1]
try:
jwks = _get_jwks()
claims = _decode_token(token, jwks)
except Exception as exc:
logger.warning("MCP auth failed", error=str(exc))
return None
if not _allowed(claims, AUTH_ALLOWED_ROLES, AUTH_ALLOWED_GROUPS):
logger.warning("MCP auth forbidden", sub=claims.get("sub"))
return None
return claims
async def mcp_asgi(scope: dict, receive, send):
"""ASGI application for MCP over SSE, mounted under /mcp in FastAPI."""
if scope["type"] != "http":
return
request = Request(scope, receive)
# Auth check
claims = await _validate_auth(request)
if claims is None:
response = Response("Unauthorized", status_code=401)
await response(scope, receive, send)
return
path = scope.get("path", "")
root_path = scope.get("root_path", "")
relative_path = path[len(root_path) :] if path.startswith(root_path) else path
method = scope.get("method", "")
if relative_path == "/sse" and method == "GET":
logger.info("MCP SSE connection established", sub=claims.get("sub", "unknown"))
async with transport.connect_sse(scope, receive, send) as (read_stream, write_stream):
await mcp_app.run(
read_stream,
write_stream,
mcp_app.create_initialization_options(),
)
elif relative_path == "/messages/" and method == "POST":
await transport.handle_post_message(scope, receive, send)
else:
response = Response("Not found", status_code=404)
await response(scope, receive, send)