Fixing GPU OOMs
This commit is contained in:
@@ -73,6 +73,12 @@ WHISPER_LANGUAGE = os.getenv("WHISPER_LANGUAGE", "auto").strip()
|
|||||||
WHISPER_DEVICE = os.getenv("WHISPER_DEVICE", "auto").strip()
|
WHISPER_DEVICE = os.getenv("WHISPER_DEVICE", "auto").strip()
|
||||||
WHISPER_DEVICE_INDEX = int(os.getenv("WHISPER_DEVICE_INDEX", "0"))
|
WHISPER_DEVICE_INDEX = int(os.getenv("WHISPER_DEVICE_INDEX", "0"))
|
||||||
WHISPER_CPU_THREADS = int(os.getenv("WHISPER_CPU_THREADS", "4"))
|
WHISPER_CPU_THREADS = int(os.getenv("WHISPER_CPU_THREADS", "4"))
|
||||||
|
# Decoding beam size (higher = more memory and quality). On GPU, 1 is typically fine.
|
||||||
|
try:
|
||||||
|
_beam_env = os.getenv("WHISPER_BEAM_SIZE", "")
|
||||||
|
WHISPER_BEAM_SIZE = int(_beam_env) if _beam_env.strip() else (1 if (os.getenv("WHISPER_DEVICE", "auto").strip().lower() == "cuda") else 2)
|
||||||
|
except Exception:
|
||||||
|
WHISPER_BEAM_SIZE = 1
|
||||||
|
|
||||||
# --- Host load guards / thread limits ---
|
# --- Host load guards / thread limits ---
|
||||||
# Limit ffmpeg threads (helps keep CPU in check when multiple workers run)
|
# Limit ffmpeg threads (helps keep CPU in check when multiple workers run)
|
||||||
@@ -192,16 +198,17 @@ def get_model():
|
|||||||
return _model
|
return _model
|
||||||
|
|
||||||
# --- Helper: Reset model with new device and device_index ---
|
# --- Helper: Reset model with new device and device_index ---
|
||||||
def reset_model(device: str, device_index: int | None = None):
|
def reset_model(device: str, device_index: int | None = None, compute_type: str | None = None):
|
||||||
"""Reset the global _model to a new WhisperModel with the given device and device_index."""
|
"""Reset the global _model to a new WhisperModel with the given device and device_index."""
|
||||||
global _model
|
global _model
|
||||||
idx = device_index if device_index is not None else WHISPER_DEVICE_INDEX
|
idx = device_index if device_index is not None else WHISPER_DEVICE_INDEX
|
||||||
print(f"[whisper] resetting model='{MODEL_NAME}' device='{device}' idx={idx} compute='{COMPUTE}' threads={WHISPER_CPU_THREADS}", flush=True)
|
ctype = compute_type or COMPUTE
|
||||||
|
print(f"[whisper] resetting model='{MODEL_NAME}' device='{device}' idx={idx} compute='{ctype}' threads={WHISPER_CPU_THREADS}", flush=True)
|
||||||
_model = WhisperModel(
|
_model = WhisperModel(
|
||||||
MODEL_NAME,
|
MODEL_NAME,
|
||||||
device=device,
|
device=device,
|
||||||
device_index=idx,
|
device_index=idx,
|
||||||
compute_type=COMPUTE,
|
compute_type=ctype,
|
||||||
cpu_threads=WHISPER_CPU_THREADS,
|
cpu_threads=WHISPER_CPU_THREADS,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -211,23 +218,48 @@ def run_transcribe_with_fallback(wav_path: Path, lang):
|
|||||||
Try to transcribe with current model; on GPU/CUDA/HIP/ROCm/OOM errors, reset to CPU and retry once.
|
Try to transcribe with current model; on GPU/CUDA/HIP/ROCm/OOM errors, reset to CPU and retry once.
|
||||||
Returns (segments, info) or raises exception.
|
Returns (segments, info) or raises exception.
|
||||||
"""
|
"""
|
||||||
|
# First attempt with current settings
|
||||||
try:
|
try:
|
||||||
model = get_model()
|
model = get_model()
|
||||||
return model.transcribe(str(wav_path), vad_filter=True, language=lang)
|
return model.transcribe(str(wav_path), vad_filter=True, language=lang, beam_size=WHISPER_BEAM_SIZE)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
msg = str(e)
|
msg = str(e)
|
||||||
gpu_errs = [
|
lower_msg = msg.lower()
|
||||||
"CUDA", "cublas", "out of memory", "HIP", "ROCm", "device-side assert", "CUDNN", "cudaError", "cuda runtime", "cudaMalloc"
|
oom_markers = ["out of memory", "cudaerrormemoryallocation", "cublas", "cuda", "cudnn"]
|
||||||
]
|
# If we encountered a GPU-related error, attempt progressive fallbacks on GPU before CPU
|
||||||
if any(err.lower() in msg.lower() for err in gpu_errs):
|
if any(m in lower_msg for m in oom_markers):
|
||||||
print(f"[whisper] GPU error detected: '{msg}'. Retrying on CPU...", flush=True)
|
# Decide GPU compute fallback ladder
|
||||||
reset_model("cpu", 0)
|
compute_chain = []
|
||||||
|
base = (COMPUTE or "float16").lower()
|
||||||
|
if base not in compute_chain:
|
||||||
|
compute_chain.append(base)
|
||||||
|
for c in ("int8_float16", "int8"):
|
||||||
|
if c not in compute_chain:
|
||||||
|
compute_chain.append(c)
|
||||||
|
|
||||||
|
for ctype in compute_chain[1:]: # skip the first (already tried via get_model())
|
||||||
|
try:
|
||||||
|
print(f"[whisper] GPU error '{msg}'. Retrying with compute_type='{ctype}' and beam_size=1...", flush=True)
|
||||||
|
reset_model("cuda", WHISPER_DEVICE_INDEX, compute_type=ctype)
|
||||||
|
model = get_model()
|
||||||
|
return model.transcribe(str(wav_path), vad_filter=True, language=lang, beam_size=1)
|
||||||
|
except Exception as e2:
|
||||||
|
emsg = str(e2).lower()
|
||||||
|
if not any(m in emsg for m in oom_markers):
|
||||||
|
# Non-GPU/oom error — rethrow
|
||||||
|
print(f"[whisper] GPU retry failed: {e2}", flush=True)
|
||||||
|
raise
|
||||||
|
# else continue to next fallback
|
||||||
|
# Fall back to CPU if all GPU fallbacks failed
|
||||||
|
print(f"[whisper] GPU attempts exhausted ('{msg}'); falling back to CPU int8", flush=True)
|
||||||
|
reset_model("cpu", 0, compute_type="int8")
|
||||||
try:
|
try:
|
||||||
model = get_model()
|
model = get_model()
|
||||||
return model.transcribe(str(wav_path), vad_filter=True, language=lang)
|
return model.transcribe(str(wav_path), vad_filter=True, language=lang, beam_size=1)
|
||||||
except Exception as e2:
|
except Exception as e3:
|
||||||
print(f"[whisper] CPU fallback also failed: {e2}", flush=True)
|
print(f"[whisper] CPU fallback also failed: {e3}", flush=True)
|
||||||
raise
|
raise
|
||||||
|
# Non-GPU error — rethrow
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user