Mahmoud Elsamadony commited on
Commit
ce8875c
·
1 Parent(s): cf179b4

Update GPU Usage

Browse files
Files changed (3) hide show
  1. api_client.py +1 -1
  2. app.py +41 -6
  3. spaces.yml +7 -0
api_client.py CHANGED
@@ -125,4 +125,4 @@ if __name__ == "__main__":
125
  # Install gradio_client first:
126
  # pip install gradio_client
127
 
128
- main()
 
125
  # Install gradio_client first:
126
  # pip install gradio_client
127
 
128
+ main()
app.py CHANGED
@@ -1,6 +1,7 @@
 
1
  import os
2
  import tempfile
3
- from typing import Dict, List, Optional
4
 
5
  import gradio as gr
6
  import torch
@@ -24,8 +25,22 @@ load_dotenv()
24
  # Whisper model: use same model names as Django app (tiny, base, small, medium, large-v3)
25
  # faster-whisper will download these automatically from Hugging Face on first run
26
  WHISPER_MODEL_SIZE = os.environ.get("WHISPER_MODEL_SIZE", "large-v3")
27
- WHISPER_DEVICE = os.environ.get("WHISPER_DEVICE", "cpu")
28
- WHISPER_COMPUTE_TYPE = os.environ.get("WHISPER_COMPUTE_TYPE", "int8_float32")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  # Diarization: NVIDIA NeMo Sortformer model
31
  DIARIZATION_MODEL_NAME = os.environ.get(
@@ -57,7 +72,7 @@ expected_speakers_default = int(os.environ.get("EXPECTED_SPEAKERS", 2))
57
  # Lazy singletons for the heavy models
58
  # ---------------------------------------------------------------------------
59
  _whisper_model: Optional[WhisperModel] = None
60
- _diarization_model: Optional[SortformerEncLabelModel] = None
61
 
62
 
63
  def _ensure_snapshot(repo_id: str, local_dir: str, allow_patterns: Optional[List[str]] = None) -> str:
@@ -92,7 +107,7 @@ def _load_whisper_model() -> WhisperModel:
92
  return _whisper_model
93
 
94
 
95
- def _load_diarization_model() -> Optional[SortformerEncLabelModel]:
96
  """Load NVIDIA NeMo Sortformer diarization model lazily (singleton)"""
97
  global _diarization_model
98
  if _diarization_model is None:
@@ -111,6 +126,19 @@ def _load_diarization_model() -> Optional[SortformerEncLabelModel]:
111
 
112
  # Switch to evaluation mode
113
  _diarization_model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  # Configure streaming parameters (high latency preset for better accuracy)
116
  # See: https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2#setting-up-streaming-configuration
@@ -443,6 +471,10 @@ def build_interface() -> gr.Blocks:
443
  """
444
  )
445
 
 
 
 
 
446
  with gr.Row():
447
  audio_input = gr.Audio(type="filepath", label="Upload audio (mp3, wav, m4a, ...)")
448
  options = gr.Column()
@@ -512,10 +544,13 @@ def build_interface() -> gr.Blocks:
512
  """
513
  )
514
 
 
 
 
515
  return demo
516
 
517
 
518
  demo = build_interface()
519
 
520
  if __name__ == "__main__":
521
- demo.launch()
 
1
+ from __future__ import annotations
2
  import os
3
  import tempfile
4
+ from typing import Dict, List, Optional, Any
5
 
6
  import gradio as gr
7
  import torch
 
25
  # Whisper model: use same model names as Django app (tiny, base, small, medium, large-v3)
26
  # faster-whisper will download these automatically from Hugging Face on first run
27
  WHISPER_MODEL_SIZE = os.environ.get("WHISPER_MODEL_SIZE", "large-v3")
28
+
29
+ # Prefer GPU on Hugging Face Spaces if available, but allow override via env
30
+ def _default_device() -> str:
31
+ try:
32
+ return "cuda" if torch.cuda.is_available() else "cpu"
33
+ except Exception:
34
+ return "cpu"
35
+
36
+ WHISPER_DEVICE = os.environ.get("WHISPER_DEVICE") or _default_device()
37
+
38
+ # Choose a sensible default compute type based on device (can be overridden by env)
39
+ # - GPU: float16 is fastest and fits T4 for small/medium; use int8_float16 to save VRAM for large-v3
40
+ # - CPU: int8_float32 works well
41
+ WHISPER_COMPUTE_TYPE = os.environ.get("WHISPER_COMPUTE_TYPE") or (
42
+ "float16" if WHISPER_DEVICE == "cuda" else "int8_float32"
43
+ )
44
 
45
  # Diarization: NVIDIA NeMo Sortformer model
46
  DIARIZATION_MODEL_NAME = os.environ.get(
 
72
  # Lazy singletons for the heavy models
73
  # ---------------------------------------------------------------------------
74
  _whisper_model: Optional[WhisperModel] = None
75
+ _diarization_model: Optional[Any] = None
76
 
77
 
78
  def _ensure_snapshot(repo_id: str, local_dir: str, allow_patterns: Optional[List[str]] = None) -> str:
 
107
  return _whisper_model
108
 
109
 
110
+ def _load_diarization_model() -> Optional[Any]:
111
  """Load NVIDIA NeMo Sortformer diarization model lazily (singleton)"""
112
  global _diarization_model
113
  if _diarization_model is None:
 
126
 
127
  # Switch to evaluation mode
128
  _diarization_model.eval()
129
+
130
+ # Move to GPU if available on Spaces
131
+ if torch.cuda.is_available():
132
+ try:
133
+ _diarization_model.to("cuda")
134
+ print("[DEBUG] Moved Sortformer model to CUDA device")
135
+ except Exception:
136
+ # Fallback for modules exposing .cuda()
137
+ try:
138
+ _diarization_model.cuda()
139
+ print("[DEBUG] Moved Sortformer model to CUDA via .cuda()")
140
+ except Exception as _e:
141
+ print(f"[WARN] Could not move Sortformer model to GPU: {_e}")
142
 
143
  # Configure streaming parameters (high latency preset for better accuracy)
144
  # See: https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2#setting-up-streaming-configuration
 
471
  """
472
  )
473
 
474
+ gr.Markdown(
475
+ f"Running on device: `{WHISPER_DEVICE}` with compute type: `{WHISPER_COMPUTE_TYPE}`"
476
+ )
477
+
478
  with gr.Row():
479
  audio_input = gr.Audio(type="filepath", label="Upload audio (mp3, wav, m4a, ...)")
480
  options = gr.Column()
 
544
  """
545
  )
546
 
547
+ # Use a queue to serialize work on GPU and avoid OOM on Spaces free/shared GPUs
548
+ demo.queue(concurrency_count=1, max_size=16)
549
+
550
  return demo
551
 
552
 
553
  demo = build_interface()
554
 
555
  if __name__ == "__main__":
556
+ demo.launch()
spaces.yml CHANGED
@@ -1,3 +1,10 @@
1
  sdk: gradio
2
  sdk_version: 4.42.0
3
  python_version: 3.10
 
 
 
 
 
 
 
 
1
  sdk: gradio
2
  sdk_version: 4.42.0
3
  python_version: 3.10
4
+
5
+ # Request a GPU on Hugging Face Spaces. Common options include:
6
+ # - t4-small (free/shared tier)
7
+ # - a10g-small
8
+ # - a100-large
9
+ # Adjust as needed in the Space settings UI.
10
+ hardware: t4-small