Vision Transformer for STM Multi-Tip Artifact Detection
This is a fine-tuned Vision Transformer (ViT-B/32) model for classifying Scanning Tunneling Microscopy (STM) images. It is designed to detect the presence of multi-tip artifacts, a common distortion that results in duplicated signals and complicates data interpretation.
This model was developed as part of the NFFA-DI (Nano Foundries and Fine Analysis Digital Infrastructure) project, funded by the European Union's NextGenerationEU program.
Model Description
The model is a ViT-B/32 pre-trained on ImageNet-21k. It was fine-tuned to classify an STM image as either Artifact-Free or Multi-Tip Artifact.
A key feature of this model is its use of a Fast Fourier Transform (FFT) based preprocessing method. The model's input is not a standard image but a 3-channel tensor composed of:
- The grayscale STM image.
- The amplitude of the image's Fourier transform.
- The phase of the image's Fourier transform.
This approach significantly improves the model's ability to identify the subtle patterns characteristic of multi-tip artifacts.
How to Use
The following Python code shows how to load and use the model for inference.
import torch
import numpy as np
from PIL import Image
from transformers import AutoModelForImageClassification
def preprocess_for_artifact_detection(image_path):
"""
Loads an STM image and converts it to the required 3-channel format
(grayscale, magnitude spectrum, phase) for the model.
"""
try:
with Image.open(image_path) as img:
img = img.convert('L').resize((224, 224))
grayscale_img = np.array(img) / 255.0
except FileNotFoundError:
print(f"Error: The file at {image_path} was not found.")
return None
# Compute FFT, Magnitude, and Phase
fft_data = np.fft.fft2(grayscale_img)
fft_shifted = np.fft.fftshift(fft_data)
magnitude_spectrum = np.abs(fft_shifted)
phase = np.angle(fft_shifted)
# Stack channels and convert to PyTorch tensor (C, H, W)
stacked_channels = np.stack([grayscale_img, magnitude_spectrum, phase], axis=0)
# Add a batch dimension (B, C, H, W) and return as float tensor
return torch.tensor(stacked_channels, dtype=torch.float32).unsqueeze(0)
# Load the model from the Hub
model_name = "t0m-R/vit-stm-artifact-fft"
model = AutoModelForImageClassification.from_pretrained(model_name)
# Preprocess
image_path = "path/to/your/stm_image" # Replace with your image path
preprocessed_image = preprocess_for_artifact_detection(image_path)
# Run inference
with torch.no_grad():
logits = model(preprocessed_image).logits
predicted_label_id = logits.argmax(-1).item()
predicted_label = model.config.id2label[predicted_label_id]
print(f"Predicted Label: {predicted_label}")
# Expected output: "Predicted Label: Multi-Tip Artifact"
Preprocessing
This model will not work with standard image preprocessing. The input must be a 3-channel tensor representing the grayscale image, FFT amplitude, and FFT phase, as implemented in the function provided in the "How to Use" section.
Training Data
The model was fine-tuned on a synthetic dataset generated from experimental STM images recorded at CNR-IOM, Trieste. Artifact-free images were transformed into synthetic multi-tip images by summing the clean image with translated and intensity-scaled versions of itself.
Citation
If you use this model in your research, please cite the original work:
@article{rodani2024enhancing,
title={Enhancing Multi-Tip Artifact Detection in STM Images Using Fourier Transform and Vision Transformers},
author={Rodani, Tommaso and Ansuini, Alessio and Cazziga, Alberto},
journal={Accepted at the 1st Machine Learning for Life and Material Sciences Workshop at ICML},
year={2024}
}
- Downloads last month
- 13
Model tree for t0m-R/vit-stm-artifact-fft
Base model
google/vit-base-patch32-224-in21k