dalybuilds commited on
Commit
9ae2699
·
verified ·
1 Parent(s): 3ed4e4e

Update model_utils.py

Browse files
Files changed (1) hide show
  1. model_utils.py +72 -159
model_utils.py CHANGED
@@ -4,88 +4,63 @@ from transformers import ViTForImageClassification, AutoFeatureExtractor
4
  import numpy as np
5
  from PIL import Image
6
  import cv2
7
- from scipy.special import softmax
8
 
9
  class BugClassifier:
10
  def __init__(self):
11
  try:
12
- # Initialize model and feature extractor
13
  self.model = ViTForImageClassification.from_pretrained(
14
- "microsoft/beit-base-patch16-224-pt22k-ft22k",
15
  num_labels=10,
16
  ignore_mismatched_sizes=True
17
  )
18
-
19
- # Add custom classification head
20
- self.model.classifier = torch.nn.Sequential(
21
- torch.nn.Linear(768, 512),
22
- torch.nn.ReLU(),
23
- torch.nn.Dropout(0.2),
24
- torch.nn.Linear(512, 10) # 10 classes
25
- )
26
-
27
- self.feature_extractor = AutoFeatureExtractor.from_pretrained(
28
- "microsoft/beit-base-patch16-224-pt22k-ft22k"
29
- )
30
 
31
  # Set model to evaluation mode
32
  self.model.eval()
33
 
34
- # Define detailed class labels
35
  self.labels = [
36
  "Seven-spotted Ladybug", "Monarch Butterfly", "Carpenter Ant",
37
  "Japanese Beetle", "Garden Spider", "Green Grasshopper",
38
  "Luna Moth", "Common Dragonfly", "Honey Bee", "Paper Wasp"
39
  ]
40
 
41
- # Create a mapping of general categories for better classification
42
- self.category_mapping = {
43
- "Seven-spotted Ladybug": ["ladybug", "ladybird", "coccinellidae"],
44
- "Monarch Butterfly": ["butterfly", "lepidoptera"],
45
- "Carpenter Ant": ["ant", "formicidae"],
46
- "Japanese Beetle": ["beetle", "coleoptera"],
47
- "Garden Spider": ["spider", "arachnid"],
48
- "Green Grasshopper": ["grasshopper", "orthoptera"],
49
- "Luna Moth": ["moth", "lepidoptera"],
50
- "Common Dragonfly": ["dragonfly", "odonata"],
51
- "Honey Bee": ["bee", "apidae"],
52
- "Paper Wasp": ["wasp", "vespidae"]
53
- }
54
-
55
- # Detailed species information database
56
  self.species_info = {
57
  "Seven-spotted Ladybug": """
58
- The Seven-spotted Ladybug (Coccinella septempunctata) is one of the most common ladybug species.
59
- These beneficial insects are natural predators of garden pests like aphids and scale insects.
60
- Each ladybug can eat up to 5,000 aphids during its lifetime, making them excellent natural pest controllers.
61
  Their distinct red coloring with seven black spots serves as a warning to predators.
62
  """,
63
  "Monarch Butterfly": """
64
- The Monarch Butterfly (Danaus plexippus) is known for its spectacular annual migration.
65
- These butterflies play a crucial role in pollination and are indicators of ecosystem health.
66
- They have a unique relationship with milkweed plants, which their caterpillars exclusively feed on.
67
- Their orange and black wings serve as warning colors to predators about their toxicity.
68
  """,
69
- "Carpenter Ant": """
70
- Carpenter Ants (Camponotus spp.) are large ants that build nests in wood.
71
- While they don't eat wood like termites, they can cause structural damage to buildings.
72
- These social insects live in colonies and play important roles in forest ecosystems,
73
- helping to break down dead wood and maintain soil health.
74
- """,
75
- "Japanese Beetle": """
76
- The Japanese Beetle (Popillia japonica) is recognized by its metallic green body.
77
- While beautiful, these beetles can be significant garden pests, feeding on many plant species.
78
- They are most active in summer months and can be managed through various natural control methods.
79
- Their presence often indicates a healthy soil ecosystem, though their feeding can damage plants.
80
- """,
81
- # Add other species info here...
82
  }
