Update tasks/image.py
Browse files- tasks/image.py +22 -4
    	
        tasks/image.py
    CHANGED
    
    | @@ -13,7 +13,6 @@ from PIL import Image | |
| 13 | 
             
            from transformers import MobileViTImageProcessor, MobileViTForSemanticSegmentation
         | 
| 14 | 
             
            import cv2
         | 
| 15 | 
             
            from tqdm import tqdm
         | 
| 16 | 
            -
            from dataset import WildfireSmokeDataset
         | 
| 17 | 
             
            from torch.utils.data import DataLoader
         | 
| 18 |  | 
| 19 | 
             
            from dotenv import load_dotenv
         | 
| @@ -30,6 +29,19 @@ model = MobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobile | |
| 30 | 
             
            model.load_state_dict(torch.load(model_path))
         | 
| 31 | 
             
            model.eval()
         | 
| 32 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 33 | 
             
            def get_bounding_boxes_from_mask(mask):
         | 
| 34 | 
             
                """Extract bounding boxes from a binary mask."""
         | 
| 35 | 
             
                pred_boxes = []
         | 
| @@ -39,7 +51,7 @@ def get_bounding_boxes_from_mask(mask): | |
| 39 | 
             
                        x, y, w, h = cv2.boundingRect(contour)
         | 
| 40 | 
             
                        pred_boxes.append((x, y, x + w, y + h))
         | 
| 41 | 
             
                return pred_boxes
         | 
| 42 | 
            -
             | 
| 43 | 
             
            def parse_boxes(annotation_string):
         | 
| 44 | 
             
                """Parse multiple boxes from a single annotation string.
         | 
| 45 | 
             
                Each box has 5 values: class_id, x_center, y_center, width, height"""
         | 
| @@ -130,6 +142,10 @@ async def evaluate_image(request: ImageEvaluationRequest): | |
| 130 | 
             
                for example in test_dataset:
         | 
| 131 | 
             
                    # Extract image and annotations
         | 
| 132 | 
             
                    image = example["image"]
         | 
|  | |
|  | |
|  | |
|  | |
| 133 | 
             
                    annotation = example.get("annotations", "").strip()
         | 
| 134 |  | 
| 135 |  | 
| @@ -154,8 +170,10 @@ async def evaluate_image(request: ImageEvaluationRequest): | |
| 154 |  | 
| 155 | 
             
                    probabilities = torch.sigmoid(logits)
         | 
| 156 | 
             
                    predicted_mask = (probabilities[0, 1] > 0.30).cpu().numpy().astype(np.uint8)
         | 
| 157 | 
            -
                    predicted_mask_resized = cv2.resize(predicted_mask, (512, 512), interpolation=cv2.INTER_NEAREST)
         | 
| 158 | 
            -
             | 
|  | |
|  | |
| 159 | 
             
                    # Extract predicted bounding boxes
         | 
| 160 | 
             
                    predicted_boxes = get_bounding_boxes_from_mask(predicted_mask_resized)
         | 
| 161 | 
             
                    pred_boxes.append(predicted_boxes)
         | 
|  | |
| 13 | 
             
            from transformers import MobileViTImageProcessor, MobileViTForSemanticSegmentation
         | 
| 14 | 
             
            import cv2
         | 
| 15 | 
             
            from tqdm import tqdm
         | 
|  | |
| 16 | 
             
            from torch.utils.data import DataLoader
         | 
| 17 |  | 
| 18 | 
             
            from dotenv import load_dotenv
         | 
|  | |
| 29 | 
             
            model.load_state_dict(torch.load(model_path))
         | 
| 30 | 
             
            model.eval()
         | 
| 31 |  | 
| 32 | 
            +
            def preprocess(image):
         | 
| 33 | 
            +
                image = image.resize((512,512))
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                # Convert to BGR
         | 
| 36 | 
            +
                image = np.array(image)[:, :, ::-1]  # Convert RGB to BGR
         | 
| 37 | 
            +
                image = Image.fromarray(image)
         | 
| 38 | 
            +
                image = image.resize(self.image_size)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                # Normalize pixel values to [0, 1]
         | 
| 41 | 
            +
                image = np.array(image, dtype=np.float32) / 255.0
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                return image
         | 
| 44 | 
            +
             | 
| 45 | 
             
            def get_bounding_boxes_from_mask(mask):
         | 
| 46 | 
             
                """Extract bounding boxes from a binary mask."""
         | 
| 47 | 
             
                pred_boxes = []
         | 
|  | |
| 51 | 
             
                        x, y, w, h = cv2.boundingRect(contour)
         | 
| 52 | 
             
                        pred_boxes.append((x, y, x + w, y + h))
         | 
| 53 | 
             
                return pred_boxes
         | 
| 54 | 
            +
             | 
| 55 | 
             
            def parse_boxes(annotation_string):
         | 
| 56 | 
             
                """Parse multiple boxes from a single annotation string.
         | 
| 57 | 
             
                Each box has 5 values: class_id, x_center, y_center, width, height"""
         | 
|  | |
| 142 | 
             
                for example in test_dataset:
         | 
| 143 | 
             
                    # Extract image and annotations
         | 
| 144 | 
             
                    image = example["image"]
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    original_shape = (len(image), len(image[0]))
         | 
| 147 | 
            +
                    image = preprocess(image)
         | 
| 148 | 
            +
                    
         | 
| 149 | 
             
                    annotation = example.get("annotations", "").strip()
         | 
| 150 |  | 
| 151 |  | 
|  | |
| 170 |  | 
| 171 | 
             
                    probabilities = torch.sigmoid(logits)
         | 
| 172 | 
             
                    predicted_mask = (probabilities[0, 1] > 0.30).cpu().numpy().astype(np.uint8)
         | 
| 173 | 
            +
                    # predicted_mask_resized = cv2.resize(predicted_mask, (512, 512), interpolation=cv2.INTER_NEAREST)
         | 
| 174 | 
            +
                    predicted_mask_resized = cv2.resize(predicted_mask, original_shape, interpolation=cv2.INTER_NEAREST)
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    
         | 
| 177 | 
             
                    # Extract predicted bounding boxes
         | 
| 178 | 
             
                    predicted_boxes = get_bounding_boxes_from_mask(predicted_mask_resized)
         | 
| 179 | 
             
                    pred_boxes.append(predicted_boxes)
         | 
