Fine-Tuning SigLIP2 for Image Classification

Community Article Published March 5, 2025

image/png

  • SigLIP 2: Multilingual Vision-Language Encoders with Improved Semantic Understanding, Localization, and Dense Features

  • SigLIP 2 introduces new multilingual vision-language encoders that build on the success of the original SigLIP. In this second iteration, we extend the original image-text training objective by integrating several independently developed techniques into a unified approach. These include captioning-based pretraining, self-supervised losses such as self-distillation and masked prediction.

  • The script below is used for fine-tuning SigLIP 2 foundational models on a single-label image classification problem.


Fine-tuning Notebook

Notebook Task Description Link
SigLIP2 Finetune ImageClassification Image Classification Fine-tune SigLIP2 on a custom image classification dataset with automated train/test split, preprocessing, and training pipeline setup. Open Notebook

GitHub Gist (Code Snippet)

Github Gist () : https://gist.github.com/PRITHIVSAKTHIUR/e3c67b9fbcaf397b6639b018d457fd08

# Fine-Tuning SigLIP2 for Image Classification | Script prepared by: hf.co/prithivMLmods
#
# Dataset with Train & Test Splits
#
# In this configuration, the dataset is already organized into separate training and testing splits. This setup is ideal for straightforward supervised learning workflows.
#
# Training Phase:
# The model is fine-tuned exclusively on the train split, where each image is paired with its corresponding class label.
#
# Evaluation Phase:
# After training, the model's performance is assessed on the test split to measure generalization accuracy.

# 1. Install the packages

# %%capture
# Install required libraries for fine-tuning SigLIP 2
# 'evaluate' - for computing evaluation metrics like accuracy
# 'datasets==3.2.0' - for handling train/test splits and loading image datasets
# 'accelerate' - for efficient multi-GPU / multi-CPU training
# 'transformers==4.50.0' - for SigLIP 2 and other transformer-based models
# 'torchvision' - for image preprocessing and augmentations
# 'huggingface-hub==0.31.0' - for model and dataset uploads/downloads
# 'hf_xet' - optional helper for versioned dataset/model storage on Hugging Face Hub
#  Hold tight, this will take around 2-3 minutes.

# If you are performing the training process outside of Google Colaboratory, install: imbalanced-learn

# --------------------------------------------------------------------------

# To demonstrate the fine-tuning process, we will use the MNIST dataset — a classic benchmark for image classification.
# MNIST consists of 28x28 grayscale images of handwritten digits (0–9), making it ideal for testing model training pipelines.
# We will load the dataset directly from the Hugging Face Hub using the 'datasets' library.
# Dataset link: https://huggingface.co/datasets/ylecun/mnist

# --------------------------------------------------------------------------

# 2. Import modules required for data manipulation, model training, and image preprocessing.

import warnings
warnings.filterwarnings("ignore")

import gc
import numpy as np
import pandas as pd
import itertools
from collections import Counter
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix, classification_report, f1_score
from imblearn.over_sampling import RandomOverSampler
import evaluate
from datasets import Dataset, Image, ClassLabel
from transformers import (
    TrainingArguments,
    Trainer,
    DefaultDataCollator
)

from transformers import AutoImageProcessor
from transformers import SiglipForImageClassification
from transformers.image_utils import load_image

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomRotation,
    RandomResizedCrop,
    RandomHorizontalFlip,
    RandomAdjustSharpness,
    Resize,
    ToTensor
)

from PIL import Image, ExifTags
from PIL import Image as PILImage
from PIL import ImageFile
# Enable loading truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True

# 3. Loading and Preparing the Dataset

from datasets import load_dataset
dataset = load_dataset("ylecun/mnist", split="train")

from pathlib import Path

file_names = []
labels = []

for example in dataset:
    file_path = str(example['image'])
    label = example['label']

    file_names.append(file_path)
    labels.append(label)

print(len(file_names), len(labels))

# 4. Creating a DataFrame and Balancing the Dataset & Working with a Subset of Labels

df = pd.DataFrame.from_dict({"image": file_names, "label": labels})
print(df.shape)

df.head()
df['label'].unique()

