selfit-camera commited on
Commit
7223d40
·
1 Parent(s): 561315f
Files changed (7) hide show
  1. .DS_Store +0 -0
  2. .gitignore +4 -1
  3. app.py +29 -13
  4. labels.json +4 -0
  5. nfsw.py +210 -0
  6. requirements.txt +4 -1
  7. util.py +0 -38
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitignore CHANGED
@@ -1,2 +1,5 @@
1
  *.jpg
2
- *.png
 
 
 
 
1
  *.jpg
2
+ *.png
3
+ hf_cache/
4
+ models/
5
+ __pycache__/
app.py CHANGED
@@ -1,10 +1,19 @@
1
  import gradio as gr
2
  import threading
3
- from util import process_image_edit, check_nsfw, get_country_info_safe
 
4
 
5
  IP_Dict = {}
6
  NSFW_Dict = {} # 记录每个IP的NSFW违规次数
7
 
 
 
 
 
 
 
 
 
8
  def edit_image_interface(input_image, prompt, request: gr.Request, progress=gr.Progress()):
9
  """
10
  Interface function for processing image editing
@@ -24,7 +33,7 @@ def edit_image_interface(input_image, prompt, request: gr.Request, progress=gr.P
24
  # 检查IP是否因NSFW违规过多而被屏蔽 3
25
  if client_ip in NSFW_Dict and NSFW_Dict[client_ip] >= 3:
26
  print(f"❌ IP blocked due to excessive NSFW violations - IP: {client_ip}({country_info}), violations: {NSFW_Dict[client_ip]}")
27
- return None, "❌ NSFW content too much. Access denied due to policy violations"
28
 
29
  if input_image is None:
30
  return None, "Please upload an image first"
@@ -36,14 +45,21 @@ def edit_image_interface(input_image, prompt, request: gr.Request, progress=gr.P
36
  if len(prompt.strip()) <= 3:
37
  return None, "❌ Editing prompt must be more than 3 characters"
38
 
39
- # 检查是否包含NSFW内容
40
- if check_nsfw(prompt.strip()) == 1:
41
- # 记录NSFW违规次数
42
- if client_ip not in NSFW_Dict:
43
- NSFW_Dict[client_ip] = 0
44
- NSFW_Dict[client_ip] += 1
45
- print(f"❌ NSFW content detected - IP: {client_ip}({country_info}), violations: {NSFW_Dict[client_ip]}, prompt: {prompt.strip()}")
46
- return None, "❌ NSFW content detected. Please modify your prompt. NSFW detected {NSFW_Dict[client_ip]}/10 times"
 
 
 
 
 
 
 
47
 
48
  if IP_Dict[client_ip]>8 and country_info.lower() in ["印度", "巴基斯坦"]:
49
  print(f"❌ Content not allowed - IP: {client_ip}({country_info}), count: {IP_Dict[client_ip]}, prompt: {prompt.strip()}")
@@ -67,17 +83,17 @@ def edit_image_interface(input_image, prompt, request: gr.Request, progress=gr.P
67
 
68
  try:
69
  # 打印成功访问的信息
70
- print(f"✅ Processing started - IP: {client_ip}({country_info}), count: {IP_Dict[client_ip]}, prompt: {prompt.strip()}")
71
 
72
  # Call image editing processing function
73
  result_url, message = process_image_edit(input_image, prompt.strip(), progress_callback)
74
 
75
  if result_url:
76
- print(f"✅ Processing completed successfully - IP: {client_ip}({country_info}), result_url: {result_url}")
77
  progress(1.0, desc="Processing completed")
78
  return result_url, "✅ " + message
79
  else:
80
- print(f"❌ Processing failed - IP: {client_ip}({country_info}), error: {message}")
81
  return None, "❌ " + message
82
 
83
  except Exception as e:
 
1
  import gradio as gr
2
  import threading
3
+ from util import process_image_edit, get_country_info_safe
4
+ from nfsw import NSFWDetector
5
 
6
  IP_Dict = {}
7
  NSFW_Dict = {} # 记录每个IP的NSFW违规次数
8
 
9
+ # 初始化NSFW检测器(从Hugging Face下载)
10
+ try:
11
+ nsfw_detector = NSFWDetector() # 自动从Hugging Face下载falconsai_yolov9_nsfw_model_quantized.pt
12
+ print("✅ NSFW检测器初始化成功")
13
+ except Exception as e:
14
+ print(f"❌ NSFW检测器初始化失败: {e}")
15
+ nsfw_detector = None
16
+
17
  def edit_image_interface(input_image, prompt, request: gr.Request, progress=gr.Progress()):
18
  """
