feat: MCP server over SSE with OIDC auth
All checks were successful
CI / lint-and-test (push) Successful in 36s
All checks were successful
CI / lint-and-test (push) Successful in 36s
- Extract shared MCP tool handlers to mcp_common.py - mcp_server.py now uses shared handlers (stdio transport for local dev) - New routes/mcp.py: SSE transport behind existing OIDC Bearer auth - Mount MCP ASGI app at /mcp in main.py when AI_FEATURES_ENABLED - /mcp/sse -> establishes SSE stream (requires valid token when auth enabled) - /mcp/messages/ -> receives MCP client messages - Update README with SSE MCP docs - Add tests for mount existence, auth, and message routing
This commit is contained in:
@@ -116,6 +116,9 @@ if AI_FEATURES_ENABLED:
|
||||
from routes.ask import router as ask_router
|
||||
|
||||
app.include_router(ask_router, prefix="/api")
|
||||
from routes.mcp import mcp_asgi
|
||||
|
||||
app.mount("/mcp", mcp_asgi)
|
||||
app.include_router(rules_router, prefix="/api")
|
||||
|
||||
|
||||
|
||||
187
backend/mcp_common.py
Normal file
187
backend/mcp_common.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""Shared MCP tool handlers used by both stdio and SSE transports."""
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from database import events_collection
|
||||
from mcp.types import TextContent
|
||||
|
||||
|
||||
async def handle_search_events(arguments: dict) -> list[TextContent]:
|
||||
days = arguments.get("days", 7)
|
||||
limit = min(arguments.get("limit", 20), 100)
|
||||
since = (datetime.now(UTC) - timedelta(days=days)).isoformat().replace("+00:00", "Z")
|
||||
|
||||
filters = [{"timestamp": {"$gte": since}}]
|
||||
|
||||
services = arguments.get("services")
|
||||
if services:
|
||||
filters.append({"service": {"$in": services}})
|
||||
|
||||
operation = arguments.get("operation")
|
||||
if operation:
|
||||
filters.append({"operation": {"$regex": operation, "$options": "i"}})
|
||||
|
||||
result = arguments.get("result")
|
||||
if result:
|
||||
filters.append({"result": {"$regex": result, "$options": "i"}})
|
||||
|
||||
entity = arguments.get("entity")
|
||||
if entity:
|
||||
entity_safe = entity.replace(".", "\\.").replace("(", "\\(").replace(")", "\\)")
|
||||
filters.append(
|
||||
{
|
||||
"$or": [
|
||||
{"target_displays": {"$elemMatch": {"$regex": entity_safe, "$options": "i"}}},
|
||||
{"actor_display": {"$regex": entity_safe, "$options": "i"}},
|
||||
{"actor_upn": {"$regex": entity_safe, "$options": "i"}},
|
||||
{"raw_text": {"$regex": entity_safe, "$options": "i"}},
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
query = {"$and": filters}
|
||||
cursor = events_collection.find(query).sort("timestamp", -1).limit(limit)
|
||||
events = list(cursor)
|
||||
|
||||
if not events:
|
||||
return [TextContent(type="text", text="No matching events found.")]
|
||||
|
||||
lines = [f"Found {len(events)} event(s):\n"]
|
||||
for e in events:
|
||||
ts = e.get("timestamp", "?")[:16].replace("T", " ")
|
||||
svc = e.get("service", "?")
|
||||
op = e.get("operation", "?")
|
||||
actor = e.get("actor_display", "?")
|
||||
result_str = e.get("result", "?")
|
||||
lines.append(f"{ts} | {svc} | {op} | {actor} | {result_str}")
|
||||
|
||||
return [TextContent(type="text", text="\n".join(lines))]
|
||||
|
||||
|
||||
async def handle_get_event(arguments: dict) -> list[TextContent]:
|
||||
event_id = arguments["event_id"]
|
||||
event = events_collection.find_one({"id": event_id})
|
||||
if not event:
|
||||
return [TextContent(type="text", text=f"Event {event_id} not found.")]
|
||||
event.pop("_id", None)
|
||||
return [TextContent(type="text", text=json.dumps(event, indent=2, default=str))]
|
||||
|
||||
|
||||
async def handle_get_summary(arguments: dict) -> list[TextContent]:
|
||||
days = arguments.get("days", 7)
|
||||
since = (datetime.now(UTC) - timedelta(days=days)).isoformat().replace("+00:00", "Z")
|
||||
query = {"timestamp": {"$gte": since}}
|
||||
|
||||
total = events_collection.count_documents(query)
|
||||
if total == 0:
|
||||
return [TextContent(type="text", text="No events in the specified period.")]
|
||||
|
||||
svc_pipeline = [
|
||||
{"$match": query},
|
||||
{"$group": {"_id": "$service", "count": {"$sum": 1}}},
|
||||
{"$sort": {"count": -1}},
|
||||
{"$limit": 10},
|
||||
]
|
||||
op_pipeline = [
|
||||
{"$match": query},
|
||||
{"$group": {"_id": "$operation", "count": {"$sum": 1}}},
|
||||
{"$sort": {"count": -1}},
|
||||
{"$limit": 10},
|
||||
]
|
||||
result_pipeline = [
|
||||
{"$match": query},
|
||||
{"$group": {"_id": "$result", "count": {"$sum": 1}}},
|
||||
{"$sort": {"count": -1}},
|
||||
]
|
||||
actor_pipeline = [
|
||||
{"$match": query},
|
||||
{"$group": {"_id": "$actor_display", "count": {"$sum": 1}}},
|
||||
{"$sort": {"count": -1}},
|
||||
{"$limit": 10},
|
||||
]
|
||||
|
||||
svc_counts = list(events_collection.aggregate(svc_pipeline))
|
||||
op_counts = list(events_collection.aggregate(op_pipeline))
|
||||
result_counts = list(events_collection.aggregate(result_pipeline))
|
||||
actor_counts = list(events_collection.aggregate(actor_pipeline))
|
||||
|
||||
lines = [f"Summary for the last {days} days ({total} total events)\n"]
|
||||
|
||||
lines.append("By service:")
|
||||
for row in svc_counts:
|
||||
lines.append(f" {row['_id'] or 'Unknown'}: {row['count']}")
|
||||
|
||||
lines.append("\nBy action:")
|
||||
for row in op_counts:
|
||||
lines.append(f" {row['_id'] or 'Unknown'}: {row['count']}")
|
||||
|
||||
lines.append("\nBy result:")
|
||||
for row in result_counts:
|
||||
lines.append(f" {row['_id'] or 'Unknown'}: {row['count']}")
|
||||
|
||||
lines.append("\nTop actors:")
|
||||
for row in actor_counts:
|
||||
lines.append(f" {row['_id'] or 'Unknown'}: {row['count']}")
|
||||
|
||||
return [TextContent(type="text", text="\n".join(lines))]
|
||||
|
||||
|
||||
async def handle_ask(arguments: dict) -> list[TextContent]:
|
||||
"""For now, returns recent events + guidance. In the future this could call the LLM backend."""
|
||||
question = arguments["question"]
|
||||
days = arguments.get("days", 7)
|
||||
|
||||
result = await handle_search_events({"entity": "", "days": days, "limit": 50})
|
||||
base_text = result[0].text if result else ""
|
||||
|
||||
text = (
|
||||
f"You asked: '{question}'\n\n"
|
||||
f"Here are the most recent events from the last {days} days:\n\n"
|
||||
f"{base_text}\n\n"
|
||||
f"Tip: Use the 'search_events' tool with specific filters "
|
||||
f"to narrow down the dataset before asking follow-up questions."
|
||||
)
|
||||
return [TextContent(type="text", text=text)]
|
||||
|
||||
|
||||
# JSON schemas for tool definitions
|
||||
SEARCH_EVENTS_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"entity": {"type": "string", "description": "Device name, user UPN, or email to search for"},
|
||||
"services": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Filter by service (e.g. Intune, Directory, Exchange)",
|
||||
},
|
||||
"operation": {"type": "string", "description": "Filter by operation name"},
|
||||
"result": {"type": "string", "description": "Filter by result (success, failure)"},
|
||||
"days": {"type": "integer", "description": "Number of days to look back (default 7)"},
|
||||
"limit": {"type": "integer", "description": "Max events to return (default 20)"},
|
||||
},
|
||||
}
|
||||
|
||||
GET_EVENT_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"event_id": {"type": "string", "description": "The event ID to retrieve"},
|
||||
},
|
||||
"required": ["event_id"],
|
||||
}
|
||||
|
||||
GET_SUMMARY_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"days": {"type": "integer", "description": "Number of days to summarise (default 7)"},
|
||||
},
|
||||
}
|
||||
|
||||
ASK_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"question": {"type": "string", "description": "Natural language question about audit logs"},
|
||||
"days": {"type": "integer", "description": "Number of days to look back (default 7)"},
|
||||
},
|
||||
"required": ["question"],
|
||||
}
|
||||
@@ -1,9 +1,9 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
AOC MCP Server
|
||||
AOC MCP Server — stdio transport
|
||||
|
||||
Standalone MCP server that exposes audit log search tools for Claude Desktop,
|
||||
Cursor, and other MCP clients.
|
||||
Standalone MCP server for local use (Claude Desktop, Cursor, etc.).
|
||||
For the HTTP/SSE version (production, behind auth), see routes/mcp.py.
|
||||
|
||||
Usage:
|
||||
python mcp_server.py
|
||||
@@ -21,65 +21,28 @@ Claude Desktop config (~/.config/claude/claude_desktop_config.json):
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
# Ensure backend modules are importable
|
||||
# Ensure backend modules are importable when run standalone
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from database import events_collection
|
||||
from mcp.server import Server
|
||||
from mcp.server.stdio import stdio_server
|
||||
from mcp.types import TextContent, Tool
|
||||
from mcp_common import (
|
||||
ASK_SCHEMA,
|
||||
GET_EVENT_SCHEMA,
|
||||
GET_SUMMARY_SCHEMA,
|
||||
SEARCH_EVENTS_SCHEMA,
|
||||
handle_ask,
|
||||
handle_get_event,
|
||||
handle_get_summary,
|
||||
handle_search_events,
|
||||
)
|
||||
|
||||
app = Server("aoc")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool definitions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SEARCH_EVENTS_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"entity": {"type": "string", "description": "Device name, user UPN, or email to search for"},
|
||||
"services": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Filter by service (e.g. Intune, Directory, Exchange)",
|
||||
},
|
||||
"operation": {"type": "string", "description": "Filter by operation name"},
|
||||
"result": {"type": "string", "description": "Filter by result (success, failure)"},
|
||||
"days": {"type": "integer", "description": "Number of days to look back (default 7)"},
|
||||
"limit": {"type": "integer", "description": "Max events to return (default 20)"},
|
||||
},
|
||||
}
|
||||
|
||||
_GET_EVENT_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"event_id": {"type": "string", "description": "The event ID to retrieve"},
|
||||
},
|
||||
"required": ["event_id"],
|
||||
}
|
||||
|
||||
_GET_SUMMARY_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"days": {"type": "integer", "description": "Number of days to summarise (default 7)"},
|
||||
},
|
||||
}
|
||||
|
||||
_ASK_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"question": {"type": "string", "description": "Natural language question about audit logs"},
|
||||
"days": {"type": "integer", "description": "Number of days to look back (default 7)"},
|
||||
},
|
||||
"required": ["question"],
|
||||
}
|
||||
|
||||
|
||||
@app.list_tools()
|
||||
async def list_tools() -> list[Tool]:
|
||||
@@ -87,186 +50,35 @@ async def list_tools() -> list[Tool]:
|
||||
Tool(
|
||||
name="search_events",
|
||||
description="Search audit events by entity, service, operation, or result.",
|
||||
inputSchema=_SEARCH_EVENTS_SCHEMA,
|
||||
inputSchema=SEARCH_EVENTS_SCHEMA,
|
||||
),
|
||||
Tool(name="get_event", description="Retrieve a single audit event by its ID.", inputSchema=_GET_EVENT_SCHEMA),
|
||||
Tool(name="get_event", description="Retrieve a single audit event by its ID.", inputSchema=GET_EVENT_SCHEMA),
|
||||
Tool(
|
||||
name="get_summary",
|
||||
description="Get an aggregated summary of audit activity for the last N days.",
|
||||
inputSchema=_GET_SUMMARY_SCHEMA,
|
||||
inputSchema=GET_SUMMARY_SCHEMA,
|
||||
),
|
||||
Tool(
|
||||
name="ask",
|
||||
description="Ask a natural language question about audit logs. Returns a narrative answer.",
|
||||
inputSchema=_ASK_SCHEMA,
|
||||
inputSchema=ASK_SCHEMA,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool handlers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@app.call_tool()
|
||||
async def call_tool(name: str, arguments: dict) -> list[TextContent]:
|
||||
if name == "search_events":
|
||||
return await _handle_search_events(arguments)
|
||||
return await handle_search_events(arguments)
|
||||
if name == "get_event":
|
||||
return await _handle_get_event(arguments)
|
||||
return await handle_get_event(arguments)
|
||||
if name == "get_summary":
|
||||
return await _handle_get_summary(arguments)
|
||||
return await handle_get_summary(arguments)
|
||||
if name == "ask":
|
||||
return await _handle_ask(arguments)
|
||||
return await handle_ask(arguments)
|
||||
raise ValueError(f"Unknown tool: {name}")
|
||||
|
||||
|
||||
async def _handle_search_events(arguments: dict) -> list[TextContent]:
|
||||
days = arguments.get("days", 7)
|
||||
limit = min(arguments.get("limit", 20), 100)
|
||||
since = (datetime.now(UTC) - timedelta(days=days)).isoformat().replace("+00:00", "Z")
|
||||
|
||||
filters = [{"timestamp": {"$gte": since}}]
|
||||
|
||||
services = arguments.get("services")
|
||||
if services:
|
||||
filters.append({"service": {"$in": services}})
|
||||
|
||||
operation = arguments.get("operation")
|
||||
if operation:
|
||||
filters.append({"operation": {"$regex": operation, "$options": "i"}})
|
||||
|
||||
result = arguments.get("result")
|
||||
if result:
|
||||
filters.append({"result": {"$regex": result, "$options": "i"}})
|
||||
|
||||
entity = arguments.get("entity")
|
||||
if entity:
|
||||
entity_safe = entity.replace(".", "\\.").replace("(", "\\(").replace(")", "\\)")
|
||||
filters.append(
|
||||
{
|
||||
"$or": [
|
||||
{"target_displays": {"$elemMatch": {"$regex": entity_safe, "$options": "i"}}},
|
||||
{"actor_display": {"$regex": entity_safe, "$options": "i"}},
|
||||
{"actor_upn": {"$regex": entity_safe, "$options": "i"}},
|
||||
{"raw_text": {"$regex": entity_safe, "$options": "i"}},
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
query = {"$and": filters}
|
||||
cursor = events_collection.find(query).sort("timestamp", -1).limit(limit)
|
||||
events = list(cursor)
|
||||
|
||||
if not events:
|
||||
return [TextContent(type="text", text="No matching events found.")]
|
||||
|
||||
lines = [f"Found {len(events)} event(s):\n"]
|
||||
for e in events:
|
||||
ts = e.get("timestamp", "?")[:16].replace("T", " ")
|
||||
svc = e.get("service", "?")
|
||||
op = e.get("operation", "?")
|
||||
actor = e.get("actor_display", "?")
|
||||
result_str = e.get("result", "?")
|
||||
lines.append(f"{ts} | {svc} | {op} | {actor} | {result_str}")
|
||||
|
||||
return [TextContent(type="text", text="\n".join(lines))]
|
||||
|
||||
|
||||
async def _handle_get_event(arguments: dict) -> list[TextContent]:
|
||||
event_id = arguments["event_id"]
|
||||
event = events_collection.find_one({"id": event_id})
|
||||
if not event:
|
||||
return [TextContent(type="text", text=f"Event {event_id} not found.")]
|
||||
event.pop("_id", None)
|
||||
return [TextContent(type="text", text=json.dumps(event, indent=2, default=str))]
|
||||
|
||||
|
||||
async def _handle_get_summary(arguments: dict) -> list[TextContent]:
|
||||
days = arguments.get("days", 7)
|
||||
since = (datetime.now(UTC) - timedelta(days=days)).isoformat().replace("+00:00", "Z")
|
||||
query = {"timestamp": {"$gte": since}}
|
||||
|
||||
total = events_collection.count_documents(query)
|
||||
if total == 0:
|
||||
return [TextContent(type="text", text="No events in the specified period.")]
|
||||
|
||||
# Aggregation pipelines
|
||||
svc_pipeline = [
|
||||
{"$match": query},
|
||||
{"$group": {"_id": "$service", "count": {"$sum": 1}}},
|
||||
{"$sort": {"count": -1}},
|
||||
{"$limit": 10},
|
||||
]
|
||||
op_pipeline = [
|
||||
{"$match": query},
|
||||
{"$group": {"_id": "$operation", "count": {"$sum": 1}}},
|
||||
{"$sort": {"count": -1}},
|
||||
{"$limit": 10},
|
||||
]
|
||||
result_pipeline = [
|
||||
{"$match": query},
|
||||
{"$group": {"_id": "$result", "count": {"$sum": 1}}},
|
||||
{"$sort": {"count": -1}},
|
||||
]
|
||||
actor_pipeline = [
|
||||
{"$match": query},
|
||||
{"$group": {"_id": "$actor_display", "count": {"$sum": 1}}},
|
||||
{"$sort": {"count": -1}},
|
||||
{"$limit": 10},
|
||||
]
|
||||
|
||||
svc_counts = list(events_collection.aggregate(svc_pipeline))
|
||||
op_counts = list(events_collection.aggregate(op_pipeline))
|
||||
result_counts = list(events_collection.aggregate(result_pipeline))
|
||||
actor_counts = list(events_collection.aggregate(actor_pipeline))
|
||||
|
||||
lines = [f"Summary for the last {days} days ({total} total events)\n"]
|
||||
|
||||
lines.append("By service:")
|
||||
for row in svc_counts:
|
||||
lines.append(f" {row['_id'] or 'Unknown'}: {row['count']}")
|
||||
|
||||
lines.append("\nBy action:")
|
||||
for row in op_counts:
|
||||
lines.append(f" {row['_id'] or 'Unknown'}: {row['count']}")
|
||||
|
||||
lines.append("\nBy result:")
|
||||
for row in result_counts:
|
||||
lines.append(f" {row['_id'] or 'Unknown'}: {row['count']}")
|
||||
|
||||
lines.append("\nTop actors:")
|
||||
for row in actor_counts:
|
||||
lines.append(f" {row['_id'] or 'Unknown'}: {row['count']}")
|
||||
|
||||
return [TextContent(type="text", text="\n".join(lines))]
|
||||
|
||||
|
||||
async def _handle_ask(arguments: dict) -> list[TextContent]:
|
||||
"""For now, the MCP 'ask' tool returns a helpful message directing the user to the web UI,
|
||||
since the full NLQ pipeline requires LLM configuration that may not be available in the MCP context."""
|
||||
question = arguments["question"]
|
||||
days = arguments.get("days", 7)
|
||||
|
||||
# Perform a search to give the user something useful immediately
|
||||
result = await _handle_search_events({"entity": "", "days": days, "limit": 50})
|
||||
base_text = result[0].text if result else ""
|
||||
|
||||
text = (
|
||||
f"You asked: '{question}'\n\n"
|
||||
f"Here are the most recent {min(50, base_text.count(chr(10)) - 1)} events from the last {days} days:\n\n"
|
||||
f"{base_text}\n\n"
|
||||
f"Tip: Use the 'search_events' tool with specific filters (services, operation, result) "
|
||||
f"to narrow down the dataset before asking follow-up questions."
|
||||
)
|
||||
return [TextContent(type="text", text=text)]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def main():
|
||||
async with stdio_server() as (read_stream, write_stream):
|
||||
await app.run(read_stream, write_stream, app.create_initialization_options())
|
||||
|
||||
124
backend/routes/mcp.py
Normal file
124
backend/routes/mcp.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""MCP server over SSE (HTTP) transport, mounted inside FastAPI with OIDC auth."""
|
||||
|
||||
import structlog
|
||||
from auth import (
|
||||
AUTH_ALLOWED_GROUPS,
|
||||
AUTH_ALLOWED_ROLES,
|
||||
AUTH_ENABLED,
|
||||
_allowed,
|
||||
_decode_token,
|
||||
_get_jwks,
|
||||
)
|
||||
from mcp.server import Server
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from mcp.types import TextContent, Tool
|
||||
from mcp_common import (
|
||||
ASK_SCHEMA,
|
||||
GET_EVENT_SCHEMA,
|
||||
GET_SUMMARY_SCHEMA,
|
||||
SEARCH_EVENTS_SCHEMA,
|
||||
handle_ask,
|
||||
handle_get_event,
|
||||
handle_get_summary,
|
||||
handle_search_events,
|
||||
)
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
logger = structlog.get_logger("aoc.mcp")
|
||||
|
||||
mcp_app = Server("aoc")
|
||||
transport = SseServerTransport("/messages/")
|
||||
|
||||
|
||||
@mcp_app.list_tools()
|
||||
async def list_tools() -> list[Tool]:
|
||||
return [
|
||||
Tool(
|
||||
name="search_events",
|
||||
description="Search audit events by entity, service, operation, or result.",
|
||||
inputSchema=SEARCH_EVENTS_SCHEMA,
|
||||
),
|
||||
Tool(name="get_event", description="Retrieve a single audit event by its ID.", inputSchema=GET_EVENT_SCHEMA),
|
||||
Tool(
|
||||
name="get_summary",
|
||||
description="Get an aggregated summary of audit activity for the last N days.",
|
||||
inputSchema=GET_SUMMARY_SCHEMA,
|
||||
),
|
||||
Tool(
|
||||
name="ask",
|
||||
description="Ask a natural language question about audit logs. Returns a narrative answer.",
|
||||
inputSchema=ASK_SCHEMA,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@mcp_app.call_tool()
|
||||
async def call_tool(name: str, arguments: dict) -> list[TextContent]:
|
||||
if name == "search_events":
|
||||
return await handle_search_events(arguments)
|
||||
if name == "get_event":
|
||||
return await handle_get_event(arguments)
|
||||
if name == "get_summary":
|
||||
return await handle_get_summary(arguments)
|
||||
if name == "ask":
|
||||
return await handle_ask(arguments)
|
||||
raise ValueError(f"Unknown tool: {name}")
|
||||
|
||||
|
||||
async def _validate_auth(request: Request) -> dict | None:
|
||||
"""Validate Bearer token. Returns claims dict or None on failure."""
|
||||
if not AUTH_ENABLED:
|
||||
return {"sub": "anonymous"}
|
||||
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
if not auth_header or not auth_header.lower().startswith("bearer "):
|
||||
return None
|
||||
|
||||
token = auth_header.split(" ", 1)[1]
|
||||
try:
|
||||
jwks = _get_jwks()
|
||||
claims = _decode_token(token, jwks)
|
||||
except Exception as exc:
|
||||
logger.warning("MCP auth failed", error=str(exc))
|
||||
return None
|
||||
|
||||
if not _allowed(claims, AUTH_ALLOWED_ROLES, AUTH_ALLOWED_GROUPS):
|
||||
logger.warning("MCP auth forbidden", sub=claims.get("sub"))
|
||||
return None
|
||||
|
||||
return claims
|
||||
|
||||
|
||||
async def mcp_asgi(scope: dict, receive, send):
|
||||
"""ASGI application for MCP over SSE, mounted under /mcp in FastAPI."""
|
||||
if scope["type"] != "http":
|
||||
return
|
||||
|
||||
request = Request(scope, receive)
|
||||
|
||||
# Auth check
|
||||
claims = await _validate_auth(request)
|
||||
if claims is None:
|
||||
response = Response("Unauthorized", status_code=401)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
path = scope.get("path", "")
|
||||
root_path = scope.get("root_path", "")
|
||||
relative_path = path[len(root_path) :] if path.startswith(root_path) else path
|
||||
method = scope.get("method", "")
|
||||
|
||||
if relative_path == "/sse" and method == "GET":
|
||||
logger.info("MCP SSE connection established", sub=claims.get("sub", "unknown"))
|
||||
async with transport.connect_sse(scope, receive, send) as (read_stream, write_stream):
|
||||
await mcp_app.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
mcp_app.create_initialization_options(),
|
||||
)
|
||||
elif relative_path == "/messages/" and method == "POST":
|
||||
await transport.handle_post_message(scope, receive, send)
|
||||
else:
|
||||
response = Response("Not found", status_code=404)
|
||||
await response(scope, receive, send)
|
||||
@@ -30,6 +30,7 @@ def client(mock_events_collection, mock_watermarks_collection, monkeypatch):
|
||||
monkeypatch.setattr("routes.fetch.get_watermark", lambda source: None)
|
||||
monkeypatch.setattr("routes.fetch.set_watermark", lambda source, ts: None)
|
||||
monkeypatch.setattr("auth.AUTH_ENABLED", False)
|
||||
monkeypatch.setattr("routes.mcp.AUTH_ENABLED", False)
|
||||
monkeypatch.setattr("database.db.command", lambda cmd: {"ok": 1} if cmd == "ping" else {})
|
||||
|
||||
# Mock audit trail and rules collections so tests don't wait on real MongoDB
|
||||
|
||||
@@ -36,6 +36,25 @@ print('OK')
|
||||
assert "OK" in result.stdout
|
||||
|
||||
|
||||
def test_mcp_sse_mount_exists():
|
||||
from main import app
|
||||
|
||||
mcp_mounts = [r for r in app.routes if getattr(r, "path", "") == "/mcp"]
|
||||
assert len(mcp_mounts) == 1, "MCP mount not found in app routes"
|
||||
|
||||
|
||||
def test_mcp_messages_no_session(client):
|
||||
response = client.post("/mcp/messages/")
|
||||
# MCP transport returns 400 when session_id is missing, 404 when session not found
|
||||
assert response.status_code in (400, 404)
|
||||
|
||||
|
||||
def test_mcp_sse_auth_required_when_enabled(client, monkeypatch):
|
||||
monkeypatch.setattr("routes.mcp.AUTH_ENABLED", True)
|
||||
response = client.get("/mcp/sse")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
def test_health(client):
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
|
||||
Reference in New Issue
Block a user