83
-
84
  except Exception as e:
 
85
  raise RuntimeError(f"Error initializing BugClassifier: {str(e)}")
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  def predict(self, image):
88
- """Make a prediction on the input image with improved confidence handling"""
89
  try:
90
  if not isinstance(image, Image.Image):
91
  raise ValueError("Input must be a PIL Image")
@@ -98,81 +73,20 @@ class BugClassifier:
98
  outputs = self.model(image_tensor)
99
  probs = F.softmax(outputs.logits, dim=-1).numpy()[0]
100
 
101
- # Get top 3 predictions
102
- top3_idx = np.argsort(probs)[-3:][::-1]
103
- top3_probs = probs[top3_idx]
104
 
105
- # Use confidence threshold
106
- CONFIDENCE_THRESHOLD = 0.4 # 40% confidence threshold
 
107
 
108
- if top3_probs[0] < CONFIDENCE_THRESHOLD:
109
- # If confidence is too low, return "Unknown"
110
- return "Unknown Insect", float(top3_probs[0] * 100)
111
-
112
- # Check if there's a clear winner (significantly higher than second best)
113
- if (top3_probs[0] - top3_probs[1]) > 0.2: # 20% margin
114
- pred_idx = top3_idx[0]
115
- else:
116
- # If it's close, consider image quality and features
117
- image_quality = self.assess_image_quality(image)
118
- if image_quality < 0.5:
119
- return "Image Unclear", 0.0
120
- pred_idx = top3_idx[0]
121
 
122
- return self.labels[pred_idx], float(probs[pred_idx] * 100)
123
-
124
  except Exception as e:
125
  print(f"Prediction error: {str(e)}")
126
  return "Error Processing Image", 0.0
127
 
128
- def preprocess_image(self, image):
129
- """Preprocess image for model input"""
130
- try:
131
- # Convert RGBA to RGB if necessary
132
- if image.mode == 'RGBA':
133
- image = image.convert('RGB')
134
-
135
- # Resize image if needed
136
- if image.size != (224, 224):
137
- image = image.resize((224, 224), Image.Resampling.LANCZOS)
138
-
139
- # Process image using feature extractor
140
- inputs = self.feature_extractor(images=image, return_tensors="pt")
141
- return inputs.pixel_values
142
-
143
- except Exception as e:
144
- raise ValueError(f"Error preprocessing image: {str(e)}")
145
-
146
- def assess_image_quality(self, image):
147
- """Assess the quality of the input image"""
148
- try:
149
- # Convert to numpy array
150
- img_array = np.array(image)
151
-
152
- # Check brightness
153
- brightness = np.mean(img_array)
154
-
155
- # Check contrast
156
- contrast = np.std(img_array)
157
-
158
- # Check blur
159
- if len(img_array.shape) == 3:
160
- gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
161
- else:
162
- gray = img_array
163
- blur_score = cv2.Laplacian(gray, cv2.CV_64F).var()
164
-
165
- # Normalize and combine scores
166
- brightness_score = 1 - abs(brightness - 128) / 128
167
- contrast_score = min(contrast / 50, 1)
168
- blur_score = min(blur_score / 1000, 1)
169
-
170
- return (brightness_score + contrast_score + blur_score) / 3
171
-
172
- except Exception as e:
173
- print(f"Error assessing image quality: {str(e)}")
174
- return 0.5 # Return middle value if assessment fails
175
-
176
  def get_species_info(self, species):
177
  """Return information about a species"""
178
  default_info = f"""
@@ -182,59 +96,58 @@ class BugClassifier:
182
  """
183
  return self.species_info.get(species, default_info)
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  def get_gradcam(self, image):
186
- """Generate Grad-CAM visualization for the image"""
187
  try:
