qfuxa commited on
Commit
a5772a2
·
1 Parent(s): 7f61611

cuda or cpu auto detection

Browse files
Files changed (1) hide show
  1. src/whisper_streaming/backends.py +6 -4
src/whisper_streaming/backends.py CHANGED
@@ -4,7 +4,7 @@ import logging
4
  import io
5
  import soundfile as sf
6
  import math
7
-
8
 
9
  logger = logging.getLogger(__name__)
10
 
@@ -102,11 +102,13 @@ class FasterWhisperASR(ASRBase):
102
  else:
103
  raise ValueError("modelsize or model_dir parameter must be set")
104
 
105
- # this worked fast and reliably on NVIDIA L40
 
 
106
  model = WhisperModel(
107
  model_size_or_path,
108
- device="cuda",
109
- compute_type="float16",
110
  download_root=cache_dir,
111
  )
112
 
 
4
  import io
5
  import soundfile as sf
6
  import math
7
+ import torch
8
 
9
  logger = logging.getLogger(__name__)
10
 
 
102
  else:
103
  raise ValueError("modelsize or model_dir parameter must be set")
104
 
105
+ device = "cuda" if torch.cuda.is_available() else "cpu"
106
+ compute_type = "float16" if device == "cuda" else "float32"
107
+
108
  model = WhisperModel(
109
  model_size_or_path,
110
+ device=device,
111
+ compute_type=compute_type,
112
  download_root=cache_dir,
113
  )
114