You need to agree to share your contact information to access this model

This repository is publicly accessible, but you have to accept the conditions to access its files and content.

Log in or Sign Up to review the conditions and access this model content.

SMB Vision V-JEPA2 ViT-L 384x384x256

SMB Vision V-JEPA2 is a 3D medical imaging model based on the Video Joint Embedding Predictive Architecture (V-JEPA2). This model is specifically designed for processing 3D CT (Computed Tomography) volumes and learning rich spatial representations through self-supervised learning.

The model uses a Vision Transformer Large (ViT-L) architecture adapted for 3D volumes with 384ร—384 spatial resolution and supports up to 256 slices per CT volume.

Model Details

Model Description

SMB Vision V-JEPA2 is a self-supervised 3D medical imaging model that learns spatial representations by predicting masked regions in CT volumes. The model consists of two main components:

  • Encoder: A Vision Transformer adapted for 3D that processes CT volume inputs and generates rich feature representations
  • Predictor: A module that learns to predict masked portions of the CT volume from context regions

Key specifications:

  • Architecture: Vision Transformer Large (ViT-L) adapted for 3D medical imaging with 24 layers and 1024 hidden dimensions

  • Input Resolution: 384ร—384 pixels per slice

  • Volume Capacity: Up to 256 slices per CT volume

  • Patch Size: 16ร—16 pixels with 16-slice depth patches (3D tubelets)

  • Parameters: ~307M parameters in the encoder, ~125M in the predictor

  • Input Channels: Single channel (Hounsfield Units from CT scans)

  • Developed by: StandardModelBio Inc.

  • Model type: 3D Medical Imaging / Self-Supervised Learning

  • Architecture: Video Joint Embedding Predictive Architecture (V-JEPA2) adapted for medical volumes

  • License: [More Information Needed]

Uses

Direct Use

The model can be used directly for:

  • Feature Extraction: Extract rich 3D spatial features from CT volumes for downstream medical imaging tasks
  • Medical Image Analysis: Analyze CT scan content through learned representations
  • Self-supervised Learning: Use as a foundation for transfer learning on 3D medical imaging tasks
  • Research: Study 3D medical image representation learning and transformer architectures for volumetric data

Downstream Use [optional]

The model can be fine-tuned or used as a feature extractor for various medical imaging tasks:

  • Disease Detection: Classify pathologies and abnormalities in CT scans
  • Organ Segmentation: Segment anatomical structures in 3D volumes
  • Volumetric Analysis: Analyze 3D spatial patterns and relationships in medical data
  • Medical Image Classification: Categorize CT volumes by anatomy, pathology, or imaging protocol
  • Representation Learning: Use as a backbone for other 3D medical imaging models

How to Get Started with the Model

Basic Usage

Preprocessing Requirements

Install the required dependencies:

pip install "torch==2.7.1" "torchvision==0.22.1" "torchaudio==2.7.1"
pip install transformers monai nibabel

Working with NIfTI Files

Here's a complete example of how to load and process NIfTI files for use with the VJEPA2 model:

import nibabel as nib
import numpy as np
import torch
from monai.transforms import (
    CenterSpatialCropd,
    Compose,
    EnsureChannelFirstd,
    LoadImaged,
    MapTransform,
    Orientationd,
    ScaleIntensityRanged,
    Spacingd,
    SpatialPadd,
    ToTensord,
)
from transformers import AutoModel


class PermuteImage(MapTransform):
    """Permute the dimensions of the image for VJEPA2 input format"""

    def __init__(self, keys=["image"], allow_missing_keys=False):
        MapTransform.__init__(self, keys, allow_missing_keys)

    def __call__(self, data):
        # VJEPA2 expects: (depth, num_channels, height, width)
        data["image"] = data["image"].permute(3, 0, 1, 2)
        return data


# Define preprocessing pipeline for CT volumes
ct_transforms = Compose([
    LoadImaged(keys=["image"]),
    EnsureChannelFirstd(keys=["image"]),
    Orientationd(keys=["image"], axcodes="RAS"),
    Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.5), mode=("bilinear")),
    ScaleIntensityRanged(
        keys=["image"], 
        a_min=-1000, a_max=1000,  # CT Hounsfield unit range
        b_min=0.0, b_max=1.0, 
        clip=True
    ),
    SpatialPadd(keys=["image"], spatial_size=[384, 384, 256]),
    CenterSpatialCropd(roi_size=[384, 384, 256], keys=["image"]),
    ToTensord(keys=["image"]),
    PermuteImage(),
])

# Load and preprocess NIfTI file
nifti_path = "path/to/your/ct_scan.nii.gz"
data_dict = {"image": nifti_path}
transformed_data = ct_transforms(data_dict)

