import torch import torch.nn.functional as F from transformers import ViTForImageClassification, AutoFeatureExtractor import numpy as np from PIL import Image import cv2 class BugClassifier: def __init__(self): try: # Use standard ViT model without modifications self.model = ViTForImageClassification.from_pretrained( "google/vit-base-patch16-224", num_labels=10, ignore_mismatched_sizes=True ) self.feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224") # Set model to evaluation mode self.model.eval() # Define class labels self.labels = [ "Seven-spotted Ladybug", "Monarch Butterfly", "Carpenter Ant", "Japanese Beetle", "Garden Spider", "Green Grasshopper", "Luna Moth", "Common Dragonfly", "Honey Bee", "Paper Wasp" ] # Species information database self.species_info = { "Seven-spotted Ladybug": """ The Seven-spotted Ladybug is one of the most common ladybug species. These beneficial insects are natural predators of garden pests like aphids. Their distinct red coloring with seven black spots serves as a warning to predators. """, "Monarch Butterfly": """ The Monarch Butterfly is known for its spectacular annual migration. They play a crucial role in pollination and are indicators of ecosystem health. Their orange and black wings serve as warning colors to predators. """, # Add other species info as needed... } except Exception as e: print(f"Error initializing model: {str(e)}") raise RuntimeError(f"Error initializing BugClassifier: {str(e)}") def preprocess_image(self, image): """Preprocess image for model input""" try: # Convert RGBA to RGB if necessary if image.mode == 'RGBA': image = image.convert('RGB') # Use feature extractor to handle resizing and normalization inputs = self.feature_extractor(images=image, return_tensors="pt") return inputs.pixel_values except Exception as e: print(f"Preprocessing error: {str(e)}") raise ValueError(f"Error preprocessing image: {str(e)}") def predict(self, image): """Make a prediction on the input image""" try: if not isinstance(image, Image.Image): raise ValueError("Input must be a PIL Image") # Preprocess image image_tensor = self.preprocess_image(image) # Make prediction with torch.no_grad(): outputs = self.model(image_tensor) probs = F.softmax(outputs.logits, dim=-1).numpy()[0] # Get prediction with highest confidence pred_idx = np.argmax(probs) confidence = float(probs[pred_idx] * 100) # Check confidence threshold if confidence < 40: # 40% threshold return "Unknown Insect", confidence return self.labels[pred_idx], confidence except Exception as e: print(f"Prediction error: {str(e)}") return "Error Processing Image", 0.0 def get_species_info(self, species): """Return information about a species""" default_info = f""" Information about {species}: This species is part of our insect database. While detailed information is still being compiled, all insects play important roles in their ecosystems. """ return self.species_info.get(species, default_info) def compare_species(self, species1, species2): """Generate comparison information between two species""" info1 = self.get_species_info(species1) info2 = self.get_species_info(species2) return f""" **Comparing {species1} and {species2}:** {species1}: {info1} {species2}: {info2} Both species contribute to their ecosystems in unique ways. """ def get_gradcam(self, image): """Generate a simple attention visualization""" try: # Create a basic heatmap using model outputs image_tensor = self.preprocess_image(image) with torch.no_grad(): outputs = self.model(image_tensor, output_attentions=True) # Get attention weights from last layer attention = outputs.attentions[-1].mean(dim=1).mean(dim=1) # Convert attention to numpy and resize attention_map = attention.numpy()[0] attention_map = cv2.resize(attention_map, (224, 224)) # Normalize the attention map attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min()) # Create heatmap heatmap = np.uint8(255 * attention_map) heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) # Prepare original image original_image = image.copy() original_image = original_image.resize((224, 224)) original_array = np.array(original_image) # Overlay heatmap on original image output = cv2.addWeighted(original_array, 0.7, heatmap, 0.3, 0) return Image.fromarray(output) except Exception as e: print(f"Grad-CAM error: {str(e)}") return image def get_severity_prediction(species): """Predict ecological severity/impact based on species""" severity_map = { "Seven-spotted Ladybug": "Low", "Monarch Butterfly": "Low", "Carpenter Ant": "Medium", "Japanese Beetle": "High", "Garden Spider": "Low", "Green Grasshopper": "Medium", "Luna Moth": "Low", "Common Dragonfly": "Low", "Honey Bee": "Low", "Paper Wasp": "Medium", "Unknown Insect": "Unknown", "Error Processing Image": "Unknown" } return severity_map.get(species, "Medium")