STAMP-2B-uni / app.py
realzliu
init
e0b18eb
import gradio as gr
import torch
import torch.nn.functional as F
import numpy as np
import cv2
import os
from PIL import Image
from huggingface_hub import hf_hub_download
# --- Import local modules (Ensure these files are uploaded to the Space) ---
try:
from segment_predictor_cache import GenerativeSegmenter
from model.segment_anything import sam_model_registry, SamPredictor
from eval.utils import compute_logits_from_mask, masks_sample_points
except ImportError as e:
raise ImportError(f"Could not import custom modules: {e}. Please ensure STAMP source code (model/, eval/, segment_predictor_cache.py) is uploaded to the Space.")
# --- Configuration ---
MODEL_PATH = "JiaZL/STAMP-2B-uni"
# Use a specific repo to download SAM weights automatically
SAM_REPO_ID = "HCMUE-Research/SAM-vit-h"
SAM_FILENAME = "sam_vit_h_4b8939.pth"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on {DEVICE}...")
# --- Load Models (Cached globally) ---
def load_models():
print(f"Loading STAMP model from {MODEL_PATH}...")
# Adjust min/max pixels if running into OOM on smaller GPUs
segmenter = GenerativeSegmenter(
MODEL_PATH,
device_map=DEVICE,
min_pixels=1024 * 28 * 28, # Reduced slightly for Space stability
max_pixels=1280 * 28 * 28
)
print("Downloading and Loading SAM model...")
sam_checkpoint = hf_hub_download(repo_id=SAM_REPO_ID, filename=SAM_FILENAME)
sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint)
sam = sam.to(dtype=torch.float32, device=DEVICE)
predictor = SamPredictor(sam)
return segmenter, predictor
# Initialize models
segmenter, sam_predictor = load_models()
# --- Core Inference Function ---
def run_inference(image, query, use_sam=True):
if image is None:
return None, "Please upload an image."
if not query:
return None, "Please enter a query."
# Convert to RGB PIL Image
image_pil = Image.fromarray(image).convert("RGB")
w_ori, h_ori = image_pil.size
with torch.inference_mode():
# 1. Set SAM image embedding
if use_sam:
sam_predictor.set_image(np.array(image_pil))
# 2. Generate Coarse Mask using STAMP
print(f"Generating coarse mask for query: {query}")
segmentation_masks, response_text = segmenter.generate_with_segmentation(
image_pil, query
)
if not segmentation_masks or len(segmentation_masks) == 0:
return image, f"No mask generated. Model response: {response_text}"
# Extract the first mask
mask = segmentation_masks[0]
# Resize coarse mask to original image size
mask_pred = F.interpolate(
mask.unsqueeze(0).unsqueeze(0).double(),
size=(h_ori, w_ori),
mode='nearest'
).squeeze(0).squeeze(0)
# --- SAM Refinement ---
final_mask = np.zeros((h_ori, w_ori), dtype=np.float32)
if use_sam:
print("Refining mask with SAM...")
unique_classes = torch.unique(mask_pred)
for class_id in unique_classes:
if class_id == 0: continue
# Get binary mask for current class
binary_mask = (mask_pred == class_id).double().cpu()
try:
logits = compute_logits_from_mask(binary_mask)
point_coords, point_labels = masks_sample_points(binary_mask)
# First pass
sam_mask, _, logit = sam_predictor.predict(
point_coords=point_coords,
point_labels=point_labels,
mask_input=logits,
multimask_output=False
)
# Iterative refinement
for _ in range(2):
sam_mask, _, logit = sam_predictor.predict(
point_coords=point_coords,
point_labels=point_labels,
mask_input=logit,
multimask_output=False
)
current_refined_mask = sam_mask[0].astype(np.float32)
final_mask = np.maximum(final_mask, current_refined_mask)
except Exception as e:
print(f"SAM Error for class {class_id}: {e}")
final_mask = np.maximum(final_mask, binary_mask.numpy())
else:
final_mask = mask_pred.cpu().numpy()
# --- Visualization ---
# Convert mask to uint8 (0 or 255)
mask_uint8 = (final_mask > 0).astype(np.uint8) * 255
# Create a red overlay
overlay = image.copy()
# Paint red where mask is present
# Format is BGR in OpenCV if read via cv2, but Gradio sends RGB numpy array
# We want Red: (255, 0, 0)
# Create colored mask
color_mask = np.zeros_like(image)
color_mask[:, :, 0] = 255 # R
color_mask[:, :, 1] = 0 # G
color_mask[:, :, 2] = 0 # B
# Blend
alpha = 0.5
mask_indices = mask_uint8 > 0
overlay[mask_indices] = (alpha * image[mask_indices] + (1 - alpha) * color_mask[mask_indices]).astype(np.uint8)
# Alternatively, just return the raw mask or the overlay.
# Here we return the overlay.
return overlay, f"Success! {response_text}"
# --- Gradio Interface ---
with gr.Blocks(title="STAMP, Segmentation Demo") as demo:
gr.Markdown("# STAMP: Better, Stronger, and Faster MLLM Segmentation")
gr.Markdown("Upload an image and provide a text query to segment objects using STAMP-2B-uni refined by SAM.")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="numpy")
text_query = gr.Textbox(label="Text Prompt", placeholder="e.g., segment the white horse")
use_sam_checkbox = gr.Checkbox(label="Refine with SAM", value=True)
submit_btn = gr.Button("Segment", variant="primary")
with gr.Column():
output_image = gr.Image(label="Segmentation Result")
status_text = gr.Textbox(label="Status/Response", interactive=False)
submit_btn.click(
fn=run_inference,
inputs=[input_image, text_query, use_sam_checkbox],
outputs=[output_image, status_text]
)
# Add examples
gr.Examples(
examples=[
["images/horses.png", "segment the white horse", True]
],
inputs=[input_image, text_query, use_sam_checkbox],
fn=run_inference, # Dummy fn for cache
cache_examples=False # Disable cache if no GPU on build
)
if __name__ == "__main__":
demo.launch()