SMB Vision V-JEPA2 ViT-L 384x384x256
SMB Vision V-JEPA2 is a 3D medical imaging model based on the Video Joint Embedding Predictive Architecture (V-JEPA2). This model is specifically designed for processing 3D CT (Computed Tomography) volumes and learning rich spatial representations through self-supervised learning.
The model uses a Vision Transformer Large (ViT-L) architecture adapted for 3D volumes with 384ร384 spatial resolution and supports up to 256 slices per CT volume.
Model Details
Model Description
SMB Vision V-JEPA2 is a self-supervised 3D medical imaging model that learns spatial representations by predicting masked regions in CT volumes. The model consists of two main components:
- Encoder: A Vision Transformer adapted for 3D that processes CT volume inputs and generates rich feature representations
- Predictor: A module that learns to predict masked portions of the CT volume from context regions
Key specifications:
Architecture: Vision Transformer Large (ViT-L) adapted for 3D medical imaging with 24 layers and 1024 hidden dimensions
Input Resolution: 384ร384 pixels per slice
Volume Capacity: Up to 256 slices per CT volume
Patch Size: 16ร16 pixels with 16-slice depth patches (3D tubelets)
Parameters: ~307M parameters in the encoder, ~125M in the predictor
Input Channels: Single channel (Hounsfield Units from CT scans)
Developed by: StandardModelBio Inc.
Model type: 3D Medical Imaging / Self-Supervised Learning
Architecture: Video Joint Embedding Predictive Architecture (V-JEPA2) adapted for medical volumes
License: [More Information Needed]
Uses
Direct Use
The model can be used directly for:
- Feature Extraction: Extract rich 3D spatial features from CT volumes for downstream medical imaging tasks
- Medical Image Analysis: Analyze CT scan content through learned representations
- Self-supervised Learning: Use as a foundation for transfer learning on 3D medical imaging tasks
- Research: Study 3D medical image representation learning and transformer architectures for volumetric data
Downstream Use [optional]
The model can be fine-tuned or used as a feature extractor for various medical imaging tasks:
- Disease Detection: Classify pathologies and abnormalities in CT scans
- Organ Segmentation: Segment anatomical structures in 3D volumes
- Volumetric Analysis: Analyze 3D spatial patterns and relationships in medical data
- Medical Image Classification: Categorize CT volumes by anatomy, pathology, or imaging protocol
- Representation Learning: Use as a backbone for other 3D medical imaging models
How to Get Started with the Model
Basic Usage
Preprocessing Requirements
Install the required dependencies:
pip install "torch==2.7.1" "torchvision==0.22.1" "torchaudio==2.7.1"
pip install transformers monai nibabel
Working with NIfTI Files
Here's a complete example of how to load and process NIfTI files for use with the VJEPA2 model:
import nibabel as nib
import numpy as np
import torch
from monai.transforms import (
CenterSpatialCropd,
Compose,
EnsureChannelFirstd,
LoadImaged,
MapTransform,
Orientationd,
ScaleIntensityRanged,
Spacingd,
SpatialPadd,
ToTensord,
)
from transformers import AutoModel
class PermuteImage(MapTransform):
"""Permute the dimensions of the image for VJEPA2 input format"""
def __init__(self, keys=["image"], allow_missing_keys=False):
MapTransform.__init__(self, keys, allow_missing_keys)
def __call__(self, data):
# VJEPA2 expects: (depth, num_channels, height, width)
data["image"] = data["image"].permute(3, 0, 1, 2)
return data
# Define preprocessing pipeline for CT volumes
ct_transforms = Compose([
LoadImaged(keys=["image"]),
EnsureChannelFirstd(keys=["image"]),
Orientationd(keys=["image"], axcodes="RAS"),
Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.5), mode=("bilinear")),
ScaleIntensityRanged(
keys=["image"],
a_min=-1000, a_max=1000, # CT Hounsfield unit range
b_min=0.0, b_max=1.0,
clip=True
),
SpatialPadd(keys=["image"], spatial_size=[384, 384, 256]),
CenterSpatialCropd(roi_size=[384, 384, 256], keys=["image"]),
ToTensord(keys=["image"]),
PermuteImage(),
])
# Load and preprocess NIfTI file
nifti_path = "path/to/your/ct_scan.nii.gz"
data_dict = {"image": nifti_path}
transformed_data = ct_transforms(data_dict)
# Prepare input tensor
input_tensor = transformed_data["image"].unsqueeze(0) # Add batch dimension
print(f"Input shape: {input_tensor.shape}") # Expected: [1, 256, 1, 384, 384]
# Load model and run inference
model = AutoModel.from_pretrained(
"standardmodelbio/smb-vision-vjepa2-vitl-384-256",
trust_remote_code=True
)
model.eval()
# Extract features using the encoder
with torch.no_grad():
output = model.encoder(input_tensor)
features = output.last_hidden_state
print(f"Output features shape: {features.shape}") # Expected: [1, 9216, 1024]
Key Points for NIfTI Processing
- Input Format: The model expects CT volumes in Hounsfield Units (HU)
- Spatial Resolution: Images are resampled to 1.0ร1.0ร1.5 mm spacing
- Normalization: HU values are clipped to [-1000, 1000] and normalized to [0, 1]
- Volume Size: Input volumes are padded/cropped to 384ร384ร256 voxels
- Orientation: Volumes are reoriented to RAS (Right-Anterior-Superior) convention
from transformers import AutoModel
import torch
# Load the model
model = AutoModel.from_pretrained(
"standardmodelbio/smb-vision-vjepa2-vitl-384-256",
trust_remote_code=True # this has to be on
)
# Prepare input tensor for CT volume
# Shape: (batch_size, num_slices, num_channels, height, width)
# Note: num_channels=1 for single-channel CT data (Hounsfield Units)
ct_volume = torch.randn(1, 128, 1, 384, 384)
# Get encoder features
output = model.encoder(ct_volume)
# The output contains rich 3D spatial features
print(f"Output shape: {output.last_hidden_state.shape}")
# Expected output shape: torch.Size([1, sequence_length, 1024])
Input Format
The model expects input CT volumes in the following format:
- Shape:
(batch_size, num_slices, num_channels, height, width) - Data type:
torch.FloatTensor - Slices: Up to 256 slices per CT volume
- Resolution: 384ร384 pixels per slice
- Channels: Single channel (Hounsfield Units from CT scans)
- Data Range: Typical CT Hounsfield Unit range (-1000 to +3000 HU)
Output Format
The encoder output contains:
- last_hidden_state: Main feature representations with shape
(batch_size, sequence_length, hidden_size) - hidden_states: Intermediate layer outputs (if requested)
- attentions: Attention weights (if requested)
The sequence length depends on the input CT volume dimensions and is calculated as:
sequence_length = (num_slices // tubelet_size) * (height // patch_size) * (width // patch_size)
For a 128ร384ร384 CT volume with 16ร16ร16 patches: sequence_length = (128//16) ร (384//16) ร (384//16) = 8 ร 24 ร 24 = 4,608
Advanced Usage
For more advanced usage including the predictor module and masked modeling:
# Full model forward pass with predictor
output = model(
pixel_values_videos=ct_volume,
skip_predictor=False, # Set to True to skip predictor
output_hidden_states=True,
output_attentions=True
)
# Access different components
encoder_features = output.last_hidden_state
predictor_output = output.predictor_output
Technical Specifications [optional]
Model Architecture and Objective
Architecture: The model implements the Video Joint Embedding Predictive Architecture (V-JEPA2) adapted for 3D medical imaging with the following components:
- Encoder: Vision Transformer (ViT-L) adapted for 3D with 24 transformer layers, 1024 hidden dimensions, and 16 attention heads
- Predictor: Lightweight transformer with 12 layers, 384 hidden dimensions, and 12 attention heads
- 3D Patch Embedding: Converts CT volume patches into tokens using 16ร16ร16 spatiotemporal patches (3D tubelets)
- Positional Encoding: RoPE (Rotary Position Embedding) for spatial position encoding across all three dimensions
Objective: Self-supervised learning through masked 3D volume modeling:
- Random 3D regions (patches) of the input CT volume are masked
- The encoder processes the visible context from the remaining volume
- The predictor learns to reconstruct features of masked regions
- This approach learns rich 3D spatial representations without requiring labeled medical data
Key Features:
- Flash Attention 2.0 for efficient computation on large 3D volumes
- Gradient checkpointing support for memory efficiency during training
- Flexible input resolution and slice count for different CT protocols
- Attentive pooling for volume-level representations
- Optimized for single-channel medical imaging data (Hounsfield Units)
- Downloads last month
- 55