feat: MCP server over SSE with OIDC auth
All checks were successful
CI / lint-and-test (push) Successful in 36s
All checks were successful
CI / lint-and-test (push) Successful in 36s
- Extract shared MCP tool handlers to mcp_common.py - mcp_server.py now uses shared handlers (stdio transport for local dev) - New routes/mcp.py: SSE transport behind existing OIDC Bearer auth - Mount MCP ASGI app at /mcp in main.py when AI_FEATURES_ENABLED - /mcp/sse -> establishes SSE stream (requires valid token when auth enabled) - /mcp/messages/ -> receives MCP client messages - Update README with SSE MCP docs - Add tests for mount existence, auth, and message routing
This commit is contained in:
124
backend/routes/mcp.py
Normal file
124
backend/routes/mcp.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""MCP server over SSE (HTTP) transport, mounted inside FastAPI with OIDC auth."""
|
||||
|
||||
import structlog
|
||||
from auth import (
|
||||
AUTH_ALLOWED_GROUPS,
|
||||
AUTH_ALLOWED_ROLES,
|
||||
AUTH_ENABLED,
|
||||
_allowed,
|
||||
_decode_token,
|
||||
_get_jwks,
|
||||
)
|
||||
from mcp.server import Server
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from mcp.types import TextContent, Tool
|
||||
from mcp_common import (
|
||||
ASK_SCHEMA,
|
||||
GET_EVENT_SCHEMA,
|
||||
GET_SUMMARY_SCHEMA,
|
||||
SEARCH_EVENTS_SCHEMA,
|
||||
handle_ask,
|
||||
handle_get_event,
|
||||
handle_get_summary,
|
||||
handle_search_events,
|
||||
)
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
logger = structlog.get_logger("aoc.mcp")
|
||||
|
||||
mcp_app = Server("aoc")
|
||||
transport = SseServerTransport("/messages/")
|
||||
|
||||
|
||||
@mcp_app.list_tools()
|
||||
async def list_tools() -> list[Tool]:
|
||||
return [
|
||||
Tool(
|
||||
name="search_events",
|
||||
description="Search audit events by entity, service, operation, or result.",
|
||||
inputSchema=SEARCH_EVENTS_SCHEMA,
|
||||
),
|
||||
Tool(name="get_event", description="Retrieve a single audit event by its ID.", inputSchema=GET_EVENT_SCHEMA),
|
||||
Tool(
|
||||
name="get_summary",
|
||||
description="Get an aggregated summary of audit activity for the last N days.",
|
||||
inputSchema=GET_SUMMARY_SCHEMA,
|
||||
),
|
||||
Tool(
|
||||
name="ask",
|
||||
description="Ask a natural language question about audit logs. Returns a narrative answer.",
|
||||
inputSchema=ASK_SCHEMA,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@mcp_app.call_tool()
|
||||
async def call_tool(name: str, arguments: dict) -> list[TextContent]:
|
||||
if name == "search_events":
|
||||
return await handle_search_events(arguments)
|
||||
if name == "get_event":
|
||||
return await handle_get_event(arguments)
|
||||
if name == "get_summary":
|
||||
return await handle_get_summary(arguments)
|
||||
if name == "ask":
|
||||
return await handle_ask(arguments)
|
||||
raise ValueError(f"Unknown tool: {name}")
|
||||
|
||||
|
||||
async def _validate_auth(request: Request) -> dict | None:
|
||||
"""Validate Bearer token. Returns claims dict or None on failure."""
|
||||
if not AUTH_ENABLED:
|
||||
return {"sub": "anonymous"}
|
||||
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
if not auth_header or not auth_header.lower().startswith("bearer "):
|
||||
return None
|
||||
|
||||
token = auth_header.split(" ", 1)[1]
|
||||
try:
|
||||
jwks = _get_jwks()
|
||||
claims = _decode_token(token, jwks)
|
||||
except Exception as exc:
|
||||
logger.warning("MCP auth failed", error=str(exc))
|
||||
return None
|
||||
|
||||
if not _allowed(claims, AUTH_ALLOWED_ROLES, AUTH_ALLOWED_GROUPS):
|
||||
logger.warning("MCP auth forbidden", sub=claims.get("sub"))
|
||||
return None
|
||||
|
||||
return claims
|
||||
|
||||
|
||||
async def mcp_asgi(scope: dict, receive, send):
|
||||
"""ASGI application for MCP over SSE, mounted under /mcp in FastAPI."""
|
||||
if scope["type"] != "http":
|
||||
return
|
||||
|
||||
request = Request(scope, receive)
|
||||
|
||||
# Auth check
|
||||
claims = await _validate_auth(request)
|
||||
if claims is None:
|
||||
response = Response("Unauthorized", status_code=401)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
path = scope.get("path", "")
|
||||
root_path = scope.get("root_path", "")
|
||||
relative_path = path[len(root_path) :] if path.startswith(root_path) else path
|
||||
method = scope.get("method", "")
|
||||
|
||||
if relative_path == "/sse" and method == "GET":
|
||||
logger.info("MCP SSE connection established", sub=claims.get("sub", "unknown"))
|
||||
async with transport.connect_sse(scope, receive, send) as (read_stream, write_stream):
|
||||
await mcp_app.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
mcp_app.create_initialization_options(),
|
||||
)
|
||||
elif relative_path == "/messages/" and method == "POST":
|
||||
await transport.handle_post_message(scope, receive, send)
|
||||
else:
|
||||
response = Response("Not found", status_code=404)
|
||||
await response(scope, receive, send)
|
||||
Reference in New Issue
Block a user