Commit
·
5277669
1
Parent(s):
ead33a6
fix : clip audio to max 2 mins
Browse files- main.py +67 -19
- models/age_and_gender_model.py +3 -7
- models/nationality_model.py +0 -5
main.py
CHANGED
|
@@ -20,6 +20,7 @@ logger = logging.getLogger(__name__)
|
|
| 20 |
UPLOAD_FOLDER = 'uploads'
|
| 21 |
ALLOWED_EXTENSIONS = {'wav', 'mp3', 'flac', 'm4a'}
|
| 22 |
SAMPLING_RATE = 16000
|
|
|
|
| 23 |
|
| 24 |
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
| 25 |
|
|
@@ -31,6 +32,23 @@ def allowed_file(filename: str) -> bool:
|
|
| 31 |
return '.' in filename and \
|
| 32 |
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
async def load_models() -> bool:
|
| 35 |
global age_gender_model, nationality_model
|
| 36 |
|
|
@@ -96,10 +114,11 @@ app = FastAPI(
|
|
| 96 |
lifespan=lifespan
|
| 97 |
)
|
| 98 |
|
| 99 |
-
def preprocess_audio(audio_data: np.ndarray, sr: int) -> tuple[np.ndarray, int]:
|
| 100 |
preprocess_start = time.time()
|
| 101 |
original_shape = audio_data.shape
|
| 102 |
-
|
|
|
|
| 103 |
|
| 104 |
# Convert to mono if stereo
|
| 105 |
if len(audio_data.shape) > 1:
|
|
@@ -115,20 +134,24 @@ def preprocess_audio(audio_data: np.ndarray, sr: int) -> tuple[np.ndarray, int]:
|
|
| 115 |
audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=SAMPLING_RATE)
|
| 116 |
resample_end = time.time()
|
| 117 |
logger.info(f"Resampling completed in {resample_end - resample_start:.3f} seconds")
|
|
|
|
| 118 |
else:
|
| 119 |
logger.info(f"No resampling needed - already at {SAMPLING_RATE}Hz")
|
| 120 |
|
|
|
|
|
|
|
|
|
|
| 121 |
# Convert to float32
|
| 122 |
audio_data = audio_data.astype(np.float32)
|
| 123 |
|
| 124 |
preprocess_end = time.time()
|
| 125 |
-
|
| 126 |
logger.info(f"Audio preprocessing completed in {preprocess_end - preprocess_start:.3f} seconds")
|
| 127 |
-
logger.info(f"Final audio: {audio_data.shape} samples, {
|
| 128 |
|
| 129 |
-
return audio_data,
|
| 130 |
|
| 131 |
-
async def process_audio_file(file: UploadFile) -> tuple[np.ndarray, int]:
|
| 132 |
process_start = time.time()
|
| 133 |
logger.info(f"Processing uploaded file: {file.filename}")
|
| 134 |
|
|
@@ -165,12 +188,12 @@ async def process_audio_file(file: UploadFile) -> tuple[np.ndarray, int]:
|
|
| 165 |
load_end = time.time()
|
| 166 |
logger.info(f"Audio loaded in {load_end - load_start:.3f} seconds")
|
| 167 |
|
| 168 |
-
processed_audio, processed_sr = preprocess_audio(audio_data, sr)
|
| 169 |
|
| 170 |
process_end = time.time()
|
| 171 |
logger.info(f"Total file processing completed in {process_end - process_start:.3f} seconds")
|
| 172 |
|
| 173 |
-
return processed_audio, processed_sr
|
| 174 |
|
| 175 |
except Exception as e:
|
| 176 |
logger.error(f"Error processing audio file {file.filename}: {str(e)}")
|
|
@@ -186,6 +209,7 @@ async def root() -> Dict[str, Any]:
|
|
| 186 |
logger.info("Root endpoint accessed")
|
| 187 |
return {
|
| 188 |
"message": "Audio Analysis API - Age, Gender & Nationality Prediction",
|
|
|
|
| 189 |
"models_loaded": {
|
| 190 |
"age_gender": age_gender_model is not None and hasattr(age_gender_model, 'model') and age_gender_model.model is not None,
|
| 191 |
"nationality": nationality_model is not None and hasattr(nationality_model, 'model') and nationality_model.model is not None
|
|
@@ -206,7 +230,6 @@ async def health_check() -> Dict[str, str]:
|
|
| 206 |
|
| 207 |
@app.post("/predict_age_and_gender")
|
| 208 |
async def predict_age_and_gender(file: UploadFile = File(...)) -> Dict[str, Any]:
|
| 209 |
-
"""Predict age and gender from uploaded audio file."""
|
| 210 |
endpoint_start = time.time()
|
| 211 |
logger.info(f"Age & Gender prediction requested for file: {file.filename}")
|
| 212 |
|
|
@@ -215,7 +238,7 @@ async def predict_age_and_gender(file: UploadFile = File(...)) -> Dict[str, Any]
|
|
| 215 |
raise HTTPException(status_code=500, detail="Age & gender model not loaded")
|
| 216 |
|
| 217 |
try:
|
| 218 |
-
processed_audio, processed_sr = await process_audio_file(file)
|
| 219 |
|
| 220 |
# Make prediction
|
| 221 |
prediction_start = time.time()
|
|
@@ -230,12 +253,21 @@ async def predict_age_and_gender(file: UploadFile = File(...)) -> Dict[str, Any]
|
|
| 230 |
endpoint_end = time.time()
|
| 231 |
logger.info(f"Total age & gender endpoint processing time: {endpoint_end - endpoint_start:.3f} seconds")
|
| 232 |
|
| 233 |
-
|
| 234 |
"success": True,
|
| 235 |
"predictions": predictions,
|
| 236 |
-
"processing_time": round(endpoint_end - endpoint_start, 3)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
}
|
| 238 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
except HTTPException:
|
| 240 |
raise
|
| 241 |
except Exception as e:
|
|
@@ -244,7 +276,6 @@ async def predict_age_and_gender(file: UploadFile = File(...)) -> Dict[str, Any]
|
|
| 244 |
|
| 245 |
@app.post("/predict_nationality")
|
| 246 |
async def predict_nationality(file: UploadFile = File(...)) -> Dict[str, Any]:
|
| 247 |
-
"""Predict nationality/language from uploaded audio file."""
|
| 248 |
endpoint_start = time.time()
|
| 249 |
logger.info(f"Nationality prediction requested for file: {file.filename}")
|
| 250 |
|
|
@@ -253,7 +284,7 @@ async def predict_nationality(file: UploadFile = File(...)) -> Dict[str, Any]:
|
|
| 253 |
raise HTTPException(status_code=500, detail="Nationality model not loaded")
|
| 254 |
|
| 255 |
try:
|
| 256 |
-
processed_audio, processed_sr = await process_audio_file(file)
|
| 257 |
|
| 258 |
# Make prediction
|
| 259 |
prediction_start = time.time()
|
|
@@ -268,12 +299,21 @@ async def predict_nationality(file: UploadFile = File(...)) -> Dict[str, Any]:
|
|
| 268 |
endpoint_end = time.time()
|
| 269 |
logger.info(f"Total nationality endpoint processing time: {endpoint_end - endpoint_start:.3f} seconds")
|
| 270 |
|
| 271 |
-
|
| 272 |
"success": True,
|
| 273 |
"predictions": predictions,
|
| 274 |
-
"processing_time": round(endpoint_end - endpoint_start, 3)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
}
|
| 276 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
except HTTPException:
|
| 278 |
raise
|
| 279 |
except Exception as e:
|
|
@@ -282,7 +322,6 @@ async def predict_nationality(file: UploadFile = File(...)) -> Dict[str, Any]:
|
|
| 282 |
|
| 283 |
@app.post("/predict_all")
|
| 284 |
async def predict_all(file: UploadFile = File(...)) -> Dict[str, Any]:
|
| 285 |
-
"""Predict age, gender, and nationality from uploaded audio file."""
|
| 286 |
endpoint_start = time.time()
|
| 287 |
logger.info(f"Complete analysis requested for file: {file.filename}")
|
| 288 |
|
|
@@ -295,7 +334,7 @@ async def predict_all(file: UploadFile = File(...)) -> Dict[str, Any]:
|
|
| 295 |
raise HTTPException(status_code=500, detail="Nationality model not loaded")
|
| 296 |
|
| 297 |
try:
|
| 298 |
-
processed_audio, processed_sr = await process_audio_file(file)
|
| 299 |
|
| 300 |
# Get age & gender predictions
|
| 301 |
age_prediction_start = time.time()
|
|
@@ -323,7 +362,7 @@ async def predict_all(file: UploadFile = File(...)) -> Dict[str, Any]:
|
|
| 323 |
logger.info(f"Total prediction time: {total_prediction_time:.3f} seconds")
|
| 324 |
logger.info(f"Total complete analysis endpoint processing time: {endpoint_end - endpoint_start:.3f} seconds")
|
| 325 |
|
| 326 |
-
|
| 327 |
"success": True,
|
| 328 |
"predictions": {
|
| 329 |
"demographics": age_gender_predictions,
|
|
@@ -333,9 +372,18 @@ async def predict_all(file: UploadFile = File(...)) -> Dict[str, Any]:
|
|
| 333 |
"total": round(endpoint_end - endpoint_start, 3),
|
| 334 |
"age_gender": round(age_prediction_end - age_prediction_start, 3),
|
| 335 |
"nationality": round(nationality_prediction_end - nationality_prediction_start, 3)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
}
|
| 337 |
}
|
| 338 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
except HTTPException:
|
| 340 |
raise
|
| 341 |
except Exception as e:
|
|
|
|
| 20 |
UPLOAD_FOLDER = 'uploads'
|
| 21 |
ALLOWED_EXTENSIONS = {'wav', 'mp3', 'flac', 'm4a'}
|
| 22 |
SAMPLING_RATE = 16000
|
| 23 |
+
MAX_DURATION_SECONDS = 120 # 2 minutes maximum
|
| 24 |
|
| 25 |
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
| 26 |
|
|
|
|
| 32 |
return '.' in filename and \
|
| 33 |
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
| 34 |
|
| 35 |
+
def clip_audio_to_max_duration(audio_data: np.ndarray, sr: int, max_duration: int = MAX_DURATION_SECONDS) -> tuple[np.ndarray, bool]:
|
| 36 |
+
current_duration = len(audio_data) / sr
|
| 37 |
+
|
| 38 |
+
if current_duration <= max_duration:
|
| 39 |
+
logger.info(f"Audio duration ({current_duration:.2f}s) is within limit ({max_duration}s) - no clipping needed")
|
| 40 |
+
return audio_data, False
|
| 41 |
+
|
| 42 |
+
# Calculate how many samples we need for the max duration
|
| 43 |
+
max_samples = int(max_duration * sr)
|
| 44 |
+
|
| 45 |
+
# Clip to first max_duration seconds
|
| 46 |
+
clipped_audio = audio_data[:max_samples]
|
| 47 |
+
|
| 48 |
+
logger.info(f"Audio clipped from {current_duration:.2f}s to {max_duration}s ({len(audio_data)} samples → {len(clipped_audio)} samples)")
|
| 49 |
+
|
| 50 |
+
return clipped_audio, True
|
| 51 |
+
|
| 52 |
async def load_models() -> bool:
|
| 53 |
global age_gender_model, nationality_model
|
| 54 |
|
|
|
|
| 114 |
lifespan=lifespan
|
| 115 |
)
|
| 116 |
|
| 117 |
+
def preprocess_audio(audio_data: np.ndarray, sr: int) -> tuple[np.ndarray, int, bool]:
|
| 118 |
preprocess_start = time.time()
|
| 119 |
original_shape = audio_data.shape
|
| 120 |
+
original_duration = len(audio_data) / sr
|
| 121 |
+
logger.info(f"Starting audio preprocessing Sample rate: {sr}Hz, Duration: {original_duration:.2f}s")
|
| 122 |
|
| 123 |
# Convert to mono if stereo
|
| 124 |
if len(audio_data.shape) > 1:
|
|
|
|
| 134 |
audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=SAMPLING_RATE)
|
| 135 |
resample_end = time.time()
|
| 136 |
logger.info(f"Resampling completed in {resample_end - resample_start:.3f} seconds")
|
| 137 |
+
sr = SAMPLING_RATE
|
| 138 |
else:
|
| 139 |
logger.info(f"No resampling needed - already at {SAMPLING_RATE}Hz")
|
| 140 |
|
| 141 |
+
# Clip audio to maximum duration if needed
|
| 142 |
+
audio_data, was_clipped = clip_audio_to_max_duration(audio_data, sr)
|
| 143 |
+
|
| 144 |
# Convert to float32
|
| 145 |
audio_data = audio_data.astype(np.float32)
|
| 146 |
|
| 147 |
preprocess_end = time.time()
|
| 148 |
+
final_duration_seconds = len(audio_data) / sr
|
| 149 |
logger.info(f"Audio preprocessing completed in {preprocess_end - preprocess_start:.3f} seconds")
|
| 150 |
+
logger.info(f"Final audio: {audio_data.shape} samples, {final_duration_seconds:.2f} seconds duration")
|
| 151 |
|
| 152 |
+
return audio_data, sr, was_clipped
|
| 153 |
|
| 154 |
+
async def process_audio_file(file: UploadFile) -> tuple[np.ndarray, int, bool]:
|
| 155 |
process_start = time.time()
|
| 156 |
logger.info(f"Processing uploaded file: {file.filename}")
|
| 157 |
|
|
|
|
| 188 |
load_end = time.time()
|
| 189 |
logger.info(f"Audio loaded in {load_end - load_start:.3f} seconds")
|
| 190 |
|
| 191 |
+
processed_audio, processed_sr, was_clipped = preprocess_audio(audio_data, sr)
|
| 192 |
|
| 193 |
process_end = time.time()
|
| 194 |
logger.info(f"Total file processing completed in {process_end - process_start:.3f} seconds")
|
| 195 |
|
| 196 |
+
return processed_audio, processed_sr, was_clipped
|
| 197 |
|
| 198 |
except Exception as e:
|
| 199 |
logger.error(f"Error processing audio file {file.filename}: {str(e)}")
|
|
|
|
| 209 |
logger.info("Root endpoint accessed")
|
| 210 |
return {
|
| 211 |
"message": "Audio Analysis API - Age, Gender & Nationality Prediction",
|
| 212 |
+
"max_audio_duration": f"{MAX_DURATION_SECONDS} seconds (files longer than this will be automatically clipped)",
|
| 213 |
"models_loaded": {
|
| 214 |
"age_gender": age_gender_model is not None and hasattr(age_gender_model, 'model') and age_gender_model.model is not None,
|
| 215 |
"nationality": nationality_model is not None and hasattr(nationality_model, 'model') and nationality_model.model is not None
|
|
|
|
| 230 |
|
| 231 |
@app.post("/predict_age_and_gender")
|
| 232 |
async def predict_age_and_gender(file: UploadFile = File(...)) -> Dict[str, Any]:
|
|
|
|
| 233 |
endpoint_start = time.time()
|
| 234 |
logger.info(f"Age & Gender prediction requested for file: {file.filename}")
|
| 235 |
|
|
|
|
| 238 |
raise HTTPException(status_code=500, detail="Age & gender model not loaded")
|
| 239 |
|
| 240 |
try:
|
| 241 |
+
processed_audio, processed_sr, was_clipped = await process_audio_file(file)
|
| 242 |
|
| 243 |
# Make prediction
|
| 244 |
prediction_start = time.time()
|
|
|
|
| 253 |
endpoint_end = time.time()
|
| 254 |
logger.info(f"Total age & gender endpoint processing time: {endpoint_end - endpoint_start:.3f} seconds")
|
| 255 |
|
| 256 |
+
response = {
|
| 257 |
"success": True,
|
| 258 |
"predictions": predictions,
|
| 259 |
+
"processing_time": round(endpoint_end - endpoint_start, 3),
|
| 260 |
+
"audio_info": {
|
| 261 |
+
"was_clipped": was_clipped,
|
| 262 |
+
"max_duration_seconds": MAX_DURATION_SECONDS
|
| 263 |
+
}
|
| 264 |
}
|
| 265 |
|
| 266 |
+
if was_clipped:
|
| 267 |
+
response["warning"] = f"Audio was longer than {MAX_DURATION_SECONDS} seconds and was automatically clipped to the first {MAX_DURATION_SECONDS} seconds for analysis."
|
| 268 |
+
|
| 269 |
+
return response
|
| 270 |
+
|
| 271 |
except HTTPException:
|
| 272 |
raise
|
| 273 |
except Exception as e:
|
|
|
|
| 276 |
|
| 277 |
@app.post("/predict_nationality")
|
| 278 |
async def predict_nationality(file: UploadFile = File(...)) -> Dict[str, Any]:
|
|
|
|
| 279 |
endpoint_start = time.time()
|
| 280 |
logger.info(f"Nationality prediction requested for file: {file.filename}")
|
| 281 |
|
|
|
|
| 284 |
raise HTTPException(status_code=500, detail="Nationality model not loaded")
|
| 285 |
|
| 286 |
try:
|
| 287 |
+
processed_audio, processed_sr, was_clipped = await process_audio_file(file)
|
| 288 |
|
| 289 |
# Make prediction
|
| 290 |
prediction_start = time.time()
|
|
|
|
| 299 |
endpoint_end = time.time()
|
| 300 |
logger.info(f"Total nationality endpoint processing time: {endpoint_end - endpoint_start:.3f} seconds")
|
| 301 |
|
| 302 |
+
response = {
|
| 303 |
"success": True,
|
| 304 |
"predictions": predictions,
|
| 305 |
+
"processing_time": round(endpoint_end - endpoint_start, 3),
|
| 306 |
+
"audio_info": {
|
| 307 |
+
"was_clipped": was_clipped,
|
| 308 |
+
"max_duration_seconds": MAX_DURATION_SECONDS
|
| 309 |
+
}
|
| 310 |
}
|
| 311 |
|
| 312 |
+
if was_clipped:
|
| 313 |
+
response["warning"] = f"Audio was longer than {MAX_DURATION_SECONDS} seconds and was automatically clipped to the first {MAX_DURATION_SECONDS} seconds for analysis."
|
| 314 |
+
|
| 315 |
+
return response
|
| 316 |
+
|
| 317 |
except HTTPException:
|
| 318 |
raise
|
| 319 |
except Exception as e:
|
|
|
|
| 322 |
|
| 323 |
@app.post("/predict_all")
|
| 324 |
async def predict_all(file: UploadFile = File(...)) -> Dict[str, Any]:
|
|
|
|
| 325 |
endpoint_start = time.time()
|
| 326 |
logger.info(f"Complete analysis requested for file: {file.filename}")
|
| 327 |
|
|
|
|
| 334 |
raise HTTPException(status_code=500, detail="Nationality model not loaded")
|
| 335 |
|
| 336 |
try:
|
| 337 |
+
processed_audio, processed_sr, was_clipped = await process_audio_file(file)
|
| 338 |
|
| 339 |
# Get age & gender predictions
|
| 340 |
age_prediction_start = time.time()
|
|
|
|
| 362 |
logger.info(f"Total prediction time: {total_prediction_time:.3f} seconds")
|
| 363 |
logger.info(f"Total complete analysis endpoint processing time: {endpoint_end - endpoint_start:.3f} seconds")
|
| 364 |
|
| 365 |
+
response = {
|
| 366 |
"success": True,
|
| 367 |
"predictions": {
|
| 368 |
"demographics": age_gender_predictions,
|
|
|
|
| 372 |
"total": round(endpoint_end - endpoint_start, 3),
|
| 373 |
"age_gender": round(age_prediction_end - age_prediction_start, 3),
|
| 374 |
"nationality": round(nationality_prediction_end - nationality_prediction_start, 3)
|
| 375 |
+
},
|
| 376 |
+
"audio_info": {
|
| 377 |
+
"was_clipped": was_clipped,
|
| 378 |
+
"max_duration_seconds": MAX_DURATION_SECONDS
|
| 379 |
}
|
| 380 |
}
|
| 381 |
|
| 382 |
+
if was_clipped:
|
| 383 |
+
response["warning"] = f"Audio was longer than {MAX_DURATION_SECONDS} seconds and was automatically clipped to the first {MAX_DURATION_SECONDS} seconds for analysis."
|
| 384 |
+
|
| 385 |
+
return response
|
| 386 |
+
|
| 387 |
except HTTPException:
|
| 388 |
raise
|
| 389 |
except Exception as e:
|
models/age_and_gender_model.py
CHANGED
|
@@ -7,7 +7,6 @@ import librosa
|
|
| 7 |
|
| 8 |
class AgeGenderModel:
|
| 9 |
def __init__(self, model_path=None):
|
| 10 |
-
# Use persistent storage if available, fallback to local cache
|
| 11 |
if model_path is None:
|
| 12 |
if os.path.exists("/data"):
|
| 13 |
# HF Spaces persistent storage
|
|
@@ -34,7 +33,7 @@ class AgeGenderModel:
|
|
| 34 |
print("Age & gender model files not found. Downloading...")
|
| 35 |
|
| 36 |
try:
|
| 37 |
-
# Use /data for cache if available, otherwise use local cache
|
| 38 |
if os.path.exists("/data"):
|
| 39 |
cache_root = '/data/cache'
|
| 40 |
else:
|
|
@@ -72,16 +71,13 @@ class AgeGenderModel:
|
|
| 72 |
|
| 73 |
def load(self):
|
| 74 |
try:
|
| 75 |
-
# Download model if needed
|
| 76 |
if not self.download_model():
|
| 77 |
print("Failed to download age & gender model")
|
| 78 |
return False
|
| 79 |
|
| 80 |
-
# Load the audonnx model
|
| 81 |
print(f"Loading age & gender model from {self.model_path}...")
|
| 82 |
self.model = audonnx.load(self.model_path)
|
| 83 |
|
| 84 |
-
# Create the audinterface Feature interface
|
| 85 |
outputs = ['logits_age', 'logits_gender']
|
| 86 |
self.interface = audinterface.Feature(
|
| 87 |
self.model.labels(outputs),
|
|
@@ -91,7 +87,7 @@ class AgeGenderModel:
|
|
| 91 |
'concat': True,
|
| 92 |
},
|
| 93 |
sampling_rate=self.sampling_rate,
|
| 94 |
-
resample=False,
|
| 95 |
verbose=False,
|
| 96 |
)
|
| 97 |
print("Age & gender model loaded successfully!")
|
|
@@ -105,7 +101,7 @@ class AgeGenderModel:
|
|
| 105 |
if self.model is None or self.interface is None:
|
| 106 |
raise ValueError("Model not loaded. Call load() first.")
|
| 107 |
|
| 108 |
-
try:
|
| 109 |
result = self.interface.process_signal(audio_data, sr)
|
| 110 |
|
| 111 |
# Extract and process results
|
|
|
|
| 7 |
|
| 8 |
class AgeGenderModel:
|
| 9 |
def __init__(self, model_path=None):
|
|
|
|
| 10 |
if model_path is None:
|
| 11 |
if os.path.exists("/data"):
|
| 12 |
# HF Spaces persistent storage
|
|
|
|
| 33 |
print("Age & gender model files not found. Downloading...")
|
| 34 |
|
| 35 |
try:
|
| 36 |
+
# Use /data for cache if available, otherwise use local cache, this i nline with HF Spaces persistent storage
|
| 37 |
if os.path.exists("/data"):
|
| 38 |
cache_root = '/data/cache'
|
| 39 |
else:
|
|
|
|
| 71 |
|
| 72 |
def load(self):
|
| 73 |
try:
|
|
|
|
| 74 |
if not self.download_model():
|
| 75 |
print("Failed to download age & gender model")
|
| 76 |
return False
|
| 77 |
|
|
|
|
| 78 |
print(f"Loading age & gender model from {self.model_path}...")
|
| 79 |
self.model = audonnx.load(self.model_path)
|
| 80 |
|
|
|
|
| 81 |
outputs = ['logits_age', 'logits_gender']
|
| 82 |
self.interface = audinterface.Feature(
|
| 83 |
self.model.labels(outputs),
|
|
|
|
| 87 |
'concat': True,
|
| 88 |
},
|
| 89 |
sampling_rate=self.sampling_rate,
|
| 90 |
+
resample=False,
|
| 91 |
verbose=False,
|
| 92 |
)
|
| 93 |
print("Age & gender model loaded successfully!")
|
|
|
|
| 101 |
if self.model is None or self.interface is None:
|
| 102 |
raise ValueError("Model not loaded. Call load() first.")
|
| 103 |
|
| 104 |
+
try:
|
| 105 |
result = self.interface.process_signal(audio_data, sr)
|
| 106 |
|
| 107 |
# Extract and process results
|
models/nationality_model.py
CHANGED
|
@@ -9,7 +9,6 @@ SAMPLING_RATE = 16000
|
|
| 9 |
|
| 10 |
class NationalityModel:
|
| 11 |
def __init__(self, cache_dir=None):
|
| 12 |
-
# Use persistent storage if available, fallback to local cache
|
| 13 |
if cache_dir is None:
|
| 14 |
if os.path.exists("/data"):
|
| 15 |
# HF Spaces persistent storage
|
|
@@ -48,16 +47,13 @@ class NationalityModel:
|
|
| 48 |
raise ValueError("Model not loaded. Call load() first.")
|
| 49 |
|
| 50 |
try:
|
| 51 |
-
# Ensure audio is properly formatted (float32, mono)
|
| 52 |
if len(audio_data.shape) > 1:
|
| 53 |
audio_data = audio_data.mean(axis=0)
|
| 54 |
|
| 55 |
audio_data = audio_data.astype(np.float32)
|
| 56 |
|
| 57 |
-
# Process audio with the feature extractor
|
| 58 |
inputs = self.processor(audio_data, sampling_rate=sampling_rate, return_tensors="pt")
|
| 59 |
|
| 60 |
-
# Get model predictions
|
| 61 |
with torch.no_grad():
|
| 62 |
outputs = self.model(**inputs).logits
|
| 63 |
|
|
@@ -65,7 +61,6 @@ class NationalityModel:
|
|
| 65 |
probabilities = torch.nn.functional.softmax(outputs, dim=-1)[0]
|
| 66 |
top_k_values, top_k_indices = torch.topk(probabilities, k=5)
|
| 67 |
|
| 68 |
-
# Convert to language codes and probabilities
|
| 69 |
top_languages = []
|
| 70 |
for i, idx in enumerate(top_k_indices):
|
| 71 |
lang_id = idx.item()
|
|
|
|
| 9 |
|
| 10 |
class NationalityModel:
|
| 11 |
def __init__(self, cache_dir=None):
|
|
|
|
| 12 |
if cache_dir is None:
|
| 13 |
if os.path.exists("/data"):
|
| 14 |
# HF Spaces persistent storage
|
|
|
|
| 47 |
raise ValueError("Model not loaded. Call load() first.")
|
| 48 |
|
| 49 |
try:
|
|
|
|
| 50 |
if len(audio_data.shape) > 1:
|
| 51 |
audio_data = audio_data.mean(axis=0)
|
| 52 |
|
| 53 |
audio_data = audio_data.astype(np.float32)
|
| 54 |
|
|
|
|
| 55 |
inputs = self.processor(audio_data, sampling_rate=sampling_rate, return_tensors="pt")
|
| 56 |
|
|
|
|
| 57 |
with torch.no_grad():
|
| 58 |
outputs = self.model(**inputs).logits
|
| 59 |
|
|
|
|
| 61 |
probabilities = torch.nn.functional.softmax(outputs, dim=-1)[0]
|
| 62 |
top_k_values, top_k_indices = torch.topk(probabilities, k=5)
|
| 63 |
|
|
|
|
| 64 |
top_languages = []
|
| 65 |
for i, idx in enumerate(top_k_indices):
|
| 66 |
lang_id = idx.item()
|