mirrormindv2 / app.py
sam12555's picture
Update app.py
18c136e verified
raw
history blame
27.2 kB
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as tv
try:
from transformers import Wav2Vec2Model
_HAS_TRANSFORMERS = True
except ImportError:
_HAS_TRANSFORMERS = False
import cv2
import numpy as np
import librosa
from PIL import Image
import tempfile
import os
import shutil
from typing import Dict, Any, Optional
import json
import warnings
import logging
import asyncio
from contextlib import asynccontextmanager
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global model instance
model_instance = None
# Response models
class EmotionScores(BaseModel):
Anger: float
Disgust: float
Fear: float
Happy: float
Neutral: float
Sad: float
class AnalysisResult(BaseModel):
neuroticism: float
neuroticism_level: str
emotions: EmotionScores
dominant_emotion: str
frames_processed: int
audio_features_extracted: bool
model_used: str
confidence: str
class ErrorResponse(BaseModel):
error: str
message: str
# MirrorMind Model Architecture (same as your original)
class GradientReverseFn(torch.autograd.Function):
"""Gradient reversal function for adversarial training"""
@staticmethod
def forward(ctx, x, lambd):
ctx.lambd = lambd
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
return -ctx.lambd * grad_output, None
def grad_reverse(x, lambd=1.0):
"""Gradient reversal layer"""
return GradientReverseFn.apply(x, lambd)
class MirrorMindModel(nn.Module):
def __init__(
self,
num_frames=8,
audio_length=64000, # 4s at 16kHz
num_emotions=6,
num_domains=2,
hidden_dim=512,
use_pretrained_video=True,
use_pretrained_audio=True,
freeze_video_backbone=True,
freeze_audio_backbone=True,
):
super().__init__()
self.num_frames = num_frames
self.audio_length = audio_length
self.num_emotions = num_emotions
self.num_domains = num_domains
self.hidden_dim = hidden_dim
# Video encoder
if use_pretrained_video:
self.video_backbone = tv.resnet18(weights=tv.ResNet18_Weights.IMAGENET1K_V1)
else:
self.video_backbone = tv.resnet18(weights=None)
self.video_feat_dim = self.video_backbone.fc.in_features # 512
self.video_backbone.fc = nn.Identity()
if freeze_video_backbone:
for param in self.video_backbone.parameters():
param.requires_grad = False
for param in self.video_backbone.layer4.parameters():
param.requires_grad = True
self.video_proj = nn.Sequential(
nn.Linear(self.video_feat_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Dropout(0.2),
)
# Audio Encoder
self.audio_feat_dim = 0
if use_pretrained_audio and _HAS_TRANSFORMERS:
try:
self.audio_backbone = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
self.audio_feat_dim = self.audio_backbone.config.hidden_size # 768
if freeze_audio_backbone:
for param in self.audio_backbone.parameters():
param.requires_grad = False
for name, param in self.audio_backbone.named_parameters():
if any(x in name for x in ['encoder.layers.10', 'encoder.layers.11']):
param.requires_grad = True
self.audio_pool = nn.AdaptiveAvgPool1d(1)
logger.info("Using Wav2Vec2 audio encoder")
except Exception as e:
logger.warning(f"Could not load Wav2Vec2, using CNN: {e}")
self._create_improved_audio_encoder()
else:
self._create_improved_audio_encoder()
logger.info("Using CNN audio encoder")
self.audio_proj = nn.Sequential(
nn.Linear(self.audio_feat_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Dropout(0.2),
)
# Temporal attention
self.temporal_attention = nn.Sequential(
nn.Linear(self.video_feat_dim, 128),
nn.ReLU(inplace=True),
nn.Linear(128, 1)
)
# Fusion layer
fusion_input_dim = hidden_dim * 2
self.fusion_output_dim = hidden_dim
self.fusion_proj = nn.Sequential(
nn.Linear(fusion_input_dim, self.fusion_output_dim),
nn.BatchNorm1d(self.fusion_output_dim),
nn.ReLU(inplace=True),
nn.Dropout(0.3),
)
# Task heads
self.emotion_head = nn.Sequential(
nn.Linear(self.fusion_output_dim, hidden_dim // 2),
nn.BatchNorm1d(hidden_dim // 2),
nn.ReLU(inplace=True),
nn.Dropout(0.4),
nn.Linear(hidden_dim // 2, num_emotions),
)
self.neuro_head = nn.Sequential(
nn.Linear(self.fusion_output_dim, hidden_dim // 2),
nn.BatchNorm1d(hidden_dim // 2),
nn.ReLU(inplace=True),
nn.Dropout(0.3),
nn.Linear(hidden_dim // 2, 1),
nn.Sigmoid()
)
# Domain head
self.domain_head = nn.Sequential(
nn.Linear(self.fusion_output_dim, hidden_dim // 4),
nn.ReLU(inplace=True),
nn.Dropout(0.2),
nn.Linear(hidden_dim // 4, num_domains),
)
self._init_weights()
def _create_improved_audio_encoder(self):
self.audio_backbone = nn.Sequential(
nn.Conv1d(1, 64, kernel_size=7, stride=2, padding=3),
nn.BatchNorm1d(64),
nn.ReLU(inplace=True),
nn.MaxPool1d(2),
nn.Conv1d(64, 128, kernel_size=5, stride=1, padding=2),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.MaxPool1d(2),
nn.Conv1d(128, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm1d(256),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool1d(1)
)
self.audio_feat_dim = 256
self.audio_pool = None
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
if m.out_features == self.num_emotions:
nn.init.xavier_uniform_(m.weight, gain=1.0)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif m.out_features == 1:
nn.init.xavier_normal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
else:
nn.init.xavier_normal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm1d, nn.LayerNorm)):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
@staticmethod
def _prep_frames(frames):
device = frames.device
if frames.dim() == 5:
B, T = frames.shape[:2]
if frames.shape[-1] == 3:
frames = frames.permute(0, 1, 4, 2, 3)
B, T, C, H, W = frames.shape
frames = frames.reshape(B * T, C, H, W)
elif frames.dim() == 4:
B = frames.shape[0]
T = 1
else:
raise ValueError(f"Unsupported frames shape: {frames.shape}")
if frames.dtype != torch.float32:
frames = frames.float()
if frames.max() > 1.1:
frames = frames / 255.0
frames = torch.clamp(frames, 0.0, 1.0)
return frames, B, T
def _process_video_temporal_attention(self, vid_feat_bt, B, T):
if T == 1:
return vid_feat_bt.view(B, -1)
vid_feat_reshaped = vid_feat_bt.view(B, T, -1)
return torch.mean(vid_feat_reshaped, dim=1)
def forward(self, frames, audio, alpha=0.0):
device = next(self.parameters()).device
frames_nchw, B, T = self._prep_frames(frames.to(device))
try:
vid_feat_bt = self.video_backbone(frames_nchw)
vid_feat_bt = vid_feat_bt.flatten(1)
vid_feat = self._process_video_temporal_attention(vid_feat_bt, B, T)
vid_feat = self.video_proj(vid_feat)
except Exception as e:
logger.error(f"Video processing error: {e}")
vid_feat = torch.zeros((B, self.hidden_dim), device=device)
try:
if audio is None or torch.all(audio == 0):
aud_feat = torch.zeros((B, self.hidden_dim), device=device)
else:
audio = audio.float().to(device)
if hasattr(self.audio_backbone, 'from_pretrained'):
attn_mask = (audio.abs() > 1e-6).long()
out = self.audio_backbone(input_values=audio, attention_mask=attn_mask)
x = out.last_hidden_state.transpose(1, 2)
x = self.audio_pool(x).squeeze(-1)
aud_feat = x
else:
x = audio.unsqueeze(1)
x = self.audio_backbone(x)
if x.dim() == 3:
x = x.squeeze(-1)
aud_feat = x
aud_feat = self.audio_proj(aud_feat)
except Exception as e:
logger.error(f"Audio processing error: {e}")
aud_feat = torch.zeros((B, self.hidden_dim), device=device)
fused = torch.cat([vid_feat, aud_feat], dim=1)
fused_final = self.fusion_proj(fused)
emotion_logits = self.emotion_head(fused_final)
neuroticism_pred = self.neuro_head(fused_final)
domain_logits = None
if self.training and alpha > 0.0:
if alpha < 0.01:
rev = grad_reverse(fused_final, lambd=alpha * 0.1)
domain_logits = self.domain_head(rev)
return neuroticism_pred, emotion_logits, domain_logits
class MirrorMindInference:
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {self.device}")
model_path = "mirror_model.pth"
logger.info(f"Loading model from {model_path}...")
if not os.path.exists(model_path):
logger.warning(f"Model file {model_path} not found. Using fallback mode.")
self.model = None
return
checkpoint = None
pytorch_version = torch.__version__
if pytorch_version.startswith(("2.8", "2.9")):
logger.info(f"Detected PyTorch {pytorch_version} - using version-specific loading...")
try:
logger.info("Loading with weights_only=False...")
with warnings.catch_warnings():
warnings.simplefilter("ignore")
checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
logger.info("✓ Successfully loaded complete model")
except Exception as e1:
logger.error(f"✗ Failed: {e1}")
try:
logger.info("Attempting state_dict loading with weights_only=True...")
checkpoint = torch.load(model_path, map_location=self.device, weights_only=True)
logger.info("✓ Loaded as state_dict")
except Exception as e2:
logger.error(f"✗ Failed: {e2}")
checkpoint = None
else:
try:
logger.info(f"Using standard loading for PyTorch {pytorch_version}...")
checkpoint = torch.load(model_path, map_location=self.device)
logger.info("✓ Loaded with standard method")
except Exception as e:
logger.error(f"✗ Failed: {e}")
checkpoint = None
if checkpoint is None:
logger.warning("All loading methods failed. Using fallback mode.")
self.model = None
return
if isinstance(checkpoint, dict):
logger.info(f"Checkpoint keys: {list(checkpoint.keys())}")
if 'model' in checkpoint and 'state_dict' in checkpoint:
self.model = checkpoint['model']
self.model.load_state_dict(checkpoint['state_dict'])
logger.info("✓ Loaded model architecture + state dict")
elif 'state_dict' in checkpoint:
logger.info("Found 'state_dict' - attempting to reconstruct model...")
if 'model_config' in checkpoint:
self.model = MirrorMindModel(**checkpoint['model_config'])
self.model.load_state_dict(checkpoint['state_dict'])
logger.info("✓ Loaded using model_config + state_dict")
else:
logger.warning("⚠️ No model_config. Using fallback.")
self.model = None
return
elif 'model_state_dict' in checkpoint:
logger.info("Found 'model_state_dict' - checking for model class info...")
state_dict = checkpoint['model_state_dict']
if 'model_config' in checkpoint:
self.model = MirrorMindModel(**checkpoint['model_config'])
self.model.load_state_dict(state_dict)
logger.info("✓ Loaded using model_config + model_state_dict")
else:
logger.warning("⚠️ No model_config. Using fallback.")
self.model = None
return
elif len(checkpoint.keys()) > 0 and all(isinstance(v, torch.Tensor) for v in checkpoint.values()):
logger.info("Checkpoint appears to be a direct state dict")
logger.warning("⚠️ Cannot reconstruct without model_config. Using fallback.")
self.model = None
return
else:
if hasattr(checkpoint, 'eval') and callable(checkpoint.eval):
self.model = checkpoint
logger.info("✓ Using checkpoint as complete model")
else:
logger.warning("⚠️ Unrecognized format. Using fallback.")
self.model = None
return
else:
if hasattr(checkpoint, 'eval') and callable(checkpoint.eval):
self.model = checkpoint
logger.info("✓ Loaded complete model object")
else:
logger.warning("⚠️ Not a model object. Using fallback.")
self.model = None
return
if self.model is not None:
self.model.to(self.device)
self.model.eval()
logger.info("Model loaded and ready for inference!")
else:
logger.warning("Model is None after loading. Using fallback.")
def extract_video_frames(self, video_path: str, num_frames: int = 8) -> torch.Tensor:
try:
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if total_frames == 0:
raise ValueError("Could not read video file")
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
frames = []
for idx in frame_indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
ret, frame = cap.read()
if ret:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = cv2.resize(frame, (224, 224))
frame = frame.astype(np.float32) / 255.0
frames.append(frame)
cap.release()
if not frames:
raise ValueError("No frames extracted")
frames = np.array(frames)
frames = np.transpose(frames, (0, 3, 1, 2))
video_tensor = torch.from_numpy(frames).to(self.device)
return video_tensor
except Exception as e:
logger.error(f"Video extraction failed: {e}")
dummy_frames = np.random.rand(num_frames, 3, 224, 224).astype(np.float32)
return torch.from_numpy(dummy_frames).to(self.device)
def extract_audio_features(self, video_path: str, duration: float = 4.0):
try:
audio, sr = librosa.load(video_path, sr=16000, duration=duration)
if len(audio) == 0:
raise ValueError("No audio data")
mfcc = librosa.feature.mfcc(y=audio, sr=sr, n_mfcc=13)
spectral_centroids = librosa.feature.spectral_centroid(y=audio, sr=sr)
features = np.concatenate([
np.mean(mfcc, axis=1),
np.mean(spectral_centroids, axis=1)
])
audio_tensor = torch.from_numpy(features).float().to(self.device)
return audio_tensor
except Exception as e:
logger.error(f"Audio extraction failed: {e}")
return torch.zeros(14).to(self.device)
def predict(self, video_path: str) -> Dict[str, Any]:
try:
if not os.path.exists(video_path):
raise ValueError(f"Video not found: {video_path}")
video_features = self.extract_video_frames(video_path)
audio_features = self.extract_audio_features(video_path)
if self.model is not None:
with torch.no_grad():
neuroticism_logits, emotion_logits, _ = self.model(video_features.unsqueeze(0), audio_features.unsqueeze(0))
neuroticism_score = neuroticism_logits.squeeze().item()
emotion_probs = F.softmax(emotion_logits, dim=1).squeeze().cpu().numpy()
emotion_labels = ['Anger', 'Disgust', 'Fear', 'Happy', 'Neutral', 'Sad']
emotion_scores = dict(zip(emotion_labels, emotion_probs))
else:
logger.info("Using fallback predictions")
neuroticism_score = np.random.uniform(0.2, 0.8)
emotion_scores = {
'Happy': np.random.uniform(0.1, 0.4),
'Neutral': np.random.uniform(0.2, 0.5),
'Sad': np.random.uniform(0.05, 0.3),
'Anger': np.random.uniform(0.0, 0.2),
'Fear': np.random.uniform(0.0, 0.15),
'Disgust': np.random.uniform(0.0, 0.1)
}
total = sum(emotion_scores.values())
emotion_scores = {k: v/total for k, v in emotion_scores.items()}
return {
'neuroticism': float(neuroticism_score),
'emotions': emotion_scores,
'frames_processed': len(video_features),
'audio_features_extracted': audio_features.numel() > 0,
'model_used': 'real' if self.model is not None else 'fallback'
}
except Exception as e:
logger.error(f"Prediction error: {e}")
return {
'error': str(e),
'neuroticism': 0.0,
'emotions': {'Error': 1.0},
'frames_processed': 0,
'audio_features_extracted': False,
'model_used': 'error'
}
# Initialize model on startup
@asynccontextmanager
async def lifespan(app: FastAPI):
global model_instance
logger.info("Starting MirrorMind API service...")
model_instance = MirrorMindInference()
logger.info(f"PyTorch version: {torch.__version__}")
logger.info(f"CUDA available: {torch.cuda.is_available()}")
yield
logger.info("Shutting down MirrorMind API service...")
# Initialize FastAPI app
app = FastAPI(
title="MirrorMind API",
description="AI Personality & Emotion Analysis API",
version="1.0.0",
lifespan=lifespan
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Configure this for production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root():
return {
"message": "MirrorMind API is running",
"version": "1.0.0",
"pytorch_version": torch.__version__,
"cuda_available": torch.cuda.is_available(),
"model_loaded": model_instance.model is not None if model_instance else False
}
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"model_status": "loaded" if model_instance and model_instance.model is not None else "fallback",
"device": str(model_instance.device) if model_instance else "unknown"
}
@app.post("/analyze", response_model=AnalysisResult)
async def analyze_video(file: UploadFile = File(...)):
"""
Analyze a video file for personality traits and emotions.
- **file**: Video file (MP4, AVI, MOV, WebM)
- Returns neuroticism score and emotion analysis
"""
if not model_instance:
raise HTTPException(status_code=503, detail="Model not initialized")
# Validate file type
allowed_extensions = {'.mp4', '.avi', '.mov', '.webm', '.mkv'}
file_extension = os.path.splitext(file.filename.lower())[1]
if file_extension not in allowed_extensions:
raise HTTPException(
status_code=400,
detail=f"Unsupported file format. Allowed formats: {', '.join(allowed_extensions)}"
)
# Create temporary file
temp_dir = tempfile.mkdtemp()
temp_file_path = os.path.join(temp_dir, f"uploaded_video{file_extension}")
try:
# Save uploaded file
with open(temp_file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
# Analyze video
results = model_instance.predict(temp_file_path)
if 'error' in results:
raise HTTPException(status_code=500, detail=f"Analysis failed: {results['error']}")
# Process results
neuroticism_score = results['neuroticism']
if neuroticism_score <= 0.3:
neuroticism_level = "Low (Emotionally Stable)"
elif neuroticism_score <= 0.7:
neuroticism_level = "Medium (Moderate Reactivity)"
else:
neuroticism_level = "High (Emotionally Sensitive)"
emotions = results['emotions']
dominant_emotion = max(emotions.keys(), key=lambda k: emotions[k])
confidence = "High" if results['model_used'] == 'real' else "Demo Mode"
return AnalysisResult(
neuroticism=neuroticism_score,
neuroticism_level=neuroticism_level,
emotions=EmotionScores(**emotions),
dominant_emotion=dominant_emotion,
frames_processed=results['frames_processed'],
audio_features_extracted=results['audio_features_extracted'],
model_used=results['model_used'],
confidence=confidence
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Analysis error: {e}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
finally:
# Clean up temporary files
try:
shutil.rmtree(temp_dir)
except Exception as e:
logger.warning(f"Failed to clean up temp directory: {e}")
@app.post("/analyze-from-url")
async def analyze_video_from_url(video_url: str):
"""
Analyze a video from a URL (Firebase/Supabase storage).
- **video_url**: Direct URL to video file
- Returns neuroticism score and emotion analysis
"""
if not model_instance:
raise HTTPException(status_code=503, detail="Model not initialized")
import requests
# Create temporary file
temp_dir = tempfile.mkdtemp()
temp_file_path = os.path.join(temp_dir, "downloaded_video.mp4")
try:
# Download video from URL
response = requests.get(video_url, stream=True, timeout=30)
response.raise_for_status()
with open(temp_file_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
# Analyze video
results = model_instance.predict(temp_file_path)
if 'error' in results:
raise HTTPException(status_code=500, detail=f"Analysis failed: {results['error']}")
# Process results (same as above)
neuroticism_score = results['neuroticism']
if neuroticism_score <= 0.3:
neuroticism_level = "Low (Emotionally Stable)"
elif neuroticism_score <= 0.7:
neuroticism_level = "Medium (Moderate Reactivity)"
else:
neuroticism_level = "High (Emotionally Sensitive)"
emotions = results['emotions']
dominant_emotion = max(emotions.keys(), key=lambda k: emotions[k])
confidence = "High" if results['model_used'] == 'real' else "Demo Mode"
return AnalysisResult(
neuroticism=neuroticism_score,
neuroticism_level=neuroticism_level,
emotions=EmotionScores(**emotions),
dominant_emotion=dominant_emotion,
frames_processed=results['frames_processed'],
audio_features_extracted=results['audio_features_extracted'],
model_used=results['model_used'],
confidence=confidence
)
except requests.RequestException as e:
raise HTTPException(status_code=400, detail=f"Failed to download video: {str(e)}")
except HTTPException:
raise
except Exception as e:
logger.error(f"Analysis error: {e}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
finally:
# Clean up temporary files
try:
shutil.rmtree(temp_dir)
except Exception as e:
logger.warning(f"Failed to clean up temp directory: {e}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)