ALBEF Price Prediction Model

This is a multimodal ALBEF (Align Before Fuse) model trained for product price prediction using both images and text descriptions.

Model Description

  • Architecture: ALBEF with ResNet50 vision encoder and BERT text encoder
  • Task: Product price prediction from images and catalog descriptions
  • Training: Cross-modal fusion with contrastive learning

Latest Metrics (Epoch 3)

Metric Value
Validation Loss 0.8670
RMSE 29.24
MAE 13.00
SMAPE 56.06%
MAPE 73.07%

Training Configuration

  • Vision Encoder: ResNet50 (pretrained)
  • Text Encoder: BERT-base-uncased
  • Hidden Dimension: 1024
  • Cross-modal Layers: 6
  • Optimizer: AdamW with Cosine Annealing
  • Loss: Combined MSE + SMAPE + Contrastive

Usage

import torch
from transformers import AutoTokenizer
from PIL import Image
import torchvision.transforms as T
from huggingface_hub import hf_hub_download

# Download checkpoint
checkpoint_path = hf_hub_download(
    repo_id="Rudra12567/albef-price-prediction",
    filename="best_model.pth"
)

# Load checkpoint
checkpoint = torch.load(checkpoint_path)
# Initialize your model and load state_dict
# model.load_state_dict(checkpoint['model_state_dict'])

# Prepare image
transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

image = Image.open('product.jpg').convert('RGB')
pixel_values = transform(image).unsqueeze(0)

# Prepare text
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
text_inputs = tokenizer(
    "Product description here",
    truncation=True,
    padding='max_length',
    max_length=128,
    return_tensors='pt'
)

# Predict
with torch.no_grad():
    outputs = model(pixel_values, text_inputs)
    price_log = outputs['price_pred']
    price = torch.expm1(price_log)

Training Details

  • Trained on product images and catalog descriptions
  • Log-transformed prices for better regression performance
  • Multi-task learning with contrastive objectives

License

Apache 2.0

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support