root
commited on
Commit
·
a3f7aaa
1
Parent(s):
ba71a6b
ss
Browse files
app.py
CHANGED
|
@@ -28,12 +28,33 @@ if "HF_TOKEN" in os.environ:
|
|
| 28 |
|
| 29 |
# Constants
|
| 30 |
GENRE_MODEL_NAME = "dima806/music_genres_classification"
|
|
|
|
| 31 |
LLM_MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
|
| 32 |
SAMPLE_RATE = 22050 # Standard sample rate for audio processing
|
| 33 |
|
| 34 |
# Check CUDA availability (for informational purposes)
|
| 35 |
CUDA_AVAILABLE = ensure_cuda_availability()
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
# Create genre classification pipeline
|
| 38 |
print(f"Loading audio classification model: {GENRE_MODEL_NAME}")
|
| 39 |
try:
|
|
@@ -209,6 +230,55 @@ Your lyrics:
|
|
| 209 |
|
| 210 |
return lyrics
|
| 211 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
def process_audio(audio_file):
|
| 213 |
"""Main function to process audio file, classify genre, and generate lyrics."""
|
| 214 |
if audio_file is None:
|
|
@@ -218,6 +288,11 @@ def process_audio(audio_file):
|
|
| 218 |
# Extract audio features
|
| 219 |
audio_data = extract_audio_features(audio_file)
|
| 220 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
# Classify genre
|
| 222 |
top_genres = classify_genre(audio_data)
|
| 223 |
|
|
|
|
| 28 |
|
| 29 |
# Constants
|
| 30 |
GENRE_MODEL_NAME = "dima806/music_genres_classification"
|
| 31 |
+
MUSIC_DETECTION_MODEL = "MIT/ast-finetuned-audioset-10-10-0.4593"
|
| 32 |
LLM_MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
|
| 33 |
SAMPLE_RATE = 22050 # Standard sample rate for audio processing
|
| 34 |
|
| 35 |
# Check CUDA availability (for informational purposes)
|
| 36 |
CUDA_AVAILABLE = ensure_cuda_availability()
|
| 37 |
|
| 38 |
+
# Create music detection pipeline
|
| 39 |
+
print(f"Loading music detection model: {MUSIC_DETECTION_MODEL}")
|
| 40 |
+
try:
|
| 41 |
+
music_detector = pipeline(
|
| 42 |
+
"audio-classification",
|
| 43 |
+
model=MUSIC_DETECTION_MODEL,
|
| 44 |
+
device=0 if CUDA_AVAILABLE else -1
|
| 45 |
+
)
|
| 46 |
+
print("Successfully loaded music detection pipeline")
|
| 47 |
+
except Exception as e:
|
| 48 |
+
print(f"Error creating music detection pipeline: {str(e)}")
|
| 49 |
+
# Fallback to manual loading
|
| 50 |
+
try:
|
| 51 |
+
music_processor = AutoFeatureExtractor.from_pretrained(MUSIC_DETECTION_MODEL)
|
| 52 |
+
music_model = AutoModelForAudioClassification.from_pretrained(MUSIC_DETECTION_MODEL)
|
| 53 |
+
print("Successfully loaded music detection model and feature extractor")
|
| 54 |
+
except Exception as e2:
|
| 55 |
+
print(f"Error loading music detection model components: {str(e2)}")
|
| 56 |
+
raise RuntimeError(f"Could not load music detection model: {str(e2)}")
|
| 57 |
+
|
| 58 |
# Create genre classification pipeline
|
| 59 |
print(f"Loading audio classification model: {GENRE_MODEL_NAME}")
|
| 60 |
try:
|
|
|
|
| 230 |
|
| 231 |
return lyrics
|
| 232 |
|
| 233 |
+
def detect_music(audio_data):
|
| 234 |
+
"""Detect if the audio is music using the MIT AST model."""
|
| 235 |
+
try:
|
| 236 |
+
# First attempt: Try using the pipeline if available
|
| 237 |
+
if 'music_detector' in globals():
|
| 238 |
+
results = music_detector(audio_data["path"])
|
| 239 |
+
# Look for music-related classes in the results
|
| 240 |
+
music_confidence = 0.0
|
| 241 |
+
for result in results:
|
| 242 |
+
label = result["label"].lower()
|
| 243 |
+
if any(music_term in label for music_term in ["music", "song", "singing", "instrument"]):
|
| 244 |
+
music_confidence = max(music_confidence, result["score"])
|
| 245 |
+
return music_confidence >= 0.5
|
| 246 |
+
|
| 247 |
+
# Second attempt: Use manually loaded model components
|
| 248 |
+
elif 'music_processor' in globals() and 'music_model' in globals():
|
| 249 |
+
# Process audio input with feature extractor
|
| 250 |
+
inputs = music_processor(
|
| 251 |
+
audio_data["waveform"],
|
| 252 |
+
sampling_rate=audio_data["sample_rate"],
|
| 253 |
+
return_tensors="pt"
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
with torch.no_grad():
|
| 257 |
+
outputs = music_model(**inputs)
|
| 258 |
+
predictions = outputs.logits.softmax(dim=-1)
|
| 259 |
+
|
| 260 |
+
# Get the top predictions
|
| 261 |
+
values, indices = torch.topk(predictions, 5)
|
| 262 |
+
|
| 263 |
+
# Map indices to labels
|
| 264 |
+
labels = music_model.config.id2label
|
| 265 |
+
|
| 266 |
+
# Check for music-related classes
|
| 267 |
+
music_confidence = 0.0
|
| 268 |
+
for i, (value, index) in enumerate(zip(values[0], indices[0])):
|
| 269 |
+
label = labels[index.item()].lower()
|
| 270 |
+
if any(music_term in label for music_term in ["music", "song", "singing", "instrument"]):
|
| 271 |
+
music_confidence = max(music_confidence, value.item())
|
| 272 |
+
|
| 273 |
+
return music_confidence >= 0.5
|
| 274 |
+
|
| 275 |
+
else:
|
| 276 |
+
raise ValueError("No music detection model available")
|
| 277 |
+
|
| 278 |
+
except Exception as e:
|
| 279 |
+
print(f"Error in music detection: {str(e)}")
|
| 280 |
+
return False
|
| 281 |
+
|
| 282 |
def process_audio(audio_file):
|
| 283 |
"""Main function to process audio file, classify genre, and generate lyrics."""
|
| 284 |
if audio_file is None:
|
|
|
|
| 288 |
# Extract audio features
|
| 289 |
audio_data = extract_audio_features(audio_file)
|
| 290 |
|
| 291 |
+
# First check if it's music
|
| 292 |
+
is_music = detect_music(audio_data)
|
| 293 |
+
if not is_music:
|
| 294 |
+
return "The uploaded audio does not appear to be music. Please upload a music file.", None
|
| 295 |
+
|
| 296 |
# Classify genre
|
| 297 |
top_genres = classify_genre(audio_data)
|
| 298 |
|