Update app.py
Browse files
app.py
CHANGED
|
@@ -301,7 +301,7 @@ def get_mask_sam_process(
|
|
| 301 |
|
| 302 |
return "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=True)
|
| 303 |
|
| 304 |
-
def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, progress=gr.Progress(track_tqdm=True)):
|
| 305 |
#### PROPAGATION ####
|
| 306 |
sam2_checkpoint, model_cfg = load_model(checkpoint)
|
| 307 |
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
|
@@ -349,11 +349,14 @@ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_
|
|
| 349 |
# Append the file path to the list
|
| 350 |
jpeg_images.append(output_filename)
|
| 351 |
|
|
|
|
|
|
|
|
|
|
| 352 |
torch.cuda.empty_cache()
|
| 353 |
print(f"JPEG_IMAGES: {jpeg_images}")
|
| 354 |
|
| 355 |
if vis_frame_type == "check":
|
| 356 |
-
return gr.update(value=jpeg_images), gr.update(value=None), gr.update(choices=
|
| 357 |
elif vis_frame_type == "render":
|
| 358 |
# Create a video clip from the image sequence
|
| 359 |
original_fps = get_video_fps(video_in)
|
|
@@ -369,7 +372,7 @@ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_
|
|
| 369 |
codec='libx264'
|
| 370 |
)
|
| 371 |
|
| 372 |
-
return gr.update(value=None), gr.update(value=final_vid_output_path), None
|
| 373 |
|
| 374 |
def update_ui(vis_frame_type):
|
| 375 |
if vis_frame_type == "check":
|
|
@@ -396,7 +399,7 @@ def reset_propagation(predictor, stored_inference_state):
|
|
| 396 |
|
| 397 |
predictor.reset_state(stored_inference_state)
|
| 398 |
print(f"RESET State: {stored_inference_state} ")
|
| 399 |
-
return gr.update(value=None, visible=False), stored_inference_state, None
|
| 400 |
|
| 401 |
with gr.Blocks() as demo:
|
| 402 |
first_frame_path = gr.State()
|
|
@@ -437,7 +440,7 @@ with gr.Blocks() as demo:
|
|
| 437 |
|
| 438 |
with gr.Row():
|
| 439 |
checkpoint = gr.Dropdown(label="Checkpoint", choices=["tiny", "small", "base-plus", "large"], value="tiny")
|
| 440 |
-
submit_btn = gr.Button("
|
| 441 |
|
| 442 |
with gr.Accordion("Your video IN", open=True) as video_in_drawer:
|
| 443 |
video_in = gr.Video(label="Video IN")
|
|
@@ -541,7 +544,7 @@ with gr.Blocks() as demo:
|
|
| 541 |
reset_prpgt_brn.click(
|
| 542 |
fn = reset_propagation,
|
| 543 |
inputs = [loaded_predictor, stored_inference_state],
|
| 544 |
-
outputs = [output_propagated, stored_inference_state, output_result],
|
| 545 |
queue=False
|
| 546 |
)
|
| 547 |
|
|
@@ -552,8 +555,8 @@ with gr.Blocks() as demo:
|
|
| 552 |
queue=False
|
| 553 |
).then(
|
| 554 |
fn = propagate_to_all,
|
| 555 |
-
inputs = [video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type],
|
| 556 |
-
outputs = [output_propagated, output_video, working_frame]
|
| 557 |
)
|
| 558 |
|
| 559 |
demo.launch(show_api=False, show_error=True)
|
|
|
|
| 301 |
|
| 302 |
return "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=True)
|
| 303 |
|
| 304 |
+
def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, progress=gr.Progress(track_tqdm=True)):
|
| 305 |
#### PROPAGATION ####
|
| 306 |
sam2_checkpoint, model_cfg = load_model(checkpoint)
|
| 307 |
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
|
|
|
| 349 |
# Append the file path to the list
|
| 350 |
jpeg_images.append(output_filename)
|
| 351 |
|
| 352 |
+
if f"frame_{out_frame_idx}.jpg" not in available_frames_to_check:
|
| 353 |
+
available_frames_to_check.append(f"frame_{out_frame_idx}.jpg")
|
| 354 |
+
|
| 355 |
torch.cuda.empty_cache()
|
| 356 |
print(f"JPEG_IMAGES: {jpeg_images}")
|
| 357 |
|
| 358 |
if vis_frame_type == "check":
|
| 359 |
+
return gr.update(value=jpeg_images), gr.update(value=None), gr.update(choices=available_frames_to_check, value=None, visible=True), available_frames_to_check
|
| 360 |
elif vis_frame_type == "render":
|
| 361 |
# Create a video clip from the image sequence
|
| 362 |
original_fps = get_video_fps(video_in)
|
|
|
|
| 372 |
codec='libx264'
|
| 373 |
)
|
| 374 |
|
| 375 |
+
return gr.update(value=None), gr.update(value=final_vid_output_path), None, available_frames_to_check
|
| 376 |
|
| 377 |
def update_ui(vis_frame_type):
|
| 378 |
if vis_frame_type == "check":
|
|
|
|
| 399 |
|
| 400 |
predictor.reset_state(stored_inference_state)
|
| 401 |
print(f"RESET State: {stored_inference_state} ")
|
| 402 |
+
return gr.update(value=None, visible=False), stored_inference_state, None, gr.State([])
|
| 403 |
|
| 404 |
with gr.Blocks() as demo:
|
| 405 |
first_frame_path = gr.State()
|
|
|
|
| 440 |
|
| 441 |
with gr.Row():
|
| 442 |
checkpoint = gr.Dropdown(label="Checkpoint", choices=["tiny", "small", "base-plus", "large"], value="tiny")
|
| 443 |
+
submit_btn = gr.Button("Get Mask", size="lg")
|
| 444 |
|
| 445 |
with gr.Accordion("Your video IN", open=True) as video_in_drawer:
|
| 446 |
video_in = gr.Video(label="Video IN")
|
|
|
|
| 544 |
reset_prpgt_brn.click(
|
| 545 |
fn = reset_propagation,
|
| 546 |
inputs = [loaded_predictor, stored_inference_state],
|
| 547 |
+
outputs = [output_propagated, stored_inference_state, output_result, available_frames_to_check],
|
| 548 |
queue=False
|
| 549 |
)
|
| 550 |
|
|
|
|
| 555 |
queue=False
|
| 556 |
).then(
|
| 557 |
fn = propagate_to_all,
|
| 558 |
+
inputs = [video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check],
|
| 559 |
+
outputs = [output_propagated, output_video, working_frame, available_frames_to_check]
|
| 560 |
)
|
| 561 |
|
| 562 |
demo.launch(show_api=False, show_error=True)
|