Update app.py
Browse files
app.py
CHANGED
|
@@ -28,7 +28,7 @@ from PIL import Image, ImageFilter
|
|
| 28 |
from sam2.build_sam import build_sam2_video_predictor
|
| 29 |
|
| 30 |
def preprocess_image(image):
|
| 31 |
-
return image, gr.State([]), gr.State([]), image
|
| 32 |
|
| 33 |
def preprocess_video_in(video_path):
|
| 34 |
|
|
@@ -70,7 +70,7 @@ def preprocess_video_in(video_path):
|
|
| 70 |
cap.release()
|
| 71 |
|
| 72 |
# 'image' is the first frame extracted from video_in
|
| 73 |
-
return first_frame, gr.State([]), gr.State([]), first_frame, first_frame, output_dir
|
| 74 |
|
| 75 |
def get_point(point_type, tracking_points, trackings_input_label, first_frame_path, evt: gr.SelectData):
|
| 76 |
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
|
|
@@ -184,12 +184,7 @@ def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_l
|
|
| 184 |
|
| 185 |
return combined_images, mask_images
|
| 186 |
|
| 187 |
-
|
| 188 |
-
def sam_process(input_first_frame_image, checkpoint, tracking_points, trackings_input_label, frames_output_dir):
|
| 189 |
-
# 1. We need to preprocess the video and store frames in the right directory
|
| 190 |
-
# — Penser à utiliser un ID unique pour le dossier
|
| 191 |
-
|
| 192 |
-
|
| 193 |
# Load model accordingly to user's choice
|
| 194 |
if checkpoint == "tiny":
|
| 195 |
sam2_checkpoint = "./checkpoints/sam2_hiera_tiny.pt"
|
|
@@ -203,13 +198,20 @@ def sam_process(input_first_frame_image, checkpoint, tracking_points, trackings_
|
|
| 203 |
elif checkpoint == "large":
|
| 204 |
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
|
| 205 |
model_cfg = "sam2_hiera_l.yaml"
|
|
|
|
|
|
|
| 206 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
| 208 |
|
| 209 |
|
| 210 |
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
|
| 211 |
-
print(f"STATE FRAME OUTPUT DIRECTORY: {
|
| 212 |
-
video_dir =
|
| 213 |
|
| 214 |
# scan all the JPEG frame names in this directory
|
| 215 |
frame_names = [
|
|
@@ -248,13 +250,18 @@ def sam_process(input_first_frame_image, checkpoint, tracking_points, trackings_
|
|
| 248 |
show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])
|
| 249 |
|
| 250 |
# Save the plot as a JPG file
|
| 251 |
-
|
| 252 |
-
plt.savefig(
|
| 253 |
plt.close()
|
| 254 |
|
| 255 |
-
|
| 256 |
-
#### PROPAGATION ####
|
| 257 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
# Define a directory to save the JPEG images
|
| 259 |
frames_output_dir = "frames_output_images"
|
| 260 |
os.makedirs(frames_output_dir, exist_ok=True)
|
|
@@ -289,16 +296,16 @@ def sam_process(input_first_frame_image, checkpoint, tracking_points, trackings_
|
|
| 289 |
|
| 290 |
# Close the plot
|
| 291 |
plt.close()
|
| 292 |
-
|
| 293 |
-
# OLD
|
| 294 |
|
| 295 |
-
return
|
| 296 |
|
| 297 |
with gr.Blocks() as demo:
|
| 298 |
first_frame_path = gr.State()
|
| 299 |
tracking_points = gr.State([])
|
| 300 |
trackings_input_label = gr.State([])
|
| 301 |
-
|
|
|
|
|
|
|
| 302 |
with gr.Column():
|
| 303 |
gr.Markdown("# SAM2 Video Predictor")
|
| 304 |
gr.Markdown("This is a simple demo for video segmentation with SAM2.")
|
|
@@ -325,20 +332,21 @@ with gr.Blocks() as demo:
|
|
| 325 |
submit_btn = gr.Button("Submit")
|
| 326 |
with gr.Column():
|
| 327 |
output_result = gr.Image()
|
|
|
|
| 328 |
output_propagated = gr.Gallery()
|
| 329 |
# output_result_mask = gr.Image()
|
| 330 |
|
| 331 |
clear_points_btn.click(
|
| 332 |
fn = preprocess_image,
|
| 333 |
inputs = input_first_frame_image,
|
| 334 |
-
outputs = [first_frame_path, tracking_points, trackings_input_label, points_map],
|
| 335 |
queue=False
|
| 336 |
)
|
| 337 |
|
| 338 |
video_in.upload(
|
| 339 |
fn = preprocess_video_in,
|
| 340 |
inputs = [video_in],
|
| 341 |
-
outputs = [first_frame_path, tracking_points, trackings_input_label, input_first_frame_image, points_map,
|
| 342 |
queue = False
|
| 343 |
)
|
| 344 |
|
|
@@ -351,8 +359,14 @@ with gr.Blocks() as demo:
|
|
| 351 |
|
| 352 |
submit_btn.click(
|
| 353 |
fn = sam_process,
|
| 354 |
-
inputs = [input_first_frame_image, checkpoint, tracking_points, trackings_input_label,
|
| 355 |
-
outputs = [output_result,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 356 |
)
|
| 357 |
|
| 358 |
demo.launch(show_api=False, show_error=True)
|
|
|
|
| 28 |
from sam2.build_sam import build_sam2_video_predictor
|
| 29 |
|
| 30 |
def preprocess_image(image):
|
| 31 |
+
return image, gr.State([]), gr.State([]), image, gr.State([])
|
| 32 |
|
| 33 |
def preprocess_video_in(video_path):
|
| 34 |
|
|
|
|
| 70 |
cap.release()
|
| 71 |
|
| 72 |
# 'image' is the first frame extracted from video_in
|
| 73 |
+
return first_frame, gr.State([]), gr.State([]), first_frame, first_frame, output_dir, gr.State([]), gr.State([])
|
| 74 |
|
| 75 |
def get_point(point_type, tracking_points, trackings_input_label, first_frame_path, evt: gr.SelectData):
|
| 76 |
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
|
|
|
|
| 184 |
|
| 185 |
return combined_images, mask_images
|
| 186 |
|
| 187 |
+
def load_model(checkpoint):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
# Load model accordingly to user's choice
|
| 189 |
if checkpoint == "tiny":
|
| 190 |
sam2_checkpoint = "./checkpoints/sam2_hiera_tiny.pt"
|
|
|
|
| 198 |
elif checkpoint == "large":
|
| 199 |
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
|
| 200 |
model_cfg = "sam2_hiera_l.yaml"
|
| 201 |
+
|
| 202 |
+
return sam2_checkpoint, model_cfg
|
| 203 |
|
| 204 |
+
def sam_process(input_first_frame_image, checkpoint, tracking_points, trackings_input_label, video_frames_dir):
|
| 205 |
+
# 1. We need to preprocess the video and store frames in the right directory
|
| 206 |
+
# — Penser à utiliser un ID unique pour le dossier
|
| 207 |
+
|
| 208 |
+
sam2_checkpoint, model_cfg = load_model(checkpoint)
|
| 209 |
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
| 210 |
|
| 211 |
|
| 212 |
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
|
| 213 |
+
print(f"STATE FRAME OUTPUT DIRECTORY: {video_frames_dir}")
|
| 214 |
+
video_dir = video_frames_dir
|
| 215 |
|
| 216 |
# scan all the JPEG frame names in this directory
|
| 217 |
frame_names = [
|
|
|
|
| 250 |
show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])
|
| 251 |
|
| 252 |
# Save the plot as a JPG file
|
| 253 |
+
first_frame_output_filename = "output_first_frame.jpg"
|
| 254 |
+
plt.savefig(first_frame_output_filename, format='jpg')
|
| 255 |
plt.close()
|
| 256 |
|
| 257 |
+
return "output_first_frame.jpg", frame_names, inference_state
|
|
|
|
| 258 |
|
| 259 |
+
def propagate_to_all(checkpoint, stored_inference_state, stored_frame_names):
|
| 260 |
+
#### PROPAGATION ####
|
| 261 |
+
sam2_checkpoint, model_cfg = load_model(checkpoint)
|
| 262 |
+
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
| 263 |
+
inference_state = stored_inference_state.value
|
| 264 |
+
frame_names = stored_frame_names.value
|
| 265 |
# Define a directory to save the JPEG images
|
| 266 |
frames_output_dir = "frames_output_images"
|
| 267 |
os.makedirs(frames_output_dir, exist_ok=True)
|
|
|
|
| 296 |
|
| 297 |
# Close the plot
|
| 298 |
plt.close()
|
|
|
|
|
|
|
| 299 |
|
| 300 |
+
return jpeg_images
|
| 301 |
|
| 302 |
with gr.Blocks() as demo:
|
| 303 |
first_frame_path = gr.State()
|
| 304 |
tracking_points = gr.State([])
|
| 305 |
trackings_input_label = gr.State([])
|
| 306 |
+
video_frames_dir = gr.State()
|
| 307 |
+
stored_inference_state = gr.State([])
|
| 308 |
+
stored_frame_names = gr.State([])
|
| 309 |
with gr.Column():
|
| 310 |
gr.Markdown("# SAM2 Video Predictor")
|
| 311 |
gr.Markdown("This is a simple demo for video segmentation with SAM2.")
|
|
|
|
| 332 |
submit_btn = gr.Button("Submit")
|
| 333 |
with gr.Column():
|
| 334 |
output_result = gr.Image()
|
| 335 |
+
propagate_btn = gr.Button("Propagate")
|
| 336 |
output_propagated = gr.Gallery()
|
| 337 |
# output_result_mask = gr.Image()
|
| 338 |
|
| 339 |
clear_points_btn.click(
|
| 340 |
fn = preprocess_image,
|
| 341 |
inputs = input_first_frame_image,
|
| 342 |
+
outputs = [first_frame_path, tracking_points, trackings_input_label, points_map, stored_inference_state],
|
| 343 |
queue=False
|
| 344 |
)
|
| 345 |
|
| 346 |
video_in.upload(
|
| 347 |
fn = preprocess_video_in,
|
| 348 |
inputs = [video_in],
|
| 349 |
+
outputs = [first_frame_path, tracking_points, trackings_input_label, input_first_frame_image, points_map, video_frames_dir, stored_inference_state, stored_frame_names],
|
| 350 |
queue = False
|
| 351 |
)
|
| 352 |
|
|
|
|
| 359 |
|
| 360 |
submit_btn.click(
|
| 361 |
fn = sam_process,
|
| 362 |
+
inputs = [input_first_frame_image, checkpoint, tracking_points, trackings_input_label, video_frames_dir],
|
| 363 |
+
outputs = [output_result, stored_frame_names, stored_inference_state]
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
propagate_btn.click(
|
| 367 |
+
fn = propagate_to_all,
|
| 368 |
+
inputs = [checkpoint, stored_inference_state, stored_frame_names],
|
| 369 |
+
outputs = [output_propagated]
|
| 370 |
)
|
| 371 |
|
| 372 |
demo.launch(show_api=False, show_error=True)
|