Updated the app to unsqueeze the images
Browse files
app.py
CHANGED
|
@@ -8,22 +8,29 @@ from datetime import datetime
|
|
| 8 |
import torch.nn.functional as F
|
| 9 |
from typing import List
|
| 10 |
|
| 11 |
-
# Load secrets
|
| 12 |
-
openai_api_key = st.secrets.get("OPENAI_API_KEY")
|
| 13 |
-
# You can now use openai_api_key for anything requiring OpenAI access
|
| 14 |
-
|
| 15 |
# Device setup
|
| 16 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 17 |
|
| 18 |
-
# Load CLIP model
|
| 19 |
-
model, preprocess = clip.load("ViT-
|
| 20 |
model.eval()
|
| 21 |
|
| 22 |
-
#
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
-
#
|
| 26 |
-
# --- COPY YOUR FUNCTION DEFINITION BELOW DIRECTLY OR PUT IT IN A SEPARATE FILE ---
|
| 27 |
def few_shot_fault_classification(
|
| 28 |
test_images: List[Image.Image],
|
| 29 |
test_image_filenames: List[str],
|
|
@@ -47,16 +54,16 @@ def few_shot_fault_classification(
|
|
| 47 |
results = []
|
| 48 |
|
| 49 |
with torch.no_grad():
|
| 50 |
-
nominal_features = torch.stack([model.encode_image(img).to(device) for img in nominal_images])
|
| 51 |
nominal_features /= nominal_features.norm(dim=-1, keepdim=True)
|
| 52 |
|
| 53 |
-
defective_features = torch.stack([model.encode_image(img).to(device) for img in defective_images])
|
| 54 |
defective_features /= defective_features.norm(dim=-1, keepdim=True)
|
| 55 |
|
| 56 |
csv_data = []
|
| 57 |
|
| 58 |
for idx, test_img in enumerate(test_images):
|
| 59 |
-
test_features = model.encode_image(test_img).to(device)
|
| 60 |
test_features /= test_features.norm(dim=-1, keepdim=True)
|
| 61 |
|
| 62 |
max_nom_sim, max_def_sim = -float('inf'), -float('inf')
|
|
@@ -110,7 +117,7 @@ def few_shot_fault_classification(
|
|
| 110 |
|
| 111 |
return ""
|
| 112 |
|
| 113 |
-
#
|
| 114 |
if 'nominal_images' not in st.session_state:
|
| 115 |
st.session_state.nominal_images = []
|
| 116 |
if 'defective_images' not in st.session_state:
|
|
@@ -120,16 +127,12 @@ if 'test_images' not in st.session_state:
|
|
| 120 |
if 'results' not in st.session_state:
|
| 121 |
st.session_state.results = []
|
| 122 |
|
| 123 |
-
|
| 124 |
-
st.title("🛠️ Few-Shot Fault Detection (Industrial Quality Control)")
|
| 125 |
-
st.markdown("Upload **Nominal Images** (good parts), **Defective Images** (bad parts), and **Test Images** to classify.")
|
| 126 |
-
|
| 127 |
tab1, tab2, tab3 = st.tabs(["📥 Upload Reference Images", "🔍 Test Classification", "📊 Results"])
|
| 128 |
|
| 129 |
-
#
|
| 130 |
with tab1:
|
| 131 |
st.header("Upload Reference Images")
|
| 132 |
-
|
| 133 |
nominal_files = st.file_uploader("Upload Nominal Images", accept_multiple_files=True, type=['png', 'jpg', 'jpeg'])
|
| 134 |
defective_files = st.file_uploader("Upload Defective Images", accept_multiple_files=True, type=['png', 'jpg', 'jpeg'])
|
| 135 |
|
|
@@ -143,10 +146,9 @@ with tab1:
|
|
| 143 |
st.session_state.defective_descriptions = [file.name for file in defective_files]
|
| 144 |
st.success(f"Uploaded {len(defective_files)} defective images.")
|
| 145 |
|
| 146 |
-
#
|
| 147 |
with tab2:
|
| 148 |
st.header("Upload Test Image(s)")
|
| 149 |
-
|
| 150 |
test_files = st.file_uploader("Upload Test Images", accept_multiple_files=True, type=['png', 'jpg', 'jpeg'])
|
| 151 |
|
| 152 |
if st.button("🔍 Run Classification") and test_files:
|
|
@@ -169,10 +171,9 @@ with tab2:
|
|
| 169 |
st.success("Classification complete!")
|
| 170 |
st.session_state.results = "streamlit_results.csv"
|
| 171 |
|
| 172 |
-
#
|
| 173 |
with tab3:
|
| 174 |
st.header("Classification Results")
|
| 175 |
-
|
| 176 |
if os.path.exists("streamlit_results.csv"):
|
| 177 |
df = pd.read_csv("streamlit_results.csv")
|
| 178 |
st.dataframe(df)
|
|
|
|
| 8 |
import torch.nn.functional as F
|
| 9 |
from typing import List
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
# Device setup
|
| 12 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 13 |
|
| 14 |
+
# Load CLIP model and preprocessor (ViT-B/32 = small model, CPU-friendly)
|
| 15 |
+
model, preprocess = clip.load("ViT-B/32", device=device)
|
| 16 |
model.eval()
|
| 17 |
|
| 18 |
+
# Display app title and information
|
| 19 |
+
st.set_page_config(page_title="Few-Shot Fault Detection", layout="wide")
|
| 20 |
+
st.title("🛠️ Few-Shot Fault Detection (Industrial Quality Control)")
|
| 21 |
+
|
| 22 |
+
st.markdown("""
|
| 23 |
+
This demo uses the **smaller `ViT-B/32` encoder** from OpenAI's CLIP model to classify test images as **Nominal** or **Defective**, based on few-shot learning using user-provided reference images.
|
| 24 |
+
|
| 25 |
+
⚠️ **Note**: This app is running on a **free CPU tier** and is meant for demonstration purposes. For more advanced use cases, including GPU acceleration, custom training, and larger models, please refer to:
|
| 26 |
+
|
| 27 |
+
📄 [Megahed et al. (2025)](https://arxiv.org/abs/2501.12596):
|
| 28 |
+
*Adapting OpenAI's CLIP Model for Few-Shot Image Inspection in Manufacturing Quality Control: An Expository Case Study with Multiple Application Examples*
|
| 29 |
+
|
| 30 |
+
🔗 [GitHub & Colab links available in the paper](https://arxiv.org/abs/2501.12596)
|
| 31 |
+
""")
|
| 32 |
|
| 33 |
+
# --- Few-shot classification logic ---
|
|
|
|
| 34 |
def few_shot_fault_classification(
|
| 35 |
test_images: List[Image.Image],
|
| 36 |
test_image_filenames: List[str],
|
|
|
|
| 54 |
results = []
|
| 55 |
|
| 56 |
with torch.no_grad():
|
| 57 |
+
nominal_features = torch.stack([model.encode_image(img.unsqueeze(0)).squeeze(0).to(device) for img in nominal_images])
|
| 58 |
nominal_features /= nominal_features.norm(dim=-1, keepdim=True)
|
| 59 |
|
| 60 |
+
defective_features = torch.stack([model.encode_image(img.unsqueeze(0)).squeeze(0).to(device) for img in defective_images])
|
| 61 |
defective_features /= defective_features.norm(dim=-1, keepdim=True)
|
| 62 |
|
| 63 |
csv_data = []
|
| 64 |
|
| 65 |
for idx, test_img in enumerate(test_images):
|
| 66 |
+
test_features = model.encode_image(test_img.unsqueeze(0)).squeeze(0).to(device)
|
| 67 |
test_features /= test_features.norm(dim=-1, keepdim=True)
|
| 68 |
|
| 69 |
max_nom_sim, max_def_sim = -float('inf'), -float('inf')
|
|
|
|
| 117 |
|
| 118 |
return ""
|
| 119 |
|
| 120 |
+
# --- App state ---
|
| 121 |
if 'nominal_images' not in st.session_state:
|
| 122 |
st.session_state.nominal_images = []
|
| 123 |
if 'defective_images' not in st.session_state:
|
|
|
|
| 127 |
if 'results' not in st.session_state:
|
| 128 |
st.session_state.results = []
|
| 129 |
|
| 130 |
+
# --- Tabs ---
|
|
|
|
|
|
|
|
|
|
| 131 |
tab1, tab2, tab3 = st.tabs(["📥 Upload Reference Images", "🔍 Test Classification", "📊 Results"])
|
| 132 |
|
| 133 |
+
# Tab 1: Upload Reference Images
|
| 134 |
with tab1:
|
| 135 |
st.header("Upload Reference Images")
|
|
|
|
| 136 |
nominal_files = st.file_uploader("Upload Nominal Images", accept_multiple_files=True, type=['png', 'jpg', 'jpeg'])
|
| 137 |
defective_files = st.file_uploader("Upload Defective Images", accept_multiple_files=True, type=['png', 'jpg', 'jpeg'])
|
| 138 |
|
|
|
|
| 146 |
st.session_state.defective_descriptions = [file.name for file in defective_files]
|
| 147 |
st.success(f"Uploaded {len(defective_files)} defective images.")
|
| 148 |
|
| 149 |
+
# Tab 2: Test Classification
|
| 150 |
with tab2:
|
| 151 |
st.header("Upload Test Image(s)")
|
|
|
|
| 152 |
test_files = st.file_uploader("Upload Test Images", accept_multiple_files=True, type=['png', 'jpg', 'jpeg'])
|
| 153 |
|
| 154 |
if st.button("🔍 Run Classification") and test_files:
|
|
|
|
| 171 |
st.success("Classification complete!")
|
| 172 |
st.session_state.results = "streamlit_results.csv"
|
| 173 |
|
| 174 |
+
# Tab 3: View/Download Results
|
| 175 |
with tab3:
|
| 176 |
st.header("Classification Results")
|
|
|
|
| 177 |
if os.path.exists("streamlit_results.csv"):
|
| 178 |
df = pd.read_csv("streamlit_results.csv")
|
| 179 |
st.dataframe(df)
|