y = df[['label']]
df = df.drop(['label'], axis=1)
ros = RandomOverSampler(random_state=83)
df, y_resampled = ros.fit_resample(df, y)
del y
df['label'] = y_resampled
del y_resampled
gc.collect()

labels_subset = labels[:5]
print(labels_subset)

#labels_list = ['example_label_0', 'example_label_1'................,'example_label_n-1']
labels_list = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']

label2id, id2label = {}, {}
for i, label in enumerate(labels_list):
    label2id[label] = i
    id2label[i] = label

ClassLabels = ClassLabel(num_classes=len(labels_list), names=labels_list)

print("Mapping of IDs to Labels:", id2label, '\n')
print("Mapping of Labels to IDs:", label2id)
     
# 5. Mapping and Casting Labels

def map_label2id(example):
    example['label'] = ClassLabels.str2int(example['label'])
    return example

# 6. Splitting the Dataset

dataset = dataset.map(map_label2id, batched=True)
dataset = dataset.cast_column('label', ClassLabels)
dataset = dataset.train_test_split(test_size=0.4, shuffle=True, stratify_by_column="label")

train_data = dataset['train']
test_data = dataset['test']

# 7. Setting Up the Model and Processor


model_str = "google/siglip2-base-patch16-224"
processor = AutoImageProcessor.from_pretrained(model_str)

# Extract preprocessing parameters
image_mean, image_std = processor.image_mean, processor.image_std
size = processor.size["height"]

# 8. Defining Data Transformations

# Define training transformations
_train_transforms = Compose([
    Resize((size, size)),
    RandomRotation(90),
    RandomAdjustSharpness(2),
    ToTensor(),
    Normalize(mean=image_mean, std=image_std)
])

# Define validation transformations
_val_transforms = Compose([
    Resize((size, size)),
    ToTensor(),
    Normalize(mean=image_mean, std=image_std)
])

# 9. Applying Transformations to the Dataset


# Apply transformations to dataset
def train_transforms(examples):
    examples['pixel_values'] = [_train_transforms(image.convert("RGB")) for image in examples['image']]
    return examples

def val_transforms(examples):
    examples['pixel_values'] = [_val_transforms(image.convert("RGB")) for image in examples['image']]
    return examples

# Assuming train_data and test_data are loaded datasets
train_data.set_transform(train_transforms)
test_data.set_transform(val_transforms)

# 10. Creating a Data Collator

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example['label'] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

# 11. Initializing the Model

model = SiglipForImageClassification.from_pretrained(model_str, num_labels=len(labels_list))
model.config.id2label = id2label
model.config.label2id = label2id

print(model.num_parameters(only_trainable=True) / 1e6)

# 12. Defining Metrics and the Compute Function

accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions = eval_pred.predictions
    label_ids = eval_pred.label_ids

    predicted_labels = predictions.argmax(axis=1)
    acc_score = accuracy.compute(predictions=predicted_labels, references=label_ids)['accuracy']

    return {
        "accuracy": acc_score
    }

# 13. Setting Up Training Arguments

args = TrainingArguments(
    output_dir="siglip2-image-classification/",
    logging_dir='./logs',
    evaluation_strategy="epoch",
    learning_rate=2e-4,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=8,
    num_train_epochs=2,
    weight_decay=0.02,
    warmup_steps=50,
    remove_unused_columns=False,
    save_strategy='epoch',
    load_best_model_at_end=True,
    save_total_limit=4,
    report_to="none"
)

# 14. Initializing the Trainer



trainer = Trainer(
    model,
    args,
    train_dataset=train_data,
    eval_dataset=test_data,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    tokenizer=processor,
)
     
# 15. Evaluating, Training, and Predicting


trainer.evaluate()

trainer.train()

trainer.evaluate()

outputs = trainer.predict(test_data)
print(outputs.metrics)

# 16. Computing Additional Metrics and Plotting the Confusion Matrix

y_true = outputs.label_ids
y_pred = outputs.predictions.argmax(1)

def plot_confusion_matrix(cm, classes, title='Confusion Matrix', cmap=plt.cm.Reds, figsize=(10, 8)):

    plt.figure(figsize=figsize)

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()

    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=90)
    plt.yticks(tick_marks, classes)

    fmt = '.0f'
    thresh = cm.max() / 2.0
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt), horizontalalignment="center", color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.show()

