Update tasks/image.py
Browse files- tasks/image.py +22 -20
tasks/image.py
CHANGED
|
@@ -35,13 +35,16 @@ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
|
|
| 35 |
model.eval()
|
| 36 |
|
| 37 |
def preprocess(image):
|
|
|
|
| 38 |
image = image.resize((512, 512))
|
| 39 |
-
|
|
|
|
|
|
|
| 40 |
image = np.array(image, dtype=np.float32) / 255.0
|
| 41 |
|
| 42 |
-
#
|
| 43 |
-
|
| 44 |
-
|
| 45 |
|
| 46 |
|
| 47 |
def get_bounding_boxes_from_mask(mask):
|
|
@@ -145,16 +148,11 @@ async def evaluate_image(request: ImageEvaluationRequest):
|
|
| 145 |
# Extract image and annotations
|
| 146 |
image = example["image"]
|
| 147 |
|
| 148 |
-
original_shape = image.size
|
| 149 |
-
image = preprocess(image)
|
| 150 |
-
|
| 151 |
annotation = example.get("annotations", "").strip()
|
| 152 |
-
|
| 153 |
-
|
| 154 |
has_smoke = len(annotation) > 0
|
| 155 |
true_labels.append(1 if has_smoke else 0)
|
| 156 |
|
| 157 |
-
|
| 158 |
if has_smoke:
|
| 159 |
image_true_boxes = parse_boxes(annotation)
|
| 160 |
if image_true_boxes:
|
|
@@ -165,26 +163,30 @@ async def evaluate_image(request: ImageEvaluationRequest):
|
|
| 165 |
true_boxes_list.append([])
|
| 166 |
|
| 167 |
# Model Inference
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
with torch.no_grad():
|
| 172 |
outputs = model(pixel_values=image_input)
|
| 173 |
logits = outputs.logits
|
| 174 |
-
|
|
|
|
| 175 |
probabilities = torch.sigmoid(logits)
|
| 176 |
predicted_mask = (probabilities[0, 1] > 0.30).cpu().numpy().astype(np.uint8)
|
| 177 |
-
# predicted_mask_resized = cv2.resize(predicted_mask, (512, 512), interpolation=cv2.INTER_NEAREST)
|
| 178 |
predicted_mask_resized = cv2.resize(predicted_mask, original_shape[::-1], interpolation=cv2.INTER_NEAREST)
|
| 179 |
-
|
| 180 |
|
| 181 |
-
# Extract
|
| 182 |
predicted_boxes = get_bounding_boxes_from_mask(predicted_mask_resized)
|
| 183 |
pred_boxes.append(predicted_boxes)
|
| 184 |
-
|
| 185 |
-
#
|
|
|
|
| 186 |
print(f"Prediction : {1 if len(predicted_boxes) > 0 else 0}")
|
| 187 |
-
predictions.append(1 if len(predicted_boxes) > 0 else 0)
|
| 188 |
|
| 189 |
|
| 190 |
# Filter only valid box pairs
|
|
|
|
| 35 |
model.eval()
|
| 36 |
|
| 37 |
def preprocess(image):
|
| 38 |
+
# Ensure input image is resized to a fixed size (512, 512)
|
| 39 |
image = image.resize((512, 512))
|
| 40 |
+
|
| 41 |
+
# Convert to NumPy and ensure BGR normalization
|
| 42 |
+
image = np.array(image)[:, :, ::-1] # Convert RGB to BGR
|
| 43 |
image = np.array(image, dtype=np.float32) / 255.0
|
| 44 |
|
| 45 |
+
# Return as a PIL Image for feature extractor compatibility
|
| 46 |
+
return Image.fromarray((image * 255).astype(np.uint8))
|
| 47 |
+
|
| 48 |
|
| 49 |
|
| 50 |
def get_bounding_boxes_from_mask(mask):
|
|
|
|
| 148 |
# Extract image and annotations
|
| 149 |
image = example["image"]
|
| 150 |
|
| 151 |
+
original_shape = image.size
|
|
|
|
|
|
|
| 152 |
annotation = example.get("annotations", "").strip()
|
|
|
|
|
|
|
| 153 |
has_smoke = len(annotation) > 0
|
| 154 |
true_labels.append(1 if has_smoke else 0)
|
| 155 |
|
|
|
|
| 156 |
if has_smoke:
|
| 157 |
image_true_boxes = parse_boxes(annotation)
|
| 158 |
if image_true_boxes:
|
|
|
|
| 163 |
true_boxes_list.append([])
|
| 164 |
|
| 165 |
# Model Inference
|
| 166 |
+
|
| 167 |
+
# Preprocess image
|
| 168 |
+
image = preprocess(image)
|
| 169 |
+
|
| 170 |
+
# Ensure correct feature extraction
|
| 171 |
+
image_input = feature_extractor(images=image, return_tensors="pt").pixel_values
|
| 172 |
+
|
| 173 |
+
# Perform inference
|
| 174 |
with torch.no_grad():
|
| 175 |
outputs = model(pixel_values=image_input)
|
| 176 |
logits = outputs.logits
|
| 177 |
+
|
| 178 |
+
# Threshold and process the segmentation mask
|
| 179 |
probabilities = torch.sigmoid(logits)
|
| 180 |
predicted_mask = (probabilities[0, 1] > 0.30).cpu().numpy().astype(np.uint8)
|
|
|
|
| 181 |
predicted_mask_resized = cv2.resize(predicted_mask, original_shape[::-1], interpolation=cv2.INTER_NEAREST)
|
|
|
|
| 182 |
|
| 183 |
+
# Extract bounding boxes
|
| 184 |
predicted_boxes = get_bounding_boxes_from_mask(predicted_mask_resized)
|
| 185 |
pred_boxes.append(predicted_boxes)
|
| 186 |
+
|
| 187 |
+
# Smoke prediction based on bounding box presence
|
| 188 |
+
predictions.append(1 if len(predicted_boxes) > 0 else 0)
|
| 189 |
print(f"Prediction : {1 if len(predicted_boxes) > 0 else 0}")
|
|
|
|
| 190 |
|
| 191 |
|
| 192 |
# Filter only valid box pairs
|