import gradio as gr from transformers import AutoImageProcessor, AutoModelForObjectDetection from PIL import Image, ImageDraw, ImageFont import torch import requests import os # Set SerpAPI key SERPAPI_KEY = os.environ.get("SERPAPI_KEY") # Load model and processor model_name = "valentinafeve/yolos-fashionpedia" processor = AutoImageProcessor.from_pretrained(model_name) model = AutoModelForObjectDetection.from_pretrained(model_name) # Fashion categories CATS = [ 'shirt, blouse', 'top, t-shirt, sweatshirt', 'sweater', 'cardigan', 'jacket', 'vest', 'pants', 'shorts', 'skirt', 'coat', 'dress', 'jumpsuit', 'cape', 'glasses', 'hat', 'headband, head covering, hair accessory', 'tie', 'glove', 'watch', 'belt', 'leg warmer', 'tights, stockings', 'sock', 'shoe', 'bag, wallet', 'scarf', 'umbrella', 'hood', 'collar', 'lapel', 'epaulette', 'sleeve', 'pocket', 'neckline', 'buckle', 'zipper', 'applique', 'bead', 'bow', 'flower', 'fringe', 'ribbon', 'rivet', 'ruffle', 'sequin', 'tassel' ] model.config.id2label = {i: label for i, label in enumerate(CATS)} model.config.label2id = {label: i for i, label in model.config.id2label.items()} # Main outfit labels only main_labels = set(CATS[:27]) def get_price(item_name): """Fetch average price from Google Shopping via SerpAPI.""" try: url = "https://serpapi.com/search.json" params = { "q": f"{item_name} price", "tbm": "shop", "api_key": SERPAPI_KEY, "num": 10 } response = requests.get(url, params=params) response.raise_for_status() data = response.json() prices = [] if "shopping_results" in data: for result in data["shopping_results"]: if "price" in result: price_str = result["price"].replace("$", "").replace(",", "") try: prices.append(float(price_str)) except ValueError: continue return round(sum(prices) / len(prices), 2) if prices else 10.0 except Exception as e: print(f"Error fetching price for {item_name}: {e}") return 10.0 def detect_fashion_items(image): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # Prepare inputs inputs = processor(images=image, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} # Run inference with torch.no_grad(): outputs = model(**inputs) # Post-process target_sizes = torch.tensor([image.size[::-1]]) results = processor.post_process_object_detection( outputs, threshold=0.5, target_sizes=target_sizes )[0] # Filter to main labels and pick best per label best_per_label = {} for score, label_id, box in zip(results["scores"], results["labels"], results["boxes"]): label = model.config.id2label[label_id.item()] if label not in main_labels: continue score_val = score.item() if label not in best_per_label or score_val > best_per_label[label]["score"]: best_per_label[label] = { "score": score_val, "box": [round(i, 2) for i in box.tolist()], "label": label, "price": get_price(label) } # Draw on image image = image.convert("RGBA") # For shadow transparency draw = ImageDraw.Draw(image) try: font = ImageFont.truetype("DejaVuSans-Bold.ttf", 18) # Bold, 18px except: font = ImageFont.load_default() for item in best_per_label.values(): box = item["box"] label = item["label"] score = item["score"] price = item["price"] # Draw bounding box draw.rectangle(box, outline="blue", width=3) # Draw label label_text = f"{label}: {score:.2f}" draw.text((box[0], box[1] - 50), label_text, fill="blue", font=font) # Draw price tag (modern: yellow, rounded, shadowed) tag_x = box[0] tag_y = box[1] - 80 # Above label tag_width = 120 tag_height = 40 draw.rounded_rectangle( [tag_x + 2, tag_y + 2, tag_x + tag_width + 2, tag_y + tag_height + 2], radius=10, fill=(0, 0, 0, 64) # Shadow ) draw.rounded_rectangle( [tag_x, tag_y, tag_x + tag_width, tag_y + tag_height], radius=10, fill="yellow", outline="black", width=2 ) price_text = f"${price:.2f}" text_bbox = draw.textbbox((0, 0), price_text, font=font) text_width = text_bbox[2] - text_bbox[0] draw.text((tag_x + (tag_width - text_width) // 2, tag_y + 10), price_text, fill="black", font=font) # Convert back to RGB image = image.convert("RGB") # Calculate total price total_price = sum(item["price"] for item in best_per_label.values()) return image, f"Total Outfit Price: ${total_price:.2f}" # Gradio interface with gr.Blocks(title="Fashion Outfit Detector with Live Prices") as iface: gr.Markdown("### Fashion Outfit Detector with Live Prices\nUpload an image to detect unique outfit items with real-time prices from Google Shopping.") with gr.Row(): image_input = gr.Image(type="pil", label="Upload a fashion image") result_image = gr.Image(type="pil", label="Detected Outfits with Prices") total_price_output = gr.Textbox(label="Total Price") # Submit button submit_btn = gr.Button("Detect Outfits") submit_btn.click( fn=detect_fashion_items, inputs=image_input, outputs=[result_image, total_price_output] ) if __name__ == "__main__": iface.launch(share=True)