Bug-O-Scope / model_utils.py
dalybuilds's picture
Update model_utils.py
9ae2699 verified
raw
history blame
6.6 kB
import torch
import torch.nn.functional as F
from transformers import ViTForImageClassification, AutoFeatureExtractor
import numpy as np
from PIL import Image
import cv2
class BugClassifier:
def __init__(self):
try:
# Use standard ViT model without modifications
self.model = ViTForImageClassification.from_pretrained(
"google/vit-base-patch16-224",
num_labels=10,
ignore_mismatched_sizes=True
)
self.feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
# Set model to evaluation mode
self.model.eval()
# Define class labels
self.labels = [
"Seven-spotted Ladybug", "Monarch Butterfly", "Carpenter Ant",
"Japanese Beetle", "Garden Spider", "Green Grasshopper",
"Luna Moth", "Common Dragonfly", "Honey Bee", "Paper Wasp"
]
# Species information database
self.species_info = {
"Seven-spotted Ladybug": """
The Seven-spotted Ladybug is one of the most common ladybug species.
These beneficial insects are natural predators of garden pests like aphids.
Their distinct red coloring with seven black spots serves as a warning to predators.
""",
"Monarch Butterfly": """
The Monarch Butterfly is known for its spectacular annual migration.
They play a crucial role in pollination and are indicators of ecosystem health.
Their orange and black wings serve as warning colors to predators.
""",
# Add other species info as needed...
}
except Exception as e:
print(f"Error initializing model: {str(e)}")
raise RuntimeError(f"Error initializing BugClassifier: {str(e)}")
def preprocess_image(self, image):
"""Preprocess image for model input"""
try:
# Convert RGBA to RGB if necessary
if image.mode == 'RGBA':
image = image.convert('RGB')
# Use feature extractor to handle resizing and normalization
inputs = self.feature_extractor(images=image, return_tensors="pt")
return inputs.pixel_values
except Exception as e:
print(f"Preprocessing error: {str(e)}")
raise ValueError(f"Error preprocessing image: {str(e)}")
def predict(self, image):
"""Make a prediction on the input image"""
try:
if not isinstance(image, Image.Image):
raise ValueError("Input must be a PIL Image")
# Preprocess image
image_tensor = self.preprocess_image(image)
# Make prediction
with torch.no_grad():
outputs = self.model(image_tensor)
probs = F.softmax(outputs.logits, dim=-1).numpy()[0]
# Get prediction with highest confidence
pred_idx = np.argmax(probs)
confidence = float(probs[pred_idx] * 100)
# Check confidence threshold
if confidence < 40: # 40% threshold
return "Unknown Insect", confidence
return self.labels[pred_idx], confidence
except Exception as e:
print(f"Prediction error: {str(e)}")
return "Error Processing Image", 0.0
def get_species_info(self, species):
"""Return information about a species"""
default_info = f"""
Information about {species}:
This species is part of our insect database. While detailed information
is still being compiled, all insects play important roles in their ecosystems.
"""
return self.species_info.get(species, default_info)
def compare_species(self, species1, species2):
"""Generate comparison information between two species"""
info1 = self.get_species_info(species1)
info2 = self.get_species_info(species2)
return f"""
**Comparing {species1} and {species2}:**
{species1}:
{info1}
{species2}:
{info2}
Both species contribute to their ecosystems in unique ways.
"""
def get_gradcam(self, image):
"""Generate a simple attention visualization"""
try:
# Create a basic heatmap using model outputs
image_tensor = self.preprocess_image(image)
with torch.no_grad():
outputs = self.model(image_tensor, output_attentions=True)
# Get attention weights from last layer
attention = outputs.attentions[-1].mean(dim=1).mean(dim=1)
# Convert attention to numpy and resize
attention_map = attention.numpy()[0]
attention_map = cv2.resize(attention_map, (224, 224))
# Normalize the attention map
attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())
# Create heatmap
heatmap = np.uint8(255 * attention_map)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
# Prepare original image
original_image = image.copy()
original_image = original_image.resize((224, 224))
original_array = np.array(original_image)
# Overlay heatmap on original image
output = cv2.addWeighted(original_array, 0.7, heatmap, 0.3, 0)
return Image.fromarray(output)
except Exception as e:
print(f"Grad-CAM error: {str(e)}")
return image
def get_severity_prediction(species):
"""Predict ecological severity/impact based on species"""
severity_map = {
"Seven-spotted Ladybug": "Low",
"Monarch Butterfly": "Low",
"Carpenter Ant": "Medium",
"Japanese Beetle": "High",
"Garden Spider": "Low",
"Green Grasshopper": "Medium",
"Luna Moth": "Low",
"Common Dragonfly": "Low",
"Honey Bee": "Low",
"Paper Wasp": "Medium",
"Unknown Insect": "Unknown",
"Error Processing Image": "Unknown"
}
return severity_map.get(species, "Medium")