Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	
		Zilong-Zhang003
		
	commited on
		
		
					Commit 
							
							·
						
						7318bea
	
1
								Parent(s):
							
							c8a655a
								
NameError
Browse files- app.py +23 -21
- dino_feature_extractor.py +3 -2
    	
        app.py
    CHANGED
    
    | @@ -1,4 +1,3 @@ | |
| 1 | 
            -
             | 
| 2 | 
             
            import os
         | 
| 3 | 
             
            import io
         | 
| 4 | 
             
            import cv2
         | 
| @@ -11,6 +10,7 @@ from functools import lru_cache | |
| 11 | 
             
            from huggingface_hub import hf_hub_download, snapshot_download
         | 
| 12 | 
             
            from torchvision.transforms.functional import normalize
         | 
| 13 | 
             
            import glob
         | 
|  | |
| 14 |  | 
| 15 |  | 
| 16 | 
             
            from restormerRFR_arch import RestormerRFR
         | 
| @@ -83,39 +83,41 @@ def get_model_and_device(): | |
| 83 | 
             
                return model, device
         | 
| 84 |  | 
| 85 |  | 
| 86 | 
            -
            @spaces.GPU(duration= | 
| 87 | 
             
            def restore_image(pil_img: Image.Image) -> Image.Image:
         | 
| 88 | 
             
                """
         | 
| 89 | 
             
                输入一张图片,输出复原后的图片(与 RAM++ RestormerRFR + DINO 特征推理一致)
         | 
| 90 | 
             
                """
         | 
| 91 | 
            -
                 | 
| 92 | 
            -
             | 
| 93 | 
            -
             | 
| 94 |  | 
| 95 | 
            -
                img_bgr = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR).astype(np.float32) / 255.0
         | 
| 96 | 
            -
                img = torch.from_numpy(np.transpose(img_bgr[:, :, [2, 1, 0]], (2, 0, 1))).float()  # (3,H,W), RGB
         | 
| 97 | 
            -
                img = img.unsqueeze(0).to(device)  # (1,3,H,W)
         | 
| 98 |  | 
|  | |
|  | |
|  | |
| 99 |  | 
| 100 | 
            -
                mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
         | 
| 101 | 
            -
                std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
         | 
| 102 | 
            -
                normalize(img, mean, std, inplace=True)
         | 
| 103 |  | 
| 104 | 
            -
             | 
| 105 | 
            -
                     | 
| 106 | 
            -
                     | 
| 107 |  | 
|  | |
|  | |
|  | |
| 108 |  | 
| 109 | 
            -
                output = normalize(output, -1 * mean / std, 1 / std)
         | 
| 110 | 
            -
                output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()  # (3,H,W)
         | 
| 111 | 
            -
                output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))  # (H,W,RGB)
         | 
| 112 | 
            -
                output = (output * 255.0).round().astype(np.uint8)
         | 
| 113 | 
            -
                out_pil = Image.fromarray(output, mode="RGB")
         | 
| 114 | 
            -
                return out_pil
         | 
| 115 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 116 |  | 
| 117 | 
             
            DESCRIPTION = """
         | 
| 118 | 
            -
            # RAM | 
| 119 | 
             
            """
         | 
| 120 |  | 
| 121 | 
             
            with gr.Blocks(title="RAM++ ZeroGPU Demo") as demo:
         | 
|  | |
|  | |
| 1 | 
             
            import os
         | 
| 2 | 
             
            import io
         | 
| 3 | 
             
            import cv2
         | 
|  | |
| 10 | 
             
            from huggingface_hub import hf_hub_download, snapshot_download
         | 
| 11 | 
             
            from torchvision.transforms.functional import normalize
         | 
| 12 | 
             
            import glob
         | 
| 13 | 
            +
            import traceback
         | 
| 14 |  | 
| 15 |  | 
| 16 | 
             
            from restormerRFR_arch import RestormerRFR
         | 
|  | |
| 83 | 
             
                return model, device
         | 
| 84 |  | 
| 85 |  | 
| 86 | 
            +
            @spaces.GPU(duration=240)
         | 
| 87 | 
             
            def restore_image(pil_img: Image.Image) -> Image.Image:
         | 
| 88 | 
             
                """
         | 
| 89 | 
             
                输入一张图片,输出复原后的图片(与 RAM++ RestormerRFR + DINO 特征推理一致)
         | 
| 90 | 
             
                """
         | 
| 91 | 
            +
                try:
         | 
| 92 | 
            +
                    model, device = get_model_and_device()
         | 
