Fixing GPU OOMs
This commit is contained in:
@@ -73,6 +73,12 @@ WHISPER_LANGUAGE = os.getenv("WHISPER_LANGUAGE", "auto").strip()
|
||||
WHISPER_DEVICE = os.getenv("WHISPER_DEVICE", "auto").strip()
|
||||
WHISPER_DEVICE_INDEX = int(os.getenv("WHISPER_DEVICE_INDEX", "0"))
|
||||
WHISPER_CPU_THREADS = int(os.getenv("WHISPER_CPU_THREADS", "4"))
|
||||
# Decoding beam size (higher = more memory and quality). On GPU, 1 is typically fine.
|
||||
try:
|
||||
_beam_env = os.getenv("WHISPER_BEAM_SIZE", "")
|
||||
WHISPER_BEAM_SIZE = int(_beam_env) if _beam_env.strip() else (1 if (os.getenv("WHISPER_DEVICE", "auto").strip().lower() == "cuda") else 2)
|
||||
except Exception:
|
||||
WHISPER_BEAM_SIZE = 1
|
||||
|
||||
# --- Host load guards / thread limits ---
|
||||
# Limit ffmpeg threads (helps keep CPU in check when multiple workers run)
|
||||
@@ -192,16 +198,17 @@ def get_model():
|
||||
return _model
|
||||
|
||||
# --- Helper: Reset model with new device and device_index ---
|
||||
def reset_model(device: str, device_index: int | None = None):
|
||||
def reset_model(device: str, device_index: int | None = None, compute_type: str | None = None):
|
||||
"""Reset the global _model to a new WhisperModel with the given device and device_index."""
|
||||
global _model
|
||||
idx = device_index if device_index is not None else WHISPER_DEVICE_INDEX
|
||||
print(f"[whisper] resetting model='{MODEL_NAME}' device='{device}' idx={idx} compute='{COMPUTE}' threads={WHISPER_CPU_THREADS}", flush=True)
|
||||
ctype = compute_type or COMPUTE
|
||||
print(f"[whisper] resetting model='{MODEL_NAME}' device='{device}' idx={idx} compute='{ctype}' threads={WHISPER_CPU_THREADS}", flush=True)
|
||||
_model = WhisperModel(
|
||||
MODEL_NAME,
|
||||
device=device,
|
||||
device_index=idx,
|
||||
compute_type=COMPUTE,
|
||||
compute_type=ctype,
|
||||
cpu_threads=WHISPER_CPU_THREADS,
|
||||
)
|
||||
|
||||
@@ -211,23 +218,48 @@ def run_transcribe_with_fallback(wav_path: Path, lang):
|
||||
Try to transcribe with current model; on GPU/CUDA/HIP/ROCm/OOM errors, reset to CPU and retry once.
|
||||
Returns (segments, info) or raises exception.
|
||||
"""
|
||||
# First attempt with current settings
|
||||
try:
|
||||
model = get_model()
|
||||
return model.transcribe(str(wav_path), vad_filter=True, language=lang)
|
||||
return model.transcribe(str(wav_path), vad_filter=True, language=lang, beam_size=WHISPER_BEAM_SIZE)
|
||||
except Exception as e:
|
||||
msg = str(e)
|
||||
gpu_errs = [
|
||||
"CUDA", "cublas", "out of memory", "HIP", "ROCm", "device-side assert", "CUDNN", "cudaError", "cuda runtime", "cudaMalloc"
|
||||
]
|
||||
if any(err.lower() in msg.lower() for err in gpu_errs):
|
||||
print(f"[whisper] GPU error detected: '{msg}'. Retrying on CPU...", flush=True)
|
||||
reset_model("cpu", 0)
|
||||
lower_msg = msg.lower()
|
||||
oom_markers = ["out of memory", "cudaerrormemoryallocation", "cublas", "cuda", "cudnn"]
|
||||
# If we encountered a GPU-related error, attempt progressive fallbacks on GPU before CPU
|
||||
if any(m in lower_msg for m in oom_markers):
|
||||
# Decide GPU compute fallback ladder
|
||||
compute_chain = []
|
||||
base = (COMPUTE or "float16").lower()
|
||||
if base not in compute_chain:
|
||||
compute_chain.append(base)
|
||||
for c in ("int8_float16", "int8"):
|
||||
if c not in compute_chain:
|
||||
compute_chain.append(c)
|
||||
|
||||
for ctype in compute_chain[1:]: # skip the first (already tried via get_model())
|
||||
try:
|
||||
print(f"[whisper] GPU error '{msg}'. Retrying with compute_type='{ctype}' and beam_size=1...", flush=True)
|
||||
reset_model("cuda", WHISPER_DEVICE_INDEX, compute_type=ctype)
|
||||
model = get_model()
|
||||
return model.transcribe(str(wav_path), vad_filter=True, language=lang, beam_size=1)
|
||||
except Exception as e2:
|
||||
emsg = str(e2).lower()
|
||||
if not any(m in emsg for m in oom_markers):
|
||||
# Non-GPU/oom error — rethrow
|
||||
print(f"[whisper] GPU retry failed: {e2}", flush=True)
|
||||
raise
|
||||
# else continue to next fallback
|
||||
# Fall back to CPU if all GPU fallbacks failed
|
||||
print(f"[whisper] GPU attempts exhausted ('{msg}'); falling back to CPU int8", flush=True)
|
||||
reset_model("cpu", 0, compute_type="int8")
|
||||
try:
|
||||
model = get_model()
|
||||
return model.transcribe(str(wav_path), vad_filter=True, language=lang)
|
||||
except Exception as e2:
|
||||
print(f"[whisper] CPU fallback also failed: {e2}", flush=True)
|
||||
return model.transcribe(str(wav_path), vad_filter=True, language=lang, beam_size=1)
|
||||
except Exception as e3:
|
||||
print(f"[whisper] CPU fallback also failed: {e3}", flush=True)
|
||||
raise
|
||||
# Non-GPU error — rethrow
|
||||
raise
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user