frogleo commited on
Commit
bb13373
·
verified ·
1 Parent(s): 9515f65

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -27
app.py CHANGED
@@ -1,20 +1,14 @@
1
- import os
2
  import gc
3
  import gradio as gr
4
  import numpy as np
5
  import torch
6
- import json
7
  import spaces
8
  import random
9
- import config
10
  import utils
11
  import logging
12
- from PIL import Image, PngImagePlugin
13
- from datetime import datetime
14
  from diffusers.models import AutoencoderKL
15
- from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
16
- import time
17
- from typing import List, Dict, Tuple, Optional
18
  from config import (
19
  MODEL,
20
  MIN_IMAGE_SIZE,
@@ -23,7 +17,8 @@ from config import (
23
  DEFAULT_NEGATIVE_PROMPT,
24
  scheduler_list,
25
  )
26
- import io
 
27
 
28
  MAX_SEED = np.iinfo(np.int32).max
29
 
@@ -62,6 +57,23 @@ else:
62
  pipe = None
63
 
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  class GenerationError(Exception):
66
  """Custom exception for generation errors"""
67
  pass
@@ -92,11 +104,20 @@ def validate_dimensions(width: int, height: int) -> None:
92
  raise GenerationError(f"Height must be between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE}")
93
 
94
 
 
 
 
 
 
 
 
 
 
95
 
96
  progress=gr.Progress()
97
 
98
  @spaces.GPU
99
- def generate(
100
  prompt: str,
101
  negative_prompt: str,
102
  width: int,
@@ -177,16 +198,33 @@ def generate(
177
  callback_on_step_end=callback2
178
  ).images
179
  out_img = images[0]
 
 
 
 
 
 
 
180
  path = utils.save_image(out_img, "./outputs")
181
  logger.info(f"output path: {path}")
182
  progress(1, desc="Complete")
183
- return path
 
 
 
 
184
  except GenerationError as e:
185
- logger.warning(f"Generation validation error: {str(e)}")
186
- raise gr.Error(str(e))
 
 
 
187
  except Exception as e:
188
- logger.exception("Unexpected error during generation")
189
- raise gr.Error(f"Generation failed: {str(e)}")
 
 
 
190
  finally:
191
  # Cleanup
192
  torch.cuda.empty_cache()
@@ -200,27 +238,43 @@ def generate(
200
 
201
  utils.free_memory()
202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
-
 
 
205
 
 
 
206
 
207
 
208
  title = "# Anime AI Generator"
209
  description = "Our AI-Powered Anime Generator turns your ideas into breathtaking AI anime art—perfect for art, storytelling, or personal AI anime wallpaper. Experience more at [Anime AI Generator](https://www.animeaigen.com)."
210
 
211
  custom_css = """
212
- #row-container {
213
- align-items: stretch;
214
- }
215
- #output-image{
216
- flex-grow: 1;
217
- }
218
- #output-image *{
219
- max-height: none !important;
220
- }
221
  """
222
 
223
-
224
  with gr.Blocks(css=custom_css).queue() as demo:
225
  gr.Markdown(title)
226
  gr.Markdown(description)
@@ -327,4 +381,4 @@ with gr.Blocks(css=custom_css).queue() as demo:
327
  )
328
 
329
  if __name__ == "__main__":
330
- demo.queue(max_size=20).launch()
 
 
1
  import gc
2
  import gradio as gr
3
  import numpy as np
4
  import torch
 
5
  import spaces
6
  import random
 
7
  import utils
8
  import logging
9
+ from PIL import Image
 
10
  from diffusers.models import AutoencoderKL
11
+ from diffusers import StableDiffusionXLImg2ImgPipeline
 
 
12
  from config import (
13
  MODEL,
14
  MIN_IMAGE_SIZE,
 
17
  DEFAULT_NEGATIVE_PROMPT,
18
  scheduler_list,
19
  )
20
+ from transformers import AutoProcessor, AutoModelForImageClassification
21
+
22
 
23
  MAX_SEED = np.iinfo(np.int32).max
24
 
 
57
  pipe = None
58
 
59
 
60
+ # -------------------- NSFW 检测模型加载 --------------------
61
+ try:
62
+ logger.info("Loading NSFW detector...")
63
+ from transformers import AutoProcessor, AutoModelForImageClassification
64
+ nsfw_processor = AutoProcessor.from_pretrained("Falconsai/nsfw_image_detection")
65
+ nsfw_model = AutoModelForImageClassification.from_pretrained(
66
+ "Falconsai/nsfw_image_detection"
67
+ ).to(device)
68
+ logger.info("NSFW detector loaded successfully.")
69
+ except Exception as e:
70
+ logger.error(f"Failed to load NSFW detector: {e}")
71
+ nsfw_model = None
72
+ nsfw_processor = None
73
+ # -----------------------------------------------------------
74
+
75
+
76
+
77
  class GenerationError(Exception):
78
  """Custom exception for generation errors"""
79
  pass
 
104
  raise GenerationError(f"Height must be between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE}")
105
 
106
 
107
+ def detect_nsfw(image: Image.Image, threshold: float = 0.5) -> bool:
108
+ """Returns True if image is NSFW"""
109
+ inputs = nsfw_processor(images=image, return_tensors="pt").to(device)
110
+ with torch.no_grad():
111
+ outputs = nsfw_model(**inputs)
112
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
113
+ nsfw_score = probs[0][1].item() # label 1 = NSFW
114
+ return nsfw_score > threshold
115
+
116
 
117
  progress=gr.Progress()
118
 
119
  @spaces.GPU
120
+ def _generate_on_gpu(
121
  prompt: str,
122
  negative_prompt: str,
123
  width: int,
 
198
  callback_on_step_end=callback2
199
  ).images
200
  out_img = images[0]
201
+
202
+ # NSFW 检测
203
+ if nsfw_model and nsfw_processor:
204
+ if detect_nsfw(out_img):
205
+ msg = "Generated image contains NSFW content and cannot be displayed. Please modify your prompt and try again."
206
+ raise Exception(msg)
207
+
208
  path = utils.save_image(out_img, "./outputs")
209
  logger.info(f"output path: {path}")
210
  progress(1, desc="Complete")
211
+
212
+ info = {
213
+ "status": "success"
214
+ }
215
+ return path, info
216
  except GenerationError as e:
217
+ error_info = {
218
+ "error": str(e),
219
+ "status": "failed",
220
+ }
221
+ return None, error_info
222
  except Exception as e:
223
+ error_info = {
224
+ "error": str(e),
225
+ "status": "failed",
226
+ }
227
+ return None, error_info
228
  finally:
229
  # Cleanup
230
  torch.cuda.empty_cache()
 
238
 
239
  utils.free_memory()
240
 
241
+ def generate(
242
+ prompt: str,
243
+ negative_prompt: str,
244
+ width: int,
245
+ height: int,
246
+ scheduler: str,
247
+ opt_strength: float,
248
+ opt_scale: float,
249
+ seed: int,
250
+ randomize_seed: bool,
251
+ guidance_scale: float,
252
+ num_inference_steps: int,
253
+ ):
254
+ # 调用 GPU 函数
255
+ image_path, info = _generate_on_gpu(
256
+ prompt, negative_prompt,
257
+ width, height,
258
+ scheduler,
259
+ opt_strength, opt_scale,
260
+ seed, randomize_seed,
261
+ guidance_scale, num_inference_steps,
262
+ )
263
 
264
+ # 如果出错,抛出异常
265
+ if info["status"] == "failed":
266
+ raise gr.Error(info["error"])
267
 
268
+ # 返回图片路径
269
+ return image_path
270
 
271
 
272
  title = "# Anime AI Generator"
273
  description = "Our AI-Powered Anime Generator turns your ideas into breathtaking AI anime art—perfect for art, storytelling, or personal AI anime wallpaper. Experience more at [Anime AI Generator](https://www.animeaigen.com)."
274
 
275
  custom_css = """
 
 
 
 
 
 
 
 
 
276
  """
277
 
 
278
  with gr.Blocks(css=custom_css).queue() as demo:
279
  gr.Markdown(title)
280
  gr.Markdown(description)
 
381
  )
382
 
383
  if __name__ == "__main__":
384
+ demo.launch()