accuracy = accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred, average='macro')

print(f"Accuracy: {accuracy:.4f}")
print(f"F1 Score: {f1:.4f}")

if len(labels_list) <= 150:
    cm = confusion_matrix(y_true, y_pred)
    plot_confusion_matrix(cm, labels_list, figsize=(8, 6))

print()
print("Classification report:")
print()
print(classification_report(y_true, y_pred, target_names=labels_list, digits=4))

# 17. Saving the Model and Uploading to Hugging Face Hub

trainer.save_model() # Trained Model Uploaded to Hugging Face: https://huggingface.co/prithivMLmods/Mnist-Digits-SigLIP2
     
# 18. Login to Hugging Face Hub
#
# Use the Hugging Face Hub API to authenticate your session in the notebook.
# This allows you to push models, datasets, and other assets directly to your account.

#from huggingface_hub import notebook_login, HfApi
#notebook_login()

# 19. Upload the Model to Hugging Face Hub
#
# Once fine-tuning is complete, you can upload the trained SigLIP 2 model to the Hugging Face Hub.
# This enables sharing, versioning, and easy access for future inference or collaboration.

#api = HfApi()
#repo_id = f"prithivMLmods/Mnist-Digits-SigLIP2"

#api.upload_folder(
#    folder_path="siglip2-image-classification/",  # Local folder containing the fine-tuned model
#    path_in_repo=".",                             # Path inside the Hugging Face repository
#    repo_id=repo_id,                              # Repository ID (username/repo_name)
#    repo_type="model",                            # Specify this is a model repository
#    revision="main"                               # Branch or revision to push to
#)

#print(f"Model uploaded to https://huggingface.co/{repo_id}")

GitHub Gist (Code Snippet)

Github Gist (Fine-Tuning ViT for Image Classification) : https://gist.github.com/PRITHIVSAKTHIUR/cad219cb0d3d07af573c979667e3afd3

SigLIP2 Paper

Title Link (Abstract) Link (PDF)
SigLIP 2: Multilingual Vision-Language Encoders arXiv:2502.14786 PDF

Details and Benefits

SigLIP 2 is built on the foundation of Vision Transformers, ensuring backward compatibility with earlier versions. This allows users to replace model weights without overhauling their entire system. Unlike traditional contrastive loss, SigLIP 2 employs a sigmoid loss, enabling a more balanced learning of both global and local features.

In addition to the sigmoid loss, SigLIP 2 integrates a decoder-based loss, enhancing tasks such as image captioning and region-specific localization. This leads to improved performance in dense prediction tasks. The model also incorporates a MAP head, which pools features from both image and text components, ensuring robust and detailed representations.

A key innovation in SigLIP 2 is the NaFlex variant, which supports native aspect ratios by processing images at various resolutions using a single checkpoint. This approach preserves the spatial integrity of images, making it particularly effective for applications such as document understanding and OCR.

Furthermore, self-distillation and masked prediction enhance the quality of local features. By training the model to predict masked patches, it learns to focus on subtle details critical for tasks like segmentation and depth estimation. This well-optimized design enables even smaller models to achieve superior performance through advanced distillation techniques.

Conclusion

SigLIP 2 represents a well-engineered and deliberate advancement in vision-language models. By integrating established techniques with thoughtful innovations, it effectively addresses key challenges such as fine-grained localization, dense prediction, and multilingual support. Moving beyond traditional contrastive losses, SigLIP 2 incorporates self-supervised objectives, leading to a more balanced and nuanced representation of visual data. Its careful handling of native aspect ratios through the NaFlex variant further enhances its applicability in real-world scenarios where preserving image integrity is crucial.

The model's inclusion of multilingual data and de-biasing measures demonstrates an awareness of the diverse contexts in which it operates. This approach not only improves performance across various benchmarks but also aligns with broader ethical considerations in AI. Ultimately, the release of SigLIP 2 marks a significant step forward for the vision-language research community. It offers a versatile, backward-compatible framework that seamlessly integrates into existing systems. With its ability to deliver reliable performance across diverse tasks—while prioritizing fairness and inclusivity—SigLIP 2 sets a strong benchmark for future advancements in the field.

Happy fine-tuning! 🤗

Community

Sign up or log in to comment