Khushi Dahiya
commited on
Commit
·
3705396
1
Parent(s):
1e137e7
debugging melodyflow api
Browse files- demos/melodyflow_api.py +34 -4
demos/melodyflow_api.py
CHANGED
|
@@ -7,9 +7,25 @@ This version focuses on high-throughput API serving with batching
|
|
| 7 |
"""
|
| 8 |
|
| 9 |
import os
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
import spaces
|
| 15 |
import asyncio
|
|
@@ -31,6 +47,10 @@ from audiocraft.data.audio_utils import convert_audio
|
|
| 31 |
from audiocraft.data.audio import audio_read, audio_write
|
| 32 |
from audiocraft.models import MelodyFlow
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
# Configuration
|
| 36 |
MODEL_PREFIX = "facebook/"
|
|
@@ -160,6 +180,9 @@ class OptimizedBatchProcessor:
|
|
| 160 |
def _process_batch(self, batch: tp.List[GenerationRequest]):
|
| 161 |
"""Process a batch of requests on GPU"""
|
| 162 |
try:
|
|
|
|
|
|
|
|
|
|
| 163 |
logging.info(f"Processing batch of {len(batch)} requests")
|
| 164 |
start_time = time.time()
|
| 165 |
|
|
@@ -200,6 +223,9 @@ class OptimizedBatchProcessor:
|
|
| 200 |
|
| 201 |
def _load_model(self, version: str):
|
| 202 |
"""Thread-safe model loading"""
|
|
|
|
|
|
|
|
|
|
| 203 |
with self.model_lock:
|
| 204 |
if self.model is None or self.model.name != version:
|
| 205 |
if self.model is not None:
|
|
@@ -397,7 +423,8 @@ def create_optimized_interface():
|
|
| 397 |
],
|
| 398 |
inputs=[model, text, solver, steps, gr.State(0.0),
|
| 399 |
gr.State(False), gr.State(0.0), duration, melody],
|
| 400 |
-
outputs=output
|
|
|
|
| 401 |
)
|
| 402 |
|
| 403 |
return interface
|
|
@@ -418,6 +445,9 @@ if __name__ == "__main__":
|
|
| 418 |
format='%(asctime)s - %(levelname)s - %(message)s'
|
| 419 |
)
|
| 420 |
|
|
|
|
|
|
|
|
|
|
| 421 |
# Start batch processor
|
| 422 |
batch_processor.start()
|
| 423 |
|
|
|
|
| 7 |
"""
|
| 8 |
|
| 9 |
import os
|
| 10 |
+
import sys
|
| 11 |
+
|
| 12 |
+
# Fix OpenMP threading issues - ensure they're set early and correctly
|
| 13 |
+
os.environ['OMP_NUM_THREADS'] = '1'
|
| 14 |
+
os.environ['MKL_NUM_THREADS'] = '1'
|
| 15 |
+
os.environ['NUMEXPR_NUM_THREADS'] = '1'
|
| 16 |
+
os.environ['OPENBLAS_NUM_THREADS'] = '1'
|
| 17 |
+
|
| 18 |
+
# Additional protection against environment variable corruption
|
| 19 |
+
def ensure_thread_env():
|
| 20 |
+
"""Ensure threading environment variables stay set"""
|
| 21 |
+
for key, value in [('OMP_NUM_THREADS', '1'), ('MKL_NUM_THREADS', '1'),
|
| 22 |
+
('NUMEXPR_NUM_THREADS', '1'), ('OPENBLAS_NUM_THREADS', '1')]:
|
| 23 |
+
if os.environ.get(key) != value:
|
| 24 |
+
os.environ[key] = value
|
| 25 |
+
print(f"Reset {key} to {value}")
|
| 26 |
+
|
| 27 |
+
# Call it immediately
|
| 28 |
+
ensure_thread_env()
|
| 29 |
|
| 30 |
import spaces
|
| 31 |
import asyncio
|
|
|
|
| 47 |
from audiocraft.data.audio import audio_read, audio_write
|
| 48 |
from audiocraft.models import MelodyFlow
|
| 49 |
|
| 50 |
+
# Fix CSV field size limit for large audio data
|
| 51 |
+
import csv
|
| 52 |
+
csv.field_size_limit(1000000) # Increase field size limit
|
| 53 |
+
|
| 54 |
|
| 55 |
# Configuration
|
| 56 |
MODEL_PREFIX = "facebook/"
|
|
|
|
| 180 |
def _process_batch(self, batch: tp.List[GenerationRequest]):
|
| 181 |
"""Process a batch of requests on GPU"""
|
| 182 |
try:
|
| 183 |
+
# Ensure environment variables are still set before processing
|
| 184 |
+
ensure_thread_env()
|
| 185 |
+
|
| 186 |
logging.info(f"Processing batch of {len(batch)} requests")
|
| 187 |
start_time = time.time()
|
| 188 |
|
|
|
|
| 223 |
|
| 224 |
def _load_model(self, version: str):
|
| 225 |
"""Thread-safe model loading"""
|
| 226 |
+
# Ensure environment variables are still set
|
| 227 |
+
ensure_thread_env()
|
| 228 |
+
|
| 229 |
with self.model_lock:
|
| 230 |
if self.model is None or self.model.name != version:
|
| 231 |
if self.model is not None:
|
|
|
|
| 423 |
],
|
| 424 |
inputs=[model, text, solver, steps, gr.State(0.0),
|
| 425 |
gr.State(False), gr.State(0.0), duration, melody],
|
| 426 |
+
outputs=output,
|
| 427 |
+
cache_examples=False # Disable caching to avoid CSV field size errors
|
| 428 |
)
|
| 429 |
|
| 430 |
return interface
|
|
|
|
| 445 |
format='%(asctime)s - %(levelname)s - %(message)s'
|
| 446 |
)
|
| 447 |
|
| 448 |
+
# Ensure environment variables one more time before starting
|
| 449 |
+
ensure_thread_env()
|
| 450 |
+
|
| 451 |
# Start batch processor
|
| 452 |
batch_processor.start()
|
| 453 |
|