Khushi Dahiya commited on
Commit
3705396
·
1 Parent(s): 1e137e7

debugging melodyflow api

Browse files
Files changed (1) hide show
  1. demos/melodyflow_api.py +34 -4
demos/melodyflow_api.py CHANGED
@@ -7,9 +7,25 @@ This version focuses on high-throughput API serving with batching
7
  """
8
 
9
  import os
10
- # Fix OpenMP threading issues
11
- os.environ.setdefault('OMP_NUM_THREADS', '1')
12
- os.environ.setdefault('MKL_NUM_THREADS', '1')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  import spaces
15
  import asyncio
@@ -31,6 +47,10 @@ from audiocraft.data.audio_utils import convert_audio
31
  from audiocraft.data.audio import audio_read, audio_write
32
  from audiocraft.models import MelodyFlow
33
 
 
 
 
 
34
 
35
  # Configuration
36
  MODEL_PREFIX = "facebook/"
@@ -160,6 +180,9 @@ class OptimizedBatchProcessor:
160
  def _process_batch(self, batch: tp.List[GenerationRequest]):
161
  """Process a batch of requests on GPU"""
162
  try:
 
 
 
163
  logging.info(f"Processing batch of {len(batch)} requests")
164
  start_time = time.time()
165
 
@@ -200,6 +223,9 @@ class OptimizedBatchProcessor:
200
 
201
  def _load_model(self, version: str):
202
  """Thread-safe model loading"""
 
 
 
203
  with self.model_lock:
204
  if self.model is None or self.model.name != version:
205
  if self.model is not None:
@@ -397,7 +423,8 @@ def create_optimized_interface():
397
  ],
398
  inputs=[model, text, solver, steps, gr.State(0.0),
399
  gr.State(False), gr.State(0.0), duration, melody],
400
- outputs=output
 
401
  )
402
 
403
  return interface
@@ -418,6 +445,9 @@ if __name__ == "__main__":
418
  format='%(asctime)s - %(levelname)s - %(message)s'
419
  )
420
 
 
 
 
421
  # Start batch processor
422
  batch_processor.start()
423
 
 
7
  """
8
 
9
  import os
10
+ import sys
11
+
12
+ # Fix OpenMP threading issues - ensure they're set early and correctly
13
+ os.environ['OMP_NUM_THREADS'] = '1'
14
+ os.environ['MKL_NUM_THREADS'] = '1'
15
+ os.environ['NUMEXPR_NUM_THREADS'] = '1'
16
+ os.environ['OPENBLAS_NUM_THREADS'] = '1'
17
+
18
+ # Additional protection against environment variable corruption
19
+ def ensure_thread_env():
20
+ """Ensure threading environment variables stay set"""
21
+ for key, value in [('OMP_NUM_THREADS', '1'), ('MKL_NUM_THREADS', '1'),
22
+ ('NUMEXPR_NUM_THREADS', '1'), ('OPENBLAS_NUM_THREADS', '1')]:
23
+ if os.environ.get(key) != value:
24
+ os.environ[key] = value
25
+ print(f"Reset {key} to {value}")
26
+
27
+ # Call it immediately
28
+ ensure_thread_env()
29
 
30
  import spaces
31
  import asyncio
 
47
  from audiocraft.data.audio import audio_read, audio_write
48
  from audiocraft.models import MelodyFlow
49
 
50
+ # Fix CSV field size limit for large audio data
51
+ import csv
52
+ csv.field_size_limit(1000000) # Increase field size limit
53
+
54
 
55
  # Configuration
56
  MODEL_PREFIX = "facebook/"
 
180
  def _process_batch(self, batch: tp.List[GenerationRequest]):
181
  """Process a batch of requests on GPU"""
182
  try:
183
+ # Ensure environment variables are still set before processing
184
+ ensure_thread_env()
185
+
186
  logging.info(f"Processing batch of {len(batch)} requests")
187
  start_time = time.time()
188
 
 
223
 
224
  def _load_model(self, version: str):
225
  """Thread-safe model loading"""
226
+ # Ensure environment variables are still set
227
+ ensure_thread_env()
228
+
229
  with self.model_lock:
230
  if self.model is None or self.model.name != version:
231
  if self.model is not None:
 
423
  ],
424
  inputs=[model, text, solver, steps, gr.State(0.0),
425
  gr.State(False), gr.State(0.0), duration, melody],
426
+ outputs=output,
427
+ cache_examples=False # Disable caching to avoid CSV field size errors
428
  )
429
 
430
  return interface
 
445
  format='%(asctime)s - %(levelname)s - %(message)s'
446
  )
447
 
448
+ # Ensure environment variables one more time before starting
449
+ ensure_thread_env()
450
+
451
  # Start batch processor
452
  batch_processor.start()
453