Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| from transformers import ViTForImageClassification, AutoFeatureExtractor | |
| import numpy as np | |
| from PIL import Image | |
| import cv2 | |
| from scipy.special import softmax | |
| class BugClassifier: | |
| def __init__(self): | |
| try: | |
| # Initialize model and feature extractor | |
| self.model = ViTForImageClassification.from_pretrained( | |
| "microsoft/beit-base-patch16-224-pt22k-ft22k", | |
| num_labels=10, | |
| ignore_mismatched_sizes=True | |
| ) | |
| # Add custom classification head | |
| self.model.classifier = torch.nn.Sequential( | |
| torch.nn.Linear(768, 512), | |
| torch.nn.ReLU(), | |
| torch.nn.Dropout(0.2), | |
| torch.nn.Linear(512, 10) # 10 classes | |
| ) | |
| self.feature_extractor = AutoFeatureExtractor.from_pretrained( | |
| "microsoft/beit-base-patch16-224-pt22k-ft22k" | |
| ) | |
| # Set model to evaluation mode | |
| self.model.eval() | |
| # Define detailed class labels | |
| self.labels = [ | |
| "Seven-spotted Ladybug", "Monarch Butterfly", "Carpenter Ant", | |
| "Japanese Beetle", "Garden Spider", "Green Grasshopper", | |
| "Luna Moth", "Common Dragonfly", "Honey Bee", "Paper Wasp" | |
| ] | |
| # Create a mapping of general categories for better classification | |
| self.category_mapping = { | |
| "Seven-spotted Ladybug": ["ladybug", "ladybird", "coccinellidae"], | |
| "Monarch Butterfly": ["butterfly", "lepidoptera"], | |
| "Carpenter Ant": ["ant", "formicidae"], | |
| "Japanese Beetle": ["beetle", "coleoptera"], | |
| "Garden Spider": ["spider", "arachnid"], | |
| "Green Grasshopper": ["grasshopper", "orthoptera"], | |
| "Luna Moth": ["moth", "lepidoptera"], | |
| "Common Dragonfly": ["dragonfly", "odonata"], | |
| "Honey Bee": ["bee", "apidae"], | |
| "Paper Wasp": ["wasp", "vespidae"] | |
| } | |
| # Detailed species information database | |
| self.species_info = { | |
| "Seven-spotted Ladybug": """ | |
| The Seven-spotted Ladybug (Coccinella septempunctata) is one of the most common ladybug species. | |
| These beneficial insects are natural predators of garden pests like aphids and scale insects. | |
| Each ladybug can eat up to 5,000 aphids during its lifetime, making them excellent natural pest controllers. | |
| Their distinct red coloring with seven black spots serves as a warning to predators. | |
| """, | |
| "Monarch Butterfly": """ | |
| The Monarch Butterfly (Danaus plexippus) is known for its spectacular annual migration. | |
| These butterflies play a crucial role in pollination and are indicators of ecosystem health. | |
| They have a unique relationship with milkweed plants, which their caterpillars exclusively feed on. | |
| Their orange and black wings serve as warning colors to predators about their toxicity. | |
| """, | |
| "Carpenter Ant": """ | |
| Carpenter Ants (Camponotus spp.) are large ants that build nests in wood. | |
| While they don't eat wood like termites, they can cause structural damage to buildings. | |
| These social insects live in colonies and play important roles in forest ecosystems, | |
| helping to break down dead wood and maintain soil health. | |
| """, | |
| "Japanese Beetle": """ | |
| The Japanese Beetle (Popillia japonica) is recognized by its metallic green body. | |
| While beautiful, these beetles can be significant garden pests, feeding on many plant species. | |
| They are most active in summer months and can be managed through various natural control methods. | |
| Their presence often indicates a healthy soil ecosystem, though their feeding can damage plants. | |
| """, | |
| # Add other species info here... | |
| } | |
| except Exception as e: | |
| raise RuntimeError(f"Error initializing BugClassifier: {str(e)}") | |
| def predict(self, image): | |
| """Make a prediction on the input image with improved confidence handling""" | |
| try: | |
| if not isinstance(image, Image.Image): | |
| raise ValueError("Input must be a PIL Image") | |
| # Preprocess image | |
| image_tensor = self.preprocess_image(image) | |
| # Make prediction | |
| with torch.no_grad(): | |
| outputs = self.model(image_tensor) | |
| probs = F.softmax(outputs.logits, dim=-1).numpy()[0] | |
| # Get top 3 predictions | |
| top3_idx = np.argsort(probs)[-3:][::-1] | |
| top3_probs = probs[top3_idx] | |
| # Use confidence threshold | |
| CONFIDENCE_THRESHOLD = 0.4 # 40% confidence threshold | |
| if top3_probs[0] < CONFIDENCE_THRESHOLD: | |
| # If confidence is too low, return "Unknown" | |
| return "Unknown Insect", float(top3_probs[0] * 100) | |
| # Check if there's a clear winner (significantly higher than second best) | |
| if (top3_probs[0] - top3_probs[1]) > 0.2: # 20% margin | |
| pred_idx = top3_idx[0] | |
| else: | |
| # If it's close, consider image quality and features | |
| image_quality = self.assess_image_quality(image) | |
| if image_quality < 0.5: | |
| return "Image Unclear", 0.0 | |
| pred_idx = top3_idx[0] | |
| return self.labels[pred_idx], float(probs[pred_idx] * 100) | |
| except Exception as e: | |
| print(f"Prediction error: {str(e)}") | |
| return "Error Processing Image", 0.0 | |
| def preprocess_image(self, image): | |
| """Preprocess image for model input""" | |
| try: | |
| # Convert RGBA to RGB if necessary | |
| if image.mode == 'RGBA': | |
| image = image.convert('RGB') | |
| # Resize image if needed | |
| if image.size != (224, 224): | |
| image = image.resize((224, 224), Image.Resampling.LANCZOS) | |
| # Process image using feature extractor | |
| inputs = self.feature_extractor(images=image, return_tensors="pt") | |
| return inputs.pixel_values | |
| except Exception as e: | |
| raise ValueError(f"Error preprocessing image: {str(e)}") | |
| def assess_image_quality(self, image): | |
| """Assess the quality of the input image""" | |
| try: | |
| # Convert to numpy array | |
| img_array = np.array(image) | |
| # Check brightness | |
| brightness = np.mean(img_array) | |
| # Check contrast | |
| contrast = np.std(img_array) | |
| # Check blur | |
| if len(img_array.shape) == 3: | |
| gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) | |
| else: | |
| gray = img_array | |
| blur_score = cv2.Laplacian(gray, cv2.CV_64F).var() | |
| # Normalize and combine scores | |
| brightness_score = 1 - abs(brightness - 128) / 128 | |
| contrast_score = min(contrast / 50, 1) | |
| blur_score = min(blur_score / 1000, 1) | |
| return (brightness_score + contrast_score + blur_score) / 3 | |
| except Exception as e: | |
| print(f"Error assessing image quality: {str(e)}") | |
| return 0.5 # Return middle value if assessment fails | |
| def get_species_info(self, species): | |
| """Return information about a species""" | |
| default_info = f""" | |
| Information about {species}: | |
| This species is part of our insect database. While detailed information | |
| is still being compiled, all insects play important roles in their ecosystems. | |
| """ | |
| return self.species_info.get(species, default_info) | |
| def get_gradcam(self, image): | |
| """Generate Grad-CAM visualization for the image""" | |
| try: | |
| # Preprocess image | |
| image_tensor = self.preprocess_image(image) | |
| # Get model attention weights | |
| with torch.no_grad(): | |
| outputs = self.model(image_tensor, output_attentions=True) | |
| attention = outputs.attentions[-1] | |
| # Convert attention to heatmap | |
| attention_map = attention.mean(dim=1).mean(dim=1).numpy()[0] | |
| # Resize attention map to image size | |
| attention_map = cv2.resize(attention_map, (224, 224)) | |
| # Normalize attention map | |
| attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min()) | |
| # Convert to heatmap | |
| heatmap = cv2.applyColorMap(np.uint8(255 * attention_map), cv2.COLORMAP_JET) | |
| # Convert original image to RGB numpy array | |
| original_image = np.array(image.resize((224, 224))) | |
| if len(original_image.shape) == 2: | |
| original_image = cv2.cvtColor(original_image, cv2.COLOR_GRAY2RGB) | |
| # Overlay heatmap on original image | |
| overlay = cv2.addWeighted(original_image, 0.7, heatmap, 0.3, 0) | |
| return Image.fromarray(overlay) | |
| except Exception as e: | |
| print(f"Error generating Grad-CAM: {str(e)}") | |
| return image # Return original image if Grad-CAM fails | |
| def compare_species(self, species1, species2): | |
| """Generate comparison information between two species""" | |
| info1 = self.get_species_info(species1) | |
| info2 = self.get_species_info(species2) | |
| return f""" | |
| **Comparing {species1} and {species2}:** | |
| {species1}: | |
| {info1} | |
| {species2}: | |
| {info2} | |
| Both species contribute to their ecosystems in unique ways. | |
| """ | |
| def get_severity_prediction(species): | |
| """Predict ecological severity/impact based on species""" | |
| severity_map = { | |
| "Seven-spotted Ladybug": "Low", | |
| "Monarch Butterfly": "Low", | |
| "Carpenter Ant": "Medium", | |
| "Japanese Beetle": "High", | |
| "Garden Spider": "Low", | |
| "Green Grasshopper": "Medium", | |
| "Luna Moth": "Low", | |
| "Common Dragonfly": "Low", | |
| "Honey Bee": "Low", | |
| "Paper Wasp": "Medium", | |
| "Unknown Insect": "Unknown", | |
| "Image Unclear": "Unknown" | |
| } | |
| return severity_map.get(species, "Unknown") |