fmegahed commited on
Commit
3cdb380
·
verified ·
1 Parent(s): e1e078d

Updated the app to unsqueeze the images

Browse files
Files changed (1) hide show
  1. app.py +25 -24
app.py CHANGED
@@ -8,22 +8,29 @@ from datetime import datetime
8
  import torch.nn.functional as F
9
  from typing import List
10
 
11
- # Load secrets
12
- openai_api_key = st.secrets.get("OPENAI_API_KEY")
13
- # You can now use openai_api_key for anything requiring OpenAI access
14
-
15
  # Device setup
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
- # Load CLIP model + preprocess from OpenAI CLIP
19
- model, preprocess = clip.load("ViT-L/14", device=device)
20
  model.eval()
21
 
22
- # Ensure reproducibility
23
- torch.set_grad_enabled(False)
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- # Import the few-shot classification function
26
- # --- COPY YOUR FUNCTION DEFINITION BELOW DIRECTLY OR PUT IT IN A SEPARATE FILE ---
27
  def few_shot_fault_classification(
28
  test_images: List[Image.Image],
29
  test_image_filenames: List[str],
@@ -47,16 +54,16 @@ def few_shot_fault_classification(
47
  results = []
48
 
49
  with torch.no_grad():
50
- nominal_features = torch.stack([model.encode_image(img).to(device) for img in nominal_images])
51
  nominal_features /= nominal_features.norm(dim=-1, keepdim=True)
52
 
53
- defective_features = torch.stack([model.encode_image(img).to(device) for img in defective_images])
54
  defective_features /= defective_features.norm(dim=-1, keepdim=True)
55
 
56
  csv_data = []
57
 
58
  for idx, test_img in enumerate(test_images):
59
- test_features = model.encode_image(test_img).to(device)
60
  test_features /= test_features.norm(dim=-1, keepdim=True)
61
 
62
  max_nom_sim, max_def_sim = -float('inf'), -float('inf')
@@ -110,7 +117,7 @@ def few_shot_fault_classification(
110
 
111
  return ""
112
 
113
- # Initialize app state
114
  if 'nominal_images' not in st.session_state:
115
  st.session_state.nominal_images = []
116
  if 'defective_images' not in st.session_state:
@@ -120,16 +127,12 @@ if 'test_images' not in st.session_state:
120
  if 'results' not in st.session_state:
121
  st.session_state.results = []
122
 
123
- st.set_page_config(page_title="Few-Shot Fault Detection", layout="wide")
124
- st.title("🛠️ Few-Shot Fault Detection (Industrial Quality Control)")
125
- st.markdown("Upload **Nominal Images** (good parts), **Defective Images** (bad parts), and **Test Images** to classify.")
126
-
127
  tab1, tab2, tab3 = st.tabs(["📥 Upload Reference Images", "🔍 Test Classification", "📊 Results"])
128
 
129
- # --- Tab 1: Upload Reference Images ---
130
  with tab1:
131
  st.header("Upload Reference Images")
132
-
133
  nominal_files = st.file_uploader("Upload Nominal Images", accept_multiple_files=True, type=['png', 'jpg', 'jpeg'])
134
  defective_files = st.file_uploader("Upload Defective Images", accept_multiple_files=True, type=['png', 'jpg', 'jpeg'])
135
 
@@ -143,10 +146,9 @@ with tab1:
143
  st.session_state.defective_descriptions = [file.name for file in defective_files]
144
  st.success(f"Uploaded {len(defective_files)} defective images.")
145
 
146
- # --- Tab 2: Classify Test Images ---
147
  with tab2:
148
  st.header("Upload Test Image(s)")
149
-
150
  test_files = st.file_uploader("Upload Test Images", accept_multiple_files=True, type=['png', 'jpg', 'jpeg'])
151
 
152
  if st.button("🔍 Run Classification") and test_files:
@@ -169,10 +171,9 @@ with tab2:
169
  st.success("Classification complete!")
170
  st.session_state.results = "streamlit_results.csv"
171
 
172
- # --- Tab 3: View/Download Results ---
173
  with tab3:
174
  st.header("Classification Results")
175
-
176
  if os.path.exists("streamlit_results.csv"):
177
  df = pd.read_csv("streamlit_results.csv")
178
  st.dataframe(df)
 
8
  import torch.nn.functional as F
9
  from typing import List
10
 
 
 
 
 
11
  # Device setup
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
+ # Load CLIP model and preprocessor (ViT-B/32 = small model, CPU-friendly)
15
+ model, preprocess = clip.load("ViT-B/32", device=device)
16
  model.eval()
17
 
18
+ # Display app title and information
19
+ st.set_page_config(page_title="Few-Shot Fault Detection", layout="wide")
20
+ st.title("🛠️ Few-Shot Fault Detection (Industrial Quality Control)")
21
+
22
+ st.markdown("""
23
+ This demo uses the **smaller `ViT-B/32` encoder** from OpenAI's CLIP model to classify test images as **Nominal** or **Defective**, based on few-shot learning using user-provided reference images.
24
+
25
+ ⚠️ **Note**: This app is running on a **free CPU tier** and is meant for demonstration purposes. For more advanced use cases, including GPU acceleration, custom training, and larger models, please refer to:
26
+
27
+ 📄 [Megahed et al. (2025)](https://arxiv.org/abs/2501.12596):
28
+ *Adapting OpenAI's CLIP Model for Few-Shot Image Inspection in Manufacturing Quality Control: An Expository Case Study with Multiple Application Examples*
29
+
30
+ 🔗 [GitHub & Colab links available in the paper](https://arxiv.org/abs/2501.12596)
31
+ """)
32
 
33
+ # --- Few-shot classification logic ---
 
34
  def few_shot_fault_classification(
35
  test_images: List[Image.Image],
36
  test_image_filenames: List[str],
 
54
  results = []
55
 
56
  with torch.no_grad():
57
+ nominal_features = torch.stack([model.encode_image(img.unsqueeze(0)).squeeze(0).to(device) for img in nominal_images])
58
  nominal_features /= nominal_features.norm(dim=-1, keepdim=True)
59
 
60
+ defective_features = torch.stack([model.encode_image(img.unsqueeze(0)).squeeze(0).to(device) for img in defective_images])
61
  defective_features /= defective_features.norm(dim=-1, keepdim=True)
62
 
63
  csv_data = []
64
 
65
  for idx, test_img in enumerate(test_images):
66
+ test_features = model.encode_image(test_img.unsqueeze(0)).squeeze(0).to(device)
67
  test_features /= test_features.norm(dim=-1, keepdim=True)
68
 
69
  max_nom_sim, max_def_sim = -float('inf'), -float('inf')
 
117
 
118
  return ""
119
 
120
+ # --- App state ---
121
  if 'nominal_images' not in st.session_state:
122
  st.session_state.nominal_images = []
123
  if 'defective_images' not in st.session_state:
 
127
  if 'results' not in st.session_state:
128
  st.session_state.results = []
129
 
130
+ # --- Tabs ---
 
 
 
131
  tab1, tab2, tab3 = st.tabs(["📥 Upload Reference Images", "🔍 Test Classification", "📊 Results"])
132
 
133
+ # Tab 1: Upload Reference Images
134
  with tab1:
135
  st.header("Upload Reference Images")
 
136
  nominal_files = st.file_uploader("Upload Nominal Images", accept_multiple_files=True, type=['png', 'jpg', 'jpeg'])
137
  defective_files = st.file_uploader("Upload Defective Images", accept_multiple_files=True, type=['png', 'jpg', 'jpeg'])
138
 
 
146
  st.session_state.defective_descriptions = [file.name for file in defective_files]
147
  st.success(f"Uploaded {len(defective_files)} defective images.")
148
 
149
+ # Tab 2: Test Classification
150
  with tab2:
151
  st.header("Upload Test Image(s)")
 
152
  test_files = st.file_uploader("Upload Test Images", accept_multiple_files=True, type=['png', 'jpg', 'jpeg'])
153
 
154
  if st.button("🔍 Run Classification") and test_files:
 
171
  st.success("Classification complete!")
172
  st.session_state.results = "streamlit_results.csv"
173
 
174
+ # Tab 3: View/Download Results
175
  with tab3:
176
  st.header("Classification Results")
 
177
  if os.path.exists("streamlit_results.csv"):
178
  df = pd.read_csv("streamlit_results.csv")
179
  st.dataframe(df)