import torch import torch.nn.functional as F import numpy as np import os import time import gradio as gr import cv2 from PIL import Image from model.CyueNet_models import MMS from utils1.data import transform_image # 设置GPU/CPU device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') def load_model(): """加载预训练的模型""" model = MMS() try: # 使用相对路径,模型文件将存储在HuggingFace Spaces上 model.load_state_dict(torch.load('models/CyueNet_EORSSD6.pth.54', map_location=device)) print("模型加载成功") except RuntimeError as e: print(f"加载状态字典时出现部分不匹配,错误信息: {e}") model.to(device) model.eval() return model def process_image(image, model, testsize=256): """处理图像并返回显著性检测结果""" # 预处理图像 image = Image.fromarray(image).convert('RGB') image = transform_image(image, testsize) image = image.unsqueeze(0) image = image.to(device) # 计时 time_start = time.time() # 推理 with torch.no_grad(): x1, res, s1_sig, edg1, edg_s, s2, e2, s2_sig, e2_sig, s3, e3, s3_sig, e3_sig, s4, e4, s4_sig, e4_sig, s5, e5, s5_sig, e5_sig, sk1, sk1_sig, sk2, sk2_sig, sk3, sk3_sig, sk4, sk4_sig, sk5, sk5_sig = model(image) time_end = time.time() inference_time = time_end - time_start # 处理输出结果 res = res.sigmoid().data.cpu().numpy().squeeze() res = (res - res.min()) / (res.max() - res.min() + 1e-8) # 将输出调整为原始图像大小 original_image = np.array(Image.fromarray(image.cpu().squeeze().permute(1, 2, 0).numpy())) h, w = original_image.shape[:2] res_resized = cv2.resize(res, (w, h)) # 转换为可视化图像 res_vis = (res_resized * 255).astype(np.uint8) # 创建热力图 heatmap = cv2.applyColorMap(res_vis, cv2.COLORMAP_JET) # 将热力图与原始图像混合 alpha = 0.5 overlayed = cv2.addWeighted(original_image, 1-alpha, heatmap, alpha, 0) # 二值化结果用于分割 _, binary_mask = cv2.threshold(res_vis, 127, 255, cv2.THRESH_BINARY) segmented = cv2.bitwise_and(original_image, original_image, mask=binary_mask) return original_image, res_vis, heatmap, overlayed, segmented, f"推理时间: {inference_time:.4f}秒" def run_demo(input_image): """Gradio界面的主函数""" if input_image is None: return [None] * 5 + ["请上传图片"] # 处理图像 original, saliency_map, heatmap, overlayed, segmented, time_info = process_image(input_image, model) return original, saliency_map, heatmap, overlayed, segmented, time_info # 加载模型 print("正在加载模型...") model = load_model() # 创建Gradio界面 with gr.Blocks(title="显著性目标检测Demo") as demo: gr.Markdown("# 显著性目标检测Demo") gr.Markdown("上传一张图片,系统将自动检测显著性区域") with gr.Row(): with gr.Column(): input_image = gr.Image(label="输入图像", type="numpy") submit_btn = gr.Button("开始检测") with gr.Column(): original_output = gr.Image(label="原始图像") saliency_output = gr.Image(label="显著性图") heatmap_output = gr.Image(label="热力图") overlayed_output = gr.Image(label="叠加结果") segmented_output = gr.Image(label="分割结果") time_info = gr.Textbox(label="处理信息") submit_btn.click( fn=run_demo, inputs=input_image, outputs=[original_output, saliency_output, heatmap_output, overlayed_output, segmented_output, time_info] ) gr.Markdown("## 使用说明") gr.Markdown("1. 点击'输入图像'区域上传一张图片") gr.Markdown("2. 点击'开始检测'按钮进行显著性目标检测") gr.Markdown("3. 系统将显示原始图像、显著性图、热力图、叠加结果和分割结果") # 启动Gradio应用 demo.launch()