A newer version of this model is available: LeafNet75/Leaf-Annotate-v2

๐ŸŒฟ Interactive Leaf Segmentation Model

License PyTorch Hugging Face Model

This is an interactive leaf segmentation model designed to identify a specific leaf in an image based on a user-provided hint.

Model Description

The model uses a U-Net architecture with a lightweight MobileNetV2 backbone, pre-trained on ImageNet. It has been fine-tuned for interactive segmentation.

The model takes a 4-channel input tensor:

  1. Red Channel (from the original image)
  2. Green Channel (from the original image)
  3. Blue Channel (from the original image)
  4. Hint Channel (a single-channel mask containing the user's scribble/dots)

It outputs a single-channel binary segmentation mask of the indicated leaf. This model was trained on the Subh775/leaf-segmentation-dataset.


Intended Use

This model is designed to be the core component of a "human-in-the-loop" application or a smart annotation tool. The primary use case is allowing a user to quickly segment a specific leaf from a complex background by providing a simple, imprecise hint.

Training Performance & Limitations

  • The model was trained on a dataset of 48 images for 50 epochs with a final validation loss of 0.4758 and a final Dice score of 0.8080. While effective, its ability to generalize to leaf types, lighting conditions, or backgrounds that are significantly different from the training data may be limited.
  • It may produce some false positives (segmenting un-hinted leaves) if the target leaf is in a very cluttered environment. Performance can be significantly improved by training on a larger and more diverse dataset.

training_outcomes


Usage

To use this model, you need to define the architecture using segmentation-models-pytorch, download the weights from this Hub repository, and provide a properly formatted 4-channel tensor. Here's the Inference script:

import torch
import cv2
import numpy as np
import segmentation_models_pytorch as smp
from huggingface_hub import hf_hub_download

# --- 1. DEFINE MODEL AND LOAD WEIGHTS ---
DEVICE = "cpu"
HF_MODEL_REPO_ID = "Subh775/leaf-segmentation-unet-mobilenetv2"

# Define the exact same architecture as used in training
model = smp.Unet(
    encoder_name="mobilenet_v2",
    encoder_weights=None, 
    in_channels=4,
    classes=1,
)

# Download model weights from the Hub
model_weights_path = hf_hub_download(
    repo_id=HF_MODEL_REPO_ID,
    filename="best_model.pth"
)

# Load the weights into the model
model.load_state_dict(torch.load(model_weights_path, map_location=DEVICE))
model.to(DEVICE)
model.eval()

# --- 2. PREPARE INPUT ---
# Load your image and create a scribble mask
image_path = "path/to/your/image.jpg"
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# Create a scribble on a black canvas of the same size
scribble_mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
cv2.line(scribble_mask, (x1, y1), (x2, y2), (255), thickness=15) # Example scribble

# Preprocess (resize, normalize, concatenate)
IMG_SIZE = 256
image_resized = cv2.resize(image, (IMG_SIZE, IMG_SIZE))
scribble_resized = cv2.resize(scribble_mask, (IMG_SIZE, IMG_SIZE))

image_tensor = torch.tensor(image_resized, dtype=torch.float32).permute(2, 0, 1) / 255.0
scribble_tensor = torch.tensor(scribble_resized, dtype=torch.float32).unsqueeze(0) / 255.0

# Combine into 4-channel input and add batch dimension
input_tensor = torch.cat([image_tensor, scribble_tensor], dim=0).unsqueeze(0).to(DEVICE)

# --- 3. RUN INFERENCE ---
with torch.no_grad():
    output = model(input_tensor)

# Post-process to get a binary mask
predicted_mask = (torch.sigmoid(output) > 0.5).float().squeeze().cpu().numpy()

print("Inference complete! The 'predicted_mask' is ready.")
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for Subh775/leaf-segmentation-unet-mobilenetv2

Finetuned
(63)
this model

Dataset used to train Subh775/leaf-segmentation-unet-mobilenetv2