19
  Interface function for processing image editing
 
33
  # 检查IP是否因NSFW违规过多而被屏蔽 3
34
  if client_ip in NSFW_Dict and NSFW_Dict[client_ip] >= 3:
35
  print(f"❌ IP blocked due to excessive NSFW violations - IP: {client_ip}({country_info}), violations: {NSFW_Dict[client_ip]}")
36
+ return None, f"❌ Your ip {client_ip},your region has been blocked"
37
 
38
  if input_image is None:
39
  return None, "Please upload an image first"
 
45
  if len(prompt.strip()) <= 3:
46
  return None, "❌ Editing prompt must be more than 3 characters"
47
 
48
+ # 检查图片是否包含NSFW内容
49
+ nsfw_result = None
50
+ if nsfw_detector is not None:
51
+ try:
52
+ nsfw_result = nsfw_detector.predict_label_only(input_image)
53
+ if nsfw_result.lower() == "nsfw":
54
+ # 记录NSFW违规次数
55
+ if client_ip not in NSFW_Dict:
56
+ NSFW_Dict[client_ip] = 0
57
+ NSFW_Dict[client_ip] += 1
58
+ print(f"❌ NSFW image detected - IP: {client_ip}({country_info}), violations: {NSFW_Dict[client_ip]}")
59
+ return None, f"❌ Your ip {client_ip},your region has been blocked"
60
+ except Exception as e:
61
+ print(f"⚠️ NSFW检测失败: {e}")
62
+ # 检测失败时允许继续处理
63
 
64
  if IP_Dict[client_ip]>8 and country_info.lower() in ["印度", "巴基斯坦"]:
65
  print(f"❌ Content not allowed - IP: {client_ip}({country_info}), count: {IP_Dict[client_ip]}, prompt: {prompt.strip()}")
 
83
 
84
  try:
85
  # 打印成功访问的信息
86
+ print(f"✅ Processing started - IP: {client_ip}({country_info}), count: {IP_Dict[client_ip]}, prompt: {prompt.strip()}", flush=True)
87
 
88
  # Call image editing processing function
89
  result_url, message = process_image_edit(input_image, prompt.strip(), progress_callback)
90
 
91
  if result_url:
92
+ print(f"✅ Processing completed successfully - IP: {client_ip}({country_info}), result_url: {result_url}", flush=True)
93
  progress(1.0, desc="Processing completed")
94
  return result_url, "✅ " + message
95
  else:
96
+ print(f"❌ Processing failed - IP: {client_ip}({country_info}), error: {message}", flush=True)
97
  return None, "❌ " + message
98
 
99
  except Exception as e:
