MonaHamid commited on
Commit
52a31b8
·
verified ·
1 Parent(s): d06f60f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -15
app.py CHANGED
@@ -2,25 +2,34 @@ from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
  import torch
3
  import gradio as gr
4
 
5
- # Load tokenizer and model from your saved folder
6
  model_dir = "saved_model"
7
  tokenizer = AutoTokenizer.from_pretrained(model_dir)
 
8
 
9
- # Load the model with explicit label mappings
10
- model = AutoModelForSequenceClassification.from_pretrained(
11
- model_dir,
12
- id2label={0: "non-toxic", 1: "toxic"},
13
- label2id={"non-toxic": 0, "toxic": 1}
14
- )
 
 
 
15
 
16
- # Define classification function
17
  def classify(text):
18
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
19
- outputs = model(**inputs)
20
- probs = torch.softmax(outputs.logits, dim=1)
21
- labels = ["non-toxic", "toxic"] # must match id2label order
22
- return {labels[i]: float(probs[0][i]) for i in range(len(labels))}
23
-
24
- # Launch Gradio app
25
- gr.Interface(fn=classify, inputs="text", outputs="label").launch()
26
 
 
 
 
 
 
 
 
 
2
  import torch
3
  import gradio as gr
4
 
5
+ # Load your saved model and tokenizer
6
  model_dir = "saved_model"
7
  tokenizer = AutoTokenizer.from_pretrained(model_dir)
8
+ model = AutoModelForSequenceClassification.from_pretrained(model_dir)
9
 
10
+ # Define all 6 labels (Jigsaw-style multi-label toxic comment classification)
11
+ labels = [
12
+ "toxic",
13
+ "severe_toxic",
14
+ "obscene",
15
+ "threat",
16
+ "insult",
17
+ "identity_hate"
18
+ ]
19
 
20
+ # Inference function
21
  def classify(text):
22
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
23
+ with torch.no_grad():
24
+ outputs = model(**inputs)
25
+ probs = torch.sigmoid(outputs.logits)[0] # Sigmoid for multi-label
26
+ result = {label: float(probs[i]) for i, label in enumerate(labels)}
27
+ return result
 
 
28
 
29
+ # Gradio interface
30
+ gr.Interface(
31
+ fn=classify,
32
+ inputs=gr.Textbox(placeholder="Enter your comment..."),
33
+ outputs=gr.Label(num_top_classes=6),
34
+ title="Toxic Comment Classifier"
35
+ ).launch()