sravan837 commited on
Commit
4cd2e2b
·
verified ·
1 Parent(s): 3ddd8b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -56
app.py CHANGED
@@ -4,13 +4,12 @@ from transformers import CLIPProcessor, CLIPModel
4
  from PIL import Image
5
  import os
6
  import faiss
7
- from datasets import load_dataset, concatenate_datasets
8
- import requests
9
  import io
10
- import time
11
 
12
  # --- Configuration ---
13
- MODEL_PATH = "clip_finetuned"
14
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
  FAISS_INDEX_PATH = "gallery.index"
16
 
@@ -22,79 +21,47 @@ processor = CLIPProcessor.from_pretrained(MODEL_PATH)
22
  print("Loading FAISS index...")
23
  faiss_index = faiss.read_index(FAISS_INDEX_PATH)
24
 
25
- # --- Load Full COCO Training and Validation Datasets ---
26
- print("Connecting to COCO dataset (train and validation splits) on the Hub...")
27
- train_dataset = load_dataset("phiyodr/coco2017", split="train")
28
- val_dataset = load_dataset("phiyodr/coco2017", split="validation")
29
- combined_dataset = concatenate_datasets([train_dataset, val_dataset])
30
- print(f"Successfully connected to combined dataset with {len(combined_dataset)} images.")
31
-
32
- # --- Filter for Child-Friendly Content ---
33
- def filter_child_friendly(dataset):
34
- adult_keywords = ["nude", "violence", "adult", "gun", "blood"]
35
- filtered_dataset = []
36
- for item in dataset:
37
- file_name = item.get('file_name', '').lower()
38
- # Exclude images with adult-related keywords in file_name
39
- if not any(keyword in file_name for keyword in adult_keywords):
40
- filtered_dataset.append(item)
41
- return filtered_dataset
42
-
43
- filtered_dataset = filter_child_friendly(combined_dataset)
44
- print(f"Filtered dataset size for child-friendly content: {len(filtered_dataset)} images.")
45
 
46
- # --- The Search Function with Metrics ---
47
  def image_search(query_text: str, top_k: int):
48
- start_time = time.time()
49
  with torch.no_grad():
50
  inputs = processor(text=query_text, return_tensors="pt").to(DEVICE)
51
  text_embedding = model.get_text_features(**inputs)
52
  text_embedding /= text_embedding.norm(p=2, dim=-1, keepdim=True)
53
 
54
  distances, indices = faiss_index.search(text_embedding.cpu().numpy(), int(top_k))
55
-
56
- # Process results with metrics
57
  results = []
58
- relevant_count = 0
59
- retrieval_time = time.time() - start_time
60
- memory_usage = torch.cuda.memory_allocated() / 1024**2 if DEVICE == "cuda" else os.cpu_count() * 10 # Approx. MB
61
-
62
  for i in indices[0]:
63
- if i < len(filtered_dataset):
64
- item = filtered_dataset[int(i)]
65
- image_url = item['coco_url'] # Assuming coco_url is available
66
- response = requests.get(image_url)
67
- image = Image.open(io.BytesIO(response.content)).convert("RGB")
68
- results.append(image)
69
- # Simple relevance check based on file_name matching query
70
- file_name = item.get('file_name', '').lower()
71
- if query_text.lower() in file_name:
72
- relevant_count += 1
73
-
74
- accuracy = (relevant_count / top_k) * 100 if top_k > 0 else 0
75
- metrics = f"Retrieval Time: {retrieval_time:.2f} seconds, Memory Usage: {memory_usage:.2f} MB, Accuracy: {accuracy:.2f}%"
76
-
77
- print(metrics)
78
- return results, metrics
79
-
80
- # --- Gradio Interface ---
81
  with gr.Blocks(theme=gr.themes.Soft()) as iface:
82
  gr.Markdown("# 🖼️ CLIP-Powered Image Search Engine")
83
- gr.Markdown("Enter a text description to search for child-friendly images from the COCO dataset.")
84
 
85
  with gr.Row():
86
- query_input = gr.Textbox(label="Search Query", placeholder="e.g., a dog playing", scale=4)
87
  k_slider = gr.Slider(minimum=1, maximum=12, value=4, step=1, label="Number of Results")
88
  submit_btn = gr.Button("Search", variant="primary")
89
 
90
  gallery_output = gr.Gallery(label="Search Results", show_label=False, columns=4, height="auto")
91
- metrics_output = gr.Textbox(label="Performance Metrics", interactive=False)
92
 
93
- submit_btn.click(fn=image_search, inputs=[query_input, k_slider], outputs=[gallery_output, metrics_output])
94
 
95
  gr.Examples(
96
- examples=[["a dog playing", 4], ["children in a park", 8]],
97
  inputs=[query_input, k_slider]
98
  )
99
 
100
- iface.launch(share=True)
 
4
  from PIL import Image
5
  import os
6
  import faiss
7
+ from datasets import load_dataset
8
+ import requests
9
  import io
 
10
 
11
  # --- Configuration ---
12
+ MODEL_PATH = "clip_finetuned"
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
  FAISS_INDEX_PATH = "gallery.index"
15
 
 
21
  print("Loading FAISS index...")
22
  faiss_index = faiss.read_index(FAISS_INDEX_PATH)
23
 
24
+ # --- Connect to the COCO dataset on the Hub ---
25
+ print("Connecting to COCO dataset on the Hub...")
26
+ val_dataset = load_dataset("phiyodr/coco2017", split="validation", trust_remote_code=True)
27
+
28
+ print(f"Successfully connected to dataset with {len(val_dataset)} images.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ # --- The Search Function (Corrected) ---
31
  def image_search(query_text: str, top_k: int):
 
32
  with torch.no_grad():
33
  inputs = processor(text=query_text, return_tensors="pt").to(DEVICE)
34
  text_embedding = model.get_text_features(**inputs)
35
  text_embedding /= text_embedding.norm(p=2, dim=-1, keepdim=True)
36
 
37
  distances, indices = faiss_index.search(text_embedding.cpu().numpy(), int(top_k))
 
 
38
  results = []
 
 
 
 
39
  for i in indices[0]:
40
+ item = val_dataset[int(i)]
41
+ image_url = item['coco_url']
42
+ response = requests.get(image_url)
43
+ image = Image.open(io.BytesIO(response.content)).convert("RGB")
44
+ results.append(image)
45
+
46
+ return results
47
+
48
+ # --- Gradio Interface (No changes needed here) ---
 
 
 
 
 
 
 
 
 
49
  with gr.Blocks(theme=gr.themes.Soft()) as iface:
50
  gr.Markdown("# 🖼️ CLIP-Powered Image Search Engine")
51
+ gr.Markdown("Enter a text description to search for matching images.")
52
 
53
  with gr.Row():
54
+ query_input = gr.Textbox(label="Search Query", placeholder="e.g., a red car parked near a building", scale=4)
55
  k_slider = gr.Slider(minimum=1, maximum=12, value=4, step=1, label="Number of Results")
56
  submit_btn = gr.Button("Search", variant="primary")
57
 
58
  gallery_output = gr.Gallery(label="Search Results", show_label=False, columns=4, height="auto")
 
59
 
60
+ submit_btn.click(fn=image_search, inputs=[query_input, k_slider], outputs=gallery_output)
61
 
62
  gr.Examples(
63
+ examples=[["a dog catching a frisbee", 4], ["two people eating pizza", 8]],
64
  inputs=[query_input, k_slider]
65
  )
66
 
67
+ iface.launch()