Update app.py
Browse files
app.py
CHANGED
|
@@ -177,6 +177,19 @@ def show_mask(mask, ax, obj_id=None, random_color=False):
|
|
| 177 |
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
| 178 |
ax.imshow(mask_image)
|
| 179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
def show_points(coords, labels, ax, marker_size=200):
|
| 182 |
pos_points = coords[labels==1]
|
|
@@ -319,7 +332,7 @@ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_
|
|
| 319 |
|
| 320 |
# Initialize a list to store file paths of saved images
|
| 321 |
jpeg_images = []
|
| 322 |
-
|
| 323 |
|
| 324 |
# run propagation throughout the video and collect the results in a dict
|
| 325 |
video_segments = {} # video_segments contains the per-frame segmentation results
|
|
@@ -343,20 +356,6 @@ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_
|
|
| 343 |
for out_obj_id, out_mask in video_segments[out_frame_idx].items():
|
| 344 |
show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
|
| 345 |
|
| 346 |
-
# Save the raw binary mask as a separate image
|
| 347 |
-
mask_filename = os.path.join(mask_frames_output_dir, f"mask_{out_frame_idx}.jpg")
|
| 348 |
-
binary_mask = np.squeeze(out_mask) # Ensure the mask is 2D
|
| 349 |
-
binary_mask = (binary_mask * 255).astype(np.uint8) # Scale mask to 0-255
|
| 350 |
-
|
| 351 |
-
if binary_mask.ndim != 2: # Ensure it's 2D for PIL
|
| 352 |
-
raise ValueError(f"Mask has invalid dimensions: {binary_mask.shape}")
|
| 353 |
-
|
| 354 |
-
mask_image = Image.fromarray(binary_mask)
|
| 355 |
-
mask_image.save(mask_filename) # Save the mask as a JPEG
|
| 356 |
-
masks_frames.append(mask_filename) # Append to the list of masks
|
| 357 |
-
|
| 358 |
-
print(f"MASKS FRAMES: {masks_frames}")
|
| 359 |
-
|
| 360 |
# Define the output filename and save the figure as a JPEG file
|
| 361 |
output_filename = os.path.join(frames_output_dir, f"frame_{out_frame_idx}.jpg")
|
| 362 |
plt.savefig(output_filename, format='jpg')
|
|
@@ -370,6 +369,23 @@ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_
|
|
| 370 |
if f"frame_{out_frame_idx}.jpg" not in available_frames_to_check:
|
| 371 |
available_frames_to_check.append(f"frame_{out_frame_idx}.jpg")
|
| 372 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
|
| 374 |
|
| 375 |
torch.cuda.empty_cache()
|
|
@@ -392,18 +408,30 @@ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_
|
|
| 392 |
codec='libx264'
|
| 393 |
)
|
| 394 |
|
| 395 |
-
|
| 396 |
-
# Create the video clip
|
| 397 |
-
mask_clip = ImageSequenceClip(masks_frames, fps=fps)
|
| 398 |
|
| 399 |
-
#
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
#
|
| 403 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
|
| 405 |
|
| 406 |
-
return gr.update(value=None), gr.update(value=final_vid_output_path), working_frame, available_frames_to_check, gr.update(visible=True),
|
| 407 |
|
| 408 |
def update_ui(vis_frame_type):
|
| 409 |
if vis_frame_type == "check":
|
|
|
|
| 177 |
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
| 178 |
ax.imshow(mask_image)
|
| 179 |
|
| 180 |
+
def show_white_mask(mask, ax):
|
| 181 |
+
# Ensure mask is binary (values 0 or 1)
|
| 182 |
+
mask = (mask > 0).astype(float) # Convert to binary mask
|
| 183 |
+
h, w = mask.shape[-2:]
|
| 184 |
+
|
| 185 |
+
# Create a white mask (RGBA: [1, 1, 1, alpha])
|
| 186 |
+
alpha = 1.0 # Fully opaque
|
| 187 |
+
color = np.array([1, 1, 1, alpha])
|
| 188 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
| 189 |
+
|
| 190 |
+
# Display black background
|
| 191 |
+
ax.imshow(np.zeros((h, w, 3), dtype=float)) # Black background
|
| 192 |
+
ax.imshow(mask_image) # Overlay white mask
|
| 193 |
|
| 194 |
def show_points(coords, labels, ax, marker_size=200):
|
| 195 |
pos_points = coords[labels==1]
|
|
|
|
| 332 |
|
| 333 |
# Initialize a list to store file paths of saved images
|
| 334 |
jpeg_images = []
|
| 335 |
+
masks_images = []
|
| 336 |
|
| 337 |
# run propagation throughout the video and collect the results in a dict
|
| 338 |
video_segments = {} # video_segments contains the per-frame segmentation results
|
|
|
|
| 356 |
for out_obj_id, out_mask in video_segments[out_frame_idx].items():
|
| 357 |
show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
|
| 358 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
# Define the output filename and save the figure as a JPEG file
|
| 360 |
output_filename = os.path.join(frames_output_dir, f"frame_{out_frame_idx}.jpg")
|
| 361 |
plt.savefig(output_filename, format='jpg')
|
|
|
|
| 369 |
if f"frame_{out_frame_idx}.jpg" not in available_frames_to_check:
|
| 370 |
available_frames_to_check.append(f"frame_{out_frame_idx}.jpg")
|
| 371 |
|
| 372 |
+
# Step 2: Create and store a black-and-white mask image using show_white_mask
|
| 373 |
+
# Create a figure without displaying it for the white mask
|
| 374 |
+
fig, ax = plt.subplots(figsize=(6, 4))
|
| 375 |
+
ax.axis("off") # Remove axes for a clean mask
|
| 376 |
+
|
| 377 |
+
# Overlay each mask as white on a black background
|
| 378 |
+
for out_mask in video_segments[out_frame_idx].values():
|
| 379 |
+
show_white_mask(out_mask, ax)
|
| 380 |
+
|
| 381 |
+
# Save the white mask figure to an image in memory
|
| 382 |
+
mask_filename = os.path.join(masks_output_dir, f"mask_{out_frame_idx}.jpg")
|
| 383 |
+
fig.savefig(mask_filename, format='jpg', bbox_inches="tight", pad_inches=0)
|
| 384 |
+
plt.close(fig)
|
| 385 |
+
|
| 386 |
+
# Add the saved mask image to the masks_images array
|
| 387 |
+
masks_images.append(mask_filename)
|
| 388 |
+
|
| 389 |
|
| 390 |
|
| 391 |
torch.cuda.empty_cache()
|
|
|
|
| 408 |
codec='libx264'
|
| 409 |
)
|
| 410 |
|
| 411 |
+
print("MAKING MASK VIDEO ...")
|
|
|
|
|
|
|
| 412 |
|
| 413 |
+
# Create a video from the masks_images array
|
| 414 |
+
mask_video_filename = "final_masks_video.mp4"
|
| 415 |
+
|
| 416 |
+
# Get the dimensions of the first mask image
|
| 417 |
+
frame = cv2.imread(masks_images[0])
|
| 418 |
+
height, width, _ = frame.shape
|
| 419 |
+
|
| 420 |
+
# Define the video writer
|
| 421 |
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
| 422 |
+
fps = original_fps # Frames per second
|
| 423 |
+
video_writer = cv2.VideoWriter(mask_video_filename, fourcc, fps, (width, height))
|
| 424 |
+
|
| 425 |
+
# Write each mask image to the video
|
| 426 |
+
for mask_path in masks_images:
|
| 427 |
+
frame = cv2.imread(mask_path)
|
| 428 |
+
video_writer.write(frame)
|
| 429 |
+
|
| 430 |
+
video_writer.release()
|
| 431 |
+
print(f"Mask Video saved at {mask_video_filename}")
|
| 432 |
|
| 433 |
|
| 434 |
+
return gr.update(value=None), gr.update(value=final_vid_output_path), working_frame, available_frames_to_check, gr.update(visible=True), mask_video_filename
|
| 435 |
|
| 436 |
def update_ui(vis_frame_type):
|
| 437 |
if vis_frame_type == "check":
|