File size: 6,601 Bytes
580daa1
 
 
 
 
 
 
 
 
f88bb49
9ae2699
f88bb49
9ae2699
f88bb49
 
 
9ae2699
f88bb49
 
 
 
9ae2699
f88bb49
3ed4e4e
 
 
f88bb49
 
9ae2699
f88bb49
3ed4e4e
9ae2699
 
3ed4e4e
 
 
9ae2699
 
 
f88bb49
9ae2699
f88bb49
 
9ae2699
f88bb49
 
9ae2699
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ed4e4e
9ae2699
3ed4e4e
 
 
 
 
 
 
 
 
 
 
 
9ae2699
 
 
3ed4e4e
9ae2699
 
 
3ed4e4e
9ae2699
3ed4e4e
 
 
 
 
580daa1
f88bb49
 
1902761
 
 
f88bb49
 
580daa1
9ae2699
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f88bb49
9ae2699
f88bb49
9ae2699
f88bb49
 
 
 
9ae2699
 
f88bb49
9ae2699
 
f88bb49
 
9ae2699
f88bb49
 
9ae2699
 
 
f88bb49
9ae2699
 
 
 
f88bb49
 
9ae2699
f88bb49
9ae2699
f88bb49
 
9ae2699
 
3ed4e4e
f88bb49
 
 
3ed4e4e
 
 
 
 
 
 
 
 
 
 
9ae2699
f88bb49
9ae2699
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import torch
import torch.nn.functional as F
from transformers import ViTForImageClassification, AutoFeatureExtractor
import numpy as np
from PIL import Image
import cv2

class BugClassifier:
    def __init__(self):
        try:
            # Use standard ViT model without modifications
            self.model = ViTForImageClassification.from_pretrained(
                "google/vit-base-patch16-224",
                num_labels=10,
                ignore_mismatched_sizes=True
            )
            self.feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
            
            # Set model to evaluation mode
            self.model.eval()
            
            # Define class labels
            self.labels = [
                "Seven-spotted Ladybug", "Monarch Butterfly", "Carpenter Ant", 
                "Japanese Beetle", "Garden Spider", "Green Grasshopper",
                "Luna Moth", "Common Dragonfly", "Honey Bee", "Paper Wasp"
            ]
            
            # Species information database
            self.species_info = {
                "Seven-spotted Ladybug": """
                The Seven-spotted Ladybug is one of the most common ladybug species.
                These beneficial insects are natural predators of garden pests like aphids.
                Their distinct red coloring with seven black spots serves as a warning to predators.
                """,
                "Monarch Butterfly": """
                The Monarch Butterfly is known for its spectacular annual migration.
                They play a crucial role in pollination and are indicators of ecosystem health.
                Their orange and black wings serve as warning colors to predators.
                """,
                # Add other species info as needed...
            }
        except Exception as e:
            print(f"Error initializing model: {str(e)}")
            raise RuntimeError(f"Error initializing BugClassifier: {str(e)}")

    def preprocess_image(self, image):
        """Preprocess image for model input"""
        try:
            # Convert RGBA to RGB if necessary
            if image.mode == 'RGBA':
                image = image.convert('RGB')
            
            # Use feature extractor to handle resizing and normalization
            inputs = self.feature_extractor(images=image, return_tensors="pt")
            return inputs.pixel_values
            
        except Exception as e:
            print(f"Preprocessing error: {str(e)}")
            raise ValueError(f"Error preprocessing image: {str(e)}")

    def predict(self, image):
        """Make a prediction on the input image"""
        try:
            if not isinstance(image, Image.Image):
                raise ValueError("Input must be a PIL Image")

            # Preprocess image
            image_tensor = self.preprocess_image(image)

            # Make prediction
            with torch.no_grad():
                outputs = self.model(image_tensor)
                probs = F.softmax(outputs.logits, dim=-1).numpy()[0]
                
                # Get prediction with highest confidence
                pred_idx = np.argmax(probs)
                confidence = float(probs[pred_idx] * 100)
                
                # Check confidence threshold
                if confidence < 40:  # 40% threshold
                    return "Unknown Insect", confidence
                
                return self.labels[pred_idx], confidence
                
        except Exception as e:
            print(f"Prediction error: {str(e)}")
            return "Error Processing Image", 0.0

    def get_species_info(self, species):
        """Return information about a species"""
        default_info = f"""
        Information about {species}:
        This species is part of our insect database. While detailed information
        is still being compiled, all insects play important roles in their ecosystems.
        """
        return self.species_info.get(species, default_info)

    def compare_species(self, species1, species2):
        """Generate comparison information between two species"""
        info1 = self.get_species_info(species1)
        info2 = self.get_species_info(species2)
        
        return f"""
        **Comparing {species1} and {species2}:**
        
        {species1}:
        {info1}
        
        {species2}:
        {info2}
        
        Both species contribute to their ecosystems in unique ways.
        """

    def get_gradcam(self, image):
        """Generate a simple attention visualization"""
        try:
            # Create a basic heatmap using model outputs
            image_tensor = self.preprocess_image(image)
            
            with torch.no_grad():
                outputs = self.model(image_tensor, output_attentions=True)
                # Get attention weights from last layer
                attention = outputs.attentions[-1].mean(dim=1).mean(dim=1)
                
            # Convert attention to numpy and resize
            attention_map = attention.numpy()[0]
            attention_map = cv2.resize(attention_map, (224, 224))
            
            # Normalize the attention map
            attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())
            
            # Create heatmap
            heatmap = np.uint8(255 * attention_map)
            heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
            
            # Prepare original image
            original_image = image.copy()
            original_image = original_image.resize((224, 224))
            original_array = np.array(original_image)
            
            # Overlay heatmap on original image
            output = cv2.addWeighted(original_array, 0.7, heatmap, 0.3, 0)
            
            return Image.fromarray(output)
            
        except Exception as e:
            print(f"Grad-CAM error: {str(e)}")
            return image

def get_severity_prediction(species):
    """Predict ecological severity/impact based on species"""
    severity_map = {
        "Seven-spotted Ladybug": "Low",
        "Monarch Butterfly": "Low",
        "Carpenter Ant": "Medium",
        "Japanese Beetle": "High",
        "Garden Spider": "Low",
        "Green Grasshopper": "Medium",
        "Luna Moth": "Low",
        "Common Dragonfly": "Low",
        "Honey Bee": "Low",
        "Paper Wasp": "Medium",
        "Unknown Insect": "Unknown",
        "Error Processing Image": "Unknown"
    }
    return severity_map.get(species, "Medium")