nishanth-saka commited on
Commit
e239a1b
·
verified ·
1 Parent(s): 413d5ae

cpu usage error

Browse files
Files changed (1) hide show
  1. app.py +9 -7
app.py CHANGED
@@ -4,37 +4,39 @@ from diffusers import StableDiffusionXLImg2ImgPipeline
4
  from PIL import Image
5
 
6
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
7
- DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
8
 
9
  pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
10
  "stabilityai/stable-diffusion-xl-base-1.0",
11
- dtype=DTYPE,
12
  use_safetensors=True,
13
  )
14
 
15
- pipe.to(DEVICE)
16
 
17
- if DEVICE == "cpu":
18
- pipe.enable_model_cpu_offload()
19
 
20
  def generate(image, prompt):
21
  image = image.resize((1024, 1024))
 
22
  result = pipe(
23
  prompt=prompt,
24
  image=image,
25
  strength=0.2,
26
  guidance_scale=6,
27
- num_inference_steps=20,
28
  )
 
29
  return result.images[0]
30
 
31
  demo = gr.Interface(
32
  fn=generate,
33
  inputs=[
34
- gr.Image(type="pil"),
35
  gr.Textbox(label="Prompt"),
36
  ],
37
  outputs=gr.Image(type="pil"),
 
38
  )
39
 
40
  demo.launch()
 
4
  from PIL import Image
5
 
6
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
7
 
8
  pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
9
  "stabilityai/stable-diffusion-xl-base-1.0",
10
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
11
  use_safetensors=True,
12
  )
13
 
14
+ pipe = pipe.to(DEVICE)
15
 
16
+ # DO NOT enable cpu offload on HF CPU
17
+ # pipe.enable_model_cpu_offload() <-- REMOVE THIS
18
 
19
  def generate(image, prompt):
20
  image = image.resize((1024, 1024))
21
+
22
  result = pipe(
23
  prompt=prompt,
24
  image=image,
25
  strength=0.2,
26
  guidance_scale=6,
27
+ num_inference_steps=20, # keep low for CPU
28
  )
29
+
30
  return result.images[0]
31
 
32
  demo = gr.Interface(
33
  fn=generate,
34
  inputs=[
35
+ gr.Image(type="pil", label="Input Image"),
36
  gr.Textbox(label="Prompt"),
37
  ],
38
  outputs=gr.Image(type="pil"),
39
+ title="SDXL Image-to-Image (CPU Safe)",
40
  )
41
 
42
  demo.launch()