gabrielchua commited on
Commit
bc1321c
·
verified ·
1 Parent(s): 0ef88b1

Upload 5 files

Browse files
Files changed (5) hide show
  1. LionGuard2.safetensors +3 -0
  2. app.py +203 -0
  3. lionguard2.py +170 -0
  4. requirements.txt +4 -0
  5. utils.py +44 -0
LionGuard2.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:20665c1cde68b57c34444accc4f0fca5a3f58b3483d6bad2d6c6911e431afac9
3
+ size 3398496
app.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import openai
3
+ import os
4
+ import sys
5
+ import torch
6
+
7
+ # # Add the parent directory to the path to import from final_model
8
+ # sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'final_model'))
9
+
10
+ from safetensors.torch import load_file
11
+ from lionguard2 import LionGuard2
12
+ from utils import get_embeddings
13
+
14
+ # Set up OpenAI client
15
+ client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
16
+
17
+ # Load LionGuard2 model
18
+ model = LionGuard2()
19
+ model.eval()
20
+
21
+ # Load model weights
22
+ model_path = os.path.join(os.path.dirname(__file__), '..', 'final_model', 'LionGuard2.safetensors')
23
+ state_dict = load_file(model_path)
24
+ model.load_state_dict(state_dict)
25
+
26
+ def lionguard_2(message, threshold=0.5):
27
+ """
28
+ LionGuard 2 function that uses the actual model to determine if content is unsafe.
29
+
30
+ Args:
31
+ message: The text message to check
32
+ threshold: Probability threshold for flagging content as unsafe (default: 0.5)
33
+
34
+ Returns:
35
+ bool: True if content is flagged as unsafe, False otherwise
36
+ """
37
+ try:
38
+ # Get embeddings for the message
39
+ embeddings = get_embeddings([message])
40
+
41
+ # Get predictions from the model
42
+ results = model.predict(embeddings)
43
+
44
+ # Check the binary classification result (overall safety)
45
+ binary_prob = results['binary'][0] # First (and only) message's binary probability
46
+
47
+ # Flag as unsafe if probability exceeds threshold
48
+ return binary_prob > threshold
49
+
50
+ except Exception as e:
51
+ print(f"Error in LionGuard 2: {e}")
52
+ # In case of error, default to not flagging to avoid blocking legitimate content
53
+ return False
54
+
55
+ def get_openai_response(message, system_prompt="You are a helpful assistant."):
56
+ """Get response from OpenAI API"""
57
+ try:
58
+ response = client.chat.completions.create(
59
+ model="gpt-4.1-nano",
60
+ messages=[
61
+ {"role": "system", "content": system_prompt},
62
+ {"role": "user", "content": message}
63
+ ],
64
+ max_tokens=500,
65
+ temperature=0,
66
+ seed=42,
67
+ )
68
+ return response.choices[0].message.content
69
+ except Exception as e:
70
+ return f"Error: {str(e)}. Please check your OpenAI API key."
71
+
72
+ def openai_moderation(message):
73
+ """
74
+ OpenAI moderation function that uses OpenAI's built-in moderation API.
75
+
76
+ Args:
77
+ message: The text message to check
78
+
79
+ Returns:
80
+ bool: True if content is flagged as unsafe, False otherwise
81
+ """
82
+ try:
83
+ response = client.moderations.create(input=message)
84
+ return response.results[0].flagged
85
+ except Exception as e:
86
+ print(f"Error in OpenAI moderation: {e}")
87
+ # In case of error, default to not flagging
88
+ return False
89
+
90
+ def process_message(message, history_no_mod, history_openai, history_lg):
91
+ """Process message for all three chatbots"""
92
+ if not message.strip():
93
+ return history_no_mod, history_openai, history_lg, ""
94
+
95
+ # Process for gpt-4.1-nano (no moderation)
96
+ no_mod_response = get_openai_response(message)
97
+ history_no_mod.append({"role": "user", "content": message})
98
+ history_no_mod.append({"role": "assistant", "content": no_mod_response})
99
+
100
+ # Process for gpt-4.1-nano with OpenAI moderation
101
+ openai_flagged = openai_moderation(message)
102
+ history_openai.append({"role": "user", "content": message})
103
+
104
+ if openai_flagged:
105
+ openai_response = "🚫 This message has been flagged by OpenAI moderation"
106
+ history_openai.append({"role": "assistant", "content": openai_response})
107
+ else:
108
+ openai_response = get_openai_response(
109
+ message,
110
+ )
111
+ history_openai.append({"role": "assistant", "content": openai_response})
112
+
113
+ # Process for gpt-4.1-nano with LionGuard 2
114
+ lg_flagged = lionguard_2(message)
115
+ history_lg.append({"role": "user", "content": message})
116
+
117
+ if lg_flagged:
118
+ lg_response = "🚫 This message has been flagged by LionGuard 2"
119
+ history_lg.append({"role": "assistant", "content": lg_response})
120
+ else:
121
+ lg_response = get_openai_response(
122
+ message,
123
+ )
124
+ history_lg.append({"role": "assistant", "content": lg_response})
125
+
126
+ return history_no_mod, history_openai, history_lg, ""
127
+
128
+ def clear_all_chats():
129
+ """Clear all chat histories"""
130
+ return [], [], []
131
+
132
+ # Create the Gradio interface
133
+ with gr.Blocks(title="LionGuard 2", theme=gr.themes.Soft()) as demo:
134
+ gr.Markdown("# EMNLP 2025 System Demonstration: LionGuard 2 🦁")
135
+ gr.Markdown("**LionGuard 2 is a content moderator localised to Singapore - use it to detect unsafe LLM inputs and outputs**")
136
+
137
+ with gr.Row():
138
+ with gr.Column(scale=1):
139
+ gr.Markdown("## 🔵 No Moderation")
140
+ chatbot_no_mod = gr.Chatbot(
141
+ height=800,
142
+ label="No Moderation",
143
+ show_label=False,
144
+ bubble_full_width=False,
145
+ type='messages'
146
+ )
147
+
148
+ with gr.Column(scale=1):
149
+ gr.Markdown("## 🟠 OpenAI Moderation")
150
+ chatbot_openai = gr.Chatbot(
151
+ height=800,
152
+ label="OpenAI Moderation",
153
+ show_label=False,
154
+ bubble_full_width=False,
155
+ type='messages'
156
+ )
157
+
158
+ with gr.Column(scale=1):
159
+ gr.Markdown("## 🛡️ LionGuard 2")
160
+ chatbot_lg = gr.Chatbot(
161
+ height=800,
162
+ label="LionGuard 2",
163
+ show_label=False,
164
+ bubble_full_width=False,
165
+ type='messages'
166
+ )
167
+
168
+ # Single input for all chatbots
169
+ gr.Markdown("### 💬 Send Message to All Models")
170
+ with gr.Row():
171
+ message_input = gr.Textbox(
172
+ placeholder="Type your message to compare responses...",
173
+ show_label=False,
174
+ scale=4
175
+ )
176
+ send_btn = gr.Button("Send", variant="primary", scale=1)
177
+
178
+ # Control buttons
179
+ with gr.Row():
180
+ clear_btn = gr.Button("Clear All Chats", variant="stop")
181
+
182
+ # Event handlers
183
+ send_btn.click(
184
+ process_message,
185
+ inputs=[message_input, chatbot_no_mod, chatbot_openai, chatbot_lg],
186
+ outputs=[chatbot_no_mod, chatbot_openai, chatbot_lg, message_input]
187
+ )
188
+
189
+ message_input.submit(
190
+ process_message,
191
+ inputs=[message_input, chatbot_no_mod, chatbot_openai, chatbot_lg],
192
+ outputs=[chatbot_no_mod, chatbot_openai, chatbot_lg, message_input]
193
+ )
194
+
195
+ # Clear button
196
+ clear_btn.click(
197
+ clear_all_chats,
198
+ outputs=[chatbot_no_mod, chatbot_openai, chatbot_lg]
199
+ )
200
+
201
+ # Launch the app
202
+ if __name__ == "__main__":
203
+ demo.launch(share=True, debug=True)
lionguard2.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ lionguard2.py
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ CATEGORIES = {
9
+ "binary": ["binary"],
10
+ "hateful": ["hateful_l1", "hateful_l2"],
11
+ "insults": ["insults"],
12
+ "sexual": [
13
+ "sexual_l1",
14
+ "sexual_l2",
15
+ ],
16
+ "physical_violence": ["physical_violence"],
17
+ "self_harm": ["self_harm_l1", "self_harm_l2"],
18
+ "all_other_misconduct": [
19
+ "all_other_misconduct_l1",
20
+ "all_other_misconduct_l2",
21
+ ],
22
+ }
23
+
24
+ INPUT_DIMENSION = 3072 # length of OpenAI embeddings
25
+
26
+
27
+ class LionGuard2(nn.Module):
28
+ def __init__(
29
+ self,
30
+ input_dim=INPUT_DIMENSION,
31
+ label_names=CATEGORIES.keys(),
32
+ categories=CATEGORIES,
33
+ ):
34
+ """
35
+ LionGuard2 is a localised content moderation model that flags whether text violates the following categories:
36
+
37
+ 1. `hateful`: Text that discriminates, criticizes, insults, denounces, or dehumanizes a person or group on the basis of a protected identity.
38
+
39
+ There are two sub-categories for the `hateful` category:
40
+ a. `level_1_discriminatory`: Text that contains derogatory or generalized negative statements targeting a protected group.
41
+ b. `level_2_hate_speech`: Text that explicitly calls for harm or violence against a protected group; or language praising or justifying violence against them.
42
+
43
+ 2. `insults`: Text that insults demeans, humiliates, mocks, or belittles a person or group **without** referencing a legally protected trait.
44
+ For example, this includes personal attacks on attributes such as someone’s appearance, intellect, behavior, or other non-protected characteristics.
45
+
46
+ 3. `sexual`: Text that depicts or indicates sexual interest, activity, or arousal, using direct or indirect references to body parts, sexual acts, or physical traits.
47
+ This includes sexual content that may be inappropriate for certain audiences.
48
+
49
+ There are two sub-categories for the `sexual` category:
50
+ a. `level_1_not_appropriate_for_minors`: Text that contains mild-to-moderate sexual content that is generally adult-oriented or potentially unsuitable for those under 16.
51
+ May include matter-of-fact discussions about sex, sexuality, or sexual preferences.
52
+ b. `level_2_not_appropriate_for_all_ages`: Text that contains content aimed at adults and considered explicit, graphic, or otherwise inappropriate for a broad audience.
53
+ May include explicit descriptions of sexual acts, detailed sexual fantasies, or highly sexualized content.
54
+
55
+ 4. `physical_violence`: Text that includes glorification of violence or threats to inflict physical harm or injury on a person, group, or entity.
56
+
57
+ 5. `self_harm`: Text that promotes, suggests, or expresses intent to self-harm or commit suicide.
58
+
59
+ There are two sub-categories for the `self_harm` category:
60
+ a. `level_1_self_harm_intent`: Text that expresses suicidal thoughts or self-harm intention; or content encouraging someone to self-harm.
61
+ b. `level_2_self_harm_action`: Text that describes or indicates ongoing or imminent self-harm behavior.
62
+
63
+ 6. `all_other_misconduct`: This is a catch-all category for any other unsafe text that does not fit into the other categories.
64
+ It includes text that seeks or provides information about engaging in misconduct, wrongdoing, or criminal activity, or that threatens to harm,
65
+ defraud, or exploit others. This includes facilitating illegal acts (under Singapore law) or other forms of socially harmful activity.
66
+
67
+ There are two sub-categories for the `all_other_misconduct` category:
68
+ a. `level_1_not_socially_accepted`: Text that advocates or instructs on unethical/immoral activities that may not necessarily be illegal but are socially condemned.
69
+ b. `level_2_illegal_activities`: Text that seeks or provides instructions to carry out clearly illegal activities or serious wrongdoing; includes credible threats of severe harm.
70
+
71
+ Lastly, there is an additional `binary` category (#7) which flags whether the text is unsafe in general.
72
+
73
+ The model takes in as input text, after it has been encoded with OpenAI's `text-embedding-3-small` model.
74
+
75
+ The model outputs the probabilities of each category being true.
76
+
77
+ ================================
78
+
79
+ Args:
80
+ input_dim: The dimension of the input embeddings. This defaults to 3072, which is the dimension of the embeddings from OpenAI's `text-embedding-3-small` model. This should not be changed.
81
+ label_names: The names of the labels. This defaults to the keys of the CATEGORIES dictionary. This should not be changed.
82
+ categories: The categories of the labels. This defaults to the CATEGORIES dictionary. This should not be changed.
83
+
84
+ Returns:
85
+ A LionGuard2 model.
86
+ """
87
+ super(LionGuard2, self).__init__()
88
+ self.label_names = label_names
89
+ self.n_outputs = len(label_names)
90
+ self.categories = categories
91
+
92
+ # Shared layers
93
+ self.shared_layers = nn.Sequential(
94
+ nn.Linear(input_dim, 256),
95
+ nn.ReLU(),
96
+ nn.Dropout(0.2),
97
+ nn.Linear(256, 128),
98
+ nn.ReLU(),
99
+ nn.Dropout(0.2),
100
+ )
101
+
102
+ # Output heads for each label
103
+ self.output_heads = nn.ModuleList(
104
+ [
105
+ nn.Sequential(
106
+ nn.Linear(128, 32),
107
+ nn.ReLU(),
108
+ nn.Linear(32, 2), # 2 thresholds for ordinal classification
109
+ nn.Sigmoid(),
110
+ )
111
+ for _ in range(self.n_outputs)
112
+ ]
113
+ )
114
+
115
+ def forward(self, x):
116
+ # Pass through shared layers
117
+ h = self.shared_layers(x)
118
+ # Pass through each output head
119
+ return [head(h) for head in self.output_heads]
120
+
121
+ def predict(self, embeddings):
122
+ """
123
+ Predict the probabilities of each label being true.
124
+
125
+ Args:
126
+ embeddings: A numpy array of embeddings (N * INPUT_DIMENSION)
127
+
128
+ Returns:
129
+ A dictionary of probabilities.
130
+ """
131
+ # Convert input to PyTorch tensor if not already
132
+ if not isinstance(embeddings, torch.Tensor):
133
+ x = torch.tensor(embeddings, dtype=torch.float32)
134
+ else:
135
+ x = embeddings
136
+
137
+ # Pass through model
138
+ with torch.no_grad():
139
+ outputs = self.forward(x)
140
+
141
+ # Stack outputs into a single tensor
142
+ raw_predictions = torch.stack(outputs) # SIZE:
143
+
144
+ # Extract and format probabilities from raw predictions
145
+ output = {}
146
+ for i, main_cat in enumerate(self.label_names):
147
+ sub_categories = self.categories[main_cat]
148
+ for j, sub_cat in enumerate(sub_categories):
149
+ # j=0 uses P(y>0)
150
+ # j=1 uses P(y>1) if L2 category exists
151
+ output[sub_cat] = raw_predictions[i, :, j]
152
+
153
+ # Post processing step:
154
+ # If L2 category exists, and P(L2) > P(L1),
155
+ # Set both P(L1) and P(L2) to their average to maintain ordinal consistency
156
+ if len(sub_categories) > 1:
157
+ l1 = output[sub_categories[0]]
158
+ l2 = output[sub_categories[1]]
159
+
160
+ # Update probabilities on samples where P(L2) > P(L1)
161
+ mask = l2 > l1
162
+ mean_prob = (l1 + l2) / 2
163
+ l1[mask] = mean_prob[mask]
164
+ l2[mask] = mean_prob[mask]
165
+ output[sub_categories[0]] = l1
166
+ output[sub_categories[1]] = l2
167
+
168
+ for key, value in output.items():
169
+ output[key] = value.numpy().tolist()
170
+ return output
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ numpy==2.1.3
2
+ openai==1.83.0
3
+ safetensors==0.5.3
4
+ torch==2.7.0
utils.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ utils.py
3
+ """
4
+
5
+ # Standard imports
6
+ import os
7
+ from typing import List
8
+
9
+ # Third party imports
10
+ import numpy as np
11
+ from openai import OpenAI
12
+
13
+ client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
14
+
15
+ # Maximum tokens for text-embedding-3-small
16
+ MAX_TOKENS = 8191 # We don't have access to the tokenizer for text-embedding-3-small, and just assume 1 character = 1 token here
17
+
18
+
19
+ def get_embeddings(
20
+ texts: List[str], model: str = "text-embedding-3-large"
21
+ ) -> List[List[float]]:
22
+ """
23
+ Generate embeddings for a list of texts using OpenAI API synchronously.
24
+
25
+ Args:
26
+ texts: List of strings to embed.
27
+ model: OpenAI embedding model to use (default: text-embedding-3-small).
28
+
29
+ Returns:
30
+ A list of embeddings (each embedding is a list of floats).
31
+
32
+ Raises:
33
+ Exception: If the OpenAI API call fails.
34
+ """
35
+
36
+ # Truncate texts to max token limit
37
+ truncated_texts = [text[:MAX_TOKENS] for text in texts]
38
+
39
+ # Make the API call
40
+ response = client.embeddings.create(input=truncated_texts, model=model)
41
+
42
+ # Extract embeddings from response
43
+ embeddings = np.array([data.embedding for data in response.data])
44
+ return embeddings