diff --git a/app/worker.py b/app/worker.py index 202191a..008dde0 100644 --- a/app/worker.py +++ b/app/worker.py @@ -73,6 +73,12 @@ WHISPER_LANGUAGE = os.getenv("WHISPER_LANGUAGE", "auto").strip() WHISPER_DEVICE = os.getenv("WHISPER_DEVICE", "auto").strip() WHISPER_DEVICE_INDEX = int(os.getenv("WHISPER_DEVICE_INDEX", "0")) WHISPER_CPU_THREADS = int(os.getenv("WHISPER_CPU_THREADS", "4")) +# Decoding beam size (higher = more memory and quality). On GPU, 1 is typically fine. +try: + _beam_env = os.getenv("WHISPER_BEAM_SIZE", "") + WHISPER_BEAM_SIZE = int(_beam_env) if _beam_env.strip() else (1 if (os.getenv("WHISPER_DEVICE", "auto").strip().lower() == "cuda") else 2) +except Exception: + WHISPER_BEAM_SIZE = 1 # --- Host load guards / thread limits --- # Limit ffmpeg threads (helps keep CPU in check when multiple workers run) @@ -192,16 +198,17 @@ def get_model(): return _model # --- Helper: Reset model with new device and device_index --- -def reset_model(device: str, device_index: int | None = None): +def reset_model(device: str, device_index: int | None = None, compute_type: str | None = None): """Reset the global _model to a new WhisperModel with the given device and device_index.""" global _model idx = device_index if device_index is not None else WHISPER_DEVICE_INDEX - print(f"[whisper] resetting model='{MODEL_NAME}' device='{device}' idx={idx} compute='{COMPUTE}' threads={WHISPER_CPU_THREADS}", flush=True) + ctype = compute_type or COMPUTE + print(f"[whisper] resetting model='{MODEL_NAME}' device='{device}' idx={idx} compute='{ctype}' threads={WHISPER_CPU_THREADS}", flush=True) _model = WhisperModel( MODEL_NAME, device=device, device_index=idx, - compute_type=COMPUTE, + compute_type=ctype, cpu_threads=WHISPER_CPU_THREADS, ) @@ -211,23 +218,48 @@ def run_transcribe_with_fallback(wav_path: Path, lang): Try to transcribe with current model; on GPU/CUDA/HIP/ROCm/OOM errors, reset to CPU and retry once. Returns (segments, info) or raises exception. """ + # First attempt with current settings try: model = get_model() - return model.transcribe(str(wav_path), vad_filter=True, language=lang) + return model.transcribe(str(wav_path), vad_filter=True, language=lang, beam_size=WHISPER_BEAM_SIZE) except Exception as e: msg = str(e) - gpu_errs = [ - "CUDA", "cublas", "out of memory", "HIP", "ROCm", "device-side assert", "CUDNN", "cudaError", "cuda runtime", "cudaMalloc" - ] - if any(err.lower() in msg.lower() for err in gpu_errs): - print(f"[whisper] GPU error detected: '{msg}'. Retrying on CPU...", flush=True) - reset_model("cpu", 0) + lower_msg = msg.lower() + oom_markers = ["out of memory", "cudaerrormemoryallocation", "cublas", "cuda", "cudnn"] + # If we encountered a GPU-related error, attempt progressive fallbacks on GPU before CPU + if any(m in lower_msg for m in oom_markers): + # Decide GPU compute fallback ladder + compute_chain = [] + base = (COMPUTE or "float16").lower() + if base not in compute_chain: + compute_chain.append(base) + for c in ("int8_float16", "int8"): + if c not in compute_chain: + compute_chain.append(c) + + for ctype in compute_chain[1:]: # skip the first (already tried via get_model()) + try: + print(f"[whisper] GPU error '{msg}'. Retrying with compute_type='{ctype}' and beam_size=1...", flush=True) + reset_model("cuda", WHISPER_DEVICE_INDEX, compute_type=ctype) + model = get_model() + return model.transcribe(str(wav_path), vad_filter=True, language=lang, beam_size=1) + except Exception as e2: + emsg = str(e2).lower() + if not any(m in emsg for m in oom_markers): + # Non-GPU/oom error — rethrow + print(f"[whisper] GPU retry failed: {e2}", flush=True) + raise + # else continue to next fallback + # Fall back to CPU if all GPU fallbacks failed + print(f"[whisper] GPU attempts exhausted ('{msg}'); falling back to CPU int8", flush=True) + reset_model("cpu", 0, compute_type="int8") try: model = get_model() - return model.transcribe(str(wav_path), vad_filter=True, language=lang) - except Exception as e2: - print(f"[whisper] CPU fallback also failed: {e2}", flush=True) + return model.transcribe(str(wav_path), vad_filter=True, language=lang, beam_size=1) + except Exception as e3: + print(f"[whisper] CPU fallback also failed: {e3}", flush=True) raise + # Non-GPU error — rethrow raise