Spaces:
Running
on
Zero
Running
on
Zero
| import warnings | |
| from glob import glob | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| import tyro | |
| from PIL import Image | |
| from scipy.ndimage import binary_erosion, distance_transform_edt | |
| from scipy.stats import ks_2samp | |
| from zea import log | |
| import fid_score | |
| from plots import plot_metrics | |
| from utils import load_image | |
| def calculate_fid_score(denoised_image_dirs, ground_truth_dir): | |
| if isinstance(denoised_image_dirs, (str, Path)): | |
| denoised_image_dirs = [denoised_image_dirs] | |
| elif not isinstance(denoised_image_dirs, list): | |
| raise ValueError("Input must be a path or list of paths") | |
| clean_images_folder = glob(str(ground_truth_dir) + "/*.png") | |
| print(f"Looking for clean images in: {ground_truth_dir}") | |
| print(f"Found {len(clean_images_folder)} clean images") | |
| # Determine optimal batch size based on number of images | |
| num_denoised = len(denoised_image_dirs) | |
| num_clean = len(clean_images_folder) | |
| optimal_batch_size = min(8, num_denoised, num_clean) | |
| print(f"Using batch size: {optimal_batch_size}") | |
| with warnings.catch_warnings(): | |
| warnings.filterwarnings("ignore", message="os.fork.*JAX is multithreaded") | |
| fid_value = fid_score.calculate_fid_with_cached_ground_truth( | |
| denoised_image_dirs, | |
| clean_images_folder, | |
| batch_size=optimal_batch_size, | |
| device="cuda" if torch.cuda.is_available() else "cpu", | |
| num_workers=2 if torch.cuda.is_available() else 0, | |
| dims=2048, | |
| ) | |
| return fid_value | |
| def gcnr(img1, img2): | |
| """Generalized Contrast-to-Noise Ratio""" | |
| _, bins = np.histogram(np.concatenate((img1, img2)), bins=256) | |
| f, _ = np.histogram(img1, bins=bins, density=True) | |
| g, _ = np.histogram(img2, bins=bins, density=True) | |
| f /= f.sum() | |
| g /= g.sum() | |
| return 1 - np.sum(np.minimum(f, g)) | |
| def cnr(img1, img2): | |
| """Contrast-to-Noise Ratio""" | |
| return (img1.mean() - img2.mean()) / np.sqrt(img1.var() + img2.var()) | |
| def calculate_cnr_gcnr(result_dehazed_cardiac_ultrasound, mask_path): | |
| """ | |
| Evaluate gCNR and CNR metrics for denoised images using paired masks. | |
| Saves detailed and summary statistics to Excel. | |
| """ | |
| results = [] | |
| mask = np.array(Image.open(mask_path).convert("L")) | |
| roi1_pixels = result_dehazed_cardiac_ultrasound[mask == 255] # Foreground ROI | |
| roi2_pixels = result_dehazed_cardiac_ultrasound[mask == 128] # Background/Noise ROI | |
| gcnr_val = gcnr(roi1_pixels, roi2_pixels) | |
| cnr_val = cnr(roi1_pixels, roi2_pixels) | |
| results.append([cnr_val, gcnr_val]) | |
| return results | |
| def calculate_ks_statistics( | |
| result_hazy_cardiac_ultrasound, result_dehazed_cardiac_ultrasound, mask_path | |
| ): | |
| mask = np.array(Image.open(mask_path).convert("L")) | |
| roi1_original = result_hazy_cardiac_ultrasound[mask == 255] # region A | |
| roi1_denoised = result_dehazed_cardiac_ultrasound[mask == 255] | |
| roi2_original = result_hazy_cardiac_ultrasound[mask == 128] # region B | |
| roi2_denoised = result_dehazed_cardiac_ultrasound[mask == 128] | |
| roi1_ks_stat, roi1_ks_p_value = (None, None) | |
| roi2_ks_stat, roi2_ks_p_value = (None, None) | |
| if roi1_original.size > 0 and roi1_denoised.size > 0: | |
| roi1_ks_stat, roi1_ks_p_value = ks_2samp(roi1_original, roi1_denoised) | |
| if roi2_original.size > 0 and roi2_denoised.size > 0: | |
| roi2_ks_stat, roi2_ks_p_value = ks_2samp(roi2_original, roi2_denoised) | |
| return roi1_ks_stat, roi1_ks_p_value, roi2_ks_stat, roi2_ks_p_value | |
| def calculate_dice_asd(image_path, label_path, checkpoint_path, image_size=224): | |
| try: | |
| from test import inference # Our Segmentation Method | |
| except ImportError: | |
| raise ImportError( | |
| "Segmentation method not available, skipping Dice/ASD calculation" | |
| ) | |
| pred_img = inference(image_path, checkpoint_path, image_size) | |
| pred = np.array(pred_img) > 127 | |
| label = Image.open(label_path).convert("L") | |
| label = label.resize((image_size, image_size), Image.NEAREST) | |
| label = np.array(label) > 127 | |
| # calculate Dice | |
| intersection = np.logical_and(pred, label).sum() | |
| dice = 2 * intersection / (pred.sum() + label.sum() + 1e-8) | |
| # calculate ASD | |
| if pred.sum() == 0 or label.sum() == 0: | |
| asd = np.nan | |
| else: | |
| pred_dt = distance_transform_edt(~pred) | |
| label_dt = distance_transform_edt(~label) | |
| surface_pred = pred ^ binary_erosion(pred) | |
| surface_label = label ^ binary_erosion(label) | |
| d1 = pred_dt[surface_label].mean() | |
| d2 = label_dt[surface_pred].mean() | |
| asd = (d1 + d2) / 2 | |
| return dice, asd | |
| def calculate_final_score(aggregates): | |
| try: | |
| # (FID + CNR + gCNR):(KS^A + KS^B):(Dice + ASD)= 5:3:2 | |
| group1_score = 0 # FID + CNR + gCNR | |
| if aggregates.get("fid") is not None: | |
| fid_min = 60.0 | |
| fid_max = 150.0 | |
| fid_score = (fid_max - aggregates["fid"]) / (fid_max - fid_min) | |
| fid_score = max(0, min(1, fid_score)) | |
| group1_score += fid_score * 100 * 0.33 | |
| if aggregates.get("cnr_mean") is not None: | |
| cnr_min = 1.0 | |
| cnr_max = 1.5 | |
| cnr_score = (aggregates["cnr_mean"] - cnr_min) / (cnr_max - cnr_min) | |
| cnr_score = max(0, min(1, cnr_score)) | |
| group1_score += cnr_score * 100 * 0.33 | |
| if aggregates.get("gcnr_mean") is not None: | |
| gcnr_min = 0.5 | |
| gcnr_max = 0.8 | |
| gcnr_score = (aggregates["gcnr_mean"] - gcnr_min) / (gcnr_max - gcnr_min) | |
| gcnr_score = max(0, min(1, gcnr_score)) | |
| group1_score += gcnr_score * 100 * 0.34 | |
| group2_score = 0 # KS^A + KS^B | |
| if aggregates.get("ks_roi1_ksstatistic_mean") is not None: | |
| ks1_min = 0.1 | |
| ks1_max = 0.3 | |
| ks1_score = (ks1_max - aggregates["ks_roi1_ksstatistic_mean"]) / ( | |
| ks1_max - ks1_min | |
| ) | |
| ks1_score = max(0, min(1, ks1_score)) | |
| group2_score += ks1_score * 100 * 0.5 | |
| if aggregates.get("ks_roi2_ksstatistic_mean") is not None: | |
| ks2_min = 0.0 | |
| ks2_max = 0.5 | |
| ks2_score = (aggregates["ks_roi2_ksstatistic_mean"] - ks2_min) / ( | |
| ks2_max - ks2_min | |
| ) | |
| ks2_score = max(0, min(1, ks2_score)) | |
| group2_score += ks2_score * 100 * 0.5 | |
| group3_score = 0 # Dice + ASD | |
| if aggregates.get("dice_mean") is not None: | |
| dice_min = 0.85 | |
| dice_max = 0.95 | |
| dice_score = (aggregates["dice_mean"] - dice_min) / (dice_max - dice_min) | |
| dice_score = max(0, min(1, dice_score)) | |
| group3_score += dice_score * 100 * 0.5 | |
| if aggregates.get("asd_mean") is not None: | |
| asd_min = 0.7 | |
| asd_max = 2.0 | |
| asd_score = (asd_max - aggregates["asd_mean"]) / (asd_max - asd_min) | |
| asd_score = max(0, min(1, asd_score)) | |
| group3_score += asd_score * 100 * 0.5 | |
| # Final score calculation | |
| final_score = (group1_score * 5 + group2_score * 3 + group3_score * 2) / 10 | |
| return final_score | |
| except Exception as e: | |
| print(f"Error calculating final score: {str(e)}") | |
| return 0 | |
| def evaluate(folder: str, noisy_folder: str, roi_folder: str, reference_folder: str): | |
| """Evaluate the dehazing algorithm. | |
| Args: | |
| folder (str): Path to the folder containing the dehazed images. | |
| Used for evaluating all metrics. | |
| noisy_folder (str): Path to the folder containing the noisy images. | |
| Only used for KS statistics. | |
| roi_folder (str): Path to the folder containing the ROI images. | |
| Used for contrast and KS statistic metrics. | |
| reference_folder (str): Path to the folder containing the reference images. | |
| Used only for FID calculation. | |
| """ | |
| folder = Path(folder) | |
| noisy_folder = Path(noisy_folder) | |
| roi_folder = Path(roi_folder) | |
| reference_folder = Path(reference_folder) | |
| folder_files = set(f.name for f in folder.glob("*.png")) | |
| noisy_files = set(f.name for f in noisy_folder.glob("*.png")) | |
| roi_files = set(f.name for f in roi_folder.glob("*.png")) | |
| print(f"Found {len(folder_files)} .png files in output folder: {folder}") | |
| print(f"Found {len(noisy_files)} .png files in noisy folder: {noisy_folder}") | |
| print(f"Found {len(roi_files)} .png files in ROI folder: {roi_folder}") | |
| # Find intersection of filenames | |
| common_files = sorted(folder_files & roi_files & noisy_files) | |
| print(f"Found {len(common_files)} matching images in noisy/dehazed/roi folders") | |
| assert len(common_files) > 0, ( | |
| "No matching .png files in all folders. Cannot proceed." | |
| ) | |
| metrics = {"CNR": [], "gCNR": [], "KS_A": [], "KS_B": []} | |
| limits = { | |
| "CNR": [1.0, 1.5], | |
| "gCNR": [0.5, 0.8], | |
| "KS_A": [0.1, 0.3], | |
| "KS_B": [0.0, 0.5], | |
| } | |
| for name in common_files: | |
| dehazed_path = folder / name | |
| noisy_path = noisy_folder / name | |
| roi_path = roi_folder / name | |
| try: | |
| img_dehazed = np.array(load_image(str(dehazed_path))) | |
| img_noisy = np.array(load_image(str(noisy_path))) | |
| except Exception as e: | |
| print(f"Error loading image {name}: {e}") | |
| continue | |
| # CNR/gCNR | |
| cnr_gcnr = calculate_cnr_gcnr(img_dehazed, str(roi_path)) | |
| metrics["CNR"].append(cnr_gcnr[0][0]) | |
| metrics["gCNR"].append(cnr_gcnr[0][1]) | |
| # KS statistics | |
| ks_a, _, ks_b, _ = calculate_ks_statistics( | |
| img_noisy, img_dehazed, str(roi_path) | |
| ) | |
| metrics["KS_A"].append(ks_a) | |
| metrics["KS_B"].append(ks_b) | |
| # Compute statistics | |
| stats = { | |
| k: (np.mean(v), np.std(v), np.min(v), np.max(v)) for k, v in metrics.items() | |
| } | |
| print("Contrast statistics:") | |
| for k, (mean, std, minv, maxv) in stats.items(): | |
| print(f"{k}: mean={mean:.3f}, std={std:.3f}, min={minv:.3f}, max={maxv:.3f}") | |
| fig = plot_metrics(metrics, limits, "contrast_metrics.png") | |
| path = Path("contrast_metrics.png") | |
| save_kwargs = {"bbox_inches": "tight", "dpi": 300} | |
| fig.savefig(path, **save_kwargs) | |
| fig.savefig(path.with_suffix(".pdf"), **save_kwargs) | |
| log.success(f"Metrics plot saved to {log.yellow(path)}") | |
| # Compute FID | |
| fid_image_paths = [str(folder / name) for name in common_files] | |
| fid_score = calculate_fid_score(fid_image_paths, str(reference_folder)) | |
| print(f"FID between {folder} and {reference_folder}: {fid_score:.3f}") | |
| # Create aggregates dictionary for final score calculation | |
| aggregates = { | |
| "fid": float(fid_score), | |
| "cnr_mean": float(np.mean(metrics["CNR"])), | |
| "cnr_std": float(np.std(metrics["CNR"])), | |
| "gcnr_mean": float(np.mean(metrics["gCNR"])), | |
| "gcnr_std": float(np.std(metrics["gCNR"])), | |
| "ks_roi1_ksstatistic_mean": float(np.mean(metrics["KS_A"])), | |
| "ks_roi1_ksstatistic_std": float(np.std(metrics["KS_A"])), | |
| "ks_roi2_ksstatistic_mean": float(np.mean(metrics["KS_B"])), | |
| "ks_roi2_ksstatistic_std": float(np.std(metrics["KS_B"])), | |
| } | |
| # Calculate final score | |
| final_score = calculate_final_score(aggregates) | |
| aggregates["final_score"] = float(final_score) | |
| return aggregates | |
| if __name__ == "__main__": | |
| tyro.cli(evaluate) | |