SAM2-HIERA-SMALL - ONNX Format for WebGPU

Powered by Segment Anything 2 (SAM2) from Meta Research

This repository contains ONNX-converted models from facebook/sam2-hiera-small, optimized for WebGPU deployment in browsers.

Model Information

  • Original Model: facebook/sam2-hiera-small
  • Version: SAM 2.0
  • Size: 46M parameters
  • Description: Small variant - balanced speed and quality
  • Format: ONNX (encoder + decoder)
  • Optimization: Encoder optimized to .ort format for WebGPU

Files

  • encoder.onnx - Image encoder (ONNX format)
  • encoder.with_runtime_opt.ort - Image encoder (optimized for WebGPU)
  • decoder.onnx - Mask decoder (ONNX format)
  • config.json - Model configuration

Usage

In Browser with ONNX Runtime Web

import * as ort from 'onnxruntime-web/webgpu';

// Load encoder (use optimized .ort version for WebGPU)
const encoderURL = 'https://huggingface.co/SharpAI/sam2-hiera-small-onnx/resolve/main/encoder.with_runtime_opt.ort';
const encoderSession = await ort.InferenceSession.create(encoderURL, {
  executionProviders: ['webgpu'],
  graphOptimizationLevel: 'disabled'
});

// Load decoder
const decoderURL = 'https://huggingface.co/SharpAI/sam2-hiera-small-onnx/resolve/main/decoder.onnx';
const decoderSession = await ort.InferenceSession.create(decoderURL, {
  executionProviders: ['webgpu']
});

// Run encoder
const imageData = preprocessImage(image); // Your preprocessing
const encoderOutputs = await encoderSession.run({ image: imageData });

// Run decoder with point
const point_coords = new ort.Tensor('float32', [x, y, 0, 0], [1, 2, 2]);
const point_labels = new ort.Tensor('float32', [1, -1], [1, 2]);
const mask_input = new ort.Tensor('float32', new Float32Array(256 * 256).fill(0), [1, 1, 256, 256]);
const has_mask_input = new ort.Tensor('float32', [0], [1]);

const decoderOutputs = await decoderSession.run({
  image_embed: encoderOutputs.image_embed,
  high_res_feats_0: encoderOutputs.high_res_feats_0,
  high_res_feats_1: encoderOutputs.high_res_feats_1,
  point_coords: point_coords,
  point_labels: point_labels,
  mask_input: mask_input,
  has_mask_input: has_mask_input
});

// Get masks
const masks = decoderOutputs.masks; // Shape: [1, num_masks, 256, 256]

In Python with ONNX Runtime

import onnxruntime as ort
import numpy as np

# Load models
encoder_session = ort.InferenceSession("encoder.onnx")
decoder_session = ort.InferenceSession("decoder.onnx")

# Run encoder
encoder_outputs = encoder_session.run(None, {"image": image_tensor})

# Run decoder
decoder_outputs = decoder_session.run(None, {
    "image_embed": encoder_outputs[0],
    "high_res_feats_0": encoder_outputs[1],
    "high_res_feats_1": encoder_outputs[2],
    "point_coords": point_coords,
    "point_labels": point_labels,
    "mask_input": mask_input,
    "has_mask_input": has_mask_input
})

masks = decoder_outputs[0]

Input/Output Specifications

Encoder

Input:

  • image: Float32[1, 3, 1024, 1024] - Normalized RGB image

Outputs:

  • image_embed: Float32[1, 256, 64, 64] - Image embeddings
  • high_res_feats_0: Float32[1, 32, 256, 256] - High-res features (level 0)
  • high_res_feats_1: Float32[1, 64, 128, 128] - High-res features (level 1)

Decoder

Inputs:

  • image_embed: Float32[1, 256, 64, 64] - From encoder
  • high_res_feats_0: Float32[1, 32, 256, 256] - From encoder
  • high_res_feats_1: Float32[1, 64, 128, 128] - From encoder
  • point_coords: Float32[1, 2, 2] - Point coordinates [[x, y], [0, 0]]
  • point_labels: Float32[1, 2] - Point labels [1, -1] (1=foreground, -1=padding)
  • mask_input: Float32[1, 1, 256, 256] - Previous mask (zeros if none)
  • has_mask_input: Float32[1] - Flag [0] or [1]

Outputs:

  • masks: Float32[1, 3, 256, 256] - Generated masks (3 candidates)
  • iou_predictions: Float32[1, 3] - IoU scores for each mask
  • low_res_masks: Float32[1, 3, 256, 256] - Low-resolution masks

Browser Requirements

  • Chrome 113+ with WebGPU enabled (chrome://flags/#enable-unsafe-webgpu)
  • Firefox Nightly with WebGPU enabled
  • Safari Technology Preview with WebGPU enabled

Performance

Typical inference times on Chrome with WebGPU:

  • Encoder: {'2-3s' if 'tiny' in model_name else '3-5s' if 'small' in model_name else '4-6s' if 'base' in model_name else '8-10s'}
  • Decoder: 0.1-0.5s per point

License

This model is released under the Apache 2.0 license, following the original SAM2 model.

Citation

@article{ravi2024sam2,
  title={SAM 2: Segment Anything in Images and Videos},
  author={Ravi, Nikhila and Gabeur, Valentin and Hu, Yuan-Ting and Hu, Ronghang and Ryali, Chaitanya and Ma, Tengyu and Khedr, Haitham and R{\"a}dle, Roman and Rolland, Chloe and Gustafson, Laura and Mintun, Eric and Pan, Junting and Alwala, Kalyan Vasudev and Carion, Nicolas and Wu, Chao-Yuan and Girshick, Ross and Doll{\'a}r, Piotr and Feichtenhofer, Christoph},
  journal={arXiv preprint arXiv:2408.00714},
  year={2024}
}

Related Resources

Acknowledgments

  • Meta Research for the original SAM2 model
  • Microsoft for ONNX Runtime
  • SamExporter for conversion tools

Converted and optimized by Aegis AI

Downloads last month
13
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support