Zilong-Zhang003 commited on
Commit
7318bea
·
1 Parent(s): c8a655a
Files changed (2) hide show
  1. app.py +23 -21
  2. dino_feature_extractor.py +3 -2
app.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import os
3
  import io
4
  import cv2
@@ -11,6 +10,7 @@ from functools import lru_cache
11
  from huggingface_hub import hf_hub_download, snapshot_download
12
  from torchvision.transforms.functional import normalize
13
  import glob
 
14
 
15
 
16
  from restormerRFR_arch import RestormerRFR
@@ -83,39 +83,41 @@ def get_model_and_device():
83
  return model, device
84
 
85
 
86
- @spaces.GPU(duration=120)
87
  def restore_image(pil_img: Image.Image) -> Image.Image:
88
  """
89
  输入一张图片,输出复原后的图片(与 RAM++ RestormerRFR + DINO 特征推理一致)
90
  """
91
- model, device = get_model_and_device()
92
- dino_extractor = get_dino_extractor(device)
93
-
94
 
95
- img_bgr = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR).astype(np.float32) / 255.0
96
- img = torch.from_numpy(np.transpose(img_bgr[:, :, [2, 1, 0]], (2, 0, 1))).float() # (3,H,W), RGB
97
- img = img.unsqueeze(0).to(device) # (1,3,H,W)
98
 
 
 
 
99
 
100
- mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
101
- std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
102
- normalize(img, mean, std, inplace=True)
103
 
104
- with torch.no_grad():
105
- dino_features = dino_extractor(img)
106
- output = model(img, dino_features)
107
 
 
 
 
108
 
109
- output = normalize(output, -1 * mean / std, 1 / std)
110
- output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() # (3,H,W)
111
- output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # (H,W,RGB)
112
- output = (output * 255.0).round().astype(np.uint8)
113
- out_pil = Image.fromarray(output, mode="RGB")
114
- return out_pil
115
 
 
 
 
 
 
 
 
 
116
 
117
  DESCRIPTION = """
118
- # RAM++ Demo
119
  """
120
 
121
  with gr.Blocks(title="RAM++ ZeroGPU Demo") as demo:
 
 
1
  import os
2
  import io
3
  import cv2
 
10
  from huggingface_hub import hf_hub_download, snapshot_download
11
  from torchvision.transforms.functional import normalize
12
  import glob
13
+ import traceback
14
 
15
 
16
  from restormerRFR_arch import RestormerRFR
 
83
  return model, device
84
 
85
 
86
+ @spaces.GPU(duration=240)
87
  def restore_image(pil_img: Image.Image) -> Image.Image:
88
  """
89
  输入一张图片,输出复原后的图片(与 RAM++ RestormerRFR + DINO 特征推理一致)
90
  """
91
+ try:
92
+ model, device = get_model_and_device()
93
+ dino_extractor = get_dino_extractor(device)
94
 
 
 
 
95
 
96
+ img_bgr = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR).astype(np.float32) / 255.0
97
+ img = torch.from_numpy(np.transpose(img_bgr[:, :, [2, 1, 0]], (2, 0, 1))).float() # (3,H,W), RGB
98
+ img = img.unsqueeze(0).to(device) # (1,3,H,W)
99
 
 
 
 
100
 
101
+ mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
102
+ std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
103
+ normalize(img, mean, std, inplace=True)
104
 
105
+ with torch.no_grad():
106
+ dino_features = dino_extractor(img)
107
+ output = model(img, dino_features)
108
 
 
 
 
 
 
 
109
 
110
+ output = normalize(output, -1 * mean / std, 1 / std)
111
+ output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() # (3,H,W)
112
+ output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # (H,W,RGB)
113
+ output = (output * 255.0).round().astype(np.uint8)
114
+ out_pil = Image.fromarray(output, mode="RGB")
115
+ return out_pil
116
+ except Exception as e:
117
+ raise gr.Error(f"{e}\n{traceback.format_exc()}")
118
 
119
  DESCRIPTION = """
120
+ # RAM++: Robust Representation Learning via Adaptive Mask for All-in-One Image Restoration
121
  """
122
 
123
  with gr.Blocks(title="RAM++ ZeroGPU Demo") as demo:
dino_feature_extractor.py CHANGED
@@ -10,8 +10,9 @@ class DinoFeatureModule(nn.Module):
10
  def __init__(self, model_id: str = "facebook/dinov2-giant"):
11
  super(DinoFeatureModule, self).__init__()
12
  dtype = torch.float32
 
13
  self.dino = AutoModel.from_pretrained(
14
- model_id,
15
  torch_dtype=dtype
16
  )
17
 
@@ -110,7 +111,7 @@ class DinoFeatureModule(nn.Module):
110
 
111
  shortest_edge = min(target_h, target_w)
112
  processor = AutoImageProcessor.from_pretrained(
113
- model_id,
114
  local_files_only=False,
115
  do_rescale=False,
116
  do_center_crop=False,
 
10
  def __init__(self, model_id: str = "facebook/dinov2-giant"):
11
  super(DinoFeatureModule, self).__init__()
12
  dtype = torch.float32
13
+ self.model_id = model_id
14
  self.dino = AutoModel.from_pretrained(
15
+ self.model_id,
16
  torch_dtype=dtype
17
  )
18
 
 
111
 
112
  shortest_edge = min(target_h, target_w)
113
  processor = AutoImageProcessor.from_pretrained(
114
+ self.model_id,
115
  local_files_only=False,
116
  do_rescale=False,
117
  do_center_crop=False,