Spaces:
Running
on
Zero
Running
on
Zero
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision.models.optical_flow import Raft_Large_Weights, raft_large | |
| from typing import List, Tuple, Dict | |
| import argparse | |
| from pathlib import Path | |
| from sklearn.cluster import KMeans | |
| from tqdm import tqdm | |
| import os | |
| os.environ['OPENBLAS_NUM_THREADS'] = '64' | |
| class OpticalFlowAnalyzer: | |
| def __init__(self, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'): | |
| self.device = device | |
| self.model = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).to(device) | |
| self.model.eval() | |
| def preprocess_frame(self, frame: np.ndarray) -> torch.Tensor: | |
| """Preprocess a frame for RAFT model.""" | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frame = torch.from_numpy(frame).permute(2, 0, 1).float() | |
| frame = frame.unsqueeze(0) / 255.0 | |
| return frame.to(self.device) | |
| def compute_optical_flow(self, frame1: np.ndarray, frame2: np.ndarray) -> np.ndarray: | |
| """Compute optical flow between two consecutive frames.""" | |
| with torch.no_grad(): | |
| frame1_tensor = self.preprocess_frame(frame1) | |
| frame2_tensor = self.preprocess_frame(frame2) | |
| flow = self.model(frame1_tensor, frame2_tensor)[-1] | |
| flow = flow[0].permute(1, 2, 0).cpu().numpy() | |
| return flow | |
| def analyze_motion_regions(self, flow: np.ndarray, num_clusters: int = 3) -> Tuple[np.ndarray, Dict]: | |
| """Cluster motion regions based on optical flow magnitude and direction.""" | |
| h, w = flow.shape[:2] | |
| magnitude = np.sqrt(flow[..., 0]**2 + flow[..., 1]**2) | |
| direction = np.arctan2(flow[..., 1], flow[..., 0]) | |
| # Create feature matrix for clustering | |
| features = np.zeros((h * w, 3)) | |
| features[:, 0] = magnitude.ravel() | |
| features[:, 1] = np.cos(direction).ravel() | |
| features[:, 2] = np.sin(direction).ravel() | |
| # Normalize features | |
| features = (features - features.mean(axis=0)) / features.std(axis=0) | |
| # Perform clustering | |
| kmeans = KMeans(n_clusters=num_clusters, random_state=42,) | |
| labels = kmeans.fit_predict(features) | |
| labels = labels.reshape(h, w) | |
| # Analyze clusters | |
| cluster_stats = {} | |
| for i in range(num_clusters): | |
| cluster_mask = (labels == i) | |
| cluster_magnitude = magnitude[cluster_mask] | |
| cluster_stats[i] = { | |
| 'mean_magnitude': np.mean(cluster_magnitude), | |
| 'std_magnitude': np.std(cluster_magnitude), | |
| 'pixel_count': np.sum(cluster_mask), | |
| 'is_static': np.mean(cluster_magnitude) < 0.1 # Threshold for static regions | |
| } | |
| return labels, cluster_stats | |
| def process_video(self, video_path: str, output_path: str = None) -> List[Tuple[np.ndarray, Dict]]: | |
| """Process a video and return motion analysis results for each frame pair.""" | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| raise ValueError(f"Could not open video: {video_path}") | |
| results = [] | |
| ret, prev_frame = cap.read() | |
| if not ret: | |
| raise ValueError("Could not read first frame") | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| pbar = tqdm(total=total_frames-1, desc="Processing video") | |
| while True: | |
| ret, curr_frame = cap.read() | |
| if not ret: | |
| break | |
| flow = self.compute_optical_flow(prev_frame, curr_frame) | |
| labels, stats = self.analyze_motion_regions(flow) | |
| if output_path: | |
| # Visualize results | |
| vis_frame = curr_frame.copy() | |
| for i, stat in stats.items(): | |
| if not stat['is_static']: | |
| mask = (labels == i).astype(np.uint8) * 255 | |
| print("mask:",mask.shape) | |
| print("vis_frame:",vis_frame.shape) | |
| mask = np.expand_dims(mask, axis=-1).repeat(3, axis=-1) | |
| print("mask:",mask.shape) | |
| vis_frame[mask > 0] = cv2.addWeighted(vis_frame[mask > 0], 0.7, 255, 0.3, 0) | |
| cv2.imwrite(f"{output_path}/frame_{len(results):04d}.jpg", vis_frame) | |
| results.append((labels, stats)) | |
| prev_frame = curr_frame | |
| pbar.update(1) | |
| cap.release() | |
| pbar.close() | |
| return results | |
| def main(): | |
| parser = argparse.ArgumentParser(description='Analyze motion regions in a video using RAFT optical flow') | |
| parser.add_argument('--video', type=str, required=True, help='Path to input video') | |
| parser.add_argument('--output', type=str, help='Path to output directory for visualization') | |
| parser.add_argument('--clusters', type=int, default=3, help='Number of motion clusters') | |
| args = parser.parse_args() | |
| analyzer = OpticalFlowAnalyzer() | |
| results = analyzer.process_video(args.video, args.output) | |
| # Print summary statistics | |
| print("\nMotion Analysis Summary:") | |
| for i, (_, stats) in enumerate(results): | |
| print(f"\nFrame {i+1}:") | |
| for cluster_id, stat in stats.items(): | |
| motion_type = "Static" if stat['is_static'] else "Moving" | |
| print(f" Cluster {cluster_id} ({motion_type}):") | |
| print(f" Mean magnitude: {stat['mean_magnitude']:.4f}") | |
| print(f" Pixel count: {stat['pixel_count']}") | |
| if __name__ == "__main__": | |
| main() | |