Bug-O-Scope / model_utils.py
dalybuilds's picture
Update model_utils.py
3ed4e4e verified
raw
history blame
11 kB
import torch
import torch.nn.functional as F
from transformers import ViTForImageClassification, AutoFeatureExtractor
import numpy as np
from PIL import Image
import cv2
from scipy.special import softmax
class BugClassifier:
def __init__(self):
try:
# Initialize model and feature extractor
self.model = ViTForImageClassification.from_pretrained(
"microsoft/beit-base-patch16-224-pt22k-ft22k",
num_labels=10,
ignore_mismatched_sizes=True
)
# Add custom classification head
self.model.classifier = torch.nn.Sequential(
torch.nn.Linear(768, 512),
torch.nn.ReLU(),
torch.nn.Dropout(0.2),
torch.nn.Linear(512, 10) # 10 classes
)
self.feature_extractor = AutoFeatureExtractor.from_pretrained(
"microsoft/beit-base-patch16-224-pt22k-ft22k"
)
# Set model to evaluation mode
self.model.eval()
# Define detailed class labels
self.labels = [
"Seven-spotted Ladybug", "Monarch Butterfly", "Carpenter Ant",
"Japanese Beetle", "Garden Spider", "Green Grasshopper",
"Luna Moth", "Common Dragonfly", "Honey Bee", "Paper Wasp"
]
# Create a mapping of general categories for better classification
self.category_mapping = {
"Seven-spotted Ladybug": ["ladybug", "ladybird", "coccinellidae"],
"Monarch Butterfly": ["butterfly", "lepidoptera"],
"Carpenter Ant": ["ant", "formicidae"],
"Japanese Beetle": ["beetle", "coleoptera"],
"Garden Spider": ["spider", "arachnid"],
"Green Grasshopper": ["grasshopper", "orthoptera"],
"Luna Moth": ["moth", "lepidoptera"],
"Common Dragonfly": ["dragonfly", "odonata"],
"Honey Bee": ["bee", "apidae"],
"Paper Wasp": ["wasp", "vespidae"]
}
# Detailed species information database
self.species_info = {
"Seven-spotted Ladybug": """
The Seven-spotted Ladybug (Coccinella septempunctata) is one of the most common ladybug species.
These beneficial insects are natural predators of garden pests like aphids and scale insects.
Each ladybug can eat up to 5,000 aphids during its lifetime, making them excellent natural pest controllers.
Their distinct red coloring with seven black spots serves as a warning to predators.
""",
"Monarch Butterfly": """
The Monarch Butterfly (Danaus plexippus) is known for its spectacular annual migration.
These butterflies play a crucial role in pollination and are indicators of ecosystem health.
They have a unique relationship with milkweed plants, which their caterpillars exclusively feed on.
Their orange and black wings serve as warning colors to predators about their toxicity.
""",
"Carpenter Ant": """
Carpenter Ants (Camponotus spp.) are large ants that build nests in wood.
While they don't eat wood like termites, they can cause structural damage to buildings.
These social insects live in colonies and play important roles in forest ecosystems,
helping to break down dead wood and maintain soil health.
""",
"Japanese Beetle": """
The Japanese Beetle (Popillia japonica) is recognized by its metallic green body.
While beautiful, these beetles can be significant garden pests, feeding on many plant species.
They are most active in summer months and can be managed through various natural control methods.
Their presence often indicates a healthy soil ecosystem, though their feeding can damage plants.
""",
# Add other species info here...
}
except Exception as e:
raise RuntimeError(f"Error initializing BugClassifier: {str(e)}")
def predict(self, image):
"""Make a prediction on the input image with improved confidence handling"""
try:
if not isinstance(image, Image.Image):
raise ValueError("Input must be a PIL Image")
# Preprocess image
image_tensor = self.preprocess_image(image)
# Make prediction
with torch.no_grad():
outputs = self.model(image_tensor)
probs = F.softmax(outputs.logits, dim=-1).numpy()[0]
# Get top 3 predictions
top3_idx = np.argsort(probs)[-3:][::-1]
top3_probs = probs[top3_idx]
# Use confidence threshold
CONFIDENCE_THRESHOLD = 0.4 # 40% confidence threshold
if top3_probs[0] < CONFIDENCE_THRESHOLD:
# If confidence is too low, return "Unknown"
return "Unknown Insect", float(top3_probs[0] * 100)
# Check if there's a clear winner (significantly higher than second best)
if (top3_probs[0] - top3_probs[1]) > 0.2: # 20% margin
pred_idx = top3_idx[0]
else:
# If it's close, consider image quality and features
image_quality = self.assess_image_quality(image)
if image_quality < 0.5:
return "Image Unclear", 0.0
pred_idx = top3_idx[0]
return self.labels[pred_idx], float(probs[pred_idx] * 100)
except Exception as e:
print(f"Prediction error: {str(e)}")
return "Error Processing Image", 0.0
def preprocess_image(self, image):
"""Preprocess image for model input"""
try:
# Convert RGBA to RGB if necessary
if image.mode == 'RGBA':
image = image.convert('RGB')
# Resize image if needed
if image.size != (224, 224):
image = image.resize((224, 224), Image.Resampling.LANCZOS)
# Process image using feature extractor
inputs = self.feature_extractor(images=image, return_tensors="pt")
return inputs.pixel_values
except Exception as e:
raise ValueError(f"Error preprocessing image: {str(e)}")
def assess_image_quality(self, image):
"""Assess the quality of the input image"""
try:
# Convert to numpy array
img_array = np.array(image)
# Check brightness
brightness = np.mean(img_array)
# Check contrast
contrast = np.std(img_array)
# Check blur
if len(img_array.shape) == 3:
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
else:
gray = img_array
blur_score = cv2.Laplacian(gray, cv2.CV_64F).var()
# Normalize and combine scores
brightness_score = 1 - abs(brightness - 128) / 128
contrast_score = min(contrast / 50, 1)
blur_score = min(blur_score / 1000, 1)
return (brightness_score + contrast_score + blur_score) / 3
except Exception as e:
print(f"Error assessing image quality: {str(e)}")
return 0.5 # Return middle value if assessment fails
def get_species_info(self, species):
"""Return information about a species"""
default_info = f"""
Information about {species}:
This species is part of our insect database. While detailed information
is still being compiled, all insects play important roles in their ecosystems.
"""
return self.species_info.get(species, default_info)
def get_gradcam(self, image):
"""Generate Grad-CAM visualization for the image"""
try:
# Preprocess image
image_tensor = self.preprocess_image(image)
# Get model attention weights
with torch.no_grad():
outputs = self.model(image_tensor, output_attentions=True)
attention = outputs.attentions[-1]
# Convert attention to heatmap
attention_map = attention.mean(dim=1).mean(dim=1).numpy()[0]
# Resize attention map to image size
attention_map = cv2.resize(attention_map, (224, 224))
# Normalize attention map
attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())
# Convert to heatmap
heatmap = cv2.applyColorMap(np.uint8(255 * attention_map), cv2.COLORMAP_JET)
# Convert original image to RGB numpy array
original_image = np.array(image.resize((224, 224)))
if len(original_image.shape) == 2:
original_image = cv2.cvtColor(original_image, cv2.COLOR_GRAY2RGB)
# Overlay heatmap on original image
overlay = cv2.addWeighted(original_image, 0.7, heatmap, 0.3, 0)
return Image.fromarray(overlay)
except Exception as e:
print(f"Error generating Grad-CAM: {str(e)}")
return image # Return original image if Grad-CAM fails
def compare_species(self, species1, species2):
"""Generate comparison information between two species"""
info1 = self.get_species_info(species1)
info2 = self.get_species_info(species2)
return f"""
**Comparing {species1} and {species2}:**
{species1}:
{info1}
{species2}:
{info2}
Both species contribute to their ecosystems in unique ways.
"""
def get_severity_prediction(species):
"""Predict ecological severity/impact based on species"""
severity_map = {
"Seven-spotted Ladybug": "Low",
"Monarch Butterfly": "Low",
"Carpenter Ant": "Medium",
"Japanese Beetle": "High",
"Garden Spider": "Low",
"Green Grasshopper": "Medium",
"Luna Moth": "Low",
"Common Dragonfly": "Low",
"Honey Bee": "Low",
"Paper Wasp": "Medium",
"Unknown Insect": "Unknown",
"Image Unclear": "Unknown"
}
return severity_map.get(species, "Unknown")