Update app.py
Browse files
app.py
CHANGED
|
@@ -219,6 +219,7 @@ def get_mask_sam_process(
|
|
| 219 |
video_frames_dir, # extracted_frames_output_dir defined in 'preprocess_video_in' function
|
| 220 |
scanned_frames,
|
| 221 |
working_frame: str = None, # current frame being added points
|
|
|
|
| 222 |
progress=gr.Progress(track_tqdm=True)
|
| 223 |
):
|
| 224 |
|
|
@@ -287,8 +288,12 @@ def get_mask_sam_process(
|
|
| 287 |
plt.savefig(first_frame_output_filename, format='jpg')
|
| 288 |
plt.close()
|
| 289 |
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
|
| 291 |
-
return "output_first_frame.jpg", frame_names, inference_state, gr.update(value=working_frame, visible=True)
|
| 292 |
|
| 293 |
def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, progress=gr.Progress(track_tqdm=True)):
|
| 294 |
#### PROPAGATION ####
|
|
@@ -389,6 +394,7 @@ with gr.Blocks() as demo:
|
|
| 389 |
scanned_frames = gr.State()
|
| 390 |
stored_inference_state = gr.State()
|
| 391 |
stored_frame_names = gr.State()
|
|
|
|
| 392 |
with gr.Column():
|
| 393 |
gr.Markdown("# SAM2 Video Predictor")
|
| 394 |
gr.Markdown("This is a simple demo for video segmentation with SAM2.")
|
|
@@ -424,6 +430,7 @@ with gr.Blocks() as demo:
|
|
| 424 |
video_in = gr.Video(label="Video IN")
|
| 425 |
|
| 426 |
with gr.Column():
|
|
|
|
| 427 |
working_frame = gr.Dropdown(label="working frame ID", choices=[""], value=None, visible=False, allow_custom_value=False, interactive=True)
|
| 428 |
output_result = gr.Image(label="current working mask ref")
|
| 429 |
with gr.Row():
|
|
@@ -505,6 +512,7 @@ with gr.Blocks() as demo:
|
|
| 505 |
video_frames_dir,
|
| 506 |
scanned_frames,
|
| 507 |
working_frame,
|
|
|
|
| 508 |
],
|
| 509 |
outputs = [
|
| 510 |
output_result,
|
|
|
|
| 219 |
video_frames_dir, # extracted_frames_output_dir defined in 'preprocess_video_in' function
|
| 220 |
scanned_frames,
|
| 221 |
working_frame: str = None, # current frame being added points
|
| 222 |
+
available_frames_to_check,
|
| 223 |
progress=gr.Progress(track_tqdm=True)
|
| 224 |
):
|
| 225 |
|
|
|
|
| 288 |
plt.savefig(first_frame_output_filename, format='jpg')
|
| 289 |
plt.close()
|
| 290 |
torch.cuda.empty_cache()
|
| 291 |
+
|
| 292 |
+
# Assuming available_frames_to_check.value is a list
|
| 293 |
+
if working_frame not in available_frames_to_check.value:
|
| 294 |
+
available_frames_to_check.value.append(working_frame)
|
| 295 |
|
| 296 |
+
return "output_first_frame.jpg", frame_names, inference_state, gr.update(choices=available_frames_to_check.value, value=working_frame, visible=True)
|
| 297 |
|
| 298 |
def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, progress=gr.Progress(track_tqdm=True)):
|
| 299 |
#### PROPAGATION ####
|
|
|
|
| 394 |
scanned_frames = gr.State()
|
| 395 |
stored_inference_state = gr.State()
|
| 396 |
stored_frame_names = gr.State()
|
| 397 |
+
available_frames_to_check = gr.State([])
|
| 398 |
with gr.Column():
|
| 399 |
gr.Markdown("# SAM2 Video Predictor")
|
| 400 |
gr.Markdown("This is a simple demo for video segmentation with SAM2.")
|
|
|
|
| 430 |
video_in = gr.Video(label="Video IN")
|
| 431 |
|
| 432 |
with gr.Column():
|
| 433 |
+
|
| 434 |
working_frame = gr.Dropdown(label="working frame ID", choices=[""], value=None, visible=False, allow_custom_value=False, interactive=True)
|
| 435 |
output_result = gr.Image(label="current working mask ref")
|
| 436 |
with gr.Row():
|
|
|
|
| 512 |
video_frames_dir,
|
| 513 |
scanned_frames,
|
| 514 |
working_frame,
|
| 515 |
+
available_frames_to_check,
|
| 516 |
],
|
| 517 |
outputs = [
|
| 518 |
output_result,
|