Khushi Dahiya commited on
Commit
1e137e7
·
1 Parent(s): 7edd7b4

trying out melodyflow api implementation

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. demos/melodyflow_api.py +439 -0
README.md CHANGED
@@ -5,7 +5,7 @@ tags:
5
  - music generation
6
  - music editing
7
  - flow matching
8
- app_file: demos/melodyflow_app.py
9
  emoji: 🎵
10
  colorFrom: gray
11
  colorTo: blue
 
5
  - music generation
6
  - music editing
7
  - flow matching
8
+ app_file: demos/melodyflow_api.py
9
  emoji: 🎵
10
  colorFrom: gray
11
  colorTo: blue
demos/melodyflow_api.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ """
5
+ Optimized MelodyFlow API for concurrent request handling on T4 GPU
6
+ This version focuses on high-throughput API serving with batching
7
+ """
8
+
9
+ import os
10
+ # Fix OpenMP threading issues
11
+ os.environ.setdefault('OMP_NUM_THREADS', '1')
12
+ os.environ.setdefault('MKL_NUM_THREADS', '1')
13
+
14
+ import spaces
15
+ import asyncio
16
+ import threading
17
+ import time
18
+ import uuid
19
+ import base64
20
+ import logging
21
+ from concurrent.futures import ThreadPoolExecutor, Future
22
+ from queue import Queue, Empty
23
+ from tempfile import NamedTemporaryFile
24
+ from pathlib import Path
25
+ import typing as tp
26
+ from dataclasses import dataclass
27
+
28
+ import torch
29
+ import gradio as gr
30
+ from audiocraft.data.audio_utils import convert_audio
31
+ from audiocraft.data.audio import audio_read, audio_write
32
+ from audiocraft.models import MelodyFlow
33
+
34
+
35
+ # Configuration
36
+ MODEL_PREFIX = "facebook/"
37
+ BATCH_SIZE = 4 # Optimal for T4 GPU memory
38
+ BATCH_TIMEOUT = 1.5 # Seconds to wait for batch formation
39
+ MAX_QUEUE_SIZE = 100
40
+ MAX_CONCURRENT_BATCHES = 2 # Number of concurrent batch processors
41
+
42
+
43
+ @dataclass
44
+ class GenerationRequest:
45
+ """Represents a single generation request"""
46
+ request_id: str
47
+ text: str
48
+ melody: tp.Optional[str]
49
+ solver: str
50
+ steps: int
51
+ target_flowstep: float
52
+ regularize: bool
53
+ regularization_strength: float
54
+ duration: float
55
+ model: str
56
+ future: Future
57
+ created_at: float
58
+
59
+
60
+ class OptimizedBatchProcessor:
61
+ """Highly optimized batch processor for T4 GPU"""
62
+
63
+ def __init__(self):
64
+ self.model = None
65
+ self.model_lock = threading.Lock()
66
+ self.request_queue = Queue(maxsize=MAX_QUEUE_SIZE)
67
+ self.current_batch = []
68
+ self.batch_start_time = None
69
+ self.processing = False
70
+ self.stop_event = threading.Event()
71
+ self.executor = ThreadPoolExecutor(max_workers=MAX_CONCURRENT_BATCHES)
72
+
73
+ def start(self):
74
+ """Start the batch processing service"""
75
+ self.thread = threading.Thread(target=self._batch_loop, daemon=True)
76
+ self.thread.start()
77
+ logging.info("Batch processor started")
78
+
79
+ def stop(self):
80
+ """Stop the batch processing service"""
81
+ self.stop_event.set()
82
+ self.executor.shutdown(wait=True)
83
+
84
+ def submit_request(self, text: str, melody: tp.Optional[str],
85
+ solver: str, steps: int, target_flowstep: float,
86
+ regularize: bool, regularization_strength: float,
87
+ duration: float, model: str) -> Future:
88
+ """Submit a generation request and return a future"""
89
+
90
+ request = GenerationRequest(
91
+ request_id=str(uuid.uuid4()),
92
+ text=text,
93
+ melody=melody,
94
+ solver=solver,
95
+ steps=steps,
96
+ target_flowstep=target_flowstep,
97
+ regularize=regularize,
98
+ regularization_strength=regularization_strength,
99
+ duration=duration,
100
+ model=model,
101
+ future=Future(),
102
+ created_at=time.time()
103
+ )
104
+
105
+ try:
106
+ self.request_queue.put_nowait(request)
107
+ return request.future
108
+ except:
109
+ # Queue is full
110
+ request.future.set_exception(Exception("Server is busy, please try again"))
111
+ return request.future
112
+
113
+ def _batch_loop(self):
114
+ """Main batch processing loop"""
115
+ while not self.stop_event.is_set():
116
+ try:
117
+ # Try to get a request
118
+ try:
119
+ request = self.request_queue.get(timeout=0.1)
120
+ self.current_batch.append(request)
121
+
122
+ if self.batch_start_time is None:
123
+ self.batch_start_time = time.time()
124
+
125
+ except Empty:
126
+ # No new requests, check if we should process current batch
127
+ if self._should_process_batch():
128
+ self._submit_batch()
129
+ continue
130
+
131
+ # Check if we should process the batch
132
+ if self._should_process_batch():
133
+ self._submit_batch()
134
+
135
+ except Exception as e:
136
+ logging.error(f"Error in batch loop: {e}")
137
+
138
+ def _should_process_batch(self) -> bool:
139
+ """Determine if current batch should be processed"""
140
+ if not self.current_batch:
141
+ return False
142
+
143
+ batch_age = time.time() - (self.batch_start_time or time.time())
144
+ return (len(self.current_batch) >= BATCH_SIZE or
145
+ batch_age >= BATCH_TIMEOUT)
146
+
147
+ def _submit_batch(self):
148
+ """Submit current batch for processing"""
149
+ if not self.current_batch:
150
+ return
151
+
152
+ batch = self.current_batch.copy()
153
+ self.current_batch = []
154
+ self.batch_start_time = None
155
+
156
+ # Submit to thread pool
157
+ self.executor.submit(self._process_batch, batch)
158
+
159
+ @spaces.GPU(duration=60) # Longer duration for batch processing
160
+ def _process_batch(self, batch: tp.List[GenerationRequest]):
161
+ """Process a batch of requests on GPU"""
162
+ try:
163
+ logging.info(f"Processing batch of {len(batch)} requests")
164
+ start_time = time.time()
165
+
166
+ # Load model (assume all requests use same model for simplicity)
167
+ model_version = batch[0].model
168
+ self._load_model(model_version)
169
+
170
+ # Separate generation vs editing requests
171
+ gen_requests = [req for req in batch if req.melody is None]
172
+ edit_requests = [req for req in batch if req.melody is not None]
173
+
174
+ results = {}
175
+
176
+ # Process generation requests in batch
177
+ if gen_requests:
178
+ gen_results = self._process_generation_batch(gen_requests)
179
+ results.update(gen_results)
180
+
181
+ # Process editing requests individually (due to melody constraints)
182
+ if edit_requests:
183
+ edit_results = self._process_editing_batch(edit_requests)
184
+ results.update(edit_results)
185
+
186
+ # Set results for all requests
187
+ for request in batch:
188
+ if request.request_id in results:
189
+ request.future.set_result(results[request.request_id])
190
+ else:
191
+ request.future.set_exception(Exception("Processing failed"))
192
+
193
+ processing_time = time.time() - start_time
194
+ logging.info(f"Batch processed in {processing_time:.2f}s")
195
+
196
+ except Exception as e:
197
+ logging.error(f"Batch processing error: {e}")
198
+ for request in batch:
199
+ request.future.set_exception(e)
200
+
201
+ def _load_model(self, version: str):
202
+ """Thread-safe model loading"""
203
+ with self.model_lock:
204
+ if self.model is None or self.model.name != version:
205
+ if self.model is not None:
206
+ del self.model
207
+ if torch.cuda.is_available():
208
+ torch.cuda.empty_cache()
209
+ self.model = MelodyFlow.get_pretrained(version)
210
+ logging.info(f"Model {version} loaded")
211
+
212
+ def _process_generation_batch(self, requests: tp.List[GenerationRequest]) -> dict:
213
+ """Process generation requests in batch"""
214
+ if not requests:
215
+ return {}
216
+
217
+ # Use parameters from first request (assuming similar params for batch)
218
+ params = requests[0]
219
+ self.model.set_generation_params(
220
+ solver=params.solver,
221
+ steps=params.steps,
222
+ duration=params.duration
223
+ )
224
+
225
+ # Extract texts
226
+ texts = [req.text for req in requests]
227
+
228
+ # Generate
229
+ outputs = self.model.generate(texts, progress=False, return_tokens=False)
230
+ outputs = outputs.detach().cpu().float()
231
+
232
+ # Create results
233
+ results = {}
234
+ for i, request in enumerate(requests):
235
+ audio_base64 = self._audio_to_base64(outputs[i])
236
+ results[request.request_id] = {
237
+ "audio": audio_base64,
238
+ "format": "wav"
239
+ }
240
+
241
+ return results
242
+
243
+ def _process_editing_batch(self, requests: tp.List[GenerationRequest]) -> dict:
244
+ """Process editing requests individually"""
245
+ results = {}
246
+
247
+ for request in requests:
248
+ try:
249
+ self.model.set_editing_params(
250
+ solver=request.solver,
251
+ steps=request.steps,
252
+ target_flowstep=request.target_flowstep,
253
+ regularize=request.regularize,
254
+ lambda_kl=request.regularization_strength
255
+ )
256
+
257
+ # Process melody
258
+ melody, sr = audio_read(request.melody)
259
+ if melody.dim() == 2:
260
+ melody = melody[None]
261
+ if melody.shape[-1] > int(sr * self.model.duration):
262
+ melody = melody[..., :int(sr * self.model.duration)]
263
+
264
+ melody = convert_audio(melody, sr, 48000, 2)
265
+ melody = self.model.encode_audio(melody.to(self.model.device))
266
+
267
+ # Edit
268
+ output = self.model.edit(
269
+ prompt_tokens=melody,
270
+ descriptions=[request.text],
271
+ src_descriptions=[""],
272
+ progress=False,
273
+ return_tokens=False
274
+ )
275
+
276
+ output = output.detach().cpu().float()[0]
277
+ audio_base64 = self._audio_to_base64(output)
278
+
279
+ results[request.request_id] = {
280
+ "audio": audio_base64,
281
+ "format": "wav"
282
+ }
283
+
284
+ except Exception as e:
285
+ logging.error(f"Error processing edit request {request.request_id}: {e}")
286
+ # Will be handled by batch processor
287
+
288
+ return results
289
+
290
+ def _audio_to_base64(self, audio_tensor: torch.Tensor) -> str:
291
+ """Convert audio tensor to base64 string"""
292
+ with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
293
+ audio_write(
294
+ file.name, audio_tensor, self.model.sample_rate,
295
+ strategy="loudness", loudness_headroom_db=16,
296
+ loudness_compressor=True, add_suffix=False
297
+ )
298
+
299
+ with open(file.name, 'rb') as f:
300
+ audio_bytes = f.read()
301
+
302
+ # Clean up temp file
303
+ Path(file.name).unlink()
304
+
305
+ return base64.b64encode(audio_bytes).decode('utf-8')
306
+
307
+
308
+ # Global batch processor
309
+ batch_processor = OptimizedBatchProcessor()
310
+
311
+
312
+ def predict_concurrent(model: str, text: str, solver: str = "euler",
313
+ steps: int = 50, target_flowstep: float = 0.0,
314
+ regularize: bool = False, regularization_strength: float = 0.0,
315
+ duration: float = 10.0, melody: tp.Optional[str] = None) -> dict:
316
+ """
317
+ Non-blocking predict function optimized for concurrent requests
318
+ """
319
+
320
+ # Adjust steps for melody editing
321
+ if melody is not None:
322
+ steps = steps // 2 if solver == "midpoint" else steps // 5
323
+
324
+ # Submit request to batch processor
325
+ future = batch_processor.submit_request(
326
+ text=text,
327
+ melody=melody,
328
+ solver=solver,
329
+ steps=steps,
330
+ target_flowstep=target_flowstep,
331
+ regularize=regularize,
332
+ regularization_strength=regularization_strength,
333
+ duration=duration,
334
+ model=model
335
+ )
336
+
337
+ # Wait for result with timeout
338
+ try:
339
+ result = future.result(timeout=120) # 2 minute timeout
340
+ return result
341
+ except TimeoutError:
342
+ raise gr.Error("Request timeout - server is overloaded")
343
+ except Exception as e:
344
+ raise gr.Error(f"Generation failed: {str(e)}")
345
+
346
+
347
+ def create_optimized_interface():
348
+ """Create Gradio interface optimized for concurrent usage"""
349
+
350
+ with gr.Blocks(title="MelodyFlow - Concurrent API") as interface:
351
+ gr.Markdown("""
352
+ # MelodyFlow - Optimized for Concurrent Requests
353
+
354
+ This version is optimized for handling multiple concurrent requests efficiently.
355
+ Requests are automatically batched for optimal GPU utilization.
356
+ """)
357
+
358
+ with gr.Row():
359
+ with gr.Column():
360
+ text = gr.Text(label="Text Description", placeholder="Describe the music you want to generate...")
361
+ melody = gr.Audio(label="Reference Audio (optional)", type="filepath")
362
+
363
+ with gr.Row():
364
+ solver = gr.Radio(["euler", "midpoint"], label="Solver", value="euler")
365
+ steps = gr.Slider(1, 128, value=50, label="Steps")
366
+
367
+ with gr.Row():
368
+ duration = gr.Slider(1, 30, value=10, label="Duration (s)")
369
+ model = gr.Dropdown(
370
+ [f"{MODEL_PREFIX}melodyflow-t24-30secs"],
371
+ value=f"{MODEL_PREFIX}melodyflow-t24-30secs",
372
+ label="Model"
373
+ )
374
+
375
+ generate_btn = gr.Button("Generate", variant="primary")
376
+
377
+ with gr.Column():
378
+ output = gr.JSON(label="Generated Audio")
379
+
380
+ generate_btn.click(
381
+ fn=predict_concurrent,
382
+ inputs=[model, text, solver, steps, gr.State(0.0),
383
+ gr.State(False), gr.State(0.0), duration, melody],
384
+ outputs=output,
385
+ concurrency_limit=20 # Set concurrency limit on the event listener
386
+ )
387
+
388
+ gr.Examples(
389
+ fn=predict_concurrent,
390
+ examples=[
391
+ [f"{MODEL_PREFIX}melodyflow-t24-30secs",
392
+ "80s electronic track with melodic synthesizers",
393
+ "euler", 50, 0.0, False, 0.0, 10.0, None],
394
+ [f"{MODEL_PREFIX}melodyflow-t24-30secs",
395
+ "Cheerful country song with acoustic guitars",
396
+ "euler", 50, 0.0, False, 0.0, 15.0, None]
397
+ ],
398
+ inputs=[model, text, solver, steps, gr.State(0.0),
399
+ gr.State(False), gr.State(0.0), duration, melody],
400
+ outputs=output
401
+ )
402
+
403
+ return interface
404
+
405
+
406
+ if __name__ == "__main__":
407
+ import argparse
408
+
409
+ parser = argparse.ArgumentParser()
410
+ parser.add_argument("--host", default="0.0.0.0", help="Host to bind to")
411
+ parser.add_argument("--port", type=int, default=7860, help="Port to bind to")
412
+ parser.add_argument("--share", action="store_true", help="Create public link")
413
+ args = parser.parse_args()
414
+
415
+ # Setup logging
416
+ logging.basicConfig(
417
+ level=logging.INFO,
418
+ format='%(asctime)s - %(levelname)s - %(message)s'
419
+ )
420
+
421
+ # Start batch processor
422
+ batch_processor.start()
423
+
424
+ # Create and launch interface
425
+ interface = create_optimized_interface()
426
+
427
+ try:
428
+ interface.queue(
429
+ max_size=200, # Large queue
430
+ api_open=True
431
+ ).launch(
432
+ server_name=args.host,
433
+ server_port=args.port,
434
+ share=args.share,
435
+ show_api=True,
436
+ max_threads=40 # Configure worker threads in launch()
437
+ )
438
+ finally:
439
+ batch_processor.stop()