Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,706 Bytes
1c76709 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import numpy as np
from keras import ops
from skimage import filters, morphology
from zea.utils import translate
def L1(x):
"""L1 norm of a tensor.
Implementation of L1 norm: https://mathworld.wolfram.com/L1-Norm.html
"""
return ops.sum(ops.abs(x))
def smooth_L1(x, beta=0.4):
"""Smooth L1 loss function.
Implementation of Smooth L1 loss. Large beta values make it similar to L1 loss,
while small beta values make it similar to L2 loss.
"""
abs_x = ops.abs(x)
loss = ops.where(abs_x < beta, 0.5 * x**2 / beta, abs_x - 0.5 * beta)
return ops.sum(loss)
def postprocess(data, normalization_range):
"""Postprocess data from model output to image."""
data = ops.clip(data, *normalization_range)
data = translate(data, normalization_range, (0, 255))
data = ops.convert_to_numpy(data)
data = np.squeeze(data, axis=-1)
return np.clip(data, 0, 255).astype("uint8")
def preprocess(data, normalization_range):
"""Preprocess data for model input. Converts uint8 image(s) in [0, 255] to model input range."""
data = ops.convert_to_tensor(data, dtype="float32")
data = translate(data, (0, 255), normalization_range)
data = ops.expand_dims(data, axis=-1)
return data
def apply_bottom_preservation(
output_images, input_images, preserve_bottom_percent=30.0, transition_width=10.0
):
"""Apply bottom preservation with smooth windowed transition.
Args:
output_images: Model output images, (batch, height, width, channels)
input_images: Original input images, (batch, height, width, channels)
preserve_bottom_percent: Percentage of bottom to preserve from input (default 30%)
transition_width: Percentage of image height for smooth transition (default 10%)
Returns:
Blended images with preserved bottom portion
"""
output_shape = ops.shape(output_images)
batch_size, height, width, channels = output_shape
preserve_height = int(height * preserve_bottom_percent / 100.0)
transition_height = int(height * transition_width / 100.0)
transition_start = height - preserve_height - transition_height
preserve_start = height - preserve_height
transition_start = max(0, transition_start)
preserve_start = min(height, preserve_start)
if transition_start >= preserve_start:
transition_start = preserve_start
transition_height = 0
y_coords = ops.arange(height, dtype="float32")
y_coords = ops.reshape(y_coords, (height, 1, 1))
if transition_height > 0:
# Smooth transition using cosine interpolation
transition_region = ops.logical_and(
y_coords >= transition_start, y_coords < preserve_start
)
transition_progress = (y_coords - transition_start) / transition_height
transition_progress = ops.clip(transition_progress, 0.0, 1.0)
# Use cosine for smooth transition (0.5 * (1 - cos(π * t)))
cosine_weight = 0.5 * (1.0 - ops.cos(np.pi * transition_progress))
blend_weight = ops.where(
y_coords < transition_start,
0.0,
ops.where(
transition_region,
cosine_weight,
1.0,
),
)
else:
# No transition, just hard switch
blend_weight = ops.where(y_coords >= preserve_start, 1.0, 0.0)
blend_weight = ops.expand_dims(blend_weight, axis=0)
blended_images = (1.0 - blend_weight) * output_images + blend_weight * input_images
return blended_images
def extract_skeleton(images, input_range, sigma_pre=4, sigma_post=4, threshold=0.3):
"""Extract skeletons from the input images."""
images_np = ops.convert_to_numpy(images)
images_np = np.clip(images_np, input_range[0], input_range[1])
images_np = translate(images_np, input_range, (0, 1))
images_np = np.squeeze(images_np, axis=-1)
skeleton_masks = []
for img in images_np:
img[img < threshold] = 0
smoothed = filters.gaussian(img, sigma=sigma_pre)
binary = smoothed > filters.threshold_otsu(smoothed)
skeleton = morphology.skeletonize(binary)
skeleton = morphology.dilation(skeleton, morphology.disk(2))
skeleton = filters.gaussian(skeleton.astype(np.float32), sigma=sigma_post)
skeleton_masks.append(skeleton)
skeleton_masks = np.array(skeleton_masks)
skeleton_masks = np.expand_dims(skeleton_masks, axis=-1)
# normalize to [0, 1]
min_val, max_val = np.min(skeleton_masks), np.max(skeleton_masks)
skeleton_masks = (skeleton_masks - min_val) / (max_val - min_val + 1e-8)
return ops.convert_to_tensor(skeleton_masks, dtype=images.dtype)
|