diff --git a/.env.example b/.env.example index 43392b4..0de03e2 100644 --- a/.env.example +++ b/.env.example @@ -50,6 +50,11 @@ LLM_MAX_EVENTS=200 LLM_TIMEOUT_SECONDS=30 LLM_API_VERSION= +# Valkey (caching + async job queue for LLM calls) +# In Docker Compose, this is set automatically to redis://redis:6379/0 +# For local dev, start Valkey with: docker run -d -p 6379:6379 valkey/valkey:8-alpine +REDIS_URL=redis://localhost:6379/0 + # Optional: privacy / access control # Hide entire services from users without PRIVACY_SERVICE_ROLES # PRIVACY_SERVICES=Exchange,Teams diff --git a/ROADMAP.md b/ROADMAP.md index 872a51f..1bcec00 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -65,9 +65,10 @@ Goal: add AI-powered analysis and external tool integration. - [x] AI feature flag (`AI_FEATURES_ENABLED`) to gate LLM-dependent features - [x] Natural language query endpoint (`/api/ask`) with intent extraction and smart sampling - [x] MCP (Model Context Protocol) server for Claude Desktop / Cursor integration +- [x] Valkey caching for LLM responses and frequent queries +- [x] Async queue (arq) for LLM requests to prevent timeout/cost explosions at scale - [ ] Advanced analytics dashboard (trending operations, anomaly detection) -- [ ] Redis caching for LLM responses and frequent queries -- [ ] Async queue for LLM requests to prevent timeout/cost explosions at scale ## Completed in this PR -All Phase 5 items marked done were implemented in v1.3.0. +All Phase 5 items marked done were implemented in v1.3.0–v1.5.0. +Redis caching + async queue implemented in v1.6.0, switched to Valkey. diff --git a/VERSION b/VERSION index 3e1ad72..dc1e644 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.5.0 \ No newline at end of file +1.6.0 diff --git a/backend/config.py b/backend/config.py index 9343b00..c349b00 100644 --- a/backend/config.py +++ b/backend/config.py @@ -57,6 +57,9 @@ class Settings(BaseSettings): PRIVACY_SENSITIVE_OPERATIONS: str = "" # comma-separated, e.g. "MailItemsAccessed,Search-Mailbox,Send" PRIVACY_SERVICE_ROLES: str = "" # comma-separated, e.g. "SecurityAdministrator,ComplianceAdministrator" + # Redis (caching + async job queue) + REDIS_URL: str = "redis://localhost:6379/0" + _settings = Settings() @@ -95,3 +98,5 @@ LLM_API_VERSION = _settings.LLM_API_VERSION PRIVACY_SERVICES = {s.strip() for s in _settings.PRIVACY_SERVICES.split(",") if s.strip()} PRIVACY_SENSITIVE_OPERATIONS = {o.strip() for o in _settings.PRIVACY_SENSITIVE_OPERATIONS.split(",") if o.strip()} PRIVACY_SERVICE_ROLES = {r.strip() for r in _settings.PRIVACY_SERVICE_ROLES.split(",") if r.strip()} + +REDIS_URL = _settings.REDIS_URL diff --git a/backend/jobs.py b/backend/jobs.py new file mode 100644 index 0000000..19c2ee6 --- /dev/null +++ b/backend/jobs.py @@ -0,0 +1,113 @@ +"""arq job functions for async LLM processing.""" + +import hashlib +import json + +import structlog +from arq.connections import RedisSettings +from config import REDIS_URL + +logger = structlog.get_logger("aoc.jobs") + +# --------------------------------------------------------------------------- +# Cache helpers +# --------------------------------------------------------------------------- + +CACHE_TTL_ASK = 3600 # 1 hour +CACHE_TTL_EXPLAIN = 86400 # 24 hours + + +def _ask_cache_key(question: str, filters: dict, events: list) -> str: + payload = json.dumps({"q": question, "f": filters, "e": [e.get("id") for e in events]}, sort_keys=True) + return f"aoc:cache:ask:{hashlib.md5(payload.encode()).hexdigest()}" + + +def _explain_cache_key(event_id: str) -> str: + return f"aoc:cache:explain:{event_id}" + + +async def get_cached_ask(redis, question: str, filters: dict, events: list) -> dict | None: + key = _ask_cache_key(question, filters, events) + raw = await redis.get(key) + if raw: + return json.loads(raw) + return None + + +async def set_cached_ask(redis, question: str, filters: dict, events: list, result: dict): + key = _ask_cache_key(question, filters, events) + await redis.setex(key, CACHE_TTL_ASK, json.dumps(result, default=str)) + + +async def get_cached_explain(redis, event_id: str) -> dict | None: + key = _explain_cache_key(event_id) + raw = await redis.get(key) + if raw: + return json.loads(raw) + return None + + +async def set_cached_explain(redis, event_id: str, result: dict): + key = _explain_cache_key(event_id) + await redis.setex(key, CACHE_TTL_EXPLAIN, json.dumps(result, default=str)) + + +# --------------------------------------------------------------------------- +# arq job functions +# --------------------------------------------------------------------------- + +async def process_ask_question(ctx, question: str, filters: dict, events: list, total: int, excluded_services: list | None): + """Background job: call LLM for /api/ask and cache result.""" + from routes.ask import _call_llm + + redis = ctx["redis"] + try: + answer = await _call_llm(question, events, total=total, excluded_services=excluded_services) + result = {"status": "completed", "answer": answer, "llm_used": True, "llm_error": None} + except Exception as exc: + logger.warning("Async ask LLM failed", error=str(exc)) + result = {"status": "failed", "answer": "", "llm_used": False, "llm_error": str(exc)} + + await set_cached_ask(redis, question, filters, events, result) + return result + + +async def process_explain_event(ctx, event_id: str, event: dict, related: list): + """Background job: call LLM for /api/events/{id}/explain and cache result.""" + from routes.ask import _explain_event + + redis = ctx["redis"] + try: + explanation = await _explain_event(event, related) + result = {"status": "completed", "explanation": explanation, "llm_used": True, "llm_error": None} + except Exception as exc: + logger.warning("Async explain LLM failed", error=str(exc)) + result = {"status": "failed", "explanation": "", "llm_used": False, "llm_error": str(exc)} + + await set_cached_explain(redis, event_id, result) + return result + + +# --------------------------------------------------------------------------- +# arq worker configuration +# --------------------------------------------------------------------------- + +async def startup(ctx): + from redis.asyncio import Redis + + ctx["redis"] = Redis.from_url(REDIS_URL, decode_responses=True) + + +async def shutdown(ctx): + await ctx["redis"].close() + + +class WorkerSettings: + functions = [process_ask_question, process_explain_event] + redis_settings = RedisSettings.from_dsn(REDIS_URL) + on_startup = startup + on_shutdown = shutdown + max_jobs = 10 + job_timeout = 120 + keep_result = 3600 + keep_result_forever = False diff --git a/backend/main.py b/backend/main.py index 53851e5..8236e1e 100644 --- a/backend/main.py +++ b/backend/main.py @@ -19,6 +19,7 @@ from routes.events import router as events_router from routes.fetch import router as fetch_router from routes.fetch import run_fetch from routes.health import router as health_router +from routes.jobs import router as jobs_router from routes.rules import router as rules_router from routes.saved_searches import router as saved_searches_router from routes.webhooks import router as webhooks_router @@ -122,6 +123,7 @@ if AI_FEATURES_ENABLED: app.mount("/mcp", mcp_asgi) app.include_router(saved_searches_router, prefix="/api") app.include_router(rules_router, prefix="/api") +app.include_router(jobs_router, prefix="/api") @app.get("/health") @@ -176,3 +178,6 @@ async def stop_periodic_fetch(): task.cancel() with suppress(Exception): await task + from redis_client import close_redis_connections + + await close_redis_connections() diff --git a/backend/models/api.py b/backend/models/api.py index f495b20..a52b704 100644 --- a/backend/models/api.py +++ b/backend/models/api.py @@ -82,6 +82,7 @@ class AskRequest(BaseModel): end: str | None = None include_tags: list[str] | None = None exclude_tags: list[str] | None = None + async_mode: bool = False # enqueue async job instead of waiting class AskEventRef(BaseModel): @@ -101,3 +102,4 @@ class AskResponse(BaseModel): query_info: dict llm_used: bool llm_error: str | None = None + job_id: str | None = None diff --git a/backend/redis_client.py b/backend/redis_client.py new file mode 100644 index 0000000..2419f7d --- /dev/null +++ b/backend/redis_client.py @@ -0,0 +1,36 @@ +"""Async Redis client singleton for caching and job queue.""" + +import redis.asyncio as aioredis +from arq import create_pool +from arq.connections import ArqRedis, RedisSettings +from config import REDIS_URL + +_arq_pool: ArqRedis | None = None +_plain_redis: aioredis.Redis | None = None + + +async def get_arq_pool() -> ArqRedis: + """Return a shared arq pool (ArqRedis extends redis.asyncio.Redis).""" + global _arq_pool + if _arq_pool is None: + _arq_pool = await create_pool(RedisSettings.from_dsn(REDIS_URL)) + return _arq_pool + + +async def get_redis() -> aioredis.Redis: + """Return a shared plain async Redis client.""" + global _plain_redis + if _plain_redis is None: + _plain_redis = aioredis.from_url(REDIS_URL, decode_responses=True) + return _plain_redis + + +async def close_redis_connections(): + """Close all Redis connections (call on shutdown).""" + global _arq_pool, _plain_redis + if _arq_pool: + await _arq_pool.close() + _arq_pool = None + if _plain_redis: + await _plain_redis.close() + _plain_redis = None diff --git a/backend/requirements.txt b/backend/requirements.txt index d61cd4f..0adcdf5 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -14,3 +14,5 @@ prometheus-client httpx gunicorn mcp +redis +arq diff --git a/backend/routes/ask.py b/backend/routes/ask.py index d52fd8c..591ef20 100644 --- a/backend/routes/ask.py +++ b/backend/routes/ask.py @@ -18,7 +18,9 @@ from config import ( ) from database import events_collection from fastapi import APIRouter, Depends, HTTPException +from jobs import get_cached_ask, get_cached_explain, set_cached_ask, set_cached_explain from models.api import AskRequest, AskResponse +from redis_client import get_arq_pool router = APIRouter(dependencies=[Depends(require_auth)]) logger = structlog.get_logger("aoc.ask") @@ -640,14 +642,23 @@ async def explain_event(event_id: str, user: dict = Depends(require_auth)): "llm_error": "LLM_API_KEY not configured", } + # Check cache first + redis = await get_arq_pool() + cached = await get_cached_explain(redis, event_id) + if cached: + cached["related_count"] = len(related) + return cached + try: explanation = await _explain_event(event, related) - return { + result = { "explanation": explanation, "llm_used": True, "llm_error": None, "related_count": len(related), } + await set_cached_explain(redis, event_id, result) + return result except Exception as exc: logger.warning("Event explanation failed", error=str(exc)) return { @@ -746,19 +757,70 @@ async def ask_question(body: AskRequest, user: dict = Depends(require_auth)): llm_error="LLM not used — no events found." if not LLM_API_KEY else None, ) - # Try LLM summarisation + # Try LLM summarisation (with caching + optional async) answer = "" llm_used = False llm_error = None - if not LLM_API_KEY: - llm_error = "LLM_API_KEY is not configured. Set it in your .env to enable AI narrative summarisation." + job_id = None + + filters_snapshot = { + "services": body.services, + "actor": body.actor, + "operation": body.operation, + "result": body.result, + "start": body.start, + "end": body.end, + "include_tags": body.include_tags, + "exclude_tags": body.exclude_tags, + } + + if LLM_API_KEY: + redis = await get_arq_pool() + cached = await get_cached_ask(redis, question, filters_snapshot, events) + if cached: + answer = cached.get("answer", "") + llm_used = cached.get("llm_used", False) + llm_error = cached.get("llm_error") + elif body.async_mode: + pool = await get_arq_pool() + job = await pool.enqueue_job( + "process_ask_question", + question, + filters_snapshot, + events, + total, + excluded_services, + ) + job_id = job.job_id if job else None + return AskResponse( + answer="Your question is being processed. Poll /api/jobs/{job_id} for the result.", + events=[_to_event_ref(e) for e in events], + query_info={ + "entity": entity, + "start": start, + "end": end, + "event_count": len(events), + "total_matched": total, + "services_queried": query_services, + "excluded_services": excluded_services, + "mongo_query": json.dumps(query, default=str), + }, + llm_used=False, + llm_error=None, + job_id=job_id, + ) + else: + try: + answer = await _call_llm(question, events, total=total, excluded_services=excluded_services) + llm_used = True + await set_cached_ask(redis, question, filters_snapshot, events, { + "answer": answer, "llm_used": True, "llm_error": None, + }) + except Exception as exc: + llm_error = f"LLM call failed: {exc}" + logger.warning("LLM call failed, falling back to structured summary", error=str(exc)) else: - try: - answer = await _call_llm(question, events, total=total, excluded_services=excluded_services) - llm_used = True - except Exception as exc: - llm_error = f"LLM call failed: {exc}" - logger.warning("LLM call failed, falling back to structured summary", error=str(exc)) + llm_error = "LLM_API_KEY is not configured. Set it in your .env to enable AI narrative summarisation." # Fallback: structured summary if LLM unavailable or failed if not answer: @@ -797,4 +859,5 @@ async def ask_question(body: AskRequest, user: dict = Depends(require_auth)): }, llm_used=llm_used, llm_error=llm_error, + job_id=job_id, ) diff --git a/backend/routes/jobs.py b/backend/routes/jobs.py new file mode 100644 index 0000000..841c7e1 --- /dev/null +++ b/backend/routes/jobs.py @@ -0,0 +1,43 @@ +"""Job status endpoints for async LLM operations.""" + +from arq.jobs import Job, JobStatus +from auth import require_auth +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel +from redis_client import get_redis + +router = APIRouter(dependencies=[Depends(require_auth)]) + + +class JobStatusResponse(BaseModel): + job_id: str + status: str # queued, in_progress, complete, not_found, deferred + result: dict | None = None + error: str | None = None + + +@router.get("/jobs/{job_id}", response_model=JobStatusResponse) +async def get_job_status(job_id: str, user: dict = Depends(require_auth)): + """Poll for the result of an async LLM job.""" + redis = await get_redis() + job = Job(job_id, redis) + status = await job.status() + + if status == JobStatus.not_found: + raise HTTPException(status_code=404, detail="Job not found") + + result = None + error = None + if status == JobStatus.complete: + try: + result_data = await job.result(timeout=0) + result = result_data if isinstance(result_data, dict) else {"data": str(result_data)} + except Exception as exc: + error = str(exc) + + return JobStatusResponse( + job_id=job_id, + status=status.value, + result=result, + error=error, + ) diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 482a62b..8711fe8 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -49,6 +49,20 @@ def client(mock_events_collection, mock_watermarks_collection, monkeypatch): monkeypatch.setattr("rules.rules_collection", audit_db["alert_rules"]) monkeypatch.setattr("routes.rules.rules_collection", audit_db["alert_rules"]) + # Mock Redis so tests don't require a running Redis server + class FakeRedis: + async def get(self, key): + return None + async def setex(self, key, ttl, value): + pass + + async def fake_get_arq_pool(): + return FakeRedis() + + monkeypatch.setattr("redis_client.get_arq_pool", fake_get_arq_pool) + monkeypatch.setattr("routes.ask.get_arq_pool", fake_get_arq_pool) + monkeypatch.setattr("routes.jobs.get_redis", fake_get_arq_pool) + from main import app return TestClient(app) diff --git a/backend/tests/test_api.py b/backend/tests/test_api.py index 1d0f89f..02d4e44 100644 --- a/backend/tests/test_api.py +++ b/backend/tests/test_api.py @@ -89,6 +89,17 @@ def test_explain_event_with_llm_mock(client, mock_events_collection, monkeypatch monkeypatch.setattr("routes.ask._explain_event", fake_explain) + class FakeRedis: + async def get(self, key): + return None + async def setex(self, key, ttl, value): + pass + + async def fake_get_arq_pool(): + return FakeRedis() + + monkeypatch.setattr("routes.ask.get_arq_pool", fake_get_arq_pool) + mock_events_collection.insert_one( { "id": "evt-explain2", diff --git a/backend/tests/test_ask.py b/backend/tests/test_ask.py index ddcaff7..8c46814 100644 --- a/backend/tests/test_ask.py +++ b/backend/tests/test_ask.py @@ -350,3 +350,124 @@ class TestAskEndpoint: data = response.json() assert data["query_info"]["event_count"] == 1 assert data["events"][0]["id"] == "evt-bob" + + +class TestAskCaching: + def test_ask_cache_hit_returns_cached_answer(self, client, mock_events_collection, monkeypatch): + """If the answer is cached, the LLM should not be called.""" + now = datetime.now(UTC) + mock_events_collection.insert_one( + { + "id": "evt-cache", + "timestamp": now.isoformat(), + "service": "Directory", + "operation": "Add user", + "result": "success", + "actor_display": "Alice", + "target_displays": ["USER-001"], + "display_summary": "summary", + "raw_text": "raw", + } + ) + + llm_called = False + + async def fake_llm(question, events, total=None, excluded_services=None): + nonlocal llm_called + llm_called = True + return "This should NOT appear." + + monkeypatch.setattr("routes.ask.LLM_API_KEY", "fake-key") + monkeypatch.setattr("routes.ask._call_llm", fake_llm) + + # Pre-populate cache with a specific answer + class CachingFakeRedis: + def __init__(self): + self.store = {} + + async def get(self, key): + return self.store.get(key) + + async def setex(self, key, ttl, value): + self.store[key] = value + + redis = CachingFakeRedis() + # Seed cache with the exact filters the endpoint will generate + import asyncio + from jobs import set_cached_ask + filters_snapshot = { + "services": None, + "actor": None, + "operation": None, + "result": None, + "start": None, + "end": None, + "include_tags": None, + "exclude_tags": None, + } + asyncio.run(set_cached_ask(redis, "What happened to USER-001?", filters_snapshot, [{"id": "evt-cache"}], {"answer": "Cached answer!", "llm_used": True, "llm_error": None})) + + async def fake_get_arq_pool(): + return redis + + monkeypatch.setattr("routes.ask.get_arq_pool", fake_get_arq_pool) + + response = client.post("/api/ask", json={"question": "What happened to USER-001?"}) + assert response.status_code == 200 + data = response.json() + assert data["answer"] == "Cached answer!" + assert data["llm_used"] is True + assert llm_called is False + + def test_ask_async_mode_returns_job_id(self, client, mock_events_collection, monkeypatch): + """Async mode should return immediately with a job_id.""" + now = datetime.now(UTC) + mock_events_collection.insert_one( + { + "id": "evt-async", + "timestamp": now.isoformat(), + "service": "Directory", + "operation": "Add user", + "result": "success", + "actor_display": "Alice", + "target_displays": ["USER-001"], + "display_summary": "summary", + "raw_text": "raw", + } + ) + + monkeypatch.setattr("routes.ask.LLM_API_KEY", "fake-key") + + # Mock arq pool to capture enqueue_job call + class FakeArqPool: + def __init__(self): + self.enqueued = [] + + async def get(self, key): + return None + + async def setex(self, key, ttl, value): + pass + + async def enqueue_job(self, func, *args, **kwargs): + from unittest.mock import MagicMock + job = MagicMock() + job.job_id = "job-12345" + self.enqueued.append((func, args, kwargs)) + return job + + pool = FakeArqPool() + + async def fake_get_arq_pool(): + return pool + + monkeypatch.setattr("routes.ask.get_arq_pool", fake_get_arq_pool) + + response = client.post("/api/ask", json={"question": "What happened to USER-001?", "async_mode": True}) + assert response.status_code == 200 + data = response.json() + assert data["job_id"] == "job-12345" + assert data["llm_used"] is False + assert "being processed" in data["answer"] + assert len(pool.enqueued) == 1 + assert pool.enqueued[0][0] == "process_ask_question" diff --git a/docker-compose.prod.yml b/docker-compose.prod.yml index 4a65383..83daf9a 100644 --- a/docker-compose.prod.yml +++ b/docker-compose.prod.yml @@ -1,4 +1,19 @@ services: + redis: + image: valkey/valkey:8-alpine + container_name: aoc-redis + restart: always + volumes: + - redis_data:/data + networks: + - aoc-internal + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 10s + timeout: 3s + retries: 5 + start_period: 5s + mongo: image: mongo:7 container_name: aoc-mongo @@ -27,9 +42,12 @@ services: - .env environment: MONGO_URI: mongodb://${MONGO_ROOT_USERNAME}:${MONGO_ROOT_PASSWORD}@mongo:27017/ + REDIS_URL: redis://redis:6379/0 depends_on: mongo: condition: service_healthy + redis: + condition: service_healthy networks: - aoc-internal healthcheck: @@ -39,6 +57,24 @@ services: retries: 3 start_period: 10s + worker: + image: git.cqre.net/cqrenet/aoc-backend:${AOC_VERSION:-latest} + container_name: aoc-worker + restart: always + env_file: + - .env + environment: + MONGO_URI: mongodb://${MONGO_ROOT_USERNAME}:${MONGO_ROOT_PASSWORD}@mongo:27017/ + REDIS_URL: redis://redis:6379/0 + command: ["arq", "jobs.WorkerSettings"] + depends_on: + redis: + condition: service_healthy + mongo: + condition: service_healthy + networks: + - aoc-internal + nginx: image: nginx:alpine container_name: aoc-nginx @@ -58,6 +94,7 @@ services: volumes: mongo_data: + redis_data: networks: aoc-internal: diff --git a/docker-compose.yml b/docker-compose.yml index 3ecde35..bb841e8 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,4 +1,13 @@ services: + redis: + image: valkey/valkey:8-alpine + container_name: aoc-redis + restart: always + ports: + - "6379:6379" + volumes: + - redis_data:/data + mongo: image: mongo:7 container_name: aoc-mongo @@ -21,10 +30,27 @@ services: - .env environment: MONGO_URI: mongodb://${MONGO_ROOT_USERNAME}:${MONGO_ROOT_PASSWORD}@mongo:${MONGO_PORT}/ + REDIS_URL: redis://redis:6379/0 depends_on: - mongo + - redis ports: - "8000:8000" + worker: + build: ./backend + container_name: aoc-worker + restart: always + env_file: + - .env + environment: + MONGO_URI: mongodb://${MONGO_ROOT_USERNAME}:${MONGO_ROOT_PASSWORD}@mongo:${MONGO_PORT}/ + REDIS_URL: redis://redis:6379/0 + command: ["arq", "jobs.WorkerSettings"] + depends_on: + - redis + - mongo + volumes: mongo_data: + redis_data: