- Replace max_tokens with max_completion_tokens (required by newer Azure models) - Remove hardcoded temperature (not supported by all model types) - Add response body logging on LLM API errors for easier debugging
325 lines
11 KiB
Python
325 lines
11 KiB
Python
import json
|
|
import re
|
|
from datetime import UTC, datetime, timedelta
|
|
|
|
import httpx
|
|
import structlog
|
|
from auth import require_auth
|
|
from config import LLM_API_KEY, LLM_API_VERSION, LLM_BASE_URL, LLM_MAX_EVENTS, LLM_MODEL, LLM_TIMEOUT_SECONDS
|
|
from database import events_collection
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from models.api import AskRequest, AskResponse
|
|
|
|
router = APIRouter(dependencies=[Depends(require_auth)])
|
|
logger = structlog.get_logger("aoc.ask")
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Time-range extraction
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_TIME_PATTERNS = [
|
|
(r"\blast\s+(\d+)\s+days?\b", "days"),
|
|
(r"\blast\s+(\d+)\s+hours?\b", "hours"),
|
|
(r"\blast\s+(\d+)\s+minutes?\b", "minutes"),
|
|
(r"\blast\s+week\b", "week"),
|
|
(r"\byesterday\b", "yesterday"),
|
|
(r"\btoday\b", "today"),
|
|
(r"\bin\s+the\s+last\s+(\d+)\s+days?\b", "days"),
|
|
(r"\bin\s+the\s+last\s+(\d+)\s+hours?\b", "hours"),
|
|
]
|
|
|
|
|
|
def _extract_time_range(question: str) -> tuple[str | None, str | None]:
|
|
"""Return (start_iso, end_iso) or (None, None) if no time detected."""
|
|
now = datetime.now(UTC)
|
|
q_lower = question.lower()
|
|
|
|
for pattern, unit in _TIME_PATTERNS:
|
|
m = re.search(pattern, q_lower)
|
|
if not m:
|
|
continue
|
|
if unit == "week":
|
|
start = now - timedelta(days=7)
|
|
elif unit == "yesterday":
|
|
start = now - timedelta(days=1)
|
|
elif unit == "today":
|
|
start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
|
else:
|
|
num = int(m.group(1))
|
|
delta = {"days": timedelta(days=num), "hours": timedelta(hours=num), "minutes": timedelta(minutes=num)}[
|
|
unit
|
|
]
|
|
start = now - delta
|
|
return start.isoformat().replace("+00:00", "Z"), now.isoformat().replace("+00:00", "Z")
|
|
|
|
return None, None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Entity extraction
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_ENTITY_HINTS = [
|
|
r"device\s+['\"]?([^'\"\s]+)['\"]?",
|
|
r"user\s+['\"]?([^'\"\s]+)['\"]?",
|
|
r"laptop\s+['\"]?([^'\"\s]+)['\"]?",
|
|
r"vm\s+['\"]?([^'\"\s]+)['\"]?",
|
|
r"server\s+['\"]?([^'\"\s]+)['\"]?",
|
|
r"computer\s+['\"]?([^'\"\s]+)['\"]?",
|
|
]
|
|
|
|
_EMAIL_RE = re.compile(r"[\w.+-]+@[\w-]+\.[\w.-]+")
|
|
|
|
|
|
def _extract_entity(question: str) -> str | None:
|
|
"""Best-effort extraction of the device / user / entity name."""
|
|
q_lower = question.lower()
|
|
|
|
# Look for explicit hints: "device ABC123"
|
|
for pattern in _ENTITY_HINTS:
|
|
m = re.search(pattern, q_lower)
|
|
if m:
|
|
# Extract from the original question to preserve case
|
|
start, end = m.span(1)
|
|
return question[start:end].strip().rstrip("?.!,;:")
|
|
|
|
# Look for quoted strings
|
|
m = re.search(r'"([^"]{2,50})"', question)
|
|
if m:
|
|
return m.group(1).strip()
|
|
m = re.search(r"'([^']{2,50})'", question)
|
|
if m:
|
|
return m.group(1).strip()
|
|
|
|
# Look for email addresses
|
|
m = _EMAIL_RE.search(question)
|
|
if m:
|
|
return m.group(0)
|
|
|
|
return None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# MongoDB query builder
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _build_event_query(entity: str | None, start: str | None, end: str | None) -> dict:
|
|
filters = []
|
|
|
|
if start or end:
|
|
time_filter = {}
|
|
if start:
|
|
time_filter["$gte"] = start
|
|
if end:
|
|
time_filter["$lte"] = end
|
|
filters.append({"timestamp": time_filter})
|
|
|
|
if entity:
|
|
entity_safe = re.escape(entity)
|
|
filters.append(
|
|
{
|
|
"$or": [
|
|
{"target_displays": {"$elemMatch": {"$regex": entity_safe, "$options": "i"}}},
|
|
{"actor_display": {"$regex": entity_safe, "$options": "i"}},
|
|
{"actor_upn": {"$regex": entity_safe, "$options": "i"}},
|
|
{"raw_text": {"$regex": entity_safe, "$options": "i"}},
|
|
]
|
|
}
|
|
)
|
|
|
|
return {"$and": filters} if filters else {}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# LLM summarisation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_SYSTEM_PROMPT = """You are an IT operations assistant. An administrator has asked a question about audit logs.
|
|
Your job is to read the list of audit events below and write a concise, plain-language answer.
|
|
|
|
Rules:
|
|
- Assume the reader is a non-expert admin.
|
|
- Group related events together and tell a coherent story.
|
|
- Highlight anything unusual, failed actions, or privilege escalations.
|
|
- Reference specific event numbers (e.g., "Event #3") when making claims so the user can verify.
|
|
- If there are no events, say so clearly.
|
|
- Keep the answer under 300 words.
|
|
- Do not invent events that are not in the list.
|
|
"""
|
|
|
|
|
|
def _format_events_for_llm(events: list[dict]) -> str:
|
|
lines = []
|
|
for i, e in enumerate(events, 1):
|
|
ts = e.get("timestamp") or "unknown time"
|
|
op = e.get("operation") or "unknown action"
|
|
actor = e.get("actor_display") or "unknown actor"
|
|
targets = ", ".join(e.get("target_displays") or []) or "unknown target"
|
|
svc = e.get("service") or "unknown service"
|
|
result = e.get("result") or "unknown result"
|
|
summary = e.get("display_summary") or ""
|
|
lines.append(
|
|
f"Event #{i}\n"
|
|
f" Time: {ts}\n"
|
|
f" Service: {svc}\n"
|
|
f" Action: {op}\n"
|
|
f" Actor: {actor}\n"
|
|
f" Target: {targets}\n"
|
|
f" Result: {result}\n"
|
|
f" Summary: {summary}\n"
|
|
)
|
|
return "\n".join(lines)
|
|
|
|
|
|
def _build_chat_url(base_url: str, api_version: str) -> str:
|
|
"""Construct the chat completions URL, handling Azure OpenAI endpoints."""
|
|
base = base_url.rstrip("/")
|
|
url = base if base.endswith("/chat/completions") else f"{base}/chat/completions"
|
|
if api_version:
|
|
url = f"{url}?api-version={api_version}"
|
|
return url
|
|
|
|
|
|
async def _call_llm(question: str, events: list[dict]) -> str:
|
|
if not LLM_API_KEY:
|
|
raise RuntimeError("LLM_API_KEY not configured")
|
|
|
|
context = _format_events_for_llm(events)
|
|
messages = [
|
|
{"role": "system", "content": _SYSTEM_PROMPT},
|
|
{
|
|
"role": "user",
|
|
"content": f"Question: {question}\n\nAudit events:\n{context}\n\nPlease answer the question based only on the events above.",
|
|
},
|
|
]
|
|
|
|
url = _build_chat_url(LLM_BASE_URL, LLM_API_VERSION)
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
}
|
|
# Azure OpenAI uses api-key header; standard OpenAI uses Bearer token
|
|
if "azure" in LLM_BASE_URL.lower() or "cognitiveservices" in LLM_BASE_URL.lower():
|
|
headers["api-key"] = LLM_API_KEY
|
|
else:
|
|
headers["Authorization"] = f"Bearer {LLM_API_KEY}"
|
|
|
|
payload = {
|
|
"model": LLM_MODEL,
|
|
"messages": messages,
|
|
"max_completion_tokens": 800,
|
|
}
|
|
|
|
async with httpx.AsyncClient(timeout=LLM_TIMEOUT_SECONDS) as client:
|
|
resp = await client.post(url, headers=headers, json=payload)
|
|
if resp.status_code >= 400:
|
|
body = resp.text
|
|
logger.error("LLM API error", status_code=resp.status_code, url=url, response_body=body)
|
|
raise RuntimeError(f"LLM API error {resp.status_code}: {body[:500]}")
|
|
data = resp.json()
|
|
return data["choices"][0]["message"]["content"].strip()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# API endpoint
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _to_event_ref(e: dict) -> dict:
|
|
return {
|
|
"id": e.get("id"),
|
|
"timestamp": e.get("timestamp"),
|
|
"operation": e.get("operation"),
|
|
"actor_display": e.get("actor_display"),
|
|
"target_displays": e.get("target_displays"),
|
|
"display_summary": e.get("display_summary"),
|
|
"service": e.get("service"),
|
|
"result": e.get("result"),
|
|
}
|
|
|
|
|
|
@router.post("/ask", response_model=AskResponse)
|
|
async def ask_question(body: AskRequest, user: dict = Depends(require_auth)):
|
|
question = body.question.strip()
|
|
if not question:
|
|
raise HTTPException(status_code=400, detail="Question is required")
|
|
|
|
start, end = _extract_time_range(question)
|
|
entity = _extract_entity(question)
|
|
|
|
# Default to last 7 days if no time range detected
|
|
if not start:
|
|
now = datetime.now(UTC)
|
|
start = (now - timedelta(days=7)).isoformat().replace("+00:00", "Z")
|
|
end = now.isoformat().replace("+00:00", "Z")
|
|
|
|
query = _build_event_query(entity, start, end)
|
|
|
|
try:
|
|
cursor = events_collection.find(query).sort([("timestamp", -1)]).limit(LLM_MAX_EVENTS)
|
|
events = list(cursor)
|
|
except Exception as exc:
|
|
logger.error("Failed to query events for ask", error=str(exc))
|
|
raise HTTPException(status_code=500, detail=f"Database query failed: {exc}") from exc
|
|
|
|
for e in events:
|
|
e["_id"] = str(e.get("_id", ""))
|
|
|
|
# If no events, return early
|
|
if not events:
|
|
return AskResponse(
|
|
answer="I couldn't find any audit events matching your question. Try broadening the time range or checking the spelling of the device/user name.",
|
|
events=[],
|
|
query_info={"entity": entity, "start": start, "end": end, "event_count": 0},
|
|
llm_used=False,
|
|
llm_error="LLM not used — no events found." if not LLM_API_KEY else None,
|
|
)
|
|
|
|
# Try LLM summarisation
|
|
answer = ""
|
|
llm_used = False
|
|
llm_error = None
|
|
if not LLM_API_KEY:
|
|
llm_error = "LLM_API_KEY is not configured. Set it in your .env to enable AI narrative summarisation."
|
|
else:
|
|
try:
|
|
answer = await _call_llm(question, events)
|
|
llm_used = True
|
|
except Exception as exc:
|
|
llm_error = f"LLM call failed: {exc}"
|
|
logger.warning("LLM call failed, falling back to structured summary", error=str(exc))
|
|
|
|
# Fallback: structured summary if LLM unavailable or failed
|
|
if not answer:
|
|
parts = [f"Found {len(events)} event(s)"]
|
|
if entity:
|
|
parts.append(f"related to **{entity}**")
|
|
parts.append(f"between {start[:10]} and {end[:10]}.\n")
|
|
|
|
for i, e in enumerate(events[:10], 1):
|
|
ts = e.get("timestamp", "?")[:16].replace("T", " ")
|
|
op = e.get("operation", "unknown action")
|
|
actor = e.get("actor_display", "unknown")
|
|
targets = ", ".join(e.get("target_displays") or []) or "—"
|
|
result = e.get("result", "—")
|
|
parts.append(f"{i}. **{ts}** — {op} by {actor} on {targets} ({result})")
|
|
|
|
if len(events) > 10:
|
|
parts.append(f"\n...and {len(events) - 10} more events.")
|
|
|
|
answer = "\n".join(parts)
|
|
|
|
return AskResponse(
|
|
answer=answer,
|
|
events=[_to_event_ref(e) for e in events],
|
|
query_info={
|
|
"entity": entity,
|
|
"start": start,
|
|
"end": end,
|
|
"event_count": len(events),
|
|
"mongo_query": json.dumps(query, default=str),
|
|
},
|
|
llm_used=llm_used,
|
|
llm_error=llm_error,
|
|
)
|