Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| from transformers import ViTForImageClassification, AutoFeatureExtractor | |
| import numpy as np | |
| from PIL import Image | |
| import cv2 | |
| class BugClassifier: | |
| def __init__(self): | |
| try: | |
| # Use standard ViT model without modifications | |
| self.model = ViTForImageClassification.from_pretrained( | |
| "google/vit-base-patch16-224", | |
| num_labels=10, | |
| ignore_mismatched_sizes=True | |
| ) | |
| self.feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224") | |
| # Set model to evaluation mode | |
| self.model.eval() | |
| # Define class labels | |
| self.labels = [ | |
| "Seven-spotted Ladybug", "Monarch Butterfly", "Carpenter Ant", | |
| "Japanese Beetle", "Garden Spider", "Green Grasshopper", | |
| "Luna Moth", "Common Dragonfly", "Honey Bee", "Paper Wasp" | |
| ] | |
| # Species information database | |
| self.species_info = { | |
| "Seven-spotted Ladybug": """ | |
| The Seven-spotted Ladybug is one of the most common ladybug species. | |
| These beneficial insects are natural predators of garden pests like aphids. | |
| Their distinct red coloring with seven black spots serves as a warning to predators. | |
| """, | |
| "Monarch Butterfly": """ | |
| The Monarch Butterfly is known for its spectacular annual migration. | |
| They play a crucial role in pollination and are indicators of ecosystem health. | |
| Their orange and black wings serve as warning colors to predators. | |
| """, | |
| # Add other species info as needed... | |
| } | |
| except Exception as e: | |
| print(f"Error initializing model: {str(e)}") | |
| raise RuntimeError(f"Error initializing BugClassifier: {str(e)}") | |
| def preprocess_image(self, image): | |
| """Preprocess image for model input""" | |
| try: | |
| # Convert RGBA to RGB if necessary | |
| if image.mode == 'RGBA': | |
| image = image.convert('RGB') | |
| # Use feature extractor to handle resizing and normalization | |
| inputs = self.feature_extractor(images=image, return_tensors="pt") | |
| return inputs.pixel_values | |
| except Exception as e: | |
| print(f"Preprocessing error: {str(e)}") | |
| raise ValueError(f"Error preprocessing image: {str(e)}") | |
| def predict(self, image): | |
| """Make a prediction on the input image""" | |
| try: | |
| if not isinstance(image, Image.Image): | |
| raise ValueError("Input must be a PIL Image") | |
| # Preprocess image | |
| image_tensor = self.preprocess_image(image) | |
| # Make prediction | |
| with torch.no_grad(): | |
| outputs = self.model(image_tensor) | |
| probs = F.softmax(outputs.logits, dim=-1).numpy()[0] | |
| # Get prediction with highest confidence | |
| pred_idx = np.argmax(probs) | |
| confidence = float(probs[pred_idx] * 100) | |
| # Check confidence threshold | |
| if confidence < 40: # 40% threshold | |
| return "Unknown Insect", confidence | |
| return self.labels[pred_idx], confidence | |
| except Exception as e: | |
| print(f"Prediction error: {str(e)}") | |
| return "Error Processing Image", 0.0 | |
| def get_species_info(self, species): | |
| """Return information about a species""" | |
| default_info = f""" | |
| Information about {species}: | |
| This species is part of our insect database. While detailed information | |
| is still being compiled, all insects play important roles in their ecosystems. | |
| """ | |
| return self.species_info.get(species, default_info) | |
| def compare_species(self, species1, species2): | |
| """Generate comparison information between two species""" | |
| info1 = self.get_species_info(species1) | |
| info2 = self.get_species_info(species2) | |
| return f""" | |
| **Comparing {species1} and {species2}:** | |
| {species1}: | |
| {info1} | |
| {species2}: | |
| {info2} | |
| Both species contribute to their ecosystems in unique ways. | |
| """ | |
| def get_gradcam(self, image): | |
| """Generate a simple attention visualization""" | |
| try: | |
| # Create a basic heatmap using model outputs | |
| image_tensor = self.preprocess_image(image) | |
| with torch.no_grad(): | |
| outputs = self.model(image_tensor, output_attentions=True) | |
| # Get attention weights from last layer | |
| attention = outputs.attentions[-1].mean(dim=1).mean(dim=1) | |
| # Convert attention to numpy and resize | |
| attention_map = attention.numpy()[0] | |
| attention_map = cv2.resize(attention_map, (224, 224)) | |
| # Normalize the attention map | |
| attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min()) | |
| # Create heatmap | |
| heatmap = np.uint8(255 * attention_map) | |
| heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) | |
| # Prepare original image | |
| original_image = image.copy() | |
| original_image = original_image.resize((224, 224)) | |
| original_array = np.array(original_image) | |
| # Overlay heatmap on original image | |
| output = cv2.addWeighted(original_array, 0.7, heatmap, 0.3, 0) | |
| return Image.fromarray(output) | |
| except Exception as e: | |
| print(f"Grad-CAM error: {str(e)}") | |
| return image | |
| def get_severity_prediction(species): | |
| """Predict ecological severity/impact based on species""" | |
| severity_map = { | |
| "Seven-spotted Ladybug": "Low", | |
| "Monarch Butterfly": "Low", | |
| "Carpenter Ant": "Medium", | |
| "Japanese Beetle": "High", | |
| "Garden Spider": "Low", | |
| "Green Grasshopper": "Medium", | |
| "Luna Moth": "Low", | |
| "Common Dragonfly": "Low", | |
| "Honey Bee": "Low", | |
| "Paper Wasp": "Medium", | |
| "Unknown Insect": "Unknown", | |
| "Error Processing Image": "Unknown" | |
| } | |
| return severity_map.get(species, "Medium") |