kunkk commited on
Commit
48275ef
·
verified ·
1 Parent(s): dad1346

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -10
app.py CHANGED
@@ -27,18 +27,21 @@ def load_model():
27
 
28
  def process_image(image, model, testsize=256):
29
  """处理图像并返回显著性检测结果"""
 
 
 
30
  # 预处理图像
31
- image = Image.fromarray(image).convert('RGB')
32
- image = transform_image(image, testsize)
33
- image = image.unsqueeze(0)
34
- image = image.to(device)
35
 
36
  # 计时
37
  time_start = time.time()
38
 
39
  # 推理
40
  with torch.no_grad():
41
- x1, res, s1_sig, edg1, edg_s, s2, e2, s2_sig, e2_sig, s3, e3, s3_sig, e3_sig, s4, e4, s4_sig, e4_sig, s5, e5, s5_sig, e5_sig, sk1, sk1_sig, sk2, sk2_sig, sk3, sk3_sig, sk4, sk4_sig, sk5, sk5_sig = model(image)
42
 
43
  time_end = time.time()
44
  inference_time = time_end - time_start
@@ -48,7 +51,6 @@ def process_image(image, model, testsize=256):
48
  res = (res - res.min()) / (res.max() - res.min() + 1e-8)
49
 
50
  # 将输出调整为原始图像大小
51
- original_image = np.array(Image.fromarray(image.cpu().squeeze().permute(1, 2, 0).numpy()))
52
  h, w = original_image.shape[:2]
53
  res_resized = cv2.resize(res, (w, h))
54
 
@@ -60,13 +62,23 @@ def process_image(image, model, testsize=256):
60
 
61
  # 将热力图与原始图像混合
62
  alpha = 0.5
63
- overlayed = cv2.addWeighted(original_image, 1-alpha, heatmap, alpha, 0)
 
 
 
 
 
 
64
 
65
  # 二值化结果用于分割
66
  _, binary_mask = cv2.threshold(res_vis, 127, 255, cv2.THRESH_BINARY)
67
- segmented = cv2.bitwise_and(original_image, original_image, mask=binary_mask)
 
 
 
 
68
 
69
- return original_image, res_vis, heatmap, overlayed, segmented, f"推理时间: {inference_time:.4f}秒"
70
 
71
  def run_demo(input_image):
72
  """Gradio界面的主函数"""
@@ -112,4 +124,5 @@ with gr.Blocks(title="显著性目标检测Demo") as demo:
112
  gr.Markdown("3. 系统将显示原始图像、显著性图、热力图、叠加结果和分割结果")
113
 
114
  # 启动Gradio应用
115
- demo.launch()
 
 
27
 
28
  def process_image(image, model, testsize=256):
29
  """处理图像并返回显著性检测结果"""
30
+ # 保存原始图像用于后续处理
31
+ original_image = image.copy()
32
+
33
  # 预处理图像
34
+ image_pil = Image.fromarray(image).convert('RGB')
35
+ image_tensor = transform_image(image_pil, testsize)
36
+ image_tensor = image_tensor.unsqueeze(0)
37
+ image_tensor = image_tensor.to(device)
38
 
39
  # 计时
40
  time_start = time.time()
41
 
42
  # 推理
43
  with torch.no_grad():
44
+ x1, res, s1_sig, edg1, edg_s, s2, e2, s2_sig, e2_sig, s3, e3, s3_sig, e3_sig, s4, e4, s4_sig, e4_sig, s5, e5, s5_sig, e5_sig, sk1, sk1_sig, sk2, sk2_sig, sk3, sk3_sig, sk4, sk4_sig, sk5, sk5_sig = model(image_tensor)
45
 
46
  time_end = time.time()
47
  inference_time = time_end - time_start
 
51
  res = (res - res.min()) / (res.max() - res.min() + 1e-8)
52
 
53
  # 将输出调整为原始图像大小
 
54
  h, w = original_image.shape[:2]
55
  res_resized = cv2.resize(res, (w, h))
56
 
 
62
 
63
  # 将热力图与原始图像混合
64
  alpha = 0.5
65
+ # 确保原始图像是BGR格式用于OpenCV操作
66
+ if len(original_image.shape) == 3 and original_image.shape[2] == 3:
67
+ original_bgr = cv2.cvtColor(original_image, cv2.COLOR_RGB2BGR)
68
+ else:
69
+ original_bgr = cv2.cvtColor(original_image, cv2.COLOR_GRAY2BGR)
70
+
71
+ overlayed = cv2.addWeighted(original_bgr, 1-alpha, heatmap, alpha, 0)
72
 
73
  # 二值化结果用于分割
74
  _, binary_mask = cv2.threshold(res_vis, 127, 255, cv2.THRESH_BINARY)
75
+ segmented = cv2.bitwise_and(original_bgr, original_bgr, mask=binary_mask)
76
+
77
+ # 转回RGB格式用于显示
78
+ overlayed_rgb = cv2.cvtColor(overlayed, cv2.COLOR_BGR2RGB)
79
+ segmented_rgb = cv2.cvtColor(segmented, cv2.COLOR_BGR2RGB)
80
 
81
+ return original_image, res_vis, heatmap, overlayed_rgb, segmented_rgb, f"推理时间: {inference_time:.4f}秒"
82
 
83
  def run_demo(input_image):
84
  """Gradio界面的主函数"""
 
124
  gr.Markdown("3. 系统将显示原始图像、显著性图、热力图、叠加结果和分割结果")
125
 
126
  # 启动Gradio应用
127
+ if __name__ == "__main__":
128
+ demo.launch(share=True)