root commited on
Commit
bb9a8b1
·
1 Parent(s): 7dfa01d
Files changed (1) hide show
  1. app.py +65 -43
app.py CHANGED
@@ -4,7 +4,7 @@ import gradio as gr
4
  import torch
5
  import numpy as np
6
  from transformers import (
7
- AutoModelForSequenceClassification,
8
  AutoFeatureExtractor,
9
  AutoTokenizer,
10
  pipeline,
@@ -18,7 +18,8 @@ from utils import (
18
  extract_mfcc_features,
19
  calculate_lyrics_length,
20
  format_genre_results,
21
- ensure_cuda_availability
 
22
  )
23
 
24
  # Login to Hugging Face Hub if token is provided
@@ -33,17 +34,25 @@ SAMPLE_RATE = 22050 # Standard sample rate for audio processing
33
  # Check CUDA availability (for informational purposes)
34
  CUDA_AVAILABLE = ensure_cuda_availability()
35
 
36
- # Load genre classification model
 
37
  try:
38
- # Try to load feature extractor first (for audio models)
39
- genre_processor = AutoFeatureExtractor.from_pretrained(GENRE_MODEL_NAME)
40
- print(f"Loaded feature extractor for genre classification model: {GENRE_MODEL_NAME}")
 
 
 
41
  except Exception as e:
42
- print(f"Error loading feature extractor, using basic processing: {str(e)}")
43
- genre_processor = None
44
-
45
- # Load the model
46
- genre_model = AutoModelForSequenceClassification.from_pretrained(GENRE_MODEL_NAME)
 
 
 
 
47
 
48
  # Load LLM with appropriate quantization for T4 GPU
49
  bnb_config = BitsAndBytesConfig(
@@ -76,48 +85,61 @@ def extract_audio_features(audio_file):
76
  # Get audio duration in seconds
77
  duration = extract_audio_duration(y, sr)
78
 
79
- # Extract MFCCs for genre classification
80
  mfccs_mean = extract_mfcc_features(y, sr, n_mfcc=20)
81
 
82
  return {
83
  "features": mfccs_mean,
84
  "duration": duration,
85
  "waveform": y,
86
- "sample_rate": sr
 
87
  }
88
 
89
  def classify_genre(audio_data):
90
  """Classify the genre of the audio using the loaded model."""
91
- if genre_processor is not None:
92
- # Use the feature extractor if available
93
- inputs = genre_processor(
94
- audio_data["waveform"],
95
- sampling_rate=audio_data["sample_rate"],
96
- return_tensors="pt"
97
- )
98
- else:
99
- # Fallback to basic feature processing
100
- # Convert MFCC features to tensor and reshape appropriately
101
- features_tensor = torch.tensor(audio_data["features"]).unsqueeze(0)
102
- inputs = {"input_features": features_tensor}
103
-
104
- with torch.no_grad():
105
- outputs = genre_model(**inputs)
106
- predictions = outputs.logits.softmax(dim=-1)
107
-
108
- # Get the top 3 genres
109
- values, indices = torch.topk(predictions, 3)
110
-
111
- # Map indices to genre labels
112
- genre_labels = genre_model.config.id2label
113
-
114
- top_genres = []
115
- for i, (value, index) in enumerate(zip(values[0], indices[0])):
116
- genre = genre_labels[index.item()]
117
- confidence = value.item()
118
- top_genres.append((genre, confidence))
119
-
120
- return top_genres
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  def generate_lyrics(genre, duration):
123
  """Generate lyrics based on the genre and with appropriate length."""
 
4
  import torch
5
  import numpy as np
6
  from transformers import (
7
+ AutoModelForAudioClassification,
8
  AutoFeatureExtractor,
9
  AutoTokenizer,
10
  pipeline,
 
18
  extract_mfcc_features,
19
  calculate_lyrics_length,
20
  format_genre_results,
21
+ ensure_cuda_availability,
22
+ preprocess_audio_for_model
23
  )
24
 
25
  # Login to Hugging Face Hub if token is provided
 
34
  # Check CUDA availability (for informational purposes)
35
  CUDA_AVAILABLE = ensure_cuda_availability()
36
 
37
+ # Create genre classification pipeline
38
+ print(f"Loading audio classification model: {GENRE_MODEL_NAME}")
39
  try:
40
+ genre_classifier = pipeline(
41
+ "audio-classification",
42
+ model=GENRE_MODEL_NAME,
43
+ device=0 if CUDA_AVAILABLE else -1
44
+ )
45
+ print("Successfully loaded audio classification pipeline")
46
  except Exception as e:
47
+ print(f"Error creating pipeline: {str(e)}")
48
+ # Fallback to manual loading
49
+ try:
50
+ genre_processor = AutoFeatureExtractor.from_pretrained(GENRE_MODEL_NAME)
51
+ genre_model = AutoModelForAudioClassification.from_pretrained(GENRE_MODEL_NAME)
52
+ print("Successfully loaded audio classification model and feature extractor")
53
+ except Exception as e2:
54
+ print(f"Error loading model components: {str(e2)}")
55
+ raise RuntimeError(f"Could not load genre classification model: {str(e2)}")
56
 
57
  # Load LLM with appropriate quantization for T4 GPU
58
  bnb_config = BitsAndBytesConfig(
 
85
  # Get audio duration in seconds
86
  duration = extract_audio_duration(y, sr)
87
 
88
+ # Extract MFCCs for genre classification (may not be needed with the pipeline)
89
  mfccs_mean = extract_mfcc_features(y, sr, n_mfcc=20)
90
 
91
  return {
92
  "features": mfccs_mean,
93
  "duration": duration,
94
  "waveform": y,
95
+ "sample_rate": sr,
96
+ "path": audio_file # Keep path for the pipeline
97
  }
98
 
99
  def classify_genre(audio_data):
100
  """Classify the genre of the audio using the loaded model."""
101
+ try:
102
+ # First attempt: Try using the pipeline if available
103
+ if 'genre_classifier' in globals():
104
+ results = genre_classifier(audio_data["path"])
105
+ # Transform pipeline results to our expected format
106
+ top_genres = [(result["label"], result["score"]) for result in results[:3]]
107
+ return top_genres
108
+
109
+ # Second attempt: Use manually loaded model components
110
+ elif 'genre_processor' in globals() and 'genre_model' in globals():
111
+ # Process audio input with feature extractor
112
+ inputs = genre_processor(
113
+ audio_data["waveform"],
114
+ sampling_rate=audio_data["sample_rate"],
115
+ return_tensors="pt"
116
+ )
117
+
118
+ with torch.no_grad():
119
+ outputs = genre_model(**inputs)
120
+ predictions = outputs.logits.softmax(dim=-1)
121
+
122
+ # Get the top 3 genres
123
+ values, indices = torch.topk(predictions, 3)
124
+
125
+ # Map indices to genre labels
126
+ genre_labels = genre_model.config.id2label
127
+
128
+ top_genres = []
129
+ for i, (value, index) in enumerate(zip(values[0], indices[0])):
130
+ genre = genre_labels[index.item()]
131
+ confidence = value.item()
132
+ top_genres.append((genre, confidence))
133
+
134
+ return top_genres
135
+
136
+ else:
137
+ raise ValueError("No genre classification model available")
138
+
139
+ except Exception as e:
140
+ print(f"Error in genre classification: {str(e)}")
141
+ # Fallback: return a default genre if everything fails
142
+ return [("rock", 1.0)]
143
 
144
  def generate_lyrics(genre, duration):
145
  """Generate lyrics based on the genre and with appropriate length."""