labels.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "0": "normal",
3
+ "1": "nsfw"
4
+ }
nfsw.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import numpy as np
4
+ import onnxruntime as ort
5
+ import json
6
+ from huggingface_hub import hf_hub_download
7
+
8
+
9
+ class NSFWDetector:
10
+ """
11
+ NSFW检测器类,使用YOLOv9模型进行图像分类
12
+ """
13
+
14
+ def __init__(self, repo_id="Falconsai/nsfw_image_detection",
15
+ model_filename="falconsai_yolov9_nsfw_model_quantized.pt",
16
+ labels_filename="labels.json",
17
+ input_size=(224, 224)):
18
+ """
19
+ 初始化NSFW检测器
20
+
21
+ Args:
22
+ repo_id (str): Hugging Face仓库ID
23
+ model_filename (str): 模型文件名
24
+ labels_filename (str): 标签文件名
25
+ input_size (tuple): 模型输入尺寸 (height, width)
26
+ """
27
+ self.repo_id = repo_id
28
+ self.model_filename = model_filename
29
+ self.labels_filename = labels_filename
30
+ self.input_size = input_size
31
+
32
+ # 从Hugging Face下载文件
33
+ self.model_path = self._download_model()
34
+ self.labels_path = self._download_labels()
35
+
36
+ # 加载标签
37
+ self.labels = self._load_labels()
38
+
39
+ # 加载模型
40
+ self.session = self._load_model()
41
+ self.input_name = self.session.get_inputs()[0].name
42
+ self.output_name = self.session.get_outputs()[0].name
43
+
44
+ def _download_model(self):
45
+ """
46
+ 从Hugging Face下载模型文件
47
+
48
+ Returns:
49
+ str: 下载的模型文件路径
50
+ """
51
+ try:
52
+ print(f"正在从 {self.repo_id} 下载模型文件: {self.model_filename}")
53
+ model_path = hf_hub_download(
54
+ repo_id=self.repo_id,
55
+ filename=self.model_filename,
56
+ cache_dir="./hf_cache"
57
+ )
58
+ print(f"✅ 模型下载成功: {model_path}")
59
+ return model_path
60
+ except Exception as e:
61
+ raise RuntimeError(f"模型下载失败: {e}")
62
+
63
+ def _download_labels(self):
64
+ """
65
+ 从Hugging Face下载标签文件
66
+
67
+ Returns:
68
+ str: 下载的标签文件路径
69
+ """
70
+ try:
71
+ print(f"正在从 {self.repo_id} 下载标签文件: {self.labels_filename}")
72
+ labels_path = hf_hub_download(
73
+ repo_id=self.repo_id,
74
+ filename=self.labels_filename,
75
+ cache_dir="./hf_cache"
76
+ )
77
+ print(f"✅ 标签文件下载成功: {labels_path}")
78
+ return labels_path
79
+ except Exception as e:
80
+ raise RuntimeError(f"标签文件下载失败: {e}")
81
+
82
+ def _load_labels(self):
83
+ """
84
+ 加载类别标签
85
+
86
+ Returns:
87
+ dict: 标签字典
88
+ """
89
+ try:
90
+ with open(self.labels_path, "r") as f:
91
+ return json.load(f)
92
+ except FileNotFoundError:
93
+ raise FileNotFoundError(f"标签文件未找到: {self.labels_path}")
94
+ except json.JSONDecodeError:
95
+ raise ValueError(f"标签文件格式错误: {self.labels_path}")
96
+
97
+ def _load_model(self):
98
+ """
99
+ 加载ONNX模型
100
+
101
+ Returns:
102
+ onnxruntime.InferenceSession: 模型会话
103
+ """
104
+ try:
105
+ return ort.InferenceSession(self.model_path)
106
+ except Exception as e:
107
+ raise RuntimeError(f"模型加载失败: {self.model_path}, 错误: {e}")
108
+
109
+ def _preprocess_image(self, image_path):
110
+ """
111
+ 图像预处理
112
+
113
+ Args:
114
+ image_path (str): 图像文件路径
115
+
116
+ Returns:
117
+ tuple: (预处理后的张量, 原始图像)
118
+ """
119
+ try:
120
+ # 加载并转换图像
121
+ original_image = Image.open(image_path).convert("RGB")
122
+
123
+ # 调整尺寸
124
+ image_resized = original_image.resize(self.input_size, Image.Resampling.BILINEAR)
125
+
126
+ # 转换为numpy数组并归一化
127
+ image_np = np.array(image_resized, dtype=np.float32) / 255.0
128
+
129
+ # 调整维度顺序 [H, W, C] -> [C, H, W]
130
+ image_np = np.transpose(image_np, (2, 0, 1))
131
+
132
+ # 添加批次维度 [C, H, W] -> [1, C, H, W]
133
+ input_tensor = np.expand_dims(image_np, axis=0).astype(np.float32)
134
+
135
+ return input_tensor, original_image
136
+
137
+ except FileNotFoundError:
138
+ raise FileNotFoundError(f"图像文件未找到: {image_path}")
139
+ except Exception as e:
140
+ raise RuntimeError(f"图像预处理失败: {e}")
141
+
142
+ def _postprocess_predictions(self, predictions):
143
+ """
144
+ 后处理预测结果
145
+
146
+ Args:
147
+ predictions: 模型预测输出
148
+
149
+ Returns:
150
+ str: 预测的类别标签
151
+ """
152
+ predicted_index = np.argmax(predictions)
153
+ predicted_label = self.labels[str(predicted_index)]
154
+ return predicted_label
155
+
156
+ def predict(self, image_path):
157
+ """
158
+ 对单张图像进行NSFW检测
159
+
160
+ Args:
161
+ image_path (str): 图像文件路径
162
+
163
+ Returns:
164
+ tuple: (预测标签, 原始图像)
165
+ """
166
+ # 预处理图像
167
+ input_tensor, original_image = self._preprocess_image(image_path)
168
+
169
+ # 运行推理
170
+ outputs = self.session.run([self.output_name], {self.input_name: input_tensor})
171
+ predictions = outputs[0]
172
+
173
+ # 后处理结果
174
+ predicted_label = self._postprocess_predictions(predictions)
175
+
176
+ return predicted_label, original_image
177
+
178
+ def predict_label_only(self, image_path):
179
+ """
180
+ 只返回预测标签(不返回图像)
181
+
182
+ Args:
183
+ image_path (str): 图像文件路径
184
+
185
+ Returns:
186
+ str: 预测的类别标签
187
+ """
188
+ predicted_label, _ = self.predict(image_path)
189
+ return predicted_label
190
+
191
+ # --- 使用示例 ---
192
+ if __name__ == "__main__":
193
+ # 配置参数
194
+ single_image_path = "datas/bad01.jpg"
195
+
196
+ try:
197
+ # 创建检测器实例(自动从Hugging Face下载)
198
+ detector = NSFWDetector()
199
+
200
+ # 检查图像文件是否存在
201
+ if os.path.exists(single_image_path):
202
+ # 进行预测
203
+ predicted_label = detector.predict_label_only(single_image_path)
204
+ print(f"图像文件: {single_image_path}")
205
+ print(f"预测结果: {predicted_label}")
206
+ else:
207
+ print(f"错误: 指定的图像文件不存在: {single_image_path}")
208
+
209
+ except Exception as e:
210
+ print(f"初始化检测器时发生错误: {e}")
requirements.txt CHANGED
@@ -4,4 +4,7 @@ requests>=2.28.0
4
  func-timeout>=4.3.5