# Prepare input tensor
input_tensor = transformed_data["image"].unsqueeze(0)  # Add batch dimension
print(f"Input shape: {input_tensor.shape}")  # Expected: [1, 256, 1, 384, 384]

# Load model and run inference
model = AutoModel.from_pretrained(
    "standardmodelbio/smb-vision-vjepa2-vitl-384-256", 
    trust_remote_code=True
)
model.eval()

# Extract features using the encoder
with torch.no_grad():
    output = model.encoder(input_tensor)
    features = output.last_hidden_state

print(f"Output features shape: {features.shape}")  # Expected: [1, 9216, 1024]

Key Points for NIfTI Processing

  • Input Format: The model expects CT volumes in Hounsfield Units (HU)
  • Spatial Resolution: Images are resampled to 1.0ร—1.0ร—1.5 mm spacing
  • Normalization: HU values are clipped to [-1000, 1000] and normalized to [0, 1]
  • Volume Size: Input volumes are padded/cropped to 384ร—384ร—256 voxels
  • Orientation: Volumes are reoriented to RAS (Right-Anterior-Superior) convention
from transformers import AutoModel
import torch

# Load the model
model = AutoModel.from_pretrained(
    "standardmodelbio/smb-vision-vjepa2-vitl-384-256", 
    trust_remote_code=True # this has to be on
)

# Prepare input tensor for CT volume
# Shape: (batch_size, num_slices, num_channels, height, width)
# Note: num_channels=1 for single-channel CT data (Hounsfield Units)
ct_volume = torch.randn(1, 128, 1, 384, 384)

# Get encoder features
output = model.encoder(ct_volume)

# The output contains rich 3D spatial features
print(f"Output shape: {output.last_hidden_state.shape}")
# Expected output shape: torch.Size([1, sequence_length, 1024])

Input Format

The model expects input CT volumes in the following format:

  • Shape: (batch_size, num_slices, num_channels, height, width)
  • Data type: torch.FloatTensor
  • Slices: Up to 256 slices per CT volume
  • Resolution: 384ร—384 pixels per slice
  • Channels: Single channel (Hounsfield Units from CT scans)
  • Data Range: Typical CT Hounsfield Unit range (-1000 to +3000 HU)

Output Format

The encoder output contains:

  • last_hidden_state: Main feature representations with shape (batch_size, sequence_length, hidden_size)
  • hidden_states: Intermediate layer outputs (if requested)
  • attentions: Attention weights (if requested)

The sequence length depends on the input CT volume dimensions and is calculated as:

sequence_length = (num_slices // tubelet_size) * (height // patch_size) * (width // patch_size)

For a 128ร—384ร—384 CT volume with 16ร—16ร—16 patches: sequence_length = (128//16) ร— (384//16) ร— (384//16) = 8 ร— 24 ร— 24 = 4,608

Advanced Usage

For more advanced usage including the predictor module and masked modeling:

# Full model forward pass with predictor
output = model(
    pixel_values_videos=ct_volume,
    skip_predictor=False,  # Set to True to skip predictor
    output_hidden_states=True,
    output_attentions=True
)

# Access different components
encoder_features = output.last_hidden_state
predictor_output = output.predictor_output

Technical Specifications [optional]

Model Architecture and Objective

Architecture: The model implements the Video Joint Embedding Predictive Architecture (V-JEPA2) adapted for 3D medical imaging with the following components:

  • Encoder: Vision Transformer (ViT-L) adapted for 3D with 24 transformer layers, 1024 hidden dimensions, and 16 attention heads
  • Predictor: Lightweight transformer with 12 layers, 384 hidden dimensions, and 12 attention heads
  • 3D Patch Embedding: Converts CT volume patches into tokens using 16ร—16ร—16 spatiotemporal patches (3D tubelets)
  • Positional Encoding: RoPE (Rotary Position Embedding) for spatial position encoding across all three dimensions

Objective: Self-supervised learning through masked 3D volume modeling:

  • Random 3D regions (patches) of the input CT volume are masked
  • The encoder processes the visible context from the remaining volume
  • The predictor learns to reconstruct features of masked regions
  • This approach learns rich 3D spatial representations without requiring labeled medical data

Key Features:

  • Flash Attention 2.0 for efficient computation on large 3D volumes
  • Gradient checkpointing support for memory efficiency during training
  • Flexible input resolution and slice count for different CT protocols
  • Attentive pooling for volume-level representations
  • Optimized for single-channel medical imaging data (Hounsfield Units)
Downloads last month
55
Safetensors
Model size
0.3B params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Collection including standardmodelbio/smb-vision-vjepa2-vitl-384-256