Spaces:
Runtime error
Runtime error
update
Browse files- .gitignore +2 -1
- app.py +469 -281
.gitignore
CHANGED
|
@@ -197,4 +197,5 @@ slurm-*.out
|
|
| 197 |
.vscode
|
| 198 |
|
| 199 |
data/
|
| 200 |
-
tmp/
|
|
|
|
|
|
| 197 |
.vscode
|
| 198 |
|
| 199 |
data/
|
| 200 |
+
tmp/
|
| 201 |
+
.gradio/
|
app.py
CHANGED
|
@@ -33,6 +33,9 @@ from submodules.MoGe.moge.model import MoGeModel
|
|
| 33 |
from submodules.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri
|
| 34 |
from submodules.vggt.vggt.models.vggt import VGGT
|
| 35 |
|
|
|
|
|
|
|
|
|
|
| 36 |
# Parse command line arguments
|
| 37 |
parser = argparse.ArgumentParser(description="Diffusion as Shader Web UI")
|
| 38 |
parser.add_argument("--port", type=int, default=7860, help="Port to run the web UI on")
|
|
@@ -82,9 +85,9 @@ def load_media(media_path, max_frames=49, transform=None):
|
|
| 82 |
|
| 83 |
# Case 1: Video longer than 6 seconds, sample first 6 seconds + 1 frame
|
| 84 |
if duration > 6.0:
|
| 85 |
-
|
| 86 |
-
frames = load_video(media_path,
|
| 87 |
-
fps =
|
| 88 |
# Cases 2 and 3: Video shorter than 6 seconds
|
| 89 |
else:
|
| 90 |
# Load all frames
|
|
@@ -195,10 +198,10 @@ def get_vggt_model():
|
|
| 195 |
def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image):
|
| 196 |
"""Process video motion transfer task"""
|
| 197 |
try:
|
| 198 |
-
#
|
| 199 |
input_video_path = save_uploaded_file(source)
|
| 200 |
if input_video_path is None:
|
| 201 |
-
return None, None
|
| 202 |
|
| 203 |
print(f"DEBUG: Repaint option: {mt_repaint_option}")
|
| 204 |
print(f"DEBUG: Repaint image: {mt_repaint_image}")
|
|
@@ -253,28 +256,20 @@ def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image)
|
|
| 253 |
tracking_path, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks, pred_visibility)
|
| 254 |
print('Export tracking video via cotracker')
|
| 255 |
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
fps=fps, # 使用 load_media 返回的 fps
|
| 259 |
-
tracking_tensor=tracking_tensor,
|
| 260 |
-
img_cond_tensor=repaint_img_tensor,
|
| 261 |
-
prompt=prompt,
|
| 262 |
-
checkpoint_path=DEFAULT_MODEL_PATH
|
| 263 |
-
)
|
| 264 |
-
|
| 265 |
-
return tracking_path, output_path
|
| 266 |
except Exception as e:
|
| 267 |
import traceback
|
| 268 |
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
| 269 |
-
return None, None
|
| 270 |
|
| 271 |
def process_camera_control(source, prompt, camera_motion, tracking_method):
|
| 272 |
"""Process camera control task"""
|
| 273 |
try:
|
| 274 |
-
#
|
| 275 |
input_media_path = save_uploaded_file(source)
|
| 276 |
if input_media_path is None:
|
| 277 |
-
return None, None
|
| 278 |
|
| 279 |
print(f"DEBUG: Camera motion: '{camera_motion}'")
|
| 280 |
print(f"DEBUG: Tracking method: '{tracking_method}'")
|
|
@@ -317,24 +312,8 @@ def process_camera_control(source, prompt, camera_motion, tracking_method):
|
|
| 317 |
# 使用在CPU上运行的cotracker
|
| 318 |
pred_tracks, pred_visibility = generate_tracking_cotracker(video_tensor)
|
| 319 |
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
new_height = round(h * (new_width / w) / 14) * 14
|
| 323 |
-
resize_transform = transforms.Resize((new_height, new_width), interpolation=Image.BICUBIC)
|
| 324 |
-
video_vggt = resize_transform(video_tensor) # [T, C, H, W]
|
| 325 |
-
|
| 326 |
-
if new_height > 518:
|
| 327 |
-
start_y = (new_height - 518) // 2
|
| 328 |
-
video_vggt = video_vggt[:, :, start_y:start_y + 518, :]
|
| 329 |
-
|
| 330 |
-
vggt_model = get_vggt_model()
|
| 331 |
-
|
| 332 |
-
with torch.no_grad():
|
| 333 |
-
with torch.cuda.amp.autocast(dtype=das.dtype):
|
| 334 |
-
video_vggt = video_vggt.unsqueeze(0) # [1, T, C, H, W]
|
| 335 |
-
aggregated_tokens_list, ps_idx = vggt_model.aggregator(video_vggt.to(das.device))
|
| 336 |
-
|
| 337 |
-
extr, intr = pose_encoding_to_extri_intri(vggt_model.camera_head(aggregated_tokens_list)[-1], video_vggt.shape[-2:])
|
| 338 |
|
| 339 |
cam_motion.set_intr(intr)
|
| 340 |
cam_motion.set_extr(extr)
|
|
@@ -345,23 +324,15 @@ def process_camera_control(source, prompt, camera_motion, tracking_method):
|
|
| 345 |
pred_tracks = cam_motion.w2s_vggt(pred_tracks_world, extr, intr, poses) # [T, N, 3]
|
| 346 |
print("Camera motion applied")
|
| 347 |
|
| 348 |
-
tracking_path, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks,
|
| 349 |
print('Export tracking video via cotracker')
|
| 350 |
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
fps=fps, # 使用 load_media 返回的 fps
|
| 354 |
-
tracking_tensor=tracking_tensor,
|
| 355 |
-
img_cond_tensor=repaint_img_tensor,
|
| 356 |
-
prompt=prompt,
|
| 357 |
-
checkpoint_path=DEFAULT_MODEL_PATH
|
| 358 |
-
)
|
| 359 |
-
|
| 360 |
-
return tracking_path, output_path
|
| 361 |
except Exception as e:
|
| 362 |
import traceback
|
| 363 |
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
| 364 |
-
return None, None
|
| 365 |
|
| 366 |
def process_object_manipulation(source, prompt, object_motion, object_mask, tracking_method):
|
| 367 |
"""Process object manipulation task"""
|
|
@@ -369,12 +340,12 @@ def process_object_manipulation(source, prompt, object_motion, object_mask, trac
|
|
| 369 |
# Save uploaded files
|
| 370 |
input_image_path = save_uploaded_file(source)
|
| 371 |
if input_image_path is None:
|
| 372 |
-
return None, None
|
| 373 |
|
| 374 |
object_mask_path = save_uploaded_file(object_mask)
|
| 375 |
if object_mask_path is None:
|
| 376 |
print("Object mask not provided")
|
| 377 |
-
return None, None
|
| 378 |
|
| 379 |
das = get_das_pipeline()
|
| 380 |
video_tensor, fps, is_video = load_media(input_image_path)
|
|
@@ -424,24 +395,8 @@ def process_object_manipulation(source, prompt, object_motion, object_mask, trac
|
|
| 424 |
# 使用在CPU上运行的cotracker
|
| 425 |
pred_tracks, pred_visibility = generate_tracking_cotracker(video_tensor)
|
| 426 |
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
new_height = round(h * (new_width / w) / 14) * 14
|
| 430 |
-
resize_transform = transforms.Resize((new_height, new_width), interpolation=Image.BICUBIC)
|
| 431 |
-
video_vggt = resize_transform(video_tensor) # [T, C, H, W]
|
| 432 |
-
|
| 433 |
-
if new_height > 518:
|
| 434 |
-
start_y = (new_height - 518) // 2
|
| 435 |
-
video_vggt = video_vggt[:, :, start_y:start_y + 518, :]
|
| 436 |
-
|
| 437 |
-
vggt_model = get_vggt_model()
|
| 438 |
-
|
| 439 |
-
with torch.no_grad():
|
| 440 |
-
with torch.cuda.amp.autocast(dtype=das.dtype):
|
| 441 |
-
video_vggt = video_vggt.unsqueeze(0) # [1, T, C, H, W]
|
| 442 |
-
aggregated_tokens_list, ps_idx = vggt_model.aggregator(video_vggt.to(das.device))
|
| 443 |
-
|
| 444 |
-
extr, intr = pose_encoding_to_extri_intri(vggt_model.camera_head(aggregated_tokens_list)[-1], video_vggt.shape[-2:])
|
| 445 |
|
| 446 |
pred_tracks = motion_generator.apply_motion(
|
| 447 |
pred_tracks=pred_tracks.squeeze(),
|
|
@@ -453,23 +408,15 @@ def process_object_manipulation(source, prompt, object_motion, object_mask, trac
|
|
| 453 |
)
|
| 454 |
print(f"Object motion '{object_motion}' applied using provided mask")
|
| 455 |
|
| 456 |
-
tracking_path, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks.unsqueeze(0),
|
| 457 |
print('Export tracking video via cotracker')
|
| 458 |
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
fps=fps, # 使用 load_media 返回的 fps
|
| 462 |
-
tracking_tensor=tracking_tensor,
|
| 463 |
-
img_cond_tensor=repaint_img_tensor,
|
| 464 |
-
prompt=prompt,
|
| 465 |
-
checkpoint_path=DEFAULT_MODEL_PATH
|
| 466 |
-
)
|
| 467 |
-
|
| 468 |
-
return tracking_path, output_path
|
| 469 |
except Exception as e:
|
| 470 |
import traceback
|
| 471 |
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
| 472 |
-
return None, None
|
| 473 |
|
| 474 |
def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma_repaint_image):
|
| 475 |
"""Process mesh animation task"""
|
|
@@ -477,11 +424,11 @@ def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma
|
|
| 477 |
# Save uploaded files
|
| 478 |
input_video_path = save_uploaded_file(source)
|
| 479 |
if input_video_path is None:
|
| 480 |
-
return None, None
|
| 481 |
|
| 482 |
tracking_video_path = save_uploaded_file(tracking_video)
|
| 483 |
if tracking_video_path is None:
|
| 484 |
-
return None, None
|
| 485 |
|
| 486 |
das = get_das_pipeline()
|
| 487 |
video_tensor, fps, is_video = load_media(input_video_path)
|
|
@@ -494,7 +441,6 @@ def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma
|
|
| 494 |
repaint_img_tensor, _, _ = load_media(repaint_path)
|
| 495 |
repaint_img_tensor = repaint_img_tensor[0] # 获取第一帧
|
| 496 |
elif ma_repaint_option == "Yes":
|
| 497 |
-
|
| 498 |
repainter = FirstFrameRepainter(gpu_id=GPU_ID, output_dir=OUTPUT_DIR)
|
| 499 |
repaint_img_tensor = repainter.repaint(
|
| 500 |
video_tensor[0],
|
|
@@ -502,20 +448,12 @@ def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma
|
|
| 502 |
depth_path=None
|
| 503 |
)
|
| 504 |
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
fps=fps, # 使用 load_media 返回的 fps
|
| 508 |
-
tracking_tensor=tracking_tensor,
|
| 509 |
-
img_cond_tensor=repaint_img_tensor,
|
| 510 |
-
prompt=prompt,
|
| 511 |
-
checkpoint_path=DEFAULT_MODEL_PATH
|
| 512 |
-
)
|
| 513 |
-
|
| 514 |
-
return tracking_video_path, output_path
|
| 515 |
except Exception as e:
|
| 516 |
import traceback
|
| 517 |
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
| 518 |
-
return None, None
|
| 519 |
|
| 520 |
def generate_tracking_cotracker(video_tensor, density=30):
|
| 521 |
"""在CPU上生成跟踪视频,只使用第一帧的深度信息,使用矩阵运算提高效率
|
|
@@ -569,22 +507,192 @@ def generate_tracking_cotracker(video_tensor, density=30):
|
|
| 569 |
# 将结果返回
|
| 570 |
return pred_tracks_with_depth.squeeze(0), pred_visibility.squeeze(0)
|
| 571 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 572 |
# Create Gradio interface with updated layout
|
| 573 |
with gr.Blocks(title="Diffusion as Shader") as demo:
|
| 574 |
gr.Markdown("# Diffusion as Shader Web UI")
|
| 575 |
gr.Markdown("### [Project Page](https://igl-hkust.github.io/das/) | [GitHub](https://github.com/IGL-HKUST/DiffusionAsShader)")
|
| 576 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 577 |
with gr.Row():
|
| 578 |
left_column = gr.Column(scale=1)
|
| 579 |
right_column = gr.Column(scale=1)
|
| 580 |
|
| 581 |
with right_column:
|
| 582 |
-
output_video = gr.Video(label="Generated Video")
|
| 583 |
tracking_video = gr.Video(label="Tracking Video")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 584 |
|
| 585 |
with left_column:
|
| 586 |
-
|
| 587 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 588 |
gr.Markdown(f"**Using GPU: {GPU_ID}**")
|
| 589 |
|
| 590 |
with gr.Tabs() as task_tabs:
|
|
@@ -600,228 +708,308 @@ with gr.Blocks(title="Diffusion as Shader") as demo:
|
|
| 600 |
value="No"
|
| 601 |
)
|
| 602 |
gr.Markdown("### Note: If you want to use your own image as repainted first frame, please upload the image in below.")
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 607 |
)
|
| 608 |
|
| 609 |
# Add run button for Motion Transfer tab
|
| 610 |
-
mt_run_btn = gr.Button("
|
| 611 |
|
| 612 |
-
# Connect to process function
|
| 613 |
mt_run_btn.click(
|
| 614 |
fn=process_motion_transfer,
|
| 615 |
inputs=[
|
| 616 |
-
|
| 617 |
-
mt_repaint_option,
|
| 618 |
],
|
| 619 |
-
outputs=[tracking_video,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 620 |
)
|
| 621 |
|
| 622 |
-
# Camera Control tab
|
| 623 |
-
with gr.TabItem("Camera Control"):
|
| 624 |
-
|
| 625 |
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
|
| 640 |
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
|
| 644 |
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
|
| 652 |
-
|
| 653 |
-
|
| 654 |
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
| 660 |
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
|
| 667 |
-
|
| 668 |
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
| 688 |
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
|
| 696 |
-
|
| 697 |
-
|
| 698 |
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
|
| 710 |
|
| 711 |
-
|
| 712 |
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
|
| 723 |
-
|
| 724 |
-
|
| 725 |
|
| 726 |
-
|
| 727 |
-
|
| 728 |
|
| 729 |
-
|
| 730 |
-
|
| 731 |
-
|
| 732 |
-
|
| 733 |
-
|
| 734 |
-
|
| 735 |
-
|
| 736 |
-
|
| 737 |
-
|
| 738 |
-
|
| 739 |
-
|
| 740 |
|
| 741 |
-
|
| 742 |
-
|
| 743 |
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 753 |
|
| 754 |
-
# Object Manipulation tab
|
| 755 |
-
with gr.TabItem("Object Manipulation"):
|
| 756 |
-
|
| 757 |
-
|
| 758 |
-
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
|
| 762 |
-
|
| 763 |
-
|
| 764 |
-
|
| 765 |
-
|
| 766 |
-
|
| 767 |
-
|
| 768 |
-
|
| 769 |
-
|
| 770 |
-
|
| 771 |
-
|
| 772 |
|
| 773 |
-
|
| 774 |
-
|
| 775 |
|
| 776 |
-
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
|
| 781 |
-
|
| 782 |
-
|
| 783 |
-
|
| 784 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 785 |
|
| 786 |
-
# Animating meshes to video tab
|
| 787 |
-
with gr.TabItem("Animating meshes to video"):
|
| 788 |
-
|
| 789 |
-
|
| 790 |
-
|
| 791 |
-
|
| 792 |
-
|
| 793 |
-
|
| 794 |
-
|
| 795 |
-
|
| 796 |
-
|
| 797 |
-
|
|
|
|
|
|
|
| 798 |
|
| 799 |
-
|
| 800 |
-
|
| 801 |
-
|
| 802 |
-
|
| 803 |
-
|
| 804 |
-
|
| 805 |
-
|
| 806 |
-
|
| 807 |
-
|
| 808 |
-
|
| 809 |
-
|
| 810 |
-
|
| 811 |
-
|
| 812 |
|
| 813 |
-
|
| 814 |
-
|
| 815 |
|
| 816 |
-
|
| 817 |
-
|
| 818 |
-
|
| 819 |
-
|
| 820 |
-
|
| 821 |
-
|
| 822 |
-
|
| 823 |
-
|
| 824 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 825 |
|
| 826 |
# Launch interface
|
| 827 |
if __name__ == "__main__":
|
|
@@ -831,4 +1019,4 @@ if __name__ == "__main__":
|
|
| 831 |
print("Creating public link for remote access")
|
| 832 |
|
| 833 |
# Launch interface
|
| 834 |
-
demo.launch(share=args.share, server_port=args.port)
|
|
|
|
| 33 |
from submodules.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri
|
| 34 |
from submodules.vggt.vggt.models.vggt import VGGT
|
| 35 |
|
| 36 |
+
import torch._dynamo
|
| 37 |
+
torch._dynamo.config.suppress_errors = True
|
| 38 |
+
|
| 39 |
# Parse command line arguments
|
| 40 |
parser = argparse.ArgumentParser(description="Diffusion as Shader Web UI")
|
| 41 |
parser.add_argument("--port", type=int, default=7860, help="Port to run the web UI on")
|
|
|
|
| 85 |
|
| 86 |
# Case 1: Video longer than 6 seconds, sample first 6 seconds + 1 frame
|
| 87 |
if duration > 6.0:
|
| 88 |
+
# 使用 max_frames 参数而不是 sampling_fps
|
| 89 |
+
frames = load_video(media_path, max_frames=max_frames)
|
| 90 |
+
fps = max_frames / 6.0 # 计算等效的 fps
|
| 91 |
# Cases 2 and 3: Video shorter than 6 seconds
|
| 92 |
else:
|
| 93 |
# Load all frames
|
|
|
|
| 198 |
def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image):
|
| 199 |
"""Process video motion transfer task"""
|
| 200 |
try:
|
| 201 |
+
# 保存上传的文件
|
| 202 |
input_video_path = save_uploaded_file(source)
|
| 203 |
if input_video_path is None:
|
| 204 |
+
return None, None, None
|
| 205 |
|
| 206 |
print(f"DEBUG: Repaint option: {mt_repaint_option}")
|
| 207 |
print(f"DEBUG: Repaint image: {mt_repaint_image}")
|
|
|
|
| 256 |
tracking_path, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks, pred_visibility)
|
| 257 |
print('Export tracking video via cotracker')
|
| 258 |
|
| 259 |
+
# 返回处理结果,但不应用跟踪
|
| 260 |
+
return tracking_path, video_tensor, tracking_tensor, repaint_img_tensor, fps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
except Exception as e:
|
| 262 |
import traceback
|
| 263 |
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
| 264 |
+
return None, None, None, None, None
|
| 265 |
|
| 266 |
def process_camera_control(source, prompt, camera_motion, tracking_method):
|
| 267 |
"""Process camera control task"""
|
| 268 |
try:
|
| 269 |
+
# 保存上传的文件
|
| 270 |
input_media_path = save_uploaded_file(source)
|
| 271 |
if input_media_path is None:
|
| 272 |
+
return None, None, None
|
| 273 |
|
| 274 |
print(f"DEBUG: Camera motion: '{camera_motion}'")
|
| 275 |
print(f"DEBUG: Tracking method: '{tracking_method}'")
|
|
|
|
| 312 |
# 使用在CPU上运行的cotracker
|
| 313 |
pred_tracks, pred_visibility = generate_tracking_cotracker(video_tensor)
|
| 314 |
|
| 315 |
+
# 使用封装的 VGGT 处理函数
|
| 316 |
+
extr, intr = process_vggt(video_tensor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
|
| 318 |
cam_motion.set_intr(intr)
|
| 319 |
cam_motion.set_extr(extr)
|
|
|
|
| 324 |
pred_tracks = cam_motion.w2s_vggt(pred_tracks_world, extr, intr, poses) # [T, N, 3]
|
| 325 |
print("Camera motion applied")
|
| 326 |
|
| 327 |
+
tracking_path, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks, pred_visibility)
|
| 328 |
print('Export tracking video via cotracker')
|
| 329 |
|
| 330 |
+
# 返回处理结果,但不应用跟踪
|
| 331 |
+
return tracking_path, video_tensor, tracking_tensor, repaint_img_tensor, fps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
except Exception as e:
|
| 333 |
import traceback
|
| 334 |
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
| 335 |
+
return None, None, None, None, None
|
| 336 |
|
| 337 |
def process_object_manipulation(source, prompt, object_motion, object_mask, tracking_method):
|
| 338 |
"""Process object manipulation task"""
|
|
|
|
| 340 |
# Save uploaded files
|
| 341 |
input_image_path = save_uploaded_file(source)
|
| 342 |
if input_image_path is None:
|
| 343 |
+
return None, None, None, None, None
|
| 344 |
|
| 345 |
object_mask_path = save_uploaded_file(object_mask)
|
| 346 |
if object_mask_path is None:
|
| 347 |
print("Object mask not provided")
|
| 348 |
+
return None, None, None, None, None
|
| 349 |
|
| 350 |
das = get_das_pipeline()
|
| 351 |
video_tensor, fps, is_video = load_media(input_image_path)
|
|
|
|
| 395 |
# 使用在CPU上运行的cotracker
|
| 396 |
pred_tracks, pred_visibility = generate_tracking_cotracker(video_tensor)
|
| 397 |
|
| 398 |
+
# 使用封装的 VGGT 处理函数
|
| 399 |
+
extr, intr = process_vggt(video_tensor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
|
| 401 |
pred_tracks = motion_generator.apply_motion(
|
| 402 |
pred_tracks=pred_tracks.squeeze(),
|
|
|
|
| 408 |
)
|
| 409 |
print(f"Object motion '{object_motion}' applied using provided mask")
|
| 410 |
|
| 411 |
+
tracking_path, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks.unsqueeze(0), pred_visibility)
|
| 412 |
print('Export tracking video via cotracker')
|
| 413 |
|
| 414 |
+
# 返回处理结果,但不应用跟踪
|
| 415 |
+
return tracking_path, video_tensor, tracking_tensor, repaint_img_tensor, fps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
except Exception as e:
|
| 417 |
import traceback
|
| 418 |
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
| 419 |
+
return None, None, None, None, None
|
| 420 |
|
| 421 |
def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma_repaint_image):
|
| 422 |
"""Process mesh animation task"""
|
|
|
|
| 424 |
# Save uploaded files
|
| 425 |
input_video_path = save_uploaded_file(source)
|
| 426 |
if input_video_path is None:
|
| 427 |
+
return None, None, None, None, None
|
| 428 |
|
| 429 |
tracking_video_path = save_uploaded_file(tracking_video)
|
| 430 |
if tracking_video_path is None:
|
| 431 |
+
return None, None, None, None, None
|
| 432 |
|
| 433 |
das = get_das_pipeline()
|
| 434 |
video_tensor, fps, is_video = load_media(input_video_path)
|
|
|
|
| 441 |
repaint_img_tensor, _, _ = load_media(repaint_path)
|
| 442 |
repaint_img_tensor = repaint_img_tensor[0] # 获取第一帧
|
| 443 |
elif ma_repaint_option == "Yes":
|
|
|
|
| 444 |
repainter = FirstFrameRepainter(gpu_id=GPU_ID, output_dir=OUTPUT_DIR)
|
| 445 |
repaint_img_tensor = repainter.repaint(
|
| 446 |
video_tensor[0],
|
|
|
|
| 448 |
depth_path=None
|
| 449 |
)
|
| 450 |
|
| 451 |
+
# 直接返回上传的跟踪视频路径,而不是生成新的跟踪视频
|
| 452 |
+
return tracking_video_path, video_tensor, tracking_tensor, repaint_img_tensor, fps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 453 |
except Exception as e:
|
| 454 |
import traceback
|
| 455 |
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
| 456 |
+
return None, None, None, None, None
|
| 457 |
|
| 458 |
def generate_tracking_cotracker(video_tensor, density=30):
|
| 459 |
"""在CPU上生成跟踪视频,只使用第一帧的深度信息,使用矩阵运算提高效率
|
|
|
|
| 507 |
# 将结果返回
|
| 508 |
return pred_tracks_with_depth.squeeze(0), pred_visibility.squeeze(0)
|
| 509 |
|
| 510 |
+
@spaces.GPU(duration=240)
|
| 511 |
+
def apply_tracking_unified(video_tensor, tracking_tensor, repaint_img_tensor, prompt, fps):
|
| 512 |
+
"""统一的应用跟踪函数"""
|
| 513 |
+
try:
|
| 514 |
+
if video_tensor is None or tracking_tensor is None:
|
| 515 |
+
return None
|
| 516 |
+
|
| 517 |
+
das = get_das_pipeline()
|
| 518 |
+
output_path = das.apply_tracking(
|
| 519 |
+
video_tensor=video_tensor,
|
| 520 |
+
fps=fps,
|
| 521 |
+
tracking_tensor=tracking_tensor,
|
| 522 |
+
img_cond_tensor=repaint_img_tensor,
|
| 523 |
+
prompt=prompt,
|
| 524 |
+
checkpoint_path=DEFAULT_MODEL_PATH
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
print(f"生成的视频路径: {output_path}")
|
| 528 |
+
|
| 529 |
+
# 确保返回的是绝对路径
|
| 530 |
+
if output_path and not os.path.isabs(output_path):
|
| 531 |
+
output_path = os.path.abspath(output_path)
|
| 532 |
+
|
| 533 |
+
# 检查文件是否存在
|
| 534 |
+
if output_path and os.path.exists(output_path):
|
| 535 |
+
print(f"文件存在,大小: {os.path.getsize(output_path)} 字节")
|
| 536 |
+
return output_path
|
| 537 |
+
else:
|
| 538 |
+
print(f"警告: 输出文件不存在或路径无效: {output_path}")
|
| 539 |
+
return None
|
| 540 |
+
except Exception as e:
|
| 541 |
+
import traceback
|
| 542 |
+
print(f"Apply tracking failed: {str(e)}\n{traceback.format_exc()}")
|
| 543 |
+
return None
|
| 544 |
+
|
| 545 |
+
# 添加在 apply_tracking_unified 函数之后,Gradio 界面定义之前
|
| 546 |
+
|
| 547 |
+
def enable_apply_button(tracking_result):
|
| 548 |
+
"""当跟踪视频生成后启用应用按钮"""
|
| 549 |
+
if tracking_result is not None:
|
| 550 |
+
return gr.update(interactive=True)
|
| 551 |
+
return gr.update(interactive=False)
|
| 552 |
+
|
| 553 |
+
@spaces.GPU
|
| 554 |
+
def process_vggt(video_tensor):
|
| 555 |
+
vggt_model = get_vggt_model()
|
| 556 |
+
|
| 557 |
+
t, c, h, w = video_tensor.shape
|
| 558 |
+
new_width = 518
|
| 559 |
+
new_height = round(h * (new_width / w) / 14) * 14
|
| 560 |
+
resize_transform = transforms.Resize((new_height, new_width), interpolation=Image.BICUBIC)
|
| 561 |
+
video_vggt = resize_transform(video_tensor) # [T, C, H, W]
|
| 562 |
+
|
| 563 |
+
if new_height > 518:
|
| 564 |
+
start_y = (new_height - 518) // 2
|
| 565 |
+
video_vggt = video_vggt[:, :, start_y:start_y + 518, :]
|
| 566 |
+
|
| 567 |
+
with torch.no_grad():
|
| 568 |
+
with torch.cuda.amp.autocast(dtype=torch.float16):
|
| 569 |
+
video_vggt = video_vggt.unsqueeze(0) # [1, T, C, H, W]
|
| 570 |
+
aggregated_tokens_list, ps_idx = vggt_model.aggregator(video_vggt.to("cuda"))
|
| 571 |
+
|
| 572 |
+
extr, intr = pose_encoding_to_extri_intri(vggt_model.camera_head(aggregated_tokens_list)[-1], video_vggt.shape[-2:])
|
| 573 |
+
|
| 574 |
+
return extr, intr
|
| 575 |
+
|
| 576 |
+
def load_examples():
|
| 577 |
+
"""加载示例文件路径"""
|
| 578 |
+
samples_dir = os.path.join(project_root, "samples")
|
| 579 |
+
if not os.path.exists(samples_dir):
|
| 580 |
+
print(f"Warning: Samples directory not found at {samples_dir}")
|
| 581 |
+
return []
|
| 582 |
+
|
| 583 |
+
examples_list = []
|
| 584 |
+
|
| 585 |
+
# 为每个示例集创建一个示例项
|
| 586 |
+
# 示例1
|
| 587 |
+
example1 = [None] * 5 # [source, repaint_image, prompt, tracking_video, result_video]
|
| 588 |
+
for filename in os.listdir(samples_dir):
|
| 589 |
+
if filename.startswith("sample1_"):
|
| 590 |
+
if filename.endswith("_raw.mp4"):
|
| 591 |
+
example1[0] = os.path.join(samples_dir, filename)
|
| 592 |
+
elif filename.endswith("_repaint.png"):
|
| 593 |
+
example1[1] = os.path.join(samples_dir, filename)
|
| 594 |
+
elif filename.endswith("_tracking.mp4"):
|
| 595 |
+
example1[3] = os.path.join(samples_dir, filename)
|
| 596 |
+
elif filename.endswith("_result.mp4"):
|
| 597 |
+
example1[4] = os.path.join(samples_dir, filename)
|
| 598 |
+
|
| 599 |
+
# 设置示例1的提示文本
|
| 600 |
+
example1[2] = "a rocket lifts off from the table and smoke erupt from its bottom."
|
| 601 |
+
|
| 602 |
+
# 示例2
|
| 603 |
+
example2 = [None] * 5 # [source, repaint_image, prompt, tracking_video, result_video]
|
| 604 |
+
for filename in os.listdir(samples_dir):
|
| 605 |
+
if filename.startswith("sample2_"):
|
| 606 |
+
if filename.endswith("_raw.mp4"):
|
| 607 |
+
example2[0] = os.path.join(samples_dir, filename)
|
| 608 |
+
elif filename.endswith("_repaint.png"):
|
| 609 |
+
example2[1] = os.path.join(samples_dir, filename)
|
| 610 |
+
elif filename.endswith("_tracking.mp4"):
|
| 611 |
+
example2[3] = os.path.join(samples_dir, filename)
|
| 612 |
+
elif filename.endswith("_result.mp4"):
|
| 613 |
+
example2[4] = os.path.join(samples_dir, filename)
|
| 614 |
+
|
| 615 |
+
# 设置示例2的提示文本
|
| 616 |
+
example2[2] = "A wonderful bright old-fasion red car is riding from left to right sun light is shining on the car, its reflection glittering. In the background is a deserted city in the noon, the roads and buildings are covered with green vegetation."
|
| 617 |
+
|
| 618 |
+
# 添加示例到列表
|
| 619 |
+
if example1[0] is not None and example1[3] is not None:
|
| 620 |
+
examples_list.append(example1)
|
| 621 |
+
|
| 622 |
+
if example2[0] is not None and example2[3] is not None:
|
| 623 |
+
examples_list.append(example2)
|
| 624 |
+
|
| 625 |
+
# 添加其他示例(如果有)
|
| 626 |
+
sample_prefixes = set()
|
| 627 |
+
for filename in os.listdir(samples_dir):
|
| 628 |
+
if filename.endswith(('.mp4', '.png')):
|
| 629 |
+
prefix = filename.split('_')[0]
|
| 630 |
+
if prefix not in ["sample1", "sample2"]:
|
| 631 |
+
sample_prefixes.add(prefix)
|
| 632 |
+
|
| 633 |
+
for prefix in sorted(sample_prefixes):
|
| 634 |
+
example = [None] * 5 # [source, repaint_image, prompt, tracking_video, result_video]
|
| 635 |
+
for filename in os.listdir(samples_dir):
|
| 636 |
+
if filename.startswith(f"{prefix}_"):
|
| 637 |
+
if filename.endswith("_raw.mp4"):
|
| 638 |
+
example[0] = os.path.join(samples_dir, filename)
|
| 639 |
+
elif filename.endswith("_repaint.png"):
|
| 640 |
+
example[1] = os.path.join(samples_dir, filename)
|
| 641 |
+
elif filename.endswith("_tracking.mp4"):
|
| 642 |
+
example[3] = os.path.join(samples_dir, filename)
|
| 643 |
+
elif filename.endswith("_result.mp4"):
|
| 644 |
+
example[4] = os.path.join(samples_dir, filename)
|
| 645 |
+
|
| 646 |
+
# 添加默认提示文本
|
| 647 |
+
example[2] = "A beautiful scene"
|
| 648 |
+
|
| 649 |
+
# 只有当至少有源文件和跟踪视频时才添加示例
|
| 650 |
+
if example[0] is not None and example[3] is not None:
|
| 651 |
+
examples_list.append(example)
|
| 652 |
+
|
| 653 |
+
return examples_list
|
| 654 |
+
|
| 655 |
# Create Gradio interface with updated layout
|
| 656 |
with gr.Blocks(title="Diffusion as Shader") as demo:
|
| 657 |
gr.Markdown("# Diffusion as Shader Web UI")
|
| 658 |
gr.Markdown("### [Project Page](https://igl-hkust.github.io/das/) | [GitHub](https://github.com/IGL-HKUST/DiffusionAsShader)")
|
| 659 |
|
| 660 |
+
# 创建隐藏状态变量来存储中间结果
|
| 661 |
+
video_tensor_state = gr.State(None)
|
| 662 |
+
tracking_tensor_state = gr.State(None)
|
| 663 |
+
repaint_img_tensor_state = gr.State(None)
|
| 664 |
+
fps_state = gr.State(None)
|
| 665 |
+
|
| 666 |
with gr.Row():
|
| 667 |
left_column = gr.Column(scale=1)
|
| 668 |
right_column = gr.Column(scale=1)
|
| 669 |
|
| 670 |
with right_column:
|
|
|
|
| 671 |
tracking_video = gr.Video(label="Tracking Video")
|
| 672 |
+
|
| 673 |
+
# 初始状态下按钮不可用
|
| 674 |
+
apply_tracking_btn = gr.Button("Generate Video", variant="primary", size="lg", interactive=False)
|
| 675 |
+
output_video = gr.Video(label="Generated Video")
|
| 676 |
|
| 677 |
with left_column:
|
| 678 |
+
source_upload = gr.UploadButton("1. Upload Source", file_types=["image", "video"])
|
| 679 |
+
source_preview = gr.Video(label="Source Preview")
|
| 680 |
+
gr.Markdown("Upload a video or image, We will extract the motion and space structure from it")
|
| 681 |
+
|
| 682 |
+
# 上传文件后更新预览
|
| 683 |
+
def update_source_preview(file):
|
| 684 |
+
if file is None:
|
| 685 |
+
return None
|
| 686 |
+
path = save_uploaded_file(file)
|
| 687 |
+
return path
|
| 688 |
+
|
| 689 |
+
source_upload.upload(
|
| 690 |
+
fn=update_source_preview,
|
| 691 |
+
inputs=[source_upload],
|
| 692 |
+
outputs=[source_preview]
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
common_prompt = gr.Textbox(label="2. Prompt: Describe the scene and the motion you want to create", lines=2)
|
| 696 |
gr.Markdown(f"**Using GPU: {GPU_ID}**")
|
| 697 |
|
| 698 |
with gr.Tabs() as task_tabs:
|
|
|
|
| 708 |
value="No"
|
| 709 |
)
|
| 710 |
gr.Markdown("### Note: If you want to use your own image as repainted first frame, please upload the image in below.")
|
| 711 |
+
|
| 712 |
+
mt_repaint_upload = gr.UploadButton("3. Upload Repaint Image (Optional)", file_types=["image"])
|
| 713 |
+
mt_repaint_preview = gr.Image(label="Repaint Image Preview")
|
| 714 |
+
|
| 715 |
+
# 上传文件后更新预览
|
| 716 |
+
mt_repaint_upload.upload(
|
| 717 |
+
fn=update_source_preview, # 复用相同的函数
|
| 718 |
+
inputs=[mt_repaint_upload],
|
| 719 |
+
outputs=[mt_repaint_preview]
|
| 720 |
)
|
| 721 |
|
| 722 |
# Add run button for Motion Transfer tab
|
| 723 |
+
mt_run_btn = gr.Button("Generate Tracking", variant="primary", size="lg")
|
| 724 |
|
| 725 |
+
# Connect to process function, but don't apply tracking
|
| 726 |
mt_run_btn.click(
|
| 727 |
fn=process_motion_transfer,
|
| 728 |
inputs=[
|
| 729 |
+
source_upload, common_prompt,
|
| 730 |
+
mt_repaint_option, mt_repaint_upload
|
| 731 |
],
|
| 732 |
+
outputs=[tracking_video, video_tensor_state, tracking_tensor_state, repaint_img_tensor_state, fps_state]
|
| 733 |
+
).then(
|
| 734 |
+
fn=enable_apply_button,
|
| 735 |
+
inputs=[tracking_video],
|
| 736 |
+
outputs=[apply_tracking_btn]
|
| 737 |
)
|
| 738 |
|
| 739 |
+
# # Camera Control tab
|
| 740 |
+
# with gr.TabItem("Camera Control"):
|
| 741 |
+
# gr.Markdown("## Camera Control")
|
| 742 |
|
| 743 |
+
# cc_camera_motion = gr.Textbox(
|
| 744 |
+
# label="Current Camera Motion Sequence",
|
| 745 |
+
# placeholder="Your camera motion sequence will appear here...",
|
| 746 |
+
# interactive=False
|
| 747 |
+
# )
|
| 748 |
|
| 749 |
+
# # Use tabs for different motion types
|
| 750 |
+
# with gr.Tabs() as cc_motion_tabs:
|
| 751 |
+
# # Translation tab
|
| 752 |
+
# with gr.TabItem("Translation (trans)"):
|
| 753 |
+
# with gr.Row():
|
| 754 |
+
# cc_trans_x = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, step=0.05, label="X-axis Movement")
|
| 755 |
+
# cc_trans_y = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, step=0.05, label="Y-axis Movement")
|
| 756 |
+
# cc_trans_z = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, step=0.05, label="Z-axis Movement (depth)")
|
| 757 |
|
| 758 |
+
# with gr.Row():
|
| 759 |
+
# cc_trans_start = gr.Number(minimum=0, maximum=48, value=0, step=1, label="Start Frame", precision=0)
|
| 760 |
+
# cc_trans_end = gr.Number(minimum=0, maximum=48, value=48, step=1, label="End Frame", precision=0)
|
| 761 |
|
| 762 |
+
# cc_trans_note = gr.Markdown("""
|
| 763 |
+
# **Translation Notes:**
|
| 764 |
+
# - Positive X: Move right, Negative X: Move left
|
| 765 |
+
# - Positive Y: Move down, Negative Y: Move up
|
| 766 |
+
# - Positive Z: Zoom in, Negative Z: Zoom out
|
| 767 |
+
# """)
|
| 768 |
|
| 769 |
+
# # Add translation button in the Translation tab
|
| 770 |
+
# cc_add_trans = gr.Button("Add Camera Translation", variant="secondary")
|
| 771 |
|
| 772 |
+
# # Function to add translation motion
|
| 773 |
+
# def add_translation_motion(current_motion, trans_x, trans_y, trans_z, trans_start, trans_end):
|
| 774 |
+
# # Format: trans dx dy dz [start_frame end_frame]
|
| 775 |
+
# frame_range = f" {int(trans_start)} {int(trans_end)}" if trans_start != 0 or trans_end != 48 else ""
|
| 776 |
+
# new_motion = f"trans {trans_x:.2f} {trans_y:.2f} {trans_z:.2f}{frame_range}"
|
| 777 |
|
| 778 |
+
# # Append to existing motion string with semicolon separator if needed
|
| 779 |
+
# if current_motion and current_motion.strip():
|
| 780 |
+
# updated_motion = f"{current_motion}; {new_motion}"
|
| 781 |
+
# else:
|
| 782 |
+
# updated_motion = new_motion
|
| 783 |
|
| 784 |
+
# return updated_motion
|
| 785 |
|
| 786 |
+
# # Connect translation button
|
| 787 |
+
# cc_add_trans.click(
|
| 788 |
+
# fn=add_translation_motion,
|
| 789 |
+
# inputs=[
|
| 790 |
+
# cc_camera_motion,
|
| 791 |
+
# cc_trans_x, cc_trans_y, cc_trans_z, cc_trans_start, cc_trans_end
|
| 792 |
+
# ],
|
| 793 |
+
# outputs=[cc_camera_motion]
|
| 794 |
+
# )
|
| 795 |
|
| 796 |
+
# # Rotation tab
|
| 797 |
+
# with gr.TabItem("Rotation (rot)"):
|
| 798 |
+
# with gr.Row():
|
| 799 |
+
# cc_rot_axis = gr.Dropdown(choices=["x", "y", "z"], value="y", label="Rotation Axis")
|
| 800 |
+
# cc_rot_angle = gr.Slider(minimum=-30, maximum=30, value=5, step=1, label="Rotation Angle (degrees)")
|
| 801 |
|
| 802 |
+
# with gr.Row():
|
| 803 |
+
# cc_rot_start = gr.Number(minimum=0, maximum=48, value=0, step=1, label="Start Frame", precision=0)
|
| 804 |
+
# cc_rot_end = gr.Number(minimum=0, maximum=48, value=48, step=1, label="End Frame", precision=0)
|
| 805 |
|
| 806 |
+
# cc_rot_note = gr.Markdown("""
|
| 807 |
+
# **Rotation Notes:**
|
| 808 |
+
# - X-axis rotation: Tilt camera up/down
|
| 809 |
+
# - Y-axis rotation: Pan camera left/right
|
| 810 |
+
# - Z-axis rotation: Roll camera
|
| 811 |
+
# """)
|
| 812 |
|
| 813 |
+
# # Add rotation button in the Rotation tab
|
| 814 |
+
# cc_add_rot = gr.Button("Add Camera Rotation", variant="secondary")
|
| 815 |
|
| 816 |
+
# # Function to add rotation motion
|
| 817 |
+
# def add_rotation_motion(current_motion, rot_axis, rot_angle, rot_start, rot_end):
|
| 818 |
+
# # Format: rot axis angle [start_frame end_frame]
|
| 819 |
+
# frame_range = f" {int(rot_start)} {int(rot_end)}" if rot_start != 0 or rot_end != 48 else ""
|
| 820 |
+
# new_motion = f"rot {rot_axis} {rot_angle}{frame_range}"
|
| 821 |
|
| 822 |
+
# # Append to existing motion string with semicolon separator if needed
|
| 823 |
+
# if current_motion and current_motion.strip():
|
| 824 |
+
# updated_motion = f"{current_motion}; {new_motion}"
|
| 825 |
+
# else:
|
| 826 |
+
# updated_motion = new_motion
|
| 827 |
|
| 828 |
+
# return updated_motion
|
| 829 |
|
| 830 |
+
# # Connect rotation button
|
| 831 |
+
# cc_add_rot.click(
|
| 832 |
+
# fn=add_rotation_motion,
|
| 833 |
+
# inputs=[
|
| 834 |
+
# cc_camera_motion,
|
| 835 |
+
# cc_rot_axis, cc_rot_angle, cc_rot_start, cc_rot_end
|
| 836 |
+
# ],
|
| 837 |
+
# outputs=[cc_camera_motion]
|
| 838 |
+
# )
|
| 839 |
|
| 840 |
+
# # Add a clear button to reset the motion sequence
|
| 841 |
+
# cc_clear_motion = gr.Button("Clear All Motions", variant="stop")
|
| 842 |
|
| 843 |
+
# def clear_camera_motion():
|
| 844 |
+
# return ""
|
| 845 |
|
| 846 |
+
# cc_clear_motion.click(
|
| 847 |
+
# fn=clear_camera_motion,
|
| 848 |
+
# inputs=[],
|
| 849 |
+
# outputs=[cc_camera_motion]
|
| 850 |
+
# )
|
| 851 |
+
|
| 852 |
+
# cc_tracking_method = gr.Radio(
|
| 853 |
+
# label="Tracking Method",
|
| 854 |
+
# choices=["moge", "cotracker"],
|
| 855 |
+
# value="cotracker"
|
| 856 |
+
# )
|
| 857 |
|
| 858 |
+
# # Add run button for Camera Control tab
|
| 859 |
+
# cc_run_btn = gr.Button("Generate Tracking", variant="primary", size="lg")
|
| 860 |
|
| 861 |
+
# # Connect to process function, but don't apply tracking
|
| 862 |
+
# cc_run_btn.click(
|
| 863 |
+
# fn=process_camera_control,
|
| 864 |
+
# inputs=[
|
| 865 |
+
# source_upload, common_prompt,
|
| 866 |
+
# cc_camera_motion, cc_tracking_method
|
| 867 |
+
# ],
|
| 868 |
+
# outputs=[tracking_video, video_tensor_state, tracking_tensor_state, repaint_img_tensor_state, fps_state]
|
| 869 |
+
# ).then(
|
| 870 |
+
# fn=enable_apply_button,
|
| 871 |
+
# inputs=[tracking_video],
|
| 872 |
+
# outputs=[apply_tracking_btn]
|
| 873 |
+
# )
|
| 874 |
|
| 875 |
+
# # Object Manipulation tab
|
| 876 |
+
# with gr.TabItem("Object Manipulation"):
|
| 877 |
+
# gr.Markdown("## Object Manipulation")
|
| 878 |
+
# om_object_mask = gr.File(
|
| 879 |
+
# label="Object Mask Image",
|
| 880 |
+
# file_types=["image"]
|
| 881 |
+
# )
|
| 882 |
+
# gr.Markdown("Upload a binary mask image, white areas indicate the object to manipulate")
|
| 883 |
+
# om_object_motion = gr.Dropdown(
|
| 884 |
+
# label="Object Motion Type",
|
| 885 |
+
# choices=["up", "down", "left", "right", "front", "back", "rot"],
|
| 886 |
+
# value="up"
|
| 887 |
+
# )
|
| 888 |
+
# om_tracking_method = gr.Radio(
|
| 889 |
+
# label="Tracking Method",
|
| 890 |
+
# choices=["moge", "cotracker"],
|
| 891 |
+
# value="cotracker"
|
| 892 |
+
# )
|
| 893 |
|
| 894 |
+
# # Add run button for Object Manipulation tab
|
| 895 |
+
# om_run_btn = gr.Button("Generate Tracking", variant="primary", size="lg")
|
| 896 |
|
| 897 |
+
# # Connect to process function, but don't apply tracking
|
| 898 |
+
# om_run_btn.click(
|
| 899 |
+
# fn=process_object_manipulation,
|
| 900 |
+
# inputs=[
|
| 901 |
+
# source_upload, common_prompt,
|
| 902 |
+
# om_object_motion, om_object_mask, om_tracking_method
|
| 903 |
+
# ],
|
| 904 |
+
# outputs=[tracking_video, video_tensor_state, tracking_tensor_state, repaint_img_tensor_state, fps_state]
|
| 905 |
+
# ).then(
|
| 906 |
+
# fn=enable_apply_button,
|
| 907 |
+
# inputs=[tracking_video],
|
| 908 |
+
# outputs=[apply_tracking_btn]
|
| 909 |
+
# )
|
| 910 |
|
| 911 |
+
# # Animating meshes to video tab
|
| 912 |
+
# with gr.TabItem("Animating meshes to video"):
|
| 913 |
+
# gr.Markdown("## Mesh Animation to Video")
|
| 914 |
+
# gr.Markdown("""
|
| 915 |
+
# Note: Currently only supports tracking videos generated with Blender (version > 4.0).
|
| 916 |
+
# Please run the script `scripts/blender.py` in your Blender project to generate tracking videos.
|
| 917 |
+
# """)
|
| 918 |
+
# ma_tracking_video = gr.File(
|
| 919 |
+
# label="Tracking Video",
|
| 920 |
+
# file_types=["video"],
|
| 921 |
+
# # 添加 change 事件处理器,当上传文件时自动激活 Generate Video 按钮
|
| 922 |
+
# elem_id="ma_tracking_video"
|
| 923 |
+
# )
|
| 924 |
+
# gr.Markdown("Tracking video needs to be generated from Blender")
|
| 925 |
|
| 926 |
+
# # Simplified controls - Radio buttons for Yes/No and separate file upload
|
| 927 |
+
# with gr.Row():
|
| 928 |
+
# ma_repaint_option = gr.Radio(
|
| 929 |
+
# label="Repaint First Frame",
|
| 930 |
+
# choices=["No", "Yes"],
|
| 931 |
+
# value="No"
|
| 932 |
+
# )
|
| 933 |
+
# gr.Markdown("### Note: If you want to use your own image as repainted first frame, please upload the image in below.")
|
| 934 |
+
# # Custom image uploader (always visible)
|
| 935 |
+
# ma_repaint_image = gr.File(
|
| 936 |
+
# label="Custom Repaint Image",
|
| 937 |
+
# file_types=["image"]
|
| 938 |
+
# )
|
| 939 |
|
| 940 |
+
# # 修改按钮名称为 "Apply Repaint"
|
| 941 |
+
# ma_run_btn = gr.Button("Apply Repaint", variant="primary", size="lg")
|
| 942 |
|
| 943 |
+
# # 添加 tracking video 上传事件处理
|
| 944 |
+
# def handle_tracking_upload(file):
|
| 945 |
+
# if file is not None:
|
| 946 |
+
# tracking_path = save_uploaded_file(file)
|
| 947 |
+
# if tracking_path:
|
| 948 |
+
# return tracking_path, gr.update(interactive=True)
|
| 949 |
+
# return None, gr.update(interactive=False)
|
| 950 |
+
|
| 951 |
+
# # 当上传 tracking video 时,直接显示并激活 Generate Video 按钮
|
| 952 |
+
# ma_tracking_video.change(
|
| 953 |
+
# fn=handle_tracking_upload,
|
| 954 |
+
# inputs=[ma_tracking_video],
|
| 955 |
+
# outputs=[tracking_video, apply_tracking_btn]
|
| 956 |
+
# )
|
| 957 |
+
|
| 958 |
+
# # 修改 process_mesh_animation 函数的行为
|
| 959 |
+
# def process_mesh_animation_repaint(source, prompt, ma_repaint_option, ma_repaint_image):
|
| 960 |
+
# """只处理重绘部分,不处理跟踪视频"""
|
| 961 |
+
# try:
|
| 962 |
+
# # 保存上传的文件
|
| 963 |
+
# input_video_path = save_uploaded_file(source)
|
| 964 |
+
# if input_video_path is None:
|
| 965 |
+
# return None, None, None, None
|
| 966 |
+
|
| 967 |
+
# das = get_das_pipeline()
|
| 968 |
+
# video_tensor, fps, is_video = load_media(input_video_path)
|
| 969 |
+
# das.fps = fps
|
| 970 |
+
|
| 971 |
+
# repaint_img_tensor = None
|
| 972 |
+
# if ma_repaint_image is not None:
|
| 973 |
+
# repaint_path = save_uploaded_file(ma_repaint_image)
|
| 974 |
+
# repaint_img_tensor, _, _ = load_media(repaint_path)
|
| 975 |
+
# repaint_img_tensor = repaint_img_tensor[0]
|
| 976 |
+
# elif ma_repaint_option == "Yes":
|
| 977 |
+
# repainter = FirstFrameRepainter(gpu_id=GPU_ID, output_dir=OUTPUT_DIR)
|
| 978 |
+
# repaint_img_tensor = repainter.repaint(
|
| 979 |
+
# video_tensor[0],
|
| 980 |
+
# prompt=prompt,
|
| 981 |
+
# depth_path=None
|
| 982 |
+
# )
|
| 983 |
+
|
| 984 |
+
# # 返回处理结果,但不包括跟踪视频路径
|
| 985 |
+
# return video_tensor, None, repaint_img_tensor, fps
|
| 986 |
+
# except Exception as e:
|
| 987 |
+
# import traceback
|
| 988 |
+
# print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
| 989 |
+
# return None, None, None, None
|
| 990 |
+
|
| 991 |
+
# # 连接到修改后的处理函数
|
| 992 |
+
# ma_run_btn.click(
|
| 993 |
+
# fn=process_mesh_animation_repaint,
|
| 994 |
+
# inputs=[
|
| 995 |
+
# source_upload, common_prompt,
|
| 996 |
+
# ma_repaint_option, ma_repaint_image
|
| 997 |
+
# ],
|
| 998 |
+
# outputs=[video_tensor_state, tracking_tensor_state, repaint_img_tensor_state, fps_state]
|
| 999 |
+
# )
|
| 1000 |
+
|
| 1001 |
+
# 在所有 UI 元素定义之后,添加 Examples 组件
|
| 1002 |
+
examples_list = load_examples()
|
| 1003 |
+
if examples_list:
|
| 1004 |
+
with gr.Blocks() as examples_block:
|
| 1005 |
+
gr.Examples(
|
| 1006 |
+
examples=examples_list,
|
| 1007 |
+
inputs=[source_preview, mt_repaint_preview, common_prompt, tracking_video, output_video],
|
| 1008 |
+
outputs=[source_preview, mt_repaint_preview, common_prompt, tracking_video, output_video],
|
| 1009 |
+
fn=lambda *args: args, # 简单地返回输入作为输出
|
| 1010 |
+
cache_examples=True,
|
| 1011 |
+
label="Examples"
|
| 1012 |
+
)
|
| 1013 |
|
| 1014 |
# Launch interface
|
| 1015 |
if __name__ == "__main__":
|
|
|
|
| 1019 |
print("Creating public link for remote access")
|
| 1020 |
|
| 1021 |
# Launch interface
|
| 1022 |
+
demo.launch(share=args.share, server_port=args.port)
|