KrishnaKapale commited on
Commit
aec8917
·
verified ·
1 Parent(s): f1c7374

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
- from transformers import AutoFeatureExtractor, AutoModelForImageClassification
 
3
  from PIL import Image
4
  import torch
5
 
@@ -7,14 +8,16 @@ import torch
7
  MODEL_NAME = "prithivMLmods/Augmented-Waste-Classifier-SigLIP2"
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
10
- extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
 
11
  model = AutoModelForImageClassification.from_pretrained(MODEL_NAME).to(device)
12
 
13
  # Inference function
14
  def classify_image(image):
15
  if image.mode != "RGB":
16
  image = image.convert("RGB")
17
- inputs = extractor(images=image, return_tensors="pt")
 
18
  inputs = {k: v.to(device) for k, v in inputs.items()}
19
  outputs = model(**inputs)
20
  probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
@@ -37,4 +40,4 @@ interface = gr.Interface(
37
  )
38
 
39
  if __name__ == "__main__":
40
- interface.launch()
 
1
  import gradio as gr
2
+ # CHANGED: We now import AutoImageProcessor instead of AutoFeatureExtractor
3
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
4
  from PIL import Image
5
  import torch
6
 
 
8
  MODEL_NAME = "prithivMLmods/Augmented-Waste-Classifier-SigLIP2"
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
+ # CHANGED: Use AutoImageProcessor
12
+ processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
13
  model = AutoModelForImageClassification.from_pretrained(MODEL_NAME).to(device)
14
 
15
  # Inference function
16
  def classify_image(image):
17
  if image.mode != "RGB":
18
  image = image.convert("RGB")
19
+ # CHANGED: Use the new 'processor' variable
20
+ inputs = processor(images=image, return_tensors="pt")
21
  inputs = {k: v.to(device) for k, v in inputs.items()}
22
  outputs = model(**inputs)
23
  probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
 
40
  )
41
 
42
  if __name__ == "__main__":
43
+ interface.launch()