kunkk's picture
Update app.py
9b80d35 verified
import torch
import torch.nn.functional as F
import numpy as np
import os
import time
import gradio as gr
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import concurrent.futures
from model.CyueNet_models import MMS
from utils1.data import transform_image
from datetime import datetime
import io
import base64
# GPU/CPU设置
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# CSS样式设置
custom_css = """
:root {
--primary-color: #2196F3;
--secondary-color: #21CBF3;
--background-color: #f6f8fa;
--text-color: #333;
--border-radius: 10px;
--glass-bg: rgba(255, 255, 255, 0.25);
--shadow: 0 8px 32px 0 rgba(31, 38, 135, 0.37);
}
.gradio-container {
background: linear-gradient(135deg, var(--background-color), #ffffff);
max-width: 1400px !important;
margin: auto !important;
backdrop-filter: blur(10px);
}
.output-image, .input-image {
border-radius: var(--border-radius);
box-shadow: var(--shadow);
transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
backdrop-filter: blur(10px);
border: 1px solid rgba(255, 255, 255, 0.18);
}
.output-image:hover, .input-image:hover {
transform: scale(1.02) translateY(-2px);
box-shadow: 0 12px 40px 0 rgba(31, 38, 135, 0.5);
}
.custom-button {
background: linear-gradient(45deg, var(--primary-color), var(--secondary-color));
border: none;
color: white;
padding: 12px 24px;
border-radius: var(--border-radius);
cursor: pointer;
transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
font-weight: bold;
text-transform: uppercase;
letter-spacing: 1px;
box-shadow: var(--shadow);
}
.custom-button:hover {
transform: translateY(-3px);
box-shadow: 0 12px 30px rgba(33, 150, 243, 0.4);
}
.advanced-controls {
background: var(--glass-bg);
border-radius: 20px;
padding: 25px;
box-shadow: var(--shadow);
backdrop-filter: blur(10px);
border: 1px solid rgba(255, 255, 255, 0.18);
}
.result-container {
background: var(--glass-bg);
border-radius: 20px;
padding: 20px;
backdrop-filter: blur(15px);
border: 1px solid rgba(255, 255, 255, 0.18);
box-shadow: var(--shadow);
}
.interactive-viz {
border-radius: 15px;
overflow: hidden;
transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
box-shadow: var(--shadow);
}
.interactive-viz:hover {
transform: translateY(-5px);
box-shadow: 0 15px 35px rgba(0,0,0,0.15);
}
.statistics-container {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
gap: 15px;
margin-top: 15px;
}
.statistic-card {
background: var(--glass-bg);
padding: 20px;
border-radius: var(--border-radius);
text-align: center;
box-shadow: var(--shadow);
backdrop-filter: blur(10px);
border: 1px solid rgba(255, 255, 255, 0.18);
transition: all 0.3s ease;
}
.statistic-card:hover {
transform: translateY(-2px);
box-shadow: 0 10px 25px rgba(0,0,0,0.1);
}
.progress-container {
background: var(--glass-bg);
border-radius: 10px;
padding: 15px;
margin: 10px 0;
backdrop-filter: blur(10px);
}
.comparison-slider {
background: var(--glass-bg);
border-radius: 15px;
padding: 20px;
backdrop-filter: blur(10px);
border: 1px solid rgba(255, 255, 255, 0.18);
}
"""
class ImageProcessor:
def __init__(self):
self.model = None
self.load_model()
self.last_results = None
self.cache = {}
def load_model(self):
"""加载预训练的模型"""
self.model = MMS()
try:
self.model.load_state_dict(torch.load('models/CyueNet_EORSSD6.pth.54', map_location=device))
print("模型加载成功")
except RuntimeError as e:
print(f"模型加载错误: {e}")
except FileNotFoundError:
print("未找到模型文件,请检查路径。")
self.model.to(device)
self.model.eval()
def adjust_brightness_contrast(self, image, brightness=0, contrast=0):
"""调整图像亮度和对比度"""
if brightness != 0:
if brightness > 0:
shadow = brightness
highlight = 255
else:
shadow = 0
highlight = 255 + brightness
alpha_b = (highlight - shadow)/255
gamma_b = shadow
image = cv2.addWeighted(image, alpha_b, image, 0, gamma_b)
if contrast != 0:
f = 131*(contrast + 127)/(127*(131-contrast))
alpha_c = f
gamma_c = 127*(1-f)
image = cv2.addWeighted(image, alpha_c, image, 0, gamma_c)
return image
def apply_filters(self, image, filter_type):
"""应用图像滤镜效果"""
if filter_type == "锐化":
kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]])
return cv2.filter2D(image, -1, kernel)
elif filter_type == "模糊":
return cv2.GaussianBlur(image, (5,5), 0)
elif filter_type == "边缘增强":
kernel = np.array([[0,-1,0], [-1,5,-1], [0,-1,0]])
return cv2.filter2D(image, -1, kernel)
return image
def generate_analysis_plots(self, saliency_map):
"""生成分析图表 - 使用原始显著性值(二值化之前)"""
plt.style.use('seaborn-v0_8')
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 8))
# 使用原始显著性值(二值化之前)
saliency_values = saliency_map.flatten()
# 直方图
ax1.hist(saliency_values, bins=50, color='#2196F3', alpha=0.7, edgecolor='black')
ax1.set_title('Histogram of Saliency Distribution', fontsize=12, pad=15)
ax1.set_xlabel('Saliency Value', fontsize=10)
ax1.set_ylabel('Frequency', fontsize=10)
ax1.grid(True, alpha=0.3)
# 添加统计信息
mean_val = np.mean(saliency_values)
median_val = np.median(saliency_values)
ax1.axvline(mean_val, color='red', linestyle='--', alpha=0.7, label=f'Mean: {mean_val:.3f}')
ax1.axvline(median_val, color='green', linestyle='--', alpha=0.7, label=f'Median: {median_val:.3f}')
ax1.legend()
# 累积分布
sorted_vals = np.sort(saliency_values)
cumulative = np.arange(1, len(sorted_vals) + 1) / len(sorted_vals)
ax2.plot(sorted_vals, cumulative, color='#FF6B35', linewidth=2)
ax2.set_title('Cumulative Distribution Function', fontsize=12)
ax2.set_xlabel('Saliency Value', fontsize=10)
ax2.set_ylabel('Cumulative Probability', fontsize=10)
ax2.grid(True, alpha=0.3)
# 箱线图
ax3.boxplot(saliency_values, patch_artist=True,
boxprops=dict(facecolor='#21CBF3', alpha=0.7))
ax3.set_title('Boxplot of Saliency Distribution', fontsize=12)
ax3.set_ylabel('Saliency Value', fontsize=10)
ax3.grid(True, alpha=0.3)
# 强度剖面(中心线)
center_row = saliency_map[saliency_map.shape[0]//2, :]
ax4.plot(center_row, color='#9C27B0', linewidth=2)
ax4.set_title('Intensity Profile along Center Line', fontsize=12)
ax4.set_xlabel('Pixel Position', fontsize=10)
ax4.set_ylabel('Saliency Value', fontsize=10)
ax4.grid(True, alpha=0.3)
plt.tight_layout()
# 保存为字节
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
buf.seek(0)
img_array = np.array(Image.open(buf))
plt.close()
return img_array
def quick_process(self, image, threshold=0.5, testsize=256):
if image is None:
return None, "请提供有效的图像"
# 检查缓存
image_hash = hash(image.tobytes())
cache_key = f"{image_hash}_{threshold}_{testsize}_quick"
if cache_key in self.cache:
return self.cache[cache_key]
image_pil = Image.fromarray(image).convert('RGB')
image_tensor = transform_image(image_pil, testsize)
image_tensor = image_tensor.unsqueeze(0).to(device)
time_start = time.time()
with torch.no_grad():
# 关键修改:只计算必要的输出,避免完整模型计算
if device.type == 'cuda':
with torch.cuda.amp.autocast():
_, res = self.model.forward_quick(image_tensor) # 使用简化版前向传播
else:
with torch.amp.autocast(device_type='cpu'):
_, res = self.model.forward_quick(image_tensor) # 使用简化版前向传播
time_end = time.time()
# 确保转换为float32类型
res = res.to(torch.float32).sigmoid().cpu().numpy().squeeze()
res = (res - res.min()) / (res.max() - res.min() + 1e-8)
h, w = image.shape[:2]
res_resized = cv2.resize(res, (w, h))
res_vis = (res_resized * 255).astype(np.uint8)
result = (res_vis, f"快速处理完成,耗时 {time_end - time_start:.3f}秒")
self.cache[cache_key] = result
return result
def process_image(self, image, threshold=0.5, testsize=256,
enhance_contrast=False, denoise=False,
brightness=0, contrast=0, filter_type="无",
process_mode="完整分析"):
"""增强的图像处理函数"""
if image is None:
return [None] * 9 + ["请提供有效的图像"]
# 快速模式检查
if process_mode == "快速模式":
saliency_map, time_info = self.quick_process(image, threshold, testsize)
return (image, saliency_map, None, None, None, None, time_info, None, None)
# 检查完整处理的缓存
image_hash = hash(image.tobytes())
cache_key = f"{image_hash}_{threshold}_{testsize}_{enhance_contrast}_{denoise}_{brightness}_{contrast}_{filter_type}_full"
if cache_key in self.cache:
return self.cache[cache_key]
# 使用线程进行图像预处理
def preprocess_image():
processed_image = image.copy()
if denoise:
processed_image = cv2.fastNlMeansDenoisingColored(processed_image, None, 10, 10, 7, 21)
processed_image = self.adjust_brightness_contrast(processed_image, brightness, contrast)
processed_image = self.apply_filters(processed_image, filter_type)
if enhance_contrast:
lab = cv2.cvtColor(processed_image, cv2.COLOR_RGB2LAB)
l, a, b = cv2.split(lab)
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
l = clahe.apply(l)
lab = cv2.merge((l,a,b))
processed_image = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
return processed_image
with concurrent.futures.ThreadPoolExecutor() as executor:
future_preprocess = executor.submit(preprocess_image)
processed_image = future_preprocess.result()
original_image = processed_image.copy()
# 模型推理
image_pil = Image.fromarray(processed_image).convert('RGB')
image_tensor = transform_image(image_pil, testsize)
image_tensor = image_tensor.unsqueeze(0).to(device)
time_start = time.time()
with torch.no_grad():
if device.type == 'cuda':
with torch.cuda.amp.autocast():
x1, res, s1_sig, edg1, edg_s, s2, e2, s2_sig, e2_sig, s3, e3, s3_sig, e3_sig, s4, e4, s4_sig, e4_sig, s5, e5, s5_sig, e5_sig, sk1, sk1_sig, sk2, sk2_sig, sk3, sk3_sig, sk4, sk4_sig, sk5, sk5_sig = self.model(image_tensor)
else:
with torch.amp.autocast(device_type='cpu'):
x1, res, s1_sig, edg1, edg_s, s2, e2, s2_sig, e2_sig, s3, e3, s3_sig, e3_sig, s4, e4, s4_sig, e4_sig, s5, e5, s5_sig, e5_sig, sk1, sk1_sig, sk2, sk2_sig, sk3, sk3_sig, sk4, sk4_sig, sk5, sk5_sig = self.model(image_tensor)
time_end = time.time()
inference_time = time_end - time_start
# 确保转换为float32类型
res = res.to(torch.float32).sigmoid().cpu().numpy().squeeze()
res = (res - res.min()) / (res.max() - res.min() + 1e-8)
h, w = original_image.shape[:2]
res_resized = cv2.resize(res, (w, h))
# 生成可视化
res_vis = (res_resized * 255).astype(np.uint8)
heatmap = cv2.applyColorMap(res_vis, cv2.COLORMAP_JET)
_, binary_mask = cv2.threshold(res_vis, int(255 * threshold), 255, cv2.THRESH_BINARY)
# 创建叠加效果
alpha = 0.5
original_bgr = cv2.cvtColor(original_image, cv2.COLOR_RGB2BGR)
overlayed = cv2.addWeighted(original_bgr, 1-alpha, heatmap, alpha, 0)
segmented = cv2.bitwise_and(original_bgr, original_bgr, mask=binary_mask)
# 转换回RGB
overlayed_rgb = cv2.cvtColor(overlayed, cv2.COLOR_BGR2RGB)
segmented_rgb = cv2.cvtColor(segmented, cv2.COLOR_BGR2RGB)
heatmap_rgb = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
# 生成分析图表 - 使用原始显著性值(二值化之前)
analysis_plot = self.generate_analysis_plots(res_resized)
# 计算统计信息
contours = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
total_area = w * h
detected_area = cv2.countNonZero(binary_mask)
coverage_ratio = detected_area / total_area
stats = {
"处理分辨率": f"{w}x{h}",
"检测到对象数": str(len(contours)),
"平均置信度": f"{np.mean(res_resized):.2%}",
"最大置信度": f"{np.max(res_resized):.2%}",
"覆盖率": f"{coverage_ratio:.2%}",
"处理时间": f"{inference_time:.3f}秒"
}
# 创建对比图像
comparison_img = self.create_comparison_image(original_image, overlayed_rgb)
# 保存结果
self.last_results = {
'saliency_map': res_resized,
'binary_mask': binary_mask,
'stats': stats
}
result = (original_image, res_vis, heatmap_rgb, overlayed_rgb, segmented_rgb,
comparison_img, f"处理时间: {inference_time:.4f}秒", stats, analysis_plot)
# 缓存结果
self.cache[cache_key] = result
return result
def create_comparison_image(self, original, processed):
"""创建对比图像"""
h, w = original.shape[:2]
comparison = np.zeros((h, w*2, 3), dtype=np.uint8)
comparison[:, :w] = original
comparison[:, w:] = processed
# 添加分界线
cv2.line(comparison, (w, 0), (w, h), (255, 255, 255), 2)
return comparison
def export_results(self, format_type="PNG"):
"""导出结果"""
if self.last_results is None:
return "没有结果可供导出"
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
if format_type == "PDF报告":
# 生成PDF报告逻辑
return f"PDF报告已保存为 saliency_report_{timestamp}.pdf"
else:
return f"结果已导出为 {format_type.lower()} 文件"
# Create processor instance
processor = ImageProcessor()
def run_demo(input_image, threshold, enhance_contrast, denoise, show_contours,
brightness, contrast, filter_type, process_mode):
"""主处理函数"""
if input_image is None:
return [None] * 9 + ["请上传图像"]
# 处理图像
results = processor.process_image(
input_image,
threshold=threshold/100.0,
enhance_contrast=enhance_contrast,
denoise=denoise,
brightness=brightness,
contrast=contrast,
filter_type=filter_type,
process_mode=process_mode
)
original, saliency_map, heatmap, overlayed, segmented, comparison, time_info, stats, analysis_plot = results
# 如果需要显示轮廓
if show_contours and saliency_map is not None and overlayed is not None:
_, binary = cv2.threshold(saliency_map, 127, 255, cv2.THRESH_BINARY)
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
overlay_with_contours = overlayed.copy()
cv2.drawContours(overlay_with_contours, contours, -1, (0,255,0), 2)
overlayed = overlay_with_contours
# 生成统计信息HTML
if stats:
stats_html = "<div class='statistics-container'>"
for key, value in stats.items():
stats_html += f"<div class='statistic-card'><h4>{key}</h4><p>{value}</p></div>"
stats_html += "</div>"
else:
stats_html = "<p>无可用统计信息</p>"
return (original, saliency_map, heatmap, overlayed, segmented,
comparison, time_info, stats_html, analysis_plot)
def create_comparison_view(original, result, slider_value):
"""创建滑块对比视图"""
if original is None or result is None:
return None
h, w = original.shape[:2]
split_point = int(w * slider_value)
comparison = original.copy()
comparison[:, split_point:] = result[:, split_point:]
# 添加垂直线
cv2.line(comparison, (split_point, 0), (split_point, h), (255, 255, 0), 3)
return comparison
# Create Gradio interface
with gr.Blocks(title="高级显著性对象检测系统", css=custom_css) as demo:
gr.Markdown(
"""
# 🎯 高级显著性对象检测系统
### AI驱动的图像显著性检测与分析工具
"""
)
with gr.Tabs() as tabs:
with gr.TabItem("🔍 主功能"):
with gr.Row():
with gr.Column(scale=1):
# 输入控件
with gr.Group(elem_classes="advanced-controls"):
input_image = gr.Image(
label="输入图像",
type="numpy",
elem_classes="input-image"
)
# 处理模式选择
process_mode = gr.Radio(
choices=["完整分析", "快速模式"],
value="完整分析",
label="处理模式",
info="快速模式仅输出显著性图,处理速度更快"
)
with gr.Accordion("基本设置", open=True):
threshold_slider = gr.Slider(
minimum=0,
maximum=100,
value=50,
step=1,
label="检测阈值",
info="调整检测灵敏度"
)
enhance_contrast = gr.Checkbox(
label="增强对比度",
value=False
)
denoise = gr.Checkbox(
label="降噪",
value=False
)
show_contours = gr.Checkbox(
label="显示轮廓",
value=True
)
with gr.Accordion("图像调整", open=False):
brightness = gr.Slider(
minimum=-100,
maximum=100,
value=0,
step=1,
label="亮度"
)
contrast = gr.Slider(
minimum=-100,
maximum=100,
value=0,
step=1,
label="对比度"
)
filter_type = gr.Radio(
choices=["无", "锐化", "模糊", "边缘增强"],
value="无",
label="图像滤镜"
)
with gr.Accordion("导出选项", open=False):
export_format = gr.Dropdown(
choices=["PNG", "JPEG", "PDF报告"],
value="PNG",
label="导出格式"
)
export_btn = gr.Button(
"导出结果",
elem_classes="custom-button"
)
with gr.Row():
submit_btn = gr.Button(
"开始检测",
variant="primary",
elem_classes="custom-button"
)
reset_btn = gr.Button(
"重置参数",
elem_classes="custom-button"
)
with gr.Column(scale=2):
# 结果显示
with gr.Tabs():
with gr.TabItem("检测结果"):
with gr.Row(elem_classes="result-container"):
original_output = gr.Image(
label="原始图像",
elem_classes="output-image"
)
saliency_output = gr.Image(
label="显著性图",
elem_classes="output-image"
)
with gr.Row(elem_classes="result-container"):
heatmap_output = gr.Image(
label="热力图分析",
elem_classes="output-image"
)
overlayed_output = gr.Image(
label="叠加效果",
elem_classes="output-image"
)
with gr.Row(elem_classes="result-container"):
segmented_output = gr.Image(
label="对象分割",
elem_classes="output-image"
)
comparison_output = gr.Image(
label="并排对比",
elem_classes="output-image"
)
with gr.TabItem("交互式对比"):
with gr.Group(elem_classes="comparison-slider"):
comparison_slider = gr.Slider(
minimum=0,
maximum=1,
value=0.5,
step=0.01,
label="原始 ← → 结果",
info="拖动滑块对比原始图像和处理结果"
)
interactive_comparison = gr.Image(
label="交互式对比视图",
elem_classes="interactive-viz"
)
with gr.TabItem("分析报告"):
with gr.Group(elem_classes="result-container"):
time_info = gr.Textbox(
label="处理时间",
show_label=True
)
stats_output = gr.HTML(
label="统计信息"
)
analysis_plot = gr.Image(
label="详细分析图表",
elem_classes="output-image"
)
with gr.TabItem("📖 用户指南"):
gr.Markdown(
"""
## 使用说明
1. **上传图像**:点击"输入图像"区域上传您的图像
2. **选择模式**:选择"完整分析"或"快速模式"
- 完整分析:完整处理流程,包含所有可视化结果
- 快速模式:快速处理,仅输出显著性图
3. **调整参数**:
- 使用阈值滑块调整检测灵敏度
- 根据需要启用对比度增强或降噪
- 在高级设置中微调亮度、对比度和滤镜
4. **开始检测**:点击"开始检测"按钮开始分析
5. **查看结果**:在不同标签页查看各种可视化结果
6. **导出**:使用导出选项保存您的结果
## 功能特点
- **显著性图**:显示图像区域的显著性分布
- **热力图**:彩色编码的强度可视化
- **叠加效果**:在原始图像上叠加检测结果
- **对象分割**:提取关键对象区域
- **交互式对比**:滑动比较原始图像和处理结果
- **分析报告**:详细的统计信息和分析图表
## 性能提示
- 当只需要显著性图时使用快速模式
- 分辨率较低的图像处理速度更快
- 启用GPU可获得更好的性能
"""
)
with gr.TabItem("ℹ️ 关于"):
gr.Markdown(
"""
## 项目信息
- **版本**:3.0.0
- **架构**:PyTorch + Gradio
- **模型**:CyueNet
- **语言**:多语言支持
## 主要特点
- 实时图像处理和分析
- 多维结果可视化
- 丰富的图像调整选项
- 详细的数据分析报告
- 交互式对比工具
- 导出功能
- 缓存优化性能
## 更新日志
- ✅ 新增快速模式,提高处理速度
- ✅ 增强图像预处理选项
- ✅ 新增统计分析功能
- ✅ 改进用户界面,采用玻璃拟态设计
- ✅ 增加交互式对比滑块
- ✅ 使用缓存和线程优化性能
- ✅ 多语言图表支持
- ✅ 导出功能
## 系统要求
- Python 3.8+
- PyTorch 1.9+
- CUDA(可选,用于GPU加速)
- 推荐4GB以上内存
"""
)
# 事件处理
def reset_params():
return {
threshold_slider: 50,
brightness: 0,
contrast: 0,
filter_type: "无",
enhance_contrast: False,
denoise: False,
show_contours: True,
process_mode: "完整分析"
}
# 设置事件处理
submit_btn.click(
fn=run_demo,
inputs=[
input_image,
threshold_slider,
enhance_contrast,
denoise,
show_contours,
brightness,
contrast,
filter_type,
process_mode
],
outputs=[
original_output,
saliency_output,
heatmap_output,
overlayed_output,
segmented_output,
comparison_output,
time_info,
stats_output,
analysis_plot
]
)
reset_btn.click(
fn=reset_params,
inputs=[],
outputs=[
threshold_slider,
brightness,
contrast,
filter_type,
enhance_contrast,
denoise,
show_contours,
process_mode
]
)
# 交互式对比
comparison_slider.change(
fn=create_comparison_view,
inputs=[original_output, overlayed_output, comparison_slider],
outputs=[interactive_comparison]
)
# 导出功能
export_btn.click(
fn=processor.export_results,
inputs=[export_format],
outputs=[gr.Textbox(label="导出状态")]
)
# 启动应用
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
show_error=True
)