When a query matches >50 events, the LLM now receives: - Aggregated counts by service, operation, result, and actor - A list of failures (up to 10) - The 50 most recent raw events as samples This scales to thousands of events without blowing the token budget or losing signal. The LLM gets a bird's-eye view plus concrete examples. Also updates the system prompt to handle both individual event lists and aggregated overviews correctly.
423 lines
15 KiB
Python
423 lines
15 KiB
Python
import json
|
|
import re
|
|
from datetime import UTC, datetime, timedelta
|
|
|
|
import httpx
|
|
import structlog
|
|
from auth import require_auth
|
|
from config import LLM_API_KEY, LLM_API_VERSION, LLM_BASE_URL, LLM_MAX_EVENTS, LLM_MODEL, LLM_TIMEOUT_SECONDS
|
|
from database import events_collection
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from models.api import AskRequest, AskResponse
|
|
|
|
router = APIRouter(dependencies=[Depends(require_auth)])
|
|
logger = structlog.get_logger("aoc.ask")
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Time-range extraction
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_TIME_PATTERNS = [
|
|
(r"\blast\s+(\d+)\s+days?\b", "days"),
|
|
(r"\blast\s+(\d+)\s+hours?\b", "hours"),
|
|
(r"\blast\s+(\d+)\s+minutes?\b", "minutes"),
|
|
(r"\blast\s+week\b", "week"),
|
|
(r"\byesterday\b", "yesterday"),
|
|
(r"\btoday\b", "today"),
|
|
(r"\bin\s+the\s+last\s+(\d+)\s+days?\b", "days"),
|
|
(r"\bin\s+the\s+last\s+(\d+)\s+hours?\b", "hours"),
|
|
]
|
|
|
|
|
|
def _extract_time_range(question: str) -> tuple[str | None, str | None]:
|
|
"""Return (start_iso, end_iso) or (None, None) if no time detected."""
|
|
now = datetime.now(UTC)
|
|
q_lower = question.lower()
|
|
|
|
for pattern, unit in _TIME_PATTERNS:
|
|
m = re.search(pattern, q_lower)
|
|
if not m:
|
|
continue
|
|
if unit == "week":
|
|
start = now - timedelta(days=7)
|
|
elif unit == "yesterday":
|
|
start = now - timedelta(days=1)
|
|
elif unit == "today":
|
|
start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
|
else:
|
|
num = int(m.group(1))
|
|
delta = {"days": timedelta(days=num), "hours": timedelta(hours=num), "minutes": timedelta(minutes=num)}[
|
|
unit
|
|
]
|
|
start = now - delta
|
|
return start.isoformat().replace("+00:00", "Z"), now.isoformat().replace("+00:00", "Z")
|
|
|
|
return None, None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Entity extraction
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_ENTITY_HINTS = [
|
|
r"device\s+['\"]?([^'\"\s]+)['\"]?",
|
|
r"user\s+['\"]?([^'\"\s]+)['\"]?",
|
|
r"laptop\s+['\"]?([^'\"\s]+)['\"]?",
|
|
r"vm\s+['\"]?([^'\"\s]+)['\"]?",
|
|
r"server\s+['\"]?([^'\"\s]+)['\"]?",
|
|
r"computer\s+['\"]?([^'\"\s]+)['\"]?",
|
|
]
|
|
|
|
_EMAIL_RE = re.compile(r"[\w.+-]+@[\w-]+\.[\w.-]+")
|
|
|
|
|
|
def _extract_entity(question: str) -> str | None:
|
|
"""Best-effort extraction of the device / user / entity name."""
|
|
q_lower = question.lower()
|
|
|
|
# Look for explicit hints: "device ABC123"
|
|
for pattern in _ENTITY_HINTS:
|
|
m = re.search(pattern, q_lower)
|
|
if m:
|
|
# Extract from the original question to preserve case
|
|
start, end = m.span(1)
|
|
return question[start:end].strip().rstrip("?.!,;:")
|
|
|
|
# Look for quoted strings
|
|
m = re.search(r'"([^"]{2,50})"', question)
|
|
if m:
|
|
return m.group(1).strip()
|
|
m = re.search(r"'([^']{2,50})'", question)
|
|
if m:
|
|
return m.group(1).strip()
|
|
|
|
# Look for email addresses
|
|
m = _EMAIL_RE.search(question)
|
|
if m:
|
|
return m.group(0)
|
|
|
|
return None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# MongoDB query builder
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _build_event_query(
|
|
entity: str | None,
|
|
start: str | None,
|
|
end: str | None,
|
|
services: list[str] | None = None,
|
|
actor: str | None = None,
|
|
operation: str | None = None,
|
|
result: str | None = None,
|
|
include_tags: list[str] | None = None,
|
|
exclude_tags: list[str] | None = None,
|
|
) -> dict:
|
|
filters = []
|
|
|
|
if start or end:
|
|
time_filter = {}
|
|
if start:
|
|
time_filter["$gte"] = start
|
|
if end:
|
|
time_filter["$lte"] = end
|
|
filters.append({"timestamp": time_filter})
|
|
|
|
if entity:
|
|
entity_safe = re.escape(entity)
|
|
filters.append(
|
|
{
|
|
"$or": [
|
|
{"target_displays": {"$elemMatch": {"$regex": entity_safe, "$options": "i"}}},
|
|
{"actor_display": {"$regex": entity_safe, "$options": "i"}},
|
|
{"actor_upn": {"$regex": entity_safe, "$options": "i"}},
|
|
{"raw_text": {"$regex": entity_safe, "$options": "i"}},
|
|
]
|
|
}
|
|
)
|
|
|
|
if services:
|
|
filters.append({"service": {"$in": services}})
|
|
if actor:
|
|
actor_safe = re.escape(actor)
|
|
filters.append(
|
|
{
|
|
"$or": [
|
|
{"actor_display": {"$regex": actor_safe, "$options": "i"}},
|
|
{"actor_upn": {"$regex": actor_safe, "$options": "i"}},
|
|
{"actor.user.userPrincipalName": {"$regex": actor_safe, "$options": "i"}},
|
|
]
|
|
}
|
|
)
|
|
if operation:
|
|
filters.append({"operation": {"$regex": re.escape(operation), "$options": "i"}})
|
|
if result:
|
|
filters.append({"result": {"$regex": re.escape(result), "$options": "i"}})
|
|
if include_tags:
|
|
filters.append({"tags": {"$all": include_tags}})
|
|
if exclude_tags:
|
|
filters.append({"tags": {"$not": {"$all": exclude_tags}}})
|
|
|
|
return {"$and": filters} if filters else {}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# LLM summarisation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_SYSTEM_PROMPT = """You are an IT operations assistant. An administrator has asked a question about audit logs.
|
|
Your job is to read the data below and write a concise, plain-language answer.
|
|
|
|
The input may be either:
|
|
- A small list of individual audit events (numbered Event #1, #2, etc.), or
|
|
- An aggregated overview with counts by service, action, result, and actor, plus sample events.
|
|
|
|
Rules:
|
|
- Assume the reader is a non-expert admin.
|
|
- For aggregated overviews: summarise the scale, top patterns, and highlight anomalies or failures.
|
|
- For small event lists: group related events together and tell a coherent story.
|
|
- Highlight anything unusual, failed actions, or privilege escalations.
|
|
- Reference specific event numbers (e.g., "Event #3") when making claims so the user can verify.
|
|
- If the data is an aggregated subset of a larger result set, acknowledge the scale (e.g., "847 events occurred — the top pattern was...").
|
|
- If there are no events, say so clearly.
|
|
- Keep the answer under 300 words.
|
|
- Do not invent events or patterns that are not supported by the data.
|
|
"""
|
|
|
|
|
|
def _aggregate_counts(events: list[dict]) -> dict:
|
|
"""Build lightweight aggregation tables for large result sets."""
|
|
from collections import Counter
|
|
|
|
svc_counts = Counter(e.get("service") or "Unknown" for e in events)
|
|
op_counts = Counter(e.get("operation") or "Unknown" for e in events)
|
|
result_counts = Counter(e.get("result") or "Unknown" for e in events)
|
|
actor_counts = Counter(e.get("actor_display") or "Unknown" for e in events)
|
|
return {
|
|
"services": svc_counts.most_common(10),
|
|
"operations": op_counts.most_common(10),
|
|
"results": result_counts.most_common(5),
|
|
"actors": actor_counts.most_common(10),
|
|
}
|
|
|
|
|
|
def _format_events_for_llm(events: list[dict], total: int | None = None) -> str:
|
|
lines = []
|
|
|
|
# If we have a large result set, send aggregation + samples instead of raw dump
|
|
if total is not None and total > len(events) and len(events) >= 50:
|
|
lines.append(f"Result set overview: {total} total events (showing the {len(events)} most recent).\n")
|
|
agg = _aggregate_counts(events)
|
|
lines.append("Breakdown by service:")
|
|
for svc, cnt in agg["services"]:
|
|
lines.append(f" {svc}: {cnt}")
|
|
lines.append("\nBreakdown by action:")
|
|
for op, cnt in agg["operations"]:
|
|
lines.append(f" {op}: {cnt}")
|
|
lines.append("\nBreakdown by result:")
|
|
for res, cnt in agg["results"]:
|
|
lines.append(f" {res}: {cnt}")
|
|
lines.append("\nTop actors:")
|
|
for actor, cnt in agg["actors"]:
|
|
lines.append(f" {actor}: {cnt}")
|
|
# Include failures and a few recent samples
|
|
failures = [e for e in events if str(e.get("result") or "").lower() in ("failure", "failed")]
|
|
if failures:
|
|
lines.append(f"\nFailures ({len(failures)}):")
|
|
for e in failures[:10]:
|
|
ts = e.get("timestamp", "?")[:16].replace("T", " ")
|
|
op = e.get("operation", "unknown")
|
|
actor = e.get("actor_display", "unknown")
|
|
lines.append(f" {ts} — {op} by {actor}")
|
|
lines.append("\nMost recent sample events:")
|
|
else:
|
|
if total is not None and total > len(events):
|
|
lines.append(f"Showing {len(events)} of {total} total matching events (most recent first):\n")
|
|
|
|
# Always include the first N raw events as detail (up to 50)
|
|
for i, e in enumerate(events[:50], 1):
|
|
ts = e.get("timestamp") or "unknown time"
|
|
op = e.get("operation") or "unknown action"
|
|
actor = e.get("actor_display") or "unknown actor"
|
|
targets = ", ".join(e.get("target_displays") or []) or "unknown target"
|
|
svc = e.get("service") or "unknown service"
|
|
result = e.get("result") or "unknown result"
|
|
summary = e.get("display_summary") or ""
|
|
lines.append(
|
|
f"Event #{i}\n"
|
|
f" Time: {ts}\n"
|
|
f" Service: {svc}\n"
|
|
f" Action: {op}\n"
|
|
f" Actor: {actor}\n"
|
|
f" Target: {targets}\n"
|
|
f" Result: {result}\n"
|
|
f" Summary: {summary}\n"
|
|
)
|
|
return "\n".join(lines)
|
|
|
|
|
|
def _build_chat_url(base_url: str, api_version: str) -> str:
|
|
"""Construct the chat completions URL, handling Azure OpenAI endpoints."""
|
|
base = base_url.rstrip("/")
|
|
url = base if base.endswith("/chat/completions") else f"{base}/chat/completions"
|
|
if api_version:
|
|
url = f"{url}?api-version={api_version}"
|
|
return url
|
|
|
|
|
|
async def _call_llm(question: str, events: list[dict], total: int | None = None) -> str:
|
|
if not LLM_API_KEY:
|
|
raise RuntimeError("LLM_API_KEY not configured")
|
|
|
|
context = _format_events_for_llm(events, total=total)
|
|
messages = [
|
|
{"role": "system", "content": _SYSTEM_PROMPT},
|
|
{
|
|
"role": "user",
|
|
"content": f"Question: {question}\n\nAudit events:\n{context}\n\nPlease answer the question based only on the events above.",
|
|
},
|
|
]
|
|
|
|
url = _build_chat_url(LLM_BASE_URL, LLM_API_VERSION)
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
}
|
|
# Azure OpenAI uses api-key header; standard OpenAI uses Bearer token
|
|
if "azure" in LLM_BASE_URL.lower() or "cognitiveservices" in LLM_BASE_URL.lower():
|
|
headers["api-key"] = LLM_API_KEY
|
|
else:
|
|
headers["Authorization"] = f"Bearer {LLM_API_KEY}"
|
|
|
|
payload = {
|
|
"model": LLM_MODEL,
|
|
"messages": messages,
|
|
"max_completion_tokens": 800,
|
|
}
|
|
|
|
async with httpx.AsyncClient(timeout=LLM_TIMEOUT_SECONDS) as client:
|
|
resp = await client.post(url, headers=headers, json=payload)
|
|
if resp.status_code >= 400:
|
|
body = resp.text
|
|
logger.error("LLM API error", status_code=resp.status_code, url=url, response_body=body)
|
|
raise RuntimeError(f"LLM API error {resp.status_code}: {body[:500]}")
|
|
data = resp.json()
|
|
return data["choices"][0]["message"]["content"].strip()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# API endpoint
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _to_event_ref(e: dict) -> dict:
|
|
return {
|
|
"id": e.get("id"),
|
|
"timestamp": e.get("timestamp"),
|
|
"operation": e.get("operation"),
|
|
"actor_display": e.get("actor_display"),
|
|
"target_displays": e.get("target_displays"),
|
|
"display_summary": e.get("display_summary"),
|
|
"service": e.get("service"),
|
|
"result": e.get("result"),
|
|
}
|
|
|
|
|
|
@router.post("/ask", response_model=AskResponse)
|
|
async def ask_question(body: AskRequest, user: dict = Depends(require_auth)):
|
|
question = body.question.strip()
|
|
if not question:
|
|
raise HTTPException(status_code=400, detail="Question is required")
|
|
|
|
start, end = _extract_time_range(question)
|
|
entity = _extract_entity(question)
|
|
|
|
# Default to last 7 days if no time range detected
|
|
if not start:
|
|
now = datetime.now(UTC)
|
|
start = (now - timedelta(days=7)).isoformat().replace("+00:00", "Z")
|
|
end = now.isoformat().replace("+00:00", "Z")
|
|
|
|
query = _build_event_query(
|
|
entity,
|
|
start,
|
|
end,
|
|
services=body.services,
|
|
actor=body.actor,
|
|
operation=body.operation,
|
|
result=body.result,
|
|
include_tags=body.include_tags,
|
|
exclude_tags=body.exclude_tags,
|
|
)
|
|
|
|
try:
|
|
total = events_collection.count_documents(query)
|
|
cursor = events_collection.find(query).sort([("timestamp", -1)]).limit(LLM_MAX_EVENTS)
|
|
events = list(cursor)
|
|
except Exception as exc:
|
|
logger.error("Failed to query events for ask", error=str(exc))
|
|
raise HTTPException(status_code=500, detail=f"Database query failed: {exc}") from exc
|
|
|
|
for e in events:
|
|
e["_id"] = str(e.get("_id", ""))
|
|
|
|
# If no events, return early
|
|
if not events:
|
|
return AskResponse(
|
|
answer="I couldn't find any audit events matching your question. Try broadening the time range or checking the spelling of the device/user name.",
|
|
events=[],
|
|
query_info={"entity": entity, "start": start, "end": end, "event_count": 0},
|
|
llm_used=False,
|
|
llm_error="LLM not used — no events found." if not LLM_API_KEY else None,
|
|
)
|
|
|
|
# Try LLM summarisation
|
|
answer = ""
|
|
llm_used = False
|
|
llm_error = None
|
|
if not LLM_API_KEY:
|
|
llm_error = "LLM_API_KEY is not configured. Set it in your .env to enable AI narrative summarisation."
|
|
else:
|
|
try:
|
|
answer = await _call_llm(question, events, total=total)
|
|
llm_used = True
|
|
except Exception as exc:
|
|
llm_error = f"LLM call failed: {exc}"
|
|
logger.warning("LLM call failed, falling back to structured summary", error=str(exc))
|
|
|
|
# Fallback: structured summary if LLM unavailable or failed
|
|
if not answer:
|
|
parts = [f"Found {len(events)} event(s)"]
|
|
if entity:
|
|
parts.append(f"related to **{entity}**")
|
|
parts.append(f"between {start[:10]} and {end[:10]}.\n")
|
|
|
|
for i, e in enumerate(events[:10], 1):
|
|
ts = e.get("timestamp", "?")[:16].replace("T", " ")
|
|
op = e.get("operation", "unknown action")
|
|
actor = e.get("actor_display", "unknown")
|
|
targets = ", ".join(e.get("target_displays") or []) or "—"
|
|
result = e.get("result", "—")
|
|
parts.append(f"{i}. **{ts}** — {op} by {actor} on {targets} ({result})")
|
|
|
|
if len(events) > 10:
|
|
parts.append(f"\n...and {len(events) - 10} more events.")
|
|
|
|
answer = "\n".join(parts)
|
|
|
|
return AskResponse(
|
|
answer=answer,
|
|
events=[_to_event_ref(e) for e in events],
|
|
query_info={
|
|
"entity": entity,
|
|
"start": start,
|
|
"end": end,
|
|
"event_count": len(events),
|
|
"total_matched": total,
|
|
"mongo_query": json.dumps(query, default=str),
|
|
},
|
|
llm_used=llm_used,
|
|
llm_error=llm_error,
|
|
)
|