Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import random | |
| import re | |
| import sys | |
| import tempfile | |
| import time | |
| from datetime import datetime | |
| from glob import glob | |
| from pathlib import Path | |
| from typing import List, Optional | |
| from uuid import uuid4 | |
| from zipfile import ZipFile | |
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| import requests | |
| from datasets import load_dataset | |
| from decord import VideoReader, cpu | |
| from huggingface_hub import ( | |
| CommitScheduler, | |
| HfApi, | |
| InferenceClient, | |
| login, | |
| snapshot_download, | |
| ) | |
| from PIL import Image | |
| import concurrent.futures | |
| cached_latest_posts_df = None | |
| cached_top_posts = None | |
| last_fetched = None | |
| last_fetched_top = None | |
| def resize_image(image): | |
| width, height = image.size | |
| new_width = width * 0.35 | |
| new_height = height * 0.35 | |
| return image.resize((int(new_width), int(new_height)), Image.BILINEAR) | |
| def get_reddit_id(url): | |
| # Regular expression pattern for r/GamePhysics URLs and IDs | |
| pattern = r"https://www\.reddit\.com/r/GamePhysics/comments/([0-9a-zA-Z]+).*|([0-9a-zA-Z]+)" | |
| # Match the URL or ID against the pattern | |
| match = re.match(pattern, url) | |
| if match: | |
| # Extract the post ID from the URL | |
| post_id = match.group(1) or match.group(2) | |
| print(f"Valid GamePhysics post ID: {post_id}") | |
| else: | |
| post_id = url | |
| return post_id | |
| def download_samples(url, video_url, num_frames): | |
| frames = extract_frames_decord(video_url, num_frames) | |
| # Create a temporary directory to store the images | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| # Save each frame as a JPEG image in the temporary directory | |
| for i, frame in enumerate(frames): | |
| frame_path = os.path.join(temp_dir, f"frame_{i}.jpg") | |
| frame.save( | |
| frame_path, format="JPEG", quality=85 | |
| ) # Adjust quality as needed | |
| # Create a zip file in a persistent location | |
| post_id = get_reddit_id(url) | |
| print(f"Creating zip file for post {post_id}") | |
| zip_path = f"frames-{post_id}.zip" | |
| with ZipFile(zip_path, "w") as zipf: | |
| for i in range(num_frames): | |
| frame_path = os.path.join(temp_dir, f"frame_{i}.jpg") | |
| zipf.write(frame_path, os.path.basename(frame_path)) | |
| # Return the path of the zip file | |
| return zip_path | |
| def extract_frames_decord(video_path, num_frames=10): | |
| try: | |
| start_time = time.time() | |
| print(f"Extracting {num_frames} frames from {video_path}") | |
| # Load the video | |
| vr = VideoReader(video_path, ctx=cpu(0)) | |
| # Calculate the indices for the frames to be extracted | |
| total_frames = len(vr) | |
| frame_indices = np.linspace( | |
| 0, total_frames - 1, num_frames, dtype=int, endpoint=False | |
| ) | |
| # Extract frames | |
| batch_frames = vr.get_batch(frame_indices).asnumpy() | |
| # Convert frames to PIL Images | |
| frame_images = [ | |
| Image.fromarray(batch_frames[i]) for i in range(batch_frames.shape[0]) | |
| ] | |
| end_time = time.time() | |
| print(f"Decord extraction took {end_time - start_time} seconds") | |
| return frame_images | |
| except Exception as e: | |
| raise Exception(f"Error extracting frames from video: {e}") | |
| def extract_frames_decord_preview(video_path, num_frames=10): | |
| try: | |
| start_time = time.time() | |
| print(f"Extracting {num_frames} frames from {video_path}") | |
| # Load the video | |
| vr = VideoReader(video_path, ctx=cpu(0)) | |
| # Calculate the indices for the frames to be extracted | |
| total_frames = len(vr) | |
| frame_indices = np.linspace( | |
| 0, total_frames - 1, num_frames, dtype=int, endpoint=False | |
| ) | |
| # Extract frames | |
| batch_frames = vr.get_batch(frame_indices).asnumpy() | |
| # Convert frames to PIL Images | |
| frame_images = [ | |
| Image.fromarray(batch_frames[i]) for i in range(batch_frames.shape[0]) | |
| ] | |
| end_time = time.time() | |
| print(f"Decord extraction took {end_time - start_time} seconds") | |
| # # resize images to save bandwidth, keep aspect ratio | |
| # for i, image in enumerate(frame_images): | |
| # width, height = image.size | |
| # new_width = int(width * 0.35) | |
| # new_height = int(height * 0.35) | |
| # frame_images[i] = image.resize((new_width, new_height), Image.ANTIALIAS) | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| frame_images = list(executor.map(resize_image, frame_images)) | |
| return frame_images | |
| except Exception as e: | |
| raise Exception(f"Error extracting frames from video: {e}") | |
| def get_top_posts(): | |
| global cached_top_posts | |
| global last_fetched_top | |
| # make sure we don't fetch data too often, limit to 1 request per 10 minutes | |
| now_time = datetime.now() | |
| if last_fetched_top is not None and (now_time - last_fetched_top).seconds < 600: | |
| print("Using cached data") | |
| return cached_top_posts | |
| last_fetched_top = now_time | |
| url = "https://www.reddit.com/r/GamePhysics/top/.json?t=month" | |
| headers = {"User-Agent": "Mozilla/5.0"} | |
| response = requests.get(url, headers=headers) | |
| if response.status_code != 200: | |
| return [] | |
| data = response.json() | |
| # Extract posts from the data | |
| posts = data["data"]["children"] | |
| for post in posts: | |
| title = post["data"]["title"] | |
| post_id = post["data"]["id"] | |
| # print(f"ID: {post_id}, Title: {title}") | |
| # create [post_id, title] list | |
| examples = [[post["data"]["id"], post["data"]["title"]] for post in posts] | |
| # make a dataframe | |
| examples = pd.DataFrame(examples, columns=["post_id", "title"]) | |
| cached_top_posts = examples | |
| return examples | |
| def get_latest_posts(): | |
| global cached_latest_posts_df | |
| global last_fetched | |
| # make sure we don't fetch data too often, limit to 1 request per 10 minutes | |
| now_time = datetime.now() | |
| if last_fetched is not None and (now_time - last_fetched).seconds < 600: | |
| print("Using cached data") | |
| return cached_latest_posts_df | |
| last_fetched = now_time | |
| url = "https://www.reddit.com/r/GamePhysics/.json" | |
| headers = {"User-Agent": "Mozilla/5.0"} | |
| response = requests.get(url, headers=headers) | |
| if response.status_code != 200: | |
| return [] | |
| data = response.json() | |
| # Extract posts from the data | |
| posts = data["data"]["children"] | |
| for post in posts: | |
| title = post["data"]["title"] | |
| post_id = post["data"]["id"] | |
| # print(f"ID: {post_id}, Title: {title}") | |
| # create [post_id, title] list | |
| examples = [[post["data"]["id"], post["data"]["title"]] for post in posts] | |
| # make a dataframe | |
| examples = pd.DataFrame(examples, columns=["post_id", "title"]) | |
| cached_latest_posts_df = examples | |
| return examples | |
| def row_selected_top(evt: gr.SelectData): | |
| global cached_top_posts | |
| string_value = evt.value | |
| row = evt.index[0] | |
| post_id = cached_top_posts.iloc[row]["post_id"] | |
| return post_id | |
| def row_selected_latest(evt: gr.SelectData): | |
| global cached_latest_posts_df | |
| string_value = evt.value | |
| row = evt.index[0] | |
| post_id = cached_latest_posts_df.iloc[row]["post_id"] | |
| return post_id | |
| def load_video(url): | |
| post_id = get_reddit_id(url) | |
| video_url = f"https://huggingface.co/datasets/asgaardlab/GamePhysicsDailyDump/resolve/main/data/videos/{post_id}.mp4?download=true" | |
| video_url2 = f"https://huggingface.co/datasets/asgaardlab/GamePhysics-FullResolution/resolve/main/videos/{post_id}/{post_id}.mp4?download=true" | |
| # make sure file exists before returning, make a request without downloading the file | |
| r1 = requests.head(video_url) | |
| r2 = requests.head(video_url2) | |
| if ( | |
| r2.status_code != 200 | |
| and r2.status_code != 302 | |
| and r1.status_code != 200 | |
| and r1.status_code != 302 | |
| ): | |
| raise gr.Error( | |
| f"Video is not in the repo, please try another post. - {r1.status_code }" | |
| ) | |
| if r1.status_code == 200 or r1.status_code == 302: | |
| return video_url | |
| else: | |
| return video_url2 | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Preview GamePhysics") | |
| dummt_title = gr.Textbox(visible=False) | |
| with gr.Row(): | |
| with gr.Column(): | |
| reddit_id = gr.Textbox( | |
| lines=1, placeholder="Post url or id here", label="URL or Post ID" | |
| ) | |
| load_btn = gr.Button("Load") | |
| video_player = gr.Video(interactive=False) | |
| with gr.Column(): | |
| gr.Markdown("## Sample frames") | |
| num_frames = gr.Slider(minimum=1, maximum=60, step=1, value=10) | |
| sample_decord_btn = gr.Button("Sample frames") | |
| sampled_frames = gr.Gallery(label="Sampled frames preview") | |
| download_samples_btn = gr.Button("Download Samples") | |
| output_files = gr.File() | |
| download_samples_btn.click( | |
| download_samples, | |
| inputs=[reddit_id, video_player, num_frames], | |
| outputs=[output_files], | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("## Reddits Posts") | |
| with gr.Tab("Latest Posts"): | |
| latest_post_dataframe = gr.Dataframe() | |
| latest_posts_btn = gr.Button("Refresh Latest Posts") | |
| with gr.Tab("Top Monthly Posts"): | |
| top_posts_dataframe = gr.Dataframe() | |
| top_posts_btn = gr.Button("Refresh Top Posts") | |
| sample_decord_btn.click( | |
| extract_frames_decord_preview, | |
| inputs=[video_player, num_frames], | |
| outputs=[sampled_frames], | |
| ) | |
| load_btn.click(load_video, inputs=[reddit_id], outputs=[video_player]) | |
| latest_posts_btn.click(get_latest_posts, outputs=[latest_post_dataframe]) | |
| top_posts_btn.click(get_top_posts, outputs=[top_posts_dataframe]) | |
| demo.load(get_latest_posts, outputs=[latest_post_dataframe]) | |
| demo.load(get_top_posts, outputs=[top_posts_dataframe]) | |
| latest_post_dataframe.select(fn=row_selected_latest, outputs=[reddit_id]).then( | |
| load_video, inputs=[reddit_id], outputs=[video_player] | |
| ) | |
| top_posts_dataframe.select(fn=row_selected_top, outputs=[reddit_id]).then( | |
| load_video, inputs=[reddit_id], outputs=[video_player] | |
| ) | |
| demo.launch() | |