Fixing GPU OOMs

This commit is contained in:
2025-10-05 16:02:42 +02:00
parent db8e98c6c3
commit a48236753b

View File

@@ -73,6 +73,12 @@ WHISPER_LANGUAGE = os.getenv("WHISPER_LANGUAGE", "auto").strip()
WHISPER_DEVICE = os.getenv("WHISPER_DEVICE", "auto").strip()
WHISPER_DEVICE_INDEX = int(os.getenv("WHISPER_DEVICE_INDEX", "0"))
WHISPER_CPU_THREADS = int(os.getenv("WHISPER_CPU_THREADS", "4"))
# Decoding beam size (higher = more memory and quality). On GPU, 1 is typically fine.
try:
_beam_env = os.getenv("WHISPER_BEAM_SIZE", "")
WHISPER_BEAM_SIZE = int(_beam_env) if _beam_env.strip() else (1 if (os.getenv("WHISPER_DEVICE", "auto").strip().lower() == "cuda") else 2)
except Exception:
WHISPER_BEAM_SIZE = 1
# --- Host load guards / thread limits ---
# Limit ffmpeg threads (helps keep CPU in check when multiple workers run)
@@ -192,16 +198,17 @@ def get_model():
return _model
# --- Helper: Reset model with new device and device_index ---
def reset_model(device: str, device_index: int | None = None):
def reset_model(device: str, device_index: int | None = None, compute_type: str | None = None):
"""Reset the global _model to a new WhisperModel with the given device and device_index."""
global _model
idx = device_index if device_index is not None else WHISPER_DEVICE_INDEX
print(f"[whisper] resetting model='{MODEL_NAME}' device='{device}' idx={idx} compute='{COMPUTE}' threads={WHISPER_CPU_THREADS}", flush=True)
ctype = compute_type or COMPUTE
print(f"[whisper] resetting model='{MODEL_NAME}' device='{device}' idx={idx} compute='{ctype}' threads={WHISPER_CPU_THREADS}", flush=True)
_model = WhisperModel(
MODEL_NAME,
device=device,
device_index=idx,
compute_type=COMPUTE,
compute_type=ctype,
cpu_threads=WHISPER_CPU_THREADS,
)
@@ -211,23 +218,48 @@ def run_transcribe_with_fallback(wav_path: Path, lang):
Try to transcribe with current model; on GPU/CUDA/HIP/ROCm/OOM errors, reset to CPU and retry once.
Returns (segments, info) or raises exception.
"""
# First attempt with current settings
try:
model = get_model()
return model.transcribe(str(wav_path), vad_filter=True, language=lang)
return model.transcribe(str(wav_path), vad_filter=True, language=lang, beam_size=WHISPER_BEAM_SIZE)
except Exception as e:
msg = str(e)
gpu_errs = [
"CUDA", "cublas", "out of memory", "HIP", "ROCm", "device-side assert", "CUDNN", "cudaError", "cuda runtime", "cudaMalloc"
]
if any(err.lower() in msg.lower() for err in gpu_errs):
print(f"[whisper] GPU error detected: '{msg}'. Retrying on CPU...", flush=True)
reset_model("cpu", 0)
lower_msg = msg.lower()
oom_markers = ["out of memory", "cudaerrormemoryallocation", "cublas", "cuda", "cudnn"]
# If we encountered a GPU-related error, attempt progressive fallbacks on GPU before CPU
if any(m in lower_msg for m in oom_markers):
# Decide GPU compute fallback ladder
compute_chain = []
base = (COMPUTE or "float16").lower()
if base not in compute_chain:
compute_chain.append(base)
for c in ("int8_float16", "int8"):
if c not in compute_chain:
compute_chain.append(c)
for ctype in compute_chain[1:]: # skip the first (already tried via get_model())
try:
print(f"[whisper] GPU error '{msg}'. Retrying with compute_type='{ctype}' and beam_size=1...", flush=True)
reset_model("cuda", WHISPER_DEVICE_INDEX, compute_type=ctype)
model = get_model()
return model.transcribe(str(wav_path), vad_filter=True, language=lang, beam_size=1)
except Exception as e2:
emsg = str(e2).lower()
if not any(m in emsg for m in oom_markers):
# Non-GPU/oom error — rethrow
print(f"[whisper] GPU retry failed: {e2}", flush=True)
raise
# else continue to next fallback
# Fall back to CPU if all GPU fallbacks failed
print(f"[whisper] GPU attempts exhausted ('{msg}'); falling back to CPU int8", flush=True)
reset_model("cpu", 0, compute_type="int8")
try:
model = get_model()
return model.transcribe(str(wav_path), vad_filter=True, language=lang)
except Exception as e2:
print(f"[whisper] CPU fallback also failed: {e2}", flush=True)
return model.transcribe(str(wav_path), vad_filter=True, language=lang, beam_size=1)
except Exception as e3:
print(f"[whisper] CPU fallback also failed: {e3}", flush=True)
raise
# Non-GPU error — rethrow
raise