All checks were successful
CI / lint-and-test (push) Successful in 36s
- Extract shared MCP tool handlers to mcp_common.py - mcp_server.py now uses shared handlers (stdio transport for local dev) - New routes/mcp.py: SSE transport behind existing OIDC Bearer auth - Mount MCP ASGI app at /mcp in main.py when AI_FEATURES_ENABLED - /mcp/sse -> establishes SSE stream (requires valid token when auth enabled) - /mcp/messages/ -> receives MCP client messages - Update README with SSE MCP docs - Add tests for mount existence, auth, and message routing
125 lines
3.9 KiB
Python
125 lines
3.9 KiB
Python
"""MCP server over SSE (HTTP) transport, mounted inside FastAPI with OIDC auth."""
|
|
|
|
import structlog
|
|
from auth import (
|
|
AUTH_ALLOWED_GROUPS,
|
|
AUTH_ALLOWED_ROLES,
|
|
AUTH_ENABLED,
|
|
_allowed,
|
|
_decode_token,
|
|
_get_jwks,
|
|
)
|
|
from mcp.server import Server
|
|
from mcp.server.sse import SseServerTransport
|
|
from mcp.types import TextContent, Tool
|
|
from mcp_common import (
|
|
ASK_SCHEMA,
|
|
GET_EVENT_SCHEMA,
|
|
GET_SUMMARY_SCHEMA,
|
|
SEARCH_EVENTS_SCHEMA,
|
|
handle_ask,
|
|
handle_get_event,
|
|
handle_get_summary,
|
|
handle_search_events,
|
|
)
|
|
from starlette.requests import Request
|
|
from starlette.responses import Response
|
|
|
|
logger = structlog.get_logger("aoc.mcp")
|
|
|
|
mcp_app = Server("aoc")
|
|
transport = SseServerTransport("/messages/")
|
|
|
|
|
|
@mcp_app.list_tools()
|
|
async def list_tools() -> list[Tool]:
|
|
return [
|
|
Tool(
|
|
name="search_events",
|
|
description="Search audit events by entity, service, operation, or result.",
|
|
inputSchema=SEARCH_EVENTS_SCHEMA,
|
|
),
|
|
Tool(name="get_event", description="Retrieve a single audit event by its ID.", inputSchema=GET_EVENT_SCHEMA),
|
|
Tool(
|
|
name="get_summary",
|
|
description="Get an aggregated summary of audit activity for the last N days.",
|
|
inputSchema=GET_SUMMARY_SCHEMA,
|
|
),
|
|
Tool(
|
|
name="ask",
|
|
description="Ask a natural language question about audit logs. Returns a narrative answer.",
|
|
inputSchema=ASK_SCHEMA,
|
|
),
|
|
]
|
|
|
|
|
|
@mcp_app.call_tool()
|
|
async def call_tool(name: str, arguments: dict) -> list[TextContent]:
|
|
if name == "search_events":
|
|
return await handle_search_events(arguments)
|
|
if name == "get_event":
|
|
return await handle_get_event(arguments)
|
|
if name == "get_summary":
|
|
return await handle_get_summary(arguments)
|
|
if name == "ask":
|
|
return await handle_ask(arguments)
|
|
raise ValueError(f"Unknown tool: {name}")
|
|
|
|
|
|
async def _validate_auth(request: Request) -> dict | None:
|
|
"""Validate Bearer token. Returns claims dict or None on failure."""
|
|
if not AUTH_ENABLED:
|
|
return {"sub": "anonymous"}
|
|
|
|
auth_header = request.headers.get("authorization", "")
|
|
if not auth_header or not auth_header.lower().startswith("bearer "):
|
|
return None
|
|
|
|
token = auth_header.split(" ", 1)[1]
|
|
try:
|
|
jwks = _get_jwks()
|
|
claims = _decode_token(token, jwks)
|
|
except Exception as exc:
|
|
logger.warning("MCP auth failed", error=str(exc))
|
|
return None
|
|
|
|
if not _allowed(claims, AUTH_ALLOWED_ROLES, AUTH_ALLOWED_GROUPS):
|
|
logger.warning("MCP auth forbidden", sub=claims.get("sub"))
|
|
return None
|
|
|
|
return claims
|
|
|
|
|
|
async def mcp_asgi(scope: dict, receive, send):
|
|
"""ASGI application for MCP over SSE, mounted under /mcp in FastAPI."""
|
|
if scope["type"] != "http":
|
|
return
|
|
|
|
request = Request(scope, receive)
|
|
|
|
# Auth check
|
|
claims = await _validate_auth(request)
|
|
if claims is None:
|
|
response = Response("Unauthorized", status_code=401)
|
|
await response(scope, receive, send)
|
|
return
|
|
|
|
path = scope.get("path", "")
|
|
root_path = scope.get("root_path", "")
|
|
relative_path = path[len(root_path) :] if path.startswith(root_path) else path
|
|
method = scope.get("method", "")
|
|
|
|
if relative_path == "/sse" and method == "GET":
|
|
logger.info("MCP SSE connection established", sub=claims.get("sub", "unknown"))
|
|
async with transport.connect_sse(scope, receive, send) as (read_stream, write_stream):
|
|
await mcp_app.run(
|
|
read_stream,
|
|
write_stream,
|
|
mcp_app.create_initialization_options(),
|
|
)
|
|
elif relative_path == "/messages/" and method == "POST":
|
|
await transport.handle_post_message(scope, receive, send)
|
|
else:
|
|
response = Response("Not found", status_code=404)
|
|
await response(scope, receive, send)
|