dalybuilds commited on
Commit
3ed4e4e
·
verified ·
1 Parent(s): 37f5146

Update model_utils.py

Browse files
Files changed (1) hide show
  1. model_utils.py +154 -70
model_utils.py CHANGED
@@ -11,43 +11,120 @@ class BugClassifier:
11
  try:
12
  # Initialize model and feature extractor
13
  self.model = ViTForImageClassification.from_pretrained(
14
- "google/vit-base-patch16-224",
15
  num_labels=10,
16
  ignore_mismatched_sizes=True
17
  )
18
- self.feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  # Set model to evaluation mode
21
  self.model.eval()
22
 
23
- # Define class labels
24
  self.labels = [
25
- "Ladybug", "Butterfly", "Ant", "Beetle", "Spider",
26
- "Grasshopper", "Moth", "Dragonfly", "Bee", "Wasp"
 
27
  ]
28
 
29
- # Species information database
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  self.species_info = {
31
- "Ladybug": """
32
- Ladybugs are small, round beetles known for their distinctive spotted patterns.
33
- They are beneficial insects that feed on plant-damaging pests like aphids.
34
- Fun fact: The number of spots on a ladybug can indicate its species!
 
 
 
 
 
 
 
35
  """,
36
- "Butterfly": """
37
- Butterflies are beautiful insects known for their large, colorful wings.
38
- They play a crucial role in pollination and are indicators of ecosystem health.
39
- They undergo complete metamorphosis from caterpillar to adult.
 
40
  """,
41
- "Ant": """
42
- Ants are social insects that live in colonies. They are incredibly strong
43
- for their size and play vital roles in soil health and ecosystem maintenance.
 
 
44
  """,
45
- # Add more species information for other classes...
46
  }
47
 
48
  except Exception as e:
49
  raise RuntimeError(f"Error initializing BugClassifier: {str(e)}")
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  def preprocess_image(self, image):
52
  """Preprocess image for model input"""
53
  try:
@@ -66,30 +143,35 @@ class BugClassifier:
66
  except Exception as e:
67
  raise ValueError(f"Error preprocessing image: {str(e)}")
68
 
69
- def predict(self, image):
70
- """Make a prediction on the input image"""
71
  try:
72
- if not isinstance(image, Image.Image):
73
- raise ValueError("Input must be a PIL Image")
74
-
75
- # Preprocess image
76
- image_tensor = self.preprocess_image(image)
77
-
78
- # Make prediction
79
- with torch.no_grad():
80
- outputs = self.model(image_tensor)
81
- probs = F.softmax(outputs.logits, dim=-1).numpy()[0]
82
- pred_idx = np.argmax(probs)
83
-
84
- # Ensure index is within bounds
85
- if pred_idx >= len(self.labels):
86
- pred_idx = 0
87
-
88
- return self.labels[pred_idx], float(probs[pred_idx] * 100)
 
 
 
 
 
89
 
90
  except Exception as e:
91
- print(f"Prediction error: {str(e)}")
92
- return self.labels[0], 0.0
93
 
94
  def get_species_info(self, species):
95
  """Return information about a species"""
