dcrey7 commited on
Commit
a4e8b55
·
verified ·
1 Parent(s): 9ee35b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -80
app.py CHANGED
@@ -4,7 +4,6 @@ import torch
4
  import numpy as np
5
  import matplotlib.pyplot as plt
6
  from PIL import Image
7
- import torchvision.transforms as transforms
8
  import requests
9
  import io
10
  import matplotlib.colors as mcolors
@@ -17,27 +16,49 @@ from rasterio.plot import reshape_as_image
17
  import warnings
18
  warnings.filterwarnings("ignore")
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  # Set device
21
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
  print(f"Using device: {device}")
23
 
24
- # Define a custom DeepLabV3+ model that matches your trained model architecture
25
- class DeepLabV3Plus(torch.nn.Module):
26
- def __init__(self, num_classes=2):
27
- super(DeepLabV3Plus, self).__init__()
28
- self.encoder = torch.nn.Sequential() # ResNet backbone
29
- self.decoder = torch.nn.Sequential() # Decoder modules
30
- self.segmentation_head = torch.nn.Conv2d(256, num_classes, kernel_size=1)
31
-
32
- def forward(self, x):
33
- # Forward pass (simplified since we're only using this for loading weights)
34
- features = self.encoder(x)
35
- decoder_output = self.decoder(features)
36
- masks = self.segmentation_head(decoder_output)
37
- return masks
38
-
39
  # Initialize the model
