YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

🖼️ Mô hình dự đoán chữ số viết tay

📝 Mô tả

Đây là mô hình Vision Transformer (ViT‑Base với patch size 32) được fine-tuned từ openai/clip-vit-base-patch32 để thực hiện phân loại chữ số viết tay (MNIST). Chỉ phần vision encoder được training lại, giữ nguyên text encoder để giữ khả năng zero-shot của CLIP.

📌 Nhiệm vụ

Dự đoán chữ số (0–9) từ ảnh MNIST, dưới dạng phân loại đơn giản gồm 10 lớp.

📥 Đầu vào

Ảnh xám (grayscale) kích thước 28×28, mô hình sẽ tự xử lý chuẩn hóa/chuyển sang 3 kênh nếu cần (qua processor của CLIP). Đầu vào sẽ được đưa vào dưới dạng tensor [batch_size, 3, 224, 224] sau khi qua CLIPProcessor.

📤 Đầu ra

logits có kích thước [batch_size, 10], đại diện xác suất tương ứng với mỗi chữ số từ 0 đến 9.

🧪 Kết quả đánh giá

Giai đoạn Accuracy Pre-trained (chưa fine-tune) 47.6% Sau fine-tune 99.57%

🛠 Yêu cầu thư viện

Cài đặt các thư viện cần thiết:

pip install torch transformers datasets pillow

🚀 Cách sử dụng

🎯 Sử dụng encoder đã fine-tuned

import torch
from transformers import CLIPVisionModel, CLIPProcessor
from PIL import Image

# Tải vision encoder và CLIP processor
vision_model = CLIPVisionModel.from_pretrained("zhaospei/Model_11")
processor = CLIPProcessor.from_pretrained("zhaospei/Model_10")

# Chuẩn bị ảnh MNIST (28×28)
img = Image.open("path_to_mnist_digit.png").convert("L")  # ảnh xám
img = img.resize((224, 224)).convert("RGB")  # mở rộng thành RGB 3 kênh

inputs = processor(images=img, return_tensors="pt")

# Lấy embedding từ ảnh
with torch.no_grad():
    vision_outputs = vision_model(**inputs)

image_embeds = vision_outputs.last_hidden_state[:, 0, :]  # CLS token embedding
print("Image embedding shape:", image_embeds.shape)

🔄 Kết hợp với CLIP để trên nền zero-shot

from transformers import CLIPModel

# Tải CLIP đầy đủ
clip = CLIPModel.from_pretrained("zhaospei/Model_11")
# Thay thành encoder đã fine-tune
clip.vision_model.load_state_dict(vision_model.vision_model.state_dict())

# Ví dụ zero-shot MNIST
from PIL import Image
img = Image.open("path_to_mnist_digit.png").convert("L").resize((224, 224)).convert("RGB")
texts = [str(i) for i in range(10)]

inputs = processor(text=texts, images=img, return_tensors="pt", padding=True)
with torch.no_grad():
    outputs = clip(**inputs)

probs = outputs.logits_per_image.softmax(dim=1)[0]
print({texts[i]: float(probs[i]) for i in range(10)})

⚙️ Thông tin huấn luyện

Optimizer: Adam, learning rate = 1e-5 Batch size: 32 Số bước huấn luyện: 4000 Chỉ fine-tune vision encode

Downloads last month
-
Safetensors
Model size
87.5M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support