Adding GPU support for transcription
This commit is contained in:
@@ -162,13 +162,33 @@ def get_model():
|
||||
global _model
|
||||
if _model is None:
|
||||
print(f"[whisper] loading model='{MODEL_NAME}' device='{WHISPER_DEVICE}' idx={WHISPER_DEVICE_INDEX} compute='{COMPUTE}' threads={WHISPER_CPU_THREADS}", flush=True)
|
||||
_model = WhisperModel(
|
||||
MODEL_NAME,
|
||||
device=WHISPER_DEVICE,
|
||||
device_index=WHISPER_DEVICE_INDEX,
|
||||
compute_type=COMPUTE,
|
||||
cpu_threads=WHISPER_CPU_THREADS,
|
||||
)
|
||||
try:
|
||||
_model = WhisperModel(
|
||||
MODEL_NAME,
|
||||
device=WHISPER_DEVICE,
|
||||
device_index=WHISPER_DEVICE_INDEX,
|
||||
compute_type=COMPUTE,
|
||||
cpu_threads=WHISPER_CPU_THREADS,
|
||||
)
|
||||
except Exception as e:
|
||||
# If GPU is selected/auto-selected but not available, some environments try to load
|
||||
# CUDA/cuDNN and fail. Fall back to CPU automatically.
|
||||
msg = str(e).lower()
|
||||
gpu_markers = [
|
||||
"cuda", "cublas", "cudnn", "hip", "rocm", "nvrtc", "gpu",
|
||||
"unable to load any of {libcudnn", "cannot load symbol cudnncreatetensordescriptor",
|
||||
]
|
||||
if WHISPER_DEVICE.lower() != "cpu" and any(m in msg for m in gpu_markers):
|
||||
print(f"[whisper] model init failed on device '{WHISPER_DEVICE}': {e}. Falling back to CPU…", flush=True)
|
||||
_model = WhisperModel(
|
||||
MODEL_NAME,
|
||||
device="cpu",
|
||||
device_index=0,
|
||||
compute_type=COMPUTE,
|
||||
cpu_threads=WHISPER_CPU_THREADS,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
return _model
|
||||
|
||||
# --- Helper: Reset model with new device and device_index ---
|
||||
@@ -191,8 +211,8 @@ def run_transcribe_with_fallback(wav_path: Path, lang):
|
||||
Try to transcribe with current model; on GPU/CUDA/HIP/ROCm/OOM errors, reset to CPU and retry once.
|
||||
Returns (segments, info) or raises exception.
|
||||
"""
|
||||
model = get_model()
|
||||
try:
|
||||
model = get_model()
|
||||
return model.transcribe(str(wav_path), vad_filter=True, language=lang)
|
||||
except Exception as e:
|
||||
msg = str(e)
|
||||
|
Reference in New Issue
Block a user