Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import spaces | |
| import gradio as gr | |
| import torch | |
| import argparse | |
| from PIL import Image | |
| import numpy as np | |
| import torchvision.transforms as transforms | |
| from moviepy.editor import VideoFileClip | |
| from diffusers.utils import load_image, load_video | |
| from tqdm import tqdm | |
| from image_gen_aux import DepthPreprocessor | |
| project_root = os.path.dirname(os.path.abspath(__file__)) | |
| os.environ["GRADIO_TEMP_DIR"] = os.path.join(project_root, "tmp", "gradio") | |
| sys.path.append(project_root) | |
| try: | |
| sys.path.append(os.path.join(project_root, "submodules/MoGe")) | |
| sys.path.append(os.path.join(project_root, "submodules/vggt")) | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| except: | |
| print("Warning: MoGe not found, motion transfer will not be applied") | |
| HERE_PATH = os.path.normpath(os.path.dirname(__file__)) | |
| sys.path.insert(0, HERE_PATH) | |
| from huggingface_hub import hf_hub_download | |
| hf_hub_download(repo_id="EXCAI/Diffusion-As-Shader", filename='spatracker/spaT_final.pth', local_dir=f'{HERE_PATH}/checkpoints/') | |
| from models.pipelines import DiffusionAsShaderPipeline, FirstFrameRepainter, CameraMotionGenerator, ObjectMotionGenerator | |
| from submodules.MoGe.moge.model import MoGeModel | |
| from submodules.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri | |
| from submodules.vggt.vggt.models.vggt import VGGT | |
| import torch._dynamo | |
| torch._dynamo.config.suppress_errors = True | |
| # Parse command line arguments | |
| parser = argparse.ArgumentParser(description="Diffusion as Shader Web UI") | |
| parser.add_argument("--port", type=int, default=7860, help="Port to run the web UI on") | |
| parser.add_argument("--share", action="store_true", help="Share the web UI") | |
| parser.add_argument("--gpu", type=int, default=0, help="GPU device ID") | |
| parser.add_argument("--model_path", type=str, default="EXCAI/Diffusion-As-Shader", help="Path to model checkpoint") | |
| parser.add_argument("--output_dir", type=str, default="tmp", help="Output directory") | |
| args = parser.parse_args() | |
| # Use the original GPU ID throughout the entire code for consistency | |
| GPU_ID = args.gpu | |
| DEFAULT_MODEL_PATH = args.model_path | |
| OUTPUT_DIR = args.output_dir | |
| # Create necessary directories | |
| os.makedirs("outputs", exist_ok=True) | |
| # Create project tmp directory instead of using system temp | |
| os.makedirs(os.path.join(project_root, "tmp"), exist_ok=True) | |
| os.makedirs(os.path.join(project_root, "tmp", "gradio"), exist_ok=True) | |
| def load_media(media_path, max_frames=49, transform=None): | |
| """Load video or image frames and convert to tensor | |
| Args: | |
| media_path (str): Path to video or image file | |
| max_frames (int): Maximum number of frames to load | |
| transform (callable): Transform to apply to frames | |
| Returns: | |
| Tuple[torch.Tensor, float, bool]: Video tensor [T,C,H,W], FPS, and is_video flag | |
| """ | |
| if transform is None: | |
| transform = transforms.Compose([ | |
| transforms.Resize((480, 720)), | |
| transforms.ToTensor() | |
| ]) | |
| # Determine if input is video or image based on extension | |
| ext = os.path.splitext(media_path)[1].lower() | |
| is_video = ext in ['.mp4', '.avi', '.mov'] | |
| if is_video: | |
| # Load video file info | |
| video_clip = VideoFileClip(media_path) | |
| duration = video_clip.duration | |
| original_fps = video_clip.fps | |
| # Case 1: Video longer than 6 seconds, sample first 6 seconds + 1 frame | |
| if duration > 6.0: | |
| # 使用 max_frames 参数而不是 sampling_fps | |
| frames = load_video(media_path, max_frames=max_frames) | |
| fps = max_frames / 6.0 # 计算等效的 fps | |
| # Cases 2 and 3: Video shorter than 6 seconds | |
| else: | |
| # Load all frames | |
| frames = load_video(media_path) | |
| # Case 2: Total frames less than max_frames, need interpolation | |
| if len(frames) < max_frames: | |
| fps = len(frames) / duration # Keep original fps | |
| # Evenly interpolate to max_frames | |
| indices = np.linspace(0, len(frames) - 1, max_frames) | |
| new_frames = [] | |
| for i in indices: | |
| idx = int(i) | |
| new_frames.append(frames[idx]) | |
| frames = new_frames | |
| # Case 3: Total frames more than max_frames but video less than 6 seconds | |
| else: | |
| # Evenly sample to max_frames | |
| indices = np.linspace(0, len(frames) - 1, max_frames) | |
| new_frames = [] | |
| for i in indices: | |
| idx = int(i) | |
| new_frames.append(frames[idx]) | |
| frames = new_frames | |
| fps = max_frames / duration # New fps to maintain duration | |
| else: | |
| # Handle image as single frame | |
| image = load_image(media_path) | |
| frames = [image] | |
| fps = 8 # Default fps for images | |
| # Duplicate frame to max_frames | |
| while len(frames) < max_frames: | |
| frames.append(frames[0].copy()) | |
| # Convert frames to tensor | |
| video_tensor = torch.stack([transform(frame) for frame in frames]) | |
| return video_tensor, fps, is_video | |
| def save_uploaded_file(file): | |
| if file is None: | |
| return None | |
| # Use project tmp directory instead of system temp | |
| temp_dir = os.path.join(project_root, "tmp") | |
| if hasattr(file, 'name'): | |
| filename = file.name | |
| else: | |
| # Generate a unique filename if name attribute is missing | |
| import uuid | |
| ext = ".tmp" | |
| if hasattr(file, 'content_type'): | |
| if "image" in file.content_type: | |
| ext = ".png" | |
| elif "video" in file.content_type: | |
| ext = ".mp4" | |
| filename = f"{uuid.uuid4()}{ext}" | |
| temp_path = os.path.join(temp_dir, filename) | |
| try: | |
| # Check if file is a FileStorage object or already a path | |
| if hasattr(file, 'save'): | |
| file.save(temp_path) | |
| elif isinstance(file, str): | |
| # It's already a path | |
| return file | |
| else: | |
| # Try to read and save the file | |
| with open(temp_path, 'wb') as f: | |
| f.write(file.read() if hasattr(file, 'read') else file) | |
| except Exception as e: | |
| print(f"Error saving file: {e}") | |
| return None | |
| return temp_path | |
| das_pipeline = None | |
| moge_model = None | |
| vggt_model = None | |
| def get_das_pipeline(): | |
| global das_pipeline | |
| if das_pipeline is None: | |
| das_pipeline = DiffusionAsShaderPipeline(gpu_id=GPU_ID, output_dir=OUTPUT_DIR) | |
| return das_pipeline | |
| def get_moge_model(): | |
| global moge_model | |
| if moge_model is None: | |
| das = get_das_pipeline() | |
| moge_model = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(das.device) | |
| return moge_model | |
| def get_vggt_model(): | |
| global vggt_model | |
| if vggt_model is None: | |
| das = get_das_pipeline() | |
| vggt_model = VGGT.from_pretrained("facebook/VGGT-1B").to(das.device) | |
| return vggt_model | |
| def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image): | |
| """Process video motion transfer task""" | |
| try: | |
| # 保存上传的文件 | |
| input_video_path = save_uploaded_file(source) | |
| if input_video_path is None: | |
| return None, None, None | |
| print(f"DEBUG: Repaint option: {mt_repaint_option}") | |
| print(f"DEBUG: Repaint image: {mt_repaint_image}") | |
| das = get_das_pipeline() | |
| video_tensor, fps, is_video = load_media(input_video_path) | |
| das.fps = fps # 设置 das.fps 为 load_media 返回的 fps | |
| if not is_video: | |
| tracking_method = "moge" | |
| print("Image input detected, using MoGe for tracking video generation.") | |
| else: | |
| tracking_method = "cotracker" | |
| repaint_img_tensor = None | |
| if mt_repaint_image is not None: | |
| repaint_path = save_uploaded_file(mt_repaint_image) | |
| repaint_img_tensor, _, _ = load_media(repaint_path) | |
| repaint_img_tensor = repaint_img_tensor[0] | |
| elif mt_repaint_option == "Yes": | |
| repainter = FirstFrameRepainter(gpu_id=GPU_ID, output_dir=OUTPUT_DIR) | |
| repaint_img_tensor = repainter.repaint( | |
| video_tensor[0], | |
| prompt=prompt, | |
| depth_path=None | |
| ) | |
| tracking_tensor = None | |
| tracking_path = None | |
| if tracking_method == "moge": | |
| moge = get_moge_model() | |
| infer_result = moge.infer(video_tensor[0].to(das.device)) # [C, H, W] in range [0,1] | |
| H, W = infer_result["points"].shape[0:2] | |
| pred_tracks = infer_result["points"].unsqueeze(0).repeat(49, 1, 1, 1) #[T, H, W, 3] | |
| poses = torch.eye(4).unsqueeze(0).repeat(49, 1, 1) | |
| pred_tracks_flatten = pred_tracks.reshape(video_tensor.shape[0], H*W, 3) | |
| cam_motion = CameraMotionGenerator(None) | |
| cam_motion.set_intr(infer_result["intrinsics"]) | |
| pred_tracks = cam_motion.w2s(pred_tracks_flatten, poses).reshape([video_tensor.shape[0], H, W, 3]) # [T, H, W, 3] | |
| tracking_path, tracking_tensor = das.visualize_tracking_moge( | |
| pred_tracks.cpu().numpy(), | |
| infer_result["mask"].cpu().numpy() | |
| ) | |
| print('Export tracking video via MoGe') | |
| else: | |
| # 使用 cotracker | |
| pred_tracks, pred_visibility = generate_tracking_cotracker(video_tensor) | |
| tracking_path, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks, pred_visibility) | |
| print('Export tracking video via cotracker') | |
| # 返回处理结果,但不应用跟踪 | |
| return tracking_path, video_tensor, tracking_tensor, repaint_img_tensor, fps | |
| except Exception as e: | |
| import traceback | |
| print(f"Processing failed: {str(e)}\n{traceback.format_exc()}") | |
| return None, None, None, None, None | |
| def process_camera_control(source, prompt, camera_motion, tracking_method): | |
| """Process camera control task""" | |
| try: | |
| # 保存上传的文件 | |
| input_media_path = save_uploaded_file(source) | |
| if input_media_path is None: | |
| return None, None, None | |
| print(f"DEBUG: Camera motion: '{camera_motion}'") | |
| print(f"DEBUG: Tracking method: '{tracking_method}'") | |
| das = get_das_pipeline() | |
| video_tensor, fps, is_video = load_media(input_media_path) | |
| das.fps = fps # 设置 das.fps 为 load_media 返回的 fps | |
| if not is_video: | |
| tracking_method = "moge" | |
| print("Image input detected, switching to MoGe") | |
| cam_motion = CameraMotionGenerator(camera_motion) | |
| repaint_img_tensor = None | |
| tracking_tensor = None | |
| if tracking_method == "moge": | |
| moge = get_moge_model() | |
| infer_result = moge.infer(video_tensor[0].to(das.device)) # [C, H, W] in range [0,1] | |
| H, W = infer_result["points"].shape[0:2] | |
| pred_tracks = infer_result["points"].unsqueeze(0).repeat(49, 1, 1, 1) #[T, H, W, 3] | |
| cam_motion.set_intr(infer_result["intrinsics"]) | |
| if camera_motion: | |
| poses = cam_motion.get_default_motion() # shape: [49, 4, 4] | |
| print("Camera motion applied") | |
| else: | |
| poses = torch.eye(4).unsqueeze(0).repeat(49, 1, 1) | |
| pred_tracks_flatten = pred_tracks.reshape(video_tensor.shape[0], H*W, 3) | |
| pred_tracks = cam_motion.w2s(pred_tracks_flatten, poses).reshape([video_tensor.shape[0], H, W, 3]) # [T, H, W, 3] | |
| _, tracking_tensor = das.visualize_tracking_moge( | |
| pred_tracks.cpu().numpy(), | |
| infer_result["mask"].cpu().numpy() | |
| ) | |
| print('Export tracking video via MoGe') | |
| else: | |
| # 使用在CPU上运行的cotracker | |
| pred_tracks, pred_visibility = generate_tracking_cotracker(video_tensor) | |
| # 使用封装的 VGGT 处理函数 | |
| extr, intr = process_vggt(video_tensor) | |
| cam_motion.set_intr(intr) | |
| cam_motion.set_extr(extr) | |
| if camera_motion: | |
| poses = cam_motion.get_default_motion() # shape: [49, 4, 4] | |
| pred_tracks_world = cam_motion.s2w_vggt(pred_tracks, extr, intr) | |
| pred_tracks = cam_motion.w2s_vggt(pred_tracks_world, extr, intr, poses) # [T, N, 3] | |
| print("Camera motion applied") | |
| tracking_path, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks, pred_visibility) | |
| print('Export tracking video via cotracker') | |
| # 返回处理结果,但不应用跟踪 | |
| return tracking_path, video_tensor, tracking_tensor, repaint_img_tensor, fps | |
| except Exception as e: | |
| import traceback | |
| print(f"Processing failed: {str(e)}\n{traceback.format_exc()}") | |
| return None, None, None, None, None | |
| def process_object_manipulation(source, prompt, object_motion, object_mask, tracking_method): | |
| """Process object manipulation task""" | |
| try: | |
| # Save uploaded files | |
| input_image_path = save_uploaded_file(source) | |
| if input_image_path is None: | |
| return None, None, None, None, None | |
| object_mask_path = save_uploaded_file(object_mask) | |
| if object_mask_path is None: | |
| print("Object mask not provided") | |
| return None, None, None, None, None | |
| das = get_das_pipeline() | |
| video_tensor, fps, is_video = load_media(input_image_path) | |
| das.fps = fps # 设置 das.fps 为 load_media 返回的 fps | |
| if not is_video: | |
| tracking_method = "moge" | |
| print("Image input detected, switching to MoGe") | |
| mask_image = Image.open(object_mask_path).convert('L') | |
| mask_image = transforms.Resize((480, 720))(mask_image) | |
| mask = torch.from_numpy(np.array(mask_image) > 127) | |
| motion_generator = ObjectMotionGenerator(device=das.device) | |
| repaint_img_tensor = None | |
| tracking_tensor = None | |
| if tracking_method == "moge": | |
| moge = get_moge_model() | |
| infer_result = moge.infer(video_tensor[0].to(das.device)) # [C, H, W] in range [0,1] | |
| H, W = infer_result["points"].shape[0:2] | |
| pred_tracks = infer_result["points"].unsqueeze(0).repeat(49, 1, 1, 1) #[T, H, W, 3] | |
| pred_tracks = motion_generator.apply_motion( | |
| pred_tracks=pred_tracks, | |
| mask=mask, | |
| motion_type=object_motion, | |
| distance=50, | |
| num_frames=49, | |
| tracking_method="moge" | |
| ) | |
| print(f"Object motion '{object_motion}' applied using provided mask") | |
| poses = torch.eye(4).unsqueeze(0).repeat(49, 1, 1) | |
| pred_tracks_flatten = pred_tracks.reshape(video_tensor.shape[0], H*W, 3) | |
| cam_motion = CameraMotionGenerator(None) | |
| cam_motion.set_intr(infer_result["intrinsics"]) | |
| pred_tracks = cam_motion.w2s(pred_tracks_flatten, poses).reshape([video_tensor.shape[0], H, W, 3]) # [T, H, W, 3] | |
| _, tracking_tensor = das.visualize_tracking_moge( | |
| pred_tracks.cpu().numpy(), | |
| infer_result["mask"].cpu().numpy() | |
| ) | |
| print('Export tracking video via MoGe') | |
| else: | |
| # 使用在CPU上运行的cotracker | |
| pred_tracks, pred_visibility = generate_tracking_cotracker(video_tensor) | |
| # 使用封装的 VGGT 处理函数 | |
| extr, intr = process_vggt(video_tensor) | |
| pred_tracks = motion_generator.apply_motion( | |
| pred_tracks=pred_tracks.squeeze(), | |
| mask=mask, | |
| motion_type=object_motion, | |
| distance=50, | |
| num_frames=49, | |
| tracking_method="cotracker" | |
| ) | |
| print(f"Object motion '{object_motion}' applied using provided mask") | |
| tracking_path, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks.unsqueeze(0), pred_visibility) | |
| print('Export tracking video via cotracker') | |
| # 返回处理结果,但不应用跟踪 | |
| return tracking_path, video_tensor, tracking_tensor, repaint_img_tensor, fps | |
| except Exception as e: | |
| import traceback | |
| print(f"Processing failed: {str(e)}\n{traceback.format_exc()}") | |
| return None, None, None, None, None | |
| def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma_repaint_image): | |
| """Process mesh animation task""" | |
| try: | |
| # Save uploaded files | |
| input_video_path = save_uploaded_file(source) | |
| if input_video_path is None: | |
| return None, None, None, None, None | |
| tracking_video_path = save_uploaded_file(tracking_video) | |
| if tracking_video_path is None: | |
| return None, None, None, None, None | |
| das = get_das_pipeline() | |
| video_tensor, fps, is_video = load_media(input_video_path) | |
| das.fps = fps # 设置 das.fps 为 load_media 返回的 fps | |
| tracking_tensor, tracking_fps, _ = load_media(tracking_video_path) | |
| repaint_img_tensor = None | |
| if ma_repaint_image is not None: | |
| repaint_path = save_uploaded_file(ma_repaint_image) | |
| repaint_img_tensor, _, _ = load_media(repaint_path) | |
| repaint_img_tensor = repaint_img_tensor[0] # 获取第一帧 | |
| elif ma_repaint_option == "Yes": | |
| repainter = FirstFrameRepainter(gpu_id=GPU_ID, output_dir=OUTPUT_DIR) | |
| repaint_img_tensor = repainter.repaint( | |
| video_tensor[0], | |
| prompt=prompt, | |
| depth_path=None | |
| ) | |
| # 直接返回上传的跟踪视频路径,而不是生成新的跟踪视频 | |
| return tracking_video_path, video_tensor, tracking_tensor, repaint_img_tensor, fps | |
| except Exception as e: | |
| import traceback | |
| print(f"Processing failed: {str(e)}\n{traceback.format_exc()}") | |
| return None, None, None, None, None | |
| def generate_tracking_cotracker(video_tensor, density=30): | |
| """在CPU上生成跟踪视频,只使用第一帧的深度信息,使用矩阵运算提高效率 | |
| 参数: | |
| video_tensor (torch.Tensor): 输入视频张量 | |
| density (int): 跟踪点的密度 | |
| 返回: | |
| tuple: (pred_tracks, pred_visibility) | |
| """ | |
| cotracker = torch.hub.load("facebookresearch/co-tracker", "cotracker3_offline").to("cpu") | |
| depth_preprocessor = DepthPreprocessor.from_pretrained("Intel/zoedepth-nyu-kitti").to("cpu") | |
| video = video_tensor.unsqueeze(0).to("cpu") | |
| # 只处理第一帧以获取深度图 | |
| print("estimating depth for first frame...") | |
| frame = (video_tensor[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) | |
| depth = depth_preprocessor(Image.fromarray(frame))[0] | |
| depth_tensor = transforms.ToTensor()(depth) # [1, H, W] | |
| # 获取跟踪点和可见性 | |
| print("tracking on CPU...") | |
| pred_tracks, pred_visibility = cotracker(video, grid_size=density) # B T N 2, B T N 1 | |
| # 提取维度 | |
| B, T, N, _ = pred_tracks.shape | |
| H, W = depth_tensor.shape[1], depth_tensor.shape[2] | |
| # 创建带深度的输出张量 | |
| pred_tracks_with_depth = torch.zeros((B, T, N, 3), device="cpu") | |
| pred_tracks_with_depth[:, :, :, :2] = pred_tracks # 复制x,y坐标 | |
| # 使用矩阵运算一次性处理所有帧和点 | |
| # 重塑pred_tracks为[B*T*N, 2]以便于处理 | |
| flat_tracks = pred_tracks.reshape(-1, 2) | |
| # 将坐标限制在有效图像边界内 | |
| x_coords = flat_tracks[:, 0].clamp(0, W-1).long() | |
| y_coords = flat_tracks[:, 1].clamp(0, H-1).long() | |
| # 从第一帧的深度图获取所有点的深度值 | |
| depths = depth_tensor[0, y_coords, x_coords] | |
| # 重塑回原始形状并分配给输出张量 | |
| pred_tracks_with_depth[:, :, :, 2] = depths.reshape(B, T, N) | |
| del cotracker,depth_preprocessor | |
| # 将结果返回 | |
| return pred_tracks_with_depth.squeeze(0), pred_visibility.squeeze(0) | |
| def apply_tracking_unified(video_tensor, tracking_tensor, repaint_img_tensor, prompt, fps): | |
| """统一的应用跟踪函数""" | |
| try: | |
| if video_tensor is None or tracking_tensor is None: | |
| return None | |
| das = get_das_pipeline() | |
| output_path = das.apply_tracking( | |
| video_tensor=video_tensor, | |
| fps=fps, | |
| tracking_tensor=tracking_tensor, | |
| img_cond_tensor=repaint_img_tensor, | |
| prompt=prompt, | |
| checkpoint_path=DEFAULT_MODEL_PATH | |
| ) | |
| print(f"生成的视频路径: {output_path}") | |
| # 确保返回的是绝对路径 | |
| if output_path and not os.path.isabs(output_path): | |
| output_path = os.path.abspath(output_path) | |
| # 检查文件是否存在 | |
| if output_path and os.path.exists(output_path): | |
| print(f"文件存在,大小: {os.path.getsize(output_path)} 字节") | |
| return output_path | |
| else: | |
| print(f"警告: 输出文件不存在或路径无效: {output_path}") | |
| return None | |
| except Exception as e: | |
| import traceback | |
| print(f"Apply tracking failed: {str(e)}\n{traceback.format_exc()}") | |
| return None | |
| # 添加在 apply_tracking_unified 函数之后,Gradio 界面定义之前 | |
| def enable_apply_button(tracking_result): | |
| """当跟踪视频生成后启用应用按钮""" | |
| if tracking_result is not None: | |
| return gr.update(interactive=True) | |
| return gr.update(interactive=False) | |
| def process_vggt(video_tensor): | |
| vggt_model = get_vggt_model() | |
| t, c, h, w = video_tensor.shape | |
| new_width = 518 | |
| new_height = round(h * (new_width / w) / 14) * 14 | |
| resize_transform = transforms.Resize((new_height, new_width), interpolation=Image.BICUBIC) | |
| video_vggt = resize_transform(video_tensor) # [T, C, H, W] | |
| if new_height > 518: | |
| start_y = (new_height - 518) // 2 | |
| video_vggt = video_vggt[:, :, start_y:start_y + 518, :] | |
| with torch.no_grad(): | |
| with torch.cuda.amp.autocast(dtype=torch.float16): | |
| video_vggt = video_vggt.unsqueeze(0) # [1, T, C, H, W] | |
| aggregated_tokens_list, ps_idx = vggt_model.aggregator(video_vggt.to("cuda")) | |
| extr, intr = pose_encoding_to_extri_intri(vggt_model.camera_head(aggregated_tokens_list)[-1], video_vggt.shape[-2:]) | |
| return extr, intr | |
| def load_examples(): | |
| """加载示例文件路径""" | |
| samples_dir = os.path.join(project_root, "samples") | |
| if not os.path.exists(samples_dir): | |
| print(f"Warning: Samples directory not found at {samples_dir}") | |
| return [] | |
| examples_list = [] | |
| # 为每个示例集创建一个示例项 | |
| # 示例1 | |
| example1 = [None] * 5 # [source, repaint_image, prompt, tracking_video, result_video] | |
| for filename in os.listdir(samples_dir): | |
| if filename.startswith("sample1_"): | |
| if filename.endswith("_raw.mp4"): | |
| example1[0] = os.path.join(samples_dir, filename) | |
| elif filename.endswith("_repaint.png"): | |
| example1[1] = os.path.join(samples_dir, filename) | |
| elif filename.endswith("_tracking.mp4"): | |
| example1[3] = os.path.join(samples_dir, filename) | |
| elif filename.endswith("_result.mp4"): | |
| example1[4] = os.path.join(samples_dir, filename) | |
| # 设置示例1的提示文本 | |
| example1[2] = "a rocket lifts off from the table and smoke erupt from its bottom." | |
| # 示例2 | |
| example2 = [None] * 5 # [source, repaint_image, prompt, tracking_video, result_video] | |
| for filename in os.listdir(samples_dir): | |
| if filename.startswith("sample2_"): | |
| if filename.endswith("_raw.mp4"): | |
| example2[0] = os.path.join(samples_dir, filename) | |
| elif filename.endswith("_repaint.png"): | |
| example2[1] = os.path.join(samples_dir, filename) | |
| elif filename.endswith("_tracking.mp4"): | |
| example2[3] = os.path.join(samples_dir, filename) | |
| elif filename.endswith("_result.mp4"): | |
| example2[4] = os.path.join(samples_dir, filename) | |
| # 设置示例2的提示文本 | |
| example2[2] = "A wonderful bright old-fasion red car is riding from left to right sun light is shining on the car, its reflection glittering. In the background is a deserted city in the noon, the roads and buildings are covered with green vegetation." | |
| # 添加示例到列表 | |
| if example1[0] is not None and example1[3] is not None: | |
| examples_list.append(example1) | |
| if example2[0] is not None and example2[3] is not None: | |
| examples_list.append(example2) | |
| # 添加其他示例(如果有) | |
| sample_prefixes = set() | |
| for filename in os.listdir(samples_dir): | |
| if filename.endswith(('.mp4', '.png')): | |
| prefix = filename.split('_')[0] | |
| if prefix not in ["sample1", "sample2"]: | |
| sample_prefixes.add(prefix) | |
| for prefix in sorted(sample_prefixes): | |
| example = [None] * 5 # [source, repaint_image, prompt, tracking_video, result_video] | |
| for filename in os.listdir(samples_dir): | |
| if filename.startswith(f"{prefix}_"): | |
| if filename.endswith("_raw.mp4"): | |
| example[0] = os.path.join(samples_dir, filename) | |
| elif filename.endswith("_repaint.png"): | |
| example[1] = os.path.join(samples_dir, filename) | |
| elif filename.endswith("_tracking.mp4"): | |
| example[3] = os.path.join(samples_dir, filename) | |
| elif filename.endswith("_result.mp4"): | |
| example[4] = os.path.join(samples_dir, filename) | |
| # 添加默认提示文本 | |
| example[2] = "A beautiful scene" | |
| # 只有当至少有源文件和跟踪视频时才添加示例 | |
| if example[0] is not None and example[3] is not None: | |
| examples_list.append(example) | |
| return examples_list | |
| # Create Gradio interface with updated layout | |
| with gr.Blocks(title="Diffusion as Shader") as demo: | |
| gr.Markdown("# Diffusion as Shader Web UI") | |
| gr.Markdown("### [Project Page](https://igl-hkust.github.io/das/) | [GitHub](https://github.com/IGL-HKUST/DiffusionAsShader)") | |
| # 创建隐藏状态变量来存储中间结果 | |
| video_tensor_state = gr.State(None) | |
| tracking_tensor_state = gr.State(None) | |
| repaint_img_tensor_state = gr.State(None) | |
| fps_state = gr.State(None) | |
| with gr.Row(): | |
| left_column = gr.Column(scale=1) | |
| right_column = gr.Column(scale=1) | |
| with right_column: | |
| tracking_video = gr.Video(label="Tracking Video") | |
| # 初始状态下按钮不可用 | |
| apply_tracking_btn = gr.Button("Generate Video", variant="primary", size="lg", interactive=False) | |
| output_video = gr.Video(label="Generated Video") | |
| with left_column: | |
| source_upload = gr.UploadButton("1. Upload Source", file_types=["image", "video"]) | |
| source_preview = gr.Video(label="Source Preview") | |
| gr.Markdown("Upload a video or image, We will extract the motion and space structure from it") | |
| # 上传文件后更新预览 | |
| def update_source_preview(file): | |
| if file is None: | |
| return None | |
| path = save_uploaded_file(file) | |
| return path | |
| source_upload.upload( | |
| fn=update_source_preview, | |
| inputs=[source_upload], | |
| outputs=[source_preview] | |
| ) | |
| common_prompt = gr.Textbox(label="2. Prompt: Describe the scene and the motion you want to create", lines=2) | |
| gr.Markdown(f"**Using GPU: {GPU_ID}**") | |
| with gr.Tabs() as task_tabs: | |
| # Motion Transfer tab | |
| with gr.TabItem("Motion Transfer"): | |
| gr.Markdown("## Motion Transfer") | |
| # Simplified controls - Radio buttons for Yes/No and separate file upload | |
| with gr.Row(): | |
| mt_repaint_option = gr.Radio( | |
| label="Repaint First Frame", | |
| choices=["No", "Yes"], | |
| value="No" | |
| ) | |
| gr.Markdown("### Note: If you want to use your own image as repainted first frame, please upload the image in below.") | |
| mt_repaint_upload = gr.UploadButton("3. Upload Repaint Image (Optional)", file_types=["image"]) | |
| mt_repaint_preview = gr.Image(label="Repaint Image Preview") | |
| # 上传文件后更新预览 | |
| mt_repaint_upload.upload( | |
| fn=update_source_preview, # 复用相同的函数 | |
| inputs=[mt_repaint_upload], | |
| outputs=[mt_repaint_preview] | |
| ) | |
| # Add run button for Motion Transfer tab | |
| mt_run_btn = gr.Button("Generate Tracking", variant="primary", size="lg") | |
| # Connect to process function, but don't apply tracking | |
| mt_run_btn.click( | |
| fn=process_motion_transfer, | |
| inputs=[ | |
| source_upload, common_prompt, | |
| mt_repaint_option, mt_repaint_upload | |
| ], | |
| outputs=[tracking_video, video_tensor_state, tracking_tensor_state, repaint_img_tensor_state, fps_state] | |
| ).then( | |
| fn=enable_apply_button, | |
| inputs=[tracking_video], | |
| outputs=[apply_tracking_btn] | |
| ) | |
| # # Camera Control tab | |
| # with gr.TabItem("Camera Control"): | |
| # gr.Markdown("## Camera Control") | |
| # cc_camera_motion = gr.Textbox( | |
| # label="Current Camera Motion Sequence", | |
| # placeholder="Your camera motion sequence will appear here...", | |
| # interactive=False | |
| # ) | |
| # # Use tabs for different motion types | |
| # with gr.Tabs() as cc_motion_tabs: | |
| # # Translation tab | |
| # with gr.TabItem("Translation (trans)"): | |
| # with gr.Row(): | |
| # cc_trans_x = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, step=0.05, label="X-axis Movement") | |
| # cc_trans_y = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, step=0.05, label="Y-axis Movement") | |
| # cc_trans_z = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, step=0.05, label="Z-axis Movement (depth)") | |
| # with gr.Row(): | |
| # cc_trans_start = gr.Number(minimum=0, maximum=48, value=0, step=1, label="Start Frame", precision=0) | |
| # cc_trans_end = gr.Number(minimum=0, maximum=48, value=48, step=1, label="End Frame", precision=0) | |
| # cc_trans_note = gr.Markdown(""" | |
| # **Translation Notes:** | |
| # - Positive X: Move right, Negative X: Move left | |
| # - Positive Y: Move down, Negative Y: Move up | |
| # - Positive Z: Zoom in, Negative Z: Zoom out | |
| # """) | |
| # # Add translation button in the Translation tab | |
| # cc_add_trans = gr.Button("Add Camera Translation", variant="secondary") | |
| # # Function to add translation motion | |
| # def add_translation_motion(current_motion, trans_x, trans_y, trans_z, trans_start, trans_end): | |
| # # Format: trans dx dy dz [start_frame end_frame] | |
| # frame_range = f" {int(trans_start)} {int(trans_end)}" if trans_start != 0 or trans_end != 48 else "" | |
| # new_motion = f"trans {trans_x:.2f} {trans_y:.2f} {trans_z:.2f}{frame_range}" | |
| # # Append to existing motion string with semicolon separator if needed | |
| # if current_motion and current_motion.strip(): | |
| # updated_motion = f"{current_motion}; {new_motion}" | |
| # else: | |
| # updated_motion = new_motion | |
| # return updated_motion | |
| # # Connect translation button | |
| # cc_add_trans.click( | |
| # fn=add_translation_motion, | |
| # inputs=[ | |
| # cc_camera_motion, | |
| # cc_trans_x, cc_trans_y, cc_trans_z, cc_trans_start, cc_trans_end | |
| # ], | |
| # outputs=[cc_camera_motion] | |
| # ) | |
| # # Rotation tab | |
| # with gr.TabItem("Rotation (rot)"): | |
| # with gr.Row(): | |
| # cc_rot_axis = gr.Dropdown(choices=["x", "y", "z"], value="y", label="Rotation Axis") | |
| # cc_rot_angle = gr.Slider(minimum=-30, maximum=30, value=5, step=1, label="Rotation Angle (degrees)") | |
| # with gr.Row(): | |
| # cc_rot_start = gr.Number(minimum=0, maximum=48, value=0, step=1, label="Start Frame", precision=0) | |
| # cc_rot_end = gr.Number(minimum=0, maximum=48, value=48, step=1, label="End Frame", precision=0) | |
| # cc_rot_note = gr.Markdown(""" | |
| # **Rotation Notes:** | |
| # - X-axis rotation: Tilt camera up/down | |
| # - Y-axis rotation: Pan camera left/right | |
| # - Z-axis rotation: Roll camera | |
| # """) | |
| # # Add rotation button in the Rotation tab | |
| # cc_add_rot = gr.Button("Add Camera Rotation", variant="secondary") | |
| # # Function to add rotation motion | |
| # def add_rotation_motion(current_motion, rot_axis, rot_angle, rot_start, rot_end): | |
| # # Format: rot axis angle [start_frame end_frame] | |
| # frame_range = f" {int(rot_start)} {int(rot_end)}" if rot_start != 0 or rot_end != 48 else "" | |
| # new_motion = f"rot {rot_axis} {rot_angle}{frame_range}" | |
| # # Append to existing motion string with semicolon separator if needed | |
| # if current_motion and current_motion.strip(): | |
| # updated_motion = f"{current_motion}; {new_motion}" | |
| # else: | |
| # updated_motion = new_motion | |
| # return updated_motion | |
| # # Connect rotation button | |
| # cc_add_rot.click( | |
| # fn=add_rotation_motion, | |
| # inputs=[ | |
| # cc_camera_motion, | |
| # cc_rot_axis, cc_rot_angle, cc_rot_start, cc_rot_end | |
| # ], | |
| # outputs=[cc_camera_motion] | |
| # ) | |
| # # Add a clear button to reset the motion sequence | |
| # cc_clear_motion = gr.Button("Clear All Motions", variant="stop") | |
| # def clear_camera_motion(): | |
| # return "" | |
| # cc_clear_motion.click( | |
| # fn=clear_camera_motion, | |
| # inputs=[], | |
| # outputs=[cc_camera_motion] | |
| # ) | |
| # cc_tracking_method = gr.Radio( | |
| # label="Tracking Method", | |
| # choices=["moge", "cotracker"], | |
| # value="cotracker" | |
| # ) | |
| # # Add run button for Camera Control tab | |
| # cc_run_btn = gr.Button("Generate Tracking", variant="primary", size="lg") | |
| # # Connect to process function, but don't apply tracking | |
| # cc_run_btn.click( | |
| # fn=process_camera_control, | |
| # inputs=[ | |
| # source_upload, common_prompt, | |
| # cc_camera_motion, cc_tracking_method | |
| # ], | |
| # outputs=[tracking_video, video_tensor_state, tracking_tensor_state, repaint_img_tensor_state, fps_state] | |
| # ).then( | |
| # fn=enable_apply_button, | |
| # inputs=[tracking_video], | |
| # outputs=[apply_tracking_btn] | |
| # ) | |
| # # Object Manipulation tab | |
| # with gr.TabItem("Object Manipulation"): | |
| # gr.Markdown("## Object Manipulation") | |
| # om_object_mask = gr.File( | |
| # label="Object Mask Image", | |
| # file_types=["image"] | |
| # ) | |
| # gr.Markdown("Upload a binary mask image, white areas indicate the object to manipulate") | |
| # om_object_motion = gr.Dropdown( | |
| # label="Object Motion Type", | |
| # choices=["up", "down", "left", "right", "front", "back", "rot"], | |
| # value="up" | |
| # ) | |
| # om_tracking_method = gr.Radio( | |
| # label="Tracking Method", | |
| # choices=["moge", "cotracker"], | |
| # value="cotracker" | |
| # ) | |
| # # Add run button for Object Manipulation tab | |
| # om_run_btn = gr.Button("Generate Tracking", variant="primary", size="lg") | |
| # # Connect to process function, but don't apply tracking | |
| # om_run_btn.click( | |
| # fn=process_object_manipulation, | |
| # inputs=[ | |
| # source_upload, common_prompt, | |
| # om_object_motion, om_object_mask, om_tracking_method | |
| # ], | |
| # outputs=[tracking_video, video_tensor_state, tracking_tensor_state, repaint_img_tensor_state, fps_state] | |
| # ).then( | |
| # fn=enable_apply_button, | |
| # inputs=[tracking_video], | |
| # outputs=[apply_tracking_btn] | |
| # ) | |
| # # Animating meshes to video tab | |
| # with gr.TabItem("Animating meshes to video"): | |
| # gr.Markdown("## Mesh Animation to Video") | |
| # gr.Markdown(""" | |
| # Note: Currently only supports tracking videos generated with Blender (version > 4.0). | |
| # Please run the script `scripts/blender.py` in your Blender project to generate tracking videos. | |
| # """) | |
| # ma_tracking_video = gr.File( | |
| # label="Tracking Video", | |
| # file_types=["video"], | |
| # # 添加 change 事件处理器,当上传文件时自动激活 Generate Video 按钮 | |
| # elem_id="ma_tracking_video" | |
| # ) | |
| # gr.Markdown("Tracking video needs to be generated from Blender") | |
| # # Simplified controls - Radio buttons for Yes/No and separate file upload | |
| # with gr.Row(): | |
| # ma_repaint_option = gr.Radio( | |
| # label="Repaint First Frame", | |
| # choices=["No", "Yes"], | |
| # value="No" | |
| # ) | |
| # gr.Markdown("### Note: If you want to use your own image as repainted first frame, please upload the image in below.") | |
| # # Custom image uploader (always visible) | |
| # ma_repaint_image = gr.File( | |
| # label="Custom Repaint Image", | |
| # file_types=["image"] | |
| # ) | |
| # # 修改按钮名称为 "Apply Repaint" | |
| # ma_run_btn = gr.Button("Apply Repaint", variant="primary", size="lg") | |
| # # 添加 tracking video 上传事件处理 | |
| # def handle_tracking_upload(file): | |
| # if file is not None: | |
| # tracking_path = save_uploaded_file(file) | |
| # if tracking_path: | |
| # return tracking_path, gr.update(interactive=True) | |
| # return None, gr.update(interactive=False) | |
| # # 当上传 tracking video 时,直接显示并激活 Generate Video 按钮 | |
| # ma_tracking_video.change( | |
| # fn=handle_tracking_upload, | |
| # inputs=[ma_tracking_video], | |
| # outputs=[tracking_video, apply_tracking_btn] | |
| # ) | |
| # # 修改 process_mesh_animation 函数的行为 | |
| # def process_mesh_animation_repaint(source, prompt, ma_repaint_option, ma_repaint_image): | |
| # """只处理重绘部分,不处理跟踪视频""" | |
| # try: | |
| # # 保存上传的文件 | |
| # input_video_path = save_uploaded_file(source) | |
| # if input_video_path is None: | |
| # return None, None, None, None | |
| # das = get_das_pipeline() | |
| # video_tensor, fps, is_video = load_media(input_video_path) | |
| # das.fps = fps | |
| # repaint_img_tensor = None | |
| # if ma_repaint_image is not None: | |
| # repaint_path = save_uploaded_file(ma_repaint_image) | |
| # repaint_img_tensor, _, _ = load_media(repaint_path) | |
| # repaint_img_tensor = repaint_img_tensor[0] | |
| # elif ma_repaint_option == "Yes": | |
| # repainter = FirstFrameRepainter(gpu_id=GPU_ID, output_dir=OUTPUT_DIR) | |
| # repaint_img_tensor = repainter.repaint( | |
| # video_tensor[0], | |
| # prompt=prompt, | |
| # depth_path=None | |
| # ) | |
| # # 返回处理结果,但不包括跟踪视频路径 | |
| # return video_tensor, None, repaint_img_tensor, fps | |
| # except Exception as e: | |
| # import traceback | |
| # print(f"Processing failed: {str(e)}\n{traceback.format_exc()}") | |
| # return None, None, None, None | |
| # # 连接到修改后的处理函数 | |
| # ma_run_btn.click( | |
| # fn=process_mesh_animation_repaint, | |
| # inputs=[ | |
| # source_upload, common_prompt, | |
| # ma_repaint_option, ma_repaint_image | |
| # ], | |
| # outputs=[video_tensor_state, tracking_tensor_state, repaint_img_tensor_state, fps_state] | |
| # ) | |
| # 在所有 UI 元素定义之后,添加 Examples 组件 | |
| examples_list = load_examples() | |
| if examples_list: | |
| with gr.Blocks() as examples_block: | |
| gr.Examples( | |
| examples=examples_list, | |
| inputs=[source_preview, mt_repaint_preview, common_prompt, tracking_video, output_video], | |
| outputs=[source_preview, mt_repaint_preview, common_prompt, tracking_video, output_video], | |
| fn=lambda *args: args, # 简单地返回输入作为输出 | |
| cache_examples=True, | |
| label="Examples" | |
| ) | |
| # Launch interface | |
| if __name__ == "__main__": | |
| print(f"Using GPU: {GPU_ID}") | |
| print(f"Web UI will start on port {args.port}") | |
| if args.share: | |
| print("Creating public link for remote access") | |
| # Launch interface | |
| demo.launch(share=args.share, server_port=args.port) |