301 lines
10 KiB
Python
301 lines
10 KiB
Python
import asyncio
|
|
import ipaddress
|
|
import logging
|
|
import os
|
|
import time
|
|
from contextlib import suppress
|
|
from pathlib import Path
|
|
|
|
import structlog
|
|
from audit_trail import log_action
|
|
from config import (
|
|
AI_FEATURES_ENABLED,
|
|
AUTH_ALLOWED_GROUPS,
|
|
AUTH_ALLOWED_ROLES,
|
|
AUTH_ENABLED,
|
|
CORS_ORIGINS,
|
|
DOCS_ENABLED,
|
|
ENABLE_PERIODIC_FETCH,
|
|
FETCH_INTERVAL_MINUTES,
|
|
METRICS_ALLOWED_IPS,
|
|
)
|
|
from database import setup_indexes
|
|
from fastapi import FastAPI, HTTPException, Request
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import Response
|
|
from fastapi.staticfiles import StaticFiles
|
|
from metrics import observe_request, prometheus_metrics
|
|
from middleware import CorrelationIdMiddleware
|
|
from routes.alerts import router as alerts_router
|
|
from routes.config import router as config_router
|
|
from routes.events import router as events_router
|
|
from routes.fetch import router as fetch_router
|
|
from routes.fetch import run_fetch
|
|
from routes.health import router as health_router
|
|
from routes.jobs import router as jobs_router
|
|
from routes.rules import router as rules_router
|
|
from routes.saved_searches import router as saved_searches_router
|
|
from routes.webhooks import router as webhooks_router
|
|
|
|
|
|
def configure_logging():
|
|
structlog.configure(
|
|
processors=[
|
|
structlog.stdlib.filter_by_level,
|
|
structlog.stdlib.add_logger_name,
|
|
structlog.stdlib.add_log_level,
|
|
structlog.stdlib.PositionalArgumentsFormatter(),
|
|
structlog.processors.TimeStamper(fmt="iso"),
|
|
structlog.processors.StackInfoRenderer(),
|
|
structlog.processors.format_exc_info,
|
|
structlog.processors.UnicodeDecoder(),
|
|
structlog.processors.JSONRenderer(),
|
|
],
|
|
context_class=dict,
|
|
logger_factory=structlog.stdlib.LoggerFactory(),
|
|
wrapper_class=structlog.stdlib.BoundLogger,
|
|
cache_logger_on_first_use=True,
|
|
)
|
|
logging.basicConfig(format="%(message)s", level=logging.INFO)
|
|
|
|
|
|
configure_logging()
|
|
logger = structlog.get_logger("aoc.fetcher")
|
|
|
|
# Disable OpenAPI docs in production by default
|
|
app = FastAPI(
|
|
docs_url="/docs" if DOCS_ENABLED else None,
|
|
redoc_url="/redoc" if DOCS_ENABLED else None,
|
|
openapi_url="/openapi.json" if DOCS_ENABLED else None,
|
|
)
|
|
|
|
# CORS: when auth is enabled, never allow credentials with wildcard origins
|
|
_effective_cors = CORS_ORIGINS
|
|
_cors_credentials = True
|
|
if AUTH_ENABLED and "*" in _effective_cors:
|
|
logger.warning(
|
|
"CORS wildcard (*) is insecure with AUTH_ENABLED=true and allow_credentials. "
|
|
"Disabling credentials. Set CORS_ORIGINS to your actual origin(s)."
|
|
)
|
|
_cors_credentials = False
|
|
|
|
app.add_middleware(CorrelationIdMiddleware)
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=_effective_cors,
|
|
allow_credentials=_cors_credentials,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
@app.middleware("http")
|
|
async def prometheus_middleware(request: Request, call_next):
|
|
start = time.time()
|
|
response = await call_next(request)
|
|
duration = time.time() - start
|
|
path = getattr(request.scope.get("route"), "path", request.url.path)
|
|
observe_request(request.method, path, response.status_code, duration)
|
|
return response
|
|
|
|
|
|
@app.middleware("http")
|
|
async def security_headers_middleware(request: Request, call_next):
|
|
response = await call_next(request)
|
|
# Prevent caching of HTML and API responses by default
|
|
if request.url.path.startswith("/api/") or request.url.path in ("/", "/index.html"):
|
|
response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate"
|
|
response.headers["Pragma"] = "no-cache"
|
|
response.headers["Expires"] = "0"
|
|
# Basic CSP for the UI and API (allows MSAL auth flows)
|
|
if request.url.path.startswith("/api/") or request.url.path in ("/", "/index.html"):
|
|
response.headers["Content-Security-Policy"] = (
|
|
"default-src 'self'; "
|
|
"script-src 'self' 'unsafe-inline' 'unsafe-eval' cdn.jsdelivr.net alcdn.msauth.net; "
|
|
"style-src 'self' 'unsafe-inline'; "
|
|
"connect-src 'self' https://login.microsoftonline.com; "
|
|
"frame-src 'self' https://login.microsoftonline.com; "
|
|
"form-action 'self' https://login.microsoftonline.com; "
|
|
"img-src 'self' data:; "
|
|
"font-src 'self' data:;"
|
|
)
|
|
# Additional security headers
|
|
response.headers["X-Content-Type-Options"] = "nosniff"
|
|
response.headers["X-Frame-Options"] = "DENY"
|
|
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
|
response.headers["Permissions-Policy"] = (
|
|
"accelerometer=(), camera=(), geolocation=(), gyroscope=(), magnetometer=(), microphone=(), payment=(), usb=()"
|
|
)
|
|
return response
|
|
|
|
|
|
@app.middleware("http")
|
|
async def rate_limit_middleware(request: Request, call_next):
|
|
"""Apply Redis-backed rate limiting before processing the request."""
|
|
# Exempt config and health endpoints from rate limiting
|
|
exempt_paths = {"/api/config/auth", "/api/config/features", "/health", "/metrics"}
|
|
if request.url.path.startswith("/api/") and request.url.path not in exempt_paths:
|
|
from rate_limiter import check_rate_limit
|
|
|
|
await check_rate_limit(request)
|
|
return await call_next(request)
|
|
|
|
|
|
@app.middleware("http")
|
|
async def audit_middleware(request: Request, call_next):
|
|
response = await call_next(request)
|
|
if request.url.path.startswith("/api/") and request.method in ("POST", "PATCH", "PUT", "DELETE"):
|
|
user = "anonymous"
|
|
if AUTH_ENABLED:
|
|
from auth import _auth_context
|
|
|
|
claims = _auth_context.get(None)
|
|
if isinstance(claims, dict):
|
|
user = claims.get("sub", "unknown")
|
|
log_action(
|
|
action=request.method.lower(),
|
|
resource=request.url.path,
|
|
details={"status_code": response.status_code},
|
|
user=user,
|
|
)
|
|
return response
|
|
|
|
|
|
app.include_router(fetch_router, prefix="/api")
|
|
app.include_router(events_router, prefix="/api")
|
|
app.include_router(config_router, prefix="/api")
|
|
app.include_router(webhooks_router, prefix="/api")
|
|
app.include_router(health_router, prefix="/api")
|
|
if AI_FEATURES_ENABLED:
|
|
from routes.ask import router as ask_router
|
|
|
|
app.include_router(ask_router, prefix="/api")
|
|
from routes.mcp import mcp_asgi
|
|
|
|
app.mount("/mcp", mcp_asgi)
|
|
app.include_router(saved_searches_router, prefix="/api")
|
|
app.include_router(rules_router, prefix="/api")
|
|
app.include_router(alerts_router, prefix="/api")
|
|
app.include_router(jobs_router, prefix="/api")
|
|
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
from database import db
|
|
|
|
try:
|
|
db.command("ping")
|
|
return {"status": "ok", "database": "connected"}
|
|
except Exception as exc:
|
|
logger.error("Health check failed", error=str(exc))
|
|
raise HTTPException(status_code=503, detail="Database unavailable") from exc
|
|
|
|
|
|
def _client_ip(request: Request) -> str:
|
|
"""Best-effort client IP: X-Forwarded-For first hop, or direct client host."""
|
|
forwarded = request.headers.get("x-forwarded-for")
|
|
if forwarded:
|
|
return forwarded.split(",")[0].strip()
|
|
return request.client.host if request.client else ""
|
|
|
|
|
|
def _is_metrics_allowed(ip: str) -> bool:
|
|
"""Check if IP is in the configured metrics allowlist."""
|
|
if not METRICS_ALLOWED_IPS:
|
|
return True
|
|
try:
|
|
client_addr = ipaddress.ip_address(ip)
|
|
except ValueError:
|
|
return False
|
|
for network in METRICS_ALLOWED_IPS.split(","):
|
|
network = network.strip()
|
|
if not network:
|
|
continue
|
|
try:
|
|
if client_addr in ipaddress.ip_network(network, strict=False):
|
|
return True
|
|
except ValueError:
|
|
continue
|
|
return False
|
|
|
|
|
|
@app.get("/metrics")
|
|
async def metrics(request: Request):
|
|
client_ip = _client_ip(request)
|
|
if not _is_metrics_allowed(client_ip):
|
|
raise HTTPException(status_code=403, detail="Forbidden")
|
|
return Response(content=prometheus_metrics(), media_type="text/plain")
|
|
|
|
|
|
@app.get("/api/version")
|
|
async def version():
|
|
return {"version": os.environ.get("VERSION", "unknown")}
|
|
|
|
|
|
@app.exception_handler(Exception)
|
|
async def generic_exception_handler(request: Request, exc: Exception):
|
|
"""Return generic error messages for unhandled exceptions to avoid info leakage."""
|
|
if isinstance(exc, HTTPException):
|
|
from fastapi.responses import JSONResponse
|
|
|
|
return JSONResponse(
|
|
status_code=exc.status_code,
|
|
content={"detail": exc.detail},
|
|
headers=getattr(exc, "headers", None) or {},
|
|
)
|
|
logger.error("Unhandled exception", path=request.url.path, error=str(exc))
|
|
return Response(
|
|
content='{"detail":"Internal server error"}',
|
|
status_code=500,
|
|
media_type="application/json",
|
|
)
|
|
|
|
|
|
frontend_dir = Path(__file__).parent / "frontend"
|
|
app.mount("/", StaticFiles(directory=frontend_dir, html=True), name="frontend")
|
|
|
|
|
|
async def _periodic_fetch():
|
|
while True:
|
|
try:
|
|
await asyncio.to_thread(run_fetch)
|
|
logger.info("Periodic fetch completed.")
|
|
except Exception as exc:
|
|
logger.error("Periodic fetch failed", error=str(exc))
|
|
await asyncio.sleep(FETCH_INTERVAL_MINUTES * 60)
|
|
|
|
|
|
@app.on_event("startup")
|
|
async def start_periodic_fetch():
|
|
setup_indexes()
|
|
from rules import seed_default_rules
|
|
|
|
seed_default_rules()
|
|
logger.info(
|
|
"AOC startup",
|
|
version=os.environ.get("VERSION", "unknown"),
|
|
auth_enabled=AUTH_ENABLED,
|
|
ai_enabled=AI_FEATURES_ENABLED,
|
|
)
|
|
# Warn when auth is enabled but no role/group restrictions are configured
|
|
if AUTH_ENABLED and not AUTH_ALLOWED_ROLES and not AUTH_ALLOWED_GROUPS:
|
|
logger.warning(
|
|
"AUTH_ENABLED is true but no AUTH_ALLOWED_ROLES or AUTH_ALLOWED_GROUPS are configured. "
|
|
"Any Entra user in the tenant can authenticate and access AOC. "
|
|
"Set AUTH_ALLOWED_ROLES or AUTH_ALLOWED_GROUPS to restrict access."
|
|
)
|
|
if ENABLE_PERIODIC_FETCH:
|
|
app.state.fetch_task = asyncio.create_task(_periodic_fetch())
|
|
|
|
|
|
@app.on_event("shutdown")
|
|
async def stop_periodic_fetch():
|
|
task = getattr(app.state, "fetch_task", None)
|
|
if task:
|
|
task.cancel()
|
|
with suppress(Exception):
|
|
await task
|
|
from redis_client import close_redis_connections
|
|
|
|
await close_redis_connections()
|