import json import re from datetime import UTC, datetime, timedelta import httpx import structlog from auth import require_auth from config import LLM_API_KEY, LLM_API_VERSION, LLM_BASE_URL, LLM_MAX_EVENTS, LLM_MODEL, LLM_TIMEOUT_SECONDS from database import events_collection from fastapi import APIRouter, Depends, HTTPException from models.api import AskRequest, AskResponse router = APIRouter(dependencies=[Depends(require_auth)]) logger = structlog.get_logger("aoc.ask") # --------------------------------------------------------------------------- # Intent extraction — map question keywords to relevant audit services # --------------------------------------------------------------------------- _SERVICE_INTENTS = { "intune": ["Intune"], "device": ["Intune", "Device"], "laptop": ["Intune", "Device"], "mobile": ["Intune", "Device"], "phone": ["Intune", "Device"], "ipad": ["Intune", "Device"], "app": ["Intune", "ApplicationManagement"], "application": ["Intune", "ApplicationManagement"], "policy": ["Intune", "Policy"], "compliance": ["Intune", "Policy"], "user": ["Directory", "UserManagement"], "group": ["Directory", "GroupManagement"], "role": ["Directory", "RoleManagement"], "permission": ["Directory", "RoleManagement"], "license": ["Directory", "License"], "email": ["Exchange"], "mailbox": ["Exchange"], "mail": ["Exchange"], "message": ["Exchange", "Teams"], "file": ["SharePoint"], "sharepoint": ["SharePoint"], "site": ["SharePoint"], "document": ["SharePoint"], "team": ["Teams"], "channel": ["Teams"], "meeting": ["Teams"], "call": ["Teams"], } # Services that are extremely noisy for typical admin questions. # We exclude them by default on broad questions unless the user explicitly mentions them. _NOISY_SERVICES = {"Exchange", "SharePoint"} # Services that are generally admin-relevant and kept by default. _DEFAULT_ADMIN_SERVICES = { "Directory", "UserManagement", "GroupManagement", "RoleManagement", "ApplicationManagement", "Intune", "Device", "Policy", "Teams", "License", } def _extract_intent_services(question: str) -> tuple[list[str] | None, bool]: """ Extract relevant services from the question. Returns: (services, is_explicit): - services: list of service names to query, or None for default admin set - is_explicit: True if the user explicitly mentioned a noisy service """ q_lower = question.lower() tokens = set(re.findall(r"\b[a-z]+\b", q_lower)) matched_services = set() for token, services in _SERVICE_INTENTS.items(): if token in tokens: matched_services.update(services) if matched_services: # User asked something specific — return exactly what they asked for is_explicit = not matched_services.isdisjoint(_NOISY_SERVICES) return sorted(matched_services), is_explicit # Broad question with no clear intent — default to admin-relevant services only return None, False # --------------------------------------------------------------------------- # Smart sampling — stratified by importance so the LLM sees signal, not noise # --------------------------------------------------------------------------- def _smart_sample(events: list[dict], max_events: int = 200) -> list[dict]: """ Return a curated subset that preserves diversity and prioritises signal. Tiers: 1. Failures (always valuable) 2. High-admin-value services (Intune, Device, Directory, etc.) 3. Everything else """ if len(events) <= max_events: return events high_value = { "Directory", "UserManagement", "GroupManagement", "RoleManagement", "Intune", "Device", "Policy", "ApplicationManagement", } failures = [e for e in events if str(e.get("result") or "").lower() in ("failure", "failed")] high_val = [e for e in events if e.get("service") in high_value and e not in failures] rest = [e for e in events if e not in failures and e not in high_val] # Allocate slots: half to failures+high-value, half to rest (but never let rest dominate) slots = max_events failure_cap = min(len(failures), max(10, slots // 4)) high_cap = min(len(high_val), max(20, slots // 4)) rest_cap = slots - failure_cap - high_cap sampled = failures[:failure_cap] + high_val[:high_cap] + rest[:rest_cap] # Sort back to chronological order sampled.sort(key=lambda e: e.get("timestamp") or "", reverse=True) return sampled # --------------------------------------------------------------------------- # Time-range extraction # --------------------------------------------------------------------------- _TIME_PATTERNS = [ (r"\blast\s+(\d+)\s+days?\b", "days"), (r"\blast\s+(\d+)\s+hours?\b", "hours"), (r"\blast\s+(\d+)\s+minutes?\b", "minutes"), (r"\blast\s+week\b", "week"), (r"\byesterday\b", "yesterday"), (r"\btoday\b", "today"), (r"\bin\s+the\s+last\s+(\d+)\s+days?\b", "days"), (r"\bin\s+the\s+last\s+(\d+)\s+hours?\b", "hours"), ] def _extract_time_range(question: str) -> tuple[str | None, str | None]: """Return (start_iso, end_iso) or (None, None) if no time detected.""" now = datetime.now(UTC) q_lower = question.lower() for pattern, unit in _TIME_PATTERNS: m = re.search(pattern, q_lower) if not m: continue if unit == "week": start = now - timedelta(days=7) elif unit == "yesterday": start = now - timedelta(days=1) elif unit == "today": start = now.replace(hour=0, minute=0, second=0, microsecond=0) else: num = int(m.group(1)) delta = {"days": timedelta(days=num), "hours": timedelta(hours=num), "minutes": timedelta(minutes=num)}[ unit ] start = now - delta return start.isoformat().replace("+00:00", "Z"), now.isoformat().replace("+00:00", "Z") return None, None # --------------------------------------------------------------------------- # Entity extraction # --------------------------------------------------------------------------- _ENTITY_HINTS = [ r"device\s+['\"]?([^'\"\s]+)['\"]?", r"user\s+['\"]?([^'\"\s]+)['\"]?", r"laptop\s+['\"]?([^'\"\s]+)['\"]?", r"vm\s+['\"]?([^'\"\s]+)['\"]?", r"server\s+['\"]?([^'\"\s]+)['\"]?", r"computer\s+['\"]?([^'\"\s]+)['\"]?", ] _EMAIL_RE = re.compile(r"[\w.+-]+@[\w-]+\.[\w.-]+") def _extract_entity(question: str) -> str | None: """Best-effort extraction of the device / user / entity name.""" q_lower = question.lower() # Look for explicit hints: "device ABC123" for pattern in _ENTITY_HINTS: m = re.search(pattern, q_lower) if m: # Extract from the original question to preserve case start, end = m.span(1) return question[start:end].strip().rstrip("?.!,;:") # Look for quoted strings m = re.search(r'"([^"]{2,50})"', question) if m: return m.group(1).strip() m = re.search(r"'([^']{2,50})'", question) if m: return m.group(1).strip() # Look for email addresses m = _EMAIL_RE.search(question) if m: return m.group(0) return None # --------------------------------------------------------------------------- # MongoDB query builder # --------------------------------------------------------------------------- def _build_event_query( entity: str | None, start: str | None, end: str | None, services: list[str] | None = None, actor: str | None = None, operation: str | None = None, result: str | None = None, include_tags: list[str] | None = None, exclude_tags: list[str] | None = None, ) -> dict: filters = [] if start or end: time_filter = {} if start: time_filter["$gte"] = start if end: time_filter["$lte"] = end filters.append({"timestamp": time_filter}) if entity: entity_safe = re.escape(entity) filters.append( { "$or": [ {"target_displays": {"$elemMatch": {"$regex": entity_safe, "$options": "i"}}}, {"actor_display": {"$regex": entity_safe, "$options": "i"}}, {"actor_upn": {"$regex": entity_safe, "$options": "i"}}, {"raw_text": {"$regex": entity_safe, "$options": "i"}}, ] } ) if services: filters.append({"service": {"$in": services}}) if actor: actor_safe = re.escape(actor) filters.append( { "$or": [ {"actor_display": {"$regex": actor_safe, "$options": "i"}}, {"actor_upn": {"$regex": actor_safe, "$options": "i"}}, {"actor.user.userPrincipalName": {"$regex": actor_safe, "$options": "i"}}, ] } ) if operation: filters.append({"operation": {"$regex": re.escape(operation), "$options": "i"}}) if result: filters.append({"result": {"$regex": re.escape(result), "$options": "i"}}) if include_tags: filters.append({"tags": {"$all": include_tags}}) if exclude_tags: filters.append({"tags": {"$not": {"$all": exclude_tags}}}) return {"$and": filters} if filters else {} # --------------------------------------------------------------------------- # LLM summarisation # --------------------------------------------------------------------------- _SYSTEM_PROMPT = """You are an IT operations assistant. An administrator has asked a question about audit logs. Your job is to read the data below and write a concise, plain-language answer. The input may be either: - A small list of individual audit events (numbered Event #1, #2, etc.), or - An aggregated overview with counts by service, action, result, and actor, plus sample events. Rules: - Assume the reader is a non-expert admin. - For aggregated overviews: summarise the scale, top patterns, and highlight anomalies or failures. - For small event lists: group related events together and tell a coherent story. - Highlight anything unusual, failed actions, or privilege escalations. - Reference specific event numbers (e.g., "Event #3") when making claims so the user can verify. - If the data is an aggregated subset of a larger result set, acknowledge the scale (e.g., "847 events occurred — the top pattern was..."). - If there are no events, say so clearly. - Keep the answer under 300 words. - Do not invent events or patterns that are not supported by the data. """ def _aggregate_counts(events: list[dict]) -> dict: """Build lightweight aggregation tables for large result sets.""" from collections import Counter svc_counts = Counter(e.get("service") or "Unknown" for e in events) op_counts = Counter(e.get("operation") or "Unknown" for e in events) result_counts = Counter(e.get("result") or "Unknown" for e in events) actor_counts = Counter(e.get("actor_display") or "Unknown" for e in events) return { "services": svc_counts.most_common(10), "operations": op_counts.most_common(10), "results": result_counts.most_common(5), "actors": actor_counts.most_common(10), } def _format_events_for_llm( events: list[dict], total: int | None = None, excluded_services: list[str] | None = None ) -> str: lines = [] # If we have a large result set, send aggregation + samples instead of raw dump if total is not None and total > len(events) and len(events) >= 50: lines.append(f"Result set overview: {total} total events (showing a curated sample of {len(events)}).\n") if excluded_services: lines.append(f"Note: high-volume services excluded by default: {', '.join(excluded_services)}.\n") agg = _aggregate_counts(events) lines.append("Breakdown by service:") for svc, cnt in agg["services"]: lines.append(f" {svc}: {cnt}") lines.append("\nBreakdown by action:") for op, cnt in agg["operations"]: lines.append(f" {op}: {cnt}") lines.append("\nBreakdown by result:") for res, cnt in agg["results"]: lines.append(f" {res}: {cnt}") lines.append("\nTop actors:") for actor, cnt in agg["actors"]: lines.append(f" {actor}: {cnt}") # Include failures and a few recent samples failures = [e for e in events if str(e.get("result") or "").lower() in ("failure", "failed")] if failures: lines.append(f"\nFailures ({len(failures)}):") for e in failures[:10]: ts = e.get("timestamp", "?")[:16].replace("T", " ") op = e.get("operation", "unknown") actor = e.get("actor_display", "unknown") lines.append(f" {ts} — {op} by {actor}") lines.append("\nMost recent sample events:") else: if total is not None and total > len(events): lines.append(f"Showing {len(events)} of {total} total matching events (most recent first):\n") # Always include the first N raw events as detail (up to 50) for i, e in enumerate(events[:50], 1): ts = e.get("timestamp") or "unknown time" op = e.get("operation") or "unknown action" actor = e.get("actor_display") or "unknown actor" targets = ", ".join(e.get("target_displays") or []) or "unknown target" svc = e.get("service") or "unknown service" result = e.get("result") or "unknown result" summary = e.get("display_summary") or "" lines.append( f"Event #{i}\n" f" Time: {ts}\n" f" Service: {svc}\n" f" Action: {op}\n" f" Actor: {actor}\n" f" Target: {targets}\n" f" Result: {result}\n" f" Summary: {summary}\n" ) return "\n".join(lines) def _build_chat_url(base_url: str, api_version: str) -> str: """Construct the chat completions URL, handling Azure OpenAI endpoints.""" base = base_url.rstrip("/") url = base if base.endswith("/chat/completions") else f"{base}/chat/completions" if api_version: url = f"{url}?api-version={api_version}" return url async def _call_llm( question: str, events: list[dict], total: int | None = None, excluded_services: list[str] | None = None, ) -> str: if not LLM_API_KEY: raise RuntimeError("LLM_API_KEY not configured") context = _format_events_for_llm(events, total=total, excluded_services=excluded_services) messages = [ {"role": "system", "content": _SYSTEM_PROMPT}, { "role": "user", "content": f"Question: {question}\n\nAudit events:\n{context}\n\nPlease answer the question based only on the events above.", }, ] url = _build_chat_url(LLM_BASE_URL, LLM_API_VERSION) headers = { "Content-Type": "application/json", } # Azure OpenAI uses api-key header; standard OpenAI uses Bearer token if "azure" in LLM_BASE_URL.lower() or "cognitiveservices" in LLM_BASE_URL.lower(): headers["api-key"] = LLM_API_KEY else: headers["Authorization"] = f"Bearer {LLM_API_KEY}" payload = { "model": LLM_MODEL, "messages": messages, "max_completion_tokens": 800, } async with httpx.AsyncClient(timeout=LLM_TIMEOUT_SECONDS) as client: resp = await client.post(url, headers=headers, json=payload) if resp.status_code >= 400: body = resp.text logger.error("LLM API error", status_code=resp.status_code, url=url, response_body=body) raise RuntimeError(f"LLM API error {resp.status_code}: {body[:500]}") data = resp.json() return data["choices"][0]["message"]["content"].strip() # --------------------------------------------------------------------------- # API endpoint # --------------------------------------------------------------------------- def _to_event_ref(e: dict) -> dict: return { "id": e.get("id"), "timestamp": e.get("timestamp"), "operation": e.get("operation"), "actor_display": e.get("actor_display"), "target_displays": e.get("target_displays"), "display_summary": e.get("display_summary"), "service": e.get("service"), "result": e.get("result"), } @router.post("/ask", response_model=AskResponse) async def ask_question(body: AskRequest, user: dict = Depends(require_auth)): question = body.question.strip() if not question: raise HTTPException(status_code=400, detail="Question is required") start, end = _extract_time_range(question) entity = _extract_entity(question) intent_services, explicit_noisy = _extract_intent_services(question) # Default to last 7 days if no time range detected if not start: now = datetime.now(UTC) start = (now - timedelta(days=7)).isoformat().replace("+00:00", "Z") end = now.isoformat().replace("+00:00", "Z") # ----------------------------------------------------------------------- # Decide which services to query # ----------------------------------------------------------------------- excluded_services: list[str] = [] if body.services: # User explicitly filtered via UI — respect that exactly query_services = body.services elif intent_services is not None: # NL question implies specific services query_services = intent_services else: # Broad question with no intent — exclude noisy services by default query_services = sorted(_DEFAULT_ADMIN_SERVICES) excluded_services = sorted(_NOISY_SERVICES) # ----------------------------------------------------------------------- # Build and run query # ----------------------------------------------------------------------- query = _build_event_query( entity, start, end, services=query_services, actor=body.actor, operation=body.operation, result=body.result, include_tags=body.include_tags, exclude_tags=body.exclude_tags, ) try: total = events_collection.count_documents(query) # Fetch a generous window so we can apply smart sampling in Python cursor = events_collection.find(query).sort([("timestamp", -1)]).limit(1000) raw_events = list(cursor) except Exception as exc: logger.error("Failed to query events for ask", error=str(exc)) raise HTTPException(status_code=500, detail=f"Database query failed: {exc}") from exc for e in raw_events: e["_id"] = str(e.get("_id", "")) # Apply smart sampling (preserves failures, prioritises admin-relevant services) events = _smart_sample(raw_events, max_events=LLM_MAX_EVENTS) # If no events, return early if not events: return AskResponse( answer="I couldn't find any audit events matching your question. Try broadening the time range or checking the spelling of the device/user name.", events=[], query_info={ "entity": entity, "start": start, "end": end, "event_count": 0, "total_matched": total, "services_queried": query_services, "excluded_services": excluded_services, }, llm_used=False, llm_error="LLM not used — no events found." if not LLM_API_KEY else None, ) # Try LLM summarisation answer = "" llm_used = False llm_error = None if not LLM_API_KEY: llm_error = "LLM_API_KEY is not configured. Set it in your .env to enable AI narrative summarisation." else: try: answer = await _call_llm(question, events, total=total, excluded_services=excluded_services) llm_used = True except Exception as exc: llm_error = f"LLM call failed: {exc}" logger.warning("LLM call failed, falling back to structured summary", error=str(exc)) # Fallback: structured summary if LLM unavailable or failed if not answer: parts = [f"Found {total} event(s)"] if entity: parts.append(f"related to **{entity}**") if excluded_services: parts.append(f"(excluding {', '.join(excluded_services)})") parts.append(f"between {start[:10]} and {end[:10]}.\n") for i, e in enumerate(events[:10], 1): ts = e.get("timestamp", "?")[:16].replace("T", " ") op = e.get("operation", "unknown action") actor = e.get("actor_display", "unknown") targets = ", ".join(e.get("target_displays") or []) or "—" result = e.get("result", "—") parts.append(f"{i}. **{ts}** — {op} by {actor} on {targets} ({result})") if len(events) > 10: parts.append(f"\n...and {len(events) - 10} more events.") answer = "\n".join(parts) return AskResponse( answer=answer, events=[_to_event_ref(e) for e in events], query_info={ "entity": entity, "start": start, "end": end, "event_count": len(events), "total_matched": total, "services_queried": query_services, "excluded_services": excluded_services, "mongo_query": json.dumps(query, default=str), }, llm_used=llm_used, llm_error=llm_error, )