๐ฟ Interactive Leaf Segmentation Model
This is an interactive leaf segmentation model designed to identify a specific leaf in an image based on a user-provided hint.
Model Description
The model uses a U-Net architecture with a lightweight MobileNetV2 backbone, pre-trained on ImageNet. It has been fine-tuned for interactive segmentation.
The model takes a 4-channel input tensor:
- Red Channel (from the original image)
- Green Channel (from the original image)
- Blue Channel (from the original image)
- Hint Channel (a single-channel mask containing the user's scribble/dots)
It outputs a single-channel binary segmentation mask of the indicated leaf. This model was trained on the Subh775/leaf-segmentation-dataset.
Intended Use
This model is designed to be the core component of a "human-in-the-loop" application or a smart annotation tool. The primary use case is allowing a user to quickly segment a specific leaf from a complex background by providing a simple, imprecise hint.
Training Performance & Limitations
- The model was trained on a dataset of 48 images for 50 epochs with a final validation loss of 0.4758 and a final Dice score of 0.8080. While effective, its ability to generalize to leaf types, lighting conditions, or backgrounds that are significantly different from the training data may be limited.
- It may produce some false positives (segmenting un-hinted leaves) if the target leaf is in a very cluttered environment. Performance can be significantly improved by training on a larger and more diverse dataset.
Usage
To use this model, you need to define the architecture using segmentation-models-pytorch, download the weights from this Hub repository, and provide a properly formatted 4-channel tensor. Here's the Inference script:
import torch
import cv2
import numpy as np
import segmentation_models_pytorch as smp
from huggingface_hub import hf_hub_download
# --- 1. DEFINE MODEL AND LOAD WEIGHTS ---
DEVICE = "cpu"
HF_MODEL_REPO_ID = "Subh775/leaf-segmentation-unet-mobilenetv2"
# Define the exact same architecture as used in training
model = smp.Unet(
    encoder_name="mobilenet_v2",
    encoder_weights=None, 
    in_channels=4,
    classes=1,
)
# Download model weights from the Hub
model_weights_path = hf_hub_download(
    repo_id=HF_MODEL_REPO_ID,
    filename="best_model.pth"
)
# Load the weights into the model
model.load_state_dict(torch.load(model_weights_path, map_location=DEVICE))
model.to(DEVICE)
model.eval()
# --- 2. PREPARE INPUT ---
# Load your image and create a scribble mask
image_path = "path/to/your/image.jpg"
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Create a scribble on a black canvas of the same size
scribble_mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
cv2.line(scribble_mask, (x1, y1), (x2, y2), (255), thickness=15) # Example scribble
# Preprocess (resize, normalize, concatenate)
IMG_SIZE = 256
image_resized = cv2.resize(image, (IMG_SIZE, IMG_SIZE))
scribble_resized = cv2.resize(scribble_mask, (IMG_SIZE, IMG_SIZE))
image_tensor = torch.tensor(image_resized, dtype=torch.float32).permute(2, 0, 1) / 255.0
scribble_tensor = torch.tensor(scribble_resized, dtype=torch.float32).unsqueeze(0) / 255.0
# Combine into 4-channel input and add batch dimension
input_tensor = torch.cat([image_tensor, scribble_tensor], dim=0).unsqueeze(0).to(DEVICE)
# --- 3. RUN INFERENCE ---
with torch.no_grad():
    output = model(input_tensor)
# Post-process to get a binary mask
predicted_mask = (torch.sigmoid(output) > 0.5).float().squeeze().cpu().numpy()
print("Inference complete! The 'predicted_mask' is ready.")
Model tree for Subh775/leaf-segmentation-unet-mobilenetv2
Base model
google/mobilenet_v2_1.0_224
