Spaces:
Running
on
Zero
Running
on
Zero
| import gc | |
| import numpy as np | |
| import torch | |
| from typing import Optional, Tuple, List, Union | |
| import warnings | |
| import cv2 | |
| try: | |
| from transformers import SamModel, SamProcessor | |
| from huggingface_hub import hf_hub_download | |
| HF_AVAILABLE = True | |
| except ImportError: | |
| HF_AVAILABLE = False | |
| warnings.warn("transformers or huggingface_hub not available. HF SAM models will not work.") | |
| # Hugging Face model mapping | |
| HF_MODELS = { | |
| 'vit_b': 'facebook/sam-vit-base', | |
| 'vit_l': 'facebook/sam-vit-large', | |
| 'vit_h': 'facebook/sam-vit-huge' | |
| } | |
| class HFSamPredictor: | |
| """ | |
| Hugging Face version of SamPredictor that wraps the transformers SAM models. | |
| This class provides the same interface as the original SamPredictor for seamless integration. | |
| """ | |
| def __init__(self, model: SamModel, processor: SamProcessor, device: Optional[str] = None): | |
| """ | |
| Initialize the HF SAM predictor. | |
| Args: | |
| model: The SAM model from transformers | |
| processor: The SAM processor from transformers | |
| device: Device to run the model on ('cuda', 'cpu', etc.) | |
| """ | |
| self.model = model | |
| self.processor = processor | |
| self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.model.to(self.device) | |
| self.model.eval() | |
| # Store the current image and its features | |
| self.original_size = None | |
| self.input_size = None | |
| self.features = None | |
| self.image = None | |
| def from_pretrained(cls, model_name: str, device: Optional[str] = None) -> 'HFSamPredictor': | |
| """ | |
| Load a SAM model from Hugging Face Hub. | |
| Args: | |
| model_name: Model name from HF_MODELS or direct HF model path | |
| device: Device to load the model on | |
| Returns: | |
| HFSamPredictor instance | |
| """ | |
| if not HF_AVAILABLE: | |
| raise ImportError("transformers and huggingface_hub are required for HF SAM models") | |
| # Map model type to HF model name if needed | |
| if model_name in HF_MODELS: | |
| model_name = HF_MODELS[model_name] | |
| print(f"Loading SAM model from Hugging Face: {model_name}") | |
| # Load model and processor | |
| model = SamModel.from_pretrained(model_name) | |
| processor = SamProcessor.from_pretrained(model_name) | |
| return cls(model, processor, device) | |
| def preprocess(self, image: np.ndarray, | |
| input_points: List[List[float]], input_labels: List[int]) -> None: | |
| """ | |
| Set the image for prediction. This preprocesses the image and extracts features. | |
| Args: | |
| image: Input image as numpy array (H, W, C) in RGB format | |
| """ | |
| if image.dtype != np.uint8: | |
| image = (image * 255).astype(np.uint8) | |
| self.image = image | |
| self.original_size = image.shape[:2] | |
| # Use dummy point to ensure processor returns original_sizes & reshaped_input_sizes | |
| inputs = self.processor( | |
| images=image, | |
| input_points=input_points, | |
| input_labels=input_labels, | |
| return_tensors="pt" | |
| ) | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| self.input_size = inputs['pixel_values'].shape[-2:] | |
| self.features = inputs | |
| return inputs | |
| def get_hf_sam_predictor(model_type: str = 'vit_h', device: Optional[str] = None, | |
| image: Optional[np.ndarray] = None) -> HFSamPredictor: | |
| """ | |
| Get a Hugging Face SAM predictor with the same interface as the original get_sam_predictor. | |
| Args: | |
| model_type: Model type ('vit_b', 'vit_l', 'vit_h') | |
| device: Device to run the model on | |
| image: Optional image to set immediately | |
| Returns: | |
| HFSamPredictor instance | |
| """ | |
| if not HF_AVAILABLE: | |
| raise ImportError("transformers and huggingface_hub are required for HF SAM models") | |
| if device is None: | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # Load the predictor | |
| predictor = HFSamPredictor.from_pretrained(model_type, device) | |
| # Set image if provided | |
| if image is not None: | |
| predictor.set_image(image) | |
| return predictor |