Spaces:
Sleeping
Sleeping
File size: 6,601 Bytes
580daa1 f88bb49 9ae2699 f88bb49 9ae2699 f88bb49 9ae2699 f88bb49 9ae2699 f88bb49 3ed4e4e f88bb49 9ae2699 f88bb49 3ed4e4e 9ae2699 3ed4e4e 9ae2699 f88bb49 9ae2699 f88bb49 9ae2699 f88bb49 9ae2699 3ed4e4e 9ae2699 3ed4e4e 9ae2699 3ed4e4e 9ae2699 3ed4e4e 9ae2699 3ed4e4e 580daa1 f88bb49 1902761 f88bb49 580daa1 9ae2699 f88bb49 9ae2699 f88bb49 9ae2699 f88bb49 9ae2699 f88bb49 9ae2699 f88bb49 9ae2699 f88bb49 9ae2699 f88bb49 9ae2699 f88bb49 9ae2699 f88bb49 9ae2699 f88bb49 9ae2699 3ed4e4e f88bb49 3ed4e4e 9ae2699 f88bb49 9ae2699 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import torch
import torch.nn.functional as F
from transformers import ViTForImageClassification, AutoFeatureExtractor
import numpy as np
from PIL import Image
import cv2
class BugClassifier:
def __init__(self):
try:
# Use standard ViT model without modifications
self.model = ViTForImageClassification.from_pretrained(
"google/vit-base-patch16-224",
num_labels=10,
ignore_mismatched_sizes=True
)
self.feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
# Set model to evaluation mode
self.model.eval()
# Define class labels
self.labels = [
"Seven-spotted Ladybug", "Monarch Butterfly", "Carpenter Ant",
"Japanese Beetle", "Garden Spider", "Green Grasshopper",
"Luna Moth", "Common Dragonfly", "Honey Bee", "Paper Wasp"
]
# Species information database
self.species_info = {
"Seven-spotted Ladybug": """
The Seven-spotted Ladybug is one of the most common ladybug species.
These beneficial insects are natural predators of garden pests like aphids.
Their distinct red coloring with seven black spots serves as a warning to predators.
""",
"Monarch Butterfly": """
The Monarch Butterfly is known for its spectacular annual migration.
They play a crucial role in pollination and are indicators of ecosystem health.
Their orange and black wings serve as warning colors to predators.
""",
# Add other species info as needed...
}
except Exception as e:
print(f"Error initializing model: {str(e)}")
raise RuntimeError(f"Error initializing BugClassifier: {str(e)}")
def preprocess_image(self, image):
"""Preprocess image for model input"""
try:
# Convert RGBA to RGB if necessary
if image.mode == 'RGBA':
image = image.convert('RGB')
# Use feature extractor to handle resizing and normalization
inputs = self.feature_extractor(images=image, return_tensors="pt")
return inputs.pixel_values
except Exception as e:
print(f"Preprocessing error: {str(e)}")
raise ValueError(f"Error preprocessing image: {str(e)}")
def predict(self, image):
"""Make a prediction on the input image"""
try:
if not isinstance(image, Image.Image):
raise ValueError("Input must be a PIL Image")
# Preprocess image
image_tensor = self.preprocess_image(image)
# Make prediction
with torch.no_grad():
outputs = self.model(image_tensor)
probs = F.softmax(outputs.logits, dim=-1).numpy()[0]
# Get prediction with highest confidence
pred_idx = np.argmax(probs)
confidence = float(probs[pred_idx] * 100)
# Check confidence threshold
if confidence < 40: # 40% threshold
return "Unknown Insect", confidence
return self.labels[pred_idx], confidence
except Exception as e:
print(f"Prediction error: {str(e)}")
return "Error Processing Image", 0.0
def get_species_info(self, species):
"""Return information about a species"""
default_info = f"""
Information about {species}:
This species is part of our insect database. While detailed information
is still being compiled, all insects play important roles in their ecosystems.
"""
return self.species_info.get(species, default_info)
def compare_species(self, species1, species2):
"""Generate comparison information between two species"""
info1 = self.get_species_info(species1)
info2 = self.get_species_info(species2)
return f"""
**Comparing {species1} and {species2}:**
{species1}:
{info1}
{species2}:
{info2}
Both species contribute to their ecosystems in unique ways.
"""
def get_gradcam(self, image):
"""Generate a simple attention visualization"""
try:
# Create a basic heatmap using model outputs
image_tensor = self.preprocess_image(image)
with torch.no_grad():
outputs = self.model(image_tensor, output_attentions=True)
# Get attention weights from last layer
attention = outputs.attentions[-1].mean(dim=1).mean(dim=1)
# Convert attention to numpy and resize
attention_map = attention.numpy()[0]
attention_map = cv2.resize(attention_map, (224, 224))
# Normalize the attention map
attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())
# Create heatmap
heatmap = np.uint8(255 * attention_map)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
# Prepare original image
original_image = image.copy()
original_image = original_image.resize((224, 224))
original_array = np.array(original_image)
# Overlay heatmap on original image
output = cv2.addWeighted(original_array, 0.7, heatmap, 0.3, 0)
return Image.fromarray(output)
except Exception as e:
print(f"Grad-CAM error: {str(e)}")
return image
def get_severity_prediction(species):
"""Predict ecological severity/impact based on species"""
severity_map = {
"Seven-spotted Ladybug": "Low",
"Monarch Butterfly": "Low",
"Carpenter Ant": "Medium",
"Japanese Beetle": "High",
"Garden Spider": "Low",
"Green Grasshopper": "Medium",
"Luna Moth": "Low",
"Common Dragonfly": "Low",
"Honey Bee": "Low",
"Paper Wasp": "Medium",
"Unknown Insect": "Unknown",
"Error Processing Image": "Unknown"
}
return severity_map.get(species, "Medium") |