Vision Transformer for STM Multi-Tip Artifact Detection

This is a fine-tuned Vision Transformer (ViT-B/32) model for classifying Scanning Tunneling Microscopy (STM) images. It is designed to detect the presence of multi-tip artifacts, a common distortion that results in duplicated signals and complicates data interpretation.

This model was developed as part of the NFFA-DI (Nano Foundries and Fine Analysis Digital Infrastructure) project, funded by the European Union's NextGenerationEU program.

Model Description

The model is a ViT-B/32 pre-trained on ImageNet-21k. It was fine-tuned to classify an STM image as either Artifact-Free or Multi-Tip Artifact.

A key feature of this model is its use of a Fast Fourier Transform (FFT) based preprocessing method. The model's input is not a standard image but a 3-channel tensor composed of:

  1. The grayscale STM image.
  2. The amplitude of the image's Fourier transform.
  3. The phase of the image's Fourier transform.

This approach significantly improves the model's ability to identify the subtle patterns characteristic of multi-tip artifacts.

How to Use

The following Python code shows how to load and use the model for inference.

import torch
import numpy as np
from PIL import Image
from transformers import AutoModelForImageClassification

def preprocess_for_artifact_detection(image_path):
    """
    Loads an STM image and converts it to the required 3-channel format
    (grayscale, magnitude spectrum, phase) for the model.
    """
    try:
        with Image.open(image_path) as img:
            img = img.convert('L').resize((224, 224))
            grayscale_img = np.array(img) / 255.0
    except FileNotFoundError:
        print(f"Error: The file at {image_path} was not found.")
        return None

    # Compute FFT, Magnitude, and Phase 
    fft_data = np.fft.fft2(grayscale_img)
    fft_shifted = np.fft.fftshift(fft_data)
    
    magnitude_spectrum = np.abs(fft_shifted)
    phase = np.angle(fft_shifted)

    # Stack channels and convert to PyTorch tensor (C, H, W)
    stacked_channels = np.stack([grayscale_img, magnitude_spectrum, phase], axis=0)
    
    # Add a batch dimension (B, C, H, W) and return as float tensor
    return torch.tensor(stacked_channels, dtype=torch.float32).unsqueeze(0)

# Load the model from the Hub
model_name = "t0m-R/vit-stm-artifact-fft"
model = AutoModelForImageClassification.from_pretrained(model_name)

# Preprocess 
image_path = "path/to/your/stm_image" # Replace with your image path
preprocessed_image = preprocess_for_artifact_detection(image_path)

# Run inference
with torch.no_grad():
    logits = model(preprocessed_image).logits
    predicted_label_id = logits.argmax(-1).item()
    predicted_label = model.config.id2label[predicted_label_id]

print(f"Predicted Label: {predicted_label}")
# Expected output: "Predicted Label: Multi-Tip Artifact"

Preprocessing

This model will not work with standard image preprocessing. The input must be a 3-channel tensor representing the grayscale image, FFT amplitude, and FFT phase, as implemented in the function provided in the "How to Use" section.

Training Data

The model was fine-tuned on a synthetic dataset generated from experimental STM images recorded at CNR-IOM, Trieste. Artifact-free images were transformed into synthetic multi-tip images by summing the clean image with translated and intensity-scaled versions of itself.

Citation

If you use this model in your research, please cite the original work:

@article{rodani2024enhancing,
  title={Enhancing Multi-Tip Artifact Detection in STM Images Using Fourier Transform and Vision Transformers},
  author={Rodani, Tommaso and Ansuini, Alessio and Cazziga, Alberto},
  journal={Accepted at the 1st Machine Learning for Life and Material Sciences Workshop at ICML},
  year={2024}
}
Downloads last month
13
Safetensors
Model size
87.5M params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for t0m-R/vit-stm-artifact-fft

Finetuned
(15)
this model