872 lines
32 KiB
Python
872 lines
32 KiB
Python
import asyncio
|
|
import json
|
|
import re
|
|
from datetime import UTC, datetime, timedelta
|
|
|
|
import httpx
|
|
import structlog
|
|
from auth import require_auth, user_can_access_privacy_services
|
|
from config import (
|
|
LLM_API_KEY,
|
|
LLM_API_VERSION,
|
|
LLM_BASE_URL,
|
|
LLM_MAX_EVENTS,
|
|
LLM_MODEL,
|
|
LLM_TIMEOUT_SECONDS,
|
|
PRIVACY_SENSITIVE_OPERATIONS,
|
|
PRIVACY_SERVICES,
|
|
)
|
|
from database import events_collection
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from jobs import get_cached_ask, get_cached_explain, set_cached_ask, set_cached_explain
|
|
from models.api import AskRequest, AskResponse
|
|
from redis_client import get_arq_pool
|
|
|
|
router = APIRouter(dependencies=[Depends(require_auth)])
|
|
logger = structlog.get_logger("aoc.ask")
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Intent extraction — map question keywords to relevant audit services
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_SERVICE_INTENTS = {
|
|
"intune": ["Intune"],
|
|
"device": ["Intune", "Device"],
|
|
"laptop": ["Intune", "Device"],
|
|
"mobile": ["Intune", "Device"],
|
|
"phone": ["Intune", "Device"],
|
|
"ipad": ["Intune", "Device"],
|
|
"app": ["Intune", "ApplicationManagement"],
|
|
"application": ["Intune", "ApplicationManagement"],
|
|
"policy": ["Intune", "Policy"],
|
|
"compliance": ["Intune", "Policy"],
|
|
"user": ["Directory", "UserManagement"],
|
|
"group": ["Directory", "GroupManagement"],
|
|
"role": ["Directory", "RoleManagement"],
|
|
"permission": ["Directory", "RoleManagement"],
|
|
"license": ["Directory", "License"],
|
|
"email": ["Exchange"],
|
|
"mailbox": ["Exchange"],
|
|
"mail": ["Exchange"],
|
|
"message": ["Exchange", "Teams"],
|
|
"file": ["SharePoint"],
|
|
"sharepoint": ["SharePoint"],
|
|
"site": ["SharePoint"],
|
|
"document": ["SharePoint"],
|
|
"team": ["Teams"],
|
|
"channel": ["Teams"],
|
|
"meeting": ["Teams"],
|
|
"call": ["Teams"],
|
|
}
|
|
|
|
# Services that are extremely noisy for typical admin questions.
|
|
# We exclude them by default on broad questions unless the user explicitly mentions them.
|
|
_NOISY_SERVICES = {"Exchange", "SharePoint", "Teams"}
|
|
|
|
# Services that are generally admin-relevant and kept by default.
|
|
_DEFAULT_ADMIN_SERVICES = {
|
|
"Directory",
|
|
"UserManagement",
|
|
"GroupManagement",
|
|
"RoleManagement",
|
|
"ApplicationManagement",
|
|
"Intune",
|
|
"Device",
|
|
"Policy",
|
|
"Teams",
|
|
"License",
|
|
}
|
|
|
|
|
|
def _extract_intent_services(question: str) -> tuple[list[str] | None, bool]:
|
|
"""
|
|
Extract relevant services from the question.
|
|
|
|
Returns:
|
|
(services, is_explicit):
|
|
- services: list of service names to query, or None for default admin set
|
|
- is_explicit: True if the user explicitly mentioned a noisy service
|
|
"""
|
|
q_lower = question.lower()
|
|
tokens = set(re.findall(r"\b[a-z]+\b", q_lower))
|
|
|
|
matched_services = set()
|
|
for token, services in _SERVICE_INTENTS.items():
|
|
if token in tokens:
|
|
matched_services.update(services)
|
|
|
|
if matched_services:
|
|
# User asked something specific — return exactly what they asked for
|
|
is_explicit = not matched_services.isdisjoint(_NOISY_SERVICES)
|
|
return sorted(matched_services), is_explicit
|
|
|
|
# Broad question with no clear intent — default to admin-relevant services only
|
|
return None, False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Smart sampling — stratified by importance so the LLM sees signal, not noise
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _smart_sample(events: list[dict], max_events: int = 200) -> list[dict]:
|
|
"""
|
|
Return a curated subset that preserves diversity and prioritises signal.
|
|
|
|
Tiers:
|
|
1. Failures (always valuable)
|
|
2. High-admin-value services (Intune, Device, Directory, etc.)
|
|
3. Everything else
|
|
"""
|
|
if len(events) <= max_events:
|
|
return events
|
|
|
|
high_value = {
|
|
"Directory",
|
|
"UserManagement",
|
|
"GroupManagement",
|
|
"RoleManagement",
|
|
"Intune",
|
|
"Device",
|
|
"Policy",
|
|
"ApplicationManagement",
|
|
}
|
|
|
|
failures = [e for e in events if str(e.get("result") or "").lower() in ("failure", "failed")]
|
|
high_val = [e for e in events if e.get("service") in high_value and e not in failures]
|
|
rest = [e for e in events if e not in failures and e not in high_val]
|
|
|
|
# Allocate slots: half to failures+high-value, half to rest (but never let rest dominate)
|
|
slots = max_events
|
|
failure_cap = min(len(failures), max(10, slots // 4))
|
|
high_cap = min(len(high_val), max(20, slots // 4))
|
|
rest_cap = slots - failure_cap - high_cap
|
|
|
|
sampled = failures[:failure_cap] + high_val[:high_cap] + rest[:rest_cap]
|
|
# Sort back to chronological order
|
|
sampled.sort(key=lambda e: e.get("timestamp") or "", reverse=True)
|
|
return sampled
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Time-range extraction
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_TIME_PATTERNS = [
|
|
(r"\blast\s+(\d+)\s+days?\b", "days"),
|
|
(r"\blast\s+(\d+)\s+hours?\b", "hours"),
|
|
(r"\blast\s+(\d+)\s+minutes?\b", "minutes"),
|
|
(r"\blast\s+week\b", "week"),
|
|
(r"\byesterday\b", "yesterday"),
|
|
(r"\btoday\b", "today"),
|
|
(r"\bin\s+the\s+last\s+(\d+)\s+days?\b", "days"),
|
|
(r"\bin\s+the\s+last\s+(\d+)\s+hours?\b", "hours"),
|
|
]
|
|
|
|
|
|
def _extract_time_range(question: str) -> tuple[str | None, str | None]:
|
|
"""Return (start_iso, end_iso) or (None, None) if no time detected."""
|
|
now = datetime.now(UTC)
|
|
q_lower = question.lower()
|
|
|
|
for pattern, unit in _TIME_PATTERNS:
|
|
m = re.search(pattern, q_lower)
|
|
if not m:
|
|
continue
|
|
if unit == "week":
|
|
start = now - timedelta(days=7)
|
|
elif unit == "yesterday":
|
|
start = now - timedelta(days=1)
|
|
elif unit == "today":
|
|
start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
|
else:
|
|
num = int(m.group(1))
|
|
delta = {"days": timedelta(days=num), "hours": timedelta(hours=num), "minutes": timedelta(minutes=num)}[
|
|
unit
|
|
]
|
|
start = now - delta
|
|
return start.isoformat().replace("+00:00", "Z"), now.isoformat().replace("+00:00", "Z")
|
|
|
|
return None, None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Entity extraction
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_ENTITY_HINTS = [
|
|
r"device\s+['\"]?([^'\"\s]+)['\"]?",
|
|
r"user\s+['\"]?([^'\"\s]+)['\"]?",
|
|
r"laptop\s+['\"]?([^'\"\s]+)['\"]?",
|
|
r"vm\s+['\"]?([^'\"\s]+)['\"]?",
|
|
r"server\s+['\"]?([^'\"\s]+)['\"]?",
|
|
r"computer\s+['\"]?([^'\"\s]+)['\"]?",
|
|
]
|
|
|
|
_EMAIL_RE = re.compile(r"[\w.+-]+@[\w-]+\.[\w.-]+")
|
|
|
|
|
|
def _extract_entity(question: str) -> str | None:
|
|
"""Best-effort extraction of the device / user / entity name."""
|
|
q_lower = question.lower()
|
|
|
|
# Look for explicit hints: "device ABC123"
|
|
for pattern in _ENTITY_HINTS:
|
|
m = re.search(pattern, q_lower)
|
|
if m:
|
|
# Extract from the original question to preserve case
|
|
start, end = m.span(1)
|
|
return question[start:end].strip().rstrip("?.!,;:")
|
|
|
|
# Look for quoted strings
|
|
m = re.search(r'"([^"]{2,50})"', question)
|
|
if m:
|
|
return m.group(1).strip()
|
|
m = re.search(r"'([^']{2,50})'", question)
|
|
if m:
|
|
return m.group(1).strip()
|
|
|
|
# Look for email addresses
|
|
m = _EMAIL_RE.search(question)
|
|
if m:
|
|
return m.group(0)
|
|
|
|
return None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# MongoDB query builder
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _build_event_query(
|
|
entity: str | None,
|
|
start: str | None,
|
|
end: str | None,
|
|
services: list[str] | None = None,
|
|
actor: str | None = None,
|
|
operation: str | None = None,
|
|
result: str | None = None,
|
|
include_tags: list[str] | None = None,
|
|
exclude_tags: list[str] | None = None,
|
|
) -> dict:
|
|
filters = []
|
|
|
|
if start or end:
|
|
time_filter = {}
|
|
if start:
|
|
time_filter["$gte"] = start
|
|
if end:
|
|
time_filter["$lte"] = end
|
|
filters.append({"timestamp": time_filter})
|
|
|
|
if entity:
|
|
entity_safe = re.escape(entity)
|
|
filters.append(
|
|
{
|
|
"$or": [
|
|
{"target_displays": {"$elemMatch": {"$regex": entity_safe, "$options": "i"}}},
|
|
{"actor_display": {"$regex": entity_safe, "$options": "i"}},
|
|
{"actor_upn": {"$regex": entity_safe, "$options": "i"}},
|
|
{"raw_text": {"$regex": entity_safe, "$options": "i"}},
|
|
]
|
|
}
|
|
)
|
|
|
|
if services:
|
|
filters.append({"service": {"$in": services}})
|
|
if actor:
|
|
actor_safe = re.escape(actor)
|
|
filters.append(
|
|
{
|
|
"$or": [
|
|
{"actor_display": {"$regex": actor_safe, "$options": "i"}},
|
|
{"actor_upn": {"$regex": actor_safe, "$options": "i"}},
|
|
{"actor.user.userPrincipalName": {"$regex": actor_safe, "$options": "i"}},
|
|
]
|
|
}
|
|
)
|
|
if operation:
|
|
filters.append({"operation": {"$regex": re.escape(operation), "$options": "i"}})
|
|
if result:
|
|
filters.append({"result": {"$regex": re.escape(result), "$options": "i"}})
|
|
if include_tags:
|
|
filters.append({"tags": {"$all": include_tags}})
|
|
if exclude_tags:
|
|
filters.append({"tags": {"$not": {"$all": exclude_tags}}})
|
|
|
|
return {"$and": filters} if filters else {}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# LLM summarisation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_SYSTEM_PROMPT = """You are an IT operations assistant. An administrator has asked a question about audit logs.
|
|
Your job is to read the data below and write a concise, plain-language answer.
|
|
|
|
The input may be either:
|
|
- A small list of individual audit events (numbered Event #1, #2, etc.), or
|
|
- An aggregated overview with counts by service, action, result, and actor, plus sample events.
|
|
|
|
Rules:
|
|
- Assume the reader is a non-expert admin.
|
|
- For aggregated overviews: summarise the scale, top patterns, and highlight anomalies or failures.
|
|
- For small event lists: group related events together and tell a coherent story.
|
|
- Highlight anything unusual, failed actions, or privilege escalations.
|
|
- Reference specific event numbers (e.g., "Event #3") when making claims so the user can verify.
|
|
- If the data is an aggregated subset of a larger result set, acknowledge the scale (e.g., "847 events occurred — the top pattern was...").
|
|
- If there are no events, say so clearly.
|
|
- Keep the answer under 300 words.
|
|
- Do not invent events or patterns that are not supported by the data.
|
|
"""
|
|
|
|
|
|
def _aggregate_counts(events: list[dict]) -> dict:
|
|
"""Build lightweight aggregation tables for large result sets."""
|
|
from collections import Counter
|
|
|
|
svc_counts = Counter(e.get("service") or "Unknown" for e in events)
|
|
op_counts = Counter(e.get("operation") or "Unknown" for e in events)
|
|
result_counts = Counter(e.get("result") or "Unknown" for e in events)
|
|
actor_counts = Counter(e.get("actor_display") or "Unknown" for e in events)
|
|
return {
|
|
"services": svc_counts.most_common(10),
|
|
"operations": op_counts.most_common(10),
|
|
"results": result_counts.most_common(5),
|
|
"actors": actor_counts.most_common(10),
|
|
}
|
|
|
|
|
|
def _format_events_for_llm(
|
|
events: list[dict], total: int | None = None, excluded_services: list[str] | None = None
|
|
) -> str:
|
|
lines = []
|
|
|
|
# If we have a large result set, send aggregation + samples instead of raw dump
|
|
if total is not None and total > len(events) and len(events) >= 50:
|
|
lines.append(f"Result set overview: {total} total events (showing a curated sample of {len(events)}).\n")
|
|
if excluded_services:
|
|
lines.append(f"Note: high-volume services excluded by default: {', '.join(excluded_services)}.\n")
|
|
agg = _aggregate_counts(events)
|
|
lines.append("Breakdown by service:")
|
|
for svc, cnt in agg["services"]:
|
|
lines.append(f" {svc}: {cnt}")
|
|
lines.append("\nBreakdown by action:")
|
|
for op, cnt in agg["operations"]:
|
|
lines.append(f" {op}: {cnt}")
|
|
lines.append("\nBreakdown by result:")
|
|
for res, cnt in agg["results"]:
|
|
lines.append(f" {res}: {cnt}")
|
|
lines.append("\nTop actors:")
|
|
for actor, cnt in agg["actors"]:
|
|
lines.append(f" {actor}: {cnt}")
|
|
# Include failures and a few recent samples
|
|
failures = [e for e in events if str(e.get("result") or "").lower() in ("failure", "failed")]
|
|
if failures:
|
|
lines.append(f"\nFailures ({len(failures)}):")
|
|
for e in failures[:10]:
|
|
ts = e.get("timestamp", "?")[:16].replace("T", " ")
|
|
op = e.get("operation", "unknown")
|
|
actor = e.get("actor_display", "unknown")
|
|
lines.append(f" {ts} — {op} by {actor}")
|
|
lines.append("\nMost recent sample events:")
|
|
else:
|
|
if total is not None and total > len(events):
|
|
lines.append(f"Showing {len(events)} of {total} total matching events (most recent first):\n")
|
|
|
|
# Always include the first N raw events as detail (up to 50)
|
|
for i, e in enumerate(events[:50], 1):
|
|
ts = e.get("timestamp") or "unknown time"
|
|
op = e.get("operation") or "unknown action"
|
|
actor = e.get("actor_display") or "unknown actor"
|
|
targets = ", ".join(e.get("target_displays") or []) or "unknown target"
|
|
svc = e.get("service") or "unknown service"
|
|
result = e.get("result") or "unknown result"
|
|
summary = e.get("display_summary") or ""
|
|
lines.append(
|
|
f"Event #{i}\n"
|
|
f" Time: {ts}\n"
|
|
f" Service: {svc}\n"
|
|
f" Action: {op}\n"
|
|
f" Actor: {actor}\n"
|
|
f" Target: {targets}\n"
|
|
f" Result: {result}\n"
|
|
f" Summary: {summary}\n"
|
|
)
|
|
return "\n".join(lines)
|
|
|
|
|
|
def _build_chat_url(base_url: str, api_version: str) -> str:
|
|
"""Construct the chat completions URL, handling Azure OpenAI endpoints."""
|
|
base = base_url.rstrip("/")
|
|
url = base if base.endswith("/chat/completions") else f"{base}/chat/completions"
|
|
if api_version:
|
|
url = f"{url}?api-version={api_version}"
|
|
return url
|
|
|
|
|
|
async def _call_llm(
|
|
question: str,
|
|
events: list[dict],
|
|
total: int | None = None,
|
|
excluded_services: list[str] | None = None,
|
|
) -> str:
|
|
if not LLM_API_KEY:
|
|
raise RuntimeError("LLM_API_KEY not configured")
|
|
|
|
context = _format_events_for_llm(events, total=total, excluded_services=excluded_services)
|
|
messages = [
|
|
{"role": "system", "content": _SYSTEM_PROMPT},
|
|
{
|
|
"role": "user",
|
|
"content": f"Question: {question}\n\nAudit events:\n{context}\n\nPlease answer the question based only on the events above.",
|
|
},
|
|
]
|
|
|
|
url = _build_chat_url(LLM_BASE_URL, LLM_API_VERSION)
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
}
|
|
# Azure OpenAI uses api-key header; standard OpenAI uses Bearer token
|
|
if "azure" in LLM_BASE_URL.lower() or "cognitiveservices" in LLM_BASE_URL.lower():
|
|
headers["api-key"] = LLM_API_KEY
|
|
else:
|
|
headers["Authorization"] = f"Bearer {LLM_API_KEY}"
|
|
|
|
payload = {
|
|
"model": LLM_MODEL,
|
|
"messages": messages,
|
|
"max_completion_tokens": 800,
|
|
}
|
|
|
|
async with httpx.AsyncClient(timeout=LLM_TIMEOUT_SECONDS) as client:
|
|
resp = await client.post(url, headers=headers, json=payload)
|
|
if resp.status_code >= 400:
|
|
body = resp.text
|
|
logger.error("LLM API error", status_code=resp.status_code, url=url, response_body=body)
|
|
raise RuntimeError(f"LLM API error {resp.status_code}: {body[:500]}")
|
|
data = resp.json()
|
|
return data["choices"][0]["message"]["content"].strip()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# API endpoint
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _to_event_ref(e: dict) -> dict:
|
|
return {
|
|
"id": e.get("id"),
|
|
"timestamp": e.get("timestamp"),
|
|
"operation": e.get("operation"),
|
|
"actor_display": e.get("actor_display"),
|
|
"target_displays": e.get("target_displays"),
|
|
"display_summary": e.get("display_summary"),
|
|
"service": e.get("service"),
|
|
"result": e.get("result"),
|
|
}
|
|
|
|
|
|
_EXPLAIN_SYSTEM_PROMPT = """You are a Microsoft 365 security and compliance expert.
|
|
An administrator needs help understanding an audit event.
|
|
|
|
Your task:
|
|
1. Explain what happened in plain language (1-2 sentences).
|
|
2. Identify who performed the action and what was the target.
|
|
3. Assess whether this is typical admin activity or something to investigate.
|
|
4. Highlight any security implications (privilege escalation, unusual actor, after-hours activity, etc.).
|
|
5. Suggest what the admin should do next, if anything.
|
|
|
|
Keep the answer under 200 words. Use bullet points for readability.
|
|
Do not invent facts that are not in the data.
|
|
"""
|
|
|
|
|
|
_GUID_RE = re.compile(r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$")
|
|
|
|
|
|
def _extract_guids(obj: dict | list | str) -> set[str]:
|
|
"""Recursively extract UUID-like strings from a JSON structure."""
|
|
guids = set()
|
|
if isinstance(obj, dict):
|
|
for k, v in obj.items():
|
|
if k.lower() in ("id", "groupid", "userid", "targetid") and isinstance(v, str) and _GUID_RE.match(v):
|
|
guids.add(v)
|
|
guids.update(_extract_guids(v))
|
|
elif isinstance(obj, list):
|
|
for item in obj:
|
|
guids.update(_extract_guids(item))
|
|
elif isinstance(obj, str) and _GUID_RE.match(obj):
|
|
guids.add(obj)
|
|
return guids
|
|
|
|
|
|
async def _resolve_guids_for_event(event: dict) -> dict[str, str]:
|
|
"""Try to resolve GUIDs in an event to human-readable names via Graph API."""
|
|
raw = event.get("raw") or {}
|
|
guids = _extract_guids(raw)
|
|
# Also include any GUIDs in targetResources that might not have displayName
|
|
for tr in raw.get("targetResources") or []:
|
|
tid = tr.get("id")
|
|
if tid and _GUID_RE.match(tid):
|
|
guids.add(tid)
|
|
for tr in raw.get("modifiedProperties") or []:
|
|
for key in ("oldValue", "newValue"):
|
|
val = tr.get(key)
|
|
if val and _GUID_RE.match(val):
|
|
guids.add(val)
|
|
|
|
if not guids:
|
|
return {}
|
|
|
|
try:
|
|
from graph.auth import get_access_token
|
|
from graph.resolve import resolve_directory_object
|
|
|
|
token = await asyncio.to_thread(get_access_token)
|
|
cache: dict[str, dict] = {}
|
|
resolved = {}
|
|
for gid in guids:
|
|
result = await asyncio.to_thread(resolve_directory_object, gid, token, cache)
|
|
if result:
|
|
resolved[gid] = result["name"]
|
|
return resolved
|
|
except Exception as exc:
|
|
logger.warning("GUID resolution failed", error=str(exc))
|
|
return {}
|
|
|
|
|
|
async def _explain_event(event: dict, related: list[dict]) -> str:
|
|
if not LLM_API_KEY:
|
|
raise RuntimeError("LLM_API_KEY not configured")
|
|
|
|
# Resolve GUIDs to names before sending to LLM
|
|
resolved = await _resolve_guids_for_event(event)
|
|
|
|
event_text = json.dumps(event, indent=2, default=str)
|
|
resolution_text = ""
|
|
if resolved:
|
|
resolution_text = "\nResolved GUIDs:\n"
|
|
for gid, name in resolved.items():
|
|
resolution_text += f" {gid} → {name}\n"
|
|
|
|
related_text = ""
|
|
if related:
|
|
related_text = "\n\nRelated events in the last 24 hours:\n"
|
|
for i, e in enumerate(related[:10], 1):
|
|
ts = e.get("timestamp", "?")[:16].replace("T", " ")
|
|
op = e.get("operation", "unknown")
|
|
actor = e.get("actor_display", "unknown")
|
|
targets = ", ".join(e.get("target_displays") or []) or "—"
|
|
result = e.get("result", "—")
|
|
related_text += f"{i}. {ts} — {op} by {actor} on {targets} ({result})\n"
|
|
|
|
messages = [
|
|
{"role": "system", "content": _EXPLAIN_SYSTEM_PROMPT},
|
|
{
|
|
"role": "user",
|
|
"content": f"Audit event:\n{event_text}{resolution_text}{related_text}\n\nPlease explain this event.",
|
|
},
|
|
]
|
|
|
|
url = _build_chat_url(LLM_BASE_URL, LLM_API_VERSION)
|
|
headers = {"Content-Type": "application/json"}
|
|
if "azure" in LLM_BASE_URL.lower() or "cognitiveservices" in LLM_BASE_URL.lower():
|
|
headers["api-key"] = LLM_API_KEY
|
|
else:
|
|
headers["Authorization"] = f"Bearer {LLM_API_KEY}"
|
|
|
|
payload = {
|
|
"model": LLM_MODEL,
|
|
"messages": messages,
|
|
"max_completion_tokens": 600,
|
|
}
|
|
|
|
async with httpx.AsyncClient(timeout=LLM_TIMEOUT_SECONDS) as client:
|
|
resp = await client.post(url, headers=headers, json=payload)
|
|
if resp.status_code >= 400:
|
|
body = resp.text
|
|
logger.error("LLM API error", status_code=resp.status_code, url=url, response_body=body)
|
|
raise RuntimeError(f"LLM API error {resp.status_code}: {body[:500]}")
|
|
data = resp.json()
|
|
return data["choices"][0]["message"]["content"].strip()
|
|
|
|
|
|
@router.post("/events/{event_id}/explain")
|
|
async def explain_event(event_id: str, user: dict = Depends(require_auth)):
|
|
event = events_collection.find_one({"id": event_id})
|
|
if not event:
|
|
raise HTTPException(status_code=404, detail="Event not found")
|
|
|
|
if (
|
|
event.get("service") in PRIVACY_SERVICES or event.get("operation") in PRIVACY_SENSITIVE_OPERATIONS
|
|
) and not user_can_access_privacy_services(user):
|
|
raise HTTPException(status_code=403, detail="Access to this event is restricted")
|
|
|
|
event.pop("_id", None)
|
|
|
|
# Fetch related events for context (same actor or target in last 24h)
|
|
related = []
|
|
since = (datetime.now(UTC) - timedelta(hours=24)).isoformat().replace("+00:00", "Z")
|
|
actor = event.get("actor_upn") or event.get("actor_display")
|
|
target = event.get("target_displays", [None])[0] if event.get("target_displays") else None
|
|
|
|
or_filters = [{"timestamp": {"$gte": since}}, {"id": {"$ne": event_id}}]
|
|
if actor:
|
|
or_filters.append(
|
|
{
|
|
"$or": [
|
|
{"actor_upn": actor},
|
|
{"actor_display": actor},
|
|
]
|
|
}
|
|
)
|
|
if target:
|
|
or_filters.append({"target_displays": target})
|
|
|
|
if len(or_filters) > 2:
|
|
try:
|
|
rel_cursor = events_collection.find({"$and": or_filters}).sort("timestamp", -1).limit(10)
|
|
related = list(rel_cursor)
|
|
for r in related:
|
|
r.pop("_id", None)
|
|
r.pop("raw", None)
|
|
except Exception as exc:
|
|
logger.warning("Failed to fetch related events", error=str(exc))
|
|
|
|
if not LLM_API_KEY:
|
|
return {
|
|
"explanation": "LLM is not configured. Set LLM_API_KEY in your environment to enable event explanations.",
|
|
"llm_used": False,
|
|
"llm_error": "LLM_API_KEY not configured",
|
|
}
|
|
|
|
# Check cache first
|
|
redis = await get_arq_pool()
|
|
cached = await get_cached_explain(redis, event_id)
|
|
if cached:
|
|
cached["related_count"] = len(related)
|
|
return cached
|
|
|
|
try:
|
|
explanation = await _explain_event(event, related)
|
|
result = {
|
|
"explanation": explanation,
|
|
"llm_used": True,
|
|
"llm_error": None,
|
|
"related_count": len(related),
|
|
}
|
|
await set_cached_explain(redis, event_id, result)
|
|
return result
|
|
except Exception as exc:
|
|
logger.warning("Event explanation failed", error=str(exc))
|
|
return {
|
|
"explanation": "Unable to generate an explanation at this time. Please check the raw event details.",
|
|
"llm_used": False,
|
|
"llm_error": str(exc),
|
|
"related_count": len(related),
|
|
}
|
|
|
|
|
|
@router.post("/ask", response_model=AskResponse)
|
|
async def ask_question(body: AskRequest, user: dict = Depends(require_auth)):
|
|
question = body.question.strip()
|
|
if not question:
|
|
raise HTTPException(status_code=400, detail="Question is required")
|
|
|
|
start, end = _extract_time_range(question)
|
|
entity = _extract_entity(question)
|
|
intent_services, explicit_noisy = _extract_intent_services(question)
|
|
|
|
# Default to last 7 days if no time range detected
|
|
if not start:
|
|
now = datetime.now(UTC)
|
|
start = (now - timedelta(days=7)).isoformat().replace("+00:00", "Z")
|
|
end = now.isoformat().replace("+00:00", "Z")
|
|
|
|
# -----------------------------------------------------------------------
|
|
# Decide which services to query
|
|
# -----------------------------------------------------------------------
|
|
excluded_services: list[str] = []
|
|
if body.services:
|
|
# User explicitly filtered via UI — respect that exactly
|
|
query_services = body.services
|
|
elif intent_services is not None:
|
|
# NL question implies specific services
|
|
query_services = intent_services
|
|
else:
|
|
# Broad question with no intent — exclude noisy services by default
|
|
query_services = sorted(_DEFAULT_ADMIN_SERVICES)
|
|
excluded_services = sorted(_NOISY_SERVICES)
|
|
|
|
# -----------------------------------------------------------------------
|
|
# Build and run query
|
|
# -----------------------------------------------------------------------
|
|
privacy_excluded_services = [] if user_can_access_privacy_services(user) else list(PRIVACY_SERVICES)
|
|
privacy_excluded_ops = [] if user_can_access_privacy_services(user) else list(PRIVACY_SENSITIVE_OPERATIONS)
|
|
query = _build_event_query(
|
|
entity,
|
|
start,
|
|
end,
|
|
services=query_services,
|
|
actor=body.actor,
|
|
operation=body.operation,
|
|
result=body.result,
|
|
include_tags=body.include_tags,
|
|
exclude_tags=body.exclude_tags,
|
|
)
|
|
extra_filters = []
|
|
if privacy_excluded_services:
|
|
extra_filters.append({"service": {"$nin": privacy_excluded_services}})
|
|
if privacy_excluded_ops:
|
|
extra_filters.append({"operation": {"$nin": privacy_excluded_ops}})
|
|
if extra_filters:
|
|
query["$and"] = query.get("$and", []) + extra_filters
|
|
|
|
try:
|
|
total = events_collection.count_documents(query)
|
|
# Fetch a generous window so we can apply smart sampling in Python
|
|
cursor = events_collection.find(query).sort([("timestamp", -1)]).limit(1000)
|
|
raw_events = list(cursor)
|
|
except Exception as exc:
|
|
logger.error("Failed to query events for ask", error=str(exc))
|
|
raise HTTPException(status_code=500, detail=f"Database query failed: {exc}") from exc
|
|
|
|
for e in raw_events:
|
|
e["_id"] = str(e.get("_id", ""))
|
|
|
|
# Apply smart sampling (preserves failures, prioritises admin-relevant services)
|
|
events = _smart_sample(raw_events, max_events=LLM_MAX_EVENTS)
|
|
|
|
# If no events, return early
|
|
if not events:
|
|
return AskResponse(
|
|
answer="I couldn't find any audit events matching your question. Try broadening the time range or checking the spelling of the device/user name.",
|
|
events=[],
|
|
query_info={
|
|
"entity": entity,
|
|
"start": start,
|
|
"end": end,
|
|
"event_count": 0,
|
|
"total_matched": total,
|
|
"services_queried": query_services,
|
|
"excluded_services": excluded_services,
|
|
},
|
|
llm_used=False,
|
|
llm_error="LLM not used — no events found." if not LLM_API_KEY else None,
|
|
)
|
|
|
|
# Try LLM summarisation (with caching + optional async)
|
|
answer = ""
|
|
llm_used = False
|
|
llm_error = None
|
|
job_id = None
|
|
|
|
filters_snapshot = {
|
|
"services": body.services,
|
|
"actor": body.actor,
|
|
"operation": body.operation,
|
|
"result": body.result,
|
|
"start": body.start,
|
|
"end": body.end,
|
|
"include_tags": body.include_tags,
|
|
"exclude_tags": body.exclude_tags,
|
|
}
|
|
|
|
if LLM_API_KEY:
|
|
redis = await get_arq_pool()
|
|
cached = await get_cached_ask(redis, question, filters_snapshot, events)
|
|
if cached:
|
|
answer = cached.get("answer", "")
|
|
llm_used = cached.get("llm_used", False)
|
|
llm_error = cached.get("llm_error")
|
|
elif body.async_mode:
|
|
pool = await get_arq_pool()
|
|
job = await pool.enqueue_job(
|
|
"process_ask_question",
|
|
question,
|
|
filters_snapshot,
|
|
events,
|
|
total,
|
|
excluded_services,
|
|
)
|
|
job_id = job.job_id if job else None
|
|
return AskResponse(
|
|
answer="Your question is being processed. Poll /api/jobs/{job_id} for the result.",
|
|
events=[_to_event_ref(e) for e in events],
|
|
query_info={
|
|
"entity": entity,
|
|
"start": start,
|
|
"end": end,
|
|
"event_count": len(events),
|
|
"total_matched": total,
|
|
"services_queried": query_services,
|
|
"excluded_services": excluded_services,
|
|
"mongo_query": json.dumps(query, default=str),
|
|
},
|
|
llm_used=False,
|
|
llm_error=None,
|
|
job_id=job_id,
|
|
)
|
|
else:
|
|
try:
|
|
answer = await _call_llm(question, events, total=total, excluded_services=excluded_services)
|
|
llm_used = True
|
|
await set_cached_ask(
|
|
redis,
|
|
question,
|
|
filters_snapshot,
|
|
events,
|
|
{
|
|
"answer": answer,
|
|
"llm_used": True,
|
|
"llm_error": None,
|
|
},
|
|
)
|
|
except Exception as exc:
|
|
llm_error = f"LLM call failed: {exc}"
|
|
logger.warning("LLM call failed, falling back to structured summary", error=str(exc))
|
|
else:
|
|
llm_error = "LLM_API_KEY is not configured. Set it in your .env to enable AI narrative summarisation."
|
|
|
|
# Fallback: structured summary if LLM unavailable or failed
|
|
if not answer:
|
|
parts = [f"Found {total} event(s)"]
|
|
if entity:
|
|
parts.append(f"related to **{entity}**")
|
|
if excluded_services:
|
|
parts.append(f"(excluding {', '.join(excluded_services)})")
|
|
parts.append(f"between {start[:10]} and {end[:10]}.\n")
|
|
|
|
for i, e in enumerate(events[:10], 1):
|
|
ts = e.get("timestamp", "?")[:16].replace("T", " ")
|
|
op = e.get("operation", "unknown action")
|
|
actor = e.get("actor_display", "unknown")
|
|
targets = ", ".join(e.get("target_displays") or []) or "—"
|
|
result = e.get("result", "—")
|
|
parts.append(f"{i}. **{ts}** — {op} by {actor} on {targets} ({result})")
|
|
|
|
if len(events) > 10:
|
|
parts.append(f"\n...and {len(events) - 10} more events.")
|
|
|
|
answer = "\n".join(parts)
|
|
|
|
return AskResponse(
|
|
answer=answer,
|
|
events=[_to_event_ref(e) for e in events],
|
|
query_info={
|
|
"entity": entity,
|
|
"start": start,
|
|
"end": end,
|
|
"event_count": len(events),
|
|
"total_matched": total,
|
|
"services_queried": query_services,
|
|
"excluded_services": excluded_services,
|
|
"mongo_query": json.dumps(query, default=str),
|
|
},
|
|
llm_used=llm_used,
|
|
llm_error=llm_error,
|
|
job_id=job_id,
|
|
)
|