Spaces:
Runtime error
Runtime error
| import inspect | |
| import math | |
| from typing import Any, Dict, List | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import ultralytics | |
| if hasattr(ultralytics, "FastSAM"): | |
| from ultralytics import FastSAM as YOLO | |
| else: | |
| from ultralytics import YOLO | |
| class FastSAM: | |
| def __init__( | |
| self, | |
| checkpoint: str, | |
| ) -> None: | |
| self.model_path = checkpoint | |
| self.model = YOLO(self.model_path) | |
| if not hasattr(torch.nn.Upsample, "recompute_scale_factor"): | |
| torch.nn.Upsample.recompute_scale_factor = None | |
| def to(self, device) -> None: | |
| self.model.to(device) | |
| def device(self) -> Any: | |
| return self.model.device | |
| def __call__(self, source=None, stream=False, **kwargs) -> Any: | |
| return self.model(source=source, stream=stream, **kwargs) | |
| class FastSamAutomaticMaskGenerator: | |
| def __init__( | |
| self, | |
| model: FastSAM, | |
| points_per_batch: int = None, | |
| pred_iou_thresh: float = None, | |
| stability_score_thresh: float = None, | |
| ) -> None: | |
| self.model = model | |
| self.points_per_batch = points_per_batch | |
| self.pred_iou_thresh = pred_iou_thresh | |
| self.stability_score_thresh = stability_score_thresh | |
| self.conf = 0.25 if stability_score_thresh >= 0.95 else 0.15 | |
| def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: | |
| height, width = image.shape[:2] | |
| new_height = math.ceil(height / 32) * 32 | |
| new_width = math.ceil(width / 32) * 32 | |
| resize_image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_CUBIC) | |
| backup_nn_dict = {} | |
| for key, _ in torch.nn.__dict__.copy().items(): | |
| if not inspect.isclass(torch.nn.__dict__.get(key)) and "Norm" in key: | |
| backup_nn_dict[key] = torch.nn.__dict__.pop(key) | |
| results = self.model( | |
| source=resize_image, | |
| stream=False, | |
| imgsz=max(new_height, new_width), | |
| device=self.model.device, | |
| retina_masks=True, | |
| iou=0.7, | |
| conf=self.conf, | |
| max_det=256) | |
| for key, value in backup_nn_dict.items(): | |
| setattr(torch.nn, key, value) | |
| # assert backup_nn_dict[key] == torch.nn.__dict__[key] | |
| annotations = results[0].masks.data | |
| if isinstance(annotations[0], torch.Tensor): | |
| annotations = np.array(annotations.cpu()) | |
| annotations_list = [] | |
| for mask in annotations: | |
| mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)) | |
| mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((7, 7), np.uint8)) | |
| mask = cv2.resize(mask, (width, height), interpolation=cv2.INTER_AREA) | |
| annotations_list.append(dict(segmentation=mask.astype(bool))) | |
| return annotations_list | |