import json import re from datetime import UTC, datetime, timedelta import httpx import structlog from auth import require_auth from config import LLM_API_KEY, LLM_API_VERSION, LLM_BASE_URL, LLM_MAX_EVENTS, LLM_MODEL, LLM_TIMEOUT_SECONDS from database import events_collection from fastapi import APIRouter, Depends, HTTPException from models.api import AskRequest, AskResponse router = APIRouter(dependencies=[Depends(require_auth)]) logger = structlog.get_logger("aoc.ask") # --------------------------------------------------------------------------- # Time-range extraction # --------------------------------------------------------------------------- _TIME_PATTERNS = [ (r"\blast\s+(\d+)\s+days?\b", "days"), (r"\blast\s+(\d+)\s+hours?\b", "hours"), (r"\blast\s+(\d+)\s+minutes?\b", "minutes"), (r"\blast\s+week\b", "week"), (r"\byesterday\b", "yesterday"), (r"\btoday\b", "today"), (r"\bin\s+the\s+last\s+(\d+)\s+days?\b", "days"), (r"\bin\s+the\s+last\s+(\d+)\s+hours?\b", "hours"), ] def _extract_time_range(question: str) -> tuple[str | None, str | None]: """Return (start_iso, end_iso) or (None, None) if no time detected.""" now = datetime.now(UTC) q_lower = question.lower() for pattern, unit in _TIME_PATTERNS: m = re.search(pattern, q_lower) if not m: continue if unit == "week": start = now - timedelta(days=7) elif unit == "yesterday": start = now - timedelta(days=1) elif unit == "today": start = now.replace(hour=0, minute=0, second=0, microsecond=0) else: num = int(m.group(1)) delta = {"days": timedelta(days=num), "hours": timedelta(hours=num), "minutes": timedelta(minutes=num)}[ unit ] start = now - delta return start.isoformat().replace("+00:00", "Z"), now.isoformat().replace("+00:00", "Z") return None, None # --------------------------------------------------------------------------- # Entity extraction # --------------------------------------------------------------------------- _ENTITY_HINTS = [ r"device\s+['\"]?([^'\"\s]+)['\"]?", r"user\s+['\"]?([^'\"\s]+)['\"]?", r"laptop\s+['\"]?([^'\"\s]+)['\"]?", r"vm\s+['\"]?([^'\"\s]+)['\"]?", r"server\s+['\"]?([^'\"\s]+)['\"]?", r"computer\s+['\"]?([^'\"\s]+)['\"]?", ] _EMAIL_RE = re.compile(r"[\w.+-]+@[\w-]+\.[\w.-]+") def _extract_entity(question: str) -> str | None: """Best-effort extraction of the device / user / entity name.""" q_lower = question.lower() # Look for explicit hints: "device ABC123" for pattern in _ENTITY_HINTS: m = re.search(pattern, q_lower) if m: # Extract from the original question to preserve case start, end = m.span(1) return question[start:end].strip().rstrip("?.!,;:") # Look for quoted strings m = re.search(r'"([^"]{2,50})"', question) if m: return m.group(1).strip() m = re.search(r"'([^']{2,50})'", question) if m: return m.group(1).strip() # Look for email addresses m = _EMAIL_RE.search(question) if m: return m.group(0) return None # --------------------------------------------------------------------------- # MongoDB query builder # --------------------------------------------------------------------------- def _build_event_query( entity: str | None, start: str | None, end: str | None, services: list[str] | None = None, actor: str | None = None, operation: str | None = None, result: str | None = None, include_tags: list[str] | None = None, exclude_tags: list[str] | None = None, ) -> dict: filters = [] if start or end: time_filter = {} if start: time_filter["$gte"] = start if end: time_filter["$lte"] = end filters.append({"timestamp": time_filter}) if entity: entity_safe = re.escape(entity) filters.append( { "$or": [ {"target_displays": {"$elemMatch": {"$regex": entity_safe, "$options": "i"}}}, {"actor_display": {"$regex": entity_safe, "$options": "i"}}, {"actor_upn": {"$regex": entity_safe, "$options": "i"}}, {"raw_text": {"$regex": entity_safe, "$options": "i"}}, ] } ) if services: filters.append({"service": {"$in": services}}) if actor: actor_safe = re.escape(actor) filters.append( { "$or": [ {"actor_display": {"$regex": actor_safe, "$options": "i"}}, {"actor_upn": {"$regex": actor_safe, "$options": "i"}}, {"actor.user.userPrincipalName": {"$regex": actor_safe, "$options": "i"}}, ] } ) if operation: filters.append({"operation": {"$regex": re.escape(operation), "$options": "i"}}) if result: filters.append({"result": {"$regex": re.escape(result), "$options": "i"}}) if include_tags: filters.append({"tags": {"$all": include_tags}}) if exclude_tags: filters.append({"tags": {"$not": {"$all": exclude_tags}}}) return {"$and": filters} if filters else {} # --------------------------------------------------------------------------- # LLM summarisation # --------------------------------------------------------------------------- _SYSTEM_PROMPT = """You are an IT operations assistant. An administrator has asked a question about audit logs. Your job is to read the data below and write a concise, plain-language answer. The input may be either: - A small list of individual audit events (numbered Event #1, #2, etc.), or - An aggregated overview with counts by service, action, result, and actor, plus sample events. Rules: - Assume the reader is a non-expert admin. - For aggregated overviews: summarise the scale, top patterns, and highlight anomalies or failures. - For small event lists: group related events together and tell a coherent story. - Highlight anything unusual, failed actions, or privilege escalations. - Reference specific event numbers (e.g., "Event #3") when making claims so the user can verify. - If the data is an aggregated subset of a larger result set, acknowledge the scale (e.g., "847 events occurred — the top pattern was..."). - If there are no events, say so clearly. - Keep the answer under 300 words. - Do not invent events or patterns that are not supported by the data. """ def _aggregate_counts(events: list[dict]) -> dict: """Build lightweight aggregation tables for large result sets.""" from collections import Counter svc_counts = Counter(e.get("service") or "Unknown" for e in events) op_counts = Counter(e.get("operation") or "Unknown" for e in events) result_counts = Counter(e.get("result") or "Unknown" for e in events) actor_counts = Counter(e.get("actor_display") or "Unknown" for e in events) return { "services": svc_counts.most_common(10), "operations": op_counts.most_common(10), "results": result_counts.most_common(5), "actors": actor_counts.most_common(10), } def _format_events_for_llm(events: list[dict], total: int | None = None) -> str: lines = [] # If we have a large result set, send aggregation + samples instead of raw dump if total is not None and total > len(events) and len(events) >= 50: lines.append(f"Result set overview: {total} total events (showing the {len(events)} most recent).\n") agg = _aggregate_counts(events) lines.append("Breakdown by service:") for svc, cnt in agg["services"]: lines.append(f" {svc}: {cnt}") lines.append("\nBreakdown by action:") for op, cnt in agg["operations"]: lines.append(f" {op}: {cnt}") lines.append("\nBreakdown by result:") for res, cnt in agg["results"]: lines.append(f" {res}: {cnt}") lines.append("\nTop actors:") for actor, cnt in agg["actors"]: lines.append(f" {actor}: {cnt}") # Include failures and a few recent samples failures = [e for e in events if str(e.get("result") or "").lower() in ("failure", "failed")] if failures: lines.append(f"\nFailures ({len(failures)}):") for e in failures[:10]: ts = e.get("timestamp", "?")[:16].replace("T", " ") op = e.get("operation", "unknown") actor = e.get("actor_display", "unknown") lines.append(f" {ts} — {op} by {actor}") lines.append("\nMost recent sample events:") else: if total is not None and total > len(events): lines.append(f"Showing {len(events)} of {total} total matching events (most recent first):\n") # Always include the first N raw events as detail (up to 50) for i, e in enumerate(events[:50], 1): ts = e.get("timestamp") or "unknown time" op = e.get("operation") or "unknown action" actor = e.get("actor_display") or "unknown actor" targets = ", ".join(e.get("target_displays") or []) or "unknown target" svc = e.get("service") or "unknown service" result = e.get("result") or "unknown result" summary = e.get("display_summary") or "" lines.append( f"Event #{i}\n" f" Time: {ts}\n" f" Service: {svc}\n" f" Action: {op}\n" f" Actor: {actor}\n" f" Target: {targets}\n" f" Result: {result}\n" f" Summary: {summary}\n" ) return "\n".join(lines) def _build_chat_url(base_url: str, api_version: str) -> str: """Construct the chat completions URL, handling Azure OpenAI endpoints.""" base = base_url.rstrip("/") url = base if base.endswith("/chat/completions") else f"{base}/chat/completions" if api_version: url = f"{url}?api-version={api_version}" return url async def _call_llm(question: str, events: list[dict], total: int | None = None) -> str: if not LLM_API_KEY: raise RuntimeError("LLM_API_KEY not configured") context = _format_events_for_llm(events, total=total) messages = [ {"role": "system", "content": _SYSTEM_PROMPT}, { "role": "user", "content": f"Question: {question}\n\nAudit events:\n{context}\n\nPlease answer the question based only on the events above.", }, ] url = _build_chat_url(LLM_BASE_URL, LLM_API_VERSION) headers = { "Content-Type": "application/json", } # Azure OpenAI uses api-key header; standard OpenAI uses Bearer token if "azure" in LLM_BASE_URL.lower() or "cognitiveservices" in LLM_BASE_URL.lower(): headers["api-key"] = LLM_API_KEY else: headers["Authorization"] = f"Bearer {LLM_API_KEY}" payload = { "model": LLM_MODEL, "messages": messages, "max_completion_tokens": 800, } async with httpx.AsyncClient(timeout=LLM_TIMEOUT_SECONDS) as client: resp = await client.post(url, headers=headers, json=payload) if resp.status_code >= 400: body = resp.text logger.error("LLM API error", status_code=resp.status_code, url=url, response_body=body) raise RuntimeError(f"LLM API error {resp.status_code}: {body[:500]}") data = resp.json() return data["choices"][0]["message"]["content"].strip() # --------------------------------------------------------------------------- # API endpoint # --------------------------------------------------------------------------- def _to_event_ref(e: dict) -> dict: return { "id": e.get("id"), "timestamp": e.get("timestamp"), "operation": e.get("operation"), "actor_display": e.get("actor_display"), "target_displays": e.get("target_displays"), "display_summary": e.get("display_summary"), "service": e.get("service"), "result": e.get("result"), } @router.post("/ask", response_model=AskResponse) async def ask_question(body: AskRequest, user: dict = Depends(require_auth)): question = body.question.strip() if not question: raise HTTPException(status_code=400, detail="Question is required") start, end = _extract_time_range(question) entity = _extract_entity(question) # Default to last 7 days if no time range detected if not start: now = datetime.now(UTC) start = (now - timedelta(days=7)).isoformat().replace("+00:00", "Z") end = now.isoformat().replace("+00:00", "Z") query = _build_event_query( entity, start, end, services=body.services, actor=body.actor, operation=body.operation, result=body.result, include_tags=body.include_tags, exclude_tags=body.exclude_tags, ) try: total = events_collection.count_documents(query) cursor = events_collection.find(query).sort([("timestamp", -1)]).limit(LLM_MAX_EVENTS) events = list(cursor) except Exception as exc: logger.error("Failed to query events for ask", error=str(exc)) raise HTTPException(status_code=500, detail=f"Database query failed: {exc}") from exc for e in events: e["_id"] = str(e.get("_id", "")) # If no events, return early if not events: return AskResponse( answer="I couldn't find any audit events matching your question. Try broadening the time range or checking the spelling of the device/user name.", events=[], query_info={"entity": entity, "start": start, "end": end, "event_count": 0}, llm_used=False, llm_error="LLM not used — no events found." if not LLM_API_KEY else None, ) # Try LLM summarisation answer = "" llm_used = False llm_error = None if not LLM_API_KEY: llm_error = "LLM_API_KEY is not configured. Set it in your .env to enable AI narrative summarisation." else: try: answer = await _call_llm(question, events, total=total) llm_used = True except Exception as exc: llm_error = f"LLM call failed: {exc}" logger.warning("LLM call failed, falling back to structured summary", error=str(exc)) # Fallback: structured summary if LLM unavailable or failed if not answer: parts = [f"Found {len(events)} event(s)"] if entity: parts.append(f"related to **{entity}**") parts.append(f"between {start[:10]} and {end[:10]}.\n") for i, e in enumerate(events[:10], 1): ts = e.get("timestamp", "?")[:16].replace("T", " ") op = e.get("operation", "unknown action") actor = e.get("actor_display", "unknown") targets = ", ".join(e.get("target_displays") or []) or "—" result = e.get("result", "—") parts.append(f"{i}. **{ts}** — {op} by {actor} on {targets} ({result})") if len(events) > 10: parts.append(f"\n...and {len(events) - 10} more events.") answer = "\n".join(parts) return AskResponse( answer=answer, events=[_to_event_ref(e) for e in events], query_info={ "entity": entity, "start": start, "end": end, "event_count": len(events), "total_matched": total, "mongo_query": json.dumps(query, default=str), }, llm_used=llm_used, llm_error=llm_error, )