@@ -100,33 +182,16 @@ class BugClassifier:
100
  """
101
  return self.species_info.get(species, default_info)
102
 
103
- def compare_species(self, species1, species2):
104
- """Generate comparison information between two species"""
105
- info1 = self.get_species_info(species1)
106
- info2 = self.get_species_info(species2)
107
-
108
- return f"""
109
- **Comparing {species1} and {species2}:**
110
-
111
- {species1}:
112
- {info1}
113
-
114
- {species2}:
115
- {info2}
116
-
117
- Both species contribute to their ecosystems in unique ways.
118
- """
119
-
120
  def get_gradcam(self, image):
121
  """Generate Grad-CAM visualization for the image"""
122
  try:
123
  # Preprocess image
124
  image_tensor = self.preprocess_image(image)
125
 
126
- # Get model attention weights (using last layer's attention)
127
  with torch.no_grad():
128
  outputs = self.model(image_tensor, output_attentions=True)
129
- attention = outputs.attentions[-1] # Get last layer's attention
130
 
131
  # Convert attention to heatmap
132
  attention_map = attention.mean(dim=1).mean(dim=1).numpy()[0]
@@ -142,7 +207,7 @@ class BugClassifier:
142
 
143
  # Convert original image to RGB numpy array
144
  original_image = np.array(image.resize((224, 224)))
145
- if len(original_image.shape) == 2: # Convert grayscale to RGB
146
  original_image = cv2.cvtColor(original_image, cv2.COLOR_GRAY2RGB)
147
 
148
  # Overlay heatmap on original image
@@ -154,18 +219,37 @@ class BugClassifier:
154
  print(f"Error generating Grad-CAM: {str(e)}")
155
  return image # Return original image if Grad-CAM fails
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  def get_severity_prediction(species):
158
  """Predict ecological severity/impact based on species"""
159
  severity_map = {
160
- "Ladybug": "Low",
161
- "Butterfly": "Low",
162
- "Ant": "Medium",
163
- "Beetle": "Medium",
164
- "Spider": "Low",
165
- "Grasshopper": "Medium",
166
- "Moth": "Low",
167
- "Dragonfly": "Low",
168
- "Bee": "Low",
169
- "Wasp": "Medium"
 
 
170
  }
171
- return severity_map.get(species, "Medium")
 
11
  try:
12
  # Initialize model and feature extractor
13
  self.model = ViTForImageClassification.from_pretrained(
14
+ "microsoft/beit-base-patch16-224-pt22k-ft22k",
15
  num_labels=10,
16
  ignore_mismatched_sizes=True
17
  )
18
+
19
+ # Add custom classification head
20
+ self.model.classifier = torch.nn.Sequential(
21
+ torch.nn.Linear(768, 512),
22
+ torch.nn.ReLU(),
23
+ torch.nn.Dropout(0.2),
24
+ torch.nn.Linear(512, 10) # 10 classes
25
+ )
26
+
27
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(
28
+ "microsoft/beit-base-patch16-224-pt22k-ft22k"
29
+ )
30
 
31
  # Set model to evaluation mode
32
  self.model.eval()
33
 
34
+ # Define detailed class labels
35
  self.labels = [
36
+ "Seven-spotted Ladybug", "Monarch Butterfly", "Carpenter Ant",
37
+ "Japanese Beetle", "Garden Spider", "Green Grasshopper",
38
+ "Luna Moth", "Common Dragonfly", "Honey Bee", "Paper Wasp"
39
  ]
40
 
41
+ # Create a mapping of general categories for better classification
42
+ self.category_mapping = {
43
+ "Seven-spotted Ladybug": ["ladybug", "ladybird", "coccinellidae"],
44
+ "Monarch Butterfly": ["butterfly", "lepidoptera"],
45
+ "Carpenter Ant": ["ant", "formicidae"],
46
+ "Japanese Beetle": ["beetle", "coleoptera"],
47
+ "Garden Spider": ["spider", "arachnid"],
48
+ "Green Grasshopper": ["grasshopper", "orthoptera"],
49
+ "Luna Moth": ["moth", "lepidoptera"],
50
+ "Common Dragonfly": ["dragonfly", "odonata"],
51
+ "Honey Bee": ["bee", "apidae"],
52
+ "Paper Wasp": ["wasp", "vespidae"]
53
+ }
54
+
55
+ # Detailed species information database
56
  self.species_info = {
57
+ "Seven-spotted Ladybug": """
58
+ The Seven-spotted Ladybug (Coccinella septempunctata) is one of the most common ladybug species.
59
+ These beneficial insects are natural predators of garden pests like aphids and scale insects.
60
+ Each ladybug can eat up to 5,000 aphids during its lifetime, making them excellent natural pest controllers.
61
+ Their distinct red coloring with seven black spots serves as a warning to predators.
62
+ """,
63
+ "Monarch Butterfly": """
64
+ The Monarch Butterfly (Danaus plexippus) is known for its spectacular annual migration.
65
+ These butterflies play a crucial role in pollination and are indicators of ecosystem health.
66
+ They have a unique relationship with milkweed plants, which their caterpillars exclusively feed on.
67
+ Their orange and black wings serve as warning colors to predators about their toxicity.
68
  """,
69
+ "Carpenter Ant": """
70
+ Carpenter Ants (Camponotus spp.) are large ants that build nests in wood.
71
+ While they don't eat wood like termites, they can cause structural damage to buildings.
72
+ These social insects live in colonies and play important roles in forest ecosystems,
73
+ helping to break down dead wood and maintain soil health.
74
  """,
75
+ "Japanese Beetle": """
76
+ The Japanese Beetle (Popillia japonica) is recognized by its metallic green body.
77
+ While beautiful, these beetles can be significant garden pests, feeding on many plant species.
78
+ They are most active in summer months and can be managed through various natural control methods.
79
+ Their presence often indicates a healthy soil ecosystem, though their feeding can damage plants.
80
  """,
81
+ # Add other species info here...
82
  }
83
 
84
  except Exception as e:
85
  raise RuntimeError(f"Error initializing BugClassifier: {str(e)}")
86
 
87
+ def predict(self, image):
88
+ """Make a prediction on the input image with improved confidence handling"""
89
+ try:
90
+ if not isinstance(image, Image.Image):
91
+ raise ValueError("Input must be a PIL Image")
92
+
93
+ # Preprocess image
94
+ image_tensor = self.preprocess_image(image)
95
+
96
+ # Make prediction
97
+ with torch.no_grad():
98
+ outputs = self.model(image_tensor)
99
+ probs = F.softmax(outputs.logits, dim=-1).numpy()[0]
100
+
101
+ # Get top 3 predictions
102
+ top3_idx = np.argsort(probs)[-3:][::-1]
103
+ top3_probs = probs[top3_idx]
104
+
105
+ # Use confidence threshold
106
+ CONFIDENCE_THRESHOLD = 0.4 # 40% confidence threshold
107
+
108
+ if top3_probs[0] < CONFIDENCE_THRESHOLD:
109
+ # If confidence is too low, return "Unknown"
110
+ return "Unknown Insect", float(top3_probs[0] * 100)
111
+
112
+ # Check if there's a clear winner (significantly higher than second best)
113
+ if (top3_probs[0] - top3_probs[1]) > 0.2: # 20% margin
114
+ pred_idx = top3_idx[0]
115
+ else:
116
+ # If it's close, consider image quality and features
117
+ image_quality = self.assess_image_quality(image)
118
+ if image_quality < 0.5:
119
+ return "Image Unclear", 0.0
120
+ pred_idx = top3_idx[0]
121
+
122
+ return self.labels[pred_idx], float(probs[pred_idx] * 100)
123
+
124
+ except Exception as e:
125
+ print(f"Prediction error: {str(e)}")
126
+ return "Error Processing Image", 0.0
127
+
128
  def preprocess_image(self, image):
129
  """Preprocess image for model input"""
130
  try:
 
143
  except Exception as e:
144
  raise ValueError(f"Error preprocessing image: {str(e)}")
145
 
146
+ def assess_image_quality(self, image):
147
+ """Assess the quality of the input image"""
148
  try:
149
+ # Convert to numpy array
150
+ img_array = np.array(image)
151
+
152
+ # Check brightness
153
+ brightness = np.mean(img_array)
154
+
155
+ # Check contrast
156
+ contrast = np.std(img_array)
157
+
158
+ # Check blur
159
+ if len(img_array.shape) == 3:
160
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
161
+ else:
162
+ gray = img_array
163
+ blur_score = cv2.Laplacian(gray, cv2.CV_64F).var()
164
+
165
+ # Normalize and combine scores
166
+ brightness_score = 1 - abs(brightness - 128) / 128
167
+ contrast_score = min(contrast / 50, 1)
168
+ blur_score = min(blur_score / 1000, 1)
169
+
170
+ return (brightness_score + contrast_score + blur_score) / 3
171
 
172
  except Exception as e:
173
+ print(f"Error assessing image quality: {str(e)}")
174
+ return 0.5 # Return middle value if assessment fails
175
 
176
  def get_species_info(self, species):
177
  """Return information about a species"""
 
182
  """
