From b4e504a87b7475fb0e6476ed76fc46de442b4a49 Mon Sep 17 00:00:00 2001 From: Tomas Kracmar Date: Mon, 20 Apr 2026 17:41:21 +0200 Subject: [PATCH] feat: intent-aware querying + smart sampling for large audit datasets MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add keyword-based intent extraction: 'device' → Intune, 'user' → Directory, etc. - Broad questions without intent auto-exclude noisy services (Exchange, SharePoint) - Smart stratified sampling: failures always included, high-value services prioritised - Fetch up to 1000 events from MongoDB, then curate best 200 for the LLM - Excluded services noted in LLM prompt and query_info so the admin knows the scope --- VERSION | 2 +- backend/routes/ask.py | 189 +++++++++++++++++++++++++++++++++++--- backend/tests/test_ask.py | 2 +- 3 files changed, 180 insertions(+), 13 deletions(-) diff --git a/VERSION b/VERSION index 7e099ec..a77d7d9 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.2.6 \ No newline at end of file +1.2.7 \ No newline at end of file diff --git a/backend/routes/ask.py b/backend/routes/ask.py index 0a7f4dd..68b92fd 100644 --- a/backend/routes/ask.py +++ b/backend/routes/ask.py @@ -13,6 +13,129 @@ from models.api import AskRequest, AskResponse router = APIRouter(dependencies=[Depends(require_auth)]) logger = structlog.get_logger("aoc.ask") +# --------------------------------------------------------------------------- +# Intent extraction — map question keywords to relevant audit services +# --------------------------------------------------------------------------- + +_SERVICE_INTENTS = { + "intune": ["Intune"], + "device": ["Intune", "Device"], + "laptop": ["Intune", "Device"], + "mobile": ["Intune", "Device"], + "phone": ["Intune", "Device"], + "ipad": ["Intune", "Device"], + "app": ["Intune", "ApplicationManagement"], + "application": ["Intune", "ApplicationManagement"], + "policy": ["Intune", "Policy"], + "compliance": ["Intune", "Policy"], + "user": ["Directory", "UserManagement"], + "group": ["Directory", "GroupManagement"], + "role": ["Directory", "RoleManagement"], + "permission": ["Directory", "RoleManagement"], + "license": ["Directory", "License"], + "email": ["Exchange"], + "mailbox": ["Exchange"], + "mail": ["Exchange"], + "message": ["Exchange", "Teams"], + "file": ["SharePoint"], + "sharepoint": ["SharePoint"], + "site": ["SharePoint"], + "document": ["SharePoint"], + "team": ["Teams"], + "channel": ["Teams"], + "meeting": ["Teams"], + "call": ["Teams"], +} + +# Services that are extremely noisy for typical admin questions. +# We exclude them by default on broad questions unless the user explicitly mentions them. +_NOISY_SERVICES = {"Exchange", "SharePoint"} + +# Services that are generally admin-relevant and kept by default. +_DEFAULT_ADMIN_SERVICES = { + "Directory", + "UserManagement", + "GroupManagement", + "RoleManagement", + "ApplicationManagement", + "Intune", + "Device", + "Policy", + "Teams", + "License", +} + + +def _extract_intent_services(question: str) -> tuple[list[str] | None, bool]: + """ + Extract relevant services from the question. + + Returns: + (services, is_explicit): + - services: list of service names to query, or None for default admin set + - is_explicit: True if the user explicitly mentioned a noisy service + """ + q_lower = question.lower() + tokens = set(re.findall(r"\b[a-z]+\b", q_lower)) + + matched_services = set() + for token, services in _SERVICE_INTENTS.items(): + if token in tokens: + matched_services.update(services) + + if matched_services: + # User asked something specific — return exactly what they asked for + is_explicit = not matched_services.isdisjoint(_NOISY_SERVICES) + return sorted(matched_services), is_explicit + + # Broad question with no clear intent — default to admin-relevant services only + return None, False + + +# --------------------------------------------------------------------------- +# Smart sampling — stratified by importance so the LLM sees signal, not noise +# --------------------------------------------------------------------------- + + +def _smart_sample(events: list[dict], max_events: int = 200) -> list[dict]: + """ + Return a curated subset that preserves diversity and prioritises signal. + + Tiers: + 1. Failures (always valuable) + 2. High-admin-value services (Intune, Device, Directory, etc.) + 3. Everything else + """ + if len(events) <= max_events: + return events + + high_value = { + "Directory", + "UserManagement", + "GroupManagement", + "RoleManagement", + "Intune", + "Device", + "Policy", + "ApplicationManagement", + } + + failures = [e for e in events if str(e.get("result") or "").lower() in ("failure", "failed")] + high_val = [e for e in events if e.get("service") in high_value and e not in failures] + rest = [e for e in events if e not in failures and e not in high_val] + + # Allocate slots: half to failures+high-value, half to rest (but never let rest dominate) + slots = max_events + failure_cap = min(len(failures), max(10, slots // 4)) + high_cap = min(len(high_val), max(20, slots // 4)) + rest_cap = slots - failure_cap - high_cap + + sampled = failures[:failure_cap] + high_val[:high_cap] + rest[:rest_cap] + # Sort back to chronological order + sampled.sort(key=lambda e: e.get("timestamp") or "", reverse=True) + return sampled + + # --------------------------------------------------------------------------- # Time-range extraction # --------------------------------------------------------------------------- @@ -203,12 +326,16 @@ def _aggregate_counts(events: list[dict]) -> dict: } -def _format_events_for_llm(events: list[dict], total: int | None = None) -> str: +def _format_events_for_llm( + events: list[dict], total: int | None = None, excluded_services: list[str] | None = None +) -> str: lines = [] # If we have a large result set, send aggregation + samples instead of raw dump if total is not None and total > len(events) and len(events) >= 50: - lines.append(f"Result set overview: {total} total events (showing the {len(events)} most recent).\n") + lines.append(f"Result set overview: {total} total events (showing a curated sample of {len(events)}).\n") + if excluded_services: + lines.append(f"Note: high-volume services excluded by default: {', '.join(excluded_services)}.\n") agg = _aggregate_counts(events) lines.append("Breakdown by service:") for svc, cnt in agg["services"]: @@ -267,11 +394,16 @@ def _build_chat_url(base_url: str, api_version: str) -> str: return url -async def _call_llm(question: str, events: list[dict], total: int | None = None) -> str: +async def _call_llm( + question: str, + events: list[dict], + total: int | None = None, + excluded_services: list[str] | None = None, +) -> str: if not LLM_API_KEY: raise RuntimeError("LLM_API_KEY not configured") - context = _format_events_for_llm(events, total=total) + context = _format_events_for_llm(events, total=total, excluded_services=excluded_services) messages = [ {"role": "system", "content": _SYSTEM_PROMPT}, { @@ -332,6 +464,7 @@ async def ask_question(body: AskRequest, user: dict = Depends(require_auth)): start, end = _extract_time_range(question) entity = _extract_entity(question) + intent_services, explicit_noisy = _extract_intent_services(question) # Default to last 7 days if no time range detected if not start: @@ -339,11 +472,29 @@ async def ask_question(body: AskRequest, user: dict = Depends(require_auth)): start = (now - timedelta(days=7)).isoformat().replace("+00:00", "Z") end = now.isoformat().replace("+00:00", "Z") + # ----------------------------------------------------------------------- + # Decide which services to query + # ----------------------------------------------------------------------- + excluded_services: list[str] = [] + if body.services: + # User explicitly filtered via UI — respect that exactly + query_services = body.services + elif intent_services is not None: + # NL question implies specific services + query_services = intent_services + else: + # Broad question with no intent — exclude noisy services by default + query_services = sorted(_DEFAULT_ADMIN_SERVICES) + excluded_services = sorted(_NOISY_SERVICES) + + # ----------------------------------------------------------------------- + # Build and run query + # ----------------------------------------------------------------------- query = _build_event_query( entity, start, end, - services=body.services, + services=query_services, actor=body.actor, operation=body.operation, result=body.result, @@ -353,21 +504,33 @@ async def ask_question(body: AskRequest, user: dict = Depends(require_auth)): try: total = events_collection.count_documents(query) - cursor = events_collection.find(query).sort([("timestamp", -1)]).limit(LLM_MAX_EVENTS) - events = list(cursor) + # Fetch a generous window so we can apply smart sampling in Python + cursor = events_collection.find(query).sort([("timestamp", -1)]).limit(1000) + raw_events = list(cursor) except Exception as exc: logger.error("Failed to query events for ask", error=str(exc)) raise HTTPException(status_code=500, detail=f"Database query failed: {exc}") from exc - for e in events: + for e in raw_events: e["_id"] = str(e.get("_id", "")) + # Apply smart sampling (preserves failures, prioritises admin-relevant services) + events = _smart_sample(raw_events, max_events=LLM_MAX_EVENTS) + # If no events, return early if not events: return AskResponse( answer="I couldn't find any audit events matching your question. Try broadening the time range or checking the spelling of the device/user name.", events=[], - query_info={"entity": entity, "start": start, "end": end, "event_count": 0}, + query_info={ + "entity": entity, + "start": start, + "end": end, + "event_count": 0, + "total_matched": total, + "services_queried": query_services, + "excluded_services": excluded_services, + }, llm_used=False, llm_error="LLM not used — no events found." if not LLM_API_KEY else None, ) @@ -380,7 +543,7 @@ async def ask_question(body: AskRequest, user: dict = Depends(require_auth)): llm_error = "LLM_API_KEY is not configured. Set it in your .env to enable AI narrative summarisation." else: try: - answer = await _call_llm(question, events, total=total) + answer = await _call_llm(question, events, total=total, excluded_services=excluded_services) llm_used = True except Exception as exc: llm_error = f"LLM call failed: {exc}" @@ -388,9 +551,11 @@ async def ask_question(body: AskRequest, user: dict = Depends(require_auth)): # Fallback: structured summary if LLM unavailable or failed if not answer: - parts = [f"Found {len(events)} event(s)"] + parts = [f"Found {total} event(s)"] if entity: parts.append(f"related to **{entity}**") + if excluded_services: + parts.append(f"(excluding {', '.join(excluded_services)})") parts.append(f"between {start[:10]} and {end[:10]}.\n") for i, e in enumerate(events[:10], 1): @@ -415,6 +580,8 @@ async def ask_question(body: AskRequest, user: dict = Depends(require_auth)): "end": end, "event_count": len(events), "total_matched": total, + "services_queried": query_services, + "excluded_services": excluded_services, "mongo_query": json.dumps(query, default=str), }, llm_used=llm_used, diff --git a/backend/tests/test_ask.py b/backend/tests/test_ask.py index 84b8561..ddcaff7 100644 --- a/backend/tests/test_ask.py +++ b/backend/tests/test_ask.py @@ -236,7 +236,7 @@ class TestAskEndpoint: } ) - async def fake_llm(question, events, total=None): + async def fake_llm(question, events, total=None, excluded_services=None): return "The device had a failed wipe attempt." monkeypatch.setattr("routes.ask.LLM_API_KEY", "fake-key")