"""MCP server over SSE (HTTP) transport, mounted inside FastAPI with OIDC auth.""" import structlog from auth import ( AUTH_ALLOWED_GROUPS, AUTH_ALLOWED_ROLES, AUTH_ENABLED, _allowed, _decode_token, _get_jwks, ) from mcp.server import Server from mcp.server.sse import SseServerTransport from mcp.types import TextContent, Tool from mcp_common import ( ASK_SCHEMA, GET_EVENT_SCHEMA, GET_SUMMARY_SCHEMA, SEARCH_EVENTS_SCHEMA, handle_ask, handle_get_event, handle_get_summary, handle_search_events, ) from starlette.requests import Request from starlette.responses import Response logger = structlog.get_logger("aoc.mcp") mcp_app = Server("aoc") transport = SseServerTransport("/messages/") @mcp_app.list_tools() async def list_tools() -> list[Tool]: return [ Tool( name="search_events", description="Search audit events by entity, service, operation, or result.", inputSchema=SEARCH_EVENTS_SCHEMA, ), Tool(name="get_event", description="Retrieve a single audit event by its ID.", inputSchema=GET_EVENT_SCHEMA), Tool( name="get_summary", description="Get an aggregated summary of audit activity for the last N days.", inputSchema=GET_SUMMARY_SCHEMA, ), Tool( name="ask", description="Ask a natural language question about audit logs. Returns a narrative answer.", inputSchema=ASK_SCHEMA, ), ] @mcp_app.call_tool() async def call_tool(name: str, arguments: dict) -> list[TextContent]: if name == "search_events": return await handle_search_events(arguments) if name == "get_event": return await handle_get_event(arguments) if name == "get_summary": return await handle_get_summary(arguments) if name == "ask": return await handle_ask(arguments) raise ValueError(f"Unknown tool: {name}") async def _validate_auth(request: Request) -> dict | None: """Validate Bearer token. Returns claims dict or None on failure.""" if not AUTH_ENABLED: return {"sub": "anonymous"} auth_header = request.headers.get("authorization", "") if not auth_header or not auth_header.lower().startswith("bearer "): return None token = auth_header.split(" ", 1)[1] try: jwks = _get_jwks() claims = _decode_token(token, jwks) except Exception as exc: logger.warning("MCP auth failed", error=str(exc)) return None if not _allowed(claims, AUTH_ALLOWED_ROLES, AUTH_ALLOWED_GROUPS): logger.warning("MCP auth forbidden", sub=claims.get("sub")) return None return claims async def mcp_asgi(scope: dict, receive, send): """ASGI application for MCP over SSE, mounted under /mcp in FastAPI.""" if scope["type"] != "http": return request = Request(scope, receive) # Auth check claims = await _validate_auth(request) if claims is None: response = Response("Unauthorized", status_code=401) await response(scope, receive, send) return path = scope.get("path", "") root_path = scope.get("root_path", "") relative_path = path[len(root_path) :] if path.startswith(root_path) else path method = scope.get("method", "") if relative_path == "/sse" and method == "GET": logger.info("MCP SSE connection established", sub=claims.get("sub", "unknown")) async with transport.connect_sse(scope, receive, send) as (read_stream, write_stream): await mcp_app.run( read_stream, write_stream, mcp_app.create_initialization_options(), ) elif relative_path == "/messages/" and method == "POST": await transport.handle_post_message(scope, receive, send) else: response = Response("Not found", status_code=404) await response(scope, receive, send)