183
  return self.species_info.get(species, default_info)
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  def get_gradcam(self, image):
186
  """Generate Grad-CAM visualization for the image"""
187
  try:
188
  # Preprocess image
189
  image_tensor = self.preprocess_image(image)
190
 
191
+ # Get model attention weights
192
  with torch.no_grad():
193
  outputs = self.model(image_tensor, output_attentions=True)
194
+ attention = outputs.attentions[-1]
195
 
196
  # Convert attention to heatmap
197
  attention_map = attention.mean(dim=1).mean(dim=1).numpy()[0]
 
207
 
208
  # Convert original image to RGB numpy array
209
  original_image = np.array(image.resize((224, 224)))
210
+ if len(original_image.shape) == 2:
211
  original_image = cv2.cvtColor(original_image, cv2.COLOR_GRAY2RGB)
212
 
213
  # Overlay heatmap on original image
 
219
  print(f"Error generating Grad-CAM: {str(e)}")
220
  return image # Return original image if Grad-CAM fails
221
 
222
+ def compare_species(self, species1, species2):
223
+ """Generate comparison information between two species"""
224
+ info1 = self.get_species_info(species1)
225
+ info2 = self.get_species_info(species2)
226
+
227
+ return f"""
228
+ **Comparing {species1} and {species2}:**
229
+
230
+ {species1}:
231
+ {info1}
232
+
233
+ {species2}:
234
+ {info2}
235
+
236
+ Both species contribute to their ecosystems in unique ways.
237
+ """
238
+
239
  def get_severity_prediction(species):
240
  """Predict ecological severity/impact based on species"""
241
  severity_map = {
242
+ "Seven-spotted Ladybug": "Low",
243
+ "Monarch Butterfly": "Low",
244
+ "Carpenter Ant": "Medium",
245
+ "Japanese Beetle": "High",
246
+ "Garden Spider": "Low",
247
+ "Green Grasshopper": "Medium",
248
+ "Luna Moth": "Low",
249
+ "Common Dragonfly": "Low",
250
+ "Honey Bee": "Low",
251
+ "Paper Wasp": "Medium",
252
+ "Unknown Insect": "Unknown",
253
+ "Image Unclear": "Unknown"
254
  }
255
+ return severity_map.get(species, "Unknown")