Luigi commited on
Commit
1b9d615
·
1 Parent(s): 7300238

Add @spaces.GPU decorator and CUDA support for HF Spaces

Browse files
Files changed (2) hide show
  1. app.py +13 -7
  2. requirements.txt +1 -1
app.py CHANGED
@@ -9,6 +9,7 @@ import tempfile
9
  import gradio as gr
10
  import torch
11
  from pathlib import Path
 
12
 
13
  # Add current directory to Python path for local zipvoice package
14
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
@@ -33,8 +34,8 @@ def load_models_and_components(model_name: str):
33
  """Load and cache models, tokenizer, vocoder, and feature extractor."""
34
  global _models_cache, _tokenizer_cache, _vocoder_cache, _feature_extractor_cache
35
 
36
- # Set device (CPU for Spaces, but could be adapted for GPU)
37
- device = torch.device("cpu")
38
 
39
  if model_name not in _models_cache:
40
  print(f"Loading {model_name} model...")
@@ -100,6 +101,7 @@ def load_models_and_components(model_name: str):
100
  model_config["feature"]["sampling_rate"])
101
 
102
 
 
103
  def synthesize_speech_gradio(
104
  text: str,
105
  prompt_audio_file,
@@ -124,7 +126,7 @@ def synthesize_speech_gradio(
124
  # Load models and components
125
  model, tokenizer, vocoder, feature_extractor, sampling_rate = load_models_and_components(model_name)
126
 
127
- device = torch.device("cpu")
128
 
129
  # Save uploaded audio to temporary file
130
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
@@ -224,7 +226,8 @@ def create_gradio_interface():
224
  model_dropdown = gr.Dropdown(
225
  choices=["zipvoice", "zipvoice_distill"],
226
  value="zipvoice",
227
- label="Model"
 
228
  )
229
 
230
  speed_slider = gr.Slider(
@@ -232,19 +235,22 @@ def create_gradio_interface():
232
  maximum=2.0,
233
  value=1.0,
234
  step=0.1,
235
- label="Speed"
 
236
  )
237
 
238
  prompt_audio = gr.File(
239
  label="Prompt Audio",
240
  file_types=["audio"],
241
- type="binary"
 
242
  )
243
 
244
  prompt_text = gr.Textbox(
245
  label="Prompt Transcription",
246
  placeholder="Enter the exact transcription of the prompt audio...",
247
- lines=2
 
248
  )
249
 
250
  generate_btn = gr.Button(
 
9
  import gradio as gr
10
  import torch
11
  from pathlib import Path
12
+ from spaces import GPU
13
 
14
  # Add current directory to Python path for local zipvoice package
15
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
 
34
  """Load and cache models, tokenizer, vocoder, and feature extractor."""
35
  global _models_cache, _tokenizer_cache, _vocoder_cache, _feature_extractor_cache
36
 
37
+ # Set device (GPU if available, otherwise CPU)
38
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
 
40
  if model_name not in _models_cache:
41
  print(f"Loading {model_name} model...")
 
101
  model_config["feature"]["sampling_rate"])
102
 
103
 
104
+ @GPU
105
  def synthesize_speech_gradio(
106
  text: str,
107
  prompt_audio_file,
 
126
  # Load models and components
127
  model, tokenizer, vocoder, feature_extractor, sampling_rate = load_models_and_components(model_name)
128
 
129
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
130
 
131
  # Save uploaded audio to temporary file
132
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
 
226
  model_dropdown = gr.Dropdown(
227
  choices=["zipvoice", "zipvoice_distill"],
228
  value="zipvoice",
229
+ label="Model",
230
+ info="zipvoice_distill is faster but slightly less accurate"
231
  )
232
 
233
  speed_slider = gr.Slider(
 
235
  maximum=2.0,
236
  value=1.0,
237
  step=0.1,
238
+ label="Speed",
239
+ info="1.0 = normal speed, >1.0 = faster, <1.0 = slower"
240
  )
241
 
242
  prompt_audio = gr.File(
243
  label="Prompt Audio",
244
  file_types=["audio"],
245
+ type="binary",
246
+ info="Upload a short audio clip (1-3 seconds recommended) to mimic the voice style"
247
  )
248
 
249
  prompt_text = gr.Textbox(
250
  label="Prompt Transcription",
251
  placeholder="Enter the exact transcription of the prompt audio...",
252
+ lines=2,
253
+ info="This should match what is spoken in the audio file"
254
  )
255
 
256
  generate_btn = gr.Button(
requirements.txt CHANGED
@@ -9,7 +9,7 @@ safetensors
9
  tensorboard
10
  vocos
11
  pydub
12
- gradio
13
 
14
  # Normalization
15
  cn2an
 
9
  tensorboard
10
  vocos
11
  pydub
12
+ gradio>=4.44.0
13
 
14
  # Normalization
15
  cn2an