5
  numpy>=1.24.0
6
  boto3
7
- botocore
 
 
 
 
4
  func-timeout>=4.3.5
5
  numpy>=1.24.0
6
  boto3
7
+ botocore
8
+ onnxruntime
9
+ huggingface_hub>=0.16.0
10
+ Pillow>=9.0.0
util.py CHANGED
@@ -177,44 +177,6 @@ def get_country_info_safe(ip):
177
  return "Unknown"
178
 
179
 
180
- def check_nsfw(prompt):
181
- """
182
- 检查prompt是否包含NSFW内容,包含返回1,否则返回0
183
- """
184
- try:
185
- response = requests.post(
186
- url="https://openrouter.ai/api/v1/chat/completions",
187
- headers={
188
- "Authorization": f"Bearer {LLMKEY}",
189
- "Content-Type": "application/json",
190
- },
191
- data=json.dumps({
192
- "model": "google/gemini-2.5-flash",
193
- "messages": [
194
- {
195
- "role": "system",
196
- "content": "你是一个nsfw指令判断助手,请判断用户输入的prompt指令是否会导致nsfw内容? 你只需要回答 是 或者 否"
197
- },
198
- {
199
- "role": "user",
200
- "content": prompt
201
- }
202
- ],
203
- })
204
- )
205
- res_json = response.json()
206
- # 兼容不同模型返回格式
207
- if "choices" in res_json and len(res_json["choices"]) > 0:
208
- content = res_json["choices"][0].get("message", {}).get("content", "")
209
- if "是" in content:
210
- return 1
211
- else:
212
- return 0
213
- else:
214
- return 0
215
- except Exception as e:
216
- # 出错时默认返回0
217
- return 0
218
 
219
 
220
  def submit_image_edit_task(user_image_url, prompt):
 
177
  return "Unknown"
178
 
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
 
182
  def submit_image_edit_task(user_image_url, prompt):