tristan-deep's picture
replaced load_image
969f59e
raw
history blame
5.45 kB
from pathlib import Path
import numpy as np
from keras import ops
from PIL import Image
from skimage import filters, morphology
from zea.utils import translate
def L1(x):
"""L1 norm of a tensor.
Implementation of L1 norm: https://mathworld.wolfram.com/L1-Norm.html
"""
return ops.sum(ops.abs(x))
def smooth_L1(x, beta=0.4):
"""Smooth L1 loss function.
Implementation of Smooth L1 loss. Large beta values make it similar to L1 loss,
while small beta values make it similar to L2 loss.
"""
abs_x = ops.abs(x)
loss = ops.where(abs_x < beta, 0.5 * x**2 / beta, abs_x - 0.5 * beta)
return ops.sum(loss)
def postprocess(data, normalization_range):
"""Postprocess data from model output to image."""
data = ops.clip(data, *normalization_range)
data = translate(data, normalization_range, (0, 255))
data = ops.convert_to_numpy(data)
data = np.squeeze(data, axis=-1)
return np.clip(data, 0, 255).astype("uint8")
def preprocess(data, normalization_range):
"""Preprocess data for model input. Converts uint8 image(s) in [0, 255] to model input range."""
data = ops.convert_to_tensor(data, dtype="float32")
data = translate(data, (0, 255), normalization_range)
data = ops.expand_dims(data, axis=-1)
return data
def apply_bottom_preservation(
output_images, input_images, preserve_bottom_percent=30.0, transition_width=10.0
):
"""Apply bottom preservation with smooth windowed transition.
Args:
output_images: Model output images, (batch, height, width, channels)
input_images: Original input images, (batch, height, width, channels)
preserve_bottom_percent: Percentage of bottom to preserve from input (default 30%)
transition_width: Percentage of image height for smooth transition (default 10%)
Returns:
Blended images with preserved bottom portion
"""
output_shape = ops.shape(output_images)
batch_size, height, width, channels = output_shape
preserve_height = int(height * preserve_bottom_percent / 100.0)
transition_height = int(height * transition_width / 100.0)
transition_start = height - preserve_height - transition_height
preserve_start = height - preserve_height
transition_start = max(0, transition_start)
preserve_start = min(height, preserve_start)
if transition_start >= preserve_start:
transition_start = preserve_start
transition_height = 0
y_coords = ops.arange(height, dtype="float32")
y_coords = ops.reshape(y_coords, (height, 1, 1))
if transition_height > 0:
# Smooth transition using cosine interpolation
transition_region = ops.logical_and(
y_coords >= transition_start, y_coords < preserve_start
)
transition_progress = (y_coords - transition_start) / transition_height
transition_progress = ops.clip(transition_progress, 0.0, 1.0)
# Use cosine for smooth transition (0.5 * (1 - cos(π * t)))
cosine_weight = 0.5 * (1.0 - ops.cos(np.pi * transition_progress))
blend_weight = ops.where(
y_coords < transition_start,
0.0,
ops.where(
transition_region,
cosine_weight,
1.0,
),
)
else:
# No transition, just hard switch
blend_weight = ops.where(y_coords >= preserve_start, 1.0, 0.0)
blend_weight = ops.expand_dims(blend_weight, axis=0)
blended_images = (1.0 - blend_weight) * output_images + blend_weight * input_images
return blended_images
def extract_skeleton(images, input_range, sigma_pre=4, sigma_post=4, threshold=0.3):
"""Extract skeletons from the input images."""
images_np = ops.convert_to_numpy(images)
images_np = np.clip(images_np, input_range[0], input_range[1])
images_np = translate(images_np, input_range, (0, 1))
images_np = np.squeeze(images_np, axis=-1)
skeleton_masks = []
for img in images_np:
img[img < threshold] = 0
smoothed = filters.gaussian(img, sigma=sigma_pre)
binary = smoothed > filters.threshold_otsu(smoothed)
skeleton = morphology.skeletonize(binary)
skeleton = morphology.dilation(skeleton, morphology.disk(2))
skeleton = filters.gaussian(skeleton.astype(np.float32), sigma=sigma_post)
skeleton_masks.append(skeleton)
skeleton_masks = np.array(skeleton_masks)
skeleton_masks = np.expand_dims(skeleton_masks, axis=-1)
# normalize to [0, 1]
min_val, max_val = np.min(skeleton_masks), np.max(skeleton_masks)
skeleton_masks = (skeleton_masks - min_val) / (max_val - min_val + 1e-8)
return ops.convert_to_tensor(skeleton_masks, dtype=images.dtype)
def load_image(filename, grayscale=True):
"""Load an image file and return a numpy array using PIL.
Args:
filename (str): The path to the image file.
grayscale (bool, optional): Whether to convert the image to grayscale. Defaults to True.
Returns:
numpy.ndarray: A numpy array of the image.
Raises:
FileNotFoundError: If the file does not exist.
"""
filename = Path(filename)
if not filename.exists():
raise FileNotFoundError(f"File {filename} does not exist")
img = Image.open(filename)
if grayscale:
img = img.convert("L")
else:
img = img.convert("RGB")
arr = np.array(img)
return arr