SAM2-HIERA-SMALL - ONNX Format for WebGPU
Powered by Segment Anything 2 (SAM2) from Meta Research
This repository contains ONNX-converted models from facebook/sam2-hiera-small, optimized for WebGPU deployment in browsers.
Model Information
- Original Model: facebook/sam2-hiera-small
- Version: SAM 2.0
- Size: 46M parameters
- Description: Small variant - balanced speed and quality
- Format: ONNX (encoder + decoder)
- Optimization: Encoder optimized to .ort format for WebGPU
Files
encoder.onnx- Image encoder (ONNX format)encoder.with_runtime_opt.ort- Image encoder (optimized for WebGPU)decoder.onnx- Mask decoder (ONNX format)config.json- Model configuration
Usage
In Browser with ONNX Runtime Web
import * as ort from 'onnxruntime-web/webgpu';
// Load encoder (use optimized .ort version for WebGPU)
const encoderURL = 'https://huggingface.co/SharpAI/sam2-hiera-small-onnx/resolve/main/encoder.with_runtime_opt.ort';
const encoderSession = await ort.InferenceSession.create(encoderURL, {
executionProviders: ['webgpu'],
graphOptimizationLevel: 'disabled'
});
// Load decoder
const decoderURL = 'https://huggingface.co/SharpAI/sam2-hiera-small-onnx/resolve/main/decoder.onnx';
const decoderSession = await ort.InferenceSession.create(decoderURL, {
executionProviders: ['webgpu']
});
// Run encoder
const imageData = preprocessImage(image); // Your preprocessing
const encoderOutputs = await encoderSession.run({ image: imageData });
// Run decoder with point
const point_coords = new ort.Tensor('float32', [x, y, 0, 0], [1, 2, 2]);
const point_labels = new ort.Tensor('float32', [1, -1], [1, 2]);
const mask_input = new ort.Tensor('float32', new Float32Array(256 * 256).fill(0), [1, 1, 256, 256]);
const has_mask_input = new ort.Tensor('float32', [0], [1]);
const decoderOutputs = await decoderSession.run({
image_embed: encoderOutputs.image_embed,
high_res_feats_0: encoderOutputs.high_res_feats_0,
high_res_feats_1: encoderOutputs.high_res_feats_1,
point_coords: point_coords,
point_labels: point_labels,
mask_input: mask_input,
has_mask_input: has_mask_input
});
// Get masks
const masks = decoderOutputs.masks; // Shape: [1, num_masks, 256, 256]
In Python with ONNX Runtime
import onnxruntime as ort
import numpy as np
# Load models
encoder_session = ort.InferenceSession("encoder.onnx")
decoder_session = ort.InferenceSession("decoder.onnx")
# Run encoder
encoder_outputs = encoder_session.run(None, {"image": image_tensor})
# Run decoder
decoder_outputs = decoder_session.run(None, {
"image_embed": encoder_outputs[0],
"high_res_feats_0": encoder_outputs[1],
"high_res_feats_1": encoder_outputs[2],
"point_coords": point_coords,
"point_labels": point_labels,
"mask_input": mask_input,
"has_mask_input": has_mask_input
})
masks = decoder_outputs[0]
Input/Output Specifications
Encoder
Input:
image: Float32[1, 3, 1024, 1024] - Normalized RGB image
Outputs:
image_embed: Float32[1, 256, 64, 64] - Image embeddingshigh_res_feats_0: Float32[1, 32, 256, 256] - High-res features (level 0)high_res_feats_1: Float32[1, 64, 128, 128] - High-res features (level 1)
Decoder
Inputs:
image_embed: Float32[1, 256, 64, 64] - From encoderhigh_res_feats_0: Float32[1, 32, 256, 256] - From encoderhigh_res_feats_1: Float32[1, 64, 128, 128] - From encoderpoint_coords: Float32[1, 2, 2] - Point coordinates [[x, y], [0, 0]]point_labels: Float32[1, 2] - Point labels [1, -1] (1=foreground, -1=padding)mask_input: Float32[1, 1, 256, 256] - Previous mask (zeros if none)has_mask_input: Float32[1] - Flag [0] or [1]
Outputs:
masks: Float32[1, 3, 256, 256] - Generated masks (3 candidates)iou_predictions: Float32[1, 3] - IoU scores for each masklow_res_masks: Float32[1, 3, 256, 256] - Low-resolution masks
Browser Requirements
- Chrome 113+ with WebGPU enabled (
chrome://flags/#enable-unsafe-webgpu) - Firefox Nightly with WebGPU enabled
- Safari Technology Preview with WebGPU enabled
Performance
Typical inference times on Chrome with WebGPU:
- Encoder: {'2-3s' if 'tiny' in model_name else '3-5s' if 'small' in model_name else '4-6s' if 'base' in model_name else '8-10s'}
- Decoder: 0.1-0.5s per point
License
This model is released under the Apache 2.0 license, following the original SAM2 model.
Citation
@article{ravi2024sam2,
title={SAM 2: Segment Anything in Images and Videos},
author={Ravi, Nikhila and Gabeur, Valentin and Hu, Yuan-Ting and Hu, Ronghang and Ryali, Chaitanya and Ma, Tengyu and Khedr, Haitham and R{\"a}dle, Roman and Rolland, Chloe and Gustafson, Laura and Mintun, Eric and Pan, Junting and Alwala, Kalyan Vasudev and Carion, Nicolas and Wu, Chao-Yuan and Girshick, Ross and Doll{\'a}r, Piotr and Feichtenhofer, Christoph},
journal={arXiv preprint arXiv:2408.00714},
year={2024}
}
Related Resources
- Original SAM2: facebookresearch/segment-anything-2
- WebGPU Demo: Aegis AI SAM2 WebGPU Demo
- Conversion Tool: SAM2 ONNX Converter
Acknowledgments
- Meta Research for the original SAM2 model
- Microsoft for ONNX Runtime
- SamExporter for conversion tools
Converted and optimized by Aegis AI
- Downloads last month
- 13