188
- # Preprocess image
189
  image_tensor = self.preprocess_image(image)
190
 
191
- # Get model attention weights
192
  with torch.no_grad():
193
  outputs = self.model(image_tensor, output_attentions=True)
194
- attention = outputs.attentions[-1]
 
195
 
196
- # Convert attention to heatmap
197
- attention_map = attention.mean(dim=1).mean(dim=1).numpy()[0]
198
-
199
- # Resize attention map to image size
200
  attention_map = cv2.resize(attention_map, (224, 224))
201
 
202
- # Normalize attention map
203
  attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())
204
 
205
- # Convert to heatmap
206
- heatmap = cv2.applyColorMap(np.uint8(255 * attention_map), cv2.COLORMAP_JET)
 
207
 
208
- # Convert original image to RGB numpy array
209
- original_image = np.array(image.resize((224, 224)))
210
- if len(original_image.shape) == 2:
211
- original_image = cv2.cvtColor(original_image, cv2.COLOR_GRAY2RGB)
212
 
213
  # Overlay heatmap on original image
214
- overlay = cv2.addWeighted(original_image, 0.7, heatmap, 0.3, 0)
215
 
216
- return Image.fromarray(overlay)
217
 
218
  except Exception as e:
219
- print(f"Error generating Grad-CAM: {str(e)}")
220
- return image # Return original image if Grad-CAM fails
221
-
222
- def compare_species(self, species1, species2):
223
- """Generate comparison information between two species"""
224
- info1 = self.get_species_info(species1)
225
- info2 = self.get_species_info(species2)
226
-
227
- return f"""
228
- **Comparing {species1} and {species2}:**
229
-
230
- {species1}:
231
- {info1}
232
-
233
- {species2}:
234
- {info2}
235
-
236
- Both species contribute to their ecosystems in unique ways.
237
- """
238
 
239
  def get_severity_prediction(species):
240
  """Predict ecological severity/impact based on species"""
@@ -250,6 +163,6 @@ def get_severity_prediction(species):
250
  "Honey Bee": "Low",
251
  "Paper Wasp": "Medium",
252
  "Unknown Insect": "Unknown",
253
- "Image Unclear": "Unknown"
254
  }
255
- return severity_map.get(species, "Unknown")
 
4
  import numpy as np
5
  from PIL import Image
6
  import cv2
 
7
 
8
  class BugClassifier:
9
  def __init__(self):
10
  try:
11
+ # Use standard ViT model without modifications
12
  self.model = ViTForImageClassification.from_pretrained(
13
+ "google/vit-base-patch16-224",
14
  num_labels=10,
15
  ignore_mismatched_sizes=True
16
  )
17
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  # Set model to evaluation mode
20
  self.model.eval()
21
 
22
+ # Define class labels
23
  self.labels = [
24
  "Seven-spotted Ladybug", "Monarch Butterfly", "Carpenter Ant",
25
  "Japanese Beetle", "Garden Spider", "Green Grasshopper",
26
  "Luna Moth", "Common Dragonfly", "Honey Bee", "Paper Wasp"
27
  ]
28
 
29
+ # Species information database
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  self.species_info = {
31
  "Seven-spotted Ladybug": """
32
+ The Seven-spotted Ladybug is one of the most common ladybug species.
33
+ These beneficial insects are natural predators of garden pests like aphids.
 
34
  Their distinct red coloring with seven black spots serves as a warning to predators.
35
  """,
36
  "Monarch Butterfly": """
37
+ The Monarch Butterfly is known for its spectacular annual migration.
38
+ They play a crucial role in pollination and are indicators of ecosystem health.
39
+ Their orange and black wings serve as warning colors to predators.
 
40
  """,
41
+ # Add other species info as needed...
 
 
 
 
 
 
 
 
 
 
 
 
42
  }
 
43
  except Exception as e:
44
+ print(f"Error initializing model: {str(e)}")
45
  raise RuntimeError(f"Error initializing BugClassifier: {str(e)}")
46
 
47
+ def preprocess_image(self, image):
48
+ """Preprocess image for model input"""
49
+ try:
50
+ # Convert RGBA to RGB if necessary
51
+ if image.mode == 'RGBA':
52
+ image = image.convert('RGB')
53
+
54
+ # Use feature extractor to handle resizing and normalization
55
+ inputs = self.feature_extractor(images=image, return_tensors="pt")
56
+ return inputs.pixel_values
57
+
58
+ except Exception as e:
59
+ print(f"Preprocessing error: {str(e)}")
60
+ raise ValueError(f"Error preprocessing image: {str(e)}")
61
+
62
  def predict(self, image):
63
+ """Make a prediction on the input image"""
64
  try:
65
  if not isinstance(image, Image.Image):
66
  raise ValueError("Input must be a PIL Image")
 
73
  outputs = self.model(image_tensor)
74
  probs = F.softmax(outputs.logits, dim=-1).numpy()[0]
75
 
76
+ # Get prediction with highest confidence
77
+ pred_idx = np.argmax(probs)
78
+ confidence = float(probs[pred_idx] * 100)
79
 
80
+ # Check confidence threshold
81
+ if confidence < 40: # 40% threshold
82
+ return "Unknown Insect", confidence
83
 
84
+ return self.labels[pred_idx], confidence
 
 
 
 
 
 
 
 
 
 
 
 
85
 
 
 
86
  except Exception as e:
87
  print(f"Prediction error: {str(e)}")
88
  return "Error Processing Image", 0.0
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  def get_species_info(self, species):
91
  """Return information about a species"""
92
  default_info = f"""
 
96
  """
97
  return self.species_info.get(species, default_info)
98
 
99
+ def compare_species(self, species1, species2):
100
+ """Generate comparison information between two species"""
101
+ info1 = self.get_species_info(species1)
102
+ info2 = self.get_species_info(species2)
103
+
104
+ return f"""
105
+ **Comparing {species1} and {species2}:**
106
+
107
+ {species1}:
108
+ {info1}
109
+
110
+ {species2}:
111
+ {info2}
112
+
113
+ Both species contribute to their ecosystems in unique ways.
114
+ """
115
+
116
  def get_gradcam(self, image):
117
+ """Generate a simple attention visualization"""
118
  try:
119
+ # Create a basic heatmap using model outputs
120
  image_tensor = self.preprocess_image(image)
121
 
 
122
  with torch.no_grad():
123
  outputs = self.model(image_tensor, output_attentions=True)
124
+ # Get attention weights from last layer
125
+ attention = outputs.attentions[-1].mean(dim=1).mean(dim=1)
126
 
127
+ # Convert attention to numpy and resize
128
+ attention_map = attention.numpy()[0]
 
 
129
  attention_map = cv2.resize(attention_map, (224, 224))
130
 
131
+ # Normalize the attention map
132
  attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())
133
 
134
+ # Create heatmap
135
+ heatmap = np.uint8(255 * attention_map)
136
+ heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
137
 
138
+ # Prepare original image
139
+ original_image = image.copy()
140
+ original_image = original_image.resize((224, 224))
141
+ original_array = np.array(original_image)
142
 
143
  # Overlay heatmap on original image
144
+ output = cv2.addWeighted(original_array, 0.7, heatmap, 0.3, 0)
145
 
146
+ return Image.fromarray(output)
147
 
148
  except Exception as e:
149
+ print(f"Grad-CAM error: {str(e)}")
150
+ return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
  def get_severity_prediction(species):
153
  """Predict ecological severity/impact based on species"""
 
163
  "Honey Bee": "Low",
164
  "Paper Wasp": "Medium",
165
  "Unknown Insect": "Unknown",
166
+ "Error Processing Image": "Unknown"
167
  }
168
+ return severity_map.get(species, "Medium")