import torch import torch.nn.functional as F import numpy as np import os import time import gradio as gr import cv2 from PIL import Image import matplotlib.pyplot as plt import concurrent.futures from model.CyueNet_models import MMS from utils1.data import transform_image from datetime import datetime import io import base64 # GPU/CPU设置 device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # CSS样式设置 custom_css = """ :root { --primary-color: #2196F3; --secondary-color: #21CBF3; --background-color: #f6f8fa; --text-color: #333; --border-radius: 10px; --glass-bg: rgba(255, 255, 255, 0.25); --shadow: 0 8px 32px 0 rgba(31, 38, 135, 0.37); } .gradio-container { background: linear-gradient(135deg, var(--background-color), #ffffff); max-width: 1400px !important; margin: auto !important; backdrop-filter: blur(10px); } .output-image, .input-image { border-radius: var(--border-radius); box-shadow: var(--shadow); transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1); backdrop-filter: blur(10px); border: 1px solid rgba(255, 255, 255, 0.18); } .output-image:hover, .input-image:hover { transform: scale(1.02) translateY(-2px); box-shadow: 0 12px 40px 0 rgba(31, 38, 135, 0.5); } .custom-button { background: linear-gradient(45deg, var(--primary-color), var(--secondary-color)); border: none; color: white; padding: 12px 24px; border-radius: var(--border-radius); cursor: pointer; transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1); font-weight: bold; text-transform: uppercase; letter-spacing: 1px; box-shadow: var(--shadow); } .custom-button:hover { transform: translateY(-3px); box-shadow: 0 12px 30px rgba(33, 150, 243, 0.4); } .advanced-controls { background: var(--glass-bg); border-radius: 20px; padding: 25px; box-shadow: var(--shadow); backdrop-filter: blur(10px); border: 1px solid rgba(255, 255, 255, 0.18); } .result-container { background: var(--glass-bg); border-radius: 20px; padding: 20px; backdrop-filter: blur(15px); border: 1px solid rgba(255, 255, 255, 0.18); box-shadow: var(--shadow); } .interactive-viz { border-radius: 15px; overflow: hidden; transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1); box-shadow: var(--shadow); } .interactive-viz:hover { transform: translateY(-5px); box-shadow: 0 15px 35px rgba(0,0,0,0.15); } .statistics-container { display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 15px; margin-top: 15px; } .statistic-card { background: var(--glass-bg); padding: 20px; border-radius: var(--border-radius); text-align: center; box-shadow: var(--shadow); backdrop-filter: blur(10px); border: 1px solid rgba(255, 255, 255, 0.18); transition: all 0.3s ease; } .statistic-card:hover { transform: translateY(-2px); box-shadow: 0 10px 25px rgba(0,0,0,0.1); } .progress-container { background: var(--glass-bg); border-radius: 10px; padding: 15px; margin: 10px 0; backdrop-filter: blur(10px); } .comparison-slider { background: var(--glass-bg); border-radius: 15px; padding: 20px; backdrop-filter: blur(10px); border: 1px solid rgba(255, 255, 255, 0.18); } """ class ImageProcessor: def __init__(self): self.model = None self.load_model() self.last_results = None self.cache = {} def load_model(self): """加载预训练的模型""" self.model = MMS() try: self.model.load_state_dict(torch.load('models/CyueNet_EORSSD6.pth.54', map_location=device)) print("模型加载成功") except RuntimeError as e: print(f"模型加载错误: {e}") except FileNotFoundError: print("未找到模型文件,请检查路径。") self.model.to(device) self.model.eval() def adjust_brightness_contrast(self, image, brightness=0, contrast=0): """调整图像亮度和对比度""" if brightness != 0: if brightness > 0: shadow = brightness highlight = 255 else: shadow = 0 highlight = 255 + brightness alpha_b = (highlight - shadow)/255 gamma_b = shadow image = cv2.addWeighted(image, alpha_b, image, 0, gamma_b) if contrast != 0: f = 131*(contrast + 127)/(127*(131-contrast)) alpha_c = f gamma_c = 127*(1-f) image = cv2.addWeighted(image, alpha_c, image, 0, gamma_c) return image def apply_filters(self, image, filter_type): """应用图像滤镜效果""" if filter_type == "锐化": kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]]) return cv2.filter2D(image, -1, kernel) elif filter_type == "模糊": return cv2.GaussianBlur(image, (5,5), 0) elif filter_type == "边缘增强": kernel = np.array([[0,-1,0], [-1,5,-1], [0,-1,0]]) return cv2.filter2D(image, -1, kernel) return image def generate_analysis_plots(self, saliency_map): """生成分析图表 - 使用原始显著性值(二值化之前)""" plt.style.use('seaborn-v0_8') fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 8)) # 使用原始显著性值(二值化之前) saliency_values = saliency_map.flatten() # 直方图 ax1.hist(saliency_values, bins=50, color='#2196F3', alpha=0.7, edgecolor='black') ax1.set_title('Histogram of Saliency Distribution', fontsize=12, pad=15) ax1.set_xlabel('Saliency Value', fontsize=10) ax1.set_ylabel('Frequency', fontsize=10) ax1.grid(True, alpha=0.3) # 添加统计信息 mean_val = np.mean(saliency_values) median_val = np.median(saliency_values) ax1.axvline(mean_val, color='red', linestyle='--', alpha=0.7, label=f'Mean: {mean_val:.3f}') ax1.axvline(median_val, color='green', linestyle='--', alpha=0.7, label=f'Median: {median_val:.3f}') ax1.legend() # 累积分布 sorted_vals = np.sort(saliency_values) cumulative = np.arange(1, len(sorted_vals) + 1) / len(sorted_vals) ax2.plot(sorted_vals, cumulative, color='#FF6B35', linewidth=2) ax2.set_title('Cumulative Distribution Function', fontsize=12) ax2.set_xlabel('Saliency Value', fontsize=10) ax2.set_ylabel('Cumulative Probability', fontsize=10) ax2.grid(True, alpha=0.3) # 箱线图 ax3.boxplot(saliency_values, patch_artist=True, boxprops=dict(facecolor='#21CBF3', alpha=0.7)) ax3.set_title('Boxplot of Saliency Distribution', fontsize=12) ax3.set_ylabel('Saliency Value', fontsize=10) ax3.grid(True, alpha=0.3) # 强度剖面(中心线) center_row = saliency_map[saliency_map.shape[0]//2, :] ax4.plot(center_row, color='#9C27B0', linewidth=2) ax4.set_title('Intensity Profile along Center Line', fontsize=12) ax4.set_xlabel('Pixel Position', fontsize=10) ax4.set_ylabel('Saliency Value', fontsize=10) ax4.grid(True, alpha=0.3) plt.tight_layout() # 保存为字节 buf = io.BytesIO() plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') buf.seek(0) img_array = np.array(Image.open(buf)) plt.close() return img_array def quick_process(self, image, threshold=0.5, testsize=256): if image is None: return None, "请提供有效的图像" # 检查缓存 image_hash = hash(image.tobytes()) cache_key = f"{image_hash}_{threshold}_{testsize}_quick" if cache_key in self.cache: return self.cache[cache_key] image_pil = Image.fromarray(image).convert('RGB') image_tensor = transform_image(image_pil, testsize) image_tensor = image_tensor.unsqueeze(0).to(device) time_start = time.time() with torch.no_grad(): # 关键修改:只计算必要的输出,避免完整模型计算 if device.type == 'cuda': with torch.cuda.amp.autocast(): _, res = self.model.forward_quick(image_tensor) # 使用简化版前向传播 else: with torch.amp.autocast(device_type='cpu'): _, res = self.model.forward_quick(image_tensor) # 使用简化版前向传播 time_end = time.time() # 确保转换为float32类型 res = res.to(torch.float32).sigmoid().cpu().numpy().squeeze() res = (res - res.min()) / (res.max() - res.min() + 1e-8) h, w = image.shape[:2] res_resized = cv2.resize(res, (w, h)) res_vis = (res_resized * 255).astype(np.uint8) result = (res_vis, f"快速处理完成,耗时 {time_end - time_start:.3f}秒") self.cache[cache_key] = result return result def process_image(self, image, threshold=0.5, testsize=256, enhance_contrast=False, denoise=False, brightness=0, contrast=0, filter_type="无", process_mode="完整分析"): """增强的图像处理函数""" if image is None: return [None] * 9 + ["请提供有效的图像"] # 快速模式检查 if process_mode == "快速模式": saliency_map, time_info = self.quick_process(image, threshold, testsize) return (image, saliency_map, None, None, None, None, time_info, None, None) # 检查完整处理的缓存 image_hash = hash(image.tobytes()) cache_key = f"{image_hash}_{threshold}_{testsize}_{enhance_contrast}_{denoise}_{brightness}_{contrast}_{filter_type}_full" if cache_key in self.cache: return self.cache[cache_key] # 使用线程进行图像预处理 def preprocess_image(): processed_image = image.copy() if denoise: processed_image = cv2.fastNlMeansDenoisingColored(processed_image, None, 10, 10, 7, 21) processed_image = self.adjust_brightness_contrast(processed_image, brightness, contrast) processed_image = self.apply_filters(processed_image, filter_type) if enhance_contrast: lab = cv2.cvtColor(processed_image, cv2.COLOR_RGB2LAB) l, a, b = cv2.split(lab) clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8)) l = clahe.apply(l) lab = cv2.merge((l,a,b)) processed_image = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB) return processed_image with concurrent.futures.ThreadPoolExecutor() as executor: future_preprocess = executor.submit(preprocess_image) processed_image = future_preprocess.result() original_image = processed_image.copy() # 模型推理 image_pil = Image.fromarray(processed_image).convert('RGB') image_tensor = transform_image(image_pil, testsize) image_tensor = image_tensor.unsqueeze(0).to(device) time_start = time.time() with torch.no_grad(): if device.type == 'cuda': with torch.cuda.amp.autocast(): x1, res, s1_sig, edg1, edg_s, s2, e2, s2_sig, e2_sig, s3, e3, s3_sig, e3_sig, s4, e4, s4_sig, e4_sig, s5, e5, s5_sig, e5_sig, sk1, sk1_sig, sk2, sk2_sig, sk3, sk3_sig, sk4, sk4_sig, sk5, sk5_sig = self.model(image_tensor) else: with torch.amp.autocast(device_type='cpu'): x1, res, s1_sig, edg1, edg_s, s2, e2, s2_sig, e2_sig, s3, e3, s3_sig, e3_sig, s4, e4, s4_sig, e4_sig, s5, e5, s5_sig, e5_sig, sk1, sk1_sig, sk2, sk2_sig, sk3, sk3_sig, sk4, sk4_sig, sk5, sk5_sig = self.model(image_tensor) time_end = time.time() inference_time = time_end - time_start # 确保转换为float32类型 res = res.to(torch.float32).sigmoid().cpu().numpy().squeeze() res = (res - res.min()) / (res.max() - res.min() + 1e-8) h, w = original_image.shape[:2] res_resized = cv2.resize(res, (w, h)) # 生成可视化 res_vis = (res_resized * 255).astype(np.uint8) heatmap = cv2.applyColorMap(res_vis, cv2.COLORMAP_JET) _, binary_mask = cv2.threshold(res_vis, int(255 * threshold), 255, cv2.THRESH_BINARY) # 创建叠加效果 alpha = 0.5 original_bgr = cv2.cvtColor(original_image, cv2.COLOR_RGB2BGR) overlayed = cv2.addWeighted(original_bgr, 1-alpha, heatmap, alpha, 0) segmented = cv2.bitwise_and(original_bgr, original_bgr, mask=binary_mask) # 转换回RGB overlayed_rgb = cv2.cvtColor(overlayed, cv2.COLOR_BGR2RGB) segmented_rgb = cv2.cvtColor(segmented, cv2.COLOR_BGR2RGB) heatmap_rgb = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) # 生成分析图表 - 使用原始显著性值(二值化之前) analysis_plot = self.generate_analysis_plots(res_resized) # 计算统计信息 contours = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0] total_area = w * h detected_area = cv2.countNonZero(binary_mask) coverage_ratio = detected_area / total_area stats = { "处理分辨率": f"{w}x{h}", "检测到对象数": str(len(contours)), "平均置信度": f"{np.mean(res_resized):.2%}", "最大置信度": f"{np.max(res_resized):.2%}", "覆盖率": f"{coverage_ratio:.2%}", "处理时间": f"{inference_time:.3f}秒" } # 创建对比图像 comparison_img = self.create_comparison_image(original_image, overlayed_rgb) # 保存结果 self.last_results = { 'saliency_map': res_resized, 'binary_mask': binary_mask, 'stats': stats } result = (original_image, res_vis, heatmap_rgb, overlayed_rgb, segmented_rgb, comparison_img, f"处理时间: {inference_time:.4f}秒", stats, analysis_plot) # 缓存结果 self.cache[cache_key] = result return result def create_comparison_image(self, original, processed): """创建对比图像""" h, w = original.shape[:2] comparison = np.zeros((h, w*2, 3), dtype=np.uint8) comparison[:, :w] = original comparison[:, w:] = processed # 添加分界线 cv2.line(comparison, (w, 0), (w, h), (255, 255, 255), 2) return comparison def export_results(self, format_type="PNG"): """导出结果""" if self.last_results is None: return "没有结果可供导出" timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") if format_type == "PDF报告": # 生成PDF报告逻辑 return f"PDF报告已保存为 saliency_report_{timestamp}.pdf" else: return f"结果已导出为 {format_type.lower()} 文件" # Create processor instance processor = ImageProcessor() def run_demo(input_image, threshold, enhance_contrast, denoise, show_contours, brightness, contrast, filter_type, process_mode): """主处理函数""" if input_image is None: return [None] * 9 + ["请上传图像"] # 处理图像 results = processor.process_image( input_image, threshold=threshold/100.0, enhance_contrast=enhance_contrast, denoise=denoise, brightness=brightness, contrast=contrast, filter_type=filter_type, process_mode=process_mode ) original, saliency_map, heatmap, overlayed, segmented, comparison, time_info, stats, analysis_plot = results # 如果需要显示轮廓 if show_contours and saliency_map is not None and overlayed is not None: _, binary = cv2.threshold(saliency_map, 127, 255, cv2.THRESH_BINARY) contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) overlay_with_contours = overlayed.copy() cv2.drawContours(overlay_with_contours, contours, -1, (0,255,0), 2) overlayed = overlay_with_contours # 生成统计信息HTML if stats: stats_html = "
{value}
无可用统计信息
" return (original, saliency_map, heatmap, overlayed, segmented, comparison, time_info, stats_html, analysis_plot) def create_comparison_view(original, result, slider_value): """创建滑块对比视图""" if original is None or result is None: return None h, w = original.shape[:2] split_point = int(w * slider_value) comparison = original.copy() comparison[:, split_point:] = result[:, split_point:] # 添加垂直线 cv2.line(comparison, (split_point, 0), (split_point, h), (255, 255, 0), 3) return comparison # Create Gradio interface with gr.Blocks(title="高级显著性对象检测系统", css=custom_css) as demo: gr.Markdown( """ # 🎯 高级显著性对象检测系统 ### AI驱动的图像显著性检测与分析工具 """ ) with gr.Tabs() as tabs: with gr.TabItem("🔍 主功能"): with gr.Row(): with gr.Column(scale=1): # 输入控件 with gr.Group(elem_classes="advanced-controls"): input_image = gr.Image( label="输入图像", type="numpy", elem_classes="input-image" ) # 处理模式选择 process_mode = gr.Radio( choices=["完整分析", "快速模式"], value="完整分析", label="处理模式", info="快速模式仅输出显著性图,处理速度更快" ) with gr.Accordion("基本设置", open=True): threshold_slider = gr.Slider( minimum=0, maximum=100, value=50, step=1, label="检测阈值", info="调整检测灵敏度" ) enhance_contrast = gr.Checkbox( label="增强对比度", value=False ) denoise = gr.Checkbox( label="降噪", value=False ) show_contours = gr.Checkbox( label="显示轮廓", value=True ) with gr.Accordion("图像调整", open=False): brightness = gr.Slider( minimum=-100, maximum=100, value=0, step=1, label="亮度" ) contrast = gr.Slider( minimum=-100, maximum=100, value=0, step=1, label="对比度" ) filter_type = gr.Radio( choices=["无", "锐化", "模糊", "边缘增强"], value="无", label="图像滤镜" ) with gr.Accordion("导出选项", open=False): export_format = gr.Dropdown( choices=["PNG", "JPEG", "PDF报告"], value="PNG", label="导出格式" ) export_btn = gr.Button( "导出结果", elem_classes="custom-button" ) with gr.Row(): submit_btn = gr.Button( "开始检测", variant="primary", elem_classes="custom-button" ) reset_btn = gr.Button( "重置参数", elem_classes="custom-button" ) with gr.Column(scale=2): # 结果显示 with gr.Tabs(): with gr.TabItem("检测结果"): with gr.Row(elem_classes="result-container"): original_output = gr.Image( label="原始图像", elem_classes="output-image" ) saliency_output = gr.Image( label="显著性图", elem_classes="output-image" ) with gr.Row(elem_classes="result-container"): heatmap_output = gr.Image( label="热力图分析", elem_classes="output-image" ) overlayed_output = gr.Image( label="叠加效果", elem_classes="output-image" ) with gr.Row(elem_classes="result-container"): segmented_output = gr.Image( label="对象分割", elem_classes="output-image" ) comparison_output = gr.Image( label="并排对比", elem_classes="output-image" ) with gr.TabItem("交互式对比"): with gr.Group(elem_classes="comparison-slider"): comparison_slider = gr.Slider( minimum=0, maximum=1, value=0.5, step=0.01, label="原始 ← → 结果", info="拖动滑块对比原始图像和处理结果" ) interactive_comparison = gr.Image( label="交互式对比视图", elem_classes="interactive-viz" ) with gr.TabItem("分析报告"): with gr.Group(elem_classes="result-container"): time_info = gr.Textbox( label="处理时间", show_label=True ) stats_output = gr.HTML( label="统计信息" ) analysis_plot = gr.Image( label="详细分析图表", elem_classes="output-image" ) with gr.TabItem("📖 用户指南"): gr.Markdown( """ ## 使用说明 1. **上传图像**:点击"输入图像"区域上传您的图像 2. **选择模式**:选择"完整分析"或"快速模式" - 完整分析:完整处理流程,包含所有可视化结果 - 快速模式:快速处理,仅输出显著性图 3. **调整参数**: - 使用阈值滑块调整检测灵敏度 - 根据需要启用对比度增强或降噪 - 在高级设置中微调亮度、对比度和滤镜 4. **开始检测**:点击"开始检测"按钮开始分析 5. **查看结果**:在不同标签页查看各种可视化结果 6. **导出**:使用导出选项保存您的结果 ## 功能特点 - **显著性图**:显示图像区域的显著性分布 - **热力图**:彩色编码的强度可视化 - **叠加效果**:在原始图像上叠加检测结果 - **对象分割**:提取关键对象区域 - **交互式对比**:滑动比较原始图像和处理结果 - **分析报告**:详细的统计信息和分析图表 ## 性能提示 - 当只需要显著性图时使用快速模式 - 分辨率较低的图像处理速度更快 - 启用GPU可获得更好的性能 """ ) with gr.TabItem("ℹ️ 关于"): gr.Markdown( """ ## 项目信息 - **版本**:3.0.0 - **架构**:PyTorch + Gradio - **模型**:CyueNet - **语言**:多语言支持 ## 主要特点 - 实时图像处理和分析 - 多维结果可视化 - 丰富的图像调整选项 - 详细的数据分析报告 - 交互式对比工具 - 导出功能 - 缓存优化性能 ## 更新日志 - ✅ 新增快速模式,提高处理速度 - ✅ 增强图像预处理选项 - ✅ 新增统计分析功能 - ✅ 改进用户界面,采用玻璃拟态设计 - ✅ 增加交互式对比滑块 - ✅ 使用缓存和线程优化性能 - ✅ 多语言图表支持 - ✅ 导出功能 ## 系统要求 - Python 3.8+ - PyTorch 1.9+ - CUDA(可选,用于GPU加速) - 推荐4GB以上内存 """ ) # 事件处理 def reset_params(): return { threshold_slider: 50, brightness: 0, contrast: 0, filter_type: "无", enhance_contrast: False, denoise: False, show_contours: True, process_mode: "完整分析" } # 设置事件处理 submit_btn.click( fn=run_demo, inputs=[ input_image, threshold_slider, enhance_contrast, denoise, show_contours, brightness, contrast, filter_type, process_mode ], outputs=[ original_output, saliency_output, heatmap_output, overlayed_output, segmented_output, comparison_output, time_info, stats_output, analysis_plot ] ) reset_btn.click( fn=reset_params, inputs=[], outputs=[ threshold_slider, brightness, contrast, filter_type, enhance_contrast, denoise, show_contours, process_mode ] ) # 交互式对比 comparison_slider.change( fn=create_comparison_view, inputs=[original_output, overlayed_output, comparison_slider], outputs=[interactive_comparison] ) # 导出功能 export_btn.click( fn=processor.export_results, inputs=[export_format], outputs=[gr.Textbox(label="导出状态")] ) # 启动应用 if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, share=True, show_error=True )