| 93 | 
            +
                    dino_extractor = get_dino_extractor(device)
         | 
| 94 |  | 
|  | |
|  | |
|  | |
| 95 |  | 
| 96 | 
            +
                    img_bgr = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR).astype(np.float32) / 255.0
         | 
| 97 | 
            +
                    img = torch.from_numpy(np.transpose(img_bgr[:, :, [2, 1, 0]], (2, 0, 1))).float()  # (3,H,W), RGB
         | 
| 98 | 
            +
                    img = img.unsqueeze(0).to(device)  # (1,3,H,W)
         | 
| 99 |  | 
|  | |
|  | |
|  | |
| 100 |  | 
| 101 | 
            +
                    mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
         | 
| 102 | 
            +
                    std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
         | 
| 103 | 
            +
                    normalize(img, mean, std, inplace=True)
         | 
| 104 |  | 
| 105 | 
            +
                    with torch.no_grad():
         | 
| 106 | 
            +
                        dino_features = dino_extractor(img)
         | 
| 107 | 
            +
                        output = model(img, dino_features)
         | 
| 108 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 109 |  | 
| 110 | 
            +
                    output = normalize(output, -1 * mean / std, 1 / std)
         | 
| 111 | 
            +
                    output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()  # (3,H,W)
         | 
| 112 | 
            +
                    output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))  # (H,W,RGB)
         | 
| 113 | 
            +
                    output = (output * 255.0).round().astype(np.uint8)
         | 
| 114 | 
            +
                    out_pil = Image.fromarray(output, mode="RGB")
         | 
| 115 | 
            +
                    return out_pil
         | 
| 116 | 
            +
                except Exception as e:
         | 
| 117 | 
            +
                    raise gr.Error(f"{e}\n{traceback.format_exc()}")
         | 
| 118 |  | 
| 119 | 
             
            DESCRIPTION = """
         | 
| 120 | 
            +
            # RAM++: Robust Representation Learning via  Adaptive Mask for All-in-One Image Restoration
         | 
| 121 | 
             
            """
         | 
| 122 |  | 
| 123 | 
             
            with gr.Blocks(title="RAM++ ZeroGPU Demo") as demo:
         | 
    	
        dino_feature_extractor.py
    CHANGED
    
    | @@ -10,8 +10,9 @@ class DinoFeatureModule(nn.Module): | |
| 10 | 
             
                def __init__(self, model_id: str = "facebook/dinov2-giant"):
         | 
| 11 | 
             
                    super(DinoFeatureModule, self).__init__()
         | 
| 12 | 
             
                    dtype = torch.float32
         | 
|  | |
| 13 | 
             
                    self.dino = AutoModel.from_pretrained(
         | 
| 14 | 
            -
                        model_id,
         | 
| 15 | 
             
                        torch_dtype=dtype
         | 
| 16 | 
             
                    )
         | 
| 17 |  | 
| @@ -110,7 +111,7 @@ class DinoFeatureModule(nn.Module): | |
| 110 |  | 
| 111 | 
             
                    shortest_edge = min(target_h, target_w)
         | 
| 112 | 
             
                    processor = AutoImageProcessor.from_pretrained(
         | 
| 113 | 
            -
                        model_id,
         | 
| 114 | 
             
                        local_files_only=False,
         | 
| 115 | 
             
                        do_rescale=False,
         | 
| 116 | 
             
                        do_center_crop=False,  
         | 
|  | |
| 10 | 
             
                def __init__(self, model_id: str = "facebook/dinov2-giant"):
         | 
| 11 | 
             
                    super(DinoFeatureModule, self).__init__()
         | 
| 12 | 
             
                    dtype = torch.float32
         | 
| 13 | 
            +
                    self.model_id = model_id
         | 
| 14 | 
             
                    self.dino = AutoModel.from_pretrained(
         | 
| 15 | 
            +
                        self.model_id,
         | 
| 16 | 
             
                        torch_dtype=dtype
         | 
| 17 | 
             
                    )
         | 
| 18 |  | 
|  | |
| 111 |  | 
| 112 | 
             
                    shortest_edge = min(target_h, target_w)
         | 
| 113 | 
             
                    processor = AutoImageProcessor.from_pretrained(
         | 
| 114 | 
            +
                        self.model_id,
         | 
| 115 | 
             
                        local_files_only=False,
         | 
| 116 | 
             
                        do_rescale=False,
         | 
| 117 | 
             
                        do_center_crop=False,  
         | 
