|
|
import os |
|
|
|
|
|
os.environ["KERAS_BACKEND"] = "jax" |
|
|
|
|
|
import gradio as gr |
|
|
import jax |
|
|
import keras |
|
|
import numpy as np |
|
|
import spaces |
|
|
from PIL import Image |
|
|
from zea import init_device |
|
|
|
|
|
from main import Config, init, run |
|
|
from utils import load_image |
|
|
import torch |
|
|
import subprocess |
|
|
|
|
|
CONFIG_PATH = "configs/semantic_dps.yaml" |
|
|
SLIDER_CONFIG_PATH = "configs/slider_params.yaml" |
|
|
ASSETS_DIR = "assets" |
|
|
DEVICE = None |
|
|
|
|
|
STATUS_STYLE_LOAD = "display:flex;align-items:center;justify-content:center;padding:40px 10px 18px 10px;border-radius:8px;font-weight:bold;font-size:1.15em;line-height:1.5;align-items:center;" |
|
|
|
|
|
STATUS_STYLE = "display:flex;align-items:center;justify-content:center;padding:18px 18px 18px 10px;border-radius:8px;font-weight:bold;font-size:1.15em;line-height:1.1;align-items:center;" |
|
|
|
|
|
description = """ |
|
|
# Cardiac Ultrasound Dehazing with Semantic Diffusion |
|
|
|
|
|
Select an example image below to see the dehazing algorithm in action. The algorithm was tuned for the DehazingEcho2025 challenge dataset, so be wary of using it on other datasets. |
|
|
|
|
|
Tip: Adjust "Omega (Ventricle)" and "Eta (haze prior)" to control the dehazing effect. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
config, diffusion_model = None, None |
|
|
model_loaded = False |
|
|
|
|
|
|
|
|
def initialize_model(): |
|
|
global config, diffusion_model, model_loaded |
|
|
if config is None or diffusion_model is None: |
|
|
config = Config.from_yaml(CONFIG_PATH) |
|
|
diffusion_model = init(config) |
|
|
|
|
|
h, w = diffusion_model.input_shape[:2] |
|
|
dummy_img = np.zeros((1, h, w), dtype=np.float32) |
|
|
params = config.params |
|
|
guidance_kwargs = { |
|
|
"omega": params["guidance_kwargs"]["omega"], |
|
|
"omega_vent": params["guidance_kwargs"].get("omega_vent", 1.0), |
|
|
"omega_sept": params["guidance_kwargs"].get("omega_sept", 1.0), |
|
|
"eta": params["guidance_kwargs"].get("eta", 1.0), |
|
|
"smooth_l1_beta": params["guidance_kwargs"]["smooth_l1_beta"], |
|
|
} |
|
|
seed = jax.random.PRNGKey(config.seed) |
|
|
run( |
|
|
hazy_images=dummy_img, |
|
|
diffusion_model=diffusion_model, |
|
|
seed=seed, |
|
|
guidance_kwargs=guidance_kwargs, |
|
|
mask_params=params["mask_params"], |
|
|
fixed_mask_params=params["fixed_mask_params"], |
|
|
skeleton_params=params["skeleton_params"], |
|
|
batch_size=1, |
|
|
diffusion_steps=1, |
|
|
verbose=False, |
|
|
) |
|
|
|
|
|
model_loaded = True |
|
|
return config, diffusion_model |
|
|
|
|
|
|
|
|
@spaces.GPU(duration=10) |
|
|
def process_image(input_img, diffusion_steps, omega, omega_vent, omega_sept, eta): |
|
|
global config, diffusion_model, model_loaded |
|
|
if not model_loaded: |
|
|
yield ( |
|
|
gr.update( |
|
|
value=f'<div style="background:#ffeeba;{STATUS_STYLE}color:#856404;">⏳ Model is still loading. Please wait...</div>' |
|
|
), |
|
|
None, |
|
|
) |
|
|
return |
|
|
|
|
|
if input_img is None: |
|
|
yield ( |
|
|
gr.update( |
|
|
value=f'<div style="background:#ffeeba;{STATUS_STYLE}color:#856404;">⚠️ No input image was provided. Please select or upload an image before running.</div>' |
|
|
), |
|
|
None, |
|
|
) |
|
|
return |
|
|
params = config.params |
|
|
|
|
|
def _prepare_image(image): |
|
|
resized = False |
|
|
if image.mode != "L": |
|
|
image = image.convert("L") |
|
|
orig_shape = image.size[::-1] |
|
|
h, w = diffusion_model.input_shape[:2] |
|
|
if image.size != (w, h): |
|
|
image = image.resize((w, h), Image.BILINEAR) |
|
|
resized = True |
|
|
image = np.array(image) |
|
|
image = image.astype(np.float32) |
|
|
image = image[None, ...] |
|
|
return image, resized, orig_shape |
|
|
|
|
|
try: |
|
|
image, resized, orig_shape = _prepare_image(input_img) |
|
|
except Exception as e: |
|
|
yield ( |
|
|
gr.update( |
|
|
value=f'<div style="background:#f8d7da;{STATUS_STYLE}color:#721c24;">❌ Error preparing input image: {e}</div>' |
|
|
), |
|
|
None, |
|
|
) |
|
|
return |
|
|
|
|
|
guidance_kwargs = { |
|
|
"omega": omega, |
|
|
"omega_vent": omega_vent, |
|
|
"omega_sept": omega_sept, |
|
|
"eta": eta, |
|
|
"smooth_l1_beta": params["guidance_kwargs"]["smooth_l1_beta"], |
|
|
} |
|
|
|
|
|
seed = jax.random.PRNGKey(config.seed) |
|
|
|
|
|
try: |
|
|
yield ( |
|
|
gr.update( |
|
|
value=f'<div style="background:#cce5ff;{STATUS_STYLE}color:#004085;">🌀 Running dehazing algorithm...</div>' |
|
|
), |
|
|
None, |
|
|
) |
|
|
_, pred_tissue_images, *_ = run( |
|
|
hazy_images=image, |
|
|
diffusion_model=diffusion_model, |
|
|
seed=seed, |
|
|
guidance_kwargs=guidance_kwargs, |
|
|
mask_params=params["mask_params"], |
|
|
fixed_mask_params=params["fixed_mask_params"], |
|
|
skeleton_params=params["skeleton_params"], |
|
|
batch_size=1, |
|
|
diffusion_steps=diffusion_steps, |
|
|
threshold_output_quantile=params.get("threshold_output_quantile", None), |
|
|
preserve_bottom_percent=params.get("preserve_bottom_percent", 30.0), |
|
|
bottom_transition_width=params.get("bottom_transition_width", 10.0), |
|
|
verbose=False, |
|
|
) |
|
|
except Exception as e: |
|
|
yield ( |
|
|
gr.update( |
|
|
value=f'<div style="background:#f8d7da;{STATUS_STYLE}color:#721c24;">❌ The algorithm failed to process the image: {e}</div>' |
|
|
), |
|
|
None, |
|
|
) |
|
|
return |
|
|
|
|
|
out_img = np.squeeze(pred_tissue_images[0]) |
|
|
out_img = np.clip(out_img, 0, 255).astype(np.uint8) |
|
|
out_pil = Image.fromarray(out_img) |
|
|
if resized and out_pil.size != (orig_shape[1], orig_shape[0]): |
|
|
out_pil = out_pil.resize((orig_shape[1], orig_shape[0]), Image.BILINEAR) |
|
|
yield ( |
|
|
gr.update( |
|
|
value=f'<div style="background:#d4edda;{STATUS_STYLE}color:#155724;">✅ Done!</div>' |
|
|
), |
|
|
(input_img, out_pil), |
|
|
) |
|
|
|
|
|
|
|
|
slider_params = Config.from_yaml(SLIDER_CONFIG_PATH) |
|
|
|
|
|
diffusion_steps_default = slider_params["diffusion_steps"]["default"] |
|
|
diffusion_steps_min = slider_params["diffusion_steps"]["min"] |
|
|
diffusion_steps_max = slider_params["diffusion_steps"]["max"] |
|
|
diffusion_steps_step = slider_params["diffusion_steps"]["step"] |
|
|
|
|
|
omega_default = slider_params["omega"]["default"] |
|
|
omega_min = slider_params["omega"]["min"] |
|
|
omega_max = slider_params["omega"]["max"] |
|
|
omega_step = slider_params["omega"]["step"] |
|
|
|
|
|
omega_vent_default = slider_params["omega_vent"]["default"] |
|
|
omega_vent_min = slider_params["omega_vent"]["min"] |
|
|
omega_vent_max = slider_params["omega_vent"]["max"] |
|
|
omega_vent_step = slider_params["omega_vent"]["step"] |
|
|
|
|
|
omega_sept_default = slider_params["omega_sept"]["default"] |
|
|
omega_sept_min = slider_params["omega_sept"]["min"] |
|
|
omega_sept_max = slider_params["omega_sept"]["max"] |
|
|
omega_sept_step = slider_params["omega_sept"]["step"] |
|
|
|
|
|
eta_default = slider_params["eta"]["default"] |
|
|
eta_min = slider_params["eta"]["min"] |
|
|
eta_max = slider_params["eta"]["max"] |
|
|
eta_step = slider_params["eta"]["step"] |
|
|
|
|
|
|
|
|
example_image_paths = [ |
|
|
os.path.join(ASSETS_DIR, f) |
|
|
for f in os.listdir(ASSETS_DIR) |
|
|
if f.lower().endswith(".png") |
|
|
] |
|
|
example_images = [load_image(p) for p in example_image_paths] |
|
|
examples = [[img] for img in example_images] |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown(description) |
|
|
status = gr.Markdown( |
|
|
f'<div style="background:#ffeeba;{STATUS_STYLE_LOAD}color:#856404;">⏳ Loading model...</div>', |
|
|
visible=True, |
|
|
) |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
img1 = gr.Image( |
|
|
label="Input Image", |
|
|
type="pil", |
|
|
webcam_options=False, |
|
|
value=example_images[0] if example_images else None, |
|
|
) |
|
|
gr.Examples(examples=examples, inputs=[img1]) |
|
|
with gr.Column(): |
|
|
img2 = gr.ImageSlider(label="Dehazed Image", type="pil") |
|
|
with gr.Row(): |
|
|
diffusion_steps_slider = gr.Slider( |
|
|
minimum=diffusion_steps_min, |
|
|
maximum=diffusion_steps_max, |
|
|
step=diffusion_steps_step, |
|
|
value=diffusion_steps_default, |
|
|
label="Diffusion Steps", |
|
|
) |
|
|
omega_slider = gr.Slider( |
|
|
minimum=omega_min, |
|
|
maximum=omega_max, |
|
|
step=omega_step, |
|
|
value=omega_default, |
|
|
label="Omega (background)", |
|
|
) |
|
|
omega_vent_slider = gr.Slider( |
|
|
minimum=omega_vent_min, |
|
|
maximum=omega_vent_max, |
|
|
step=omega_vent_step, |
|
|
value=omega_vent_default, |
|
|
label="Omega Ventricle", |
|
|
) |
|
|
omega_sept_slider = gr.Slider( |
|
|
minimum=omega_sept_min, |
|
|
maximum=omega_sept_max, |
|
|
step=omega_sept_step, |
|
|
value=omega_sept_default, |
|
|
label="Omega Septum", |
|
|
) |
|
|
eta_slider = gr.Slider( |
|
|
minimum=eta_min, |
|
|
maximum=eta_max, |
|
|
step=eta_step, |
|
|
value=eta_default, |
|
|
label="Eta (haze prior)", |
|
|
) |
|
|
run_btn = gr.Button("Run", interactive=False) |
|
|
|
|
|
run_btn.click( |
|
|
process_image, |
|
|
inputs=[ |
|
|
img1, |
|
|
diffusion_steps_slider, |
|
|
omega_slider, |
|
|
omega_vent_slider, |
|
|
omega_sept_slider, |
|
|
eta_slider, |
|
|
], |
|
|
outputs=[status, img2], |
|
|
queue=True, |
|
|
) |
|
|
|
|
|
def load_model_event(): |
|
|
global config, diffusion_model, model_loaded, DEVICE |
|
|
try: |
|
|
if DEVICE is None: |
|
|
try: |
|
|
DEVICE = init_device() |
|
|
except: |
|
|
print("Could not initialize device using `zea.init_device()`") |
|
|
print(f"KERAS version: {keras.__version__}") |
|
|
try: |
|
|
print(f"JAX version: {jax.__version__}") |
|
|
print(f"JAX devices: {jax.devices()}") |
|
|
except Exception as e: |
|
|
print(f"Could not get JAX info: {e}") |
|
|
|
|
|
try: |
|
|
print(f"PyTorch version: {torch.__version__}") |
|
|
print(f"PyTorch CUDA available: {torch.cuda.is_available()}") |
|
|
print(f"PyTorch CUDA device count: {torch.cuda.device_count()}") |
|
|
print(f"PyTorch devices: {[torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]}") |
|
|
print(f"PyTorch CUDA version: {torch.version.cuda}") |
|
|
print(f"PyTorch cuDNN version: {torch.backends.cudnn.version()}") |
|
|
except Exception as e: |
|
|
print(f"Could not get PyTorch info: {e}") |
|
|
|
|
|
try: |
|
|
cuda_version = subprocess.getoutput("nvcc --version") |
|
|
print(f"nvcc version:\n{cuda_version}") |
|
|
nvidia_smi = subprocess.getoutput("nvidia-smi") |
|
|
print(f"nvidia-smi output:\n{nvidia_smi}") |
|
|
except Exception as e: |
|
|
print(f"Could not get CUDA/nvidia-smi info: {e}") |
|
|
|
|
|
config, diffusion_model = initialize_model() |
|
|
ready_msg = gr.update( |
|
|
value=f'<div style="background:#d4edda;{STATUS_STYLE}color:#155724;">✅ Model loaded! You can now press Run.</div>' |
|
|
) |
|
|
return ready_msg, gr.update(interactive=True) |
|
|
except Exception as e: |
|
|
return gr.update( |
|
|
value=f'<div style="background:#f8d7da;{STATUS_STYLE}color:#721c24;">❌ Error loading model: {e}</div>' |
|
|
), gr.update(interactive=False) |
|
|
|
|
|
demo.load( |
|
|
load_model_event, |
|
|
inputs=None, |
|
|
outputs=[status, run_btn], |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|