🖼️ Mô hình dự đoán chữ số viết tay
📝 Mô tả
Đây là mô hình Vision Transformer (ViT‑Base với patch size 32) được fine-tuned từ openai/clip-vit-base-patch32 để thực hiện phân loại chữ số viết tay (MNIST). Chỉ phần vision encoder được training lại, giữ nguyên text encoder để giữ khả năng zero-shot của CLIP.
📌 Nhiệm vụ
Dự đoán chữ số (0–9) từ ảnh MNIST, dưới dạng phân loại đơn giản gồm 10 lớp.
📥 Đầu vào
Ảnh xám (grayscale) kích thước 28×28, mô hình sẽ tự xử lý chuẩn hóa/chuyển sang 3 kênh nếu cần (qua processor của CLIP). Đầu vào sẽ được đưa vào dưới dạng tensor [batch_size, 3, 224, 224] sau khi qua CLIPProcessor.
📤 Đầu ra
logits có kích thước [batch_size, 10], đại diện xác suất tương ứng với mỗi chữ số từ 0 đến 9.
🧪 Kết quả đánh giá
Giai đoạn Accuracy Pre-trained (chưa fine-tune) 47.6% Sau fine-tune 99.57%
🛠 Yêu cầu thư viện
Cài đặt các thư viện cần thiết:
pip install torch transformers datasets pillow
🚀 Cách sử dụng
🎯 Sử dụng encoder đã fine-tuned
import torch
from transformers import CLIPVisionModel, CLIPProcessor
from PIL import Image
# Tải vision encoder và CLIP processor
vision_model = CLIPVisionModel.from_pretrained("zhaospei/Model_11")
processor = CLIPProcessor.from_pretrained("zhaospei/Model_10")
# Chuẩn bị ảnh MNIST (28×28)
img = Image.open("path_to_mnist_digit.png").convert("L") # ảnh xám
img = img.resize((224, 224)).convert("RGB") # mở rộng thành RGB 3 kênh
inputs = processor(images=img, return_tensors="pt")
# Lấy embedding từ ảnh
with torch.no_grad():
vision_outputs = vision_model(**inputs)
image_embeds = vision_outputs.last_hidden_state[:, 0, :] # CLS token embedding
print("Image embedding shape:", image_embeds.shape)
🔄 Kết hợp với CLIP để trên nền zero-shot
from transformers import CLIPModel
# Tải CLIP đầy đủ
clip = CLIPModel.from_pretrained("zhaospei/Model_11")
# Thay thành encoder đã fine-tune
clip.vision_model.load_state_dict(vision_model.vision_model.state_dict())
# Ví dụ zero-shot MNIST
from PIL import Image
img = Image.open("path_to_mnist_digit.png").convert("L").resize((224, 224)).convert("RGB")
texts = [str(i) for i in range(10)]
inputs = processor(text=texts, images=img, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = clip(**inputs)
probs = outputs.logits_per_image.softmax(dim=1)[0]
print({texts[i]: float(probs[i]) for i in range(10)})
⚙️ Thông tin huấn luyện
Optimizer: Adam, learning rate = 1e-5 Batch size: 32 Số bước huấn luyện: 4000 Chỉ fine-tune vision encode
- Downloads last month
- -