Update app.py
Browse files
app.py
CHANGED
|
@@ -316,6 +316,7 @@ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_
|
|
| 316 |
|
| 317 |
# Initialize a list to store file paths of saved images
|
| 318 |
jpeg_images = []
|
|
|
|
| 319 |
|
| 320 |
# run propagation throughout the video and collect the results in a dict
|
| 321 |
video_segments = {} # video_segments contains the per-frame segmentation results
|
|
@@ -352,11 +353,18 @@ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_
|
|
| 352 |
if f"frame_{out_frame_idx}.jpg" not in available_frames_to_check:
|
| 353 |
available_frames_to_check.append(f"frame_{out_frame_idx}.jpg")
|
| 354 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
torch.cuda.empty_cache()
|
| 356 |
print(f"JPEG_IMAGES: {jpeg_images}")
|
| 357 |
|
| 358 |
if vis_frame_type == "check":
|
| 359 |
-
return gr.update(value=jpeg_images), gr.update(value=None), gr.update(choices=available_frames_to_check, value=working_frame, visible=True), available_frames_to_check, gr.update(visible=True)
|
| 360 |
elif vis_frame_type == "render":
|
| 361 |
# Create a video clip from the image sequence
|
| 362 |
original_fps = get_video_fps(video_in)
|
|
@@ -371,8 +379,12 @@ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_
|
|
| 371 |
final_vid_output_path,
|
| 372 |
codec='libx264'
|
| 373 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
|
| 375 |
-
return gr.update(value=None), gr.update(value=final_vid_output_path), working_frame, available_frames_to_check, gr.update(visible=True)
|
| 376 |
|
| 377 |
def update_ui(vis_frame_type):
|
| 378 |
if vis_frame_type == "check":
|
|
@@ -478,6 +490,7 @@ with gr.Blocks(css=css) as demo:
|
|
| 478 |
reset_prpgt_brn = gr.Button("Reset", visible=False)
|
| 479 |
output_propagated = gr.Gallery(label="Propagated Mask samples gallery", columns=4, visible=False)
|
| 480 |
output_video = gr.Video(visible=False)
|
|
|
|
| 481 |
# output_result_mask = gr.Image()
|
| 482 |
|
| 483 |
|
|
@@ -581,7 +594,7 @@ with gr.Blocks(css=css) as demo:
|
|
| 581 |
).then(
|
| 582 |
fn = propagate_to_all,
|
| 583 |
inputs = [video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame],
|
| 584 |
-
outputs = [output_propagated, output_video, working_frame, available_frames_to_check, reset_prpgt_brn]
|
| 585 |
)
|
| 586 |
|
| 587 |
demo.launch(show_api=False, show_error=True)
|
|
|
|
| 316 |
|
| 317 |
# Initialize a list to store file paths of saved images
|
| 318 |
jpeg_images = []
|
| 319 |
+
masks_frames = []
|
| 320 |
|
| 321 |
# run propagation throughout the video and collect the results in a dict
|
| 322 |
video_segments = {} # video_segments contains the per-frame segmentation results
|
|
|
|
| 353 |
if f"frame_{out_frame_idx}.jpg" not in available_frames_to_check:
|
| 354 |
available_frames_to_check.append(f"frame_{out_frame_idx}.jpg")
|
| 355 |
|
| 356 |
+
# Save the raw binary mask as a separate image
|
| 357 |
+
mask_filename = os.path.join(frames_output_dir, f"mask_{out_frame_idx}.jpg")
|
| 358 |
+
binary_mask = (out_mask * 255).astype(np.uint8) # Scale mask to 0-255
|
| 359 |
+
mask_image = Image.fromarray(binary_mask)
|
| 360 |
+
mask_image.save(mask_filename) # Save the mask as a JPEG
|
| 361 |
+
masks_frames.append(mask_filename) # Append to the list of masks
|
| 362 |
+
|
| 363 |
torch.cuda.empty_cache()
|
| 364 |
print(f"JPEG_IMAGES: {jpeg_images}")
|
| 365 |
|
| 366 |
if vis_frame_type == "check":
|
| 367 |
+
return gr.update(value=jpeg_images), gr.update(value=None), gr.update(choices=available_frames_to_check, value=working_frame, visible=True), available_frames_to_check, gr.update(visible=True), None
|
| 368 |
elif vis_frame_type == "render":
|
| 369 |
# Create a video clip from the image sequence
|
| 370 |
original_fps = get_video_fps(video_in)
|
|
|
|
| 379 |
final_vid_output_path,
|
| 380 |
codec='libx264'
|
| 381 |
)
|
| 382 |
+
|
| 383 |
+
mask_clip = ImageSequenceClip(masks_frames, fps=fps)
|
| 384 |
+
# Write the result to a file
|
| 385 |
+
mask_final_vid_output_path = "mask_output_video.mp4"
|
| 386 |
|
| 387 |
+
return gr.update(value=None), gr.update(value=final_vid_output_path), working_frame, available_frames_to_check, gr.update(visible=True), mask_final_vid_output_path
|
| 388 |
|
| 389 |
def update_ui(vis_frame_type):
|
| 390 |
if vis_frame_type == "check":
|
|
|
|
| 490 |
reset_prpgt_brn = gr.Button("Reset", visible=False)
|
| 491 |
output_propagated = gr.Gallery(label="Propagated Mask samples gallery", columns=4, visible=False)
|
| 492 |
output_video = gr.Video(visible=False)
|
| 493 |
+
mask_final_output = gr.Video(label="Mask Video")
|
| 494 |
# output_result_mask = gr.Image()
|
| 495 |
|
| 496 |
|
|
|
|
| 594 |
).then(
|
| 595 |
fn = propagate_to_all,
|
| 596 |
inputs = [video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame],
|
| 597 |
+
outputs = [output_propagated, output_video, working_frame, available_frames_to_check, reset_prpgt_brn, mask_final_output]
|
| 598 |
)
|
| 599 |
|
| 600 |
demo.launch(show_api=False, show_error=True)
|