40
- model = DeepLabV3Plus(num_classes=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  # Download model weights from HuggingFace
43
  MODEL_REPO = "dcrey7/wetlands_segmentation_deeplabsv3plus"
@@ -64,31 +85,25 @@ def download_model_weights():
64
  print(f"Error downloading model weights: {e}")
65
  return None
66
 
67
- # Dummy model for testing if model weights can't be loaded
68
- class DummyModel(torch.nn.Module):
69
- def __init__(self):
70
- super(DummyModel, self).__init__()
71
-
72
- def forward(self, x):
73
- # Simply return a random segmentation mask for visualization
74
- batch_size, _, height, width = x.shape
75
- return torch.randint(0, 2, (batch_size, 2, height, width), device=x.device).float()
76
-
77
  # Load the model weights
78
  weights_path = download_model_weights()
79
  if weights_path:
80
  try:
81
- # Try to load using strict=False to allow for partial weight loading
82
  state_dict = torch.load(weights_path, map_location=device)
83
- model.load_state_dict(state_dict, strict=False)
84
- print("Model weights loaded with non-strict mapping")
 
 
 
 
 
 
 
85
  except Exception as e:
86
  print(f"Error loading model weights: {e}")
87
- print("Using dummy model for demo purposes")
88
- model = DummyModel()
89
  else:
90
- print("No weights available. Using dummy model.")
91
- model = DummyModel()
92
 
93
  model.to(device)
94
  model.eval()
@@ -167,8 +182,22 @@ def preprocess_mask(mask, target_size=(128, 128)):
167
  """
168
  Preprocess a ground truth mask
169
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  # Convert to numpy array if PIL image
171
- if isinstance(mask, Image.Image):
172
  mask = np.array(mask)
173
 
174
  # Convert to grayscale if needed
@@ -193,22 +222,18 @@ def predict_segmentation(image_tensor):
193
  with torch.no_grad():
194
  output = model(image_tensor)
195
 
196
- # Get the predicted class (0: background, 1: wetland)
197
- # Handle different output formats
198
- if isinstance(output, dict) and 'out' in output:
199
  output = output['out']
200
-
201
- if output.shape[1] > 1: # If output has multiple channels (classes)
202
  pred = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
203
- else:
204
- # If output is single channel, threshold it
205
- pred = (output.squeeze(0).squeeze(0) > 0.5).cpu().numpy().astype(np.uint8)
206
-
207
  return pred
208
  except Exception as e:
209
  print(f"Error during prediction: {e}")
210
- # Return random prediction for demo purposes
211
- return np.random.randint(0, 2, (128, 128), dtype=np.uint8)
212
 
213
  def calculate_metrics(pred_mask, gt_mask):
214
  """
@@ -245,29 +270,6 @@ def calculate_metrics(pred_mask, gt_mask):
245
 
246
  return metrics
247
 
248
- def save_uploaded_file(file_obj):
249
- """Save an uploaded file to a temporary location and return the path"""
250
- try:
251
- # Create a temporary file
252
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.tif')
253
- temp_path = temp_file.name
254
-
255
- # Write the content to the file
256
- if hasattr(file_obj, 'name'):
257
- # If it's a FileUpload object from gradio
258
- with open(file_obj.name, 'rb') as f:
259
- content = f.read()
260
- temp_file.write(content)
261
- else:
262
- # If it's binary content
263
- temp_file.write(file_obj)
264
-
265
- temp_file.close()
266
- return temp_path
267
- except Exception as e:
268
- print(f"Error saving uploaded file: {e}")
269
- return None
270
-
271
  def process_images(input_image=None, input_tiff=None, gt_mask=None):
272
  """
273
  Process input images and generate predictions
@@ -279,10 +281,18 @@ def process_images(input_image=None, input_tiff=None, gt_mask=None):
279
 
280
  # Process the input image
281
  if input_tiff is not None:
282
- # Save uploaded TIFF to a temporary file
283
- temp_tiff_path = save_uploaded_file(input_tiff)
284
- if not temp_tiff_path:
285
- return None, "Failed to process the uploaded TIFF file."
 
 
 
 
 
 
 
 
286
 
287
  # Process TIFF file
288
  image_tensor, display_image = preprocess_tiff(temp_tiff_path)
@@ -308,9 +318,28 @@ def process_images(input_image=None, input_tiff=None, gt_mask=None):
308
  metrics_text = ""
309
 
310
  if gt_mask is not None:
311
- gt_mask_processed = preprocess_mask(gt_mask)
312
- metrics = calculate_metrics(pred_mask, gt_mask_processed)
313
- metrics_text = "\n".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
 
315
  # Create visualization
316
  fig = plt.figure(figsize=(12, 6))
@@ -362,10 +391,9 @@ def process_images(input_image=None, input_tiff=None, gt_mask=None):
362
  return result_image, result_text
363
 
364
  except Exception as e:
365
- import traceback
366
- trace = traceback.format_exc()
367
  print(f"Error in processing: {e}")
368
- print(trace)
 
369
  return None, f"Error: {str(e)}"
370
 
371
  # Create Gradio interface
@@ -399,7 +427,7 @@ with gr.Blocks(title="Wetlands Segmentation from Satellite Imagery") as demo:
399
  This application uses a DeepLabv3+ model trained to segment wetland areas in satellite imagery.
400
 
401
  **Model Details:**
402
- - Architecture: DeepLabv3+ with ResNet-50 backbone
403
  - Input: RGB satellite imagery
404
  - Output: Binary segmentation mask (Wetland vs Background)
405
  - Resolution: 128×128 pixels
 
4
  import numpy as np
5
  import matplotlib.pyplot as plt
6
  from PIL import Image
 
7
  import requests
8
  import io
9
  import matplotlib.colors as mcolors
 
16
  import warnings
17
  warnings.filterwarnings("ignore")
18
 
19
+ # Try to import segmentation_models_pytorch
20
+ try:
21
+ import segmentation_models_pytorch as smp
22
+ smp_available = True
23
+ print("Successfully imported segmentation_models_pytorch")
24
+ except ImportError:
25
+ smp_available = False
26
+ print("Warning: segmentation_models_pytorch not available, will try to install it")
27
+ import subprocess
28
+ try:
29
+ subprocess.check_call([
30
+ "pip", "install", "segmentation-models-pytorch"
31
+ ])
32
+ import segmentation_models_pytorch as smp
33
+ smp_available = True
34
+ print("Successfully installed and imported segmentation_models_pytorch")
35
+ except:
36
+ print("Failed to install segmentation_models_pytorch")
37
+
38
  # Set device
39
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
  print(f"Using device: {device}")
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  # Initialize the model
43
+ if smp_available:
44
+ # Define the DeepLabV3+ model using smp
45
+ model = smp.DeepLabV3Plus(
46
+ encoder_name="resnet34", # Using ResNet34 backbone as in your training
47
+ encoder_weights=None, # We'll load your custom weights
48
+ in_channels=3, # RGB input
49
+ classes=1, # Binary segmentation
50
+ )
51
+ else:
52
+ # Fallback to a simple model that won't actually work but allows the UI to load
53
+ print("Warning: Using a placeholder model that won't produce correct predictions.")
54
+ from torch import nn
55
+ class PlaceholderModel(nn.Module):
56
+ def __init__(self):
57
+ super().__init__()
58
+ self.conv = nn.Conv2d(3, 1, 3, padding=1)
59
+ def forward(self, x):
60
+ return self.conv(x)
61
+ model = PlaceholderModel()
62
 
63
  # Download model weights from HuggingFace
64
  MODEL_REPO = "dcrey7/wetlands_segmentation_deeplabsv3plus"
 
85
  print(f"Error downloading model weights: {e}")
86
  return None
87
 
 
 
 
 
 
 
 
 
 
 
88
  # Load the model weights
89
  weights_path = download_model_weights()
90
  if weights_path:
91
  try:
92
+ # Try to load with strict=False to allow for some parameter mismatches
93
  state_dict = torch.load(weights_path, map_location=device)
94
+ # Check if we need to modify the state dict keys
95
+ if all(key.startswith('encoder.') or key.startswith('decoder.') for key in list(state_dict.keys())[:5]):
96
+ print("Model weights use encoder/decoder format, loading directly")
97
+ model.load_state_dict(state_dict, strict=False)
98
+ else:
99
+ print("Attempting to adapt state dict to match model architecture")
100
+ # This is a placeholder for state dict adaptation if needed
101
+ model.load_state_dict(state_dict, strict=False)
102
+ print("Model weights loaded successfully")
103
  except Exception as e:
104
  print(f"Error loading model weights: {e}")
 
 
105
  else:
106
+ print("No weights available. Model will not produce valid predictions.")
 
107
 
108
  model.to(device)
109
  model.eval()
 
182
  """
183
  Preprocess a ground truth mask
184
  """
185
+ # If mask is a file path (string), open it
186
+ if isinstance(mask, str):
187
+ try:
188
+ # Try to open as a TIFF file with rasterio
189
+ with rasterio.open(mask) as src:
190
+ mask_array = src.read(1) # Read first band
191
+ mask = mask_array
192
+ except:
193
+ # Fall back to opening with PIL
194
+ try:
195
+ mask = np.array(Image.open(mask))
196
+ except Exception as e:
197
+ print(f"Error reading mask file: {e}")
198
+ return None
199
  # Convert to numpy array if PIL image
200
+ elif isinstance(mask, Image.Image):
201
  mask = np.array(mask)
202
 
203
  # Convert to grayscale if needed
 
222
  with torch.no_grad():
223
  output = model(image_tensor)
224
 
225
+ # Handle different model output formats
226
+ if isinstance(output, dict):
 
227
  output = output['out']
228
+ if output.shape[1] > 1: # Multi-class output
 
229
  pred = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
230
+ else: # Binary output (from smp models)
231
+ pred = (torch.sigmoid(output) > 0.5).squeeze().cpu().numpy().astype(np.uint8)
232
+
 
233
  return pred
234
  except Exception as e:
235
  print(f"Error during prediction: {e}")
236
+ return None
 
237
 
238
  def calculate_metrics(pred_mask, gt_mask):
239
  """
 
270
 
271
  return metrics
272
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  def process_images(input_image=None, input_tiff=None, gt_mask=None):
274
  """
275
  Process input images and generate predictions
 
281
 
282
  # Process the input image
283
  if input_tiff is not None:
284
+ # Create a temporary file for the uploaded TIFF
285
+ with tempfile.NamedTemporaryFile(suffix='.tif', delete=False) as temp_tiff:
286
+ temp_tiff_path = temp_tiff.name
287
+
288
+ # Write the file content to the temporary file
289
+ if isinstance(input_tiff, str):
290
+ # If input_tiff is a path
291
+ with open(input_tiff, 'rb') as f:
292
+ temp_tiff.write(f.read())
293
+ else:
294
+ # If input_tiff is file-like object or bytes
295
+ temp_tiff.write(input_tiff)
296
 
297
  # Process TIFF file
298
  image_tensor, display_image = preprocess_tiff(temp_tiff_path)
 
318
  metrics_text = ""
319
 
320
  if gt_mask is not None:
321
+ # If gt_mask is a file upload
322
+ if isinstance(gt_mask, (str, bytes)):
323
+ # Create a temporary file for the mask
324
+ with tempfile.NamedTemporaryFile(suffix='.tif', delete=False) as temp_mask:
325
+ temp_mask_path = temp_mask.name
326
+ if isinstance(gt_mask, str):
327
+ with open(gt_mask, 'rb') as f:
328
+ temp_mask.write(f.read())
329
+ else:
330
+ temp_mask.write(gt_mask)
331
+ gt_mask_processed = preprocess_mask(temp_mask_path)
332
+ try:
333
+ os.unlink(temp_mask_path)
334
+ except:
335
+ pass
336
+ else:
337
+ # Normal image upload
338
+ gt_mask_processed = preprocess_mask(gt_mask)
339
+
340
+ if gt_mask_processed is not None:
341
+ metrics = calculate_metrics(pred_mask, gt_mask_processed)
342
+ metrics_text = "\n".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
343
 
344
  # Create visualization
345
  fig = plt.figure(figsize=(12, 6))
 
391
  return result_image, result_text
392
 
393
  except Exception as e:
 
 
394
  print(f"Error in processing: {e}")
395
+ import traceback
396
+ traceback.print_exc()
397
  return None, f"Error: {str(e)}"
398
 
399
  # Create Gradio interface
 
427
  This application uses a DeepLabv3+ model trained to segment wetland areas in satellite imagery.
428
 
429
  **Model Details:**
430
+ - Architecture: DeepLabv3+ with ResNet-34
431
  - Input: RGB satellite imagery
432
  - Output: Binary segmentation mask (Wetland vs Background)
433
  - Resolution: 128×128 pixels