gabrielchua commited on
Commit
aa648a7
·
1 Parent(s): 1435ea6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -12
app.py CHANGED
@@ -1,28 +1,47 @@
1
- import gradio as gr
 
 
 
 
 
2
  import os
3
- import openai
4
- import torch
5
  import sys
6
  import uuid
7
  from datetime import datetime
8
- import json
9
-
10
- from safetensors.torch import load_file
11
- from lionguard2 import LionGuard2, CATEGORIES
12
- from utils import get_embeddings
13
 
 
 
 
14
  import gspread
15
  from google.oauth2 import service_account
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  # --- OpenAI Setup ---
18
  client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
19
 
20
  # --- Model Loading ---
21
  def load_lionguard2():
22
- model = LionGuard2()
23
- model.eval()
24
- state_dict = load_file('LionGuard2.safetensors')
25
- model.load_state_dict(state_dict)
26
  return model
27
 
28
  model = load_lionguard2()
 
1
+ """
2
+ app.py
3
+ """
4
+
5
+ # Standard imports
6
+ import json
7
  import os
 
 
8
  import sys
9
  import uuid
10
  from datetime import datetime
 
 
 
 
 
11
 
12
+ # Third party imports
13
+ import openai
14
+ import gradio as gr
15
  import gspread
16
  from google.oauth2 import service_account
17
+ from transformers import AutoModel
18
+
19
+ # Local imports
20
+ from utils import get_embeddings
21
+
22
+ # --- Categories
23
+ CATEGORIES = {
24
+ "binary": ["binary"],
25
+ "hateful": ["hateful_l1", "hateful_l2"],
26
+ "insults": ["insults"],
27
+ "sexual": [
28
+ "sexual_l1",
29
+ "sexual_l2",
30
+ ],
31
+ "physical_violence": ["physical_violence"],
32
+ "self_harm": ["self_harm_l1", "self_harm_l2"],
33
+ "all_other_misconduct": [
34
+ "all_other_misconduct_l1",
35
+ "all_other_misconduct_l2",
36
+ ],
37
+ }
38
 
39
  # --- OpenAI Setup ---
40
  client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
41
 
42
  # --- Model Loading ---
43
  def load_lionguard2():
44
+ model = AutoModel.from_pretrained("govtech/lionguard-2", trust_remote_code=True)
 
 
 
45
  return model
46
 
47
  model = load_lionguard2()